Merge remote-tracking branch 'origin/zijie/fix-bug-2' into tzj/vs_offload

This commit is contained in:
Zijie Tian
2026-01-09 15:21:48 +08:00
2 changed files with 234 additions and 433 deletions

View File

@@ -44,7 +44,17 @@ class ModelRunner:
self.allocate_kv_cache()
if not self.enforce_eager:
self.capture_cudagraph()
if config.enable_cpu_offload:
# TODO: Implement capture_offload_cudagraph() for offload mode
# For now, offload mode uses eager execution
# The standard capture_cudagraph() cannot be used because:
# - It captures the PagedAttention decode path via Attention.forward()
# - In offload mode, Attention.k_cache/v_cache are empty (KV is in ring buffer)
# - The refactored offload decode now uses Attention.forward() with ring buffer
# - Need specialized graph capture that sets up ring buffer correctly
pass
else:
self.capture_cudagraph()
torch.set_default_device("cpu")
torch.set_default_dtype(default_dtype)
@@ -845,9 +855,9 @@ class ModelRunner:
Key design:
- Ring buffer pipeline: load layer N+k while computing layer N
- Uses standard Attention.forward() path (not bypassing)
- Per-layer decode buffer for accumulating new tokens
- Async block offload when decode buffer is full
- Uses OffloadEngine's ring buffer API for H2D pipeline
"""
assert len(seqs) == 1, "Layer-wise offload only supports single sequence"
seq = seqs[0]
@@ -881,11 +891,15 @@ class ModelRunner:
# Current decode position info
pos_in_block = (len(seq) - 1) % self.block_size
decode_start_pos = self.kvcache_manager.get_decode_start_pos(seq)
num_decode_tokens = pos_in_block - decode_start_pos + 1
num_prev_decode_tokens = pos_in_block - decode_start_pos # Previous decode tokens (not including current)
# Import FlashAttention once
from flash_attn.flash_attn_interface import flash_attn_varlen_func
cu_seqlens_q = torch.tensor([0, 1], dtype=torch.int32, device="cuda")
# Total context length (prefill + previous decode tokens)
# New token will be stored at this position
context_len = total_prefill_tokens + num_prev_decode_tokens
# Context setup for Attention.forward() - contiguous mode (no block tables)
slot_mapping = torch.tensor([context_len], dtype=torch.int32, device="cuda")
context_lens = torch.tensor([context_len + 1], dtype=torch.int32, device="cuda")
# Phase 1: Preload first N layers to ring buffer (fill pipeline)
num_preload = min(num_buffers, num_layers)
@@ -902,94 +916,70 @@ class ModelRunner:
# Phase 2: Layer-by-layer processing with ring buffer pipeline
for layer_id in range(num_layers):
layer = self.model.model.layers[layer_id]
attn_module = layer.self_attn.attn # The Attention module
current_buffer = layer_id % num_buffers
# 2a. Wait for current buffer's load to complete
offload_engine.wait_buffer_load(current_buffer)
# 2c. Input LayerNorm
if residual is None:
hidden_ln, residual = layer.input_layernorm(hidden_states), hidden_states
else:
hidden_ln, residual = layer.input_layernorm(hidden_states, residual)
# 2d. QKV projection for new token
qkv = layer.self_attn.qkv_proj(hidden_ln)
q, k_new, v_new = qkv.split([
layer.self_attn.q_size,
layer.self_attn.kv_size,
layer.self_attn.kv_size
], dim=-1)
q = q.view(1, layer.self_attn.num_heads, layer.self_attn.head_dim)
k_new = k_new.view(1, layer.self_attn.num_kv_heads, layer.self_attn.head_dim)
v_new = v_new.view(1, layer.self_attn.num_kv_heads, layer.self_attn.head_dim)
# Q/K norms
if not layer.self_attn.qkv_bias:
q = layer.self_attn.q_norm(q.reshape(-1, layer.self_attn.head_dim))
q = q.view(1, layer.self_attn.num_heads, layer.self_attn.head_dim)
k_new = layer.self_attn.k_norm(k_new.reshape(-1, layer.self_attn.head_dim))
k_new = k_new.view(1, layer.self_attn.num_kv_heads, layer.self_attn.head_dim)
# RoPE
q, k_new = layer.self_attn.rotary_emb(positions, q, k_new)
# 2e. Get prefilled KV from ring buffer
k_prefill, v_prefill = offload_engine.get_buffer_kv(current_buffer, total_prefill_tokens)
# 2f. Get accumulated decode KV from decode buffer (if any previous decode tokens)
if num_decode_tokens > 1:
# 2b. Copy previous decode KV from decode buffer to ring buffer
# Ring buffer already has prefill KV at [0:total_prefill_tokens]
# We need to add decode KV at [total_prefill_tokens:]
if num_prev_decode_tokens > 0:
k_decode_prev, v_decode_prev = offload_engine.get_decode_kv(
layer_id, decode_start_pos, pos_in_block
)
k_full = torch.cat([k_prefill, k_decode_prev, k_new], dim=0)
v_full = torch.cat([v_prefill, v_decode_prev, v_new], dim=0)
else:
k_full = torch.cat([k_prefill, k_new], dim=0)
v_full = torch.cat([v_prefill, v_new], dim=0)
ring_k = offload_engine.layer_k_cache[current_buffer]
ring_v = offload_engine.layer_v_cache[current_buffer]
ring_k[total_prefill_tokens:total_prefill_tokens + num_prev_decode_tokens].copy_(k_decode_prev)
ring_v[total_prefill_tokens:total_prefill_tokens + num_prev_decode_tokens].copy_(v_decode_prev)
# 2g. Store new KV to decode buffer for future decode steps
offload_engine.store_decode_kv(layer_id, pos_in_block, k_new, v_new)
# 2c. Set Attention module's cache to ring buffer (contiguous format)
# Shape: [max_seq_len, kv_heads, head_dim] -> [1, max_seq_len, kv_heads, head_dim]
attn_module.k_cache = offload_engine.layer_k_cache[current_buffer:current_buffer+1]
attn_module.v_cache = offload_engine.layer_v_cache[current_buffer:current_buffer+1]
# 2h. Mark buffer compute done (allows next load to reuse this buffer)
# 2d. Set context for Attention.forward() - contiguous mode
set_context(
is_prefill=False,
slot_mapping=slot_mapping,
context_lens=context_lens,
block_tables=None, # Contiguous mode, no block tables
)
# 2e. Forward through layer using standard path
# This calls Qwen3Attention.forward() -> Attention.forward()
# Attention.forward() will:
# - Store new K,V to ring buffer via store_kvcache
# - Compute attention via flash_attn_with_kvcache
hidden_states, residual = layer(positions, hidden_states, residual)
# 2f. Copy new token's KV from ring buffer to decode buffer (for persistence)
# The new token was stored at position context_len in ring buffer
ring_k = offload_engine.layer_k_cache[current_buffer]
ring_v = offload_engine.layer_v_cache[current_buffer]
offload_engine.decode_k_buffer[layer_id, pos_in_block].copy_(ring_k[context_len])
offload_engine.decode_v_buffer[layer_id, pos_in_block].copy_(ring_v[context_len])
# 2g. Mark buffer compute done (allows next load to reuse this buffer)
offload_engine.record_buffer_compute_done(current_buffer)
# 2i. Start loading next layer to same buffer (after compute done)
# 2h. Start loading next layer to same buffer (after compute done)
next_layer_to_load = layer_id + num_buffers
if next_layer_to_load < num_layers:
offload_engine.load_layer_kv_to_buffer(
current_buffer, next_layer_to_load, cpu_block_table, valid_tokens_per_block
)
# 2j. Compute attention
total_kv_tokens = k_full.shape[0]
cu_seqlens_k = torch.tensor([0, total_kv_tokens], dtype=torch.int32, device="cuda")
attn_output = flash_attn_varlen_func(
q, k_full, v_full,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=1,
max_seqlen_k=total_kv_tokens,
softmax_scale=layer.self_attn.attn.scale,
causal=False,
)
# O projection
attn_output = attn_output.view(1, -1)
hidden_states = layer.self_attn.o_proj(attn_output)
# 2k. Post-attention LayerNorm + MLP
hidden_states, residual = layer.post_attention_layernorm(hidden_states, residual)
hidden_states = layer.mlp(hidden_states)
# Step 3: Final norm
hidden_states, _ = self.model.model.norm(hidden_states, residual)
# Step 4: Compute logits
logits = self.model.compute_logits(hidden_states)
# Reset context
reset_context()
# Step 5: Handle block-full offload (async)
if pos_in_block == self.block_size - 1:
last_cpu_block = self.kvcache_manager.get_last_cpu_block(seq)

View File

@@ -1,412 +1,223 @@
# Task Plan: Fix GPU-only Mode Performance Issue
# Task Plan: Enable CUDA Graphs for CPU Offload Mode
## Goal
Eliminate the `store_kvcache` scatter overhead in GPU-only mode by using **contiguous KV cache layout** (like offload mode), avoiding PagedAttention's blocked layout for single-sequence inference.
## Current Status
## Problem Summary
### Completed: Refactor Offload Decode to Use Standard Attention Path
GPU-only mode with MInference is **slower** than CPU offload mode:
**Problem solved**: The original offload decode (`run_layerwise_offload_decode`) bypassed `Attention.forward()` by manually calling attention components. This was inconsistent with the standard execution path.
| Mode | Prefill Speed (32K tokens, Qwen3-4B) |
|------|--------------------------------------|
| GPU-only + MInference | 3383 tok/s |
| Offload + MInference | 5373 tok/s |
**Root cause**: PagedAttention's blocked layout requires expensive `index_copy_` scatter operations to convert contiguous K,V to blocked format.
## Key Insight: Why Offload is Fast
Offload mode uses **contiguous layout** for KV cache:
```python
# OffloadEngine's CPU cache layout
k_cache_cpu: [num_layers, num_blocks, block_size, kv_heads, head_dim]
# Store is simple contiguous slice assignment
self.k_cache_cpu[layer_id, block_id, :actual_size].copy_(k[start:end])
**Solution implemented**: Refactored to use `layer.forward()` which goes through:
```
Qwen3DecoderLayer.forward()
→ Qwen3Attention.forward()
→ Attention.forward() ← Now properly used!
```
The K,V computed during prefill `[seq_len, kv_heads, head_dim]` matches the cache layout - no format conversion needed!
## Solution: Contiguous Layout for GPU-only Mode
For GPU-only single-sequence mode, use the **same contiguous layout as offload mode**, but on GPU:
```
Current GPU-only (PagedAttention):
Cache: [num_blocks, block_size, kv_heads, head_dim] (blocked)
Store: scatter via index_copy_ (SLOW)
Proposed GPU-only (Contiguous):
Cache: [num_layers, max_seq_len, kv_heads, head_dim] (contiguous)
Store: slice assignment k_cache[layer_id, :seq_len] = k (FAST)
```
This mirrors offload mode's architecture but keeps everything on GPU - no cross-device transfer, no layout conversion.
## Phases
- [x] Phase 1: Add contiguous GPU KV cache in GPUOnlyManager (for single-seq mode)
- [x] Phase 2: Implement `run_gpu_only_prefill()` using contiguous cache
- [x] Phase 3: Implement decode path for contiguous cache
- [x] Phase 4: Test and validate performance
## Results
| Mode | 32K Prefill Speed | Notes |
|------|-------------------|-------|
| GPU-only (before) | ~3383 tok/s | PagedAttention scatter overhead |
| GPU-only contiguous (after) | **5293 tok/s** | 56% improvement |
| Offload mode | 5391 tok/s | Baseline comparison |
**Test passed**: `test_needle.py --input-len 32768 --max-model-len 40960` - correct output retrieved.
## Detailed Design
### Phase 1: Contiguous GPU KV Cache
**File**: `nanovllm/kvcache/gpu_manager.py`
Add contiguous cache allocation for single-sequence mode:
```python
class GPUOnlyManager(KVCacheManager):
def __init__(self, num_blocks: int, block_size: int, max_seq_len: int = 0):
# ... existing code ...
self.max_seq_len = max_seq_len
# Contiguous cache for single-seq mode (allocated in allocate_cache)
self.contiguous_k_cache = None # [num_layers, max_seq_len, kv_heads, head_dim]
self.contiguous_v_cache = None
def allocate_cache(
self,
num_layers: int,
num_kv_heads: int,
head_dim: int,
dtype: torch.dtype,
) -> None:
# Existing PagedAttention cache for multi-seq/decode
self.kv_cache = torch.empty(
2, num_layers, self._num_blocks, self._block_size,
num_kv_heads, head_dim,
dtype=dtype, device="cuda"
)
# Contiguous cache for single-seq prefill (if max_seq_len specified)
if self.max_seq_len > 0:
self.contiguous_k_cache = torch.empty(
num_layers, self.max_seq_len, num_kv_heads, head_dim,
dtype=dtype, device="cuda"
)
self.contiguous_v_cache = torch.empty(
num_layers, self.max_seq_len, num_kv_heads, head_dim,
dtype=dtype, device="cuda"
)
```
### Phase 2: Layer-wise GPU-only Prefill
### Code Changes Made
**File**: `nanovllm/engine/model_runner.py`
Following offload pattern exactly - store K,V per-layer to contiguous cache:
1. **`run_layerwise_offload_decode()` (line 841-991)** - Completely refactored:
Before (bypassed Attention):
```python
qkv = layer.self_attn.qkv_proj(hidden_ln)
q, k_new, v_new = qkv.split(...)
q = layer.self_attn.q_norm(...)
k = layer.self_attn.k_norm(...)
q, k = layer.self_attn.rotary_emb(...)
attn_output = flash_attn_varlen_func(q, k_full, v_full, ...) # Direct call!
hidden_states = layer.self_attn.o_proj(attn_output)
```
After (uses standard path):
```python
# Set up Attention module's cache to ring buffer
attn_module.k_cache = offload_engine.layer_k_cache[buffer_idx:buffer_idx+1]
attn_module.v_cache = offload_engine.layer_v_cache[buffer_idx:buffer_idx+1]
# Set context for contiguous mode
set_context(is_prefill=False, slot_mapping=..., context_lens=..., block_tables=None)
# Standard layer forward - goes through Attention.forward()!
hidden_states, residual = layer(positions, hidden_states, residual)
```
2. **`ModelRunner.__init__()` (line 46-57)** - Conditional CUDA graph capture:
```python
if not self.enforce_eager:
if config.enable_cpu_offload:
# TODO: Implement capture_offload_cudagraph()
pass # Temporarily use eager execution
else:
self.capture_cudagraph()
```
### Test Results
| Test | Mode | Status |
|------|------|--------|
| `test_needle.py --input-len 4096` | GPU-only | PASSED |
| `test_needle.py --input-len 4096 --enable-offload` | CPU offload | PASSED |
## Remaining Work: Implement Offload CUDA Graph
### Why Standard `capture_cudagraph()` Cannot Be Used
The standard capture function captures the PagedAttention decode path:
```python
# capture_cudagraph() sets up:
k_cache: [num_blocks, block_size, kv_heads, head_dim] # PagedAttention format
block_tables: [...] # Block indices for paged indexing
```
But offload mode uses contiguous ring buffer:
```python
# Offload decode sets up:
k_cache: [1, max_seq_len, kv_heads, head_dim] # Contiguous format
block_tables: None # No paging
```
### Implementation Plan for `capture_offload_cudagraph()`
#### Phase 1: Prepare Fixed-Address Tensors
```python
@torch.inference_mode()
def run_gpu_only_prefill(self, seqs: list[Sequence]) -> list[int]:
"""
GPU-only prefill with contiguous KV cache layout.
def capture_offload_cudagraph(self):
"""Capture CUDA graphs for offload decode using ring buffer."""
offload_engine = self.kvcache_manager.offload_engine
num_buffers = offload_engine.num_kv_buffers
Mirrors run_layerwise_offload_prefill() but stores to GPU instead of CPU.
No scatter operations - just contiguous slice assignment.
"""
assert len(seqs) == 1, "GPU-only layer-wise prefill only supports single sequence"
seq = seqs[0]
# Fixed-address tensors for graph capture
input_ids = torch.zeros(1, dtype=torch.int64, device="cuda")
positions = torch.zeros(1, dtype=torch.int64, device="cuda")
slot_mapping = torch.zeros(1, dtype=torch.int32, device="cuda")
context_lens = torch.zeros(1, dtype=torch.int32, device="cuda")
num_layers = len(self.model.model.layers)
total_tokens = len(seq)
# Get contiguous GPU cache
k_cache = self.kvcache_manager.contiguous_k_cache
v_cache = self.kvcache_manager.contiguous_v_cache
# Prepare inputs
input_ids = torch.tensor(seq[:], dtype=torch.int64, device="cuda")
positions = torch.arange(total_tokens, dtype=torch.int64, device="cuda")
from flash_attn.flash_attn_interface import flash_attn_varlen_func
cu_seqlens = torch.tensor([0, total_tokens], dtype=torch.int32, device="cuda")
# Embedding
hidden_states = self.model.model.embed_tokens(input_ids)
residual = None
# Layer-by-layer processing (same as offload prefill)
for layer_id in range(num_layers):
layer = self.model.model.layers[layer_id]
# Input LayerNorm
if residual is None:
hidden_ln, residual = layer.input_layernorm(hidden_states), hidden_states
else:
hidden_ln, residual = layer.input_layernorm(hidden_states, residual)
# QKV projection
qkv = layer.self_attn.qkv_proj(hidden_ln)
q, k, v = qkv.split([
layer.self_attn.q_size,
layer.self_attn.kv_size,
layer.self_attn.kv_size
], dim=-1)
q = q.view(total_tokens, layer.self_attn.num_heads, layer.self_attn.head_dim)
k = k.view(total_tokens, layer.self_attn.num_kv_heads, layer.self_attn.head_dim)
v = v.view(total_tokens, layer.self_attn.num_kv_heads, layer.self_attn.head_dim)
# Q/K norms (Qwen3 specific)
if not layer.self_attn.qkv_bias:
q = layer.self_attn.q_norm(q.reshape(-1, layer.self_attn.head_dim))
q = q.view(total_tokens, layer.self_attn.num_heads, layer.self_attn.head_dim)
k = layer.self_attn.k_norm(k.reshape(-1, layer.self_attn.head_dim))
k = k.view(total_tokens, layer.self_attn.num_kv_heads, layer.self_attn.head_dim)
# RoPE
q, k = layer.self_attn.rotary_emb(positions, q, k)
# Store K,V to contiguous GPU cache (same layout - no conversion!)
# This is just slice assignment, not scatter
k_cache[layer_id, :total_tokens] = k
v_cache[layer_id, :total_tokens] = v
# Sparse or Full attention (uses k, v directly)
if self.sparse_prefill_policy is not None:
attn_output = self.sparse_prefill_policy.sparse_prefill_attention(
q, k, v, layer_id
)
else:
attn_output = flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=total_tokens,
max_seqlen_k=total_tokens,
softmax_scale=layer.self_attn.attn.scale,
causal=True,
)
# O projection
attn_output = attn_output.view(total_tokens, -1)
hidden_states = layer.self_attn.o_proj(attn_output)
# Post-attention LayerNorm + MLP
hidden_states, residual = layer.post_attention_layernorm(hidden_states, residual)
hidden_states = layer.mlp(hidden_states)
# Final norm
hidden_states, _ = self.model.model.norm(hidden_states, residual)
# Compute logits
logits = self.model.compute_logits(hidden_states[-1:])
# Record prefill length for decode
self.kvcache_manager.contiguous_seq_len = total_tokens
# Sample
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None
return token_ids
self.offload_graphs = {}
self.offload_graph_pool = None
```
### Phase 3: Decode with Contiguous Cache
#### Phase 2: Capture Per-Buffer Graphs
**File**: `nanovllm/engine/model_runner.py`
Since layer processing rotates through ring buffers (`layer_id % num_buffers`), we need graphs for each buffer slot:
```python
@torch.inference_mode()
def run_gpu_only_decode(self, seqs: list[Sequence]) -> list[int]:
"""
Decode using contiguous GPU KV cache.
for buffer_idx in range(num_buffers):
graph = torch.cuda.CUDAGraph()
Similar to offload decode but simpler - all KV already on GPU.
"""
assert len(seqs) == 1
seq = seqs[0]
# Set Attention cache to this buffer slot (fixed address)
for layer in self.model.model.layers:
layer.self_attn.attn.k_cache = offload_engine.layer_k_cache[buffer_idx:buffer_idx+1]
layer.self_attn.attn.v_cache = offload_engine.layer_v_cache[buffer_idx:buffer_idx+1]
num_layers = len(self.model.model.layers)
k_cache = self.kvcache_manager.contiguous_k_cache
v_cache = self.kvcache_manager.contiguous_v_cache
context_len = self.kvcache_manager.contiguous_seq_len
# Set context
set_context(is_prefill=False, slot_mapping=slot_mapping,
context_lens=context_lens, block_tables=None)
# Prepare inputs
input_ids = torch.tensor([seq.last_token], dtype=torch.int64, device="cuda")
positions = torch.tensor([len(seq) - 1], dtype=torch.int64, device="cuda")
# Warmup
hidden = self.model.model.embed_tokens(input_ids)
residual = None
for layer_id, layer in enumerate(self.model.model.layers):
if layer_id % num_buffers == buffer_idx:
hidden, residual = layer(positions, hidden, residual)
from flash_attn.flash_attn_interface import flash_attn_varlen_func
cu_seqlens_q = torch.tensor([0, 1], dtype=torch.int32, device="cuda")
# Capture
with torch.cuda.graph(graph, self.offload_graph_pool):
# Same operations
...
# Embedding
hidden_states = self.model.model.embed_tokens(input_ids)
residual = None
for layer_id in range(num_layers):
layer = self.model.model.layers[layer_id]
# Input LayerNorm
if residual is None:
hidden_ln, residual = layer.input_layernorm(hidden_states), hidden_states
else:
hidden_ln, residual = layer.input_layernorm(hidden_states, residual)
# QKV projection
qkv = layer.self_attn.qkv_proj(hidden_ln)
q, k_new, v_new = qkv.split([
layer.self_attn.q_size,
layer.self_attn.kv_size,
layer.self_attn.kv_size
], dim=-1)
q = q.view(1, layer.self_attn.num_heads, layer.self_attn.head_dim)
k_new = k_new.view(1, layer.self_attn.num_kv_heads, layer.self_attn.head_dim)
v_new = v_new.view(1, layer.self_attn.num_kv_heads, layer.self_attn.head_dim)
# Q/K norms
if not layer.self_attn.qkv_bias:
q = layer.self_attn.q_norm(q.reshape(-1, layer.self_attn.head_dim))
q = q.view(1, layer.self_attn.num_heads, layer.self_attn.head_dim)
k_new = layer.self_attn.k_norm(k_new.reshape(-1, layer.self_attn.head_dim))
k_new = k_new.view(1, layer.self_attn.num_kv_heads, layer.self_attn.head_dim)
# RoPE
q, k_new = layer.self_attn.rotary_emb(positions, q, k_new)
# Get cached K,V and append new token
k_cached = k_cache[layer_id, :context_len]
v_cached = v_cache[layer_id, :context_len]
# Store new K,V to cache
k_cache[layer_id, context_len] = k_new.squeeze(0)
v_cache[layer_id, context_len] = v_new.squeeze(0)
# Full K,V for attention
k_full = k_cache[layer_id, :context_len + 1]
v_full = v_cache[layer_id, :context_len + 1]
# Attention
cu_seqlens_k = torch.tensor([0, context_len + 1], dtype=torch.int32, device="cuda")
attn_output = flash_attn_varlen_func(
q, k_full, v_full,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=1,
max_seqlen_k=context_len + 1,
softmax_scale=layer.self_attn.attn.scale,
causal=False, # Single query, no causal needed
)
# O projection
attn_output = attn_output.view(1, -1)
hidden_states = layer.self_attn.o_proj(attn_output)
# Post-attention LayerNorm + MLP
hidden_states, residual = layer.post_attention_layernorm(hidden_states, residual)
hidden_states = layer.mlp(hidden_states)
# Update context length
self.kvcache_manager.contiguous_seq_len = context_len + 1
# Final norm
hidden_states, _ = self.model.model.norm(hidden_states, residual)
# Compute logits
logits = self.model.compute_logits(hidden_states)
# Sample
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None
return token_ids
self.offload_graphs[buffer_idx] = graph
```
### Phase 4: Decision Logic
#### Phase 3: Use Graphs in Decode
Modify `run_layerwise_offload_decode()` to replay graphs:
```python
def _should_use_contiguous_gpu_mode(self, seqs: list[Sequence], is_prefill: bool) -> bool:
"""Check if contiguous GPU mode should be used."""
# Must have contiguous cache allocated
if not hasattr(self.kvcache_manager, 'contiguous_k_cache'):
return False
if self.kvcache_manager.contiguous_k_cache is None:
return False
for layer_id in range(num_layers):
current_buffer = layer_id % num_buffers
# Must NOT be offload mode
if hasattr(self.kvcache_manager, 'offload_engine'):
return False
# Wait for H2D load
offload_engine.wait_buffer_load(current_buffer)
# Single sequence only
if len(seqs) != 1:
return False
# Copy decode buffer to ring buffer (same as current)
...
# For prefill: has blocks (not warmup)
if is_prefill and not seqs[0].block_table:
return False
# Update graph variables
self.offload_graph_vars["positions"][0] = positions[0]
self.offload_graph_vars["slot_mapping"][0] = context_len
self.offload_graph_vars["context_lens"][0] = context_len + 1
return True
# Replay graph instead of eager forward
self.offload_graphs[current_buffer].replay()
def run(self, seqs: list[Sequence], is_prefill: bool) -> list[int]:
# Check offload mode (existing)
if hasattr(self, 'kvcache_manager') and hasattr(self.kvcache_manager, 'offload_engine'):
...
# Check contiguous GPU mode
if self._should_use_contiguous_gpu_mode(seqs, is_prefill):
if is_prefill:
return self.run_gpu_only_prefill(seqs)
else:
return self.run_gpu_only_decode(seqs)
# Standard PagedAttention path
# Copy new KV to decode buffer (same as current)
...
```
## Architecture Comparison
### Challenges and Considerations
| Aspect | Offload Mode | GPU-only (Proposed) | GPU-only (Current) |
|--------|--------------|---------------------|-------------------|
| Cache location | CPU (contiguous) | GPU (contiguous) | GPU (PagedAttention) |
| Cache layout | `[layers, blocks, block_size, heads, dim]` | `[layers, max_seq_len, heads, dim]` | `[blocks, block_size, heads, dim]` |
| Prefill store | Contiguous slice copy | **Slice assignment (no copy!)** | Scatter (index_copy_) |
| Decode read | H2D ring buffer | Direct GPU access | PagedAttention |
| Challenge | Solution |
|-----------|----------|
| H2D transfers interleaved with compute | H2D happens outside graph, only compute is captured |
| Different layers use different buffers | Capture per-buffer graphs, replay correct one |
| Variable context length | Use `cache_seqlens` parameter (fixed address, variable value) |
| Per-layer buffer rotation | Graph captures single-layer forward, loop in Python |
## Key Points
### Alternative: Full-Decode Graph (More Complex)
1. **No explicit copy_ needed**: Slice assignment `cache[layer, :len] = k` is direct memory write
2. **Same layout as computed K,V**: No format conversion required
3. **Mirrors offload architecture**: Same layer-wise processing pattern
4. **GPU advantage**: No cross-device transfer, faster than offload
Instead of per-layer graphs, capture entire decode step:
1. Complete all H2D loads before graph
2. Single graph covers all layers
3. Better kernel fusion, less CPU overhead
4. More complex to implement (need to handle buffer rotation inside graph)
## Memory Usage
## Implementation Phases
Contiguous GPU cache: `2 * num_layers * max_seq_len * kv_heads * head_dim * dtype_size`
| Phase | Description | Status |
|-------|-------------|--------|
| Phase 0 | Refactor offload decode to use Attention.forward() | ✅ Completed |
| Phase 1 | Implement `capture_offload_cudagraph()` with per-buffer graphs | ⬜ Pending |
| Phase 2 | Modify `run_layerwise_offload_decode()` to use graphs | ⬜ Pending |
| Phase 3 | Test and benchmark | ⬜ Pending |
| Phase 4 | (Optional) Optimize to full-decode graph | ⬜ Future |
For Qwen3-4B with 32K max_seq_len:
- `2 * 28 * 32768 * 8 * 128 * 2 = 3.5GB`
## Architecture After Refactoring
Same as offload mode's CPU cache, but on GPU.
```
┌─────────────────────────────────────────────────────────────────────────────┐
│ Offload Decode Flow (After Refactoring) │
├─────────────────────────────────────────────────────────────────────────────┤
│ │
│ For each layer: │
│ 1. Wait for H2D load (ring buffer has prefill KV) │
│ 2. Copy decode buffer → ring buffer (at prefill_len offset) │
│ 3. Set Attention.k_cache = ring_buffer[buffer_idx] │
│ 4. Set context (slot_mapping, context_lens, block_tables=None) │
│ 5. layer.forward() → Qwen3Attention.forward() → Attention.forward() │
│ └── store_kvcache() stores new token to ring buffer │
│ └── flash_attn_with_kvcache() computes attention │
│ 6. Copy new token KV: ring buffer → decode buffer │
│ 7. Start next layer H2D load │
│ │
│ Key insight: Now uses standard Attention path, just with ring buffer │
│ as k_cache/v_cache in contiguous format (block_tables=None) │
│ │
└─────────────────────────────────────────────────────────────────────────────┘
```
## Files to Modify
## Files Modified
| File | Changes |
|------|---------|
| `nanovllm/kvcache/gpu_manager.py` | Add contiguous cache allocation |
| `nanovllm/engine/model_runner.py` | Add `run_gpu_only_prefill()`, `run_gpu_only_decode()`, modify `run()` |
| `model_runner.py:46-57` | Conditional CUDA graph capture (skip for offload) |
| `model_runner.py:841-991` | Refactored `run_layerwise_offload_decode()` to use standard `layer.forward()` |
## Expected Performance
## Next Steps
| Metric | Before | After | Improvement |
|--------|--------|-------|-------------|
| GPU-only prefill (32K) | 3383 tok/s | ~5400+ tok/s | ~60%+ |
| Decode | Baseline | Similar | ~0% |
## Status
**Currently in Phase 1** - Ready to implement contiguous GPU cache
1. Implement `capture_offload_cudagraph()` method
2. Modify `run_layerwise_offload_decode()` to optionally use captured graphs
3. Benchmark performance improvement from CUDA graphs
4. Consider full-decode graph optimization for maximum performance