better
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import atexit
|
||||
from dataclasses import fields
|
||||
from time import perf_counter
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import AutoConfig, AutoTokenizer
|
||||
@@ -14,13 +15,9 @@ from nanovllm.engine.model_runner import ModelRunner
|
||||
class LLMEngine:
|
||||
|
||||
def __init__(self, model, **kwargs):
|
||||
config = Config(model)
|
||||
for k, v in kwargs.items():
|
||||
if hasattr(config, k):
|
||||
setattr(config, k, v)
|
||||
Sequence.block_size = config.kvcache_block_size
|
||||
config.hf_config = AutoConfig.from_pretrained(config.model)
|
||||
config.max_model_len = min(config.max_model_len, config.hf_config.max_position_embeddings)
|
||||
config_fileds = {field.name for field in fields(Config)}
|
||||
config_kwargs = {k: v for k, v in kwargs.items() if k in config_fileds}
|
||||
config = Config(model, **config_kwargs)
|
||||
self.ps = []
|
||||
self.events = []
|
||||
for i in range(1, config.tensor_parallel_size):
|
||||
@@ -95,4 +92,4 @@ class LLMEngine:
|
||||
outputs = [{"text": self.tokenizer.decode(token_ids), "token_ids": token_ids} for token_ids in outputs]
|
||||
if use_tqdm:
|
||||
pbar.close()
|
||||
return outputs
|
||||
return outputs
|
||||
|
||||
@@ -57,9 +57,7 @@ class ModelRunner:
|
||||
def loop(self):
|
||||
while True:
|
||||
method_name, args = self.read_shm()
|
||||
method = getattr(self, method_name, None)
|
||||
assert callable(method)
|
||||
method(*args)
|
||||
self.call(method_name, *args)
|
||||
if method_name == "exit":
|
||||
break
|
||||
|
||||
@@ -82,8 +80,7 @@ class ModelRunner:
|
||||
event.set()
|
||||
|
||||
def call(self, method_name, *args):
|
||||
assert self.rank == 0
|
||||
if self.world_size > 1:
|
||||
if self.world_size > 1 and self.rank == 0:
|
||||
self.write_shm(method_name, *args)
|
||||
method = getattr(self, method_name, None)
|
||||
assert callable(method)
|
||||
|
||||
Reference in New Issue
Block a user