✨ feat: add GPU-only XAttention BSA sparse attention support
- Implement compute_prefill() in XAttentionBSAPolicy for GPU-only mode - Uses xattn_estimate to compute sparse block mask - Uses block_sparse_attn_func for efficient sparse attention - Handles GQA by expanding K/V heads - Falls back to flash_attn for paged KV cache (prefix cache) - Implement compute_decode() by delegating to FullAttentionPolicy - Add --policy xattn option to bench.py Verified: RULER 32k niah_single_1 5/5 samples passed (100%) Generated with [Claude Code](https://claude.ai/code) via [Happy](https://happy.engineering) Co-Authored-By: Claude <noreply@anthropic.com> Co-Authored-By: Happy <yesreply@happy.engineering>
This commit is contained in:
8
bench.py
8
bench.py
@@ -51,6 +51,9 @@ def main():
|
||||
parser.add_argument("--bench-decode", action="store_true", help="Run decode benchmark (default: prefill only)")
|
||||
parser.add_argument("--bench-all", action="store_true", help="Run both prefill and decode benchmarks")
|
||||
# Sparse policy option (GPU-only mode now supports policy routing)
|
||||
parser.add_argument("--policy", type=str, default=None,
|
||||
choices=["full", "xattn"],
|
||||
help="Sparse policy: full (FullAttention), xattn (XAttention+BSA)")
|
||||
parser.add_argument("--enable-policy", action="store_true",
|
||||
help="Enable sparse policy routing (FullAttentionPolicy by default)")
|
||||
args = parser.parse_args()
|
||||
@@ -59,7 +62,10 @@ def main():
|
||||
max_len = args.max_len
|
||||
|
||||
# Configure sparse policy
|
||||
if args.enable_policy:
|
||||
if args.policy == "xattn":
|
||||
sparse_policy = SparsePolicyType.XATTN_BSA
|
||||
print(f"\n[nanovllm GPU + XAttention BSA] max_len={max_len}")
|
||||
elif args.policy == "full" or args.enable_policy:
|
||||
sparse_policy = SparsePolicyType.FULL
|
||||
print(f"\n[nanovllm GPU + Policy] sparse_policy=FULL, max_len={max_len}")
|
||||
else:
|
||||
|
||||
@@ -122,6 +122,206 @@ class XAttentionBSAPolicy(SparsePolicy):
|
||||
self._stats_total_selected_blocks = 0
|
||||
self._stats_num_chunks = 0
|
||||
|
||||
# =========================================================================
|
||||
# GPU-only methods (non-chunked)
|
||||
# =========================================================================
|
||||
|
||||
def compute_prefill(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
cu_seqlens_q: torch.Tensor,
|
||||
cu_seqlens_k: torch.Tensor,
|
||||
max_seqlen_q: int,
|
||||
max_seqlen_k: int,
|
||||
softmax_scale: float,
|
||||
layer_id: int,
|
||||
block_tables: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
GPU-only prefill attention using XAttention + BSA.
|
||||
|
||||
This method implements sparse attention for GPU-only mode:
|
||||
1. Estimate block importance using xattn_estimate
|
||||
2. Compute sparse attention using block_sparse_attn_func
|
||||
|
||||
Args:
|
||||
q: Query tensor [total_q, num_heads, head_dim] (varlen packed)
|
||||
k: Key tensor [total_kv, num_kv_heads, head_dim] (varlen packed)
|
||||
v: Value tensor [total_kv, num_kv_heads, head_dim] (varlen packed)
|
||||
cu_seqlens_q: Cumulative sequence lengths for Q [batch+1]
|
||||
cu_seqlens_k: Cumulative sequence lengths for K [batch+1]
|
||||
max_seqlen_q: Maximum Q sequence length
|
||||
max_seqlen_k: Maximum K sequence length
|
||||
softmax_scale: Softmax scaling factor
|
||||
layer_id: Transformer layer index
|
||||
block_tables: Paged attention block tables (not used for XAttention)
|
||||
|
||||
Returns:
|
||||
Attention output [total_q, num_heads, head_dim]
|
||||
"""
|
||||
# When block_tables is provided (paged KV cache / prefix cache),
|
||||
# fallback to flash_attn as XAttention expects contiguous K, V
|
||||
if block_tables is not None:
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
return flash_attn_varlen_func(
|
||||
q, k, v,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=True,
|
||||
block_table=block_tables,
|
||||
)
|
||||
|
||||
if not BSA_AVAILABLE:
|
||||
# Fallback to flash attention if BSA not available
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
return flash_attn_varlen_func(
|
||||
q, k, v,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
if not XATTN_AVAILABLE:
|
||||
# Fallback to flash attention if xattn not available
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
return flash_attn_varlen_func(
|
||||
q, k, v,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
from nanovllm.ops.xattn import xattn_estimate
|
||||
|
||||
# Get dimensions
|
||||
total_q, num_heads, head_dim = q.shape
|
||||
total_kv, num_kv_heads, _ = k.shape
|
||||
|
||||
# For now, assume batch_size = 1 (single sequence)
|
||||
# TODO: Support batched varlen format
|
||||
batch_size = cu_seqlens_q.shape[0] - 1
|
||||
if batch_size != 1:
|
||||
# Fallback to flash attention for batched input
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
logger.warning(f"[XAttn] batch_size={batch_size} > 1, falling back to flash attention")
|
||||
return flash_attn_varlen_func(
|
||||
q, k, v,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
q_len = max_seqlen_q
|
||||
k_len = max_seqlen_k
|
||||
|
||||
# Convert from varlen format [total, heads, dim] to [batch, heads, seq, dim]
|
||||
# q: [q_len, num_heads, head_dim] -> [1, num_heads, q_len, head_dim]
|
||||
Q = q.unsqueeze(0).transpose(1, 2) # [1, num_heads, q_len, head_dim]
|
||||
K = k.unsqueeze(0).transpose(1, 2) # [1, num_kv_heads, k_len, head_dim]
|
||||
V = v.unsqueeze(0).transpose(1, 2) # [1, num_kv_heads, k_len, head_dim]
|
||||
|
||||
# Expand KV for GQA (xattn_estimate requires matching heads)
|
||||
K_exp, V_exp = expand_kv_for_gqa(K, V, num_heads)
|
||||
|
||||
# Estimate block importance and get sparse mask
|
||||
_, mask = xattn_estimate(
|
||||
Q, K_exp,
|
||||
chunk_size=self.chunk_size,
|
||||
block_size=self.BSA_BLOCK_SIZE,
|
||||
threshold=self.threshold,
|
||||
use_triton=self.use_triton,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
# Compute block counts
|
||||
q_block_num = (q_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE
|
||||
k_block_num = (k_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE
|
||||
|
||||
# Prepare tensors for BSA
|
||||
# q, k, v need to be [seq_len, num_heads, head_dim]
|
||||
q_bsa = q # Already [q_len, num_heads, head_dim]
|
||||
|
||||
# For GQA with BSA, we need to expand k, v to match num_heads
|
||||
# k, v: [k_len, num_kv_heads, head_dim] -> [k_len, num_heads, head_dim]
|
||||
if num_heads != num_kv_heads:
|
||||
num_groups = num_heads // num_kv_heads
|
||||
k_bsa = k.repeat_interleave(num_groups, dim=1)
|
||||
v_bsa = v.repeat_interleave(num_groups, dim=1)
|
||||
else:
|
||||
k_bsa = k
|
||||
v_bsa = v
|
||||
|
||||
# Prepare BSA inputs
|
||||
cu_seqlens_q_bsa = torch.tensor([0, q_len], dtype=torch.int32, device=q.device)
|
||||
cu_seqlens_k_bsa = torch.tensor([0, k_len], dtype=torch.int32, device=k.device)
|
||||
head_groups = torch.ones(num_heads, dtype=torch.int32, device=q.device)
|
||||
|
||||
# Trim mask to actual block counts
|
||||
mask_trimmed = mask[:, :, :q_block_num, :k_block_num].contiguous()
|
||||
|
||||
# Compute sparse attention using BSA
|
||||
output = block_sparse_attn_func(
|
||||
q_bsa, k_bsa, v_bsa,
|
||||
cu_seqlens_q_bsa,
|
||||
cu_seqlens_k_bsa,
|
||||
head_groups,
|
||||
None, # key_padding_mask
|
||||
mask_trimmed,
|
||||
q_len, k_len,
|
||||
p_dropout=0.0,
|
||||
deterministic=True,
|
||||
is_causal=True,
|
||||
)
|
||||
|
||||
# Update statistics (layer 0 only to avoid overcounting)
|
||||
if layer_id == 0:
|
||||
selected_blocks = mask_trimmed.sum().item()
|
||||
total_blocks = q_block_num * k_block_num * num_heads
|
||||
density = selected_blocks / total_blocks if total_blocks > 0 else 1.0
|
||||
logger.debug(f"[XAttn GPU-only] layer={layer_id}, q_blocks={q_block_num}, "
|
||||
f"k_blocks={k_block_num}, density={density:.1%}")
|
||||
|
||||
return output
|
||||
|
||||
def compute_decode(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
cache_seqlens: torch.Tensor,
|
||||
softmax_scale: float,
|
||||
layer_id: int,
|
||||
block_tables: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
GPU-only decode attention - delegates to FullAttentionPolicy.
|
||||
|
||||
XAttention is designed for long prefill sequences. For decode (single token),
|
||||
we use FullAttentionPolicy which calls flash_attn_with_kvcache.
|
||||
"""
|
||||
from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy
|
||||
return FullAttentionPolicy().compute_decode(
|
||||
q, k_cache, v_cache, cache_seqlens, softmax_scale, layer_id, block_tables
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# Chunked offload methods
|
||||
# =========================================================================
|
||||
|
||||
def select_blocks(
|
||||
self,
|
||||
available_blocks: List[int],
|
||||
|
||||
Reference in New Issue
Block a user