From 24096431edc7d8d161da3fa556ea097c9eb156db Mon Sep 17 00:00:00 2001 From: Zijie Tian Date: Sun, 4 Jan 2026 20:55:40 +0800 Subject: [PATCH] [refactor] refactor test_align.py. --- tests/test_align.py | 212 +++++++++++++++++++------------------------- tests/utils.py | 105 ++++++++++++---------- 2 files changed, 151 insertions(+), 166 deletions(-) diff --git a/tests/test_align.py b/tests/test_align.py index 67f0515..c6a9655 100644 --- a/tests/test_align.py +++ b/tests/test_align.py @@ -1,6 +1,6 @@ """ Test alignment between nanovllm and custom torch Qwen3 implementation. -Compares attention layer outputs to verify correctness. +Compares attention layer outputs and QKV tensors to verify correctness. """ import os @@ -14,88 +14,94 @@ from utils import generate_needle_prompt # Config MODEL_PATH = os.path.expanduser("~/models/Qwen3-0.6B/") -INPUT_LEN = 512 # Use shorter length for alignment test +INPUT_LEN = 64 DTYPE = torch.float16 # Storage for captured tensors nanovllm_outputs = {} torch_outputs = {} +nanovllm_qkv = {} +nanovllm_proj_inputs = {} # Input to qkv_proj +torch_proj_inputs = {} # Input to q_proj def make_nanovllm_hook(layer_id: int, storage: dict): - """Capture nanovllm self_attn outputs (after o_proj).""" def hook(module, inputs, output): - # Qwen3Attention output is a tuple (attn_output, None) - if isinstance(output, tuple): - attn_output = output[0] - else: - attn_output = output - # nanovllm shape: [num_tokens, hidden_size] -> add batch dim + attn_output = output[0] if isinstance(output, tuple) else output if attn_output.dim() == 2: attn_output = attn_output.unsqueeze(0) storage[layer_id] = attn_output.detach().clone() return hook -def make_torch_hook(layer_id: int, storage: dict): - """Capture torch model self_attn outputs (after o_proj).""" - def hook(module, inputs, output): - # Qwen3Attention output is (attn_output, past_kv, qkv_dict) - attn_output, _, _ = output - storage[layer_id] = attn_output.detach().clone() +def make_nanovllm_qkv_hook(layer_id: int, storage: dict): + def hook(module, inputs): + q, k, v = inputs[0], inputs[1], inputs[2] + storage[layer_id] = { + "q": q.detach().clone(), + "k": k.detach().clone(), + "v": v.detach().clone(), + } return hook -def compare_tensors(name: str, t1: torch.Tensor, t2: torch.Tensor, atol: float = 1e-2): - """Compare two tensors and print statistics.""" - # Handle shape differences - if t1.shape != t2.shape: - print(f"[{name}] Shape mismatch: {t1.shape} vs {t2.shape}") - # Try to reshape for comparison if possible - if t1.numel() == t2.numel(): - t2 = t2.view(t1.shape) - else: - return False +def make_proj_input_hook(layer_id: int, storage: dict): + """Capture input to projection layer (hidden_states after layernorm).""" + def hook(module, inputs): + # inputs[0] is hidden_states + hidden = inputs[0] + if hidden.dim() == 2: + hidden = hidden.unsqueeze(0) + storage[layer_id] = hidden.detach().clone() + return hook - diff = (t1.float() - t2.float()).abs() - max_diff = diff.max().item() - mean_diff = diff.mean().item() - passed = max_diff < atol - status = "PASS" if passed else "FAIL" +def make_torch_hook(layer_id: int, storage: dict): + def hook(module, inputs, output): + storage[layer_id] = output[0].detach().clone() + return hook - print(f"[{name}] {status}") - print(f" Shape: {list(t1.shape)}") - print(f" t1 mean: {t1.float().mean():.6f}, std: {t1.float().std():.6f}") - print(f" t2 mean: {t2.float().mean():.6f}, std: {t2.float().std():.6f}") - print(f" Max diff: {max_diff:.6f}, Mean diff: {mean_diff:.6f}") - return passed +def max_diff(t1: torch.Tensor, t2: torch.Tensor) -> float: + return (t1.float() - t2.float()).abs().max().item() + + +def compute_qkv_diffs(nano_qkv: dict, torch_qkv: dict, num_kv_groups: int): + """Compute Q, K, V max diffs. Returns (q_diff, k_diff, v_diff).""" + nano_q = nano_qkv["q"] + torch_q = torch_qkv["q"].squeeze(0).transpose(0, 1) + q_diff = max_diff(nano_q, torch_q) + + nano_k = nano_qkv["k"] + torch_k = torch_qkv["k"].squeeze(0)[::num_kv_groups, :, :].transpose(0, 1) + k_diff = max_diff(nano_k, torch_k) + + nano_v = nano_qkv["v"] + torch_v = torch_qkv["v"].squeeze(0)[::num_kv_groups, :, :].transpose(0, 1) + v_diff = max_diff(nano_v, torch_v) + + return q_diff, k_diff, v_diff # ============================================================ -# Load nanovllm model +# Load models # ============================================================ -print("=" * 60) print("Loading nanovllm model...") -print("=" * 60) - llm = LLM( MODEL_PATH, enforce_eager=True, max_model_len=4096, max_num_batched_tokens=4096, - enable_cpu_offload=False, # Disable offload for alignment test + enable_cpu_offload=False, dtype="float16", ) -# ============================================================ -# Load torch model -# ============================================================ -print("\n" + "=" * 60) -print("Loading custom torch model...") -print("=" * 60) +num_heads = llm.model_runner.model.model.layers[0].self_attn.num_heads +num_kv_heads = llm.model_runner.model.model.layers[0].self_attn.num_kv_heads +num_kv_groups = num_heads // num_kv_heads +num_layers = len(llm.model_runner.model.model.layers) +print("Loading torch model...") torch_model = Qwen3ForCausalLM.from_pretrained(MODEL_PATH, dtype=DTYPE) torch_model = torch_model.to("cuda") torch_model.eval() @@ -103,110 +109,78 @@ torch_model.eval() # ============================================================ # Generate test input # ============================================================ -print("\n" + "=" * 60) -print("Generating test input...") -print("=" * 60) - tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) -prompt, _ = generate_needle_prompt( - tokenizer=tokenizer, - target_length=INPUT_LEN, - verbose=True, -) +prompt, _ = generate_needle_prompt(tokenizer=tokenizer, target_length=INPUT_LEN, verbose=True) input_ids = tokenizer.encode(prompt, return_tensors="pt").to("cuda") print(f"Input shape: {input_ids.shape}") # ============================================================ -# Register hooks on both models +# Register hooks # ============================================================ -print("\n" + "=" * 60) -print("Registering hooks...") -print("=" * 60) - -# Hook on nanovllm (self_attn is Qwen3Attention, captures output after o_proj) nanovllm_hooks = [] for layer_idx, layer in enumerate(llm.model_runner.model.model.layers): - if layer_idx >= 2: # Only first 2 layers - break - nanovllm_hooks.append( - layer.self_attn.register_forward_hook( - make_nanovllm_hook(layer_idx, nanovllm_outputs) - ) - ) - print(f" Registered nanovllm hook on layer {layer_idx} self_attn") + nanovllm_hooks.append(layer.self_attn.register_forward_hook(make_nanovllm_hook(layer_idx, nanovllm_outputs))) + nanovllm_hooks.append(layer.self_attn.attn.register_forward_pre_hook(make_nanovllm_qkv_hook(layer_idx, nanovllm_qkv))) + nanovllm_hooks.append(layer.self_attn.qkv_proj.register_forward_pre_hook(make_proj_input_hook(layer_idx, nanovllm_proj_inputs))) -# Hook on torch model (self_attn is Qwen3Attention, captures output after o_proj) torch_hooks = [] for layer_idx, layer in enumerate(torch_model.model.layers): - if layer_idx >= 2: # Only first 2 layers - break - torch_hooks.append( - layer.self_attn.register_forward_hook( - make_torch_hook(layer_idx, torch_outputs) - ) - ) - print(f" Registered torch hook on layer {layer_idx} self_attn") + torch_hooks.append(layer.self_attn.register_forward_hook(make_torch_hook(layer_idx, torch_outputs))) + torch_hooks.append(layer.self_attn.q_proj.register_forward_pre_hook(make_proj_input_hook(layer_idx, torch_proj_inputs))) # ============================================================ -# Run nanovllm inference +# Run inference # ============================================================ -print("\n" + "=" * 60) print("Running nanovllm inference...") -print("=" * 60) +nanovllm_result = llm.generate([input_ids[0].tolist()], SamplingParams(temperature=0.01, max_tokens=1), use_tqdm=False) -# Use prompt_token_ids to ensure same input -prompt_token_ids = input_ids[0].tolist() -nanovllm_result = llm.generate( - [prompt_token_ids], - SamplingParams(temperature=0.01, max_tokens=1), # Near-greedy for determinism - use_tqdm=False, -) - -# ============================================================ -# Run torch inference -# ============================================================ -print("\n" + "=" * 60) print("Running torch inference...") -print("=" * 60) - with torch.no_grad(): - torch_logits, _, _ = torch_model(input_ids) + torch_logits, _, torch_qkv_outputs = torch_model(input_ids, output_qkv_layers=list(range(num_layers))) # ============================================================ -# Compare outputs +# Compare QKVO per layer (one line each) # ============================================================ -print("\n" + "=" * 60) -print("Comparing attention outputs...") -print("=" * 60) +print("\n" + "=" * 82) +print(f"{'Layer':<8} {'I':>10} {'Q':>10} {'K':>10} {'V':>10} {'O':>10}") +print("=" * 82) all_passed = True -for layer_idx in sorted(nanovllm_outputs.keys()): - if layer_idx not in torch_outputs: - print(f"[Layer {layer_idx}] Missing torch output") - all_passed = False - continue +atol = 0.1 +for layer_idx in range(num_layers): + # Input diff (to qkv_proj / q_proj) + nano_in = nanovllm_proj_inputs[layer_idx] + torch_in = torch_proj_inputs[layer_idx] + if nano_in.shape != torch_in.shape and nano_in.numel() == torch_in.numel(): + torch_in = torch_in.view(nano_in.shape) + i_diff = max_diff(nano_in, torch_in) + + # QKV diffs + q_diff, k_diff, v_diff = compute_qkv_diffs(nanovllm_qkv[layer_idx], torch_qkv_outputs[layer_idx], num_kv_groups) + + # O diff nano_out = nanovllm_outputs[layer_idx] torch_out = torch_outputs[layer_idx] + if nano_out.shape != torch_out.shape and nano_out.numel() == torch_out.numel(): + torch_out = torch_out.view(nano_out.shape) + o_diff = max_diff(nano_out, torch_out) - print(f"\n--- Layer {layer_idx} ---") - passed = compare_tensors(f"Layer {layer_idx} attn_output", nano_out, torch_out, atol=0.1) + # Check pass/fail + passed = all(d < atol for d in [i_diff, q_diff, k_diff, v_diff, o_diff]) all_passed = all_passed and passed + status = "" if passed else " *" + + print(f"Layer {layer_idx:2d}{status:<3} {i_diff:>10.6f} {q_diff:>10.6f} {k_diff:>10.6f} {v_diff:>10.6f} {o_diff:>10.6f}") # ============================================================ -# Cleanup +# Cleanup and result # ============================================================ -for hook in nanovllm_hooks: - hook.remove() -for hook in torch_hooks: +for hook in nanovllm_hooks + torch_hooks: hook.remove() -# ============================================================ -# Result -# ============================================================ -print("\n" + "=" * 60) +print("=" * 82) if all_passed: - print("test_align: PASSED - nanovllm and torch outputs aligned!") + print("test_align: PASSED") else: - print("test_align: FAILED - outputs differ!") -print("=" * 60) + print("test_align: FAILED (* = max_diff >= 0.1)") diff --git a/tests/utils.py b/tests/utils.py index 350d0b9..dabeb0e 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -55,7 +55,7 @@ def generate_needle_prompt( verbose: bool = True, ) -> Tuple[str, str]: """ - Generate a needle-in-haystack prompt of approximately target_length tokens. + Generate a needle-in-haystack prompt of exactly target_length tokens. Args: tokenizer: HuggingFace tokenizer for length estimation @@ -71,68 +71,79 @@ def generate_needle_prompt( # The needle sentence needle = f"The secret number you need to remember is {needle_value}. This is very important. " - # Question at the end - question = "\n\nQuestion: What is the secret number mentioned in the text above?\nAnswer: The secret number is" + # Question text + if use_chat_template and hasattr(tokenizer, 'apply_chat_template'): + question_text = "/no_think Answer only with the secret number mentioned above, nothing else:" + else: + question_text = "\n\nQuestion: What is the secret number mentioned in the text above?\nAnswer: The secret number is" - # Estimate tokens for fixed parts - needle_tokens = len(tokenizer.encode(needle, add_special_tokens=False)) - question_text = "What is the secret number mentioned in the text above? Answer with just the number." - question_tokens = len(tokenizer.encode(question_text, add_special_tokens=False)) - # Buffer for chat template, special tokens, etc. - overhead_tokens = 100 if use_chat_template else 50 + def build_prompt(haystack_parts, needle_idx): + """Build full prompt from haystack parts with needle inserted.""" + parts = haystack_parts.copy() + parts.insert(needle_idx, needle) + full_text = "".join(parts) - # Available tokens for haystack - haystack_target_tokens = target_length - needle_tokens - question_tokens - overhead_tokens - if haystack_target_tokens < 100: - raise ValueError(f"target_length {target_length} is too short for needle test") + if use_chat_template and hasattr(tokenizer, 'apply_chat_template'): + messages = [{"role": "user", "content": f"{full_text}\n\n{question_text}"}] + return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + else: + return full_text + question_text - # Build haystack by repeating paragraphs + def count_tokens(prompt): + return len(tokenizer.encode(prompt, add_special_tokens=False)) + + def get_needle_idx(parts): + idx = int(len(parts) * needle_position) + return max(0, min(idx, len(parts))) + + # Phase 1: Build haystack with full paragraphs until we exceed target haystack_parts = [] - current_tokens = 0 para_idx = 0 - while current_tokens < haystack_target_tokens: + while True: para = HAYSTACK_PARAGRAPHS[para_idx % len(HAYSTACK_PARAGRAPHS)] - para_tokens = len(tokenizer.encode(para, add_special_tokens=False)) - if current_tokens + para_tokens > haystack_target_tokens: + test_parts = haystack_parts + [para] + prompt = build_prompt(test_parts, get_needle_idx(test_parts)) + + if count_tokens(prompt) > target_length: break + haystack_parts.append(para) - current_tokens += para_tokens para_idx += 1 - # Calculate needle insertion point - needle_idx = int(len(haystack_parts) * needle_position) - needle_idx = max(0, min(needle_idx, len(haystack_parts))) + if para_idx > 10000: # Safety limit + break - # Insert needle - haystack_parts.insert(needle_idx, needle) + # Phase 2: Fine-tune by adding words from next paragraph + next_para = HAYSTACK_PARAGRAPHS[para_idx % len(HAYSTACK_PARAGRAPHS)] + words = next_para.split() - # Assemble prompt - full_text = "".join(haystack_parts) + best_parts = haystack_parts.copy() + best_diff = abs(target_length - count_tokens(build_prompt(haystack_parts, get_needle_idx(haystack_parts)))) - if use_chat_template and hasattr(tokenizer, 'apply_chat_template'): - # Use chat template for instruct models - # For Qwen3, add /no_think to disable thinking mode - question_text = "/no_think Answer only with the secret number mentioned above, nothing else:" - messages = [ - {"role": "user", "content": f"{full_text}\n\n{question_text}"} - ] - prompt = tokenizer.apply_chat_template( - messages, - tokenize=False, - add_generation_prompt=True, - ) - else: - # Raw text format for base models - question = "\n\nQuestion: What is the secret number mentioned in the text above?\nAnswer: The secret number is" - prompt = full_text + question + for i in range(1, len(words) + 1): + partial = " ".join(words[:i]) + " " + test_parts = haystack_parts + [partial] + prompt = build_prompt(test_parts, get_needle_idx(test_parts)) + token_count = count_tokens(prompt) + diff = abs(target_length - token_count) - # Verify length - actual_tokens = len(tokenizer.encode(prompt, add_special_tokens=False)) + if diff < best_diff: + best_diff = diff + best_parts = test_parts.copy() + + if token_count >= target_length: + break + + haystack_parts = best_parts + + # Final build + needle_idx = get_needle_idx(haystack_parts) + prompt = build_prompt(haystack_parts, needle_idx) + + actual_tokens = count_tokens(prompt) if verbose: - print(f"[NeedleTest] Target: {target_length} tokens, Actual: {actual_tokens} tokens") - print(f"[NeedleTest] Needle position: {needle_position:.0%} ({needle_idx}/{len(haystack_parts)-1} paragraphs)") - print(f"[NeedleTest] Using chat template: {use_chat_template and hasattr(tokenizer, 'apply_chat_template')}") + print(f"[NeedleTest] Target: {target_length}, Actual: {actual_tokens} tokens (diff={actual_tokens - target_length})") return prompt, needle_value