multi file loader

This commit is contained in:
GeeeekExplorer
2025-06-11 22:32:48 +08:00
parent 386290d69e
commit 08c84ec08d
5 changed files with 39 additions and 31 deletions

View File

@@ -194,15 +194,11 @@ class Qwen3Model(nn.Module):
class Qwen3ForCausalLM(nn.Module):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
"q_proj": ("qkv_proj", "q"),
"k_proj": ("qkv_proj", "k"),
"v_proj": ("qkv_proj", "v"),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
def __init__(
@@ -230,21 +226,3 @@ class Qwen3ForCausalLM(nn.Module):
) -> torch.Tensor:
logits = self.lm_head(hidden_states)
return logits
def load_weights(self, path: str):
import os
from safetensors import safe_open
default_weight_loader = lambda param, loaded_weight: param.data.copy_(loaded_weight)
with safe_open(os.path.join(path, "model.safetensors"), "pt", "cpu") as f:
for n, p in self.named_parameters():
if self.tie_word_embeddings and "lm_head" in n:
continue
for x in self.packed_modules_mapping:
if x in n:
weight_loader = getattr(p, "weight_loader", default_weight_loader)
for i, y in enumerate(self.packed_modules_mapping[x]):
weight_loader(p, f.get_tensor(n.replace(x, y)), y[0] if x == "qkv_proj" else i)
break
else:
weight_loader = getattr(p, "weight_loader", default_weight_loader)
weight_loader(p, f.get_tensor(n))