Files
nano-vllm/task_plan.md
Zijie Tian 69b779e252 📝 docs: add layer offload planning notes and task plan
Add planning documents for layer-wise offload implementation:
- notes.md: Implementation notes and findings
- task_plan.md: Detailed task breakdown and progress tracking

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-22 06:04:36 +08:00

550 lines
17 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# Task Plan: Refactor SparsePolicy for Layerwise Offload
## Goal
重构 SparsePolicy 接口,参考 tzj/minference 分支的设计模式,使所有 attention 都可以抽象成 policy并按统一规范编写。适配当前 layerwise offload 架构特点(整层 KV 在 GPU 上)。
## Background
### 两种 Offload 架构对比
| 特性 | tzj/minference (Chunked Offload) | tzj/layer-offload (Layerwise Offload) |
|------|----------------------------------|---------------------------------------|
| 处理粒度 | 每次一个 chunk (block_size tokens) | 每次一整层 (所有 tokens) |
| KV 位置 | 历史 chunks 在 CPU需要加载 | 整层 KV 都在 GPU |
| Policy 入口 | `compute_chunked_prefill()/decode()` | `compute_prefill()/decode()` |
| 需要 offload_engine | 是(加载 blocks | 否KV 已在 GPU |
| Mask 计算 | `select_blocks()` 返回 block IDs | `estimate()` 返回 sparse mask |
### tzj/minference 的 Policy 接口
```python
class SparsePolicy(ABC):
supports_prefill: bool
supports_decode: bool
@abstractmethod
def select_blocks(self, available_blocks, offload_engine, ctx) -> List[int]
@abstractmethod
def compute_chunked_prefill(self, q, k, v, layer_id, ..., offload_engine, ...) -> Tensor
@abstractmethod
def compute_chunked_decode(self, q, layer_id, ..., offload_engine, ...) -> Tensor
```
### 当前 branch 的 Policy 接口(重构前)
```python
class SparsePolicy(ABC):
supports_prefill: bool
supports_decode: bool
@abstractmethod
def select_blocks(self, available_blocks, ctx) -> List[int]
def sparse_prefill_attention(self, q, k, v, layer_id) -> Tensor
```
## Phases
- [x] Phase 1: 分析差异并设计新接口
- [x] **Phase 0: 创建 nanovllm.ops 模块** ✅ 测试通过
- [ ] Phase 2: 重构 AttentionPolicy 基类
- [ ] Phase 3: 重构 FullAttentionPolicy
- [ ] Phase 4: 重构 XAttentionPolicy (含 estimate 方法)
- [ ] Phase 5: 更新 model_runner 调用方式
- [ ] Phase 6: 测试验证
---
## Phase 0: 创建 nanovllm.ops 模块
### 目标
从 tzj/minference 分支提取 ops 模块,为 XAttention estimate 提供底层算子支持。
### 步骤
1. **创建目录结构**
```
nanovllm/ops/
├── __init__.py
├── xattn.py # xattn_estimate, xattn_estimate_chunked, Triton kernels
└── chunked_attention.py # flash_attn_with_lse, merge_attention_outputs (备用)
```
2. **从 tzj/minference 提取文件**
```bash
git show tzj/minference:nanovllm/ops/__init__.py > nanovllm/ops/__init__.py
git show tzj/minference:nanovllm/ops/xattn.py > nanovllm/ops/xattn.py
git show tzj/minference:nanovllm/ops/chunked_attention.py > nanovllm/ops/chunked_attention.py
```
3. **Cherry-pick 测试文件**
```bash
git show tzj/minference:tests/test_xattn_estimate_chunked.py > tests/test_xattn_estimate_chunked.py
```
4. **运行测试验证**
```bash
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/Worktree/nano-vllm:$PYTHONPATH \
python tests/test_xattn_estimate_chunked.py
```
### nanovllm/ops 模块内容
| 文件 | 核心函数 | 用途 |
|------|----------|------|
| `xattn.py` | `xattn_estimate()` | 标准 XAttention estimation |
| `xattn.py` | `xattn_estimate_chunked()` | Chunked prefill 版本 |
| `xattn.py` | `flat_group_gemm_fuse_reshape()` | Triton kernel: fused reshape + GEMM |
| `xattn.py` | `softmax_fuse_block_sum()` | Triton kernel: softmax + block sum |
| `xattn.py` | `find_blocks_chunked()` | Block selection based on threshold |
| `chunked_attention.py` | `flash_attn_with_lse()` | Flash attention with LSE output |
| `chunked_attention.py` | `merge_attention_outputs()` | Merge multiple attention chunks |
### 与 Policy 的关系
```
XAttentionPolicy.estimate()
└── 调用 nanovllm.ops.xattn.xattn_estimate()
├── flat_group_gemm_fuse_reshape() (Triton)
├── softmax_fuse_block_sum() (Triton)
└── find_blocks_chunked()
```
---
## Key Questions
1. **`select_blocks` 改为什么?**
- 改名为 `estimate()`:用于计算 sparse mask
- 对于 XAttention对应 COMPASS 的 `xattn_estimate()` 函数
- FullAttentionPolicy 的 `estimate()` 返回 None表示 full attention
2. **Policy 接口应该如何设计?**
- Prefill: `compute_prefill(q, k, v, layer_id, softmax_scale)`
- Decode: `compute_decode(q, k, v, layer_id, softmax_scale)`
- Estimate: `estimate(q, k, layer_id)` - 计算 sparse mask
3. **FULL policy 如何处理?**
- FULL 也实现 `compute_prefill/decode`,使用 FlashAttention
- `estimate()` 返回 None表示不进行稀疏化
## Proposed New Interface
```python
from abc import ABC, abstractmethod
from typing import Optional
import torch
class AttentionPolicy(ABC):
"""Layerwise Offload 模式下的 Attention Policy
所有 attention 计算都通过 policy 进行,包括 Full 和 Sparse。
支持 prefill 和 decode 两个阶段。
"""
supports_prefill: bool = True
supports_decode: bool = True
def estimate(
self,
q: torch.Tensor, # [seq_len, num_heads, head_dim]
k: torch.Tensor, # [seq_len, num_kv_heads, head_dim]
layer_id: int,
) -> Optional[torch.Tensor]:
"""
估算 sparse attention mask。
对于 sparse policy如 XAttention计算哪些 blocks 需要 attend。
对于 full policy返回 None 表示使用完整 attention。
对应 COMPASS 的 xattn_estimate() 函数。
Args:
q: Query tensor [seq_len, num_heads, head_dim]
k: Key tensor [seq_len, num_kv_heads, head_dim]
layer_id: Transformer layer index
Returns:
sparse_mask: [num_heads, q_blocks, k_blocks] boolean mask, 或 None
"""
return None # 默认为 full attention
@abstractmethod
def compute_prefill(
self,
q: torch.Tensor, # [seq_len, num_heads, head_dim]
k: torch.Tensor, # [seq_len, num_kv_heads, head_dim]
v: torch.Tensor, # [seq_len, num_kv_heads, head_dim]
layer_id: int,
softmax_scale: float,
) -> torch.Tensor:
"""
计算 prefill attention。
整层 KV 都在 GPU 上,一次计算完整 attention。
可以先调用 estimate() 获取 sparse mask然后应用 block sparse attention。
Args:
q: Query tensor [seq_len, num_heads, head_dim]
k: Key tensor [seq_len, num_kv_heads, head_dim]
v: Value tensor [seq_len, num_kv_heads, head_dim]
layer_id: Transformer layer index
softmax_scale: Softmax scaling factor (1/sqrt(head_dim))
Returns:
Attention output [seq_len, num_heads, head_dim]
"""
pass
def compute_decode(
self,
q: torch.Tensor, # [1, num_heads, head_dim]
k: torch.Tensor, # [context_len+1, num_kv_heads, head_dim]
v: torch.Tensor, # [context_len+1, num_kv_heads, head_dim]
layer_id: int,
softmax_scale: float,
) -> torch.Tensor:
"""
计算 decode attention。
KV 从 ring buffer 提供,包含 prefill tokens + 已 decode 的 tokens。
Args:
q: Query tensor [1, num_heads, head_dim]
k: Key tensor [context_len+1, num_kv_heads, head_dim]
v: Value tensor [context_len+1, num_kv_heads, head_dim]
layer_id: Transformer layer index
softmax_scale: Softmax scaling factor
Returns:
Attention output [1, num_heads, head_dim]
"""
# 默认实现:使用 FlashAttention
from flash_attn.flash_attn_interface import flash_attn_varlen_func
context_len = k.shape[0]
cu_seqlens_q = torch.tensor([0, 1], dtype=torch.int32, device=q.device)
cu_seqlens_k = torch.tensor([0, context_len], dtype=torch.int32, device=q.device)
return flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=1,
max_seqlen_k=context_len,
softmax_scale=softmax_scale,
causal=False,
)
def reset(self) -> None:
"""Reset policy state between sequences."""
pass
def __repr__(self) -> str:
return f"{self.__class__.__name__}()"
# 保留旧名称作为别名
SparsePolicy = AttentionPolicy
```
## Implementation Plan
### Phase 2: 重构 policy.py
```python
# nanovllm/kvcache/sparse/policy.py
from abc import ABC, abstractmethod
from typing import Optional
import torch
class AttentionPolicy(ABC):
"""Base class for attention policies in layerwise offload mode."""
supports_prefill: bool = True
supports_decode: bool = True
def estimate(
self,
q: torch.Tensor,
k: torch.Tensor,
layer_id: int,
) -> Optional[torch.Tensor]:
"""
Estimate sparse attention mask.
For sparse policies (e.g., XAttention), computes block-level importance.
For full policy, returns None.
Corresponds to xattn_estimate() in COMPASS.
Returns:
sparse_mask: [num_heads, q_blocks, k_blocks] or None
"""
return None
@abstractmethod
def compute_prefill(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer_id: int,
softmax_scale: float,
) -> torch.Tensor:
"""Compute prefill attention."""
pass
def compute_decode(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer_id: int,
softmax_scale: float,
) -> torch.Tensor:
"""Compute decode attention (default: FlashAttention)."""
from flash_attn.flash_attn_interface import flash_attn_varlen_func
context_len = k.shape[0]
cu_seqlens_q = torch.tensor([0, 1], dtype=torch.int32, device=q.device)
cu_seqlens_k = torch.tensor([0, context_len], dtype=torch.int32, device=q.device)
return flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=1,
max_seqlen_k=context_len,
softmax_scale=softmax_scale,
causal=False,
)
def reset(self) -> None:
pass
def __repr__(self) -> str:
return f"{self.__class__.__name__}()"
# Backward compatibility alias
SparsePolicy = AttentionPolicy
```
### Phase 3: 重构 FullAttentionPolicy
```python
# nanovllm/kvcache/sparse/full_policy.py
import torch
from .policy import AttentionPolicy
class FullAttentionPolicy(AttentionPolicy):
"""Full attention using FlashAttention (no sparsity)."""
supports_prefill = True
supports_decode = True
def estimate(self, q, k, layer_id):
"""Full attention - no sparse mask needed."""
return None
def compute_prefill(self, q, k, v, layer_id, softmax_scale):
from flash_attn.flash_attn_interface import flash_attn_varlen_func
seq_len = q.shape[0]
cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=q.device)
return flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=seq_len,
max_seqlen_k=seq_len,
softmax_scale=softmax_scale,
causal=True,
)
def __repr__(self):
return "FullAttentionPolicy()"
```
### Phase 4: 重构 XAttentionPolicy
```python
# nanovllm/kvcache/sparse/xattn.py
import torch
from typing import Optional
from .policy import AttentionPolicy
class XAttentionPolicy(AttentionPolicy):
"""
XAttention sparse prefill policy.
Uses chunked estimation to compute sparse attention mask,
then applies block sparse attention.
"""
supports_prefill = True
supports_decode = True
def __init__(
self,
stride: int = 8,
threshold: float = 0.9,
block_size: int = 128,
chunk_size: int = 16384,
use_triton: bool = True,
):
self.stride = stride
self.threshold = threshold
self.block_size = block_size
self.chunk_size = chunk_size
self.use_triton = use_triton
def estimate(
self,
q: torch.Tensor,
k: torch.Tensor,
layer_id: int,
) -> Optional[torch.Tensor]:
"""
XAttention estimation (xattn_estimate).
Uses chunked GEMM + softmax to estimate block-level importance,
then selects important blocks based on threshold.
对应 COMPASS 的 xattn_estimate() 函数:
1. Pad inputs to chunk_size multiples
2. Reshape with stride
3. Compute QK^T in chunks (Triton)
4. Block-wise softmax + aggregation
5. Threshold-based selection
Args:
q: [seq_len, num_heads, head_dim]
k: [seq_len, num_kv_heads, head_dim]
layer_id: transformer layer index
Returns:
sparse_mask: [num_heads, q_blocks, k_blocks] boolean mask
or None (fallback to full attention)
"""
# TODO: 实现真正的 xattn_estimate
# 当前返回 None 使用 full attention
return None
def compute_prefill(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer_id: int,
softmax_scale: float,
) -> torch.Tensor:
"""
Compute XAttention sparse prefill.
Flow:
1. Call estimate() to get sparse mask
2. If mask is None, use full attention
3. Otherwise, apply block sparse attention with mask
"""
# Step 1: Estimate sparse mask
sparse_mask = self.estimate(q, k, layer_id)
# Step 2: Compute attention
if sparse_mask is None:
# Fallback to full attention
from flash_attn.flash_attn_interface import flash_attn_varlen_func
seq_len = q.shape[0]
cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=q.device)
return flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=seq_len,
max_seqlen_k=seq_len,
softmax_scale=softmax_scale,
causal=True,
)
else:
# Apply block sparse attention with mask
# 使用 block_sparse_attn_func(q, k, v, sparse_mask, block_size)
raise NotImplementedError("Block sparse attention not yet implemented")
def __repr__(self):
return (f"XAttentionPolicy("
f"stride={self.stride}, "
f"threshold={self.threshold}, "
f"block_size={self.block_size})")
```
### Phase 5: 更新 model_runner.py
```python
# model_runner.py - allocate_kv_cache()
# 改为总是创建 policy包括 FULL
from nanovllm.kvcache.sparse import create_attention_policy
self.attention_policy = create_attention_policy(config.attention_policy, **policy_kwargs)
logger.info(f"Attention policy: {self.attention_policy}")
# run_layerwise_offload_prefill() 和 run_gpu_only_prefill()
# 旧代码:
if self.sparse_prefill_policy is not None:
attn_output = self.sparse_prefill_policy.sparse_prefill_attention(q, k, v, layer_id)
else:
attn_output = flash_attn_varlen_func(...)
# 新代码:
attn_output = self.attention_policy.compute_prefill(
q, k, v, layer_id, softmax_scale=layer.self_attn.attn.scale
)
```
## Method Mapping
| 旧方法 | 新方法 | 说明 |
|--------|--------|------|
| `select_blocks()` | `estimate()` | 计算 sparse mask对应 xattn_estimate |
| `sparse_prefill_attention()` | `compute_prefill()` | Prefill attention |
| (无) | `compute_decode()` | Decode attention默认实现 |
| `on_prefill_offload()` | (移除) | Offload 在 model_runner 处理 |
## Files to Modify
| File | Changes |
|------|---------|
| `nanovllm/kvcache/sparse/policy.py` | 新接口estimate, compute_prefill, compute_decode |
| `nanovllm/kvcache/sparse/full_policy.py` | 实现 compute_prefill(), estimate() 返回 None |
| `nanovllm/kvcache/sparse/xattn.py` | estimate() 对应 xattn_estimate, compute_prefill() |
| `nanovllm/kvcache/sparse/__init__.py` | 更新工厂函数 |
| `nanovllm/engine/model_runner.py` | 统一调用 attention_policy.compute_prefill() |
| `nanovllm/config.py` | 可选:重命名配置项 |
## Decisions Made
1. **方法命名**: `compute_prefill` / `compute_decode` 对应 chunked 版本的命名风格
2. **estimate 方法**: 替代 `select_blocks`,返回 sparse mask 而不是 block IDs
3. **XAttention**: `estimate()` 对应 COMPASS 的 `xattn_estimate()`
4. **Full Policy**: `estimate()` 返回 None 表示使用完整 attention
5. **Decode 默认实现**: 基类提供默认的 FlashAttention 实现
## Errors Encountered
- (无)
## Status
**Currently in Phase 1** - 完成分析和接口设计,等待用户确认后进入 Phase 2