[feat] Added metric into tqdm bar.
This commit is contained in:
@@ -1,6 +1,6 @@
|
|||||||
import atexit
|
import atexit
|
||||||
from dataclasses import fields
|
from dataclasses import fields
|
||||||
from time import perf_counter
|
from time import perf_counter, perf_counter_ns
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
@@ -10,6 +10,7 @@ from nanovllm.sampling_params import SamplingParams
|
|||||||
from nanovllm.engine.sequence import Sequence
|
from nanovllm.engine.sequence import Sequence
|
||||||
from nanovllm.engine.scheduler import Scheduler
|
from nanovllm.engine.scheduler import Scheduler
|
||||||
from nanovllm.engine.model_runner import ModelRunner
|
from nanovllm.engine.model_runner import ModelRunner
|
||||||
|
from nanovllm.utils.observer import Observer
|
||||||
|
|
||||||
|
|
||||||
class LLMEngine:
|
class LLMEngine:
|
||||||
@@ -47,6 +48,15 @@ class LLMEngine:
|
|||||||
|
|
||||||
def step(self):
|
def step(self):
|
||||||
seqs, is_prefill = self.scheduler.schedule()
|
seqs, is_prefill = self.scheduler.schedule()
|
||||||
|
if not is_prefill:
|
||||||
|
# The end of the prefill mode. Get TTFT.
|
||||||
|
if Observer.ttft_start != 0:
|
||||||
|
Observer.ttft = perf_counter_ns() - Observer.ttft_start
|
||||||
|
Observer.reset_ttft()
|
||||||
|
# The start of the decode mode. Get TPOT.
|
||||||
|
if Observer.tpot_start != 0:
|
||||||
|
Observer.tpot = perf_counter_ns() - Observer.tpot_start
|
||||||
|
Observer.tpot_start = perf_counter_ns()
|
||||||
token_ids = self.model_runner.call("run", seqs, is_prefill)
|
token_ids = self.model_runner.call("run", seqs, is_prefill)
|
||||||
self.scheduler.postprocess(seqs, token_ids)
|
self.scheduler.postprocess(seqs, token_ids)
|
||||||
outputs = [(seq.seq_id, seq.completion_token_ids) for seq in seqs if seq.is_finished]
|
outputs = [(seq.seq_id, seq.completion_token_ids) for seq in seqs if seq.is_finished]
|
||||||
@@ -62,6 +72,7 @@ class LLMEngine:
|
|||||||
sampling_params: SamplingParams | list[SamplingParams],
|
sampling_params: SamplingParams | list[SamplingParams],
|
||||||
use_tqdm: bool = True,
|
use_tqdm: bool = True,
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
|
Observer.complete_reset()
|
||||||
if use_tqdm:
|
if use_tqdm:
|
||||||
pbar = tqdm(total=len(prompts), desc="Generating", dynamic_ncols=True)
|
pbar = tqdm(total=len(prompts), desc="Generating", dynamic_ncols=True)
|
||||||
if not isinstance(sampling_params, list):
|
if not isinstance(sampling_params, list):
|
||||||
@@ -81,6 +92,8 @@ class LLMEngine:
|
|||||||
pbar.set_postfix({
|
pbar.set_postfix({
|
||||||
"Prefill": f"{int(prefill_throughput)}tok/s",
|
"Prefill": f"{int(prefill_throughput)}tok/s",
|
||||||
"Decode": f"{int(decode_throughput)}tok/s",
|
"Decode": f"{int(decode_throughput)}tok/s",
|
||||||
|
"ttft": f"{float(Observer.ttft) / 1e6}ms",
|
||||||
|
"tpot": f"{float(Observer.tpot) / 1e6}ms",
|
||||||
})
|
})
|
||||||
for seq_id, token_ids in output:
|
for seq_id, token_ids in output:
|
||||||
outputs[seq_id] = token_ids
|
outputs[seq_id] = token_ids
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
from collections import deque
|
from collections import deque
|
||||||
|
from time import perf_counter_ns
|
||||||
|
|
||||||
from nanovllm.config import Config
|
from nanovllm.config import Config
|
||||||
from nanovllm.engine.sequence import Sequence, SequenceStatus
|
from nanovllm.engine.sequence import Sequence, SequenceStatus
|
||||||
from nanovllm.engine.block_manager import BlockManager
|
from nanovllm.engine.block_manager import BlockManager
|
||||||
|
from nanovllm.utils.observer import Observer
|
||||||
|
|
||||||
|
|
||||||
class Scheduler:
|
class Scheduler:
|
||||||
@@ -27,6 +29,8 @@ class Scheduler:
|
|||||||
num_seqs = 0
|
num_seqs = 0
|
||||||
num_batched_tokens = 0
|
num_batched_tokens = 0
|
||||||
while self.waiting and num_seqs < self.max_num_seqs:
|
while self.waiting and num_seqs < self.max_num_seqs:
|
||||||
|
if Observer.ttft_start == 0:
|
||||||
|
Observer.ttft_start = perf_counter_ns()
|
||||||
seq = self.waiting[0]
|
seq = self.waiting[0]
|
||||||
if num_batched_tokens + len(seq) > self.max_num_batched_tokens or not self.block_manager.can_allocate(seq):
|
if num_batched_tokens + len(seq) > self.max_num_batched_tokens or not self.block_manager.can_allocate(seq):
|
||||||
break
|
break
|
||||||
|
|||||||
17
nanovllm/utils/observer.py
Normal file
17
nanovllm/utils/observer.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
class Observer():
|
||||||
|
ttft_start = 0
|
||||||
|
tpot_start = 0
|
||||||
|
|
||||||
|
ttft = 0
|
||||||
|
tpot = 0
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def reset_ttft(cls):
|
||||||
|
cls.ttft_start = 0
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def complete_reset(cls):
|
||||||
|
cls.ttft_start = 0
|
||||||
|
cls.tpot_start = 0
|
||||||
|
cls.ttft = 0
|
||||||
|
cls.tpot = 0
|
||||||
Reference in New Issue
Block a user