📝 docs: add SparsePolicy implementation guide and update rules
- Create docs/sparse_policy_implementation_guide.md with comprehensive guide - Rewrite .claude/rules/sparse-policy.md with mandatory base class requirements - Add new doc reference to CLAUDE.md documentation index Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
317
docs/sparse_policy_implementation_guide.md
Normal file
317
docs/sparse_policy_implementation_guide.md
Normal file
@@ -0,0 +1,317 @@
|
||||
# SparsePolicy Implementation Guide
|
||||
|
||||
This guide describes how to implement a custom `SparsePolicy` for sparse attention in CPU offload mode.
|
||||
|
||||
## Overview
|
||||
|
||||
`SparsePolicy` is an abstract base class that controls:
|
||||
1. **Block Selection**: Which KV cache blocks to load from CPU for each query
|
||||
2. **Attention Computation**: How to compute chunked prefill and decode attention
|
||||
|
||||
All computation happens in the policy, with `attention.py` only delegating to the policy methods.
|
||||
|
||||
---
|
||||
|
||||
## Base Class Structure
|
||||
|
||||
```python
|
||||
class SparsePolicy(ABC):
|
||||
# Phase support flags (REQUIRED to override)
|
||||
supports_prefill: bool = True
|
||||
supports_decode: bool = True
|
||||
|
||||
# Abstract methods (MUST implement)
|
||||
def select_blocks(self, available_blocks, offload_engine, ctx) -> List[int]
|
||||
def compute_chunked_prefill(self, q, k, v, layer_id, ...) -> torch.Tensor
|
||||
def compute_chunked_decode(self, q, layer_id, ...) -> torch.Tensor
|
||||
|
||||
# Optional hooks (CAN override)
|
||||
def initialize(self, num_layers, num_kv_heads, head_dim, num_cpu_blocks, dtype, device)
|
||||
def on_prefill_offload(self, cpu_block_id, layer_id, k_cache, num_valid_tokens)
|
||||
def on_decode_offload(self, cpu_block_id, layer_id, k_cache, num_valid_tokens)
|
||||
def reset(self)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Required Implementations
|
||||
|
||||
### 1. Phase Support Flags
|
||||
|
||||
Every policy MUST declare which phases it supports:
|
||||
|
||||
```python
|
||||
class MyPolicy(SparsePolicy):
|
||||
supports_prefill = True # Can be used in prefill phase?
|
||||
supports_decode = True # Can be used in decode phase?
|
||||
```
|
||||
|
||||
| Policy Type | supports_prefill | supports_decode | Example |
|
||||
|-------------|------------------|-----------------|---------|
|
||||
| Full support | True | True | `FullAttentionPolicy` |
|
||||
| Decode-only | False | True | `QuestPolicy` |
|
||||
| Prefill-only | True | False | (hypothetical) |
|
||||
|
||||
### 2. select_blocks() - Block Selection
|
||||
|
||||
```python
|
||||
@abstractmethod
|
||||
def select_blocks(
|
||||
self,
|
||||
available_blocks: List[int], # CPU block IDs with historical KV
|
||||
offload_engine: "OffloadEngine",
|
||||
ctx: PolicyContext, # Context about current query
|
||||
) -> List[int]:
|
||||
"""Return subset of available_blocks to load."""
|
||||
```
|
||||
|
||||
**PolicyContext fields:**
|
||||
- `query_chunk_idx`: Current chunk index (0-indexed)
|
||||
- `num_query_chunks`: Total number of chunks
|
||||
- `layer_id`: Transformer layer index
|
||||
- `query`: Query tensor (available for decode)
|
||||
- `is_prefill`: True if prefill phase
|
||||
- `block_size`: Tokens per block
|
||||
- `total_kv_len`: Total KV length so far
|
||||
|
||||
**Example implementations:**
|
||||
|
||||
```python
|
||||
# Full attention: load all blocks
|
||||
def select_blocks(self, available_blocks, offload_engine, ctx):
|
||||
return available_blocks
|
||||
|
||||
# Top-K sparse: load K most important blocks
|
||||
def select_blocks(self, available_blocks, offload_engine, ctx):
|
||||
scores = self.compute_block_scores(available_blocks, ctx.query)
|
||||
topk_indices = scores.topk(self.config.topk).indices
|
||||
return [available_blocks[i] for i in sorted(topk_indices.tolist())]
|
||||
```
|
||||
|
||||
### 3. compute_chunked_prefill() - Prefill Attention
|
||||
|
||||
```python
|
||||
@abstractmethod
|
||||
def compute_chunked_prefill(
|
||||
self,
|
||||
q: torch.Tensor, # [seq_len, num_heads, head_dim]
|
||||
k: torch.Tensor, # [seq_len, num_kv_heads, head_dim] (unused)
|
||||
v: torch.Tensor, # [seq_len, num_kv_heads, head_dim] (unused)
|
||||
layer_id: int,
|
||||
softmax_scale: float,
|
||||
offload_engine: "OffloadEngine",
|
||||
kvcache_manager: "KVCacheManager",
|
||||
current_chunk_idx: int,
|
||||
seq: "Sequence",
|
||||
num_tokens: int,
|
||||
) -> torch.Tensor: # [seq_len, num_heads, head_dim]
|
||||
```
|
||||
|
||||
**Required flow:**
|
||||
1. Get historical blocks: `kvcache_manager.get_prefilled_cpu_blocks(seq)`
|
||||
2. Call `select_blocks()` to filter blocks
|
||||
3. Load blocks via ring buffer pipeline
|
||||
4. Get current chunk KV: `offload_engine.get_prefill_buffer_slice(layer_id, num_tokens)`
|
||||
5. Compute attention with `flash_attn_with_lse()` (historical: causal=False, current: causal=True)
|
||||
6. Merge results with `merge_attention_outputs()`
|
||||
7. Return output with shape `[seq_len, num_heads, head_dim]`
|
||||
|
||||
**If policy doesn't support prefill:**
|
||||
```python
|
||||
def compute_chunked_prefill(self, ...):
|
||||
assert False, "MyPolicy does not support prefill phase"
|
||||
```
|
||||
|
||||
### 4. compute_chunked_decode() - Decode Attention
|
||||
|
||||
```python
|
||||
@abstractmethod
|
||||
def compute_chunked_decode(
|
||||
self,
|
||||
q: torch.Tensor, # [batch_size, num_heads, head_dim]
|
||||
layer_id: int,
|
||||
softmax_scale: float,
|
||||
offload_engine: "OffloadEngine",
|
||||
kvcache_manager: "KVCacheManager",
|
||||
seq: "Sequence",
|
||||
) -> torch.Tensor: # [batch_size, 1, num_heads, head_dim]
|
||||
```
|
||||
|
||||
**Required flow:**
|
||||
1. Get prefilled blocks: `kvcache_manager.get_prefilled_cpu_blocks(seq)`
|
||||
2. Calculate last block valid tokens from `kvcache_manager.get_prefill_len(seq)`
|
||||
3. Call `select_blocks()` to filter blocks
|
||||
4. Load blocks via `_decode_ring_buffer_pipeline()` helper
|
||||
5. Read decode buffer: `offload_engine.decode_k_buffer[layer_id, ...]`
|
||||
6. Merge results with `merge_attention_outputs()`
|
||||
7. Return output with shape `[batch_size, 1, num_heads, head_dim]`
|
||||
|
||||
**If policy doesn't support decode:**
|
||||
```python
|
||||
def compute_chunked_decode(self, ...):
|
||||
assert False, "MyPolicy does not support decode phase"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Optional Hooks
|
||||
|
||||
### initialize()
|
||||
|
||||
Called after KV cache allocation. Use to create metadata structures.
|
||||
|
||||
```python
|
||||
def initialize(self, num_layers, num_kv_heads, head_dim, num_cpu_blocks, dtype, device):
|
||||
self.metadata = BlockMetadataManager(
|
||||
num_blocks=num_cpu_blocks,
|
||||
num_layers=num_layers,
|
||||
...
|
||||
)
|
||||
```
|
||||
|
||||
### on_prefill_offload() / on_decode_offload()
|
||||
|
||||
Called BEFORE GPU→CPU copy. Use to collect block metadata while data is still on GPU.
|
||||
|
||||
```python
|
||||
def on_prefill_offload(self, cpu_block_id, layer_id, k_cache, num_valid_tokens):
|
||||
# k_cache is still on GPU here
|
||||
self.metadata.update_min_max(cpu_block_id, layer_id, k_cache, num_valid_tokens)
|
||||
```
|
||||
|
||||
### reset()
|
||||
|
||||
Called when starting new sequence. Use to clear state.
|
||||
|
||||
```python
|
||||
def reset(self):
|
||||
if self.metadata is not None:
|
||||
self.metadata.reset()
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## CPU-GPU Communication Rules
|
||||
|
||||
**MUST use OffloadEngine methods:**
|
||||
```python
|
||||
# Loading blocks
|
||||
offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id)
|
||||
offload_engine.wait_slot_layer(slot)
|
||||
k, v = offload_engine.get_kv_for_slot(slot)
|
||||
offload_engine.record_slot_compute_done(slot)
|
||||
|
||||
# Current chunk KV
|
||||
k, v = offload_engine.get_prefill_buffer_slice(layer_id, num_tokens)
|
||||
|
||||
# Decode buffer
|
||||
decode_k = offload_engine.decode_k_buffer[layer_id, start:end]
|
||||
decode_v = offload_engine.decode_v_buffer[layer_id, start:end]
|
||||
```
|
||||
|
||||
**NEVER do direct transfers:**
|
||||
```python
|
||||
# WRONG!
|
||||
gpu_tensor.copy_(cpu_tensor)
|
||||
gpu_tensor = cpu_tensor.to("cuda")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Ring Buffer Pipeline Pattern
|
||||
|
||||
The standard pattern for loading blocks:
|
||||
|
||||
```python
|
||||
def _decode_ring_buffer_pipeline(self, q_batched, cpu_block_table, load_slots, ...):
|
||||
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||
|
||||
num_blocks = len(cpu_block_table)
|
||||
num_slots = len(load_slots)
|
||||
o_acc, lse_acc = None, None
|
||||
|
||||
# Phase 1: Pre-load up to num_slots blocks
|
||||
for i in range(min(num_slots, num_blocks)):
|
||||
offload_engine.load_to_slot_layer(load_slots[i], layer_id, cpu_block_table[i])
|
||||
|
||||
# Phase 2: Process with pipeline
|
||||
for block_idx in range(num_blocks):
|
||||
slot = load_slots[block_idx % num_slots]
|
||||
|
||||
# Wait for H2D transfer
|
||||
offload_engine.wait_slot_layer(slot)
|
||||
|
||||
with torch.cuda.stream(offload_engine.compute_stream):
|
||||
# Get KV and compute attention
|
||||
k, v = offload_engine.get_kv_for_slot(slot)
|
||||
o, lse = flash_attn_with_lse(q_batched, k, v, softmax_scale, causal=False)
|
||||
offload_engine.record_slot_compute_done(slot)
|
||||
|
||||
# Pipeline: start next block transfer
|
||||
next_idx = block_idx + num_slots
|
||||
if next_idx < num_blocks:
|
||||
offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_table[next_idx])
|
||||
|
||||
# Merge results
|
||||
with torch.cuda.stream(offload_engine.compute_stream):
|
||||
if o_acc is None:
|
||||
o_acc, lse_acc = o, lse
|
||||
else:
|
||||
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o, lse)
|
||||
|
||||
return o_acc, lse_acc
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Complete Example: Decode-Only Policy
|
||||
|
||||
```python
|
||||
class TopKPolicy(SparsePolicy):
|
||||
"""Load only top-K blocks based on query-key similarity."""
|
||||
|
||||
supports_prefill = False # Use FullAttentionPolicy for prefill
|
||||
supports_decode = True
|
||||
|
||||
def __init__(self, topk: int = 8):
|
||||
self.topk = topk
|
||||
self.metadata = None
|
||||
|
||||
def initialize(self, num_layers, num_kv_heads, head_dim, num_cpu_blocks, dtype, device):
|
||||
self.metadata = BlockMetadataManager(num_cpu_blocks, num_layers, num_kv_heads, head_dim)
|
||||
|
||||
def select_blocks(self, available_blocks, offload_engine, ctx):
|
||||
if len(available_blocks) <= self.topk:
|
||||
return available_blocks
|
||||
|
||||
# Compute scores and select top-K
|
||||
scores = self.metadata.compute_scores(available_blocks, ctx.layer_id, ctx.query)
|
||||
topk_indices = scores.topk(self.topk).indices.cpu().tolist()
|
||||
return [available_blocks[i] for i in sorted(topk_indices)]
|
||||
|
||||
def on_prefill_offload(self, cpu_block_id, layer_id, k_cache, num_valid_tokens):
|
||||
self.metadata.update(cpu_block_id, layer_id, k_cache, num_valid_tokens)
|
||||
|
||||
def compute_chunked_prefill(self, ...):
|
||||
assert False, "TopKPolicy does not support prefill phase"
|
||||
|
||||
def compute_chunked_decode(self, q, layer_id, softmax_scale, offload_engine, kvcache_manager, seq):
|
||||
# Copy implementation from FullAttentionPolicy.compute_chunked_decode
|
||||
# The only difference is select_blocks() will filter to top-K
|
||||
...
|
||||
|
||||
def reset(self):
|
||||
if self.metadata:
|
||||
self.metadata.reset()
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## File Locations
|
||||
|
||||
| File | Purpose |
|
||||
|------|---------|
|
||||
| `nanovllm/kvcache/sparse/policy.py` | Base class and PolicyContext |
|
||||
| `nanovllm/kvcache/sparse/full_policy.py` | FullAttentionPolicy (reference implementation) |
|
||||
| `nanovllm/kvcache/sparse/quest.py` | QuestPolicy (decode-only example) |
|
||||
| `nanovllm/kvcache/chunked_attention.py` | `flash_attn_with_lse`, `merge_attention_outputs` |
|
||||
Reference in New Issue
Block a user