Files
nano-vllm/nanovllm/kvcache/sparse/quest.py
2025-12-22 08:51:02 +08:00

285 lines
9.3 KiB
Python

"""
Quest-style sparse attention policy.
Uses min/max key bounds per block to estimate attention scores
and select Top-K blocks most relevant to the current query.
Reference: Quest paper on query-aware KV cache selection.
"""
import logging
import torch
from dataclasses import dataclass
from typing import List, Tuple, Optional
from .policy import SparsePolicy, PolicyContext
logger = logging.getLogger(__name__)
class BlockMetadataManager:
"""
Manages per-block metadata for Quest-style sparse selection.
Stores min/max key values for each block, which are used to
compute upper bounds on attention scores without loading the
full KV cache.
Memory usage: 2 * num_blocks * num_layers * num_kv_heads * head_dim * dtype_size
Example: 1000 blocks, 28 layers, 4 heads, 128 dim, bf16 = ~57 MB
"""
def __init__(
self,
num_blocks: int,
num_layers: int,
num_kv_heads: int,
head_dim: int,
dtype: torch.dtype = torch.bfloat16,
):
"""
Initialize metadata storage.
Args:
num_blocks: Maximum number of CPU blocks
num_layers: Number of transformer layers
num_kv_heads: Number of KV attention heads
head_dim: Dimension per head
dtype: Data type for metadata storage
"""
self.num_blocks = num_blocks
self.num_layers = num_layers
self.num_kv_heads = num_kv_heads
self.head_dim = head_dim
self.dtype = dtype
# Per-block min/max key values: [num_blocks, num_layers, num_heads, head_dim]
shape = (num_blocks, num_layers, num_kv_heads, head_dim)
self.key_min = torch.zeros(shape, dtype=dtype, pin_memory=True)
self.key_max = torch.zeros(shape, dtype=dtype, pin_memory=True)
# Track which blocks have valid metadata
self.valid_blocks = torch.zeros(num_blocks, dtype=torch.bool)
def update_metadata(
self,
block_id: int,
layer_id: int,
k_cache: torch.Tensor,
num_valid_tokens: int,
) -> None:
"""
Update min/max key bounds for a block.
Called when a block is offloaded to CPU.
Args:
block_id: CPU block ID
layer_id: Layer index
k_cache: Key cache tensor [block_size, num_kv_heads, head_dim]
num_valid_tokens: Number of valid tokens in this block
"""
if num_valid_tokens == 0:
return
# Get valid keys only
k_valid = k_cache[:num_valid_tokens].cpu() # [num_tokens, heads, dim]
# Compute min/max across token dimension
self.key_min[block_id, layer_id] = k_valid.min(dim=0).values
self.key_max[block_id, layer_id] = k_valid.max(dim=0).values
self.valid_blocks[block_id] = True
def get_block_metadata(
self,
block_ids: List[int],
layer_id: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Get min/max keys for specified blocks.
Args:
block_ids: List of CPU block IDs
layer_id: Layer index
Returns:
Tuple of (key_min, key_max) tensors
Shape: [num_blocks, num_kv_heads, head_dim]
"""
key_min = self.key_min[block_ids, layer_id]
key_max = self.key_max[block_ids, layer_id]
return key_min, key_max
def reset(self) -> None:
"""Reset all metadata."""
self.key_min.zero_()
self.key_max.zero_()
self.valid_blocks.zero_()
@dataclass
class QuestConfig:
"""Configuration for QuestPolicy."""
topk_blocks: int = 8
"""Number of top blocks to select based on estimated attention scores."""
threshold_blocks: int = 4
"""If total blocks <= threshold, load all (no scoring needed)."""
include_sink_blocks: int = 0
"""Always include this many sink blocks (first N blocks), in addition to Top-K."""
include_recent_blocks: int = 0
"""Always include this many recent blocks (last N blocks), in addition to Top-K."""
class QuestPolicy(SparsePolicy):
"""
Quest-style Top-K block selection using min/max key bounds.
For each query, computes an upper bound on attention scores for
each block using the stored min/max keys, then selects the Top-K
blocks with highest estimated scores.
Score computation:
score(q, block) = max(q · key_min, q · key_max)
This upper bound is derived from the fact that for any key k in
the block: min_k <= k <= max_k (element-wise), so the actual
attention score is bounded by the maximum of the two extremes.
"""
def __init__(
self,
config: QuestConfig,
metadata_manager: BlockMetadataManager,
):
"""
Initialize Quest policy.
Args:
config: QuestConfig with selection parameters
metadata_manager: BlockMetadataManager for min/max key storage
"""
self.config = config
self.metadata = metadata_manager
def select_blocks(
self,
available_blocks: List[int],
ctx: PolicyContext,
) -> List[int]:
"""
Select Top-K blocks based on query-key similarity bounds.
If query is not available (some prefill scenarios), falls back
to loading all blocks.
"""
n = len(available_blocks)
# If below threshold or no query, load all
if n <= self.config.threshold_blocks:
return available_blocks
if ctx.query is None:
# No query available - cannot compute scores
return available_blocks
# Get metadata for available blocks
key_min, key_max = self.metadata.get_block_metadata(
available_blocks, ctx.layer_id
)
# Move to query device for computation
device = ctx.query.device
key_min = key_min.to(device, non_blocking=True)
key_max = key_max.to(device, non_blocking=True)
# Compute upper bound scores
# query shape: [1, num_heads, head_dim] or [1, seq_len, num_heads, head_dim]
q = ctx.query
if q.dim() == 4:
# Prefill: use mean over sequence length
q = q.mean(dim=1) # [1, num_heads, head_dim]
q = q.squeeze(0) # [num_q_heads, head_dim]
# Handle GQA: query may have more heads than KV
# key_min/key_max: [num_blocks, num_kv_heads, head_dim]
num_q_heads = q.shape[0]
num_kv_heads = key_min.shape[1]
if num_q_heads != num_kv_heads:
# GQA: group query heads and average per KV group
# Reshape q: [num_q_heads, head_dim] -> [num_kv_heads, group_size, head_dim]
group_size = num_q_heads // num_kv_heads
q = q.view(num_kv_heads, group_size, -1).mean(dim=1) # [num_kv_heads, head_dim]
# Score: max(q·k_min, q·k_max) averaged over heads
# key_min/key_max: [num_blocks, num_kv_heads, head_dim]
# q: [num_kv_heads, head_dim]
score_min = torch.einsum('hd,bhd->bh', q, key_min) # [num_blocks, kv_heads]
score_max = torch.einsum('hd,bhd->bh', q, key_max)
scores = torch.maximum(score_min, score_max).mean(dim=-1) # [num_blocks]
# Build selection set
selected_indices = set()
# Always include sink blocks
for i in range(min(self.config.include_sink_blocks, n)):
selected_indices.add(i)
# Always include recent blocks
for i in range(max(0, n - self.config.include_recent_blocks), n):
selected_indices.add(i)
# Top-K selection from remaining
remaining_k = max(0, self.config.topk_blocks - len(selected_indices))
if remaining_k > 0:
# Mask out already selected
mask = torch.ones(n, dtype=torch.bool, device=device)
for idx in selected_indices:
mask[idx] = False
if mask.any():
masked_scores = scores.clone()
masked_scores[~mask] = float('-inf')
topk_count = min(remaining_k, mask.sum().item())
if topk_count > 0:
topk_indices = masked_scores.topk(topk_count).indices.cpu().tolist()
selected_indices.update(topk_indices)
# Return in sequential order for better memory access
result = [available_blocks[i] for i in sorted(selected_indices)]
# Log selection info (only for layer 0 to avoid spam)
if ctx.layer_id == 0:
logger.debug(
f"Quest select: {len(result)}/{n} blocks "
f"(topk={self.config.topk_blocks}, sink={self.config.include_sink_blocks}, "
f"recent={self.config.include_recent_blocks})"
)
return result
def on_block_offloaded(
self,
cpu_block_id: int,
layer_id: int,
k_cache: torch.Tensor,
num_valid_tokens: int,
) -> None:
"""Update min/max key metadata when block is offloaded."""
self.metadata.update_metadata(cpu_block_id, layer_id, k_cache, num_valid_tokens)
def reset(self) -> None:
"""Reset metadata."""
self.metadata.reset()
def __repr__(self) -> str:
return (
f"QuestPolicy(topk={self.config.topk_blocks}, "
f"threshold={self.config.threshold_blocks}, "
f"sink={self.config.include_sink_blocks}, "
f"recent={self.config.include_recent_blocks})"
)