multi file loader
This commit is contained in:
@@ -26,4 +26,4 @@ outputs = llm.generate(prompts, sampling_params)
|
||||
for prompt, output in zip(prompts, outputs):
|
||||
print("\n")
|
||||
print(f"Prompt: {prompt}")
|
||||
print(f"Completion: {output["text"]}")
|
||||
print(f"Completion: {output['text']}")
|
||||
|
||||
@@ -6,6 +6,7 @@ from nanovllm.utils.context import set_context, get_context, reset_context
|
||||
from nanovllm.utils.memory import get_gpu_memory
|
||||
from nanovllm.models.qwen3 import Qwen3ForCausalLM
|
||||
from nanovllm.layers.sampler import Sampler
|
||||
from nanovllm.utils.loader import load_model
|
||||
|
||||
|
||||
class ModelRunner:
|
||||
@@ -20,7 +21,7 @@ class ModelRunner:
|
||||
torch.set_default_dtype(hf_config.torch_dtype)
|
||||
torch.set_default_device("cuda")
|
||||
self.model = Qwen3ForCausalLM(hf_config)
|
||||
self.model.load_weights(config.model)
|
||||
load_model(self.model, config.model)
|
||||
self.sampler = Sampler()
|
||||
self.allocate_kv_cache(config.gpu_memory_utilization)
|
||||
if not self.enforce_eager:
|
||||
|
||||
@@ -105,7 +105,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
assert all(output_size % tp_size == 0 for output_size in output_sizes)
|
||||
super().__init__(input_size, sum(output_sizes), bias=bias)
|
||||
|
||||
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: int | None = None):
|
||||
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: int):
|
||||
param_data = param.data
|
||||
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size
|
||||
shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
|
||||
@@ -145,7 +145,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
|
||||
super().__init__(input_size, output_size, bias)
|
||||
|
||||
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: str | None = None):
|
||||
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: str):
|
||||
param_data = param.data
|
||||
assert loaded_shard_id in ["q", "k", "v"]
|
||||
if loaded_shard_id == "q":
|
||||
|
||||
@@ -194,15 +194,11 @@ class Qwen3Model(nn.Module):
|
||||
|
||||
class Qwen3ForCausalLM(nn.Module):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
"q_proj": ("qkv_proj", "q"),
|
||||
"k_proj": ("qkv_proj", "k"),
|
||||
"v_proj": ("qkv_proj", "v"),
|
||||
"gate_proj": ("gate_up_proj", 0),
|
||||
"up_proj": ("gate_up_proj", 1),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
@@ -230,21 +226,3 @@ class Qwen3ForCausalLM(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
logits = self.lm_head(hidden_states)
|
||||
return logits
|
||||
|
||||
def load_weights(self, path: str):
|
||||
import os
|
||||
from safetensors import safe_open
|
||||
default_weight_loader = lambda param, loaded_weight: param.data.copy_(loaded_weight)
|
||||
with safe_open(os.path.join(path, "model.safetensors"), "pt", "cpu") as f:
|
||||
for n, p in self.named_parameters():
|
||||
if self.tie_word_embeddings and "lm_head" in n:
|
||||
continue
|
||||
for x in self.packed_modules_mapping:
|
||||
if x in n:
|
||||
weight_loader = getattr(p, "weight_loader", default_weight_loader)
|
||||
for i, y in enumerate(self.packed_modules_mapping[x]):
|
||||
weight_loader(p, f.get_tensor(n.replace(x, y)), y[0] if x == "qkv_proj" else i)
|
||||
break
|
||||
else:
|
||||
weight_loader = getattr(p, "weight_loader", default_weight_loader)
|
||||
weight_loader(p, f.get_tensor(n))
|
||||
|
||||
29
nanovllm/utils/loader.py
Normal file
29
nanovllm/utils/loader.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import os
|
||||
from glob import glob
|
||||
import torch
|
||||
from torch import nn
|
||||
from safetensors import safe_open
|
||||
|
||||
|
||||
def default_weight_loader(param: nn.Parameter, loaded_weight: torch.Tensor):
|
||||
param.data.copy_(loaded_weight)
|
||||
|
||||
|
||||
def load_model(model: nn.Module, path: str):
|
||||
assert os.path.isdir(path)
|
||||
packed_modules_mapping = getattr(model, "packed_modules_mapping", {})
|
||||
for file in glob(os.path.join(path, "*.safetensors")):
|
||||
with safe_open(file, "pt", "cpu") as f:
|
||||
for weight_name in f.keys():
|
||||
for k in packed_modules_mapping:
|
||||
if k in weight_name:
|
||||
v, shard_id = packed_modules_mapping[k]
|
||||
param_name = weight_name.replace(k, v)
|
||||
param = model.get_parameter(param_name)
|
||||
weight_loader = getattr(param, "weight_loader")
|
||||
weight_loader(param, f.get_tensor(weight_name), shard_id)
|
||||
break
|
||||
else:
|
||||
param = model.get_parameter(weight_name)
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, f.get_tensor(weight_name))
|
||||
Reference in New Issue
Block a user