diff --git a/tests/test_cudagraph_memory.py b/tests/test_cudagraph_memory.py new file mode 100644 index 0000000..fdf8d1e --- /dev/null +++ b/tests/test_cudagraph_memory.py @@ -0,0 +1,357 @@ +#!/usr/bin/env python3 +""" +CUDA Graph Memory Analysis Test + +This script analyzes the memory overhead of CUDA Graph at each stage: +1. Model loading +2. StaticCache allocation +3. Warmup runs +4. Graph capture +5. Graph replay + +Usage: + CUDA_VISIBLE_DEVICES=4 python tests/test_cudagraph_memory.py + CUDA_VISIBLE_DEVICES=4 python tests/test_cudagraph_memory.py --model ~/models/Qwen3-0.6B + CUDA_VISIBLE_DEVICES=4 python tests/test_cudagraph_memory.py --max-cache-len 2048 +""" + +import argparse +import os +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers.cache_utils import StaticCache + + +def get_memory_mb(): + """Get current allocated memory in MB.""" + return torch.cuda.memory_allocated() / 1024**2 + + +def get_memory_gb(): + """Get current allocated memory in GB.""" + return torch.cuda.memory_allocated() / 1024**3 + + +def get_peak_memory_gb(): + """Get peak allocated memory in GB.""" + return torch.cuda.max_memory_allocated() / 1024**3 + + +def print_separator(title=None): + """Print a separator line.""" + if title: + print(f"\n{'=' * 70}") + print(f" {title}") + print(f"{'=' * 70}") + else: + print("-" * 70) + + +def test_memory_stages(model_path: str, max_cache_len: int, batch_size: int = 1): + """ + Test memory usage at each stage of CUDA Graph setup. + + Args: + model_path: Path to the model + max_cache_len: Maximum cache length for StaticCache + batch_size: Batch size for inference + """ + print_separator("CUDA Graph Memory Analysis") + print(f"Model: {model_path}") + print(f"Max cache length: {max_cache_len}") + print(f"Batch size: {batch_size}") + + results = {} + + # Stage 0: Initial + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + results["initial"] = get_memory_mb() + + # Stage 1: Load model + print_separator("Stage 1: Model Loading") + model = AutoModelForCausalLM.from_pretrained( + model_path, + torch_dtype=torch.bfloat16, + device_map="cuda", + trust_remote_code=True, + ) + model.eval() + + results["after_model"] = get_memory_mb() + model_size = results["after_model"] - results["initial"] + print(f" Memory: {results['after_model']:.0f} MB") + print(f" Model size: {model_size:.0f} MB ({model_size/1024:.2f} GB)") + + config = model.config + device = next(model.parameters()).device + dtype = next(model.parameters()).dtype + + # Stage 2: Allocate StaticCache + print_separator("Stage 2: StaticCache Allocation") + torch.cuda.reset_peak_memory_stats() + before = get_memory_mb() + + static_cache = StaticCache( + config=config, + max_batch_size=batch_size, + max_cache_len=max_cache_len, + device=device, + dtype=dtype, + ) + + results["after_cache"] = get_memory_mb() + cache_size = results["after_cache"] - before + print(f" Memory: {results['after_cache']:.0f} MB") + print(f" StaticCache size: {cache_size:.0f} MB") + + # Calculate theoretical cache size + num_layers = config.num_hidden_layers + num_kv_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) + head_dim = config.hidden_size // config.num_attention_heads + dtype_size = 2 # bfloat16 + + theoretical_cache = ( + num_layers * 2 * batch_size * num_kv_heads * max_cache_len * head_dim * dtype_size + ) / (1024**2) + print(f" Theoretical: {theoretical_cache:.0f} MB") + print(f" Overhead: {cache_size - theoretical_cache:.0f} MB ({(cache_size/theoretical_cache - 1)*100:.1f}%)") + + # Stage 3: Prepare static tensors + print_separator("Stage 3: Static Tensor Allocation") + before = get_memory_mb() + + static_input_ids = torch.zeros(batch_size, 1, dtype=torch.long, device=device) + static_position_ids = torch.zeros(batch_size, 1, dtype=torch.long, device=device) + static_cache_position = torch.tensor([0], dtype=torch.long, device=device) + + results["after_tensors"] = get_memory_mb() + tensor_size = results["after_tensors"] - before + print(f" Memory: {results['after_tensors']:.0f} MB") + print(f" Static tensors: {tensor_size:.2f} MB (negligible)") + + # Stage 4: Warmup runs + print_separator("Stage 4: Warmup Runs (3 iterations)") + torch.cuda.reset_peak_memory_stats() + before = get_memory_mb() + + with torch.inference_mode(): + for i in range(3): + _ = model( + input_ids=static_input_ids, + position_ids=static_position_ids, + past_key_values=static_cache, + cache_position=static_cache_position, + use_cache=True, + ) + torch.cuda.synchronize() + + results["after_warmup"] = get_memory_mb() + results["warmup_peak"] = get_peak_memory_gb() * 1024 + warmup_size = results["after_warmup"] - before + print(f" Memory: {results['after_warmup']:.0f} MB") + print(f" Peak: {results['warmup_peak']:.0f} MB") + print(f" Warmup overhead: {warmup_size:.0f} MB") + + # Stage 5: CUDA Graph capture + print_separator("Stage 5: CUDA Graph Capture") + torch.cuda.reset_peak_memory_stats() + before = get_memory_mb() + + graph = torch.cuda.CUDAGraph() + with torch.inference_mode(): + with torch.cuda.graph(graph): + outputs = model( + input_ids=static_input_ids, + position_ids=static_position_ids, + past_key_values=static_cache, + cache_position=static_cache_position, + use_cache=True, + ) + static_logits = outputs.logits + torch.cuda.synchronize() + + results["after_capture"] = get_memory_mb() + results["capture_peak"] = get_peak_memory_gb() * 1024 + capture_size = results["after_capture"] - before + print(f" Memory: {results['after_capture']:.0f} MB") + print(f" Peak: {results['capture_peak']:.0f} MB") + print(f" Graph capture overhead: {capture_size:.0f} MB") + + # Stage 6: Graph replay + print_separator("Stage 6: Graph Replay (10 iterations)") + torch.cuda.reset_peak_memory_stats() + before = get_memory_mb() + + with torch.inference_mode(): + for _ in range(10): + static_input_ids.fill_(1) + static_cache_position.fill_(0) + graph.replay() + torch.cuda.synchronize() + + results["after_replay"] = get_memory_mb() + results["replay_peak"] = get_peak_memory_gb() * 1024 + replay_change = results["after_replay"] - before + print(f" Memory: {results['after_replay']:.0f} MB") + print(f" Peak: {results['replay_peak']:.0f} MB") + print(f" Replay memory change: {replay_change:.0f} MB (should be ~0)") + + # Summary + print_separator("SUMMARY") + total_overhead = results["after_capture"] - results["after_model"] + + print(f"{'Stage':<25} {'Memory (MB)':>12} {'Delta (MB)':>12}") + print("-" * 50) + print(f"{'Model loaded':<25} {results['after_model']:>12.0f} {model_size:>+12.0f}") + print(f"{'StaticCache allocated':<25} {results['after_cache']:>12.0f} {cache_size:>+12.0f}") + print(f"{'After warmup':<25} {results['after_warmup']:>12.0f} {warmup_size:>+12.0f}") + print(f"{'After graph capture':<25} {results['after_capture']:>12.0f} {capture_size:>+12.0f}") + print(f"{'After graph replay':<25} {results['after_replay']:>12.0f} {replay_change:>+12.0f}") + print("-" * 50) + print(f"{'Total (excl. model)':<25} {'':<12} {total_overhead:>+12.0f}") + + print_separator("KEY FINDINGS") + print(f" 1. Model size: {model_size/1024:.2f} GB") + print(f" 2. StaticCache: {cache_size:.0f} MB (main overhead, scales with cache_len)") + print(f" 3. Graph capture: {capture_size:.0f} MB (small, stores kernel sequence)") + print(f" 4. Graph replay: {replay_change:.0f} MB (zero allocation, reuses memory)") + print(f" 5. Total CUDA Graph overhead: {total_overhead:.0f} MB") + + return results + + +def test_cache_length_scaling(model_path: str, cache_lengths: list): + """ + Test how memory scales with different cache lengths. + + Args: + model_path: Path to the model + cache_lengths: List of cache lengths to test + """ + print_separator("Cache Length Scaling Test") + print(f"Model: {model_path}") + print(f"Cache lengths: {cache_lengths}") + + # Load model once + model = AutoModelForCausalLM.from_pretrained( + model_path, + torch_dtype=torch.bfloat16, + device_map="cuda", + trust_remote_code=True, + ) + model.eval() + + config = model.config + device = next(model.parameters()).device + dtype = next(model.parameters()).dtype + + model_mem = get_memory_mb() + + results = [] + for cache_len in cache_lengths: + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + + # Create cache and capture graph + static_cache = StaticCache( + config=config, + max_batch_size=1, + max_cache_len=cache_len, + device=device, + dtype=dtype, + ) + + static_input_ids = torch.zeros(1, 1, dtype=torch.long, device=device) + static_position_ids = torch.zeros(1, 1, dtype=torch.long, device=device) + static_cache_position = torch.tensor([0], dtype=torch.long, device=device) + + with torch.inference_mode(): + # Warmup + for _ in range(3): + _ = model( + input_ids=static_input_ids, + position_ids=static_position_ids, + past_key_values=static_cache, + cache_position=static_cache_position, + use_cache=True, + ) + torch.cuda.synchronize() + + # Capture + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + outputs = model( + input_ids=static_input_ids, + position_ids=static_position_ids, + past_key_values=static_cache, + cache_position=static_cache_position, + use_cache=True, + ) + torch.cuda.synchronize() + + total_mem = get_memory_mb() + overhead = total_mem - model_mem + results.append((cache_len, total_mem, overhead)) + + del static_cache, graph + torch.cuda.empty_cache() + + # Print results + print() + print(f"{'Cache Length':>12} | {'Total (MB)':>12} | {'Overhead (MB)':>14} | {'Per 1K tokens':>14}") + print("-" * 60) + for cache_len, total, overhead in results: + per_1k = overhead / (cache_len / 1000) + print(f"{cache_len:>12} | {total:>12.0f} | {overhead:>14.0f} | {per_1k:>14.1f}") + + return results + + +def main(): + parser = argparse.ArgumentParser(description="CUDA Graph Memory Analysis") + parser.add_argument( + "--model", + type=str, + default="~/models/Qwen3-4B-Instruct-2507", + help="Model path", + ) + parser.add_argument( + "--max-cache-len", + type=int, + default=1024, + help="Maximum cache length", + ) + parser.add_argument( + "--batch-size", + type=int, + default=1, + help="Batch size", + ) + parser.add_argument( + "--test-scaling", + action="store_true", + help="Test cache length scaling", + ) + args = parser.parse_args() + + model_path = os.path.expanduser(args.model) + + if not torch.cuda.is_available(): + print("CUDA is not available!") + return + + print(f"Device: cuda:{torch.cuda.current_device()}") + print(f"GPU: {torch.cuda.get_device_name()}") + + if args.test_scaling: + cache_lengths = [256, 512, 1024, 2048, 4096] + test_cache_length_scaling(model_path, cache_lengths) + else: + test_memory_stages(model_path, args.max_cache_len, args.batch_size) + + print("\ntest_cudagraph_memory: PASSED") + + +if __name__ == "__main__": + main()