""" Base class for attention policies in layerwise offload mode. AttentionPolicy defines the interface for all attention computation, including full attention and sparse attention methods like XAttention. Key methods: - estimate(): Compute sparse attention mask (optional, returns None for full attention) - compute_prefill(): Compute prefill attention - compute_decode(): Compute decode attention (default implementation provided) """ from abc import ABC, abstractmethod from dataclasses import dataclass from typing import List, Optional, Tuple import torch # Import SparsePolicyType from config to avoid circular imports from nanovllm.config import SparsePolicyType @dataclass class PolicyContext: """ Context passed to attention policy for block selection. This dataclass contains all information needed by an attention policy for sparse estimation and attention computation. """ query_chunk_idx: int """Index of the current query chunk (0-indexed).""" num_query_chunks: int """Total number of query chunks in this prefill.""" layer_id: int """Current transformer layer index.""" query: Optional[torch.Tensor] """ Query tensor for current chunk. Shape: [1, num_heads, head_dim] for decode, [1, seq_len, num_heads, head_dim] for prefill. May be None if not available (e.g., some prefill scenarios). """ is_prefill: bool """True if in prefill phase, False if in decode phase.""" block_size: int = 1024 """Number of tokens per block.""" total_kv_len: int = 0 """Total KV sequence length so far (for reference).""" class AttentionPolicy(ABC): """ Base class for attention policies in layerwise offload mode. All attention computation goes through a policy, including both full attention and sparse attention methods. The policy interface is designed for layerwise offload where: - The entire KV cache for a layer is on GPU during computation - No need for block loading from CPU during attention - estimate() returns a sparse mask (or None for full attention) - compute_prefill()/compute_decode() perform the actual attention Attributes: supports_prefill: Whether this policy can be used for prefill phase. supports_decode: Whether this policy can be used for decode phase. Example: class MyPolicy(AttentionPolicy): supports_prefill = True supports_decode = True def estimate(self, q, k, layer_id): # Return sparse mask or None return None def compute_prefill(self, q, k, v, layer_id, softmax_scale): # Compute attention return flash_attn_varlen_func(q, k, v, ...) """ # Compatibility flags - override in subclasses supports_prefill: bool = True supports_decode: bool = True def initialize( self, num_layers: int, num_kv_heads: int, head_dim: int, num_cpu_blocks: int, dtype: torch.dtype, device: torch.device = None, ) -> None: """ Initialize policy resources. Called by the framework after KV cache is allocated. Override this to create metadata structures or pre-allocate buffers. Default implementation does nothing. Args: num_layers: Number of transformer layers num_kv_heads: Number of KV attention heads head_dim: Dimension per head num_cpu_blocks: Number of CPU blocks allocated dtype: Data type for tensors device: Device for metadata storage (GPU recommended for performance) """ pass def estimate( self, q: torch.Tensor, k: torch.Tensor, layer_id: int, ) -> Optional[torch.Tensor]: """ Estimate sparse attention mask. For sparse policies (e.g., XAttention), computes block-level importance and returns a boolean mask indicating which blocks to attend. For full attention policy, returns None. This corresponds to xattn_estimate() in COMPASS. Args: q: Query tensor [seq_len, num_heads, head_dim] k: Key tensor [seq_len, num_kv_heads, head_dim] layer_id: Transformer layer index Returns: sparse_mask: [batch, num_heads, q_blocks, k_blocks] boolean mask, or None for full attention """ return None @abstractmethod def compute_prefill( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, layer_id: int, softmax_scale: float, ) -> torch.Tensor: """ Compute prefill attention. The entire KV cache for this layer is on GPU. Compute attention between Q and K/V, optionally using sparse mask from estimate(). Args: q: Query tensor [seq_len, num_heads, head_dim] k: Key tensor [seq_len, num_kv_heads, head_dim] v: Value tensor [seq_len, num_kv_heads, head_dim] layer_id: Transformer layer index softmax_scale: Softmax scaling factor (1/sqrt(head_dim)) Returns: Attention output [seq_len, num_heads, head_dim] """ pass def compute_decode( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, layer_id: int, softmax_scale: float, ) -> torch.Tensor: """ Compute decode attention. KV is provided from ring buffer, containing prefill tokens + decoded tokens. Default implementation uses FlashAttention. Args: q: Query tensor [1, num_heads, head_dim] k: Key tensor [context_len+1, num_kv_heads, head_dim] v: Value tensor [context_len+1, num_kv_heads, head_dim] layer_id: Transformer layer index softmax_scale: Softmax scaling factor Returns: Attention output [1, num_heads, head_dim] """ from flash_attn.flash_attn_interface import flash_attn_varlen_func context_len = k.shape[0] cu_seqlens_q = torch.tensor([0, 1], dtype=torch.int32, device=q.device) cu_seqlens_k = torch.tensor([0, context_len], dtype=torch.int32, device=q.device) return flash_attn_varlen_func( q, k, v, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=1, max_seqlen_k=context_len, softmax_scale=softmax_scale, causal=False, ) def reset(self) -> None: """ Reset policy state. Called when starting a new sequence or clearing state. Default implementation does nothing. """ pass def __repr__(self) -> str: return f"{self.__class__.__name__}()" # Backward compatibility alias SparsePolicy = AttentionPolicy