diff --git a/bench_vllm.py b/bench_vllm.py index 483b311..1636bee 100644 --- a/bench_vllm.py +++ b/bench_vllm.py @@ -1,5 +1,14 @@ import os -os.environ["VLLM_USE_V1"] = "1" +import sys + +# Parse --use-v1 flag before importing vllm +use_v1 = "--use-v1" in sys.argv +if use_v1: + os.environ["VLLM_USE_V1"] = "1" + sys.argv.remove("--use-v1") +else: + os.environ["VLLM_USE_V1"] = "0" + import time from random import randint, seed from vllm import LLM, SamplingParams