Files
nano-vllm/task_plan.md
2026-01-23 03:36:12 +08:00

9.2 KiB
Raw Blame History

Task Plan: XAttention BSA 真正的 Sparse 实现

Goal

实现 XAttentionBSAPolicy 的真正 sparse attentionselect_blocks 中使用 xattn_estimate_chunked 选择重要的 blocks然后复用 FullAttentionPolicy 的 ring buffer pipeline。

验收标准:

CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
    python tests/test_ruler.py \
    --model ~/models/Llama-3.1-8B-Instruct \
    --enable-offload \
    --sparse-policy XATTN_BSA \
    --datasets niah_single_1 \
    --sample-indices 0,1,2,3,4
# 期望: 5/5 PASS并且真正使用 sparse selection

当前状态: Phase 1 - 代码分析完成

核心设计理解

1. Block Size 关系

参数 说明
BSA block_size 128 tokens XAttention 的 block 粒度
kvcache_block_size 1024 tokens CPU offload 的 block 粒度
比例 1:8 1 CPU block = 8 BSA blocks

2. 特化条件(用户要求)

  • BSA chunk_size = 外部 chunk_size
  • 这样 xattn_estimate_chunked 返回的 mask 可以直接映射到 CPU block selection
  • 复用现有的 flash_attn_with_lse + merge_attention_outputs

3. select_blocks 设计

select_blocks(available_blocks, offload_engine, ctx) -> List[int]
    │
    ├─ 1. 从 metadata cache 获取下采样的 K
    │      (在 on_prefill_offload 中收集)
    │
    ├─ 2. 调用 xattn_estimate_chunked(Q, K_downsampled, q_start_pos)
    │      返回 mask: [B, H, q_blocks, k_blocks]
    │
    ├─ 3. 将 BSA k_blocks 映射到 CPU block IDs
    │      每 8 个 BSA blocks = 1 CPU block
    │      只要 8 个中有任意一个被选中,就保留该 CPU block
    │
    └─ 4. 返回 selected_cpu_blocks

4. Metadata 存储策略

方案 A: 存储下采样的 K内存友好

# on_prefill_offload 中:
k_downsampled = k_cache[::stride]  # [block_size/stride, H, D]
self._k_cache[layer_id][cpu_block_id] = k_downsampled

内存计算 (stride=8):

  • 每 block: (1024/8) * 8 * 128 * 2 bytes = 256 KB
  • 256 blocks * 32 layers = 2 GB (GPU 上用于快速估计)

方案 B: 存储 min/max metadata (更省内存)

# on_prefill_offload 中:
k_min = k_cache[:num_valid].min(dim=0).values  # [H, D]
k_max = k_cache[:num_valid].max(dim=0).values  # [H, D]
  • 但这需要不同的估计算法,不能直接用 xattn_estimate

决定: 使用方案 A下采样 K因为可以直接复用 xattn_estimate_chunked

Phases

  • Phase 1: 代码分析,理解当前实现
  • Phase 2: 实现 on_prefill_offload 收集 K metadata
  • Phase 3: 实现 select_blocks 中的 xattn estimation
  • Phase 4: 实现 BSA block → CPU block 的映射
  • Phase 5: 测试验证

Phase 2: on_prefill_offload 实现

需要修改的文件

  • nanovllm/kvcache/sparse/xattn_bsa.py

实现细节

class XAttentionBSAPolicy(SparsePolicy):
    def __init__(self, threshold=0.9, stride=8, ...):
        self.threshold = threshold
        self.stride = stride
        self._k_cache: Dict[int, Dict[int, torch.Tensor]] = {}
        # _k_cache[layer_id][cpu_block_id] = k_downsampled

    def initialize(self, num_layers, num_kv_heads, head_dim, num_cpu_blocks, dtype, device):
        """初始化 K cache 结构"""
        self._k_cache = {layer_id: {} for layer_id in range(num_layers)}
        self._num_kv_heads = num_kv_heads
        self._head_dim = head_dim

    def on_prefill_offload(self, cpu_block_id, layer_id, k_cache, num_valid_tokens):
        """收集下采样的 K 用于后续估计"""
        # k_cache: [block_size, num_kv_heads, head_dim]
        k_downsampled = k_cache[:num_valid_tokens:self.stride].clone()
        # k_downsampled: [num_valid_tokens//stride, num_kv_heads, head_dim]
        self._k_cache[layer_id][cpu_block_id] = k_downsampled

Phase 3: select_blocks 实现

关键问题

  1. Q 从哪里来?

    • ctx.query 需要在调用 select_blocks 时传入
    • 当前 FullAttentionPolicy 传递 query=None
    • 需要修改 compute_chunked_prefill 传递真实的 Q
  2. Q 的格式转换

    • 输入 Q: [seq_len, num_heads, head_dim]
    • xattn 需要: [B, H, q_len, D]
    • 转换: q.unsqueeze(0).transpose(1, 2)
  3. K 的组装

    • _k_cache[layer_id] 获取各 block 的下采样 K
    • available_blocks 顺序 cat 起来
    • 结果: [B, H, total_k_downsampled, D]

实现草案

def select_blocks(self, available_blocks, offload_engine, ctx):
    if not available_blocks or ctx.query is None:
        return available_blocks

    layer_id = ctx.layer_id

    # 1. 组装下采样的 K
    k_list = []
    for cpu_block_id in available_blocks:
        if cpu_block_id in self._k_cache[layer_id]:
            k_list.append(self._k_cache[layer_id][cpu_block_id])

    if not k_list:
        return available_blocks

    k_hist = torch.cat(k_list, dim=0)  # [total_tokens/stride, H, D]
    k_hist = k_hist.unsqueeze(0).transpose(1, 2)  # [1, H, k_len, D]

    # 2. 准备 Q
    q = ctx.query  # [seq_len, num_heads, head_dim]
    q = q.unsqueeze(0).transpose(1, 2)  # [1, H, q_len, D]

    # GQA 扩展(如果需要)
    if q.shape[1] != k_hist.shape[1]:
        num_groups = q.shape[1] // k_hist.shape[1]
        k_hist = k_hist.repeat_interleave(num_groups, dim=1)

    # 3. 计算 q_start_pos
    q_start_pos = len(available_blocks) * ctx.block_size

    # 4. 调用 xattn_estimate_chunked
    # 注意K 已经是下采样的,需要调整参数
    attn_sum, mask = xattn_estimate_chunked(
        q, k_hist,
        q_start_pos=q_start_pos // self.stride,  # 调整到下采样空间
        block_size=self.BSA_BLOCK_SIZE // self.stride,  # 16
        stride=1,  # K 已经下采样
        threshold=self.threshold,
        chunk_size=q.shape[2],  # 与 Q 长度一致
        use_triton=self.use_triton,
    )

    # 5. 从 mask 提取 CPU block IDs
    # mask: [1, H, q_blocks, k_blocks]
    # 对所有 heads 取 OR
    selected_mask = mask.any(dim=1).squeeze(0)  # [q_blocks, k_blocks]
    # 对所有 q_blocks 取 OR只要任意 Q 位置需要这个 K block
    selected_k_mask = selected_mask.any(dim=0)  # [k_blocks]

    # 6. 映射 BSA blocks → CPU blocks
    # 每个 CPU block = 8 BSA blocks (block_size=1024, BSA_block=128)
    bsa_to_cpu_ratio = ctx.block_size // self.BSA_BLOCK_SIZE  # 8
    num_cpu_blocks = len(available_blocks)

    selected_cpu_indices = set()
    for bsa_idx in selected_k_mask.nonzero(as_tuple=True)[0].tolist():
        cpu_idx = bsa_idx // bsa_to_cpu_ratio
        if cpu_idx < num_cpu_blocks:
            selected_cpu_indices.add(cpu_idx)

    selected_blocks = [available_blocks[i] for i in sorted(selected_cpu_indices)]

    logger.info(f"[XAttn] select_blocks: {len(available_blocks)} -> {len(selected_blocks)} "
                f"({100*len(selected_blocks)/len(available_blocks):.1f}%)")

    return selected_blocks

Phase 4: compute_chunked_prefill

关键修改

  1. 传递真实的 Q 给 select_blocks

    • 修改 PolicyContext 构造,设置 query=q
  2. 复用 FullAttentionPolicy 的 pipeline

    • 继承 FullAttentionPolicy 而不是 SparsePolicy
    • 或者直接调用父类方法

方案对比

方案 A: XAttentionBSAPolicy 继承 FullAttentionPolicy

class XAttentionBSAPolicy(FullAttentionPolicy):
    # 只需要 override select_blocks 和 on_prefill_offload
    # compute_chunked_prefill 直接用父类的

方案 B: 独立实现,调用相同的 pipeline 代码

class XAttentionBSAPolicy(SparsePolicy):
    def compute_chunked_prefill(self, q, k, v, ...):
        # 复制 FullAttentionPolicy 的代码
        # 但修改 PolicyContext 传递 query=q

决定: 使用方案 B因为需要在 compute_chunked_prefill 中修改 PolicyContext

Phase 5: 测试

单元测试

# 测试 select_blocks 的 sparsity
python -c "
from nanovllm.kvcache.sparse.xattn_bsa import XAttentionBSAPolicy
policy = XAttentionBSAPolicy(threshold=0.9)
# ... 测试代码
"

集成测试

CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
    python tests/test_ruler.py \
    --model ~/models/Llama-3.1-8B-Instruct \
    --enable-offload \
    --sparse-policy XATTN_BSA \
    --datasets niah_single_1 \
    --sample-indices 0,1,2,3,4

Key Decisions

决策 理由
使用下采样 K 作为 metadata 可以直接复用 xattn_estimate_chunked
stride=8 平衡内存和精度
BSA blocks → CPU blocks 映射用 OR 只要有一个 BSA block 被选中就保留
继承 FullAttentionPolicy 的 pipeline 复用已验证的 ring buffer 流程

Files to Modify

文件 修改
nanovllm/kvcache/sparse/xattn_bsa.py 主要实现initialize, on_prefill_offload, select_blocks

注意事项

  1. GQA 处理: Llama-3.1-8B 有 32 query heads, 8 kv heads需要在估计时扩展 K
  2. 内存管理: _k_cache 存储在 GPU需要在 reset() 时清理
  3. Triton 兼容性: xattn_estimate_chunked 有 Triton bug可能需要用 PyTorch fallback
  4. 边界条件: 第一个 chunk (available_blocks=[]) 时直接返回空列表

Errors Encountered

(待填充)

Status

Currently in Phase 1 - 代码分析完成,准备开始 Phase 2 实现