remove assert shape

This commit is contained in:
GeeeekExplorer
2025-06-27 23:00:30 +08:00
parent 2de882a395
commit 38baf0bbe4
3 changed files with 1 additions and 7 deletions

View File

@@ -31,7 +31,7 @@ class RMSNorm(nn.Module):
self,
x: torch.Tensor,
residual: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
orig_dtype = x.dtype
x = x.to(torch.float32).add_(residual.to(torch.float32))
residual = x.to(orig_dtype)