[feat] Optimized with ASYNC offload.
This commit is contained in:
@@ -152,6 +152,14 @@ class OffloadEngine:
|
||||
self.ring_slot_all_layers_ready = [torch.cuda.Event() for _ in range(self.num_ring_slots)]
|
||||
self.ring_slot_all_layers_offload_done = [torch.cuda.Event() for _ in range(self.num_ring_slots)]
|
||||
|
||||
# ========== Per-slot Per-layer compute_done events for async pipeline ==========
|
||||
# ring_slot_compute_done[slot_idx][layer_id] = CUDA Event for compute completion
|
||||
# This is used to ensure we don't overwrite data before it's been read by attention
|
||||
self.ring_slot_compute_done = [
|
||||
[torch.cuda.Event() for _ in range(num_layers)]
|
||||
for _ in range(self.num_ring_slots)
|
||||
]
|
||||
|
||||
# ========== Event tracking for async transfers ==========
|
||||
self.pending_events: Dict[Tuple[int, int], torch.cuda.Event] = {}
|
||||
|
||||
@@ -622,11 +630,26 @@ class OffloadEngine:
|
||||
|
||||
# ----- Per-slot Per-layer loading methods -----
|
||||
|
||||
def record_slot_compute_done(self, slot_idx: int, layer_id: int) -> None:
|
||||
"""
|
||||
Record that computation using this slot's data is done.
|
||||
|
||||
This event is used by load_to_slot_layer to ensure we don't overwrite
|
||||
data before it's been read by attention computation.
|
||||
|
||||
Args:
|
||||
slot_idx: GPU slot index that was just used for computation
|
||||
layer_id: Layer index
|
||||
"""
|
||||
self.ring_slot_compute_done[slot_idx][layer_id].record()
|
||||
|
||||
def load_to_slot_layer(self, slot_idx: int, layer_id: int, cpu_block_id: int) -> None:
|
||||
"""
|
||||
Async load a single CPU block to a ring buffer slot for one layer.
|
||||
|
||||
This is the core building block for ring buffer pipelining.
|
||||
Before starting the transfer, waits for any previous compute on this slot
|
||||
to complete (using compute_done event).
|
||||
|
||||
Args:
|
||||
slot_idx: Target GPU slot index
|
||||
@@ -636,6 +659,10 @@ class OffloadEngine:
|
||||
logger.debug(f"Ring load: layer={layer_id}, CPU[{cpu_block_id}] -> GPU slot[{slot_idx}]")
|
||||
|
||||
with torch.cuda.stream(self.transfer_stream_main):
|
||||
# Wait for previous compute on this slot to complete before overwriting
|
||||
# This prevents data race: transfer must not start until attention finishes reading
|
||||
self.transfer_stream_main.wait_event(self.ring_slot_compute_done[slot_idx][layer_id])
|
||||
|
||||
self.k_cache_gpu[layer_id, slot_idx].copy_(
|
||||
self.k_cache_cpu[layer_id, cpu_block_id], non_blocking=True
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user