feat: add nanovllm.ops module with XAttention estimation kernels

Add ops module ported from tzj/minference branch containing:
- xattn.py: XAttention block importance estimation with Triton kernels
  - xattn_estimate(): standard estimation for sparse attention mask
  - xattn_estimate_chunked(): chunked prefill compatible version
  - flat_group_gemm_fuse_reshape(): fused stride reshape + GEMM kernel
  - softmax_fuse_block_sum(): online softmax + block-wise sum kernel
- chunked_attention.py: Flash attention with LSE output for chunk merging
- test_xattn_estimate_chunked.py: verification test (all seq_lens pass)

This prepares the foundation for AttentionPolicy refactoring where
XAttentionPolicy.estimate() will call these ops.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Zijie Tian
2026-01-22 06:00:42 +08:00
parent 2826a649de
commit 9f3ee9279e
4 changed files with 2073 additions and 0 deletions

38
nanovllm/ops/__init__.py Normal file
View File

@@ -0,0 +1,38 @@
"""
Operators module for nano-vLLM.
This module contains low-level attention operators and kernels.
"""
from nanovllm.ops.chunked_attention import (
flash_attn_with_lse,
merge_attention_outputs,
chunked_attention_varlen,
ChunkedPrefillState,
)
from nanovllm.ops.xattn import (
xattn_estimate,
xattn_estimate_chunked,
flat_group_gemm_fuse_reshape,
softmax_fuse_block_sum,
find_blocks_chunked,
create_causal_mask,
compute_sparsity,
)
__all__ = [
# chunked_attention
"flash_attn_with_lse",
"merge_attention_outputs",
"chunked_attention_varlen",
"ChunkedPrefillState",
# xattn
"xattn_estimate",
"xattn_estimate_chunked",
"flat_group_gemm_fuse_reshape",
"softmax_fuse_block_sum",
"find_blocks_chunked",
"create_causal_mask",
"compute_sparsity",
]

View File

@@ -0,0 +1,624 @@
"""
Chunked attention implementation for CPU KV cache offloading.
This module implements flash attention with LSE (log-sum-exp) output,
enabling proper online softmax merging for chunked prefill.
Key functions:
- flash_attn_with_lse: Flash attention that returns output and LSE
- merge_attention_outputs: Merge outputs from multiple KV chunks
- chunked_prefill_attention: High-level interface for chunked attention
"""
import math
import torch
import triton
import triton.language as tl
from typing import Tuple, List, Optional
@triton.heuristics(
{
"EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
"EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
"EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
}
)
@triton.jit
def _fwd_kernel_with_lse(
Q,
K,
V,
Out,
Lse,
softmax_scale,
stride_qb,
stride_qh,
stride_qm,
stride_kb,
stride_kh,
stride_kn,
stride_vb,
stride_vh,
stride_vn,
stride_ob,
stride_oh,
stride_om,
nheads,
seqlen_q,
seqlen_k,
seqlen_q_rounded,
headdim,
CACHE_KEY_SEQLEN_Q,
CACHE_KEY_SEQLEN_K,
IS_CAUSAL: tl.constexpr,
BLOCK_HEADDIM: tl.constexpr,
EVEN_M: tl.constexpr,
EVEN_N: tl.constexpr,
EVEN_HEADDIM: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
"""
Flash attention forward kernel with LSE output.
Implements standard Flash Attention online softmax algorithm:
- m_i: running max of attention scores
- l_i: running sum of exp(scores - m_i)
- acc_o: running sum of softmax(scores) @ V (unnormalized)
Final output: acc_o / l_i
Final LSE: m_i + log(l_i)
"""
start_m = tl.program_id(0)
off_hb = tl.program_id(1)
off_b = off_hb // nheads
off_h = off_hb % nheads
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_HEADDIM)
# Pointers
q_ptrs = (
Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :])
)
k_ptrs = (
K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :])
)
v_ptrs = (
V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :])
)
# Initialize running statistics
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") # running max
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) # running sum of exp
acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32) # running output (unnormalized)
# Load Q (once per block)
if EVEN_M & EVEN_N:
if EVEN_HEADDIM:
q = tl.load(q_ptrs)
else:
q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
else:
if EVEN_HEADDIM:
q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)
else:
q = tl.load(
q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0
)
# Loop over K, V blocks
end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
for start_n in range(0, end_n, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# Load K
if EVEN_N & EVEN_M:
if EVEN_HEADDIM:
k = tl.load(k_ptrs + start_n * stride_kn)
else:
k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0)
else:
if EVEN_HEADDIM:
k = tl.load(
k_ptrs + start_n * stride_kn,
mask=(start_n + offs_n)[:, None] < seqlen_k,
other=0.0,
)
else:
k = tl.load(
k_ptrs + start_n * stride_kn,
mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
other=0.0,
)
# Compute QK^T * scale
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, tl.trans(k))
qk *= softmax_scale
# Apply masks
if not EVEN_N:
qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf"))
if IS_CAUSAL:
qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf"))
# Online softmax: compute block max
m_ij = tl.max(qk, 1) # [BLOCK_M]
# New running max
m_new = tl.maximum(m_i, m_ij) # [BLOCK_M]
# Rescale factor for previous accumulator
alpha = tl.exp(m_i - m_new) # [BLOCK_M]
# Compute P = exp(qk - m_new)
p = tl.exp(qk - m_new[:, None]) # [BLOCK_M, BLOCK_N]
# Sum of current block
l_ij = tl.sum(p, 1) # [BLOCK_M]
# Update running sum: l_new = l_i * alpha + l_ij
l_new = l_i * alpha + l_ij
# Rescale previous output and add new contribution
acc_o = acc_o * alpha[:, None]
# Load V
if EVEN_N & EVEN_M:
if EVEN_HEADDIM:
v = tl.load(v_ptrs + start_n * stride_vn)
else:
v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0)
else:
if EVEN_HEADDIM:
v = tl.load(
v_ptrs + start_n * stride_vn,
mask=(start_n + offs_n)[:, None] < seqlen_k,
other=0.0,
)
else:
v = tl.load(
v_ptrs + start_n * stride_vn,
mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
other=0.0,
)
# acc_o += P @ V
p = p.to(v.dtype)
acc_o += tl.dot(p, v)
# Update running statistics
m_i = m_new
l_i = l_new
# Final normalization: output = acc_o / l_i
acc_o = acc_o / l_i[:, None]
# Compute LSE = m_i + log(l_i)
lse_i = m_i + tl.log(l_i)
# Store LSE
lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m
if EVEN_M:
tl.store(lse_ptrs, lse_i)
else:
tl.store(lse_ptrs, lse_i, mask=offs_m < seqlen_q)
# Store output
out_ptrs = (
Out
+ off_b * stride_ob
+ off_h * stride_oh
+ (offs_m[:, None] * stride_om + offs_d[None, :])
)
if EVEN_M:
if EVEN_HEADDIM:
tl.store(out_ptrs, acc_o)
else:
tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim)
else:
if EVEN_HEADDIM:
tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q)
else:
tl.store(
out_ptrs, acc_o, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim)
)
def flash_attn_with_lse(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
softmax_scale: Optional[float] = None,
causal: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Flash attention forward pass that returns both output and LSE.
Uses flash_attn library which natively supports GQA without memory overhead.
Args:
q: Query tensor [batch, seqlen_q, nheads_q, headdim]
k: Key tensor [batch, seqlen_k, nheads_kv, headdim]
v: Value tensor [batch, seqlen_k, nheads_kv, headdim]
softmax_scale: Scaling factor (default: 1/sqrt(headdim))
causal: Whether to apply causal masking
Returns:
out: Output tensor [batch, seqlen_q, nheads_q, headdim]
lse: Log-sum-exp tensor [batch, nheads_q, seqlen_q]
"""
from flash_attn.flash_attn_interface import flash_attn_func
batch, seqlen_q, nheads_q, headdim = q.shape
_, seqlen_k, nheads_kv, _ = k.shape
if softmax_scale is None:
softmax_scale = 1.0 / math.sqrt(headdim)
# Use flash_attn_func which natively supports GQA (no memory overhead)
# It returns (output, softmax_lse) when return_attn_probs=True is not set
# We need to use the internal function to get LSE
out, lse, _ = flash_attn_func(
q, k, v,
softmax_scale=softmax_scale,
causal=causal,
return_attn_probs=True, # This makes it return (out, softmax_lse, S_dmask)
)
# lse shape from flash_attn: [batch, nheads_q, seqlen_q_rounded]
# Trim to actual seqlen_q
lse = lse[:, :, :seqlen_q]
return out, lse
@triton.jit
def _merge_lse_kernel(
lse1_ptr, lse2_ptr, lse_out_ptr,
num_elements: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""Fused kernel for merging LSE values.
IMPORTANT: Uses fp32 for exp/log operations to avoid precision loss.
bf16 has only 7 bits of mantissa, causing significant errors in exp/log.
"""
# Each program handles BLOCK_SIZE elements
pid = tl.program_id(0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < num_elements
# Load lse values and convert to fp32 for precision
lse1 = tl.load(lse1_ptr + offsets, mask=mask).to(tl.float32)
lse2 = tl.load(lse2_ptr + offsets, mask=mask).to(tl.float32)
# Compute max for numerical stability (in fp32)
max_lse = tl.maximum(lse1, lse2)
# Compute exp(lse - max_lse) in fp32
exp1 = tl.exp(lse1 - max_lse)
exp2 = tl.exp(lse2 - max_lse)
# Compute merged LSE: max_lse + log(exp1 + exp2) in fp32
lse_merged = max_lse + tl.log(exp1 + exp2)
# Store result (convert back to original dtype)
tl.store(lse_out_ptr + offsets, lse_merged, mask=mask)
@triton.jit
def _merge_output_kernel(
o1_ptr, o2_ptr, lse1_ptr, lse2_ptr, o_out_ptr,
batch: tl.constexpr, seqlen_q: tl.constexpr, nheads: tl.constexpr, headdim: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""Fused kernel for merging attention outputs.
IMPORTANT: Uses fp32 for exp operations and weighted sum to avoid precision loss.
This is critical for numerical accuracy in chunked attention.
"""
# Each program handles BLOCK_SIZE elements along headdim for one (batch, seqlen_q, nheads) position
pid_batch = tl.program_id(0)
pid_seq = tl.program_id(1)
pid_head = tl.program_id(2)
# Compute LSE index: [batch, nheads, seqlen_q]
lse_idx = pid_batch * nheads * seqlen_q + pid_head * seqlen_q + pid_seq
# Load LSE values and convert to fp32 for precision
lse1 = tl.load(lse1_ptr + lse_idx).to(tl.float32)
lse2 = tl.load(lse2_ptr + lse_idx).to(tl.float32)
# Compute max and scaling factors in fp32
max_lse = tl.maximum(lse1, lse2)
exp1 = tl.exp(lse1 - max_lse)
exp2 = tl.exp(lse2 - max_lse)
sum_exp = exp1 + exp2
# Process headdim in chunks
for d_offset in range(0, headdim, BLOCK_SIZE):
d_idx = d_offset + tl.arange(0, BLOCK_SIZE)
mask = d_idx < headdim
# Compute output index: [batch, seqlen_q, nheads, headdim]
base_idx = (pid_batch * seqlen_q * nheads * headdim +
pid_seq * nheads * headdim +
pid_head * headdim)
o_idx = base_idx + d_idx
# Load o1, o2 and convert to fp32 for weighted sum
o1_val = tl.load(o1_ptr + o_idx, mask=mask, other=0.0).to(tl.float32)
o2_val = tl.load(o2_ptr + o_idx, mask=mask, other=0.0).to(tl.float32)
# Compute merged output in fp32: (o1 * exp1 + o2 * exp2) / sum_exp
o_merged = (o1_val * exp1 + o2_val * exp2) / sum_exp
# Store result (Triton will convert back to original dtype)
tl.store(o_out_ptr + o_idx, o_merged, mask=mask)
def merge_attention_outputs(
o1: torch.Tensor,
lse1: torch.Tensor,
o2: torch.Tensor,
lse2: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Merge two attention outputs using online softmax (Triton fused kernel).
This implements the online softmax merging formula:
- m_new = max(lse1, lse2)
- o_new = (exp(lse1 - m_new) * o1 + exp(lse2 - m_new) * o2) / (exp(lse1 - m_new) + exp(lse2 - m_new))
- lse_new = m_new + log(exp(lse1 - m_new) + exp(lse2 - m_new))
Args:
o1: First output [batch, seqlen_q, nheads, headdim]
lse1: First LSE [batch, nheads, seqlen_q]
o2: Second output [batch, seqlen_q, nheads, headdim]
lse2: Second LSE [batch, nheads, seqlen_q]
Returns:
o_merged: Merged output [batch, seqlen_q, nheads, headdim]
lse_merged: Merged LSE [batch, nheads, seqlen_q]
"""
batch, seqlen_q, nheads, headdim = o1.shape
# Allocate output tensors
o_merged = torch.empty_like(o1)
lse_merged = torch.empty_like(lse1)
# Launch LSE merge kernel
num_lse_elements = batch * nheads * seqlen_q
BLOCK_SIZE_LSE = 256
grid_lse = (triton.cdiv(num_lse_elements, BLOCK_SIZE_LSE),)
_merge_lse_kernel[grid_lse](
lse1, lse2, lse_merged,
num_lse_elements,
BLOCK_SIZE=BLOCK_SIZE_LSE,
)
# Launch output merge kernel
BLOCK_SIZE = 128
grid_output = (batch, seqlen_q, nheads)
_merge_output_kernel[grid_output](
o1, o2, lse1, lse2, o_merged,
batch, seqlen_q, nheads, headdim,
BLOCK_SIZE=BLOCK_SIZE,
)
return o_merged, lse_merged
def chunked_attention_varlen(
q: torch.Tensor,
kv_chunks: List[Tuple[torch.Tensor, torch.Tensor]],
cu_seqlens_q: torch.Tensor,
cu_seqlens_k_list: List[torch.Tensor],
max_seqlen_q: int,
max_seqlen_k_list: List[int],
softmax_scale: Optional[float] = None,
causal_mask_per_chunk: Optional[List[bool]] = None,
) -> torch.Tensor:
"""
Compute attention with KV split across multiple chunks.
This is the core function for chunked prefill. It computes attention
against each KV chunk and merges results using online softmax.
For causal attention with chunked KV:
- First chunk (current tokens): Apply causal mask
- Previous chunks: No causal mask (all previous tokens are valid context)
Args:
q: Query tensor [total_q_tokens, nheads, headdim]
kv_chunks: List of (K, V) tuples, each [batch, seqlen_k_i, nheads, headdim]
cu_seqlens_q: Cumulative sequence lengths for Q [batch+1]
cu_seqlens_k_list: List of cumulative sequence lengths for each KV chunk
max_seqlen_q: Maximum query sequence length
max_seqlen_k_list: List of maximum key sequence lengths for each chunk
softmax_scale: Scaling factor
causal_mask_per_chunk: Whether to apply causal mask for each chunk
Returns:
out: Output tensor [total_q_tokens, nheads, headdim]
"""
if len(kv_chunks) == 0:
raise ValueError("Need at least one KV chunk")
nheads = q.shape[1]
headdim = q.shape[2]
batch = cu_seqlens_q.shape[0] - 1
if softmax_scale is None:
softmax_scale = 1.0 / math.sqrt(headdim)
if causal_mask_per_chunk is None:
# Default: causal for last chunk only
causal_mask_per_chunk = [False] * (len(kv_chunks) - 1) + [True]
# Initialize accumulated output and LSE
accumulated_o = None
accumulated_lse = None
for chunk_idx, (k_chunk, v_chunk) in enumerate(kv_chunks):
is_causal = causal_mask_per_chunk[chunk_idx]
# Reshape Q for batch processing
# For varlen, we need to handle each sequence separately
# For simplicity, assume single sequence (batch=1) for now
q_batched = q.unsqueeze(0) # [1, total_q, nheads, headdim]
# Compute attention for this chunk
chunk_o, chunk_lse = flash_attn_with_lse(
q_batched,
k_chunk,
v_chunk,
softmax_scale=softmax_scale,
causal=is_causal,
)
# Merge with accumulated
if accumulated_o is None:
accumulated_o = chunk_o
accumulated_lse = chunk_lse
else:
accumulated_o, accumulated_lse = merge_attention_outputs(
accumulated_o, accumulated_lse,
chunk_o, chunk_lse,
)
# Remove batch dimension
return accumulated_o.squeeze(0)
class ChunkedPrefillState:
"""
State for tracking chunked prefill progress.
This class maintains the accumulated attention output and LSE
across multiple prefill chunks.
"""
def __init__(self, num_layers: int, dtype: torch.dtype, device: torch.device):
self.num_layers = num_layers
self.dtype = dtype
self.device = device
# Per-layer accumulated outputs
# Each entry: (accumulated_output, accumulated_lse) or None
self.layer_states: List[Optional[Tuple[torch.Tensor, torch.Tensor]]] = [
None for _ in range(num_layers)
]
# Track which chunks have been processed
self.processed_chunks: int = 0
def update_layer(
self,
layer_id: int,
chunk_output: torch.Tensor,
chunk_lse: torch.Tensor,
):
"""Update accumulated state for a layer with a new chunk's output."""
if self.layer_states[layer_id] is None:
self.layer_states[layer_id] = (chunk_output, chunk_lse)
else:
acc_o, acc_lse = self.layer_states[layer_id]
merged_o, merged_lse = merge_attention_outputs(
acc_o, acc_lse,
chunk_output, chunk_lse,
)
self.layer_states[layer_id] = (merged_o, merged_lse)
def get_layer_output(self, layer_id: int) -> Optional[torch.Tensor]:
"""Get the final accumulated output for a layer."""
if self.layer_states[layer_id] is None:
return None
return self.layer_states[layer_id][0]
def clear(self):
"""Clear all accumulated state."""
self.layer_states = [None for _ in range(self.num_layers)]
self.processed_chunks = 0
# Test function
def _test_chunked_attention():
"""Test chunked attention using flash_attn_with_lse and merge_attention_outputs."""
from flash_attn.flash_attn_interface import flash_attn_func
torch.manual_seed(42)
print("=" * 70)
print("Test: Chunked attention vs flash_attn_func (non-causal)")
print("=" * 70)
print("Splitting K,V into chunks, computing attention per chunk, then merging")
print()
for dtype in [torch.float16, torch.bfloat16]:
for num_chunks in [64, 128, 256]:
for batch, seqlen, nheads, headdim in [
(1, 1024, 32, 128),
(1, 2048, 32, 128),
(1, 4096, 32, 128),
(1, 8192, 32, 128),
]:
# Generate random Q, K, V
q = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=dtype)
k = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=dtype)
v = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=dtype)
# Reference: full attention (non-causal)
out_ref = flash_attn_func(q, k, v, causal=False)
# Chunked attention: split K, V into chunks
chunk_size = seqlen // num_chunks
accumulated_o = None
accumulated_lse = None
for i in range(num_chunks):
start = i * chunk_size
end = (i + 1) * chunk_size
k_chunk = k[:, start:end, :, :]
v_chunk = v[:, start:end, :, :]
# Q attends to this K,V chunk (non-causal)
chunk_o, chunk_lse = flash_attn_with_lse(
q, k_chunk, v_chunk, causal=False
)
if accumulated_o is None:
accumulated_o = chunk_o
accumulated_lse = chunk_lse
else:
# Merge with previous chunks
accumulated_o, accumulated_lse = merge_attention_outputs(
accumulated_o, accumulated_lse,
chunk_o, chunk_lse
)
# Compare
out_diff = (out_ref - accumulated_o).abs()
out_max_diff = out_diff.max().item()
out_mean_diff = out_diff.mean().item()
status = "PASS" if out_max_diff < 1e-2 else "FAIL"
print(
f"[{status}] dtype={str(dtype):14s} chunks={num_chunks} "
f"shape=({batch}, {seqlen:4d}, {nheads:2d}, {headdim:3d}) "
f"max_diff={out_max_diff:.6f} mean_diff={out_mean_diff:.6f}"
)
print()
print("=" * 70)
print("Test completed!")
if __name__ == "__main__":
_test_chunked_attention()

1167
nanovllm/ops/xattn.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,244 @@
"""
Test: Compare xattn_estimate vs xattn_estimate_chunked
Verify that chunked estimation with EXTERNAL chunking produces the same mask
as standard estimation. This ensures the chunked version can be used in
chunked prefill scenarios without accuracy loss.
Usage:
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
python tests/test_xattn_estimate_chunked.py
"""
import sys
import traceback
import torch
from nanovllm.ops.xattn import xattn_estimate, xattn_estimate_chunked
# ============================================================
# Configuration
# ============================================================
# Configuration for xattn_estimate_chunked consistency test.
# Key requirements for 100% match:
# 1. Use matching chunk_size for both standard and chunked versions
# 2. Use same random seed for reproducibility
# Note: Tiny differences (~0.000001) may occur at boundary cases due to
# floating point precision in cumulative sum calculations.
BLOCK_SIZE = 64
STRIDE = 4
THRESHOLD = 0.9
CHUNK_SIZE = 4096 # External chunking size
# Test sequence lengths
TEST_SEQ_LENS = [4096, 8192, 16384, 32768]
# ============================================================
# Utility Functions
# ============================================================
def compare_masks(mask1, mask2, name1="standard", name2="chunked"):
"""Compare two masks and report differences."""
if mask1.shape != mask2.shape:
print(f" Shape mismatch: {name1}={mask1.shape}, {name2}={mask2.shape}")
return False
diff = (mask1 != mask2).sum().item()
total = mask1.numel()
match_rate = (total - diff) / total * 100
print(f" Match rate: {match_rate:.4f}% ({total - diff}/{total})")
if diff > 0:
diff_indices = torch.where(mask1 != mask2)
print(f" First 5 diff positions: {list(zip(*[idx[:5].tolist() for idx in diff_indices]))}")
return diff == 0
def run_chunked_externally(query, key, block_size, stride, threshold, chunk_size):
"""
Run xattn_estimate_chunked with EXTERNAL chunking.
This simulates how chunked prefill should be used in practice.
"""
batch_size, num_heads, q_len, head_dim = query.shape
_, _, k_len, _ = key.shape
q_block_num = (q_len + block_size - 1) // block_size
k_block_num = (k_len + block_size - 1) // block_size
# If Q fits in one chunk, call directly
if q_len <= chunk_size:
return xattn_estimate_chunked(
query, key,
q_start_pos=0,
block_size=block_size,
stride=stride,
threshold=threshold,
use_triton=True,
chunk_size=chunk_size,
)
# External chunking: split Q and call for each chunk
num_q_chunks = (q_len + chunk_size - 1) // chunk_size
print(f" External chunking: {num_q_chunks} chunks")
combined_attn_sum = torch.zeros(
batch_size, num_heads, q_block_num, k_block_num,
dtype=query.dtype, device=query.device
)
combined_mask = torch.zeros(
batch_size, num_heads, q_block_num, k_block_num,
dtype=torch.bool, device=query.device
)
q_block_offset = 0
for q_chunk_idx in range(num_q_chunks):
q_chunk_start = q_chunk_idx * chunk_size
q_chunk_end = min((q_chunk_idx + 1) * chunk_size, q_len)
q_chunk = query[:, :, q_chunk_start:q_chunk_end, :]
# For causal attention, K accumulates up to current Q position
# q_start_pos=0 means Q starts at position 0 in the full sequence
# K is [0, q_chunk_end) for causal attention
k_end = q_chunk_end
k_chunk = key[:, :, :k_end, :]
attn_sum_chunk, mask_chunk = xattn_estimate_chunked(
q_chunk, k_chunk,
q_start_pos=q_chunk_start,
block_size=block_size,
stride=stride,
threshold=threshold,
use_triton=True,
chunk_size=chunk_size,
)
# Place chunk results into combined output
chunk_q_blocks = mask_chunk.shape[2]
chunk_k_blocks = mask_chunk.shape[3]
combined_attn_sum[:, :, q_block_offset:q_block_offset+chunk_q_blocks, :chunk_k_blocks] = attn_sum_chunk
combined_mask[:, :, q_block_offset:q_block_offset+chunk_q_blocks, :chunk_k_blocks] = mask_chunk
q_block_offset += chunk_q_blocks
return combined_attn_sum, combined_mask
def test_single_seq_len(seq_len, num_heads=32, head_dim=128):
"""Test a single sequence length."""
print(f"\nTesting seq_len={seq_len}")
print("=" * 60)
# Generate random Q/K
query = torch.randn(1, num_heads, seq_len, head_dim, device="cuda", dtype=torch.bfloat16)
key = torch.randn(1, num_heads, seq_len, head_dim, device="cuda", dtype=torch.bfloat16)
# Run standard xattn_estimate
print("[1] Running standard xattn_estimate...")
try:
attn_sum_std, mask_std = xattn_estimate(
query, key,
block_size=BLOCK_SIZE,
stride=STRIDE,
threshold=THRESHOLD,
chunk_size=CHUNK_SIZE,
use_triton=True,
causal=True,
)
density_std = mask_std.float().mean().item()
print(f" mask shape: {mask_std.shape}, density: {density_std:.4f}")
except Exception as e:
print(f" ERROR: {e}")
traceback.print_exc()
return False
# Run chunked xattn_estimate with EXTERNAL chunking
print("[2] Running chunked xattn_estimate (external chunking)...")
try:
attn_sum_chunked, mask_chunked = run_chunked_externally(
query, key,
block_size=BLOCK_SIZE,
stride=STRIDE,
threshold=THRESHOLD,
chunk_size=CHUNK_SIZE,
)
density_chunked = mask_chunked.float().mean().item()
print(f" mask shape: {mask_chunked.shape}, density: {density_chunked:.4f}")
except Exception as e:
print(f" ERROR: {e}")
traceback.print_exc()
return False
# Compare results
print("[3] Comparing results...")
chunked_q_blocks = mask_chunked.shape[2]
chunked_k_blocks = mask_chunked.shape[3]
# Extract comparable region from standard mask
mask_std_comparable = mask_std[:, :, :chunked_q_blocks, :chunked_k_blocks]
# Compare masks
masks_match = compare_masks(mask_std_comparable, mask_chunked, "standard", "chunked")
# Compare attn_sums
attn_sum_std_comparable = attn_sum_std[:, :, :chunked_q_blocks, :chunked_k_blocks]
if attn_sum_std_comparable.shape == attn_sum_chunked.shape:
attn_diff = (attn_sum_std_comparable - attn_sum_chunked).abs().max().item()
print(f" Attn sum max diff: {attn_diff:.6f}")
else:
print(f" Attn sum shape mismatch: std={attn_sum_std_comparable.shape}, chunked={attn_sum_chunked.shape}")
# Clean up GPU memory
del query, key, attn_sum_std, mask_std, attn_sum_chunked, mask_chunked
torch.cuda.empty_cache()
return masks_match
# ============================================================
# Main Test
# ============================================================
if __name__ == "__main__":
print("XAttention Chunked vs Standard Test")
print("=" * 60)
print(f"Config: block_size={BLOCK_SIZE}, stride={STRIDE}, threshold={THRESHOLD}")
print(f"External chunk_size={CHUNK_SIZE}")
print()
# Check CUDA availability
if not torch.cuda.is_available():
print("CUDA not available!")
sys.exit(1)
print(f"Using GPU: {torch.cuda.get_device_name(0)}")
print("✓ xattn_estimate imported")
print("✓ xattn_estimate_chunked imported")
# Run tests
all_passed = True
results = []
for seq_len in TEST_SEQ_LENS:
passed = test_single_seq_len(seq_len)
chunks = (seq_len + CHUNK_SIZE - 1) // CHUNK_SIZE
results.append((seq_len, chunks, passed))
if not passed:
all_passed = False
# Summary
print("\n" + "=" * 60)
print("SUMMARY")
print("=" * 60)
for seq_len, chunks, passed in results:
status = "PASSED" if passed else "FAILED"
print(f" seq_len={seq_len:5d} ({chunks} chunk{'s' if chunks > 1 else ' '}): {status}")
print("=" * 60)
if all_passed:
print("ALL TESTS PASSED!")
sys.exit(0)
else:
print("SOME TESTS FAILED!")
sys.exit(1)