[refactor] refactor test_align.py.

This commit is contained in:
Zijie Tian
2026-01-04 20:55:40 +08:00
parent 772313db8f
commit 24096431ed
2 changed files with 151 additions and 166 deletions

View File

@@ -1,6 +1,6 @@
""" """
Test alignment between nanovllm and custom torch Qwen3 implementation. Test alignment between nanovllm and custom torch Qwen3 implementation.
Compares attention layer outputs to verify correctness. Compares attention layer outputs and QKV tensors to verify correctness.
""" """
import os import os
@@ -14,88 +14,94 @@ from utils import generate_needle_prompt
# Config # Config
MODEL_PATH = os.path.expanduser("~/models/Qwen3-0.6B/") MODEL_PATH = os.path.expanduser("~/models/Qwen3-0.6B/")
INPUT_LEN = 512 # Use shorter length for alignment test INPUT_LEN = 64
DTYPE = torch.float16 DTYPE = torch.float16
# Storage for captured tensors # Storage for captured tensors
nanovllm_outputs = {} nanovllm_outputs = {}
torch_outputs = {} torch_outputs = {}
nanovllm_qkv = {}
nanovllm_proj_inputs = {} # Input to qkv_proj
torch_proj_inputs = {} # Input to q_proj
def make_nanovllm_hook(layer_id: int, storage: dict): def make_nanovllm_hook(layer_id: int, storage: dict):
"""Capture nanovllm self_attn outputs (after o_proj)."""
def hook(module, inputs, output): def hook(module, inputs, output):
# Qwen3Attention output is a tuple (attn_output, None) attn_output = output[0] if isinstance(output, tuple) else output
if isinstance(output, tuple):
attn_output = output[0]
else:
attn_output = output
# nanovllm shape: [num_tokens, hidden_size] -> add batch dim
if attn_output.dim() == 2: if attn_output.dim() == 2:
attn_output = attn_output.unsqueeze(0) attn_output = attn_output.unsqueeze(0)
storage[layer_id] = attn_output.detach().clone() storage[layer_id] = attn_output.detach().clone()
return hook return hook
def make_torch_hook(layer_id: int, storage: dict): def make_nanovllm_qkv_hook(layer_id: int, storage: dict):
"""Capture torch model self_attn outputs (after o_proj).""" def hook(module, inputs):
def hook(module, inputs, output): q, k, v = inputs[0], inputs[1], inputs[2]
# Qwen3Attention output is (attn_output, past_kv, qkv_dict) storage[layer_id] = {
attn_output, _, _ = output "q": q.detach().clone(),
storage[layer_id] = attn_output.detach().clone() "k": k.detach().clone(),
"v": v.detach().clone(),
}
return hook return hook
def compare_tensors(name: str, t1: torch.Tensor, t2: torch.Tensor, atol: float = 1e-2): def make_proj_input_hook(layer_id: int, storage: dict):
"""Compare two tensors and print statistics.""" """Capture input to projection layer (hidden_states after layernorm)."""
# Handle shape differences def hook(module, inputs):
if t1.shape != t2.shape: # inputs[0] is hidden_states
print(f"[{name}] Shape mismatch: {t1.shape} vs {t2.shape}") hidden = inputs[0]
# Try to reshape for comparison if possible if hidden.dim() == 2:
if t1.numel() == t2.numel(): hidden = hidden.unsqueeze(0)
t2 = t2.view(t1.shape) storage[layer_id] = hidden.detach().clone()
else: return hook
return False
diff = (t1.float() - t2.float()).abs()
max_diff = diff.max().item()
mean_diff = diff.mean().item()
passed = max_diff < atol def make_torch_hook(layer_id: int, storage: dict):
status = "PASS" if passed else "FAIL" def hook(module, inputs, output):
storage[layer_id] = output[0].detach().clone()
return hook
print(f"[{name}] {status}")
print(f" Shape: {list(t1.shape)}")
print(f" t1 mean: {t1.float().mean():.6f}, std: {t1.float().std():.6f}")
print(f" t2 mean: {t2.float().mean():.6f}, std: {t2.float().std():.6f}")
print(f" Max diff: {max_diff:.6f}, Mean diff: {mean_diff:.6f}")
return passed def max_diff(t1: torch.Tensor, t2: torch.Tensor) -> float:
return (t1.float() - t2.float()).abs().max().item()
def compute_qkv_diffs(nano_qkv: dict, torch_qkv: dict, num_kv_groups: int):
"""Compute Q, K, V max diffs. Returns (q_diff, k_diff, v_diff)."""
nano_q = nano_qkv["q"]
torch_q = torch_qkv["q"].squeeze(0).transpose(0, 1)
q_diff = max_diff(nano_q, torch_q)
nano_k = nano_qkv["k"]
torch_k = torch_qkv["k"].squeeze(0)[::num_kv_groups, :, :].transpose(0, 1)
k_diff = max_diff(nano_k, torch_k)
nano_v = nano_qkv["v"]
torch_v = torch_qkv["v"].squeeze(0)[::num_kv_groups, :, :].transpose(0, 1)
v_diff = max_diff(nano_v, torch_v)
return q_diff, k_diff, v_diff
# ============================================================ # ============================================================
# Load nanovllm model # Load models
# ============================================================ # ============================================================
print("=" * 60)
print("Loading nanovllm model...") print("Loading nanovllm model...")
print("=" * 60)
llm = LLM( llm = LLM(
MODEL_PATH, MODEL_PATH,
enforce_eager=True, enforce_eager=True,
max_model_len=4096, max_model_len=4096,
max_num_batched_tokens=4096, max_num_batched_tokens=4096,
enable_cpu_offload=False, # Disable offload for alignment test enable_cpu_offload=False,
dtype="float16", dtype="float16",
) )
# ============================================================ num_heads = llm.model_runner.model.model.layers[0].self_attn.num_heads
# Load torch model num_kv_heads = llm.model_runner.model.model.layers[0].self_attn.num_kv_heads
# ============================================================ num_kv_groups = num_heads // num_kv_heads
print("\n" + "=" * 60) num_layers = len(llm.model_runner.model.model.layers)
print("Loading custom torch model...")
print("=" * 60)
print("Loading torch model...")
torch_model = Qwen3ForCausalLM.from_pretrained(MODEL_PATH, dtype=DTYPE) torch_model = Qwen3ForCausalLM.from_pretrained(MODEL_PATH, dtype=DTYPE)
torch_model = torch_model.to("cuda") torch_model = torch_model.to("cuda")
torch_model.eval() torch_model.eval()
@@ -103,110 +109,78 @@ torch_model.eval()
# ============================================================ # ============================================================
# Generate test input # Generate test input
# ============================================================ # ============================================================
print("\n" + "=" * 60)
print("Generating test input...")
print("=" * 60)
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
prompt, _ = generate_needle_prompt( prompt, _ = generate_needle_prompt(tokenizer=tokenizer, target_length=INPUT_LEN, verbose=True)
tokenizer=tokenizer,
target_length=INPUT_LEN,
verbose=True,
)
input_ids = tokenizer.encode(prompt, return_tensors="pt").to("cuda") input_ids = tokenizer.encode(prompt, return_tensors="pt").to("cuda")
print(f"Input shape: {input_ids.shape}") print(f"Input shape: {input_ids.shape}")
# ============================================================ # ============================================================
# Register hooks on both models # Register hooks
# ============================================================ # ============================================================
print("\n" + "=" * 60)
print("Registering hooks...")
print("=" * 60)
# Hook on nanovllm (self_attn is Qwen3Attention, captures output after o_proj)
nanovllm_hooks = [] nanovllm_hooks = []
for layer_idx, layer in enumerate(llm.model_runner.model.model.layers): for layer_idx, layer in enumerate(llm.model_runner.model.model.layers):
if layer_idx >= 2: # Only first 2 layers nanovllm_hooks.append(layer.self_attn.register_forward_hook(make_nanovllm_hook(layer_idx, nanovllm_outputs)))
break nanovllm_hooks.append(layer.self_attn.attn.register_forward_pre_hook(make_nanovllm_qkv_hook(layer_idx, nanovllm_qkv)))
nanovllm_hooks.append( nanovllm_hooks.append(layer.self_attn.qkv_proj.register_forward_pre_hook(make_proj_input_hook(layer_idx, nanovllm_proj_inputs)))
layer.self_attn.register_forward_hook(
make_nanovllm_hook(layer_idx, nanovllm_outputs)
)
)
print(f" Registered nanovllm hook on layer {layer_idx} self_attn")
# Hook on torch model (self_attn is Qwen3Attention, captures output after o_proj)
torch_hooks = [] torch_hooks = []
for layer_idx, layer in enumerate(torch_model.model.layers): for layer_idx, layer in enumerate(torch_model.model.layers):
if layer_idx >= 2: # Only first 2 layers torch_hooks.append(layer.self_attn.register_forward_hook(make_torch_hook(layer_idx, torch_outputs)))
break torch_hooks.append(layer.self_attn.q_proj.register_forward_pre_hook(make_proj_input_hook(layer_idx, torch_proj_inputs)))
torch_hooks.append(
layer.self_attn.register_forward_hook(
make_torch_hook(layer_idx, torch_outputs)
)
)
print(f" Registered torch hook on layer {layer_idx} self_attn")
# ============================================================ # ============================================================
# Run nanovllm inference # Run inference
# ============================================================ # ============================================================
print("\n" + "=" * 60)
print("Running nanovllm inference...") print("Running nanovllm inference...")
print("=" * 60) nanovllm_result = llm.generate([input_ids[0].tolist()], SamplingParams(temperature=0.01, max_tokens=1), use_tqdm=False)
# Use prompt_token_ids to ensure same input
prompt_token_ids = input_ids[0].tolist()
nanovllm_result = llm.generate(
[prompt_token_ids],
SamplingParams(temperature=0.01, max_tokens=1), # Near-greedy for determinism
use_tqdm=False,
)
# ============================================================
# Run torch inference
# ============================================================
print("\n" + "=" * 60)
print("Running torch inference...") print("Running torch inference...")
print("=" * 60)
with torch.no_grad(): with torch.no_grad():
torch_logits, _, _ = torch_model(input_ids) torch_logits, _, torch_qkv_outputs = torch_model(input_ids, output_qkv_layers=list(range(num_layers)))
# ============================================================ # ============================================================
# Compare outputs # Compare QKVO per layer (one line each)
# ============================================================ # ============================================================
print("\n" + "=" * 60) print("\n" + "=" * 82)
print("Comparing attention outputs...") print(f"{'Layer':<8} {'I':>10} {'Q':>10} {'K':>10} {'V':>10} {'O':>10}")
print("=" * 60) print("=" * 82)
all_passed = True all_passed = True
for layer_idx in sorted(nanovllm_outputs.keys()): atol = 0.1
if layer_idx not in torch_outputs:
print(f"[Layer {layer_idx}] Missing torch output")
all_passed = False
continue
for layer_idx in range(num_layers):
# Input diff (to qkv_proj / q_proj)
nano_in = nanovllm_proj_inputs[layer_idx]
torch_in = torch_proj_inputs[layer_idx]
if nano_in.shape != torch_in.shape and nano_in.numel() == torch_in.numel():
torch_in = torch_in.view(nano_in.shape)
i_diff = max_diff(nano_in, torch_in)
# QKV diffs
q_diff, k_diff, v_diff = compute_qkv_diffs(nanovllm_qkv[layer_idx], torch_qkv_outputs[layer_idx], num_kv_groups)
# O diff
nano_out = nanovllm_outputs[layer_idx] nano_out = nanovllm_outputs[layer_idx]
torch_out = torch_outputs[layer_idx] torch_out = torch_outputs[layer_idx]
if nano_out.shape != torch_out.shape and nano_out.numel() == torch_out.numel():
torch_out = torch_out.view(nano_out.shape)
o_diff = max_diff(nano_out, torch_out)
print(f"\n--- Layer {layer_idx} ---") # Check pass/fail
passed = compare_tensors(f"Layer {layer_idx} attn_output", nano_out, torch_out, atol=0.1) passed = all(d < atol for d in [i_diff, q_diff, k_diff, v_diff, o_diff])
all_passed = all_passed and passed all_passed = all_passed and passed
status = "" if passed else " *"
print(f"Layer {layer_idx:2d}{status:<3} {i_diff:>10.6f} {q_diff:>10.6f} {k_diff:>10.6f} {v_diff:>10.6f} {o_diff:>10.6f}")
# ============================================================ # ============================================================
# Cleanup # Cleanup and result
# ============================================================ # ============================================================
for hook in nanovllm_hooks: for hook in nanovllm_hooks + torch_hooks:
hook.remove()
for hook in torch_hooks:
hook.remove() hook.remove()
# ============================================================ print("=" * 82)
# Result
# ============================================================
print("\n" + "=" * 60)
if all_passed: if all_passed:
print("test_align: PASSED - nanovllm and torch outputs aligned!") print("test_align: PASSED")
else: else:
print("test_align: FAILED - outputs differ!") print("test_align: FAILED (* = max_diff >= 0.1)")
print("=" * 60)

View File

@@ -55,7 +55,7 @@ def generate_needle_prompt(
verbose: bool = True, verbose: bool = True,
) -> Tuple[str, str]: ) -> Tuple[str, str]:
""" """
Generate a needle-in-haystack prompt of approximately target_length tokens. Generate a needle-in-haystack prompt of exactly target_length tokens.
Args: Args:
tokenizer: HuggingFace tokenizer for length estimation tokenizer: HuggingFace tokenizer for length estimation
@@ -71,68 +71,79 @@ def generate_needle_prompt(
# The needle sentence # The needle sentence
needle = f"The secret number you need to remember is {needle_value}. This is very important. " needle = f"The secret number you need to remember is {needle_value}. This is very important. "
# Question at the end # Question text
question = "\n\nQuestion: What is the secret number mentioned in the text above?\nAnswer: The secret number is" if use_chat_template and hasattr(tokenizer, 'apply_chat_template'):
question_text = "/no_think Answer only with the secret number mentioned above, nothing else:"
else:
question_text = "\n\nQuestion: What is the secret number mentioned in the text above?\nAnswer: The secret number is"
# Estimate tokens for fixed parts def build_prompt(haystack_parts, needle_idx):
needle_tokens = len(tokenizer.encode(needle, add_special_tokens=False)) """Build full prompt from haystack parts with needle inserted."""
question_text = "What is the secret number mentioned in the text above? Answer with just the number." parts = haystack_parts.copy()
question_tokens = len(tokenizer.encode(question_text, add_special_tokens=False)) parts.insert(needle_idx, needle)
# Buffer for chat template, special tokens, etc. full_text = "".join(parts)
overhead_tokens = 100 if use_chat_template else 50
# Available tokens for haystack if use_chat_template and hasattr(tokenizer, 'apply_chat_template'):
haystack_target_tokens = target_length - needle_tokens - question_tokens - overhead_tokens messages = [{"role": "user", "content": f"{full_text}\n\n{question_text}"}]
if haystack_target_tokens < 100: return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
raise ValueError(f"target_length {target_length} is too short for needle test") else:
return full_text + question_text
# Build haystack by repeating paragraphs def count_tokens(prompt):
return len(tokenizer.encode(prompt, add_special_tokens=False))
def get_needle_idx(parts):
idx = int(len(parts) * needle_position)
return max(0, min(idx, len(parts)))
# Phase 1: Build haystack with full paragraphs until we exceed target
haystack_parts = [] haystack_parts = []
current_tokens = 0
para_idx = 0 para_idx = 0
while current_tokens < haystack_target_tokens: while True:
para = HAYSTACK_PARAGRAPHS[para_idx % len(HAYSTACK_PARAGRAPHS)] para = HAYSTACK_PARAGRAPHS[para_idx % len(HAYSTACK_PARAGRAPHS)]
para_tokens = len(tokenizer.encode(para, add_special_tokens=False)) test_parts = haystack_parts + [para]
if current_tokens + para_tokens > haystack_target_tokens: prompt = build_prompt(test_parts, get_needle_idx(test_parts))
if count_tokens(prompt) > target_length:
break break
haystack_parts.append(para) haystack_parts.append(para)
current_tokens += para_tokens
para_idx += 1 para_idx += 1
# Calculate needle insertion point if para_idx > 10000: # Safety limit
needle_idx = int(len(haystack_parts) * needle_position) break
needle_idx = max(0, min(needle_idx, len(haystack_parts)))
# Insert needle # Phase 2: Fine-tune by adding words from next paragraph
haystack_parts.insert(needle_idx, needle) next_para = HAYSTACK_PARAGRAPHS[para_idx % len(HAYSTACK_PARAGRAPHS)]
words = next_para.split()
# Assemble prompt best_parts = haystack_parts.copy()
full_text = "".join(haystack_parts) best_diff = abs(target_length - count_tokens(build_prompt(haystack_parts, get_needle_idx(haystack_parts))))
if use_chat_template and hasattr(tokenizer, 'apply_chat_template'): for i in range(1, len(words) + 1):
# Use chat template for instruct models partial = " ".join(words[:i]) + " "
# For Qwen3, add /no_think to disable thinking mode test_parts = haystack_parts + [partial]
question_text = "/no_think Answer only with the secret number mentioned above, nothing else:" prompt = build_prompt(test_parts, get_needle_idx(test_parts))
messages = [ token_count = count_tokens(prompt)
{"role": "user", "content": f"{full_text}\n\n{question_text}"} diff = abs(target_length - token_count)
]
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
else:
# Raw text format for base models
question = "\n\nQuestion: What is the secret number mentioned in the text above?\nAnswer: The secret number is"
prompt = full_text + question
# Verify length if diff < best_diff:
actual_tokens = len(tokenizer.encode(prompt, add_special_tokens=False)) best_diff = diff
best_parts = test_parts.copy()
if token_count >= target_length:
break
haystack_parts = best_parts
# Final build
needle_idx = get_needle_idx(haystack_parts)
prompt = build_prompt(haystack_parts, needle_idx)
actual_tokens = count_tokens(prompt)
if verbose: if verbose:
print(f"[NeedleTest] Target: {target_length} tokens, Actual: {actual_tokens} tokens") print(f"[NeedleTest] Target: {target_length}, Actual: {actual_tokens} tokens (diff={actual_tokens - target_length})")
print(f"[NeedleTest] Needle position: {needle_position:.0%} ({needle_idx}/{len(haystack_parts)-1} paragraphs)")
print(f"[NeedleTest] Using chat template: {use_chat_template and hasattr(tokenizer, 'apply_chat_template')}")
return prompt, needle_value return prompt, needle_value