- Add xattn_estimate_chunked function ported from COMPASS - Support chunked prefill with q_start_pos parameter - Ensure 100% consistency with standard xattn_estimate when using matching chunk_size parameter - Add test and documentation Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
245 lines
8.2 KiB
Python
245 lines
8.2 KiB
Python
"""
|
|
Test: Compare xattn_estimate vs xattn_estimate_chunked
|
|
|
|
Verify that chunked estimation with EXTERNAL chunking produces the same mask
|
|
as standard estimation. This ensures the chunked version can be used in
|
|
chunked prefill scenarios without accuracy loss.
|
|
|
|
Usage:
|
|
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
|
python tests/test_xattn_estimate_chunked.py
|
|
"""
|
|
|
|
import sys
|
|
import traceback
|
|
import torch
|
|
from nanovllm.ops.xattn import xattn_estimate, xattn_estimate_chunked
|
|
|
|
# ============================================================
|
|
# Configuration
|
|
# ============================================================
|
|
|
|
# Configuration for xattn_estimate_chunked consistency test.
|
|
# Key requirements for 100% match:
|
|
# 1. Use matching chunk_size for both standard and chunked versions
|
|
# 2. Use same random seed for reproducibility
|
|
# Note: Tiny differences (~0.000001) may occur at boundary cases due to
|
|
# floating point precision in cumulative sum calculations.
|
|
BLOCK_SIZE = 64
|
|
STRIDE = 4
|
|
THRESHOLD = 0.9
|
|
CHUNK_SIZE = 4096 # External chunking size
|
|
|
|
# Test sequence lengths
|
|
TEST_SEQ_LENS = [4096, 8192, 16384, 32768]
|
|
|
|
# ============================================================
|
|
# Utility Functions
|
|
# ============================================================
|
|
|
|
def compare_masks(mask1, mask2, name1="standard", name2="chunked"):
|
|
"""Compare two masks and report differences."""
|
|
if mask1.shape != mask2.shape:
|
|
print(f" Shape mismatch: {name1}={mask1.shape}, {name2}={mask2.shape}")
|
|
return False
|
|
|
|
diff = (mask1 != mask2).sum().item()
|
|
total = mask1.numel()
|
|
match_rate = (total - diff) / total * 100
|
|
|
|
print(f" Match rate: {match_rate:.4f}% ({total - diff}/{total})")
|
|
|
|
if diff > 0:
|
|
diff_indices = torch.where(mask1 != mask2)
|
|
print(f" First 5 diff positions: {list(zip(*[idx[:5].tolist() for idx in diff_indices]))}")
|
|
|
|
return diff == 0
|
|
|
|
|
|
def run_chunked_externally(query, key, block_size, stride, threshold, chunk_size):
|
|
"""
|
|
Run xattn_estimate_chunked with EXTERNAL chunking.
|
|
This simulates how chunked prefill should be used in practice.
|
|
"""
|
|
batch_size, num_heads, q_len, head_dim = query.shape
|
|
_, _, k_len, _ = key.shape
|
|
|
|
q_block_num = (q_len + block_size - 1) // block_size
|
|
k_block_num = (k_len + block_size - 1) // block_size
|
|
|
|
# If Q fits in one chunk, call directly
|
|
if q_len <= chunk_size:
|
|
return xattn_estimate_chunked(
|
|
query, key,
|
|
q_start_pos=0,
|
|
block_size=block_size,
|
|
stride=stride,
|
|
threshold=threshold,
|
|
use_triton=True,
|
|
chunk_size=chunk_size,
|
|
)
|
|
|
|
# External chunking: split Q and call for each chunk
|
|
num_q_chunks = (q_len + chunk_size - 1) // chunk_size
|
|
print(f" External chunking: {num_q_chunks} chunks")
|
|
|
|
combined_attn_sum = torch.zeros(
|
|
batch_size, num_heads, q_block_num, k_block_num,
|
|
dtype=query.dtype, device=query.device
|
|
)
|
|
combined_mask = torch.zeros(
|
|
batch_size, num_heads, q_block_num, k_block_num,
|
|
dtype=torch.bool, device=query.device
|
|
)
|
|
|
|
q_block_offset = 0
|
|
for q_chunk_idx in range(num_q_chunks):
|
|
q_chunk_start = q_chunk_idx * chunk_size
|
|
q_chunk_end = min((q_chunk_idx + 1) * chunk_size, q_len)
|
|
|
|
q_chunk = query[:, :, q_chunk_start:q_chunk_end, :]
|
|
|
|
# For causal attention, K accumulates up to current Q position
|
|
# q_start_pos=0 means Q starts at position 0 in the full sequence
|
|
# K is [0, q_chunk_end) for causal attention
|
|
k_end = q_chunk_end
|
|
k_chunk = key[:, :, :k_end, :]
|
|
|
|
attn_sum_chunk, mask_chunk = xattn_estimate_chunked(
|
|
q_chunk, k_chunk,
|
|
q_start_pos=q_chunk_start,
|
|
block_size=block_size,
|
|
stride=stride,
|
|
threshold=threshold,
|
|
use_triton=True,
|
|
chunk_size=chunk_size,
|
|
)
|
|
|
|
# Place chunk results into combined output
|
|
chunk_q_blocks = mask_chunk.shape[2]
|
|
chunk_k_blocks = mask_chunk.shape[3]
|
|
combined_attn_sum[:, :, q_block_offset:q_block_offset+chunk_q_blocks, :chunk_k_blocks] = attn_sum_chunk
|
|
combined_mask[:, :, q_block_offset:q_block_offset+chunk_q_blocks, :chunk_k_blocks] = mask_chunk
|
|
q_block_offset += chunk_q_blocks
|
|
|
|
return combined_attn_sum, combined_mask
|
|
|
|
|
|
def test_single_seq_len(seq_len, num_heads=32, head_dim=128):
|
|
"""Test a single sequence length."""
|
|
print(f"\nTesting seq_len={seq_len}")
|
|
print("=" * 60)
|
|
|
|
# Generate random Q/K
|
|
query = torch.randn(1, num_heads, seq_len, head_dim, device="cuda", dtype=torch.bfloat16)
|
|
key = torch.randn(1, num_heads, seq_len, head_dim, device="cuda", dtype=torch.bfloat16)
|
|
|
|
# Run standard xattn_estimate
|
|
print("[1] Running standard xattn_estimate...")
|
|
try:
|
|
attn_sum_std, mask_std = xattn_estimate(
|
|
query, key,
|
|
block_size=BLOCK_SIZE,
|
|
stride=STRIDE,
|
|
threshold=THRESHOLD,
|
|
chunk_size=CHUNK_SIZE,
|
|
use_triton=True,
|
|
causal=True,
|
|
)
|
|
density_std = mask_std.float().mean().item()
|
|
print(f" mask shape: {mask_std.shape}, density: {density_std:.4f}")
|
|
except Exception as e:
|
|
print(f" ERROR: {e}")
|
|
traceback.print_exc()
|
|
return False
|
|
|
|
# Run chunked xattn_estimate with EXTERNAL chunking
|
|
print("[2] Running chunked xattn_estimate (external chunking)...")
|
|
try:
|
|
attn_sum_chunked, mask_chunked = run_chunked_externally(
|
|
query, key,
|
|
block_size=BLOCK_SIZE,
|
|
stride=STRIDE,
|
|
threshold=THRESHOLD,
|
|
chunk_size=CHUNK_SIZE,
|
|
)
|
|
density_chunked = mask_chunked.float().mean().item()
|
|
print(f" mask shape: {mask_chunked.shape}, density: {density_chunked:.4f}")
|
|
except Exception as e:
|
|
print(f" ERROR: {e}")
|
|
traceback.print_exc()
|
|
return False
|
|
|
|
# Compare results
|
|
print("[3] Comparing results...")
|
|
chunked_q_blocks = mask_chunked.shape[2]
|
|
chunked_k_blocks = mask_chunked.shape[3]
|
|
|
|
# Extract comparable region from standard mask
|
|
mask_std_comparable = mask_std[:, :, :chunked_q_blocks, :chunked_k_blocks]
|
|
|
|
# Compare masks
|
|
masks_match = compare_masks(mask_std_comparable, mask_chunked, "standard", "chunked")
|
|
|
|
# Compare attn_sums
|
|
attn_sum_std_comparable = attn_sum_std[:, :, :chunked_q_blocks, :chunked_k_blocks]
|
|
if attn_sum_std_comparable.shape == attn_sum_chunked.shape:
|
|
attn_diff = (attn_sum_std_comparable - attn_sum_chunked).abs().max().item()
|
|
print(f" Attn sum max diff: {attn_diff:.6f}")
|
|
else:
|
|
print(f" Attn sum shape mismatch: std={attn_sum_std_comparable.shape}, chunked={attn_sum_chunked.shape}")
|
|
|
|
# Clean up GPU memory
|
|
del query, key, attn_sum_std, mask_std, attn_sum_chunked, mask_chunked
|
|
torch.cuda.empty_cache()
|
|
|
|
return masks_match
|
|
|
|
|
|
# ============================================================
|
|
# Main Test
|
|
# ============================================================
|
|
|
|
if __name__ == "__main__":
|
|
print("XAttention Chunked vs Standard Test")
|
|
print("=" * 60)
|
|
print(f"Config: block_size={BLOCK_SIZE}, stride={STRIDE}, threshold={THRESHOLD}")
|
|
print(f"External chunk_size={CHUNK_SIZE}")
|
|
print()
|
|
|
|
# Check CUDA availability
|
|
if not torch.cuda.is_available():
|
|
print("CUDA not available!")
|
|
sys.exit(1)
|
|
|
|
print(f"Using GPU: {torch.cuda.get_device_name(0)}")
|
|
print("✓ xattn_estimate imported")
|
|
print("✓ xattn_estimate_chunked imported")
|
|
|
|
# Run tests
|
|
all_passed = True
|
|
results = []
|
|
|
|
for seq_len in TEST_SEQ_LENS:
|
|
passed = test_single_seq_len(seq_len)
|
|
chunks = (seq_len + CHUNK_SIZE - 1) // CHUNK_SIZE
|
|
results.append((seq_len, chunks, passed))
|
|
if not passed:
|
|
all_passed = False
|
|
|
|
# Summary
|
|
print("\n" + "=" * 60)
|
|
print("SUMMARY")
|
|
print("=" * 60)
|
|
for seq_len, chunks, passed in results:
|
|
status = "PASSED" if passed else "FAILED"
|
|
print(f" seq_len={seq_len:5d} ({chunks} chunk{'s' if chunks > 1 else ' '}): {status}")
|
|
|
|
print("=" * 60)
|
|
if all_passed:
|
|
print("ALL TESTS PASSED!")
|
|
sys.exit(0)
|
|
else:
|
|
print("SOME TESTS FAILED!")
|
|
sys.exit(1)
|