From 2a6e0a2c02542737f16c1e37cfa4a47e826a08fe Mon Sep 17 00:00:00 2001 From: Zijie Tian Date: Wed, 7 Jan 2026 03:29:21 +0800 Subject: [PATCH] [feat] Added Quest Sparsity Policy. --- nanovllm/config.py | 16 +++++++++----- nanovllm/engine/model_runner.py | 25 ++++++++++----------- nanovllm/kvcache/__init__.py | 32 ++++++++++----------------- nanovllm/kvcache/hybrid_manager.py | 26 +++++----------------- nanovllm/kvcache/offload_engine.py | 24 +++++++------------- nanovllm/kvcache/sparse/__init__.py | 3 ++- nanovllm/kvcache/sparse/policy.py | 8 ++----- nanovllm/layers/attention.py | 16 +++++++------- tests/test_needle.py | 34 +++++++++++++++++++++++++++++ 9 files changed, 92 insertions(+), 92 deletions(-) diff --git a/nanovllm/config.py b/nanovllm/config.py index 6582f16..2be7b8d 100644 --- a/nanovllm/config.py +++ b/nanovllm/config.py @@ -1,9 +1,16 @@ import os from dataclasses import dataclass +from enum import Enum, auto from transformers import AutoConfig import torch +class SparsePolicyType(Enum): + """Sparse attention policy types.""" + FULL = auto() # No sparse attention (load all blocks) + QUEST = auto() # Query-aware Top-K block selection (decode only) + + @dataclass class Config: model: str @@ -29,11 +36,10 @@ class Config: num_gpu_kvcache_blocks: int = -1 num_cpu_kvcache_blocks: int = -1 - # Sparse attention configuration (dual policy architecture) - prefill_policy: str = "full" # "full", "quest", "vertical_slash", "streaming_llm" - decode_policy: str = "full" # "full", "quest", "vertical_slash", "streaming_llm" - sparse_num_sink_blocks: int = 1 # Number of sink blocks for sparse patterns - sparse_local_window_blocks: int = 2 # Local window size for VerticalSlash + # Sparse attention configuration + # Quest: decode-only sparse attention with Top-K block selection + # FULL: no sparse attention (load all blocks) + sparse_policy: SparsePolicyType = SparsePolicyType.FULL sparse_topk_blocks: int = 8 # Top-K blocks for Quest sparse_threshold_blocks: int = 4 # Apply sparse only when blocks > threshold diff --git a/nanovllm/engine/model_runner.py b/nanovllm/engine/model_runner.py index 7073ce5..3b463f0 100644 --- a/nanovllm/engine/model_runner.py +++ b/nanovllm/engine/model_runner.py @@ -156,22 +156,19 @@ class ModelRunner: dtype=hf_config.torch_dtype, ) - # Initialize sparse policies if manager has them (CPU offload mode) - if hasattr(self.kvcache_manager, 'prefill_policy') and hasattr(self.kvcache_manager, 'decode_policy'): - # Initialize both policies with model config - for policy in [self.kvcache_manager.prefill_policy, self.kvcache_manager.decode_policy]: - if policy is not None: - policy.initialize( - num_layers=hf_config.num_hidden_layers, - num_kv_heads=num_kv_heads, - head_dim=head_dim, - num_cpu_blocks=config.num_cpu_kvcache_blocks, - dtype=hf_config.torch_dtype, - device=torch.device("cuda"), - ) + # Initialize sparse policy if manager has one (CPU offload mode) + if hasattr(self.kvcache_manager, 'sparse_policy') and self.kvcache_manager.sparse_policy is not None: + self.kvcache_manager.sparse_policy.initialize( + num_layers=hf_config.num_hidden_layers, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + num_cpu_blocks=config.num_cpu_kvcache_blocks, + dtype=hf_config.torch_dtype, + device=torch.device("cuda"), + ) logger.info( - f"Sparse policies initialized: prefill={config.prefill_policy}, decode={config.decode_policy} " + f"Sparse policy initialized: {config.sparse_policy.name} " f"(topk={config.sparse_topk_blocks}, threshold={config.sparse_threshold_blocks})" ) diff --git a/nanovllm/kvcache/__init__.py b/nanovllm/kvcache/__init__.py index c312b73..07ddd61 100644 --- a/nanovllm/kvcache/__init__.py +++ b/nanovllm/kvcache/__init__.py @@ -56,36 +56,26 @@ def create_kvcache_manager(config: "Config") -> KVCacheManager: # Need CPU offload: use hybrid manager from nanovllm.kvcache.hybrid_manager import HybridKVCacheManager from nanovllm.kvcache.policies import get_policy - from nanovllm.kvcache.sparse import create_sparse_policy, SparsePolicyType + from nanovllm.kvcache.sparse import create_sparse_policy + from nanovllm.config import SparsePolicyType eviction_policy = get_policy(getattr(config, 'offload_policy', 'lru')) - # Create sparse policies from config - prefill_policy_type = getattr(config, 'prefill_policy', 'full') - decode_policy_type = getattr(config, 'decode_policy', 'full') - - def create_policy(policy_type_str): - """Create a sparse policy from config string.""" - if policy_type_str.lower() == 'full': - return create_sparse_policy(SparsePolicyType.FULL) - policy_type = SparsePolicyType[policy_type_str.upper()] - return create_sparse_policy( - policy_type, - topk_blocks=getattr(config, 'sparse_topk_blocks', 8), - threshold_blocks=getattr(config, 'sparse_threshold_blocks', 4), - include_sink_blocks=getattr(config, 'sparse_num_sink_blocks', 1), - ) - - prefill_policy = create_policy(prefill_policy_type) - decode_policy = create_policy(decode_policy_type) + # Create sparse policy from config enum + # Quest is decode-only: prefill returns all blocks (query=None), decode does Top-K + sparse_policy_type = getattr(config, 'sparse_policy', SparsePolicyType.FULL) + sparse_policy = create_sparse_policy( + sparse_policy_type, + topk_blocks=getattr(config, 'sparse_topk_blocks', 8), + threshold_blocks=getattr(config, 'sparse_threshold_blocks', 4), + ) return HybridKVCacheManager( num_gpu_slots=num_gpu_blocks, num_cpu_blocks=num_cpu_blocks, block_size=config.kvcache_block_size, policy=eviction_policy, - prefill_policy=prefill_policy, - decode_policy=decode_policy, + sparse_policy=sparse_policy, ) diff --git a/nanovllm/kvcache/hybrid_manager.py b/nanovllm/kvcache/hybrid_manager.py index 5c050df..61dd844 100644 --- a/nanovllm/kvcache/hybrid_manager.py +++ b/nanovllm/kvcache/hybrid_manager.py @@ -90,8 +90,7 @@ class HybridKVCacheManager(KVCacheManager): num_cpu_blocks: int, block_size: int, policy: Optional[EvictionPolicy] = None, - prefill_policy: "SparsePolicy" = None, - decode_policy: "SparsePolicy" = None, + sparse_policy: "SparsePolicy" = None, ): """ Initialize hybrid manager with CPU-primary ring buffer design. @@ -104,8 +103,7 @@ class HybridKVCacheManager(KVCacheManager): num_cpu_blocks: Number of CPU pool blocks (primary storage) block_size: Tokens per block policy: Eviction policy (default: LRU, used for prefix cache management) - prefill_policy: Sparse attention policy for prefill phase - decode_policy: Sparse attention policy for decode phase + sparse_policy: Sparse attention policy (Quest for decode-only sparse) """ self._block_size = block_size self.num_gpu_slots = num_gpu_slots @@ -117,9 +115,8 @@ class HybridKVCacheManager(KVCacheManager): # Eviction policy self.policy = policy or LRUPolicy() - # Sparse attention policies (set at construction time, immutable) - self.prefill_policy = prefill_policy - self.decode_policy = decode_policy + # Sparse attention policy (set at construction time, immutable) + self.sparse_policy = sparse_policy # Logical blocks (what sequences reference) - one per CPU block self.logical_blocks: List[LogicalBlock] = [ @@ -185,8 +182,7 @@ class HybridKVCacheManager(KVCacheManager): num_kv_heads=num_kv_heads, head_dim=head_dim, dtype=dtype, - prefill_policy=self.prefill_policy, - decode_policy=self.decode_policy, + sparse_policy=self.sparse_policy, ) def get_layer_cache(self, layer_id: int) -> Tuple[Tensor, Tensor]: @@ -194,18 +190,6 @@ class HybridKVCacheManager(KVCacheManager): assert self.offload_engine is not None return self.offload_engine.get_layer_cache(layer_id) - def get_policy_for_phase(self, is_prefill: bool) -> Optional["SparsePolicy"]: - """ - Get sparse policy for the specified phase. - - Args: - is_prefill: True for prefill phase, False for decode phase - - Returns: - SparsePolicy for the phase, or None if not set - """ - return self.prefill_policy if is_prefill else self.decode_policy - def can_allocate(self, seq: Sequence) -> bool: """Check if we can allocate blocks for a new sequence.""" return len(self.free_logical_ids) >= seq.num_blocks diff --git a/nanovllm/kvcache/offload_engine.py b/nanovllm/kvcache/offload_engine.py index f3431de..5260906 100644 --- a/nanovllm/kvcache/offload_engine.py +++ b/nanovllm/kvcache/offload_engine.py @@ -60,8 +60,7 @@ class OffloadEngine: head_dim: int, dtype: torch.dtype = torch.float16, num_streams: int = 4, - prefill_policy: "SparsePolicy" = None, - decode_policy: "SparsePolicy" = None, + sparse_policy: "SparsePolicy" = None, ): self.num_layers = num_layers self.num_gpu_blocks = num_gpu_blocks @@ -217,9 +216,8 @@ class OffloadEngine: self._debug_mode = False self._debug_hooks: List = [] # External hooks for debug events - # ========== Sparse attention policies (set at construction time) ========== - self.prefill_policy = prefill_policy - self.decode_policy = decode_policy + # ========== Sparse attention policy (set at construction time) ========== + self.sparse_policy = sparse_policy def _get_next_stream(self) -> torch.cuda.Stream: """Round-robin stream selection for parallel transfers.""" @@ -765,20 +763,14 @@ class OffloadEngine: logger.debug(f"Ring offload: GPU slot[{slot_idx}] -> CPU[layer={layer_id}, block={cpu_block_id}]") # Collect metadata BEFORE offload (while k_cache is still on GPU) - # Both policies' callbacks are called - each decides whether to respond valid_tokens = num_valid_tokens if num_valid_tokens > 0 else self.block_size k_cache = self.k_cache_gpu[slot_idx] - if is_prefill: - if self.prefill_policy is not None: - self.prefill_policy.on_prefill_offload(cpu_block_id, layer_id, k_cache, valid_tokens) - if self.decode_policy is not None: - self.decode_policy.on_prefill_offload(cpu_block_id, layer_id, k_cache, valid_tokens) - else: - if self.prefill_policy is not None: - self.prefill_policy.on_decode_offload(cpu_block_id, layer_id, k_cache, valid_tokens) - if self.decode_policy is not None: - self.decode_policy.on_decode_offload(cpu_block_id, layer_id, k_cache, valid_tokens) + if self.sparse_policy is not None: + if is_prefill: + self.sparse_policy.on_prefill_offload(cpu_block_id, layer_id, k_cache, valid_tokens) + else: + self.sparse_policy.on_decode_offload(cpu_block_id, layer_id, k_cache, valid_tokens) torch.cuda.nvtx.range_push(f"D2H: Slot[{slot_idx}]->CPU[L{layer_id},B{cpu_block_id}]") with torch.cuda.stream(self.transfer_stream_main): diff --git a/nanovllm/kvcache/sparse/__init__.py b/nanovllm/kvcache/sparse/__init__.py index ae9473c..ae8e922 100644 --- a/nanovllm/kvcache/sparse/__init__.py +++ b/nanovllm/kvcache/sparse/__init__.py @@ -19,7 +19,8 @@ Usage: return available_blocks[:5] # Just first 5 blocks """ -from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext, SparsePolicyType +from nanovllm.config import SparsePolicyType +from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy from nanovllm.kvcache.sparse.quest import QuestPolicy, QuestConfig, BlockMetadataManager diff --git a/nanovllm/kvcache/sparse/policy.py b/nanovllm/kvcache/sparse/policy.py index ee1f64b..2813745 100644 --- a/nanovllm/kvcache/sparse/policy.py +++ b/nanovllm/kvcache/sparse/policy.py @@ -7,15 +7,11 @@ from CPU for each query chunk during chunked attention computation. from abc import ABC, abstractmethod from dataclasses import dataclass -from enum import Enum, auto from typing import List, Optional, Any import torch - -class SparsePolicyType(Enum): - """Built-in sparse attention policy types.""" - FULL = auto() # prefill + decode - QUEST = auto() # decode only +# Import SparsePolicyType from config to avoid circular imports +from nanovllm.config import SparsePolicyType @dataclass diff --git a/nanovllm/layers/attention.py b/nanovllm/layers/attention.py index 0a36141..1648c0d 100644 --- a/nanovllm/layers/attention.py +++ b/nanovllm/layers/attention.py @@ -188,9 +188,9 @@ class Attention(nn.Module): # Get prefilled CPU blocks (blocks from previous chunks) cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq) - # Apply sparse policy if enabled - prefill_policy = kvcache_manager.get_policy_for_phase(is_prefill=True) - if cpu_block_table and prefill_policy is not None: + # Apply sparse policy if enabled (Quest returns all blocks for prefill since query=None) + sparse_policy = kvcache_manager.sparse_policy + if cpu_block_table and sparse_policy is not None: num_chunks = getattr(context, 'num_chunks', current_chunk_idx + 1) policy_ctx = PolicyContext( query_chunk_idx=current_chunk_idx, @@ -201,7 +201,7 @@ class Attention(nn.Module): block_size=kvcache_manager.block_size, total_kv_len=len(cpu_block_table) * kvcache_manager.block_size, ) - cpu_block_table = prefill_policy.select_blocks( + cpu_block_table = sparse_policy.select_blocks( cpu_block_table, policy_ctx ) @@ -512,9 +512,9 @@ class Attention(nn.Module): if last_block_valid_tokens == 0 and total_prefill_tokens > 0: last_block_valid_tokens = block_size # Last block was exactly full - # Apply sparse policy if enabled - decode_policy = kvcache_manager.get_policy_for_phase(is_prefill=False) - if decode_policy is not None: + # Apply sparse policy if enabled (Quest does Top-K selection for decode) + sparse_policy = kvcache_manager.sparse_policy + if sparse_policy is not None: policy_ctx = PolicyContext( query_chunk_idx=0, num_query_chunks=1, @@ -524,7 +524,7 @@ class Attention(nn.Module): block_size=kvcache_manager.block_size, total_kv_len=len(cpu_block_table) * kvcache_manager.block_size, ) - cpu_block_table = decode_policy.select_blocks( + cpu_block_table = sparse_policy.select_blocks( cpu_block_table, policy_ctx ) diff --git a/tests/test_needle.py b/tests/test_needle.py index bc30f87..7792ddc 100644 --- a/tests/test_needle.py +++ b/tests/test_needle.py @@ -12,6 +12,7 @@ os.environ["NANOVLLM_LOG_LEVEL"] = "INFO" import argparse from nanovllm import LLM, SamplingParams +from nanovllm.config import SparsePolicyType from utils import generate_needle_prompt, check_needle_answer @@ -29,6 +30,9 @@ def run_needle_test( needle_value: str = "7492", max_new_tokens: int = 32, enable_cpu_offload: bool = False, + enable_quest: bool = False, + sparse_topk: int = 8, + sparse_threshold: int = 4, verbose: bool = True, ) -> bool: """ @@ -44,11 +48,16 @@ def run_needle_test( needle_value: The secret value to find max_new_tokens: Maximum tokens to generate enable_cpu_offload: Enable CPU offload mode + enable_quest: Enable Quest sparse attention (decode-only Top-K) + sparse_topk: Top-K blocks for Quest + sparse_threshold: Apply sparse only when blocks > threshold verbose: Print detailed output Returns: True if test passed, False otherwise """ + sparse_policy = SparsePolicyType.QUEST if enable_quest else SparsePolicyType.FULL + if verbose: print(f"\n{'='*60}") print(f"Needle-in-Haystack Test") @@ -60,6 +69,8 @@ def run_needle_test( print(f"Needle position: {needle_position:.0%}") print(f"Needle value: {needle_value}") print(f"CPU offload: {enable_cpu_offload}") + if enable_cpu_offload: + print(f"Sparse policy: {sparse_policy.name} (topk={sparse_topk}, threshold={sparse_threshold})") print(f"{'='*60}\n") # 1. Initialize LLM @@ -72,6 +83,9 @@ def run_needle_test( } if enable_cpu_offload: llm_kwargs["num_gpu_blocks"] = num_gpu_blocks + llm_kwargs["sparse_policy"] = sparse_policy + llm_kwargs["sparse_topk_blocks"] = sparse_topk + llm_kwargs["sparse_threshold_blocks"] = sparse_threshold llm = LLM(model_path, **llm_kwargs) @@ -167,6 +181,23 @@ if __name__ == "__main__": action="store_true", help="Enable CPU offload (has known bug for long sequences)" ) + parser.add_argument( + "--enable-quest", + action="store_true", + help="Enable Quest sparse attention (decode-only Top-K selection)" + ) + parser.add_argument( + "--sparse-topk", + type=int, + default=8, + help="Top-K blocks for Quest sparse attention" + ) + parser.add_argument( + "--sparse-threshold", + type=int, + default=4, + help="Apply sparse only when blocks > threshold" + ) args = parser.parse_args() passed = run_needle_test( @@ -179,6 +210,9 @@ if __name__ == "__main__": needle_value=args.needle_value, max_new_tokens=args.max_new_tokens, enable_cpu_offload=args.enable_offload, + enable_quest=args.enable_quest, + sparse_topk=args.sparse_topk, + sparse_threshold=args.sparse_threshold, verbose=True, )