[WIP] NEED to modify communication.
This commit is contained in:
@@ -8,6 +8,7 @@ Key design principles for CUDA Graph compatibility:
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.cuda.nvtx
|
||||
from torch import Tensor
|
||||
from typing import Dict, List, Tuple, Optional
|
||||
from dataclasses import dataclass
|
||||
@@ -660,6 +661,7 @@ class OffloadEngine:
|
||||
"""
|
||||
logger.debug(f"Ring load: layer={layer_id}, CPU[{cpu_block_id}] -> GPU slot[{slot_idx}]")
|
||||
|
||||
torch.cuda.nvtx.range_push(f"H2D: L{layer_id} CPU[{cpu_block_id}]->Slot[{slot_idx}]")
|
||||
with torch.cuda.stream(self.transfer_stream_main):
|
||||
# Wait for previous compute on this slot to complete before overwriting
|
||||
# This prevents data race: transfer must not start until attention finishes reading
|
||||
@@ -672,6 +674,7 @@ class OffloadEngine:
|
||||
self.v_cache_cpu[layer_id, cpu_block_id], non_blocking=True
|
||||
)
|
||||
self.ring_slot_ready[slot_idx][layer_id].record(self.transfer_stream_main)
|
||||
torch.cuda.nvtx.range_pop()
|
||||
|
||||
def wait_slot_layer(self, slot_idx: int, layer_id: int) -> None:
|
||||
"""
|
||||
@@ -718,6 +721,7 @@ class OffloadEngine:
|
||||
"""
|
||||
logger.debug(f"Ring offload: GPU slot[{slot_idx}] -> CPU[{cpu_block_id}]")
|
||||
|
||||
torch.cuda.nvtx.range_push(f"D2H: Slot[{slot_idx}]->CPU[{cpu_block_id}]")
|
||||
with torch.cuda.stream(self.transfer_stream_main):
|
||||
self.transfer_stream_main.wait_stream(self.compute_stream)
|
||||
self.k_cache_cpu[:, cpu_block_id].copy_(
|
||||
@@ -727,6 +731,7 @@ class OffloadEngine:
|
||||
self.v_cache_gpu[:, slot_idx], non_blocking=True
|
||||
)
|
||||
self.ring_slot_all_layers_offload_done[slot_idx].record(self.transfer_stream_main)
|
||||
torch.cuda.nvtx.range_pop()
|
||||
|
||||
def wait_slot_offload(self, slot_idx: int) -> None:
|
||||
"""Wait for slot offload to complete."""
|
||||
|
||||
Reference in New Issue
Block a user