diff --git a/nanovllm/engine/model_runner.py b/nanovllm/engine/model_runner.py index e66a57e..eea316d 100644 --- a/nanovllm/engine/model_runner.py +++ b/nanovllm/engine/model_runner.py @@ -44,7 +44,17 @@ class ModelRunner: self.allocate_kv_cache() if not self.enforce_eager: - self.capture_cudagraph() + if config.enable_cpu_offload: + # TODO: Implement capture_offload_cudagraph() for offload mode + # For now, offload mode uses eager execution + # The standard capture_cudagraph() cannot be used because: + # - It captures the PagedAttention decode path via Attention.forward() + # - In offload mode, Attention.k_cache/v_cache are empty (KV is in ring buffer) + # - The refactored offload decode now uses Attention.forward() with ring buffer + # - Need specialized graph capture that sets up ring buffer correctly + pass + else: + self.capture_cudagraph() torch.set_default_device("cpu") torch.set_default_dtype(default_dtype) @@ -845,9 +855,9 @@ class ModelRunner: Key design: - Ring buffer pipeline: load layer N+k while computing layer N + - Uses standard Attention.forward() path (not bypassing) - Per-layer decode buffer for accumulating new tokens - Async block offload when decode buffer is full - - Uses OffloadEngine's ring buffer API for H2D pipeline """ assert len(seqs) == 1, "Layer-wise offload only supports single sequence" seq = seqs[0] @@ -881,11 +891,15 @@ class ModelRunner: # Current decode position info pos_in_block = (len(seq) - 1) % self.block_size decode_start_pos = self.kvcache_manager.get_decode_start_pos(seq) - num_decode_tokens = pos_in_block - decode_start_pos + 1 + num_prev_decode_tokens = pos_in_block - decode_start_pos # Previous decode tokens (not including current) - # Import FlashAttention once - from flash_attn.flash_attn_interface import flash_attn_varlen_func - cu_seqlens_q = torch.tensor([0, 1], dtype=torch.int32, device="cuda") + # Total context length (prefill + previous decode tokens) + # New token will be stored at this position + context_len = total_prefill_tokens + num_prev_decode_tokens + + # Context setup for Attention.forward() - contiguous mode (no block tables) + slot_mapping = torch.tensor([context_len], dtype=torch.int32, device="cuda") + context_lens = torch.tensor([context_len + 1], dtype=torch.int32, device="cuda") # Phase 1: Preload first N layers to ring buffer (fill pipeline) num_preload = min(num_buffers, num_layers) @@ -902,94 +916,70 @@ class ModelRunner: # Phase 2: Layer-by-layer processing with ring buffer pipeline for layer_id in range(num_layers): layer = self.model.model.layers[layer_id] + attn_module = layer.self_attn.attn # The Attention module current_buffer = layer_id % num_buffers # 2a. Wait for current buffer's load to complete offload_engine.wait_buffer_load(current_buffer) - # 2c. Input LayerNorm - if residual is None: - hidden_ln, residual = layer.input_layernorm(hidden_states), hidden_states - else: - hidden_ln, residual = layer.input_layernorm(hidden_states, residual) - - # 2d. QKV projection for new token - qkv = layer.self_attn.qkv_proj(hidden_ln) - q, k_new, v_new = qkv.split([ - layer.self_attn.q_size, - layer.self_attn.kv_size, - layer.self_attn.kv_size - ], dim=-1) - - q = q.view(1, layer.self_attn.num_heads, layer.self_attn.head_dim) - k_new = k_new.view(1, layer.self_attn.num_kv_heads, layer.self_attn.head_dim) - v_new = v_new.view(1, layer.self_attn.num_kv_heads, layer.self_attn.head_dim) - - # Q/K norms - if not layer.self_attn.qkv_bias: - q = layer.self_attn.q_norm(q.reshape(-1, layer.self_attn.head_dim)) - q = q.view(1, layer.self_attn.num_heads, layer.self_attn.head_dim) - k_new = layer.self_attn.k_norm(k_new.reshape(-1, layer.self_attn.head_dim)) - k_new = k_new.view(1, layer.self_attn.num_kv_heads, layer.self_attn.head_dim) - - # RoPE - q, k_new = layer.self_attn.rotary_emb(positions, q, k_new) - - # 2e. Get prefilled KV from ring buffer - k_prefill, v_prefill = offload_engine.get_buffer_kv(current_buffer, total_prefill_tokens) - - # 2f. Get accumulated decode KV from decode buffer (if any previous decode tokens) - if num_decode_tokens > 1: + # 2b. Copy previous decode KV from decode buffer to ring buffer + # Ring buffer already has prefill KV at [0:total_prefill_tokens] + # We need to add decode KV at [total_prefill_tokens:] + if num_prev_decode_tokens > 0: k_decode_prev, v_decode_prev = offload_engine.get_decode_kv( layer_id, decode_start_pos, pos_in_block ) - k_full = torch.cat([k_prefill, k_decode_prev, k_new], dim=0) - v_full = torch.cat([v_prefill, v_decode_prev, v_new], dim=0) - else: - k_full = torch.cat([k_prefill, k_new], dim=0) - v_full = torch.cat([v_prefill, v_new], dim=0) + ring_k = offload_engine.layer_k_cache[current_buffer] + ring_v = offload_engine.layer_v_cache[current_buffer] + ring_k[total_prefill_tokens:total_prefill_tokens + num_prev_decode_tokens].copy_(k_decode_prev) + ring_v[total_prefill_tokens:total_prefill_tokens + num_prev_decode_tokens].copy_(v_decode_prev) - # 2g. Store new KV to decode buffer for future decode steps - offload_engine.store_decode_kv(layer_id, pos_in_block, k_new, v_new) + # 2c. Set Attention module's cache to ring buffer (contiguous format) + # Shape: [max_seq_len, kv_heads, head_dim] -> [1, max_seq_len, kv_heads, head_dim] + attn_module.k_cache = offload_engine.layer_k_cache[current_buffer:current_buffer+1] + attn_module.v_cache = offload_engine.layer_v_cache[current_buffer:current_buffer+1] - # 2h. Mark buffer compute done (allows next load to reuse this buffer) + # 2d. Set context for Attention.forward() - contiguous mode + set_context( + is_prefill=False, + slot_mapping=slot_mapping, + context_lens=context_lens, + block_tables=None, # Contiguous mode, no block tables + ) + + # 2e. Forward through layer using standard path + # This calls Qwen3Attention.forward() -> Attention.forward() + # Attention.forward() will: + # - Store new K,V to ring buffer via store_kvcache + # - Compute attention via flash_attn_with_kvcache + hidden_states, residual = layer(positions, hidden_states, residual) + + # 2f. Copy new token's KV from ring buffer to decode buffer (for persistence) + # The new token was stored at position context_len in ring buffer + ring_k = offload_engine.layer_k_cache[current_buffer] + ring_v = offload_engine.layer_v_cache[current_buffer] + offload_engine.decode_k_buffer[layer_id, pos_in_block].copy_(ring_k[context_len]) + offload_engine.decode_v_buffer[layer_id, pos_in_block].copy_(ring_v[context_len]) + + # 2g. Mark buffer compute done (allows next load to reuse this buffer) offload_engine.record_buffer_compute_done(current_buffer) - # 2i. Start loading next layer to same buffer (after compute done) + # 2h. Start loading next layer to same buffer (after compute done) next_layer_to_load = layer_id + num_buffers if next_layer_to_load < num_layers: offload_engine.load_layer_kv_to_buffer( current_buffer, next_layer_to_load, cpu_block_table, valid_tokens_per_block ) - # 2j. Compute attention - total_kv_tokens = k_full.shape[0] - cu_seqlens_k = torch.tensor([0, total_kv_tokens], dtype=torch.int32, device="cuda") - - attn_output = flash_attn_varlen_func( - q, k_full, v_full, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=1, - max_seqlen_k=total_kv_tokens, - softmax_scale=layer.self_attn.attn.scale, - causal=False, - ) - - # O projection - attn_output = attn_output.view(1, -1) - hidden_states = layer.self_attn.o_proj(attn_output) - - # 2k. Post-attention LayerNorm + MLP - hidden_states, residual = layer.post_attention_layernorm(hidden_states, residual) - hidden_states = layer.mlp(hidden_states) - # Step 3: Final norm hidden_states, _ = self.model.model.norm(hidden_states, residual) # Step 4: Compute logits logits = self.model.compute_logits(hidden_states) + # Reset context + reset_context() + # Step 5: Handle block-full offload (async) if pos_in_block == self.block_size - 1: last_cpu_block = self.kvcache_manager.get_last_cpu_block(seq) diff --git a/task_plan.md b/task_plan.md index 7c791c4..177647f 100644 --- a/task_plan.md +++ b/task_plan.md @@ -1,412 +1,274 @@ -# Task Plan: Fix GPU-only Mode Performance Issue - -## Goal -Eliminate the `store_kvcache` scatter overhead in GPU-only mode by using **contiguous KV cache layout** (like offload mode), avoiding PagedAttention's blocked layout for single-sequence inference. +# Task Plan: Enable CUDA Graphs for CPU Offload Mode ## Problem Summary -GPU-only mode with MInference is **slower** than CPU offload mode: +Running `bench_offload.py` fails with: +``` +IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1) +``` -| Mode | Prefill Speed (32K tokens, Qwen3-4B) | -|------|--------------------------------------| -| GPU-only + MInference | 3383 tok/s | -| Offload + MInference | 5373 tok/s | +**Root cause**: In offload mode, `HybridKVCacheManager.get_layer_cache()` returns empty tensors (by design), but CUDA graph capture calls `Attention.forward()` decode path which expects valid k_cache/v_cache. -**Root cause**: PagedAttention's blocked layout requires expensive `index_copy_` scatter operations to convert contiguous K,V to blocked format. +**User requirement**: Enable CUDA graphs in offload mode for better decode performance. -## Key Insight: Why Offload is Fast +## Deep Analysis: Why Current Design is Incompatible -Offload mode uses **contiguous layout** for KV cache: +### Current Offload Decode Flow (`run_layerwise_offload_decode`) + +``` +1. Preload N layers to ring buffer (H2D async) +2. For each layer: + a. Wait for buffer load + b. LayerNorm → QKV proj → RoPE + c. k_full = torch.cat([k_prefill, k_decode_prev, k_new]) <-- DYNAMIC SHAPE + d. flash_attn_varlen_func(q, k_full, v_full, ...) <-- VARIABLE LENGTH + e. O_proj → MLP + f. Start next layer H2D load +3. Final norm → Logits → Sample +``` + +### CUDA Graph Incompatibility Points + +| Issue | Location | Why Incompatible | +|-------|----------|------------------| +| Dynamic tensor creation | `torch.cat([k_prefill, ...])` | Creates new tensors with variable shapes | +| Variable-length attention | `flash_attn_varlen_func` | `max_seqlen_k` changes every step | +| Data-dependent branching | `if num_decode_tokens > 1` | Control flow varies at runtime | +| Empty k_cache/v_cache | `Attention.forward()` | Current capture uses standard decode path | + +### Why Empty Tensors in Offload Mode? + +`HybridKVCacheManager.get_layer_cache()` returns empty tensors because: +- Offload mode manages KV via `OffloadEngine`'s ring buffer +- The standard `Attention.forward()` is NEVER used in offload inference +- Empty tensors are intentional placeholders + +## Solution: Fixed-Address CUDA Graph Capture for Offload Decode + +### Key Insight + +The `OffloadEngine` ring buffer already has **fixed GPU addresses** with **fixed max shape**: +```python +layer_k_cache: [num_kv_buffers, max_seq_len, kv_heads, head_dim] # Fixed! +layer_v_cache: [num_kv_buffers, max_seq_len, kv_heads, head_dim] # Fixed! +``` + +`flash_attn_with_kvcache` supports **cache_seqlens** parameter for variable actual lengths with fixed-shape cache. This is the key to CUDA graph compatibility! + +### Solution Design + +Replace `torch.cat` + `flash_attn_varlen_func` with: +1. Pre-copy decode buffer content to ring buffer at correct offset +2. Store new token KV directly to ring buffer +3. Use `flash_attn_with_kvcache` with `cache_seqlens` for variable length ```python -# OffloadEngine's CPU cache layout -k_cache_cpu: [num_layers, num_blocks, block_size, kv_heads, head_dim] +# Before (dynamic, not graphable): +k_full = torch.cat([k_prefill, k_decode_prev, k_new], dim=0) +o = flash_attn_varlen_func(q, k_full, v_full, ...) -# Store is simple contiguous slice assignment -self.k_cache_cpu[layer_id, block_id, :actual_size].copy_(k[start:end]) +# After (fixed addresses, graphable): +# Ring buffer already has k_prefill at [0:prefill_len] +# Copy decode_prev and k_new to buffer at [prefill_len:] +ring_buffer[prefill_len:prefill_len+decode_len] = decode_buffer +ring_buffer[total_len-1] = k_new + +o = flash_attn_with_kvcache( + q.unsqueeze(1), # [1, 1, heads, dim] - fixed shape + ring_k.unsqueeze(0), # [1, max_seq_len, heads, dim] - FIXED ADDRESS + ring_v.unsqueeze(0), # [1, max_seq_len, heads, dim] - FIXED ADDRESS + cache_seqlens=total_tokens_tensor, # [1] - variable VALUE, fixed address + softmax_scale=scale, + causal=True, +) ``` -The K,V computed during prefill `[seq_len, kv_heads, head_dim]` matches the cache layout - no format conversion needed! +## Implementation Plan -## Solution: Contiguous Layout for GPU-only Mode - -For GPU-only single-sequence mode, use the **same contiguous layout as offload mode**, but on GPU: - -``` -Current GPU-only (PagedAttention): - Cache: [num_blocks, block_size, kv_heads, head_dim] (blocked) - Store: scatter via index_copy_ (SLOW) - -Proposed GPU-only (Contiguous): - Cache: [num_layers, max_seq_len, kv_heads, head_dim] (contiguous) - Store: slice assignment k_cache[layer_id, :seq_len] = k (FAST) -``` - -This mirrors offload mode's architecture but keeps everything on GPU - no cross-device transfer, no layout conversion. - -## Phases -- [x] Phase 1: Add contiguous GPU KV cache in GPUOnlyManager (for single-seq mode) -- [x] Phase 2: Implement `run_gpu_only_prefill()` using contiguous cache -- [x] Phase 3: Implement decode path for contiguous cache -- [x] Phase 4: Test and validate performance - -## Results - -| Mode | 32K Prefill Speed | Notes | -|------|-------------------|-------| -| GPU-only (before) | ~3383 tok/s | PagedAttention scatter overhead | -| GPU-only contiguous (after) | **5293 tok/s** | 56% improvement | -| Offload mode | 5391 tok/s | Baseline comparison | - -**Test passed**: `test_needle.py --input-len 32768 --max-model-len 40960` - correct output retrieved. - -## Detailed Design - -### Phase 1: Contiguous GPU KV Cache - -**File**: `nanovllm/kvcache/gpu_manager.py` - -Add contiguous cache allocation for single-sequence mode: - -```python -class GPUOnlyManager(KVCacheManager): - def __init__(self, num_blocks: int, block_size: int, max_seq_len: int = 0): - # ... existing code ... - self.max_seq_len = max_seq_len - - # Contiguous cache for single-seq mode (allocated in allocate_cache) - self.contiguous_k_cache = None # [num_layers, max_seq_len, kv_heads, head_dim] - self.contiguous_v_cache = None - - def allocate_cache( - self, - num_layers: int, - num_kv_heads: int, - head_dim: int, - dtype: torch.dtype, - ) -> None: - # Existing PagedAttention cache for multi-seq/decode - self.kv_cache = torch.empty( - 2, num_layers, self._num_blocks, self._block_size, - num_kv_heads, head_dim, - dtype=dtype, device="cuda" - ) - - # Contiguous cache for single-seq prefill (if max_seq_len specified) - if self.max_seq_len > 0: - self.contiguous_k_cache = torch.empty( - num_layers, self.max_seq_len, num_kv_heads, head_dim, - dtype=dtype, device="cuda" - ) - self.contiguous_v_cache = torch.empty( - num_layers, self.max_seq_len, num_kv_heads, head_dim, - dtype=dtype, device="cuda" - ) -``` - -### Phase 2: Layer-wise GPU-only Prefill +### Phase 1: Modify Offload Decode for CUDA Graph Compatibility **File**: `nanovllm/engine/model_runner.py` -Following offload pattern exactly - store K,V per-layer to contiguous cache: +**Changes**: +1. Add `capture_offload_cudagraph()` method +2. Modify `run_layerwise_offload_decode()` to use fixed-address buffers +3. Replace `flash_attn_varlen_func` with `flash_attn_with_kvcache` + +#### 1.1 New Method: `capture_offload_cudagraph()` ```python @torch.inference_mode() -def run_gpu_only_prefill(self, seqs: list[Sequence]) -> list[int]: +def capture_offload_cudagraph(self): """ - GPU-only prefill with contiguous KV cache layout. + Capture CUDA graphs for offload decode. - Mirrors run_layerwise_offload_prefill() but stores to GPU instead of CPU. - No scatter operations - just contiguous slice assignment. + Key design: + - Uses OffloadEngine's ring buffer as fixed-address k_cache/v_cache + - Captures per-layer compute (after H2D load is done) + - Uses flash_attn_with_kvcache with cache_seqlens for variable context """ - assert len(seqs) == 1, "GPU-only layer-wise prefill only supports single sequence" - seq = seqs[0] - + offload_engine = self.kvcache_manager.offload_engine num_layers = len(self.model.model.layers) - total_tokens = len(seq) + num_buffers = offload_engine.num_kv_buffers + max_seq_len = offload_engine.max_seq_len - # Get contiguous GPU cache - k_cache = self.kvcache_manager.contiguous_k_cache - v_cache = self.kvcache_manager.contiguous_v_cache + # Fixed-address tensors for graph capture + input_ids = torch.zeros(1, dtype=torch.int64, device="cuda") + positions = torch.zeros(1, dtype=torch.int64, device="cuda") + cache_seqlens = torch.zeros(1, dtype=torch.int32, device="cuda") + hidden_output = torch.zeros(1, self.config.hf_config.hidden_size, device="cuda") - # Prepare inputs - input_ids = torch.tensor(seq[:], dtype=torch.int64, device="cuda") - positions = torch.arange(total_tokens, dtype=torch.int64, device="cuda") + # Graph capture per buffer slot (deterministic: layer_id % num_buffers) + self.offload_graphs = {} + self.offload_graph_pool = None - from flash_attn.flash_attn_interface import flash_attn_varlen_func - cu_seqlens = torch.tensor([0, total_tokens], dtype=torch.int32, device="cuda") + for buffer_idx in range(num_buffers): + graph = torch.cuda.CUDAGraph() - # Embedding - hidden_states = self.model.model.embed_tokens(input_ids) - residual = None + # Get fixed-address ring buffer for this slot + k_cache = offload_engine.layer_k_cache[buffer_idx:buffer_idx+1] # [1, max_seq, heads, dim] + v_cache = offload_engine.layer_v_cache[buffer_idx:buffer_idx+1] - # Layer-by-layer processing (same as offload prefill) - for layer_id in range(num_layers): - layer = self.model.model.layers[layer_id] + # Warmup + with torch.cuda.stream(offload_engine.compute_stream): + # ... (layer forward pass using k_cache, v_cache) + pass - # Input LayerNorm - if residual is None: - hidden_ln, residual = layer.input_layernorm(hidden_states), hidden_states - else: - hidden_ln, residual = layer.input_layernorm(hidden_states, residual) + # Capture + with torch.cuda.graph(graph, self.offload_graph_pool): + # ... (same layer forward pass) + pass - # QKV projection - qkv = layer.self_attn.qkv_proj(hidden_ln) - q, k, v = qkv.split([ - layer.self_attn.q_size, - layer.self_attn.kv_size, - layer.self_attn.kv_size - ], dim=-1) + if self.offload_graph_pool is None: + self.offload_graph_pool = graph.pool() + self.offload_graphs[buffer_idx] = graph - q = q.view(total_tokens, layer.self_attn.num_heads, layer.self_attn.head_dim) - k = k.view(total_tokens, layer.self_attn.num_kv_heads, layer.self_attn.head_dim) - v = v.view(total_tokens, layer.self_attn.num_kv_heads, layer.self_attn.head_dim) - - # Q/K norms (Qwen3 specific) - if not layer.self_attn.qkv_bias: - q = layer.self_attn.q_norm(q.reshape(-1, layer.self_attn.head_dim)) - q = q.view(total_tokens, layer.self_attn.num_heads, layer.self_attn.head_dim) - k = layer.self_attn.k_norm(k.reshape(-1, layer.self_attn.head_dim)) - k = k.view(total_tokens, layer.self_attn.num_kv_heads, layer.self_attn.head_dim) - - # RoPE - q, k = layer.self_attn.rotary_emb(positions, q, k) - - # Store K,V to contiguous GPU cache (same layout - no conversion!) - # This is just slice assignment, not scatter - k_cache[layer_id, :total_tokens] = k - v_cache[layer_id, :total_tokens] = v - - # Sparse or Full attention (uses k, v directly) - if self.sparse_prefill_policy is not None: - attn_output = self.sparse_prefill_policy.sparse_prefill_attention( - q, k, v, layer_id - ) - else: - attn_output = flash_attn_varlen_func( - q, k, v, - cu_seqlens_q=cu_seqlens, - cu_seqlens_k=cu_seqlens, - max_seqlen_q=total_tokens, - max_seqlen_k=total_tokens, - softmax_scale=layer.self_attn.attn.scale, - causal=True, - ) - - # O projection - attn_output = attn_output.view(total_tokens, -1) - hidden_states = layer.self_attn.o_proj(attn_output) - - # Post-attention LayerNorm + MLP - hidden_states, residual = layer.post_attention_layernorm(hidden_states, residual) - hidden_states = layer.mlp(hidden_states) - - # Final norm - hidden_states, _ = self.model.model.norm(hidden_states, residual) - - # Compute logits - logits = self.model.compute_logits(hidden_states[-1:]) - - # Record prefill length for decode - self.kvcache_manager.contiguous_seq_len = total_tokens - - # Sample - temperatures = self.prepare_sample(seqs) if self.rank == 0 else None - token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None - - return token_ids + self.offload_graph_vars = dict( + input_ids=input_ids, + positions=positions, + cache_seqlens=cache_seqlens, + hidden_output=hidden_output, + ) ``` -### Phase 3: Decode with Contiguous Cache +#### 1.2 Modified `run_layerwise_offload_decode()` + +Key changes: +1. Copy decode buffer content to ring buffer before attention +2. Store new token directly to ring buffer +3. Use `flash_attn_with_kvcache` instead of `flash_attn_varlen_func` +4. Optionally use captured CUDA graph for per-layer compute + +```python +# In the layer loop, replace: +k_full = torch.cat([k_prefill, k_decode_prev, k_new], dim=0) +attn_output = flash_attn_varlen_func(q, k_full, v_full, ...) + +# With: +# 1. Get ring buffer slice +k_buffer = offload_engine.layer_k_cache[buffer_idx] # [max_seq_len, heads, dim] +v_buffer = offload_engine.layer_v_cache[buffer_idx] + +# 2. Copy decode buffer to ring buffer (after prefill content) +if num_decode_tokens > 1: + k_buffer[total_prefill_tokens:total_prefill_tokens+num_decode_tokens-1].copy_(k_decode_prev) + v_buffer[total_prefill_tokens:total_prefill_tokens+num_decode_tokens-1].copy_(v_decode_prev) + +# 3. Store new token to ring buffer +total_kv_tokens = total_prefill_tokens + num_decode_tokens +k_buffer[total_kv_tokens-1].copy_(k_new.squeeze(0)) +v_buffer[total_kv_tokens-1].copy_(v_new.squeeze(0)) + +# 4. Flash attention with fixed-address cache +cache_seqlens = torch.tensor([total_kv_tokens], dtype=torch.int32, device="cuda") +attn_output = flash_attn_with_kvcache( + q.unsqueeze(1), # [1, 1, heads, dim] + k_buffer.unsqueeze(0), # [1, max_seq_len, heads, dim] - FIXED ADDRESS + v_buffer.unsqueeze(0), # [1, max_seq_len, heads, dim] - FIXED ADDRESS + cache_seqlens=cache_seqlens, + softmax_scale=layer.self_attn.attn.scale, + causal=True, +) +attn_output = attn_output.squeeze(1) # [1, heads*dim] +``` + +### Phase 2: Handle CUDA Graph Capture in `__init__` **File**: `nanovllm/engine/model_runner.py` -```python -@torch.inference_mode() -def run_gpu_only_decode(self, seqs: list[Sequence]) -> list[int]: - """ - Decode using contiguous GPU KV cache. - - Similar to offload decode but simpler - all KV already on GPU. - """ - assert len(seqs) == 1 - seq = seqs[0] - - num_layers = len(self.model.model.layers) - k_cache = self.kvcache_manager.contiguous_k_cache - v_cache = self.kvcache_manager.contiguous_v_cache - context_len = self.kvcache_manager.contiguous_seq_len - - # Prepare inputs - input_ids = torch.tensor([seq.last_token], dtype=torch.int64, device="cuda") - positions = torch.tensor([len(seq) - 1], dtype=torch.int64, device="cuda") - - from flash_attn.flash_attn_interface import flash_attn_varlen_func - cu_seqlens_q = torch.tensor([0, 1], dtype=torch.int32, device="cuda") - - # Embedding - hidden_states = self.model.model.embed_tokens(input_ids) - residual = None - - for layer_id in range(num_layers): - layer = self.model.model.layers[layer_id] - - # Input LayerNorm - if residual is None: - hidden_ln, residual = layer.input_layernorm(hidden_states), hidden_states - else: - hidden_ln, residual = layer.input_layernorm(hidden_states, residual) - - # QKV projection - qkv = layer.self_attn.qkv_proj(hidden_ln) - q, k_new, v_new = qkv.split([ - layer.self_attn.q_size, - layer.self_attn.kv_size, - layer.self_attn.kv_size - ], dim=-1) - - q = q.view(1, layer.self_attn.num_heads, layer.self_attn.head_dim) - k_new = k_new.view(1, layer.self_attn.num_kv_heads, layer.self_attn.head_dim) - v_new = v_new.view(1, layer.self_attn.num_kv_heads, layer.self_attn.head_dim) - - # Q/K norms - if not layer.self_attn.qkv_bias: - q = layer.self_attn.q_norm(q.reshape(-1, layer.self_attn.head_dim)) - q = q.view(1, layer.self_attn.num_heads, layer.self_attn.head_dim) - k_new = layer.self_attn.k_norm(k_new.reshape(-1, layer.self_attn.head_dim)) - k_new = k_new.view(1, layer.self_attn.num_kv_heads, layer.self_attn.head_dim) - - # RoPE - q, k_new = layer.self_attn.rotary_emb(positions, q, k_new) - - # Get cached K,V and append new token - k_cached = k_cache[layer_id, :context_len] - v_cached = v_cache[layer_id, :context_len] - - # Store new K,V to cache - k_cache[layer_id, context_len] = k_new.squeeze(0) - v_cache[layer_id, context_len] = v_new.squeeze(0) - - # Full K,V for attention - k_full = k_cache[layer_id, :context_len + 1] - v_full = v_cache[layer_id, :context_len + 1] - - # Attention - cu_seqlens_k = torch.tensor([0, context_len + 1], dtype=torch.int32, device="cuda") - attn_output = flash_attn_varlen_func( - q, k_full, v_full, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=1, - max_seqlen_k=context_len + 1, - softmax_scale=layer.self_attn.attn.scale, - causal=False, # Single query, no causal needed - ) - - # O projection - attn_output = attn_output.view(1, -1) - hidden_states = layer.self_attn.o_proj(attn_output) - - # Post-attention LayerNorm + MLP - hidden_states, residual = layer.post_attention_layernorm(hidden_states, residual) - hidden_states = layer.mlp(hidden_states) - - # Update context length - self.kvcache_manager.contiguous_seq_len = context_len + 1 - - # Final norm - hidden_states, _ = self.model.model.norm(hidden_states, residual) - - # Compute logits - logits = self.model.compute_logits(hidden_states) - - # Sample - temperatures = self.prepare_sample(seqs) if self.rank == 0 else None - token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None - - return token_ids -``` - -### Phase 4: Decision Logic +**Change**: Line 46-47 ```python -def _should_use_contiguous_gpu_mode(self, seqs: list[Sequence], is_prefill: bool) -> bool: - """Check if contiguous GPU mode should be used.""" - # Must have contiguous cache allocated - if not hasattr(self.kvcache_manager, 'contiguous_k_cache'): - return False - if self.kvcache_manager.contiguous_k_cache is None: - return False +# Current (crashes in offload mode): +if not self.enforce_eager: + self.capture_cudagraph() - # Must NOT be offload mode - if hasattr(self.kvcache_manager, 'offload_engine'): - return False - - # Single sequence only - if len(seqs) != 1: - return False - - # For prefill: has blocks (not warmup) - if is_prefill and not seqs[0].block_table: - return False - - return True - - -def run(self, seqs: list[Sequence], is_prefill: bool) -> list[int]: - # Check offload mode (existing) - if hasattr(self, 'kvcache_manager') and hasattr(self.kvcache_manager, 'offload_engine'): - ... - - # Check contiguous GPU mode - if self._should_use_contiguous_gpu_mode(seqs, is_prefill): - if is_prefill: - return self.run_gpu_only_prefill(seqs) - else: - return self.run_gpu_only_decode(seqs) - - # Standard PagedAttention path - ... +# Fixed (conditional capture based on mode): +if not self.enforce_eager: + if config.enable_cpu_offload: + self.capture_offload_cudagraph() # New method for offload mode + else: + self.capture_cudagraph() # Standard PagedAttention decode ``` -## Architecture Comparison +### Phase 3: Per-Layer Graph vs Full-Decode Graph -| Aspect | Offload Mode | GPU-only (Proposed) | GPU-only (Current) | -|--------|--------------|---------------------|-------------------| -| Cache location | CPU (contiguous) | GPU (contiguous) | GPU (PagedAttention) | -| Cache layout | `[layers, blocks, block_size, heads, dim]` | `[layers, max_seq_len, heads, dim]` | `[blocks, block_size, heads, dim]` | -| Prefill store | Contiguous slice copy | **Slice assignment (no copy!)** | Scatter (index_copy_) | -| Decode read | H2D ring buffer | Direct GPU access | PagedAttention | +Two approaches for graph capture: -## Key Points +#### Option A: Per-Layer Graphs (Simpler, Less Overhead Reduction) +- Capture N graphs (one per buffer slot) +- Each graph covers: LayerNorm → QKV → RoPE → Attention → O_proj → MLP +- H2D transfers and buffer management outside graph -1. **No explicit copy_ needed**: Slice assignment `cache[layer, :len] = k` is direct memory write -2. **Same layout as computed K,V**: No format conversion required -3. **Mirrors offload architecture**: Same layer-wise processing pattern -4. **GPU advantage**: No cross-device transfer, faster than offload +#### Option B: Full-Decode Graph (More Complex, Maximum Overhead Reduction) +- Capture one graph for entire decode step (all layers) +- Requires all H2D loads completed before graph replay +- Better kernel fusion, less CPU overhead -## Memory Usage +**Recommendation**: Start with Option A (simpler), optimize to Option B later. -Contiguous GPU cache: `2 * num_layers * max_seq_len * kv_heads * head_dim * dtype_size` +## Implementation Phases -For Qwen3-4B with 32K max_seq_len: -- `2 * 28 * 32768 * 8 * 128 * 2 = 3.5GB` +| Phase | Description | Status | +|-------|-------------|--------| +| Phase 1 | Modify decode to use fixed-address buffers + flash_attn_with_kvcache | [ ] | +| Phase 2 | Add `capture_offload_cudagraph()` method | [ ] | +| Phase 3 | Update `__init__` to call correct capture method | [ ] | +| Phase 4 | Test with `bench_offload.py` | [ ] | +| Phase 5 | Benchmark performance improvement | [ ] | -Same as offload mode's CPU cache, but on GPU. +## Key Code Changes Summary -## Files to Modify +| File | Change | +|------|--------| +| `model_runner.py:46-47` | Conditional CUDA graph capture based on offload mode | +| `model_runner.py` (new) | Add `capture_offload_cudagraph()` method | +| `model_runner.py:850-1010` | Modify `run_layerwise_offload_decode()` to use fixed-address attention | -| File | Changes | -|------|---------| -| `nanovllm/kvcache/gpu_manager.py` | Add contiguous cache allocation | -| `nanovllm/engine/model_runner.py` | Add `run_gpu_only_prefill()`, `run_gpu_only_decode()`, modify `run()` | +## Alternative: Quick Fix (Skip Graph Capture) -## Expected Performance +If CUDA graph support is not immediately needed, the simpler fix is: -| Metric | Before | After | Improvement | -|--------|--------|-------|-------------| -| GPU-only prefill (32K) | 3383 tok/s | ~5400+ tok/s | ~60%+ | -| Decode | Baseline | Similar | ~0% | +```python +# Line 46-47 in model_runner.py +if not self.enforce_eager and not config.enable_cpu_offload: + self.capture_cudagraph() +``` -## Status -**Currently in Phase 1** - Ready to implement contiguous GPU cache +This skips CUDA graph capture entirely in offload mode. Offload mode will use eager execution (which already works). + +## Risk Assessment + +| Risk | Mitigation | +|------|------------| +| flash_attn_with_kvcache API differences | Test with actual flash-attn version | +| Memory overhead of fixed-size buffers | Already allocated in OffloadEngine | +| Performance regression | Benchmark before/after | +| Graph capture complexity | Start with per-layer graphs | + +## Expected Performance Impact + +| Metric | Without Graph | With Graph | Improvement | +|--------|---------------|------------|-------------| +| Decode latency per token | Baseline | ~10-20% faster | Reduced kernel launch overhead | +| GPU utilization | Medium | Higher | Better kernel fusion |