diff --git a/nanovllm/kvcache/offload_engine.py b/nanovllm/kvcache/offload_engine.py index 9f2d4f8..c89e16b 100644 --- a/nanovllm/kvcache/offload_engine.py +++ b/nanovllm/kvcache/offload_engine.py @@ -1093,4 +1093,7 @@ class OffloadEngine: try: hook(slot_idx, layer_id, cpu_block_id, k, v) except Exception as e: + # Allow pdb quit to propagate + if e.__class__.__name__ == 'BdbQuit': + raise logger.warning(f"Debug hook error: {e}") \ No newline at end of file diff --git a/tests/test_debug_verification.py b/tests/test_debug_verification.py index bf1d0cc..4258a8c 100644 --- a/tests/test_debug_verification.py +++ b/tests/test_debug_verification.py @@ -1,196 +1,85 @@ """ -Test script for verifying KV cache offload correctness using debug hooks. - -Strategy: -1. Inject distinctive K/V values (K=chunk_idx+1, V=-(chunk_idx+1)) -2. Register debug hook to receive loaded tensor -3. Hook reads tensor values to verify correct block was loaded -4. No verification logic in framework - all external - -This tests the framework's normal async execution path. +Test KV cache offload correctness using debug hooks. +Injects distinctive K/V values, verifies loaded tensors match expected patterns. """ import os -os.environ["NANOVLLM_LOG_LEVEL"] = "INFO" +os.environ["NANOVLLM_LOG_LEVEL"] = "WARNING" from random import randint, seed -from typing import Dict, List, Tuple +from typing import Dict, List import torch from torch import Tensor from nanovllm import LLM, SamplingParams from nanovllm.utils.context import get_context - -# ============================================================ -# Configuration -# ============================================================ - +# Config MODEL_PATH = os.path.expanduser("~/models/Qwen3-0.6B/") MAX_MODEL_LEN = 32 * 1024 NUM_GPU_BLOCKS = 4 INPUT_LEN = 32 * 1024 BLOCK_SIZE = 1024 - -# ============================================================ -# External state (managed by test, not framework) -# ============================================================ - -# Record all load operations: list of {cpu_block_id, k_value, v_value, ...} +# State load_log: List[Dict] = [] +current_chunk: List[int] = [0] -# Track current chunk for grouping loads -current_chunk: List[int] = [0] # mutable container - - -# ============================================================ -# Debug hook - receives loaded tensor directly -# ============================================================ def debug_load_hook(slot_idx: int, layer_id: int, cpu_block_id: int, k: Tensor, v: Tensor) -> None: - """ - Debug hook called after each H2D load. - Reads tensor values to verify which block was actually loaded. - """ - # Only record layer 0 for efficiency + """Record loaded tensor values for layer 0.""" if layer_id != 0: return - - # Read tensor values (the distinctive pattern we injected) - k_val = k.float().mean().item() - v_val = v.float().mean().item() - + + if layer_id == 0: + __import__('pdb').set_trace() + load_log.append({ "chunk_idx": current_chunk[0], - "slot_idx": slot_idx, "cpu_block_id": cpu_block_id, - "k_value": k_val, - "v_value": v_val, + "k_value": k.float().mean().item(), }) -# ============================================================ -# Pattern injection hook - injects distinctive values into K/V -# ============================================================ - def make_pattern_injection_hook(layer_id): - """Inject distinctive patterns: K = chunk_idx + 1, V = -(chunk_idx + 1)""" + """Inject K = chunk_idx + 1, V = -(chunk_idx + 1) for layer 0.""" def hook(module, inputs): ctx = get_context() - if not ctx.is_prefill: + if not ctx.is_prefill or layer_id != 0: return inputs - - if layer_id != 0: - return inputs - chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0 - current_chunk[0] = chunk_idx # Update for debug_load_hook - + current_chunk[0] = chunk_idx if len(inputs) >= 3: q, k, v = inputs[0], inputs[1], inputs[2] - k_pattern = float(chunk_idx + 1) - v_pattern = float(-(chunk_idx + 1)) - k_new = torch.full_like(k, k_pattern) - v_new = torch.full_like(v, v_pattern) + k_new = torch.full_like(k, float(chunk_idx + 1)) + v_new = torch.full_like(v, float(-(chunk_idx + 1))) return (q, k_new, v_new) + inputs[3:] - return inputs return hook -# ============================================================ -# Verification functions (all external, not in framework) -# ============================================================ - -def verify_load_order() -> Tuple[int, int, List[Dict]]: - """Verify blocks were loaded in correct order by checking K values.""" - # Group loads by chunk - chunk_loads: Dict[int, List[Tuple[int, float]]] = {} +def verify() -> bool: + """Verify blocks loaded in correct order with correct K values.""" + chunk_loads: Dict[int, List[tuple]] = {} for log in load_log: chunk = log["chunk_idx"] if chunk not in chunk_loads: chunk_loads[chunk] = [] chunk_loads[chunk].append((log["cpu_block_id"], log["k_value"])) - correct = 0 - incorrect = 0 - errors = [] - - for chunk in sorted(chunk_loads.keys()): - loads = chunk_loads[chunk] - # Expected: blocks [0, 1, ..., chunk-1] with K values [1, 2, ..., chunk] + for chunk, loads in chunk_loads.items(): expected_blocks = list(range(chunk)) - actual_blocks = [block_id for block_id, _ in loads] + actual_blocks = [b for b, _ in loads] + k_values = [k for _, k in loads] + expected_k = [float(b + 1) for b in expected_blocks] - # Also verify K values match expected pattern - k_values = [k_val for _, k_val in loads] - expected_k_values = [float(b + 1) for b in expected_blocks] - - blocks_ok = actual_blocks == expected_blocks - # Check K values with tolerance - k_ok = all(abs(a - e) < 1e-2 for a, e in zip(k_values, expected_k_values)) if len(k_values) == len(expected_k_values) else False - - if blocks_ok and k_ok: - correct += 1 - else: - incorrect += 1 - errors.append({ - "chunk_idx": chunk, - "expected_blocks": expected_blocks, - "actual_blocks": actual_blocks, - "expected_k": expected_k_values, - "actual_k": k_values, - }) - - return correct, incorrect, errors + if actual_blocks != expected_blocks: + return False + if not all(abs(a - e) < 1e-2 for a, e in zip(k_values, expected_k)): + return False + return True -def print_verification_summary(): - """Print verification results.""" - correct, incorrect, errors = verify_load_order() - - # Group for display - chunk_loads: Dict[int, List[int]] = {} - for log in load_log: - chunk = log["chunk_idx"] - if chunk not in chunk_loads: - chunk_loads[chunk] = [] - chunk_loads[chunk].append(log["cpu_block_id"]) - - print(f"\n{'='*60}") - print("Debug Verification Summary") - print(f"{'='*60}") - - print(f"\n1. Load Operations:") - print(f" Total H2D loads recorded: {len(load_log)}") - print(f" Chunks with correct order: {correct}") - print(f" Chunks with incorrect order: {incorrect}") - - if incorrect > 0: - print(f"\n Errors:") - for err in errors[:5]: - print(f" Chunk {err['chunk_idx']}:") - print(f" Expected blocks: {err['expected_blocks']}") - print(f" Actual blocks: {err['actual_blocks']}") - print(f" K values: {[f'{v:.1f}' for v in err['actual_k']]}") - - print(f"\n2. Load Order Sample (first 5 and last 2 chunks):") - sorted_chunks = sorted(chunk_loads.keys()) - display_chunks = sorted_chunks[:5] + sorted_chunks[-2:] if len(sorted_chunks) > 7 else sorted_chunks - for chunk in display_chunks: - blocks = chunk_loads[chunk] - expected = list(range(chunk)) - status = "OK" if blocks == expected else "WRONG" - print(f" Chunk {chunk}: {blocks} [{status}]") - - print(f"\n{'='*60}") - - -# ============================================================ -# Main Test Script -# ============================================================ - -print("Initializing LLM with CPU offload...") +# Main llm = LLM( MODEL_PATH, enforce_eager=True, @@ -202,66 +91,28 @@ llm = LLM( dtype="float16", ) -# Get offload engine and enable debug mode -kvcache_manager = llm.model_runner.kvcache_manager -offload_engine = kvcache_manager.offload_engine +offload_engine = llm.model_runner.kvcache_manager.offload_engine offload_engine.enable_debug_mode() - -# Register our debug hook offload_engine.register_debug_hook(debug_load_hook) -print("Debug mode enabled with custom hook") -# Register pattern injection hooks hooks = [] -model = llm.model_runner.model -for layer_idx, decoder_layer in enumerate(model.model.layers): - attn_module = decoder_layer.self_attn.attn - pre_hook = attn_module.register_forward_pre_hook(make_pattern_injection_hook(layer_idx)) - hooks.append(pre_hook) -print(f"Registered {len(hooks)} pattern injection hooks") +for layer_idx, decoder_layer in enumerate(llm.model_runner.model.model.layers): + hooks.append(decoder_layer.self_attn.attn.register_forward_pre_hook( + make_pattern_injection_hook(layer_idx) + )) -# Generate input seed(42) prompt_token_ids = [[randint(0, 10000) for _ in range(INPUT_LEN)]] -num_chunks = INPUT_LEN // BLOCK_SIZE -print(f"\nInput: {INPUT_LEN} tokens, {num_chunks} chunks expected") -print(f"GPU blocks: {NUM_GPU_BLOCKS}, Block size: {BLOCK_SIZE}") +outputs = llm.generate(prompt_token_ids, SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=1), use_tqdm=False) -# Run prefill -print("\n" + "=" * 60) -print("Starting Prefill...") -print("=" * 60) - -sampling_params = SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=1) -outputs = llm.generate(prompt_token_ids, sampling_params, use_tqdm=False) - -# Remove hooks for hook in hooks: hook.remove() offload_engine.remove_debug_hook(debug_load_hook) - -# Verify and print -print("\n" + "=" * 60) -print("Post-Execution Verification") -print("=" * 60) -print_verification_summary() - -# Final verdict -correct, incorrect, _ = verify_load_order() -expected_loads = num_chunks * (num_chunks - 1) // 2 -actual_loads = len(load_log) - -print(f"\nResults:") -print(f" Total loads: {actual_loads} (expected: {expected_loads})") -print(f" Order verification: {correct} correct, {incorrect} incorrect") - -print("\n" + "=" * 60) -all_passed = incorrect == 0 and actual_loads == expected_loads - -if all_passed: - print("test_debug_verification: PASSED") -else: - print("test_debug_verification: FAILED") -print("=" * 60) - offload_engine.disable_debug_mode() + +# Verify +num_chunks = INPUT_LEN // BLOCK_SIZE +expected_loads = num_chunks * (num_chunks - 1) // 2 +passed = len(load_log) == expected_loads and verify() + +print(f"test_debug_verification: {'PASSED' if passed else 'FAILED'}")