[WIP] Before refactor policies.
This commit is contained in:
@@ -5,86 +5,68 @@ Provides pluggable policies for selecting which KV blocks to load
|
|||||||
during chunked attention with CPU offload.
|
during chunked attention with CPU offload.
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
from nanovllm.kvcache.sparse import SparsePolicy, PolicyContext
|
from nanovllm.kvcache.sparse import create_sparse_policy, SparsePolicyType
|
||||||
from nanovllm.kvcache.sparse import VerticalSlashPolicy, QuestPolicy
|
|
||||||
|
|
||||||
# Use built-in policy
|
# Create policy using factory function
|
||||||
policy = VerticalSlashPolicy(VerticalSlashConfig())
|
policy = create_sparse_policy(SparsePolicyType.QUEST, topk_blocks=8)
|
||||||
|
|
||||||
# Or create custom policy
|
# Or create custom policy
|
||||||
class MyPolicy(SparsePolicy):
|
class MyPolicy(SparsePolicy):
|
||||||
|
supports_prefill = True
|
||||||
|
supports_decode = True
|
||||||
|
|
||||||
def select_blocks(self, available_blocks, ctx):
|
def select_blocks(self, available_blocks, ctx):
|
||||||
return available_blocks[:5] # Just first 5 blocks
|
return available_blocks[:5] # Just first 5 blocks
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext
|
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext, SparsePolicyType
|
||||||
from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy
|
from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy
|
||||||
from nanovllm.kvcache.sparse.vertical_slash import VerticalSlashPolicy, VerticalSlashConfig
|
|
||||||
from nanovllm.kvcache.sparse.quest import QuestPolicy, QuestConfig, BlockMetadataManager
|
from nanovllm.kvcache.sparse.quest import QuestPolicy, QuestConfig, BlockMetadataManager
|
||||||
from nanovllm.kvcache.sparse.streaming_llm import StreamingLLMPolicy, StreamingLLMConfig
|
|
||||||
from nanovllm.kvcache.sparse.hybrid import HybridPolicy
|
from nanovllm.kvcache.sparse.hybrid import HybridPolicy
|
||||||
|
|
||||||
# Built-in policy registry
|
|
||||||
BUILTIN_SPARSE_POLICIES = {
|
|
||||||
"full": FullAttentionPolicy,
|
|
||||||
"vertical_slash": VerticalSlashPolicy,
|
|
||||||
"streaming_llm": StreamingLLMPolicy,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
def create_sparse_policy(policy_type: SparsePolicyType, **kwargs) -> SparsePolicy:
|
||||||
def get_sparse_policy(policy_name: str, **kwargs) -> SparsePolicy:
|
|
||||||
"""
|
"""
|
||||||
Get a sparse attention policy instance by name.
|
Create a sparse policy instance from an enum type.
|
||||||
|
|
||||||
|
The returned policy is not yet initialized. Call policy.initialize()
|
||||||
|
or let the framework call it during KV cache allocation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
policy_name: Policy name ("full", "vertical_slash", "streaming_llm", "quest")
|
policy_type: SparsePolicyType enum value
|
||||||
**kwargs: Policy-specific configuration
|
**kwargs: Policy-specific configuration options
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
SparsePolicy instance
|
SparsePolicy instance (not initialized)
|
||||||
"""
|
|
||||||
policy_name = policy_name.lower()
|
|
||||||
|
|
||||||
if policy_name == "full":
|
Example:
|
||||||
|
policy = create_sparse_policy(SparsePolicyType.QUEST, topk_blocks=4)
|
||||||
|
policy.initialize(num_layers=28, num_kv_heads=8, ...)
|
||||||
|
"""
|
||||||
|
if policy_type == SparsePolicyType.FULL:
|
||||||
return FullAttentionPolicy()
|
return FullAttentionPolicy()
|
||||||
elif policy_name == "vertical_slash":
|
|
||||||
config = VerticalSlashConfig(
|
elif policy_type == SparsePolicyType.QUEST:
|
||||||
num_sink_blocks=kwargs.get("num_sink_blocks", 1),
|
config = QuestConfig(
|
||||||
local_window_blocks=kwargs.get("local_window_blocks", 2),
|
topk_blocks=kwargs.get("topk_blocks", 8),
|
||||||
threshold_blocks=kwargs.get("threshold_blocks", 4),
|
threshold_blocks=kwargs.get("threshold_blocks", 4),
|
||||||
|
include_sink_blocks=kwargs.get("include_sink_blocks", 0),
|
||||||
|
include_recent_blocks=kwargs.get("include_recent_blocks", 0),
|
||||||
)
|
)
|
||||||
return VerticalSlashPolicy(config)
|
return QuestPolicy(config)
|
||||||
elif policy_name == "streaming_llm":
|
|
||||||
config = StreamingLLMConfig(
|
|
||||||
num_sink_blocks=kwargs.get("num_sink_blocks", 1),
|
|
||||||
num_recent_blocks=kwargs.get("num_recent_blocks", 3),
|
|
||||||
)
|
|
||||||
return StreamingLLMPolicy(config)
|
|
||||||
elif policy_name == "quest":
|
|
||||||
# Quest requires metadata_manager to be passed separately
|
|
||||||
raise ValueError(
|
|
||||||
"Quest policy requires BlockMetadataManager. "
|
|
||||||
"Use QuestPolicy(config, metadata_manager) directly."
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(f"Unknown policy type: {policy_type}")
|
||||||
f"Unknown sparse policy '{policy_name}'. "
|
|
||||||
f"Available policies: {list(BUILTIN_SPARSE_POLICIES.keys())}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"SparsePolicy",
|
"SparsePolicy",
|
||||||
"PolicyContext",
|
"PolicyContext",
|
||||||
|
"SparsePolicyType",
|
||||||
"FullAttentionPolicy",
|
"FullAttentionPolicy",
|
||||||
"VerticalSlashPolicy",
|
|
||||||
"VerticalSlashConfig",
|
|
||||||
"QuestPolicy",
|
"QuestPolicy",
|
||||||
"QuestConfig",
|
"QuestConfig",
|
||||||
"BlockMetadataManager",
|
"BlockMetadataManager",
|
||||||
"StreamingLLMPolicy",
|
|
||||||
"StreamingLLMConfig",
|
|
||||||
"HybridPolicy",
|
"HybridPolicy",
|
||||||
"get_sparse_policy",
|
"create_sparse_policy",
|
||||||
"BUILTIN_SPARSE_POLICIES",
|
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -22,6 +22,10 @@ class FullAttentionPolicy(SparsePolicy):
|
|||||||
- For short sequences where sparsity isn't beneficial
|
- For short sequences where sparsity isn't beneficial
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Full attention supports both prefill and decode
|
||||||
|
supports_prefill = True
|
||||||
|
supports_decode = True
|
||||||
|
|
||||||
def select_blocks(
|
def select_blocks(
|
||||||
self,
|
self,
|
||||||
available_blocks: List[int],
|
available_blocks: List[int],
|
||||||
|
|||||||
@@ -7,10 +7,17 @@ from CPU for each query chunk during chunked attention computation.
|
|||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum, auto
|
||||||
from typing import List, Optional, Any
|
from typing import List, Optional, Any
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class SparsePolicyType(Enum):
|
||||||
|
"""Built-in sparse attention policy types."""
|
||||||
|
FULL = auto() # prefill + decode
|
||||||
|
QUEST = auto() # decode only
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PolicyContext:
|
class PolicyContext:
|
||||||
"""
|
"""
|
||||||
@@ -54,8 +61,15 @@ class SparsePolicy(ABC):
|
|||||||
sparse attention patterns. The policy receives context about
|
sparse attention patterns. The policy receives context about
|
||||||
the current query chunk and returns which KV blocks to load.
|
the current query chunk and returns which KV blocks to load.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
supports_prefill: Whether this policy can be used for prefill phase.
|
||||||
|
supports_decode: Whether this policy can be used for decode phase.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
class MySparsePolicy(SparsePolicy):
|
class MySparsePolicy(SparsePolicy):
|
||||||
|
supports_prefill = False # decode-only policy
|
||||||
|
supports_decode = True
|
||||||
|
|
||||||
def select_blocks(self, available_blocks, ctx):
|
def select_blocks(self, available_blocks, ctx):
|
||||||
# Load first block and last 2 blocks
|
# Load first block and last 2 blocks
|
||||||
if len(available_blocks) <= 3:
|
if len(available_blocks) <= 3:
|
||||||
@@ -63,6 +77,34 @@ class SparsePolicy(ABC):
|
|||||||
return [available_blocks[0]] + available_blocks[-2:]
|
return [available_blocks[0]] + available_blocks[-2:]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Compatibility flags - override in subclasses
|
||||||
|
supports_prefill: bool = True
|
||||||
|
supports_decode: bool = True
|
||||||
|
|
||||||
|
def initialize(
|
||||||
|
self,
|
||||||
|
num_layers: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
head_dim: int,
|
||||||
|
num_cpu_blocks: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize policy resources.
|
||||||
|
|
||||||
|
Called by the framework after KV cache is allocated. Override this
|
||||||
|
to create metadata structures (e.g., BlockMetadataManager for Quest).
|
||||||
|
Default implementation does nothing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_layers: Number of transformer layers
|
||||||
|
num_kv_heads: Number of KV attention heads
|
||||||
|
head_dim: Dimension per head
|
||||||
|
num_cpu_blocks: Number of CPU blocks allocated
|
||||||
|
dtype: Data type for tensors
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def select_blocks(
|
def select_blocks(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -147,22 +147,40 @@ class QuestPolicy(SparsePolicy):
|
|||||||
This upper bound is derived from the fact that for any key k in
|
This upper bound is derived from the fact that for any key k in
|
||||||
the block: min_k <= k <= max_k (element-wise), so the actual
|
the block: min_k <= k <= max_k (element-wise), so the actual
|
||||||
attention score is bounded by the maximum of the two extremes.
|
attention score is bounded by the maximum of the two extremes.
|
||||||
|
|
||||||
|
Note: This is a decode-only policy. For prefill, use FullAttentionPolicy.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
# Quest is decode-only
|
||||||
self,
|
supports_prefill = False
|
||||||
config: QuestConfig,
|
supports_decode = True
|
||||||
metadata_manager: BlockMetadataManager,
|
|
||||||
):
|
def __init__(self, config: QuestConfig):
|
||||||
"""
|
"""
|
||||||
Initialize Quest policy.
|
Initialize Quest policy.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config: QuestConfig with selection parameters
|
config: QuestConfig with selection parameters
|
||||||
metadata_manager: BlockMetadataManager for min/max key storage
|
|
||||||
"""
|
"""
|
||||||
self.config = config
|
self.config = config
|
||||||
self.metadata = metadata_manager
|
self.metadata: Optional[BlockMetadataManager] = None
|
||||||
|
|
||||||
|
def initialize(
|
||||||
|
self,
|
||||||
|
num_layers: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
head_dim: int,
|
||||||
|
num_cpu_blocks: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
) -> None:
|
||||||
|
"""Create BlockMetadataManager for storing min/max keys."""
|
||||||
|
self.metadata = BlockMetadataManager(
|
||||||
|
num_blocks=num_cpu_blocks,
|
||||||
|
num_layers=num_layers,
|
||||||
|
num_kv_heads=num_kv_heads,
|
||||||
|
head_dim=head_dim,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
def select_blocks(
|
def select_blocks(
|
||||||
self,
|
self,
|
||||||
@@ -175,6 +193,12 @@ class QuestPolicy(SparsePolicy):
|
|||||||
If query is not available (some prefill scenarios), falls back
|
If query is not available (some prefill scenarios), falls back
|
||||||
to loading all blocks.
|
to loading all blocks.
|
||||||
"""
|
"""
|
||||||
|
if self.metadata is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"QuestPolicy not initialized. Call initialize() first or "
|
||||||
|
"let the framework call it during KV cache allocation."
|
||||||
|
)
|
||||||
|
|
||||||
n = len(available_blocks)
|
n = len(available_blocks)
|
||||||
|
|
||||||
# If below threshold or no query, load all
|
# If below threshold or no query, load all
|
||||||
@@ -269,11 +293,13 @@ class QuestPolicy(SparsePolicy):
|
|||||||
num_valid_tokens: int,
|
num_valid_tokens: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Update min/max key metadata when block is offloaded."""
|
"""Update min/max key metadata when block is offloaded."""
|
||||||
self.metadata.update_metadata(cpu_block_id, layer_id, k_cache, num_valid_tokens)
|
if self.metadata is not None:
|
||||||
|
self.metadata.update_metadata(cpu_block_id, layer_id, k_cache, num_valid_tokens)
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
"""Reset metadata."""
|
"""Reset metadata."""
|
||||||
self.metadata.reset()
|
if self.metadata is not None:
|
||||||
|
self.metadata.reset()
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return (
|
return (
|
||||||
|
|||||||
@@ -1,84 +0,0 @@
|
|||||||
"""
|
|
||||||
StreamingLLM sparse attention policy.
|
|
||||||
|
|
||||||
Only keeps sink tokens (beginning) + recent tokens (end).
|
|
||||||
Intermediate context is discarded. This enables infinite-length
|
|
||||||
generation but loses intermediate context.
|
|
||||||
|
|
||||||
Reference: StreamingLLM paper on attention sinks.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import List
|
|
||||||
from .policy import SparsePolicy, PolicyContext
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class StreamingLLMConfig:
|
|
||||||
"""Configuration for StreamingLLMPolicy."""
|
|
||||||
|
|
||||||
num_sink_blocks: int = 1
|
|
||||||
"""Number of blocks at the beginning to always include (attention sinks)."""
|
|
||||||
|
|
||||||
num_recent_blocks: int = 3
|
|
||||||
"""Number of most recent blocks to include (sliding window)."""
|
|
||||||
|
|
||||||
|
|
||||||
class StreamingLLMPolicy(SparsePolicy):
|
|
||||||
"""
|
|
||||||
StreamingLLM pattern: sink tokens + recent tokens only.
|
|
||||||
|
|
||||||
This is the most aggressive sparsity pattern - only keeps a small
|
|
||||||
fixed window of context. Suitable for:
|
|
||||||
- Very long streaming generation
|
|
||||||
- When intermediate context can be safely discarded
|
|
||||||
- Maximizing throughput over accuracy
|
|
||||||
|
|
||||||
Pattern visualization:
|
|
||||||
```
|
|
||||||
Blocks: [0] [1] [2] [3] [4] [5] [6] [7] [8]
|
|
||||||
↑ × × × ↑ ↑ ↑
|
|
||||||
sink (discarded) recent window
|
|
||||||
```
|
|
||||||
|
|
||||||
Warning: This loses information from intermediate blocks!
|
|
||||||
Use only when this trade-off is acceptable.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config: StreamingLLMConfig = None):
|
|
||||||
self.config = config or StreamingLLMConfig()
|
|
||||||
|
|
||||||
def select_blocks(
|
|
||||||
self,
|
|
||||||
available_blocks: List[int],
|
|
||||||
ctx: PolicyContext,
|
|
||||||
) -> List[int]:
|
|
||||||
"""
|
|
||||||
Select sink blocks + recent blocks only.
|
|
||||||
|
|
||||||
Intermediate blocks are not loaded (effectively discarded).
|
|
||||||
"""
|
|
||||||
n = len(available_blocks)
|
|
||||||
|
|
||||||
# If total blocks fit in sink + recent, load all
|
|
||||||
total_keep = self.config.num_sink_blocks + self.config.num_recent_blocks
|
|
||||||
if n <= total_keep:
|
|
||||||
return available_blocks
|
|
||||||
|
|
||||||
selected_indices = set()
|
|
||||||
|
|
||||||
# Sink blocks (first N)
|
|
||||||
for i in range(min(self.config.num_sink_blocks, n)):
|
|
||||||
selected_indices.add(i)
|
|
||||||
|
|
||||||
# Recent blocks (last M)
|
|
||||||
for i in range(max(0, n - self.config.num_recent_blocks), n):
|
|
||||||
selected_indices.add(i)
|
|
||||||
|
|
||||||
return [available_blocks[i] for i in sorted(selected_indices)]
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return (
|
|
||||||
f"StreamingLLMPolicy(sink={self.config.num_sink_blocks}, "
|
|
||||||
f"recent={self.config.num_recent_blocks})"
|
|
||||||
)
|
|
||||||
@@ -1,95 +0,0 @@
|
|||||||
"""
|
|
||||||
Vertical-Slash sparse attention policy (MInference-style).
|
|
||||||
|
|
||||||
Selects sink blocks (beginning of sequence) + local window blocks
|
|
||||||
(near the current query position). This pattern captures:
|
|
||||||
- Important initial context (system prompt, instructions)
|
|
||||||
- Recent context (relevant for local dependencies)
|
|
||||||
"""
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import List
|
|
||||||
from .policy import SparsePolicy, PolicyContext
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class VerticalSlashConfig:
|
|
||||||
"""Configuration for VerticalSlashPolicy."""
|
|
||||||
|
|
||||||
num_sink_blocks: int = 1
|
|
||||||
"""Number of blocks at the beginning to always include (sink tokens)."""
|
|
||||||
|
|
||||||
local_window_blocks: int = 2
|
|
||||||
"""Number of blocks in the local window near current query position."""
|
|
||||||
|
|
||||||
threshold_blocks: int = 4
|
|
||||||
"""If total blocks <= threshold, load all (no sparsity applied)."""
|
|
||||||
|
|
||||||
|
|
||||||
class VerticalSlashPolicy(SparsePolicy):
|
|
||||||
"""
|
|
||||||
Vertical-Slash pattern: sink tokens + local window.
|
|
||||||
|
|
||||||
This pattern is inspired by MInference and observations that:
|
|
||||||
1. Initial tokens (sink) often receive high attention
|
|
||||||
2. Local context (recent tokens) is important for dependencies
|
|
||||||
|
|
||||||
Pattern visualization:
|
|
||||||
```
|
|
||||||
Blocks: [0] [1] [2] [3] [4] [5] [6] [7] [8]
|
|
||||||
↑ ↑ ↑ ↑
|
|
||||||
sink local window (for query at block 9)
|
|
||||||
```
|
|
||||||
|
|
||||||
For prefill chunk K, the local window is blocks [K-window, K-1].
|
|
||||||
For decode, the local window is the last N blocks.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config: VerticalSlashConfig = None):
|
|
||||||
self.config = config or VerticalSlashConfig()
|
|
||||||
|
|
||||||
def select_blocks(
|
|
||||||
self,
|
|
||||||
available_blocks: List[int],
|
|
||||||
ctx: PolicyContext,
|
|
||||||
) -> List[int]:
|
|
||||||
"""
|
|
||||||
Select sink blocks + local window blocks.
|
|
||||||
|
|
||||||
For prefill: local window is relative to current chunk position.
|
|
||||||
For decode: local window is the most recent blocks.
|
|
||||||
"""
|
|
||||||
n = len(available_blocks)
|
|
||||||
|
|
||||||
# If below threshold, load all
|
|
||||||
if n <= self.config.threshold_blocks:
|
|
||||||
return available_blocks
|
|
||||||
|
|
||||||
selected_indices = set()
|
|
||||||
|
|
||||||
# Sink blocks (first N blocks)
|
|
||||||
for i in range(min(self.config.num_sink_blocks, n)):
|
|
||||||
selected_indices.add(i)
|
|
||||||
|
|
||||||
# Local window
|
|
||||||
if ctx.is_prefill:
|
|
||||||
# For prefill chunk K, local window is blocks [K-window, K-1]
|
|
||||||
# (blocks before current chunk, not including current)
|
|
||||||
window_end = min(ctx.query_chunk_idx, n)
|
|
||||||
window_start = max(0, window_end - self.config.local_window_blocks)
|
|
||||||
for i in range(window_start, window_end):
|
|
||||||
selected_indices.add(i)
|
|
||||||
else:
|
|
||||||
# For decode, local window is the last M blocks
|
|
||||||
for i in range(max(0, n - self.config.local_window_blocks), n):
|
|
||||||
selected_indices.add(i)
|
|
||||||
|
|
||||||
# Return blocks in order (maintains sequential access pattern)
|
|
||||||
return [available_blocks[i] for i in sorted(selected_indices)]
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return (
|
|
||||||
f"VerticalSlashPolicy(sink={self.config.num_sink_blocks}, "
|
|
||||||
f"window={self.config.local_window_blocks}, "
|
|
||||||
f"threshold={self.config.threshold_blocks})"
|
|
||||||
)
|
|
||||||
Reference in New Issue
Block a user