♻️ refactor: unify KV cache operations through OffloadEngine
- Add write_to_prefill_buffer() and write_to_decode_buffer() methods
- Add chunk_idx parameter to load_to_slot_layer() for NVTX labeling
- Replace direct copy_() calls with OffloadEngine methods in attention.py
- Update all load_to_slot_layer() calls to pass chunk_idx
- NVTX markers now show chunk info: "H2D: L{layer} Chunk{chunk} CPU[{block}]->Slot[{slot}]"
All KV cache data transfers in chunked offload mode now go through
OffloadEngine, enabling better profiling and consistent management.
Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)
Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
This commit is contained in:
@@ -104,27 +104,21 @@ class Attention(nn.Module):
|
||||
# This enables fully async offloads since each layer has its own buffer.
|
||||
offload_engine = context.kvcache_manager.offload_engine
|
||||
compute_stream = offload_engine.compute_stream
|
||||
chunk_idx = context.current_chunk_idx if hasattr(context, 'current_chunk_idx') else -1
|
||||
|
||||
# Wait for default stream to ensure slot_mapping tensor transfer is complete
|
||||
compute_stream.wait_stream(torch.cuda.default_stream())
|
||||
|
||||
with torch.cuda.stream(compute_stream):
|
||||
# Write KV to per-layer prefill buffer (contiguous write, no slot_mapping)
|
||||
# Write KV to per-layer prefill buffer via offload_engine
|
||||
# k, v shape: [num_tokens, kv_heads, head_dim]
|
||||
num_tokens = k.shape[0]
|
||||
offload_engine.prefill_k_buffer[self.layer_id, :num_tokens].copy_(k)
|
||||
offload_engine.prefill_v_buffer[self.layer_id, :num_tokens].copy_(v)
|
||||
#! GPU 2 GPU
|
||||
offload_engine.write_to_prefill_buffer(self.layer_id, k, v, chunk_idx=chunk_idx)
|
||||
elif is_chunked_offload:
|
||||
# Chunked decode mode: use compute_stream for store_kvcache
|
||||
# This ensures proper synchronization with per-layer offload
|
||||
compute_stream = context.kvcache_manager.offload_engine.compute_stream
|
||||
if k_cache.numel() and v_cache.numel():
|
||||
# CRITICAL: Wait for default stream to ensure slot_mapping tensor transfer is complete
|
||||
# slot_mapping is created with non_blocking=True on default stream, but we use it
|
||||
# on compute_stream. Without this sync, index_copy_ can get corrupted indices.
|
||||
compute_stream.wait_stream(torch.cuda.default_stream())
|
||||
with torch.cuda.stream(compute_stream):
|
||||
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
|
||||
# Chunked decode mode: write KV to per-layer decode buffer via offload_engine
|
||||
# KV will be written to decode buffer in the decode branch below
|
||||
# No store_kvcache needed - all KV management goes through offload_engine
|
||||
pass
|
||||
else:
|
||||
# Normal mode: store on default stream
|
||||
if k_cache.numel() and v_cache.numel():
|
||||
@@ -155,8 +149,7 @@ class Attention(nn.Module):
|
||||
offload_engine = kvcache_manager.offload_engine
|
||||
pos_in_block = context.decode_pos_in_block
|
||||
# k, v shape: [1, kv_heads, head_dim]
|
||||
offload_engine.decode_k_buffer[self.layer_id, pos_in_block].copy_(k.squeeze(0))
|
||||
offload_engine.decode_v_buffer[self.layer_id, pos_in_block].copy_(v.squeeze(0))
|
||||
offload_engine.write_to_decode_buffer(self.layer_id, pos_in_block, k.squeeze(0), v.squeeze(0))
|
||||
o = self._chunked_decode_attention(q, k, v, context)
|
||||
else:
|
||||
o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache,
|
||||
|
||||
Reference in New Issue
Block a user