[refactor] refactor test_align.py.
This commit is contained in:
105
tests/utils.py
105
tests/utils.py
@@ -55,7 +55,7 @@ def generate_needle_prompt(
|
||||
verbose: bool = True,
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
Generate a needle-in-haystack prompt of approximately target_length tokens.
|
||||
Generate a needle-in-haystack prompt of exactly target_length tokens.
|
||||
|
||||
Args:
|
||||
tokenizer: HuggingFace tokenizer for length estimation
|
||||
@@ -71,68 +71,79 @@ def generate_needle_prompt(
|
||||
# The needle sentence
|
||||
needle = f"The secret number you need to remember is {needle_value}. This is very important. "
|
||||
|
||||
# Question at the end
|
||||
question = "\n\nQuestion: What is the secret number mentioned in the text above?\nAnswer: The secret number is"
|
||||
# Question text
|
||||
if use_chat_template and hasattr(tokenizer, 'apply_chat_template'):
|
||||
question_text = "/no_think Answer only with the secret number mentioned above, nothing else:"
|
||||
else:
|
||||
question_text = "\n\nQuestion: What is the secret number mentioned in the text above?\nAnswer: The secret number is"
|
||||
|
||||
# Estimate tokens for fixed parts
|
||||
needle_tokens = len(tokenizer.encode(needle, add_special_tokens=False))
|
||||
question_text = "What is the secret number mentioned in the text above? Answer with just the number."
|
||||
question_tokens = len(tokenizer.encode(question_text, add_special_tokens=False))
|
||||
# Buffer for chat template, special tokens, etc.
|
||||
overhead_tokens = 100 if use_chat_template else 50
|
||||
def build_prompt(haystack_parts, needle_idx):
|
||||
"""Build full prompt from haystack parts with needle inserted."""
|
||||
parts = haystack_parts.copy()
|
||||
parts.insert(needle_idx, needle)
|
||||
full_text = "".join(parts)
|
||||
|
||||
# Available tokens for haystack
|
||||
haystack_target_tokens = target_length - needle_tokens - question_tokens - overhead_tokens
|
||||
if haystack_target_tokens < 100:
|
||||
raise ValueError(f"target_length {target_length} is too short for needle test")
|
||||
if use_chat_template and hasattr(tokenizer, 'apply_chat_template'):
|
||||
messages = [{"role": "user", "content": f"{full_text}\n\n{question_text}"}]
|
||||
return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
else:
|
||||
return full_text + question_text
|
||||
|
||||
# Build haystack by repeating paragraphs
|
||||
def count_tokens(prompt):
|
||||
return len(tokenizer.encode(prompt, add_special_tokens=False))
|
||||
|
||||
def get_needle_idx(parts):
|
||||
idx = int(len(parts) * needle_position)
|
||||
return max(0, min(idx, len(parts)))
|
||||
|
||||
# Phase 1: Build haystack with full paragraphs until we exceed target
|
||||
haystack_parts = []
|
||||
current_tokens = 0
|
||||
para_idx = 0
|
||||
|
||||
while current_tokens < haystack_target_tokens:
|
||||
while True:
|
||||
para = HAYSTACK_PARAGRAPHS[para_idx % len(HAYSTACK_PARAGRAPHS)]
|
||||
para_tokens = len(tokenizer.encode(para, add_special_tokens=False))
|
||||
if current_tokens + para_tokens > haystack_target_tokens:
|
||||
test_parts = haystack_parts + [para]
|
||||
prompt = build_prompt(test_parts, get_needle_idx(test_parts))
|
||||
|
||||
if count_tokens(prompt) > target_length:
|
||||
break
|
||||
|
||||
haystack_parts.append(para)
|
||||
current_tokens += para_tokens
|
||||
para_idx += 1
|
||||
|
||||
# Calculate needle insertion point
|
||||
needle_idx = int(len(haystack_parts) * needle_position)
|
||||
needle_idx = max(0, min(needle_idx, len(haystack_parts)))
|
||||
if para_idx > 10000: # Safety limit
|
||||
break
|
||||
|
||||
# Insert needle
|
||||
haystack_parts.insert(needle_idx, needle)
|
||||
# Phase 2: Fine-tune by adding words from next paragraph
|
||||
next_para = HAYSTACK_PARAGRAPHS[para_idx % len(HAYSTACK_PARAGRAPHS)]
|
||||
words = next_para.split()
|
||||
|
||||
# Assemble prompt
|
||||
full_text = "".join(haystack_parts)
|
||||
best_parts = haystack_parts.copy()
|
||||
best_diff = abs(target_length - count_tokens(build_prompt(haystack_parts, get_needle_idx(haystack_parts))))
|
||||
|
||||
if use_chat_template and hasattr(tokenizer, 'apply_chat_template'):
|
||||
# Use chat template for instruct models
|
||||
# For Qwen3, add /no_think to disable thinking mode
|
||||
question_text = "/no_think Answer only with the secret number mentioned above, nothing else:"
|
||||
messages = [
|
||||
{"role": "user", "content": f"{full_text}\n\n{question_text}"}
|
||||
]
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
else:
|
||||
# Raw text format for base models
|
||||
question = "\n\nQuestion: What is the secret number mentioned in the text above?\nAnswer: The secret number is"
|
||||
prompt = full_text + question
|
||||
for i in range(1, len(words) + 1):
|
||||
partial = " ".join(words[:i]) + " "
|
||||
test_parts = haystack_parts + [partial]
|
||||
prompt = build_prompt(test_parts, get_needle_idx(test_parts))
|
||||
token_count = count_tokens(prompt)
|
||||
diff = abs(target_length - token_count)
|
||||
|
||||
# Verify length
|
||||
actual_tokens = len(tokenizer.encode(prompt, add_special_tokens=False))
|
||||
if diff < best_diff:
|
||||
best_diff = diff
|
||||
best_parts = test_parts.copy()
|
||||
|
||||
if token_count >= target_length:
|
||||
break
|
||||
|
||||
haystack_parts = best_parts
|
||||
|
||||
# Final build
|
||||
needle_idx = get_needle_idx(haystack_parts)
|
||||
prompt = build_prompt(haystack_parts, needle_idx)
|
||||
|
||||
actual_tokens = count_tokens(prompt)
|
||||
if verbose:
|
||||
print(f"[NeedleTest] Target: {target_length} tokens, Actual: {actual_tokens} tokens")
|
||||
print(f"[NeedleTest] Needle position: {needle_position:.0%} ({needle_idx}/{len(haystack_parts)-1} paragraphs)")
|
||||
print(f"[NeedleTest] Using chat template: {use_chat_template and hasattr(tokenizer, 'apply_chat_template')}")
|
||||
print(f"[NeedleTest] Target: {target_length}, Actual: {actual_tokens} tokens (diff={actual_tokens - target_length})")
|
||||
|
||||
return prompt, needle_value
|
||||
|
||||
|
||||
Reference in New Issue
Block a user