[WIP] fixing attention compute error.

This commit is contained in:
Zijie Tian
2025-12-30 00:31:48 +08:00
parent bf4c63c7ec
commit 89f8020d38
12 changed files with 2175 additions and 103 deletions

View File

@@ -538,7 +538,7 @@ class OffloadEngine:
def sync_indices(self) -> None:
"""Synchronize to ensure all index updates are complete."""
torch.cuda.current_stream().synchronize()
torch.cuda.default_stream().synchronize()
# ========== Cache access methods ==========
@@ -682,8 +682,9 @@ class OffloadEngine:
Async load a single CPU block to a ring buffer slot for one layer.
This is the core building block for ring buffer pipelining.
Before starting the transfer, waits for any previous compute on this slot
to complete (using compute_done event).
Before starting the transfer, waits for:
1. Any previous compute on this slot to complete
2. Any pending offload of this slot to complete
Args:
slot_idx: Target GPU slot index
@@ -701,6 +702,10 @@ class OffloadEngine:
# This prevents data race: transfer must not start until attention finishes reading
stream.wait_event(self.ring_slot_compute_done[slot_idx][layer_id])
# Also wait for any pending offload of this slot to complete
# This prevents race: load must not write GPU slot while offload is reading from it
stream.wait_event(self.ring_slot_all_layers_offload_done[slot_idx])
self.k_cache_gpu[layer_id, slot_idx].copy_(
self.k_cache_cpu[layer_id, cpu_block_id], non_blocking=True
)
@@ -763,7 +768,11 @@ class OffloadEngine:
torch.cuda.nvtx.range_push(f"D2H: Slot[{slot_idx}]->CPU[{cpu_block_id}]")
with torch.cuda.stream(self.transfer_stream_main):
# Wait for both compute_stream and default stream
# - compute_stream: for flash attention operations
# - default_stream: for store_kvcache which runs on default stream
self.transfer_stream_main.wait_stream(self.compute_stream)
self.transfer_stream_main.wait_stream(torch.cuda.default_stream())
memcpy_2d_async(
self.k_cache_cpu[:, cpu_block_id],
self.k_cache_gpu[:, slot_idx],
@@ -793,7 +802,9 @@ class OffloadEngine:
cpu_block_id: Target CPU block ID
"""
with torch.cuda.stream(self.transfer_stream_main):
# Wait for both compute_stream and default stream
self.transfer_stream_main.wait_stream(self.compute_stream)
self.transfer_stream_main.wait_stream(torch.cuda.default_stream())
self.k_cache_cpu[layer_id, cpu_block_id].copy_(
self.k_cache_gpu[layer_id, slot_idx], non_blocking=True
)