feat: add xattn_estimate_chunked for chunked prefill support

- Add xattn_estimate_chunked function ported from COMPASS
- Support chunked prefill with q_start_pos parameter
- Ensure 100% consistency with standard xattn_estimate when
  using matching chunk_size parameter
- Add test and documentation

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Zijie Tian
2026-01-22 01:13:17 +08:00
parent 2866d4fd88
commit bc92c1fdb8
5 changed files with 561 additions and 0 deletions

View File

@@ -950,3 +950,218 @@ def compute_sparsity(mask: torch.Tensor, causal: bool = True) -> float:
selected_blocks = mask.sum().item()
return 1.0 - (selected_blocks / total_blocks)
# ============================================================
# Chunked Estimation Function (for Chunked Prefill)
# ============================================================
def xattn_estimate_chunked(
query_states: torch.Tensor,
key_states: torch.Tensor,
q_start_pos: int,
block_size: int = 128,
stride: int = 8,
norm: float = 1.0,
threshold: float = 0.9,
chunk_size: int = 16384,
use_triton: bool = True,
causal: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Estimate block importance for XAttention in chunked prefill mode.
This function is designed for chunked prefill scenarios where:
- Q is processed in chunks while K accumulates across chunks
- q_start_pos indicates the position of the current Q chunk in the full sequence
- K length can be >= Q length (accumulated KV cache)
Ported from COMPASS project (compass/src/Xattn_chunked.py).
Args:
query_states: Q tensor [batch, heads, q_chunk_len, head_dim] - current Q chunk
key_states: K tensor [batch, heads, k_len, head_dim] - accumulated K (k_len >= q_chunk_len)
q_start_pos: Start position of this Q chunk in the full sequence
block_size: Block size in tokens (typically 128 for BSA compatibility)
stride: Stride for Q/K reshape (typically 8)
norm: Normalization factor for attention scores
threshold: Cumulative attention threshold (0.0-1.0)
chunk_size: Processing chunk size for Triton kernel alignment
use_triton: Whether to use Triton kernels (requires SM 80+)
causal: Whether to apply causal masking
Returns:
attn_sums: Block-level attention scores [batch, heads, q_blocks, k_blocks]
simple_masks: Boolean mask for sparse attention [batch, heads, q_blocks, k_blocks]
Example:
>>> # Chunk 0: Q[0:C] attends to K[0:C]
>>> attn_sums, mask = xattn_estimate_chunked(q_chunk0, k_chunk0, q_start_pos=0)
>>>
>>> # Chunk 1: Q[C:2C] attends to K[0:2C]
>>> attn_sums, mask = xattn_estimate_chunked(q_chunk1, k_accum, q_start_pos=C)
"""
batch_size, num_heads, q_len, head_dim = query_states.shape
_, _, k_len, _ = key_states.shape
# Store original lengths for valid region tracking
original_q_len = q_len
original_k_len = k_len
# Validate inputs
assert k_len >= q_len, f"K length ({k_len}) must be >= Q length ({q_len})"
assert q_start_pos + q_len <= k_len, f"Q end position ({q_start_pos + q_len}) exceeds K length ({k_len})"
# Calculate block counts
q_block_num = (q_len + block_size - 1) // block_size
k_block_num = (k_len + block_size - 1) // block_size
q_start_block = q_start_pos // block_size
# Check GPU capability for Triton
if use_triton:
props = torch.cuda.get_device_properties(torch.cuda.current_device())
if props.major < 8:
use_triton = False
# Pad Q and K for alignment
if use_triton:
# For Triton: pad to chunk_size alignment
padded_q_len = ((q_len + chunk_size - 1) // chunk_size) * chunk_size
padded_k_len = ((k_len + chunk_size - 1) // chunk_size) * chunk_size
else:
# For PyTorch fallback: pad to block_size alignment
padded_q_len = q_block_num * block_size
padded_k_len = k_block_num * block_size
q_pad = padded_q_len - q_len
k_pad = padded_k_len - k_len
if q_pad > 0:
query_states = F.pad(query_states, (0, 0, 0, q_pad), value=0)
if k_pad > 0:
key_states = F.pad(key_states, (0, 0, 0, k_pad), value=0)
# Reshape dimensions
reshaped_block_size = block_size // stride
reshaped_q_len = padded_q_len // stride
reshaped_k_len = padded_k_len // stride
# Calculate valid lengths in reshaped space (for masking padding)
valid_q_reshaped = (original_q_len + stride - 1) // stride
valid_k_reshaped = (original_k_len + stride - 1) // stride
if use_triton:
# Compute chunk boundaries in reshaped space
chunk_start = q_start_block * reshaped_block_size
chunk_end = chunk_start + reshaped_q_len # Padded end for computation
real_q_len = chunk_start + valid_q_reshaped # Valid end for masking padding
# Use Triton kernel for efficient computation
attn_weights = flat_group_gemm_fuse_reshape(
query_states,
key_states,
stride,
chunk_start, # q_start in reshaped space
chunk_end, # q_end in reshaped space (padded)
is_causal=causal,
)
# Softmax + block sum
attn_sum = softmax_fuse_block_sum(
attn_weights,
reshaped_block_size,
min(4096, reshaped_block_size),
chunk_start,
chunk_end,
real_q_len,
1.4426950408889634 / math.sqrt(head_dim) / stride / norm,
is_causal=causal,
)
# Extract only the valid block region
attn_sum = attn_sum[:, :, :q_block_num, :k_block_num]
else:
# PyTorch fallback implementation
# Reshape K: interleave positions and concatenate head dims
reshaped_key = torch.cat(
[(key_states[:, :, k::stride, :]) for k in range(stride)], dim=-1
) # (B, H, k_len/stride, D*stride)
# Reshape Q (inverse mode)
reshaped_query = torch.cat(
[(query_states[:, :, (stride - 1 - q)::stride, :]) for q in range(stride)],
dim=-1,
)
# Compute attention weights: (B, H, q_len/stride, k_len/stride)
attn_weights = torch.matmul(
reshaped_query, reshaped_key.transpose(2, 3)
) / math.sqrt(head_dim) / stride / norm
# Apply causal mask
if causal:
reshaped_q_positions = reshaped_q_len
causal_mask = torch.zeros(
(batch_size, num_heads, reshaped_q_positions, reshaped_k_len),
device=key_states.device,
dtype=attn_weights.dtype,
)
# Mask out padding in K
if k_pad > 0:
causal_mask[:, :, :, -(k_pad // stride):] = float("-inf")
# Mask out future positions
q_start_reshaped = q_start_pos // stride
for q_idx in range(reshaped_q_positions):
q_pos_reshaped = q_start_reshaped + q_idx
if q_pos_reshaped + 1 < reshaped_k_len:
causal_mask[:, :, q_idx, q_pos_reshaped + 1:] = float("-inf")
# Handle padding in Q
if q_pad > 0:
q_pad_reshaped = q_pad // stride
if q_pad_reshaped > 0:
causal_mask[:, :, -q_pad_reshaped:, :] = float("-inf")
attn_weights = attn_weights + causal_mask
# Apply softmax
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
# Zero out padded Q positions
if q_pad > 0:
q_pad_reshaped = q_pad // stride
if q_pad_reshaped > 0:
attn_weights[:, :, -q_pad_reshaped:, :] = 0
# Aggregate to block level
attn_sum = attn_weights.view(
batch_size,
num_heads,
q_block_num,
reshaped_block_size,
k_block_num,
reshaped_block_size,
).sum(dim=-1).sum(dim=-2)
# Find blocks that exceed threshold
simple_mask = find_blocks_chunked(
attn_sum,
q_start_block, # offset for causal mask in find_blocks_chunked
threshold,
None,
decoding=False,
mode="prefill",
causal=causal,
)
# Apply causal constraint on block level
if causal:
# For block-level causal: Q block i can only attend to K blocks j where j <= q_start_block + i
for q_blk_idx in range(q_block_num):
q_blk_global = q_start_block + q_blk_idx
if q_blk_global + 1 < k_block_num:
simple_mask[:, :, q_blk_idx, q_blk_global + 1:] = False
return attn_sum, simple_mask