Compare commits
6 Commits
b1f292cf22
...
07f5220f40
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
07f5220f40 | ||
|
|
37aecd4d52 | ||
|
|
fa7601f4b8 | ||
|
|
6080bf7554 | ||
|
|
e5a17c832c | ||
|
|
4593f42ec3 |
@@ -1,9 +0,0 @@
|
||||
---
|
||||
active: true
|
||||
iteration: 1
|
||||
max_iterations: 0
|
||||
completion_promise: "COMPLETE"
|
||||
started_at: "2026-01-19T17:25:00Z"
|
||||
---
|
||||
|
||||
请你按照 task_plan.md的要求,进行 nanovllm 的代码重构,确保plan 中最终目标可以圆满实现,注意你仅仅只能使用 GPU 0 来进行调试,其他 GPU 一定不能使用。最终将测试结果写一个报告。 <promise>COMPLETE</promise> -max-iterations 30
|
||||
@@ -77,6 +77,45 @@ Claude: Runs `python tests/test_needle.py ...` # NO! Missing GPU specification!
|
||||
|
||||
---
|
||||
|
||||
## Needle Test Requirements (MANDATORY)
|
||||
|
||||
When running `test_needle.py`, **ALWAYS** use these settings:
|
||||
|
||||
1. **Enable offload**: `--enable-offload` is **REQUIRED**
|
||||
2. **Use 32K context**: `--input-len 32768` is **REQUIRED**
|
||||
|
||||
### Standard Needle Test Command
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=X PYTHONPATH=/path/to/nano-vllm:$PYTHONPATH \
|
||||
python tests/test_needle.py \
|
||||
--model ~/models/Llama-3.1-8B-Instruct \
|
||||
--enable-offload \
|
||||
--input-len 32768
|
||||
```
|
||||
|
||||
### Why These Settings?
|
||||
|
||||
| Setting | Reason |
|
||||
|---------|--------|
|
||||
| `--enable-offload` | Tests the CPU offload pipeline which is the main feature being developed |
|
||||
| `--input-len 32768` | 32K context properly exercises the chunked prefill/decode paths; 8K is too short to catch many issues |
|
||||
|
||||
### Do NOT Use
|
||||
|
||||
```bash
|
||||
# ❌ Wrong: Missing offload
|
||||
python tests/test_needle.py --model ~/models/Llama-3.1-8B-Instruct
|
||||
|
||||
# ❌ Wrong: Too short (default 8K)
|
||||
python tests/test_needle.py --model ~/models/Llama-3.1-8B-Instruct --enable-offload
|
||||
|
||||
# ✅ Correct: Offload + 32K
|
||||
python tests/test_needle.py --model ~/models/Llama-3.1-8B-Instruct --enable-offload --input-len 32768
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Combined Checklist
|
||||
|
||||
Before running any GPU test:
|
||||
|
||||
@@ -1,5 +1,37 @@
|
||||
# Planning with Files Rule
|
||||
|
||||
## Git 管理政策
|
||||
|
||||
**重要**:Planning 文件已从 Git 管理中排除,不会被提交。
|
||||
|
||||
### 已配置的 .gitignore 规则
|
||||
|
||||
```gitignore
|
||||
# Planning-with-files temporary files
|
||||
task_plan.md
|
||||
findings.md
|
||||
progress.md
|
||||
task_plan_*.md
|
||||
findings_*.md
|
||||
progress_*.md
|
||||
```
|
||||
|
||||
### 为什么排除这些文件
|
||||
|
||||
1. **临时性质**:计划文件是会话级别的临时文件,不应进入版本控制
|
||||
2. **避免冲突**:多实例并行开发时,不同任务的计划文件会产生冲突
|
||||
3. **保持仓库整洁**:这些文件只对当前任务有用,不需要历史记录
|
||||
|
||||
### 如果不小心已经 commit 了
|
||||
|
||||
```bash
|
||||
# 从 git 中移除(保留本地文件)
|
||||
git rm --cached task_plan.md findings.md progress.md
|
||||
git commit -m "chore: remove planning files from git tracking"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 自动清理旧计划文件
|
||||
|
||||
**重要**:每次开始新的复杂任务使用 planning-with-files 时,先删除旧的计划文件。
|
||||
|
||||
@@ -1,97 +1,92 @@
|
||||
# Sparse Policy 代码规范
|
||||
|
||||
## supports_prefill / supports_decode 标志
|
||||
## 基类要求 (MANDATORY)
|
||||
|
||||
每个 SparsePolicy 子类必须正确设置这两个标志:
|
||||
每个 `SparsePolicy` 子类 **必须** 遵守以下要求:
|
||||
|
||||
### 1. 声明 supports_prefill / supports_decode 标志
|
||||
|
||||
```python
|
||||
class MyPolicy(SparsePolicy):
|
||||
supports_prefill = True # 是否支持 prefill 阶段
|
||||
supports_decode = False # 是否支持 decode 阶段
|
||||
supports_decode = True # 是否支持 decode 阶段
|
||||
```
|
||||
|
||||
## 方法实现规范
|
||||
### 2. 实现三个抽象方法
|
||||
|
||||
### 规则:不支持的阶段必须 assert False
|
||||
| 方法 | 必须实现 | 说明 |
|
||||
|------|---------|------|
|
||||
| `select_blocks()` | ✅ | 选择要加载的 blocks |
|
||||
| `compute_chunked_prefill()` | ✅ | Prefill attention 计算 |
|
||||
| `compute_chunked_decode()` | ✅ | Decode attention 计算 |
|
||||
|
||||
如果 policy 不支持某个阶段,对应的 `compute_chunked_*` 方法内部**必须** `assert False`:
|
||||
### 3. 不支持的阶段必须 assert False
|
||||
|
||||
```python
|
||||
class PrefillOnlyPolicy(SparsePolicy):
|
||||
supports_prefill = True
|
||||
supports_decode = False
|
||||
|
||||
def compute_chunked_attention(self, ...):
|
||||
# 正常实现 prefill 逻辑
|
||||
...
|
||||
|
||||
def compute_chunked_decode(self, ...):
|
||||
# 不支持 decode,必须 assert False
|
||||
assert False, "PrefillOnlyPolicy does not support decode phase"
|
||||
```
|
||||
如果 `supports_prefill = False`,则 `compute_chunked_prefill()` 内部 **必须** `assert False`:
|
||||
|
||||
```python
|
||||
class DecodeOnlyPolicy(SparsePolicy):
|
||||
supports_prefill = False
|
||||
supports_decode = True
|
||||
|
||||
def compute_chunked_attention(self, ...):
|
||||
# 不支持 prefill,必须 assert False
|
||||
def compute_chunked_prefill(self, ...):
|
||||
assert False, "DecodeOnlyPolicy does not support prefill phase"
|
||||
|
||||
def compute_chunked_decode(self, ...):
|
||||
# 正常实现 decode 逻辑
|
||||
# 正常实现
|
||||
...
|
||||
```
|
||||
|
||||
### 规则:FullPolicy 必须同时支持两个阶段
|
||||
同理,如果 `supports_decode = False`:
|
||||
|
||||
`FullAttentionPolicy` 作为默认策略,必须同时支持 prefill 和 decode:
|
||||
```python
|
||||
class PrefillOnlyPolicy(SparsePolicy):
|
||||
supports_prefill = True
|
||||
supports_decode = False
|
||||
|
||||
def compute_chunked_prefill(self, ...):
|
||||
# 正常实现
|
||||
...
|
||||
|
||||
def compute_chunked_decode(self, ...):
|
||||
assert False, "PrefillOnlyPolicy does not support decode phase"
|
||||
```
|
||||
|
||||
### 4. FullAttentionPolicy 必须同时支持两个阶段
|
||||
|
||||
```python
|
||||
class FullAttentionPolicy(SparsePolicy):
|
||||
supports_prefill = True
|
||||
supports_decode = True
|
||||
|
||||
def compute_chunked_attention(self, ...):
|
||||
def compute_chunked_prefill(self, ...):
|
||||
# 完整实现
|
||||
|
||||
def compute_chunked_decode(self, ...):
|
||||
# 完整实现
|
||||
```
|
||||
|
||||
## 调用方检查
|
||||
|
||||
`attention.py` 中应在调用前检查 policy 是否支持当前阶段:
|
||||
|
||||
```python
|
||||
# Prefill 路径
|
||||
if not sparse_policy.supports_prefill:
|
||||
raise RuntimeError(f"{sparse_policy} does not support prefill")
|
||||
|
||||
# Decode 路径
|
||||
if not sparse_policy.supports_decode:
|
||||
raise RuntimeError(f"{sparse_policy} does not support decode")
|
||||
```
|
||||
|
||||
这样提供双重保护:
|
||||
1. 调用方检查 → 提供清晰的错误信息
|
||||
2. 方法内 assert → 防止绕过检查的调用
|
||||
---
|
||||
|
||||
## CPU-GPU 通信规范
|
||||
|
||||
### 规则:所有通信必须通过 OffloadEngine
|
||||
|
||||
在 SparsePolicy 的 `compute_chunked_*` 方法中,所有 CPU-GPU 数据传输**必须**通过 `OffloadEngine` 进行,**禁止**直接使用 `torch.Tensor.copy_()` 或 `.to(device)`:
|
||||
在 `compute_chunked_*` 方法中,**禁止** 直接使用 `torch.Tensor.copy_()` 或 `.to(device)`:
|
||||
|
||||
```python
|
||||
# ✅ 正确:使用 OffloadEngine 的方法
|
||||
# ✅ 正确:使用 OffloadEngine 的 ring buffer 方法
|
||||
offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id)
|
||||
offload_engine.wait_slot_layer(slot)
|
||||
k, v = offload_engine.get_kv_for_slot(slot)
|
||||
offload_engine.record_slot_compute_done(slot)
|
||||
|
||||
# ✅ 正确:使用 cross-layer pipeline
|
||||
k, v = offload_engine.get_decode_layer_kv(layer_id, num_blocks)
|
||||
# ✅ 正确:使用 prefill buffer
|
||||
k, v = offload_engine.get_prefill_buffer_slice(layer_id, num_tokens)
|
||||
|
||||
# ✅ 正确:使用 decode buffer
|
||||
decode_k = offload_engine.decode_k_buffer[layer_id, start:end]
|
||||
decode_v = offload_engine.decode_v_buffer[layer_id, start:end]
|
||||
|
||||
# ❌ 错误:直接使用 torch 通信
|
||||
gpu_tensor.copy_(cpu_tensor)
|
||||
@@ -102,6 +97,70 @@ gpu_tensor = cpu_tensor.cuda()
|
||||
### 原因
|
||||
|
||||
1. **流同步**:OffloadEngine 内部管理 CUDA streams,确保正确的同步
|
||||
2. **Pipeline 优化**:OffloadEngine 实现了 ring buffer 和 cross-layer pipeline
|
||||
2. **Pipeline 优化**:OffloadEngine 实现了 ring buffer pipeline
|
||||
3. **资源管理**:OffloadEngine 管理 GPU buffer slots,避免内存碎片
|
||||
4. **一致性**:统一的接口便于调试和维护
|
||||
|
||||
---
|
||||
|
||||
## 方法签名要求
|
||||
|
||||
### select_blocks()
|
||||
|
||||
```python
|
||||
def select_blocks(
|
||||
self,
|
||||
available_blocks: List[int], # 可用的 CPU block IDs
|
||||
offload_engine: "OffloadEngine", # 用于加载数据
|
||||
ctx: PolicyContext, # 上下文信息
|
||||
) -> List[int]: # 返回要加载的 block IDs
|
||||
```
|
||||
|
||||
### compute_chunked_prefill()
|
||||
|
||||
```python
|
||||
def compute_chunked_prefill(
|
||||
self,
|
||||
q: torch.Tensor, # [seq_len, num_heads, head_dim]
|
||||
k: torch.Tensor, # [seq_len, num_kv_heads, head_dim] (unused)
|
||||
v: torch.Tensor, # [seq_len, num_kv_heads, head_dim] (unused)
|
||||
layer_id: int,
|
||||
softmax_scale: float,
|
||||
offload_engine: "OffloadEngine",
|
||||
kvcache_manager: "KVCacheManager",
|
||||
current_chunk_idx: int,
|
||||
seq: "Sequence",
|
||||
num_tokens: int,
|
||||
) -> torch.Tensor: # [seq_len, num_heads, head_dim]
|
||||
```
|
||||
|
||||
### compute_chunked_decode()
|
||||
|
||||
```python
|
||||
def compute_chunked_decode(
|
||||
self,
|
||||
q: torch.Tensor, # [batch_size, num_heads, head_dim]
|
||||
layer_id: int,
|
||||
softmax_scale: float,
|
||||
offload_engine: "OffloadEngine",
|
||||
kvcache_manager: "KVCacheManager",
|
||||
seq: "Sequence",
|
||||
) -> torch.Tensor: # [batch_size, 1, num_heads, head_dim]
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 可选钩子方法
|
||||
|
||||
| 方法 | 调用时机 | 用途 |
|
||||
|------|---------|------|
|
||||
| `initialize()` | KV cache 分配后 | 初始化 metadata 结构 |
|
||||
| `on_prefill_offload()` | GPU→CPU 复制前(prefill) | 收集 block metadata |
|
||||
| `on_decode_offload()` | GPU→CPU 复制前(decode) | 更新 block metadata |
|
||||
| `reset()` | 新 sequence 开始时 | 重置 policy 状态 |
|
||||
|
||||
---
|
||||
|
||||
## 详细实现指南
|
||||
|
||||
参考文档:[`docs/sparse_policy_implementation_guide.md`](../docs/sparse_policy_implementation_guide.md)
|
||||
|
||||
8
.gitignore
vendored
8
.gitignore
vendored
@@ -230,3 +230,11 @@ tests/data/
|
||||
|
||||
# Serena MCP tool config
|
||||
.serena/
|
||||
|
||||
# Planning-with-files temporary files
|
||||
task_plan.md
|
||||
findings.md
|
||||
progress.md
|
||||
task_plan_*.md
|
||||
findings_*.md
|
||||
progress_*.md
|
||||
|
||||
@@ -11,6 +11,8 @@ Nano-vLLM is a lightweight vLLM implementation (~1,200 lines) for fast offline L
|
||||
| Document | Purpose |
|
||||
|----------|---------|
|
||||
| [`docs/architecture_guide.md`](docs/architecture_guide.md) | Core components, CPU offload system design, ring buffer architecture, stream configuration |
|
||||
| [`docs/sparse_policy_architecture.md`](docs/sparse_policy_architecture.md) | SparsePolicy abstraction: prefill/decode delegation, pipeline modes, policy implementations |
|
||||
| [`docs/sparse_policy_implementation_guide.md`](docs/sparse_policy_implementation_guide.md) | How to implement custom SparsePolicy: required methods, hooks, ring buffer pipeline pattern |
|
||||
| [`docs/sparse_attention_guide.md`](docs/sparse_attention_guide.md) | Block sparse attention methods (XAttention, FlexPrefill, MInference, AvgPool, Quest), computation flow, algorithms |
|
||||
| [`docs/debugging_guide.md`](docs/debugging_guide.md) | PyTorch hooks for debugging, hook positions, tensor comparison, memory profiling |
|
||||
| [`docs/optimization_guide.md`](docs/optimization_guide.md) | Performance optimizations: sgDMA (15x), Triton merge (4.3x), N-way pipeline (2x) |
|
||||
|
||||
288
docs/sparse_policy_architecture.md
Normal file
288
docs/sparse_policy_architecture.md
Normal file
@@ -0,0 +1,288 @@
|
||||
# SparsePolicy Architecture Guide
|
||||
|
||||
This document describes the SparsePolicy abstraction for chunked attention computation in CPU offload mode.
|
||||
|
||||
## Overview
|
||||
|
||||
SparsePolicy is an abstract base class that defines how attention is computed during chunked prefill and decode phases. All attention computation logic is delegated to the policy, allowing different sparse attention strategies to be implemented without modifying the core attention layer.
|
||||
|
||||
```
|
||||
attention.py SparsePolicy
|
||||
| |
|
||||
| _chunked_prefill_attention |
|
||||
| ────────────────────────────> | compute_chunked_prefill()
|
||||
| |
|
||||
| _chunked_decode_attention |
|
||||
| ────────────────────────────> | compute_chunked_decode()
|
||||
| |
|
||||
```
|
||||
|
||||
## Key Design Principles
|
||||
|
||||
1. **Delegation Pattern**: `attention.py` only validates and delegates; all computation is in the policy
|
||||
2. **No Direct Imports**: `attention.py` does not import `flash_attn_with_lse` or `merge_attention_outputs`
|
||||
3. **Pipeline Encapsulation**: Ring buffer and cross-layer pipelines are internal to the policy
|
||||
4. **Phase Support Flags**: Policies declare which phases they support via `supports_prefill` and `supports_decode`
|
||||
|
||||
---
|
||||
|
||||
## SparsePolicy Base Class
|
||||
|
||||
**File**: `nanovllm/kvcache/sparse/policy.py`
|
||||
|
||||
### Class Attributes
|
||||
|
||||
| Attribute | Type | Description |
|
||||
|-----------|------|-------------|
|
||||
| `supports_prefill` | bool | Whether policy supports prefill phase |
|
||||
| `supports_decode` | bool | Whether policy supports decode phase |
|
||||
|
||||
### Abstract Methods
|
||||
|
||||
```python
|
||||
@abstractmethod
|
||||
def select_blocks(
|
||||
self,
|
||||
available_blocks: List[int],
|
||||
offload_engine: "OffloadEngine",
|
||||
ctx: PolicyContext,
|
||||
) -> List[int]:
|
||||
"""Select which KV blocks to load for the current query chunk."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def compute_chunked_prefill(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
layer_id: int,
|
||||
softmax_scale: float,
|
||||
offload_engine: "OffloadEngine",
|
||||
kvcache_manager: "KVCacheManager",
|
||||
current_chunk_idx: int,
|
||||
seq: "Sequence",
|
||||
num_tokens: int,
|
||||
) -> torch.Tensor:
|
||||
"""Compute chunked prefill attention (complete flow)."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def compute_chunked_decode(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
layer_id: int,
|
||||
softmax_scale: float,
|
||||
offload_engine: "OffloadEngine",
|
||||
kvcache_manager: "KVCacheManager",
|
||||
seq: "Sequence",
|
||||
) -> torch.Tensor:
|
||||
"""Compute chunked decode attention (complete flow)."""
|
||||
pass
|
||||
```
|
||||
|
||||
### Hook Methods
|
||||
|
||||
| Method | When Called | Purpose |
|
||||
|--------|-------------|---------|
|
||||
| `initialize()` | After KV cache allocation | Initialize policy resources (e.g., metadata) |
|
||||
| `on_prefill_offload()` | Before GPU→CPU copy during prefill | Collect block metadata |
|
||||
| `on_decode_offload()` | Before GPU→CPU copy during decode | Update block metadata |
|
||||
| `reset()` | New sequence / clear state | Reset policy state |
|
||||
|
||||
---
|
||||
|
||||
## FullAttentionPolicy
|
||||
|
||||
**File**: `nanovllm/kvcache/sparse/full_policy.py`
|
||||
|
||||
The default policy that loads all blocks (no sparsity). Serves as the baseline implementation.
|
||||
|
||||
### Flags
|
||||
|
||||
```python
|
||||
supports_prefill = True
|
||||
supports_decode = True
|
||||
```
|
||||
|
||||
### Prefill Flow (`compute_chunked_prefill`)
|
||||
|
||||
```
|
||||
1. Get historical blocks from kvcache_manager
|
||||
└── cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
|
||||
|
||||
2. Apply select_blocks (returns all for FullPolicy)
|
||||
└── cpu_block_table = self.select_blocks(cpu_block_table, offload_engine, ctx)
|
||||
|
||||
3. Load and compute historical blocks via ring buffer
|
||||
└── For each block:
|
||||
a. load_to_slot_layer(slot, layer_id, cpu_block_id)
|
||||
b. wait_slot_layer(slot)
|
||||
c. prev_k, prev_v = get_kv_for_slot(slot)
|
||||
d. flash_attn_with_lse(q, prev_k, prev_v, causal=False)
|
||||
e. merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
|
||||
|
||||
4. Compute current chunk attention (causal)
|
||||
└── k_curr, v_curr = offload_engine.get_prefill_buffer_slice(layer_id, num_tokens)
|
||||
└── flash_attn_with_lse(q, k_curr, v_curr, causal=True)
|
||||
|
||||
5. Merge historical and current attention
|
||||
└── merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
|
||||
```
|
||||
|
||||
### Decode Flow (`compute_chunked_decode`)
|
||||
|
||||
```
|
||||
1. Get prefilled CPU blocks
|
||||
└── cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
|
||||
|
||||
2. Calculate last block valid tokens
|
||||
└── total_prefill_tokens = kvcache_manager.get_prefill_len(seq)
|
||||
└── last_block_valid_tokens = total_prefill_tokens % block_size
|
||||
|
||||
3. Apply select_blocks for block filtering
|
||||
└── cpu_block_table = self.select_blocks(cpu_block_table, offload_engine, ctx)
|
||||
|
||||
4. Load prefilled blocks via ring buffer pipeline
|
||||
└── _decode_ring_buffer_pipeline()
|
||||
|
||||
5. Read accumulated decode tokens from decode buffer
|
||||
└── decode_k = offload_engine.decode_k_buffer[layer_id, start:end]
|
||||
└── decode_v = offload_engine.decode_v_buffer[layer_id, start:end]
|
||||
└── flash_attn_with_lse(q, decode_k, decode_v, causal=False)
|
||||
|
||||
6. Merge all results
|
||||
└── merge_attention_outputs(o_acc, lse_acc, decode_o, decode_lse)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Ring Buffer Pipeline
|
||||
|
||||
The ring buffer pipeline (`_decode_ring_buffer_pipeline`) loads blocks one by one using GPU ring buffer slots. This approach is memory-efficient and works well for both short and long sequences.
|
||||
|
||||
```
|
||||
Slot[0]: Block A ──> Compute ──> Block C ──> Compute
|
||||
Slot[1]: Block B ──> Compute ──> Block D ──> Compute
|
||||
```
|
||||
|
||||
**Advantages**:
|
||||
- Memory efficient (only needs a few GPU slots)
|
||||
- Fine-grained overlap between H2D transfer and compute
|
||||
- Works well for long sequences
|
||||
|
||||
**Flow**:
|
||||
```python
|
||||
# Phase 1: Pre-load up to num_slots blocks
|
||||
for i in range(num_preload):
|
||||
offload_engine.load_to_slot_layer(load_slots[i], layer_id, cpu_block_table[i])
|
||||
|
||||
# Phase 2: Process blocks with pipeline
|
||||
for block_idx in range(num_blocks):
|
||||
current_slot = load_slots[block_idx % num_slots]
|
||||
|
||||
# Wait for transfer
|
||||
offload_engine.wait_slot_layer(current_slot)
|
||||
|
||||
# Compute attention
|
||||
prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot)
|
||||
prev_o, prev_lse = flash_attn_with_lse(q, prev_k, prev_v, causal=False)
|
||||
offload_engine.record_slot_compute_done(current_slot)
|
||||
|
||||
# Pipeline: start loading next block
|
||||
if next_block_idx < num_blocks:
|
||||
offload_engine.load_to_slot_layer(current_slot, layer_id, cpu_block_table[next_block_idx])
|
||||
|
||||
# Merge results
|
||||
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Code Conventions
|
||||
|
||||
### Unsupported Phases Must Assert False
|
||||
|
||||
If a policy doesn't support a phase, the corresponding method must `assert False`:
|
||||
|
||||
```python
|
||||
class PrefillOnlyPolicy(SparsePolicy):
|
||||
supports_prefill = True
|
||||
supports_decode = False
|
||||
|
||||
def compute_chunked_prefill(self, ...):
|
||||
# Normal prefill implementation
|
||||
...
|
||||
|
||||
def compute_chunked_decode(self, ...):
|
||||
assert False, "PrefillOnlyPolicy does not support decode phase"
|
||||
```
|
||||
|
||||
### Caller Must Check Support Flags
|
||||
|
||||
`attention.py` checks support flags before calling:
|
||||
|
||||
```python
|
||||
if not sparse_policy.supports_decode:
|
||||
raise RuntimeError(f"{sparse_policy} does not support decode phase")
|
||||
```
|
||||
|
||||
This provides double protection:
|
||||
1. Caller check → Clear error message
|
||||
2. Method assert → Prevents bypassing the check
|
||||
|
||||
### CPU-GPU Communication via OffloadEngine Only
|
||||
|
||||
All CPU-GPU data transfers must go through `OffloadEngine` methods:
|
||||
|
||||
```python
|
||||
# Correct: Use OffloadEngine methods
|
||||
offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id)
|
||||
offload_engine.wait_slot_layer(slot)
|
||||
k, v = offload_engine.get_kv_for_slot(slot)
|
||||
|
||||
# Incorrect: Direct torch operations
|
||||
gpu_tensor.copy_(cpu_tensor) # DON'T DO THIS
|
||||
gpu_tensor = cpu_tensor.to("cuda") # DON'T DO THIS
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## File Structure
|
||||
|
||||
| File | Purpose |
|
||||
|------|---------|
|
||||
| `nanovllm/kvcache/sparse/policy.py` | Base class, PolicyContext, abstract methods |
|
||||
| `nanovllm/kvcache/sparse/full_policy.py` | FullAttentionPolicy implementation |
|
||||
| `nanovllm/kvcache/sparse/quest.py` | QuestPolicy (decode-only Top-K selection) |
|
||||
| `nanovllm/layers/attention.py` | Attention layer, delegates to policy |
|
||||
|
||||
---
|
||||
|
||||
## Policy Implementations
|
||||
|
||||
| Policy | supports_prefill | supports_decode | Description |
|
||||
|--------|------------------|-----------------|-------------|
|
||||
| `FullAttentionPolicy` | True | True | Loads all blocks (baseline) |
|
||||
| `QuestPolicy` | False | True | Decode-only Top-K selection |
|
||||
| `XAttentionBSAPolicy` | False | False | Placeholder for future BSA |
|
||||
|
||||
---
|
||||
|
||||
## Testing
|
||||
|
||||
Run needle-in-haystack test with offload:
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/path/to/nano-vllm:$PYTHONPATH \
|
||||
python tests/test_needle.py --model ~/models/Llama-3.1-8B-Instruct --enable-offload
|
||||
```
|
||||
|
||||
Expected output:
|
||||
```
|
||||
Needle-in-Haystack Test
|
||||
Model: Llama-3.1-8B-Instruct
|
||||
CPU offload: True
|
||||
Sparse policy: FULL
|
||||
Result: PASSED
|
||||
```
|
||||
317
docs/sparse_policy_implementation_guide.md
Normal file
317
docs/sparse_policy_implementation_guide.md
Normal file
@@ -0,0 +1,317 @@
|
||||
# SparsePolicy Implementation Guide
|
||||
|
||||
This guide describes how to implement a custom `SparsePolicy` for sparse attention in CPU offload mode.
|
||||
|
||||
## Overview
|
||||
|
||||
`SparsePolicy` is an abstract base class that controls:
|
||||
1. **Block Selection**: Which KV cache blocks to load from CPU for each query
|
||||
2. **Attention Computation**: How to compute chunked prefill and decode attention
|
||||
|
||||
All computation happens in the policy, with `attention.py` only delegating to the policy methods.
|
||||
|
||||
---
|
||||
|
||||
## Base Class Structure
|
||||
|
||||
```python
|
||||
class SparsePolicy(ABC):
|
||||
# Phase support flags (REQUIRED to override)
|
||||
supports_prefill: bool = True
|
||||
supports_decode: bool = True
|
||||
|
||||
# Abstract methods (MUST implement)
|
||||
def select_blocks(self, available_blocks, offload_engine, ctx) -> List[int]
|
||||
def compute_chunked_prefill(self, q, k, v, layer_id, ...) -> torch.Tensor
|
||||
def compute_chunked_decode(self, q, layer_id, ...) -> torch.Tensor
|
||||
|
||||
# Optional hooks (CAN override)
|
||||
def initialize(self, num_layers, num_kv_heads, head_dim, num_cpu_blocks, dtype, device)
|
||||
def on_prefill_offload(self, cpu_block_id, layer_id, k_cache, num_valid_tokens)
|
||||
def on_decode_offload(self, cpu_block_id, layer_id, k_cache, num_valid_tokens)
|
||||
def reset(self)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Required Implementations
|
||||
|
||||
### 1. Phase Support Flags
|
||||
|
||||
Every policy MUST declare which phases it supports:
|
||||
|
||||
```python
|
||||
class MyPolicy(SparsePolicy):
|
||||
supports_prefill = True # Can be used in prefill phase?
|
||||
supports_decode = True # Can be used in decode phase?
|
||||
```
|
||||
|
||||
| Policy Type | supports_prefill | supports_decode | Example |
|
||||
|-------------|------------------|-----------------|---------|
|
||||
| Full support | True | True | `FullAttentionPolicy` |
|
||||
| Decode-only | False | True | `QuestPolicy` |
|
||||
| Prefill-only | True | False | (hypothetical) |
|
||||
|
||||
### 2. select_blocks() - Block Selection
|
||||
|
||||
```python
|
||||
@abstractmethod
|
||||
def select_blocks(
|
||||
self,
|
||||
available_blocks: List[int], # CPU block IDs with historical KV
|
||||
offload_engine: "OffloadEngine",
|
||||
ctx: PolicyContext, # Context about current query
|
||||
) -> List[int]:
|
||||
"""Return subset of available_blocks to load."""
|
||||
```
|
||||
|
||||
**PolicyContext fields:**
|
||||
- `query_chunk_idx`: Current chunk index (0-indexed)
|
||||
- `num_query_chunks`: Total number of chunks
|
||||
- `layer_id`: Transformer layer index
|
||||
- `query`: Query tensor (available for decode)
|
||||
- `is_prefill`: True if prefill phase
|
||||
- `block_size`: Tokens per block
|
||||
- `total_kv_len`: Total KV length so far
|
||||
|
||||
**Example implementations:**
|
||||
|
||||
```python
|
||||
# Full attention: load all blocks
|
||||
def select_blocks(self, available_blocks, offload_engine, ctx):
|
||||
return available_blocks
|
||||
|
||||
# Top-K sparse: load K most important blocks
|
||||
def select_blocks(self, available_blocks, offload_engine, ctx):
|
||||
scores = self.compute_block_scores(available_blocks, ctx.query)
|
||||
topk_indices = scores.topk(self.config.topk).indices
|
||||
return [available_blocks[i] for i in sorted(topk_indices.tolist())]
|
||||
```
|
||||
|
||||
### 3. compute_chunked_prefill() - Prefill Attention
|
||||
|
||||
```python
|
||||
@abstractmethod
|
||||
def compute_chunked_prefill(
|
||||
self,
|
||||
q: torch.Tensor, # [seq_len, num_heads, head_dim]
|
||||
k: torch.Tensor, # [seq_len, num_kv_heads, head_dim] (unused)
|
||||
v: torch.Tensor, # [seq_len, num_kv_heads, head_dim] (unused)
|
||||
layer_id: int,
|
||||
softmax_scale: float,
|
||||
offload_engine: "OffloadEngine",
|
||||
kvcache_manager: "KVCacheManager",
|
||||
current_chunk_idx: int,
|
||||
seq: "Sequence",
|
||||
num_tokens: int,
|
||||
) -> torch.Tensor: # [seq_len, num_heads, head_dim]
|
||||
```
|
||||
|
||||
**Required flow:**
|
||||
1. Get historical blocks: `kvcache_manager.get_prefilled_cpu_blocks(seq)`
|
||||
2. Call `select_blocks()` to filter blocks
|
||||
3. Load blocks via ring buffer pipeline
|
||||
4. Get current chunk KV: `offload_engine.get_prefill_buffer_slice(layer_id, num_tokens)`
|
||||
5. Compute attention with `flash_attn_with_lse()` (historical: causal=False, current: causal=True)
|
||||
6. Merge results with `merge_attention_outputs()`
|
||||
7. Return output with shape `[seq_len, num_heads, head_dim]`
|
||||
|
||||
**If policy doesn't support prefill:**
|
||||
```python
|
||||
def compute_chunked_prefill(self, ...):
|
||||
assert False, "MyPolicy does not support prefill phase"
|
||||
```
|
||||
|
||||
### 4. compute_chunked_decode() - Decode Attention
|
||||
|
||||
```python
|
||||
@abstractmethod
|
||||
def compute_chunked_decode(
|
||||
self,
|
||||
q: torch.Tensor, # [batch_size, num_heads, head_dim]
|
||||
layer_id: int,
|
||||
softmax_scale: float,
|
||||
offload_engine: "OffloadEngine",
|
||||
kvcache_manager: "KVCacheManager",
|
||||
seq: "Sequence",
|
||||
) -> torch.Tensor: # [batch_size, 1, num_heads, head_dim]
|
||||
```
|
||||
|
||||
**Required flow:**
|
||||
1. Get prefilled blocks: `kvcache_manager.get_prefilled_cpu_blocks(seq)`
|
||||
2. Calculate last block valid tokens from `kvcache_manager.get_prefill_len(seq)`
|
||||
3. Call `select_blocks()` to filter blocks
|
||||
4. Load blocks via `_decode_ring_buffer_pipeline()` helper
|
||||
5. Read decode buffer: `offload_engine.decode_k_buffer[layer_id, ...]`
|
||||
6. Merge results with `merge_attention_outputs()`
|
||||
7. Return output with shape `[batch_size, 1, num_heads, head_dim]`
|
||||
|
||||
**If policy doesn't support decode:**
|
||||
```python
|
||||
def compute_chunked_decode(self, ...):
|
||||
assert False, "MyPolicy does not support decode phase"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Optional Hooks
|
||||
|
||||
### initialize()
|
||||
|
||||
Called after KV cache allocation. Use to create metadata structures.
|
||||
|
||||
```python
|
||||
def initialize(self, num_layers, num_kv_heads, head_dim, num_cpu_blocks, dtype, device):
|
||||
self.metadata = BlockMetadataManager(
|
||||
num_blocks=num_cpu_blocks,
|
||||
num_layers=num_layers,
|
||||
...
|
||||
)
|
||||
```
|
||||
|
||||
### on_prefill_offload() / on_decode_offload()
|
||||
|
||||
Called BEFORE GPU→CPU copy. Use to collect block metadata while data is still on GPU.
|
||||
|
||||
```python
|
||||
def on_prefill_offload(self, cpu_block_id, layer_id, k_cache, num_valid_tokens):
|
||||
# k_cache is still on GPU here
|
||||
self.metadata.update_min_max(cpu_block_id, layer_id, k_cache, num_valid_tokens)
|
||||
```
|
||||
|
||||
### reset()
|
||||
|
||||
Called when starting new sequence. Use to clear state.
|
||||
|
||||
```python
|
||||
def reset(self):
|
||||
if self.metadata is not None:
|
||||
self.metadata.reset()
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## CPU-GPU Communication Rules
|
||||
|
||||
**MUST use OffloadEngine methods:**
|
||||
```python
|
||||
# Loading blocks
|
||||
offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id)
|
||||
offload_engine.wait_slot_layer(slot)
|
||||
k, v = offload_engine.get_kv_for_slot(slot)
|
||||
offload_engine.record_slot_compute_done(slot)
|
||||
|
||||
# Current chunk KV
|
||||
k, v = offload_engine.get_prefill_buffer_slice(layer_id, num_tokens)
|
||||
|
||||
# Decode buffer
|
||||
decode_k = offload_engine.decode_k_buffer[layer_id, start:end]
|
||||
decode_v = offload_engine.decode_v_buffer[layer_id, start:end]
|
||||
```
|
||||
|
||||
**NEVER do direct transfers:**
|
||||
```python
|
||||
# WRONG!
|
||||
gpu_tensor.copy_(cpu_tensor)
|
||||
gpu_tensor = cpu_tensor.to("cuda")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Ring Buffer Pipeline Pattern
|
||||
|
||||
The standard pattern for loading blocks:
|
||||
|
||||
```python
|
||||
def _decode_ring_buffer_pipeline(self, q_batched, cpu_block_table, load_slots, ...):
|
||||
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||
|
||||
num_blocks = len(cpu_block_table)
|
||||
num_slots = len(load_slots)
|
||||
o_acc, lse_acc = None, None
|
||||
|
||||
# Phase 1: Pre-load up to num_slots blocks
|
||||
for i in range(min(num_slots, num_blocks)):
|
||||
offload_engine.load_to_slot_layer(load_slots[i], layer_id, cpu_block_table[i])
|
||||
|
||||
# Phase 2: Process with pipeline
|
||||
for block_idx in range(num_blocks):
|
||||
slot = load_slots[block_idx % num_slots]
|
||||
|
||||
# Wait for H2D transfer
|
||||
offload_engine.wait_slot_layer(slot)
|
||||
|
||||
with torch.cuda.stream(offload_engine.compute_stream):
|
||||
# Get KV and compute attention
|
||||
k, v = offload_engine.get_kv_for_slot(slot)
|
||||
o, lse = flash_attn_with_lse(q_batched, k, v, softmax_scale, causal=False)
|
||||
offload_engine.record_slot_compute_done(slot)
|
||||
|
||||
# Pipeline: start next block transfer
|
||||
next_idx = block_idx + num_slots
|
||||
if next_idx < num_blocks:
|
||||
offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_table[next_idx])
|
||||
|
||||
# Merge results
|
||||
with torch.cuda.stream(offload_engine.compute_stream):
|
||||
if o_acc is None:
|
||||
o_acc, lse_acc = o, lse
|
||||
else:
|
||||
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o, lse)
|
||||
|
||||
return o_acc, lse_acc
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Complete Example: Decode-Only Policy
|
||||
|
||||
```python
|
||||
class TopKPolicy(SparsePolicy):
|
||||
"""Load only top-K blocks based on query-key similarity."""
|
||||
|
||||
supports_prefill = False # Use FullAttentionPolicy for prefill
|
||||
supports_decode = True
|
||||
|
||||
def __init__(self, topk: int = 8):
|
||||
self.topk = topk
|
||||
self.metadata = None
|
||||
|
||||
def initialize(self, num_layers, num_kv_heads, head_dim, num_cpu_blocks, dtype, device):
|
||||
self.metadata = BlockMetadataManager(num_cpu_blocks, num_layers, num_kv_heads, head_dim)
|
||||
|
||||
def select_blocks(self, available_blocks, offload_engine, ctx):
|
||||
if len(available_blocks) <= self.topk:
|
||||
return available_blocks
|
||||
|
||||
# Compute scores and select top-K
|
||||
scores = self.metadata.compute_scores(available_blocks, ctx.layer_id, ctx.query)
|
||||
topk_indices = scores.topk(self.topk).indices.cpu().tolist()
|
||||
return [available_blocks[i] for i in sorted(topk_indices)]
|
||||
|
||||
def on_prefill_offload(self, cpu_block_id, layer_id, k_cache, num_valid_tokens):
|
||||
self.metadata.update(cpu_block_id, layer_id, k_cache, num_valid_tokens)
|
||||
|
||||
def compute_chunked_prefill(self, ...):
|
||||
assert False, "TopKPolicy does not support prefill phase"
|
||||
|
||||
def compute_chunked_decode(self, q, layer_id, softmax_scale, offload_engine, kvcache_manager, seq):
|
||||
# Copy implementation from FullAttentionPolicy.compute_chunked_decode
|
||||
# The only difference is select_blocks() will filter to top-K
|
||||
...
|
||||
|
||||
def reset(self):
|
||||
if self.metadata:
|
||||
self.metadata.reset()
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## File Locations
|
||||
|
||||
| File | Purpose |
|
||||
|------|---------|
|
||||
| `nanovllm/kvcache/sparse/policy.py` | Base class and PolicyContext |
|
||||
| `nanovllm/kvcache/sparse/full_policy.py` | FullAttentionPolicy (reference implementation) |
|
||||
| `nanovllm/kvcache/sparse/quest.py` | QuestPolicy (decode-only example) |
|
||||
| `nanovllm/kvcache/chunked_attention.py` | `flash_attn_with_lse`, `merge_attention_outputs` |
|
||||
@@ -644,12 +644,6 @@ class ModelRunner:
|
||||
# Get decode start position for accumulated token tracking
|
||||
decode_start_pos = self.kvcache_manager.get_decode_start_pos(seq)
|
||||
|
||||
# Get prefilled CPU blocks for pipeline initialization
|
||||
cpu_block_table = self.kvcache_manager.get_prefilled_cpu_blocks(seq)
|
||||
|
||||
# Start cross-layer pipeline (preloads Layer 0's data)
|
||||
offload_engine.start_decode_pipeline(cpu_block_table)
|
||||
|
||||
# Set up context for chunked decode
|
||||
set_context(
|
||||
is_prefill=False,
|
||||
@@ -666,9 +660,6 @@ class ModelRunner:
|
||||
logits = self.run_model(input_ids, positions, is_prefill=False)
|
||||
reset_context()
|
||||
|
||||
# End cross-layer pipeline
|
||||
offload_engine.end_decode_pipeline()
|
||||
|
||||
# Only offload when block is full (pos_in_block == block_size - 1)
|
||||
# This avoids unnecessary offloading on every decode step
|
||||
if pos_in_block == self.block_size - 1:
|
||||
|
||||
@@ -141,40 +141,6 @@ class OffloadEngine:
|
||||
decode_buf_mb = 2 * num_layers * block_size * num_kv_heads * head_dim * dtype.itemsize / (1024 * 1024)
|
||||
logger.info(f" Per-layer decode buffer: {decode_buf_mb:.1f} MB")
|
||||
|
||||
# ========== Cross-layer pipeline buffers for decode ==========
|
||||
# Double-buffered layer cache for pipelined decode:
|
||||
# - Buffer A: Current layer's prefilled KV being computed
|
||||
# - Buffer B: Next layer's prefilled KV being loaded
|
||||
# Shape: [max_prefill_blocks, block_size, kv_heads, head_dim]
|
||||
# Memory: 2 * max_prefill_blocks * block_size * kv_heads * head_dim * dtype_size
|
||||
max_prefill_blocks = num_cpu_blocks # Can hold all prefill blocks
|
||||
self.layer_k_buffer_a = torch.zeros(
|
||||
max_prefill_blocks, block_size, num_kv_heads, head_dim,
|
||||
dtype=dtype, device="cuda"
|
||||
)
|
||||
self.layer_v_buffer_a = torch.zeros(
|
||||
max_prefill_blocks, block_size, num_kv_heads, head_dim,
|
||||
dtype=dtype, device="cuda"
|
||||
)
|
||||
self.layer_k_buffer_b = torch.zeros(
|
||||
max_prefill_blocks, block_size, num_kv_heads, head_dim,
|
||||
dtype=dtype, device="cuda"
|
||||
)
|
||||
self.layer_v_buffer_b = torch.zeros(
|
||||
max_prefill_blocks, block_size, num_kv_heads, head_dim,
|
||||
dtype=dtype, device="cuda"
|
||||
)
|
||||
layer_buf_mb = 4 * max_prefill_blocks * block_size * num_kv_heads * head_dim * dtype.itemsize / (1024 * 1024)
|
||||
logger.info(f" Cross-layer pipeline buffers: {layer_buf_mb:.1f} MB ({max_prefill_blocks} blocks × 2)")
|
||||
|
||||
# Pipeline state tracking
|
||||
self._pipeline_active = False
|
||||
self._pipeline_current_buffer = 0 # 0 = buffer A, 1 = buffer B
|
||||
self._pipeline_next_layer_event = torch.cuda.Event()
|
||||
self._pipeline_cpu_blocks: list = [] # CPU block IDs to load
|
||||
self._pipeline_num_blocks = 0
|
||||
self._pipeline_layer_stream = torch.cuda.Stream() # Dedicated stream for layer loading
|
||||
|
||||
# ========== Per-layer prefill buffer for async offload ==========
|
||||
# During chunked prefill, all layers share the same GPU slot. This means
|
||||
# each layer must wait for offload to complete before the next layer can
|
||||
@@ -702,122 +668,6 @@ class OffloadEngine:
|
||||
raise
|
||||
logger.warning(f"Debug hook error: {e}")
|
||||
|
||||
# ========== Cross-layer Pipeline Methods for Decode ==========
|
||||
|
||||
def start_decode_pipeline(self, cpu_block_ids: List[int]) -> None:
|
||||
"""
|
||||
Start cross-layer pipeline for decode.
|
||||
|
||||
Called at the beginning of a decode step to initialize the pipeline.
|
||||
Preloads Layer 0's data into buffer A.
|
||||
|
||||
Args:
|
||||
cpu_block_ids: List of CPU block IDs for prefilled blocks
|
||||
"""
|
||||
if not cpu_block_ids:
|
||||
self._pipeline_active = False
|
||||
return
|
||||
|
||||
self._pipeline_active = True
|
||||
self._pipeline_cpu_blocks = cpu_block_ids
|
||||
self._pipeline_num_blocks = len(cpu_block_ids)
|
||||
self._pipeline_current_buffer = 0
|
||||
|
||||
# Preload Layer 0 into buffer A
|
||||
self._load_layer_to_buffer(0, 0) # layer_id=0, buffer_idx=0 (A)
|
||||
|
||||
def get_decode_layer_kv(self, layer_id: int, num_blocks: int) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Get KV cache for a layer during decode.
|
||||
|
||||
If pipeline is active, returns data from the current buffer.
|
||||
Also triggers preloading of the next layer (if not last layer).
|
||||
|
||||
Args:
|
||||
layer_id: Current layer ID
|
||||
num_blocks: Number of blocks to return
|
||||
|
||||
Returns:
|
||||
(k_cache, v_cache) tensors, shape: [num_blocks, block_size, kv_heads, head_dim]
|
||||
"""
|
||||
if not self._pipeline_active:
|
||||
raise RuntimeError("Decode pipeline not active. Call start_decode_pipeline first.")
|
||||
|
||||
# Wait for current layer's data to be ready
|
||||
self.compute_stream.wait_event(self._pipeline_next_layer_event)
|
||||
|
||||
# Get current buffer
|
||||
if self._pipeline_current_buffer == 0:
|
||||
k = self.layer_k_buffer_a[:num_blocks]
|
||||
v = self.layer_v_buffer_a[:num_blocks]
|
||||
else:
|
||||
k = self.layer_k_buffer_b[:num_blocks]
|
||||
v = self.layer_v_buffer_b[:num_blocks]
|
||||
|
||||
# Trigger preloading of next layer (if not last layer)
|
||||
next_layer_id = layer_id + 1
|
||||
if next_layer_id < self.num_layers:
|
||||
# Use the other buffer for next layer
|
||||
next_buffer_idx = 1 - self._pipeline_current_buffer
|
||||
self._load_layer_to_buffer(next_layer_id, next_buffer_idx)
|
||||
# Switch to next buffer for next layer
|
||||
self._pipeline_current_buffer = next_buffer_idx
|
||||
|
||||
return k, v
|
||||
|
||||
def _load_layer_to_buffer(self, layer_id: int, buffer_idx: int) -> None:
|
||||
"""
|
||||
Async load a layer's prefilled blocks to the specified buffer.
|
||||
|
||||
Uses sgDMA for efficient strided transfer from CPU cache.
|
||||
|
||||
Args:
|
||||
layer_id: Layer index to load
|
||||
buffer_idx: 0 for buffer A, 1 for buffer B
|
||||
"""
|
||||
num_blocks = self._pipeline_num_blocks
|
||||
cpu_block_ids = self._pipeline_cpu_blocks
|
||||
|
||||
# Select target buffer
|
||||
if buffer_idx == 0:
|
||||
k_buffer = self.layer_k_buffer_a
|
||||
v_buffer = self.layer_v_buffer_a
|
||||
else:
|
||||
k_buffer = self.layer_k_buffer_b
|
||||
v_buffer = self.layer_v_buffer_b
|
||||
|
||||
# Load all blocks for this layer using dedicated stream
|
||||
with torch.cuda.stream(self._pipeline_layer_stream):
|
||||
for i, cpu_block_id in enumerate(cpu_block_ids):
|
||||
# Copy from CPU cache (has layer dimension) to GPU buffer
|
||||
k_buffer[i].copy_(
|
||||
self.k_cache_cpu[layer_id, cpu_block_id],
|
||||
non_blocking=True
|
||||
)
|
||||
v_buffer[i].copy_(
|
||||
self.v_cache_cpu[layer_id, cpu_block_id],
|
||||
non_blocking=True
|
||||
)
|
||||
# Record event when all transfers complete
|
||||
self._pipeline_next_layer_event.record(self._pipeline_layer_stream)
|
||||
|
||||
def end_decode_pipeline(self) -> None:
|
||||
"""
|
||||
End the cross-layer pipeline.
|
||||
|
||||
Called at the end of a decode step to clean up pipeline state.
|
||||
"""
|
||||
if self._pipeline_active:
|
||||
# Ensure all transfers complete before ending
|
||||
self._pipeline_layer_stream.synchronize()
|
||||
self._pipeline_active = False
|
||||
self._pipeline_cpu_blocks = []
|
||||
self._pipeline_num_blocks = 0
|
||||
|
||||
def is_pipeline_active(self) -> bool:
|
||||
"""Check if decode pipeline is currently active."""
|
||||
return self._pipeline_active
|
||||
|
||||
# ========== Per-layer Prefill Buffer Methods ==========
|
||||
# These methods enable async offload during chunked prefill by using
|
||||
# per-layer buffers instead of shared GPU slots.
|
||||
|
||||
@@ -46,7 +46,7 @@ class FullAttentionPolicy(SparsePolicy):
|
||||
"""Return all blocks - no sparsity."""
|
||||
return available_blocks
|
||||
|
||||
def compute_chunked_attention(
|
||||
def compute_chunked_prefill(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
@@ -86,7 +86,7 @@ class FullAttentionPolicy(SparsePolicy):
|
||||
"""
|
||||
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||
|
||||
logger.debug(f"[DEBUG] FullPolicy.compute_chunked_attention called, "
|
||||
logger.debug(f"[DEBUG] FullPolicy.compute_chunked_prefill called, "
|
||||
f"layer={layer_id}, chunk={current_chunk_idx}, num_tokens={num_tokens}")
|
||||
|
||||
q_batched = q.unsqueeze(0) # [1, seq_len, num_heads, head_dim]
|
||||
@@ -192,5 +192,192 @@ class FullAttentionPolicy(SparsePolicy):
|
||||
# Remove batch dimension: [1, seq_len, num_heads, head_dim] -> [seq_len, num_heads, head_dim]
|
||||
return final_o.squeeze(0)
|
||||
|
||||
def compute_chunked_decode(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
layer_id: int,
|
||||
softmax_scale: float,
|
||||
offload_engine: "OffloadEngine",
|
||||
kvcache_manager: "KVCacheManager",
|
||||
seq: "Sequence",
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute full attention for chunked decode.
|
||||
|
||||
This method handles the complete chunked decode flow:
|
||||
1. Get prefilled CPU blocks
|
||||
2. Apply select_blocks for block filtering
|
||||
3. Load blocks via pipeline (ring buffer or cross-layer)
|
||||
4. Read accumulated decode tokens from decode buffer
|
||||
5. Merge all results
|
||||
|
||||
Args:
|
||||
q: Query tensor [batch_size, num_heads, head_dim]
|
||||
layer_id: Current layer index
|
||||
softmax_scale: Softmax scaling factor
|
||||
offload_engine: OffloadEngine for loading blocks
|
||||
kvcache_manager: KVCacheManager for block management
|
||||
seq: Sequence object
|
||||
|
||||
Returns:
|
||||
Attention output [batch_size, 1, num_heads, head_dim]
|
||||
"""
|
||||
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||
|
||||
# q shape: [batch_size, num_heads, head_dim] (single decode token per sequence)
|
||||
q_batched = q.unsqueeze(1) # [batch, 1, heads, dim]
|
||||
|
||||
# Get only PREFILLED CPU blocks (exclude the current decode block)
|
||||
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
|
||||
if layer_id == 0:
|
||||
logger.debug(f"Decode attention: cpu_block_table={cpu_block_table}, seq.block_table={list(seq.block_table)}")
|
||||
if not cpu_block_table:
|
||||
raise RuntimeError("Chunked decode attention failed: no prefilled CPU blocks available")
|
||||
|
||||
# Calculate valid tokens in the last CPU block
|
||||
# CRITICAL: Use original prefill length, not current seq length!
|
||||
# CPU blocks are fixed after prefill, their content doesn't change during decode.
|
||||
block_size = kvcache_manager.block_size
|
||||
num_prefill_blocks = len(cpu_block_table)
|
||||
total_prefill_tokens = kvcache_manager.get_prefill_len(seq) # Original prefill length
|
||||
last_block_valid_tokens = total_prefill_tokens % block_size
|
||||
if last_block_valid_tokens == 0 and total_prefill_tokens > 0:
|
||||
last_block_valid_tokens = block_size # Last block was exactly full
|
||||
|
||||
# Apply sparse policy (self) for block filtering
|
||||
policy_ctx = PolicyContext(
|
||||
query_chunk_idx=0,
|
||||
num_query_chunks=1,
|
||||
layer_id=layer_id,
|
||||
query=q_batched,
|
||||
is_prefill=False,
|
||||
block_size=kvcache_manager.block_size,
|
||||
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
|
||||
)
|
||||
cpu_block_table = self.select_blocks(cpu_block_table, offload_engine, policy_ctx)
|
||||
|
||||
# Use ring buffer pipeline for loading prefilled blocks
|
||||
load_slots = offload_engine.decode_load_slots
|
||||
o_acc, lse_acc = self._decode_ring_buffer_pipeline(
|
||||
q_batched, cpu_block_table, load_slots, offload_engine,
|
||||
block_size, last_block_valid_tokens, layer_id, softmax_scale
|
||||
)
|
||||
|
||||
# Now attend to accumulated decode tokens from per-layer decode buffer
|
||||
# Compute decode position information internally
|
||||
seq_len = len(seq)
|
||||
decode_pos_in_block = (seq_len - 1) % block_size
|
||||
decode_start_pos = kvcache_manager.get_decode_start_pos(seq)
|
||||
decode_start_pos_in_block = decode_start_pos % block_size
|
||||
num_accumulated = decode_pos_in_block - decode_start_pos_in_block + 1
|
||||
|
||||
# Sync compute_stream with default stream before reading decode_buffer
|
||||
compute_stream = offload_engine.compute_stream
|
||||
compute_stream.wait_stream(torch.cuda.default_stream())
|
||||
|
||||
with torch.cuda.stream(compute_stream):
|
||||
if num_accumulated > 0:
|
||||
# Read from per-layer decode buffer
|
||||
decode_k = offload_engine.decode_k_buffer[layer_id, decode_start_pos_in_block:decode_pos_in_block+1]
|
||||
decode_v = offload_engine.decode_v_buffer[layer_id, decode_start_pos_in_block:decode_pos_in_block+1]
|
||||
decode_k = decode_k.unsqueeze(0)
|
||||
decode_v = decode_v.unsqueeze(0)
|
||||
|
||||
decode_o, decode_lse = flash_attn_with_lse(
|
||||
q_batched, decode_k, decode_v,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=False,
|
||||
)
|
||||
|
||||
if o_acc is None:
|
||||
o_acc = decode_o
|
||||
else:
|
||||
o_acc, _ = merge_attention_outputs(o_acc, lse_acc, decode_o, decode_lse)
|
||||
|
||||
if o_acc is None:
|
||||
raise RuntimeError("Chunked decode attention failed: no KV available")
|
||||
|
||||
# Sync back to default stream before returning
|
||||
torch.cuda.default_stream().wait_stream(compute_stream)
|
||||
|
||||
return o_acc
|
||||
|
||||
def _decode_ring_buffer_pipeline(
|
||||
self,
|
||||
q_batched: torch.Tensor,
|
||||
cpu_block_table: list,
|
||||
load_slots: list,
|
||||
offload_engine: "OffloadEngine",
|
||||
block_size: int,
|
||||
last_block_valid_tokens: int,
|
||||
layer_id: int,
|
||||
softmax_scale: float,
|
||||
):
|
||||
"""
|
||||
Ring buffer pipeline for decode prefill loading.
|
||||
|
||||
Loads one block at a time, computes attention, and merges results.
|
||||
Uses load_to_slot_layer / wait_slot_layer / get_kv_for_slot methods.
|
||||
"""
|
||||
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||
|
||||
num_blocks = len(cpu_block_table)
|
||||
if num_blocks == 0:
|
||||
return None, None
|
||||
|
||||
if not load_slots:
|
||||
return None, None
|
||||
|
||||
o_acc, lse_acc = None, None
|
||||
num_slots = len(load_slots)
|
||||
compute_stream = offload_engine.compute_stream
|
||||
|
||||
# Phase 1: Pre-load up to num_slots blocks
|
||||
num_preload = min(num_slots, num_blocks)
|
||||
for i in range(num_preload):
|
||||
offload_engine.load_to_slot_layer(load_slots[i], layer_id, cpu_block_table[i])
|
||||
|
||||
# Phase 2: Process blocks with pipeline
|
||||
for block_idx in range(num_blocks):
|
||||
current_slot = load_slots[block_idx % num_slots]
|
||||
cpu_block_id = cpu_block_table[block_idx]
|
||||
|
||||
# Wait for current slot's transfer to complete
|
||||
offload_engine.wait_slot_layer(current_slot)
|
||||
|
||||
with torch.cuda.stream(compute_stream):
|
||||
# Get KV from slot
|
||||
prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot)
|
||||
|
||||
# Handle partial last block
|
||||
is_last_block = (block_idx == num_blocks - 1)
|
||||
if is_last_block and last_block_valid_tokens < block_size:
|
||||
prev_k = prev_k[:, :last_block_valid_tokens, :, :]
|
||||
prev_v = prev_v[:, :last_block_valid_tokens, :, :]
|
||||
|
||||
# Compute attention
|
||||
prev_o, prev_lse = flash_attn_with_lse(
|
||||
q_batched, prev_k, prev_v,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=False,
|
||||
)
|
||||
|
||||
# Record compute done for slot reuse
|
||||
offload_engine.record_slot_compute_done(current_slot)
|
||||
|
||||
# Start loading next block (pipeline)
|
||||
next_block_idx = block_idx + num_slots
|
||||
if next_block_idx < num_blocks:
|
||||
offload_engine.load_to_slot_layer(current_slot, layer_id, cpu_block_table[next_block_idx])
|
||||
|
||||
# Merge with accumulated
|
||||
with torch.cuda.stream(compute_stream):
|
||||
if o_acc is None:
|
||||
o_acc, lse_acc = prev_o, prev_lse
|
||||
else:
|
||||
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
|
||||
|
||||
return o_acc, lse_acc
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "FullAttentionPolicy()"
|
||||
|
||||
@@ -192,7 +192,7 @@ class SparsePolicy(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def compute_chunked_attention(
|
||||
def compute_chunked_prefill(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
@@ -233,5 +233,43 @@ class SparsePolicy(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def compute_chunked_decode(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
layer_id: int,
|
||||
softmax_scale: float,
|
||||
offload_engine: "OffloadEngine",
|
||||
kvcache_manager: "KVCacheManager",
|
||||
seq: "Sequence",
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute chunked decode attention (complete flow).
|
||||
|
||||
This is the main entry point for decode attention computation.
|
||||
It defines the complete decode flow:
|
||||
1. Get prefilled blocks from CPU
|
||||
2. Select blocks (call select_blocks)
|
||||
3. Load blocks via pipeline (ring buffer or cross-layer)
|
||||
4. Read accumulated decode tokens from decode buffer
|
||||
5. Merge all results
|
||||
|
||||
The decode position information can be computed internally:
|
||||
- decode_start_pos = kvcache_manager.get_decode_start_pos(seq)
|
||||
- decode_pos_in_block = (len(seq) - 1) % kvcache_manager.block_size
|
||||
|
||||
Args:
|
||||
q: [batch_size, num_heads, head_dim] query for decode token
|
||||
layer_id: transformer layer index
|
||||
softmax_scale: softmax scaling factor
|
||||
offload_engine: OffloadEngine for loading blocks
|
||||
kvcache_manager: KVCacheManager for block management
|
||||
seq: Sequence object
|
||||
|
||||
Returns:
|
||||
[batch_size, 1, num_heads, head_dim] final attention output
|
||||
"""
|
||||
pass
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}()"
|
||||
|
||||
@@ -5,7 +5,6 @@ from torch import nn
|
||||
|
||||
from flash_attn.flash_attn_interface import flash_attn_varlen_func, flash_attn_with_kvcache
|
||||
from nanovllm.utils.context import get_context
|
||||
from nanovllm.kvcache.sparse.policy import PolicyContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -175,7 +174,7 @@ class Attention(nn.Module):
|
||||
Compute attention with per-layer prefill buffer for async offload.
|
||||
|
||||
Simplified design:
|
||||
- All computation logic is delegated to sparse_policy.compute_chunked_attention()
|
||||
- All computation logic is delegated to sparse_policy.compute_chunked_prefill()
|
||||
- This method only handles async offload after computation
|
||||
|
||||
The policy handles:
|
||||
@@ -199,11 +198,11 @@ class Attention(nn.Module):
|
||||
raise RuntimeError("sparse_policy is required for chunked prefill")
|
||||
|
||||
# [DEBUG] Verify execution path
|
||||
logger.debug(f"[DEBUG] Calling sparse_policy.compute_chunked_attention, "
|
||||
logger.debug(f"[DEBUG] Calling sparse_policy.compute_chunked_prefill, "
|
||||
f"policy={sparse_policy}, layer={self.layer_id}, chunk={current_chunk_idx}")
|
||||
|
||||
# Delegate all computation to policy (no flash_attn or merge calls here!)
|
||||
final_o = sparse_policy.compute_chunked_attention(
|
||||
final_o = sparse_policy.compute_chunked_prefill(
|
||||
q, k, v,
|
||||
self.layer_id,
|
||||
self.scale,
|
||||
@@ -237,240 +236,41 @@ class Attention(nn.Module):
|
||||
context,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute decode attention using cross-layer pipeline.
|
||||
Compute decode attention by delegating to sparse policy.
|
||||
|
||||
Optimization: Uses double-buffered layer cache to overlap H2D transfer
|
||||
with computation across layers:
|
||||
- Layer N computes while Layer N+1's data is being loaded
|
||||
- Each layer only waits for its own data, not all layers' data
|
||||
Simplified design:
|
||||
- All computation logic is delegated to sparse_policy.compute_chunked_decode()
|
||||
- This method only validates the policy and delegates
|
||||
|
||||
This reduces effective latency from O(num_layers * transfer_time) to
|
||||
O(transfer_time + num_layers * compute_time) when transfer < compute.
|
||||
The policy handles:
|
||||
1. Loading prefilled blocks from CPU via pipeline
|
||||
2. Computing attention against prefilled KV
|
||||
3. Reading accumulated decode tokens from decode buffer
|
||||
4. Merging all results
|
||||
"""
|
||||
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||
|
||||
# q shape: [batch_size, num_heads, head_dim] (single decode token per sequence)
|
||||
q_batched = q.unsqueeze(1) # [batch, 1, heads, dim]
|
||||
|
||||
kvcache_manager = context.kvcache_manager
|
||||
seq = context.chunked_seq
|
||||
|
||||
# Get only PREFILLED CPU blocks (exclude the current decode block)
|
||||
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
|
||||
if self.layer_id == 0:
|
||||
logger.debug(f"Decode attention: cpu_block_table={cpu_block_table}, seq.block_table={list(seq.block_table)}")
|
||||
if not cpu_block_table:
|
||||
raise RuntimeError("Chunked decode attention failed: no prefilled CPU blocks available")
|
||||
|
||||
# Calculate valid tokens in the last CPU block
|
||||
# CRITICAL: Use original prefill length, not current seq length!
|
||||
# CPU blocks are fixed after prefill, their content doesn't change during decode.
|
||||
block_size = kvcache_manager.block_size
|
||||
num_prefill_blocks = len(cpu_block_table)
|
||||
total_prefill_tokens = kvcache_manager.get_prefill_len(seq) # Original prefill length
|
||||
last_block_valid_tokens = total_prefill_tokens % block_size
|
||||
if last_block_valid_tokens == 0 and total_prefill_tokens > 0:
|
||||
last_block_valid_tokens = block_size # Last block was exactly full
|
||||
|
||||
offload_engine = kvcache_manager.offload_engine
|
||||
|
||||
# Apply sparse policy if enabled (Quest does Top-K selection for decode)
|
||||
# Get sparse policy - required for chunked decode
|
||||
sparse_policy = kvcache_manager.sparse_policy
|
||||
if sparse_policy is not None:
|
||||
policy_ctx = PolicyContext(
|
||||
query_chunk_idx=0,
|
||||
num_query_chunks=1,
|
||||
layer_id=self.layer_id,
|
||||
query=q_batched,
|
||||
is_prefill=False,
|
||||
block_size=kvcache_manager.block_size,
|
||||
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
|
||||
)
|
||||
cpu_block_table = sparse_policy.select_blocks(
|
||||
cpu_block_table, offload_engine, policy_ctx
|
||||
)
|
||||
if sparse_policy is None:
|
||||
raise RuntimeError("sparse_policy is required for chunked decode")
|
||||
|
||||
# Use cross-layer pipeline if active (initialized in model_runner)
|
||||
if offload_engine.is_pipeline_active():
|
||||
o_acc, lse_acc = self._decode_with_layer_pipeline(
|
||||
q_batched, cpu_block_table, offload_engine,
|
||||
block_size, last_block_valid_tokens
|
||||
)
|
||||
else:
|
||||
# Fallback to original ring buffer pipeline
|
||||
load_slots = offload_engine.decode_load_slots
|
||||
o_acc, lse_acc = self._decode_ring_buffer_pipeline(
|
||||
q_batched, cpu_block_table, load_slots, offload_engine,
|
||||
block_size, last_block_valid_tokens
|
||||
)
|
||||
# Check if policy supports decode phase
|
||||
if not sparse_policy.supports_decode:
|
||||
raise RuntimeError(f"{sparse_policy} does not support decode phase")
|
||||
|
||||
# Now attend to accumulated decode tokens from per-layer decode buffer
|
||||
pos_in_block = context.decode_pos_in_block
|
||||
start_pos = context.decode_start_pos_in_block
|
||||
num_accumulated = pos_in_block - start_pos + 1
|
||||
# [DEBUG] Verify execution path
|
||||
logger.debug(f"[DEBUG] Calling sparse_policy.compute_chunked_decode, "
|
||||
f"policy={sparse_policy}, layer={self.layer_id}")
|
||||
|
||||
# Sync compute_stream with default stream before reading decode_buffer
|
||||
compute_stream = offload_engine.compute_stream
|
||||
compute_stream.wait_stream(torch.cuda.default_stream())
|
||||
|
||||
with torch.cuda.stream(compute_stream):
|
||||
if num_accumulated > 0:
|
||||
# Read from per-layer decode buffer
|
||||
decode_k = offload_engine.decode_k_buffer[self.layer_id, start_pos:pos_in_block+1]
|
||||
decode_v = offload_engine.decode_v_buffer[self.layer_id, start_pos:pos_in_block+1]
|
||||
decode_k = decode_k.unsqueeze(0)
|
||||
decode_v = decode_v.unsqueeze(0)
|
||||
|
||||
decode_o, decode_lse = flash_attn_with_lse(
|
||||
q_batched, decode_k, decode_v,
|
||||
softmax_scale=self.scale,
|
||||
causal=False,
|
||||
)
|
||||
|
||||
if o_acc is None:
|
||||
o_acc = decode_o
|
||||
else:
|
||||
o_acc, _ = merge_attention_outputs(o_acc, lse_acc, decode_o, decode_lse)
|
||||
|
||||
if o_acc is None:
|
||||
raise RuntimeError("Chunked decode attention failed: no KV available")
|
||||
|
||||
# Sync back to default stream before returning
|
||||
torch.cuda.default_stream().wait_stream(compute_stream)
|
||||
|
||||
return o_acc
|
||||
|
||||
def _decode_ring_buffer_pipeline(
|
||||
self,
|
||||
q_batched: torch.Tensor,
|
||||
cpu_block_table: list,
|
||||
load_slots: list,
|
||||
# Delegate all computation to policy (no flash_attn or merge calls here!)
|
||||
return sparse_policy.compute_chunked_decode(
|
||||
q,
|
||||
self.layer_id,
|
||||
self.scale,
|
||||
offload_engine,
|
||||
block_size: int,
|
||||
last_block_valid_tokens: int,
|
||||
):
|
||||
"""
|
||||
Ring buffer pipeline for decode prefill loading (same mechanism as prefill).
|
||||
|
||||
Loads one block at a time, computes attention, and merges results.
|
||||
Uses the same load_to_slot_layer / wait_slot_layer / get_kv_for_slot
|
||||
methods as prefill for proven correctness.
|
||||
"""
|
||||
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||
|
||||
num_blocks = len(cpu_block_table)
|
||||
if num_blocks == 0:
|
||||
return None, None
|
||||
|
||||
if not load_slots:
|
||||
return None, None
|
||||
|
||||
o_acc, lse_acc = None, None
|
||||
num_slots = len(load_slots)
|
||||
compute_stream = offload_engine.compute_stream
|
||||
|
||||
# Phase 1: Pre-load up to num_slots blocks
|
||||
num_preload = min(num_slots, num_blocks)
|
||||
for i in range(num_preload):
|
||||
offload_engine.load_to_slot_layer(load_slots[i], self.layer_id, cpu_block_table[i])
|
||||
|
||||
# Phase 2: Process blocks with pipeline
|
||||
for block_idx in range(num_blocks):
|
||||
current_slot = load_slots[block_idx % num_slots]
|
||||
cpu_block_id = cpu_block_table[block_idx]
|
||||
|
||||
# Wait for current slot's transfer to complete
|
||||
offload_engine.wait_slot_layer(current_slot)
|
||||
|
||||
with torch.cuda.stream(compute_stream):
|
||||
# Get KV from slot
|
||||
prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot)
|
||||
|
||||
# Handle partial last block
|
||||
is_last_block = (block_idx == num_blocks - 1)
|
||||
if is_last_block and last_block_valid_tokens < block_size:
|
||||
prev_k = prev_k[:, :last_block_valid_tokens, :, :]
|
||||
prev_v = prev_v[:, :last_block_valid_tokens, :, :]
|
||||
|
||||
# Compute attention
|
||||
prev_o, prev_lse = flash_attn_with_lse(
|
||||
q_batched, prev_k, prev_v,
|
||||
softmax_scale=self.scale,
|
||||
causal=False,
|
||||
kvcache_manager,
|
||||
seq,
|
||||
)
|
||||
|
||||
# Record compute done for slot reuse
|
||||
offload_engine.record_slot_compute_done(current_slot)
|
||||
|
||||
# Start loading next block (pipeline)
|
||||
next_block_idx = block_idx + num_slots
|
||||
if next_block_idx < num_blocks:
|
||||
offload_engine.load_to_slot_layer(current_slot, self.layer_id, cpu_block_table[next_block_idx])
|
||||
|
||||
# Merge with accumulated
|
||||
with torch.cuda.stream(compute_stream):
|
||||
if o_acc is None:
|
||||
o_acc, lse_acc = prev_o, prev_lse
|
||||
else:
|
||||
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
|
||||
|
||||
return o_acc, lse_acc
|
||||
|
||||
def _decode_with_layer_pipeline(
|
||||
self,
|
||||
q_batched: torch.Tensor,
|
||||
cpu_block_table: list,
|
||||
offload_engine,
|
||||
block_size: int,
|
||||
last_block_valid_tokens: int,
|
||||
):
|
||||
"""
|
||||
Decode using cross-layer pipeline for optimized H2D transfer.
|
||||
|
||||
This method uses pre-loaded layer buffers instead of loading
|
||||
blocks one by one. The pipeline loads the next layer's data
|
||||
while the current layer computes, achieving transfer/compute overlap.
|
||||
|
||||
The key insight is that each layer needs the SAME blocks but from
|
||||
different layers of CPU cache. By double-buffering and pipelining
|
||||
across layers, we reduce total latency.
|
||||
"""
|
||||
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||
|
||||
num_blocks = len(cpu_block_table)
|
||||
if num_blocks == 0:
|
||||
return None, None
|
||||
|
||||
compute_stream = offload_engine.compute_stream
|
||||
|
||||
# Get KV from pre-loaded layer buffer (triggers next layer loading)
|
||||
prev_k, prev_v = offload_engine.get_decode_layer_kv(self.layer_id, num_blocks)
|
||||
|
||||
# prev_k, prev_v shape: [num_blocks, block_size, kv_heads, head_dim]
|
||||
# Reshape to [1, num_blocks * block_size, kv_heads, head_dim]
|
||||
total_tokens = num_blocks * block_size
|
||||
|
||||
# Handle partial last block
|
||||
if last_block_valid_tokens < block_size:
|
||||
# Only use valid tokens from last block
|
||||
actual_tokens = (num_blocks - 1) * block_size + last_block_valid_tokens
|
||||
# Flatten and truncate
|
||||
prev_k_flat = prev_k.reshape(-1, prev_k.shape[-2], prev_k.shape[-1])[:actual_tokens]
|
||||
prev_v_flat = prev_v.reshape(-1, prev_v.shape[-2], prev_v.shape[-1])[:actual_tokens]
|
||||
else:
|
||||
prev_k_flat = prev_k.reshape(-1, prev_k.shape[-2], prev_k.shape[-1])
|
||||
prev_v_flat = prev_v.reshape(-1, prev_v.shape[-2], prev_v.shape[-1])
|
||||
|
||||
# Add batch dimension: [1, total_tokens, kv_heads, head_dim]
|
||||
prev_k_batched = prev_k_flat.unsqueeze(0)
|
||||
prev_v_batched = prev_v_flat.unsqueeze(0)
|
||||
|
||||
# Compute attention on all prefilled blocks at once
|
||||
with torch.cuda.stream(compute_stream):
|
||||
o_acc, lse_acc = flash_attn_with_lse(
|
||||
q_batched, prev_k_batched, prev_v_batched,
|
||||
softmax_scale=self.scale,
|
||||
causal=False,
|
||||
)
|
||||
|
||||
return o_acc, lse_acc
|
||||
|
||||
467
task_plan.md
467
task_plan.md
@@ -1,467 +0,0 @@
|
||||
# Task Plan: Sparse Policy 架构重构 v4 (FullPolicy Only)
|
||||
|
||||
## Goal
|
||||
|
||||
将 chunked prefill 的 attention 计算逻辑完全从 `attention.py` 移到 `SparsePolicy` 内部。
|
||||
|
||||
### 验收标准(必须全部满足)
|
||||
|
||||
| # | 标准 | 说明 |
|
||||
|---|------|------|
|
||||
| **1** | `test_needle.py --enable-offload` 通过 | 功能正确性验证 |
|
||||
| **2** | `attention.py` 中 chunked prefill 路径零计算调用 | 不直接调用 `flash_attn_*` 或 `merge_attention_outputs`,全部由 policy 完成 |
|
||||
| **3** | 所有 KV 通信由 `offload_engine` 完成 | 不直接调用 `torch.copy_` 或 `.copy()` 进行 KV 数据传输 |
|
||||
|
||||
**范围**: 仅实现 FullPolicy,暂不涉及 QuestPolicy 和 XAttentionBSAPolicy。Decode 阶段不处理。
|
||||
|
||||
## 当前代码状态(重要发现)
|
||||
|
||||
**`FullPolicy.compute_prefill_attention` 已经实现了完整的 prefill 流程!**
|
||||
|
||||
但 `attention.py` 没有调用它,而是:
|
||||
- 调用 `sparse_policy.select_blocks()` 仅做 block 筛选
|
||||
- 自己实现 `_ring_buffer_pipeline_load` 和 `_sync_load_previous_chunks`
|
||||
- 自己调用 `flash_attn_with_lse` 和 `merge_attention_outputs`
|
||||
|
||||
**结论**:当前代码有冗余,同样的逻辑在两个地方实现。
|
||||
|
||||
### 当前 attention.py 中的违规调用(需要移除)
|
||||
|
||||
```python
|
||||
# 直接计算调用(违反目标 2)
|
||||
flash_attn_with_lse(...)
|
||||
merge_attention_outputs(...)
|
||||
|
||||
# 直接通信调用(违反目标 3)
|
||||
offload_engine.prefill_k_buffer[self.layer_id, :num_tokens].copy_(k)
|
||||
offload_engine.prefill_v_buffer[self.layer_id, :num_tokens].copy_(v)
|
||||
```
|
||||
|
||||
## 核心设计原则
|
||||
|
||||
1. **Policy 内部完成所有 prefill 计算**:包括 block 加载、attention 计算和结果合并
|
||||
2. **select_blocks 传入 offload_engine**:其他策略(Quest/XAttn)可能需要加载 KV 来判断
|
||||
3. **统一方法命名**:使用 `compute_chunked_attention`(不是 `compute_prefill_attention`)
|
||||
4. **chunked_prefill 强制 policy 存在**:没有 policy 则报错
|
||||
5. **attention.py 零计算逻辑**:`_chunked_prefill_attention` 只调用 policy
|
||||
6. **所有 KV 通信通过 offload_engine**:不直接调用 torch.copy
|
||||
|
||||
## 目标架构
|
||||
|
||||
```
|
||||
attention.py (_chunked_prefill_attention):
|
||||
检查 sparse_policy 是否存在
|
||||
↓
|
||||
调用 sparse_policy.compute_chunked_attention(q, k, v, ...)
|
||||
↓
|
||||
处理 async offload(通过 offload_engine)
|
||||
↓
|
||||
返回最终输出(不包含任何计算逻辑,不包含任何直接 copy 调用)
|
||||
|
||||
SparsePolicy.compute_chunked_attention():
|
||||
1. 获取 cpu_block_table
|
||||
2. 调用 select_blocks(blocks, offload_engine, ctx) → 筛选 blocks
|
||||
3. 通过 offload_engine 加载 blocks 并计算 attention(pipeline 或 sync)
|
||||
4. 通过 offload_engine 获取当前 chunk KV,计算 attention(causal)
|
||||
5. 合并所有结果
|
||||
6. 返回 final_output
|
||||
```
|
||||
|
||||
## 关键设计决策
|
||||
|
||||
| 决策 | 说明 |
|
||||
|------|------|
|
||||
| **决策 1** | `compute_chunked_attention` 是唯一的抽象方法,定义完整 prefill 流程 |
|
||||
| **决策 2** | 不添加 `compute_block_attention` 和 `merge_attention_outputs` 抽象方法(过度设计) |
|
||||
| **决策 3** | `select_blocks` 接收 `offload_engine` 参数(其他策略需要) |
|
||||
| **决策 4** | attention.py 的 `_chunked_prefill_attention` 不包含任何 flashattn 或 merge 调用 |
|
||||
| **决策 5** | Decode 阶段不处理,保持现有逻辑 |
|
||||
| **决策 6** | async offload 逻辑保留在 attention.py(通过 offload_engine 方法调用) |
|
||||
| **决策 7** | Phase 4 需要添加 debug 输出验证执行路径 |
|
||||
| **决策 8** | 所有 KV 通信必须通过 offload_engine 方法,不直接调用 torch.copy |
|
||||
|
||||
## Phases
|
||||
|
||||
- [x] Phase 1: 分析当前架构 ✅ 已完成
|
||||
- [ ] Phase 2: 修改 SparsePolicy 基类
|
||||
- [ ] Phase 3: 修改 FullPolicy
|
||||
- [ ] Phase 4: 验证执行路径(添加 debug 输出)
|
||||
- [ ] Phase 5: 修改 attention.py
|
||||
- [ ] Phase 6: 测试验证
|
||||
|
||||
## Phase 1: 分析当前架构 ✅ 已完成
|
||||
|
||||
### 当前 attention.py 中包含的计算逻辑(需要移除)
|
||||
|
||||
1. `_ring_buffer_pipeline_load` 方法:直接调用 flashattn 和 merge
|
||||
2. `_sync_load_previous_chunks` 方法:直接调用 flashattn 和 merge
|
||||
3. `_chunked_prefill_attention` 方法:
|
||||
- 调用上述两个方法
|
||||
- 计算当前 chunk(flash_attn)
|
||||
- 合并结果(merge)
|
||||
|
||||
### 当前 attention.py 中的直接 copy 调用(需要移除或封装)
|
||||
|
||||
```python
|
||||
# attention.py:115-116 - 写入 prefill buffer
|
||||
offload_engine.prefill_k_buffer[self.layer_id, :num_tokens].copy_(k)
|
||||
offload_engine.prefill_v_buffer[self.layer_id, :num_tokens].copy_(v)
|
||||
```
|
||||
|
||||
**处理方案**:在 offload_engine 中添加封装方法,或将此逻辑移入 policy。
|
||||
|
||||
### 当前 FullPolicy 已实现的功能
|
||||
|
||||
`full_policy.py:40-162` 的 `compute_prefill_attention` 已实现:
|
||||
- ring buffer pipeline 加载
|
||||
- sync 加载 fallback
|
||||
- 当前 chunk attention 计算
|
||||
- 结果合并
|
||||
|
||||
**只需重命名为 `compute_chunked_attention` 并微调接口。**
|
||||
|
||||
## Phase 2: 修改 SparsePolicy 基类
|
||||
|
||||
### 2.1 修改 select_blocks 接口
|
||||
|
||||
```python
|
||||
@abstractmethod
|
||||
def select_blocks(
|
||||
self,
|
||||
available_blocks: List[int],
|
||||
offload_engine: "OffloadEngine", # 新增参数
|
||||
ctx: PolicyContext,
|
||||
) -> List[int]:
|
||||
"""
|
||||
选择要加载的 blocks。
|
||||
|
||||
Args:
|
||||
available_blocks: 所有可用的 block IDs
|
||||
offload_engine: offload engine(其他策略可能需要加载 KV 来判断)
|
||||
ctx: policy context
|
||||
|
||||
Returns:
|
||||
选择的 block IDs
|
||||
"""
|
||||
pass
|
||||
```
|
||||
|
||||
### 2.2 添加 compute_chunked_attention 抽象方法
|
||||
|
||||
```python
|
||||
@abstractmethod
|
||||
def compute_chunked_attention(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
layer_id: int,
|
||||
softmax_scale: float,
|
||||
offload_engine: "OffloadEngine",
|
||||
current_chunk_idx: int,
|
||||
seq: "ChunkedSequence",
|
||||
num_tokens: int,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
计算 chunked prefill attention(完整流程)。
|
||||
|
||||
这是 policy 的主入口,定义完整的 prefill 计算流程:
|
||||
1. 获取历史 blocks
|
||||
2. 筛选 blocks(调用 select_blocks)
|
||||
3. 通过 offload_engine 加载和计算历史 blocks
|
||||
4. 通过 offload_engine 获取当前 chunk KV,计算 attention
|
||||
5. 合并所有结果
|
||||
|
||||
Args:
|
||||
q: [seq_len, num_heads, head_dim] 当前 chunk 的 query
|
||||
k, v: [seq_len, num_kv_heads, head_dim] 当前 chunk 的 KV(已写入 prefill buffer)
|
||||
layer_id: 层索引
|
||||
softmax_scale: softmax 缩放因子
|
||||
offload_engine: offload engine
|
||||
current_chunk_idx: 当前 chunk 索引
|
||||
seq: chunked 序列
|
||||
num_tokens: 当前 chunk 的 token 数
|
||||
|
||||
Returns:
|
||||
[seq_len, num_heads, head_dim] 最终 attention 输出
|
||||
"""
|
||||
pass
|
||||
```
|
||||
|
||||
## Phase 3: 修改 FullPolicy
|
||||
|
||||
### 3.1 重命名方法
|
||||
|
||||
将 `compute_prefill_attention` 重命名为 `compute_chunked_attention`。
|
||||
|
||||
### 3.2 修改 select_blocks 签名
|
||||
|
||||
```python
|
||||
def select_blocks(
|
||||
self,
|
||||
available_blocks: List[int],
|
||||
offload_engine: "OffloadEngine", # 新增参数(不使用)
|
||||
ctx: PolicyContext,
|
||||
) -> List[int]:
|
||||
"""Return all blocks - no sparsity."""
|
||||
return available_blocks
|
||||
```
|
||||
|
||||
### 3.3 验证 compute_chunked_attention 实现
|
||||
|
||||
当前 `compute_prefill_attention` 已实现完整逻辑,确认:
|
||||
- [x] 获取 cpu_block_table
|
||||
- [x] ring buffer pipeline 加载(通过 offload_engine)
|
||||
- [x] sync 加载 fallback(通过 offload_engine)
|
||||
- [x] 当前 chunk attention 计算
|
||||
- [x] 结果合并
|
||||
|
||||
**注意**:当前实现没有调用 `select_blocks`,需要添加。
|
||||
|
||||
### 3.4 确保所有 KV 通信通过 offload_engine
|
||||
|
||||
检查 `compute_chunked_attention` 内部:
|
||||
- 历史 block 加载:已通过 `offload_engine.load_to_slot_layer()` 等方法 ✅
|
||||
- 当前 chunk KV 获取:已通过 `offload_engine.get_prefill_buffer_slice()` ✅
|
||||
|
||||
## Phase 4: 验证执行路径(添加 debug 输出)
|
||||
|
||||
### 4.1 验证目标
|
||||
|
||||
确认代码修改后,执行路径正确:
|
||||
|
||||
| 检查点 | 位置 | 预期行为 |
|
||||
|--------|------|----------|
|
||||
| **Policy 创建** | `kvcache/__init__.py` | FullAttentionPolicy 被创建 |
|
||||
| **Policy 调用** | `attention.py` | `_chunked_prefill_attention` 调用 `sparse_policy.compute_chunked_attention` |
|
||||
| **select_blocks 调用** | `full_policy.py` | `compute_chunked_attention` 内部调用 `select_blocks` |
|
||||
| **旧方法未调用** | `attention.py` | `_ring_buffer_pipeline_load` 和 `_sync_load_previous_chunks` 不再被调用 |
|
||||
| **无直接 copy 调用** | `attention.py` | chunked prefill 路径不直接调用 `.copy_()` |
|
||||
|
||||
### 4.2 添加 debug 输出位置
|
||||
|
||||
**位置 1: `kvcache/__init__.py` - policy 创建时**
|
||||
```python
|
||||
sparse_policy = create_sparse_policy(sparse_policy_type, **policy_kwargs)
|
||||
logger.info(f"[DEBUG] Created sparse policy: {sparse_policy}")
|
||||
```
|
||||
|
||||
**位置 2: `attention.py` - 调用 policy 时**
|
||||
```python
|
||||
# 在 _chunked_prefill_attention 中
|
||||
logger.debug(f"[DEBUG] Calling sparse_policy.compute_chunked_attention, "
|
||||
f"policy={sparse_policy}, layer={self.layer_id}, chunk={current_chunk_idx}")
|
||||
```
|
||||
|
||||
**位置 3: `full_policy.py` - compute_chunked_attention 入口**
|
||||
```python
|
||||
def compute_chunked_attention(self, ...):
|
||||
logger.debug(f"[DEBUG] FullPolicy.compute_chunked_attention called, "
|
||||
f"layer={layer_id}, chunk={current_chunk_idx}, num_tokens={num_tokens}")
|
||||
# ... 实现
|
||||
```
|
||||
|
||||
**位置 4: `full_policy.py` - select_blocks 调用**
|
||||
```python
|
||||
# 在 compute_chunked_attention 内部
|
||||
selected_blocks = self.select_blocks(cpu_block_table, offload_engine, policy_ctx)
|
||||
logger.debug(f"[DEBUG] select_blocks: input={len(cpu_block_table)} blocks, "
|
||||
f"output={len(selected_blocks)} blocks")
|
||||
```
|
||||
|
||||
### 4.3 验证方法
|
||||
|
||||
运行测试并检查日志输出:
|
||||
```bash
|
||||
PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
||||
python tests/test_needle.py --model <model_path> --enable-offload 2>&1 | grep DEBUG
|
||||
```
|
||||
|
||||
预期输出:
|
||||
```
|
||||
[DEBUG] Created sparse policy: FullAttentionPolicy()
|
||||
[DEBUG] Calling sparse_policy.compute_chunked_attention, policy=FullAttentionPolicy(), layer=0, chunk=0
|
||||
[DEBUG] FullPolicy.compute_chunked_attention called, layer=0, chunk=0, num_tokens=...
|
||||
[DEBUG] select_blocks: input=0 blocks, output=0 blocks
|
||||
[DEBUG] Calling sparse_policy.compute_chunked_attention, policy=FullAttentionPolicy(), layer=0, chunk=1
|
||||
[DEBUG] FullPolicy.compute_chunked_attention called, layer=0, chunk=1, num_tokens=...
|
||||
[DEBUG] select_blocks: input=1 blocks, output=1 blocks
|
||||
...
|
||||
```
|
||||
|
||||
### 4.4 清理 debug 输出
|
||||
|
||||
验证完成后,将 debug 级别的日志改为更低级别(如 `logger.debug`),或通过环境变量控制:
|
||||
```python
|
||||
if os.environ.get('NANOVLLM_DEBUG_POLICY'):
|
||||
logger.info(f"[DEBUG] ...")
|
||||
```
|
||||
|
||||
## Phase 5: 修改 attention.py
|
||||
|
||||
### 5.1 简化 _chunked_prefill_attention
|
||||
|
||||
**修改后**:
|
||||
```python
|
||||
def _chunked_prefill_attention(self, q, k, v, context):
|
||||
kvcache_manager = context.kvcache_manager
|
||||
seq = context.chunked_seq
|
||||
offload_engine = kvcache_manager.offload_engine
|
||||
current_chunk_idx = context.current_chunk_idx
|
||||
num_tokens = k.shape[0]
|
||||
|
||||
# 获取 sparse policy
|
||||
sparse_policy = kvcache_manager.sparse_policy
|
||||
if sparse_policy is None:
|
||||
raise RuntimeError("sparse_policy is required for chunked prefill")
|
||||
|
||||
# [DEBUG] 验证执行路径
|
||||
logger.debug(f"[DEBUG] Calling sparse_policy.compute_chunked_attention, "
|
||||
f"policy={sparse_policy}, layer={self.layer_id}, chunk={current_chunk_idx}")
|
||||
|
||||
# 调用 policy 计算 attention(所有计算逻辑在 policy 内部)
|
||||
# 注意:不直接调用 flash_attn 或 merge,全部由 policy 完成
|
||||
final_o = sparse_policy.compute_chunked_attention(
|
||||
q, k, v,
|
||||
self.layer_id,
|
||||
self.scale,
|
||||
offload_engine,
|
||||
current_chunk_idx,
|
||||
seq,
|
||||
num_tokens,
|
||||
)
|
||||
|
||||
# Per-layer ASYNC offload(通过 offload_engine 方法,不直接 copy)
|
||||
if offload_engine is not None and seq is not None:
|
||||
cpu_block_ids, _ = kvcache_manager.get_all_cpu_blocks(seq)
|
||||
if current_chunk_idx < len(cpu_block_ids):
|
||||
cpu_block_id = cpu_block_ids[current_chunk_idx]
|
||||
offload_engine.offload_prefill_buffer_async(
|
||||
self.layer_id, cpu_block_id, num_tokens
|
||||
)
|
||||
|
||||
return final_o
|
||||
```
|
||||
|
||||
### 5.2 处理 prefill buffer 写入
|
||||
|
||||
当前 `forward()` 方法中有直接 copy 调用:
|
||||
```python
|
||||
# 当前代码(违反目标 3)
|
||||
offload_engine.prefill_k_buffer[self.layer_id, :num_tokens].copy_(k)
|
||||
offload_engine.prefill_v_buffer[self.layer_id, :num_tokens].copy_(v)
|
||||
```
|
||||
|
||||
**方案 A**:在 offload_engine 中添加封装方法
|
||||
```python
|
||||
# offload_engine.py
|
||||
def write_prefill_buffer(self, layer_id: int, k: Tensor, v: Tensor, num_tokens: int):
|
||||
self.prefill_k_buffer[layer_id, :num_tokens].copy_(k)
|
||||
self.prefill_v_buffer[layer_id, :num_tokens].copy_(v)
|
||||
|
||||
# attention.py
|
||||
offload_engine.write_prefill_buffer(self.layer_id, k, v, num_tokens)
|
||||
```
|
||||
|
||||
**方案 B**:将此逻辑移入 policy(作为 compute_chunked_attention 的一部分)
|
||||
|
||||
**推荐方案 A**:保持 attention.py 调用 offload_engine 方法,但不直接操作 buffer。
|
||||
|
||||
### 5.3 删除的方法
|
||||
|
||||
删除以下方法(逻辑已移到 FullPolicy):
|
||||
- `_ring_buffer_pipeline_load`
|
||||
- `_sync_load_previous_chunks`
|
||||
|
||||
### 5.4 保留的方法
|
||||
|
||||
Decode 相关方法保持不变:
|
||||
- `_chunked_decode_attention`
|
||||
- `_decode_with_layer_pipeline`
|
||||
- `_decode_ring_buffer_pipeline`
|
||||
|
||||
## Phase 6: 测试验证
|
||||
|
||||
### 6.1 功能测试
|
||||
|
||||
- [ ] 运行 `test_needle.py --enable-offload` (FULL policy)
|
||||
- [ ] 验证输出正确(needle value 匹配)
|
||||
- [ ] 检查 debug 日志确认执行路径正确
|
||||
|
||||
### 6.2 代码审查(验收标准检查)
|
||||
|
||||
- [ ] **标准 1**: test_needle.py 通过 ✓
|
||||
- [ ] **标准 2**: `_chunked_prefill_attention` 方法内无 `flash_attn` 或 `merge_attention_outputs` 调用
|
||||
- [ ] **标准 3**: `_chunked_prefill_attention` 方法内无直接 `.copy_()` 调用
|
||||
|
||||
**注意**:标准 2 和 3 仅适用于 chunked prefill 路径。Decode 路径和其他路径可以有 `flash_attn` 调用。
|
||||
|
||||
**验证方法**:
|
||||
|
||||
**方法 1:使用 cclsp LSP 工具验证调用链(推荐)**
|
||||
|
||||
使用 `mcp__cclsp__find_references` 查找计算函数的调用位置,确认 chunked prefill 路径无直接调用:
|
||||
|
||||
```
|
||||
# 查找 flash_attn_with_lse 的所有调用
|
||||
mcp__cclsp__find_references(file_path="nanovllm/layers/attention.py", symbol_name="flash_attn_with_lse")
|
||||
|
||||
# 查找 merge_attention_outputs 的所有调用
|
||||
mcp__cclsp__find_references(file_path="nanovllm/layers/attention.py", symbol_name="merge_attention_outputs")
|
||||
|
||||
# 查找 _chunked_prefill_attention 的实现
|
||||
mcp__cclsp__find_definition(file_path="nanovllm/layers/attention.py", symbol_name="_chunked_prefill_attention")
|
||||
```
|
||||
|
||||
验证结果应显示:
|
||||
- `flash_attn_with_lse` 调用仅出现在 decode 路径或 `full_policy.py` 中
|
||||
- `_chunked_prefill_attention` 内部只调用 `sparse_policy.compute_chunked_attention`
|
||||
|
||||
**方法 2:手动代码审查**
|
||||
|
||||
检查 `_chunked_prefill_attention` 方法实现,确认:
|
||||
1. 只调用 `sparse_policy.compute_chunked_attention(...)`
|
||||
2. 只调用 `offload_engine.offload_prefill_buffer_async(...)` 等 offload_engine 方法
|
||||
3. 不直接调用 `flash_attn_*`、`merge_attention_outputs` 或 `.copy_()`
|
||||
|
||||
```bash
|
||||
# 辅助检查:找出所有 flash_attn 调用位置
|
||||
grep -n "flash_attn\|merge_attention_outputs" nanovllm/layers/attention.py
|
||||
|
||||
# 辅助检查:找出所有 copy 调用位置
|
||||
grep -n "\.copy_\|\.copy(" nanovllm/layers/attention.py
|
||||
```
|
||||
|
||||
### 6.3 回归测试
|
||||
|
||||
- [ ] 验证 decode 阶段不受影响
|
||||
- [ ] 验证非 offload 模式不受影响(如果适用)
|
||||
|
||||
## 关键文件清单
|
||||
|
||||
| 文件 | 修改内容 |
|
||||
|------|----------|
|
||||
| `nanovllm/kvcache/sparse/policy.py` | 添加 `compute_chunked_attention` 抽象方法,修改 `select_blocks` 签名 |
|
||||
| `nanovllm/kvcache/sparse/full_policy.py` | 重命名方法,修改 `select_blocks` 签名,添加 `select_blocks` 调用,添加 debug 输出 |
|
||||
| `nanovllm/layers/attention.py` | 简化 `_chunked_prefill_attention`,删除 `_ring_buffer_pipeline_load` 和 `_sync_load_previous_chunks`,添加 debug 输出 |
|
||||
| `nanovllm/kvcache/__init__.py` | 添加 policy 创建的 debug 输出 |
|
||||
| `nanovllm/kvcache/offload_engine.py` | (可选)添加 `write_prefill_buffer` 方法封装 |
|
||||
|
||||
## Decisions Made
|
||||
|
||||
- **决策 1**: 只添加一个抽象方法 `compute_chunked_attention`(不添加 `compute_block_attention` 和 `merge_attention_outputs`)
|
||||
- **决策 2**: `select_blocks` 接收 `offload_engine` 参数
|
||||
- **决策 3**: 统一使用 `compute_chunked_attention` 命名
|
||||
- **决策 4**: Decode 阶段不处理
|
||||
- **决策 5**: async offload 逻辑保留在 attention.py(通过 offload_engine 方法调用)
|
||||
- **决策 6**: Phase 4 添加 debug 输出验证执行路径,验证完成后可降级或移除
|
||||
- **决策 7**: prefill buffer 写入通过 offload_engine 封装方法实现(方案 A)
|
||||
- **决策 8**: 所有 KV 通信必须通过 offload_engine 方法,不直接调用 torch.copy
|
||||
|
||||
## Errors Encountered
|
||||
|
||||
(待记录)
|
||||
|
||||
## Status
|
||||
|
||||
**Planning Complete** - v4 计划已完成,包含明确的验收标准和执行路径验证步骤
|
||||
@@ -1,362 +0,0 @@
|
||||
# Task Plan: XAttention BSA 模块化集成
|
||||
|
||||
## Goal
|
||||
将 XAttention BSA 策略按照统一接口集成到 nano-vllm 的 sparse policy 框架中,实现模块化设计。
|
||||
|
||||
**最终验证目标**: 运行 `tests/test_ruler.py` 测试 32K 数据的 10 个以内的 sample,得到合理结果(不一定全部 PASS,但结果应在预期精度范围内)。
|
||||
|
||||
---
|
||||
|
||||
## 强制要求:使用 Hive-Mind 集群思考
|
||||
|
||||
**必须使用 Claude Flow MCP 的 hive-mind 集群进行深度推理,提高实现精度。**
|
||||
|
||||
### 启动 Hive-Mind 的方式
|
||||
|
||||
在每个复杂阶段开始前,必须执行以下步骤:
|
||||
|
||||
1. **初始化 Hive-Mind 集群**:
|
||||
```python
|
||||
# 通过 MCP 调用
|
||||
mcp__claude-flow_alpha__hive-mind_init(
|
||||
topology="mesh", # 或 "hierarchical", "ring", "star"
|
||||
maxAgents=5, # 集群大小
|
||||
)
|
||||
```
|
||||
|
||||
2. **生成专业代理(Spawning Specialists)**:
|
||||
```python
|
||||
# 为不同任务类型创建代理
|
||||
mcp__claude-flow_alpha__hive-mind_spawn(
|
||||
count=3,
|
||||
type="specialist", # researcher, coder, analyst
|
||||
)
|
||||
```
|
||||
|
||||
3. **广播思考任务**:
|
||||
```python
|
||||
mcp__claude-flow_alpha__hive-mind_broadcast(
|
||||
message="分析当前架构设计的潜在问题...",
|
||||
priority="high"
|
||||
)
|
||||
```
|
||||
|
||||
4. **获取集群状态和共识**:
|
||||
```python
|
||||
mcp__claude-flow_alpha__hive-mind_status(verbose=True)
|
||||
mcp__claude-flow_alpha__hive-mind_consensus(
|
||||
action="propose",
|
||||
type="design",
|
||||
value="模块化接口设计方案"
|
||||
)
|
||||
```
|
||||
|
||||
### 适用阶段
|
||||
|
||||
以下阶段**必须**使用 Hive-Mind 集群思考:
|
||||
|
||||
- ✅ Phase 1: SparsePolicy 基类接口确认
|
||||
- ✅ Phase 2: XAttentionBSAPolicy 接口对齐
|
||||
- ✅ Phase 3: OffloadEngine 辅助方法模块化
|
||||
- ✅ Phase 5: attention.py 集成点验证
|
||||
|
||||
其他阶段(Phase 4, 6, 7)可以使用标准思考模式。
|
||||
|
||||
### 集群配置建议
|
||||
|
||||
```yaml
|
||||
# 推荐配置
|
||||
topology: mesh # 网状拓扑,适合并行推理
|
||||
maxAgents: 5 # 5个专业代理
|
||||
agentTypes:
|
||||
- researcher # 架构分析
|
||||
- coder # 代码实现
|
||||
- analyst # 接口验证
|
||||
- optimizer # 性能优化
|
||||
- validator # 正确性验证
|
||||
```
|
||||
|
||||
### 输出要求
|
||||
|
||||
使用 Hive-Mind 后,必须在计划中记录:
|
||||
1. 集群产生的关键洞察
|
||||
2. 多代理共识达成的决策
|
||||
3. 发现的潜在问题和解决方案
|
||||
|
||||
---
|
||||
|
||||
## 当前架构分析
|
||||
|
||||
### SparsePolicy 基类接口
|
||||
|
||||
从 `nanovllm/kvcache/sparse/policy.py` 需要确认基类定义:
|
||||
|
||||
```python
|
||||
class SparsePolicy:
|
||||
# 能力标记
|
||||
supports_prefill: bool
|
||||
supports_decode: bool
|
||||
requires_block_selection: bool
|
||||
|
||||
# 核心方法
|
||||
def select_blocks(self, available_blocks: List[int], ctx: PolicyContext) -> List[int]
|
||||
|
||||
# 可选方法(prefill 专用)
|
||||
def sparse_prefill_attention(self, q, k, v, layer_id) -> torch.Tensor
|
||||
|
||||
# 初始化
|
||||
def initialize(self, num_layers, num_kv_heads, head_dim, num_cpu_blocks, dtype, device)
|
||||
def reset(self)
|
||||
```
|
||||
|
||||
### 当前 XAttentionBSAPolicy 实现
|
||||
|
||||
已实现但需要确认模块化集成的部分:
|
||||
- `xattn_bsa.py` - 策略类实现
|
||||
- `config.py` - 枚举和参数
|
||||
- `sparse/__init__.py` - 策略工厂
|
||||
- `offload_engine.py` - 辅助方法
|
||||
- `attention.py` - 集成点
|
||||
|
||||
## 详细实现计划
|
||||
|
||||
### Phase 1: 确保 SparsePolicy 基类接口统一
|
||||
|
||||
**任务**: 验证 `SparsePolicy` 基类定义是否包含所有必需的方法
|
||||
|
||||
**步骤**:
|
||||
1. 读取 `nanovllm/kvcache/sparse/policy.py`
|
||||
2. 确认基类定义包含:
|
||||
- `supports_prefill`, `supports_decode`, `requires_block_selection` 类属性
|
||||
- `select_blocks()` 方法
|
||||
- `sparse_prefill_attention()` 方法(可选)
|
||||
- `initialize()`, `reset()` 方法
|
||||
3. 如果缺失,补充到基类定义中
|
||||
|
||||
**预期结果**: 基类定义完整,所有策略类可以遵循统一接口
|
||||
|
||||
---
|
||||
|
||||
### Phase 2: XAttentionBSAPolicy 接口对齐
|
||||
|
||||
**任务**: 确保 XAttentionBSAPolicy 完全符合 SparsePolicy 接口
|
||||
|
||||
**步骤**:
|
||||
1. 确认 `xattn_bsa.py` 中的类属性正确:
|
||||
```python
|
||||
class XAttentionBSAPolicy(SparsePolicy):
|
||||
supports_prefill = True
|
||||
supports_decode = False
|
||||
requires_block_selection = False # 注意:BSA 内部处理选择
|
||||
```
|
||||
|
||||
2. 确保方法签名与基类一致:
|
||||
- `select_blocks(available_blocks, ctx) -> List[int]`
|
||||
- `sparse_prefill_attention(q, k, v, layer_id) -> Tensor`
|
||||
- `initialize(...)`
|
||||
- `reset()`
|
||||
|
||||
3. 添加文档说明:BSA 在 prefill 阶段内部处理 block 选择,因此 `select_blocks` 返回所有可用块
|
||||
|
||||
**预期结果**: XAttentionBSAPolicy 完全符合 SparsePolicy 统一接口
|
||||
|
||||
---
|
||||
|
||||
### Phase 3: OffloadEngine 辅助方法模块化
|
||||
|
||||
**任务**: 确保 OffloadEngine 的辅助方法正确定义且模块化
|
||||
|
||||
**步骤**:
|
||||
1. 确认 `offload_engine.py` 中的辅助方法位置:
|
||||
```python
|
||||
# 在 OffloadEngine 类中添加这两个方法
|
||||
def load_block_sample_from_cpu(self, cpu_block_id, layer_id, num_samples):
|
||||
"""加载采样 tokens 用于估算阶段"""
|
||||
...
|
||||
|
||||
def load_block_full_from_cpu(self, cpu_block_id, layer_id):
|
||||
"""加载完整 block 用于计算阶段"""
|
||||
...
|
||||
```
|
||||
|
||||
2. 确保方法签名与 `xattn_bsa.py` 中的调用一致
|
||||
|
||||
3. 添加适当的文档说明这两个方法的用途和使用场景
|
||||
|
||||
**预期结果**: OffloadEngine 提供统一的 block 加载接口
|
||||
|
||||
---
|
||||
|
||||
### Phase 4: 模块化集成到工厂模式
|
||||
|
||||
**任务**: 确保策略创建通过统一的工厂模式
|
||||
|
||||
**步骤**:
|
||||
1. 检查 `nanovllm/kvcache/__init__.py` 中的 `create_kvcache_manager` 函数
|
||||
|
||||
2. 确认策略创建逻辑清晰:
|
||||
```python
|
||||
# 根据策略类型构建相应的 kwargs
|
||||
if sparse_policy_type == SparsePolicyType.XATTN_BSA:
|
||||
policy_kwargs = {
|
||||
'block_size': getattr(config, 'sparse_block_size', 128),
|
||||
'samples_per_chunk': getattr(config, 'sparse_samples_per_chunk', 128),
|
||||
'threshold': getattr(config, 'sparse_threshold', 0.9),
|
||||
'use_triton': getattr(config, 'sparse_use_triton', True),
|
||||
'stride': getattr(config, sparse_stride', 8),
|
||||
}
|
||||
```
|
||||
|
||||
3. 确认所有策略类型都有相应的 kwargs 构建逻辑
|
||||
|
||||
**预期结果**: 通过 `create_sparse_policy()` 创建所有策略
|
||||
|
||||
---
|
||||
|
||||
### Phase 5: attention.py 集成点验证
|
||||
|
||||
**任务**: 确保 attention.py 中的集成点正确调用策略接口
|
||||
|
||||
**步骤**:
|
||||
1. 检查 `nanovllm/layers/attention.py` 中的 `_chunked_prefill_attention` 方法
|
||||
|
||||
2. 确认集成逻辑:
|
||||
```python
|
||||
# 检测策略是否有 sparse_prefill_attention 方法
|
||||
if sparse_policy is not None and hasattr(sparse_policy, 'sparse_prefill_attention'):
|
||||
if sparse_policy.supports_prefill:
|
||||
# 使用策略的 sparse_prefill_attention 方法
|
||||
o = sparse_policy.sparse_prefill_attention(q, k, v, self.layer_id)
|
||||
# 处理异步 offload
|
||||
return o
|
||||
|
||||
# 否则使用标准流程(Quest, etc.)
|
||||
# ...
|
||||
```
|
||||
|
||||
3. 确保没有绕过策略接口直接调用其他逻辑
|
||||
|
||||
**预期结果**: attention.py 通过统一的策略接口调用 BSA
|
||||
|
||||
---
|
||||
|
||||
### Phase 6: 配置参数模块化
|
||||
|
||||
**任务**: 确保配置参数结构清晰,易于使用
|
||||
|
||||
**步骤**:
|
||||
1. 检查 `nanovllm/config.py` 中的配置结构
|
||||
|
||||
2. 确认 XAttention BSA 参数组织清晰:
|
||||
```python
|
||||
# 通用 sparse 参数
|
||||
sparse_policy: SparsePolicyType = SparsePolicyType.FULL
|
||||
sparse_topk_blocks: int = 8 # Quest
|
||||
sparse_threshold_blocks: int = 4 # Quest
|
||||
|
||||
# XATTN_BSA 专用参数
|
||||
sparse_block_size: int = 128
|
||||
sparse_samples_per_chunk: int = 128
|
||||
sparse_threshold: float = 0.9
|
||||
sparse_use_triton: bool = True
|
||||
sparse_stride: int = 8
|
||||
```
|
||||
|
||||
3. 考虑是否需要参数分组或嵌套配置
|
||||
|
||||
**预期结果**: 配置参数清晰,易于理解和使用
|
||||
|
||||
---
|
||||
|
||||
### Phase 7: 模块化验证测试
|
||||
|
||||
**任务**: 创建简单的验证脚本确保模块化集成正确
|
||||
|
||||
**步骤**:
|
||||
1. 创建 `tests/test_xattn_bsa_integration.py` 测试脚本
|
||||
|
||||
2. 验证以下功能:
|
||||
- XAttentionBSAPolicy 可以通过 `create_sparse_policy()` 创建
|
||||
- 策略正确响应 `supports_prefill`, `supports_decode` 查询
|
||||
- `select_blocks()` 方法返回正确结果
|
||||
- OffloadEngine 辅助方法可以正常调用
|
||||
- 在模拟环境中策略可以被正确调用
|
||||
|
||||
3. 测试用例:
|
||||
```python
|
||||
# Test 1: 策略创建
|
||||
from nanovllm.config import Config, SparsePolicyType
|
||||
from nanovllm.kvcache.sparse import create_sparse_policy
|
||||
|
||||
policy = create_sparse_policy(SparsePolicyType.XATTN_BSA)
|
||||
assert hasattr(policy, 'sparse_prefill_attention')
|
||||
assert policy.supports_prefill == True
|
||||
assert policy.supports_decode == False
|
||||
|
||||
# Test 2: 接口一致性
|
||||
# 验证方法签名
|
||||
# ...
|
||||
|
||||
# Test 3: OffloadEngine 辅助方法
|
||||
# ...
|
||||
```
|
||||
|
||||
**预期结果**: 所有测试通过,模块化集成验证成功
|
||||
|
||||
---
|
||||
|
||||
## 关键设计原则
|
||||
|
||||
### 1. 接口统一性
|
||||
- 所有策略通过 `SparsePolicy` 基类提供统一接口
|
||||
- 工厂模式创建策略实例
|
||||
- 策略切换透明,不影响其他模块
|
||||
|
||||
### 2. 模块化独立性
|
||||
- 每个策略类独立实现
|
||||
- OffloadEngine 提供通用辅助方法
|
||||
- attention.py 通过策略接口调用,不依赖具体实现
|
||||
|
||||
### 3. 可扩展性
|
||||
- 添加新策略只需:
|
||||
1. 创建新的策略类继承 `SparsePolicy`
|
||||
2. 添加到 `SparsePolicyType` 枚举
|
||||
3. 在工厂函数中添加创建逻辑
|
||||
4. 添加相应的配置参数
|
||||
|
||||
---
|
||||
|
||||
## 文件修改清单
|
||||
|
||||
### 必须修改的文件
|
||||
1. `nanovllm/kvcache/sparse/policy.py` - 确保基类定义完整
|
||||
2. `nanovllm/kvcache/sparse/xattn_bsa.py` - 确保接口对齐
|
||||
3. `nanovllm/kvcache/offload_engine.py` - 添加辅助方法
|
||||
4. `nanovllm/layers/attention.py` - 验证集成点
|
||||
5. `nanovllm/config.py` - 确认参数结构
|
||||
6. `nanovllm/kvcache/__init__.py` - 确认工厂模式
|
||||
7. `nanovllm/kvcache/sparse/__init__.py` - 确认注册逻辑
|
||||
|
||||
### 可选创建的文件
|
||||
- `tests/test_xattn_bsa_integration.py` - 集成验证测试
|
||||
|
||||
---
|
||||
|
||||
## 实现状态
|
||||
|
||||
- [ ] Phase 1: SparsePolicy 基类接口确认
|
||||
- [ ] Phase 2: XAttentionBSAPolicy 接口对齐
|
||||
- [ ] Phase 3: OffloadEngine 辅助方法模块化
|
||||
- [ ] Phase 4: 工厂模式集成验证
|
||||
- [ ] Phase 5: attention.py 集成点验证
|
||||
- [ ] Phase 6: 配置参数模块化
|
||||
- [ ] Phase 7: 模块化验证测试
|
||||
|
||||
---
|
||||
|
||||
## 备注
|
||||
|
||||
- 此计划专注于模块化集成,不涉及算法优化
|
||||
- 所有修改都遵循现有框架的设计模式
|
||||
- 重点在于接口统一和模块解耦
|
||||
- 测试阶段使用简单脚本验证即可,不需要完整的端到端测试
|
||||
Reference in New Issue
Block a user