diff --git a/tests/test_xattn_estimate_alignment.py b/tests/test_xattn_estimate_alignment.py index ae6487b..021d3ae 100644 --- a/tests/test_xattn_estimate_alignment.py +++ b/tests/test_xattn_estimate_alignment.py @@ -5,7 +5,7 @@ Test: 验证 xattn_estimate 与底层 kernel 调用的一致性 1. xattn_estimate (高层 API) 2. flat_group_gemm_fuse_reshape + softmax_fuse_block_sum + find_blocks_chunked (底层 kernels) -验证两种方式的 density 是否一致。 +底层 kernels 按 Q 分 chunk,与 xattn_estimate 内部逻辑一致,减少峰值内存占用。 Usage: CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \ @@ -21,7 +21,6 @@ from nanovllm.ops.xattn import ( flat_group_gemm_fuse_reshape, softmax_fuse_block_sum, find_blocks_chunked, - compute_sparsity, ) # ============================================================ @@ -29,7 +28,7 @@ from nanovllm.ops.xattn import ( # ============================================================ DATA_FILE = "/home/zijie/Code/nano-vllm/results/kvcache/qkv_32485.pt" BSA_BLOCK_SIZE = 128 -# STRIDE 和 THRESHOLD 从保存的数据中读取 +CHUNK_SIZE = 16384 # xattn_estimate 默认值 USE_SAVED_PARAMS = True # 设为 False 则使用默认值 device = "cuda" @@ -58,7 +57,7 @@ else: print(f"Q shape: {Q.shape}") print(f"K shape: {K.shape}") print(f"Data layer_id: {data['layer_id']}, saved density: {data['density']:.4f}") -print(f"使用参数: STRIDE={STRIDE}, THRESHOLD={THRESHOLD}") +print(f"使用参数: STRIDE={STRIDE}, THRESHOLD={THRESHOLD}, CHUNK_SIZE={CHUNK_SIZE}") print() # ============================================================ @@ -68,16 +67,12 @@ print("=" * 60) print("Step 2: 调用 xattn_estimate (高层 API)") print("=" * 60) -# 使用与底层计算一致的 chunk_size (seq_len 对齐到 alignment) -alignment = STRIDE * 128 -chunk_size_aligned = ((seq_len + alignment - 1) // alignment) * alignment - attn_sums_api, mask_api = xattn_estimate( Q, K, block_size=BSA_BLOCK_SIZE, stride=STRIDE, threshold=THRESHOLD, - chunk_size=chunk_size_aligned, # 保持一致 + chunk_size=CHUNK_SIZE, causal=True, ) @@ -98,110 +93,111 @@ print(f"[xattn_estimate] density: {density_api:.6f} (selected={selected_api}, to print() # ============================================================ -# Step 3: 使用底层 kernels 手动计算 +# Step 3: 使用底层 kernels 手动计算 (按 Q 分 chunk) # ============================================================ print("=" * 60) -print("Step 3: 使用底层 kernels 手动计算") +print("Step 3: 使用底层 kernels 手动计算 (按 Q 分 chunk)") print("=" * 60) -# 3.1 Padding -BLOCK_M = 128 -BLOCK_N = 128 -alignment = STRIDE * BLOCK_M -k_alignment = STRIDE * BLOCK_N +# 3.1 计算 padding 参数 (与 xattn_estimate 内部一致) +k_num_to_pad = ((seq_len + CHUNK_SIZE - 1) // CHUNK_SIZE) * CHUNK_SIZE - seq_len +q_num_to_pad = ((seq_len + CHUNK_SIZE - 1) // CHUNK_SIZE) * CHUNK_SIZE - seq_len +k_chunk_num = (seq_len + k_num_to_pad) // CHUNK_SIZE +q_chunk_num = (seq_len + q_num_to_pad) // CHUNK_SIZE -padded_q_len = ((seq_len + alignment - 1) // alignment) * alignment -padded_k_len = ((seq_len + k_alignment - 1) // k_alignment) * k_alignment +k_block_num = (seq_len + k_num_to_pad) // BSA_BLOCK_SIZE +q_block_num = (seq_len + q_num_to_pad) // BSA_BLOCK_SIZE + +reshaped_chunk_size = CHUNK_SIZE // STRIDE +reshaped_block_size = BSA_BLOCK_SIZE // STRIDE +k_reshaped_seq_len = (seq_len + k_num_to_pad) // STRIDE +k_reshaped_num_to_pad = k_num_to_pad // STRIDE +num_blocks_per_chunk = reshaped_chunk_size // reshaped_block_size print(f"原始 seq_len: {seq_len}") -print(f"Padded Q len: {padded_q_len}") -print(f"Padded K len: {padded_k_len}") +print(f"q_chunk_num: {q_chunk_num}, k_chunk_num: {k_chunk_num}") +print(f"q_block_num: {q_block_num}, k_block_num: {k_block_num}") +print(f"reshaped_chunk_size: {reshaped_chunk_size}, reshaped_block_size: {reshaped_block_size}") +print(f"num_blocks_per_chunk: {num_blocks_per_chunk}") +print() -if padded_q_len != seq_len: - Q_padded = torch.nn.functional.pad(Q, (0, 0, 0, padded_q_len - seq_len), value=0) -else: - Q_padded = Q - -if padded_k_len != seq_len: - K_padded = torch.nn.functional.pad(K, (0, 0, 0, padded_k_len - seq_len), value=0) +# 3.2 Padding +if k_num_to_pad > 0: + K_padded = torch.nn.functional.pad(K, (0, 0, 0, k_num_to_pad), value=0) else: K_padded = K +if q_num_to_pad > 0: + Q_padded = torch.nn.functional.pad(Q, (0, 0, 0, q_num_to_pad), value=0) +else: + Q_padded = Q + print(f"Q_padded shape: {Q_padded.shape}") print(f"K_padded shape: {K_padded.shape}") print() -# 3.2 计算 reshaped 维度 -q_reshaped_len = padded_q_len // STRIDE -k_reshaped_len = padded_k_len // STRIDE -reshaped_block_size = BSA_BLOCK_SIZE // STRIDE - -q_block_num = padded_q_len // BSA_BLOCK_SIZE -k_block_num = padded_k_len // BSA_BLOCK_SIZE - -print(f"q_reshaped_len: {q_reshaped_len}") -print(f"k_reshaped_len: {k_reshaped_len}") -print(f"reshaped_block_size: {reshaped_block_size}") -print(f"q_block_num: {q_block_num}, k_block_num: {k_block_num}") -print() - -# 3.3 调用 flat_group_gemm_fuse_reshape -print("3.3 调用 flat_group_gemm_fuse_reshape...") -chunk_start = (k_block_num - q_block_num) * reshaped_block_size # 对于 q_len=k_len, offset=0 -chunk_end = chunk_start + q_reshaped_len - -attn_scores = flat_group_gemm_fuse_reshape( - Q_padded, K_padded, STRIDE, - chunk_start=chunk_start, - chunk_end=chunk_end, - is_causal=True, -) -print(f"attn_scores shape: {attn_scores.shape}") -print() - -# 3.4 调用 softmax_fuse_block_sum -print("3.4 调用 softmax_fuse_block_sum...") +# 3.3 按 Q chunk 处理 (与 xattn_estimate 内部逻辑一致) norm = 1.0 scale = 1.4426950408889634 / math.sqrt(head_dim) / STRIDE / norm -segment_size = min(4096, reshaped_block_size) -# 计算 real_q_len (排除 padding) -k_reshaped_num_to_pad = (padded_k_len - seq_len) // STRIDE -real_q_len = k_reshaped_len - k_reshaped_num_to_pad +simple_mask_list = [] -block_sums = softmax_fuse_block_sum( - attn_scores, - reshaped_block_size, - segment_size, - chunk_start=chunk_start, - chunk_end=chunk_end, - real_q_len=real_q_len, - scale=scale, - is_causal=True, -) -print(f"block_sums shape: {block_sums.shape}") -print() +print(f"按 Q 分 {q_chunk_num} 个 chunk 处理...") +for chunk_idx in range(q_chunk_num): + # 提取当前 Q chunk (与 xattn_estimate line 811-816 一致) + q_start = chunk_idx * reshaped_chunk_size * STRIDE + q_end = q_start + reshaped_chunk_size * STRIDE + Q_chunk = Q_padded[:, :, q_start:q_end, :] -# 3.5 调用 find_blocks_chunked -print("3.5 调用 find_blocks_chunked...") -mask_manual = find_blocks_chunked( - block_sums, - current_index=0, # Q 从位置 0 开始 (因为 q_len = k_len) - threshold=THRESHOLD, - num_to_choose=None, - decoding=False, - mode="prefill", - causal=True, -) + # 计算 chunk_start/chunk_end (与 xattn_estimate line 819-820 一致) + chunk_start = (k_block_num - q_block_num) * reshaped_block_size + chunk_idx * reshaped_chunk_size + chunk_end = chunk_start + reshaped_chunk_size + + # flat_group_gemm_fuse_reshape (与 xattn_estimate line 810-822 一致) + attn_weights_slice = flat_group_gemm_fuse_reshape( + Q_chunk, K_padded, STRIDE, + chunk_start=chunk_start, + chunk_end=chunk_end, + is_causal=True, + ) + + # softmax_fuse_block_sum (与 xattn_estimate line 827-836 一致) + attn_sum = softmax_fuse_block_sum( + attn_weights_slice, + reshaped_block_size, + min(4096, reshaped_block_size), + chunk_start=chunk_start, + chunk_end=chunk_end, + real_q_len=k_reshaped_seq_len - k_reshaped_num_to_pad, + scale=scale, + is_causal=True, + ) + + # find_blocks_chunked (与 xattn_estimate line 887-895 一致) + simple_mask = find_blocks_chunked( + attn_sum, + current_index=k_block_num - q_block_num + chunk_idx * num_blocks_per_chunk, + threshold=THRESHOLD, + num_to_choose=None, + decoding=False, + mode="prefill", + causal=True, + ) + + simple_mask_list.append(simple_mask) + print(f" Chunk {chunk_idx}: Q[{q_start}:{q_end}], attn shape={attn_weights_slice.shape}, mask shape={simple_mask.shape}") + +# 3.4 合并所有 chunks 的 mask (与 xattn_estimate line 901-905 一致) +mask_manual = torch.cat(simple_mask_list, dim=2) +print(f"\n合并后 mask_manual shape: {mask_manual.shape}") # 裁剪到有效区域 mask_manual_valid = mask_manual[:, :, :q_blocks, :k_blocks] -print(f"mask_manual shape (padded): {mask_manual.shape}") -print(f"mask_manual_valid shape: {mask_manual_valid.shape}") +print(f"mask_manual_valid shape: {mask_manual_valid.shape}") # 计算 density selected_manual = (mask_manual_valid & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item() -total_manual = total_api # 相同的 total +total_manual = total_api density_manual = selected_manual / total_manual print(f"[底层 kernels] density: {density_manual:.6f} (selected={selected_manual}, total={total_manual})") @@ -222,10 +218,10 @@ print() # 对比 mask mask_diff = (mask_api_valid != mask_manual_valid).sum().item() mask_total = mask_api_valid.numel() -print(f"Mask 不同的元素数: {mask_diff} / {mask_total} ({100*mask_diff/mask_total:.4f}%)") +mask_diff_ratio = mask_diff / mask_total +print(f"Mask 不同的元素数: {mask_diff} / {mask_total} ({100*mask_diff_ratio:.4f}%)") print() -mask_diff_ratio = mask_diff / mask_total if abs(density_api - density_manual) < 1e-6 and mask_diff_ratio < 0.001: print("✅ xattn_estimate 与底层 kernels 对齐! (mask 差异 < 0.1%)") elif abs(density_api - density_manual) < 0.01: @@ -248,4 +244,4 @@ print(f"差异: {abs(saved_density - density_api):.6f}") if abs(saved_density - density_api) < 0.01: print("✅ 与保存的 density 基本一致!") else: - print("⚠️ 与保存的 density 有差异,可能是 threshold 不同") + print("⚠️ 与保存的 density 有差异,可能是参数不同")