[feat] Added metric into tqdm bar.

This commit is contained in:
Zijie Tian
2025-12-10 00:52:13 +08:00
parent 761929390e
commit 204fe2b38f
3 changed files with 35 additions and 1 deletions

View File

@@ -1,6 +1,6 @@
import atexit
from dataclasses import fields
from time import perf_counter
from time import perf_counter, perf_counter_ns
from tqdm.auto import tqdm
from transformers import AutoTokenizer
import torch.multiprocessing as mp
@@ -10,6 +10,7 @@ from nanovllm.sampling_params import SamplingParams
from nanovllm.engine.sequence import Sequence
from nanovllm.engine.scheduler import Scheduler
from nanovllm.engine.model_runner import ModelRunner
from nanovllm.utils.observer import Observer
class LLMEngine:
@@ -47,6 +48,15 @@ class LLMEngine:
def step(self):
seqs, is_prefill = self.scheduler.schedule()
if not is_prefill:
# The end of the prefill mode. Get TTFT.
if Observer.ttft_start != 0:
Observer.ttft = perf_counter_ns() - Observer.ttft_start
Observer.reset_ttft()
# The start of the decode mode. Get TPOT.
if Observer.tpot_start != 0:
Observer.tpot = perf_counter_ns() - Observer.tpot_start
Observer.tpot_start = perf_counter_ns()
token_ids = self.model_runner.call("run", seqs, is_prefill)
self.scheduler.postprocess(seqs, token_ids)
outputs = [(seq.seq_id, seq.completion_token_ids) for seq in seqs if seq.is_finished]
@@ -62,6 +72,7 @@ class LLMEngine:
sampling_params: SamplingParams | list[SamplingParams],
use_tqdm: bool = True,
) -> list[str]:
Observer.complete_reset()
if use_tqdm:
pbar = tqdm(total=len(prompts), desc="Generating", dynamic_ncols=True)
if not isinstance(sampling_params, list):
@@ -81,6 +92,8 @@ class LLMEngine:
pbar.set_postfix({
"Prefill": f"{int(prefill_throughput)}tok/s",
"Decode": f"{int(decode_throughput)}tok/s",
"ttft": f"{float(Observer.ttft) / 1e6}ms",
"tpot": f"{float(Observer.tpot) / 1e6}ms",
})
for seq_id, token_ids in output:
outputs[seq_id] = token_ids