[refactor] Refactor the test_chunked_prefill/decode.

This commit is contained in:
Zijie Tian
2026-01-01 03:32:26 +08:00
parent 965c8aff12
commit 62b8a63314
2 changed files with 294 additions and 731 deletions

View File

@@ -1,208 +1,146 @@
""" """
Hook-based correctness test for chunked decode attention. Correctness test for chunked decode attention.
Uses PyTorch register_forward_hook() to capture real inference I/O, Captures Q and output during inference, then computes reference using
then compares against reference computation to locate bugs. CPU KV cache with standard flash attention.
This test targets the decode phase with CPU offload - after prefill,
the model generates tokens one by one while attending to all previous context.
""" """
import os import os
os.environ["NANOVLLM_LOG_LEVEL"] = "INFO" os.environ["NANOVLLM_LOG_LEVEL"] = "WARNING"
import torch import torch
from random import randint, seed from random import randint, seed
from typing import Dict, List
from nanovllm import LLM, SamplingParams from nanovllm import LLM, SamplingParams
from nanovllm.utils.context import get_context from nanovllm.utils.context import get_context
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs from flash_attn.flash_attn_interface import flash_attn_func
# Config
# ============================================================ MODEL_PATH = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/")
# Configuration MAX_MODEL_LEN = 128 * 1024
# ============================================================
MODEL_PATH = os.path.expanduser("~/models/Qwen3-0.6B/")
MAX_MODEL_LEN = 8 * 1024
NUM_GPU_BLOCKS = 2 NUM_GPU_BLOCKS = 2
INPUT_LEN = 2 * 1024 # 2K tokens for prefill INPUT_LEN = 16 * 1024
NUM_DECODE_TOKENS = 5 # Generate 5 tokens to test decode NUM_DECODE_TOKENS = 5
BLOCK_SIZE = 1024 BLOCK_SIZE = 1024
# State
# ============================================================ prefill_captures: List[Dict] = []
# Global capture storage decode_captures: List[Dict] = []
# ============================================================
captures = []
prefill_kv = {} # Store prefill k,v for reference computation
# ============================================================ def make_ones_injection_hook():
# Hook Functions """Inject Q=K=V=1.0 for deterministic testing."""
# ============================================================ def hook(module, inputs):
q, k, v = inputs[0], inputs[1], inputs[2]
q_ones = torch.ones_like(q)
k_ones = torch.ones_like(k)
v_ones = torch.ones_like(v)
return (q_ones, k_ones, v_ones) + inputs[3:]
return hook
def make_hook(layer_id):
"""Create a forward hook for a specific layer.""" def make_capture_hook(layer_id: int):
"""Capture Q, K, V, output during inference."""
def hook(module, inputs, output): def hook(module, inputs, output):
q, k, v = inputs
ctx = get_context() ctx = get_context()
q, k, v = inputs
is_prefill = ctx.is_prefill if ctx.is_prefill:
chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0
capture_entry = { prefill_captures.append({
'layer_id': layer_id, 'layer_id': layer_id,
'is_prefill': is_prefill, 'chunk_idx': chunk_idx,
'q': q.clone().cpu(), 'q': q.clone().cpu(),
'k': k.clone().cpu(), 'k': k.clone().cpu(),
'v': v.clone().cpu(), 'v': v.clone().cpu(),
'output': output.clone().cpu(), 'output': output.clone().cpu(),
'is_chunked_prefill': ctx.is_chunked_prefill,
}
if is_prefill:
# Store prefill k,v for reference computation
chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0
capture_entry['chunk_idx'] = chunk_idx
if layer_id not in prefill_kv:
prefill_kv[layer_id] = []
prefill_kv[layer_id].append({
'chunk_idx': chunk_idx,
'k': k.clone().cpu(),
'v': v.clone().cpu(),
}) })
else: else:
# Decode phase - capture decode token info decode_step = len([c for c in decode_captures if c['layer_id'] == layer_id])
capture_entry['decode_step'] = len([c for c in captures decode_captures.append({
if c['layer_id'] == layer_id and not c['is_prefill']]) 'layer_id': layer_id,
'decode_step': decode_step,
captures.append(capture_entry) 'q': q.clone().cpu(),
'k': k.clone().cpu(),
'v': v.clone().cpu(),
'output': output.clone().cpu(),
})
return hook return hook
def register_hooks(llm): def compute_decode_reference(layer_id: int, decode_step: int, scale: float,
"""Register forward hooks on all Attention modules.""" k_cache_cpu: torch.Tensor, v_cache_cpu: torch.Tensor,
hooks = [] block_size: int, num_prefill_chunks: int) -> torch.Tensor:
model = llm.model_runner.model
for layer_idx, decoder_layer in enumerate(model.model.layers):
attn_module = decoder_layer.self_attn.attn
hook = attn_module.register_forward_hook(make_hook(layer_idx))
hooks.append(hook)
return hooks
# ============================================================
# Reference Computation
# ============================================================
def compute_decode_reference(layer_id, decode_step, scale, debug=False):
""" """
Compute reference decode attention output for a specific layer. Compute reference decode output using CPU KV cache and standard flash attention.
For decode, the query is a single token that attends to: For decode, query attends to:
1. All prefill KV (from CPU cache) 1. All prefill KV (from CPU cache)
2. All previous decode tokens (stored in GPU decode slot) 2. All previous decode tokens (from captured decode k, v)
""" """
# Get the decode capture # Get decode capture for this layer and step
decode_captures = [c for c in captures decode_cap = None
if c['layer_id'] == layer_id and not c['is_prefill']] for c in decode_captures:
if decode_step >= len(decode_captures): if c['layer_id'] == layer_id and c['decode_step'] == decode_step:
decode_cap = c
break
if decode_cap is None:
return None return None
decode_capture = decode_captures[decode_step] # Query: single decode token
q = decode_capture['q'].cuda() # [1, num_heads, head_dim] q = decode_cap['q'].cuda() # [1, num_heads, head_dim]
q_batched = q.unsqueeze(1) # [1, 1, num_heads, head_dim] q_batched = q.unsqueeze(0) # [1, 1, num_heads, head_dim]
if debug: # Collect all K, V: prefill chunks from CPU cache + decode tokens from captures
print(f" Reference for L{layer_id} D{decode_step}:") all_k = []
print(f" q shape: {q_batched.shape}, mean={q_batched.mean().item():.4f}") all_v = []
o_acc, lse_acc = None, None # 1. Prefill chunks from CPU cache
for cidx in range(num_prefill_chunks):
# Get prefill capture to know the sequence length for this chunk
prefill_cap = None
for c in prefill_captures:
if c['layer_id'] == layer_id and c['chunk_idx'] == cidx:
prefill_cap = c
break
# Attend to all prefill chunks if prefill_cap is not None:
if layer_id in prefill_kv: seq_len = prefill_cap['q'].shape[0]
for chunk_data in sorted(prefill_kv[layer_id], key=lambda x: x['chunk_idx']): k = k_cache_cpu[layer_id, cidx, :seq_len].cuda()
k = chunk_data['k'].cuda().unsqueeze(0) # [1, seqlen, kv_heads, head_dim] v = v_cache_cpu[layer_id, cidx, :seq_len].cuda()
v = chunk_data['v'].cuda().unsqueeze(0) all_k.append(k)
all_v.append(v)
o, lse = flash_attn_with_lse(q_batched, k, v, softmax_scale=scale, causal=False) # 2. Decode tokens from captures (up to and including current step)
for step in range(decode_step + 1):
for c in decode_captures:
if c['layer_id'] == layer_id and c['decode_step'] == step:
all_k.append(c['k'].cuda())
all_v.append(c['v'].cuda())
break
if debug: if not all_k:
print(f" Prefill chunk {chunk_data['chunk_idx']}: o.mean={o.mean().item():.6f}")
if o_acc is None:
o_acc, lse_acc = o, lse
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o, lse)
# Attend to previous decode tokens (including current)
# In decode, the current token's k,v are stored, and we need to attend to all previous decode tokens
# For step 0, we just have the current token's k,v
# For step 1, we have tokens 0 and 1's k,v
# etc.
# Collect k,v from all decode steps up to and including current
decode_kv = []
for i in range(decode_step + 1):
if i < len(decode_captures):
decode_kv.append({
'k': decode_captures[i]['k'].cuda(),
'v': decode_captures[i]['v'].cuda(),
})
if decode_kv:
# Stack decode k,v into a single tensor
decode_k = torch.cat([d['k'] for d in decode_kv], dim=0).unsqueeze(0) # [1, num_decode, kv_heads, head_dim]
decode_v = torch.cat([d['v'] for d in decode_kv], dim=0).unsqueeze(0)
if debug:
print(f" Decode tokens: {len(decode_kv)}, k.shape={decode_k.shape}")
# For decode, we use causal=False since we're attending to all decode tokens
# (the causal masking was already handled by only including tokens up to current)
o_decode, lse_decode = flash_attn_with_lse(q_batched, decode_k, decode_v,
softmax_scale=scale, causal=False)
if debug:
print(f" Decode attention: o.mean={o_decode.mean().item():.6f}")
if o_acc is None:
o_acc, lse_acc = o_decode, lse_decode
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o_decode, lse_decode)
if o_acc is None:
return None return None
if debug: # Concatenate all K, V
print(f" Final: o.mean={o_acc.mean().item():.6f}") full_k = torch.cat(all_k, dim=0).unsqueeze(0) # [1, total_len, kv_heads, head_dim]
full_v = torch.cat(all_v, dim=0).unsqueeze(0)
return o_acc.squeeze(0).squeeze(0).cpu() # [num_heads, head_dim] # Run flash attention (non-causal since we explicitly control what KV to include)
output = flash_attn_func(
q_batched, full_k, full_v,
softmax_scale=scale,
causal=False,
)
return output.squeeze(0).squeeze(0).cpu() # [num_heads, head_dim]
# ============================================================ # ============================================================
# Test Runner # Main
# ============================================================ # ============================================================
def run_test(verbose=True):
"""Run the hook-based chunked decode correctness test."""
global captures, prefill_kv
captures = []
prefill_kv = {}
if verbose:
print("=" * 70)
print("Test: Hook-Based Chunked Decode Correctness")
print("=" * 70)
print(f"Model: {MODEL_PATH}")
print(f"Input length: {INPUT_LEN} tokens")
print(f"Decode tokens: {NUM_DECODE_TOKENS}")
print(f"Block size: {BLOCK_SIZE}")
print()
# Initialize LLM with CPU offload
llm = LLM( llm = LLM(
MODEL_PATH, MODEL_PATH,
enforce_eager=True, enforce_eager=True,
@@ -211,6 +149,7 @@ def run_test(verbose=True):
enable_cpu_offload=True, enable_cpu_offload=True,
kvcache_block_size=BLOCK_SIZE, kvcache_block_size=BLOCK_SIZE,
num_gpu_blocks=NUM_GPU_BLOCKS, num_gpu_blocks=NUM_GPU_BLOCKS,
dtype="float16",
) )
# Get model info # Get model info
@@ -218,157 +157,58 @@ def run_test(verbose=True):
head_dim = llm.model_runner.model.model.layers[0].self_attn.attn.head_dim head_dim = llm.model_runner.model.model.layers[0].self_attn.attn.head_dim
scale = head_dim ** -0.5 scale = head_dim ** -0.5
if verbose:
print(f"Num layers: {num_layers}")
print(f"Head dim: {head_dim}")
print()
# Register hooks # Register hooks
hooks = register_hooks(llm) hooks = []
if verbose: for layer_idx, decoder_layer in enumerate(llm.model_runner.model.model.layers):
print(f"Registered {len(hooks)} hooks") # Pre-hook: inject all ones for Q, K, V
# pre_hook = decoder_layer.self_attn.attn.register_forward_pre_hook(make_ones_injection_hook())
# hooks.append(pre_hook)
# Post-hook: capture Q, K, V, output
post_hook = decoder_layer.self_attn.attn.register_forward_hook(make_capture_hook(layer_idx))
hooks.append(post_hook)
# Generate random prompt # Run inference
seed(42) seed(42)
prompt_token_ids = [[randint(0, 10000) for _ in range(INPUT_LEN)]] prompt_token_ids = [[randint(0, 10000) for _ in range(INPUT_LEN)]]
outputs = llm.generate(prompt_token_ids, SamplingParams(temperature=0.6, max_tokens=NUM_DECODE_TOKENS), use_tqdm=False)
# Run prefill and decode
if verbose:
print(f"Running inference with {NUM_DECODE_TOKENS} decode tokens...")
sampling_params = SamplingParams(temperature=0.6, max_tokens=NUM_DECODE_TOKENS)
outputs = llm.generate(prompt_token_ids, sampling_params, use_tqdm=False)
# Remove hooks # Remove hooks
for hook in hooks: for hook in hooks:
hook.remove() hook.remove()
# =========== VERIFICATION: Check CPU cache after prefill =========== # Get CPU cache reference
# Verify that CPU cache data matches captured prefill k,v
if verbose:
print("\n--- CPU Cache Verification (After Prefill) ---")
offload_engine = llm.model_runner.kvcache_manager.offload_engine offload_engine = llm.model_runner.kvcache_manager.offload_engine
k_cache_cpu = offload_engine.k_cache_cpu.clone()
v_cache_cpu = offload_engine.v_cache_cpu.clone()
# For each prefill capture, check if CPU cache matches # Calculate number of prefill chunks
for layer_id in [0]: # Only check layer 0 for brevity num_prefill_chunks = INPUT_LEN // BLOCK_SIZE
if layer_id not in prefill_kv:
continue
for chunk_data in prefill_kv[layer_id]: # Verify decode outputs
chunk_idx = chunk_data['chunk_idx']
captured_k = chunk_data['k'] # [block_size, kv_heads, head_dim]
# CPU block ID should be chunk_idx (based on allocation order)
cpu_block_id = chunk_idx
cpu_k = offload_engine.k_cache_cpu[layer_id, cpu_block_id].cpu()
diff = (captured_k - cpu_k).abs().max().item()
print(f"Layer {layer_id}, Chunk {chunk_idx}: captured_k vs cpu_k max_diff={diff:.6f}")
if diff > 1e-3:
print(f" WARNING: CPU cache doesn't match captured k!")
print(f" captured_k[0,0,:5] = {captured_k[0,0,:5].tolist()}")
print(f" cpu_k[0,0,:5] = {cpu_k[0,0,:5].tolist()}")
print()
# Analyze captures
prefill_count = sum(1 for c in captures if c['is_prefill'])
decode_count = sum(1 for c in captures if not c['is_prefill'])
if verbose:
print(f"\nCaptured {prefill_count} prefill calls, {decode_count} decode calls")
# Count decode steps per layer
decode_per_layer = {}
for c in captures:
if not c['is_prefill']:
layer_id = c['layer_id']
if layer_id not in decode_per_layer:
decode_per_layer[layer_id] = 0
decode_per_layer[layer_id] += 1
if verbose:
print(f"Decode calls per layer: {decode_per_layer}")
print()
# Verify decode correctness
all_passed = True all_passed = True
results = []
first_fail_debug = True
for c in captures:
if c['is_prefill']:
continue # Skip prefill (already tested in test_chunked_prefill_hook.py)
for c in decode_captures:
layer_id = c['layer_id'] layer_id = c['layer_id']
decode_step = c['decode_step'] decode_step = c['decode_step']
# Only test first decode step for now (simpler reference computation) ref_output = compute_decode_reference(
if decode_step > 0: layer_id, decode_step, scale,
continue k_cache_cpu, v_cache_cpu, BLOCK_SIZE, num_prefill_chunks
)
# Compute reference (debug first failure)
debug_this = (layer_id == 0 and first_fail_debug)
ref_output = compute_decode_reference(layer_id, decode_step, scale, debug=debug_this)
if ref_output is None: if ref_output is None:
continue continue
# Compare actual_output = c['output'].squeeze(0)
actual_output = c['output'].squeeze(0) # Remove seq dim for decode
if actual_output.dim() == 3: if actual_output.dim() == 3:
actual_output = actual_output.squeeze(0) # Handle [1, heads, dim] case actual_output = actual_output.squeeze(0)
diff = (actual_output - ref_output).abs() diff = (actual_output - ref_output).abs()
max_diff = diff.max().item() max_diff = diff.max().item()
mean_diff = diff.mean().item()
tol = 1e-2 passed = max_diff < 1e-1
passed = max_diff < tol
all_passed = all_passed and passed all_passed = all_passed and passed
status = "PASS" if passed else "FAIL" # if not passed:
results.append((layer_id, decode_step, passed, max_diff, mean_diff)) print(f"[FAIL] Layer {layer_id}, Decode {decode_step}: max_diff={max_diff:.6f}")
if verbose: print(f"test_chunked_decode_hook: {'PASSED' if all_passed else 'FAILED'}")
print(f"[{status}] Layer {layer_id:2d}, Decode {decode_step}: "
f"max_diff={max_diff:.6f} mean_diff={mean_diff:.8f}")
# Debug first failure
if not passed and first_fail_debug:
first_fail_debug = False
print(f" Debug: actual_output shape={actual_output.shape}, mean={actual_output.mean().item():.4f}")
print(f" Debug: ref_output shape={ref_output.shape}, mean={ref_output.mean().item():.4f}")
# Find where max diff is
max_idx = diff.argmax()
flat_actual = actual_output.flatten()
flat_ref = ref_output.flatten()
print(f" Debug: max_diff at idx={max_idx.item()}, actual={flat_actual[max_idx].item():.4f}, ref={flat_ref[max_idx].item():.4f}")
print()
print("=" * 70)
# Summary
total_tests = len(results)
passed_count = sum(1 for r in results if r[2])
print(f"Results: {passed_count}/{total_tests} tests passed")
if not all_passed:
print("\nFailed tests:")
for layer_id, decode_step, passed, max_diff, mean_diff in results:
if not passed:
print(f" - Layer {layer_id}, Decode {decode_step}: max_diff={max_diff:.6f}")
print()
return all_passed
# ============================================================
# Main
# ============================================================
if __name__ == "__main__":
passed = run_test(verbose=True)
if passed:
print("test_chunked_decode_hook: PASSED")
else:
print("test_chunked_decode_hook: FAILED")
exit(1)

View File

@@ -1,203 +1,111 @@
""" """
Hook-based correctness test for chunked prefill attention. Correctness test for chunked prefill attention.
Uses PyTorch register_forward_hook() to capture real inference I/O, Captures Q and output during inference, then computes reference using
then compares against reference computation to locate bugs. CPU KV cache with standard flash attention.
This test targets the integration layer (context setup, cpu_block_table management)
which is where the needle test fails despite isolated attention tests passing.
""" """
import os import os
os.environ["NANOVLLM_LOG_LEVEL"] = "DEBUG" os.environ["NANOVLLM_LOG_LEVEL"] = "WARNING"
import torch import torch
from random import randint, seed from random import randint, seed
from typing import Dict, List
from nanovllm import LLM, SamplingParams from nanovllm import LLM, SamplingParams
from nanovllm.utils.context import get_context from nanovllm.utils.context import get_context
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
from flash_attn.flash_attn_interface import flash_attn_varlen_func from flash_attn.flash_attn_interface import flash_attn_varlen_func
# Config
# ============================================================
# Configuration
# ============================================================
MODEL_PATH = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/") MODEL_PATH = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/")
MAX_MODEL_LEN = 32 * 1024 MAX_MODEL_LEN = 128 * 1024
NUM_GPU_BLOCKS = 2 NUM_GPU_BLOCKS = 2
INPUT_LEN = 16 * 1024 # 4K tokens = 4 chunks with 1K block size INPUT_LEN = 16 * 1024
BLOCK_SIZE = 1024 BLOCK_SIZE = 1024
# State - capture Q and output for each (layer, chunk)
# ============================================================ captures: List[Dict] = []
# Global capture storage
# ============================================================
captures = []
# ============================================================ def make_ones_injection_hook():
# Hook Functions """Inject Q=K=V=1.0 for deterministic testing."""
# ============================================================ def hook(module, inputs):
def make_hook(layer_id):
"""Create a forward hook for a specific layer."""
def hook(module, inputs, output):
q, k, v = inputs
ctx = get_context() ctx = get_context()
if not ctx.is_prefill:
return inputs
# Only capture prefill phase q, k, v = inputs[0], inputs[1], inputs[2]
q_ones = torch.ones_like(q)
k_ones = torch.ones_like(k)
v_ones = torch.ones_like(v)
return (q_ones, k_ones, v_ones) + inputs[3:]
return hook
def make_capture_hook(layer_id: int):
"""Capture Q and output during prefill."""
def hook(module, inputs, output):
ctx = get_context()
if not ctx.is_prefill: if not ctx.is_prefill:
return return
q, k, v = inputs
chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0 chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0
capture_entry = { captures.append({
'layer_id': layer_id, 'layer_id': layer_id,
'chunk_idx': chunk_idx, 'chunk_idx': chunk_idx,
'q': q.clone().cpu(), 'q': q.clone().cpu(),
'k': k.clone().cpu(), 'k': k.clone().cpu(),
'v': v.clone().cpu(), 'v': v.clone().cpu(),
'output': output.clone().cpu(), 'output': output.clone().cpu(),
'is_chunked_prefill': ctx.is_chunked_prefill, })
}
# For debugging: also capture CPU cache state for layer 0
if layer_id == 0 and chunk_idx >= 2:
kvcache_manager = ctx.kvcache_manager if hasattr(ctx, 'kvcache_manager') else None
if kvcache_manager is not None and hasattr(kvcache_manager, 'offload_engine'):
oe = kvcache_manager.offload_engine
# Get what should have been loaded from CPU
cpu_k0 = oe.k_cache_cpu[0, 0].clone().cpu() # Layer 0, CPU block 0
cpu_k1 = oe.k_cache_cpu[0, 1].clone().cpu() # Layer 0, CPU block 1
capture_entry['cpu_k0'] = cpu_k0
capture_entry['cpu_k1'] = cpu_k1
captures.append(capture_entry)
return hook return hook
def register_hooks(llm): def compute_reference(layer_id: int, chunk_idx: int, scale: float,
"""Register forward hooks on all Attention modules.""" k_cache_cpu: torch.Tensor, v_cache_cpu: torch.Tensor,
hooks = [] block_size: int) -> torch.Tensor:
model = llm.model_runner.model
for layer_idx, decoder_layer in enumerate(model.model.layers):
attn_module = decoder_layer.self_attn.attn
hook = attn_module.register_forward_hook(make_hook(layer_idx))
hooks.append(hook)
return hooks
# ============================================================
# Reference Computation
# ============================================================
def compute_reference(layer_id, chunk_idx, scale, debug=False):
""" """
Compute reference attention output for a specific layer and chunk. Compute reference output using CPU KV cache and standard flash attention.
Uses the captured k, v from all chunks up to and including chunk_idx. Concatenates all Q, K, V from chunks 0..chunk_idx and runs causal attention,
then extracts output for the current chunk.
""" """
# Filter captures for this layer # Get all captures for this layer up to chunk_idx
layer_captures = [c for c in captures layer_captures = [c for c in captures
if c['layer_id'] == layer_id and c['chunk_idx'] <= chunk_idx] if c['layer_id'] == layer_id and c['chunk_idx'] <= chunk_idx]
if not layer_captures:
return None
# Get current chunk's q
current_capture = [c for c in layer_captures if c['chunk_idx'] == chunk_idx][0]
q = current_capture['q'].cuda().unsqueeze(0) # [1, seqlen, nheads, headdim]
# Collect all k, v up to current chunk
kv_list = []
for c in sorted(layer_captures, key=lambda x: x['chunk_idx']):
k = c['k'].cuda().unsqueeze(0) # [1, seqlen, nheads, headdim]
v = c['v'].cuda().unsqueeze(0)
kv_list.append((k, v, c['chunk_idx']))
if debug:
print(f" Reference for L{layer_id} C{chunk_idx}:")
print(f" q shape: {q.shape}, mean={q.mean().item():.4f}")
print(f" kv_list: {len(kv_list)} chunks")
for i, (k, v, cidx) in enumerate(kv_list):
print(f" chunk {cidx}: k.mean={k.mean().item():.4f}, v.mean={v.mean().item():.4f}")
o_acc, lse_acc = None, None
# Previous chunks: non-causal attention
for i in range(len(kv_list) - 1):
k, v, _ = kv_list[i]
o, lse = flash_attn_with_lse(q, k, v, softmax_scale=scale, causal=False)
if o_acc is None:
o_acc, lse_acc = o, lse
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o, lse)
# Current chunk: causal attention
k_cur, v_cur, _ = kv_list[-1]
o_cur, lse_cur = flash_attn_with_lse(q, k_cur, v_cur, softmax_scale=scale, causal=True)
if o_acc is None:
return o_cur.squeeze(0).cpu()
final_o, _ = merge_attention_outputs(o_acc, lse_acc, o_cur, lse_cur)
return final_o.squeeze(0).cpu()
def compute_standard_reference(layer_id, chunk_idx, scale, debug=False):
"""
Compute reference using standard flash attention (single pass with all K, V).
This simulates what standard (non-chunked) prefill would produce.
Concatenates all Q, K, V from chunks 0 to chunk_idx and runs a single
causal attention pass, then extracts the output for the current chunk.
"""
# Filter captures for this layer
layer_captures = [c for c in captures
if c['layer_id'] == layer_id and c['chunk_idx'] <= chunk_idx]
if not layer_captures:
return None
# Sort by chunk index
layer_captures = sorted(layer_captures, key=lambda x: x['chunk_idx']) layer_captures = sorted(layer_captures, key=lambda x: x['chunk_idx'])
# Concatenate all Q, K, V if not layer_captures:
return None
# Collect Q from captures, K/V from CPU cache
all_q = [] all_q = []
all_k = [] all_k = []
all_v = [] all_v = []
chunk_lengths = [] chunk_lengths = []
for c in layer_captures: for c in layer_captures:
cidx = c['chunk_idx']
q = c['q'].cuda() # [seqlen, nheads, headdim] q = c['q'].cuda() # [seqlen, nheads, headdim]
k = c['k'].cuda()
v = c['v'].cuda()
all_q.append(q) all_q.append(q)
all_k.append(k)
all_v.append(v)
chunk_lengths.append(q.shape[0]) chunk_lengths.append(q.shape[0])
# Concatenate along sequence dimension # Get K, V from CPU cache (already offloaded during prefill)
full_q = torch.cat(all_q, dim=0) # [total_seqlen, nheads, headdim] # CPU cache shape: [num_layers, num_blocks, block_size, kv_heads, head_dim]
k = k_cache_cpu[layer_id, cidx, :q.shape[0]].cuda()
v = v_cache_cpu[layer_id, cidx, :q.shape[0]].cuda()
all_k.append(k)
all_v.append(v)
# Concatenate
full_q = torch.cat(all_q, dim=0)
full_k = torch.cat(all_k, dim=0) full_k = torch.cat(all_k, dim=0)
full_v = torch.cat(all_v, dim=0) full_v = torch.cat(all_v, dim=0)
total_len = full_q.shape[0] total_len = full_q.shape[0]
if debug:
print(f" Standard Reference for L{layer_id} C{chunk_idx}:")
print(f" full_q shape: {full_q.shape}, mean={full_q.mean().item():.4f}")
print(f" full_k shape: {full_k.shape}, mean={full_k.mean().item():.4f}")
print(f" chunk_lengths: {chunk_lengths}")
# Run standard causal flash attention # Run standard causal flash attention
# flash_attn_varlen_func expects: q, k, v with shape [total_seqlen, nheads, headdim]
cu_seqlens = torch.tensor([0, total_len], dtype=torch.int32, device='cuda') cu_seqlens = torch.tensor([0, total_len], dtype=torch.int32, device='cuda')
full_o = flash_attn_varlen_func( full_o = flash_attn_varlen_func(
full_q, full_k, full_v, full_q, full_k, full_v,
cu_seqlens_q=cu_seqlens, cu_seqlens_q=cu_seqlens,
@@ -208,39 +116,16 @@ def compute_standard_reference(layer_id, chunk_idx, scale, debug=False):
causal=True, causal=True,
) )
# Extract output for current chunk only # Extract output for current chunk
start_pos = sum(chunk_lengths[:-1]) start_pos = sum(chunk_lengths[:-1])
end_pos = sum(chunk_lengths) end_pos = sum(chunk_lengths)
chunk_output = full_o[start_pos:end_pos] return full_o[start_pos:end_pos].cpu()
if debug:
print(f" full_o shape: {full_o.shape}")
print(f" extracting positions [{start_pos}:{end_pos}]")
print(f" chunk_output shape: {chunk_output.shape}, mean={chunk_output.mean().item():.4f}")
return chunk_output.cpu()
# ============================================================ # ============================================================
# Test Runner # Main
# ============================================================ # ============================================================
def run_test(verbose=True):
"""Run the hook-based chunked prefill correctness test."""
global captures
captures = []
if verbose:
print("=" * 70)
print("Test: Hook-Based Chunked Prefill Correctness")
print("=" * 70)
print(f"Model: {MODEL_PATH}")
print(f"Input length: {INPUT_LEN} tokens")
print(f"Block size: {BLOCK_SIZE}")
print(f"Expected chunks: {INPUT_LEN // BLOCK_SIZE}")
print()
# Initialize LLM with CPU offload
llm = LLM( llm = LLM(
MODEL_PATH, MODEL_PATH,
enforce_eager=True, enforce_eager=True,
@@ -249,6 +134,7 @@ def run_test(verbose=True):
enable_cpu_offload=True, enable_cpu_offload=True,
kvcache_block_size=BLOCK_SIZE, kvcache_block_size=BLOCK_SIZE,
num_gpu_blocks=NUM_GPU_BLOCKS, num_gpu_blocks=NUM_GPU_BLOCKS,
dtype="float16",
) )
# Get model info # Get model info
@@ -256,218 +142,55 @@ def run_test(verbose=True):
head_dim = llm.model_runner.model.model.layers[0].self_attn.attn.head_dim head_dim = llm.model_runner.model.model.layers[0].self_attn.attn.head_dim
scale = head_dim ** -0.5 scale = head_dim ** -0.5
if verbose:
print(f"Num layers: {num_layers}")
print(f"Head dim: {head_dim}")
print()
# Register hooks # Register hooks
hooks = register_hooks(llm) hooks = []
if verbose: for layer_idx, decoder_layer in enumerate(llm.model_runner.model.model.layers):
print(f"Registered {len(hooks)} hooks") # Pre-hook: inject all ones for Q, K, V
# pre_hook = decoder_layer.self_attn.attn.register_forward_pre_hook(make_ones_injection_hook())
# hooks.append(pre_hook)
# Post-hook: capture Q, K, V, output
post_hook = decoder_layer.self_attn.attn.register_forward_hook(make_capture_hook(layer_idx))
hooks.append(post_hook)
# Generate random prompt # Run inference
seed(42) seed(42)
prompt_token_ids = [[randint(0, 10000) for _ in range(INPUT_LEN)]] prompt_token_ids = [[randint(0, 10000) for _ in range(INPUT_LEN)]]
outputs = llm.generate(prompt_token_ids, SamplingParams(temperature=0.6, max_tokens=1), use_tqdm=False)
# Run prefill only (max_tokens=1)
if verbose:
print("Running inference...")
sampling_params = SamplingParams(temperature=0.6, max_tokens=1)
outputs = llm.generate(prompt_token_ids, sampling_params, use_tqdm=False)
# Remove hooks # Remove hooks
for hook in hooks: for hook in hooks:
hook.remove() hook.remove()
# Analyze captures # Get CPU cache reference
if verbose: offload_engine = llm.model_runner.kvcache_manager.offload_engine
print(f"\nCaptured {len(captures)} attention calls") k_cache_cpu = offload_engine.k_cache_cpu.clone()
v_cache_cpu = offload_engine.v_cache_cpu.clone()
# Group by layer and chunk # Verify: compare actual output with reference computed from CPU cache
chunks_per_layer = {} all_passed = True
for c in captures: num_chunks = INPUT_LEN // BLOCK_SIZE
layer_id = c['layer_id']
chunk_idx = c['chunk_idx']
if layer_id not in chunks_per_layer:
chunks_per_layer[layer_id] = set()
chunks_per_layer[layer_id].add(chunk_idx)
if verbose: for idx,c in enumerate(captures):
print("Chunks per layer:", {k: sorted(v) for k, v in chunks_per_layer.items()})
print()
# First, verify CPU cache data integrity
if verbose:
print("\n--- CPU Cache Verification (Layer 0) ---")
# Get original k from chunk 0 and chunk 1 captures
chunk0_k = None
chunk1_k = None
chunk2_capture = None
for c in captures:
if c['layer_id'] == 0:
if c['chunk_idx'] == 0:
chunk0_k = c['k']
elif c['chunk_idx'] == 1:
chunk1_k = c['k']
elif c['chunk_idx'] == 2:
chunk2_capture = c
if chunk0_k is not None and chunk2_capture is not None and 'cpu_k0' in chunk2_capture:
cpu_k0 = chunk2_capture['cpu_k0']
diff_k0 = (chunk0_k - cpu_k0).abs().max().item()
print(f"Chunk 0 k vs CPU cache block 0: max_diff={diff_k0:.6f}")
if diff_k0 > 1e-3:
print(f" WARNING: CPU cache block 0 differs from original chunk 0 k!")
print(f" Original k[0,0,:5] = {chunk0_k[0,0,:5].tolist()}")
print(f" CPU k0[0,0,:5] = {cpu_k0[0,0,:5].tolist()}")
if chunk1_k is not None and chunk2_capture is not None and 'cpu_k1' in chunk2_capture:
cpu_k1 = chunk2_capture['cpu_k1']
diff_k1 = (chunk1_k - cpu_k1).abs().max().item()
print(f"Chunk 1 k vs CPU cache block 1: max_diff={diff_k1:.6f}")
if diff_k1 > 1e-3:
print(f" WARNING: CPU cache block 1 differs from original chunk 1 k!")
print(f" Original k[0,0,:5] = {chunk1_k[0,0,:5].tolist()}")
print(f" CPU k1[0,0,:5] = {cpu_k1[0,0,:5].tolist()}")
print()
# ================================================================
# Test 1: Verify against merge-based reference (same algorithm)
# ================================================================
if verbose:
print("--- Test 1: Merge-based Reference (verifies merge algorithm) ---")
all_passed_merge = True
results_merge = []
first_fail_debug = True
for c in captures:
layer_id = c['layer_id'] layer_id = c['layer_id']
chunk_idx = c['chunk_idx'] chunk_idx = c['chunk_idx']
# Skip chunk 0 (no previous KV to load)
if chunk_idx == 0: if chunk_idx == 0:
continue continue
debug_this = (chunk_idx >= 2 and layer_id == 0 and first_fail_debug) ref_output = compute_reference(layer_id, chunk_idx, scale, k_cache_cpu, v_cache_cpu, BLOCK_SIZE)
ref_output = compute_reference(layer_id, chunk_idx, scale, debug=debug_this)
if ref_output is None: if ref_output is None:
continue continue
actual_output = c['output'] actual_output = c['output']
diff = (actual_output - ref_output).abs() diff = (actual_output - ref_output).abs()
max_diff = diff.max().item() max_diff = diff.max().item()
mean_diff = diff.mean().item()
tol = 1e-2 passed = max_diff < 1e-1 # float16 tolerance
passed = max_diff < tol all_passed = all_passed and passed
all_passed_merge = all_passed_merge and passed
status = "PASS" if passed else "FAIL"
results_merge.append((layer_id, chunk_idx, passed, max_diff, mean_diff))
if verbose:
print(f"[{status}] Layer {layer_id:2d}, Chunk {chunk_idx}: "
f"max_diff={max_diff:.6f} mean_diff={mean_diff:.8f}")
if not passed and first_fail_debug:
first_fail_debug = False
print(f" Debug: actual_output shape={actual_output.shape}, mean={actual_output.mean().item():.4f}")
print(f" Debug: ref_output shape={ref_output.shape}, mean={ref_output.mean().item():.4f}")
max_idx = diff.argmax()
flat_actual = actual_output.flatten()
flat_ref = ref_output.flatten()
print(f" Debug: max_diff at idx={max_idx.item()}, actual={flat_actual[max_idx].item():.4f}, ref={flat_ref[max_idx].item():.4f}")
print()
# ================================================================
# Test 2: Verify against standard flash attention (single pass)
# ================================================================
if verbose:
print("--- Test 2: Standard FlashAttn Reference (verifies correctness vs non-chunked) ---")
all_passed_standard = True
results_standard = []
first_fail_debug = True
for c in captures:
layer_id = c['layer_id']
chunk_idx = c['chunk_idx']
if chunk_idx == 0:
continue
debug_this = (chunk_idx >= 2 and layer_id == 0 and first_fail_debug)
std_ref_output = compute_standard_reference(layer_id, chunk_idx, scale, debug=debug_this)
if std_ref_output is None:
continue
actual_output = c['output']
diff = (actual_output - std_ref_output).abs()
max_diff = diff.max().item()
mean_diff = diff.mean().item()
tol = 1e-2
passed = max_diff < tol
all_passed_standard = all_passed_standard and passed
status = "PASS" if passed else "FAIL"
results_standard.append((layer_id, chunk_idx, passed, max_diff, mean_diff))
if verbose:
print(f"[{status}] Layer {layer_id:2d}, Chunk {chunk_idx}: "
f"max_diff={max_diff:.6f} mean_diff={mean_diff:.8f}")
if not passed and first_fail_debug:
first_fail_debug = False
print(f" Debug: actual_output shape={actual_output.shape}, mean={actual_output.mean().item():.4f}")
print(f" Debug: std_ref_output shape={std_ref_output.shape}, mean={std_ref_output.mean().item():.4f}")
max_idx = diff.argmax()
flat_actual = actual_output.flatten()
flat_ref = std_ref_output.flatten()
print(f" Debug: max_diff at idx={max_idx.item()}, actual={flat_actual[max_idx].item():.4f}, ref={flat_ref[max_idx].item():.4f}")
print()
print("=" * 70)
# Summary
total_merge = len(results_merge)
passed_merge = sum(1 for r in results_merge if r[2])
total_standard = len(results_standard)
passed_standard = sum(1 for r in results_standard if r[2])
print(f"Merge-based reference: {passed_merge}/{total_merge} tests passed")
print(f"Standard FlashAttn ref: {passed_standard}/{total_standard} tests passed")
all_passed = all_passed_merge and all_passed_standard
if not all_passed_merge:
print("\nFailed merge-based tests:")
for layer_id, chunk_idx, passed, max_diff, mean_diff in results_merge:
if not passed: if not passed:
print(f" - Layer {layer_id}, Chunk {chunk_idx}: max_diff={max_diff:.6f}") print(f"[FAIL] Layer {layer_id}, Chunk {chunk_idx}: max_diff={max_diff:.6f}")
__import__('pdb').set_trace()
if not all_passed_standard: print(f"test_chunked_prefill_hook: {'PASSED' if all_passed else 'FAILED'}")
print("\nFailed standard FlashAttn tests:")
for layer_id, chunk_idx, passed, max_diff, mean_diff in results_standard:
if not passed:
print(f" - Layer {layer_id}, Chunk {chunk_idx}: max_diff={max_diff:.6f}")
print()
return all_passed
# ============================================================
# Main
# ============================================================
if __name__ == "__main__":
passed = run_test(verbose=True)
if passed:
print("test_chunked_prefill_hook: PASSED")
else:
print("test_chunked_prefill_hook: FAILED")
exit(1)