[refactor] Translate into english, void Chinese due to claude.
This commit is contained in:
@@ -18,6 +18,8 @@ class RMSNorm(nn.Module):
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# Input MUST be 2D [N, D] to avoid recompilation due to rank mismatch
|
||||
# Callers should reshape 3D tensors to 2D before calling
|
||||
orig_dtype = x.dtype
|
||||
x = x.float()
|
||||
var = x.pow(2).mean(dim=-1, keepdim=True)
|
||||
@@ -31,6 +33,7 @@ class RMSNorm(nn.Module):
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# Input MUST be 2D [N, D] to avoid recompilation due to rank mismatch
|
||||
orig_dtype = x.dtype
|
||||
x = x.float().add_(residual.float())
|
||||
residual = x.to(orig_dtype)
|
||||
|
||||
Reference in New Issue
Block a user