[refactor] Implement real chunked prefill mechenism.

This commit is contained in:
Zijie Tian
2025-12-10 18:34:01 +08:00
parent 0b6f19242d
commit 87055cc5ce
4 changed files with 313 additions and 85 deletions

View File

@@ -308,6 +308,112 @@ class OffloadEngine:
events.append(event)
return events
# ========== Chunked Decode: Load CPU blocks to GPU slots ==========
def load_cpu_blocks_to_gpu_slots(
self,
layer_id: int,
cpu_block_ids: List[int],
gpu_slot_ids: List[int],
) -> None:
"""
Load CPU blocks to specific GPU slots for chunked decode.
Uses the main GPU KV cache slots, not a separate temp buffer.
This is the same mechanism as chunked prefill uses.
Args:
layer_id: Layer index
cpu_block_ids: List of CPU block IDs to load
gpu_slot_ids: List of GPU slot IDs to load into (must be same length)
"""
assert len(cpu_block_ids) == len(gpu_slot_ids)
stream = self._get_next_stream()
with torch.cuda.stream(stream):
for cpu_block_id, gpu_slot in zip(cpu_block_ids, gpu_slot_ids):
# Copy from pinned CPU memory to GPU KV cache slot
self.k_cache_gpu[layer_id, gpu_slot].copy_(
self.k_cache_cpu[layer_id, cpu_block_id],
non_blocking=True
)
self.v_cache_gpu[layer_id, gpu_slot].copy_(
self.v_cache_cpu[layer_id, cpu_block_id],
non_blocking=True
)
# Wait for transfer to complete
stream.synchronize()
def load_cpu_blocks_to_gpu_slots_async(
self,
layer_id: int,
cpu_block_ids: List[int],
gpu_slot_ids: List[int],
) -> torch.cuda.Event:
"""
Async version: Load CPU blocks to GPU slots.
Args:
layer_id: Layer index
cpu_block_ids: List of CPU block IDs to load
gpu_slot_ids: List of GPU slot IDs to load into
Returns:
CUDA event to wait on
"""
assert len(cpu_block_ids) == len(gpu_slot_ids)
stream = self._get_next_stream()
event = torch.cuda.Event()
with torch.cuda.stream(stream):
for cpu_block_id, gpu_slot in zip(cpu_block_ids, gpu_slot_ids):
self.k_cache_gpu[layer_id, gpu_slot].copy_(
self.k_cache_cpu[layer_id, cpu_block_id],
non_blocking=True
)
self.v_cache_gpu[layer_id, gpu_slot].copy_(
self.v_cache_cpu[layer_id, cpu_block_id],
non_blocking=True
)
event.record()
return event
def load_cpu_blocks_to_gpu_slots_all_layers(
self,
cpu_block_ids: List[int],
gpu_slot_ids: List[int],
) -> None:
"""
Load CPU blocks to GPU slots for ALL layers at once.
More efficient than per-layer loading when we know the mapping upfront.
Args:
cpu_block_ids: List of CPU block IDs to load
gpu_slot_ids: List of GPU slot IDs to load into
"""
assert len(cpu_block_ids) == len(gpu_slot_ids)
stream = self._get_next_stream()
with torch.cuda.stream(stream):
for cpu_block_id, gpu_slot in zip(cpu_block_ids, gpu_slot_ids):
# Copy all layers at once
self.k_cache_gpu[:, gpu_slot].copy_(
self.k_cache_cpu[:, cpu_block_id],
non_blocking=True
)
self.v_cache_gpu[:, gpu_slot].copy_(
self.v_cache_cpu[:, cpu_block_id],
non_blocking=True
)
stream.synchronize()
# ========== Synchronization methods ==========
def wait_for_block(self, layer_id: int, gpu_block_id: int) -> None: