[feat] Added chunked prefill and kvcache offload mechenism.

This commit is contained in:
Zijie Tian
2025-12-10 03:47:37 +08:00
parent 204fe2b38f
commit 0b6f19242d
25 changed files with 4414 additions and 61 deletions

View File

@@ -17,6 +17,16 @@ class Config:
kvcache_block_size: int = 256
num_kvcache_blocks: int = -1
# CPU Offload configuration
enable_cpu_offload: bool = False
cpu_memory_gb: float = 16.0 # CPU memory limit for KV cache
offload_policy: str = "lru" # "lru", "fifo", or full class path
num_transfer_streams: int = 4 # Number of CUDA streams for async transfers
# Computed fields for offload (set in __post_init__ or by ModelRunner)
num_gpu_kvcache_blocks: int = -1
num_cpu_kvcache_blocks: int = -1
def __post_init__(self):
assert os.path.isdir(self.model)
assert self.kvcache_block_size % 256 == 0

View File

@@ -31,7 +31,7 @@ class LLMEngine:
self.model_runner = ModelRunner(config, 0, self.events)
self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True)
config.eos = self.tokenizer.eos_token_id
self.scheduler = Scheduler(config)
self.scheduler = Scheduler(config, self.model_runner.kvcache_manager)
atexit.register(self.exit)
def exit(self):

View File

@@ -10,6 +10,7 @@ from nanovllm.models.qwen3 import Qwen3ForCausalLM
from nanovllm.layers.sampler import Sampler
from nanovllm.utils.context import set_context, get_context, reset_context
from nanovllm.utils.loader import load_model
from nanovllm.kvcache import create_kvcache_manager, KVCacheManager
class ModelRunner:
@@ -107,14 +108,45 @@ class ModelRunner:
num_kv_heads = hf_config.num_key_value_heads // self.world_size
head_dim = getattr(hf_config, "head_dim", hf_config.hidden_size // hf_config.num_attention_heads)
block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * head_dim * hf_config.torch_dtype.itemsize
config.num_kvcache_blocks = int(total * config.gpu_memory_utilization - used - peak + current) // block_bytes
assert config.num_kvcache_blocks > 0
self.kv_cache = torch.empty(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, head_dim)
# Calculate GPU block count
num_gpu_blocks = int(total * config.gpu_memory_utilization - used - peak + current) // block_bytes
assert num_gpu_blocks > 0
if config.enable_cpu_offload:
# Calculate CPU blocks based on cpu_memory_gb
cpu_bytes = int(config.cpu_memory_gb * 1024**3)
num_cpu_blocks = cpu_bytes // block_bytes
config.num_gpu_kvcache_blocks = num_gpu_blocks
config.num_cpu_kvcache_blocks = num_cpu_blocks
# For backward compatibility
config.num_kvcache_blocks = num_gpu_blocks + num_cpu_blocks
else:
config.num_kvcache_blocks = num_gpu_blocks
config.num_gpu_kvcache_blocks = num_gpu_blocks
config.num_cpu_kvcache_blocks = 0
# Create KV cache manager using factory
self.kvcache_manager: KVCacheManager = create_kvcache_manager(config)
# Allocate cache through manager
self.kvcache_manager.allocate_cache(
num_layers=hf_config.num_hidden_layers,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
dtype=hf_config.torch_dtype,
)
# Bind layer caches to attention modules and set layer_id
layer_id = 0
for module in self.model.modules():
if hasattr(module, "k_cache") and hasattr(module, "v_cache"):
module.k_cache = self.kv_cache[0, layer_id]
module.v_cache = self.kv_cache[1, layer_id]
k_cache, v_cache = self.kvcache_manager.get_layer_cache(layer_id)
module.k_cache = k_cache
module.v_cache = v_cache
# Set layer_id for chunked prefill support
if hasattr(module, "layer_id"):
module.layer_id = layer_id
layer_id += 1
def prepare_block_tables(self, seqs: list[Sequence]):
@@ -123,7 +155,30 @@ class ModelRunner:
block_tables = torch.tensor(block_tables, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
return block_tables
def prepare_prefill(self, seqs: list[Sequence]):
def prepare_prefill(self, seqs: list[Sequence], chunk_info: list[tuple] = None):
"""
Prepare inputs for prefill.
Args:
seqs: List of sequences to prefill
chunk_info: Optional chunked prefill info from get_gpu_block_tables_partial().
If provided, only process blocks in the chunk.
Format: [(gpu_block_ids, start_block_idx, end_block_idx), ...]
"""
# Check if any sequence has blocks (not warmup)
has_blocks = any(seq.block_table for seq in seqs)
gpu_block_tables = None
if has_blocks and hasattr(self, 'kvcache_manager'):
if chunk_info is None:
# Standard prefill - try to get all blocks
# This may fail if GPU doesn't have enough capacity
self.kvcache_manager.prepare_for_attention(seqs, is_prefill=True)
gpu_block_tables = self.kvcache_manager.get_gpu_block_tables(seqs)
else:
# Chunked prefill - use provided chunk info
gpu_block_tables = [info[0] for info in chunk_info]
input_ids = []
positions = []
cu_seqlens_q = [0]
@@ -132,27 +187,67 @@ class ModelRunner:
max_seqlen_k = 0
slot_mapping = []
block_tables = None
for seq in seqs:
seqlen = len(seq)
input_ids.extend(seq[seq.num_cached_tokens:])
positions.extend(list(range(seq.num_cached_tokens, seqlen)))
seqlen_q = seqlen - seq.num_cached_tokens
seqlen_k = seqlen
cu_seqlens_q.append(cu_seqlens_q[-1] + seqlen_q)
cu_seqlens_k.append(cu_seqlens_k[-1] + seqlen_k)
max_seqlen_q = max(seqlen_q, max_seqlen_q)
max_seqlen_k = max(seqlen_k, max_seqlen_k)
if not seq.block_table: # warmup
continue
for i in range(seq.num_cached_blocks, seq.num_blocks):
start = seq.block_table[i] * self.block_size
if i != seq.num_blocks - 1:
end = start + self.block_size
else:
end = start + seq.last_block_num_tokens
slot_mapping.extend(list(range(start, end)))
if cu_seqlens_k[-1] > cu_seqlens_q[-1]: # prefix cache
block_tables = self.prepare_block_tables(seqs)
for seq_idx, seq in enumerate(seqs):
if chunk_info is not None:
# Chunked prefill: only process blocks in the chunk
gpu_blocks, start_block_idx, end_block_idx = chunk_info[seq_idx]
if not gpu_blocks:
continue
# Calculate token range for this chunk
start_token = start_block_idx * self.block_size
end_token = min(end_block_idx * self.block_size, len(seq))
if end_block_idx == seq.num_blocks:
# Last chunk includes partial last block
end_token = len(seq)
# Input tokens for this chunk
chunk_tokens = seq[start_token:end_token]
input_ids.extend(chunk_tokens)
positions.extend(list(range(start_token, end_token)))
seqlen_q = end_token - start_token
seqlen_k = end_token # Context includes all tokens up to this point
cu_seqlens_q.append(cu_seqlens_q[-1] + seqlen_q)
cu_seqlens_k.append(cu_seqlens_k[-1] + seqlen_k)
max_seqlen_q = max(seqlen_q, max_seqlen_q)
max_seqlen_k = max(seqlen_k, max_seqlen_k)
# Slot mapping for blocks in this chunk
for i, gpu_block_id in enumerate(gpu_blocks):
block_idx = start_block_idx + i
start = gpu_block_id * self.block_size
if block_idx != seq.num_blocks - 1:
end = start + self.block_size
else:
end = start + seq.last_block_num_tokens
slot_mapping.extend(list(range(start, end)))
else:
# Standard prefill
seqlen = len(seq)
input_ids.extend(seq[seq.num_cached_tokens:])
positions.extend(list(range(seq.num_cached_tokens, seqlen)))
seqlen_q = seqlen - seq.num_cached_tokens
seqlen_k = seqlen
cu_seqlens_q.append(cu_seqlens_q[-1] + seqlen_q)
cu_seqlens_k.append(cu_seqlens_k[-1] + seqlen_k)
max_seqlen_q = max(seqlen_q, max_seqlen_q)
max_seqlen_k = max(seqlen_k, max_seqlen_k)
if not seq.block_table: # warmup
continue
# Use GPU physical block IDs for slot mapping
gpu_blocks = gpu_block_tables[seq_idx]
for i in range(seq.num_cached_blocks, seq.num_blocks):
start = gpu_blocks[i] * self.block_size
if i != seq.num_blocks - 1:
end = start + self.block_size
else:
end = start + seq.last_block_num_tokens
slot_mapping.extend(list(range(start, end)))
if cu_seqlens_k[-1] > cu_seqlens_q[-1] and gpu_block_tables: # prefix cache
block_tables = self._prepare_gpu_block_tables(gpu_block_tables)
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
cu_seqlens_q = torch.tensor(cu_seqlens_q, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
@@ -162,23 +257,40 @@ class ModelRunner:
return input_ids, positions
def prepare_decode(self, seqs: list[Sequence]):
# Prepare KV cache (updates gather_indices for hybrid manager)
if hasattr(self, 'kvcache_manager'):
self.kvcache_manager.prepare_for_attention(seqs, is_prefill=False)
# Get GPU physical block tables
gpu_block_tables = self.kvcache_manager.get_gpu_block_tables(seqs)
else:
gpu_block_tables = [list(seq.block_table) for seq in seqs]
input_ids = []
positions = []
slot_mapping = []
context_lens = []
for seq in seqs:
for seq_idx, seq in enumerate(seqs):
input_ids.append(seq.last_token)
positions.append(len(seq) - 1)
context_lens.append(len(seq))
slot_mapping.append(seq.block_table[-1] * self.block_size + seq.last_block_num_tokens - 1)
# Use GPU physical block ID for slot mapping
gpu_blocks = gpu_block_tables[seq_idx]
slot_mapping.append(gpu_blocks[-1] * self.block_size + seq.last_block_num_tokens - 1)
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
context_lens = torch.tensor(context_lens, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
block_tables = self.prepare_block_tables(seqs)
# Use GPU physical block tables for attention
block_tables = self._prepare_gpu_block_tables(gpu_block_tables)
set_context(False, slot_mapping=slot_mapping, context_lens=context_lens, block_tables=block_tables)
return input_ids, positions
def _prepare_gpu_block_tables(self, gpu_block_tables: list[list[int]]):
"""Prepare block tables tensor from GPU physical block IDs."""
max_len = max(len(bt) for bt in gpu_block_tables)
padded = [bt + [-1] * (max_len - len(bt)) for bt in gpu_block_tables]
return torch.tensor(padded, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
def prepare_sample(self, seqs: list[Sequence]):
temperatures = []
for seq in seqs:
@@ -206,6 +318,26 @@ class ModelRunner:
return self.model.compute_logits(graph_vars["outputs"][:bs])
def run(self, seqs: list[Sequence], is_prefill: bool) -> list[int]:
# Check if chunked prefill is needed
if is_prefill and hasattr(self, 'kvcache_manager'):
needs_chunked = any(
hasattr(self.kvcache_manager, 'needs_chunked_prefill') and
self.kvcache_manager.needs_chunked_prefill(seq)
for seq in seqs if seq.block_table
)
if needs_chunked:
return self.run_chunked_prefill(seqs)
# Check if chunked decode is needed
if not is_prefill and hasattr(self, 'kvcache_manager'):
needs_chunked = any(
hasattr(self.kvcache_manager, 'needs_chunked_decode') and
self.kvcache_manager.needs_chunked_decode(seq)
for seq in seqs if seq.block_table
)
if needs_chunked:
return self.run_chunked_decode(seqs)
input_ids, positions = self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs)
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
logits = self.run_model(input_ids, positions, is_prefill)
@@ -213,6 +345,204 @@ class ModelRunner:
reset_context()
return token_ids
def run_chunked_prefill(self, seqs: list[Sequence]) -> list[int]:
"""
Run prefill in chunks when sequences exceed GPU capacity.
For each chunk:
1. Process tokens through model forward pass
2. At each attention layer:
- Load previous KV from CPU (handled by attention layer)
- Compute attention with online softmax merging
- Store current KV to GPU cache
3. After chunk completes, offload KV to CPU
4. Load next chunk's blocks to GPU
"""
import sys
# Currently only supporting single sequence for chunked prefill
assert len(seqs) == 1, "Chunked prefill only supports single sequence"
seq = seqs[0]
total_blocks = seq.num_blocks
print(f"[Chunked Prefill] Starting: {total_blocks} total blocks, "
f"GPU slots: {self.kvcache_manager.num_gpu_slots}", file=sys.stderr)
chunk_num = 0
logits = None
while True:
# Get chunk info (which blocks are on GPU and not yet prefilled)
chunk_info = self.kvcache_manager.get_gpu_block_tables_partial(seqs)
gpu_blocks, start_block_idx, end_block_idx = chunk_info[0]
if not gpu_blocks:
# No more blocks to process
break
chunk_num += 1
chunk_tokens = (end_block_idx - start_block_idx) * self.block_size
if end_block_idx == seq.num_blocks:
# Last block may be partial
chunk_tokens = len(seq) - start_block_idx * self.block_size
print(f"[Chunked Prefill] Chunk {chunk_num}: blocks {start_block_idx}-{end_block_idx-1}, "
f"~{chunk_tokens} tokens", file=sys.stderr)
# Prepare inputs for this chunk
input_ids, positions = self._prepare_chunked_prefill(seq, gpu_blocks, start_block_idx, end_block_idx)
if input_ids.numel() == 0:
print(f"[Chunked Prefill] No input tokens, breaking", file=sys.stderr)
break
print(f"[Chunked Prefill] Running model with {input_ids.numel()} tokens...", file=sys.stderr)
# Run model forward pass
logits = self.run_model(input_ids, positions, is_prefill=True)
reset_context()
print(f"[Chunked Prefill] Model forward complete", file=sys.stderr)
# Check if this is the last chunk
# Mark current chunk as prefilled and offload to CPU
self.kvcache_manager.complete_prefill_chunk(seq)
# Check if more chunks needed
if not self.kvcache_manager.needs_chunked_prefill(seq):
print(f"[Chunked Prefill] All chunks done, sampling", file=sys.stderr)
break
print(f"[Chunked Prefill] Chunk transfer complete, loading next...", file=sys.stderr)
# Sample from the last chunk's logits
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
if logits is not None:
token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None
else:
token_ids = [0] if self.rank == 0 else None
return token_ids
def run_chunked_decode(self, seqs: list[Sequence]) -> list[int]:
"""
Run decode with chunked attention when sequence exceeds GPU capacity.
For decode, we need attention over ALL previous tokens. With CPU offload,
we load KV chunks and compute attention incrementally.
"""
import sys
# Currently only supporting single sequence for chunked decode
assert len(seqs) == 1, "Chunked decode only supports single sequence"
seq = seqs[0]
total_blocks = len(seq.block_table)
print(f"[Chunked Decode] Sequence has {total_blocks} blocks, "
f"GPU slots: {self.kvcache_manager.num_gpu_slots}", file=sys.stderr)
# Prepare inputs
input_ids = torch.tensor([seq.last_token], dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
positions = torch.tensor([len(seq) - 1], dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
# Compute slot mapping for the new token
# Get the last block's GPU slot if it's on GPU, otherwise we need to handle it
last_logical_id = seq.block_table[-1]
last_block = self.kvcache_manager.logical_blocks[last_logical_id]
if last_block.location.name == "GPU":
slot = last_block.gpu_slot * self.block_size + seq.last_block_num_tokens - 1
else:
# Last block is on CPU - we need to bring it to GPU for writing the new token
# This is a special case - allocate a temporary GPU slot
# For simplicity, use a fixed slot (this might conflict, but for decode
# we only write 1 token so it should be ok)
print(f"[Chunked Decode] Warning: last block on CPU, using temp slot", file=sys.stderr)
slot = 0 # Use first slot temporarily
slot_mapping = torch.tensor([slot], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
context_len = torch.tensor([len(seq)], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
# Set up context for chunked decode
set_context(
is_prefill=False, # Decode mode
slot_mapping=slot_mapping,
context_lens=context_len,
is_chunked_prefill=True, # Use chunked attention
offload_engine=self.kvcache_manager,
chunked_seq=seq,
)
# Run model forward pass
logits = self.run_model(input_ids, positions, is_prefill=False)
reset_context()
# Sample
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None
return token_ids
def _prepare_chunked_prefill(
self,
seq: Sequence,
gpu_blocks: list[int],
start_block_idx: int,
end_block_idx: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Prepare inputs for a single chunk in chunked prefill.
Sets up context with is_chunked_prefill=True so attention layers
know to load previous KV from CPU.
"""
# Calculate token range for this chunk
start_token = start_block_idx * self.block_size
end_token = min(end_block_idx * self.block_size, len(seq))
# Input tokens for this chunk
input_ids = seq[start_token:end_token]
positions = list(range(start_token, end_token))
# Slot mapping for storing KV cache
slot_mapping = []
for i, gpu_block_id in enumerate(gpu_blocks):
block_idx = start_block_idx + i
start = gpu_block_id * self.block_size
if block_idx != seq.num_blocks - 1:
end = start + self.block_size
else:
end = start + seq.last_block_num_tokens
slot_mapping.extend(list(range(start, end)))
# Trim slot_mapping to match actual token count
actual_tokens = end_token - start_token
slot_mapping = slot_mapping[:actual_tokens]
# Convert to tensors
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
# Set up context for chunked prefill
seqlen = actual_tokens
cu_seqlens_q = torch.tensor([0, seqlen], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
cu_seqlens_k = torch.tensor([0, seqlen], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
set_context(
is_prefill=True,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=seqlen,
max_seqlen_k=seqlen,
slot_mapping=slot_mapping,
is_chunked_prefill=True,
offload_engine=self.kvcache_manager, # Pass manager for loading previous KV
chunked_seq=seq, # Pass sequence for loading previous KV
)
return input_ids, positions
@torch.inference_mode()
def capture_cudagraph(self):
config = self.config

View File

@@ -1,19 +1,22 @@
from collections import deque
from time import perf_counter_ns
from typing import TYPE_CHECKING
from nanovllm.config import Config
from nanovllm.engine.sequence import Sequence, SequenceStatus
from nanovllm.engine.block_manager import BlockManager
from nanovllm.utils.observer import Observer
if TYPE_CHECKING:
from nanovllm.kvcache import KVCacheManager
class Scheduler:
def __init__(self, config: Config):
def __init__(self, config: Config, kvcache_manager: "KVCacheManager"):
self.max_num_seqs = config.max_num_seqs
self.max_num_batched_tokens = config.max_num_batched_tokens
self.eos = config.eos
self.block_manager = BlockManager(config.num_kvcache_blocks, config.kvcache_block_size)
self.kvcache_manager = kvcache_manager
self.waiting: deque[Sequence] = deque()
self.running: deque[Sequence] = deque()
@@ -32,10 +35,10 @@ class Scheduler:
if Observer.ttft_start == 0:
Observer.ttft_start = perf_counter_ns()
seq = self.waiting[0]
if num_batched_tokens + len(seq) > self.max_num_batched_tokens or not self.block_manager.can_allocate(seq):
if num_batched_tokens + len(seq) > self.max_num_batched_tokens or not self.kvcache_manager.can_allocate(seq):
break
num_seqs += 1
self.block_manager.allocate(seq)
self.kvcache_manager.allocate(seq)
num_batched_tokens += len(seq) - seq.num_cached_tokens
seq.status = SequenceStatus.RUNNING
self.waiting.popleft()
@@ -47,7 +50,7 @@ class Scheduler:
# decode
while self.running and num_seqs < self.max_num_seqs:
seq = self.running.popleft()
while not self.block_manager.can_append(seq):
while not self.kvcache_manager.can_append(seq):
if self.running:
self.preempt(self.running.pop())
else:
@@ -55,7 +58,7 @@ class Scheduler:
break
else:
num_seqs += 1
self.block_manager.may_append(seq)
self.kvcache_manager.may_append(seq)
scheduled_seqs.append(seq)
assert scheduled_seqs
self.running.extendleft(reversed(scheduled_seqs))
@@ -63,7 +66,7 @@ class Scheduler:
def preempt(self, seq: Sequence):
seq.status = SequenceStatus.WAITING
self.block_manager.deallocate(seq)
self.kvcache_manager.deallocate(seq)
self.waiting.appendleft(seq)
def postprocess(self, seqs: list[Sequence], token_ids: list[int]) -> list[bool]:
@@ -71,5 +74,5 @@ class Scheduler:
seq.append_token(token_id)
if (not seq.ignore_eos and token_id == self.eos) or seq.num_completion_tokens == seq.max_tokens:
seq.status = SequenceStatus.FINISHED
self.block_manager.deallocate(seq)
self.kvcache_manager.deallocate(seq)
self.running.remove(seq)

View File

@@ -0,0 +1,74 @@
"""
KV Cache management module.
This module provides pluggable KV cache management strategies:
- GPUOnlyManager: Pure GPU (default, current nano-vllm behavior)
- HybridKVCacheManager: CPU offload with CUDA Graph support
Usage:
from nanovllm.kvcache import create_kvcache_manager
manager = create_kvcache_manager(config)
"""
from typing import TYPE_CHECKING
from nanovllm.kvcache.base_manager import KVCacheManager
from nanovllm.kvcache.gpu_manager import GPUOnlyManager
if TYPE_CHECKING:
from nanovllm.config import Config
def create_kvcache_manager(config: "Config") -> KVCacheManager:
"""
Factory function to create the appropriate KV cache manager.
Decision logic:
1. If enable_cpu_offload=False: use GPUOnlyManager
2. If enable_cpu_offload=True but all blocks fit in GPU: use GPUOnlyManager
3. If enable_cpu_offload=True and need CPU blocks: use HybridKVCacheManager
Args:
config: Model configuration with offload settings
Returns:
KVCacheManager instance
"""
if not getattr(config, 'enable_cpu_offload', False):
# Default: pure GPU mode
return GPUOnlyManager(
num_blocks=config.num_kvcache_blocks,
block_size=config.kvcache_block_size,
)
# CPU offload is enabled
num_gpu_blocks = config.num_gpu_kvcache_blocks
num_cpu_blocks = config.num_cpu_kvcache_blocks
if num_cpu_blocks <= 0:
# All blocks fit in GPU, use pure GPU mode
return GPUOnlyManager(
num_blocks=num_gpu_blocks,
block_size=config.kvcache_block_size,
)
# Need CPU offload: use hybrid manager
from nanovllm.kvcache.hybrid_manager import HybridKVCacheManager
from nanovllm.kvcache.policies import get_policy
policy = get_policy(getattr(config, 'offload_policy', 'lru'))
return HybridKVCacheManager(
num_gpu_slots=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks,
block_size=config.kvcache_block_size,
policy=policy,
)
__all__ = [
"KVCacheManager",
"GPUOnlyManager",
"create_kvcache_manager",
]

View File

@@ -0,0 +1,260 @@
"""
Abstract base class for KV cache managers.
This interface allows pluggable implementations:
- GPUOnlyManager: Pure GPU (current nano-vllm behavior)
- HybridKVCacheManager: CPU offload with CUDA Graph support
- Future: Disk offload, distributed cache, etc.
"""
from abc import ABC, abstractmethod
from typing import List, Tuple, Optional
import torch
from torch import Tensor
from nanovllm.engine.sequence import Sequence
class KVCacheManager(ABC):
"""
Abstract base class for KV cache management strategies.
A KVCacheManager handles:
1. Physical memory allocation (GPU and optionally CPU)
2. Logical block management (allocation, deallocation, prefix caching)
3. Data transfer between devices (for hybrid managers)
4. Integration with CUDA graphs
Key design principles:
- Sequences reference logical block IDs
- Physical block IDs (GPU slots) may differ from logical IDs
- CUDA Graph compatibility requires fixed tensor addresses
"""
@property
@abstractmethod
def block_size(self) -> int:
"""Number of tokens per block."""
pass
@property
@abstractmethod
def num_free_blocks(self) -> int:
"""Number of free logical blocks available for allocation."""
pass
@abstractmethod
def allocate_cache(
self,
num_layers: int,
num_kv_heads: int,
head_dim: int,
dtype: torch.dtype,
) -> None:
"""
Allocate KV cache storage.
Called once during initialization to allocate GPU (and optionally CPU)
memory for the KV cache.
Args:
num_layers: Number of transformer layers
num_kv_heads: Number of key-value heads per layer
head_dim: Dimension per head
dtype: Data type for cache (e.g., torch.float16)
"""
pass
@abstractmethod
def get_layer_cache(self, layer_id: int) -> Tuple[Tensor, Tensor]:
"""
Get K and V cache tensors for a specific layer.
The returned tensors must be on GPU and have fixed addresses
for CUDA Graph compatibility.
Args:
layer_id: Layer index
Returns:
(k_cache, v_cache) tensors
Shape depends on implementation, typically:
[num_blocks, block_size, kv_heads, head_dim]
"""
pass
@abstractmethod
def can_allocate(self, seq: Sequence) -> bool:
"""
Check if blocks can be allocated for a new sequence.
Called before allocate() to ensure sufficient resources.
Args:
seq: Sequence to check
Returns:
True if allocation is possible
"""
pass
@abstractmethod
def allocate(self, seq: Sequence) -> None:
"""
Allocate blocks for a sequence during prefill.
This method:
1. Checks prefix cache for matching blocks
2. Allocates new blocks as needed
3. Updates seq.block_table with logical block IDs
4. Updates seq.num_cached_tokens for prefix cache hits
Args:
seq: Sequence to allocate blocks for
"""
pass
@abstractmethod
def deallocate(self, seq: Sequence) -> None:
"""
Release blocks for a finished sequence.
This method:
1. Decrements reference counts
2. Frees blocks with zero references
3. Clears seq.block_table
Args:
seq: Sequence whose blocks to release
"""
pass
@abstractmethod
def can_append(self, seq: Sequence) -> bool:
"""
Check if a new block can be allocated for decode.
Called before may_append() to check if resources are available.
Args:
seq: Sequence to check
Returns:
True if append is possible (or no new block needed)
"""
pass
@abstractmethod
def may_append(self, seq: Sequence) -> None:
"""
Potentially allocate a new block during decode.
Called after each decode step. If the current block is full,
allocates a new block and updates seq.block_table.
Args:
seq: Sequence that may need a new block
"""
pass
@abstractmethod
def prepare_for_attention(
self,
seqs: List[Sequence],
is_prefill: bool,
) -> None:
"""
Prepare KV cache for attention computation.
For GPU-only managers: typically a no-op.
For hybrid managers: ensures all needed blocks are on GPU,
may trigger prefetching from CPU.
Called before attention computation. For decode with CUDA graphs,
this should update gather_indices but not perform actual transfers
(transfers happen inside the graph).
Args:
seqs: Sequences that will be processed
is_prefill: True for prefill phase, False for decode
"""
pass
@abstractmethod
def get_gpu_block_tables(
self,
seqs: List[Sequence],
) -> List[List[int]]:
"""
Get GPU physical block tables for sequences.
For GPU-only managers: returns seq.block_table directly.
For hybrid managers: returns GPU slot IDs (may differ from logical IDs).
The returned block tables are used to compute slot_mapping
in ModelRunner.prepare_prefill/decode.
Args:
seqs: Sequences to get block tables for
Returns:
List of GPU block tables, one per sequence
"""
pass
def post_attention_cleanup(
self,
seqs: List[Sequence],
is_prefill: bool,
) -> None:
"""
Cleanup after attention computation.
Optional hook for managers to perform post-attention tasks:
- Offloading cold blocks to CPU
- Updating access statistics
- etc.
Default implementation does nothing.
Args:
seqs: Sequences that were processed
is_prefill: True for prefill phase, False for decode
"""
pass
def get_num_blocks_needed(self, num_tokens: int) -> int:
"""
Calculate number of blocks needed for given token count.
Args:
num_tokens: Number of tokens
Returns:
Number of blocks needed
"""
return (num_tokens + self.block_size - 1) // self.block_size
@staticmethod
def compute_hash(token_ids: list, prefix: int = -1) -> int:
"""
Compute hash for prefix caching.
Uses xxhash for fast hashing. The hash includes the prefix hash
to create a chain of hashes for multi-block sequences.
Args:
token_ids: Token IDs in the block
prefix: Hash of previous block, or -1 for first block
Returns:
Hash value
"""
import xxhash
import numpy as np
h = xxhash.xxh64()
if prefix != -1:
h.update(prefix.to_bytes(8, "little"))
h.update(np.array(token_ids).tobytes())
return h.intdigest()

View File

@@ -0,0 +1,555 @@
"""
Chunked attention implementation for CPU KV cache offloading.
This module implements flash attention with LSE (log-sum-exp) output,
enabling proper online softmax merging for chunked prefill.
Key functions:
- flash_attn_with_lse: Flash attention that returns output and LSE
- merge_attention_outputs: Merge outputs from multiple KV chunks
- chunked_prefill_attention: High-level interface for chunked attention
"""
import math
import torch
import triton
import triton.language as tl
from typing import Tuple, List, Optional
@triton.heuristics(
{
"EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
"EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
"EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
}
)
@triton.jit
def _fwd_kernel_with_lse(
Q,
K,
V,
Out,
Lse,
TMP,
softmax_scale,
stride_qb,
stride_qh,
stride_qm,
stride_kb,
stride_kh,
stride_kn,
stride_vb,
stride_vh,
stride_vn,
stride_ob,
stride_oh,
stride_om,
nheads,
seqlen_q,
seqlen_k,
seqlen_q_rounded,
headdim,
CACHE_KEY_SEQLEN_Q,
CACHE_KEY_SEQLEN_K,
IS_CAUSAL: tl.constexpr,
BLOCK_HEADDIM: tl.constexpr,
EVEN_M: tl.constexpr,
EVEN_N: tl.constexpr,
EVEN_HEADDIM: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
"""Flash attention forward kernel with LSE output for online softmax."""
start_m = tl.program_id(0)
off_hb = tl.program_id(1)
off_b = off_hb // nheads
off_h = off_hb % nheads
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_HEADDIM)
q_ptrs = (
Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :])
)
k_ptrs = (
K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :])
)
v_ptrs = (
V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :])
)
t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m
lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)
# Load Q
if EVEN_M & EVEN_N:
if EVEN_HEADDIM:
q = tl.load(q_ptrs)
else:
q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
else:
if EVEN_HEADDIM:
q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)
else:
q = tl.load(
q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0
)
# Loop over k, v and update accumulator
end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
for start_n in range(0, end_n, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# Load K
if EVEN_N & EVEN_M:
if EVEN_HEADDIM:
k = tl.load(k_ptrs + start_n * stride_kn)
else:
k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0)
else:
if EVEN_HEADDIM:
k = tl.load(
k_ptrs + start_n * stride_kn,
mask=(start_n + offs_n)[:, None] < seqlen_k,
other=0.0,
)
else:
k = tl.load(
k_ptrs + start_n * stride_kn,
mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
other=0.0,
)
# Compute QK
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, tl.trans(k))
# Masking
if not EVEN_N:
qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf"))
if IS_CAUSAL:
qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf"))
# Online softmax
m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i)
p = tl.exp(qk * softmax_scale - m_ij[:, None])
l_ij = tl.sum(p, 1)
# Scale acc_o
acc_o_scale = tl.exp(m_i - m_ij)
tl.store(t_ptrs, acc_o_scale)
acc_o_scale = tl.load(t_ptrs)
acc_o = acc_o * acc_o_scale[:, None]
# Load V and update output
if EVEN_N & EVEN_M:
if EVEN_HEADDIM:
v = tl.load(v_ptrs + start_n * stride_vn)
else:
v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0)
else:
if EVEN_HEADDIM:
v = tl.load(
v_ptrs + start_n * stride_vn,
mask=(start_n + offs_n)[:, None] < seqlen_k,
other=0.0,
)
else:
v = tl.load(
v_ptrs + start_n * stride_vn,
mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
other=0.0,
)
p = p.to(v.dtype)
acc_o += tl.dot(p, v)
# Update statistics
m_i = m_ij
l_i_new = tl.exp(lse_i - m_ij) + l_ij
lse_i = m_ij + tl.log(l_i_new)
# Final scaling
o_scale = tl.exp(m_i - lse_i)
tl.store(t_ptrs, o_scale)
o_scale = tl.load(t_ptrs)
acc_o = acc_o * o_scale[:, None]
# Store LSE
lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m
tl.store(lse_ptrs, lse_i)
# Store output
out_ptrs = (
Out
+ off_b * stride_ob
+ off_h * stride_oh
+ (offs_m[:, None] * stride_om + offs_d[None, :])
)
if EVEN_M:
if EVEN_HEADDIM:
tl.store(out_ptrs, acc_o)
else:
tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim)
else:
if EVEN_HEADDIM:
tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q)
else:
tl.store(
out_ptrs, acc_o, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim)
)
def flash_attn_with_lse(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
softmax_scale: Optional[float] = None,
causal: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Flash attention forward pass that returns both output and LSE.
Supports GQA (grouped query attention) where num_kv_heads < num_q_heads.
Args:
q: Query tensor [batch, seqlen_q, nheads_q, headdim]
k: Key tensor [batch, seqlen_k, nheads_kv, headdim]
v: Value tensor [batch, seqlen_k, nheads_kv, headdim]
softmax_scale: Scaling factor (default: 1/sqrt(headdim))
causal: Whether to apply causal masking
Returns:
out: Output tensor [batch, seqlen_q, nheads_q, headdim]
lse: Log-sum-exp tensor [batch, nheads_q, seqlen_q]
"""
# Ensure contiguous
if not q.is_contiguous():
q = q.contiguous()
if not k.is_contiguous():
k = k.contiguous()
if not v.is_contiguous():
v = v.contiguous()
batch, seqlen_q, nheads_q, headdim = q.shape
_, seqlen_k, nheads_kv, _ = k.shape
assert k.shape == (batch, seqlen_k, nheads_kv, headdim)
assert v.shape == (batch, seqlen_k, nheads_kv, headdim)
assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16"
assert q.dtype == k.dtype == v.dtype
# Handle GQA by repeating K/V heads
if nheads_kv != nheads_q:
assert nheads_q % nheads_kv == 0, f"nheads_q ({nheads_q}) must be divisible by nheads_kv ({nheads_kv})"
repeat_factor = nheads_q // nheads_kv
# [batch, seqlen_k, nheads_kv, headdim] -> [batch, seqlen_k, nheads_q, headdim]
k = k.repeat_interleave(repeat_factor, dim=2)
v = v.repeat_interleave(repeat_factor, dim=2)
nheads = nheads_q
else:
nheads = nheads_q
if softmax_scale is None:
softmax_scale = 1.0 / math.sqrt(headdim)
seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
tmp = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
out = torch.empty_like(q)
BLOCK_HEADDIM = max(triton.next_power_of_2(headdim), 16)
BLOCK = 128
num_warps = 4 if headdim <= 64 else 8
grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
_fwd_kernel_with_lse[grid](
q,
k,
v,
out,
lse,
tmp,
softmax_scale,
q.stride(0),
q.stride(2),
q.stride(1),
k.stride(0),
k.stride(2),
k.stride(1),
v.stride(0),
v.stride(2),
v.stride(1),
out.stride(0),
out.stride(2),
out.stride(1),
nheads,
seqlen_q,
seqlen_k,
seqlen_q_rounded,
headdim,
seqlen_q // 32,
seqlen_k // 32,
causal,
BLOCK_HEADDIM,
BLOCK_M=BLOCK,
BLOCK_N=BLOCK,
num_warps=num_warps,
num_stages=1,
)
# Trim LSE to actual seqlen_q
lse = lse[:, :, :seqlen_q]
# Ensure output has same dtype as input
out = out.to(q.dtype)
return out, lse
def merge_attention_outputs(
o1: torch.Tensor,
lse1: torch.Tensor,
o2: torch.Tensor,
lse2: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Merge two attention outputs using online softmax.
This implements the online softmax merging formula:
- m_new = max(lse1, lse2)
- o_new = (exp(lse1 - m_new) * o1 + exp(lse2 - m_new) * o2) / (exp(lse1 - m_new) + exp(lse2 - m_new))
- lse_new = m_new + log(exp(lse1 - m_new) + exp(lse2 - m_new))
Args:
o1: First output [batch, seqlen_q, nheads, headdim]
lse1: First LSE [batch, nheads, seqlen_q]
o2: Second output [batch, seqlen_q, nheads, headdim]
lse2: Second LSE [batch, nheads, seqlen_q]
Returns:
o_merged: Merged output [batch, seqlen_q, nheads, headdim]
lse_merged: Merged LSE [batch, nheads, seqlen_q]
"""
# lse shape: [batch, nheads, seqlen_q]
# o shape: [batch, seqlen_q, nheads, headdim]
# Compute max for numerical stability
max_lse = torch.maximum(lse1, lse2)
# Compute scaling factors
# exp1, exp2 shape: [batch, nheads, seqlen_q]
exp1 = torch.exp(lse1 - max_lse)
exp2 = torch.exp(lse2 - max_lse)
# Reshape for broadcasting with output
# [batch, nheads, seqlen_q] -> [batch, seqlen_q, nheads, 1]
exp1_broad = exp1.transpose(1, 2).unsqueeze(-1)
exp2_broad = exp2.transpose(1, 2).unsqueeze(-1)
# Merge outputs
sum_exp = exp1_broad + exp2_broad
o_merged = (o1 * exp1_broad + o2 * exp2_broad) / sum_exp
# Compute merged LSE
lse_merged = max_lse + torch.log(exp1 + exp2)
# Ensure output has same dtype as input
o_merged = o_merged.to(o1.dtype)
return o_merged, lse_merged
def chunked_attention_varlen(
q: torch.Tensor,
kv_chunks: List[Tuple[torch.Tensor, torch.Tensor]],
cu_seqlens_q: torch.Tensor,
cu_seqlens_k_list: List[torch.Tensor],
max_seqlen_q: int,
max_seqlen_k_list: List[int],
softmax_scale: Optional[float] = None,
causal_mask_per_chunk: Optional[List[bool]] = None,
) -> torch.Tensor:
"""
Compute attention with KV split across multiple chunks.
This is the core function for chunked prefill. It computes attention
against each KV chunk and merges results using online softmax.
For causal attention with chunked KV:
- First chunk (current tokens): Apply causal mask
- Previous chunks: No causal mask (all previous tokens are valid context)
Args:
q: Query tensor [total_q_tokens, nheads, headdim]
kv_chunks: List of (K, V) tuples, each [batch, seqlen_k_i, nheads, headdim]
cu_seqlens_q: Cumulative sequence lengths for Q [batch+1]
cu_seqlens_k_list: List of cumulative sequence lengths for each KV chunk
max_seqlen_q: Maximum query sequence length
max_seqlen_k_list: List of maximum key sequence lengths for each chunk
softmax_scale: Scaling factor
causal_mask_per_chunk: Whether to apply causal mask for each chunk
Returns:
out: Output tensor [total_q_tokens, nheads, headdim]
"""
if len(kv_chunks) == 0:
raise ValueError("Need at least one KV chunk")
nheads = q.shape[1]
headdim = q.shape[2]
batch = cu_seqlens_q.shape[0] - 1
if softmax_scale is None:
softmax_scale = 1.0 / math.sqrt(headdim)
if causal_mask_per_chunk is None:
# Default: causal for last chunk only
causal_mask_per_chunk = [False] * (len(kv_chunks) - 1) + [True]
# Initialize accumulated output and LSE
accumulated_o = None
accumulated_lse = None
for chunk_idx, (k_chunk, v_chunk) in enumerate(kv_chunks):
is_causal = causal_mask_per_chunk[chunk_idx]
# Reshape Q for batch processing
# For varlen, we need to handle each sequence separately
# For simplicity, assume single sequence (batch=1) for now
q_batched = q.unsqueeze(0) # [1, total_q, nheads, headdim]
# Compute attention for this chunk
chunk_o, chunk_lse = flash_attn_with_lse(
q_batched,
k_chunk,
v_chunk,
softmax_scale=softmax_scale,
causal=is_causal,
)
# Merge with accumulated
if accumulated_o is None:
accumulated_o = chunk_o
accumulated_lse = chunk_lse
else:
accumulated_o, accumulated_lse = merge_attention_outputs(
accumulated_o, accumulated_lse,
chunk_o, chunk_lse,
)
# Remove batch dimension
return accumulated_o.squeeze(0)
class ChunkedPrefillState:
"""
State for tracking chunked prefill progress.
This class maintains the accumulated attention output and LSE
across multiple prefill chunks.
"""
def __init__(self, num_layers: int, dtype: torch.dtype, device: torch.device):
self.num_layers = num_layers
self.dtype = dtype
self.device = device
# Per-layer accumulated outputs
# Each entry: (accumulated_output, accumulated_lse) or None
self.layer_states: List[Optional[Tuple[torch.Tensor, torch.Tensor]]] = [
None for _ in range(num_layers)
]
# Track which chunks have been processed
self.processed_chunks: int = 0
def update_layer(
self,
layer_id: int,
chunk_output: torch.Tensor,
chunk_lse: torch.Tensor,
):
"""Update accumulated state for a layer with a new chunk's output."""
if self.layer_states[layer_id] is None:
self.layer_states[layer_id] = (chunk_output, chunk_lse)
else:
acc_o, acc_lse = self.layer_states[layer_id]
merged_o, merged_lse = merge_attention_outputs(
acc_o, acc_lse,
chunk_output, chunk_lse,
)
self.layer_states[layer_id] = (merged_o, merged_lse)
def get_layer_output(self, layer_id: int) -> Optional[torch.Tensor]:
"""Get the final accumulated output for a layer."""
if self.layer_states[layer_id] is None:
return None
return self.layer_states[layer_id][0]
def clear(self):
"""Clear all accumulated state."""
self.layer_states = [None for _ in range(self.num_layers)]
self.processed_chunks = 0
# Test function
def _test_chunked_attention():
"""Test chunked attention correctness against full attention."""
from flash_attn import flash_attn_func
torch.manual_seed(42)
batch, seqlen, nheads, headdim = 1, 1024, 32, 128
# Generate random Q, K, V
q = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=torch.bfloat16)
k = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=torch.bfloat16)
v = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=torch.bfloat16)
# Full attention (reference)
out_ref = flash_attn_func(q, k, v, causal=True)
# Chunked attention
chunk_size = 256
num_chunks = seqlen // chunk_size
accumulated_o = None
accumulated_lse = None
for i in range(num_chunks):
start = i * chunk_size
end = (i + 1) * chunk_size
# Q for this chunk
q_chunk = q[:, start:end, :, :]
# K, V up to current position (for causal)
k_context = k[:, :end, :, :]
v_context = v[:, :end, :, :]
# Compute attention
chunk_o, chunk_lse = flash_attn_with_lse(
q_chunk, k_context, v_context, causal=True
)
if accumulated_o is None:
accumulated_o = chunk_o
accumulated_lse = chunk_lse
else:
# For chunked prefill, we need to concatenate outputs, not merge
# Because each chunk's Q attends to different K positions
accumulated_o = torch.cat([accumulated_o, chunk_o], dim=1)
# Compare
max_diff = (out_ref - accumulated_o).abs().max().item()
print(f"Max difference: {max_diff}")
assert max_diff < 1e-2, f"Chunked attention differs from reference: {max_diff}"
print("Test passed!")
if __name__ == "__main__":
_test_chunked_attention()

View File

@@ -0,0 +1,262 @@
"""
GPU-only KV cache manager.
This is the default manager when CPU offload is disabled.
Refactored from the original block_manager.py to implement
the KVCacheManager interface.
"""
from collections import deque
from typing import List, Tuple, Dict, Optional
import torch
from torch import Tensor
from nanovllm.engine.sequence import Sequence
from nanovllm.kvcache.base_manager import KVCacheManager
class Block:
"""Physical block in GPU memory."""
def __init__(self, block_id: int):
self.block_id = block_id
self.ref_count = 0
self.hash = -1
self.token_ids: List[int] = []
def update(self, hash: int, token_ids: List[int]):
self.hash = hash
self.token_ids = token_ids
def reset(self):
self.ref_count = 1
self.hash = -1
self.token_ids = []
class GPUOnlyManager(KVCacheManager):
"""
Pure GPU KV cache manager.
This is the default implementation when enable_cpu_offload=False.
All KV cache resides in GPU memory.
Features:
- Paged attention with configurable block size
- Prefix caching via xxhash
- Reference counting for block sharing
This manager is fully compatible with CUDA graphs since
all data stays on GPU at fixed addresses.
"""
def __init__(self, num_blocks: int, block_size: int):
"""
Initialize GPU-only manager.
Args:
num_blocks: Total number of blocks to manage
block_size: Tokens per block (default 256)
"""
self._block_size = block_size
self._num_blocks = num_blocks
# Block metadata
self.blocks: List[Block] = [Block(i) for i in range(num_blocks)]
# Prefix cache: hash -> block_id
self.hash_to_block_id: Dict[int, int] = {}
# Free/used tracking
self.free_block_ids: deque[int] = deque(range(num_blocks))
self.used_block_ids: set[int] = set()
# KV cache tensors (set by allocate_cache)
self.kv_cache: Optional[Tensor] = None
self.num_layers: int = 0
self.num_kv_heads: int = 0
self.head_dim: int = 0
@property
def block_size(self) -> int:
return self._block_size
@property
def num_free_blocks(self) -> int:
return len(self.free_block_ids)
def allocate_cache(
self,
num_layers: int,
num_kv_heads: int,
head_dim: int,
dtype: torch.dtype,
) -> None:
"""Allocate GPU KV cache tensor."""
self.num_layers = num_layers
self.num_kv_heads = num_kv_heads
self.head_dim = head_dim
# Shape: [2, num_layers, num_blocks, block_size, kv_heads, head_dim]
# 2 for K and V
self.kv_cache = torch.empty(
2, num_layers, self._num_blocks, self._block_size,
num_kv_heads, head_dim,
dtype=dtype, device="cuda"
)
def get_layer_cache(self, layer_id: int) -> Tuple[Tensor, Tensor]:
"""Get K/V cache for a layer."""
assert self.kv_cache is not None, "Cache not allocated"
return self.kv_cache[0, layer_id], self.kv_cache[1, layer_id]
def _allocate_block(self, block_id: int) -> Block:
"""Internal: allocate a specific block."""
block = self.blocks[block_id]
assert block.ref_count == 0, f"Block {block_id} is not free"
block.reset()
self.free_block_ids.remove(block_id)
self.used_block_ids.add(block_id)
return block
def _deallocate_block(self, block_id: int) -> None:
"""Internal: deallocate a block."""
assert self.blocks[block_id].ref_count == 0
self.used_block_ids.remove(block_id)
self.free_block_ids.append(block_id)
def can_allocate(self, seq: Sequence) -> bool:
"""Check if we have enough blocks for the sequence."""
return len(self.free_block_ids) >= seq.num_blocks
def allocate(self, seq: Sequence) -> None:
"""
Allocate blocks for a sequence during prefill.
Implements prefix caching: if a block's content matches
a previously cached block, reuse it instead of allocating new.
"""
assert not seq.block_table, "Sequence already has blocks allocated"
h = -1 # Hash chain
cache_miss = False
for i in range(seq.num_blocks):
token_ids = seq.block(i)
# Only compute hash for full blocks
if len(token_ids) == self._block_size:
h = self.compute_hash(token_ids, h)
else:
h = -1
# Try prefix cache lookup
block_id = self.hash_to_block_id.get(h, -1)
if block_id == -1 or self.blocks[block_id].token_ids != token_ids:
cache_miss = True
if cache_miss:
# Cache miss: allocate new block
block_id = self.free_block_ids[0]
block = self._allocate_block(block_id)
else:
# Cache hit: reuse existing block
seq.num_cached_tokens += self._block_size
if block_id in self.used_block_ids:
# Block is in use, increment ref count
block = self.blocks[block_id]
block.ref_count += 1
else:
# Block was freed but hash still valid
block = self._allocate_block(block_id)
# Update hash mapping for full blocks
if h != -1:
block.update(h, token_ids)
self.hash_to_block_id[h] = block_id
seq.block_table.append(block_id)
def deallocate(self, seq: Sequence) -> None:
"""Release all blocks for a sequence."""
for block_id in reversed(seq.block_table):
block = self.blocks[block_id]
block.ref_count -= 1
if block.ref_count == 0:
self._deallocate_block(block_id)
seq.num_cached_tokens = 0
seq.block_table.clear()
def can_append(self, seq: Sequence) -> bool:
"""Check if we can append a token (may need new block)."""
# Need new block only if current position is at block boundary
need_new_block = (len(seq) % self._block_size == 1)
return len(self.free_block_ids) >= int(need_new_block)
def may_append(self, seq: Sequence) -> None:
"""Handle potential new block allocation during decode."""
block_table = seq.block_table
last_block = self.blocks[block_table[-1]]
seq_len = len(seq)
pos_in_block = seq_len % self._block_size
if pos_in_block == 1:
# Just crossed into new block, need to allocate
assert last_block.hash != -1, "Previous block should be complete"
block_id = self.free_block_ids[0]
self._allocate_block(block_id)
block_table.append(block_id)
elif pos_in_block == 0:
# Just filled a block, compute hash for prefix cache
assert last_block.hash == -1, "Block should not have hash yet"
token_ids = seq.block(seq.num_blocks - 1)
prefix = self.blocks[block_table[-2]].hash if len(block_table) > 1 else -1
h = self.compute_hash(token_ids, prefix)
last_block.update(h, token_ids)
self.hash_to_block_id[h] = last_block.block_id
else:
# Middle of block, nothing to do
assert last_block.hash == -1
def prepare_for_attention(
self,
seqs: List[Sequence],
is_prefill: bool,
) -> None:
"""
No-op for GPU-only manager.
All blocks are already on GPU, no preparation needed.
"""
pass
def get_gpu_block_tables(
self,
seqs: List[Sequence],
) -> List[List[int]]:
"""
Return block tables directly (logical = physical for GPU-only).
"""
return [list(seq.block_table) for seq in seqs]
def post_attention_cleanup(
self,
seqs: List[Sequence],
is_prefill: bool,
) -> None:
"""No-op for GPU-only manager."""
pass
def __repr__(self) -> str:
return (
f"GPUOnlyManager("
f"num_blocks={self._num_blocks}, "
f"block_size={self._block_size}, "
f"free={len(self.free_block_ids)}, "
f"used={len(self.used_block_ids)}"
f")"
)

View File

@@ -0,0 +1,906 @@
"""
Hybrid CPU-GPU KV cache manager with CUDA Graph support.
Key design for CUDA Graph compatibility:
1. GPU buffer has fixed addresses (allocated once)
2. CPU pool has fixed addresses (pinned memory)
3. gather_indices tensor has fixed address, variable content
4. H2D transfer uses gathered_copy kernel inside CUDA graphs
5. Graph replay only needs index updates (tiny overhead)
"""
from collections import deque
from dataclasses import dataclass, field
from enum import Enum, auto
from typing import List, Tuple, Dict, Set, Optional
import torch
from torch import Tensor
from nanovllm.engine.sequence import Sequence
from nanovllm.kvcache.base_manager import KVCacheManager
from nanovllm.kvcache.offload_engine import OffloadEngine
from nanovllm.kvcache.policies.base_policy import EvictionPolicy
from nanovllm.kvcache.policies.lru_policy import LRUPolicy
class BlockLocation(Enum):
"""Where a logical block's data currently resides."""
GPU = auto()
CPU = auto()
INVALID = auto() # Not yet written / deallocated
@dataclass
class LogicalBlock:
"""
Logical block that can be mapped to GPU or CPU physical storage.
Sequences reference logical blocks. Physical blocks are the actual
storage locations (GPU slots or CPU blocks).
"""
logical_id: int
location: BlockLocation = BlockLocation.INVALID
gpu_slot: int = -1 # GPU buffer slot ID (if on GPU)
cpu_block_id: int = -1 # CPU pool block ID (if on CPU)
ref_count: int = 0
hash: int = -1
token_ids: List[int] = field(default_factory=list)
def reset(self):
self.location = BlockLocation.INVALID
self.gpu_slot = -1
self.cpu_block_id = -1
self.ref_count = 0
self.hash = -1
self.token_ids = []
class HybridKVCacheManager(KVCacheManager):
"""
Hybrid CPU-GPU KV cache manager with CUDA Graph support.
Architecture:
- GPU buffer: Fixed-size working set (num_gpu_slots)
- CPU pool: Overflow storage (num_cpu_blocks)
- Logical blocks: What sequences reference (num_gpu_slots + num_cpu_blocks)
CUDA Graph compatibility:
- All tensor addresses fixed at init time
- prepare_for_attention() updates gather_indices (outside graph)
- gathered_h2d_layer() executes transfer (inside graph)
Strategy:
1. New KV data written to GPU slots
2. Cold blocks evicted to CPU using configurable policy
3. Needed blocks prefetched back to GPU before attention
"""
def __init__(
self,
num_gpu_slots: int,
num_cpu_blocks: int,
block_size: int,
policy: Optional[EvictionPolicy] = None,
):
"""
Initialize hybrid manager.
Args:
num_gpu_slots: Number of GPU buffer slots (working set)
num_cpu_blocks: Number of CPU pool blocks (overflow)
block_size: Tokens per block
policy: Eviction policy (default: LRU)
"""
self._block_size = block_size
self.num_gpu_slots = num_gpu_slots
self.num_cpu_blocks = num_cpu_blocks
self.total_blocks = num_gpu_slots + num_cpu_blocks
# Eviction policy
self.policy = policy or LRUPolicy()
# Logical blocks (what sequences reference)
self.logical_blocks: List[LogicalBlock] = [
LogicalBlock(i) for i in range(self.total_blocks)
]
self.free_logical_ids: deque[int] = deque(range(self.total_blocks))
# GPU slot management (slots are fixed, mapping is variable)
self.free_gpu_slots: deque[int] = deque(range(num_gpu_slots))
self.gpu_slot_to_logical: Dict[int, int] = {} # gpu_slot -> logical_id
# CPU block management
self.free_cpu_blocks: deque[int] = deque(range(num_cpu_blocks))
self.cpu_block_to_logical: Dict[int, int] = {} # cpu_block -> logical_id
# Prefix cache (uses logical block IDs)
self.hash_to_logical_id: Dict[int, int] = {}
# Step counter for policy
self.current_step = 0
# Offload engine (set by allocate_cache)
self.offload_engine: Optional[OffloadEngine] = None
# Track blocks pending GPU load (for decode graph)
self.pending_gpu_loads: Set[int] = set() # logical_ids
# Track blocks that have been prefilled (KV written) for chunked prefill
self.prefilled_blocks: Set[int] = set() # logical_ids
@property
def block_size(self) -> int:
return self._block_size
@property
def num_free_blocks(self) -> int:
return len(self.free_logical_ids)
def allocate_cache(
self,
num_layers: int,
num_kv_heads: int,
head_dim: int,
dtype: torch.dtype,
) -> None:
"""Initialize the offload engine with actual cache storage."""
self.offload_engine = OffloadEngine(
num_layers=num_layers,
num_gpu_blocks=self.num_gpu_slots,
num_cpu_blocks=self.num_cpu_blocks,
block_size=self._block_size,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
dtype=dtype,
)
def get_layer_cache(self, layer_id: int) -> Tuple[Tensor, Tensor]:
"""Get GPU K/V cache tensors for a layer."""
assert self.offload_engine is not None
return self.offload_engine.get_layer_cache(layer_id)
def _allocate_gpu_slot(self, protected_logical_ids: Optional[Set[int]] = None) -> int:
"""
Get a free GPU slot, evicting if necessary.
Args:
protected_logical_ids: Logical block IDs that cannot be evicted
Returns:
GPU slot ID
Raises:
RuntimeError: If no GPU slot is available
"""
if self.free_gpu_slots:
return self.free_gpu_slots.popleft()
# Need to evict - find victim using policy
return self._evict_to_cpu(protected_logical_ids)
def _try_allocate_gpu_slot(self, protected_logical_ids: Optional[Set[int]] = None) -> Optional[int]:
"""
Try to get a free GPU slot, evicting if necessary.
Unlike _allocate_gpu_slot(), returns None instead of raising if no eviction possible.
Args:
protected_logical_ids: Logical block IDs that cannot be evicted
Returns:
GPU slot ID, or None if no slot available
"""
if self.free_gpu_slots:
return self.free_gpu_slots.popleft()
# Check if we can evict
protected = protected_logical_ids or set()
for gpu_slot, logical_id in self.gpu_slot_to_logical.items():
if logical_id not in protected:
block = self.logical_blocks[logical_id]
if block.ref_count > 0:
# Found evictable block
return self._evict_to_cpu(protected_logical_ids)
# No evictable blocks
return None
def _evict_to_cpu(self, protected_logical_ids: Optional[Set[int]] = None) -> int:
"""
Evict a GPU block to CPU to make room.
Args:
protected_logical_ids: Logical block IDs that cannot be evicted
Returns:
The freed GPU slot ID
"""
protected = protected_logical_ids or set()
# Find candidates (blocks currently on GPU with ref_count > 0, excluding protected)
candidates: Set[int] = set()
for gpu_slot, logical_id in self.gpu_slot_to_logical.items():
if logical_id in protected:
continue # Skip protected blocks
block = self.logical_blocks[logical_id]
if block.ref_count > 0: # Only evict blocks still in use
candidates.add(gpu_slot)
if not candidates:
raise RuntimeError(
f"No GPU slots available for eviction. "
f"GPU slots: {self.num_gpu_slots}, protected: {len(protected)}, "
f"need more GPU memory or reduce sequence length"
)
# Use policy to select victim
victim_gpu_slot = self.policy.select_victim(candidates)
logical_id = self.gpu_slot_to_logical[victim_gpu_slot]
block = self.logical_blocks[logical_id]
# Allocate CPU block
if not self.free_cpu_blocks:
raise RuntimeError("Both GPU and CPU are full")
cpu_block_id = self.free_cpu_blocks.popleft()
# Async offload GPU -> CPU
self.offload_engine.offload_block_async(
layer_id=0, # TODO: handle per-layer offloading
gpu_block_id=victim_gpu_slot,
cpu_block_id=cpu_block_id,
)
# Update mappings
del self.gpu_slot_to_logical[victim_gpu_slot]
self.cpu_block_to_logical[cpu_block_id] = logical_id
block.location = BlockLocation.CPU
block.gpu_slot = -1
block.cpu_block_id = cpu_block_id
# Notify policy
self.policy.on_block_evicted(victim_gpu_slot)
return victim_gpu_slot
def _ensure_on_gpu(
self,
logical_id: int,
protected_logical_ids: Optional[Set[int]] = None,
) -> int:
"""
Ensure a logical block is on GPU.
Args:
logical_id: Logical block ID
protected_logical_ids: Logical block IDs that cannot be evicted
Returns:
GPU slot ID where the block is/will be
"""
block = self.logical_blocks[logical_id]
if block.location == BlockLocation.GPU:
# Already on GPU, update policy
self.policy.on_block_access(block.gpu_slot, self.current_step)
return block.gpu_slot
if block.location == BlockLocation.CPU:
# Need to prefetch from CPU
gpu_slot = self._allocate_gpu_slot(protected_logical_ids)
# Async prefetch CPU -> GPU
self.offload_engine.prefetch_block_async(
layer_id=0, # TODO: handle per-layer
cpu_block_id=block.cpu_block_id,
gpu_block_id=gpu_slot,
)
# Update mappings
self.free_cpu_blocks.append(block.cpu_block_id)
del self.cpu_block_to_logical[block.cpu_block_id]
self.gpu_slot_to_logical[gpu_slot] = logical_id
block.location = BlockLocation.GPU
block.gpu_slot = gpu_slot
block.cpu_block_id = -1
# Notify policy
self.policy.on_block_prefetched(gpu_slot, self.current_step)
return gpu_slot
raise RuntimeError(f"Block {logical_id} is in invalid state")
def can_allocate(self, seq: Sequence) -> bool:
"""Check if we can allocate blocks for a new sequence."""
return len(self.free_logical_ids) >= seq.num_blocks
def allocate(self, seq: Sequence) -> None:
"""
Allocate logical blocks for prefill.
New blocks are allocated on GPU when possible. If GPU is full and all
GPU blocks belong to this sequence (can't evict), remaining blocks
are allocated to CPU for chunked prefill.
"""
assert not seq.block_table, "Sequence already has blocks"
h = -1
cache_miss = False
# Track blocks allocated for this sequence to protect them from eviction
allocated_for_seq: Set[int] = set()
for i in range(seq.num_blocks):
token_ids = seq.block(i)
# Hash for full blocks only
if len(token_ids) == self._block_size:
h = self.compute_hash(token_ids, h)
else:
h = -1
# Check prefix cache
cached_logical_id = self.hash_to_logical_id.get(h, -1)
if cached_logical_id != -1:
cached_block = self.logical_blocks[cached_logical_id]
if cached_block.token_ids == token_ids and cached_block.ref_count > 0:
# Cache hit
cached_block.ref_count += 1
seq.num_cached_tokens += self._block_size
seq.block_table.append(cached_logical_id)
allocated_for_seq.add(cached_logical_id)
# Ensure block is on GPU (protect already allocated blocks)
if cached_block.location == BlockLocation.CPU:
self._ensure_on_gpu(cached_logical_id, allocated_for_seq)
continue
cache_miss = True
# Allocate new logical block
logical_id = self.free_logical_ids.popleft()
block = self.logical_blocks[logical_id]
block.ref_count = 1
block.hash = h
block.token_ids = token_ids.copy() if len(token_ids) == self._block_size else []
# Try to allocate GPU slot
gpu_slot = self._try_allocate_gpu_slot(allocated_for_seq)
if gpu_slot is not None:
# Got GPU slot
block.location = BlockLocation.GPU
block.gpu_slot = gpu_slot
block.cpu_block_id = -1
self.gpu_slot_to_logical[gpu_slot] = logical_id
else:
# GPU full and can't evict (all protected) - allocate to CPU
# This block will be written via chunked prefill
if not self.free_cpu_blocks:
raise RuntimeError(
f"Both GPU and CPU are full. Need {seq.num_blocks} blocks, "
f"GPU has {self.num_gpu_slots}, CPU has {self.num_cpu_blocks}"
)
cpu_block_id = self.free_cpu_blocks.popleft()
block.location = BlockLocation.CPU
block.gpu_slot = -1
block.cpu_block_id = cpu_block_id
self.cpu_block_to_logical[cpu_block_id] = logical_id
allocated_for_seq.add(logical_id)
# Update prefix cache
if h != -1:
self.hash_to_logical_id[h] = logical_id
# Notify policy
self.policy.on_block_allocated(gpu_slot, self.current_step)
seq.block_table.append(logical_id)
def deallocate(self, seq: Sequence) -> None:
"""Release all blocks for a sequence."""
for logical_id in reversed(seq.block_table):
block = self.logical_blocks[logical_id]
block.ref_count -= 1
if block.ref_count == 0:
# Free physical block
if block.location == BlockLocation.GPU:
self.free_gpu_slots.append(block.gpu_slot)
del self.gpu_slot_to_logical[block.gpu_slot]
self.policy.on_block_deallocated(block.gpu_slot)
elif block.location == BlockLocation.CPU:
self.free_cpu_blocks.append(block.cpu_block_id)
del self.cpu_block_to_logical[block.cpu_block_id]
# Free logical block
block.reset()
self.free_logical_ids.append(logical_id)
# Remove from prefilled tracking
self.prefilled_blocks.discard(logical_id)
seq.num_cached_tokens = 0
seq.block_table.clear()
def can_append(self, seq: Sequence) -> bool:
"""Check if we can append a token."""
need_new_block = (len(seq) % self._block_size == 1)
return len(self.free_logical_ids) >= int(need_new_block)
def may_append(self, seq: Sequence) -> None:
"""Handle potential new block allocation during decode."""
block_table = seq.block_table
last_logical_id = block_table[-1]
last_block = self.logical_blocks[last_logical_id]
seq_len = len(seq)
pos_in_block = seq_len % self._block_size
if pos_in_block == 1:
# Need new block
assert last_block.hash != -1
logical_id = self.free_logical_ids.popleft()
block = self.logical_blocks[logical_id]
block.ref_count = 1
block.hash = -1
block.token_ids = []
# New decode blocks go to GPU
gpu_slot = self._allocate_gpu_slot()
block.location = BlockLocation.GPU
block.gpu_slot = gpu_slot
self.gpu_slot_to_logical[gpu_slot] = logical_id
self.policy.on_block_allocated(gpu_slot, self.current_step)
block_table.append(logical_id)
elif pos_in_block == 0:
# Block is full, update hash for prefix cache
assert last_block.hash == -1
token_ids = seq.block(seq.num_blocks - 1)
prefix_hash = (
self.logical_blocks[block_table[-2]].hash
if len(block_table) > 1 else -1
)
h = self.compute_hash(token_ids, prefix_hash)
last_block.hash = h
last_block.token_ids = token_ids.copy()
self.hash_to_logical_id[h] = last_logical_id
def prepare_for_attention(
self,
seqs: List[Sequence],
is_prefill: bool,
) -> None:
"""
Prepare KV cache for attention computation.
For prefill: async prefetch blocks from CPU to GPU.
For decode: update gather_indices for CUDA graph.
"""
self.current_step += 1
# Collect all needed logical blocks
needed_logical_ids: Set[int] = set()
for seq in seqs:
needed_logical_ids.update(seq.block_table)
if is_prefill:
# Prefill: ensure all blocks on GPU (async prefetch)
# Pass needed_logical_ids as protected to prevent evicting blocks we need
for logical_id in needed_logical_ids:
block = self.logical_blocks[logical_id]
if block.location == BlockLocation.CPU:
self._ensure_on_gpu(logical_id, needed_logical_ids)
# Wait for all prefetches to complete
self.offload_engine.wait_all_transfers()
else:
# Decode: Check if we need chunked decode
cpu_blocks_count = sum(
1 for lid in needed_logical_ids
if self.logical_blocks[lid].location == BlockLocation.CPU
)
if cpu_blocks_count > self.num_gpu_slots:
# Too many blocks on CPU - will use chunked decode
# Don't try to load all blocks now
return
# Standard decode: prepare gather_indices for CUDA graph
# Identify blocks needing transfer
self.pending_gpu_loads.clear()
mappings_per_layer: List[List[Tuple[int, int]]] = [
[] for _ in range(self.offload_engine.num_layers)
]
for logical_id in needed_logical_ids:
block = self.logical_blocks[logical_id]
if block.location == BlockLocation.CPU:
# Allocate GPU slot (protect needed blocks from eviction)
gpu_slot = self._allocate_gpu_slot(needed_logical_ids)
# Record mapping for each layer
for layer_id in range(self.offload_engine.num_layers):
mappings_per_layer[layer_id].append(
(block.cpu_block_id, gpu_slot)
)
# Update block state
self.free_cpu_blocks.append(block.cpu_block_id)
del self.cpu_block_to_logical[block.cpu_block_id]
self.gpu_slot_to_logical[gpu_slot] = logical_id
block.location = BlockLocation.GPU
block.gpu_slot = gpu_slot
block.cpu_block_id = -1
self.pending_gpu_loads.add(logical_id)
self.policy.on_block_prefetched(gpu_slot, self.current_step)
elif block.location == BlockLocation.GPU:
self.policy.on_block_access(block.gpu_slot, self.current_step)
# Update gather indices (outside graph)
self.offload_engine.update_gather_indices_all_layers(mappings_per_layer)
self.offload_engine.sync_indices()
def needs_chunked_decode(self, seq: Sequence) -> bool:
"""
Check if sequence needs chunked decode.
Returns True if there are blocks on CPU and total blocks exceed GPU capacity.
"""
cpu_blocks = 0
for logical_id in seq.block_table:
block = self.logical_blocks[logical_id]
if block.location == BlockLocation.CPU:
cpu_blocks += 1
return cpu_blocks > 0 and len(seq.block_table) > self.num_gpu_slots
def load_all_kv_for_layer(
self,
seq: Sequence,
layer_id: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Load ALL KV for a sequence from both GPU and CPU for a layer.
Used during chunked decode to compute full attention.
Returns:
(k, v) tensors with shape [1, total_tokens, kv_heads, head_dim]
"""
k_chunks = []
v_chunks = []
for logical_id in seq.block_table:
block = self.logical_blocks[logical_id]
if block.location == BlockLocation.GPU:
# Get from GPU cache
k, v = self.offload_engine.get_layer_cache(layer_id)
# k, v shape: [num_gpu_blocks, block_size, kv_heads, head_dim]
k_block = k[block.gpu_slot] # [block_size, kv_heads, head_dim]
v_block = v[block.gpu_slot]
k_chunks.append(k_block)
v_chunks.append(v_block)
elif block.location == BlockLocation.CPU:
# Get from CPU cache
k_block, v_block = self.offload_engine.get_cpu_block(layer_id, block.cpu_block_id)
# Already [block_size, kv_heads, head_dim]
k_chunks.append(k_block.to("cuda", non_blocking=True))
v_chunks.append(v_block.to("cuda", non_blocking=True))
# Concatenate all chunks
k_all = torch.cat(k_chunks, dim=0) # [total_tokens, kv_heads, head_dim]
v_all = torch.cat(v_chunks, dim=0)
# Add batch dimension
k_all = k_all.unsqueeze(0) # [1, total_tokens, kv_heads, head_dim]
v_all = v_all.unsqueeze(0)
return k_all, v_all
def get_gpu_block_tables(
self,
seqs: List[Sequence],
) -> List[List[int]]:
"""
Get GPU slot tables for sequences.
Returns GPU slot IDs, which may differ from logical block IDs.
"""
result = []
for seq in seqs:
gpu_table = []
for logical_id in seq.block_table:
block = self.logical_blocks[logical_id]
assert block.location == BlockLocation.GPU, (
f"Block {logical_id} not on GPU (location={block.location})"
)
gpu_table.append(block.gpu_slot)
result.append(gpu_table)
return result
def post_attention_cleanup(
self,
seqs: List[Sequence],
is_prefill: bool,
) -> None:
"""
Cleanup after attention.
Clear pending loads and optionally proactive offload.
"""
self.pending_gpu_loads.clear()
# ========== Chunked Prefill Support ==========
def needs_chunked_prefill(self, seq: Sequence) -> bool:
"""
Check if sequence needs chunked prefill.
Returns True if there are unprefilled blocks that are on CPU.
This indicates we need to process in chunks because not all blocks fit on GPU.
"""
for logical_id in seq.block_table:
if logical_id not in self.prefilled_blocks:
block = self.logical_blocks[logical_id]
if block.location == BlockLocation.CPU:
return True
return False
def get_gpu_block_count(self, seq: Sequence) -> int:
"""Get number of blocks currently on GPU for this sequence."""
count = 0
for logical_id in seq.block_table:
block = self.logical_blocks[logical_id]
if block.location == BlockLocation.GPU:
count += 1
return count
def get_prefill_chunk_info(self, seq: Sequence) -> Tuple[int, int, List[int]]:
"""
Get information for current prefill chunk.
Returns:
(start_block_idx, end_block_idx, gpu_block_ids)
- start_block_idx: First block index in this chunk
- end_block_idx: Last block index (exclusive) in this chunk
- gpu_block_ids: GPU slot IDs for blocks in this chunk
"""
start_idx = -1
end_idx = -1
gpu_block_ids = []
for i, logical_id in enumerate(seq.block_table):
block = self.logical_blocks[logical_id]
if block.location == BlockLocation.GPU:
if start_idx == -1:
start_idx = i
end_idx = i + 1
gpu_block_ids.append(block.gpu_slot)
elif start_idx != -1:
# Found CPU block after GPU blocks - stop here
break
if start_idx == -1:
return (0, 0, [])
return (start_idx, end_idx, gpu_block_ids)
def complete_prefill_chunk(self, seq: Sequence) -> bool:
"""
Complete a prefill chunk: mark blocks as prefilled, offload to CPU, load next chunk.
Returns:
True if there are more chunks to process, False if done.
"""
# Find blocks currently on GPU that were just prefilled
gpu_blocks_to_offload = []
for logical_id in seq.block_table:
block = self.logical_blocks[logical_id]
if block.location == BlockLocation.GPU and logical_id not in self.prefilled_blocks:
# Mark as prefilled
self.prefilled_blocks.add(logical_id)
gpu_blocks_to_offload.append(logical_id)
# Offload prefilled GPU blocks to CPU
for logical_id in gpu_blocks_to_offload:
block = self.logical_blocks[logical_id]
if not self.free_cpu_blocks:
raise RuntimeError("No free CPU blocks for offload")
cpu_block_id = self.free_cpu_blocks.popleft()
# Async offload all layers
for layer_id in range(self.offload_engine.num_layers):
self.offload_engine.offload_block_async(
layer_id=layer_id,
gpu_block_id=block.gpu_slot,
cpu_block_id=cpu_block_id,
)
# Update mappings
self.free_gpu_slots.append(block.gpu_slot)
del self.gpu_slot_to_logical[block.gpu_slot]
self.cpu_block_to_logical[cpu_block_id] = logical_id
block.location = BlockLocation.CPU
block.cpu_block_id = cpu_block_id
block.gpu_slot = -1
# Wait for offload to complete
self.offload_engine.wait_all_transfers()
# Find next UNPREFILLED CPU blocks and bring them to GPU
cpu_blocks_to_load = []
for logical_id in seq.block_table:
if logical_id in self.prefilled_blocks:
continue # Skip already prefilled
block = self.logical_blocks[logical_id]
if block.location == BlockLocation.CPU:
if len(cpu_blocks_to_load) >= self.num_gpu_slots:
break # GPU is full
cpu_blocks_to_load.append(logical_id)
if not cpu_blocks_to_load:
return False # All blocks have been prefilled
# Load unprefilled CPU blocks to GPU
for logical_id in cpu_blocks_to_load:
block = self.logical_blocks[logical_id]
gpu_slot = self.free_gpu_slots.popleft()
# Note: We're NOT prefetching existing data - these blocks are being
# loaded for the first time, so we just need to assign GPU slots
# The model will write new KV cache data to these slots
# Update mappings
self.free_cpu_blocks.append(block.cpu_block_id)
del self.cpu_block_to_logical[block.cpu_block_id]
self.gpu_slot_to_logical[gpu_slot] = logical_id
block.location = BlockLocation.GPU
block.gpu_slot = gpu_slot
block.cpu_block_id = -1
return True # More chunks to process
def get_gpu_block_tables_partial(
self,
seqs: List[Sequence],
) -> List[Tuple[List[int], int, int]]:
"""
Get GPU block tables for chunked prefill.
Returns list of (gpu_block_ids, start_block_idx, end_block_idx) per sequence.
Only includes blocks that are currently on GPU AND haven't been prefilled yet.
"""
result = []
for seq in seqs:
gpu_table = []
start_idx = -1
end_idx = -1
for i, logical_id in enumerate(seq.block_table):
# Skip already prefilled blocks
if logical_id in self.prefilled_blocks:
continue
block = self.logical_blocks[logical_id]
if block.location == BlockLocation.GPU:
if start_idx == -1:
start_idx = i
end_idx = i + 1
gpu_table.append(block.gpu_slot)
elif start_idx != -1:
# Stop at first non-GPU block after GPU blocks
break
if start_idx == -1:
start_idx = 0
end_idx = 0
result.append((gpu_table, start_idx, end_idx))
return result
def get_prefilled_cpu_blocks(self, seq: Sequence) -> List[int]:
"""
Get list of CPU block IDs for blocks that have been prefilled.
Used for loading previous KV during chunked prefill.
Returns:
List of CPU block IDs in sequence order
"""
cpu_blocks = []
for logical_id in seq.block_table:
if logical_id in self.prefilled_blocks:
block = self.logical_blocks[logical_id]
if block.location == BlockLocation.CPU:
cpu_blocks.append(block.cpu_block_id)
return cpu_blocks
def load_prev_kv_for_layer(
self,
seq: Sequence,
layer_id: int,
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
Load previous prefilled KV from CPU for a specific layer.
This concatenates KV from all previously prefilled blocks for use
during chunked prefill attention.
Args:
seq: Sequence to load KV for
layer_id: Layer index
Returns:
(k, v) tensors with shape [1, total_prev_tokens, kv_heads, head_dim]
or (None, None) if no previous KV exists
"""
cpu_blocks = self.get_prefilled_cpu_blocks(seq)
if not cpu_blocks:
return None, None
k_chunks = []
v_chunks = []
for cpu_block_id in cpu_blocks:
k, v = self.offload_engine.get_cpu_block(layer_id, cpu_block_id)
# k, v shape: [block_size, kv_heads, head_dim]
k_chunks.append(k)
v_chunks.append(v)
# Concatenate all chunks
k_prev = torch.cat(k_chunks, dim=0) # [total_prev_tokens, kv_heads, head_dim]
v_prev = torch.cat(v_chunks, dim=0)
# Move to GPU and add batch dimension
k_prev = k_prev.to("cuda", non_blocking=True).unsqueeze(0) # [1, tokens, heads, dim]
v_prev = v_prev.to("cuda", non_blocking=True).unsqueeze(0)
return k_prev, v_prev
def get_chunk_start_position(self, seq: Sequence) -> int:
"""
Get the starting token position for the current chunk.
This is the total number of tokens in previously prefilled blocks.
Returns:
Token position offset for current chunk
"""
pos = 0
for logical_id in seq.block_table:
if logical_id in self.prefilled_blocks:
# Full block's worth of tokens
pos += self._block_size
else:
break
return pos
def __repr__(self) -> str:
return (
f"HybridKVCacheManager(\n"
f" num_gpu_slots={self.num_gpu_slots},\n"
f" num_cpu_blocks={self.num_cpu_blocks},\n"
f" block_size={self._block_size},\n"
f" free_logical={len(self.free_logical_ids)},\n"
f" free_gpu={len(self.free_gpu_slots)},\n"
f" free_cpu={len(self.free_cpu_blocks)},\n"
f" policy={self.policy}\n"
f")"
)

190
nanovllm/kvcache/kernels.py Normal file
View File

@@ -0,0 +1,190 @@
"""
Triton kernels for CPU-GPU KV cache transfer.
These kernels are designed to be CUDA Graph compatible:
- All tensor addresses are fixed at graph capture time
- Only the content of index tensors changes between replays
"""
import torch
import triton
import triton.language as tl
@triton.jit
def gathered_copy_kernel(
src_ptr, # Source tensor base pointer (CPU pinned or GPU)
dst_ptr, # Destination tensor base pointer (GPU)
indices_ptr, # Gather indices [num_dst_blocks]
num_dst_blocks, # Number of destination blocks
block_numel: tl.constexpr, # Elements per block (block_size * kv_heads * head_dim)
BLOCK_SIZE: tl.constexpr = 1024,
):
"""
Gathered copy kernel: dst[i] = src[indices[i]]
Each program instance handles one destination block.
The indices tensor specifies which source block to copy from.
This kernel is CUDA Graph compatible because:
- src_ptr, dst_ptr, indices_ptr addresses are fixed
- Only indices content changes between graph replays
Args:
src_ptr: Base pointer to source blocks [num_src_blocks, block_numel]
dst_ptr: Base pointer to destination blocks [num_dst_blocks, block_numel]
indices_ptr: Gather indices [num_dst_blocks], each value is a source block index
num_dst_blocks: Number of destination blocks to copy
block_numel: Number of elements per block
BLOCK_SIZE: Triton block size for parallelization
"""
dst_block_idx = tl.program_id(0)
# Skip if out of range
if dst_block_idx >= num_dst_blocks:
return
# Load source block index from indices tensor
src_block_idx = tl.load(indices_ptr + dst_block_idx)
# Skip if index is -1 (invalid/no-op marker)
if src_block_idx < 0:
return
# Calculate base offsets
src_base = src_block_idx * block_numel
dst_base = dst_block_idx * block_numel
# Copy block data in chunks of BLOCK_SIZE
for start in range(0, block_numel, BLOCK_SIZE):
offsets = start + tl.arange(0, BLOCK_SIZE)
mask = offsets < block_numel
# Load from source and store to destination
data = tl.load(src_ptr + src_base + offsets, mask=mask)
tl.store(dst_ptr + dst_base + offsets, data, mask=mask)
@triton.jit
def gathered_copy_kv_kernel(
k_src_ptr, # K cache source [num_src_blocks, block_size, kv_heads, head_dim]
v_src_ptr, # V cache source
k_dst_ptr, # K cache destination
v_dst_ptr, # V cache destination
indices_ptr, # Gather indices [num_dst_blocks]
num_dst_blocks, # Number of destination blocks
block_numel: tl.constexpr, # Elements per block
BLOCK_SIZE: tl.constexpr = 1024,
):
"""
Gathered copy for both K and V caches simultaneously.
More efficient than calling gathered_copy_kernel twice because:
- Single kernel launch overhead
- Better memory access patterns when K and V are accessed together
"""
dst_block_idx = tl.program_id(0)
if dst_block_idx >= num_dst_blocks:
return
src_block_idx = tl.load(indices_ptr + dst_block_idx)
if src_block_idx < 0:
return
src_base = src_block_idx * block_numel
dst_base = dst_block_idx * block_numel
for start in range(0, block_numel, BLOCK_SIZE):
offsets = start + tl.arange(0, BLOCK_SIZE)
mask = offsets < block_numel
# Copy K cache
k_data = tl.load(k_src_ptr + src_base + offsets, mask=mask)
tl.store(k_dst_ptr + dst_base + offsets, k_data, mask=mask)
# Copy V cache
v_data = tl.load(v_src_ptr + src_base + offsets, mask=mask)
tl.store(v_dst_ptr + dst_base + offsets, v_data, mask=mask)
def gathered_copy(
src: torch.Tensor,
dst: torch.Tensor,
indices: torch.Tensor,
) -> None:
"""
Perform gathered copy: dst[i] = src[indices[i]]
Args:
src: Source tensor [num_src_blocks, ...]
dst: Destination tensor [num_dst_blocks, ...]
indices: Index tensor [num_dst_blocks], dtype=int64
-1 means skip (no-op)
Note:
- src can be on CPU (pinned memory) or GPU
- dst must be on GPU
- indices must be on GPU
- All shapes after first dimension must match
"""
assert dst.is_cuda, "Destination must be on GPU"
assert indices.is_cuda, "Indices must be on GPU"
assert src.shape[1:] == dst.shape[1:], "Shape mismatch after first dimension"
num_dst_blocks = dst.shape[0]
block_numel = dst[0].numel()
# Flatten for kernel
src_flat = src.view(src.shape[0], -1)
dst_flat = dst.view(dst.shape[0], -1)
grid = (num_dst_blocks,)
gathered_copy_kernel[grid](
src_flat,
dst_flat,
indices,
num_dst_blocks,
block_numel=block_numel,
)
def gathered_copy_kv(
k_src: torch.Tensor,
v_src: torch.Tensor,
k_dst: torch.Tensor,
v_dst: torch.Tensor,
indices: torch.Tensor,
) -> None:
"""
Perform gathered copy for both K and V caches.
Args:
k_src, v_src: Source K/V caches [num_src_blocks, block_size, kv_heads, head_dim]
k_dst, v_dst: Destination K/V caches [num_dst_blocks, block_size, kv_heads, head_dim]
indices: Index tensor [num_dst_blocks], dtype=int64
"""
assert k_dst.is_cuda and v_dst.is_cuda, "Destinations must be on GPU"
assert indices.is_cuda, "Indices must be on GPU"
assert k_src.shape[1:] == k_dst.shape[1:], "K shape mismatch"
assert v_src.shape[1:] == v_dst.shape[1:], "V shape mismatch"
num_dst_blocks = k_dst.shape[0]
block_numel = k_dst[0].numel()
k_src_flat = k_src.view(k_src.shape[0], -1)
v_src_flat = v_src.view(v_src.shape[0], -1)
k_dst_flat = k_dst.view(k_dst.shape[0], -1)
v_dst_flat = v_dst.view(v_dst.shape[0], -1)
grid = (num_dst_blocks,)
gathered_copy_kv_kernel[grid](
k_src_flat,
v_src_flat,
k_dst_flat,
v_dst_flat,
indices,
num_dst_blocks,
block_numel=block_numel,
)

View File

@@ -0,0 +1,400 @@
"""
High-performance CPU-GPU KV cache transfer engine.
Key design principles for CUDA Graph compatibility:
1. All tensor addresses are fixed at initialization
2. Only index tensor contents change between graph replays
3. Supports both async transfer (for prefill) and graph-based transfer (for decode)
"""
import torch
from torch import Tensor
from typing import Dict, List, Tuple, Optional
from dataclasses import dataclass
from nanovllm.kvcache.kernels import gathered_copy_kv
@dataclass
class TransferEvent:
"""Tracks a pending async transfer."""
event: torch.cuda.Event
layer_id: int
src_block_id: int
dst_block_id: int
direction: str # "h2d" or "d2h"
class OffloadEngine:
"""
High-performance CPU-GPU async transfer engine for KV cache offloading.
Memory layout:
- GPU cache: [num_layers, num_gpu_blocks, block_size, kv_heads, head_dim]
- CPU cache: [num_layers, num_cpu_blocks, block_size, kv_heads, head_dim] (pinned)
- Gather indices: [num_layers, num_gpu_blocks] (fixed address, variable content)
CUDA Graph compatibility:
- gathered_h2d_layer() can be captured into CUDA graphs
- update_gather_indices() is called outside graphs to prepare indices
- All tensor addresses remain fixed across graph replays
"""
def __init__(
self,
num_layers: int,
num_gpu_blocks: int,
num_cpu_blocks: int,
block_size: int,
num_kv_heads: int,
head_dim: int,
dtype: torch.dtype = torch.float16,
num_streams: int = 4,
):
self.num_layers = num_layers
self.num_gpu_blocks = num_gpu_blocks
self.num_cpu_blocks = num_cpu_blocks
self.block_size = block_size
self.num_kv_heads = num_kv_heads
self.head_dim = head_dim
self.dtype = dtype
self.kv_dim = num_kv_heads * head_dim
self.block_numel = block_size * self.kv_dim
# ========== Fixed-address GPU KV cache ==========
# Shape: [num_layers, num_gpu_blocks, block_size, kv_heads, head_dim]
self.k_cache_gpu = torch.empty(
num_layers, num_gpu_blocks, block_size, num_kv_heads, head_dim,
dtype=dtype, device="cuda"
)
self.v_cache_gpu = torch.empty(
num_layers, num_gpu_blocks, block_size, num_kv_heads, head_dim,
dtype=dtype, device="cuda"
)
# ========== Fixed-address CPU KV cache (pinned memory) ==========
self.k_cache_cpu = torch.empty(
num_layers, num_cpu_blocks, block_size, num_kv_heads, head_dim,
dtype=dtype, device="cpu", pin_memory=True
)
self.v_cache_cpu = torch.empty(
num_layers, num_cpu_blocks, block_size, num_kv_heads, head_dim,
dtype=dtype, device="cpu", pin_memory=True
)
# ========== Fixed-address gather indices (content is variable) ==========
# gather_indices[layer][i] = CPU block id to copy to GPU slot i
# -1 means no-op (skip this slot)
self.gather_indices_cpu = torch.empty(
num_layers, num_gpu_blocks,
dtype=torch.int64, device="cpu", pin_memory=True
)
self.gather_indices_cpu.fill_(-1)
self.gather_indices_gpu = torch.full(
(num_layers, num_gpu_blocks), -1,
dtype=torch.int64, device="cuda"
)
# ========== Transfer streams for async operations ==========
self.transfer_streams = [torch.cuda.Stream() for _ in range(num_streams)]
self.compute_stream = torch.cuda.current_stream()
self._stream_idx = 0
# ========== Event tracking for async transfers ==========
self.pending_events: Dict[Tuple[int, int], torch.cuda.Event] = {}
def _get_next_stream(self) -> torch.cuda.Stream:
"""Round-robin stream selection for parallel transfers."""
stream = self.transfer_streams[self._stream_idx]
self._stream_idx = (self._stream_idx + 1) % len(self.transfer_streams)
return stream
# ========== CUDA Graph compatible methods ==========
def gathered_h2d_layer(self, layer_id: int) -> None:
"""
Execute gathered H2D copy for a single layer.
This method is CUDA Graph compatible - can be captured into a graph.
Before calling, update_gather_indices() must be called to set up
which CPU blocks to copy to which GPU slots.
Args:
layer_id: Layer index to transfer
"""
gathered_copy_kv(
k_src=self.k_cache_cpu[layer_id],
v_src=self.v_cache_cpu[layer_id],
k_dst=self.k_cache_gpu[layer_id],
v_dst=self.v_cache_gpu[layer_id],
indices=self.gather_indices_gpu[layer_id],
)
def gathered_h2d_all_layers(self) -> None:
"""
Execute gathered H2D copy for all layers.
CUDA Graph compatible - can be captured into a single graph.
"""
for layer_id in range(self.num_layers):
self.gathered_h2d_layer(layer_id)
def update_gather_indices(
self,
layer_id: int,
mappings: List[Tuple[int, int]],
) -> None:
"""
Update gather indices for a layer (call OUTSIDE CUDA graph).
Args:
layer_id: Layer index
mappings: List of (cpu_block_id, gpu_slot) tuples
Only these slots will be updated; others keep their values
"""
for cpu_block_id, gpu_slot in mappings:
self.gather_indices_cpu[layer_id, gpu_slot] = cpu_block_id
# Async copy to GPU
self.gather_indices_gpu[layer_id].copy_(
self.gather_indices_cpu[layer_id],
non_blocking=True
)
def update_gather_indices_all_layers(
self,
mappings_per_layer: List[List[Tuple[int, int]]],
) -> None:
"""
Update gather indices for all layers.
Args:
mappings_per_layer: mappings_per_layer[layer_id] = [(cpu_block_id, gpu_slot), ...]
"""
for layer_id, mappings in enumerate(mappings_per_layer):
for cpu_block_id, gpu_slot in mappings:
self.gather_indices_cpu[layer_id, gpu_slot] = cpu_block_id
# Batch copy all layers
self.gather_indices_gpu.copy_(self.gather_indices_cpu, non_blocking=True)
def clear_gather_indices(self, layer_id: Optional[int] = None) -> None:
"""
Clear gather indices (set all to -1, meaning no-op).
Args:
layer_id: If provided, clear only this layer; otherwise clear all
"""
if layer_id is not None:
self.gather_indices_cpu[layer_id].fill_(-1)
self.gather_indices_gpu[layer_id].fill_(-1)
else:
self.gather_indices_cpu.fill_(-1)
self.gather_indices_gpu.fill_(-1)
# ========== Async transfer methods (for prefill, outside CUDA graph) ==========
def prefetch_block_async(
self,
layer_id: int,
cpu_block_id: int,
gpu_block_id: int,
) -> torch.cuda.Event:
"""
Async prefetch a single block from CPU to GPU.
For use in prefill phase where CUDA graphs are not used.
Args:
layer_id: Layer index
cpu_block_id: Source block in CPU cache
gpu_block_id: Destination slot in GPU cache
Returns:
CUDA event that signals completion
"""
stream = self._get_next_stream()
event = torch.cuda.Event()
with torch.cuda.stream(stream):
# K cache
self.k_cache_gpu[layer_id, gpu_block_id].copy_(
self.k_cache_cpu[layer_id, cpu_block_id],
non_blocking=True
)
# V cache
self.v_cache_gpu[layer_id, gpu_block_id].copy_(
self.v_cache_cpu[layer_id, cpu_block_id],
non_blocking=True
)
event.record()
self.pending_events[(layer_id, gpu_block_id)] = event
return event
def prefetch_blocks_batch_async(
self,
transfers: List[Tuple[int, int, int]], # [(layer_id, cpu_block_id, gpu_block_id), ...]
) -> List[torch.cuda.Event]:
"""
Batch async prefetch multiple blocks.
Args:
transfers: List of (layer_id, cpu_block_id, gpu_block_id) tuples
Returns:
List of CUDA events for each transfer
"""
events = []
for layer_id, cpu_block_id, gpu_block_id in transfers:
event = self.prefetch_block_async(layer_id, cpu_block_id, gpu_block_id)
events.append(event)
return events
def offload_block_async(
self,
layer_id: int,
gpu_block_id: int,
cpu_block_id: int,
) -> torch.cuda.Event:
"""
Async offload a block from GPU to CPU.
Args:
layer_id: Layer index
gpu_block_id: Source slot in GPU cache
cpu_block_id: Destination block in CPU cache
Returns:
CUDA event that signals completion
"""
stream = self._get_next_stream()
event = torch.cuda.Event()
with torch.cuda.stream(stream):
# Wait for any compute using this block
stream.wait_stream(self.compute_stream)
# K cache
self.k_cache_cpu[layer_id, cpu_block_id].copy_(
self.k_cache_gpu[layer_id, gpu_block_id],
non_blocking=True
)
# V cache
self.v_cache_cpu[layer_id, cpu_block_id].copy_(
self.v_cache_gpu[layer_id, gpu_block_id],
non_blocking=True
)
event.record()
return event
def offload_blocks_batch_async(
self,
transfers: List[Tuple[int, int, int]], # [(layer_id, gpu_block_id, cpu_block_id), ...]
) -> List[torch.cuda.Event]:
"""
Batch async offload multiple blocks.
Args:
transfers: List of (layer_id, gpu_block_id, cpu_block_id) tuples
Returns:
List of CUDA events
"""
events = []
for layer_id, gpu_block_id, cpu_block_id in transfers:
event = self.offload_block_async(layer_id, gpu_block_id, cpu_block_id)
events.append(event)
return events
# ========== Synchronization methods ==========
def wait_for_block(self, layer_id: int, gpu_block_id: int) -> None:
"""Wait for a specific block's transfer to complete."""
key = (layer_id, gpu_block_id)
if key in self.pending_events:
self.pending_events[key].synchronize()
del self.pending_events[key]
def wait_all_transfers(self) -> None:
"""Wait for all pending transfers to complete."""
for stream in self.transfer_streams:
stream.synchronize()
self.pending_events.clear()
def sync_indices(self) -> None:
"""Synchronize to ensure all index updates are complete."""
torch.cuda.current_stream().synchronize()
# ========== Cache access methods ==========
def get_layer_cache(self, layer_id: int) -> Tuple[Tensor, Tensor]:
"""
Get GPU K/V cache tensors for a specific layer.
Returns:
(k_cache, v_cache) tensors for the layer
Shape: [num_gpu_blocks, block_size, kv_heads, head_dim]
"""
return self.k_cache_gpu[layer_id], self.v_cache_gpu[layer_id]
def get_all_gpu_cache(self) -> Tuple[Tensor, Tensor]:
"""
Get full GPU K/V cache tensors.
Returns:
(k_cache, v_cache) tensors
Shape: [num_layers, num_gpu_blocks, block_size, kv_heads, head_dim]
"""
return self.k_cache_gpu, self.v_cache_gpu
def get_cpu_block(
self,
layer_id: int,
cpu_block_id: int,
) -> Tuple[Tensor, Tensor]:
"""
Get a specific CPU block's K/V cache.
Returns:
(k_cache, v_cache) for the block
Shape: [block_size, kv_heads, head_dim]
"""
return (
self.k_cache_cpu[layer_id, cpu_block_id],
self.v_cache_cpu[layer_id, cpu_block_id],
)
# ========== Memory info ==========
def gpu_memory_bytes(self) -> int:
"""Total GPU memory used by KV caches."""
return (
self.k_cache_gpu.numel() * self.k_cache_gpu.element_size() +
self.v_cache_gpu.numel() * self.v_cache_gpu.element_size() +
self.gather_indices_gpu.numel() * self.gather_indices_gpu.element_size()
)
def cpu_memory_bytes(self) -> int:
"""Total CPU memory used by KV caches."""
return (
self.k_cache_cpu.numel() * self.k_cache_cpu.element_size() +
self.v_cache_cpu.numel() * self.v_cache_cpu.element_size() +
self.gather_indices_cpu.numel() * self.gather_indices_cpu.element_size()
)
def __repr__(self) -> str:
return (
f"OffloadEngine(\n"
f" num_layers={self.num_layers},\n"
f" num_gpu_blocks={self.num_gpu_blocks},\n"
f" num_cpu_blocks={self.num_cpu_blocks},\n"
f" block_size={self.block_size},\n"
f" kv_heads={self.num_kv_heads},\n"
f" head_dim={self.head_dim},\n"
f" dtype={self.dtype},\n"
f" gpu_memory={self.gpu_memory_bytes() / 1024**2:.1f}MB,\n"
f" cpu_memory={self.cpu_memory_bytes() / 1024**2:.1f}MB\n"
f")"
)

View File

@@ -0,0 +1,51 @@
"""
Eviction policy plugins for KV cache offloading.
Users can create custom policies by subclassing EvictionPolicy
and specifying the full class path in config.offload_policy.
"""
from nanovllm.kvcache.policies.base_policy import EvictionPolicy
from nanovllm.kvcache.policies.lru_policy import LRUPolicy
from nanovllm.kvcache.policies.fifo_policy import FIFOPolicy
# Built-in policy registry
BUILTIN_POLICIES = {
"lru": LRUPolicy,
"fifo": FIFOPolicy,
}
def get_policy(policy_name: str) -> EvictionPolicy:
"""
Get an eviction policy instance by name or class path.
Args:
policy_name: Either a built-in name ("lru", "fifo") or
a full class path ("mymodule.MyPolicy")
Returns:
EvictionPolicy instance
"""
# Check built-in policies first
if policy_name.lower() in BUILTIN_POLICIES:
return BUILTIN_POLICIES[policy_name.lower()]()
# Try to import custom policy
try:
module_path, class_name = policy_name.rsplit(".", 1)
import importlib
module = importlib.import_module(module_path)
policy_class = getattr(module, class_name)
if not issubclass(policy_class, EvictionPolicy):
raise TypeError(f"{policy_name} is not a subclass of EvictionPolicy")
return policy_class()
except (ValueError, ImportError, AttributeError) as e:
raise ValueError(
f"Unknown policy '{policy_name}'. "
f"Available built-in policies: {list(BUILTIN_POLICIES.keys())}. "
f"For custom policies, use full class path: 'mymodule.MyPolicy'"
) from e
__all__ = ["EvictionPolicy", "LRUPolicy", "FIFOPolicy", "get_policy", "BUILTIN_POLICIES"]

View File

@@ -0,0 +1,156 @@
"""
Base class for eviction policies.
Users can implement custom policies by subclassing EvictionPolicy
and overriding the abstract methods.
"""
from abc import ABC, abstractmethod
from typing import Set, Optional
class EvictionPolicy(ABC):
"""
Abstract base class for KV cache eviction policies.
An eviction policy determines which GPU blocks to evict to CPU
when GPU memory is full and new blocks need to be allocated.
Lifecycle:
1. on_block_allocated() - called when a new block is allocated
2. on_block_access() - called each time a block is accessed (e.g., in attention)
3. select_victim() - called when a block needs to be evicted
4. on_block_evicted() - called after a block is evicted
Example custom policy:
```python
class MyCustomPolicy(EvictionPolicy):
def __init__(self):
self.priorities = {}
def on_block_allocated(self, block_id: int, step: int):
self.priorities[block_id] = step
def on_block_access(self, block_id: int, step: int):
# Custom access tracking
pass
def select_victim(self, candidates: Set[int]) -> int:
# Return block with lowest priority
return min(candidates, key=lambda b: self.priorities.get(b, 0))
def on_block_evicted(self, block_id: int):
self.priorities.pop(block_id, None)
```
"""
@abstractmethod
def on_block_allocated(self, block_id: int, step: int) -> None:
"""
Called when a new block is allocated on GPU.
Args:
block_id: The GPU block ID that was allocated
step: Current inference step (monotonically increasing)
"""
pass
@abstractmethod
def on_block_access(self, block_id: int, step: int) -> None:
"""
Called when a block is accessed during attention computation.
Args:
block_id: The GPU block ID being accessed
step: Current inference step
"""
pass
@abstractmethod
def select_victim(self, candidates: Set[int]) -> int:
"""
Select a block to evict from the candidate set.
This is called when GPU memory is full and a new block
needs to be allocated. The returned block will be evicted
to CPU.
Args:
candidates: Set of GPU block IDs that can be evicted
(blocks not currently being used)
Returns:
Block ID to evict
Raises:
ValueError: If candidates is empty
"""
pass
@abstractmethod
def on_block_evicted(self, block_id: int) -> None:
"""
Called after a block is evicted from GPU to CPU.
Args:
block_id: The GPU block ID that was evicted
"""
pass
def on_block_prefetched(self, block_id: int, step: int) -> None:
"""
Called when a block is prefetched from CPU back to GPU.
Default implementation calls on_block_allocated().
Override for custom behavior.
Args:
block_id: The GPU block ID that was prefetched to
step: Current inference step
"""
self.on_block_allocated(block_id, step)
def on_block_deallocated(self, block_id: int) -> None:
"""
Called when a block is fully deallocated (sequence finished).
Default implementation calls on_block_evicted().
Override for custom behavior.
Args:
block_id: The GPU block ID being deallocated
"""
self.on_block_evicted(block_id)
def reset(self) -> None:
"""
Reset policy state.
Called when the inference engine is reset.
Default implementation does nothing.
"""
pass
def get_eviction_order(self, candidates: Set[int], count: int) -> list:
"""
Get multiple blocks to evict in order of priority.
Default implementation calls select_victim() repeatedly.
Override for more efficient batch selection.
Args:
candidates: Set of candidate block IDs
count: Number of blocks to evict
Returns:
List of block IDs to evict, in order
"""
result = []
remaining = set(candidates)
for _ in range(min(count, len(remaining))):
if not remaining:
break
victim = self.select_victim(remaining)
result.append(victim)
remaining.remove(victim)
return result

View File

@@ -0,0 +1,101 @@
"""
FIFO (First In, First Out) eviction policy.
Evicts the block that was allocated earliest.
Simple policy that ignores access patterns.
"""
from collections import OrderedDict
from typing import Set
from nanovllm.kvcache.policies.base_policy import EvictionPolicy
class FIFOPolicy(EvictionPolicy):
"""
First In, First Out (FIFO) eviction policy.
Evicts blocks in the order they were allocated,
regardless of access patterns.
Properties:
- O(1) operations for all methods
- Simple and predictable behavior
- Good for streaming workloads where older data
is naturally less relevant
- Does not adapt to access patterns (unlike LRU)
"""
def __init__(self):
# OrderedDict maintains insertion order
# Key: block_id, Value: allocation_step
# Oldest (first allocated) is at the front
self.allocation_order: OrderedDict[int, int] = OrderedDict()
def on_block_allocated(self, block_id: int, step: int) -> None:
"""Record allocation order (does not change on access)."""
if block_id not in self.allocation_order:
self.allocation_order[block_id] = step
def on_block_access(self, block_id: int, step: int) -> None:
"""
FIFO ignores access patterns.
This is the key difference from LRU - we don't
update the position based on access.
"""
pass # Intentionally empty
def select_victim(self, candidates: Set[int]) -> int:
"""
Select the earliest allocated block from candidates.
"""
if not candidates:
raise ValueError("Cannot select victim from empty candidate set")
# Iterate from oldest (front) to newest (back)
for block_id in self.allocation_order:
if block_id in candidates:
return block_id
# Fallback: return any candidate
return next(iter(candidates))
def on_block_evicted(self, block_id: int) -> None:
"""Remove block from tracking."""
self.allocation_order.pop(block_id, None)
def on_block_prefetched(self, block_id: int, step: int) -> None:
"""
When prefetched, treat as new allocation.
This moves the block to the end of the queue,
giving it more time before eviction.
"""
# Remove old entry if exists
self.allocation_order.pop(block_id, None)
# Add as new allocation
self.allocation_order[block_id] = step
def on_block_deallocated(self, block_id: int) -> None:
"""Remove block from tracking."""
self.allocation_order.pop(block_id, None)
def reset(self) -> None:
"""Clear all tracking data."""
self.allocation_order.clear()
def get_eviction_order(self, candidates: Set[int], count: int) -> list:
"""
Get multiple blocks to evict in FIFO order.
"""
result = []
for block_id in self.allocation_order:
if block_id in candidates:
result.append(block_id)
if len(result) >= count:
break
return result
def __repr__(self) -> str:
return f"FIFOPolicy(tracked_blocks={len(self.allocation_order)})"

View File

@@ -0,0 +1,93 @@
"""
LRU (Least Recently Used) eviction policy.
Evicts the block that was accessed least recently.
This is the default and recommended policy for most use cases.
"""
from collections import OrderedDict
from typing import Set
from nanovllm.kvcache.policies.base_policy import EvictionPolicy
class LRUPolicy(EvictionPolicy):
"""
Least Recently Used (LRU) eviction policy.
Maintains an ordered dictionary of block access times.
When eviction is needed, selects the block that was
accessed least recently.
Properties:
- O(1) access tracking
- O(n) victim selection in worst case, but typically fast
due to OrderedDict iteration order
- Good for workloads with temporal locality
"""
def __init__(self):
# OrderedDict maintains insertion/update order
# Key: block_id, Value: last_access_step
# Oldest (least recently used) is at the front
self.access_order: OrderedDict[int, int] = OrderedDict()
def on_block_allocated(self, block_id: int, step: int) -> None:
"""Record allocation as an access."""
# Move to end (most recently used)
self.access_order[block_id] = step
self.access_order.move_to_end(block_id)
def on_block_access(self, block_id: int, step: int) -> None:
"""Update access time and move to end."""
if block_id in self.access_order:
self.access_order[block_id] = step
self.access_order.move_to_end(block_id)
def select_victim(self, candidates: Set[int]) -> int:
"""
Select the least recently used block from candidates.
Iterates from oldest to newest in access order,
returns the first one that's in the candidate set.
"""
if not candidates:
raise ValueError("Cannot select victim from empty candidate set")
# Iterate from oldest (front) to newest (back)
for block_id in self.access_order:
if block_id in candidates:
return block_id
# Fallback: return any candidate (shouldn't happen normally)
return next(iter(candidates))
def on_block_evicted(self, block_id: int) -> None:
"""Remove block from tracking."""
self.access_order.pop(block_id, None)
def on_block_deallocated(self, block_id: int) -> None:
"""Remove block from tracking."""
self.access_order.pop(block_id, None)
def reset(self) -> None:
"""Clear all tracking data."""
self.access_order.clear()
def get_eviction_order(self, candidates: Set[int], count: int) -> list:
"""
Efficiently get multiple blocks to evict in LRU order.
Optimized for batch eviction - iterates through access_order
once instead of calling select_victim() multiple times.
"""
result = []
for block_id in self.access_order:
if block_id in candidates:
result.append(block_id)
if len(result) >= count:
break
return result
def __repr__(self) -> str:
return f"LRUPolicy(tracked_blocks={len(self.access_order)})"

View File

@@ -55,21 +55,164 @@ class Attention(nn.Module):
self.scale = scale
self.num_kv_heads = num_kv_heads
self.k_cache = self.v_cache = torch.tensor([])
# Layer ID set by model_runner after model creation
self.layer_id: int = -1
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
context = get_context()
k_cache, v_cache = self.k_cache, self.v_cache
if k_cache.numel() and v_cache.numel():
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
if context.is_prefill:
if context.block_tables is not None: # prefix cache
if context.is_chunked_prefill:
# Chunked prefill: merge attention from previous KV
o = self._chunked_prefill_attention(q, k, v, context)
elif context.block_tables is not None: # prefix cache
k, v = k_cache, v_cache
o = flash_attn_varlen_func(q, k, v,
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
softmax_scale=self.scale, causal=True, block_table=context.block_tables)
o = flash_attn_varlen_func(q, k, v,
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
softmax_scale=self.scale, causal=True, block_table=context.block_tables)
else:
o = flash_attn_varlen_func(q, k, v,
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
softmax_scale=self.scale, causal=True, block_table=context.block_tables)
else: # decode
o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache,
cache_seqlens=context.context_lens, block_table=context.block_tables,
softmax_scale=self.scale, causal=True)
if context.is_chunked_prefill:
# Chunked decode: need to load all KV from CPU+GPU
o = self._chunked_decode_attention(q, k, v, context)
else:
o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache,
cache_seqlens=context.context_lens, block_table=context.block_tables,
softmax_scale=self.scale, causal=True)
return o
def _chunked_prefill_attention(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
context,
) -> torch.Tensor:
"""
Compute attention with chunked KV from CPU cache.
For chunked prefill:
1. Load previous KV from CPU for this layer
2. Compute attention against previous KV (no causal mask)
3. Compute attention against current chunk's KV (causal)
4. Merge results using online softmax
"""
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
# q, k, v shape: [total_tokens, num_heads, head_dim]
total_tokens = q.shape[0]
# Reshape for flash attention: [batch, seq, heads, dim]
q_batched = q.unsqueeze(0) # [1, total_tokens, heads, dim]
k_batched = k.unsqueeze(0)
v_batched = v.unsqueeze(0)
accumulated_o = None
accumulated_lse = None
# Load previous KV from CPU for this layer
if context.offload_engine is not None and self.layer_id >= 0:
# Get the kvcache_manager from context
kvcache_manager = context.offload_engine
# For each sequence in the chunk, load previous KV
# Currently assuming single sequence
if hasattr(context, 'chunked_seq') and context.chunked_seq is not None:
prev_k, prev_v = kvcache_manager.load_prev_kv_for_layer(
context.chunked_seq,
self.layer_id,
)
if prev_k is not None and prev_v is not None:
# Compute attention against previous KV (no causal mask)
prev_o, prev_lse = flash_attn_with_lse(
q_batched,
prev_k,
prev_v,
softmax_scale=self.scale,
causal=False, # No causal mask for previous context
)
accumulated_o = prev_o
accumulated_lse = prev_lse
# Compute attention against current chunk's KV (with causal mask)
current_o, current_lse = flash_attn_with_lse(
q_batched,
k_batched,
v_batched,
softmax_scale=self.scale,
causal=True, # Causal mask for current chunk
)
# Merge with accumulated
if accumulated_o is None:
final_o = current_o
else:
final_o, _ = merge_attention_outputs(
accumulated_o, accumulated_lse,
current_o, current_lse,
)
# Remove batch dimension: [1, total_tokens, heads, dim] -> [total_tokens, heads, dim]
return final_o.squeeze(0)
def _chunked_decode_attention(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
context,
) -> torch.Tensor:
"""
Compute decode attention with KV spread across CPU and GPU.
For decode with chunked KV:
1. Load all KV for this layer from CPU+GPU
2. Compute attention (1 query token vs all KV)
3. Return output
"""
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse
# q shape: [batch_size, num_heads, head_dim] (single decode token per sequence)
# We need to attend to ALL previous tokens
# Load all KV for this layer
if context.offload_engine is not None and self.layer_id >= 0:
kvcache_manager = context.offload_engine
if hasattr(context, 'chunked_seq') and context.chunked_seq is not None:
# Load all KV from both GPU and CPU for this layer
k_all, v_all = kvcache_manager.load_all_kv_for_layer(
context.chunked_seq,
self.layer_id,
)
if k_all is not None and v_all is not None:
# q shape: [batch_size, num_heads, head_dim]
# Need: [batch, seqlen, heads, dim]
# Insert seqlen dimension at position 1
q_batched = q.unsqueeze(1) # [batch, 1, heads, dim]
# k_all, v_all shape: [1, total_kv_tokens, kv_heads, head_dim]
# Compute attention (no causal mask for decode - we want all KV)
out, _ = flash_attn_with_lse(
q_batched,
k_all,
v_all,
softmax_scale=self.scale,
causal=False, # No causal mask for decode
)
# Output shape: [batch, 1, heads, dim] -> [batch, heads, dim]
return out.squeeze(1)
# Fallback: shouldn't reach here
raise RuntimeError("Chunked decode attention failed: no KV available")

View File

@@ -1,4 +1,5 @@
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Optional, List, Tuple, Any
import torch
@@ -13,14 +14,60 @@ class Context:
context_lens: torch.Tensor | None = None
block_tables: torch.Tensor | None = None
# Chunked prefill support
is_chunked_prefill: bool = False
# Previous KV chunks info: List of (start_pos, end_pos) for blocks on CPU
prev_kv_ranges: List[Tuple[int, int]] = field(default_factory=list)
# Current chunk's position offset (for causal mask)
chunk_offset: int = 0
# Reference to kvcache manager for loading previous KV (HybridKVCacheManager)
offload_engine: Any = None
# Current layer's previous K/V chunks (loaded from CPU)
# Set by model_runner before each layer's forward
prev_kv_chunks: List[Tuple[torch.Tensor, torch.Tensor]] = field(default_factory=list)
# Current sequence being processed (for chunked prefill to load KV)
chunked_seq: Any = None
_CONTEXT = Context()
def get_context():
return _CONTEXT
def set_context(is_prefill, cu_seqlens_q=None, cu_seqlens_k=None, max_seqlen_q=0, max_seqlen_k=0, slot_mapping=None, context_lens=None, block_tables=None):
def set_context(
is_prefill,
cu_seqlens_q=None,
cu_seqlens_k=None,
max_seqlen_q=0,
max_seqlen_k=0,
slot_mapping=None,
context_lens=None,
block_tables=None,
is_chunked_prefill=False,
prev_kv_ranges=None,
chunk_offset=0,
offload_engine=None,
chunked_seq=None,
):
global _CONTEXT
_CONTEXT = Context(is_prefill, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, context_lens, block_tables)
_CONTEXT = Context(
is_prefill=is_prefill,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
slot_mapping=slot_mapping,
context_lens=context_lens,
block_tables=block_tables,
is_chunked_prefill=is_chunked_prefill,
prev_kv_ranges=prev_kv_ranges or [],
chunk_offset=chunk_offset,
offload_engine=offload_engine,
chunked_seq=chunked_seq,
)
def reset_context():
global _CONTEXT