197 lines
7.6 KiB
Markdown
197 lines
7.6 KiB
Markdown
# CUDA Graph Support for CPU Offload Mode
|
|
|
|
This document describes the CUDA graph implementation for the CPU offload decode path, which provides significant performance improvements for decode throughput.
|
|
|
|
## Overview
|
|
|
|
CUDA graphs capture a sequence of GPU operations and replay them with minimal CPU overhead. In offload mode, we capture per-layer graphs for the decode path, achieving **4x decode throughput improvement**.
|
|
|
|
## Performance Results
|
|
|
|
| Metric | Eager Mode | CUDA Graph | Improvement |
|
|
|--------|------------|------------|-------------|
|
|
| Decode Throughput | ~12 tok/s | ~50 tok/s | **4.2x** |
|
|
| TPOT (Time per output token) | ~80ms | ~19ms | **4.2x** |
|
|
| Prefill Throughput | ~8000 tok/s | ~8000 tok/s | Same |
|
|
|
|
## Architecture
|
|
|
|
### Why Standard CUDA Graph Capture Doesn't Work
|
|
|
|
The standard `capture_cudagraph()` captures the PagedAttention decode path:
|
|
- Uses block tables for scattered KV cache access
|
|
- `Attention.k_cache/v_cache` point to PagedAttention buffers
|
|
|
|
In offload mode, the decode path is different:
|
|
- Uses contiguous ring buffers for KV cache
|
|
- `Attention.k_cache/v_cache` dynamically point to ring buffer slices
|
|
- H2D transfers interleaved with compute
|
|
|
|
### Per-Layer Graph Design
|
|
|
|
We capture one CUDA graph per transformer layer:
|
|
|
|
```
|
|
┌─────────────────────────────────────────────────────────────┐
|
|
│ Offload Decode with CUDA Graphs │
|
|
├─────────────────────────────────────────────────────────────┤
|
|
│ │
|
|
│ Initialization: │
|
|
│ capture_offload_cudagraph() captures 36 layer graphs │
|
|
│ Each graph: layer.forward() with ring buffer as cache │
|
|
│ │
|
|
│ Decode Step: │
|
|
│ 1. Embedding (eager, outside graph) │
|
|
│ 2. For each layer: │
|
|
│ a. Wait for H2D load (outside graph) │
|
|
│ b. Copy decode KV to ring buffer (outside graph) │
|
|
│ c. Set Attention.k_cache = ring_buffer[buffer_idx] │
|
|
│ d. Set context (slot_mapping, context_lens) │
|
|
│ e. graph.replay() - layer forward │
|
|
│ f. synchronize() │
|
|
│ g. Copy layer_outputs -> hidden_states │
|
|
│ h. Copy new KV to decode buffer (outside graph) │
|
|
│ i. Start next layer H2D load │
|
|
│ 3. Final norm and logits (eager) │
|
|
│ │
|
|
└─────────────────────────────────────────────────────────────┘
|
|
```
|
|
|
|
### Ring Buffer Mapping
|
|
|
|
Each layer maps to a ring buffer slot:
|
|
```python
|
|
buffer_idx = layer_id % num_kv_buffers
|
|
```
|
|
|
|
With 4 buffers and 36 layers:
|
|
- Layer 0, 4, 8, ... use buffer 0
|
|
- Layer 1, 5, 9, ... use buffer 1
|
|
- Layer 2, 6, 10, ... use buffer 2
|
|
- Layer 3, 7, 11, ... use buffer 3
|
|
|
|
## Implementation Details
|
|
|
|
### Graph Capture (`capture_offload_cudagraph`)
|
|
|
|
Location: `model_runner.py:1075-1164`
|
|
|
|
```python
|
|
def capture_offload_cudagraph(self):
|
|
# Fixed-address tensors for graph I/O
|
|
hidden_states = torch.randn(1, hidden_size, ...)
|
|
residual = torch.randn(1, hidden_size, ...)
|
|
layer_outputs = torch.zeros(1, hidden_size, ...)
|
|
layer_residual = torch.zeros(1, hidden_size, ...)
|
|
|
|
for layer_id in range(num_layers):
|
|
buffer_idx = layer_id % num_buffers
|
|
|
|
# Set Attention cache to ring buffer slice
|
|
attn_module.k_cache = ring_buffer[buffer_idx:buffer_idx+1]
|
|
attn_module.v_cache = ring_buffer[buffer_idx:buffer_idx+1]
|
|
|
|
# Set context for contiguous mode
|
|
set_context(is_prefill=False, slot_mapping=...,
|
|
context_lens=..., block_tables=None)
|
|
|
|
# Warmup and capture
|
|
with torch.cuda.graph(graph, pool):
|
|
out_h, out_r = layer(positions, hidden_states, residual)
|
|
layer_outputs.copy_(out_h)
|
|
layer_residual.copy_(out_r)
|
|
|
|
# Propagate state for next layer's capture
|
|
hidden_states.copy_(layer_outputs)
|
|
residual.copy_(layer_residual)
|
|
```
|
|
|
|
Key design decisions:
|
|
1. **Fixed-address tensors**: Graph inputs/outputs use pre-allocated tensors
|
|
2. **Include copy in graph**: `layer_outputs.copy_(out_h)` is captured
|
|
3. **State propagation**: Update hidden_states between layer captures
|
|
4. **Random initialization**: Use `randn` instead of zeros for realistic distributions
|
|
|
|
### Graph Replay (`run_layerwise_offload_decode`)
|
|
|
|
Location: `model_runner.py:844-1031`
|
|
|
|
```python
|
|
use_cuda_graph = not self.enforce_eager and hasattr(self, 'offload_graphs')
|
|
|
|
if use_cuda_graph:
|
|
# Use fixed-address tensors
|
|
graph_vars["positions"][0] = len(seq) - 1
|
|
graph_vars["slot_mapping"][0] = context_len
|
|
graph_vars["context_lens"][0] = context_len + 1
|
|
graph_vars["hidden_states"].copy_(embedding)
|
|
graph_vars["residual"].zero_()
|
|
|
|
for layer_id in range(num_layers):
|
|
# H2D and buffer setup (outside graph)
|
|
offload_engine.wait_buffer_load(current_buffer)
|
|
attn_module.k_cache = ring_buffer[current_buffer:current_buffer+1]
|
|
set_context(...)
|
|
|
|
if use_cuda_graph:
|
|
# Replay graph
|
|
self.offload_graphs[layer_id].replay()
|
|
torch.cuda.current_stream().synchronize()
|
|
|
|
# Copy outputs to inputs for next layer
|
|
if layer_id < num_layers - 1:
|
|
graph_vars["hidden_states"].copy_(graph_vars["layer_outputs"])
|
|
graph_vars["residual"].copy_(graph_vars["layer_residual"])
|
|
else:
|
|
# Eager execution
|
|
hidden_states, residual = layer(positions, hidden_states, residual)
|
|
```
|
|
|
|
Key points:
|
|
1. **Synchronization required**: `synchronize()` after each graph replay
|
|
2. **Manual state propagation**: Copy layer_outputs to hidden_states between replays
|
|
3. **H2D outside graph**: Ring buffer loads happen before graph replay
|
|
|
|
## Limitations and Future Work
|
|
|
|
### Current Limitations
|
|
|
|
1. **Per-layer sync overhead**: Each layer requires synchronization
|
|
2. **No kernel fusion across layers**: Each layer is a separate graph
|
|
3. **Fixed batch size**: Only supports batch_size=1 for offload
|
|
|
|
### Future Optimization: Full-Decode Graph
|
|
|
|
Potential improvement: Capture entire decode step as single graph
|
|
- Complete all H2D loads before graph
|
|
- Single graph covers all 36 layers
|
|
- Better kernel fusion, less CPU overhead
|
|
- More complex to implement (handle buffer rotation inside graph)
|
|
|
|
## Testing
|
|
|
|
Run needle test with CUDA graph:
|
|
```bash
|
|
PYTHONPATH=/path/to/nano-vllm:$PYTHONPATH python tests/test_needle.py \
|
|
--input-len 32768 \
|
|
--enable-offload \
|
|
--use-cuda-graph
|
|
```
|
|
|
|
Run benchmark:
|
|
```bash
|
|
PYTHONPATH=/path/to/nano-vllm:$PYTHONPATH python bench_offload.py \
|
|
--input-len 16384 \
|
|
--bench-all
|
|
```
|
|
|
|
## Files Modified
|
|
|
|
| File | Changes |
|
|
|------|---------|
|
|
| `model_runner.py:46-50` | Call `capture_offload_cudagraph()` for offload mode |
|
|
| `model_runner.py:69-73` | Clean up offload graph resources in `exit()` |
|
|
| `model_runner.py:844-1031` | Add CUDA graph support to `run_layerwise_offload_decode()` |
|
|
| `model_runner.py:1075-1164` | New `capture_offload_cudagraph()` method |
|
|
| `tests/test_needle.py` | Add `--use-cuda-graph` flag |
|