This commit is contained in:
GeeeekExplorer
2025-06-15 10:31:48 +08:00
parent c1fd4ea3c2
commit fc778a4da9
10 changed files with 19 additions and 22 deletions

View File

@@ -1,10 +1,11 @@
import os
from dataclasses import dataclass from dataclasses import dataclass
from transformers import AutoConfig from transformers import AutoConfig
@dataclass @dataclass
class Config: class Config:
model: str = '' model: str
max_num_batched_tokens: int = 32768 max_num_batched_tokens: int = 32768
max_num_seqs: int = 512 max_num_seqs: int = 512
max_model_len: int = 4096 max_model_len: int = 4096
@@ -17,5 +18,8 @@ class Config:
num_kvcache_blocks: int = -1 num_kvcache_blocks: int = -1
def __post_init__(self): def __post_init__(self):
assert self.model assert os.path.isdir(self.model)
assert self.kvcache_block_size % 256 == 0 assert self.kvcache_block_size % 256 == 0
assert 1 <= self.tensor_parallel_size <= 8
self.hf_config = AutoConfig.from_pretrained(self.model)
self.max_model_len = min(self.max_model_len, self.hf_config.max_position_embeddings)

View File

@@ -1,4 +1,5 @@
import atexit import atexit
from dataclasses import fields
from time import perf_counter from time import perf_counter
from tqdm.auto import tqdm from tqdm.auto import tqdm
from transformers import AutoConfig, AutoTokenizer from transformers import AutoConfig, AutoTokenizer
@@ -14,13 +15,9 @@ from nanovllm.engine.model_runner import ModelRunner
class LLMEngine: class LLMEngine:
def __init__(self, model, **kwargs): def __init__(self, model, **kwargs):
config = Config(model) config_fileds = {field.name for field in fields(Config)}
for k, v in kwargs.items(): config_kwargs = {k: v for k, v in kwargs.items() if k in config_fileds}
if hasattr(config, k): config = Config(model, **config_kwargs)
setattr(config, k, v)
Sequence.block_size = config.kvcache_block_size
config.hf_config = AutoConfig.from_pretrained(config.model)
config.max_model_len = min(config.max_model_len, config.hf_config.max_position_embeddings)
self.ps = [] self.ps = []
self.events = [] self.events = []
for i in range(1, config.tensor_parallel_size): for i in range(1, config.tensor_parallel_size):

View File

@@ -57,9 +57,7 @@ class ModelRunner:
def loop(self): def loop(self):
while True: while True:
method_name, args = self.read_shm() method_name, args = self.read_shm()
method = getattr(self, method_name, None) self.call(method_name, *args)
assert callable(method)
method(*args)
if method_name == "exit": if method_name == "exit":
break break
@@ -82,8 +80,7 @@ class ModelRunner:
event.set() event.set()
def call(self, method_name, *args): def call(self, method_name, *args):
assert self.rank == 0 if self.world_size > 1 and self.rank == 0:
if self.world_size > 1:
self.write_shm(method_name, *args) self.write_shm(method_name, *args)
method = getattr(self, method_name, None) method = getattr(self, method_name, None)
assert callable(method) assert callable(method)

View File

@@ -11,4 +11,4 @@ class SiluAndMul(nn.Module):
@torch.compile @torch.compile
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
x, y = x.chunk(2, -1) x, y = x.chunk(2, -1)
return F.silu(x) * y return y.mul_(F.silu(x))

View File

@@ -10,7 +10,6 @@ def default_weight_loader(param: nn.Parameter, loaded_weight: torch.Tensor):
def load_model(model: nn.Module, path: str): def load_model(model: nn.Module, path: str):
assert os.path.isdir(path)
packed_modules_mapping = getattr(model, "packed_modules_mapping", {}) packed_modules_mapping = getattr(model, "packed_modules_mapping", {})
for file in glob(os.path.join(path, "*.safetensors")): for file in glob(os.path.join(path, "*.safetensors")):
with safe_open(file, "pt", "cpu") as f: with safe_open(file, "pt", "cpu") as f:

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "nano-vllm" name = "nano-vllm"
version = "0.1.0" version = "0.2.0"
authors = [{ name = "Xingkai Yu" }] authors = [{ name = "Xingkai Yu" }]
license = "MIT" license = "MIT"
license-files = ["LICENSE"] license-files = ["LICENSE"]