Files
nano-vllm/nanovllm/kvcache/offload_engine.py
2025-12-31 23:35:25 +08:00

1021 lines
38 KiB
Python

"""
High-performance CPU-GPU KV cache transfer engine.
Key design principles for CUDA Graph compatibility:
1. All tensor addresses are fixed at initialization
2. Only index tensor contents change between graph replays
3. Supports both async transfer (for prefill) and graph-based transfer (for decode)
"""
import torch
import torch.cuda.nvtx
from torch import Tensor
from typing import Dict, List, Tuple, Optional
from dataclasses import dataclass
from nanovllm.kvcache.kernels import gathered_copy_kv
from nanovllm.comm import memcpy_2d_async
from nanovllm.utils.logger import get_logger
logger = get_logger("offload_engine")
@dataclass
class TransferEvent:
"""Tracks a pending async transfer."""
event: torch.cuda.Event
layer_id: int
src_block_id: int
dst_block_id: int
direction: str # "h2d" or "d2h"
class OffloadEngine:
"""
High-performance CPU-GPU async transfer engine for KV cache offloading.
Memory layout:
- GPU cache: [num_layers, num_gpu_blocks, block_size, kv_heads, head_dim]
- CPU cache: [num_layers, num_cpu_blocks, block_size, kv_heads, head_dim] (pinned)
- Gather indices: [num_layers, num_gpu_blocks] (fixed address, variable content)
CUDA Graph compatibility:
- gathered_h2d_layer() can be captured into CUDA graphs
- update_gather_indices() is called outside graphs to prepare indices
- All tensor addresses remain fixed across graph replays
"""
def __init__(
self,
num_layers: int,
num_gpu_blocks: int,
num_cpu_blocks: int,
block_size: int,
num_kv_heads: int,
head_dim: int,
dtype: torch.dtype = torch.float16,
num_streams: int = 4,
):
self.num_layers = num_layers
self.num_gpu_blocks = num_gpu_blocks
self.num_cpu_blocks = num_cpu_blocks
self.block_size = block_size
self.num_kv_heads = num_kv_heads
self.head_dim = head_dim
self.dtype = dtype
self.kv_dim = num_kv_heads * head_dim
self.block_numel = block_size * self.kv_dim
# ========== sgDMA pitch parameters for strided transfers ==========
# CPU cache: [num_layers, num_cpu_blocks, block_size, kv_heads, head_dim]
# GPU cache: [num_gpu_blocks, block_size, kv_heads, head_dim] (no layer dim)
# For CPU-to-GPU transfer (H2D): copy single layer, single block at a time
# For all-layer CPU operations (D2H offload to all layers): use sgDMA
self.dtype_size = dtype.itemsize
# CPU pitch: stride between layers in CPU cache (for all-layer operations)
self.cpu_pitch = num_cpu_blocks * self.block_numel * self.dtype_size
# GPU has no layer dimension, so single block transfer is contiguous
self.gpu_block_bytes = self.block_numel * self.dtype_size
self.height = num_layers # For CPU all-layer operations
logger.info(f"sgDMA parameters: cpu_pitch={self.cpu_pitch}, "
f"gpu_block_bytes={self.gpu_block_bytes}, height={self.height}")
# ========== Unified Ring Buffer configuration ==========
# Constraint checks
assert num_gpu_blocks >= 2, \
f"Need at least 2 GPU blocks for ring buffer, got {num_gpu_blocks}"
# Unified Ring Buffer: all slots cycle for prefill
# Prefill: use ALL slots as ring buffer (slot[chunk_idx % N])
# Decode: slot[0] as decode_slot, slots[1:] for loading previous chunks
self.num_ring_slots = num_gpu_blocks
self.ring_slots = list(range(num_gpu_blocks))
# Decode phase uses slot[0] for writing new token's KV
self.decode_slot = 0
# Decode phase uses slots[1:] for loading previous chunks from CPU
self.decode_load_slots = list(range(1, num_gpu_blocks))
self.num_decode_load_slots = len(self.decode_load_slots)
self.num_gpu_slots = num_gpu_blocks # alias
logger.info(f"Unified Ring Buffer: {self.num_ring_slots} slots total")
logger.info(f" Prefill: all slots as ring buffer [0..{num_gpu_blocks-1}]")
logger.info(f" Decode: slot[0] as decode_slot, slots[1..{num_gpu_blocks-1}] for loading")
# ========== Fixed-address GPU KV cache ==========
# Shape: [num_gpu_blocks, block_size, kv_heads, head_dim]
# NOTE: No num_layers dimension! GPU slots are shared across layers.
# Each layer reuses the same slots (layers execute sequentially).
# This saves 28x GPU memory compared to per-layer allocation.
self.k_cache_gpu = torch.zeros(
num_gpu_blocks, block_size, num_kv_heads, head_dim,
dtype=dtype, device="cuda"
)
self.v_cache_gpu = torch.zeros(
num_gpu_blocks, block_size, num_kv_heads, head_dim,
dtype=dtype, device="cuda"
)
# ========== Fixed-address CPU KV cache (pinned memory) ==========
self.k_cache_cpu = torch.zeros(
num_layers, num_cpu_blocks, block_size, num_kv_heads, head_dim,
dtype=dtype, device="cpu", pin_memory=True
)
self.v_cache_cpu = torch.zeros(
num_layers, num_cpu_blocks, block_size, num_kv_heads, head_dim,
dtype=dtype, device="cpu", pin_memory=True
)
# ========== Fixed-address gather indices (content is variable) ==========
# gather_indices[layer][i] = CPU block id to copy to GPU slot i
# -1 means no-op (skip this slot)
self.gather_indices_cpu = torch.empty(
num_layers, num_gpu_blocks,
dtype=torch.int64, device="cpu", pin_memory=True
)
self.gather_indices_cpu.fill_(-1)
self.gather_indices_gpu = torch.full(
(num_layers, num_gpu_blocks), -1,
dtype=torch.int64, device="cuda"
)
# Log memory allocation
gpu_mem_mb = self.gpu_memory_bytes() / (1024 * 1024)
cpu_mem_mb = self.cpu_memory_bytes() / (1024 * 1024)
logger.info(f" GPU memory: {gpu_mem_mb:.1f} MB, CPU memory: {cpu_mem_mb:.1f} MB")
# ========== Transfer streams for async operations ==========
self.transfer_streams = [torch.cuda.Stream() for _ in range(num_streams)]
# IMPORTANT: Create a dedicated compute stream (not default stream!)
# Default stream has implicit synchronization with other streams,
# which prevents overlap between transfer and compute.
self.compute_stream = torch.cuda.Stream()
self._stream_idx = 0
# ========== Per-slot transfer streams for parallel H2D ==========
# Each slot has its own stream to enable parallel transfers
# This allows multiple slots to load simultaneously
self.slot_transfer_streams = [torch.cuda.Stream() for _ in range(self.num_ring_slots)]
logger.info(f" Created {self.num_ring_slots} per-slot transfer streams")
# ========== Ring Buffer dedicated stream and events ==========
self.transfer_stream_main = torch.cuda.Stream() # Main transfer stream (for legacy/batch ops)
# Decode offload event
self.decode_offload_done = torch.cuda.Event()
# ========== Per-slot events for ring buffer ==========
# Since GPU cache has no layer dimension and layers execute sequentially,
# we only need per-slot events (not per-slot per-layer).
# ring_slot_ready[slot_idx] = CUDA Event for H2D completion
# ring_slot_offload_done[slot_idx] = CUDA Event for D2H completion
self.ring_slot_ready = [torch.cuda.Event() for _ in range(self.num_ring_slots)]
self.ring_slot_offload_done = [torch.cuda.Event() for _ in range(self.num_ring_slots)]
# ========== Per-slot compute_done events for async pipeline ==========
# ring_slot_compute_done[slot_idx] = CUDA Event for compute completion
# This ensures we don't overwrite data before it's been read by attention
self.ring_slot_compute_done = [torch.cuda.Event() for _ in range(self.num_ring_slots)]
# Initialize all compute_done events (record them once)
# This prevents undefined behavior on first load_to_slot_layer call
for slot_idx in range(self.num_ring_slots):
self.ring_slot_compute_done[slot_idx].record()
torch.cuda.synchronize() # Ensure all events are recorded
# ========== Event tracking for async transfers ==========
self.pending_events: Dict[Tuple[int, int], torch.cuda.Event] = {}
# ========== Debug hook mode ==========
self._debug_mode = False
self._debug_hooks: List = [] # External hooks for debug events
def _get_next_stream(self) -> torch.cuda.Stream:
"""Round-robin stream selection for parallel transfers."""
stream = self.transfer_streams[self._stream_idx]
self._stream_idx = (self._stream_idx + 1) % len(self.transfer_streams)
return stream
# ========== CUDA Graph compatible methods ==========
# NOTE: These methods need to be updated for the new GPU cache architecture.
# GPU cache no longer has layer dimension, so gathered copy semantics change.
# For now, these are kept for reference but should not be used without updating.
def gathered_h2d_layer(self, layer_id: int) -> None:
"""
Execute gathered H2D copy for a single layer.
WARNING: This method needs updating for new GPU cache architecture.
GPU cache no longer has layer dimension.
"""
# GPU cache has no layer dimension - use flat indexing
# Source is CPU[layer_id], dest is GPU (shared across layers)
gathered_copy_kv(
k_src=self.k_cache_cpu[layer_id],
v_src=self.v_cache_cpu[layer_id],
k_dst=self.k_cache_gpu, # No layer indexing
v_dst=self.v_cache_gpu, # No layer indexing
indices=self.gather_indices_gpu[layer_id],
)
def gathered_h2d_all_layers(self) -> None:
"""
Execute gathered H2D copy for all layers.
WARNING: In new architecture, GPU slots are shared across layers.
This method would overwrite slots multiple times. Not recommended.
"""
for layer_id in range(self.num_layers):
self.gathered_h2d_layer(layer_id)
def update_gather_indices(
self,
layer_id: int,
mappings: List[Tuple[int, int]],
) -> None:
"""
Update gather indices for a layer (call OUTSIDE CUDA graph).
Args:
layer_id: Layer index
mappings: List of (cpu_block_id, gpu_slot) tuples
Only these slots will be updated; others keep their values
"""
for cpu_block_id, gpu_slot in mappings:
self.gather_indices_cpu[layer_id, gpu_slot] = cpu_block_id
# Async copy to GPU
self.gather_indices_gpu[layer_id].copy_(
self.gather_indices_cpu[layer_id],
non_blocking=True
)
def update_gather_indices_all_layers(
self,
mappings_per_layer: List[List[Tuple[int, int]]],
) -> None:
"""
Update gather indices for all layers.
Args:
mappings_per_layer: mappings_per_layer[layer_id] = [(cpu_block_id, gpu_slot), ...]
"""
for layer_id, mappings in enumerate(mappings_per_layer):
for cpu_block_id, gpu_slot in mappings:
self.gather_indices_cpu[layer_id, gpu_slot] = cpu_block_id
# Batch copy all layers
self.gather_indices_gpu.copy_(self.gather_indices_cpu, non_blocking=True)
def clear_gather_indices(self, layer_id: Optional[int] = None) -> None:
"""
Clear gather indices (set all to -1, meaning no-op).
Args:
layer_id: If provided, clear only this layer; otherwise clear all
"""
if layer_id is not None:
self.gather_indices_cpu[layer_id].fill_(-1)
self.gather_indices_gpu[layer_id].fill_(-1)
else:
self.gather_indices_cpu.fill_(-1)
self.gather_indices_gpu.fill_(-1)
# ========== Async transfer methods (for prefill, outside CUDA graph) ==========
def prefetch_block_async(
self,
layer_id: int,
cpu_block_id: int,
gpu_block_id: int,
) -> torch.cuda.Event:
"""
Async prefetch a single block from CPU to GPU.
GPU cache has no layer dimension - layer_id is for CPU cache indexing.
Args:
layer_id: Layer index (for CPU cache)
cpu_block_id: Source block in CPU cache
gpu_block_id: Destination slot in GPU cache
Returns:
CUDA event that signals completion
"""
stream = self._get_next_stream()
event = torch.cuda.Event()
logger.debug(f"H2D prefetch: layer={layer_id}, CPU[{cpu_block_id}] -> GPU[{gpu_block_id}]")
with torch.cuda.stream(stream):
# GPU: no layer dimension, CPU: has layer dimension
self.k_cache_gpu[gpu_block_id].copy_(
self.k_cache_cpu[layer_id, cpu_block_id],
non_blocking=True
)
self.v_cache_gpu[gpu_block_id].copy_(
self.v_cache_cpu[layer_id, cpu_block_id],
non_blocking=True
)
event.record()
self.pending_events[(layer_id, gpu_block_id)] = event
return event
def prefetch_blocks_batch_async(
self,
transfers: List[Tuple[int, int, int]], # [(layer_id, cpu_block_id, gpu_block_id), ...]
) -> List[torch.cuda.Event]:
"""
Batch async prefetch multiple blocks.
Args:
transfers: List of (layer_id, cpu_block_id, gpu_block_id) tuples
Returns:
List of CUDA events for each transfer
"""
events = []
for layer_id, cpu_block_id, gpu_block_id in transfers:
event = self.prefetch_block_async(layer_id, cpu_block_id, gpu_block_id)
events.append(event)
return events
def offload_block_async(
self,
layer_id: int,
gpu_block_id: int,
cpu_block_id: int,
) -> torch.cuda.Event:
"""
Async offload a block from GPU to CPU.
GPU cache has no layer dimension - layer_id is for CPU cache indexing.
Args:
layer_id: Layer index (for CPU cache)
gpu_block_id: Source slot in GPU cache
cpu_block_id: Destination block in CPU cache
Returns:
CUDA event that signals completion
"""
stream = self._get_next_stream()
event = torch.cuda.Event()
logger.debug(f"D2H offload: layer={layer_id}, GPU[{gpu_block_id}] -> CPU[{cpu_block_id}]")
with torch.cuda.stream(stream):
# Wait for any compute using this block
stream.wait_stream(self.compute_stream)
# GPU: no layer dimension, CPU: has layer dimension
self.k_cache_cpu[layer_id, cpu_block_id].copy_(
self.k_cache_gpu[gpu_block_id],
non_blocking=True
)
self.v_cache_cpu[layer_id, cpu_block_id].copy_(
self.v_cache_gpu[gpu_block_id],
non_blocking=True
)
event.record()
return event
def offload_blocks_batch_async(
self,
transfers: List[Tuple[int, int, int]], # [(layer_id, gpu_block_id, cpu_block_id), ...]
) -> List[torch.cuda.Event]:
"""
Batch async offload multiple blocks.
Args:
transfers: List of (layer_id, gpu_block_id, cpu_block_id) tuples
Returns:
List of CUDA events
"""
events = []
for layer_id, gpu_block_id, cpu_block_id in transfers:
event = self.offload_block_async(layer_id, gpu_block_id, cpu_block_id)
events.append(event)
return events
# ========== Chunked Decode: Load CPU blocks to GPU slots ==========
def load_cpu_blocks_to_gpu_slots(
self,
layer_id: int,
cpu_block_ids: List[int],
gpu_slot_ids: List[int],
) -> None:
"""
Load CPU blocks to specific GPU slots for chunked decode.
GPU cache has no layer dimension - layer_id is for CPU cache indexing.
Args:
layer_id: Layer index (for CPU cache)
cpu_block_ids: List of CPU block IDs to load
gpu_slot_ids: List of GPU slot IDs to load into (must be same length)
"""
assert len(cpu_block_ids) == len(gpu_slot_ids)
if cpu_block_ids:
logger.debug(f"H2D chunked load: layer={layer_id}, CPU{cpu_block_ids} -> GPU{gpu_slot_ids}")
stream = self._get_next_stream()
with torch.cuda.stream(stream):
for cpu_block_id, gpu_slot in zip(cpu_block_ids, gpu_slot_ids):
# GPU: no layer dimension, CPU: has layer dimension
self.k_cache_gpu[gpu_slot].copy_(
self.k_cache_cpu[layer_id, cpu_block_id],
non_blocking=True
)
self.v_cache_gpu[gpu_slot].copy_(
self.v_cache_cpu[layer_id, cpu_block_id],
non_blocking=True
)
# Wait for transfer to complete
stream.synchronize()
def load_cpu_blocks_to_gpu_slots_async(
self,
layer_id: int,
cpu_block_ids: List[int],
gpu_slot_ids: List[int],
) -> torch.cuda.Event:
"""
Async version: Load CPU blocks to GPU slots.
GPU cache has no layer dimension - layer_id is for CPU cache indexing.
Args:
layer_id: Layer index (for CPU cache)
cpu_block_ids: List of CPU block IDs to load
gpu_slot_ids: List of GPU slot IDs to load into
Returns:
CUDA event to wait on
"""
assert len(cpu_block_ids) == len(gpu_slot_ids)
if cpu_block_ids:
logger.debug(f"H2D chunked load async: layer={layer_id}, CPU{cpu_block_ids} -> GPU{gpu_slot_ids}")
stream = self._get_next_stream()
event = torch.cuda.Event()
with torch.cuda.stream(stream):
for cpu_block_id, gpu_slot in zip(cpu_block_ids, gpu_slot_ids):
# GPU: no layer dimension, CPU: has layer dimension
self.k_cache_gpu[gpu_slot].copy_(
self.k_cache_cpu[layer_id, cpu_block_id],
non_blocking=True
)
self.v_cache_gpu[gpu_slot].copy_(
self.v_cache_cpu[layer_id, cpu_block_id],
non_blocking=True
)
event.record()
return event
# NOTE: load_cpu_blocks_to_gpu_slots_all_layers removed - GPU cache no longer has
# layer dimension. Each GPU slot holds data for ONE layer at a time.
# ========== Synchronization methods ==========
def wait_for_block(self, layer_id: int, gpu_block_id: int) -> None:
"""Wait for a specific block's transfer to complete."""
key = (layer_id, gpu_block_id)
if key in self.pending_events:
self.pending_events[key].synchronize()
del self.pending_events[key]
def wait_all_transfers(self) -> None:
"""Wait for all pending transfers to complete."""
for stream in self.transfer_streams:
stream.synchronize()
self.pending_events.clear()
def sync_indices(self) -> None:
"""Synchronize to ensure all index updates are complete."""
torch.cuda.default_stream().synchronize()
# ========== Cache access methods ==========
def get_layer_cache(self, layer_id: int) -> Tuple[Tensor, Tensor]:
"""
Get GPU K/V cache tensors for attention layer.
NOTE: GPU cache has no layer dimension - all layers share the same slots.
The layer_id parameter is kept for API compatibility but not used.
Returns:
(k_cache, v_cache) tensors
Shape: [num_gpu_blocks, block_size, kv_heads, head_dim]
"""
# GPU cache is shared across all layers (no layer dimension)
return self.k_cache_gpu, self.v_cache_gpu
def get_all_gpu_cache(self) -> Tuple[Tensor, Tensor]:
"""
Get full GPU K/V cache tensors.
NOTE: GPU cache has no layer dimension in the new architecture.
Returns:
(k_cache, v_cache) tensors
Shape: [num_gpu_blocks, block_size, kv_heads, head_dim]
"""
return self.k_cache_gpu, self.v_cache_gpu
def get_cpu_block(
self,
layer_id: int,
cpu_block_id: int,
) -> Tuple[Tensor, Tensor]:
"""
Get a specific CPU block's K/V cache.
Returns:
(k_cache, v_cache) for the block
Shape: [block_size, kv_heads, head_dim]
"""
return (
self.k_cache_cpu[layer_id, cpu_block_id],
self.v_cache_cpu[layer_id, cpu_block_id],
)
# ========== Memory info ==========
def gpu_memory_bytes(self) -> int:
"""Total GPU memory used by KV caches."""
return (
self.k_cache_gpu.numel() * self.k_cache_gpu.element_size() +
self.v_cache_gpu.numel() * self.v_cache_gpu.element_size() +
self.gather_indices_gpu.numel() * self.gather_indices_gpu.element_size()
)
def cpu_memory_bytes(self) -> int:
"""Total CPU memory used by KV caches."""
return (
self.k_cache_cpu.numel() * self.k_cache_cpu.element_size() +
self.v_cache_cpu.numel() * self.v_cache_cpu.element_size() +
self.gather_indices_cpu.numel() * self.gather_indices_cpu.element_size()
)
def __repr__(self) -> str:
return (
f"OffloadEngine(\n"
f" num_layers={self.num_layers},\n"
f" num_gpu_blocks={self.num_gpu_blocks},\n"
f" num_cpu_blocks={self.num_cpu_blocks},\n"
f" block_size={self.block_size},\n"
f" kv_heads={self.num_kv_heads},\n"
f" head_dim={self.head_dim},\n"
f" dtype={self.dtype},\n"
f" ring_buffer: {self.num_ring_slots} slots, decode_slot={self.decode_slot}, decode_load_slots={self.decode_load_slots},\n"
f" gpu_memory={self.gpu_memory_bytes() / 1024**2:.1f}MB,\n"
f" cpu_memory={self.cpu_memory_bytes() / 1024**2:.1f}MB\n"
f")"
)
def wait_all_offload_done(self) -> None:
"""Wait for all offload operations to complete."""
self.transfer_stream_main.synchronize()
# ========== Unified Ring Buffer methods ==========
# ----- Prefill: Ring Buffer slot management -----
def get_write_slot_for_prefill(self, chunk_idx: int) -> int:
"""
Get ring buffer slot for writing prefill chunk.
For prefill, ALL slots are used as ring buffer, cycling through.
Args:
chunk_idx: Current chunk index (0, 1, 2, ...)
Returns:
GPU slot index for writing
"""
return chunk_idx % self.num_ring_slots
def get_load_slots_for_prefill(self, write_slot_idx: int) -> List[int]:
"""
Get available slots for loading previous chunks during prefill.
Excludes the current write slot to avoid conflict.
Args:
write_slot_idx: Current write slot index
Returns:
List of slot indices available for loading (N-1 slots)
"""
return [i for i in range(self.num_ring_slots) if i != write_slot_idx]
# ----- Decode: slot management -----
def get_load_slots_for_decode(self) -> List[int]:
"""
Get slots available for loading during decode.
Excludes decode_slot (slot[0]) since it's used for writing new token's KV.
Returns:
List of slot indices for loading (slots[1:])
"""
return self.decode_load_slots
# ----- Per-slot Per-layer loading methods -----
def record_slot_compute_done(self, slot_idx: int) -> None:
"""
Record that computation using this slot's data is done.
This event is used by load_to_slot_layer to ensure we don't overwrite
data before it's been read by attention computation.
Args:
slot_idx: GPU slot index that was just used for computation
"""
self.ring_slot_compute_done[slot_idx].record()
def load_to_slot_layer(self, slot_idx: int, layer_id: int, cpu_block_id: int) -> None:
"""
Async load a single CPU block to a ring buffer slot for one layer.
This is the core building block for ring buffer pipelining.
GPU cache has no layer dimension - slots are shared across all layers.
CPU cache still has layer dimension for persistent storage.
Before starting the transfer, waits for:
1. Any previous compute on this slot to complete
Args:
slot_idx: Target GPU slot index
layer_id: Layer index to load (for CPU cache indexing)
cpu_block_id: Source CPU block ID
"""
logger.debug(f"Ring load: layer={layer_id}, CPU[{cpu_block_id}] -> GPU slot[{slot_idx}]")
# Use per-slot stream for parallel transfers across different slots
stream = self.slot_transfer_streams[slot_idx]
torch.cuda.nvtx.range_push(f"H2D: L{layer_id} CPU[{cpu_block_id}]->Slot[{slot_idx}]")
with torch.cuda.stream(stream):
# Wait for previous compute on this slot to complete before overwriting
# This prevents data race: transfer must not start until attention finishes reading
stream.wait_event(self.ring_slot_compute_done[slot_idx])
# Also wait for any pending offload of this slot to complete
# This prevents race: load must not write GPU slot while offload is reading from it
stream.wait_event(self.ring_slot_offload_done[slot_idx])
# GPU: no layer dimension, CPU: has layer dimension
self.k_cache_gpu[slot_idx].copy_(
self.k_cache_cpu[layer_id, cpu_block_id], non_blocking=True
)
self.v_cache_gpu[slot_idx].copy_(
self.v_cache_cpu[layer_id, cpu_block_id], non_blocking=True
)
self.ring_slot_ready[slot_idx].record(stream)
torch.cuda.nvtx.range_pop()
def wait_slot_layer(self, slot_idx: int) -> None:
"""
Wait for a slot's loading to complete.
Args:
slot_idx: GPU slot index to wait for
"""
self.compute_stream.wait_event(self.ring_slot_ready[slot_idx])
# NOTE: load_to_slot_all_layers removed - GPU cache no longer has layer dimension.
# Each GPU slot holds data for ONE layer at a time. Layers execute sequentially,
# reusing the same GPU slots.
# ----- Slot offload methods -----
# NOTE: offload_slot_to_cpu (all-layers) removed - GPU cache no longer has layer dimension.
# Use offload_slot_layer_to_cpu for per-layer offloading.
def wait_slot_offload(self, slot_idx: int) -> None:
"""Wait for slot offload to complete."""
self.compute_stream.wait_event(self.ring_slot_offload_done[slot_idx])
def offload_slot_layer_to_cpu(self, slot_idx: int, layer_id: int, cpu_block_id: int) -> None:
"""
Async offload a ring buffer slot to CPU for one layer.
GPU cache has no layer dimension, so we copy from GPU slot to the
specific layer in CPU cache.
Args:
slot_idx: Source GPU slot index
layer_id: Target layer in CPU cache
cpu_block_id: Target CPU block ID
"""
logger.debug(f"Ring offload: GPU slot[{slot_idx}] -> CPU[layer={layer_id}, block={cpu_block_id}]")
torch.cuda.nvtx.range_push(f"D2H: Slot[{slot_idx}]->CPU[L{layer_id},B{cpu_block_id}]")
with torch.cuda.stream(self.transfer_stream_main):
# Wait for both compute_stream and default stream
# - compute_stream: for flash attention operations
# - default_stream: for store_kvcache which runs on default stream
self.transfer_stream_main.wait_stream(self.compute_stream)
self.transfer_stream_main.wait_stream(torch.cuda.default_stream())
# GPU: no layer dimension, CPU: has layer dimension
self.k_cache_cpu[layer_id, cpu_block_id].copy_(
self.k_cache_gpu[slot_idx], non_blocking=True
)
self.v_cache_cpu[layer_id, cpu_block_id].copy_(
self.v_cache_gpu[slot_idx], non_blocking=True
)
self.ring_slot_offload_done[slot_idx].record(self.transfer_stream_main)
torch.cuda.nvtx.range_pop()
# ----- KV access methods for ring buffer -----
def get_kv_for_slot(self, slot_idx: int) -> Tuple[Tensor, Tensor]:
"""
Get KV for a single ring buffer slot.
GPU cache has no layer dimension - slots contain data for whatever
layer was most recently loaded.
Args:
slot_idx: GPU slot index
Returns:
(k_cache, v_cache), shape: [1, block_size, kv_heads, head_dim]
"""
k = self.k_cache_gpu[slot_idx].unsqueeze(0) # [1, block_size, heads, dim]
v = self.v_cache_gpu[slot_idx].unsqueeze(0)
return k, v
def get_kv_for_slots(
self,
slot_indices: List[int],
) -> Tuple[Tensor, Tensor]:
"""
Get KV for multiple ring buffer slots.
GPU cache has no layer dimension - returns data from specified slots.
Args:
slot_indices: List of GPU slot indices
Returns:
(k_cache, v_cache), shape: [1, len(slots) * block_size, kv_heads, head_dim]
"""
if not slot_indices:
return None, None
k = self.k_cache_gpu[slot_indices]
v = self.v_cache_gpu[slot_indices]
k = k.reshape(1, -1, self.num_kv_heads, self.head_dim)
v = v.reshape(1, -1, self.num_kv_heads, self.head_dim)
return k, v
# ----- Decode slot methods (kept for decode phase) -----
# NOTE: For decode with CPU offload, the flow is per-layer:
# 1. Each layer stores to decode_slot (same GPU memory, reused)
# 2. Each layer offloads its data to CPU[layer_id, block_id]
# 3. Each layer loads prev blocks from CPU[layer_id] when needed
def offload_decode_slot_layer(self, layer_id: int, cpu_block_id: int) -> None:
"""
Offload KV from decode slot (slot[0]) to CPU for one layer.
Args:
layer_id: Layer ID
cpu_block_id: Target CPU block ID
"""
# Reuse the existing per-layer offload method
self.offload_slot_layer_to_cpu(self.decode_slot, layer_id, cpu_block_id)
def wait_decode_offload(self) -> None:
"""Wait for decode slot offload to complete."""
self.wait_slot_offload(self.decode_slot)
def get_kv_for_decode_slot(
self,
pos_in_block: int,
) -> Tuple[Tensor, Tensor]:
"""
Get KV at specified position in decode slot.
GPU cache has no layer dimension - decode slot contains data for
whatever layer was most recently stored.
Args:
pos_in_block: Token position within block (0 to block_size-1)
Returns:
(k_cache, v_cache), shape: [1, 1, kv_heads, head_dim]
"""
k = self.k_cache_gpu[self.decode_slot, pos_in_block:pos_in_block+1]
v = self.v_cache_gpu[self.decode_slot, pos_in_block:pos_in_block+1]
k = k.unsqueeze(0)
v = v.unsqueeze(0)
return k, v
def get_kv_for_decode_slot_accumulated(
self,
num_tokens: int,
) -> Tuple[Tensor, Tensor]:
"""
Get accumulated KV in decode slot (positions 0 to num_tokens-1).
GPU cache has no layer dimension - decode slot contains data for
whatever layer was most recently stored.
Args:
num_tokens: Number of accumulated tokens (1 to block_size)
Returns:
(k_cache, v_cache), shape: [1, num_tokens, kv_heads, head_dim]
"""
k = self.k_cache_gpu[self.decode_slot, :num_tokens]
v = self.v_cache_gpu[self.decode_slot, :num_tokens]
k = k.unsqueeze(0)
v = v.unsqueeze(0)
return k, v
# ----- Legacy compatibility methods (for decode double-buffering) -----
# NOTE: GPU cache has no layer dimension. Layer ID is used for CPU cache indexing only.
def load_to_compute_layer(self, layer_id: int, cpu_block_ids: List[int]) -> None:
"""
Legacy: Load CPU blocks to decode_load_slots for decode double-buffering.
Uses first half of decode_load_slots as 'compute' region.
GPU cache has no layer dimension - layer_id is for CPU cache indexing.
"""
if not cpu_block_ids:
return
half = max(1, len(self.decode_load_slots) // 2)
slots = self.decode_load_slots[:half]
num_to_load = min(len(cpu_block_ids), len(slots))
with torch.cuda.stream(self.transfer_stream_main):
for i in range(num_to_load):
cpu_id = cpu_block_ids[i]
gpu_slot = slots[i]
# GPU: no layer dimension, CPU: has layer dimension
self.k_cache_gpu[gpu_slot].copy_(
self.k_cache_cpu[layer_id, cpu_id], non_blocking=True
)
self.v_cache_gpu[gpu_slot].copy_(
self.v_cache_cpu[layer_id, cpu_id], non_blocking=True
)
if num_to_load > 0:
self.ring_slot_ready[slots[0]].record(self.transfer_stream_main)
def wait_compute_layer(self) -> None:
"""Legacy: Wait for 'compute' region loading."""
if self.decode_load_slots:
self.wait_slot_layer(self.decode_load_slots[0])
def load_to_prefetch_layer(self, layer_id: int, cpu_block_ids: List[int]) -> None:
"""
Legacy: Load CPU blocks to decode_load_slots for decode double-buffering.
Uses second half of decode_load_slots as 'prefetch' region.
GPU cache has no layer dimension - layer_id is for CPU cache indexing.
"""
if not cpu_block_ids:
return
half = max(1, len(self.decode_load_slots) // 2)
slots = self.decode_load_slots[half:]
if not slots:
slots = self.decode_load_slots # Fallback if only 1-2 slots
num_to_load = min(len(cpu_block_ids), len(slots))
with torch.cuda.stream(self.transfer_stream_main):
for i in range(num_to_load):
cpu_id = cpu_block_ids[i]
gpu_slot = slots[i]
# GPU: no layer dimension, CPU: has layer dimension
self.k_cache_gpu[gpu_slot].copy_(
self.k_cache_cpu[layer_id, cpu_id], non_blocking=True
)
self.v_cache_gpu[gpu_slot].copy_(
self.v_cache_cpu[layer_id, cpu_id], non_blocking=True
)
if num_to_load > 0:
self.ring_slot_ready[slots[0]].record(self.transfer_stream_main)
def wait_prefetch_layer(self) -> None:
"""Legacy: Wait for 'prefetch' region loading."""
half = max(1, len(self.decode_load_slots) // 2)
slots = self.decode_load_slots[half:]
if slots:
self.wait_slot_layer(slots[0])
elif self.decode_load_slots:
self.wait_slot_layer(self.decode_load_slots[0])
def get_kv_for_compute(
self,
num_blocks: int,
) -> Tuple[Tensor, Tensor]:
"""Legacy: Get KV from 'compute' region (first half of decode_load_slots)."""
half = max(1, len(self.decode_load_slots) // 2)
slots = self.decode_load_slots[:half][:num_blocks]
return self.get_kv_for_slots(slots)
def get_kv_for_prefetch(
self,
num_blocks: int,
) -> Tuple[Tensor, Tensor]:
"""Legacy: Get KV from 'prefetch' region (second half of decode_load_slots)."""
half = max(1, len(self.decode_load_slots) // 2)
slots = self.decode_load_slots[half:]
if not slots:
slots = self.decode_load_slots
slots = slots[:num_blocks]
return self.get_kv_for_slots(slots)
# ========== Debug Hook Interface ==========
#
# Minimal generic hook system for debugging.
# Framework only provides hook registration and tensor access.
# All verification logic is external.
def enable_debug_mode(self) -> None:
"""Enable debug mode."""
self._debug_mode = True
logger.info("OffloadEngine debug mode ENABLED")
def disable_debug_mode(self) -> None:
"""Disable debug mode and clear all hooks."""
self._debug_mode = False
self._debug_hooks.clear()
logger.info("OffloadEngine debug mode DISABLED")
@property
def debug_mode(self) -> bool:
"""Check if debug mode is enabled."""
return self._debug_mode
def register_debug_hook(self, hook_fn) -> None:
"""
Register a debug hook.
The hook is called after H2D load completes (after wait_slot_layer),
receiving the loaded tensor for inspection.
Args:
hook_fn: Callable with signature:
(slot_idx: int, layer_id: int, cpu_block_id: int, k: Tensor, v: Tensor) -> None
- k, v: GPU tensor views for the loaded slot
Example:
def my_hook(slot_idx, layer_id, cpu_block_id, k, v):
if layer_id == 0:
k_val = k.float().mean().item()
print(f"Loaded block {cpu_block_id}, K mean = {k_val}")
offload_engine.register_debug_hook(my_hook)
"""
self._debug_hooks.append(hook_fn)
def remove_debug_hook(self, hook_fn) -> None:
"""Remove a registered debug hook."""
if hook_fn in self._debug_hooks:
self._debug_hooks.remove(hook_fn)
def _call_debug_hooks(self, slot_idx: int, layer_id: int, cpu_block_id: int) -> None:
"""
Call all registered debug hooks with loaded tensor (internal use).
Called by attention.py after wait_slot_layer completes.
GPU cache has no layer dimension - slot contains data for the layer
that was just loaded.
"""
if not self._debug_mode or not self._debug_hooks:
return
# GPU cache has no layer dimension
k = self.k_cache_gpu[slot_idx]
v = self.v_cache_gpu[slot_idx]
for hook in self._debug_hooks:
try:
hook(slot_idx, layer_id, cpu_block_id, k, v)
except Exception as e:
# Allow pdb quit to propagate
if e.__class__.__name__ == 'BdbQuit':
raise
logger.warning(f"Debug hook error: {e}")