[feat] Need to optimized with async prefetch.

This commit is contained in:
Zijie Tian
2025-12-15 06:58:40 +08:00
parent 1081ab51ea
commit b8b6478506
9 changed files with 556 additions and 404 deletions

View File

@@ -65,34 +65,30 @@ class OffloadEngine:
self.kv_dim = num_kv_heads * head_dim
self.block_numel = block_size * self.kv_dim
# ========== Three-region GPU Buffer configuration ==========
# ========== Unified Ring Buffer configuration ==========
# Constraint checks
assert num_gpu_blocks >= 3, \
f"Need at least 3 GPU blocks: 1 decode + 1 compute + 1 prefetch, got {num_gpu_blocks}"
assert num_prefetch_blocks >= 1, \
f"Need at least 1 prefetch block, got {num_prefetch_blocks}"
assert num_gpu_blocks >= 1 + 1 + num_prefetch_blocks, \
f"Insufficient GPU blocks: need 1(decode) + 1(compute) + {num_prefetch_blocks}(prefetch), got {num_gpu_blocks}"
assert num_gpu_blocks >= 2, \
f"Need at least 2 GPU blocks for ring buffer, got {num_gpu_blocks}"
# Three-region configuration
# Decode region: [0] - Fixed 1 block for writing new KV
# Unified Ring Buffer: all slots cycle for prefill
# Prefill: use ALL slots as ring buffer (slot[chunk_idx % N])
# Decode: slot[0] as decode_slot, slots[1:] for loading previous chunks
self.num_ring_slots = num_gpu_blocks
self.ring_slots = list(range(num_gpu_blocks))
# Decode phase uses slot[0] for writing new token's KV
self.decode_slot = 0
# Decode phase uses slots[1:] for loading previous chunks from CPU
self.decode_load_slots = list(range(1, num_gpu_blocks))
self.num_decode_load_slots = len(self.decode_load_slots)
# Compute region: [1, ..., num_gpu_blocks - num_prefetch_blocks - 1]
compute_start = 1
compute_end = num_gpu_blocks - num_prefetch_blocks
self.compute_slots = list(range(compute_start, compute_end))
self.num_compute_blocks = len(self.compute_slots)
# Prefetch region: [num_gpu_blocks - num_prefetch_blocks, ..., num_gpu_blocks - 1]
prefetch_start = compute_end
self.prefetch_slots = list(range(prefetch_start, num_gpu_blocks))
# Keep num_prefetch_blocks for compatibility (used as chunk size for loading)
self.num_prefetch_blocks = num_prefetch_blocks
self.num_gpu_slots = num_gpu_blocks # alias
logger.info(f"Three-region GPU Buffer: decode_slot={self.decode_slot}, "
f"compute_slots={self.compute_slots}, prefetch_slots={self.prefetch_slots}")
logger.info(f"Unified Ring Buffer: {self.num_ring_slots} slots total")
logger.info(f" Prefill: all slots as ring buffer [0..{num_gpu_blocks-1}]")
logger.info(f" Decode: slot[0] as decode_slot, slots[1..{num_gpu_blocks-1}] for loading")
# ========== Fixed-address GPU KV cache ==========
# Shape: [num_layers, num_gpu_blocks, block_size, kv_heads, head_dim]
@@ -134,18 +130,27 @@ class OffloadEngine:
self.compute_stream = torch.cuda.current_stream()
self._stream_idx = 0
# ========== Three-region dedicated stream and events ==========
# ========== Ring Buffer dedicated stream and events ==========
self.transfer_stream_main = torch.cuda.Stream() # Main transfer stream
# Sync events - three-region loading completion
self.compute_ready = torch.cuda.Event()
self.prefetch_ready = torch.cuda.Event()
# Decode offload event
self.decode_offload_done = torch.cuda.Event()
# ========== Per-layer events for chunked attention ==========
# Each layer has its own event for synchronization
self.compute_ready_per_layer = [torch.cuda.Event() for _ in range(num_layers)]
self.prefetch_ready_per_layer = [torch.cuda.Event() for _ in range(num_layers)]
# ========== Per-slot Per-layer events for ring buffer ==========
# ring_slot_ready[slot_idx][layer_id] = CUDA Event for H2D completion
# ring_slot_offload_done[slot_idx][layer_id] = CUDA Event for D2H completion
self.ring_slot_ready = [
[torch.cuda.Event() for _ in range(num_layers)]
for _ in range(self.num_ring_slots)
]
self.ring_slot_offload_done = [
[torch.cuda.Event() for _ in range(num_layers)]
for _ in range(self.num_ring_slots)
]
# Per-slot events for all-layer operations (used in some legacy paths)
self.ring_slot_all_layers_ready = [torch.cuda.Event() for _ in range(self.num_ring_slots)]
self.ring_slot_all_layers_offload_done = [torch.cuda.Event() for _ in range(self.num_ring_slots)]
# ========== Event tracking for async transfers ==========
self.pending_events: Dict[Tuple[int, int], torch.cuda.Event] = {}
@@ -560,7 +565,7 @@ class OffloadEngine:
f" kv_heads={self.num_kv_heads},\n"
f" head_dim={self.head_dim},\n"
f" dtype={self.dtype},\n"
f" three_regions: decode_slot={self.decode_slot}, compute={self.compute_slots}, prefetch={self.prefetch_slots},\n"
f" ring_buffer: {self.num_ring_slots} slots, decode_slot={self.decode_slot}, decode_load_slots={self.decode_load_slots},\n"
f" gpu_memory={self.gpu_memory_bytes() / 1024**2:.1f}MB,\n"
f" cpu_memory={self.cpu_memory_bytes() / 1024**2:.1f}MB\n"
f")"
@@ -570,174 +575,207 @@ class OffloadEngine:
"""Wait for all offload operations to complete."""
self.transfer_stream_main.synchronize()
# ========== Unified Ring Buffer methods ==========
# ----- Prefill: Ring Buffer slot management -----
def get_write_slot_for_prefill(self, chunk_idx: int) -> int:
"""
Get ring buffer slot for writing prefill chunk.
For prefill, ALL slots are used as ring buffer, cycling through.
Args:
chunk_idx: Current chunk index (0, 1, 2, ...)
Returns:
GPU slot index for writing
"""
return chunk_idx % self.num_ring_slots
def get_load_slots_for_prefill(self, write_slot_idx: int) -> List[int]:
"""
Get available slots for loading previous chunks during prefill.
Excludes the current write slot to avoid conflict.
Args:
write_slot_idx: Current write slot index
Returns:
List of slot indices available for loading (N-1 slots)
"""
return [i for i in range(self.num_ring_slots) if i != write_slot_idx]
# ----- Decode: slot management -----
def get_load_slots_for_decode(self) -> List[int]:
"""
Get slots available for loading during decode.
Excludes decode_slot (slot[0]) since it's used for writing new token's KV.
Returns:
List of slot indices for loading (slots[1:])
"""
return self.decode_load_slots
# ----- Per-slot Per-layer loading methods -----
def load_to_slot_layer(self, slot_idx: int, layer_id: int, cpu_block_id: int) -> None:
"""
Async load a single CPU block to a ring buffer slot for one layer.
This is the core building block for ring buffer pipelining.
Args:
slot_idx: Target GPU slot index
layer_id: Layer index to load
cpu_block_id: Source CPU block ID
"""
logger.debug(f"Ring load: layer={layer_id}, CPU[{cpu_block_id}] -> GPU slot[{slot_idx}]")
with torch.cuda.stream(self.transfer_stream_main):
self.k_cache_gpu[layer_id, slot_idx].copy_(
self.k_cache_cpu[layer_id, cpu_block_id], non_blocking=True
)
self.v_cache_gpu[layer_id, slot_idx].copy_(
self.v_cache_cpu[layer_id, cpu_block_id], non_blocking=True
)
self.ring_slot_ready[slot_idx][layer_id].record(self.transfer_stream_main)
def wait_slot_layer(self, slot_idx: int, layer_id: int) -> None:
"""
Wait for a slot's loading to complete for a specific layer.
Args:
slot_idx: GPU slot index to wait for
layer_id: Layer index to wait for
"""
self.compute_stream.wait_event(self.ring_slot_ready[slot_idx][layer_id])
def load_to_slot_all_layers(self, slot_idx: int, cpu_block_id: int) -> None:
"""
Async load a CPU block to a ring buffer slot for ALL layers.
Args:
slot_idx: Target GPU slot index
cpu_block_id: Source CPU block ID
"""
logger.debug(f"Ring load all layers: CPU[{cpu_block_id}] -> GPU slot[{slot_idx}]")
with torch.cuda.stream(self.transfer_stream_main):
self.k_cache_gpu[:, slot_idx].copy_(
self.k_cache_cpu[:, cpu_block_id], non_blocking=True
)
self.v_cache_gpu[:, slot_idx].copy_(
self.v_cache_cpu[:, cpu_block_id], non_blocking=True
)
self.ring_slot_all_layers_ready[slot_idx].record(self.transfer_stream_main)
def wait_slot_all_layers(self, slot_idx: int) -> None:
"""Wait for a slot's loading to complete for ALL layers."""
self.compute_stream.wait_event(self.ring_slot_all_layers_ready[slot_idx])
# ----- Slot offload methods -----
def offload_slot_to_cpu(self, slot_idx: int, cpu_block_id: int) -> None:
"""
Async offload a ring buffer slot to CPU (all layers).
Args:
slot_idx: Source GPU slot index
cpu_block_id: Target CPU block ID
"""
logger.debug(f"Ring offload: GPU slot[{slot_idx}] -> CPU[{cpu_block_id}]")
with torch.cuda.stream(self.transfer_stream_main):
self.transfer_stream_main.wait_stream(self.compute_stream)
self.k_cache_cpu[:, cpu_block_id].copy_(
self.k_cache_gpu[:, slot_idx], non_blocking=True
)
self.v_cache_cpu[:, cpu_block_id].copy_(
self.v_cache_gpu[:, slot_idx], non_blocking=True
)
self.ring_slot_all_layers_offload_done[slot_idx].record(self.transfer_stream_main)
def wait_slot_offload(self, slot_idx: int) -> None:
"""Wait for slot offload to complete."""
self.compute_stream.wait_event(self.ring_slot_all_layers_offload_done[slot_idx])
def offload_slot_layer_to_cpu(self, slot_idx: int, layer_id: int, cpu_block_id: int) -> None:
"""
Async offload a ring buffer slot to CPU for one layer.
Args:
slot_idx: Source GPU slot index
layer_id: Layer index to offload
cpu_block_id: Target CPU block ID
"""
with torch.cuda.stream(self.transfer_stream_main):
self.transfer_stream_main.wait_stream(self.compute_stream)
self.k_cache_cpu[layer_id, cpu_block_id].copy_(
self.k_cache_gpu[layer_id, slot_idx], non_blocking=True
)
self.v_cache_cpu[layer_id, cpu_block_id].copy_(
self.v_cache_gpu[layer_id, slot_idx], non_blocking=True
)
self.ring_slot_offload_done[slot_idx][layer_id].record(self.transfer_stream_main)
def wait_slot_offload_layer(self, slot_idx: int, layer_id: int) -> None:
"""Wait for slot offload to complete for a specific layer."""
self.compute_stream.wait_event(self.ring_slot_offload_done[slot_idx][layer_id])
# ----- KV access methods for ring buffer -----
def get_kv_for_slot(self, slot_idx: int, layer_id: int) -> Tuple[Tensor, Tensor]:
"""
Get KV for a single ring buffer slot.
Args:
slot_idx: GPU slot index
layer_id: Layer ID
Returns:
(k_cache, v_cache), shape: [1, block_size, kv_heads, head_dim]
"""
k = self.k_cache_gpu[layer_id, slot_idx].unsqueeze(0) # [1, block_size, heads, dim]
v = self.v_cache_gpu[layer_id, slot_idx].unsqueeze(0)
return k, v
def get_kv_for_slots(
self,
layer_id: int,
gpu_slots: List[int],
slot_indices: List[int],
) -> Tuple[Tensor, Tensor]:
"""
Get KV for specified GPU slots.
Get KV for multiple ring buffer slots.
Args:
layer_id: Layer ID
gpu_slots: List of GPU slot IDs
slot_indices: List of GPU slot indices
Returns:
(k_cache, v_cache), shape: [1, len(slots) * block_size, kv_heads, head_dim]
"""
if not gpu_slots:
if not slot_indices:
return None, None
k = self.k_cache_gpu[layer_id, gpu_slots]
v = self.v_cache_gpu[layer_id, gpu_slots]
k = self.k_cache_gpu[layer_id, slot_indices]
v = self.v_cache_gpu[layer_id, slot_indices]
k = k.reshape(1, -1, self.num_kv_heads, self.head_dim)
v = v.reshape(1, -1, self.num_kv_heads, self.head_dim)
return k, v
# ========== Three-region GPU Buffer methods ==========
def load_to_compute(self, cpu_block_ids: List[int]) -> None:
"""
Async load CPU blocks to Compute region.
Args:
cpu_block_ids: List of CPU block IDs to load
"""
if not cpu_block_ids:
self.compute_ready.record(self.transfer_stream_main)
return
num_to_load = min(len(cpu_block_ids), len(self.compute_slots))
logger.debug(f"Compute load: CPU{cpu_block_ids[:num_to_load]} -> GPU compute slots {self.compute_slots[:num_to_load]}")
with torch.cuda.stream(self.transfer_stream_main):
for i in range(num_to_load):
cpu_id = cpu_block_ids[i]
gpu_slot = self.compute_slots[i]
# Copy all layers together
self.k_cache_gpu[:, gpu_slot].copy_(
self.k_cache_cpu[:, cpu_id], non_blocking=True
)
self.v_cache_gpu[:, gpu_slot].copy_(
self.v_cache_cpu[:, cpu_id], non_blocking=True
)
self.compute_ready.record(self.transfer_stream_main)
def load_to_prefetch(self, cpu_block_ids: List[int]) -> None:
"""
Async load CPU blocks to Prefetch region.
Args:
cpu_block_ids: List of CPU block IDs to load
"""
if not cpu_block_ids:
self.prefetch_ready.record(self.transfer_stream_main)
return
num_to_load = min(len(cpu_block_ids), len(self.prefetch_slots))
logger.debug(f"Prefetch load: CPU{cpu_block_ids[:num_to_load]} -> GPU prefetch slots {self.prefetch_slots[:num_to_load]}")
with torch.cuda.stream(self.transfer_stream_main):
for i in range(num_to_load):
cpu_id = cpu_block_ids[i]
gpu_slot = self.prefetch_slots[i]
self.k_cache_gpu[:, gpu_slot].copy_(
self.k_cache_cpu[:, cpu_id], non_blocking=True
)
self.v_cache_gpu[:, gpu_slot].copy_(
self.v_cache_cpu[:, cpu_id], non_blocking=True
)
self.prefetch_ready.record(self.transfer_stream_main)
def wait_compute(self) -> None:
"""Wait for Compute region loading to complete."""
self.compute_stream.wait_event(self.compute_ready)
def load_to_compute_layer(self, layer_id: int, cpu_block_ids: List[int]) -> None:
"""
Load CPU blocks to Compute region for a single layer only.
This is used for per-layer chunked attention where each layer
independently loads its KV data.
Args:
layer_id: Layer index to load
cpu_block_ids: List of CPU block IDs to load
"""
if not cpu_block_ids:
self.compute_ready_per_layer[layer_id].record(self.transfer_stream_main)
return
num_to_load = min(len(cpu_block_ids), len(self.compute_slots))
logger.debug(f"Compute load (layer {layer_id}): CPU{cpu_block_ids[:num_to_load]} -> GPU compute slots {self.compute_slots[:num_to_load]}")
with torch.cuda.stream(self.transfer_stream_main):
for i in range(num_to_load):
cpu_id = cpu_block_ids[i]
gpu_slot = self.compute_slots[i]
# Copy only this layer (not all layers)
self.k_cache_gpu[layer_id, gpu_slot].copy_(
self.k_cache_cpu[layer_id, cpu_id], non_blocking=True
)
self.v_cache_gpu[layer_id, gpu_slot].copy_(
self.v_cache_cpu[layer_id, cpu_id], non_blocking=True
)
self.compute_ready_per_layer[layer_id].record(self.transfer_stream_main)
def wait_compute_layer(self, layer_id: int) -> None:
"""Wait for specific layer's Compute region loading to complete."""
self.compute_stream.wait_event(self.compute_ready_per_layer[layer_id])
def wait_prefetch(self) -> None:
"""Wait for Prefetch region loading to complete."""
self.compute_stream.wait_event(self.prefetch_ready)
def load_to_prefetch_layer(self, layer_id: int, cpu_block_ids: List[int]) -> None:
"""
Load CPU blocks to Prefetch region for a single layer only.
This is used for per-layer chunked attention where each layer
independently loads its KV data.
Args:
layer_id: Layer index to load
cpu_block_ids: List of CPU block IDs to load
"""
if not cpu_block_ids:
self.prefetch_ready_per_layer[layer_id].record(self.transfer_stream_main)
return
num_to_load = min(len(cpu_block_ids), len(self.prefetch_slots))
logger.debug(f"Prefetch load (layer {layer_id}): CPU{cpu_block_ids[:num_to_load]} -> GPU prefetch slots {self.prefetch_slots[:num_to_load]}")
with torch.cuda.stream(self.transfer_stream_main):
for i in range(num_to_load):
cpu_id = cpu_block_ids[i]
gpu_slot = self.prefetch_slots[i]
# Copy only this layer (not all layers)
self.k_cache_gpu[layer_id, gpu_slot].copy_(
self.k_cache_cpu[layer_id, cpu_id], non_blocking=True
)
self.v_cache_gpu[layer_id, gpu_slot].copy_(
self.v_cache_cpu[layer_id, cpu_id], non_blocking=True
)
self.prefetch_ready_per_layer[layer_id].record(self.transfer_stream_main)
def wait_prefetch_layer(self, layer_id: int) -> None:
"""Wait for specific layer's Prefetch region loading to complete."""
self.compute_stream.wait_event(self.prefetch_ready_per_layer[layer_id])
def swap_compute_prefetch(self) -> None:
"""Swap roles of Compute region and Prefetch region."""
self.compute_slots, self.prefetch_slots = self.prefetch_slots, self.compute_slots
# ----- Decode slot methods (kept for decode phase) -----
def offload_decode_slot(self, cpu_block_id: int) -> None:
"""
Offload KV from Decode region to CPU.
Offload KV from decode slot (slot[0]) to CPU.
Args:
cpu_block_id: Target CPU block ID
"""
logger.debug(f"Decode offload: GPU[{self.decode_slot}] -> CPU[{cpu_block_id}]")
logger.debug(f"Decode offload: GPU slot[{self.decode_slot}] -> CPU[{cpu_block_id}]")
with torch.cuda.stream(self.transfer_stream_main):
self.transfer_stream_main.wait_stream(self.compute_stream)
@@ -750,61 +788,16 @@ class OffloadEngine:
self.decode_offload_done.record(self.transfer_stream_main)
def wait_decode_offload(self) -> None:
"""Wait for Decode region offload to complete."""
"""Wait for decode slot offload to complete."""
self.compute_stream.wait_event(self.decode_offload_done)
def get_kv_for_compute(
self,
layer_id: int,
num_blocks: int,
) -> Tuple[Tensor, Tensor]:
"""
Get KV for specified number of blocks in Compute region.
Args:
layer_id: Layer ID
num_blocks: Number of blocks needed
Returns:
(k_cache, v_cache), shape: [1, num_blocks * block_size, kv_heads, head_dim]
"""
slots = self.compute_slots[:num_blocks]
k = self.k_cache_gpu[layer_id, slots] # [num_blocks, block_size, heads, dim]
v = self.v_cache_gpu[layer_id, slots]
# Reshape: [num_blocks, block_size, heads, dim] -> [1, num_blocks*block_size, heads, dim]
k = k.reshape(1, -1, self.num_kv_heads, self.head_dim)
v = v.reshape(1, -1, self.num_kv_heads, self.head_dim)
return k, v
def get_kv_for_prefetch(
self,
layer_id: int,
num_blocks: int,
) -> Tuple[Tensor, Tensor]:
"""
Get KV for specified number of blocks in Prefetch region.
Args:
layer_id: Layer ID
num_blocks: Number of blocks needed
Returns:
(k_cache, v_cache), shape: [1, num_blocks * block_size, kv_heads, head_dim]
"""
slots = self.prefetch_slots[:num_blocks]
k = self.k_cache_gpu[layer_id, slots]
v = self.v_cache_gpu[layer_id, slots]
k = k.reshape(1, -1, self.num_kv_heads, self.head_dim)
v = v.reshape(1, -1, self.num_kv_heads, self.head_dim)
return k, v
def get_kv_for_decode_slot(
self,
layer_id: int,
pos_in_block: int,
) -> Tuple[Tensor, Tensor]:
"""
Get KV at specified position in Decode region (for new token during decode).
Get KV at specified position in decode slot.
Args:
layer_id: Layer ID
@@ -813,9 +806,9 @@ class OffloadEngine:
Returns:
(k_cache, v_cache), shape: [1, 1, kv_heads, head_dim]
"""
k = self.k_cache_gpu[layer_id, self.decode_slot, pos_in_block:pos_in_block+1] # [1, heads, dim]
k = self.k_cache_gpu[layer_id, self.decode_slot, pos_in_block:pos_in_block+1]
v = self.v_cache_gpu[layer_id, self.decode_slot, pos_in_block:pos_in_block+1]
k = k.unsqueeze(0) # [1, 1, heads, dim]
k = k.unsqueeze(0)
v = v.unsqueeze(0)
return k, v
@@ -825,10 +818,7 @@ class OffloadEngine:
num_tokens: int,
) -> Tuple[Tensor, Tensor]:
"""
Get accumulated KV in Decode region (all tokens from position 0 to num_tokens-1).
Used when batching decode offloads - attend to all accumulated tokens,
not just the current one.
Get accumulated KV in decode slot (positions 0 to num_tokens-1).
Args:
layer_id: Layer ID
@@ -837,35 +827,102 @@ class OffloadEngine:
Returns:
(k_cache, v_cache), shape: [1, num_tokens, kv_heads, head_dim]
"""
k = self.k_cache_gpu[layer_id, self.decode_slot, :num_tokens] # [num_tokens, heads, dim]
k = self.k_cache_gpu[layer_id, self.decode_slot, :num_tokens]
v = self.v_cache_gpu[layer_id, self.decode_slot, :num_tokens]
k = k.unsqueeze(0) # [1, num_tokens, heads, dim]
k = k.unsqueeze(0)
v = v.unsqueeze(0)
return k, v
def offload_compute_to_cpu(self, cpu_block_ids: List[int]) -> None:
"""
Offload KV from Compute region to CPU.
# ----- Legacy compatibility methods (for decode double-buffering) -----
Args:
cpu_block_ids: Target CPU block IDs list
def load_to_compute_layer(self, layer_id: int, cpu_block_ids: List[int]) -> None:
"""
Legacy: Load CPU blocks to decode_load_slots for decode double-buffering.
Uses first half of decode_load_slots as 'compute' region.
"""
if not cpu_block_ids:
return
num_to_offload = min(len(cpu_block_ids), len(self.compute_slots))
logger.debug(f"Compute offload: GPU {self.compute_slots[:num_to_offload]} -> CPU{cpu_block_ids[:num_to_offload]}")
half = max(1, len(self.decode_load_slots) // 2)
slots = self.decode_load_slots[:half]
num_to_load = min(len(cpu_block_ids), len(slots))
with torch.cuda.stream(self.transfer_stream_main):
# Wait for compute to complete
self.transfer_stream_main.wait_stream(self.compute_stream)
for i in range(num_to_offload):
gpu_slot = self.compute_slots[i]
for i in range(num_to_load):
cpu_id = cpu_block_ids[i]
self.k_cache_cpu[:, cpu_id].copy_(
self.k_cache_gpu[:, gpu_slot], non_blocking=True
gpu_slot = slots[i]
self.k_cache_gpu[layer_id, gpu_slot].copy_(
self.k_cache_cpu[layer_id, cpu_id], non_blocking=True
)
self.v_cache_cpu[:, cpu_id].copy_(
self.v_cache_gpu[:, gpu_slot], non_blocking=True
)
self.v_cache_gpu[layer_id, gpu_slot].copy_(
self.v_cache_cpu[layer_id, cpu_id], non_blocking=True
)
if num_to_load > 0:
self.ring_slot_ready[slots[0]][layer_id].record(self.transfer_stream_main)
def wait_compute_layer(self, layer_id: int) -> None:
"""Legacy: Wait for 'compute' region loading."""
half = max(1, len(self.decode_load_slots) // 2)
if self.decode_load_slots:
self.wait_slot_layer(self.decode_load_slots[0], layer_id)
def load_to_prefetch_layer(self, layer_id: int, cpu_block_ids: List[int]) -> None:
"""
Legacy: Load CPU blocks to decode_load_slots for decode double-buffering.
Uses second half of decode_load_slots as 'prefetch' region.
"""
if not cpu_block_ids:
return
half = max(1, len(self.decode_load_slots) // 2)
slots = self.decode_load_slots[half:]
if not slots:
slots = self.decode_load_slots # Fallback if only 1-2 slots
num_to_load = min(len(cpu_block_ids), len(slots))
with torch.cuda.stream(self.transfer_stream_main):
for i in range(num_to_load):
cpu_id = cpu_block_ids[i]
gpu_slot = slots[i]
self.k_cache_gpu[layer_id, gpu_slot].copy_(
self.k_cache_cpu[layer_id, cpu_id], non_blocking=True
)
self.v_cache_gpu[layer_id, gpu_slot].copy_(
self.v_cache_cpu[layer_id, cpu_id], non_blocking=True
)
if num_to_load > 0:
self.ring_slot_ready[slots[0]][layer_id].record(self.transfer_stream_main)
def wait_prefetch_layer(self, layer_id: int) -> None:
"""Legacy: Wait for 'prefetch' region loading."""
half = max(1, len(self.decode_load_slots) // 2)
slots = self.decode_load_slots[half:]
if slots:
self.wait_slot_layer(slots[0], layer_id)
elif self.decode_load_slots:
self.wait_slot_layer(self.decode_load_slots[0], layer_id)
def get_kv_for_compute(
self,
layer_id: int,
num_blocks: int,
) -> Tuple[Tensor, Tensor]:
"""Legacy: Get KV from 'compute' region (first half of decode_load_slots)."""
half = max(1, len(self.decode_load_slots) // 2)
slots = self.decode_load_slots[:half][:num_blocks]
return self.get_kv_for_slots(layer_id, slots)
def get_kv_for_prefetch(
self,
layer_id: int,
num_blocks: int,
) -> Tuple[Tensor, Tensor]:
"""Legacy: Get KV from 'prefetch' region (second half of decode_load_slots)."""
half = max(1, len(self.decode_load_slots) // 2)
slots = self.decode_load_slots[half:]
if not slots:
slots = self.decode_load_slots
slots = slots[:num_blocks]
return self.get_kv_for_slots(layer_id, slots)