375 lines
13 KiB
Python
375 lines
13 KiB
Python
"""
|
|
Hook-based correctness test for chunked decode attention.
|
|
|
|
Uses PyTorch register_forward_hook() to capture real inference I/O,
|
|
then compares against reference computation to locate bugs.
|
|
|
|
This test targets the decode phase with CPU offload - after prefill,
|
|
the model generates tokens one by one while attending to all previous context.
|
|
"""
|
|
|
|
import os
|
|
os.environ["NANOVLLM_LOG_LEVEL"] = "INFO"
|
|
|
|
import torch
|
|
from random import randint, seed
|
|
from nanovllm import LLM, SamplingParams
|
|
from nanovllm.utils.context import get_context
|
|
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
|
|
|
|
|
# ============================================================
|
|
# Configuration
|
|
# ============================================================
|
|
|
|
MODEL_PATH = os.path.expanduser("~/models/Qwen3-0.6B/")
|
|
MAX_MODEL_LEN = 8 * 1024
|
|
NUM_GPU_BLOCKS = 2
|
|
INPUT_LEN = 2 * 1024 # 2K tokens for prefill
|
|
NUM_DECODE_TOKENS = 5 # Generate 5 tokens to test decode
|
|
BLOCK_SIZE = 1024
|
|
|
|
|
|
# ============================================================
|
|
# Global capture storage
|
|
# ============================================================
|
|
|
|
captures = []
|
|
prefill_kv = {} # Store prefill k,v for reference computation
|
|
|
|
|
|
# ============================================================
|
|
# Hook Functions
|
|
# ============================================================
|
|
|
|
def make_hook(layer_id):
|
|
"""Create a forward hook for a specific layer."""
|
|
def hook(module, inputs, output):
|
|
q, k, v = inputs
|
|
ctx = get_context()
|
|
|
|
is_prefill = ctx.is_prefill
|
|
|
|
capture_entry = {
|
|
'layer_id': layer_id,
|
|
'is_prefill': is_prefill,
|
|
'q': q.clone().cpu(),
|
|
'k': k.clone().cpu(),
|
|
'v': v.clone().cpu(),
|
|
'output': output.clone().cpu(),
|
|
'is_chunked_prefill': ctx.is_chunked_prefill,
|
|
}
|
|
|
|
if is_prefill:
|
|
# Store prefill k,v for reference computation
|
|
chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0
|
|
capture_entry['chunk_idx'] = chunk_idx
|
|
if layer_id not in prefill_kv:
|
|
prefill_kv[layer_id] = []
|
|
prefill_kv[layer_id].append({
|
|
'chunk_idx': chunk_idx,
|
|
'k': k.clone().cpu(),
|
|
'v': v.clone().cpu(),
|
|
})
|
|
else:
|
|
# Decode phase - capture decode token info
|
|
capture_entry['decode_step'] = len([c for c in captures
|
|
if c['layer_id'] == layer_id and not c['is_prefill']])
|
|
|
|
captures.append(capture_entry)
|
|
return hook
|
|
|
|
|
|
def register_hooks(llm):
|
|
"""Register forward hooks on all Attention modules."""
|
|
hooks = []
|
|
model = llm.model_runner.model
|
|
|
|
for layer_idx, decoder_layer in enumerate(model.model.layers):
|
|
attn_module = decoder_layer.self_attn.attn
|
|
hook = attn_module.register_forward_hook(make_hook(layer_idx))
|
|
hooks.append(hook)
|
|
|
|
return hooks
|
|
|
|
|
|
# ============================================================
|
|
# Reference Computation
|
|
# ============================================================
|
|
|
|
def compute_decode_reference(layer_id, decode_step, scale, debug=False):
|
|
"""
|
|
Compute reference decode attention output for a specific layer.
|
|
|
|
For decode, the query is a single token that attends to:
|
|
1. All prefill KV (from CPU cache)
|
|
2. All previous decode tokens (stored in GPU decode slot)
|
|
"""
|
|
# Get the decode capture
|
|
decode_captures = [c for c in captures
|
|
if c['layer_id'] == layer_id and not c['is_prefill']]
|
|
if decode_step >= len(decode_captures):
|
|
return None
|
|
|
|
decode_capture = decode_captures[decode_step]
|
|
q = decode_capture['q'].cuda() # [1, num_heads, head_dim]
|
|
q_batched = q.unsqueeze(1) # [1, 1, num_heads, head_dim]
|
|
|
|
if debug:
|
|
print(f" Reference for L{layer_id} D{decode_step}:")
|
|
print(f" q shape: {q_batched.shape}, mean={q_batched.mean().item():.4f}")
|
|
|
|
o_acc, lse_acc = None, None
|
|
|
|
# Attend to all prefill chunks
|
|
if layer_id in prefill_kv:
|
|
for chunk_data in sorted(prefill_kv[layer_id], key=lambda x: x['chunk_idx']):
|
|
k = chunk_data['k'].cuda().unsqueeze(0) # [1, seqlen, kv_heads, head_dim]
|
|
v = chunk_data['v'].cuda().unsqueeze(0)
|
|
|
|
o, lse = flash_attn_with_lse(q_batched, k, v, softmax_scale=scale, causal=False)
|
|
|
|
if debug:
|
|
print(f" Prefill chunk {chunk_data['chunk_idx']}: o.mean={o.mean().item():.6f}")
|
|
|
|
if o_acc is None:
|
|
o_acc, lse_acc = o, lse
|
|
else:
|
|
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o, lse)
|
|
|
|
# Attend to previous decode tokens (including current)
|
|
# In decode, the current token's k,v are stored, and we need to attend to all previous decode tokens
|
|
# For step 0, we just have the current token's k,v
|
|
# For step 1, we have tokens 0 and 1's k,v
|
|
# etc.
|
|
|
|
# Collect k,v from all decode steps up to and including current
|
|
decode_kv = []
|
|
for i in range(decode_step + 1):
|
|
if i < len(decode_captures):
|
|
decode_kv.append({
|
|
'k': decode_captures[i]['k'].cuda(),
|
|
'v': decode_captures[i]['v'].cuda(),
|
|
})
|
|
|
|
if decode_kv:
|
|
# Stack decode k,v into a single tensor
|
|
decode_k = torch.cat([d['k'] for d in decode_kv], dim=0).unsqueeze(0) # [1, num_decode, kv_heads, head_dim]
|
|
decode_v = torch.cat([d['v'] for d in decode_kv], dim=0).unsqueeze(0)
|
|
|
|
if debug:
|
|
print(f" Decode tokens: {len(decode_kv)}, k.shape={decode_k.shape}")
|
|
|
|
# For decode, we use causal=False since we're attending to all decode tokens
|
|
# (the causal masking was already handled by only including tokens up to current)
|
|
o_decode, lse_decode = flash_attn_with_lse(q_batched, decode_k, decode_v,
|
|
softmax_scale=scale, causal=False)
|
|
|
|
if debug:
|
|
print(f" Decode attention: o.mean={o_decode.mean().item():.6f}")
|
|
|
|
if o_acc is None:
|
|
o_acc, lse_acc = o_decode, lse_decode
|
|
else:
|
|
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o_decode, lse_decode)
|
|
|
|
if o_acc is None:
|
|
return None
|
|
|
|
if debug:
|
|
print(f" Final: o.mean={o_acc.mean().item():.6f}")
|
|
|
|
return o_acc.squeeze(0).squeeze(0).cpu() # [num_heads, head_dim]
|
|
|
|
|
|
# ============================================================
|
|
# Test Runner
|
|
# ============================================================
|
|
|
|
def run_test(verbose=True):
|
|
"""Run the hook-based chunked decode correctness test."""
|
|
global captures, prefill_kv
|
|
captures = []
|
|
prefill_kv = {}
|
|
|
|
if verbose:
|
|
print("=" * 70)
|
|
print("Test: Hook-Based Chunked Decode Correctness")
|
|
print("=" * 70)
|
|
print(f"Model: {MODEL_PATH}")
|
|
print(f"Input length: {INPUT_LEN} tokens")
|
|
print(f"Decode tokens: {NUM_DECODE_TOKENS}")
|
|
print(f"Block size: {BLOCK_SIZE}")
|
|
print()
|
|
|
|
# Initialize LLM with CPU offload
|
|
llm = LLM(
|
|
MODEL_PATH,
|
|
enforce_eager=True,
|
|
max_model_len=MAX_MODEL_LEN,
|
|
max_num_batched_tokens=MAX_MODEL_LEN,
|
|
enable_cpu_offload=True,
|
|
kvcache_block_size=BLOCK_SIZE,
|
|
num_gpu_blocks=NUM_GPU_BLOCKS,
|
|
)
|
|
|
|
# Get model info
|
|
num_layers = len(llm.model_runner.model.model.layers)
|
|
head_dim = llm.model_runner.model.model.layers[0].self_attn.attn.head_dim
|
|
scale = head_dim ** -0.5
|
|
|
|
if verbose:
|
|
print(f"Num layers: {num_layers}")
|
|
print(f"Head dim: {head_dim}")
|
|
print()
|
|
|
|
# Register hooks
|
|
hooks = register_hooks(llm)
|
|
if verbose:
|
|
print(f"Registered {len(hooks)} hooks")
|
|
|
|
# Generate random prompt
|
|
seed(42)
|
|
prompt_token_ids = [[randint(0, 10000) for _ in range(INPUT_LEN)]]
|
|
|
|
# Run prefill and decode
|
|
if verbose:
|
|
print(f"Running inference with {NUM_DECODE_TOKENS} decode tokens...")
|
|
sampling_params = SamplingParams(temperature=0.6, max_tokens=NUM_DECODE_TOKENS)
|
|
outputs = llm.generate(prompt_token_ids, sampling_params, use_tqdm=False)
|
|
|
|
# Remove hooks
|
|
for hook in hooks:
|
|
hook.remove()
|
|
|
|
# =========== VERIFICATION: Check CPU cache after prefill ===========
|
|
# Verify that CPU cache data matches captured prefill k,v
|
|
if verbose:
|
|
print("\n--- CPU Cache Verification (After Prefill) ---")
|
|
offload_engine = llm.model_runner.kvcache_manager.offload_engine
|
|
|
|
# For each prefill capture, check if CPU cache matches
|
|
for layer_id in [0]: # Only check layer 0 for brevity
|
|
if layer_id not in prefill_kv:
|
|
continue
|
|
|
|
for chunk_data in prefill_kv[layer_id]:
|
|
chunk_idx = chunk_data['chunk_idx']
|
|
captured_k = chunk_data['k'] # [block_size, kv_heads, head_dim]
|
|
|
|
# CPU block ID should be chunk_idx (based on allocation order)
|
|
cpu_block_id = chunk_idx
|
|
cpu_k = offload_engine.k_cache_cpu[layer_id, cpu_block_id].cpu()
|
|
|
|
diff = (captured_k - cpu_k).abs().max().item()
|
|
print(f"Layer {layer_id}, Chunk {chunk_idx}: captured_k vs cpu_k max_diff={diff:.6f}")
|
|
if diff > 1e-3:
|
|
print(f" WARNING: CPU cache doesn't match captured k!")
|
|
print(f" captured_k[0,0,:5] = {captured_k[0,0,:5].tolist()}")
|
|
print(f" cpu_k[0,0,:5] = {cpu_k[0,0,:5].tolist()}")
|
|
print()
|
|
|
|
# Analyze captures
|
|
prefill_count = sum(1 for c in captures if c['is_prefill'])
|
|
decode_count = sum(1 for c in captures if not c['is_prefill'])
|
|
if verbose:
|
|
print(f"\nCaptured {prefill_count} prefill calls, {decode_count} decode calls")
|
|
|
|
# Count decode steps per layer
|
|
decode_per_layer = {}
|
|
for c in captures:
|
|
if not c['is_prefill']:
|
|
layer_id = c['layer_id']
|
|
if layer_id not in decode_per_layer:
|
|
decode_per_layer[layer_id] = 0
|
|
decode_per_layer[layer_id] += 1
|
|
|
|
if verbose:
|
|
print(f"Decode calls per layer: {decode_per_layer}")
|
|
print()
|
|
|
|
# Verify decode correctness
|
|
all_passed = True
|
|
results = []
|
|
first_fail_debug = True
|
|
|
|
for c in captures:
|
|
if c['is_prefill']:
|
|
continue # Skip prefill (already tested in test_chunked_prefill_hook.py)
|
|
|
|
layer_id = c['layer_id']
|
|
decode_step = c['decode_step']
|
|
|
|
# Only test first decode step for now (simpler reference computation)
|
|
if decode_step > 0:
|
|
continue
|
|
|
|
# Compute reference (debug first failure)
|
|
debug_this = (layer_id == 0 and first_fail_debug)
|
|
ref_output = compute_decode_reference(layer_id, decode_step, scale, debug=debug_this)
|
|
if ref_output is None:
|
|
continue
|
|
|
|
# Compare
|
|
actual_output = c['output'].squeeze(0) # Remove seq dim for decode
|
|
if actual_output.dim() == 3:
|
|
actual_output = actual_output.squeeze(0) # Handle [1, heads, dim] case
|
|
|
|
diff = (actual_output - ref_output).abs()
|
|
max_diff = diff.max().item()
|
|
mean_diff = diff.mean().item()
|
|
|
|
tol = 1e-2
|
|
passed = max_diff < tol
|
|
all_passed = all_passed and passed
|
|
|
|
status = "PASS" if passed else "FAIL"
|
|
results.append((layer_id, decode_step, passed, max_diff, mean_diff))
|
|
|
|
if verbose:
|
|
print(f"[{status}] Layer {layer_id:2d}, Decode {decode_step}: "
|
|
f"max_diff={max_diff:.6f} mean_diff={mean_diff:.8f}")
|
|
|
|
# Debug first failure
|
|
if not passed and first_fail_debug:
|
|
first_fail_debug = False
|
|
print(f" Debug: actual_output shape={actual_output.shape}, mean={actual_output.mean().item():.4f}")
|
|
print(f" Debug: ref_output shape={ref_output.shape}, mean={ref_output.mean().item():.4f}")
|
|
# Find where max diff is
|
|
max_idx = diff.argmax()
|
|
flat_actual = actual_output.flatten()
|
|
flat_ref = ref_output.flatten()
|
|
print(f" Debug: max_diff at idx={max_idx.item()}, actual={flat_actual[max_idx].item():.4f}, ref={flat_ref[max_idx].item():.4f}")
|
|
|
|
print()
|
|
print("=" * 70)
|
|
|
|
# Summary
|
|
total_tests = len(results)
|
|
passed_count = sum(1 for r in results if r[2])
|
|
|
|
print(f"Results: {passed_count}/{total_tests} tests passed")
|
|
|
|
if not all_passed:
|
|
print("\nFailed tests:")
|
|
for layer_id, decode_step, passed, max_diff, mean_diff in results:
|
|
if not passed:
|
|
print(f" - Layer {layer_id}, Decode {decode_step}: max_diff={max_diff:.6f}")
|
|
|
|
print()
|
|
return all_passed
|
|
|
|
|
|
# ============================================================
|
|
# Main
|
|
# ============================================================
|
|
|
|
if __name__ == "__main__":
|
|
passed = run_test(verbose=True)
|
|
|
|
if passed:
|
|
print("test_chunked_decode_hook: PASSED")
|
|
else:
|
|
print("test_chunked_decode_hook: FAILED")
|
|
exit(1)
|