This commit is contained in:
GeeeekExplorer
2025-06-15 10:31:48 +08:00
parent c1fd4ea3c2
commit fc778a4da9
10 changed files with 19 additions and 22 deletions

View File

@@ -11,4 +11,4 @@ class SiluAndMul(nn.Module):
@torch.compile
def forward(self, x: torch.Tensor) -> torch.Tensor:
x, y = x.chunk(2, -1)
return F.silu(x) * y
return y.mul_(F.silu(x))

View File

@@ -69,4 +69,4 @@ class ParallelLMHead(VocabParallelEmbedding):
all_logits = [torch.empty_like(logits) for _ in range(self.tp_size)] if self.tp_rank == 0 else None
dist.gather(logits, all_logits, 0)
logits = torch.cat(all_logits, -1) if self.tp_rank == 0 else None
return logits
return logits

View File

@@ -70,4 +70,4 @@ def get_rope(
):
assert rope_scaling is None
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base)
return rotary_emb
return rotary_emb