📝 docs: add SparsePolicy implementation guide and update rules

- Create docs/sparse_policy_implementation_guide.md with comprehensive guide
- Rewrite .claude/rules/sparse-policy.md with mandatory base class requirements
- Add new doc reference to CLAUDE.md documentation index

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Zijie Tian
2026-01-20 02:25:46 +08:00
parent fa7601f4b8
commit 37aecd4d52
3 changed files with 421 additions and 41 deletions

View File

@@ -1,34 +1,28 @@
# Sparse Policy 代码规范
## supports_prefill / supports_decode 标志
## 基类要求 (MANDATORY)
每个 SparsePolicy 子类必须正确设置这两个标志
每个 `SparsePolicy` 子类 **必须** 遵守以下要求
### 1. 声明 supports_prefill / supports_decode 标志
```python
class MyPolicy(SparsePolicy):
supports_prefill = True # 是否支持 prefill 阶段
supports_decode = False # 是否支持 decode 阶段
supports_decode = True # 是否支持 decode 阶段
```
## 方法实现规范
### 2. 实现三个抽象方法
### 规则:不支持的阶段必须 assert False
| 方法 | 必须实现 | 说明 |
|------|---------|------|
| `select_blocks()` | ✅ | 选择要加载的 blocks |
| `compute_chunked_prefill()` | ✅ | Prefill attention 计算 |
| `compute_chunked_decode()` | ✅ | Decode attention 计算 |
如果 policy 不支持某个阶段,对应的 `compute_chunked_*` 方法内部**必须** `assert False`
### 3. 不支持的阶段必须 assert False
```python
class PrefillOnlyPolicy(SparsePolicy):
supports_prefill = True
supports_decode = False
def compute_chunked_prefill(self, ...):
# 正常实现 prefill 逻辑
...
def compute_chunked_decode(self, ...):
# 不支持 decode必须 assert False
assert False, "PrefillOnlyPolicy does not support decode phase"
```
如果 `supports_prefill = False`,则 `compute_chunked_prefill()` 内部 **必须** `assert False`
```python
class DecodeOnlyPolicy(SparsePolicy):
@@ -36,17 +30,29 @@ class DecodeOnlyPolicy(SparsePolicy):
supports_decode = True
def compute_chunked_prefill(self, ...):
# 不支持 prefill必须 assert False
assert False, "DecodeOnlyPolicy does not support prefill phase"
def compute_chunked_decode(self, ...):
# 正常实现 decode 逻辑
# 正常实现
...
```
### 规则FullPolicy 必须同时支持两个阶段
同理,如果 `supports_decode = False`
`FullAttentionPolicy` 作为默认策略,必须同时支持 prefill 和 decode
```python
class PrefillOnlyPolicy(SparsePolicy):
supports_prefill = True
supports_decode = False
def compute_chunked_prefill(self, ...):
# 正常实现
...
def compute_chunked_decode(self, ...):
assert False, "PrefillOnlyPolicy does not support decode phase"
```
### 4. FullAttentionPolicy 必须同时支持两个阶段
```python
class FullAttentionPolicy(SparsePolicy):
@@ -60,35 +66,27 @@ class FullAttentionPolicy(SparsePolicy):
# 完整实现
```
## 调用方检查
`attention.py` 中应在调用前检查 policy 是否支持当前阶段:
```python
# Prefill 路径
if not sparse_policy.supports_prefill:
raise RuntimeError(f"{sparse_policy} does not support prefill")
# Decode 路径
if not sparse_policy.supports_decode:
raise RuntimeError(f"{sparse_policy} does not support decode")
```
这样提供双重保护:
1. 调用方检查 → 提供清晰的错误信息
2. 方法内 assert → 防止绕过检查的调用
---
## CPU-GPU 通信规范
### 规则:所有通信必须通过 OffloadEngine
SparsePolicy 的 `compute_chunked_*` 方法中,所有 CPU-GPU 数据传输**必须**通过 `OffloadEngine` 进行,**禁止**直接使用 `torch.Tensor.copy_()``.to(device)`
`compute_chunked_*` 方法中,**禁止** 直接使用 `torch.Tensor.copy_()``.to(device)`
```python
# ✅ 正确:使用 OffloadEngine 的 ring buffer 方法
offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id)
offload_engine.wait_slot_layer(slot)
k, v = offload_engine.get_kv_for_slot(slot)
offload_engine.record_slot_compute_done(slot)
# ✅ 正确:使用 prefill buffer
k, v = offload_engine.get_prefill_buffer_slice(layer_id, num_tokens)
# ✅ 正确:使用 decode buffer
decode_k = offload_engine.decode_k_buffer[layer_id, start:end]
decode_v = offload_engine.decode_v_buffer[layer_id, start:end]
# ❌ 错误:直接使用 torch 通信
gpu_tensor.copy_(cpu_tensor)
@@ -102,3 +100,67 @@ gpu_tensor = cpu_tensor.cuda()
2. **Pipeline 优化**OffloadEngine 实现了 ring buffer pipeline
3. **资源管理**OffloadEngine 管理 GPU buffer slots避免内存碎片
4. **一致性**:统一的接口便于调试和维护
---
## 方法签名要求
### select_blocks()
```python
def select_blocks(
self,
available_blocks: List[int], # 可用的 CPU block IDs
offload_engine: "OffloadEngine", # 用于加载数据
ctx: PolicyContext, # 上下文信息
) -> List[int]: # 返回要加载的 block IDs
```
### compute_chunked_prefill()
```python
def compute_chunked_prefill(
self,
q: torch.Tensor, # [seq_len, num_heads, head_dim]
k: torch.Tensor, # [seq_len, num_kv_heads, head_dim] (unused)
v: torch.Tensor, # [seq_len, num_kv_heads, head_dim] (unused)
layer_id: int,
softmax_scale: float,
offload_engine: "OffloadEngine",
kvcache_manager: "KVCacheManager",
current_chunk_idx: int,
seq: "Sequence",
num_tokens: int,
) -> torch.Tensor: # [seq_len, num_heads, head_dim]
```
### compute_chunked_decode()
```python
def compute_chunked_decode(
self,
q: torch.Tensor, # [batch_size, num_heads, head_dim]
layer_id: int,
softmax_scale: float,
offload_engine: "OffloadEngine",
kvcache_manager: "KVCacheManager",
seq: "Sequence",
) -> torch.Tensor: # [batch_size, 1, num_heads, head_dim]
```
---
## 可选钩子方法
| 方法 | 调用时机 | 用途 |
|------|---------|------|
| `initialize()` | KV cache 分配后 | 初始化 metadata 结构 |
| `on_prefill_offload()` | GPU→CPU 复制前prefill | 收集 block metadata |
| `on_decode_offload()` | GPU→CPU 复制前decode | 更新 block metadata |
| `reset()` | 新 sequence 开始时 | 重置 policy 状态 |
---
## 详细实现指南
参考文档:[`docs/sparse_policy_implementation_guide.md`](../docs/sparse_policy_implementation_guide.md)