[refactor] refactor test_align.py.
This commit is contained in:
@@ -1,6 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
Test alignment between nanovllm and custom torch Qwen3 implementation.
|
Test alignment between nanovllm and custom torch Qwen3 implementation.
|
||||||
Compares attention layer outputs to verify correctness.
|
Compares attention layer outputs and QKV tensors to verify correctness.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
@@ -14,88 +14,94 @@ from utils import generate_needle_prompt
|
|||||||
|
|
||||||
# Config
|
# Config
|
||||||
MODEL_PATH = os.path.expanduser("~/models/Qwen3-0.6B/")
|
MODEL_PATH = os.path.expanduser("~/models/Qwen3-0.6B/")
|
||||||
INPUT_LEN = 512 # Use shorter length for alignment test
|
INPUT_LEN = 64
|
||||||
DTYPE = torch.float16
|
DTYPE = torch.float16
|
||||||
|
|
||||||
# Storage for captured tensors
|
# Storage for captured tensors
|
||||||
nanovllm_outputs = {}
|
nanovllm_outputs = {}
|
||||||
torch_outputs = {}
|
torch_outputs = {}
|
||||||
|
nanovllm_qkv = {}
|
||||||
|
nanovllm_proj_inputs = {} # Input to qkv_proj
|
||||||
|
torch_proj_inputs = {} # Input to q_proj
|
||||||
|
|
||||||
|
|
||||||
def make_nanovllm_hook(layer_id: int, storage: dict):
|
def make_nanovllm_hook(layer_id: int, storage: dict):
|
||||||
"""Capture nanovllm self_attn outputs (after o_proj)."""
|
|
||||||
def hook(module, inputs, output):
|
def hook(module, inputs, output):
|
||||||
# Qwen3Attention output is a tuple (attn_output, None)
|
attn_output = output[0] if isinstance(output, tuple) else output
|
||||||
if isinstance(output, tuple):
|
|
||||||
attn_output = output[0]
|
|
||||||
else:
|
|
||||||
attn_output = output
|
|
||||||
# nanovllm shape: [num_tokens, hidden_size] -> add batch dim
|
|
||||||
if attn_output.dim() == 2:
|
if attn_output.dim() == 2:
|
||||||
attn_output = attn_output.unsqueeze(0)
|
attn_output = attn_output.unsqueeze(0)
|
||||||
storage[layer_id] = attn_output.detach().clone()
|
storage[layer_id] = attn_output.detach().clone()
|
||||||
return hook
|
return hook
|
||||||
|
|
||||||
|
|
||||||
def make_torch_hook(layer_id: int, storage: dict):
|
def make_nanovllm_qkv_hook(layer_id: int, storage: dict):
|
||||||
"""Capture torch model self_attn outputs (after o_proj)."""
|
def hook(module, inputs):
|
||||||
def hook(module, inputs, output):
|
q, k, v = inputs[0], inputs[1], inputs[2]
|
||||||
# Qwen3Attention output is (attn_output, past_kv, qkv_dict)
|
storage[layer_id] = {
|
||||||
attn_output, _, _ = output
|
"q": q.detach().clone(),
|
||||||
storage[layer_id] = attn_output.detach().clone()
|
"k": k.detach().clone(),
|
||||||
|
"v": v.detach().clone(),
|
||||||
|
}
|
||||||
return hook
|
return hook
|
||||||
|
|
||||||
|
|
||||||
def compare_tensors(name: str, t1: torch.Tensor, t2: torch.Tensor, atol: float = 1e-2):
|
def make_proj_input_hook(layer_id: int, storage: dict):
|
||||||
"""Compare two tensors and print statistics."""
|
"""Capture input to projection layer (hidden_states after layernorm)."""
|
||||||
# Handle shape differences
|
def hook(module, inputs):
|
||||||
if t1.shape != t2.shape:
|
# inputs[0] is hidden_states
|
||||||
print(f"[{name}] Shape mismatch: {t1.shape} vs {t2.shape}")
|
hidden = inputs[0]
|
||||||
# Try to reshape for comparison if possible
|
if hidden.dim() == 2:
|
||||||
if t1.numel() == t2.numel():
|
hidden = hidden.unsqueeze(0)
|
||||||
t2 = t2.view(t1.shape)
|
storage[layer_id] = hidden.detach().clone()
|
||||||
else:
|
return hook
|
||||||
return False
|
|
||||||
|
|
||||||
diff = (t1.float() - t2.float()).abs()
|
|
||||||
max_diff = diff.max().item()
|
|
||||||
mean_diff = diff.mean().item()
|
|
||||||
|
|
||||||
passed = max_diff < atol
|
def make_torch_hook(layer_id: int, storage: dict):
|
||||||
status = "PASS" if passed else "FAIL"
|
def hook(module, inputs, output):
|
||||||
|
storage[layer_id] = output[0].detach().clone()
|
||||||
|
return hook
|
||||||
|
|
||||||
print(f"[{name}] {status}")
|
|
||||||
print(f" Shape: {list(t1.shape)}")
|
|
||||||
print(f" t1 mean: {t1.float().mean():.6f}, std: {t1.float().std():.6f}")
|
|
||||||
print(f" t2 mean: {t2.float().mean():.6f}, std: {t2.float().std():.6f}")
|
|
||||||
print(f" Max diff: {max_diff:.6f}, Mean diff: {mean_diff:.6f}")
|
|
||||||
|
|
||||||
return passed
|
def max_diff(t1: torch.Tensor, t2: torch.Tensor) -> float:
|
||||||
|
return (t1.float() - t2.float()).abs().max().item()
|
||||||
|
|
||||||
|
|
||||||
|
def compute_qkv_diffs(nano_qkv: dict, torch_qkv: dict, num_kv_groups: int):
|
||||||
|
"""Compute Q, K, V max diffs. Returns (q_diff, k_diff, v_diff)."""
|
||||||
|
nano_q = nano_qkv["q"]
|
||||||
|
torch_q = torch_qkv["q"].squeeze(0).transpose(0, 1)
|
||||||
|
q_diff = max_diff(nano_q, torch_q)
|
||||||
|
|
||||||
|
nano_k = nano_qkv["k"]
|
||||||
|
torch_k = torch_qkv["k"].squeeze(0)[::num_kv_groups, :, :].transpose(0, 1)
|
||||||
|
k_diff = max_diff(nano_k, torch_k)
|
||||||
|
|
||||||
|
nano_v = nano_qkv["v"]
|
||||||
|
torch_v = torch_qkv["v"].squeeze(0)[::num_kv_groups, :, :].transpose(0, 1)
|
||||||
|
v_diff = max_diff(nano_v, torch_v)
|
||||||
|
|
||||||
|
return q_diff, k_diff, v_diff
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
# ============================================================
|
||||||
# Load nanovllm model
|
# Load models
|
||||||
# ============================================================
|
# ============================================================
|
||||||
print("=" * 60)
|
|
||||||
print("Loading nanovllm model...")
|
print("Loading nanovllm model...")
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
llm = LLM(
|
llm = LLM(
|
||||||
MODEL_PATH,
|
MODEL_PATH,
|
||||||
enforce_eager=True,
|
enforce_eager=True,
|
||||||
max_model_len=4096,
|
max_model_len=4096,
|
||||||
max_num_batched_tokens=4096,
|
max_num_batched_tokens=4096,
|
||||||
enable_cpu_offload=False, # Disable offload for alignment test
|
enable_cpu_offload=False,
|
||||||
dtype="float16",
|
dtype="float16",
|
||||||
)
|
)
|
||||||
|
|
||||||
# ============================================================
|
num_heads = llm.model_runner.model.model.layers[0].self_attn.num_heads
|
||||||
# Load torch model
|
num_kv_heads = llm.model_runner.model.model.layers[0].self_attn.num_kv_heads
|
||||||
# ============================================================
|
num_kv_groups = num_heads // num_kv_heads
|
||||||
print("\n" + "=" * 60)
|
num_layers = len(llm.model_runner.model.model.layers)
|
||||||
print("Loading custom torch model...")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
|
print("Loading torch model...")
|
||||||
torch_model = Qwen3ForCausalLM.from_pretrained(MODEL_PATH, dtype=DTYPE)
|
torch_model = Qwen3ForCausalLM.from_pretrained(MODEL_PATH, dtype=DTYPE)
|
||||||
torch_model = torch_model.to("cuda")
|
torch_model = torch_model.to("cuda")
|
||||||
torch_model.eval()
|
torch_model.eval()
|
||||||
@@ -103,110 +109,78 @@ torch_model.eval()
|
|||||||
# ============================================================
|
# ============================================================
|
||||||
# Generate test input
|
# Generate test input
|
||||||
# ============================================================
|
# ============================================================
|
||||||
print("\n" + "=" * 60)
|
|
||||||
print("Generating test input...")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
|
||||||
prompt, _ = generate_needle_prompt(
|
prompt, _ = generate_needle_prompt(tokenizer=tokenizer, target_length=INPUT_LEN, verbose=True)
|
||||||
tokenizer=tokenizer,
|
|
||||||
target_length=INPUT_LEN,
|
|
||||||
verbose=True,
|
|
||||||
)
|
|
||||||
input_ids = tokenizer.encode(prompt, return_tensors="pt").to("cuda")
|
input_ids = tokenizer.encode(prompt, return_tensors="pt").to("cuda")
|
||||||
print(f"Input shape: {input_ids.shape}")
|
print(f"Input shape: {input_ids.shape}")
|
||||||
|
|
||||||
# ============================================================
|
# ============================================================
|
||||||
# Register hooks on both models
|
# Register hooks
|
||||||
# ============================================================
|
# ============================================================
|
||||||
print("\n" + "=" * 60)
|
|
||||||
print("Registering hooks...")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
# Hook on nanovllm (self_attn is Qwen3Attention, captures output after o_proj)
|
|
||||||
nanovllm_hooks = []
|
nanovllm_hooks = []
|
||||||
for layer_idx, layer in enumerate(llm.model_runner.model.model.layers):
|
for layer_idx, layer in enumerate(llm.model_runner.model.model.layers):
|
||||||
if layer_idx >= 2: # Only first 2 layers
|
nanovllm_hooks.append(layer.self_attn.register_forward_hook(make_nanovllm_hook(layer_idx, nanovllm_outputs)))
|
||||||
break
|
nanovllm_hooks.append(layer.self_attn.attn.register_forward_pre_hook(make_nanovllm_qkv_hook(layer_idx, nanovllm_qkv)))
|
||||||
nanovllm_hooks.append(
|
nanovllm_hooks.append(layer.self_attn.qkv_proj.register_forward_pre_hook(make_proj_input_hook(layer_idx, nanovllm_proj_inputs)))
|
||||||
layer.self_attn.register_forward_hook(
|
|
||||||
make_nanovllm_hook(layer_idx, nanovllm_outputs)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
print(f" Registered nanovllm hook on layer {layer_idx} self_attn")
|
|
||||||
|
|
||||||
# Hook on torch model (self_attn is Qwen3Attention, captures output after o_proj)
|
|
||||||
torch_hooks = []
|
torch_hooks = []
|
||||||
for layer_idx, layer in enumerate(torch_model.model.layers):
|
for layer_idx, layer in enumerate(torch_model.model.layers):
|
||||||
if layer_idx >= 2: # Only first 2 layers
|
torch_hooks.append(layer.self_attn.register_forward_hook(make_torch_hook(layer_idx, torch_outputs)))
|
||||||
break
|
torch_hooks.append(layer.self_attn.q_proj.register_forward_pre_hook(make_proj_input_hook(layer_idx, torch_proj_inputs)))
|
||||||
torch_hooks.append(
|
|
||||||
layer.self_attn.register_forward_hook(
|
|
||||||
make_torch_hook(layer_idx, torch_outputs)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
print(f" Registered torch hook on layer {layer_idx} self_attn")
|
|
||||||
|
|
||||||
# ============================================================
|
# ============================================================
|
||||||
# Run nanovllm inference
|
# Run inference
|
||||||
# ============================================================
|
# ============================================================
|
||||||
print("\n" + "=" * 60)
|
|
||||||
print("Running nanovllm inference...")
|
print("Running nanovllm inference...")
|
||||||
print("=" * 60)
|
nanovllm_result = llm.generate([input_ids[0].tolist()], SamplingParams(temperature=0.01, max_tokens=1), use_tqdm=False)
|
||||||
|
|
||||||
# Use prompt_token_ids to ensure same input
|
|
||||||
prompt_token_ids = input_ids[0].tolist()
|
|
||||||
nanovllm_result = llm.generate(
|
|
||||||
[prompt_token_ids],
|
|
||||||
SamplingParams(temperature=0.01, max_tokens=1), # Near-greedy for determinism
|
|
||||||
use_tqdm=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Run torch inference
|
|
||||||
# ============================================================
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
print("Running torch inference...")
|
print("Running torch inference...")
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
torch_logits, _, _ = torch_model(input_ids)
|
torch_logits, _, torch_qkv_outputs = torch_model(input_ids, output_qkv_layers=list(range(num_layers)))
|
||||||
|
|
||||||
# ============================================================
|
# ============================================================
|
||||||
# Compare outputs
|
# Compare QKVO per layer (one line each)
|
||||||
# ============================================================
|
# ============================================================
|
||||||
print("\n" + "=" * 60)
|
print("\n" + "=" * 82)
|
||||||
print("Comparing attention outputs...")
|
print(f"{'Layer':<8} {'I':>10} {'Q':>10} {'K':>10} {'V':>10} {'O':>10}")
|
||||||
print("=" * 60)
|
print("=" * 82)
|
||||||
|
|
||||||
all_passed = True
|
all_passed = True
|
||||||
for layer_idx in sorted(nanovllm_outputs.keys()):
|
atol = 0.1
|
||||||
if layer_idx not in torch_outputs:
|
|
||||||
print(f"[Layer {layer_idx}] Missing torch output")
|
|
||||||
all_passed = False
|
|
||||||
continue
|
|
||||||
|
|
||||||
|
for layer_idx in range(num_layers):
|
||||||
|
# Input diff (to qkv_proj / q_proj)
|
||||||
|
nano_in = nanovllm_proj_inputs[layer_idx]
|
||||||
|
torch_in = torch_proj_inputs[layer_idx]
|
||||||
|
if nano_in.shape != torch_in.shape and nano_in.numel() == torch_in.numel():
|
||||||
|
torch_in = torch_in.view(nano_in.shape)
|
||||||
|
i_diff = max_diff(nano_in, torch_in)
|
||||||
|
|
||||||
|
# QKV diffs
|
||||||
|
q_diff, k_diff, v_diff = compute_qkv_diffs(nanovllm_qkv[layer_idx], torch_qkv_outputs[layer_idx], num_kv_groups)
|
||||||
|
|
||||||
|
# O diff
|
||||||
nano_out = nanovllm_outputs[layer_idx]
|
nano_out = nanovllm_outputs[layer_idx]
|
||||||
torch_out = torch_outputs[layer_idx]
|
torch_out = torch_outputs[layer_idx]
|
||||||
|
if nano_out.shape != torch_out.shape and nano_out.numel() == torch_out.numel():
|
||||||
|
torch_out = torch_out.view(nano_out.shape)
|
||||||
|
o_diff = max_diff(nano_out, torch_out)
|
||||||
|
|
||||||
print(f"\n--- Layer {layer_idx} ---")
|
# Check pass/fail
|
||||||
passed = compare_tensors(f"Layer {layer_idx} attn_output", nano_out, torch_out, atol=0.1)
|
passed = all(d < atol for d in [i_diff, q_diff, k_diff, v_diff, o_diff])
|
||||||
all_passed = all_passed and passed
|
all_passed = all_passed and passed
|
||||||
|
status = "" if passed else " *"
|
||||||
|
|
||||||
|
print(f"Layer {layer_idx:2d}{status:<3} {i_diff:>10.6f} {q_diff:>10.6f} {k_diff:>10.6f} {v_diff:>10.6f} {o_diff:>10.6f}")
|
||||||
|
|
||||||
# ============================================================
|
# ============================================================
|
||||||
# Cleanup
|
# Cleanup and result
|
||||||
# ============================================================
|
# ============================================================
|
||||||
for hook in nanovllm_hooks:
|
for hook in nanovllm_hooks + torch_hooks:
|
||||||
hook.remove()
|
|
||||||
for hook in torch_hooks:
|
|
||||||
hook.remove()
|
hook.remove()
|
||||||
|
|
||||||
# ============================================================
|
print("=" * 82)
|
||||||
# Result
|
|
||||||
# ============================================================
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
if all_passed:
|
if all_passed:
|
||||||
print("test_align: PASSED - nanovllm and torch outputs aligned!")
|
print("test_align: PASSED")
|
||||||
else:
|
else:
|
||||||
print("test_align: FAILED - outputs differ!")
|
print("test_align: FAILED (* = max_diff >= 0.1)")
|
||||||
print("=" * 60)
|
|
||||||
|
|||||||
105
tests/utils.py
105
tests/utils.py
@@ -55,7 +55,7 @@ def generate_needle_prompt(
|
|||||||
verbose: bool = True,
|
verbose: bool = True,
|
||||||
) -> Tuple[str, str]:
|
) -> Tuple[str, str]:
|
||||||
"""
|
"""
|
||||||
Generate a needle-in-haystack prompt of approximately target_length tokens.
|
Generate a needle-in-haystack prompt of exactly target_length tokens.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tokenizer: HuggingFace tokenizer for length estimation
|
tokenizer: HuggingFace tokenizer for length estimation
|
||||||
@@ -71,68 +71,79 @@ def generate_needle_prompt(
|
|||||||
# The needle sentence
|
# The needle sentence
|
||||||
needle = f"The secret number you need to remember is {needle_value}. This is very important. "
|
needle = f"The secret number you need to remember is {needle_value}. This is very important. "
|
||||||
|
|
||||||
# Question at the end
|
# Question text
|
||||||
question = "\n\nQuestion: What is the secret number mentioned in the text above?\nAnswer: The secret number is"
|
if use_chat_template and hasattr(tokenizer, 'apply_chat_template'):
|
||||||
|
question_text = "/no_think Answer only with the secret number mentioned above, nothing else:"
|
||||||
|
else:
|
||||||
|
question_text = "\n\nQuestion: What is the secret number mentioned in the text above?\nAnswer: The secret number is"
|
||||||
|
|
||||||
# Estimate tokens for fixed parts
|
def build_prompt(haystack_parts, needle_idx):
|
||||||
needle_tokens = len(tokenizer.encode(needle, add_special_tokens=False))
|
"""Build full prompt from haystack parts with needle inserted."""
|
||||||
question_text = "What is the secret number mentioned in the text above? Answer with just the number."
|
parts = haystack_parts.copy()
|
||||||
question_tokens = len(tokenizer.encode(question_text, add_special_tokens=False))
|
parts.insert(needle_idx, needle)
|
||||||
# Buffer for chat template, special tokens, etc.
|
full_text = "".join(parts)
|
||||||
overhead_tokens = 100 if use_chat_template else 50
|
|
||||||
|
|
||||||
# Available tokens for haystack
|
if use_chat_template and hasattr(tokenizer, 'apply_chat_template'):
|
||||||
haystack_target_tokens = target_length - needle_tokens - question_tokens - overhead_tokens
|
messages = [{"role": "user", "content": f"{full_text}\n\n{question_text}"}]
|
||||||
if haystack_target_tokens < 100:
|
return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||||
raise ValueError(f"target_length {target_length} is too short for needle test")
|
else:
|
||||||
|
return full_text + question_text
|
||||||
|
|
||||||
# Build haystack by repeating paragraphs
|
def count_tokens(prompt):
|
||||||
|
return len(tokenizer.encode(prompt, add_special_tokens=False))
|
||||||
|
|
||||||
|
def get_needle_idx(parts):
|
||||||
|
idx = int(len(parts) * needle_position)
|
||||||
|
return max(0, min(idx, len(parts)))
|
||||||
|
|
||||||
|
# Phase 1: Build haystack with full paragraphs until we exceed target
|
||||||
haystack_parts = []
|
haystack_parts = []
|
||||||
current_tokens = 0
|
|
||||||
para_idx = 0
|
para_idx = 0
|
||||||
|
|
||||||
while current_tokens < haystack_target_tokens:
|
while True:
|
||||||
para = HAYSTACK_PARAGRAPHS[para_idx % len(HAYSTACK_PARAGRAPHS)]
|
para = HAYSTACK_PARAGRAPHS[para_idx % len(HAYSTACK_PARAGRAPHS)]
|
||||||
para_tokens = len(tokenizer.encode(para, add_special_tokens=False))
|
test_parts = haystack_parts + [para]
|
||||||
if current_tokens + para_tokens > haystack_target_tokens:
|
prompt = build_prompt(test_parts, get_needle_idx(test_parts))
|
||||||
|
|
||||||
|
if count_tokens(prompt) > target_length:
|
||||||
break
|
break
|
||||||
|
|
||||||
haystack_parts.append(para)
|
haystack_parts.append(para)
|
||||||
current_tokens += para_tokens
|
|
||||||
para_idx += 1
|
para_idx += 1
|
||||||
|
|
||||||
# Calculate needle insertion point
|
if para_idx > 10000: # Safety limit
|
||||||
needle_idx = int(len(haystack_parts) * needle_position)
|
break
|
||||||
needle_idx = max(0, min(needle_idx, len(haystack_parts)))
|
|
||||||
|
|
||||||
# Insert needle
|
# Phase 2: Fine-tune by adding words from next paragraph
|
||||||
haystack_parts.insert(needle_idx, needle)
|
next_para = HAYSTACK_PARAGRAPHS[para_idx % len(HAYSTACK_PARAGRAPHS)]
|
||||||
|
words = next_para.split()
|
||||||
|
|
||||||
# Assemble prompt
|
best_parts = haystack_parts.copy()
|
||||||
full_text = "".join(haystack_parts)
|
best_diff = abs(target_length - count_tokens(build_prompt(haystack_parts, get_needle_idx(haystack_parts))))
|
||||||
|
|
||||||
if use_chat_template and hasattr(tokenizer, 'apply_chat_template'):
|
for i in range(1, len(words) + 1):
|
||||||
# Use chat template for instruct models
|
partial = " ".join(words[:i]) + " "
|
||||||
# For Qwen3, add /no_think to disable thinking mode
|
test_parts = haystack_parts + [partial]
|
||||||
question_text = "/no_think Answer only with the secret number mentioned above, nothing else:"
|
prompt = build_prompt(test_parts, get_needle_idx(test_parts))
|
||||||
messages = [
|
token_count = count_tokens(prompt)
|
||||||
{"role": "user", "content": f"{full_text}\n\n{question_text}"}
|
diff = abs(target_length - token_count)
|
||||||
]
|
|
||||||
prompt = tokenizer.apply_chat_template(
|
|
||||||
messages,
|
|
||||||
tokenize=False,
|
|
||||||
add_generation_prompt=True,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Raw text format for base models
|
|
||||||
question = "\n\nQuestion: What is the secret number mentioned in the text above?\nAnswer: The secret number is"
|
|
||||||
prompt = full_text + question
|
|
||||||
|
|
||||||
# Verify length
|
if diff < best_diff:
|
||||||
actual_tokens = len(tokenizer.encode(prompt, add_special_tokens=False))
|
best_diff = diff
|
||||||
|
best_parts = test_parts.copy()
|
||||||
|
|
||||||
|
if token_count >= target_length:
|
||||||
|
break
|
||||||
|
|
||||||
|
haystack_parts = best_parts
|
||||||
|
|
||||||
|
# Final build
|
||||||
|
needle_idx = get_needle_idx(haystack_parts)
|
||||||
|
prompt = build_prompt(haystack_parts, needle_idx)
|
||||||
|
|
||||||
|
actual_tokens = count_tokens(prompt)
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f"[NeedleTest] Target: {target_length} tokens, Actual: {actual_tokens} tokens")
|
print(f"[NeedleTest] Target: {target_length}, Actual: {actual_tokens} tokens (diff={actual_tokens - target_length})")
|
||||||
print(f"[NeedleTest] Needle position: {needle_position:.0%} ({needle_idx}/{len(haystack_parts)-1} paragraphs)")
|
|
||||||
print(f"[NeedleTest] Using chat template: {use_chat_template and hasattr(tokenizer, 'apply_chat_template')}")
|
|
||||||
|
|
||||||
return prompt, needle_value
|
return prompt, needle_value
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user