[refactor] Refactor the test_chunked_prefill/decode.
This commit is contained in:
@@ -1,208 +1,146 @@
|
|||||||
"""
|
"""
|
||||||
Hook-based correctness test for chunked decode attention.
|
Correctness test for chunked decode attention.
|
||||||
|
|
||||||
Uses PyTorch register_forward_hook() to capture real inference I/O,
|
Captures Q and output during inference, then computes reference using
|
||||||
then compares against reference computation to locate bugs.
|
CPU KV cache with standard flash attention.
|
||||||
|
|
||||||
This test targets the decode phase with CPU offload - after prefill,
|
|
||||||
the model generates tokens one by one while attending to all previous context.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
os.environ["NANOVLLM_LOG_LEVEL"] = "INFO"
|
os.environ["NANOVLLM_LOG_LEVEL"] = "WARNING"
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from random import randint, seed
|
from random import randint, seed
|
||||||
|
from typing import Dict, List
|
||||||
from nanovllm import LLM, SamplingParams
|
from nanovllm import LLM, SamplingParams
|
||||||
from nanovllm.utils.context import get_context
|
from nanovllm.utils.context import get_context
|
||||||
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
from flash_attn.flash_attn_interface import flash_attn_func
|
||||||
|
|
||||||
|
# Config
|
||||||
# ============================================================
|
MODEL_PATH = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/")
|
||||||
# Configuration
|
MAX_MODEL_LEN = 128 * 1024
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
MODEL_PATH = os.path.expanduser("~/models/Qwen3-0.6B/")
|
|
||||||
MAX_MODEL_LEN = 8 * 1024
|
|
||||||
NUM_GPU_BLOCKS = 2
|
NUM_GPU_BLOCKS = 2
|
||||||
INPUT_LEN = 2 * 1024 # 2K tokens for prefill
|
INPUT_LEN = 16 * 1024
|
||||||
NUM_DECODE_TOKENS = 5 # Generate 5 tokens to test decode
|
NUM_DECODE_TOKENS = 5
|
||||||
BLOCK_SIZE = 1024
|
BLOCK_SIZE = 1024
|
||||||
|
|
||||||
|
# State
|
||||||
# ============================================================
|
prefill_captures: List[Dict] = []
|
||||||
# Global capture storage
|
decode_captures: List[Dict] = []
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
captures = []
|
|
||||||
prefill_kv = {} # Store prefill k,v for reference computation
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
def make_ones_injection_hook():
|
||||||
# Hook Functions
|
"""Inject Q=K=V=1.0 for deterministic testing."""
|
||||||
# ============================================================
|
def hook(module, inputs):
|
||||||
|
q, k, v = inputs[0], inputs[1], inputs[2]
|
||||||
|
q_ones = torch.ones_like(q)
|
||||||
|
k_ones = torch.ones_like(k)
|
||||||
|
v_ones = torch.ones_like(v)
|
||||||
|
return (q_ones, k_ones, v_ones) + inputs[3:]
|
||||||
|
return hook
|
||||||
|
|
||||||
def make_hook(layer_id):
|
|
||||||
"""Create a forward hook for a specific layer."""
|
def make_capture_hook(layer_id: int):
|
||||||
|
"""Capture Q, K, V, output during inference."""
|
||||||
def hook(module, inputs, output):
|
def hook(module, inputs, output):
|
||||||
q, k, v = inputs
|
|
||||||
ctx = get_context()
|
ctx = get_context()
|
||||||
|
q, k, v = inputs
|
||||||
|
|
||||||
is_prefill = ctx.is_prefill
|
if ctx.is_prefill:
|
||||||
|
chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0
|
||||||
capture_entry = {
|
prefill_captures.append({
|
||||||
'layer_id': layer_id,
|
'layer_id': layer_id,
|
||||||
'is_prefill': is_prefill,
|
'chunk_idx': chunk_idx,
|
||||||
'q': q.clone().cpu(),
|
'q': q.clone().cpu(),
|
||||||
'k': k.clone().cpu(),
|
'k': k.clone().cpu(),
|
||||||
'v': v.clone().cpu(),
|
'v': v.clone().cpu(),
|
||||||
'output': output.clone().cpu(),
|
'output': output.clone().cpu(),
|
||||||
'is_chunked_prefill': ctx.is_chunked_prefill,
|
|
||||||
}
|
|
||||||
|
|
||||||
if is_prefill:
|
|
||||||
# Store prefill k,v for reference computation
|
|
||||||
chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0
|
|
||||||
capture_entry['chunk_idx'] = chunk_idx
|
|
||||||
if layer_id not in prefill_kv:
|
|
||||||
prefill_kv[layer_id] = []
|
|
||||||
prefill_kv[layer_id].append({
|
|
||||||
'chunk_idx': chunk_idx,
|
|
||||||
'k': k.clone().cpu(),
|
|
||||||
'v': v.clone().cpu(),
|
|
||||||
})
|
})
|
||||||
else:
|
else:
|
||||||
# Decode phase - capture decode token info
|
decode_step = len([c for c in decode_captures if c['layer_id'] == layer_id])
|
||||||
capture_entry['decode_step'] = len([c for c in captures
|
decode_captures.append({
|
||||||
if c['layer_id'] == layer_id and not c['is_prefill']])
|
'layer_id': layer_id,
|
||||||
|
'decode_step': decode_step,
|
||||||
captures.append(capture_entry)
|
'q': q.clone().cpu(),
|
||||||
|
'k': k.clone().cpu(),
|
||||||
|
'v': v.clone().cpu(),
|
||||||
|
'output': output.clone().cpu(),
|
||||||
|
})
|
||||||
return hook
|
return hook
|
||||||
|
|
||||||
|
|
||||||
def register_hooks(llm):
|
def compute_decode_reference(layer_id: int, decode_step: int, scale: float,
|
||||||
"""Register forward hooks on all Attention modules."""
|
k_cache_cpu: torch.Tensor, v_cache_cpu: torch.Tensor,
|
||||||
hooks = []
|
block_size: int, num_prefill_chunks: int) -> torch.Tensor:
|
||||||
model = llm.model_runner.model
|
|
||||||
|
|
||||||
for layer_idx, decoder_layer in enumerate(model.model.layers):
|
|
||||||
attn_module = decoder_layer.self_attn.attn
|
|
||||||
hook = attn_module.register_forward_hook(make_hook(layer_idx))
|
|
||||||
hooks.append(hook)
|
|
||||||
|
|
||||||
return hooks
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Reference Computation
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
def compute_decode_reference(layer_id, decode_step, scale, debug=False):
|
|
||||||
"""
|
"""
|
||||||
Compute reference decode attention output for a specific layer.
|
Compute reference decode output using CPU KV cache and standard flash attention.
|
||||||
|
|
||||||
For decode, the query is a single token that attends to:
|
For decode, query attends to:
|
||||||
1. All prefill KV (from CPU cache)
|
1. All prefill KV (from CPU cache)
|
||||||
2. All previous decode tokens (stored in GPU decode slot)
|
2. All previous decode tokens (from captured decode k, v)
|
||||||
"""
|
"""
|
||||||
# Get the decode capture
|
# Get decode capture for this layer and step
|
||||||
decode_captures = [c for c in captures
|
decode_cap = None
|
||||||
if c['layer_id'] == layer_id and not c['is_prefill']]
|
for c in decode_captures:
|
||||||
if decode_step >= len(decode_captures):
|
if c['layer_id'] == layer_id and c['decode_step'] == decode_step:
|
||||||
|
decode_cap = c
|
||||||
|
break
|
||||||
|
|
||||||
|
if decode_cap is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
decode_capture = decode_captures[decode_step]
|
# Query: single decode token
|
||||||
q = decode_capture['q'].cuda() # [1, num_heads, head_dim]
|
q = decode_cap['q'].cuda() # [1, num_heads, head_dim]
|
||||||
q_batched = q.unsqueeze(1) # [1, 1, num_heads, head_dim]
|
q_batched = q.unsqueeze(0) # [1, 1, num_heads, head_dim]
|
||||||
|
|
||||||
if debug:
|
# Collect all K, V: prefill chunks from CPU cache + decode tokens from captures
|
||||||
print(f" Reference for L{layer_id} D{decode_step}:")
|
all_k = []
|
||||||
print(f" q shape: {q_batched.shape}, mean={q_batched.mean().item():.4f}")
|
all_v = []
|
||||||
|
|
||||||
o_acc, lse_acc = None, None
|
# 1. Prefill chunks from CPU cache
|
||||||
|
for cidx in range(num_prefill_chunks):
|
||||||
|
# Get prefill capture to know the sequence length for this chunk
|
||||||
|
prefill_cap = None
|
||||||
|
for c in prefill_captures:
|
||||||
|
if c['layer_id'] == layer_id and c['chunk_idx'] == cidx:
|
||||||
|
prefill_cap = c
|
||||||
|
break
|
||||||
|
|
||||||
# Attend to all prefill chunks
|
if prefill_cap is not None:
|
||||||
if layer_id in prefill_kv:
|
seq_len = prefill_cap['q'].shape[0]
|
||||||
for chunk_data in sorted(prefill_kv[layer_id], key=lambda x: x['chunk_idx']):
|
k = k_cache_cpu[layer_id, cidx, :seq_len].cuda()
|
||||||
k = chunk_data['k'].cuda().unsqueeze(0) # [1, seqlen, kv_heads, head_dim]
|
v = v_cache_cpu[layer_id, cidx, :seq_len].cuda()
|
||||||
v = chunk_data['v'].cuda().unsqueeze(0)
|
all_k.append(k)
|
||||||
|
all_v.append(v)
|
||||||
|
|
||||||
o, lse = flash_attn_with_lse(q_batched, k, v, softmax_scale=scale, causal=False)
|
# 2. Decode tokens from captures (up to and including current step)
|
||||||
|
for step in range(decode_step + 1):
|
||||||
|
for c in decode_captures:
|
||||||
|
if c['layer_id'] == layer_id and c['decode_step'] == step:
|
||||||
|
all_k.append(c['k'].cuda())
|
||||||
|
all_v.append(c['v'].cuda())
|
||||||
|
break
|
||||||
|
|
||||||
if debug:
|
if not all_k:
|
||||||
print(f" Prefill chunk {chunk_data['chunk_idx']}: o.mean={o.mean().item():.6f}")
|
|
||||||
|
|
||||||
if o_acc is None:
|
|
||||||
o_acc, lse_acc = o, lse
|
|
||||||
else:
|
|
||||||
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o, lse)
|
|
||||||
|
|
||||||
# Attend to previous decode tokens (including current)
|
|
||||||
# In decode, the current token's k,v are stored, and we need to attend to all previous decode tokens
|
|
||||||
# For step 0, we just have the current token's k,v
|
|
||||||
# For step 1, we have tokens 0 and 1's k,v
|
|
||||||
# etc.
|
|
||||||
|
|
||||||
# Collect k,v from all decode steps up to and including current
|
|
||||||
decode_kv = []
|
|
||||||
for i in range(decode_step + 1):
|
|
||||||
if i < len(decode_captures):
|
|
||||||
decode_kv.append({
|
|
||||||
'k': decode_captures[i]['k'].cuda(),
|
|
||||||
'v': decode_captures[i]['v'].cuda(),
|
|
||||||
})
|
|
||||||
|
|
||||||
if decode_kv:
|
|
||||||
# Stack decode k,v into a single tensor
|
|
||||||
decode_k = torch.cat([d['k'] for d in decode_kv], dim=0).unsqueeze(0) # [1, num_decode, kv_heads, head_dim]
|
|
||||||
decode_v = torch.cat([d['v'] for d in decode_kv], dim=0).unsqueeze(0)
|
|
||||||
|
|
||||||
if debug:
|
|
||||||
print(f" Decode tokens: {len(decode_kv)}, k.shape={decode_k.shape}")
|
|
||||||
|
|
||||||
# For decode, we use causal=False since we're attending to all decode tokens
|
|
||||||
# (the causal masking was already handled by only including tokens up to current)
|
|
||||||
o_decode, lse_decode = flash_attn_with_lse(q_batched, decode_k, decode_v,
|
|
||||||
softmax_scale=scale, causal=False)
|
|
||||||
|
|
||||||
if debug:
|
|
||||||
print(f" Decode attention: o.mean={o_decode.mean().item():.6f}")
|
|
||||||
|
|
||||||
if o_acc is None:
|
|
||||||
o_acc, lse_acc = o_decode, lse_decode
|
|
||||||
else:
|
|
||||||
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o_decode, lse_decode)
|
|
||||||
|
|
||||||
if o_acc is None:
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if debug:
|
# Concatenate all K, V
|
||||||
print(f" Final: o.mean={o_acc.mean().item():.6f}")
|
full_k = torch.cat(all_k, dim=0).unsqueeze(0) # [1, total_len, kv_heads, head_dim]
|
||||||
|
full_v = torch.cat(all_v, dim=0).unsqueeze(0)
|
||||||
|
|
||||||
return o_acc.squeeze(0).squeeze(0).cpu() # [num_heads, head_dim]
|
# Run flash attention (non-causal since we explicitly control what KV to include)
|
||||||
|
output = flash_attn_func(
|
||||||
|
q_batched, full_k, full_v,
|
||||||
|
softmax_scale=scale,
|
||||||
|
causal=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
return output.squeeze(0).squeeze(0).cpu() # [num_heads, head_dim]
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
# ============================================================
|
||||||
# Test Runner
|
# Main
|
||||||
# ============================================================
|
# ============================================================
|
||||||
|
|
||||||
def run_test(verbose=True):
|
|
||||||
"""Run the hook-based chunked decode correctness test."""
|
|
||||||
global captures, prefill_kv
|
|
||||||
captures = []
|
|
||||||
prefill_kv = {}
|
|
||||||
|
|
||||||
if verbose:
|
|
||||||
print("=" * 70)
|
|
||||||
print("Test: Hook-Based Chunked Decode Correctness")
|
|
||||||
print("=" * 70)
|
|
||||||
print(f"Model: {MODEL_PATH}")
|
|
||||||
print(f"Input length: {INPUT_LEN} tokens")
|
|
||||||
print(f"Decode tokens: {NUM_DECODE_TOKENS}")
|
|
||||||
print(f"Block size: {BLOCK_SIZE}")
|
|
||||||
print()
|
|
||||||
|
|
||||||
# Initialize LLM with CPU offload
|
|
||||||
llm = LLM(
|
llm = LLM(
|
||||||
MODEL_PATH,
|
MODEL_PATH,
|
||||||
enforce_eager=True,
|
enforce_eager=True,
|
||||||
@@ -211,6 +149,7 @@ def run_test(verbose=True):
|
|||||||
enable_cpu_offload=True,
|
enable_cpu_offload=True,
|
||||||
kvcache_block_size=BLOCK_SIZE,
|
kvcache_block_size=BLOCK_SIZE,
|
||||||
num_gpu_blocks=NUM_GPU_BLOCKS,
|
num_gpu_blocks=NUM_GPU_BLOCKS,
|
||||||
|
dtype="float16",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get model info
|
# Get model info
|
||||||
@@ -218,157 +157,58 @@ def run_test(verbose=True):
|
|||||||
head_dim = llm.model_runner.model.model.layers[0].self_attn.attn.head_dim
|
head_dim = llm.model_runner.model.model.layers[0].self_attn.attn.head_dim
|
||||||
scale = head_dim ** -0.5
|
scale = head_dim ** -0.5
|
||||||
|
|
||||||
if verbose:
|
|
||||||
print(f"Num layers: {num_layers}")
|
|
||||||
print(f"Head dim: {head_dim}")
|
|
||||||
print()
|
|
||||||
|
|
||||||
# Register hooks
|
# Register hooks
|
||||||
hooks = register_hooks(llm)
|
hooks = []
|
||||||
if verbose:
|
for layer_idx, decoder_layer in enumerate(llm.model_runner.model.model.layers):
|
||||||
print(f"Registered {len(hooks)} hooks")
|
# Pre-hook: inject all ones for Q, K, V
|
||||||
|
# pre_hook = decoder_layer.self_attn.attn.register_forward_pre_hook(make_ones_injection_hook())
|
||||||
|
# hooks.append(pre_hook)
|
||||||
|
# Post-hook: capture Q, K, V, output
|
||||||
|
post_hook = decoder_layer.self_attn.attn.register_forward_hook(make_capture_hook(layer_idx))
|
||||||
|
hooks.append(post_hook)
|
||||||
|
|
||||||
# Generate random prompt
|
# Run inference
|
||||||
seed(42)
|
seed(42)
|
||||||
prompt_token_ids = [[randint(0, 10000) for _ in range(INPUT_LEN)]]
|
prompt_token_ids = [[randint(0, 10000) for _ in range(INPUT_LEN)]]
|
||||||
|
outputs = llm.generate(prompt_token_ids, SamplingParams(temperature=0.6, max_tokens=NUM_DECODE_TOKENS), use_tqdm=False)
|
||||||
# Run prefill and decode
|
|
||||||
if verbose:
|
|
||||||
print(f"Running inference with {NUM_DECODE_TOKENS} decode tokens...")
|
|
||||||
sampling_params = SamplingParams(temperature=0.6, max_tokens=NUM_DECODE_TOKENS)
|
|
||||||
outputs = llm.generate(prompt_token_ids, sampling_params, use_tqdm=False)
|
|
||||||
|
|
||||||
# Remove hooks
|
# Remove hooks
|
||||||
for hook in hooks:
|
for hook in hooks:
|
||||||
hook.remove()
|
hook.remove()
|
||||||
|
|
||||||
# =========== VERIFICATION: Check CPU cache after prefill ===========
|
# Get CPU cache reference
|
||||||
# Verify that CPU cache data matches captured prefill k,v
|
|
||||||
if verbose:
|
|
||||||
print("\n--- CPU Cache Verification (After Prefill) ---")
|
|
||||||
offload_engine = llm.model_runner.kvcache_manager.offload_engine
|
offload_engine = llm.model_runner.kvcache_manager.offload_engine
|
||||||
|
k_cache_cpu = offload_engine.k_cache_cpu.clone()
|
||||||
|
v_cache_cpu = offload_engine.v_cache_cpu.clone()
|
||||||
|
|
||||||
# For each prefill capture, check if CPU cache matches
|
# Calculate number of prefill chunks
|
||||||
for layer_id in [0]: # Only check layer 0 for brevity
|
num_prefill_chunks = INPUT_LEN // BLOCK_SIZE
|
||||||
if layer_id not in prefill_kv:
|
|
||||||
continue
|
|
||||||
|
|
||||||
for chunk_data in prefill_kv[layer_id]:
|
# Verify decode outputs
|
||||||
chunk_idx = chunk_data['chunk_idx']
|
|
||||||
captured_k = chunk_data['k'] # [block_size, kv_heads, head_dim]
|
|
||||||
|
|
||||||
# CPU block ID should be chunk_idx (based on allocation order)
|
|
||||||
cpu_block_id = chunk_idx
|
|
||||||
cpu_k = offload_engine.k_cache_cpu[layer_id, cpu_block_id].cpu()
|
|
||||||
|
|
||||||
diff = (captured_k - cpu_k).abs().max().item()
|
|
||||||
print(f"Layer {layer_id}, Chunk {chunk_idx}: captured_k vs cpu_k max_diff={diff:.6f}")
|
|
||||||
if diff > 1e-3:
|
|
||||||
print(f" WARNING: CPU cache doesn't match captured k!")
|
|
||||||
print(f" captured_k[0,0,:5] = {captured_k[0,0,:5].tolist()}")
|
|
||||||
print(f" cpu_k[0,0,:5] = {cpu_k[0,0,:5].tolist()}")
|
|
||||||
print()
|
|
||||||
|
|
||||||
# Analyze captures
|
|
||||||
prefill_count = sum(1 for c in captures if c['is_prefill'])
|
|
||||||
decode_count = sum(1 for c in captures if not c['is_prefill'])
|
|
||||||
if verbose:
|
|
||||||
print(f"\nCaptured {prefill_count} prefill calls, {decode_count} decode calls")
|
|
||||||
|
|
||||||
# Count decode steps per layer
|
|
||||||
decode_per_layer = {}
|
|
||||||
for c in captures:
|
|
||||||
if not c['is_prefill']:
|
|
||||||
layer_id = c['layer_id']
|
|
||||||
if layer_id not in decode_per_layer:
|
|
||||||
decode_per_layer[layer_id] = 0
|
|
||||||
decode_per_layer[layer_id] += 1
|
|
||||||
|
|
||||||
if verbose:
|
|
||||||
print(f"Decode calls per layer: {decode_per_layer}")
|
|
||||||
print()
|
|
||||||
|
|
||||||
# Verify decode correctness
|
|
||||||
all_passed = True
|
all_passed = True
|
||||||
results = []
|
|
||||||
first_fail_debug = True
|
|
||||||
|
|
||||||
for c in captures:
|
|
||||||
if c['is_prefill']:
|
|
||||||
continue # Skip prefill (already tested in test_chunked_prefill_hook.py)
|
|
||||||
|
|
||||||
|
for c in decode_captures:
|
||||||
layer_id = c['layer_id']
|
layer_id = c['layer_id']
|
||||||
decode_step = c['decode_step']
|
decode_step = c['decode_step']
|
||||||
|
|
||||||
# Only test first decode step for now (simpler reference computation)
|
ref_output = compute_decode_reference(
|
||||||
if decode_step > 0:
|
layer_id, decode_step, scale,
|
||||||
continue
|
k_cache_cpu, v_cache_cpu, BLOCK_SIZE, num_prefill_chunks
|
||||||
|
)
|
||||||
# Compute reference (debug first failure)
|
|
||||||
debug_this = (layer_id == 0 and first_fail_debug)
|
|
||||||
ref_output = compute_decode_reference(layer_id, decode_step, scale, debug=debug_this)
|
|
||||||
if ref_output is None:
|
if ref_output is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Compare
|
actual_output = c['output'].squeeze(0)
|
||||||
actual_output = c['output'].squeeze(0) # Remove seq dim for decode
|
|
||||||
if actual_output.dim() == 3:
|
if actual_output.dim() == 3:
|
||||||
actual_output = actual_output.squeeze(0) # Handle [1, heads, dim] case
|
actual_output = actual_output.squeeze(0)
|
||||||
|
|
||||||
diff = (actual_output - ref_output).abs()
|
diff = (actual_output - ref_output).abs()
|
||||||
max_diff = diff.max().item()
|
max_diff = diff.max().item()
|
||||||
mean_diff = diff.mean().item()
|
|
||||||
|
|
||||||
tol = 1e-2
|
passed = max_diff < 1e-1
|
||||||
passed = max_diff < tol
|
|
||||||
all_passed = all_passed and passed
|
all_passed = all_passed and passed
|
||||||
|
|
||||||
status = "PASS" if passed else "FAIL"
|
# if not passed:
|
||||||
results.append((layer_id, decode_step, passed, max_diff, mean_diff))
|
print(f"[FAIL] Layer {layer_id}, Decode {decode_step}: max_diff={max_diff:.6f}")
|
||||||
|
|
||||||
if verbose:
|
print(f"test_chunked_decode_hook: {'PASSED' if all_passed else 'FAILED'}")
|
||||||
print(f"[{status}] Layer {layer_id:2d}, Decode {decode_step}: "
|
|
||||||
f"max_diff={max_diff:.6f} mean_diff={mean_diff:.8f}")
|
|
||||||
|
|
||||||
# Debug first failure
|
|
||||||
if not passed and first_fail_debug:
|
|
||||||
first_fail_debug = False
|
|
||||||
print(f" Debug: actual_output shape={actual_output.shape}, mean={actual_output.mean().item():.4f}")
|
|
||||||
print(f" Debug: ref_output shape={ref_output.shape}, mean={ref_output.mean().item():.4f}")
|
|
||||||
# Find where max diff is
|
|
||||||
max_idx = diff.argmax()
|
|
||||||
flat_actual = actual_output.flatten()
|
|
||||||
flat_ref = ref_output.flatten()
|
|
||||||
print(f" Debug: max_diff at idx={max_idx.item()}, actual={flat_actual[max_idx].item():.4f}, ref={flat_ref[max_idx].item():.4f}")
|
|
||||||
|
|
||||||
print()
|
|
||||||
print("=" * 70)
|
|
||||||
|
|
||||||
# Summary
|
|
||||||
total_tests = len(results)
|
|
||||||
passed_count = sum(1 for r in results if r[2])
|
|
||||||
|
|
||||||
print(f"Results: {passed_count}/{total_tests} tests passed")
|
|
||||||
|
|
||||||
if not all_passed:
|
|
||||||
print("\nFailed tests:")
|
|
||||||
for layer_id, decode_step, passed, max_diff, mean_diff in results:
|
|
||||||
if not passed:
|
|
||||||
print(f" - Layer {layer_id}, Decode {decode_step}: max_diff={max_diff:.6f}")
|
|
||||||
|
|
||||||
print()
|
|
||||||
return all_passed
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Main
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
passed = run_test(verbose=True)
|
|
||||||
|
|
||||||
if passed:
|
|
||||||
print("test_chunked_decode_hook: PASSED")
|
|
||||||
else:
|
|
||||||
print("test_chunked_decode_hook: FAILED")
|
|
||||||
exit(1)
|
|
||||||
|
|||||||
@@ -1,203 +1,111 @@
|
|||||||
"""
|
"""
|
||||||
Hook-based correctness test for chunked prefill attention.
|
Correctness test for chunked prefill attention.
|
||||||
|
|
||||||
Uses PyTorch register_forward_hook() to capture real inference I/O,
|
Captures Q and output during inference, then computes reference using
|
||||||
then compares against reference computation to locate bugs.
|
CPU KV cache with standard flash attention.
|
||||||
|
|
||||||
This test targets the integration layer (context setup, cpu_block_table management)
|
|
||||||
which is where the needle test fails despite isolated attention tests passing.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
os.environ["NANOVLLM_LOG_LEVEL"] = "DEBUG"
|
os.environ["NANOVLLM_LOG_LEVEL"] = "WARNING"
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from random import randint, seed
|
from random import randint, seed
|
||||||
|
from typing import Dict, List
|
||||||
from nanovllm import LLM, SamplingParams
|
from nanovllm import LLM, SamplingParams
|
||||||
from nanovllm.utils.context import get_context
|
from nanovllm.utils.context import get_context
|
||||||
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
|
||||||
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
||||||
|
|
||||||
|
# Config
|
||||||
# ============================================================
|
|
||||||
# Configuration
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
MODEL_PATH = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/")
|
MODEL_PATH = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/")
|
||||||
MAX_MODEL_LEN = 32 * 1024
|
MAX_MODEL_LEN = 128 * 1024
|
||||||
NUM_GPU_BLOCKS = 2
|
NUM_GPU_BLOCKS = 2
|
||||||
INPUT_LEN = 16 * 1024 # 4K tokens = 4 chunks with 1K block size
|
INPUT_LEN = 16 * 1024
|
||||||
BLOCK_SIZE = 1024
|
BLOCK_SIZE = 1024
|
||||||
|
|
||||||
|
# State - capture Q and output for each (layer, chunk)
|
||||||
# ============================================================
|
captures: List[Dict] = []
|
||||||
# Global capture storage
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
captures = []
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
def make_ones_injection_hook():
|
||||||
# Hook Functions
|
"""Inject Q=K=V=1.0 for deterministic testing."""
|
||||||
# ============================================================
|
def hook(module, inputs):
|
||||||
|
|
||||||
def make_hook(layer_id):
|
|
||||||
"""Create a forward hook for a specific layer."""
|
|
||||||
def hook(module, inputs, output):
|
|
||||||
q, k, v = inputs
|
|
||||||
ctx = get_context()
|
ctx = get_context()
|
||||||
|
if not ctx.is_prefill:
|
||||||
|
return inputs
|
||||||
|
|
||||||
# Only capture prefill phase
|
q, k, v = inputs[0], inputs[1], inputs[2]
|
||||||
|
q_ones = torch.ones_like(q)
|
||||||
|
k_ones = torch.ones_like(k)
|
||||||
|
v_ones = torch.ones_like(v)
|
||||||
|
return (q_ones, k_ones, v_ones) + inputs[3:]
|
||||||
|
return hook
|
||||||
|
|
||||||
|
|
||||||
|
def make_capture_hook(layer_id: int):
|
||||||
|
"""Capture Q and output during prefill."""
|
||||||
|
def hook(module, inputs, output):
|
||||||
|
ctx = get_context()
|
||||||
if not ctx.is_prefill:
|
if not ctx.is_prefill:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
q, k, v = inputs
|
||||||
chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0
|
chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0
|
||||||
|
|
||||||
capture_entry = {
|
captures.append({
|
||||||
'layer_id': layer_id,
|
'layer_id': layer_id,
|
||||||
'chunk_idx': chunk_idx,
|
'chunk_idx': chunk_idx,
|
||||||
'q': q.clone().cpu(),
|
'q': q.clone().cpu(),
|
||||||
'k': k.clone().cpu(),
|
'k': k.clone().cpu(),
|
||||||
'v': v.clone().cpu(),
|
'v': v.clone().cpu(),
|
||||||
'output': output.clone().cpu(),
|
'output': output.clone().cpu(),
|
||||||
'is_chunked_prefill': ctx.is_chunked_prefill,
|
})
|
||||||
}
|
|
||||||
|
|
||||||
# For debugging: also capture CPU cache state for layer 0
|
|
||||||
if layer_id == 0 and chunk_idx >= 2:
|
|
||||||
kvcache_manager = ctx.kvcache_manager if hasattr(ctx, 'kvcache_manager') else None
|
|
||||||
if kvcache_manager is not None and hasattr(kvcache_manager, 'offload_engine'):
|
|
||||||
oe = kvcache_manager.offload_engine
|
|
||||||
# Get what should have been loaded from CPU
|
|
||||||
cpu_k0 = oe.k_cache_cpu[0, 0].clone().cpu() # Layer 0, CPU block 0
|
|
||||||
cpu_k1 = oe.k_cache_cpu[0, 1].clone().cpu() # Layer 0, CPU block 1
|
|
||||||
capture_entry['cpu_k0'] = cpu_k0
|
|
||||||
capture_entry['cpu_k1'] = cpu_k1
|
|
||||||
|
|
||||||
captures.append(capture_entry)
|
|
||||||
return hook
|
return hook
|
||||||
|
|
||||||
|
|
||||||
def register_hooks(llm):
|
def compute_reference(layer_id: int, chunk_idx: int, scale: float,
|
||||||
"""Register forward hooks on all Attention modules."""
|
k_cache_cpu: torch.Tensor, v_cache_cpu: torch.Tensor,
|
||||||
hooks = []
|
block_size: int) -> torch.Tensor:
|
||||||
model = llm.model_runner.model
|
|
||||||
|
|
||||||
for layer_idx, decoder_layer in enumerate(model.model.layers):
|
|
||||||
attn_module = decoder_layer.self_attn.attn
|
|
||||||
hook = attn_module.register_forward_hook(make_hook(layer_idx))
|
|
||||||
hooks.append(hook)
|
|
||||||
|
|
||||||
return hooks
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Reference Computation
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
def compute_reference(layer_id, chunk_idx, scale, debug=False):
|
|
||||||
"""
|
"""
|
||||||
Compute reference attention output for a specific layer and chunk.
|
Compute reference output using CPU KV cache and standard flash attention.
|
||||||
|
|
||||||
Uses the captured k, v from all chunks up to and including chunk_idx.
|
Concatenates all Q, K, V from chunks 0..chunk_idx and runs causal attention,
|
||||||
|
then extracts output for the current chunk.
|
||||||
"""
|
"""
|
||||||
# Filter captures for this layer
|
# Get all captures for this layer up to chunk_idx
|
||||||
layer_captures = [c for c in captures
|
layer_captures = [c for c in captures
|
||||||
if c['layer_id'] == layer_id and c['chunk_idx'] <= chunk_idx]
|
if c['layer_id'] == layer_id and c['chunk_idx'] <= chunk_idx]
|
||||||
|
|
||||||
if not layer_captures:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Get current chunk's q
|
|
||||||
current_capture = [c for c in layer_captures if c['chunk_idx'] == chunk_idx][0]
|
|
||||||
q = current_capture['q'].cuda().unsqueeze(0) # [1, seqlen, nheads, headdim]
|
|
||||||
|
|
||||||
# Collect all k, v up to current chunk
|
|
||||||
kv_list = []
|
|
||||||
for c in sorted(layer_captures, key=lambda x: x['chunk_idx']):
|
|
||||||
k = c['k'].cuda().unsqueeze(0) # [1, seqlen, nheads, headdim]
|
|
||||||
v = c['v'].cuda().unsqueeze(0)
|
|
||||||
kv_list.append((k, v, c['chunk_idx']))
|
|
||||||
|
|
||||||
if debug:
|
|
||||||
print(f" Reference for L{layer_id} C{chunk_idx}:")
|
|
||||||
print(f" q shape: {q.shape}, mean={q.mean().item():.4f}")
|
|
||||||
print(f" kv_list: {len(kv_list)} chunks")
|
|
||||||
for i, (k, v, cidx) in enumerate(kv_list):
|
|
||||||
print(f" chunk {cidx}: k.mean={k.mean().item():.4f}, v.mean={v.mean().item():.4f}")
|
|
||||||
|
|
||||||
o_acc, lse_acc = None, None
|
|
||||||
|
|
||||||
# Previous chunks: non-causal attention
|
|
||||||
for i in range(len(kv_list) - 1):
|
|
||||||
k, v, _ = kv_list[i]
|
|
||||||
o, lse = flash_attn_with_lse(q, k, v, softmax_scale=scale, causal=False)
|
|
||||||
if o_acc is None:
|
|
||||||
o_acc, lse_acc = o, lse
|
|
||||||
else:
|
|
||||||
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o, lse)
|
|
||||||
|
|
||||||
# Current chunk: causal attention
|
|
||||||
k_cur, v_cur, _ = kv_list[-1]
|
|
||||||
o_cur, lse_cur = flash_attn_with_lse(q, k_cur, v_cur, softmax_scale=scale, causal=True)
|
|
||||||
|
|
||||||
if o_acc is None:
|
|
||||||
return o_cur.squeeze(0).cpu()
|
|
||||||
|
|
||||||
final_o, _ = merge_attention_outputs(o_acc, lse_acc, o_cur, lse_cur)
|
|
||||||
return final_o.squeeze(0).cpu()
|
|
||||||
|
|
||||||
|
|
||||||
def compute_standard_reference(layer_id, chunk_idx, scale, debug=False):
|
|
||||||
"""
|
|
||||||
Compute reference using standard flash attention (single pass with all K, V).
|
|
||||||
|
|
||||||
This simulates what standard (non-chunked) prefill would produce.
|
|
||||||
Concatenates all Q, K, V from chunks 0 to chunk_idx and runs a single
|
|
||||||
causal attention pass, then extracts the output for the current chunk.
|
|
||||||
"""
|
|
||||||
# Filter captures for this layer
|
|
||||||
layer_captures = [c for c in captures
|
|
||||||
if c['layer_id'] == layer_id and c['chunk_idx'] <= chunk_idx]
|
|
||||||
|
|
||||||
if not layer_captures:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Sort by chunk index
|
|
||||||
layer_captures = sorted(layer_captures, key=lambda x: x['chunk_idx'])
|
layer_captures = sorted(layer_captures, key=lambda x: x['chunk_idx'])
|
||||||
|
|
||||||
# Concatenate all Q, K, V
|
if not layer_captures:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Collect Q from captures, K/V from CPU cache
|
||||||
all_q = []
|
all_q = []
|
||||||
all_k = []
|
all_k = []
|
||||||
all_v = []
|
all_v = []
|
||||||
chunk_lengths = []
|
chunk_lengths = []
|
||||||
|
|
||||||
for c in layer_captures:
|
for c in layer_captures:
|
||||||
|
cidx = c['chunk_idx']
|
||||||
q = c['q'].cuda() # [seqlen, nheads, headdim]
|
q = c['q'].cuda() # [seqlen, nheads, headdim]
|
||||||
k = c['k'].cuda()
|
|
||||||
v = c['v'].cuda()
|
|
||||||
all_q.append(q)
|
all_q.append(q)
|
||||||
all_k.append(k)
|
|
||||||
all_v.append(v)
|
|
||||||
chunk_lengths.append(q.shape[0])
|
chunk_lengths.append(q.shape[0])
|
||||||
|
|
||||||
# Concatenate along sequence dimension
|
# Get K, V from CPU cache (already offloaded during prefill)
|
||||||
full_q = torch.cat(all_q, dim=0) # [total_seqlen, nheads, headdim]
|
# CPU cache shape: [num_layers, num_blocks, block_size, kv_heads, head_dim]
|
||||||
|
k = k_cache_cpu[layer_id, cidx, :q.shape[0]].cuda()
|
||||||
|
v = v_cache_cpu[layer_id, cidx, :q.shape[0]].cuda()
|
||||||
|
all_k.append(k)
|
||||||
|
all_v.append(v)
|
||||||
|
|
||||||
|
# Concatenate
|
||||||
|
full_q = torch.cat(all_q, dim=0)
|
||||||
full_k = torch.cat(all_k, dim=0)
|
full_k = torch.cat(all_k, dim=0)
|
||||||
full_v = torch.cat(all_v, dim=0)
|
full_v = torch.cat(all_v, dim=0)
|
||||||
|
|
||||||
total_len = full_q.shape[0]
|
total_len = full_q.shape[0]
|
||||||
|
|
||||||
if debug:
|
|
||||||
print(f" Standard Reference for L{layer_id} C{chunk_idx}:")
|
|
||||||
print(f" full_q shape: {full_q.shape}, mean={full_q.mean().item():.4f}")
|
|
||||||
print(f" full_k shape: {full_k.shape}, mean={full_k.mean().item():.4f}")
|
|
||||||
print(f" chunk_lengths: {chunk_lengths}")
|
|
||||||
|
|
||||||
# Run standard causal flash attention
|
# Run standard causal flash attention
|
||||||
# flash_attn_varlen_func expects: q, k, v with shape [total_seqlen, nheads, headdim]
|
|
||||||
cu_seqlens = torch.tensor([0, total_len], dtype=torch.int32, device='cuda')
|
cu_seqlens = torch.tensor([0, total_len], dtype=torch.int32, device='cuda')
|
||||||
|
|
||||||
full_o = flash_attn_varlen_func(
|
full_o = flash_attn_varlen_func(
|
||||||
full_q, full_k, full_v,
|
full_q, full_k, full_v,
|
||||||
cu_seqlens_q=cu_seqlens,
|
cu_seqlens_q=cu_seqlens,
|
||||||
@@ -208,39 +116,16 @@ def compute_standard_reference(layer_id, chunk_idx, scale, debug=False):
|
|||||||
causal=True,
|
causal=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Extract output for current chunk only
|
# Extract output for current chunk
|
||||||
start_pos = sum(chunk_lengths[:-1])
|
start_pos = sum(chunk_lengths[:-1])
|
||||||
end_pos = sum(chunk_lengths)
|
end_pos = sum(chunk_lengths)
|
||||||
chunk_output = full_o[start_pos:end_pos]
|
return full_o[start_pos:end_pos].cpu()
|
||||||
|
|
||||||
if debug:
|
|
||||||
print(f" full_o shape: {full_o.shape}")
|
|
||||||
print(f" extracting positions [{start_pos}:{end_pos}]")
|
|
||||||
print(f" chunk_output shape: {chunk_output.shape}, mean={chunk_output.mean().item():.4f}")
|
|
||||||
|
|
||||||
return chunk_output.cpu()
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
# ============================================================
|
||||||
# Test Runner
|
# Main
|
||||||
# ============================================================
|
# ============================================================
|
||||||
|
|
||||||
def run_test(verbose=True):
|
|
||||||
"""Run the hook-based chunked prefill correctness test."""
|
|
||||||
global captures
|
|
||||||
captures = []
|
|
||||||
|
|
||||||
if verbose:
|
|
||||||
print("=" * 70)
|
|
||||||
print("Test: Hook-Based Chunked Prefill Correctness")
|
|
||||||
print("=" * 70)
|
|
||||||
print(f"Model: {MODEL_PATH}")
|
|
||||||
print(f"Input length: {INPUT_LEN} tokens")
|
|
||||||
print(f"Block size: {BLOCK_SIZE}")
|
|
||||||
print(f"Expected chunks: {INPUT_LEN // BLOCK_SIZE}")
|
|
||||||
print()
|
|
||||||
|
|
||||||
# Initialize LLM with CPU offload
|
|
||||||
llm = LLM(
|
llm = LLM(
|
||||||
MODEL_PATH,
|
MODEL_PATH,
|
||||||
enforce_eager=True,
|
enforce_eager=True,
|
||||||
@@ -249,6 +134,7 @@ def run_test(verbose=True):
|
|||||||
enable_cpu_offload=True,
|
enable_cpu_offload=True,
|
||||||
kvcache_block_size=BLOCK_SIZE,
|
kvcache_block_size=BLOCK_SIZE,
|
||||||
num_gpu_blocks=NUM_GPU_BLOCKS,
|
num_gpu_blocks=NUM_GPU_BLOCKS,
|
||||||
|
dtype="float16",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get model info
|
# Get model info
|
||||||
@@ -256,218 +142,55 @@ def run_test(verbose=True):
|
|||||||
head_dim = llm.model_runner.model.model.layers[0].self_attn.attn.head_dim
|
head_dim = llm.model_runner.model.model.layers[0].self_attn.attn.head_dim
|
||||||
scale = head_dim ** -0.5
|
scale = head_dim ** -0.5
|
||||||
|
|
||||||
if verbose:
|
|
||||||
print(f"Num layers: {num_layers}")
|
|
||||||
print(f"Head dim: {head_dim}")
|
|
||||||
print()
|
|
||||||
|
|
||||||
# Register hooks
|
# Register hooks
|
||||||
hooks = register_hooks(llm)
|
hooks = []
|
||||||
if verbose:
|
for layer_idx, decoder_layer in enumerate(llm.model_runner.model.model.layers):
|
||||||
print(f"Registered {len(hooks)} hooks")
|
# Pre-hook: inject all ones for Q, K, V
|
||||||
|
# pre_hook = decoder_layer.self_attn.attn.register_forward_pre_hook(make_ones_injection_hook())
|
||||||
|
# hooks.append(pre_hook)
|
||||||
|
# Post-hook: capture Q, K, V, output
|
||||||
|
post_hook = decoder_layer.self_attn.attn.register_forward_hook(make_capture_hook(layer_idx))
|
||||||
|
hooks.append(post_hook)
|
||||||
|
|
||||||
# Generate random prompt
|
# Run inference
|
||||||
seed(42)
|
seed(42)
|
||||||
prompt_token_ids = [[randint(0, 10000) for _ in range(INPUT_LEN)]]
|
prompt_token_ids = [[randint(0, 10000) for _ in range(INPUT_LEN)]]
|
||||||
|
outputs = llm.generate(prompt_token_ids, SamplingParams(temperature=0.6, max_tokens=1), use_tqdm=False)
|
||||||
# Run prefill only (max_tokens=1)
|
|
||||||
if verbose:
|
|
||||||
print("Running inference...")
|
|
||||||
sampling_params = SamplingParams(temperature=0.6, max_tokens=1)
|
|
||||||
outputs = llm.generate(prompt_token_ids, sampling_params, use_tqdm=False)
|
|
||||||
|
|
||||||
# Remove hooks
|
# Remove hooks
|
||||||
for hook in hooks:
|
for hook in hooks:
|
||||||
hook.remove()
|
hook.remove()
|
||||||
|
|
||||||
# Analyze captures
|
# Get CPU cache reference
|
||||||
if verbose:
|
offload_engine = llm.model_runner.kvcache_manager.offload_engine
|
||||||
print(f"\nCaptured {len(captures)} attention calls")
|
k_cache_cpu = offload_engine.k_cache_cpu.clone()
|
||||||
|
v_cache_cpu = offload_engine.v_cache_cpu.clone()
|
||||||
|
|
||||||
# Group by layer and chunk
|
# Verify: compare actual output with reference computed from CPU cache
|
||||||
chunks_per_layer = {}
|
all_passed = True
|
||||||
for c in captures:
|
num_chunks = INPUT_LEN // BLOCK_SIZE
|
||||||
layer_id = c['layer_id']
|
|
||||||
chunk_idx = c['chunk_idx']
|
|
||||||
if layer_id not in chunks_per_layer:
|
|
||||||
chunks_per_layer[layer_id] = set()
|
|
||||||
chunks_per_layer[layer_id].add(chunk_idx)
|
|
||||||
|
|
||||||
if verbose:
|
for idx,c in enumerate(captures):
|
||||||
print("Chunks per layer:", {k: sorted(v) for k, v in chunks_per_layer.items()})
|
|
||||||
print()
|
|
||||||
|
|
||||||
# First, verify CPU cache data integrity
|
|
||||||
if verbose:
|
|
||||||
print("\n--- CPU Cache Verification (Layer 0) ---")
|
|
||||||
# Get original k from chunk 0 and chunk 1 captures
|
|
||||||
chunk0_k = None
|
|
||||||
chunk1_k = None
|
|
||||||
chunk2_capture = None
|
|
||||||
for c in captures:
|
|
||||||
if c['layer_id'] == 0:
|
|
||||||
if c['chunk_idx'] == 0:
|
|
||||||
chunk0_k = c['k']
|
|
||||||
elif c['chunk_idx'] == 1:
|
|
||||||
chunk1_k = c['k']
|
|
||||||
elif c['chunk_idx'] == 2:
|
|
||||||
chunk2_capture = c
|
|
||||||
|
|
||||||
if chunk0_k is not None and chunk2_capture is not None and 'cpu_k0' in chunk2_capture:
|
|
||||||
cpu_k0 = chunk2_capture['cpu_k0']
|
|
||||||
diff_k0 = (chunk0_k - cpu_k0).abs().max().item()
|
|
||||||
print(f"Chunk 0 k vs CPU cache block 0: max_diff={diff_k0:.6f}")
|
|
||||||
if diff_k0 > 1e-3:
|
|
||||||
print(f" WARNING: CPU cache block 0 differs from original chunk 0 k!")
|
|
||||||
print(f" Original k[0,0,:5] = {chunk0_k[0,0,:5].tolist()}")
|
|
||||||
print(f" CPU k0[0,0,:5] = {cpu_k0[0,0,:5].tolist()}")
|
|
||||||
|
|
||||||
if chunk1_k is not None and chunk2_capture is not None and 'cpu_k1' in chunk2_capture:
|
|
||||||
cpu_k1 = chunk2_capture['cpu_k1']
|
|
||||||
diff_k1 = (chunk1_k - cpu_k1).abs().max().item()
|
|
||||||
print(f"Chunk 1 k vs CPU cache block 1: max_diff={diff_k1:.6f}")
|
|
||||||
if diff_k1 > 1e-3:
|
|
||||||
print(f" WARNING: CPU cache block 1 differs from original chunk 1 k!")
|
|
||||||
print(f" Original k[0,0,:5] = {chunk1_k[0,0,:5].tolist()}")
|
|
||||||
print(f" CPU k1[0,0,:5] = {cpu_k1[0,0,:5].tolist()}")
|
|
||||||
|
|
||||||
print()
|
|
||||||
|
|
||||||
# ================================================================
|
|
||||||
# Test 1: Verify against merge-based reference (same algorithm)
|
|
||||||
# ================================================================
|
|
||||||
if verbose:
|
|
||||||
print("--- Test 1: Merge-based Reference (verifies merge algorithm) ---")
|
|
||||||
|
|
||||||
all_passed_merge = True
|
|
||||||
results_merge = []
|
|
||||||
first_fail_debug = True
|
|
||||||
|
|
||||||
for c in captures:
|
|
||||||
layer_id = c['layer_id']
|
layer_id = c['layer_id']
|
||||||
chunk_idx = c['chunk_idx']
|
chunk_idx = c['chunk_idx']
|
||||||
|
|
||||||
|
# Skip chunk 0 (no previous KV to load)
|
||||||
if chunk_idx == 0:
|
if chunk_idx == 0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
debug_this = (chunk_idx >= 2 and layer_id == 0 and first_fail_debug)
|
ref_output = compute_reference(layer_id, chunk_idx, scale, k_cache_cpu, v_cache_cpu, BLOCK_SIZE)
|
||||||
ref_output = compute_reference(layer_id, chunk_idx, scale, debug=debug_this)
|
|
||||||
if ref_output is None:
|
if ref_output is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
actual_output = c['output']
|
actual_output = c['output']
|
||||||
diff = (actual_output - ref_output).abs()
|
diff = (actual_output - ref_output).abs()
|
||||||
max_diff = diff.max().item()
|
max_diff = diff.max().item()
|
||||||
mean_diff = diff.mean().item()
|
|
||||||
|
|
||||||
tol = 1e-2
|
passed = max_diff < 1e-1 # float16 tolerance
|
||||||
passed = max_diff < tol
|
all_passed = all_passed and passed
|
||||||
all_passed_merge = all_passed_merge and passed
|
|
||||||
|
|
||||||
status = "PASS" if passed else "FAIL"
|
|
||||||
results_merge.append((layer_id, chunk_idx, passed, max_diff, mean_diff))
|
|
||||||
|
|
||||||
if verbose:
|
|
||||||
print(f"[{status}] Layer {layer_id:2d}, Chunk {chunk_idx}: "
|
|
||||||
f"max_diff={max_diff:.6f} mean_diff={mean_diff:.8f}")
|
|
||||||
|
|
||||||
if not passed and first_fail_debug:
|
|
||||||
first_fail_debug = False
|
|
||||||
print(f" Debug: actual_output shape={actual_output.shape}, mean={actual_output.mean().item():.4f}")
|
|
||||||
print(f" Debug: ref_output shape={ref_output.shape}, mean={ref_output.mean().item():.4f}")
|
|
||||||
max_idx = diff.argmax()
|
|
||||||
flat_actual = actual_output.flatten()
|
|
||||||
flat_ref = ref_output.flatten()
|
|
||||||
print(f" Debug: max_diff at idx={max_idx.item()}, actual={flat_actual[max_idx].item():.4f}, ref={flat_ref[max_idx].item():.4f}")
|
|
||||||
|
|
||||||
print()
|
|
||||||
|
|
||||||
# ================================================================
|
|
||||||
# Test 2: Verify against standard flash attention (single pass)
|
|
||||||
# ================================================================
|
|
||||||
if verbose:
|
|
||||||
print("--- Test 2: Standard FlashAttn Reference (verifies correctness vs non-chunked) ---")
|
|
||||||
|
|
||||||
all_passed_standard = True
|
|
||||||
results_standard = []
|
|
||||||
first_fail_debug = True
|
|
||||||
|
|
||||||
for c in captures:
|
|
||||||
layer_id = c['layer_id']
|
|
||||||
chunk_idx = c['chunk_idx']
|
|
||||||
|
|
||||||
if chunk_idx == 0:
|
|
||||||
continue
|
|
||||||
|
|
||||||
debug_this = (chunk_idx >= 2 and layer_id == 0 and first_fail_debug)
|
|
||||||
std_ref_output = compute_standard_reference(layer_id, chunk_idx, scale, debug=debug_this)
|
|
||||||
if std_ref_output is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
actual_output = c['output']
|
|
||||||
diff = (actual_output - std_ref_output).abs()
|
|
||||||
max_diff = diff.max().item()
|
|
||||||
mean_diff = diff.mean().item()
|
|
||||||
|
|
||||||
tol = 1e-2
|
|
||||||
passed = max_diff < tol
|
|
||||||
all_passed_standard = all_passed_standard and passed
|
|
||||||
|
|
||||||
status = "PASS" if passed else "FAIL"
|
|
||||||
results_standard.append((layer_id, chunk_idx, passed, max_diff, mean_diff))
|
|
||||||
|
|
||||||
if verbose:
|
|
||||||
print(f"[{status}] Layer {layer_id:2d}, Chunk {chunk_idx}: "
|
|
||||||
f"max_diff={max_diff:.6f} mean_diff={mean_diff:.8f}")
|
|
||||||
|
|
||||||
if not passed and first_fail_debug:
|
|
||||||
first_fail_debug = False
|
|
||||||
print(f" Debug: actual_output shape={actual_output.shape}, mean={actual_output.mean().item():.4f}")
|
|
||||||
print(f" Debug: std_ref_output shape={std_ref_output.shape}, mean={std_ref_output.mean().item():.4f}")
|
|
||||||
max_idx = diff.argmax()
|
|
||||||
flat_actual = actual_output.flatten()
|
|
||||||
flat_ref = std_ref_output.flatten()
|
|
||||||
print(f" Debug: max_diff at idx={max_idx.item()}, actual={flat_actual[max_idx].item():.4f}, ref={flat_ref[max_idx].item():.4f}")
|
|
||||||
|
|
||||||
print()
|
|
||||||
print("=" * 70)
|
|
||||||
|
|
||||||
# Summary
|
|
||||||
total_merge = len(results_merge)
|
|
||||||
passed_merge = sum(1 for r in results_merge if r[2])
|
|
||||||
total_standard = len(results_standard)
|
|
||||||
passed_standard = sum(1 for r in results_standard if r[2])
|
|
||||||
|
|
||||||
print(f"Merge-based reference: {passed_merge}/{total_merge} tests passed")
|
|
||||||
print(f"Standard FlashAttn ref: {passed_standard}/{total_standard} tests passed")
|
|
||||||
|
|
||||||
all_passed = all_passed_merge and all_passed_standard
|
|
||||||
|
|
||||||
if not all_passed_merge:
|
|
||||||
print("\nFailed merge-based tests:")
|
|
||||||
for layer_id, chunk_idx, passed, max_diff, mean_diff in results_merge:
|
|
||||||
if not passed:
|
if not passed:
|
||||||
print(f" - Layer {layer_id}, Chunk {chunk_idx}: max_diff={max_diff:.6f}")
|
print(f"[FAIL] Layer {layer_id}, Chunk {chunk_idx}: max_diff={max_diff:.6f}")
|
||||||
|
__import__('pdb').set_trace()
|
||||||
|
|
||||||
if not all_passed_standard:
|
print(f"test_chunked_prefill_hook: {'PASSED' if all_passed else 'FAILED'}")
|
||||||
print("\nFailed standard FlashAttn tests:")
|
|
||||||
for layer_id, chunk_idx, passed, max_diff, mean_diff in results_standard:
|
|
||||||
if not passed:
|
|
||||||
print(f" - Layer {layer_id}, Chunk {chunk_idx}: max_diff={max_diff:.6f}")
|
|
||||||
|
|
||||||
print()
|
|
||||||
return all_passed
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Main
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
passed = run_test(verbose=True)
|
|
||||||
|
|
||||||
if passed:
|
|
||||||
print("test_chunked_prefill_hook: PASSED")
|
|
||||||
else:
|
|
||||||
print("test_chunked_prefill_hook: FAILED")
|
|
||||||
exit(1)
|
|
||||||
|
|||||||
Reference in New Issue
Block a user