[WIP] Before fix needle.

This commit is contained in:
Zijie Tian
2025-12-31 23:35:25 +08:00
parent ccd1b3d4ab
commit 30462fe89a
5 changed files with 212 additions and 290 deletions

View File

@@ -67,14 +67,19 @@ class OffloadEngine:
self.block_numel = block_size * self.kv_dim
# ========== sgDMA pitch parameters for strided transfers ==========
# CPU cache: [num_layers, num_cpu_blocks, block_size, kv_heads, head_dim]
# GPU cache: [num_gpu_blocks, block_size, kv_heads, head_dim] (no layer dim)
# For CPU-to-GPU transfer (H2D): copy single layer, single block at a time
# For all-layer CPU operations (D2H offload to all layers): use sgDMA
self.dtype_size = dtype.itemsize
# CPU pitch: stride between layers in CPU cache (for all-layer operations)
self.cpu_pitch = num_cpu_blocks * self.block_numel * self.dtype_size
self.gpu_pitch = num_gpu_blocks * self.block_numel * self.dtype_size
self.width = self.block_numel * self.dtype_size
self.height = num_layers
# GPU has no layer dimension, so single block transfer is contiguous
self.gpu_block_bytes = self.block_numel * self.dtype_size
self.height = num_layers # For CPU all-layer operations
logger.info(f"sgDMA parameters: cpu_pitch={self.cpu_pitch}, gpu_pitch={self.gpu_pitch}, "
f"width={self.width}, height={self.height}")
logger.info(f"sgDMA parameters: cpu_pitch={self.cpu_pitch}, "
f"gpu_block_bytes={self.gpu_block_bytes}, height={self.height}")
# ========== Unified Ring Buffer configuration ==========
# Constraint checks
@@ -100,14 +105,16 @@ class OffloadEngine:
logger.info(f" Decode: slot[0] as decode_slot, slots[1..{num_gpu_blocks-1}] for loading")
# ========== Fixed-address GPU KV cache ==========
# Shape: [num_layers, num_gpu_blocks, block_size, kv_heads, head_dim]
# Use zeros initialization to avoid uninitialized memory issues
# Shape: [num_gpu_blocks, block_size, kv_heads, head_dim]
# NOTE: No num_layers dimension! GPU slots are shared across layers.
# Each layer reuses the same slots (layers execute sequentially).
# This saves 28x GPU memory compared to per-layer allocation.
self.k_cache_gpu = torch.zeros(
num_layers, num_gpu_blocks, block_size, num_kv_heads, head_dim,
num_gpu_blocks, block_size, num_kv_heads, head_dim,
dtype=dtype, device="cuda"
)
self.v_cache_gpu = torch.zeros(
num_layers, num_gpu_blocks, block_size, num_kv_heads, head_dim,
num_gpu_blocks, block_size, num_kv_heads, head_dim,
dtype=dtype, device="cuda"
)
@@ -159,35 +166,23 @@ class OffloadEngine:
# Decode offload event
self.decode_offload_done = torch.cuda.Event()
# ========== Per-slot Per-layer events for ring buffer ==========
# ring_slot_ready[slot_idx][layer_id] = CUDA Event for H2D completion
# ring_slot_offload_done[slot_idx][layer_id] = CUDA Event for D2H completion
self.ring_slot_ready = [
[torch.cuda.Event() for _ in range(num_layers)]
for _ in range(self.num_ring_slots)
]
self.ring_slot_offload_done = [
[torch.cuda.Event() for _ in range(num_layers)]
for _ in range(self.num_ring_slots)
]
# ========== Per-slot events for ring buffer ==========
# Since GPU cache has no layer dimension and layers execute sequentially,
# we only need per-slot events (not per-slot per-layer).
# ring_slot_ready[slot_idx] = CUDA Event for H2D completion
# ring_slot_offload_done[slot_idx] = CUDA Event for D2H completion
self.ring_slot_ready = [torch.cuda.Event() for _ in range(self.num_ring_slots)]
self.ring_slot_offload_done = [torch.cuda.Event() for _ in range(self.num_ring_slots)]
# Per-slot events for all-layer operations (used in some legacy paths)
self.ring_slot_all_layers_ready = [torch.cuda.Event() for _ in range(self.num_ring_slots)]
self.ring_slot_all_layers_offload_done = [torch.cuda.Event() for _ in range(self.num_ring_slots)]
# ========== Per-slot Per-layer compute_done events for async pipeline ==========
# ring_slot_compute_done[slot_idx][layer_id] = CUDA Event for compute completion
# This is used to ensure we don't overwrite data before it's been read by attention
self.ring_slot_compute_done = [
[torch.cuda.Event() for _ in range(num_layers)]
for _ in range(self.num_ring_slots)
]
# ========== Per-slot compute_done events for async pipeline ==========
# ring_slot_compute_done[slot_idx] = CUDA Event for compute completion
# This ensures we don't overwrite data before it's been read by attention
self.ring_slot_compute_done = [torch.cuda.Event() for _ in range(self.num_ring_slots)]
# Initialize all compute_done events (record them once)
# This prevents undefined behavior on first load_to_slot_layer call
for slot_idx in range(self.num_ring_slots):
for layer_id in range(num_layers):
self.ring_slot_compute_done[slot_idx][layer_id].record()
self.ring_slot_compute_done[slot_idx].record()
torch.cuda.synchronize() # Ensure all events are recorded
# ========== Event tracking for async transfers ==========
@@ -204,23 +199,24 @@ class OffloadEngine:
return stream
# ========== CUDA Graph compatible methods ==========
# NOTE: These methods need to be updated for the new GPU cache architecture.
# GPU cache no longer has layer dimension, so gathered copy semantics change.
# For now, these are kept for reference but should not be used without updating.
def gathered_h2d_layer(self, layer_id: int) -> None:
"""
Execute gathered H2D copy for a single layer.
This method is CUDA Graph compatible - can be captured into a graph.
Before calling, update_gather_indices() must be called to set up
which CPU blocks to copy to which GPU slots.
Args:
layer_id: Layer index to transfer
WARNING: This method needs updating for new GPU cache architecture.
GPU cache no longer has layer dimension.
"""
# GPU cache has no layer dimension - use flat indexing
# Source is CPU[layer_id], dest is GPU (shared across layers)
gathered_copy_kv(
k_src=self.k_cache_cpu[layer_id],
v_src=self.v_cache_cpu[layer_id],
k_dst=self.k_cache_gpu[layer_id],
v_dst=self.v_cache_gpu[layer_id],
k_dst=self.k_cache_gpu, # No layer indexing
v_dst=self.v_cache_gpu, # No layer indexing
indices=self.gather_indices_gpu[layer_id],
)
@@ -228,7 +224,8 @@ class OffloadEngine:
"""
Execute gathered H2D copy for all layers.
CUDA Graph compatible - can be captured into a single graph.
WARNING: In new architecture, GPU slots are shared across layers.
This method would overwrite slots multiple times. Not recommended.
"""
for layer_id in range(self.num_layers):
self.gathered_h2d_layer(layer_id)
@@ -297,10 +294,10 @@ class OffloadEngine:
"""
Async prefetch a single block from CPU to GPU.
For use in prefill phase where CUDA graphs are not used.
GPU cache has no layer dimension - layer_id is for CPU cache indexing.
Args:
layer_id: Layer index
layer_id: Layer index (for CPU cache)
cpu_block_id: Source block in CPU cache
gpu_block_id: Destination slot in GPU cache
@@ -313,13 +310,12 @@ class OffloadEngine:
logger.debug(f"H2D prefetch: layer={layer_id}, CPU[{cpu_block_id}] -> GPU[{gpu_block_id}]")
with torch.cuda.stream(stream):
# K cache
self.k_cache_gpu[layer_id, gpu_block_id].copy_(
# GPU: no layer dimension, CPU: has layer dimension
self.k_cache_gpu[gpu_block_id].copy_(
self.k_cache_cpu[layer_id, cpu_block_id],
non_blocking=True
)
# V cache
self.v_cache_gpu[layer_id, gpu_block_id].copy_(
self.v_cache_gpu[gpu_block_id].copy_(
self.v_cache_cpu[layer_id, cpu_block_id],
non_blocking=True
)
@@ -356,8 +352,10 @@ class OffloadEngine:
"""
Async offload a block from GPU to CPU.
GPU cache has no layer dimension - layer_id is for CPU cache indexing.
Args:
layer_id: Layer index
layer_id: Layer index (for CPU cache)
gpu_block_id: Source slot in GPU cache
cpu_block_id: Destination block in CPU cache
@@ -373,14 +371,13 @@ class OffloadEngine:
# Wait for any compute using this block
stream.wait_stream(self.compute_stream)
# K cache
# GPU: no layer dimension, CPU: has layer dimension
self.k_cache_cpu[layer_id, cpu_block_id].copy_(
self.k_cache_gpu[layer_id, gpu_block_id],
self.k_cache_gpu[gpu_block_id],
non_blocking=True
)
# V cache
self.v_cache_cpu[layer_id, cpu_block_id].copy_(
self.v_cache_gpu[layer_id, gpu_block_id],
self.v_cache_gpu[gpu_block_id],
non_blocking=True
)
event.record()
@@ -417,11 +414,10 @@ class OffloadEngine:
"""
Load CPU blocks to specific GPU slots for chunked decode.
Uses the main GPU KV cache slots, not a separate temp buffer.
This is the same mechanism as chunked prefill uses.
GPU cache has no layer dimension - layer_id is for CPU cache indexing.
Args:
layer_id: Layer index
layer_id: Layer index (for CPU cache)
cpu_block_ids: List of CPU block IDs to load
gpu_slot_ids: List of GPU slot IDs to load into (must be same length)
"""
@@ -434,12 +430,12 @@ class OffloadEngine:
with torch.cuda.stream(stream):
for cpu_block_id, gpu_slot in zip(cpu_block_ids, gpu_slot_ids):
# Copy from pinned CPU memory to GPU KV cache slot
self.k_cache_gpu[layer_id, gpu_slot].copy_(
# GPU: no layer dimension, CPU: has layer dimension
self.k_cache_gpu[gpu_slot].copy_(
self.k_cache_cpu[layer_id, cpu_block_id],
non_blocking=True
)
self.v_cache_gpu[layer_id, gpu_slot].copy_(
self.v_cache_gpu[gpu_slot].copy_(
self.v_cache_cpu[layer_id, cpu_block_id],
non_blocking=True
)
@@ -456,8 +452,10 @@ class OffloadEngine:
"""
Async version: Load CPU blocks to GPU slots.
GPU cache has no layer dimension - layer_id is for CPU cache indexing.
Args:
layer_id: Layer index
layer_id: Layer index (for CPU cache)
cpu_block_ids: List of CPU block IDs to load
gpu_slot_ids: List of GPU slot IDs to load into
@@ -474,11 +472,12 @@ class OffloadEngine:
with torch.cuda.stream(stream):
for cpu_block_id, gpu_slot in zip(cpu_block_ids, gpu_slot_ids):
self.k_cache_gpu[layer_id, gpu_slot].copy_(
# GPU: no layer dimension, CPU: has layer dimension
self.k_cache_gpu[gpu_slot].copy_(
self.k_cache_cpu[layer_id, cpu_block_id],
non_blocking=True
)
self.v_cache_gpu[layer_id, gpu_slot].copy_(
self.v_cache_gpu[gpu_slot].copy_(
self.v_cache_cpu[layer_id, cpu_block_id],
non_blocking=True
)
@@ -486,44 +485,8 @@ class OffloadEngine:
return event
def load_cpu_blocks_to_gpu_slots_all_layers(
self,
cpu_block_ids: List[int],
gpu_slot_ids: List[int],
) -> None:
"""
Load CPU blocks to GPU slots for ALL layers at once.
More efficient than per-layer loading when we know the mapping upfront.
Args:
cpu_block_ids: List of CPU block IDs to load
gpu_slot_ids: List of GPU slot IDs to load into
"""
assert len(cpu_block_ids) == len(gpu_slot_ids)
if cpu_block_ids:
logger.debug(f"H2D all layers: CPU{cpu_block_ids} -> GPU{gpu_slot_ids}")
stream = self._get_next_stream()
with torch.cuda.stream(stream):
for cpu_block_id, gpu_slot in zip(cpu_block_ids, gpu_slot_ids):
# Copy all layers at once using sgDMA
memcpy_2d_async(
self.k_cache_gpu[:, gpu_slot],
self.k_cache_cpu[:, cpu_block_id],
self.gpu_pitch, self.cpu_pitch, self.width, self.height,
"h2d", stream=stream
)
memcpy_2d_async(
self.v_cache_gpu[:, gpu_slot],
self.v_cache_cpu[:, cpu_block_id],
self.gpu_pitch, self.cpu_pitch, self.width, self.height,
"h2d", stream=stream
)
stream.synchronize()
# NOTE: load_cpu_blocks_to_gpu_slots_all_layers removed - GPU cache no longer has
# layer dimension. Each GPU slot holds data for ONE layer at a time.
# ========== Synchronization methods ==========
@@ -548,21 +511,27 @@ class OffloadEngine:
def get_layer_cache(self, layer_id: int) -> Tuple[Tensor, Tensor]:
"""
Get GPU K/V cache tensors for a specific layer.
Get GPU K/V cache tensors for attention layer.
NOTE: GPU cache has no layer dimension - all layers share the same slots.
The layer_id parameter is kept for API compatibility but not used.
Returns:
(k_cache, v_cache) tensors for the layer
(k_cache, v_cache) tensors
Shape: [num_gpu_blocks, block_size, kv_heads, head_dim]
"""
return self.k_cache_gpu[layer_id], self.v_cache_gpu[layer_id]
# GPU cache is shared across all layers (no layer dimension)
return self.k_cache_gpu, self.v_cache_gpu
def get_all_gpu_cache(self) -> Tuple[Tensor, Tensor]:
"""
Get full GPU K/V cache tensors.
NOTE: GPU cache has no layer dimension in the new architecture.
Returns:
(k_cache, v_cache) tensors
Shape: [num_layers, num_gpu_blocks, block_size, kv_heads, head_dim]
Shape: [num_gpu_blocks, block_size, kv_heads, head_dim]
"""
return self.k_cache_gpu, self.v_cache_gpu
@@ -668,7 +637,7 @@ class OffloadEngine:
# ----- Per-slot Per-layer loading methods -----
def record_slot_compute_done(self, slot_idx: int, layer_id: int) -> None:
def record_slot_compute_done(self, slot_idx: int) -> None:
"""
Record that computation using this slot's data is done.
@@ -677,22 +646,23 @@ class OffloadEngine:
Args:
slot_idx: GPU slot index that was just used for computation
layer_id: Layer index
"""
self.ring_slot_compute_done[slot_idx][layer_id].record()
self.ring_slot_compute_done[slot_idx].record()
def load_to_slot_layer(self, slot_idx: int, layer_id: int, cpu_block_id: int) -> None:
"""
Async load a single CPU block to a ring buffer slot for one layer.
This is the core building block for ring buffer pipelining.
GPU cache has no layer dimension - slots are shared across all layers.
CPU cache still has layer dimension for persistent storage.
Before starting the transfer, waits for:
1. Any previous compute on this slot to complete
2. Any pending offload of this slot to complete
Args:
slot_idx: Target GPU slot index
layer_id: Layer index to load
layer_id: Layer index to load (for CPU cache indexing)
cpu_block_id: Source CPU block ID
"""
logger.debug(f"Ring load: layer={layer_id}, CPU[{cpu_block_id}] -> GPU slot[{slot_idx}]")
@@ -704,150 +674,105 @@ class OffloadEngine:
with torch.cuda.stream(stream):
# Wait for previous compute on this slot to complete before overwriting
# This prevents data race: transfer must not start until attention finishes reading
stream.wait_event(self.ring_slot_compute_done[slot_idx][layer_id])
stream.wait_event(self.ring_slot_compute_done[slot_idx])
# Also wait for any pending offload of this slot to complete
# This prevents race: load must not write GPU slot while offload is reading from it
stream.wait_event(self.ring_slot_all_layers_offload_done[slot_idx])
stream.wait_event(self.ring_slot_offload_done[slot_idx])
self.k_cache_gpu[layer_id, slot_idx].copy_(
# GPU: no layer dimension, CPU: has layer dimension
self.k_cache_gpu[slot_idx].copy_(
self.k_cache_cpu[layer_id, cpu_block_id], non_blocking=True
)
self.v_cache_gpu[layer_id, slot_idx].copy_(
self.v_cache_gpu[slot_idx].copy_(
self.v_cache_cpu[layer_id, cpu_block_id], non_blocking=True
)
self.ring_slot_ready[slot_idx][layer_id].record(stream)
self.ring_slot_ready[slot_idx].record(stream)
torch.cuda.nvtx.range_pop()
def wait_slot_layer(self, slot_idx: int, layer_id: int) -> None:
def wait_slot_layer(self, slot_idx: int) -> None:
"""
Wait for a slot's loading to complete for a specific layer.
Wait for a slot's loading to complete.
Args:
slot_idx: GPU slot index to wait for
layer_id: Layer index to wait for
"""
self.compute_stream.wait_event(self.ring_slot_ready[slot_idx][layer_id])
self.compute_stream.wait_event(self.ring_slot_ready[slot_idx])
def load_to_slot_all_layers(self, slot_idx: int, cpu_block_id: int) -> None:
"""
Async load a CPU block to a ring buffer slot for ALL layers.
Args:
slot_idx: Target GPU slot index
cpu_block_id: Source CPU block ID
"""
logger.debug(f"Ring load all layers: CPU[{cpu_block_id}] -> GPU slot[{slot_idx}]")
with torch.cuda.stream(self.transfer_stream_main):
memcpy_2d_async(
self.k_cache_gpu[:, slot_idx],
self.k_cache_cpu[:, cpu_block_id],
self.gpu_pitch, self.cpu_pitch, self.width, self.height,
"h2d", stream=self.transfer_stream_main
)
memcpy_2d_async(
self.v_cache_gpu[:, slot_idx],
self.v_cache_cpu[:, cpu_block_id],
self.gpu_pitch, self.cpu_pitch, self.width, self.height,
"h2d", stream=self.transfer_stream_main
)
self.ring_slot_all_layers_ready[slot_idx].record(self.transfer_stream_main)
def wait_slot_all_layers(self, slot_idx: int) -> None:
"""Wait for a slot's loading to complete for ALL layers."""
self.compute_stream.wait_event(self.ring_slot_all_layers_ready[slot_idx])
# NOTE: load_to_slot_all_layers removed - GPU cache no longer has layer dimension.
# Each GPU slot holds data for ONE layer at a time. Layers execute sequentially,
# reusing the same GPU slots.
# ----- Slot offload methods -----
def offload_slot_to_cpu(self, slot_idx: int, cpu_block_id: int) -> None:
# NOTE: offload_slot_to_cpu (all-layers) removed - GPU cache no longer has layer dimension.
# Use offload_slot_layer_to_cpu for per-layer offloading.
def wait_slot_offload(self, slot_idx: int) -> None:
"""Wait for slot offload to complete."""
self.compute_stream.wait_event(self.ring_slot_offload_done[slot_idx])
def offload_slot_layer_to_cpu(self, slot_idx: int, layer_id: int, cpu_block_id: int) -> None:
"""
Async offload a ring buffer slot to CPU (all layers).
Async offload a ring buffer slot to CPU for one layer.
GPU cache has no layer dimension, so we copy from GPU slot to the
specific layer in CPU cache.
Args:
slot_idx: Source GPU slot index
layer_id: Target layer in CPU cache
cpu_block_id: Target CPU block ID
"""
logger.debug(f"Ring offload: GPU slot[{slot_idx}] -> CPU[{cpu_block_id}]")
logger.debug(f"Ring offload: GPU slot[{slot_idx}] -> CPU[layer={layer_id}, block={cpu_block_id}]")
torch.cuda.nvtx.range_push(f"D2H: Slot[{slot_idx}]->CPU[{cpu_block_id}]")
torch.cuda.nvtx.range_push(f"D2H: Slot[{slot_idx}]->CPU[L{layer_id},B{cpu_block_id}]")
with torch.cuda.stream(self.transfer_stream_main):
# Wait for both compute_stream and default stream
# - compute_stream: for flash attention operations
# - default_stream: for store_kvcache which runs on default stream
self.transfer_stream_main.wait_stream(self.compute_stream)
self.transfer_stream_main.wait_stream(torch.cuda.default_stream())
memcpy_2d_async(
self.k_cache_cpu[:, cpu_block_id],
self.k_cache_gpu[:, slot_idx],
self.cpu_pitch, self.gpu_pitch, self.width, self.height,
"d2h", stream=self.transfer_stream_main
)
memcpy_2d_async(
self.v_cache_cpu[:, cpu_block_id],
self.v_cache_gpu[:, slot_idx],
self.cpu_pitch, self.gpu_pitch, self.width, self.height,
"d2h", stream=self.transfer_stream_main
)
self.ring_slot_all_layers_offload_done[slot_idx].record(self.transfer_stream_main)
torch.cuda.nvtx.range_pop()
def wait_slot_offload(self, slot_idx: int) -> None:
"""Wait for slot offload to complete."""
self.compute_stream.wait_event(self.ring_slot_all_layers_offload_done[slot_idx])
def offload_slot_layer_to_cpu(self, slot_idx: int, layer_id: int, cpu_block_id: int) -> None:
"""
Async offload a ring buffer slot to CPU for one layer.
Args:
slot_idx: Source GPU slot index
layer_id: Layer index to offload
cpu_block_id: Target CPU block ID
"""
with torch.cuda.stream(self.transfer_stream_main):
# Wait for both compute_stream and default stream
self.transfer_stream_main.wait_stream(self.compute_stream)
self.transfer_stream_main.wait_stream(torch.cuda.default_stream())
# GPU: no layer dimension, CPU: has layer dimension
self.k_cache_cpu[layer_id, cpu_block_id].copy_(
self.k_cache_gpu[layer_id, slot_idx], non_blocking=True
self.k_cache_gpu[slot_idx], non_blocking=True
)
self.v_cache_cpu[layer_id, cpu_block_id].copy_(
self.v_cache_gpu[layer_id, slot_idx], non_blocking=True
self.v_cache_gpu[slot_idx], non_blocking=True
)
self.ring_slot_offload_done[slot_idx][layer_id].record(self.transfer_stream_main)
def wait_slot_offload_layer(self, slot_idx: int, layer_id: int) -> None:
"""Wait for slot offload to complete for a specific layer."""
self.compute_stream.wait_event(self.ring_slot_offload_done[slot_idx][layer_id])
self.ring_slot_offload_done[slot_idx].record(self.transfer_stream_main)
torch.cuda.nvtx.range_pop()
# ----- KV access methods for ring buffer -----
def get_kv_for_slot(self, slot_idx: int, layer_id: int) -> Tuple[Tensor, Tensor]:
def get_kv_for_slot(self, slot_idx: int) -> Tuple[Tensor, Tensor]:
"""
Get KV for a single ring buffer slot.
GPU cache has no layer dimension - slots contain data for whatever
layer was most recently loaded.
Args:
slot_idx: GPU slot index
layer_id: Layer ID
Returns:
(k_cache, v_cache), shape: [1, block_size, kv_heads, head_dim]
"""
k = self.k_cache_gpu[layer_id, slot_idx].unsqueeze(0) # [1, block_size, heads, dim]
v = self.v_cache_gpu[layer_id, slot_idx].unsqueeze(0)
k = self.k_cache_gpu[slot_idx].unsqueeze(0) # [1, block_size, heads, dim]
v = self.v_cache_gpu[slot_idx].unsqueeze(0)
return k, v
def get_kv_for_slots(
self,
layer_id: int,
slot_indices: List[int],
) -> Tuple[Tensor, Tensor]:
"""
Get KV for multiple ring buffer slots.
GPU cache has no layer dimension - returns data from specified slots.
Args:
layer_id: Layer ID
slot_indices: List of GPU slot indices
Returns:
@@ -855,92 +780,86 @@ class OffloadEngine:
"""
if not slot_indices:
return None, None
k = self.k_cache_gpu[layer_id, slot_indices]
v = self.v_cache_gpu[layer_id, slot_indices]
k = self.k_cache_gpu[slot_indices]
v = self.v_cache_gpu[slot_indices]
k = k.reshape(1, -1, self.num_kv_heads, self.head_dim)
v = v.reshape(1, -1, self.num_kv_heads, self.head_dim)
return k, v
# ----- Decode slot methods (kept for decode phase) -----
# NOTE: For decode with CPU offload, the flow is per-layer:
# 1. Each layer stores to decode_slot (same GPU memory, reused)
# 2. Each layer offloads its data to CPU[layer_id, block_id]
# 3. Each layer loads prev blocks from CPU[layer_id] when needed
def offload_decode_slot(self, cpu_block_id: int) -> None:
def offload_decode_slot_layer(self, layer_id: int, cpu_block_id: int) -> None:
"""
Offload KV from decode slot (slot[0]) to CPU.
Offload KV from decode slot (slot[0]) to CPU for one layer.
Args:
layer_id: Layer ID
cpu_block_id: Target CPU block ID
"""
logger.debug(f"Decode offload: GPU slot[{self.decode_slot}] -> CPU[{cpu_block_id}]")
with torch.cuda.stream(self.transfer_stream_main):
self.transfer_stream_main.wait_stream(self.compute_stream)
memcpy_2d_async(
self.k_cache_cpu[:, cpu_block_id],
self.k_cache_gpu[:, self.decode_slot],
self.cpu_pitch, self.gpu_pitch, self.width, self.height,
"d2h", stream=self.transfer_stream_main
)
memcpy_2d_async(
self.v_cache_cpu[:, cpu_block_id],
self.v_cache_gpu[:, self.decode_slot],
self.cpu_pitch, self.gpu_pitch, self.width, self.height,
"d2h", stream=self.transfer_stream_main
)
self.decode_offload_done.record(self.transfer_stream_main)
# Reuse the existing per-layer offload method
self.offload_slot_layer_to_cpu(self.decode_slot, layer_id, cpu_block_id)
def wait_decode_offload(self) -> None:
"""Wait for decode slot offload to complete."""
self.compute_stream.wait_event(self.decode_offload_done)
self.wait_slot_offload(self.decode_slot)
def get_kv_for_decode_slot(
self,
layer_id: int,
pos_in_block: int,
) -> Tuple[Tensor, Tensor]:
"""
Get KV at specified position in decode slot.
GPU cache has no layer dimension - decode slot contains data for
whatever layer was most recently stored.
Args:
layer_id: Layer ID
pos_in_block: Token position within block (0 to block_size-1)
Returns:
(k_cache, v_cache), shape: [1, 1, kv_heads, head_dim]
"""
k = self.k_cache_gpu[layer_id, self.decode_slot, pos_in_block:pos_in_block+1]
v = self.v_cache_gpu[layer_id, self.decode_slot, pos_in_block:pos_in_block+1]
k = self.k_cache_gpu[self.decode_slot, pos_in_block:pos_in_block+1]
v = self.v_cache_gpu[self.decode_slot, pos_in_block:pos_in_block+1]
k = k.unsqueeze(0)
v = v.unsqueeze(0)
return k, v
def get_kv_for_decode_slot_accumulated(
self,
layer_id: int,
num_tokens: int,
) -> Tuple[Tensor, Tensor]:
"""
Get accumulated KV in decode slot (positions 0 to num_tokens-1).
GPU cache has no layer dimension - decode slot contains data for
whatever layer was most recently stored.
Args:
layer_id: Layer ID
num_tokens: Number of accumulated tokens (1 to block_size)
Returns:
(k_cache, v_cache), shape: [1, num_tokens, kv_heads, head_dim]
"""
k = self.k_cache_gpu[layer_id, self.decode_slot, :num_tokens]
v = self.v_cache_gpu[layer_id, self.decode_slot, :num_tokens]
k = self.k_cache_gpu[self.decode_slot, :num_tokens]
v = self.v_cache_gpu[self.decode_slot, :num_tokens]
k = k.unsqueeze(0)
v = v.unsqueeze(0)
return k, v
# ----- Legacy compatibility methods (for decode double-buffering) -----
# NOTE: GPU cache has no layer dimension. Layer ID is used for CPU cache indexing only.
def load_to_compute_layer(self, layer_id: int, cpu_block_ids: List[int]) -> None:
"""
Legacy: Load CPU blocks to decode_load_slots for decode double-buffering.
Uses first half of decode_load_slots as 'compute' region.
GPU cache has no layer dimension - layer_id is for CPU cache indexing.
"""
if not cpu_block_ids:
return
@@ -953,26 +872,27 @@ class OffloadEngine:
for i in range(num_to_load):
cpu_id = cpu_block_ids[i]
gpu_slot = slots[i]
self.k_cache_gpu[layer_id, gpu_slot].copy_(
# GPU: no layer dimension, CPU: has layer dimension
self.k_cache_gpu[gpu_slot].copy_(
self.k_cache_cpu[layer_id, cpu_id], non_blocking=True
)
self.v_cache_gpu[layer_id, gpu_slot].copy_(
self.v_cache_gpu[gpu_slot].copy_(
self.v_cache_cpu[layer_id, cpu_id], non_blocking=True
)
if num_to_load > 0:
self.ring_slot_ready[slots[0]][layer_id].record(self.transfer_stream_main)
self.ring_slot_ready[slots[0]].record(self.transfer_stream_main)
def wait_compute_layer(self, layer_id: int) -> None:
def wait_compute_layer(self) -> None:
"""Legacy: Wait for 'compute' region loading."""
half = max(1, len(self.decode_load_slots) // 2)
if self.decode_load_slots:
self.wait_slot_layer(self.decode_load_slots[0], layer_id)
self.wait_slot_layer(self.decode_load_slots[0])
def load_to_prefetch_layer(self, layer_id: int, cpu_block_ids: List[int]) -> None:
"""
Legacy: Load CPU blocks to decode_load_slots for decode double-buffering.
Uses second half of decode_load_slots as 'prefetch' region.
GPU cache has no layer dimension - layer_id is for CPU cache indexing.
"""
if not cpu_block_ids:
return
@@ -987,37 +907,36 @@ class OffloadEngine:
for i in range(num_to_load):
cpu_id = cpu_block_ids[i]
gpu_slot = slots[i]
self.k_cache_gpu[layer_id, gpu_slot].copy_(
# GPU: no layer dimension, CPU: has layer dimension
self.k_cache_gpu[gpu_slot].copy_(
self.k_cache_cpu[layer_id, cpu_id], non_blocking=True
)
self.v_cache_gpu[layer_id, gpu_slot].copy_(
self.v_cache_gpu[gpu_slot].copy_(
self.v_cache_cpu[layer_id, cpu_id], non_blocking=True
)
if num_to_load > 0:
self.ring_slot_ready[slots[0]][layer_id].record(self.transfer_stream_main)
self.ring_slot_ready[slots[0]].record(self.transfer_stream_main)
def wait_prefetch_layer(self, layer_id: int) -> None:
def wait_prefetch_layer(self) -> None:
"""Legacy: Wait for 'prefetch' region loading."""
half = max(1, len(self.decode_load_slots) // 2)
slots = self.decode_load_slots[half:]
if slots:
self.wait_slot_layer(slots[0], layer_id)
self.wait_slot_layer(slots[0])
elif self.decode_load_slots:
self.wait_slot_layer(self.decode_load_slots[0], layer_id)
self.wait_slot_layer(self.decode_load_slots[0])
def get_kv_for_compute(
self,
layer_id: int,
num_blocks: int,
) -> Tuple[Tensor, Tensor]:
"""Legacy: Get KV from 'compute' region (first half of decode_load_slots)."""
half = max(1, len(self.decode_load_slots) // 2)
slots = self.decode_load_slots[:half][:num_blocks]
return self.get_kv_for_slots(layer_id, slots)
return self.get_kv_for_slots(slots)
def get_kv_for_prefetch(
self,
layer_id: int,
num_blocks: int,
) -> Tuple[Tensor, Tensor]:
"""Legacy: Get KV from 'prefetch' region (second half of decode_load_slots)."""
@@ -1026,7 +945,7 @@ class OffloadEngine:
if not slots:
slots = self.decode_load_slots
slots = slots[:num_blocks]
return self.get_kv_for_slots(layer_id, slots)
return self.get_kv_for_slots(slots)
# ========== Debug Hook Interface ==========
#
@@ -1082,12 +1001,15 @@ class OffloadEngine:
Call all registered debug hooks with loaded tensor (internal use).
Called by attention.py after wait_slot_layer completes.
GPU cache has no layer dimension - slot contains data for the layer
that was just loaded.
"""
if not self._debug_mode or not self._debug_hooks:
return
k = self.k_cache_gpu[layer_id, slot_idx]
v = self.v_cache_gpu[layer_id, slot_idx]
# GPU cache has no layer dimension
k = self.k_cache_gpu[slot_idx]
v = self.v_cache_gpu[slot_idx]
for hook in self._debug_hooks:
try: