- Remove cross-layer pipeline from OffloadEngine (saves ~1GB GPU memory for long sequences) - Delete layer_k/v_buffer_a/b double buffers - Remove start_decode_pipeline, get_decode_layer_kv, end_decode_pipeline methods - Remove pipeline state tracking variables - Simplify decode to use ring buffer pipeline only (more efficient for long sequences) - Rename compute_chunked_attention → compute_chunked_prefill for clarity - Add mandatory needle test requirements: --enable-offload --input-len 32768 Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
8.9 KiB
SparsePolicy Architecture Guide
This document describes the SparsePolicy abstraction for chunked attention computation in CPU offload mode.
Overview
SparsePolicy is an abstract base class that defines how attention is computed during chunked prefill and decode phases. All attention computation logic is delegated to the policy, allowing different sparse attention strategies to be implemented without modifying the core attention layer.
attention.py SparsePolicy
| |
| _chunked_prefill_attention |
| ────────────────────────────> | compute_chunked_prefill()
| |
| _chunked_decode_attention |
| ────────────────────────────> | compute_chunked_decode()
| |
Key Design Principles
- Delegation Pattern:
attention.pyonly validates and delegates; all computation is in the policy - No Direct Imports:
attention.pydoes not importflash_attn_with_lseormerge_attention_outputs - Pipeline Encapsulation: Ring buffer and cross-layer pipelines are internal to the policy
- Phase Support Flags: Policies declare which phases they support via
supports_prefillandsupports_decode
SparsePolicy Base Class
File: nanovllm/kvcache/sparse/policy.py
Class Attributes
| Attribute | Type | Description |
|---|---|---|
supports_prefill |
bool | Whether policy supports prefill phase |
supports_decode |
bool | Whether policy supports decode phase |
Abstract Methods
@abstractmethod
def select_blocks(
self,
available_blocks: List[int],
offload_engine: "OffloadEngine",
ctx: PolicyContext,
) -> List[int]:
"""Select which KV blocks to load for the current query chunk."""
pass
@abstractmethod
def compute_chunked_prefill(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer_id: int,
softmax_scale: float,
offload_engine: "OffloadEngine",
kvcache_manager: "KVCacheManager",
current_chunk_idx: int,
seq: "Sequence",
num_tokens: int,
) -> torch.Tensor:
"""Compute chunked prefill attention (complete flow)."""
pass
@abstractmethod
def compute_chunked_decode(
self,
q: torch.Tensor,
layer_id: int,
softmax_scale: float,
offload_engine: "OffloadEngine",
kvcache_manager: "KVCacheManager",
seq: "Sequence",
) -> torch.Tensor:
"""Compute chunked decode attention (complete flow)."""
pass
Hook Methods
| Method | When Called | Purpose |
|---|---|---|
initialize() |
After KV cache allocation | Initialize policy resources (e.g., metadata) |
on_prefill_offload() |
Before GPU→CPU copy during prefill | Collect block metadata |
on_decode_offload() |
Before GPU→CPU copy during decode | Update block metadata |
reset() |
New sequence / clear state | Reset policy state |
FullAttentionPolicy
File: nanovllm/kvcache/sparse/full_policy.py
The default policy that loads all blocks (no sparsity). Serves as the baseline implementation.
Flags
supports_prefill = True
supports_decode = True
Prefill Flow (compute_chunked_prefill)
1. Get historical blocks from kvcache_manager
└── cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
2. Apply select_blocks (returns all for FullPolicy)
└── cpu_block_table = self.select_blocks(cpu_block_table, offload_engine, ctx)
3. Load and compute historical blocks via ring buffer
└── For each block:
a. load_to_slot_layer(slot, layer_id, cpu_block_id)
b. wait_slot_layer(slot)
c. prev_k, prev_v = get_kv_for_slot(slot)
d. flash_attn_with_lse(q, prev_k, prev_v, causal=False)
e. merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
4. Compute current chunk attention (causal)
└── k_curr, v_curr = offload_engine.get_prefill_buffer_slice(layer_id, num_tokens)
└── flash_attn_with_lse(q, k_curr, v_curr, causal=True)
5. Merge historical and current attention
└── merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
Decode Flow (compute_chunked_decode)
1. Get prefilled CPU blocks
└── cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
2. Calculate last block valid tokens
└── total_prefill_tokens = kvcache_manager.get_prefill_len(seq)
└── last_block_valid_tokens = total_prefill_tokens % block_size
3. Apply select_blocks for block filtering
└── cpu_block_table = self.select_blocks(cpu_block_table, offload_engine, ctx)
4. Load prefilled blocks via ring buffer pipeline
└── _decode_ring_buffer_pipeline()
5. Read accumulated decode tokens from decode buffer
└── decode_k = offload_engine.decode_k_buffer[layer_id, start:end]
└── decode_v = offload_engine.decode_v_buffer[layer_id, start:end]
└── flash_attn_with_lse(q, decode_k, decode_v, causal=False)
6. Merge all results
└── merge_attention_outputs(o_acc, lse_acc, decode_o, decode_lse)
Ring Buffer Pipeline
The ring buffer pipeline (_decode_ring_buffer_pipeline) loads blocks one by one using GPU ring buffer slots. This approach is memory-efficient and works well for both short and long sequences.
Slot[0]: Block A ──> Compute ──> Block C ──> Compute
Slot[1]: Block B ──> Compute ──> Block D ──> Compute
Advantages:
- Memory efficient (only needs a few GPU slots)
- Fine-grained overlap between H2D transfer and compute
- Works well for long sequences
Flow:
# Phase 1: Pre-load up to num_slots blocks
for i in range(num_preload):
offload_engine.load_to_slot_layer(load_slots[i], layer_id, cpu_block_table[i])
# Phase 2: Process blocks with pipeline
for block_idx in range(num_blocks):
current_slot = load_slots[block_idx % num_slots]
# Wait for transfer
offload_engine.wait_slot_layer(current_slot)
# Compute attention
prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot)
prev_o, prev_lse = flash_attn_with_lse(q, prev_k, prev_v, causal=False)
offload_engine.record_slot_compute_done(current_slot)
# Pipeline: start loading next block
if next_block_idx < num_blocks:
offload_engine.load_to_slot_layer(current_slot, layer_id, cpu_block_table[next_block_idx])
# Merge results
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
Code Conventions
Unsupported Phases Must Assert False
If a policy doesn't support a phase, the corresponding method must assert False:
class PrefillOnlyPolicy(SparsePolicy):
supports_prefill = True
supports_decode = False
def compute_chunked_prefill(self, ...):
# Normal prefill implementation
...
def compute_chunked_decode(self, ...):
assert False, "PrefillOnlyPolicy does not support decode phase"
Caller Must Check Support Flags
attention.py checks support flags before calling:
if not sparse_policy.supports_decode:
raise RuntimeError(f"{sparse_policy} does not support decode phase")
This provides double protection:
- Caller check → Clear error message
- Method assert → Prevents bypassing the check
CPU-GPU Communication via OffloadEngine Only
All CPU-GPU data transfers must go through OffloadEngine methods:
# Correct: Use OffloadEngine methods
offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id)
offload_engine.wait_slot_layer(slot)
k, v = offload_engine.get_kv_for_slot(slot)
# Incorrect: Direct torch operations
gpu_tensor.copy_(cpu_tensor) # DON'T DO THIS
gpu_tensor = cpu_tensor.to("cuda") # DON'T DO THIS
File Structure
| File | Purpose |
|---|---|
nanovllm/kvcache/sparse/policy.py |
Base class, PolicyContext, abstract methods |
nanovllm/kvcache/sparse/full_policy.py |
FullAttentionPolicy implementation |
nanovllm/kvcache/sparse/quest.py |
QuestPolicy (decode-only Top-K selection) |
nanovllm/layers/attention.py |
Attention layer, delegates to policy |
Policy Implementations
| Policy | supports_prefill | supports_decode | Description |
|---|---|---|---|
FullAttentionPolicy |
True | True | Loads all blocks (baseline) |
QuestPolicy |
False | True | Decode-only Top-K selection |
XAttentionBSAPolicy |
False | False | Placeholder for future BSA |
Testing
Run needle-in-haystack test with offload:
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/path/to/nano-vllm:$PYTHONPATH \
python tests/test_needle.py --model ~/models/Llama-3.1-8B-Instruct --enable-offload
Expected output:
Needle-in-Haystack Test
Model: Llama-3.1-8B-Instruct
CPU offload: True
Sparse policy: FULL
Result: PASSED