Files
nano-vllm/tests/test_offload_correctness.py
2025-12-31 23:35:25 +08:00

696 lines
25 KiB
Python

"""
Test script to verify CPU offload correctness using distinctive KV patterns.
Strategy:
1. Hook into attention forward pass
2. Overwrite K/V with distinctive patterns based on chunk_idx (e.g., K=chunk_idx, V=-chunk_idx)
3. After offload to CPU, verify CPU cache contains correct patterns
4. On subsequent chunks, verify loaded KV from CPU has correct patterns
This catches bugs like:
- Wrong block being offloaded
- Wrong block being loaded
- Data corruption during transfer
"""
import os
os.environ["NANOVLLM_LOG_LEVEL"] = "INFO"
import torch
from random import randint, seed
from nanovllm import LLM, SamplingParams
from nanovllm.utils.context import get_context
# ============================================================
# Configuration
# ============================================================
MODEL_PATH = os.path.expanduser("~/models/Qwen3-0.6B/")
MAX_MODEL_LEN = 64 * 1024
NUM_GPU_BLOCKS = 4
INPUT_LEN = 32 * 1024 # 32K tokens = 32 chunks (fits in 40 CPU blocks)
BLOCK_SIZE = 1024
# Test state
errors = []
chunk_patterns = {} # chunk_idx -> (k_pattern, v_pattern)
block_coverage = {} # chunk_idx -> set of blocks that were actually computed
load_operations = [] # List of (chunk_idx, slot_id, cpu_block_id, k_ok, v_ok) tuples
current_chunk_for_load = [0] # Mutable container to track current chunk during loads
# ============================================================
# Pattern Helpers
# ============================================================
def get_expected_pattern(chunk_idx: int):
"""Get expected K/V pattern for a chunk."""
# Use float values that are easy to identify
k_val = float(chunk_idx + 1) # 1.0, 2.0, 3.0, ...
v_val = float(-(chunk_idx + 1)) # -1.0, -2.0, -3.0, ...
return k_val, v_val
def fill_with_pattern(tensor: torch.Tensor, value: float):
"""Fill tensor with a constant value."""
tensor.fill_(value)
def check_pattern(tensor: torch.Tensor, expected: float, name: str, tolerance: float = 1e-3):
"""Check if tensor contains expected pattern."""
actual_mean = tensor.float().mean().item()
if abs(actual_mean - expected) > tolerance:
return False, f"{name}: expected mean={expected}, got {actual_mean}"
return True, None
# ============================================================
# Load Verification Instrumentation
# ============================================================
_original_load_to_slot_layer = None
_offload_engine_ref = None
def make_verified_load_to_slot_layer(original_func, offload_engine):
"""
Create a wrapper around load_to_slot_layer that verifies each load operation.
After each H2D transfer, checks that the GPU slot contains the expected
pattern from the source CPU block.
"""
def verified_load(slot_idx: int, layer_id: int, cpu_block_id: int):
# Call original load
original_func(slot_idx, layer_id, cpu_block_id)
# Only verify layer 0 to reduce overhead
if layer_id != 0:
return
# IMPORTANT: Synchronize CUDA to ensure async transfer is complete
# The transfer happens on a per-slot stream, and wait_slot_layer only
# makes compute_stream wait. We need full sync to read on default stream.
torch.cuda.synchronize()
# Get the expected pattern for this CPU block
# cpu_block_id == chunk_idx in our sequential test
expected_k, expected_v = get_expected_pattern(cpu_block_id)
# Read GPU slot data (GPU cache has no layer dimension)
gpu_k = offload_engine.k_cache_gpu[slot_idx]
gpu_v = offload_engine.v_cache_gpu[slot_idx]
actual_k = gpu_k.float().mean().item()
actual_v = gpu_v.float().mean().item()
k_ok = abs(actual_k - expected_k) < 1e-3
v_ok = abs(actual_v - expected_v) < 1e-3
chunk_idx = current_chunk_for_load[0]
load_operations.append({
'chunk_idx': chunk_idx,
'slot_idx': slot_idx,
'cpu_block_id': cpu_block_id,
'expected_k': expected_k,
'expected_v': expected_v,
'actual_k': actual_k,
'actual_v': actual_v,
'k_ok': k_ok,
'v_ok': v_ok,
})
if not (k_ok and v_ok):
errors.append(f"Load verification failed: chunk {chunk_idx}, "
f"CPU block {cpu_block_id} -> GPU slot {slot_idx}: "
f"expected K={expected_k:.1f}/V={expected_v:.1f}, "
f"got K={actual_k:.4f}/V={actual_v:.4f}")
return verified_load
def install_load_verification(llm):
"""Install verification wrapper on load_to_slot_layer."""
global _original_load_to_slot_layer, _offload_engine_ref
oe = llm.model_runner.kvcache_manager.offload_engine
_offload_engine_ref = oe
_original_load_to_slot_layer = oe.load_to_slot_layer
oe.load_to_slot_layer = make_verified_load_to_slot_layer(
_original_load_to_slot_layer, oe
)
print("Installed load verification wrapper on load_to_slot_layer")
def uninstall_load_verification():
"""Restore original load_to_slot_layer."""
global _original_load_to_slot_layer, _offload_engine_ref
if _offload_engine_ref is not None and _original_load_to_slot_layer is not None:
_offload_engine_ref.load_to_slot_layer = _original_load_to_slot_layer
print("Restored original load_to_slot_layer")
_original_load_to_slot_layer = None
_offload_engine_ref = None
# ============================================================
# Attention Hook
# ============================================================
def make_kv_pattern_pre_hook(layer_id: int):
"""
Create a PRE-forward hook that overwrites K/V with distinctive patterns BEFORE
they are stored to cache. This is called before attention.forward().
register_forward_pre_hook receives (module, inputs) and can modify inputs in-place.
"""
def hook(module, inputs):
ctx = get_context()
if not ctx.is_prefill:
return
chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0
kvcache_manager = ctx.kvcache_manager if hasattr(ctx, 'kvcache_manager') else None
if kvcache_manager is None:
return
# Only process layer 0 for cleaner output
if layer_id != 0:
return
q, k, v = inputs
k_pattern, v_pattern = get_expected_pattern(chunk_idx)
# === Overwrite current chunk's K/V with distinctive pattern ===
# This happens BEFORE forward(), so these values will be stored to cache
k.fill_(k_pattern)
v.fill_(v_pattern)
# Only print for first few and last few chunks to reduce noise
num_chunks = INPUT_LEN // BLOCK_SIZE
if chunk_idx < 3 or chunk_idx >= num_chunks - 2:
print(f"[Chunk {chunk_idx:3d}] Set K={k_pattern:.1f}, V={v_pattern:.1f}")
elif chunk_idx == 3:
print(f"... (chunks 3 to {num_chunks - 3} omitted) ...")
return hook
def make_block_coverage_pre_hook(layer_id: int):
"""
Create a PRE-forward hook to verify that all previous blocks are included
in the cpu_block_table for chunked prefill attention.
This catches bugs where:
- Some blocks are missing from the computation
- Sparse policy incorrectly filters out blocks (when not intended)
- Block table construction has off-by-one errors
"""
def hook(module, inputs):
ctx = get_context()
if not ctx.is_prefill:
return
chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0
kvcache_manager = ctx.kvcache_manager if hasattr(ctx, 'kvcache_manager') else None
if kvcache_manager is None:
return
# Only process layer 0 for cleaner output
if layer_id != 0:
return
# Update current chunk for load verification tracking
current_chunk_for_load[0] = chunk_idx
# No previous blocks for chunk 0
if chunk_idx == 0:
return
# Get the sequence and its block table (same logic as _chunked_prefill_attention)
seq = ctx.chunked_seq if hasattr(ctx, 'chunked_seq') else None
if seq is None:
return
# Get the CPU block table that will be used for attention
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
# Expected blocks: 0 to chunk_idx-1 (all previous chunks)
expected_blocks = set(range(chunk_idx))
actual_blocks = set(cpu_block_table) if cpu_block_table else set()
# Store for later summary
block_coverage[chunk_idx] = {
'expected': expected_blocks,
'actual': actual_blocks,
}
# Check for missing blocks
missing_blocks = expected_blocks - actual_blocks
extra_blocks = actual_blocks - expected_blocks
num_chunks = INPUT_LEN // BLOCK_SIZE
if chunk_idx < 4 or chunk_idx >= num_chunks - 2 or missing_blocks:
if not missing_blocks and not extra_blocks:
print(f" Block coverage chunk {chunk_idx:2d}: {len(actual_blocks)}/{len(expected_blocks)} blocks [OK]")
else:
status_parts = []
if missing_blocks:
status_parts.append(f"MISSING {sorted(missing_blocks)}")
if extra_blocks:
status_parts.append(f"EXTRA {sorted(extra_blocks)}")
print(f" Block coverage chunk {chunk_idx:2d}: {len(actual_blocks)}/{len(expected_blocks)} blocks [{', '.join(status_parts)}]")
elif chunk_idx == 4:
# Indicate that middle chunks are being verified silently
print(f" ... (verifying chunks 4-{num_chunks - 3} silently) ...")
if missing_blocks:
errors.append(f"Chunk {chunk_idx} missing blocks: {sorted(missing_blocks)}")
return hook
def make_gpu_write_verification_post_hook(layer_id: int):
"""
Create a POST-forward hook to verify the current chunk's KV was correctly
written to the GPU ring buffer write_slot.
This is a more reliable verification than checking load slots, because:
1. Post-hook runs AFTER forward() writes to GPU cache
2. write_slot mapping is deterministic: chunk_idx % num_ring_slots
3. We injected known patterns in pre-hook, now verify they're in GPU cache
"""
def hook(module, inputs, output):
ctx = get_context()
if not ctx.is_prefill:
return
chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0
kvcache_manager = ctx.kvcache_manager if hasattr(ctx, 'kvcache_manager') else None
if kvcache_manager is None:
return
# Only process layer 0 for cleaner output
if layer_id != 0:
return
oe = kvcache_manager.offload_engine
num_ring_slots = oe.num_ring_slots
write_slot = chunk_idx % num_ring_slots
# Get expected pattern for current chunk
expected_k, expected_v = get_expected_pattern(chunk_idx)
# Verify write_slot contains current chunk's data (GPU cache has no layer dimension)
gpu_k = oe.k_cache_gpu[write_slot]
gpu_v = oe.v_cache_gpu[write_slot]
actual_k_mean = gpu_k.float().mean().item()
actual_v_mean = gpu_v.float().mean().item()
k_ok, _ = check_pattern(gpu_k, expected_k, f"GPU slot {write_slot}")
v_ok, _ = check_pattern(gpu_v, expected_v, f"GPU slot {write_slot}")
num_chunks = INPUT_LEN // BLOCK_SIZE
# Print for first/last chunks, or if there's an error
if True or chunk_idx >= num_chunks - 2 or not (k_ok and v_ok):
if k_ok and v_ok:
print(f" GPU write_slot[{write_slot}] chunk {chunk_idx:2d}: K={expected_k:.1f}, V={expected_v:.1f} [OK]")
else:
print(f" GPU write_slot[{write_slot}] chunk {chunk_idx:2d}: expected K={expected_k:.1f}/V={expected_v:.1f}, "
f"got K={actual_k_mean:.2f}/V={actual_v_mean:.2f} [FAIL]")
elif chunk_idx == 4:
print(f" ... (GPU write verification for chunks 4-{num_chunks - 3} silently) ...")
if not (k_ok and v_ok):
errors.append(f"GPU write_slot {write_slot} at chunk {chunk_idx}: "
f"expected K={expected_k}, V={expected_v}, got K={actual_k_mean:.4f}, V={actual_v_mean:.4f}")
return hook
def make_kv_verification_post_hook(layer_id: int):
"""
Create a POST-forward hook to verify CPU cache contains correct patterns
from previously offloaded blocks.
"""
def hook(module, inputs, output):
ctx = get_context()
if not ctx.is_prefill:
return
chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0
kvcache_manager = ctx.kvcache_manager if hasattr(ctx, 'kvcache_manager') else None
if kvcache_manager is None:
return
# Only process layer 0 for cleaner output
if layer_id != 0:
return
# === Verify previously offloaded blocks in CPU cache ===
if chunk_idx >= 1:
oe = kvcache_manager.offload_engine
num_ok = 0
num_fail = 0
# Check all previously offloaded blocks
for prev_chunk in range(chunk_idx):
# CPU block ID = prev_chunk (in simple sequential case)
cpu_block_id = prev_chunk
# Get expected pattern for this block
expected_k, expected_v = get_expected_pattern(prev_chunk)
# Read from CPU cache (layer 0)
cpu_k = oe.k_cache_cpu[layer_id, cpu_block_id]
cpu_v = oe.v_cache_cpu[layer_id, cpu_block_id]
# Verify patterns
k_ok, k_err = check_pattern(cpu_k, expected_k, f"CPU K block {cpu_block_id}")
v_ok, v_err = check_pattern(cpu_v, expected_v, f"CPU V block {cpu_block_id}")
if k_ok and v_ok:
num_ok += 1
else:
num_fail += 1
if k_err:
errors.append(k_err)
if v_err:
errors.append(v_err)
# Only print summary for each chunk verification
num_chunks = INPUT_LEN // BLOCK_SIZE
if chunk_idx < 4 or chunk_idx >= num_chunks - 2 or num_fail > 0:
status = "OK" if num_fail == 0 else f"FAIL({num_fail})"
print(f" CPU verify chunk {chunk_idx:2d}: {num_ok} blocks OK [{status}]")
elif chunk_idx == 4:
print(f" ... (CPU cache verification for chunks 4-{num_chunks - 3} silently) ...")
return hook
def make_post_chunk_verification_hook(layer_id: int):
"""
Post-forward hook to verify GPU ring buffer state after attention.
"""
def hook(module, inputs, output):
ctx = get_context()
if not ctx.is_prefill or layer_id != 0:
return
chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0
kvcache_manager = ctx.kvcache_manager if hasattr(ctx, 'kvcache_manager') else None
if kvcache_manager is None:
return
oe = kvcache_manager.offload_engine
# After attention, the current chunk's KV should be in the GPU ring buffer
# Ring slot = chunk_idx % num_ring_slots
ring_slot = chunk_idx % oe.num_ring_slots
expected_k, expected_v = get_expected_pattern(chunk_idx)
# Check GPU ring buffer (GPU cache has no layer dimension)
gpu_k = oe.k_cache_gpu[ring_slot]
gpu_v = oe.v_cache_gpu[ring_slot]
k_ok, k_err = check_pattern(gpu_k, expected_k, f"GPU K slot {ring_slot}")
v_ok, v_err = check_pattern(gpu_v, expected_v, f"GPU V slot {ring_slot}")
if k_ok and v_ok:
print(f" [OK] GPU slot {ring_slot} (chunk {chunk_idx}): K={expected_k}, V={expected_v}")
else:
if k_err:
print(f" [FAIL] {k_err}")
errors.append(k_err)
if v_err:
print(f" [FAIL] {v_err}")
errors.append(v_err)
return hook
def register_hooks(llm):
"""Register pre and post forward hooks."""
hooks = []
model = llm.model_runner.model
for layer_idx, decoder_layer in enumerate(model.model.layers):
attn_module = decoder_layer.self_attn.attn
# PRE-forward hook 1: Verify all previous blocks are in cpu_block_table
coverage_hook = attn_module.register_forward_pre_hook(make_block_coverage_pre_hook(layer_idx))
hooks.append(coverage_hook)
# PRE-forward hook 2: Inject K/V patterns before they're stored to cache
pattern_hook = attn_module.register_forward_pre_hook(make_kv_pattern_pre_hook(layer_idx))
hooks.append(pattern_hook)
# POST-forward hook 1: Verify GPU write_slot contains current chunk's data
gpu_verify_hook = attn_module.register_forward_hook(make_gpu_write_verification_post_hook(layer_idx))
hooks.append(gpu_verify_hook)
# POST-forward hook 2: Verify CPU cache contains correct patterns after offload
cpu_verify_hook = attn_module.register_forward_hook(make_kv_verification_post_hook(layer_idx))
hooks.append(cpu_verify_hook)
return hooks
# ============================================================
# Final Verification
# ============================================================
def verify_final_cpu_state(llm, num_chunks: int):
"""Verify all CPU blocks have correct patterns after prefill completes."""
print("\n" + "=" * 60)
print("Final CPU Cache Verification")
print("=" * 60)
kvcache_manager = llm.model_runner.kvcache_manager
oe = kvcache_manager.offload_engine
num_ok = 0
num_fail = 0
fail_details = []
# After prefill, all chunks should be in CPU
for chunk_idx in range(num_chunks):
cpu_block_id = chunk_idx
expected_k, expected_v = get_expected_pattern(chunk_idx)
# Check layer 0
cpu_k = oe.k_cache_cpu[0, cpu_block_id]
cpu_v = oe.v_cache_cpu[0, cpu_block_id]
k_ok, k_err = check_pattern(cpu_k, expected_k, f"Final CPU K block {cpu_block_id}")
v_ok, v_err = check_pattern(cpu_v, expected_v, f"Final CPU V block {cpu_block_id}")
if k_ok and v_ok:
num_ok += 1
# Only print first few and last few
if chunk_idx < 3 or chunk_idx >= num_chunks - 2:
actual_k_mean = cpu_k.float().mean().item()
actual_v_mean = cpu_v.float().mean().item()
print(f" Block {cpu_block_id:3d}: K={expected_k:.1f} ({actual_k_mean:.4f}), "
f"V={expected_v:.1f} ({actual_v_mean:.4f}) [OK]")
elif chunk_idx == 3:
print(f" ... (blocks 3 to {num_chunks - 3} verified OK) ...")
else:
num_fail += 1
actual_k_mean = cpu_k.float().mean().item()
actual_v_mean = cpu_v.float().mean().item()
print(f" Block {cpu_block_id:3d}: K={expected_k:.1f} ({actual_k_mean:.4f}), "
f"V={expected_v:.1f} ({actual_v_mean:.4f}) [FAIL]")
if k_err:
errors.append(k_err)
if v_err:
errors.append(v_err)
print(f"\nTotal: {num_ok} OK, {num_fail} FAIL out of {num_chunks} blocks")
def verify_block_coverage_summary(num_chunks: int):
"""Verify that all chunks had complete block coverage during prefill."""
print("\n" + "=" * 60)
print("Block Coverage Verification Summary")
print("=" * 60)
num_ok = 0
num_fail = 0
total_blocks_expected = 0
total_blocks_computed = 0
for chunk_idx in range(1, num_chunks): # Start from 1 (chunk 0 has no previous)
if chunk_idx not in block_coverage:
print(f" Chunk {chunk_idx}: NO COVERAGE DATA [FAIL]")
errors.append(f"Chunk {chunk_idx} has no block coverage data")
num_fail += 1
continue
coverage = block_coverage[chunk_idx]
expected = coverage['expected']
actual = coverage['actual']
missing = expected - actual
total_blocks_expected += len(expected)
total_blocks_computed += len(actual)
if not missing:
num_ok += 1
else:
num_fail += 1
# Print summary
if num_fail == 0:
print(f" All {num_ok} chunks had complete block coverage [OK]")
print(f" Total blocks computed: {total_blocks_computed} (expected: {total_blocks_expected})")
else:
print(f" {num_ok} chunks OK, {num_fail} chunks with missing blocks [FAIL]")
print(f" Total blocks computed: {total_blocks_computed} (expected: {total_blocks_expected})")
# Verify the total is correct: sum of 0+1+2+...+(n-1) = n*(n-1)/2
expected_total = num_chunks * (num_chunks - 1) // 2
if total_blocks_expected == expected_total:
print(f" Expected total blocks matches formula: {expected_total} [OK]")
else:
print(f" Expected total mismatch: got {total_blocks_expected}, formula gives {expected_total} [FAIL]")
errors.append(f"Block coverage total mismatch")
def verify_load_operations_summary(num_chunks: int):
"""Verify all H2D load operations transferred correct data."""
print("\n" + "=" * 60)
print("H2D Load Operations Verification Summary")
print("=" * 60)
if not load_operations:
print(" WARNING: No load operations recorded!")
print(" (This may indicate load verification was not installed)")
return
num_ok = 0
num_fail = 0
loads_per_chunk = {}
for op in load_operations:
chunk_idx = op['chunk_idx']
if chunk_idx not in loads_per_chunk:
loads_per_chunk[chunk_idx] = []
loads_per_chunk[chunk_idx].append(op)
if op['k_ok'] and op['v_ok']:
num_ok += 1
else:
num_fail += 1
# Print per-chunk summary for first/last chunks
for chunk_idx in sorted(loads_per_chunk.keys()):
ops = loads_per_chunk[chunk_idx]
chunk_ok = sum(1 for op in ops if op['k_ok'] and op['v_ok'])
chunk_fail = len(ops) - chunk_ok
if chunk_idx < 4 or chunk_idx >= num_chunks - 2 or chunk_fail > 0:
# Show loaded block IDs in order
block_ids = [op['cpu_block_id'] for op in ops]
if chunk_fail == 0:
print(f" Chunk {chunk_idx:2d}: loaded {len(ops)} blocks {block_ids} [OK]")
else:
print(f" Chunk {chunk_idx:2d}: loaded {len(ops)} blocks, {chunk_fail} FAILED [FAIL]")
for op in ops:
if not (op['k_ok'] and op['v_ok']):
print(f" CPU block {op['cpu_block_id']} -> slot {op['slot_idx']}: "
f"expected K={op['expected_k']:.1f}/V={op['expected_v']:.1f}, "
f"got K={op['actual_k']:.4f}/V={op['actual_v']:.4f}")
elif chunk_idx == 4:
print(f" ... (chunks 4-{num_chunks - 3} load verification running silently) ...")
# Print overall summary
print(f"\n Total load operations: {len(load_operations)}")
print(f" Successful: {num_ok}, Failed: {num_fail}")
if num_fail == 0:
print(f" All H2D transfers verified correct [OK]")
else:
print(f" {num_fail} H2D transfers had incorrect data [FAIL]")
# ============================================================
# Main Test Script
# ============================================================
if __name__ == "__main__":
print("=" * 60)
print("Test: CPU Offload Correctness with Distinctive KV Patterns")
print("=" * 60)
print(f"Input: {INPUT_LEN} tokens, {INPUT_LEN // BLOCK_SIZE} chunks")
print(f"GPU blocks: {NUM_GPU_BLOCKS}, Block size: {BLOCK_SIZE}")
print(f"Pattern: K=chunk_idx+1, V=-(chunk_idx+1)")
print()
# 1. Initialize LLM
print("Initializing LLM...")
llm = LLM(
MODEL_PATH,
enforce_eager=True,
max_model_len=MAX_MODEL_LEN,
max_num_batched_tokens=MAX_MODEL_LEN,
enable_cpu_offload=True,
kvcache_block_size=BLOCK_SIZE,
num_gpu_blocks=NUM_GPU_BLOCKS,
dtype="float16",
)
# 2. Register hooks
hooks = register_hooks(llm)
print(f"Registered {len(hooks)} hooks")
# 3. Install load verification (instrument load_to_slot_layer)
install_load_verification(llm)
# 4. Generate prompt
seed(42)
prompt_token_ids = [[randint(0, 10000) for _ in range(INPUT_LEN)]]
num_chunks = INPUT_LEN // BLOCK_SIZE
# 5. Run prefill
print("\n" + "=" * 60)
print("Running Prefill with KV Pattern Injection...")
print("=" * 60)
sampling_params = SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=1)
outputs = llm.generate(prompt_token_ids, sampling_params, use_tqdm=False)
# 6. Remove hooks and uninstall load verification
for hook in hooks:
hook.remove()
uninstall_load_verification()
# 7. Final verification
verify_final_cpu_state(llm, num_chunks)
# 8. Block coverage summary
verify_block_coverage_summary(num_chunks)
# 9. H2D load operations summary
verify_load_operations_summary(num_chunks)
# 10. Report results
print("\n" + "=" * 60)
if errors:
print(f"test_offload_correctness: FAILED ({len(errors)} errors)")
for err in errors[:10]: # Show first 10 errors
print(f" - {err}")
exit(1)
else:
print("test_offload_correctness: PASSED")
print("=" * 60)