✨ feat: add GPU-only XAttention BSA sparse attention support
- Implement compute_prefill() in XAttentionBSAPolicy for GPU-only mode - Uses xattn_estimate to compute sparse block mask - Uses block_sparse_attn_func for efficient sparse attention - Handles GQA by expanding K/V heads - Falls back to flash_attn for paged KV cache (prefix cache) - Implement compute_decode() by delegating to FullAttentionPolicy - Add --policy xattn option to bench.py Verified: RULER 32k niah_single_1 5/5 samples passed (100%) Generated with [Claude Code](https://claude.ai/code) via [Happy](https://happy.engineering) Co-Authored-By: Claude <noreply@anthropic.com> Co-Authored-By: Happy <yesreply@happy.engineering>
This commit is contained in:
8
bench.py
8
bench.py
@@ -51,6 +51,9 @@ def main():
|
||||
parser.add_argument("--bench-decode", action="store_true", help="Run decode benchmark (default: prefill only)")
|
||||
parser.add_argument("--bench-all", action="store_true", help="Run both prefill and decode benchmarks")
|
||||
# Sparse policy option (GPU-only mode now supports policy routing)
|
||||
parser.add_argument("--policy", type=str, default=None,
|
||||
choices=["full", "xattn"],
|
||||
help="Sparse policy: full (FullAttention), xattn (XAttention+BSA)")
|
||||
parser.add_argument("--enable-policy", action="store_true",
|
||||
help="Enable sparse policy routing (FullAttentionPolicy by default)")
|
||||
args = parser.parse_args()
|
||||
@@ -59,7 +62,10 @@ def main():
|
||||
max_len = args.max_len
|
||||
|
||||
# Configure sparse policy
|
||||
if args.enable_policy:
|
||||
if args.policy == "xattn":
|
||||
sparse_policy = SparsePolicyType.XATTN_BSA
|
||||
print(f"\n[nanovllm GPU + XAttention BSA] max_len={max_len}")
|
||||
elif args.policy == "full" or args.enable_policy:
|
||||
sparse_policy = SparsePolicyType.FULL
|
||||
print(f"\n[nanovllm GPU + Policy] sparse_policy=FULL, max_len={max_len}")
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user