support tensor parallel

This commit is contained in:
cheunglei
2025-06-15 01:31:24 +08:00
parent b6136383c9
commit 53b3ef2e32
9 changed files with 102 additions and 31 deletions

View File

@@ -1,6 +1,8 @@
import atexit
from time import perf_counter
from tqdm.auto import tqdm
from transformers import AutoConfig, AutoTokenizer
import torch.multiprocessing as mp
from nanovllm.config import Config
from nanovllm.sampling_params import SamplingParams
@@ -19,10 +21,24 @@ class LLMEngine:
Sequence.block_size = config.kvcache_block_size
config.hf_config = AutoConfig.from_pretrained(config.model)
config.max_model_len = min(config.max_model_len, config.hf_config.max_position_embeddings)
self.ps = []
self.events = []
for i in range(1, config.tensor_parallel_size):
event = mp.Event()
process = mp.Process(target=ModelRunner, args=(config, i, event))
process.start()
self.ps.append(process)
self.events.append(event)
self.model_runner = ModelRunner(config, 0, self.events)
self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True)
config.eos = self.tokenizer.eos_token_id
self.model_runner = ModelRunner(config)
self.scheduler = Scheduler(config)
atexit.register(self.exit)
def exit(self):
self.model_runner.call("exit")
for p in self.ps:
p.join()
def add_request(self, prompt: str | list[int], sampling_params: SamplingParams):
if isinstance(prompt, str):
@@ -32,7 +48,7 @@ class LLMEngine:
def step(self):
seqs, is_prefill = self.scheduler.schedule()
token_ids = self.model_runner.run(seqs, is_prefill)
token_ids = self.model_runner.call("run", seqs, is_prefill)
self.scheduler.postprocess(seqs, token_ids)
outputs = [(seq.seq_id, seq.completion_token_ids) for seq in seqs if seq.is_finished]
num_tokens = sum(len(seq) for seq in seqs) if is_prefill else -len(seqs)