♻️ refactor: unify KV cache operations through OffloadEngine
- Add write_to_prefill_buffer() and write_to_decode_buffer() methods
- Add chunk_idx parameter to load_to_slot_layer() for NVTX labeling
- Replace direct copy_() calls with OffloadEngine methods in attention.py
- Update all load_to_slot_layer() calls to pass chunk_idx
- NVTX markers now show chunk info: "H2D: L{layer} Chunk{chunk} CPU[{block}]->Slot[{slot}]"
All KV cache data transfers in chunked offload mode now go through
OffloadEngine, enabling better profiling and consistent management.
Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)
Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
This commit is contained in:
@@ -374,7 +374,9 @@ class OffloadEngine:
|
||||
"""
|
||||
self.ring_slot_compute_done[slot_idx].record()
|
||||
|
||||
def load_to_slot_layer(self, slot_idx: int, layer_id: int, cpu_block_id: int) -> None:
|
||||
def load_to_slot_layer(
|
||||
self, slot_idx: int, layer_id: int, cpu_block_id: int, chunk_idx: int = -1
|
||||
) -> None:
|
||||
"""
|
||||
Async load a single CPU block to a ring buffer slot for one layer.
|
||||
|
||||
@@ -389,13 +391,19 @@ class OffloadEngine:
|
||||
slot_idx: Target GPU slot index
|
||||
layer_id: Layer index to load (for CPU cache indexing)
|
||||
cpu_block_id: Source CPU block ID
|
||||
chunk_idx: Optional chunk index for NVTX labeling (-1 means not specified)
|
||||
"""
|
||||
logger.debug(f"Ring load: layer={layer_id}, CPU[{cpu_block_id}] -> GPU slot[{slot_idx}]")
|
||||
|
||||
# Use per-slot stream for parallel transfers across different slots
|
||||
stream = self.slot_transfer_streams[slot_idx]
|
||||
|
||||
torch.cuda.nvtx.range_push(f"H2D: L{layer_id} CPU[{cpu_block_id}]->Slot[{slot_idx}]")
|
||||
# Build NVTX label with optional chunk info
|
||||
if chunk_idx >= 0:
|
||||
nvtx_label = f"H2D: L{layer_id} Chunk{chunk_idx} CPU[{cpu_block_id}]->Slot[{slot_idx}]"
|
||||
else:
|
||||
nvtx_label = f"H2D: L{layer_id} CPU[{cpu_block_id}]->Slot[{slot_idx}]"
|
||||
torch.cuda.nvtx.range_push(nvtx_label)
|
||||
with torch.cuda.stream(stream):
|
||||
# Wait for previous compute on this slot to complete before overwriting
|
||||
# This prevents data race: transfer must not start until attention finishes reading
|
||||
@@ -702,6 +710,61 @@ class OffloadEngine:
|
||||
v = self.prefill_v_buffer[layer_id, :num_tokens].unsqueeze(0)
|
||||
return k, v
|
||||
|
||||
def write_to_prefill_buffer(
|
||||
self,
|
||||
layer_id: int,
|
||||
k: Tensor,
|
||||
v: Tensor,
|
||||
chunk_idx: int = -1,
|
||||
) -> None:
|
||||
"""
|
||||
Write KV tensors to prefill buffer (D2D copy within GPU).
|
||||
|
||||
This is called during chunked prefill to store current chunk's KV
|
||||
before computing attention.
|
||||
|
||||
Args:
|
||||
layer_id: Layer index
|
||||
k: Key tensor [num_tokens, kv_heads, head_dim]
|
||||
v: Value tensor [num_tokens, kv_heads, head_dim]
|
||||
chunk_idx: Current chunk index for NVTX labeling (-1 = not specified)
|
||||
"""
|
||||
num_tokens = k.shape[0]
|
||||
|
||||
# Build NVTX label
|
||||
if chunk_idx >= 0:
|
||||
nvtx_label = f"D2D: L{layer_id} Chunk{chunk_idx} WritePrefillBuffer"
|
||||
else:
|
||||
nvtx_label = f"D2D: L{layer_id} WritePrefillBuffer"
|
||||
|
||||
torch.cuda.nvtx.range_push(nvtx_label)
|
||||
self.prefill_k_buffer[layer_id, :num_tokens].copy_(k)
|
||||
self.prefill_v_buffer[layer_id, :num_tokens].copy_(v)
|
||||
torch.cuda.nvtx.range_pop()
|
||||
|
||||
def write_to_decode_buffer(
|
||||
self,
|
||||
layer_id: int,
|
||||
pos_in_block: int,
|
||||
k: Tensor,
|
||||
v: Tensor,
|
||||
) -> None:
|
||||
"""
|
||||
Write KV tensors to decode buffer (D2D copy within GPU).
|
||||
|
||||
This is called during chunked decode to store current decode token's KV.
|
||||
|
||||
Args:
|
||||
layer_id: Layer index
|
||||
pos_in_block: Position within the current block
|
||||
k: Key tensor [kv_heads, head_dim] (single token, squeezed)
|
||||
v: Value tensor [kv_heads, head_dim] (single token, squeezed)
|
||||
"""
|
||||
torch.cuda.nvtx.range_push(f"D2D: L{layer_id} Pos{pos_in_block} WriteDecodeBuffer")
|
||||
self.decode_k_buffer[layer_id, pos_in_block].copy_(k)
|
||||
self.decode_v_buffer[layer_id, pos_in_block].copy_(v)
|
||||
torch.cuda.nvtx.range_pop()
|
||||
|
||||
def offload_prefill_buffer_async(
|
||||
self,
|
||||
layer_id: int,
|
||||
|
||||
Reference in New Issue
Block a user