diff --git a/bench.py b/bench.py index ea977ba..a7c85f7 100644 --- a/bench.py +++ b/bench.py @@ -5,23 +5,28 @@ from nanovllm import LLM, SamplingParams # from vllm import LLM, SamplingParams -seed(0) -num_seqs = 256 -max_input_len = 1024 -max_ouput_len = 1024 +def main(): + seed(0) + num_seqs = 256 + max_input_len = 1024 + max_ouput_len = 1024 -path = os.path.expanduser("~/huggingface/Qwen3-0.6B/") -llm = LLM(path, enforce_eager=False, max_model_len=4096) + path = os.path.expanduser("~/huggingface/Qwen3-0.6B/") + llm = LLM(path, enforce_eager=False, max_model_len=4096) -prompt_token_ids = [[randint(0, 10000) for _ in range(randint(100, max_input_len))] for _ in range(num_seqs)] -sampling_params = [SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=randint(100, max_ouput_len)) for _ in range(num_seqs)] -# uncomment the following line for vllm -# prompt_token_ids = [dict(prompt_token_ids=p) for p in prompt_token_ids] + prompt_token_ids = [[randint(0, 10000) for _ in range(randint(100, max_input_len))] for _ in range(num_seqs)] + sampling_params = [SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=randint(100, max_ouput_len)) for _ in range(num_seqs)] + # uncomment the following line for vllm + # prompt_token_ids = [dict(prompt_token_ids=p) for p in prompt_token_ids] -llm.generate(["Benchmark: "], SamplingParams()) -t = time.time() -llm.generate(prompt_token_ids, sampling_params) -t = (time.time() - t) -total_tokens = sum(sp.max_tokens for sp in sampling_params) -throughput = total_tokens / t -print(f"Total: {total_tokens}tok, Time: {t:.2f}s, Throughput: {throughput:.2f}tok/s") + llm.generate(["Benchmark: "], SamplingParams()) + t = time.time() + llm.generate(prompt_token_ids, sampling_params) + t = (time.time() - t) + total_tokens = sum(sp.max_tokens for sp in sampling_params) + throughput = total_tokens / t + print(f"Total: {total_tokens}tok, Time: {t:.2f}s, Throughput: {throughput:.2f}tok/s") + + +if __name__ == "__main__": + main() diff --git a/example.py b/example.py index fef1f30..33540f6 100644 --- a/example.py +++ b/example.py @@ -3,27 +3,32 @@ from nanovllm import LLM, SamplingParams from transformers import AutoTokenizer -path = os.path.expanduser("~/huggingface/Qwen3-0.6B/") -tokenizer = AutoTokenizer.from_pretrained(path) -llm = LLM(path, enforce_eager=True) +def main(): + path = os.path.expanduser("~/huggingface/Qwen3-0.6B/") + tokenizer = AutoTokenizer.from_pretrained(path) + llm = LLM(path, enforce_eager=True, tensor_parallel_size=1) -sampling_params = SamplingParams(temperature=0.6, max_tokens=256) -prompts = [ - "introduce yourself", - "list all prime numbers within 100", -] -prompts = [ - tokenizer.apply_chat_template( - [{"role": "user", "content": prompt}], - tokenize=False, - add_generation_prompt=True, - enable_thinking=True - ) - for prompt in prompts -] -outputs = llm.generate(prompts, sampling_params) + sampling_params = SamplingParams(temperature=0.6, max_tokens=256) + prompts = [ + "introduce yourself", + "list all prime numbers within 100", + ] + prompts = [ + tokenizer.apply_chat_template( + [{"role": "user", "content": prompt}], + tokenize=False, + add_generation_prompt=True, + enable_thinking=True + ) + for prompt in prompts + ] + outputs = llm.generate(prompts, sampling_params) -for prompt, output in zip(prompts, outputs): - print("\n") - print(f"Prompt: {prompt!r}") - print(f"Completion: {output['text']!r}") + for prompt, output in zip(prompts, outputs): + print("\n") + print(f"Prompt: {prompt!r}") + print(f"Completion: {output['text']!r}") + + +if __name__ == "__main__": + main() diff --git a/nanovllm/engine/llm_engine.py b/nanovllm/engine/llm_engine.py index 7d73d42..6e64afd 100644 --- a/nanovllm/engine/llm_engine.py +++ b/nanovllm/engine/llm_engine.py @@ -20,9 +20,10 @@ class LLMEngine: config = Config(model, **config_kwargs) self.ps = [] self.events = [] + ctx = mp.get_context("spawn") for i in range(1, config.tensor_parallel_size): - event = mp.Event() - process = mp.Process(target=ModelRunner, args=(config, i, event)) + event = ctx.Event() + process = ctx.Process(target=ModelRunner, args=(config, i, event)) process.start() self.ps.append(process) self.events.append(event)