""" Base class for sparse attention policies. Sparse attention policies determine which KV cache blocks to load from CPU for each query chunk during chunked attention computation. """ from abc import ABC, abstractmethod from dataclasses import dataclass from typing import List, Optional, Any, TYPE_CHECKING import torch # Import SparsePolicyType from config to avoid circular imports from nanovllm.config import SparsePolicyType if TYPE_CHECKING: from nanovllm.kvcache.offload_engine import OffloadEngine from nanovllm.kvcache.manager import KVCacheManager from nanovllm.engine.sequence import Sequence @dataclass class PolicyContext: """ Context passed to sparse policy for block selection. This dataclass contains all information needed by a sparse policy to decide which blocks to load for the current query chunk. """ query_chunk_idx: int """Index of the current query chunk (0-indexed).""" num_query_chunks: int """Total number of query chunks in this prefill.""" layer_id: int """Current transformer layer index.""" query: Optional[torch.Tensor] """ Query tensor for current chunk. Shape: [1, num_heads, head_dim] for decode, [seq_len, num_heads, head_dim] for prefill. Available for both prefill and decode phases. """ is_prefill: bool """True if in prefill phase, False if in decode phase.""" block_size: int = 1024 """Number of tokens per block.""" total_kv_len: int = 0 """Total KV sequence length so far (for reference).""" class SparsePolicy(ABC): """ Abstract base class for sparse attention policies. Subclass this and implement select_blocks() to create custom sparse attention patterns. The policy receives context about the current query chunk and returns which KV blocks to load. Attributes: supports_prefill: Whether this policy can be used for prefill phase. supports_decode: Whether this policy can be used for decode phase. Example: class MySparsePolicy(SparsePolicy): supports_prefill = False # decode-only policy supports_decode = True def select_blocks(self, available_blocks, ctx): # Load first block and last 2 blocks if len(available_blocks) <= 3: return available_blocks return [available_blocks[0]] + available_blocks[-2:] """ # Compatibility flags - override in subclasses supports_prefill: bool = True supports_decode: bool = True def initialize( self, num_layers: int, num_kv_heads: int, head_dim: int, num_cpu_blocks: int, dtype: torch.dtype, device: torch.device = None, ) -> None: """ Initialize policy resources. Called by the framework after KV cache is allocated. Override this to create metadata structures (e.g., BlockMetadataManager for Quest). Default implementation does nothing. Args: num_layers: Number of transformer layers num_kv_heads: Number of KV attention heads head_dim: Dimension per head num_cpu_blocks: Number of CPU blocks allocated dtype: Data type for tensors device: Device for metadata storage (GPU recommended for performance) """ pass @abstractmethod def select_blocks( self, available_blocks: List[int], offload_engine: "OffloadEngine", ctx: PolicyContext, ) -> List[int]: """ Select which KV blocks to load for the current query chunk. This is the core method that defines the sparse attention pattern. The returned blocks will be loaded from CPU to GPU for attention computation against the current query chunk. Args: available_blocks: List of CPU block IDs that contain KV cache from previous chunks. These are ordered by their position in the sequence. offload_engine: OffloadEngine for loading KV (some policies need to load KV to make selection decisions). ctx: PolicyContext with information about the current query chunk, layer, phase (prefill/decode), etc. Returns: List of block IDs to load (must be a subset of available_blocks). The order may affect performance (sequential access is faster). Returning [] means no previous blocks will be loaded. """ pass def on_prefill_offload( self, cpu_block_id: int, layer_id: int, k_cache: torch.Tensor, num_valid_tokens: int, ) -> None: """ Hook called when a block is offloaded during prefill phase. Called BEFORE GPU→CPU copy, while k_cache is still on GPU. Override this to collect metadata about blocks (e.g., min/max keys for Quest-style selection). Default implementation does nothing. Args: cpu_block_id: The CPU block ID that will be written layer_id: Transformer layer index k_cache: Key cache tensor [block_size, num_kv_heads, head_dim] (on GPU) num_valid_tokens: Number of valid tokens in this block """ pass def on_decode_offload( self, cpu_block_id: int, layer_id: int, k_cache: torch.Tensor, num_valid_tokens: int, ) -> None: """ Hook called when a block is offloaded during decode phase. Called BEFORE GPU→CPU copy, while k_cache is still on GPU. Override this to update metadata about blocks. Default implementation does nothing. Args: cpu_block_id: The CPU block ID that will be written layer_id: Transformer layer index k_cache: Key cache tensor [block_size, num_kv_heads, head_dim] (on GPU) num_valid_tokens: Number of valid tokens in this block """ pass def reset(self) -> None: """ Reset policy state. Called when starting a new sequence or clearing state. Default implementation does nothing. """ pass @abstractmethod def compute_chunked_attention( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, layer_id: int, softmax_scale: float, offload_engine: "OffloadEngine", kvcache_manager: "KVCacheManager", current_chunk_idx: int, seq: "Sequence", num_tokens: int, ) -> torch.Tensor: """ Compute chunked prefill attention (complete flow). This is the main entry point for prefill attention computation. It defines the complete prefill flow: 1. Get historical blocks 2. Select blocks (call select_blocks) 3. Load and compute historical blocks via offload_engine 4. Get current chunk KV from offload_engine, compute attention 5. Merge all results Args: q: [seq_len, num_heads, head_dim] query for current chunk k: [seq_len, num_kv_heads, head_dim] key for current chunk (in prefill buffer) v: [seq_len, num_kv_heads, head_dim] value for current chunk (in prefill buffer) layer_id: transformer layer index softmax_scale: softmax scaling factor offload_engine: OffloadEngine for loading blocks kvcache_manager: KVCacheManager for block management current_chunk_idx: current chunk index seq: Sequence object num_tokens: number of tokens in current chunk Returns: [seq_len, num_heads, head_dim] final attention output """ pass @abstractmethod def compute_chunked_decode( self, q: torch.Tensor, layer_id: int, softmax_scale: float, offload_engine: "OffloadEngine", kvcache_manager: "KVCacheManager", seq: "Sequence", ) -> torch.Tensor: """ Compute chunked decode attention (complete flow). This is the main entry point for decode attention computation. It defines the complete decode flow: 1. Get prefilled blocks from CPU 2. Select blocks (call select_blocks) 3. Load blocks via pipeline (ring buffer or cross-layer) 4. Read accumulated decode tokens from decode buffer 5. Merge all results The decode position information can be computed internally: - decode_start_pos = kvcache_manager.get_decode_start_pos(seq) - decode_pos_in_block = (len(seq) - 1) % kvcache_manager.block_size Args: q: [batch_size, num_heads, head_dim] query for decode token layer_id: transformer layer index softmax_scale: softmax scaling factor offload_engine: OffloadEngine for loading blocks kvcache_manager: KVCacheManager for block management seq: Sequence object Returns: [batch_size, 1, num_heads, head_dim] final attention output """ pass def __repr__(self) -> str: return f"{self.__class__.__name__}()"