Files
nano-vllm/tests/test_cudagraph_memory.py
2026-01-21 03:30:36 +08:00

358 lines
12 KiB
Python

#!/usr/bin/env python3
"""
CUDA Graph Memory Analysis Test
This script analyzes the memory overhead of CUDA Graph at each stage:
1. Model loading
2. StaticCache allocation
3. Warmup runs
4. Graph capture
5. Graph replay
Usage:
CUDA_VISIBLE_DEVICES=4 python tests/test_cudagraph_memory.py
CUDA_VISIBLE_DEVICES=4 python tests/test_cudagraph_memory.py --model ~/models/Qwen3-0.6B
CUDA_VISIBLE_DEVICES=4 python tests/test_cudagraph_memory.py --max-cache-len 2048
"""
import argparse
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.cache_utils import StaticCache
def get_memory_mb():
"""Get current allocated memory in MB."""
return torch.cuda.memory_allocated() / 1024**2
def get_memory_gb():
"""Get current allocated memory in GB."""
return torch.cuda.memory_allocated() / 1024**3
def get_peak_memory_gb():
"""Get peak allocated memory in GB."""
return torch.cuda.max_memory_allocated() / 1024**3
def print_separator(title=None):
"""Print a separator line."""
if title:
print(f"\n{'=' * 70}")
print(f" {title}")
print(f"{'=' * 70}")
else:
print("-" * 70)
def test_memory_stages(model_path: str, max_cache_len: int, batch_size: int = 1):
"""
Test memory usage at each stage of CUDA Graph setup.
Args:
model_path: Path to the model
max_cache_len: Maximum cache length for StaticCache
batch_size: Batch size for inference
"""
print_separator("CUDA Graph Memory Analysis")
print(f"Model: {model_path}")
print(f"Max cache length: {max_cache_len}")
print(f"Batch size: {batch_size}")
results = {}
# Stage 0: Initial
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
results["initial"] = get_memory_mb()
# Stage 1: Load model
print_separator("Stage 1: Model Loading")
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
device_map="cuda",
trust_remote_code=True,
)
model.eval()
results["after_model"] = get_memory_mb()
model_size = results["after_model"] - results["initial"]
print(f" Memory: {results['after_model']:.0f} MB")
print(f" Model size: {model_size:.0f} MB ({model_size/1024:.2f} GB)")
config = model.config
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
# Stage 2: Allocate StaticCache
print_separator("Stage 2: StaticCache Allocation")
torch.cuda.reset_peak_memory_stats()
before = get_memory_mb()
static_cache = StaticCache(
config=config,
max_batch_size=batch_size,
max_cache_len=max_cache_len,
device=device,
dtype=dtype,
)
results["after_cache"] = get_memory_mb()
cache_size = results["after_cache"] - before
print(f" Memory: {results['after_cache']:.0f} MB")
print(f" StaticCache size: {cache_size:.0f} MB")
# Calculate theoretical cache size
num_layers = config.num_hidden_layers
num_kv_heads = getattr(config, "num_key_value_heads", config.num_attention_heads)
head_dim = config.hidden_size // config.num_attention_heads
dtype_size = 2 # bfloat16
theoretical_cache = (
num_layers * 2 * batch_size * num_kv_heads * max_cache_len * head_dim * dtype_size
) / (1024**2)
print(f" Theoretical: {theoretical_cache:.0f} MB")
print(f" Overhead: {cache_size - theoretical_cache:.0f} MB ({(cache_size/theoretical_cache - 1)*100:.1f}%)")
# Stage 3: Prepare static tensors
print_separator("Stage 3: Static Tensor Allocation")
before = get_memory_mb()
static_input_ids = torch.zeros(batch_size, 1, dtype=torch.long, device=device)
static_position_ids = torch.zeros(batch_size, 1, dtype=torch.long, device=device)
static_cache_position = torch.tensor([0], dtype=torch.long, device=device)
results["after_tensors"] = get_memory_mb()
tensor_size = results["after_tensors"] - before
print(f" Memory: {results['after_tensors']:.0f} MB")
print(f" Static tensors: {tensor_size:.2f} MB (negligible)")
# Stage 4: Warmup runs
print_separator("Stage 4: Warmup Runs (3 iterations)")
torch.cuda.reset_peak_memory_stats()
before = get_memory_mb()
with torch.inference_mode():
for i in range(3):
_ = model(
input_ids=static_input_ids,
position_ids=static_position_ids,
past_key_values=static_cache,
cache_position=static_cache_position,
use_cache=True,
)
torch.cuda.synchronize()
results["after_warmup"] = get_memory_mb()
results["warmup_peak"] = get_peak_memory_gb() * 1024
warmup_size = results["after_warmup"] - before
print(f" Memory: {results['after_warmup']:.0f} MB")
print(f" Peak: {results['warmup_peak']:.0f} MB")
print(f" Warmup overhead: {warmup_size:.0f} MB")
# Stage 5: CUDA Graph capture
print_separator("Stage 5: CUDA Graph Capture")
torch.cuda.reset_peak_memory_stats()
before = get_memory_mb()
graph = torch.cuda.CUDAGraph()
with torch.inference_mode():
with torch.cuda.graph(graph):
outputs = model(
input_ids=static_input_ids,
position_ids=static_position_ids,
past_key_values=static_cache,
cache_position=static_cache_position,
use_cache=True,
)
static_logits = outputs.logits
torch.cuda.synchronize()
results["after_capture"] = get_memory_mb()
results["capture_peak"] = get_peak_memory_gb() * 1024
capture_size = results["after_capture"] - before
print(f" Memory: {results['after_capture']:.0f} MB")
print(f" Peak: {results['capture_peak']:.0f} MB")
print(f" Graph capture overhead: {capture_size:.0f} MB")
# Stage 6: Graph replay
print_separator("Stage 6: Graph Replay (10 iterations)")
torch.cuda.reset_peak_memory_stats()
before = get_memory_mb()
with torch.inference_mode():
for _ in range(10):
static_input_ids.fill_(1)
static_cache_position.fill_(0)
graph.replay()
torch.cuda.synchronize()
results["after_replay"] = get_memory_mb()
results["replay_peak"] = get_peak_memory_gb() * 1024
replay_change = results["after_replay"] - before
print(f" Memory: {results['after_replay']:.0f} MB")
print(f" Peak: {results['replay_peak']:.0f} MB")
print(f" Replay memory change: {replay_change:.0f} MB (should be ~0)")
# Summary
print_separator("SUMMARY")
total_overhead = results["after_capture"] - results["after_model"]
print(f"{'Stage':<25} {'Memory (MB)':>12} {'Delta (MB)':>12}")
print("-" * 50)
print(f"{'Model loaded':<25} {results['after_model']:>12.0f} {model_size:>+12.0f}")
print(f"{'StaticCache allocated':<25} {results['after_cache']:>12.0f} {cache_size:>+12.0f}")
print(f"{'After warmup':<25} {results['after_warmup']:>12.0f} {warmup_size:>+12.0f}")
print(f"{'After graph capture':<25} {results['after_capture']:>12.0f} {capture_size:>+12.0f}")
print(f"{'After graph replay':<25} {results['after_replay']:>12.0f} {replay_change:>+12.0f}")
print("-" * 50)
print(f"{'Total (excl. model)':<25} {'':<12} {total_overhead:>+12.0f}")
print_separator("KEY FINDINGS")
print(f" 1. Model size: {model_size/1024:.2f} GB")
print(f" 2. StaticCache: {cache_size:.0f} MB (main overhead, scales with cache_len)")
print(f" 3. Graph capture: {capture_size:.0f} MB (small, stores kernel sequence)")
print(f" 4. Graph replay: {replay_change:.0f} MB (zero allocation, reuses memory)")
print(f" 5. Total CUDA Graph overhead: {total_overhead:.0f} MB")
return results
def test_cache_length_scaling(model_path: str, cache_lengths: list):
"""
Test how memory scales with different cache lengths.
Args:
model_path: Path to the model
cache_lengths: List of cache lengths to test
"""
print_separator("Cache Length Scaling Test")
print(f"Model: {model_path}")
print(f"Cache lengths: {cache_lengths}")
# Load model once
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
device_map="cuda",
trust_remote_code=True,
)
model.eval()
config = model.config
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
model_mem = get_memory_mb()
results = []
for cache_len in cache_lengths:
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
# Create cache and capture graph
static_cache = StaticCache(
config=config,
max_batch_size=1,
max_cache_len=cache_len,
device=device,
dtype=dtype,
)
static_input_ids = torch.zeros(1, 1, dtype=torch.long, device=device)
static_position_ids = torch.zeros(1, 1, dtype=torch.long, device=device)
static_cache_position = torch.tensor([0], dtype=torch.long, device=device)
with torch.inference_mode():
# Warmup
for _ in range(3):
_ = model(
input_ids=static_input_ids,
position_ids=static_position_ids,
past_key_values=static_cache,
cache_position=static_cache_position,
use_cache=True,
)
torch.cuda.synchronize()
# Capture
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
outputs = model(
input_ids=static_input_ids,
position_ids=static_position_ids,
past_key_values=static_cache,
cache_position=static_cache_position,
use_cache=True,
)
torch.cuda.synchronize()
total_mem = get_memory_mb()
overhead = total_mem - model_mem
results.append((cache_len, total_mem, overhead))
del static_cache, graph
torch.cuda.empty_cache()
# Print results
print()
print(f"{'Cache Length':>12} | {'Total (MB)':>12} | {'Overhead (MB)':>14} | {'Per 1K tokens':>14}")
print("-" * 60)
for cache_len, total, overhead in results:
per_1k = overhead / (cache_len / 1000)
print(f"{cache_len:>12} | {total:>12.0f} | {overhead:>14.0f} | {per_1k:>14.1f}")
return results
def main():
parser = argparse.ArgumentParser(description="CUDA Graph Memory Analysis")
parser.add_argument(
"--model",
type=str,
default="~/models/Qwen3-4B-Instruct-2507",
help="Model path",
)
parser.add_argument(
"--max-cache-len",
type=int,
default=1024,
help="Maximum cache length",
)
parser.add_argument(
"--batch-size",
type=int,
default=1,
help="Batch size",
)
parser.add_argument(
"--test-scaling",
action="store_true",
help="Test cache length scaling",
)
args = parser.parse_args()
model_path = os.path.expanduser(args.model)
if not torch.cuda.is_available():
print("CUDA is not available!")
return
print(f"Device: cuda:{torch.cuda.current_device()}")
print(f"GPU: {torch.cuda.get_device_name()}")
if args.test_scaling:
cache_lengths = [256, 512, 1024, 2048, 4096]
test_cache_length_scaling(model_path, cache_lengths)
else:
test_memory_stages(model_path, args.max_cache_len, args.batch_size)
print("\ntest_cudagraph_memory: PASSED")
if __name__ == "__main__":
main()