From 62b8a63314747c557160c42ab9b5c394e522987f Mon Sep 17 00:00:00 2001 From: Zijie Tian Date: Thu, 1 Jan 2026 03:32:26 +0800 Subject: [PATCH] [refactor] Refactor the test_chunked_prefill/decode. --- tests/test_chunked_decode_hook.py | 508 ++++++++++------------------ tests/test_chunked_prefill_hook.py | 517 +++++++---------------------- 2 files changed, 294 insertions(+), 731 deletions(-) diff --git a/tests/test_chunked_decode_hook.py b/tests/test_chunked_decode_hook.py index 90cce90..4113381 100644 --- a/tests/test_chunked_decode_hook.py +++ b/tests/test_chunked_decode_hook.py @@ -1,374 +1,214 @@ """ -Hook-based correctness test for chunked decode attention. +Correctness test for chunked decode attention. -Uses PyTorch register_forward_hook() to capture real inference I/O, -then compares against reference computation to locate bugs. - -This test targets the decode phase with CPU offload - after prefill, -the model generates tokens one by one while attending to all previous context. +Captures Q and output during inference, then computes reference using +CPU KV cache with standard flash attention. """ import os -os.environ["NANOVLLM_LOG_LEVEL"] = "INFO" +os.environ["NANOVLLM_LOG_LEVEL"] = "WARNING" import torch from random import randint, seed +from typing import Dict, List from nanovllm import LLM, SamplingParams from nanovllm.utils.context import get_context -from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs +from flash_attn.flash_attn_interface import flash_attn_func - -# ============================================================ -# Configuration -# ============================================================ - -MODEL_PATH = os.path.expanduser("~/models/Qwen3-0.6B/") -MAX_MODEL_LEN = 8 * 1024 +# Config +MODEL_PATH = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/") +MAX_MODEL_LEN = 128 * 1024 NUM_GPU_BLOCKS = 2 -INPUT_LEN = 2 * 1024 # 2K tokens for prefill -NUM_DECODE_TOKENS = 5 # Generate 5 tokens to test decode +INPUT_LEN = 16 * 1024 +NUM_DECODE_TOKENS = 5 BLOCK_SIZE = 1024 - -# ============================================================ -# Global capture storage -# ============================================================ - -captures = [] -prefill_kv = {} # Store prefill k,v for reference computation +# State +prefill_captures: List[Dict] = [] +decode_captures: List[Dict] = [] -# ============================================================ -# Hook Functions -# ============================================================ - -def make_hook(layer_id): - """Create a forward hook for a specific layer.""" - def hook(module, inputs, output): - q, k, v = inputs - ctx = get_context() - - is_prefill = ctx.is_prefill - - capture_entry = { - 'layer_id': layer_id, - 'is_prefill': is_prefill, - 'q': q.clone().cpu(), - 'k': k.clone().cpu(), - 'v': v.clone().cpu(), - 'output': output.clone().cpu(), - 'is_chunked_prefill': ctx.is_chunked_prefill, - } - - if is_prefill: - # Store prefill k,v for reference computation - chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0 - capture_entry['chunk_idx'] = chunk_idx - if layer_id not in prefill_kv: - prefill_kv[layer_id] = [] - prefill_kv[layer_id].append({ - 'chunk_idx': chunk_idx, - 'k': k.clone().cpu(), - 'v': v.clone().cpu(), - }) - else: - # Decode phase - capture decode token info - capture_entry['decode_step'] = len([c for c in captures - if c['layer_id'] == layer_id and not c['is_prefill']]) - - captures.append(capture_entry) +def make_ones_injection_hook(): + """Inject Q=K=V=1.0 for deterministic testing.""" + def hook(module, inputs): + q, k, v = inputs[0], inputs[1], inputs[2] + q_ones = torch.ones_like(q) + k_ones = torch.ones_like(k) + v_ones = torch.ones_like(v) + return (q_ones, k_ones, v_ones) + inputs[3:] return hook -def register_hooks(llm): - """Register forward hooks on all Attention modules.""" - hooks = [] - model = llm.model_runner.model +def make_capture_hook(layer_id: int): + """Capture Q, K, V, output during inference.""" + def hook(module, inputs, output): + ctx = get_context() + q, k, v = inputs - for layer_idx, decoder_layer in enumerate(model.model.layers): - attn_module = decoder_layer.self_attn.attn - hook = attn_module.register_forward_hook(make_hook(layer_idx)) - hooks.append(hook) - - return hooks - - -# ============================================================ -# Reference Computation -# ============================================================ - -def compute_decode_reference(layer_id, decode_step, scale, debug=False): - """ - Compute reference decode attention output for a specific layer. - - For decode, the query is a single token that attends to: - 1. All prefill KV (from CPU cache) - 2. All previous decode tokens (stored in GPU decode slot) - """ - # Get the decode capture - decode_captures = [c for c in captures - if c['layer_id'] == layer_id and not c['is_prefill']] - if decode_step >= len(decode_captures): - return None - - decode_capture = decode_captures[decode_step] - q = decode_capture['q'].cuda() # [1, num_heads, head_dim] - q_batched = q.unsqueeze(1) # [1, 1, num_heads, head_dim] - - if debug: - print(f" Reference for L{layer_id} D{decode_step}:") - print(f" q shape: {q_batched.shape}, mean={q_batched.mean().item():.4f}") - - o_acc, lse_acc = None, None - - # Attend to all prefill chunks - if layer_id in prefill_kv: - for chunk_data in sorted(prefill_kv[layer_id], key=lambda x: x['chunk_idx']): - k = chunk_data['k'].cuda().unsqueeze(0) # [1, seqlen, kv_heads, head_dim] - v = chunk_data['v'].cuda().unsqueeze(0) - - o, lse = flash_attn_with_lse(q_batched, k, v, softmax_scale=scale, causal=False) - - if debug: - print(f" Prefill chunk {chunk_data['chunk_idx']}: o.mean={o.mean().item():.6f}") - - if o_acc is None: - o_acc, lse_acc = o, lse - else: - o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o, lse) - - # Attend to previous decode tokens (including current) - # In decode, the current token's k,v are stored, and we need to attend to all previous decode tokens - # For step 0, we just have the current token's k,v - # For step 1, we have tokens 0 and 1's k,v - # etc. - - # Collect k,v from all decode steps up to and including current - decode_kv = [] - for i in range(decode_step + 1): - if i < len(decode_captures): - decode_kv.append({ - 'k': decode_captures[i]['k'].cuda(), - 'v': decode_captures[i]['v'].cuda(), + if ctx.is_prefill: + chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0 + prefill_captures.append({ + 'layer_id': layer_id, + 'chunk_idx': chunk_idx, + 'q': q.clone().cpu(), + 'k': k.clone().cpu(), + 'v': v.clone().cpu(), + 'output': output.clone().cpu(), }) - - if decode_kv: - # Stack decode k,v into a single tensor - decode_k = torch.cat([d['k'] for d in decode_kv], dim=0).unsqueeze(0) # [1, num_decode, kv_heads, head_dim] - decode_v = torch.cat([d['v'] for d in decode_kv], dim=0).unsqueeze(0) - - if debug: - print(f" Decode tokens: {len(decode_kv)}, k.shape={decode_k.shape}") - - # For decode, we use causal=False since we're attending to all decode tokens - # (the causal masking was already handled by only including tokens up to current) - o_decode, lse_decode = flash_attn_with_lse(q_batched, decode_k, decode_v, - softmax_scale=scale, causal=False) - - if debug: - print(f" Decode attention: o.mean={o_decode.mean().item():.6f}") - - if o_acc is None: - o_acc, lse_acc = o_decode, lse_decode else: - o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o_decode, lse_decode) + decode_step = len([c for c in decode_captures if c['layer_id'] == layer_id]) + decode_captures.append({ + 'layer_id': layer_id, + 'decode_step': decode_step, + 'q': q.clone().cpu(), + 'k': k.clone().cpu(), + 'v': v.clone().cpu(), + 'output': output.clone().cpu(), + }) + return hook - if o_acc is None: + +def compute_decode_reference(layer_id: int, decode_step: int, scale: float, + k_cache_cpu: torch.Tensor, v_cache_cpu: torch.Tensor, + block_size: int, num_prefill_chunks: int) -> torch.Tensor: + """ + Compute reference decode output using CPU KV cache and standard flash attention. + + For decode, query attends to: + 1. All prefill KV (from CPU cache) + 2. All previous decode tokens (from captured decode k, v) + """ + # Get decode capture for this layer and step + decode_cap = None + for c in decode_captures: + if c['layer_id'] == layer_id and c['decode_step'] == decode_step: + decode_cap = c + break + + if decode_cap is None: return None - if debug: - print(f" Final: o.mean={o_acc.mean().item():.6f}") + # Query: single decode token + q = decode_cap['q'].cuda() # [1, num_heads, head_dim] + q_batched = q.unsqueeze(0) # [1, 1, num_heads, head_dim] - return o_acc.squeeze(0).squeeze(0).cpu() # [num_heads, head_dim] + # Collect all K, V: prefill chunks from CPU cache + decode tokens from captures + all_k = [] + all_v = [] + # 1. Prefill chunks from CPU cache + for cidx in range(num_prefill_chunks): + # Get prefill capture to know the sequence length for this chunk + prefill_cap = None + for c in prefill_captures: + if c['layer_id'] == layer_id and c['chunk_idx'] == cidx: + prefill_cap = c + break -# ============================================================ -# Test Runner -# ============================================================ + if prefill_cap is not None: + seq_len = prefill_cap['q'].shape[0] + k = k_cache_cpu[layer_id, cidx, :seq_len].cuda() + v = v_cache_cpu[layer_id, cidx, :seq_len].cuda() + all_k.append(k) + all_v.append(v) -def run_test(verbose=True): - """Run the hook-based chunked decode correctness test.""" - global captures, prefill_kv - captures = [] - prefill_kv = {} + # 2. Decode tokens from captures (up to and including current step) + for step in range(decode_step + 1): + for c in decode_captures: + if c['layer_id'] == layer_id and c['decode_step'] == step: + all_k.append(c['k'].cuda()) + all_v.append(c['v'].cuda()) + break - if verbose: - print("=" * 70) - print("Test: Hook-Based Chunked Decode Correctness") - print("=" * 70) - print(f"Model: {MODEL_PATH}") - print(f"Input length: {INPUT_LEN} tokens") - print(f"Decode tokens: {NUM_DECODE_TOKENS}") - print(f"Block size: {BLOCK_SIZE}") - print() + if not all_k: + return None - # Initialize LLM with CPU offload - llm = LLM( - MODEL_PATH, - enforce_eager=True, - max_model_len=MAX_MODEL_LEN, - max_num_batched_tokens=MAX_MODEL_LEN, - enable_cpu_offload=True, - kvcache_block_size=BLOCK_SIZE, - num_gpu_blocks=NUM_GPU_BLOCKS, + # Concatenate all K, V + full_k = torch.cat(all_k, dim=0).unsqueeze(0) # [1, total_len, kv_heads, head_dim] + full_v = torch.cat(all_v, dim=0).unsqueeze(0) + + # Run flash attention (non-causal since we explicitly control what KV to include) + output = flash_attn_func( + q_batched, full_k, full_v, + softmax_scale=scale, + causal=False, ) - # Get model info - num_layers = len(llm.model_runner.model.model.layers) - head_dim = llm.model_runner.model.model.layers[0].self_attn.attn.head_dim - scale = head_dim ** -0.5 - - if verbose: - print(f"Num layers: {num_layers}") - print(f"Head dim: {head_dim}") - print() - - # Register hooks - hooks = register_hooks(llm) - if verbose: - print(f"Registered {len(hooks)} hooks") - - # Generate random prompt - seed(42) - prompt_token_ids = [[randint(0, 10000) for _ in range(INPUT_LEN)]] - - # Run prefill and decode - if verbose: - print(f"Running inference with {NUM_DECODE_TOKENS} decode tokens...") - sampling_params = SamplingParams(temperature=0.6, max_tokens=NUM_DECODE_TOKENS) - outputs = llm.generate(prompt_token_ids, sampling_params, use_tqdm=False) - - # Remove hooks - for hook in hooks: - hook.remove() - - # =========== VERIFICATION: Check CPU cache after prefill =========== - # Verify that CPU cache data matches captured prefill k,v - if verbose: - print("\n--- CPU Cache Verification (After Prefill) ---") - offload_engine = llm.model_runner.kvcache_manager.offload_engine - - # For each prefill capture, check if CPU cache matches - for layer_id in [0]: # Only check layer 0 for brevity - if layer_id not in prefill_kv: - continue - - for chunk_data in prefill_kv[layer_id]: - chunk_idx = chunk_data['chunk_idx'] - captured_k = chunk_data['k'] # [block_size, kv_heads, head_dim] - - # CPU block ID should be chunk_idx (based on allocation order) - cpu_block_id = chunk_idx - cpu_k = offload_engine.k_cache_cpu[layer_id, cpu_block_id].cpu() - - diff = (captured_k - cpu_k).abs().max().item() - print(f"Layer {layer_id}, Chunk {chunk_idx}: captured_k vs cpu_k max_diff={diff:.6f}") - if diff > 1e-3: - print(f" WARNING: CPU cache doesn't match captured k!") - print(f" captured_k[0,0,:5] = {captured_k[0,0,:5].tolist()}") - print(f" cpu_k[0,0,:5] = {cpu_k[0,0,:5].tolist()}") - print() - - # Analyze captures - prefill_count = sum(1 for c in captures if c['is_prefill']) - decode_count = sum(1 for c in captures if not c['is_prefill']) - if verbose: - print(f"\nCaptured {prefill_count} prefill calls, {decode_count} decode calls") - - # Count decode steps per layer - decode_per_layer = {} - for c in captures: - if not c['is_prefill']: - layer_id = c['layer_id'] - if layer_id not in decode_per_layer: - decode_per_layer[layer_id] = 0 - decode_per_layer[layer_id] += 1 - - if verbose: - print(f"Decode calls per layer: {decode_per_layer}") - print() - - # Verify decode correctness - all_passed = True - results = [] - first_fail_debug = True - - for c in captures: - if c['is_prefill']: - continue # Skip prefill (already tested in test_chunked_prefill_hook.py) - - layer_id = c['layer_id'] - decode_step = c['decode_step'] - - # Only test first decode step for now (simpler reference computation) - if decode_step > 0: - continue - - # Compute reference (debug first failure) - debug_this = (layer_id == 0 and first_fail_debug) - ref_output = compute_decode_reference(layer_id, decode_step, scale, debug=debug_this) - if ref_output is None: - continue - - # Compare - actual_output = c['output'].squeeze(0) # Remove seq dim for decode - if actual_output.dim() == 3: - actual_output = actual_output.squeeze(0) # Handle [1, heads, dim] case - - diff = (actual_output - ref_output).abs() - max_diff = diff.max().item() - mean_diff = diff.mean().item() - - tol = 1e-2 - passed = max_diff < tol - all_passed = all_passed and passed - - status = "PASS" if passed else "FAIL" - results.append((layer_id, decode_step, passed, max_diff, mean_diff)) - - if verbose: - print(f"[{status}] Layer {layer_id:2d}, Decode {decode_step}: " - f"max_diff={max_diff:.6f} mean_diff={mean_diff:.8f}") - - # Debug first failure - if not passed and first_fail_debug: - first_fail_debug = False - print(f" Debug: actual_output shape={actual_output.shape}, mean={actual_output.mean().item():.4f}") - print(f" Debug: ref_output shape={ref_output.shape}, mean={ref_output.mean().item():.4f}") - # Find where max diff is - max_idx = diff.argmax() - flat_actual = actual_output.flatten() - flat_ref = ref_output.flatten() - print(f" Debug: max_diff at idx={max_idx.item()}, actual={flat_actual[max_idx].item():.4f}, ref={flat_ref[max_idx].item():.4f}") - - print() - print("=" * 70) - - # Summary - total_tests = len(results) - passed_count = sum(1 for r in results if r[2]) - - print(f"Results: {passed_count}/{total_tests} tests passed") - - if not all_passed: - print("\nFailed tests:") - for layer_id, decode_step, passed, max_diff, mean_diff in results: - if not passed: - print(f" - Layer {layer_id}, Decode {decode_step}: max_diff={max_diff:.6f}") - - print() - return all_passed + return output.squeeze(0).squeeze(0).cpu() # [num_heads, head_dim] # ============================================================ # Main # ============================================================ -if __name__ == "__main__": - passed = run_test(verbose=True) +llm = LLM( + MODEL_PATH, + enforce_eager=True, + max_model_len=MAX_MODEL_LEN, + max_num_batched_tokens=MAX_MODEL_LEN, + enable_cpu_offload=True, + kvcache_block_size=BLOCK_SIZE, + num_gpu_blocks=NUM_GPU_BLOCKS, + dtype="float16", +) - if passed: - print("test_chunked_decode_hook: PASSED") - else: - print("test_chunked_decode_hook: FAILED") - exit(1) +# Get model info +num_layers = len(llm.model_runner.model.model.layers) +head_dim = llm.model_runner.model.model.layers[0].self_attn.attn.head_dim +scale = head_dim ** -0.5 + +# Register hooks +hooks = [] +for layer_idx, decoder_layer in enumerate(llm.model_runner.model.model.layers): + # Pre-hook: inject all ones for Q, K, V + # pre_hook = decoder_layer.self_attn.attn.register_forward_pre_hook(make_ones_injection_hook()) + # hooks.append(pre_hook) + # Post-hook: capture Q, K, V, output + post_hook = decoder_layer.self_attn.attn.register_forward_hook(make_capture_hook(layer_idx)) + hooks.append(post_hook) + +# Run inference +seed(42) +prompt_token_ids = [[randint(0, 10000) for _ in range(INPUT_LEN)]] +outputs = llm.generate(prompt_token_ids, SamplingParams(temperature=0.6, max_tokens=NUM_DECODE_TOKENS), use_tqdm=False) + +# Remove hooks +for hook in hooks: + hook.remove() + +# Get CPU cache reference +offload_engine = llm.model_runner.kvcache_manager.offload_engine +k_cache_cpu = offload_engine.k_cache_cpu.clone() +v_cache_cpu = offload_engine.v_cache_cpu.clone() + +# Calculate number of prefill chunks +num_prefill_chunks = INPUT_LEN // BLOCK_SIZE + +# Verify decode outputs +all_passed = True + +for c in decode_captures: + layer_id = c['layer_id'] + decode_step = c['decode_step'] + + ref_output = compute_decode_reference( + layer_id, decode_step, scale, + k_cache_cpu, v_cache_cpu, BLOCK_SIZE, num_prefill_chunks + ) + if ref_output is None: + continue + + actual_output = c['output'].squeeze(0) + if actual_output.dim() == 3: + actual_output = actual_output.squeeze(0) + + diff = (actual_output - ref_output).abs() + max_diff = diff.max().item() + + passed = max_diff < 1e-1 + all_passed = all_passed and passed + + # if not passed: + print(f"[FAIL] Layer {layer_id}, Decode {decode_step}: max_diff={max_diff:.6f}") + +print(f"test_chunked_decode_hook: {'PASSED' if all_passed else 'FAILED'}") diff --git a/tests/test_chunked_prefill_hook.py b/tests/test_chunked_prefill_hook.py index 46b7278..cd00429 100644 --- a/tests/test_chunked_prefill_hook.py +++ b/tests/test_chunked_prefill_hook.py @@ -1,203 +1,111 @@ """ -Hook-based correctness test for chunked prefill attention. +Correctness test for chunked prefill attention. -Uses PyTorch register_forward_hook() to capture real inference I/O, -then compares against reference computation to locate bugs. - -This test targets the integration layer (context setup, cpu_block_table management) -which is where the needle test fails despite isolated attention tests passing. +Captures Q and output during inference, then computes reference using +CPU KV cache with standard flash attention. """ import os -os.environ["NANOVLLM_LOG_LEVEL"] = "DEBUG" +os.environ["NANOVLLM_LOG_LEVEL"] = "WARNING" import torch from random import randint, seed +from typing import Dict, List from nanovllm import LLM, SamplingParams from nanovllm.utils.context import get_context -from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs from flash_attn.flash_attn_interface import flash_attn_varlen_func - -# ============================================================ -# Configuration -# ============================================================ - +# Config MODEL_PATH = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/") -MAX_MODEL_LEN = 32 * 1024 +MAX_MODEL_LEN = 128 * 1024 NUM_GPU_BLOCKS = 2 -INPUT_LEN = 16 * 1024 # 4K tokens = 4 chunks with 1K block size +INPUT_LEN = 16 * 1024 BLOCK_SIZE = 1024 - -# ============================================================ -# Global capture storage -# ============================================================ - -captures = [] +# State - capture Q and output for each (layer, chunk) +captures: List[Dict] = [] -# ============================================================ -# Hook Functions -# ============================================================ - -def make_hook(layer_id): - """Create a forward hook for a specific layer.""" - def hook(module, inputs, output): - q, k, v = inputs +def make_ones_injection_hook(): + """Inject Q=K=V=1.0 for deterministic testing.""" + def hook(module, inputs): ctx = get_context() + if not ctx.is_prefill: + return inputs - # Only capture prefill phase + q, k, v = inputs[0], inputs[1], inputs[2] + q_ones = torch.ones_like(q) + k_ones = torch.ones_like(k) + v_ones = torch.ones_like(v) + return (q_ones, k_ones, v_ones) + inputs[3:] + return hook + + +def make_capture_hook(layer_id: int): + """Capture Q and output during prefill.""" + def hook(module, inputs, output): + ctx = get_context() if not ctx.is_prefill: return + q, k, v = inputs chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0 - capture_entry = { + captures.append({ 'layer_id': layer_id, 'chunk_idx': chunk_idx, 'q': q.clone().cpu(), 'k': k.clone().cpu(), 'v': v.clone().cpu(), 'output': output.clone().cpu(), - 'is_chunked_prefill': ctx.is_chunked_prefill, - } - - # For debugging: also capture CPU cache state for layer 0 - if layer_id == 0 and chunk_idx >= 2: - kvcache_manager = ctx.kvcache_manager if hasattr(ctx, 'kvcache_manager') else None - if kvcache_manager is not None and hasattr(kvcache_manager, 'offload_engine'): - oe = kvcache_manager.offload_engine - # Get what should have been loaded from CPU - cpu_k0 = oe.k_cache_cpu[0, 0].clone().cpu() # Layer 0, CPU block 0 - cpu_k1 = oe.k_cache_cpu[0, 1].clone().cpu() # Layer 0, CPU block 1 - capture_entry['cpu_k0'] = cpu_k0 - capture_entry['cpu_k1'] = cpu_k1 - - captures.append(capture_entry) + }) return hook -def register_hooks(llm): - """Register forward hooks on all Attention modules.""" - hooks = [] - model = llm.model_runner.model - - for layer_idx, decoder_layer in enumerate(model.model.layers): - attn_module = decoder_layer.self_attn.attn - hook = attn_module.register_forward_hook(make_hook(layer_idx)) - hooks.append(hook) - - return hooks - - -# ============================================================ -# Reference Computation -# ============================================================ - -def compute_reference(layer_id, chunk_idx, scale, debug=False): +def compute_reference(layer_id: int, chunk_idx: int, scale: float, + k_cache_cpu: torch.Tensor, v_cache_cpu: torch.Tensor, + block_size: int) -> torch.Tensor: """ - Compute reference attention output for a specific layer and chunk. + Compute reference output using CPU KV cache and standard flash attention. - Uses the captured k, v from all chunks up to and including chunk_idx. + Concatenates all Q, K, V from chunks 0..chunk_idx and runs causal attention, + then extracts output for the current chunk. """ - # Filter captures for this layer + # Get all captures for this layer up to chunk_idx layer_captures = [c for c in captures if c['layer_id'] == layer_id and c['chunk_idx'] <= chunk_idx] - - if not layer_captures: - return None - - # Get current chunk's q - current_capture = [c for c in layer_captures if c['chunk_idx'] == chunk_idx][0] - q = current_capture['q'].cuda().unsqueeze(0) # [1, seqlen, nheads, headdim] - - # Collect all k, v up to current chunk - kv_list = [] - for c in sorted(layer_captures, key=lambda x: x['chunk_idx']): - k = c['k'].cuda().unsqueeze(0) # [1, seqlen, nheads, headdim] - v = c['v'].cuda().unsqueeze(0) - kv_list.append((k, v, c['chunk_idx'])) - - if debug: - print(f" Reference for L{layer_id} C{chunk_idx}:") - print(f" q shape: {q.shape}, mean={q.mean().item():.4f}") - print(f" kv_list: {len(kv_list)} chunks") - for i, (k, v, cidx) in enumerate(kv_list): - print(f" chunk {cidx}: k.mean={k.mean().item():.4f}, v.mean={v.mean().item():.4f}") - - o_acc, lse_acc = None, None - - # Previous chunks: non-causal attention - for i in range(len(kv_list) - 1): - k, v, _ = kv_list[i] - o, lse = flash_attn_with_lse(q, k, v, softmax_scale=scale, causal=False) - if o_acc is None: - o_acc, lse_acc = o, lse - else: - o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o, lse) - - # Current chunk: causal attention - k_cur, v_cur, _ = kv_list[-1] - o_cur, lse_cur = flash_attn_with_lse(q, k_cur, v_cur, softmax_scale=scale, causal=True) - - if o_acc is None: - return o_cur.squeeze(0).cpu() - - final_o, _ = merge_attention_outputs(o_acc, lse_acc, o_cur, lse_cur) - return final_o.squeeze(0).cpu() - - -def compute_standard_reference(layer_id, chunk_idx, scale, debug=False): - """ - Compute reference using standard flash attention (single pass with all K, V). - - This simulates what standard (non-chunked) prefill would produce. - Concatenates all Q, K, V from chunks 0 to chunk_idx and runs a single - causal attention pass, then extracts the output for the current chunk. - """ - # Filter captures for this layer - layer_captures = [c for c in captures - if c['layer_id'] == layer_id and c['chunk_idx'] <= chunk_idx] - - if not layer_captures: - return None - - # Sort by chunk index layer_captures = sorted(layer_captures, key=lambda x: x['chunk_idx']) - # Concatenate all Q, K, V + if not layer_captures: + return None + + # Collect Q from captures, K/V from CPU cache all_q = [] all_k = [] all_v = [] chunk_lengths = [] for c in layer_captures: + cidx = c['chunk_idx'] q = c['q'].cuda() # [seqlen, nheads, headdim] - k = c['k'].cuda() - v = c['v'].cuda() all_q.append(q) - all_k.append(k) - all_v.append(v) chunk_lengths.append(q.shape[0]) - # Concatenate along sequence dimension - full_q = torch.cat(all_q, dim=0) # [total_seqlen, nheads, headdim] + # Get K, V from CPU cache (already offloaded during prefill) + # CPU cache shape: [num_layers, num_blocks, block_size, kv_heads, head_dim] + k = k_cache_cpu[layer_id, cidx, :q.shape[0]].cuda() + v = v_cache_cpu[layer_id, cidx, :q.shape[0]].cuda() + all_k.append(k) + all_v.append(v) + + # Concatenate + full_q = torch.cat(all_q, dim=0) full_k = torch.cat(all_k, dim=0) full_v = torch.cat(all_v, dim=0) - total_len = full_q.shape[0] - if debug: - print(f" Standard Reference for L{layer_id} C{chunk_idx}:") - print(f" full_q shape: {full_q.shape}, mean={full_q.mean().item():.4f}") - print(f" full_k shape: {full_k.shape}, mean={full_k.mean().item():.4f}") - print(f" chunk_lengths: {chunk_lengths}") - # Run standard causal flash attention - # flash_attn_varlen_func expects: q, k, v with shape [total_seqlen, nheads, headdim] cu_seqlens = torch.tensor([0, total_len], dtype=torch.int32, device='cuda') - full_o = flash_attn_varlen_func( full_q, full_k, full_v, cu_seqlens_q=cu_seqlens, @@ -208,266 +116,81 @@ def compute_standard_reference(layer_id, chunk_idx, scale, debug=False): causal=True, ) - # Extract output for current chunk only + # Extract output for current chunk start_pos = sum(chunk_lengths[:-1]) end_pos = sum(chunk_lengths) - chunk_output = full_o[start_pos:end_pos] - - if debug: - print(f" full_o shape: {full_o.shape}") - print(f" extracting positions [{start_pos}:{end_pos}]") - print(f" chunk_output shape: {chunk_output.shape}, mean={chunk_output.mean().item():.4f}") - - return chunk_output.cpu() - - -# ============================================================ -# Test Runner -# ============================================================ - -def run_test(verbose=True): - """Run the hook-based chunked prefill correctness test.""" - global captures - captures = [] - - if verbose: - print("=" * 70) - print("Test: Hook-Based Chunked Prefill Correctness") - print("=" * 70) - print(f"Model: {MODEL_PATH}") - print(f"Input length: {INPUT_LEN} tokens") - print(f"Block size: {BLOCK_SIZE}") - print(f"Expected chunks: {INPUT_LEN // BLOCK_SIZE}") - print() - - # Initialize LLM with CPU offload - llm = LLM( - MODEL_PATH, - enforce_eager=True, - max_model_len=MAX_MODEL_LEN, - max_num_batched_tokens=MAX_MODEL_LEN, - enable_cpu_offload=True, - kvcache_block_size=BLOCK_SIZE, - num_gpu_blocks=NUM_GPU_BLOCKS, - ) - - # Get model info - num_layers = len(llm.model_runner.model.model.layers) - head_dim = llm.model_runner.model.model.layers[0].self_attn.attn.head_dim - scale = head_dim ** -0.5 - - if verbose: - print(f"Num layers: {num_layers}") - print(f"Head dim: {head_dim}") - print() - - # Register hooks - hooks = register_hooks(llm) - if verbose: - print(f"Registered {len(hooks)} hooks") - - # Generate random prompt - seed(42) - prompt_token_ids = [[randint(0, 10000) for _ in range(INPUT_LEN)]] - - # Run prefill only (max_tokens=1) - if verbose: - print("Running inference...") - sampling_params = SamplingParams(temperature=0.6, max_tokens=1) - outputs = llm.generate(prompt_token_ids, sampling_params, use_tqdm=False) - - # Remove hooks - for hook in hooks: - hook.remove() - - # Analyze captures - if verbose: - print(f"\nCaptured {len(captures)} attention calls") - - # Group by layer and chunk - chunks_per_layer = {} - for c in captures: - layer_id = c['layer_id'] - chunk_idx = c['chunk_idx'] - if layer_id not in chunks_per_layer: - chunks_per_layer[layer_id] = set() - chunks_per_layer[layer_id].add(chunk_idx) - - if verbose: - print("Chunks per layer:", {k: sorted(v) for k, v in chunks_per_layer.items()}) - print() - - # First, verify CPU cache data integrity - if verbose: - print("\n--- CPU Cache Verification (Layer 0) ---") - # Get original k from chunk 0 and chunk 1 captures - chunk0_k = None - chunk1_k = None - chunk2_capture = None - for c in captures: - if c['layer_id'] == 0: - if c['chunk_idx'] == 0: - chunk0_k = c['k'] - elif c['chunk_idx'] == 1: - chunk1_k = c['k'] - elif c['chunk_idx'] == 2: - chunk2_capture = c - - if chunk0_k is not None and chunk2_capture is not None and 'cpu_k0' in chunk2_capture: - cpu_k0 = chunk2_capture['cpu_k0'] - diff_k0 = (chunk0_k - cpu_k0).abs().max().item() - print(f"Chunk 0 k vs CPU cache block 0: max_diff={diff_k0:.6f}") - if diff_k0 > 1e-3: - print(f" WARNING: CPU cache block 0 differs from original chunk 0 k!") - print(f" Original k[0,0,:5] = {chunk0_k[0,0,:5].tolist()}") - print(f" CPU k0[0,0,:5] = {cpu_k0[0,0,:5].tolist()}") - - if chunk1_k is not None and chunk2_capture is not None and 'cpu_k1' in chunk2_capture: - cpu_k1 = chunk2_capture['cpu_k1'] - diff_k1 = (chunk1_k - cpu_k1).abs().max().item() - print(f"Chunk 1 k vs CPU cache block 1: max_diff={diff_k1:.6f}") - if diff_k1 > 1e-3: - print(f" WARNING: CPU cache block 1 differs from original chunk 1 k!") - print(f" Original k[0,0,:5] = {chunk1_k[0,0,:5].tolist()}") - print(f" CPU k1[0,0,:5] = {cpu_k1[0,0,:5].tolist()}") - - print() - - # ================================================================ - # Test 1: Verify against merge-based reference (same algorithm) - # ================================================================ - if verbose: - print("--- Test 1: Merge-based Reference (verifies merge algorithm) ---") - - all_passed_merge = True - results_merge = [] - first_fail_debug = True - - for c in captures: - layer_id = c['layer_id'] - chunk_idx = c['chunk_idx'] - - if chunk_idx == 0: - continue - - debug_this = (chunk_idx >= 2 and layer_id == 0 and first_fail_debug) - ref_output = compute_reference(layer_id, chunk_idx, scale, debug=debug_this) - if ref_output is None: - continue - - actual_output = c['output'] - diff = (actual_output - ref_output).abs() - max_diff = diff.max().item() - mean_diff = diff.mean().item() - - tol = 1e-2 - passed = max_diff < tol - all_passed_merge = all_passed_merge and passed - - status = "PASS" if passed else "FAIL" - results_merge.append((layer_id, chunk_idx, passed, max_diff, mean_diff)) - - if verbose: - print(f"[{status}] Layer {layer_id:2d}, Chunk {chunk_idx}: " - f"max_diff={max_diff:.6f} mean_diff={mean_diff:.8f}") - - if not passed and first_fail_debug: - first_fail_debug = False - print(f" Debug: actual_output shape={actual_output.shape}, mean={actual_output.mean().item():.4f}") - print(f" Debug: ref_output shape={ref_output.shape}, mean={ref_output.mean().item():.4f}") - max_idx = diff.argmax() - flat_actual = actual_output.flatten() - flat_ref = ref_output.flatten() - print(f" Debug: max_diff at idx={max_idx.item()}, actual={flat_actual[max_idx].item():.4f}, ref={flat_ref[max_idx].item():.4f}") - - print() - - # ================================================================ - # Test 2: Verify against standard flash attention (single pass) - # ================================================================ - if verbose: - print("--- Test 2: Standard FlashAttn Reference (verifies correctness vs non-chunked) ---") - - all_passed_standard = True - results_standard = [] - first_fail_debug = True - - for c in captures: - layer_id = c['layer_id'] - chunk_idx = c['chunk_idx'] - - if chunk_idx == 0: - continue - - debug_this = (chunk_idx >= 2 and layer_id == 0 and first_fail_debug) - std_ref_output = compute_standard_reference(layer_id, chunk_idx, scale, debug=debug_this) - if std_ref_output is None: - continue - - actual_output = c['output'] - diff = (actual_output - std_ref_output).abs() - max_diff = diff.max().item() - mean_diff = diff.mean().item() - - tol = 1e-2 - passed = max_diff < tol - all_passed_standard = all_passed_standard and passed - - status = "PASS" if passed else "FAIL" - results_standard.append((layer_id, chunk_idx, passed, max_diff, mean_diff)) - - if verbose: - print(f"[{status}] Layer {layer_id:2d}, Chunk {chunk_idx}: " - f"max_diff={max_diff:.6f} mean_diff={mean_diff:.8f}") - - if not passed and first_fail_debug: - first_fail_debug = False - print(f" Debug: actual_output shape={actual_output.shape}, mean={actual_output.mean().item():.4f}") - print(f" Debug: std_ref_output shape={std_ref_output.shape}, mean={std_ref_output.mean().item():.4f}") - max_idx = diff.argmax() - flat_actual = actual_output.flatten() - flat_ref = std_ref_output.flatten() - print(f" Debug: max_diff at idx={max_idx.item()}, actual={flat_actual[max_idx].item():.4f}, ref={flat_ref[max_idx].item():.4f}") - - print() - print("=" * 70) - - # Summary - total_merge = len(results_merge) - passed_merge = sum(1 for r in results_merge if r[2]) - total_standard = len(results_standard) - passed_standard = sum(1 for r in results_standard if r[2]) - - print(f"Merge-based reference: {passed_merge}/{total_merge} tests passed") - print(f"Standard FlashAttn ref: {passed_standard}/{total_standard} tests passed") - - all_passed = all_passed_merge and all_passed_standard - - if not all_passed_merge: - print("\nFailed merge-based tests:") - for layer_id, chunk_idx, passed, max_diff, mean_diff in results_merge: - if not passed: - print(f" - Layer {layer_id}, Chunk {chunk_idx}: max_diff={max_diff:.6f}") - - if not all_passed_standard: - print("\nFailed standard FlashAttn tests:") - for layer_id, chunk_idx, passed, max_diff, mean_diff in results_standard: - if not passed: - print(f" - Layer {layer_id}, Chunk {chunk_idx}: max_diff={max_diff:.6f}") - - print() - return all_passed + return full_o[start_pos:end_pos].cpu() # ============================================================ # Main # ============================================================ -if __name__ == "__main__": - passed = run_test(verbose=True) +llm = LLM( + MODEL_PATH, + enforce_eager=True, + max_model_len=MAX_MODEL_LEN, + max_num_batched_tokens=MAX_MODEL_LEN, + enable_cpu_offload=True, + kvcache_block_size=BLOCK_SIZE, + num_gpu_blocks=NUM_GPU_BLOCKS, + dtype="float16", +) - if passed: - print("test_chunked_prefill_hook: PASSED") - else: - print("test_chunked_prefill_hook: FAILED") - exit(1) +# Get model info +num_layers = len(llm.model_runner.model.model.layers) +head_dim = llm.model_runner.model.model.layers[0].self_attn.attn.head_dim +scale = head_dim ** -0.5 + +# Register hooks +hooks = [] +for layer_idx, decoder_layer in enumerate(llm.model_runner.model.model.layers): + # Pre-hook: inject all ones for Q, K, V + # pre_hook = decoder_layer.self_attn.attn.register_forward_pre_hook(make_ones_injection_hook()) + # hooks.append(pre_hook) + # Post-hook: capture Q, K, V, output + post_hook = decoder_layer.self_attn.attn.register_forward_hook(make_capture_hook(layer_idx)) + hooks.append(post_hook) + +# Run inference +seed(42) +prompt_token_ids = [[randint(0, 10000) for _ in range(INPUT_LEN)]] +outputs = llm.generate(prompt_token_ids, SamplingParams(temperature=0.6, max_tokens=1), use_tqdm=False) + +# Remove hooks +for hook in hooks: + hook.remove() + +# Get CPU cache reference +offload_engine = llm.model_runner.kvcache_manager.offload_engine +k_cache_cpu = offload_engine.k_cache_cpu.clone() +v_cache_cpu = offload_engine.v_cache_cpu.clone() + +# Verify: compare actual output with reference computed from CPU cache +all_passed = True +num_chunks = INPUT_LEN // BLOCK_SIZE + +for idx,c in enumerate(captures): + layer_id = c['layer_id'] + chunk_idx = c['chunk_idx'] + + # Skip chunk 0 (no previous KV to load) + if chunk_idx == 0: + continue + + ref_output = compute_reference(layer_id, chunk_idx, scale, k_cache_cpu, v_cache_cpu, BLOCK_SIZE) + if ref_output is None: + continue + + actual_output = c['output'] + diff = (actual_output - ref_output).abs() + max_diff = diff.max().item() + + passed = max_diff < 1e-1 # float16 tolerance + all_passed = all_passed and passed + + if not passed: + print(f"[FAIL] Layer {layer_id}, Chunk {chunk_idx}: max_diff={max_diff:.6f}") + __import__('pdb').set_trace() + +print(f"test_chunked_prefill_hook: {'PASSED' if all_passed else 'FAILED'}")