[WIP] need refactor.
This commit is contained in:
@@ -1,49 +1,56 @@
|
||||
"""
|
||||
Sparse Attention Policy module.
|
||||
Attention Policy module for layerwise offload mode.
|
||||
|
||||
Provides pluggable policies for selecting which KV blocks to load
|
||||
during chunked attention with CPU offload.
|
||||
Provides pluggable policies for attention computation:
|
||||
- FullAttentionPolicy: Standard FlashAttention (no sparsity)
|
||||
- XAttentionPolicy: Sparse prefill using XAttention algorithm
|
||||
- MInferencePolicy: MInference sparse attention
|
||||
- QuestPolicy: Quest block selection (for chunked offload)
|
||||
|
||||
Usage:
|
||||
from nanovllm.kvcache.sparse import create_sparse_policy, SparsePolicyType
|
||||
from nanovllm.kvcache.sparse import create_attention_policy, SparsePolicyType
|
||||
|
||||
# Create policy using factory function
|
||||
policy = create_sparse_policy(SparsePolicyType.QUEST, topk_blocks=8)
|
||||
policy = create_attention_policy(SparsePolicyType.XATTN, threshold=0.9)
|
||||
|
||||
# Use policy for attention
|
||||
attn_output = policy.compute_prefill(q, k, v, layer_id, softmax_scale)
|
||||
|
||||
# Or create custom policy
|
||||
class MyPolicy(SparsePolicy):
|
||||
class MyPolicy(AttentionPolicy):
|
||||
supports_prefill = True
|
||||
supports_decode = True
|
||||
|
||||
def select_blocks(self, available_blocks, ctx):
|
||||
return available_blocks[:5] # Just first 5 blocks
|
||||
def compute_prefill(self, q, k, v, layer_id, softmax_scale):
|
||||
# Custom attention computation
|
||||
...
|
||||
"""
|
||||
|
||||
from nanovllm.config import SparsePolicyType
|
||||
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext
|
||||
from nanovllm.kvcache.sparse.policy import AttentionPolicy, SparsePolicy, PolicyContext
|
||||
from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy
|
||||
from nanovllm.kvcache.sparse.quest import QuestPolicy, QuestConfig, BlockMetadataManager
|
||||
from nanovllm.kvcache.sparse.minference import MInferencePolicy
|
||||
from nanovllm.kvcache.sparse.xattn import XAttentionPolicy
|
||||
|
||||
|
||||
def create_sparse_policy(policy_type: SparsePolicyType, **kwargs) -> SparsePolicy:
|
||||
def create_attention_policy(policy_type: SparsePolicyType, **kwargs) -> AttentionPolicy:
|
||||
"""
|
||||
Create a sparse policy instance from an enum type.
|
||||
Create an attention policy instance from an enum type.
|
||||
|
||||
The returned policy is not yet initialized. Call policy.initialize()
|
||||
or let the framework call it during KV cache allocation.
|
||||
All attention (including full attention) goes through a policy in layerwise
|
||||
offload mode. The policy is responsible for computing prefill/decode attention.
|
||||
|
||||
Args:
|
||||
policy_type: SparsePolicyType enum value
|
||||
policy_type: SparsePolicyType enum value (FULL, XATTN, MINFERENCE, QUEST)
|
||||
**kwargs: Policy-specific configuration options
|
||||
|
||||
Returns:
|
||||
SparsePolicy instance (not initialized)
|
||||
AttentionPolicy instance
|
||||
|
||||
Example:
|
||||
policy = create_sparse_policy(SparsePolicyType.QUEST, topk_blocks=4)
|
||||
policy.initialize(num_layers=28, num_kv_heads=8, ...)
|
||||
policy = create_attention_policy(SparsePolicyType.XATTN, threshold=0.9)
|
||||
attn_out = policy.compute_prefill(q, k, v, layer_id, softmax_scale)
|
||||
"""
|
||||
if policy_type == SparsePolicyType.FULL:
|
||||
return FullAttentionPolicy()
|
||||
@@ -75,21 +82,32 @@ def create_sparse_policy(policy_type: SparsePolicyType, **kwargs) -> SparsePolic
|
||||
keep_sink=kwargs.get("keep_sink", False),
|
||||
keep_recent=kwargs.get("keep_recent", False),
|
||||
norm=kwargs.get("norm", 1.0),
|
||||
use_bsa=kwargs.get("use_bsa", True),
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown policy type: {policy_type}")
|
||||
|
||||
|
||||
# Backward compatibility alias
|
||||
create_sparse_policy = create_attention_policy
|
||||
|
||||
|
||||
__all__ = [
|
||||
# New interface
|
||||
"AttentionPolicy",
|
||||
"create_attention_policy",
|
||||
# Backward compatibility
|
||||
"SparsePolicy",
|
||||
"create_sparse_policy",
|
||||
# Common types
|
||||
"PolicyContext",
|
||||
"SparsePolicyType",
|
||||
# Policy implementations
|
||||
"FullAttentionPolicy",
|
||||
"QuestPolicy",
|
||||
"QuestConfig",
|
||||
"BlockMetadataManager",
|
||||
"MInferencePolicy",
|
||||
"XAttentionPolicy",
|
||||
"create_sparse_policy",
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user