diff --git a/bench.py b/bench.py index 3358157..46bc21f 100644 --- a/bench.py +++ b/bench.py @@ -2,6 +2,7 @@ import os import time from random import randint, seed from nanovllm import LLM, SamplingParams +from nanovllm.utils.observer import InferenceObserver def bench_decode(llm, num_seqs, input_len, output_len): @@ -14,13 +15,17 @@ def bench_decode(llm, num_seqs, input_len, output_len): llm.generate(prompt_token_ids, sampling_params, use_tqdm=False) t = time.time() - t - # Calculate metrics - prefill_tokens = num_seqs * input_len + # Get metrics from InferenceObserver + ttft_ms = InferenceObserver.ttft / 1e6 + tpot_ms = InferenceObserver.tpot / 1e6 + + # Calculate throughput from observer metrics decode_tokens = num_seqs * output_len - decode_throughput = decode_tokens / t + decode_throughput = 1000.0 / tpot_ms if tpot_ms > 0 else 0 # tokens/s per sequence print(f"[Decode] Input: {num_seqs}x{input_len}tok, Output: {decode_tokens}tok, Time: {t:.2f}s") - print(f" Throughput: {decode_throughput:.2f} tok/s (includes prefill overhead)") + print(f" TTFT: {ttft_ms:.2f}ms, TPOT: {tpot_ms:.2f}ms") + print(f" Decode Throughput: {decode_throughput:.2f} tok/s (from observer)") def bench_prefill(llm, num_seqs, input_len): @@ -33,9 +38,19 @@ def bench_prefill(llm, num_seqs, input_len): t = time.time() llm.generate(prompt_token_ids, sampling_params, use_tqdm=False) t = time.time() - t + + # Get TTFT from InferenceObserver + ttft_ms = InferenceObserver.ttft / 1e6 + ttft_s = ttft_ms / 1000.0 + total_input_tokens = num_seqs * input_len - throughput = total_input_tokens / t - print(f"[Prefill] Input: {total_input_tokens}tok ({num_seqs}x{input_len}), Time: {t:.2f}s, Throughput: {throughput:.2f}tok/s") + # Use observer TTFT for accurate prefill throughput + throughput_observer = total_input_tokens / ttft_s if ttft_s > 0 else 0 + throughput_external = total_input_tokens / t + + print(f"[Prefill] Input: {total_input_tokens}tok ({num_seqs}x{input_len})") + print(f" External Time: {t:.2f}s, Throughput: {throughput_external:.2f}tok/s") + print(f" Observer TTFT: {ttft_ms:.2f}ms, Throughput: {throughput_observer:.2f}tok/s") def main(): diff --git a/bench_offload.py b/bench_offload.py index 140e568..90e4f4d 100644 --- a/bench_offload.py +++ b/bench_offload.py @@ -2,6 +2,7 @@ import os import time from random import randint, seed from nanovllm import LLM, SamplingParams +from nanovllm.utils.observer import InferenceObserver def bench_decode(llm, num_seqs, input_len, output_len): @@ -14,16 +15,17 @@ def bench_decode(llm, num_seqs, input_len, output_len): llm.generate(prompt_token_ids, sampling_params, use_tqdm=False) t = time.time() - t - # Calculate metrics - prefill_tokens = num_seqs * input_len - decode_tokens = num_seqs * output_len + # Get metrics from InferenceObserver + ttft_ms = InferenceObserver.ttft / 1e6 + tpot_ms = InferenceObserver.tpot / 1e6 - # Approximate: assume prefill takes ~input_len/prefill_speed, rest is decode - # For more accurate measurement, we'd need internal timing - decode_throughput = decode_tokens / t # This includes prefill time, so it's a lower bound + # Calculate throughput from observer metrics + decode_tokens = num_seqs * output_len + decode_throughput = 1000.0 / tpot_ms if tpot_ms > 0 else 0 # tokens/s per sequence print(f"[Decode] Input: {num_seqs}x{input_len}tok, Output: {decode_tokens}tok, Time: {t:.2f}s") - print(f" Throughput: {decode_throughput:.2f} tok/s (includes prefill overhead)") + print(f" TTFT: {ttft_ms:.2f}ms, TPOT: {tpot_ms:.2f}ms") + print(f" Decode Throughput: {decode_throughput:.2f} tok/s (from observer)") def bench_prefill(llm, num_seqs, input_len): @@ -36,9 +38,19 @@ def bench_prefill(llm, num_seqs, input_len): t = time.time() llm.generate(prompt_token_ids, sampling_params, use_tqdm=False) t = time.time() - t + + # Get TTFT from InferenceObserver + ttft_ms = InferenceObserver.ttft / 1e6 + ttft_s = ttft_ms / 1000.0 + total_input_tokens = num_seqs * input_len - throughput = total_input_tokens / t - print(f"[Prefill] Input: {total_input_tokens}tok ({num_seqs}x{input_len}), Time: {t:.2f}s, Throughput: {throughput:.2f}tok/s") + # Use observer TTFT for accurate prefill throughput + throughput_observer = total_input_tokens / ttft_s if ttft_s > 0 else 0 + throughput_external = total_input_tokens / t + + print(f"[Prefill] Input: {total_input_tokens}tok ({num_seqs}x{input_len})") + print(f" External Time: {t:.2f}s, Throughput: {throughput_external:.2f}tok/s") + print(f" Observer TTFT: {ttft_ms:.2f}ms, Throughput: {throughput_observer:.2f}tok/s") def main(): diff --git a/nanovllm/engine/llm_engine.py b/nanovllm/engine/llm_engine.py index e7c4858..938b16a 100644 --- a/nanovllm/engine/llm_engine.py +++ b/nanovllm/engine/llm_engine.py @@ -10,7 +10,7 @@ from nanovllm.sampling_params import SamplingParams from nanovllm.engine.sequence import Sequence from nanovllm.engine.scheduler import Scheduler from nanovllm.engine.model_runner import ModelRunner -from nanovllm.utils.observer import Observer +from nanovllm.utils.observer import InferenceObserver class LLMEngine: @@ -58,15 +58,18 @@ class LLMEngine: print(f"[DEBUG LLMEngine.step] Mode={mode}, active_sequences={len(seqs)}") if not is_prefill: - # The end of the prefill mode. Get TTFT. - if Observer.ttft_start != 0: - Observer.ttft = perf_counter_ns() - Observer.ttft_start - Observer.reset_ttft() - # The start of the decode mode. Get TPOT. - if Observer.tpot_start != 0: - Observer.tpot = perf_counter_ns() - Observer.tpot_start - Observer.tpot_start = perf_counter_ns() + # Decode mode: calculate TPOT from previous decode step + if InferenceObserver.tpot_start != 0: + InferenceObserver.tpot = perf_counter_ns() - InferenceObserver.tpot_start + InferenceObserver.tpot_start = perf_counter_ns() + token_ids = self.model_runner.call("run", seqs, is_prefill) + + if is_prefill: + # Calculate TTFT after prefill completes (including chunked prefill) + if InferenceObserver.ttft_start != 0: + InferenceObserver.ttft = perf_counter_ns() - InferenceObserver.ttft_start + InferenceObserver.reset_ttft() self.scheduler.postprocess(seqs, token_ids) outputs = [(seq.seq_id, seq.completion_token_ids) for seq in seqs if seq.is_finished] @@ -91,7 +94,7 @@ class LLMEngine: log_level = os.environ.get('NANOVLLM_LOG_LEVEL', 'INFO') debug_enabled = log_level.upper() == 'DEBUG' - Observer.complete_reset() + InferenceObserver.complete_reset() if use_tqdm: pbar = tqdm(total=len(prompts), desc="Generating", dynamic_ncols=True) if not isinstance(sampling_params, list): @@ -128,8 +131,8 @@ class LLMEngine: pbar.set_postfix({ "Prefill": f"{int(prefill_throughput)}tok/s", "Decode": f"{int(decode_throughput)}tok/s", - "ttft": f"{float(Observer.ttft) / 1e6}ms", - "tpot": f"{float(Observer.tpot) / 1e6}ms", + "ttft": f"{float(InferenceObserver.ttft) / 1e6}ms", + "tpot": f"{float(InferenceObserver.tpot) / 1e6}ms", }) for seq_id, token_ids in output: outputs[seq_id] = token_ids diff --git a/nanovllm/engine/scheduler.py b/nanovllm/engine/scheduler.py index 47192af..994ddd6 100644 --- a/nanovllm/engine/scheduler.py +++ b/nanovllm/engine/scheduler.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING from nanovllm.config import Config from nanovllm.engine.sequence import Sequence, SequenceStatus -from nanovllm.utils.observer import Observer +from nanovllm.utils.observer import InferenceObserver if TYPE_CHECKING: from nanovllm.kvcache import KVCacheManager @@ -32,8 +32,8 @@ class Scheduler: num_seqs = 0 num_batched_tokens = 0 while self.waiting and num_seqs < self.max_num_seqs: - if Observer.ttft_start == 0: - Observer.ttft_start = perf_counter_ns() + if InferenceObserver.ttft_start == 0: + InferenceObserver.ttft_start = perf_counter_ns() seq = self.waiting[0] # Check if sequence is too large diff --git a/nanovllm/utils/observer.py b/nanovllm/utils/observer.py index a1ac862..6b17eed 100644 --- a/nanovllm/utils/observer.py +++ b/nanovllm/utils/observer.py @@ -1,17 +1,106 @@ -class Observer(): - ttft_start = 0 - tpot_start = 0 +""" +Observer 基类和 InferenceObserver 实现。 - ttft = 0 - tpot = 0 +Observer 架构: +- Observer: 基类,定义通用接口 +- InferenceObserver: 推理性能观测(TTFT/TPOT) +- MemoryObserver: 内存传输观测(在 memory_observer.py 中定义) +""" + + +class Observer: + """ + Observer 基类,提供通用的启用/禁用、重置、输出接口。 + + 所有 Observer 子类应继承此类并实现: + - complete_reset(): 重置所有统计数据 + - get_summary(): 返回统计摘要 dict + - print_summary(): 打印人类可读的摘要 + """ + + _enabled: bool = True # 默认启用 @classmethod - def reset_ttft(cls): + def enable(cls) -> None: + """启用 observer""" + cls._enabled = True + + @classmethod + def disable(cls) -> None: + """禁用 observer""" + cls._enabled = False + + @classmethod + def is_enabled(cls) -> bool: + """检查是否启用""" + return cls._enabled + + @classmethod + def complete_reset(cls) -> None: + """重置所有统计数据(子类实现)""" + raise NotImplementedError + + @classmethod + def get_summary(cls) -> dict: + """返回统计摘要(子类实现)""" + raise NotImplementedError + + @classmethod + def print_summary(cls) -> None: + """打印人类可读的摘要(子类可选覆盖)""" + import json + print(json.dumps(cls.get_summary(), indent=2)) + + +class InferenceObserver(Observer): + """ + 推理性能 Observer,统计 TTFT 和 TPOT。 + + - TTFT (Time To First Token): 首个 token 生成延迟 + - TPOT (Time Per Output Token): 每个输出 token 的平均延迟 + + 统计位置: + - TTFT 开始: scheduler.py:35-36 - 第一个 sequence 从 waiting 队列取出时 + - TTFT 结束: llm_engine.py:69-72 - prefill 完成后(包括 chunked prefill 所有 chunks) + - TPOT 开始: llm_engine.py:65 - 每次 decode step 结束时 + - TPOT 结束: llm_engine.py:62-63 - 下一次 decode step 开始时计算(测量上一次 decode 时间) + - 重置: llm_engine.py:97 - generate() 开始时 + + 注意:TPOT 需要至少 2 个输出 token 才能计算(测量 decode step 间隔)。 + """ + + # 时间戳 (nanoseconds) + ttft_start: int = 0 + tpot_start: int = 0 + + # 统计结果 (nanoseconds) + ttft: int = 0 + tpot: int = 0 + + @classmethod + def reset_ttft(cls) -> None: + """重置 TTFT 计时器""" cls.ttft_start = 0 @classmethod - def complete_reset(cls): + def complete_reset(cls) -> None: + """重置所有统计数据""" cls.ttft_start = 0 cls.tpot_start = 0 cls.ttft = 0 cls.tpot = 0 + + @classmethod + def get_summary(cls) -> dict: + """返回统计摘要""" + return { + "ttft_ns": cls.ttft, + "ttft_ms": cls.ttft / 1e6, + "tpot_ns": cls.tpot, + "tpot_ms": cls.tpot / 1e6, + } + + @classmethod + def print_summary(cls) -> None: + """打印摘要""" + print(f"[InferenceObserver] TTFT: {cls.ttft / 1e6:.2f}ms, TPOT: {cls.tpot / 1e6:.2f}ms")