Merge perf_opt-1 and perf_opt-2 branches

Combines two performance optimization features:
- perf_opt-1: Cross-layer pipeline for decode (double-buffered layer cache)
- perf_opt-2: Per-layer prefill buffer for async offload

Both features are complementary and improve CPU offload performance.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Zijie Tian
2026-01-07 06:03:44 +08:00
4 changed files with 175 additions and 68 deletions

View File

@@ -99,8 +99,23 @@ class Attention(nn.Module):
# torch.cuda.synchronize()
#! =======================================================
if is_chunked_offload:
# Chunked offload mode: use compute_stream for store_kvcache
if is_chunked_offload and context.is_prefill:
# Chunked prefill mode: write KV to per-layer prefill buffer (not GPU slot)
# This enables fully async offloads since each layer has its own buffer.
offload_engine = context.kvcache_manager.offload_engine
compute_stream = offload_engine.compute_stream
# Wait for default stream to ensure slot_mapping tensor transfer is complete
compute_stream.wait_stream(torch.cuda.default_stream())
with torch.cuda.stream(compute_stream):
# Write KV to per-layer prefill buffer (contiguous write, no slot_mapping)
# k, v shape: [num_tokens, kv_heads, head_dim]
num_tokens = k.shape[0]
offload_engine.prefill_k_buffer[self.layer_id, :num_tokens].copy_(k)
offload_engine.prefill_v_buffer[self.layer_id, :num_tokens].copy_(v)
elif is_chunked_offload:
# Chunked decode mode: use compute_stream for store_kvcache
# This ensures proper synchronization with per-layer offload
compute_stream = context.kvcache_manager.offload_engine.compute_stream
if k_cache.numel() and v_cache.numel():
@@ -157,36 +172,36 @@ class Attention(nn.Module):
context,
) -> torch.Tensor:
"""
Compute attention with unified ring buffer for chunked prefill.
Compute attention with per-layer prefill buffer for async offload.
Ring buffer design:
- Current chunk's KV is written to ring_slot[chunk_idx % N]
- Previous chunks' KV are loaded from CPU using N-1 available slots
- Pipeline: pre-fill slots, then process with overlapped load/compute
Optimized design:
- Current chunk's KV is written to per-layer prefill buffer (not GPU slot)
- Previous chunks' KV are loaded from CPU using GPU slots
- Each layer offloads from its own buffer - no waiting required!
For each layer:
1. Current chunk's KV is in k_batched, v_batched (just written by model)
1. Current chunk's KV is in prefill_buffer[layer_id] (just written by model)
2. Load previous chunks from CPU using available slots (pipeline)
3. Compute attention against previous KV (no causal mask)
4. Compute attention against current KV (causal)
4. Compute attention against current KV from prefill buffer (causal)
5. Merge all results using online softmax
6. Async offload prefill buffer to CPU (no waiting!)
"""
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
current_chunk_idx = context.current_chunk_idx
torch.cuda.nvtx.range_push(f"ChunkedPrefill: L{self.layer_id} Chunk{current_chunk_idx}")
# q, k, v shape: [total_tokens, num_heads, head_dim]
# Reshape for flash attention: [batch, seq, heads, dim]
# q shape: [total_tokens, num_heads, head_dim]
q_batched = q.unsqueeze(0) # [1, total_tokens, heads, dim]
k_batched = k.unsqueeze(0)
v_batched = v.unsqueeze(0)
num_tokens = k.shape[0]
o_acc = None
lse_acc = None
kvcache_manager = context.kvcache_manager
seq = context.chunked_seq if hasattr(context, 'chunked_seq') else None
offload_engine = kvcache_manager.offload_engine if kvcache_manager is not None else None
if kvcache_manager is not None and seq is not None and self.layer_id >= 0:
# Get prefilled CPU blocks (blocks from previous chunks)
@@ -210,11 +225,8 @@ class Attention(nn.Module):
)
if cpu_block_table:
offload_engine = kvcache_manager.offload_engine
# Get write slot for current chunk and available load slots
write_slot = offload_engine.get_write_slot_for_prefill(current_chunk_idx)
load_slots = offload_engine.get_load_slots_for_prefill(write_slot)
# Get available load slots (all slots can be used since we use prefill buffer)
load_slots = list(range(offload_engine.num_ring_slots))
pipeline_depth = len(load_slots)
if pipeline_depth == 0:
@@ -230,15 +242,14 @@ class Attention(nn.Module):
)
# Get compute stream for all attention operations
compute_stream = None
if kvcache_manager is not None and hasattr(kvcache_manager, 'offload_engine'):
compute_stream = kvcache_manager.offload_engine.compute_stream
compute_stream = offload_engine.compute_stream if offload_engine is not None else None
# Compute attention against current chunk's KV (with causal mask)
# Use compute_stream to ensure proper sync with store_kvcache and offload
# Compute attention against current chunk's KV from prefill buffer (with causal mask)
if compute_stream is not None:
with torch.cuda.stream(compute_stream):
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} CurrentChunk (causal)")
# Get KV from per-layer prefill buffer
k_batched, v_batched = offload_engine.get_prefill_buffer_slice(self.layer_id, num_tokens)
current_o, current_lse = flash_attn_with_lse(
q_batched,
k_batched,
@@ -249,6 +260,8 @@ class Attention(nn.Module):
torch.cuda.nvtx.range_pop()
else:
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} CurrentChunk (causal)")
k_batched = k.unsqueeze(0)
v_batched = v.unsqueeze(0)
current_o, current_lse = flash_attn_with_lse(
q_batched,
k_batched,
@@ -274,26 +287,16 @@ class Attention(nn.Module):
torch.cuda.nvtx.range_pop() # ChunkedPrefill
# Per-layer offload: In new GPU cache architecture (no layer dimension),
# each layer must offload its KV to CPU before next layer overwrites the GPU slot.
if kvcache_manager is not None and hasattr(kvcache_manager, 'offload_engine'):
offload_engine = kvcache_manager.offload_engine
write_slot = offload_engine.get_write_slot_for_prefill(current_chunk_idx)
seq = context.chunked_seq if hasattr(context, 'chunked_seq') else None
if seq is not None:
cpu_block_ids, _ = kvcache_manager.get_all_cpu_blocks(seq)
if current_chunk_idx < len(cpu_block_ids):
cpu_block_id = cpu_block_ids[current_chunk_idx]
# k.shape[0] = number of tokens in current chunk
num_valid_tokens = k.shape[0]
offload_engine.offload_slot_layer_to_cpu(
write_slot, self.layer_id, cpu_block_id, num_valid_tokens
)
# CRITICAL: compute_stream must wait for offload to complete
# before the next layer's store_kvcache can overwrite the GPU slot.
# Without this, Layer N+1's store races with Layer N's offload copy.
compute_stream.wait_event(offload_engine.ring_slot_offload_done[write_slot])
# Per-layer ASYNC offload: offload prefill buffer to CPU
# No waiting required! Each layer has its own buffer and stream.
if offload_engine is not None and seq is not None:
cpu_block_ids, _ = kvcache_manager.get_all_cpu_blocks(seq)
if current_chunk_idx < len(cpu_block_ids):
cpu_block_id = cpu_block_ids[current_chunk_idx]
# Async offload - no waiting, fully parallel across layers
offload_engine.offload_prefill_buffer_async(
self.layer_id, cpu_block_id, num_tokens
)
# Sync default stream with compute_stream before returning
# This ensures the result is ready for the rest of the model (layernorm, MLP)