[claudesquad] update from 'fix-bug-2' on 09 Jan 26 16:05 CST

This commit is contained in:
Zijie Tian
2026-01-09 16:05:36 +08:00
parent ccf04d3917
commit 1425510a2e
3 changed files with 267 additions and 34 deletions

View File

@@ -38,6 +38,7 @@ def run_needle_test(
minference_vertical: int = 1000,
minference_slash: int = 6096,
gpu_utilization: float = 0.9,
enforce_eager: bool = True,
verbose: bool = True,
) -> bool:
"""
@@ -97,7 +98,7 @@ def run_needle_test(
# 1. Initialize LLM
llm_kwargs = {
"enforce_eager": True,
"enforce_eager": enforce_eager,
"max_model_len": max_model_len,
"max_num_batched_tokens": max_model_len,
"enable_cpu_offload": enable_cpu_offload,
@@ -259,11 +260,25 @@ if __name__ == "__main__":
default=0.9,
help="GPU memory utilization (default: 0.9)"
)
parser.add_argument(
"--enforce-eager",
action="store_true",
default=True,
help="Force eager execution (disable CUDA graphs)"
)
parser.add_argument(
"--use-cuda-graph",
action="store_true",
help="Enable CUDA graph (disable enforce_eager)"
)
args = parser.parse_args()
# Convert budget=0 to None for fixed mode
minference_budget = args.minference_budget if args.minference_budget > 0 else None
# Determine enforce_eager: use_cuda_graph overrides enforce_eager
enforce_eager = not args.use_cuda_graph
passed = run_needle_test(
model_path=args.model,
max_model_len=args.max_model_len,
@@ -282,6 +297,7 @@ if __name__ == "__main__":
minference_vertical=args.minference_vertical,
minference_slash=args.minference_slash,
gpu_utilization=args.gpu_utilization,
enforce_eager=enforce_eager,
verbose=True,
)