Files
nano-vllm/task_plan.md
2026-01-23 03:36:12 +08:00

287 lines
9.2 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# Task Plan: XAttention BSA 真正的 Sparse 实现
## Goal
实现 XAttentionBSAPolicy 的真正 sparse attention`select_blocks` 中使用 `xattn_estimate_chunked` 选择重要的 blocks然后复用 FullAttentionPolicy 的 ring buffer pipeline。
**验收标准**:
```bash
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
python tests/test_ruler.py \
--model ~/models/Llama-3.1-8B-Instruct \
--enable-offload \
--sparse-policy XATTN_BSA \
--datasets niah_single_1 \
--sample-indices 0,1,2,3,4
# 期望: 5/5 PASS并且真正使用 sparse selection
```
## 当前状态: Phase 1 - 代码分析完成
## 核心设计理解
### 1. Block Size 关系
| 参数 | 值 | 说明 |
|------|-----|------|
| BSA block_size | 128 tokens | XAttention 的 block 粒度 |
| kvcache_block_size | 1024 tokens | CPU offload 的 block 粒度 |
| 比例 | 1:8 | 1 CPU block = 8 BSA blocks |
### 2. 特化条件(用户要求)
- BSA chunk_size = 外部 chunk_size
- 这样 `xattn_estimate_chunked` 返回的 mask 可以直接映射到 CPU block selection
- 复用现有的 `flash_attn_with_lse` + `merge_attention_outputs`
### 3. select_blocks 设计
```
select_blocks(available_blocks, offload_engine, ctx) -> List[int]
├─ 1. 从 metadata cache 获取下采样的 K
│ (在 on_prefill_offload 中收集)
├─ 2. 调用 xattn_estimate_chunked(Q, K_downsampled, q_start_pos)
│ 返回 mask: [B, H, q_blocks, k_blocks]
├─ 3. 将 BSA k_blocks 映射到 CPU block IDs
│ 每 8 个 BSA blocks = 1 CPU block
│ 只要 8 个中有任意一个被选中,就保留该 CPU block
└─ 4. 返回 selected_cpu_blocks
```
### 4. Metadata 存储策略
**方案 A**: 存储下采样的 K内存友好
```python
# on_prefill_offload 中:
k_downsampled = k_cache[::stride] # [block_size/stride, H, D]
self._k_cache[layer_id][cpu_block_id] = k_downsampled
```
**内存计算** (stride=8):
- 每 block: (1024/8) * 8 * 128 * 2 bytes = 256 KB
- 256 blocks * 32 layers = 2 GB (GPU 上用于快速估计)
**方案 B**: 存储 min/max metadata (更省内存)
```python
# on_prefill_offload 中:
k_min = k_cache[:num_valid].min(dim=0).values # [H, D]
k_max = k_cache[:num_valid].max(dim=0).values # [H, D]
```
- 但这需要不同的估计算法,不能直接用 xattn_estimate
**决定**: 使用方案 A下采样 K因为可以直接复用 xattn_estimate_chunked
## Phases
- [x] Phase 1: 代码分析,理解当前实现
- [ ] Phase 2: 实现 on_prefill_offload 收集 K metadata
- [ ] Phase 3: 实现 select_blocks 中的 xattn estimation
- [ ] Phase 4: 实现 BSA block → CPU block 的映射
- [ ] Phase 5: 测试验证
## Phase 2: on_prefill_offload 实现
### 需要修改的文件
- `nanovllm/kvcache/sparse/xattn_bsa.py`
### 实现细节
```python
class XAttentionBSAPolicy(SparsePolicy):
def __init__(self, threshold=0.9, stride=8, ...):
self.threshold = threshold
self.stride = stride
self._k_cache: Dict[int, Dict[int, torch.Tensor]] = {}
# _k_cache[layer_id][cpu_block_id] = k_downsampled
def initialize(self, num_layers, num_kv_heads, head_dim, num_cpu_blocks, dtype, device):
"""初始化 K cache 结构"""
self._k_cache = {layer_id: {} for layer_id in range(num_layers)}
self._num_kv_heads = num_kv_heads
self._head_dim = head_dim
def on_prefill_offload(self, cpu_block_id, layer_id, k_cache, num_valid_tokens):
"""收集下采样的 K 用于后续估计"""
# k_cache: [block_size, num_kv_heads, head_dim]
k_downsampled = k_cache[:num_valid_tokens:self.stride].clone()
# k_downsampled: [num_valid_tokens//stride, num_kv_heads, head_dim]
self._k_cache[layer_id][cpu_block_id] = k_downsampled
```
## Phase 3: select_blocks 实现
### 关键问题
1. **Q 从哪里来?**
- `ctx.query` 需要在调用 select_blocks 时传入
- 当前 FullAttentionPolicy 传递 `query=None`
- 需要修改 compute_chunked_prefill 传递真实的 Q
2. **Q 的格式转换**
- 输入 Q: [seq_len, num_heads, head_dim]
- xattn 需要: [B, H, q_len, D]
- 转换: `q.unsqueeze(0).transpose(1, 2)`
3. **K 的组装**
-`_k_cache[layer_id]` 获取各 block 的下采样 K
-`available_blocks` 顺序 cat 起来
- 结果: [B, H, total_k_downsampled, D]
### 实现草案
```python
def select_blocks(self, available_blocks, offload_engine, ctx):
if not available_blocks or ctx.query is None:
return available_blocks
layer_id = ctx.layer_id
# 1. 组装下采样的 K
k_list = []
for cpu_block_id in available_blocks:
if cpu_block_id in self._k_cache[layer_id]:
k_list.append(self._k_cache[layer_id][cpu_block_id])
if not k_list:
return available_blocks
k_hist = torch.cat(k_list, dim=0) # [total_tokens/stride, H, D]
k_hist = k_hist.unsqueeze(0).transpose(1, 2) # [1, H, k_len, D]
# 2. 准备 Q
q = ctx.query # [seq_len, num_heads, head_dim]
q = q.unsqueeze(0).transpose(1, 2) # [1, H, q_len, D]
# GQA 扩展(如果需要)
if q.shape[1] != k_hist.shape[1]:
num_groups = q.shape[1] // k_hist.shape[1]
k_hist = k_hist.repeat_interleave(num_groups, dim=1)
# 3. 计算 q_start_pos
q_start_pos = len(available_blocks) * ctx.block_size
# 4. 调用 xattn_estimate_chunked
# 注意K 已经是下采样的,需要调整参数
attn_sum, mask = xattn_estimate_chunked(
q, k_hist,
q_start_pos=q_start_pos // self.stride, # 调整到下采样空间
block_size=self.BSA_BLOCK_SIZE // self.stride, # 16
stride=1, # K 已经下采样
threshold=self.threshold,
chunk_size=q.shape[2], # 与 Q 长度一致
use_triton=self.use_triton,
)
# 5. 从 mask 提取 CPU block IDs
# mask: [1, H, q_blocks, k_blocks]
# 对所有 heads 取 OR
selected_mask = mask.any(dim=1).squeeze(0) # [q_blocks, k_blocks]
# 对所有 q_blocks 取 OR只要任意 Q 位置需要这个 K block
selected_k_mask = selected_mask.any(dim=0) # [k_blocks]
# 6. 映射 BSA blocks → CPU blocks
# 每个 CPU block = 8 BSA blocks (block_size=1024, BSA_block=128)
bsa_to_cpu_ratio = ctx.block_size // self.BSA_BLOCK_SIZE # 8
num_cpu_blocks = len(available_blocks)
selected_cpu_indices = set()
for bsa_idx in selected_k_mask.nonzero(as_tuple=True)[0].tolist():
cpu_idx = bsa_idx // bsa_to_cpu_ratio
if cpu_idx < num_cpu_blocks:
selected_cpu_indices.add(cpu_idx)
selected_blocks = [available_blocks[i] for i in sorted(selected_cpu_indices)]
logger.info(f"[XAttn] select_blocks: {len(available_blocks)} -> {len(selected_blocks)} "
f"({100*len(selected_blocks)/len(available_blocks):.1f}%)")
return selected_blocks
```
## Phase 4: compute_chunked_prefill
### 关键修改
1. **传递真实的 Q 给 select_blocks**
- 修改 PolicyContext 构造,设置 `query=q`
2. **复用 FullAttentionPolicy 的 pipeline**
- 继承 FullAttentionPolicy 而不是 SparsePolicy
- 或者直接调用父类方法
### 方案对比
**方案 A**: XAttentionBSAPolicy 继承 FullAttentionPolicy
```python
class XAttentionBSAPolicy(FullAttentionPolicy):
# 只需要 override select_blocks 和 on_prefill_offload
# compute_chunked_prefill 直接用父类的
```
**方案 B**: 独立实现,调用相同的 pipeline 代码
```python
class XAttentionBSAPolicy(SparsePolicy):
def compute_chunked_prefill(self, q, k, v, ...):
# 复制 FullAttentionPolicy 的代码
# 但修改 PolicyContext 传递 query=q
```
**决定**: 使用方案 B因为需要在 compute_chunked_prefill 中修改 PolicyContext
## Phase 5: 测试
### 单元测试
```bash
# 测试 select_blocks 的 sparsity
python -c "
from nanovllm.kvcache.sparse.xattn_bsa import XAttentionBSAPolicy
policy = XAttentionBSAPolicy(threshold=0.9)
# ... 测试代码
"
```
### 集成测试
```bash
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
python tests/test_ruler.py \
--model ~/models/Llama-3.1-8B-Instruct \
--enable-offload \
--sparse-policy XATTN_BSA \
--datasets niah_single_1 \
--sample-indices 0,1,2,3,4
```
## Key Decisions
| 决策 | 理由 |
|------|------|
| 使用下采样 K 作为 metadata | 可以直接复用 xattn_estimate_chunked |
| stride=8 | 平衡内存和精度 |
| BSA blocks → CPU blocks 映射用 OR | 只要有一个 BSA block 被选中就保留 |
| 继承 FullAttentionPolicy 的 pipeline | 复用已验证的 ring buffer 流程 |
## Files to Modify
| 文件 | 修改 |
|------|------|
| `nanovllm/kvcache/sparse/xattn_bsa.py` | 主要实现initialize, on_prefill_offload, select_blocks |
## 注意事项
1. **GQA 处理**: Llama-3.1-8B 有 32 query heads, 8 kv heads需要在估计时扩展 K
2. **内存管理**: `_k_cache` 存储在 GPU需要在 reset() 时清理
3. **Triton 兼容性**: xattn_estimate_chunked 有 Triton bug可能需要用 PyTorch fallback
4. **边界条件**: 第一个 chunk (available_blocks=[]) 时直接返回空列表
## Errors Encountered
(待填充)
## Status
**Currently in Phase 1** - 代码分析完成,准备开始 Phase 2 实现