[WIP] Before fix bench_offload.py.
This commit is contained in:
@@ -96,46 +96,41 @@ def generate_needle_prompt(
|
||||
idx = int(len(parts) * needle_position)
|
||||
return max(0, min(idx, len(parts)))
|
||||
|
||||
# Phase 1: Build haystack with full paragraphs until we exceed target
|
||||
# Pre-compute tokens per paragraph for efficiency (avoid O(n²) tokenization)
|
||||
para_tokens = []
|
||||
for para in HAYSTACK_PARAGRAPHS:
|
||||
para_tokens.append(len(tokenizer.encode(para, add_special_tokens=False)))
|
||||
avg_tokens_per_para = sum(para_tokens) / len(para_tokens)
|
||||
|
||||
# Estimate overhead (needle + question + chat template)
|
||||
overhead_prompt = build_prompt([HAYSTACK_PARAGRAPHS[0]], 0)
|
||||
overhead_tokens = count_tokens(overhead_prompt) - para_tokens[0]
|
||||
|
||||
# Phase 1: Estimate number of paragraphs needed
|
||||
estimated_paras = int((target_length - overhead_tokens) / avg_tokens_per_para) + 1
|
||||
|
||||
# Build haystack with estimated paragraphs
|
||||
haystack_parts = []
|
||||
para_idx = 0
|
||||
for i in range(estimated_paras):
|
||||
haystack_parts.append(HAYSTACK_PARAGRAPHS[i % len(HAYSTACK_PARAGRAPHS)])
|
||||
|
||||
while True:
|
||||
# Phase 2: Adjust to get closer to target
|
||||
prompt = build_prompt(haystack_parts, get_needle_idx(haystack_parts))
|
||||
current_tokens = count_tokens(prompt)
|
||||
|
||||
# Add more if under target
|
||||
para_idx = len(haystack_parts)
|
||||
while current_tokens < target_length and para_idx < 100000:
|
||||
para = HAYSTACK_PARAGRAPHS[para_idx % len(HAYSTACK_PARAGRAPHS)]
|
||||
test_parts = haystack_parts + [para]
|
||||
prompt = build_prompt(test_parts, get_needle_idx(test_parts))
|
||||
|
||||
if count_tokens(prompt) > target_length:
|
||||
break
|
||||
|
||||
haystack_parts.append(para)
|
||||
current_tokens += para_tokens[para_idx % len(HAYSTACK_PARAGRAPHS)]
|
||||
para_idx += 1
|
||||
|
||||
if para_idx > 10000: # Safety limit
|
||||
break
|
||||
|
||||
# Phase 2: Fine-tune by adding words from next paragraph
|
||||
next_para = HAYSTACK_PARAGRAPHS[para_idx % len(HAYSTACK_PARAGRAPHS)]
|
||||
words = next_para.split()
|
||||
|
||||
best_parts = haystack_parts.copy()
|
||||
best_diff = abs(target_length - count_tokens(build_prompt(haystack_parts, get_needle_idx(haystack_parts))))
|
||||
|
||||
for i in range(1, len(words) + 1):
|
||||
partial = " ".join(words[:i]) + " "
|
||||
test_parts = haystack_parts + [partial]
|
||||
prompt = build_prompt(test_parts, get_needle_idx(test_parts))
|
||||
token_count = count_tokens(prompt)
|
||||
diff = abs(target_length - token_count)
|
||||
|
||||
if diff < best_diff:
|
||||
best_diff = diff
|
||||
best_parts = test_parts.copy()
|
||||
|
||||
if token_count >= target_length:
|
||||
break
|
||||
|
||||
haystack_parts = best_parts
|
||||
# Remove if over target
|
||||
while current_tokens > target_length + 100 and len(haystack_parts) > 1:
|
||||
removed_para_idx = (len(haystack_parts) - 1) % len(HAYSTACK_PARAGRAPHS)
|
||||
haystack_parts.pop()
|
||||
current_tokens -= para_tokens[removed_para_idx]
|
||||
|
||||
# Final build
|
||||
needle_idx = get_needle_idx(haystack_parts)
|
||||
|
||||
Reference in New Issue
Block a user