Files
nano-vllm/docs/sparse_policy_implementation_guide.md
Zijie Tian 37aecd4d52 📝 docs: add SparsePolicy implementation guide and update rules
- Create docs/sparse_policy_implementation_guide.md with comprehensive guide
- Rewrite .claude/rules/sparse-policy.md with mandatory base class requirements
- Add new doc reference to CLAUDE.md documentation index

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-20 02:25:46 +08:00

10 KiB

SparsePolicy Implementation Guide

This guide describes how to implement a custom SparsePolicy for sparse attention in CPU offload mode.

Overview

SparsePolicy is an abstract base class that controls:

  1. Block Selection: Which KV cache blocks to load from CPU for each query
  2. Attention Computation: How to compute chunked prefill and decode attention

All computation happens in the policy, with attention.py only delegating to the policy methods.


Base Class Structure

class SparsePolicy(ABC):
    # Phase support flags (REQUIRED to override)
    supports_prefill: bool = True
    supports_decode: bool = True

    # Abstract methods (MUST implement)
    def select_blocks(self, available_blocks, offload_engine, ctx) -> List[int]
    def compute_chunked_prefill(self, q, k, v, layer_id, ...) -> torch.Tensor
    def compute_chunked_decode(self, q, layer_id, ...) -> torch.Tensor

    # Optional hooks (CAN override)
    def initialize(self, num_layers, num_kv_heads, head_dim, num_cpu_blocks, dtype, device)
    def on_prefill_offload(self, cpu_block_id, layer_id, k_cache, num_valid_tokens)
    def on_decode_offload(self, cpu_block_id, layer_id, k_cache, num_valid_tokens)
    def reset(self)

Required Implementations

1. Phase Support Flags

Every policy MUST declare which phases it supports:

class MyPolicy(SparsePolicy):
    supports_prefill = True   # Can be used in prefill phase?
    supports_decode = True    # Can be used in decode phase?
Policy Type supports_prefill supports_decode Example
Full support True True FullAttentionPolicy
Decode-only False True QuestPolicy
Prefill-only True False (hypothetical)

2. select_blocks() - Block Selection

@abstractmethod
def select_blocks(
    self,
    available_blocks: List[int],  # CPU block IDs with historical KV
    offload_engine: "OffloadEngine",
    ctx: PolicyContext,           # Context about current query
) -> List[int]:
    """Return subset of available_blocks to load."""

PolicyContext fields:

  • query_chunk_idx: Current chunk index (0-indexed)
  • num_query_chunks: Total number of chunks
  • layer_id: Transformer layer index
  • query: Query tensor (available for decode)
  • is_prefill: True if prefill phase
  • block_size: Tokens per block
  • total_kv_len: Total KV length so far

Example implementations:

# Full attention: load all blocks
def select_blocks(self, available_blocks, offload_engine, ctx):
    return available_blocks

# Top-K sparse: load K most important blocks
def select_blocks(self, available_blocks, offload_engine, ctx):
    scores = self.compute_block_scores(available_blocks, ctx.query)
    topk_indices = scores.topk(self.config.topk).indices
    return [available_blocks[i] for i in sorted(topk_indices.tolist())]

3. compute_chunked_prefill() - Prefill Attention

@abstractmethod
def compute_chunked_prefill(
    self,
    q: torch.Tensor,              # [seq_len, num_heads, head_dim]
    k: torch.Tensor,              # [seq_len, num_kv_heads, head_dim] (unused)
    v: torch.Tensor,              # [seq_len, num_kv_heads, head_dim] (unused)
    layer_id: int,
    softmax_scale: float,
    offload_engine: "OffloadEngine",
    kvcache_manager: "KVCacheManager",
    current_chunk_idx: int,
    seq: "Sequence",
    num_tokens: int,
) -> torch.Tensor:  # [seq_len, num_heads, head_dim]

Required flow:

  1. Get historical blocks: kvcache_manager.get_prefilled_cpu_blocks(seq)
  2. Call select_blocks() to filter blocks
  3. Load blocks via ring buffer pipeline
  4. Get current chunk KV: offload_engine.get_prefill_buffer_slice(layer_id, num_tokens)
  5. Compute attention with flash_attn_with_lse() (historical: causal=False, current: causal=True)
  6. Merge results with merge_attention_outputs()
  7. Return output with shape [seq_len, num_heads, head_dim]

If policy doesn't support prefill:

def compute_chunked_prefill(self, ...):
    assert False, "MyPolicy does not support prefill phase"

4. compute_chunked_decode() - Decode Attention

@abstractmethod
def compute_chunked_decode(
    self,
    q: torch.Tensor,              # [batch_size, num_heads, head_dim]
    layer_id: int,
    softmax_scale: float,
    offload_engine: "OffloadEngine",
    kvcache_manager: "KVCacheManager",
    seq: "Sequence",
) -> torch.Tensor:  # [batch_size, 1, num_heads, head_dim]

Required flow:

  1. Get prefilled blocks: kvcache_manager.get_prefilled_cpu_blocks(seq)
  2. Calculate last block valid tokens from kvcache_manager.get_prefill_len(seq)
  3. Call select_blocks() to filter blocks
  4. Load blocks via _decode_ring_buffer_pipeline() helper
  5. Read decode buffer: offload_engine.decode_k_buffer[layer_id, ...]
  6. Merge results with merge_attention_outputs()
  7. Return output with shape [batch_size, 1, num_heads, head_dim]

If policy doesn't support decode:

def compute_chunked_decode(self, ...):
    assert False, "MyPolicy does not support decode phase"

Optional Hooks

initialize()

Called after KV cache allocation. Use to create metadata structures.

def initialize(self, num_layers, num_kv_heads, head_dim, num_cpu_blocks, dtype, device):
    self.metadata = BlockMetadataManager(
        num_blocks=num_cpu_blocks,
        num_layers=num_layers,
        ...
    )

on_prefill_offload() / on_decode_offload()

Called BEFORE GPU→CPU copy. Use to collect block metadata while data is still on GPU.

def on_prefill_offload(self, cpu_block_id, layer_id, k_cache, num_valid_tokens):
    # k_cache is still on GPU here
    self.metadata.update_min_max(cpu_block_id, layer_id, k_cache, num_valid_tokens)

reset()

Called when starting new sequence. Use to clear state.

def reset(self):
    if self.metadata is not None:
        self.metadata.reset()

CPU-GPU Communication Rules

MUST use OffloadEngine methods:

# Loading blocks
offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id)
offload_engine.wait_slot_layer(slot)
k, v = offload_engine.get_kv_for_slot(slot)
offload_engine.record_slot_compute_done(slot)

# Current chunk KV
k, v = offload_engine.get_prefill_buffer_slice(layer_id, num_tokens)

# Decode buffer
decode_k = offload_engine.decode_k_buffer[layer_id, start:end]
decode_v = offload_engine.decode_v_buffer[layer_id, start:end]

NEVER do direct transfers:

# WRONG!
gpu_tensor.copy_(cpu_tensor)
gpu_tensor = cpu_tensor.to("cuda")

Ring Buffer Pipeline Pattern

The standard pattern for loading blocks:

def _decode_ring_buffer_pipeline(self, q_batched, cpu_block_table, load_slots, ...):
    from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs

    num_blocks = len(cpu_block_table)
    num_slots = len(load_slots)
    o_acc, lse_acc = None, None

    # Phase 1: Pre-load up to num_slots blocks
    for i in range(min(num_slots, num_blocks)):
        offload_engine.load_to_slot_layer(load_slots[i], layer_id, cpu_block_table[i])

    # Phase 2: Process with pipeline
    for block_idx in range(num_blocks):
        slot = load_slots[block_idx % num_slots]

        # Wait for H2D transfer
        offload_engine.wait_slot_layer(slot)

        with torch.cuda.stream(offload_engine.compute_stream):
            # Get KV and compute attention
            k, v = offload_engine.get_kv_for_slot(slot)
            o, lse = flash_attn_with_lse(q_batched, k, v, softmax_scale, causal=False)
            offload_engine.record_slot_compute_done(slot)

        # Pipeline: start next block transfer
        next_idx = block_idx + num_slots
        if next_idx < num_blocks:
            offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_table[next_idx])

        # Merge results
        with torch.cuda.stream(offload_engine.compute_stream):
            if o_acc is None:
                o_acc, lse_acc = o, lse
            else:
                o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o, lse)

    return o_acc, lse_acc

Complete Example: Decode-Only Policy

class TopKPolicy(SparsePolicy):
    """Load only top-K blocks based on query-key similarity."""

    supports_prefill = False  # Use FullAttentionPolicy for prefill
    supports_decode = True

    def __init__(self, topk: int = 8):
        self.topk = topk
        self.metadata = None

    def initialize(self, num_layers, num_kv_heads, head_dim, num_cpu_blocks, dtype, device):
        self.metadata = BlockMetadataManager(num_cpu_blocks, num_layers, num_kv_heads, head_dim)

    def select_blocks(self, available_blocks, offload_engine, ctx):
        if len(available_blocks) <= self.topk:
            return available_blocks

        # Compute scores and select top-K
        scores = self.metadata.compute_scores(available_blocks, ctx.layer_id, ctx.query)
        topk_indices = scores.topk(self.topk).indices.cpu().tolist()
        return [available_blocks[i] for i in sorted(topk_indices)]

    def on_prefill_offload(self, cpu_block_id, layer_id, k_cache, num_valid_tokens):
        self.metadata.update(cpu_block_id, layer_id, k_cache, num_valid_tokens)

    def compute_chunked_prefill(self, ...):
        assert False, "TopKPolicy does not support prefill phase"

    def compute_chunked_decode(self, q, layer_id, softmax_scale, offload_engine, kvcache_manager, seq):
        # Copy implementation from FullAttentionPolicy.compute_chunked_decode
        # The only difference is select_blocks() will filter to top-K
        ...

    def reset(self):
        if self.metadata:
            self.metadata.reset()

File Locations

File Purpose
nanovllm/kvcache/sparse/policy.py Base class and PolicyContext
nanovllm/kvcache/sparse/full_policy.py FullAttentionPolicy (reference implementation)
nanovllm/kvcache/sparse/quest.py QuestPolicy (decode-only example)
nanovllm/kvcache/chunked_attention.py flash_attn_with_lse, merge_attention_outputs