- Update task_plan.md with 6-phase segmented graph implementation plan - Add findings.md documenting 7 key discoveries about current implementation - Add progress.md for tracking implementation progress - Add test_chunk_attention_graph_reuse.py validating 2-graph reuse strategy Key architecture decision: Split transformer layer into 3 segments: - PRE-ATTENTION GRAPH: norm → qkv_proj → rotary (1 graph, reused) - CHUNKED ATTENTION: H2D (eager) + flash_attn (2 graphs) + merge (eager) - POST-ATTENTION GRAPH: o_proj → norm → FFN (1 graph, reused) Total: 4 graphs serving all layers via copy_() tensor updates. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
12 KiB
Task Plan: CUDA Graph 优化 Offload Mode Decode
目标
为 nanovllm 的 CPU offload 模式添加 CUDA Graph 支持,加速 decode 阶段的计算。
问题分析
Transformer 层的完整结构
Qwen3DecoderLayer.forward:
├── input_layernorm (RMSNorm) # ✅ 纯 GPU
├── self_attn:
│ ├── qkv_proj (Linear) # ✅ 纯 GPU
│ ├── q_norm, k_norm (RMSNorm) # ✅ 纯 GPU
│ ├── rotary_emb # ✅ 纯 GPU
│ ├── attn._chunked_decode_attention: # ⚠️ 包含 CPU→GPU
│ │ ├── H2D transfer # ❌ 不能 graph
│ │ ├── flash_attn_with_lse # ✅ 可以 graph
│ │ └── merge # ✅ 纯 GPU
│ └── o_proj (Linear) # ✅ 纯 GPU
├── post_attention_layernorm # ✅ 纯 GPU
└── mlp (FFN: gate, up, down) # ✅ 纯 GPU
核心问题:H2D 传输被嵌在 attention 中间,打断了整层的 graph 捕获。
可能的方案
| 方案 | 描述 | 优点 | 缺点 |
|---|---|---|---|
| A. 分段 Graph | 将层拆分为 pre/post attention 两段 | 覆盖面广 | 改动大,需拆分层执行 |
| B. 只 Graph Attention | 只优化 flash_attn_with_lse | 改动小 | 优化效果有限 |
| C. 重构执行流程 | 完全重写 model forward | 最优效果 | 工作量巨大 |
推荐:方案 A(分段 Graph)
将每层拆分为两个 graph:
- pre_attention_graph:
norm → qkv_proj → q/k_norm → rotary - post_attention_graph:
o_proj → norm → FFN
中间的 _chunked_decode_attention 保持 eager(包含 H2D),但内部的 flash_attn_with_lse 使用 graph。
当前状态分析
现有 CUDA Graph 实现
文件: nanovllm/engine/model_runner.py
| 方法 | 行号 | 功能 |
|---|---|---|
capture_cudagraph() |
682-717 | 为不同 batch size 捕获完整 model forward |
run_model() |
415-436 | 决定使用 eager 还是 graph replay |
关键逻辑 (run_model):
use_eager = is_prefill or self.enforce_eager or input_ids.size(0) > 512 or context.is_chunked_prefill
问题: run_chunked_offload_decode 设置 is_chunked_prefill=True,导致永远使用 eager mode。
Offload Decode 流程
文件: nanovllm/kvcache/sparse/full_policy.py
_decode_ring_buffer_pipeline() (L304-379):
for block in cpu_blocks:
1. wait_slot_layer(slot) # 等待 H2D 完成
2. k, v = get_kv_for_slot(slot) # 获取 KV
3. o, lse = flash_attn_with_lse() # ⭐ 纯 GPU 计算
4. record_slot_compute_done(slot) # 标记计算完成
5. load_next_block() # 启动下一个 H2D
6. merge_attention_outputs() # ⭐ 纯 GPU 计算
可 Graph 化的部分:
flash_attn_with_lse()- 纯 GPU 计算- 不可 Graph 化: H2D 传输、动态 merge
验证结果
测试文件: tests/test_chunk_attention_graph_reuse.py
| 测试 | 结果 |
|---|---|
| 2 个 Graph 复用于所有层和所有 chunk | ✅ PASSED |
| copy_() 更新 static tensors | ✅ 有效 |
| Eager merge | ✅ 用户已接受 |
结论: 只需 2 个 graph(causal + non-causal),通过 copy_() 复用。
修改计划(方案 A:分段 Graph)
架构设计
每层执行流程(Offload Decode):
┌─────────────────────────────────────────────────────────────┐
│ PRE-ATTENTION GRAPH (可复用于所有层) │
│ input_layernorm → qkv_proj → q/k_norm → rotary → split Q │
└─────────────────────────────────────────────────────────────┘
↓
┌─────────────────────────────────────────────────────────────┐
│ CHUNKED ATTENTION (Eager + 部分 Graph) │
│ for block in cpu_blocks: │
│ H2D transfer (eager) │
│ flash_attn_with_lse (GRAPH - 2个可复用) │
│ merge (eager) │
│ decode_buffer attention (eager) │
└─────────────────────────────────────────────────────────────┘
↓
┌─────────────────────────────────────────────────────────────┐
│ POST-ATTENTION GRAPH (可复用于所有层) │
│ o_proj → post_layernorm → gate_proj → up_proj → SiLU │
│ → down_proj → residual │
└─────────────────────────────────────────────────────────────┘
总共需要的 Graph 数量:
- 1 个 pre_attention_graph(所有层复用)
- 2 个 attention_graph(causal + non-causal,所有层复用)
- 1 个 post_attention_graph(所有层复用)
- 总计: 4 个 graph
Phase 1: 拆分 DecoderLayer 执行
目标: 将 Qwen3DecoderLayer.forward 拆分为可独立调用的三段
修改文件: nanovllm/models/qwen3.py
新增方法:
class Qwen3DecoderLayer:
def forward_pre_attention(self, positions, hidden_states, residual):
"""Pre-attention: norm → qkv → rotary → 返回 q, k, v"""
if residual is None:
hidden_states, residual = self.input_layernorm(hidden_states), hidden_states
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
qkv = self.self_attn.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q = q.view(-1, self.num_heads, self.head_dim)
k = k.view(-1, self.num_kv_heads, self.head_dim)
v = v.view(-1, self.num_kv_heads, self.head_dim)
q = self.self_attn.q_norm(q)
k = self.self_attn.k_norm(k)
q, k = self.self_attn.rotary_emb(positions, q, k)
return q, k, v, hidden_states, residual
def forward_post_attention(self, attn_output, hidden_states, residual):
"""Post-attention: o_proj → norm → FFN"""
output = self.self_attn.o_proj(attn_output.flatten(1, -1))
hidden_states, residual = self.post_attention_layernorm(output, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
状态: pending
Phase 2: 捕获 Pre/Post Attention Graph
目标: 捕获 pre_attention 和 post_attention 的 graph
修改文件: nanovllm/engine/model_runner.py
新增方法: capture_offload_layer_graphs()
def capture_offload_layer_graphs(self):
"""捕获 offload mode 的 layer graphs"""
# 获取任意一层作为模板(所有层结构相同)
layer = self.model.model.layers[0]
# Static tensors
static_hidden = torch.zeros(1, self.hidden_size, ...)
static_residual = torch.zeros(1, self.hidden_size, ...)
static_positions = torch.zeros(1, ...)
# Pre-attention graph
self.pre_attn_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.pre_attn_graph):
static_q, static_k, static_v, _, _ = layer.forward_pre_attention(
static_positions, static_hidden, static_residual
)
# Post-attention graph
self.post_attn_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.post_attn_graph):
_, _ = layer.forward_post_attention(
static_attn_output, static_hidden, static_residual
)
状态: pending
Phase 3: 捕获 Attention Graph
目标: 捕获 2 个 attention graph(causal + non-causal)
修改文件: nanovllm/kvcache/offload_engine.py
class OffloadEngine:
def capture_attention_graphs(self):
"""捕获 attention graphs(复用于所有层)"""
self.attn_graph_causal = self._capture_attn_graph(causal=True)
self.attn_graph_non_causal = self._capture_attn_graph(causal=False)
def _capture_attn_graph(self, causal: bool):
static_q = torch.zeros(1, 1, num_heads, head_dim, ...)
static_k = torch.zeros(1, block_size, num_kv_heads, head_dim, ...)
static_v = torch.zeros(1, block_size, num_kv_heads, head_dim, ...)
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
output, lse = flash_attn_with_lse(static_q, static_k, static_v,
self.scale, causal)
return AttentionGraph(graph, static_q, static_k, static_v, output, lse)
状态: pending
Phase 4: 修改 Offload Decode 执行流程
目标: 使用 graph replay 执行 offload decode
修改文件: nanovllm/engine/model_runner.py
修改方法: run_chunked_offload_decode()
def run_chunked_offload_decode_with_graph(self, seqs):
"""使用 graph 加速的 offload decode"""
seq = seqs[0]
# 准备输入
input_ids = torch.tensor([seq.last_token], ...)
positions = torch.tensor([len(seq) - 1], ...)
# Embedding
hidden_states = self.model.model.embed_tokens(input_ids)
residual = None
for layer_id, layer in enumerate(self.model.model.layers):
# Phase 1: Pre-attention (GRAPH)
self.pre_attn_vars["hidden"].copy_(hidden_states)
self.pre_attn_vars["residual"].copy_(residual) if residual else None
self.pre_attn_vars["positions"].copy_(positions)
self.pre_attn_graph.replay()
q = self.pre_attn_vars["q"].clone()
k = self.pre_attn_vars["k"].clone()
v = self.pre_attn_vars["v"].clone()
# Phase 2: Chunked Attention (Eager + Graph)
attn_output = self._chunked_attention_with_graph(q, k, v, layer_id, ...)
# Phase 3: Post-attention (GRAPH)
self.post_attn_vars["attn_output"].copy_(attn_output)
self.post_attn_graph.replay()
hidden_states = self.post_attn_vars["hidden"].clone()
residual = self.post_attn_vars["residual"].clone()
# LM head
logits = self.model.compute_logits(hidden_states)
return logits
状态: pending
Phase 5: 修改 Ring Buffer Pipeline
目标: 在 attention 内部使用 graph
修改文件: nanovllm/kvcache/sparse/full_policy.py
修改: _decode_ring_buffer_pipeline() 中的 flash_attn_with_lse 调用
# 当前:eager
prev_o, prev_lse = flash_attn_with_lse(q, k, v, scale, causal=False)
# 修改为:graph replay
graph = offload_engine.attn_graph_non_causal
graph.static_q.copy_(q)
graph.static_k.copy_(k)
graph.static_v.copy_(v)
graph.graph.replay()
prev_o = graph.static_output.clone()
prev_lse = graph.static_lse.clone()
状态: pending
Phase 6: 添加配置开关
修改文件: nanovllm/config.py
enable_offload_graph: bool = True # 默认启用
状态: pending
文件修改清单
| 文件 | 修改类型 | 说明 |
|---|---|---|
nanovllm/engine/model_runner.py |
新增方法 | capture_offload_attention_graph() |
nanovllm/kvcache/offload_engine.py |
新增属性+方法 | Graph 存储和访问 |
nanovllm/kvcache/sparse/full_policy.py |
修改方法 | 使用 graph replay |
nanovllm/config.py |
新增配置 | enable_offload_graph |
风险和注意事项
- Graph 捕获时机: 需要在 KV cache 分配后、第一次 decode 前捕获
- Chunk size 匹配: Graph 的 chunk_size 必须和 block_size 一致
- 多 GPU: Graph 需要在每个 GPU 上分别捕获
- 内存: 2 个 graph 的额外内存开销很小
测试计划
- 单元测试: 验证 graph replay 结果正确
- 集成测试: 运行
test_needle.py --enable-offload --input-len 32768 - 性能测试: 对比 eager vs graph 的 decode 延迟
预期收益
- Decode 阶段 attention 计算加速(减少 kernel launch overhead)
- 与现有 ring buffer pipeline 兼容
- 内存开销极小(只有 2 个额外 graph)