Files
nano-vllm/docs/cuda_graph_offload_guide.md
2026-01-21 21:56:24 +08:00

7.6 KiB

CUDA Graph Support for CPU Offload Mode

This document describes the CUDA graph implementation for the CPU offload decode path, which provides significant performance improvements for decode throughput.

Overview

CUDA graphs capture a sequence of GPU operations and replay them with minimal CPU overhead. In offload mode, we capture per-layer graphs for the decode path, achieving 4x decode throughput improvement.

Performance Results

Metric Eager Mode CUDA Graph Improvement
Decode Throughput ~12 tok/s ~50 tok/s 4.2x
TPOT (Time per output token) ~80ms ~19ms 4.2x
Prefill Throughput ~8000 tok/s ~8000 tok/s Same

Architecture

Why Standard CUDA Graph Capture Doesn't Work

The standard capture_cudagraph() captures the PagedAttention decode path:

  • Uses block tables for scattered KV cache access
  • Attention.k_cache/v_cache point to PagedAttention buffers

In offload mode, the decode path is different:

  • Uses contiguous ring buffers for KV cache
  • Attention.k_cache/v_cache dynamically point to ring buffer slices
  • H2D transfers interleaved with compute

Per-Layer Graph Design

We capture one CUDA graph per transformer layer:

┌─────────────────────────────────────────────────────────────┐
│                    Offload Decode with CUDA Graphs          │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  Initialization:                                            │
│    capture_offload_cudagraph() captures 36 layer graphs     │
│    Each graph: layer.forward() with ring buffer as cache    │
│                                                             │
│  Decode Step:                                               │
│    1. Embedding (eager, outside graph)                      │
│    2. For each layer:                                       │
│       a. Wait for H2D load (outside graph)                  │
│       b. Copy decode KV to ring buffer (outside graph)      │
│       c. Set Attention.k_cache = ring_buffer[buffer_idx]    │
│       d. Set context (slot_mapping, context_lens)           │
│       e. graph.replay() - layer forward                     │
│       f. synchronize()                                      │
│       g. Copy layer_outputs -> hidden_states                │
│       h. Copy new KV to decode buffer (outside graph)       │
│       i. Start next layer H2D load                          │
│    3. Final norm and logits (eager)                         │
│                                                             │
└─────────────────────────────────────────────────────────────┘

Ring Buffer Mapping

Each layer maps to a ring buffer slot:

buffer_idx = layer_id % num_kv_buffers

With 4 buffers and 36 layers:

  • Layer 0, 4, 8, ... use buffer 0
  • Layer 1, 5, 9, ... use buffer 1
  • Layer 2, 6, 10, ... use buffer 2
  • Layer 3, 7, 11, ... use buffer 3

Implementation Details

Graph Capture (capture_offload_cudagraph)

Location: model_runner.py:1075-1164

def capture_offload_cudagraph(self):
    # Fixed-address tensors for graph I/O
    hidden_states = torch.randn(1, hidden_size, ...)
    residual = torch.randn(1, hidden_size, ...)
    layer_outputs = torch.zeros(1, hidden_size, ...)
    layer_residual = torch.zeros(1, hidden_size, ...)

    for layer_id in range(num_layers):
        buffer_idx = layer_id % num_buffers

        # Set Attention cache to ring buffer slice
        attn_module.k_cache = ring_buffer[buffer_idx:buffer_idx+1]
        attn_module.v_cache = ring_buffer[buffer_idx:buffer_idx+1]

        # Set context for contiguous mode
        set_context(is_prefill=False, slot_mapping=...,
                    context_lens=..., block_tables=None)

        # Warmup and capture
        with torch.cuda.graph(graph, pool):
            out_h, out_r = layer(positions, hidden_states, residual)
            layer_outputs.copy_(out_h)
            layer_residual.copy_(out_r)

        # Propagate state for next layer's capture
        hidden_states.copy_(layer_outputs)
        residual.copy_(layer_residual)

Key design decisions:

  1. Fixed-address tensors: Graph inputs/outputs use pre-allocated tensors
  2. Include copy in graph: layer_outputs.copy_(out_h) is captured
  3. State propagation: Update hidden_states between layer captures
  4. Random initialization: Use randn instead of zeros for realistic distributions

Graph Replay (run_layerwise_offload_decode)

Location: model_runner.py:844-1031

use_cuda_graph = not self.enforce_eager and hasattr(self, 'offload_graphs')

if use_cuda_graph:
    # Use fixed-address tensors
    graph_vars["positions"][0] = len(seq) - 1
    graph_vars["slot_mapping"][0] = context_len
    graph_vars["context_lens"][0] = context_len + 1
    graph_vars["hidden_states"].copy_(embedding)
    graph_vars["residual"].zero_()

for layer_id in range(num_layers):
    # H2D and buffer setup (outside graph)
    offload_engine.wait_buffer_load(current_buffer)
    attn_module.k_cache = ring_buffer[current_buffer:current_buffer+1]
    set_context(...)

    if use_cuda_graph:
        # Replay graph
        self.offload_graphs[layer_id].replay()
        torch.cuda.current_stream().synchronize()

        # Copy outputs to inputs for next layer
        if layer_id < num_layers - 1:
            graph_vars["hidden_states"].copy_(graph_vars["layer_outputs"])
            graph_vars["residual"].copy_(graph_vars["layer_residual"])
    else:
        # Eager execution
        hidden_states, residual = layer(positions, hidden_states, residual)

Key points:

  1. Synchronization required: synchronize() after each graph replay
  2. Manual state propagation: Copy layer_outputs to hidden_states between replays
  3. H2D outside graph: Ring buffer loads happen before graph replay

Limitations and Future Work

Current Limitations

  1. Per-layer sync overhead: Each layer requires synchronization
  2. No kernel fusion across layers: Each layer is a separate graph
  3. Fixed batch size: Only supports batch_size=1 for offload

Future Optimization: Full-Decode Graph

Potential improvement: Capture entire decode step as single graph

  • Complete all H2D loads before graph
  • Single graph covers all 36 layers
  • Better kernel fusion, less CPU overhead
  • More complex to implement (handle buffer rotation inside graph)

Testing

Run needle test with CUDA graph:

PYTHONPATH=/path/to/nano-vllm:$PYTHONPATH python tests/test_needle.py \
    --input-len 32768 \
    --enable-offload \
    --use-cuda-graph

Run benchmark:

PYTHONPATH=/path/to/nano-vllm:$PYTHONPATH python bench_offload.py \
    --input-len 16384 \
    --bench-all

Files Modified

File Changes
model_runner.py:46-50 Call capture_offload_cudagraph() for offload mode
model_runner.py:69-73 Clean up offload graph resources in exit()
model_runner.py:844-1031 Add CUDA graph support to run_layerwise_offload_decode()
model_runner.py:1075-1164 New capture_offload_cudagraph() method
tests/test_needle.py Add --use-cuda-graph flag