[WIP] Before refactor policies.

This commit is contained in:
Zijie Tian
2026-01-06 20:47:55 +08:00
parent 7cc8a394a5
commit 690492e074
6 changed files with 112 additions and 237 deletions

View File

@@ -5,86 +5,68 @@ Provides pluggable policies for selecting which KV blocks to load
during chunked attention with CPU offload. during chunked attention with CPU offload.
Usage: Usage:
from nanovllm.kvcache.sparse import SparsePolicy, PolicyContext from nanovllm.kvcache.sparse import create_sparse_policy, SparsePolicyType
from nanovllm.kvcache.sparse import VerticalSlashPolicy, QuestPolicy
# Use built-in policy # Create policy using factory function
policy = VerticalSlashPolicy(VerticalSlashConfig()) policy = create_sparse_policy(SparsePolicyType.QUEST, topk_blocks=8)
# Or create custom policy # Or create custom policy
class MyPolicy(SparsePolicy): class MyPolicy(SparsePolicy):
supports_prefill = True
supports_decode = True
def select_blocks(self, available_blocks, ctx): def select_blocks(self, available_blocks, ctx):
return available_blocks[:5] # Just first 5 blocks return available_blocks[:5] # Just first 5 blocks
""" """
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext, SparsePolicyType
from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy
from nanovllm.kvcache.sparse.vertical_slash import VerticalSlashPolicy, VerticalSlashConfig
from nanovllm.kvcache.sparse.quest import QuestPolicy, QuestConfig, BlockMetadataManager from nanovllm.kvcache.sparse.quest import QuestPolicy, QuestConfig, BlockMetadataManager
from nanovllm.kvcache.sparse.streaming_llm import StreamingLLMPolicy, StreamingLLMConfig
from nanovllm.kvcache.sparse.hybrid import HybridPolicy from nanovllm.kvcache.sparse.hybrid import HybridPolicy
# Built-in policy registry
BUILTIN_SPARSE_POLICIES = {
"full": FullAttentionPolicy,
"vertical_slash": VerticalSlashPolicy,
"streaming_llm": StreamingLLMPolicy,
}
def create_sparse_policy(policy_type: SparsePolicyType, **kwargs) -> SparsePolicy:
def get_sparse_policy(policy_name: str, **kwargs) -> SparsePolicy:
""" """
Get a sparse attention policy instance by name. Create a sparse policy instance from an enum type.
The returned policy is not yet initialized. Call policy.initialize()
or let the framework call it during KV cache allocation.
Args: Args:
policy_name: Policy name ("full", "vertical_slash", "streaming_llm", "quest") policy_type: SparsePolicyType enum value
**kwargs: Policy-specific configuration **kwargs: Policy-specific configuration options
Returns: Returns:
SparsePolicy instance SparsePolicy instance (not initialized)
"""
policy_name = policy_name.lower()
if policy_name == "full": Example:
policy = create_sparse_policy(SparsePolicyType.QUEST, topk_blocks=4)
policy.initialize(num_layers=28, num_kv_heads=8, ...)
"""
if policy_type == SparsePolicyType.FULL:
return FullAttentionPolicy() return FullAttentionPolicy()
elif policy_name == "vertical_slash":
config = VerticalSlashConfig( elif policy_type == SparsePolicyType.QUEST:
num_sink_blocks=kwargs.get("num_sink_blocks", 1), config = QuestConfig(
local_window_blocks=kwargs.get("local_window_blocks", 2), topk_blocks=kwargs.get("topk_blocks", 8),
threshold_blocks=kwargs.get("threshold_blocks", 4), threshold_blocks=kwargs.get("threshold_blocks", 4),
include_sink_blocks=kwargs.get("include_sink_blocks", 0),
include_recent_blocks=kwargs.get("include_recent_blocks", 0),
) )
return VerticalSlashPolicy(config) return QuestPolicy(config)
elif policy_name == "streaming_llm":
config = StreamingLLMConfig(
num_sink_blocks=kwargs.get("num_sink_blocks", 1),
num_recent_blocks=kwargs.get("num_recent_blocks", 3),
)
return StreamingLLMPolicy(config)
elif policy_name == "quest":
# Quest requires metadata_manager to be passed separately
raise ValueError(
"Quest policy requires BlockMetadataManager. "
"Use QuestPolicy(config, metadata_manager) directly."
)
else: else:
raise ValueError( raise ValueError(f"Unknown policy type: {policy_type}")
f"Unknown sparse policy '{policy_name}'. "
f"Available policies: {list(BUILTIN_SPARSE_POLICIES.keys())}"
)
__all__ = [ __all__ = [
"SparsePolicy", "SparsePolicy",
"PolicyContext", "PolicyContext",
"SparsePolicyType",
"FullAttentionPolicy", "FullAttentionPolicy",
"VerticalSlashPolicy",
"VerticalSlashConfig",
"QuestPolicy", "QuestPolicy",
"QuestConfig", "QuestConfig",
"BlockMetadataManager", "BlockMetadataManager",
"StreamingLLMPolicy",
"StreamingLLMConfig",
"HybridPolicy", "HybridPolicy",
"get_sparse_policy", "create_sparse_policy",
"BUILTIN_SPARSE_POLICIES",
] ]

View File

@@ -22,6 +22,10 @@ class FullAttentionPolicy(SparsePolicy):
- For short sequences where sparsity isn't beneficial - For short sequences where sparsity isn't beneficial
""" """
# Full attention supports both prefill and decode
supports_prefill = True
supports_decode = True
def select_blocks( def select_blocks(
self, self,
available_blocks: List[int], available_blocks: List[int],

View File

@@ -7,10 +7,17 @@ from CPU for each query chunk during chunked attention computation.
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum, auto
from typing import List, Optional, Any from typing import List, Optional, Any
import torch import torch
class SparsePolicyType(Enum):
"""Built-in sparse attention policy types."""
FULL = auto() # prefill + decode
QUEST = auto() # decode only
@dataclass @dataclass
class PolicyContext: class PolicyContext:
""" """
@@ -54,8 +61,15 @@ class SparsePolicy(ABC):
sparse attention patterns. The policy receives context about sparse attention patterns. The policy receives context about
the current query chunk and returns which KV blocks to load. the current query chunk and returns which KV blocks to load.
Attributes:
supports_prefill: Whether this policy can be used for prefill phase.
supports_decode: Whether this policy can be used for decode phase.
Example: Example:
class MySparsePolicy(SparsePolicy): class MySparsePolicy(SparsePolicy):
supports_prefill = False # decode-only policy
supports_decode = True
def select_blocks(self, available_blocks, ctx): def select_blocks(self, available_blocks, ctx):
# Load first block and last 2 blocks # Load first block and last 2 blocks
if len(available_blocks) <= 3: if len(available_blocks) <= 3:
@@ -63,6 +77,34 @@ class SparsePolicy(ABC):
return [available_blocks[0]] + available_blocks[-2:] return [available_blocks[0]] + available_blocks[-2:]
""" """
# Compatibility flags - override in subclasses
supports_prefill: bool = True
supports_decode: bool = True
def initialize(
self,
num_layers: int,
num_kv_heads: int,
head_dim: int,
num_cpu_blocks: int,
dtype: torch.dtype,
) -> None:
"""
Initialize policy resources.
Called by the framework after KV cache is allocated. Override this
to create metadata structures (e.g., BlockMetadataManager for Quest).
Default implementation does nothing.
Args:
num_layers: Number of transformer layers
num_kv_heads: Number of KV attention heads
head_dim: Dimension per head
num_cpu_blocks: Number of CPU blocks allocated
dtype: Data type for tensors
"""
pass
@abstractmethod @abstractmethod
def select_blocks( def select_blocks(
self, self,

View File

@@ -147,22 +147,40 @@ class QuestPolicy(SparsePolicy):
This upper bound is derived from the fact that for any key k in This upper bound is derived from the fact that for any key k in
the block: min_k <= k <= max_k (element-wise), so the actual the block: min_k <= k <= max_k (element-wise), so the actual
attention score is bounded by the maximum of the two extremes. attention score is bounded by the maximum of the two extremes.
Note: This is a decode-only policy. For prefill, use FullAttentionPolicy.
""" """
def __init__( # Quest is decode-only
self, supports_prefill = False
config: QuestConfig, supports_decode = True
metadata_manager: BlockMetadataManager,
): def __init__(self, config: QuestConfig):
""" """
Initialize Quest policy. Initialize Quest policy.
Args: Args:
config: QuestConfig with selection parameters config: QuestConfig with selection parameters
metadata_manager: BlockMetadataManager for min/max key storage
""" """
self.config = config self.config = config
self.metadata = metadata_manager self.metadata: Optional[BlockMetadataManager] = None
def initialize(
self,
num_layers: int,
num_kv_heads: int,
head_dim: int,
num_cpu_blocks: int,
dtype: torch.dtype,
) -> None:
"""Create BlockMetadataManager for storing min/max keys."""
self.metadata = BlockMetadataManager(
num_blocks=num_cpu_blocks,
num_layers=num_layers,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
dtype=dtype,
)
def select_blocks( def select_blocks(
self, self,
@@ -175,6 +193,12 @@ class QuestPolicy(SparsePolicy):
If query is not available (some prefill scenarios), falls back If query is not available (some prefill scenarios), falls back
to loading all blocks. to loading all blocks.
""" """
if self.metadata is None:
raise RuntimeError(
"QuestPolicy not initialized. Call initialize() first or "
"let the framework call it during KV cache allocation."
)
n = len(available_blocks) n = len(available_blocks)
# If below threshold or no query, load all # If below threshold or no query, load all
@@ -269,10 +293,12 @@ class QuestPolicy(SparsePolicy):
num_valid_tokens: int, num_valid_tokens: int,
) -> None: ) -> None:
"""Update min/max key metadata when block is offloaded.""" """Update min/max key metadata when block is offloaded."""
if self.metadata is not None:
self.metadata.update_metadata(cpu_block_id, layer_id, k_cache, num_valid_tokens) self.metadata.update_metadata(cpu_block_id, layer_id, k_cache, num_valid_tokens)
def reset(self) -> None: def reset(self) -> None:
"""Reset metadata.""" """Reset metadata."""
if self.metadata is not None:
self.metadata.reset() self.metadata.reset()
def __repr__(self) -> str: def __repr__(self) -> str:

View File

@@ -1,84 +0,0 @@
"""
StreamingLLM sparse attention policy.
Only keeps sink tokens (beginning) + recent tokens (end).
Intermediate context is discarded. This enables infinite-length
generation but loses intermediate context.
Reference: StreamingLLM paper on attention sinks.
"""
from dataclasses import dataclass
from typing import List
from .policy import SparsePolicy, PolicyContext
@dataclass
class StreamingLLMConfig:
"""Configuration for StreamingLLMPolicy."""
num_sink_blocks: int = 1
"""Number of blocks at the beginning to always include (attention sinks)."""
num_recent_blocks: int = 3
"""Number of most recent blocks to include (sliding window)."""
class StreamingLLMPolicy(SparsePolicy):
"""
StreamingLLM pattern: sink tokens + recent tokens only.
This is the most aggressive sparsity pattern - only keeps a small
fixed window of context. Suitable for:
- Very long streaming generation
- When intermediate context can be safely discarded
- Maximizing throughput over accuracy
Pattern visualization:
```
Blocks: [0] [1] [2] [3] [4] [5] [6] [7] [8]
× × × ↑ ↑ ↑
sink (discarded) recent window
```
Warning: This loses information from intermediate blocks!
Use only when this trade-off is acceptable.
"""
def __init__(self, config: StreamingLLMConfig = None):
self.config = config or StreamingLLMConfig()
def select_blocks(
self,
available_blocks: List[int],
ctx: PolicyContext,
) -> List[int]:
"""
Select sink blocks + recent blocks only.
Intermediate blocks are not loaded (effectively discarded).
"""
n = len(available_blocks)
# If total blocks fit in sink + recent, load all
total_keep = self.config.num_sink_blocks + self.config.num_recent_blocks
if n <= total_keep:
return available_blocks
selected_indices = set()
# Sink blocks (first N)
for i in range(min(self.config.num_sink_blocks, n)):
selected_indices.add(i)
# Recent blocks (last M)
for i in range(max(0, n - self.config.num_recent_blocks), n):
selected_indices.add(i)
return [available_blocks[i] for i in sorted(selected_indices)]
def __repr__(self) -> str:
return (
f"StreamingLLMPolicy(sink={self.config.num_sink_blocks}, "
f"recent={self.config.num_recent_blocks})"
)

View File

@@ -1,95 +0,0 @@
"""
Vertical-Slash sparse attention policy (MInference-style).
Selects sink blocks (beginning of sequence) + local window blocks
(near the current query position). This pattern captures:
- Important initial context (system prompt, instructions)
- Recent context (relevant for local dependencies)
"""
from dataclasses import dataclass
from typing import List
from .policy import SparsePolicy, PolicyContext
@dataclass
class VerticalSlashConfig:
"""Configuration for VerticalSlashPolicy."""
num_sink_blocks: int = 1
"""Number of blocks at the beginning to always include (sink tokens)."""
local_window_blocks: int = 2
"""Number of blocks in the local window near current query position."""
threshold_blocks: int = 4
"""If total blocks <= threshold, load all (no sparsity applied)."""
class VerticalSlashPolicy(SparsePolicy):
"""
Vertical-Slash pattern: sink tokens + local window.
This pattern is inspired by MInference and observations that:
1. Initial tokens (sink) often receive high attention
2. Local context (recent tokens) is important for dependencies
Pattern visualization:
```
Blocks: [0] [1] [2] [3] [4] [5] [6] [7] [8]
↑ ↑ ↑ ↑
sink local window (for query at block 9)
```
For prefill chunk K, the local window is blocks [K-window, K-1].
For decode, the local window is the last N blocks.
"""
def __init__(self, config: VerticalSlashConfig = None):
self.config = config or VerticalSlashConfig()
def select_blocks(
self,
available_blocks: List[int],
ctx: PolicyContext,
) -> List[int]:
"""
Select sink blocks + local window blocks.
For prefill: local window is relative to current chunk position.
For decode: local window is the most recent blocks.
"""
n = len(available_blocks)
# If below threshold, load all
if n <= self.config.threshold_blocks:
return available_blocks
selected_indices = set()
# Sink blocks (first N blocks)
for i in range(min(self.config.num_sink_blocks, n)):
selected_indices.add(i)
# Local window
if ctx.is_prefill:
# For prefill chunk K, local window is blocks [K-window, K-1]
# (blocks before current chunk, not including current)
window_end = min(ctx.query_chunk_idx, n)
window_start = max(0, window_end - self.config.local_window_blocks)
for i in range(window_start, window_end):
selected_indices.add(i)
else:
# For decode, local window is the last M blocks
for i in range(max(0, n - self.config.local_window_blocks), n):
selected_indices.add(i)
# Return blocks in order (maintains sequential access pattern)
return [available_blocks[i] for i in sorted(selected_indices)]
def __repr__(self) -> str:
return (
f"VerticalSlashPolicy(sink={self.config.num_sink_blocks}, "
f"window={self.config.local_window_blocks}, "
f"threshold={self.config.threshold_blocks})"
)