📝 docs: add layer offload planning notes and task plan
Add planning documents for layer-wise offload implementation: - notes.md: Implementation notes and findings - task_plan.md: Detailed task breakdown and progress tracking Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
130
notes.md
Normal file
130
notes.md
Normal file
@@ -0,0 +1,130 @@
|
||||
# Notes: SparsePolicy Refactoring Research
|
||||
|
||||
## Sources
|
||||
|
||||
### Source 1: tzj/minference branch - policy.py
|
||||
- 路径: `nanovllm/kvcache/sparse/policy.py`
|
||||
- 关键设计:
|
||||
- `PolicyContext` 数据类包含 query_chunk_idx, num_query_chunks, layer_id, query, is_prefill 等
|
||||
- `select_blocks()` 需要 offload_engine 参数
|
||||
- `compute_chunked_prefill()` 和 `compute_chunked_decode()` 是完整的 attention 流程
|
||||
- `on_prefill_offload()` / `on_decode_offload()` hooks 用于收集元数据
|
||||
|
||||
### Source 2: tzj/minference branch - full_policy.py
|
||||
- 路径: `nanovllm/kvcache/sparse/full_policy.py`
|
||||
- 关键实现:
|
||||
- `compute_chunked_prefill()` 内部使用 ring buffer pipeline 加载 blocks
|
||||
- 使用 `flash_attn_with_lse` 和 `merge_attention_outputs` 合并多个 chunk 的 attention
|
||||
- `compute_chunked_decode()` 处理 prefilled blocks + decode buffer
|
||||
|
||||
### Source 3: tzj/layer-offload branch - model_runner.py
|
||||
- 路径: `nanovllm/engine/model_runner.py`
|
||||
- 关键设计:
|
||||
- `run_layerwise_offload_prefill()` 逐层处理,每层计算完整 attention
|
||||
- `sparse_prefill_policy.sparse_prefill_attention(q, k, v, layer_id)` 简单接口
|
||||
- FULL policy 通过 `if sparse_prefill_policy is None` 走 else 分支
|
||||
|
||||
### Source 4: tzj/layer-offload branch - xattn.py
|
||||
- 路径: `nanovllm/kvcache/sparse/xattn.py`
|
||||
- 关键实现:
|
||||
- `sparse_prefill_attention()` 直接使用 FlashAttention(因为 chunked prefill 架构限制)
|
||||
- 保留 Triton kernels 供未来 GPU-only 模式
|
||||
|
||||
## Synthesized Findings
|
||||
|
||||
### 架构差异总结
|
||||
|
||||
| 方面 | Chunked Offload | Layerwise Offload |
|
||||
|------|-----------------|-------------------|
|
||||
| **Prefill 流程** | chunk-by-chunk,跨层 | layer-by-layer,完整序列 |
|
||||
| **KV 存储** | 每 chunk 立即 offload | 每层计算后 offload |
|
||||
| **Attention 计算** | 分多次计算+合并 | 一次完整计算 |
|
||||
| **Block 加载** | 需要从 CPU 加载历史 | 不需要,已在 GPU |
|
||||
| **Policy 责任** | 完整 attention 流程 | 仅 attention kernel 选择 |
|
||||
|
||||
### Layerwise Offload 的简化点
|
||||
|
||||
1. **不需要 block selection**: 整层 KV 都在 GPU,无需选择
|
||||
2. **不需要 offload_engine 参数**: Policy 不负责加载 KV
|
||||
3. **不需要 merge_attention_outputs**: 一次计算完整 attention
|
||||
4. **不需要 offload hooks**: offload 在 model_runner 统一处理
|
||||
|
||||
### 设计建议
|
||||
|
||||
1. **保持接口简单**: 只需要 `compute_prefill_attention()` 和 `compute_decode_attention()`
|
||||
2. **FULL 也实现方法**: 不再通过 `is None` 判断,所有 policy 统一调用
|
||||
3. **移除不必要的参数**: 不需要 offload_engine, kvcache_manager, seq 等
|
||||
4. **统一命名**: 使用 `compute_*_attention` 而不是 `sparse_prefill_attention`
|
||||
|
||||
## Code Examples
|
||||
|
||||
### 当前调用方式 (model_runner.py:876-891)
|
||||
|
||||
```python
|
||||
# Sparse or Full attention
|
||||
if self.sparse_prefill_policy is not None:
|
||||
# MInference or other sparse prefill policy
|
||||
attn_output = self.sparse_prefill_policy.sparse_prefill_attention(
|
||||
q, k, v, layer_id
|
||||
)
|
||||
else:
|
||||
# Full attention using FlashAttention
|
||||
attn_output = flash_attn_varlen_func(
|
||||
q, k, v, ...
|
||||
)
|
||||
```
|
||||
|
||||
### 建议的新调用方式
|
||||
|
||||
```python
|
||||
# 所有 policy 统一调用
|
||||
attn_output = self.attention_policy.compute_prefill_attention(
|
||||
q, k, v, layer_id, softmax_scale=layer.self_attn.attn.scale
|
||||
)
|
||||
```
|
||||
|
||||
## Questions Resolved
|
||||
|
||||
- Q: 是否需要 PolicyContext?
|
||||
- A: 可以简化,因为 layerwise 模式下不需要 chunk 信息
|
||||
|
||||
- Q: decode 阶段如何处理?
|
||||
- A: **Decode 不需要 policy**!当前 `run_layerwise_offload_decode()` 使用标准 `layer(positions, hidden_states, residual)` 调用,走 Attention.forward() 路径
|
||||
|
||||
- Q: 为什么 decode 不需要 sparse?
|
||||
- A: 因为 decode 每次只有 1 个 token,没有稀疏化的意义。KV 从 ring buffer 加载后直接用 flash_attn_with_kvcache
|
||||
|
||||
## Key Insight
|
||||
|
||||
**Layerwise Offload 的 Policy 设计应该只关注 Prefill**:
|
||||
|
||||
```
|
||||
Prefill: 需要 Policy
|
||||
- 整个序列一次计算 attention
|
||||
- 可以使用 sparse attention 方法(如 MInference 的 vertical+slash pattern)
|
||||
- Policy 接收 q, k, v, layer_id, softmax_scale
|
||||
|
||||
Decode: 不需要 Policy
|
||||
- 每次只有 1 个 token query
|
||||
- KV 从 ring buffer 加载
|
||||
- 使用标准 flash_attn_with_kvcache
|
||||
```
|
||||
|
||||
## Interface Comparison Summary
|
||||
|
||||
| 方面 | tzj/minference | tzj/layer-offload (新设计) |
|
||||
|------|----------------|---------------------------|
|
||||
| 类名 | SparsePolicy | AttentionPolicy |
|
||||
| Prefill 方法 | compute_chunked_prefill() | compute_attention() |
|
||||
| Decode 方法 | compute_chunked_decode() | 不需要(用标准路径) |
|
||||
| 需要 offload_engine | 是 | 否 |
|
||||
| 需要 kvcache_manager | 是 | 否 |
|
||||
| 需要 seq | 是 | 否 |
|
||||
| 支持 FULL | 是 | 是 |
|
||||
|
||||
## Migration Path
|
||||
|
||||
1. 保留 `SparsePolicy` 作为 `AttentionPolicy` 的别名
|
||||
2. 保留 `PolicyContext` 供未来扩展
|
||||
3. 保留 `select_blocks()` 方法签名(虽然不使用)
|
||||
4. 移除 `requires_block_selection` 属性(不需要)
|
||||
549
task_plan.md
Normal file
549
task_plan.md
Normal file
@@ -0,0 +1,549 @@
|
||||
# Task Plan: Refactor SparsePolicy for Layerwise Offload
|
||||
|
||||
## Goal
|
||||
重构 SparsePolicy 接口,参考 tzj/minference 分支的设计模式,使所有 attention 都可以抽象成 policy,并按统一规范编写。适配当前 layerwise offload 架构特点(整层 KV 在 GPU 上)。
|
||||
|
||||
## Background
|
||||
|
||||
### 两种 Offload 架构对比
|
||||
|
||||
| 特性 | tzj/minference (Chunked Offload) | tzj/layer-offload (Layerwise Offload) |
|
||||
|------|----------------------------------|---------------------------------------|
|
||||
| 处理粒度 | 每次一个 chunk (block_size tokens) | 每次一整层 (所有 tokens) |
|
||||
| KV 位置 | 历史 chunks 在 CPU,需要加载 | 整层 KV 都在 GPU |
|
||||
| Policy 入口 | `compute_chunked_prefill()/decode()` | `compute_prefill()/decode()` |
|
||||
| 需要 offload_engine | 是(加载 blocks) | 否(KV 已在 GPU) |
|
||||
| Mask 计算 | `select_blocks()` 返回 block IDs | `estimate()` 返回 sparse mask |
|
||||
|
||||
### tzj/minference 的 Policy 接口
|
||||
|
||||
```python
|
||||
class SparsePolicy(ABC):
|
||||
supports_prefill: bool
|
||||
supports_decode: bool
|
||||
|
||||
@abstractmethod
|
||||
def select_blocks(self, available_blocks, offload_engine, ctx) -> List[int]
|
||||
|
||||
@abstractmethod
|
||||
def compute_chunked_prefill(self, q, k, v, layer_id, ..., offload_engine, ...) -> Tensor
|
||||
|
||||
@abstractmethod
|
||||
def compute_chunked_decode(self, q, layer_id, ..., offload_engine, ...) -> Tensor
|
||||
```
|
||||
|
||||
### 当前 branch 的 Policy 接口(重构前)
|
||||
|
||||
```python
|
||||
class SparsePolicy(ABC):
|
||||
supports_prefill: bool
|
||||
supports_decode: bool
|
||||
|
||||
@abstractmethod
|
||||
def select_blocks(self, available_blocks, ctx) -> List[int]
|
||||
|
||||
def sparse_prefill_attention(self, q, k, v, layer_id) -> Tensor
|
||||
```
|
||||
|
||||
## Phases
|
||||
|
||||
- [x] Phase 1: 分析差异并设计新接口
|
||||
- [x] **Phase 0: 创建 nanovllm.ops 模块** ✅ 测试通过
|
||||
- [ ] Phase 2: 重构 AttentionPolicy 基类
|
||||
- [ ] Phase 3: 重构 FullAttentionPolicy
|
||||
- [ ] Phase 4: 重构 XAttentionPolicy (含 estimate 方法)
|
||||
- [ ] Phase 5: 更新 model_runner 调用方式
|
||||
- [ ] Phase 6: 测试验证
|
||||
|
||||
---
|
||||
|
||||
## Phase 0: 创建 nanovllm.ops 模块
|
||||
|
||||
### 目标
|
||||
从 tzj/minference 分支提取 ops 模块,为 XAttention estimate 提供底层算子支持。
|
||||
|
||||
### 步骤
|
||||
|
||||
1. **创建目录结构**
|
||||
```
|
||||
nanovllm/ops/
|
||||
├── __init__.py
|
||||
├── xattn.py # xattn_estimate, xattn_estimate_chunked, Triton kernels
|
||||
└── chunked_attention.py # flash_attn_with_lse, merge_attention_outputs (备用)
|
||||
```
|
||||
|
||||
2. **从 tzj/minference 提取文件**
|
||||
```bash
|
||||
git show tzj/minference:nanovllm/ops/__init__.py > nanovllm/ops/__init__.py
|
||||
git show tzj/minference:nanovllm/ops/xattn.py > nanovllm/ops/xattn.py
|
||||
git show tzj/minference:nanovllm/ops/chunked_attention.py > nanovllm/ops/chunked_attention.py
|
||||
```
|
||||
|
||||
3. **Cherry-pick 测试文件**
|
||||
```bash
|
||||
git show tzj/minference:tests/test_xattn_estimate_chunked.py > tests/test_xattn_estimate_chunked.py
|
||||
```
|
||||
|
||||
4. **运行测试验证**
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/Worktree/nano-vllm:$PYTHONPATH \
|
||||
python tests/test_xattn_estimate_chunked.py
|
||||
```
|
||||
|
||||
### nanovllm/ops 模块内容
|
||||
|
||||
| 文件 | 核心函数 | 用途 |
|
||||
|------|----------|------|
|
||||
| `xattn.py` | `xattn_estimate()` | 标准 XAttention estimation |
|
||||
| `xattn.py` | `xattn_estimate_chunked()` | Chunked prefill 版本 |
|
||||
| `xattn.py` | `flat_group_gemm_fuse_reshape()` | Triton kernel: fused reshape + GEMM |
|
||||
| `xattn.py` | `softmax_fuse_block_sum()` | Triton kernel: softmax + block sum |
|
||||
| `xattn.py` | `find_blocks_chunked()` | Block selection based on threshold |
|
||||
| `chunked_attention.py` | `flash_attn_with_lse()` | Flash attention with LSE output |
|
||||
| `chunked_attention.py` | `merge_attention_outputs()` | Merge multiple attention chunks |
|
||||
|
||||
### 与 Policy 的关系
|
||||
|
||||
```
|
||||
XAttentionPolicy.estimate()
|
||||
└── 调用 nanovllm.ops.xattn.xattn_estimate()
|
||||
├── flat_group_gemm_fuse_reshape() (Triton)
|
||||
├── softmax_fuse_block_sum() (Triton)
|
||||
└── find_blocks_chunked()
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Key Questions
|
||||
|
||||
1. **`select_blocks` 改为什么?**
|
||||
- 改名为 `estimate()`:用于计算 sparse mask
|
||||
- 对于 XAttention,对应 COMPASS 的 `xattn_estimate()` 函数
|
||||
- FullAttentionPolicy 的 `estimate()` 返回 None(表示 full attention)
|
||||
|
||||
2. **Policy 接口应该如何设计?**
|
||||
- Prefill: `compute_prefill(q, k, v, layer_id, softmax_scale)`
|
||||
- Decode: `compute_decode(q, k, v, layer_id, softmax_scale)`
|
||||
- Estimate: `estimate(q, k, layer_id)` - 计算 sparse mask
|
||||
|
||||
3. **FULL policy 如何处理?**
|
||||
- FULL 也实现 `compute_prefill/decode`,使用 FlashAttention
|
||||
- `estimate()` 返回 None(表示不进行稀疏化)
|
||||
|
||||
## Proposed New Interface
|
||||
|
||||
```python
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
import torch
|
||||
|
||||
|
||||
class AttentionPolicy(ABC):
|
||||
"""Layerwise Offload 模式下的 Attention Policy
|
||||
|
||||
所有 attention 计算都通过 policy 进行,包括 Full 和 Sparse。
|
||||
支持 prefill 和 decode 两个阶段。
|
||||
"""
|
||||
|
||||
supports_prefill: bool = True
|
||||
supports_decode: bool = True
|
||||
|
||||
def estimate(
|
||||
self,
|
||||
q: torch.Tensor, # [seq_len, num_heads, head_dim]
|
||||
k: torch.Tensor, # [seq_len, num_kv_heads, head_dim]
|
||||
layer_id: int,
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
估算 sparse attention mask。
|
||||
|
||||
对于 sparse policy(如 XAttention),计算哪些 blocks 需要 attend。
|
||||
对于 full policy,返回 None 表示使用完整 attention。
|
||||
|
||||
对应 COMPASS 的 xattn_estimate() 函数。
|
||||
|
||||
Args:
|
||||
q: Query tensor [seq_len, num_heads, head_dim]
|
||||
k: Key tensor [seq_len, num_kv_heads, head_dim]
|
||||
layer_id: Transformer layer index
|
||||
|
||||
Returns:
|
||||
sparse_mask: [num_heads, q_blocks, k_blocks] boolean mask, 或 None
|
||||
"""
|
||||
return None # 默认为 full attention
|
||||
|
||||
@abstractmethod
|
||||
def compute_prefill(
|
||||
self,
|
||||
q: torch.Tensor, # [seq_len, num_heads, head_dim]
|
||||
k: torch.Tensor, # [seq_len, num_kv_heads, head_dim]
|
||||
v: torch.Tensor, # [seq_len, num_kv_heads, head_dim]
|
||||
layer_id: int,
|
||||
softmax_scale: float,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
计算 prefill attention。
|
||||
|
||||
整层 KV 都在 GPU 上,一次计算完整 attention。
|
||||
可以先调用 estimate() 获取 sparse mask,然后应用 block sparse attention。
|
||||
|
||||
Args:
|
||||
q: Query tensor [seq_len, num_heads, head_dim]
|
||||
k: Key tensor [seq_len, num_kv_heads, head_dim]
|
||||
v: Value tensor [seq_len, num_kv_heads, head_dim]
|
||||
layer_id: Transformer layer index
|
||||
softmax_scale: Softmax scaling factor (1/sqrt(head_dim))
|
||||
|
||||
Returns:
|
||||
Attention output [seq_len, num_heads, head_dim]
|
||||
"""
|
||||
pass
|
||||
|
||||
def compute_decode(
|
||||
self,
|
||||
q: torch.Tensor, # [1, num_heads, head_dim]
|
||||
k: torch.Tensor, # [context_len+1, num_kv_heads, head_dim]
|
||||
v: torch.Tensor, # [context_len+1, num_kv_heads, head_dim]
|
||||
layer_id: int,
|
||||
softmax_scale: float,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
计算 decode attention。
|
||||
|
||||
KV 从 ring buffer 提供,包含 prefill tokens + 已 decode 的 tokens。
|
||||
|
||||
Args:
|
||||
q: Query tensor [1, num_heads, head_dim]
|
||||
k: Key tensor [context_len+1, num_kv_heads, head_dim]
|
||||
v: Value tensor [context_len+1, num_kv_heads, head_dim]
|
||||
layer_id: Transformer layer index
|
||||
softmax_scale: Softmax scaling factor
|
||||
|
||||
Returns:
|
||||
Attention output [1, num_heads, head_dim]
|
||||
"""
|
||||
# 默认实现:使用 FlashAttention
|
||||
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
||||
|
||||
context_len = k.shape[0]
|
||||
cu_seqlens_q = torch.tensor([0, 1], dtype=torch.int32, device=q.device)
|
||||
cu_seqlens_k = torch.tensor([0, context_len], dtype=torch.int32, device=q.device)
|
||||
|
||||
return flash_attn_varlen_func(
|
||||
q, k, v,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=1,
|
||||
max_seqlen_k=context_len,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=False,
|
||||
)
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset policy state between sequences."""
|
||||
pass
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}()"
|
||||
|
||||
|
||||
# 保留旧名称作为别名
|
||||
SparsePolicy = AttentionPolicy
|
||||
```
|
||||
|
||||
## Implementation Plan
|
||||
|
||||
### Phase 2: 重构 policy.py
|
||||
|
||||
```python
|
||||
# nanovllm/kvcache/sparse/policy.py
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
import torch
|
||||
|
||||
|
||||
class AttentionPolicy(ABC):
|
||||
"""Base class for attention policies in layerwise offload mode."""
|
||||
|
||||
supports_prefill: bool = True
|
||||
supports_decode: bool = True
|
||||
|
||||
def estimate(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
layer_id: int,
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Estimate sparse attention mask.
|
||||
|
||||
For sparse policies (e.g., XAttention), computes block-level importance.
|
||||
For full policy, returns None.
|
||||
|
||||
Corresponds to xattn_estimate() in COMPASS.
|
||||
|
||||
Returns:
|
||||
sparse_mask: [num_heads, q_blocks, k_blocks] or None
|
||||
"""
|
||||
return None
|
||||
|
||||
@abstractmethod
|
||||
def compute_prefill(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
layer_id: int,
|
||||
softmax_scale: float,
|
||||
) -> torch.Tensor:
|
||||
"""Compute prefill attention."""
|
||||
pass
|
||||
|
||||
def compute_decode(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
layer_id: int,
|
||||
softmax_scale: float,
|
||||
) -> torch.Tensor:
|
||||
"""Compute decode attention (default: FlashAttention)."""
|
||||
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
||||
|
||||
context_len = k.shape[0]
|
||||
cu_seqlens_q = torch.tensor([0, 1], dtype=torch.int32, device=q.device)
|
||||
cu_seqlens_k = torch.tensor([0, context_len], dtype=torch.int32, device=q.device)
|
||||
|
||||
return flash_attn_varlen_func(
|
||||
q, k, v,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=1,
|
||||
max_seqlen_k=context_len,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=False,
|
||||
)
|
||||
|
||||
def reset(self) -> None:
|
||||
pass
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}()"
|
||||
|
||||
|
||||
# Backward compatibility alias
|
||||
SparsePolicy = AttentionPolicy
|
||||
```
|
||||
|
||||
### Phase 3: 重构 FullAttentionPolicy
|
||||
|
||||
```python
|
||||
# nanovllm/kvcache/sparse/full_policy.py
|
||||
|
||||
import torch
|
||||
from .policy import AttentionPolicy
|
||||
|
||||
|
||||
class FullAttentionPolicy(AttentionPolicy):
|
||||
"""Full attention using FlashAttention (no sparsity)."""
|
||||
|
||||
supports_prefill = True
|
||||
supports_decode = True
|
||||
|
||||
def estimate(self, q, k, layer_id):
|
||||
"""Full attention - no sparse mask needed."""
|
||||
return None
|
||||
|
||||
def compute_prefill(self, q, k, v, layer_id, softmax_scale):
|
||||
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
||||
|
||||
seq_len = q.shape[0]
|
||||
cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=q.device)
|
||||
|
||||
return flash_attn_varlen_func(
|
||||
q, k, v,
|
||||
cu_seqlens_q=cu_seqlens,
|
||||
cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=seq_len,
|
||||
max_seqlen_k=seq_len,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return "FullAttentionPolicy()"
|
||||
```
|
||||
|
||||
### Phase 4: 重构 XAttentionPolicy
|
||||
|
||||
```python
|
||||
# nanovllm/kvcache/sparse/xattn.py
|
||||
|
||||
import torch
|
||||
from typing import Optional
|
||||
from .policy import AttentionPolicy
|
||||
|
||||
|
||||
class XAttentionPolicy(AttentionPolicy):
|
||||
"""
|
||||
XAttention sparse prefill policy.
|
||||
|
||||
Uses chunked estimation to compute sparse attention mask,
|
||||
then applies block sparse attention.
|
||||
"""
|
||||
|
||||
supports_prefill = True
|
||||
supports_decode = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stride: int = 8,
|
||||
threshold: float = 0.9,
|
||||
block_size: int = 128,
|
||||
chunk_size: int = 16384,
|
||||
use_triton: bool = True,
|
||||
):
|
||||
self.stride = stride
|
||||
self.threshold = threshold
|
||||
self.block_size = block_size
|
||||
self.chunk_size = chunk_size
|
||||
self.use_triton = use_triton
|
||||
|
||||
def estimate(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
layer_id: int,
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
XAttention estimation (xattn_estimate).
|
||||
|
||||
Uses chunked GEMM + softmax to estimate block-level importance,
|
||||
then selects important blocks based on threshold.
|
||||
|
||||
对应 COMPASS 的 xattn_estimate() 函数:
|
||||
1. Pad inputs to chunk_size multiples
|
||||
2. Reshape with stride
|
||||
3. Compute QK^T in chunks (Triton)
|
||||
4. Block-wise softmax + aggregation
|
||||
5. Threshold-based selection
|
||||
|
||||
Args:
|
||||
q: [seq_len, num_heads, head_dim]
|
||||
k: [seq_len, num_kv_heads, head_dim]
|
||||
layer_id: transformer layer index
|
||||
|
||||
Returns:
|
||||
sparse_mask: [num_heads, q_blocks, k_blocks] boolean mask
|
||||
or None (fallback to full attention)
|
||||
"""
|
||||
# TODO: 实现真正的 xattn_estimate
|
||||
# 当前返回 None 使用 full attention
|
||||
return None
|
||||
|
||||
def compute_prefill(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
layer_id: int,
|
||||
softmax_scale: float,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute XAttention sparse prefill.
|
||||
|
||||
Flow:
|
||||
1. Call estimate() to get sparse mask
|
||||
2. If mask is None, use full attention
|
||||
3. Otherwise, apply block sparse attention with mask
|
||||
"""
|
||||
# Step 1: Estimate sparse mask
|
||||
sparse_mask = self.estimate(q, k, layer_id)
|
||||
|
||||
# Step 2: Compute attention
|
||||
if sparse_mask is None:
|
||||
# Fallback to full attention
|
||||
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
||||
|
||||
seq_len = q.shape[0]
|
||||
cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=q.device)
|
||||
|
||||
return flash_attn_varlen_func(
|
||||
q, k, v,
|
||||
cu_seqlens_q=cu_seqlens,
|
||||
cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=seq_len,
|
||||
max_seqlen_k=seq_len,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=True,
|
||||
)
|
||||
else:
|
||||
# Apply block sparse attention with mask
|
||||
# 使用 block_sparse_attn_func(q, k, v, sparse_mask, block_size)
|
||||
raise NotImplementedError("Block sparse attention not yet implemented")
|
||||
|
||||
def __repr__(self):
|
||||
return (f"XAttentionPolicy("
|
||||
f"stride={self.stride}, "
|
||||
f"threshold={self.threshold}, "
|
||||
f"block_size={self.block_size})")
|
||||
```
|
||||
|
||||
### Phase 5: 更新 model_runner.py
|
||||
|
||||
```python
|
||||
# model_runner.py - allocate_kv_cache()
|
||||
|
||||
# 改为总是创建 policy(包括 FULL)
|
||||
from nanovllm.kvcache.sparse import create_attention_policy
|
||||
self.attention_policy = create_attention_policy(config.attention_policy, **policy_kwargs)
|
||||
logger.info(f"Attention policy: {self.attention_policy}")
|
||||
|
||||
# run_layerwise_offload_prefill() 和 run_gpu_only_prefill()
|
||||
|
||||
# 旧代码:
|
||||
if self.sparse_prefill_policy is not None:
|
||||
attn_output = self.sparse_prefill_policy.sparse_prefill_attention(q, k, v, layer_id)
|
||||
else:
|
||||
attn_output = flash_attn_varlen_func(...)
|
||||
|
||||
# 新代码:
|
||||
attn_output = self.attention_policy.compute_prefill(
|
||||
q, k, v, layer_id, softmax_scale=layer.self_attn.attn.scale
|
||||
)
|
||||
```
|
||||
|
||||
## Method Mapping
|
||||
|
||||
| 旧方法 | 新方法 | 说明 |
|
||||
|--------|--------|------|
|
||||
| `select_blocks()` | `estimate()` | 计算 sparse mask(对应 xattn_estimate) |
|
||||
| `sparse_prefill_attention()` | `compute_prefill()` | Prefill attention |
|
||||
| (无) | `compute_decode()` | Decode attention(默认实现) |
|
||||
| `on_prefill_offload()` | (移除) | Offload 在 model_runner 处理 |
|
||||
|
||||
## Files to Modify
|
||||
|
||||
| File | Changes |
|
||||
|------|---------|
|
||||
| `nanovllm/kvcache/sparse/policy.py` | 新接口:estimate, compute_prefill, compute_decode |
|
||||
| `nanovllm/kvcache/sparse/full_policy.py` | 实现 compute_prefill(), estimate() 返回 None |
|
||||
| `nanovllm/kvcache/sparse/xattn.py` | estimate() 对应 xattn_estimate, compute_prefill() |
|
||||
| `nanovllm/kvcache/sparse/__init__.py` | 更新工厂函数 |
|
||||
| `nanovllm/engine/model_runner.py` | 统一调用 attention_policy.compute_prefill() |
|
||||
| `nanovllm/config.py` | 可选:重命名配置项 |
|
||||
|
||||
## Decisions Made
|
||||
|
||||
1. **方法命名**: `compute_prefill` / `compute_decode` 对应 chunked 版本的命名风格
|
||||
2. **estimate 方法**: 替代 `select_blocks`,返回 sparse mask 而不是 block IDs
|
||||
3. **XAttention**: `estimate()` 对应 COMPASS 的 `xattn_estimate()`
|
||||
4. **Full Policy**: `estimate()` 返回 None 表示使用完整 attention
|
||||
5. **Decode 默认实现**: 基类提供默认的 FlashAttention 实现
|
||||
|
||||
## Errors Encountered
|
||||
- (无)
|
||||
|
||||
## Status
|
||||
**Currently in Phase 1** - 完成分析和接口设计,等待用户确认后进入 Phase 2
|
||||
Reference in New Issue
Block a user