[feat] Need to optimized with async prefetch.

This commit is contained in:
Zijie Tian
2025-12-15 06:58:40 +08:00
parent 1081ab51ea
commit b8b6478506
9 changed files with 556 additions and 404 deletions

View File

@@ -95,16 +95,16 @@ class HybridKVCacheManager(KVCacheManager):
num_cpu_blocks: Number of CPU pool blocks (overflow or primary storage)
block_size: Tokens per block
policy: Eviction policy (default: LRU)
cpu_primary: If True, use CPU as primary storage with three-region GPU buffer.
cpu_primary: If True, use CPU as primary storage with ring buffer GPU design.
If False, use GPU as primary with CPU as overflow (legacy mode).
num_prefetch_blocks: Number of prefetch blocks for three-region GPU buffer design
num_prefetch_blocks: Number of blocks for ring buffer pipeline (deprecated, ring_slots = num_gpu_slots)
"""
self._block_size = block_size
self.num_gpu_slots = num_gpu_slots
self.num_cpu_blocks = num_cpu_blocks
self.total_blocks = num_gpu_slots + num_cpu_blocks
self.cpu_primary = cpu_primary # Three-region mode flag
self.num_prefetch_blocks = num_prefetch_blocks # Three-region design parameter
self.cpu_primary = cpu_primary # Ring buffer mode flag
self.num_prefetch_blocks = num_prefetch_blocks # Ring buffer design parameter (deprecated)
# Eviction policy
self.policy = policy or LRUPolicy()
@@ -341,7 +341,7 @@ class HybridKVCacheManager(KVCacheManager):
"""
assert not seq.block_table, "Sequence already has blocks"
# Three-region mode: all blocks are allocated to CPU
# Ring buffer mode: all blocks are allocated to CPU
if self.cpu_primary:
return self.allocate_cpu_only(seq)
@@ -471,7 +471,7 @@ class HybridKVCacheManager(KVCacheManager):
block.token_ids = []
if self.cpu_primary:
# Three-region mode: new block allocated to CPU
# Ring buffer mode: new block allocated to CPU
if not self.free_cpu_blocks:
raise RuntimeError("No free CPU blocks for decode")
cpu_block_id = self.free_cpu_blocks.popleft()
@@ -1025,14 +1025,14 @@ class HybridKVCacheManager(KVCacheManager):
break
return pos
# ========== Three-region double buffering support ==========
# ========== Ring Buffer CPU-primary support ==========
def allocate_cpu_only(self, seq: Sequence) -> None:
"""
Allocate CPU blocks for sequence (for three-region mode).
Allocate CPU blocks for sequence (for ring buffer mode).
Unlike allocate(), here all blocks are allocated to CPU,
GPU is only used as working buffer.
GPU is only used as ring buffer for computation.
Args:
seq: Sequence to allocate
@@ -1092,10 +1092,10 @@ class HybridKVCacheManager(KVCacheManager):
cpu_blocks.append(block.cpu_block_id)
else:
# If block is on GPU, it should have a corresponding CPU block
# In three-region mode, all data ultimately resides on CPU
# In ring buffer mode, all data ultimately resides on CPU
raise RuntimeError(
f"Block {logical_id} not on CPU (location={block.location}). "
f"In three-region mode, all blocks should be on CPU."
f"In ring buffer mode, all blocks should be on CPU."
)
return cpu_blocks
@@ -1171,8 +1171,8 @@ class HybridKVCacheManager(KVCacheManager):
"""
Get GPU slot for writing new KV during chunked offload decode.
In three-region design, always use Decode region (slot 0) to write new KV.
This avoids conflicts with Compute/Prefetch region loading operations.
In ring buffer design, always use decode_slot (slot[0]) to write new KV.
This avoids conflicts with loading operations which use slots[1:].
Args:
seq: Sequence

View File

@@ -65,34 +65,30 @@ class OffloadEngine:
self.kv_dim = num_kv_heads * head_dim
self.block_numel = block_size * self.kv_dim
# ========== Three-region GPU Buffer configuration ==========
# ========== Unified Ring Buffer configuration ==========
# Constraint checks
assert num_gpu_blocks >= 3, \
f"Need at least 3 GPU blocks: 1 decode + 1 compute + 1 prefetch, got {num_gpu_blocks}"
assert num_prefetch_blocks >= 1, \
f"Need at least 1 prefetch block, got {num_prefetch_blocks}"
assert num_gpu_blocks >= 1 + 1 + num_prefetch_blocks, \
f"Insufficient GPU blocks: need 1(decode) + 1(compute) + {num_prefetch_blocks}(prefetch), got {num_gpu_blocks}"
assert num_gpu_blocks >= 2, \
f"Need at least 2 GPU blocks for ring buffer, got {num_gpu_blocks}"
# Three-region configuration
# Decode region: [0] - Fixed 1 block for writing new KV
# Unified Ring Buffer: all slots cycle for prefill
# Prefill: use ALL slots as ring buffer (slot[chunk_idx % N])
# Decode: slot[0] as decode_slot, slots[1:] for loading previous chunks
self.num_ring_slots = num_gpu_blocks
self.ring_slots = list(range(num_gpu_blocks))
# Decode phase uses slot[0] for writing new token's KV
self.decode_slot = 0
# Decode phase uses slots[1:] for loading previous chunks from CPU
self.decode_load_slots = list(range(1, num_gpu_blocks))
self.num_decode_load_slots = len(self.decode_load_slots)
# Compute region: [1, ..., num_gpu_blocks - num_prefetch_blocks - 1]
compute_start = 1
compute_end = num_gpu_blocks - num_prefetch_blocks
self.compute_slots = list(range(compute_start, compute_end))
self.num_compute_blocks = len(self.compute_slots)
# Prefetch region: [num_gpu_blocks - num_prefetch_blocks, ..., num_gpu_blocks - 1]
prefetch_start = compute_end
self.prefetch_slots = list(range(prefetch_start, num_gpu_blocks))
# Keep num_prefetch_blocks for compatibility (used as chunk size for loading)
self.num_prefetch_blocks = num_prefetch_blocks
self.num_gpu_slots = num_gpu_blocks # alias
logger.info(f"Three-region GPU Buffer: decode_slot={self.decode_slot}, "
f"compute_slots={self.compute_slots}, prefetch_slots={self.prefetch_slots}")
logger.info(f"Unified Ring Buffer: {self.num_ring_slots} slots total")
logger.info(f" Prefill: all slots as ring buffer [0..{num_gpu_blocks-1}]")
logger.info(f" Decode: slot[0] as decode_slot, slots[1..{num_gpu_blocks-1}] for loading")
# ========== Fixed-address GPU KV cache ==========
# Shape: [num_layers, num_gpu_blocks, block_size, kv_heads, head_dim]
@@ -134,18 +130,27 @@ class OffloadEngine:
self.compute_stream = torch.cuda.current_stream()
self._stream_idx = 0
# ========== Three-region dedicated stream and events ==========
# ========== Ring Buffer dedicated stream and events ==========
self.transfer_stream_main = torch.cuda.Stream() # Main transfer stream
# Sync events - three-region loading completion
self.compute_ready = torch.cuda.Event()
self.prefetch_ready = torch.cuda.Event()
# Decode offload event
self.decode_offload_done = torch.cuda.Event()
# ========== Per-layer events for chunked attention ==========
# Each layer has its own event for synchronization
self.compute_ready_per_layer = [torch.cuda.Event() for _ in range(num_layers)]
self.prefetch_ready_per_layer = [torch.cuda.Event() for _ in range(num_layers)]
# ========== Per-slot Per-layer events for ring buffer ==========
# ring_slot_ready[slot_idx][layer_id] = CUDA Event for H2D completion
# ring_slot_offload_done[slot_idx][layer_id] = CUDA Event for D2H completion
self.ring_slot_ready = [
[torch.cuda.Event() for _ in range(num_layers)]
for _ in range(self.num_ring_slots)
]
self.ring_slot_offload_done = [
[torch.cuda.Event() for _ in range(num_layers)]
for _ in range(self.num_ring_slots)
]
# Per-slot events for all-layer operations (used in some legacy paths)
self.ring_slot_all_layers_ready = [torch.cuda.Event() for _ in range(self.num_ring_slots)]
self.ring_slot_all_layers_offload_done = [torch.cuda.Event() for _ in range(self.num_ring_slots)]
# ========== Event tracking for async transfers ==========
self.pending_events: Dict[Tuple[int, int], torch.cuda.Event] = {}
@@ -560,7 +565,7 @@ class OffloadEngine:
f" kv_heads={self.num_kv_heads},\n"
f" head_dim={self.head_dim},\n"
f" dtype={self.dtype},\n"
f" three_regions: decode_slot={self.decode_slot}, compute={self.compute_slots}, prefetch={self.prefetch_slots},\n"
f" ring_buffer: {self.num_ring_slots} slots, decode_slot={self.decode_slot}, decode_load_slots={self.decode_load_slots},\n"
f" gpu_memory={self.gpu_memory_bytes() / 1024**2:.1f}MB,\n"
f" cpu_memory={self.cpu_memory_bytes() / 1024**2:.1f}MB\n"
f")"
@@ -570,174 +575,207 @@ class OffloadEngine:
"""Wait for all offload operations to complete."""
self.transfer_stream_main.synchronize()
# ========== Unified Ring Buffer methods ==========
# ----- Prefill: Ring Buffer slot management -----
def get_write_slot_for_prefill(self, chunk_idx: int) -> int:
"""
Get ring buffer slot for writing prefill chunk.
For prefill, ALL slots are used as ring buffer, cycling through.
Args:
chunk_idx: Current chunk index (0, 1, 2, ...)
Returns:
GPU slot index for writing
"""
return chunk_idx % self.num_ring_slots
def get_load_slots_for_prefill(self, write_slot_idx: int) -> List[int]:
"""
Get available slots for loading previous chunks during prefill.
Excludes the current write slot to avoid conflict.
Args:
write_slot_idx: Current write slot index
Returns:
List of slot indices available for loading (N-1 slots)
"""
return [i for i in range(self.num_ring_slots) if i != write_slot_idx]
# ----- Decode: slot management -----
def get_load_slots_for_decode(self) -> List[int]:
"""
Get slots available for loading during decode.
Excludes decode_slot (slot[0]) since it's used for writing new token's KV.
Returns:
List of slot indices for loading (slots[1:])
"""
return self.decode_load_slots
# ----- Per-slot Per-layer loading methods -----
def load_to_slot_layer(self, slot_idx: int, layer_id: int, cpu_block_id: int) -> None:
"""
Async load a single CPU block to a ring buffer slot for one layer.
This is the core building block for ring buffer pipelining.
Args:
slot_idx: Target GPU slot index
layer_id: Layer index to load
cpu_block_id: Source CPU block ID
"""
logger.debug(f"Ring load: layer={layer_id}, CPU[{cpu_block_id}] -> GPU slot[{slot_idx}]")
with torch.cuda.stream(self.transfer_stream_main):
self.k_cache_gpu[layer_id, slot_idx].copy_(
self.k_cache_cpu[layer_id, cpu_block_id], non_blocking=True
)
self.v_cache_gpu[layer_id, slot_idx].copy_(
self.v_cache_cpu[layer_id, cpu_block_id], non_blocking=True
)
self.ring_slot_ready[slot_idx][layer_id].record(self.transfer_stream_main)
def wait_slot_layer(self, slot_idx: int, layer_id: int) -> None:
"""
Wait for a slot's loading to complete for a specific layer.
Args:
slot_idx: GPU slot index to wait for
layer_id: Layer index to wait for
"""
self.compute_stream.wait_event(self.ring_slot_ready[slot_idx][layer_id])
def load_to_slot_all_layers(self, slot_idx: int, cpu_block_id: int) -> None:
"""
Async load a CPU block to a ring buffer slot for ALL layers.
Args:
slot_idx: Target GPU slot index
cpu_block_id: Source CPU block ID
"""
logger.debug(f"Ring load all layers: CPU[{cpu_block_id}] -> GPU slot[{slot_idx}]")
with torch.cuda.stream(self.transfer_stream_main):
self.k_cache_gpu[:, slot_idx].copy_(
self.k_cache_cpu[:, cpu_block_id], non_blocking=True
)
self.v_cache_gpu[:, slot_idx].copy_(
self.v_cache_cpu[:, cpu_block_id], non_blocking=True
)
self.ring_slot_all_layers_ready[slot_idx].record(self.transfer_stream_main)
def wait_slot_all_layers(self, slot_idx: int) -> None:
"""Wait for a slot's loading to complete for ALL layers."""
self.compute_stream.wait_event(self.ring_slot_all_layers_ready[slot_idx])
# ----- Slot offload methods -----
def offload_slot_to_cpu(self, slot_idx: int, cpu_block_id: int) -> None:
"""
Async offload a ring buffer slot to CPU (all layers).
Args:
slot_idx: Source GPU slot index
cpu_block_id: Target CPU block ID
"""
logger.debug(f"Ring offload: GPU slot[{slot_idx}] -> CPU[{cpu_block_id}]")
with torch.cuda.stream(self.transfer_stream_main):
self.transfer_stream_main.wait_stream(self.compute_stream)
self.k_cache_cpu[:, cpu_block_id].copy_(
self.k_cache_gpu[:, slot_idx], non_blocking=True
)
self.v_cache_cpu[:, cpu_block_id].copy_(
self.v_cache_gpu[:, slot_idx], non_blocking=True
)
self.ring_slot_all_layers_offload_done[slot_idx].record(self.transfer_stream_main)
def wait_slot_offload(self, slot_idx: int) -> None:
"""Wait for slot offload to complete."""
self.compute_stream.wait_event(self.ring_slot_all_layers_offload_done[slot_idx])
def offload_slot_layer_to_cpu(self, slot_idx: int, layer_id: int, cpu_block_id: int) -> None:
"""
Async offload a ring buffer slot to CPU for one layer.
Args:
slot_idx: Source GPU slot index
layer_id: Layer index to offload
cpu_block_id: Target CPU block ID
"""
with torch.cuda.stream(self.transfer_stream_main):
self.transfer_stream_main.wait_stream(self.compute_stream)
self.k_cache_cpu[layer_id, cpu_block_id].copy_(
self.k_cache_gpu[layer_id, slot_idx], non_blocking=True
)
self.v_cache_cpu[layer_id, cpu_block_id].copy_(
self.v_cache_gpu[layer_id, slot_idx], non_blocking=True
)
self.ring_slot_offload_done[slot_idx][layer_id].record(self.transfer_stream_main)
def wait_slot_offload_layer(self, slot_idx: int, layer_id: int) -> None:
"""Wait for slot offload to complete for a specific layer."""
self.compute_stream.wait_event(self.ring_slot_offload_done[slot_idx][layer_id])
# ----- KV access methods for ring buffer -----
def get_kv_for_slot(self, slot_idx: int, layer_id: int) -> Tuple[Tensor, Tensor]:
"""
Get KV for a single ring buffer slot.
Args:
slot_idx: GPU slot index
layer_id: Layer ID
Returns:
(k_cache, v_cache), shape: [1, block_size, kv_heads, head_dim]
"""
k = self.k_cache_gpu[layer_id, slot_idx].unsqueeze(0) # [1, block_size, heads, dim]
v = self.v_cache_gpu[layer_id, slot_idx].unsqueeze(0)
return k, v
def get_kv_for_slots(
self,
layer_id: int,
gpu_slots: List[int],
slot_indices: List[int],
) -> Tuple[Tensor, Tensor]:
"""
Get KV for specified GPU slots.
Get KV for multiple ring buffer slots.
Args:
layer_id: Layer ID
gpu_slots: List of GPU slot IDs
slot_indices: List of GPU slot indices
Returns:
(k_cache, v_cache), shape: [1, len(slots) * block_size, kv_heads, head_dim]
"""
if not gpu_slots:
if not slot_indices:
return None, None
k = self.k_cache_gpu[layer_id, gpu_slots]
v = self.v_cache_gpu[layer_id, gpu_slots]
k = self.k_cache_gpu[layer_id, slot_indices]
v = self.v_cache_gpu[layer_id, slot_indices]
k = k.reshape(1, -1, self.num_kv_heads, self.head_dim)
v = v.reshape(1, -1, self.num_kv_heads, self.head_dim)
return k, v
# ========== Three-region GPU Buffer methods ==========
def load_to_compute(self, cpu_block_ids: List[int]) -> None:
"""
Async load CPU blocks to Compute region.
Args:
cpu_block_ids: List of CPU block IDs to load
"""
if not cpu_block_ids:
self.compute_ready.record(self.transfer_stream_main)
return
num_to_load = min(len(cpu_block_ids), len(self.compute_slots))
logger.debug(f"Compute load: CPU{cpu_block_ids[:num_to_load]} -> GPU compute slots {self.compute_slots[:num_to_load]}")
with torch.cuda.stream(self.transfer_stream_main):
for i in range(num_to_load):
cpu_id = cpu_block_ids[i]
gpu_slot = self.compute_slots[i]
# Copy all layers together
self.k_cache_gpu[:, gpu_slot].copy_(
self.k_cache_cpu[:, cpu_id], non_blocking=True
)
self.v_cache_gpu[:, gpu_slot].copy_(
self.v_cache_cpu[:, cpu_id], non_blocking=True
)
self.compute_ready.record(self.transfer_stream_main)
def load_to_prefetch(self, cpu_block_ids: List[int]) -> None:
"""
Async load CPU blocks to Prefetch region.
Args:
cpu_block_ids: List of CPU block IDs to load
"""
if not cpu_block_ids:
self.prefetch_ready.record(self.transfer_stream_main)
return
num_to_load = min(len(cpu_block_ids), len(self.prefetch_slots))
logger.debug(f"Prefetch load: CPU{cpu_block_ids[:num_to_load]} -> GPU prefetch slots {self.prefetch_slots[:num_to_load]}")
with torch.cuda.stream(self.transfer_stream_main):
for i in range(num_to_load):
cpu_id = cpu_block_ids[i]
gpu_slot = self.prefetch_slots[i]
self.k_cache_gpu[:, gpu_slot].copy_(
self.k_cache_cpu[:, cpu_id], non_blocking=True
)
self.v_cache_gpu[:, gpu_slot].copy_(
self.v_cache_cpu[:, cpu_id], non_blocking=True
)
self.prefetch_ready.record(self.transfer_stream_main)
def wait_compute(self) -> None:
"""Wait for Compute region loading to complete."""
self.compute_stream.wait_event(self.compute_ready)
def load_to_compute_layer(self, layer_id: int, cpu_block_ids: List[int]) -> None:
"""
Load CPU blocks to Compute region for a single layer only.
This is used for per-layer chunked attention where each layer
independently loads its KV data.
Args:
layer_id: Layer index to load
cpu_block_ids: List of CPU block IDs to load
"""
if not cpu_block_ids:
self.compute_ready_per_layer[layer_id].record(self.transfer_stream_main)
return
num_to_load = min(len(cpu_block_ids), len(self.compute_slots))
logger.debug(f"Compute load (layer {layer_id}): CPU{cpu_block_ids[:num_to_load]} -> GPU compute slots {self.compute_slots[:num_to_load]}")
with torch.cuda.stream(self.transfer_stream_main):
for i in range(num_to_load):
cpu_id = cpu_block_ids[i]
gpu_slot = self.compute_slots[i]
# Copy only this layer (not all layers)
self.k_cache_gpu[layer_id, gpu_slot].copy_(
self.k_cache_cpu[layer_id, cpu_id], non_blocking=True
)
self.v_cache_gpu[layer_id, gpu_slot].copy_(
self.v_cache_cpu[layer_id, cpu_id], non_blocking=True
)
self.compute_ready_per_layer[layer_id].record(self.transfer_stream_main)
def wait_compute_layer(self, layer_id: int) -> None:
"""Wait for specific layer's Compute region loading to complete."""
self.compute_stream.wait_event(self.compute_ready_per_layer[layer_id])
def wait_prefetch(self) -> None:
"""Wait for Prefetch region loading to complete."""
self.compute_stream.wait_event(self.prefetch_ready)
def load_to_prefetch_layer(self, layer_id: int, cpu_block_ids: List[int]) -> None:
"""
Load CPU blocks to Prefetch region for a single layer only.
This is used for per-layer chunked attention where each layer
independently loads its KV data.
Args:
layer_id: Layer index to load
cpu_block_ids: List of CPU block IDs to load
"""
if not cpu_block_ids:
self.prefetch_ready_per_layer[layer_id].record(self.transfer_stream_main)
return
num_to_load = min(len(cpu_block_ids), len(self.prefetch_slots))
logger.debug(f"Prefetch load (layer {layer_id}): CPU{cpu_block_ids[:num_to_load]} -> GPU prefetch slots {self.prefetch_slots[:num_to_load]}")
with torch.cuda.stream(self.transfer_stream_main):
for i in range(num_to_load):
cpu_id = cpu_block_ids[i]
gpu_slot = self.prefetch_slots[i]
# Copy only this layer (not all layers)
self.k_cache_gpu[layer_id, gpu_slot].copy_(
self.k_cache_cpu[layer_id, cpu_id], non_blocking=True
)
self.v_cache_gpu[layer_id, gpu_slot].copy_(
self.v_cache_cpu[layer_id, cpu_id], non_blocking=True
)
self.prefetch_ready_per_layer[layer_id].record(self.transfer_stream_main)
def wait_prefetch_layer(self, layer_id: int) -> None:
"""Wait for specific layer's Prefetch region loading to complete."""
self.compute_stream.wait_event(self.prefetch_ready_per_layer[layer_id])
def swap_compute_prefetch(self) -> None:
"""Swap roles of Compute region and Prefetch region."""
self.compute_slots, self.prefetch_slots = self.prefetch_slots, self.compute_slots
# ----- Decode slot methods (kept for decode phase) -----
def offload_decode_slot(self, cpu_block_id: int) -> None:
"""
Offload KV from Decode region to CPU.
Offload KV from decode slot (slot[0]) to CPU.
Args:
cpu_block_id: Target CPU block ID
"""
logger.debug(f"Decode offload: GPU[{self.decode_slot}] -> CPU[{cpu_block_id}]")
logger.debug(f"Decode offload: GPU slot[{self.decode_slot}] -> CPU[{cpu_block_id}]")
with torch.cuda.stream(self.transfer_stream_main):
self.transfer_stream_main.wait_stream(self.compute_stream)
@@ -750,61 +788,16 @@ class OffloadEngine:
self.decode_offload_done.record(self.transfer_stream_main)
def wait_decode_offload(self) -> None:
"""Wait for Decode region offload to complete."""
"""Wait for decode slot offload to complete."""
self.compute_stream.wait_event(self.decode_offload_done)
def get_kv_for_compute(
self,
layer_id: int,
num_blocks: int,
) -> Tuple[Tensor, Tensor]:
"""
Get KV for specified number of blocks in Compute region.
Args:
layer_id: Layer ID
num_blocks: Number of blocks needed
Returns:
(k_cache, v_cache), shape: [1, num_blocks * block_size, kv_heads, head_dim]
"""
slots = self.compute_slots[:num_blocks]
k = self.k_cache_gpu[layer_id, slots] # [num_blocks, block_size, heads, dim]
v = self.v_cache_gpu[layer_id, slots]
# Reshape: [num_blocks, block_size, heads, dim] -> [1, num_blocks*block_size, heads, dim]
k = k.reshape(1, -1, self.num_kv_heads, self.head_dim)
v = v.reshape(1, -1, self.num_kv_heads, self.head_dim)
return k, v
def get_kv_for_prefetch(
self,
layer_id: int,
num_blocks: int,
) -> Tuple[Tensor, Tensor]:
"""
Get KV for specified number of blocks in Prefetch region.
Args:
layer_id: Layer ID
num_blocks: Number of blocks needed
Returns:
(k_cache, v_cache), shape: [1, num_blocks * block_size, kv_heads, head_dim]
"""
slots = self.prefetch_slots[:num_blocks]
k = self.k_cache_gpu[layer_id, slots]
v = self.v_cache_gpu[layer_id, slots]
k = k.reshape(1, -1, self.num_kv_heads, self.head_dim)
v = v.reshape(1, -1, self.num_kv_heads, self.head_dim)
return k, v
def get_kv_for_decode_slot(
self,
layer_id: int,
pos_in_block: int,
) -> Tuple[Tensor, Tensor]:
"""
Get KV at specified position in Decode region (for new token during decode).
Get KV at specified position in decode slot.
Args:
layer_id: Layer ID
@@ -813,9 +806,9 @@ class OffloadEngine:
Returns:
(k_cache, v_cache), shape: [1, 1, kv_heads, head_dim]
"""
k = self.k_cache_gpu[layer_id, self.decode_slot, pos_in_block:pos_in_block+1] # [1, heads, dim]
k = self.k_cache_gpu[layer_id, self.decode_slot, pos_in_block:pos_in_block+1]
v = self.v_cache_gpu[layer_id, self.decode_slot, pos_in_block:pos_in_block+1]
k = k.unsqueeze(0) # [1, 1, heads, dim]
k = k.unsqueeze(0)
v = v.unsqueeze(0)
return k, v
@@ -825,10 +818,7 @@ class OffloadEngine:
num_tokens: int,
) -> Tuple[Tensor, Tensor]:
"""
Get accumulated KV in Decode region (all tokens from position 0 to num_tokens-1).
Used when batching decode offloads - attend to all accumulated tokens,
not just the current one.
Get accumulated KV in decode slot (positions 0 to num_tokens-1).
Args:
layer_id: Layer ID
@@ -837,35 +827,102 @@ class OffloadEngine:
Returns:
(k_cache, v_cache), shape: [1, num_tokens, kv_heads, head_dim]
"""
k = self.k_cache_gpu[layer_id, self.decode_slot, :num_tokens] # [num_tokens, heads, dim]
k = self.k_cache_gpu[layer_id, self.decode_slot, :num_tokens]
v = self.v_cache_gpu[layer_id, self.decode_slot, :num_tokens]
k = k.unsqueeze(0) # [1, num_tokens, heads, dim]
k = k.unsqueeze(0)
v = v.unsqueeze(0)
return k, v
def offload_compute_to_cpu(self, cpu_block_ids: List[int]) -> None:
"""
Offload KV from Compute region to CPU.
# ----- Legacy compatibility methods (for decode double-buffering) -----
Args:
cpu_block_ids: Target CPU block IDs list
def load_to_compute_layer(self, layer_id: int, cpu_block_ids: List[int]) -> None:
"""
Legacy: Load CPU blocks to decode_load_slots for decode double-buffering.
Uses first half of decode_load_slots as 'compute' region.
"""
if not cpu_block_ids:
return
num_to_offload = min(len(cpu_block_ids), len(self.compute_slots))
logger.debug(f"Compute offload: GPU {self.compute_slots[:num_to_offload]} -> CPU{cpu_block_ids[:num_to_offload]}")
half = max(1, len(self.decode_load_slots) // 2)
slots = self.decode_load_slots[:half]
num_to_load = min(len(cpu_block_ids), len(slots))
with torch.cuda.stream(self.transfer_stream_main):
# Wait for compute to complete
self.transfer_stream_main.wait_stream(self.compute_stream)
for i in range(num_to_offload):
gpu_slot = self.compute_slots[i]
for i in range(num_to_load):
cpu_id = cpu_block_ids[i]
self.k_cache_cpu[:, cpu_id].copy_(
self.k_cache_gpu[:, gpu_slot], non_blocking=True
gpu_slot = slots[i]
self.k_cache_gpu[layer_id, gpu_slot].copy_(
self.k_cache_cpu[layer_id, cpu_id], non_blocking=True
)
self.v_cache_cpu[:, cpu_id].copy_(
self.v_cache_gpu[:, gpu_slot], non_blocking=True
)
self.v_cache_gpu[layer_id, gpu_slot].copy_(
self.v_cache_cpu[layer_id, cpu_id], non_blocking=True
)
if num_to_load > 0:
self.ring_slot_ready[slots[0]][layer_id].record(self.transfer_stream_main)
def wait_compute_layer(self, layer_id: int) -> None:
"""Legacy: Wait for 'compute' region loading."""
half = max(1, len(self.decode_load_slots) // 2)
if self.decode_load_slots:
self.wait_slot_layer(self.decode_load_slots[0], layer_id)
def load_to_prefetch_layer(self, layer_id: int, cpu_block_ids: List[int]) -> None:
"""
Legacy: Load CPU blocks to decode_load_slots for decode double-buffering.
Uses second half of decode_load_slots as 'prefetch' region.
"""
if not cpu_block_ids:
return
half = max(1, len(self.decode_load_slots) // 2)
slots = self.decode_load_slots[half:]
if not slots:
slots = self.decode_load_slots # Fallback if only 1-2 slots
num_to_load = min(len(cpu_block_ids), len(slots))
with torch.cuda.stream(self.transfer_stream_main):
for i in range(num_to_load):
cpu_id = cpu_block_ids[i]
gpu_slot = slots[i]
self.k_cache_gpu[layer_id, gpu_slot].copy_(
self.k_cache_cpu[layer_id, cpu_id], non_blocking=True
)
self.v_cache_gpu[layer_id, gpu_slot].copy_(
self.v_cache_cpu[layer_id, cpu_id], non_blocking=True
)
if num_to_load > 0:
self.ring_slot_ready[slots[0]][layer_id].record(self.transfer_stream_main)
def wait_prefetch_layer(self, layer_id: int) -> None:
"""Legacy: Wait for 'prefetch' region loading."""
half = max(1, len(self.decode_load_slots) // 2)
slots = self.decode_load_slots[half:]
if slots:
self.wait_slot_layer(slots[0], layer_id)
elif self.decode_load_slots:
self.wait_slot_layer(self.decode_load_slots[0], layer_id)
def get_kv_for_compute(
self,
layer_id: int,
num_blocks: int,
) -> Tuple[Tensor, Tensor]:
"""Legacy: Get KV from 'compute' region (first half of decode_load_slots)."""
half = max(1, len(self.decode_load_slots) // 2)
slots = self.decode_load_slots[:half][:num_blocks]
return self.get_kv_for_slots(layer_id, slots)
def get_kv_for_prefetch(
self,
layer_id: int,
num_blocks: int,
) -> Tuple[Tensor, Tensor]:
"""Legacy: Get KV from 'prefetch' region (second half of decode_load_slots)."""
half = max(1, len(self.decode_load_slots) // 2)
slots = self.decode_load_slots[half:]
if not slots:
slots = self.decode_load_slots
slots = slots[:num_blocks]
return self.get_kv_for_slots(layer_id, slots)