Files
nano-vllm/tests/test_chunked_prefill_hook.py
2025-12-30 00:31:48 +08:00

474 lines
17 KiB
Python

"""
Hook-based correctness test for chunked prefill attention.
Uses PyTorch register_forward_hook() to capture real inference I/O,
then compares against reference computation to locate bugs.
This test targets the integration layer (context setup, cpu_block_table management)
which is where the needle test fails despite isolated attention tests passing.
"""
import os
os.environ["NANOVLLM_LOG_LEVEL"] = "DEBUG"
import torch
from random import randint, seed
from nanovllm import LLM, SamplingParams
from nanovllm.utils.context import get_context
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
from flash_attn.flash_attn_interface import flash_attn_varlen_func
# ============================================================
# Configuration
# ============================================================
MODEL_PATH = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/")
MAX_MODEL_LEN = 32 * 1024
NUM_GPU_BLOCKS = 2
INPUT_LEN = 16 * 1024 # 4K tokens = 4 chunks with 1K block size
BLOCK_SIZE = 1024
# ============================================================
# Global capture storage
# ============================================================
captures = []
# ============================================================
# Hook Functions
# ============================================================
def make_hook(layer_id):
"""Create a forward hook for a specific layer."""
def hook(module, inputs, output):
q, k, v = inputs
ctx = get_context()
# Only capture prefill phase
if not ctx.is_prefill:
return
chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0
capture_entry = {
'layer_id': layer_id,
'chunk_idx': chunk_idx,
'q': q.clone().cpu(),
'k': k.clone().cpu(),
'v': v.clone().cpu(),
'output': output.clone().cpu(),
'is_chunked_prefill': ctx.is_chunked_prefill,
}
# For debugging: also capture CPU cache state for layer 0
if layer_id == 0 and chunk_idx >= 2:
kvcache_manager = ctx.kvcache_manager if hasattr(ctx, 'kvcache_manager') else None
if kvcache_manager is not None and hasattr(kvcache_manager, 'offload_engine'):
oe = kvcache_manager.offload_engine
# Get what should have been loaded from CPU
cpu_k0 = oe.k_cache_cpu[0, 0].clone().cpu() # Layer 0, CPU block 0
cpu_k1 = oe.k_cache_cpu[0, 1].clone().cpu() # Layer 0, CPU block 1
capture_entry['cpu_k0'] = cpu_k0
capture_entry['cpu_k1'] = cpu_k1
captures.append(capture_entry)
return hook
def register_hooks(llm):
"""Register forward hooks on all Attention modules."""
hooks = []
model = llm.model_runner.model
for layer_idx, decoder_layer in enumerate(model.model.layers):
attn_module = decoder_layer.self_attn.attn
hook = attn_module.register_forward_hook(make_hook(layer_idx))
hooks.append(hook)
return hooks
# ============================================================
# Reference Computation
# ============================================================
def compute_reference(layer_id, chunk_idx, scale, debug=False):
"""
Compute reference attention output for a specific layer and chunk.
Uses the captured k, v from all chunks up to and including chunk_idx.
"""
# Filter captures for this layer
layer_captures = [c for c in captures
if c['layer_id'] == layer_id and c['chunk_idx'] <= chunk_idx]
if not layer_captures:
return None
# Get current chunk's q
current_capture = [c for c in layer_captures if c['chunk_idx'] == chunk_idx][0]
q = current_capture['q'].cuda().unsqueeze(0) # [1, seqlen, nheads, headdim]
# Collect all k, v up to current chunk
kv_list = []
for c in sorted(layer_captures, key=lambda x: x['chunk_idx']):
k = c['k'].cuda().unsqueeze(0) # [1, seqlen, nheads, headdim]
v = c['v'].cuda().unsqueeze(0)
kv_list.append((k, v, c['chunk_idx']))
if debug:
print(f" Reference for L{layer_id} C{chunk_idx}:")
print(f" q shape: {q.shape}, mean={q.mean().item():.4f}")
print(f" kv_list: {len(kv_list)} chunks")
for i, (k, v, cidx) in enumerate(kv_list):
print(f" chunk {cidx}: k.mean={k.mean().item():.4f}, v.mean={v.mean().item():.4f}")
o_acc, lse_acc = None, None
# Previous chunks: non-causal attention
for i in range(len(kv_list) - 1):
k, v, _ = kv_list[i]
o, lse = flash_attn_with_lse(q, k, v, softmax_scale=scale, causal=False)
if o_acc is None:
o_acc, lse_acc = o, lse
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o, lse)
# Current chunk: causal attention
k_cur, v_cur, _ = kv_list[-1]
o_cur, lse_cur = flash_attn_with_lse(q, k_cur, v_cur, softmax_scale=scale, causal=True)
if o_acc is None:
return o_cur.squeeze(0).cpu()
final_o, _ = merge_attention_outputs(o_acc, lse_acc, o_cur, lse_cur)
return final_o.squeeze(0).cpu()
def compute_standard_reference(layer_id, chunk_idx, scale, debug=False):
"""
Compute reference using standard flash attention (single pass with all K, V).
This simulates what standard (non-chunked) prefill would produce.
Concatenates all Q, K, V from chunks 0 to chunk_idx and runs a single
causal attention pass, then extracts the output for the current chunk.
"""
# Filter captures for this layer
layer_captures = [c for c in captures
if c['layer_id'] == layer_id and c['chunk_idx'] <= chunk_idx]
if not layer_captures:
return None
# Sort by chunk index
layer_captures = sorted(layer_captures, key=lambda x: x['chunk_idx'])
# Concatenate all Q, K, V
all_q = []
all_k = []
all_v = []
chunk_lengths = []
for c in layer_captures:
q = c['q'].cuda() # [seqlen, nheads, headdim]
k = c['k'].cuda()
v = c['v'].cuda()
all_q.append(q)
all_k.append(k)
all_v.append(v)
chunk_lengths.append(q.shape[0])
# Concatenate along sequence dimension
full_q = torch.cat(all_q, dim=0) # [total_seqlen, nheads, headdim]
full_k = torch.cat(all_k, dim=0)
full_v = torch.cat(all_v, dim=0)
total_len = full_q.shape[0]
if debug:
print(f" Standard Reference for L{layer_id} C{chunk_idx}:")
print(f" full_q shape: {full_q.shape}, mean={full_q.mean().item():.4f}")
print(f" full_k shape: {full_k.shape}, mean={full_k.mean().item():.4f}")
print(f" chunk_lengths: {chunk_lengths}")
# Run standard causal flash attention
# flash_attn_varlen_func expects: q, k, v with shape [total_seqlen, nheads, headdim]
cu_seqlens = torch.tensor([0, total_len], dtype=torch.int32, device='cuda')
full_o = flash_attn_varlen_func(
full_q, full_k, full_v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=total_len,
max_seqlen_k=total_len,
softmax_scale=scale,
causal=True,
)
# Extract output for current chunk only
start_pos = sum(chunk_lengths[:-1])
end_pos = sum(chunk_lengths)
chunk_output = full_o[start_pos:end_pos]
if debug:
print(f" full_o shape: {full_o.shape}")
print(f" extracting positions [{start_pos}:{end_pos}]")
print(f" chunk_output shape: {chunk_output.shape}, mean={chunk_output.mean().item():.4f}")
return chunk_output.cpu()
# ============================================================
# Test Runner
# ============================================================
def run_test(verbose=True):
"""Run the hook-based chunked prefill correctness test."""
global captures
captures = []
if verbose:
print("=" * 70)
print("Test: Hook-Based Chunked Prefill Correctness")
print("=" * 70)
print(f"Model: {MODEL_PATH}")
print(f"Input length: {INPUT_LEN} tokens")
print(f"Block size: {BLOCK_SIZE}")
print(f"Expected chunks: {INPUT_LEN // BLOCK_SIZE}")
print()
# Initialize LLM with CPU offload
llm = LLM(
MODEL_PATH,
enforce_eager=True,
max_model_len=MAX_MODEL_LEN,
max_num_batched_tokens=MAX_MODEL_LEN,
enable_cpu_offload=True,
kvcache_block_size=BLOCK_SIZE,
num_gpu_blocks=NUM_GPU_BLOCKS,
)
# Get model info
num_layers = len(llm.model_runner.model.model.layers)
head_dim = llm.model_runner.model.model.layers[0].self_attn.attn.head_dim
scale = head_dim ** -0.5
if verbose:
print(f"Num layers: {num_layers}")
print(f"Head dim: {head_dim}")
print()
# Register hooks
hooks = register_hooks(llm)
if verbose:
print(f"Registered {len(hooks)} hooks")
# Generate random prompt
seed(42)
prompt_token_ids = [[randint(0, 10000) for _ in range(INPUT_LEN)]]
# Run prefill only (max_tokens=1)
if verbose:
print("Running inference...")
sampling_params = SamplingParams(temperature=0.6, max_tokens=1)
outputs = llm.generate(prompt_token_ids, sampling_params, use_tqdm=False)
# Remove hooks
for hook in hooks:
hook.remove()
# Analyze captures
if verbose:
print(f"\nCaptured {len(captures)} attention calls")
# Group by layer and chunk
chunks_per_layer = {}
for c in captures:
layer_id = c['layer_id']
chunk_idx = c['chunk_idx']
if layer_id not in chunks_per_layer:
chunks_per_layer[layer_id] = set()
chunks_per_layer[layer_id].add(chunk_idx)
if verbose:
print("Chunks per layer:", {k: sorted(v) for k, v in chunks_per_layer.items()})
print()
# First, verify CPU cache data integrity
if verbose:
print("\n--- CPU Cache Verification (Layer 0) ---")
# Get original k from chunk 0 and chunk 1 captures
chunk0_k = None
chunk1_k = None
chunk2_capture = None
for c in captures:
if c['layer_id'] == 0:
if c['chunk_idx'] == 0:
chunk0_k = c['k']
elif c['chunk_idx'] == 1:
chunk1_k = c['k']
elif c['chunk_idx'] == 2:
chunk2_capture = c
if chunk0_k is not None and chunk2_capture is not None and 'cpu_k0' in chunk2_capture:
cpu_k0 = chunk2_capture['cpu_k0']
diff_k0 = (chunk0_k - cpu_k0).abs().max().item()
print(f"Chunk 0 k vs CPU cache block 0: max_diff={diff_k0:.6f}")
if diff_k0 > 1e-3:
print(f" WARNING: CPU cache block 0 differs from original chunk 0 k!")
print(f" Original k[0,0,:5] = {chunk0_k[0,0,:5].tolist()}")
print(f" CPU k0[0,0,:5] = {cpu_k0[0,0,:5].tolist()}")
if chunk1_k is not None and chunk2_capture is not None and 'cpu_k1' in chunk2_capture:
cpu_k1 = chunk2_capture['cpu_k1']
diff_k1 = (chunk1_k - cpu_k1).abs().max().item()
print(f"Chunk 1 k vs CPU cache block 1: max_diff={diff_k1:.6f}")
if diff_k1 > 1e-3:
print(f" WARNING: CPU cache block 1 differs from original chunk 1 k!")
print(f" Original k[0,0,:5] = {chunk1_k[0,0,:5].tolist()}")
print(f" CPU k1[0,0,:5] = {cpu_k1[0,0,:5].tolist()}")
print()
# ================================================================
# Test 1: Verify against merge-based reference (same algorithm)
# ================================================================
if verbose:
print("--- Test 1: Merge-based Reference (verifies merge algorithm) ---")
all_passed_merge = True
results_merge = []
first_fail_debug = True
for c in captures:
layer_id = c['layer_id']
chunk_idx = c['chunk_idx']
if chunk_idx == 0:
continue
debug_this = (chunk_idx >= 2 and layer_id == 0 and first_fail_debug)
ref_output = compute_reference(layer_id, chunk_idx, scale, debug=debug_this)
if ref_output is None:
continue
actual_output = c['output']
diff = (actual_output - ref_output).abs()
max_diff = diff.max().item()
mean_diff = diff.mean().item()
tol = 1e-2
passed = max_diff < tol
all_passed_merge = all_passed_merge and passed
status = "PASS" if passed else "FAIL"
results_merge.append((layer_id, chunk_idx, passed, max_diff, mean_diff))
if verbose:
print(f"[{status}] Layer {layer_id:2d}, Chunk {chunk_idx}: "
f"max_diff={max_diff:.6f} mean_diff={mean_diff:.8f}")
if not passed and first_fail_debug:
first_fail_debug = False
print(f" Debug: actual_output shape={actual_output.shape}, mean={actual_output.mean().item():.4f}")
print(f" Debug: ref_output shape={ref_output.shape}, mean={ref_output.mean().item():.4f}")
max_idx = diff.argmax()
flat_actual = actual_output.flatten()
flat_ref = ref_output.flatten()
print(f" Debug: max_diff at idx={max_idx.item()}, actual={flat_actual[max_idx].item():.4f}, ref={flat_ref[max_idx].item():.4f}")
print()
# ================================================================
# Test 2: Verify against standard flash attention (single pass)
# ================================================================
if verbose:
print("--- Test 2: Standard FlashAttn Reference (verifies correctness vs non-chunked) ---")
all_passed_standard = True
results_standard = []
first_fail_debug = True
for c in captures:
layer_id = c['layer_id']
chunk_idx = c['chunk_idx']
if chunk_idx == 0:
continue
debug_this = (chunk_idx >= 2 and layer_id == 0 and first_fail_debug)
std_ref_output = compute_standard_reference(layer_id, chunk_idx, scale, debug=debug_this)
if std_ref_output is None:
continue
actual_output = c['output']
diff = (actual_output - std_ref_output).abs()
max_diff = diff.max().item()
mean_diff = diff.mean().item()
tol = 1e-2
passed = max_diff < tol
all_passed_standard = all_passed_standard and passed
status = "PASS" if passed else "FAIL"
results_standard.append((layer_id, chunk_idx, passed, max_diff, mean_diff))
if verbose:
print(f"[{status}] Layer {layer_id:2d}, Chunk {chunk_idx}: "
f"max_diff={max_diff:.6f} mean_diff={mean_diff:.8f}")
if not passed and first_fail_debug:
first_fail_debug = False
print(f" Debug: actual_output shape={actual_output.shape}, mean={actual_output.mean().item():.4f}")
print(f" Debug: std_ref_output shape={std_ref_output.shape}, mean={std_ref_output.mean().item():.4f}")
max_idx = diff.argmax()
flat_actual = actual_output.flatten()
flat_ref = std_ref_output.flatten()
print(f" Debug: max_diff at idx={max_idx.item()}, actual={flat_actual[max_idx].item():.4f}, ref={flat_ref[max_idx].item():.4f}")
print()
print("=" * 70)
# Summary
total_merge = len(results_merge)
passed_merge = sum(1 for r in results_merge if r[2])
total_standard = len(results_standard)
passed_standard = sum(1 for r in results_standard if r[2])
print(f"Merge-based reference: {passed_merge}/{total_merge} tests passed")
print(f"Standard FlashAttn ref: {passed_standard}/{total_standard} tests passed")
all_passed = all_passed_merge and all_passed_standard
if not all_passed_merge:
print("\nFailed merge-based tests:")
for layer_id, chunk_idx, passed, max_diff, mean_diff in results_merge:
if not passed:
print(f" - Layer {layer_id}, Chunk {chunk_idx}: max_diff={max_diff:.6f}")
if not all_passed_standard:
print("\nFailed standard FlashAttn tests:")
for layer_id, chunk_idx, passed, max_diff, mean_diff in results_standard:
if not passed:
print(f" - Layer {layer_id}, Chunk {chunk_idx}: max_diff={max_diff:.6f}")
print()
return all_passed
# ============================================================
# Main
# ============================================================
if __name__ == "__main__":
passed = run_test(verbose=True)
if passed:
print("test_chunked_prefill_hook: PASSED")
else:
print("test_chunked_prefill_hook: FAILED")
exit(1)