feat: add XAttention BSA support to bench_offload.py

- Add --model parameter (default: Llama-3.1-8B-Instruct)
- Add --enable-xattn flag for XAttention BSA sparse prefill
- Add --xattn-threshold and --xattn-stride parameters
- Change default num-gpu-blocks from 6 to 4
- Add benchmark results doc with Full vs XAttn comparison (32K/128K)

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
This commit is contained in:
Zijie Tian
2026-01-27 04:20:16 +08:00
parent 924a0d2bfa
commit 73c9dc46ff
3 changed files with 115 additions and 4 deletions

View File

@@ -46,24 +46,41 @@ def main():
from nanovllm.config import SparsePolicyType
parser = argparse.ArgumentParser(description="Benchmark CPU offload performance")
parser.add_argument("--enable-quest", action="store_true", help="Enable Quest sparse attention for decode")
parser.add_argument("--model", type=str, default="~/models/Llama-3.1-8B-Instruct",
help="Model path (default: ~/models/Llama-3.1-8B-Instruct)")
# Sparse policy selection (mutually exclusive)
sparse_group = parser.add_mutually_exclusive_group()
sparse_group.add_argument("--enable-quest", action="store_true",
help="Enable Quest sparse attention (decode only, prefill uses full)")
sparse_group.add_argument("--enable-xattn", action="store_true",
help="Enable XAttention BSA (prefill only, decode uses full)")
# Quest parameters
parser.add_argument("--topk", type=int, default=16, help="Top-K blocks for Quest (default: 16)")
parser.add_argument("--threshold", type=int, default=4, help="Apply sparse only when blocks > threshold (default: 4)")
# XAttention parameters
parser.add_argument("--xattn-threshold", type=float, default=0.95,
help="XAttention cumulative attention threshold (default: 0.95)")
parser.add_argument("--xattn-stride", type=int, default=8,
help="XAttention Q/K downsampling stride (default: 8)")
# General parameters
parser.add_argument("--input-len", type=int, default=None, help="Input length in tokens")
parser.add_argument("--output-len", type=int, default=64, help="Output length for decode benchmark (default: 64)")
parser.add_argument("--num-gpu-blocks", type=int, default=6, help="Number of GPU blocks (default: 6)")
parser.add_argument("--num-gpu-blocks", type=int, default=4, help="Number of GPU blocks (default: 4)")
parser.add_argument("--max-len", type=int, default=32*1024, help="Max model length (default: 32K)")
parser.add_argument("--bench-decode", action="store_true", help="Run decode benchmark (default: prefill only)")
parser.add_argument("--bench-all", action="store_true", help="Run both prefill and decode benchmarks")
args = parser.parse_args()
path = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/")
path = os.path.expanduser(args.model)
max_len = args.max_len
# Setup policy configuration
if args.enable_quest:
sparse_policy = SparsePolicyType.QUEST
print(f"\n[Quest Sparse Attention] topk={args.topk}, threshold={args.threshold}")
print(f"\n[Quest Sparse Attention] decode: Quest (topk={args.topk}, threshold={args.threshold}), prefill: Full")
elif args.enable_xattn:
sparse_policy = SparsePolicyType.XATTN_BSA
print(f"\n[XAttention BSA] prefill: XAttn (tau={args.xattn_threshold}, stride={args.xattn_stride}), decode: Full")
else:
sparse_policy = SparsePolicyType.FULL
print("\n[Full Attention] baseline (no sparse)")
@@ -78,8 +95,12 @@ def main():
enable_cpu_offload=True,
num_gpu_blocks=args.num_gpu_blocks,
sparse_policy=sparse_policy,
# Quest parameters
sparse_topk_blocks=args.topk,
sparse_threshold_blocks=args.threshold,
# XAttention parameters
sparse_threshold=args.xattn_threshold,
sparse_stride=args.xattn_stride,
)
# Warmup