[WIP] Before modify nanovllm CPU-GPU kvcache.

This commit is contained in:
Zijie Tian
2025-12-31 22:41:07 +08:00
parent 31e90a7268
commit ccd1b3d4ab
2 changed files with 47 additions and 193 deletions

View File

@@ -1093,4 +1093,7 @@ class OffloadEngine:
try: try:
hook(slot_idx, layer_id, cpu_block_id, k, v) hook(slot_idx, layer_id, cpu_block_id, k, v)
except Exception as e: except Exception as e:
# Allow pdb quit to propagate
if e.__class__.__name__ == 'BdbQuit':
raise
logger.warning(f"Debug hook error: {e}") logger.warning(f"Debug hook error: {e}")

View File

@@ -1,196 +1,85 @@
""" """
Test script for verifying KV cache offload correctness using debug hooks. Test KV cache offload correctness using debug hooks.
Injects distinctive K/V values, verifies loaded tensors match expected patterns.
Strategy:
1. Inject distinctive K/V values (K=chunk_idx+1, V=-(chunk_idx+1))
2. Register debug hook to receive loaded tensor
3. Hook reads tensor values to verify correct block was loaded
4. No verification logic in framework - all external
This tests the framework's normal async execution path.
""" """
import os import os
os.environ["NANOVLLM_LOG_LEVEL"] = "INFO" os.environ["NANOVLLM_LOG_LEVEL"] = "WARNING"
from random import randint, seed from random import randint, seed
from typing import Dict, List, Tuple from typing import Dict, List
import torch import torch
from torch import Tensor from torch import Tensor
from nanovllm import LLM, SamplingParams from nanovllm import LLM, SamplingParams
from nanovllm.utils.context import get_context from nanovllm.utils.context import get_context
# Config
# ============================================================
# Configuration
# ============================================================
MODEL_PATH = os.path.expanduser("~/models/Qwen3-0.6B/") MODEL_PATH = os.path.expanduser("~/models/Qwen3-0.6B/")
MAX_MODEL_LEN = 32 * 1024 MAX_MODEL_LEN = 32 * 1024
NUM_GPU_BLOCKS = 4 NUM_GPU_BLOCKS = 4
INPUT_LEN = 32 * 1024 INPUT_LEN = 32 * 1024
BLOCK_SIZE = 1024 BLOCK_SIZE = 1024
# State
# ============================================================
# External state (managed by test, not framework)
# ============================================================
# Record all load operations: list of {cpu_block_id, k_value, v_value, ...}
load_log: List[Dict] = [] load_log: List[Dict] = []
current_chunk: List[int] = [0]
# Track current chunk for grouping loads
current_chunk: List[int] = [0] # mutable container
# ============================================================
# Debug hook - receives loaded tensor directly
# ============================================================
def debug_load_hook(slot_idx: int, layer_id: int, cpu_block_id: int, k: Tensor, v: Tensor) -> None: def debug_load_hook(slot_idx: int, layer_id: int, cpu_block_id: int, k: Tensor, v: Tensor) -> None:
""" """Record loaded tensor values for layer 0."""
Debug hook called after each H2D load.
Reads tensor values to verify which block was actually loaded.
"""
# Only record layer 0 for efficiency
if layer_id != 0: if layer_id != 0:
return return
# Read tensor values (the distinctive pattern we injected) if layer_id == 0:
k_val = k.float().mean().item() __import__('pdb').set_trace()
v_val = v.float().mean().item()
load_log.append({ load_log.append({
"chunk_idx": current_chunk[0], "chunk_idx": current_chunk[0],
"slot_idx": slot_idx,
"cpu_block_id": cpu_block_id, "cpu_block_id": cpu_block_id,
"k_value": k_val, "k_value": k.float().mean().item(),
"v_value": v_val,
}) })
# ============================================================
# Pattern injection hook - injects distinctive values into K/V
# ============================================================
def make_pattern_injection_hook(layer_id): def make_pattern_injection_hook(layer_id):
"""Inject distinctive patterns: K = chunk_idx + 1, V = -(chunk_idx + 1)""" """Inject K = chunk_idx + 1, V = -(chunk_idx + 1) for layer 0."""
def hook(module, inputs): def hook(module, inputs):
ctx = get_context() ctx = get_context()
if not ctx.is_prefill: if not ctx.is_prefill or layer_id != 0:
return inputs return inputs
if layer_id != 0:
return inputs
chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0 chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0
current_chunk[0] = chunk_idx # Update for debug_load_hook current_chunk[0] = chunk_idx
if len(inputs) >= 3: if len(inputs) >= 3:
q, k, v = inputs[0], inputs[1], inputs[2] q, k, v = inputs[0], inputs[1], inputs[2]
k_pattern = float(chunk_idx + 1) k_new = torch.full_like(k, float(chunk_idx + 1))
v_pattern = float(-(chunk_idx + 1)) v_new = torch.full_like(v, float(-(chunk_idx + 1)))
k_new = torch.full_like(k, k_pattern)
v_new = torch.full_like(v, v_pattern)
return (q, k_new, v_new) + inputs[3:] return (q, k_new, v_new) + inputs[3:]
return inputs return inputs
return hook return hook
# ============================================================ def verify() -> bool:
# Verification functions (all external, not in framework) """Verify blocks loaded in correct order with correct K values."""
# ============================================================ chunk_loads: Dict[int, List[tuple]] = {}
def verify_load_order() -> Tuple[int, int, List[Dict]]:
"""Verify blocks were loaded in correct order by checking K values."""
# Group loads by chunk
chunk_loads: Dict[int, List[Tuple[int, float]]] = {}
for log in load_log: for log in load_log:
chunk = log["chunk_idx"] chunk = log["chunk_idx"]
if chunk not in chunk_loads: if chunk not in chunk_loads:
chunk_loads[chunk] = [] chunk_loads[chunk] = []
chunk_loads[chunk].append((log["cpu_block_id"], log["k_value"])) chunk_loads[chunk].append((log["cpu_block_id"], log["k_value"]))
correct = 0 for chunk, loads in chunk_loads.items():
incorrect = 0
errors = []
for chunk in sorted(chunk_loads.keys()):
loads = chunk_loads[chunk]
# Expected: blocks [0, 1, ..., chunk-1] with K values [1, 2, ..., chunk]
expected_blocks = list(range(chunk)) expected_blocks = list(range(chunk))
actual_blocks = [block_id for block_id, _ in loads] actual_blocks = [b for b, _ in loads]
k_values = [k for _, k in loads]
expected_k = [float(b + 1) for b in expected_blocks]
# Also verify K values match expected pattern if actual_blocks != expected_blocks:
k_values = [k_val for _, k_val in loads] return False
expected_k_values = [float(b + 1) for b in expected_blocks] if not all(abs(a - e) < 1e-2 for a, e in zip(k_values, expected_k)):
return False
blocks_ok = actual_blocks == expected_blocks return True
# Check K values with tolerance
k_ok = all(abs(a - e) < 1e-2 for a, e in zip(k_values, expected_k_values)) if len(k_values) == len(expected_k_values) else False
if blocks_ok and k_ok:
correct += 1
else:
incorrect += 1
errors.append({
"chunk_idx": chunk,
"expected_blocks": expected_blocks,
"actual_blocks": actual_blocks,
"expected_k": expected_k_values,
"actual_k": k_values,
})
return correct, incorrect, errors
def print_verification_summary(): # Main
"""Print verification results."""
correct, incorrect, errors = verify_load_order()
# Group for display
chunk_loads: Dict[int, List[int]] = {}
for log in load_log:
chunk = log["chunk_idx"]
if chunk not in chunk_loads:
chunk_loads[chunk] = []
chunk_loads[chunk].append(log["cpu_block_id"])
print(f"\n{'='*60}")
print("Debug Verification Summary")
print(f"{'='*60}")
print(f"\n1. Load Operations:")
print(f" Total H2D loads recorded: {len(load_log)}")
print(f" Chunks with correct order: {correct}")
print(f" Chunks with incorrect order: {incorrect}")
if incorrect > 0:
print(f"\n Errors:")
for err in errors[:5]:
print(f" Chunk {err['chunk_idx']}:")
print(f" Expected blocks: {err['expected_blocks']}")
print(f" Actual blocks: {err['actual_blocks']}")
print(f" K values: {[f'{v:.1f}' for v in err['actual_k']]}")
print(f"\n2. Load Order Sample (first 5 and last 2 chunks):")
sorted_chunks = sorted(chunk_loads.keys())
display_chunks = sorted_chunks[:5] + sorted_chunks[-2:] if len(sorted_chunks) > 7 else sorted_chunks
for chunk in display_chunks:
blocks = chunk_loads[chunk]
expected = list(range(chunk))
status = "OK" if blocks == expected else "WRONG"
print(f" Chunk {chunk}: {blocks} [{status}]")
print(f"\n{'='*60}")
# ============================================================
# Main Test Script
# ============================================================
print("Initializing LLM with CPU offload...")
llm = LLM( llm = LLM(
MODEL_PATH, MODEL_PATH,
enforce_eager=True, enforce_eager=True,
@@ -202,66 +91,28 @@ llm = LLM(
dtype="float16", dtype="float16",
) )
# Get offload engine and enable debug mode offload_engine = llm.model_runner.kvcache_manager.offload_engine
kvcache_manager = llm.model_runner.kvcache_manager
offload_engine = kvcache_manager.offload_engine
offload_engine.enable_debug_mode() offload_engine.enable_debug_mode()
# Register our debug hook
offload_engine.register_debug_hook(debug_load_hook) offload_engine.register_debug_hook(debug_load_hook)
print("Debug mode enabled with custom hook")
# Register pattern injection hooks
hooks = [] hooks = []
model = llm.model_runner.model for layer_idx, decoder_layer in enumerate(llm.model_runner.model.model.layers):
for layer_idx, decoder_layer in enumerate(model.model.layers): hooks.append(decoder_layer.self_attn.attn.register_forward_pre_hook(
attn_module = decoder_layer.self_attn.attn make_pattern_injection_hook(layer_idx)
pre_hook = attn_module.register_forward_pre_hook(make_pattern_injection_hook(layer_idx)) ))
hooks.append(pre_hook)
print(f"Registered {len(hooks)} pattern injection hooks")
# Generate input
seed(42) seed(42)
prompt_token_ids = [[randint(0, 10000) for _ in range(INPUT_LEN)]] prompt_token_ids = [[randint(0, 10000) for _ in range(INPUT_LEN)]]
num_chunks = INPUT_LEN // BLOCK_SIZE outputs = llm.generate(prompt_token_ids, SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=1), use_tqdm=False)
print(f"\nInput: {INPUT_LEN} tokens, {num_chunks} chunks expected")
print(f"GPU blocks: {NUM_GPU_BLOCKS}, Block size: {BLOCK_SIZE}")
# Run prefill
print("\n" + "=" * 60)
print("Starting Prefill...")
print("=" * 60)
sampling_params = SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=1)
outputs = llm.generate(prompt_token_ids, sampling_params, use_tqdm=False)
# Remove hooks
for hook in hooks: for hook in hooks:
hook.remove() hook.remove()
offload_engine.remove_debug_hook(debug_load_hook) offload_engine.remove_debug_hook(debug_load_hook)
# Verify and print
print("\n" + "=" * 60)
print("Post-Execution Verification")
print("=" * 60)
print_verification_summary()
# Final verdict
correct, incorrect, _ = verify_load_order()
expected_loads = num_chunks * (num_chunks - 1) // 2
actual_loads = len(load_log)
print(f"\nResults:")
print(f" Total loads: {actual_loads} (expected: {expected_loads})")
print(f" Order verification: {correct} correct, {incorrect} incorrect")
print("\n" + "=" * 60)
all_passed = incorrect == 0 and actual_loads == expected_loads
if all_passed:
print("test_debug_verification: PASSED")
else:
print("test_debug_verification: FAILED")
print("=" * 60)
offload_engine.disable_debug_mode() offload_engine.disable_debug_mode()
# Verify
num_chunks = INPUT_LEN // BLOCK_SIZE
expected_loads = num_chunks * (num_chunks - 1) // 2
passed = len(load_log) == expected_loads and verify()
print(f"test_debug_verification: {'PASSED' if passed else 'FAILED'}")