[refactor] Refactor the test_chunked_prefill/decode.

This commit is contained in:
Zijie Tian
2026-01-01 03:32:26 +08:00
parent 965c8aff12
commit 62b8a63314
2 changed files with 294 additions and 731 deletions

View File

@@ -1,374 +1,214 @@
""" """
Hook-based correctness test for chunked decode attention. Correctness test for chunked decode attention.
Uses PyTorch register_forward_hook() to capture real inference I/O, Captures Q and output during inference, then computes reference using
then compares against reference computation to locate bugs. CPU KV cache with standard flash attention.
This test targets the decode phase with CPU offload - after prefill,
the model generates tokens one by one while attending to all previous context.
""" """
import os import os
os.environ["NANOVLLM_LOG_LEVEL"] = "INFO" os.environ["NANOVLLM_LOG_LEVEL"] = "WARNING"
import torch import torch
from random import randint, seed from random import randint, seed
from typing import Dict, List
from nanovllm import LLM, SamplingParams from nanovllm import LLM, SamplingParams
from nanovllm.utils.context import get_context from nanovllm.utils.context import get_context
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs from flash_attn.flash_attn_interface import flash_attn_func
# Config
# ============================================================ MODEL_PATH = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/")
# Configuration MAX_MODEL_LEN = 128 * 1024
# ============================================================
MODEL_PATH = os.path.expanduser("~/models/Qwen3-0.6B/")
MAX_MODEL_LEN = 8 * 1024
NUM_GPU_BLOCKS = 2 NUM_GPU_BLOCKS = 2
INPUT_LEN = 2 * 1024 # 2K tokens for prefill INPUT_LEN = 16 * 1024
NUM_DECODE_TOKENS = 5 # Generate 5 tokens to test decode NUM_DECODE_TOKENS = 5
BLOCK_SIZE = 1024 BLOCK_SIZE = 1024
# State
# ============================================================ prefill_captures: List[Dict] = []
# Global capture storage decode_captures: List[Dict] = []
# ============================================================
captures = []
prefill_kv = {} # Store prefill k,v for reference computation
# ============================================================ def make_ones_injection_hook():
# Hook Functions """Inject Q=K=V=1.0 for deterministic testing."""
# ============================================================ def hook(module, inputs):
q, k, v = inputs[0], inputs[1], inputs[2]
def make_hook(layer_id): q_ones = torch.ones_like(q)
"""Create a forward hook for a specific layer.""" k_ones = torch.ones_like(k)
def hook(module, inputs, output): v_ones = torch.ones_like(v)
q, k, v = inputs return (q_ones, k_ones, v_ones) + inputs[3:]
ctx = get_context()
is_prefill = ctx.is_prefill
capture_entry = {
'layer_id': layer_id,
'is_prefill': is_prefill,
'q': q.clone().cpu(),
'k': k.clone().cpu(),
'v': v.clone().cpu(),
'output': output.clone().cpu(),
'is_chunked_prefill': ctx.is_chunked_prefill,
}
if is_prefill:
# Store prefill k,v for reference computation
chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0
capture_entry['chunk_idx'] = chunk_idx
if layer_id not in prefill_kv:
prefill_kv[layer_id] = []
prefill_kv[layer_id].append({
'chunk_idx': chunk_idx,
'k': k.clone().cpu(),
'v': v.clone().cpu(),
})
else:
# Decode phase - capture decode token info
capture_entry['decode_step'] = len([c for c in captures
if c['layer_id'] == layer_id and not c['is_prefill']])
captures.append(capture_entry)
return hook return hook
def register_hooks(llm): def make_capture_hook(layer_id: int):
"""Register forward hooks on all Attention modules.""" """Capture Q, K, V, output during inference."""
hooks = [] def hook(module, inputs, output):
model = llm.model_runner.model ctx = get_context()
q, k, v = inputs
for layer_idx, decoder_layer in enumerate(model.model.layers): if ctx.is_prefill:
attn_module = decoder_layer.self_attn.attn chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0
hook = attn_module.register_forward_hook(make_hook(layer_idx)) prefill_captures.append({
hooks.append(hook) 'layer_id': layer_id,
'chunk_idx': chunk_idx,
return hooks 'q': q.clone().cpu(),
'k': k.clone().cpu(),
'v': v.clone().cpu(),
# ============================================================ 'output': output.clone().cpu(),
# Reference Computation
# ============================================================
def compute_decode_reference(layer_id, decode_step, scale, debug=False):
"""
Compute reference decode attention output for a specific layer.
For decode, the query is a single token that attends to:
1. All prefill KV (from CPU cache)
2. All previous decode tokens (stored in GPU decode slot)
"""
# Get the decode capture
decode_captures = [c for c in captures
if c['layer_id'] == layer_id and not c['is_prefill']]
if decode_step >= len(decode_captures):
return None
decode_capture = decode_captures[decode_step]
q = decode_capture['q'].cuda() # [1, num_heads, head_dim]
q_batched = q.unsqueeze(1) # [1, 1, num_heads, head_dim]
if debug:
print(f" Reference for L{layer_id} D{decode_step}:")
print(f" q shape: {q_batched.shape}, mean={q_batched.mean().item():.4f}")
o_acc, lse_acc = None, None
# Attend to all prefill chunks
if layer_id in prefill_kv:
for chunk_data in sorted(prefill_kv[layer_id], key=lambda x: x['chunk_idx']):
k = chunk_data['k'].cuda().unsqueeze(0) # [1, seqlen, kv_heads, head_dim]
v = chunk_data['v'].cuda().unsqueeze(0)
o, lse = flash_attn_with_lse(q_batched, k, v, softmax_scale=scale, causal=False)
if debug:
print(f" Prefill chunk {chunk_data['chunk_idx']}: o.mean={o.mean().item():.6f}")
if o_acc is None:
o_acc, lse_acc = o, lse
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o, lse)
# Attend to previous decode tokens (including current)
# In decode, the current token's k,v are stored, and we need to attend to all previous decode tokens
# For step 0, we just have the current token's k,v
# For step 1, we have tokens 0 and 1's k,v
# etc.
# Collect k,v from all decode steps up to and including current
decode_kv = []
for i in range(decode_step + 1):
if i < len(decode_captures):
decode_kv.append({
'k': decode_captures[i]['k'].cuda(),
'v': decode_captures[i]['v'].cuda(),
}) })
if decode_kv:
# Stack decode k,v into a single tensor
decode_k = torch.cat([d['k'] for d in decode_kv], dim=0).unsqueeze(0) # [1, num_decode, kv_heads, head_dim]
decode_v = torch.cat([d['v'] for d in decode_kv], dim=0).unsqueeze(0)
if debug:
print(f" Decode tokens: {len(decode_kv)}, k.shape={decode_k.shape}")
# For decode, we use causal=False since we're attending to all decode tokens
# (the causal masking was already handled by only including tokens up to current)
o_decode, lse_decode = flash_attn_with_lse(q_batched, decode_k, decode_v,
softmax_scale=scale, causal=False)
if debug:
print(f" Decode attention: o.mean={o_decode.mean().item():.6f}")
if o_acc is None:
o_acc, lse_acc = o_decode, lse_decode
else: else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o_decode, lse_decode) decode_step = len([c for c in decode_captures if c['layer_id'] == layer_id])
decode_captures.append({
'layer_id': layer_id,
'decode_step': decode_step,
'q': q.clone().cpu(),
'k': k.clone().cpu(),
'v': v.clone().cpu(),
'output': output.clone().cpu(),
})
return hook
if o_acc is None:
def compute_decode_reference(layer_id: int, decode_step: int, scale: float,
k_cache_cpu: torch.Tensor, v_cache_cpu: torch.Tensor,
block_size: int, num_prefill_chunks: int) -> torch.Tensor:
"""
Compute reference decode output using CPU KV cache and standard flash attention.
For decode, query attends to:
1. All prefill KV (from CPU cache)
2. All previous decode tokens (from captured decode k, v)
"""
# Get decode capture for this layer and step
decode_cap = None
for c in decode_captures:
if c['layer_id'] == layer_id and c['decode_step'] == decode_step:
decode_cap = c
break
if decode_cap is None:
return None return None
if debug: # Query: single decode token
print(f" Final: o.mean={o_acc.mean().item():.6f}") q = decode_cap['q'].cuda() # [1, num_heads, head_dim]
q_batched = q.unsqueeze(0) # [1, 1, num_heads, head_dim]
return o_acc.squeeze(0).squeeze(0).cpu() # [num_heads, head_dim] # Collect all K, V: prefill chunks from CPU cache + decode tokens from captures
all_k = []
all_v = []
# 1. Prefill chunks from CPU cache
for cidx in range(num_prefill_chunks):
# Get prefill capture to know the sequence length for this chunk
prefill_cap = None
for c in prefill_captures:
if c['layer_id'] == layer_id and c['chunk_idx'] == cidx:
prefill_cap = c
break
# ============================================================ if prefill_cap is not None:
# Test Runner seq_len = prefill_cap['q'].shape[0]
# ============================================================ k = k_cache_cpu[layer_id, cidx, :seq_len].cuda()
v = v_cache_cpu[layer_id, cidx, :seq_len].cuda()
all_k.append(k)
all_v.append(v)
def run_test(verbose=True): # 2. Decode tokens from captures (up to and including current step)
"""Run the hook-based chunked decode correctness test.""" for step in range(decode_step + 1):
global captures, prefill_kv for c in decode_captures:
captures = [] if c['layer_id'] == layer_id and c['decode_step'] == step:
prefill_kv = {} all_k.append(c['k'].cuda())
all_v.append(c['v'].cuda())
break
if verbose: if not all_k:
print("=" * 70) return None
print("Test: Hook-Based Chunked Decode Correctness")
print("=" * 70)
print(f"Model: {MODEL_PATH}")
print(f"Input length: {INPUT_LEN} tokens")
print(f"Decode tokens: {NUM_DECODE_TOKENS}")
print(f"Block size: {BLOCK_SIZE}")
print()
# Initialize LLM with CPU offload # Concatenate all K, V
llm = LLM( full_k = torch.cat(all_k, dim=0).unsqueeze(0) # [1, total_len, kv_heads, head_dim]
MODEL_PATH, full_v = torch.cat(all_v, dim=0).unsqueeze(0)
enforce_eager=True,
max_model_len=MAX_MODEL_LEN, # Run flash attention (non-causal since we explicitly control what KV to include)
max_num_batched_tokens=MAX_MODEL_LEN, output = flash_attn_func(
enable_cpu_offload=True, q_batched, full_k, full_v,
kvcache_block_size=BLOCK_SIZE, softmax_scale=scale,
num_gpu_blocks=NUM_GPU_BLOCKS, causal=False,
) )
# Get model info return output.squeeze(0).squeeze(0).cpu() # [num_heads, head_dim]
num_layers = len(llm.model_runner.model.model.layers)
head_dim = llm.model_runner.model.model.layers[0].self_attn.attn.head_dim
scale = head_dim ** -0.5
if verbose:
print(f"Num layers: {num_layers}")
print(f"Head dim: {head_dim}")
print()
# Register hooks
hooks = register_hooks(llm)
if verbose:
print(f"Registered {len(hooks)} hooks")
# Generate random prompt
seed(42)
prompt_token_ids = [[randint(0, 10000) for _ in range(INPUT_LEN)]]
# Run prefill and decode
if verbose:
print(f"Running inference with {NUM_DECODE_TOKENS} decode tokens...")
sampling_params = SamplingParams(temperature=0.6, max_tokens=NUM_DECODE_TOKENS)
outputs = llm.generate(prompt_token_ids, sampling_params, use_tqdm=False)
# Remove hooks
for hook in hooks:
hook.remove()
# =========== VERIFICATION: Check CPU cache after prefill ===========
# Verify that CPU cache data matches captured prefill k,v
if verbose:
print("\n--- CPU Cache Verification (After Prefill) ---")
offload_engine = llm.model_runner.kvcache_manager.offload_engine
# For each prefill capture, check if CPU cache matches
for layer_id in [0]: # Only check layer 0 for brevity
if layer_id not in prefill_kv:
continue
for chunk_data in prefill_kv[layer_id]:
chunk_idx = chunk_data['chunk_idx']
captured_k = chunk_data['k'] # [block_size, kv_heads, head_dim]
# CPU block ID should be chunk_idx (based on allocation order)
cpu_block_id = chunk_idx
cpu_k = offload_engine.k_cache_cpu[layer_id, cpu_block_id].cpu()
diff = (captured_k - cpu_k).abs().max().item()
print(f"Layer {layer_id}, Chunk {chunk_idx}: captured_k vs cpu_k max_diff={diff:.6f}")
if diff > 1e-3:
print(f" WARNING: CPU cache doesn't match captured k!")
print(f" captured_k[0,0,:5] = {captured_k[0,0,:5].tolist()}")
print(f" cpu_k[0,0,:5] = {cpu_k[0,0,:5].tolist()}")
print()
# Analyze captures
prefill_count = sum(1 for c in captures if c['is_prefill'])
decode_count = sum(1 for c in captures if not c['is_prefill'])
if verbose:
print(f"\nCaptured {prefill_count} prefill calls, {decode_count} decode calls")
# Count decode steps per layer
decode_per_layer = {}
for c in captures:
if not c['is_prefill']:
layer_id = c['layer_id']
if layer_id not in decode_per_layer:
decode_per_layer[layer_id] = 0
decode_per_layer[layer_id] += 1
if verbose:
print(f"Decode calls per layer: {decode_per_layer}")
print()
# Verify decode correctness
all_passed = True
results = []
first_fail_debug = True
for c in captures:
if c['is_prefill']:
continue # Skip prefill (already tested in test_chunked_prefill_hook.py)
layer_id = c['layer_id']
decode_step = c['decode_step']
# Only test first decode step for now (simpler reference computation)
if decode_step > 0:
continue
# Compute reference (debug first failure)
debug_this = (layer_id == 0 and first_fail_debug)
ref_output = compute_decode_reference(layer_id, decode_step, scale, debug=debug_this)
if ref_output is None:
continue
# Compare
actual_output = c['output'].squeeze(0) # Remove seq dim for decode
if actual_output.dim() == 3:
actual_output = actual_output.squeeze(0) # Handle [1, heads, dim] case
diff = (actual_output - ref_output).abs()
max_diff = diff.max().item()
mean_diff = diff.mean().item()
tol = 1e-2
passed = max_diff < tol
all_passed = all_passed and passed
status = "PASS" if passed else "FAIL"
results.append((layer_id, decode_step, passed, max_diff, mean_diff))
if verbose:
print(f"[{status}] Layer {layer_id:2d}, Decode {decode_step}: "
f"max_diff={max_diff:.6f} mean_diff={mean_diff:.8f}")
# Debug first failure
if not passed and first_fail_debug:
first_fail_debug = False
print(f" Debug: actual_output shape={actual_output.shape}, mean={actual_output.mean().item():.4f}")
print(f" Debug: ref_output shape={ref_output.shape}, mean={ref_output.mean().item():.4f}")
# Find where max diff is
max_idx = diff.argmax()
flat_actual = actual_output.flatten()
flat_ref = ref_output.flatten()
print(f" Debug: max_diff at idx={max_idx.item()}, actual={flat_actual[max_idx].item():.4f}, ref={flat_ref[max_idx].item():.4f}")
print()
print("=" * 70)
# Summary
total_tests = len(results)
passed_count = sum(1 for r in results if r[2])
print(f"Results: {passed_count}/{total_tests} tests passed")
if not all_passed:
print("\nFailed tests:")
for layer_id, decode_step, passed, max_diff, mean_diff in results:
if not passed:
print(f" - Layer {layer_id}, Decode {decode_step}: max_diff={max_diff:.6f}")
print()
return all_passed
# ============================================================ # ============================================================
# Main # Main
# ============================================================ # ============================================================
if __name__ == "__main__": llm = LLM(
passed = run_test(verbose=True) MODEL_PATH,
enforce_eager=True,
max_model_len=MAX_MODEL_LEN,
max_num_batched_tokens=MAX_MODEL_LEN,
enable_cpu_offload=True,
kvcache_block_size=BLOCK_SIZE,
num_gpu_blocks=NUM_GPU_BLOCKS,
dtype="float16",
)
if passed: # Get model info
print("test_chunked_decode_hook: PASSED") num_layers = len(llm.model_runner.model.model.layers)
else: head_dim = llm.model_runner.model.model.layers[0].self_attn.attn.head_dim
print("test_chunked_decode_hook: FAILED") scale = head_dim ** -0.5
exit(1)
# Register hooks
hooks = []
for layer_idx, decoder_layer in enumerate(llm.model_runner.model.model.layers):
# Pre-hook: inject all ones for Q, K, V
# pre_hook = decoder_layer.self_attn.attn.register_forward_pre_hook(make_ones_injection_hook())
# hooks.append(pre_hook)
# Post-hook: capture Q, K, V, output
post_hook = decoder_layer.self_attn.attn.register_forward_hook(make_capture_hook(layer_idx))
hooks.append(post_hook)
# Run inference
seed(42)
prompt_token_ids = [[randint(0, 10000) for _ in range(INPUT_LEN)]]
outputs = llm.generate(prompt_token_ids, SamplingParams(temperature=0.6, max_tokens=NUM_DECODE_TOKENS), use_tqdm=False)
# Remove hooks
for hook in hooks:
hook.remove()
# Get CPU cache reference
offload_engine = llm.model_runner.kvcache_manager.offload_engine
k_cache_cpu = offload_engine.k_cache_cpu.clone()
v_cache_cpu = offload_engine.v_cache_cpu.clone()
# Calculate number of prefill chunks
num_prefill_chunks = INPUT_LEN // BLOCK_SIZE
# Verify decode outputs
all_passed = True
for c in decode_captures:
layer_id = c['layer_id']
decode_step = c['decode_step']
ref_output = compute_decode_reference(
layer_id, decode_step, scale,
k_cache_cpu, v_cache_cpu, BLOCK_SIZE, num_prefill_chunks
)
if ref_output is None:
continue
actual_output = c['output'].squeeze(0)
if actual_output.dim() == 3:
actual_output = actual_output.squeeze(0)
diff = (actual_output - ref_output).abs()
max_diff = diff.max().item()
passed = max_diff < 1e-1
all_passed = all_passed and passed
# if not passed:
print(f"[FAIL] Layer {layer_id}, Decode {decode_step}: max_diff={max_diff:.6f}")
print(f"test_chunked_decode_hook: {'PASSED' if all_passed else 'FAILED'}")

View File

@@ -1,203 +1,111 @@
""" """
Hook-based correctness test for chunked prefill attention. Correctness test for chunked prefill attention.
Uses PyTorch register_forward_hook() to capture real inference I/O, Captures Q and output during inference, then computes reference using
then compares against reference computation to locate bugs. CPU KV cache with standard flash attention.
This test targets the integration layer (context setup, cpu_block_table management)
which is where the needle test fails despite isolated attention tests passing.
""" """
import os import os
os.environ["NANOVLLM_LOG_LEVEL"] = "DEBUG" os.environ["NANOVLLM_LOG_LEVEL"] = "WARNING"
import torch import torch
from random import randint, seed from random import randint, seed
from typing import Dict, List
from nanovllm import LLM, SamplingParams from nanovllm import LLM, SamplingParams
from nanovllm.utils.context import get_context from nanovllm.utils.context import get_context
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
from flash_attn.flash_attn_interface import flash_attn_varlen_func from flash_attn.flash_attn_interface import flash_attn_varlen_func
# Config
# ============================================================
# Configuration
# ============================================================
MODEL_PATH = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/") MODEL_PATH = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/")
MAX_MODEL_LEN = 32 * 1024 MAX_MODEL_LEN = 128 * 1024
NUM_GPU_BLOCKS = 2 NUM_GPU_BLOCKS = 2
INPUT_LEN = 16 * 1024 # 4K tokens = 4 chunks with 1K block size INPUT_LEN = 16 * 1024
BLOCK_SIZE = 1024 BLOCK_SIZE = 1024
# State - capture Q and output for each (layer, chunk)
# ============================================================ captures: List[Dict] = []
# Global capture storage
# ============================================================
captures = []
# ============================================================ def make_ones_injection_hook():
# Hook Functions """Inject Q=K=V=1.0 for deterministic testing."""
# ============================================================ def hook(module, inputs):
def make_hook(layer_id):
"""Create a forward hook for a specific layer."""
def hook(module, inputs, output):
q, k, v = inputs
ctx = get_context() ctx = get_context()
if not ctx.is_prefill:
return inputs
# Only capture prefill phase q, k, v = inputs[0], inputs[1], inputs[2]
q_ones = torch.ones_like(q)
k_ones = torch.ones_like(k)
v_ones = torch.ones_like(v)
return (q_ones, k_ones, v_ones) + inputs[3:]
return hook
def make_capture_hook(layer_id: int):
"""Capture Q and output during prefill."""
def hook(module, inputs, output):
ctx = get_context()
if not ctx.is_prefill: if not ctx.is_prefill:
return return
q, k, v = inputs
chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0 chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0
capture_entry = { captures.append({
'layer_id': layer_id, 'layer_id': layer_id,
'chunk_idx': chunk_idx, 'chunk_idx': chunk_idx,
'q': q.clone().cpu(), 'q': q.clone().cpu(),
'k': k.clone().cpu(), 'k': k.clone().cpu(),
'v': v.clone().cpu(), 'v': v.clone().cpu(),
'output': output.clone().cpu(), 'output': output.clone().cpu(),
'is_chunked_prefill': ctx.is_chunked_prefill, })
}
# For debugging: also capture CPU cache state for layer 0
if layer_id == 0 and chunk_idx >= 2:
kvcache_manager = ctx.kvcache_manager if hasattr(ctx, 'kvcache_manager') else None
if kvcache_manager is not None and hasattr(kvcache_manager, 'offload_engine'):
oe = kvcache_manager.offload_engine
# Get what should have been loaded from CPU
cpu_k0 = oe.k_cache_cpu[0, 0].clone().cpu() # Layer 0, CPU block 0
cpu_k1 = oe.k_cache_cpu[0, 1].clone().cpu() # Layer 0, CPU block 1
capture_entry['cpu_k0'] = cpu_k0
capture_entry['cpu_k1'] = cpu_k1
captures.append(capture_entry)
return hook return hook
def register_hooks(llm): def compute_reference(layer_id: int, chunk_idx: int, scale: float,
"""Register forward hooks on all Attention modules.""" k_cache_cpu: torch.Tensor, v_cache_cpu: torch.Tensor,
hooks = [] block_size: int) -> torch.Tensor:
model = llm.model_runner.model
for layer_idx, decoder_layer in enumerate(model.model.layers):
attn_module = decoder_layer.self_attn.attn
hook = attn_module.register_forward_hook(make_hook(layer_idx))
hooks.append(hook)
return hooks
# ============================================================
# Reference Computation
# ============================================================
def compute_reference(layer_id, chunk_idx, scale, debug=False):
""" """
Compute reference attention output for a specific layer and chunk. Compute reference output using CPU KV cache and standard flash attention.
Uses the captured k, v from all chunks up to and including chunk_idx. Concatenates all Q, K, V from chunks 0..chunk_idx and runs causal attention,
then extracts output for the current chunk.
""" """
# Filter captures for this layer # Get all captures for this layer up to chunk_idx
layer_captures = [c for c in captures layer_captures = [c for c in captures
if c['layer_id'] == layer_id and c['chunk_idx'] <= chunk_idx] if c['layer_id'] == layer_id and c['chunk_idx'] <= chunk_idx]
if not layer_captures:
return None
# Get current chunk's q
current_capture = [c for c in layer_captures if c['chunk_idx'] == chunk_idx][0]
q = current_capture['q'].cuda().unsqueeze(0) # [1, seqlen, nheads, headdim]
# Collect all k, v up to current chunk
kv_list = []
for c in sorted(layer_captures, key=lambda x: x['chunk_idx']):
k = c['k'].cuda().unsqueeze(0) # [1, seqlen, nheads, headdim]
v = c['v'].cuda().unsqueeze(0)
kv_list.append((k, v, c['chunk_idx']))
if debug:
print(f" Reference for L{layer_id} C{chunk_idx}:")
print(f" q shape: {q.shape}, mean={q.mean().item():.4f}")
print(f" kv_list: {len(kv_list)} chunks")
for i, (k, v, cidx) in enumerate(kv_list):
print(f" chunk {cidx}: k.mean={k.mean().item():.4f}, v.mean={v.mean().item():.4f}")
o_acc, lse_acc = None, None
# Previous chunks: non-causal attention
for i in range(len(kv_list) - 1):
k, v, _ = kv_list[i]
o, lse = flash_attn_with_lse(q, k, v, softmax_scale=scale, causal=False)
if o_acc is None:
o_acc, lse_acc = o, lse
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o, lse)
# Current chunk: causal attention
k_cur, v_cur, _ = kv_list[-1]
o_cur, lse_cur = flash_attn_with_lse(q, k_cur, v_cur, softmax_scale=scale, causal=True)
if o_acc is None:
return o_cur.squeeze(0).cpu()
final_o, _ = merge_attention_outputs(o_acc, lse_acc, o_cur, lse_cur)
return final_o.squeeze(0).cpu()
def compute_standard_reference(layer_id, chunk_idx, scale, debug=False):
"""
Compute reference using standard flash attention (single pass with all K, V).
This simulates what standard (non-chunked) prefill would produce.
Concatenates all Q, K, V from chunks 0 to chunk_idx and runs a single
causal attention pass, then extracts the output for the current chunk.
"""
# Filter captures for this layer
layer_captures = [c for c in captures
if c['layer_id'] == layer_id and c['chunk_idx'] <= chunk_idx]
if not layer_captures:
return None
# Sort by chunk index
layer_captures = sorted(layer_captures, key=lambda x: x['chunk_idx']) layer_captures = sorted(layer_captures, key=lambda x: x['chunk_idx'])
# Concatenate all Q, K, V if not layer_captures:
return None
# Collect Q from captures, K/V from CPU cache
all_q = [] all_q = []
all_k = [] all_k = []
all_v = [] all_v = []
chunk_lengths = [] chunk_lengths = []
for c in layer_captures: for c in layer_captures:
cidx = c['chunk_idx']
q = c['q'].cuda() # [seqlen, nheads, headdim] q = c['q'].cuda() # [seqlen, nheads, headdim]
k = c['k'].cuda()
v = c['v'].cuda()
all_q.append(q) all_q.append(q)
all_k.append(k)
all_v.append(v)
chunk_lengths.append(q.shape[0]) chunk_lengths.append(q.shape[0])
# Concatenate along sequence dimension # Get K, V from CPU cache (already offloaded during prefill)
full_q = torch.cat(all_q, dim=0) # [total_seqlen, nheads, headdim] # CPU cache shape: [num_layers, num_blocks, block_size, kv_heads, head_dim]
k = k_cache_cpu[layer_id, cidx, :q.shape[0]].cuda()
v = v_cache_cpu[layer_id, cidx, :q.shape[0]].cuda()
all_k.append(k)
all_v.append(v)
# Concatenate
full_q = torch.cat(all_q, dim=0)
full_k = torch.cat(all_k, dim=0) full_k = torch.cat(all_k, dim=0)
full_v = torch.cat(all_v, dim=0) full_v = torch.cat(all_v, dim=0)
total_len = full_q.shape[0] total_len = full_q.shape[0]
if debug:
print(f" Standard Reference for L{layer_id} C{chunk_idx}:")
print(f" full_q shape: {full_q.shape}, mean={full_q.mean().item():.4f}")
print(f" full_k shape: {full_k.shape}, mean={full_k.mean().item():.4f}")
print(f" chunk_lengths: {chunk_lengths}")
# Run standard causal flash attention # Run standard causal flash attention
# flash_attn_varlen_func expects: q, k, v with shape [total_seqlen, nheads, headdim]
cu_seqlens = torch.tensor([0, total_len], dtype=torch.int32, device='cuda') cu_seqlens = torch.tensor([0, total_len], dtype=torch.int32, device='cuda')
full_o = flash_attn_varlen_func( full_o = flash_attn_varlen_func(
full_q, full_k, full_v, full_q, full_k, full_v,
cu_seqlens_q=cu_seqlens, cu_seqlens_q=cu_seqlens,
@@ -208,266 +116,81 @@ def compute_standard_reference(layer_id, chunk_idx, scale, debug=False):
causal=True, causal=True,
) )
# Extract output for current chunk only # Extract output for current chunk
start_pos = sum(chunk_lengths[:-1]) start_pos = sum(chunk_lengths[:-1])
end_pos = sum(chunk_lengths) end_pos = sum(chunk_lengths)
chunk_output = full_o[start_pos:end_pos] return full_o[start_pos:end_pos].cpu()
if debug:
print(f" full_o shape: {full_o.shape}")
print(f" extracting positions [{start_pos}:{end_pos}]")
print(f" chunk_output shape: {chunk_output.shape}, mean={chunk_output.mean().item():.4f}")
return chunk_output.cpu()
# ============================================================
# Test Runner
# ============================================================
def run_test(verbose=True):
"""Run the hook-based chunked prefill correctness test."""
global captures
captures = []
if verbose:
print("=" * 70)
print("Test: Hook-Based Chunked Prefill Correctness")
print("=" * 70)
print(f"Model: {MODEL_PATH}")
print(f"Input length: {INPUT_LEN} tokens")
print(f"Block size: {BLOCK_SIZE}")
print(f"Expected chunks: {INPUT_LEN // BLOCK_SIZE}")
print()
# Initialize LLM with CPU offload
llm = LLM(
MODEL_PATH,
enforce_eager=True,
max_model_len=MAX_MODEL_LEN,
max_num_batched_tokens=MAX_MODEL_LEN,
enable_cpu_offload=True,
kvcache_block_size=BLOCK_SIZE,
num_gpu_blocks=NUM_GPU_BLOCKS,
)
# Get model info
num_layers = len(llm.model_runner.model.model.layers)
head_dim = llm.model_runner.model.model.layers[0].self_attn.attn.head_dim
scale = head_dim ** -0.5
if verbose:
print(f"Num layers: {num_layers}")
print(f"Head dim: {head_dim}")
print()
# Register hooks
hooks = register_hooks(llm)
if verbose:
print(f"Registered {len(hooks)} hooks")
# Generate random prompt
seed(42)
prompt_token_ids = [[randint(0, 10000) for _ in range(INPUT_LEN)]]
# Run prefill only (max_tokens=1)
if verbose:
print("Running inference...")
sampling_params = SamplingParams(temperature=0.6, max_tokens=1)
outputs = llm.generate(prompt_token_ids, sampling_params, use_tqdm=False)
# Remove hooks
for hook in hooks:
hook.remove()
# Analyze captures
if verbose:
print(f"\nCaptured {len(captures)} attention calls")
# Group by layer and chunk
chunks_per_layer = {}
for c in captures:
layer_id = c['layer_id']
chunk_idx = c['chunk_idx']
if layer_id not in chunks_per_layer:
chunks_per_layer[layer_id] = set()
chunks_per_layer[layer_id].add(chunk_idx)
if verbose:
print("Chunks per layer:", {k: sorted(v) for k, v in chunks_per_layer.items()})
print()
# First, verify CPU cache data integrity
if verbose:
print("\n--- CPU Cache Verification (Layer 0) ---")
# Get original k from chunk 0 and chunk 1 captures
chunk0_k = None
chunk1_k = None
chunk2_capture = None
for c in captures:
if c['layer_id'] == 0:
if c['chunk_idx'] == 0:
chunk0_k = c['k']
elif c['chunk_idx'] == 1:
chunk1_k = c['k']
elif c['chunk_idx'] == 2:
chunk2_capture = c
if chunk0_k is not None and chunk2_capture is not None and 'cpu_k0' in chunk2_capture:
cpu_k0 = chunk2_capture['cpu_k0']
diff_k0 = (chunk0_k - cpu_k0).abs().max().item()
print(f"Chunk 0 k vs CPU cache block 0: max_diff={diff_k0:.6f}")
if diff_k0 > 1e-3:
print(f" WARNING: CPU cache block 0 differs from original chunk 0 k!")
print(f" Original k[0,0,:5] = {chunk0_k[0,0,:5].tolist()}")
print(f" CPU k0[0,0,:5] = {cpu_k0[0,0,:5].tolist()}")
if chunk1_k is not None and chunk2_capture is not None and 'cpu_k1' in chunk2_capture:
cpu_k1 = chunk2_capture['cpu_k1']
diff_k1 = (chunk1_k - cpu_k1).abs().max().item()
print(f"Chunk 1 k vs CPU cache block 1: max_diff={diff_k1:.6f}")
if diff_k1 > 1e-3:
print(f" WARNING: CPU cache block 1 differs from original chunk 1 k!")
print(f" Original k[0,0,:5] = {chunk1_k[0,0,:5].tolist()}")
print(f" CPU k1[0,0,:5] = {cpu_k1[0,0,:5].tolist()}")
print()
# ================================================================
# Test 1: Verify against merge-based reference (same algorithm)
# ================================================================
if verbose:
print("--- Test 1: Merge-based Reference (verifies merge algorithm) ---")
all_passed_merge = True
results_merge = []
first_fail_debug = True
for c in captures:
layer_id = c['layer_id']
chunk_idx = c['chunk_idx']
if chunk_idx == 0:
continue
debug_this = (chunk_idx >= 2 and layer_id == 0 and first_fail_debug)
ref_output = compute_reference(layer_id, chunk_idx, scale, debug=debug_this)
if ref_output is None:
continue
actual_output = c['output']
diff = (actual_output - ref_output).abs()
max_diff = diff.max().item()
mean_diff = diff.mean().item()
tol = 1e-2
passed = max_diff < tol
all_passed_merge = all_passed_merge and passed
status = "PASS" if passed else "FAIL"
results_merge.append((layer_id, chunk_idx, passed, max_diff, mean_diff))
if verbose:
print(f"[{status}] Layer {layer_id:2d}, Chunk {chunk_idx}: "
f"max_diff={max_diff:.6f} mean_diff={mean_diff:.8f}")
if not passed and first_fail_debug:
first_fail_debug = False
print(f" Debug: actual_output shape={actual_output.shape}, mean={actual_output.mean().item():.4f}")
print(f" Debug: ref_output shape={ref_output.shape}, mean={ref_output.mean().item():.4f}")
max_idx = diff.argmax()
flat_actual = actual_output.flatten()
flat_ref = ref_output.flatten()
print(f" Debug: max_diff at idx={max_idx.item()}, actual={flat_actual[max_idx].item():.4f}, ref={flat_ref[max_idx].item():.4f}")
print()
# ================================================================
# Test 2: Verify against standard flash attention (single pass)
# ================================================================
if verbose:
print("--- Test 2: Standard FlashAttn Reference (verifies correctness vs non-chunked) ---")
all_passed_standard = True
results_standard = []
first_fail_debug = True
for c in captures:
layer_id = c['layer_id']
chunk_idx = c['chunk_idx']
if chunk_idx == 0:
continue
debug_this = (chunk_idx >= 2 and layer_id == 0 and first_fail_debug)
std_ref_output = compute_standard_reference(layer_id, chunk_idx, scale, debug=debug_this)
if std_ref_output is None:
continue
actual_output = c['output']
diff = (actual_output - std_ref_output).abs()
max_diff = diff.max().item()
mean_diff = diff.mean().item()
tol = 1e-2
passed = max_diff < tol
all_passed_standard = all_passed_standard and passed
status = "PASS" if passed else "FAIL"
results_standard.append((layer_id, chunk_idx, passed, max_diff, mean_diff))
if verbose:
print(f"[{status}] Layer {layer_id:2d}, Chunk {chunk_idx}: "
f"max_diff={max_diff:.6f} mean_diff={mean_diff:.8f}")
if not passed and first_fail_debug:
first_fail_debug = False
print(f" Debug: actual_output shape={actual_output.shape}, mean={actual_output.mean().item():.4f}")
print(f" Debug: std_ref_output shape={std_ref_output.shape}, mean={std_ref_output.mean().item():.4f}")
max_idx = diff.argmax()
flat_actual = actual_output.flatten()
flat_ref = std_ref_output.flatten()
print(f" Debug: max_diff at idx={max_idx.item()}, actual={flat_actual[max_idx].item():.4f}, ref={flat_ref[max_idx].item():.4f}")
print()
print("=" * 70)
# Summary
total_merge = len(results_merge)
passed_merge = sum(1 for r in results_merge if r[2])
total_standard = len(results_standard)
passed_standard = sum(1 for r in results_standard if r[2])
print(f"Merge-based reference: {passed_merge}/{total_merge} tests passed")
print(f"Standard FlashAttn ref: {passed_standard}/{total_standard} tests passed")
all_passed = all_passed_merge and all_passed_standard
if not all_passed_merge:
print("\nFailed merge-based tests:")
for layer_id, chunk_idx, passed, max_diff, mean_diff in results_merge:
if not passed:
print(f" - Layer {layer_id}, Chunk {chunk_idx}: max_diff={max_diff:.6f}")
if not all_passed_standard:
print("\nFailed standard FlashAttn tests:")
for layer_id, chunk_idx, passed, max_diff, mean_diff in results_standard:
if not passed:
print(f" - Layer {layer_id}, Chunk {chunk_idx}: max_diff={max_diff:.6f}")
print()
return all_passed
# ============================================================ # ============================================================
# Main # Main
# ============================================================ # ============================================================
if __name__ == "__main__": llm = LLM(
passed = run_test(verbose=True) MODEL_PATH,
enforce_eager=True,
max_model_len=MAX_MODEL_LEN,
max_num_batched_tokens=MAX_MODEL_LEN,
enable_cpu_offload=True,
kvcache_block_size=BLOCK_SIZE,
num_gpu_blocks=NUM_GPU_BLOCKS,
dtype="float16",
)
if passed: # Get model info
print("test_chunked_prefill_hook: PASSED") num_layers = len(llm.model_runner.model.model.layers)
else: head_dim = llm.model_runner.model.model.layers[0].self_attn.attn.head_dim
print("test_chunked_prefill_hook: FAILED") scale = head_dim ** -0.5
exit(1)
# Register hooks
hooks = []
for layer_idx, decoder_layer in enumerate(llm.model_runner.model.model.layers):
# Pre-hook: inject all ones for Q, K, V
# pre_hook = decoder_layer.self_attn.attn.register_forward_pre_hook(make_ones_injection_hook())
# hooks.append(pre_hook)
# Post-hook: capture Q, K, V, output
post_hook = decoder_layer.self_attn.attn.register_forward_hook(make_capture_hook(layer_idx))
hooks.append(post_hook)
# Run inference
seed(42)
prompt_token_ids = [[randint(0, 10000) for _ in range(INPUT_LEN)]]
outputs = llm.generate(prompt_token_ids, SamplingParams(temperature=0.6, max_tokens=1), use_tqdm=False)
# Remove hooks
for hook in hooks:
hook.remove()
# Get CPU cache reference
offload_engine = llm.model_runner.kvcache_manager.offload_engine
k_cache_cpu = offload_engine.k_cache_cpu.clone()
v_cache_cpu = offload_engine.v_cache_cpu.clone()
# Verify: compare actual output with reference computed from CPU cache
all_passed = True
num_chunks = INPUT_LEN // BLOCK_SIZE
for idx,c in enumerate(captures):
layer_id = c['layer_id']
chunk_idx = c['chunk_idx']
# Skip chunk 0 (no previous KV to load)
if chunk_idx == 0:
continue
ref_output = compute_reference(layer_id, chunk_idx, scale, k_cache_cpu, v_cache_cpu, BLOCK_SIZE)
if ref_output is None:
continue
actual_output = c['output']
diff = (actual_output - ref_output).abs()
max_diff = diff.max().item()
passed = max_diff < 1e-1 # float16 tolerance
all_passed = all_passed and passed
if not passed:
print(f"[FAIL] Layer {layer_id}, Chunk {chunk_idx}: max_diff={max_diff:.6f}")
__import__('pdb').set_trace()
print(f"test_chunked_prefill_hook: {'PASSED' if all_passed else 'FAILED'}")