- Add "Policy 不能为 None (CRITICAL)" section - Document that sparse_policy must always be at least FullAttentionPolicy - Document warmup phase as the only exception where kvcache_manager can be None Generated with [Claude Code](https://claude.ai/code) via [Happy](https://happy.engineering) Co-Authored-By: Claude <noreply@anthropic.com> Co-Authored-By: Happy <yesreply@happy.engineering>
201 lines
5.7 KiB
Markdown
201 lines
5.7 KiB
Markdown
# Sparse Policy 代码规范
|
||
|
||
## Policy 不能为 None (CRITICAL)
|
||
|
||
**强制规则**: `sparse_policy` 参数**永远不能为 None**,必须至少为 `FullAttentionPolicy`。
|
||
|
||
```python
|
||
# ❌ 错误:允许 None
|
||
sparse_policy = getattr(config, 'sparse_policy', None)
|
||
|
||
# ✅ 正确:显式处理 None,默认使用 FULL
|
||
sparse_policy_type = getattr(config, 'sparse_policy', None)
|
||
if sparse_policy_type is None:
|
||
sparse_policy_type = SparsePolicyType.FULL
|
||
```
|
||
|
||
**原因**:
|
||
1. 统一的 API:所有代码路径都通过 policy 进行 attention 计算
|
||
2. 避免空指针:消除 `policy.xxx` 调用时的 None 检查
|
||
3. 简化逻辑:不需要 `if policy is not None` 的分支
|
||
|
||
**唯一例外:Warmup 阶段**
|
||
|
||
在 `model_runner.warmup_model()` 期间,kvcache_manager 还未分配。此时 `attention.py` 使用 flash_attn fallback:
|
||
|
||
```python
|
||
# attention.py 中的 warmup 处理
|
||
if context.kvcache_manager is None:
|
||
# Warmup phase: use flash_attn directly
|
||
return flash_attn_varlen_func(...) if context.is_prefill else flash_attn_with_kvcache(...)
|
||
```
|
||
|
||
这是唯一允许 kvcache_manager 为 None 的情况。正式推理时,policy 必须存在。
|
||
|
||
---
|
||
|
||
## 基类要求 (MANDATORY)
|
||
|
||
每个 `SparsePolicy` 子类 **必须** 遵守以下要求:
|
||
|
||
### 1. 声明 supports_prefill / supports_decode 标志
|
||
|
||
```python
|
||
class MyPolicy(SparsePolicy):
|
||
supports_prefill = True # 是否支持 prefill 阶段
|
||
supports_decode = True # 是否支持 decode 阶段
|
||
```
|
||
|
||
### 2. 实现三个抽象方法
|
||
|
||
| 方法 | 必须实现 | 说明 |
|
||
|------|---------|------|
|
||
| `select_blocks()` | ✅ | 选择要加载的 blocks |
|
||
| `compute_chunked_prefill()` | ✅ | Prefill attention 计算 |
|
||
| `compute_chunked_decode()` | ✅ | Decode attention 计算 |
|
||
|
||
### 3. 不支持的阶段必须 assert False
|
||
|
||
如果 `supports_prefill = False`,则 `compute_chunked_prefill()` 内部 **必须** `assert False`:
|
||
|
||
```python
|
||
class DecodeOnlyPolicy(SparsePolicy):
|
||
supports_prefill = False
|
||
supports_decode = True
|
||
|
||
def compute_chunked_prefill(self, ...):
|
||
assert False, "DecodeOnlyPolicy does not support prefill phase"
|
||
|
||
def compute_chunked_decode(self, ...):
|
||
# 正常实现
|
||
...
|
||
```
|
||
|
||
同理,如果 `supports_decode = False`:
|
||
|
||
```python
|
||
class PrefillOnlyPolicy(SparsePolicy):
|
||
supports_prefill = True
|
||
supports_decode = False
|
||
|
||
def compute_chunked_prefill(self, ...):
|
||
# 正常实现
|
||
...
|
||
|
||
def compute_chunked_decode(self, ...):
|
||
assert False, "PrefillOnlyPolicy does not support decode phase"
|
||
```
|
||
|
||
### 4. FullAttentionPolicy 必须同时支持两个阶段
|
||
|
||
```python
|
||
class FullAttentionPolicy(SparsePolicy):
|
||
supports_prefill = True
|
||
supports_decode = True
|
||
|
||
def compute_chunked_prefill(self, ...):
|
||
# 完整实现
|
||
|
||
def compute_chunked_decode(self, ...):
|
||
# 完整实现
|
||
```
|
||
|
||
---
|
||
|
||
## CPU-GPU 通信规范
|
||
|
||
### 规则:所有通信必须通过 OffloadEngine
|
||
|
||
在 `compute_chunked_*` 方法中,**禁止** 直接使用 `torch.Tensor.copy_()` 或 `.to(device)`:
|
||
|
||
```python
|
||
# ✅ 正确:使用 OffloadEngine 的 ring buffer 方法
|
||
offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id)
|
||
offload_engine.wait_slot_layer(slot)
|
||
k, v = offload_engine.get_kv_for_slot(slot)
|
||
offload_engine.record_slot_compute_done(slot)
|
||
|
||
# ✅ 正确:使用 prefill buffer
|
||
k, v = offload_engine.get_prefill_buffer_slice(layer_id, num_tokens)
|
||
|
||
# ✅ 正确:使用 decode buffer
|
||
decode_k = offload_engine.decode_k_buffer[layer_id, start:end]
|
||
decode_v = offload_engine.decode_v_buffer[layer_id, start:end]
|
||
|
||
# ❌ 错误:直接使用 torch 通信
|
||
gpu_tensor.copy_(cpu_tensor)
|
||
gpu_tensor = cpu_tensor.to("cuda")
|
||
gpu_tensor = cpu_tensor.cuda()
|
||
```
|
||
|
||
### 原因
|
||
|
||
1. **流同步**:OffloadEngine 内部管理 CUDA streams,确保正确的同步
|
||
2. **Pipeline 优化**:OffloadEngine 实现了 ring buffer pipeline
|
||
3. **资源管理**:OffloadEngine 管理 GPU buffer slots,避免内存碎片
|
||
4. **一致性**:统一的接口便于调试和维护
|
||
|
||
---
|
||
|
||
## 方法签名要求
|
||
|
||
### select_blocks()
|
||
|
||
```python
|
||
def select_blocks(
|
||
self,
|
||
available_blocks: List[int], # 可用的 CPU block IDs
|
||
offload_engine: "OffloadEngine", # 用于加载数据
|
||
ctx: PolicyContext, # 上下文信息
|
||
) -> List[int]: # 返回要加载的 block IDs
|
||
```
|
||
|
||
### compute_chunked_prefill()
|
||
|
||
```python
|
||
def compute_chunked_prefill(
|
||
self,
|
||
q: torch.Tensor, # [seq_len, num_heads, head_dim]
|
||
k: torch.Tensor, # [seq_len, num_kv_heads, head_dim] (unused)
|
||
v: torch.Tensor, # [seq_len, num_kv_heads, head_dim] (unused)
|
||
layer_id: int,
|
||
softmax_scale: float,
|
||
offload_engine: "OffloadEngine",
|
||
kvcache_manager: "KVCacheManager",
|
||
current_chunk_idx: int,
|
||
seq: "Sequence",
|
||
num_tokens: int,
|
||
) -> torch.Tensor: # [seq_len, num_heads, head_dim]
|
||
```
|
||
|
||
### compute_chunked_decode()
|
||
|
||
```python
|
||
def compute_chunked_decode(
|
||
self,
|
||
q: torch.Tensor, # [batch_size, num_heads, head_dim]
|
||
layer_id: int,
|
||
softmax_scale: float,
|
||
offload_engine: "OffloadEngine",
|
||
kvcache_manager: "KVCacheManager",
|
||
seq: "Sequence",
|
||
) -> torch.Tensor: # [batch_size, 1, num_heads, head_dim]
|
||
```
|
||
|
||
---
|
||
|
||
## 可选钩子方法
|
||
|
||
| 方法 | 调用时机 | 用途 |
|
||
|------|---------|------|
|
||
| `initialize()` | KV cache 分配后 | 初始化 metadata 结构 |
|
||
| `on_prefill_offload()` | GPU→CPU 复制前(prefill) | 收集 block metadata |
|
||
| `on_decode_offload()` | GPU→CPU 复制前(decode) | 更新 block metadata |
|
||
| `reset()` | 新 sequence 开始时 | 重置 policy 状态 |
|
||
|
||
---
|
||
|
||
## 详细实现指南
|
||
|
||
参考文档:[`docs/sparse_policy_implementation_guide.md`](../docs/sparse_policy_implementation_guide.md)
|