""" Attention Policy module for layerwise offload mode. Provides pluggable policies for attention computation: - FullAttentionPolicy: Standard FlashAttention (no sparsity) - XAttentionPolicy: Sparse prefill using XAttention algorithm - MInferencePolicy: MInference sparse attention - QuestPolicy: Quest block selection (for chunked offload) Usage: from nanovllm.kvcache.sparse import create_attention_policy, SparsePolicyType # Create policy using factory function policy = create_attention_policy(SparsePolicyType.XATTN, threshold=0.9) # Use policy for attention attn_output = policy.compute_prefill(q, k, v, layer_id, softmax_scale) # Or create custom policy class MyPolicy(AttentionPolicy): supports_prefill = True supports_decode = True def compute_prefill(self, q, k, v, layer_id, softmax_scale): # Custom attention computation ... """ from nanovllm.config import SparsePolicyType from nanovllm.kvcache.sparse.policy import AttentionPolicy, SparsePolicy, PolicyContext from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy from nanovllm.kvcache.sparse.quest import QuestPolicy, QuestConfig, BlockMetadataManager from nanovllm.kvcache.sparse.minference import MInferencePolicy from nanovllm.kvcache.sparse.xattn import XAttentionPolicy def create_attention_policy(policy_type: SparsePolicyType, **kwargs) -> AttentionPolicy: """ Create an attention policy instance from an enum type. All attention (including full attention) goes through a policy in layerwise offload mode. The policy is responsible for computing prefill/decode attention. Args: policy_type: SparsePolicyType enum value (FULL, XATTN, MINFERENCE, QUEST) **kwargs: Policy-specific configuration options Returns: AttentionPolicy instance Example: policy = create_attention_policy(SparsePolicyType.XATTN, threshold=0.9) attn_out = policy.compute_prefill(q, k, v, layer_id, softmax_scale) """ if policy_type == SparsePolicyType.FULL: return FullAttentionPolicy() elif policy_type == SparsePolicyType.QUEST: config = QuestConfig( topk_blocks=kwargs.get("topk_blocks", 8), threshold_blocks=kwargs.get("threshold_blocks", 4), include_sink_blocks=kwargs.get("include_sink_blocks", 0), include_recent_blocks=kwargs.get("include_recent_blocks", 0), ) return QuestPolicy(config) elif policy_type == SparsePolicyType.MINFERENCE: return MInferencePolicy( vertical_size=kwargs.get("vertical_size", 1000), slash_size=kwargs.get("slash_size", 6096), adaptive_budget=kwargs.get("adaptive_budget", 0.3), num_sink_tokens=kwargs.get("num_sink_tokens", 30), num_recent_diags=kwargs.get("num_recent_diags", 100), ) elif policy_type == SparsePolicyType.XATTN: return XAttentionPolicy( stride=kwargs.get("stride", 8), threshold=kwargs.get("threshold", 0.9), chunk_size=kwargs.get("chunk_size", 16384), use_triton=kwargs.get("use_triton", True), keep_sink=kwargs.get("keep_sink", False), keep_recent=kwargs.get("keep_recent", False), norm=kwargs.get("norm", 1.0), use_bsa=kwargs.get("use_bsa", True), ) else: raise ValueError(f"Unknown policy type: {policy_type}") # Backward compatibility alias create_sparse_policy = create_attention_policy __all__ = [ # New interface "AttentionPolicy", "create_attention_policy", # Backward compatibility "SparsePolicy", "create_sparse_policy", # Common types "PolicyContext", "SparsePolicyType", # Policy implementations "FullAttentionPolicy", "QuestPolicy", "QuestConfig", "BlockMetadataManager", "MInferencePolicy", "XAttentionPolicy", ]