remove assert shape
This commit is contained in:
@@ -31,7 +31,7 @@ class RMSNorm(nn.Module):
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
orig_dtype = x.dtype
|
||||
x = x.to(torch.float32).add_(residual.to(torch.float32))
|
||||
residual = x.to(orig_dtype)
|
||||
|
||||
Reference in New Issue
Block a user