Files
nano-vllm/task_plan.md
2026-01-19 22:34:44 +08:00

11 KiB
Raw Blame History

Task Plan: Sparse Policy 架构重构 v3

Goal

将 chunked prefill 的 attention 计算逻辑完全从 attention.py 移到 SparsePolicy 内部。attention.py 只负责调用 policy不包含任何计算逻辑。

核心设计原则(强制要求)

  1. Policy 内部完成所有计算:包括 attention 计算和结果合并
  2. select_blocks 传入 offload_enginepolicy 通过 offload_engine 加载 blocks
  3. 强制实现计算函数:所有 policy 必须实现 compute_block_attentionmerge_attention_outputs
  4. chunked_prefill 强制 policy 存在:没有 policy 则报错
  5. 外部默认 FULL policymodel_runner.py 默认创建 FullPolicy
  6. attention.py 零计算逻辑_chunked_prefill_attention 只调用 policy不直接调用 flashattn 或 merge

目标架构

model_runner.py:
  默认创建 FullPolicy如果没有指定 sparse policy

attention.py (_chunked_prefill_attention):
  检查 sparse_policy 是否存在
    ↓
  调用 sparse_policy.compute_prefill_attention(q, k, v, ...)
    ↓
  返回最终输出(不包含任何计算逻辑)

SparsePolicy.compute_prefill_attention():
  1. select_blocks(blocks, offload_engine, ctx) → 筛选 blocks
  2. 加载 blocks通过 offload_engine
  3. 遍历 blocks
     - 调用 self.compute_block_attention(q, k, v, ...)
     - 调用 self.merge_attention_outputs(...)
  4. 计算当前 chunk attention
  5. 合并最终结果
  6. 返回 final_output

关键设计决策

决策 说明
决策 1 compute_block_attention 是抽象方法,所有 policy 必须实现
决策 2 merge_attention_outputs 是抽象方法,所有 policy 必须实现
决策 3 compute_prefill_attention 是抽象方法,定义完整的 prefill 流程
决策 4 select_blocks 接收 offload_engine 参数(为未来准备)
决策 5 chunked_prefill 检查 policy 是否存在,不存在则抛出错误
决策 6 model_runner 默认创建 FullPolicy 作为兜底
决策 7 attention.py 的 _chunked_prefill_attention 不包含任何 flashattn 或 merge 调用

Phases

  • Phase 1: 分析当前架构,理解所有计算逻辑的位置
  • Phase 2: 在 SparsePolicy 基类中添加三个抽象方法
  • Phase 3: 修改 FullPolicy实现三个抽象方法
  • Phase 4: 修改 QuestPolicy实现三个抽象方法
  • Phase 5: 修改 XAttentionBSAPolicy实现三个抽象方法
  • Phase 6: 修改 model_runner.py默认创建 FullPolicy
  • Phase 7: 修改 attention.py移除所有计算逻辑只调用 policy
  • Phase 8: 测试验证

Phase 1: 分析当前架构,理解所有计算逻辑的位置

当前 attention.py 中包含的计算逻辑

  1. _ring_buffer_pipeline_load 方法:

    • 调用 offload_engine.load_to_slot_layer()
    • 调用 offload_engine.wait_slot_layer()
    • 调用 offload_engine.get_kv_for_slot()
    • 调用 flash_attn_with_lse()直接调用
    • 调用 merge_attention_outputs()直接调用
  2. _sync_load_previous_chunks 方法:

    • 同上,直接调用 flashattn 和 merge
  3. _chunked_prefill_attention 方法:

    • 调用 _ring_buffer_pipeline_load_sync_load_previous_chunks
    • 调用 flash_attn_with_lse() 计算当前 chunk
    • 调用 merge_attention_outputs() 合并结果

需要移动的计算逻辑

所有 flash_attn_with_lsemerge_attention_outputs 调用都应该在 SparsePolicy 内部。

Phase 2: 在 SparsePolicy 基类中添加三个抽象方法

2.1 compute_block_attention

@abstractmethod
def compute_block_attention(
    self,
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    layer_id: int,
    softmax_scale: float,
    causal: bool,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
    """
    计算单个 block 的 attention。

    Args:
        q: [1, seq_len, num_heads, head_dim] 或 [seq_len, num_heads, head_dim]
        k, v: 同上
        layer_id: 层索引
        softmax_scale: softmax 缩放因子
        causal: 是否应用因果掩码

    Returns:
        (o, lse) - attention 输出和 LSE
    """
    pass

2.2 merge_attention_outputs

@abstractmethod
def merge_attention_outputs(
    self,
    o_acc: torch.Tensor,
    lse_acc: Optional[torch.Tensor],
    o_new: torch.Tensor,
    lse_new: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
    """
    合并两个 attention 输出。

    Args:
        o_acc: 累积的 attention 输出 [1, seq_len, num_heads, head_dim]
        lse_acc: 累积的 LSE
        o_new: 新的 attention 输出
        lse_new: 新的 LSE

    Returns:
        (merged_o, merged_lse)
    """
    pass

2.3 compute_chunked_attention

@abstractmethod
def compute_chunked_attention(
    self,
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    layer_id: int,
    softmax_scale: float,
    offload_engine: OffloadEngine,
    current_chunk_idx: int,
    seq: ChunkedSequence,
    num_tokens: int,
) -> torch.Tensor:
    """
    计算 chunked prefill attention完整流程
    这是 policy 的主入口,定义完整的 prefill 计算流程:
    1. 获取历史 blocks
    2. 筛选 blocks调用 select_blocks
    3. 加载和计算历史 blocks
    4. 计算当前 chunk attention
    5. 合并所有结果

    Args:
        q, k, v: 当前 chunk 的 QKV
        layer_id: 层索引
        softmax_scale: softmax 缩放因子
        offload_engine: offload engine
        current_chunk_idx: 当前 chunk 索引
        seq: chunked 序列
        num_tokens: 当前 chunk 的 token 数

    Returns:
        [seq_len, num_heads, head_dim] 最终 attention输出
    """
    pass

2.4 修改 select_blocks 接口

def select_blocks(
    self,
    available_blocks: List[int],
    offload_engine: OffloadEngine,
    ctx: PolicyContext,
) -> List[int]:
    """
    选择要加载的 blocks。

    Args:
        available_blocks: 所有可用的 block IDs
        offload_engine: offload engine为未来准备当前可能不使用
        ctx: policy context

    Returns:
        选择的 block IDs
    """
    pass

Phase 3: 修改 FullPolicy实现三个抽象方法

3.1 FullPolicy.compute_block_attention

直接调用 flash_attn_with_lse,处理 3D 输入。

3.2 FullPolicy.merge_attention_outputs

调用 chunked_attention.merge_attention_outputs

3.3 FullPolicy.compute_prefill_attention

实现完整的 prefill 流程:

  1. 获取 cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
  2. 调用 select_blocks(cpu_block_table, offload_engine, ctx)
  3. 遍历 blocks
    • offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id)
    • offload_engine.wait_slot_layer(slot)
    • k, v = offload_engine.get_kv_for_slot(slot)
    • 调用 self.compute_block_attention(q, k, v, layer_id, scale, causal=False)
    • 调用 self.merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
  4. 计算当前 chunk attention
  5. 合并最终结果

需要移动的代码

attention.py_ring_buffer_pipeline_load_sync_load_previous_chunks 移动逻辑:

  • slot 遍历逻辑
  • offload_engine 调用
  • 计算和合并逻辑

attention.py_chunked_prefill_attention 移动逻辑:

  • 当前 chunk 的 attention 计算
  • 最终合并逻辑

Phase 4: 修改 QuestPolicy

QuestPolicy 实现与 FullPolicy 类似,区别在于:

  • select_blocks 返回 Top-K blocks
  • 其他计算逻辑相同

Phase 5: 修改 XAttentionBSAPolicy

当前 XAttentionBSAPolicy 只返回所有 blocks修改后

  • select_blocks 当前返回所有 blocks
  • compute_block_attention 与 FullPolicy 相同
  • merge_attention_outputs 与 FullPolicy 相同
  • compute_prefill_attention 与 FullPolicy 相同

未来可以实现稀疏计算。

Phase 6: 修改 model_runner.py默认创建 FullPolicy

6.1 当前创建 sparse policy 的逻辑

# 当前:只有指定 sparse_policy_type 时才创建
if sparse_policy_type is not None:
    sparse_policy = create_sparse_policy(sparse_policy_type, **kwargs)

6.2 修改后

# 默认创建 FullPolicy
if sparse_policy_type is None:
    sparse_policy_type = SparsePolicyType.FULL

sparse_policy = create_sparse_policy(sparse_policy_type, **kwargs)

6.3 位置

model_runner.py 中的 allocate_kv_cache 方法。

Phase 7: 修改 attention.py移除所有计算逻辑

7.1 _chunked_prefill_attention 简化

当前(伪代码)

# 获取 cpu_block_table
# 调用 select_blocks
# 调用 _ring_buffer_pipeline_load包含计算逻辑
# 计算当前 chunkflash_attn
# 合并结果merge

修改后

sparse_policy = kvcache_manager.sparse_policy
if sparse_policy is None:
    raise RuntimeError("sparse_policy is required for chunked prefill")

o = sparse_policy.compute_prefill_attention(
    q, k, v, self.layer_id, self.scale,
    offload_engine, current_chunk_idx, seq, num_tokens
)

# 直接返回不需要合并policy 内部已完成所有计算)
return o

7.2 删除的方法

删除以下方法(逻辑移到 policy 中):

  • _ring_buffer_pipeline_load - 逻辑移到 FullPolicy.compute_prefill_attention
  • _sync_load_previous_chunks - 逻辑移到 FullPolicy.compute_prefill_attention

7.3 保留的方法

  • _decode_with_layer_pipeline - decode 逻辑保持不变
  • _decode_ring_buffer_pipeline - decode 逻辑保持不变

Phase 8: 测试验证

  • 运行 test_needle.py --enable-offload (FULL policy)
  • 验证输出正确 (needle value: 7492)
  • 验证性能无明显下降

关键文件清单

文件 修改内容
nanovllm/kvcache/sparse/policy.py 添加三个抽象方法,修改 select_blocks 签名
nanovllm/kvcache/sparse/full_policy.py 实现三个抽象方法,移动计算逻辑
nanovllm/kvcache/sparse/quest.py 实现三个抽象方法
nanovllm/kvcache/sparse/xattn_bsa.py 实现三个抽象方法
nanovllm/engine/model_runner.py 默认创建 FullPolicy
nanovllm/layers/attention.py 简化 _chunked_prefill_attention删除计算方法

Decisions Made

  • 决策 1: 三个方法都是抽象方法,强制所有 policy 实现
  • 决策 2: compute_prefill_attention 定义完整的 prefill 流程,是 policy 的主入口
  • 决策 3: attention.py 只调用 policy.compute_prefill_attention零计算逻辑
  • 决策 4: chunked_prefill 检查 policy 是否存在,不存在则抛出错误
  • 决策 5: model_runner 默认创建 FullPolicy 作为兜底
  • 决策 6: _ring_buffer_pipeline_load 和 _sync_load_previous_chunks 删除,逻辑移到 policy

Errors Encountered

(待记录)

Status

Currently in Phase 1 - 分析当前架构,理解所有计算逻辑的位置