📝 docs: add CUDA Graph optimization plan for offload mode decode

- Update task_plan.md with 6-phase segmented graph implementation plan
- Add findings.md documenting 7 key discoveries about current implementation
- Add progress.md for tracking implementation progress
- Add test_chunk_attention_graph_reuse.py validating 2-graph reuse strategy

Key architecture decision: Split transformer layer into 3 segments:
- PRE-ATTENTION GRAPH: norm → qkv_proj → rotary (1 graph, reused)
- CHUNKED ATTENTION: H2D (eager) + flash_attn (2 graphs) + merge (eager)
- POST-ATTENTION GRAPH: o_proj → norm → FFN (1 graph, reused)

Total: 4 graphs serving all layers via copy_() tensor updates.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Zijie Tian
2026-01-22 02:12:24 +08:00
parent d808970f2f
commit a5307fb124
4 changed files with 651 additions and 64 deletions

109
findings.md Normal file
View File

@@ -0,0 +1,109 @@
# Findings: CUDA Graph for Offload Mode
## Discovery 1: 为什么 Offload Mode 不使用 CUDA Graph
**位置**: `nanovllm/engine/model_runner.py:421`
```python
use_eager = is_prefill or self.enforce_eager or input_ids.size(0) > 512 or context.is_chunked_prefill
```
**原因**: `run_chunked_offload_decode()` 设置 `is_chunked_prefill=True`,强制使用 eager mode。
---
## Discovery 2: 当前 CUDA Graph 架构
**文件**: `model_runner.py:682-717`
```python
def capture_cudagraph(self):
# 为不同 batch size 捕获完整 model forward
for bs in [1, 2, 4, 8, 16, ...]:
with torch.cuda.graph(graph):
outputs[:bs] = self.model(input_ids[:bs], positions[:bs])
```
**特点**:
- 捕获完整的 `model()` 调用(包含所有层)
- 使用 graph pool 共享内存
- 只用于 decodeprefill 始终 eager
---
## Discovery 3: Offload Decode 的 Attention 流程
**文件**: `nanovllm/kvcache/sparse/full_policy.py:304-379`
**Ring Buffer Pipeline**:
```
1. 预加载前 N 个 blocks 到 GPU slots
2. 对每个 block:
a. wait_slot_layer() # 等待 H2D
b. get_kv_for_slot() # 获取 KV
c. flash_attn_with_lse() # ⭐ 可 graph
d. record_slot_compute_done()
e. load_next_block() # 启动下一个 H2D
f. merge_attention_outputs() # ⭐ 可 graph但动态
```
**关键**: H2D 传输不能 graph但 attention 计算可以。
---
## Discovery 4: 验证 Graph 复用可行性
**测试**: `tests/test_chunk_attention_graph_reuse.py`
**结论**:
- 只需 2 个 graphcausal + non-causal
- 通过 `copy_()` 更新 static tensors
- 可复用于所有层和所有 chunk pairs
**测试结果**:
```
Layer 0: max_diff=3.91e-03 ✅
Layer 1: max_diff=7.81e-03 ✅
Layer 2: max_diff=3.91e-03 ✅
✅ PASSED
```
---
## Discovery 5: Chunk Size 和 Block Size 关系
**观察**:
- Prefilled blocks 的 KV size = `block_size`
- Decode buffer 的 KV size = `1``block_size`(动态)
**Graph 策略**:
- Prefilled blocks: 固定 size = block_size适合 graph
- Decode buffer: 动态 size建议保持 eager
---
## Discovery 6: 使用的 Triton 算子
**文件**: `nanovllm/ops/chunked_attention.py`
| 算子 | 功能 | 可 Graph |
|------|------|----------|
| `flash_attn_with_lse()` | Attention + LSE | ✅ |
| `merge_attention_outputs()` | 合并两个 attention 输出 | ✅ |
这两个算子是纯 GPU 计算,可以被 CUDA Graph 捕获。
---
## Discovery 7: 数据依赖分析
**Attention 输入**:
- `q`: 来自当前层的 QKV projectionshape 固定
- `k, v`: 来自 GPU slotH2D 传输后shape = [1, block_size, heads, dim]
**依赖链**:
```
H2D(block) → wait() → get_kv() → copy_to_static() → graph.replay() → clone_output()
```
**关键**: Graph 只封装 attention 计算,不包含数据传输。

55
progress.md Normal file
View File

@@ -0,0 +1,55 @@
# Progress: CUDA Graph for Offload Mode
## Session: 2026-01-22
### 调研阶段 ✅ 完成
**完成的调研**:
1. ✅ 分析 `model_runner.py` 中的 CUDA Graph 实现
- `capture_cudagraph()`: 为不同 batch size 捕获完整 model forward
- `run_model()`: 通过 `is_chunked_prefill` 决定 eager/graph
2. ✅ 分析 offload decode 流程
- `run_chunked_offload_decode()` 设置 `is_chunked_prefill=True`
- 导致永远使用 eager mode
3. ✅ 分析 ring buffer pipeline
- `_decode_ring_buffer_pipeline()` 包含 H2D 传输 + attention 计算
- H2D 不能 graphattention 可以 graph
4. ✅ 验证 graph 复用策略
- 创建 `test_chunk_attention_graph_reuse.py`
- 确认 2 个 graph 可复用于所有层
### 计划编写 ✅ 完成
- ✅ 创建 `task_plan.md`
- ✅ 创建 `findings.md`
- ✅ 创建 `progress.md`
### 下一步: 实现
**Phase 1**: 添加 graph 捕获到 OffloadEngine
- [ ]`offload_engine.py` 添加 `capture_attention_graphs()`
- [ ] 添加 `attention_graph_causal``attention_graph_non_causal` 属性
**Phase 2**: 修改 ring buffer pipeline
- [ ]`_decode_ring_buffer_pipeline()` 使用 graph replay
- [ ] 保持 H2D 和 merge 为 eager
**Phase 3**: 测试
- [ ] 运行 needle test 验证正确性
- [ ] 对比性能
---
## 文件清单
| 文件 | 状态 | 说明 |
|------|------|------|
| `tests/test_chunk_attention_graph.py` | ✅ 已提交 | 预分配 chunk pair graphs 测试 |
| `tests/test_chunk_attention_graph_reuse.py` | 待提交 | Graph 复用验证 |
| `task_plan.md` | ✅ 创建 | 实现计划 |
| `findings.md` | ✅ 创建 | 调研发现 |
| `progress.md` | ✅ 创建 | 进度日志 |

View File

@@ -1,90 +1,357 @@
# Task Plan: XAttention BSA 集成到 nanovllm
# Task Plan: CUDA Graph 优化 Offload Mode Decode
## Goal
## 目标
使用 `--sparse-policy XATTN_BSA` 运行 `test_ruler.py`,通过 `niah_single_1` 的前 5 个 sample
为 nanovllm 的 CPU offload 模式添加 CUDA Graph 支持,加速 decode 阶段的计算
**验收标准**:
```bash
CUDA_VISIBLE_DEVICES=X PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
python tests/test_ruler.py \
--model ~/models/Llama-3.1-8B-Instruct \
--enable-offload \
--sparse-policy XATTN_BSA \
--task niah_single_1 \
--sample-ids 0,1,2,3,4
# 期望: 5/5 PASS
## 问题分析
### Transformer 层的完整结构
```
Qwen3DecoderLayer.forward:
├── input_layernorm (RMSNorm) # ✅ 纯 GPU
├── self_attn:
├── qkv_proj (Linear) # ✅ 纯 GPU
│ ├── q_norm, k_norm (RMSNorm) # ✅ 纯 GPU
│ ├── rotary_emb # ✅ 纯 GPU
│ ├── attn._chunked_decode_attention: # ⚠️ 包含 CPU→GPU
│ │ ├── H2D transfer # ❌ 不能 graph
│ │ ├── flash_attn_with_lse # ✅ 可以 graph
│ │ └── merge # ✅ 纯 GPU
│ └── o_proj (Linear) # ✅ 纯 GPU
├── post_attention_layernorm # ✅ 纯 GPU
└── mlp (FFN: gate, up, down) # ✅ 纯 GPU
```
## 当前状态
**核心问题**H2D 传输被嵌在 attention 中间,打断了整层的 graph 捕获。
- `XAttentionBSAPolicy.compute_chunked_prefill` 实现 = `FullAttentionPolicy`(无 sparse
- `xattn_estimate_chunked` 已实现并验证
- BSA kernel (`block_sparse_attn`) 可用
### 可能的方案
## Phases
| 方案 | 描述 | 优点 | 缺点 |
|------|------|------|------|
| A. 分段 Graph | 将层拆分为 pre/post attention 两段 | 覆盖面广 | 改动大,需拆分层执行 |
| B. 只 Graph Attention | 只优化 flash_attn_with_lse | 改动小 | 优化效果有限 |
| C. 重构执行流程 | 完全重写 model forward | 最优效果 | 工作量巨大 |
- [ ] Phase 1: 理解当前代码路径
- [ ] Phase 2: 实现 sparse mask 估计
- [ ] Phase 3: 实现 BSA sparse 计算
- [ ] Phase 4: 测试验证
### 推荐:方案 A分段 Graph
## Phase 1: 理解当前代码路径
将每层拆分为两个 graph
1. **pre_attention_graph**: `norm → qkv_proj → q/k_norm → rotary`
2. **post_attention_graph**: `o_proj → norm → FFN`
### 1.1 确认 XATTN_BSA policy 是否被正确加载
- [ ] 检查 `test_ruler.py` 如何解析 `--sparse-policy XATTN_BSA`
- [ ] 检查 `KVCacheManager` 如何实例化 sparse_policy
- [ ] 运行 baseline 测试(`--sparse-policy FULL`)确认基础功能正常
中间的 `_chunked_decode_attention` 保持 eager包含 H2D但内部的 `flash_attn_with_lse` 使用 graph。
### 1.2 确认数据流
- [ ] `compute_chunked_prefill` 的输入参数含义
- [ ] `offload_engine` 提供的数据访问接口
- [ ] 当前 chunk 的 K/V 如何获取
---
## Phase 2: 实现 sparse mask 估计
## 当前状态分析
### 2.1 调用 xattn_estimate_chunked
- [ ]`compute_chunked_prefill` 中加载历史 K
- [ ] 拼接历史 K + 当前 K
- [ ] 调用 `xattn_estimate_chunked(q, k_full, q_start_pos=...)`
- [ ] 获取 block mask
### 现有 CUDA Graph 实现
### 2.2 处理参数对齐
- [ ] BSA block_size = 128
- [ ] chunk_size 与 kvcache_block_size 的关系
- [ ] q_start_pos 计算
**文件**: `nanovllm/engine/model_runner.py`
## Phase 3: 实现 BSA sparse 计算
| 方法 | 行号 | 功能 |
|------|------|------|
| `capture_cudagraph()` | 682-717 | 为不同 batch size 捕获完整 model forward |
| `run_model()` | 415-436 | 决定使用 eager 还是 graph replay |
### 3.1 方案选择
- 选项 A: 历史 + 当前分开计算,然后 merge
- 选项 B: 全部一起用 BSA 计算
**关键逻辑** (`run_model`):
```python
use_eager = is_prefill or self.enforce_eager or input_ids.size(0) > 512 or context.is_chunked_prefill
```
### 3.2 实现
- [ ] 构造 BSA 需要的输入格式
- [ ] 调用 `block_sparse_attn_func`
- [ ] 处理输出格式
**问题**: `run_chunked_offload_decode` 设置 `is_chunked_prefill=True`,导致**永远使用 eager mode**。
## Phase 4: 测试验证
### Offload Decode 流程
### 4.1 单元测试
- [ ] 验证 sparse mask 与 `test_xattn_estimate_chunked.py` 一致
**文件**: `nanovllm/kvcache/sparse/full_policy.py`
### 4.2 集成测试
- [ ] 运行验收命令
- [ ] 5/5 PASS
`_decode_ring_buffer_pipeline()` (L304-379):
```
for block in cpu_blocks:
1. wait_slot_layer(slot) # 等待 H2D 完成
2. k, v = get_kv_for_slot(slot) # 获取 KV
3. o, lse = flash_attn_with_lse() # ⭐ 纯 GPU 计算
4. record_slot_compute_done(slot) # 标记计算完成
5. load_next_block() # 启动下一个 H2D
6. merge_attention_outputs() # ⭐ 纯 GPU 计算
```
## Key Questions
**可 Graph 化的部分**:
- `flash_attn_with_lse()` - 纯 GPU 计算
- 不可 Graph 化: H2D 传输、动态 merge
1. 历史 K 如何高效加载?(全量 vs 按需)
2. BSA causal mask 如何处理?(历史 non-causal + 当前 causal
## 验证结果
## Status
**测试文件**: `tests/test_chunk_attention_graph_reuse.py`
**Currently in Phase 1** - 等待用户确认后开始
| 测试 | 结果 |
|------|------|
| 2 个 Graph 复用于所有层和所有 chunk | ✅ PASSED |
| copy_() 更新 static tensors | ✅ 有效 |
| Eager merge | ✅ 用户已接受 |
## 待讨论
**结论**: 只需 2 个 graphcausal + non-causal通过 copy_() 复用。
请确认:
1. 这个 goal 和验收标准是否正确?
2. 我使用哪个 GPU 运行测试?
---
## 修改计划(方案 A分段 Graph
### 架构设计
```
每层执行流程Offload Decode:
┌─────────────────────────────────────────────────────────────┐
│ PRE-ATTENTION GRAPH (可复用于所有层) │
│ input_layernorm → qkv_proj → q/k_norm → rotary → split Q │
└─────────────────────────────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────┐
│ CHUNKED ATTENTION (Eager + 部分 Graph) │
│ for block in cpu_blocks: │
│ H2D transfer (eager) │
│ flash_attn_with_lse (GRAPH - 2个可复用) │
│ merge (eager) │
│ decode_buffer attention (eager) │
└─────────────────────────────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────┐
│ POST-ATTENTION GRAPH (可复用于所有层) │
│ o_proj → post_layernorm → gate_proj → up_proj → SiLU │
│ → down_proj → residual │
└─────────────────────────────────────────────────────────────┘
```
**总共需要的 Graph 数量**:
- 1 个 pre_attention_graph所有层复用
- 2 个 attention_graphcausal + non-causal所有层复用
- 1 个 post_attention_graph所有层复用
- **总计: 4 个 graph**
---
### Phase 1: 拆分 DecoderLayer 执行
**目标**: 将 `Qwen3DecoderLayer.forward` 拆分为可独立调用的三段
**修改文件**: `nanovllm/models/qwen3.py`
**新增方法**:
```python
class Qwen3DecoderLayer:
def forward_pre_attention(self, positions, hidden_states, residual):
"""Pre-attention: norm → qkv → rotary → 返回 q, k, v"""
if residual is None:
hidden_states, residual = self.input_layernorm(hidden_states), hidden_states
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
qkv = self.self_attn.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q = q.view(-1, self.num_heads, self.head_dim)
k = k.view(-1, self.num_kv_heads, self.head_dim)
v = v.view(-1, self.num_kv_heads, self.head_dim)
q = self.self_attn.q_norm(q)
k = self.self_attn.k_norm(k)
q, k = self.self_attn.rotary_emb(positions, q, k)
return q, k, v, hidden_states, residual
def forward_post_attention(self, attn_output, hidden_states, residual):
"""Post-attention: o_proj → norm → FFN"""
output = self.self_attn.o_proj(attn_output.flatten(1, -1))
hidden_states, residual = self.post_attention_layernorm(output, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
```
**状态**: `pending`
---
### Phase 2: 捕获 Pre/Post Attention Graph
**目标**: 捕获 pre_attention 和 post_attention 的 graph
**修改文件**: `nanovllm/engine/model_runner.py`
**新增方法**: `capture_offload_layer_graphs()`
```python
def capture_offload_layer_graphs(self):
"""捕获 offload mode 的 layer graphs"""
# 获取任意一层作为模板(所有层结构相同)
layer = self.model.model.layers[0]
# Static tensors
static_hidden = torch.zeros(1, self.hidden_size, ...)
static_residual = torch.zeros(1, self.hidden_size, ...)
static_positions = torch.zeros(1, ...)
# Pre-attention graph
self.pre_attn_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.pre_attn_graph):
static_q, static_k, static_v, _, _ = layer.forward_pre_attention(
static_positions, static_hidden, static_residual
)
# Post-attention graph
self.post_attn_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.post_attn_graph):
_, _ = layer.forward_post_attention(
static_attn_output, static_hidden, static_residual
)
```
**状态**: `pending`
---
### Phase 3: 捕获 Attention Graph
**目标**: 捕获 2 个 attention graphcausal + non-causal
**修改文件**: `nanovllm/kvcache/offload_engine.py`
```python
class OffloadEngine:
def capture_attention_graphs(self):
"""捕获 attention graphs复用于所有层"""
self.attn_graph_causal = self._capture_attn_graph(causal=True)
self.attn_graph_non_causal = self._capture_attn_graph(causal=False)
def _capture_attn_graph(self, causal: bool):
static_q = torch.zeros(1, 1, num_heads, head_dim, ...)
static_k = torch.zeros(1, block_size, num_kv_heads, head_dim, ...)
static_v = torch.zeros(1, block_size, num_kv_heads, head_dim, ...)
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
output, lse = flash_attn_with_lse(static_q, static_k, static_v,
self.scale, causal)
return AttentionGraph(graph, static_q, static_k, static_v, output, lse)
```
**状态**: `pending`
---
### Phase 4: 修改 Offload Decode 执行流程
**目标**: 使用 graph replay 执行 offload decode
**修改文件**: `nanovllm/engine/model_runner.py`
**修改方法**: `run_chunked_offload_decode()`
```python
def run_chunked_offload_decode_with_graph(self, seqs):
"""使用 graph 加速的 offload decode"""
seq = seqs[0]
# 准备输入
input_ids = torch.tensor([seq.last_token], ...)
positions = torch.tensor([len(seq) - 1], ...)
# Embedding
hidden_states = self.model.model.embed_tokens(input_ids)
residual = None
for layer_id, layer in enumerate(self.model.model.layers):
# Phase 1: Pre-attention (GRAPH)
self.pre_attn_vars["hidden"].copy_(hidden_states)
self.pre_attn_vars["residual"].copy_(residual) if residual else None
self.pre_attn_vars["positions"].copy_(positions)
self.pre_attn_graph.replay()
q = self.pre_attn_vars["q"].clone()
k = self.pre_attn_vars["k"].clone()
v = self.pre_attn_vars["v"].clone()
# Phase 2: Chunked Attention (Eager + Graph)
attn_output = self._chunked_attention_with_graph(q, k, v, layer_id, ...)
# Phase 3: Post-attention (GRAPH)
self.post_attn_vars["attn_output"].copy_(attn_output)
self.post_attn_graph.replay()
hidden_states = self.post_attn_vars["hidden"].clone()
residual = self.post_attn_vars["residual"].clone()
# LM head
logits = self.model.compute_logits(hidden_states)
return logits
```
**状态**: `pending`
---
### Phase 5: 修改 Ring Buffer Pipeline
**目标**: 在 attention 内部使用 graph
**修改文件**: `nanovllm/kvcache/sparse/full_policy.py`
**修改**: `_decode_ring_buffer_pipeline()` 中的 `flash_attn_with_lse` 调用
```python
# 当前eager
prev_o, prev_lse = flash_attn_with_lse(q, k, v, scale, causal=False)
# 修改为graph replay
graph = offload_engine.attn_graph_non_causal
graph.static_q.copy_(q)
graph.static_k.copy_(k)
graph.static_v.copy_(v)
graph.graph.replay()
prev_o = graph.static_output.clone()
prev_lse = graph.static_lse.clone()
```
**状态**: `pending`
---
### Phase 6: 添加配置开关
**修改文件**: `nanovllm/config.py`
```python
enable_offload_graph: bool = True # 默认启用
```
**状态**: `pending`
---
## 文件修改清单
| 文件 | 修改类型 | 说明 |
|------|----------|------|
| `nanovllm/engine/model_runner.py` | 新增方法 | `capture_offload_attention_graph()` |
| `nanovllm/kvcache/offload_engine.py` | 新增属性+方法 | Graph 存储和访问 |
| `nanovllm/kvcache/sparse/full_policy.py` | 修改方法 | 使用 graph replay |
| `nanovllm/config.py` | 新增配置 | `enable_offload_graph` |
---
## 风险和注意事项
1. **Graph 捕获时机**: 需要在 KV cache 分配后、第一次 decode 前捕获
2. **Chunk size 匹配**: Graph 的 chunk_size 必须和 block_size 一致
3. **多 GPU**: Graph 需要在每个 GPU 上分别捕获
4. **内存**: 2 个 graph 的额外内存开销很小
---
## 测试计划
1. **单元测试**: 验证 graph replay 结果正确
2. **集成测试**: 运行 `test_needle.py --enable-offload --input-len 32768`
3. **性能测试**: 对比 eager vs graph 的 decode 延迟
---
## 预期收益
- Decode 阶段 attention 计算加速(减少 kernel launch overhead
- 与现有 ring buffer pipeline 兼容
- 内存开销极小(只有 2 个额外 graph

View File

@@ -0,0 +1,156 @@
#!/usr/bin/env python3
"""
Test: Reuse a single CUDA Graph across all layers and all chunk pairs.
Key insight: LLM layers have identical computation structure.
We only need 2 graphs (causal + non-causal), reused for all (layer, Q_i, K_j) combinations.
Usage:
CUDA_VISIBLE_DEVICES=0 python tests/test_chunk_attention_graph_reuse.py
"""
from dataclasses import dataclass
import torch
from nanovllm.ops.chunked_attention import flash_attn_with_lse, merge_attention_outputs
@dataclass
class ReusableChunkGraph:
"""A single graph that can be reused with copy_() updates."""
graph: torch.cuda.CUDAGraph
static_q: torch.Tensor
static_k: torch.Tensor
static_v: torch.Tensor
static_output: torch.Tensor
static_lse: torch.Tensor
def capture_reusable_graph(
chunk_size: int,
num_heads: int,
num_kv_heads: int,
head_dim: int,
scale: float,
device: torch.device,
dtype: torch.dtype,
causal: bool,
) -> ReusableChunkGraph:
"""Capture ONE graph to be reused for all chunk pairs."""
static_q = torch.zeros(1, chunk_size, num_heads, head_dim, dtype=dtype, device=device)
static_k = torch.zeros(1, chunk_size, num_kv_heads, head_dim, dtype=dtype, device=device)
static_v = torch.zeros(1, chunk_size, num_kv_heads, head_dim, dtype=dtype, device=device)
static_q.normal_()
static_k.normal_()
static_v.normal_()
# Warmup
with torch.inference_mode():
for _ in range(3):
_ = flash_attn_with_lse(static_q, static_k, static_v, scale, causal)
torch.cuda.synchronize()
# Capture
graph = torch.cuda.CUDAGraph()
with torch.inference_mode():
with torch.cuda.graph(graph):
static_output, static_lse = flash_attn_with_lse(static_q, static_k, static_v, scale, causal)
torch.cuda.synchronize()
return ReusableChunkGraph(
graph=graph,
static_q=static_q,
static_k=static_k,
static_v=static_v,
static_output=static_output,
static_lse=static_lse,
)
def replay_with_copy(graph: ReusableChunkGraph, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
"""Replay graph after updating static tensors with copy_()."""
graph.static_q.copy_(q)
graph.static_k.copy_(k)
graph.static_v.copy_(v)
graph.graph.replay()
return graph.static_output.clone(), graph.static_lse.clone()
def main():
device = torch.device("cuda")
dtype = torch.bfloat16
chunk_size = 64
num_chunks = 4
num_layers = 3 # Simulate multiple layers
num_heads = 8
num_kv_heads = 8
head_dim = 64
scale = 1.0 / (head_dim ** 0.5)
seq_len = chunk_size * num_chunks
print(f"Device: {torch.cuda.get_device_name()}")
print(f"Chunk size: {chunk_size}, Num chunks: {num_chunks}, Num layers: {num_layers}")
print(f"Only 2 graphs (causal + non-causal) for ALL layer × chunk combinations")
# Capture only 2 graphs
graph_causal = capture_reusable_graph(
chunk_size, num_heads, num_kv_heads, head_dim, scale, device, dtype, causal=True
)
graph_non_causal = capture_reusable_graph(
chunk_size, num_heads, num_kv_heads, head_dim, scale, device, dtype, causal=False
)
print("2 graphs captured (causal + non-causal)")
all_pass = True
for layer_id in range(num_layers):
# Different Q/K/V for each layer (simulating different layer outputs)
full_q = torch.randn(1, seq_len, num_heads, head_dim, dtype=dtype, device=device)
full_k = torch.randn(1, seq_len, num_kv_heads, head_dim, dtype=dtype, device=device)
full_v = torch.randn(1, seq_len, num_kv_heads, head_dim, dtype=dtype, device=device)
# Reference: full causal attention
with torch.inference_mode():
full_output, _ = flash_attn_with_lse(full_q, full_k, full_v, scale, causal=True)
# Chunked with graph reuse
chunked_output = torch.zeros_like(full_output)
for q_idx in range(num_chunks):
q_chunk = full_q[:, q_idx*chunk_size:(q_idx+1)*chunk_size]
acc_out, acc_lse = None, None
for k_idx in range(q_idx + 1):
k_chunk = full_k[:, k_idx*chunk_size:(k_idx+1)*chunk_size]
v_chunk = full_v[:, k_idx*chunk_size:(k_idx+1)*chunk_size]
# Reuse graph with copy_()
graph = graph_causal if k_idx == q_idx else graph_non_causal
out, lse = replay_with_copy(graph, q_chunk, k_chunk, v_chunk)
if acc_out is None:
acc_out, acc_lse = out, lse
else:
with torch.inference_mode():
acc_out, acc_lse = merge_attention_outputs(acc_out, acc_lse, out, lse)
chunked_output[:, q_idx*chunk_size:(q_idx+1)*chunk_size] = acc_out
torch.cuda.synchronize()
# Compare
max_diff = (full_output - chunked_output).abs().max().item()
status = "" if max_diff < 1e-2 else ""
print(f"Layer {layer_id}: max_diff={max_diff:.2e} {status}")
if max_diff >= 1e-2:
all_pass = False
print("✅ PASSED - Single graph reuse across layers works!" if all_pass else "❌ FAILED")
if __name__ == "__main__":
main()