diff --git a/bench.py b/bench.py index 7140b82..7fa0d69 100644 --- a/bench.py +++ b/bench.py @@ -51,6 +51,9 @@ def main(): parser.add_argument("--bench-decode", action="store_true", help="Run decode benchmark (default: prefill only)") parser.add_argument("--bench-all", action="store_true", help="Run both prefill and decode benchmarks") # Sparse policy option (GPU-only mode now supports policy routing) + parser.add_argument("--policy", type=str, default=None, + choices=["full", "xattn"], + help="Sparse policy: full (FullAttention), xattn (XAttention+BSA)") parser.add_argument("--enable-policy", action="store_true", help="Enable sparse policy routing (FullAttentionPolicy by default)") args = parser.parse_args() @@ -59,7 +62,10 @@ def main(): max_len = args.max_len # Configure sparse policy - if args.enable_policy: + if args.policy == "xattn": + sparse_policy = SparsePolicyType.XATTN_BSA + print(f"\n[nanovllm GPU + XAttention BSA] max_len={max_len}") + elif args.policy == "full" or args.enable_policy: sparse_policy = SparsePolicyType.FULL print(f"\n[nanovllm GPU + Policy] sparse_policy=FULL, max_len={max_len}") else: diff --git a/nanovllm/kvcache/sparse/xattn_bsa.py b/nanovllm/kvcache/sparse/xattn_bsa.py index 70dcd1d..277f6d6 100644 --- a/nanovllm/kvcache/sparse/xattn_bsa.py +++ b/nanovllm/kvcache/sparse/xattn_bsa.py @@ -122,6 +122,206 @@ class XAttentionBSAPolicy(SparsePolicy): self._stats_total_selected_blocks = 0 self._stats_num_chunks = 0 + # ========================================================================= + # GPU-only methods (non-chunked) + # ========================================================================= + + def compute_prefill( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + softmax_scale: float, + layer_id: int, + block_tables: torch.Tensor = None, + ) -> torch.Tensor: + """ + GPU-only prefill attention using XAttention + BSA. + + This method implements sparse attention for GPU-only mode: + 1. Estimate block importance using xattn_estimate + 2. Compute sparse attention using block_sparse_attn_func + + Args: + q: Query tensor [total_q, num_heads, head_dim] (varlen packed) + k: Key tensor [total_kv, num_kv_heads, head_dim] (varlen packed) + v: Value tensor [total_kv, num_kv_heads, head_dim] (varlen packed) + cu_seqlens_q: Cumulative sequence lengths for Q [batch+1] + cu_seqlens_k: Cumulative sequence lengths for K [batch+1] + max_seqlen_q: Maximum Q sequence length + max_seqlen_k: Maximum K sequence length + softmax_scale: Softmax scaling factor + layer_id: Transformer layer index + block_tables: Paged attention block tables (not used for XAttention) + + Returns: + Attention output [total_q, num_heads, head_dim] + """ + # When block_tables is provided (paged KV cache / prefix cache), + # fallback to flash_attn as XAttention expects contiguous K, V + if block_tables is not None: + from flash_attn import flash_attn_varlen_func + return flash_attn_varlen_func( + q, k, v, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + softmax_scale=softmax_scale, + causal=True, + block_table=block_tables, + ) + + if not BSA_AVAILABLE: + # Fallback to flash attention if BSA not available + from flash_attn import flash_attn_varlen_func + return flash_attn_varlen_func( + q, k, v, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + softmax_scale=softmax_scale, + causal=True, + ) + + if not XATTN_AVAILABLE: + # Fallback to flash attention if xattn not available + from flash_attn import flash_attn_varlen_func + return flash_attn_varlen_func( + q, k, v, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + softmax_scale=softmax_scale, + causal=True, + ) + + from nanovllm.ops.xattn import xattn_estimate + + # Get dimensions + total_q, num_heads, head_dim = q.shape + total_kv, num_kv_heads, _ = k.shape + + # For now, assume batch_size = 1 (single sequence) + # TODO: Support batched varlen format + batch_size = cu_seqlens_q.shape[0] - 1 + if batch_size != 1: + # Fallback to flash attention for batched input + from flash_attn import flash_attn_varlen_func + logger.warning(f"[XAttn] batch_size={batch_size} > 1, falling back to flash attention") + return flash_attn_varlen_func( + q, k, v, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + softmax_scale=softmax_scale, + causal=True, + ) + + q_len = max_seqlen_q + k_len = max_seqlen_k + + # Convert from varlen format [total, heads, dim] to [batch, heads, seq, dim] + # q: [q_len, num_heads, head_dim] -> [1, num_heads, q_len, head_dim] + Q = q.unsqueeze(0).transpose(1, 2) # [1, num_heads, q_len, head_dim] + K = k.unsqueeze(0).transpose(1, 2) # [1, num_kv_heads, k_len, head_dim] + V = v.unsqueeze(0).transpose(1, 2) # [1, num_kv_heads, k_len, head_dim] + + # Expand KV for GQA (xattn_estimate requires matching heads) + K_exp, V_exp = expand_kv_for_gqa(K, V, num_heads) + + # Estimate block importance and get sparse mask + _, mask = xattn_estimate( + Q, K_exp, + chunk_size=self.chunk_size, + block_size=self.BSA_BLOCK_SIZE, + threshold=self.threshold, + use_triton=self.use_triton, + causal=True, + ) + + # Compute block counts + q_block_num = (q_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE + k_block_num = (k_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE + + # Prepare tensors for BSA + # q, k, v need to be [seq_len, num_heads, head_dim] + q_bsa = q # Already [q_len, num_heads, head_dim] + + # For GQA with BSA, we need to expand k, v to match num_heads + # k, v: [k_len, num_kv_heads, head_dim] -> [k_len, num_heads, head_dim] + if num_heads != num_kv_heads: + num_groups = num_heads // num_kv_heads + k_bsa = k.repeat_interleave(num_groups, dim=1) + v_bsa = v.repeat_interleave(num_groups, dim=1) + else: + k_bsa = k + v_bsa = v + + # Prepare BSA inputs + cu_seqlens_q_bsa = torch.tensor([0, q_len], dtype=torch.int32, device=q.device) + cu_seqlens_k_bsa = torch.tensor([0, k_len], dtype=torch.int32, device=k.device) + head_groups = torch.ones(num_heads, dtype=torch.int32, device=q.device) + + # Trim mask to actual block counts + mask_trimmed = mask[:, :, :q_block_num, :k_block_num].contiguous() + + # Compute sparse attention using BSA + output = block_sparse_attn_func( + q_bsa, k_bsa, v_bsa, + cu_seqlens_q_bsa, + cu_seqlens_k_bsa, + head_groups, + None, # key_padding_mask + mask_trimmed, + q_len, k_len, + p_dropout=0.0, + deterministic=True, + is_causal=True, + ) + + # Update statistics (layer 0 only to avoid overcounting) + if layer_id == 0: + selected_blocks = mask_trimmed.sum().item() + total_blocks = q_block_num * k_block_num * num_heads + density = selected_blocks / total_blocks if total_blocks > 0 else 1.0 + logger.debug(f"[XAttn GPU-only] layer={layer_id}, q_blocks={q_block_num}, " + f"k_blocks={k_block_num}, density={density:.1%}") + + return output + + def compute_decode( + self, + q: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + cache_seqlens: torch.Tensor, + softmax_scale: float, + layer_id: int, + block_tables: torch.Tensor = None, + ) -> torch.Tensor: + """ + GPU-only decode attention - delegates to FullAttentionPolicy. + + XAttention is designed for long prefill sequences. For decode (single token), + we use FullAttentionPolicy which calls flash_attn_with_kvcache. + """ + from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy + return FullAttentionPolicy().compute_decode( + q, k_cache, v_cache, cache_seqlens, softmax_scale, layer_id, block_tables + ) + + # ========================================================================= + # Chunked offload methods + # ========================================================================= + def select_blocks( self, available_blocks: List[int],