Merge perf_opt-1 and perf_opt-2 branches
Combines two performance optimization features: - perf_opt-1: Cross-layer pipeline for decode (double-buffered layer cache) - perf_opt-2: Per-layer prefill buffer for async offload Both features are complementary and improve CPU offload performance. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -77,6 +77,12 @@ PYTHONPATH=./.local/lib/python3.10/site-packages:$PYTHONPATH python <script.py>
|
||||
|
||||
**Note**: The Python version in the path (python3.10) should match your environment.
|
||||
|
||||
**CRITICAL**: After making code changes to `nanovllm/` source files, you MUST reinstall the package for changes to take effect:
|
||||
```bash
|
||||
pip install -e . --prefix=./.local --no-deps
|
||||
```
|
||||
Without reinstallation, Python will use the old cached version and your changes will NOT be reflected!
|
||||
|
||||
## Sparse Attention
|
||||
|
||||
For sparse attention related content (block sparse attention, MInference, FlexPrefill, XAttention, AvgPool, etc.), refer to [`docs/sparse_attention_guide.md`](docs/sparse_attention_guide.md).
|
||||
|
||||
@@ -455,8 +455,6 @@ class ModelRunner:
|
||||
3. After each chunk, offload from ring buffer slot to CPU
|
||||
4. All N-1 other slots are used to load previous chunks for attention
|
||||
"""
|
||||
import sys
|
||||
|
||||
assert len(seqs) == 1, "Ring buffer prefill only supports single sequence"
|
||||
seq = seqs[0]
|
||||
|
||||
@@ -466,10 +464,9 @@ class ModelRunner:
|
||||
|
||||
total_tokens = len(seq)
|
||||
num_chunks = (total_tokens + tokens_per_chunk - 1) // tokens_per_chunk
|
||||
print(f"[Ring Buffer Prefill] Starting: {total_tokens} tokens, "
|
||||
logger.debug(f"[Ring Buffer Prefill] Starting: {total_tokens} tokens, "
|
||||
f"ring_slots={offload_engine.num_ring_slots}, chunk={tokens_per_chunk} tokens, "
|
||||
f"total_chunks={num_chunks}",
|
||||
file=sys.stderr)
|
||||
f"total_chunks={num_chunks}")
|
||||
|
||||
chunk_idx = 0
|
||||
logits = None
|
||||
@@ -488,9 +485,8 @@ class ModelRunner:
|
||||
# CPU block index for this chunk
|
||||
block_idx = chunk_idx
|
||||
|
||||
print(f"[Ring Buffer Prefill] Chunk {chunk_idx}: tokens {chunk_start}-{chunk_end}, "
|
||||
f"write_slot={write_slot}",
|
||||
file=sys.stderr)
|
||||
logger.debug(f"[Ring Buffer Prefill] Chunk {chunk_idx}: tokens {chunk_start}-{chunk_end}, "
|
||||
f"write_slot={write_slot}")
|
||||
|
||||
# Prepare inputs
|
||||
input_ids, positions = self._prepare_chunked_offload_chunk(
|
||||
@@ -509,27 +505,17 @@ class ModelRunner:
|
||||
logical_id = seq.block_table[block_idx]
|
||||
self.kvcache_manager.prefilled_blocks.add(logical_id)
|
||||
|
||||
# NOTE: Per-layer offloading is now done in attention.forward
|
||||
# Each layer offloads its KV to CPU immediately after computing attention.
|
||||
# We just need to wait for the last offload to complete before reusing the slot.
|
||||
if block_idx < len(cpu_block_ids):
|
||||
# TODO: Sparse policy hook needs update for new GPU cache architecture
|
||||
# The GPU cache no longer has layer dimension, so we can't access
|
||||
# k_cache_gpu[layer_id, write_slot]. Sparse policy should be called
|
||||
# in attention.forward after per-layer offload.
|
||||
pass
|
||||
|
||||
# Wait for offload to complete before next chunk
|
||||
# (slot will be reused after N chunks)
|
||||
offload_engine.wait_slot_offload(write_slot)
|
||||
# NOTE: Per-layer async offloading is now done in attention.forward
|
||||
# Each layer offloads from its own prefill buffer - no waiting required!
|
||||
# The sparse policy hook is called in offload_prefill_buffer_async.
|
||||
|
||||
processed_tokens = chunk_end
|
||||
chunk_idx += 1
|
||||
|
||||
# Wait for all offloads to complete
|
||||
offload_engine.wait_all_offload_done()
|
||||
# Wait for all async prefill offloads to complete
|
||||
offload_engine.wait_all_prefill_offloads()
|
||||
|
||||
print(f"[Ring Buffer Prefill] Complete: {chunk_idx} chunks", file=sys.stderr)
|
||||
logger.debug(f"[Ring Buffer Prefill] Complete: {chunk_idx} chunks")
|
||||
|
||||
# Sample from last logits
|
||||
# For chunked prefill, ParallelLMHead automatically selects last position's logits
|
||||
|
||||
@@ -176,6 +176,30 @@ class OffloadEngine:
|
||||
self._pipeline_num_blocks = 0
|
||||
self._pipeline_layer_stream = torch.cuda.Stream() # Dedicated stream for layer loading
|
||||
|
||||
# ========== Per-layer prefill buffer for async offload ==========
|
||||
# During chunked prefill, all layers share the same GPU slot. This means
|
||||
# each layer must wait for offload to complete before the next layer can
|
||||
# write to the same slot. This serializes offloads and hurts performance.
|
||||
# Solution: Maintain separate per-layer buffers for prefill.
|
||||
# Each layer writes to its own buffer, enabling fully async offloads.
|
||||
# Shape: [num_layers, block_size, kv_heads, head_dim]
|
||||
self.prefill_k_buffer = torch.zeros(
|
||||
num_layers, block_size, num_kv_heads, head_dim,
|
||||
dtype=dtype, device="cuda"
|
||||
)
|
||||
self.prefill_v_buffer = torch.zeros(
|
||||
num_layers, block_size, num_kv_heads, head_dim,
|
||||
dtype=dtype, device="cuda"
|
||||
)
|
||||
prefill_buf_mb = 2 * num_layers * block_size * num_kv_heads * head_dim * dtype.itemsize / (1024 * 1024)
|
||||
logger.info(f" Per-layer prefill buffer: {prefill_buf_mb:.1f} MB")
|
||||
|
||||
# Per-layer offload events for async prefill offload
|
||||
# Each layer has its own event to track offload completion
|
||||
self.prefill_offload_events = [torch.cuda.Event() for _ in range(num_layers)]
|
||||
# Per-layer transfer streams for parallel offloads
|
||||
self.prefill_offload_streams = [torch.cuda.Stream() for _ in range(num_layers)]
|
||||
|
||||
# ========== Fixed-address CPU KV cache (pinned memory) ==========
|
||||
self.k_cache_cpu = torch.zeros(
|
||||
num_layers, num_cpu_blocks, block_size, num_kv_heads, head_dim,
|
||||
@@ -1214,3 +1238,91 @@ class OffloadEngine:
|
||||
def is_pipeline_active(self) -> bool:
|
||||
"""Check if decode pipeline is currently active."""
|
||||
return self._pipeline_active
|
||||
|
||||
# ========== Per-layer Prefill Buffer Methods ==========
|
||||
# These methods enable async offload during chunked prefill by using
|
||||
# per-layer buffers instead of shared GPU slots.
|
||||
|
||||
def get_prefill_buffer(self, layer_id: int) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Get prefill buffer for a layer.
|
||||
|
||||
Args:
|
||||
layer_id: Layer index
|
||||
|
||||
Returns:
|
||||
(k_buffer, v_buffer), shape: [block_size, kv_heads, head_dim]
|
||||
"""
|
||||
return self.prefill_k_buffer[layer_id], self.prefill_v_buffer[layer_id]
|
||||
|
||||
def get_prefill_buffer_slice(
|
||||
self,
|
||||
layer_id: int,
|
||||
num_tokens: int,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Get a slice of prefill buffer for attention computation.
|
||||
|
||||
Args:
|
||||
layer_id: Layer index
|
||||
num_tokens: Number of valid tokens in current chunk
|
||||
|
||||
Returns:
|
||||
(k, v) with shape [1, num_tokens, kv_heads, head_dim]
|
||||
"""
|
||||
k = self.prefill_k_buffer[layer_id, :num_tokens].unsqueeze(0)
|
||||
v = self.prefill_v_buffer[layer_id, :num_tokens].unsqueeze(0)
|
||||
return k, v
|
||||
|
||||
def offload_prefill_buffer_async(
|
||||
self,
|
||||
layer_id: int,
|
||||
cpu_block_id: int,
|
||||
num_valid_tokens: int = -1,
|
||||
) -> None:
|
||||
"""
|
||||
Async offload prefill buffer to CPU (no waiting required).
|
||||
|
||||
This uses per-layer streams and events to enable fully async offloads.
|
||||
Each layer can offload independently without blocking other layers.
|
||||
|
||||
Args:
|
||||
layer_id: Layer index
|
||||
cpu_block_id: Target CPU block ID
|
||||
num_valid_tokens: Number of valid tokens (-1 = use block_size)
|
||||
"""
|
||||
valid_tokens = num_valid_tokens if num_valid_tokens > 0 else self.block_size
|
||||
|
||||
# Collect sparse policy metadata before offload
|
||||
if self.sparse_policy is not None:
|
||||
k_cache = self.prefill_k_buffer[layer_id]
|
||||
self.sparse_policy.on_prefill_offload(cpu_block_id, layer_id, k_cache, valid_tokens)
|
||||
|
||||
# Use per-layer stream for parallel offloads
|
||||
stream = self.prefill_offload_streams[layer_id]
|
||||
|
||||
torch.cuda.nvtx.range_push(f"AsyncPrefillOffload: L{layer_id}->CPU[{cpu_block_id}]")
|
||||
with torch.cuda.stream(stream):
|
||||
# Wait for compute to finish writing to prefill buffer
|
||||
stream.wait_stream(self.compute_stream)
|
||||
|
||||
# Copy from prefill buffer to CPU
|
||||
self.k_cache_cpu[layer_id, cpu_block_id].copy_(
|
||||
self.prefill_k_buffer[layer_id], non_blocking=True
|
||||
)
|
||||
self.v_cache_cpu[layer_id, cpu_block_id].copy_(
|
||||
self.prefill_v_buffer[layer_id], non_blocking=True
|
||||
)
|
||||
|
||||
# Record completion event
|
||||
self.prefill_offload_events[layer_id].record(stream)
|
||||
torch.cuda.nvtx.range_pop()
|
||||
|
||||
def wait_all_prefill_offloads(self) -> None:
|
||||
"""Wait for all prefill buffer offloads to complete."""
|
||||
for stream in self.prefill_offload_streams:
|
||||
stream.synchronize()
|
||||
|
||||
def wait_prefill_offload(self, layer_id: int) -> None:
|
||||
"""Wait for a specific layer's prefill offload to complete."""
|
||||
self.prefill_offload_events[layer_id].synchronize()
|
||||
|
||||
@@ -99,8 +99,23 @@ class Attention(nn.Module):
|
||||
# torch.cuda.synchronize()
|
||||
#! =======================================================
|
||||
|
||||
if is_chunked_offload:
|
||||
# Chunked offload mode: use compute_stream for store_kvcache
|
||||
if is_chunked_offload and context.is_prefill:
|
||||
# Chunked prefill mode: write KV to per-layer prefill buffer (not GPU slot)
|
||||
# This enables fully async offloads since each layer has its own buffer.
|
||||
offload_engine = context.kvcache_manager.offload_engine
|
||||
compute_stream = offload_engine.compute_stream
|
||||
|
||||
# Wait for default stream to ensure slot_mapping tensor transfer is complete
|
||||
compute_stream.wait_stream(torch.cuda.default_stream())
|
||||
|
||||
with torch.cuda.stream(compute_stream):
|
||||
# Write KV to per-layer prefill buffer (contiguous write, no slot_mapping)
|
||||
# k, v shape: [num_tokens, kv_heads, head_dim]
|
||||
num_tokens = k.shape[0]
|
||||
offload_engine.prefill_k_buffer[self.layer_id, :num_tokens].copy_(k)
|
||||
offload_engine.prefill_v_buffer[self.layer_id, :num_tokens].copy_(v)
|
||||
elif is_chunked_offload:
|
||||
# Chunked decode mode: use compute_stream for store_kvcache
|
||||
# This ensures proper synchronization with per-layer offload
|
||||
compute_stream = context.kvcache_manager.offload_engine.compute_stream
|
||||
if k_cache.numel() and v_cache.numel():
|
||||
@@ -157,36 +172,36 @@ class Attention(nn.Module):
|
||||
context,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute attention with unified ring buffer for chunked prefill.
|
||||
Compute attention with per-layer prefill buffer for async offload.
|
||||
|
||||
Ring buffer design:
|
||||
- Current chunk's KV is written to ring_slot[chunk_idx % N]
|
||||
- Previous chunks' KV are loaded from CPU using N-1 available slots
|
||||
- Pipeline: pre-fill slots, then process with overlapped load/compute
|
||||
Optimized design:
|
||||
- Current chunk's KV is written to per-layer prefill buffer (not GPU slot)
|
||||
- Previous chunks' KV are loaded from CPU using GPU slots
|
||||
- Each layer offloads from its own buffer - no waiting required!
|
||||
|
||||
For each layer:
|
||||
1. Current chunk's KV is in k_batched, v_batched (just written by model)
|
||||
1. Current chunk's KV is in prefill_buffer[layer_id] (just written by model)
|
||||
2. Load previous chunks from CPU using available slots (pipeline)
|
||||
3. Compute attention against previous KV (no causal mask)
|
||||
4. Compute attention against current KV (causal)
|
||||
4. Compute attention against current KV from prefill buffer (causal)
|
||||
5. Merge all results using online softmax
|
||||
6. Async offload prefill buffer to CPU (no waiting!)
|
||||
"""
|
||||
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||
|
||||
current_chunk_idx = context.current_chunk_idx
|
||||
torch.cuda.nvtx.range_push(f"ChunkedPrefill: L{self.layer_id} Chunk{current_chunk_idx}")
|
||||
|
||||
# q, k, v shape: [total_tokens, num_heads, head_dim]
|
||||
# Reshape for flash attention: [batch, seq, heads, dim]
|
||||
# q shape: [total_tokens, num_heads, head_dim]
|
||||
q_batched = q.unsqueeze(0) # [1, total_tokens, heads, dim]
|
||||
k_batched = k.unsqueeze(0)
|
||||
v_batched = v.unsqueeze(0)
|
||||
num_tokens = k.shape[0]
|
||||
|
||||
o_acc = None
|
||||
lse_acc = None
|
||||
|
||||
kvcache_manager = context.kvcache_manager
|
||||
seq = context.chunked_seq if hasattr(context, 'chunked_seq') else None
|
||||
offload_engine = kvcache_manager.offload_engine if kvcache_manager is not None else None
|
||||
|
||||
if kvcache_manager is not None and seq is not None and self.layer_id >= 0:
|
||||
# Get prefilled CPU blocks (blocks from previous chunks)
|
||||
@@ -210,11 +225,8 @@ class Attention(nn.Module):
|
||||
)
|
||||
|
||||
if cpu_block_table:
|
||||
offload_engine = kvcache_manager.offload_engine
|
||||
|
||||
# Get write slot for current chunk and available load slots
|
||||
write_slot = offload_engine.get_write_slot_for_prefill(current_chunk_idx)
|
||||
load_slots = offload_engine.get_load_slots_for_prefill(write_slot)
|
||||
# Get available load slots (all slots can be used since we use prefill buffer)
|
||||
load_slots = list(range(offload_engine.num_ring_slots))
|
||||
pipeline_depth = len(load_slots)
|
||||
|
||||
if pipeline_depth == 0:
|
||||
@@ -230,15 +242,14 @@ class Attention(nn.Module):
|
||||
)
|
||||
|
||||
# Get compute stream for all attention operations
|
||||
compute_stream = None
|
||||
if kvcache_manager is not None and hasattr(kvcache_manager, 'offload_engine'):
|
||||
compute_stream = kvcache_manager.offload_engine.compute_stream
|
||||
compute_stream = offload_engine.compute_stream if offload_engine is not None else None
|
||||
|
||||
# Compute attention against current chunk's KV (with causal mask)
|
||||
# Use compute_stream to ensure proper sync with store_kvcache and offload
|
||||
# Compute attention against current chunk's KV from prefill buffer (with causal mask)
|
||||
if compute_stream is not None:
|
||||
with torch.cuda.stream(compute_stream):
|
||||
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} CurrentChunk (causal)")
|
||||
# Get KV from per-layer prefill buffer
|
||||
k_batched, v_batched = offload_engine.get_prefill_buffer_slice(self.layer_id, num_tokens)
|
||||
current_o, current_lse = flash_attn_with_lse(
|
||||
q_batched,
|
||||
k_batched,
|
||||
@@ -249,6 +260,8 @@ class Attention(nn.Module):
|
||||
torch.cuda.nvtx.range_pop()
|
||||
else:
|
||||
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} CurrentChunk (causal)")
|
||||
k_batched = k.unsqueeze(0)
|
||||
v_batched = v.unsqueeze(0)
|
||||
current_o, current_lse = flash_attn_with_lse(
|
||||
q_batched,
|
||||
k_batched,
|
||||
@@ -274,26 +287,16 @@ class Attention(nn.Module):
|
||||
|
||||
torch.cuda.nvtx.range_pop() # ChunkedPrefill
|
||||
|
||||
# Per-layer offload: In new GPU cache architecture (no layer dimension),
|
||||
# each layer must offload its KV to CPU before next layer overwrites the GPU slot.
|
||||
if kvcache_manager is not None and hasattr(kvcache_manager, 'offload_engine'):
|
||||
offload_engine = kvcache_manager.offload_engine
|
||||
write_slot = offload_engine.get_write_slot_for_prefill(current_chunk_idx)
|
||||
seq = context.chunked_seq if hasattr(context, 'chunked_seq') else None
|
||||
if seq is not None:
|
||||
cpu_block_ids, _ = kvcache_manager.get_all_cpu_blocks(seq)
|
||||
if current_chunk_idx < len(cpu_block_ids):
|
||||
cpu_block_id = cpu_block_ids[current_chunk_idx]
|
||||
# k.shape[0] = number of tokens in current chunk
|
||||
num_valid_tokens = k.shape[0]
|
||||
offload_engine.offload_slot_layer_to_cpu(
|
||||
write_slot, self.layer_id, cpu_block_id, num_valid_tokens
|
||||
)
|
||||
|
||||
# CRITICAL: compute_stream must wait for offload to complete
|
||||
# before the next layer's store_kvcache can overwrite the GPU slot.
|
||||
# Without this, Layer N+1's store races with Layer N's offload copy.
|
||||
compute_stream.wait_event(offload_engine.ring_slot_offload_done[write_slot])
|
||||
# Per-layer ASYNC offload: offload prefill buffer to CPU
|
||||
# No waiting required! Each layer has its own buffer and stream.
|
||||
if offload_engine is not None and seq is not None:
|
||||
cpu_block_ids, _ = kvcache_manager.get_all_cpu_blocks(seq)
|
||||
if current_chunk_idx < len(cpu_block_ids):
|
||||
cpu_block_id = cpu_block_ids[current_chunk_idx]
|
||||
# Async offload - no waiting, fully parallel across layers
|
||||
offload_engine.offload_prefill_buffer_async(
|
||||
self.layer_id, cpu_block_id, num_tokens
|
||||
)
|
||||
|
||||
# Sync default stream with compute_stream before returning
|
||||
# This ensures the result is ready for the rest of the model (layernorm, MLP)
|
||||
|
||||
Reference in New Issue
Block a user