From 690492e074bd02d3773936028ee27435fe4350f6 Mon Sep 17 00:00:00 2001 From: Zijie Tian Date: Tue, 6 Jan 2026 20:47:55 +0800 Subject: [PATCH] [WIP] Before refactor policies. --- nanovllm/kvcache/sparse/__init__.py | 80 ++++++++----------- nanovllm/kvcache/sparse/full_policy.py | 4 + nanovllm/kvcache/sparse/policy.py | 42 ++++++++++ nanovllm/kvcache/sparse/quest.py | 44 ++++++++--- nanovllm/kvcache/sparse/streaming_llm.py | 84 -------------------- nanovllm/kvcache/sparse/vertical_slash.py | 95 ----------------------- 6 files changed, 112 insertions(+), 237 deletions(-) delete mode 100644 nanovllm/kvcache/sparse/streaming_llm.py delete mode 100644 nanovllm/kvcache/sparse/vertical_slash.py diff --git a/nanovllm/kvcache/sparse/__init__.py b/nanovllm/kvcache/sparse/__init__.py index d397b0f..cfc08f0 100644 --- a/nanovllm/kvcache/sparse/__init__.py +++ b/nanovllm/kvcache/sparse/__init__.py @@ -5,86 +5,68 @@ Provides pluggable policies for selecting which KV blocks to load during chunked attention with CPU offload. Usage: - from nanovllm.kvcache.sparse import SparsePolicy, PolicyContext - from nanovllm.kvcache.sparse import VerticalSlashPolicy, QuestPolicy + from nanovllm.kvcache.sparse import create_sparse_policy, SparsePolicyType - # Use built-in policy - policy = VerticalSlashPolicy(VerticalSlashConfig()) + # Create policy using factory function + policy = create_sparse_policy(SparsePolicyType.QUEST, topk_blocks=8) # Or create custom policy class MyPolicy(SparsePolicy): + supports_prefill = True + supports_decode = True + def select_blocks(self, available_blocks, ctx): return available_blocks[:5] # Just first 5 blocks """ -from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext +from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext, SparsePolicyType from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy -from nanovllm.kvcache.sparse.vertical_slash import VerticalSlashPolicy, VerticalSlashConfig from nanovllm.kvcache.sparse.quest import QuestPolicy, QuestConfig, BlockMetadataManager -from nanovllm.kvcache.sparse.streaming_llm import StreamingLLMPolicy, StreamingLLMConfig from nanovllm.kvcache.sparse.hybrid import HybridPolicy -# Built-in policy registry -BUILTIN_SPARSE_POLICIES = { - "full": FullAttentionPolicy, - "vertical_slash": VerticalSlashPolicy, - "streaming_llm": StreamingLLMPolicy, -} - -def get_sparse_policy(policy_name: str, **kwargs) -> SparsePolicy: +def create_sparse_policy(policy_type: SparsePolicyType, **kwargs) -> SparsePolicy: """ - Get a sparse attention policy instance by name. + Create a sparse policy instance from an enum type. + + The returned policy is not yet initialized. Call policy.initialize() + or let the framework call it during KV cache allocation. Args: - policy_name: Policy name ("full", "vertical_slash", "streaming_llm", "quest") - **kwargs: Policy-specific configuration + policy_type: SparsePolicyType enum value + **kwargs: Policy-specific configuration options Returns: - SparsePolicy instance - """ - policy_name = policy_name.lower() + SparsePolicy instance (not initialized) - if policy_name == "full": + Example: + policy = create_sparse_policy(SparsePolicyType.QUEST, topk_blocks=4) + policy.initialize(num_layers=28, num_kv_heads=8, ...) + """ + if policy_type == SparsePolicyType.FULL: return FullAttentionPolicy() - elif policy_name == "vertical_slash": - config = VerticalSlashConfig( - num_sink_blocks=kwargs.get("num_sink_blocks", 1), - local_window_blocks=kwargs.get("local_window_blocks", 2), + + elif policy_type == SparsePolicyType.QUEST: + config = QuestConfig( + topk_blocks=kwargs.get("topk_blocks", 8), threshold_blocks=kwargs.get("threshold_blocks", 4), + include_sink_blocks=kwargs.get("include_sink_blocks", 0), + include_recent_blocks=kwargs.get("include_recent_blocks", 0), ) - return VerticalSlashPolicy(config) - elif policy_name == "streaming_llm": - config = StreamingLLMConfig( - num_sink_blocks=kwargs.get("num_sink_blocks", 1), - num_recent_blocks=kwargs.get("num_recent_blocks", 3), - ) - return StreamingLLMPolicy(config) - elif policy_name == "quest": - # Quest requires metadata_manager to be passed separately - raise ValueError( - "Quest policy requires BlockMetadataManager. " - "Use QuestPolicy(config, metadata_manager) directly." - ) + return QuestPolicy(config) + else: - raise ValueError( - f"Unknown sparse policy '{policy_name}'. " - f"Available policies: {list(BUILTIN_SPARSE_POLICIES.keys())}" - ) + raise ValueError(f"Unknown policy type: {policy_type}") __all__ = [ "SparsePolicy", "PolicyContext", + "SparsePolicyType", "FullAttentionPolicy", - "VerticalSlashPolicy", - "VerticalSlashConfig", "QuestPolicy", "QuestConfig", "BlockMetadataManager", - "StreamingLLMPolicy", - "StreamingLLMConfig", "HybridPolicy", - "get_sparse_policy", - "BUILTIN_SPARSE_POLICIES", + "create_sparse_policy", ] diff --git a/nanovllm/kvcache/sparse/full_policy.py b/nanovllm/kvcache/sparse/full_policy.py index 6e57d5c..a6cff50 100644 --- a/nanovllm/kvcache/sparse/full_policy.py +++ b/nanovllm/kvcache/sparse/full_policy.py @@ -22,6 +22,10 @@ class FullAttentionPolicy(SparsePolicy): - For short sequences where sparsity isn't beneficial """ + # Full attention supports both prefill and decode + supports_prefill = True + supports_decode = True + def select_blocks( self, available_blocks: List[int], diff --git a/nanovllm/kvcache/sparse/policy.py b/nanovllm/kvcache/sparse/policy.py index 41a0f87..fab87ca 100644 --- a/nanovllm/kvcache/sparse/policy.py +++ b/nanovllm/kvcache/sparse/policy.py @@ -7,10 +7,17 @@ from CPU for each query chunk during chunked attention computation. from abc import ABC, abstractmethod from dataclasses import dataclass +from enum import Enum, auto from typing import List, Optional, Any import torch +class SparsePolicyType(Enum): + """Built-in sparse attention policy types.""" + FULL = auto() # prefill + decode + QUEST = auto() # decode only + + @dataclass class PolicyContext: """ @@ -54,8 +61,15 @@ class SparsePolicy(ABC): sparse attention patterns. The policy receives context about the current query chunk and returns which KV blocks to load. + Attributes: + supports_prefill: Whether this policy can be used for prefill phase. + supports_decode: Whether this policy can be used for decode phase. + Example: class MySparsePolicy(SparsePolicy): + supports_prefill = False # decode-only policy + supports_decode = True + def select_blocks(self, available_blocks, ctx): # Load first block and last 2 blocks if len(available_blocks) <= 3: @@ -63,6 +77,34 @@ class SparsePolicy(ABC): return [available_blocks[0]] + available_blocks[-2:] """ + # Compatibility flags - override in subclasses + supports_prefill: bool = True + supports_decode: bool = True + + def initialize( + self, + num_layers: int, + num_kv_heads: int, + head_dim: int, + num_cpu_blocks: int, + dtype: torch.dtype, + ) -> None: + """ + Initialize policy resources. + + Called by the framework after KV cache is allocated. Override this + to create metadata structures (e.g., BlockMetadataManager for Quest). + Default implementation does nothing. + + Args: + num_layers: Number of transformer layers + num_kv_heads: Number of KV attention heads + head_dim: Dimension per head + num_cpu_blocks: Number of CPU blocks allocated + dtype: Data type for tensors + """ + pass + @abstractmethod def select_blocks( self, diff --git a/nanovllm/kvcache/sparse/quest.py b/nanovllm/kvcache/sparse/quest.py index 7439256..3583905 100644 --- a/nanovllm/kvcache/sparse/quest.py +++ b/nanovllm/kvcache/sparse/quest.py @@ -147,22 +147,40 @@ class QuestPolicy(SparsePolicy): This upper bound is derived from the fact that for any key k in the block: min_k <= k <= max_k (element-wise), so the actual attention score is bounded by the maximum of the two extremes. + + Note: This is a decode-only policy. For prefill, use FullAttentionPolicy. """ - def __init__( - self, - config: QuestConfig, - metadata_manager: BlockMetadataManager, - ): + # Quest is decode-only + supports_prefill = False + supports_decode = True + + def __init__(self, config: QuestConfig): """ Initialize Quest policy. Args: config: QuestConfig with selection parameters - metadata_manager: BlockMetadataManager for min/max key storage """ self.config = config - self.metadata = metadata_manager + self.metadata: Optional[BlockMetadataManager] = None + + def initialize( + self, + num_layers: int, + num_kv_heads: int, + head_dim: int, + num_cpu_blocks: int, + dtype: torch.dtype, + ) -> None: + """Create BlockMetadataManager for storing min/max keys.""" + self.metadata = BlockMetadataManager( + num_blocks=num_cpu_blocks, + num_layers=num_layers, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + dtype=dtype, + ) def select_blocks( self, @@ -175,6 +193,12 @@ class QuestPolicy(SparsePolicy): If query is not available (some prefill scenarios), falls back to loading all blocks. """ + if self.metadata is None: + raise RuntimeError( + "QuestPolicy not initialized. Call initialize() first or " + "let the framework call it during KV cache allocation." + ) + n = len(available_blocks) # If below threshold or no query, load all @@ -269,11 +293,13 @@ class QuestPolicy(SparsePolicy): num_valid_tokens: int, ) -> None: """Update min/max key metadata when block is offloaded.""" - self.metadata.update_metadata(cpu_block_id, layer_id, k_cache, num_valid_tokens) + if self.metadata is not None: + self.metadata.update_metadata(cpu_block_id, layer_id, k_cache, num_valid_tokens) def reset(self) -> None: """Reset metadata.""" - self.metadata.reset() + if self.metadata is not None: + self.metadata.reset() def __repr__(self) -> str: return ( diff --git a/nanovllm/kvcache/sparse/streaming_llm.py b/nanovllm/kvcache/sparse/streaming_llm.py deleted file mode 100644 index 29606cb..0000000 --- a/nanovllm/kvcache/sparse/streaming_llm.py +++ /dev/null @@ -1,84 +0,0 @@ -""" -StreamingLLM sparse attention policy. - -Only keeps sink tokens (beginning) + recent tokens (end). -Intermediate context is discarded. This enables infinite-length -generation but loses intermediate context. - -Reference: StreamingLLM paper on attention sinks. -""" - -from dataclasses import dataclass -from typing import List -from .policy import SparsePolicy, PolicyContext - - -@dataclass -class StreamingLLMConfig: - """Configuration for StreamingLLMPolicy.""" - - num_sink_blocks: int = 1 - """Number of blocks at the beginning to always include (attention sinks).""" - - num_recent_blocks: int = 3 - """Number of most recent blocks to include (sliding window).""" - - -class StreamingLLMPolicy(SparsePolicy): - """ - StreamingLLM pattern: sink tokens + recent tokens only. - - This is the most aggressive sparsity pattern - only keeps a small - fixed window of context. Suitable for: - - Very long streaming generation - - When intermediate context can be safely discarded - - Maximizing throughput over accuracy - - Pattern visualization: - ``` - Blocks: [0] [1] [2] [3] [4] [5] [6] [7] [8] - ↑ × × × ↑ ↑ ↑ - sink (discarded) recent window - ``` - - Warning: This loses information from intermediate blocks! - Use only when this trade-off is acceptable. - """ - - def __init__(self, config: StreamingLLMConfig = None): - self.config = config or StreamingLLMConfig() - - def select_blocks( - self, - available_blocks: List[int], - ctx: PolicyContext, - ) -> List[int]: - """ - Select sink blocks + recent blocks only. - - Intermediate blocks are not loaded (effectively discarded). - """ - n = len(available_blocks) - - # If total blocks fit in sink + recent, load all - total_keep = self.config.num_sink_blocks + self.config.num_recent_blocks - if n <= total_keep: - return available_blocks - - selected_indices = set() - - # Sink blocks (first N) - for i in range(min(self.config.num_sink_blocks, n)): - selected_indices.add(i) - - # Recent blocks (last M) - for i in range(max(0, n - self.config.num_recent_blocks), n): - selected_indices.add(i) - - return [available_blocks[i] for i in sorted(selected_indices)] - - def __repr__(self) -> str: - return ( - f"StreamingLLMPolicy(sink={self.config.num_sink_blocks}, " - f"recent={self.config.num_recent_blocks})" - ) diff --git a/nanovllm/kvcache/sparse/vertical_slash.py b/nanovllm/kvcache/sparse/vertical_slash.py deleted file mode 100644 index 372b4b6..0000000 --- a/nanovllm/kvcache/sparse/vertical_slash.py +++ /dev/null @@ -1,95 +0,0 @@ -""" -Vertical-Slash sparse attention policy (MInference-style). - -Selects sink blocks (beginning of sequence) + local window blocks -(near the current query position). This pattern captures: -- Important initial context (system prompt, instructions) -- Recent context (relevant for local dependencies) -""" - -from dataclasses import dataclass -from typing import List -from .policy import SparsePolicy, PolicyContext - - -@dataclass -class VerticalSlashConfig: - """Configuration for VerticalSlashPolicy.""" - - num_sink_blocks: int = 1 - """Number of blocks at the beginning to always include (sink tokens).""" - - local_window_blocks: int = 2 - """Number of blocks in the local window near current query position.""" - - threshold_blocks: int = 4 - """If total blocks <= threshold, load all (no sparsity applied).""" - - -class VerticalSlashPolicy(SparsePolicy): - """ - Vertical-Slash pattern: sink tokens + local window. - - This pattern is inspired by MInference and observations that: - 1. Initial tokens (sink) often receive high attention - 2. Local context (recent tokens) is important for dependencies - - Pattern visualization: - ``` - Blocks: [0] [1] [2] [3] [4] [5] [6] [7] [8] - ↑ ↑ ↑ ↑ - sink local window (for query at block 9) - ``` - - For prefill chunk K, the local window is blocks [K-window, K-1]. - For decode, the local window is the last N blocks. - """ - - def __init__(self, config: VerticalSlashConfig = None): - self.config = config or VerticalSlashConfig() - - def select_blocks( - self, - available_blocks: List[int], - ctx: PolicyContext, - ) -> List[int]: - """ - Select sink blocks + local window blocks. - - For prefill: local window is relative to current chunk position. - For decode: local window is the most recent blocks. - """ - n = len(available_blocks) - - # If below threshold, load all - if n <= self.config.threshold_blocks: - return available_blocks - - selected_indices = set() - - # Sink blocks (first N blocks) - for i in range(min(self.config.num_sink_blocks, n)): - selected_indices.add(i) - - # Local window - if ctx.is_prefill: - # For prefill chunk K, local window is blocks [K-window, K-1] - # (blocks before current chunk, not including current) - window_end = min(ctx.query_chunk_idx, n) - window_start = max(0, window_end - self.config.local_window_blocks) - for i in range(window_start, window_end): - selected_indices.add(i) - else: - # For decode, local window is the last M blocks - for i in range(max(0, n - self.config.local_window_blocks), n): - selected_indices.add(i) - - # Return blocks in order (maintains sequential access pattern) - return [available_blocks[i] for i in sorted(selected_indices)] - - def __repr__(self) -> str: - return ( - f"VerticalSlashPolicy(sink={self.config.num_sink_blocks}, " - f"window={self.config.local_window_blocks}, " - f"threshold={self.config.threshold_blocks})" - )