Files
nano-vllm/nanovllm/engine/llm_engine.py

107 lines
4.3 KiB
Python

import atexit
from dataclasses import fields
from time import perf_counter, perf_counter_ns
from tqdm.auto import tqdm
from transformers import AutoTokenizer
import torch.multiprocessing as mp
from nanovllm.config import Config
from nanovllm.sampling_params import SamplingParams
from nanovllm.engine.sequence import Sequence
from nanovllm.engine.scheduler import Scheduler
from nanovllm.engine.model_runner import ModelRunner
from nanovllm.utils.observer import Observer
class LLMEngine:
def __init__(self, model, **kwargs):
config_fields = {field.name for field in fields(Config)}
config_kwargs = {k: v for k, v in kwargs.items() if k in config_fields}
config = Config(model, **config_kwargs)
self.ps = []
self.events = []
ctx = mp.get_context("spawn")
for i in range(1, config.tensor_parallel_size):
event = ctx.Event()
process = ctx.Process(target=ModelRunner, args=(config, i, event))
process.start()
self.ps.append(process)
self.events.append(event)
self.model_runner = ModelRunner(config, 0, self.events)
self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True)
config.eos = self.tokenizer.eos_token_id
self.scheduler = Scheduler(config, self.model_runner.kvcache_manager)
atexit.register(self.exit)
def exit(self):
self.model_runner.call("exit")
del self.model_runner
for p in self.ps:
p.join()
def add_request(self, prompt: str | list[int], sampling_params: SamplingParams):
if isinstance(prompt, str):
prompt = self.tokenizer.encode(prompt)
seq = Sequence(prompt, sampling_params)
self.scheduler.add(seq)
def step(self):
seqs, is_prefill = self.scheduler.schedule()
if not is_prefill:
# The end of the prefill mode. Get TTFT.
if Observer.ttft_start != 0:
Observer.ttft = perf_counter_ns() - Observer.ttft_start
Observer.reset_ttft()
# The start of the decode mode. Get TPOT.
if Observer.tpot_start != 0:
Observer.tpot = perf_counter_ns() - Observer.tpot_start
Observer.tpot_start = perf_counter_ns()
token_ids = self.model_runner.call("run", seqs, is_prefill)
self.scheduler.postprocess(seqs, token_ids)
outputs = [(seq.seq_id, seq.completion_token_ids) for seq in seqs if seq.is_finished]
num_tokens = sum(len(seq) for seq in seqs) if is_prefill else -len(seqs)
return outputs, num_tokens
def is_finished(self):
return self.scheduler.is_finished()
def generate(
self,
prompts: list[str] | list[list[int]],
sampling_params: SamplingParams | list[SamplingParams],
use_tqdm: bool = True,
) -> list[str]:
Observer.complete_reset()
if use_tqdm:
pbar = tqdm(total=len(prompts), desc="Generating", dynamic_ncols=True)
if not isinstance(sampling_params, list):
sampling_params = [sampling_params] * len(prompts)
for prompt, sp in zip(prompts, sampling_params):
self.add_request(prompt, sp)
outputs = {}
prefill_throughput = decode_throughput = 0.
while not self.is_finished():
t = perf_counter()
output, num_tokens = self.step()
if use_tqdm:
if num_tokens > 0:
prefill_throughput = num_tokens / (perf_counter() - t)
else:
decode_throughput = -num_tokens / (perf_counter() - t)
pbar.set_postfix({
"Prefill": f"{int(prefill_throughput)}tok/s",
"Decode": f"{int(decode_throughput)}tok/s",
"ttft": f"{float(Observer.ttft) / 1e6}ms",
"tpot": f"{float(Observer.tpot) / 1e6}ms",
})
for seq_id, token_ids in output:
outputs[seq_id] = token_ids
if use_tqdm:
pbar.update(1)
outputs = [outputs[seq_id] for seq_id in sorted(outputs.keys())]
outputs = [{"text": self.tokenizer.decode(token_ids), "token_ids": token_ids} for token_ids in outputs]
if use_tqdm:
pbar.close()
return outputs