Merge branch 'tzj/minference' of ssh://git.zijie-tian.site:2222/zijie-tian/nano-vllm into tzj/minference
This commit is contained in:
@@ -49,7 +49,14 @@ class LLMEngine:
|
||||
self.scheduler.add(seq)
|
||||
|
||||
def step(self):
|
||||
import os
|
||||
debug_enabled = os.environ.get('NANOVLLM_LOG_LEVEL', 'INFO').upper() == 'DEBUG'
|
||||
|
||||
seqs, is_prefill = self.scheduler.schedule()
|
||||
if debug_enabled:
|
||||
mode = "PREFILL" if is_prefill else "DECODE"
|
||||
print(f"[DEBUG LLMEngine.step] Mode={mode}, active_sequences={len(seqs)}")
|
||||
|
||||
if not is_prefill:
|
||||
# The end of the prefill mode. Get TTFT.
|
||||
if Observer.ttft_start != 0:
|
||||
@@ -62,7 +69,11 @@ class LLMEngine:
|
||||
token_ids = self.model_runner.call("run", seqs, is_prefill)
|
||||
self.scheduler.postprocess(seqs, token_ids)
|
||||
outputs = [(seq.seq_id, seq.completion_token_ids) for seq in seqs if seq.is_finished]
|
||||
|
||||
|
||||
if debug_enabled and outputs:
|
||||
for seq_id, tokens in outputs:
|
||||
print(f"[DEBUG LLMEngine.step] Sequence {seq_id} finished, {len(tokens)} tokens generated")
|
||||
|
||||
#> Calculate number of tokens processed
|
||||
num_tokens = sum(len(seq) for seq in seqs) if is_prefill else -len(seqs)
|
||||
return outputs, num_tokens
|
||||
@@ -76,6 +87,10 @@ class LLMEngine:
|
||||
sampling_params: SamplingParams | list[SamplingParams],
|
||||
use_tqdm: bool = True,
|
||||
) -> list[str]:
|
||||
import os
|
||||
log_level = os.environ.get('NANOVLLM_LOG_LEVEL', 'INFO')
|
||||
debug_enabled = log_level.upper() == 'DEBUG'
|
||||
|
||||
Observer.complete_reset()
|
||||
if use_tqdm:
|
||||
pbar = tqdm(total=len(prompts), desc="Generating", dynamic_ncols=True)
|
||||
@@ -85,7 +100,24 @@ class LLMEngine:
|
||||
self.add_request(prompt, sp)
|
||||
outputs = {}
|
||||
prefill_throughput = decode_throughput = 0.
|
||||
iteration = 0
|
||||
last_output_count = 0
|
||||
|
||||
while not self.is_finished():
|
||||
if debug_enabled and iteration % 100 == 0:
|
||||
print(f"[DEBUG LLMEngine] Iteration {iteration}, finished_sequences={len(outputs)}, total_prompts={len(prompts)}")
|
||||
|
||||
# Timeout check (32K sample should finish within 20 minutes = 1200 seconds)
|
||||
if iteration == 0:
|
||||
import time
|
||||
start_time = time.time()
|
||||
elif debug_enabled and iteration % 100 == 0:
|
||||
elapsed = time.time() - start_time
|
||||
if elapsed > 1200: # 20 minutes
|
||||
print(f"[WARNING] Test exceeded 20 minutes timeout! Iteration={iteration}, forcing exit.")
|
||||
import sys
|
||||
sys.exit(1)
|
||||
|
||||
t = perf_counter()
|
||||
output, num_tokens = self.step()
|
||||
if use_tqdm:
|
||||
|
||||
Reference in New Issue
Block a user