♻️ refactor: unify KV cache operations through OffloadEngine

- Add write_to_prefill_buffer() and write_to_decode_buffer() methods
- Add chunk_idx parameter to load_to_slot_layer() for NVTX labeling
- Replace direct copy_() calls with OffloadEngine methods in attention.py
- Update all load_to_slot_layer() calls to pass chunk_idx
- NVTX markers now show chunk info: "H2D: L{layer} Chunk{chunk} CPU[{block}]->Slot[{slot}]"

All KV cache data transfers in chunked offload mode now go through
OffloadEngine, enabling better profiling and consistent management.

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
This commit is contained in:
Zijie Tian
2026-01-27 02:20:59 +08:00
parent 3100724666
commit aea3812230
4 changed files with 89 additions and 28 deletions

View File

@@ -374,7 +374,9 @@ class OffloadEngine:
"""
self.ring_slot_compute_done[slot_idx].record()
def load_to_slot_layer(self, slot_idx: int, layer_id: int, cpu_block_id: int) -> None:
def load_to_slot_layer(
self, slot_idx: int, layer_id: int, cpu_block_id: int, chunk_idx: int = -1
) -> None:
"""
Async load a single CPU block to a ring buffer slot for one layer.
@@ -389,13 +391,19 @@ class OffloadEngine:
slot_idx: Target GPU slot index
layer_id: Layer index to load (for CPU cache indexing)
cpu_block_id: Source CPU block ID
chunk_idx: Optional chunk index for NVTX labeling (-1 means not specified)
"""
logger.debug(f"Ring load: layer={layer_id}, CPU[{cpu_block_id}] -> GPU slot[{slot_idx}]")
# Use per-slot stream for parallel transfers across different slots
stream = self.slot_transfer_streams[slot_idx]
torch.cuda.nvtx.range_push(f"H2D: L{layer_id} CPU[{cpu_block_id}]->Slot[{slot_idx}]")
# Build NVTX label with optional chunk info
if chunk_idx >= 0:
nvtx_label = f"H2D: L{layer_id} Chunk{chunk_idx} CPU[{cpu_block_id}]->Slot[{slot_idx}]"
else:
nvtx_label = f"H2D: L{layer_id} CPU[{cpu_block_id}]->Slot[{slot_idx}]"
torch.cuda.nvtx.range_push(nvtx_label)
with torch.cuda.stream(stream):
# Wait for previous compute on this slot to complete before overwriting
# This prevents data race: transfer must not start until attention finishes reading
@@ -702,6 +710,61 @@ class OffloadEngine:
v = self.prefill_v_buffer[layer_id, :num_tokens].unsqueeze(0)
return k, v
def write_to_prefill_buffer(
self,
layer_id: int,
k: Tensor,
v: Tensor,
chunk_idx: int = -1,
) -> None:
"""
Write KV tensors to prefill buffer (D2D copy within GPU).
This is called during chunked prefill to store current chunk's KV
before computing attention.
Args:
layer_id: Layer index
k: Key tensor [num_tokens, kv_heads, head_dim]
v: Value tensor [num_tokens, kv_heads, head_dim]
chunk_idx: Current chunk index for NVTX labeling (-1 = not specified)
"""
num_tokens = k.shape[0]
# Build NVTX label
if chunk_idx >= 0:
nvtx_label = f"D2D: L{layer_id} Chunk{chunk_idx} WritePrefillBuffer"
else:
nvtx_label = f"D2D: L{layer_id} WritePrefillBuffer"
torch.cuda.nvtx.range_push(nvtx_label)
self.prefill_k_buffer[layer_id, :num_tokens].copy_(k)
self.prefill_v_buffer[layer_id, :num_tokens].copy_(v)
torch.cuda.nvtx.range_pop()
def write_to_decode_buffer(
self,
layer_id: int,
pos_in_block: int,
k: Tensor,
v: Tensor,
) -> None:
"""
Write KV tensors to decode buffer (D2D copy within GPU).
This is called during chunked decode to store current decode token's KV.
Args:
layer_id: Layer index
pos_in_block: Position within the current block
k: Key tensor [kv_heads, head_dim] (single token, squeezed)
v: Value tensor [kv_heads, head_dim] (single token, squeezed)
"""
torch.cuda.nvtx.range_push(f"D2D: L{layer_id} Pos{pos_in_block} WriteDecodeBuffer")
self.decode_k_buffer[layer_id, pos_in_block].copy_(k)
self.decode_v_buffer[layer_id, pos_in_block].copy_(v)
torch.cuda.nvtx.range_pop()
def offload_prefill_buffer_async(
self,
layer_id: int,

View File

@@ -139,7 +139,8 @@ class FullAttentionPolicy(SparsePolicy):
slot = load_slots[0]
for block_idx in range(num_blocks):
cpu_block_id = cpu_block_table[block_idx]
offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id)
# cpu_block_id is the chunk index (block N = chunk N)
offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id, chunk_idx=cpu_block_id)
offload_engine.wait_slot_layer(slot)
with torch.cuda.stream(compute_stream):
@@ -159,7 +160,8 @@ class FullAttentionPolicy(SparsePolicy):
num_slots = len(load_slots)
num_preload = min(num_slots, num_blocks)
for i in range(num_preload):
offload_engine.load_to_slot_layer(load_slots[i], layer_id, cpu_block_table[i])
cpu_block_id = cpu_block_table[i]
offload_engine.load_to_slot_layer(load_slots[i], layer_id, cpu_block_id, chunk_idx=cpu_block_id)
for block_idx in range(num_blocks):
current_slot = load_slots[block_idx % num_slots]
@@ -186,7 +188,7 @@ class FullAttentionPolicy(SparsePolicy):
if next_block_idx < num_blocks:
next_slot = load_slots[next_block_idx % num_slots]
next_cpu_block_id = cpu_block_table[next_block_idx]
offload_engine.load_to_slot_layer(next_slot, layer_id, next_cpu_block_id)
offload_engine.load_to_slot_layer(next_slot, layer_id, next_cpu_block_id, chunk_idx=next_cpu_block_id)
# Step 4: Compute attention to current chunk (causal mask)
with torch.cuda.stream(compute_stream):
@@ -350,7 +352,8 @@ class FullAttentionPolicy(SparsePolicy):
# Phase 1: Pre-load up to num_slots blocks
num_preload = min(num_slots, num_blocks)
for i in range(num_preload):
offload_engine.load_to_slot_layer(load_slots[i], layer_id, cpu_block_table[i])
cpu_block_id = cpu_block_table[i]
offload_engine.load_to_slot_layer(load_slots[i], layer_id, cpu_block_id, chunk_idx=cpu_block_id)
# Phase 2: Process blocks with pipeline
for block_idx in range(num_blocks):
@@ -383,7 +386,8 @@ class FullAttentionPolicy(SparsePolicy):
# Start loading next block (pipeline)
next_block_idx = block_idx + num_slots
if next_block_idx < num_blocks:
offload_engine.load_to_slot_layer(current_slot, layer_id, cpu_block_table[next_block_idx])
next_cpu_block_id = cpu_block_table[next_block_idx]
offload_engine.load_to_slot_layer(current_slot, layer_id, next_cpu_block_id, chunk_idx=next_cpu_block_id)
# Merge with accumulated
with torch.cuda.stream(compute_stream):

View File

@@ -189,8 +189,8 @@ class XAttentionBSAPolicy(SparsePolicy):
reshaped_block_size = block_size // self.stride # e.g., 1024/8 = 128
for cpu_block_id in available_blocks:
# Load K block from CPU to GPU
offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id)
# Load K block from CPU to GPU (cpu_block_id is chunk index)
offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id, chunk_idx=cpu_block_id)
offload_engine.wait_slot_layer(slot)
# Get KV: [1, block_size, num_kv_heads, head_dim]
@@ -382,7 +382,7 @@ class XAttentionBSAPolicy(SparsePolicy):
slot = load_slots[0]
for block_idx in range(num_blocks):
cpu_block_id = cpu_block_table[block_idx]
offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id)
offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id, chunk_idx=cpu_block_id)
offload_engine.wait_slot_layer(slot)
with torch.cuda.stream(compute_stream):
@@ -402,7 +402,8 @@ class XAttentionBSAPolicy(SparsePolicy):
num_slots = len(load_slots)
num_preload = min(num_slots, num_blocks)
for i in range(num_preload):
offload_engine.load_to_slot_layer(load_slots[i], layer_id, cpu_block_table[i])
cpu_block_id = cpu_block_table[i]
offload_engine.load_to_slot_layer(load_slots[i], layer_id, cpu_block_id, chunk_idx=cpu_block_id)
for block_idx in range(num_blocks):
current_slot = load_slots[block_idx % num_slots]
@@ -428,7 +429,7 @@ class XAttentionBSAPolicy(SparsePolicy):
if next_block_idx < num_blocks:
next_slot = load_slots[next_block_idx % num_slots]
next_cpu_block_id = cpu_block_table[next_block_idx]
offload_engine.load_to_slot_layer(next_slot, layer_id, next_cpu_block_id)
offload_engine.load_to_slot_layer(next_slot, layer_id, next_cpu_block_id, chunk_idx=next_cpu_block_id)
# Compute attention to current chunk (causal mask)
with torch.cuda.stream(compute_stream):

View File

@@ -104,27 +104,21 @@ class Attention(nn.Module):
# This enables fully async offloads since each layer has its own buffer.
offload_engine = context.kvcache_manager.offload_engine
compute_stream = offload_engine.compute_stream
chunk_idx = context.current_chunk_idx if hasattr(context, 'current_chunk_idx') else -1
# Wait for default stream to ensure slot_mapping tensor transfer is complete
compute_stream.wait_stream(torch.cuda.default_stream())
with torch.cuda.stream(compute_stream):
# Write KV to per-layer prefill buffer (contiguous write, no slot_mapping)
# Write KV to per-layer prefill buffer via offload_engine
# k, v shape: [num_tokens, kv_heads, head_dim]
num_tokens = k.shape[0]
offload_engine.prefill_k_buffer[self.layer_id, :num_tokens].copy_(k)
offload_engine.prefill_v_buffer[self.layer_id, :num_tokens].copy_(v)
#! GPU 2 GPU
offload_engine.write_to_prefill_buffer(self.layer_id, k, v, chunk_idx=chunk_idx)
elif is_chunked_offload:
# Chunked decode mode: use compute_stream for store_kvcache
# This ensures proper synchronization with per-layer offload
compute_stream = context.kvcache_manager.offload_engine.compute_stream
if k_cache.numel() and v_cache.numel():
# CRITICAL: Wait for default stream to ensure slot_mapping tensor transfer is complete
# slot_mapping is created with non_blocking=True on default stream, but we use it
# on compute_stream. Without this sync, index_copy_ can get corrupted indices.
compute_stream.wait_stream(torch.cuda.default_stream())
with torch.cuda.stream(compute_stream):
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
# Chunked decode mode: write KV to per-layer decode buffer via offload_engine
# KV will be written to decode buffer in the decode branch below
# No store_kvcache needed - all KV management goes through offload_engine
pass
else:
# Normal mode: store on default stream
if k_cache.numel() and v_cache.numel():
@@ -155,8 +149,7 @@ class Attention(nn.Module):
offload_engine = kvcache_manager.offload_engine
pos_in_block = context.decode_pos_in_block
# k, v shape: [1, kv_heads, head_dim]
offload_engine.decode_k_buffer[self.layer_id, pos_in_block].copy_(k.squeeze(0))
offload_engine.decode_v_buffer[self.layer_id, pos_in_block].copy_(v.squeeze(0))
offload_engine.write_to_decode_buffer(self.layer_id, pos_in_block, k.squeeze(0), v.squeeze(0))
o = self._chunked_decode_attention(q, k, v, context)
else:
o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache,