diff --git a/example.py b/example.py index 5ae260e..face5eb 100644 --- a/example.py +++ b/example.py @@ -26,4 +26,4 @@ outputs = llm.generate(prompts, sampling_params) for prompt, output in zip(prompts, outputs): print("\n") print(f"Prompt: {prompt}") - print(f"Completion: {output["text"]}") + print(f"Completion: {output['text']}") diff --git a/nanovllm/engine/model_runner.py b/nanovllm/engine/model_runner.py index c8f8588..b803278 100644 --- a/nanovllm/engine/model_runner.py +++ b/nanovllm/engine/model_runner.py @@ -6,6 +6,7 @@ from nanovllm.utils.context import set_context, get_context, reset_context from nanovllm.utils.memory import get_gpu_memory from nanovllm.models.qwen3 import Qwen3ForCausalLM from nanovllm.layers.sampler import Sampler +from nanovllm.utils.loader import load_model class ModelRunner: @@ -20,7 +21,7 @@ class ModelRunner: torch.set_default_dtype(hf_config.torch_dtype) torch.set_default_device("cuda") self.model = Qwen3ForCausalLM(hf_config) - self.model.load_weights(config.model) + load_model(self.model, config.model) self.sampler = Sampler() self.allocate_kv_cache(config.gpu_memory_utilization) if not self.enforce_eager: diff --git a/nanovllm/layers/linear.py b/nanovllm/layers/linear.py index f39fa59..d5133f1 100755 --- a/nanovllm/layers/linear.py +++ b/nanovllm/layers/linear.py @@ -105,7 +105,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear): assert all(output_size % tp_size == 0 for output_size in output_sizes) super().__init__(input_size, sum(output_sizes), bias=bias) - def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: int | None = None): + def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: int): param_data = param.data shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size shard_size = self.output_sizes[loaded_shard_id] // self.tp_size @@ -145,7 +145,7 @@ class QKVParallelLinear(ColumnParallelLinear): super().__init__(input_size, output_size, bias) - def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: str | None = None): + def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: str): param_data = param.data assert loaded_shard_id in ["q", "k", "v"] if loaded_shard_id == "q": diff --git a/nanovllm/models/qwen3.py b/nanovllm/models/qwen3.py index 1d3ff41..f65475f 100755 --- a/nanovllm/models/qwen3.py +++ b/nanovllm/models/qwen3.py @@ -194,15 +194,11 @@ class Qwen3Model(nn.Module): class Qwen3ForCausalLM(nn.Module): packed_modules_mapping = { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], + "q_proj": ("qkv_proj", "q"), + "k_proj": ("qkv_proj", "k"), + "v_proj": ("qkv_proj", "v"), + "gate_proj": ("gate_up_proj", 0), + "up_proj": ("gate_up_proj", 1), } def __init__( @@ -230,21 +226,3 @@ class Qwen3ForCausalLM(nn.Module): ) -> torch.Tensor: logits = self.lm_head(hidden_states) return logits - - def load_weights(self, path: str): - import os - from safetensors import safe_open - default_weight_loader = lambda param, loaded_weight: param.data.copy_(loaded_weight) - with safe_open(os.path.join(path, "model.safetensors"), "pt", "cpu") as f: - for n, p in self.named_parameters(): - if self.tie_word_embeddings and "lm_head" in n: - continue - for x in self.packed_modules_mapping: - if x in n: - weight_loader = getattr(p, "weight_loader", default_weight_loader) - for i, y in enumerate(self.packed_modules_mapping[x]): - weight_loader(p, f.get_tensor(n.replace(x, y)), y[0] if x == "qkv_proj" else i) - break - else: - weight_loader = getattr(p, "weight_loader", default_weight_loader) - weight_loader(p, f.get_tensor(n)) diff --git a/nanovllm/utils/loader.py b/nanovllm/utils/loader.py new file mode 100644 index 0000000..c052e0f --- /dev/null +++ b/nanovllm/utils/loader.py @@ -0,0 +1,29 @@ +import os +from glob import glob +import torch +from torch import nn +from safetensors import safe_open + + +def default_weight_loader(param: nn.Parameter, loaded_weight: torch.Tensor): + param.data.copy_(loaded_weight) + + +def load_model(model: nn.Module, path: str): + assert os.path.isdir(path) + packed_modules_mapping = getattr(model, "packed_modules_mapping", {}) + for file in glob(os.path.join(path, "*.safetensors")): + with safe_open(file, "pt", "cpu") as f: + for weight_name in f.keys(): + for k in packed_modules_mapping: + if k in weight_name: + v, shard_id = packed_modules_mapping[k] + param_name = weight_name.replace(k, v) + param = model.get_parameter(param_name) + weight_loader = getattr(param, "weight_loader") + weight_loader(param, f.get_tensor(weight_name), shard_id) + break + else: + param = model.get_parameter(weight_name) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, f.get_tensor(weight_name))