This commit is contained in:
GeeeekExplorer
2025-06-12 09:41:12 +08:00
parent fee58d44e4
commit f16adb729e
4 changed files with 14 additions and 22 deletions

View File

@@ -8,7 +8,7 @@ class Config:
max_num_batched_tokens: int = 32768 max_num_batched_tokens: int = 32768
max_num_seqs: int = 512 max_num_seqs: int = 512
max_model_len: int = 4096 max_model_len: int = 4096
gpu_memory_utilization: float = 0.95 gpu_memory_utilization: float = 0.9
enforce_eager: bool = False enforce_eager: bool = False
hf_config: AutoConfig | None = None hf_config: AutoConfig | None = None
eos: int = -1 eos: int = -1

View File

@@ -57,9 +57,6 @@ class BlockManager:
self.used_block_ids.remove(block_id) self.used_block_ids.remove(block_id)
self.free_block_ids.append(block_id) self.free_block_ids.append(block_id)
def can_prefill(self):
return len(self.free_block_ids) > 0.1 * len(self.blocks)
def can_allocate(self, seq: Sequence): def can_allocate(self, seq: Sequence):
return len(self.free_block_ids) >= seq.num_blocks return len(self.free_block_ids) >= seq.num_blocks

View File

@@ -1,11 +1,10 @@
from collections import defaultdict
from time import perf_counter from time import perf_counter
from tqdm.auto import tqdm from tqdm.auto import tqdm
from transformers import AutoConfig, AutoTokenizer from transformers import AutoConfig, AutoTokenizer
from nanovllm.config import Config from nanovllm.config import Config
from nanovllm.sampling_params import SamplingParams from nanovllm.sampling_params import SamplingParams
from nanovllm.engine.sequence import Sequence from nanovllm.engine.sequence import Sequence, SequenceStatus
from nanovllm.engine.scheduler import Scheduler from nanovllm.engine.scheduler import Scheduler
from nanovllm.engine.model_runner import ModelRunner from nanovllm.engine.model_runner import ModelRunner
@@ -34,8 +33,10 @@ class LLMEngine:
def step(self): def step(self):
seqs, is_prefill = self.scheduler.schedule() seqs, is_prefill = self.scheduler.schedule()
token_ids = self.model_runner.run(seqs, is_prefill) token_ids = self.model_runner.run(seqs, is_prefill)
finished = self.scheduler.postprocess(seqs, token_ids) self.scheduler.postprocess(seqs, token_ids)
return [(seq.seq_id, token_id, finish) for seq, token_id, finish in zip(seqs, token_ids, finished)], sum(len(seq) for seq in seqs) if is_prefill else len(seqs) outputs = [(seq.seq_id, seq[seq.num_prompt_tokens:]) for seq in seqs if seq.status == SequenceStatus.FINISHED]
num_tokens = sum(len(seq) for seq in seqs) if is_prefill else -len(seqs)
return outputs, num_tokens
def is_finished(self): def is_finished(self):
return self.scheduler.is_finished() return self.scheduler.is_finished()
@@ -56,23 +57,23 @@ class LLMEngine:
sampling_params = [sampling_params] * len(prompts) sampling_params = [sampling_params] * len(prompts)
for prompt, sp in zip(prompts, sampling_params): for prompt, sp in zip(prompts, sampling_params):
self.add_request(prompt, sp) self.add_request(prompt, sp)
outputs = defaultdict(list) outputs = {}
prefill_throughput = decode_throughput = 0. prefill_throughput = decode_throughput = 0.
while not self.is_finished(): while not self.is_finished():
t = perf_counter() t = perf_counter()
output, num_tokens = self.step() output, num_tokens = self.step()
if use_tqdm: if use_tqdm:
if num_tokens > len(output): if num_tokens > 0:
prefill_throughput = num_tokens / (perf_counter() - t) prefill_throughput = num_tokens / (perf_counter() - t)
else: else:
decode_throughput = num_tokens / (perf_counter() - t) decode_throughput = -num_tokens / (perf_counter() - t)
pbar.set_postfix({ pbar.set_postfix({
"Prefill": f"{int(prefill_throughput)}tok/s", "Prefill": f"{int(prefill_throughput)}tok/s",
"Decode": f"{int(decode_throughput)}tok/s", "Decode": f"{int(decode_throughput)}tok/s",
}) })
for seq_id, token_id, finish in output: for seq_id, token_ids in output:
outputs[seq_id].append(token_id) outputs[seq_id] = token_ids
if finish and use_tqdm: if use_tqdm:
pbar.update(1) pbar.update(1)
outputs = [outputs[seq_id] for seq_id in sorted(outputs)] outputs = [outputs[seq_id] for seq_id in sorted(outputs)]
outputs = [{"text": self.tokenizer.decode(token_ids), "token_ids": token_ids} for token_ids in outputs] outputs = [{"text": self.tokenizer.decode(token_ids), "token_ids": token_ids} for token_ids in outputs]

View File

@@ -43,7 +43,6 @@ class Scheduler:
return scheduled_seqs, True return scheduled_seqs, True
# decode # decode
# self.running = deque(sorted(self.running))
while self.running and num_seqs < self.max_num_seqs: while self.running and num_seqs < self.max_num_seqs:
seq = self.running.popleft() seq = self.running.popleft()
while not self.block_manager.can_append(seq): while not self.block_manager.can_append(seq):
@@ -59,8 +58,8 @@ class Scheduler:
running = deque(scheduled_seqs) running = deque(scheduled_seqs)
running.extend(self.running) running.extend(self.running)
self.running = running self.running = running
if scheduled_seqs: assert scheduled_seqs
return scheduled_seqs, False return scheduled_seqs, False
def preempt(self, seq: Sequence): def preempt(self, seq: Sequence):
seq.status = SequenceStatus.WAITING seq.status = SequenceStatus.WAITING
@@ -69,7 +68,6 @@ class Scheduler:
def postprocess(self, seqs: list[Sequence], token_ids: list[int]) -> list[bool]: def postprocess(self, seqs: list[Sequence], token_ids: list[int]) -> list[bool]:
self.num_tokens += len(token_ids) self.num_tokens += len(token_ids)
finished = []
for seq, token_id in zip(seqs, token_ids): for seq, token_id in zip(seqs, token_ids):
seq.append_token(token_id) seq.append_token(token_id)
if (not seq.ignore_eos and token_id == self.eos) or seq.num_completion_tokens == seq.max_tokens: if (not seq.ignore_eos and token_id == self.eos) or seq.num_completion_tokens == seq.max_tokens:
@@ -77,7 +75,3 @@ class Scheduler:
self.block_manager.deallocate(seq) self.block_manager.deallocate(seq)
self.running.remove(seq) self.running.remove(seq)
self.num_finished += 1 self.num_finished += 1
finished.append(True)
else:
finished.append(False)
return finished