✨ feat: add XAttention BSA support to bench_offload.py
- Add --model parameter (default: Llama-3.1-8B-Instruct) - Add --enable-xattn flag for XAttention BSA sparse prefill - Add --xattn-threshold and --xattn-stride parameters - Change default num-gpu-blocks from 6 to 4 - Add benchmark results doc with Full vs XAttn comparison (32K/128K) Generated with [Claude Code](https://claude.ai/code) via [Happy](https://happy.engineering) Co-Authored-By: Claude <noreply@anthropic.com> Co-Authored-By: Happy <yesreply@happy.engineering>
This commit is contained in:
@@ -46,24 +46,41 @@ def main():
|
||||
from nanovllm.config import SparsePolicyType
|
||||
|
||||
parser = argparse.ArgumentParser(description="Benchmark CPU offload performance")
|
||||
parser.add_argument("--enable-quest", action="store_true", help="Enable Quest sparse attention for decode")
|
||||
parser.add_argument("--model", type=str, default="~/models/Llama-3.1-8B-Instruct",
|
||||
help="Model path (default: ~/models/Llama-3.1-8B-Instruct)")
|
||||
# Sparse policy selection (mutually exclusive)
|
||||
sparse_group = parser.add_mutually_exclusive_group()
|
||||
sparse_group.add_argument("--enable-quest", action="store_true",
|
||||
help="Enable Quest sparse attention (decode only, prefill uses full)")
|
||||
sparse_group.add_argument("--enable-xattn", action="store_true",
|
||||
help="Enable XAttention BSA (prefill only, decode uses full)")
|
||||
# Quest parameters
|
||||
parser.add_argument("--topk", type=int, default=16, help="Top-K blocks for Quest (default: 16)")
|
||||
parser.add_argument("--threshold", type=int, default=4, help="Apply sparse only when blocks > threshold (default: 4)")
|
||||
# XAttention parameters
|
||||
parser.add_argument("--xattn-threshold", type=float, default=0.95,
|
||||
help="XAttention cumulative attention threshold (default: 0.95)")
|
||||
parser.add_argument("--xattn-stride", type=int, default=8,
|
||||
help="XAttention Q/K downsampling stride (default: 8)")
|
||||
# General parameters
|
||||
parser.add_argument("--input-len", type=int, default=None, help="Input length in tokens")
|
||||
parser.add_argument("--output-len", type=int, default=64, help="Output length for decode benchmark (default: 64)")
|
||||
parser.add_argument("--num-gpu-blocks", type=int, default=6, help="Number of GPU blocks (default: 6)")
|
||||
parser.add_argument("--num-gpu-blocks", type=int, default=4, help="Number of GPU blocks (default: 4)")
|
||||
parser.add_argument("--max-len", type=int, default=32*1024, help="Max model length (default: 32K)")
|
||||
parser.add_argument("--bench-decode", action="store_true", help="Run decode benchmark (default: prefill only)")
|
||||
parser.add_argument("--bench-all", action="store_true", help="Run both prefill and decode benchmarks")
|
||||
args = parser.parse_args()
|
||||
|
||||
path = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/")
|
||||
path = os.path.expanduser(args.model)
|
||||
max_len = args.max_len
|
||||
|
||||
# Setup policy configuration
|
||||
if args.enable_quest:
|
||||
sparse_policy = SparsePolicyType.QUEST
|
||||
print(f"\n[Quest Sparse Attention] topk={args.topk}, threshold={args.threshold}")
|
||||
print(f"\n[Quest Sparse Attention] decode: Quest (topk={args.topk}, threshold={args.threshold}), prefill: Full")
|
||||
elif args.enable_xattn:
|
||||
sparse_policy = SparsePolicyType.XATTN_BSA
|
||||
print(f"\n[XAttention BSA] prefill: XAttn (tau={args.xattn_threshold}, stride={args.xattn_stride}), decode: Full")
|
||||
else:
|
||||
sparse_policy = SparsePolicyType.FULL
|
||||
print("\n[Full Attention] baseline (no sparse)")
|
||||
@@ -78,8 +95,12 @@ def main():
|
||||
enable_cpu_offload=True,
|
||||
num_gpu_blocks=args.num_gpu_blocks,
|
||||
sparse_policy=sparse_policy,
|
||||
# Quest parameters
|
||||
sparse_topk_blocks=args.topk,
|
||||
sparse_threshold_blocks=args.threshold,
|
||||
# XAttention parameters
|
||||
sparse_threshold=args.xattn_threshold,
|
||||
sparse_stride=args.xattn_stride,
|
||||
)
|
||||
|
||||
# Warmup
|
||||
|
||||
Reference in New Issue
Block a user