fix
This commit is contained in:
@@ -212,7 +212,8 @@ class Qwen3ForCausalLM(nn.Module):
|
||||
super().__init__()
|
||||
self.model = Qwen3Model(config)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
if config.tie_word_embeddings:
|
||||
self.tie_word_embeddings = config.tie_word_embeddings
|
||||
if self.tie_word_embeddings:
|
||||
self.lm_head.weight.data = self.model.embed_tokens.weight.data
|
||||
|
||||
def forward(
|
||||
@@ -236,6 +237,8 @@ class Qwen3ForCausalLM(nn.Module):
|
||||
default_weight_loader = lambda param, loaded_weight: param.data.copy_(loaded_weight)
|
||||
with safe_open(os.path.join(path, "model.safetensors"), "pt", "cpu") as f:
|
||||
for n, p in self.named_parameters():
|
||||
if self.tie_word_embeddings and "lm_head" in n:
|
||||
continue
|
||||
for x in self.packed_modules_mapping:
|
||||
if x in n:
|
||||
weight_loader = getattr(p, "weight_loader", default_weight_loader)
|
||||
|
||||
Reference in New Issue
Block a user