""" XAttention sparse attention policy for nano-vllm. Implements the XAttention algorithm from COMPASS, using chunked estimation and block sparse attention for efficient long-context inference. Architecture: XAttention = Estimate (Triton) + Compute (BSA) - Estimate: xattn_estimate() computes block-level importance scores - Compute: block_sparse_attn_func() executes sparse attention Reference: COMPASS/compass/src/Xattention.py """ import math from typing import Optional import torch import torch.nn.functional as F from nanovllm.kvcache.sparse.policy import AttentionPolicy # BSA block size is fixed at 128 (hardcoded in block_sparse_attn) BSA_BLOCK_SIZE = 128 class XAttentionPolicy(AttentionPolicy): """ XAttention sparse prefill policy using chunked estimation + block sparse attention. This policy estimates sparse attention patterns by: 1. Chunked QK computation using Triton kernels (via nanovllm.ops.xattn) 2. Block-wise softmax with importance scores 3. Block selection based on threshold 4. Block sparse attention computation using MIT-HAN-LAB BSA library The key method is estimate() which calls xattn_estimate() from nanovllm.ops to compute the sparse attention mask. Note: Requires Triton >= 2.1.0 and CUDA SM 80+ (RTX 3090, A100, H100, etc.) BSA library: https://github.com/mit-han-lab/Block-Sparse-Attention """ supports_prefill = True supports_decode = True # Uses default FlashAttention for decode def __init__( self, stride: int = 8, threshold: float = 0.9, block_size: int = 128, chunk_size: int = 16384, use_triton: bool = True, keep_sink: bool = False, keep_recent: bool = False, norm: float = 1.0, use_bsa: bool = True, ): """ Initialize XAttention policy. Args: stride: Stride for reorganizing Q/K (default: 8) threshold: Block selection threshold, 0-1 (default: 0.9) block_size: Block size for sparse attention (default: 128, must match BSA) chunk_size: Chunk size for estimation (default: 16384) use_triton: Use Triton kernels (requires SM 80+) keep_sink: Always keep first block (sink tokens) keep_recent: Always keep recent diagonal blocks norm: Normalization factor for attention scores use_bsa: Use Block Sparse Attention library (default: True) """ self.stride = stride self.threshold = threshold self.block_size = block_size self.chunk_size = chunk_size self.use_triton = use_triton self.keep_sink = keep_sink self.keep_recent = keep_recent self.norm = norm self.use_bsa = use_bsa # BSA requires block_size = 128 if self.use_bsa and self.block_size != BSA_BLOCK_SIZE: print(f"XAttention: BSA requires block_size=128, adjusting from {self.block_size}") self.block_size = BSA_BLOCK_SIZE # Check Triton availability if self.use_triton: try: import triton props = torch.cuda.get_device_properties(torch.cuda.current_device()) if props.major < 8: self.use_triton = False print(f"XAttention: Triton requires SM 80+, got SM {props.major}{props.minor}. Falling back to PyTorch.") except ImportError: self.use_triton = False print("XAttention: Triton not available. Falling back to PyTorch.") # Check BSA availability if self.use_bsa: try: from block_sparse_attn import block_sparse_attn_func except ImportError: self.use_bsa = False print("XAttention: block_sparse_attn not available. Falling back to FlashAttention.") def estimate( self, q: torch.Tensor, k: torch.Tensor, layer_id: int, ) -> Optional[torch.Tensor]: """ Estimate sparse attention mask using XAttention algorithm. Calls xattn_estimate() from nanovllm.ops.xattn to compute block-level importance scores and generate a sparse boolean mask. Args: q: Query tensor [seq_len, num_heads, head_dim] k: Key tensor [seq_len, num_kv_heads, head_dim] layer_id: Transformer layer index Returns: sparse_mask: [batch, num_heads, q_blocks, k_blocks] boolean mask, or None if estimation fails (fallback to full attention) """ try: from nanovllm.ops.xattn import xattn_estimate seq_len, num_heads, head_dim = q.shape num_kv_heads = k.shape[1] # Convert to [batch, heads, seq, dim] format expected by xattn_estimate q_bhsd = q.unsqueeze(0).transpose(1, 2) # [1, num_heads, seq_len, head_dim] k_bhsd = k.unsqueeze(0).transpose(1, 2) # [1, num_kv_heads, seq_len, head_dim] # Handle GQA: expand k to match q heads for estimation if num_kv_heads != num_heads: # GQA: expand k by repeating repeat_factor = num_heads // num_kv_heads k_bhsd = k_bhsd.repeat(1, repeat_factor, 1, 1) # Call xattn_estimate attn_sums, sparse_mask = xattn_estimate( q_bhsd, k_bhsd, block_size=self.block_size, stride=self.stride, norm=self.norm, threshold=self.threshold, chunk_size=self.chunk_size, use_triton=self.use_triton, causal=True, keep_sink=self.keep_sink, keep_recent=self.keep_recent, ) return sparse_mask except Exception as e: # If estimation fails, return None to use full attention print(f"XAttention estimate failed: {e}, falling back to full attention") return None def compute_prefill( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, layer_id: int, softmax_scale: float, ) -> torch.Tensor: """ Compute XAttention sparse prefill attention. Flow: 1. Call estimate() to get sparse mask 2. If mask is None or BSA unavailable, use full FlashAttention 3. Otherwise, use block_sparse_attn_func with mask Args: q: Query tensor [seq_len, num_heads, head_dim] k: Key tensor [seq_len, num_kv_heads, head_dim] v: Value tensor [seq_len, num_kv_heads, head_dim] layer_id: Transformer layer index softmax_scale: Softmax scaling factor Returns: Attention output [seq_len, num_heads, head_dim] """ # If BSA is disabled, use full attention directly (skip estimation) if not self.use_bsa: return self._full_attention(q, k, v, softmax_scale) # Step 1: Estimate sparse mask sparse_mask = self.estimate(q, k, layer_id) # Step 2: Compute attention if sparse_mask is None: # Estimation failed, fallback to full FlashAttention return self._full_attention(q, k, v, softmax_scale) # Use block sparse attention with mask return self._block_sparse_attention(q, k, v, sparse_mask, softmax_scale) def _block_sparse_attention( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, sparse_mask: torch.Tensor, softmax_scale: float, ) -> torch.Tensor: """ Compute block sparse attention using MIT-HAN-LAB BSA library. Args: q: Query tensor [seq_len, num_heads, head_dim] k: Key tensor [seq_len, num_kv_heads, head_dim] v: Value tensor [seq_len, num_kv_heads, head_dim] sparse_mask: Block mask [batch, num_heads, q_blocks, k_blocks] softmax_scale: Softmax scaling factor Returns: Attention output [seq_len, num_heads, head_dim] """ from block_sparse_attn import block_sparse_attn_func seq_len, num_heads, head_dim = q.shape num_kv_heads = k.shape[1] # Handle GQA: expand K/V to match Q heads if num_kv_heads != num_heads: repeat_factor = num_heads // num_kv_heads k = k.repeat_interleave(repeat_factor, dim=1) v = v.repeat_interleave(repeat_factor, dim=1) # Cumulative sequence lengths (batch=1) cu_seqlens_q = torch.tensor([0, seq_len], dtype=torch.int32, device=q.device) cu_seqlens_k = torch.tensor([0, seq_len], dtype=torch.int32, device=q.device) # Head mask type: 1 for all heads using block sparse head_mask_type = torch.ones(num_heads, dtype=torch.int32, device=q.device) # Trim sparse_mask to actual block counts q_blocks = (seq_len + BSA_BLOCK_SIZE - 1) // BSA_BLOCK_SIZE k_blocks = (seq_len + BSA_BLOCK_SIZE - 1) // BSA_BLOCK_SIZE block_mask = sparse_mask[:, :, :q_blocks, :k_blocks].contiguous() # Call BSA attn_output = block_sparse_attn_func( q, k, v, cu_seqlens_q, cu_seqlens_k, head_mask_type, None, # streaming_info (left_mask) block_mask, seq_len, seq_len, p_dropout=0.0, deterministic=True, softmax_scale=softmax_scale, is_causal=True, ) return attn_output def _full_attention( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, softmax_scale: float, ) -> torch.Tensor: """ Compute full causal attention using FlashAttention. Args: q: Query tensor [seq_len, num_heads, head_dim] k: Key tensor [seq_len, num_kv_heads, head_dim] v: Value tensor [seq_len, num_kv_heads, head_dim] softmax_scale: Softmax scaling factor Returns: Attention output [seq_len, num_heads, head_dim] """ from flash_attn.flash_attn_interface import flash_attn_varlen_func seq_len = q.shape[0] cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=q.device) return flash_attn_varlen_func( q, k, v, cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, max_seqlen_q=seq_len, max_seqlen_k=seq_len, softmax_scale=softmax_scale, causal=True, ) def reset(self) -> None: """Reset policy state (no state to reset for XAttention).""" pass def __repr__(self) -> str: return (f"XAttentionPolicy(" f"stride={self.stride}, " f"threshold={self.threshold}, " f"block_size={self.block_size}, " f"use_triton={self.use_triton}, " f"use_bsa={self.use_bsa})")