Files
nano-vllm/tests/test_xattn_estimate_alignment.py
Zijie Tian 193ef55d18 ♻️ refactor: use Q-chunked processing in xattn alignment test
Match xattn_estimate internal logic by processing Q in chunks:
- Reduces peak memory for attn_scores tensor
- Enables testing 64K sequences without OOM
- All 5 test files pass (3.6K to 64K)

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
2026-02-01 18:08:15 +08:00

248 lines
8.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Test: 验证 xattn_estimate 与底层 kernel 调用的一致性
使用真实 KV cache 数据,分别调用:
1. xattn_estimate (高层 API)
2. flat_group_gemm_fuse_reshape + softmax_fuse_block_sum + find_blocks_chunked (底层 kernels)
底层 kernels 按 Q 分 chunk与 xattn_estimate 内部逻辑一致,减少峰值内存占用。
Usage:
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
python tests/test_xattn_estimate_alignment.py
"""
import sys
sys.path.insert(0, "/home/zijie/Code/nano-vllm")
import torch
import math
from nanovllm.ops.xattn import (
xattn_estimate,
flat_group_gemm_fuse_reshape,
softmax_fuse_block_sum,
find_blocks_chunked,
)
# ============================================================
# 参数配置
# ============================================================
DATA_FILE = "/home/zijie/Code/nano-vllm/results/kvcache/qkv_32485.pt"
BSA_BLOCK_SIZE = 128
CHUNK_SIZE = 16384 # xattn_estimate 默认值
USE_SAVED_PARAMS = True # 设为 False 则使用默认值
device = "cuda"
# ============================================================
# Step 1: 加载真实数据
# ============================================================
print("=" * 60)
print("Step 1: 加载真实 KV cache 数据")
print("=" * 60)
data = torch.load(DATA_FILE, map_location="cpu")
Q = data["query"].to(device) # [1, 32, seq_len, 128]
K = data["key"].to(device) # [1, 32, seq_len, 128]
batch_size, num_heads, seq_len, head_dim = Q.shape
# 从保存的数据中读取参数
if USE_SAVED_PARAMS:
STRIDE = data["stride"]
THRESHOLD = data["threshold"][0].item() if isinstance(data["threshold"], torch.Tensor) else data["threshold"]
else:
STRIDE = 8
THRESHOLD = 0.9
print(f"Q shape: {Q.shape}")
print(f"K shape: {K.shape}")
print(f"Data layer_id: {data['layer_id']}, saved density: {data['density']:.4f}")
print(f"使用参数: STRIDE={STRIDE}, THRESHOLD={THRESHOLD}, CHUNK_SIZE={CHUNK_SIZE}")
print()
# ============================================================
# Step 2: 使用 xattn_estimate 高层 API
# ============================================================
print("=" * 60)
print("Step 2: 调用 xattn_estimate (高层 API)")
print("=" * 60)
attn_sums_api, mask_api = xattn_estimate(
Q, K,
block_size=BSA_BLOCK_SIZE,
stride=STRIDE,
threshold=THRESHOLD,
chunk_size=CHUNK_SIZE,
causal=True,
)
# 裁剪到有效区域
q_blocks = (seq_len + BSA_BLOCK_SIZE - 1) // BSA_BLOCK_SIZE
k_blocks = (seq_len + BSA_BLOCK_SIZE - 1) // BSA_BLOCK_SIZE
mask_api_valid = mask_api[:, :, :q_blocks, :k_blocks]
# 计算 density (causal)
causal_mask = torch.tril(torch.ones(q_blocks, k_blocks, device=device, dtype=torch.bool))
total_api = causal_mask.sum().item() * batch_size * num_heads
selected_api = (mask_api_valid & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item()
density_api = selected_api / total_api
print(f"mask_api shape (padded): {mask_api.shape}")
print(f"mask_api_valid shape: {mask_api_valid.shape}")
print(f"[xattn_estimate] density: {density_api:.6f} (selected={selected_api}, total={total_api})")
print()
# ============================================================
# Step 3: 使用底层 kernels 手动计算 (按 Q 分 chunk)
# ============================================================
print("=" * 60)
print("Step 3: 使用底层 kernels 手动计算 (按 Q 分 chunk)")
print("=" * 60)
# 3.1 计算 padding 参数 (与 xattn_estimate 内部一致)
k_num_to_pad = ((seq_len + CHUNK_SIZE - 1) // CHUNK_SIZE) * CHUNK_SIZE - seq_len
q_num_to_pad = ((seq_len + CHUNK_SIZE - 1) // CHUNK_SIZE) * CHUNK_SIZE - seq_len
k_chunk_num = (seq_len + k_num_to_pad) // CHUNK_SIZE
q_chunk_num = (seq_len + q_num_to_pad) // CHUNK_SIZE
k_block_num = (seq_len + k_num_to_pad) // BSA_BLOCK_SIZE
q_block_num = (seq_len + q_num_to_pad) // BSA_BLOCK_SIZE
reshaped_chunk_size = CHUNK_SIZE // STRIDE
reshaped_block_size = BSA_BLOCK_SIZE // STRIDE
k_reshaped_seq_len = (seq_len + k_num_to_pad) // STRIDE
k_reshaped_num_to_pad = k_num_to_pad // STRIDE
num_blocks_per_chunk = reshaped_chunk_size // reshaped_block_size
print(f"原始 seq_len: {seq_len}")
print(f"q_chunk_num: {q_chunk_num}, k_chunk_num: {k_chunk_num}")
print(f"q_block_num: {q_block_num}, k_block_num: {k_block_num}")
print(f"reshaped_chunk_size: {reshaped_chunk_size}, reshaped_block_size: {reshaped_block_size}")
print(f"num_blocks_per_chunk: {num_blocks_per_chunk}")
print()
# 3.2 Padding
if k_num_to_pad > 0:
K_padded = torch.nn.functional.pad(K, (0, 0, 0, k_num_to_pad), value=0)
else:
K_padded = K
if q_num_to_pad > 0:
Q_padded = torch.nn.functional.pad(Q, (0, 0, 0, q_num_to_pad), value=0)
else:
Q_padded = Q
print(f"Q_padded shape: {Q_padded.shape}")
print(f"K_padded shape: {K_padded.shape}")
print()
# 3.3 按 Q chunk 处理 (与 xattn_estimate 内部逻辑一致)
norm = 1.0
scale = 1.4426950408889634 / math.sqrt(head_dim) / STRIDE / norm
simple_mask_list = []
print(f"按 Q 分 {q_chunk_num} 个 chunk 处理...")
for chunk_idx in range(q_chunk_num):
# 提取当前 Q chunk (与 xattn_estimate line 811-816 一致)
q_start = chunk_idx * reshaped_chunk_size * STRIDE
q_end = q_start + reshaped_chunk_size * STRIDE
Q_chunk = Q_padded[:, :, q_start:q_end, :]
# 计算 chunk_start/chunk_end (与 xattn_estimate line 819-820 一致)
chunk_start = (k_block_num - q_block_num) * reshaped_block_size + chunk_idx * reshaped_chunk_size
chunk_end = chunk_start + reshaped_chunk_size
# flat_group_gemm_fuse_reshape (与 xattn_estimate line 810-822 一致)
attn_weights_slice = flat_group_gemm_fuse_reshape(
Q_chunk, K_padded, STRIDE,
chunk_start=chunk_start,
chunk_end=chunk_end,
is_causal=True,
)
# softmax_fuse_block_sum (与 xattn_estimate line 827-836 一致)
attn_sum = softmax_fuse_block_sum(
attn_weights_slice,
reshaped_block_size,
min(4096, reshaped_block_size),
chunk_start=chunk_start,
chunk_end=chunk_end,
real_q_len=k_reshaped_seq_len - k_reshaped_num_to_pad,
scale=scale,
is_causal=True,
)
# find_blocks_chunked (与 xattn_estimate line 887-895 一致)
simple_mask = find_blocks_chunked(
attn_sum,
current_index=k_block_num - q_block_num + chunk_idx * num_blocks_per_chunk,
threshold=THRESHOLD,
num_to_choose=None,
decoding=False,
mode="prefill",
causal=True,
)
simple_mask_list.append(simple_mask)
print(f" Chunk {chunk_idx}: Q[{q_start}:{q_end}], attn shape={attn_weights_slice.shape}, mask shape={simple_mask.shape}")
# 3.4 合并所有 chunks 的 mask (与 xattn_estimate line 901-905 一致)
mask_manual = torch.cat(simple_mask_list, dim=2)
print(f"\n合并后 mask_manual shape: {mask_manual.shape}")
# 裁剪到有效区域
mask_manual_valid = mask_manual[:, :, :q_blocks, :k_blocks]
print(f"mask_manual_valid shape: {mask_manual_valid.shape}")
# 计算 density
selected_manual = (mask_manual_valid & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item()
total_manual = total_api
density_manual = selected_manual / total_manual
print(f"[底层 kernels] density: {density_manual:.6f} (selected={selected_manual}, total={total_manual})")
print()
# ============================================================
# Step 4: 对比结果
# ============================================================
print("=" * 60)
print("Step 4: 对比结果")
print("=" * 60)
print(f"xattn_estimate density: {density_api:.6f}")
print(f"底层 kernels density: {density_manual:.6f}")
print(f"差异: {abs(density_api - density_manual):.6f}")
print()
# 对比 mask
mask_diff = (mask_api_valid != mask_manual_valid).sum().item()
mask_total = mask_api_valid.numel()
mask_diff_ratio = mask_diff / mask_total
print(f"Mask 不同的元素数: {mask_diff} / {mask_total} ({100*mask_diff_ratio:.4f}%)")
print()
if abs(density_api - density_manual) < 1e-6 and mask_diff_ratio < 0.001:
print("✅ xattn_estimate 与底层 kernels 对齐! (mask 差异 < 0.1%)")
elif abs(density_api - density_manual) < 0.01:
print("⚠️ Density 基本一致,但 mask 有差异")
else:
print("❌ Density 不一致,需要检查参数")
# ============================================================
# Step 5: 额外验证 - 与保存的 density 对比
# ============================================================
print()
print("=" * 60)
print("Step 5: 与保存的 density 对比")
print("=" * 60)
saved_density = data["density"]
print(f"保存的 density: {saved_density:.6f}")
print(f"xattn_estimate density: {density_api:.6f}")
print(f"差异: {abs(saved_density - density_api):.6f}")
if abs(saved_density - density_api) < 0.01:
print("✅ 与保存的 density 基本一致!")
else:
print("⚠️ 与保存的 density 有差异,可能是参数不同")