Files
nano-vllm/tests/test_chunked_decode_hook.py

215 lines
6.7 KiB
Python

"""
Correctness test for chunked decode attention.
Captures Q and output during inference, then computes reference using
CPU KV cache with standard flash attention.
"""
import os
os.environ["NANOVLLM_LOG_LEVEL"] = "WARNING"
import torch
from random import randint, seed
from typing import Dict, List
from nanovllm import LLM, SamplingParams
from nanovllm.utils.context import get_context
from flash_attn.flash_attn_interface import flash_attn_func
# Config
MODEL_PATH = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/")
MAX_MODEL_LEN = 128 * 1024
NUM_GPU_BLOCKS = 2
INPUT_LEN = 16 * 1024
NUM_DECODE_TOKENS = 5
BLOCK_SIZE = 1024
# State
prefill_captures: List[Dict] = []
decode_captures: List[Dict] = []
def make_ones_injection_hook():
"""Inject Q=K=V=1.0 for deterministic testing."""
def hook(module, inputs):
q, k, v = inputs[0], inputs[1], inputs[2]
q_ones = torch.ones_like(q)
k_ones = torch.ones_like(k)
v_ones = torch.ones_like(v)
return (q_ones, k_ones, v_ones) + inputs[3:]
return hook
def make_capture_hook(layer_id: int):
"""Capture Q, K, V, output during inference."""
def hook(module, inputs, output):
ctx = get_context()
q, k, v = inputs
if ctx.is_prefill:
chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0
prefill_captures.append({
'layer_id': layer_id,
'chunk_idx': chunk_idx,
'q': q.clone().cpu(),
'k': k.clone().cpu(),
'v': v.clone().cpu(),
'output': output.clone().cpu(),
})
else:
decode_step = len([c for c in decode_captures if c['layer_id'] == layer_id])
decode_captures.append({
'layer_id': layer_id,
'decode_step': decode_step,
'q': q.clone().cpu(),
'k': k.clone().cpu(),
'v': v.clone().cpu(),
'output': output.clone().cpu(),
})
return hook
def compute_decode_reference(layer_id: int, decode_step: int, scale: float,
k_cache_cpu: torch.Tensor, v_cache_cpu: torch.Tensor,
block_size: int, num_prefill_chunks: int) -> torch.Tensor:
"""
Compute reference decode output using CPU KV cache and standard flash attention.
For decode, query attends to:
1. All prefill KV (from CPU cache)
2. All previous decode tokens (from captured decode k, v)
"""
# Get decode capture for this layer and step
decode_cap = None
for c in decode_captures:
if c['layer_id'] == layer_id and c['decode_step'] == decode_step:
decode_cap = c
break
if decode_cap is None:
return None
# Query: single decode token
q = decode_cap['q'].cuda() # [1, num_heads, head_dim]
q_batched = q.unsqueeze(0) # [1, 1, num_heads, head_dim]
# Collect all K, V: prefill chunks from CPU cache + decode tokens from captures
all_k = []
all_v = []
# 1. Prefill chunks from CPU cache
for cidx in range(num_prefill_chunks):
# Get prefill capture to know the sequence length for this chunk
prefill_cap = None
for c in prefill_captures:
if c['layer_id'] == layer_id and c['chunk_idx'] == cidx:
prefill_cap = c
break
if prefill_cap is not None:
seq_len = prefill_cap['q'].shape[0]
k = k_cache_cpu[layer_id, cidx, :seq_len].cuda()
v = v_cache_cpu[layer_id, cidx, :seq_len].cuda()
all_k.append(k)
all_v.append(v)
# 2. Decode tokens from captures (up to and including current step)
for step in range(decode_step + 1):
for c in decode_captures:
if c['layer_id'] == layer_id and c['decode_step'] == step:
all_k.append(c['k'].cuda())
all_v.append(c['v'].cuda())
break
if not all_k:
return None
# Concatenate all K, V
full_k = torch.cat(all_k, dim=0).unsqueeze(0) # [1, total_len, kv_heads, head_dim]
full_v = torch.cat(all_v, dim=0).unsqueeze(0)
# Run flash attention (non-causal since we explicitly control what KV to include)
output = flash_attn_func(
q_batched, full_k, full_v,
softmax_scale=scale,
causal=False,
)
return output.squeeze(0).squeeze(0).cpu() # [num_heads, head_dim]
# ============================================================
# Main
# ============================================================
llm = LLM(
MODEL_PATH,
enforce_eager=True,
max_model_len=MAX_MODEL_LEN,
max_num_batched_tokens=MAX_MODEL_LEN,
enable_cpu_offload=True,
kvcache_block_size=BLOCK_SIZE,
num_gpu_blocks=NUM_GPU_BLOCKS,
dtype="float16",
)
# Get model info
num_layers = len(llm.model_runner.model.model.layers)
head_dim = llm.model_runner.model.model.layers[0].self_attn.attn.head_dim
scale = head_dim ** -0.5
# Register hooks
hooks = []
for layer_idx, decoder_layer in enumerate(llm.model_runner.model.model.layers):
# Pre-hook: inject all ones for Q, K, V
# pre_hook = decoder_layer.self_attn.attn.register_forward_pre_hook(make_ones_injection_hook())
# hooks.append(pre_hook)
# Post-hook: capture Q, K, V, output
post_hook = decoder_layer.self_attn.attn.register_forward_hook(make_capture_hook(layer_idx))
hooks.append(post_hook)
# Run inference
seed(42)
prompt_token_ids = [[randint(0, 10000) for _ in range(INPUT_LEN)]]
outputs = llm.generate(prompt_token_ids, SamplingParams(temperature=0.6, max_tokens=NUM_DECODE_TOKENS), use_tqdm=False)
# Remove hooks
for hook in hooks:
hook.remove()
# Get CPU cache reference
offload_engine = llm.model_runner.kvcache_manager.offload_engine
k_cache_cpu = offload_engine.k_cache_cpu.clone()
v_cache_cpu = offload_engine.v_cache_cpu.clone()
# Calculate number of prefill chunks
num_prefill_chunks = INPUT_LEN // BLOCK_SIZE
# Verify decode outputs
all_passed = True
for c in decode_captures:
layer_id = c['layer_id']
decode_step = c['decode_step']
ref_output = compute_decode_reference(
layer_id, decode_step, scale,
k_cache_cpu, v_cache_cpu, BLOCK_SIZE, num_prefill_chunks
)
if ref_output is None:
continue
actual_output = c['output'].squeeze(0)
if actual_output.dim() == 3:
actual_output = actual_output.squeeze(0)
diff = (actual_output - ref_output).abs()
max_diff = diff.max().item()
passed = max_diff < 1e-1
all_passed = all_passed and passed
# if not passed:
print(f"[FAIL] Layer {layer_id}, Decode {decode_step}: max_diff={max_diff:.6f}")
print(f"test_chunked_decode_hook: {'PASSED' if all_passed else 'FAILED'}")