diff --git a/bench.py b/bench.py index 8717ef1..d1dac2e 100644 --- a/bench.py +++ b/bench.py @@ -58,6 +58,8 @@ def main(): help="Enable sparse policy routing (FullAttentionPolicy by default)") parser.add_argument("--gpu-util", type=float, default=0.9, help="GPU memory utilization (default: 0.9)") + parser.add_argument("--enforce-eager", action="store_true", + help="Disable CUDA graphs (default: False)") args = parser.parse_args() path = os.path.expanduser(args.model) @@ -76,7 +78,7 @@ def main(): llm = LLM( path, - enforce_eager=False, + enforce_eager=args.enforce_eager, max_model_len=max_len, max_num_batched_tokens=max_len, sparse_policy=sparse_policy,