✨ feat: add GPU-only XAttention BSA sparse attention support
- Implement compute_prefill() in XAttentionBSAPolicy for GPU-only mode - Uses xattn_estimate to compute sparse block mask - Uses block_sparse_attn_func for efficient sparse attention - Handles GQA by expanding K/V heads - Falls back to flash_attn for paged KV cache (prefix cache) - Implement compute_decode() by delegating to FullAttentionPolicy - Add --policy xattn option to bench.py Verified: RULER 32k niah_single_1 5/5 samples passed (100%) Generated with [Claude Code](https://claude.ai/code) via [Happy](https://happy.engineering) Co-Authored-By: Claude <noreply@anthropic.com> Co-Authored-By: Happy <yesreply@happy.engineering>
This commit is contained in:
8
bench.py
8
bench.py
@@ -51,6 +51,9 @@ def main():
|
|||||||
parser.add_argument("--bench-decode", action="store_true", help="Run decode benchmark (default: prefill only)")
|
parser.add_argument("--bench-decode", action="store_true", help="Run decode benchmark (default: prefill only)")
|
||||||
parser.add_argument("--bench-all", action="store_true", help="Run both prefill and decode benchmarks")
|
parser.add_argument("--bench-all", action="store_true", help="Run both prefill and decode benchmarks")
|
||||||
# Sparse policy option (GPU-only mode now supports policy routing)
|
# Sparse policy option (GPU-only mode now supports policy routing)
|
||||||
|
parser.add_argument("--policy", type=str, default=None,
|
||||||
|
choices=["full", "xattn"],
|
||||||
|
help="Sparse policy: full (FullAttention), xattn (XAttention+BSA)")
|
||||||
parser.add_argument("--enable-policy", action="store_true",
|
parser.add_argument("--enable-policy", action="store_true",
|
||||||
help="Enable sparse policy routing (FullAttentionPolicy by default)")
|
help="Enable sparse policy routing (FullAttentionPolicy by default)")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
@@ -59,7 +62,10 @@ def main():
|
|||||||
max_len = args.max_len
|
max_len = args.max_len
|
||||||
|
|
||||||
# Configure sparse policy
|
# Configure sparse policy
|
||||||
if args.enable_policy:
|
if args.policy == "xattn":
|
||||||
|
sparse_policy = SparsePolicyType.XATTN_BSA
|
||||||
|
print(f"\n[nanovllm GPU + XAttention BSA] max_len={max_len}")
|
||||||
|
elif args.policy == "full" or args.enable_policy:
|
||||||
sparse_policy = SparsePolicyType.FULL
|
sparse_policy = SparsePolicyType.FULL
|
||||||
print(f"\n[nanovllm GPU + Policy] sparse_policy=FULL, max_len={max_len}")
|
print(f"\n[nanovllm GPU + Policy] sparse_policy=FULL, max_len={max_len}")
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -122,6 +122,206 @@ class XAttentionBSAPolicy(SparsePolicy):
|
|||||||
self._stats_total_selected_blocks = 0
|
self._stats_total_selected_blocks = 0
|
||||||
self._stats_num_chunks = 0
|
self._stats_num_chunks = 0
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# GPU-only methods (non-chunked)
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
def compute_prefill(
|
||||||
|
self,
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
cu_seqlens_q: torch.Tensor,
|
||||||
|
cu_seqlens_k: torch.Tensor,
|
||||||
|
max_seqlen_q: int,
|
||||||
|
max_seqlen_k: int,
|
||||||
|
softmax_scale: float,
|
||||||
|
layer_id: int,
|
||||||
|
block_tables: torch.Tensor = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
GPU-only prefill attention using XAttention + BSA.
|
||||||
|
|
||||||
|
This method implements sparse attention for GPU-only mode:
|
||||||
|
1. Estimate block importance using xattn_estimate
|
||||||
|
2. Compute sparse attention using block_sparse_attn_func
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q: Query tensor [total_q, num_heads, head_dim] (varlen packed)
|
||||||
|
k: Key tensor [total_kv, num_kv_heads, head_dim] (varlen packed)
|
||||||
|
v: Value tensor [total_kv, num_kv_heads, head_dim] (varlen packed)
|
||||||
|
cu_seqlens_q: Cumulative sequence lengths for Q [batch+1]
|
||||||
|
cu_seqlens_k: Cumulative sequence lengths for K [batch+1]
|
||||||
|
max_seqlen_q: Maximum Q sequence length
|
||||||
|
max_seqlen_k: Maximum K sequence length
|
||||||
|
softmax_scale: Softmax scaling factor
|
||||||
|
layer_id: Transformer layer index
|
||||||
|
block_tables: Paged attention block tables (not used for XAttention)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Attention output [total_q, num_heads, head_dim]
|
||||||
|
"""
|
||||||
|
# When block_tables is provided (paged KV cache / prefix cache),
|
||||||
|
# fallback to flash_attn as XAttention expects contiguous K, V
|
||||||
|
if block_tables is not None:
|
||||||
|
from flash_attn import flash_attn_varlen_func
|
||||||
|
return flash_attn_varlen_func(
|
||||||
|
q, k, v,
|
||||||
|
cu_seqlens_q=cu_seqlens_q,
|
||||||
|
cu_seqlens_k=cu_seqlens_k,
|
||||||
|
max_seqlen_q=max_seqlen_q,
|
||||||
|
max_seqlen_k=max_seqlen_k,
|
||||||
|
softmax_scale=softmax_scale,
|
||||||
|
causal=True,
|
||||||
|
block_table=block_tables,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not BSA_AVAILABLE:
|
||||||
|
# Fallback to flash attention if BSA not available
|
||||||
|
from flash_attn import flash_attn_varlen_func
|
||||||
|
return flash_attn_varlen_func(
|
||||||
|
q, k, v,
|
||||||
|
cu_seqlens_q=cu_seqlens_q,
|
||||||
|
cu_seqlens_k=cu_seqlens_k,
|
||||||
|
max_seqlen_q=max_seqlen_q,
|
||||||
|
max_seqlen_k=max_seqlen_k,
|
||||||
|
softmax_scale=softmax_scale,
|
||||||
|
causal=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not XATTN_AVAILABLE:
|
||||||
|
# Fallback to flash attention if xattn not available
|
||||||
|
from flash_attn import flash_attn_varlen_func
|
||||||
|
return flash_attn_varlen_func(
|
||||||
|
q, k, v,
|
||||||
|
cu_seqlens_q=cu_seqlens_q,
|
||||||
|
cu_seqlens_k=cu_seqlens_k,
|
||||||
|
max_seqlen_q=max_seqlen_q,
|
||||||
|
max_seqlen_k=max_seqlen_k,
|
||||||
|
softmax_scale=softmax_scale,
|
||||||
|
causal=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
from nanovllm.ops.xattn import xattn_estimate
|
||||||
|
|
||||||
|
# Get dimensions
|
||||||
|
total_q, num_heads, head_dim = q.shape
|
||||||
|
total_kv, num_kv_heads, _ = k.shape
|
||||||
|
|
||||||
|
# For now, assume batch_size = 1 (single sequence)
|
||||||
|
# TODO: Support batched varlen format
|
||||||
|
batch_size = cu_seqlens_q.shape[0] - 1
|
||||||
|
if batch_size != 1:
|
||||||
|
# Fallback to flash attention for batched input
|
||||||
|
from flash_attn import flash_attn_varlen_func
|
||||||
|
logger.warning(f"[XAttn] batch_size={batch_size} > 1, falling back to flash attention")
|
||||||
|
return flash_attn_varlen_func(
|
||||||
|
q, k, v,
|
||||||
|
cu_seqlens_q=cu_seqlens_q,
|
||||||
|
cu_seqlens_k=cu_seqlens_k,
|
||||||
|
max_seqlen_q=max_seqlen_q,
|
||||||
|
max_seqlen_k=max_seqlen_k,
|
||||||
|
softmax_scale=softmax_scale,
|
||||||
|
causal=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
q_len = max_seqlen_q
|
||||||
|
k_len = max_seqlen_k
|
||||||
|
|
||||||
|
# Convert from varlen format [total, heads, dim] to [batch, heads, seq, dim]
|
||||||
|
# q: [q_len, num_heads, head_dim] -> [1, num_heads, q_len, head_dim]
|
||||||
|
Q = q.unsqueeze(0).transpose(1, 2) # [1, num_heads, q_len, head_dim]
|
||||||
|
K = k.unsqueeze(0).transpose(1, 2) # [1, num_kv_heads, k_len, head_dim]
|
||||||
|
V = v.unsqueeze(0).transpose(1, 2) # [1, num_kv_heads, k_len, head_dim]
|
||||||
|
|
||||||
|
# Expand KV for GQA (xattn_estimate requires matching heads)
|
||||||
|
K_exp, V_exp = expand_kv_for_gqa(K, V, num_heads)
|
||||||
|
|
||||||
|
# Estimate block importance and get sparse mask
|
||||||
|
_, mask = xattn_estimate(
|
||||||
|
Q, K_exp,
|
||||||
|
chunk_size=self.chunk_size,
|
||||||
|
block_size=self.BSA_BLOCK_SIZE,
|
||||||
|
threshold=self.threshold,
|
||||||
|
use_triton=self.use_triton,
|
||||||
|
causal=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute block counts
|
||||||
|
q_block_num = (q_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE
|
||||||
|
k_block_num = (k_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE
|
||||||
|
|
||||||
|
# Prepare tensors for BSA
|
||||||
|
# q, k, v need to be [seq_len, num_heads, head_dim]
|
||||||
|
q_bsa = q # Already [q_len, num_heads, head_dim]
|
||||||
|
|
||||||
|
# For GQA with BSA, we need to expand k, v to match num_heads
|
||||||
|
# k, v: [k_len, num_kv_heads, head_dim] -> [k_len, num_heads, head_dim]
|
||||||
|
if num_heads != num_kv_heads:
|
||||||
|
num_groups = num_heads // num_kv_heads
|
||||||
|
k_bsa = k.repeat_interleave(num_groups, dim=1)
|
||||||
|
v_bsa = v.repeat_interleave(num_groups, dim=1)
|
||||||
|
else:
|
||||||
|
k_bsa = k
|
||||||
|
v_bsa = v
|
||||||
|
|
||||||
|
# Prepare BSA inputs
|
||||||
|
cu_seqlens_q_bsa = torch.tensor([0, q_len], dtype=torch.int32, device=q.device)
|
||||||
|
cu_seqlens_k_bsa = torch.tensor([0, k_len], dtype=torch.int32, device=k.device)
|
||||||
|
head_groups = torch.ones(num_heads, dtype=torch.int32, device=q.device)
|
||||||
|
|
||||||
|
# Trim mask to actual block counts
|
||||||
|
mask_trimmed = mask[:, :, :q_block_num, :k_block_num].contiguous()
|
||||||
|
|
||||||
|
# Compute sparse attention using BSA
|
||||||
|
output = block_sparse_attn_func(
|
||||||
|
q_bsa, k_bsa, v_bsa,
|
||||||
|
cu_seqlens_q_bsa,
|
||||||
|
cu_seqlens_k_bsa,
|
||||||
|
head_groups,
|
||||||
|
None, # key_padding_mask
|
||||||
|
mask_trimmed,
|
||||||
|
q_len, k_len,
|
||||||
|
p_dropout=0.0,
|
||||||
|
deterministic=True,
|
||||||
|
is_causal=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update statistics (layer 0 only to avoid overcounting)
|
||||||
|
if layer_id == 0:
|
||||||
|
selected_blocks = mask_trimmed.sum().item()
|
||||||
|
total_blocks = q_block_num * k_block_num * num_heads
|
||||||
|
density = selected_blocks / total_blocks if total_blocks > 0 else 1.0
|
||||||
|
logger.debug(f"[XAttn GPU-only] layer={layer_id}, q_blocks={q_block_num}, "
|
||||||
|
f"k_blocks={k_block_num}, density={density:.1%}")
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
def compute_decode(
|
||||||
|
self,
|
||||||
|
q: torch.Tensor,
|
||||||
|
k_cache: torch.Tensor,
|
||||||
|
v_cache: torch.Tensor,
|
||||||
|
cache_seqlens: torch.Tensor,
|
||||||
|
softmax_scale: float,
|
||||||
|
layer_id: int,
|
||||||
|
block_tables: torch.Tensor = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
GPU-only decode attention - delegates to FullAttentionPolicy.
|
||||||
|
|
||||||
|
XAttention is designed for long prefill sequences. For decode (single token),
|
||||||
|
we use FullAttentionPolicy which calls flash_attn_with_kvcache.
|
||||||
|
"""
|
||||||
|
from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy
|
||||||
|
return FullAttentionPolicy().compute_decode(
|
||||||
|
q, k_cache, v_cache, cache_seqlens, softmax_scale, layer_id, block_tables
|
||||||
|
)
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Chunked offload methods
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
def select_blocks(
|
def select_blocks(
|
||||||
self,
|
self,
|
||||||
available_blocks: List[int],
|
available_blocks: List[int],
|
||||||
|
|||||||
Reference in New Issue
Block a user