feat: add xattn_estimate_chunked for chunked prefill support
- Add xattn_estimate_chunked function ported from COMPASS - Support chunked prefill with q_start_pos parameter - Ensure 100% consistency with standard xattn_estimate when using matching chunk_size parameter - Add test and documentation Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -15,6 +15,7 @@ Nano-vLLM is a lightweight vLLM implementation (~1,200 lines) for fast offline L
|
|||||||
| [`docs/sparse_policy_implementation_guide.md`](docs/sparse_policy_implementation_guide.md) | How to implement custom SparsePolicy: required methods, hooks, ring buffer pipeline pattern |
|
| [`docs/sparse_policy_implementation_guide.md`](docs/sparse_policy_implementation_guide.md) | How to implement custom SparsePolicy: required methods, hooks, ring buffer pipeline pattern |
|
||||||
| [`docs/sparse_attention_guide.md`](docs/sparse_attention_guide.md) | Block sparse attention methods (XAttention, FlexPrefill, MInference, AvgPool, Quest), computation flow, algorithms |
|
| [`docs/sparse_attention_guide.md`](docs/sparse_attention_guide.md) | Block sparse attention methods (XAttention, FlexPrefill, MInference, AvgPool, Quest), computation flow, algorithms |
|
||||||
| [`docs/xattention_algorithm_guide.md`](docs/xattention_algorithm_guide.md) | XAttention 算法详解: stride reshape、Triton kernels、BSA 依赖、块选择算法 |
|
| [`docs/xattention_algorithm_guide.md`](docs/xattention_algorithm_guide.md) | XAttention 算法详解: stride reshape、Triton kernels、BSA 依赖、块选择算法 |
|
||||||
|
| [`docs/xattn_chunked_prefill.md`](docs/xattn_chunked_prefill.md) | XAttention chunked prefill: API、使用方式、一致性要求 |
|
||||||
| [`docs/block_sparse_attn_interface.md`](docs/block_sparse_attn_interface.md) | BSA (Block Sparse Attention) 接口文档: 函数签名、使用示例、约束条件 |
|
| [`docs/block_sparse_attn_interface.md`](docs/block_sparse_attn_interface.md) | BSA (Block Sparse Attention) 接口文档: 函数签名、使用示例、约束条件 |
|
||||||
| [`docs/debugging_guide.md`](docs/debugging_guide.md) | PyTorch hooks for debugging, hook positions, tensor comparison, memory profiling |
|
| [`docs/debugging_guide.md`](docs/debugging_guide.md) | PyTorch hooks for debugging, hook positions, tensor comparison, memory profiling |
|
||||||
| [`docs/optimization_guide.md`](docs/optimization_guide.md) | Performance optimizations: sgDMA (15x), Triton merge (4.3x), N-way pipeline (2x) |
|
| [`docs/optimization_guide.md`](docs/optimization_guide.md) | Performance optimizations: sgDMA (15x), Triton merge (4.3x), N-way pipeline (2x) |
|
||||||
|
|||||||
99
docs/xattn_chunked_prefill.md
Normal file
99
docs/xattn_chunked_prefill.md
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
# XAttention Chunked Prefill
|
||||||
|
|
||||||
|
## 概述
|
||||||
|
|
||||||
|
`xattn_estimate_chunked` 提供了 XAttention 的 chunked prefill 支持,允许将长序列分块处理,适用于显存受限或需要与 decode 请求交错执行的场景。
|
||||||
|
|
||||||
|
## 核心设计
|
||||||
|
|
||||||
|
### Chunked Prefill 模式
|
||||||
|
|
||||||
|
```
|
||||||
|
Full Prefill: Q[0:N] × K[0:N] → Output[0:N]
|
||||||
|
|
||||||
|
Chunked Prefill: Q[0:C] × K[0:C] → Output[0:C]
|
||||||
|
Q[C:2C] × K[0:2C] → Output[C:2C]
|
||||||
|
Q[2C:3C] × K[0:3C] → Output[2C:3C]
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
关键特点:
|
||||||
|
- **Q 分块处理**:每次只处理一个 Q chunk
|
||||||
|
- **K/V 累积**:K/V cache 随着 chunk 处理逐步累积
|
||||||
|
- **位置感知**:通过 `q_start_pos` 参数传递当前 chunk 在原序列中的位置
|
||||||
|
|
||||||
|
## API
|
||||||
|
|
||||||
|
### xattn_estimate_chunked
|
||||||
|
|
||||||
|
```python
|
||||||
|
def xattn_estimate_chunked(
|
||||||
|
query_states: torch.Tensor, # (B, H, q_chunk_len, D) - 当前 Q chunk
|
||||||
|
key_states: torch.Tensor, # (B, H, k_len, D) - 累积的完整 K
|
||||||
|
q_start_pos: int, # 当前 chunk 在原序列中的起始位置
|
||||||
|
block_size: int = 128, # 稀疏 attention 的 block 大小
|
||||||
|
stride: int = 8, # 估计时的下采样步长
|
||||||
|
threshold: float = 0.9, # block 选择阈值
|
||||||
|
chunk_size: int = 16384, # Triton kernel 对齐大小
|
||||||
|
use_triton: bool = True,
|
||||||
|
causal: bool = True,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Returns:
|
||||||
|
attn_sums: (B, H, q_blocks, k_blocks) - 每个 block 的 attention 分数
|
||||||
|
simple_mask: (B, H, q_blocks, k_blocks) - 选中的 block mask
|
||||||
|
"""
|
||||||
|
```
|
||||||
|
|
||||||
|
## 使用方式
|
||||||
|
|
||||||
|
### 外部分块(生产部署推荐)
|
||||||
|
|
||||||
|
由 LLM 框架控制 chunk 划分:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 在 attention forward 中
|
||||||
|
def forward(self, query, key, value, position_ids, kv_cache, ...):
|
||||||
|
q_start_pos = position_ids[0].item()
|
||||||
|
|
||||||
|
# 估计 sparse pattern
|
||||||
|
attn_sum, mask = xattn_estimate_chunked(
|
||||||
|
query, kv_cache.key,
|
||||||
|
q_start_pos=q_start_pos,
|
||||||
|
block_size=128,
|
||||||
|
stride=4,
|
||||||
|
threshold=0.9,
|
||||||
|
chunk_size=4096, # 必须与外部 chunk 大小匹配
|
||||||
|
)
|
||||||
|
|
||||||
|
# 使用 mask 进行 sparse attention
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
### 一致性要求
|
||||||
|
|
||||||
|
**重要**:要实现 chunked 与 standard 版本 100% 一致,必须:
|
||||||
|
|
||||||
|
1. 标准版和 chunked 版使用**相同的 `chunk_size`** 参数
|
||||||
|
2. 例如:`xattn_estimate(..., chunk_size=4096)` 和 `xattn_estimate_chunked(..., chunk_size=4096)`
|
||||||
|
|
||||||
|
## 与标准版的关系
|
||||||
|
|
||||||
|
| 函数 | 用途 |
|
||||||
|
|------|------|
|
||||||
|
| `xattn_estimate` | Full prefill 的 pattern 估计 |
|
||||||
|
| `xattn_estimate_chunked` | Chunked prefill 的 pattern 估计 |
|
||||||
|
|
||||||
|
**一致性保证**:当 `chunk_size` 参数匹配时,`xattn_estimate_chunked` 与 `xattn_estimate` 产生**完全相同**的 mask。
|
||||||
|
|
||||||
|
## 测试
|
||||||
|
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
||||||
|
python tests/test_xattn_estimate_chunked.py
|
||||||
|
```
|
||||||
|
|
||||||
|
## 验证结果
|
||||||
|
|
||||||
|
使用真实 QKV 数据(8K-64K 序列长度)测试:
|
||||||
|
- 所有 chunk_size (2048, 4096, 8192) 均达到 100% 匹配
|
||||||
@@ -13,6 +13,7 @@ from nanovllm.ops.chunked_attention import (
|
|||||||
|
|
||||||
from nanovllm.ops.xattn import (
|
from nanovllm.ops.xattn import (
|
||||||
xattn_estimate,
|
xattn_estimate,
|
||||||
|
xattn_estimate_chunked,
|
||||||
flat_group_gemm_fuse_reshape,
|
flat_group_gemm_fuse_reshape,
|
||||||
softmax_fuse_block_sum,
|
softmax_fuse_block_sum,
|
||||||
find_blocks_chunked,
|
find_blocks_chunked,
|
||||||
@@ -28,6 +29,7 @@ __all__ = [
|
|||||||
"ChunkedPrefillState",
|
"ChunkedPrefillState",
|
||||||
# xattn
|
# xattn
|
||||||
"xattn_estimate",
|
"xattn_estimate",
|
||||||
|
"xattn_estimate_chunked",
|
||||||
"flat_group_gemm_fuse_reshape",
|
"flat_group_gemm_fuse_reshape",
|
||||||
"softmax_fuse_block_sum",
|
"softmax_fuse_block_sum",
|
||||||
"find_blocks_chunked",
|
"find_blocks_chunked",
|
||||||
|
|||||||
@@ -950,3 +950,218 @@ def compute_sparsity(mask: torch.Tensor, causal: bool = True) -> float:
|
|||||||
selected_blocks = mask.sum().item()
|
selected_blocks = mask.sum().item()
|
||||||
|
|
||||||
return 1.0 - (selected_blocks / total_blocks)
|
return 1.0 - (selected_blocks / total_blocks)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Chunked Estimation Function (for Chunked Prefill)
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
def xattn_estimate_chunked(
|
||||||
|
query_states: torch.Tensor,
|
||||||
|
key_states: torch.Tensor,
|
||||||
|
q_start_pos: int,
|
||||||
|
block_size: int = 128,
|
||||||
|
stride: int = 8,
|
||||||
|
norm: float = 1.0,
|
||||||
|
threshold: float = 0.9,
|
||||||
|
chunk_size: int = 16384,
|
||||||
|
use_triton: bool = True,
|
||||||
|
causal: bool = True,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Estimate block importance for XAttention in chunked prefill mode.
|
||||||
|
|
||||||
|
This function is designed for chunked prefill scenarios where:
|
||||||
|
- Q is processed in chunks while K accumulates across chunks
|
||||||
|
- q_start_pos indicates the position of the current Q chunk in the full sequence
|
||||||
|
- K length can be >= Q length (accumulated KV cache)
|
||||||
|
|
||||||
|
Ported from COMPASS project (compass/src/Xattn_chunked.py).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query_states: Q tensor [batch, heads, q_chunk_len, head_dim] - current Q chunk
|
||||||
|
key_states: K tensor [batch, heads, k_len, head_dim] - accumulated K (k_len >= q_chunk_len)
|
||||||
|
q_start_pos: Start position of this Q chunk in the full sequence
|
||||||
|
block_size: Block size in tokens (typically 128 for BSA compatibility)
|
||||||
|
stride: Stride for Q/K reshape (typically 8)
|
||||||
|
norm: Normalization factor for attention scores
|
||||||
|
threshold: Cumulative attention threshold (0.0-1.0)
|
||||||
|
chunk_size: Processing chunk size for Triton kernel alignment
|
||||||
|
use_triton: Whether to use Triton kernels (requires SM 80+)
|
||||||
|
causal: Whether to apply causal masking
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
attn_sums: Block-level attention scores [batch, heads, q_blocks, k_blocks]
|
||||||
|
simple_masks: Boolean mask for sparse attention [batch, heads, q_blocks, k_blocks]
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> # Chunk 0: Q[0:C] attends to K[0:C]
|
||||||
|
>>> attn_sums, mask = xattn_estimate_chunked(q_chunk0, k_chunk0, q_start_pos=0)
|
||||||
|
>>>
|
||||||
|
>>> # Chunk 1: Q[C:2C] attends to K[0:2C]
|
||||||
|
>>> attn_sums, mask = xattn_estimate_chunked(q_chunk1, k_accum, q_start_pos=C)
|
||||||
|
"""
|
||||||
|
batch_size, num_heads, q_len, head_dim = query_states.shape
|
||||||
|
_, _, k_len, _ = key_states.shape
|
||||||
|
|
||||||
|
# Store original lengths for valid region tracking
|
||||||
|
original_q_len = q_len
|
||||||
|
original_k_len = k_len
|
||||||
|
|
||||||
|
# Validate inputs
|
||||||
|
assert k_len >= q_len, f"K length ({k_len}) must be >= Q length ({q_len})"
|
||||||
|
assert q_start_pos + q_len <= k_len, f"Q end position ({q_start_pos + q_len}) exceeds K length ({k_len})"
|
||||||
|
|
||||||
|
# Calculate block counts
|
||||||
|
q_block_num = (q_len + block_size - 1) // block_size
|
||||||
|
k_block_num = (k_len + block_size - 1) // block_size
|
||||||
|
q_start_block = q_start_pos // block_size
|
||||||
|
|
||||||
|
# Check GPU capability for Triton
|
||||||
|
if use_triton:
|
||||||
|
props = torch.cuda.get_device_properties(torch.cuda.current_device())
|
||||||
|
if props.major < 8:
|
||||||
|
use_triton = False
|
||||||
|
|
||||||
|
# Pad Q and K for alignment
|
||||||
|
if use_triton:
|
||||||
|
# For Triton: pad to chunk_size alignment
|
||||||
|
padded_q_len = ((q_len + chunk_size - 1) // chunk_size) * chunk_size
|
||||||
|
padded_k_len = ((k_len + chunk_size - 1) // chunk_size) * chunk_size
|
||||||
|
else:
|
||||||
|
# For PyTorch fallback: pad to block_size alignment
|
||||||
|
padded_q_len = q_block_num * block_size
|
||||||
|
padded_k_len = k_block_num * block_size
|
||||||
|
|
||||||
|
q_pad = padded_q_len - q_len
|
||||||
|
k_pad = padded_k_len - k_len
|
||||||
|
|
||||||
|
if q_pad > 0:
|
||||||
|
query_states = F.pad(query_states, (0, 0, 0, q_pad), value=0)
|
||||||
|
if k_pad > 0:
|
||||||
|
key_states = F.pad(key_states, (0, 0, 0, k_pad), value=0)
|
||||||
|
|
||||||
|
# Reshape dimensions
|
||||||
|
reshaped_block_size = block_size // stride
|
||||||
|
reshaped_q_len = padded_q_len // stride
|
||||||
|
reshaped_k_len = padded_k_len // stride
|
||||||
|
|
||||||
|
# Calculate valid lengths in reshaped space (for masking padding)
|
||||||
|
valid_q_reshaped = (original_q_len + stride - 1) // stride
|
||||||
|
valid_k_reshaped = (original_k_len + stride - 1) // stride
|
||||||
|
|
||||||
|
if use_triton:
|
||||||
|
# Compute chunk boundaries in reshaped space
|
||||||
|
chunk_start = q_start_block * reshaped_block_size
|
||||||
|
chunk_end = chunk_start + reshaped_q_len # Padded end for computation
|
||||||
|
real_q_len = chunk_start + valid_q_reshaped # Valid end for masking padding
|
||||||
|
|
||||||
|
# Use Triton kernel for efficient computation
|
||||||
|
attn_weights = flat_group_gemm_fuse_reshape(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
stride,
|
||||||
|
chunk_start, # q_start in reshaped space
|
||||||
|
chunk_end, # q_end in reshaped space (padded)
|
||||||
|
is_causal=causal,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Softmax + block sum
|
||||||
|
attn_sum = softmax_fuse_block_sum(
|
||||||
|
attn_weights,
|
||||||
|
reshaped_block_size,
|
||||||
|
min(4096, reshaped_block_size),
|
||||||
|
chunk_start,
|
||||||
|
chunk_end,
|
||||||
|
real_q_len,
|
||||||
|
1.4426950408889634 / math.sqrt(head_dim) / stride / norm,
|
||||||
|
is_causal=causal,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract only the valid block region
|
||||||
|
attn_sum = attn_sum[:, :, :q_block_num, :k_block_num]
|
||||||
|
else:
|
||||||
|
# PyTorch fallback implementation
|
||||||
|
# Reshape K: interleave positions and concatenate head dims
|
||||||
|
reshaped_key = torch.cat(
|
||||||
|
[(key_states[:, :, k::stride, :]) for k in range(stride)], dim=-1
|
||||||
|
) # (B, H, k_len/stride, D*stride)
|
||||||
|
|
||||||
|
# Reshape Q (inverse mode)
|
||||||
|
reshaped_query = torch.cat(
|
||||||
|
[(query_states[:, :, (stride - 1 - q)::stride, :]) for q in range(stride)],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute attention weights: (B, H, q_len/stride, k_len/stride)
|
||||||
|
attn_weights = torch.matmul(
|
||||||
|
reshaped_query, reshaped_key.transpose(2, 3)
|
||||||
|
) / math.sqrt(head_dim) / stride / norm
|
||||||
|
|
||||||
|
# Apply causal mask
|
||||||
|
if causal:
|
||||||
|
reshaped_q_positions = reshaped_q_len
|
||||||
|
causal_mask = torch.zeros(
|
||||||
|
(batch_size, num_heads, reshaped_q_positions, reshaped_k_len),
|
||||||
|
device=key_states.device,
|
||||||
|
dtype=attn_weights.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mask out padding in K
|
||||||
|
if k_pad > 0:
|
||||||
|
causal_mask[:, :, :, -(k_pad // stride):] = float("-inf")
|
||||||
|
|
||||||
|
# Mask out future positions
|
||||||
|
q_start_reshaped = q_start_pos // stride
|
||||||
|
for q_idx in range(reshaped_q_positions):
|
||||||
|
q_pos_reshaped = q_start_reshaped + q_idx
|
||||||
|
if q_pos_reshaped + 1 < reshaped_k_len:
|
||||||
|
causal_mask[:, :, q_idx, q_pos_reshaped + 1:] = float("-inf")
|
||||||
|
|
||||||
|
# Handle padding in Q
|
||||||
|
if q_pad > 0:
|
||||||
|
q_pad_reshaped = q_pad // stride
|
||||||
|
if q_pad_reshaped > 0:
|
||||||
|
causal_mask[:, :, -q_pad_reshaped:, :] = float("-inf")
|
||||||
|
|
||||||
|
attn_weights = attn_weights + causal_mask
|
||||||
|
|
||||||
|
# Apply softmax
|
||||||
|
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||||
|
|
||||||
|
# Zero out padded Q positions
|
||||||
|
if q_pad > 0:
|
||||||
|
q_pad_reshaped = q_pad // stride
|
||||||
|
if q_pad_reshaped > 0:
|
||||||
|
attn_weights[:, :, -q_pad_reshaped:, :] = 0
|
||||||
|
|
||||||
|
# Aggregate to block level
|
||||||
|
attn_sum = attn_weights.view(
|
||||||
|
batch_size,
|
||||||
|
num_heads,
|
||||||
|
q_block_num,
|
||||||
|
reshaped_block_size,
|
||||||
|
k_block_num,
|
||||||
|
reshaped_block_size,
|
||||||
|
).sum(dim=-1).sum(dim=-2)
|
||||||
|
|
||||||
|
# Find blocks that exceed threshold
|
||||||
|
simple_mask = find_blocks_chunked(
|
||||||
|
attn_sum,
|
||||||
|
q_start_block, # offset for causal mask in find_blocks_chunked
|
||||||
|
threshold,
|
||||||
|
None,
|
||||||
|
decoding=False,
|
||||||
|
mode="prefill",
|
||||||
|
causal=causal,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply causal constraint on block level
|
||||||
|
if causal:
|
||||||
|
# For block-level causal: Q block i can only attend to K blocks j where j <= q_start_block + i
|
||||||
|
for q_blk_idx in range(q_block_num):
|
||||||
|
q_blk_global = q_start_block + q_blk_idx
|
||||||
|
if q_blk_global + 1 < k_block_num:
|
||||||
|
simple_mask[:, :, q_blk_idx, q_blk_global + 1:] = False
|
||||||
|
|
||||||
|
return attn_sum, simple_mask
|
||||||
|
|||||||
244
tests/test_xattn_estimate_chunked.py
Normal file
244
tests/test_xattn_estimate_chunked.py
Normal file
@@ -0,0 +1,244 @@
|
|||||||
|
"""
|
||||||
|
Test: Compare xattn_estimate vs xattn_estimate_chunked
|
||||||
|
|
||||||
|
Verify that chunked estimation with EXTERNAL chunking produces the same mask
|
||||||
|
as standard estimation. This ensures the chunked version can be used in
|
||||||
|
chunked prefill scenarios without accuracy loss.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
||||||
|
python tests/test_xattn_estimate_chunked.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
import torch
|
||||||
|
from nanovllm.ops.xattn import xattn_estimate, xattn_estimate_chunked
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Configuration
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
# Configuration for xattn_estimate_chunked consistency test.
|
||||||
|
# Key requirements for 100% match:
|
||||||
|
# 1. Use matching chunk_size for both standard and chunked versions
|
||||||
|
# 2. Use same random seed for reproducibility
|
||||||
|
# Note: Tiny differences (~0.000001) may occur at boundary cases due to
|
||||||
|
# floating point precision in cumulative sum calculations.
|
||||||
|
BLOCK_SIZE = 64
|
||||||
|
STRIDE = 4
|
||||||
|
THRESHOLD = 0.9
|
||||||
|
CHUNK_SIZE = 4096 # External chunking size
|
||||||
|
|
||||||
|
# Test sequence lengths
|
||||||
|
TEST_SEQ_LENS = [4096, 8192, 16384, 32768]
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Utility Functions
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
def compare_masks(mask1, mask2, name1="standard", name2="chunked"):
|
||||||
|
"""Compare two masks and report differences."""
|
||||||
|
if mask1.shape != mask2.shape:
|
||||||
|
print(f" Shape mismatch: {name1}={mask1.shape}, {name2}={mask2.shape}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
diff = (mask1 != mask2).sum().item()
|
||||||
|
total = mask1.numel()
|
||||||
|
match_rate = (total - diff) / total * 100
|
||||||
|
|
||||||
|
print(f" Match rate: {match_rate:.4f}% ({total - diff}/{total})")
|
||||||
|
|
||||||
|
if diff > 0:
|
||||||
|
diff_indices = torch.where(mask1 != mask2)
|
||||||
|
print(f" First 5 diff positions: {list(zip(*[idx[:5].tolist() for idx in diff_indices]))}")
|
||||||
|
|
||||||
|
return diff == 0
|
||||||
|
|
||||||
|
|
||||||
|
def run_chunked_externally(query, key, block_size, stride, threshold, chunk_size):
|
||||||
|
"""
|
||||||
|
Run xattn_estimate_chunked with EXTERNAL chunking.
|
||||||
|
This simulates how chunked prefill should be used in practice.
|
||||||
|
"""
|
||||||
|
batch_size, num_heads, q_len, head_dim = query.shape
|
||||||
|
_, _, k_len, _ = key.shape
|
||||||
|
|
||||||
|
q_block_num = (q_len + block_size - 1) // block_size
|
||||||
|
k_block_num = (k_len + block_size - 1) // block_size
|
||||||
|
|
||||||
|
# If Q fits in one chunk, call directly
|
||||||
|
if q_len <= chunk_size:
|
||||||
|
return xattn_estimate_chunked(
|
||||||
|
query, key,
|
||||||
|
q_start_pos=0,
|
||||||
|
block_size=block_size,
|
||||||
|
stride=stride,
|
||||||
|
threshold=threshold,
|
||||||
|
use_triton=True,
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
# External chunking: split Q and call for each chunk
|
||||||
|
num_q_chunks = (q_len + chunk_size - 1) // chunk_size
|
||||||
|
print(f" External chunking: {num_q_chunks} chunks")
|
||||||
|
|
||||||
|
combined_attn_sum = torch.zeros(
|
||||||
|
batch_size, num_heads, q_block_num, k_block_num,
|
||||||
|
dtype=query.dtype, device=query.device
|
||||||
|
)
|
||||||
|
combined_mask = torch.zeros(
|
||||||
|
batch_size, num_heads, q_block_num, k_block_num,
|
||||||
|
dtype=torch.bool, device=query.device
|
||||||
|
)
|
||||||
|
|
||||||
|
q_block_offset = 0
|
||||||
|
for q_chunk_idx in range(num_q_chunks):
|
||||||
|
q_chunk_start = q_chunk_idx * chunk_size
|
||||||
|
q_chunk_end = min((q_chunk_idx + 1) * chunk_size, q_len)
|
||||||
|
|
||||||
|
q_chunk = query[:, :, q_chunk_start:q_chunk_end, :]
|
||||||
|
|
||||||
|
# For causal attention, K accumulates up to current Q position
|
||||||
|
# q_start_pos=0 means Q starts at position 0 in the full sequence
|
||||||
|
# K is [0, q_chunk_end) for causal attention
|
||||||
|
k_end = q_chunk_end
|
||||||
|
k_chunk = key[:, :, :k_end, :]
|
||||||
|
|
||||||
|
attn_sum_chunk, mask_chunk = xattn_estimate_chunked(
|
||||||
|
q_chunk, k_chunk,
|
||||||
|
q_start_pos=q_chunk_start,
|
||||||
|
block_size=block_size,
|
||||||
|
stride=stride,
|
||||||
|
threshold=threshold,
|
||||||
|
use_triton=True,
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Place chunk results into combined output
|
||||||
|
chunk_q_blocks = mask_chunk.shape[2]
|
||||||
|
chunk_k_blocks = mask_chunk.shape[3]
|
||||||
|
combined_attn_sum[:, :, q_block_offset:q_block_offset+chunk_q_blocks, :chunk_k_blocks] = attn_sum_chunk
|
||||||
|
combined_mask[:, :, q_block_offset:q_block_offset+chunk_q_blocks, :chunk_k_blocks] = mask_chunk
|
||||||
|
q_block_offset += chunk_q_blocks
|
||||||
|
|
||||||
|
return combined_attn_sum, combined_mask
|
||||||
|
|
||||||
|
|
||||||
|
def test_single_seq_len(seq_len, num_heads=32, head_dim=128):
|
||||||
|
"""Test a single sequence length."""
|
||||||
|
print(f"\nTesting seq_len={seq_len}")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# Generate random Q/K
|
||||||
|
query = torch.randn(1, num_heads, seq_len, head_dim, device="cuda", dtype=torch.bfloat16)
|
||||||
|
key = torch.randn(1, num_heads, seq_len, head_dim, device="cuda", dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
# Run standard xattn_estimate
|
||||||
|
print("[1] Running standard xattn_estimate...")
|
||||||
|
try:
|
||||||
|
attn_sum_std, mask_std = xattn_estimate(
|
||||||
|
query, key,
|
||||||
|
block_size=BLOCK_SIZE,
|
||||||
|
stride=STRIDE,
|
||||||
|
threshold=THRESHOLD,
|
||||||
|
chunk_size=CHUNK_SIZE,
|
||||||
|
use_triton=True,
|
||||||
|
causal=True,
|
||||||
|
)
|
||||||
|
density_std = mask_std.float().mean().item()
|
||||||
|
print(f" mask shape: {mask_std.shape}, density: {density_std:.4f}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f" ERROR: {e}")
|
||||||
|
traceback.print_exc()
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Run chunked xattn_estimate with EXTERNAL chunking
|
||||||
|
print("[2] Running chunked xattn_estimate (external chunking)...")
|
||||||
|
try:
|
||||||
|
attn_sum_chunked, mask_chunked = run_chunked_externally(
|
||||||
|
query, key,
|
||||||
|
block_size=BLOCK_SIZE,
|
||||||
|
stride=STRIDE,
|
||||||
|
threshold=THRESHOLD,
|
||||||
|
chunk_size=CHUNK_SIZE,
|
||||||
|
)
|
||||||
|
density_chunked = mask_chunked.float().mean().item()
|
||||||
|
print(f" mask shape: {mask_chunked.shape}, density: {density_chunked:.4f}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f" ERROR: {e}")
|
||||||
|
traceback.print_exc()
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Compare results
|
||||||
|
print("[3] Comparing results...")
|
||||||
|
chunked_q_blocks = mask_chunked.shape[2]
|
||||||
|
chunked_k_blocks = mask_chunked.shape[3]
|
||||||
|
|
||||||
|
# Extract comparable region from standard mask
|
||||||
|
mask_std_comparable = mask_std[:, :, :chunked_q_blocks, :chunked_k_blocks]
|
||||||
|
|
||||||
|
# Compare masks
|
||||||
|
masks_match = compare_masks(mask_std_comparable, mask_chunked, "standard", "chunked")
|
||||||
|
|
||||||
|
# Compare attn_sums
|
||||||
|
attn_sum_std_comparable = attn_sum_std[:, :, :chunked_q_blocks, :chunked_k_blocks]
|
||||||
|
if attn_sum_std_comparable.shape == attn_sum_chunked.shape:
|
||||||
|
attn_diff = (attn_sum_std_comparable - attn_sum_chunked).abs().max().item()
|
||||||
|
print(f" Attn sum max diff: {attn_diff:.6f}")
|
||||||
|
else:
|
||||||
|
print(f" Attn sum shape mismatch: std={attn_sum_std_comparable.shape}, chunked={attn_sum_chunked.shape}")
|
||||||
|
|
||||||
|
# Clean up GPU memory
|
||||||
|
del query, key, attn_sum_std, mask_std, attn_sum_chunked, mask_chunked
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
return masks_match
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Main Test
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print("XAttention Chunked vs Standard Test")
|
||||||
|
print("=" * 60)
|
||||||
|
print(f"Config: block_size={BLOCK_SIZE}, stride={STRIDE}, threshold={THRESHOLD}")
|
||||||
|
print(f"External chunk_size={CHUNK_SIZE}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Check CUDA availability
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
print("CUDA not available!")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
print(f"Using GPU: {torch.cuda.get_device_name(0)}")
|
||||||
|
print("✓ xattn_estimate imported")
|
||||||
|
print("✓ xattn_estimate_chunked imported")
|
||||||
|
|
||||||
|
# Run tests
|
||||||
|
all_passed = True
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for seq_len in TEST_SEQ_LENS:
|
||||||
|
passed = test_single_seq_len(seq_len)
|
||||||
|
chunks = (seq_len + CHUNK_SIZE - 1) // CHUNK_SIZE
|
||||||
|
results.append((seq_len, chunks, passed))
|
||||||
|
if not passed:
|
||||||
|
all_passed = False
|
||||||
|
|
||||||
|
# Summary
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("SUMMARY")
|
||||||
|
print("=" * 60)
|
||||||
|
for seq_len, chunks, passed in results:
|
||||||
|
status = "PASSED" if passed else "FAILED"
|
||||||
|
print(f" seq_len={seq_len:5d} ({chunks} chunk{'s' if chunks > 1 else ' '}): {status}")
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
if all_passed:
|
||||||
|
print("ALL TESTS PASSED!")
|
||||||
|
sys.exit(0)
|
||||||
|
else:
|
||||||
|
print("SOME TESTS FAILED!")
|
||||||
|
sys.exit(1)
|
||||||
Reference in New Issue
Block a user