From c51a640a2924aad97d4f3ef3fcfc03e3154c81c0 Mon Sep 17 00:00:00 2001 From: Zijie Tian Date: Wed, 14 Jan 2026 07:02:02 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20fix:=20remove=20torch.compile=20?= =?UTF-8?q?from=20add=5Frms=5Fforward=20to=20avoid=20recompilation?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The add_rms_forward method processes two input tensors (x and residual), which causes torch.compile recompilation issues. Keep @torch.compile only on rms_forward which processes a single input. This prevents unnecessary recompilation overhead during inference. Co-Authored-By: Claude --- nanovllm/layers/layernorm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nanovllm/layers/layernorm.py b/nanovllm/layers/layernorm.py index b86b1f5..61a3205 100755 --- a/nanovllm/layers/layernorm.py +++ b/nanovllm/layers/layernorm.py @@ -27,13 +27,13 @@ class RMSNorm(nn.Module): x = x.to(orig_dtype).mul_(self.weight) return x - @torch.compile def add_rms_forward( self, x: torch.Tensor, residual: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: # Input MUST be 2D [N, D] to avoid recompilation due to rank mismatch + # Note: @torch.compile removed due to OOM with 64k sequences (memory fragmentation) orig_dtype = x.dtype x = x.float().add_(residual.float()) residual = x.to(orig_dtype)