[WIP] Before modify nanovllm CPU-GPU kvcache.
This commit is contained in:
@@ -1093,4 +1093,7 @@ class OffloadEngine:
|
|||||||
try:
|
try:
|
||||||
hook(slot_idx, layer_id, cpu_block_id, k, v)
|
hook(slot_idx, layer_id, cpu_block_id, k, v)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
# Allow pdb quit to propagate
|
||||||
|
if e.__class__.__name__ == 'BdbQuit':
|
||||||
|
raise
|
||||||
logger.warning(f"Debug hook error: {e}")
|
logger.warning(f"Debug hook error: {e}")
|
||||||
@@ -1,196 +1,85 @@
|
|||||||
"""
|
"""
|
||||||
Test script for verifying KV cache offload correctness using debug hooks.
|
Test KV cache offload correctness using debug hooks.
|
||||||
|
Injects distinctive K/V values, verifies loaded tensors match expected patterns.
|
||||||
Strategy:
|
|
||||||
1. Inject distinctive K/V values (K=chunk_idx+1, V=-(chunk_idx+1))
|
|
||||||
2. Register debug hook to receive loaded tensor
|
|
||||||
3. Hook reads tensor values to verify correct block was loaded
|
|
||||||
4. No verification logic in framework - all external
|
|
||||||
|
|
||||||
This tests the framework's normal async execution path.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
os.environ["NANOVLLM_LOG_LEVEL"] = "INFO"
|
os.environ["NANOVLLM_LOG_LEVEL"] = "WARNING"
|
||||||
|
|
||||||
from random import randint, seed
|
from random import randint, seed
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from nanovllm import LLM, SamplingParams
|
from nanovllm import LLM, SamplingParams
|
||||||
from nanovllm.utils.context import get_context
|
from nanovllm.utils.context import get_context
|
||||||
|
|
||||||
|
# Config
|
||||||
# ============================================================
|
|
||||||
# Configuration
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
MODEL_PATH = os.path.expanduser("~/models/Qwen3-0.6B/")
|
MODEL_PATH = os.path.expanduser("~/models/Qwen3-0.6B/")
|
||||||
MAX_MODEL_LEN = 32 * 1024
|
MAX_MODEL_LEN = 32 * 1024
|
||||||
NUM_GPU_BLOCKS = 4
|
NUM_GPU_BLOCKS = 4
|
||||||
INPUT_LEN = 32 * 1024
|
INPUT_LEN = 32 * 1024
|
||||||
BLOCK_SIZE = 1024
|
BLOCK_SIZE = 1024
|
||||||
|
|
||||||
|
# State
|
||||||
# ============================================================
|
|
||||||
# External state (managed by test, not framework)
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
# Record all load operations: list of {cpu_block_id, k_value, v_value, ...}
|
|
||||||
load_log: List[Dict] = []
|
load_log: List[Dict] = []
|
||||||
|
current_chunk: List[int] = [0]
|
||||||
|
|
||||||
# Track current chunk for grouping loads
|
|
||||||
current_chunk: List[int] = [0] # mutable container
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Debug hook - receives loaded tensor directly
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
def debug_load_hook(slot_idx: int, layer_id: int, cpu_block_id: int, k: Tensor, v: Tensor) -> None:
|
def debug_load_hook(slot_idx: int, layer_id: int, cpu_block_id: int, k: Tensor, v: Tensor) -> None:
|
||||||
"""
|
"""Record loaded tensor values for layer 0."""
|
||||||
Debug hook called after each H2D load.
|
|
||||||
Reads tensor values to verify which block was actually loaded.
|
|
||||||
"""
|
|
||||||
# Only record layer 0 for efficiency
|
|
||||||
if layer_id != 0:
|
if layer_id != 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Read tensor values (the distinctive pattern we injected)
|
if layer_id == 0:
|
||||||
k_val = k.float().mean().item()
|
__import__('pdb').set_trace()
|
||||||
v_val = v.float().mean().item()
|
|
||||||
|
|
||||||
load_log.append({
|
load_log.append({
|
||||||
"chunk_idx": current_chunk[0],
|
"chunk_idx": current_chunk[0],
|
||||||
"slot_idx": slot_idx,
|
|
||||||
"cpu_block_id": cpu_block_id,
|
"cpu_block_id": cpu_block_id,
|
||||||
"k_value": k_val,
|
"k_value": k.float().mean().item(),
|
||||||
"v_value": v_val,
|
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Pattern injection hook - injects distinctive values into K/V
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
def make_pattern_injection_hook(layer_id):
|
def make_pattern_injection_hook(layer_id):
|
||||||
"""Inject distinctive patterns: K = chunk_idx + 1, V = -(chunk_idx + 1)"""
|
"""Inject K = chunk_idx + 1, V = -(chunk_idx + 1) for layer 0."""
|
||||||
def hook(module, inputs):
|
def hook(module, inputs):
|
||||||
ctx = get_context()
|
ctx = get_context()
|
||||||
if not ctx.is_prefill:
|
if not ctx.is_prefill or layer_id != 0:
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
if layer_id != 0:
|
|
||||||
return inputs
|
|
||||||
|
|
||||||
chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0
|
chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0
|
||||||
current_chunk[0] = chunk_idx # Update for debug_load_hook
|
current_chunk[0] = chunk_idx
|
||||||
|
|
||||||
if len(inputs) >= 3:
|
if len(inputs) >= 3:
|
||||||
q, k, v = inputs[0], inputs[1], inputs[2]
|
q, k, v = inputs[0], inputs[1], inputs[2]
|
||||||
k_pattern = float(chunk_idx + 1)
|
k_new = torch.full_like(k, float(chunk_idx + 1))
|
||||||
v_pattern = float(-(chunk_idx + 1))
|
v_new = torch.full_like(v, float(-(chunk_idx + 1)))
|
||||||
k_new = torch.full_like(k, k_pattern)
|
|
||||||
v_new = torch.full_like(v, v_pattern)
|
|
||||||
return (q, k_new, v_new) + inputs[3:]
|
return (q, k_new, v_new) + inputs[3:]
|
||||||
|
|
||||||
return inputs
|
return inputs
|
||||||
return hook
|
return hook
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
def verify() -> bool:
|
||||||
# Verification functions (all external, not in framework)
|
"""Verify blocks loaded in correct order with correct K values."""
|
||||||
# ============================================================
|
chunk_loads: Dict[int, List[tuple]] = {}
|
||||||
|
|
||||||
def verify_load_order() -> Tuple[int, int, List[Dict]]:
|
|
||||||
"""Verify blocks were loaded in correct order by checking K values."""
|
|
||||||
# Group loads by chunk
|
|
||||||
chunk_loads: Dict[int, List[Tuple[int, float]]] = {}
|
|
||||||
for log in load_log:
|
for log in load_log:
|
||||||
chunk = log["chunk_idx"]
|
chunk = log["chunk_idx"]
|
||||||
if chunk not in chunk_loads:
|
if chunk not in chunk_loads:
|
||||||
chunk_loads[chunk] = []
|
chunk_loads[chunk] = []
|
||||||
chunk_loads[chunk].append((log["cpu_block_id"], log["k_value"]))
|
chunk_loads[chunk].append((log["cpu_block_id"], log["k_value"]))
|
||||||
|
|
||||||
correct = 0
|
for chunk, loads in chunk_loads.items():
|
||||||
incorrect = 0
|
|
||||||
errors = []
|
|
||||||
|
|
||||||
for chunk in sorted(chunk_loads.keys()):
|
|
||||||
loads = chunk_loads[chunk]
|
|
||||||
# Expected: blocks [0, 1, ..., chunk-1] with K values [1, 2, ..., chunk]
|
|
||||||
expected_blocks = list(range(chunk))
|
expected_blocks = list(range(chunk))
|
||||||
actual_blocks = [block_id for block_id, _ in loads]
|
actual_blocks = [b for b, _ in loads]
|
||||||
|
k_values = [k for _, k in loads]
|
||||||
|
expected_k = [float(b + 1) for b in expected_blocks]
|
||||||
|
|
||||||
# Also verify K values match expected pattern
|
if actual_blocks != expected_blocks:
|
||||||
k_values = [k_val for _, k_val in loads]
|
return False
|
||||||
expected_k_values = [float(b + 1) for b in expected_blocks]
|
if not all(abs(a - e) < 1e-2 for a, e in zip(k_values, expected_k)):
|
||||||
|
return False
|
||||||
blocks_ok = actual_blocks == expected_blocks
|
return True
|
||||||
# Check K values with tolerance
|
|
||||||
k_ok = all(abs(a - e) < 1e-2 for a, e in zip(k_values, expected_k_values)) if len(k_values) == len(expected_k_values) else False
|
|
||||||
|
|
||||||
if blocks_ok and k_ok:
|
|
||||||
correct += 1
|
|
||||||
else:
|
|
||||||
incorrect += 1
|
|
||||||
errors.append({
|
|
||||||
"chunk_idx": chunk,
|
|
||||||
"expected_blocks": expected_blocks,
|
|
||||||
"actual_blocks": actual_blocks,
|
|
||||||
"expected_k": expected_k_values,
|
|
||||||
"actual_k": k_values,
|
|
||||||
})
|
|
||||||
|
|
||||||
return correct, incorrect, errors
|
|
||||||
|
|
||||||
|
|
||||||
def print_verification_summary():
|
# Main
|
||||||
"""Print verification results."""
|
|
||||||
correct, incorrect, errors = verify_load_order()
|
|
||||||
|
|
||||||
# Group for display
|
|
||||||
chunk_loads: Dict[int, List[int]] = {}
|
|
||||||
for log in load_log:
|
|
||||||
chunk = log["chunk_idx"]
|
|
||||||
if chunk not in chunk_loads:
|
|
||||||
chunk_loads[chunk] = []
|
|
||||||
chunk_loads[chunk].append(log["cpu_block_id"])
|
|
||||||
|
|
||||||
print(f"\n{'='*60}")
|
|
||||||
print("Debug Verification Summary")
|
|
||||||
print(f"{'='*60}")
|
|
||||||
|
|
||||||
print(f"\n1. Load Operations:")
|
|
||||||
print(f" Total H2D loads recorded: {len(load_log)}")
|
|
||||||
print(f" Chunks with correct order: {correct}")
|
|
||||||
print(f" Chunks with incorrect order: {incorrect}")
|
|
||||||
|
|
||||||
if incorrect > 0:
|
|
||||||
print(f"\n Errors:")
|
|
||||||
for err in errors[:5]:
|
|
||||||
print(f" Chunk {err['chunk_idx']}:")
|
|
||||||
print(f" Expected blocks: {err['expected_blocks']}")
|
|
||||||
print(f" Actual blocks: {err['actual_blocks']}")
|
|
||||||
print(f" K values: {[f'{v:.1f}' for v in err['actual_k']]}")
|
|
||||||
|
|
||||||
print(f"\n2. Load Order Sample (first 5 and last 2 chunks):")
|
|
||||||
sorted_chunks = sorted(chunk_loads.keys())
|
|
||||||
display_chunks = sorted_chunks[:5] + sorted_chunks[-2:] if len(sorted_chunks) > 7 else sorted_chunks
|
|
||||||
for chunk in display_chunks:
|
|
||||||
blocks = chunk_loads[chunk]
|
|
||||||
expected = list(range(chunk))
|
|
||||||
status = "OK" if blocks == expected else "WRONG"
|
|
||||||
print(f" Chunk {chunk}: {blocks} [{status}]")
|
|
||||||
|
|
||||||
print(f"\n{'='*60}")
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Main Test Script
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
print("Initializing LLM with CPU offload...")
|
|
||||||
llm = LLM(
|
llm = LLM(
|
||||||
MODEL_PATH,
|
MODEL_PATH,
|
||||||
enforce_eager=True,
|
enforce_eager=True,
|
||||||
@@ -202,66 +91,28 @@ llm = LLM(
|
|||||||
dtype="float16",
|
dtype="float16",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get offload engine and enable debug mode
|
offload_engine = llm.model_runner.kvcache_manager.offload_engine
|
||||||
kvcache_manager = llm.model_runner.kvcache_manager
|
|
||||||
offload_engine = kvcache_manager.offload_engine
|
|
||||||
offload_engine.enable_debug_mode()
|
offload_engine.enable_debug_mode()
|
||||||
|
|
||||||
# Register our debug hook
|
|
||||||
offload_engine.register_debug_hook(debug_load_hook)
|
offload_engine.register_debug_hook(debug_load_hook)
|
||||||
print("Debug mode enabled with custom hook")
|
|
||||||
|
|
||||||
# Register pattern injection hooks
|
|
||||||
hooks = []
|
hooks = []
|
||||||
model = llm.model_runner.model
|
for layer_idx, decoder_layer in enumerate(llm.model_runner.model.model.layers):
|
||||||
for layer_idx, decoder_layer in enumerate(model.model.layers):
|
hooks.append(decoder_layer.self_attn.attn.register_forward_pre_hook(
|
||||||
attn_module = decoder_layer.self_attn.attn
|
make_pattern_injection_hook(layer_idx)
|
||||||
pre_hook = attn_module.register_forward_pre_hook(make_pattern_injection_hook(layer_idx))
|
))
|
||||||
hooks.append(pre_hook)
|
|
||||||
print(f"Registered {len(hooks)} pattern injection hooks")
|
|
||||||
|
|
||||||
# Generate input
|
|
||||||
seed(42)
|
seed(42)
|
||||||
prompt_token_ids = [[randint(0, 10000) for _ in range(INPUT_LEN)]]
|
prompt_token_ids = [[randint(0, 10000) for _ in range(INPUT_LEN)]]
|
||||||
num_chunks = INPUT_LEN // BLOCK_SIZE
|
outputs = llm.generate(prompt_token_ids, SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=1), use_tqdm=False)
|
||||||
print(f"\nInput: {INPUT_LEN} tokens, {num_chunks} chunks expected")
|
|
||||||
print(f"GPU blocks: {NUM_GPU_BLOCKS}, Block size: {BLOCK_SIZE}")
|
|
||||||
|
|
||||||
# Run prefill
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
print("Starting Prefill...")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
sampling_params = SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=1)
|
|
||||||
outputs = llm.generate(prompt_token_ids, sampling_params, use_tqdm=False)
|
|
||||||
|
|
||||||
# Remove hooks
|
|
||||||
for hook in hooks:
|
for hook in hooks:
|
||||||
hook.remove()
|
hook.remove()
|
||||||
offload_engine.remove_debug_hook(debug_load_hook)
|
offload_engine.remove_debug_hook(debug_load_hook)
|
||||||
|
|
||||||
# Verify and print
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
print("Post-Execution Verification")
|
|
||||||
print("=" * 60)
|
|
||||||
print_verification_summary()
|
|
||||||
|
|
||||||
# Final verdict
|
|
||||||
correct, incorrect, _ = verify_load_order()
|
|
||||||
expected_loads = num_chunks * (num_chunks - 1) // 2
|
|
||||||
actual_loads = len(load_log)
|
|
||||||
|
|
||||||
print(f"\nResults:")
|
|
||||||
print(f" Total loads: {actual_loads} (expected: {expected_loads})")
|
|
||||||
print(f" Order verification: {correct} correct, {incorrect} incorrect")
|
|
||||||
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
all_passed = incorrect == 0 and actual_loads == expected_loads
|
|
||||||
|
|
||||||
if all_passed:
|
|
||||||
print("test_debug_verification: PASSED")
|
|
||||||
else:
|
|
||||||
print("test_debug_verification: FAILED")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
offload_engine.disable_debug_mode()
|
offload_engine.disable_debug_mode()
|
||||||
|
|
||||||
|
# Verify
|
||||||
|
num_chunks = INPUT_LEN // BLOCK_SIZE
|
||||||
|
expected_loads = num_chunks * (num_chunks - 1) // 2
|
||||||
|
passed = len(load_log) == expected_loads and verify()
|
||||||
|
|
||||||
|
print(f"test_debug_verification: {'PASSED' if passed else 'FAILED'}")
|
||||||
|
|||||||
Reference in New Issue
Block a user