♻️ refactor: unify KV cache operations through OffloadEngine

- Add write_to_prefill_buffer() and write_to_decode_buffer() methods
- Add chunk_idx parameter to load_to_slot_layer() for NVTX labeling
- Replace direct copy_() calls with OffloadEngine methods in attention.py
- Update all load_to_slot_layer() calls to pass chunk_idx
- NVTX markers now show chunk info: "H2D: L{layer} Chunk{chunk} CPU[{block}]->Slot[{slot}]"

All KV cache data transfers in chunked offload mode now go through
OffloadEngine, enabling better profiling and consistent management.

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
This commit is contained in:
Zijie Tian
2026-01-27 02:20:59 +08:00
parent 3100724666
commit aea3812230
4 changed files with 89 additions and 28 deletions

View File

@@ -104,27 +104,21 @@ class Attention(nn.Module):
# This enables fully async offloads since each layer has its own buffer.
offload_engine = context.kvcache_manager.offload_engine
compute_stream = offload_engine.compute_stream
chunk_idx = context.current_chunk_idx if hasattr(context, 'current_chunk_idx') else -1
# Wait for default stream to ensure slot_mapping tensor transfer is complete
compute_stream.wait_stream(torch.cuda.default_stream())
with torch.cuda.stream(compute_stream):
# Write KV to per-layer prefill buffer (contiguous write, no slot_mapping)
# Write KV to per-layer prefill buffer via offload_engine
# k, v shape: [num_tokens, kv_heads, head_dim]
num_tokens = k.shape[0]
offload_engine.prefill_k_buffer[self.layer_id, :num_tokens].copy_(k)
offload_engine.prefill_v_buffer[self.layer_id, :num_tokens].copy_(v)
#! GPU 2 GPU
offload_engine.write_to_prefill_buffer(self.layer_id, k, v, chunk_idx=chunk_idx)
elif is_chunked_offload:
# Chunked decode mode: use compute_stream for store_kvcache
# This ensures proper synchronization with per-layer offload
compute_stream = context.kvcache_manager.offload_engine.compute_stream
if k_cache.numel() and v_cache.numel():
# CRITICAL: Wait for default stream to ensure slot_mapping tensor transfer is complete
# slot_mapping is created with non_blocking=True on default stream, but we use it
# on compute_stream. Without this sync, index_copy_ can get corrupted indices.
compute_stream.wait_stream(torch.cuda.default_stream())
with torch.cuda.stream(compute_stream):
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
# Chunked decode mode: write KV to per-layer decode buffer via offload_engine
# KV will be written to decode buffer in the decode branch below
# No store_kvcache needed - all KV management goes through offload_engine
pass
else:
# Normal mode: store on default stream
if k_cache.numel() and v_cache.numel():
@@ -155,8 +149,7 @@ class Attention(nn.Module):
offload_engine = kvcache_manager.offload_engine
pos_in_block = context.decode_pos_in_block
# k, v shape: [1, kv_heads, head_dim]
offload_engine.decode_k_buffer[self.layer_id, pos_in_block].copy_(k.squeeze(0))
offload_engine.decode_v_buffer[self.layer_id, pos_in_block].copy_(v.squeeze(0))
offload_engine.write_to_decode_buffer(self.layer_id, pos_in_block, k.squeeze(0), v.squeeze(0))
o = self._chunked_decode_attention(q, k, v, context)
else:
o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache,