feat: add GPU-only XAttention BSA sparse attention support

- Implement compute_prefill() in XAttentionBSAPolicy for GPU-only mode
  - Uses xattn_estimate to compute sparse block mask
  - Uses block_sparse_attn_func for efficient sparse attention
  - Handles GQA by expanding K/V heads
  - Falls back to flash_attn for paged KV cache (prefix cache)
- Implement compute_decode() by delegating to FullAttentionPolicy
- Add --policy xattn option to bench.py

Verified: RULER 32k niah_single_1 5/5 samples passed (100%)

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
This commit is contained in:
Zijie Tian
2026-01-27 05:19:24 +08:00
parent b6b59b50ed
commit 076656c9c2
2 changed files with 207 additions and 1 deletions

View File

@@ -51,6 +51,9 @@ def main():
parser.add_argument("--bench-decode", action="store_true", help="Run decode benchmark (default: prefill only)")
parser.add_argument("--bench-all", action="store_true", help="Run both prefill and decode benchmarks")
# Sparse policy option (GPU-only mode now supports policy routing)
parser.add_argument("--policy", type=str, default=None,
choices=["full", "xattn"],
help="Sparse policy: full (FullAttention), xattn (XAttention+BSA)")
parser.add_argument("--enable-policy", action="store_true",
help="Enable sparse policy routing (FullAttentionPolicy by default)")
args = parser.parse_args()
@@ -59,7 +62,10 @@ def main():
max_len = args.max_len
# Configure sparse policy
if args.enable_policy:
if args.policy == "xattn":
sparse_policy = SparsePolicyType.XATTN_BSA
print(f"\n[nanovllm GPU + XAttention BSA] max_len={max_len}")
elif args.policy == "full" or args.enable_policy:
sparse_policy = SparsePolicyType.FULL
print(f"\n[nanovllm GPU + Policy] sparse_policy=FULL, max_len={max_len}")
else: