support tensor parallel
This commit is contained in:
@@ -1,4 +1,8 @@
|
||||
import pickle
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from multiprocess.synchronize import Event
|
||||
from multiprocess.shared_memory import SharedMemory
|
||||
|
||||
from nanovllm.config import Config
|
||||
from nanovllm.engine.sequence import Sequence
|
||||
@@ -11,12 +15,17 @@ from nanovllm.utils.loader import load_model
|
||||
|
||||
class ModelRunner:
|
||||
|
||||
def __init__(self, config: Config):
|
||||
def __init__(self, config: Config, rank: int, event: Event | list[Event]):
|
||||
self.config = config
|
||||
hf_config = config.hf_config
|
||||
self.block_size = config.kvcache_block_size
|
||||
self.enforce_eager = config.enforce_eager
|
||||
self.world_size = config.tensor_parallel_size
|
||||
self.rank = rank
|
||||
self.event = event
|
||||
|
||||
dist.init_process_group("nccl", "tcp://localhost:2333", world_size=self.world_size, rank=rank)
|
||||
torch.cuda.set_device(rank)
|
||||
default_dtype = torch.get_default_dtype()
|
||||
torch.set_default_dtype(hf_config.torch_dtype)
|
||||
torch.set_default_device("cuda")
|
||||
@@ -29,14 +38,66 @@ class ModelRunner:
|
||||
torch.set_default_device("cpu")
|
||||
torch.set_default_dtype(default_dtype)
|
||||
|
||||
if self.world_size > 1:
|
||||
if rank == 0:
|
||||
self.shm = SharedMemory(name="nanovllm", create=True, size=2**20)
|
||||
dist.barrier()
|
||||
else:
|
||||
dist.barrier()
|
||||
self.shm = SharedMemory(name="nanovllm")
|
||||
self.loop()
|
||||
|
||||
def exit(self):
|
||||
if self.world_size > 1:
|
||||
self.shm.close()
|
||||
if self.rank == 0:
|
||||
self.shm.unlink()
|
||||
# dist.destroy_process_group()
|
||||
|
||||
def loop(self):
|
||||
while True:
|
||||
method_name, args = self.read_shm()
|
||||
method = getattr(self, method_name, None)
|
||||
assert callable(method)
|
||||
method(*args)
|
||||
if method_name == "exit":
|
||||
break
|
||||
|
||||
def read_shm(self):
|
||||
assert self.world_size > 1 and self.rank
|
||||
self.event.wait()
|
||||
n = int.from_bytes(self.shm.buf[0:4], "little")
|
||||
method_name, *args = pickle.loads(self.shm.buf[4:n+4])
|
||||
self.event.clear()
|
||||
return method_name, args
|
||||
|
||||
def write_shm(self, method_name, *args):
|
||||
assert self.world_size > 1 and not self.rank
|
||||
data = pickle.dumps([method_name, *args])
|
||||
n = len(data)
|
||||
assert n + 4 <= self.shm.size
|
||||
self.shm.buf[0:4] = n.to_bytes(4, "little")
|
||||
self.shm.buf[4:n+4] = data
|
||||
for event in self.event:
|
||||
event.set()
|
||||
|
||||
def call(self, method_name, *args):
|
||||
assert self.rank == 0
|
||||
if self.world_size > 1:
|
||||
self.write_shm(method_name, *args)
|
||||
method = getattr(self, method_name, None)
|
||||
assert callable(method)
|
||||
return method(*args)
|
||||
|
||||
def allocate_kv_cache(self, gpu_memory_utilization):
|
||||
config = self.config
|
||||
hf_config = config.hf_config
|
||||
total, used, _ = get_gpu_memory()
|
||||
free = total * gpu_memory_utilization - used
|
||||
block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * hf_config.num_key_value_heads * hf_config.head_dim * hf_config.torch_dtype.itemsize
|
||||
num_kv_heads = hf_config.num_key_value_heads // dist.get_world_size()
|
||||
block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * hf_config.head_dim * hf_config.torch_dtype.itemsize
|
||||
config.num_kvcache_blocks = int(free) // block_bytes
|
||||
self.kv_cache = torch.zeros(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, hf_config.num_key_value_heads, hf_config.head_dim)
|
||||
self.kv_cache = torch.zeros(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, hf_config.head_dim)
|
||||
layer_id = 0
|
||||
for module in self.model.modules():
|
||||
if hasattr(module, "k_cache") and hasattr(module, "v_cache"):
|
||||
@@ -148,7 +209,7 @@ class ModelRunner:
|
||||
input_ids, positions = self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs)
|
||||
temperatures = self.prepare_sample(seqs)
|
||||
logits = self.run_model(input_ids, positions, is_prefill)
|
||||
token_ids = self.sampler(logits, temperatures).tolist()
|
||||
token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None
|
||||
reset_context()
|
||||
return token_ids
|
||||
|
||||
|
||||
Reference in New Issue
Block a user