diff --git a/nanovllm/layers/layernorm.py b/nanovllm/layers/layernorm.py index b86b1f5..61a3205 100755 --- a/nanovllm/layers/layernorm.py +++ b/nanovllm/layers/layernorm.py @@ -27,13 +27,13 @@ class RMSNorm(nn.Module): x = x.to(orig_dtype).mul_(self.weight) return x - @torch.compile def add_rms_forward( self, x: torch.Tensor, residual: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: # Input MUST be 2D [N, D] to avoid recompilation due to rank mismatch + # Note: @torch.compile removed due to OOM with 64k sequences (memory fragmentation) orig_dtype = x.dtype x = x.float().add_(residual.float()) residual = x.to(orig_dtype)