feat: add KV chunking support for XAttention softmax kernels
Implement three-phase KV chunking for sparse attention estimation: 1. softmax_compute_partial_stats: compute (m, l) per KV chunk 2. merge_softmax_stats: merge partial stats on host 3. softmax_normalize_and_block_sum: normalize with global stats This allows computing sparse attention masks without storing full raw attention scores in GPU memory, reducing peak memory usage from O(q_len * k_full_len) to O(q_len * k_chunk_len). Key changes: - Add softmax_partial_stats_kernel with causal mask support - Add softmax_normalize_block_sum_kernel with kv_offset parameter - Add Python wrappers for new kernels - Update test script to validate KV chunking alignment - Add documentation for the new kernels Test results show perfect alignment with xattn_estimate API: - Density difference: 0.000000 - Mask difference: 0.0044% Generated with [Claude Code](https://claude.ai/code) via [Happy](https://happy.engineering) Co-Authored-By: Claude <noreply@anthropic.com> Co-Authored-By: Happy <yesreply@happy.engineering>
This commit is contained in:
@@ -1,11 +1,14 @@
|
||||
"""
|
||||
Test: 验证 xattn_estimate 与底层 kernel 调用的一致性
|
||||
Test: 验证 xattn_estimate 与 KV chunking kernels 的一致性
|
||||
|
||||
使用真实 KV cache 数据,分别调用:
|
||||
使用真实 KV cache 数据,对比:
|
||||
1. xattn_estimate (高层 API)
|
||||
2. flat_group_gemm_fuse_reshape + softmax_fuse_block_sum + find_blocks_chunked (底层 kernels)
|
||||
2. 三阶段 KV chunking (softmax_compute_partial_stats + merge + normalize)
|
||||
|
||||
底层 kernels 按 Q 分 chunk,与 xattn_estimate 内部逻辑一致,减少峰值内存占用。
|
||||
三阶段 KV chunking 流程:
|
||||
1. softmax_compute_partial_stats: 计算每个 KV chunk 的 (m, l)
|
||||
2. merge_softmax_stats: Host 端合并所有 chunks 的 stats
|
||||
3. softmax_normalize_and_block_sum: 使用全局 stats 归一化
|
||||
|
||||
Usage:
|
||||
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
||||
@@ -19,7 +22,9 @@ import math
|
||||
from nanovllm.ops.xattn import (
|
||||
xattn_estimate,
|
||||
flat_group_gemm_fuse_reshape,
|
||||
softmax_fuse_block_sum,
|
||||
softmax_compute_partial_stats,
|
||||
softmax_normalize_and_block_sum,
|
||||
merge_softmax_stats,
|
||||
find_blocks_chunked,
|
||||
)
|
||||
|
||||
@@ -93,17 +98,21 @@ print(f"[xattn_estimate] density: {density_api:.6f} (selected={selected_api}, to
|
||||
print()
|
||||
|
||||
# ============================================================
|
||||
# Step 3: 使用底层 kernels 手动计算 (按 Q 分 chunk)
|
||||
# Step 3: 三阶段 KV Chunking
|
||||
# ============================================================
|
||||
print("=" * 60)
|
||||
print("Step 3: 使用底层 kernels 手动计算 (按 Q 分 chunk)")
|
||||
print("Step 3: 三阶段 KV Chunking")
|
||||
print("=" * 60)
|
||||
print(" 1) 每个 KV chunk 计算 partial stats")
|
||||
print(" 2) Host 端合并 stats")
|
||||
print(" 3) 使用全局 stats 归一化并计算 block sums")
|
||||
print()
|
||||
|
||||
# 3.1 计算 padding 参数 (与 xattn_estimate 内部一致)
|
||||
# 计算 padding 参数
|
||||
k_num_to_pad = ((seq_len + CHUNK_SIZE - 1) // CHUNK_SIZE) * CHUNK_SIZE - seq_len
|
||||
q_num_to_pad = ((seq_len + CHUNK_SIZE - 1) // CHUNK_SIZE) * CHUNK_SIZE - seq_len
|
||||
k_chunk_num = (seq_len + k_num_to_pad) // CHUNK_SIZE
|
||||
q_chunk_num = (seq_len + q_num_to_pad) // CHUNK_SIZE
|
||||
kv_chunk_num = (seq_len + k_num_to_pad) // CHUNK_SIZE
|
||||
|
||||
k_block_num = (seq_len + k_num_to_pad) // BSA_BLOCK_SIZE
|
||||
q_block_num = (seq_len + q_num_to_pad) // BSA_BLOCK_SIZE
|
||||
@@ -113,15 +122,12 @@ reshaped_block_size = BSA_BLOCK_SIZE // STRIDE
|
||||
k_reshaped_seq_len = (seq_len + k_num_to_pad) // STRIDE
|
||||
k_reshaped_num_to_pad = k_num_to_pad // STRIDE
|
||||
num_blocks_per_chunk = reshaped_chunk_size // reshaped_block_size
|
||||
kv_reshaped_chunk_size = CHUNK_SIZE // STRIDE
|
||||
|
||||
print(f"原始 seq_len: {seq_len}")
|
||||
print(f"q_chunk_num: {q_chunk_num}, k_chunk_num: {k_chunk_num}")
|
||||
print(f"q_block_num: {q_block_num}, k_block_num: {k_block_num}")
|
||||
print(f"reshaped_chunk_size: {reshaped_chunk_size}, reshaped_block_size: {reshaped_block_size}")
|
||||
print(f"num_blocks_per_chunk: {num_blocks_per_chunk}")
|
||||
print(f"seq_len: {seq_len}, q_chunk_num: {q_chunk_num}, kv_chunk_num: {kv_chunk_num}")
|
||||
print()
|
||||
|
||||
# 3.2 Padding
|
||||
# Padding
|
||||
if k_num_to_pad > 0:
|
||||
K_padded = torch.nn.functional.pad(K, (0, 0, 0, k_num_to_pad), value=0)
|
||||
else:
|
||||
@@ -132,75 +138,100 @@ if q_num_to_pad > 0:
|
||||
else:
|
||||
Q_padded = Q
|
||||
|
||||
print(f"Q_padded shape: {Q_padded.shape}")
|
||||
print(f"K_padded shape: {K_padded.shape}")
|
||||
print()
|
||||
|
||||
# 3.3 按 Q chunk 处理 (与 xattn_estimate 内部逻辑一致)
|
||||
# Softmax scale
|
||||
norm = 1.0
|
||||
scale = 1.4426950408889634 / math.sqrt(head_dim) / STRIDE / norm
|
||||
|
||||
simple_mask_list = []
|
||||
|
||||
print(f"按 Q 分 {q_chunk_num} 个 chunk 处理...")
|
||||
for chunk_idx in range(q_chunk_num):
|
||||
# 提取当前 Q chunk (与 xattn_estimate line 811-816 一致)
|
||||
q_start = chunk_idx * reshaped_chunk_size * STRIDE
|
||||
for q_chunk_idx in range(q_chunk_num):
|
||||
q_start = q_chunk_idx * reshaped_chunk_size * STRIDE
|
||||
q_end = q_start + reshaped_chunk_size * STRIDE
|
||||
Q_chunk = Q_padded[:, :, q_start:q_end, :]
|
||||
|
||||
# 计算 chunk_start/chunk_end (与 xattn_estimate line 819-820 一致)
|
||||
chunk_start = (k_block_num - q_block_num) * reshaped_block_size + chunk_idx * reshaped_chunk_size
|
||||
chunk_start = (k_block_num - q_block_num) * reshaped_block_size + q_chunk_idx * reshaped_chunk_size
|
||||
chunk_end = chunk_start + reshaped_chunk_size
|
||||
|
||||
# flat_group_gemm_fuse_reshape (与 xattn_estimate line 810-822 一致)
|
||||
attn_weights_slice = flat_group_gemm_fuse_reshape(
|
||||
Q_chunk, K_padded, STRIDE,
|
||||
chunk_start=chunk_start,
|
||||
chunk_end=chunk_end,
|
||||
is_causal=True,
|
||||
)
|
||||
# 阶段 1: 每个 KV chunk 计算 partial stats 和 raw scores
|
||||
m_chunks = []
|
||||
l_chunks = []
|
||||
attn_weights_chunks = []
|
||||
|
||||
# softmax_fuse_block_sum (与 xattn_estimate line 827-836 一致)
|
||||
attn_sum = softmax_fuse_block_sum(
|
||||
attn_weights_slice,
|
||||
reshaped_block_size,
|
||||
min(4096, reshaped_block_size),
|
||||
chunk_start=chunk_start,
|
||||
chunk_end=chunk_end,
|
||||
real_q_len=k_reshaped_seq_len - k_reshaped_num_to_pad,
|
||||
scale=scale,
|
||||
is_causal=True,
|
||||
)
|
||||
for kv_chunk_idx in range(kv_chunk_num):
|
||||
kv_start = kv_chunk_idx * CHUNK_SIZE
|
||||
kv_end = kv_start + CHUNK_SIZE
|
||||
K_chunk = K_padded[:, :, kv_start:kv_end, :]
|
||||
|
||||
# find_blocks_chunked (与 xattn_estimate line 887-895 一致)
|
||||
# KV offset in reshaped space
|
||||
kv_offset_reshaped = kv_chunk_idx * kv_reshaped_chunk_size
|
||||
|
||||
# 计算 raw attention scores
|
||||
attn_weights_kv = flat_group_gemm_fuse_reshape(
|
||||
Q_chunk, K_chunk, STRIDE,
|
||||
chunk_start=chunk_start,
|
||||
chunk_end=chunk_end,
|
||||
is_causal=False, # K 不完整,不能在这里用 causal
|
||||
)
|
||||
attn_weights_chunks.append(attn_weights_kv)
|
||||
|
||||
# 计算 partial stats (带 causal mask)
|
||||
m_partial, l_partial = softmax_compute_partial_stats(
|
||||
attn_weights_kv,
|
||||
reshaped_block_size,
|
||||
min(4096, reshaped_block_size),
|
||||
scale,
|
||||
chunk_start=chunk_start,
|
||||
kv_offset=kv_offset_reshaped,
|
||||
is_causal=True,
|
||||
)
|
||||
m_chunks.append(m_partial)
|
||||
l_chunks.append(l_partial)
|
||||
|
||||
# 阶段 2: Host 端合并 stats
|
||||
m_global, l_global = merge_softmax_stats(m_chunks, l_chunks)
|
||||
|
||||
# 阶段 3: 使用全局 stats 归一化并计算 block sums
|
||||
attn_sum_per_kv = []
|
||||
for kv_chunk_idx, attn_weights_kv in enumerate(attn_weights_chunks):
|
||||
kv_offset_reshaped = kv_chunk_idx * kv_reshaped_chunk_size
|
||||
attn_sum_kv = softmax_normalize_and_block_sum(
|
||||
attn_weights_kv,
|
||||
m_global,
|
||||
l_global,
|
||||
reshaped_block_size,
|
||||
min(4096, reshaped_block_size),
|
||||
chunk_start=chunk_start,
|
||||
real_q_len=k_reshaped_seq_len - k_reshaped_num_to_pad,
|
||||
scale=scale,
|
||||
kv_offset=kv_offset_reshaped,
|
||||
is_causal=True,
|
||||
)
|
||||
attn_sum_per_kv.append(attn_sum_kv)
|
||||
|
||||
# 拼接各 KV chunk 的 block sums
|
||||
attn_sum_concat = torch.cat(attn_sum_per_kv, dim=-1)
|
||||
|
||||
# 选择 blocks
|
||||
simple_mask = find_blocks_chunked(
|
||||
attn_sum,
|
||||
current_index=k_block_num - q_block_num + chunk_idx * num_blocks_per_chunk,
|
||||
attn_sum_concat,
|
||||
current_index=k_block_num - q_block_num + q_chunk_idx * num_blocks_per_chunk,
|
||||
threshold=THRESHOLD,
|
||||
num_to_choose=None,
|
||||
decoding=False,
|
||||
mode="prefill",
|
||||
causal=True,
|
||||
)
|
||||
|
||||
simple_mask_list.append(simple_mask)
|
||||
print(f" Chunk {chunk_idx}: Q[{q_start}:{q_end}], attn shape={attn_weights_slice.shape}, mask shape={simple_mask.shape}")
|
||||
|
||||
# 3.4 合并所有 chunks 的 mask (与 xattn_estimate line 901-905 一致)
|
||||
mask_manual = torch.cat(simple_mask_list, dim=2)
|
||||
print(f"\n合并后 mask_manual shape: {mask_manual.shape}")
|
||||
print(f" Q chunk {q_chunk_idx}: merged {kv_chunk_num} KV chunks, attn_sum shape={attn_sum_concat.shape}")
|
||||
|
||||
# 裁剪到有效区域
|
||||
mask_manual_valid = mask_manual[:, :, :q_blocks, :k_blocks]
|
||||
print(f"mask_manual_valid shape: {mask_manual_valid.shape}")
|
||||
mask_kv_chunking = torch.cat(simple_mask_list, dim=2)
|
||||
mask_kv_chunking_valid = mask_kv_chunking[:, :, :q_blocks, :k_blocks]
|
||||
selected_kv = (mask_kv_chunking_valid & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item()
|
||||
density_kv = selected_kv / total_api
|
||||
|
||||
# 计算 density
|
||||
selected_manual = (mask_manual_valid & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item()
|
||||
total_manual = total_api
|
||||
density_manual = selected_manual / total_manual
|
||||
|
||||
print(f"[底层 kernels] density: {density_manual:.6f} (selected={selected_manual}, total={total_manual})")
|
||||
print()
|
||||
print(f"[KV chunking] density: {density_kv:.6f} (selected={selected_kv}, total={total_api})")
|
||||
print()
|
||||
|
||||
# ============================================================
|
||||
@@ -209,39 +240,18 @@ print()
|
||||
print("=" * 60)
|
||||
print("Step 4: 对比结果")
|
||||
print("=" * 60)
|
||||
|
||||
print(f"xattn_estimate density: {density_api:.6f}")
|
||||
print(f"底层 kernels density: {density_manual:.6f}")
|
||||
print(f"差异: {abs(density_api - density_manual):.6f}")
|
||||
print()
|
||||
|
||||
# 对比 mask
|
||||
mask_diff = (mask_api_valid != mask_manual_valid).sum().item()
|
||||
mask_total = mask_api_valid.numel()
|
||||
mask_diff_ratio = mask_diff / mask_total
|
||||
print(f"Mask 不同的元素数: {mask_diff} / {mask_total} ({100*mask_diff_ratio:.4f}%)")
|
||||
mask_diff = (mask_api_valid != mask_kv_chunking_valid).sum().item()
|
||||
|
||||
print("| 方法 | density | 与 API 差异 | Mask 差异 |")
|
||||
print("|------|---------|-------------|-----------|")
|
||||
print(f"| xattn_estimate API | {density_api:.6f} | - | - |")
|
||||
print(f"| KV chunking | {density_kv:.6f} | {abs(density_api - density_kv):.6f} | {100*mask_diff/mask_total:.4f}% |")
|
||||
print()
|
||||
|
||||
if abs(density_api - density_manual) < 1e-6 and mask_diff_ratio < 0.001:
|
||||
print("✅ xattn_estimate 与底层 kernels 对齐! (mask 差异 < 0.1%)")
|
||||
elif abs(density_api - density_manual) < 0.01:
|
||||
print("⚠️ Density 基本一致,但 mask 有差异")
|
||||
if abs(density_api - density_kv) < 1e-6 and mask_diff / mask_total < 0.001:
|
||||
print("test_xattn_estimate_alignment: PASSED")
|
||||
else:
|
||||
print("❌ Density 不一致,需要检查参数")
|
||||
|
||||
# ============================================================
|
||||
# Step 5: 额外验证 - 与保存的 density 对比
|
||||
# ============================================================
|
||||
print()
|
||||
print("=" * 60)
|
||||
print("Step 5: 与保存的 density 对比")
|
||||
print("=" * 60)
|
||||
saved_density = data["density"]
|
||||
print(f"保存的 density: {saved_density:.6f}")
|
||||
print(f"xattn_estimate density: {density_api:.6f}")
|
||||
print(f"差异: {abs(saved_density - density_api):.6f}")
|
||||
|
||||
if abs(saved_density - density_api) < 0.01:
|
||||
print("✅ 与保存的 density 基本一致!")
|
||||
else:
|
||||
print("⚠️ 与保存的 density 有差异,可能是参数不同")
|
||||
print("test_xattn_estimate_alignment: FAILED")
|
||||
|
||||
Reference in New Issue
Block a user