[WIP] need refactor.

This commit is contained in:
Zijie Tian
2026-01-22 22:20:34 +08:00
parent 69b779e252
commit 5fb0f67295
11 changed files with 514 additions and 548 deletions

View File

@@ -1,49 +1,56 @@
"""
Sparse Attention Policy module.
Attention Policy module for layerwise offload mode.
Provides pluggable policies for selecting which KV blocks to load
during chunked attention with CPU offload.
Provides pluggable policies for attention computation:
- FullAttentionPolicy: Standard FlashAttention (no sparsity)
- XAttentionPolicy: Sparse prefill using XAttention algorithm
- MInferencePolicy: MInference sparse attention
- QuestPolicy: Quest block selection (for chunked offload)
Usage:
from nanovllm.kvcache.sparse import create_sparse_policy, SparsePolicyType
from nanovllm.kvcache.sparse import create_attention_policy, SparsePolicyType
# Create policy using factory function
policy = create_sparse_policy(SparsePolicyType.QUEST, topk_blocks=8)
policy = create_attention_policy(SparsePolicyType.XATTN, threshold=0.9)
# Use policy for attention
attn_output = policy.compute_prefill(q, k, v, layer_id, softmax_scale)
# Or create custom policy
class MyPolicy(SparsePolicy):
class MyPolicy(AttentionPolicy):
supports_prefill = True
supports_decode = True
def select_blocks(self, available_blocks, ctx):
return available_blocks[:5] # Just first 5 blocks
def compute_prefill(self, q, k, v, layer_id, softmax_scale):
# Custom attention computation
...
"""
from nanovllm.config import SparsePolicyType
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext
from nanovllm.kvcache.sparse.policy import AttentionPolicy, SparsePolicy, PolicyContext
from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy
from nanovllm.kvcache.sparse.quest import QuestPolicy, QuestConfig, BlockMetadataManager
from nanovllm.kvcache.sparse.minference import MInferencePolicy
from nanovllm.kvcache.sparse.xattn import XAttentionPolicy
def create_sparse_policy(policy_type: SparsePolicyType, **kwargs) -> SparsePolicy:
def create_attention_policy(policy_type: SparsePolicyType, **kwargs) -> AttentionPolicy:
"""
Create a sparse policy instance from an enum type.
Create an attention policy instance from an enum type.
The returned policy is not yet initialized. Call policy.initialize()
or let the framework call it during KV cache allocation.
All attention (including full attention) goes through a policy in layerwise
offload mode. The policy is responsible for computing prefill/decode attention.
Args:
policy_type: SparsePolicyType enum value
policy_type: SparsePolicyType enum value (FULL, XATTN, MINFERENCE, QUEST)
**kwargs: Policy-specific configuration options
Returns:
SparsePolicy instance (not initialized)
AttentionPolicy instance
Example:
policy = create_sparse_policy(SparsePolicyType.QUEST, topk_blocks=4)
policy.initialize(num_layers=28, num_kv_heads=8, ...)
policy = create_attention_policy(SparsePolicyType.XATTN, threshold=0.9)
attn_out = policy.compute_prefill(q, k, v, layer_id, softmax_scale)
"""
if policy_type == SparsePolicyType.FULL:
return FullAttentionPolicy()
@@ -75,21 +82,32 @@ def create_sparse_policy(policy_type: SparsePolicyType, **kwargs) -> SparsePolic
keep_sink=kwargs.get("keep_sink", False),
keep_recent=kwargs.get("keep_recent", False),
norm=kwargs.get("norm", 1.0),
use_bsa=kwargs.get("use_bsa", True),
)
else:
raise ValueError(f"Unknown policy type: {policy_type}")
# Backward compatibility alias
create_sparse_policy = create_attention_policy
__all__ = [
# New interface
"AttentionPolicy",
"create_attention_policy",
# Backward compatibility
"SparsePolicy",
"create_sparse_policy",
# Common types
"PolicyContext",
"SparsePolicyType",
# Policy implementations
"FullAttentionPolicy",
"QuestPolicy",
"QuestConfig",
"BlockMetadataManager",
"MInferencePolicy",
"XAttentionPolicy",
"create_sparse_policy",
]