[refactor] refactor test_align.py.

This commit is contained in:
Zijie Tian
2026-01-04 20:55:40 +08:00
parent 772313db8f
commit 24096431ed
2 changed files with 151 additions and 166 deletions

View File

@@ -55,7 +55,7 @@ def generate_needle_prompt(
verbose: bool = True,
) -> Tuple[str, str]:
"""
Generate a needle-in-haystack prompt of approximately target_length tokens.
Generate a needle-in-haystack prompt of exactly target_length tokens.
Args:
tokenizer: HuggingFace tokenizer for length estimation
@@ -71,68 +71,79 @@ def generate_needle_prompt(
# The needle sentence
needle = f"The secret number you need to remember is {needle_value}. This is very important. "
# Question at the end
question = "\n\nQuestion: What is the secret number mentioned in the text above?\nAnswer: The secret number is"
# Question text
if use_chat_template and hasattr(tokenizer, 'apply_chat_template'):
question_text = "/no_think Answer only with the secret number mentioned above, nothing else:"
else:
question_text = "\n\nQuestion: What is the secret number mentioned in the text above?\nAnswer: The secret number is"
# Estimate tokens for fixed parts
needle_tokens = len(tokenizer.encode(needle, add_special_tokens=False))
question_text = "What is the secret number mentioned in the text above? Answer with just the number."
question_tokens = len(tokenizer.encode(question_text, add_special_tokens=False))
# Buffer for chat template, special tokens, etc.
overhead_tokens = 100 if use_chat_template else 50
def build_prompt(haystack_parts, needle_idx):
"""Build full prompt from haystack parts with needle inserted."""
parts = haystack_parts.copy()
parts.insert(needle_idx, needle)
full_text = "".join(parts)
# Available tokens for haystack
haystack_target_tokens = target_length - needle_tokens - question_tokens - overhead_tokens
if haystack_target_tokens < 100:
raise ValueError(f"target_length {target_length} is too short for needle test")
if use_chat_template and hasattr(tokenizer, 'apply_chat_template'):
messages = [{"role": "user", "content": f"{full_text}\n\n{question_text}"}]
return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
else:
return full_text + question_text
# Build haystack by repeating paragraphs
def count_tokens(prompt):
return len(tokenizer.encode(prompt, add_special_tokens=False))
def get_needle_idx(parts):
idx = int(len(parts) * needle_position)
return max(0, min(idx, len(parts)))
# Phase 1: Build haystack with full paragraphs until we exceed target
haystack_parts = []
current_tokens = 0
para_idx = 0
while current_tokens < haystack_target_tokens:
while True:
para = HAYSTACK_PARAGRAPHS[para_idx % len(HAYSTACK_PARAGRAPHS)]
para_tokens = len(tokenizer.encode(para, add_special_tokens=False))
if current_tokens + para_tokens > haystack_target_tokens:
test_parts = haystack_parts + [para]
prompt = build_prompt(test_parts, get_needle_idx(test_parts))
if count_tokens(prompt) > target_length:
break
haystack_parts.append(para)
current_tokens += para_tokens
para_idx += 1
# Calculate needle insertion point
needle_idx = int(len(haystack_parts) * needle_position)
needle_idx = max(0, min(needle_idx, len(haystack_parts)))
if para_idx > 10000: # Safety limit
break
# Insert needle
haystack_parts.insert(needle_idx, needle)
# Phase 2: Fine-tune by adding words from next paragraph
next_para = HAYSTACK_PARAGRAPHS[para_idx % len(HAYSTACK_PARAGRAPHS)]
words = next_para.split()
# Assemble prompt
full_text = "".join(haystack_parts)
best_parts = haystack_parts.copy()
best_diff = abs(target_length - count_tokens(build_prompt(haystack_parts, get_needle_idx(haystack_parts))))
if use_chat_template and hasattr(tokenizer, 'apply_chat_template'):
# Use chat template for instruct models
# For Qwen3, add /no_think to disable thinking mode
question_text = "/no_think Answer only with the secret number mentioned above, nothing else:"
messages = [
{"role": "user", "content": f"{full_text}\n\n{question_text}"}
]
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
else:
# Raw text format for base models
question = "\n\nQuestion: What is the secret number mentioned in the text above?\nAnswer: The secret number is"
prompt = full_text + question
for i in range(1, len(words) + 1):
partial = " ".join(words[:i]) + " "
test_parts = haystack_parts + [partial]
prompt = build_prompt(test_parts, get_needle_idx(test_parts))
token_count = count_tokens(prompt)
diff = abs(target_length - token_count)
# Verify length
actual_tokens = len(tokenizer.encode(prompt, add_special_tokens=False))
if diff < best_diff:
best_diff = diff
best_parts = test_parts.copy()
if token_count >= target_length:
break
haystack_parts = best_parts
# Final build
needle_idx = get_needle_idx(haystack_parts)
prompt = build_prompt(haystack_parts, needle_idx)
actual_tokens = count_tokens(prompt)
if verbose:
print(f"[NeedleTest] Target: {target_length} tokens, Actual: {actual_tokens} tokens")
print(f"[NeedleTest] Needle position: {needle_position:.0%} ({needle_idx}/{len(haystack_parts)-1} paragraphs)")
print(f"[NeedleTest] Using chat template: {use_chat_template and hasattr(tokenizer, 'apply_chat_template')}")
print(f"[NeedleTest] Target: {target_length}, Actual: {actual_tokens} tokens (diff={actual_tokens - target_length})")
return prompt, needle_value