- Add xattn_estimate_chunked function ported from COMPASS - Support chunked prefill with q_start_pos parameter - Ensure 100% consistency with standard xattn_estimate when using matching chunk_size parameter - Add test and documentation Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2.9 KiB
2.9 KiB
XAttention Chunked Prefill
概述
xattn_estimate_chunked 提供了 XAttention 的 chunked prefill 支持,允许将长序列分块处理,适用于显存受限或需要与 decode 请求交错执行的场景。
核心设计
Chunked Prefill 模式
Full Prefill: Q[0:N] × K[0:N] → Output[0:N]
Chunked Prefill: Q[0:C] × K[0:C] → Output[0:C]
Q[C:2C] × K[0:2C] → Output[C:2C]
Q[2C:3C] × K[0:3C] → Output[2C:3C]
...
关键特点:
- Q 分块处理:每次只处理一个 Q chunk
- K/V 累积:K/V cache 随着 chunk 处理逐步累积
- 位置感知:通过
q_start_pos参数传递当前 chunk 在原序列中的位置
API
xattn_estimate_chunked
def xattn_estimate_chunked(
query_states: torch.Tensor, # (B, H, q_chunk_len, D) - 当前 Q chunk
key_states: torch.Tensor, # (B, H, k_len, D) - 累积的完整 K
q_start_pos: int, # 当前 chunk 在原序列中的起始位置
block_size: int = 128, # 稀疏 attention 的 block 大小
stride: int = 8, # 估计时的下采样步长
threshold: float = 0.9, # block 选择阈值
chunk_size: int = 16384, # Triton kernel 对齐大小
use_triton: bool = True,
causal: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Returns:
attn_sums: (B, H, q_blocks, k_blocks) - 每个 block 的 attention 分数
simple_mask: (B, H, q_blocks, k_blocks) - 选中的 block mask
"""
使用方式
外部分块(生产部署推荐)
由 LLM 框架控制 chunk 划分:
# 在 attention forward 中
def forward(self, query, key, value, position_ids, kv_cache, ...):
q_start_pos = position_ids[0].item()
# 估计 sparse pattern
attn_sum, mask = xattn_estimate_chunked(
query, kv_cache.key,
q_start_pos=q_start_pos,
block_size=128,
stride=4,
threshold=0.9,
chunk_size=4096, # 必须与外部 chunk 大小匹配
)
# 使用 mask 进行 sparse attention
...
一致性要求
重要:要实现 chunked 与 standard 版本 100% 一致,必须:
- 标准版和 chunked 版使用相同的
chunk_size参数 - 例如:
xattn_estimate(..., chunk_size=4096)和xattn_estimate_chunked(..., chunk_size=4096)
与标准版的关系
| 函数 | 用途 |
|---|---|
xattn_estimate |
Full prefill 的 pattern 估计 |
xattn_estimate_chunked |
Chunked prefill 的 pattern 估计 |
一致性保证:当 chunk_size 参数匹配时,xattn_estimate_chunked 与 xattn_estimate 产生完全相同的 mask。
测试
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
python tests/test_xattn_estimate_chunked.py
验证结果
使用真实 QKV 数据(8K-64K 序列长度)测试:
- 所有 chunk_size (2048, 4096, 8192) 均达到 100% 匹配