""" Test: Compare xattn_estimate vs xattn_estimate_chunked Verify that chunked estimation with EXTERNAL chunking produces the same mask as standard estimation. This ensures the chunked version can be used in chunked prefill scenarios without accuracy loss. Usage: CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \ python tests/test_xattn_estimate_chunked.py """ import sys import traceback import torch from nanovllm.ops.xattn import xattn_estimate, xattn_estimate_chunked # ============================================================ # Configuration # ============================================================ # Configuration for xattn_estimate_chunked consistency test. # Key requirements for 100% match: # 1. Use matching chunk_size for both standard and chunked versions # 2. Use same random seed for reproducibility # Note: Tiny differences (~0.000001) may occur at boundary cases due to # floating point precision in cumulative sum calculations. BLOCK_SIZE = 64 STRIDE = 4 THRESHOLD = 0.9 CHUNK_SIZE = 4096 # External chunking size # Test sequence lengths TEST_SEQ_LENS = [4096, 8192, 16384, 32768] # ============================================================ # Utility Functions # ============================================================ def compare_masks(mask1, mask2, name1="standard", name2="chunked"): """Compare two masks and report differences.""" if mask1.shape != mask2.shape: print(f" Shape mismatch: {name1}={mask1.shape}, {name2}={mask2.shape}") return False diff = (mask1 != mask2).sum().item() total = mask1.numel() match_rate = (total - diff) / total * 100 print(f" Match rate: {match_rate:.4f}% ({total - diff}/{total})") if diff > 0: diff_indices = torch.where(mask1 != mask2) print(f" First 5 diff positions: {list(zip(*[idx[:5].tolist() for idx in diff_indices]))}") return diff == 0 def run_chunked_externally(query, key, block_size, stride, threshold, chunk_size): """ Run xattn_estimate_chunked with EXTERNAL chunking. This simulates how chunked prefill should be used in practice. """ batch_size, num_heads, q_len, head_dim = query.shape _, _, k_len, _ = key.shape q_block_num = (q_len + block_size - 1) // block_size k_block_num = (k_len + block_size - 1) // block_size # If Q fits in one chunk, call directly if q_len <= chunk_size: return xattn_estimate_chunked( query, key, q_start_pos=0, block_size=block_size, stride=stride, threshold=threshold, use_triton=True, chunk_size=chunk_size, ) # External chunking: split Q and call for each chunk num_q_chunks = (q_len + chunk_size - 1) // chunk_size print(f" External chunking: {num_q_chunks} chunks") combined_attn_sum = torch.zeros( batch_size, num_heads, q_block_num, k_block_num, dtype=query.dtype, device=query.device ) combined_mask = torch.zeros( batch_size, num_heads, q_block_num, k_block_num, dtype=torch.bool, device=query.device ) q_block_offset = 0 for q_chunk_idx in range(num_q_chunks): q_chunk_start = q_chunk_idx * chunk_size q_chunk_end = min((q_chunk_idx + 1) * chunk_size, q_len) q_chunk = query[:, :, q_chunk_start:q_chunk_end, :] # For causal attention, K accumulates up to current Q position # q_start_pos=0 means Q starts at position 0 in the full sequence # K is [0, q_chunk_end) for causal attention k_end = q_chunk_end k_chunk = key[:, :, :k_end, :] attn_sum_chunk, mask_chunk = xattn_estimate_chunked( q_chunk, k_chunk, q_start_pos=q_chunk_start, block_size=block_size, stride=stride, threshold=threshold, use_triton=True, chunk_size=chunk_size, ) # Place chunk results into combined output chunk_q_blocks = mask_chunk.shape[2] chunk_k_blocks = mask_chunk.shape[3] combined_attn_sum[:, :, q_block_offset:q_block_offset+chunk_q_blocks, :chunk_k_blocks] = attn_sum_chunk combined_mask[:, :, q_block_offset:q_block_offset+chunk_q_blocks, :chunk_k_blocks] = mask_chunk q_block_offset += chunk_q_blocks return combined_attn_sum, combined_mask def test_single_seq_len(seq_len, num_heads=32, head_dim=128): """Test a single sequence length.""" print(f"\nTesting seq_len={seq_len}") print("=" * 60) # Generate random Q/K query = torch.randn(1, num_heads, seq_len, head_dim, device="cuda", dtype=torch.bfloat16) key = torch.randn(1, num_heads, seq_len, head_dim, device="cuda", dtype=torch.bfloat16) # Run standard xattn_estimate print("[1] Running standard xattn_estimate...") try: attn_sum_std, mask_std = xattn_estimate( query, key, block_size=BLOCK_SIZE, stride=STRIDE, threshold=THRESHOLD, chunk_size=CHUNK_SIZE, use_triton=True, causal=True, ) density_std = mask_std.float().mean().item() print(f" mask shape: {mask_std.shape}, density: {density_std:.4f}") except Exception as e: print(f" ERROR: {e}") traceback.print_exc() return False # Run chunked xattn_estimate with EXTERNAL chunking print("[2] Running chunked xattn_estimate (external chunking)...") try: attn_sum_chunked, mask_chunked = run_chunked_externally( query, key, block_size=BLOCK_SIZE, stride=STRIDE, threshold=THRESHOLD, chunk_size=CHUNK_SIZE, ) density_chunked = mask_chunked.float().mean().item() print(f" mask shape: {mask_chunked.shape}, density: {density_chunked:.4f}") except Exception as e: print(f" ERROR: {e}") traceback.print_exc() return False # Compare results print("[3] Comparing results...") chunked_q_blocks = mask_chunked.shape[2] chunked_k_blocks = mask_chunked.shape[3] # Extract comparable region from standard mask mask_std_comparable = mask_std[:, :, :chunked_q_blocks, :chunked_k_blocks] # Compare masks masks_match = compare_masks(mask_std_comparable, mask_chunked, "standard", "chunked") # Compare attn_sums attn_sum_std_comparable = attn_sum_std[:, :, :chunked_q_blocks, :chunked_k_blocks] if attn_sum_std_comparable.shape == attn_sum_chunked.shape: attn_diff = (attn_sum_std_comparable - attn_sum_chunked).abs().max().item() print(f" Attn sum max diff: {attn_diff:.6f}") else: print(f" Attn sum shape mismatch: std={attn_sum_std_comparable.shape}, chunked={attn_sum_chunked.shape}") # Clean up GPU memory del query, key, attn_sum_std, mask_std, attn_sum_chunked, mask_chunked torch.cuda.empty_cache() return masks_match # ============================================================ # Main Test # ============================================================ if __name__ == "__main__": print("XAttention Chunked vs Standard Test") print("=" * 60) print(f"Config: block_size={BLOCK_SIZE}, stride={STRIDE}, threshold={THRESHOLD}") print(f"External chunk_size={CHUNK_SIZE}") print() # Check CUDA availability if not torch.cuda.is_available(): print("CUDA not available!") sys.exit(1) print(f"Using GPU: {torch.cuda.get_device_name(0)}") print("✓ xattn_estimate imported") print("✓ xattn_estimate_chunked imported") # Run tests all_passed = True results = [] for seq_len in TEST_SEQ_LENS: passed = test_single_seq_len(seq_len) chunks = (seq_len + CHUNK_SIZE - 1) // CHUNK_SIZE results.append((seq_len, chunks, passed)) if not passed: all_passed = False # Summary print("\n" + "=" * 60) print("SUMMARY") print("=" * 60) for seq_len, chunks, passed in results: status = "PASSED" if passed else "FAILED" print(f" seq_len={seq_len:5d} ({chunks} chunk{'s' if chunks > 1 else ' '}): {status}") print("=" * 60) if all_passed: print("ALL TESTS PASSED!") sys.exit(0) else: print("SOME TESTS FAILED!") sys.exit(1)