This commit is contained in:
GeeeekExplorer
2025-06-15 13:09:05 +08:00
parent 326b121fad
commit 7e42fa6f63
3 changed files with 7 additions and 5 deletions

View File

@@ -34,6 +34,7 @@ class LLMEngine:
def exit(self): def exit(self):
self.model_runner.call("exit") self.model_runner.call("exit")
del self.model_runner
for p in self.ps: for p in self.ps:
p.join() p.join()

View File

@@ -1,8 +1,8 @@
import pickle import pickle
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from multiprocess.synchronize import Event from multiprocessing.synchronize import Event
from multiprocess.shared_memory import SharedMemory from multiprocessing.shared_memory import SharedMemory
from nanovllm.config import Config from nanovllm.config import Config
from nanovllm.engine.sequence import Sequence from nanovllm.engine.sequence import Sequence
@@ -50,6 +50,7 @@ class ModelRunner:
def exit(self): def exit(self):
if self.world_size > 1: if self.world_size > 1:
self.shm.close() self.shm.close()
dist.barrier()
if self.rank == 0: if self.rank == 0:
self.shm.unlink() self.shm.unlink()
# dist.destroy_process_group() # dist.destroy_process_group()
@@ -178,7 +179,7 @@ class ModelRunner:
@torch.inference_mode() @torch.inference_mode()
def run_model(self, input_ids: torch.Tensor, positions: torch.Tensor, is_prefill): def run_model(self, input_ids: torch.Tensor, positions: torch.Tensor, is_prefill):
if is_prefill or self.enforce_eager or input_ids.size(0) > 256: if is_prefill or self.enforce_eager or input_ids.size(0) > 512:
return self.model.compute_logits(self.model(input_ids, positions)) return self.model.compute_logits(self.model(input_ids, positions))
else: else:
bs = input_ids.size(0) bs = input_ids.size(0)
@@ -220,7 +221,7 @@ class ModelRunner:
config = self.config config = self.config
hf_config = config.hf_config hf_config = config.hf_config
max_bs = min(self.config.max_num_seqs, 256) max_bs = min(self.config.max_num_seqs, 512)
max_num_blocks = (config.max_model_len + self.block_size - 1) // self.block_size max_num_blocks = (config.max_model_len + self.block_size - 1) // self.block_size
input_ids = torch.zeros(max_bs, dtype=torch.int64) input_ids = torch.zeros(max_bs, dtype=torch.int64)
positions = torch.zeros(max_bs, dtype=torch.int64) positions = torch.zeros(max_bs, dtype=torch.int64)

View File

@@ -11,4 +11,4 @@ class SiluAndMul(nn.Module):
@torch.compile @torch.compile
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
x, y = x.chunk(2, -1) x, y = x.chunk(2, -1)
return y.mul_(F.silu(x)) return F.silu(x) * y