This commit is contained in:
GeeeekExplorer
2025-06-10 08:52:58 +08:00
parent a5a4909e6a
commit b98e1ca305
10 changed files with 39 additions and 26 deletions

View File

@@ -1,4 +1,5 @@
from collections import defaultdict
from time import perf_counter
from tqdm.auto import tqdm
from transformers import AutoConfig, AutoTokenizer
@@ -33,7 +34,7 @@ class LLMEngine:
seqs, is_prefill = self.scheduler.schedule()
token_ids = self.model_runner.run(seqs, is_prefill)
finished = self.scheduler.postprocess(seqs, token_ids)
return [(seq.seq_id, token_id, finish) for seq, token_id, finish in zip(seqs, token_ids, finished)]
return [(seq.seq_id, token_id, finish) for seq, token_id, finish in zip(seqs, token_ids, finished)], sum(len(seq) for seq in seqs) if is_prefill else len(seqs)
def is_finished(self):
return self.scheduler.is_finished()
@@ -45,19 +46,32 @@ class LLMEngine:
use_tqdm: bool = True,
) -> list[str]:
if use_tqdm:
pbar = tqdm(total=len(prompts),
desc="Processed prompts",
pbar = tqdm(
total=len(prompts),
desc="Generating",
dynamic_ncols=True,
)
if not isinstance(SamplingParams, list):
sampling_params = [sampling_params] * len(prompts)
for prompt, sp in zip(prompts, sampling_params):
self.add_request(prompt, sp)
outputs = defaultdict(list)
prefill_throughput = decode_throughput = 0.
while not self.is_finished():
output = self.step()
t = perf_counter()
output, num_tokens = self.step()
if use_tqdm:
if num_tokens > len(output):
prefill_throughput = num_tokens / (perf_counter() - t)
else:
decode_throughput = num_tokens / (perf_counter() - t)
pbar.set_postfix({
"Prefill": f"{int(prefill_throughput)}tok/s",
"Decode": f"{int(decode_throughput)}tok/s",
})
for seq_id, token_id, finish in output:
outputs[seq_id].append(token_id)
if use_tqdm and finish:
if finish and use_tqdm:
pbar.update(1)
outputs = [outputs[seq_id] for seq_id in sorted(outputs)]
outputs = [self.tokenizer.decode(token_ids) for token_ids in outputs]