[feat] Added sparse KVcache feature, NEED VERIFY.

This commit is contained in:
Zijie Tian
2025-12-22 08:51:02 +08:00
parent 8df0c7517b
commit 051f2295c9
14 changed files with 1215 additions and 12 deletions

View File

@@ -25,6 +25,11 @@ from nanovllm.kvcache.offload_engine import OffloadEngine
from nanovllm.kvcache.policies.base_policy import EvictionPolicy
from nanovllm.kvcache.policies.lru_policy import LRUPolicy
# Type checking import for sparse policy
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from nanovllm.kvcache.sparse.policy import SparsePolicy
class BlockLocation(Enum):
"""Where a logical block's data currently resides."""
@@ -142,6 +147,9 @@ class HybridKVCacheManager(KVCacheManager):
# Key: sequence id, Value: starting position where decode began in current block
self._decode_start_pos: Dict[int, int] = {}
# Sparse attention policy (optional)
self.sparse_policy: Optional["SparsePolicy"] = None
@property
def block_size(self) -> int:
return self._block_size
@@ -174,6 +182,24 @@ class HybridKVCacheManager(KVCacheManager):
assert self.offload_engine is not None
return self.offload_engine.get_layer_cache(layer_id)
def set_sparse_policy(self, policy: "SparsePolicy") -> None:
"""
Set sparse attention policy for block selection.
The sparse policy determines which KV blocks to load from CPU
for each query chunk during chunked attention computation.
Args:
policy: SparsePolicy instance (e.g., VerticalSlashPolicy, QuestPolicy)
Example:
from nanovllm.kvcache.sparse import VerticalSlashPolicy, VerticalSlashConfig
policy = VerticalSlashPolicy(VerticalSlashConfig(num_sink_blocks=2))
manager.set_sparse_policy(policy)
"""
self.sparse_policy = policy
logger.info(f"Sparse attention policy set: {policy}")
def _allocate_gpu_slot(self, protected_logical_ids: Optional[Set[int]] = None) -> int:
"""
Get a free GPU slot, evicting if necessary.

View File

@@ -0,0 +1,90 @@
"""
Sparse Attention Policy module.
Provides pluggable policies for selecting which KV blocks to load
during chunked attention with CPU offload.
Usage:
from nanovllm.kvcache.sparse import SparsePolicy, PolicyContext
from nanovllm.kvcache.sparse import VerticalSlashPolicy, QuestPolicy
# Use built-in policy
policy = VerticalSlashPolicy(VerticalSlashConfig())
# Or create custom policy
class MyPolicy(SparsePolicy):
def select_blocks(self, available_blocks, ctx):
return available_blocks[:5] # Just first 5 blocks
"""
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext
from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy
from nanovllm.kvcache.sparse.vertical_slash import VerticalSlashPolicy, VerticalSlashConfig
from nanovllm.kvcache.sparse.quest import QuestPolicy, QuestConfig, BlockMetadataManager
from nanovllm.kvcache.sparse.streaming_llm import StreamingLLMPolicy, StreamingLLMConfig
from nanovllm.kvcache.sparse.hybrid import HybridPolicy
# Built-in policy registry
BUILTIN_SPARSE_POLICIES = {
"full": FullAttentionPolicy,
"vertical_slash": VerticalSlashPolicy,
"streaming_llm": StreamingLLMPolicy,
}
def get_sparse_policy(policy_name: str, **kwargs) -> SparsePolicy:
"""
Get a sparse attention policy instance by name.
Args:
policy_name: Policy name ("full", "vertical_slash", "streaming_llm", "quest")
**kwargs: Policy-specific configuration
Returns:
SparsePolicy instance
"""
policy_name = policy_name.lower()
if policy_name == "full":
return FullAttentionPolicy()
elif policy_name == "vertical_slash":
config = VerticalSlashConfig(
num_sink_blocks=kwargs.get("num_sink_blocks", 1),
local_window_blocks=kwargs.get("local_window_blocks", 2),
threshold_blocks=kwargs.get("threshold_blocks", 4),
)
return VerticalSlashPolicy(config)
elif policy_name == "streaming_llm":
config = StreamingLLMConfig(
num_sink_blocks=kwargs.get("num_sink_blocks", 1),
num_recent_blocks=kwargs.get("num_recent_blocks", 3),
)
return StreamingLLMPolicy(config)
elif policy_name == "quest":
# Quest requires metadata_manager to be passed separately
raise ValueError(
"Quest policy requires BlockMetadataManager. "
"Use QuestPolicy(config, metadata_manager) directly."
)
else:
raise ValueError(
f"Unknown sparse policy '{policy_name}'. "
f"Available policies: {list(BUILTIN_SPARSE_POLICIES.keys())}"
)
__all__ = [
"SparsePolicy",
"PolicyContext",
"FullAttentionPolicy",
"VerticalSlashPolicy",
"VerticalSlashConfig",
"QuestPolicy",
"QuestConfig",
"BlockMetadataManager",
"StreamingLLMPolicy",
"StreamingLLMConfig",
"HybridPolicy",
"get_sparse_policy",
"BUILTIN_SPARSE_POLICIES",
]

View File

@@ -0,0 +1,34 @@
"""
Full attention policy - loads all blocks (no sparsity).
This serves as a baseline and default policy when sparse
attention is not needed.
"""
from typing import List
from .policy import SparsePolicy, PolicyContext
class FullAttentionPolicy(SparsePolicy):
"""
Full attention policy that loads all available blocks.
This is the default behavior with no sparsity - all previous
KV cache blocks are loaded for each query chunk.
Use this as:
- A baseline for comparing sparse policies
- When you need full attention accuracy
- For short sequences where sparsity isn't beneficial
"""
def select_blocks(
self,
available_blocks: List[int],
ctx: PolicyContext,
) -> List[int]:
"""Return all blocks - no sparsity."""
return available_blocks
def __repr__(self) -> str:
return "FullAttentionPolicy()"

View File

@@ -0,0 +1,93 @@
"""
Hybrid sparse attention policy.
Allows using different policies for prefill vs decode phases.
This is useful because optimal sparsity patterns often differ:
- Prefill: fixed patterns work well (e.g., VerticalSlash)
- Decode: query-aware selection helps (e.g., Quest)
"""
from typing import List
import torch
from .policy import SparsePolicy, PolicyContext
class HybridPolicy(SparsePolicy):
"""
Hybrid policy that uses different policies for prefill and decode.
Example usage:
```python
from nanovllm.kvcache.sparse import (
HybridPolicy, VerticalSlashPolicy, QuestPolicy,
VerticalSlashConfig, QuestConfig, BlockMetadataManager
)
# Prefill: use fast fixed pattern
prefill_policy = VerticalSlashPolicy(VerticalSlashConfig(
num_sink_blocks=1,
local_window_blocks=3,
))
# Decode: use query-aware selection
metadata = BlockMetadataManager(num_blocks, num_layers, num_heads, head_dim)
decode_policy = QuestPolicy(QuestConfig(topk_blocks=8), metadata)
# Combine
policy = HybridPolicy(prefill_policy, decode_policy)
```
"""
def __init__(
self,
prefill_policy: SparsePolicy,
decode_policy: SparsePolicy,
):
"""
Initialize hybrid policy.
Args:
prefill_policy: Policy to use during prefill phase
decode_policy: Policy to use during decode phase
"""
self.prefill_policy = prefill_policy
self.decode_policy = decode_policy
def select_blocks(
self,
available_blocks: List[int],
ctx: PolicyContext,
) -> List[int]:
"""Delegate to appropriate policy based on phase."""
if ctx.is_prefill:
return self.prefill_policy.select_blocks(available_blocks, ctx)
else:
return self.decode_policy.select_blocks(available_blocks, ctx)
def on_block_offloaded(
self,
cpu_block_id: int,
layer_id: int,
k_cache: torch.Tensor,
num_valid_tokens: int,
) -> None:
"""Forward to both policies (both may need metadata updates)."""
self.prefill_policy.on_block_offloaded(
cpu_block_id, layer_id, k_cache, num_valid_tokens
)
self.decode_policy.on_block_offloaded(
cpu_block_id, layer_id, k_cache, num_valid_tokens
)
def reset(self) -> None:
"""Reset both policies."""
self.prefill_policy.reset()
self.decode_policy.reset()
def __repr__(self) -> str:
return (
f"HybridPolicy(\n"
f" prefill={self.prefill_policy},\n"
f" decode={self.decode_policy}\n"
f")"
)

View File

@@ -0,0 +1,124 @@
"""
Base class for sparse attention policies.
Sparse attention policies determine which KV cache blocks to load
from CPU for each query chunk during chunked attention computation.
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List, Optional, Any
import torch
@dataclass
class PolicyContext:
"""
Context passed to sparse policy for block selection.
This dataclass contains all information needed by a sparse policy
to decide which blocks to load for the current query chunk.
"""
query_chunk_idx: int
"""Index of the current query chunk (0-indexed)."""
num_query_chunks: int
"""Total number of query chunks in this prefill."""
layer_id: int
"""Current transformer layer index."""
query: Optional[torch.Tensor]
"""
Query tensor for current chunk.
Shape: [1, num_heads, head_dim] for decode, [1, seq_len, num_heads, head_dim] for prefill.
May be None if not available (e.g., some prefill scenarios).
"""
is_prefill: bool
"""True if in prefill phase, False if in decode phase."""
block_size: int = 4096
"""Number of tokens per block."""
total_kv_len: int = 0
"""Total KV sequence length so far (for reference)."""
class SparsePolicy(ABC):
"""
Abstract base class for sparse attention policies.
Subclass this and implement select_blocks() to create custom
sparse attention patterns. The policy receives context about
the current query chunk and returns which KV blocks to load.
Example:
class MySparsePolicy(SparsePolicy):
def select_blocks(self, available_blocks, ctx):
# Load first block and last 2 blocks
if len(available_blocks) <= 3:
return available_blocks
return [available_blocks[0]] + available_blocks[-2:]
"""
@abstractmethod
def select_blocks(
self,
available_blocks: List[int],
ctx: PolicyContext,
) -> List[int]:
"""
Select which KV blocks to load for the current query chunk.
This is the core method that defines the sparse attention pattern.
The returned blocks will be loaded from CPU to GPU for attention
computation against the current query chunk.
Args:
available_blocks: List of CPU block IDs that contain KV cache
from previous chunks. These are ordered by
their position in the sequence.
ctx: PolicyContext with information about the current query
chunk, layer, phase (prefill/decode), etc.
Returns:
List of block IDs to load (must be a subset of available_blocks).
The order may affect performance (sequential access is faster).
Returning [] means no previous blocks will be loaded.
"""
pass
def on_block_offloaded(
self,
cpu_block_id: int,
layer_id: int,
k_cache: torch.Tensor,
num_valid_tokens: int,
) -> None:
"""
Hook called when a block is offloaded from GPU to CPU.
Override this to collect metadata about blocks (e.g., min/max keys
for Quest-style selection). Default implementation does nothing.
Args:
cpu_block_id: The CPU block ID that was written
layer_id: Transformer layer index
k_cache: Key cache tensor [block_size, num_kv_heads, head_dim]
num_valid_tokens: Number of valid tokens in this block
"""
pass
def reset(self) -> None:
"""
Reset policy state.
Called when starting a new sequence or clearing state.
Default implementation does nothing.
"""
pass
def __repr__(self) -> str:
return f"{self.__class__.__name__}()"

View File

@@ -0,0 +1,284 @@
"""
Quest-style sparse attention policy.
Uses min/max key bounds per block to estimate attention scores
and select Top-K blocks most relevant to the current query.
Reference: Quest paper on query-aware KV cache selection.
"""
import logging
import torch
from dataclasses import dataclass
from typing import List, Tuple, Optional
from .policy import SparsePolicy, PolicyContext
logger = logging.getLogger(__name__)
class BlockMetadataManager:
"""
Manages per-block metadata for Quest-style sparse selection.
Stores min/max key values for each block, which are used to
compute upper bounds on attention scores without loading the
full KV cache.
Memory usage: 2 * num_blocks * num_layers * num_kv_heads * head_dim * dtype_size
Example: 1000 blocks, 28 layers, 4 heads, 128 dim, bf16 = ~57 MB
"""
def __init__(
self,
num_blocks: int,
num_layers: int,
num_kv_heads: int,
head_dim: int,
dtype: torch.dtype = torch.bfloat16,
):
"""
Initialize metadata storage.
Args:
num_blocks: Maximum number of CPU blocks
num_layers: Number of transformer layers
num_kv_heads: Number of KV attention heads
head_dim: Dimension per head
dtype: Data type for metadata storage
"""
self.num_blocks = num_blocks
self.num_layers = num_layers
self.num_kv_heads = num_kv_heads
self.head_dim = head_dim
self.dtype = dtype
# Per-block min/max key values: [num_blocks, num_layers, num_heads, head_dim]
shape = (num_blocks, num_layers, num_kv_heads, head_dim)
self.key_min = torch.zeros(shape, dtype=dtype, pin_memory=True)
self.key_max = torch.zeros(shape, dtype=dtype, pin_memory=True)
# Track which blocks have valid metadata
self.valid_blocks = torch.zeros(num_blocks, dtype=torch.bool)
def update_metadata(
self,
block_id: int,
layer_id: int,
k_cache: torch.Tensor,
num_valid_tokens: int,
) -> None:
"""
Update min/max key bounds for a block.
Called when a block is offloaded to CPU.
Args:
block_id: CPU block ID
layer_id: Layer index
k_cache: Key cache tensor [block_size, num_kv_heads, head_dim]
num_valid_tokens: Number of valid tokens in this block
"""
if num_valid_tokens == 0:
return
# Get valid keys only
k_valid = k_cache[:num_valid_tokens].cpu() # [num_tokens, heads, dim]
# Compute min/max across token dimension
self.key_min[block_id, layer_id] = k_valid.min(dim=0).values
self.key_max[block_id, layer_id] = k_valid.max(dim=0).values
self.valid_blocks[block_id] = True
def get_block_metadata(
self,
block_ids: List[int],
layer_id: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Get min/max keys for specified blocks.
Args:
block_ids: List of CPU block IDs
layer_id: Layer index
Returns:
Tuple of (key_min, key_max) tensors
Shape: [num_blocks, num_kv_heads, head_dim]
"""
key_min = self.key_min[block_ids, layer_id]
key_max = self.key_max[block_ids, layer_id]
return key_min, key_max
def reset(self) -> None:
"""Reset all metadata."""
self.key_min.zero_()
self.key_max.zero_()
self.valid_blocks.zero_()
@dataclass
class QuestConfig:
"""Configuration for QuestPolicy."""
topk_blocks: int = 8
"""Number of top blocks to select based on estimated attention scores."""
threshold_blocks: int = 4
"""If total blocks <= threshold, load all (no scoring needed)."""
include_sink_blocks: int = 0
"""Always include this many sink blocks (first N blocks), in addition to Top-K."""
include_recent_blocks: int = 0
"""Always include this many recent blocks (last N blocks), in addition to Top-K."""
class QuestPolicy(SparsePolicy):
"""
Quest-style Top-K block selection using min/max key bounds.
For each query, computes an upper bound on attention scores for
each block using the stored min/max keys, then selects the Top-K
blocks with highest estimated scores.
Score computation:
score(q, block) = max(q · key_min, q · key_max)
This upper bound is derived from the fact that for any key k in
the block: min_k <= k <= max_k (element-wise), so the actual
attention score is bounded by the maximum of the two extremes.
"""
def __init__(
self,
config: QuestConfig,
metadata_manager: BlockMetadataManager,
):
"""
Initialize Quest policy.
Args:
config: QuestConfig with selection parameters
metadata_manager: BlockMetadataManager for min/max key storage
"""
self.config = config
self.metadata = metadata_manager
def select_blocks(
self,
available_blocks: List[int],
ctx: PolicyContext,
) -> List[int]:
"""
Select Top-K blocks based on query-key similarity bounds.
If query is not available (some prefill scenarios), falls back
to loading all blocks.
"""
n = len(available_blocks)
# If below threshold or no query, load all
if n <= self.config.threshold_blocks:
return available_blocks
if ctx.query is None:
# No query available - cannot compute scores
return available_blocks
# Get metadata for available blocks
key_min, key_max = self.metadata.get_block_metadata(
available_blocks, ctx.layer_id
)
# Move to query device for computation
device = ctx.query.device
key_min = key_min.to(device, non_blocking=True)
key_max = key_max.to(device, non_blocking=True)
# Compute upper bound scores
# query shape: [1, num_heads, head_dim] or [1, seq_len, num_heads, head_dim]
q = ctx.query
if q.dim() == 4:
# Prefill: use mean over sequence length
q = q.mean(dim=1) # [1, num_heads, head_dim]
q = q.squeeze(0) # [num_q_heads, head_dim]
# Handle GQA: query may have more heads than KV
# key_min/key_max: [num_blocks, num_kv_heads, head_dim]
num_q_heads = q.shape[0]
num_kv_heads = key_min.shape[1]
if num_q_heads != num_kv_heads:
# GQA: group query heads and average per KV group
# Reshape q: [num_q_heads, head_dim] -> [num_kv_heads, group_size, head_dim]
group_size = num_q_heads // num_kv_heads
q = q.view(num_kv_heads, group_size, -1).mean(dim=1) # [num_kv_heads, head_dim]
# Score: max(q·k_min, q·k_max) averaged over heads
# key_min/key_max: [num_blocks, num_kv_heads, head_dim]
# q: [num_kv_heads, head_dim]
score_min = torch.einsum('hd,bhd->bh', q, key_min) # [num_blocks, kv_heads]
score_max = torch.einsum('hd,bhd->bh', q, key_max)
scores = torch.maximum(score_min, score_max).mean(dim=-1) # [num_blocks]
# Build selection set
selected_indices = set()
# Always include sink blocks
for i in range(min(self.config.include_sink_blocks, n)):
selected_indices.add(i)
# Always include recent blocks
for i in range(max(0, n - self.config.include_recent_blocks), n):
selected_indices.add(i)
# Top-K selection from remaining
remaining_k = max(0, self.config.topk_blocks - len(selected_indices))
if remaining_k > 0:
# Mask out already selected
mask = torch.ones(n, dtype=torch.bool, device=device)
for idx in selected_indices:
mask[idx] = False
if mask.any():
masked_scores = scores.clone()
masked_scores[~mask] = float('-inf')
topk_count = min(remaining_k, mask.sum().item())
if topk_count > 0:
topk_indices = masked_scores.topk(topk_count).indices.cpu().tolist()
selected_indices.update(topk_indices)
# Return in sequential order for better memory access
result = [available_blocks[i] for i in sorted(selected_indices)]
# Log selection info (only for layer 0 to avoid spam)
if ctx.layer_id == 0:
logger.debug(
f"Quest select: {len(result)}/{n} blocks "
f"(topk={self.config.topk_blocks}, sink={self.config.include_sink_blocks}, "
f"recent={self.config.include_recent_blocks})"
)
return result
def on_block_offloaded(
self,
cpu_block_id: int,
layer_id: int,
k_cache: torch.Tensor,
num_valid_tokens: int,
) -> None:
"""Update min/max key metadata when block is offloaded."""
self.metadata.update_metadata(cpu_block_id, layer_id, k_cache, num_valid_tokens)
def reset(self) -> None:
"""Reset metadata."""
self.metadata.reset()
def __repr__(self) -> str:
return (
f"QuestPolicy(topk={self.config.topk_blocks}, "
f"threshold={self.config.threshold_blocks}, "
f"sink={self.config.include_sink_blocks}, "
f"recent={self.config.include_recent_blocks})"
)

View File

@@ -0,0 +1,84 @@
"""
StreamingLLM sparse attention policy.
Only keeps sink tokens (beginning) + recent tokens (end).
Intermediate context is discarded. This enables infinite-length
generation but loses intermediate context.
Reference: StreamingLLM paper on attention sinks.
"""
from dataclasses import dataclass
from typing import List
from .policy import SparsePolicy, PolicyContext
@dataclass
class StreamingLLMConfig:
"""Configuration for StreamingLLMPolicy."""
num_sink_blocks: int = 1
"""Number of blocks at the beginning to always include (attention sinks)."""
num_recent_blocks: int = 3
"""Number of most recent blocks to include (sliding window)."""
class StreamingLLMPolicy(SparsePolicy):
"""
StreamingLLM pattern: sink tokens + recent tokens only.
This is the most aggressive sparsity pattern - only keeps a small
fixed window of context. Suitable for:
- Very long streaming generation
- When intermediate context can be safely discarded
- Maximizing throughput over accuracy
Pattern visualization:
```
Blocks: [0] [1] [2] [3] [4] [5] [6] [7] [8]
× × × ↑ ↑ ↑
sink (discarded) recent window
```
Warning: This loses information from intermediate blocks!
Use only when this trade-off is acceptable.
"""
def __init__(self, config: StreamingLLMConfig = None):
self.config = config or StreamingLLMConfig()
def select_blocks(
self,
available_blocks: List[int],
ctx: PolicyContext,
) -> List[int]:
"""
Select sink blocks + recent blocks only.
Intermediate blocks are not loaded (effectively discarded).
"""
n = len(available_blocks)
# If total blocks fit in sink + recent, load all
total_keep = self.config.num_sink_blocks + self.config.num_recent_blocks
if n <= total_keep:
return available_blocks
selected_indices = set()
# Sink blocks (first N)
for i in range(min(self.config.num_sink_blocks, n)):
selected_indices.add(i)
# Recent blocks (last M)
for i in range(max(0, n - self.config.num_recent_blocks), n):
selected_indices.add(i)
return [available_blocks[i] for i in sorted(selected_indices)]
def __repr__(self) -> str:
return (
f"StreamingLLMPolicy(sink={self.config.num_sink_blocks}, "
f"recent={self.config.num_recent_blocks})"
)

View File

@@ -0,0 +1,95 @@
"""
Vertical-Slash sparse attention policy (MInference-style).
Selects sink blocks (beginning of sequence) + local window blocks
(near the current query position). This pattern captures:
- Important initial context (system prompt, instructions)
- Recent context (relevant for local dependencies)
"""
from dataclasses import dataclass
from typing import List
from .policy import SparsePolicy, PolicyContext
@dataclass
class VerticalSlashConfig:
"""Configuration for VerticalSlashPolicy."""
num_sink_blocks: int = 1
"""Number of blocks at the beginning to always include (sink tokens)."""
local_window_blocks: int = 2
"""Number of blocks in the local window near current query position."""
threshold_blocks: int = 4
"""If total blocks <= threshold, load all (no sparsity applied)."""
class VerticalSlashPolicy(SparsePolicy):
"""
Vertical-Slash pattern: sink tokens + local window.
This pattern is inspired by MInference and observations that:
1. Initial tokens (sink) often receive high attention
2. Local context (recent tokens) is important for dependencies
Pattern visualization:
```
Blocks: [0] [1] [2] [3] [4] [5] [6] [7] [8]
↑ ↑ ↑ ↑
sink local window (for query at block 9)
```
For prefill chunk K, the local window is blocks [K-window, K-1].
For decode, the local window is the last N blocks.
"""
def __init__(self, config: VerticalSlashConfig = None):
self.config = config or VerticalSlashConfig()
def select_blocks(
self,
available_blocks: List[int],
ctx: PolicyContext,
) -> List[int]:
"""
Select sink blocks + local window blocks.
For prefill: local window is relative to current chunk position.
For decode: local window is the most recent blocks.
"""
n = len(available_blocks)
# If below threshold, load all
if n <= self.config.threshold_blocks:
return available_blocks
selected_indices = set()
# Sink blocks (first N blocks)
for i in range(min(self.config.num_sink_blocks, n)):
selected_indices.add(i)
# Local window
if ctx.is_prefill:
# For prefill chunk K, local window is blocks [K-window, K-1]
# (blocks before current chunk, not including current)
window_end = min(ctx.query_chunk_idx, n)
window_start = max(0, window_end - self.config.local_window_blocks)
for i in range(window_start, window_end):
selected_indices.add(i)
else:
# For decode, local window is the last M blocks
for i in range(max(0, n - self.config.local_window_blocks), n):
selected_indices.add(i)
# Return blocks in order (maintains sequential access pattern)
return [available_blocks[i] for i in sorted(selected_indices)]
def __repr__(self) -> str:
return (
f"VerticalSlashPolicy(sink={self.config.num_sink_blocks}, "
f"window={self.config.local_window_blocks}, "
f"threshold={self.config.threshold_blocks})"
)