multi file loader
This commit is contained in:
@@ -194,15 +194,11 @@ class Qwen3Model(nn.Module):
|
||||
|
||||
class Qwen3ForCausalLM(nn.Module):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
"q_proj": ("qkv_proj", "q"),
|
||||
"k_proj": ("qkv_proj", "k"),
|
||||
"v_proj": ("qkv_proj", "v"),
|
||||
"gate_proj": ("gate_up_proj", 0),
|
||||
"up_proj": ("gate_up_proj", 1),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
@@ -230,21 +226,3 @@ class Qwen3ForCausalLM(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
logits = self.lm_head(hidden_states)
|
||||
return logits
|
||||
|
||||
def load_weights(self, path: str):
|
||||
import os
|
||||
from safetensors import safe_open
|
||||
default_weight_loader = lambda param, loaded_weight: param.data.copy_(loaded_weight)
|
||||
with safe_open(os.path.join(path, "model.safetensors"), "pt", "cpu") as f:
|
||||
for n, p in self.named_parameters():
|
||||
if self.tie_word_embeddings and "lm_head" in n:
|
||||
continue
|
||||
for x in self.packed_modules_mapping:
|
||||
if x in n:
|
||||
weight_loader = getattr(p, "weight_loader", default_weight_loader)
|
||||
for i, y in enumerate(self.packed_modules_mapping[x]):
|
||||
weight_loader(p, f.get_tensor(n.replace(x, y)), y[0] if x == "qkv_proj" else i)
|
||||
break
|
||||
else:
|
||||
weight_loader = getattr(p, "weight_loader", default_weight_loader)
|
||||
weight_loader(p, f.get_tensor(n))
|
||||
|
||||
Reference in New Issue
Block a user