Merge perf_opt-1 and perf_opt-2 branches
Combines two performance optimization features: - perf_opt-1: Cross-layer pipeline for decode (double-buffered layer cache) - perf_opt-2: Per-layer prefill buffer for async offload Both features are complementary and improve CPU offload performance. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -99,8 +99,23 @@ class Attention(nn.Module):
|
||||
# torch.cuda.synchronize()
|
||||
#! =======================================================
|
||||
|
||||
if is_chunked_offload:
|
||||
# Chunked offload mode: use compute_stream for store_kvcache
|
||||
if is_chunked_offload and context.is_prefill:
|
||||
# Chunked prefill mode: write KV to per-layer prefill buffer (not GPU slot)
|
||||
# This enables fully async offloads since each layer has its own buffer.
|
||||
offload_engine = context.kvcache_manager.offload_engine
|
||||
compute_stream = offload_engine.compute_stream
|
||||
|
||||
# Wait for default stream to ensure slot_mapping tensor transfer is complete
|
||||
compute_stream.wait_stream(torch.cuda.default_stream())
|
||||
|
||||
with torch.cuda.stream(compute_stream):
|
||||
# Write KV to per-layer prefill buffer (contiguous write, no slot_mapping)
|
||||
# k, v shape: [num_tokens, kv_heads, head_dim]
|
||||
num_tokens = k.shape[0]
|
||||
offload_engine.prefill_k_buffer[self.layer_id, :num_tokens].copy_(k)
|
||||
offload_engine.prefill_v_buffer[self.layer_id, :num_tokens].copy_(v)
|
||||
elif is_chunked_offload:
|
||||
# Chunked decode mode: use compute_stream for store_kvcache
|
||||
# This ensures proper synchronization with per-layer offload
|
||||
compute_stream = context.kvcache_manager.offload_engine.compute_stream
|
||||
if k_cache.numel() and v_cache.numel():
|
||||
@@ -157,36 +172,36 @@ class Attention(nn.Module):
|
||||
context,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute attention with unified ring buffer for chunked prefill.
|
||||
Compute attention with per-layer prefill buffer for async offload.
|
||||
|
||||
Ring buffer design:
|
||||
- Current chunk's KV is written to ring_slot[chunk_idx % N]
|
||||
- Previous chunks' KV are loaded from CPU using N-1 available slots
|
||||
- Pipeline: pre-fill slots, then process with overlapped load/compute
|
||||
Optimized design:
|
||||
- Current chunk's KV is written to per-layer prefill buffer (not GPU slot)
|
||||
- Previous chunks' KV are loaded from CPU using GPU slots
|
||||
- Each layer offloads from its own buffer - no waiting required!
|
||||
|
||||
For each layer:
|
||||
1. Current chunk's KV is in k_batched, v_batched (just written by model)
|
||||
1. Current chunk's KV is in prefill_buffer[layer_id] (just written by model)
|
||||
2. Load previous chunks from CPU using available slots (pipeline)
|
||||
3. Compute attention against previous KV (no causal mask)
|
||||
4. Compute attention against current KV (causal)
|
||||
4. Compute attention against current KV from prefill buffer (causal)
|
||||
5. Merge all results using online softmax
|
||||
6. Async offload prefill buffer to CPU (no waiting!)
|
||||
"""
|
||||
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||
|
||||
current_chunk_idx = context.current_chunk_idx
|
||||
torch.cuda.nvtx.range_push(f"ChunkedPrefill: L{self.layer_id} Chunk{current_chunk_idx}")
|
||||
|
||||
# q, k, v shape: [total_tokens, num_heads, head_dim]
|
||||
# Reshape for flash attention: [batch, seq, heads, dim]
|
||||
# q shape: [total_tokens, num_heads, head_dim]
|
||||
q_batched = q.unsqueeze(0) # [1, total_tokens, heads, dim]
|
||||
k_batched = k.unsqueeze(0)
|
||||
v_batched = v.unsqueeze(0)
|
||||
num_tokens = k.shape[0]
|
||||
|
||||
o_acc = None
|
||||
lse_acc = None
|
||||
|
||||
kvcache_manager = context.kvcache_manager
|
||||
seq = context.chunked_seq if hasattr(context, 'chunked_seq') else None
|
||||
offload_engine = kvcache_manager.offload_engine if kvcache_manager is not None else None
|
||||
|
||||
if kvcache_manager is not None and seq is not None and self.layer_id >= 0:
|
||||
# Get prefilled CPU blocks (blocks from previous chunks)
|
||||
@@ -210,11 +225,8 @@ class Attention(nn.Module):
|
||||
)
|
||||
|
||||
if cpu_block_table:
|
||||
offload_engine = kvcache_manager.offload_engine
|
||||
|
||||
# Get write slot for current chunk and available load slots
|
||||
write_slot = offload_engine.get_write_slot_for_prefill(current_chunk_idx)
|
||||
load_slots = offload_engine.get_load_slots_for_prefill(write_slot)
|
||||
# Get available load slots (all slots can be used since we use prefill buffer)
|
||||
load_slots = list(range(offload_engine.num_ring_slots))
|
||||
pipeline_depth = len(load_slots)
|
||||
|
||||
if pipeline_depth == 0:
|
||||
@@ -230,15 +242,14 @@ class Attention(nn.Module):
|
||||
)
|
||||
|
||||
# Get compute stream for all attention operations
|
||||
compute_stream = None
|
||||
if kvcache_manager is not None and hasattr(kvcache_manager, 'offload_engine'):
|
||||
compute_stream = kvcache_manager.offload_engine.compute_stream
|
||||
compute_stream = offload_engine.compute_stream if offload_engine is not None else None
|
||||
|
||||
# Compute attention against current chunk's KV (with causal mask)
|
||||
# Use compute_stream to ensure proper sync with store_kvcache and offload
|
||||
# Compute attention against current chunk's KV from prefill buffer (with causal mask)
|
||||
if compute_stream is not None:
|
||||
with torch.cuda.stream(compute_stream):
|
||||
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} CurrentChunk (causal)")
|
||||
# Get KV from per-layer prefill buffer
|
||||
k_batched, v_batched = offload_engine.get_prefill_buffer_slice(self.layer_id, num_tokens)
|
||||
current_o, current_lse = flash_attn_with_lse(
|
||||
q_batched,
|
||||
k_batched,
|
||||
@@ -249,6 +260,8 @@ class Attention(nn.Module):
|
||||
torch.cuda.nvtx.range_pop()
|
||||
else:
|
||||
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} CurrentChunk (causal)")
|
||||
k_batched = k.unsqueeze(0)
|
||||
v_batched = v.unsqueeze(0)
|
||||
current_o, current_lse = flash_attn_with_lse(
|
||||
q_batched,
|
||||
k_batched,
|
||||
@@ -274,26 +287,16 @@ class Attention(nn.Module):
|
||||
|
||||
torch.cuda.nvtx.range_pop() # ChunkedPrefill
|
||||
|
||||
# Per-layer offload: In new GPU cache architecture (no layer dimension),
|
||||
# each layer must offload its KV to CPU before next layer overwrites the GPU slot.
|
||||
if kvcache_manager is not None and hasattr(kvcache_manager, 'offload_engine'):
|
||||
offload_engine = kvcache_manager.offload_engine
|
||||
write_slot = offload_engine.get_write_slot_for_prefill(current_chunk_idx)
|
||||
seq = context.chunked_seq if hasattr(context, 'chunked_seq') else None
|
||||
if seq is not None:
|
||||
cpu_block_ids, _ = kvcache_manager.get_all_cpu_blocks(seq)
|
||||
if current_chunk_idx < len(cpu_block_ids):
|
||||
cpu_block_id = cpu_block_ids[current_chunk_idx]
|
||||
# k.shape[0] = number of tokens in current chunk
|
||||
num_valid_tokens = k.shape[0]
|
||||
offload_engine.offload_slot_layer_to_cpu(
|
||||
write_slot, self.layer_id, cpu_block_id, num_valid_tokens
|
||||
)
|
||||
|
||||
# CRITICAL: compute_stream must wait for offload to complete
|
||||
# before the next layer's store_kvcache can overwrite the GPU slot.
|
||||
# Without this, Layer N+1's store races with Layer N's offload copy.
|
||||
compute_stream.wait_event(offload_engine.ring_slot_offload_done[write_slot])
|
||||
# Per-layer ASYNC offload: offload prefill buffer to CPU
|
||||
# No waiting required! Each layer has its own buffer and stream.
|
||||
if offload_engine is not None and seq is not None:
|
||||
cpu_block_ids, _ = kvcache_manager.get_all_cpu_blocks(seq)
|
||||
if current_chunk_idx < len(cpu_block_ids):
|
||||
cpu_block_id = cpu_block_ids[current_chunk_idx]
|
||||
# Async offload - no waiting, fully parallel across layers
|
||||
offload_engine.offload_prefill_buffer_async(
|
||||
self.layer_id, cpu_block_id, num_tokens
|
||||
)
|
||||
|
||||
# Sync default stream with compute_stream before returning
|
||||
# This ensures the result is ready for the rest of the model (layernorm, MLP)
|
||||
|
||||
Reference in New Issue
Block a user