7.6 KiB
CUDA Graph Support for CPU Offload Mode
This document describes the CUDA graph implementation for the CPU offload decode path, which provides significant performance improvements for decode throughput.
Overview
CUDA graphs capture a sequence of GPU operations and replay them with minimal CPU overhead. In offload mode, we capture per-layer graphs for the decode path, achieving 4x decode throughput improvement.
Performance Results
| Metric | Eager Mode | CUDA Graph | Improvement |
|---|---|---|---|
| Decode Throughput | ~12 tok/s | ~50 tok/s | 4.2x |
| TPOT (Time per output token) | ~80ms | ~19ms | 4.2x |
| Prefill Throughput | ~8000 tok/s | ~8000 tok/s | Same |
Architecture
Why Standard CUDA Graph Capture Doesn't Work
The standard capture_cudagraph() captures the PagedAttention decode path:
- Uses block tables for scattered KV cache access
Attention.k_cache/v_cachepoint to PagedAttention buffers
In offload mode, the decode path is different:
- Uses contiguous ring buffers for KV cache
Attention.k_cache/v_cachedynamically point to ring buffer slices- H2D transfers interleaved with compute
Per-Layer Graph Design
We capture one CUDA graph per transformer layer:
┌─────────────────────────────────────────────────────────────┐
│ Offload Decode with CUDA Graphs │
├─────────────────────────────────────────────────────────────┤
│ │
│ Initialization: │
│ capture_offload_cudagraph() captures 36 layer graphs │
│ Each graph: layer.forward() with ring buffer as cache │
│ │
│ Decode Step: │
│ 1. Embedding (eager, outside graph) │
│ 2. For each layer: │
│ a. Wait for H2D load (outside graph) │
│ b. Copy decode KV to ring buffer (outside graph) │
│ c. Set Attention.k_cache = ring_buffer[buffer_idx] │
│ d. Set context (slot_mapping, context_lens) │
│ e. graph.replay() - layer forward │
│ f. synchronize() │
│ g. Copy layer_outputs -> hidden_states │
│ h. Copy new KV to decode buffer (outside graph) │
│ i. Start next layer H2D load │
│ 3. Final norm and logits (eager) │
│ │
└─────────────────────────────────────────────────────────────┘
Ring Buffer Mapping
Each layer maps to a ring buffer slot:
buffer_idx = layer_id % num_kv_buffers
With 4 buffers and 36 layers:
- Layer 0, 4, 8, ... use buffer 0
- Layer 1, 5, 9, ... use buffer 1
- Layer 2, 6, 10, ... use buffer 2
- Layer 3, 7, 11, ... use buffer 3
Implementation Details
Graph Capture (capture_offload_cudagraph)
Location: model_runner.py:1075-1164
def capture_offload_cudagraph(self):
# Fixed-address tensors for graph I/O
hidden_states = torch.randn(1, hidden_size, ...)
residual = torch.randn(1, hidden_size, ...)
layer_outputs = torch.zeros(1, hidden_size, ...)
layer_residual = torch.zeros(1, hidden_size, ...)
for layer_id in range(num_layers):
buffer_idx = layer_id % num_buffers
# Set Attention cache to ring buffer slice
attn_module.k_cache = ring_buffer[buffer_idx:buffer_idx+1]
attn_module.v_cache = ring_buffer[buffer_idx:buffer_idx+1]
# Set context for contiguous mode
set_context(is_prefill=False, slot_mapping=...,
context_lens=..., block_tables=None)
# Warmup and capture
with torch.cuda.graph(graph, pool):
out_h, out_r = layer(positions, hidden_states, residual)
layer_outputs.copy_(out_h)
layer_residual.copy_(out_r)
# Propagate state for next layer's capture
hidden_states.copy_(layer_outputs)
residual.copy_(layer_residual)
Key design decisions:
- Fixed-address tensors: Graph inputs/outputs use pre-allocated tensors
- Include copy in graph:
layer_outputs.copy_(out_h)is captured - State propagation: Update hidden_states between layer captures
- Random initialization: Use
randninstead of zeros for realistic distributions
Graph Replay (run_layerwise_offload_decode)
Location: model_runner.py:844-1031
use_cuda_graph = not self.enforce_eager and hasattr(self, 'offload_graphs')
if use_cuda_graph:
# Use fixed-address tensors
graph_vars["positions"][0] = len(seq) - 1
graph_vars["slot_mapping"][0] = context_len
graph_vars["context_lens"][0] = context_len + 1
graph_vars["hidden_states"].copy_(embedding)
graph_vars["residual"].zero_()
for layer_id in range(num_layers):
# H2D and buffer setup (outside graph)
offload_engine.wait_buffer_load(current_buffer)
attn_module.k_cache = ring_buffer[current_buffer:current_buffer+1]
set_context(...)
if use_cuda_graph:
# Replay graph
self.offload_graphs[layer_id].replay()
torch.cuda.current_stream().synchronize()
# Copy outputs to inputs for next layer
if layer_id < num_layers - 1:
graph_vars["hidden_states"].copy_(graph_vars["layer_outputs"])
graph_vars["residual"].copy_(graph_vars["layer_residual"])
else:
# Eager execution
hidden_states, residual = layer(positions, hidden_states, residual)
Key points:
- Synchronization required:
synchronize()after each graph replay - Manual state propagation: Copy layer_outputs to hidden_states between replays
- H2D outside graph: Ring buffer loads happen before graph replay
Limitations and Future Work
Current Limitations
- Per-layer sync overhead: Each layer requires synchronization
- No kernel fusion across layers: Each layer is a separate graph
- Fixed batch size: Only supports batch_size=1 for offload
Future Optimization: Full-Decode Graph
Potential improvement: Capture entire decode step as single graph
- Complete all H2D loads before graph
- Single graph covers all 36 layers
- Better kernel fusion, less CPU overhead
- More complex to implement (handle buffer rotation inside graph)
Testing
Run needle test with CUDA graph:
PYTHONPATH=/path/to/nano-vllm:$PYTHONPATH python tests/test_needle.py \
--input-len 32768 \
--enable-offload \
--use-cuda-graph
Run benchmark:
PYTHONPATH=/path/to/nano-vllm:$PYTHONPATH python bench_offload.py \
--input-len 16384 \
--bench-all
Files Modified
| File | Changes |
|---|---|
model_runner.py:46-50 |
Call capture_offload_cudagraph() for offload mode |
model_runner.py:69-73 |
Clean up offload graph resources in exit() |
model_runner.py:844-1031 |
Add CUDA graph support to run_layerwise_offload_decode() |
model_runner.py:1075-1164 |
New capture_offload_cudagraph() method |
tests/test_needle.py |
Add --use-cuda-graph flag |