feat: add xattn kernels test and update testing rules

- Add test_xattn_kernels.py demonstrating flat_group_gemm_fuse_reshape
  and softmax_fuse_block_sum Triton kernels with structured data
- Update testing.md with new test code style guidelines
- Update xattn.py and xattn_bsa.py with improvements

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Zijie Tian
2026-01-23 03:01:25 +08:00
parent d808970f2f
commit 999858e82f
4 changed files with 508 additions and 124 deletions

View File

@@ -0,0 +1,86 @@
"""
Test: XAttention Triton kernels
演示 XAttention 的两个核心 Triton kernel:
1. flat_group_gemm_fuse_reshape: 计算 stride reshape 后的 attention scores (反对角线求和)
2. softmax_fuse_block_sum: 对 attention scores 做 softmax 后按 block 求和
数据流:
Q, K [batch, heads, seq_len, head_dim]
↓ flat_group_gemm_fuse_reshape
attn_scores [batch, heads, seq_len/stride, seq_len/stride]
↓ softmax_fuse_block_sum
block_sums [batch, heads, q_blocks, k_blocks]
"""
import torch
import sys
sys.path.insert(0, "/home/zijie/Code/nano-vllm")
from nanovllm.ops.xattn import flat_group_gemm_fuse_reshape, softmax_fuse_block_sum
# ============================================================
# 参数配置
# ============================================================
seq_len = 512 # Triton 要求 seq_len >= stride * BLOCK_M = 4 * 128 = 512
head_dim = 128
stride = 4
block_size = 128 # softmax block size (in reshaped space)
segment_size = 128 # Triton kernel 要求 segment_size >= block_size
# ============================================================
# 构造输入: 偶数位置=1, 奇数位置=2
# ============================================================
Q = torch.zeros(1, 1, seq_len, head_dim, dtype=torch.bfloat16).cuda()
K = torch.zeros(1, 1, seq_len, head_dim, dtype=torch.bfloat16).cuda()
for i in range(seq_len):
if i % 2 == 0:
Q[0, 0, i, :] = 1
K[0, 0, i, :] = 1
else:
Q[0, 0, i, :] = 2
K[0, 0, i, :] = 2
# ============================================================
# Step 1: flat_group_gemm_fuse_reshape
# ============================================================
attn_scores = flat_group_gemm_fuse_reshape(
Q, K, stride,
chunk_start=0,
chunk_end=seq_len // stride,
is_causal=False
)
# 验证: 反对角线求和
# 每个 stride x stride 块的反对角线: Q[奇]*K[偶] + Q[偶]*K[奇] = 2*1 + 1*2 = 4
# 反对角线有 stride/2 对,再乘以 head_dim
expected_gemm = (2*1 + 1*2) * (stride // 2) * head_dim
actual_gemm = attn_scores[0, 0, 0, 0].item()
assert actual_gemm == expected_gemm, f"flat_group_gemm: {actual_gemm} != {expected_gemm}"
# ============================================================
# Step 2: softmax_fuse_block_sum
# ============================================================
reshaped_len = seq_len // stride
scale = 1.4426950408889634 # log2(e) for exp2
block_sums = softmax_fuse_block_sum(
attn_scores,
block_size,
segment_size,
chunk_start=0,
chunk_end=reshaped_len,
real_q_len=reshaped_len,
scale=scale,
is_causal=False
)
# 验证: 每个 block 的 softmax 结果求和
# 所有 attn_scores 相同 → softmax 均匀分布 → block_sum = block_size^2 / reshaped_len
expected_sum = block_size * block_size / reshaped_len
actual_sum = block_sums[0, 0, 0, 0].item()
assert actual_sum == expected_sum, f"softmax_fuse_block_sum: {actual_sum} != {expected_sum}"
print("test_xattn_kernels: PASSED")