From 5012b112919409a3d03efeb6e6a7ea32050762c9 Mon Sep 17 00:00:00 2001 From: Zijie Tian Date: Fri, 9 Jan 2026 15:20:37 +0800 Subject: [PATCH] [bench] Modify bench_vllm.py --- bench_vllm.py | 63 +++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 49 insertions(+), 14 deletions(-) diff --git a/bench_vllm.py b/bench_vllm.py index 6d1e269..dce213f 100644 --- a/bench_vllm.py +++ b/bench_vllm.py @@ -1,4 +1,5 @@ import os + os.environ["VLLM_USE_V1"] = "1" import time from random import randint, seed @@ -8,8 +9,12 @@ from vllm import LLM, SamplingParams def bench_decode(llm, num_seqs, input_len, output_len): """Benchmark decode performance""" seed(0) - prompt_token_ids = [[randint(0, 10000) for _ in range(input_len)] for _ in range(num_seqs)] - sampling_params = SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=output_len) + prompt_token_ids = [ + [randint(0, 10000) for _ in range(input_len)] for _ in range(num_seqs) + ] + sampling_params = SamplingParams( + temperature=0.6, ignore_eos=True, max_tokens=output_len + ) prompt_token_ids = [dict(prompt_token_ids=p) for p in prompt_token_ids] t = time.time() @@ -21,15 +26,21 @@ def bench_decode(llm, num_seqs, input_len, output_len): decode_tokens = num_seqs * output_len decode_throughput = decode_tokens / t - print(f"[Decode] Input: {num_seqs}x{input_len}tok, Output: {decode_tokens}tok, Time: {t:.2f}s") - print(f" Throughput: {decode_throughput:.2f} tok/s (includes prefill overhead)") + print( + f"[Decode] Input: {num_seqs}x{input_len}tok, Output: {decode_tokens}tok, Time: {t:.2f}s" + ) + print( + f" Throughput: {decode_throughput:.2f} tok/s (includes prefill overhead)" + ) def bench_prefill(llm, num_seqs, input_len): """Benchmark prefill performance""" seed(0) # Fixed length input, minimal output to focus on prefill - prompt_token_ids = [[randint(0, 10000) for _ in range(input_len)] for _ in range(num_seqs)] + prompt_token_ids = [ + [randint(0, 10000) for _ in range(input_len)] for _ in range(num_seqs) + ] sampling_params = SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=1) prompt_token_ids = [dict(prompt_token_ids=p) for p in prompt_token_ids] @@ -38,17 +49,39 @@ def bench_prefill(llm, num_seqs, input_len): t = time.time() - t total_input_tokens = num_seqs * input_len throughput = total_input_tokens / t - print(f"[Prefill] Input: {total_input_tokens}tok ({num_seqs}x{input_len}), Time: {t:.2f}s, Throughput: {throughput:.2f}tok/s") + print( + f"[Prefill] Input: {total_input_tokens}tok ({num_seqs}x{input_len}), Time: {t:.2f}s, Throughput: {throughput:.2f}tok/s" + ) def main(): import argparse - parser = argparse.ArgumentParser(description="Benchmark vLLM performance (for comparison)") - parser.add_argument("--input-len", type=int, default=None, help="Input length in tokens") - parser.add_argument("--output-len", type=int, default=64, help="Output length for decode benchmark (default: 64)") - parser.add_argument("--max-len", type=int, default=32*1024, help="Max model length (default: 32K)") - parser.add_argument("--bench-decode", action="store_true", help="Run decode benchmark (default: prefill only)") - parser.add_argument("--bench-all", action="store_true", help="Run both prefill and decode benchmarks") + + parser = argparse.ArgumentParser( + description="Benchmark vLLM performance (for comparison)" + ) + parser.add_argument( + "--input-len", type=int, default=None, help="Input length in tokens" + ) + parser.add_argument( + "--output-len", + type=int, + default=64, + help="Output length for decode benchmark (default: 64)", + ) + parser.add_argument( + "--max-len", type=int, default=32 * 1024, help="Max model length (default: 32K)" + ) + parser.add_argument( + "--bench-decode", + action="store_true", + help="Run decode benchmark (default: prefill only)", + ) + parser.add_argument( + "--bench-all", + action="store_true", + help="Run both prefill and decode benchmarks", + ) args = parser.parse_args() path = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/") @@ -61,7 +94,7 @@ def main(): enforce_eager=False, max_model_len=max_len, max_num_seqs=128, - gpu_memory_utilization=0.9, + gpu_memory_utilization=0.7, ) # Warmup @@ -86,7 +119,9 @@ def main(): print("\n" + "=" * 60) print("Decode Benchmark (vLLM)") print("=" * 60) - bench_decode(llm, num_seqs=1, input_len=decode_input_len, output_len=args.output_len) + bench_decode( + llm, num_seqs=1, input_len=decode_input_len, output_len=args.output_len + ) if __name__ == "__main__":