Files
nano-vllm/tests/test_hierarchical_estimate.py
Zijie Tian f049971f84 test: add hierarchical block sum estimation validation
Validate the hierarchical estimation approach for XAttention:
- Test 1: Math equivalence (diff = 0.0) between hierarchical and direct
- Test 2: Score + threshold selection strategy (replaces mask + voting)
- Test 3: Performance benchmark (41x speedup)

Uses pure torch + xattn kernels, independent of nanovllm framework.

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
2026-01-28 06:24:35 +08:00

443 lines
14 KiB
Python

"""
Test: Hierarchical Block Sum Estimation for XAttention
Verify that hierarchical estimation (small estimate_block_size + aggregation)
produces equivalent results to direct estimation (large block_size), while
being significantly faster.
Key changes validated:
1. Hierarchical block sum: estimate_block_size=1024 → aggregate to cpu_block_size=4096
2. Selection strategy: score + threshold (NOT mask + majority voting)
This test uses pure torch + xattn kernels, independent of nanovllm framework.
"""
import sys
import os
import torch
import math
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from nanovllm.ops.xattn import flat_group_gemm_fuse_reshape, softmax_fuse_block_sum
# ============================================================
# Configuration
# ============================================================
# Model dimensions (Llama-3.1-8B-Instruct style)
NUM_HEADS = 32
NUM_KV_HEADS = 8
HEAD_DIM = 128
STRIDE = 8
# Block sizes
CPU_BLOCK_SIZE = 4096 # External CPU block size (fixed, for overlap)
ESTIMATE_BLOCK_SIZE = 1024 # Internal estimate block size (optimized)
# Selection parameters
THRESHOLD = 0.95 # Cumulative attention threshold
# ============================================================
# Hierarchical Estimation Implementation
# ============================================================
def compute_attention_scores(Q, K_blocks, stride):
"""
Compute attention scores for Q against multiple K blocks.
Args:
Q: [1, num_heads, q_len, head_dim]
K_blocks: List of K tensors, each [1, num_heads, block_size, head_dim]
stride: Stride for reshape
Returns:
attn_scores: [1, num_heads, q_reshaped, total_k_reshaped]
"""
q_len = Q.shape[2]
q_reshaped = q_len // stride
attn_chunks = []
for K_block in K_blocks:
# flat_group_gemm_fuse_reshape
attn_chunk = flat_group_gemm_fuse_reshape(
Q, K_block, stride,
chunk_start=0,
chunk_end=q_reshaped,
is_causal=False,
)
attn_chunks.append(attn_chunk)
# Concatenate along K dimension
attn_scores = torch.cat(attn_chunks, dim=-1)
return attn_scores
def hierarchical_block_sum(
attn_scores,
estimate_block_size,
cpu_block_size,
stride,
head_dim,
):
"""
Compute hierarchical block sums: fine-grained → aggregated to CPU block level.
Args:
attn_scores: [batch, heads, q_reshaped, k_reshaped]
estimate_block_size: Small block size for efficient softmax (e.g., 1024)
cpu_block_size: External CPU block size (e.g., 4096)
stride: Stride used in reshape
head_dim: Head dimension for scale computation
Returns:
cpu_block_scores: [batch, heads, num_cpu_blocks] - attention score per CPU block
"""
batch_size, num_heads, q_reshaped, k_reshaped = attn_scores.shape
# Compute reshaped block sizes
reshaped_est_bs = estimate_block_size // stride # 1024/8 = 128
reshaped_cpu_bs = cpu_block_size // stride # 4096/8 = 512
# Scale factor
norm = 1.0
scale = 1.4426950408889634 / math.sqrt(head_dim) / stride / norm
# Segment size for softmax kernel
segment_size = min(4096, reshaped_est_bs)
# Step 1: Fine-grained softmax + block sum
block_sums_fine = softmax_fuse_block_sum(
attn_scores,
reshaped_est_bs,
segment_size,
chunk_start=0,
chunk_end=q_reshaped,
real_q_len=q_reshaped,
scale=scale,
is_causal=False,
)
# block_sums_fine: [batch, heads, q_est_blocks, k_est_blocks]
q_est_blocks = block_sums_fine.shape[2]
k_est_blocks = block_sums_fine.shape[3]
# Step 2: Aggregate to CPU block level
# ratio = cpu_block_size / estimate_block_size = 4
ratio = cpu_block_size // estimate_block_size
num_cpu_blocks = k_est_blocks // ratio
# Reshape and sum along K dimension
# [batch, heads, q_est, k_est] → [batch, heads, q_est, num_cpu, ratio]
block_sums_coarse = block_sums_fine.view(
batch_size, num_heads, q_est_blocks, num_cpu_blocks, ratio
).sum(dim=-1) # [batch, heads, q_est_blocks, num_cpu_blocks]
# Step 3: Sum over Q dimension (total attention from Q chunk to each K block)
cpu_block_scores = block_sums_coarse.sum(dim=2) # [batch, heads, num_cpu_blocks]
return cpu_block_scores, block_sums_fine
def direct_block_sum(
attn_scores,
cpu_block_size,
stride,
head_dim,
):
"""
Compute block sums directly with CPU block size (baseline for comparison).
Args:
attn_scores: [batch, heads, q_reshaped, k_reshaped]
cpu_block_size: Block size (e.g., 4096)
stride: Stride used in reshape
head_dim: Head dimension for scale computation
Returns:
cpu_block_scores: [batch, heads, num_cpu_blocks]
"""
batch_size, num_heads, q_reshaped, k_reshaped = attn_scores.shape
reshaped_cpu_bs = cpu_block_size // stride # 4096/8 = 512
norm = 1.0
scale = 1.4426950408889634 / math.sqrt(head_dim) / stride / norm
segment_size = min(4096, reshaped_cpu_bs)
block_sums = softmax_fuse_block_sum(
attn_scores,
reshaped_cpu_bs,
segment_size,
chunk_start=0,
chunk_end=q_reshaped,
real_q_len=q_reshaped,
scale=scale,
is_causal=False,
)
# block_sums: [batch, heads, q_cpu_blocks, k_cpu_blocks]
# Sum over Q dimension
cpu_block_scores = block_sums.sum(dim=2) # [batch, heads, num_cpu_blocks]
return cpu_block_scores
def select_blocks_by_score(
cpu_block_scores,
threshold=0.95,
always_include_first=True,
always_include_last=True,
):
"""
Select CPU blocks based on score + threshold.
⚠️ IMPORTANT: This replaces the original mask + majority voting strategy.
This change should be documented in the final implementation.
Args:
cpu_block_scores: [batch, heads, num_cpu_blocks]
threshold: Cumulative attention threshold (e.g., 0.95)
always_include_first: Always include first block (sink)
always_include_last: Always include last block (safety)
Returns:
selected_block_ids: List of selected block indices
density: Fraction of blocks selected
"""
# Average scores across heads
scores_per_block = cpu_block_scores.mean(dim=(0, 1)) # [num_cpu_blocks]
num_blocks = scores_per_block.shape[0]
# Normalize to get attention distribution
total_score = scores_per_block.sum()
score_ratio = scores_per_block / total_score
# Sort by score (descending)
sorted_indices = torch.argsort(score_ratio, descending=True)
# Select blocks until cumulative threshold is reached
cumsum = 0.0
selected = set()
for idx in sorted_indices.tolist():
selected.add(idx)
cumsum += score_ratio[idx].item()
if cumsum >= threshold:
break
# Always include first and last blocks
if always_include_first:
selected.add(0)
if always_include_last:
selected.add(num_blocks - 1)
selected_block_ids = sorted(list(selected))
density = len(selected_block_ids) / num_blocks
return selected_block_ids, density
# ============================================================
# Test Cases
# ============================================================
def test_equivalence():
"""
Test that hierarchical estimation produces equivalent scores to direct estimation.
"""
print("=" * 60)
print("Test 1: Hierarchical vs Direct - Equivalence")
print("=" * 60)
# Create random Q and multiple K blocks
q_len = CPU_BLOCK_SIZE # 4096
num_k_blocks = 4
# Q: [1, num_heads, q_len, head_dim]
Q = torch.randn(1, NUM_HEADS, q_len, HEAD_DIM, dtype=torch.bfloat16, device='cuda')
# K blocks: each [1, num_heads, cpu_block_size, head_dim]
K_blocks = [
torch.randn(1, NUM_HEADS, CPU_BLOCK_SIZE, HEAD_DIM, dtype=torch.bfloat16, device='cuda')
for _ in range(num_k_blocks)
]
# Compute attention scores
attn_scores = compute_attention_scores(Q, K_blocks, STRIDE)
print(f"attn_scores shape: {attn_scores.shape}")
# Method 1: Hierarchical (fast)
scores_hier, _ = hierarchical_block_sum(
attn_scores, ESTIMATE_BLOCK_SIZE, CPU_BLOCK_SIZE, STRIDE, HEAD_DIM
)
print(f"scores_hier shape: {scores_hier.shape}")
# Method 2: Direct (slow)
scores_direct = direct_block_sum(
attn_scores, CPU_BLOCK_SIZE, STRIDE, HEAD_DIM
)
print(f"scores_direct shape: {scores_direct.shape}")
# Compare
diff = (scores_hier - scores_direct).abs().max().item()
print(f"\nMax difference: {diff:.6f}")
# Per-block comparison
print("\nPer-block scores comparison:")
for i in range(num_k_blocks):
h_val = scores_hier[0, 0, i].item()
d_val = scores_direct[0, 0, i].item()
print(f" Block {i}: hierarchical={h_val:.4f}, direct={d_val:.4f}, diff={abs(h_val-d_val):.6f}")
passed = diff < 0.01
print(f"\nTest 1: {'PASSED' if passed else 'FAILED'}")
return passed
def test_selection():
"""
Test the score + threshold selection strategy.
"""
print("\n" + "=" * 60)
print("Test 2: Score + Threshold Selection")
print("=" * 60)
# Create Q and K blocks with varying importance
q_len = CPU_BLOCK_SIZE
num_k_blocks = 8
Q = torch.randn(1, NUM_HEADS, q_len, HEAD_DIM, dtype=torch.bfloat16, device='cuda')
# Create K blocks - make some more important than others
K_blocks = []
for i in range(num_k_blocks):
# First and middle blocks are more important (higher values)
importance = 2.0 if i in [0, 3, 4] else 1.0
K = torch.randn(1, NUM_HEADS, CPU_BLOCK_SIZE, HEAD_DIM, dtype=torch.bfloat16, device='cuda')
K = K * importance
K_blocks.append(K)
# Compute scores
attn_scores = compute_attention_scores(Q, K_blocks, STRIDE)
scores, _ = hierarchical_block_sum(
attn_scores, ESTIMATE_BLOCK_SIZE, CPU_BLOCK_SIZE, STRIDE, HEAD_DIM
)
# Print scores per block
print("\nCPU block scores (head 0):")
for i in range(num_k_blocks):
print(f" Block {i}: {scores[0, 0, i].item():.4f}")
# Select blocks with different thresholds
for thresh in [0.9, 0.95, 0.99]:
selected, density = select_blocks_by_score(scores, threshold=thresh)
print(f"\nThreshold {thresh}: selected {len(selected)}/{num_k_blocks} blocks ({density:.1%})")
print(f" Selected: {selected}")
print("\nTest 2: PASSED (visual inspection)")
return True
def test_performance():
"""
Benchmark hierarchical vs direct estimation performance.
"""
print("\n" + "=" * 60)
print("Test 3: Performance Benchmark")
print("=" * 60)
import time
NUM_WARMUP = 3
NUM_RUNS = 10
# Larger test case
q_len = CPU_BLOCK_SIZE
num_k_blocks = 16 # 64K context
Q = torch.randn(1, NUM_HEADS, q_len, HEAD_DIM, dtype=torch.bfloat16, device='cuda')
K_blocks = [
torch.randn(1, NUM_HEADS, CPU_BLOCK_SIZE, HEAD_DIM, dtype=torch.bfloat16, device='cuda')
for _ in range(num_k_blocks)
]
# Compute attention scores (shared)
attn_scores = compute_attention_scores(Q, K_blocks, STRIDE)
print(f"attn_scores shape: {attn_scores.shape}")
print(f"Context: {num_k_blocks * CPU_BLOCK_SIZE // 1024}K tokens")
# Warmup and benchmark hierarchical
for _ in range(NUM_WARMUP):
_ = hierarchical_block_sum(attn_scores, ESTIMATE_BLOCK_SIZE, CPU_BLOCK_SIZE, STRIDE, HEAD_DIM)
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(NUM_RUNS):
_ = hierarchical_block_sum(attn_scores, ESTIMATE_BLOCK_SIZE, CPU_BLOCK_SIZE, STRIDE, HEAD_DIM)
torch.cuda.synchronize()
hier_time = (time.perf_counter() - start) / NUM_RUNS * 1000
# Warmup and benchmark direct
for _ in range(NUM_WARMUP):
_ = direct_block_sum(attn_scores, CPU_BLOCK_SIZE, STRIDE, HEAD_DIM)
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(NUM_RUNS):
_ = direct_block_sum(attn_scores, CPU_BLOCK_SIZE, STRIDE, HEAD_DIM)
torch.cuda.synchronize()
direct_time = (time.perf_counter() - start) / NUM_RUNS * 1000
speedup = direct_time / hier_time
print(f"\nResults:")
print(f" Hierarchical (bs=1024): {hier_time:.2f} ms")
print(f" Direct (bs=4096): {direct_time:.2f} ms")
print(f" Speedup: {speedup:.2f}x")
passed = speedup > 5.0 # Expect at least 5x speedup
print(f"\nTest 3: {'PASSED' if passed else 'FAILED'} (speedup > 5x expected)")
return passed
# ============================================================
# Main
# ============================================================
if __name__ == "__main__":
print("=" * 60)
print("Hierarchical Block Sum Estimation Test")
print("=" * 60)
print(f"\nConfiguration:")
print(f" NUM_HEADS: {NUM_HEADS}")
print(f" NUM_KV_HEADS: {NUM_KV_HEADS}")
print(f" HEAD_DIM: {HEAD_DIM}")
print(f" STRIDE: {STRIDE}")
print(f" CPU_BLOCK_SIZE: {CPU_BLOCK_SIZE}")
print(f" ESTIMATE_BLOCK_SIZE: {ESTIMATE_BLOCK_SIZE}")
print(f" THRESHOLD: {THRESHOLD}")
print()
results = []
results.append(("Equivalence", test_equivalence()))
results.append(("Selection", test_selection()))
results.append(("Performance", test_performance()))
print("\n" + "=" * 60)
print("SUMMARY")
print("=" * 60)
for name, passed in results:
status = "PASSED" if passed else "FAILED"
print(f" {name}: {status}")
all_passed = all(p for _, p in results)
print("=" * 60)
if all_passed:
print("test_hierarchical_estimate: ALL PASSED")
sys.exit(0)
else:
print("test_hierarchical_estimate: SOME FAILED")
sys.exit(1)