""" Quest-style sparse attention policy. Uses min/max key bounds per block to estimate attention scores and select Top-K blocks most relevant to the current query. Reference: Quest paper on query-aware KV cache selection. """ import logging import torch from dataclasses import dataclass from typing import List, Tuple, Optional from .policy import SparsePolicy, PolicyContext logger = logging.getLogger(__name__) class BlockMetadataManager: """ Manages per-block metadata for Quest-style sparse selection. Stores min/max key values for each block, which are used to compute upper bounds on attention scores without loading the full KV cache. Memory usage: 2 * num_blocks * num_layers * num_kv_heads * head_dim * dtype_size Example: 1000 blocks, 28 layers, 4 heads, 128 dim, bf16 = ~57 MB """ def __init__( self, num_blocks: int, num_layers: int, num_kv_heads: int, head_dim: int, dtype: torch.dtype = torch.bfloat16, ): """ Initialize metadata storage. Args: num_blocks: Maximum number of CPU blocks num_layers: Number of transformer layers num_kv_heads: Number of KV attention heads head_dim: Dimension per head dtype: Data type for metadata storage """ self.num_blocks = num_blocks self.num_layers = num_layers self.num_kv_heads = num_kv_heads self.head_dim = head_dim self.dtype = dtype # Per-block min/max key values: [num_blocks, num_layers, num_heads, head_dim] shape = (num_blocks, num_layers, num_kv_heads, head_dim) self.key_min = torch.zeros(shape, dtype=dtype, pin_memory=True) self.key_max = torch.zeros(shape, dtype=dtype, pin_memory=True) # Track which blocks have valid metadata self.valid_blocks = torch.zeros(num_blocks, dtype=torch.bool) def update_metadata( self, block_id: int, layer_id: int, k_cache: torch.Tensor, num_valid_tokens: int, ) -> None: """ Update min/max key bounds for a block. Called when a block is offloaded to CPU. Args: block_id: CPU block ID layer_id: Layer index k_cache: Key cache tensor [block_size, num_kv_heads, head_dim] num_valid_tokens: Number of valid tokens in this block """ if num_valid_tokens == 0: return # Get valid keys only k_valid = k_cache[:num_valid_tokens].cpu() # [num_tokens, heads, dim] # Compute min/max across token dimension self.key_min[block_id, layer_id] = k_valid.min(dim=0).values self.key_max[block_id, layer_id] = k_valid.max(dim=0).values self.valid_blocks[block_id] = True def get_block_metadata( self, block_ids: List[int], layer_id: int, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Get min/max keys for specified blocks. Args: block_ids: List of CPU block IDs layer_id: Layer index Returns: Tuple of (key_min, key_max) tensors Shape: [num_blocks, num_kv_heads, head_dim] """ key_min = self.key_min[block_ids, layer_id] key_max = self.key_max[block_ids, layer_id] return key_min, key_max def reset(self) -> None: """Reset all metadata.""" self.key_min.zero_() self.key_max.zero_() self.valid_blocks.zero_() @dataclass class QuestConfig: """Configuration for QuestPolicy.""" topk_blocks: int = 8 """Number of top blocks to select based on estimated attention scores.""" threshold_blocks: int = 4 """If total blocks <= threshold, load all (no scoring needed).""" include_sink_blocks: int = 0 """Always include this many sink blocks (first N blocks), in addition to Top-K.""" include_recent_blocks: int = 0 """Always include this many recent blocks (last N blocks), in addition to Top-K.""" class QuestPolicy(SparsePolicy): """ Quest-style Top-K block selection using min/max key bounds. For each query, computes an upper bound on attention scores for each block using the stored min/max keys, then selects the Top-K blocks with highest estimated scores. Score computation: score(q, block) = max(q · key_min, q · key_max) This upper bound is derived from the fact that for any key k in the block: min_k <= k <= max_k (element-wise), so the actual attention score is bounded by the maximum of the two extremes. """ def __init__( self, config: QuestConfig, metadata_manager: BlockMetadataManager, ): """ Initialize Quest policy. Args: config: QuestConfig with selection parameters metadata_manager: BlockMetadataManager for min/max key storage """ self.config = config self.metadata = metadata_manager def select_blocks( self, available_blocks: List[int], ctx: PolicyContext, ) -> List[int]: """ Select Top-K blocks based on query-key similarity bounds. If query is not available (some prefill scenarios), falls back to loading all blocks. """ n = len(available_blocks) # If below threshold or no query, load all if n <= self.config.threshold_blocks: return available_blocks if ctx.query is None: # No query available - cannot compute scores return available_blocks # Get metadata for available blocks key_min, key_max = self.metadata.get_block_metadata( available_blocks, ctx.layer_id ) # Move to query device for computation device = ctx.query.device key_min = key_min.to(device, non_blocking=True) key_max = key_max.to(device, non_blocking=True) # Compute upper bound scores # query shape: [1, num_heads, head_dim] or [1, seq_len, num_heads, head_dim] q = ctx.query if q.dim() == 4: # Prefill: use mean over sequence length q = q.mean(dim=1) # [1, num_heads, head_dim] q = q.squeeze(0) # [num_q_heads, head_dim] # Handle GQA: query may have more heads than KV # key_min/key_max: [num_blocks, num_kv_heads, head_dim] num_q_heads = q.shape[0] num_kv_heads = key_min.shape[1] if num_q_heads != num_kv_heads: # GQA: group query heads and average per KV group # Reshape q: [num_q_heads, head_dim] -> [num_kv_heads, group_size, head_dim] group_size = num_q_heads // num_kv_heads q = q.view(num_kv_heads, group_size, -1).mean(dim=1) # [num_kv_heads, head_dim] # Score: max(q·k_min, q·k_max) averaged over heads # key_min/key_max: [num_blocks, num_kv_heads, head_dim] # q: [num_kv_heads, head_dim] score_min = torch.einsum('hd,bhd->bh', q, key_min) # [num_blocks, kv_heads] score_max = torch.einsum('hd,bhd->bh', q, key_max) scores = torch.maximum(score_min, score_max).mean(dim=-1) # [num_blocks] # Build selection set selected_indices = set() # Always include sink blocks for i in range(min(self.config.include_sink_blocks, n)): selected_indices.add(i) # Always include recent blocks for i in range(max(0, n - self.config.include_recent_blocks), n): selected_indices.add(i) # Top-K selection from remaining remaining_k = max(0, self.config.topk_blocks - len(selected_indices)) if remaining_k > 0: # Mask out already selected mask = torch.ones(n, dtype=torch.bool, device=device) for idx in selected_indices: mask[idx] = False if mask.any(): masked_scores = scores.clone() masked_scores[~mask] = float('-inf') topk_count = min(remaining_k, mask.sum().item()) if topk_count > 0: topk_indices = masked_scores.topk(topk_count).indices.cpu().tolist() selected_indices.update(topk_indices) # Return in sequential order for better memory access result = [available_blocks[i] for i in sorted(selected_indices)] # Log selection info (only for layer 0 to avoid spam) if ctx.layer_id == 0: logger.debug( f"Quest select: {len(result)}/{n} blocks " f"(topk={self.config.topk_blocks}, sink={self.config.include_sink_blocks}, " f"recent={self.config.include_recent_blocks})" ) return result def on_block_offloaded( self, cpu_block_id: int, layer_id: int, k_cache: torch.Tensor, num_valid_tokens: int, ) -> None: """Update min/max key metadata when block is offloaded.""" self.metadata.update_metadata(cpu_block_id, layer_id, k_cache, num_valid_tokens) def reset(self) -> None: """Reset metadata.""" self.metadata.reset() def __repr__(self) -> str: return ( f"QuestPolicy(topk={self.config.topk_blocks}, " f"threshold={self.config.threshold_blocks}, " f"sink={self.config.include_sink_blocks}, " f"recent={self.config.include_recent_blocks})" )