[feat] Added sparse KVcache feature, NEED VERIFY.
This commit is contained in:
@@ -25,6 +25,11 @@ from nanovllm.kvcache.offload_engine import OffloadEngine
|
||||
from nanovllm.kvcache.policies.base_policy import EvictionPolicy
|
||||
from nanovllm.kvcache.policies.lru_policy import LRUPolicy
|
||||
|
||||
# Type checking import for sparse policy
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from nanovllm.kvcache.sparse.policy import SparsePolicy
|
||||
|
||||
|
||||
class BlockLocation(Enum):
|
||||
"""Where a logical block's data currently resides."""
|
||||
@@ -142,6 +147,9 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
# Key: sequence id, Value: starting position where decode began in current block
|
||||
self._decode_start_pos: Dict[int, int] = {}
|
||||
|
||||
# Sparse attention policy (optional)
|
||||
self.sparse_policy: Optional["SparsePolicy"] = None
|
||||
|
||||
@property
|
||||
def block_size(self) -> int:
|
||||
return self._block_size
|
||||
@@ -174,6 +182,24 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
assert self.offload_engine is not None
|
||||
return self.offload_engine.get_layer_cache(layer_id)
|
||||
|
||||
def set_sparse_policy(self, policy: "SparsePolicy") -> None:
|
||||
"""
|
||||
Set sparse attention policy for block selection.
|
||||
|
||||
The sparse policy determines which KV blocks to load from CPU
|
||||
for each query chunk during chunked attention computation.
|
||||
|
||||
Args:
|
||||
policy: SparsePolicy instance (e.g., VerticalSlashPolicy, QuestPolicy)
|
||||
|
||||
Example:
|
||||
from nanovllm.kvcache.sparse import VerticalSlashPolicy, VerticalSlashConfig
|
||||
policy = VerticalSlashPolicy(VerticalSlashConfig(num_sink_blocks=2))
|
||||
manager.set_sparse_policy(policy)
|
||||
"""
|
||||
self.sparse_policy = policy
|
||||
logger.info(f"Sparse attention policy set: {policy}")
|
||||
|
||||
def _allocate_gpu_slot(self, protected_logical_ids: Optional[Set[int]] = None) -> int:
|
||||
"""
|
||||
Get a free GPU slot, evicting if necessary.
|
||||
|
||||
90
nanovllm/kvcache/sparse/__init__.py
Normal file
90
nanovllm/kvcache/sparse/__init__.py
Normal file
@@ -0,0 +1,90 @@
|
||||
"""
|
||||
Sparse Attention Policy module.
|
||||
|
||||
Provides pluggable policies for selecting which KV blocks to load
|
||||
during chunked attention with CPU offload.
|
||||
|
||||
Usage:
|
||||
from nanovllm.kvcache.sparse import SparsePolicy, PolicyContext
|
||||
from nanovllm.kvcache.sparse import VerticalSlashPolicy, QuestPolicy
|
||||
|
||||
# Use built-in policy
|
||||
policy = VerticalSlashPolicy(VerticalSlashConfig())
|
||||
|
||||
# Or create custom policy
|
||||
class MyPolicy(SparsePolicy):
|
||||
def select_blocks(self, available_blocks, ctx):
|
||||
return available_blocks[:5] # Just first 5 blocks
|
||||
"""
|
||||
|
||||
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext
|
||||
from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy
|
||||
from nanovllm.kvcache.sparse.vertical_slash import VerticalSlashPolicy, VerticalSlashConfig
|
||||
from nanovllm.kvcache.sparse.quest import QuestPolicy, QuestConfig, BlockMetadataManager
|
||||
from nanovllm.kvcache.sparse.streaming_llm import StreamingLLMPolicy, StreamingLLMConfig
|
||||
from nanovllm.kvcache.sparse.hybrid import HybridPolicy
|
||||
|
||||
# Built-in policy registry
|
||||
BUILTIN_SPARSE_POLICIES = {
|
||||
"full": FullAttentionPolicy,
|
||||
"vertical_slash": VerticalSlashPolicy,
|
||||
"streaming_llm": StreamingLLMPolicy,
|
||||
}
|
||||
|
||||
|
||||
def get_sparse_policy(policy_name: str, **kwargs) -> SparsePolicy:
|
||||
"""
|
||||
Get a sparse attention policy instance by name.
|
||||
|
||||
Args:
|
||||
policy_name: Policy name ("full", "vertical_slash", "streaming_llm", "quest")
|
||||
**kwargs: Policy-specific configuration
|
||||
|
||||
Returns:
|
||||
SparsePolicy instance
|
||||
"""
|
||||
policy_name = policy_name.lower()
|
||||
|
||||
if policy_name == "full":
|
||||
return FullAttentionPolicy()
|
||||
elif policy_name == "vertical_slash":
|
||||
config = VerticalSlashConfig(
|
||||
num_sink_blocks=kwargs.get("num_sink_blocks", 1),
|
||||
local_window_blocks=kwargs.get("local_window_blocks", 2),
|
||||
threshold_blocks=kwargs.get("threshold_blocks", 4),
|
||||
)
|
||||
return VerticalSlashPolicy(config)
|
||||
elif policy_name == "streaming_llm":
|
||||
config = StreamingLLMConfig(
|
||||
num_sink_blocks=kwargs.get("num_sink_blocks", 1),
|
||||
num_recent_blocks=kwargs.get("num_recent_blocks", 3),
|
||||
)
|
||||
return StreamingLLMPolicy(config)
|
||||
elif policy_name == "quest":
|
||||
# Quest requires metadata_manager to be passed separately
|
||||
raise ValueError(
|
||||
"Quest policy requires BlockMetadataManager. "
|
||||
"Use QuestPolicy(config, metadata_manager) directly."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown sparse policy '{policy_name}'. "
|
||||
f"Available policies: {list(BUILTIN_SPARSE_POLICIES.keys())}"
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"SparsePolicy",
|
||||
"PolicyContext",
|
||||
"FullAttentionPolicy",
|
||||
"VerticalSlashPolicy",
|
||||
"VerticalSlashConfig",
|
||||
"QuestPolicy",
|
||||
"QuestConfig",
|
||||
"BlockMetadataManager",
|
||||
"StreamingLLMPolicy",
|
||||
"StreamingLLMConfig",
|
||||
"HybridPolicy",
|
||||
"get_sparse_policy",
|
||||
"BUILTIN_SPARSE_POLICIES",
|
||||
]
|
||||
34
nanovllm/kvcache/sparse/full_policy.py
Normal file
34
nanovllm/kvcache/sparse/full_policy.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""
|
||||
Full attention policy - loads all blocks (no sparsity).
|
||||
|
||||
This serves as a baseline and default policy when sparse
|
||||
attention is not needed.
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
from .policy import SparsePolicy, PolicyContext
|
||||
|
||||
|
||||
class FullAttentionPolicy(SparsePolicy):
|
||||
"""
|
||||
Full attention policy that loads all available blocks.
|
||||
|
||||
This is the default behavior with no sparsity - all previous
|
||||
KV cache blocks are loaded for each query chunk.
|
||||
|
||||
Use this as:
|
||||
- A baseline for comparing sparse policies
|
||||
- When you need full attention accuracy
|
||||
- For short sequences where sparsity isn't beneficial
|
||||
"""
|
||||
|
||||
def select_blocks(
|
||||
self,
|
||||
available_blocks: List[int],
|
||||
ctx: PolicyContext,
|
||||
) -> List[int]:
|
||||
"""Return all blocks - no sparsity."""
|
||||
return available_blocks
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "FullAttentionPolicy()"
|
||||
93
nanovllm/kvcache/sparse/hybrid.py
Normal file
93
nanovllm/kvcache/sparse/hybrid.py
Normal file
@@ -0,0 +1,93 @@
|
||||
"""
|
||||
Hybrid sparse attention policy.
|
||||
|
||||
Allows using different policies for prefill vs decode phases.
|
||||
This is useful because optimal sparsity patterns often differ:
|
||||
- Prefill: fixed patterns work well (e.g., VerticalSlash)
|
||||
- Decode: query-aware selection helps (e.g., Quest)
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
import torch
|
||||
from .policy import SparsePolicy, PolicyContext
|
||||
|
||||
|
||||
class HybridPolicy(SparsePolicy):
|
||||
"""
|
||||
Hybrid policy that uses different policies for prefill and decode.
|
||||
|
||||
Example usage:
|
||||
```python
|
||||
from nanovllm.kvcache.sparse import (
|
||||
HybridPolicy, VerticalSlashPolicy, QuestPolicy,
|
||||
VerticalSlashConfig, QuestConfig, BlockMetadataManager
|
||||
)
|
||||
|
||||
# Prefill: use fast fixed pattern
|
||||
prefill_policy = VerticalSlashPolicy(VerticalSlashConfig(
|
||||
num_sink_blocks=1,
|
||||
local_window_blocks=3,
|
||||
))
|
||||
|
||||
# Decode: use query-aware selection
|
||||
metadata = BlockMetadataManager(num_blocks, num_layers, num_heads, head_dim)
|
||||
decode_policy = QuestPolicy(QuestConfig(topk_blocks=8), metadata)
|
||||
|
||||
# Combine
|
||||
policy = HybridPolicy(prefill_policy, decode_policy)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prefill_policy: SparsePolicy,
|
||||
decode_policy: SparsePolicy,
|
||||
):
|
||||
"""
|
||||
Initialize hybrid policy.
|
||||
|
||||
Args:
|
||||
prefill_policy: Policy to use during prefill phase
|
||||
decode_policy: Policy to use during decode phase
|
||||
"""
|
||||
self.prefill_policy = prefill_policy
|
||||
self.decode_policy = decode_policy
|
||||
|
||||
def select_blocks(
|
||||
self,
|
||||
available_blocks: List[int],
|
||||
ctx: PolicyContext,
|
||||
) -> List[int]:
|
||||
"""Delegate to appropriate policy based on phase."""
|
||||
if ctx.is_prefill:
|
||||
return self.prefill_policy.select_blocks(available_blocks, ctx)
|
||||
else:
|
||||
return self.decode_policy.select_blocks(available_blocks, ctx)
|
||||
|
||||
def on_block_offloaded(
|
||||
self,
|
||||
cpu_block_id: int,
|
||||
layer_id: int,
|
||||
k_cache: torch.Tensor,
|
||||
num_valid_tokens: int,
|
||||
) -> None:
|
||||
"""Forward to both policies (both may need metadata updates)."""
|
||||
self.prefill_policy.on_block_offloaded(
|
||||
cpu_block_id, layer_id, k_cache, num_valid_tokens
|
||||
)
|
||||
self.decode_policy.on_block_offloaded(
|
||||
cpu_block_id, layer_id, k_cache, num_valid_tokens
|
||||
)
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset both policies."""
|
||||
self.prefill_policy.reset()
|
||||
self.decode_policy.reset()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"HybridPolicy(\n"
|
||||
f" prefill={self.prefill_policy},\n"
|
||||
f" decode={self.decode_policy}\n"
|
||||
f")"
|
||||
)
|
||||
124
nanovllm/kvcache/sparse/policy.py
Normal file
124
nanovllm/kvcache/sparse/policy.py
Normal file
@@ -0,0 +1,124 @@
|
||||
"""
|
||||
Base class for sparse attention policies.
|
||||
|
||||
Sparse attention policies determine which KV cache blocks to load
|
||||
from CPU for each query chunk during chunked attention computation.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Any
|
||||
import torch
|
||||
|
||||
|
||||
@dataclass
|
||||
class PolicyContext:
|
||||
"""
|
||||
Context passed to sparse policy for block selection.
|
||||
|
||||
This dataclass contains all information needed by a sparse policy
|
||||
to decide which blocks to load for the current query chunk.
|
||||
"""
|
||||
|
||||
query_chunk_idx: int
|
||||
"""Index of the current query chunk (0-indexed)."""
|
||||
|
||||
num_query_chunks: int
|
||||
"""Total number of query chunks in this prefill."""
|
||||
|
||||
layer_id: int
|
||||
"""Current transformer layer index."""
|
||||
|
||||
query: Optional[torch.Tensor]
|
||||
"""
|
||||
Query tensor for current chunk.
|
||||
Shape: [1, num_heads, head_dim] for decode, [1, seq_len, num_heads, head_dim] for prefill.
|
||||
May be None if not available (e.g., some prefill scenarios).
|
||||
"""
|
||||
|
||||
is_prefill: bool
|
||||
"""True if in prefill phase, False if in decode phase."""
|
||||
|
||||
block_size: int = 4096
|
||||
"""Number of tokens per block."""
|
||||
|
||||
total_kv_len: int = 0
|
||||
"""Total KV sequence length so far (for reference)."""
|
||||
|
||||
|
||||
class SparsePolicy(ABC):
|
||||
"""
|
||||
Abstract base class for sparse attention policies.
|
||||
|
||||
Subclass this and implement select_blocks() to create custom
|
||||
sparse attention patterns. The policy receives context about
|
||||
the current query chunk and returns which KV blocks to load.
|
||||
|
||||
Example:
|
||||
class MySparsePolicy(SparsePolicy):
|
||||
def select_blocks(self, available_blocks, ctx):
|
||||
# Load first block and last 2 blocks
|
||||
if len(available_blocks) <= 3:
|
||||
return available_blocks
|
||||
return [available_blocks[0]] + available_blocks[-2:]
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def select_blocks(
|
||||
self,
|
||||
available_blocks: List[int],
|
||||
ctx: PolicyContext,
|
||||
) -> List[int]:
|
||||
"""
|
||||
Select which KV blocks to load for the current query chunk.
|
||||
|
||||
This is the core method that defines the sparse attention pattern.
|
||||
The returned blocks will be loaded from CPU to GPU for attention
|
||||
computation against the current query chunk.
|
||||
|
||||
Args:
|
||||
available_blocks: List of CPU block IDs that contain KV cache
|
||||
from previous chunks. These are ordered by
|
||||
their position in the sequence.
|
||||
ctx: PolicyContext with information about the current query
|
||||
chunk, layer, phase (prefill/decode), etc.
|
||||
|
||||
Returns:
|
||||
List of block IDs to load (must be a subset of available_blocks).
|
||||
The order may affect performance (sequential access is faster).
|
||||
Returning [] means no previous blocks will be loaded.
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_block_offloaded(
|
||||
self,
|
||||
cpu_block_id: int,
|
||||
layer_id: int,
|
||||
k_cache: torch.Tensor,
|
||||
num_valid_tokens: int,
|
||||
) -> None:
|
||||
"""
|
||||
Hook called when a block is offloaded from GPU to CPU.
|
||||
|
||||
Override this to collect metadata about blocks (e.g., min/max keys
|
||||
for Quest-style selection). Default implementation does nothing.
|
||||
|
||||
Args:
|
||||
cpu_block_id: The CPU block ID that was written
|
||||
layer_id: Transformer layer index
|
||||
k_cache: Key cache tensor [block_size, num_kv_heads, head_dim]
|
||||
num_valid_tokens: Number of valid tokens in this block
|
||||
"""
|
||||
pass
|
||||
|
||||
def reset(self) -> None:
|
||||
"""
|
||||
Reset policy state.
|
||||
|
||||
Called when starting a new sequence or clearing state.
|
||||
Default implementation does nothing.
|
||||
"""
|
||||
pass
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}()"
|
||||
284
nanovllm/kvcache/sparse/quest.py
Normal file
284
nanovllm/kvcache/sparse/quest.py
Normal file
@@ -0,0 +1,284 @@
|
||||
"""
|
||||
Quest-style sparse attention policy.
|
||||
|
||||
Uses min/max key bounds per block to estimate attention scores
|
||||
and select Top-K blocks most relevant to the current query.
|
||||
|
||||
Reference: Quest paper on query-aware KV cache selection.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import torch
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Tuple, Optional
|
||||
from .policy import SparsePolicy, PolicyContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BlockMetadataManager:
|
||||
"""
|
||||
Manages per-block metadata for Quest-style sparse selection.
|
||||
|
||||
Stores min/max key values for each block, which are used to
|
||||
compute upper bounds on attention scores without loading the
|
||||
full KV cache.
|
||||
|
||||
Memory usage: 2 * num_blocks * num_layers * num_kv_heads * head_dim * dtype_size
|
||||
Example: 1000 blocks, 28 layers, 4 heads, 128 dim, bf16 = ~57 MB
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_blocks: int,
|
||||
num_layers: int,
|
||||
num_kv_heads: int,
|
||||
head_dim: int,
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
):
|
||||
"""
|
||||
Initialize metadata storage.
|
||||
|
||||
Args:
|
||||
num_blocks: Maximum number of CPU blocks
|
||||
num_layers: Number of transformer layers
|
||||
num_kv_heads: Number of KV attention heads
|
||||
head_dim: Dimension per head
|
||||
dtype: Data type for metadata storage
|
||||
"""
|
||||
self.num_blocks = num_blocks
|
||||
self.num_layers = num_layers
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.head_dim = head_dim
|
||||
self.dtype = dtype
|
||||
|
||||
# Per-block min/max key values: [num_blocks, num_layers, num_heads, head_dim]
|
||||
shape = (num_blocks, num_layers, num_kv_heads, head_dim)
|
||||
self.key_min = torch.zeros(shape, dtype=dtype, pin_memory=True)
|
||||
self.key_max = torch.zeros(shape, dtype=dtype, pin_memory=True)
|
||||
|
||||
# Track which blocks have valid metadata
|
||||
self.valid_blocks = torch.zeros(num_blocks, dtype=torch.bool)
|
||||
|
||||
def update_metadata(
|
||||
self,
|
||||
block_id: int,
|
||||
layer_id: int,
|
||||
k_cache: torch.Tensor,
|
||||
num_valid_tokens: int,
|
||||
) -> None:
|
||||
"""
|
||||
Update min/max key bounds for a block.
|
||||
|
||||
Called when a block is offloaded to CPU.
|
||||
|
||||
Args:
|
||||
block_id: CPU block ID
|
||||
layer_id: Layer index
|
||||
k_cache: Key cache tensor [block_size, num_kv_heads, head_dim]
|
||||
num_valid_tokens: Number of valid tokens in this block
|
||||
"""
|
||||
if num_valid_tokens == 0:
|
||||
return
|
||||
|
||||
# Get valid keys only
|
||||
k_valid = k_cache[:num_valid_tokens].cpu() # [num_tokens, heads, dim]
|
||||
|
||||
# Compute min/max across token dimension
|
||||
self.key_min[block_id, layer_id] = k_valid.min(dim=0).values
|
||||
self.key_max[block_id, layer_id] = k_valid.max(dim=0).values
|
||||
self.valid_blocks[block_id] = True
|
||||
|
||||
def get_block_metadata(
|
||||
self,
|
||||
block_ids: List[int],
|
||||
layer_id: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Get min/max keys for specified blocks.
|
||||
|
||||
Args:
|
||||
block_ids: List of CPU block IDs
|
||||
layer_id: Layer index
|
||||
|
||||
Returns:
|
||||
Tuple of (key_min, key_max) tensors
|
||||
Shape: [num_blocks, num_kv_heads, head_dim]
|
||||
"""
|
||||
key_min = self.key_min[block_ids, layer_id]
|
||||
key_max = self.key_max[block_ids, layer_id]
|
||||
return key_min, key_max
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset all metadata."""
|
||||
self.key_min.zero_()
|
||||
self.key_max.zero_()
|
||||
self.valid_blocks.zero_()
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuestConfig:
|
||||
"""Configuration for QuestPolicy."""
|
||||
|
||||
topk_blocks: int = 8
|
||||
"""Number of top blocks to select based on estimated attention scores."""
|
||||
|
||||
threshold_blocks: int = 4
|
||||
"""If total blocks <= threshold, load all (no scoring needed)."""
|
||||
|
||||
include_sink_blocks: int = 0
|
||||
"""Always include this many sink blocks (first N blocks), in addition to Top-K."""
|
||||
|
||||
include_recent_blocks: int = 0
|
||||
"""Always include this many recent blocks (last N blocks), in addition to Top-K."""
|
||||
|
||||
|
||||
class QuestPolicy(SparsePolicy):
|
||||
"""
|
||||
Quest-style Top-K block selection using min/max key bounds.
|
||||
|
||||
For each query, computes an upper bound on attention scores for
|
||||
each block using the stored min/max keys, then selects the Top-K
|
||||
blocks with highest estimated scores.
|
||||
|
||||
Score computation:
|
||||
score(q, block) = max(q · key_min, q · key_max)
|
||||
|
||||
This upper bound is derived from the fact that for any key k in
|
||||
the block: min_k <= k <= max_k (element-wise), so the actual
|
||||
attention score is bounded by the maximum of the two extremes.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: QuestConfig,
|
||||
metadata_manager: BlockMetadataManager,
|
||||
):
|
||||
"""
|
||||
Initialize Quest policy.
|
||||
|
||||
Args:
|
||||
config: QuestConfig with selection parameters
|
||||
metadata_manager: BlockMetadataManager for min/max key storage
|
||||
"""
|
||||
self.config = config
|
||||
self.metadata = metadata_manager
|
||||
|
||||
def select_blocks(
|
||||
self,
|
||||
available_blocks: List[int],
|
||||
ctx: PolicyContext,
|
||||
) -> List[int]:
|
||||
"""
|
||||
Select Top-K blocks based on query-key similarity bounds.
|
||||
|
||||
If query is not available (some prefill scenarios), falls back
|
||||
to loading all blocks.
|
||||
"""
|
||||
n = len(available_blocks)
|
||||
|
||||
# If below threshold or no query, load all
|
||||
if n <= self.config.threshold_blocks:
|
||||
return available_blocks
|
||||
|
||||
if ctx.query is None:
|
||||
# No query available - cannot compute scores
|
||||
return available_blocks
|
||||
|
||||
# Get metadata for available blocks
|
||||
key_min, key_max = self.metadata.get_block_metadata(
|
||||
available_blocks, ctx.layer_id
|
||||
)
|
||||
|
||||
# Move to query device for computation
|
||||
device = ctx.query.device
|
||||
key_min = key_min.to(device, non_blocking=True)
|
||||
key_max = key_max.to(device, non_blocking=True)
|
||||
|
||||
# Compute upper bound scores
|
||||
# query shape: [1, num_heads, head_dim] or [1, seq_len, num_heads, head_dim]
|
||||
q = ctx.query
|
||||
if q.dim() == 4:
|
||||
# Prefill: use mean over sequence length
|
||||
q = q.mean(dim=1) # [1, num_heads, head_dim]
|
||||
q = q.squeeze(0) # [num_q_heads, head_dim]
|
||||
|
||||
# Handle GQA: query may have more heads than KV
|
||||
# key_min/key_max: [num_blocks, num_kv_heads, head_dim]
|
||||
num_q_heads = q.shape[0]
|
||||
num_kv_heads = key_min.shape[1]
|
||||
|
||||
if num_q_heads != num_kv_heads:
|
||||
# GQA: group query heads and average per KV group
|
||||
# Reshape q: [num_q_heads, head_dim] -> [num_kv_heads, group_size, head_dim]
|
||||
group_size = num_q_heads // num_kv_heads
|
||||
q = q.view(num_kv_heads, group_size, -1).mean(dim=1) # [num_kv_heads, head_dim]
|
||||
|
||||
# Score: max(q·k_min, q·k_max) averaged over heads
|
||||
# key_min/key_max: [num_blocks, num_kv_heads, head_dim]
|
||||
# q: [num_kv_heads, head_dim]
|
||||
score_min = torch.einsum('hd,bhd->bh', q, key_min) # [num_blocks, kv_heads]
|
||||
score_max = torch.einsum('hd,bhd->bh', q, key_max)
|
||||
scores = torch.maximum(score_min, score_max).mean(dim=-1) # [num_blocks]
|
||||
|
||||
# Build selection set
|
||||
selected_indices = set()
|
||||
|
||||
# Always include sink blocks
|
||||
for i in range(min(self.config.include_sink_blocks, n)):
|
||||
selected_indices.add(i)
|
||||
|
||||
# Always include recent blocks
|
||||
for i in range(max(0, n - self.config.include_recent_blocks), n):
|
||||
selected_indices.add(i)
|
||||
|
||||
# Top-K selection from remaining
|
||||
remaining_k = max(0, self.config.topk_blocks - len(selected_indices))
|
||||
if remaining_k > 0:
|
||||
# Mask out already selected
|
||||
mask = torch.ones(n, dtype=torch.bool, device=device)
|
||||
for idx in selected_indices:
|
||||
mask[idx] = False
|
||||
|
||||
if mask.any():
|
||||
masked_scores = scores.clone()
|
||||
masked_scores[~mask] = float('-inf')
|
||||
topk_count = min(remaining_k, mask.sum().item())
|
||||
if topk_count > 0:
|
||||
topk_indices = masked_scores.topk(topk_count).indices.cpu().tolist()
|
||||
selected_indices.update(topk_indices)
|
||||
|
||||
# Return in sequential order for better memory access
|
||||
result = [available_blocks[i] for i in sorted(selected_indices)]
|
||||
|
||||
# Log selection info (only for layer 0 to avoid spam)
|
||||
if ctx.layer_id == 0:
|
||||
logger.debug(
|
||||
f"Quest select: {len(result)}/{n} blocks "
|
||||
f"(topk={self.config.topk_blocks}, sink={self.config.include_sink_blocks}, "
|
||||
f"recent={self.config.include_recent_blocks})"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def on_block_offloaded(
|
||||
self,
|
||||
cpu_block_id: int,
|
||||
layer_id: int,
|
||||
k_cache: torch.Tensor,
|
||||
num_valid_tokens: int,
|
||||
) -> None:
|
||||
"""Update min/max key metadata when block is offloaded."""
|
||||
self.metadata.update_metadata(cpu_block_id, layer_id, k_cache, num_valid_tokens)
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset metadata."""
|
||||
self.metadata.reset()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"QuestPolicy(topk={self.config.topk_blocks}, "
|
||||
f"threshold={self.config.threshold_blocks}, "
|
||||
f"sink={self.config.include_sink_blocks}, "
|
||||
f"recent={self.config.include_recent_blocks})"
|
||||
)
|
||||
84
nanovllm/kvcache/sparse/streaming_llm.py
Normal file
84
nanovllm/kvcache/sparse/streaming_llm.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""
|
||||
StreamingLLM sparse attention policy.
|
||||
|
||||
Only keeps sink tokens (beginning) + recent tokens (end).
|
||||
Intermediate context is discarded. This enables infinite-length
|
||||
generation but loses intermediate context.
|
||||
|
||||
Reference: StreamingLLM paper on attention sinks.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
from .policy import SparsePolicy, PolicyContext
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamingLLMConfig:
|
||||
"""Configuration for StreamingLLMPolicy."""
|
||||
|
||||
num_sink_blocks: int = 1
|
||||
"""Number of blocks at the beginning to always include (attention sinks)."""
|
||||
|
||||
num_recent_blocks: int = 3
|
||||
"""Number of most recent blocks to include (sliding window)."""
|
||||
|
||||
|
||||
class StreamingLLMPolicy(SparsePolicy):
|
||||
"""
|
||||
StreamingLLM pattern: sink tokens + recent tokens only.
|
||||
|
||||
This is the most aggressive sparsity pattern - only keeps a small
|
||||
fixed window of context. Suitable for:
|
||||
- Very long streaming generation
|
||||
- When intermediate context can be safely discarded
|
||||
- Maximizing throughput over accuracy
|
||||
|
||||
Pattern visualization:
|
||||
```
|
||||
Blocks: [0] [1] [2] [3] [4] [5] [6] [7] [8]
|
||||
↑ × × × ↑ ↑ ↑
|
||||
sink (discarded) recent window
|
||||
```
|
||||
|
||||
Warning: This loses information from intermediate blocks!
|
||||
Use only when this trade-off is acceptable.
|
||||
"""
|
||||
|
||||
def __init__(self, config: StreamingLLMConfig = None):
|
||||
self.config = config or StreamingLLMConfig()
|
||||
|
||||
def select_blocks(
|
||||
self,
|
||||
available_blocks: List[int],
|
||||
ctx: PolicyContext,
|
||||
) -> List[int]:
|
||||
"""
|
||||
Select sink blocks + recent blocks only.
|
||||
|
||||
Intermediate blocks are not loaded (effectively discarded).
|
||||
"""
|
||||
n = len(available_blocks)
|
||||
|
||||
# If total blocks fit in sink + recent, load all
|
||||
total_keep = self.config.num_sink_blocks + self.config.num_recent_blocks
|
||||
if n <= total_keep:
|
||||
return available_blocks
|
||||
|
||||
selected_indices = set()
|
||||
|
||||
# Sink blocks (first N)
|
||||
for i in range(min(self.config.num_sink_blocks, n)):
|
||||
selected_indices.add(i)
|
||||
|
||||
# Recent blocks (last M)
|
||||
for i in range(max(0, n - self.config.num_recent_blocks), n):
|
||||
selected_indices.add(i)
|
||||
|
||||
return [available_blocks[i] for i in sorted(selected_indices)]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"StreamingLLMPolicy(sink={self.config.num_sink_blocks}, "
|
||||
f"recent={self.config.num_recent_blocks})"
|
||||
)
|
||||
95
nanovllm/kvcache/sparse/vertical_slash.py
Normal file
95
nanovllm/kvcache/sparse/vertical_slash.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""
|
||||
Vertical-Slash sparse attention policy (MInference-style).
|
||||
|
||||
Selects sink blocks (beginning of sequence) + local window blocks
|
||||
(near the current query position). This pattern captures:
|
||||
- Important initial context (system prompt, instructions)
|
||||
- Recent context (relevant for local dependencies)
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
from .policy import SparsePolicy, PolicyContext
|
||||
|
||||
|
||||
@dataclass
|
||||
class VerticalSlashConfig:
|
||||
"""Configuration for VerticalSlashPolicy."""
|
||||
|
||||
num_sink_blocks: int = 1
|
||||
"""Number of blocks at the beginning to always include (sink tokens)."""
|
||||
|
||||
local_window_blocks: int = 2
|
||||
"""Number of blocks in the local window near current query position."""
|
||||
|
||||
threshold_blocks: int = 4
|
||||
"""If total blocks <= threshold, load all (no sparsity applied)."""
|
||||
|
||||
|
||||
class VerticalSlashPolicy(SparsePolicy):
|
||||
"""
|
||||
Vertical-Slash pattern: sink tokens + local window.
|
||||
|
||||
This pattern is inspired by MInference and observations that:
|
||||
1. Initial tokens (sink) often receive high attention
|
||||
2. Local context (recent tokens) is important for dependencies
|
||||
|
||||
Pattern visualization:
|
||||
```
|
||||
Blocks: [0] [1] [2] [3] [4] [5] [6] [7] [8]
|
||||
↑ ↑ ↑ ↑
|
||||
sink local window (for query at block 9)
|
||||
```
|
||||
|
||||
For prefill chunk K, the local window is blocks [K-window, K-1].
|
||||
For decode, the local window is the last N blocks.
|
||||
"""
|
||||
|
||||
def __init__(self, config: VerticalSlashConfig = None):
|
||||
self.config = config or VerticalSlashConfig()
|
||||
|
||||
def select_blocks(
|
||||
self,
|
||||
available_blocks: List[int],
|
||||
ctx: PolicyContext,
|
||||
) -> List[int]:
|
||||
"""
|
||||
Select sink blocks + local window blocks.
|
||||
|
||||
For prefill: local window is relative to current chunk position.
|
||||
For decode: local window is the most recent blocks.
|
||||
"""
|
||||
n = len(available_blocks)
|
||||
|
||||
# If below threshold, load all
|
||||
if n <= self.config.threshold_blocks:
|
||||
return available_blocks
|
||||
|
||||
selected_indices = set()
|
||||
|
||||
# Sink blocks (first N blocks)
|
||||
for i in range(min(self.config.num_sink_blocks, n)):
|
||||
selected_indices.add(i)
|
||||
|
||||
# Local window
|
||||
if ctx.is_prefill:
|
||||
# For prefill chunk K, local window is blocks [K-window, K-1]
|
||||
# (blocks before current chunk, not including current)
|
||||
window_end = min(ctx.query_chunk_idx, n)
|
||||
window_start = max(0, window_end - self.config.local_window_blocks)
|
||||
for i in range(window_start, window_end):
|
||||
selected_indices.add(i)
|
||||
else:
|
||||
# For decode, local window is the last M blocks
|
||||
for i in range(max(0, n - self.config.local_window_blocks), n):
|
||||
selected_indices.add(i)
|
||||
|
||||
# Return blocks in order (maintains sequential access pattern)
|
||||
return [available_blocks[i] for i in sorted(selected_indices)]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"VerticalSlashPolicy(sink={self.config.num_sink_blocks}, "
|
||||
f"window={self.config.local_window_blocks}, "
|
||||
f"threshold={self.config.threshold_blocks})"
|
||||
)
|
||||
Reference in New Issue
Block a user