[WIP] NEED to modify communication.

This commit is contained in:
Zijie Tian
2025-12-24 21:57:51 +08:00
parent 782437c486
commit 6ec1b23982
9 changed files with 462 additions and 2 deletions

View File

@@ -8,6 +8,7 @@ Key design principles for CUDA Graph compatibility:
"""
import torch
import torch.cuda.nvtx
from torch import Tensor
from typing import Dict, List, Tuple, Optional
from dataclasses import dataclass
@@ -660,6 +661,7 @@ class OffloadEngine:
"""
logger.debug(f"Ring load: layer={layer_id}, CPU[{cpu_block_id}] -> GPU slot[{slot_idx}]")
torch.cuda.nvtx.range_push(f"H2D: L{layer_id} CPU[{cpu_block_id}]->Slot[{slot_idx}]")
with torch.cuda.stream(self.transfer_stream_main):
# Wait for previous compute on this slot to complete before overwriting
# This prevents data race: transfer must not start until attention finishes reading
@@ -672,6 +674,7 @@ class OffloadEngine:
self.v_cache_cpu[layer_id, cpu_block_id], non_blocking=True
)
self.ring_slot_ready[slot_idx][layer_id].record(self.transfer_stream_main)
torch.cuda.nvtx.range_pop()
def wait_slot_layer(self, slot_idx: int, layer_id: int) -> None:
"""
@@ -718,6 +721,7 @@ class OffloadEngine:
"""
logger.debug(f"Ring offload: GPU slot[{slot_idx}] -> CPU[{cpu_block_id}]")
torch.cuda.nvtx.range_push(f"D2H: Slot[{slot_idx}]->CPU[{cpu_block_id}]")
with torch.cuda.stream(self.transfer_stream_main):
self.transfer_stream_main.wait_stream(self.compute_stream)
self.k_cache_cpu[:, cpu_block_id].copy_(
@@ -727,6 +731,7 @@ class OffloadEngine:
self.v_cache_gpu[:, slot_idx], non_blocking=True
)
self.ring_slot_all_layers_offload_done[slot_idx].record(self.transfer_stream_main)
torch.cuda.nvtx.range_pop()
def wait_slot_offload(self, slot_idx: int) -> None:
"""Wait for slot offload to complete."""