[refactor] Delete unnesscessory test, and refacrtor the offload prefix cache.
This commit is contained in:
@@ -1,695 +0,0 @@
|
||||
"""
|
||||
Test script to verify CPU offload correctness using distinctive KV patterns.
|
||||
|
||||
Strategy:
|
||||
1. Hook into attention forward pass
|
||||
2. Overwrite K/V with distinctive patterns based on chunk_idx (e.g., K=chunk_idx, V=-chunk_idx)
|
||||
3. After offload to CPU, verify CPU cache contains correct patterns
|
||||
4. On subsequent chunks, verify loaded KV from CPU has correct patterns
|
||||
|
||||
This catches bugs like:
|
||||
- Wrong block being offloaded
|
||||
- Wrong block being loaded
|
||||
- Data corruption during transfer
|
||||
"""
|
||||
|
||||
import os
|
||||
os.environ["NANOVLLM_LOG_LEVEL"] = "INFO"
|
||||
|
||||
import torch
|
||||
from random import randint, seed
|
||||
from nanovllm import LLM, SamplingParams
|
||||
from nanovllm.utils.context import get_context
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Configuration
|
||||
# ============================================================
|
||||
|
||||
MODEL_PATH = os.path.expanduser("~/models/Qwen3-0.6B/")
|
||||
MAX_MODEL_LEN = 64 * 1024
|
||||
NUM_GPU_BLOCKS = 4
|
||||
INPUT_LEN = 32 * 1024 # 32K tokens = 32 chunks (fits in 40 CPU blocks)
|
||||
BLOCK_SIZE = 1024
|
||||
|
||||
# Test state
|
||||
errors = []
|
||||
chunk_patterns = {} # chunk_idx -> (k_pattern, v_pattern)
|
||||
block_coverage = {} # chunk_idx -> set of blocks that were actually computed
|
||||
load_operations = [] # List of (chunk_idx, slot_id, cpu_block_id, k_ok, v_ok) tuples
|
||||
current_chunk_for_load = [0] # Mutable container to track current chunk during loads
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Pattern Helpers
|
||||
# ============================================================
|
||||
|
||||
def get_expected_pattern(chunk_idx: int):
|
||||
"""Get expected K/V pattern for a chunk."""
|
||||
# Use float values that are easy to identify
|
||||
k_val = float(chunk_idx + 1) # 1.0, 2.0, 3.0, ...
|
||||
v_val = float(-(chunk_idx + 1)) # -1.0, -2.0, -3.0, ...
|
||||
return k_val, v_val
|
||||
|
||||
|
||||
def fill_with_pattern(tensor: torch.Tensor, value: float):
|
||||
"""Fill tensor with a constant value."""
|
||||
tensor.fill_(value)
|
||||
|
||||
|
||||
def check_pattern(tensor: torch.Tensor, expected: float, name: str, tolerance: float = 1e-3):
|
||||
"""Check if tensor contains expected pattern."""
|
||||
actual_mean = tensor.float().mean().item()
|
||||
if abs(actual_mean - expected) > tolerance:
|
||||
return False, f"{name}: expected mean={expected}, got {actual_mean}"
|
||||
return True, None
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Load Verification Instrumentation
|
||||
# ============================================================
|
||||
|
||||
_original_load_to_slot_layer = None
|
||||
_offload_engine_ref = None
|
||||
|
||||
def make_verified_load_to_slot_layer(original_func, offload_engine):
|
||||
"""
|
||||
Create a wrapper around load_to_slot_layer that verifies each load operation.
|
||||
|
||||
After each H2D transfer, checks that the GPU slot contains the expected
|
||||
pattern from the source CPU block.
|
||||
"""
|
||||
def verified_load(slot_idx: int, layer_id: int, cpu_block_id: int):
|
||||
# Call original load
|
||||
original_func(slot_idx, layer_id, cpu_block_id)
|
||||
|
||||
# Only verify layer 0 to reduce overhead
|
||||
if layer_id != 0:
|
||||
return
|
||||
|
||||
# IMPORTANT: Synchronize CUDA to ensure async transfer is complete
|
||||
# The transfer happens on a per-slot stream, and wait_slot_layer only
|
||||
# makes compute_stream wait. We need full sync to read on default stream.
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Get the expected pattern for this CPU block
|
||||
# cpu_block_id == chunk_idx in our sequential test
|
||||
expected_k, expected_v = get_expected_pattern(cpu_block_id)
|
||||
|
||||
# Read GPU slot data (GPU cache has no layer dimension)
|
||||
gpu_k = offload_engine.k_cache_gpu[slot_idx]
|
||||
gpu_v = offload_engine.v_cache_gpu[slot_idx]
|
||||
|
||||
actual_k = gpu_k.float().mean().item()
|
||||
actual_v = gpu_v.float().mean().item()
|
||||
|
||||
k_ok = abs(actual_k - expected_k) < 1e-3
|
||||
v_ok = abs(actual_v - expected_v) < 1e-3
|
||||
|
||||
chunk_idx = current_chunk_for_load[0]
|
||||
load_operations.append({
|
||||
'chunk_idx': chunk_idx,
|
||||
'slot_idx': slot_idx,
|
||||
'cpu_block_id': cpu_block_id,
|
||||
'expected_k': expected_k,
|
||||
'expected_v': expected_v,
|
||||
'actual_k': actual_k,
|
||||
'actual_v': actual_v,
|
||||
'k_ok': k_ok,
|
||||
'v_ok': v_ok,
|
||||
})
|
||||
|
||||
if not (k_ok and v_ok):
|
||||
errors.append(f"Load verification failed: chunk {chunk_idx}, "
|
||||
f"CPU block {cpu_block_id} -> GPU slot {slot_idx}: "
|
||||
f"expected K={expected_k:.1f}/V={expected_v:.1f}, "
|
||||
f"got K={actual_k:.4f}/V={actual_v:.4f}")
|
||||
|
||||
return verified_load
|
||||
|
||||
|
||||
def install_load_verification(llm):
|
||||
"""Install verification wrapper on load_to_slot_layer."""
|
||||
global _original_load_to_slot_layer, _offload_engine_ref
|
||||
|
||||
oe = llm.model_runner.kvcache_manager.offload_engine
|
||||
_offload_engine_ref = oe
|
||||
_original_load_to_slot_layer = oe.load_to_slot_layer
|
||||
|
||||
oe.load_to_slot_layer = make_verified_load_to_slot_layer(
|
||||
_original_load_to_slot_layer, oe
|
||||
)
|
||||
print("Installed load verification wrapper on load_to_slot_layer")
|
||||
|
||||
|
||||
def uninstall_load_verification():
|
||||
"""Restore original load_to_slot_layer."""
|
||||
global _original_load_to_slot_layer, _offload_engine_ref
|
||||
|
||||
if _offload_engine_ref is not None and _original_load_to_slot_layer is not None:
|
||||
_offload_engine_ref.load_to_slot_layer = _original_load_to_slot_layer
|
||||
print("Restored original load_to_slot_layer")
|
||||
|
||||
_original_load_to_slot_layer = None
|
||||
_offload_engine_ref = None
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Attention Hook
|
||||
# ============================================================
|
||||
|
||||
def make_kv_pattern_pre_hook(layer_id: int):
|
||||
"""
|
||||
Create a PRE-forward hook that overwrites K/V with distinctive patterns BEFORE
|
||||
they are stored to cache. This is called before attention.forward().
|
||||
|
||||
register_forward_pre_hook receives (module, inputs) and can modify inputs in-place.
|
||||
"""
|
||||
def hook(module, inputs):
|
||||
ctx = get_context()
|
||||
if not ctx.is_prefill:
|
||||
return
|
||||
|
||||
chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0
|
||||
kvcache_manager = ctx.kvcache_manager if hasattr(ctx, 'kvcache_manager') else None
|
||||
|
||||
if kvcache_manager is None:
|
||||
return
|
||||
|
||||
# Only process layer 0 for cleaner output
|
||||
if layer_id != 0:
|
||||
return
|
||||
|
||||
q, k, v = inputs
|
||||
k_pattern, v_pattern = get_expected_pattern(chunk_idx)
|
||||
|
||||
# === Overwrite current chunk's K/V with distinctive pattern ===
|
||||
# This happens BEFORE forward(), so these values will be stored to cache
|
||||
k.fill_(k_pattern)
|
||||
v.fill_(v_pattern)
|
||||
|
||||
# Only print for first few and last few chunks to reduce noise
|
||||
num_chunks = INPUT_LEN // BLOCK_SIZE
|
||||
if chunk_idx < 3 or chunk_idx >= num_chunks - 2:
|
||||
print(f"[Chunk {chunk_idx:3d}] Set K={k_pattern:.1f}, V={v_pattern:.1f}")
|
||||
elif chunk_idx == 3:
|
||||
print(f"... (chunks 3 to {num_chunks - 3} omitted) ...")
|
||||
|
||||
return hook
|
||||
|
||||
|
||||
def make_block_coverage_pre_hook(layer_id: int):
|
||||
"""
|
||||
Create a PRE-forward hook to verify that all previous blocks are included
|
||||
in the cpu_block_table for chunked prefill attention.
|
||||
|
||||
This catches bugs where:
|
||||
- Some blocks are missing from the computation
|
||||
- Sparse policy incorrectly filters out blocks (when not intended)
|
||||
- Block table construction has off-by-one errors
|
||||
"""
|
||||
def hook(module, inputs):
|
||||
ctx = get_context()
|
||||
if not ctx.is_prefill:
|
||||
return
|
||||
|
||||
chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0
|
||||
kvcache_manager = ctx.kvcache_manager if hasattr(ctx, 'kvcache_manager') else None
|
||||
|
||||
if kvcache_manager is None:
|
||||
return
|
||||
|
||||
# Only process layer 0 for cleaner output
|
||||
if layer_id != 0:
|
||||
return
|
||||
|
||||
# Update current chunk for load verification tracking
|
||||
current_chunk_for_load[0] = chunk_idx
|
||||
|
||||
# No previous blocks for chunk 0
|
||||
if chunk_idx == 0:
|
||||
return
|
||||
|
||||
# Get the sequence and its block table (same logic as _chunked_prefill_attention)
|
||||
seq = ctx.chunked_seq if hasattr(ctx, 'chunked_seq') else None
|
||||
if seq is None:
|
||||
return
|
||||
|
||||
# Get the CPU block table that will be used for attention
|
||||
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
|
||||
|
||||
# Expected blocks: 0 to chunk_idx-1 (all previous chunks)
|
||||
expected_blocks = set(range(chunk_idx))
|
||||
actual_blocks = set(cpu_block_table) if cpu_block_table else set()
|
||||
|
||||
# Store for later summary
|
||||
block_coverage[chunk_idx] = {
|
||||
'expected': expected_blocks,
|
||||
'actual': actual_blocks,
|
||||
}
|
||||
|
||||
# Check for missing blocks
|
||||
missing_blocks = expected_blocks - actual_blocks
|
||||
extra_blocks = actual_blocks - expected_blocks
|
||||
|
||||
num_chunks = INPUT_LEN // BLOCK_SIZE
|
||||
if chunk_idx < 4 or chunk_idx >= num_chunks - 2 or missing_blocks:
|
||||
if not missing_blocks and not extra_blocks:
|
||||
print(f" Block coverage chunk {chunk_idx:2d}: {len(actual_blocks)}/{len(expected_blocks)} blocks [OK]")
|
||||
else:
|
||||
status_parts = []
|
||||
if missing_blocks:
|
||||
status_parts.append(f"MISSING {sorted(missing_blocks)}")
|
||||
if extra_blocks:
|
||||
status_parts.append(f"EXTRA {sorted(extra_blocks)}")
|
||||
print(f" Block coverage chunk {chunk_idx:2d}: {len(actual_blocks)}/{len(expected_blocks)} blocks [{', '.join(status_parts)}]")
|
||||
elif chunk_idx == 4:
|
||||
# Indicate that middle chunks are being verified silently
|
||||
print(f" ... (verifying chunks 4-{num_chunks - 3} silently) ...")
|
||||
|
||||
if missing_blocks:
|
||||
errors.append(f"Chunk {chunk_idx} missing blocks: {sorted(missing_blocks)}")
|
||||
|
||||
return hook
|
||||
|
||||
|
||||
def make_gpu_write_verification_post_hook(layer_id: int):
|
||||
"""
|
||||
Create a POST-forward hook to verify the current chunk's KV was correctly
|
||||
written to the GPU ring buffer write_slot.
|
||||
|
||||
This is a more reliable verification than checking load slots, because:
|
||||
1. Post-hook runs AFTER forward() writes to GPU cache
|
||||
2. write_slot mapping is deterministic: chunk_idx % num_ring_slots
|
||||
3. We injected known patterns in pre-hook, now verify they're in GPU cache
|
||||
"""
|
||||
def hook(module, inputs, output):
|
||||
ctx = get_context()
|
||||
if not ctx.is_prefill:
|
||||
return
|
||||
|
||||
chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0
|
||||
kvcache_manager = ctx.kvcache_manager if hasattr(ctx, 'kvcache_manager') else None
|
||||
|
||||
if kvcache_manager is None:
|
||||
return
|
||||
|
||||
# Only process layer 0 for cleaner output
|
||||
if layer_id != 0:
|
||||
return
|
||||
|
||||
oe = kvcache_manager.offload_engine
|
||||
num_ring_slots = oe.num_ring_slots
|
||||
write_slot = chunk_idx % num_ring_slots
|
||||
|
||||
# Get expected pattern for current chunk
|
||||
expected_k, expected_v = get_expected_pattern(chunk_idx)
|
||||
|
||||
# Verify write_slot contains current chunk's data (GPU cache has no layer dimension)
|
||||
gpu_k = oe.k_cache_gpu[write_slot]
|
||||
gpu_v = oe.v_cache_gpu[write_slot]
|
||||
|
||||
actual_k_mean = gpu_k.float().mean().item()
|
||||
actual_v_mean = gpu_v.float().mean().item()
|
||||
|
||||
k_ok, _ = check_pattern(gpu_k, expected_k, f"GPU slot {write_slot}")
|
||||
v_ok, _ = check_pattern(gpu_v, expected_v, f"GPU slot {write_slot}")
|
||||
|
||||
num_chunks = INPUT_LEN // BLOCK_SIZE
|
||||
# Print for first/last chunks, or if there's an error
|
||||
if True or chunk_idx >= num_chunks - 2 or not (k_ok and v_ok):
|
||||
if k_ok and v_ok:
|
||||
print(f" GPU write_slot[{write_slot}] chunk {chunk_idx:2d}: K={expected_k:.1f}, V={expected_v:.1f} [OK]")
|
||||
else:
|
||||
print(f" GPU write_slot[{write_slot}] chunk {chunk_idx:2d}: expected K={expected_k:.1f}/V={expected_v:.1f}, "
|
||||
f"got K={actual_k_mean:.2f}/V={actual_v_mean:.2f} [FAIL]")
|
||||
elif chunk_idx == 4:
|
||||
print(f" ... (GPU write verification for chunks 4-{num_chunks - 3} silently) ...")
|
||||
|
||||
if not (k_ok and v_ok):
|
||||
errors.append(f"GPU write_slot {write_slot} at chunk {chunk_idx}: "
|
||||
f"expected K={expected_k}, V={expected_v}, got K={actual_k_mean:.4f}, V={actual_v_mean:.4f}")
|
||||
|
||||
return hook
|
||||
|
||||
|
||||
def make_kv_verification_post_hook(layer_id: int):
|
||||
"""
|
||||
Create a POST-forward hook to verify CPU cache contains correct patterns
|
||||
from previously offloaded blocks.
|
||||
"""
|
||||
def hook(module, inputs, output):
|
||||
ctx = get_context()
|
||||
if not ctx.is_prefill:
|
||||
return
|
||||
|
||||
chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0
|
||||
kvcache_manager = ctx.kvcache_manager if hasattr(ctx, 'kvcache_manager') else None
|
||||
|
||||
if kvcache_manager is None:
|
||||
return
|
||||
|
||||
# Only process layer 0 for cleaner output
|
||||
if layer_id != 0:
|
||||
return
|
||||
|
||||
# === Verify previously offloaded blocks in CPU cache ===
|
||||
if chunk_idx >= 1:
|
||||
oe = kvcache_manager.offload_engine
|
||||
num_ok = 0
|
||||
num_fail = 0
|
||||
|
||||
# Check all previously offloaded blocks
|
||||
for prev_chunk in range(chunk_idx):
|
||||
# CPU block ID = prev_chunk (in simple sequential case)
|
||||
cpu_block_id = prev_chunk
|
||||
|
||||
# Get expected pattern for this block
|
||||
expected_k, expected_v = get_expected_pattern(prev_chunk)
|
||||
|
||||
# Read from CPU cache (layer 0)
|
||||
cpu_k = oe.k_cache_cpu[layer_id, cpu_block_id]
|
||||
cpu_v = oe.v_cache_cpu[layer_id, cpu_block_id]
|
||||
|
||||
# Verify patterns
|
||||
k_ok, k_err = check_pattern(cpu_k, expected_k, f"CPU K block {cpu_block_id}")
|
||||
v_ok, v_err = check_pattern(cpu_v, expected_v, f"CPU V block {cpu_block_id}")
|
||||
|
||||
if k_ok and v_ok:
|
||||
num_ok += 1
|
||||
else:
|
||||
num_fail += 1
|
||||
if k_err:
|
||||
errors.append(k_err)
|
||||
if v_err:
|
||||
errors.append(v_err)
|
||||
|
||||
# Only print summary for each chunk verification
|
||||
num_chunks = INPUT_LEN // BLOCK_SIZE
|
||||
if chunk_idx < 4 or chunk_idx >= num_chunks - 2 or num_fail > 0:
|
||||
status = "OK" if num_fail == 0 else f"FAIL({num_fail})"
|
||||
print(f" CPU verify chunk {chunk_idx:2d}: {num_ok} blocks OK [{status}]")
|
||||
elif chunk_idx == 4:
|
||||
print(f" ... (CPU cache verification for chunks 4-{num_chunks - 3} silently) ...")
|
||||
|
||||
return hook
|
||||
|
||||
|
||||
def make_post_chunk_verification_hook(layer_id: int):
|
||||
"""
|
||||
Post-forward hook to verify GPU ring buffer state after attention.
|
||||
"""
|
||||
def hook(module, inputs, output):
|
||||
ctx = get_context()
|
||||
if not ctx.is_prefill or layer_id != 0:
|
||||
return
|
||||
|
||||
chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0
|
||||
kvcache_manager = ctx.kvcache_manager if hasattr(ctx, 'kvcache_manager') else None
|
||||
|
||||
if kvcache_manager is None:
|
||||
return
|
||||
|
||||
oe = kvcache_manager.offload_engine
|
||||
|
||||
# After attention, the current chunk's KV should be in the GPU ring buffer
|
||||
# Ring slot = chunk_idx % num_ring_slots
|
||||
ring_slot = chunk_idx % oe.num_ring_slots
|
||||
|
||||
expected_k, expected_v = get_expected_pattern(chunk_idx)
|
||||
|
||||
# Check GPU ring buffer (GPU cache has no layer dimension)
|
||||
gpu_k = oe.k_cache_gpu[ring_slot]
|
||||
gpu_v = oe.v_cache_gpu[ring_slot]
|
||||
|
||||
k_ok, k_err = check_pattern(gpu_k, expected_k, f"GPU K slot {ring_slot}")
|
||||
v_ok, v_err = check_pattern(gpu_v, expected_v, f"GPU V slot {ring_slot}")
|
||||
|
||||
if k_ok and v_ok:
|
||||
print(f" [OK] GPU slot {ring_slot} (chunk {chunk_idx}): K={expected_k}, V={expected_v}")
|
||||
else:
|
||||
if k_err:
|
||||
print(f" [FAIL] {k_err}")
|
||||
errors.append(k_err)
|
||||
if v_err:
|
||||
print(f" [FAIL] {v_err}")
|
||||
errors.append(v_err)
|
||||
|
||||
return hook
|
||||
|
||||
|
||||
def register_hooks(llm):
|
||||
"""Register pre and post forward hooks."""
|
||||
hooks = []
|
||||
model = llm.model_runner.model
|
||||
|
||||
for layer_idx, decoder_layer in enumerate(model.model.layers):
|
||||
attn_module = decoder_layer.self_attn.attn
|
||||
|
||||
# PRE-forward hook 1: Verify all previous blocks are in cpu_block_table
|
||||
coverage_hook = attn_module.register_forward_pre_hook(make_block_coverage_pre_hook(layer_idx))
|
||||
hooks.append(coverage_hook)
|
||||
|
||||
# PRE-forward hook 2: Inject K/V patterns before they're stored to cache
|
||||
pattern_hook = attn_module.register_forward_pre_hook(make_kv_pattern_pre_hook(layer_idx))
|
||||
hooks.append(pattern_hook)
|
||||
|
||||
# POST-forward hook 1: Verify GPU write_slot contains current chunk's data
|
||||
gpu_verify_hook = attn_module.register_forward_hook(make_gpu_write_verification_post_hook(layer_idx))
|
||||
hooks.append(gpu_verify_hook)
|
||||
|
||||
# POST-forward hook 2: Verify CPU cache contains correct patterns after offload
|
||||
cpu_verify_hook = attn_module.register_forward_hook(make_kv_verification_post_hook(layer_idx))
|
||||
hooks.append(cpu_verify_hook)
|
||||
|
||||
return hooks
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Final Verification
|
||||
# ============================================================
|
||||
|
||||
def verify_final_cpu_state(llm, num_chunks: int):
|
||||
"""Verify all CPU blocks have correct patterns after prefill completes."""
|
||||
print("\n" + "=" * 60)
|
||||
print("Final CPU Cache Verification")
|
||||
print("=" * 60)
|
||||
|
||||
kvcache_manager = llm.model_runner.kvcache_manager
|
||||
oe = kvcache_manager.offload_engine
|
||||
|
||||
num_ok = 0
|
||||
num_fail = 0
|
||||
fail_details = []
|
||||
|
||||
# After prefill, all chunks should be in CPU
|
||||
for chunk_idx in range(num_chunks):
|
||||
cpu_block_id = chunk_idx
|
||||
expected_k, expected_v = get_expected_pattern(chunk_idx)
|
||||
|
||||
# Check layer 0
|
||||
cpu_k = oe.k_cache_cpu[0, cpu_block_id]
|
||||
cpu_v = oe.v_cache_cpu[0, cpu_block_id]
|
||||
|
||||
k_ok, k_err = check_pattern(cpu_k, expected_k, f"Final CPU K block {cpu_block_id}")
|
||||
v_ok, v_err = check_pattern(cpu_v, expected_v, f"Final CPU V block {cpu_block_id}")
|
||||
|
||||
if k_ok and v_ok:
|
||||
num_ok += 1
|
||||
# Only print first few and last few
|
||||
if chunk_idx < 3 or chunk_idx >= num_chunks - 2:
|
||||
actual_k_mean = cpu_k.float().mean().item()
|
||||
actual_v_mean = cpu_v.float().mean().item()
|
||||
print(f" Block {cpu_block_id:3d}: K={expected_k:.1f} ({actual_k_mean:.4f}), "
|
||||
f"V={expected_v:.1f} ({actual_v_mean:.4f}) [OK]")
|
||||
elif chunk_idx == 3:
|
||||
print(f" ... (blocks 3 to {num_chunks - 3} verified OK) ...")
|
||||
else:
|
||||
num_fail += 1
|
||||
actual_k_mean = cpu_k.float().mean().item()
|
||||
actual_v_mean = cpu_v.float().mean().item()
|
||||
print(f" Block {cpu_block_id:3d}: K={expected_k:.1f} ({actual_k_mean:.4f}), "
|
||||
f"V={expected_v:.1f} ({actual_v_mean:.4f}) [FAIL]")
|
||||
if k_err:
|
||||
errors.append(k_err)
|
||||
if v_err:
|
||||
errors.append(v_err)
|
||||
|
||||
print(f"\nTotal: {num_ok} OK, {num_fail} FAIL out of {num_chunks} blocks")
|
||||
|
||||
|
||||
def verify_block_coverage_summary(num_chunks: int):
|
||||
"""Verify that all chunks had complete block coverage during prefill."""
|
||||
print("\n" + "=" * 60)
|
||||
print("Block Coverage Verification Summary")
|
||||
print("=" * 60)
|
||||
|
||||
num_ok = 0
|
||||
num_fail = 0
|
||||
total_blocks_expected = 0
|
||||
total_blocks_computed = 0
|
||||
|
||||
for chunk_idx in range(1, num_chunks): # Start from 1 (chunk 0 has no previous)
|
||||
if chunk_idx not in block_coverage:
|
||||
print(f" Chunk {chunk_idx}: NO COVERAGE DATA [FAIL]")
|
||||
errors.append(f"Chunk {chunk_idx} has no block coverage data")
|
||||
num_fail += 1
|
||||
continue
|
||||
|
||||
coverage = block_coverage[chunk_idx]
|
||||
expected = coverage['expected']
|
||||
actual = coverage['actual']
|
||||
missing = expected - actual
|
||||
|
||||
total_blocks_expected += len(expected)
|
||||
total_blocks_computed += len(actual)
|
||||
|
||||
if not missing:
|
||||
num_ok += 1
|
||||
else:
|
||||
num_fail += 1
|
||||
|
||||
# Print summary
|
||||
if num_fail == 0:
|
||||
print(f" All {num_ok} chunks had complete block coverage [OK]")
|
||||
print(f" Total blocks computed: {total_blocks_computed} (expected: {total_blocks_expected})")
|
||||
else:
|
||||
print(f" {num_ok} chunks OK, {num_fail} chunks with missing blocks [FAIL]")
|
||||
print(f" Total blocks computed: {total_blocks_computed} (expected: {total_blocks_expected})")
|
||||
|
||||
# Verify the total is correct: sum of 0+1+2+...+(n-1) = n*(n-1)/2
|
||||
expected_total = num_chunks * (num_chunks - 1) // 2
|
||||
if total_blocks_expected == expected_total:
|
||||
print(f" Expected total blocks matches formula: {expected_total} [OK]")
|
||||
else:
|
||||
print(f" Expected total mismatch: got {total_blocks_expected}, formula gives {expected_total} [FAIL]")
|
||||
errors.append(f"Block coverage total mismatch")
|
||||
|
||||
|
||||
def verify_load_operations_summary(num_chunks: int):
|
||||
"""Verify all H2D load operations transferred correct data."""
|
||||
print("\n" + "=" * 60)
|
||||
print("H2D Load Operations Verification Summary")
|
||||
print("=" * 60)
|
||||
|
||||
if not load_operations:
|
||||
print(" WARNING: No load operations recorded!")
|
||||
print(" (This may indicate load verification was not installed)")
|
||||
return
|
||||
|
||||
num_ok = 0
|
||||
num_fail = 0
|
||||
loads_per_chunk = {}
|
||||
|
||||
for op in load_operations:
|
||||
chunk_idx = op['chunk_idx']
|
||||
if chunk_idx not in loads_per_chunk:
|
||||
loads_per_chunk[chunk_idx] = []
|
||||
loads_per_chunk[chunk_idx].append(op)
|
||||
|
||||
if op['k_ok'] and op['v_ok']:
|
||||
num_ok += 1
|
||||
else:
|
||||
num_fail += 1
|
||||
|
||||
# Print per-chunk summary for first/last chunks
|
||||
for chunk_idx in sorted(loads_per_chunk.keys()):
|
||||
ops = loads_per_chunk[chunk_idx]
|
||||
chunk_ok = sum(1 for op in ops if op['k_ok'] and op['v_ok'])
|
||||
chunk_fail = len(ops) - chunk_ok
|
||||
|
||||
if chunk_idx < 4 or chunk_idx >= num_chunks - 2 or chunk_fail > 0:
|
||||
# Show loaded block IDs in order
|
||||
block_ids = [op['cpu_block_id'] for op in ops]
|
||||
if chunk_fail == 0:
|
||||
print(f" Chunk {chunk_idx:2d}: loaded {len(ops)} blocks {block_ids} [OK]")
|
||||
else:
|
||||
print(f" Chunk {chunk_idx:2d}: loaded {len(ops)} blocks, {chunk_fail} FAILED [FAIL]")
|
||||
for op in ops:
|
||||
if not (op['k_ok'] and op['v_ok']):
|
||||
print(f" CPU block {op['cpu_block_id']} -> slot {op['slot_idx']}: "
|
||||
f"expected K={op['expected_k']:.1f}/V={op['expected_v']:.1f}, "
|
||||
f"got K={op['actual_k']:.4f}/V={op['actual_v']:.4f}")
|
||||
elif chunk_idx == 4:
|
||||
print(f" ... (chunks 4-{num_chunks - 3} load verification running silently) ...")
|
||||
|
||||
# Print overall summary
|
||||
print(f"\n Total load operations: {len(load_operations)}")
|
||||
print(f" Successful: {num_ok}, Failed: {num_fail}")
|
||||
|
||||
if num_fail == 0:
|
||||
print(f" All H2D transfers verified correct [OK]")
|
||||
else:
|
||||
print(f" {num_fail} H2D transfers had incorrect data [FAIL]")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Main Test Script
|
||||
# ============================================================
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("=" * 60)
|
||||
print("Test: CPU Offload Correctness with Distinctive KV Patterns")
|
||||
print("=" * 60)
|
||||
print(f"Input: {INPUT_LEN} tokens, {INPUT_LEN // BLOCK_SIZE} chunks")
|
||||
print(f"GPU blocks: {NUM_GPU_BLOCKS}, Block size: {BLOCK_SIZE}")
|
||||
print(f"Pattern: K=chunk_idx+1, V=-(chunk_idx+1)")
|
||||
print()
|
||||
|
||||
# 1. Initialize LLM
|
||||
print("Initializing LLM...")
|
||||
llm = LLM(
|
||||
MODEL_PATH,
|
||||
enforce_eager=True,
|
||||
max_model_len=MAX_MODEL_LEN,
|
||||
max_num_batched_tokens=MAX_MODEL_LEN,
|
||||
enable_cpu_offload=True,
|
||||
kvcache_block_size=BLOCK_SIZE,
|
||||
num_gpu_blocks=NUM_GPU_BLOCKS,
|
||||
dtype="float16",
|
||||
)
|
||||
|
||||
# 2. Register hooks
|
||||
hooks = register_hooks(llm)
|
||||
print(f"Registered {len(hooks)} hooks")
|
||||
|
||||
# 3. Install load verification (instrument load_to_slot_layer)
|
||||
install_load_verification(llm)
|
||||
|
||||
# 4. Generate prompt
|
||||
seed(42)
|
||||
prompt_token_ids = [[randint(0, 10000) for _ in range(INPUT_LEN)]]
|
||||
num_chunks = INPUT_LEN // BLOCK_SIZE
|
||||
|
||||
# 5. Run prefill
|
||||
print("\n" + "=" * 60)
|
||||
print("Running Prefill with KV Pattern Injection...")
|
||||
print("=" * 60)
|
||||
|
||||
sampling_params = SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=1)
|
||||
outputs = llm.generate(prompt_token_ids, sampling_params, use_tqdm=False)
|
||||
|
||||
# 6. Remove hooks and uninstall load verification
|
||||
for hook in hooks:
|
||||
hook.remove()
|
||||
uninstall_load_verification()
|
||||
|
||||
# 7. Final verification
|
||||
verify_final_cpu_state(llm, num_chunks)
|
||||
|
||||
# 8. Block coverage summary
|
||||
verify_block_coverage_summary(num_chunks)
|
||||
|
||||
# 9. H2D load operations summary
|
||||
verify_load_operations_summary(num_chunks)
|
||||
|
||||
# 10. Report results
|
||||
print("\n" + "=" * 60)
|
||||
if errors:
|
||||
print(f"test_offload_correctness: FAILED ({len(errors)} errors)")
|
||||
for err in errors[:10]: # Show first 10 errors
|
||||
print(f" - {err}")
|
||||
exit(1)
|
||||
else:
|
||||
print("test_offload_correctness: PASSED")
|
||||
print("=" * 60)
|
||||
Reference in New Issue
Block a user