♻️ refactor: use Q-chunked processing in xattn alignment test
Match xattn_estimate internal logic by processing Q in chunks: - Reduces peak memory for attn_scores tensor - Enables testing 64K sequences without OOM - All 5 test files pass (3.6K to 64K) Generated with [Claude Code](https://claude.ai/code) via [Happy](https://happy.engineering) Co-Authored-By: Claude <noreply@anthropic.com> Co-Authored-By: Happy <yesreply@happy.engineering>
This commit is contained in:
@@ -5,7 +5,7 @@ Test: 验证 xattn_estimate 与底层 kernel 调用的一致性
|
|||||||
1. xattn_estimate (高层 API)
|
1. xattn_estimate (高层 API)
|
||||||
2. flat_group_gemm_fuse_reshape + softmax_fuse_block_sum + find_blocks_chunked (底层 kernels)
|
2. flat_group_gemm_fuse_reshape + softmax_fuse_block_sum + find_blocks_chunked (底层 kernels)
|
||||||
|
|
||||||
验证两种方式的 density 是否一致。
|
底层 kernels 按 Q 分 chunk,与 xattn_estimate 内部逻辑一致,减少峰值内存占用。
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
||||||
@@ -21,7 +21,6 @@ from nanovllm.ops.xattn import (
|
|||||||
flat_group_gemm_fuse_reshape,
|
flat_group_gemm_fuse_reshape,
|
||||||
softmax_fuse_block_sum,
|
softmax_fuse_block_sum,
|
||||||
find_blocks_chunked,
|
find_blocks_chunked,
|
||||||
compute_sparsity,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# ============================================================
|
# ============================================================
|
||||||
@@ -29,7 +28,7 @@ from nanovllm.ops.xattn import (
|
|||||||
# ============================================================
|
# ============================================================
|
||||||
DATA_FILE = "/home/zijie/Code/nano-vllm/results/kvcache/qkv_32485.pt"
|
DATA_FILE = "/home/zijie/Code/nano-vllm/results/kvcache/qkv_32485.pt"
|
||||||
BSA_BLOCK_SIZE = 128
|
BSA_BLOCK_SIZE = 128
|
||||||
# STRIDE 和 THRESHOLD 从保存的数据中读取
|
CHUNK_SIZE = 16384 # xattn_estimate 默认值
|
||||||
USE_SAVED_PARAMS = True # 设为 False 则使用默认值
|
USE_SAVED_PARAMS = True # 设为 False 则使用默认值
|
||||||
|
|
||||||
device = "cuda"
|
device = "cuda"
|
||||||
@@ -58,7 +57,7 @@ else:
|
|||||||
print(f"Q shape: {Q.shape}")
|
print(f"Q shape: {Q.shape}")
|
||||||
print(f"K shape: {K.shape}")
|
print(f"K shape: {K.shape}")
|
||||||
print(f"Data layer_id: {data['layer_id']}, saved density: {data['density']:.4f}")
|
print(f"Data layer_id: {data['layer_id']}, saved density: {data['density']:.4f}")
|
||||||
print(f"使用参数: STRIDE={STRIDE}, THRESHOLD={THRESHOLD}")
|
print(f"使用参数: STRIDE={STRIDE}, THRESHOLD={THRESHOLD}, CHUNK_SIZE={CHUNK_SIZE}")
|
||||||
print()
|
print()
|
||||||
|
|
||||||
# ============================================================
|
# ============================================================
|
||||||
@@ -68,16 +67,12 @@ print("=" * 60)
|
|||||||
print("Step 2: 调用 xattn_estimate (高层 API)")
|
print("Step 2: 调用 xattn_estimate (高层 API)")
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
|
|
||||||
# 使用与底层计算一致的 chunk_size (seq_len 对齐到 alignment)
|
|
||||||
alignment = STRIDE * 128
|
|
||||||
chunk_size_aligned = ((seq_len + alignment - 1) // alignment) * alignment
|
|
||||||
|
|
||||||
attn_sums_api, mask_api = xattn_estimate(
|
attn_sums_api, mask_api = xattn_estimate(
|
||||||
Q, K,
|
Q, K,
|
||||||
block_size=BSA_BLOCK_SIZE,
|
block_size=BSA_BLOCK_SIZE,
|
||||||
stride=STRIDE,
|
stride=STRIDE,
|
||||||
threshold=THRESHOLD,
|
threshold=THRESHOLD,
|
||||||
chunk_size=chunk_size_aligned, # 保持一致
|
chunk_size=CHUNK_SIZE,
|
||||||
causal=True,
|
causal=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -98,110 +93,111 @@ print(f"[xattn_estimate] density: {density_api:.6f} (selected={selected_api}, to
|
|||||||
print()
|
print()
|
||||||
|
|
||||||
# ============================================================
|
# ============================================================
|
||||||
# Step 3: 使用底层 kernels 手动计算
|
# Step 3: 使用底层 kernels 手动计算 (按 Q 分 chunk)
|
||||||
# ============================================================
|
# ============================================================
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
print("Step 3: 使用底层 kernels 手动计算")
|
print("Step 3: 使用底层 kernels 手动计算 (按 Q 分 chunk)")
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
|
|
||||||
# 3.1 Padding
|
# 3.1 计算 padding 参数 (与 xattn_estimate 内部一致)
|
||||||
BLOCK_M = 128
|
k_num_to_pad = ((seq_len + CHUNK_SIZE - 1) // CHUNK_SIZE) * CHUNK_SIZE - seq_len
|
||||||
BLOCK_N = 128
|
q_num_to_pad = ((seq_len + CHUNK_SIZE - 1) // CHUNK_SIZE) * CHUNK_SIZE - seq_len
|
||||||
alignment = STRIDE * BLOCK_M
|
k_chunk_num = (seq_len + k_num_to_pad) // CHUNK_SIZE
|
||||||
k_alignment = STRIDE * BLOCK_N
|
q_chunk_num = (seq_len + q_num_to_pad) // CHUNK_SIZE
|
||||||
|
|
||||||
padded_q_len = ((seq_len + alignment - 1) // alignment) * alignment
|
k_block_num = (seq_len + k_num_to_pad) // BSA_BLOCK_SIZE
|
||||||
padded_k_len = ((seq_len + k_alignment - 1) // k_alignment) * k_alignment
|
q_block_num = (seq_len + q_num_to_pad) // BSA_BLOCK_SIZE
|
||||||
|
|
||||||
|
reshaped_chunk_size = CHUNK_SIZE // STRIDE
|
||||||
|
reshaped_block_size = BSA_BLOCK_SIZE // STRIDE
|
||||||
|
k_reshaped_seq_len = (seq_len + k_num_to_pad) // STRIDE
|
||||||
|
k_reshaped_num_to_pad = k_num_to_pad // STRIDE
|
||||||
|
num_blocks_per_chunk = reshaped_chunk_size // reshaped_block_size
|
||||||
|
|
||||||
print(f"原始 seq_len: {seq_len}")
|
print(f"原始 seq_len: {seq_len}")
|
||||||
print(f"Padded Q len: {padded_q_len}")
|
print(f"q_chunk_num: {q_chunk_num}, k_chunk_num: {k_chunk_num}")
|
||||||
print(f"Padded K len: {padded_k_len}")
|
print(f"q_block_num: {q_block_num}, k_block_num: {k_block_num}")
|
||||||
|
print(f"reshaped_chunk_size: {reshaped_chunk_size}, reshaped_block_size: {reshaped_block_size}")
|
||||||
|
print(f"num_blocks_per_chunk: {num_blocks_per_chunk}")
|
||||||
|
print()
|
||||||
|
|
||||||
if padded_q_len != seq_len:
|
# 3.2 Padding
|
||||||
Q_padded = torch.nn.functional.pad(Q, (0, 0, 0, padded_q_len - seq_len), value=0)
|
if k_num_to_pad > 0:
|
||||||
else:
|
K_padded = torch.nn.functional.pad(K, (0, 0, 0, k_num_to_pad), value=0)
|
||||||
Q_padded = Q
|
|
||||||
|
|
||||||
if padded_k_len != seq_len:
|
|
||||||
K_padded = torch.nn.functional.pad(K, (0, 0, 0, padded_k_len - seq_len), value=0)
|
|
||||||
else:
|
else:
|
||||||
K_padded = K
|
K_padded = K
|
||||||
|
|
||||||
|
if q_num_to_pad > 0:
|
||||||
|
Q_padded = torch.nn.functional.pad(Q, (0, 0, 0, q_num_to_pad), value=0)
|
||||||
|
else:
|
||||||
|
Q_padded = Q
|
||||||
|
|
||||||
print(f"Q_padded shape: {Q_padded.shape}")
|
print(f"Q_padded shape: {Q_padded.shape}")
|
||||||
print(f"K_padded shape: {K_padded.shape}")
|
print(f"K_padded shape: {K_padded.shape}")
|
||||||
print()
|
print()
|
||||||
|
|
||||||
# 3.2 计算 reshaped 维度
|
# 3.3 按 Q chunk 处理 (与 xattn_estimate 内部逻辑一致)
|
||||||
q_reshaped_len = padded_q_len // STRIDE
|
|
||||||
k_reshaped_len = padded_k_len // STRIDE
|
|
||||||
reshaped_block_size = BSA_BLOCK_SIZE // STRIDE
|
|
||||||
|
|
||||||
q_block_num = padded_q_len // BSA_BLOCK_SIZE
|
|
||||||
k_block_num = padded_k_len // BSA_BLOCK_SIZE
|
|
||||||
|
|
||||||
print(f"q_reshaped_len: {q_reshaped_len}")
|
|
||||||
print(f"k_reshaped_len: {k_reshaped_len}")
|
|
||||||
print(f"reshaped_block_size: {reshaped_block_size}")
|
|
||||||
print(f"q_block_num: {q_block_num}, k_block_num: {k_block_num}")
|
|
||||||
print()
|
|
||||||
|
|
||||||
# 3.3 调用 flat_group_gemm_fuse_reshape
|
|
||||||
print("3.3 调用 flat_group_gemm_fuse_reshape...")
|
|
||||||
chunk_start = (k_block_num - q_block_num) * reshaped_block_size # 对于 q_len=k_len, offset=0
|
|
||||||
chunk_end = chunk_start + q_reshaped_len
|
|
||||||
|
|
||||||
attn_scores = flat_group_gemm_fuse_reshape(
|
|
||||||
Q_padded, K_padded, STRIDE,
|
|
||||||
chunk_start=chunk_start,
|
|
||||||
chunk_end=chunk_end,
|
|
||||||
is_causal=True,
|
|
||||||
)
|
|
||||||
print(f"attn_scores shape: {attn_scores.shape}")
|
|
||||||
print()
|
|
||||||
|
|
||||||
# 3.4 调用 softmax_fuse_block_sum
|
|
||||||
print("3.4 调用 softmax_fuse_block_sum...")
|
|
||||||
norm = 1.0
|
norm = 1.0
|
||||||
scale = 1.4426950408889634 / math.sqrt(head_dim) / STRIDE / norm
|
scale = 1.4426950408889634 / math.sqrt(head_dim) / STRIDE / norm
|
||||||
segment_size = min(4096, reshaped_block_size)
|
|
||||||
|
|
||||||
# 计算 real_q_len (排除 padding)
|
simple_mask_list = []
|
||||||
k_reshaped_num_to_pad = (padded_k_len - seq_len) // STRIDE
|
|
||||||
real_q_len = k_reshaped_len - k_reshaped_num_to_pad
|
|
||||||
|
|
||||||
block_sums = softmax_fuse_block_sum(
|
print(f"按 Q 分 {q_chunk_num} 个 chunk 处理...")
|
||||||
attn_scores,
|
for chunk_idx in range(q_chunk_num):
|
||||||
reshaped_block_size,
|
# 提取当前 Q chunk (与 xattn_estimate line 811-816 一致)
|
||||||
segment_size,
|
q_start = chunk_idx * reshaped_chunk_size * STRIDE
|
||||||
|
q_end = q_start + reshaped_chunk_size * STRIDE
|
||||||
|
Q_chunk = Q_padded[:, :, q_start:q_end, :]
|
||||||
|
|
||||||
|
# 计算 chunk_start/chunk_end (与 xattn_estimate line 819-820 一致)
|
||||||
|
chunk_start = (k_block_num - q_block_num) * reshaped_block_size + chunk_idx * reshaped_chunk_size
|
||||||
|
chunk_end = chunk_start + reshaped_chunk_size
|
||||||
|
|
||||||
|
# flat_group_gemm_fuse_reshape (与 xattn_estimate line 810-822 一致)
|
||||||
|
attn_weights_slice = flat_group_gemm_fuse_reshape(
|
||||||
|
Q_chunk, K_padded, STRIDE,
|
||||||
chunk_start=chunk_start,
|
chunk_start=chunk_start,
|
||||||
chunk_end=chunk_end,
|
chunk_end=chunk_end,
|
||||||
real_q_len=real_q_len,
|
is_causal=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# softmax_fuse_block_sum (与 xattn_estimate line 827-836 一致)
|
||||||
|
attn_sum = softmax_fuse_block_sum(
|
||||||
|
attn_weights_slice,
|
||||||
|
reshaped_block_size,
|
||||||
|
min(4096, reshaped_block_size),
|
||||||
|
chunk_start=chunk_start,
|
||||||
|
chunk_end=chunk_end,
|
||||||
|
real_q_len=k_reshaped_seq_len - k_reshaped_num_to_pad,
|
||||||
scale=scale,
|
scale=scale,
|
||||||
is_causal=True,
|
is_causal=True,
|
||||||
)
|
)
|
||||||
print(f"block_sums shape: {block_sums.shape}")
|
|
||||||
print()
|
|
||||||
|
|
||||||
# 3.5 调用 find_blocks_chunked
|
# find_blocks_chunked (与 xattn_estimate line 887-895 一致)
|
||||||
print("3.5 调用 find_blocks_chunked...")
|
simple_mask = find_blocks_chunked(
|
||||||
mask_manual = find_blocks_chunked(
|
attn_sum,
|
||||||
block_sums,
|
current_index=k_block_num - q_block_num + chunk_idx * num_blocks_per_chunk,
|
||||||
current_index=0, # Q 从位置 0 开始 (因为 q_len = k_len)
|
|
||||||
threshold=THRESHOLD,
|
threshold=THRESHOLD,
|
||||||
num_to_choose=None,
|
num_to_choose=None,
|
||||||
decoding=False,
|
decoding=False,
|
||||||
mode="prefill",
|
mode="prefill",
|
||||||
causal=True,
|
causal=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
simple_mask_list.append(simple_mask)
|
||||||
|
print(f" Chunk {chunk_idx}: Q[{q_start}:{q_end}], attn shape={attn_weights_slice.shape}, mask shape={simple_mask.shape}")
|
||||||
|
|
||||||
|
# 3.4 合并所有 chunks 的 mask (与 xattn_estimate line 901-905 一致)
|
||||||
|
mask_manual = torch.cat(simple_mask_list, dim=2)
|
||||||
|
print(f"\n合并后 mask_manual shape: {mask_manual.shape}")
|
||||||
|
|
||||||
# 裁剪到有效区域
|
# 裁剪到有效区域
|
||||||
mask_manual_valid = mask_manual[:, :, :q_blocks, :k_blocks]
|
mask_manual_valid = mask_manual[:, :, :q_blocks, :k_blocks]
|
||||||
print(f"mask_manual shape (padded): {mask_manual.shape}")
|
|
||||||
print(f"mask_manual_valid shape: {mask_manual_valid.shape}")
|
print(f"mask_manual_valid shape: {mask_manual_valid.shape}")
|
||||||
|
|
||||||
# 计算 density
|
# 计算 density
|
||||||
selected_manual = (mask_manual_valid & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item()
|
selected_manual = (mask_manual_valid & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item()
|
||||||
total_manual = total_api # 相同的 total
|
total_manual = total_api
|
||||||
density_manual = selected_manual / total_manual
|
density_manual = selected_manual / total_manual
|
||||||
|
|
||||||
print(f"[底层 kernels] density: {density_manual:.6f} (selected={selected_manual}, total={total_manual})")
|
print(f"[底层 kernels] density: {density_manual:.6f} (selected={selected_manual}, total={total_manual})")
|
||||||
@@ -222,10 +218,10 @@ print()
|
|||||||
# 对比 mask
|
# 对比 mask
|
||||||
mask_diff = (mask_api_valid != mask_manual_valid).sum().item()
|
mask_diff = (mask_api_valid != mask_manual_valid).sum().item()
|
||||||
mask_total = mask_api_valid.numel()
|
mask_total = mask_api_valid.numel()
|
||||||
print(f"Mask 不同的元素数: {mask_diff} / {mask_total} ({100*mask_diff/mask_total:.4f}%)")
|
mask_diff_ratio = mask_diff / mask_total
|
||||||
|
print(f"Mask 不同的元素数: {mask_diff} / {mask_total} ({100*mask_diff_ratio:.4f}%)")
|
||||||
print()
|
print()
|
||||||
|
|
||||||
mask_diff_ratio = mask_diff / mask_total
|
|
||||||
if abs(density_api - density_manual) < 1e-6 and mask_diff_ratio < 0.001:
|
if abs(density_api - density_manual) < 1e-6 and mask_diff_ratio < 0.001:
|
||||||
print("✅ xattn_estimate 与底层 kernels 对齐! (mask 差异 < 0.1%)")
|
print("✅ xattn_estimate 与底层 kernels 对齐! (mask 差异 < 0.1%)")
|
||||||
elif abs(density_api - density_manual) < 0.01:
|
elif abs(density_api - density_manual) < 0.01:
|
||||||
@@ -248,4 +244,4 @@ print(f"差异: {abs(saved_density - density_api):.6f}")
|
|||||||
if abs(saved_density - density_api) < 0.01:
|
if abs(saved_density - density_api) < 0.01:
|
||||||
print("✅ 与保存的 density 基本一致!")
|
print("✅ 与保存的 density 基本一致!")
|
||||||
else:
|
else:
|
||||||
print("⚠️ 与保存的 density 有差异,可能是 threshold 不同")
|
print("⚠️ 与保存的 density 有差异,可能是参数不同")
|
||||||
|
|||||||
Reference in New Issue
Block a user