82 lines
3.2 KiB
Python
82 lines
3.2 KiB
Python
from time import perf_counter
|
|
from tqdm.auto import tqdm
|
|
from transformers import AutoConfig, AutoTokenizer
|
|
|
|
from nanovllm.config import Config
|
|
from nanovllm.sampling_params import SamplingParams
|
|
from nanovllm.engine.sequence import Sequence, SequenceStatus
|
|
from nanovllm.engine.scheduler import Scheduler
|
|
from nanovllm.engine.model_runner import ModelRunner
|
|
|
|
|
|
class LLMEngine:
|
|
|
|
def __init__(self, model, **kwargs):
|
|
config = Config(model)
|
|
for k, v in kwargs.items():
|
|
if hasattr(config, k):
|
|
setattr(config, k, v)
|
|
Sequence.block_size = config.kvcache_block_size
|
|
config.hf_config = AutoConfig.from_pretrained(config.model)
|
|
config.max_model_len = min(config.max_model_len, config.hf_config.max_position_embeddings)
|
|
self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True)
|
|
config.eos = self.tokenizer.eos_token_id
|
|
self.model_runner = ModelRunner(config)
|
|
self.scheduler = Scheduler(config)
|
|
|
|
def add_request(self, prompt: str | list[int], sampling_params: SamplingParams):
|
|
if isinstance(prompt, str):
|
|
prompt = self.tokenizer.encode(prompt)
|
|
seq = Sequence(prompt, sampling_params)
|
|
self.scheduler.add(seq)
|
|
|
|
def step(self):
|
|
seqs, is_prefill = self.scheduler.schedule()
|
|
token_ids = self.model_runner.run(seqs, is_prefill)
|
|
self.scheduler.postprocess(seqs, token_ids)
|
|
outputs = [(seq.seq_id, seq[seq.num_prompt_tokens:]) for seq in seqs if seq.status == SequenceStatus.FINISHED]
|
|
num_tokens = sum(len(seq) for seq in seqs) if is_prefill else -len(seqs)
|
|
return outputs, num_tokens
|
|
|
|
def is_finished(self):
|
|
return self.scheduler.is_finished()
|
|
|
|
def generate(
|
|
self,
|
|
prompts: list[str] | list[list[int]],
|
|
sampling_params: SamplingParams | list[SamplingParams],
|
|
use_tqdm: bool = True,
|
|
) -> list[str]:
|
|
if use_tqdm:
|
|
pbar = tqdm(
|
|
total=len(prompts),
|
|
desc="Generating",
|
|
dynamic_ncols=True,
|
|
)
|
|
if not isinstance(sampling_params, list):
|
|
sampling_params = [sampling_params] * len(prompts)
|
|
for prompt, sp in zip(prompts, sampling_params):
|
|
self.add_request(prompt, sp)
|
|
outputs = {}
|
|
prefill_throughput = decode_throughput = 0.
|
|
while not self.is_finished():
|
|
t = perf_counter()
|
|
output, num_tokens = self.step()
|
|
if use_tqdm:
|
|
if num_tokens > 0:
|
|
prefill_throughput = num_tokens / (perf_counter() - t)
|
|
else:
|
|
decode_throughput = -num_tokens / (perf_counter() - t)
|
|
pbar.set_postfix({
|
|
"Prefill": f"{int(prefill_throughput)}tok/s",
|
|
"Decode": f"{int(decode_throughput)}tok/s",
|
|
})
|
|
for seq_id, token_ids in output:
|
|
outputs[seq_id] = token_ids
|
|
if use_tqdm:
|
|
pbar.update(1)
|
|
outputs = [outputs[seq_id] for seq_id in sorted(outputs)]
|
|
outputs = [{"text": self.tokenizer.decode(token_ids), "token_ids": token_ids} for token_ids in outputs]
|
|
if use_tqdm:
|
|
pbar.close()
|
|
return outputs |