feat: add xattn_estimate_chunked for chunked prefill support
- Add xattn_estimate_chunked function ported from COMPASS - Support chunked prefill with q_start_pos parameter - Ensure 100% consistency with standard xattn_estimate when using matching chunk_size parameter - Add test and documentation Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
99
docs/xattn_chunked_prefill.md
Normal file
99
docs/xattn_chunked_prefill.md
Normal file
@@ -0,0 +1,99 @@
|
||||
# XAttention Chunked Prefill
|
||||
|
||||
## 概述
|
||||
|
||||
`xattn_estimate_chunked` 提供了 XAttention 的 chunked prefill 支持,允许将长序列分块处理,适用于显存受限或需要与 decode 请求交错执行的场景。
|
||||
|
||||
## 核心设计
|
||||
|
||||
### Chunked Prefill 模式
|
||||
|
||||
```
|
||||
Full Prefill: Q[0:N] × K[0:N] → Output[0:N]
|
||||
|
||||
Chunked Prefill: Q[0:C] × K[0:C] → Output[0:C]
|
||||
Q[C:2C] × K[0:2C] → Output[C:2C]
|
||||
Q[2C:3C] × K[0:3C] → Output[2C:3C]
|
||||
...
|
||||
```
|
||||
|
||||
关键特点:
|
||||
- **Q 分块处理**:每次只处理一个 Q chunk
|
||||
- **K/V 累积**:K/V cache 随着 chunk 处理逐步累积
|
||||
- **位置感知**:通过 `q_start_pos` 参数传递当前 chunk 在原序列中的位置
|
||||
|
||||
## API
|
||||
|
||||
### xattn_estimate_chunked
|
||||
|
||||
```python
|
||||
def xattn_estimate_chunked(
|
||||
query_states: torch.Tensor, # (B, H, q_chunk_len, D) - 当前 Q chunk
|
||||
key_states: torch.Tensor, # (B, H, k_len, D) - 累积的完整 K
|
||||
q_start_pos: int, # 当前 chunk 在原序列中的起始位置
|
||||
block_size: int = 128, # 稀疏 attention 的 block 大小
|
||||
stride: int = 8, # 估计时的下采样步长
|
||||
threshold: float = 0.9, # block 选择阈值
|
||||
chunk_size: int = 16384, # Triton kernel 对齐大小
|
||||
use_triton: bool = True,
|
||||
causal: bool = True,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Returns:
|
||||
attn_sums: (B, H, q_blocks, k_blocks) - 每个 block 的 attention 分数
|
||||
simple_mask: (B, H, q_blocks, k_blocks) - 选中的 block mask
|
||||
"""
|
||||
```
|
||||
|
||||
## 使用方式
|
||||
|
||||
### 外部分块(生产部署推荐)
|
||||
|
||||
由 LLM 框架控制 chunk 划分:
|
||||
|
||||
```python
|
||||
# 在 attention forward 中
|
||||
def forward(self, query, key, value, position_ids, kv_cache, ...):
|
||||
q_start_pos = position_ids[0].item()
|
||||
|
||||
# 估计 sparse pattern
|
||||
attn_sum, mask = xattn_estimate_chunked(
|
||||
query, kv_cache.key,
|
||||
q_start_pos=q_start_pos,
|
||||
block_size=128,
|
||||
stride=4,
|
||||
threshold=0.9,
|
||||
chunk_size=4096, # 必须与外部 chunk 大小匹配
|
||||
)
|
||||
|
||||
# 使用 mask 进行 sparse attention
|
||||
...
|
||||
```
|
||||
|
||||
### 一致性要求
|
||||
|
||||
**重要**:要实现 chunked 与 standard 版本 100% 一致,必须:
|
||||
|
||||
1. 标准版和 chunked 版使用**相同的 `chunk_size`** 参数
|
||||
2. 例如:`xattn_estimate(..., chunk_size=4096)` 和 `xattn_estimate_chunked(..., chunk_size=4096)`
|
||||
|
||||
## 与标准版的关系
|
||||
|
||||
| 函数 | 用途 |
|
||||
|------|------|
|
||||
| `xattn_estimate` | Full prefill 的 pattern 估计 |
|
||||
| `xattn_estimate_chunked` | Chunked prefill 的 pattern 估计 |
|
||||
|
||||
**一致性保证**:当 `chunk_size` 参数匹配时,`xattn_estimate_chunked` 与 `xattn_estimate` 产生**完全相同**的 mask。
|
||||
|
||||
## 测试
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
||||
python tests/test_xattn_estimate_chunked.py
|
||||
```
|
||||
|
||||
## 验证结果
|
||||
|
||||
使用真实 QKV 数据(8K-64K 序列长度)测试:
|
||||
- 所有 chunk_size (2048, 4096, 8192) 均达到 100% 匹配
|
||||
Reference in New Issue
Block a user