260 lines
8.7 KiB
Python
260 lines
8.7 KiB
Python
"""
|
|
Test: Compare xattn_estimate vs xattn_estimate_chunked
|
|
Verify that chunked estimation with EXTERNAL chunking produces the same mask as standard estimation.
|
|
|
|
Uses real QKV data captured from model inference.
|
|
"""
|
|
|
|
import sys
|
|
import os
|
|
import torch
|
|
import warnings
|
|
|
|
from nanovllm.ops.xattn import xattn_estimate, xattn_estimate_chunked
|
|
|
|
# ============================================================
|
|
# Configuration
|
|
# ============================================================
|
|
|
|
BLOCK_SIZE = 64
|
|
STRIDE = 4
|
|
THRESHOLD = 0.9
|
|
CHUNK_SIZE = 4096
|
|
|
|
# Default QKV data directory (relative to project root)
|
|
DEFAULT_QKV_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "results", "kvcache")
|
|
|
|
# ============================================================
|
|
# Utility Functions
|
|
# ============================================================
|
|
|
|
def load_qkv(path):
|
|
"""Load saved QKV data."""
|
|
data = torch.load(path, map_location="cpu", weights_only=False)
|
|
print(f"Loaded: {path}")
|
|
print(f" Query shape: {data['query'].shape}")
|
|
print(f" Key shape: {data['key'].shape}")
|
|
print(f" Layer: {data['layer_id']}, Density: {data['density']:.2%}")
|
|
return data
|
|
|
|
|
|
def compare_masks(mask1, mask2, name1="standard", name2="chunked"):
|
|
"""Compare two masks and report differences."""
|
|
if mask1.shape != mask2.shape:
|
|
print(f"Shape mismatch: {name1}={mask1.shape}, {name2}={mask2.shape}")
|
|
return False
|
|
|
|
diff = (mask1 != mask2).sum().item()
|
|
total = mask1.numel()
|
|
match_rate = (total - diff) / total * 100
|
|
|
|
print(f" Match rate: {match_rate:.4f}% ({total - diff}/{total})")
|
|
|
|
if diff > 0:
|
|
diff_indices = torch.where(mask1 != mask2)
|
|
print(f" First 5 diff positions: {list(zip(*[idx[:5].tolist() for idx in diff_indices]))}")
|
|
|
|
return diff == 0
|
|
|
|
|
|
def run_chunked_externally(query, key, q_start_pos, block_size, stride, threshold, chunk_size):
|
|
"""
|
|
Run xattn_estimate_chunked with EXTERNAL chunking.
|
|
This simulates how chunked prefill should be used in practice.
|
|
"""
|
|
batch_size, num_heads, q_len, head_dim = query.shape
|
|
_, _, k_len, _ = key.shape
|
|
|
|
q_block_num = (q_len + block_size - 1) // block_size
|
|
k_block_num = (k_len + block_size - 1) // block_size
|
|
|
|
# If Q fits in one chunk, call directly
|
|
if q_len <= chunk_size:
|
|
with warnings.catch_warnings():
|
|
warnings.simplefilter("ignore")
|
|
return xattn_estimate_chunked(
|
|
query, key,
|
|
q_start_pos=q_start_pos,
|
|
block_size=block_size,
|
|
stride=stride,
|
|
threshold=threshold,
|
|
use_triton=True,
|
|
chunk_size=chunk_size,
|
|
)
|
|
|
|
# External chunking: split Q and call for each chunk
|
|
num_q_chunks = (q_len + chunk_size - 1) // chunk_size
|
|
print(f" External chunking: {num_q_chunks} chunks")
|
|
|
|
combined_attn_sum = torch.zeros(
|
|
batch_size, num_heads, q_block_num, k_block_num,
|
|
dtype=query.dtype, device=query.device
|
|
)
|
|
combined_mask = torch.zeros(
|
|
batch_size, num_heads, q_block_num, k_block_num,
|
|
dtype=torch.bool, device=query.device
|
|
)
|
|
|
|
q_block_offset = 0
|
|
for q_chunk_idx in range(num_q_chunks):
|
|
q_chunk_start = q_chunk_idx * chunk_size
|
|
q_chunk_end = min((q_chunk_idx + 1) * chunk_size, q_len)
|
|
|
|
q_chunk = query[:, :, q_chunk_start:q_chunk_end, :]
|
|
|
|
# For causal attention, K accumulates up to current Q position
|
|
k_end = q_start_pos + q_chunk_end
|
|
k_chunk = key[:, :, :k_end, :]
|
|
|
|
with warnings.catch_warnings():
|
|
warnings.simplefilter("ignore")
|
|
attn_sum_chunk, mask_chunk = xattn_estimate_chunked(
|
|
q_chunk, k_chunk,
|
|
q_start_pos=q_start_pos + q_chunk_start,
|
|
block_size=block_size,
|
|
stride=stride,
|
|
threshold=threshold,
|
|
use_triton=True,
|
|
chunk_size=chunk_size,
|
|
)
|
|
|
|
# Place chunk results into combined output
|
|
chunk_q_blocks = mask_chunk.shape[2]
|
|
chunk_k_blocks = mask_chunk.shape[3]
|
|
combined_attn_sum[:, :, q_block_offset:q_block_offset+chunk_q_blocks, :chunk_k_blocks] = attn_sum_chunk
|
|
combined_mask[:, :, q_block_offset:q_block_offset+chunk_q_blocks, :chunk_k_blocks] = mask_chunk
|
|
q_block_offset += chunk_q_blocks
|
|
|
|
return combined_attn_sum, combined_mask
|
|
|
|
|
|
def test_single_qkv(qkv_path):
|
|
"""Test a single QKV file."""
|
|
data = load_qkv(qkv_path)
|
|
query = data["query"].cuda().to(torch.bfloat16)
|
|
key = data["key"].cuda().to(torch.bfloat16)
|
|
|
|
seq_len = query.shape[2]
|
|
print(f"\nTesting with seq_len={seq_len}")
|
|
print("=" * 60)
|
|
|
|
# Run standard xattn_estimate
|
|
print("[1] Running standard xattn_estimate...")
|
|
try:
|
|
attn_sum_std, mask_std = xattn_estimate(
|
|
query, key,
|
|
block_size=BLOCK_SIZE,
|
|
stride=STRIDE,
|
|
threshold=THRESHOLD,
|
|
chunk_size=CHUNK_SIZE,
|
|
use_triton=True,
|
|
)
|
|
print(f" mask shape: {mask_std.shape}, density: {mask_std.float().mean().item():.4f}")
|
|
except Exception as e:
|
|
print(f" ERROR: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
return False
|
|
|
|
# Run chunked xattn_estimate with EXTERNAL chunking
|
|
print("[2] Running chunked xattn_estimate (external chunking)...")
|
|
try:
|
|
attn_sum_chunked, mask_chunked = run_chunked_externally(
|
|
query, key,
|
|
q_start_pos=0,
|
|
block_size=BLOCK_SIZE,
|
|
stride=STRIDE,
|
|
threshold=THRESHOLD,
|
|
chunk_size=CHUNK_SIZE,
|
|
)
|
|
print(f" mask shape: {mask_chunked.shape}, density: {mask_chunked.float().mean().item():.4f}")
|
|
except Exception as e:
|
|
print(f" ERROR: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
return False
|
|
|
|
# Compare results
|
|
print("[3] Comparing results...")
|
|
chunked_q_blocks = mask_chunked.shape[2]
|
|
chunked_k_blocks = mask_chunked.shape[3]
|
|
|
|
# Extract comparable region from standard mask
|
|
mask_std_comparable = mask_std[:, :, :chunked_q_blocks, :chunked_k_blocks]
|
|
|
|
# Compare masks
|
|
masks_match = compare_masks(mask_std_comparable, mask_chunked, "standard", "chunked")
|
|
|
|
# Compare attn_sums
|
|
attn_sum_std_comparable = attn_sum_std[:, :, :chunked_q_blocks, :chunked_k_blocks]
|
|
if attn_sum_std_comparable.shape == attn_sum_chunked.shape:
|
|
attn_diff = (attn_sum_std_comparable - attn_sum_chunked).abs().max().item()
|
|
print(f" Attn sum max diff: {attn_diff:.6f}")
|
|
else:
|
|
print(f" Attn sum shape mismatch")
|
|
|
|
# Clean up GPU memory
|
|
del query, key, attn_sum_std, mask_std, attn_sum_chunked, mask_chunked
|
|
torch.cuda.empty_cache()
|
|
|
|
return masks_match
|
|
|
|
|
|
# ============================================================
|
|
# Main Test
|
|
# ============================================================
|
|
|
|
if __name__ == "__main__":
|
|
import argparse
|
|
parser = argparse.ArgumentParser(description="Test xattn_estimate vs xattn_estimate_chunked")
|
|
parser.add_argument("--qkv-dir", type=str, default=DEFAULT_QKV_DIR,
|
|
help="Directory containing QKV files")
|
|
args = parser.parse_args()
|
|
|
|
# QKV files to test
|
|
qkv_files = [
|
|
os.path.join(args.qkv_dir, "qkv_3688.pt"), # ~4K
|
|
os.path.join(args.qkv_dir, "qkv_7888.pt"), # ~8K
|
|
os.path.join(args.qkv_dir, "qkv_15685.pt"), # ~16K
|
|
os.path.join(args.qkv_dir, "qkv_32485.pt"), # ~32K
|
|
os.path.join(args.qkv_dir, "qkv_64891.pt"), # ~64K
|
|
]
|
|
|
|
available_files = [p for p in qkv_files if os.path.exists(p)]
|
|
|
|
if not available_files:
|
|
print(f"No QKV file found in {args.qkv_dir}.")
|
|
print(f"Expected files: qkv_3688.pt, qkv_7888.pt, qkv_15685.pt, qkv_32485.pt, qkv_64891.pt")
|
|
sys.exit(1)
|
|
|
|
print(f"Found {len(available_files)} QKV files to test")
|
|
print(f"Testing EXTERNAL chunking (chunk_size={CHUNK_SIZE})")
|
|
print(f"Using Triton kernels")
|
|
|
|
all_passed = True
|
|
results = []
|
|
|
|
for qkv_path in available_files:
|
|
passed = test_single_qkv(qkv_path)
|
|
seq_len = int(os.path.basename(qkv_path).replace("qkv_", "").replace(".pt", ""))
|
|
results.append((seq_len, passed))
|
|
if not passed:
|
|
all_passed = False
|
|
|
|
# Summary
|
|
print("\n" + "=" * 60)
|
|
print("SUMMARY")
|
|
print("=" * 60)
|
|
for seq_len, passed in results:
|
|
status = "PASSED" if passed else "FAILED"
|
|
chunks = (seq_len + CHUNK_SIZE - 1) // CHUNK_SIZE
|
|
print(f" seq_len={seq_len} ({chunks} chunk{'s' if chunks > 1 else ''}): {status}")
|
|
|
|
print("=" * 60)
|
|
if all_passed:
|
|
print("test_xattn_chunked: PASSED")
|
|
sys.exit(0)
|
|
else:
|
|
print("test_xattn_chunked: FAILED")
|
|
sys.exit(1)
|