[refactor] Implement real chunked prefill mechenism.
This commit is contained in:
@@ -308,6 +308,112 @@ class OffloadEngine:
|
||||
events.append(event)
|
||||
return events
|
||||
|
||||
# ========== Chunked Decode: Load CPU blocks to GPU slots ==========
|
||||
|
||||
def load_cpu_blocks_to_gpu_slots(
|
||||
self,
|
||||
layer_id: int,
|
||||
cpu_block_ids: List[int],
|
||||
gpu_slot_ids: List[int],
|
||||
) -> None:
|
||||
"""
|
||||
Load CPU blocks to specific GPU slots for chunked decode.
|
||||
|
||||
Uses the main GPU KV cache slots, not a separate temp buffer.
|
||||
This is the same mechanism as chunked prefill uses.
|
||||
|
||||
Args:
|
||||
layer_id: Layer index
|
||||
cpu_block_ids: List of CPU block IDs to load
|
||||
gpu_slot_ids: List of GPU slot IDs to load into (must be same length)
|
||||
"""
|
||||
assert len(cpu_block_ids) == len(gpu_slot_ids)
|
||||
|
||||
stream = self._get_next_stream()
|
||||
|
||||
with torch.cuda.stream(stream):
|
||||
for cpu_block_id, gpu_slot in zip(cpu_block_ids, gpu_slot_ids):
|
||||
# Copy from pinned CPU memory to GPU KV cache slot
|
||||
self.k_cache_gpu[layer_id, gpu_slot].copy_(
|
||||
self.k_cache_cpu[layer_id, cpu_block_id],
|
||||
non_blocking=True
|
||||
)
|
||||
self.v_cache_gpu[layer_id, gpu_slot].copy_(
|
||||
self.v_cache_cpu[layer_id, cpu_block_id],
|
||||
non_blocking=True
|
||||
)
|
||||
|
||||
# Wait for transfer to complete
|
||||
stream.synchronize()
|
||||
|
||||
def load_cpu_blocks_to_gpu_slots_async(
|
||||
self,
|
||||
layer_id: int,
|
||||
cpu_block_ids: List[int],
|
||||
gpu_slot_ids: List[int],
|
||||
) -> torch.cuda.Event:
|
||||
"""
|
||||
Async version: Load CPU blocks to GPU slots.
|
||||
|
||||
Args:
|
||||
layer_id: Layer index
|
||||
cpu_block_ids: List of CPU block IDs to load
|
||||
gpu_slot_ids: List of GPU slot IDs to load into
|
||||
|
||||
Returns:
|
||||
CUDA event to wait on
|
||||
"""
|
||||
assert len(cpu_block_ids) == len(gpu_slot_ids)
|
||||
|
||||
stream = self._get_next_stream()
|
||||
event = torch.cuda.Event()
|
||||
|
||||
with torch.cuda.stream(stream):
|
||||
for cpu_block_id, gpu_slot in zip(cpu_block_ids, gpu_slot_ids):
|
||||
self.k_cache_gpu[layer_id, gpu_slot].copy_(
|
||||
self.k_cache_cpu[layer_id, cpu_block_id],
|
||||
non_blocking=True
|
||||
)
|
||||
self.v_cache_gpu[layer_id, gpu_slot].copy_(
|
||||
self.v_cache_cpu[layer_id, cpu_block_id],
|
||||
non_blocking=True
|
||||
)
|
||||
event.record()
|
||||
|
||||
return event
|
||||
|
||||
def load_cpu_blocks_to_gpu_slots_all_layers(
|
||||
self,
|
||||
cpu_block_ids: List[int],
|
||||
gpu_slot_ids: List[int],
|
||||
) -> None:
|
||||
"""
|
||||
Load CPU blocks to GPU slots for ALL layers at once.
|
||||
|
||||
More efficient than per-layer loading when we know the mapping upfront.
|
||||
|
||||
Args:
|
||||
cpu_block_ids: List of CPU block IDs to load
|
||||
gpu_slot_ids: List of GPU slot IDs to load into
|
||||
"""
|
||||
assert len(cpu_block_ids) == len(gpu_slot_ids)
|
||||
|
||||
stream = self._get_next_stream()
|
||||
|
||||
with torch.cuda.stream(stream):
|
||||
for cpu_block_id, gpu_slot in zip(cpu_block_ids, gpu_slot_ids):
|
||||
# Copy all layers at once
|
||||
self.k_cache_gpu[:, gpu_slot].copy_(
|
||||
self.k_cache_cpu[:, cpu_block_id],
|
||||
non_blocking=True
|
||||
)
|
||||
self.v_cache_gpu[:, gpu_slot].copy_(
|
||||
self.v_cache_cpu[:, cpu_block_id],
|
||||
non_blocking=True
|
||||
)
|
||||
|
||||
stream.synchronize()
|
||||
|
||||
# ========== Synchronization methods ==========
|
||||
|
||||
def wait_for_block(self, layer_id: int, gpu_block_id: int) -> None:
|
||||
|
||||
Reference in New Issue
Block a user