From 0ad86eb449738516362a3a701c7d0e87766d292e Mon Sep 17 00:00:00 2001 From: Zijie Tian Date: Wed, 7 Jan 2026 05:58:10 +0800 Subject: [PATCH] [claudesquad] update from 'perf_opt-2' on 07 Jan 26 05:58 CST --- CLAUDE.md | 6 ++ nanovllm/engine/model_runner.py | 34 +++------ nanovllm/kvcache/offload_engine.py | 114 ++++++++++++++++++++++++++++- nanovllm/layers/attention.py | 89 +++++++++++----------- 4 files changed, 175 insertions(+), 68 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index c40c588..bfc5813 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -66,6 +66,12 @@ python bench_offload.py **Note**: The Python version in the path (python3.10) should match your environment. +**CRITICAL**: After making code changes to `nanovllm/` source files, you MUST reinstall the package for changes to take effect: +```bash +pip install -e . --prefix=./.local --no-deps +``` +Without reinstallation, Python will use the old cached version and your changes will NOT be reflected! + ## Sparse Attention For sparse attention related content (block sparse attention, MInference, FlexPrefill, XAttention, AvgPool, etc.), refer to [`docs/sparse_attention_guide.md`](docs/sparse_attention_guide.md). diff --git a/nanovllm/engine/model_runner.py b/nanovllm/engine/model_runner.py index 3b463f0..cd3ab77 100644 --- a/nanovllm/engine/model_runner.py +++ b/nanovllm/engine/model_runner.py @@ -455,8 +455,6 @@ class ModelRunner: 3. After each chunk, offload from ring buffer slot to CPU 4. All N-1 other slots are used to load previous chunks for attention """ - import sys - assert len(seqs) == 1, "Ring buffer prefill only supports single sequence" seq = seqs[0] @@ -466,10 +464,9 @@ class ModelRunner: total_tokens = len(seq) num_chunks = (total_tokens + tokens_per_chunk - 1) // tokens_per_chunk - print(f"[Ring Buffer Prefill] Starting: {total_tokens} tokens, " + logger.debug(f"[Ring Buffer Prefill] Starting: {total_tokens} tokens, " f"ring_slots={offload_engine.num_ring_slots}, chunk={tokens_per_chunk} tokens, " - f"total_chunks={num_chunks}", - file=sys.stderr) + f"total_chunks={num_chunks}") chunk_idx = 0 logits = None @@ -488,9 +485,8 @@ class ModelRunner: # CPU block index for this chunk block_idx = chunk_idx - print(f"[Ring Buffer Prefill] Chunk {chunk_idx}: tokens {chunk_start}-{chunk_end}, " - f"write_slot={write_slot}", - file=sys.stderr) + logger.debug(f"[Ring Buffer Prefill] Chunk {chunk_idx}: tokens {chunk_start}-{chunk_end}, " + f"write_slot={write_slot}") # Prepare inputs input_ids, positions = self._prepare_chunked_offload_chunk( @@ -509,27 +505,17 @@ class ModelRunner: logical_id = seq.block_table[block_idx] self.kvcache_manager.prefilled_blocks.add(logical_id) - # NOTE: Per-layer offloading is now done in attention.forward - # Each layer offloads its KV to CPU immediately after computing attention. - # We just need to wait for the last offload to complete before reusing the slot. - if block_idx < len(cpu_block_ids): - # TODO: Sparse policy hook needs update for new GPU cache architecture - # The GPU cache no longer has layer dimension, so we can't access - # k_cache_gpu[layer_id, write_slot]. Sparse policy should be called - # in attention.forward after per-layer offload. - pass - - # Wait for offload to complete before next chunk - # (slot will be reused after N chunks) - offload_engine.wait_slot_offload(write_slot) + # NOTE: Per-layer async offloading is now done in attention.forward + # Each layer offloads from its own prefill buffer - no waiting required! + # The sparse policy hook is called in offload_prefill_buffer_async. processed_tokens = chunk_end chunk_idx += 1 - # Wait for all offloads to complete - offload_engine.wait_all_offload_done() + # Wait for all async prefill offloads to complete + offload_engine.wait_all_prefill_offloads() - print(f"[Ring Buffer Prefill] Complete: {chunk_idx} chunks", file=sys.stderr) + logger.debug(f"[Ring Buffer Prefill] Complete: {chunk_idx} chunks") # Sample from last logits # For chunked prefill, ParallelLMHead automatically selects last position's logits diff --git a/nanovllm/kvcache/offload_engine.py b/nanovllm/kvcache/offload_engine.py index 5260906..e10ad2a 100644 --- a/nanovllm/kvcache/offload_engine.py +++ b/nanovllm/kvcache/offload_engine.py @@ -142,6 +142,30 @@ class OffloadEngine: decode_buf_mb = 2 * num_layers * block_size * num_kv_heads * head_dim * dtype.itemsize / (1024 * 1024) logger.info(f" Per-layer decode buffer: {decode_buf_mb:.1f} MB") + # ========== Per-layer prefill buffer for async offload ========== + # During chunked prefill, all layers share the same GPU slot. This means + # each layer must wait for offload to complete before the next layer can + # write to the same slot. This serializes offloads and hurts performance. + # Solution: Maintain separate per-layer buffers for prefill. + # Each layer writes to its own buffer, enabling fully async offloads. + # Shape: [num_layers, block_size, kv_heads, head_dim] + self.prefill_k_buffer = torch.zeros( + num_layers, block_size, num_kv_heads, head_dim, + dtype=dtype, device="cuda" + ) + self.prefill_v_buffer = torch.zeros( + num_layers, block_size, num_kv_heads, head_dim, + dtype=dtype, device="cuda" + ) + prefill_buf_mb = 2 * num_layers * block_size * num_kv_heads * head_dim * dtype.itemsize / (1024 * 1024) + logger.info(f" Per-layer prefill buffer: {prefill_buf_mb:.1f} MB") + + # Per-layer offload events for async prefill offload + # Each layer has its own event to track offload completion + self.prefill_offload_events = [torch.cuda.Event() for _ in range(num_layers)] + # Per-layer transfer streams for parallel offloads + self.prefill_offload_streams = [torch.cuda.Stream() for _ in range(num_layers)] + # ========== Fixed-address CPU KV cache (pinned memory) ========== self.k_cache_cpu = torch.zeros( num_layers, num_cpu_blocks, block_size, num_kv_heads, head_dim, @@ -1063,4 +1087,92 @@ class OffloadEngine: # Allow pdb quit to propagate if e.__class__.__name__ == 'BdbQuit': raise - logger.warning(f"Debug hook error: {e}") \ No newline at end of file + logger.warning(f"Debug hook error: {e}") + + # ========== Per-layer Prefill Buffer Methods ========== + # These methods enable async offload during chunked prefill by using + # per-layer buffers instead of shared GPU slots. + + def get_prefill_buffer(self, layer_id: int) -> Tuple[Tensor, Tensor]: + """ + Get prefill buffer for a layer. + + Args: + layer_id: Layer index + + Returns: + (k_buffer, v_buffer), shape: [block_size, kv_heads, head_dim] + """ + return self.prefill_k_buffer[layer_id], self.prefill_v_buffer[layer_id] + + def get_prefill_buffer_slice( + self, + layer_id: int, + num_tokens: int, + ) -> Tuple[Tensor, Tensor]: + """ + Get a slice of prefill buffer for attention computation. + + Args: + layer_id: Layer index + num_tokens: Number of valid tokens in current chunk + + Returns: + (k, v) with shape [1, num_tokens, kv_heads, head_dim] + """ + k = self.prefill_k_buffer[layer_id, :num_tokens].unsqueeze(0) + v = self.prefill_v_buffer[layer_id, :num_tokens].unsqueeze(0) + return k, v + + def offload_prefill_buffer_async( + self, + layer_id: int, + cpu_block_id: int, + num_valid_tokens: int = -1, + ) -> None: + """ + Async offload prefill buffer to CPU (no waiting required). + + This uses per-layer streams and events to enable fully async offloads. + Each layer can offload independently without blocking other layers. + + Args: + layer_id: Layer index + cpu_block_id: Target CPU block ID + num_valid_tokens: Number of valid tokens (-1 = use block_size) + """ + valid_tokens = num_valid_tokens if num_valid_tokens > 0 else self.block_size + + # Collect sparse policy metadata before offload + if self.sparse_policy is not None: + k_cache = self.prefill_k_buffer[layer_id] + self.sparse_policy.on_prefill_offload(cpu_block_id, layer_id, k_cache, valid_tokens) + + # Use per-layer stream for parallel offloads + stream = self.prefill_offload_streams[layer_id] + + torch.cuda.nvtx.range_push(f"AsyncPrefillOffload: L{layer_id}->CPU[{cpu_block_id}]") + with torch.cuda.stream(stream): + # Wait for compute to finish writing to prefill buffer + stream.wait_stream(self.compute_stream) + + # Copy from prefill buffer to CPU + self.k_cache_cpu[layer_id, cpu_block_id].copy_( + self.prefill_k_buffer[layer_id], non_blocking=True + ) + self.v_cache_cpu[layer_id, cpu_block_id].copy_( + self.prefill_v_buffer[layer_id], non_blocking=True + ) + + # Record completion event + self.prefill_offload_events[layer_id].record(stream) + torch.cuda.nvtx.range_pop() + + def wait_all_prefill_offloads(self) -> None: + """Wait for all prefill buffer offloads to complete.""" + for stream in self.prefill_offload_streams: + stream.synchronize() + + def wait_prefill_offload(self, layer_id: int) -> None: + """Wait for a specific layer's prefill offload to complete.""" + self.prefill_offload_events[layer_id].synchronize() \ No newline at end of file diff --git a/nanovllm/layers/attention.py b/nanovllm/layers/attention.py index 197d082..6283eb9 100644 --- a/nanovllm/layers/attention.py +++ b/nanovllm/layers/attention.py @@ -99,8 +99,23 @@ class Attention(nn.Module): # torch.cuda.synchronize() #! ======================================================= - if is_chunked_offload: - # Chunked offload mode: use compute_stream for store_kvcache + if is_chunked_offload and context.is_prefill: + # Chunked prefill mode: write KV to per-layer prefill buffer (not GPU slot) + # This enables fully async offloads since each layer has its own buffer. + offload_engine = context.kvcache_manager.offload_engine + compute_stream = offload_engine.compute_stream + + # Wait for default stream to ensure slot_mapping tensor transfer is complete + compute_stream.wait_stream(torch.cuda.default_stream()) + + with torch.cuda.stream(compute_stream): + # Write KV to per-layer prefill buffer (contiguous write, no slot_mapping) + # k, v shape: [num_tokens, kv_heads, head_dim] + num_tokens = k.shape[0] + offload_engine.prefill_k_buffer[self.layer_id, :num_tokens].copy_(k) + offload_engine.prefill_v_buffer[self.layer_id, :num_tokens].copy_(v) + elif is_chunked_offload: + # Chunked decode mode: use compute_stream for store_kvcache # This ensures proper synchronization with per-layer offload compute_stream = context.kvcache_manager.offload_engine.compute_stream if k_cache.numel() and v_cache.numel(): @@ -157,36 +172,36 @@ class Attention(nn.Module): context, ) -> torch.Tensor: """ - Compute attention with unified ring buffer for chunked prefill. + Compute attention with per-layer prefill buffer for async offload. - Ring buffer design: - - Current chunk's KV is written to ring_slot[chunk_idx % N] - - Previous chunks' KV are loaded from CPU using N-1 available slots - - Pipeline: pre-fill slots, then process with overlapped load/compute + Optimized design: + - Current chunk's KV is written to per-layer prefill buffer (not GPU slot) + - Previous chunks' KV are loaded from CPU using GPU slots + - Each layer offloads from its own buffer - no waiting required! For each layer: - 1. Current chunk's KV is in k_batched, v_batched (just written by model) + 1. Current chunk's KV is in prefill_buffer[layer_id] (just written by model) 2. Load previous chunks from CPU using available slots (pipeline) 3. Compute attention against previous KV (no causal mask) - 4. Compute attention against current KV (causal) + 4. Compute attention against current KV from prefill buffer (causal) 5. Merge all results using online softmax + 6. Async offload prefill buffer to CPU (no waiting!) """ from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs current_chunk_idx = context.current_chunk_idx torch.cuda.nvtx.range_push(f"ChunkedPrefill: L{self.layer_id} Chunk{current_chunk_idx}") - # q, k, v shape: [total_tokens, num_heads, head_dim] - # Reshape for flash attention: [batch, seq, heads, dim] + # q shape: [total_tokens, num_heads, head_dim] q_batched = q.unsqueeze(0) # [1, total_tokens, heads, dim] - k_batched = k.unsqueeze(0) - v_batched = v.unsqueeze(0) + num_tokens = k.shape[0] o_acc = None lse_acc = None kvcache_manager = context.kvcache_manager seq = context.chunked_seq if hasattr(context, 'chunked_seq') else None + offload_engine = kvcache_manager.offload_engine if kvcache_manager is not None else None if kvcache_manager is not None and seq is not None and self.layer_id >= 0: # Get prefilled CPU blocks (blocks from previous chunks) @@ -210,11 +225,8 @@ class Attention(nn.Module): ) if cpu_block_table: - offload_engine = kvcache_manager.offload_engine - - # Get write slot for current chunk and available load slots - write_slot = offload_engine.get_write_slot_for_prefill(current_chunk_idx) - load_slots = offload_engine.get_load_slots_for_prefill(write_slot) + # Get available load slots (all slots can be used since we use prefill buffer) + load_slots = list(range(offload_engine.num_ring_slots)) pipeline_depth = len(load_slots) if pipeline_depth == 0: @@ -230,15 +242,14 @@ class Attention(nn.Module): ) # Get compute stream for all attention operations - compute_stream = None - if kvcache_manager is not None and hasattr(kvcache_manager, 'offload_engine'): - compute_stream = kvcache_manager.offload_engine.compute_stream + compute_stream = offload_engine.compute_stream if offload_engine is not None else None - # Compute attention against current chunk's KV (with causal mask) - # Use compute_stream to ensure proper sync with store_kvcache and offload + # Compute attention against current chunk's KV from prefill buffer (with causal mask) if compute_stream is not None: with torch.cuda.stream(compute_stream): torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} CurrentChunk (causal)") + # Get KV from per-layer prefill buffer + k_batched, v_batched = offload_engine.get_prefill_buffer_slice(self.layer_id, num_tokens) current_o, current_lse = flash_attn_with_lse( q_batched, k_batched, @@ -249,6 +260,8 @@ class Attention(nn.Module): torch.cuda.nvtx.range_pop() else: torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} CurrentChunk (causal)") + k_batched = k.unsqueeze(0) + v_batched = v.unsqueeze(0) current_o, current_lse = flash_attn_with_lse( q_batched, k_batched, @@ -274,26 +287,16 @@ class Attention(nn.Module): torch.cuda.nvtx.range_pop() # ChunkedPrefill - # Per-layer offload: In new GPU cache architecture (no layer dimension), - # each layer must offload its KV to CPU before next layer overwrites the GPU slot. - if kvcache_manager is not None and hasattr(kvcache_manager, 'offload_engine'): - offload_engine = kvcache_manager.offload_engine - write_slot = offload_engine.get_write_slot_for_prefill(current_chunk_idx) - seq = context.chunked_seq if hasattr(context, 'chunked_seq') else None - if seq is not None: - cpu_block_ids, _ = kvcache_manager.get_all_cpu_blocks(seq) - if current_chunk_idx < len(cpu_block_ids): - cpu_block_id = cpu_block_ids[current_chunk_idx] - # k.shape[0] = number of tokens in current chunk - num_valid_tokens = k.shape[0] - offload_engine.offload_slot_layer_to_cpu( - write_slot, self.layer_id, cpu_block_id, num_valid_tokens - ) - - # CRITICAL: compute_stream must wait for offload to complete - # before the next layer's store_kvcache can overwrite the GPU slot. - # Without this, Layer N+1's store races with Layer N's offload copy. - compute_stream.wait_event(offload_engine.ring_slot_offload_done[write_slot]) + # Per-layer ASYNC offload: offload prefill buffer to CPU + # No waiting required! Each layer has its own buffer and stream. + if offload_engine is not None and seq is not None: + cpu_block_ids, _ = kvcache_manager.get_all_cpu_blocks(seq) + if current_chunk_idx < len(cpu_block_ids): + cpu_block_id = cpu_block_ids[current_chunk_idx] + # Async offload - no waiting, fully parallel across layers + offload_engine.offload_prefill_buffer_async( + self.layer_id, cpu_block_id, num_tokens + ) # Sync default stream with compute_stream before returning # This ensures the result is ready for the rest of the model (layernorm, MLP)