358 lines
12 KiB
Python
358 lines
12 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
CUDA Graph Memory Analysis Test
|
|
|
|
This script analyzes the memory overhead of CUDA Graph at each stage:
|
|
1. Model loading
|
|
2. StaticCache allocation
|
|
3. Warmup runs
|
|
4. Graph capture
|
|
5. Graph replay
|
|
|
|
Usage:
|
|
CUDA_VISIBLE_DEVICES=4 python tests/test_cudagraph_memory.py
|
|
CUDA_VISIBLE_DEVICES=4 python tests/test_cudagraph_memory.py --model ~/models/Qwen3-0.6B
|
|
CUDA_VISIBLE_DEVICES=4 python tests/test_cudagraph_memory.py --max-cache-len 2048
|
|
"""
|
|
|
|
import argparse
|
|
import os
|
|
import torch
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
from transformers.cache_utils import StaticCache
|
|
|
|
|
|
def get_memory_mb():
|
|
"""Get current allocated memory in MB."""
|
|
return torch.cuda.memory_allocated() / 1024**2
|
|
|
|
|
|
def get_memory_gb():
|
|
"""Get current allocated memory in GB."""
|
|
return torch.cuda.memory_allocated() / 1024**3
|
|
|
|
|
|
def get_peak_memory_gb():
|
|
"""Get peak allocated memory in GB."""
|
|
return torch.cuda.max_memory_allocated() / 1024**3
|
|
|
|
|
|
def print_separator(title=None):
|
|
"""Print a separator line."""
|
|
if title:
|
|
print(f"\n{'=' * 70}")
|
|
print(f" {title}")
|
|
print(f"{'=' * 70}")
|
|
else:
|
|
print("-" * 70)
|
|
|
|
|
|
def test_memory_stages(model_path: str, max_cache_len: int, batch_size: int = 1):
|
|
"""
|
|
Test memory usage at each stage of CUDA Graph setup.
|
|
|
|
Args:
|
|
model_path: Path to the model
|
|
max_cache_len: Maximum cache length for StaticCache
|
|
batch_size: Batch size for inference
|
|
"""
|
|
print_separator("CUDA Graph Memory Analysis")
|
|
print(f"Model: {model_path}")
|
|
print(f"Max cache length: {max_cache_len}")
|
|
print(f"Batch size: {batch_size}")
|
|
|
|
results = {}
|
|
|
|
# Stage 0: Initial
|
|
torch.cuda.empty_cache()
|
|
torch.cuda.reset_peak_memory_stats()
|
|
results["initial"] = get_memory_mb()
|
|
|
|
# Stage 1: Load model
|
|
print_separator("Stage 1: Model Loading")
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
model_path,
|
|
torch_dtype=torch.bfloat16,
|
|
device_map="cuda",
|
|
trust_remote_code=True,
|
|
)
|
|
model.eval()
|
|
|
|
results["after_model"] = get_memory_mb()
|
|
model_size = results["after_model"] - results["initial"]
|
|
print(f" Memory: {results['after_model']:.0f} MB")
|
|
print(f" Model size: {model_size:.0f} MB ({model_size/1024:.2f} GB)")
|
|
|
|
config = model.config
|
|
device = next(model.parameters()).device
|
|
dtype = next(model.parameters()).dtype
|
|
|
|
# Stage 2: Allocate StaticCache
|
|
print_separator("Stage 2: StaticCache Allocation")
|
|
torch.cuda.reset_peak_memory_stats()
|
|
before = get_memory_mb()
|
|
|
|
static_cache = StaticCache(
|
|
config=config,
|
|
max_batch_size=batch_size,
|
|
max_cache_len=max_cache_len,
|
|
device=device,
|
|
dtype=dtype,
|
|
)
|
|
|
|
results["after_cache"] = get_memory_mb()
|
|
cache_size = results["after_cache"] - before
|
|
print(f" Memory: {results['after_cache']:.0f} MB")
|
|
print(f" StaticCache size: {cache_size:.0f} MB")
|
|
|
|
# Calculate theoretical cache size
|
|
num_layers = config.num_hidden_layers
|
|
num_kv_heads = getattr(config, "num_key_value_heads", config.num_attention_heads)
|
|
head_dim = config.hidden_size // config.num_attention_heads
|
|
dtype_size = 2 # bfloat16
|
|
|
|
theoretical_cache = (
|
|
num_layers * 2 * batch_size * num_kv_heads * max_cache_len * head_dim * dtype_size
|
|
) / (1024**2)
|
|
print(f" Theoretical: {theoretical_cache:.0f} MB")
|
|
print(f" Overhead: {cache_size - theoretical_cache:.0f} MB ({(cache_size/theoretical_cache - 1)*100:.1f}%)")
|
|
|
|
# Stage 3: Prepare static tensors
|
|
print_separator("Stage 3: Static Tensor Allocation")
|
|
before = get_memory_mb()
|
|
|
|
static_input_ids = torch.zeros(batch_size, 1, dtype=torch.long, device=device)
|
|
static_position_ids = torch.zeros(batch_size, 1, dtype=torch.long, device=device)
|
|
static_cache_position = torch.tensor([0], dtype=torch.long, device=device)
|
|
|
|
results["after_tensors"] = get_memory_mb()
|
|
tensor_size = results["after_tensors"] - before
|
|
print(f" Memory: {results['after_tensors']:.0f} MB")
|
|
print(f" Static tensors: {tensor_size:.2f} MB (negligible)")
|
|
|
|
# Stage 4: Warmup runs
|
|
print_separator("Stage 4: Warmup Runs (3 iterations)")
|
|
torch.cuda.reset_peak_memory_stats()
|
|
before = get_memory_mb()
|
|
|
|
with torch.inference_mode():
|
|
for i in range(3):
|
|
_ = model(
|
|
input_ids=static_input_ids,
|
|
position_ids=static_position_ids,
|
|
past_key_values=static_cache,
|
|
cache_position=static_cache_position,
|
|
use_cache=True,
|
|
)
|
|
torch.cuda.synchronize()
|
|
|
|
results["after_warmup"] = get_memory_mb()
|
|
results["warmup_peak"] = get_peak_memory_gb() * 1024
|
|
warmup_size = results["after_warmup"] - before
|
|
print(f" Memory: {results['after_warmup']:.0f} MB")
|
|
print(f" Peak: {results['warmup_peak']:.0f} MB")
|
|
print(f" Warmup overhead: {warmup_size:.0f} MB")
|
|
|
|
# Stage 5: CUDA Graph capture
|
|
print_separator("Stage 5: CUDA Graph Capture")
|
|
torch.cuda.reset_peak_memory_stats()
|
|
before = get_memory_mb()
|
|
|
|
graph = torch.cuda.CUDAGraph()
|
|
with torch.inference_mode():
|
|
with torch.cuda.graph(graph):
|
|
outputs = model(
|
|
input_ids=static_input_ids,
|
|
position_ids=static_position_ids,
|
|
past_key_values=static_cache,
|
|
cache_position=static_cache_position,
|
|
use_cache=True,
|
|
)
|
|
static_logits = outputs.logits
|
|
torch.cuda.synchronize()
|
|
|
|
results["after_capture"] = get_memory_mb()
|
|
results["capture_peak"] = get_peak_memory_gb() * 1024
|
|
capture_size = results["after_capture"] - before
|
|
print(f" Memory: {results['after_capture']:.0f} MB")
|
|
print(f" Peak: {results['capture_peak']:.0f} MB")
|
|
print(f" Graph capture overhead: {capture_size:.0f} MB")
|
|
|
|
# Stage 6: Graph replay
|
|
print_separator("Stage 6: Graph Replay (10 iterations)")
|
|
torch.cuda.reset_peak_memory_stats()
|
|
before = get_memory_mb()
|
|
|
|
with torch.inference_mode():
|
|
for _ in range(10):
|
|
static_input_ids.fill_(1)
|
|
static_cache_position.fill_(0)
|
|
graph.replay()
|
|
torch.cuda.synchronize()
|
|
|
|
results["after_replay"] = get_memory_mb()
|
|
results["replay_peak"] = get_peak_memory_gb() * 1024
|
|
replay_change = results["after_replay"] - before
|
|
print(f" Memory: {results['after_replay']:.0f} MB")
|
|
print(f" Peak: {results['replay_peak']:.0f} MB")
|
|
print(f" Replay memory change: {replay_change:.0f} MB (should be ~0)")
|
|
|
|
# Summary
|
|
print_separator("SUMMARY")
|
|
total_overhead = results["after_capture"] - results["after_model"]
|
|
|
|
print(f"{'Stage':<25} {'Memory (MB)':>12} {'Delta (MB)':>12}")
|
|
print("-" * 50)
|
|
print(f"{'Model loaded':<25} {results['after_model']:>12.0f} {model_size:>+12.0f}")
|
|
print(f"{'StaticCache allocated':<25} {results['after_cache']:>12.0f} {cache_size:>+12.0f}")
|
|
print(f"{'After warmup':<25} {results['after_warmup']:>12.0f} {warmup_size:>+12.0f}")
|
|
print(f"{'After graph capture':<25} {results['after_capture']:>12.0f} {capture_size:>+12.0f}")
|
|
print(f"{'After graph replay':<25} {results['after_replay']:>12.0f} {replay_change:>+12.0f}")
|
|
print("-" * 50)
|
|
print(f"{'Total (excl. model)':<25} {'':<12} {total_overhead:>+12.0f}")
|
|
|
|
print_separator("KEY FINDINGS")
|
|
print(f" 1. Model size: {model_size/1024:.2f} GB")
|
|
print(f" 2. StaticCache: {cache_size:.0f} MB (main overhead, scales with cache_len)")
|
|
print(f" 3. Graph capture: {capture_size:.0f} MB (small, stores kernel sequence)")
|
|
print(f" 4. Graph replay: {replay_change:.0f} MB (zero allocation, reuses memory)")
|
|
print(f" 5. Total CUDA Graph overhead: {total_overhead:.0f} MB")
|
|
|
|
return results
|
|
|
|
|
|
def test_cache_length_scaling(model_path: str, cache_lengths: list):
|
|
"""
|
|
Test how memory scales with different cache lengths.
|
|
|
|
Args:
|
|
model_path: Path to the model
|
|
cache_lengths: List of cache lengths to test
|
|
"""
|
|
print_separator("Cache Length Scaling Test")
|
|
print(f"Model: {model_path}")
|
|
print(f"Cache lengths: {cache_lengths}")
|
|
|
|
# Load model once
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
model_path,
|
|
torch_dtype=torch.bfloat16,
|
|
device_map="cuda",
|
|
trust_remote_code=True,
|
|
)
|
|
model.eval()
|
|
|
|
config = model.config
|
|
device = next(model.parameters()).device
|
|
dtype = next(model.parameters()).dtype
|
|
|
|
model_mem = get_memory_mb()
|
|
|
|
results = []
|
|
for cache_len in cache_lengths:
|
|
torch.cuda.empty_cache()
|
|
torch.cuda.reset_peak_memory_stats()
|
|
|
|
# Create cache and capture graph
|
|
static_cache = StaticCache(
|
|
config=config,
|
|
max_batch_size=1,
|
|
max_cache_len=cache_len,
|
|
device=device,
|
|
dtype=dtype,
|
|
)
|
|
|
|
static_input_ids = torch.zeros(1, 1, dtype=torch.long, device=device)
|
|
static_position_ids = torch.zeros(1, 1, dtype=torch.long, device=device)
|
|
static_cache_position = torch.tensor([0], dtype=torch.long, device=device)
|
|
|
|
with torch.inference_mode():
|
|
# Warmup
|
|
for _ in range(3):
|
|
_ = model(
|
|
input_ids=static_input_ids,
|
|
position_ids=static_position_ids,
|
|
past_key_values=static_cache,
|
|
cache_position=static_cache_position,
|
|
use_cache=True,
|
|
)
|
|
torch.cuda.synchronize()
|
|
|
|
# Capture
|
|
graph = torch.cuda.CUDAGraph()
|
|
with torch.cuda.graph(graph):
|
|
outputs = model(
|
|
input_ids=static_input_ids,
|
|
position_ids=static_position_ids,
|
|
past_key_values=static_cache,
|
|
cache_position=static_cache_position,
|
|
use_cache=True,
|
|
)
|
|
torch.cuda.synchronize()
|
|
|
|
total_mem = get_memory_mb()
|
|
overhead = total_mem - model_mem
|
|
results.append((cache_len, total_mem, overhead))
|
|
|
|
del static_cache, graph
|
|
torch.cuda.empty_cache()
|
|
|
|
# Print results
|
|
print()
|
|
print(f"{'Cache Length':>12} | {'Total (MB)':>12} | {'Overhead (MB)':>14} | {'Per 1K tokens':>14}")
|
|
print("-" * 60)
|
|
for cache_len, total, overhead in results:
|
|
per_1k = overhead / (cache_len / 1000)
|
|
print(f"{cache_len:>12} | {total:>12.0f} | {overhead:>14.0f} | {per_1k:>14.1f}")
|
|
|
|
return results
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="CUDA Graph Memory Analysis")
|
|
parser.add_argument(
|
|
"--model",
|
|
type=str,
|
|
default="~/models/Qwen3-4B-Instruct-2507",
|
|
help="Model path",
|
|
)
|
|
parser.add_argument(
|
|
"--max-cache-len",
|
|
type=int,
|
|
default=1024,
|
|
help="Maximum cache length",
|
|
)
|
|
parser.add_argument(
|
|
"--batch-size",
|
|
type=int,
|
|
default=1,
|
|
help="Batch size",
|
|
)
|
|
parser.add_argument(
|
|
"--test-scaling",
|
|
action="store_true",
|
|
help="Test cache length scaling",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
model_path = os.path.expanduser(args.model)
|
|
|
|
if not torch.cuda.is_available():
|
|
print("CUDA is not available!")
|
|
return
|
|
|
|
print(f"Device: cuda:{torch.cuda.current_device()}")
|
|
print(f"GPU: {torch.cuda.get_device_name()}")
|
|
|
|
if args.test_scaling:
|
|
cache_lengths = [256, 512, 1024, 2048, 4096]
|
|
test_cache_length_scaling(model_path, cache_lengths)
|
|
else:
|
|
test_memory_stages(model_path, args.max_cache_len, args.batch_size)
|
|
|
|
print("\ntest_cudagraph_memory: PASSED")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|