better
This commit is contained in:
@@ -11,4 +11,4 @@ class SiluAndMul(nn.Module):
|
||||
@torch.compile
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x, y = x.chunk(2, -1)
|
||||
return F.silu(x) * y
|
||||
return y.mul_(F.silu(x))
|
||||
|
||||
@@ -69,4 +69,4 @@ class ParallelLMHead(VocabParallelEmbedding):
|
||||
all_logits = [torch.empty_like(logits) for _ in range(self.tp_size)] if self.tp_rank == 0 else None
|
||||
dist.gather(logits, all_logits, 0)
|
||||
logits = torch.cat(all_logits, -1) if self.tp_rank == 0 else None
|
||||
return logits
|
||||
return logits
|
||||
|
||||
@@ -70,4 +70,4 @@ def get_rope(
|
||||
):
|
||||
assert rope_scaling is None
|
||||
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base)
|
||||
return rotary_emb
|
||||
return rotary_emb
|
||||
|
||||
Reference in New Issue
Block a user