[test] Added test_align.py and Before change nanovllm attention.
This commit is contained in:
@@ -480,7 +480,7 @@ class ModelRunner:
|
|||||||
if input_ids.numel() == 0:
|
if input_ids.numel() == 0:
|
||||||
break
|
break
|
||||||
|
|
||||||
# Run model forward
|
#> Run model forward
|
||||||
logits = self.run_model(input_ids, positions, is_prefill=True)
|
logits = self.run_model(input_ids, positions, is_prefill=True)
|
||||||
reset_context()
|
reset_context()
|
||||||
|
|
||||||
|
|||||||
@@ -34,6 +34,14 @@ class Sequence:
|
|||||||
def __getitem__(self, key):
|
def __getitem__(self, key):
|
||||||
return self.token_ids[key]
|
return self.token_ids[key]
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
ids = self.token_ids
|
||||||
|
if len(ids) > 20:
|
||||||
|
ids_str = "[" + ", ".join(map(str, ids[:10])) + ", ..., " + ", ".join(map(str, ids[-5:])) + "]"
|
||||||
|
else:
|
||||||
|
ids_str = str(ids)
|
||||||
|
return f"Seq(id={self.seq_id}, status={self.status.name}, tokens={self.num_tokens}, ids={ids_str})"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_finished(self):
|
def is_finished(self):
|
||||||
return self.status == SequenceStatus.FINISHED
|
return self.status == SequenceStatus.FINISHED
|
||||||
|
|||||||
@@ -1,28 +1,44 @@
|
|||||||
"""
|
"""
|
||||||
Test alignment between nanovllm and custom torch Qwen3 implementation.
|
Test alignment between nanovllm and custom torch Qwen3 implementation.
|
||||||
Compares attention layer outputs and QKV tensors to verify correctness.
|
Compares attention layer outputs and QKV tensors to verify correctness.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python test_align.py # Without CPU offload
|
||||||
|
python test_align.py --enable-offload # With CPU offload
|
||||||
|
python test_align.py --input-len 4096 # Custom input length
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
os.environ["NANOVLLM_LOG_LEVEL"] = "WARNING"
|
os.environ["NANOVLLM_LOG_LEVEL"] = "WARNING"
|
||||||
|
|
||||||
|
import argparse
|
||||||
import torch
|
import torch
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
from nanovllm import LLM, SamplingParams
|
from nanovllm import LLM, SamplingParams
|
||||||
from modeling_qwen3 import Qwen3ForCausalLM
|
from modeling_qwen3 import Qwen3ForCausalLM
|
||||||
from utils import generate_needle_prompt
|
from utils import generate_needle_prompt
|
||||||
|
|
||||||
|
# Parse arguments
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--enable-offload", action="store_true", help="Enable CPU offload")
|
||||||
|
parser.add_argument("--input-len", type=int, default=1024 * 12, help="Input sequence length")
|
||||||
|
parser.add_argument("--model-path", type=str, default="~/models/Qwen3-0.6B/", help="Model path")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Config
|
# Config
|
||||||
MODEL_PATH = os.path.expanduser("~/models/Qwen3-0.6B/")
|
MODEL_PATH = os.path.expanduser(args.model_path)
|
||||||
INPUT_LEN = 64
|
INPUT_LEN = args.input_len
|
||||||
|
ENABLE_OFFLOAD = args.enable_offload
|
||||||
DTYPE = torch.float16
|
DTYPE = torch.float16
|
||||||
|
|
||||||
|
print(f"Config: input_len={INPUT_LEN}, enable_offload={ENABLE_OFFLOAD}")
|
||||||
|
|
||||||
# Storage for captured tensors
|
# Storage for captured tensors
|
||||||
nanovllm_outputs = {}
|
nanovllm_outputs = {}
|
||||||
torch_outputs = {}
|
torch_outputs = {}
|
||||||
nanovllm_qkv = {}
|
nanovllm_qkv = {}
|
||||||
nanovllm_proj_inputs = {} # Input to qkv_proj
|
nanovllm_proj_inputs = {}
|
||||||
torch_proj_inputs = {} # Input to q_proj
|
torch_proj_inputs = {}
|
||||||
|
|
||||||
|
|
||||||
def make_nanovllm_hook(layer_id: int, storage: dict):
|
def make_nanovllm_hook(layer_id: int, storage: dict):
|
||||||
@@ -46,9 +62,7 @@ def make_nanovllm_qkv_hook(layer_id: int, storage: dict):
|
|||||||
|
|
||||||
|
|
||||||
def make_proj_input_hook(layer_id: int, storage: dict):
|
def make_proj_input_hook(layer_id: int, storage: dict):
|
||||||
"""Capture input to projection layer (hidden_states after layernorm)."""
|
|
||||||
def hook(module, inputs):
|
def hook(module, inputs):
|
||||||
# inputs[0] is hidden_states
|
|
||||||
hidden = inputs[0]
|
hidden = inputs[0]
|
||||||
if hidden.dim() == 2:
|
if hidden.dim() == 2:
|
||||||
hidden = hidden.unsqueeze(0)
|
hidden = hidden.unsqueeze(0)
|
||||||
@@ -62,25 +76,25 @@ def make_torch_hook(layer_id: int, storage: dict):
|
|||||||
return hook
|
return hook
|
||||||
|
|
||||||
|
|
||||||
def max_diff(t1: torch.Tensor, t2: torch.Tensor) -> float:
|
def cosine_sim(t1: torch.Tensor, t2: torch.Tensor) -> float:
|
||||||
return (t1.float() - t2.float()).abs().max().item()
|
"""Cosine similarity between flattened tensors (1.0 = identical)."""
|
||||||
|
return torch.nn.functional.cosine_similarity(
|
||||||
|
t1.flatten().float(), t2.flatten().float(), dim=0
|
||||||
|
).item()
|
||||||
|
|
||||||
|
|
||||||
def compute_qkv_diffs(nano_qkv: dict, torch_qkv: dict, num_kv_groups: int):
|
def compute_qkv_sims(nano_qkv: dict, torch_qkv: dict, num_kv_groups: int):
|
||||||
"""Compute Q, K, V max diffs. Returns (q_diff, k_diff, v_diff)."""
|
"""Compute Q, K, V cosine similarities. Returns (q_sim, k_sim, v_sim)."""
|
||||||
nano_q = nano_qkv["q"]
|
nano_q = nano_qkv["q"]
|
||||||
torch_q = torch_qkv["q"].squeeze(0).transpose(0, 1)
|
torch_q = torch_qkv["q"].squeeze(0).transpose(0, 1)
|
||||||
q_diff = max_diff(nano_q, torch_q)
|
|
||||||
|
|
||||||
nano_k = nano_qkv["k"]
|
nano_k = nano_qkv["k"]
|
||||||
torch_k = torch_qkv["k"].squeeze(0)[::num_kv_groups, :, :].transpose(0, 1)
|
torch_k = torch_qkv["k"].squeeze(0)[::num_kv_groups, :, :].transpose(0, 1)
|
||||||
k_diff = max_diff(nano_k, torch_k)
|
|
||||||
|
|
||||||
nano_v = nano_qkv["v"]
|
nano_v = nano_qkv["v"]
|
||||||
torch_v = torch_qkv["v"].squeeze(0)[::num_kv_groups, :, :].transpose(0, 1)
|
torch_v = torch_qkv["v"].squeeze(0)[::num_kv_groups, :, :].transpose(0, 1)
|
||||||
v_diff = max_diff(nano_v, torch_v)
|
|
||||||
|
|
||||||
return q_diff, k_diff, v_diff
|
return cosine_sim(nano_q, torch_q), cosine_sim(nano_k, torch_k), cosine_sim(nano_v, torch_v)
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
# ============================================================
|
||||||
@@ -90,9 +104,10 @@ print("Loading nanovllm model...")
|
|||||||
llm = LLM(
|
llm = LLM(
|
||||||
MODEL_PATH,
|
MODEL_PATH,
|
||||||
enforce_eager=True,
|
enforce_eager=True,
|
||||||
max_model_len=4096,
|
max_model_len=32768,
|
||||||
max_num_batched_tokens=4096,
|
gpu_memory_utilization=0.2,
|
||||||
enable_cpu_offload=False,
|
max_num_batched_tokens=32768,
|
||||||
|
enable_cpu_offload=ENABLE_OFFLOAD,
|
||||||
dtype="float16",
|
dtype="float16",
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -139,39 +154,41 @@ with torch.no_grad():
|
|||||||
torch_logits, _, torch_qkv_outputs = torch_model(input_ids, output_qkv_layers=list(range(num_layers)))
|
torch_logits, _, torch_qkv_outputs = torch_model(input_ids, output_qkv_layers=list(range(num_layers)))
|
||||||
|
|
||||||
# ============================================================
|
# ============================================================
|
||||||
# Compare QKVO per layer (one line each)
|
# Compare using cosine similarity (1.0 = perfect alignment)
|
||||||
# ============================================================
|
# ============================================================
|
||||||
print("\n" + "=" * 82)
|
print("\n" + "=" * 70)
|
||||||
print(f"{'Layer':<8} {'I':>10} {'Q':>10} {'K':>10} {'V':>10} {'O':>10}")
|
print(f"{'Layer':<8} {'I':>10} {'Q':>10} {'K':>10} {'V':>10} {'O':>10}")
|
||||||
print("=" * 82)
|
print("=" * 70)
|
||||||
|
|
||||||
all_passed = True
|
all_passed = True
|
||||||
atol = 0.1
|
threshold = 0.999 # Cosine similarity threshold
|
||||||
|
|
||||||
for layer_idx in range(num_layers):
|
for layer_idx in range(num_layers):
|
||||||
# Input diff (to qkv_proj / q_proj)
|
# Input similarity
|
||||||
nano_in = nanovllm_proj_inputs[layer_idx]
|
nano_in = nanovllm_proj_inputs[layer_idx]
|
||||||
torch_in = torch_proj_inputs[layer_idx]
|
torch_in = torch_proj_inputs[layer_idx]
|
||||||
|
|
||||||
if nano_in.shape != torch_in.shape and nano_in.numel() == torch_in.numel():
|
if nano_in.shape != torch_in.shape and nano_in.numel() == torch_in.numel():
|
||||||
torch_in = torch_in.view(nano_in.shape)
|
torch_in = torch_in.view(nano_in.shape)
|
||||||
i_diff = max_diff(nano_in, torch_in)
|
|
||||||
|
i_sim = cosine_sim(nano_in, torch_in)
|
||||||
|
|
||||||
# QKV diffs
|
# QKV similarities
|
||||||
q_diff, k_diff, v_diff = compute_qkv_diffs(nanovllm_qkv[layer_idx], torch_qkv_outputs[layer_idx], num_kv_groups)
|
q_sim, k_sim, v_sim = compute_qkv_sims(nanovllm_qkv[layer_idx], torch_qkv_outputs[layer_idx], num_kv_groups)
|
||||||
|
|
||||||
# O diff
|
# O similarity
|
||||||
nano_out = nanovllm_outputs[layer_idx]
|
nano_out = nanovllm_outputs[layer_idx]
|
||||||
torch_out = torch_outputs[layer_idx]
|
torch_out = torch_outputs[layer_idx]
|
||||||
if nano_out.shape != torch_out.shape and nano_out.numel() == torch_out.numel():
|
if nano_out.shape != torch_out.shape and nano_out.numel() == torch_out.numel():
|
||||||
torch_out = torch_out.view(nano_out.shape)
|
torch_out = torch_out.view(nano_out.shape)
|
||||||
o_diff = max_diff(nano_out, torch_out)
|
o_sim = cosine_sim(nano_out, torch_out)
|
||||||
|
|
||||||
# Check pass/fail
|
# Check pass/fail
|
||||||
passed = all(d < atol for d in [i_diff, q_diff, k_diff, v_diff, o_diff])
|
passed = all(s >= threshold for s in [i_sim, q_sim, k_sim, v_sim, o_sim])
|
||||||
all_passed = all_passed and passed
|
all_passed = all_passed and passed
|
||||||
status = "" if passed else " *"
|
status = "" if passed else " *"
|
||||||
|
|
||||||
print(f"Layer {layer_idx:2d}{status:<3} {i_diff:>10.6f} {q_diff:>10.6f} {k_diff:>10.6f} {v_diff:>10.6f} {o_diff:>10.6f}")
|
print(f"Layer {layer_idx:2d}{status:<3} {i_sim:>10.6f} {q_sim:>10.6f} {k_sim:>10.6f} {v_sim:>10.6f} {o_sim:>10.6f}")
|
||||||
|
|
||||||
# ============================================================
|
# ============================================================
|
||||||
# Cleanup and result
|
# Cleanup and result
|
||||||
@@ -179,8 +196,8 @@ for layer_idx in range(num_layers):
|
|||||||
for hook in nanovllm_hooks + torch_hooks:
|
for hook in nanovllm_hooks + torch_hooks:
|
||||||
hook.remove()
|
hook.remove()
|
||||||
|
|
||||||
print("=" * 82)
|
print("=" * 70)
|
||||||
if all_passed:
|
if all_passed:
|
||||||
print("test_align: PASSED")
|
print("test_align: PASSED (cosine_sim >= 0.999)")
|
||||||
else:
|
else:
|
||||||
print("test_align: FAILED (* = max_diff >= 0.1)")
|
print("test_align: FAILED (* = cosine_sim < 0.999)")
|
||||||
|
|||||||
Reference in New Issue
Block a user