[WIP] Before refactor policies.

This commit is contained in:
Zijie Tian
2026-01-06 20:47:55 +08:00
parent 7cc8a394a5
commit 690492e074
6 changed files with 112 additions and 237 deletions

View File

@@ -5,86 +5,68 @@ Provides pluggable policies for selecting which KV blocks to load
during chunked attention with CPU offload.
Usage:
from nanovllm.kvcache.sparse import SparsePolicy, PolicyContext
from nanovllm.kvcache.sparse import VerticalSlashPolicy, QuestPolicy
from nanovllm.kvcache.sparse import create_sparse_policy, SparsePolicyType
# Use built-in policy
policy = VerticalSlashPolicy(VerticalSlashConfig())
# Create policy using factory function
policy = create_sparse_policy(SparsePolicyType.QUEST, topk_blocks=8)
# Or create custom policy
class MyPolicy(SparsePolicy):
supports_prefill = True
supports_decode = True
def select_blocks(self, available_blocks, ctx):
return available_blocks[:5] # Just first 5 blocks
"""
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext, SparsePolicyType
from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy
from nanovllm.kvcache.sparse.vertical_slash import VerticalSlashPolicy, VerticalSlashConfig
from nanovllm.kvcache.sparse.quest import QuestPolicy, QuestConfig, BlockMetadataManager
from nanovllm.kvcache.sparse.streaming_llm import StreamingLLMPolicy, StreamingLLMConfig
from nanovllm.kvcache.sparse.hybrid import HybridPolicy
# Built-in policy registry
BUILTIN_SPARSE_POLICIES = {
"full": FullAttentionPolicy,
"vertical_slash": VerticalSlashPolicy,
"streaming_llm": StreamingLLMPolicy,
}
def get_sparse_policy(policy_name: str, **kwargs) -> SparsePolicy:
def create_sparse_policy(policy_type: SparsePolicyType, **kwargs) -> SparsePolicy:
"""
Get a sparse attention policy instance by name.
Create a sparse policy instance from an enum type.
The returned policy is not yet initialized. Call policy.initialize()
or let the framework call it during KV cache allocation.
Args:
policy_name: Policy name ("full", "vertical_slash", "streaming_llm", "quest")
**kwargs: Policy-specific configuration
policy_type: SparsePolicyType enum value
**kwargs: Policy-specific configuration options
Returns:
SparsePolicy instance
"""
policy_name = policy_name.lower()
SparsePolicy instance (not initialized)
if policy_name == "full":
Example:
policy = create_sparse_policy(SparsePolicyType.QUEST, topk_blocks=4)
policy.initialize(num_layers=28, num_kv_heads=8, ...)
"""
if policy_type == SparsePolicyType.FULL:
return FullAttentionPolicy()
elif policy_name == "vertical_slash":
config = VerticalSlashConfig(
num_sink_blocks=kwargs.get("num_sink_blocks", 1),
local_window_blocks=kwargs.get("local_window_blocks", 2),
elif policy_type == SparsePolicyType.QUEST:
config = QuestConfig(
topk_blocks=kwargs.get("topk_blocks", 8),
threshold_blocks=kwargs.get("threshold_blocks", 4),
include_sink_blocks=kwargs.get("include_sink_blocks", 0),
include_recent_blocks=kwargs.get("include_recent_blocks", 0),
)
return VerticalSlashPolicy(config)
elif policy_name == "streaming_llm":
config = StreamingLLMConfig(
num_sink_blocks=kwargs.get("num_sink_blocks", 1),
num_recent_blocks=kwargs.get("num_recent_blocks", 3),
)
return StreamingLLMPolicy(config)
elif policy_name == "quest":
# Quest requires metadata_manager to be passed separately
raise ValueError(
"Quest policy requires BlockMetadataManager. "
"Use QuestPolicy(config, metadata_manager) directly."
)
return QuestPolicy(config)
else:
raise ValueError(
f"Unknown sparse policy '{policy_name}'. "
f"Available policies: {list(BUILTIN_SPARSE_POLICIES.keys())}"
)
raise ValueError(f"Unknown policy type: {policy_type}")
__all__ = [
"SparsePolicy",
"PolicyContext",
"SparsePolicyType",
"FullAttentionPolicy",
"VerticalSlashPolicy",
"VerticalSlashConfig",
"QuestPolicy",
"QuestConfig",
"BlockMetadataManager",
"StreamingLLMPolicy",
"StreamingLLMConfig",
"HybridPolicy",
"get_sparse_policy",
"BUILTIN_SPARSE_POLICIES",
"create_sparse_policy",
]

View File

@@ -22,6 +22,10 @@ class FullAttentionPolicy(SparsePolicy):
- For short sequences where sparsity isn't beneficial
"""
# Full attention supports both prefill and decode
supports_prefill = True
supports_decode = True
def select_blocks(
self,
available_blocks: List[int],

View File

@@ -7,10 +7,17 @@ from CPU for each query chunk during chunked attention computation.
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum, auto
from typing import List, Optional, Any
import torch
class SparsePolicyType(Enum):
"""Built-in sparse attention policy types."""
FULL = auto() # prefill + decode
QUEST = auto() # decode only
@dataclass
class PolicyContext:
"""
@@ -54,8 +61,15 @@ class SparsePolicy(ABC):
sparse attention patterns. The policy receives context about
the current query chunk and returns which KV blocks to load.
Attributes:
supports_prefill: Whether this policy can be used for prefill phase.
supports_decode: Whether this policy can be used for decode phase.
Example:
class MySparsePolicy(SparsePolicy):
supports_prefill = False # decode-only policy
supports_decode = True
def select_blocks(self, available_blocks, ctx):
# Load first block and last 2 blocks
if len(available_blocks) <= 3:
@@ -63,6 +77,34 @@ class SparsePolicy(ABC):
return [available_blocks[0]] + available_blocks[-2:]
"""
# Compatibility flags - override in subclasses
supports_prefill: bool = True
supports_decode: bool = True
def initialize(
self,
num_layers: int,
num_kv_heads: int,
head_dim: int,
num_cpu_blocks: int,
dtype: torch.dtype,
) -> None:
"""
Initialize policy resources.
Called by the framework after KV cache is allocated. Override this
to create metadata structures (e.g., BlockMetadataManager for Quest).
Default implementation does nothing.
Args:
num_layers: Number of transformer layers
num_kv_heads: Number of KV attention heads
head_dim: Dimension per head
num_cpu_blocks: Number of CPU blocks allocated
dtype: Data type for tensors
"""
pass
@abstractmethod
def select_blocks(
self,

View File

@@ -147,22 +147,40 @@ class QuestPolicy(SparsePolicy):
This upper bound is derived from the fact that for any key k in
the block: min_k <= k <= max_k (element-wise), so the actual
attention score is bounded by the maximum of the two extremes.
Note: This is a decode-only policy. For prefill, use FullAttentionPolicy.
"""
def __init__(
self,
config: QuestConfig,
metadata_manager: BlockMetadataManager,
):
# Quest is decode-only
supports_prefill = False
supports_decode = True
def __init__(self, config: QuestConfig):
"""
Initialize Quest policy.
Args:
config: QuestConfig with selection parameters
metadata_manager: BlockMetadataManager for min/max key storage
"""
self.config = config
self.metadata = metadata_manager
self.metadata: Optional[BlockMetadataManager] = None
def initialize(
self,
num_layers: int,
num_kv_heads: int,
head_dim: int,
num_cpu_blocks: int,
dtype: torch.dtype,
) -> None:
"""Create BlockMetadataManager for storing min/max keys."""
self.metadata = BlockMetadataManager(
num_blocks=num_cpu_blocks,
num_layers=num_layers,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
dtype=dtype,
)
def select_blocks(
self,
@@ -175,6 +193,12 @@ class QuestPolicy(SparsePolicy):
If query is not available (some prefill scenarios), falls back
to loading all blocks.
"""
if self.metadata is None:
raise RuntimeError(
"QuestPolicy not initialized. Call initialize() first or "
"let the framework call it during KV cache allocation."
)
n = len(available_blocks)
# If below threshold or no query, load all
@@ -269,10 +293,12 @@ class QuestPolicy(SparsePolicy):
num_valid_tokens: int,
) -> None:
"""Update min/max key metadata when block is offloaded."""
if self.metadata is not None:
self.metadata.update_metadata(cpu_block_id, layer_id, k_cache, num_valid_tokens)
def reset(self) -> None:
"""Reset metadata."""
if self.metadata is not None:
self.metadata.reset()
def __repr__(self) -> str:

View File

@@ -1,84 +0,0 @@
"""
StreamingLLM sparse attention policy.
Only keeps sink tokens (beginning) + recent tokens (end).
Intermediate context is discarded. This enables infinite-length
generation but loses intermediate context.
Reference: StreamingLLM paper on attention sinks.
"""
from dataclasses import dataclass
from typing import List
from .policy import SparsePolicy, PolicyContext
@dataclass
class StreamingLLMConfig:
"""Configuration for StreamingLLMPolicy."""
num_sink_blocks: int = 1
"""Number of blocks at the beginning to always include (attention sinks)."""
num_recent_blocks: int = 3
"""Number of most recent blocks to include (sliding window)."""
class StreamingLLMPolicy(SparsePolicy):
"""
StreamingLLM pattern: sink tokens + recent tokens only.
This is the most aggressive sparsity pattern - only keeps a small
fixed window of context. Suitable for:
- Very long streaming generation
- When intermediate context can be safely discarded
- Maximizing throughput over accuracy
Pattern visualization:
```
Blocks: [0] [1] [2] [3] [4] [5] [6] [7] [8]
× × × ↑ ↑ ↑
sink (discarded) recent window
```
Warning: This loses information from intermediate blocks!
Use only when this trade-off is acceptable.
"""
def __init__(self, config: StreamingLLMConfig = None):
self.config = config or StreamingLLMConfig()
def select_blocks(
self,
available_blocks: List[int],
ctx: PolicyContext,
) -> List[int]:
"""
Select sink blocks + recent blocks only.
Intermediate blocks are not loaded (effectively discarded).
"""
n = len(available_blocks)
# If total blocks fit in sink + recent, load all
total_keep = self.config.num_sink_blocks + self.config.num_recent_blocks
if n <= total_keep:
return available_blocks
selected_indices = set()
# Sink blocks (first N)
for i in range(min(self.config.num_sink_blocks, n)):
selected_indices.add(i)
# Recent blocks (last M)
for i in range(max(0, n - self.config.num_recent_blocks), n):
selected_indices.add(i)
return [available_blocks[i] for i in sorted(selected_indices)]
def __repr__(self) -> str:
return (
f"StreamingLLMPolicy(sink={self.config.num_sink_blocks}, "
f"recent={self.config.num_recent_blocks})"
)

View File

@@ -1,95 +0,0 @@
"""
Vertical-Slash sparse attention policy (MInference-style).
Selects sink blocks (beginning of sequence) + local window blocks
(near the current query position). This pattern captures:
- Important initial context (system prompt, instructions)
- Recent context (relevant for local dependencies)
"""
from dataclasses import dataclass
from typing import List
from .policy import SparsePolicy, PolicyContext
@dataclass
class VerticalSlashConfig:
"""Configuration for VerticalSlashPolicy."""
num_sink_blocks: int = 1
"""Number of blocks at the beginning to always include (sink tokens)."""
local_window_blocks: int = 2
"""Number of blocks in the local window near current query position."""
threshold_blocks: int = 4
"""If total blocks <= threshold, load all (no sparsity applied)."""
class VerticalSlashPolicy(SparsePolicy):
"""
Vertical-Slash pattern: sink tokens + local window.
This pattern is inspired by MInference and observations that:
1. Initial tokens (sink) often receive high attention
2. Local context (recent tokens) is important for dependencies
Pattern visualization:
```
Blocks: [0] [1] [2] [3] [4] [5] [6] [7] [8]
↑ ↑ ↑ ↑
sink local window (for query at block 9)
```
For prefill chunk K, the local window is blocks [K-window, K-1].
For decode, the local window is the last N blocks.
"""
def __init__(self, config: VerticalSlashConfig = None):
self.config = config or VerticalSlashConfig()
def select_blocks(
self,
available_blocks: List[int],
ctx: PolicyContext,
) -> List[int]:
"""
Select sink blocks + local window blocks.
For prefill: local window is relative to current chunk position.
For decode: local window is the most recent blocks.
"""
n = len(available_blocks)
# If below threshold, load all
if n <= self.config.threshold_blocks:
return available_blocks
selected_indices = set()
# Sink blocks (first N blocks)
for i in range(min(self.config.num_sink_blocks, n)):
selected_indices.add(i)
# Local window
if ctx.is_prefill:
# For prefill chunk K, local window is blocks [K-window, K-1]
# (blocks before current chunk, not including current)
window_end = min(ctx.query_chunk_idx, n)
window_start = max(0, window_end - self.config.local_window_blocks)
for i in range(window_start, window_end):
selected_indices.add(i)
else:
# For decode, local window is the last M blocks
for i in range(max(0, n - self.config.local_window_blocks), n):
selected_indices.add(i)
# Return blocks in order (maintains sequential access pattern)
return [available_blocks[i] for i in sorted(selected_indices)]
def __repr__(self) -> str:
return (
f"VerticalSlashPolicy(sink={self.config.num_sink_blocks}, "
f"window={self.config.local_window_blocks}, "
f"threshold={self.config.threshold_blocks})"
)