Files
nano-vllm/nanovllm/kvcache/hybrid_manager.py

599 lines
21 KiB
Python

"""
Hybrid CPU-GPU KV cache manager with CUDA Graph support.
Key design for CUDA Graph compatibility:
1. GPU buffer has fixed addresses (allocated once)
2. CPU pool has fixed addresses (pinned memory)
3. gather_indices tensor has fixed address, variable content
4. H2D transfer uses gathered_copy kernel inside CUDA graphs
5. Graph replay only needs index updates (tiny overhead)
"""
import logging
from collections import deque
from dataclasses import dataclass, field
from enum import Enum, auto
from typing import List, Tuple, Dict, Set, Optional
import torch
from torch import Tensor
logger = logging.getLogger(__name__)
from nanovllm.engine.sequence import Sequence
from nanovllm.kvcache.base_manager import KVCacheManager
from nanovllm.kvcache.offload_engine import OffloadEngine
from nanovllm.kvcache.policies.base_policy import EvictionPolicy
from nanovllm.kvcache.policies.lru_policy import LRUPolicy
# Type checking import for sparse policy
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from nanovllm.kvcache.sparse.policy import SparsePolicy
class BlockLocation(Enum):
"""Where a logical block's data currently resides."""
GPU = auto()
CPU = auto()
INVALID = auto() # Not yet written / deallocated
@dataclass
class LogicalBlock:
"""
Logical block that can be mapped to GPU or CPU physical storage.
Sequences reference logical blocks. Physical blocks are the actual
storage locations (GPU slots or CPU blocks).
"""
logical_id: int
location: BlockLocation = BlockLocation.INVALID
gpu_slot: int = -1 # GPU buffer slot ID (if on GPU)
cpu_block_id: int = -1 # CPU pool block ID (if on CPU)
ref_count: int = 0
hash: int = -1
token_ids: List[int] = field(default_factory=list)
def reset(self):
self.location = BlockLocation.INVALID
self.gpu_slot = -1
self.cpu_block_id = -1
self.ref_count = 0
self.hash = -1
self.token_ids = []
class HybridKVCacheManager(KVCacheManager):
"""
Hybrid CPU-GPU KV cache manager with layer-wise offload design.
Architecture (CPU-primary mode):
- CPU pool: Primary storage for all KV cache (num_cpu_blocks)
- GPU ring buffer: For decode H2D pipeline (num_kv_buffers)
- Decode buffer: Per-layer accumulation of decode tokens (block_size)
Design:
- All KV cache is stored on CPU as primary storage
- GPU ring buffer enables pipelined H2D transfers during decode
- During prefill: KV is computed and offloaded layer-by-layer to CPU
- During decode: Previous KV is loaded from CPU via ring buffer pipeline
Note:
- Logical blocks map 1:1 with CPU blocks (total_blocks = num_cpu_blocks)
- GPU ring buffer is for decode pipeline, not persistent storage
"""
def __init__(
self,
num_gpu_slots: int,
num_cpu_blocks: int,
block_size: int,
policy: Optional[EvictionPolicy] = None,
sparse_policy: "SparsePolicy" = None,
num_kv_buffers: int = 4,
max_seq_len: int = 131072,
):
"""
Initialize hybrid manager with layer-wise offload design.
All KV cache is stored on CPU as primary storage. GPU ring buffer is used
for decode H2D pipeline.
Args:
num_gpu_slots: Number of GPU buffer slots (kept for backward compat, not used)
num_cpu_blocks: Number of CPU pool blocks (primary storage)
block_size: Tokens per block
policy: Eviction policy (default: LRU, used for prefix cache management)
sparse_policy: Sparse attention policy (Quest for decode-only sparse)
num_kv_buffers: Ring buffer size for decode H2D pipeline
max_seq_len: Maximum sequence length for GPU buffer allocation
"""
self._block_size = block_size
self.num_gpu_slots = num_gpu_slots
self.num_cpu_blocks = num_cpu_blocks
self.num_kv_buffers = num_kv_buffers
self.max_seq_len = max_seq_len
# In CPU-primary mode, logical blocks map 1:1 with CPU blocks
# GPU ring buffer is for decode pipeline, not persistent storage
self.total_blocks = num_cpu_blocks
# Eviction policy
self.policy = policy or LRUPolicy()
# Sparse attention policy (set at construction time, immutable)
self.sparse_policy = sparse_policy
# Logical blocks (what sequences reference) - one per CPU block
self.logical_blocks: List[LogicalBlock] = [
LogicalBlock(i) for i in range(self.total_blocks)
]
self.free_logical_ids: deque[int] = deque(range(self.total_blocks))
# GPU slot management (kept for potential future use, but not used in CPU-primary mode)
self.free_gpu_slots: deque[int] = deque(range(num_gpu_slots))
self.gpu_slot_to_logical: Dict[int, int] = {} # gpu_slot -> logical_id (unused in CPU-primary mode)
# CPU block management
self.free_cpu_blocks: deque[int] = deque(range(num_cpu_blocks))
self.cpu_block_to_logical: Dict[int, int] = {} # cpu_block -> logical_id
# Prefix cache (uses logical block IDs)
# NOTE: Currently WRITE-ONLY in offload mode - hashes are stored but never
#> used for cache hit detection. This is intentional: offload mode always
#> allocates new blocks and doesn't reuse existing ones.
self.hash_to_logical_id: Dict[int, int] = {}
# Step counter for policy
self.current_step = 0
# Offload engine (set by allocate_cache)
self.offload_engine: Optional[OffloadEngine] = None
# Track blocks pending GPU load (for decode graph)
self.pending_gpu_loads: Set[int] = set() # logical_ids
# Track blocks that have been prefilled (KV offloaded to CPU)
self.prefilled_blocks: Set[int] = set() # logical_ids
# Track decode starting position within block (for batched offload optimization)
# Key: sequence id, Value: starting position where decode began in current block
self._decode_start_pos: Dict[int, int] = {}
# Track original prefill length (for correct last_block_valid_tokens calculation)
# Key: sequence id, Value: number of tokens from prefill (before decode started)
self._prefill_len: Dict[int, int] = {}
@property
def block_size(self) -> int:
return self._block_size
@property
def num_free_blocks(self) -> int:
return len(self.free_logical_ids)
def allocate_cache(
self,
num_layers: int,
num_kv_heads: int,
head_dim: int,
dtype: torch.dtype,
) -> None:
"""Initialize the offload engine with actual cache storage."""
self.offload_engine = OffloadEngine(
num_layers=num_layers,
num_gpu_blocks=self.num_gpu_slots,
num_cpu_blocks=self.num_cpu_blocks,
block_size=self._block_size,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
dtype=dtype,
num_kv_buffers=self.num_kv_buffers,
max_seq_len=self.max_seq_len,
sparse_policy=self.sparse_policy,
)
def get_layer_cache(self, layer_id: int) -> Tuple[Tensor, Tensor]:
"""
Get GPU K/V cache tensors for a layer.
Note: In layer-wise offload mode, this returns empty tensors as KV
is managed directly by the offload engine's ring buffer.
"""
assert self.offload_engine is not None
# Return empty tensors - actual KV is in offload_engine's ring buffer
return torch.empty(0), torch.empty(0)
def can_allocate(self, seq: Sequence) -> bool:
"""Check if we can allocate blocks for a new sequence."""
return len(self.free_logical_ids) >= seq.num_blocks
def allocate(self, seq: Sequence) -> None:
"""
Allocate logical blocks for prefill.
All blocks are allocated to CPU (primary storage).
GPU is used as ring buffer for computation only.
"""
return self.allocate_cpu_only(seq)
def deallocate(self, seq: Sequence) -> None:
"""Release all blocks for a sequence."""
for logical_id in reversed(seq.block_table):
block = self.logical_blocks[logical_id]
block.ref_count -= 1
if block.ref_count == 0:
# Free physical block based on location
# Note: In CPU-primary mode, blocks are always on CPU.
# GPU branch kept for potential future hybrid mode support.
if block.location == BlockLocation.GPU:
self.free_gpu_slots.append(block.gpu_slot)
del self.gpu_slot_to_logical[block.gpu_slot]
self.policy.on_block_deallocated(block.gpu_slot)
elif block.location == BlockLocation.CPU:
self.free_cpu_blocks.append(block.cpu_block_id)
del self.cpu_block_to_logical[block.cpu_block_id]
# Free logical block
block.reset()
self.free_logical_ids.append(logical_id)
# Remove from prefilled tracking
self.prefilled_blocks.discard(logical_id)
seq.num_cached_tokens = 0
seq.block_table.clear()
# Clear decode tracking to prevent state pollution between requests
self.clear_decode_tracking(seq)
# Clear offload engine state (decode buffer, events)
if self.offload_engine is not None:
self.offload_engine.on_sequence_finished()
def can_append(self, seq: Sequence) -> bool:
"""Check if we can append a token."""
need_new_block = (len(seq) % self._block_size == 1)
return len(self.free_logical_ids) >= int(need_new_block)
def may_append(self, seq: Sequence) -> None:
"""Handle potential new block allocation during decode."""
block_table = seq.block_table
last_logical_id = block_table[-1]
last_block = self.logical_blocks[last_logical_id]
seq_len = len(seq)
pos_in_block = seq_len % self._block_size
if pos_in_block == 1:
# Need new block (previous block is full)
logical_id = self.free_logical_ids.popleft()
block = self.logical_blocks[logical_id]
block.ref_count = 1
# Allocate new block to CPU (ring buffer mode)
if not self.free_cpu_blocks:
raise RuntimeError("No free CPU blocks for decode")
cpu_block_id = self.free_cpu_blocks.popleft()
block.location = BlockLocation.CPU
block.cpu_block_id = cpu_block_id
block.gpu_slot = -1
self.cpu_block_to_logical[cpu_block_id] = logical_id
block_table.append(logical_id)
elif pos_in_block == 0:
# Block is full
# NOTE: Prefix cache disabled in offload mode
# If enabled, would compute hash and update:
# h = self.compute_hash(seq.block(seq.num_blocks - 1), prefix_hash)
# last_block.hash = h
# self.hash_to_logical_id[h] = last_logical_id
pass
def prepare_for_attention(
self,
seqs: List[Sequence],
is_prefill: bool,
) -> None:
"""
Prepare KV cache for attention computation.
In layer-wise offload mode, this is a no-op because KV transfers
are handled directly in model_runner's layer-by-layer methods.
"""
pass
def get_gpu_block_tables(
self,
seqs: List[Sequence],
) -> List[List[int]]:
"""
Get GPU slot tables for sequences.
In layer-wise offload mode, all blocks are on CPU, so this raises an error
if called. Use run_layerwise_offload_* methods instead.
"""
raise RuntimeError(
"get_gpu_block_tables should not be called in layer-wise offload mode. "
"Use run_layerwise_offload_prefill/decode instead."
)
def post_attention_cleanup(
self,
seqs: List[Sequence],
is_prefill: bool,
) -> None:
"""
Cleanup after attention.
In layer-wise offload mode, this is a no-op because offload is handled
directly in model_runner's layer-by-layer methods.
"""
pass
# ========== Layer-wise Offload Support ==========
def get_prefilled_cpu_blocks(self, seq: Sequence) -> List[int]:
"""
Get list of CPU block IDs for blocks that have been prefilled.
Used for loading prefilled KV during decode.
Returns:
List of CPU block IDs in sequence order
"""
cpu_blocks = []
for logical_id in seq.block_table:
if logical_id in self.prefilled_blocks:
block = self.logical_blocks[logical_id]
if block.location == BlockLocation.CPU:
cpu_blocks.append(block.cpu_block_id)
# DEBUG: Log on first decode call
logger.debug(
f"[DEBUG] get_prefilled_cpu_blocks: block_table={list(seq.block_table)}, "
f"prefilled_blocks={list(self.prefilled_blocks)}, "
f"returned cpu_blocks={cpu_blocks}"
)
return cpu_blocks
# ========== CPU Block Allocation ==========
def allocate_cpu_only(self, seq: Sequence) -> None:
"""
Allocate CPU blocks for sequence (for layer-wise offload mode).
Unlike allocate(), here all blocks are allocated to CPU,
GPU is only used as ring buffer for computation.
Args:
seq: Sequence to allocate
"""
assert not seq.block_table, "Sequence already has blocks"
for i in range(seq.num_blocks):
# Allocate CPU block
if not self.free_cpu_blocks:
raise RuntimeError(
f"No free CPU blocks. Need {seq.num_blocks}, "
f"available: {len(self.free_cpu_blocks)}"
)
cpu_block_id = self.free_cpu_blocks.popleft()
# Allocate logical block
logical_id = self.free_logical_ids.popleft()
block = self.logical_blocks[logical_id]
block.ref_count = 1
block.location = BlockLocation.CPU
block.cpu_block_id = cpu_block_id
block.gpu_slot = -1
self.cpu_block_to_logical[cpu_block_id] = logical_id
seq.block_table.append(logical_id)
# DEBUG: Log allocated CPU blocks
cpu_blocks = [self.logical_blocks[lid].cpu_block_id for lid in seq.block_table]
logger.debug(f"[DEBUG] allocate_cpu_only: allocated cpu_blocks={cpu_blocks}")
# NOTE: Prefix cache disabled in offload mode
# If enabled, would compute hash and update:
# h = self.compute_hash(seq.block(i), prefix_hash)
# block.hash = h
# self.hash_to_logical_id[h] = logical_id
def get_cpu_block_table(self, seq: Sequence) -> List[int]:
"""
Get CPU block ID list for sequence.
Args:
seq: Sequence
Returns:
List of CPU block IDs in sequence order
"""
cpu_blocks = []
for logical_id in seq.block_table:
block = self.logical_blocks[logical_id]
if block.location == BlockLocation.CPU:
cpu_blocks.append(block.cpu_block_id)
else:
# If block is on GPU, it should have a corresponding CPU block
# In ring buffer mode, all data ultimately resides on CPU
raise RuntimeError(
f"Block {logical_id} not on CPU (location={block.location}). "
f"In ring buffer mode, all blocks should be on CPU."
)
return cpu_blocks
def get_all_cpu_blocks(self, seq: Sequence) -> Tuple[List[int], List[int]]:
"""
Get all CPU blocks and their logical IDs for sequence.
Args:
seq: Sequence
Returns:
(cpu_block_ids, logical_ids)
"""
cpu_block_ids = []
logical_ids = []
for logical_id in seq.block_table:
block = self.logical_blocks[logical_id]
if block.location == BlockLocation.CPU:
cpu_block_ids.append(block.cpu_block_id)
logical_ids.append(logical_id)
# DEBUG: Log during prefill
logger.debug(f"[DEBUG] get_all_cpu_blocks: returned cpu_block_ids={cpu_block_ids}")
return cpu_block_ids, logical_ids
def allocate_next_cpu_block(self, seq: Sequence) -> int:
"""
Allocate next CPU block for sequence (for new token during decode).
Args:
seq: Sequence
Returns:
Newly allocated CPU block ID
"""
if not self.free_cpu_blocks:
raise RuntimeError("No free CPU blocks")
cpu_block_id = self.free_cpu_blocks.popleft()
logical_id = self.free_logical_ids.popleft()
block = self.logical_blocks[logical_id]
block.ref_count = 1
block.location = BlockLocation.CPU
block.cpu_block_id = cpu_block_id
block.gpu_slot = -1
self.cpu_block_to_logical[cpu_block_id] = logical_id
seq.block_table.append(logical_id)
return cpu_block_id
def get_last_cpu_block(self, seq: Sequence) -> int:
"""
Get CPU block ID of the last block in sequence.
Returns -1 if the last block is not on CPU.
Args:
seq: Sequence
Returns:
CPU block ID, or -1 if not on CPU
"""
if not seq.block_table:
return -1
last_logical_id = seq.block_table[-1]
block = self.logical_blocks[last_logical_id]
if block.location == BlockLocation.CPU:
return block.cpu_block_id
return -1
def get_decode_start_pos(self, seq: Sequence) -> int:
"""
Get the starting position within block where decode tokens began.
This is used for batched offload optimization - we need to attend to all
accumulated tokens in decode slot, not just the current one.
Args:
seq: Sequence
Returns:
Starting position within block (0 to block_size-1)
"""
seq_id = id(seq)
if seq_id not in self._decode_start_pos:
# First decode step - compute starting position
# After prefill, the last block has some tokens filled
# Decode starts at the next position
prefill_len = len(seq) - 1 # Current len includes the new decode token
self._decode_start_pos[seq_id] = prefill_len % self._block_size
# DEBUG: Log first access
logger.debug(
f"[DEBUG] get_decode_start_pos FIRST ACCESS: seq_id={seq_id}, "
f"len(seq)={len(seq)}, prefill_len={prefill_len}, "
f"stored decode_start_pos={self._decode_start_pos[seq_id]}"
)
return self._decode_start_pos[seq_id]
def reset_decode_start_pos(self, seq: Sequence) -> None:
"""
Reset decode start position for sequence.
Called when block is full and offloaded - next decode starts at position 0.
Args:
seq: Sequence
"""
seq_id = id(seq)
self._decode_start_pos[seq_id] = 0
def get_prefill_len(self, seq: Sequence) -> int:
"""
Get the original prefill length for a sequence.
This is cached on first call to ensure correct last_block_valid_tokens
calculation during decode (the CPU blocks don't change after prefill).
Args:
seq: Sequence
Returns:
Number of tokens from prefill (before decode started)
"""
seq_id = id(seq)
if seq_id not in self._prefill_len:
# First decode step - store the prefill length
# len(seq) - 1 because current len includes the first decode token
self._prefill_len[seq_id] = len(seq) - 1
# DEBUG: Log first access
logger.debug(
f"[DEBUG] get_prefill_len FIRST ACCESS: seq_id={seq_id}, "
f"len(seq)={len(seq)}, stored prefill_len={self._prefill_len[seq_id]}"
)
return self._prefill_len[seq_id]
def clear_decode_tracking(self, seq: Sequence) -> None:
"""
Clear decode position tracking for sequence.
Called when sequence is deallocated.
Args:
seq: Sequence
"""
seq_id = id(seq)
# DEBUG: Log clearing and CPU blocks
cpu_blocks = [self.logical_blocks[lid].cpu_block_id for lid in seq.block_table
if self.logical_blocks[lid].location == BlockLocation.CPU]
logger.debug(
f"[DEBUG] clear_decode_tracking: seq_id={seq_id}, "
f"clearing decode_start_pos={self._decode_start_pos.get(seq_id, 'N/A')}, "
f"prefill_len={self._prefill_len.get(seq_id, 'N/A')}, "
f"cpu_blocks={cpu_blocks}"
)
self._decode_start_pos.pop(seq_id, None)
self._prefill_len.pop(seq_id, None)
def __repr__(self) -> str:
return (
f"HybridKVCacheManager(\n"
f" num_gpu_slots={self.num_gpu_slots},\n"
f" num_cpu_blocks={self.num_cpu_blocks},\n"
f" block_size={self._block_size},\n"
f" free_logical={len(self.free_logical_ids)},\n"
f" free_gpu={len(self.free_gpu_slots)},\n"
f" free_cpu={len(self.free_cpu_blocks)},\n"
f" policy={self.policy}\n"
f")"
)