This commit is contained in:
GeeeekExplorer
2025-08-31 19:44:57 +08:00
parent 6a6d217de7
commit df99418f7d
11 changed files with 47 additions and 96 deletions

View File

@@ -10,7 +10,6 @@ class RMSNorm(nn.Module):
eps: float = 1e-6,
) -> None:
super().__init__()
self.hidden_size = hidden_size
self.eps = eps
self.weight = nn.Parameter(torch.ones(hidden_size))
@@ -20,7 +19,7 @@ class RMSNorm(nn.Module):
x: torch.Tensor,
) -> torch.Tensor:
orig_dtype = x.dtype
x = x.to(torch.float32)
x = x.float()
var = x.pow(2).mean(dim=-1, keepdim=True)
x.mul_(torch.rsqrt(var + self.eps))
x = x.to(orig_dtype).mul_(self.weight)
@@ -33,7 +32,7 @@ class RMSNorm(nn.Module):
residual: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
orig_dtype = x.dtype
x = x.to(torch.float32).add_(residual.to(torch.float32))
x = x.float().add_(residual.float())
residual = x.to(orig_dtype)
var = x.pow(2).mean(dim=-1, keepdim=True)
x.mul_(torch.rsqrt(var + self.eps))