♻️ refactor: unify KV cache operations through OffloadEngine
- Add write_to_prefill_buffer() and write_to_decode_buffer() methods
- Add chunk_idx parameter to load_to_slot_layer() for NVTX labeling
- Replace direct copy_() calls with OffloadEngine methods in attention.py
- Update all load_to_slot_layer() calls to pass chunk_idx
- NVTX markers now show chunk info: "H2D: L{layer} Chunk{chunk} CPU[{block}]->Slot[{slot}]"
All KV cache data transfers in chunked offload mode now go through
OffloadEngine, enabling better profiling and consistent management.
Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)
Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
This commit is contained in:
@@ -374,7 +374,9 @@ class OffloadEngine:
|
||||
"""
|
||||
self.ring_slot_compute_done[slot_idx].record()
|
||||
|
||||
def load_to_slot_layer(self, slot_idx: int, layer_id: int, cpu_block_id: int) -> None:
|
||||
def load_to_slot_layer(
|
||||
self, slot_idx: int, layer_id: int, cpu_block_id: int, chunk_idx: int = -1
|
||||
) -> None:
|
||||
"""
|
||||
Async load a single CPU block to a ring buffer slot for one layer.
|
||||
|
||||
@@ -389,13 +391,19 @@ class OffloadEngine:
|
||||
slot_idx: Target GPU slot index
|
||||
layer_id: Layer index to load (for CPU cache indexing)
|
||||
cpu_block_id: Source CPU block ID
|
||||
chunk_idx: Optional chunk index for NVTX labeling (-1 means not specified)
|
||||
"""
|
||||
logger.debug(f"Ring load: layer={layer_id}, CPU[{cpu_block_id}] -> GPU slot[{slot_idx}]")
|
||||
|
||||
# Use per-slot stream for parallel transfers across different slots
|
||||
stream = self.slot_transfer_streams[slot_idx]
|
||||
|
||||
torch.cuda.nvtx.range_push(f"H2D: L{layer_id} CPU[{cpu_block_id}]->Slot[{slot_idx}]")
|
||||
# Build NVTX label with optional chunk info
|
||||
if chunk_idx >= 0:
|
||||
nvtx_label = f"H2D: L{layer_id} Chunk{chunk_idx} CPU[{cpu_block_id}]->Slot[{slot_idx}]"
|
||||
else:
|
||||
nvtx_label = f"H2D: L{layer_id} CPU[{cpu_block_id}]->Slot[{slot_idx}]"
|
||||
torch.cuda.nvtx.range_push(nvtx_label)
|
||||
with torch.cuda.stream(stream):
|
||||
# Wait for previous compute on this slot to complete before overwriting
|
||||
# This prevents data race: transfer must not start until attention finishes reading
|
||||
@@ -702,6 +710,61 @@ class OffloadEngine:
|
||||
v = self.prefill_v_buffer[layer_id, :num_tokens].unsqueeze(0)
|
||||
return k, v
|
||||
|
||||
def write_to_prefill_buffer(
|
||||
self,
|
||||
layer_id: int,
|
||||
k: Tensor,
|
||||
v: Tensor,
|
||||
chunk_idx: int = -1,
|
||||
) -> None:
|
||||
"""
|
||||
Write KV tensors to prefill buffer (D2D copy within GPU).
|
||||
|
||||
This is called during chunked prefill to store current chunk's KV
|
||||
before computing attention.
|
||||
|
||||
Args:
|
||||
layer_id: Layer index
|
||||
k: Key tensor [num_tokens, kv_heads, head_dim]
|
||||
v: Value tensor [num_tokens, kv_heads, head_dim]
|
||||
chunk_idx: Current chunk index for NVTX labeling (-1 = not specified)
|
||||
"""
|
||||
num_tokens = k.shape[0]
|
||||
|
||||
# Build NVTX label
|
||||
if chunk_idx >= 0:
|
||||
nvtx_label = f"D2D: L{layer_id} Chunk{chunk_idx} WritePrefillBuffer"
|
||||
else:
|
||||
nvtx_label = f"D2D: L{layer_id} WritePrefillBuffer"
|
||||
|
||||
torch.cuda.nvtx.range_push(nvtx_label)
|
||||
self.prefill_k_buffer[layer_id, :num_tokens].copy_(k)
|
||||
self.prefill_v_buffer[layer_id, :num_tokens].copy_(v)
|
||||
torch.cuda.nvtx.range_pop()
|
||||
|
||||
def write_to_decode_buffer(
|
||||
self,
|
||||
layer_id: int,
|
||||
pos_in_block: int,
|
||||
k: Tensor,
|
||||
v: Tensor,
|
||||
) -> None:
|
||||
"""
|
||||
Write KV tensors to decode buffer (D2D copy within GPU).
|
||||
|
||||
This is called during chunked decode to store current decode token's KV.
|
||||
|
||||
Args:
|
||||
layer_id: Layer index
|
||||
pos_in_block: Position within the current block
|
||||
k: Key tensor [kv_heads, head_dim] (single token, squeezed)
|
||||
v: Value tensor [kv_heads, head_dim] (single token, squeezed)
|
||||
"""
|
||||
torch.cuda.nvtx.range_push(f"D2D: L{layer_id} Pos{pos_in_block} WriteDecodeBuffer")
|
||||
self.decode_k_buffer[layer_id, pos_in_block].copy_(k)
|
||||
self.decode_v_buffer[layer_id, pos_in_block].copy_(v)
|
||||
torch.cuda.nvtx.range_pop()
|
||||
|
||||
def offload_prefill_buffer_async(
|
||||
self,
|
||||
layer_id: int,
|
||||
|
||||
@@ -139,7 +139,8 @@ class FullAttentionPolicy(SparsePolicy):
|
||||
slot = load_slots[0]
|
||||
for block_idx in range(num_blocks):
|
||||
cpu_block_id = cpu_block_table[block_idx]
|
||||
offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id)
|
||||
# cpu_block_id is the chunk index (block N = chunk N)
|
||||
offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id, chunk_idx=cpu_block_id)
|
||||
offload_engine.wait_slot_layer(slot)
|
||||
|
||||
with torch.cuda.stream(compute_stream):
|
||||
@@ -159,7 +160,8 @@ class FullAttentionPolicy(SparsePolicy):
|
||||
num_slots = len(load_slots)
|
||||
num_preload = min(num_slots, num_blocks)
|
||||
for i in range(num_preload):
|
||||
offload_engine.load_to_slot_layer(load_slots[i], layer_id, cpu_block_table[i])
|
||||
cpu_block_id = cpu_block_table[i]
|
||||
offload_engine.load_to_slot_layer(load_slots[i], layer_id, cpu_block_id, chunk_idx=cpu_block_id)
|
||||
|
||||
for block_idx in range(num_blocks):
|
||||
current_slot = load_slots[block_idx % num_slots]
|
||||
@@ -186,7 +188,7 @@ class FullAttentionPolicy(SparsePolicy):
|
||||
if next_block_idx < num_blocks:
|
||||
next_slot = load_slots[next_block_idx % num_slots]
|
||||
next_cpu_block_id = cpu_block_table[next_block_idx]
|
||||
offload_engine.load_to_slot_layer(next_slot, layer_id, next_cpu_block_id)
|
||||
offload_engine.load_to_slot_layer(next_slot, layer_id, next_cpu_block_id, chunk_idx=next_cpu_block_id)
|
||||
|
||||
# Step 4: Compute attention to current chunk (causal mask)
|
||||
with torch.cuda.stream(compute_stream):
|
||||
@@ -350,7 +352,8 @@ class FullAttentionPolicy(SparsePolicy):
|
||||
# Phase 1: Pre-load up to num_slots blocks
|
||||
num_preload = min(num_slots, num_blocks)
|
||||
for i in range(num_preload):
|
||||
offload_engine.load_to_slot_layer(load_slots[i], layer_id, cpu_block_table[i])
|
||||
cpu_block_id = cpu_block_table[i]
|
||||
offload_engine.load_to_slot_layer(load_slots[i], layer_id, cpu_block_id, chunk_idx=cpu_block_id)
|
||||
|
||||
# Phase 2: Process blocks with pipeline
|
||||
for block_idx in range(num_blocks):
|
||||
@@ -383,7 +386,8 @@ class FullAttentionPolicy(SparsePolicy):
|
||||
# Start loading next block (pipeline)
|
||||
next_block_idx = block_idx + num_slots
|
||||
if next_block_idx < num_blocks:
|
||||
offload_engine.load_to_slot_layer(current_slot, layer_id, cpu_block_table[next_block_idx])
|
||||
next_cpu_block_id = cpu_block_table[next_block_idx]
|
||||
offload_engine.load_to_slot_layer(current_slot, layer_id, next_cpu_block_id, chunk_idx=next_cpu_block_id)
|
||||
|
||||
# Merge with accumulated
|
||||
with torch.cuda.stream(compute_stream):
|
||||
|
||||
@@ -189,8 +189,8 @@ class XAttentionBSAPolicy(SparsePolicy):
|
||||
reshaped_block_size = block_size // self.stride # e.g., 1024/8 = 128
|
||||
|
||||
for cpu_block_id in available_blocks:
|
||||
# Load K block from CPU to GPU
|
||||
offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id)
|
||||
# Load K block from CPU to GPU (cpu_block_id is chunk index)
|
||||
offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id, chunk_idx=cpu_block_id)
|
||||
offload_engine.wait_slot_layer(slot)
|
||||
|
||||
# Get KV: [1, block_size, num_kv_heads, head_dim]
|
||||
@@ -382,7 +382,7 @@ class XAttentionBSAPolicy(SparsePolicy):
|
||||
slot = load_slots[0]
|
||||
for block_idx in range(num_blocks):
|
||||
cpu_block_id = cpu_block_table[block_idx]
|
||||
offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id)
|
||||
offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id, chunk_idx=cpu_block_id)
|
||||
offload_engine.wait_slot_layer(slot)
|
||||
|
||||
with torch.cuda.stream(compute_stream):
|
||||
@@ -402,7 +402,8 @@ class XAttentionBSAPolicy(SparsePolicy):
|
||||
num_slots = len(load_slots)
|
||||
num_preload = min(num_slots, num_blocks)
|
||||
for i in range(num_preload):
|
||||
offload_engine.load_to_slot_layer(load_slots[i], layer_id, cpu_block_table[i])
|
||||
cpu_block_id = cpu_block_table[i]
|
||||
offload_engine.load_to_slot_layer(load_slots[i], layer_id, cpu_block_id, chunk_idx=cpu_block_id)
|
||||
|
||||
for block_idx in range(num_blocks):
|
||||
current_slot = load_slots[block_idx % num_slots]
|
||||
@@ -428,7 +429,7 @@ class XAttentionBSAPolicy(SparsePolicy):
|
||||
if next_block_idx < num_blocks:
|
||||
next_slot = load_slots[next_block_idx % num_slots]
|
||||
next_cpu_block_id = cpu_block_table[next_block_idx]
|
||||
offload_engine.load_to_slot_layer(next_slot, layer_id, next_cpu_block_id)
|
||||
offload_engine.load_to_slot_layer(next_slot, layer_id, next_cpu_block_id, chunk_idx=next_cpu_block_id)
|
||||
|
||||
# Compute attention to current chunk (causal mask)
|
||||
with torch.cuda.stream(compute_stream):
|
||||
|
||||
@@ -104,27 +104,21 @@ class Attention(nn.Module):
|
||||
# This enables fully async offloads since each layer has its own buffer.
|
||||
offload_engine = context.kvcache_manager.offload_engine
|
||||
compute_stream = offload_engine.compute_stream
|
||||
chunk_idx = context.current_chunk_idx if hasattr(context, 'current_chunk_idx') else -1
|
||||
|
||||
# Wait for default stream to ensure slot_mapping tensor transfer is complete
|
||||
compute_stream.wait_stream(torch.cuda.default_stream())
|
||||
|
||||
with torch.cuda.stream(compute_stream):
|
||||
# Write KV to per-layer prefill buffer (contiguous write, no slot_mapping)
|
||||
# Write KV to per-layer prefill buffer via offload_engine
|
||||
# k, v shape: [num_tokens, kv_heads, head_dim]
|
||||
num_tokens = k.shape[0]
|
||||
offload_engine.prefill_k_buffer[self.layer_id, :num_tokens].copy_(k)
|
||||
offload_engine.prefill_v_buffer[self.layer_id, :num_tokens].copy_(v)
|
||||
#! GPU 2 GPU
|
||||
offload_engine.write_to_prefill_buffer(self.layer_id, k, v, chunk_idx=chunk_idx)
|
||||
elif is_chunked_offload:
|
||||
# Chunked decode mode: use compute_stream for store_kvcache
|
||||
# This ensures proper synchronization with per-layer offload
|
||||
compute_stream = context.kvcache_manager.offload_engine.compute_stream
|
||||
if k_cache.numel() and v_cache.numel():
|
||||
# CRITICAL: Wait for default stream to ensure slot_mapping tensor transfer is complete
|
||||
# slot_mapping is created with non_blocking=True on default stream, but we use it
|
||||
# on compute_stream. Without this sync, index_copy_ can get corrupted indices.
|
||||
compute_stream.wait_stream(torch.cuda.default_stream())
|
||||
with torch.cuda.stream(compute_stream):
|
||||
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
|
||||
# Chunked decode mode: write KV to per-layer decode buffer via offload_engine
|
||||
# KV will be written to decode buffer in the decode branch below
|
||||
# No store_kvcache needed - all KV management goes through offload_engine
|
||||
pass
|
||||
else:
|
||||
# Normal mode: store on default stream
|
||||
if k_cache.numel() and v_cache.numel():
|
||||
@@ -155,8 +149,7 @@ class Attention(nn.Module):
|
||||
offload_engine = kvcache_manager.offload_engine
|
||||
pos_in_block = context.decode_pos_in_block
|
||||
# k, v shape: [1, kv_heads, head_dim]
|
||||
offload_engine.decode_k_buffer[self.layer_id, pos_in_block].copy_(k.squeeze(0))
|
||||
offload_engine.decode_v_buffer[self.layer_id, pos_in_block].copy_(v.squeeze(0))
|
||||
offload_engine.write_to_decode_buffer(self.layer_id, pos_in_block, k.squeeze(0), v.squeeze(0))
|
||||
o = self._chunked_decode_attention(q, k, v, context)
|
||||
else:
|
||||
o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache,
|
||||
|
||||
Reference in New Issue
Block a user