[test] Added test_align.py and Before change nanovllm attention.

This commit is contained in:
Zijie Tian
2026-01-04 22:48:01 +08:00
parent 24096431ed
commit e897380127
3 changed files with 58 additions and 33 deletions

View File

@@ -480,7 +480,7 @@ class ModelRunner:
if input_ids.numel() == 0: if input_ids.numel() == 0:
break break
# Run model forward #> Run model forward
logits = self.run_model(input_ids, positions, is_prefill=True) logits = self.run_model(input_ids, positions, is_prefill=True)
reset_context() reset_context()

View File

@@ -34,6 +34,14 @@ class Sequence:
def __getitem__(self, key): def __getitem__(self, key):
return self.token_ids[key] return self.token_ids[key]
def __repr__(self):
ids = self.token_ids
if len(ids) > 20:
ids_str = "[" + ", ".join(map(str, ids[:10])) + ", ..., " + ", ".join(map(str, ids[-5:])) + "]"
else:
ids_str = str(ids)
return f"Seq(id={self.seq_id}, status={self.status.name}, tokens={self.num_tokens}, ids={ids_str})"
@property @property
def is_finished(self): def is_finished(self):
return self.status == SequenceStatus.FINISHED return self.status == SequenceStatus.FINISHED

View File

@@ -1,28 +1,44 @@
""" """
Test alignment between nanovllm and custom torch Qwen3 implementation. Test alignment between nanovllm and custom torch Qwen3 implementation.
Compares attention layer outputs and QKV tensors to verify correctness. Compares attention layer outputs and QKV tensors to verify correctness.
Usage:
python test_align.py # Without CPU offload
python test_align.py --enable-offload # With CPU offload
python test_align.py --input-len 4096 # Custom input length
""" """
import os import os
os.environ["NANOVLLM_LOG_LEVEL"] = "WARNING" os.environ["NANOVLLM_LOG_LEVEL"] = "WARNING"
import argparse
import torch import torch
from transformers import AutoTokenizer from transformers import AutoTokenizer
from nanovllm import LLM, SamplingParams from nanovllm import LLM, SamplingParams
from modeling_qwen3 import Qwen3ForCausalLM from modeling_qwen3 import Qwen3ForCausalLM
from utils import generate_needle_prompt from utils import generate_needle_prompt
# Parse arguments
parser = argparse.ArgumentParser()
parser.add_argument("--enable-offload", action="store_true", help="Enable CPU offload")
parser.add_argument("--input-len", type=int, default=1024 * 12, help="Input sequence length")
parser.add_argument("--model-path", type=str, default="~/models/Qwen3-0.6B/", help="Model path")
args = parser.parse_args()
# Config # Config
MODEL_PATH = os.path.expanduser("~/models/Qwen3-0.6B/") MODEL_PATH = os.path.expanduser(args.model_path)
INPUT_LEN = 64 INPUT_LEN = args.input_len
ENABLE_OFFLOAD = args.enable_offload
DTYPE = torch.float16 DTYPE = torch.float16
print(f"Config: input_len={INPUT_LEN}, enable_offload={ENABLE_OFFLOAD}")
# Storage for captured tensors # Storage for captured tensors
nanovllm_outputs = {} nanovllm_outputs = {}
torch_outputs = {} torch_outputs = {}
nanovllm_qkv = {} nanovllm_qkv = {}
nanovllm_proj_inputs = {} # Input to qkv_proj nanovllm_proj_inputs = {}
torch_proj_inputs = {} # Input to q_proj torch_proj_inputs = {}
def make_nanovllm_hook(layer_id: int, storage: dict): def make_nanovllm_hook(layer_id: int, storage: dict):
@@ -46,9 +62,7 @@ def make_nanovllm_qkv_hook(layer_id: int, storage: dict):
def make_proj_input_hook(layer_id: int, storage: dict): def make_proj_input_hook(layer_id: int, storage: dict):
"""Capture input to projection layer (hidden_states after layernorm)."""
def hook(module, inputs): def hook(module, inputs):
# inputs[0] is hidden_states
hidden = inputs[0] hidden = inputs[0]
if hidden.dim() == 2: if hidden.dim() == 2:
hidden = hidden.unsqueeze(0) hidden = hidden.unsqueeze(0)
@@ -62,25 +76,25 @@ def make_torch_hook(layer_id: int, storage: dict):
return hook return hook
def max_diff(t1: torch.Tensor, t2: torch.Tensor) -> float: def cosine_sim(t1: torch.Tensor, t2: torch.Tensor) -> float:
return (t1.float() - t2.float()).abs().max().item() """Cosine similarity between flattened tensors (1.0 = identical)."""
return torch.nn.functional.cosine_similarity(
t1.flatten().float(), t2.flatten().float(), dim=0
).item()
def compute_qkv_diffs(nano_qkv: dict, torch_qkv: dict, num_kv_groups: int): def compute_qkv_sims(nano_qkv: dict, torch_qkv: dict, num_kv_groups: int):
"""Compute Q, K, V max diffs. Returns (q_diff, k_diff, v_diff).""" """Compute Q, K, V cosine similarities. Returns (q_sim, k_sim, v_sim)."""
nano_q = nano_qkv["q"] nano_q = nano_qkv["q"]
torch_q = torch_qkv["q"].squeeze(0).transpose(0, 1) torch_q = torch_qkv["q"].squeeze(0).transpose(0, 1)
q_diff = max_diff(nano_q, torch_q)
nano_k = nano_qkv["k"] nano_k = nano_qkv["k"]
torch_k = torch_qkv["k"].squeeze(0)[::num_kv_groups, :, :].transpose(0, 1) torch_k = torch_qkv["k"].squeeze(0)[::num_kv_groups, :, :].transpose(0, 1)
k_diff = max_diff(nano_k, torch_k)
nano_v = nano_qkv["v"] nano_v = nano_qkv["v"]
torch_v = torch_qkv["v"].squeeze(0)[::num_kv_groups, :, :].transpose(0, 1) torch_v = torch_qkv["v"].squeeze(0)[::num_kv_groups, :, :].transpose(0, 1)
v_diff = max_diff(nano_v, torch_v)
return q_diff, k_diff, v_diff return cosine_sim(nano_q, torch_q), cosine_sim(nano_k, torch_k), cosine_sim(nano_v, torch_v)
# ============================================================ # ============================================================
@@ -90,9 +104,10 @@ print("Loading nanovllm model...")
llm = LLM( llm = LLM(
MODEL_PATH, MODEL_PATH,
enforce_eager=True, enforce_eager=True,
max_model_len=4096, max_model_len=32768,
max_num_batched_tokens=4096, gpu_memory_utilization=0.2,
enable_cpu_offload=False, max_num_batched_tokens=32768,
enable_cpu_offload=ENABLE_OFFLOAD,
dtype="float16", dtype="float16",
) )
@@ -139,39 +154,41 @@ with torch.no_grad():
torch_logits, _, torch_qkv_outputs = torch_model(input_ids, output_qkv_layers=list(range(num_layers))) torch_logits, _, torch_qkv_outputs = torch_model(input_ids, output_qkv_layers=list(range(num_layers)))
# ============================================================ # ============================================================
# Compare QKVO per layer (one line each) # Compare using cosine similarity (1.0 = perfect alignment)
# ============================================================ # ============================================================
print("\n" + "=" * 82) print("\n" + "=" * 70)
print(f"{'Layer':<8} {'I':>10} {'Q':>10} {'K':>10} {'V':>10} {'O':>10}") print(f"{'Layer':<8} {'I':>10} {'Q':>10} {'K':>10} {'V':>10} {'O':>10}")
print("=" * 82) print("=" * 70)
all_passed = True all_passed = True
atol = 0.1 threshold = 0.999 # Cosine similarity threshold
for layer_idx in range(num_layers): for layer_idx in range(num_layers):
# Input diff (to qkv_proj / q_proj) # Input similarity
nano_in = nanovllm_proj_inputs[layer_idx] nano_in = nanovllm_proj_inputs[layer_idx]
torch_in = torch_proj_inputs[layer_idx] torch_in = torch_proj_inputs[layer_idx]
if nano_in.shape != torch_in.shape and nano_in.numel() == torch_in.numel(): if nano_in.shape != torch_in.shape and nano_in.numel() == torch_in.numel():
torch_in = torch_in.view(nano_in.shape) torch_in = torch_in.view(nano_in.shape)
i_diff = max_diff(nano_in, torch_in)
i_sim = cosine_sim(nano_in, torch_in)
# QKV diffs # QKV similarities
q_diff, k_diff, v_diff = compute_qkv_diffs(nanovllm_qkv[layer_idx], torch_qkv_outputs[layer_idx], num_kv_groups) q_sim, k_sim, v_sim = compute_qkv_sims(nanovllm_qkv[layer_idx], torch_qkv_outputs[layer_idx], num_kv_groups)
# O diff # O similarity
nano_out = nanovllm_outputs[layer_idx] nano_out = nanovllm_outputs[layer_idx]
torch_out = torch_outputs[layer_idx] torch_out = torch_outputs[layer_idx]
if nano_out.shape != torch_out.shape and nano_out.numel() == torch_out.numel(): if nano_out.shape != torch_out.shape and nano_out.numel() == torch_out.numel():
torch_out = torch_out.view(nano_out.shape) torch_out = torch_out.view(nano_out.shape)
o_diff = max_diff(nano_out, torch_out) o_sim = cosine_sim(nano_out, torch_out)
# Check pass/fail # Check pass/fail
passed = all(d < atol for d in [i_diff, q_diff, k_diff, v_diff, o_diff]) passed = all(s >= threshold for s in [i_sim, q_sim, k_sim, v_sim, o_sim])
all_passed = all_passed and passed all_passed = all_passed and passed
status = "" if passed else " *" status = "" if passed else " *"
print(f"Layer {layer_idx:2d}{status:<3} {i_diff:>10.6f} {q_diff:>10.6f} {k_diff:>10.6f} {v_diff:>10.6f} {o_diff:>10.6f}") print(f"Layer {layer_idx:2d}{status:<3} {i_sim:>10.6f} {q_sim:>10.6f} {k_sim:>10.6f} {v_sim:>10.6f} {o_sim:>10.6f}")
# ============================================================ # ============================================================
# Cleanup and result # Cleanup and result
@@ -179,8 +196,8 @@ for layer_idx in range(num_layers):
for hook in nanovllm_hooks + torch_hooks: for hook in nanovllm_hooks + torch_hooks:
hook.remove() hook.remove()
print("=" * 82) print("=" * 70)
if all_passed: if all_passed:
print("test_align: PASSED") print("test_align: PASSED (cosine_sim >= 0.999)")
else: else:
print("test_align: FAILED (* = max_diff >= 0.1)") print("test_align: FAILED (* = cosine_sim < 0.999)")