Files
nano-vllm/nanovllm/kvcache/offload_engine.py

428 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
High-performance CPU-GPU KV cache transfer engine for layer-wise offload.
Key design principles:
1. Layer-wise processing: process entire sequence through one layer at a time
2. Ring-buffered GPU KV cache for decode phase (configurable num_kv_buffers)
3. Async D2H offload during prefill with per-layer streams
4. Async H2D load during decode with ring buffer pipeline
"""
import torch
import torch.cuda.nvtx
from torch import Tensor
from typing import Dict, List, Tuple, Optional
from nanovllm.utils.logger import get_logger
# Import for type hints only (avoid circular import)
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from nanovllm.kvcache.sparse import SparsePolicy
logger = get_logger("offload_engine")
class OffloadEngine:
"""
High-performance CPU-GPU async transfer engine for layer-wise KV cache offloading.
Memory layout:
- CPU cache: [num_layers, num_cpu_blocks, block_size, kv_heads, head_dim] (pinned)
- GPU layer buffers: [num_kv_buffers, max_seq_tokens, kv_heads, head_dim] (ring buffer)
- Decode KV buffer: [num_layers, block_size, kv_heads, head_dim] (per-layer decode)
Features:
- Ring buffer for decode H2D pipeline (configurable depth)
- Per-layer async D2H offload during prefill
- Stream-based synchronization (no global synchronize)
"""
def __init__(
self,
num_layers: int,
num_gpu_blocks: int,
num_cpu_blocks: int,
block_size: int,
num_kv_heads: int,
head_dim: int,
dtype: torch.dtype = torch.float16,
num_kv_buffers: int = 4,
max_seq_len: int = 131072,
sparse_policy: "SparsePolicy" = None,
):
self.num_layers = num_layers
self.num_gpu_blocks = num_gpu_blocks
self.num_cpu_blocks = num_cpu_blocks
self.block_size = block_size
self.num_kv_heads = num_kv_heads
self.head_dim = head_dim
self.dtype = dtype
self.kv_dim = num_kv_heads * head_dim
self.block_numel = block_size * self.kv_dim
self.num_kv_buffers = num_kv_buffers
self.max_seq_len = max_seq_len
logger.info(f"OffloadEngine initializing: num_layers={num_layers}, "
f"num_kv_buffers={num_kv_buffers}, max_seq_len={max_seq_len}")
# ========== Ring-Buffered GPU KV Cache for Layer-wise Decode ==========
#
# Ring Buffer流水线 (以4个buffer为例):
# Buffer 0: [Load L0] → [Compute L0] → [Load L4] → ...
# Buffer 1: [Load L1] → [Compute L1] → [Load L5] → ...
# Buffer 2: [Load L2] → [Compute L2] → ...
# Buffer 3: [Load L3] → [Compute L3] → ...
#
# Shape: [num_kv_buffers, max_seq_len, kv_heads, head_dim]
self.layer_k_cache = torch.zeros(
num_kv_buffers, max_seq_len, num_kv_heads, head_dim,
dtype=dtype, device="cuda"
)
self.layer_v_cache = torch.zeros(
num_kv_buffers, max_seq_len, num_kv_heads, head_dim,
dtype=dtype, device="cuda"
)
layer_cache_mb = 2 * num_kv_buffers * max_seq_len * num_kv_heads * head_dim * dtype.itemsize / (1024 * 1024)
logger.info(f" Ring buffer GPU cache: {layer_cache_mb:.1f} MB "
f"({num_kv_buffers} buffers × {max_seq_len} tokens)")
# ========== Per-layer Decode Buffer ==========
# During decode, accumulate new tokens' KV per layer until block is full
# Shape: [num_layers, block_size, kv_heads, head_dim]
self.decode_k_buffer = torch.zeros(
num_layers, block_size, num_kv_heads, head_dim,
dtype=dtype, device="cuda"
)
self.decode_v_buffer = torch.zeros(
num_layers, block_size, num_kv_heads, head_dim,
dtype=dtype, device="cuda"
)
decode_buf_mb = 2 * num_layers * block_size * num_kv_heads * head_dim * dtype.itemsize / (1024 * 1024)
logger.info(f" Per-layer decode buffer: {decode_buf_mb:.1f} MB")
# ========== Fixed-address CPU KV cache (pinned memory) ==========
self.k_cache_cpu = torch.zeros(
num_layers, num_cpu_blocks, block_size, num_kv_heads, head_dim,
dtype=dtype, device="cpu", pin_memory=True
)
self.v_cache_cpu = torch.zeros(
num_layers, num_cpu_blocks, block_size, num_kv_heads, head_dim,
dtype=dtype, device="cpu", pin_memory=True
)
cpu_mem_mb = 2 * num_layers * num_cpu_blocks * block_size * num_kv_heads * head_dim * dtype.itemsize / (1024 * 1024)
logger.info(f" CPU cache: {cpu_mem_mb:.1f} MB "
f"({num_layers} layers × {num_cpu_blocks} blocks)")
# ========== Compute Stream ==========
# IMPORTANT: Create a dedicated compute stream (not default stream!)
# Default stream has implicit synchronization with other streams,
# which prevents overlap between transfer and compute.
self.compute_stream = torch.cuda.Stream()
# ========== Prefill: Per-layer D2H offload streams and events ==========
# Each layer has its own stream for parallel offloads
self.prefill_offload_streams = [torch.cuda.Stream() for _ in range(num_layers)]
self.prefill_offload_events = [torch.cuda.Event() for _ in range(num_layers)]
# ========== Decode: Ring buffer H2D load streams and events ==========
# Per-buffer streams for parallel loading
self.layer_load_streams = [torch.cuda.Stream() for _ in range(num_kv_buffers)]
self.buffer_load_events = [torch.cuda.Event() for _ in range(num_kv_buffers)]
self.buffer_compute_done_events = [torch.cuda.Event() for _ in range(num_kv_buffers)]
# Initialize: mark all buffers as "compute done" (allows first load)
for event in self.buffer_compute_done_events:
event.record()
# ========== Decode offload stream ==========
self.decode_offload_stream = torch.cuda.Stream()
self.decode_offload_event = torch.cuda.Event()
# ========== Sparse attention policy ==========
self.sparse_policy = sparse_policy
logger.info(f"OffloadEngine initialized: GPU={self.gpu_memory_bytes()/(1024**2):.1f}MB, "
f"CPU={self.cpu_memory_bytes()/(1024**2):.1f}MB")
# ========== Memory info ==========
def gpu_memory_bytes(self) -> int:
"""Total GPU memory used by KV caches."""
return (
self.layer_k_cache.numel() * self.layer_k_cache.element_size() +
self.layer_v_cache.numel() * self.layer_v_cache.element_size() +
self.decode_k_buffer.numel() * self.decode_k_buffer.element_size() +
self.decode_v_buffer.numel() * self.decode_v_buffer.element_size()
)
def cpu_memory_bytes(self) -> int:
"""Total CPU memory used by KV caches."""
return (
self.k_cache_cpu.numel() * self.k_cache_cpu.element_size() +
self.v_cache_cpu.numel() * self.v_cache_cpu.element_size()
)
def __repr__(self) -> str:
return (
f"OffloadEngine(\n"
f" num_layers={self.num_layers},\n"
f" num_kv_buffers={self.num_kv_buffers},\n"
f" max_seq_len={self.max_seq_len},\n"
f" num_cpu_blocks={self.num_cpu_blocks},\n"
f" block_size={self.block_size},\n"
f" kv_heads={self.num_kv_heads},\n"
f" head_dim={self.head_dim},\n"
f" dtype={self.dtype},\n"
f" gpu_memory={self.gpu_memory_bytes() / 1024**2:.1f}MB,\n"
f" cpu_memory={self.cpu_memory_bytes() / 1024**2:.1f}MB\n"
f")"
)
# ========== State Reset ==========
def on_sequence_finished(self):
"""
Clear state after sequence completion to prevent pollution between requests.
Called by HybridKVCacheManager.deallocate() when a sequence finishes.
"""
# Clear decode buffer to prevent residual KV from affecting next request
self.decode_k_buffer.zero_()
self.decode_v_buffer.zero_()
# Re-record buffer_compute_done_events to mark all buffers as available
for event in self.buffer_compute_done_events:
event.record()
logger.debug("OffloadEngine: state cleared for next sequence")
# ========== Prefill: Async D2H Offload API ==========
def offload_layer_kv_async(
self,
layer_id: int,
k: Tensor,
v: Tensor,
cpu_block_ids: List[int],
total_tokens: int,
) -> None:
"""
Async offload layer KV to CPU using per-layer stream.
This enables overlap: layer N offload overlaps with layer N+1 compute.
Args:
layer_id: Layer index
k: Key tensor [seq_len, kv_heads, head_dim]
v: Value tensor [seq_len, kv_heads, head_dim]
cpu_block_ids: List of CPU block IDs to offload to
total_tokens: Total number of tokens
"""
stream = self.prefill_offload_streams[layer_id]
torch.cuda.nvtx.range_push(f"D2H: L{layer_id}")
with torch.cuda.stream(stream):
# Wait for compute to finish
stream.wait_stream(self.compute_stream)
# Copy to CPU in blocks
for i, cpu_block_id in enumerate(cpu_block_ids):
start = i * self.block_size
end = min(start + self.block_size, total_tokens)
actual_size = end - start
self.k_cache_cpu[layer_id, cpu_block_id, :actual_size].copy_(
k[start:end], non_blocking=True
)
self.v_cache_cpu[layer_id, cpu_block_id, :actual_size].copy_(
v[start:end], non_blocking=True
)
# Record completion event
self.prefill_offload_events[layer_id].record(stream)
torch.cuda.nvtx.range_pop()
def wait_layer_offload(self, layer_id: int) -> None:
"""
Wait for specific layer's offload to complete on compute_stream.
Call this before reusing the layer's GPU buffer.
"""
self.compute_stream.wait_event(self.prefill_offload_events[layer_id])
def wait_all_prefill_offloads(self) -> None:
"""Wait for all prefill offloads to complete."""
for stream in self.prefill_offload_streams:
stream.synchronize()
# ========== Decode: Ring-Buffered H2D Load API ==========
def load_layer_kv_to_buffer(
self,
buffer_idx: int,
layer_id: int,
cpu_block_ids: List[int],
valid_tokens_per_block: List[int],
) -> None:
"""
Async load layer KV from CPU to specified ring buffer slot.
Args:
buffer_idx: Ring buffer slot index (0 to num_kv_buffers-1)
layer_id: Which layer's KV to load
cpu_block_ids: CPU block IDs containing this layer's KV
valid_tokens_per_block: Number of valid tokens in each block
"""
stream = self.layer_load_streams[buffer_idx]
torch.cuda.nvtx.range_push(f"H2D: L{layer_id}->Buf{buffer_idx}")
with torch.cuda.stream(stream):
# Wait for previous compute on this buffer to complete
stream.wait_event(self.buffer_compute_done_events[buffer_idx])
offset = 0
for i, cpu_block_id in enumerate(cpu_block_ids):
valid_tokens = valid_tokens_per_block[i]
self.layer_k_cache[buffer_idx, offset:offset+valid_tokens].copy_(
self.k_cache_cpu[layer_id, cpu_block_id, :valid_tokens],
non_blocking=True
)
self.layer_v_cache[buffer_idx, offset:offset+valid_tokens].copy_(
self.v_cache_cpu[layer_id, cpu_block_id, :valid_tokens],
non_blocking=True
)
offset += valid_tokens
self.buffer_load_events[buffer_idx].record(stream)
torch.cuda.nvtx.range_pop()
def wait_buffer_load(self, buffer_idx: int) -> None:
"""Wait for buffer load to complete on compute_stream."""
self.compute_stream.wait_event(self.buffer_load_events[buffer_idx])
def get_buffer_kv(self, buffer_idx: int, total_tokens: int) -> Tuple[Tensor, Tensor]:
"""Get KV from specified ring buffer slot."""
return (
self.layer_k_cache[buffer_idx, :total_tokens],
self.layer_v_cache[buffer_idx, :total_tokens]
)
def record_buffer_compute_done(self, buffer_idx: int) -> None:
"""Record that compute on this buffer is done (allows next load to reuse it)."""
self.buffer_compute_done_events[buffer_idx].record(self.compute_stream)
# ========== Decode Buffer API ==========
def get_decode_kv(self, layer_id: int, start_pos: int, end_pos: int) -> Tuple[Tensor, Tensor]:
"""
Get accumulated decode KV for a layer.
Args:
layer_id: Layer index
start_pos: Start position in block
end_pos: End position in block (exclusive)
Returns:
(k, v) tensors with shape [end_pos - start_pos, kv_heads, head_dim]
"""
return (
self.decode_k_buffer[layer_id, start_pos:end_pos],
self.decode_v_buffer[layer_id, start_pos:end_pos]
)
def store_decode_kv(
self,
layer_id: int,
pos_in_block: int,
k: Tensor,
v: Tensor,
) -> None:
"""
Store new decode token's KV to decode buffer.
Args:
layer_id: Layer index
pos_in_block: Position within block (0 to block_size-1)
k: Key tensor [1, kv_heads, head_dim]
v: Value tensor [1, kv_heads, head_dim]
"""
self.decode_k_buffer[layer_id, pos_in_block].copy_(k.squeeze(0))
self.decode_v_buffer[layer_id, pos_in_block].copy_(v.squeeze(0))
def offload_decode_buffer_async(self, cpu_block_id: int) -> None:
"""
Async offload entire decode buffer to CPU.
Called when a decode block is full. Also calls sparse policy hooks
to update metadata (e.g., Quest min/max keys).
Args:
cpu_block_id: Target CPU block ID
"""
torch.cuda.nvtx.range_push(f"D2H: DecBuf->CPU[{cpu_block_id}]")
with torch.cuda.stream(self.decode_offload_stream):
self.decode_offload_stream.wait_stream(self.compute_stream)
for layer_id in range(self.num_layers):
# Hook: notify sparse policy BEFORE offload (k still on GPU)
if self.sparse_policy is not None:
self.sparse_policy.on_decode_offload(
cpu_block_id, layer_id,
self.decode_k_buffer[layer_id],
self.block_size # Full block
)
self.k_cache_cpu[layer_id, cpu_block_id].copy_(
self.decode_k_buffer[layer_id], non_blocking=True
)
self.v_cache_cpu[layer_id, cpu_block_id].copy_(
self.decode_v_buffer[layer_id], non_blocking=True
)
self.decode_offload_event.record(self.decode_offload_stream)
torch.cuda.nvtx.range_pop()
def wait_decode_offload(self) -> None:
"""Wait for decode buffer offload to complete."""
self.compute_stream.wait_event(self.decode_offload_event)
# ========== Encapsulated Prefill Offload API (with sparse hooks) ==========
def offload_layer_kv_sync(
self,
layer_id: int,
k: Tensor,
v: Tensor,
cpu_block_ids: List[int],
total_tokens: int,
) -> None:
"""
Synchronously offload layer KV to CPU with sparse policy hooks.
This method encapsulates:
1. Block-wise copy to CPU cache
2. Sparse policy hooks (on_prefill_offload for Quest metadata)
Args:
layer_id: Layer index
k: Key tensor [seq_len, kv_heads, head_dim]
v: Value tensor [seq_len, kv_heads, head_dim]
cpu_block_ids: List of CPU block IDs to offload to
total_tokens: Total number of tokens
"""
for i, cpu_block_id in enumerate(cpu_block_ids):
start = i * self.block_size
end = min(start + self.block_size, total_tokens)
actual_size = end - start
# Hook: notify sparse policy BEFORE offload (k still on GPU)
if self.sparse_policy is not None:
self.sparse_policy.on_prefill_offload(
cpu_block_id, layer_id, k[start:end], actual_size
)
# Synchronous copy to CPU
self.k_cache_cpu[layer_id, cpu_block_id, :actual_size].copy_(k[start:end])
self.v_cache_cpu[layer_id, cpu_block_id, :actual_size].copy_(v[start:end])