📝 docs: add SparsePolicy implementation guide and update rules
- Create docs/sparse_policy_implementation_guide.md with comprehensive guide - Rewrite .claude/rules/sparse-policy.md with mandatory base class requirements - Add new doc reference to CLAUDE.md documentation index Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -1,34 +1,28 @@
|
|||||||
# Sparse Policy 代码规范
|
# Sparse Policy 代码规范
|
||||||
|
|
||||||
## supports_prefill / supports_decode 标志
|
## 基类要求 (MANDATORY)
|
||||||
|
|
||||||
每个 SparsePolicy 子类必须正确设置这两个标志:
|
每个 `SparsePolicy` 子类 **必须** 遵守以下要求:
|
||||||
|
|
||||||
|
### 1. 声明 supports_prefill / supports_decode 标志
|
||||||
|
|
||||||
```python
|
```python
|
||||||
class MyPolicy(SparsePolicy):
|
class MyPolicy(SparsePolicy):
|
||||||
supports_prefill = True # 是否支持 prefill 阶段
|
supports_prefill = True # 是否支持 prefill 阶段
|
||||||
supports_decode = False # 是否支持 decode 阶段
|
supports_decode = True # 是否支持 decode 阶段
|
||||||
```
|
```
|
||||||
|
|
||||||
## 方法实现规范
|
### 2. 实现三个抽象方法
|
||||||
|
|
||||||
### 规则:不支持的阶段必须 assert False
|
| 方法 | 必须实现 | 说明 |
|
||||||
|
|------|---------|------|
|
||||||
|
| `select_blocks()` | ✅ | 选择要加载的 blocks |
|
||||||
|
| `compute_chunked_prefill()` | ✅ | Prefill attention 计算 |
|
||||||
|
| `compute_chunked_decode()` | ✅ | Decode attention 计算 |
|
||||||
|
|
||||||
如果 policy 不支持某个阶段,对应的 `compute_chunked_*` 方法内部**必须** `assert False`:
|
### 3. 不支持的阶段必须 assert False
|
||||||
|
|
||||||
```python
|
如果 `supports_prefill = False`,则 `compute_chunked_prefill()` 内部 **必须** `assert False`:
|
||||||
class PrefillOnlyPolicy(SparsePolicy):
|
|
||||||
supports_prefill = True
|
|
||||||
supports_decode = False
|
|
||||||
|
|
||||||
def compute_chunked_prefill(self, ...):
|
|
||||||
# 正常实现 prefill 逻辑
|
|
||||||
...
|
|
||||||
|
|
||||||
def compute_chunked_decode(self, ...):
|
|
||||||
# 不支持 decode,必须 assert False
|
|
||||||
assert False, "PrefillOnlyPolicy does not support decode phase"
|
|
||||||
```
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
class DecodeOnlyPolicy(SparsePolicy):
|
class DecodeOnlyPolicy(SparsePolicy):
|
||||||
@@ -36,17 +30,29 @@ class DecodeOnlyPolicy(SparsePolicy):
|
|||||||
supports_decode = True
|
supports_decode = True
|
||||||
|
|
||||||
def compute_chunked_prefill(self, ...):
|
def compute_chunked_prefill(self, ...):
|
||||||
# 不支持 prefill,必须 assert False
|
|
||||||
assert False, "DecodeOnlyPolicy does not support prefill phase"
|
assert False, "DecodeOnlyPolicy does not support prefill phase"
|
||||||
|
|
||||||
def compute_chunked_decode(self, ...):
|
def compute_chunked_decode(self, ...):
|
||||||
# 正常实现 decode 逻辑
|
# 正常实现
|
||||||
...
|
...
|
||||||
```
|
```
|
||||||
|
|
||||||
### 规则:FullPolicy 必须同时支持两个阶段
|
同理,如果 `supports_decode = False`:
|
||||||
|
|
||||||
`FullAttentionPolicy` 作为默认策略,必须同时支持 prefill 和 decode:
|
```python
|
||||||
|
class PrefillOnlyPolicy(SparsePolicy):
|
||||||
|
supports_prefill = True
|
||||||
|
supports_decode = False
|
||||||
|
|
||||||
|
def compute_chunked_prefill(self, ...):
|
||||||
|
# 正常实现
|
||||||
|
...
|
||||||
|
|
||||||
|
def compute_chunked_decode(self, ...):
|
||||||
|
assert False, "PrefillOnlyPolicy does not support decode phase"
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. FullAttentionPolicy 必须同时支持两个阶段
|
||||||
|
|
||||||
```python
|
```python
|
||||||
class FullAttentionPolicy(SparsePolicy):
|
class FullAttentionPolicy(SparsePolicy):
|
||||||
@@ -60,35 +66,27 @@ class FullAttentionPolicy(SparsePolicy):
|
|||||||
# 完整实现
|
# 完整实现
|
||||||
```
|
```
|
||||||
|
|
||||||
## 调用方检查
|
---
|
||||||
|
|
||||||
`attention.py` 中应在调用前检查 policy 是否支持当前阶段:
|
|
||||||
|
|
||||||
```python
|
|
||||||
# Prefill 路径
|
|
||||||
if not sparse_policy.supports_prefill:
|
|
||||||
raise RuntimeError(f"{sparse_policy} does not support prefill")
|
|
||||||
|
|
||||||
# Decode 路径
|
|
||||||
if not sparse_policy.supports_decode:
|
|
||||||
raise RuntimeError(f"{sparse_policy} does not support decode")
|
|
||||||
```
|
|
||||||
|
|
||||||
这样提供双重保护:
|
|
||||||
1. 调用方检查 → 提供清晰的错误信息
|
|
||||||
2. 方法内 assert → 防止绕过检查的调用
|
|
||||||
|
|
||||||
## CPU-GPU 通信规范
|
## CPU-GPU 通信规范
|
||||||
|
|
||||||
### 规则:所有通信必须通过 OffloadEngine
|
### 规则:所有通信必须通过 OffloadEngine
|
||||||
|
|
||||||
在 SparsePolicy 的 `compute_chunked_*` 方法中,所有 CPU-GPU 数据传输**必须**通过 `OffloadEngine` 进行,**禁止**直接使用 `torch.Tensor.copy_()` 或 `.to(device)`:
|
在 `compute_chunked_*` 方法中,**禁止** 直接使用 `torch.Tensor.copy_()` 或 `.to(device)`:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# ✅ 正确:使用 OffloadEngine 的 ring buffer 方法
|
# ✅ 正确:使用 OffloadEngine 的 ring buffer 方法
|
||||||
offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id)
|
offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id)
|
||||||
offload_engine.wait_slot_layer(slot)
|
offload_engine.wait_slot_layer(slot)
|
||||||
k, v = offload_engine.get_kv_for_slot(slot)
|
k, v = offload_engine.get_kv_for_slot(slot)
|
||||||
|
offload_engine.record_slot_compute_done(slot)
|
||||||
|
|
||||||
|
# ✅ 正确:使用 prefill buffer
|
||||||
|
k, v = offload_engine.get_prefill_buffer_slice(layer_id, num_tokens)
|
||||||
|
|
||||||
|
# ✅ 正确:使用 decode buffer
|
||||||
|
decode_k = offload_engine.decode_k_buffer[layer_id, start:end]
|
||||||
|
decode_v = offload_engine.decode_v_buffer[layer_id, start:end]
|
||||||
|
|
||||||
# ❌ 错误:直接使用 torch 通信
|
# ❌ 错误:直接使用 torch 通信
|
||||||
gpu_tensor.copy_(cpu_tensor)
|
gpu_tensor.copy_(cpu_tensor)
|
||||||
@@ -102,3 +100,67 @@ gpu_tensor = cpu_tensor.cuda()
|
|||||||
2. **Pipeline 优化**:OffloadEngine 实现了 ring buffer pipeline
|
2. **Pipeline 优化**:OffloadEngine 实现了 ring buffer pipeline
|
||||||
3. **资源管理**:OffloadEngine 管理 GPU buffer slots,避免内存碎片
|
3. **资源管理**:OffloadEngine 管理 GPU buffer slots,避免内存碎片
|
||||||
4. **一致性**:统一的接口便于调试和维护
|
4. **一致性**:统一的接口便于调试和维护
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 方法签名要求
|
||||||
|
|
||||||
|
### select_blocks()
|
||||||
|
|
||||||
|
```python
|
||||||
|
def select_blocks(
|
||||||
|
self,
|
||||||
|
available_blocks: List[int], # 可用的 CPU block IDs
|
||||||
|
offload_engine: "OffloadEngine", # 用于加载数据
|
||||||
|
ctx: PolicyContext, # 上下文信息
|
||||||
|
) -> List[int]: # 返回要加载的 block IDs
|
||||||
|
```
|
||||||
|
|
||||||
|
### compute_chunked_prefill()
|
||||||
|
|
||||||
|
```python
|
||||||
|
def compute_chunked_prefill(
|
||||||
|
self,
|
||||||
|
q: torch.Tensor, # [seq_len, num_heads, head_dim]
|
||||||
|
k: torch.Tensor, # [seq_len, num_kv_heads, head_dim] (unused)
|
||||||
|
v: torch.Tensor, # [seq_len, num_kv_heads, head_dim] (unused)
|
||||||
|
layer_id: int,
|
||||||
|
softmax_scale: float,
|
||||||
|
offload_engine: "OffloadEngine",
|
||||||
|
kvcache_manager: "KVCacheManager",
|
||||||
|
current_chunk_idx: int,
|
||||||
|
seq: "Sequence",
|
||||||
|
num_tokens: int,
|
||||||
|
) -> torch.Tensor: # [seq_len, num_heads, head_dim]
|
||||||
|
```
|
||||||
|
|
||||||
|
### compute_chunked_decode()
|
||||||
|
|
||||||
|
```python
|
||||||
|
def compute_chunked_decode(
|
||||||
|
self,
|
||||||
|
q: torch.Tensor, # [batch_size, num_heads, head_dim]
|
||||||
|
layer_id: int,
|
||||||
|
softmax_scale: float,
|
||||||
|
offload_engine: "OffloadEngine",
|
||||||
|
kvcache_manager: "KVCacheManager",
|
||||||
|
seq: "Sequence",
|
||||||
|
) -> torch.Tensor: # [batch_size, 1, num_heads, head_dim]
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 可选钩子方法
|
||||||
|
|
||||||
|
| 方法 | 调用时机 | 用途 |
|
||||||
|
|------|---------|------|
|
||||||
|
| `initialize()` | KV cache 分配后 | 初始化 metadata 结构 |
|
||||||
|
| `on_prefill_offload()` | GPU→CPU 复制前(prefill) | 收集 block metadata |
|
||||||
|
| `on_decode_offload()` | GPU→CPU 复制前(decode) | 更新 block metadata |
|
||||||
|
| `reset()` | 新 sequence 开始时 | 重置 policy 状态 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 详细实现指南
|
||||||
|
|
||||||
|
参考文档:[`docs/sparse_policy_implementation_guide.md`](../docs/sparse_policy_implementation_guide.md)
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ Nano-vLLM is a lightweight vLLM implementation (~1,200 lines) for fast offline L
|
|||||||
|----------|---------|
|
|----------|---------|
|
||||||
| [`docs/architecture_guide.md`](docs/architecture_guide.md) | Core components, CPU offload system design, ring buffer architecture, stream configuration |
|
| [`docs/architecture_guide.md`](docs/architecture_guide.md) | Core components, CPU offload system design, ring buffer architecture, stream configuration |
|
||||||
| [`docs/sparse_policy_architecture.md`](docs/sparse_policy_architecture.md) | SparsePolicy abstraction: prefill/decode delegation, pipeline modes, policy implementations |
|
| [`docs/sparse_policy_architecture.md`](docs/sparse_policy_architecture.md) | SparsePolicy abstraction: prefill/decode delegation, pipeline modes, policy implementations |
|
||||||
|
| [`docs/sparse_policy_implementation_guide.md`](docs/sparse_policy_implementation_guide.md) | How to implement custom SparsePolicy: required methods, hooks, ring buffer pipeline pattern |
|
||||||
| [`docs/sparse_attention_guide.md`](docs/sparse_attention_guide.md) | Block sparse attention methods (XAttention, FlexPrefill, MInference, AvgPool, Quest), computation flow, algorithms |
|
| [`docs/sparse_attention_guide.md`](docs/sparse_attention_guide.md) | Block sparse attention methods (XAttention, FlexPrefill, MInference, AvgPool, Quest), computation flow, algorithms |
|
||||||
| [`docs/debugging_guide.md`](docs/debugging_guide.md) | PyTorch hooks for debugging, hook positions, tensor comparison, memory profiling |
|
| [`docs/debugging_guide.md`](docs/debugging_guide.md) | PyTorch hooks for debugging, hook positions, tensor comparison, memory profiling |
|
||||||
| [`docs/optimization_guide.md`](docs/optimization_guide.md) | Performance optimizations: sgDMA (15x), Triton merge (4.3x), N-way pipeline (2x) |
|
| [`docs/optimization_guide.md`](docs/optimization_guide.md) | Performance optimizations: sgDMA (15x), Triton merge (4.3x), N-way pipeline (2x) |
|
||||||
|
|||||||
317
docs/sparse_policy_implementation_guide.md
Normal file
317
docs/sparse_policy_implementation_guide.md
Normal file
@@ -0,0 +1,317 @@
|
|||||||
|
# SparsePolicy Implementation Guide
|
||||||
|
|
||||||
|
This guide describes how to implement a custom `SparsePolicy` for sparse attention in CPU offload mode.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
`SparsePolicy` is an abstract base class that controls:
|
||||||
|
1. **Block Selection**: Which KV cache blocks to load from CPU for each query
|
||||||
|
2. **Attention Computation**: How to compute chunked prefill and decode attention
|
||||||
|
|
||||||
|
All computation happens in the policy, with `attention.py` only delegating to the policy methods.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Base Class Structure
|
||||||
|
|
||||||
|
```python
|
||||||
|
class SparsePolicy(ABC):
|
||||||
|
# Phase support flags (REQUIRED to override)
|
||||||
|
supports_prefill: bool = True
|
||||||
|
supports_decode: bool = True
|
||||||
|
|
||||||
|
# Abstract methods (MUST implement)
|
||||||
|
def select_blocks(self, available_blocks, offload_engine, ctx) -> List[int]
|
||||||
|
def compute_chunked_prefill(self, q, k, v, layer_id, ...) -> torch.Tensor
|
||||||
|
def compute_chunked_decode(self, q, layer_id, ...) -> torch.Tensor
|
||||||
|
|
||||||
|
# Optional hooks (CAN override)
|
||||||
|
def initialize(self, num_layers, num_kv_heads, head_dim, num_cpu_blocks, dtype, device)
|
||||||
|
def on_prefill_offload(self, cpu_block_id, layer_id, k_cache, num_valid_tokens)
|
||||||
|
def on_decode_offload(self, cpu_block_id, layer_id, k_cache, num_valid_tokens)
|
||||||
|
def reset(self)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Required Implementations
|
||||||
|
|
||||||
|
### 1. Phase Support Flags
|
||||||
|
|
||||||
|
Every policy MUST declare which phases it supports:
|
||||||
|
|
||||||
|
```python
|
||||||
|
class MyPolicy(SparsePolicy):
|
||||||
|
supports_prefill = True # Can be used in prefill phase?
|
||||||
|
supports_decode = True # Can be used in decode phase?
|
||||||
|
```
|
||||||
|
|
||||||
|
| Policy Type | supports_prefill | supports_decode | Example |
|
||||||
|
|-------------|------------------|-----------------|---------|
|
||||||
|
| Full support | True | True | `FullAttentionPolicy` |
|
||||||
|
| Decode-only | False | True | `QuestPolicy` |
|
||||||
|
| Prefill-only | True | False | (hypothetical) |
|
||||||
|
|
||||||
|
### 2. select_blocks() - Block Selection
|
||||||
|
|
||||||
|
```python
|
||||||
|
@abstractmethod
|
||||||
|
def select_blocks(
|
||||||
|
self,
|
||||||
|
available_blocks: List[int], # CPU block IDs with historical KV
|
||||||
|
offload_engine: "OffloadEngine",
|
||||||
|
ctx: PolicyContext, # Context about current query
|
||||||
|
) -> List[int]:
|
||||||
|
"""Return subset of available_blocks to load."""
|
||||||
|
```
|
||||||
|
|
||||||
|
**PolicyContext fields:**
|
||||||
|
- `query_chunk_idx`: Current chunk index (0-indexed)
|
||||||
|
- `num_query_chunks`: Total number of chunks
|
||||||
|
- `layer_id`: Transformer layer index
|
||||||
|
- `query`: Query tensor (available for decode)
|
||||||
|
- `is_prefill`: True if prefill phase
|
||||||
|
- `block_size`: Tokens per block
|
||||||
|
- `total_kv_len`: Total KV length so far
|
||||||
|
|
||||||
|
**Example implementations:**
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Full attention: load all blocks
|
||||||
|
def select_blocks(self, available_blocks, offload_engine, ctx):
|
||||||
|
return available_blocks
|
||||||
|
|
||||||
|
# Top-K sparse: load K most important blocks
|
||||||
|
def select_blocks(self, available_blocks, offload_engine, ctx):
|
||||||
|
scores = self.compute_block_scores(available_blocks, ctx.query)
|
||||||
|
topk_indices = scores.topk(self.config.topk).indices
|
||||||
|
return [available_blocks[i] for i in sorted(topk_indices.tolist())]
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. compute_chunked_prefill() - Prefill Attention
|
||||||
|
|
||||||
|
```python
|
||||||
|
@abstractmethod
|
||||||
|
def compute_chunked_prefill(
|
||||||
|
self,
|
||||||
|
q: torch.Tensor, # [seq_len, num_heads, head_dim]
|
||||||
|
k: torch.Tensor, # [seq_len, num_kv_heads, head_dim] (unused)
|
||||||
|
v: torch.Tensor, # [seq_len, num_kv_heads, head_dim] (unused)
|
||||||
|
layer_id: int,
|
||||||
|
softmax_scale: float,
|
||||||
|
offload_engine: "OffloadEngine",
|
||||||
|
kvcache_manager: "KVCacheManager",
|
||||||
|
current_chunk_idx: int,
|
||||||
|
seq: "Sequence",
|
||||||
|
num_tokens: int,
|
||||||
|
) -> torch.Tensor: # [seq_len, num_heads, head_dim]
|
||||||
|
```
|
||||||
|
|
||||||
|
**Required flow:**
|
||||||
|
1. Get historical blocks: `kvcache_manager.get_prefilled_cpu_blocks(seq)`
|
||||||
|
2. Call `select_blocks()` to filter blocks
|
||||||
|
3. Load blocks via ring buffer pipeline
|
||||||
|
4. Get current chunk KV: `offload_engine.get_prefill_buffer_slice(layer_id, num_tokens)`
|
||||||
|
5. Compute attention with `flash_attn_with_lse()` (historical: causal=False, current: causal=True)
|
||||||
|
6. Merge results with `merge_attention_outputs()`
|
||||||
|
7. Return output with shape `[seq_len, num_heads, head_dim]`
|
||||||
|
|
||||||
|
**If policy doesn't support prefill:**
|
||||||
|
```python
|
||||||
|
def compute_chunked_prefill(self, ...):
|
||||||
|
assert False, "MyPolicy does not support prefill phase"
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. compute_chunked_decode() - Decode Attention
|
||||||
|
|
||||||
|
```python
|
||||||
|
@abstractmethod
|
||||||
|
def compute_chunked_decode(
|
||||||
|
self,
|
||||||
|
q: torch.Tensor, # [batch_size, num_heads, head_dim]
|
||||||
|
layer_id: int,
|
||||||
|
softmax_scale: float,
|
||||||
|
offload_engine: "OffloadEngine",
|
||||||
|
kvcache_manager: "KVCacheManager",
|
||||||
|
seq: "Sequence",
|
||||||
|
) -> torch.Tensor: # [batch_size, 1, num_heads, head_dim]
|
||||||
|
```
|
||||||
|
|
||||||
|
**Required flow:**
|
||||||
|
1. Get prefilled blocks: `kvcache_manager.get_prefilled_cpu_blocks(seq)`
|
||||||
|
2. Calculate last block valid tokens from `kvcache_manager.get_prefill_len(seq)`
|
||||||
|
3. Call `select_blocks()` to filter blocks
|
||||||
|
4. Load blocks via `_decode_ring_buffer_pipeline()` helper
|
||||||
|
5. Read decode buffer: `offload_engine.decode_k_buffer[layer_id, ...]`
|
||||||
|
6. Merge results with `merge_attention_outputs()`
|
||||||
|
7. Return output with shape `[batch_size, 1, num_heads, head_dim]`
|
||||||
|
|
||||||
|
**If policy doesn't support decode:**
|
||||||
|
```python
|
||||||
|
def compute_chunked_decode(self, ...):
|
||||||
|
assert False, "MyPolicy does not support decode phase"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Optional Hooks
|
||||||
|
|
||||||
|
### initialize()
|
||||||
|
|
||||||
|
Called after KV cache allocation. Use to create metadata structures.
|
||||||
|
|
||||||
|
```python
|
||||||
|
def initialize(self, num_layers, num_kv_heads, head_dim, num_cpu_blocks, dtype, device):
|
||||||
|
self.metadata = BlockMetadataManager(
|
||||||
|
num_blocks=num_cpu_blocks,
|
||||||
|
num_layers=num_layers,
|
||||||
|
...
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### on_prefill_offload() / on_decode_offload()
|
||||||
|
|
||||||
|
Called BEFORE GPU→CPU copy. Use to collect block metadata while data is still on GPU.
|
||||||
|
|
||||||
|
```python
|
||||||
|
def on_prefill_offload(self, cpu_block_id, layer_id, k_cache, num_valid_tokens):
|
||||||
|
# k_cache is still on GPU here
|
||||||
|
self.metadata.update_min_max(cpu_block_id, layer_id, k_cache, num_valid_tokens)
|
||||||
|
```
|
||||||
|
|
||||||
|
### reset()
|
||||||
|
|
||||||
|
Called when starting new sequence. Use to clear state.
|
||||||
|
|
||||||
|
```python
|
||||||
|
def reset(self):
|
||||||
|
if self.metadata is not None:
|
||||||
|
self.metadata.reset()
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## CPU-GPU Communication Rules
|
||||||
|
|
||||||
|
**MUST use OffloadEngine methods:**
|
||||||
|
```python
|
||||||
|
# Loading blocks
|
||||||
|
offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id)
|
||||||
|
offload_engine.wait_slot_layer(slot)
|
||||||
|
k, v = offload_engine.get_kv_for_slot(slot)
|
||||||
|
offload_engine.record_slot_compute_done(slot)
|
||||||
|
|
||||||
|
# Current chunk KV
|
||||||
|
k, v = offload_engine.get_prefill_buffer_slice(layer_id, num_tokens)
|
||||||
|
|
||||||
|
# Decode buffer
|
||||||
|
decode_k = offload_engine.decode_k_buffer[layer_id, start:end]
|
||||||
|
decode_v = offload_engine.decode_v_buffer[layer_id, start:end]
|
||||||
|
```
|
||||||
|
|
||||||
|
**NEVER do direct transfers:**
|
||||||
|
```python
|
||||||
|
# WRONG!
|
||||||
|
gpu_tensor.copy_(cpu_tensor)
|
||||||
|
gpu_tensor = cpu_tensor.to("cuda")
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Ring Buffer Pipeline Pattern
|
||||||
|
|
||||||
|
The standard pattern for loading blocks:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def _decode_ring_buffer_pipeline(self, q_batched, cpu_block_table, load_slots, ...):
|
||||||
|
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||||
|
|
||||||
|
num_blocks = len(cpu_block_table)
|
||||||
|
num_slots = len(load_slots)
|
||||||
|
o_acc, lse_acc = None, None
|
||||||
|
|
||||||
|
# Phase 1: Pre-load up to num_slots blocks
|
||||||
|
for i in range(min(num_slots, num_blocks)):
|
||||||
|
offload_engine.load_to_slot_layer(load_slots[i], layer_id, cpu_block_table[i])
|
||||||
|
|
||||||
|
# Phase 2: Process with pipeline
|
||||||
|
for block_idx in range(num_blocks):
|
||||||
|
slot = load_slots[block_idx % num_slots]
|
||||||
|
|
||||||
|
# Wait for H2D transfer
|
||||||
|
offload_engine.wait_slot_layer(slot)
|
||||||
|
|
||||||
|
with torch.cuda.stream(offload_engine.compute_stream):
|
||||||
|
# Get KV and compute attention
|
||||||
|
k, v = offload_engine.get_kv_for_slot(slot)
|
||||||
|
o, lse = flash_attn_with_lse(q_batched, k, v, softmax_scale, causal=False)
|
||||||
|
offload_engine.record_slot_compute_done(slot)
|
||||||
|
|
||||||
|
# Pipeline: start next block transfer
|
||||||
|
next_idx = block_idx + num_slots
|
||||||
|
if next_idx < num_blocks:
|
||||||
|
offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_table[next_idx])
|
||||||
|
|
||||||
|
# Merge results
|
||||||
|
with torch.cuda.stream(offload_engine.compute_stream):
|
||||||
|
if o_acc is None:
|
||||||
|
o_acc, lse_acc = o, lse
|
||||||
|
else:
|
||||||
|
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o, lse)
|
||||||
|
|
||||||
|
return o_acc, lse_acc
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Complete Example: Decode-Only Policy
|
||||||
|
|
||||||
|
```python
|
||||||
|
class TopKPolicy(SparsePolicy):
|
||||||
|
"""Load only top-K blocks based on query-key similarity."""
|
||||||
|
|
||||||
|
supports_prefill = False # Use FullAttentionPolicy for prefill
|
||||||
|
supports_decode = True
|
||||||
|
|
||||||
|
def __init__(self, topk: int = 8):
|
||||||
|
self.topk = topk
|
||||||
|
self.metadata = None
|
||||||
|
|
||||||
|
def initialize(self, num_layers, num_kv_heads, head_dim, num_cpu_blocks, dtype, device):
|
||||||
|
self.metadata = BlockMetadataManager(num_cpu_blocks, num_layers, num_kv_heads, head_dim)
|
||||||
|
|
||||||
|
def select_blocks(self, available_blocks, offload_engine, ctx):
|
||||||
|
if len(available_blocks) <= self.topk:
|
||||||
|
return available_blocks
|
||||||
|
|
||||||
|
# Compute scores and select top-K
|
||||||
|
scores = self.metadata.compute_scores(available_blocks, ctx.layer_id, ctx.query)
|
||||||
|
topk_indices = scores.topk(self.topk).indices.cpu().tolist()
|
||||||
|
return [available_blocks[i] for i in sorted(topk_indices)]
|
||||||
|
|
||||||
|
def on_prefill_offload(self, cpu_block_id, layer_id, k_cache, num_valid_tokens):
|
||||||
|
self.metadata.update(cpu_block_id, layer_id, k_cache, num_valid_tokens)
|
||||||
|
|
||||||
|
def compute_chunked_prefill(self, ...):
|
||||||
|
assert False, "TopKPolicy does not support prefill phase"
|
||||||
|
|
||||||
|
def compute_chunked_decode(self, q, layer_id, softmax_scale, offload_engine, kvcache_manager, seq):
|
||||||
|
# Copy implementation from FullAttentionPolicy.compute_chunked_decode
|
||||||
|
# The only difference is select_blocks() will filter to top-K
|
||||||
|
...
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
if self.metadata:
|
||||||
|
self.metadata.reset()
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## File Locations
|
||||||
|
|
||||||
|
| File | Purpose |
|
||||||
|
|------|---------|
|
||||||
|
| `nanovllm/kvcache/sparse/policy.py` | Base class and PolicyContext |
|
||||||
|
| `nanovllm/kvcache/sparse/full_policy.py` | FullAttentionPolicy (reference implementation) |
|
||||||
|
| `nanovllm/kvcache/sparse/quest.py` | QuestPolicy (decode-only example) |
|
||||||
|
| `nanovllm/kvcache/chunked_attention.py` | `flash_attn_with_lse`, `merge_attention_outputs` |
|
||||||
Reference in New Issue
Block a user