diff --git a/nanovllm/engine/llm_engine.py b/nanovllm/engine/llm_engine.py index 53a4887..028f170 100644 --- a/nanovllm/engine/llm_engine.py +++ b/nanovllm/engine/llm_engine.py @@ -34,6 +34,7 @@ class LLMEngine: def exit(self): self.model_runner.call("exit") + del self.model_runner for p in self.ps: p.join() diff --git a/nanovllm/engine/model_runner.py b/nanovllm/engine/model_runner.py index 8973310..e5958ec 100644 --- a/nanovllm/engine/model_runner.py +++ b/nanovllm/engine/model_runner.py @@ -1,8 +1,8 @@ import pickle import torch import torch.distributed as dist -from multiprocess.synchronize import Event -from multiprocess.shared_memory import SharedMemory +from multiprocessing.synchronize import Event +from multiprocessing.shared_memory import SharedMemory from nanovllm.config import Config from nanovllm.engine.sequence import Sequence @@ -50,6 +50,7 @@ class ModelRunner: def exit(self): if self.world_size > 1: self.shm.close() + dist.barrier() if self.rank == 0: self.shm.unlink() # dist.destroy_process_group() @@ -178,7 +179,7 @@ class ModelRunner: @torch.inference_mode() def run_model(self, input_ids: torch.Tensor, positions: torch.Tensor, is_prefill): - if is_prefill or self.enforce_eager or input_ids.size(0) > 256: + if is_prefill or self.enforce_eager or input_ids.size(0) > 512: return self.model.compute_logits(self.model(input_ids, positions)) else: bs = input_ids.size(0) @@ -220,7 +221,7 @@ class ModelRunner: config = self.config hf_config = config.hf_config - max_bs = min(self.config.max_num_seqs, 256) + max_bs = min(self.config.max_num_seqs, 512) max_num_blocks = (config.max_model_len + self.block_size - 1) // self.block_size input_ids = torch.zeros(max_bs, dtype=torch.int64) positions = torch.zeros(max_bs, dtype=torch.int64) diff --git a/nanovllm/layers/activation.py b/nanovllm/layers/activation.py index 8d026e1..041ee20 100755 --- a/nanovllm/layers/activation.py +++ b/nanovllm/layers/activation.py @@ -11,4 +11,4 @@ class SiluAndMul(nn.Module): @torch.compile def forward(self, x: torch.Tensor) -> torch.Tensor: x, y = x.chunk(2, -1) - return y.mul_(F.silu(x)) + return F.silu(x) * y