diff --git a/bench.py b/bench.py index 514721b..ac2c42a 100644 --- a/bench.py +++ b/bench.py @@ -43,17 +43,17 @@ def main(): print("=" * 60) print("Prefill Benchmark") print("=" * 60) - bench_prefill(llm, num_seqs=1, input_len=1024) + # bench_prefill(llm, num_seqs=1, input_len=1024) # bench_prefill(llm, num_seqs=1, input_len=2048) - # bench_prefill(llm, num_seqs=1, input_len=4095) + bench_prefill(llm, num_seqs=1, input_len=4095) # bench_prefill(llm, num_seqs=16, input_len=1024) # bench_prefill(llm, num_seqs=64, input_len=1024) print("=" * 60) print("Decode Benchmark") print("=" * 60) - bench_decode(llm, num_seqs=1, max_input_len=1024, max_output_len=1024) - # bench_decode(llm, num_seqs=256, max_input_len=1024, max_output_len=1024) + # bench_decode(llm, num_seqs=1, max_input_len=1024, max_output_len=1024) + bench_decode(llm, num_seqs=1, max_input_len=4072, max_output_len=16) if __name__ == "__main__": diff --git a/bench_offload.py b/bench_offload.py index ed1b953..8055141 100644 --- a/bench_offload.py +++ b/bench_offload.py @@ -3,12 +3,17 @@ import time from random import randint, seed from nanovllm import LLM, SamplingParams +# Import sparse policy classes +from nanovllm.kvcache.sparse.quest import QuestPolicy, QuestConfig, BlockMetadataManager +from nanovllm.kvcache.sparse.hybrid import HybridPolicy +from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy + def bench_decode(llm, num_seqs, input_len, max_output_len): """Benchmark decode performance (original test)""" seed(0) - prompt_token_ids = [[randint(0, 10000) for _ in range(randint(100, input_len))] for _ in range(num_seqs)] - sampling_params = [SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=randint(100, max_output_len)) for _ in range(num_seqs)] + prompt_token_ids = [[randint(0, 10000) for _ in range(input_len)] for _ in range(num_seqs)] + sampling_params = [SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=max_output_len) for _ in range(num_seqs)] t = time.time() llm.generate(prompt_token_ids, sampling_params, use_tqdm=False) @@ -33,7 +38,67 @@ def bench_prefill(llm, num_seqs, input_len): print(f"[Prefill] Input: {total_input_tokens}tok ({num_seqs}x{input_len}), Time: {t:.2f}s, Throughput: {throughput:.2f}tok/s") +def setup_quest_policy(llm, topk_blocks=8, threshold_blocks=4): + """ + Setup Quest sparse policy for decode phase. + + Uses HybridPolicy: Full attention for prefill, Quest Top-K for decode. + """ + import torch + + kvcache_manager = llm.model_runner.kvcache_manager + offload_engine = kvcache_manager.offload_engine + + # Get model parameters from offload engine + num_layers = offload_engine.num_layers + num_kv_heads = offload_engine.num_kv_heads + head_dim = offload_engine.head_dim + num_cpu_blocks = kvcache_manager.num_cpu_blocks + dtype = offload_engine.k_cache_cpu.dtype + + print(f"Setting up Quest policy:") + print(f" num_layers={num_layers}, num_kv_heads={num_kv_heads}, head_dim={head_dim}") + print(f" num_cpu_blocks={num_cpu_blocks}, dtype={dtype}") + print(f" topk_blocks={topk_blocks}, threshold_blocks={threshold_blocks}") + + # Create BlockMetadataManager for storing min/max keys + metadata = BlockMetadataManager( + num_blocks=num_cpu_blocks, + num_layers=num_layers, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + dtype=dtype, + ) + + # Create Quest policy for decode + quest_config = QuestConfig( + topk_blocks=topk_blocks, + threshold_blocks=threshold_blocks, + ) + quest_policy = QuestPolicy(quest_config, metadata) + + # Create Hybrid policy: Full for prefill, Quest for decode + hybrid_policy = HybridPolicy( + prefill_policy=FullAttentionPolicy(), + decode_policy=quest_policy, + ) + + # Set the policy + kvcache_manager.set_sparse_policy(hybrid_policy) + print(f" Policy set: HybridPolicy(prefill=Full, decode=Quest)") + + return hybrid_policy + + def main(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("--no-sparse", action="store_true", help="Disable sparse attention (baseline)") + parser.add_argument("--topk", type=int, default=8, help="Top-K blocks for Quest") + parser.add_argument("--input-len", type=int, default=128 * 1024, help="Input length in tokens") + parser.add_argument("--output-len", type=int, default=128, help="Output length in tokens") + args = parser.parse_args() + path = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/") llm = LLM( path, @@ -45,22 +110,25 @@ def main(): num_prefetch_blocks=4, ) + if not args.no_sparse: + # Setup Quest policy for decode (Top-K blocks, apply when > 4 blocks) + setup_quest_policy(llm, topk_blocks=args.topk, threshold_blocks=4) + print(f"\n[Quest Sparse Attention] topk={args.topk}") + else: + print("\n[Full Attention] No sparse policy (baseline)") + # Warmup llm.generate(["Benchmark: "], SamplingParams()) print("=" * 60) print("Prefill Benchmark (CPU Offload)") print("=" * 60) - # bench_prefill(llm, num_seqs=1, input_len=1024) - # bench_prefill(llm, num_seqs=1, input_len=2048) - # bench_prefill(llm, num_seqs=1, input_len=4096) - bench_prefill(llm, num_seqs=1, input_len=128 * 1024) + bench_prefill(llm, num_seqs=1, input_len=args.input_len) print("=" * 60) print("Decode Benchmark (CPU Offload)") print("=" * 60) - bench_decode(llm, num_seqs=1, input_len=128 * 1024, max_output_len=128) - # bench_decode(llm, num_seqs=1, input_len=2048, max_output_len=128) + bench_decode(llm, num_seqs=1, input_len=args.input_len, max_output_len=args.output_len) if __name__ == "__main__": diff --git a/nanovllm/config.py b/nanovllm/config.py index 5b9f98d..da2aee5 100644 --- a/nanovllm/config.py +++ b/nanovllm/config.py @@ -28,6 +28,13 @@ class Config: num_gpu_kvcache_blocks: int = -1 num_cpu_kvcache_blocks: int = -1 + # Sparse attention configuration + sparse_policy: str | None = None # "vertical_slash", "quest", "streaming_llm", or None + sparse_num_sink_blocks: int = 1 # Number of sink blocks for sparse patterns + sparse_local_window_blocks: int = 2 # Local window size for VerticalSlash + sparse_topk_blocks: int = 8 # Top-K blocks for Quest + sparse_threshold_blocks: int = 4 # Apply sparse only when blocks > threshold + def __post_init__(self): assert os.path.isdir(self.model) assert self.kvcache_block_size % 256 == 0 diff --git a/nanovllm/engine/model_runner.py b/nanovllm/engine/model_runner.py index 38cd684..e39e6c5 100644 --- a/nanovllm/engine/model_runner.py +++ b/nanovllm/engine/model_runner.py @@ -700,6 +700,20 @@ class ModelRunner: # Offload this chunk's ring buffer slot to CPU (async) if block_idx < len(cpu_block_ids): cpu_block_id = cpu_block_ids[block_idx] + + # Call sparse policy hook before offload (to capture metadata) + sparse_policy = self.kvcache_manager.sparse_policy + if sparse_policy is not None: + num_tokens = chunk_end - chunk_start + for layer_id in range(offload_engine.num_layers): + k_cache = offload_engine.k_cache_gpu[layer_id, write_slot, :num_tokens] + sparse_policy.on_block_offloaded( + cpu_block_id=cpu_block_id, + layer_id=layer_id, + k_cache=k_cache, + num_valid_tokens=num_tokens, + ) + offload_engine.offload_slot_to_cpu(write_slot, cpu_block_id) # Wait for offload to complete before next chunk diff --git a/nanovllm/kvcache/hybrid_manager.py b/nanovllm/kvcache/hybrid_manager.py index c915acd..f4dde34 100644 --- a/nanovllm/kvcache/hybrid_manager.py +++ b/nanovllm/kvcache/hybrid_manager.py @@ -25,6 +25,11 @@ from nanovllm.kvcache.offload_engine import OffloadEngine from nanovllm.kvcache.policies.base_policy import EvictionPolicy from nanovllm.kvcache.policies.lru_policy import LRUPolicy +# Type checking import for sparse policy +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from nanovllm.kvcache.sparse.policy import SparsePolicy + class BlockLocation(Enum): """Where a logical block's data currently resides.""" @@ -142,6 +147,9 @@ class HybridKVCacheManager(KVCacheManager): # Key: sequence id, Value: starting position where decode began in current block self._decode_start_pos: Dict[int, int] = {} + # Sparse attention policy (optional) + self.sparse_policy: Optional["SparsePolicy"] = None + @property def block_size(self) -> int: return self._block_size @@ -174,6 +182,24 @@ class HybridKVCacheManager(KVCacheManager): assert self.offload_engine is not None return self.offload_engine.get_layer_cache(layer_id) + def set_sparse_policy(self, policy: "SparsePolicy") -> None: + """ + Set sparse attention policy for block selection. + + The sparse policy determines which KV blocks to load from CPU + for each query chunk during chunked attention computation. + + Args: + policy: SparsePolicy instance (e.g., VerticalSlashPolicy, QuestPolicy) + + Example: + from nanovllm.kvcache.sparse import VerticalSlashPolicy, VerticalSlashConfig + policy = VerticalSlashPolicy(VerticalSlashConfig(num_sink_blocks=2)) + manager.set_sparse_policy(policy) + """ + self.sparse_policy = policy + logger.info(f"Sparse attention policy set: {policy}") + def _allocate_gpu_slot(self, protected_logical_ids: Optional[Set[int]] = None) -> int: """ Get a free GPU slot, evicting if necessary. diff --git a/nanovllm/kvcache/sparse/__init__.py b/nanovllm/kvcache/sparse/__init__.py new file mode 100644 index 0000000..d397b0f --- /dev/null +++ b/nanovllm/kvcache/sparse/__init__.py @@ -0,0 +1,90 @@ +""" +Sparse Attention Policy module. + +Provides pluggable policies for selecting which KV blocks to load +during chunked attention with CPU offload. + +Usage: + from nanovllm.kvcache.sparse import SparsePolicy, PolicyContext + from nanovllm.kvcache.sparse import VerticalSlashPolicy, QuestPolicy + + # Use built-in policy + policy = VerticalSlashPolicy(VerticalSlashConfig()) + + # Or create custom policy + class MyPolicy(SparsePolicy): + def select_blocks(self, available_blocks, ctx): + return available_blocks[:5] # Just first 5 blocks +""" + +from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext +from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy +from nanovllm.kvcache.sparse.vertical_slash import VerticalSlashPolicy, VerticalSlashConfig +from nanovllm.kvcache.sparse.quest import QuestPolicy, QuestConfig, BlockMetadataManager +from nanovllm.kvcache.sparse.streaming_llm import StreamingLLMPolicy, StreamingLLMConfig +from nanovllm.kvcache.sparse.hybrid import HybridPolicy + +# Built-in policy registry +BUILTIN_SPARSE_POLICIES = { + "full": FullAttentionPolicy, + "vertical_slash": VerticalSlashPolicy, + "streaming_llm": StreamingLLMPolicy, +} + + +def get_sparse_policy(policy_name: str, **kwargs) -> SparsePolicy: + """ + Get a sparse attention policy instance by name. + + Args: + policy_name: Policy name ("full", "vertical_slash", "streaming_llm", "quest") + **kwargs: Policy-specific configuration + + Returns: + SparsePolicy instance + """ + policy_name = policy_name.lower() + + if policy_name == "full": + return FullAttentionPolicy() + elif policy_name == "vertical_slash": + config = VerticalSlashConfig( + num_sink_blocks=kwargs.get("num_sink_blocks", 1), + local_window_blocks=kwargs.get("local_window_blocks", 2), + threshold_blocks=kwargs.get("threshold_blocks", 4), + ) + return VerticalSlashPolicy(config) + elif policy_name == "streaming_llm": + config = StreamingLLMConfig( + num_sink_blocks=kwargs.get("num_sink_blocks", 1), + num_recent_blocks=kwargs.get("num_recent_blocks", 3), + ) + return StreamingLLMPolicy(config) + elif policy_name == "quest": + # Quest requires metadata_manager to be passed separately + raise ValueError( + "Quest policy requires BlockMetadataManager. " + "Use QuestPolicy(config, metadata_manager) directly." + ) + else: + raise ValueError( + f"Unknown sparse policy '{policy_name}'. " + f"Available policies: {list(BUILTIN_SPARSE_POLICIES.keys())}" + ) + + +__all__ = [ + "SparsePolicy", + "PolicyContext", + "FullAttentionPolicy", + "VerticalSlashPolicy", + "VerticalSlashConfig", + "QuestPolicy", + "QuestConfig", + "BlockMetadataManager", + "StreamingLLMPolicy", + "StreamingLLMConfig", + "HybridPolicy", + "get_sparse_policy", + "BUILTIN_SPARSE_POLICIES", +] diff --git a/nanovllm/kvcache/sparse/full_policy.py b/nanovllm/kvcache/sparse/full_policy.py new file mode 100644 index 0000000..6e57d5c --- /dev/null +++ b/nanovllm/kvcache/sparse/full_policy.py @@ -0,0 +1,34 @@ +""" +Full attention policy - loads all blocks (no sparsity). + +This serves as a baseline and default policy when sparse +attention is not needed. +""" + +from typing import List +from .policy import SparsePolicy, PolicyContext + + +class FullAttentionPolicy(SparsePolicy): + """ + Full attention policy that loads all available blocks. + + This is the default behavior with no sparsity - all previous + KV cache blocks are loaded for each query chunk. + + Use this as: + - A baseline for comparing sparse policies + - When you need full attention accuracy + - For short sequences where sparsity isn't beneficial + """ + + def select_blocks( + self, + available_blocks: List[int], + ctx: PolicyContext, + ) -> List[int]: + """Return all blocks - no sparsity.""" + return available_blocks + + def __repr__(self) -> str: + return "FullAttentionPolicy()" diff --git a/nanovllm/kvcache/sparse/hybrid.py b/nanovllm/kvcache/sparse/hybrid.py new file mode 100644 index 0000000..1c2492a --- /dev/null +++ b/nanovllm/kvcache/sparse/hybrid.py @@ -0,0 +1,93 @@ +""" +Hybrid sparse attention policy. + +Allows using different policies for prefill vs decode phases. +This is useful because optimal sparsity patterns often differ: +- Prefill: fixed patterns work well (e.g., VerticalSlash) +- Decode: query-aware selection helps (e.g., Quest) +""" + +from typing import List +import torch +from .policy import SparsePolicy, PolicyContext + + +class HybridPolicy(SparsePolicy): + """ + Hybrid policy that uses different policies for prefill and decode. + + Example usage: + ```python + from nanovllm.kvcache.sparse import ( + HybridPolicy, VerticalSlashPolicy, QuestPolicy, + VerticalSlashConfig, QuestConfig, BlockMetadataManager + ) + + # Prefill: use fast fixed pattern + prefill_policy = VerticalSlashPolicy(VerticalSlashConfig( + num_sink_blocks=1, + local_window_blocks=3, + )) + + # Decode: use query-aware selection + metadata = BlockMetadataManager(num_blocks, num_layers, num_heads, head_dim) + decode_policy = QuestPolicy(QuestConfig(topk_blocks=8), metadata) + + # Combine + policy = HybridPolicy(prefill_policy, decode_policy) + ``` + """ + + def __init__( + self, + prefill_policy: SparsePolicy, + decode_policy: SparsePolicy, + ): + """ + Initialize hybrid policy. + + Args: + prefill_policy: Policy to use during prefill phase + decode_policy: Policy to use during decode phase + """ + self.prefill_policy = prefill_policy + self.decode_policy = decode_policy + + def select_blocks( + self, + available_blocks: List[int], + ctx: PolicyContext, + ) -> List[int]: + """Delegate to appropriate policy based on phase.""" + if ctx.is_prefill: + return self.prefill_policy.select_blocks(available_blocks, ctx) + else: + return self.decode_policy.select_blocks(available_blocks, ctx) + + def on_block_offloaded( + self, + cpu_block_id: int, + layer_id: int, + k_cache: torch.Tensor, + num_valid_tokens: int, + ) -> None: + """Forward to both policies (both may need metadata updates).""" + self.prefill_policy.on_block_offloaded( + cpu_block_id, layer_id, k_cache, num_valid_tokens + ) + self.decode_policy.on_block_offloaded( + cpu_block_id, layer_id, k_cache, num_valid_tokens + ) + + def reset(self) -> None: + """Reset both policies.""" + self.prefill_policy.reset() + self.decode_policy.reset() + + def __repr__(self) -> str: + return ( + f"HybridPolicy(\n" + f" prefill={self.prefill_policy},\n" + f" decode={self.decode_policy}\n" + f")" + ) diff --git a/nanovllm/kvcache/sparse/policy.py b/nanovllm/kvcache/sparse/policy.py new file mode 100644 index 0000000..157b3ca --- /dev/null +++ b/nanovllm/kvcache/sparse/policy.py @@ -0,0 +1,124 @@ +""" +Base class for sparse attention policies. + +Sparse attention policies determine which KV cache blocks to load +from CPU for each query chunk during chunked attention computation. +""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import List, Optional, Any +import torch + + +@dataclass +class PolicyContext: + """ + Context passed to sparse policy for block selection. + + This dataclass contains all information needed by a sparse policy + to decide which blocks to load for the current query chunk. + """ + + query_chunk_idx: int + """Index of the current query chunk (0-indexed).""" + + num_query_chunks: int + """Total number of query chunks in this prefill.""" + + layer_id: int + """Current transformer layer index.""" + + query: Optional[torch.Tensor] + """ + Query tensor for current chunk. + Shape: [1, num_heads, head_dim] for decode, [1, seq_len, num_heads, head_dim] for prefill. + May be None if not available (e.g., some prefill scenarios). + """ + + is_prefill: bool + """True if in prefill phase, False if in decode phase.""" + + block_size: int = 4096 + """Number of tokens per block.""" + + total_kv_len: int = 0 + """Total KV sequence length so far (for reference).""" + + +class SparsePolicy(ABC): + """ + Abstract base class for sparse attention policies. + + Subclass this and implement select_blocks() to create custom + sparse attention patterns. The policy receives context about + the current query chunk and returns which KV blocks to load. + + Example: + class MySparsePolicy(SparsePolicy): + def select_blocks(self, available_blocks, ctx): + # Load first block and last 2 blocks + if len(available_blocks) <= 3: + return available_blocks + return [available_blocks[0]] + available_blocks[-2:] + """ + + @abstractmethod + def select_blocks( + self, + available_blocks: List[int], + ctx: PolicyContext, + ) -> List[int]: + """ + Select which KV blocks to load for the current query chunk. + + This is the core method that defines the sparse attention pattern. + The returned blocks will be loaded from CPU to GPU for attention + computation against the current query chunk. + + Args: + available_blocks: List of CPU block IDs that contain KV cache + from previous chunks. These are ordered by + their position in the sequence. + ctx: PolicyContext with information about the current query + chunk, layer, phase (prefill/decode), etc. + + Returns: + List of block IDs to load (must be a subset of available_blocks). + The order may affect performance (sequential access is faster). + Returning [] means no previous blocks will be loaded. + """ + pass + + def on_block_offloaded( + self, + cpu_block_id: int, + layer_id: int, + k_cache: torch.Tensor, + num_valid_tokens: int, + ) -> None: + """ + Hook called when a block is offloaded from GPU to CPU. + + Override this to collect metadata about blocks (e.g., min/max keys + for Quest-style selection). Default implementation does nothing. + + Args: + cpu_block_id: The CPU block ID that was written + layer_id: Transformer layer index + k_cache: Key cache tensor [block_size, num_kv_heads, head_dim] + num_valid_tokens: Number of valid tokens in this block + """ + pass + + def reset(self) -> None: + """ + Reset policy state. + + Called when starting a new sequence or clearing state. + Default implementation does nothing. + """ + pass + + def __repr__(self) -> str: + return f"{self.__class__.__name__}()" diff --git a/nanovllm/kvcache/sparse/quest.py b/nanovllm/kvcache/sparse/quest.py new file mode 100644 index 0000000..7439256 --- /dev/null +++ b/nanovllm/kvcache/sparse/quest.py @@ -0,0 +1,284 @@ +""" +Quest-style sparse attention policy. + +Uses min/max key bounds per block to estimate attention scores +and select Top-K blocks most relevant to the current query. + +Reference: Quest paper on query-aware KV cache selection. +""" + +import logging +import torch +from dataclasses import dataclass +from typing import List, Tuple, Optional +from .policy import SparsePolicy, PolicyContext + +logger = logging.getLogger(__name__) + + +class BlockMetadataManager: + """ + Manages per-block metadata for Quest-style sparse selection. + + Stores min/max key values for each block, which are used to + compute upper bounds on attention scores without loading the + full KV cache. + + Memory usage: 2 * num_blocks * num_layers * num_kv_heads * head_dim * dtype_size + Example: 1000 blocks, 28 layers, 4 heads, 128 dim, bf16 = ~57 MB + """ + + def __init__( + self, + num_blocks: int, + num_layers: int, + num_kv_heads: int, + head_dim: int, + dtype: torch.dtype = torch.bfloat16, + ): + """ + Initialize metadata storage. + + Args: + num_blocks: Maximum number of CPU blocks + num_layers: Number of transformer layers + num_kv_heads: Number of KV attention heads + head_dim: Dimension per head + dtype: Data type for metadata storage + """ + self.num_blocks = num_blocks + self.num_layers = num_layers + self.num_kv_heads = num_kv_heads + self.head_dim = head_dim + self.dtype = dtype + + # Per-block min/max key values: [num_blocks, num_layers, num_heads, head_dim] + shape = (num_blocks, num_layers, num_kv_heads, head_dim) + self.key_min = torch.zeros(shape, dtype=dtype, pin_memory=True) + self.key_max = torch.zeros(shape, dtype=dtype, pin_memory=True) + + # Track which blocks have valid metadata + self.valid_blocks = torch.zeros(num_blocks, dtype=torch.bool) + + def update_metadata( + self, + block_id: int, + layer_id: int, + k_cache: torch.Tensor, + num_valid_tokens: int, + ) -> None: + """ + Update min/max key bounds for a block. + + Called when a block is offloaded to CPU. + + Args: + block_id: CPU block ID + layer_id: Layer index + k_cache: Key cache tensor [block_size, num_kv_heads, head_dim] + num_valid_tokens: Number of valid tokens in this block + """ + if num_valid_tokens == 0: + return + + # Get valid keys only + k_valid = k_cache[:num_valid_tokens].cpu() # [num_tokens, heads, dim] + + # Compute min/max across token dimension + self.key_min[block_id, layer_id] = k_valid.min(dim=0).values + self.key_max[block_id, layer_id] = k_valid.max(dim=0).values + self.valid_blocks[block_id] = True + + def get_block_metadata( + self, + block_ids: List[int], + layer_id: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Get min/max keys for specified blocks. + + Args: + block_ids: List of CPU block IDs + layer_id: Layer index + + Returns: + Tuple of (key_min, key_max) tensors + Shape: [num_blocks, num_kv_heads, head_dim] + """ + key_min = self.key_min[block_ids, layer_id] + key_max = self.key_max[block_ids, layer_id] + return key_min, key_max + + def reset(self) -> None: + """Reset all metadata.""" + self.key_min.zero_() + self.key_max.zero_() + self.valid_blocks.zero_() + + +@dataclass +class QuestConfig: + """Configuration for QuestPolicy.""" + + topk_blocks: int = 8 + """Number of top blocks to select based on estimated attention scores.""" + + threshold_blocks: int = 4 + """If total blocks <= threshold, load all (no scoring needed).""" + + include_sink_blocks: int = 0 + """Always include this many sink blocks (first N blocks), in addition to Top-K.""" + + include_recent_blocks: int = 0 + """Always include this many recent blocks (last N blocks), in addition to Top-K.""" + + +class QuestPolicy(SparsePolicy): + """ + Quest-style Top-K block selection using min/max key bounds. + + For each query, computes an upper bound on attention scores for + each block using the stored min/max keys, then selects the Top-K + blocks with highest estimated scores. + + Score computation: + score(q, block) = max(q · key_min, q · key_max) + + This upper bound is derived from the fact that for any key k in + the block: min_k <= k <= max_k (element-wise), so the actual + attention score is bounded by the maximum of the two extremes. + """ + + def __init__( + self, + config: QuestConfig, + metadata_manager: BlockMetadataManager, + ): + """ + Initialize Quest policy. + + Args: + config: QuestConfig with selection parameters + metadata_manager: BlockMetadataManager for min/max key storage + """ + self.config = config + self.metadata = metadata_manager + + def select_blocks( + self, + available_blocks: List[int], + ctx: PolicyContext, + ) -> List[int]: + """ + Select Top-K blocks based on query-key similarity bounds. + + If query is not available (some prefill scenarios), falls back + to loading all blocks. + """ + n = len(available_blocks) + + # If below threshold or no query, load all + if n <= self.config.threshold_blocks: + return available_blocks + + if ctx.query is None: + # No query available - cannot compute scores + return available_blocks + + # Get metadata for available blocks + key_min, key_max = self.metadata.get_block_metadata( + available_blocks, ctx.layer_id + ) + + # Move to query device for computation + device = ctx.query.device + key_min = key_min.to(device, non_blocking=True) + key_max = key_max.to(device, non_blocking=True) + + # Compute upper bound scores + # query shape: [1, num_heads, head_dim] or [1, seq_len, num_heads, head_dim] + q = ctx.query + if q.dim() == 4: + # Prefill: use mean over sequence length + q = q.mean(dim=1) # [1, num_heads, head_dim] + q = q.squeeze(0) # [num_q_heads, head_dim] + + # Handle GQA: query may have more heads than KV + # key_min/key_max: [num_blocks, num_kv_heads, head_dim] + num_q_heads = q.shape[0] + num_kv_heads = key_min.shape[1] + + if num_q_heads != num_kv_heads: + # GQA: group query heads and average per KV group + # Reshape q: [num_q_heads, head_dim] -> [num_kv_heads, group_size, head_dim] + group_size = num_q_heads // num_kv_heads + q = q.view(num_kv_heads, group_size, -1).mean(dim=1) # [num_kv_heads, head_dim] + + # Score: max(q·k_min, q·k_max) averaged over heads + # key_min/key_max: [num_blocks, num_kv_heads, head_dim] + # q: [num_kv_heads, head_dim] + score_min = torch.einsum('hd,bhd->bh', q, key_min) # [num_blocks, kv_heads] + score_max = torch.einsum('hd,bhd->bh', q, key_max) + scores = torch.maximum(score_min, score_max).mean(dim=-1) # [num_blocks] + + # Build selection set + selected_indices = set() + + # Always include sink blocks + for i in range(min(self.config.include_sink_blocks, n)): + selected_indices.add(i) + + # Always include recent blocks + for i in range(max(0, n - self.config.include_recent_blocks), n): + selected_indices.add(i) + + # Top-K selection from remaining + remaining_k = max(0, self.config.topk_blocks - len(selected_indices)) + if remaining_k > 0: + # Mask out already selected + mask = torch.ones(n, dtype=torch.bool, device=device) + for idx in selected_indices: + mask[idx] = False + + if mask.any(): + masked_scores = scores.clone() + masked_scores[~mask] = float('-inf') + topk_count = min(remaining_k, mask.sum().item()) + if topk_count > 0: + topk_indices = masked_scores.topk(topk_count).indices.cpu().tolist() + selected_indices.update(topk_indices) + + # Return in sequential order for better memory access + result = [available_blocks[i] for i in sorted(selected_indices)] + + # Log selection info (only for layer 0 to avoid spam) + if ctx.layer_id == 0: + logger.debug( + f"Quest select: {len(result)}/{n} blocks " + f"(topk={self.config.topk_blocks}, sink={self.config.include_sink_blocks}, " + f"recent={self.config.include_recent_blocks})" + ) + + return result + + def on_block_offloaded( + self, + cpu_block_id: int, + layer_id: int, + k_cache: torch.Tensor, + num_valid_tokens: int, + ) -> None: + """Update min/max key metadata when block is offloaded.""" + self.metadata.update_metadata(cpu_block_id, layer_id, k_cache, num_valid_tokens) + + def reset(self) -> None: + """Reset metadata.""" + self.metadata.reset() + + def __repr__(self) -> str: + return ( + f"QuestPolicy(topk={self.config.topk_blocks}, " + f"threshold={self.config.threshold_blocks}, " + f"sink={self.config.include_sink_blocks}, " + f"recent={self.config.include_recent_blocks})" + ) diff --git a/nanovllm/kvcache/sparse/streaming_llm.py b/nanovllm/kvcache/sparse/streaming_llm.py new file mode 100644 index 0000000..29606cb --- /dev/null +++ b/nanovllm/kvcache/sparse/streaming_llm.py @@ -0,0 +1,84 @@ +""" +StreamingLLM sparse attention policy. + +Only keeps sink tokens (beginning) + recent tokens (end). +Intermediate context is discarded. This enables infinite-length +generation but loses intermediate context. + +Reference: StreamingLLM paper on attention sinks. +""" + +from dataclasses import dataclass +from typing import List +from .policy import SparsePolicy, PolicyContext + + +@dataclass +class StreamingLLMConfig: + """Configuration for StreamingLLMPolicy.""" + + num_sink_blocks: int = 1 + """Number of blocks at the beginning to always include (attention sinks).""" + + num_recent_blocks: int = 3 + """Number of most recent blocks to include (sliding window).""" + + +class StreamingLLMPolicy(SparsePolicy): + """ + StreamingLLM pattern: sink tokens + recent tokens only. + + This is the most aggressive sparsity pattern - only keeps a small + fixed window of context. Suitable for: + - Very long streaming generation + - When intermediate context can be safely discarded + - Maximizing throughput over accuracy + + Pattern visualization: + ``` + Blocks: [0] [1] [2] [3] [4] [5] [6] [7] [8] + ↑ × × × ↑ ↑ ↑ + sink (discarded) recent window + ``` + + Warning: This loses information from intermediate blocks! + Use only when this trade-off is acceptable. + """ + + def __init__(self, config: StreamingLLMConfig = None): + self.config = config or StreamingLLMConfig() + + def select_blocks( + self, + available_blocks: List[int], + ctx: PolicyContext, + ) -> List[int]: + """ + Select sink blocks + recent blocks only. + + Intermediate blocks are not loaded (effectively discarded). + """ + n = len(available_blocks) + + # If total blocks fit in sink + recent, load all + total_keep = self.config.num_sink_blocks + self.config.num_recent_blocks + if n <= total_keep: + return available_blocks + + selected_indices = set() + + # Sink blocks (first N) + for i in range(min(self.config.num_sink_blocks, n)): + selected_indices.add(i) + + # Recent blocks (last M) + for i in range(max(0, n - self.config.num_recent_blocks), n): + selected_indices.add(i) + + return [available_blocks[i] for i in sorted(selected_indices)] + + def __repr__(self) -> str: + return ( + f"StreamingLLMPolicy(sink={self.config.num_sink_blocks}, " + f"recent={self.config.num_recent_blocks})" + ) diff --git a/nanovllm/kvcache/sparse/vertical_slash.py b/nanovllm/kvcache/sparse/vertical_slash.py new file mode 100644 index 0000000..372b4b6 --- /dev/null +++ b/nanovllm/kvcache/sparse/vertical_slash.py @@ -0,0 +1,95 @@ +""" +Vertical-Slash sparse attention policy (MInference-style). + +Selects sink blocks (beginning of sequence) + local window blocks +(near the current query position). This pattern captures: +- Important initial context (system prompt, instructions) +- Recent context (relevant for local dependencies) +""" + +from dataclasses import dataclass +from typing import List +from .policy import SparsePolicy, PolicyContext + + +@dataclass +class VerticalSlashConfig: + """Configuration for VerticalSlashPolicy.""" + + num_sink_blocks: int = 1 + """Number of blocks at the beginning to always include (sink tokens).""" + + local_window_blocks: int = 2 + """Number of blocks in the local window near current query position.""" + + threshold_blocks: int = 4 + """If total blocks <= threshold, load all (no sparsity applied).""" + + +class VerticalSlashPolicy(SparsePolicy): + """ + Vertical-Slash pattern: sink tokens + local window. + + This pattern is inspired by MInference and observations that: + 1. Initial tokens (sink) often receive high attention + 2. Local context (recent tokens) is important for dependencies + + Pattern visualization: + ``` + Blocks: [0] [1] [2] [3] [4] [5] [6] [7] [8] + ↑ ↑ ↑ ↑ + sink local window (for query at block 9) + ``` + + For prefill chunk K, the local window is blocks [K-window, K-1]. + For decode, the local window is the last N blocks. + """ + + def __init__(self, config: VerticalSlashConfig = None): + self.config = config or VerticalSlashConfig() + + def select_blocks( + self, + available_blocks: List[int], + ctx: PolicyContext, + ) -> List[int]: + """ + Select sink blocks + local window blocks. + + For prefill: local window is relative to current chunk position. + For decode: local window is the most recent blocks. + """ + n = len(available_blocks) + + # If below threshold, load all + if n <= self.config.threshold_blocks: + return available_blocks + + selected_indices = set() + + # Sink blocks (first N blocks) + for i in range(min(self.config.num_sink_blocks, n)): + selected_indices.add(i) + + # Local window + if ctx.is_prefill: + # For prefill chunk K, local window is blocks [K-window, K-1] + # (blocks before current chunk, not including current) + window_end = min(ctx.query_chunk_idx, n) + window_start = max(0, window_end - self.config.local_window_blocks) + for i in range(window_start, window_end): + selected_indices.add(i) + else: + # For decode, local window is the last M blocks + for i in range(max(0, n - self.config.local_window_blocks), n): + selected_indices.add(i) + + # Return blocks in order (maintains sequential access pattern) + return [available_blocks[i] for i in sorted(selected_indices)] + + def __repr__(self) -> str: + return ( + f"VerticalSlashPolicy(sink={self.config.num_sink_blocks}, " + f"window={self.config.local_window_blocks}, " + f"threshold={self.config.threshold_blocks})" + ) diff --git a/nanovllm/layers/attention.py b/nanovllm/layers/attention.py index b99b679..ef6c1f5 100644 --- a/nanovllm/layers/attention.py +++ b/nanovllm/layers/attention.py @@ -6,6 +6,7 @@ import triton.language as tl from flash_attn.flash_attn_interface import flash_attn_varlen_func, flash_attn_with_kvcache from nanovllm.utils.context import get_context +from nanovllm.kvcache.sparse.policy import PolicyContext logger = logging.getLogger(__name__) @@ -133,6 +134,22 @@ class Attention(nn.Module): # Get prefilled CPU blocks (blocks from previous chunks) cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq) + # Apply sparse policy if enabled + if cpu_block_table and kvcache_manager.sparse_policy is not None: + num_chunks = getattr(context, 'num_chunks', current_chunk_idx + 1) + policy_ctx = PolicyContext( + query_chunk_idx=current_chunk_idx, + num_query_chunks=num_chunks, + layer_id=self.layer_id, + query=None, # Prefill typically doesn't use query for selection + is_prefill=True, + block_size=kvcache_manager.block_size, + total_kv_len=len(cpu_block_table) * kvcache_manager.block_size, + ) + cpu_block_table = kvcache_manager.sparse_policy.select_blocks( + cpu_block_table, policy_ctx + ) + if cpu_block_table: offload_engine = kvcache_manager.offload_engine @@ -344,6 +361,21 @@ class Attention(nn.Module): if not cpu_block_table: raise RuntimeError("Chunked decode attention failed: no CPU blocks available") + # Apply sparse policy if enabled + if kvcache_manager.sparse_policy is not None: + policy_ctx = PolicyContext( + query_chunk_idx=0, + num_query_chunks=1, + layer_id=self.layer_id, + query=q_batched, # Decode provides query for query-aware selection + is_prefill=False, + block_size=kvcache_manager.block_size, + total_kv_len=len(cpu_block_table) * kvcache_manager.block_size, + ) + cpu_block_table = kvcache_manager.sparse_policy.select_blocks( + cpu_block_table, policy_ctx + ) + offload_engine = kvcache_manager.offload_engine # Use prefetch_size as chunk size for double buffering diff --git a/tests/test_sparse_policy.py b/tests/test_sparse_policy.py new file mode 100644 index 0000000..e8c4b51 --- /dev/null +++ b/tests/test_sparse_policy.py @@ -0,0 +1,252 @@ +""" +Test sparse attention policies. + +Usage: + CUDA_VISIBLE_DEVICES=4,5 python tests/test_sparse_policy.py [policy_name] + +Policy names: full, vertical_slash, streaming_llm, quest +""" + +import sys +import os + +os.environ["NANOVLLM_LOG_LEVEL"] = "WARNING" + +import torch +from typing import List + +# Test the sparse policy implementations +from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext +from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy +from nanovllm.kvcache.sparse.vertical_slash import VerticalSlashPolicy, VerticalSlashConfig +from nanovllm.kvcache.sparse.streaming_llm import StreamingLLMPolicy, StreamingLLMConfig +from nanovllm.kvcache.sparse.quest import QuestPolicy, QuestConfig, BlockMetadataManager + + +def test_full_attention_policy(): + """Test FullAttentionPolicy returns all blocks.""" + print("\n=== Testing FullAttentionPolicy ===") + policy = FullAttentionPolicy() + + available_blocks = list(range(10)) + ctx = PolicyContext( + query_chunk_idx=5, + num_query_chunks=10, + layer_id=0, + query=None, + is_prefill=True, + ) + + selected = policy.select_blocks(available_blocks, ctx) + assert selected == available_blocks, f"Expected all blocks, got {selected}" + print(f" Prefill: input={available_blocks}, selected={selected} [PASS]") + + # Test decode + ctx.is_prefill = False + selected = policy.select_blocks(available_blocks, ctx) + assert selected == available_blocks, f"Expected all blocks, got {selected}" + print(f" Decode: input={available_blocks}, selected={selected} [PASS]") + + +def test_vertical_slash_policy(): + """Test VerticalSlashPolicy selects sink + local window.""" + print("\n=== Testing VerticalSlashPolicy ===") + config = VerticalSlashConfig( + num_sink_blocks=2, + local_window_blocks=3, + threshold_blocks=4, + ) + policy = VerticalSlashPolicy(config) + + # Test with 10 blocks, chunk 7 (should select sink[0,1] + local[4,5,6]) + available_blocks = list(range(10)) + ctx = PolicyContext( + query_chunk_idx=7, + num_query_chunks=10, + layer_id=0, + query=None, + is_prefill=True, + ) + + selected = policy.select_blocks(available_blocks, ctx) + expected = [0, 1, 4, 5, 6] # sink + local window before chunk 7 + assert selected == expected, f"Expected {expected}, got {selected}" + print(f" Prefill chunk 7: input={available_blocks}, selected={selected} [PASS]") + + # Test with small number of blocks (below threshold) + available_blocks = [0, 1, 2] + selected = policy.select_blocks(available_blocks, ctx) + assert selected == [0, 1, 2], f"Expected all blocks for small input, got {selected}" + print(f" Below threshold: input={[0,1,2]}, selected={selected} [PASS]") + + # Test decode (local window is last M blocks) + available_blocks = list(range(10)) + ctx.is_prefill = False + selected = policy.select_blocks(available_blocks, ctx) + expected = [0, 1, 7, 8, 9] # sink + last 3 blocks + assert selected == expected, f"Expected {expected}, got {selected}" + print(f" Decode: input={available_blocks}, selected={selected} [PASS]") + + +def test_streaming_llm_policy(): + """Test StreamingLLMPolicy selects sink + recent only.""" + print("\n=== Testing StreamingLLMPolicy ===") + config = StreamingLLMConfig( + num_sink_blocks=1, + num_recent_blocks=2, + ) + policy = StreamingLLMPolicy(config) + + available_blocks = list(range(10)) + ctx = PolicyContext( + query_chunk_idx=0, + num_query_chunks=1, + layer_id=0, + query=None, + is_prefill=False, + ) + + selected = policy.select_blocks(available_blocks, ctx) + expected = [0, 8, 9] # sink[0] + recent[8,9] + assert selected == expected, f"Expected {expected}, got {selected}" + print(f" 10 blocks: selected={selected} [PASS]") + + # Test with 3 blocks (all fit in sink+recent) + available_blocks = [0, 1, 2] + selected = policy.select_blocks(available_blocks, ctx) + assert selected == [0, 1, 2], f"Expected all blocks, got {selected}" + print(f" 3 blocks: selected={selected} [PASS]") + + +def test_quest_policy(): + """Test QuestPolicy with mock metadata.""" + print("\n=== Testing QuestPolicy ===") + + # Create metadata manager + num_blocks = 10 + num_layers = 2 + num_kv_heads = 4 + head_dim = 64 + dtype = torch.float32 + + metadata = BlockMetadataManager( + num_blocks=num_blocks, + num_layers=num_layers, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + dtype=dtype, + ) + + # Simulate offloading blocks with different key patterns + # Blocks 0, 5, 9 will have high scores (keys aligned with query) + for block_id in range(num_blocks): + for layer_id in range(num_layers): + k_cache = torch.randn(100, num_kv_heads, head_dim) # 100 tokens per block + if block_id in [0, 5, 9]: + # Make these blocks have keys that score high + k_cache = k_cache.abs() # All positive + else: + k_cache = -k_cache.abs() # All negative + metadata.update_metadata(block_id, layer_id, k_cache, 100) + + config = QuestConfig( + topk_blocks=4, + threshold_blocks=3, + ) + policy = QuestPolicy(config, metadata) + + available_blocks = list(range(10)) + + # Create query that scores high with positive keys + query = torch.ones(1, num_kv_heads, head_dim, device='cuda') + + ctx = PolicyContext( + query_chunk_idx=0, + num_query_chunks=1, + layer_id=0, + query=query, + is_prefill=False, + ) + + selected = policy.select_blocks(available_blocks, ctx) + print(f" Top-4 selection: input={available_blocks}, selected={selected}") + + # High-scoring blocks [0, 5, 9] should be in selection + for expected_block in [0, 5, 9]: + assert expected_block in selected, f"Expected block {expected_block} in selection" + print(f" High-score blocks [0, 5, 9] in selection [PASS]") + + # Test below threshold (should return all) + available_blocks = [0, 1, 2] + selected = policy.select_blocks(available_blocks, ctx) + assert selected == [0, 1, 2], f"Expected all blocks below threshold, got {selected}" + print(f" Below threshold: selected={selected} [PASS]") + + # Test without query (should return all) + ctx.query = None + available_blocks = list(range(10)) + selected = policy.select_blocks(available_blocks, ctx) + assert selected == available_blocks, f"Expected all blocks without query, got {selected}" + print(f" No query: selected all [PASS]") + + +def test_custom_policy(): + """Test creating a custom policy.""" + print("\n=== Testing Custom Policy ===") + + class EveryOtherPolicy(SparsePolicy): + """Select every other block.""" + + def select_blocks(self, available_blocks: List[int], ctx: PolicyContext) -> List[int]: + return [available_blocks[i] for i in range(0, len(available_blocks), 2)] + + policy = EveryOtherPolicy() + available_blocks = list(range(10)) + ctx = PolicyContext( + query_chunk_idx=0, + num_query_chunks=1, + layer_id=0, + query=None, + is_prefill=True, + ) + + selected = policy.select_blocks(available_blocks, ctx) + expected = [0, 2, 4, 6, 8] + assert selected == expected, f"Expected {expected}, got {selected}" + print(f" Every other: input={available_blocks}, selected={selected} [PASS]") + + +def run_all_tests(): + """Run all policy tests.""" + print("Running Sparse Policy Tests...") + + test_full_attention_policy() + test_vertical_slash_policy() + test_streaming_llm_policy() + test_quest_policy() + test_custom_policy() + + print("\n" + "=" * 50) + print("All tests passed!") + print("=" * 50) + + +if __name__ == "__main__": + if len(sys.argv) > 1: + policy_name = sys.argv[1].lower() + if policy_name == "full": + test_full_attention_policy() + elif policy_name == "vertical_slash": + test_vertical_slash_policy() + elif policy_name == "streaming_llm": + test_streaming_llm_policy() + elif policy_name == "quest": + test_quest_policy() + elif policy_name == "custom": + test_custom_policy() + else: + print(f"Unknown policy: {policy_name}") + print("Available: full, vertical_slash, streaming_llm, quest, custom") + sys.exit(1) + else: + run_all_tests()