Files
nano-vllm/tests/test_chunked_prefill_hook.py

197 lines
5.9 KiB
Python

"""
Correctness test for chunked prefill attention.
Captures Q and output during inference, then computes reference using
CPU KV cache with standard flash attention.
"""
import os
os.environ["NANOVLLM_LOG_LEVEL"] = "WARNING"
import torch
from random import randint, seed
from typing import Dict, List
from nanovllm import LLM, SamplingParams
from nanovllm.utils.context import get_context
from flash_attn.flash_attn_interface import flash_attn_varlen_func
# Config
MODEL_PATH = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/")
MAX_MODEL_LEN = 128 * 1024
NUM_GPU_BLOCKS = 2
INPUT_LEN = 16 * 1024
BLOCK_SIZE = 1024
# State - capture Q and output for each (layer, chunk)
captures: List[Dict] = []
def make_ones_injection_hook():
"""Inject Q=K=V=1.0 for deterministic testing."""
def hook(module, inputs):
ctx = get_context()
if not ctx.is_prefill:
return inputs
q, k, v = inputs[0], inputs[1], inputs[2]
q_ones = torch.ones_like(q)
k_ones = torch.ones_like(k)
v_ones = torch.ones_like(v)
return (q_ones, k_ones, v_ones) + inputs[3:]
return hook
def make_capture_hook(layer_id: int):
"""Capture Q and output during prefill."""
def hook(module, inputs, output):
ctx = get_context()
if not ctx.is_prefill:
return
q, k, v = inputs
chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0
captures.append({
'layer_id': layer_id,
'chunk_idx': chunk_idx,
'q': q.clone().cpu(),
'k': k.clone().cpu(),
'v': v.clone().cpu(),
'output': output.clone().cpu(),
})
return hook
def compute_reference(layer_id: int, chunk_idx: int, scale: float,
k_cache_cpu: torch.Tensor, v_cache_cpu: torch.Tensor,
block_size: int) -> torch.Tensor:
"""
Compute reference output using CPU KV cache and standard flash attention.
Concatenates all Q, K, V from chunks 0..chunk_idx and runs causal attention,
then extracts output for the current chunk.
"""
# Get all captures for this layer up to chunk_idx
layer_captures = [c for c in captures
if c['layer_id'] == layer_id and c['chunk_idx'] <= chunk_idx]
layer_captures = sorted(layer_captures, key=lambda x: x['chunk_idx'])
if not layer_captures:
return None
# Collect Q from captures, K/V from CPU cache
all_q = []
all_k = []
all_v = []
chunk_lengths = []
for c in layer_captures:
cidx = c['chunk_idx']
q = c['q'].cuda() # [seqlen, nheads, headdim]
all_q.append(q)
chunk_lengths.append(q.shape[0])
# Get K, V from CPU cache (already offloaded during prefill)
# CPU cache shape: [num_layers, num_blocks, block_size, kv_heads, head_dim]
k = k_cache_cpu[layer_id, cidx, :q.shape[0]].cuda()
v = v_cache_cpu[layer_id, cidx, :q.shape[0]].cuda()
all_k.append(k)
all_v.append(v)
# Concatenate
full_q = torch.cat(all_q, dim=0)
full_k = torch.cat(all_k, dim=0)
full_v = torch.cat(all_v, dim=0)
total_len = full_q.shape[0]
# Run standard causal flash attention
cu_seqlens = torch.tensor([0, total_len], dtype=torch.int32, device='cuda')
full_o = flash_attn_varlen_func(
full_q, full_k, full_v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=total_len,
max_seqlen_k=total_len,
softmax_scale=scale,
causal=True,
)
# Extract output for current chunk
start_pos = sum(chunk_lengths[:-1])
end_pos = sum(chunk_lengths)
return full_o[start_pos:end_pos].cpu()
# ============================================================
# Main
# ============================================================
llm = LLM(
MODEL_PATH,
enforce_eager=True,
max_model_len=MAX_MODEL_LEN,
max_num_batched_tokens=MAX_MODEL_LEN,
enable_cpu_offload=True,
kvcache_block_size=BLOCK_SIZE,
num_gpu_blocks=NUM_GPU_BLOCKS,
dtype="float16",
)
# Get model info
num_layers = len(llm.model_runner.model.model.layers)
head_dim = llm.model_runner.model.model.layers[0].self_attn.attn.head_dim
scale = head_dim ** -0.5
# Register hooks
hooks = []
for layer_idx, decoder_layer in enumerate(llm.model_runner.model.model.layers):
# Pre-hook: inject all ones for Q, K, V
# pre_hook = decoder_layer.self_attn.attn.register_forward_pre_hook(make_ones_injection_hook())
# hooks.append(pre_hook)
# Post-hook: capture Q, K, V, output
post_hook = decoder_layer.self_attn.attn.register_forward_hook(make_capture_hook(layer_idx))
hooks.append(post_hook)
# Run inference
seed(42)
prompt_token_ids = [[randint(0, 10000) for _ in range(INPUT_LEN)]]
outputs = llm.generate(prompt_token_ids, SamplingParams(temperature=0.6, max_tokens=1), use_tqdm=False)
# Remove hooks
for hook in hooks:
hook.remove()
# Get CPU cache reference
offload_engine = llm.model_runner.kvcache_manager.offload_engine
k_cache_cpu = offload_engine.k_cache_cpu.clone()
v_cache_cpu = offload_engine.v_cache_cpu.clone()
# Verify: compare actual output with reference computed from CPU cache
all_passed = True
num_chunks = INPUT_LEN // BLOCK_SIZE
for idx,c in enumerate(captures):
layer_id = c['layer_id']
chunk_idx = c['chunk_idx']
# Skip chunk 0 (no previous KV to load)
if chunk_idx == 0:
continue
ref_output = compute_reference(layer_id, chunk_idx, scale, k_cache_cpu, v_cache_cpu, BLOCK_SIZE)
if ref_output is None:
continue
actual_output = c['output']
diff = (actual_output - ref_output).abs()
max_diff = diff.max().item()
passed = max_diff < 1e-1 # float16 tolerance
all_passed = all_passed and passed
if not passed:
print(f"[FAIL] Layer {layer_id}, Chunk {chunk_idx}: max_diff={max_diff:.6f}")
__import__('pdb').set_trace()
print(f"test_chunked_prefill_hook: {'PASSED' if all_passed else 'FAILED'}")