392 lines
16 KiB
Python
392 lines
16 KiB
Python
"""
|
|
Correctness test for chunked decode attention.
|
|
|
|
Captures Q and output during inference, then computes reference using
|
|
CPU KV cache with standard flash attention.
|
|
"""
|
|
|
|
import os
|
|
os.environ["NANOVLLM_LOG_LEVEL"] = "WARNING"
|
|
|
|
import torch
|
|
from random import randint, seed
|
|
from typing import Dict, List
|
|
from nanovllm import LLM, SamplingParams
|
|
from nanovllm.utils.context import get_context
|
|
from flash_attn.flash_attn_interface import flash_attn_func
|
|
|
|
# Config
|
|
MODEL_PATH = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/")
|
|
MAX_MODEL_LEN = 128 * 1024
|
|
NUM_GPU_BLOCKS = 2
|
|
INPUT_LEN = 16 * 1024
|
|
NUM_DECODE_TOKENS = 5
|
|
BLOCK_SIZE = 1024
|
|
|
|
# State
|
|
prefill_captures: List[Dict] = []
|
|
decode_captures: List[Dict] = []
|
|
|
|
|
|
def make_ones_injection_hook():
|
|
"""Inject Q=K=V=1.0 for deterministic testing."""
|
|
def hook(module, inputs):
|
|
q, k, v = inputs[0], inputs[1], inputs[2]
|
|
q_ones = torch.ones_like(q)
|
|
k_ones = torch.ones_like(k)
|
|
v_ones = torch.ones_like(v)
|
|
return (q_ones, k_ones, v_ones) + inputs[3:]
|
|
return hook
|
|
|
|
|
|
def make_capture_hook(layer_id: int):
|
|
"""Capture Q, K, V, output during inference."""
|
|
def hook(module, inputs, output):
|
|
ctx = get_context()
|
|
q, k, v = inputs
|
|
|
|
if ctx.is_prefill:
|
|
chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0
|
|
prefill_captures.append({
|
|
'layer_id': layer_id,
|
|
'chunk_idx': chunk_idx,
|
|
'q': q.clone().cpu(),
|
|
'k': k.clone().cpu(),
|
|
'v': v.clone().cpu(),
|
|
'output': output.clone().cpu(),
|
|
})
|
|
else:
|
|
decode_step = len([c for c in decode_captures if c['layer_id'] == layer_id])
|
|
decode_captures.append({
|
|
'layer_id': layer_id,
|
|
'decode_step': decode_step,
|
|
'q': q.clone().cpu(),
|
|
'k': k.clone().cpu(),
|
|
'v': v.clone().cpu(),
|
|
'output': output.clone().cpu(),
|
|
})
|
|
return hook
|
|
|
|
|
|
def compute_decode_reference(layer_id: int, decode_step: int, scale: float,
|
|
k_cache_cpu: torch.Tensor, v_cache_cpu: torch.Tensor,
|
|
block_size: int, num_prefill_chunks: int) -> torch.Tensor:
|
|
"""
|
|
Compute reference decode output using CPU KV cache and standard flash attention.
|
|
|
|
For decode, query attends to:
|
|
1. All prefill KV (from CPU cache)
|
|
2. All previous decode tokens (from captured decode k, v)
|
|
"""
|
|
# Get decode capture for this layer and step
|
|
decode_cap = None
|
|
for c in decode_captures:
|
|
if c['layer_id'] == layer_id and c['decode_step'] == decode_step:
|
|
decode_cap = c
|
|
break
|
|
|
|
if decode_cap is None:
|
|
return None
|
|
|
|
# Query: single decode token
|
|
q = decode_cap['q'].cuda() # [1, num_heads, head_dim]
|
|
q_batched = q.unsqueeze(0) # [1, 1, num_heads, head_dim]
|
|
|
|
# Collect all K, V: prefill chunks from captures + decode tokens from captures
|
|
# NOTE: We use prefill captures directly instead of CPU cache because
|
|
# the CPU block ID may not equal the chunk index.
|
|
all_k = []
|
|
all_v = []
|
|
|
|
# 1. Prefill chunks from captures (use captured K/V, not CPU cache)
|
|
for cidx in range(num_prefill_chunks):
|
|
prefill_cap = None
|
|
for c in prefill_captures:
|
|
if c['layer_id'] == layer_id and c['chunk_idx'] == cidx:
|
|
prefill_cap = c
|
|
break
|
|
|
|
if prefill_cap is not None:
|
|
# Use captured K/V directly (guaranteed to be correct layer data)
|
|
all_k.append(prefill_cap['k'].cuda())
|
|
all_v.append(prefill_cap['v'].cuda())
|
|
|
|
# 2. Decode tokens from captures (up to and including current step)
|
|
for step in range(decode_step + 1):
|
|
for c in decode_captures:
|
|
if c['layer_id'] == layer_id and c['decode_step'] == step:
|
|
all_k.append(c['k'].cuda())
|
|
all_v.append(c['v'].cuda())
|
|
break
|
|
|
|
if not all_k:
|
|
return None
|
|
|
|
# Concatenate all K, V
|
|
full_k = torch.cat(all_k, dim=0).unsqueeze(0) # [1, total_len, kv_heads, head_dim]
|
|
full_v = torch.cat(all_v, dim=0).unsqueeze(0)
|
|
|
|
# Run flash attention (non-causal since we explicitly control what KV to include)
|
|
output = flash_attn_func(
|
|
q_batched, full_k, full_v,
|
|
softmax_scale=scale,
|
|
causal=False,
|
|
)
|
|
|
|
return output.squeeze(0).squeeze(0).cpu() # [num_heads, head_dim]
|
|
|
|
|
|
# ============================================================
|
|
# Main
|
|
# ============================================================
|
|
|
|
llm = LLM(
|
|
MODEL_PATH,
|
|
enforce_eager=True,
|
|
max_model_len=MAX_MODEL_LEN,
|
|
max_num_batched_tokens=MAX_MODEL_LEN,
|
|
enable_cpu_offload=True,
|
|
kvcache_block_size=BLOCK_SIZE,
|
|
num_gpu_blocks=NUM_GPU_BLOCKS,
|
|
dtype="float16",
|
|
)
|
|
|
|
# Get model info
|
|
num_layers = len(llm.model_runner.model.model.layers)
|
|
head_dim = llm.model_runner.model.model.layers[0].self_attn.attn.head_dim
|
|
scale = head_dim ** -0.5
|
|
|
|
# Register hooks
|
|
hooks = []
|
|
for layer_idx, decoder_layer in enumerate(llm.model_runner.model.model.layers):
|
|
# Pre-hook: inject all ones for Q, K, V
|
|
# pre_hook = decoder_layer.self_attn.attn.register_forward_pre_hook(make_ones_injection_hook())
|
|
# hooks.append(pre_hook)
|
|
# Post-hook: capture Q, K, V, output
|
|
post_hook = decoder_layer.self_attn.attn.register_forward_hook(make_capture_hook(layer_idx))
|
|
hooks.append(post_hook)
|
|
|
|
# Run inference
|
|
seed(42)
|
|
prompt_token_ids = [[randint(0, 10000) for _ in range(INPUT_LEN)]]
|
|
outputs = llm.generate(prompt_token_ids, SamplingParams(temperature=0.6, max_tokens=NUM_DECODE_TOKENS), use_tqdm=False)
|
|
|
|
# Remove hooks
|
|
for hook in hooks:
|
|
hook.remove()
|
|
|
|
# Get CPU cache reference
|
|
offload_engine = llm.model_runner.kvcache_manager.offload_engine
|
|
k_cache_cpu = offload_engine.k_cache_cpu.clone()
|
|
v_cache_cpu = offload_engine.v_cache_cpu.clone()
|
|
|
|
# Calculate number of prefill chunks
|
|
num_prefill_chunks = INPUT_LEN // BLOCK_SIZE
|
|
|
|
# Debug: Compare decode_buffer with captured K/V
|
|
print("\n=== DEBUG: Comparing decode_buffer with captured K/V ===")
|
|
decode_k_buffer = offload_engine.decode_k_buffer.clone().cpu()
|
|
for step in range(NUM_DECODE_TOKENS):
|
|
for layer_id in [0, 17, 35]: # Sample a few layers
|
|
# Find captured K for this step and layer
|
|
for c in decode_captures:
|
|
if c['layer_id'] == layer_id and c['decode_step'] == step:
|
|
captured_k = c['k'].squeeze(0) # [kv_heads, head_dim]
|
|
buffer_k = decode_k_buffer[layer_id, step] # [kv_heads, head_dim]
|
|
diff = (captured_k - buffer_k).abs().max().item()
|
|
print(f"Step {step}, Layer {layer_id}: captured vs buffer max_diff={diff:.6f}")
|
|
break
|
|
|
|
# Debug: Verify that decode_buffer slices match concatenated captures
|
|
print("\n=== DEBUG: Verifying decode_buffer slices ===")
|
|
for layer_id in [0]:
|
|
for decode_step in [1, 2]: # Check steps that use multiple tokens
|
|
# Build expected slice from captures
|
|
expected_k_list = []
|
|
for step in range(decode_step + 1):
|
|
for c in decode_captures:
|
|
if c['layer_id'] == layer_id and c['decode_step'] == step:
|
|
expected_k_list.append(c['k'].squeeze(0)) # [kv_heads, head_dim]
|
|
break
|
|
if expected_k_list:
|
|
expected_k = torch.stack(expected_k_list, dim=0) # [num_tokens, kv_heads, head_dim]
|
|
buffer_slice = decode_k_buffer[layer_id, 0:decode_step+1]
|
|
diff = (expected_k - buffer_slice).abs().max().item()
|
|
print(f"Decode step {decode_step}, Layer {layer_id}: buffer slice vs expected max_diff={diff:.6f}")
|
|
# Print first values
|
|
print(f" Buffer[0,0,0]={buffer_slice[0,0,0].item():.6f}, Expected[0,0,0]={expected_k[0,0,0].item():.6f}")
|
|
if decode_step >= 1:
|
|
print(f" Buffer[1,0,0]={buffer_slice[1,0,0].item():.6f}, Expected[1,0,0]={expected_k[1,0,0].item():.6f}")
|
|
|
|
# Debug: Print expected K value for block 0, layer 0 (to compare with actual loading)
|
|
print("\n=== DEBUG: Expected K values for block 0, layer 0 ===")
|
|
for c in prefill_captures:
|
|
if c['layer_id'] == 0 and c['chunk_idx'] == 0:
|
|
print(f"Captured K[0,0,0] for layer 0, chunk 0: {c['k'][0,0,0].item():.6f}")
|
|
break
|
|
print(f"CPU cache K[0,0,0,0,0] for layer 0, block 0: {k_cache_cpu[0,0,0,0,0].item():.6f}")
|
|
|
|
# Debug: Compare CPU cache with captured prefill K/V
|
|
print("\n=== DEBUG: Comparing CPU cache with captured prefill K/V ===")
|
|
for chunk_idx in [0, 7, 15]: # Sample a few chunks
|
|
for layer_id in [0, 17, 35]: # Sample a few layers
|
|
# Find captured K for this chunk and layer
|
|
for c in prefill_captures:
|
|
if c['layer_id'] == layer_id and c['chunk_idx'] == chunk_idx:
|
|
captured_k = c['k'] # [seq_len, kv_heads, head_dim]
|
|
cpu_cache_k = k_cache_cpu[layer_id, chunk_idx, :captured_k.shape[0]]
|
|
diff = (captured_k - cpu_cache_k).abs().max().item()
|
|
print(f"Chunk {chunk_idx}, Layer {layer_id}: captured vs CPU cache max_diff={diff:.6f}")
|
|
break
|
|
|
|
# Debug: Get cpu_block_table to check order
|
|
kvcache_manager = llm.model_runner.kvcache_manager
|
|
# Find the sequence (it should still exist)
|
|
from nanovllm.engine.sequence import Sequence
|
|
for attr_name in ['sequences', '_sequences', 'active_sequences']:
|
|
if hasattr(kvcache_manager, attr_name):
|
|
print(f"Found {attr_name}")
|
|
break
|
|
|
|
# Try to get cpu_block_table through a different way
|
|
print(f"\n=== DEBUG: CPU block order ===")
|
|
# For each prefill capture, check which CPU block it ended up in
|
|
for chunk_idx in range(num_prefill_chunks):
|
|
for c in prefill_captures:
|
|
if c['layer_id'] == 0 and c['chunk_idx'] == chunk_idx:
|
|
# Check if this chunk's K matches any CPU block
|
|
captured_k_first = c['k'][0, 0, 0].item()
|
|
for block_id in range(num_prefill_chunks):
|
|
cpu_k_first = k_cache_cpu[0, block_id, 0, 0, 0].item()
|
|
if abs(captured_k_first - cpu_k_first) < 1e-6:
|
|
print(f"Chunk {chunk_idx} -> CPU block {block_id}")
|
|
break
|
|
break
|
|
|
|
# Debug: Check reference vs actual for decode steps 0 and 1
|
|
# Also compute partial references (prefill only, decode only) to isolate the bug
|
|
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
|
for decode_step in [0, 1]:
|
|
print(f"\n=== DEBUG: Reference vs Actual for layer 0, decode {decode_step} ===")
|
|
layer_id = 0
|
|
# Find the capture
|
|
for c in decode_captures:
|
|
if c['layer_id'] == layer_id and c['decode_step'] == decode_step:
|
|
q = c['q'].cuda() # [1, num_heads, head_dim]
|
|
q_batched = q.unsqueeze(0) # [1, 1, num_heads, head_dim]
|
|
|
|
# Build prefill K/V per-block for block-by-block reference
|
|
prefill_k_blocks = []
|
|
prefill_v_blocks = []
|
|
for cidx in range(num_prefill_chunks):
|
|
for pc in prefill_captures:
|
|
if pc['layer_id'] == layer_id and pc['chunk_idx'] == cidx:
|
|
prefill_k_blocks.append(pc['k'].cuda().unsqueeze(0)) # [1, block_size, kv_heads, head_dim]
|
|
prefill_v_blocks.append(pc['v'].cuda().unsqueeze(0))
|
|
break
|
|
|
|
# Build decode K/V
|
|
decode_k_list = []
|
|
decode_v_list = []
|
|
for step in range(decode_step + 1):
|
|
for dc in decode_captures:
|
|
if dc['layer_id'] == layer_id and dc['decode_step'] == step:
|
|
decode_k_list.append(dc['k'].cuda())
|
|
decode_v_list.append(dc['v'].cuda())
|
|
break
|
|
|
|
full_prefill_k = torch.cat([kb.squeeze(0) for kb in prefill_k_blocks], dim=0).unsqueeze(0)
|
|
full_prefill_v = torch.cat([vb.squeeze(0) for vb in prefill_v_blocks], dim=0).unsqueeze(0)
|
|
full_decode_k = torch.cat(decode_k_list, dim=0).unsqueeze(0)
|
|
full_decode_v = torch.cat(decode_v_list, dim=0).unsqueeze(0)
|
|
|
|
full_k = torch.cat([full_prefill_k, full_decode_k], dim=1)
|
|
full_v = torch.cat([full_prefill_v, full_decode_v], dim=1)
|
|
|
|
print(f"Q shape: {q_batched.shape}")
|
|
print(f"Prefill K shape: {full_prefill_k.shape}")
|
|
print(f"Decode K shape: {full_decode_k.shape}")
|
|
print(f"Full K shape: {full_k.shape}")
|
|
print(f"Total tokens: prefill={num_prefill_chunks * BLOCK_SIZE}, decode={decode_step + 1}")
|
|
|
|
# Reference output (single attention over all)
|
|
ref_output = flash_attn_func(
|
|
q_batched, full_k, full_v,
|
|
softmax_scale=scale,
|
|
causal=False,
|
|
)
|
|
|
|
# Chunked reference: prefill attention + decode attention + merge
|
|
prefill_o, prefill_lse = flash_attn_with_lse(
|
|
q_batched, full_prefill_k, full_prefill_v,
|
|
softmax_scale=scale,
|
|
causal=False,
|
|
)
|
|
decode_o, decode_lse = flash_attn_with_lse(
|
|
q_batched, full_decode_k, full_decode_v,
|
|
softmax_scale=scale,
|
|
causal=False,
|
|
)
|
|
chunked_output, _ = merge_attention_outputs(prefill_o, prefill_lse, decode_o, decode_lse)
|
|
|
|
# Block-by-block reference (simulating ring buffer pipeline)
|
|
block_o_acc, block_lse_acc = None, None
|
|
for bidx, (kb, vb) in enumerate(zip(prefill_k_blocks, prefill_v_blocks)):
|
|
o_blk, lse_blk = flash_attn_with_lse(q_batched, kb, vb, softmax_scale=scale, causal=False)
|
|
if block_o_acc is None:
|
|
block_o_acc, block_lse_acc = o_blk, lse_blk
|
|
else:
|
|
block_o_acc, block_lse_acc = merge_attention_outputs(block_o_acc, block_lse_acc, o_blk, lse_blk)
|
|
|
|
# Compare block-by-block vs single
|
|
block_vs_single_diff = (block_o_acc - prefill_o).abs().max().item()
|
|
print(f"Block-by-block vs Single max_diff: {block_vs_single_diff:.6f}")
|
|
|
|
# Compare full reference vs chunked reference
|
|
ref_vs_chunked_diff = (ref_output - chunked_output).abs().max().item()
|
|
print(f"Reference vs Chunked-reference max_diff: {ref_vs_chunked_diff:.6f}")
|
|
|
|
ref_output = ref_output.squeeze(0).squeeze(0).cpu()
|
|
chunked_output_cpu = chunked_output.squeeze(0).squeeze(0).cpu()
|
|
|
|
# Actual output
|
|
actual_output = c['output'].squeeze(0)
|
|
if actual_output.dim() == 3:
|
|
actual_output = actual_output.squeeze(0)
|
|
|
|
diff_ref = (actual_output - ref_output).abs()
|
|
diff_chunked = (actual_output - chunked_output_cpu).abs()
|
|
print(f"Actual vs Reference max_diff: {diff_ref.max().item():.6f}")
|
|
print(f"Actual vs Chunked-ref max_diff: {diff_chunked.max().item():.6f}")
|
|
break
|
|
print()
|
|
|
|
# Verify decode outputs
|
|
all_passed = True
|
|
|
|
for c in decode_captures:
|
|
layer_id = c['layer_id']
|
|
decode_step = c['decode_step']
|
|
|
|
ref_output = compute_decode_reference(
|
|
layer_id, decode_step, scale,
|
|
k_cache_cpu, v_cache_cpu, BLOCK_SIZE, num_prefill_chunks
|
|
)
|
|
if ref_output is None:
|
|
continue
|
|
|
|
actual_output = c['output'].squeeze(0)
|
|
if actual_output.dim() == 3:
|
|
actual_output = actual_output.squeeze(0)
|
|
|
|
diff = (actual_output - ref_output).abs()
|
|
max_diff = diff.max().item()
|
|
|
|
passed = max_diff < 1e-1
|
|
all_passed = all_passed and passed
|
|
|
|
if not passed:
|
|
print(f"[FAIL] Layer {layer_id}, Decode {decode_step}: max_diff={max_diff:.6f}")
|
|
|
|
print(f"test_chunked_decode_hook: {'PASSED' if all_passed else 'FAILED'}")
|