This commit is contained in:
GeeeekExplorer
2025-06-10 08:52:58 +08:00
parent a5a4909e6a
commit b98e1ca305
10 changed files with 39 additions and 26 deletions

View File

@@ -212,7 +212,8 @@ class Qwen3ForCausalLM(nn.Module):
super().__init__()
self.model = Qwen3Model(config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
if config.tie_word_embeddings:
self.tie_word_embeddings = config.tie_word_embeddings
if self.tie_word_embeddings:
self.lm_head.weight.data = self.model.embed_tokens.weight.data
def forward(
@@ -236,6 +237,8 @@ class Qwen3ForCausalLM(nn.Module):
default_weight_loader = lambda param, loaded_weight: param.data.copy_(loaded_weight)
with safe_open(os.path.join(path, "model.safetensors"), "pt", "cpu") as f:
for n, p in self.named_parameters():
if self.tie_word_embeddings and "lm_head" in n:
continue
for x in self.packed_modules_mapping:
if x in n:
weight_loader = getattr(p, "weight_loader", default_weight_loader)