♻️ refactor: consolidate RULER test files and document root cause
- test_ruler.py: add --fresh-llm, --sample-indices, --json-output options - test_ruler.py: consolidate test_ruler_single_sample.py, test_ruler_sequential.py, test_ruler_samples.py - docs: update chunked offload issue with root cause (state leakage confirmed) - docs: add single-sample test results showing 100% accuracy for niah_single_1 Deleted redundant test files: - tests/test_ruler_single_sample.py - tests/test_ruler_sequential.py - tests/test_ruler_samples.py Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -17,6 +17,15 @@ Usage:
|
||||
|
||||
# Test all samples in all datasets
|
||||
python tests/test_ruler.py --enable-offload
|
||||
|
||||
# Test specific sample indices (comma-separated)
|
||||
python tests/test_ruler.py --enable-offload --datasets niah_single_1 --sample-indices 28,33,40
|
||||
|
||||
# Single-sample mode: reinitialize LLM for each sample (avoids state leakage)
|
||||
python tests/test_ruler.py --enable-offload --datasets niah_single_1 --fresh-llm
|
||||
|
||||
# JSON output mode for scripting
|
||||
python tests/test_ruler.py --enable-offload --datasets niah_single_1 --json-output
|
||||
"""
|
||||
|
||||
import os
|
||||
@@ -150,17 +159,30 @@ def run_task_test(
|
||||
sample_indices: Optional[List[int]] = None,
|
||||
max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
|
||||
verbose: bool = True,
|
||||
llm_factory: Optional[callable] = None,
|
||||
fresh_llm: bool = False,
|
||||
) -> Dict:
|
||||
"""
|
||||
Run test for a single RULER task.
|
||||
|
||||
Args:
|
||||
llm: LLM instance (ignored if fresh_llm=True)
|
||||
task_name: Name of the task to test
|
||||
data_dir: Path to data directory
|
||||
sample_indices: Optional list of specific sample indices to test
|
||||
max_new_tokens: Maximum tokens to generate
|
||||
verbose: Print detailed output
|
||||
llm_factory: Callable to create LLM instance (required if fresh_llm=True)
|
||||
fresh_llm: If True, reinitialize LLM for each sample (avoids state leakage)
|
||||
|
||||
Returns dict with: task, correct, total, score, results
|
||||
"""
|
||||
data_file = data_dir / task_name / "validation.jsonl"
|
||||
samples = load_samples(data_file, sample_indices)
|
||||
|
||||
if verbose:
|
||||
print(f"\n Testing {task_name}: {len(samples)} samples")
|
||||
mode_str = " [fresh-llm mode]" if fresh_llm else ""
|
||||
print(f"\n Testing {task_name}: {len(samples)} samples{mode_str}")
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.1,
|
||||
@@ -171,13 +193,26 @@ def run_task_test(
|
||||
total_score = 0.0
|
||||
results = []
|
||||
|
||||
current_llm = llm
|
||||
|
||||
for sample in samples:
|
||||
idx = sample.get("index", sample["_local_idx"])
|
||||
prompt = sample["input"]
|
||||
expected = sample["outputs"]
|
||||
|
||||
# Fresh LLM mode: reinitialize for each sample
|
||||
if fresh_llm:
|
||||
if llm_factory is None:
|
||||
raise ValueError("llm_factory required when fresh_llm=True")
|
||||
# Cleanup previous LLM
|
||||
if current_llm is not None:
|
||||
del current_llm
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
current_llm = llm_factory()
|
||||
|
||||
# Generate
|
||||
outputs = llm.generate([prompt], sampling_params, use_tqdm=False)
|
||||
outputs = current_llm.generate([prompt], sampling_params, use_tqdm=False)
|
||||
output_text = outputs[0]["text"]
|
||||
|
||||
# Evaluate
|
||||
@@ -200,6 +235,12 @@ def run_task_test(
|
||||
out_preview = output_text[:50].replace('\n', ' ')
|
||||
print(f" [{idx:3d}] {status} (score={score:.2f}) exp={exp_preview}... | out={out_preview}...")
|
||||
|
||||
# Cleanup last LLM instance in fresh mode
|
||||
if fresh_llm and current_llm is not None:
|
||||
del current_llm
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
avg_score = total_score / len(samples) if samples else 0.0
|
||||
|
||||
return {
|
||||
@@ -217,6 +258,7 @@ def run_ruler_benchmark(
|
||||
data_dir: Path,
|
||||
datasets: Optional[List[str]] = None,
|
||||
num_samples: Optional[int] = None,
|
||||
sample_indices: Optional[List[int]] = None,
|
||||
max_model_len: int = DEFAULT_MAX_MODEL_LEN,
|
||||
max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
|
||||
enable_cpu_offload: bool = False,
|
||||
@@ -226,6 +268,8 @@ def run_ruler_benchmark(
|
||||
gpu_utilization: float = 0.9,
|
||||
enforce_eager: bool = True,
|
||||
verbose: bool = True,
|
||||
fresh_llm: bool = False,
|
||||
json_output: bool = False,
|
||||
sparse_policy: Optional[str] = None,
|
||||
sparse_threshold: float = 0.9,
|
||||
sparse_samples: int = 128,
|
||||
@@ -239,7 +283,9 @@ def run_ruler_benchmark(
|
||||
data_dir: Directory containing task subdirectories
|
||||
datasets: List of task names to test (None = all)
|
||||
num_samples: Number of samples per task (None = all)
|
||||
...other LLM config params...
|
||||
sample_indices: Specific sample indices to test (overrides num_samples)
|
||||
fresh_llm: If True, reinitialize LLM for each sample (avoids state leakage)
|
||||
json_output: If True, output JSON results at the end
|
||||
sparse_policy: Sparse attention policy (FULL, QUEST, MINFERENCE, XATTN)
|
||||
|
||||
Returns:
|
||||
@@ -251,21 +297,29 @@ def run_ruler_benchmark(
|
||||
else:
|
||||
tasks = datasets
|
||||
|
||||
# Sample indices
|
||||
sample_indices = list(range(num_samples)) if num_samples else None
|
||||
# Sample indices: explicit list takes precedence over num_samples
|
||||
if sample_indices is not None:
|
||||
indices = sample_indices
|
||||
elif num_samples:
|
||||
indices = list(range(num_samples))
|
||||
else:
|
||||
indices = None
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f"RULER Benchmark")
|
||||
print(f"{'='*60}")
|
||||
print(f"Model: {model_path}")
|
||||
print(f"Data dir: {data_dir}")
|
||||
print(f"Tasks: {len(tasks)}")
|
||||
print(f"Samples per task: {num_samples if num_samples else 'all'}")
|
||||
print(f"CPU offload: {enable_cpu_offload}")
|
||||
print(f"{'='*60}")
|
||||
samples_desc = str(sample_indices) if sample_indices else (str(num_samples) if num_samples else 'all')
|
||||
|
||||
# Initialize LLM
|
||||
print("\nInitializing LLM...")
|
||||
if not json_output:
|
||||
print(f"\n{'='*60}")
|
||||
print(f"RULER Benchmark")
|
||||
print(f"{'='*60}")
|
||||
print(f"Model: {model_path}")
|
||||
print(f"Data dir: {data_dir}")
|
||||
print(f"Tasks: {len(tasks)}")
|
||||
print(f"Samples: {samples_desc}")
|
||||
print(f"CPU offload: {enable_cpu_offload}")
|
||||
print(f"Fresh LLM mode: {fresh_llm}")
|
||||
print(f"{'='*60}")
|
||||
|
||||
# LLM initialization kwargs
|
||||
llm_kwargs = {
|
||||
"max_model_len": max_model_len,
|
||||
"max_num_batched_tokens": max_model_len,
|
||||
@@ -286,7 +340,16 @@ def run_ruler_benchmark(
|
||||
llm_kwargs["sparse_threshold"] = sparse_threshold
|
||||
llm_kwargs["sparse_samples_per_chunk"] = sparse_samples
|
||||
|
||||
llm = LLM(model_path, **llm_kwargs)
|
||||
# Factory function for fresh_llm mode
|
||||
def create_llm():
|
||||
return LLM(model_path, **llm_kwargs)
|
||||
|
||||
# Initialize LLM (only once if not fresh_llm mode)
|
||||
llm = None
|
||||
if not fresh_llm:
|
||||
if not json_output:
|
||||
print("\nInitializing LLM...")
|
||||
llm = create_llm()
|
||||
|
||||
# Run tests
|
||||
start_time = time.time()
|
||||
@@ -297,22 +360,25 @@ def run_ruler_benchmark(
|
||||
llm=llm,
|
||||
task_name=task_name,
|
||||
data_dir=data_dir,
|
||||
sample_indices=sample_indices,
|
||||
sample_indices=indices,
|
||||
max_new_tokens=max_new_tokens,
|
||||
verbose=verbose,
|
||||
verbose=verbose and not json_output,
|
||||
llm_factory=create_llm,
|
||||
fresh_llm=fresh_llm,
|
||||
)
|
||||
task_results.append(result)
|
||||
|
||||
if verbose:
|
||||
if verbose and not json_output:
|
||||
print(f" -> {task_name}: {result['correct']}/{result['total']} "
|
||||
f"({result['accuracy']*100:.1f}%) avg_score={result['avg_score']:.3f}")
|
||||
|
||||
total_time = time.time() - start_time
|
||||
|
||||
# Cleanup
|
||||
del llm
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
# Cleanup (only if not fresh_llm mode, since fresh mode cleans up itself)
|
||||
if llm is not None:
|
||||
del llm
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Aggregate results
|
||||
total_correct = sum(r["correct"] for r in task_results)
|
||||
@@ -320,28 +386,53 @@ def run_ruler_benchmark(
|
||||
overall_accuracy = total_correct / total_samples if total_samples > 0 else 0.0
|
||||
avg_score = sum(r["avg_score"] for r in task_results) / len(task_results) if task_results else 0.0
|
||||
|
||||
# Print summary
|
||||
print(f"\n{'='*60}")
|
||||
print(f"RULER Benchmark Results")
|
||||
print(f"{'='*60}")
|
||||
print(f"\n{'Task':<20} {'Correct':<10} {'Accuracy':<12} {'Avg Score':<12}")
|
||||
print(f"{'-'*54}")
|
||||
# Collect failed samples
|
||||
failed_samples = {}
|
||||
for r in task_results:
|
||||
print(f"{r['task']:<20} {r['correct']}/{r['total']:<7} {r['accuracy']*100:>6.1f}% {r['avg_score']:.3f}")
|
||||
print(f"{'-'*54}")
|
||||
print(f"{'TOTAL':<20} {total_correct}/{total_samples:<7} {overall_accuracy*100:>6.1f}% {avg_score:.3f}")
|
||||
print(f"\nTime: {total_time:.1f}s")
|
||||
print(f"{'='*60}\n")
|
||||
failed = [res["index"] for res in r["results"] if not res["passed"]]
|
||||
if failed:
|
||||
failed_samples[r["task"]] = failed
|
||||
|
||||
return {
|
||||
# Print summary
|
||||
if not json_output:
|
||||
print(f"\n{'='*60}")
|
||||
print(f"RULER Benchmark Results")
|
||||
print(f"{'='*60}")
|
||||
print(f"\n{'Task':<20} {'Correct':<10} {'Accuracy':<12} {'Avg Score':<12}")
|
||||
print(f"{'-'*54}")
|
||||
for r in task_results:
|
||||
print(f"{r['task']:<20} {r['correct']}/{r['total']:<7} {r['accuracy']*100:>6.1f}% {r['avg_score']:.3f}")
|
||||
print(f"{'-'*54}")
|
||||
print(f"{'TOTAL':<20} {total_correct}/{total_samples:<7} {overall_accuracy*100:>6.1f}% {avg_score:.3f}")
|
||||
print(f"\nTime: {total_time:.1f}s")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
results = {
|
||||
"total_correct": total_correct,
|
||||
"total_samples": total_samples,
|
||||
"overall_accuracy": overall_accuracy,
|
||||
"avg_score": avg_score,
|
||||
"time": total_time,
|
||||
"task_results": task_results,
|
||||
"failed_samples": failed_samples,
|
||||
}
|
||||
|
||||
# JSON output
|
||||
if json_output:
|
||||
json_results = {
|
||||
"total_correct": total_correct,
|
||||
"total_samples": total_samples,
|
||||
"overall_accuracy": overall_accuracy,
|
||||
"avg_score": avg_score,
|
||||
"time": total_time,
|
||||
"tasks": {r["task"]: {"correct": r["correct"], "total": r["total"], "accuracy": r["accuracy"]}
|
||||
for r in task_results},
|
||||
"failed_samples": failed_samples,
|
||||
}
|
||||
print(json.dumps(json_results, indent=2))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# ============================================================
|
||||
# CLI Entry Point
|
||||
@@ -361,6 +452,8 @@ if __name__ == "__main__":
|
||||
help="Comma-separated list of datasets to test (default: all)")
|
||||
parser.add_argument("--num-samples", type=int, default=0,
|
||||
help="Number of samples per dataset (default: 0 = all)")
|
||||
parser.add_argument("--sample-indices", type=str, default="",
|
||||
help="Comma-separated specific sample indices (e.g., 28,33,40)")
|
||||
parser.add_argument("--max-model-len", type=int, default=DEFAULT_MAX_MODEL_LEN,
|
||||
help=f"Maximum model context length (default: {DEFAULT_MAX_MODEL_LEN})")
|
||||
parser.add_argument("--max-new-tokens", type=int, default=DEFAULT_MAX_NEW_TOKENS,
|
||||
@@ -379,6 +472,10 @@ if __name__ == "__main__":
|
||||
help="Enable CUDA graph")
|
||||
parser.add_argument("--quiet", "-q", action="store_true",
|
||||
help="Quiet mode")
|
||||
parser.add_argument("--fresh-llm", action="store_true",
|
||||
help="Reinitialize LLM for each sample (avoids state leakage)")
|
||||
parser.add_argument("--json-output", action="store_true",
|
||||
help="Output results in JSON format")
|
||||
parser.add_argument("--sparse-policy", type=str, default="",
|
||||
help="Sparse attention policy (FULL, QUEST, XATTN_BSA)")
|
||||
# XAttention BSA specific parameters
|
||||
@@ -395,6 +492,11 @@ if __name__ == "__main__":
|
||||
datasets = args.datasets.split(",") if args.datasets else None
|
||||
num_samples = args.num_samples if args.num_samples > 0 else None
|
||||
|
||||
# Parse sample indices (takes precedence over num_samples)
|
||||
sample_indices = None
|
||||
if args.sample_indices:
|
||||
sample_indices = [int(x.strip()) for x in args.sample_indices.split(",")]
|
||||
|
||||
# Parse sparse policy
|
||||
sparse_policy_str = args.sparse_policy.upper() if args.sparse_policy else None
|
||||
|
||||
@@ -403,6 +505,7 @@ if __name__ == "__main__":
|
||||
data_dir=Path(args.data_dir),
|
||||
datasets=datasets,
|
||||
num_samples=num_samples,
|
||||
sample_indices=sample_indices,
|
||||
max_model_len=args.max_model_len,
|
||||
max_new_tokens=args.max_new_tokens,
|
||||
enable_cpu_offload=args.enable_offload,
|
||||
@@ -412,15 +515,18 @@ if __name__ == "__main__":
|
||||
gpu_utilization=args.gpu_utilization,
|
||||
enforce_eager=not args.use_cuda_graph,
|
||||
verbose=not args.quiet,
|
||||
fresh_llm=args.fresh_llm,
|
||||
json_output=args.json_output,
|
||||
sparse_policy=sparse_policy_str,
|
||||
sparse_threshold=args.sparse_threshold,
|
||||
sparse_samples=args.sparse_samples,
|
||||
sparse_block_size=args.sparse_block_size,
|
||||
)
|
||||
|
||||
# Exit code
|
||||
if results["overall_accuracy"] >= 0.5:
|
||||
print("test_ruler: PASSED")
|
||||
else:
|
||||
print(f"test_ruler: FAILED (accuracy={results['overall_accuracy']*100:.1f}%)")
|
||||
exit(1)
|
||||
# Exit code (skip for json output mode)
|
||||
if not args.json_output:
|
||||
if results["overall_accuracy"] >= 0.5:
|
||||
print("test_ruler: PASSED")
|
||||
else:
|
||||
print(f"test_ruler: FAILED (accuracy={results['overall_accuracy']*100:.1f}%)")
|
||||
exit(1)
|
||||
|
||||
Reference in New Issue
Block a user