From ccf27d3a74d27a1669ec872bcdda244cda71ccba Mon Sep 17 00:00:00 2001 From: Zijie Tian Date: Wed, 7 Jan 2026 05:58:23 +0800 Subject: [PATCH] [claudesquad] update from 'perf_opt-1' on 07 Jan 26 05:58 CST --- CLAUDE.md | 11 +++ nanovllm/engine/model_runner.py | 18 +++- nanovllm/kvcache/offload_engine.py | 152 ++++++++++++++++++++++++++++- nanovllm/layers/attention.py | 94 +++++++++++++++--- 4 files changed, 255 insertions(+), 20 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index c40c588..a29f610 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -46,6 +46,17 @@ python bench_offload.py ## Local Package Installation for Multi-Instance +**CRITICAL**: After ANY code modification in the `nanovllm/` directory, you MUST reinstall the package before running tests or benchmarks: + +```bash +pip install -e . --prefix=./.local --no-deps +``` + +Then run with PYTHONPATH: +```bash +PYTHONPATH=./.local/lib/python3.10/site-packages:$PYTHONPATH python +``` + **IMPORTANT**: When running multiple Claude instances on different worktrees, do NOT use `pip install -e .` globally as it will affect other instances. Instead, use local installation: 1. **Install to worktree-local directory**: diff --git a/nanovllm/engine/model_runner.py b/nanovllm/engine/model_runner.py index 3b463f0..4dd19f5 100644 --- a/nanovllm/engine/model_runner.py +++ b/nanovllm/engine/model_runner.py @@ -590,14 +590,15 @@ class ModelRunner: def run_chunked_offload_decode(self, seqs: list[Sequence]) -> list[int]: """ - Run decode with ring buffer (CPU is primary storage). + Run decode with cross-layer pipeline (CPU is primary storage). All KV is on CPU. Uses decode_slot (slot[0]) to write new KV. - Other slots (slots[1:]) are used to load previous KV chunks via pipeline. - New token's KV is written to decode_slot then offloaded to CPU only when block is full. + Optimized with cross-layer pipeline: Layer N's data is loaded while + Layer N-1 computes, achieving transfer/compute overlap. Key: decode_slot is dedicated to writing new KV, never used for loading. - Optimization: Batch offloads - only offload when block is full, attend to all accumulated tokens. + Optimization: Cross-layer pipeline reduces effective latency by overlapping + H2D transfers with attention computation across layers. """ assert len(seqs) == 1, "Ring buffer decode only supports single sequence" seq = seqs[0] @@ -618,6 +619,12 @@ class ModelRunner: # Get decode start position for accumulated token tracking decode_start_pos = self.kvcache_manager.get_decode_start_pos(seq) + # Get prefilled CPU blocks for pipeline initialization + cpu_block_table = self.kvcache_manager.get_prefilled_cpu_blocks(seq) + + # Start cross-layer pipeline (preloads Layer 0's data) + offload_engine.start_decode_pipeline(cpu_block_table) + # Set up context for chunked decode set_context( is_prefill=False, @@ -634,6 +641,9 @@ class ModelRunner: logits = self.run_model(input_ids, positions, is_prefill=False) reset_context() + # End cross-layer pipeline + offload_engine.end_decode_pipeline() + # Only offload when block is full (pos_in_block == block_size - 1) # This avoids unnecessary offloading on every decode step if pos_in_block == self.block_size - 1: diff --git a/nanovllm/kvcache/offload_engine.py b/nanovllm/kvcache/offload_engine.py index 5260906..5270c0e 100644 --- a/nanovllm/kvcache/offload_engine.py +++ b/nanovllm/kvcache/offload_engine.py @@ -142,6 +142,40 @@ class OffloadEngine: decode_buf_mb = 2 * num_layers * block_size * num_kv_heads * head_dim * dtype.itemsize / (1024 * 1024) logger.info(f" Per-layer decode buffer: {decode_buf_mb:.1f} MB") + # ========== Cross-layer pipeline buffers for decode ========== + # Double-buffered layer cache for pipelined decode: + # - Buffer A: Current layer's prefilled KV being computed + # - Buffer B: Next layer's prefilled KV being loaded + # Shape: [max_prefill_blocks, block_size, kv_heads, head_dim] + # Memory: 2 * max_prefill_blocks * block_size * kv_heads * head_dim * dtype_size + max_prefill_blocks = num_cpu_blocks # Can hold all prefill blocks + self.layer_k_buffer_a = torch.zeros( + max_prefill_blocks, block_size, num_kv_heads, head_dim, + dtype=dtype, device="cuda" + ) + self.layer_v_buffer_a = torch.zeros( + max_prefill_blocks, block_size, num_kv_heads, head_dim, + dtype=dtype, device="cuda" + ) + self.layer_k_buffer_b = torch.zeros( + max_prefill_blocks, block_size, num_kv_heads, head_dim, + dtype=dtype, device="cuda" + ) + self.layer_v_buffer_b = torch.zeros( + max_prefill_blocks, block_size, num_kv_heads, head_dim, + dtype=dtype, device="cuda" + ) + layer_buf_mb = 4 * max_prefill_blocks * block_size * num_kv_heads * head_dim * dtype.itemsize / (1024 * 1024) + logger.info(f" Cross-layer pipeline buffers: {layer_buf_mb:.1f} MB ({max_prefill_blocks} blocks × 2)") + + # Pipeline state tracking + self._pipeline_active = False + self._pipeline_current_buffer = 0 # 0 = buffer A, 1 = buffer B + self._pipeline_next_layer_event = torch.cuda.Event() + self._pipeline_cpu_blocks: list = [] # CPU block IDs to load + self._pipeline_num_blocks = 0 + self._pipeline_layer_stream = torch.cuda.Stream() # Dedicated stream for layer loading + # ========== Fixed-address CPU KV cache (pinned memory) ========== self.k_cache_cpu = torch.zeros( num_layers, num_cpu_blocks, block_size, num_kv_heads, head_dim, @@ -1063,4 +1097,120 @@ class OffloadEngine: # Allow pdb quit to propagate if e.__class__.__name__ == 'BdbQuit': raise - logger.warning(f"Debug hook error: {e}") \ No newline at end of file + logger.warning(f"Debug hook error: {e}") + + # ========== Cross-layer Pipeline Methods for Decode ========== + + def start_decode_pipeline(self, cpu_block_ids: List[int]) -> None: + """ + Start cross-layer pipeline for decode. + + Called at the beginning of a decode step to initialize the pipeline. + Preloads Layer 0's data into buffer A. + + Args: + cpu_block_ids: List of CPU block IDs for prefilled blocks + """ + if not cpu_block_ids: + self._pipeline_active = False + return + + self._pipeline_active = True + self._pipeline_cpu_blocks = cpu_block_ids + self._pipeline_num_blocks = len(cpu_block_ids) + self._pipeline_current_buffer = 0 + + # Preload Layer 0 into buffer A + self._load_layer_to_buffer(0, 0) # layer_id=0, buffer_idx=0 (A) + + def get_decode_layer_kv(self, layer_id: int, num_blocks: int) -> Tuple[Tensor, Tensor]: + """ + Get KV cache for a layer during decode. + + If pipeline is active, returns data from the current buffer. + Also triggers preloading of the next layer (if not last layer). + + Args: + layer_id: Current layer ID + num_blocks: Number of blocks to return + + Returns: + (k_cache, v_cache) tensors, shape: [num_blocks, block_size, kv_heads, head_dim] + """ + if not self._pipeline_active: + raise RuntimeError("Decode pipeline not active. Call start_decode_pipeline first.") + + # Wait for current layer's data to be ready + self.compute_stream.wait_event(self._pipeline_next_layer_event) + + # Get current buffer + if self._pipeline_current_buffer == 0: + k = self.layer_k_buffer_a[:num_blocks] + v = self.layer_v_buffer_a[:num_blocks] + else: + k = self.layer_k_buffer_b[:num_blocks] + v = self.layer_v_buffer_b[:num_blocks] + + # Trigger preloading of next layer (if not last layer) + next_layer_id = layer_id + 1 + if next_layer_id < self.num_layers: + # Use the other buffer for next layer + next_buffer_idx = 1 - self._pipeline_current_buffer + self._load_layer_to_buffer(next_layer_id, next_buffer_idx) + # Switch to next buffer for next layer + self._pipeline_current_buffer = next_buffer_idx + + return k, v + + def _load_layer_to_buffer(self, layer_id: int, buffer_idx: int) -> None: + """ + Async load a layer's prefilled blocks to the specified buffer. + + Uses sgDMA for efficient strided transfer from CPU cache. + + Args: + layer_id: Layer index to load + buffer_idx: 0 for buffer A, 1 for buffer B + """ + num_blocks = self._pipeline_num_blocks + cpu_block_ids = self._pipeline_cpu_blocks + + # Select target buffer + if buffer_idx == 0: + k_buffer = self.layer_k_buffer_a + v_buffer = self.layer_v_buffer_a + else: + k_buffer = self.layer_k_buffer_b + v_buffer = self.layer_v_buffer_b + + # Load all blocks for this layer using dedicated stream + with torch.cuda.stream(self._pipeline_layer_stream): + for i, cpu_block_id in enumerate(cpu_block_ids): + # Copy from CPU cache (has layer dimension) to GPU buffer + k_buffer[i].copy_( + self.k_cache_cpu[layer_id, cpu_block_id], + non_blocking=True + ) + v_buffer[i].copy_( + self.v_cache_cpu[layer_id, cpu_block_id], + non_blocking=True + ) + # Record event when all transfers complete + self._pipeline_next_layer_event.record(self._pipeline_layer_stream) + + def end_decode_pipeline(self) -> None: + """ + End the cross-layer pipeline. + + Called at the end of a decode step to clean up pipeline state. + """ + if self._pipeline_active: + # Ensure all transfers complete before ending + self._pipeline_layer_stream.synchronize() + self._pipeline_active = False + self._pipeline_cpu_blocks = [] + self._pipeline_num_blocks = 0 + + def is_pipeline_active(self) -> bool: + """Check if decode pipeline is currently active.""" + return self._pipeline_active \ No newline at end of file diff --git a/nanovllm/layers/attention.py b/nanovllm/layers/attention.py index 197d082..da36154 100644 --- a/nanovllm/layers/attention.py +++ b/nanovllm/layers/attention.py @@ -479,17 +479,15 @@ class Attention(nn.Module): context, ) -> torch.Tensor: """ - Compute decode attention using ring buffer pipeline (same as prefill). + Compute decode attention using cross-layer pipeline. - Uses the same loading mechanism as _chunked_prefill_attention: - - Load one block at a time from CPU to GPU slot - - Compute attention for each block - - Merge results using online softmax - - Finally merge with decode buffer (accumulated decode tokens) + Optimization: Uses double-buffered layer cache to overlap H2D transfer + with computation across layers: + - Layer N computes while Layer N+1's data is being loaded + - Each layer only waits for its own data, not all layers' data - This approach is simpler and proven correct (prefill tests pass). - The only difference from prefill is the additional decode buffer - that stores new tokens generated during decode. + This reduces effective latency from O(num_layers * transfer_time) to + O(transfer_time + num_layers * compute_time) when transfer < compute. """ from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs @@ -533,13 +531,20 @@ class Attention(nn.Module): ) offload_engine = kvcache_manager.offload_engine - load_slots = offload_engine.decode_load_slots # Available slots for loading - # Use ring buffer pipeline (same as prefill) to load prefilled blocks - o_acc, lse_acc = self._decode_ring_buffer_pipeline( - q_batched, cpu_block_table, load_slots, offload_engine, - block_size, last_block_valid_tokens - ) + # Use cross-layer pipeline if active (initialized in model_runner) + if offload_engine.is_pipeline_active(): + o_acc, lse_acc = self._decode_with_layer_pipeline( + q_batched, cpu_block_table, offload_engine, + block_size, last_block_valid_tokens + ) + else: + # Fallback to original ring buffer pipeline + load_slots = offload_engine.decode_load_slots + o_acc, lse_acc = self._decode_ring_buffer_pipeline( + q_batched, cpu_block_table, load_slots, offload_engine, + block_size, last_block_valid_tokens + ) # Now attend to accumulated decode tokens from per-layer decode buffer pos_in_block = context.decode_pos_in_block @@ -652,3 +657,62 @@ class Attention(nn.Module): o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse) return o_acc, lse_acc + + def _decode_with_layer_pipeline( + self, + q_batched: torch.Tensor, + cpu_block_table: list, + offload_engine, + block_size: int, + last_block_valid_tokens: int, + ): + """ + Decode using cross-layer pipeline for optimized H2D transfer. + + This method uses pre-loaded layer buffers instead of loading + blocks one by one. The pipeline loads the next layer's data + while the current layer computes, achieving transfer/compute overlap. + + The key insight is that each layer needs the SAME blocks but from + different layers of CPU cache. By double-buffering and pipelining + across layers, we reduce total latency. + """ + from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs + + num_blocks = len(cpu_block_table) + if num_blocks == 0: + return None, None + + compute_stream = offload_engine.compute_stream + + # Get KV from pre-loaded layer buffer (triggers next layer loading) + prev_k, prev_v = offload_engine.get_decode_layer_kv(self.layer_id, num_blocks) + + # prev_k, prev_v shape: [num_blocks, block_size, kv_heads, head_dim] + # Reshape to [1, num_blocks * block_size, kv_heads, head_dim] + total_tokens = num_blocks * block_size + + # Handle partial last block + if last_block_valid_tokens < block_size: + # Only use valid tokens from last block + actual_tokens = (num_blocks - 1) * block_size + last_block_valid_tokens + # Flatten and truncate + prev_k_flat = prev_k.reshape(-1, prev_k.shape[-2], prev_k.shape[-1])[:actual_tokens] + prev_v_flat = prev_v.reshape(-1, prev_v.shape[-2], prev_v.shape[-1])[:actual_tokens] + else: + prev_k_flat = prev_k.reshape(-1, prev_k.shape[-2], prev_k.shape[-1]) + prev_v_flat = prev_v.reshape(-1, prev_v.shape[-2], prev_v.shape[-1]) + + # Add batch dimension: [1, total_tokens, kv_heads, head_dim] + prev_k_batched = prev_k_flat.unsqueeze(0) + prev_v_batched = prev_v_flat.unsqueeze(0) + + # Compute attention on all prefilled blocks at once + with torch.cuda.stream(compute_stream): + o_acc, lse_acc = flash_attn_with_lse( + q_batched, prev_k_batched, prev_v_batched, + softmax_scale=self.scale, + causal=False, + ) + + return o_acc, lse_acc