📝 docs: add layer offload planning notes and task plan

Add planning documents for layer-wise offload implementation:
- notes.md: Implementation notes and findings
- task_plan.md: Detailed task breakdown and progress tracking

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Zijie Tian
2026-01-22 06:04:36 +08:00
parent e313dd795a
commit 69b779e252
2 changed files with 679 additions and 0 deletions

130
notes.md Normal file
View File

@@ -0,0 +1,130 @@
# Notes: SparsePolicy Refactoring Research
## Sources
### Source 1: tzj/minference branch - policy.py
- 路径: `nanovllm/kvcache/sparse/policy.py`
- 关键设计:
- `PolicyContext` 数据类包含 query_chunk_idx, num_query_chunks, layer_id, query, is_prefill 等
- `select_blocks()` 需要 offload_engine 参数
- `compute_chunked_prefill()``compute_chunked_decode()` 是完整的 attention 流程
- `on_prefill_offload()` / `on_decode_offload()` hooks 用于收集元数据
### Source 2: tzj/minference branch - full_policy.py
- 路径: `nanovllm/kvcache/sparse/full_policy.py`
- 关键实现:
- `compute_chunked_prefill()` 内部使用 ring buffer pipeline 加载 blocks
- 使用 `flash_attn_with_lse``merge_attention_outputs` 合并多个 chunk 的 attention
- `compute_chunked_decode()` 处理 prefilled blocks + decode buffer
### Source 3: tzj/layer-offload branch - model_runner.py
- 路径: `nanovllm/engine/model_runner.py`
- 关键设计:
- `run_layerwise_offload_prefill()` 逐层处理,每层计算完整 attention
- `sparse_prefill_policy.sparse_prefill_attention(q, k, v, layer_id)` 简单接口
- FULL policy 通过 `if sparse_prefill_policy is None` 走 else 分支
### Source 4: tzj/layer-offload branch - xattn.py
- 路径: `nanovllm/kvcache/sparse/xattn.py`
- 关键实现:
- `sparse_prefill_attention()` 直接使用 FlashAttention因为 chunked prefill 架构限制)
- 保留 Triton kernels 供未来 GPU-only 模式
## Synthesized Findings
### 架构差异总结
| 方面 | Chunked Offload | Layerwise Offload |
|------|-----------------|-------------------|
| **Prefill 流程** | chunk-by-chunk跨层 | layer-by-layer完整序列 |
| **KV 存储** | 每 chunk 立即 offload | 每层计算后 offload |
| **Attention 计算** | 分多次计算+合并 | 一次完整计算 |
| **Block 加载** | 需要从 CPU 加载历史 | 不需要,已在 GPU |
| **Policy 责任** | 完整 attention 流程 | 仅 attention kernel 选择 |
### Layerwise Offload 的简化点
1. **不需要 block selection**: 整层 KV 都在 GPU无需选择
2. **不需要 offload_engine 参数**: Policy 不负责加载 KV
3. **不需要 merge_attention_outputs**: 一次计算完整 attention
4. **不需要 offload hooks**: offload 在 model_runner 统一处理
### 设计建议
1. **保持接口简单**: 只需要 `compute_prefill_attention()``compute_decode_attention()`
2. **FULL 也实现方法**: 不再通过 `is None` 判断,所有 policy 统一调用
3. **移除不必要的参数**: 不需要 offload_engine, kvcache_manager, seq 等
4. **统一命名**: 使用 `compute_*_attention` 而不是 `sparse_prefill_attention`
## Code Examples
### 当前调用方式 (model_runner.py:876-891)
```python
# Sparse or Full attention
if self.sparse_prefill_policy is not None:
# MInference or other sparse prefill policy
attn_output = self.sparse_prefill_policy.sparse_prefill_attention(
q, k, v, layer_id
)
else:
# Full attention using FlashAttention
attn_output = flash_attn_varlen_func(
q, k, v, ...
)
```
### 建议的新调用方式
```python
# 所有 policy 统一调用
attn_output = self.attention_policy.compute_prefill_attention(
q, k, v, layer_id, softmax_scale=layer.self_attn.attn.scale
)
```
## Questions Resolved
- Q: 是否需要 PolicyContext?
- A: 可以简化,因为 layerwise 模式下不需要 chunk 信息
- Q: decode 阶段如何处理?
- A: **Decode 不需要 policy**!当前 `run_layerwise_offload_decode()` 使用标准 `layer(positions, hidden_states, residual)` 调用,走 Attention.forward() 路径
- Q: 为什么 decode 不需要 sparse?
- A: 因为 decode 每次只有 1 个 token没有稀疏化的意义。KV 从 ring buffer 加载后直接用 flash_attn_with_kvcache
## Key Insight
**Layerwise Offload 的 Policy 设计应该只关注 Prefill**
```
Prefill: 需要 Policy
- 整个序列一次计算 attention
- 可以使用 sparse attention 方法(如 MInference 的 vertical+slash pattern
- Policy 接收 q, k, v, layer_id, softmax_scale
Decode: 不需要 Policy
- 每次只有 1 个 token query
- KV 从 ring buffer 加载
- 使用标准 flash_attn_with_kvcache
```
## Interface Comparison Summary
| 方面 | tzj/minference | tzj/layer-offload (新设计) |
|------|----------------|---------------------------|
| 类名 | SparsePolicy | AttentionPolicy |
| Prefill 方法 | compute_chunked_prefill() | compute_attention() |
| Decode 方法 | compute_chunked_decode() | 不需要(用标准路径) |
| 需要 offload_engine | 是 | 否 |
| 需要 kvcache_manager | 是 | 否 |
| 需要 seq | 是 | 否 |
| 支持 FULL | 是 | 是 |
## Migration Path
1. 保留 `SparsePolicy` 作为 `AttentionPolicy` 的别名
2. 保留 `PolicyContext` 供未来扩展
3. 保留 `select_blocks()` 方法签名(虽然不使用)
4. 移除 `requires_block_selection` 属性(不需要)

549
task_plan.md Normal file
View File

@@ -0,0 +1,549 @@
# Task Plan: Refactor SparsePolicy for Layerwise Offload
## Goal
重构 SparsePolicy 接口,参考 tzj/minference 分支的设计模式,使所有 attention 都可以抽象成 policy并按统一规范编写。适配当前 layerwise offload 架构特点(整层 KV 在 GPU 上)。
## Background
### 两种 Offload 架构对比
| 特性 | tzj/minference (Chunked Offload) | tzj/layer-offload (Layerwise Offload) |
|------|----------------------------------|---------------------------------------|
| 处理粒度 | 每次一个 chunk (block_size tokens) | 每次一整层 (所有 tokens) |
| KV 位置 | 历史 chunks 在 CPU需要加载 | 整层 KV 都在 GPU |
| Policy 入口 | `compute_chunked_prefill()/decode()` | `compute_prefill()/decode()` |
| 需要 offload_engine | 是(加载 blocks | 否KV 已在 GPU |
| Mask 计算 | `select_blocks()` 返回 block IDs | `estimate()` 返回 sparse mask |
### tzj/minference 的 Policy 接口
```python
class SparsePolicy(ABC):
supports_prefill: bool
supports_decode: bool
@abstractmethod
def select_blocks(self, available_blocks, offload_engine, ctx) -> List[int]
@abstractmethod
def compute_chunked_prefill(self, q, k, v, layer_id, ..., offload_engine, ...) -> Tensor
@abstractmethod
def compute_chunked_decode(self, q, layer_id, ..., offload_engine, ...) -> Tensor
```
### 当前 branch 的 Policy 接口(重构前)
```python
class SparsePolicy(ABC):
supports_prefill: bool
supports_decode: bool
@abstractmethod
def select_blocks(self, available_blocks, ctx) -> List[int]
def sparse_prefill_attention(self, q, k, v, layer_id) -> Tensor
```
## Phases
- [x] Phase 1: 分析差异并设计新接口
- [x] **Phase 0: 创建 nanovllm.ops 模块** ✅ 测试通过
- [ ] Phase 2: 重构 AttentionPolicy 基类
- [ ] Phase 3: 重构 FullAttentionPolicy
- [ ] Phase 4: 重构 XAttentionPolicy (含 estimate 方法)
- [ ] Phase 5: 更新 model_runner 调用方式
- [ ] Phase 6: 测试验证
---
## Phase 0: 创建 nanovllm.ops 模块
### 目标
从 tzj/minference 分支提取 ops 模块,为 XAttention estimate 提供底层算子支持。
### 步骤
1. **创建目录结构**
```
nanovllm/ops/
├── __init__.py
├── xattn.py # xattn_estimate, xattn_estimate_chunked, Triton kernels
└── chunked_attention.py # flash_attn_with_lse, merge_attention_outputs (备用)
```
2. **从 tzj/minference 提取文件**
```bash
git show tzj/minference:nanovllm/ops/__init__.py > nanovllm/ops/__init__.py
git show tzj/minference:nanovllm/ops/xattn.py > nanovllm/ops/xattn.py
git show tzj/minference:nanovllm/ops/chunked_attention.py > nanovllm/ops/chunked_attention.py
```
3. **Cherry-pick 测试文件**
```bash
git show tzj/minference:tests/test_xattn_estimate_chunked.py > tests/test_xattn_estimate_chunked.py
```
4. **运行测试验证**
```bash
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/Worktree/nano-vllm:$PYTHONPATH \
python tests/test_xattn_estimate_chunked.py
```
### nanovllm/ops 模块内容
| 文件 | 核心函数 | 用途 |
|------|----------|------|
| `xattn.py` | `xattn_estimate()` | 标准 XAttention estimation |
| `xattn.py` | `xattn_estimate_chunked()` | Chunked prefill 版本 |
| `xattn.py` | `flat_group_gemm_fuse_reshape()` | Triton kernel: fused reshape + GEMM |
| `xattn.py` | `softmax_fuse_block_sum()` | Triton kernel: softmax + block sum |
| `xattn.py` | `find_blocks_chunked()` | Block selection based on threshold |
| `chunked_attention.py` | `flash_attn_with_lse()` | Flash attention with LSE output |
| `chunked_attention.py` | `merge_attention_outputs()` | Merge multiple attention chunks |
### 与 Policy 的关系
```
XAttentionPolicy.estimate()
└── 调用 nanovllm.ops.xattn.xattn_estimate()
├── flat_group_gemm_fuse_reshape() (Triton)
├── softmax_fuse_block_sum() (Triton)
└── find_blocks_chunked()
```
---
## Key Questions
1. **`select_blocks` 改为什么?**
- 改名为 `estimate()`:用于计算 sparse mask
- 对于 XAttention对应 COMPASS 的 `xattn_estimate()` 函数
- FullAttentionPolicy 的 `estimate()` 返回 None表示 full attention
2. **Policy 接口应该如何设计?**
- Prefill: `compute_prefill(q, k, v, layer_id, softmax_scale)`
- Decode: `compute_decode(q, k, v, layer_id, softmax_scale)`
- Estimate: `estimate(q, k, layer_id)` - 计算 sparse mask
3. **FULL policy 如何处理?**
- FULL 也实现 `compute_prefill/decode`,使用 FlashAttention
- `estimate()` 返回 None表示不进行稀疏化
## Proposed New Interface
```python
from abc import ABC, abstractmethod
from typing import Optional
import torch
class AttentionPolicy(ABC):
"""Layerwise Offload 模式下的 Attention Policy
所有 attention 计算都通过 policy 进行,包括 Full 和 Sparse。
支持 prefill 和 decode 两个阶段。
"""
supports_prefill: bool = True
supports_decode: bool = True
def estimate(
self,
q: torch.Tensor, # [seq_len, num_heads, head_dim]
k: torch.Tensor, # [seq_len, num_kv_heads, head_dim]
layer_id: int,
) -> Optional[torch.Tensor]:
"""
估算 sparse attention mask。
对于 sparse policy如 XAttention计算哪些 blocks 需要 attend。
对于 full policy返回 None 表示使用完整 attention。
对应 COMPASS 的 xattn_estimate() 函数。
Args:
q: Query tensor [seq_len, num_heads, head_dim]
k: Key tensor [seq_len, num_kv_heads, head_dim]
layer_id: Transformer layer index
Returns:
sparse_mask: [num_heads, q_blocks, k_blocks] boolean mask, 或 None
"""
return None # 默认为 full attention
@abstractmethod
def compute_prefill(
self,
q: torch.Tensor, # [seq_len, num_heads, head_dim]
k: torch.Tensor, # [seq_len, num_kv_heads, head_dim]
v: torch.Tensor, # [seq_len, num_kv_heads, head_dim]
layer_id: int,
softmax_scale: float,
) -> torch.Tensor:
"""
计算 prefill attention。
整层 KV 都在 GPU 上,一次计算完整 attention。
可以先调用 estimate() 获取 sparse mask然后应用 block sparse attention。
Args:
q: Query tensor [seq_len, num_heads, head_dim]
k: Key tensor [seq_len, num_kv_heads, head_dim]
v: Value tensor [seq_len, num_kv_heads, head_dim]
layer_id: Transformer layer index
softmax_scale: Softmax scaling factor (1/sqrt(head_dim))
Returns:
Attention output [seq_len, num_heads, head_dim]
"""
pass
def compute_decode(
self,
q: torch.Tensor, # [1, num_heads, head_dim]
k: torch.Tensor, # [context_len+1, num_kv_heads, head_dim]
v: torch.Tensor, # [context_len+1, num_kv_heads, head_dim]
layer_id: int,
softmax_scale: float,
) -> torch.Tensor:
"""
计算 decode attention。
KV 从 ring buffer 提供,包含 prefill tokens + 已 decode 的 tokens。
Args:
q: Query tensor [1, num_heads, head_dim]
k: Key tensor [context_len+1, num_kv_heads, head_dim]
v: Value tensor [context_len+1, num_kv_heads, head_dim]
layer_id: Transformer layer index
softmax_scale: Softmax scaling factor
Returns:
Attention output [1, num_heads, head_dim]
"""
# 默认实现:使用 FlashAttention
from flash_attn.flash_attn_interface import flash_attn_varlen_func
context_len = k.shape[0]
cu_seqlens_q = torch.tensor([0, 1], dtype=torch.int32, device=q.device)
cu_seqlens_k = torch.tensor([0, context_len], dtype=torch.int32, device=q.device)
return flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=1,
max_seqlen_k=context_len,
softmax_scale=softmax_scale,
causal=False,
)
def reset(self) -> None:
"""Reset policy state between sequences."""
pass
def __repr__(self) -> str:
return f"{self.__class__.__name__}()"
# 保留旧名称作为别名
SparsePolicy = AttentionPolicy
```
## Implementation Plan
### Phase 2: 重构 policy.py
```python
# nanovllm/kvcache/sparse/policy.py
from abc import ABC, abstractmethod
from typing import Optional
import torch
class AttentionPolicy(ABC):
"""Base class for attention policies in layerwise offload mode."""
supports_prefill: bool = True
supports_decode: bool = True
def estimate(
self,
q: torch.Tensor,
k: torch.Tensor,
layer_id: int,
) -> Optional[torch.Tensor]:
"""
Estimate sparse attention mask.
For sparse policies (e.g., XAttention), computes block-level importance.
For full policy, returns None.
Corresponds to xattn_estimate() in COMPASS.
Returns:
sparse_mask: [num_heads, q_blocks, k_blocks] or None
"""
return None
@abstractmethod
def compute_prefill(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer_id: int,
softmax_scale: float,
) -> torch.Tensor:
"""Compute prefill attention."""
pass
def compute_decode(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer_id: int,
softmax_scale: float,
) -> torch.Tensor:
"""Compute decode attention (default: FlashAttention)."""
from flash_attn.flash_attn_interface import flash_attn_varlen_func
context_len = k.shape[0]
cu_seqlens_q = torch.tensor([0, 1], dtype=torch.int32, device=q.device)
cu_seqlens_k = torch.tensor([0, context_len], dtype=torch.int32, device=q.device)
return flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=1,
max_seqlen_k=context_len,
softmax_scale=softmax_scale,
causal=False,
)
def reset(self) -> None:
pass
def __repr__(self) -> str:
return f"{self.__class__.__name__}()"
# Backward compatibility alias
SparsePolicy = AttentionPolicy
```
### Phase 3: 重构 FullAttentionPolicy
```python
# nanovllm/kvcache/sparse/full_policy.py
import torch
from .policy import AttentionPolicy
class FullAttentionPolicy(AttentionPolicy):
"""Full attention using FlashAttention (no sparsity)."""
supports_prefill = True
supports_decode = True
def estimate(self, q, k, layer_id):
"""Full attention - no sparse mask needed."""
return None
def compute_prefill(self, q, k, v, layer_id, softmax_scale):
from flash_attn.flash_attn_interface import flash_attn_varlen_func
seq_len = q.shape[0]
cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=q.device)
return flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=seq_len,
max_seqlen_k=seq_len,
softmax_scale=softmax_scale,
causal=True,
)
def __repr__(self):
return "FullAttentionPolicy()"
```
### Phase 4: 重构 XAttentionPolicy
```python
# nanovllm/kvcache/sparse/xattn.py
import torch
from typing import Optional
from .policy import AttentionPolicy
class XAttentionPolicy(AttentionPolicy):
"""
XAttention sparse prefill policy.
Uses chunked estimation to compute sparse attention mask,
then applies block sparse attention.
"""
supports_prefill = True
supports_decode = True
def __init__(
self,
stride: int = 8,
threshold: float = 0.9,
block_size: int = 128,
chunk_size: int = 16384,
use_triton: bool = True,
):
self.stride = stride
self.threshold = threshold
self.block_size = block_size
self.chunk_size = chunk_size
self.use_triton = use_triton
def estimate(
self,
q: torch.Tensor,
k: torch.Tensor,
layer_id: int,
) -> Optional[torch.Tensor]:
"""
XAttention estimation (xattn_estimate).
Uses chunked GEMM + softmax to estimate block-level importance,
then selects important blocks based on threshold.
对应 COMPASS 的 xattn_estimate() 函数:
1. Pad inputs to chunk_size multiples
2. Reshape with stride
3. Compute QK^T in chunks (Triton)
4. Block-wise softmax + aggregation
5. Threshold-based selection
Args:
q: [seq_len, num_heads, head_dim]
k: [seq_len, num_kv_heads, head_dim]
layer_id: transformer layer index
Returns:
sparse_mask: [num_heads, q_blocks, k_blocks] boolean mask
or None (fallback to full attention)
"""
# TODO: 实现真正的 xattn_estimate
# 当前返回 None 使用 full attention
return None
def compute_prefill(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer_id: int,
softmax_scale: float,
) -> torch.Tensor:
"""
Compute XAttention sparse prefill.
Flow:
1. Call estimate() to get sparse mask
2. If mask is None, use full attention
3. Otherwise, apply block sparse attention with mask
"""
# Step 1: Estimate sparse mask
sparse_mask = self.estimate(q, k, layer_id)
# Step 2: Compute attention
if sparse_mask is None:
# Fallback to full attention
from flash_attn.flash_attn_interface import flash_attn_varlen_func
seq_len = q.shape[0]
cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=q.device)
return flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=seq_len,
max_seqlen_k=seq_len,
softmax_scale=softmax_scale,
causal=True,
)
else:
# Apply block sparse attention with mask
# 使用 block_sparse_attn_func(q, k, v, sparse_mask, block_size)
raise NotImplementedError("Block sparse attention not yet implemented")
def __repr__(self):
return (f"XAttentionPolicy("
f"stride={self.stride}, "
f"threshold={self.threshold}, "
f"block_size={self.block_size})")
```
### Phase 5: 更新 model_runner.py
```python
# model_runner.py - allocate_kv_cache()
# 改为总是创建 policy包括 FULL
from nanovllm.kvcache.sparse import create_attention_policy
self.attention_policy = create_attention_policy(config.attention_policy, **policy_kwargs)
logger.info(f"Attention policy: {self.attention_policy}")
# run_layerwise_offload_prefill() 和 run_gpu_only_prefill()
# 旧代码:
if self.sparse_prefill_policy is not None:
attn_output = self.sparse_prefill_policy.sparse_prefill_attention(q, k, v, layer_id)
else:
attn_output = flash_attn_varlen_func(...)
# 新代码:
attn_output = self.attention_policy.compute_prefill(
q, k, v, layer_id, softmax_scale=layer.self_attn.attn.scale
)
```
## Method Mapping
| 旧方法 | 新方法 | 说明 |
|--------|--------|------|
| `select_blocks()` | `estimate()` | 计算 sparse mask对应 xattn_estimate |
| `sparse_prefill_attention()` | `compute_prefill()` | Prefill attention |
| (无) | `compute_decode()` | Decode attention默认实现 |
| `on_prefill_offload()` | (移除) | Offload 在 model_runner 处理 |
## Files to Modify
| File | Changes |
|------|---------|
| `nanovllm/kvcache/sparse/policy.py` | 新接口estimate, compute_prefill, compute_decode |
| `nanovllm/kvcache/sparse/full_policy.py` | 实现 compute_prefill(), estimate() 返回 None |
| `nanovllm/kvcache/sparse/xattn.py` | estimate() 对应 xattn_estimate, compute_prefill() |
| `nanovllm/kvcache/sparse/__init__.py` | 更新工厂函数 |
| `nanovllm/engine/model_runner.py` | 统一调用 attention_policy.compute_prefill() |
| `nanovllm/config.py` | 可选:重命名配置项 |
## Decisions Made
1. **方法命名**: `compute_prefill` / `compute_decode` 对应 chunked 版本的命名风格
2. **estimate 方法**: 替代 `select_blocks`,返回 sparse mask 而不是 block IDs
3. **XAttention**: `estimate()` 对应 COMPASS 的 `xattn_estimate()`
4. **Full Policy**: `estimate()` 返回 None 表示使用完整 attention
5. **Decode 默认实现**: 基类提供默认的 FlashAttention 实现
## Errors Encountered
- (无)
## Status
**Currently in Phase 1** - 完成分析和接口设计,等待用户确认后进入 Phase 2