[WIP] fixing attention compute error.
This commit is contained in:
@@ -538,7 +538,7 @@ class OffloadEngine:
|
||||
|
||||
def sync_indices(self) -> None:
|
||||
"""Synchronize to ensure all index updates are complete."""
|
||||
torch.cuda.current_stream().synchronize()
|
||||
torch.cuda.default_stream().synchronize()
|
||||
|
||||
# ========== Cache access methods ==========
|
||||
|
||||
@@ -682,8 +682,9 @@ class OffloadEngine:
|
||||
Async load a single CPU block to a ring buffer slot for one layer.
|
||||
|
||||
This is the core building block for ring buffer pipelining.
|
||||
Before starting the transfer, waits for any previous compute on this slot
|
||||
to complete (using compute_done event).
|
||||
Before starting the transfer, waits for:
|
||||
1. Any previous compute on this slot to complete
|
||||
2. Any pending offload of this slot to complete
|
||||
|
||||
Args:
|
||||
slot_idx: Target GPU slot index
|
||||
@@ -701,6 +702,10 @@ class OffloadEngine:
|
||||
# This prevents data race: transfer must not start until attention finishes reading
|
||||
stream.wait_event(self.ring_slot_compute_done[slot_idx][layer_id])
|
||||
|
||||
# Also wait for any pending offload of this slot to complete
|
||||
# This prevents race: load must not write GPU slot while offload is reading from it
|
||||
stream.wait_event(self.ring_slot_all_layers_offload_done[slot_idx])
|
||||
|
||||
self.k_cache_gpu[layer_id, slot_idx].copy_(
|
||||
self.k_cache_cpu[layer_id, cpu_block_id], non_blocking=True
|
||||
)
|
||||
@@ -763,7 +768,11 @@ class OffloadEngine:
|
||||
|
||||
torch.cuda.nvtx.range_push(f"D2H: Slot[{slot_idx}]->CPU[{cpu_block_id}]")
|
||||
with torch.cuda.stream(self.transfer_stream_main):
|
||||
# Wait for both compute_stream and default stream
|
||||
# - compute_stream: for flash attention operations
|
||||
# - default_stream: for store_kvcache which runs on default stream
|
||||
self.transfer_stream_main.wait_stream(self.compute_stream)
|
||||
self.transfer_stream_main.wait_stream(torch.cuda.default_stream())
|
||||
memcpy_2d_async(
|
||||
self.k_cache_cpu[:, cpu_block_id],
|
||||
self.k_cache_gpu[:, slot_idx],
|
||||
@@ -793,7 +802,9 @@ class OffloadEngine:
|
||||
cpu_block_id: Target CPU block ID
|
||||
"""
|
||||
with torch.cuda.stream(self.transfer_stream_main):
|
||||
# Wait for both compute_stream and default stream
|
||||
self.transfer_stream_main.wait_stream(self.compute_stream)
|
||||
self.transfer_stream_main.wait_stream(torch.cuda.default_stream())
|
||||
self.k_cache_cpu[layer_id, cpu_block_id].copy_(
|
||||
self.k_cache_gpu[layer_id, slot_idx], non_blocking=True
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user