Files
nano-vllm/tests/test_debug_verification.py
2025-12-31 23:35:25 +08:00

116 lines
3.5 KiB
Python

"""
Test KV cache offload correctness using debug hooks.
Injects distinctive K/V values, verifies loaded tensors match expected patterns.
"""
import os
os.environ["NANOVLLM_LOG_LEVEL"] = "WARNING"
from random import randint, seed
from typing import Dict, List
import torch
from torch import Tensor
from nanovllm import LLM, SamplingParams
from nanovllm.utils.context import get_context
# Config
MODEL_PATH = os.path.expanduser("~/models/Qwen3-0.6B/")
MAX_MODEL_LEN = 32 * 1024
NUM_GPU_BLOCKS = 4
INPUT_LEN = 32 * 1024
BLOCK_SIZE = 1024
# State
load_log: List[Dict] = []
current_chunk: List[int] = [0]
def debug_load_hook(slot_idx: int, layer_id: int, cpu_block_id: int, k: Tensor, v: Tensor) -> None:
"""Record loaded tensor values for layer 0."""
if layer_id != 0:
return
load_log.append({
"chunk_idx": current_chunk[0],
"cpu_block_id": cpu_block_id,
"k_value": k.float().mean().item(),
})
def make_pattern_injection_hook(layer_id):
"""Inject K = chunk_idx + 1, V = -(chunk_idx + 1) for layer 0."""
def hook(module, inputs):
ctx = get_context()
if not ctx.is_prefill or layer_id != 0:
return inputs
chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0
current_chunk[0] = chunk_idx
if len(inputs) >= 3:
q, k, v = inputs[0], inputs[1], inputs[2]
k_new = torch.full_like(k, float(chunk_idx + 1))
v_new = torch.full_like(v, float(-(chunk_idx + 1)))
return (q, k_new, v_new) + inputs[3:]
return inputs
return hook
def verify() -> bool:
"""Verify blocks loaded in correct order with correct K values."""
chunk_loads: Dict[int, List[tuple]] = {}
for log in load_log:
chunk = log["chunk_idx"]
if chunk not in chunk_loads:
chunk_loads[chunk] = []
chunk_loads[chunk].append((log["cpu_block_id"], log["k_value"]))
for chunk, loads in chunk_loads.items():
expected_blocks = list(range(chunk))
actual_blocks = [b for b, _ in loads]
k_values = [k for _, k in loads]
expected_k = [float(b + 1) for b in expected_blocks]
if actual_blocks != expected_blocks:
return False
if not all(abs(a - e) < 1e-2 for a, e in zip(k_values, expected_k)):
return False
return True
# Main
llm = LLM(
MODEL_PATH,
enforce_eager=True,
max_model_len=MAX_MODEL_LEN,
max_num_batched_tokens=MAX_MODEL_LEN,
enable_cpu_offload=True,
kvcache_block_size=BLOCK_SIZE,
num_gpu_blocks=NUM_GPU_BLOCKS,
dtype="float16",
)
offload_engine = llm.model_runner.kvcache_manager.offload_engine
offload_engine.enable_debug_mode()
offload_engine.register_debug_hook(debug_load_hook)
hooks = []
for layer_idx, decoder_layer in enumerate(llm.model_runner.model.model.layers):
hooks.append(decoder_layer.self_attn.attn.register_forward_pre_hook(
make_pattern_injection_hook(layer_idx)
))
seed(42)
prompt_token_ids = [[randint(0, 10000) for _ in range(INPUT_LEN)]]
outputs = llm.generate(prompt_token_ids, SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=1), use_tqdm=False)
for hook in hooks:
hook.remove()
offload_engine.remove_debug_hook(debug_load_hook)
offload_engine.disable_debug_mode()
# Verify
num_chunks = INPUT_LEN // BLOCK_SIZE
expected_loads = num_chunks * (num_chunks - 1) // 2
passed = len(load_log) == expected_loads and verify()
print(f"test_debug_verification: {'PASSED' if passed else 'FAILED'}")