[fix] Fixed kvcache offload problem.
This commit is contained in:
@@ -155,6 +155,11 @@ class OffloadEngine:
|
||||
self.ping_offload_done = torch.cuda.Event()
|
||||
self.pong_offload_done = torch.cuda.Event()
|
||||
|
||||
# ========== Per-layer events for chunked attention ==========
|
||||
# Each layer has its own event for synchronization
|
||||
self.compute_ready_per_layer = [torch.cuda.Event() for _ in range(num_layers)]
|
||||
self.prefetch_ready_per_layer = [torch.cuda.Event() for _ in range(num_layers)]
|
||||
|
||||
# ========== Event tracking for async transfers ==========
|
||||
self.pending_events: Dict[Tuple[int, int], torch.cuda.Event] = {}
|
||||
|
||||
@@ -836,10 +841,80 @@ class OffloadEngine:
|
||||
"""Wait for Compute region loading to complete."""
|
||||
self.compute_stream.wait_event(self.compute_ready)
|
||||
|
||||
def load_to_compute_layer(self, layer_id: int, cpu_block_ids: List[int]) -> None:
|
||||
"""
|
||||
Load CPU blocks to Compute region for a single layer only.
|
||||
|
||||
This is used for per-layer chunked attention where each layer
|
||||
independently loads its KV data.
|
||||
|
||||
Args:
|
||||
layer_id: Layer index to load
|
||||
cpu_block_ids: List of CPU block IDs to load
|
||||
"""
|
||||
if not cpu_block_ids:
|
||||
self.compute_ready_per_layer[layer_id].record(self.transfer_stream_main)
|
||||
return
|
||||
|
||||
num_to_load = min(len(cpu_block_ids), len(self.compute_slots))
|
||||
logger.debug(f"Compute load (layer {layer_id}): CPU{cpu_block_ids[:num_to_load]} -> GPU compute slots {self.compute_slots[:num_to_load]}")
|
||||
|
||||
with torch.cuda.stream(self.transfer_stream_main):
|
||||
for i in range(num_to_load):
|
||||
cpu_id = cpu_block_ids[i]
|
||||
gpu_slot = self.compute_slots[i]
|
||||
# Copy only this layer (not all layers)
|
||||
self.k_cache_gpu[layer_id, gpu_slot].copy_(
|
||||
self.k_cache_cpu[layer_id, cpu_id], non_blocking=True
|
||||
)
|
||||
self.v_cache_gpu[layer_id, gpu_slot].copy_(
|
||||
self.v_cache_cpu[layer_id, cpu_id], non_blocking=True
|
||||
)
|
||||
self.compute_ready_per_layer[layer_id].record(self.transfer_stream_main)
|
||||
|
||||
def wait_compute_layer(self, layer_id: int) -> None:
|
||||
"""Wait for specific layer's Compute region loading to complete."""
|
||||
self.compute_stream.wait_event(self.compute_ready_per_layer[layer_id])
|
||||
|
||||
def wait_prefetch(self) -> None:
|
||||
"""Wait for Prefetch region loading to complete."""
|
||||
self.compute_stream.wait_event(self.prefetch_ready)
|
||||
|
||||
def load_to_prefetch_layer(self, layer_id: int, cpu_block_ids: List[int]) -> None:
|
||||
"""
|
||||
Load CPU blocks to Prefetch region for a single layer only.
|
||||
|
||||
This is used for per-layer chunked attention where each layer
|
||||
independently loads its KV data.
|
||||
|
||||
Args:
|
||||
layer_id: Layer index to load
|
||||
cpu_block_ids: List of CPU block IDs to load
|
||||
"""
|
||||
if not cpu_block_ids:
|
||||
self.prefetch_ready_per_layer[layer_id].record(self.transfer_stream_main)
|
||||
return
|
||||
|
||||
num_to_load = min(len(cpu_block_ids), len(self.prefetch_slots))
|
||||
logger.debug(f"Prefetch load (layer {layer_id}): CPU{cpu_block_ids[:num_to_load]} -> GPU prefetch slots {self.prefetch_slots[:num_to_load]}")
|
||||
|
||||
with torch.cuda.stream(self.transfer_stream_main):
|
||||
for i in range(num_to_load):
|
||||
cpu_id = cpu_block_ids[i]
|
||||
gpu_slot = self.prefetch_slots[i]
|
||||
# Copy only this layer (not all layers)
|
||||
self.k_cache_gpu[layer_id, gpu_slot].copy_(
|
||||
self.k_cache_cpu[layer_id, cpu_id], non_blocking=True
|
||||
)
|
||||
self.v_cache_gpu[layer_id, gpu_slot].copy_(
|
||||
self.v_cache_cpu[layer_id, cpu_id], non_blocking=True
|
||||
)
|
||||
self.prefetch_ready_per_layer[layer_id].record(self.transfer_stream_main)
|
||||
|
||||
def wait_prefetch_layer(self, layer_id: int) -> None:
|
||||
"""Wait for specific layer's Prefetch region loading to complete."""
|
||||
self.compute_stream.wait_event(self.prefetch_ready_per_layer[layer_id])
|
||||
|
||||
def swap_compute_prefetch(self) -> None:
|
||||
"""Swap roles of Compute region and Prefetch region."""
|
||||
self.compute_slots, self.prefetch_slots = self.prefetch_slots, self.compute_slots
|
||||
|
||||
Reference in New Issue
Block a user