diff --git a/tests/test_hierarchical_estimate.py b/tests/test_hierarchical_estimate.py new file mode 100644 index 0000000..1edd4f8 --- /dev/null +++ b/tests/test_hierarchical_estimate.py @@ -0,0 +1,442 @@ +""" +Test: Hierarchical Block Sum Estimation for XAttention + +Verify that hierarchical estimation (small estimate_block_size + aggregation) +produces equivalent results to direct estimation (large block_size), while +being significantly faster. + +Key changes validated: +1. Hierarchical block sum: estimate_block_size=1024 → aggregate to cpu_block_size=4096 +2. Selection strategy: score + threshold (NOT mask + majority voting) + +This test uses pure torch + xattn kernels, independent of nanovllm framework. +""" + +import sys +import os +import torch +import math + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from nanovllm.ops.xattn import flat_group_gemm_fuse_reshape, softmax_fuse_block_sum + +# ============================================================ +# Configuration +# ============================================================ + +# Model dimensions (Llama-3.1-8B-Instruct style) +NUM_HEADS = 32 +NUM_KV_HEADS = 8 +HEAD_DIM = 128 +STRIDE = 8 + +# Block sizes +CPU_BLOCK_SIZE = 4096 # External CPU block size (fixed, for overlap) +ESTIMATE_BLOCK_SIZE = 1024 # Internal estimate block size (optimized) + +# Selection parameters +THRESHOLD = 0.95 # Cumulative attention threshold + +# ============================================================ +# Hierarchical Estimation Implementation +# ============================================================ + +def compute_attention_scores(Q, K_blocks, stride): + """ + Compute attention scores for Q against multiple K blocks. + + Args: + Q: [1, num_heads, q_len, head_dim] + K_blocks: List of K tensors, each [1, num_heads, block_size, head_dim] + stride: Stride for reshape + + Returns: + attn_scores: [1, num_heads, q_reshaped, total_k_reshaped] + """ + q_len = Q.shape[2] + q_reshaped = q_len // stride + + attn_chunks = [] + for K_block in K_blocks: + # flat_group_gemm_fuse_reshape + attn_chunk = flat_group_gemm_fuse_reshape( + Q, K_block, stride, + chunk_start=0, + chunk_end=q_reshaped, + is_causal=False, + ) + attn_chunks.append(attn_chunk) + + # Concatenate along K dimension + attn_scores = torch.cat(attn_chunks, dim=-1) + return attn_scores + + +def hierarchical_block_sum( + attn_scores, + estimate_block_size, + cpu_block_size, + stride, + head_dim, +): + """ + Compute hierarchical block sums: fine-grained → aggregated to CPU block level. + + Args: + attn_scores: [batch, heads, q_reshaped, k_reshaped] + estimate_block_size: Small block size for efficient softmax (e.g., 1024) + cpu_block_size: External CPU block size (e.g., 4096) + stride: Stride used in reshape + head_dim: Head dimension for scale computation + + Returns: + cpu_block_scores: [batch, heads, num_cpu_blocks] - attention score per CPU block + """ + batch_size, num_heads, q_reshaped, k_reshaped = attn_scores.shape + + # Compute reshaped block sizes + reshaped_est_bs = estimate_block_size // stride # 1024/8 = 128 + reshaped_cpu_bs = cpu_block_size // stride # 4096/8 = 512 + + # Scale factor + norm = 1.0 + scale = 1.4426950408889634 / math.sqrt(head_dim) / stride / norm + + # Segment size for softmax kernel + segment_size = min(4096, reshaped_est_bs) + + # Step 1: Fine-grained softmax + block sum + block_sums_fine = softmax_fuse_block_sum( + attn_scores, + reshaped_est_bs, + segment_size, + chunk_start=0, + chunk_end=q_reshaped, + real_q_len=q_reshaped, + scale=scale, + is_causal=False, + ) + # block_sums_fine: [batch, heads, q_est_blocks, k_est_blocks] + + q_est_blocks = block_sums_fine.shape[2] + k_est_blocks = block_sums_fine.shape[3] + + # Step 2: Aggregate to CPU block level + # ratio = cpu_block_size / estimate_block_size = 4 + ratio = cpu_block_size // estimate_block_size + num_cpu_blocks = k_est_blocks // ratio + + # Reshape and sum along K dimension + # [batch, heads, q_est, k_est] → [batch, heads, q_est, num_cpu, ratio] + block_sums_coarse = block_sums_fine.view( + batch_size, num_heads, q_est_blocks, num_cpu_blocks, ratio + ).sum(dim=-1) # [batch, heads, q_est_blocks, num_cpu_blocks] + + # Step 3: Sum over Q dimension (total attention from Q chunk to each K block) + cpu_block_scores = block_sums_coarse.sum(dim=2) # [batch, heads, num_cpu_blocks] + + return cpu_block_scores, block_sums_fine + + +def direct_block_sum( + attn_scores, + cpu_block_size, + stride, + head_dim, +): + """ + Compute block sums directly with CPU block size (baseline for comparison). + + Args: + attn_scores: [batch, heads, q_reshaped, k_reshaped] + cpu_block_size: Block size (e.g., 4096) + stride: Stride used in reshape + head_dim: Head dimension for scale computation + + Returns: + cpu_block_scores: [batch, heads, num_cpu_blocks] + """ + batch_size, num_heads, q_reshaped, k_reshaped = attn_scores.shape + + reshaped_cpu_bs = cpu_block_size // stride # 4096/8 = 512 + + norm = 1.0 + scale = 1.4426950408889634 / math.sqrt(head_dim) / stride / norm + segment_size = min(4096, reshaped_cpu_bs) + + block_sums = softmax_fuse_block_sum( + attn_scores, + reshaped_cpu_bs, + segment_size, + chunk_start=0, + chunk_end=q_reshaped, + real_q_len=q_reshaped, + scale=scale, + is_causal=False, + ) + # block_sums: [batch, heads, q_cpu_blocks, k_cpu_blocks] + + # Sum over Q dimension + cpu_block_scores = block_sums.sum(dim=2) # [batch, heads, num_cpu_blocks] + + return cpu_block_scores + + +def select_blocks_by_score( + cpu_block_scores, + threshold=0.95, + always_include_first=True, + always_include_last=True, +): + """ + Select CPU blocks based on score + threshold. + + ⚠️ IMPORTANT: This replaces the original mask + majority voting strategy. + This change should be documented in the final implementation. + + Args: + cpu_block_scores: [batch, heads, num_cpu_blocks] + threshold: Cumulative attention threshold (e.g., 0.95) + always_include_first: Always include first block (sink) + always_include_last: Always include last block (safety) + + Returns: + selected_block_ids: List of selected block indices + density: Fraction of blocks selected + """ + # Average scores across heads + scores_per_block = cpu_block_scores.mean(dim=(0, 1)) # [num_cpu_blocks] + num_blocks = scores_per_block.shape[0] + + # Normalize to get attention distribution + total_score = scores_per_block.sum() + score_ratio = scores_per_block / total_score + + # Sort by score (descending) + sorted_indices = torch.argsort(score_ratio, descending=True) + + # Select blocks until cumulative threshold is reached + cumsum = 0.0 + selected = set() + + for idx in sorted_indices.tolist(): + selected.add(idx) + cumsum += score_ratio[idx].item() + if cumsum >= threshold: + break + + # Always include first and last blocks + if always_include_first: + selected.add(0) + if always_include_last: + selected.add(num_blocks - 1) + + selected_block_ids = sorted(list(selected)) + density = len(selected_block_ids) / num_blocks + + return selected_block_ids, density + + +# ============================================================ +# Test Cases +# ============================================================ + +def test_equivalence(): + """ + Test that hierarchical estimation produces equivalent scores to direct estimation. + """ + print("=" * 60) + print("Test 1: Hierarchical vs Direct - Equivalence") + print("=" * 60) + + # Create random Q and multiple K blocks + q_len = CPU_BLOCK_SIZE # 4096 + num_k_blocks = 4 + + # Q: [1, num_heads, q_len, head_dim] + Q = torch.randn(1, NUM_HEADS, q_len, HEAD_DIM, dtype=torch.bfloat16, device='cuda') + + # K blocks: each [1, num_heads, cpu_block_size, head_dim] + K_blocks = [ + torch.randn(1, NUM_HEADS, CPU_BLOCK_SIZE, HEAD_DIM, dtype=torch.bfloat16, device='cuda') + for _ in range(num_k_blocks) + ] + + # Compute attention scores + attn_scores = compute_attention_scores(Q, K_blocks, STRIDE) + print(f"attn_scores shape: {attn_scores.shape}") + + # Method 1: Hierarchical (fast) + scores_hier, _ = hierarchical_block_sum( + attn_scores, ESTIMATE_BLOCK_SIZE, CPU_BLOCK_SIZE, STRIDE, HEAD_DIM + ) + print(f"scores_hier shape: {scores_hier.shape}") + + # Method 2: Direct (slow) + scores_direct = direct_block_sum( + attn_scores, CPU_BLOCK_SIZE, STRIDE, HEAD_DIM + ) + print(f"scores_direct shape: {scores_direct.shape}") + + # Compare + diff = (scores_hier - scores_direct).abs().max().item() + print(f"\nMax difference: {diff:.6f}") + + # Per-block comparison + print("\nPer-block scores comparison:") + for i in range(num_k_blocks): + h_val = scores_hier[0, 0, i].item() + d_val = scores_direct[0, 0, i].item() + print(f" Block {i}: hierarchical={h_val:.4f}, direct={d_val:.4f}, diff={abs(h_val-d_val):.6f}") + + passed = diff < 0.01 + print(f"\nTest 1: {'PASSED' if passed else 'FAILED'}") + return passed + + +def test_selection(): + """ + Test the score + threshold selection strategy. + """ + print("\n" + "=" * 60) + print("Test 2: Score + Threshold Selection") + print("=" * 60) + + # Create Q and K blocks with varying importance + q_len = CPU_BLOCK_SIZE + num_k_blocks = 8 + + Q = torch.randn(1, NUM_HEADS, q_len, HEAD_DIM, dtype=torch.bfloat16, device='cuda') + + # Create K blocks - make some more important than others + K_blocks = [] + for i in range(num_k_blocks): + # First and middle blocks are more important (higher values) + importance = 2.0 if i in [0, 3, 4] else 1.0 + K = torch.randn(1, NUM_HEADS, CPU_BLOCK_SIZE, HEAD_DIM, dtype=torch.bfloat16, device='cuda') + K = K * importance + K_blocks.append(K) + + # Compute scores + attn_scores = compute_attention_scores(Q, K_blocks, STRIDE) + scores, _ = hierarchical_block_sum( + attn_scores, ESTIMATE_BLOCK_SIZE, CPU_BLOCK_SIZE, STRIDE, HEAD_DIM + ) + + # Print scores per block + print("\nCPU block scores (head 0):") + for i in range(num_k_blocks): + print(f" Block {i}: {scores[0, 0, i].item():.4f}") + + # Select blocks with different thresholds + for thresh in [0.9, 0.95, 0.99]: + selected, density = select_blocks_by_score(scores, threshold=thresh) + print(f"\nThreshold {thresh}: selected {len(selected)}/{num_k_blocks} blocks ({density:.1%})") + print(f" Selected: {selected}") + + print("\nTest 2: PASSED (visual inspection)") + return True + + +def test_performance(): + """ + Benchmark hierarchical vs direct estimation performance. + """ + print("\n" + "=" * 60) + print("Test 3: Performance Benchmark") + print("=" * 60) + + import time + + NUM_WARMUP = 3 + NUM_RUNS = 10 + + # Larger test case + q_len = CPU_BLOCK_SIZE + num_k_blocks = 16 # 64K context + + Q = torch.randn(1, NUM_HEADS, q_len, HEAD_DIM, dtype=torch.bfloat16, device='cuda') + K_blocks = [ + torch.randn(1, NUM_HEADS, CPU_BLOCK_SIZE, HEAD_DIM, dtype=torch.bfloat16, device='cuda') + for _ in range(num_k_blocks) + ] + + # Compute attention scores (shared) + attn_scores = compute_attention_scores(Q, K_blocks, STRIDE) + print(f"attn_scores shape: {attn_scores.shape}") + print(f"Context: {num_k_blocks * CPU_BLOCK_SIZE // 1024}K tokens") + + # Warmup and benchmark hierarchical + for _ in range(NUM_WARMUP): + _ = hierarchical_block_sum(attn_scores, ESTIMATE_BLOCK_SIZE, CPU_BLOCK_SIZE, STRIDE, HEAD_DIM) + torch.cuda.synchronize() + + start = time.perf_counter() + for _ in range(NUM_RUNS): + _ = hierarchical_block_sum(attn_scores, ESTIMATE_BLOCK_SIZE, CPU_BLOCK_SIZE, STRIDE, HEAD_DIM) + torch.cuda.synchronize() + hier_time = (time.perf_counter() - start) / NUM_RUNS * 1000 + + # Warmup and benchmark direct + for _ in range(NUM_WARMUP): + _ = direct_block_sum(attn_scores, CPU_BLOCK_SIZE, STRIDE, HEAD_DIM) + torch.cuda.synchronize() + + start = time.perf_counter() + for _ in range(NUM_RUNS): + _ = direct_block_sum(attn_scores, CPU_BLOCK_SIZE, STRIDE, HEAD_DIM) + torch.cuda.synchronize() + direct_time = (time.perf_counter() - start) / NUM_RUNS * 1000 + + speedup = direct_time / hier_time + + print(f"\nResults:") + print(f" Hierarchical (bs=1024): {hier_time:.2f} ms") + print(f" Direct (bs=4096): {direct_time:.2f} ms") + print(f" Speedup: {speedup:.2f}x") + + passed = speedup > 5.0 # Expect at least 5x speedup + print(f"\nTest 3: {'PASSED' if passed else 'FAILED'} (speedup > 5x expected)") + return passed + + +# ============================================================ +# Main +# ============================================================ + +if __name__ == "__main__": + print("=" * 60) + print("Hierarchical Block Sum Estimation Test") + print("=" * 60) + print(f"\nConfiguration:") + print(f" NUM_HEADS: {NUM_HEADS}") + print(f" NUM_KV_HEADS: {NUM_KV_HEADS}") + print(f" HEAD_DIM: {HEAD_DIM}") + print(f" STRIDE: {STRIDE}") + print(f" CPU_BLOCK_SIZE: {CPU_BLOCK_SIZE}") + print(f" ESTIMATE_BLOCK_SIZE: {ESTIMATE_BLOCK_SIZE}") + print(f" THRESHOLD: {THRESHOLD}") + print() + + results = [] + + results.append(("Equivalence", test_equivalence())) + results.append(("Selection", test_selection())) + results.append(("Performance", test_performance())) + + print("\n" + "=" * 60) + print("SUMMARY") + print("=" * 60) + for name, passed in results: + status = "PASSED" if passed else "FAILED" + print(f" {name}: {status}") + + all_passed = all(p for _, p in results) + print("=" * 60) + if all_passed: + print("test_hierarchical_estimate: ALL PASSED") + sys.exit(0) + else: + print("test_hierarchical_estimate: SOME FAILED") + sys.exit(1)