[feat] Need to optimized with async prefetch.
This commit is contained in:
@@ -65,34 +65,30 @@ class OffloadEngine:
|
||||
self.kv_dim = num_kv_heads * head_dim
|
||||
self.block_numel = block_size * self.kv_dim
|
||||
|
||||
# ========== Three-region GPU Buffer configuration ==========
|
||||
# ========== Unified Ring Buffer configuration ==========
|
||||
# Constraint checks
|
||||
assert num_gpu_blocks >= 3, \
|
||||
f"Need at least 3 GPU blocks: 1 decode + 1 compute + 1 prefetch, got {num_gpu_blocks}"
|
||||
assert num_prefetch_blocks >= 1, \
|
||||
f"Need at least 1 prefetch block, got {num_prefetch_blocks}"
|
||||
assert num_gpu_blocks >= 1 + 1 + num_prefetch_blocks, \
|
||||
f"Insufficient GPU blocks: need 1(decode) + 1(compute) + {num_prefetch_blocks}(prefetch), got {num_gpu_blocks}"
|
||||
assert num_gpu_blocks >= 2, \
|
||||
f"Need at least 2 GPU blocks for ring buffer, got {num_gpu_blocks}"
|
||||
|
||||
# Three-region configuration
|
||||
# Decode region: [0] - Fixed 1 block for writing new KV
|
||||
# Unified Ring Buffer: all slots cycle for prefill
|
||||
# Prefill: use ALL slots as ring buffer (slot[chunk_idx % N])
|
||||
# Decode: slot[0] as decode_slot, slots[1:] for loading previous chunks
|
||||
self.num_ring_slots = num_gpu_blocks
|
||||
self.ring_slots = list(range(num_gpu_blocks))
|
||||
|
||||
# Decode phase uses slot[0] for writing new token's KV
|
||||
self.decode_slot = 0
|
||||
# Decode phase uses slots[1:] for loading previous chunks from CPU
|
||||
self.decode_load_slots = list(range(1, num_gpu_blocks))
|
||||
self.num_decode_load_slots = len(self.decode_load_slots)
|
||||
|
||||
# Compute region: [1, ..., num_gpu_blocks - num_prefetch_blocks - 1]
|
||||
compute_start = 1
|
||||
compute_end = num_gpu_blocks - num_prefetch_blocks
|
||||
self.compute_slots = list(range(compute_start, compute_end))
|
||||
self.num_compute_blocks = len(self.compute_slots)
|
||||
|
||||
# Prefetch region: [num_gpu_blocks - num_prefetch_blocks, ..., num_gpu_blocks - 1]
|
||||
prefetch_start = compute_end
|
||||
self.prefetch_slots = list(range(prefetch_start, num_gpu_blocks))
|
||||
# Keep num_prefetch_blocks for compatibility (used as chunk size for loading)
|
||||
self.num_prefetch_blocks = num_prefetch_blocks
|
||||
|
||||
self.num_gpu_slots = num_gpu_blocks # alias
|
||||
|
||||
logger.info(f"Three-region GPU Buffer: decode_slot={self.decode_slot}, "
|
||||
f"compute_slots={self.compute_slots}, prefetch_slots={self.prefetch_slots}")
|
||||
logger.info(f"Unified Ring Buffer: {self.num_ring_slots} slots total")
|
||||
logger.info(f" Prefill: all slots as ring buffer [0..{num_gpu_blocks-1}]")
|
||||
logger.info(f" Decode: slot[0] as decode_slot, slots[1..{num_gpu_blocks-1}] for loading")
|
||||
|
||||
# ========== Fixed-address GPU KV cache ==========
|
||||
# Shape: [num_layers, num_gpu_blocks, block_size, kv_heads, head_dim]
|
||||
@@ -134,18 +130,27 @@ class OffloadEngine:
|
||||
self.compute_stream = torch.cuda.current_stream()
|
||||
self._stream_idx = 0
|
||||
|
||||
# ========== Three-region dedicated stream and events ==========
|
||||
# ========== Ring Buffer dedicated stream and events ==========
|
||||
self.transfer_stream_main = torch.cuda.Stream() # Main transfer stream
|
||||
|
||||
# Sync events - three-region loading completion
|
||||
self.compute_ready = torch.cuda.Event()
|
||||
self.prefetch_ready = torch.cuda.Event()
|
||||
# Decode offload event
|
||||
self.decode_offload_done = torch.cuda.Event()
|
||||
|
||||
# ========== Per-layer events for chunked attention ==========
|
||||
# Each layer has its own event for synchronization
|
||||
self.compute_ready_per_layer = [torch.cuda.Event() for _ in range(num_layers)]
|
||||
self.prefetch_ready_per_layer = [torch.cuda.Event() for _ in range(num_layers)]
|
||||
# ========== Per-slot Per-layer events for ring buffer ==========
|
||||
# ring_slot_ready[slot_idx][layer_id] = CUDA Event for H2D completion
|
||||
# ring_slot_offload_done[slot_idx][layer_id] = CUDA Event for D2H completion
|
||||
self.ring_slot_ready = [
|
||||
[torch.cuda.Event() for _ in range(num_layers)]
|
||||
for _ in range(self.num_ring_slots)
|
||||
]
|
||||
self.ring_slot_offload_done = [
|
||||
[torch.cuda.Event() for _ in range(num_layers)]
|
||||
for _ in range(self.num_ring_slots)
|
||||
]
|
||||
|
||||
# Per-slot events for all-layer operations (used in some legacy paths)
|
||||
self.ring_slot_all_layers_ready = [torch.cuda.Event() for _ in range(self.num_ring_slots)]
|
||||
self.ring_slot_all_layers_offload_done = [torch.cuda.Event() for _ in range(self.num_ring_slots)]
|
||||
|
||||
# ========== Event tracking for async transfers ==========
|
||||
self.pending_events: Dict[Tuple[int, int], torch.cuda.Event] = {}
|
||||
@@ -560,7 +565,7 @@ class OffloadEngine:
|
||||
f" kv_heads={self.num_kv_heads},\n"
|
||||
f" head_dim={self.head_dim},\n"
|
||||
f" dtype={self.dtype},\n"
|
||||
f" three_regions: decode_slot={self.decode_slot}, compute={self.compute_slots}, prefetch={self.prefetch_slots},\n"
|
||||
f" ring_buffer: {self.num_ring_slots} slots, decode_slot={self.decode_slot}, decode_load_slots={self.decode_load_slots},\n"
|
||||
f" gpu_memory={self.gpu_memory_bytes() / 1024**2:.1f}MB,\n"
|
||||
f" cpu_memory={self.cpu_memory_bytes() / 1024**2:.1f}MB\n"
|
||||
f")"
|
||||
@@ -570,174 +575,207 @@ class OffloadEngine:
|
||||
"""Wait for all offload operations to complete."""
|
||||
self.transfer_stream_main.synchronize()
|
||||
|
||||
# ========== Unified Ring Buffer methods ==========
|
||||
|
||||
# ----- Prefill: Ring Buffer slot management -----
|
||||
|
||||
def get_write_slot_for_prefill(self, chunk_idx: int) -> int:
|
||||
"""
|
||||
Get ring buffer slot for writing prefill chunk.
|
||||
|
||||
For prefill, ALL slots are used as ring buffer, cycling through.
|
||||
|
||||
Args:
|
||||
chunk_idx: Current chunk index (0, 1, 2, ...)
|
||||
|
||||
Returns:
|
||||
GPU slot index for writing
|
||||
"""
|
||||
return chunk_idx % self.num_ring_slots
|
||||
|
||||
def get_load_slots_for_prefill(self, write_slot_idx: int) -> List[int]:
|
||||
"""
|
||||
Get available slots for loading previous chunks during prefill.
|
||||
|
||||
Excludes the current write slot to avoid conflict.
|
||||
|
||||
Args:
|
||||
write_slot_idx: Current write slot index
|
||||
|
||||
Returns:
|
||||
List of slot indices available for loading (N-1 slots)
|
||||
"""
|
||||
return [i for i in range(self.num_ring_slots) if i != write_slot_idx]
|
||||
|
||||
# ----- Decode: slot management -----
|
||||
|
||||
def get_load_slots_for_decode(self) -> List[int]:
|
||||
"""
|
||||
Get slots available for loading during decode.
|
||||
|
||||
Excludes decode_slot (slot[0]) since it's used for writing new token's KV.
|
||||
|
||||
Returns:
|
||||
List of slot indices for loading (slots[1:])
|
||||
"""
|
||||
return self.decode_load_slots
|
||||
|
||||
# ----- Per-slot Per-layer loading methods -----
|
||||
|
||||
def load_to_slot_layer(self, slot_idx: int, layer_id: int, cpu_block_id: int) -> None:
|
||||
"""
|
||||
Async load a single CPU block to a ring buffer slot for one layer.
|
||||
|
||||
This is the core building block for ring buffer pipelining.
|
||||
|
||||
Args:
|
||||
slot_idx: Target GPU slot index
|
||||
layer_id: Layer index to load
|
||||
cpu_block_id: Source CPU block ID
|
||||
"""
|
||||
logger.debug(f"Ring load: layer={layer_id}, CPU[{cpu_block_id}] -> GPU slot[{slot_idx}]")
|
||||
|
||||
with torch.cuda.stream(self.transfer_stream_main):
|
||||
self.k_cache_gpu[layer_id, slot_idx].copy_(
|
||||
self.k_cache_cpu[layer_id, cpu_block_id], non_blocking=True
|
||||
)
|
||||
self.v_cache_gpu[layer_id, slot_idx].copy_(
|
||||
self.v_cache_cpu[layer_id, cpu_block_id], non_blocking=True
|
||||
)
|
||||
self.ring_slot_ready[slot_idx][layer_id].record(self.transfer_stream_main)
|
||||
|
||||
def wait_slot_layer(self, slot_idx: int, layer_id: int) -> None:
|
||||
"""
|
||||
Wait for a slot's loading to complete for a specific layer.
|
||||
|
||||
Args:
|
||||
slot_idx: GPU slot index to wait for
|
||||
layer_id: Layer index to wait for
|
||||
"""
|
||||
self.compute_stream.wait_event(self.ring_slot_ready[slot_idx][layer_id])
|
||||
|
||||
def load_to_slot_all_layers(self, slot_idx: int, cpu_block_id: int) -> None:
|
||||
"""
|
||||
Async load a CPU block to a ring buffer slot for ALL layers.
|
||||
|
||||
Args:
|
||||
slot_idx: Target GPU slot index
|
||||
cpu_block_id: Source CPU block ID
|
||||
"""
|
||||
logger.debug(f"Ring load all layers: CPU[{cpu_block_id}] -> GPU slot[{slot_idx}]")
|
||||
|
||||
with torch.cuda.stream(self.transfer_stream_main):
|
||||
self.k_cache_gpu[:, slot_idx].copy_(
|
||||
self.k_cache_cpu[:, cpu_block_id], non_blocking=True
|
||||
)
|
||||
self.v_cache_gpu[:, slot_idx].copy_(
|
||||
self.v_cache_cpu[:, cpu_block_id], non_blocking=True
|
||||
)
|
||||
self.ring_slot_all_layers_ready[slot_idx].record(self.transfer_stream_main)
|
||||
|
||||
def wait_slot_all_layers(self, slot_idx: int) -> None:
|
||||
"""Wait for a slot's loading to complete for ALL layers."""
|
||||
self.compute_stream.wait_event(self.ring_slot_all_layers_ready[slot_idx])
|
||||
|
||||
# ----- Slot offload methods -----
|
||||
|
||||
def offload_slot_to_cpu(self, slot_idx: int, cpu_block_id: int) -> None:
|
||||
"""
|
||||
Async offload a ring buffer slot to CPU (all layers).
|
||||
|
||||
Args:
|
||||
slot_idx: Source GPU slot index
|
||||
cpu_block_id: Target CPU block ID
|
||||
"""
|
||||
logger.debug(f"Ring offload: GPU slot[{slot_idx}] -> CPU[{cpu_block_id}]")
|
||||
|
||||
with torch.cuda.stream(self.transfer_stream_main):
|
||||
self.transfer_stream_main.wait_stream(self.compute_stream)
|
||||
self.k_cache_cpu[:, cpu_block_id].copy_(
|
||||
self.k_cache_gpu[:, slot_idx], non_blocking=True
|
||||
)
|
||||
self.v_cache_cpu[:, cpu_block_id].copy_(
|
||||
self.v_cache_gpu[:, slot_idx], non_blocking=True
|
||||
)
|
||||
self.ring_slot_all_layers_offload_done[slot_idx].record(self.transfer_stream_main)
|
||||
|
||||
def wait_slot_offload(self, slot_idx: int) -> None:
|
||||
"""Wait for slot offload to complete."""
|
||||
self.compute_stream.wait_event(self.ring_slot_all_layers_offload_done[slot_idx])
|
||||
|
||||
def offload_slot_layer_to_cpu(self, slot_idx: int, layer_id: int, cpu_block_id: int) -> None:
|
||||
"""
|
||||
Async offload a ring buffer slot to CPU for one layer.
|
||||
|
||||
Args:
|
||||
slot_idx: Source GPU slot index
|
||||
layer_id: Layer index to offload
|
||||
cpu_block_id: Target CPU block ID
|
||||
"""
|
||||
with torch.cuda.stream(self.transfer_stream_main):
|
||||
self.transfer_stream_main.wait_stream(self.compute_stream)
|
||||
self.k_cache_cpu[layer_id, cpu_block_id].copy_(
|
||||
self.k_cache_gpu[layer_id, slot_idx], non_blocking=True
|
||||
)
|
||||
self.v_cache_cpu[layer_id, cpu_block_id].copy_(
|
||||
self.v_cache_gpu[layer_id, slot_idx], non_blocking=True
|
||||
)
|
||||
self.ring_slot_offload_done[slot_idx][layer_id].record(self.transfer_stream_main)
|
||||
|
||||
def wait_slot_offload_layer(self, slot_idx: int, layer_id: int) -> None:
|
||||
"""Wait for slot offload to complete for a specific layer."""
|
||||
self.compute_stream.wait_event(self.ring_slot_offload_done[slot_idx][layer_id])
|
||||
|
||||
# ----- KV access methods for ring buffer -----
|
||||
|
||||
def get_kv_for_slot(self, slot_idx: int, layer_id: int) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Get KV for a single ring buffer slot.
|
||||
|
||||
Args:
|
||||
slot_idx: GPU slot index
|
||||
layer_id: Layer ID
|
||||
|
||||
Returns:
|
||||
(k_cache, v_cache), shape: [1, block_size, kv_heads, head_dim]
|
||||
"""
|
||||
k = self.k_cache_gpu[layer_id, slot_idx].unsqueeze(0) # [1, block_size, heads, dim]
|
||||
v = self.v_cache_gpu[layer_id, slot_idx].unsqueeze(0)
|
||||
return k, v
|
||||
|
||||
def get_kv_for_slots(
|
||||
self,
|
||||
layer_id: int,
|
||||
gpu_slots: List[int],
|
||||
slot_indices: List[int],
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Get KV for specified GPU slots.
|
||||
Get KV for multiple ring buffer slots.
|
||||
|
||||
Args:
|
||||
layer_id: Layer ID
|
||||
gpu_slots: List of GPU slot IDs
|
||||
slot_indices: List of GPU slot indices
|
||||
|
||||
Returns:
|
||||
(k_cache, v_cache), shape: [1, len(slots) * block_size, kv_heads, head_dim]
|
||||
"""
|
||||
if not gpu_slots:
|
||||
if not slot_indices:
|
||||
return None, None
|
||||
k = self.k_cache_gpu[layer_id, gpu_slots]
|
||||
v = self.v_cache_gpu[layer_id, gpu_slots]
|
||||
k = self.k_cache_gpu[layer_id, slot_indices]
|
||||
v = self.v_cache_gpu[layer_id, slot_indices]
|
||||
k = k.reshape(1, -1, self.num_kv_heads, self.head_dim)
|
||||
v = v.reshape(1, -1, self.num_kv_heads, self.head_dim)
|
||||
return k, v
|
||||
|
||||
# ========== Three-region GPU Buffer methods ==========
|
||||
|
||||
def load_to_compute(self, cpu_block_ids: List[int]) -> None:
|
||||
"""
|
||||
Async load CPU blocks to Compute region.
|
||||
|
||||
Args:
|
||||
cpu_block_ids: List of CPU block IDs to load
|
||||
"""
|
||||
if not cpu_block_ids:
|
||||
self.compute_ready.record(self.transfer_stream_main)
|
||||
return
|
||||
|
||||
num_to_load = min(len(cpu_block_ids), len(self.compute_slots))
|
||||
logger.debug(f"Compute load: CPU{cpu_block_ids[:num_to_load]} -> GPU compute slots {self.compute_slots[:num_to_load]}")
|
||||
|
||||
with torch.cuda.stream(self.transfer_stream_main):
|
||||
for i in range(num_to_load):
|
||||
cpu_id = cpu_block_ids[i]
|
||||
gpu_slot = self.compute_slots[i]
|
||||
# Copy all layers together
|
||||
self.k_cache_gpu[:, gpu_slot].copy_(
|
||||
self.k_cache_cpu[:, cpu_id], non_blocking=True
|
||||
)
|
||||
self.v_cache_gpu[:, gpu_slot].copy_(
|
||||
self.v_cache_cpu[:, cpu_id], non_blocking=True
|
||||
)
|
||||
self.compute_ready.record(self.transfer_stream_main)
|
||||
|
||||
def load_to_prefetch(self, cpu_block_ids: List[int]) -> None:
|
||||
"""
|
||||
Async load CPU blocks to Prefetch region.
|
||||
|
||||
Args:
|
||||
cpu_block_ids: List of CPU block IDs to load
|
||||
"""
|
||||
if not cpu_block_ids:
|
||||
self.prefetch_ready.record(self.transfer_stream_main)
|
||||
return
|
||||
|
||||
num_to_load = min(len(cpu_block_ids), len(self.prefetch_slots))
|
||||
logger.debug(f"Prefetch load: CPU{cpu_block_ids[:num_to_load]} -> GPU prefetch slots {self.prefetch_slots[:num_to_load]}")
|
||||
|
||||
with torch.cuda.stream(self.transfer_stream_main):
|
||||
for i in range(num_to_load):
|
||||
cpu_id = cpu_block_ids[i]
|
||||
gpu_slot = self.prefetch_slots[i]
|
||||
self.k_cache_gpu[:, gpu_slot].copy_(
|
||||
self.k_cache_cpu[:, cpu_id], non_blocking=True
|
||||
)
|
||||
self.v_cache_gpu[:, gpu_slot].copy_(
|
||||
self.v_cache_cpu[:, cpu_id], non_blocking=True
|
||||
)
|
||||
self.prefetch_ready.record(self.transfer_stream_main)
|
||||
|
||||
def wait_compute(self) -> None:
|
||||
"""Wait for Compute region loading to complete."""
|
||||
self.compute_stream.wait_event(self.compute_ready)
|
||||
|
||||
def load_to_compute_layer(self, layer_id: int, cpu_block_ids: List[int]) -> None:
|
||||
"""
|
||||
Load CPU blocks to Compute region for a single layer only.
|
||||
|
||||
This is used for per-layer chunked attention where each layer
|
||||
independently loads its KV data.
|
||||
|
||||
Args:
|
||||
layer_id: Layer index to load
|
||||
cpu_block_ids: List of CPU block IDs to load
|
||||
"""
|
||||
if not cpu_block_ids:
|
||||
self.compute_ready_per_layer[layer_id].record(self.transfer_stream_main)
|
||||
return
|
||||
|
||||
num_to_load = min(len(cpu_block_ids), len(self.compute_slots))
|
||||
logger.debug(f"Compute load (layer {layer_id}): CPU{cpu_block_ids[:num_to_load]} -> GPU compute slots {self.compute_slots[:num_to_load]}")
|
||||
|
||||
with torch.cuda.stream(self.transfer_stream_main):
|
||||
for i in range(num_to_load):
|
||||
cpu_id = cpu_block_ids[i]
|
||||
gpu_slot = self.compute_slots[i]
|
||||
# Copy only this layer (not all layers)
|
||||
self.k_cache_gpu[layer_id, gpu_slot].copy_(
|
||||
self.k_cache_cpu[layer_id, cpu_id], non_blocking=True
|
||||
)
|
||||
self.v_cache_gpu[layer_id, gpu_slot].copy_(
|
||||
self.v_cache_cpu[layer_id, cpu_id], non_blocking=True
|
||||
)
|
||||
self.compute_ready_per_layer[layer_id].record(self.transfer_stream_main)
|
||||
|
||||
def wait_compute_layer(self, layer_id: int) -> None:
|
||||
"""Wait for specific layer's Compute region loading to complete."""
|
||||
self.compute_stream.wait_event(self.compute_ready_per_layer[layer_id])
|
||||
|
||||
def wait_prefetch(self) -> None:
|
||||
"""Wait for Prefetch region loading to complete."""
|
||||
self.compute_stream.wait_event(self.prefetch_ready)
|
||||
|
||||
def load_to_prefetch_layer(self, layer_id: int, cpu_block_ids: List[int]) -> None:
|
||||
"""
|
||||
Load CPU blocks to Prefetch region for a single layer only.
|
||||
|
||||
This is used for per-layer chunked attention where each layer
|
||||
independently loads its KV data.
|
||||
|
||||
Args:
|
||||
layer_id: Layer index to load
|
||||
cpu_block_ids: List of CPU block IDs to load
|
||||
"""
|
||||
if not cpu_block_ids:
|
||||
self.prefetch_ready_per_layer[layer_id].record(self.transfer_stream_main)
|
||||
return
|
||||
|
||||
num_to_load = min(len(cpu_block_ids), len(self.prefetch_slots))
|
||||
logger.debug(f"Prefetch load (layer {layer_id}): CPU{cpu_block_ids[:num_to_load]} -> GPU prefetch slots {self.prefetch_slots[:num_to_load]}")
|
||||
|
||||
with torch.cuda.stream(self.transfer_stream_main):
|
||||
for i in range(num_to_load):
|
||||
cpu_id = cpu_block_ids[i]
|
||||
gpu_slot = self.prefetch_slots[i]
|
||||
# Copy only this layer (not all layers)
|
||||
self.k_cache_gpu[layer_id, gpu_slot].copy_(
|
||||
self.k_cache_cpu[layer_id, cpu_id], non_blocking=True
|
||||
)
|
||||
self.v_cache_gpu[layer_id, gpu_slot].copy_(
|
||||
self.v_cache_cpu[layer_id, cpu_id], non_blocking=True
|
||||
)
|
||||
self.prefetch_ready_per_layer[layer_id].record(self.transfer_stream_main)
|
||||
|
||||
def wait_prefetch_layer(self, layer_id: int) -> None:
|
||||
"""Wait for specific layer's Prefetch region loading to complete."""
|
||||
self.compute_stream.wait_event(self.prefetch_ready_per_layer[layer_id])
|
||||
|
||||
def swap_compute_prefetch(self) -> None:
|
||||
"""Swap roles of Compute region and Prefetch region."""
|
||||
self.compute_slots, self.prefetch_slots = self.prefetch_slots, self.compute_slots
|
||||
# ----- Decode slot methods (kept for decode phase) -----
|
||||
|
||||
def offload_decode_slot(self, cpu_block_id: int) -> None:
|
||||
"""
|
||||
Offload KV from Decode region to CPU.
|
||||
Offload KV from decode slot (slot[0]) to CPU.
|
||||
|
||||
Args:
|
||||
cpu_block_id: Target CPU block ID
|
||||
"""
|
||||
logger.debug(f"Decode offload: GPU[{self.decode_slot}] -> CPU[{cpu_block_id}]")
|
||||
logger.debug(f"Decode offload: GPU slot[{self.decode_slot}] -> CPU[{cpu_block_id}]")
|
||||
|
||||
with torch.cuda.stream(self.transfer_stream_main):
|
||||
self.transfer_stream_main.wait_stream(self.compute_stream)
|
||||
@@ -750,61 +788,16 @@ class OffloadEngine:
|
||||
self.decode_offload_done.record(self.transfer_stream_main)
|
||||
|
||||
def wait_decode_offload(self) -> None:
|
||||
"""Wait for Decode region offload to complete."""
|
||||
"""Wait for decode slot offload to complete."""
|
||||
self.compute_stream.wait_event(self.decode_offload_done)
|
||||
|
||||
def get_kv_for_compute(
|
||||
self,
|
||||
layer_id: int,
|
||||
num_blocks: int,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Get KV for specified number of blocks in Compute region.
|
||||
|
||||
Args:
|
||||
layer_id: Layer ID
|
||||
num_blocks: Number of blocks needed
|
||||
|
||||
Returns:
|
||||
(k_cache, v_cache), shape: [1, num_blocks * block_size, kv_heads, head_dim]
|
||||
"""
|
||||
slots = self.compute_slots[:num_blocks]
|
||||
k = self.k_cache_gpu[layer_id, slots] # [num_blocks, block_size, heads, dim]
|
||||
v = self.v_cache_gpu[layer_id, slots]
|
||||
# Reshape: [num_blocks, block_size, heads, dim] -> [1, num_blocks*block_size, heads, dim]
|
||||
k = k.reshape(1, -1, self.num_kv_heads, self.head_dim)
|
||||
v = v.reshape(1, -1, self.num_kv_heads, self.head_dim)
|
||||
return k, v
|
||||
|
||||
def get_kv_for_prefetch(
|
||||
self,
|
||||
layer_id: int,
|
||||
num_blocks: int,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Get KV for specified number of blocks in Prefetch region.
|
||||
|
||||
Args:
|
||||
layer_id: Layer ID
|
||||
num_blocks: Number of blocks needed
|
||||
|
||||
Returns:
|
||||
(k_cache, v_cache), shape: [1, num_blocks * block_size, kv_heads, head_dim]
|
||||
"""
|
||||
slots = self.prefetch_slots[:num_blocks]
|
||||
k = self.k_cache_gpu[layer_id, slots]
|
||||
v = self.v_cache_gpu[layer_id, slots]
|
||||
k = k.reshape(1, -1, self.num_kv_heads, self.head_dim)
|
||||
v = v.reshape(1, -1, self.num_kv_heads, self.head_dim)
|
||||
return k, v
|
||||
|
||||
def get_kv_for_decode_slot(
|
||||
self,
|
||||
layer_id: int,
|
||||
pos_in_block: int,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Get KV at specified position in Decode region (for new token during decode).
|
||||
Get KV at specified position in decode slot.
|
||||
|
||||
Args:
|
||||
layer_id: Layer ID
|
||||
@@ -813,9 +806,9 @@ class OffloadEngine:
|
||||
Returns:
|
||||
(k_cache, v_cache), shape: [1, 1, kv_heads, head_dim]
|
||||
"""
|
||||
k = self.k_cache_gpu[layer_id, self.decode_slot, pos_in_block:pos_in_block+1] # [1, heads, dim]
|
||||
k = self.k_cache_gpu[layer_id, self.decode_slot, pos_in_block:pos_in_block+1]
|
||||
v = self.v_cache_gpu[layer_id, self.decode_slot, pos_in_block:pos_in_block+1]
|
||||
k = k.unsqueeze(0) # [1, 1, heads, dim]
|
||||
k = k.unsqueeze(0)
|
||||
v = v.unsqueeze(0)
|
||||
return k, v
|
||||
|
||||
@@ -825,10 +818,7 @@ class OffloadEngine:
|
||||
num_tokens: int,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Get accumulated KV in Decode region (all tokens from position 0 to num_tokens-1).
|
||||
|
||||
Used when batching decode offloads - attend to all accumulated tokens,
|
||||
not just the current one.
|
||||
Get accumulated KV in decode slot (positions 0 to num_tokens-1).
|
||||
|
||||
Args:
|
||||
layer_id: Layer ID
|
||||
@@ -837,35 +827,102 @@ class OffloadEngine:
|
||||
Returns:
|
||||
(k_cache, v_cache), shape: [1, num_tokens, kv_heads, head_dim]
|
||||
"""
|
||||
k = self.k_cache_gpu[layer_id, self.decode_slot, :num_tokens] # [num_tokens, heads, dim]
|
||||
k = self.k_cache_gpu[layer_id, self.decode_slot, :num_tokens]
|
||||
v = self.v_cache_gpu[layer_id, self.decode_slot, :num_tokens]
|
||||
k = k.unsqueeze(0) # [1, num_tokens, heads, dim]
|
||||
k = k.unsqueeze(0)
|
||||
v = v.unsqueeze(0)
|
||||
return k, v
|
||||
|
||||
def offload_compute_to_cpu(self, cpu_block_ids: List[int]) -> None:
|
||||
"""
|
||||
Offload KV from Compute region to CPU.
|
||||
# ----- Legacy compatibility methods (for decode double-buffering) -----
|
||||
|
||||
Args:
|
||||
cpu_block_ids: Target CPU block IDs list
|
||||
def load_to_compute_layer(self, layer_id: int, cpu_block_ids: List[int]) -> None:
|
||||
"""
|
||||
Legacy: Load CPU blocks to decode_load_slots for decode double-buffering.
|
||||
|
||||
Uses first half of decode_load_slots as 'compute' region.
|
||||
"""
|
||||
if not cpu_block_ids:
|
||||
return
|
||||
|
||||
num_to_offload = min(len(cpu_block_ids), len(self.compute_slots))
|
||||
logger.debug(f"Compute offload: GPU {self.compute_slots[:num_to_offload]} -> CPU{cpu_block_ids[:num_to_offload]}")
|
||||
half = max(1, len(self.decode_load_slots) // 2)
|
||||
slots = self.decode_load_slots[:half]
|
||||
num_to_load = min(len(cpu_block_ids), len(slots))
|
||||
|
||||
with torch.cuda.stream(self.transfer_stream_main):
|
||||
# Wait for compute to complete
|
||||
self.transfer_stream_main.wait_stream(self.compute_stream)
|
||||
|
||||
for i in range(num_to_offload):
|
||||
gpu_slot = self.compute_slots[i]
|
||||
for i in range(num_to_load):
|
||||
cpu_id = cpu_block_ids[i]
|
||||
self.k_cache_cpu[:, cpu_id].copy_(
|
||||
self.k_cache_gpu[:, gpu_slot], non_blocking=True
|
||||
gpu_slot = slots[i]
|
||||
self.k_cache_gpu[layer_id, gpu_slot].copy_(
|
||||
self.k_cache_cpu[layer_id, cpu_id], non_blocking=True
|
||||
)
|
||||
self.v_cache_cpu[:, cpu_id].copy_(
|
||||
self.v_cache_gpu[:, gpu_slot], non_blocking=True
|
||||
)
|
||||
self.v_cache_gpu[layer_id, gpu_slot].copy_(
|
||||
self.v_cache_cpu[layer_id, cpu_id], non_blocking=True
|
||||
)
|
||||
if num_to_load > 0:
|
||||
self.ring_slot_ready[slots[0]][layer_id].record(self.transfer_stream_main)
|
||||
|
||||
def wait_compute_layer(self, layer_id: int) -> None:
|
||||
"""Legacy: Wait for 'compute' region loading."""
|
||||
half = max(1, len(self.decode_load_slots) // 2)
|
||||
if self.decode_load_slots:
|
||||
self.wait_slot_layer(self.decode_load_slots[0], layer_id)
|
||||
|
||||
def load_to_prefetch_layer(self, layer_id: int, cpu_block_ids: List[int]) -> None:
|
||||
"""
|
||||
Legacy: Load CPU blocks to decode_load_slots for decode double-buffering.
|
||||
|
||||
Uses second half of decode_load_slots as 'prefetch' region.
|
||||
"""
|
||||
if not cpu_block_ids:
|
||||
return
|
||||
|
||||
half = max(1, len(self.decode_load_slots) // 2)
|
||||
slots = self.decode_load_slots[half:]
|
||||
if not slots:
|
||||
slots = self.decode_load_slots # Fallback if only 1-2 slots
|
||||
num_to_load = min(len(cpu_block_ids), len(slots))
|
||||
|
||||
with torch.cuda.stream(self.transfer_stream_main):
|
||||
for i in range(num_to_load):
|
||||
cpu_id = cpu_block_ids[i]
|
||||
gpu_slot = slots[i]
|
||||
self.k_cache_gpu[layer_id, gpu_slot].copy_(
|
||||
self.k_cache_cpu[layer_id, cpu_id], non_blocking=True
|
||||
)
|
||||
self.v_cache_gpu[layer_id, gpu_slot].copy_(
|
||||
self.v_cache_cpu[layer_id, cpu_id], non_blocking=True
|
||||
)
|
||||
if num_to_load > 0:
|
||||
self.ring_slot_ready[slots[0]][layer_id].record(self.transfer_stream_main)
|
||||
|
||||
def wait_prefetch_layer(self, layer_id: int) -> None:
|
||||
"""Legacy: Wait for 'prefetch' region loading."""
|
||||
half = max(1, len(self.decode_load_slots) // 2)
|
||||
slots = self.decode_load_slots[half:]
|
||||
if slots:
|
||||
self.wait_slot_layer(slots[0], layer_id)
|
||||
elif self.decode_load_slots:
|
||||
self.wait_slot_layer(self.decode_load_slots[0], layer_id)
|
||||
|
||||
def get_kv_for_compute(
|
||||
self,
|
||||
layer_id: int,
|
||||
num_blocks: int,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""Legacy: Get KV from 'compute' region (first half of decode_load_slots)."""
|
||||
half = max(1, len(self.decode_load_slots) // 2)
|
||||
slots = self.decode_load_slots[:half][:num_blocks]
|
||||
return self.get_kv_for_slots(layer_id, slots)
|
||||
|
||||
def get_kv_for_prefetch(
|
||||
self,
|
||||
layer_id: int,
|
||||
num_blocks: int,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""Legacy: Get KV from 'prefetch' region (second half of decode_load_slots)."""
|
||||
half = max(1, len(self.decode_load_slots) // 2)
|
||||
slots = self.decode_load_slots[half:]
|
||||
if not slots:
|
||||
slots = self.decode_load_slots
|
||||
slots = slots[:num_blocks]
|
||||
return self.get_kv_for_slots(layer_id, slots)
|
||||
Reference in New Issue
Block a user