[WIP] Before fix bench_offload.py.

This commit is contained in:
Zijie Tian
2026-01-06 18:41:08 +08:00
parent c7ac39dfbd
commit 535f2037ab
7 changed files with 66 additions and 44 deletions

View File

@@ -96,46 +96,41 @@ def generate_needle_prompt(
idx = int(len(parts) * needle_position)
return max(0, min(idx, len(parts)))
# Phase 1: Build haystack with full paragraphs until we exceed target
# Pre-compute tokens per paragraph for efficiency (avoid O(n²) tokenization)
para_tokens = []
for para in HAYSTACK_PARAGRAPHS:
para_tokens.append(len(tokenizer.encode(para, add_special_tokens=False)))
avg_tokens_per_para = sum(para_tokens) / len(para_tokens)
# Estimate overhead (needle + question + chat template)
overhead_prompt = build_prompt([HAYSTACK_PARAGRAPHS[0]], 0)
overhead_tokens = count_tokens(overhead_prompt) - para_tokens[0]
# Phase 1: Estimate number of paragraphs needed
estimated_paras = int((target_length - overhead_tokens) / avg_tokens_per_para) + 1
# Build haystack with estimated paragraphs
haystack_parts = []
para_idx = 0
for i in range(estimated_paras):
haystack_parts.append(HAYSTACK_PARAGRAPHS[i % len(HAYSTACK_PARAGRAPHS)])
while True:
# Phase 2: Adjust to get closer to target
prompt = build_prompt(haystack_parts, get_needle_idx(haystack_parts))
current_tokens = count_tokens(prompt)
# Add more if under target
para_idx = len(haystack_parts)
while current_tokens < target_length and para_idx < 100000:
para = HAYSTACK_PARAGRAPHS[para_idx % len(HAYSTACK_PARAGRAPHS)]
test_parts = haystack_parts + [para]
prompt = build_prompt(test_parts, get_needle_idx(test_parts))
if count_tokens(prompt) > target_length:
break
haystack_parts.append(para)
current_tokens += para_tokens[para_idx % len(HAYSTACK_PARAGRAPHS)]
para_idx += 1
if para_idx > 10000: # Safety limit
break
# Phase 2: Fine-tune by adding words from next paragraph
next_para = HAYSTACK_PARAGRAPHS[para_idx % len(HAYSTACK_PARAGRAPHS)]
words = next_para.split()
best_parts = haystack_parts.copy()
best_diff = abs(target_length - count_tokens(build_prompt(haystack_parts, get_needle_idx(haystack_parts))))
for i in range(1, len(words) + 1):
partial = " ".join(words[:i]) + " "
test_parts = haystack_parts + [partial]
prompt = build_prompt(test_parts, get_needle_idx(test_parts))
token_count = count_tokens(prompt)
diff = abs(target_length - token_count)
if diff < best_diff:
best_diff = diff
best_parts = test_parts.copy()
if token_count >= target_length:
break
haystack_parts = best_parts
# Remove if over target
while current_tokens > target_length + 100 and len(haystack_parts) > 1:
removed_para_idx = (len(haystack_parts) - 1) % len(HAYSTACK_PARAGRAPHS)
haystack_parts.pop()
current_tokens -= para_tokens[removed_para_idx]
# Final build
needle_idx = get_needle_idx(haystack_parts)