🔀 merge: integrate remote changes (exec-plan command, CUDA graph plan)
Resolve task_plan.md conflict by keeping remote version (CUDA Graph optimization plan). Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
158
.claude/commands/exec-plan.md
Normal file
158
.claude/commands/exec-plan.md
Normal file
@@ -0,0 +1,158 @@
|
||||
---
|
||||
allowed-tools: Bash(CUDA_VISIBLE_DEVICES=*), Bash(PYTHONPATH=*), Bash(python*), Bash(git*), Bash(rm*), Bash(ls*), Bash(cat*), Bash(nvidia-smi*), Read, Edit, Write, Glob, Grep, TodoWrite, Task
|
||||
argument-hint: --gpu <id> [--no-interrupt]
|
||||
description: Execute task_plan.md refactoring with specified GPU, optionally without user interruption
|
||||
---
|
||||
|
||||
# Execute Task Plan (exec-plan)
|
||||
|
||||
按照 `task_plan.md` 的要求执行代码重构,确保计划中的最终目标圆满实现。
|
||||
|
||||
## 参数说明
|
||||
|
||||
命令格式: `/exec-plan --gpu <id> [--no-interrupt]`
|
||||
|
||||
| 参数 | 说明 | 示例 |
|
||||
|------|------|------|
|
||||
| `--gpu <id>` | **必需**。指定可用的 GPU ID,只能使用此 GPU 进行调试 | `--gpu 0`, `--gpu 2` |
|
||||
| `--no-interrupt` | 可选。禁止中断执行,遇到问题不与用户交互,自动解决或跳过 | `--no-interrupt` |
|
||||
|
||||
## 当前参数
|
||||
|
||||
```
|
||||
$ARGUMENTS
|
||||
```
|
||||
|
||||
## 执行前准备
|
||||
|
||||
### 1. 解析参数
|
||||
|
||||
从 `$ARGUMENTS` 中解析:
|
||||
- `GPU_ID`: 从 `--gpu <id>` 或 `-g <id>` 提取
|
||||
- `NO_INTERRUPT`: 是否存在 `--no-interrupt` 或 `-n` 标志
|
||||
|
||||
### 2. 参数验证
|
||||
|
||||
**必须验证**:
|
||||
- GPU_ID 必须是有效的数字
|
||||
- 运行 `nvidia-smi -i <GPU_ID>` 验证 GPU 存在
|
||||
|
||||
### 3. 读取 task_plan.md
|
||||
|
||||
读取项目根目录下的 `task_plan.md` 文件,理解:
|
||||
- 总体目标
|
||||
- 分阶段计划 (Phase 1, 2, 3...)
|
||||
- 文件修改清单
|
||||
- 风险和注意事项
|
||||
- 测试计划
|
||||
|
||||
## 执行流程
|
||||
|
||||
### Step 1: 创建执行计划
|
||||
|
||||
使用 TodoWrite 工具创建详细的执行计划,包括:
|
||||
- 从 task_plan.md 提取的所有 Phase
|
||||
- 每个 Phase 的子任务
|
||||
- 测试验证步骤
|
||||
|
||||
### Step 2: 按 Phase 执行重构
|
||||
|
||||
对于 task_plan.md 中的每个 Phase:
|
||||
|
||||
1. **读取当前代码**: 使用 Read/Grep 理解现有实现
|
||||
2. **实施修改**: 使用 Edit/Write 进行代码修改
|
||||
3. **验证修改**: 运行相关测试
|
||||
|
||||
### Step 3: 运行测试验证
|
||||
|
||||
执行 task_plan.md 中定义的测试计划,验证重构成功。
|
||||
|
||||
## GPU 限制规则
|
||||
|
||||
**严格限制**: 只能使用指定的 GPU,所有涉及 GPU 的命令必须加 `CUDA_VISIBLE_DEVICES` 前缀:
|
||||
|
||||
```bash
|
||||
# 正确
|
||||
CUDA_VISIBLE_DEVICES=$GPU_ID PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH python test.py
|
||||
|
||||
# 错误 - 禁止使用其他 GPU
|
||||
python test.py # 可能使用默认 GPU 0
|
||||
CUDA_VISIBLE_DEVICES=0,1 python test.py # 使用多个 GPU
|
||||
```
|
||||
|
||||
## 中断模式规则
|
||||
|
||||
### 当 `--no-interrupt` 生效时
|
||||
|
||||
遇到以下情况**不停下来询问用户**,而是:
|
||||
|
||||
| 情况 | 处理方式 |
|
||||
|------|----------|
|
||||
| 测试失败 | 记录失败原因,尝试自动修复,继续下一步 |
|
||||
| 代码冲突 | 尝试合理解决,记录解决方案 |
|
||||
| 不确定的实现细节 | 选择最合理的方案继续 |
|
||||
| 执行错误 | 分析错误,尝试修复,记录问题 |
|
||||
|
||||
**自动决策原则**:
|
||||
1. 优先保证功能正确性
|
||||
2. 遵循现有代码风格
|
||||
3. 选择简单直接的实现
|
||||
4. 记录所有自动决策到 `progress.md`
|
||||
|
||||
### 当未指定 `--no-interrupt` 时
|
||||
|
||||
遇到以下情况**可以询问用户**:
|
||||
- 多个实现方案需要选择
|
||||
- 测试持续失败无法自动修复
|
||||
- 发现 task_plan.md 中的问题或矛盾
|
||||
|
||||
## 执行记录
|
||||
|
||||
### 进度文件: progress.md
|
||||
|
||||
实时更新 `progress.md` 记录:
|
||||
|
||||
```markdown
|
||||
## 执行进度
|
||||
|
||||
### Phase X: [名称]
|
||||
- 状态: [进行中/完成/失败]
|
||||
- 开始时间: [时间]
|
||||
- 完成时间: [时间]
|
||||
- 修改文件: [文件列表]
|
||||
- 自动决策: [如果有]
|
||||
- 问题记录: [如果有]
|
||||
```
|
||||
|
||||
### 发现记录: findings.md
|
||||
|
||||
记录执行过程中的重要发现到 `findings.md`。
|
||||
|
||||
## 示例用法
|
||||
|
||||
```bash
|
||||
# 使用 GPU 2,允许中断
|
||||
/exec-plan --gpu 2
|
||||
|
||||
# 使用 GPU 0,不中断执行
|
||||
/exec-plan --gpu 0 --no-interrupt
|
||||
|
||||
# 简短形式
|
||||
/exec-plan -g 1 -n
|
||||
```
|
||||
|
||||
## 完成标准
|
||||
|
||||
执行完成后,确保:
|
||||
|
||||
1. **所有 Phase 完成**: task_plan.md 中的所有 Phase 都已实施
|
||||
2. **测试通过**: task_plan.md 中的测试计划全部通过
|
||||
3. **代码质量**: 修改符合项目代码规范
|
||||
4. **文档更新**: progress.md 包含完整执行记录
|
||||
|
||||
## 重要约束
|
||||
|
||||
1. **GPU 隔离**: 绝对不能使用指定 GPU 以外的设备
|
||||
2. **遵循计划**: 严格按照 task_plan.md 执行,不做计划外的修改
|
||||
3. **渐进式修改**: 每个 Phase 完成后验证,而不是最后一起验证
|
||||
4. **回滚准备**: 重大修改前考虑是否需要 git commit 保存点
|
||||
109
findings.md
Normal file
109
findings.md
Normal file
@@ -0,0 +1,109 @@
|
||||
# Findings: CUDA Graph for Offload Mode
|
||||
|
||||
## Discovery 1: 为什么 Offload Mode 不使用 CUDA Graph
|
||||
|
||||
**位置**: `nanovllm/engine/model_runner.py:421`
|
||||
|
||||
```python
|
||||
use_eager = is_prefill or self.enforce_eager or input_ids.size(0) > 512 or context.is_chunked_prefill
|
||||
```
|
||||
|
||||
**原因**: `run_chunked_offload_decode()` 设置 `is_chunked_prefill=True`,强制使用 eager mode。
|
||||
|
||||
---
|
||||
|
||||
## Discovery 2: 当前 CUDA Graph 架构
|
||||
|
||||
**文件**: `model_runner.py:682-717`
|
||||
|
||||
```python
|
||||
def capture_cudagraph(self):
|
||||
# 为不同 batch size 捕获完整 model forward
|
||||
for bs in [1, 2, 4, 8, 16, ...]:
|
||||
with torch.cuda.graph(graph):
|
||||
outputs[:bs] = self.model(input_ids[:bs], positions[:bs])
|
||||
```
|
||||
|
||||
**特点**:
|
||||
- 捕获完整的 `model()` 调用(包含所有层)
|
||||
- 使用 graph pool 共享内存
|
||||
- 只用于 decode(prefill 始终 eager)
|
||||
|
||||
---
|
||||
|
||||
## Discovery 3: Offload Decode 的 Attention 流程
|
||||
|
||||
**文件**: `nanovllm/kvcache/sparse/full_policy.py:304-379`
|
||||
|
||||
**Ring Buffer Pipeline**:
|
||||
```
|
||||
1. 预加载前 N 个 blocks 到 GPU slots
|
||||
2. 对每个 block:
|
||||
a. wait_slot_layer() # 等待 H2D
|
||||
b. get_kv_for_slot() # 获取 KV
|
||||
c. flash_attn_with_lse() # ⭐ 可 graph
|
||||
d. record_slot_compute_done()
|
||||
e. load_next_block() # 启动下一个 H2D
|
||||
f. merge_attention_outputs() # ⭐ 可 graph(但动态)
|
||||
```
|
||||
|
||||
**关键**: H2D 传输不能 graph,但 attention 计算可以。
|
||||
|
||||
---
|
||||
|
||||
## Discovery 4: 验证 Graph 复用可行性
|
||||
|
||||
**测试**: `tests/test_chunk_attention_graph_reuse.py`
|
||||
|
||||
**结论**:
|
||||
- 只需 2 个 graph(causal + non-causal)
|
||||
- 通过 `copy_()` 更新 static tensors
|
||||
- 可复用于所有层和所有 chunk pairs
|
||||
|
||||
**测试结果**:
|
||||
```
|
||||
Layer 0: max_diff=3.91e-03 ✅
|
||||
Layer 1: max_diff=7.81e-03 ✅
|
||||
Layer 2: max_diff=3.91e-03 ✅
|
||||
✅ PASSED
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Discovery 5: Chunk Size 和 Block Size 关系
|
||||
|
||||
**观察**:
|
||||
- Prefilled blocks 的 KV size = `block_size`
|
||||
- Decode buffer 的 KV size = `1` 到 `block_size`(动态)
|
||||
|
||||
**Graph 策略**:
|
||||
- Prefilled blocks: 固定 size = block_size,适合 graph
|
||||
- Decode buffer: 动态 size,建议保持 eager
|
||||
|
||||
---
|
||||
|
||||
## Discovery 6: 使用的 Triton 算子
|
||||
|
||||
**文件**: `nanovllm/ops/chunked_attention.py`
|
||||
|
||||
| 算子 | 功能 | 可 Graph |
|
||||
|------|------|----------|
|
||||
| `flash_attn_with_lse()` | Attention + LSE | ✅ |
|
||||
| `merge_attention_outputs()` | 合并两个 attention 输出 | ✅ |
|
||||
|
||||
这两个算子是纯 GPU 计算,可以被 CUDA Graph 捕获。
|
||||
|
||||
---
|
||||
|
||||
## Discovery 7: 数据依赖分析
|
||||
|
||||
**Attention 输入**:
|
||||
- `q`: 来自当前层的 QKV projection,shape 固定
|
||||
- `k, v`: 来自 GPU slot(H2D 传输后),shape = [1, block_size, heads, dim]
|
||||
|
||||
**依赖链**:
|
||||
```
|
||||
H2D(block) → wait() → get_kv() → copy_to_static() → graph.replay() → clone_output()
|
||||
```
|
||||
|
||||
**关键**: Graph 只封装 attention 计算,不包含数据传输。
|
||||
55
progress.md
Normal file
55
progress.md
Normal file
@@ -0,0 +1,55 @@
|
||||
# Progress: CUDA Graph for Offload Mode
|
||||
|
||||
## Session: 2026-01-22
|
||||
|
||||
### 调研阶段 ✅ 完成
|
||||
|
||||
**完成的调研**:
|
||||
|
||||
1. ✅ 分析 `model_runner.py` 中的 CUDA Graph 实现
|
||||
- `capture_cudagraph()`: 为不同 batch size 捕获完整 model forward
|
||||
- `run_model()`: 通过 `is_chunked_prefill` 决定 eager/graph
|
||||
|
||||
2. ✅ 分析 offload decode 流程
|
||||
- `run_chunked_offload_decode()` 设置 `is_chunked_prefill=True`
|
||||
- 导致永远使用 eager mode
|
||||
|
||||
3. ✅ 分析 ring buffer pipeline
|
||||
- `_decode_ring_buffer_pipeline()` 包含 H2D 传输 + attention 计算
|
||||
- H2D 不能 graph,attention 可以 graph
|
||||
|
||||
4. ✅ 验证 graph 复用策略
|
||||
- 创建 `test_chunk_attention_graph_reuse.py`
|
||||
- 确认 2 个 graph 可复用于所有层
|
||||
|
||||
### 计划编写 ✅ 完成
|
||||
|
||||
- ✅ 创建 `task_plan.md`
|
||||
- ✅ 创建 `findings.md`
|
||||
- ✅ 创建 `progress.md`
|
||||
|
||||
### 下一步: 实现
|
||||
|
||||
**Phase 1**: 添加 graph 捕获到 OffloadEngine
|
||||
- [ ] 在 `offload_engine.py` 添加 `capture_attention_graphs()`
|
||||
- [ ] 添加 `attention_graph_causal` 和 `attention_graph_non_causal` 属性
|
||||
|
||||
**Phase 2**: 修改 ring buffer pipeline
|
||||
- [ ] 在 `_decode_ring_buffer_pipeline()` 使用 graph replay
|
||||
- [ ] 保持 H2D 和 merge 为 eager
|
||||
|
||||
**Phase 3**: 测试
|
||||
- [ ] 运行 needle test 验证正确性
|
||||
- [ ] 对比性能
|
||||
|
||||
---
|
||||
|
||||
## 文件清单
|
||||
|
||||
| 文件 | 状态 | 说明 |
|
||||
|------|------|------|
|
||||
| `tests/test_chunk_attention_graph.py` | ✅ 已提交 | 预分配 chunk pair graphs 测试 |
|
||||
| `tests/test_chunk_attention_graph_reuse.py` | 待提交 | Graph 复用验证 |
|
||||
| `task_plan.md` | ✅ 创建 | 实现计划 |
|
||||
| `findings.md` | ✅ 创建 | 调研发现 |
|
||||
| `progress.md` | ✅ 创建 | 进度日志 |
|
||||
587
task_plan.md
587
task_plan.md
@@ -1,286 +1,357 @@
|
||||
# Task Plan: XAttention BSA 真正的 Sparse 实现
|
||||
# Task Plan: CUDA Graph 优化 Offload Mode Decode
|
||||
|
||||
## Goal
|
||||
## 目标
|
||||
|
||||
实现 XAttentionBSAPolicy 的真正 sparse attention,在 `select_blocks` 中使用 `xattn_estimate_chunked` 选择重要的 blocks,然后复用 FullAttentionPolicy 的 ring buffer pipeline。
|
||||
为 nanovllm 的 CPU offload 模式添加 CUDA Graph 支持,加速 decode 阶段的计算。
|
||||
|
||||
**验收标准**:
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
||||
python tests/test_ruler.py \
|
||||
--model ~/models/Llama-3.1-8B-Instruct \
|
||||
--enable-offload \
|
||||
--sparse-policy XATTN_BSA \
|
||||
--datasets niah_single_1 \
|
||||
--sample-indices 0,1,2,3,4
|
||||
# 期望: 5/5 PASS,并且真正使用 sparse selection
|
||||
```
|
||||
## 问题分析
|
||||
|
||||
## 当前状态: Phase 1 - 代码分析完成
|
||||
|
||||
## 核心设计理解
|
||||
|
||||
### 1. Block Size 关系
|
||||
|
||||
| 参数 | 值 | 说明 |
|
||||
|------|-----|------|
|
||||
| BSA block_size | 128 tokens | XAttention 的 block 粒度 |
|
||||
| kvcache_block_size | 1024 tokens | CPU offload 的 block 粒度 |
|
||||
| 比例 | 1:8 | 1 CPU block = 8 BSA blocks |
|
||||
|
||||
### 2. 特化条件(用户要求)
|
||||
|
||||
- BSA chunk_size = 外部 chunk_size
|
||||
- 这样 `xattn_estimate_chunked` 返回的 mask 可以直接映射到 CPU block selection
|
||||
- 复用现有的 `flash_attn_with_lse` + `merge_attention_outputs`
|
||||
|
||||
### 3. select_blocks 设计
|
||||
### Transformer 层的完整结构
|
||||
|
||||
```
|
||||
select_blocks(available_blocks, offload_engine, ctx) -> List[int]
|
||||
│
|
||||
├─ 1. 从 metadata cache 获取下采样的 K
|
||||
│ (在 on_prefill_offload 中收集)
|
||||
│
|
||||
├─ 2. 调用 xattn_estimate_chunked(Q, K_downsampled, q_start_pos)
|
||||
│ 返回 mask: [B, H, q_blocks, k_blocks]
|
||||
│
|
||||
├─ 3. 将 BSA k_blocks 映射到 CPU block IDs
|
||||
│ 每 8 个 BSA blocks = 1 CPU block
|
||||
│ 只要 8 个中有任意一个被选中,就保留该 CPU block
|
||||
│
|
||||
└─ 4. 返回 selected_cpu_blocks
|
||||
Qwen3DecoderLayer.forward:
|
||||
├── input_layernorm (RMSNorm) # ✅ 纯 GPU
|
||||
├── self_attn:
|
||||
│ ├── qkv_proj (Linear) # ✅ 纯 GPU
|
||||
│ ├── q_norm, k_norm (RMSNorm) # ✅ 纯 GPU
|
||||
│ ├── rotary_emb # ✅ 纯 GPU
|
||||
│ ├── attn._chunked_decode_attention: # ⚠️ 包含 CPU→GPU
|
||||
│ │ ├── H2D transfer # ❌ 不能 graph
|
||||
│ │ ├── flash_attn_with_lse # ✅ 可以 graph
|
||||
│ │ └── merge # ✅ 纯 GPU
|
||||
│ └── o_proj (Linear) # ✅ 纯 GPU
|
||||
├── post_attention_layernorm # ✅ 纯 GPU
|
||||
└── mlp (FFN: gate, up, down) # ✅ 纯 GPU
|
||||
```
|
||||
|
||||
### 4. Metadata 存储策略
|
||||
**核心问题**:H2D 传输被嵌在 attention 中间,打断了整层的 graph 捕获。
|
||||
|
||||
**方案 A**: 存储下采样的 K(内存友好)
|
||||
### 可能的方案
|
||||
|
||||
| 方案 | 描述 | 优点 | 缺点 |
|
||||
|------|------|------|------|
|
||||
| A. 分段 Graph | 将层拆分为 pre/post attention 两段 | 覆盖面广 | 改动大,需拆分层执行 |
|
||||
| B. 只 Graph Attention | 只优化 flash_attn_with_lse | 改动小 | 优化效果有限 |
|
||||
| C. 重构执行流程 | 完全重写 model forward | 最优效果 | 工作量巨大 |
|
||||
|
||||
### 推荐:方案 A(分段 Graph)
|
||||
|
||||
将每层拆分为两个 graph:
|
||||
1. **pre_attention_graph**: `norm → qkv_proj → q/k_norm → rotary`
|
||||
2. **post_attention_graph**: `o_proj → norm → FFN`
|
||||
|
||||
中间的 `_chunked_decode_attention` 保持 eager(包含 H2D),但内部的 `flash_attn_with_lse` 使用 graph。
|
||||
|
||||
---
|
||||
|
||||
## 当前状态分析
|
||||
|
||||
### 现有 CUDA Graph 实现
|
||||
|
||||
**文件**: `nanovllm/engine/model_runner.py`
|
||||
|
||||
| 方法 | 行号 | 功能 |
|
||||
|------|------|------|
|
||||
| `capture_cudagraph()` | 682-717 | 为不同 batch size 捕获完整 model forward |
|
||||
| `run_model()` | 415-436 | 决定使用 eager 还是 graph replay |
|
||||
|
||||
**关键逻辑** (`run_model`):
|
||||
```python
|
||||
# on_prefill_offload 中:
|
||||
k_downsampled = k_cache[::stride] # [block_size/stride, H, D]
|
||||
self._k_cache[layer_id][cpu_block_id] = k_downsampled
|
||||
use_eager = is_prefill or self.enforce_eager or input_ids.size(0) > 512 or context.is_chunked_prefill
|
||||
```
|
||||
|
||||
**内存计算** (stride=8):
|
||||
- 每 block: (1024/8) * 8 * 128 * 2 bytes = 256 KB
|
||||
- 256 blocks * 32 layers = 2 GB (GPU 上用于快速估计)
|
||||
**问题**: `run_chunked_offload_decode` 设置 `is_chunked_prefill=True`,导致**永远使用 eager mode**。
|
||||
|
||||
**方案 B**: 存储 min/max metadata (更省内存)
|
||||
```python
|
||||
# on_prefill_offload 中:
|
||||
k_min = k_cache[:num_valid].min(dim=0).values # [H, D]
|
||||
k_max = k_cache[:num_valid].max(dim=0).values # [H, D]
|
||||
### Offload Decode 流程
|
||||
|
||||
**文件**: `nanovllm/kvcache/sparse/full_policy.py`
|
||||
|
||||
`_decode_ring_buffer_pipeline()` (L304-379):
|
||||
```
|
||||
- 但这需要不同的估计算法,不能直接用 xattn_estimate
|
||||
|
||||
**决定**: 使用方案 A(下采样 K),因为可以直接复用 xattn_estimate_chunked
|
||||
|
||||
## Phases
|
||||
|
||||
- [x] Phase 1: 代码分析,理解当前实现
|
||||
- [ ] Phase 2: 实现 on_prefill_offload 收集 K metadata
|
||||
- [ ] Phase 3: 实现 select_blocks 中的 xattn estimation
|
||||
- [ ] Phase 4: 实现 BSA block → CPU block 的映射
|
||||
- [ ] Phase 5: 测试验证
|
||||
|
||||
## Phase 2: on_prefill_offload 实现
|
||||
|
||||
### 需要修改的文件
|
||||
- `nanovllm/kvcache/sparse/xattn_bsa.py`
|
||||
|
||||
### 实现细节
|
||||
|
||||
```python
|
||||
class XAttentionBSAPolicy(SparsePolicy):
|
||||
def __init__(self, threshold=0.9, stride=8, ...):
|
||||
self.threshold = threshold
|
||||
self.stride = stride
|
||||
self._k_cache: Dict[int, Dict[int, torch.Tensor]] = {}
|
||||
# _k_cache[layer_id][cpu_block_id] = k_downsampled
|
||||
|
||||
def initialize(self, num_layers, num_kv_heads, head_dim, num_cpu_blocks, dtype, device):
|
||||
"""初始化 K cache 结构"""
|
||||
self._k_cache = {layer_id: {} for layer_id in range(num_layers)}
|
||||
self._num_kv_heads = num_kv_heads
|
||||
self._head_dim = head_dim
|
||||
|
||||
def on_prefill_offload(self, cpu_block_id, layer_id, k_cache, num_valid_tokens):
|
||||
"""收集下采样的 K 用于后续估计"""
|
||||
# k_cache: [block_size, num_kv_heads, head_dim]
|
||||
k_downsampled = k_cache[:num_valid_tokens:self.stride].clone()
|
||||
# k_downsampled: [num_valid_tokens//stride, num_kv_heads, head_dim]
|
||||
self._k_cache[layer_id][cpu_block_id] = k_downsampled
|
||||
for block in cpu_blocks:
|
||||
1. wait_slot_layer(slot) # 等待 H2D 完成
|
||||
2. k, v = get_kv_for_slot(slot) # 获取 KV
|
||||
3. o, lse = flash_attn_with_lse() # ⭐ 纯 GPU 计算
|
||||
4. record_slot_compute_done(slot) # 标记计算完成
|
||||
5. load_next_block() # 启动下一个 H2D
|
||||
6. merge_attention_outputs() # ⭐ 纯 GPU 计算
|
||||
```
|
||||
|
||||
## Phase 3: select_blocks 实现
|
||||
**可 Graph 化的部分**:
|
||||
- `flash_attn_with_lse()` - 纯 GPU 计算
|
||||
- 不可 Graph 化: H2D 传输、动态 merge
|
||||
|
||||
### 关键问题
|
||||
## 验证结果
|
||||
|
||||
1. **Q 从哪里来?**
|
||||
- `ctx.query` 需要在调用 select_blocks 时传入
|
||||
- 当前 FullAttentionPolicy 传递 `query=None`
|
||||
- 需要修改 compute_chunked_prefill 传递真实的 Q
|
||||
**测试文件**: `tests/test_chunk_attention_graph_reuse.py`
|
||||
|
||||
2. **Q 的格式转换**
|
||||
- 输入 Q: [seq_len, num_heads, head_dim]
|
||||
- xattn 需要: [B, H, q_len, D]
|
||||
- 转换: `q.unsqueeze(0).transpose(1, 2)`
|
||||
|
||||
3. **K 的组装**
|
||||
- 从 `_k_cache[layer_id]` 获取各 block 的下采样 K
|
||||
- 按 `available_blocks` 顺序 cat 起来
|
||||
- 结果: [B, H, total_k_downsampled, D]
|
||||
|
||||
### 实现草案
|
||||
|
||||
```python
|
||||
def select_blocks(self, available_blocks, offload_engine, ctx):
|
||||
if not available_blocks or ctx.query is None:
|
||||
return available_blocks
|
||||
|
||||
layer_id = ctx.layer_id
|
||||
|
||||
# 1. 组装下采样的 K
|
||||
k_list = []
|
||||
for cpu_block_id in available_blocks:
|
||||
if cpu_block_id in self._k_cache[layer_id]:
|
||||
k_list.append(self._k_cache[layer_id][cpu_block_id])
|
||||
|
||||
if not k_list:
|
||||
return available_blocks
|
||||
|
||||
k_hist = torch.cat(k_list, dim=0) # [total_tokens/stride, H, D]
|
||||
k_hist = k_hist.unsqueeze(0).transpose(1, 2) # [1, H, k_len, D]
|
||||
|
||||
# 2. 准备 Q
|
||||
q = ctx.query # [seq_len, num_heads, head_dim]
|
||||
q = q.unsqueeze(0).transpose(1, 2) # [1, H, q_len, D]
|
||||
|
||||
# GQA 扩展(如果需要)
|
||||
if q.shape[1] != k_hist.shape[1]:
|
||||
num_groups = q.shape[1] // k_hist.shape[1]
|
||||
k_hist = k_hist.repeat_interleave(num_groups, dim=1)
|
||||
|
||||
# 3. 计算 q_start_pos
|
||||
q_start_pos = len(available_blocks) * ctx.block_size
|
||||
|
||||
# 4. 调用 xattn_estimate_chunked
|
||||
# 注意:K 已经是下采样的,需要调整参数
|
||||
attn_sum, mask = xattn_estimate_chunked(
|
||||
q, k_hist,
|
||||
q_start_pos=q_start_pos // self.stride, # 调整到下采样空间
|
||||
block_size=self.BSA_BLOCK_SIZE // self.stride, # 16
|
||||
stride=1, # K 已经下采样
|
||||
threshold=self.threshold,
|
||||
chunk_size=q.shape[2], # 与 Q 长度一致
|
||||
use_triton=self.use_triton,
|
||||
)
|
||||
|
||||
# 5. 从 mask 提取 CPU block IDs
|
||||
# mask: [1, H, q_blocks, k_blocks]
|
||||
# 对所有 heads 取 OR
|
||||
selected_mask = mask.any(dim=1).squeeze(0) # [q_blocks, k_blocks]
|
||||
# 对所有 q_blocks 取 OR(只要任意 Q 位置需要这个 K block)
|
||||
selected_k_mask = selected_mask.any(dim=0) # [k_blocks]
|
||||
|
||||
# 6. 映射 BSA blocks → CPU blocks
|
||||
# 每个 CPU block = 8 BSA blocks (block_size=1024, BSA_block=128)
|
||||
bsa_to_cpu_ratio = ctx.block_size // self.BSA_BLOCK_SIZE # 8
|
||||
num_cpu_blocks = len(available_blocks)
|
||||
|
||||
selected_cpu_indices = set()
|
||||
for bsa_idx in selected_k_mask.nonzero(as_tuple=True)[0].tolist():
|
||||
cpu_idx = bsa_idx // bsa_to_cpu_ratio
|
||||
if cpu_idx < num_cpu_blocks:
|
||||
selected_cpu_indices.add(cpu_idx)
|
||||
|
||||
selected_blocks = [available_blocks[i] for i in sorted(selected_cpu_indices)]
|
||||
|
||||
logger.info(f"[XAttn] select_blocks: {len(available_blocks)} -> {len(selected_blocks)} "
|
||||
f"({100*len(selected_blocks)/len(available_blocks):.1f}%)")
|
||||
|
||||
return selected_blocks
|
||||
```
|
||||
|
||||
## Phase 4: compute_chunked_prefill
|
||||
|
||||
### 关键修改
|
||||
|
||||
1. **传递真实的 Q 给 select_blocks**
|
||||
- 修改 PolicyContext 构造,设置 `query=q`
|
||||
|
||||
2. **复用 FullAttentionPolicy 的 pipeline**
|
||||
- 继承 FullAttentionPolicy 而不是 SparsePolicy
|
||||
- 或者直接调用父类方法
|
||||
|
||||
### 方案对比
|
||||
|
||||
**方案 A**: XAttentionBSAPolicy 继承 FullAttentionPolicy
|
||||
```python
|
||||
class XAttentionBSAPolicy(FullAttentionPolicy):
|
||||
# 只需要 override select_blocks 和 on_prefill_offload
|
||||
# compute_chunked_prefill 直接用父类的
|
||||
```
|
||||
|
||||
**方案 B**: 独立实现,调用相同的 pipeline 代码
|
||||
```python
|
||||
class XAttentionBSAPolicy(SparsePolicy):
|
||||
def compute_chunked_prefill(self, q, k, v, ...):
|
||||
# 复制 FullAttentionPolicy 的代码
|
||||
# 但修改 PolicyContext 传递 query=q
|
||||
```
|
||||
|
||||
**决定**: 使用方案 B,因为需要在 compute_chunked_prefill 中修改 PolicyContext
|
||||
|
||||
## Phase 5: 测试
|
||||
|
||||
### 单元测试
|
||||
```bash
|
||||
# 测试 select_blocks 的 sparsity
|
||||
python -c "
|
||||
from nanovllm.kvcache.sparse.xattn_bsa import XAttentionBSAPolicy
|
||||
policy = XAttentionBSAPolicy(threshold=0.9)
|
||||
# ... 测试代码
|
||||
"
|
||||
```
|
||||
|
||||
### 集成测试
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
||||
python tests/test_ruler.py \
|
||||
--model ~/models/Llama-3.1-8B-Instruct \
|
||||
--enable-offload \
|
||||
--sparse-policy XATTN_BSA \
|
||||
--datasets niah_single_1 \
|
||||
--sample-indices 0,1,2,3,4
|
||||
```
|
||||
|
||||
## Key Decisions
|
||||
|
||||
| 决策 | 理由 |
|
||||
| 测试 | 结果 |
|
||||
|------|------|
|
||||
| 使用下采样 K 作为 metadata | 可以直接复用 xattn_estimate_chunked |
|
||||
| stride=8 | 平衡内存和精度 |
|
||||
| BSA blocks → CPU blocks 映射用 OR | 只要有一个 BSA block 被选中就保留 |
|
||||
| 继承 FullAttentionPolicy 的 pipeline | 复用已验证的 ring buffer 流程 |
|
||||
| 2 个 Graph 复用于所有层和所有 chunk | ✅ PASSED |
|
||||
| copy_() 更新 static tensors | ✅ 有效 |
|
||||
| Eager merge | ✅ 用户已接受 |
|
||||
|
||||
## Files to Modify
|
||||
**结论**: 只需 2 个 graph(causal + non-causal),通过 copy_() 复用。
|
||||
|
||||
| 文件 | 修改 |
|
||||
|------|------|
|
||||
| `nanovllm/kvcache/sparse/xattn_bsa.py` | 主要实现:initialize, on_prefill_offload, select_blocks |
|
||||
---
|
||||
|
||||
## 注意事项
|
||||
## 修改计划(方案 A:分段 Graph)
|
||||
|
||||
1. **GQA 处理**: Llama-3.1-8B 有 32 query heads, 8 kv heads,需要在估计时扩展 K
|
||||
2. **内存管理**: `_k_cache` 存储在 GPU,需要在 reset() 时清理
|
||||
3. **Triton 兼容性**: xattn_estimate_chunked 有 Triton bug,可能需要用 PyTorch fallback
|
||||
4. **边界条件**: 第一个 chunk (available_blocks=[]) 时直接返回空列表
|
||||
### 架构设计
|
||||
|
||||
## Errors Encountered
|
||||
```
|
||||
每层执行流程(Offload Decode):
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ PRE-ATTENTION GRAPH (可复用于所有层) │
|
||||
│ input_layernorm → qkv_proj → q/k_norm → rotary → split Q │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
↓
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ CHUNKED ATTENTION (Eager + 部分 Graph) │
|
||||
│ for block in cpu_blocks: │
|
||||
│ H2D transfer (eager) │
|
||||
│ flash_attn_with_lse (GRAPH - 2个可复用) │
|
||||
│ merge (eager) │
|
||||
│ decode_buffer attention (eager) │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
↓
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ POST-ATTENTION GRAPH (可复用于所有层) │
|
||||
│ o_proj → post_layernorm → gate_proj → up_proj → SiLU │
|
||||
│ → down_proj → residual │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
(待填充)
|
||||
**总共需要的 Graph 数量**:
|
||||
- 1 个 pre_attention_graph(所有层复用)
|
||||
- 2 个 attention_graph(causal + non-causal,所有层复用)
|
||||
- 1 个 post_attention_graph(所有层复用)
|
||||
- **总计: 4 个 graph**
|
||||
|
||||
## Status
|
||||
---
|
||||
|
||||
**Currently in Phase 1** - 代码分析完成,准备开始 Phase 2 实现
|
||||
### Phase 1: 拆分 DecoderLayer 执行
|
||||
|
||||
**目标**: 将 `Qwen3DecoderLayer.forward` 拆分为可独立调用的三段
|
||||
|
||||
**修改文件**: `nanovllm/models/qwen3.py`
|
||||
|
||||
**新增方法**:
|
||||
```python
|
||||
class Qwen3DecoderLayer:
|
||||
def forward_pre_attention(self, positions, hidden_states, residual):
|
||||
"""Pre-attention: norm → qkv → rotary → 返回 q, k, v"""
|
||||
if residual is None:
|
||||
hidden_states, residual = self.input_layernorm(hidden_states), hidden_states
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
|
||||
qkv = self.self_attn.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q = q.view(-1, self.num_heads, self.head_dim)
|
||||
k = k.view(-1, self.num_kv_heads, self.head_dim)
|
||||
v = v.view(-1, self.num_kv_heads, self.head_dim)
|
||||
q = self.self_attn.q_norm(q)
|
||||
k = self.self_attn.k_norm(k)
|
||||
q, k = self.self_attn.rotary_emb(positions, q, k)
|
||||
return q, k, v, hidden_states, residual
|
||||
|
||||
def forward_post_attention(self, attn_output, hidden_states, residual):
|
||||
"""Post-attention: o_proj → norm → FFN"""
|
||||
output = self.self_attn.o_proj(attn_output.flatten(1, -1))
|
||||
hidden_states, residual = self.post_attention_layernorm(output, residual)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
return hidden_states, residual
|
||||
```
|
||||
|
||||
**状态**: `pending`
|
||||
|
||||
---
|
||||
|
||||
### Phase 2: 捕获 Pre/Post Attention Graph
|
||||
|
||||
**目标**: 捕获 pre_attention 和 post_attention 的 graph
|
||||
|
||||
**修改文件**: `nanovllm/engine/model_runner.py`
|
||||
|
||||
**新增方法**: `capture_offload_layer_graphs()`
|
||||
|
||||
```python
|
||||
def capture_offload_layer_graphs(self):
|
||||
"""捕获 offload mode 的 layer graphs"""
|
||||
# 获取任意一层作为模板(所有层结构相同)
|
||||
layer = self.model.model.layers[0]
|
||||
|
||||
# Static tensors
|
||||
static_hidden = torch.zeros(1, self.hidden_size, ...)
|
||||
static_residual = torch.zeros(1, self.hidden_size, ...)
|
||||
static_positions = torch.zeros(1, ...)
|
||||
|
||||
# Pre-attention graph
|
||||
self.pre_attn_graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(self.pre_attn_graph):
|
||||
static_q, static_k, static_v, _, _ = layer.forward_pre_attention(
|
||||
static_positions, static_hidden, static_residual
|
||||
)
|
||||
|
||||
# Post-attention graph
|
||||
self.post_attn_graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(self.post_attn_graph):
|
||||
_, _ = layer.forward_post_attention(
|
||||
static_attn_output, static_hidden, static_residual
|
||||
)
|
||||
```
|
||||
|
||||
**状态**: `pending`
|
||||
|
||||
---
|
||||
|
||||
### Phase 3: 捕获 Attention Graph
|
||||
|
||||
**目标**: 捕获 2 个 attention graph(causal + non-causal)
|
||||
|
||||
**修改文件**: `nanovllm/kvcache/offload_engine.py`
|
||||
|
||||
```python
|
||||
class OffloadEngine:
|
||||
def capture_attention_graphs(self):
|
||||
"""捕获 attention graphs(复用于所有层)"""
|
||||
self.attn_graph_causal = self._capture_attn_graph(causal=True)
|
||||
self.attn_graph_non_causal = self._capture_attn_graph(causal=False)
|
||||
|
||||
def _capture_attn_graph(self, causal: bool):
|
||||
static_q = torch.zeros(1, 1, num_heads, head_dim, ...)
|
||||
static_k = torch.zeros(1, block_size, num_kv_heads, head_dim, ...)
|
||||
static_v = torch.zeros(1, block_size, num_kv_heads, head_dim, ...)
|
||||
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph):
|
||||
output, lse = flash_attn_with_lse(static_q, static_k, static_v,
|
||||
self.scale, causal)
|
||||
return AttentionGraph(graph, static_q, static_k, static_v, output, lse)
|
||||
```
|
||||
|
||||
**状态**: `pending`
|
||||
|
||||
---
|
||||
|
||||
### Phase 4: 修改 Offload Decode 执行流程
|
||||
|
||||
**目标**: 使用 graph replay 执行 offload decode
|
||||
|
||||
**修改文件**: `nanovllm/engine/model_runner.py`
|
||||
|
||||
**修改方法**: `run_chunked_offload_decode()`
|
||||
|
||||
```python
|
||||
def run_chunked_offload_decode_with_graph(self, seqs):
|
||||
"""使用 graph 加速的 offload decode"""
|
||||
seq = seqs[0]
|
||||
|
||||
# 准备输入
|
||||
input_ids = torch.tensor([seq.last_token], ...)
|
||||
positions = torch.tensor([len(seq) - 1], ...)
|
||||
|
||||
# Embedding
|
||||
hidden_states = self.model.model.embed_tokens(input_ids)
|
||||
residual = None
|
||||
|
||||
for layer_id, layer in enumerate(self.model.model.layers):
|
||||
# Phase 1: Pre-attention (GRAPH)
|
||||
self.pre_attn_vars["hidden"].copy_(hidden_states)
|
||||
self.pre_attn_vars["residual"].copy_(residual) if residual else None
|
||||
self.pre_attn_vars["positions"].copy_(positions)
|
||||
self.pre_attn_graph.replay()
|
||||
q = self.pre_attn_vars["q"].clone()
|
||||
k = self.pre_attn_vars["k"].clone()
|
||||
v = self.pre_attn_vars["v"].clone()
|
||||
|
||||
# Phase 2: Chunked Attention (Eager + Graph)
|
||||
attn_output = self._chunked_attention_with_graph(q, k, v, layer_id, ...)
|
||||
|
||||
# Phase 3: Post-attention (GRAPH)
|
||||
self.post_attn_vars["attn_output"].copy_(attn_output)
|
||||
self.post_attn_graph.replay()
|
||||
hidden_states = self.post_attn_vars["hidden"].clone()
|
||||
residual = self.post_attn_vars["residual"].clone()
|
||||
|
||||
# LM head
|
||||
logits = self.model.compute_logits(hidden_states)
|
||||
return logits
|
||||
```
|
||||
|
||||
**状态**: `pending`
|
||||
|
||||
---
|
||||
|
||||
### Phase 5: 修改 Ring Buffer Pipeline
|
||||
|
||||
**目标**: 在 attention 内部使用 graph
|
||||
|
||||
**修改文件**: `nanovllm/kvcache/sparse/full_policy.py`
|
||||
|
||||
**修改**: `_decode_ring_buffer_pipeline()` 中的 `flash_attn_with_lse` 调用
|
||||
|
||||
```python
|
||||
# 当前:eager
|
||||
prev_o, prev_lse = flash_attn_with_lse(q, k, v, scale, causal=False)
|
||||
|
||||
# 修改为:graph replay
|
||||
graph = offload_engine.attn_graph_non_causal
|
||||
graph.static_q.copy_(q)
|
||||
graph.static_k.copy_(k)
|
||||
graph.static_v.copy_(v)
|
||||
graph.graph.replay()
|
||||
prev_o = graph.static_output.clone()
|
||||
prev_lse = graph.static_lse.clone()
|
||||
```
|
||||
|
||||
**状态**: `pending`
|
||||
|
||||
---
|
||||
|
||||
### Phase 6: 添加配置开关
|
||||
|
||||
**修改文件**: `nanovllm/config.py`
|
||||
|
||||
```python
|
||||
enable_offload_graph: bool = True # 默认启用
|
||||
```
|
||||
|
||||
**状态**: `pending`
|
||||
|
||||
---
|
||||
|
||||
## 文件修改清单
|
||||
|
||||
| 文件 | 修改类型 | 说明 |
|
||||
|------|----------|------|
|
||||
| `nanovllm/engine/model_runner.py` | 新增方法 | `capture_offload_attention_graph()` |
|
||||
| `nanovllm/kvcache/offload_engine.py` | 新增属性+方法 | Graph 存储和访问 |
|
||||
| `nanovllm/kvcache/sparse/full_policy.py` | 修改方法 | 使用 graph replay |
|
||||
| `nanovllm/config.py` | 新增配置 | `enable_offload_graph` |
|
||||
|
||||
---
|
||||
|
||||
## 风险和注意事项
|
||||
|
||||
1. **Graph 捕获时机**: 需要在 KV cache 分配后、第一次 decode 前捕获
|
||||
2. **Chunk size 匹配**: Graph 的 chunk_size 必须和 block_size 一致
|
||||
3. **多 GPU**: Graph 需要在每个 GPU 上分别捕获
|
||||
4. **内存**: 2 个 graph 的额外内存开销很小
|
||||
|
||||
---
|
||||
|
||||
## 测试计划
|
||||
|
||||
1. **单元测试**: 验证 graph replay 结果正确
|
||||
2. **集成测试**: 运行 `test_needle.py --enable-offload --input-len 32768`
|
||||
3. **性能测试**: 对比 eager vs graph 的 decode 延迟
|
||||
|
||||
---
|
||||
|
||||
## 预期收益
|
||||
|
||||
- Decode 阶段 attention 计算加速(减少 kernel launch overhead)
|
||||
- 与现有 ring buffer pipeline 兼容
|
||||
- 内存开销极小(只有 2 个额外 graph)
|
||||
|
||||
156
tests/test_chunk_attention_graph_reuse.py
Normal file
156
tests/test_chunk_attention_graph_reuse.py
Normal file
@@ -0,0 +1,156 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test: Reuse a single CUDA Graph across all layers and all chunk pairs.
|
||||
|
||||
Key insight: LLM layers have identical computation structure.
|
||||
We only need 2 graphs (causal + non-causal), reused for all (layer, Q_i, K_j) combinations.
|
||||
|
||||
Usage:
|
||||
CUDA_VISIBLE_DEVICES=0 python tests/test_chunk_attention_graph_reuse.py
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from nanovllm.ops.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReusableChunkGraph:
|
||||
"""A single graph that can be reused with copy_() updates."""
|
||||
graph: torch.cuda.CUDAGraph
|
||||
static_q: torch.Tensor
|
||||
static_k: torch.Tensor
|
||||
static_v: torch.Tensor
|
||||
static_output: torch.Tensor
|
||||
static_lse: torch.Tensor
|
||||
|
||||
|
||||
def capture_reusable_graph(
|
||||
chunk_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_dim: int,
|
||||
scale: float,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
causal: bool,
|
||||
) -> ReusableChunkGraph:
|
||||
"""Capture ONE graph to be reused for all chunk pairs."""
|
||||
static_q = torch.zeros(1, chunk_size, num_heads, head_dim, dtype=dtype, device=device)
|
||||
static_k = torch.zeros(1, chunk_size, num_kv_heads, head_dim, dtype=dtype, device=device)
|
||||
static_v = torch.zeros(1, chunk_size, num_kv_heads, head_dim, dtype=dtype, device=device)
|
||||
|
||||
static_q.normal_()
|
||||
static_k.normal_()
|
||||
static_v.normal_()
|
||||
|
||||
# Warmup
|
||||
with torch.inference_mode():
|
||||
for _ in range(3):
|
||||
_ = flash_attn_with_lse(static_q, static_k, static_v, scale, causal)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Capture
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.inference_mode():
|
||||
with torch.cuda.graph(graph):
|
||||
static_output, static_lse = flash_attn_with_lse(static_q, static_k, static_v, scale, causal)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return ReusableChunkGraph(
|
||||
graph=graph,
|
||||
static_q=static_q,
|
||||
static_k=static_k,
|
||||
static_v=static_v,
|
||||
static_output=static_output,
|
||||
static_lse=static_lse,
|
||||
)
|
||||
|
||||
|
||||
def replay_with_copy(graph: ReusableChunkGraph, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
|
||||
"""Replay graph after updating static tensors with copy_()."""
|
||||
graph.static_q.copy_(q)
|
||||
graph.static_k.copy_(k)
|
||||
graph.static_v.copy_(v)
|
||||
graph.graph.replay()
|
||||
return graph.static_output.clone(), graph.static_lse.clone()
|
||||
|
||||
|
||||
def main():
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.bfloat16
|
||||
|
||||
chunk_size = 64
|
||||
num_chunks = 4
|
||||
num_layers = 3 # Simulate multiple layers
|
||||
num_heads = 8
|
||||
num_kv_heads = 8
|
||||
head_dim = 64
|
||||
scale = 1.0 / (head_dim ** 0.5)
|
||||
seq_len = chunk_size * num_chunks
|
||||
|
||||
print(f"Device: {torch.cuda.get_device_name()}")
|
||||
print(f"Chunk size: {chunk_size}, Num chunks: {num_chunks}, Num layers: {num_layers}")
|
||||
print(f"Only 2 graphs (causal + non-causal) for ALL layer × chunk combinations")
|
||||
|
||||
# Capture only 2 graphs
|
||||
graph_causal = capture_reusable_graph(
|
||||
chunk_size, num_heads, num_kv_heads, head_dim, scale, device, dtype, causal=True
|
||||
)
|
||||
graph_non_causal = capture_reusable_graph(
|
||||
chunk_size, num_heads, num_kv_heads, head_dim, scale, device, dtype, causal=False
|
||||
)
|
||||
print("2 graphs captured (causal + non-causal)")
|
||||
|
||||
all_pass = True
|
||||
|
||||
for layer_id in range(num_layers):
|
||||
# Different Q/K/V for each layer (simulating different layer outputs)
|
||||
full_q = torch.randn(1, seq_len, num_heads, head_dim, dtype=dtype, device=device)
|
||||
full_k = torch.randn(1, seq_len, num_kv_heads, head_dim, dtype=dtype, device=device)
|
||||
full_v = torch.randn(1, seq_len, num_kv_heads, head_dim, dtype=dtype, device=device)
|
||||
|
||||
# Reference: full causal attention
|
||||
with torch.inference_mode():
|
||||
full_output, _ = flash_attn_with_lse(full_q, full_k, full_v, scale, causal=True)
|
||||
|
||||
# Chunked with graph reuse
|
||||
chunked_output = torch.zeros_like(full_output)
|
||||
|
||||
for q_idx in range(num_chunks):
|
||||
q_chunk = full_q[:, q_idx*chunk_size:(q_idx+1)*chunk_size]
|
||||
acc_out, acc_lse = None, None
|
||||
|
||||
for k_idx in range(q_idx + 1):
|
||||
k_chunk = full_k[:, k_idx*chunk_size:(k_idx+1)*chunk_size]
|
||||
v_chunk = full_v[:, k_idx*chunk_size:(k_idx+1)*chunk_size]
|
||||
|
||||
# Reuse graph with copy_()
|
||||
graph = graph_causal if k_idx == q_idx else graph_non_causal
|
||||
out, lse = replay_with_copy(graph, q_chunk, k_chunk, v_chunk)
|
||||
|
||||
if acc_out is None:
|
||||
acc_out, acc_lse = out, lse
|
||||
else:
|
||||
with torch.inference_mode():
|
||||
acc_out, acc_lse = merge_attention_outputs(acc_out, acc_lse, out, lse)
|
||||
|
||||
chunked_output[:, q_idx*chunk_size:(q_idx+1)*chunk_size] = acc_out
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Compare
|
||||
max_diff = (full_output - chunked_output).abs().max().item()
|
||||
status = "✅" if max_diff < 1e-2 else "❌"
|
||||
print(f"Layer {layer_id}: max_diff={max_diff:.2e} {status}")
|
||||
if max_diff >= 1e-2:
|
||||
all_pass = False
|
||||
|
||||
print("✅ PASSED - Single graph reuse across layers works!" if all_pass else "❌ FAILED")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user