[feat] Optimized with ASYNC offload.

This commit is contained in:
Zijie Tian
2025-12-15 07:21:35 +08:00
parent b8b6478506
commit 91a0f09a24
3 changed files with 93 additions and 20 deletions

View File

@@ -152,6 +152,14 @@ class OffloadEngine:
self.ring_slot_all_layers_ready = [torch.cuda.Event() for _ in range(self.num_ring_slots)]
self.ring_slot_all_layers_offload_done = [torch.cuda.Event() for _ in range(self.num_ring_slots)]
# ========== Per-slot Per-layer compute_done events for async pipeline ==========
# ring_slot_compute_done[slot_idx][layer_id] = CUDA Event for compute completion
# This is used to ensure we don't overwrite data before it's been read by attention
self.ring_slot_compute_done = [
[torch.cuda.Event() for _ in range(num_layers)]
for _ in range(self.num_ring_slots)
]
# ========== Event tracking for async transfers ==========
self.pending_events: Dict[Tuple[int, int], torch.cuda.Event] = {}
@@ -622,11 +630,26 @@ class OffloadEngine:
# ----- Per-slot Per-layer loading methods -----
def record_slot_compute_done(self, slot_idx: int, layer_id: int) -> None:
"""
Record that computation using this slot's data is done.
This event is used by load_to_slot_layer to ensure we don't overwrite
data before it's been read by attention computation.
Args:
slot_idx: GPU slot index that was just used for computation
layer_id: Layer index
"""
self.ring_slot_compute_done[slot_idx][layer_id].record()
def load_to_slot_layer(self, slot_idx: int, layer_id: int, cpu_block_id: int) -> None:
"""
Async load a single CPU block to a ring buffer slot for one layer.
This is the core building block for ring buffer pipelining.
Before starting the transfer, waits for any previous compute on this slot
to complete (using compute_done event).
Args:
slot_idx: Target GPU slot index
@@ -636,6 +659,10 @@ class OffloadEngine:
logger.debug(f"Ring load: layer={layer_id}, CPU[{cpu_block_id}] -> GPU slot[{slot_idx}]")
with torch.cuda.stream(self.transfer_stream_main):
# Wait for previous compute on this slot to complete before overwriting
# This prevents data race: transfer must not start until attention finishes reading
self.transfer_stream_main.wait_event(self.ring_slot_compute_done[slot_idx][layer_id])
self.k_cache_gpu[layer_id, slot_idx].copy_(
self.k_cache_cpu[layer_id, cpu_block_id], non_blocking=True
)