[feat] Added Quest Sparsity Policy.

This commit is contained in:
Zijie Tian
2026-01-07 03:29:21 +08:00
parent c99a6f3d3f
commit 2a6e0a2c02
9 changed files with 92 additions and 92 deletions

View File

@@ -1,9 +1,16 @@
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum, auto
from transformers import AutoConfig from transformers import AutoConfig
import torch import torch
class SparsePolicyType(Enum):
"""Sparse attention policy types."""
FULL = auto() # No sparse attention (load all blocks)
QUEST = auto() # Query-aware Top-K block selection (decode only)
@dataclass @dataclass
class Config: class Config:
model: str model: str
@@ -29,11 +36,10 @@ class Config:
num_gpu_kvcache_blocks: int = -1 num_gpu_kvcache_blocks: int = -1
num_cpu_kvcache_blocks: int = -1 num_cpu_kvcache_blocks: int = -1
# Sparse attention configuration (dual policy architecture) # Sparse attention configuration
prefill_policy: str = "full" # "full", "quest", "vertical_slash", "streaming_llm" # Quest: decode-only sparse attention with Top-K block selection
decode_policy: str = "full" # "full", "quest", "vertical_slash", "streaming_llm" # FULL: no sparse attention (load all blocks)
sparse_num_sink_blocks: int = 1 # Number of sink blocks for sparse patterns sparse_policy: SparsePolicyType = SparsePolicyType.FULL
sparse_local_window_blocks: int = 2 # Local window size for VerticalSlash
sparse_topk_blocks: int = 8 # Top-K blocks for Quest sparse_topk_blocks: int = 8 # Top-K blocks for Quest
sparse_threshold_blocks: int = 4 # Apply sparse only when blocks > threshold sparse_threshold_blocks: int = 4 # Apply sparse only when blocks > threshold

View File

@@ -156,12 +156,9 @@ class ModelRunner:
dtype=hf_config.torch_dtype, dtype=hf_config.torch_dtype,
) )
# Initialize sparse policies if manager has them (CPU offload mode) # Initialize sparse policy if manager has one (CPU offload mode)
if hasattr(self.kvcache_manager, 'prefill_policy') and hasattr(self.kvcache_manager, 'decode_policy'): if hasattr(self.kvcache_manager, 'sparse_policy') and self.kvcache_manager.sparse_policy is not None:
# Initialize both policies with model config self.kvcache_manager.sparse_policy.initialize(
for policy in [self.kvcache_manager.prefill_policy, self.kvcache_manager.decode_policy]:
if policy is not None:
policy.initialize(
num_layers=hf_config.num_hidden_layers, num_layers=hf_config.num_hidden_layers,
num_kv_heads=num_kv_heads, num_kv_heads=num_kv_heads,
head_dim=head_dim, head_dim=head_dim,
@@ -171,7 +168,7 @@ class ModelRunner:
) )
logger.info( logger.info(
f"Sparse policies initialized: prefill={config.prefill_policy}, decode={config.decode_policy} " f"Sparse policy initialized: {config.sparse_policy.name} "
f"(topk={config.sparse_topk_blocks}, threshold={config.sparse_threshold_blocks})" f"(topk={config.sparse_topk_blocks}, threshold={config.sparse_threshold_blocks})"
) )

View File

@@ -56,36 +56,26 @@ def create_kvcache_manager(config: "Config") -> KVCacheManager:
# Need CPU offload: use hybrid manager # Need CPU offload: use hybrid manager
from nanovllm.kvcache.hybrid_manager import HybridKVCacheManager from nanovllm.kvcache.hybrid_manager import HybridKVCacheManager
from nanovllm.kvcache.policies import get_policy from nanovllm.kvcache.policies import get_policy
from nanovllm.kvcache.sparse import create_sparse_policy, SparsePolicyType from nanovllm.kvcache.sparse import create_sparse_policy
from nanovllm.config import SparsePolicyType
eviction_policy = get_policy(getattr(config, 'offload_policy', 'lru')) eviction_policy = get_policy(getattr(config, 'offload_policy', 'lru'))
# Create sparse policies from config # Create sparse policy from config enum
prefill_policy_type = getattr(config, 'prefill_policy', 'full') # Quest is decode-only: prefill returns all blocks (query=None), decode does Top-K
decode_policy_type = getattr(config, 'decode_policy', 'full') sparse_policy_type = getattr(config, 'sparse_policy', SparsePolicyType.FULL)
sparse_policy = create_sparse_policy(
def create_policy(policy_type_str): sparse_policy_type,
"""Create a sparse policy from config string."""
if policy_type_str.lower() == 'full':
return create_sparse_policy(SparsePolicyType.FULL)
policy_type = SparsePolicyType[policy_type_str.upper()]
return create_sparse_policy(
policy_type,
topk_blocks=getattr(config, 'sparse_topk_blocks', 8), topk_blocks=getattr(config, 'sparse_topk_blocks', 8),
threshold_blocks=getattr(config, 'sparse_threshold_blocks', 4), threshold_blocks=getattr(config, 'sparse_threshold_blocks', 4),
include_sink_blocks=getattr(config, 'sparse_num_sink_blocks', 1),
) )
prefill_policy = create_policy(prefill_policy_type)
decode_policy = create_policy(decode_policy_type)
return HybridKVCacheManager( return HybridKVCacheManager(
num_gpu_slots=num_gpu_blocks, num_gpu_slots=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks, num_cpu_blocks=num_cpu_blocks,
block_size=config.kvcache_block_size, block_size=config.kvcache_block_size,
policy=eviction_policy, policy=eviction_policy,
prefill_policy=prefill_policy, sparse_policy=sparse_policy,
decode_policy=decode_policy,
) )

View File

@@ -90,8 +90,7 @@ class HybridKVCacheManager(KVCacheManager):
num_cpu_blocks: int, num_cpu_blocks: int,
block_size: int, block_size: int,
policy: Optional[EvictionPolicy] = None, policy: Optional[EvictionPolicy] = None,
prefill_policy: "SparsePolicy" = None, sparse_policy: "SparsePolicy" = None,
decode_policy: "SparsePolicy" = None,
): ):
""" """
Initialize hybrid manager with CPU-primary ring buffer design. Initialize hybrid manager with CPU-primary ring buffer design.
@@ -104,8 +103,7 @@ class HybridKVCacheManager(KVCacheManager):
num_cpu_blocks: Number of CPU pool blocks (primary storage) num_cpu_blocks: Number of CPU pool blocks (primary storage)
block_size: Tokens per block block_size: Tokens per block
policy: Eviction policy (default: LRU, used for prefix cache management) policy: Eviction policy (default: LRU, used for prefix cache management)
prefill_policy: Sparse attention policy for prefill phase sparse_policy: Sparse attention policy (Quest for decode-only sparse)
decode_policy: Sparse attention policy for decode phase
""" """
self._block_size = block_size self._block_size = block_size
self.num_gpu_slots = num_gpu_slots self.num_gpu_slots = num_gpu_slots
@@ -117,9 +115,8 @@ class HybridKVCacheManager(KVCacheManager):
# Eviction policy # Eviction policy
self.policy = policy or LRUPolicy() self.policy = policy or LRUPolicy()
# Sparse attention policies (set at construction time, immutable) # Sparse attention policy (set at construction time, immutable)
self.prefill_policy = prefill_policy self.sparse_policy = sparse_policy
self.decode_policy = decode_policy
# Logical blocks (what sequences reference) - one per CPU block # Logical blocks (what sequences reference) - one per CPU block
self.logical_blocks: List[LogicalBlock] = [ self.logical_blocks: List[LogicalBlock] = [
@@ -185,8 +182,7 @@ class HybridKVCacheManager(KVCacheManager):
num_kv_heads=num_kv_heads, num_kv_heads=num_kv_heads,
head_dim=head_dim, head_dim=head_dim,
dtype=dtype, dtype=dtype,
prefill_policy=self.prefill_policy, sparse_policy=self.sparse_policy,
decode_policy=self.decode_policy,
) )
def get_layer_cache(self, layer_id: int) -> Tuple[Tensor, Tensor]: def get_layer_cache(self, layer_id: int) -> Tuple[Tensor, Tensor]:
@@ -194,18 +190,6 @@ class HybridKVCacheManager(KVCacheManager):
assert self.offload_engine is not None assert self.offload_engine is not None
return self.offload_engine.get_layer_cache(layer_id) return self.offload_engine.get_layer_cache(layer_id)
def get_policy_for_phase(self, is_prefill: bool) -> Optional["SparsePolicy"]:
"""
Get sparse policy for the specified phase.
Args:
is_prefill: True for prefill phase, False for decode phase
Returns:
SparsePolicy for the phase, or None if not set
"""
return self.prefill_policy if is_prefill else self.decode_policy
def can_allocate(self, seq: Sequence) -> bool: def can_allocate(self, seq: Sequence) -> bool:
"""Check if we can allocate blocks for a new sequence.""" """Check if we can allocate blocks for a new sequence."""
return len(self.free_logical_ids) >= seq.num_blocks return len(self.free_logical_ids) >= seq.num_blocks

View File

@@ -60,8 +60,7 @@ class OffloadEngine:
head_dim: int, head_dim: int,
dtype: torch.dtype = torch.float16, dtype: torch.dtype = torch.float16,
num_streams: int = 4, num_streams: int = 4,
prefill_policy: "SparsePolicy" = None, sparse_policy: "SparsePolicy" = None,
decode_policy: "SparsePolicy" = None,
): ):
self.num_layers = num_layers self.num_layers = num_layers
self.num_gpu_blocks = num_gpu_blocks self.num_gpu_blocks = num_gpu_blocks
@@ -217,9 +216,8 @@ class OffloadEngine:
self._debug_mode = False self._debug_mode = False
self._debug_hooks: List = [] # External hooks for debug events self._debug_hooks: List = [] # External hooks for debug events
# ========== Sparse attention policies (set at construction time) ========== # ========== Sparse attention policy (set at construction time) ==========
self.prefill_policy = prefill_policy self.sparse_policy = sparse_policy
self.decode_policy = decode_policy
def _get_next_stream(self) -> torch.cuda.Stream: def _get_next_stream(self) -> torch.cuda.Stream:
"""Round-robin stream selection for parallel transfers.""" """Round-robin stream selection for parallel transfers."""
@@ -765,20 +763,14 @@ class OffloadEngine:
logger.debug(f"Ring offload: GPU slot[{slot_idx}] -> CPU[layer={layer_id}, block={cpu_block_id}]") logger.debug(f"Ring offload: GPU slot[{slot_idx}] -> CPU[layer={layer_id}, block={cpu_block_id}]")
# Collect metadata BEFORE offload (while k_cache is still on GPU) # Collect metadata BEFORE offload (while k_cache is still on GPU)
# Both policies' callbacks are called - each decides whether to respond
valid_tokens = num_valid_tokens if num_valid_tokens > 0 else self.block_size valid_tokens = num_valid_tokens if num_valid_tokens > 0 else self.block_size
k_cache = self.k_cache_gpu[slot_idx] k_cache = self.k_cache_gpu[slot_idx]
if self.sparse_policy is not None:
if is_prefill: if is_prefill:
if self.prefill_policy is not None: self.sparse_policy.on_prefill_offload(cpu_block_id, layer_id, k_cache, valid_tokens)
self.prefill_policy.on_prefill_offload(cpu_block_id, layer_id, k_cache, valid_tokens)
if self.decode_policy is not None:
self.decode_policy.on_prefill_offload(cpu_block_id, layer_id, k_cache, valid_tokens)
else: else:
if self.prefill_policy is not None: self.sparse_policy.on_decode_offload(cpu_block_id, layer_id, k_cache, valid_tokens)
self.prefill_policy.on_decode_offload(cpu_block_id, layer_id, k_cache, valid_tokens)
if self.decode_policy is not None:
self.decode_policy.on_decode_offload(cpu_block_id, layer_id, k_cache, valid_tokens)
torch.cuda.nvtx.range_push(f"D2H: Slot[{slot_idx}]->CPU[L{layer_id},B{cpu_block_id}]") torch.cuda.nvtx.range_push(f"D2H: Slot[{slot_idx}]->CPU[L{layer_id},B{cpu_block_id}]")
with torch.cuda.stream(self.transfer_stream_main): with torch.cuda.stream(self.transfer_stream_main):

View File

@@ -19,7 +19,8 @@ Usage:
return available_blocks[:5] # Just first 5 blocks return available_blocks[:5] # Just first 5 blocks
""" """
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext, SparsePolicyType from nanovllm.config import SparsePolicyType
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext
from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy
from nanovllm.kvcache.sparse.quest import QuestPolicy, QuestConfig, BlockMetadataManager from nanovllm.kvcache.sparse.quest import QuestPolicy, QuestConfig, BlockMetadataManager

View File

@@ -7,15 +7,11 @@ from CPU for each query chunk during chunked attention computation.
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum, auto
from typing import List, Optional, Any from typing import List, Optional, Any
import torch import torch
# Import SparsePolicyType from config to avoid circular imports
class SparsePolicyType(Enum): from nanovllm.config import SparsePolicyType
"""Built-in sparse attention policy types."""
FULL = auto() # prefill + decode
QUEST = auto() # decode only
@dataclass @dataclass

View File

@@ -188,9 +188,9 @@ class Attention(nn.Module):
# Get prefilled CPU blocks (blocks from previous chunks) # Get prefilled CPU blocks (blocks from previous chunks)
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq) cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
# Apply sparse policy if enabled # Apply sparse policy if enabled (Quest returns all blocks for prefill since query=None)
prefill_policy = kvcache_manager.get_policy_for_phase(is_prefill=True) sparse_policy = kvcache_manager.sparse_policy
if cpu_block_table and prefill_policy is not None: if cpu_block_table and sparse_policy is not None:
num_chunks = getattr(context, 'num_chunks', current_chunk_idx + 1) num_chunks = getattr(context, 'num_chunks', current_chunk_idx + 1)
policy_ctx = PolicyContext( policy_ctx = PolicyContext(
query_chunk_idx=current_chunk_idx, query_chunk_idx=current_chunk_idx,
@@ -201,7 +201,7 @@ class Attention(nn.Module):
block_size=kvcache_manager.block_size, block_size=kvcache_manager.block_size,
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size, total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
) )
cpu_block_table = prefill_policy.select_blocks( cpu_block_table = sparse_policy.select_blocks(
cpu_block_table, policy_ctx cpu_block_table, policy_ctx
) )
@@ -512,9 +512,9 @@ class Attention(nn.Module):
if last_block_valid_tokens == 0 and total_prefill_tokens > 0: if last_block_valid_tokens == 0 and total_prefill_tokens > 0:
last_block_valid_tokens = block_size # Last block was exactly full last_block_valid_tokens = block_size # Last block was exactly full
# Apply sparse policy if enabled # Apply sparse policy if enabled (Quest does Top-K selection for decode)
decode_policy = kvcache_manager.get_policy_for_phase(is_prefill=False) sparse_policy = kvcache_manager.sparse_policy
if decode_policy is not None: if sparse_policy is not None:
policy_ctx = PolicyContext( policy_ctx = PolicyContext(
query_chunk_idx=0, query_chunk_idx=0,
num_query_chunks=1, num_query_chunks=1,
@@ -524,7 +524,7 @@ class Attention(nn.Module):
block_size=kvcache_manager.block_size, block_size=kvcache_manager.block_size,
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size, total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
) )
cpu_block_table = decode_policy.select_blocks( cpu_block_table = sparse_policy.select_blocks(
cpu_block_table, policy_ctx cpu_block_table, policy_ctx
) )

View File

@@ -12,6 +12,7 @@ os.environ["NANOVLLM_LOG_LEVEL"] = "INFO"
import argparse import argparse
from nanovllm import LLM, SamplingParams from nanovllm import LLM, SamplingParams
from nanovllm.config import SparsePolicyType
from utils import generate_needle_prompt, check_needle_answer from utils import generate_needle_prompt, check_needle_answer
@@ -29,6 +30,9 @@ def run_needle_test(
needle_value: str = "7492", needle_value: str = "7492",
max_new_tokens: int = 32, max_new_tokens: int = 32,
enable_cpu_offload: bool = False, enable_cpu_offload: bool = False,
enable_quest: bool = False,
sparse_topk: int = 8,
sparse_threshold: int = 4,
verbose: bool = True, verbose: bool = True,
) -> bool: ) -> bool:
""" """
@@ -44,11 +48,16 @@ def run_needle_test(
needle_value: The secret value to find needle_value: The secret value to find
max_new_tokens: Maximum tokens to generate max_new_tokens: Maximum tokens to generate
enable_cpu_offload: Enable CPU offload mode enable_cpu_offload: Enable CPU offload mode
enable_quest: Enable Quest sparse attention (decode-only Top-K)
sparse_topk: Top-K blocks for Quest
sparse_threshold: Apply sparse only when blocks > threshold
verbose: Print detailed output verbose: Print detailed output
Returns: Returns:
True if test passed, False otherwise True if test passed, False otherwise
""" """
sparse_policy = SparsePolicyType.QUEST if enable_quest else SparsePolicyType.FULL
if verbose: if verbose:
print(f"\n{'='*60}") print(f"\n{'='*60}")
print(f"Needle-in-Haystack Test") print(f"Needle-in-Haystack Test")
@@ -60,6 +69,8 @@ def run_needle_test(
print(f"Needle position: {needle_position:.0%}") print(f"Needle position: {needle_position:.0%}")
print(f"Needle value: {needle_value}") print(f"Needle value: {needle_value}")
print(f"CPU offload: {enable_cpu_offload}") print(f"CPU offload: {enable_cpu_offload}")
if enable_cpu_offload:
print(f"Sparse policy: {sparse_policy.name} (topk={sparse_topk}, threshold={sparse_threshold})")
print(f"{'='*60}\n") print(f"{'='*60}\n")
# 1. Initialize LLM # 1. Initialize LLM
@@ -72,6 +83,9 @@ def run_needle_test(
} }
if enable_cpu_offload: if enable_cpu_offload:
llm_kwargs["num_gpu_blocks"] = num_gpu_blocks llm_kwargs["num_gpu_blocks"] = num_gpu_blocks
llm_kwargs["sparse_policy"] = sparse_policy
llm_kwargs["sparse_topk_blocks"] = sparse_topk
llm_kwargs["sparse_threshold_blocks"] = sparse_threshold
llm = LLM(model_path, **llm_kwargs) llm = LLM(model_path, **llm_kwargs)
@@ -167,6 +181,23 @@ if __name__ == "__main__":
action="store_true", action="store_true",
help="Enable CPU offload (has known bug for long sequences)" help="Enable CPU offload (has known bug for long sequences)"
) )
parser.add_argument(
"--enable-quest",
action="store_true",
help="Enable Quest sparse attention (decode-only Top-K selection)"
)
parser.add_argument(
"--sparse-topk",
type=int,
default=8,
help="Top-K blocks for Quest sparse attention"
)
parser.add_argument(
"--sparse-threshold",
type=int,
default=4,
help="Apply sparse only when blocks > threshold"
)
args = parser.parse_args() args = parser.parse_args()
passed = run_needle_test( passed = run_needle_test(
@@ -179,6 +210,9 @@ if __name__ == "__main__":
needle_value=args.needle_value, needle_value=args.needle_value,
max_new_tokens=args.max_new_tokens, max_new_tokens=args.max_new_tokens,
enable_cpu_offload=args.enable_offload, enable_cpu_offload=args.enable_offload,
enable_quest=args.enable_quest,
sparse_topk=args.sparse_topk,
sparse_threshold=args.sparse_threshold,
verbose=True, verbose=True,
) )