feat: add GPU-only XAttention BSA sparse attention support

- Implement compute_prefill() in XAttentionBSAPolicy for GPU-only mode
  - Uses xattn_estimate to compute sparse block mask
  - Uses block_sparse_attn_func for efficient sparse attention
  - Handles GQA by expanding K/V heads
  - Falls back to flash_attn for paged KV cache (prefix cache)
- Implement compute_decode() by delegating to FullAttentionPolicy
- Add --policy xattn option to bench.py

Verified: RULER 32k niah_single_1 5/5 samples passed (100%)

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
This commit is contained in:
Zijie Tian
2026-01-27 05:19:24 +08:00
parent b6b59b50ed
commit 076656c9c2
2 changed files with 207 additions and 1 deletions

View File

@@ -51,6 +51,9 @@ def main():
parser.add_argument("--bench-decode", action="store_true", help="Run decode benchmark (default: prefill only)") parser.add_argument("--bench-decode", action="store_true", help="Run decode benchmark (default: prefill only)")
parser.add_argument("--bench-all", action="store_true", help="Run both prefill and decode benchmarks") parser.add_argument("--bench-all", action="store_true", help="Run both prefill and decode benchmarks")
# Sparse policy option (GPU-only mode now supports policy routing) # Sparse policy option (GPU-only mode now supports policy routing)
parser.add_argument("--policy", type=str, default=None,
choices=["full", "xattn"],
help="Sparse policy: full (FullAttention), xattn (XAttention+BSA)")
parser.add_argument("--enable-policy", action="store_true", parser.add_argument("--enable-policy", action="store_true",
help="Enable sparse policy routing (FullAttentionPolicy by default)") help="Enable sparse policy routing (FullAttentionPolicy by default)")
args = parser.parse_args() args = parser.parse_args()
@@ -59,7 +62,10 @@ def main():
max_len = args.max_len max_len = args.max_len
# Configure sparse policy # Configure sparse policy
if args.enable_policy: if args.policy == "xattn":
sparse_policy = SparsePolicyType.XATTN_BSA
print(f"\n[nanovllm GPU + XAttention BSA] max_len={max_len}")
elif args.policy == "full" or args.enable_policy:
sparse_policy = SparsePolicyType.FULL sparse_policy = SparsePolicyType.FULL
print(f"\n[nanovllm GPU + Policy] sparse_policy=FULL, max_len={max_len}") print(f"\n[nanovllm GPU + Policy] sparse_policy=FULL, max_len={max_len}")
else: else:

View File

@@ -122,6 +122,206 @@ class XAttentionBSAPolicy(SparsePolicy):
self._stats_total_selected_blocks = 0 self._stats_total_selected_blocks = 0
self._stats_num_chunks = 0 self._stats_num_chunks = 0
# =========================================================================
# GPU-only methods (non-chunked)
# =========================================================================
def compute_prefill(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
softmax_scale: float,
layer_id: int,
block_tables: torch.Tensor = None,
) -> torch.Tensor:
"""
GPU-only prefill attention using XAttention + BSA.
This method implements sparse attention for GPU-only mode:
1. Estimate block importance using xattn_estimate
2. Compute sparse attention using block_sparse_attn_func
Args:
q: Query tensor [total_q, num_heads, head_dim] (varlen packed)
k: Key tensor [total_kv, num_kv_heads, head_dim] (varlen packed)
v: Value tensor [total_kv, num_kv_heads, head_dim] (varlen packed)
cu_seqlens_q: Cumulative sequence lengths for Q [batch+1]
cu_seqlens_k: Cumulative sequence lengths for K [batch+1]
max_seqlen_q: Maximum Q sequence length
max_seqlen_k: Maximum K sequence length
softmax_scale: Softmax scaling factor
layer_id: Transformer layer index
block_tables: Paged attention block tables (not used for XAttention)
Returns:
Attention output [total_q, num_heads, head_dim]
"""
# When block_tables is provided (paged KV cache / prefix cache),
# fallback to flash_attn as XAttention expects contiguous K, V
if block_tables is not None:
from flash_attn import flash_attn_varlen_func
return flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
softmax_scale=softmax_scale,
causal=True,
block_table=block_tables,
)
if not BSA_AVAILABLE:
# Fallback to flash attention if BSA not available
from flash_attn import flash_attn_varlen_func
return flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
softmax_scale=softmax_scale,
causal=True,
)
if not XATTN_AVAILABLE:
# Fallback to flash attention if xattn not available
from flash_attn import flash_attn_varlen_func
return flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
softmax_scale=softmax_scale,
causal=True,
)
from nanovllm.ops.xattn import xattn_estimate
# Get dimensions
total_q, num_heads, head_dim = q.shape
total_kv, num_kv_heads, _ = k.shape
# For now, assume batch_size = 1 (single sequence)
# TODO: Support batched varlen format
batch_size = cu_seqlens_q.shape[0] - 1
if batch_size != 1:
# Fallback to flash attention for batched input
from flash_attn import flash_attn_varlen_func
logger.warning(f"[XAttn] batch_size={batch_size} > 1, falling back to flash attention")
return flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
softmax_scale=softmax_scale,
causal=True,
)
q_len = max_seqlen_q
k_len = max_seqlen_k
# Convert from varlen format [total, heads, dim] to [batch, heads, seq, dim]
# q: [q_len, num_heads, head_dim] -> [1, num_heads, q_len, head_dim]
Q = q.unsqueeze(0).transpose(1, 2) # [1, num_heads, q_len, head_dim]
K = k.unsqueeze(0).transpose(1, 2) # [1, num_kv_heads, k_len, head_dim]
V = v.unsqueeze(0).transpose(1, 2) # [1, num_kv_heads, k_len, head_dim]
# Expand KV for GQA (xattn_estimate requires matching heads)
K_exp, V_exp = expand_kv_for_gqa(K, V, num_heads)
# Estimate block importance and get sparse mask
_, mask = xattn_estimate(
Q, K_exp,
chunk_size=self.chunk_size,
block_size=self.BSA_BLOCK_SIZE,
threshold=self.threshold,
use_triton=self.use_triton,
causal=True,
)
# Compute block counts
q_block_num = (q_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE
k_block_num = (k_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE
# Prepare tensors for BSA
# q, k, v need to be [seq_len, num_heads, head_dim]
q_bsa = q # Already [q_len, num_heads, head_dim]
# For GQA with BSA, we need to expand k, v to match num_heads
# k, v: [k_len, num_kv_heads, head_dim] -> [k_len, num_heads, head_dim]
if num_heads != num_kv_heads:
num_groups = num_heads // num_kv_heads
k_bsa = k.repeat_interleave(num_groups, dim=1)
v_bsa = v.repeat_interleave(num_groups, dim=1)
else:
k_bsa = k
v_bsa = v
# Prepare BSA inputs
cu_seqlens_q_bsa = torch.tensor([0, q_len], dtype=torch.int32, device=q.device)
cu_seqlens_k_bsa = torch.tensor([0, k_len], dtype=torch.int32, device=k.device)
head_groups = torch.ones(num_heads, dtype=torch.int32, device=q.device)
# Trim mask to actual block counts
mask_trimmed = mask[:, :, :q_block_num, :k_block_num].contiguous()
# Compute sparse attention using BSA
output = block_sparse_attn_func(
q_bsa, k_bsa, v_bsa,
cu_seqlens_q_bsa,
cu_seqlens_k_bsa,
head_groups,
None, # key_padding_mask
mask_trimmed,
q_len, k_len,
p_dropout=0.0,
deterministic=True,
is_causal=True,
)
# Update statistics (layer 0 only to avoid overcounting)
if layer_id == 0:
selected_blocks = mask_trimmed.sum().item()
total_blocks = q_block_num * k_block_num * num_heads
density = selected_blocks / total_blocks if total_blocks > 0 else 1.0
logger.debug(f"[XAttn GPU-only] layer={layer_id}, q_blocks={q_block_num}, "
f"k_blocks={k_block_num}, density={density:.1%}")
return output
def compute_decode(
self,
q: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
cache_seqlens: torch.Tensor,
softmax_scale: float,
layer_id: int,
block_tables: torch.Tensor = None,
) -> torch.Tensor:
"""
GPU-only decode attention - delegates to FullAttentionPolicy.
XAttention is designed for long prefill sequences. For decode (single token),
we use FullAttentionPolicy which calls flash_attn_with_kvcache.
"""
from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy
return FullAttentionPolicy().compute_decode(
q, k_cache, v_cache, cache_seqlens, softmax_scale, layer_id, block_tables
)
# =========================================================================
# Chunked offload methods
# =========================================================================
def select_blocks( def select_blocks(
self, self,
available_blocks: List[int], available_blocks: List[int],