This commit is contained in:
GeeeekExplorer
2025-06-15 13:09:05 +08:00
parent 326b121fad
commit 7e42fa6f63
3 changed files with 7 additions and 5 deletions

View File

@@ -1,8 +1,8 @@
import pickle
import torch
import torch.distributed as dist
from multiprocess.synchronize import Event
from multiprocess.shared_memory import SharedMemory
from multiprocessing.synchronize import Event
from multiprocessing.shared_memory import SharedMemory
from nanovllm.config import Config
from nanovllm.engine.sequence import Sequence
@@ -50,6 +50,7 @@ class ModelRunner:
def exit(self):
if self.world_size > 1:
self.shm.close()
dist.barrier()
if self.rank == 0:
self.shm.unlink()
# dist.destroy_process_group()
@@ -178,7 +179,7 @@ class ModelRunner:
@torch.inference_mode()
def run_model(self, input_ids: torch.Tensor, positions: torch.Tensor, is_prefill):
if is_prefill or self.enforce_eager or input_ids.size(0) > 256:
if is_prefill or self.enforce_eager or input_ids.size(0) > 512:
return self.model.compute_logits(self.model(input_ids, positions))
else:
bs = input_ids.size(0)
@@ -220,7 +221,7 @@ class ModelRunner:
config = self.config
hf_config = config.hf_config
max_bs = min(self.config.max_num_seqs, 256)
max_bs = min(self.config.max_num_seqs, 512)
max_num_blocks = (config.max_model_len + self.block_size - 1) // self.block_size
input_ids = torch.zeros(max_bs, dtype=torch.int64)
positions = torch.zeros(max_bs, dtype=torch.int64)