[refactor] Translate into english, void Chinese due to claude.

This commit is contained in:
Zijie Tian
2025-12-11 00:30:24 +08:00
parent e85c2b4776
commit babfa17354
9 changed files with 297 additions and 187 deletions

View File

@@ -18,6 +18,8 @@ class RMSNorm(nn.Module):
self,
x: torch.Tensor,
) -> torch.Tensor:
# Input MUST be 2D [N, D] to avoid recompilation due to rank mismatch
# Callers should reshape 3D tensors to 2D before calling
orig_dtype = x.dtype
x = x.float()
var = x.pow(2).mean(dim=-1, keepdim=True)
@@ -31,6 +33,7 @@ class RMSNorm(nn.Module):
x: torch.Tensor,
residual: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
# Input MUST be 2D [N, D] to avoid recompilation due to rank mismatch
orig_dtype = x.dtype
x = x.float().add_(residual.float())
residual = x.to(orig_dtype)