feat: add KV chunking support for XAttention softmax kernels
Implement three-phase KV chunking for sparse attention estimation: 1. softmax_compute_partial_stats: compute (m, l) per KV chunk 2. merge_softmax_stats: merge partial stats on host 3. softmax_normalize_and_block_sum: normalize with global stats This allows computing sparse attention masks without storing full raw attention scores in GPU memory, reducing peak memory usage from O(q_len * k_full_len) to O(q_len * k_chunk_len). Key changes: - Add softmax_partial_stats_kernel with causal mask support - Add softmax_normalize_block_sum_kernel with kv_offset parameter - Add Python wrappers for new kernels - Update test script to validate KV chunking alignment - Add documentation for the new kernels Test results show perfect alignment with xattn_estimate API: - Density difference: 0.000000 - Mask difference: 0.0044% Generated with [Claude Code](https://claude.ai/code) via [Happy](https://happy.engineering) Co-Authored-By: Claude <noreply@anthropic.com> Co-Authored-By: Happy <yesreply@happy.engineering>
This commit is contained in:
@@ -218,6 +218,209 @@ def softmax_fuse_block_sum_kernel_non_causal(
|
||||
tl.store(output_ptr + iter * segment_size // block_size, X.to(Out.type.element_ty))
|
||||
|
||||
|
||||
# ============================================================
|
||||
# KV Chunking Support Kernels
|
||||
# ============================================================
|
||||
|
||||
@triton.jit
|
||||
def softmax_partial_stats_kernel(
|
||||
In,
|
||||
M_out, # max per row
|
||||
L_out, # sum per row (normalized by M_out)
|
||||
scale,
|
||||
input_stride_0,
|
||||
input_stride_1,
|
||||
input_stride_2,
|
||||
stats_stride_0,
|
||||
stats_stride_1,
|
||||
k_len,
|
||||
chunk_start, # Q start position (for causal)
|
||||
kv_offset, # KV chunk offset (for causal)
|
||||
segment_size: tl.constexpr,
|
||||
block_size: tl.constexpr,
|
||||
is_causal: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Compute partial softmax statistics for a KV chunk.
|
||||
|
||||
For each query row, computes:
|
||||
- m: max value in this chunk
|
||||
- l: sum of exp(x - m) in this chunk
|
||||
|
||||
These can be merged across chunks using online softmax formula.
|
||||
|
||||
Input shape: [batch, heads, q_len, k_chunk_len]
|
||||
Output shapes: M[batch, heads, q_len], L[batch, heads, q_len]
|
||||
"""
|
||||
block_id = tl.program_id(0)
|
||||
head_id = tl.program_id(1)
|
||||
batch_id = tl.program_id(2)
|
||||
|
||||
offs_q = tl.arange(0, block_size) + chunk_start + block_id * block_size
|
||||
offs_k = tl.arange(0, segment_size)
|
||||
|
||||
num_iters = k_len // segment_size
|
||||
|
||||
# For causal: compute boundary
|
||||
if is_causal:
|
||||
# causal boundary: Q position where this KV chunk starts to be valid
|
||||
# Q[i] can attend K[j] if i >= j
|
||||
# For KV chunk at kv_offset, Q[i] can attend if i >= kv_offset
|
||||
num_iters_before_causal = (chunk_start + (block_id + 1) * block_size - 1 - kv_offset) // segment_size
|
||||
num_iters_before_causal = tl.minimum(num_iters_before_causal, num_iters)
|
||||
num_iters_before_causal = tl.maximum(num_iters_before_causal, 0)
|
||||
else:
|
||||
num_iters_before_causal = num_iters
|
||||
|
||||
# Online softmax state
|
||||
m_i = tl.zeros([block_size], dtype=tl.float32) - float("inf")
|
||||
l_i = tl.zeros([block_size], dtype=tl.float32)
|
||||
|
||||
# Input pointer
|
||||
input_ptr = In + batch_id * input_stride_0 + head_id * input_stride_1 + block_id * block_size * input_stride_2
|
||||
input_ptr = input_ptr + tl.arange(0, segment_size) + tl.arange(0, block_size)[:, None] * input_stride_2
|
||||
|
||||
# Compute max and sum (before causal boundary)
|
||||
for iter in range(0, num_iters_before_causal):
|
||||
X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale
|
||||
m_local = tl.max(X, 1)
|
||||
m_new = tl.maximum(m_i, m_local)
|
||||
alpha = tl.math.exp2(m_i - m_new)
|
||||
|
||||
X = X - m_new[:, None]
|
||||
l_local = tl.sum(tl.math.exp2(X), 1)
|
||||
l_i = l_i * alpha + l_local
|
||||
|
||||
m_i = m_new
|
||||
|
||||
# Handle causal boundary
|
||||
if is_causal:
|
||||
for iter in range(num_iters_before_causal, num_iters_before_causal + 1):
|
||||
if iter < num_iters:
|
||||
X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale
|
||||
# causal mask: Q[i] >= K[j] + kv_offset
|
||||
mask = offs_q[:, None] >= (offs_k[None, :] + iter * segment_size + kv_offset)
|
||||
X = tl.where(mask, X, -1.0e6)
|
||||
m_local = tl.max(X, 1)
|
||||
m_new = tl.maximum(m_i, m_local)
|
||||
alpha = tl.math.exp2(m_i - m_new)
|
||||
|
||||
X = X - m_new[:, None]
|
||||
l_local = tl.sum(tl.math.exp2(X), 1)
|
||||
l_i = l_i * alpha + l_local
|
||||
|
||||
m_i = m_new
|
||||
|
||||
# Output pointers
|
||||
m_ptr = M_out + batch_id * stats_stride_0 + head_id * stats_stride_1 + block_id * block_size
|
||||
l_ptr = L_out + batch_id * stats_stride_0 + head_id * stats_stride_1 + block_id * block_size
|
||||
|
||||
offs = tl.arange(0, block_size)
|
||||
tl.store(m_ptr + offs, m_i.to(M_out.type.element_ty))
|
||||
tl.store(l_ptr + offs, l_i.to(L_out.type.element_ty))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def softmax_normalize_block_sum_kernel(
|
||||
In,
|
||||
Out,
|
||||
M_global, # global max per row
|
||||
L_global, # global sum per row
|
||||
scale,
|
||||
input_stride_0,
|
||||
input_stride_1,
|
||||
input_stride_2,
|
||||
output_stride_0,
|
||||
output_stride_1,
|
||||
output_stride_2,
|
||||
stats_stride_0,
|
||||
stats_stride_1,
|
||||
real_q_len,
|
||||
k_len,
|
||||
chunk_start,
|
||||
kv_offset, # KV chunk offset (for causal)
|
||||
segment_size: tl.constexpr,
|
||||
block_size: tl.constexpr,
|
||||
is_causal: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Normalize with global stats and compute block sums for a KV chunk.
|
||||
|
||||
Uses pre-computed global m and l to correctly normalize softmax
|
||||
across all KV chunks.
|
||||
|
||||
Input shape: [batch, heads, q_len, k_chunk_len]
|
||||
Output shape: [batch, heads, q_blocks, k_chunk_blocks]
|
||||
"""
|
||||
block_id = tl.program_id(0)
|
||||
head_id = tl.program_id(1)
|
||||
batch_id = tl.program_id(2)
|
||||
|
||||
offs_q = tl.arange(0, block_size) + chunk_start + block_id * block_size
|
||||
offs_k = tl.arange(0, segment_size)
|
||||
|
||||
num_iters = k_len // segment_size
|
||||
|
||||
# For causal: compute boundary
|
||||
if is_causal:
|
||||
num_iters_before_causal = (chunk_start + (block_id + 1) * block_size - 1 - kv_offset) // segment_size
|
||||
num_iters_before_causal = tl.minimum(num_iters_before_causal, num_iters)
|
||||
num_iters_before_causal = tl.maximum(num_iters_before_causal, 0)
|
||||
else:
|
||||
num_iters_before_causal = num_iters
|
||||
|
||||
# Load global stats
|
||||
m_ptr = M_global + batch_id * stats_stride_0 + head_id * stats_stride_1 + block_id * block_size
|
||||
l_ptr = L_global + batch_id * stats_stride_0 + head_id * stats_stride_1 + block_id * block_size
|
||||
|
||||
offs = tl.arange(0, block_size)
|
||||
m_global = tl.load(m_ptr + offs).to(tl.float32)
|
||||
l_global = tl.load(l_ptr + offs).to(tl.float32)
|
||||
# Handle l_global = 0 (when all positions are masked)
|
||||
l_global_safe = tl.where(l_global > 0, l_global, 1.0)
|
||||
l_global_inv = 1.0 / l_global_safe
|
||||
|
||||
# Input pointer
|
||||
input_ptr = In + batch_id * input_stride_0 + head_id * input_stride_1 + block_id * block_size * input_stride_2
|
||||
input_ptr = input_ptr + tl.arange(0, segment_size) + tl.arange(0, block_size)[:, None] * input_stride_2
|
||||
|
||||
# Output pointer
|
||||
output_ptr = Out + batch_id * output_stride_0 + head_id * output_stride_1 + block_id * output_stride_2
|
||||
output_ptr = output_ptr + tl.arange(0, segment_size // block_size)
|
||||
|
||||
sum_mask = offs_q[:, None] < real_q_len
|
||||
|
||||
# Normalize and compute block sums (before causal boundary)
|
||||
for iter in range(0, num_iters_before_causal):
|
||||
X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale
|
||||
X = tl.exp2(X - m_global[:, None]) * l_global_inv[:, None]
|
||||
X = tl.where(sum_mask, X, 0)
|
||||
X = tl.reshape(X, (block_size, segment_size // block_size, block_size))
|
||||
X = tl.sum(X, 2)
|
||||
X = tl.sum(X, 0)
|
||||
tl.store(output_ptr + iter * segment_size // block_size, X.to(Out.type.element_ty))
|
||||
|
||||
# Handle causal boundary
|
||||
if is_causal:
|
||||
for iter in range(num_iters_before_causal, num_iters_before_causal + 1):
|
||||
if iter < num_iters:
|
||||
X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale
|
||||
# causal mask: Q[i] >= K[j] + kv_offset
|
||||
causal_mask = offs_q[:, None] >= (offs_k[None, :] + iter * segment_size + kv_offset)
|
||||
X = tl.where(causal_mask, X, -1.0e6)
|
||||
X = tl.exp2(X - m_global[:, None]) * l_global_inv[:, None]
|
||||
X = tl.where(sum_mask, X, 0)
|
||||
X = tl.reshape(X, (block_size, segment_size // block_size, block_size))
|
||||
X = tl.sum(X, 2)
|
||||
X = tl.sum(X, 0)
|
||||
tl.store(output_ptr + iter * segment_size // block_size, X.to(Out.type.element_ty))
|
||||
|
||||
# Zero out future blocks
|
||||
for iter in range(num_iters_before_causal + 1, num_iters):
|
||||
X = tl.zeros([segment_size // block_size], dtype=tl.float32)
|
||||
tl.store(output_ptr + iter * segment_size // block_size, X.to(Out.type.element_ty))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def flat_group_gemm_fuse_reshape_kernel(
|
||||
Q, K, Out,
|
||||
@@ -380,6 +583,194 @@ def softmax_fuse_block_sum(
|
||||
return output
|
||||
|
||||
|
||||
def softmax_compute_partial_stats(
|
||||
attn_weights_slice: torch.Tensor,
|
||||
reshaped_block_size: int,
|
||||
segment_size: int,
|
||||
scale: float,
|
||||
chunk_start: int = 0,
|
||||
kv_offset: int = 0,
|
||||
is_causal: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Compute partial softmax statistics for a KV chunk.
|
||||
|
||||
This is the first step for KV-chunked softmax computation.
|
||||
For each query row, computes:
|
||||
- m: max value in this chunk
|
||||
- l: sum of exp(x - m) in this chunk
|
||||
|
||||
These partial stats can be merged across KV chunks using
|
||||
`merge_softmax_stats()`, then used with `softmax_normalize_and_block_sum()`.
|
||||
|
||||
Args:
|
||||
attn_weights_slice: Raw attention scores [batch, heads, q_len, k_chunk_len]
|
||||
reshaped_block_size: Block size in reshaped space
|
||||
segment_size: Processing segment size
|
||||
scale: Softmax scale factor
|
||||
chunk_start: Q chunk start position (in reshaped space)
|
||||
kv_offset: KV chunk offset (in reshaped space, for causal masking)
|
||||
is_causal: Whether to apply causal masking
|
||||
|
||||
Returns:
|
||||
Tuple of (m, l) where:
|
||||
- m: [batch, heads, q_len] max values per row
|
||||
- l: [batch, heads, q_len] partial sums per row
|
||||
"""
|
||||
batch_size, num_heads, q_len, k_len = attn_weights_slice.shape
|
||||
|
||||
assert q_len % reshaped_block_size == 0
|
||||
assert k_len % segment_size == 0
|
||||
assert attn_weights_slice.stride(-1) == 1
|
||||
|
||||
m_out = torch.empty(
|
||||
(batch_size, num_heads, q_len),
|
||||
dtype=torch.float32,
|
||||
device=attn_weights_slice.device
|
||||
)
|
||||
l_out = torch.empty(
|
||||
(batch_size, num_heads, q_len),
|
||||
dtype=torch.float32,
|
||||
device=attn_weights_slice.device
|
||||
)
|
||||
|
||||
grid = (q_len // reshaped_block_size, num_heads, batch_size)
|
||||
|
||||
softmax_partial_stats_kernel[grid](
|
||||
attn_weights_slice,
|
||||
m_out,
|
||||
l_out,
|
||||
scale,
|
||||
attn_weights_slice.stride(0),
|
||||
attn_weights_slice.stride(1),
|
||||
attn_weights_slice.stride(2),
|
||||
m_out.stride(0),
|
||||
m_out.stride(1),
|
||||
k_len,
|
||||
chunk_start,
|
||||
kv_offset,
|
||||
segment_size,
|
||||
reshaped_block_size,
|
||||
is_causal,
|
||||
)
|
||||
|
||||
return m_out, l_out
|
||||
|
||||
|
||||
def merge_softmax_stats(
|
||||
m_chunks: list,
|
||||
l_chunks: list,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Merge partial softmax statistics from multiple KV chunks.
|
||||
|
||||
Uses the online softmax merging formula:
|
||||
m_new = max(m1, m2)
|
||||
l_new = l1 * exp(m1 - m_new) + l2 * exp(m2 - m_new)
|
||||
|
||||
Args:
|
||||
m_chunks: List of max tensors [batch, heads, q_len] from each chunk
|
||||
l_chunks: List of sum tensors [batch, heads, q_len] from each chunk
|
||||
|
||||
Returns:
|
||||
Tuple of (m_global, l_global) with same shape as inputs
|
||||
"""
|
||||
assert len(m_chunks) == len(l_chunks)
|
||||
assert len(m_chunks) > 0
|
||||
|
||||
# Use log2 scale to match kernel (exp2)
|
||||
LOG2E = 1.4426950408889634
|
||||
|
||||
m_global = m_chunks[0].clone()
|
||||
l_global = l_chunks[0].clone()
|
||||
|
||||
for i in range(1, len(m_chunks)):
|
||||
m_chunk = m_chunks[i]
|
||||
l_chunk = l_chunks[i]
|
||||
|
||||
m_new = torch.maximum(m_global, m_chunk)
|
||||
# exp2(m - m_new) = 2^(m - m_new)
|
||||
l_global = l_global * torch.pow(2.0, m_global - m_new) + l_chunk * torch.pow(2.0, m_chunk - m_new)
|
||||
m_global = m_new
|
||||
|
||||
return m_global, l_global
|
||||
|
||||
|
||||
def softmax_normalize_and_block_sum(
|
||||
attn_weights_slice: torch.Tensor,
|
||||
m_global: torch.Tensor,
|
||||
l_global: torch.Tensor,
|
||||
reshaped_block_size: int,
|
||||
segment_size: int,
|
||||
chunk_start: int,
|
||||
real_q_len: int,
|
||||
scale: float,
|
||||
kv_offset: int = 0,
|
||||
is_causal: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Normalize with global stats and compute block sums for a KV chunk.
|
||||
|
||||
This is the second step for KV-chunked softmax computation.
|
||||
Uses pre-computed global m and l (from `merge_softmax_stats()`)
|
||||
to correctly normalize softmax values and compute block sums.
|
||||
|
||||
Args:
|
||||
attn_weights_slice: Raw attention scores [batch, heads, q_len, k_chunk_len]
|
||||
m_global: Global max values [batch, heads, q_len]
|
||||
l_global: Global sum values [batch, heads, q_len]
|
||||
reshaped_block_size: Block size in reshaped space
|
||||
segment_size: Processing segment size
|
||||
chunk_start: Start position for this chunk (for masking)
|
||||
real_q_len: Actual Q length (before padding)
|
||||
scale: Softmax scale factor
|
||||
kv_offset: KV chunk offset (in reshaped space, for causal masking)
|
||||
is_causal: Whether to apply causal masking
|
||||
|
||||
Returns:
|
||||
Block-level attention sums [batch, heads, q_blocks, k_chunk_blocks]
|
||||
"""
|
||||
batch_size, num_heads, q_len, k_len = attn_weights_slice.shape
|
||||
|
||||
assert q_len % reshaped_block_size == 0
|
||||
assert k_len % segment_size == 0
|
||||
assert segment_size % reshaped_block_size == 0
|
||||
assert attn_weights_slice.stride(-1) == 1
|
||||
|
||||
output = torch.empty(
|
||||
(batch_size, num_heads, q_len // reshaped_block_size, k_len // reshaped_block_size),
|
||||
dtype=attn_weights_slice.dtype,
|
||||
device=attn_weights_slice.device
|
||||
)
|
||||
|
||||
grid = (q_len // reshaped_block_size, num_heads, batch_size)
|
||||
|
||||
softmax_normalize_block_sum_kernel[grid](
|
||||
attn_weights_slice,
|
||||
output,
|
||||
m_global,
|
||||
l_global,
|
||||
scale,
|
||||
attn_weights_slice.stride(0),
|
||||
attn_weights_slice.stride(1),
|
||||
attn_weights_slice.stride(2),
|
||||
output.stride(0),
|
||||
output.stride(1),
|
||||
output.stride(2),
|
||||
m_global.stride(0),
|
||||
m_global.stride(1),
|
||||
real_q_len,
|
||||
k_len,
|
||||
chunk_start,
|
||||
kv_offset,
|
||||
segment_size,
|
||||
reshaped_block_size,
|
||||
is_causal,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def flat_group_gemm_fuse_reshape(
|
||||
query_states: torch.Tensor,
|
||||
key_states: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user