Files
nano-vllm/nanovllm/layers/graphed_layers.py
Zijie Tian 0437311068 feat: add Phase 5 CUDA Graph optimization for chunked prefill
Implement extended CUDA Graph coverage for CPU offload path:
- Add graphed_layers.py with N+2 graph architecture (EmbedGraph, FirstGraph, InterGraphs, LastGraph)
- Support both prefill (seq_len=chunk_size) and decode (seq_len=1) graph modes
- Extend graph coverage to ~70-80% including qkv_proj, rotary, o_proj
- Only attention core remains in eager mode for dynamic offload

Performance: Prefill throughput improved ~5.6% (3782 -> 3995 tok/s at 32K)

Also adds:
- --enforce-eager flag to bench_offload.py for comparison
- Offload mode constraint documentation in CLAUDE.md

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
2026-01-27 07:38:40 +08:00

573 lines
19 KiB
Python

"""
CUDA Graph wrapped layers for offload optimization.
This module provides Graph-wrapped versions of non-attention layers
to reduce kernel launch overhead in CPU offload path.
Phase 5 Design:
- Supports both Prefill (seq_len=chunk_size) and Decode (seq_len=1)
- Extended coverage: embed, input_norm, qkv_proj, rotary, o_proj, post_norm, mlp, final_norm
- Only attention core (attn.forward) remains in eager mode
Graph Structure (N layers):
- EmbedGraph: embed_tokens
- FirstGraph: input_norm → qkv_proj → rotary
- InterGraph[i]: o_proj → post_norm → mlp → input_norm → qkv_proj → rotary (N-1 graphs)
- LastGraph: o_proj → post_norm → mlp → final_norm
Total: N+2 graphs
"""
import torch
from torch import nn
from typing import Optional, Tuple
class EmbedGraph(nn.Module):
"""
Graph wrapper for embedding layer.
Input: input_ids [seq_len]
Output: hidden_states [seq_len, hidden_size]
"""
def __init__(
self,
embed_tokens: nn.Module,
seq_len: int,
hidden_size: int,
dtype: torch.dtype = torch.bfloat16,
):
super().__init__()
self.embed_tokens = embed_tokens
self.seq_len = seq_len
self.hidden_size = hidden_size
self.dtype = dtype
# Graph state
self.graph: Optional[torch.cuda.CUDAGraph] = None
self.ids_in: Optional[torch.Tensor] = None
self.h_out: Optional[torch.Tensor] = None
def _compute(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def capture_graph(self, graph_pool=None):
"""Capture CUDA Graph."""
# Allocate placeholders outside inference_mode
self.ids_in = torch.zeros(self.seq_len, dtype=torch.long, device="cuda")
self.h_out = torch.zeros(self.seq_len, self.hidden_size, dtype=self.dtype, device="cuda")
with torch.inference_mode():
# Warmup
for _ in range(3):
h = self._compute(self.ids_in)
self.h_out.copy_(h)
torch.cuda.synchronize()
# Capture
self.graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.graph, pool=graph_pool):
h = self._compute(self.ids_in)
self.h_out.copy_(h)
return self.graph.pool() if graph_pool is None else graph_pool
def forward(self, input_ids: torch.Tensor, use_graph: bool = False) -> torch.Tensor:
if use_graph and self.graph is not None and input_ids.shape[0] == self.seq_len:
self.ids_in.copy_(input_ids)
self.graph.replay()
return self.h_out.clone()
else:
return self._compute(input_ids)
class FirstGraph(nn.Module):
"""
Graph wrapper for first layer pre-attention:
input_norm → qkv_proj → split → reshape → rotary
Input: hidden_states [seq_len, hidden_size], positions [seq_len]
Output: q [seq_len, num_heads, head_dim], k [seq_len, num_kv_heads, head_dim],
v [seq_len, num_kv_heads, head_dim], residual [seq_len, hidden_size]
"""
def __init__(
self,
input_norm: nn.Module,
qkv_proj: nn.Module,
rotary_emb: nn.Module,
# Shape parameters
seq_len: int,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
head_dim: int,
dtype: torch.dtype = torch.bfloat16,
):
super().__init__()
self.input_norm = input_norm
self.qkv_proj = qkv_proj
self.rotary_emb = rotary_emb
self.seq_len = seq_len
self.hidden_size = hidden_size
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.head_dim = head_dim
self.dtype = dtype
# Split sizes
self.q_size = num_heads * head_dim
self.kv_size = num_kv_heads * head_dim
# Graph state
self.graph: Optional[torch.cuda.CUDAGraph] = None
def _compute(
self,
hidden_states: torch.Tensor,
positions: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
First layer computation:
1. input_layernorm (residual = hidden_states for first layer)
2. QKV projection
3. Split and reshape
4. Rotary embedding
"""
# For first layer, residual = hidden_states (before norm)
residual = hidden_states.clone()
hidden_states = self.input_norm(hidden_states)
# QKV projection
qkv = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
# Reshape
q = q.view(-1, self.num_heads, self.head_dim)
k = k.view(-1, self.num_kv_heads, self.head_dim)
v = v.view(-1, self.num_kv_heads, self.head_dim)
# Rotary embedding
q, k = self.rotary_emb(positions, q, k)
return q, k, v, residual
def capture_graph(self, graph_pool=None):
"""Capture CUDA Graph."""
# Allocate placeholders
self.h_in = torch.zeros(self.seq_len, self.hidden_size, dtype=self.dtype, device="cuda")
self.pos_in = torch.zeros(self.seq_len, dtype=torch.long, device="cuda")
self.q_out = torch.zeros(self.seq_len, self.num_heads, self.head_dim, dtype=self.dtype, device="cuda")
self.k_out = torch.zeros(self.seq_len, self.num_kv_heads, self.head_dim, dtype=self.dtype, device="cuda")
self.v_out = torch.zeros(self.seq_len, self.num_kv_heads, self.head_dim, dtype=self.dtype, device="cuda")
self.r_out = torch.zeros(self.seq_len, self.hidden_size, dtype=self.dtype, device="cuda")
with torch.inference_mode():
# Warmup
for _ in range(3):
q, k, v, r = self._compute(self.h_in, self.pos_in)
self.q_out.copy_(q)
self.k_out.copy_(k)
self.v_out.copy_(v)
self.r_out.copy_(r)
torch.cuda.synchronize()
# Capture
self.graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.graph, pool=graph_pool):
q, k, v, r = self._compute(self.h_in, self.pos_in)
self.q_out.copy_(q)
self.k_out.copy_(k)
self.v_out.copy_(v)
self.r_out.copy_(r)
return self.graph.pool() if graph_pool is None else graph_pool
def forward(
self,
hidden_states: torch.Tensor,
positions: torch.Tensor,
use_graph: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
if use_graph and self.graph is not None and hidden_states.shape[0] == self.seq_len:
self.h_in.copy_(hidden_states)
self.pos_in.copy_(positions)
self.graph.replay()
return self.q_out.clone(), self.k_out.clone(), self.v_out.clone(), self.r_out.clone()
else:
return self._compute(hidden_states, positions)
class InterGraph(nn.Module):
"""
Graph wrapper for inter-layer computation:
o_proj → post_norm → mlp → input_norm → qkv_proj → rotary
Merges current layer's post-attention with next layer's pre-attention.
Input: attn_output [seq_len, num_heads, head_dim], residual [seq_len, hidden_size], positions [seq_len]
Output: q [seq_len, num_heads, head_dim], k [seq_len, num_kv_heads, head_dim],
v [seq_len, num_kv_heads, head_dim], residual [seq_len, hidden_size]
"""
def __init__(
self,
# Current layer components
o_proj: nn.Module,
post_norm: nn.Module,
mlp: nn.Module,
# Next layer components
next_input_norm: nn.Module,
next_qkv_proj: nn.Module,
next_rotary_emb: nn.Module,
# Shape parameters
seq_len: int,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
head_dim: int,
dtype: torch.dtype = torch.bfloat16,
):
super().__init__()
# Current layer
self.o_proj = o_proj
self.post_norm = post_norm
self.mlp = mlp
# Next layer
self.next_input_norm = next_input_norm
self.next_qkv_proj = next_qkv_proj
self.next_rotary_emb = next_rotary_emb
# Shape params
self.seq_len = seq_len
self.hidden_size = hidden_size
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.head_dim = head_dim
self.dtype = dtype
# Split sizes
self.q_size = num_heads * head_dim
self.kv_size = num_kv_heads * head_dim
# Graph state
self.graph: Optional[torch.cuda.CUDAGraph] = None
def _compute(
self,
attn_output: torch.Tensor, # [seq_len, num_heads, head_dim]
residual: torch.Tensor, # [seq_len, hidden_size]
positions: torch.Tensor, # [seq_len]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Inter-layer computation:
1. O projection (flatten first)
2. Post-attention layernorm + residual
3. MLP
4. Next layer's input layernorm + residual
5. QKV projection
6. Split and reshape
7. Rotary embedding
"""
# O projection
hidden_states = self.o_proj(attn_output.flatten(1, -1))
# Post-attention of current layer
hidden_states, residual = self.post_norm(hidden_states, residual)
hidden_states = self.mlp(hidden_states)
# Pre-attention of next layer
hidden_states, residual = self.next_input_norm(hidden_states, residual)
# QKV projection
qkv = self.next_qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
# Reshape
q = q.view(-1, self.num_heads, self.head_dim)
k = k.view(-1, self.num_kv_heads, self.head_dim)
v = v.view(-1, self.num_kv_heads, self.head_dim)
# Rotary embedding
q, k = self.next_rotary_emb(positions, q, k)
return q, k, v, residual
def capture_graph(self, graph_pool=None):
"""Capture CUDA Graph."""
# Allocate placeholders
self.attn_in = torch.zeros(self.seq_len, self.num_heads, self.head_dim, dtype=self.dtype, device="cuda")
self.r_in = torch.zeros(self.seq_len, self.hidden_size, dtype=self.dtype, device="cuda")
self.pos_in = torch.zeros(self.seq_len, dtype=torch.long, device="cuda")
self.q_out = torch.zeros(self.seq_len, self.num_heads, self.head_dim, dtype=self.dtype, device="cuda")
self.k_out = torch.zeros(self.seq_len, self.num_kv_heads, self.head_dim, dtype=self.dtype, device="cuda")
self.v_out = torch.zeros(self.seq_len, self.num_kv_heads, self.head_dim, dtype=self.dtype, device="cuda")
self.r_out = torch.zeros(self.seq_len, self.hidden_size, dtype=self.dtype, device="cuda")
with torch.inference_mode():
# Warmup
for _ in range(3):
q, k, v, r = self._compute(self.attn_in, self.r_in, self.pos_in)
self.q_out.copy_(q)
self.k_out.copy_(k)
self.v_out.copy_(v)
self.r_out.copy_(r)
torch.cuda.synchronize()
# Capture
self.graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.graph, pool=graph_pool):
q, k, v, r = self._compute(self.attn_in, self.r_in, self.pos_in)
self.q_out.copy_(q)
self.k_out.copy_(k)
self.v_out.copy_(v)
self.r_out.copy_(r)
return self.graph.pool() if graph_pool is None else graph_pool
def forward(
self,
attn_output: torch.Tensor,
residual: torch.Tensor,
positions: torch.Tensor,
use_graph: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
if use_graph and self.graph is not None and attn_output.shape[0] == self.seq_len:
self.attn_in.copy_(attn_output)
self.r_in.copy_(residual)
self.pos_in.copy_(positions)
self.graph.replay()
return self.q_out.clone(), self.k_out.clone(), self.v_out.clone(), self.r_out.clone()
else:
return self._compute(attn_output, residual, positions)
class LastGraph(nn.Module):
"""
Graph wrapper for last layer:
o_proj → post_norm → mlp → final_norm
Input: attn_output [seq_len, num_heads, head_dim], residual [seq_len, hidden_size]
Output: hidden_states [seq_len, hidden_size]
"""
def __init__(
self,
o_proj: nn.Module,
post_norm: nn.Module,
mlp: nn.Module,
final_norm: nn.Module,
# Shape parameters
seq_len: int,
hidden_size: int,
num_heads: int,
head_dim: int,
dtype: torch.dtype = torch.bfloat16,
):
super().__init__()
self.o_proj = o_proj
self.post_norm = post_norm
self.mlp = mlp
self.final_norm = final_norm
self.seq_len = seq_len
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = head_dim
self.dtype = dtype
# Graph state
self.graph: Optional[torch.cuda.CUDAGraph] = None
def _compute(
self,
attn_output: torch.Tensor,
residual: torch.Tensor,
) -> torch.Tensor:
"""
Last layer computation:
1. O projection
2. Post-attention layernorm + residual
3. MLP
4. Final model norm + residual
"""
hidden_states = self.o_proj(attn_output.flatten(1, -1))
hidden_states, residual = self.post_norm(hidden_states, residual)
hidden_states = self.mlp(hidden_states)
hidden_states, _ = self.final_norm(hidden_states, residual)
return hidden_states
def capture_graph(self, graph_pool=None):
"""Capture CUDA Graph."""
# Allocate placeholders
self.attn_in = torch.zeros(self.seq_len, self.num_heads, self.head_dim, dtype=self.dtype, device="cuda")
self.r_in = torch.zeros(self.seq_len, self.hidden_size, dtype=self.dtype, device="cuda")
self.h_out = torch.zeros(self.seq_len, self.hidden_size, dtype=self.dtype, device="cuda")
with torch.inference_mode():
# Warmup
for _ in range(3):
h = self._compute(self.attn_in, self.r_in)
self.h_out.copy_(h)
torch.cuda.synchronize()
# Capture
self.graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.graph, pool=graph_pool):
h = self._compute(self.attn_in, self.r_in)
self.h_out.copy_(h)
return self.graph.pool() if graph_pool is None else graph_pool
def forward(
self,
attn_output: torch.Tensor,
residual: torch.Tensor,
use_graph: bool = False,
) -> torch.Tensor:
if use_graph and self.graph is not None and attn_output.shape[0] == self.seq_len:
self.attn_in.copy_(attn_output)
self.r_in.copy_(residual)
self.graph.replay()
return self.h_out.clone()
else:
return self._compute(attn_output, residual)
class OffloadGraphManager:
"""
Manager for all CUDA Graphs in offload path.
Creates and manages N+2 graphs for N-layer model:
- 1 EmbedGraph
- 1 FirstGraph
- N-1 InterGraphs
- 1 LastGraph
Supports both Prefill and Decode modes via seq_len parameter.
"""
def __init__(
self,
model: nn.Module,
seq_len: int,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
head_dim: int,
dtype: torch.dtype,
):
"""
Initialize graph manager from model.
Args:
model: The CausalLM model (e.g., LlamaForCausalLM)
seq_len: Sequence length (1 for decode, chunk_size for prefill)
hidden_size: Model hidden dimension
num_heads: Number of attention heads
num_kv_heads: Number of KV heads
head_dim: Head dimension
dtype: Data type for tensors
"""
self.seq_len = seq_len
self.hidden_size = hidden_size
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.head_dim = head_dim
self.dtype = dtype
# Access model layers
layers = model.model.layers
num_layers = len(layers)
self.num_layers = num_layers
# Create EmbedGraph
self.embed_graph = EmbedGraph(
embed_tokens=model.model.embed_tokens,
seq_len=seq_len,
hidden_size=hidden_size,
dtype=dtype,
)
# Create FirstGraph: input_norm_0 → qkv_proj_0 → rotary_0
self.first_graph = FirstGraph(
input_norm=layers[0].input_layernorm,
qkv_proj=layers[0].self_attn.qkv_proj,
rotary_emb=layers[0].self_attn.rotary_emb,
seq_len=seq_len,
hidden_size=hidden_size,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
dtype=dtype,
)
# Create InterGraphs: o_proj_i → post_norm_i → mlp_i → input_norm_{i+1} → qkv_proj_{i+1} → rotary_{i+1}
self.inter_graphs = nn.ModuleList()
for i in range(num_layers - 1):
self.inter_graphs.append(InterGraph(
o_proj=layers[i].self_attn.o_proj,
post_norm=layers[i].post_attention_layernorm,
mlp=layers[i].mlp,
next_input_norm=layers[i + 1].input_layernorm,
next_qkv_proj=layers[i + 1].self_attn.qkv_proj,
next_rotary_emb=layers[i + 1].self_attn.rotary_emb,
seq_len=seq_len,
hidden_size=hidden_size,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
dtype=dtype,
))
# Create LastGraph: o_proj_{N-1} → post_norm_{N-1} → mlp_{N-1} → final_norm
self.last_graph = LastGraph(
o_proj=layers[-1].self_attn.o_proj,
post_norm=layers[-1].post_attention_layernorm,
mlp=layers[-1].mlp,
final_norm=model.model.norm,
seq_len=seq_len,
hidden_size=hidden_size,
num_heads=num_heads,
head_dim=head_dim,
dtype=dtype,
)
self.captured = False
self.graph_pool = None
def capture_all(self):
"""Capture all graphs, sharing memory pool."""
graph_pool = None
# Capture embed graph
graph_pool = self.embed_graph.capture_graph(graph_pool)
# Capture first graph
graph_pool = self.first_graph.capture_graph(graph_pool)
# Capture inter-layer graphs
for inter_graph in self.inter_graphs:
graph_pool = inter_graph.capture_graph(graph_pool)
# Capture last graph
graph_pool = self.last_graph.capture_graph(graph_pool)
self.graph_pool = graph_pool
self.captured = True
@property
def num_graphs(self) -> int:
"""Total number of graphs: 1 + 1 + (N-1) + 1 = N+2"""
return 1 + 1 + len(self.inter_graphs) + 1
# Legacy compatibility aliases (for gradual migration)
FirstLayerGraph = FirstGraph
InterLayerGraph = InterGraph
LastLayerGraph = LastGraph