diff --git a/bench_offload.py b/bench_offload.py index 48f23ee..34edafe 100644 --- a/bench_offload.py +++ b/bench_offload.py @@ -3,11 +3,6 @@ import time from random import randint, seed from nanovllm import LLM, SamplingParams -# Import sparse policy classes -from nanovllm.kvcache.sparse.quest import QuestPolicy, QuestConfig, BlockMetadataManager -from nanovllm.kvcache.sparse.hybrid import HybridPolicy -from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy - def bench_decode(llm, num_seqs, input_len, output_len): """Benchmark decode performance (original test)""" @@ -38,58 +33,6 @@ def bench_prefill(llm, num_seqs, input_len): print(f"[Prefill] Input: {total_input_tokens}tok ({num_seqs}x{input_len}), Time: {t:.2f}s, Throughput: {throughput:.2f}tok/s") -def setup_quest_policy(llm, topk_blocks=8, threshold_blocks=4): - """ - Setup Quest sparse policy for decode phase. - - Uses HybridPolicy: Full attention for prefill, Quest Top-K for decode. - """ - import torch - - kvcache_manager = llm.model_runner.kvcache_manager - offload_engine = kvcache_manager.offload_engine - - # Get model parameters from offload engine - num_layers = offload_engine.num_layers - num_kv_heads = offload_engine.num_kv_heads - head_dim = offload_engine.head_dim - num_cpu_blocks = kvcache_manager.num_cpu_blocks - dtype = offload_engine.k_cache_cpu.dtype - - print(f"Setting up Quest policy:") - print(f" num_layers={num_layers}, num_kv_heads={num_kv_heads}, head_dim={head_dim}") - print(f" num_cpu_blocks={num_cpu_blocks}, dtype={dtype}") - print(f" topk_blocks={topk_blocks}, threshold_blocks={threshold_blocks}") - - # Create BlockMetadataManager for storing min/max keys - metadata = BlockMetadataManager( - num_blocks=num_cpu_blocks, - num_layers=num_layers, - num_kv_heads=num_kv_heads, - head_dim=head_dim, - dtype=dtype, - ) - - # Create Quest policy for decode - quest_config = QuestConfig( - topk_blocks=topk_blocks, - threshold_blocks=threshold_blocks, - ) - quest_policy = QuestPolicy(quest_config, metadata) - - # Create Hybrid policy: Full for prefill, Quest for decode - hybrid_policy = HybridPolicy( - prefill_policy=FullAttentionPolicy(), - decode_policy=quest_policy, - ) - - # Set the policy - kvcache_manager.set_sparse_policy(hybrid_policy) - print(f" Policy set: HybridPolicy(prefill=Full, decode=Quest)") - - return hybrid_policy - - def main(): import argparse parser = argparse.ArgumentParser() @@ -101,7 +44,18 @@ def main(): path = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/") # Note: Qwen3-4B-Instruct-2507 max_position_embeddings = 262144 - max_len = 131072 # 128K tokens + max_len = 32 * 1024 # 128K tokens + + # Setup policy configuration + if not args.no_sparse: + prefill_policy = "full" # Full attention for prefill + decode_policy = "quest" # Quest Top-K for decode + print(f"\n[Quest Sparse Attention] prefill={prefill_policy}, decode={decode_policy}, topk={args.topk}") + else: + prefill_policy = "full" # Full attention for both phases + decode_policy = "full" + print("\n[Full Attention] No sparse policy (baseline)") + llm = LLM( path, enforce_eager=False, @@ -109,15 +63,12 @@ def main(): max_num_batched_tokens=max_len, enable_cpu_offload=True, num_gpu_blocks=6, # Small GPU buffer for offload testing + prefill_policy=prefill_policy, + decode_policy=decode_policy, + sparse_topk_blocks=args.topk, + sparse_threshold_blocks=4, ) - if not args.no_sparse: - # Setup Quest policy for decode (Top-K blocks, apply when > 4 blocks) - setup_quest_policy(llm, topk_blocks=args.topk, threshold_blocks=4) - print(f"\n[Quest Sparse Attention] topk={args.topk}") - else: - print("\n[Full Attention] No sparse policy (baseline)") - # Warmup llm.generate(["Benchmark: "], SamplingParams()) diff --git a/nanovllm/config.py b/nanovllm/config.py index 97758e0..6582f16 100644 --- a/nanovllm/config.py +++ b/nanovllm/config.py @@ -29,8 +29,9 @@ class Config: num_gpu_kvcache_blocks: int = -1 num_cpu_kvcache_blocks: int = -1 - # Sparse attention configuration - sparse_policy: str | None = None # "vertical_slash", "quest", "streaming_llm", or None + # Sparse attention configuration (dual policy architecture) + prefill_policy: str = "full" # "full", "quest", "vertical_slash", "streaming_llm" + decode_policy: str = "full" # "full", "quest", "vertical_slash", "streaming_llm" sparse_num_sink_blocks: int = 1 # Number of sink blocks for sparse patterns sparse_local_window_blocks: int = 2 # Local window size for VerticalSlash sparse_topk_blocks: int = 8 # Top-K blocks for Quest diff --git a/nanovllm/engine/model_runner.py b/nanovllm/engine/model_runner.py index 131ede9..7073ce5 100644 --- a/nanovllm/engine/model_runner.py +++ b/nanovllm/engine/model_runner.py @@ -156,6 +156,25 @@ class ModelRunner: dtype=hf_config.torch_dtype, ) + # Initialize sparse policies if manager has them (CPU offload mode) + if hasattr(self.kvcache_manager, 'prefill_policy') and hasattr(self.kvcache_manager, 'decode_policy'): + # Initialize both policies with model config + for policy in [self.kvcache_manager.prefill_policy, self.kvcache_manager.decode_policy]: + if policy is not None: + policy.initialize( + num_layers=hf_config.num_hidden_layers, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + num_cpu_blocks=config.num_cpu_kvcache_blocks, + dtype=hf_config.torch_dtype, + device=torch.device("cuda"), + ) + + logger.info( + f"Sparse policies initialized: prefill={config.prefill_policy}, decode={config.decode_policy} " + f"(topk={config.sparse_topk_blocks}, threshold={config.sparse_threshold_blocks})" + ) + # Log KV cache allocation info with detailed per-token breakdown gpu_memory_mb = config.num_gpu_kvcache_blocks * block_bytes / (1024 ** 2) cpu_memory_mb = config.num_cpu_kvcache_blocks * block_bytes / (1024 ** 2) diff --git a/nanovllm/kvcache/__init__.py b/nanovllm/kvcache/__init__.py index 02de400..c312b73 100644 --- a/nanovllm/kvcache/__init__.py +++ b/nanovllm/kvcache/__init__.py @@ -56,14 +56,36 @@ def create_kvcache_manager(config: "Config") -> KVCacheManager: # Need CPU offload: use hybrid manager from nanovllm.kvcache.hybrid_manager import HybridKVCacheManager from nanovllm.kvcache.policies import get_policy + from nanovllm.kvcache.sparse import create_sparse_policy, SparsePolicyType - policy = get_policy(getattr(config, 'offload_policy', 'lru')) + eviction_policy = get_policy(getattr(config, 'offload_policy', 'lru')) + + # Create sparse policies from config + prefill_policy_type = getattr(config, 'prefill_policy', 'full') + decode_policy_type = getattr(config, 'decode_policy', 'full') + + def create_policy(policy_type_str): + """Create a sparse policy from config string.""" + if policy_type_str.lower() == 'full': + return create_sparse_policy(SparsePolicyType.FULL) + policy_type = SparsePolicyType[policy_type_str.upper()] + return create_sparse_policy( + policy_type, + topk_blocks=getattr(config, 'sparse_topk_blocks', 8), + threshold_blocks=getattr(config, 'sparse_threshold_blocks', 4), + include_sink_blocks=getattr(config, 'sparse_num_sink_blocks', 1), + ) + + prefill_policy = create_policy(prefill_policy_type) + decode_policy = create_policy(decode_policy_type) return HybridKVCacheManager( num_gpu_slots=num_gpu_blocks, num_cpu_blocks=num_cpu_blocks, block_size=config.kvcache_block_size, - policy=policy, + policy=eviction_policy, + prefill_policy=prefill_policy, + decode_policy=decode_policy, ) diff --git a/nanovllm/kvcache/hybrid_manager.py b/nanovllm/kvcache/hybrid_manager.py index a1f7209..5c050df 100644 --- a/nanovllm/kvcache/hybrid_manager.py +++ b/nanovllm/kvcache/hybrid_manager.py @@ -90,6 +90,8 @@ class HybridKVCacheManager(KVCacheManager): num_cpu_blocks: int, block_size: int, policy: Optional[EvictionPolicy] = None, + prefill_policy: "SparsePolicy" = None, + decode_policy: "SparsePolicy" = None, ): """ Initialize hybrid manager with CPU-primary ring buffer design. @@ -102,6 +104,8 @@ class HybridKVCacheManager(KVCacheManager): num_cpu_blocks: Number of CPU pool blocks (primary storage) block_size: Tokens per block policy: Eviction policy (default: LRU, used for prefix cache management) + prefill_policy: Sparse attention policy for prefill phase + decode_policy: Sparse attention policy for decode phase """ self._block_size = block_size self.num_gpu_slots = num_gpu_slots @@ -113,6 +117,10 @@ class HybridKVCacheManager(KVCacheManager): # Eviction policy self.policy = policy or LRUPolicy() + # Sparse attention policies (set at construction time, immutable) + self.prefill_policy = prefill_policy + self.decode_policy = decode_policy + # Logical blocks (what sequences reference) - one per CPU block self.logical_blocks: List[LogicalBlock] = [ LogicalBlock(i) for i in range(self.total_blocks) @@ -153,9 +161,6 @@ class HybridKVCacheManager(KVCacheManager): # Key: sequence id, Value: number of tokens from prefill (before decode started) self._prefill_len: Dict[int, int] = {} - # Sparse attention policy (optional) - self.sparse_policy: Optional["SparsePolicy"] = None - @property def block_size(self) -> int: return self._block_size @@ -180,6 +185,8 @@ class HybridKVCacheManager(KVCacheManager): num_kv_heads=num_kv_heads, head_dim=head_dim, dtype=dtype, + prefill_policy=self.prefill_policy, + decode_policy=self.decode_policy, ) def get_layer_cache(self, layer_id: int) -> Tuple[Tensor, Tensor]: @@ -187,23 +194,17 @@ class HybridKVCacheManager(KVCacheManager): assert self.offload_engine is not None return self.offload_engine.get_layer_cache(layer_id) - def set_sparse_policy(self, policy: "SparsePolicy") -> None: + def get_policy_for_phase(self, is_prefill: bool) -> Optional["SparsePolicy"]: """ - Set sparse attention policy for block selection. - - The sparse policy determines which KV blocks to load from CPU - for each query chunk during chunked attention computation. + Get sparse policy for the specified phase. Args: - policy: SparsePolicy instance (e.g., VerticalSlashPolicy, QuestPolicy) + is_prefill: True for prefill phase, False for decode phase - Example: - from nanovllm.kvcache.sparse import VerticalSlashPolicy, VerticalSlashConfig - policy = VerticalSlashPolicy(VerticalSlashConfig(num_sink_blocks=2)) - manager.set_sparse_policy(policy) + Returns: + SparsePolicy for the phase, or None if not set """ - self.sparse_policy = policy - logger.info(f"Sparse attention policy set: {policy}") + return self.prefill_policy if is_prefill else self.decode_policy def can_allocate(self, seq: Sequence) -> bool: """Check if we can allocate blocks for a new sequence.""" diff --git a/nanovllm/kvcache/offload_engine.py b/nanovllm/kvcache/offload_engine.py index 6ca2ab7..f3431de 100644 --- a/nanovllm/kvcache/offload_engine.py +++ b/nanovllm/kvcache/offload_engine.py @@ -17,6 +17,11 @@ from nanovllm.kvcache.kernels import gathered_copy_kv from nanovllm.comm import memcpy_2d_async from nanovllm.utils.logger import get_logger +# Import for type hints only (avoid circular import) +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from nanovllm.kvcache.sparse import SparsePolicy + logger = get_logger("offload_engine") @@ -55,6 +60,8 @@ class OffloadEngine: head_dim: int, dtype: torch.dtype = torch.float16, num_streams: int = 4, + prefill_policy: "SparsePolicy" = None, + decode_policy: "SparsePolicy" = None, ): self.num_layers = num_layers self.num_gpu_blocks = num_gpu_blocks @@ -210,6 +217,10 @@ class OffloadEngine: self._debug_mode = False self._debug_hooks: List = [] # External hooks for debug events + # ========== Sparse attention policies (set at construction time) ========== + self.prefill_policy = prefill_policy + self.decode_policy = decode_policy + def _get_next_stream(self) -> torch.cuda.Stream: """Round-robin stream selection for parallel transfers.""" stream = self.transfer_streams[self._stream_idx] @@ -730,7 +741,14 @@ class OffloadEngine: """Wait for slot offload to complete.""" self.compute_stream.wait_event(self.ring_slot_offload_done[slot_idx]) - def offload_slot_layer_to_cpu(self, slot_idx: int, layer_id: int, cpu_block_id: int) -> None: + def offload_slot_layer_to_cpu( + self, + slot_idx: int, + layer_id: int, + cpu_block_id: int, + num_valid_tokens: int = -1, + is_prefill: bool = True, + ) -> None: """ Async offload a ring buffer slot to CPU for one layer. @@ -741,9 +759,27 @@ class OffloadEngine: slot_idx: Source GPU slot index layer_id: Target layer in CPU cache cpu_block_id: Target CPU block ID + num_valid_tokens: Number of valid tokens in this block (-1 = use block_size) + is_prefill: True if in prefill phase, False if in decode phase """ logger.debug(f"Ring offload: GPU slot[{slot_idx}] -> CPU[layer={layer_id}, block={cpu_block_id}]") + # Collect metadata BEFORE offload (while k_cache is still on GPU) + # Both policies' callbacks are called - each decides whether to respond + valid_tokens = num_valid_tokens if num_valid_tokens > 0 else self.block_size + k_cache = self.k_cache_gpu[slot_idx] + + if is_prefill: + if self.prefill_policy is not None: + self.prefill_policy.on_prefill_offload(cpu_block_id, layer_id, k_cache, valid_tokens) + if self.decode_policy is not None: + self.decode_policy.on_prefill_offload(cpu_block_id, layer_id, k_cache, valid_tokens) + else: + if self.prefill_policy is not None: + self.prefill_policy.on_decode_offload(cpu_block_id, layer_id, k_cache, valid_tokens) + if self.decode_policy is not None: + self.decode_policy.on_decode_offload(cpu_block_id, layer_id, k_cache, valid_tokens) + torch.cuda.nvtx.range_push(f"D2H: Slot[{slot_idx}]->CPU[L{layer_id},B{cpu_block_id}]") with torch.cuda.stream(self.transfer_stream_main): # Wait for both compute_stream and default stream diff --git a/nanovllm/kvcache/sparse/__init__.py b/nanovllm/kvcache/sparse/__init__.py index cfc08f0..ae9473c 100644 --- a/nanovllm/kvcache/sparse/__init__.py +++ b/nanovllm/kvcache/sparse/__init__.py @@ -22,7 +22,6 @@ Usage: from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext, SparsePolicyType from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy from nanovllm.kvcache.sparse.quest import QuestPolicy, QuestConfig, BlockMetadataManager -from nanovllm.kvcache.sparse.hybrid import HybridPolicy def create_sparse_policy(policy_type: SparsePolicyType, **kwargs) -> SparsePolicy: @@ -67,6 +66,5 @@ __all__ = [ "QuestPolicy", "QuestConfig", "BlockMetadataManager", - "HybridPolicy", "create_sparse_policy", ] diff --git a/nanovllm/kvcache/sparse/hybrid.py b/nanovllm/kvcache/sparse/hybrid.py deleted file mode 100644 index 1c2492a..0000000 --- a/nanovllm/kvcache/sparse/hybrid.py +++ /dev/null @@ -1,93 +0,0 @@ -""" -Hybrid sparse attention policy. - -Allows using different policies for prefill vs decode phases. -This is useful because optimal sparsity patterns often differ: -- Prefill: fixed patterns work well (e.g., VerticalSlash) -- Decode: query-aware selection helps (e.g., Quest) -""" - -from typing import List -import torch -from .policy import SparsePolicy, PolicyContext - - -class HybridPolicy(SparsePolicy): - """ - Hybrid policy that uses different policies for prefill and decode. - - Example usage: - ```python - from nanovllm.kvcache.sparse import ( - HybridPolicy, VerticalSlashPolicy, QuestPolicy, - VerticalSlashConfig, QuestConfig, BlockMetadataManager - ) - - # Prefill: use fast fixed pattern - prefill_policy = VerticalSlashPolicy(VerticalSlashConfig( - num_sink_blocks=1, - local_window_blocks=3, - )) - - # Decode: use query-aware selection - metadata = BlockMetadataManager(num_blocks, num_layers, num_heads, head_dim) - decode_policy = QuestPolicy(QuestConfig(topk_blocks=8), metadata) - - # Combine - policy = HybridPolicy(prefill_policy, decode_policy) - ``` - """ - - def __init__( - self, - prefill_policy: SparsePolicy, - decode_policy: SparsePolicy, - ): - """ - Initialize hybrid policy. - - Args: - prefill_policy: Policy to use during prefill phase - decode_policy: Policy to use during decode phase - """ - self.prefill_policy = prefill_policy - self.decode_policy = decode_policy - - def select_blocks( - self, - available_blocks: List[int], - ctx: PolicyContext, - ) -> List[int]: - """Delegate to appropriate policy based on phase.""" - if ctx.is_prefill: - return self.prefill_policy.select_blocks(available_blocks, ctx) - else: - return self.decode_policy.select_blocks(available_blocks, ctx) - - def on_block_offloaded( - self, - cpu_block_id: int, - layer_id: int, - k_cache: torch.Tensor, - num_valid_tokens: int, - ) -> None: - """Forward to both policies (both may need metadata updates).""" - self.prefill_policy.on_block_offloaded( - cpu_block_id, layer_id, k_cache, num_valid_tokens - ) - self.decode_policy.on_block_offloaded( - cpu_block_id, layer_id, k_cache, num_valid_tokens - ) - - def reset(self) -> None: - """Reset both policies.""" - self.prefill_policy.reset() - self.decode_policy.reset() - - def __repr__(self) -> str: - return ( - f"HybridPolicy(\n" - f" prefill={self.prefill_policy},\n" - f" decode={self.decode_policy}\n" - f")" - ) diff --git a/nanovllm/kvcache/sparse/policy.py b/nanovllm/kvcache/sparse/policy.py index c935d06..ee1f64b 100644 --- a/nanovllm/kvcache/sparse/policy.py +++ b/nanovllm/kvcache/sparse/policy.py @@ -134,7 +134,7 @@ class SparsePolicy(ABC): """ pass - def on_block_offloaded( + def on_prefill_offload( self, cpu_block_id: int, layer_id: int, @@ -142,15 +142,38 @@ class SparsePolicy(ABC): num_valid_tokens: int, ) -> None: """ - Hook called when a block is offloaded from GPU to CPU. + Hook called when a block is offloaded during prefill phase. + Called BEFORE GPU→CPU copy, while k_cache is still on GPU. Override this to collect metadata about blocks (e.g., min/max keys for Quest-style selection). Default implementation does nothing. Args: - cpu_block_id: The CPU block ID that was written + cpu_block_id: The CPU block ID that will be written layer_id: Transformer layer index - k_cache: Key cache tensor [block_size, num_kv_heads, head_dim] + k_cache: Key cache tensor [block_size, num_kv_heads, head_dim] (on GPU) + num_valid_tokens: Number of valid tokens in this block + """ + pass + + def on_decode_offload( + self, + cpu_block_id: int, + layer_id: int, + k_cache: torch.Tensor, + num_valid_tokens: int, + ) -> None: + """ + Hook called when a block is offloaded during decode phase. + + Called BEFORE GPU→CPU copy, while k_cache is still on GPU. + Override this to update metadata about blocks. Default implementation + does nothing. + + Args: + cpu_block_id: The CPU block ID that will be written + layer_id: Transformer layer index + k_cache: Key cache tensor [block_size, num_kv_heads, head_dim] (on GPU) num_valid_tokens: Number of valid tokens in this block """ pass diff --git a/nanovllm/kvcache/sparse/quest.py b/nanovllm/kvcache/sparse/quest.py index d038832..42b96fc 100644 --- a/nanovllm/kvcache/sparse/quest.py +++ b/nanovllm/kvcache/sparse/quest.py @@ -289,14 +289,25 @@ class QuestPolicy(SparsePolicy): return result - def on_block_offloaded( + def on_prefill_offload( self, cpu_block_id: int, layer_id: int, k_cache: torch.Tensor, num_valid_tokens: int, ) -> None: - """Update min/max key metadata when block is offloaded.""" + """Update min/max key metadata during prefill offload.""" + if self.metadata is not None: + self.metadata.update_metadata(cpu_block_id, layer_id, k_cache, num_valid_tokens) + + def on_decode_offload( + self, + cpu_block_id: int, + layer_id: int, + k_cache: torch.Tensor, + num_valid_tokens: int, + ) -> None: + """Update min/max key metadata during decode offload (for new blocks).""" if self.metadata is not None: self.metadata.update_metadata(cpu_block_id, layer_id, k_cache, num_valid_tokens) diff --git a/nanovllm/layers/attention.py b/nanovllm/layers/attention.py index de966b7..0a36141 100644 --- a/nanovllm/layers/attention.py +++ b/nanovllm/layers/attention.py @@ -189,7 +189,8 @@ class Attention(nn.Module): cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq) # Apply sparse policy if enabled - if cpu_block_table and kvcache_manager.sparse_policy is not None: + prefill_policy = kvcache_manager.get_policy_for_phase(is_prefill=True) + if cpu_block_table and prefill_policy is not None: num_chunks = getattr(context, 'num_chunks', current_chunk_idx + 1) policy_ctx = PolicyContext( query_chunk_idx=current_chunk_idx, @@ -200,7 +201,7 @@ class Attention(nn.Module): block_size=kvcache_manager.block_size, total_kv_len=len(cpu_block_table) * kvcache_manager.block_size, ) - cpu_block_table = kvcache_manager.sparse_policy.select_blocks( + cpu_block_table = prefill_policy.select_blocks( cpu_block_table, policy_ctx ) @@ -279,7 +280,11 @@ class Attention(nn.Module): cpu_block_ids, _ = kvcache_manager.get_all_cpu_blocks(seq) if current_chunk_idx < len(cpu_block_ids): cpu_block_id = cpu_block_ids[current_chunk_idx] - offload_engine.offload_slot_layer_to_cpu(write_slot, self.layer_id, cpu_block_id) + # k.shape[0] = number of tokens in current chunk + num_valid_tokens = k.shape[0] + offload_engine.offload_slot_layer_to_cpu( + write_slot, self.layer_id, cpu_block_id, num_valid_tokens + ) # CRITICAL: compute_stream must wait for offload to complete # before the next layer's store_kvcache can overwrite the GPU slot. @@ -508,7 +513,8 @@ class Attention(nn.Module): last_block_valid_tokens = block_size # Last block was exactly full # Apply sparse policy if enabled - if kvcache_manager.sparse_policy is not None: + decode_policy = kvcache_manager.get_policy_for_phase(is_prefill=False) + if decode_policy is not None: policy_ctx = PolicyContext( query_chunk_idx=0, num_query_chunks=1, @@ -518,7 +524,7 @@ class Attention(nn.Module): block_size=kvcache_manager.block_size, total_kv_len=len(cpu_block_table) * kvcache_manager.block_size, ) - cpu_block_table = kvcache_manager.sparse_policy.select_blocks( + cpu_block_table = decode_policy.select_blocks( cpu_block_table, policy_ctx )