""" XAttention Block Sparse Attention (BSA) Policy for nano-vllm. This module implements XAttention-inspired block sparse attention for chunked prefill. Key design: 1. Use xattn_estimate_chunked to estimate sparse block mask 2. Use BSA kernel for efficient sparse attention computation 3. Support chunked prefill with q_start_pos for correct position handling Note: Decode phase is not supported - use FullAttentionPolicy for decode. """ import logging import torch import torch.cuda.nvtx as nvtx from typing import List, Tuple, TYPE_CHECKING from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext from nanovllm.utils.density_observer import DensityObserver if TYPE_CHECKING: from nanovllm.kvcache.offload_engine import OffloadEngine from nanovllm.kvcache.manager import KVCacheManager from nanovllm.engine.sequence import Sequence logger = logging.getLogger(__name__) # Check BSA availability try: from block_sparse_attn import block_sparse_attn_func BSA_AVAILABLE = True except ImportError: BSA_AVAILABLE = False logger.warning("block_sparse_attn not available, XAttentionBSAPolicy will fallback to dense") # Check xattn_estimate_chunked availability try: from nanovllm.ops.xattn import xattn_estimate_chunked XATTN_AVAILABLE = True except ImportError: XATTN_AVAILABLE = False logger.warning("xattn_estimate_chunked not available") def expand_kv_for_gqa( key_states: torch.Tensor, value_states: torch.Tensor, num_heads: int, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Expand KV for Grouped Query Attention. Args: key_states: [B, num_kv_heads, seq_len, head_dim] value_states: [B, num_kv_heads, seq_len, head_dim] num_heads: Number of query heads Returns: Expanded (key, value) with shape [B, num_heads, seq_len, head_dim] """ num_kv_heads = key_states.shape[1] if num_heads == num_kv_heads: return key_states, value_states num_groups = num_heads // num_kv_heads return ( key_states.repeat_interleave(num_groups, dim=1), value_states.repeat_interleave(num_groups, dim=1), ) class XAttentionBSAPolicy(SparsePolicy): """ XAttention Block Sparse Attention policy for chunked prefill. Uses xattn_estimate_chunked to estimate sparse mask, then BSA kernel for efficient sparse attention computation. Note: - Only supports prefill phase (decode uses FullAttentionPolicy) - BSA block size is fixed at 128 tokens """ supports_prefill = True supports_decode = False # Decode uses FullAttentionPolicy requires_block_selection = False # Selection happens internally # BSA requires 128-token blocks BSA_BLOCK_SIZE = 128 def __init__( self, threshold: float = 0.95, # High threshold for accuracy testing stride: int = 8, chunk_size: int = 16384, block_size: int = 128, samples_per_chunk: int = 128, use_triton: bool = True, estimate_block_size: int = 1024, # Optimized block size for softmax_fuse_block_sum ): """ Initialize XAttention BSA policy. Args: threshold: Cumulative attention threshold for block selection (0-1) Higher values = more blocks selected = less sparse stride: Stride for Q/K reshape in estimation (typically 8) chunk_size: Processing chunk size for xattn_estimate (Triton alignment) block_size: BSA block size (must be 128) samples_per_chunk: Samples per chunk for estimation (unused) use_triton: Whether to use Triton kernels estimate_block_size: Block size for softmax_fuse_block_sum in select_blocks. Default 1024 is optimal (15x faster than 4096). Must be a factor of cpu_block_size (e.g., 4096/1024=4). """ self.threshold = threshold self.stride = stride self.chunk_size = chunk_size self.use_triton = use_triton self.estimate_block_size = estimate_block_size self._num_heads = None # Set during first forward # Sparse metadata: stores attention scores per layer # Dict[layer_id, Tensor[num_q_blocks, num_k_blocks]] self.sparse_metadata: dict = {} # Statistics for density tracking self._stats_total_available_blocks = 0 self._stats_total_selected_blocks = 0 self._stats_num_chunks = 0 # Pre-allocated GQA expansion buffers (GPU-only mode) # Set by alloc_policy_metadata(), None if not pre-allocated self._k_expanded: torch.Tensor | None = None self._v_expanded: torch.Tensor | None = None self._max_seq_len: int = 0 # Pre-allocated mask buffer for chunked prefill (offload mode) # Stores BSA-level mask from select_blocks for use in compute_chunked_prefill # Shape: [1, num_heads, max_q_bsa_blocks, max_k_bsa_blocks] self._prefill_mask_buffer: torch.Tensor | None = None self._current_mask_q_bsa: int = 0 # Current Q BSA blocks in buffer self._current_mask_k_bsa: int = 0 # Current K BSA blocks in buffer # Selected block indices for mask extraction in compute_chunked_prefill # Stores the indices of selected CPU blocks in available_blocks self._selected_cpu_indices: List[int] = [] self._bsa_per_cpu: int = 0 # BSA blocks per CPU block #> Debug: store all K cache and density counts self._debug_k_full: torch.Tensor | None = None self._debug_selected: int = 0 # 累积的 selected blocks self._debug_total: int = 0 # 累积的 total blocks def alloc_policy_metadata( self, num_heads: int, num_kv_heads: int, head_dim: int, max_seq_len: int, dtype: torch.dtype, device: torch.device, ) -> None: """ Pre-allocate GQA expansion buffers for GPU-only mode. These buffers are used by compute_prefill() to avoid dynamic allocation during forward pass. The buffers are sized for max_seq_len and sliced to actual seq_len during use. Memory usage: 2 * num_heads * max_seq_len * head_dim * dtype_size For 64K seq, 32 heads, 128 dim, fp16: 2 * 32 * 65536 * 128 * 2 = 1 GB Args: num_heads: Number of query heads num_kv_heads: Number of KV heads (for GQA) head_dim: Dimension per head max_seq_len: Maximum sequence length dtype: Data type device: Target device """ # Pre-allocate mask buffer for chunked prefill (offload mode) # mask shape: [1, num_heads, max_q_bsa_blocks, max_k_bsa_blocks] # This is needed regardless of GQA max_q_bsa_blocks = self.chunk_size // self.BSA_BLOCK_SIZE max_k_bsa_blocks = max_seq_len // self.BSA_BLOCK_SIZE mask_shape = (1, num_heads, max_q_bsa_blocks, max_k_bsa_blocks) self._prefill_mask_buffer = torch.empty(mask_shape, dtype=torch.bool, device=device) mask_memory_mb = num_heads * max_q_bsa_blocks * max_k_bsa_blocks / (1024 * 1024) logger.info(f"[XAttn] Pre-allocated mask buffer: shape={mask_shape}, memory={mask_memory_mb:.1f} MB") # Only allocate GQA expansion buffers if GQA (num_heads != num_kv_heads) if num_heads == num_kv_heads: logger.info(f"[XAttn] No GQA expansion needed (num_heads == num_kv_heads = {num_heads})") return # Shape: [1, num_heads, max_seq_len, head_dim] for xattn_estimate format # Also used for BSA which expects [seq_len, num_heads, head_dim] shape = (1, num_heads, max_seq_len, head_dim) self._k_expanded = torch.empty(shape, dtype=dtype, device=device) self._v_expanded = torch.empty(shape, dtype=dtype, device=device) self._max_seq_len = max_seq_len memory_mb = 2 * num_heads * max_seq_len * head_dim * dtype.itemsize / (1024 * 1024) logger.info(f"[XAttn] Pre-allocated GQA buffers: shape={shape}, memory={memory_mb:.1f} MB") #DEBUG : buffer for save all K cache self._debug_k_full = torch.empty((1, num_heads, max_seq_len, head_dim), dtype=dtype, device=device) self._debug_selected = 0 self._debug_total = 0 # ========================================================================= # GPU-only methods (non-chunked) # ========================================================================= def compute_prefill( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, cu_seqlens_q: torch.Tensor, cu_seqlens_k: torch.Tensor, max_seqlen_q: int, max_seqlen_k: int, softmax_scale: float, layer_id: int, block_tables: torch.Tensor = None, ) -> torch.Tensor: """ GPU-only prefill attention using XAttention + BSA. This method implements sparse attention for GPU-only mode: 1. Estimate block importance using xattn_estimate 2. Compute sparse attention using block_sparse_attn_func Args: q: Query tensor [total_q, num_heads, head_dim] (varlen packed) k: Key tensor [total_kv, num_kv_heads, head_dim] (varlen packed) v: Value tensor [total_kv, num_kv_heads, head_dim] (varlen packed) cu_seqlens_q: Cumulative sequence lengths for Q [batch+1] cu_seqlens_k: Cumulative sequence lengths for K [batch+1] max_seqlen_q: Maximum Q sequence length max_seqlen_k: Maximum K sequence length softmax_scale: Softmax scaling factor layer_id: Transformer layer index block_tables: Paged attention block tables (not used for XAttention) Returns: Attention output [total_q, num_heads, head_dim] """ # When block_tables is provided (paged KV cache / prefix cache), # fallback to flash_attn as XAttention expects contiguous K, V if block_tables is not None: from flash_attn import flash_attn_varlen_func return flash_attn_varlen_func( q, k, v, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=max_seqlen_q, max_seqlen_k=max_seqlen_k, softmax_scale=softmax_scale, causal=True, block_table=block_tables, ) if not BSA_AVAILABLE: # Fallback to flash attention if BSA not available from flash_attn import flash_attn_varlen_func return flash_attn_varlen_func( q, k, v, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=max_seqlen_q, max_seqlen_k=max_seqlen_k, softmax_scale=softmax_scale, causal=True, ) if not XATTN_AVAILABLE: # Fallback to flash attention if xattn not available from flash_attn import flash_attn_varlen_func return flash_attn_varlen_func( q, k, v, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=max_seqlen_q, max_seqlen_k=max_seqlen_k, softmax_scale=softmax_scale, causal=True, ) from nanovllm.ops.xattn import xattn_estimate # Set DensityObserver mode on first layer if layer_id == 0: DensityObserver.set_mode("gpu_only") # Get dimensions total_q, num_heads, head_dim = q.shape total_kv, num_kv_heads, _ = k.shape # For now, assume batch_size = 1 (single sequence) # TODO: Support batched varlen format batch_size = cu_seqlens_q.shape[0] - 1 if batch_size != 1: # Fallback to flash attention for batched input from flash_attn import flash_attn_varlen_func logger.warning(f"[XAttn] batch_size={batch_size} > 1, falling back to flash attention") return flash_attn_varlen_func( q, k, v, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=max_seqlen_q, max_seqlen_k=max_seqlen_k, softmax_scale=softmax_scale, causal=True, ) q_len = max_seqlen_q k_len = max_seqlen_k # Convert from varlen format [total, heads, dim] to [batch, heads, seq, dim] # q: [q_len, num_heads, head_dim] -> [1, num_heads, q_len, head_dim] Q = q.unsqueeze(0).transpose(1, 2) # [1, num_heads, q_len, head_dim] K = k.unsqueeze(0).transpose(1, 2) # [1, num_kv_heads, k_len, head_dim] V = v.unsqueeze(0).transpose(1, 2) # [1, num_kv_heads, k_len, head_dim] # Expand KV for GQA - use pre-allocated buffers if available if num_heads != num_kv_heads: num_groups = num_heads // num_kv_heads if self._k_expanded is not None and k_len <= self._max_seq_len: # Use pre-allocated buffers with in-place expansion K_exp = self._k_expanded[:, :, :k_len, :] V_exp = self._v_expanded[:, :, :k_len, :] # In-place GQA expansion: [1, num_kv_heads, k_len, head_dim] -> [1, num_heads, k_len, head_dim] # Reshape K to [1, num_kv_heads, 1, k_len, head_dim] and broadcast to [1, num_kv_heads, num_groups, k_len, head_dim] K_exp.view(1, num_kv_heads, num_groups, k_len, head_dim).copy_( K.unsqueeze(2).expand(-1, -1, num_groups, -1, -1) ) V_exp.view(1, num_kv_heads, num_groups, k_len, head_dim).copy_( V.unsqueeze(2).expand(-1, -1, num_groups, -1, -1) ) else: # Fallback: dynamic allocation (when buffers not pre-allocated or seq too long) K_exp, V_exp = expand_kv_for_gqa(K, V, num_heads) else: K_exp, V_exp = K, V # Estimate block importance and get sparse mask with nvtx.range("xattn_estimate"): _, mask = xattn_estimate( Q, K_exp, chunk_size=self.chunk_size, block_size=self.BSA_BLOCK_SIZE, stride=self.stride, threshold=self.threshold, use_triton=self.use_triton, causal=True, ) # Compute block counts q_block_num = (q_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE k_block_num = (k_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE # Prepare tensors for BSA # q, k, v need to be [seq_len, num_heads, head_dim] q_bsa = q # Already [q_len, num_heads, head_dim] # For GQA with BSA, reuse the expanded K_exp, V_exp (convert to BSA format) # K_exp: [1, num_heads, k_len, head_dim] -> [k_len, num_heads, head_dim] if num_heads != num_kv_heads: k_bsa = K_exp.squeeze(0).transpose(0, 1) # [k_len, num_heads, head_dim] v_bsa = V_exp.squeeze(0).transpose(0, 1) # [k_len, num_heads, head_dim] else: k_bsa = k v_bsa = v # Prepare BSA inputs cu_seqlens_q_bsa = torch.tensor([0, q_len], dtype=torch.int32, device=q.device) cu_seqlens_k_bsa = torch.tensor([0, k_len], dtype=torch.int32, device=k.device) head_groups = torch.ones(num_heads, dtype=torch.int32, device=q.device) # Trim mask to actual block counts mask_trimmed = mask[:, :, :q_block_num, :k_block_num].contiguous() # Compute sparse attention using BSA with nvtx.range("xattn_bsa_compute"): output = block_sparse_attn_func( q_bsa, k_bsa, v_bsa, cu_seqlens_q_bsa, cu_seqlens_k_bsa, head_groups, None, # key_padding_mask mask_trimmed, q_len, k_len, p_dropout=0.0, deterministic=True, is_causal=True, ) # Record density for all layers via DensityObserver if layer_id == 0: # DEBUG: 打印 GPU-only Layer 0 的 mask 详情 q_bk = mask_trimmed.shape[2] k_bk = mask_trimmed.shape[3] causal_total = q_bk * (q_bk + 1) // 2 * mask_trimmed.shape[0] * mask_trimmed.shape[1] causal_mask = torch.tril(torch.ones(q_bk, k_bk, device=mask_trimmed.device, dtype=torch.bool)) selected = (mask_trimmed & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item() logger.info(f"[DEBUG GPU-only Layer0] mask_shape={mask_trimmed.shape}, " f"density={selected/causal_total:.6f}, selected={selected}, total={causal_total}") DensityObserver.record(layer_id, mask_trimmed, causal=True) return output def compute_decode( self, q: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, cache_seqlens: torch.Tensor, softmax_scale: float, layer_id: int, block_tables: torch.Tensor = None, ) -> torch.Tensor: """ GPU-only decode attention - delegates to FullAttentionPolicy. XAttention is designed for long prefill sequences. For decode (single token), we use FullAttentionPolicy which calls flash_attn_with_kvcache. """ from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy return FullAttentionPolicy().compute_decode( q, k_cache, v_cache, cache_seqlens, softmax_scale, layer_id, block_tables ) # ========================================================================= # Chunked offload methods # ========================================================================= def select_blocks( self, available_blocks: List[int], offload_engine: "OffloadEngine", ctx: PolicyContext, q: torch.Tensor, k: torch.Tensor, ) -> List[int]: """ Compute attention scores for all available blocks using flat_group_gemm, then use softmax_fuse_block_sum and find_blocks_chunked to select important blocks. This method aligns with GPU-only xattn_estimate_chunked: 1. Loads each K block from CPU (historical blocks) 2. Gets current chunk K from prefill buffer 3. Concatenates [historical K, current chunk K] for correct softmax normalization 4. Uses causal=True with correct chunk_start for position-aware masking 5. Only selects from historical blocks (current chunk is always full attention) Args: available_blocks: List of CPU block IDs (historical blocks only) offload_engine: OffloadEngine for loading blocks ctx: PolicyContext with metadata q: Query tensor [seq_len, num_heads, head_dim] for current chunk k: Key tensor [seq_len, num_kv_heads, head_dim] for current chunk (used for estimation) Returns: Selected block IDs based on attention threshold """ if q is None: return available_blocks from nanovllm.ops.xattn import flat_group_gemm_fuse_reshape, softmax_fuse_block_sum, find_blocks_chunked import math layer_id = ctx.layer_id # Use passed q parameter instead of ctx.query # Set DensityObserver mode on first layer if layer_id == 0: DensityObserver.set_mode("offload") # Convert Q to [batch, heads, seq_len, head_dim] # q: [seq_len, num_heads, head_dim] -> [1, num_heads, seq_len, head_dim] Q = q.unsqueeze(0).transpose(1, 2) # [1, num_heads, seq_len, head_dim] num_heads = Q.shape[1] head_dim = Q.shape[3] q_len = Q.shape[2] # flat_group_gemm requires q_len to be divisible by stride * BLOCK_M (typically 8 * 128 = 1024) # Pad Q if necessary BLOCK_M = 128 # Triton block size alignment = self.stride * BLOCK_M if q_len < alignment: # Q too short, skip estimation and return all blocks logger.debug(f"[XAttn] select_blocks: q_len={q_len} < alignment={alignment}, skipping estimation") return available_blocks # Pad Q to alignment padded_q_len = ((q_len + alignment - 1) // alignment) * alignment if padded_q_len != q_len: pad_size = padded_q_len - q_len Q = torch.nn.functional.pad(Q, (0, 0, 0, pad_size), value=0) q_reshaped_len = padded_q_len // self.stride # Get block size from context block_size = ctx.block_size # tokens per CPU block (e.g., 4096) reshaped_block_size = block_size // self.stride # e.g., 4096/8 = 512 # ============================================================ # Step 1: Compute chunk_start and related parameters # ============================================================ # chunk_start = Q's global position in reshaped space # Q starts at position: num_historical_blocks * block_size num_historical_blocks = len(available_blocks) historical_k_len = num_historical_blocks * block_size chunk_start = historical_k_len // self.stride # Q's position in reshaped space chunk_end = chunk_start + q_reshaped_len # For valid Q length tracking (excluding padding) valid_q_reshaped = (q_len + self.stride - 1) // self.stride real_q_len = chunk_start + valid_q_reshaped # ============================================================ # Step 2: Pipeline load historical K blocks and compute attn_scores # ============================================================ # Key design: Load each block, compute immediately, then release # This avoids storing all K in GPU memory at once (offload-friendly) slot = 0 attn_scores_list = [] BLOCK_N = 128 k_alignment = self.stride * BLOCK_N with nvtx.range("xattn_estimate_historical"): for cpu_block_id in available_blocks: # Load only K from CPU to GPU (V not needed for estimate) offload_engine.load_k_only_to_slot_layer(slot, layer_id, cpu_block_id, chunk_idx=cpu_block_id) offload_engine.wait_slot_layer(slot) # Get K only: [1, block_size, num_kv_heads, head_dim] k_block = offload_engine.get_k_for_slot(slot) # Convert K to [batch, heads, k_len, head_dim] K_chunk = k_block.transpose(1, 2) # [1, num_kv_heads, block_size, head_dim] # Handle GQA: expand K heads to match Q heads num_kv_heads = K_chunk.shape[1] if num_heads != num_kv_heads: num_groups = num_heads // num_kv_heads K_chunk = K_chunk.repeat_interleave(num_groups, dim=1) #> DEBUG: save all K cache start_pos = cpu_block_id * block_size self._debug_k_full[:, :, start_pos:start_pos + block_size, :].copy_(K_chunk) # # Pad K if necessary # k_len = K_chunk.shape[2] # if k_len < k_alignment: # pad_size = k_alignment - k_len # K_chunk = torch.nn.functional.pad(K_chunk, (0, 0, 0, pad_size), value=0) # # Compute attention scores for this historical block # # Historical blocks: all positions < Q, so Q always sees them (full attention) # # Use LOCAL chunk_start=0 to match test_xattn_k_chunked.py behavior # attn_chunk = flat_group_gemm_fuse_reshape( # Q, K_chunk, self.stride, # chunk_start=0, # Local: same as test # chunk_end=q_reshaped_len, # is_causal=False, # Historical K: all visible to Q # ) # attn_scores_list.append(attn_chunk) # Mark slot as done for reuse offload_engine.record_slot_compute_done(slot) num_kv_heads = k.shape[1] if num_heads != num_kv_heads: num_groups = num_heads // num_kv_heads k_repeated = k.repeat_interleave(num_groups, dim=1).unsqueeze(0).transpose(1, 2) # [1, num_heads, historical_k_len, head_dim] self._debug_k_full[:, :, historical_k_len:historical_k_len + q_len, :].copy_(k_repeated) # ============================================================ # DEBUG: 累积 selected/total counts (仅 layer 0) # 使用完整 K 调用 xattn_estimate,与 GPU-only 逻辑一致 # ============================================================ if layer_id == 0: from nanovllm.ops.xattn import xattn_estimate total_k_len = historical_k_len + q_len K_full = self._debug_k_full[:, :, :total_k_len, :] # 用当前 Q chunk 和累积的 K 调用 xattn_estimate # 设置 chunk_size 为 q_len 的最小对齐值 (stride * BLOCK_M = 8 * 128 = 1024) alignment = self.stride * 128 aligned_chunk_size = ((q_len + alignment - 1) // alignment) * alignment # DEBUG: 使用固定 threshold 测试 _, mask_chunk = xattn_estimate( Q[:, :, :q_len, :], # 当前 Q chunk K_full, # 累积的 K block_size=self.BSA_BLOCK_SIZE, stride=self.stride, threshold=self.threshold, # DEBUG: 使用传入的 threshold chunk_size=aligned_chunk_size, # 对齐的 chunk_size causal=True, ) # 计算有效的 block 数量(排除 padding) valid_q_blocks = (q_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE valid_k_blocks = (total_k_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE # 裁剪 mask 到有效区域 mask_valid = mask_chunk[:, :, :valid_q_blocks, :valid_k_blocks] # 计算当前 chunk 的 selected/total (考虑 causal,考虑 Q 偏移量) q_blocks = valid_q_blocks k_blocks = valid_k_blocks # Q 从位置 (k_blocks - q_blocks) 开始,所以 Q block i 实际位置是 i + offset # Q block i (实际位置 i+offset) 可以看到 K block 0 到 i+offset q_offset_blocks = k_blocks - q_blocks indices = torch.arange(k_blocks, device=mask_valid.device).unsqueeze(0) # [1, k_blocks] q_indices = torch.arange(q_blocks, device=mask_valid.device).unsqueeze(1) # [q_blocks, 1] causal_mask = indices <= (q_indices + q_offset_blocks) # [q_blocks, k_blocks] chunk_total = causal_mask.sum().item() * mask_valid.shape[0] * mask_valid.shape[1] chunk_selected = (mask_valid & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item() # 累积 self._debug_selected += chunk_selected self._debug_total += chunk_total # 打印当前累积的 density if self._debug_total > 0: density = self._debug_selected / self._debug_total logger.info(f"[DEBUG Offload Layer0] 累积 density: {density:.4f} " f"(selected={self._debug_selected}, total={self._debug_total}, k_len={total_k_len}, " f"mask_shape={mask_chunk.shape}, q_offset={q_offset_blocks})") # DEBUG: 跳过正常 offload 逻辑,直接返回所有 blocks return available_blocks else: # DEBUG: 非 Layer 0 也跳过正常 offload 逻辑 return available_blocks # ============================================================ # Step 3: Get current chunk K and compute its attn_scores # ============================================================ with nvtx.range("xattn_estimate_current"): # Current chunk K is in prefill buffer (already on GPU) k_curr, _ = offload_engine.get_prefill_buffer_slice(layer_id, q_len) # k_curr: [1, q_len, num_kv_heads, head_dim] -> [1, num_kv_heads, q_len, head_dim] K_current = k_curr.transpose(1, 2) # Handle GQA for current chunk K num_kv_heads = K_current.shape[1] if num_heads != num_kv_heads: num_groups = num_heads // num_kv_heads K_current = K_current.repeat_interleave(num_groups, dim=1) # Pad current K if necessary curr_k_len = K_current.shape[2] padded_curr_k_len = ((curr_k_len + k_alignment - 1) // k_alignment) * k_alignment if padded_curr_k_len != curr_k_len: pad_size = padded_curr_k_len - curr_k_len K_current = torch.nn.functional.pad(K_current, (0, 0, 0, pad_size), value=0) # Compute attention scores for current chunk # IMPORTANT: Use LOCAL coordinates (0 to q_reshaped_len) for current chunk! # Because K_current only contains current chunk K (not full sequence), # block_n in kernel starts from 0. Using global chunk_start would cause # incorrect causal mask (Q would see K blocks it shouldn't). attn_current = flat_group_gemm_fuse_reshape( Q, K_current, self.stride, chunk_start=0, # Local: Q starts at 0 relative to K_current chunk_end=q_reshaped_len, # Local: Q ends at q_reshaped_len is_causal=True, # Current chunk: apply causal mask ) attn_scores_list.append(attn_current) del K_current # ============================================================ # Step 4: Concatenate all attn_scores # ============================================================ if not attn_scores_list: return available_blocks attn_scores = torch.cat(attn_scores_list, dim=-1) del attn_scores_list # Calculate padded K length for later use padded_k_len = historical_k_len + padded_curr_k_len # ============================================================ # Step 5: Apply softmax_fuse_block_sum with causal=True # ============================================================ cpu_block_size = block_size # e.g., 4096 bsa_per_cpu = cpu_block_size // self.BSA_BLOCK_SIZE # e.g., 4096/128 = 32 # Use BSA_BLOCK_SIZE for block aggregation (aligned with GPU-only) reshaped_bsa_bs = self.BSA_BLOCK_SIZE // self.stride # e.g., 128/8 = 16 norm = 1.0 scale = 1.4426950408889634 / math.sqrt(head_dim) / self.stride / norm segment_size = min(4096, reshaped_bsa_bs) with nvtx.range("xattn_estimate_softmax"): block_sums = softmax_fuse_block_sum( attn_scores, reshaped_bsa_bs, segment_size, chunk_start=chunk_start, chunk_end=chunk_end, real_q_len=real_q_len, scale=scale, is_causal=True, # Causal for consistent with GPU-only ) # block_sums shape: [batch, heads, q_bsa_blocks, total_k_bsa_blocks] # ============================================================ # Step 6: Use find_blocks_chunked to generate BSA-level mask # ============================================================ # Calculate BSA block indices q_bsa_blocks = (padded_q_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE total_k_bsa_blocks = (padded_k_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE historical_k_bsa_blocks = num_historical_blocks * bsa_per_cpu # current_index for find_blocks_chunked: Q's block offset q_start_bsa_block = historical_k_bsa_blocks # Q starts after historical K with nvtx.range("xattn_find_blocks"): # 对于历史 K 的选择,使用 causal=False 因为历史 K 都在当前 Q 之前 # current_index=0 避免超出 block_sums 的 K 维度 mask = find_blocks_chunked( block_sums, current_index=0, threshold=self.threshold, num_to_choose=None, decoding=False, mode="both", causal=False, ) # mask shape: [batch, heads, q_bsa_blocks, total_k_bsa_blocks] # ============================================================ # Step 7: Extract mask portions and record density # ============================================================ B, H, Q_bsa, K_bsa_total = mask.shape # Calculate valid Q blocks (excluding padding) valid_q_bsa = (q_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE valid_curr_k_bsa = (curr_k_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE # 7a: Record historical blocks density (暂时禁用,使用 DEBUG 输出代替) # if historical_k_bsa_blocks > 0: # ... DensityObserver.record_counts ... # 7b: Record current chunk density (暂时禁用) # if valid_curr_k_bsa > 0: # ... DensityObserver.record_counts ... # Step 7.5: Save historical mask to pre-allocated buffer for compute_chunked_prefill # Use full Q_bsa (padded) for buffer, not valid_q_bsa mask_historical_full = mask[:, :, :, :historical_k_bsa_blocks] if self._prefill_mask_buffer is not None: # Only save historical portion of mask self._prefill_mask_buffer[:, :, :Q_bsa, :historical_k_bsa_blocks].copy_(mask_historical_full) self._current_mask_q_bsa = Q_bsa self._current_mask_k_bsa = historical_k_bsa_blocks # ============================================================ # Step 8: Aggregate mask to CPU block level (union of heads) # ============================================================ # Only aggregate historical blocks (current chunk is always full attention) num_cpu_blocks = num_historical_blocks with nvtx.range("xattn_aggregate_mask"): # Reshape historical mask: [B, H, Q_bsa, historical_k_bsa] -> [B, H, Q_bsa, num_cpu, bsa_per_cpu] # Use full Q_bsa (not valid_q_bsa) for aggregation mask_per_cpu = mask_historical_full.view(B, H, Q_bsa, num_cpu_blocks, bsa_per_cpu) # Union across: bsa_per_cpu, Q_bsa, heads -> [B, num_cpu] cpu_needed = mask_per_cpu.any(dim=-1).any(dim=2).any(dim=1) # [B, num_cpu] # Get selected indices selected_indices = cpu_needed[0].nonzero().squeeze(-1).tolist() if isinstance(selected_indices, int): selected_indices = [selected_indices] # Handle empty available_blocks case (first chunk) if available_blocks: selected_block_ids = [available_blocks[i] for i in selected_indices] else: selected_block_ids = [] # Always include first block (sink) and last block for safety if available_blocks and available_blocks[0] not in selected_block_ids: selected_block_ids.insert(0, available_blocks[0]) if available_blocks and available_blocks[-1] not in selected_block_ids: selected_block_ids.append(available_blocks[-1]) # Record communication density (CPU block granularity) - only if there are historical blocks if available_blocks: DensityObserver.record_comm_density( layer_id, selected_cpu_blocks=len(selected_block_ids), total_cpu_blocks=len(available_blocks), ) # Update statistics (only for layer 0 to avoid overcounting) if layer_id == 0 and available_blocks: self._stats_total_available_blocks += len(available_blocks) self._stats_total_selected_blocks += len(selected_block_ids) self._stats_num_chunks += 1 # Log per-chunk density chunk_density = len(selected_block_ids) / len(available_blocks) logger.debug(f"[XAttn] chunk={ctx.query_chunk_idx}, available={len(available_blocks)}, " f"selected={len(selected_block_ids)}, chunk_density={chunk_density:.1%}") # Free intermediate tensors to prevent memory leak del attn_scores, block_sums, mask, mask_historical_full return selected_block_ids def compute_chunked_prefill( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, layer_id: int, softmax_scale: float, offload_engine: "OffloadEngine", kvcache_manager: "KVCacheManager", current_chunk_idx: int, seq: "Sequence", num_tokens: int, selected_blocks: List[int], ) -> torch.Tensor: """ Compute attention for chunked prefill using XAttention sparse block selection. This method handles the chunked prefill computation: 1. Load and compute attention to historical chunks (using selected_blocks) 2. Compute attention to current chunk 3. Merge all results Note: The BSA-level mask is saved in self._prefill_mask_buffer by select_blocks(). Currently we use flash_attn_with_lse for computation (supports LSE merge). TODO: Optimize to use BSA kernel with the saved mask for per-head sparse attention. Args: q: Query tensor [seq_len, num_heads, head_dim] k: Key tensor [seq_len, num_kv_heads, head_dim] (unused, from prefill buffer) v: Value tensor [seq_len, num_kv_heads, head_dim] (unused, from prefill buffer) layer_id: Current layer index softmax_scale: Softmax scaling factor offload_engine: OffloadEngine for loading blocks kvcache_manager: KVCacheManager for block management current_chunk_idx: Current chunk index seq: Sequence object num_tokens: Number of tokens in current chunk selected_blocks: List of CPU block IDs selected by select_blocks Returns: Attention output [seq_len, num_heads, head_dim] """ # Use FlashInfer-based implementations (more optimized) from nanovllm.ops.chunked_attention import ( flash_attn_with_lse_flashinfer as flash_attn_with_lse, merge_attention_outputs_flashinfer as merge_attention_outputs, ) q_batched = q.unsqueeze(0) # [1, seq_len, num_heads, head_dim] o_acc = None lse_acc = None compute_stream = offload_engine.compute_stream # Use the pre-selected blocks directly cpu_block_table = selected_blocks # Note: BSA mask is available in self._prefill_mask_buffer (saved by select_blocks) # Mask shape: [1, num_heads, Q_bsa, K_bsa] where Q_bsa = self._current_mask_q_bsa # Selected indices: self._selected_cpu_indices, bsa_per_cpu: self._bsa_per_cpu # TODO: Use this mask with BSA kernel for per-head sparse attention optimization if cpu_block_table: with nvtx.range("xattn_compute_historical"): load_slots = list(range(offload_engine.num_ring_slots)) num_blocks = len(cpu_block_table) if len(load_slots) == 1: # Only 1 slot - use synchronous mode slot = load_slots[0] for block_idx in range(num_blocks): cpu_block_id = cpu_block_table[block_idx] offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id, chunk_idx=cpu_block_id) offload_engine.wait_slot_layer(slot) with torch.cuda.stream(compute_stream): prev_k, prev_v = offload_engine.get_kv_for_slot(slot) prev_o, prev_lse = flash_attn_with_lse( q_batched, prev_k, prev_v, softmax_scale=softmax_scale, causal=False, ) if o_acc is None: o_acc, lse_acc = prev_o, prev_lse else: o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse) offload_engine.record_slot_compute_done(slot) else: # Multiple slots - use pipeline num_slots = len(load_slots) num_preload = min(num_slots, num_blocks) for i in range(num_preload): cpu_block_id = cpu_block_table[i] offload_engine.load_to_slot_layer(load_slots[i], layer_id, cpu_block_id, chunk_idx=cpu_block_id) for block_idx in range(num_blocks): current_slot = load_slots[block_idx % num_slots] offload_engine.wait_slot_layer(current_slot) with torch.cuda.stream(compute_stream): prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot) prev_o, prev_lse = flash_attn_with_lse( q_batched, prev_k, prev_v, softmax_scale=softmax_scale, causal=False, ) offload_engine.record_slot_compute_done(current_slot) if o_acc is None: o_acc, lse_acc = prev_o, prev_lse else: o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse) # Issue next transfer next_block_idx = block_idx + num_slots if next_block_idx < num_blocks: next_slot = load_slots[next_block_idx % num_slots] next_cpu_block_id = cpu_block_table[next_block_idx] offload_engine.load_to_slot_layer(next_slot, layer_id, next_cpu_block_id, chunk_idx=next_cpu_block_id) # Compute attention to current chunk (causal mask) with nvtx.range("xattn_compute_current"): with torch.cuda.stream(compute_stream): k_curr, v_curr = offload_engine.get_prefill_buffer_slice(layer_id, num_tokens) current_o, current_lse = flash_attn_with_lse( q_batched, k_curr, v_curr, softmax_scale=softmax_scale, causal=True, ) # Merge historical and current attention with nvtx.range("xattn_compute_merge"): with torch.cuda.stream(compute_stream): if o_acc is None: final_o = current_o else: final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse) # Sync default stream with compute_stream before returning torch.cuda.default_stream().wait_stream(compute_stream) # Remove batch dimension: [1, seq_len, num_heads, head_dim] -> [seq_len, num_heads, head_dim] return final_o.squeeze(0) def compute_chunked_decode( self, q: torch.Tensor, layer_id: int, softmax_scale: float, offload_engine: "OffloadEngine", kvcache_manager: "KVCacheManager", seq: "Sequence", selected_blocks: List[int], ) -> torch.Tensor: """ XAttention does not support decode phase. """ raise NotImplementedError( "XAttentionBSAPolicy does not support decode phase. " "Use FullAttentionPolicy for decode." ) def reset(self) -> None: """Reset policy state and clear sparse metadata.""" self.sparse_metadata.clear() # Don't reset statistics here - they accumulate across the entire prefill def reset_stats(self) -> None: """Reset density statistics.""" self._stats_total_available_blocks = 0 self._stats_total_selected_blocks = 0 self._stats_num_chunks = 0 def get_density_stats(self) -> dict: """Get density statistics.""" if self._stats_total_available_blocks == 0: return { "total_available_blocks": 0, "total_selected_blocks": 0, "num_chunks": 0, "overall_density": 0.0, } return { "total_available_blocks": self._stats_total_available_blocks, "total_selected_blocks": self._stats_total_selected_blocks, "num_chunks": self._stats_num_chunks, "overall_density": self._stats_total_selected_blocks / self._stats_total_available_blocks, } def print_density_stats(self) -> None: """Print density statistics summary.""" stats = self.get_density_stats() logger.info(f"[XAttn BSA] Density Stats: chunks={stats['num_chunks']}, " f"available={stats['total_available_blocks']}, " f"selected={stats['total_selected_blocks']}, " f"density={stats['overall_density']:.1%}") def __repr__(self) -> str: return f"XAttentionBSAPolicy(threshold={self.threshold}, stride={self.stride})"