📝 docs: add layer offload planning notes and task plan
Add planning documents for layer-wise offload implementation: - notes.md: Implementation notes and findings - task_plan.md: Detailed task breakdown and progress tracking Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
130
notes.md
Normal file
130
notes.md
Normal file
@@ -0,0 +1,130 @@
|
|||||||
|
# Notes: SparsePolicy Refactoring Research
|
||||||
|
|
||||||
|
## Sources
|
||||||
|
|
||||||
|
### Source 1: tzj/minference branch - policy.py
|
||||||
|
- 路径: `nanovllm/kvcache/sparse/policy.py`
|
||||||
|
- 关键设计:
|
||||||
|
- `PolicyContext` 数据类包含 query_chunk_idx, num_query_chunks, layer_id, query, is_prefill 等
|
||||||
|
- `select_blocks()` 需要 offload_engine 参数
|
||||||
|
- `compute_chunked_prefill()` 和 `compute_chunked_decode()` 是完整的 attention 流程
|
||||||
|
- `on_prefill_offload()` / `on_decode_offload()` hooks 用于收集元数据
|
||||||
|
|
||||||
|
### Source 2: tzj/minference branch - full_policy.py
|
||||||
|
- 路径: `nanovllm/kvcache/sparse/full_policy.py`
|
||||||
|
- 关键实现:
|
||||||
|
- `compute_chunked_prefill()` 内部使用 ring buffer pipeline 加载 blocks
|
||||||
|
- 使用 `flash_attn_with_lse` 和 `merge_attention_outputs` 合并多个 chunk 的 attention
|
||||||
|
- `compute_chunked_decode()` 处理 prefilled blocks + decode buffer
|
||||||
|
|
||||||
|
### Source 3: tzj/layer-offload branch - model_runner.py
|
||||||
|
- 路径: `nanovllm/engine/model_runner.py`
|
||||||
|
- 关键设计:
|
||||||
|
- `run_layerwise_offload_prefill()` 逐层处理,每层计算完整 attention
|
||||||
|
- `sparse_prefill_policy.sparse_prefill_attention(q, k, v, layer_id)` 简单接口
|
||||||
|
- FULL policy 通过 `if sparse_prefill_policy is None` 走 else 分支
|
||||||
|
|
||||||
|
### Source 4: tzj/layer-offload branch - xattn.py
|
||||||
|
- 路径: `nanovllm/kvcache/sparse/xattn.py`
|
||||||
|
- 关键实现:
|
||||||
|
- `sparse_prefill_attention()` 直接使用 FlashAttention(因为 chunked prefill 架构限制)
|
||||||
|
- 保留 Triton kernels 供未来 GPU-only 模式
|
||||||
|
|
||||||
|
## Synthesized Findings
|
||||||
|
|
||||||
|
### 架构差异总结
|
||||||
|
|
||||||
|
| 方面 | Chunked Offload | Layerwise Offload |
|
||||||
|
|------|-----------------|-------------------|
|
||||||
|
| **Prefill 流程** | chunk-by-chunk,跨层 | layer-by-layer,完整序列 |
|
||||||
|
| **KV 存储** | 每 chunk 立即 offload | 每层计算后 offload |
|
||||||
|
| **Attention 计算** | 分多次计算+合并 | 一次完整计算 |
|
||||||
|
| **Block 加载** | 需要从 CPU 加载历史 | 不需要,已在 GPU |
|
||||||
|
| **Policy 责任** | 完整 attention 流程 | 仅 attention kernel 选择 |
|
||||||
|
|
||||||
|
### Layerwise Offload 的简化点
|
||||||
|
|
||||||
|
1. **不需要 block selection**: 整层 KV 都在 GPU,无需选择
|
||||||
|
2. **不需要 offload_engine 参数**: Policy 不负责加载 KV
|
||||||
|
3. **不需要 merge_attention_outputs**: 一次计算完整 attention
|
||||||
|
4. **不需要 offload hooks**: offload 在 model_runner 统一处理
|
||||||
|
|
||||||
|
### 设计建议
|
||||||
|
|
||||||
|
1. **保持接口简单**: 只需要 `compute_prefill_attention()` 和 `compute_decode_attention()`
|
||||||
|
2. **FULL 也实现方法**: 不再通过 `is None` 判断,所有 policy 统一调用
|
||||||
|
3. **移除不必要的参数**: 不需要 offload_engine, kvcache_manager, seq 等
|
||||||
|
4. **统一命名**: 使用 `compute_*_attention` 而不是 `sparse_prefill_attention`
|
||||||
|
|
||||||
|
## Code Examples
|
||||||
|
|
||||||
|
### 当前调用方式 (model_runner.py:876-891)
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Sparse or Full attention
|
||||||
|
if self.sparse_prefill_policy is not None:
|
||||||
|
# MInference or other sparse prefill policy
|
||||||
|
attn_output = self.sparse_prefill_policy.sparse_prefill_attention(
|
||||||
|
q, k, v, layer_id
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Full attention using FlashAttention
|
||||||
|
attn_output = flash_attn_varlen_func(
|
||||||
|
q, k, v, ...
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 建议的新调用方式
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 所有 policy 统一调用
|
||||||
|
attn_output = self.attention_policy.compute_prefill_attention(
|
||||||
|
q, k, v, layer_id, softmax_scale=layer.self_attn.attn.scale
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Questions Resolved
|
||||||
|
|
||||||
|
- Q: 是否需要 PolicyContext?
|
||||||
|
- A: 可以简化,因为 layerwise 模式下不需要 chunk 信息
|
||||||
|
|
||||||
|
- Q: decode 阶段如何处理?
|
||||||
|
- A: **Decode 不需要 policy**!当前 `run_layerwise_offload_decode()` 使用标准 `layer(positions, hidden_states, residual)` 调用,走 Attention.forward() 路径
|
||||||
|
|
||||||
|
- Q: 为什么 decode 不需要 sparse?
|
||||||
|
- A: 因为 decode 每次只有 1 个 token,没有稀疏化的意义。KV 从 ring buffer 加载后直接用 flash_attn_with_kvcache
|
||||||
|
|
||||||
|
## Key Insight
|
||||||
|
|
||||||
|
**Layerwise Offload 的 Policy 设计应该只关注 Prefill**:
|
||||||
|
|
||||||
|
```
|
||||||
|
Prefill: 需要 Policy
|
||||||
|
- 整个序列一次计算 attention
|
||||||
|
- 可以使用 sparse attention 方法(如 MInference 的 vertical+slash pattern)
|
||||||
|
- Policy 接收 q, k, v, layer_id, softmax_scale
|
||||||
|
|
||||||
|
Decode: 不需要 Policy
|
||||||
|
- 每次只有 1 个 token query
|
||||||
|
- KV 从 ring buffer 加载
|
||||||
|
- 使用标准 flash_attn_with_kvcache
|
||||||
|
```
|
||||||
|
|
||||||
|
## Interface Comparison Summary
|
||||||
|
|
||||||
|
| 方面 | tzj/minference | tzj/layer-offload (新设计) |
|
||||||
|
|------|----------------|---------------------------|
|
||||||
|
| 类名 | SparsePolicy | AttentionPolicy |
|
||||||
|
| Prefill 方法 | compute_chunked_prefill() | compute_attention() |
|
||||||
|
| Decode 方法 | compute_chunked_decode() | 不需要(用标准路径) |
|
||||||
|
| 需要 offload_engine | 是 | 否 |
|
||||||
|
| 需要 kvcache_manager | 是 | 否 |
|
||||||
|
| 需要 seq | 是 | 否 |
|
||||||
|
| 支持 FULL | 是 | 是 |
|
||||||
|
|
||||||
|
## Migration Path
|
||||||
|
|
||||||
|
1. 保留 `SparsePolicy` 作为 `AttentionPolicy` 的别名
|
||||||
|
2. 保留 `PolicyContext` 供未来扩展
|
||||||
|
3. 保留 `select_blocks()` 方法签名(虽然不使用)
|
||||||
|
4. 移除 `requires_block_selection` 属性(不需要)
|
||||||
549
task_plan.md
Normal file
549
task_plan.md
Normal file
@@ -0,0 +1,549 @@
|
|||||||
|
# Task Plan: Refactor SparsePolicy for Layerwise Offload
|
||||||
|
|
||||||
|
## Goal
|
||||||
|
重构 SparsePolicy 接口,参考 tzj/minference 分支的设计模式,使所有 attention 都可以抽象成 policy,并按统一规范编写。适配当前 layerwise offload 架构特点(整层 KV 在 GPU 上)。
|
||||||
|
|
||||||
|
## Background
|
||||||
|
|
||||||
|
### 两种 Offload 架构对比
|
||||||
|
|
||||||
|
| 特性 | tzj/minference (Chunked Offload) | tzj/layer-offload (Layerwise Offload) |
|
||||||
|
|------|----------------------------------|---------------------------------------|
|
||||||
|
| 处理粒度 | 每次一个 chunk (block_size tokens) | 每次一整层 (所有 tokens) |
|
||||||
|
| KV 位置 | 历史 chunks 在 CPU,需要加载 | 整层 KV 都在 GPU |
|
||||||
|
| Policy 入口 | `compute_chunked_prefill()/decode()` | `compute_prefill()/decode()` |
|
||||||
|
| 需要 offload_engine | 是(加载 blocks) | 否(KV 已在 GPU) |
|
||||||
|
| Mask 计算 | `select_blocks()` 返回 block IDs | `estimate()` 返回 sparse mask |
|
||||||
|
|
||||||
|
### tzj/minference 的 Policy 接口
|
||||||
|
|
||||||
|
```python
|
||||||
|
class SparsePolicy(ABC):
|
||||||
|
supports_prefill: bool
|
||||||
|
supports_decode: bool
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def select_blocks(self, available_blocks, offload_engine, ctx) -> List[int]
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def compute_chunked_prefill(self, q, k, v, layer_id, ..., offload_engine, ...) -> Tensor
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def compute_chunked_decode(self, q, layer_id, ..., offload_engine, ...) -> Tensor
|
||||||
|
```
|
||||||
|
|
||||||
|
### 当前 branch 的 Policy 接口(重构前)
|
||||||
|
|
||||||
|
```python
|
||||||
|
class SparsePolicy(ABC):
|
||||||
|
supports_prefill: bool
|
||||||
|
supports_decode: bool
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def select_blocks(self, available_blocks, ctx) -> List[int]
|
||||||
|
|
||||||
|
def sparse_prefill_attention(self, q, k, v, layer_id) -> Tensor
|
||||||
|
```
|
||||||
|
|
||||||
|
## Phases
|
||||||
|
|
||||||
|
- [x] Phase 1: 分析差异并设计新接口
|
||||||
|
- [x] **Phase 0: 创建 nanovllm.ops 模块** ✅ 测试通过
|
||||||
|
- [ ] Phase 2: 重构 AttentionPolicy 基类
|
||||||
|
- [ ] Phase 3: 重构 FullAttentionPolicy
|
||||||
|
- [ ] Phase 4: 重构 XAttentionPolicy (含 estimate 方法)
|
||||||
|
- [ ] Phase 5: 更新 model_runner 调用方式
|
||||||
|
- [ ] Phase 6: 测试验证
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Phase 0: 创建 nanovllm.ops 模块
|
||||||
|
|
||||||
|
### 目标
|
||||||
|
从 tzj/minference 分支提取 ops 模块,为 XAttention estimate 提供底层算子支持。
|
||||||
|
|
||||||
|
### 步骤
|
||||||
|
|
||||||
|
1. **创建目录结构**
|
||||||
|
```
|
||||||
|
nanovllm/ops/
|
||||||
|
├── __init__.py
|
||||||
|
├── xattn.py # xattn_estimate, xattn_estimate_chunked, Triton kernels
|
||||||
|
└── chunked_attention.py # flash_attn_with_lse, merge_attention_outputs (备用)
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **从 tzj/minference 提取文件**
|
||||||
|
```bash
|
||||||
|
git show tzj/minference:nanovllm/ops/__init__.py > nanovllm/ops/__init__.py
|
||||||
|
git show tzj/minference:nanovllm/ops/xattn.py > nanovllm/ops/xattn.py
|
||||||
|
git show tzj/minference:nanovllm/ops/chunked_attention.py > nanovllm/ops/chunked_attention.py
|
||||||
|
```
|
||||||
|
|
||||||
|
3. **Cherry-pick 测试文件**
|
||||||
|
```bash
|
||||||
|
git show tzj/minference:tests/test_xattn_estimate_chunked.py > tests/test_xattn_estimate_chunked.py
|
||||||
|
```
|
||||||
|
|
||||||
|
4. **运行测试验证**
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/Worktree/nano-vllm:$PYTHONPATH \
|
||||||
|
python tests/test_xattn_estimate_chunked.py
|
||||||
|
```
|
||||||
|
|
||||||
|
### nanovllm/ops 模块内容
|
||||||
|
|
||||||
|
| 文件 | 核心函数 | 用途 |
|
||||||
|
|------|----------|------|
|
||||||
|
| `xattn.py` | `xattn_estimate()` | 标准 XAttention estimation |
|
||||||
|
| `xattn.py` | `xattn_estimate_chunked()` | Chunked prefill 版本 |
|
||||||
|
| `xattn.py` | `flat_group_gemm_fuse_reshape()` | Triton kernel: fused reshape + GEMM |
|
||||||
|
| `xattn.py` | `softmax_fuse_block_sum()` | Triton kernel: softmax + block sum |
|
||||||
|
| `xattn.py` | `find_blocks_chunked()` | Block selection based on threshold |
|
||||||
|
| `chunked_attention.py` | `flash_attn_with_lse()` | Flash attention with LSE output |
|
||||||
|
| `chunked_attention.py` | `merge_attention_outputs()` | Merge multiple attention chunks |
|
||||||
|
|
||||||
|
### 与 Policy 的关系
|
||||||
|
|
||||||
|
```
|
||||||
|
XAttentionPolicy.estimate()
|
||||||
|
└── 调用 nanovllm.ops.xattn.xattn_estimate()
|
||||||
|
├── flat_group_gemm_fuse_reshape() (Triton)
|
||||||
|
├── softmax_fuse_block_sum() (Triton)
|
||||||
|
└── find_blocks_chunked()
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Key Questions
|
||||||
|
|
||||||
|
1. **`select_blocks` 改为什么?**
|
||||||
|
- 改名为 `estimate()`:用于计算 sparse mask
|
||||||
|
- 对于 XAttention,对应 COMPASS 的 `xattn_estimate()` 函数
|
||||||
|
- FullAttentionPolicy 的 `estimate()` 返回 None(表示 full attention)
|
||||||
|
|
||||||
|
2. **Policy 接口应该如何设计?**
|
||||||
|
- Prefill: `compute_prefill(q, k, v, layer_id, softmax_scale)`
|
||||||
|
- Decode: `compute_decode(q, k, v, layer_id, softmax_scale)`
|
||||||
|
- Estimate: `estimate(q, k, layer_id)` - 计算 sparse mask
|
||||||
|
|
||||||
|
3. **FULL policy 如何处理?**
|
||||||
|
- FULL 也实现 `compute_prefill/decode`,使用 FlashAttention
|
||||||
|
- `estimate()` 返回 None(表示不进行稀疏化)
|
||||||
|
|
||||||
|
## Proposed New Interface
|
||||||
|
|
||||||
|
```python
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Optional
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionPolicy(ABC):
|
||||||
|
"""Layerwise Offload 模式下的 Attention Policy
|
||||||
|
|
||||||
|
所有 attention 计算都通过 policy 进行,包括 Full 和 Sparse。
|
||||||
|
支持 prefill 和 decode 两个阶段。
|
||||||
|
"""
|
||||||
|
|
||||||
|
supports_prefill: bool = True
|
||||||
|
supports_decode: bool = True
|
||||||
|
|
||||||
|
def estimate(
|
||||||
|
self,
|
||||||
|
q: torch.Tensor, # [seq_len, num_heads, head_dim]
|
||||||
|
k: torch.Tensor, # [seq_len, num_kv_heads, head_dim]
|
||||||
|
layer_id: int,
|
||||||
|
) -> Optional[torch.Tensor]:
|
||||||
|
"""
|
||||||
|
估算 sparse attention mask。
|
||||||
|
|
||||||
|
对于 sparse policy(如 XAttention),计算哪些 blocks 需要 attend。
|
||||||
|
对于 full policy,返回 None 表示使用完整 attention。
|
||||||
|
|
||||||
|
对应 COMPASS 的 xattn_estimate() 函数。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q: Query tensor [seq_len, num_heads, head_dim]
|
||||||
|
k: Key tensor [seq_len, num_kv_heads, head_dim]
|
||||||
|
layer_id: Transformer layer index
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
sparse_mask: [num_heads, q_blocks, k_blocks] boolean mask, 或 None
|
||||||
|
"""
|
||||||
|
return None # 默认为 full attention
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def compute_prefill(
|
||||||
|
self,
|
||||||
|
q: torch.Tensor, # [seq_len, num_heads, head_dim]
|
||||||
|
k: torch.Tensor, # [seq_len, num_kv_heads, head_dim]
|
||||||
|
v: torch.Tensor, # [seq_len, num_kv_heads, head_dim]
|
||||||
|
layer_id: int,
|
||||||
|
softmax_scale: float,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
计算 prefill attention。
|
||||||
|
|
||||||
|
整层 KV 都在 GPU 上,一次计算完整 attention。
|
||||||
|
可以先调用 estimate() 获取 sparse mask,然后应用 block sparse attention。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q: Query tensor [seq_len, num_heads, head_dim]
|
||||||
|
k: Key tensor [seq_len, num_kv_heads, head_dim]
|
||||||
|
v: Value tensor [seq_len, num_kv_heads, head_dim]
|
||||||
|
layer_id: Transformer layer index
|
||||||
|
softmax_scale: Softmax scaling factor (1/sqrt(head_dim))
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Attention output [seq_len, num_heads, head_dim]
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def compute_decode(
|
||||||
|
self,
|
||||||
|
q: torch.Tensor, # [1, num_heads, head_dim]
|
||||||
|
k: torch.Tensor, # [context_len+1, num_kv_heads, head_dim]
|
||||||
|
v: torch.Tensor, # [context_len+1, num_kv_heads, head_dim]
|
||||||
|
layer_id: int,
|
||||||
|
softmax_scale: float,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
计算 decode attention。
|
||||||
|
|
||||||
|
KV 从 ring buffer 提供,包含 prefill tokens + 已 decode 的 tokens。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q: Query tensor [1, num_heads, head_dim]
|
||||||
|
k: Key tensor [context_len+1, num_kv_heads, head_dim]
|
||||||
|
v: Value tensor [context_len+1, num_kv_heads, head_dim]
|
||||||
|
layer_id: Transformer layer index
|
||||||
|
softmax_scale: Softmax scaling factor
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Attention output [1, num_heads, head_dim]
|
||||||
|
"""
|
||||||
|
# 默认实现:使用 FlashAttention
|
||||||
|
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
||||||
|
|
||||||
|
context_len = k.shape[0]
|
||||||
|
cu_seqlens_q = torch.tensor([0, 1], dtype=torch.int32, device=q.device)
|
||||||
|
cu_seqlens_k = torch.tensor([0, context_len], dtype=torch.int32, device=q.device)
|
||||||
|
|
||||||
|
return flash_attn_varlen_func(
|
||||||
|
q, k, v,
|
||||||
|
cu_seqlens_q=cu_seqlens_q,
|
||||||
|
cu_seqlens_k=cu_seqlens_k,
|
||||||
|
max_seqlen_q=1,
|
||||||
|
max_seqlen_k=context_len,
|
||||||
|
softmax_scale=softmax_scale,
|
||||||
|
causal=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
"""Reset policy state between sequences."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"{self.__class__.__name__}()"
|
||||||
|
|
||||||
|
|
||||||
|
# 保留旧名称作为别名
|
||||||
|
SparsePolicy = AttentionPolicy
|
||||||
|
```
|
||||||
|
|
||||||
|
## Implementation Plan
|
||||||
|
|
||||||
|
### Phase 2: 重构 policy.py
|
||||||
|
|
||||||
|
```python
|
||||||
|
# nanovllm/kvcache/sparse/policy.py
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Optional
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionPolicy(ABC):
|
||||||
|
"""Base class for attention policies in layerwise offload mode."""
|
||||||
|
|
||||||
|
supports_prefill: bool = True
|
||||||
|
supports_decode: bool = True
|
||||||
|
|
||||||
|
def estimate(
|
||||||
|
self,
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
layer_id: int,
|
||||||
|
) -> Optional[torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Estimate sparse attention mask.
|
||||||
|
|
||||||
|
For sparse policies (e.g., XAttention), computes block-level importance.
|
||||||
|
For full policy, returns None.
|
||||||
|
|
||||||
|
Corresponds to xattn_estimate() in COMPASS.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
sparse_mask: [num_heads, q_blocks, k_blocks] or None
|
||||||
|
"""
|
||||||
|
return None
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def compute_prefill(
|
||||||
|
self,
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
layer_id: int,
|
||||||
|
softmax_scale: float,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Compute prefill attention."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def compute_decode(
|
||||||
|
self,
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
layer_id: int,
|
||||||
|
softmax_scale: float,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Compute decode attention (default: FlashAttention)."""
|
||||||
|
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
||||||
|
|
||||||
|
context_len = k.shape[0]
|
||||||
|
cu_seqlens_q = torch.tensor([0, 1], dtype=torch.int32, device=q.device)
|
||||||
|
cu_seqlens_k = torch.tensor([0, context_len], dtype=torch.int32, device=q.device)
|
||||||
|
|
||||||
|
return flash_attn_varlen_func(
|
||||||
|
q, k, v,
|
||||||
|
cu_seqlens_q=cu_seqlens_q,
|
||||||
|
cu_seqlens_k=cu_seqlens_k,
|
||||||
|
max_seqlen_q=1,
|
||||||
|
max_seqlen_k=context_len,
|
||||||
|
softmax_scale=softmax_scale,
|
||||||
|
causal=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"{self.__class__.__name__}()"
|
||||||
|
|
||||||
|
|
||||||
|
# Backward compatibility alias
|
||||||
|
SparsePolicy = AttentionPolicy
|
||||||
|
```
|
||||||
|
|
||||||
|
### Phase 3: 重构 FullAttentionPolicy
|
||||||
|
|
||||||
|
```python
|
||||||
|
# nanovllm/kvcache/sparse/full_policy.py
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from .policy import AttentionPolicy
|
||||||
|
|
||||||
|
|
||||||
|
class FullAttentionPolicy(AttentionPolicy):
|
||||||
|
"""Full attention using FlashAttention (no sparsity)."""
|
||||||
|
|
||||||
|
supports_prefill = True
|
||||||
|
supports_decode = True
|
||||||
|
|
||||||
|
def estimate(self, q, k, layer_id):
|
||||||
|
"""Full attention - no sparse mask needed."""
|
||||||
|
return None
|
||||||
|
|
||||||
|
def compute_prefill(self, q, k, v, layer_id, softmax_scale):
|
||||||
|
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
||||||
|
|
||||||
|
seq_len = q.shape[0]
|
||||||
|
cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=q.device)
|
||||||
|
|
||||||
|
return flash_attn_varlen_func(
|
||||||
|
q, k, v,
|
||||||
|
cu_seqlens_q=cu_seqlens,
|
||||||
|
cu_seqlens_k=cu_seqlens,
|
||||||
|
max_seqlen_q=seq_len,
|
||||||
|
max_seqlen_k=seq_len,
|
||||||
|
softmax_scale=softmax_scale,
|
||||||
|
causal=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "FullAttentionPolicy()"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Phase 4: 重构 XAttentionPolicy
|
||||||
|
|
||||||
|
```python
|
||||||
|
# nanovllm/kvcache/sparse/xattn.py
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from typing import Optional
|
||||||
|
from .policy import AttentionPolicy
|
||||||
|
|
||||||
|
|
||||||
|
class XAttentionPolicy(AttentionPolicy):
|
||||||
|
"""
|
||||||
|
XAttention sparse prefill policy.
|
||||||
|
|
||||||
|
Uses chunked estimation to compute sparse attention mask,
|
||||||
|
then applies block sparse attention.
|
||||||
|
"""
|
||||||
|
|
||||||
|
supports_prefill = True
|
||||||
|
supports_decode = True
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
stride: int = 8,
|
||||||
|
threshold: float = 0.9,
|
||||||
|
block_size: int = 128,
|
||||||
|
chunk_size: int = 16384,
|
||||||
|
use_triton: bool = True,
|
||||||
|
):
|
||||||
|
self.stride = stride
|
||||||
|
self.threshold = threshold
|
||||||
|
self.block_size = block_size
|
||||||
|
self.chunk_size = chunk_size
|
||||||
|
self.use_triton = use_triton
|
||||||
|
|
||||||
|
def estimate(
|
||||||
|
self,
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
layer_id: int,
|
||||||
|
) -> Optional[torch.Tensor]:
|
||||||
|
"""
|
||||||
|
XAttention estimation (xattn_estimate).
|
||||||
|
|
||||||
|
Uses chunked GEMM + softmax to estimate block-level importance,
|
||||||
|
then selects important blocks based on threshold.
|
||||||
|
|
||||||
|
对应 COMPASS 的 xattn_estimate() 函数:
|
||||||
|
1. Pad inputs to chunk_size multiples
|
||||||
|
2. Reshape with stride
|
||||||
|
3. Compute QK^T in chunks (Triton)
|
||||||
|
4. Block-wise softmax + aggregation
|
||||||
|
5. Threshold-based selection
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q: [seq_len, num_heads, head_dim]
|
||||||
|
k: [seq_len, num_kv_heads, head_dim]
|
||||||
|
layer_id: transformer layer index
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
sparse_mask: [num_heads, q_blocks, k_blocks] boolean mask
|
||||||
|
or None (fallback to full attention)
|
||||||
|
"""
|
||||||
|
# TODO: 实现真正的 xattn_estimate
|
||||||
|
# 当前返回 None 使用 full attention
|
||||||
|
return None
|
||||||
|
|
||||||
|
def compute_prefill(
|
||||||
|
self,
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
layer_id: int,
|
||||||
|
softmax_scale: float,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Compute XAttention sparse prefill.
|
||||||
|
|
||||||
|
Flow:
|
||||||
|
1. Call estimate() to get sparse mask
|
||||||
|
2. If mask is None, use full attention
|
||||||
|
3. Otherwise, apply block sparse attention with mask
|
||||||
|
"""
|
||||||
|
# Step 1: Estimate sparse mask
|
||||||
|
sparse_mask = self.estimate(q, k, layer_id)
|
||||||
|
|
||||||
|
# Step 2: Compute attention
|
||||||
|
if sparse_mask is None:
|
||||||
|
# Fallback to full attention
|
||||||
|
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
||||||
|
|
||||||
|
seq_len = q.shape[0]
|
||||||
|
cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=q.device)
|
||||||
|
|
||||||
|
return flash_attn_varlen_func(
|
||||||
|
q, k, v,
|
||||||
|
cu_seqlens_q=cu_seqlens,
|
||||||
|
cu_seqlens_k=cu_seqlens,
|
||||||
|
max_seqlen_q=seq_len,
|
||||||
|
max_seqlen_k=seq_len,
|
||||||
|
softmax_scale=softmax_scale,
|
||||||
|
causal=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Apply block sparse attention with mask
|
||||||
|
# 使用 block_sparse_attn_func(q, k, v, sparse_mask, block_size)
|
||||||
|
raise NotImplementedError("Block sparse attention not yet implemented")
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return (f"XAttentionPolicy("
|
||||||
|
f"stride={self.stride}, "
|
||||||
|
f"threshold={self.threshold}, "
|
||||||
|
f"block_size={self.block_size})")
|
||||||
|
```
|
||||||
|
|
||||||
|
### Phase 5: 更新 model_runner.py
|
||||||
|
|
||||||
|
```python
|
||||||
|
# model_runner.py - allocate_kv_cache()
|
||||||
|
|
||||||
|
# 改为总是创建 policy(包括 FULL)
|
||||||
|
from nanovllm.kvcache.sparse import create_attention_policy
|
||||||
|
self.attention_policy = create_attention_policy(config.attention_policy, **policy_kwargs)
|
||||||
|
logger.info(f"Attention policy: {self.attention_policy}")
|
||||||
|
|
||||||
|
# run_layerwise_offload_prefill() 和 run_gpu_only_prefill()
|
||||||
|
|
||||||
|
# 旧代码:
|
||||||
|
if self.sparse_prefill_policy is not None:
|
||||||
|
attn_output = self.sparse_prefill_policy.sparse_prefill_attention(q, k, v, layer_id)
|
||||||
|
else:
|
||||||
|
attn_output = flash_attn_varlen_func(...)
|
||||||
|
|
||||||
|
# 新代码:
|
||||||
|
attn_output = self.attention_policy.compute_prefill(
|
||||||
|
q, k, v, layer_id, softmax_scale=layer.self_attn.attn.scale
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Method Mapping
|
||||||
|
|
||||||
|
| 旧方法 | 新方法 | 说明 |
|
||||||
|
|--------|--------|------|
|
||||||
|
| `select_blocks()` | `estimate()` | 计算 sparse mask(对应 xattn_estimate) |
|
||||||
|
| `sparse_prefill_attention()` | `compute_prefill()` | Prefill attention |
|
||||||
|
| (无) | `compute_decode()` | Decode attention(默认实现) |
|
||||||
|
| `on_prefill_offload()` | (移除) | Offload 在 model_runner 处理 |
|
||||||
|
|
||||||
|
## Files to Modify
|
||||||
|
|
||||||
|
| File | Changes |
|
||||||
|
|------|---------|
|
||||||
|
| `nanovllm/kvcache/sparse/policy.py` | 新接口:estimate, compute_prefill, compute_decode |
|
||||||
|
| `nanovllm/kvcache/sparse/full_policy.py` | 实现 compute_prefill(), estimate() 返回 None |
|
||||||
|
| `nanovllm/kvcache/sparse/xattn.py` | estimate() 对应 xattn_estimate, compute_prefill() |
|
||||||
|
| `nanovllm/kvcache/sparse/__init__.py` | 更新工厂函数 |
|
||||||
|
| `nanovllm/engine/model_runner.py` | 统一调用 attention_policy.compute_prefill() |
|
||||||
|
| `nanovllm/config.py` | 可选:重命名配置项 |
|
||||||
|
|
||||||
|
## Decisions Made
|
||||||
|
|
||||||
|
1. **方法命名**: `compute_prefill` / `compute_decode` 对应 chunked 版本的命名风格
|
||||||
|
2. **estimate 方法**: 替代 `select_blocks`,返回 sparse mask 而不是 block IDs
|
||||||
|
3. **XAttention**: `estimate()` 对应 COMPASS 的 `xattn_estimate()`
|
||||||
|
4. **Full Policy**: `estimate()` 返回 None 表示使用完整 attention
|
||||||
|
5. **Decode 默认实现**: 基类提供默认的 FlashAttention 实现
|
||||||
|
|
||||||
|
## Errors Encountered
|
||||||
|
- (无)
|
||||||
|
|
||||||
|
## Status
|
||||||
|
**Currently in Phase 1** - 完成分析和接口设计,等待用户确认后进入 Phase 2
|
||||||
Reference in New Issue
Block a user