simplify
This commit is contained in:
@@ -10,7 +10,6 @@ class RMSNorm(nn.Module):
|
||||
eps: float = 1e-6,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
|
||||
@@ -20,7 +19,7 @@ class RMSNorm(nn.Module):
|
||||
x: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
orig_dtype = x.dtype
|
||||
x = x.to(torch.float32)
|
||||
x = x.float()
|
||||
var = x.pow(2).mean(dim=-1, keepdim=True)
|
||||
x.mul_(torch.rsqrt(var + self.eps))
|
||||
x = x.to(orig_dtype).mul_(self.weight)
|
||||
@@ -33,7 +32,7 @@ class RMSNorm(nn.Module):
|
||||
residual: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
orig_dtype = x.dtype
|
||||
x = x.to(torch.float32).add_(residual.to(torch.float32))
|
||||
x = x.float().add_(residual.float())
|
||||
residual = x.to(orig_dtype)
|
||||
var = x.pow(2).mean(dim=-1, keepdim=True)
|
||||
x.mul_(torch.rsqrt(var + self.eps))
|
||||
|
||||
Reference in New Issue
Block a user