feat: add xattn kernels test and update testing rules

- Add test_xattn_kernels.py demonstrating flat_group_gemm_fuse_reshape
  and softmax_fuse_block_sum Triton kernels with structured data
- Update testing.md with new test code style guidelines
- Update xattn.py and xattn_bsa.py with improvements

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Zijie Tian
2026-01-23 03:01:25 +08:00
parent d808970f2f
commit 999858e82f
4 changed files with 508 additions and 124 deletions

View File

@@ -419,7 +419,9 @@ def flat_group_gemm_fuse_reshape(
assert key_states.shape[1] == num_heads
assert key_states.shape[3] == head_dim
output = torch.empty(
# Use zeros instead of empty to handle causal early-exit in kernel
# (some blocks may not be written due to causal mask optimization)
output = torch.zeros(
(batch_size, num_heads, q_len // stride, kv_len // stride),
dtype=query_states.dtype,
device=query_states.device
@@ -1067,6 +1069,7 @@ def xattn_estimate_chunked(
)
# Softmax + block sum
# segment_size should match the standard xattn_estimate for consistency
attn_sum = softmax_fuse_block_sum(
attn_weights,
reshaped_block_size,
@@ -1082,6 +1085,14 @@ def xattn_estimate_chunked(
attn_sum = attn_sum[:, :, :q_block_num, :k_block_num]
else:
# PyTorch fallback implementation
# Match Triton kernel exactly for consistency
#
# Triton uses:
# 1. exp2 (base-2 exponential) for softmax
# 2. scale factor includes log2(e) = 1.4426950408889634
# 3. causal mask: q_pos >= k_pos (not q_pos + 1 > k_pos)
# 4. chunk_start for global Q position tracking
# Reshape K: interleave positions and concatenate head dims
reshaped_key = torch.cat(
[(key_states[:, :, k::stride, :]) for k in range(stride)], dim=-1
@@ -1093,49 +1104,58 @@ def xattn_estimate_chunked(
dim=-1,
)
# Use same scale as Triton: includes log2(e) for exp2 compatibility
# Triton: scale = 1.4426950408889634 / math.sqrt(head_dim) / stride / norm
scale = 1.4426950408889634 / math.sqrt(head_dim) / stride / norm
# Convert to float32 for numerical stability (matching Triton)
reshaped_query_f32 = reshaped_query.to(torch.float32)
reshaped_key_f32 = reshaped_key.to(torch.float32)
# Compute attention weights: (B, H, q_len/stride, k_len/stride)
attn_weights = torch.matmul(
reshaped_query, reshaped_key.transpose(2, 3)
) / math.sqrt(head_dim) / stride / norm
reshaped_query_f32, reshaped_key_f32.transpose(2, 3)
) * scale
# Apply causal mask
# Apply causal mask (matching Triton's logic exactly)
if causal:
reshaped_q_positions = reshaped_q_len
causal_mask = torch.zeros(
(batch_size, num_heads, reshaped_q_positions, reshaped_k_len),
device=key_states.device,
dtype=attn_weights.dtype,
# Triton uses: offs_q = chunk_start + block_id * block_size + arange(0, block_size)
# chunk_start = q_start_block * reshaped_block_size
chunk_start = q_start_block * reshaped_block_size
# Create position indices in reshaped space
q_positions = torch.arange(reshaped_q_len, device=attn_weights.device) + chunk_start
k_positions = torch.arange(reshaped_k_len, device=attn_weights.device)
# Triton causal mask: q_pos >= k_pos
causal_mask = q_positions[:, None] >= k_positions[None, :] # (reshaped_q_len, reshaped_k_len)
# Apply causal mask: set future positions to -1e6 (matching Triton)
attn_weights = attn_weights.masked_fill(
~causal_mask.unsqueeze(0).unsqueeze(0), -1e6
)
# Mask out padding in K
if k_pad > 0:
causal_mask[:, :, :, -(k_pad // stride):] = float("-inf")
# Softmax using exp2 (matching Triton exactly)
# Triton: X = tl.exp2(X - m_i[:, None]) * l_i_inv[:, None]
# All computation in float32
attn_max = attn_weights.max(dim=-1, keepdim=True).values
attn_weights_shifted = attn_weights - attn_max
attn_exp2 = torch.exp2(attn_weights_shifted)
attn_sum_exp2 = attn_exp2.sum(dim=-1, keepdim=True)
attn_weights = attn_exp2 / attn_sum_exp2
# Mask out future positions
q_start_reshaped = q_start_pos // stride
for q_idx in range(reshaped_q_positions):
q_pos_reshaped = q_start_reshaped + q_idx
if q_pos_reshaped + 1 < reshaped_k_len:
causal_mask[:, :, q_idx, q_pos_reshaped + 1:] = float("-inf")
# Mask for valid Q positions (matching Triton's sum_mask)
# Triton: sum_mask = offs_q[:, None] < real_q_len
# real_q_len = chunk_start + valid_q_reshaped
chunk_start = q_start_block * reshaped_block_size
real_q_len = chunk_start + valid_q_reshaped
q_positions = torch.arange(reshaped_q_len, device=attn_weights.device) + chunk_start
valid_q_mask = q_positions < real_q_len # (reshaped_q_len,)
# Handle padding in Q
if q_pad > 0:
q_pad_reshaped = q_pad // stride
if q_pad_reshaped > 0:
causal_mask[:, :, -q_pad_reshaped:, :] = float("-inf")
# Zero out invalid Q positions
attn_weights = attn_weights * valid_q_mask.view(1, 1, -1, 1).float()
attn_weights = attn_weights + causal_mask
# Apply softmax
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
# Zero out padded Q positions
if q_pad > 0:
q_pad_reshaped = q_pad // stride
if q_pad_reshaped > 0:
attn_weights[:, :, -q_pad_reshaped:, :] = 0
# Aggregate to block level
# Aggregate to block level (keep in float32)
attn_sum = attn_weights.view(
batch_size,
num_heads,
@@ -1145,6 +1165,9 @@ def xattn_estimate_chunked(
reshaped_block_size,
).sum(dim=-1).sum(dim=-2)
# Convert back to input dtype for consistency
attn_sum = attn_sum.to(query_states.dtype)
# Find blocks that exceed threshold
simple_mask = find_blocks_chunked(
attn_sum,