197 lines
5.9 KiB
Python
197 lines
5.9 KiB
Python
"""
|
|
Correctness test for chunked prefill attention.
|
|
|
|
Captures Q and output during inference, then computes reference using
|
|
CPU KV cache with standard flash attention.
|
|
"""
|
|
|
|
import os
|
|
os.environ["NANOVLLM_LOG_LEVEL"] = "WARNING"
|
|
|
|
import torch
|
|
from random import randint, seed
|
|
from typing import Dict, List
|
|
from nanovllm import LLM, SamplingParams
|
|
from nanovllm.utils.context import get_context
|
|
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
|
|
|
# Config
|
|
MODEL_PATH = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/")
|
|
MAX_MODEL_LEN = 128 * 1024
|
|
NUM_GPU_BLOCKS = 2
|
|
INPUT_LEN = 16 * 1024
|
|
BLOCK_SIZE = 1024
|
|
|
|
# State - capture Q and output for each (layer, chunk)
|
|
captures: List[Dict] = []
|
|
|
|
|
|
def make_ones_injection_hook():
|
|
"""Inject Q=K=V=1.0 for deterministic testing."""
|
|
def hook(module, inputs):
|
|
ctx = get_context()
|
|
if not ctx.is_prefill:
|
|
return inputs
|
|
|
|
q, k, v = inputs[0], inputs[1], inputs[2]
|
|
q_ones = torch.ones_like(q)
|
|
k_ones = torch.ones_like(k)
|
|
v_ones = torch.ones_like(v)
|
|
return (q_ones, k_ones, v_ones) + inputs[3:]
|
|
return hook
|
|
|
|
|
|
def make_capture_hook(layer_id: int):
|
|
"""Capture Q and output during prefill."""
|
|
def hook(module, inputs, output):
|
|
ctx = get_context()
|
|
if not ctx.is_prefill:
|
|
return
|
|
|
|
q, k, v = inputs
|
|
chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0
|
|
|
|
captures.append({
|
|
'layer_id': layer_id,
|
|
'chunk_idx': chunk_idx,
|
|
'q': q.clone().cpu(),
|
|
'k': k.clone().cpu(),
|
|
'v': v.clone().cpu(),
|
|
'output': output.clone().cpu(),
|
|
})
|
|
return hook
|
|
|
|
|
|
def compute_reference(layer_id: int, chunk_idx: int, scale: float,
|
|
k_cache_cpu: torch.Tensor, v_cache_cpu: torch.Tensor,
|
|
block_size: int) -> torch.Tensor:
|
|
"""
|
|
Compute reference output using CPU KV cache and standard flash attention.
|
|
|
|
Concatenates all Q, K, V from chunks 0..chunk_idx and runs causal attention,
|
|
then extracts output for the current chunk.
|
|
"""
|
|
# Get all captures for this layer up to chunk_idx
|
|
layer_captures = [c for c in captures
|
|
if c['layer_id'] == layer_id and c['chunk_idx'] <= chunk_idx]
|
|
layer_captures = sorted(layer_captures, key=lambda x: x['chunk_idx'])
|
|
|
|
if not layer_captures:
|
|
return None
|
|
|
|
# Collect Q from captures, K/V from CPU cache
|
|
all_q = []
|
|
all_k = []
|
|
all_v = []
|
|
chunk_lengths = []
|
|
|
|
for c in layer_captures:
|
|
cidx = c['chunk_idx']
|
|
q = c['q'].cuda() # [seqlen, nheads, headdim]
|
|
all_q.append(q)
|
|
chunk_lengths.append(q.shape[0])
|
|
|
|
# Get K, V from CPU cache (already offloaded during prefill)
|
|
# CPU cache shape: [num_layers, num_blocks, block_size, kv_heads, head_dim]
|
|
k = k_cache_cpu[layer_id, cidx, :q.shape[0]].cuda()
|
|
v = v_cache_cpu[layer_id, cidx, :q.shape[0]].cuda()
|
|
all_k.append(k)
|
|
all_v.append(v)
|
|
|
|
# Concatenate
|
|
full_q = torch.cat(all_q, dim=0)
|
|
full_k = torch.cat(all_k, dim=0)
|
|
full_v = torch.cat(all_v, dim=0)
|
|
total_len = full_q.shape[0]
|
|
|
|
# Run standard causal flash attention
|
|
cu_seqlens = torch.tensor([0, total_len], dtype=torch.int32, device='cuda')
|
|
full_o = flash_attn_varlen_func(
|
|
full_q, full_k, full_v,
|
|
cu_seqlens_q=cu_seqlens,
|
|
cu_seqlens_k=cu_seqlens,
|
|
max_seqlen_q=total_len,
|
|
max_seqlen_k=total_len,
|
|
softmax_scale=scale,
|
|
causal=True,
|
|
)
|
|
|
|
# Extract output for current chunk
|
|
start_pos = sum(chunk_lengths[:-1])
|
|
end_pos = sum(chunk_lengths)
|
|
return full_o[start_pos:end_pos].cpu()
|
|
|
|
|
|
# ============================================================
|
|
# Main
|
|
# ============================================================
|
|
|
|
llm = LLM(
|
|
MODEL_PATH,
|
|
enforce_eager=True,
|
|
max_model_len=MAX_MODEL_LEN,
|
|
max_num_batched_tokens=MAX_MODEL_LEN,
|
|
enable_cpu_offload=True,
|
|
kvcache_block_size=BLOCK_SIZE,
|
|
num_gpu_blocks=NUM_GPU_BLOCKS,
|
|
dtype="float16",
|
|
)
|
|
|
|
# Get model info
|
|
num_layers = len(llm.model_runner.model.model.layers)
|
|
head_dim = llm.model_runner.model.model.layers[0].self_attn.attn.head_dim
|
|
scale = head_dim ** -0.5
|
|
|
|
# Register hooks
|
|
hooks = []
|
|
for layer_idx, decoder_layer in enumerate(llm.model_runner.model.model.layers):
|
|
# Pre-hook: inject all ones for Q, K, V
|
|
# pre_hook = decoder_layer.self_attn.attn.register_forward_pre_hook(make_ones_injection_hook())
|
|
# hooks.append(pre_hook)
|
|
# Post-hook: capture Q, K, V, output
|
|
post_hook = decoder_layer.self_attn.attn.register_forward_hook(make_capture_hook(layer_idx))
|
|
hooks.append(post_hook)
|
|
|
|
# Run inference
|
|
seed(42)
|
|
prompt_token_ids = [[randint(0, 10000) for _ in range(INPUT_LEN)]]
|
|
outputs = llm.generate(prompt_token_ids, SamplingParams(temperature=0.6, max_tokens=1), use_tqdm=False)
|
|
|
|
# Remove hooks
|
|
for hook in hooks:
|
|
hook.remove()
|
|
|
|
# Get CPU cache reference
|
|
offload_engine = llm.model_runner.kvcache_manager.offload_engine
|
|
k_cache_cpu = offload_engine.k_cache_cpu.clone()
|
|
v_cache_cpu = offload_engine.v_cache_cpu.clone()
|
|
|
|
# Verify: compare actual output with reference computed from CPU cache
|
|
all_passed = True
|
|
num_chunks = INPUT_LEN // BLOCK_SIZE
|
|
|
|
for idx,c in enumerate(captures):
|
|
layer_id = c['layer_id']
|
|
chunk_idx = c['chunk_idx']
|
|
|
|
# Skip chunk 0 (no previous KV to load)
|
|
if chunk_idx == 0:
|
|
continue
|
|
|
|
ref_output = compute_reference(layer_id, chunk_idx, scale, k_cache_cpu, v_cache_cpu, BLOCK_SIZE)
|
|
if ref_output is None:
|
|
continue
|
|
|
|
actual_output = c['output']
|
|
diff = (actual_output - ref_output).abs()
|
|
max_diff = diff.max().item()
|
|
|
|
passed = max_diff < 1e-1 # float16 tolerance
|
|
all_passed = all_passed and passed
|
|
|
|
if not passed:
|
|
print(f"[FAIL] Layer {layer_id}, Chunk {chunk_idx}: max_diff={max_diff:.6f}")
|
|
__import__('pdb').set_trace()
|
|
|
|
print(f"test_chunked_prefill_hook: {'PASSED' if all_passed else 'FAILED'}")
|