♻️ refactor: unify KV cache operations through OffloadEngine

- Add write_to_prefill_buffer() and write_to_decode_buffer() methods
- Add chunk_idx parameter to load_to_slot_layer() for NVTX labeling
- Replace direct copy_() calls with OffloadEngine methods in attention.py
- Update all load_to_slot_layer() calls to pass chunk_idx
- NVTX markers now show chunk info: "H2D: L{layer} Chunk{chunk} CPU[{block}]->Slot[{slot}]"

All KV cache data transfers in chunked offload mode now go through
OffloadEngine, enabling better profiling and consistent management.

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
This commit is contained in:
Zijie Tian
2026-01-27 02:20:59 +08:00
parent 3100724666
commit aea3812230
4 changed files with 89 additions and 28 deletions

View File

@@ -189,8 +189,8 @@ class XAttentionBSAPolicy(SparsePolicy):
reshaped_block_size = block_size // self.stride # e.g., 1024/8 = 128
for cpu_block_id in available_blocks:
# Load K block from CPU to GPU
offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id)
# Load K block from CPU to GPU (cpu_block_id is chunk index)
offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id, chunk_idx=cpu_block_id)
offload_engine.wait_slot_layer(slot)
# Get KV: [1, block_size, num_kv_heads, head_dim]
@@ -382,7 +382,7 @@ class XAttentionBSAPolicy(SparsePolicy):
slot = load_slots[0]
for block_idx in range(num_blocks):
cpu_block_id = cpu_block_table[block_idx]
offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id)
offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id, chunk_idx=cpu_block_id)
offload_engine.wait_slot_layer(slot)
with torch.cuda.stream(compute_stream):
@@ -402,7 +402,8 @@ class XAttentionBSAPolicy(SparsePolicy):
num_slots = len(load_slots)
num_preload = min(num_slots, num_blocks)
for i in range(num_preload):
offload_engine.load_to_slot_layer(load_slots[i], layer_id, cpu_block_table[i])
cpu_block_id = cpu_block_table[i]
offload_engine.load_to_slot_layer(load_slots[i], layer_id, cpu_block_id, chunk_idx=cpu_block_id)
for block_idx in range(num_blocks):
current_slot = load_slots[block_idx % num_slots]
@@ -428,7 +429,7 @@ class XAttentionBSAPolicy(SparsePolicy):
if next_block_idx < num_blocks:
next_slot = load_slots[next_block_idx % num_slots]
next_cpu_block_id = cpu_block_table[next_block_idx]
offload_engine.load_to_slot_layer(next_slot, layer_id, next_cpu_block_id)
offload_engine.load_to_slot_layer(next_slot, layer_id, next_cpu_block_id, chunk_idx=next_cpu_block_id)
# Compute attention to current chunk (causal mask)
with torch.cuda.stream(compute_stream):