From 31e90a726803d7c0fc43d02a15330b6252f3d4d1 Mon Sep 17 00:00:00 2001 From: Zijie Tian Date: Wed, 31 Dec 2025 20:59:53 +0800 Subject: [PATCH] [test] Added offload correct verify. --- tests/test_offload_correctness.py | 1101 ++++++++++++++++------------- 1 file changed, 612 insertions(+), 489 deletions(-) diff --git a/tests/test_offload_correctness.py b/tests/test_offload_correctness.py index 4808015..8a5a54b 100644 --- a/tests/test_offload_correctness.py +++ b/tests/test_offload_correctness.py @@ -1,463 +1,627 @@ """ -Correctness test for chunked attention with CPU offload. +Test script to verify CPU offload correctness using distinctive KV patterns. -Validates that the offload pipeline (CPU -> GPU transfer + chunked attention) -produces the same result as direct GPU computation. +Strategy: +1. Hook into attention forward pass +2. Overwrite K/V with distinctive patterns based on chunk_idx (e.g., K=chunk_idx, V=-chunk_idx) +3. After offload to CPU, verify CPU cache contains correct patterns +4. On subsequent chunks, verify loaded KV from CPU has correct patterns -Test scenario: -1. Generate Q, K, V data -2. Reference: Compute full causal attention on GPU -3. Offload: Store K, V in CPU cache, load via pipeline, compute chunked attention -4. Compare results - -This test is designed to identify bugs in: -- CPU <-> GPU data transfer (sgDMA) -- Ring buffer slot management -- N-way pipeline ordering -- Triton merge kernel correctness +This catches bugs like: +- Wrong block being offloaded +- Wrong block being loaded +- Data corruption during transfer """ +import os +os.environ["NANOVLLM_LOG_LEVEL"] = "INFO" + import torch -from flash_attn.flash_attn_interface import flash_attn_func -from nanovllm.kvcache.hybrid_manager import HybridKVCacheManager -from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs +from random import randint, seed +from nanovllm import LLM, SamplingParams +from nanovllm.utils.context import get_context +from nanovllm.kvcache.debug_utils import dump_block_state # ============================================================ # Configuration # ============================================================ -NUM_LAYERS = 4 -NUM_HEADS = 8 -NUM_KV_HEADS = 8 -HEAD_DIM = 64 -BLOCK_SIZE = 256 # Smaller for faster testing -DTYPE = torch.bfloat16 -DEVICE = "cuda" +MODEL_PATH = os.path.expanduser("~/models/Qwen3-0.6B/") +MAX_MODEL_LEN = 64 * 1024 +NUM_GPU_BLOCKS = 4 +INPUT_LEN = 32 * 1024 # 32K tokens = 32 chunks (fits in 40 CPU blocks) +BLOCK_SIZE = 1024 + +# Test state +errors = [] +chunk_patterns = {} # chunk_idx -> (k_pattern, v_pattern) +block_coverage = {} # chunk_idx -> set of blocks that were actually computed +load_operations = [] # List of (chunk_idx, slot_id, cpu_block_id, k_ok, v_ok) tuples +current_chunk_for_load = [0] # Mutable container to track current chunk during loads # ============================================================ -# Reference Implementation (GPU only, no offload) +# Pattern Helpers # ============================================================ -def compute_reference_causal(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: - """ - Compute reference causal attention using flash_attn_func. - - Args: - q, k, v: [batch, seqlen, nheads, headdim] - - Returns: - out: [batch, seqlen, nheads, headdim] - """ - return flash_attn_func(q, k, v, causal=True) +def get_expected_pattern(chunk_idx: int): + """Get expected K/V pattern for a chunk.""" + # Use float values that are easy to identify + k_val = float(chunk_idx + 1) # 1.0, 2.0, 3.0, ... + v_val = float(-(chunk_idx + 1)) # -1.0, -2.0, -3.0, ... + return k_val, v_val -def compute_reference_chunked( - q_chunks: list, - kv_chunks: list, - scale: float, -) -> torch.Tensor: - """ - Compute chunked prefill attention directly on GPU (no offload). +def fill_with_pattern(tensor: torch.Tensor, value: float): + """Fill tensor with a constant value.""" + tensor.fill_(value) - This is the "gold standard" for chunked attention correctness. - Args: - q_chunks: List of [batch, chunk_size, nheads, headdim] - kv_chunks: List of (k, v) tuples, each [batch, chunk_size, nheads, headdim] - scale: Softmax scale - - Returns: - out: [batch, total_seqlen, nheads, headdim] - """ - out_chunks = [] - - for chunk_idx, q_chunk in enumerate(q_chunks): - o_acc, lse_acc = None, None - - # Attend to all previous chunks (no causal mask) - for i in range(chunk_idx): - k_chunk, v_chunk = kv_chunks[i] - chunk_o, chunk_lse = flash_attn_with_lse( - q_chunk, k_chunk, v_chunk, - softmax_scale=scale, - causal=False, - ) - if o_acc is None: - o_acc, lse_acc = chunk_o, chunk_lse - else: - o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, chunk_o, chunk_lse) - - # Attend to current chunk (with causal mask) - k_chunk, v_chunk = kv_chunks[chunk_idx] - current_o, current_lse = flash_attn_with_lse( - q_chunk, k_chunk, v_chunk, - softmax_scale=scale, - causal=True, - ) - - if o_acc is None: - final_o = current_o - else: - final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse) - - out_chunks.append(final_o) - - return torch.cat(out_chunks, dim=1) +def check_pattern(tensor: torch.Tensor, expected: float, name: str, tolerance: float = 1e-3): + """Check if tensor contains expected pattern.""" + actual_mean = tensor.float().mean().item() + if abs(actual_mean - expected) > tolerance: + return False, f"{name}: expected mean={expected}, got {actual_mean}" + return True, None # ============================================================ -# Offload Implementation +# Load Verification Instrumentation # ============================================================ -def create_manager(num_gpu_slots: int, num_cpu_blocks: int): - """Create HybridKVCacheManager with specified configuration.""" - manager = HybridKVCacheManager( - num_gpu_slots=num_gpu_slots, - num_cpu_blocks=num_cpu_blocks, - block_size=BLOCK_SIZE, +_original_load_to_slot_layer = None +_offload_engine_ref = None + +def make_verified_load_to_slot_layer(original_func, offload_engine): + """ + Create a wrapper around load_to_slot_layer that verifies each load operation. + + After each H2D transfer, checks that the GPU slot contains the expected + pattern from the source CPU block. + """ + def verified_load(slot_idx: int, layer_id: int, cpu_block_id: int): + # Call original load + original_func(slot_idx, layer_id, cpu_block_id) + + # Only verify layer 0 to reduce overhead + if layer_id != 0: + return + + # IMPORTANT: Synchronize CUDA to ensure async transfer is complete + # The transfer happens on a per-slot stream, and wait_slot_layer only + # makes compute_stream wait. We need full sync to read on default stream. + torch.cuda.synchronize() + + # Get the expected pattern for this CPU block + # cpu_block_id == chunk_idx in our sequential test + expected_k, expected_v = get_expected_pattern(cpu_block_id) + + # Read GPU slot data + gpu_k = offload_engine.k_cache_gpu[layer_id, slot_idx] + gpu_v = offload_engine.v_cache_gpu[layer_id, slot_idx] + + actual_k = gpu_k.float().mean().item() + actual_v = gpu_v.float().mean().item() + + k_ok = abs(actual_k - expected_k) < 1e-3 + v_ok = abs(actual_v - expected_v) < 1e-3 + + chunk_idx = current_chunk_for_load[0] + load_operations.append({ + 'chunk_idx': chunk_idx, + 'slot_idx': slot_idx, + 'cpu_block_id': cpu_block_id, + 'expected_k': expected_k, + 'expected_v': expected_v, + 'actual_k': actual_k, + 'actual_v': actual_v, + 'k_ok': k_ok, + 'v_ok': v_ok, + }) + + if not (k_ok and v_ok): + errors.append(f"Load verification failed: chunk {chunk_idx}, " + f"CPU block {cpu_block_id} -> GPU slot {slot_idx}: " + f"expected K={expected_k:.1f}/V={expected_v:.1f}, " + f"got K={actual_k:.4f}/V={actual_v:.4f}") + + return verified_load + + +def install_load_verification(llm): + """Install verification wrapper on load_to_slot_layer.""" + global _original_load_to_slot_layer, _offload_engine_ref + + oe = llm.model_runner.kvcache_manager.offload_engine + _offload_engine_ref = oe + _original_load_to_slot_layer = oe.load_to_slot_layer + + oe.load_to_slot_layer = make_verified_load_to_slot_layer( + _original_load_to_slot_layer, oe ) - manager.allocate_cache( - num_layers=NUM_LAYERS, - num_kv_heads=NUM_KV_HEADS, - head_dim=HEAD_DIM, - dtype=DTYPE, - ) - return manager + print("Installed load verification wrapper on load_to_slot_layer") -def store_kv_to_cpu_cache(manager, kv_chunks: list, layer_id: int): +def uninstall_load_verification(): + """Restore original load_to_slot_layer.""" + global _original_load_to_slot_layer, _offload_engine_ref + + if _offload_engine_ref is not None and _original_load_to_slot_layer is not None: + _offload_engine_ref.load_to_slot_layer = _original_load_to_slot_layer + print("Restored original load_to_slot_layer") + + _original_load_to_slot_layer = None + _offload_engine_ref = None + + +# ============================================================ +# Attention Hook +# ============================================================ + +def make_kv_pattern_pre_hook(layer_id: int): """ - Store K, V chunks to CPU cache. + Create a PRE-forward hook that overwrites K/V with distinctive patterns BEFORE + they are stored to cache. This is called before attention.forward(). - Args: - manager: HybridKVCacheManager - kv_chunks: List of (k, v) tuples, each [batch, chunk_size, nheads, headdim] - layer_id: Layer index - - Returns: - cpu_block_ids: List of CPU block IDs + register_forward_pre_hook receives (module, inputs) and can modify inputs in-place. """ - offload_engine = manager.offload_engine - cpu_block_ids = [] + def hook(module, inputs): + ctx = get_context() + if not ctx.is_prefill: + return - for block_idx, (k_chunk, v_chunk) in enumerate(kv_chunks): - # k_chunk, v_chunk: [batch, chunk_size, nheads, headdim] - # CPU cache layout: [num_layers, num_blocks, block_size, nheads, headdim] - k_data = k_chunk.squeeze(0) # [chunk_size, nheads, headdim] - v_data = v_chunk.squeeze(0) + chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0 + kvcache_manager = ctx.kvcache_manager if hasattr(ctx, 'kvcache_manager') else None - offload_engine.k_cache_cpu[layer_id, block_idx, :k_data.shape[0]].copy_(k_data) - offload_engine.v_cache_cpu[layer_id, block_idx, :v_data.shape[0]].copy_(v_data) + if kvcache_manager is None: + return - cpu_block_ids.append(block_idx) + # Only process layer 0 for cleaner output + if layer_id != 0: + return - return cpu_block_ids + q, k, v = inputs + k_pattern, v_pattern = get_expected_pattern(chunk_idx) + + # === Overwrite current chunk's K/V with distinctive pattern === + # This happens BEFORE forward(), so these values will be stored to cache + k.fill_(k_pattern) + v.fill_(v_pattern) + + # Only print for first few and last few chunks to reduce noise + num_chunks = INPUT_LEN // BLOCK_SIZE + if chunk_idx < 3 or chunk_idx >= num_chunks - 2: + print(f"[Chunk {chunk_idx:3d}] Set K={k_pattern:.1f}, V={v_pattern:.1f}") + elif chunk_idx == 3: + print(f"... (chunks 3 to {num_chunks - 3} omitted) ...") + + return hook -def compute_offload_chunked_single_layer( - manager, - q_chunks: list, - cpu_block_ids: list, - layer_id: int, - scale: float, -) -> torch.Tensor: +def make_block_coverage_pre_hook(layer_id: int): """ - Compute chunked attention for a single layer using offload pipeline. + Create a PRE-forward hook to verify that all previous blocks are included + in the cpu_block_table for chunked prefill attention. - This mimics the behavior of Attention._ring_buffer_pipeline_load(). - - Args: - manager: HybridKVCacheManager - q_chunks: List of [batch, chunk_size, nheads, headdim] - cpu_block_ids: List of CPU block IDs containing K, V data - layer_id: Layer index - scale: Softmax scale - - Returns: - out: [batch, total_seqlen, nheads, headdim] + This catches bugs where: + - Some blocks are missing from the computation + - Sparse policy incorrectly filters out blocks (when not intended) + - Block table construction has off-by-one errors """ - offload_engine = manager.offload_engine - out_chunks = [] + def hook(module, inputs): + ctx = get_context() + if not ctx.is_prefill: + return - for chunk_idx, q_chunk in enumerate(q_chunks): - # CPU blocks to load: all blocks before current chunk - blocks_to_load = cpu_block_ids[:chunk_idx] + chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0 + kvcache_manager = ctx.kvcache_manager if hasattr(ctx, 'kvcache_manager') else None - # Get slots for this chunk - write_slot = offload_engine.get_write_slot_for_prefill(chunk_idx) - load_slots = offload_engine.get_load_slots_for_prefill(write_slot) + if kvcache_manager is None: + return - # Load and compute attention for previous chunks - o_acc, lse_acc = None, None + # Only process layer 0 for cleaner output + if layer_id != 0: + return - if len(blocks_to_load) > 0 and len(load_slots) > 0: - o_acc, lse_acc = _pipeline_load_and_compute( - offload_engine, - q_chunk, - blocks_to_load, - load_slots, - layer_id, - scale, - ) + # Update current chunk for load verification tracking + current_chunk_for_load[0] = chunk_idx - # Current chunk's K, V (load from CPU to GPU slot) - current_cpu_block = cpu_block_ids[chunk_idx] - offload_engine.load_to_slot_layer(write_slot, layer_id, current_cpu_block) - offload_engine.wait_slot_layer(write_slot, layer_id) + # No previous blocks for chunk 0 + if chunk_idx == 0: + return - current_k, current_v = offload_engine.get_kv_for_slot(write_slot, layer_id) + # Get the sequence and its block table (same logic as _chunked_prefill_attention) + seq = ctx.chunked_seq if hasattr(ctx, 'chunked_seq') else None + if seq is None: + return - # Compute attention with causal mask - current_o, current_lse = flash_attn_with_lse( - q_chunk, current_k, current_v, - softmax_scale=scale, - causal=True, - ) + # Get the CPU block table that will be used for attention + cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq) - # Merge - if o_acc is None: - final_o = current_o + # Expected blocks: 0 to chunk_idx-1 (all previous chunks) + expected_blocks = set(range(chunk_idx)) + actual_blocks = set(cpu_block_table) if cpu_block_table else set() + + # Store for later summary + block_coverage[chunk_idx] = { + 'expected': expected_blocks, + 'actual': actual_blocks, + } + + # Check for missing blocks + missing_blocks = expected_blocks - actual_blocks + extra_blocks = actual_blocks - expected_blocks + + num_chunks = INPUT_LEN // BLOCK_SIZE + if chunk_idx < 4 or chunk_idx >= num_chunks - 2 or missing_blocks: + if not missing_blocks and not extra_blocks: + print(f" Block coverage chunk {chunk_idx:2d}: {len(actual_blocks)}/{len(expected_blocks)} blocks [OK]") + else: + status_parts = [] + if missing_blocks: + status_parts.append(f"MISSING {sorted(missing_blocks)}") + if extra_blocks: + status_parts.append(f"EXTRA {sorted(extra_blocks)}") + print(f" Block coverage chunk {chunk_idx:2d}: {len(actual_blocks)}/{len(expected_blocks)} blocks [{', '.join(status_parts)}]") + elif chunk_idx == 4: + # Indicate that middle chunks are being verified silently + print(f" ... (verifying chunks 4-{num_chunks - 3} silently) ...") + + if missing_blocks: + errors.append(f"Chunk {chunk_idx} missing blocks: {sorted(missing_blocks)}") + + return hook + + +def make_gpu_write_verification_post_hook(layer_id: int): + """ + Create a POST-forward hook to verify the current chunk's KV was correctly + written to the GPU ring buffer write_slot. + + This is a more reliable verification than checking load slots, because: + 1. Post-hook runs AFTER forward() writes to GPU cache + 2. write_slot mapping is deterministic: chunk_idx % num_ring_slots + 3. We injected known patterns in pre-hook, now verify they're in GPU cache + """ + def hook(module, inputs, output): + ctx = get_context() + if not ctx.is_prefill: + return + + chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0 + kvcache_manager = ctx.kvcache_manager if hasattr(ctx, 'kvcache_manager') else None + + if kvcache_manager is None: + return + + # Only process layer 0 for cleaner output + if layer_id != 0: + return + + oe = kvcache_manager.offload_engine + num_ring_slots = oe.num_ring_slots + write_slot = chunk_idx % num_ring_slots + + # Get expected pattern for current chunk + expected_k, expected_v = get_expected_pattern(chunk_idx) + + # Verify write_slot contains current chunk's data + gpu_k = oe.k_cache_gpu[layer_id, write_slot] + gpu_v = oe.v_cache_gpu[layer_id, write_slot] + + actual_k_mean = gpu_k.float().mean().item() + actual_v_mean = gpu_v.float().mean().item() + + k_ok, _ = check_pattern(gpu_k, expected_k, f"GPU slot {write_slot}") + v_ok, _ = check_pattern(gpu_v, expected_v, f"GPU slot {write_slot}") + + num_chunks = INPUT_LEN // BLOCK_SIZE + # Print for first/last chunks, or if there's an error + if True or chunk_idx >= num_chunks - 2 or not (k_ok and v_ok): + if k_ok and v_ok: + print(f" GPU write_slot[{write_slot}] chunk {chunk_idx:2d}: K={expected_k:.1f}, V={expected_v:.1f} [OK]") + else: + print(f" GPU write_slot[{write_slot}] chunk {chunk_idx:2d}: expected K={expected_k:.1f}/V={expected_v:.1f}, " + f"got K={actual_k_mean:.2f}/V={actual_v_mean:.2f} [FAIL]") + elif chunk_idx == 4: + print(f" ... (GPU write verification for chunks 4-{num_chunks - 3} silently) ...") + + if not (k_ok and v_ok): + errors.append(f"GPU write_slot {write_slot} at chunk {chunk_idx}: " + f"expected K={expected_k}, V={expected_v}, got K={actual_k_mean:.4f}, V={actual_v_mean:.4f}") + + return hook + + +def make_kv_verification_post_hook(layer_id: int): + """ + Create a POST-forward hook to verify CPU cache contains correct patterns + from previously offloaded blocks. + """ + def hook(module, inputs, output): + ctx = get_context() + if not ctx.is_prefill: + return + + chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0 + kvcache_manager = ctx.kvcache_manager if hasattr(ctx, 'kvcache_manager') else None + + if kvcache_manager is None: + return + + # Only process layer 0 for cleaner output + if layer_id != 0: + return + + # === Verify previously offloaded blocks in CPU cache === + if chunk_idx >= 1: + oe = kvcache_manager.offload_engine + num_ok = 0 + num_fail = 0 + + # Check all previously offloaded blocks + for prev_chunk in range(chunk_idx): + # CPU block ID = prev_chunk (in simple sequential case) + cpu_block_id = prev_chunk + + # Get expected pattern for this block + expected_k, expected_v = get_expected_pattern(prev_chunk) + + # Read from CPU cache (layer 0) + cpu_k = oe.k_cache_cpu[layer_id, cpu_block_id] + cpu_v = oe.v_cache_cpu[layer_id, cpu_block_id] + + # Verify patterns + k_ok, k_err = check_pattern(cpu_k, expected_k, f"CPU K block {cpu_block_id}") + v_ok, v_err = check_pattern(cpu_v, expected_v, f"CPU V block {cpu_block_id}") + + if k_ok and v_ok: + num_ok += 1 + else: + num_fail += 1 + if k_err: + errors.append(k_err) + if v_err: + errors.append(v_err) + + # Only print summary for each chunk verification + num_chunks = INPUT_LEN // BLOCK_SIZE + if chunk_idx < 4 or chunk_idx >= num_chunks - 2 or num_fail > 0: + status = "OK" if num_fail == 0 else f"FAIL({num_fail})" + print(f" CPU verify chunk {chunk_idx:2d}: {num_ok} blocks OK [{status}]") + elif chunk_idx == 4: + print(f" ... (CPU cache verification for chunks 4-{num_chunks - 3} silently) ...") + + return hook + + +def make_post_chunk_verification_hook(layer_id: int): + """ + Post-forward hook to verify GPU ring buffer state after attention. + """ + def hook(module, inputs, output): + ctx = get_context() + if not ctx.is_prefill or layer_id != 0: + return + + chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0 + kvcache_manager = ctx.kvcache_manager if hasattr(ctx, 'kvcache_manager') else None + + if kvcache_manager is None: + return + + oe = kvcache_manager.offload_engine + + # After attention, the current chunk's KV should be in the GPU ring buffer + # Ring slot = chunk_idx % num_ring_slots + ring_slot = chunk_idx % oe.num_ring_slots + + expected_k, expected_v = get_expected_pattern(chunk_idx) + + # Check GPU ring buffer + gpu_k = oe.k_cache_gpu[layer_id, ring_slot] + gpu_v = oe.v_cache_gpu[layer_id, ring_slot] + + k_ok, k_err = check_pattern(gpu_k, expected_k, f"GPU K slot {ring_slot}") + v_ok, v_err = check_pattern(gpu_v, expected_v, f"GPU V slot {ring_slot}") + + if k_ok and v_ok: + print(f" [OK] GPU slot {ring_slot} (chunk {chunk_idx}): K={expected_k}, V={expected_v}") else: - final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse) + if k_err: + print(f" [FAIL] {k_err}") + errors.append(k_err) + if v_err: + print(f" [FAIL] {v_err}") + errors.append(v_err) - out_chunks.append(final_o) - - return torch.cat(out_chunks, dim=1) + return hook -def _pipeline_load_and_compute( - offload_engine, - q_chunk: torch.Tensor, - cpu_block_table: list, - load_slots: list, - layer_id: int, - scale: float, -): - """ - Pipeline loading from CPU and computing attention. +def register_hooks(llm): + """Register pre and post forward hooks.""" + hooks = [] + model = llm.model_runner.model - Mirrors Attention._ring_buffer_pipeline_load() logic. - """ - num_blocks = len(cpu_block_table) - num_slots = len(load_slots) + for layer_idx, decoder_layer in enumerate(model.model.layers): + attn_module = decoder_layer.self_attn.attn - o_acc, lse_acc = None, None + # PRE-forward hook 1: Verify all previous blocks are in cpu_block_table + coverage_hook = attn_module.register_forward_pre_hook(make_block_coverage_pre_hook(layer_idx)) + hooks.append(coverage_hook) - # Phase 1: Pre-load up to num_slots blocks - num_preload = min(num_slots, num_blocks) - for i in range(num_preload): - offload_engine.load_to_slot_layer(load_slots[i], layer_id, cpu_block_table[i]) + # PRE-forward hook 2: Inject K/V patterns before they're stored to cache + pattern_hook = attn_module.register_forward_pre_hook(make_kv_pattern_pre_hook(layer_idx)) + hooks.append(pattern_hook) - # Phase 2: Main loop - compute_stream = offload_engine.compute_stream + # POST-forward hook 1: Verify GPU write_slot contains current chunk's data + gpu_verify_hook = attn_module.register_forward_hook(make_gpu_write_verification_post_hook(layer_idx)) + hooks.append(gpu_verify_hook) - for block_idx in range(num_blocks): - current_slot = load_slots[block_idx % num_slots] + # POST-forward hook 2: Verify CPU cache contains correct patterns after offload + cpu_verify_hook = attn_module.register_forward_hook(make_kv_verification_post_hook(layer_idx)) + hooks.append(cpu_verify_hook) - # Wait for transfer - offload_engine.wait_slot_layer(current_slot, layer_id) + return hooks - # Compute on dedicated stream - with torch.cuda.stream(compute_stream): - prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot, layer_id) - prev_o, prev_lse = flash_attn_with_lse( - q_chunk, prev_k, prev_v, - softmax_scale=scale, - causal=False, - ) - offload_engine.record_slot_compute_done(current_slot, layer_id) - # Start next transfer - next_block_idx = block_idx + num_slots - if next_block_idx < num_blocks: - offload_engine.load_to_slot_layer( - current_slot, layer_id, cpu_block_table[next_block_idx] - ) +# ============================================================ +# Final Verification +# ============================================================ - # Merge - with torch.cuda.stream(compute_stream): - if o_acc is None: - o_acc, lse_acc = prev_o, prev_lse +def verify_final_cpu_state(llm, num_chunks: int): + """Verify all CPU blocks have correct patterns after prefill completes.""" + print("\n" + "=" * 60) + print("Final CPU Cache Verification") + print("=" * 60) + + kvcache_manager = llm.model_runner.kvcache_manager + oe = kvcache_manager.offload_engine + + num_ok = 0 + num_fail = 0 + fail_details = [] + + # After prefill, all chunks should be in CPU + for chunk_idx in range(num_chunks): + cpu_block_id = chunk_idx + expected_k, expected_v = get_expected_pattern(chunk_idx) + + # Check layer 0 + cpu_k = oe.k_cache_cpu[0, cpu_block_id] + cpu_v = oe.v_cache_cpu[0, cpu_block_id] + + k_ok, k_err = check_pattern(cpu_k, expected_k, f"Final CPU K block {cpu_block_id}") + v_ok, v_err = check_pattern(cpu_v, expected_v, f"Final CPU V block {cpu_block_id}") + + if k_ok and v_ok: + num_ok += 1 + # Only print first few and last few + if chunk_idx < 3 or chunk_idx >= num_chunks - 2: + actual_k_mean = cpu_k.float().mean().item() + actual_v_mean = cpu_v.float().mean().item() + print(f" Block {cpu_block_id:3d}: K={expected_k:.1f} ({actual_k_mean:.4f}), " + f"V={expected_v:.1f} ({actual_v_mean:.4f}) [OK]") + elif chunk_idx == 3: + print(f" ... (blocks 3 to {num_chunks - 3} verified OK) ...") + else: + num_fail += 1 + actual_k_mean = cpu_k.float().mean().item() + actual_v_mean = cpu_v.float().mean().item() + print(f" Block {cpu_block_id:3d}: K={expected_k:.1f} ({actual_k_mean:.4f}), " + f"V={expected_v:.1f} ({actual_v_mean:.4f}) [FAIL]") + if k_err: + errors.append(k_err) + if v_err: + errors.append(v_err) + + print(f"\nTotal: {num_ok} OK, {num_fail} FAIL out of {num_chunks} blocks") + + +def verify_block_coverage_summary(num_chunks: int): + """Verify that all chunks had complete block coverage during prefill.""" + print("\n" + "=" * 60) + print("Block Coverage Verification Summary") + print("=" * 60) + + num_ok = 0 + num_fail = 0 + total_blocks_expected = 0 + total_blocks_computed = 0 + + for chunk_idx in range(1, num_chunks): # Start from 1 (chunk 0 has no previous) + if chunk_idx not in block_coverage: + print(f" Chunk {chunk_idx}: NO COVERAGE DATA [FAIL]") + errors.append(f"Chunk {chunk_idx} has no block coverage data") + num_fail += 1 + continue + + coverage = block_coverage[chunk_idx] + expected = coverage['expected'] + actual = coverage['actual'] + missing = expected - actual + + total_blocks_expected += len(expected) + total_blocks_computed += len(actual) + + if not missing: + num_ok += 1 + else: + num_fail += 1 + + # Print summary + if num_fail == 0: + print(f" All {num_ok} chunks had complete block coverage [OK]") + print(f" Total blocks computed: {total_blocks_computed} (expected: {total_blocks_expected})") + else: + print(f" {num_ok} chunks OK, {num_fail} chunks with missing blocks [FAIL]") + print(f" Total blocks computed: {total_blocks_computed} (expected: {total_blocks_expected})") + + # Verify the total is correct: sum of 0+1+2+...+(n-1) = n*(n-1)/2 + expected_total = num_chunks * (num_chunks - 1) // 2 + if total_blocks_expected == expected_total: + print(f" Expected total blocks matches formula: {expected_total} [OK]") + else: + print(f" Expected total mismatch: got {total_blocks_expected}, formula gives {expected_total} [FAIL]") + errors.append(f"Block coverage total mismatch") + + +def verify_load_operations_summary(num_chunks: int): + """Verify all H2D load operations transferred correct data.""" + print("\n" + "=" * 60) + print("H2D Load Operations Verification Summary") + print("=" * 60) + + if not load_operations: + print(" WARNING: No load operations recorded!") + print(" (This may indicate load verification was not installed)") + return + + num_ok = 0 + num_fail = 0 + loads_per_chunk = {} + + for op in load_operations: + chunk_idx = op['chunk_idx'] + if chunk_idx not in loads_per_chunk: + loads_per_chunk[chunk_idx] = [] + loads_per_chunk[chunk_idx].append(op) + + if op['k_ok'] and op['v_ok']: + num_ok += 1 + else: + num_fail += 1 + + # Print per-chunk summary for first/last chunks + for chunk_idx in sorted(loads_per_chunk.keys()): + ops = loads_per_chunk[chunk_idx] + chunk_ok = sum(1 for op in ops if op['k_ok'] and op['v_ok']) + chunk_fail = len(ops) - chunk_ok + + if chunk_idx < 4 or chunk_idx >= num_chunks - 2 or chunk_fail > 0: + # Show loaded block IDs in order + block_ids = [op['cpu_block_id'] for op in ops] + if chunk_fail == 0: + print(f" Chunk {chunk_idx:2d}: loaded {len(ops)} blocks {block_ids} [OK]") else: - o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse) + print(f" Chunk {chunk_idx:2d}: loaded {len(ops)} blocks, {chunk_fail} FAILED [FAIL]") + for op in ops: + if not (op['k_ok'] and op['v_ok']): + print(f" CPU block {op['cpu_block_id']} -> slot {op['slot_idx']}: " + f"expected K={op['expected_k']:.1f}/V={op['expected_v']:.1f}, " + f"got K={op['actual_k']:.4f}/V={op['actual_v']:.4f}") + elif chunk_idx == 4: + print(f" ... (chunks 4-{num_chunks - 3} load verification running silently) ...") - # Sync compute stream - compute_stream.synchronize() + # Print overall summary + print(f"\n Total load operations: {len(load_operations)}") + print(f" Successful: {num_ok}, Failed: {num_fail}") - return o_acc, lse_acc - - -# ============================================================ -# Test Runner -# ============================================================ - -def run_correctness_test( - num_chunks: int, - num_gpu_slots: int, - verbose: bool = True, -) -> tuple[bool, float, float]: - """ - Run a single correctness test. - - Args: - num_chunks: Number of chunks (= number of CPU blocks) - num_gpu_slots: Number of GPU ring buffer slots - verbose: Print detailed info - - Returns: - (passed, max_diff, mean_diff) - """ - torch.manual_seed(42) - - seqlen = num_chunks * BLOCK_SIZE - scale = HEAD_DIM ** -0.5 - - # Generate Q, K, V - q_full = torch.randn(1, seqlen, NUM_HEADS, HEAD_DIM, device=DEVICE, dtype=DTYPE) - k_full = torch.randn(1, seqlen, NUM_KV_HEADS, HEAD_DIM, device=DEVICE, dtype=DTYPE) - v_full = torch.randn(1, seqlen, NUM_KV_HEADS, HEAD_DIM, device=DEVICE, dtype=DTYPE) - - # Split into chunks - q_chunks = [q_full[:, i*BLOCK_SIZE:(i+1)*BLOCK_SIZE] for i in range(num_chunks)] - kv_chunks = [ - (k_full[:, i*BLOCK_SIZE:(i+1)*BLOCK_SIZE], - v_full[:, i*BLOCK_SIZE:(i+1)*BLOCK_SIZE]) - for i in range(num_chunks) - ] - - # Reference: chunked attention on GPU (no offload) - out_ref = compute_reference_chunked(q_chunks, kv_chunks, scale) - - # Create manager with enough CPU blocks - manager = create_manager(num_gpu_slots, num_chunks) - - # Test each layer - all_passed = True - max_diff_all = 0.0 - mean_diff_all = 0.0 - - for layer_id in range(NUM_LAYERS): - # Store K, V to CPU cache - cpu_block_ids = store_kv_to_cpu_cache(manager, kv_chunks, layer_id) - - # Compute with offload - out_offload = compute_offload_chunked_single_layer( - manager, q_chunks, cpu_block_ids, layer_id, scale - ) - - # Compare - diff = (out_ref - out_offload).abs() - max_diff = diff.max().item() - mean_diff = diff.mean().item() - - max_diff_all = max(max_diff_all, max_diff) - mean_diff_all = max(mean_diff_all, mean_diff) - - tol = 1e-2 - passed = max_diff < tol - all_passed = all_passed and passed - - if verbose and not passed: - print(f" Layer {layer_id}: FAIL max_diff={max_diff:.6f}") - - return all_passed, max_diff_all, mean_diff_all - - -# ============================================================ -# Decode Phase Test -# ============================================================ - -def run_decode_correctness_test( - num_prefill_chunks: int, - num_gpu_slots: int, - num_decode_steps: int = 4, - verbose: bool = True, -) -> tuple[bool, float, float]: - """ - Test decode phase correctness with CPU offload. - - Simulates: - 1. Prefill: Store K, V for multiple chunks in CPU cache - 2. Decode: Single token queries against all prefilled K, V - - This tests the scenario in needle test where decode reads all previous KV. - """ - torch.manual_seed(42) - - scale = HEAD_DIM ** -0.5 - prefill_len = num_prefill_chunks * BLOCK_SIZE - - # Generate prefill K, V (store in CPU) - k_prefill = torch.randn(1, prefill_len, NUM_KV_HEADS, HEAD_DIM, device=DEVICE, dtype=DTYPE) - v_prefill = torch.randn(1, prefill_len, NUM_KV_HEADS, HEAD_DIM, device=DEVICE, dtype=DTYPE) - - # Split into chunks for CPU storage - kv_chunks = [ - (k_prefill[:, i*BLOCK_SIZE:(i+1)*BLOCK_SIZE], - v_prefill[:, i*BLOCK_SIZE:(i+1)*BLOCK_SIZE]) - for i in range(num_prefill_chunks) - ] - - # Create manager - manager = create_manager(num_gpu_slots, num_prefill_chunks) - offload_engine = manager.offload_engine - - all_passed = True - max_diff_all = 0.0 - mean_diff_all = 0.0 - - for layer_id in range(NUM_LAYERS): - # Store prefilled K, V to CPU cache - cpu_block_ids = store_kv_to_cpu_cache(manager, kv_chunks, layer_id) - - for decode_step in range(num_decode_steps): - # Decode query: single token - q_decode = torch.randn(1, 1, NUM_HEADS, HEAD_DIM, device=DEVICE, dtype=DTYPE) - - # Reference: direct attention on GPU - # Concat all prefilled K, V and compute attention - out_ref = flash_attn_func( - q_decode, - k_prefill, - v_prefill, - causal=False, # Decode query can attend to all prefilled tokens - ) - - # Offload: load from CPU and compute - load_slots = offload_engine.get_load_slots_for_prefill(0) # Use all slots except decode slot - - if len(load_slots) == 0 or len(cpu_block_ids) == 0: - # No previous chunks to load - out_offload = out_ref # Trivially equal - else: - o_acc, lse_acc = _pipeline_load_and_compute( - offload_engine, - q_decode, - cpu_block_ids, - load_slots, - layer_id, - scale, - ) - out_offload = o_acc - - # Compare - diff = (out_ref - out_offload).abs() - max_diff = diff.max().item() - mean_diff = diff.mean().item() - - max_diff_all = max(max_diff_all, max_diff) - mean_diff_all = max(mean_diff_all, mean_diff) - - tol = 1e-2 - passed = max_diff < tol - all_passed = all_passed and passed - - if verbose and not passed: - print(f" Layer {layer_id} Step {decode_step}: FAIL max_diff={max_diff:.6f}") - - return all_passed, max_diff_all, mean_diff_all + if num_fail == 0: + print(f" All H2D transfers verified correct [OK]") + else: + print(f" {num_fail} H2D transfers had incorrect data [FAIL]") # ============================================================ @@ -465,109 +629,68 @@ def run_decode_correctness_test( # ============================================================ if __name__ == "__main__": - print("=" * 70) - print("Test: Offload Chunked Attention Correctness") - print("=" * 70) - print(f"Config: layers={NUM_LAYERS}, heads={NUM_HEADS}, kv_heads={NUM_KV_HEADS}, " - f"head_dim={HEAD_DIM}, block_size={BLOCK_SIZE}, dtype={DTYPE}") - print() - print("Comparing: Reference (GPU chunked) vs Offload (CPU->GPU pipeline)") + print("=" * 60) + print("Test: CPU Offload Correctness with Distinctive KV Patterns") + print("=" * 60) + print(f"Input: {INPUT_LEN} tokens, {INPUT_LEN // BLOCK_SIZE} chunks") + print(f"GPU blocks: {NUM_GPU_BLOCKS}, Block size: {BLOCK_SIZE}") + print(f"Pattern: K=chunk_idx+1, V=-(chunk_idx+1)") print() - # Test configurations: (num_chunks, num_gpu_slots) - TEST_CASES = [ - # Basic tests - (2, 2), # Minimal: 2 chunks, 2 slots (no pipeline) - (2, 3), # 2 chunks, 3 slots (1-slot pipeline) - (4, 2), # 4 chunks, 2 slots (heavy slot reuse) - (4, 3), # 4 chunks, 3 slots - (4, 4), # 4 chunks, 4 slots - # Stress tests - (8, 3), # Many chunks, few slots - (8, 4), # Many chunks, moderate slots - (8, 6), # Many chunks, many slots (like bench_offload) - # Edge cases - (1, 2), # Single chunk - (3, 5), # Fewer chunks than slots - ] + # 1. Initialize LLM + print("Initializing LLM...") + llm = LLM( + MODEL_PATH, + enforce_eager=True, + max_model_len=MAX_MODEL_LEN, + max_num_batched_tokens=MAX_MODEL_LEN, + enable_cpu_offload=True, + kvcache_block_size=BLOCK_SIZE, + num_gpu_blocks=NUM_GPU_BLOCKS, + dtype="float16", + ) - all_passed = True - results = [] + # 2. Register hooks + hooks = register_hooks(llm) + print(f"Registered {len(hooks)} hooks") - for num_chunks, num_gpu_slots in TEST_CASES: - seqlen = num_chunks * BLOCK_SIZE - passed, max_diff, mean_diff = run_correctness_test( - num_chunks, num_gpu_slots, verbose=False - ) + # 3. Install load verification (instrument load_to_slot_layer) + install_load_verification(llm) - all_passed = all_passed and passed - status = "PASS" if passed else "FAIL" + # 4. Generate prompt + seed(42) + prompt_token_ids = [[randint(0, 10000) for _ in range(INPUT_LEN)]] + num_chunks = INPUT_LEN // BLOCK_SIZE - results.append((num_chunks, num_gpu_slots, seqlen, passed, max_diff, mean_diff)) + # 5. Run prefill + print("\n" + "=" * 60) + print("Running Prefill with KV Pattern Injection...") + print("=" * 60) - print(f"[{status}] chunks={num_chunks:2d} slots={num_gpu_slots:2d} " - f"seqlen={seqlen:5d} max_diff={max_diff:.6f} mean_diff={mean_diff:.8f}") + sampling_params = SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=1) + outputs = llm.generate(prompt_token_ids, sampling_params, use_tqdm=False) - print() + # 6. Remove hooks and uninstall load verification + for hook in hooks: + hook.remove() + uninstall_load_verification() - # ================================================================ - # Part 2: Decode Phase Tests - # ================================================================ - print("=" * 70) - print("Part 2: Decode Phase Correctness") - print("=" * 70) - print("Testing: Decode query (single token) against all prefilled K, V") - print() + # 7. Final verification + verify_final_cpu_state(llm, num_chunks) - DECODE_TEST_CASES = [ - # (num_prefill_chunks, num_gpu_slots) - (2, 2), - (4, 3), - (4, 4), - (8, 4), - (8, 6), - ] + # 8. Block coverage summary + verify_block_coverage_summary(num_chunks) - decode_results = [] + # 9. H2D load operations summary + verify_load_operations_summary(num_chunks) - for num_prefill_chunks, num_gpu_slots in DECODE_TEST_CASES: - prefill_len = num_prefill_chunks * BLOCK_SIZE - passed, max_diff, mean_diff = run_decode_correctness_test( - num_prefill_chunks, num_gpu_slots, num_decode_steps=4, verbose=False - ) - - all_passed = all_passed and passed - status = "PASS" if passed else "FAIL" - - decode_results.append((num_prefill_chunks, num_gpu_slots, prefill_len, passed, max_diff, mean_diff)) - - print(f"[{status}] prefill_chunks={num_prefill_chunks:2d} slots={num_gpu_slots:2d} " - f"prefill_len={prefill_len:5d} max_diff={max_diff:.6f} mean_diff={mean_diff:.8f}") - - print() - print("=" * 70) - - # Summary - prefill_passed = sum(1 for r in results if r[3]) - decode_passed = sum(1 for r in decode_results if r[3]) - total_tests = len(results) + len(decode_results) - total_passed = prefill_passed + decode_passed - - print(f"Results: {total_passed}/{total_tests} tests passed") - print(f" - Prefill: {prefill_passed}/{len(results)}") - print(f" - Decode: {decode_passed}/{len(decode_results)}") - - if not all_passed: - print("\nFailed tests:") - for num_chunks, num_gpu_slots, seqlen, passed, max_diff, mean_diff in results: - if not passed: - print(f" - [Prefill] chunks={num_chunks}, slots={num_gpu_slots}, " - f"seqlen={seqlen}, max_diff={max_diff:.6f}") - for num_chunks, num_gpu_slots, seqlen, passed, max_diff, mean_diff in decode_results: - if not passed: - print(f" - [Decode] prefill_chunks={num_chunks}, slots={num_gpu_slots}, " - f"prefill_len={seqlen}, max_diff={max_diff:.6f}") - - print() - assert all_passed, "Some correctness tests failed!" - print("test_offload_correctness: PASSED") + # 10. Report results + print("\n" + "=" * 60) + if errors: + print(f"test_offload_correctness: FAILED ({len(errors)} errors)") + for err in errors[:10]: # Show first 10 errors + print(f" - {err}") + exit(1) + else: + print("test_offload_correctness: PASSED") + print("=" * 60)