[WIP] Before refactor the compute)_chunked_prefill.
This commit is contained in:
259
tests/test_xattn_chunked.py
Normal file
259
tests/test_xattn_chunked.py
Normal file
@@ -0,0 +1,259 @@
|
||||
"""
|
||||
Test: Compare xattn_estimate vs xattn_estimate_chunked
|
||||
Verify that chunked estimation with EXTERNAL chunking produces the same mask as standard estimation.
|
||||
|
||||
Uses real QKV data captured from model inference.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import torch
|
||||
import warnings
|
||||
|
||||
from nanovllm.ops.xattn import xattn_estimate, xattn_estimate_chunked
|
||||
|
||||
# ============================================================
|
||||
# Configuration
|
||||
# ============================================================
|
||||
|
||||
BLOCK_SIZE = 64
|
||||
STRIDE = 4
|
||||
THRESHOLD = 0.9
|
||||
CHUNK_SIZE = 4096
|
||||
|
||||
# Default QKV data directory (relative to project root)
|
||||
DEFAULT_QKV_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "results", "kvcache")
|
||||
|
||||
# ============================================================
|
||||
# Utility Functions
|
||||
# ============================================================
|
||||
|
||||
def load_qkv(path):
|
||||
"""Load saved QKV data."""
|
||||
data = torch.load(path, map_location="cpu", weights_only=False)
|
||||
print(f"Loaded: {path}")
|
||||
print(f" Query shape: {data['query'].shape}")
|
||||
print(f" Key shape: {data['key'].shape}")
|
||||
print(f" Layer: {data['layer_id']}, Density: {data['density']:.2%}")
|
||||
return data
|
||||
|
||||
|
||||
def compare_masks(mask1, mask2, name1="standard", name2="chunked"):
|
||||
"""Compare two masks and report differences."""
|
||||
if mask1.shape != mask2.shape:
|
||||
print(f"Shape mismatch: {name1}={mask1.shape}, {name2}={mask2.shape}")
|
||||
return False
|
||||
|
||||
diff = (mask1 != mask2).sum().item()
|
||||
total = mask1.numel()
|
||||
match_rate = (total - diff) / total * 100
|
||||
|
||||
print(f" Match rate: {match_rate:.4f}% ({total - diff}/{total})")
|
||||
|
||||
if diff > 0:
|
||||
diff_indices = torch.where(mask1 != mask2)
|
||||
print(f" First 5 diff positions: {list(zip(*[idx[:5].tolist() for idx in diff_indices]))}")
|
||||
|
||||
return diff == 0
|
||||
|
||||
|
||||
def run_chunked_externally(query, key, q_start_pos, block_size, stride, threshold, chunk_size):
|
||||
"""
|
||||
Run xattn_estimate_chunked with EXTERNAL chunking.
|
||||
This simulates how chunked prefill should be used in practice.
|
||||
"""
|
||||
batch_size, num_heads, q_len, head_dim = query.shape
|
||||
_, _, k_len, _ = key.shape
|
||||
|
||||
q_block_num = (q_len + block_size - 1) // block_size
|
||||
k_block_num = (k_len + block_size - 1) // block_size
|
||||
|
||||
# If Q fits in one chunk, call directly
|
||||
if q_len <= chunk_size:
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
return xattn_estimate_chunked(
|
||||
query, key,
|
||||
q_start_pos=q_start_pos,
|
||||
block_size=block_size,
|
||||
stride=stride,
|
||||
threshold=threshold,
|
||||
use_triton=True,
|
||||
chunk_size=chunk_size,
|
||||
)
|
||||
|
||||
# External chunking: split Q and call for each chunk
|
||||
num_q_chunks = (q_len + chunk_size - 1) // chunk_size
|
||||
print(f" External chunking: {num_q_chunks} chunks")
|
||||
|
||||
combined_attn_sum = torch.zeros(
|
||||
batch_size, num_heads, q_block_num, k_block_num,
|
||||
dtype=query.dtype, device=query.device
|
||||
)
|
||||
combined_mask = torch.zeros(
|
||||
batch_size, num_heads, q_block_num, k_block_num,
|
||||
dtype=torch.bool, device=query.device
|
||||
)
|
||||
|
||||
q_block_offset = 0
|
||||
for q_chunk_idx in range(num_q_chunks):
|
||||
q_chunk_start = q_chunk_idx * chunk_size
|
||||
q_chunk_end = min((q_chunk_idx + 1) * chunk_size, q_len)
|
||||
|
||||
q_chunk = query[:, :, q_chunk_start:q_chunk_end, :]
|
||||
|
||||
# For causal attention, K accumulates up to current Q position
|
||||
k_end = q_start_pos + q_chunk_end
|
||||
k_chunk = key[:, :, :k_end, :]
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
attn_sum_chunk, mask_chunk = xattn_estimate_chunked(
|
||||
q_chunk, k_chunk,
|
||||
q_start_pos=q_start_pos + q_chunk_start,
|
||||
block_size=block_size,
|
||||
stride=stride,
|
||||
threshold=threshold,
|
||||
use_triton=True,
|
||||
chunk_size=chunk_size,
|
||||
)
|
||||
|
||||
# Place chunk results into combined output
|
||||
chunk_q_blocks = mask_chunk.shape[2]
|
||||
chunk_k_blocks = mask_chunk.shape[3]
|
||||
combined_attn_sum[:, :, q_block_offset:q_block_offset+chunk_q_blocks, :chunk_k_blocks] = attn_sum_chunk
|
||||
combined_mask[:, :, q_block_offset:q_block_offset+chunk_q_blocks, :chunk_k_blocks] = mask_chunk
|
||||
q_block_offset += chunk_q_blocks
|
||||
|
||||
return combined_attn_sum, combined_mask
|
||||
|
||||
|
||||
def test_single_qkv(qkv_path):
|
||||
"""Test a single QKV file."""
|
||||
data = load_qkv(qkv_path)
|
||||
query = data["query"].cuda().to(torch.bfloat16)
|
||||
key = data["key"].cuda().to(torch.bfloat16)
|
||||
|
||||
seq_len = query.shape[2]
|
||||
print(f"\nTesting with seq_len={seq_len}")
|
||||
print("=" * 60)
|
||||
|
||||
# Run standard xattn_estimate
|
||||
print("[1] Running standard xattn_estimate...")
|
||||
try:
|
||||
attn_sum_std, mask_std = xattn_estimate(
|
||||
query, key,
|
||||
block_size=BLOCK_SIZE,
|
||||
stride=STRIDE,
|
||||
threshold=THRESHOLD,
|
||||
chunk_size=CHUNK_SIZE,
|
||||
use_triton=True,
|
||||
)
|
||||
print(f" mask shape: {mask_std.shape}, density: {mask_std.float().mean().item():.4f}")
|
||||
except Exception as e:
|
||||
print(f" ERROR: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
# Run chunked xattn_estimate with EXTERNAL chunking
|
||||
print("[2] Running chunked xattn_estimate (external chunking)...")
|
||||
try:
|
||||
attn_sum_chunked, mask_chunked = run_chunked_externally(
|
||||
query, key,
|
||||
q_start_pos=0,
|
||||
block_size=BLOCK_SIZE,
|
||||
stride=STRIDE,
|
||||
threshold=THRESHOLD,
|
||||
chunk_size=CHUNK_SIZE,
|
||||
)
|
||||
print(f" mask shape: {mask_chunked.shape}, density: {mask_chunked.float().mean().item():.4f}")
|
||||
except Exception as e:
|
||||
print(f" ERROR: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
# Compare results
|
||||
print("[3] Comparing results...")
|
||||
chunked_q_blocks = mask_chunked.shape[2]
|
||||
chunked_k_blocks = mask_chunked.shape[3]
|
||||
|
||||
# Extract comparable region from standard mask
|
||||
mask_std_comparable = mask_std[:, :, :chunked_q_blocks, :chunked_k_blocks]
|
||||
|
||||
# Compare masks
|
||||
masks_match = compare_masks(mask_std_comparable, mask_chunked, "standard", "chunked")
|
||||
|
||||
# Compare attn_sums
|
||||
attn_sum_std_comparable = attn_sum_std[:, :, :chunked_q_blocks, :chunked_k_blocks]
|
||||
if attn_sum_std_comparable.shape == attn_sum_chunked.shape:
|
||||
attn_diff = (attn_sum_std_comparable - attn_sum_chunked).abs().max().item()
|
||||
print(f" Attn sum max diff: {attn_diff:.6f}")
|
||||
else:
|
||||
print(f" Attn sum shape mismatch")
|
||||
|
||||
# Clean up GPU memory
|
||||
del query, key, attn_sum_std, mask_std, attn_sum_chunked, mask_chunked
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return masks_match
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Main Test
|
||||
# ============================================================
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description="Test xattn_estimate vs xattn_estimate_chunked")
|
||||
parser.add_argument("--qkv-dir", type=str, default=DEFAULT_QKV_DIR,
|
||||
help="Directory containing QKV files")
|
||||
args = parser.parse_args()
|
||||
|
||||
# QKV files to test
|
||||
qkv_files = [
|
||||
os.path.join(args.qkv_dir, "qkv_3688.pt"), # ~4K
|
||||
os.path.join(args.qkv_dir, "qkv_7888.pt"), # ~8K
|
||||
os.path.join(args.qkv_dir, "qkv_15685.pt"), # ~16K
|
||||
os.path.join(args.qkv_dir, "qkv_32485.pt"), # ~32K
|
||||
os.path.join(args.qkv_dir, "qkv_64891.pt"), # ~64K
|
||||
]
|
||||
|
||||
available_files = [p for p in qkv_files if os.path.exists(p)]
|
||||
|
||||
if not available_files:
|
||||
print(f"No QKV file found in {args.qkv_dir}.")
|
||||
print(f"Expected files: qkv_3688.pt, qkv_7888.pt, qkv_15685.pt, qkv_32485.pt, qkv_64891.pt")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Found {len(available_files)} QKV files to test")
|
||||
print(f"Testing EXTERNAL chunking (chunk_size={CHUNK_SIZE})")
|
||||
print(f"Using Triton kernels")
|
||||
|
||||
all_passed = True
|
||||
results = []
|
||||
|
||||
for qkv_path in available_files:
|
||||
passed = test_single_qkv(qkv_path)
|
||||
seq_len = int(os.path.basename(qkv_path).replace("qkv_", "").replace(".pt", ""))
|
||||
results.append((seq_len, passed))
|
||||
if not passed:
|
||||
all_passed = False
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 60)
|
||||
print("SUMMARY")
|
||||
print("=" * 60)
|
||||
for seq_len, passed in results:
|
||||
status = "PASSED" if passed else "FAILED"
|
||||
chunks = (seq_len + CHUNK_SIZE - 1) // CHUNK_SIZE
|
||||
print(f" seq_len={seq_len} ({chunks} chunk{'s' if chunks > 1 else ''}): {status}")
|
||||
|
||||
print("=" * 60)
|
||||
if all_passed:
|
||||
print("test_xattn_chunked: PASSED")
|
||||
sys.exit(0)
|
||||
else:
|
||||
print("test_xattn_chunked: FAILED")
|
||||
sys.exit(1)
|
||||
Reference in New Issue
Block a user