[refactor] refactor test_align.py.
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
"""
|
||||
Test alignment between nanovllm and custom torch Qwen3 implementation.
|
||||
Compares attention layer outputs to verify correctness.
|
||||
Compares attention layer outputs and QKV tensors to verify correctness.
|
||||
"""
|
||||
|
||||
import os
|
||||
@@ -14,88 +14,94 @@ from utils import generate_needle_prompt
|
||||
|
||||
# Config
|
||||
MODEL_PATH = os.path.expanduser("~/models/Qwen3-0.6B/")
|
||||
INPUT_LEN = 512 # Use shorter length for alignment test
|
||||
INPUT_LEN = 64
|
||||
DTYPE = torch.float16
|
||||
|
||||
# Storage for captured tensors
|
||||
nanovllm_outputs = {}
|
||||
torch_outputs = {}
|
||||
nanovllm_qkv = {}
|
||||
nanovllm_proj_inputs = {} # Input to qkv_proj
|
||||
torch_proj_inputs = {} # Input to q_proj
|
||||
|
||||
|
||||
def make_nanovllm_hook(layer_id: int, storage: dict):
|
||||
"""Capture nanovllm self_attn outputs (after o_proj)."""
|
||||
def hook(module, inputs, output):
|
||||
# Qwen3Attention output is a tuple (attn_output, None)
|
||||
if isinstance(output, tuple):
|
||||
attn_output = output[0]
|
||||
else:
|
||||
attn_output = output
|
||||
# nanovllm shape: [num_tokens, hidden_size] -> add batch dim
|
||||
attn_output = output[0] if isinstance(output, tuple) else output
|
||||
if attn_output.dim() == 2:
|
||||
attn_output = attn_output.unsqueeze(0)
|
||||
storage[layer_id] = attn_output.detach().clone()
|
||||
return hook
|
||||
|
||||
|
||||
def make_torch_hook(layer_id: int, storage: dict):
|
||||
"""Capture torch model self_attn outputs (after o_proj)."""
|
||||
def hook(module, inputs, output):
|
||||
# Qwen3Attention output is (attn_output, past_kv, qkv_dict)
|
||||
attn_output, _, _ = output
|
||||
storage[layer_id] = attn_output.detach().clone()
|
||||
def make_nanovllm_qkv_hook(layer_id: int, storage: dict):
|
||||
def hook(module, inputs):
|
||||
q, k, v = inputs[0], inputs[1], inputs[2]
|
||||
storage[layer_id] = {
|
||||
"q": q.detach().clone(),
|
||||
"k": k.detach().clone(),
|
||||
"v": v.detach().clone(),
|
||||
}
|
||||
return hook
|
||||
|
||||
|
||||
def compare_tensors(name: str, t1: torch.Tensor, t2: torch.Tensor, atol: float = 1e-2):
|
||||
"""Compare two tensors and print statistics."""
|
||||
# Handle shape differences
|
||||
if t1.shape != t2.shape:
|
||||
print(f"[{name}] Shape mismatch: {t1.shape} vs {t2.shape}")
|
||||
# Try to reshape for comparison if possible
|
||||
if t1.numel() == t2.numel():
|
||||
t2 = t2.view(t1.shape)
|
||||
else:
|
||||
return False
|
||||
def make_proj_input_hook(layer_id: int, storage: dict):
|
||||
"""Capture input to projection layer (hidden_states after layernorm)."""
|
||||
def hook(module, inputs):
|
||||
# inputs[0] is hidden_states
|
||||
hidden = inputs[0]
|
||||
if hidden.dim() == 2:
|
||||
hidden = hidden.unsqueeze(0)
|
||||
storage[layer_id] = hidden.detach().clone()
|
||||
return hook
|
||||
|
||||
diff = (t1.float() - t2.float()).abs()
|
||||
max_diff = diff.max().item()
|
||||
mean_diff = diff.mean().item()
|
||||
|
||||
passed = max_diff < atol
|
||||
status = "PASS" if passed else "FAIL"
|
||||
def make_torch_hook(layer_id: int, storage: dict):
|
||||
def hook(module, inputs, output):
|
||||
storage[layer_id] = output[0].detach().clone()
|
||||
return hook
|
||||
|
||||
print(f"[{name}] {status}")
|
||||
print(f" Shape: {list(t1.shape)}")
|
||||
print(f" t1 mean: {t1.float().mean():.6f}, std: {t1.float().std():.6f}")
|
||||
print(f" t2 mean: {t2.float().mean():.6f}, std: {t2.float().std():.6f}")
|
||||
print(f" Max diff: {max_diff:.6f}, Mean diff: {mean_diff:.6f}")
|
||||
|
||||
return passed
|
||||
def max_diff(t1: torch.Tensor, t2: torch.Tensor) -> float:
|
||||
return (t1.float() - t2.float()).abs().max().item()
|
||||
|
||||
|
||||
def compute_qkv_diffs(nano_qkv: dict, torch_qkv: dict, num_kv_groups: int):
|
||||
"""Compute Q, K, V max diffs. Returns (q_diff, k_diff, v_diff)."""
|
||||
nano_q = nano_qkv["q"]
|
||||
torch_q = torch_qkv["q"].squeeze(0).transpose(0, 1)
|
||||
q_diff = max_diff(nano_q, torch_q)
|
||||
|
||||
nano_k = nano_qkv["k"]
|
||||
torch_k = torch_qkv["k"].squeeze(0)[::num_kv_groups, :, :].transpose(0, 1)
|
||||
k_diff = max_diff(nano_k, torch_k)
|
||||
|
||||
nano_v = nano_qkv["v"]
|
||||
torch_v = torch_qkv["v"].squeeze(0)[::num_kv_groups, :, :].transpose(0, 1)
|
||||
v_diff = max_diff(nano_v, torch_v)
|
||||
|
||||
return q_diff, k_diff, v_diff
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Load nanovllm model
|
||||
# Load models
|
||||
# ============================================================
|
||||
print("=" * 60)
|
||||
print("Loading nanovllm model...")
|
||||
print("=" * 60)
|
||||
|
||||
llm = LLM(
|
||||
MODEL_PATH,
|
||||
enforce_eager=True,
|
||||
max_model_len=4096,
|
||||
max_num_batched_tokens=4096,
|
||||
enable_cpu_offload=False, # Disable offload for alignment test
|
||||
enable_cpu_offload=False,
|
||||
dtype="float16",
|
||||
)
|
||||
|
||||
# ============================================================
|
||||
# Load torch model
|
||||
# ============================================================
|
||||
print("\n" + "=" * 60)
|
||||
print("Loading custom torch model...")
|
||||
print("=" * 60)
|
||||
num_heads = llm.model_runner.model.model.layers[0].self_attn.num_heads
|
||||
num_kv_heads = llm.model_runner.model.model.layers[0].self_attn.num_kv_heads
|
||||
num_kv_groups = num_heads // num_kv_heads
|
||||
num_layers = len(llm.model_runner.model.model.layers)
|
||||
|
||||
print("Loading torch model...")
|
||||
torch_model = Qwen3ForCausalLM.from_pretrained(MODEL_PATH, dtype=DTYPE)
|
||||
torch_model = torch_model.to("cuda")
|
||||
torch_model.eval()
|
||||
@@ -103,110 +109,78 @@ torch_model.eval()
|
||||
# ============================================================
|
||||
# Generate test input
|
||||
# ============================================================
|
||||
print("\n" + "=" * 60)
|
||||
print("Generating test input...")
|
||||
print("=" * 60)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
|
||||
prompt, _ = generate_needle_prompt(
|
||||
tokenizer=tokenizer,
|
||||
target_length=INPUT_LEN,
|
||||
verbose=True,
|
||||
)
|
||||
prompt, _ = generate_needle_prompt(tokenizer=tokenizer, target_length=INPUT_LEN, verbose=True)
|
||||
input_ids = tokenizer.encode(prompt, return_tensors="pt").to("cuda")
|
||||
print(f"Input shape: {input_ids.shape}")
|
||||
|
||||
# ============================================================
|
||||
# Register hooks on both models
|
||||
# Register hooks
|
||||
# ============================================================
|
||||
print("\n" + "=" * 60)
|
||||
print("Registering hooks...")
|
||||
print("=" * 60)
|
||||
|
||||
# Hook on nanovllm (self_attn is Qwen3Attention, captures output after o_proj)
|
||||
nanovllm_hooks = []
|
||||
for layer_idx, layer in enumerate(llm.model_runner.model.model.layers):
|
||||
if layer_idx >= 2: # Only first 2 layers
|
||||
break
|
||||
nanovllm_hooks.append(
|
||||
layer.self_attn.register_forward_hook(
|
||||
make_nanovllm_hook(layer_idx, nanovllm_outputs)
|
||||
)
|
||||
)
|
||||
print(f" Registered nanovllm hook on layer {layer_idx} self_attn")
|
||||
nanovllm_hooks.append(layer.self_attn.register_forward_hook(make_nanovllm_hook(layer_idx, nanovllm_outputs)))
|
||||
nanovllm_hooks.append(layer.self_attn.attn.register_forward_pre_hook(make_nanovllm_qkv_hook(layer_idx, nanovllm_qkv)))
|
||||
nanovllm_hooks.append(layer.self_attn.qkv_proj.register_forward_pre_hook(make_proj_input_hook(layer_idx, nanovllm_proj_inputs)))
|
||||
|
||||
# Hook on torch model (self_attn is Qwen3Attention, captures output after o_proj)
|
||||
torch_hooks = []
|
||||
for layer_idx, layer in enumerate(torch_model.model.layers):
|
||||
if layer_idx >= 2: # Only first 2 layers
|
||||
break
|
||||
torch_hooks.append(
|
||||
layer.self_attn.register_forward_hook(
|
||||
make_torch_hook(layer_idx, torch_outputs)
|
||||
)
|
||||
)
|
||||
print(f" Registered torch hook on layer {layer_idx} self_attn")
|
||||
torch_hooks.append(layer.self_attn.register_forward_hook(make_torch_hook(layer_idx, torch_outputs)))
|
||||
torch_hooks.append(layer.self_attn.q_proj.register_forward_pre_hook(make_proj_input_hook(layer_idx, torch_proj_inputs)))
|
||||
|
||||
# ============================================================
|
||||
# Run nanovllm inference
|
||||
# Run inference
|
||||
# ============================================================
|
||||
print("\n" + "=" * 60)
|
||||
print("Running nanovllm inference...")
|
||||
print("=" * 60)
|
||||
nanovllm_result = llm.generate([input_ids[0].tolist()], SamplingParams(temperature=0.01, max_tokens=1), use_tqdm=False)
|
||||
|
||||
# Use prompt_token_ids to ensure same input
|
||||
prompt_token_ids = input_ids[0].tolist()
|
||||
nanovllm_result = llm.generate(
|
||||
[prompt_token_ids],
|
||||
SamplingParams(temperature=0.01, max_tokens=1), # Near-greedy for determinism
|
||||
use_tqdm=False,
|
||||
)
|
||||
|
||||
# ============================================================
|
||||
# Run torch inference
|
||||
# ============================================================
|
||||
print("\n" + "=" * 60)
|
||||
print("Running torch inference...")
|
||||
print("=" * 60)
|
||||
|
||||
with torch.no_grad():
|
||||
torch_logits, _, _ = torch_model(input_ids)
|
||||
torch_logits, _, torch_qkv_outputs = torch_model(input_ids, output_qkv_layers=list(range(num_layers)))
|
||||
|
||||
# ============================================================
|
||||
# Compare outputs
|
||||
# Compare QKVO per layer (one line each)
|
||||
# ============================================================
|
||||
print("\n" + "=" * 60)
|
||||
print("Comparing attention outputs...")
|
||||
print("=" * 60)
|
||||
print("\n" + "=" * 82)
|
||||
print(f"{'Layer':<8} {'I':>10} {'Q':>10} {'K':>10} {'V':>10} {'O':>10}")
|
||||
print("=" * 82)
|
||||
|
||||
all_passed = True
|
||||
for layer_idx in sorted(nanovllm_outputs.keys()):
|
||||
if layer_idx not in torch_outputs:
|
||||
print(f"[Layer {layer_idx}] Missing torch output")
|
||||
all_passed = False
|
||||
continue
|
||||
atol = 0.1
|
||||
|
||||
for layer_idx in range(num_layers):
|
||||
# Input diff (to qkv_proj / q_proj)
|
||||
nano_in = nanovllm_proj_inputs[layer_idx]
|
||||
torch_in = torch_proj_inputs[layer_idx]
|
||||
if nano_in.shape != torch_in.shape and nano_in.numel() == torch_in.numel():
|
||||
torch_in = torch_in.view(nano_in.shape)
|
||||
i_diff = max_diff(nano_in, torch_in)
|
||||
|
||||
# QKV diffs
|
||||
q_diff, k_diff, v_diff = compute_qkv_diffs(nanovllm_qkv[layer_idx], torch_qkv_outputs[layer_idx], num_kv_groups)
|
||||
|
||||
# O diff
|
||||
nano_out = nanovllm_outputs[layer_idx]
|
||||
torch_out = torch_outputs[layer_idx]
|
||||
if nano_out.shape != torch_out.shape and nano_out.numel() == torch_out.numel():
|
||||
torch_out = torch_out.view(nano_out.shape)
|
||||
o_diff = max_diff(nano_out, torch_out)
|
||||
|
||||
print(f"\n--- Layer {layer_idx} ---")
|
||||
passed = compare_tensors(f"Layer {layer_idx} attn_output", nano_out, torch_out, atol=0.1)
|
||||
# Check pass/fail
|
||||
passed = all(d < atol for d in [i_diff, q_diff, k_diff, v_diff, o_diff])
|
||||
all_passed = all_passed and passed
|
||||
status = "" if passed else " *"
|
||||
|
||||
print(f"Layer {layer_idx:2d}{status:<3} {i_diff:>10.6f} {q_diff:>10.6f} {k_diff:>10.6f} {v_diff:>10.6f} {o_diff:>10.6f}")
|
||||
|
||||
# ============================================================
|
||||
# Cleanup
|
||||
# Cleanup and result
|
||||
# ============================================================
|
||||
for hook in nanovllm_hooks:
|
||||
hook.remove()
|
||||
for hook in torch_hooks:
|
||||
for hook in nanovllm_hooks + torch_hooks:
|
||||
hook.remove()
|
||||
|
||||
# ============================================================
|
||||
# Result
|
||||
# ============================================================
|
||||
print("\n" + "=" * 60)
|
||||
print("=" * 82)
|
||||
if all_passed:
|
||||
print("test_align: PASSED - nanovllm and torch outputs aligned!")
|
||||
print("test_align: PASSED")
|
||||
else:
|
||||
print("test_align: FAILED - outputs differ!")
|
||||
print("=" * 60)
|
||||
print("test_align: FAILED (* = max_diff >= 0.1)")
|
||||
|
||||
105
tests/utils.py
105
tests/utils.py
@@ -55,7 +55,7 @@ def generate_needle_prompt(
|
||||
verbose: bool = True,
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
Generate a needle-in-haystack prompt of approximately target_length tokens.
|
||||
Generate a needle-in-haystack prompt of exactly target_length tokens.
|
||||
|
||||
Args:
|
||||
tokenizer: HuggingFace tokenizer for length estimation
|
||||
@@ -71,68 +71,79 @@ def generate_needle_prompt(
|
||||
# The needle sentence
|
||||
needle = f"The secret number you need to remember is {needle_value}. This is very important. "
|
||||
|
||||
# Question at the end
|
||||
question = "\n\nQuestion: What is the secret number mentioned in the text above?\nAnswer: The secret number is"
|
||||
# Question text
|
||||
if use_chat_template and hasattr(tokenizer, 'apply_chat_template'):
|
||||
question_text = "/no_think Answer only with the secret number mentioned above, nothing else:"
|
||||
else:
|
||||
question_text = "\n\nQuestion: What is the secret number mentioned in the text above?\nAnswer: The secret number is"
|
||||
|
||||
# Estimate tokens for fixed parts
|
||||
needle_tokens = len(tokenizer.encode(needle, add_special_tokens=False))
|
||||
question_text = "What is the secret number mentioned in the text above? Answer with just the number."
|
||||
question_tokens = len(tokenizer.encode(question_text, add_special_tokens=False))
|
||||
# Buffer for chat template, special tokens, etc.
|
||||
overhead_tokens = 100 if use_chat_template else 50
|
||||
def build_prompt(haystack_parts, needle_idx):
|
||||
"""Build full prompt from haystack parts with needle inserted."""
|
||||
parts = haystack_parts.copy()
|
||||
parts.insert(needle_idx, needle)
|
||||
full_text = "".join(parts)
|
||||
|
||||
# Available tokens for haystack
|
||||
haystack_target_tokens = target_length - needle_tokens - question_tokens - overhead_tokens
|
||||
if haystack_target_tokens < 100:
|
||||
raise ValueError(f"target_length {target_length} is too short for needle test")
|
||||
if use_chat_template and hasattr(tokenizer, 'apply_chat_template'):
|
||||
messages = [{"role": "user", "content": f"{full_text}\n\n{question_text}"}]
|
||||
return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
else:
|
||||
return full_text + question_text
|
||||
|
||||
# Build haystack by repeating paragraphs
|
||||
def count_tokens(prompt):
|
||||
return len(tokenizer.encode(prompt, add_special_tokens=False))
|
||||
|
||||
def get_needle_idx(parts):
|
||||
idx = int(len(parts) * needle_position)
|
||||
return max(0, min(idx, len(parts)))
|
||||
|
||||
# Phase 1: Build haystack with full paragraphs until we exceed target
|
||||
haystack_parts = []
|
||||
current_tokens = 0
|
||||
para_idx = 0
|
||||
|
||||
while current_tokens < haystack_target_tokens:
|
||||
while True:
|
||||
para = HAYSTACK_PARAGRAPHS[para_idx % len(HAYSTACK_PARAGRAPHS)]
|
||||
para_tokens = len(tokenizer.encode(para, add_special_tokens=False))
|
||||
if current_tokens + para_tokens > haystack_target_tokens:
|
||||
test_parts = haystack_parts + [para]
|
||||
prompt = build_prompt(test_parts, get_needle_idx(test_parts))
|
||||
|
||||
if count_tokens(prompt) > target_length:
|
||||
break
|
||||
|
||||
haystack_parts.append(para)
|
||||
current_tokens += para_tokens
|
||||
para_idx += 1
|
||||
|
||||
# Calculate needle insertion point
|
||||
needle_idx = int(len(haystack_parts) * needle_position)
|
||||
needle_idx = max(0, min(needle_idx, len(haystack_parts)))
|
||||
if para_idx > 10000: # Safety limit
|
||||
break
|
||||
|
||||
# Insert needle
|
||||
haystack_parts.insert(needle_idx, needle)
|
||||
# Phase 2: Fine-tune by adding words from next paragraph
|
||||
next_para = HAYSTACK_PARAGRAPHS[para_idx % len(HAYSTACK_PARAGRAPHS)]
|
||||
words = next_para.split()
|
||||
|
||||
# Assemble prompt
|
||||
full_text = "".join(haystack_parts)
|
||||
best_parts = haystack_parts.copy()
|
||||
best_diff = abs(target_length - count_tokens(build_prompt(haystack_parts, get_needle_idx(haystack_parts))))
|
||||
|
||||
if use_chat_template and hasattr(tokenizer, 'apply_chat_template'):
|
||||
# Use chat template for instruct models
|
||||
# For Qwen3, add /no_think to disable thinking mode
|
||||
question_text = "/no_think Answer only with the secret number mentioned above, nothing else:"
|
||||
messages = [
|
||||
{"role": "user", "content": f"{full_text}\n\n{question_text}"}
|
||||
]
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
else:
|
||||
# Raw text format for base models
|
||||
question = "\n\nQuestion: What is the secret number mentioned in the text above?\nAnswer: The secret number is"
|
||||
prompt = full_text + question
|
||||
for i in range(1, len(words) + 1):
|
||||
partial = " ".join(words[:i]) + " "
|
||||
test_parts = haystack_parts + [partial]
|
||||
prompt = build_prompt(test_parts, get_needle_idx(test_parts))
|
||||
token_count = count_tokens(prompt)
|
||||
diff = abs(target_length - token_count)
|
||||
|
||||
# Verify length
|
||||
actual_tokens = len(tokenizer.encode(prompt, add_special_tokens=False))
|
||||
if diff < best_diff:
|
||||
best_diff = diff
|
||||
best_parts = test_parts.copy()
|
||||
|
||||
if token_count >= target_length:
|
||||
break
|
||||
|
||||
haystack_parts = best_parts
|
||||
|
||||
# Final build
|
||||
needle_idx = get_needle_idx(haystack_parts)
|
||||
prompt = build_prompt(haystack_parts, needle_idx)
|
||||
|
||||
actual_tokens = count_tokens(prompt)
|
||||
if verbose:
|
||||
print(f"[NeedleTest] Target: {target_length} tokens, Actual: {actual_tokens} tokens")
|
||||
print(f"[NeedleTest] Needle position: {needle_position:.0%} ({needle_idx}/{len(haystack_parts)-1} paragraphs)")
|
||||
print(f"[NeedleTest] Using chat template: {use_chat_template and hasattr(tokenizer, 'apply_chat_template')}")
|
||||
print(f"[NeedleTest] Target: {target_length}, Actual: {actual_tokens} tokens (diff={actual_tokens - target_length})")
|
||||
|
||||
return prompt, needle_value
|
||||
|
||||
|
||||
Reference in New Issue
Block a user