refactor
This commit is contained in:
@@ -8,7 +8,7 @@ class Config:
|
|||||||
max_num_batched_tokens: int = 32768
|
max_num_batched_tokens: int = 32768
|
||||||
max_num_seqs: int = 512
|
max_num_seqs: int = 512
|
||||||
max_model_len: int = 4096
|
max_model_len: int = 4096
|
||||||
gpu_memory_utilization: float = 0.95
|
gpu_memory_utilization: float = 0.9
|
||||||
enforce_eager: bool = False
|
enforce_eager: bool = False
|
||||||
hf_config: AutoConfig | None = None
|
hf_config: AutoConfig | None = None
|
||||||
eos: int = -1
|
eos: int = -1
|
||||||
|
|||||||
@@ -57,9 +57,6 @@ class BlockManager:
|
|||||||
self.used_block_ids.remove(block_id)
|
self.used_block_ids.remove(block_id)
|
||||||
self.free_block_ids.append(block_id)
|
self.free_block_ids.append(block_id)
|
||||||
|
|
||||||
def can_prefill(self):
|
|
||||||
return len(self.free_block_ids) > 0.1 * len(self.blocks)
|
|
||||||
|
|
||||||
def can_allocate(self, seq: Sequence):
|
def can_allocate(self, seq: Sequence):
|
||||||
return len(self.free_block_ids) >= seq.num_blocks
|
return len(self.free_block_ids) >= seq.num_blocks
|
||||||
|
|
||||||
|
|||||||
@@ -1,11 +1,10 @@
|
|||||||
from collections import defaultdict
|
|
||||||
from time import perf_counter
|
from time import perf_counter
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
from transformers import AutoConfig, AutoTokenizer
|
from transformers import AutoConfig, AutoTokenizer
|
||||||
|
|
||||||
from nanovllm.config import Config
|
from nanovllm.config import Config
|
||||||
from nanovllm.sampling_params import SamplingParams
|
from nanovllm.sampling_params import SamplingParams
|
||||||
from nanovllm.engine.sequence import Sequence
|
from nanovllm.engine.sequence import Sequence, SequenceStatus
|
||||||
from nanovllm.engine.scheduler import Scheduler
|
from nanovllm.engine.scheduler import Scheduler
|
||||||
from nanovllm.engine.model_runner import ModelRunner
|
from nanovllm.engine.model_runner import ModelRunner
|
||||||
|
|
||||||
@@ -34,8 +33,10 @@ class LLMEngine:
|
|||||||
def step(self):
|
def step(self):
|
||||||
seqs, is_prefill = self.scheduler.schedule()
|
seqs, is_prefill = self.scheduler.schedule()
|
||||||
token_ids = self.model_runner.run(seqs, is_prefill)
|
token_ids = self.model_runner.run(seqs, is_prefill)
|
||||||
finished = self.scheduler.postprocess(seqs, token_ids)
|
self.scheduler.postprocess(seqs, token_ids)
|
||||||
return [(seq.seq_id, token_id, finish) for seq, token_id, finish in zip(seqs, token_ids, finished)], sum(len(seq) for seq in seqs) if is_prefill else len(seqs)
|
outputs = [(seq.seq_id, seq[seq.num_prompt_tokens:]) for seq in seqs if seq.status == SequenceStatus.FINISHED]
|
||||||
|
num_tokens = sum(len(seq) for seq in seqs) if is_prefill else -len(seqs)
|
||||||
|
return outputs, num_tokens
|
||||||
|
|
||||||
def is_finished(self):
|
def is_finished(self):
|
||||||
return self.scheduler.is_finished()
|
return self.scheduler.is_finished()
|
||||||
@@ -56,23 +57,23 @@ class LLMEngine:
|
|||||||
sampling_params = [sampling_params] * len(prompts)
|
sampling_params = [sampling_params] * len(prompts)
|
||||||
for prompt, sp in zip(prompts, sampling_params):
|
for prompt, sp in zip(prompts, sampling_params):
|
||||||
self.add_request(prompt, sp)
|
self.add_request(prompt, sp)
|
||||||
outputs = defaultdict(list)
|
outputs = {}
|
||||||
prefill_throughput = decode_throughput = 0.
|
prefill_throughput = decode_throughput = 0.
|
||||||
while not self.is_finished():
|
while not self.is_finished():
|
||||||
t = perf_counter()
|
t = perf_counter()
|
||||||
output, num_tokens = self.step()
|
output, num_tokens = self.step()
|
||||||
if use_tqdm:
|
if use_tqdm:
|
||||||
if num_tokens > len(output):
|
if num_tokens > 0:
|
||||||
prefill_throughput = num_tokens / (perf_counter() - t)
|
prefill_throughput = num_tokens / (perf_counter() - t)
|
||||||
else:
|
else:
|
||||||
decode_throughput = num_tokens / (perf_counter() - t)
|
decode_throughput = -num_tokens / (perf_counter() - t)
|
||||||
pbar.set_postfix({
|
pbar.set_postfix({
|
||||||
"Prefill": f"{int(prefill_throughput)}tok/s",
|
"Prefill": f"{int(prefill_throughput)}tok/s",
|
||||||
"Decode": f"{int(decode_throughput)}tok/s",
|
"Decode": f"{int(decode_throughput)}tok/s",
|
||||||
})
|
})
|
||||||
for seq_id, token_id, finish in output:
|
for seq_id, token_ids in output:
|
||||||
outputs[seq_id].append(token_id)
|
outputs[seq_id] = token_ids
|
||||||
if finish and use_tqdm:
|
if use_tqdm:
|
||||||
pbar.update(1)
|
pbar.update(1)
|
||||||
outputs = [outputs[seq_id] for seq_id in sorted(outputs)]
|
outputs = [outputs[seq_id] for seq_id in sorted(outputs)]
|
||||||
outputs = [{"text": self.tokenizer.decode(token_ids), "token_ids": token_ids} for token_ids in outputs]
|
outputs = [{"text": self.tokenizer.decode(token_ids), "token_ids": token_ids} for token_ids in outputs]
|
||||||
|
|||||||
@@ -43,7 +43,6 @@ class Scheduler:
|
|||||||
return scheduled_seqs, True
|
return scheduled_seqs, True
|
||||||
|
|
||||||
# decode
|
# decode
|
||||||
# self.running = deque(sorted(self.running))
|
|
||||||
while self.running and num_seqs < self.max_num_seqs:
|
while self.running and num_seqs < self.max_num_seqs:
|
||||||
seq = self.running.popleft()
|
seq = self.running.popleft()
|
||||||
while not self.block_manager.can_append(seq):
|
while not self.block_manager.can_append(seq):
|
||||||
@@ -59,8 +58,8 @@ class Scheduler:
|
|||||||
running = deque(scheduled_seqs)
|
running = deque(scheduled_seqs)
|
||||||
running.extend(self.running)
|
running.extend(self.running)
|
||||||
self.running = running
|
self.running = running
|
||||||
if scheduled_seqs:
|
assert scheduled_seqs
|
||||||
return scheduled_seqs, False
|
return scheduled_seqs, False
|
||||||
|
|
||||||
def preempt(self, seq: Sequence):
|
def preempt(self, seq: Sequence):
|
||||||
seq.status = SequenceStatus.WAITING
|
seq.status = SequenceStatus.WAITING
|
||||||
@@ -69,7 +68,6 @@ class Scheduler:
|
|||||||
|
|
||||||
def postprocess(self, seqs: list[Sequence], token_ids: list[int]) -> list[bool]:
|
def postprocess(self, seqs: list[Sequence], token_ids: list[int]) -> list[bool]:
|
||||||
self.num_tokens += len(token_ids)
|
self.num_tokens += len(token_ids)
|
||||||
finished = []
|
|
||||||
for seq, token_id in zip(seqs, token_ids):
|
for seq, token_id in zip(seqs, token_ids):
|
||||||
seq.append_token(token_id)
|
seq.append_token(token_id)
|
||||||
if (not seq.ignore_eos and token_id == self.eos) or seq.num_completion_tokens == seq.max_tokens:
|
if (not seq.ignore_eos and token_id == self.eos) or seq.num_completion_tokens == seq.max_tokens:
|
||||||
@@ -77,7 +75,3 @@ class Scheduler:
|
|||||||
self.block_manager.deallocate(seq)
|
self.block_manager.deallocate(seq)
|
||||||
self.running.remove(seq)
|
self.running.remove(seq)
|
||||||
self.num_finished += 1
|
self.num_finished += 1
|
||||||
finished.append(True)
|
|
||||||
else:
|
|
||||||
finished.append(False)
|
|
||||||
return finished
|
|
||||||
|
|||||||
Reference in New Issue
Block a user