feat: add xattn kernels test and update testing rules
- Add test_xattn_kernels.py demonstrating flat_group_gemm_fuse_reshape and softmax_fuse_block_sum Triton kernels with structured data - Update testing.md with new test code style guidelines - Update xattn.py and xattn_bsa.py with improvements Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -419,7 +419,9 @@ def flat_group_gemm_fuse_reshape(
|
||||
assert key_states.shape[1] == num_heads
|
||||
assert key_states.shape[3] == head_dim
|
||||
|
||||
output = torch.empty(
|
||||
# Use zeros instead of empty to handle causal early-exit in kernel
|
||||
# (some blocks may not be written due to causal mask optimization)
|
||||
output = torch.zeros(
|
||||
(batch_size, num_heads, q_len // stride, kv_len // stride),
|
||||
dtype=query_states.dtype,
|
||||
device=query_states.device
|
||||
@@ -1067,6 +1069,7 @@ def xattn_estimate_chunked(
|
||||
)
|
||||
|
||||
# Softmax + block sum
|
||||
# segment_size should match the standard xattn_estimate for consistency
|
||||
attn_sum = softmax_fuse_block_sum(
|
||||
attn_weights,
|
||||
reshaped_block_size,
|
||||
@@ -1082,6 +1085,14 @@ def xattn_estimate_chunked(
|
||||
attn_sum = attn_sum[:, :, :q_block_num, :k_block_num]
|
||||
else:
|
||||
# PyTorch fallback implementation
|
||||
# Match Triton kernel exactly for consistency
|
||||
#
|
||||
# Triton uses:
|
||||
# 1. exp2 (base-2 exponential) for softmax
|
||||
# 2. scale factor includes log2(e) = 1.4426950408889634
|
||||
# 3. causal mask: q_pos >= k_pos (not q_pos + 1 > k_pos)
|
||||
# 4. chunk_start for global Q position tracking
|
||||
|
||||
# Reshape K: interleave positions and concatenate head dims
|
||||
reshaped_key = torch.cat(
|
||||
[(key_states[:, :, k::stride, :]) for k in range(stride)], dim=-1
|
||||
@@ -1093,49 +1104,58 @@ def xattn_estimate_chunked(
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
# Use same scale as Triton: includes log2(e) for exp2 compatibility
|
||||
# Triton: scale = 1.4426950408889634 / math.sqrt(head_dim) / stride / norm
|
||||
scale = 1.4426950408889634 / math.sqrt(head_dim) / stride / norm
|
||||
|
||||
# Convert to float32 for numerical stability (matching Triton)
|
||||
reshaped_query_f32 = reshaped_query.to(torch.float32)
|
||||
reshaped_key_f32 = reshaped_key.to(torch.float32)
|
||||
|
||||
# Compute attention weights: (B, H, q_len/stride, k_len/stride)
|
||||
attn_weights = torch.matmul(
|
||||
reshaped_query, reshaped_key.transpose(2, 3)
|
||||
) / math.sqrt(head_dim) / stride / norm
|
||||
reshaped_query_f32, reshaped_key_f32.transpose(2, 3)
|
||||
) * scale
|
||||
|
||||
# Apply causal mask
|
||||
# Apply causal mask (matching Triton's logic exactly)
|
||||
if causal:
|
||||
reshaped_q_positions = reshaped_q_len
|
||||
causal_mask = torch.zeros(
|
||||
(batch_size, num_heads, reshaped_q_positions, reshaped_k_len),
|
||||
device=key_states.device,
|
||||
dtype=attn_weights.dtype,
|
||||
# Triton uses: offs_q = chunk_start + block_id * block_size + arange(0, block_size)
|
||||
# chunk_start = q_start_block * reshaped_block_size
|
||||
chunk_start = q_start_block * reshaped_block_size
|
||||
|
||||
# Create position indices in reshaped space
|
||||
q_positions = torch.arange(reshaped_q_len, device=attn_weights.device) + chunk_start
|
||||
k_positions = torch.arange(reshaped_k_len, device=attn_weights.device)
|
||||
|
||||
# Triton causal mask: q_pos >= k_pos
|
||||
causal_mask = q_positions[:, None] >= k_positions[None, :] # (reshaped_q_len, reshaped_k_len)
|
||||
|
||||
# Apply causal mask: set future positions to -1e6 (matching Triton)
|
||||
attn_weights = attn_weights.masked_fill(
|
||||
~causal_mask.unsqueeze(0).unsqueeze(0), -1e6
|
||||
)
|
||||
|
||||
# Mask out padding in K
|
||||
if k_pad > 0:
|
||||
causal_mask[:, :, :, -(k_pad // stride):] = float("-inf")
|
||||
# Softmax using exp2 (matching Triton exactly)
|
||||
# Triton: X = tl.exp2(X - m_i[:, None]) * l_i_inv[:, None]
|
||||
# All computation in float32
|
||||
attn_max = attn_weights.max(dim=-1, keepdim=True).values
|
||||
attn_weights_shifted = attn_weights - attn_max
|
||||
attn_exp2 = torch.exp2(attn_weights_shifted)
|
||||
attn_sum_exp2 = attn_exp2.sum(dim=-1, keepdim=True)
|
||||
attn_weights = attn_exp2 / attn_sum_exp2
|
||||
|
||||
# Mask out future positions
|
||||
q_start_reshaped = q_start_pos // stride
|
||||
for q_idx in range(reshaped_q_positions):
|
||||
q_pos_reshaped = q_start_reshaped + q_idx
|
||||
if q_pos_reshaped + 1 < reshaped_k_len:
|
||||
causal_mask[:, :, q_idx, q_pos_reshaped + 1:] = float("-inf")
|
||||
# Mask for valid Q positions (matching Triton's sum_mask)
|
||||
# Triton: sum_mask = offs_q[:, None] < real_q_len
|
||||
# real_q_len = chunk_start + valid_q_reshaped
|
||||
chunk_start = q_start_block * reshaped_block_size
|
||||
real_q_len = chunk_start + valid_q_reshaped
|
||||
q_positions = torch.arange(reshaped_q_len, device=attn_weights.device) + chunk_start
|
||||
valid_q_mask = q_positions < real_q_len # (reshaped_q_len,)
|
||||
|
||||
# Handle padding in Q
|
||||
if q_pad > 0:
|
||||
q_pad_reshaped = q_pad // stride
|
||||
if q_pad_reshaped > 0:
|
||||
causal_mask[:, :, -q_pad_reshaped:, :] = float("-inf")
|
||||
# Zero out invalid Q positions
|
||||
attn_weights = attn_weights * valid_q_mask.view(1, 1, -1, 1).float()
|
||||
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
# Apply softmax
|
||||
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
|
||||
# Zero out padded Q positions
|
||||
if q_pad > 0:
|
||||
q_pad_reshaped = q_pad // stride
|
||||
if q_pad_reshaped > 0:
|
||||
attn_weights[:, :, -q_pad_reshaped:, :] = 0
|
||||
|
||||
# Aggregate to block level
|
||||
# Aggregate to block level (keep in float32)
|
||||
attn_sum = attn_weights.view(
|
||||
batch_size,
|
||||
num_heads,
|
||||
@@ -1145,6 +1165,9 @@ def xattn_estimate_chunked(
|
||||
reshaped_block_size,
|
||||
).sum(dim=-1).sum(dim=-2)
|
||||
|
||||
# Convert back to input dtype for consistency
|
||||
attn_sum = attn_sum.to(query_states.dtype)
|
||||
|
||||
# Find blocks that exceed threshold
|
||||
simple_mask = find_blocks_chunked(
|
||||
attn_sum,
|
||||
|
||||
Reference in New Issue
Block a user