114 lines
3.9 KiB
Python
114 lines
3.9 KiB
Python
"""
|
|
Attention Policy module for layerwise offload mode.
|
|
|
|
Provides pluggable policies for attention computation:
|
|
- FullAttentionPolicy: Standard FlashAttention (no sparsity)
|
|
- XAttentionPolicy: Sparse prefill using XAttention algorithm
|
|
- MInferencePolicy: MInference sparse attention
|
|
- QuestPolicy: Quest block selection (for chunked offload)
|
|
|
|
Usage:
|
|
from nanovllm.kvcache.sparse import create_attention_policy, SparsePolicyType
|
|
|
|
# Create policy using factory function
|
|
policy = create_attention_policy(SparsePolicyType.XATTN, threshold=0.9)
|
|
|
|
# Use policy for attention
|
|
attn_output = policy.compute_prefill(q, k, v, layer_id, softmax_scale)
|
|
|
|
# Or create custom policy
|
|
class MyPolicy(AttentionPolicy):
|
|
supports_prefill = True
|
|
supports_decode = True
|
|
|
|
def compute_prefill(self, q, k, v, layer_id, softmax_scale):
|
|
# Custom attention computation
|
|
...
|
|
"""
|
|
|
|
from nanovllm.config import SparsePolicyType
|
|
from nanovllm.kvcache.sparse.policy import AttentionPolicy, SparsePolicy, PolicyContext
|
|
from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy
|
|
from nanovllm.kvcache.sparse.quest import QuestPolicy, QuestConfig, BlockMetadataManager
|
|
from nanovllm.kvcache.sparse.minference import MInferencePolicy
|
|
from nanovllm.kvcache.sparse.xattn import XAttentionPolicy
|
|
|
|
|
|
def create_attention_policy(policy_type: SparsePolicyType, **kwargs) -> AttentionPolicy:
|
|
"""
|
|
Create an attention policy instance from an enum type.
|
|
|
|
All attention (including full attention) goes through a policy in layerwise
|
|
offload mode. The policy is responsible for computing prefill/decode attention.
|
|
|
|
Args:
|
|
policy_type: SparsePolicyType enum value (FULL, XATTN, MINFERENCE, QUEST)
|
|
**kwargs: Policy-specific configuration options
|
|
|
|
Returns:
|
|
AttentionPolicy instance
|
|
|
|
Example:
|
|
policy = create_attention_policy(SparsePolicyType.XATTN, threshold=0.9)
|
|
attn_out = policy.compute_prefill(q, k, v, layer_id, softmax_scale)
|
|
"""
|
|
if policy_type == SparsePolicyType.FULL:
|
|
return FullAttentionPolicy()
|
|
|
|
elif policy_type == SparsePolicyType.QUEST:
|
|
config = QuestConfig(
|
|
topk_blocks=kwargs.get("topk_blocks", 8),
|
|
threshold_blocks=kwargs.get("threshold_blocks", 4),
|
|
include_sink_blocks=kwargs.get("include_sink_blocks", 0),
|
|
include_recent_blocks=kwargs.get("include_recent_blocks", 0),
|
|
)
|
|
return QuestPolicy(config)
|
|
|
|
elif policy_type == SparsePolicyType.MINFERENCE:
|
|
return MInferencePolicy(
|
|
vertical_size=kwargs.get("vertical_size", 1000),
|
|
slash_size=kwargs.get("slash_size", 6096),
|
|
adaptive_budget=kwargs.get("adaptive_budget", 0.3),
|
|
num_sink_tokens=kwargs.get("num_sink_tokens", 30),
|
|
num_recent_diags=kwargs.get("num_recent_diags", 100),
|
|
)
|
|
|
|
elif policy_type == SparsePolicyType.XATTN:
|
|
return XAttentionPolicy(
|
|
stride=kwargs.get("stride", 8),
|
|
threshold=kwargs.get("threshold", 0.9),
|
|
chunk_size=kwargs.get("chunk_size", 16384),
|
|
use_triton=kwargs.get("use_triton", True),
|
|
keep_sink=kwargs.get("keep_sink", False),
|
|
keep_recent=kwargs.get("keep_recent", False),
|
|
norm=kwargs.get("norm", 1.0),
|
|
use_bsa=kwargs.get("use_bsa", True),
|
|
)
|
|
|
|
else:
|
|
raise ValueError(f"Unknown policy type: {policy_type}")
|
|
|
|
|
|
# Backward compatibility alias
|
|
create_sparse_policy = create_attention_policy
|
|
|
|
|
|
__all__ = [
|
|
# New interface
|
|
"AttentionPolicy",
|
|
"create_attention_policy",
|
|
# Backward compatibility
|
|
"SparsePolicy",
|
|
"create_sparse_policy",
|
|
# Common types
|
|
"PolicyContext",
|
|
"SparsePolicyType",
|
|
# Policy implementations
|
|
"FullAttentionPolicy",
|
|
"QuestPolicy",
|
|
"QuestConfig",
|
|
"BlockMetadataManager",
|
|
"MInferencePolicy",
|
|
"XAttentionPolicy",
|
|
]
|