Files
nano-vllm/tests/test_chunked_decode_hook.py
2025-12-30 00:31:48 +08:00

375 lines
13 KiB
Python

"""
Hook-based correctness test for chunked decode attention.
Uses PyTorch register_forward_hook() to capture real inference I/O,
then compares against reference computation to locate bugs.
This test targets the decode phase with CPU offload - after prefill,
the model generates tokens one by one while attending to all previous context.
"""
import os
os.environ["NANOVLLM_LOG_LEVEL"] = "INFO"
import torch
from random import randint, seed
from nanovllm import LLM, SamplingParams
from nanovllm.utils.context import get_context
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
# ============================================================
# Configuration
# ============================================================
MODEL_PATH = os.path.expanduser("~/models/Qwen3-0.6B/")
MAX_MODEL_LEN = 8 * 1024
NUM_GPU_BLOCKS = 2
INPUT_LEN = 2 * 1024 # 2K tokens for prefill
NUM_DECODE_TOKENS = 5 # Generate 5 tokens to test decode
BLOCK_SIZE = 1024
# ============================================================
# Global capture storage
# ============================================================
captures = []
prefill_kv = {} # Store prefill k,v for reference computation
# ============================================================
# Hook Functions
# ============================================================
def make_hook(layer_id):
"""Create a forward hook for a specific layer."""
def hook(module, inputs, output):
q, k, v = inputs
ctx = get_context()
is_prefill = ctx.is_prefill
capture_entry = {
'layer_id': layer_id,
'is_prefill': is_prefill,
'q': q.clone().cpu(),
'k': k.clone().cpu(),
'v': v.clone().cpu(),
'output': output.clone().cpu(),
'is_chunked_prefill': ctx.is_chunked_prefill,
}
if is_prefill:
# Store prefill k,v for reference computation
chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0
capture_entry['chunk_idx'] = chunk_idx
if layer_id not in prefill_kv:
prefill_kv[layer_id] = []
prefill_kv[layer_id].append({
'chunk_idx': chunk_idx,
'k': k.clone().cpu(),
'v': v.clone().cpu(),
})
else:
# Decode phase - capture decode token info
capture_entry['decode_step'] = len([c for c in captures
if c['layer_id'] == layer_id and not c['is_prefill']])
captures.append(capture_entry)
return hook
def register_hooks(llm):
"""Register forward hooks on all Attention modules."""
hooks = []
model = llm.model_runner.model
for layer_idx, decoder_layer in enumerate(model.model.layers):
attn_module = decoder_layer.self_attn.attn
hook = attn_module.register_forward_hook(make_hook(layer_idx))
hooks.append(hook)
return hooks
# ============================================================
# Reference Computation
# ============================================================
def compute_decode_reference(layer_id, decode_step, scale, debug=False):
"""
Compute reference decode attention output for a specific layer.
For decode, the query is a single token that attends to:
1. All prefill KV (from CPU cache)
2. All previous decode tokens (stored in GPU decode slot)
"""
# Get the decode capture
decode_captures = [c for c in captures
if c['layer_id'] == layer_id and not c['is_prefill']]
if decode_step >= len(decode_captures):
return None
decode_capture = decode_captures[decode_step]
q = decode_capture['q'].cuda() # [1, num_heads, head_dim]
q_batched = q.unsqueeze(1) # [1, 1, num_heads, head_dim]
if debug:
print(f" Reference for L{layer_id} D{decode_step}:")
print(f" q shape: {q_batched.shape}, mean={q_batched.mean().item():.4f}")
o_acc, lse_acc = None, None
# Attend to all prefill chunks
if layer_id in prefill_kv:
for chunk_data in sorted(prefill_kv[layer_id], key=lambda x: x['chunk_idx']):
k = chunk_data['k'].cuda().unsqueeze(0) # [1, seqlen, kv_heads, head_dim]
v = chunk_data['v'].cuda().unsqueeze(0)
o, lse = flash_attn_with_lse(q_batched, k, v, softmax_scale=scale, causal=False)
if debug:
print(f" Prefill chunk {chunk_data['chunk_idx']}: o.mean={o.mean().item():.6f}")
if o_acc is None:
o_acc, lse_acc = o, lse
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o, lse)
# Attend to previous decode tokens (including current)
# In decode, the current token's k,v are stored, and we need to attend to all previous decode tokens
# For step 0, we just have the current token's k,v
# For step 1, we have tokens 0 and 1's k,v
# etc.
# Collect k,v from all decode steps up to and including current
decode_kv = []
for i in range(decode_step + 1):
if i < len(decode_captures):
decode_kv.append({
'k': decode_captures[i]['k'].cuda(),
'v': decode_captures[i]['v'].cuda(),
})
if decode_kv:
# Stack decode k,v into a single tensor
decode_k = torch.cat([d['k'] for d in decode_kv], dim=0).unsqueeze(0) # [1, num_decode, kv_heads, head_dim]
decode_v = torch.cat([d['v'] for d in decode_kv], dim=0).unsqueeze(0)
if debug:
print(f" Decode tokens: {len(decode_kv)}, k.shape={decode_k.shape}")
# For decode, we use causal=False since we're attending to all decode tokens
# (the causal masking was already handled by only including tokens up to current)
o_decode, lse_decode = flash_attn_with_lse(q_batched, decode_k, decode_v,
softmax_scale=scale, causal=False)
if debug:
print(f" Decode attention: o.mean={o_decode.mean().item():.6f}")
if o_acc is None:
o_acc, lse_acc = o_decode, lse_decode
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o_decode, lse_decode)
if o_acc is None:
return None
if debug:
print(f" Final: o.mean={o_acc.mean().item():.6f}")
return o_acc.squeeze(0).squeeze(0).cpu() # [num_heads, head_dim]
# ============================================================
# Test Runner
# ============================================================
def run_test(verbose=True):
"""Run the hook-based chunked decode correctness test."""
global captures, prefill_kv
captures = []
prefill_kv = {}
if verbose:
print("=" * 70)
print("Test: Hook-Based Chunked Decode Correctness")
print("=" * 70)
print(f"Model: {MODEL_PATH}")
print(f"Input length: {INPUT_LEN} tokens")
print(f"Decode tokens: {NUM_DECODE_TOKENS}")
print(f"Block size: {BLOCK_SIZE}")
print()
# Initialize LLM with CPU offload
llm = LLM(
MODEL_PATH,
enforce_eager=True,
max_model_len=MAX_MODEL_LEN,
max_num_batched_tokens=MAX_MODEL_LEN,
enable_cpu_offload=True,
kvcache_block_size=BLOCK_SIZE,
num_gpu_blocks=NUM_GPU_BLOCKS,
)
# Get model info
num_layers = len(llm.model_runner.model.model.layers)
head_dim = llm.model_runner.model.model.layers[0].self_attn.attn.head_dim
scale = head_dim ** -0.5
if verbose:
print(f"Num layers: {num_layers}")
print(f"Head dim: {head_dim}")
print()
# Register hooks
hooks = register_hooks(llm)
if verbose:
print(f"Registered {len(hooks)} hooks")
# Generate random prompt
seed(42)
prompt_token_ids = [[randint(0, 10000) for _ in range(INPUT_LEN)]]
# Run prefill and decode
if verbose:
print(f"Running inference with {NUM_DECODE_TOKENS} decode tokens...")
sampling_params = SamplingParams(temperature=0.6, max_tokens=NUM_DECODE_TOKENS)
outputs = llm.generate(prompt_token_ids, sampling_params, use_tqdm=False)
# Remove hooks
for hook in hooks:
hook.remove()
# =========== VERIFICATION: Check CPU cache after prefill ===========
# Verify that CPU cache data matches captured prefill k,v
if verbose:
print("\n--- CPU Cache Verification (After Prefill) ---")
offload_engine = llm.model_runner.kvcache_manager.offload_engine
# For each prefill capture, check if CPU cache matches
for layer_id in [0]: # Only check layer 0 for brevity
if layer_id not in prefill_kv:
continue
for chunk_data in prefill_kv[layer_id]:
chunk_idx = chunk_data['chunk_idx']
captured_k = chunk_data['k'] # [block_size, kv_heads, head_dim]
# CPU block ID should be chunk_idx (based on allocation order)
cpu_block_id = chunk_idx
cpu_k = offload_engine.k_cache_cpu[layer_id, cpu_block_id].cpu()
diff = (captured_k - cpu_k).abs().max().item()
print(f"Layer {layer_id}, Chunk {chunk_idx}: captured_k vs cpu_k max_diff={diff:.6f}")
if diff > 1e-3:
print(f" WARNING: CPU cache doesn't match captured k!")
print(f" captured_k[0,0,:5] = {captured_k[0,0,:5].tolist()}")
print(f" cpu_k[0,0,:5] = {cpu_k[0,0,:5].tolist()}")
print()
# Analyze captures
prefill_count = sum(1 for c in captures if c['is_prefill'])
decode_count = sum(1 for c in captures if not c['is_prefill'])
if verbose:
print(f"\nCaptured {prefill_count} prefill calls, {decode_count} decode calls")
# Count decode steps per layer
decode_per_layer = {}
for c in captures:
if not c['is_prefill']:
layer_id = c['layer_id']
if layer_id not in decode_per_layer:
decode_per_layer[layer_id] = 0
decode_per_layer[layer_id] += 1
if verbose:
print(f"Decode calls per layer: {decode_per_layer}")
print()
# Verify decode correctness
all_passed = True
results = []
first_fail_debug = True
for c in captures:
if c['is_prefill']:
continue # Skip prefill (already tested in test_chunked_prefill_hook.py)
layer_id = c['layer_id']
decode_step = c['decode_step']
# Only test first decode step for now (simpler reference computation)
if decode_step > 0:
continue
# Compute reference (debug first failure)
debug_this = (layer_id == 0 and first_fail_debug)
ref_output = compute_decode_reference(layer_id, decode_step, scale, debug=debug_this)
if ref_output is None:
continue
# Compare
actual_output = c['output'].squeeze(0) # Remove seq dim for decode
if actual_output.dim() == 3:
actual_output = actual_output.squeeze(0) # Handle [1, heads, dim] case
diff = (actual_output - ref_output).abs()
max_diff = diff.max().item()
mean_diff = diff.mean().item()
tol = 1e-2
passed = max_diff < tol
all_passed = all_passed and passed
status = "PASS" if passed else "FAIL"
results.append((layer_id, decode_step, passed, max_diff, mean_diff))
if verbose:
print(f"[{status}] Layer {layer_id:2d}, Decode {decode_step}: "
f"max_diff={max_diff:.6f} mean_diff={mean_diff:.8f}")
# Debug first failure
if not passed and first_fail_debug:
first_fail_debug = False
print(f" Debug: actual_output shape={actual_output.shape}, mean={actual_output.mean().item():.4f}")
print(f" Debug: ref_output shape={ref_output.shape}, mean={ref_output.mean().item():.4f}")
# Find where max diff is
max_idx = diff.argmax()
flat_actual = actual_output.flatten()
flat_ref = ref_output.flatten()
print(f" Debug: max_diff at idx={max_idx.item()}, actual={flat_actual[max_idx].item():.4f}, ref={flat_ref[max_idx].item():.4f}")
print()
print("=" * 70)
# Summary
total_tests = len(results)
passed_count = sum(1 for r in results if r[2])
print(f"Results: {passed_count}/{total_tests} tests passed")
if not all_passed:
print("\nFailed tests:")
for layer_id, decode_step, passed, max_diff, mean_diff in results:
if not passed:
print(f" - Layer {layer_id}, Decode {decode_step}: max_diff={max_diff:.6f}")
print()
return all_passed
# ============================================================
# Main
# ============================================================
if __name__ == "__main__":
passed = run_test(verbose=True)
if passed:
print("test_chunked_decode_hook: PASSED")
else:
print("test_chunked_decode_hook: FAILED")
exit(1)