Files
nano-vllm/nanovllm/layers/layernorm.py
Zijie Tian c51a640a29 🐛 fix: remove torch.compile from add_rms_forward to avoid recompilation
The add_rms_forward method processes two input tensors (x and residual),
which causes torch.compile recompilation issues. Keep @torch.compile only
on rms_forward which processes a single input.

This prevents unnecessary recompilation overhead during inference.

Co-Authored-By: Claude <noreply@anthropic.com>
2026-01-14 07:02:02 +08:00

54 lines
1.6 KiB
Python
Executable File

import torch
from torch import nn
class RMSNorm(nn.Module):
def __init__(
self,
hidden_size: int,
eps: float = 1e-6,
) -> None:
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(hidden_size))
@torch.compile
def rms_forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
# Input MUST be 2D [N, D] to avoid recompilation due to rank mismatch
# Callers should reshape 3D tensors to 2D before calling
orig_dtype = x.dtype
x = x.float()
var = x.pow(2).mean(dim=-1, keepdim=True)
x.mul_(torch.rsqrt(var + self.eps))
x = x.to(orig_dtype).mul_(self.weight)
return x
def add_rms_forward(
self,
x: torch.Tensor,
residual: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
# Input MUST be 2D [N, D] to avoid recompilation due to rank mismatch
# Note: @torch.compile removed due to OOM with 64k sequences (memory fragmentation)
orig_dtype = x.dtype
x = x.float().add_(residual.float())
residual = x.to(orig_dtype)
var = x.pow(2).mean(dim=-1, keepdim=True)
x.mul_(torch.rsqrt(var + self.eps))
x = x.to(orig_dtype).mul_(self.weight)
return x, residual
def forward(
self,
x: torch.Tensor,
residual: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if residual is None:
return self.rms_forward(x)
else:
return self.add_rms_forward(x, residual)