♻️ refactor: unify KV cache operations through OffloadEngine
- Add write_to_prefill_buffer() and write_to_decode_buffer() methods
- Add chunk_idx parameter to load_to_slot_layer() for NVTX labeling
- Replace direct copy_() calls with OffloadEngine methods in attention.py
- Update all load_to_slot_layer() calls to pass chunk_idx
- NVTX markers now show chunk info: "H2D: L{layer} Chunk{chunk} CPU[{block}]->Slot[{slot}]"
All KV cache data transfers in chunked offload mode now go through
OffloadEngine, enabling better profiling and consistent management.
Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)
Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
This commit is contained in:
@@ -139,7 +139,8 @@ class FullAttentionPolicy(SparsePolicy):
|
||||
slot = load_slots[0]
|
||||
for block_idx in range(num_blocks):
|
||||
cpu_block_id = cpu_block_table[block_idx]
|
||||
offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id)
|
||||
# cpu_block_id is the chunk index (block N = chunk N)
|
||||
offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id, chunk_idx=cpu_block_id)
|
||||
offload_engine.wait_slot_layer(slot)
|
||||
|
||||
with torch.cuda.stream(compute_stream):
|
||||
@@ -159,7 +160,8 @@ class FullAttentionPolicy(SparsePolicy):
|
||||
num_slots = len(load_slots)
|
||||
num_preload = min(num_slots, num_blocks)
|
||||
for i in range(num_preload):
|
||||
offload_engine.load_to_slot_layer(load_slots[i], layer_id, cpu_block_table[i])
|
||||
cpu_block_id = cpu_block_table[i]
|
||||
offload_engine.load_to_slot_layer(load_slots[i], layer_id, cpu_block_id, chunk_idx=cpu_block_id)
|
||||
|
||||
for block_idx in range(num_blocks):
|
||||
current_slot = load_slots[block_idx % num_slots]
|
||||
@@ -186,7 +188,7 @@ class FullAttentionPolicy(SparsePolicy):
|
||||
if next_block_idx < num_blocks:
|
||||
next_slot = load_slots[next_block_idx % num_slots]
|
||||
next_cpu_block_id = cpu_block_table[next_block_idx]
|
||||
offload_engine.load_to_slot_layer(next_slot, layer_id, next_cpu_block_id)
|
||||
offload_engine.load_to_slot_layer(next_slot, layer_id, next_cpu_block_id, chunk_idx=next_cpu_block_id)
|
||||
|
||||
# Step 4: Compute attention to current chunk (causal mask)
|
||||
with torch.cuda.stream(compute_stream):
|
||||
@@ -350,7 +352,8 @@ class FullAttentionPolicy(SparsePolicy):
|
||||
# Phase 1: Pre-load up to num_slots blocks
|
||||
num_preload = min(num_slots, num_blocks)
|
||||
for i in range(num_preload):
|
||||
offload_engine.load_to_slot_layer(load_slots[i], layer_id, cpu_block_table[i])
|
||||
cpu_block_id = cpu_block_table[i]
|
||||
offload_engine.load_to_slot_layer(load_slots[i], layer_id, cpu_block_id, chunk_idx=cpu_block_id)
|
||||
|
||||
# Phase 2: Process blocks with pipeline
|
||||
for block_idx in range(num_blocks):
|
||||
@@ -383,7 +386,8 @@ class FullAttentionPolicy(SparsePolicy):
|
||||
# Start loading next block (pipeline)
|
||||
next_block_idx = block_idx + num_slots
|
||||
if next_block_idx < num_blocks:
|
||||
offload_engine.load_to_slot_layer(current_slot, layer_id, cpu_block_table[next_block_idx])
|
||||
next_cpu_block_id = cpu_block_table[next_block_idx]
|
||||
offload_engine.load_to_slot_layer(current_slot, layer_id, next_cpu_block_id, chunk_idx=next_cpu_block_id)
|
||||
|
||||
# Merge with accumulated
|
||||
with torch.cuda.stream(compute_stream):
|
||||
|
||||
Reference in New Issue
Block a user