multi file loader
This commit is contained in:
@@ -26,4 +26,4 @@ outputs = llm.generate(prompts, sampling_params)
|
|||||||
for prompt, output in zip(prompts, outputs):
|
for prompt, output in zip(prompts, outputs):
|
||||||
print("\n")
|
print("\n")
|
||||||
print(f"Prompt: {prompt}")
|
print(f"Prompt: {prompt}")
|
||||||
print(f"Completion: {output["text"]}")
|
print(f"Completion: {output['text']}")
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from nanovllm.utils.context import set_context, get_context, reset_context
|
|||||||
from nanovllm.utils.memory import get_gpu_memory
|
from nanovllm.utils.memory import get_gpu_memory
|
||||||
from nanovllm.models.qwen3 import Qwen3ForCausalLM
|
from nanovllm.models.qwen3 import Qwen3ForCausalLM
|
||||||
from nanovllm.layers.sampler import Sampler
|
from nanovllm.layers.sampler import Sampler
|
||||||
|
from nanovllm.utils.loader import load_model
|
||||||
|
|
||||||
|
|
||||||
class ModelRunner:
|
class ModelRunner:
|
||||||
@@ -20,7 +21,7 @@ class ModelRunner:
|
|||||||
torch.set_default_dtype(hf_config.torch_dtype)
|
torch.set_default_dtype(hf_config.torch_dtype)
|
||||||
torch.set_default_device("cuda")
|
torch.set_default_device("cuda")
|
||||||
self.model = Qwen3ForCausalLM(hf_config)
|
self.model = Qwen3ForCausalLM(hf_config)
|
||||||
self.model.load_weights(config.model)
|
load_model(self.model, config.model)
|
||||||
self.sampler = Sampler()
|
self.sampler = Sampler()
|
||||||
self.allocate_kv_cache(config.gpu_memory_utilization)
|
self.allocate_kv_cache(config.gpu_memory_utilization)
|
||||||
if not self.enforce_eager:
|
if not self.enforce_eager:
|
||||||
|
|||||||
@@ -105,7 +105,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|||||||
assert all(output_size % tp_size == 0 for output_size in output_sizes)
|
assert all(output_size % tp_size == 0 for output_size in output_sizes)
|
||||||
super().__init__(input_size, sum(output_sizes), bias=bias)
|
super().__init__(input_size, sum(output_sizes), bias=bias)
|
||||||
|
|
||||||
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: int | None = None):
|
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: int):
|
||||||
param_data = param.data
|
param_data = param.data
|
||||||
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size
|
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size
|
||||||
shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
|
shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
|
||||||
@@ -145,7 +145,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|||||||
|
|
||||||
super().__init__(input_size, output_size, bias)
|
super().__init__(input_size, output_size, bias)
|
||||||
|
|
||||||
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: str | None = None):
|
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: str):
|
||||||
param_data = param.data
|
param_data = param.data
|
||||||
assert loaded_shard_id in ["q", "k", "v"]
|
assert loaded_shard_id in ["q", "k", "v"]
|
||||||
if loaded_shard_id == "q":
|
if loaded_shard_id == "q":
|
||||||
|
|||||||
@@ -194,15 +194,11 @@ class Qwen3Model(nn.Module):
|
|||||||
|
|
||||||
class Qwen3ForCausalLM(nn.Module):
|
class Qwen3ForCausalLM(nn.Module):
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
"qkv_proj": [
|
"q_proj": ("qkv_proj", "q"),
|
||||||
"q_proj",
|
"k_proj": ("qkv_proj", "k"),
|
||||||
"k_proj",
|
"v_proj": ("qkv_proj", "v"),
|
||||||
"v_proj",
|
"gate_proj": ("gate_up_proj", 0),
|
||||||
],
|
"up_proj": ("gate_up_proj", 1),
|
||||||
"gate_up_proj": [
|
|
||||||
"gate_proj",
|
|
||||||
"up_proj",
|
|
||||||
],
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -230,21 +226,3 @@ class Qwen3ForCausalLM(nn.Module):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
logits = self.lm_head(hidden_states)
|
logits = self.lm_head(hidden_states)
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
def load_weights(self, path: str):
|
|
||||||
import os
|
|
||||||
from safetensors import safe_open
|
|
||||||
default_weight_loader = lambda param, loaded_weight: param.data.copy_(loaded_weight)
|
|
||||||
with safe_open(os.path.join(path, "model.safetensors"), "pt", "cpu") as f:
|
|
||||||
for n, p in self.named_parameters():
|
|
||||||
if self.tie_word_embeddings and "lm_head" in n:
|
|
||||||
continue
|
|
||||||
for x in self.packed_modules_mapping:
|
|
||||||
if x in n:
|
|
||||||
weight_loader = getattr(p, "weight_loader", default_weight_loader)
|
|
||||||
for i, y in enumerate(self.packed_modules_mapping[x]):
|
|
||||||
weight_loader(p, f.get_tensor(n.replace(x, y)), y[0] if x == "qkv_proj" else i)
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
weight_loader = getattr(p, "weight_loader", default_weight_loader)
|
|
||||||
weight_loader(p, f.get_tensor(n))
|
|
||||||
|
|||||||
29
nanovllm/utils/loader.py
Normal file
29
nanovllm/utils/loader.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
import os
|
||||||
|
from glob import glob
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from safetensors import safe_open
|
||||||
|
|
||||||
|
|
||||||
|
def default_weight_loader(param: nn.Parameter, loaded_weight: torch.Tensor):
|
||||||
|
param.data.copy_(loaded_weight)
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(model: nn.Module, path: str):
|
||||||
|
assert os.path.isdir(path)
|
||||||
|
packed_modules_mapping = getattr(model, "packed_modules_mapping", {})
|
||||||
|
for file in glob(os.path.join(path, "*.safetensors")):
|
||||||
|
with safe_open(file, "pt", "cpu") as f:
|
||||||
|
for weight_name in f.keys():
|
||||||
|
for k in packed_modules_mapping:
|
||||||
|
if k in weight_name:
|
||||||
|
v, shard_id = packed_modules_mapping[k]
|
||||||
|
param_name = weight_name.replace(k, v)
|
||||||
|
param = model.get_parameter(param_name)
|
||||||
|
weight_loader = getattr(param, "weight_loader")
|
||||||
|
weight_loader(param, f.get_tensor(weight_name), shard_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
param = model.get_parameter(weight_name)
|
||||||
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
|
weight_loader(param, f.get_tensor(weight_name))
|
||||||
Reference in New Issue
Block a user