[feat] Added metric into tqdm bar.
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import atexit
|
||||
from dataclasses import fields
|
||||
from time import perf_counter
|
||||
from time import perf_counter, perf_counter_ns
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import AutoTokenizer
|
||||
import torch.multiprocessing as mp
|
||||
@@ -10,6 +10,7 @@ from nanovllm.sampling_params import SamplingParams
|
||||
from nanovllm.engine.sequence import Sequence
|
||||
from nanovllm.engine.scheduler import Scheduler
|
||||
from nanovllm.engine.model_runner import ModelRunner
|
||||
from nanovllm.utils.observer import Observer
|
||||
|
||||
|
||||
class LLMEngine:
|
||||
@@ -47,6 +48,15 @@ class LLMEngine:
|
||||
|
||||
def step(self):
|
||||
seqs, is_prefill = self.scheduler.schedule()
|
||||
if not is_prefill:
|
||||
# The end of the prefill mode. Get TTFT.
|
||||
if Observer.ttft_start != 0:
|
||||
Observer.ttft = perf_counter_ns() - Observer.ttft_start
|
||||
Observer.reset_ttft()
|
||||
# The start of the decode mode. Get TPOT.
|
||||
if Observer.tpot_start != 0:
|
||||
Observer.tpot = perf_counter_ns() - Observer.tpot_start
|
||||
Observer.tpot_start = perf_counter_ns()
|
||||
token_ids = self.model_runner.call("run", seqs, is_prefill)
|
||||
self.scheduler.postprocess(seqs, token_ids)
|
||||
outputs = [(seq.seq_id, seq.completion_token_ids) for seq in seqs if seq.is_finished]
|
||||
@@ -62,6 +72,7 @@ class LLMEngine:
|
||||
sampling_params: SamplingParams | list[SamplingParams],
|
||||
use_tqdm: bool = True,
|
||||
) -> list[str]:
|
||||
Observer.complete_reset()
|
||||
if use_tqdm:
|
||||
pbar = tqdm(total=len(prompts), desc="Generating", dynamic_ncols=True)
|
||||
if not isinstance(sampling_params, list):
|
||||
@@ -81,6 +92,8 @@ class LLMEngine:
|
||||
pbar.set_postfix({
|
||||
"Prefill": f"{int(prefill_throughput)}tok/s",
|
||||
"Decode": f"{int(decode_throughput)}tok/s",
|
||||
"ttft": f"{float(Observer.ttft) / 1e6}ms",
|
||||
"tpot": f"{float(Observer.tpot) / 1e6}ms",
|
||||
})
|
||||
for seq_id, token_ids in output:
|
||||
outputs[seq_id] = token_ids
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
from collections import deque
|
||||
from time import perf_counter_ns
|
||||
|
||||
from nanovllm.config import Config
|
||||
from nanovllm.engine.sequence import Sequence, SequenceStatus
|
||||
from nanovllm.engine.block_manager import BlockManager
|
||||
from nanovllm.utils.observer import Observer
|
||||
|
||||
|
||||
class Scheduler:
|
||||
@@ -27,6 +29,8 @@ class Scheduler:
|
||||
num_seqs = 0
|
||||
num_batched_tokens = 0
|
||||
while self.waiting and num_seqs < self.max_num_seqs:
|
||||
if Observer.ttft_start == 0:
|
||||
Observer.ttft_start = perf_counter_ns()
|
||||
seq = self.waiting[0]
|
||||
if num_batched_tokens + len(seq) > self.max_num_batched_tokens or not self.block_manager.can_allocate(seq):
|
||||
break
|
||||
|
||||
17
nanovllm/utils/observer.py
Normal file
17
nanovllm/utils/observer.py
Normal file
@@ -0,0 +1,17 @@
|
||||
class Observer():
|
||||
ttft_start = 0
|
||||
tpot_start = 0
|
||||
|
||||
ttft = 0
|
||||
tpot = 0
|
||||
|
||||
@classmethod
|
||||
def reset_ttft(cls):
|
||||
cls.ttft_start = 0
|
||||
|
||||
@classmethod
|
||||
def complete_reset(cls):
|
||||
cls.ttft_start = 0
|
||||
cls.tpot_start = 0
|
||||
cls.ttft = 0
|
||||
cls.tpot = 0
|
||||
Reference in New Issue
Block a user