feat: add xattn_estimate_chunked for chunked prefill support
- Add xattn_estimate_chunked function ported from COMPASS - Support chunked prefill with q_start_pos parameter - Ensure 100% consistency with standard xattn_estimate when using matching chunk_size parameter - Add test and documentation Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -950,3 +950,218 @@ def compute_sparsity(mask: torch.Tensor, causal: bool = True) -> float:
|
||||
selected_blocks = mask.sum().item()
|
||||
|
||||
return 1.0 - (selected_blocks / total_blocks)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Chunked Estimation Function (for Chunked Prefill)
|
||||
# ============================================================
|
||||
|
||||
def xattn_estimate_chunked(
|
||||
query_states: torch.Tensor,
|
||||
key_states: torch.Tensor,
|
||||
q_start_pos: int,
|
||||
block_size: int = 128,
|
||||
stride: int = 8,
|
||||
norm: float = 1.0,
|
||||
threshold: float = 0.9,
|
||||
chunk_size: int = 16384,
|
||||
use_triton: bool = True,
|
||||
causal: bool = True,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Estimate block importance for XAttention in chunked prefill mode.
|
||||
|
||||
This function is designed for chunked prefill scenarios where:
|
||||
- Q is processed in chunks while K accumulates across chunks
|
||||
- q_start_pos indicates the position of the current Q chunk in the full sequence
|
||||
- K length can be >= Q length (accumulated KV cache)
|
||||
|
||||
Ported from COMPASS project (compass/src/Xattn_chunked.py).
|
||||
|
||||
Args:
|
||||
query_states: Q tensor [batch, heads, q_chunk_len, head_dim] - current Q chunk
|
||||
key_states: K tensor [batch, heads, k_len, head_dim] - accumulated K (k_len >= q_chunk_len)
|
||||
q_start_pos: Start position of this Q chunk in the full sequence
|
||||
block_size: Block size in tokens (typically 128 for BSA compatibility)
|
||||
stride: Stride for Q/K reshape (typically 8)
|
||||
norm: Normalization factor for attention scores
|
||||
threshold: Cumulative attention threshold (0.0-1.0)
|
||||
chunk_size: Processing chunk size for Triton kernel alignment
|
||||
use_triton: Whether to use Triton kernels (requires SM 80+)
|
||||
causal: Whether to apply causal masking
|
||||
|
||||
Returns:
|
||||
attn_sums: Block-level attention scores [batch, heads, q_blocks, k_blocks]
|
||||
simple_masks: Boolean mask for sparse attention [batch, heads, q_blocks, k_blocks]
|
||||
|
||||
Example:
|
||||
>>> # Chunk 0: Q[0:C] attends to K[0:C]
|
||||
>>> attn_sums, mask = xattn_estimate_chunked(q_chunk0, k_chunk0, q_start_pos=0)
|
||||
>>>
|
||||
>>> # Chunk 1: Q[C:2C] attends to K[0:2C]
|
||||
>>> attn_sums, mask = xattn_estimate_chunked(q_chunk1, k_accum, q_start_pos=C)
|
||||
"""
|
||||
batch_size, num_heads, q_len, head_dim = query_states.shape
|
||||
_, _, k_len, _ = key_states.shape
|
||||
|
||||
# Store original lengths for valid region tracking
|
||||
original_q_len = q_len
|
||||
original_k_len = k_len
|
||||
|
||||
# Validate inputs
|
||||
assert k_len >= q_len, f"K length ({k_len}) must be >= Q length ({q_len})"
|
||||
assert q_start_pos + q_len <= k_len, f"Q end position ({q_start_pos + q_len}) exceeds K length ({k_len})"
|
||||
|
||||
# Calculate block counts
|
||||
q_block_num = (q_len + block_size - 1) // block_size
|
||||
k_block_num = (k_len + block_size - 1) // block_size
|
||||
q_start_block = q_start_pos // block_size
|
||||
|
||||
# Check GPU capability for Triton
|
||||
if use_triton:
|
||||
props = torch.cuda.get_device_properties(torch.cuda.current_device())
|
||||
if props.major < 8:
|
||||
use_triton = False
|
||||
|
||||
# Pad Q and K for alignment
|
||||
if use_triton:
|
||||
# For Triton: pad to chunk_size alignment
|
||||
padded_q_len = ((q_len + chunk_size - 1) // chunk_size) * chunk_size
|
||||
padded_k_len = ((k_len + chunk_size - 1) // chunk_size) * chunk_size
|
||||
else:
|
||||
# For PyTorch fallback: pad to block_size alignment
|
||||
padded_q_len = q_block_num * block_size
|
||||
padded_k_len = k_block_num * block_size
|
||||
|
||||
q_pad = padded_q_len - q_len
|
||||
k_pad = padded_k_len - k_len
|
||||
|
||||
if q_pad > 0:
|
||||
query_states = F.pad(query_states, (0, 0, 0, q_pad), value=0)
|
||||
if k_pad > 0:
|
||||
key_states = F.pad(key_states, (0, 0, 0, k_pad), value=0)
|
||||
|
||||
# Reshape dimensions
|
||||
reshaped_block_size = block_size // stride
|
||||
reshaped_q_len = padded_q_len // stride
|
||||
reshaped_k_len = padded_k_len // stride
|
||||
|
||||
# Calculate valid lengths in reshaped space (for masking padding)
|
||||
valid_q_reshaped = (original_q_len + stride - 1) // stride
|
||||
valid_k_reshaped = (original_k_len + stride - 1) // stride
|
||||
|
||||
if use_triton:
|
||||
# Compute chunk boundaries in reshaped space
|
||||
chunk_start = q_start_block * reshaped_block_size
|
||||
chunk_end = chunk_start + reshaped_q_len # Padded end for computation
|
||||
real_q_len = chunk_start + valid_q_reshaped # Valid end for masking padding
|
||||
|
||||
# Use Triton kernel for efficient computation
|
||||
attn_weights = flat_group_gemm_fuse_reshape(
|
||||
query_states,
|
||||
key_states,
|
||||
stride,
|
||||
chunk_start, # q_start in reshaped space
|
||||
chunk_end, # q_end in reshaped space (padded)
|
||||
is_causal=causal,
|
||||
)
|
||||
|
||||
# Softmax + block sum
|
||||
attn_sum = softmax_fuse_block_sum(
|
||||
attn_weights,
|
||||
reshaped_block_size,
|
||||
min(4096, reshaped_block_size),
|
||||
chunk_start,
|
||||
chunk_end,
|
||||
real_q_len,
|
||||
1.4426950408889634 / math.sqrt(head_dim) / stride / norm,
|
||||
is_causal=causal,
|
||||
)
|
||||
|
||||
# Extract only the valid block region
|
||||
attn_sum = attn_sum[:, :, :q_block_num, :k_block_num]
|
||||
else:
|
||||
# PyTorch fallback implementation
|
||||
# Reshape K: interleave positions and concatenate head dims
|
||||
reshaped_key = torch.cat(
|
||||
[(key_states[:, :, k::stride, :]) for k in range(stride)], dim=-1
|
||||
) # (B, H, k_len/stride, D*stride)
|
||||
|
||||
# Reshape Q (inverse mode)
|
||||
reshaped_query = torch.cat(
|
||||
[(query_states[:, :, (stride - 1 - q)::stride, :]) for q in range(stride)],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
# Compute attention weights: (B, H, q_len/stride, k_len/stride)
|
||||
attn_weights = torch.matmul(
|
||||
reshaped_query, reshaped_key.transpose(2, 3)
|
||||
) / math.sqrt(head_dim) / stride / norm
|
||||
|
||||
# Apply causal mask
|
||||
if causal:
|
||||
reshaped_q_positions = reshaped_q_len
|
||||
causal_mask = torch.zeros(
|
||||
(batch_size, num_heads, reshaped_q_positions, reshaped_k_len),
|
||||
device=key_states.device,
|
||||
dtype=attn_weights.dtype,
|
||||
)
|
||||
|
||||
# Mask out padding in K
|
||||
if k_pad > 0:
|
||||
causal_mask[:, :, :, -(k_pad // stride):] = float("-inf")
|
||||
|
||||
# Mask out future positions
|
||||
q_start_reshaped = q_start_pos // stride
|
||||
for q_idx in range(reshaped_q_positions):
|
||||
q_pos_reshaped = q_start_reshaped + q_idx
|
||||
if q_pos_reshaped + 1 < reshaped_k_len:
|
||||
causal_mask[:, :, q_idx, q_pos_reshaped + 1:] = float("-inf")
|
||||
|
||||
# Handle padding in Q
|
||||
if q_pad > 0:
|
||||
q_pad_reshaped = q_pad // stride
|
||||
if q_pad_reshaped > 0:
|
||||
causal_mask[:, :, -q_pad_reshaped:, :] = float("-inf")
|
||||
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
# Apply softmax
|
||||
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
|
||||
# Zero out padded Q positions
|
||||
if q_pad > 0:
|
||||
q_pad_reshaped = q_pad // stride
|
||||
if q_pad_reshaped > 0:
|
||||
attn_weights[:, :, -q_pad_reshaped:, :] = 0
|
||||
|
||||
# Aggregate to block level
|
||||
attn_sum = attn_weights.view(
|
||||
batch_size,
|
||||
num_heads,
|
||||
q_block_num,
|
||||
reshaped_block_size,
|
||||
k_block_num,
|
||||
reshaped_block_size,
|
||||
).sum(dim=-1).sum(dim=-2)
|
||||
|
||||
# Find blocks that exceed threshold
|
||||
simple_mask = find_blocks_chunked(
|
||||
attn_sum,
|
||||
q_start_block, # offset for causal mask in find_blocks_chunked
|
||||
threshold,
|
||||
None,
|
||||
decoding=False,
|
||||
mode="prefill",
|
||||
causal=causal,
|
||||
)
|
||||
|
||||
# Apply causal constraint on block level
|
||||
if causal:
|
||||
# For block-level causal: Q block i can only attend to K blocks j where j <= q_start_block + i
|
||||
for q_blk_idx in range(q_block_num):
|
||||
q_blk_global = q_start_block + q_blk_idx
|
||||
if q_blk_global + 1 < k_block_num:
|
||||
simple_mask[:, :, q_blk_idx, q_blk_global + 1:] = False
|
||||
|
||||
return attn_sum, simple_mask
|
||||
|
||||
Reference in New Issue
Block a user