♻️ refactor: use Q-chunked processing in xattn alignment test

Match xattn_estimate internal logic by processing Q in chunks:
- Reduces peak memory for attn_scores tensor
- Enables testing 64K sequences without OOM
- All 5 test files pass (3.6K to 64K)

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
This commit is contained in:
Zijie Tian
2026-02-01 18:08:15 +08:00
parent f173a3f7f5
commit 193ef55d18

View File

@@ -5,7 +5,7 @@ Test: 验证 xattn_estimate 与底层 kernel 调用的一致性
1. xattn_estimate (高层 API) 1. xattn_estimate (高层 API)
2. flat_group_gemm_fuse_reshape + softmax_fuse_block_sum + find_blocks_chunked (底层 kernels) 2. flat_group_gemm_fuse_reshape + softmax_fuse_block_sum + find_blocks_chunked (底层 kernels)
验证两种方式的 density 是否一致 底层 kernels 按 Q 分 chunk与 xattn_estimate 内部逻辑一致,减少峰值内存占用
Usage: Usage:
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \ CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
@@ -21,7 +21,6 @@ from nanovllm.ops.xattn import (
flat_group_gemm_fuse_reshape, flat_group_gemm_fuse_reshape,
softmax_fuse_block_sum, softmax_fuse_block_sum,
find_blocks_chunked, find_blocks_chunked,
compute_sparsity,
) )
# ============================================================ # ============================================================
@@ -29,7 +28,7 @@ from nanovllm.ops.xattn import (
# ============================================================ # ============================================================
DATA_FILE = "/home/zijie/Code/nano-vllm/results/kvcache/qkv_32485.pt" DATA_FILE = "/home/zijie/Code/nano-vllm/results/kvcache/qkv_32485.pt"
BSA_BLOCK_SIZE = 128 BSA_BLOCK_SIZE = 128
# STRIDE 和 THRESHOLD 从保存的数据中读取 CHUNK_SIZE = 16384 # xattn_estimate 默认值
USE_SAVED_PARAMS = True # 设为 False 则使用默认值 USE_SAVED_PARAMS = True # 设为 False 则使用默认值
device = "cuda" device = "cuda"
@@ -58,7 +57,7 @@ else:
print(f"Q shape: {Q.shape}") print(f"Q shape: {Q.shape}")
print(f"K shape: {K.shape}") print(f"K shape: {K.shape}")
print(f"Data layer_id: {data['layer_id']}, saved density: {data['density']:.4f}") print(f"Data layer_id: {data['layer_id']}, saved density: {data['density']:.4f}")
print(f"使用参数: STRIDE={STRIDE}, THRESHOLD={THRESHOLD}") print(f"使用参数: STRIDE={STRIDE}, THRESHOLD={THRESHOLD}, CHUNK_SIZE={CHUNK_SIZE}")
print() print()
# ============================================================ # ============================================================
@@ -68,16 +67,12 @@ print("=" * 60)
print("Step 2: 调用 xattn_estimate (高层 API)") print("Step 2: 调用 xattn_estimate (高层 API)")
print("=" * 60) print("=" * 60)
# 使用与底层计算一致的 chunk_size (seq_len 对齐到 alignment)
alignment = STRIDE * 128
chunk_size_aligned = ((seq_len + alignment - 1) // alignment) * alignment
attn_sums_api, mask_api = xattn_estimate( attn_sums_api, mask_api = xattn_estimate(
Q, K, Q, K,
block_size=BSA_BLOCK_SIZE, block_size=BSA_BLOCK_SIZE,
stride=STRIDE, stride=STRIDE,
threshold=THRESHOLD, threshold=THRESHOLD,
chunk_size=chunk_size_aligned, # 保持一致 chunk_size=CHUNK_SIZE,
causal=True, causal=True,
) )
@@ -98,110 +93,111 @@ print(f"[xattn_estimate] density: {density_api:.6f} (selected={selected_api}, to
print() print()
# ============================================================ # ============================================================
# Step 3: 使用底层 kernels 手动计算 # Step 3: 使用底层 kernels 手动计算 (按 Q 分 chunk)
# ============================================================ # ============================================================
print("=" * 60) print("=" * 60)
print("Step 3: 使用底层 kernels 手动计算") print("Step 3: 使用底层 kernels 手动计算 (按 Q 分 chunk)")
print("=" * 60) print("=" * 60)
# 3.1 Padding # 3.1 计算 padding 参数 (与 xattn_estimate 内部一致)
BLOCK_M = 128 k_num_to_pad = ((seq_len + CHUNK_SIZE - 1) // CHUNK_SIZE) * CHUNK_SIZE - seq_len
BLOCK_N = 128 q_num_to_pad = ((seq_len + CHUNK_SIZE - 1) // CHUNK_SIZE) * CHUNK_SIZE - seq_len
alignment = STRIDE * BLOCK_M k_chunk_num = (seq_len + k_num_to_pad) // CHUNK_SIZE
k_alignment = STRIDE * BLOCK_N q_chunk_num = (seq_len + q_num_to_pad) // CHUNK_SIZE
padded_q_len = ((seq_len + alignment - 1) // alignment) * alignment k_block_num = (seq_len + k_num_to_pad) // BSA_BLOCK_SIZE
padded_k_len = ((seq_len + k_alignment - 1) // k_alignment) * k_alignment q_block_num = (seq_len + q_num_to_pad) // BSA_BLOCK_SIZE
reshaped_chunk_size = CHUNK_SIZE // STRIDE
reshaped_block_size = BSA_BLOCK_SIZE // STRIDE
k_reshaped_seq_len = (seq_len + k_num_to_pad) // STRIDE
k_reshaped_num_to_pad = k_num_to_pad // STRIDE
num_blocks_per_chunk = reshaped_chunk_size // reshaped_block_size
print(f"原始 seq_len: {seq_len}") print(f"原始 seq_len: {seq_len}")
print(f"Padded Q len: {padded_q_len}") print(f"q_chunk_num: {q_chunk_num}, k_chunk_num: {k_chunk_num}")
print(f"Padded K len: {padded_k_len}") print(f"q_block_num: {q_block_num}, k_block_num: {k_block_num}")
print(f"reshaped_chunk_size: {reshaped_chunk_size}, reshaped_block_size: {reshaped_block_size}")
print(f"num_blocks_per_chunk: {num_blocks_per_chunk}")
print()
if padded_q_len != seq_len: # 3.2 Padding
Q_padded = torch.nn.functional.pad(Q, (0, 0, 0, padded_q_len - seq_len), value=0) if k_num_to_pad > 0:
else: K_padded = torch.nn.functional.pad(K, (0, 0, 0, k_num_to_pad), value=0)
Q_padded = Q
if padded_k_len != seq_len:
K_padded = torch.nn.functional.pad(K, (0, 0, 0, padded_k_len - seq_len), value=0)
else: else:
K_padded = K K_padded = K
if q_num_to_pad > 0:
Q_padded = torch.nn.functional.pad(Q, (0, 0, 0, q_num_to_pad), value=0)
else:
Q_padded = Q
print(f"Q_padded shape: {Q_padded.shape}") print(f"Q_padded shape: {Q_padded.shape}")
print(f"K_padded shape: {K_padded.shape}") print(f"K_padded shape: {K_padded.shape}")
print() print()
# 3.2 计算 reshaped 维度 # 3.3 按 Q chunk 处理 (与 xattn_estimate 内部逻辑一致)
q_reshaped_len = padded_q_len // STRIDE
k_reshaped_len = padded_k_len // STRIDE
reshaped_block_size = BSA_BLOCK_SIZE // STRIDE
q_block_num = padded_q_len // BSA_BLOCK_SIZE
k_block_num = padded_k_len // BSA_BLOCK_SIZE
print(f"q_reshaped_len: {q_reshaped_len}")
print(f"k_reshaped_len: {k_reshaped_len}")
print(f"reshaped_block_size: {reshaped_block_size}")
print(f"q_block_num: {q_block_num}, k_block_num: {k_block_num}")
print()
# 3.3 调用 flat_group_gemm_fuse_reshape
print("3.3 调用 flat_group_gemm_fuse_reshape...")
chunk_start = (k_block_num - q_block_num) * reshaped_block_size # 对于 q_len=k_len, offset=0
chunk_end = chunk_start + q_reshaped_len
attn_scores = flat_group_gemm_fuse_reshape(
Q_padded, K_padded, STRIDE,
chunk_start=chunk_start,
chunk_end=chunk_end,
is_causal=True,
)
print(f"attn_scores shape: {attn_scores.shape}")
print()
# 3.4 调用 softmax_fuse_block_sum
print("3.4 调用 softmax_fuse_block_sum...")
norm = 1.0 norm = 1.0
scale = 1.4426950408889634 / math.sqrt(head_dim) / STRIDE / norm scale = 1.4426950408889634 / math.sqrt(head_dim) / STRIDE / norm
segment_size = min(4096, reshaped_block_size)
# 计算 real_q_len (排除 padding) simple_mask_list = []
k_reshaped_num_to_pad = (padded_k_len - seq_len) // STRIDE
real_q_len = k_reshaped_len - k_reshaped_num_to_pad
block_sums = softmax_fuse_block_sum( print(f"按 Q 分 {q_chunk_num} 个 chunk 处理...")
attn_scores, for chunk_idx in range(q_chunk_num):
reshaped_block_size, # 提取当前 Q chunk (与 xattn_estimate line 811-816 一致)
segment_size, q_start = chunk_idx * reshaped_chunk_size * STRIDE
q_end = q_start + reshaped_chunk_size * STRIDE
Q_chunk = Q_padded[:, :, q_start:q_end, :]
# 计算 chunk_start/chunk_end (与 xattn_estimate line 819-820 一致)
chunk_start = (k_block_num - q_block_num) * reshaped_block_size + chunk_idx * reshaped_chunk_size
chunk_end = chunk_start + reshaped_chunk_size
# flat_group_gemm_fuse_reshape (与 xattn_estimate line 810-822 一致)
attn_weights_slice = flat_group_gemm_fuse_reshape(
Q_chunk, K_padded, STRIDE,
chunk_start=chunk_start, chunk_start=chunk_start,
chunk_end=chunk_end, chunk_end=chunk_end,
real_q_len=real_q_len, is_causal=True,
)
# softmax_fuse_block_sum (与 xattn_estimate line 827-836 一致)
attn_sum = softmax_fuse_block_sum(
attn_weights_slice,
reshaped_block_size,
min(4096, reshaped_block_size),
chunk_start=chunk_start,
chunk_end=chunk_end,
real_q_len=k_reshaped_seq_len - k_reshaped_num_to_pad,
scale=scale, scale=scale,
is_causal=True, is_causal=True,
) )
print(f"block_sums shape: {block_sums.shape}")
print()
# 3.5 调用 find_blocks_chunked # find_blocks_chunked (与 xattn_estimate line 887-895 一致)
print("3.5 调用 find_blocks_chunked...") simple_mask = find_blocks_chunked(
mask_manual = find_blocks_chunked( attn_sum,
block_sums, current_index=k_block_num - q_block_num + chunk_idx * num_blocks_per_chunk,
current_index=0, # Q 从位置 0 开始 (因为 q_len = k_len)
threshold=THRESHOLD, threshold=THRESHOLD,
num_to_choose=None, num_to_choose=None,
decoding=False, decoding=False,
mode="prefill", mode="prefill",
causal=True, causal=True,
) )
simple_mask_list.append(simple_mask)
print(f" Chunk {chunk_idx}: Q[{q_start}:{q_end}], attn shape={attn_weights_slice.shape}, mask shape={simple_mask.shape}")
# 3.4 合并所有 chunks 的 mask (与 xattn_estimate line 901-905 一致)
mask_manual = torch.cat(simple_mask_list, dim=2)
print(f"\n合并后 mask_manual shape: {mask_manual.shape}")
# 裁剪到有效区域 # 裁剪到有效区域
mask_manual_valid = mask_manual[:, :, :q_blocks, :k_blocks] mask_manual_valid = mask_manual[:, :, :q_blocks, :k_blocks]
print(f"mask_manual shape (padded): {mask_manual.shape}")
print(f"mask_manual_valid shape: {mask_manual_valid.shape}") print(f"mask_manual_valid shape: {mask_manual_valid.shape}")
# 计算 density # 计算 density
selected_manual = (mask_manual_valid & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item() selected_manual = (mask_manual_valid & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item()
total_manual = total_api # 相同的 total total_manual = total_api
density_manual = selected_manual / total_manual density_manual = selected_manual / total_manual
print(f"[底层 kernels] density: {density_manual:.6f} (selected={selected_manual}, total={total_manual})") print(f"[底层 kernels] density: {density_manual:.6f} (selected={selected_manual}, total={total_manual})")
@@ -222,10 +218,10 @@ print()
# 对比 mask # 对比 mask
mask_diff = (mask_api_valid != mask_manual_valid).sum().item() mask_diff = (mask_api_valid != mask_manual_valid).sum().item()
mask_total = mask_api_valid.numel() mask_total = mask_api_valid.numel()
print(f"Mask 不同的元素数: {mask_diff} / {mask_total} ({100*mask_diff/mask_total:.4f}%)") mask_diff_ratio = mask_diff / mask_total
print(f"Mask 不同的元素数: {mask_diff} / {mask_total} ({100*mask_diff_ratio:.4f}%)")
print() print()
mask_diff_ratio = mask_diff / mask_total
if abs(density_api - density_manual) < 1e-6 and mask_diff_ratio < 0.001: if abs(density_api - density_manual) < 1e-6 and mask_diff_ratio < 0.001:
print("✅ xattn_estimate 与底层 kernels 对齐! (mask 差异 < 0.1%)") print("✅ xattn_estimate 与底层 kernels 对齐! (mask 差异 < 0.1%)")
elif abs(density_api - density_manual) < 0.01: elif abs(density_api - density_manual) < 0.01:
@@ -248,4 +244,4 @@ print(f"差异: {abs(saved_density - density_api):.6f}")
if abs(saved_density - density_api) < 0.01: if abs(saved_density - density_api) < 0.01:
print("✅ 与保存的 density 基本一致!") print("✅ 与保存的 density 基本一致!")
else: else:
print("⚠️ 与保存的 density 有差异,可能是 threshold 不同") print("⚠️ 与保存的 density 有差异,可能是参数不同")