✨ feat: add nanovllm.ops module with XAttention estimation kernels
Add ops module ported from tzj/minference branch containing: - xattn.py: XAttention block importance estimation with Triton kernels - xattn_estimate(): standard estimation for sparse attention mask - xattn_estimate_chunked(): chunked prefill compatible version - flat_group_gemm_fuse_reshape(): fused stride reshape + GEMM kernel - softmax_fuse_block_sum(): online softmax + block-wise sum kernel - chunked_attention.py: Flash attention with LSE output for chunk merging - test_xattn_estimate_chunked.py: verification test (all seq_lens pass) This prepares the foundation for AttentionPolicy refactoring where XAttentionPolicy.estimate() will call these ops. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
244
tests/test_xattn_estimate_chunked.py
Normal file
244
tests/test_xattn_estimate_chunked.py
Normal file
@@ -0,0 +1,244 @@
|
||||
"""
|
||||
Test: Compare xattn_estimate vs xattn_estimate_chunked
|
||||
|
||||
Verify that chunked estimation with EXTERNAL chunking produces the same mask
|
||||
as standard estimation. This ensures the chunked version can be used in
|
||||
chunked prefill scenarios without accuracy loss.
|
||||
|
||||
Usage:
|
||||
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
||||
python tests/test_xattn_estimate_chunked.py
|
||||
"""
|
||||
|
||||
import sys
|
||||
import traceback
|
||||
import torch
|
||||
from nanovllm.ops.xattn import xattn_estimate, xattn_estimate_chunked
|
||||
|
||||
# ============================================================
|
||||
# Configuration
|
||||
# ============================================================
|
||||
|
||||
# Configuration for xattn_estimate_chunked consistency test.
|
||||
# Key requirements for 100% match:
|
||||
# 1. Use matching chunk_size for both standard and chunked versions
|
||||
# 2. Use same random seed for reproducibility
|
||||
# Note: Tiny differences (~0.000001) may occur at boundary cases due to
|
||||
# floating point precision in cumulative sum calculations.
|
||||
BLOCK_SIZE = 64
|
||||
STRIDE = 4
|
||||
THRESHOLD = 0.9
|
||||
CHUNK_SIZE = 4096 # External chunking size
|
||||
|
||||
# Test sequence lengths
|
||||
TEST_SEQ_LENS = [4096, 8192, 16384, 32768]
|
||||
|
||||
# ============================================================
|
||||
# Utility Functions
|
||||
# ============================================================
|
||||
|
||||
def compare_masks(mask1, mask2, name1="standard", name2="chunked"):
|
||||
"""Compare two masks and report differences."""
|
||||
if mask1.shape != mask2.shape:
|
||||
print(f" Shape mismatch: {name1}={mask1.shape}, {name2}={mask2.shape}")
|
||||
return False
|
||||
|
||||
diff = (mask1 != mask2).sum().item()
|
||||
total = mask1.numel()
|
||||
match_rate = (total - diff) / total * 100
|
||||
|
||||
print(f" Match rate: {match_rate:.4f}% ({total - diff}/{total})")
|
||||
|
||||
if diff > 0:
|
||||
diff_indices = torch.where(mask1 != mask2)
|
||||
print(f" First 5 diff positions: {list(zip(*[idx[:5].tolist() for idx in diff_indices]))}")
|
||||
|
||||
return diff == 0
|
||||
|
||||
|
||||
def run_chunked_externally(query, key, block_size, stride, threshold, chunk_size):
|
||||
"""
|
||||
Run xattn_estimate_chunked with EXTERNAL chunking.
|
||||
This simulates how chunked prefill should be used in practice.
|
||||
"""
|
||||
batch_size, num_heads, q_len, head_dim = query.shape
|
||||
_, _, k_len, _ = key.shape
|
||||
|
||||
q_block_num = (q_len + block_size - 1) // block_size
|
||||
k_block_num = (k_len + block_size - 1) // block_size
|
||||
|
||||
# If Q fits in one chunk, call directly
|
||||
if q_len <= chunk_size:
|
||||
return xattn_estimate_chunked(
|
||||
query, key,
|
||||
q_start_pos=0,
|
||||
block_size=block_size,
|
||||
stride=stride,
|
||||
threshold=threshold,
|
||||
use_triton=True,
|
||||
chunk_size=chunk_size,
|
||||
)
|
||||
|
||||
# External chunking: split Q and call for each chunk
|
||||
num_q_chunks = (q_len + chunk_size - 1) // chunk_size
|
||||
print(f" External chunking: {num_q_chunks} chunks")
|
||||
|
||||
combined_attn_sum = torch.zeros(
|
||||
batch_size, num_heads, q_block_num, k_block_num,
|
||||
dtype=query.dtype, device=query.device
|
||||
)
|
||||
combined_mask = torch.zeros(
|
||||
batch_size, num_heads, q_block_num, k_block_num,
|
||||
dtype=torch.bool, device=query.device
|
||||
)
|
||||
|
||||
q_block_offset = 0
|
||||
for q_chunk_idx in range(num_q_chunks):
|
||||
q_chunk_start = q_chunk_idx * chunk_size
|
||||
q_chunk_end = min((q_chunk_idx + 1) * chunk_size, q_len)
|
||||
|
||||
q_chunk = query[:, :, q_chunk_start:q_chunk_end, :]
|
||||
|
||||
# For causal attention, K accumulates up to current Q position
|
||||
# q_start_pos=0 means Q starts at position 0 in the full sequence
|
||||
# K is [0, q_chunk_end) for causal attention
|
||||
k_end = q_chunk_end
|
||||
k_chunk = key[:, :, :k_end, :]
|
||||
|
||||
attn_sum_chunk, mask_chunk = xattn_estimate_chunked(
|
||||
q_chunk, k_chunk,
|
||||
q_start_pos=q_chunk_start,
|
||||
block_size=block_size,
|
||||
stride=stride,
|
||||
threshold=threshold,
|
||||
use_triton=True,
|
||||
chunk_size=chunk_size,
|
||||
)
|
||||
|
||||
# Place chunk results into combined output
|
||||
chunk_q_blocks = mask_chunk.shape[2]
|
||||
chunk_k_blocks = mask_chunk.shape[3]
|
||||
combined_attn_sum[:, :, q_block_offset:q_block_offset+chunk_q_blocks, :chunk_k_blocks] = attn_sum_chunk
|
||||
combined_mask[:, :, q_block_offset:q_block_offset+chunk_q_blocks, :chunk_k_blocks] = mask_chunk
|
||||
q_block_offset += chunk_q_blocks
|
||||
|
||||
return combined_attn_sum, combined_mask
|
||||
|
||||
|
||||
def test_single_seq_len(seq_len, num_heads=32, head_dim=128):
|
||||
"""Test a single sequence length."""
|
||||
print(f"\nTesting seq_len={seq_len}")
|
||||
print("=" * 60)
|
||||
|
||||
# Generate random Q/K
|
||||
query = torch.randn(1, num_heads, seq_len, head_dim, device="cuda", dtype=torch.bfloat16)
|
||||
key = torch.randn(1, num_heads, seq_len, head_dim, device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
# Run standard xattn_estimate
|
||||
print("[1] Running standard xattn_estimate...")
|
||||
try:
|
||||
attn_sum_std, mask_std = xattn_estimate(
|
||||
query, key,
|
||||
block_size=BLOCK_SIZE,
|
||||
stride=STRIDE,
|
||||
threshold=THRESHOLD,
|
||||
chunk_size=CHUNK_SIZE,
|
||||
use_triton=True,
|
||||
causal=True,
|
||||
)
|
||||
density_std = mask_std.float().mean().item()
|
||||
print(f" mask shape: {mask_std.shape}, density: {density_std:.4f}")
|
||||
except Exception as e:
|
||||
print(f" ERROR: {e}")
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
# Run chunked xattn_estimate with EXTERNAL chunking
|
||||
print("[2] Running chunked xattn_estimate (external chunking)...")
|
||||
try:
|
||||
attn_sum_chunked, mask_chunked = run_chunked_externally(
|
||||
query, key,
|
||||
block_size=BLOCK_SIZE,
|
||||
stride=STRIDE,
|
||||
threshold=THRESHOLD,
|
||||
chunk_size=CHUNK_SIZE,
|
||||
)
|
||||
density_chunked = mask_chunked.float().mean().item()
|
||||
print(f" mask shape: {mask_chunked.shape}, density: {density_chunked:.4f}")
|
||||
except Exception as e:
|
||||
print(f" ERROR: {e}")
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
# Compare results
|
||||
print("[3] Comparing results...")
|
||||
chunked_q_blocks = mask_chunked.shape[2]
|
||||
chunked_k_blocks = mask_chunked.shape[3]
|
||||
|
||||
# Extract comparable region from standard mask
|
||||
mask_std_comparable = mask_std[:, :, :chunked_q_blocks, :chunked_k_blocks]
|
||||
|
||||
# Compare masks
|
||||
masks_match = compare_masks(mask_std_comparable, mask_chunked, "standard", "chunked")
|
||||
|
||||
# Compare attn_sums
|
||||
attn_sum_std_comparable = attn_sum_std[:, :, :chunked_q_blocks, :chunked_k_blocks]
|
||||
if attn_sum_std_comparable.shape == attn_sum_chunked.shape:
|
||||
attn_diff = (attn_sum_std_comparable - attn_sum_chunked).abs().max().item()
|
||||
print(f" Attn sum max diff: {attn_diff:.6f}")
|
||||
else:
|
||||
print(f" Attn sum shape mismatch: std={attn_sum_std_comparable.shape}, chunked={attn_sum_chunked.shape}")
|
||||
|
||||
# Clean up GPU memory
|
||||
del query, key, attn_sum_std, mask_std, attn_sum_chunked, mask_chunked
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return masks_match
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Main Test
|
||||
# ============================================================
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("XAttention Chunked vs Standard Test")
|
||||
print("=" * 60)
|
||||
print(f"Config: block_size={BLOCK_SIZE}, stride={STRIDE}, threshold={THRESHOLD}")
|
||||
print(f"External chunk_size={CHUNK_SIZE}")
|
||||
print()
|
||||
|
||||
# Check CUDA availability
|
||||
if not torch.cuda.is_available():
|
||||
print("CUDA not available!")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Using GPU: {torch.cuda.get_device_name(0)}")
|
||||
print("✓ xattn_estimate imported")
|
||||
print("✓ xattn_estimate_chunked imported")
|
||||
|
||||
# Run tests
|
||||
all_passed = True
|
||||
results = []
|
||||
|
||||
for seq_len in TEST_SEQ_LENS:
|
||||
passed = test_single_seq_len(seq_len)
|
||||
chunks = (seq_len + CHUNK_SIZE - 1) // CHUNK_SIZE
|
||||
results.append((seq_len, chunks, passed))
|
||||
if not passed:
|
||||
all_passed = False
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 60)
|
||||
print("SUMMARY")
|
||||
print("=" * 60)
|
||||
for seq_len, chunks, passed in results:
|
||||
status = "PASSED" if passed else "FAILED"
|
||||
print(f" seq_len={seq_len:5d} ({chunks} chunk{'s' if chunks > 1 else ' '}): {status}")
|
||||
|
||||
print("=" * 60)
|
||||
if all_passed:
|
||||
print("ALL TESTS PASSED!")
|
||||
sys.exit(0)
|
||||
else:
|
||||
print("SOME TESTS FAILED!")
|
||||
sys.exit(1)
|
||||
Reference in New Issue
Block a user