fix
This commit is contained in:
@@ -1,8 +1,8 @@
|
||||
import pickle
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from multiprocess.synchronize import Event
|
||||
from multiprocess.shared_memory import SharedMemory
|
||||
from multiprocessing.synchronize import Event
|
||||
from multiprocessing.shared_memory import SharedMemory
|
||||
|
||||
from nanovllm.config import Config
|
||||
from nanovllm.engine.sequence import Sequence
|
||||
@@ -50,6 +50,7 @@ class ModelRunner:
|
||||
def exit(self):
|
||||
if self.world_size > 1:
|
||||
self.shm.close()
|
||||
dist.barrier()
|
||||
if self.rank == 0:
|
||||
self.shm.unlink()
|
||||
# dist.destroy_process_group()
|
||||
@@ -178,7 +179,7 @@ class ModelRunner:
|
||||
|
||||
@torch.inference_mode()
|
||||
def run_model(self, input_ids: torch.Tensor, positions: torch.Tensor, is_prefill):
|
||||
if is_prefill or self.enforce_eager or input_ids.size(0) > 256:
|
||||
if is_prefill or self.enforce_eager or input_ids.size(0) > 512:
|
||||
return self.model.compute_logits(self.model(input_ids, positions))
|
||||
else:
|
||||
bs = input_ids.size(0)
|
||||
@@ -220,7 +221,7 @@ class ModelRunner:
|
||||
|
||||
config = self.config
|
||||
hf_config = config.hf_config
|
||||
max_bs = min(self.config.max_num_seqs, 256)
|
||||
max_bs = min(self.config.max_num_seqs, 512)
|
||||
max_num_blocks = (config.max_model_len + self.block_size - 1) // self.block_size
|
||||
input_ids = torch.zeros(max_bs, dtype=torch.int64)
|
||||
positions = torch.zeros(max_bs, dtype=torch.int64)
|
||||
|
||||
Reference in New Issue
Block a user