Files
nano-vllm/tests/bench_estimate_block_size.py
Zijie Tian c90dc196b2 📝 docs: add estimate block_size performance analysis
Document the performance impact of block_size on softmax_fuse_block_sum:
- Current 4096 (reshaped 512) is the WORST point: 95ms
- Optimal 1024 (reshaped 128): 6ms - 15x faster
- Performance follows U-shaped curve

Add tests/bench_estimate_block_size.py for benchmarking and propose
hierarchical block sum approach for optimization.

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
2026-01-28 06:24:28 +08:00

315 lines
10 KiB
Python

"""
Benchmark: block_size impact on XAttention estimate phase performance.
This script tests how different block_size values affect the performance of:
1. flat_group_gemm_fuse_reshape (estimate GEMM)
2. softmax_fuse_block_sum (estimate softmax + block aggregation)
Key insight: The current select_blocks uses global kvcache_block_size for estimation,
which may not be optimal for the Triton kernels.
"""
import sys
import os
import torch
import time
import math
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from nanovllm.ops.xattn import flat_group_gemm_fuse_reshape, softmax_fuse_block_sum
# ============================================================
# Configuration
# ============================================================
# Test configurations
BLOCK_SIZES = [64, 128, 256, 512] # BSA optimal is 128
STRIDE = 8
NUM_WARMUP = 3
NUM_RUNS = 10
# Model dimensions (Llama-3.1-8B-Instruct)
NUM_HEADS = 32
NUM_KV_HEADS = 8
HEAD_DIM = 128
# Context lengths to test
CONTEXT_LENGTHS = [16384, 32768, 65536] # 16K, 32K, 64K
# ============================================================
# Benchmark Functions
# ============================================================
def benchmark_flat_group_gemm(Q, K, stride, block_size, num_warmup=3, num_runs=10):
"""
Benchmark flat_group_gemm_fuse_reshape kernel.
Args:
Q: [batch, heads, q_len, head_dim]
K: [batch, heads, k_len, head_dim]
stride: Stride for reshape
block_size: Block size (affects alignment requirements)
Returns:
(avg_time_ms, output_tensor)
"""
q_len = Q.shape[2]
k_len = K.shape[2]
# Compute reshaped dimensions
reshaped_q_len = q_len // stride
reshaped_k_len = k_len // stride
reshaped_block_size = block_size // stride
# Warmup
for _ in range(num_warmup):
_ = flat_group_gemm_fuse_reshape(
Q, K, stride,
chunk_start=0,
chunk_end=reshaped_q_len,
is_causal=False,
)
torch.cuda.synchronize()
# Benchmark
start = time.perf_counter()
for _ in range(num_runs):
output = flat_group_gemm_fuse_reshape(
Q, K, stride,
chunk_start=0,
chunk_end=reshaped_q_len,
is_causal=False,
)
torch.cuda.synchronize()
end = time.perf_counter()
avg_time_ms = (end - start) / num_runs * 1000
return avg_time_ms, output
def benchmark_softmax_fuse_block_sum(attn_weights, reshaped_block_size, num_warmup=3, num_runs=10):
"""
Benchmark softmax_fuse_block_sum kernel.
Args:
attn_weights: [batch, heads, q_len, k_len] attention weights
reshaped_block_size: Block size in reshaped space
Returns:
avg_time_ms
"""
batch_size, num_heads, q_len, k_len = attn_weights.shape
head_dim = HEAD_DIM
stride = STRIDE
norm = 1.0
# segment_size must divide k_len and be >= reshaped_block_size
segment_size = min(4096, reshaped_block_size)
# Ensure k_len is divisible by segment_size
if k_len % segment_size != 0:
# Pad k_len
pad_size = segment_size - (k_len % segment_size)
attn_weights = torch.nn.functional.pad(attn_weights, (0, pad_size), value=0)
k_len = attn_weights.shape[3]
# Scale factor
scale = 1.4426950408889634 / math.sqrt(head_dim) / stride / norm
# Warmup
for _ in range(num_warmup):
_ = softmax_fuse_block_sum(
attn_weights,
reshaped_block_size,
segment_size,
chunk_start=0,
chunk_end=q_len,
real_q_len=q_len,
scale=scale,
is_causal=False,
)
torch.cuda.synchronize()
# Benchmark
start = time.perf_counter()
for _ in range(num_runs):
output = softmax_fuse_block_sum(
attn_weights,
reshaped_block_size,
segment_size,
chunk_start=0,
chunk_end=q_len,
real_q_len=q_len,
scale=scale,
is_causal=False,
)
torch.cuda.synchronize()
end = time.perf_counter()
avg_time_ms = (end - start) / num_runs * 1000
return avg_time_ms
def run_estimate_benchmark(q_len, k_len, block_size, stride=STRIDE):
"""
Run full estimate benchmark for given configuration.
Args:
q_len: Query length
k_len: Key length (usually same as q_len for current chunk scenario)
block_size: Block size to test
stride: Stride for reshape
Returns:
dict with timing results
"""
# Create random Q and K tensors
# Shape: [batch, heads, seq_len, head_dim]
Q = torch.randn(1, NUM_HEADS, q_len, HEAD_DIM, dtype=torch.bfloat16, device="cuda")
K = torch.randn(1, NUM_HEADS, k_len, HEAD_DIM, dtype=torch.bfloat16, device="cuda")
reshaped_block_size = block_size // stride
reshaped_q_len = q_len // stride
reshaped_k_len = k_len // stride
# Benchmark GEMM
gemm_time, attn_weights = benchmark_flat_group_gemm(
Q, K, stride, block_size,
num_warmup=NUM_WARMUP, num_runs=NUM_RUNS
)
# Benchmark softmax + block sum
softmax_time = benchmark_softmax_fuse_block_sum(
attn_weights, reshaped_block_size,
num_warmup=NUM_WARMUP, num_runs=NUM_RUNS
)
# Clean up
del Q, K, attn_weights
torch.cuda.empty_cache()
return {
"q_len": q_len,
"k_len": k_len,
"block_size": block_size,
"reshaped_block_size": reshaped_block_size,
"gemm_time_ms": gemm_time,
"softmax_time_ms": softmax_time,
"total_time_ms": gemm_time + softmax_time,
}
# ============================================================
# Main Benchmark
# ============================================================
def main():
import argparse
parser = argparse.ArgumentParser(description="Benchmark block_size impact on estimate phase")
parser.add_argument("--gpu", type=int, default=0, help="GPU to use")
parser.add_argument("--ctx-len", type=int, default=None,
help="Single context length to test (default: test multiple)")
args = parser.parse_args()
# Set GPU
torch.cuda.set_device(args.gpu)
device_name = torch.cuda.get_device_name(args.gpu)
print(f"Using GPU {args.gpu}: {device_name}")
print()
# Determine context lengths to test
if args.ctx_len:
context_lengths = [args.ctx_len]
else:
context_lengths = CONTEXT_LENGTHS
print("=" * 80)
print("Benchmark: block_size impact on XAttention estimate phase")
print("=" * 80)
print(f"Configuration:")
print(f" NUM_HEADS: {NUM_HEADS}")
print(f" NUM_KV_HEADS: {NUM_KV_HEADS}")
print(f" HEAD_DIM: {HEAD_DIM}")
print(f" STRIDE: {STRIDE}")
print(f" BLOCK_SIZES: {BLOCK_SIZES}")
print(f" NUM_WARMUP: {NUM_WARMUP}")
print(f" NUM_RUNS: {NUM_RUNS}")
print()
all_results = []
for ctx_len in context_lengths:
print(f"\n{'='*80}")
print(f"Context Length: {ctx_len // 1024}K ({ctx_len} tokens)")
print(f"{'='*80}")
# Pad to alignment
alignment = STRIDE * 128 # Triton BLOCK_M requirement
padded_len = ((ctx_len + alignment - 1) // alignment) * alignment
print(f"Padded to: {padded_len} tokens (alignment={alignment})")
print()
results = []
for block_size in BLOCK_SIZES:
print(f"Testing block_size={block_size} (reshaped={block_size // STRIDE})...", end=" ")
try:
result = run_estimate_benchmark(padded_len, padded_len, block_size)
results.append(result)
print(f"GEMM={result['gemm_time_ms']:.2f}ms, "
f"Softmax={result['softmax_time_ms']:.2f}ms, "
f"Total={result['total_time_ms']:.2f}ms")
except Exception as e:
print(f"ERROR: {e}")
import traceback
traceback.print_exc()
if results:
all_results.extend(results)
# Print summary table for this context length
print(f"\n--- Summary for {ctx_len // 1024}K context ---")
print(f"{'block_size':>12} {'reshaped':>10} {'GEMM (ms)':>12} {'Softmax (ms)':>14} {'Total (ms)':>12} {'Speedup':>10}")
print("-" * 74)
baseline_total = results[0]["total_time_ms"]
for r in results:
speedup = baseline_total / r["total_time_ms"]
print(f"{r['block_size']:>12} {r['reshaped_block_size']:>10} "
f"{r['gemm_time_ms']:>12.2f} {r['softmax_time_ms']:>14.2f} "
f"{r['total_time_ms']:>12.2f} {speedup:>9.2f}x")
# Final summary across all context lengths
if len(context_lengths) > 1:
print(f"\n{'='*80}")
print("OVERALL SUMMARY")
print(f"{'='*80}")
print(f"{'ctx_len':>10} {'block_size':>12} {'GEMM (ms)':>12} {'Softmax (ms)':>14} {'Total (ms)':>12}")
print("-" * 64)
for r in all_results:
print(f"{r['q_len']//1024:>9}K {r['block_size']:>12} "
f"{r['gemm_time_ms']:>12.2f} {r['softmax_time_ms']:>14.2f} "
f"{r['total_time_ms']:>12.2f}")
# Find optimal block_size for softmax
print(f"\n{'='*80}")
print("ANALYSIS: Optimal block_size for softmax_fuse_block_sum")
print(f"{'='*80}")
for ctx_len in context_lengths:
ctx_results = [r for r in all_results if r["q_len"] == ((ctx_len + STRIDE * 128 - 1) // (STRIDE * 128)) * (STRIDE * 128)]
if ctx_results:
best = min(ctx_results, key=lambda x: x["softmax_time_ms"])
worst = max(ctx_results, key=lambda x: x["softmax_time_ms"])
improvement = worst["softmax_time_ms"] / best["softmax_time_ms"]
print(f"Context {ctx_len // 1024}K:")
print(f" Best: block_size={best['block_size']} ({best['softmax_time_ms']:.2f}ms)")
print(f" Worst: block_size={worst['block_size']} ({worst['softmax_time_ms']:.2f}ms)")
print(f" Potential improvement: {improvement:.2f}x")
print("\nbench_estimate_block_size: DONE")
if __name__ == "__main__":
main()