multi file loader

This commit is contained in:
GeeeekExplorer
2025-06-11 22:32:48 +08:00
parent 386290d69e
commit 08c84ec08d
5 changed files with 39 additions and 31 deletions

View File

@@ -26,4 +26,4 @@ outputs = llm.generate(prompts, sampling_params)
for prompt, output in zip(prompts, outputs): for prompt, output in zip(prompts, outputs):
print("\n") print("\n")
print(f"Prompt: {prompt}") print(f"Prompt: {prompt}")
print(f"Completion: {output["text"]}") print(f"Completion: {output['text']}")

View File

@@ -6,6 +6,7 @@ from nanovllm.utils.context import set_context, get_context, reset_context
from nanovllm.utils.memory import get_gpu_memory from nanovllm.utils.memory import get_gpu_memory
from nanovllm.models.qwen3 import Qwen3ForCausalLM from nanovllm.models.qwen3 import Qwen3ForCausalLM
from nanovllm.layers.sampler import Sampler from nanovllm.layers.sampler import Sampler
from nanovllm.utils.loader import load_model
class ModelRunner: class ModelRunner:
@@ -20,7 +21,7 @@ class ModelRunner:
torch.set_default_dtype(hf_config.torch_dtype) torch.set_default_dtype(hf_config.torch_dtype)
torch.set_default_device("cuda") torch.set_default_device("cuda")
self.model = Qwen3ForCausalLM(hf_config) self.model = Qwen3ForCausalLM(hf_config)
self.model.load_weights(config.model) load_model(self.model, config.model)
self.sampler = Sampler() self.sampler = Sampler()
self.allocate_kv_cache(config.gpu_memory_utilization) self.allocate_kv_cache(config.gpu_memory_utilization)
if not self.enforce_eager: if not self.enforce_eager:

View File

@@ -105,7 +105,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
assert all(output_size % tp_size == 0 for output_size in output_sizes) assert all(output_size % tp_size == 0 for output_size in output_sizes)
super().__init__(input_size, sum(output_sizes), bias=bias) super().__init__(input_size, sum(output_sizes), bias=bias)
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: int | None = None): def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: int):
param_data = param.data param_data = param.data
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size
shard_size = self.output_sizes[loaded_shard_id] // self.tp_size shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
@@ -145,7 +145,7 @@ class QKVParallelLinear(ColumnParallelLinear):
super().__init__(input_size, output_size, bias) super().__init__(input_size, output_size, bias)
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: str | None = None): def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: str):
param_data = param.data param_data = param.data
assert loaded_shard_id in ["q", "k", "v"] assert loaded_shard_id in ["q", "k", "v"]
if loaded_shard_id == "q": if loaded_shard_id == "q":

View File

@@ -194,15 +194,11 @@ class Qwen3Model(nn.Module):
class Qwen3ForCausalLM(nn.Module): class Qwen3ForCausalLM(nn.Module):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "q_proj": ("qkv_proj", "q"),
"q_proj", "k_proj": ("qkv_proj", "k"),
"k_proj", "v_proj": ("qkv_proj", "v"),
"v_proj", "gate_proj": ("gate_up_proj", 0),
], "up_proj": ("gate_up_proj", 1),
"gate_up_proj": [
"gate_proj",
"up_proj",
],
} }
def __init__( def __init__(
@@ -230,21 +226,3 @@ class Qwen3ForCausalLM(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
logits = self.lm_head(hidden_states) logits = self.lm_head(hidden_states)
return logits return logits
def load_weights(self, path: str):
import os
from safetensors import safe_open
default_weight_loader = lambda param, loaded_weight: param.data.copy_(loaded_weight)
with safe_open(os.path.join(path, "model.safetensors"), "pt", "cpu") as f:
for n, p in self.named_parameters():
if self.tie_word_embeddings and "lm_head" in n:
continue
for x in self.packed_modules_mapping:
if x in n:
weight_loader = getattr(p, "weight_loader", default_weight_loader)
for i, y in enumerate(self.packed_modules_mapping[x]):
weight_loader(p, f.get_tensor(n.replace(x, y)), y[0] if x == "qkv_proj" else i)
break
else:
weight_loader = getattr(p, "weight_loader", default_weight_loader)
weight_loader(p, f.get_tensor(n))

29
nanovllm/utils/loader.py Normal file
View File

@@ -0,0 +1,29 @@
import os
from glob import glob
import torch
from torch import nn
from safetensors import safe_open
def default_weight_loader(param: nn.Parameter, loaded_weight: torch.Tensor):
param.data.copy_(loaded_weight)
def load_model(model: nn.Module, path: str):
assert os.path.isdir(path)
packed_modules_mapping = getattr(model, "packed_modules_mapping", {})
for file in glob(os.path.join(path, "*.safetensors")):
with safe_open(file, "pt", "cpu") as f:
for weight_name in f.keys():
for k in packed_modules_mapping:
if k in weight_name:
v, shard_id = packed_modules_mapping[k]
param_name = weight_name.replace(k, v)
param = model.get_parameter(param_name)
weight_loader = getattr(param, "weight_loader")
weight_loader(param, f.get_tensor(weight_name), shard_id)
break
else:
param = model.get_parameter(weight_name)
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, f.get_tensor(weight_name))