support tensor parallel

This commit is contained in:
cheunglei
2025-06-15 01:31:24 +08:00
parent b6136383c9
commit 53b3ef2e32
9 changed files with 102 additions and 31 deletions

View File

@@ -1,6 +1,8 @@
import atexit
from time import perf_counter
from tqdm.auto import tqdm
from transformers import AutoConfig, AutoTokenizer
import torch.multiprocessing as mp
from nanovllm.config import Config
from nanovllm.sampling_params import SamplingParams
@@ -19,10 +21,24 @@ class LLMEngine:
Sequence.block_size = config.kvcache_block_size
config.hf_config = AutoConfig.from_pretrained(config.model)
config.max_model_len = min(config.max_model_len, config.hf_config.max_position_embeddings)
self.ps = []
self.events = []
for i in range(1, config.tensor_parallel_size):
event = mp.Event()
process = mp.Process(target=ModelRunner, args=(config, i, event))
process.start()
self.ps.append(process)
self.events.append(event)
self.model_runner = ModelRunner(config, 0, self.events)
self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True)
config.eos = self.tokenizer.eos_token_id
self.model_runner = ModelRunner(config)
self.scheduler = Scheduler(config)
atexit.register(self.exit)
def exit(self):
self.model_runner.call("exit")
for p in self.ps:
p.join()
def add_request(self, prompt: str | list[int], sampling_params: SamplingParams):
if isinstance(prompt, str):
@@ -32,7 +48,7 @@ class LLMEngine:
def step(self):
seqs, is_prefill = self.scheduler.schedule()
token_ids = self.model_runner.run(seqs, is_prefill)
token_ids = self.model_runner.call("run", seqs, is_prefill)
self.scheduler.postprocess(seqs, token_ids)
outputs = [(seq.seq_id, seq.completion_token_ids) for seq in seqs if seq.is_finished]
num_tokens = sum(len(seq) for seq in seqs) if is_prefill else -len(seqs)

View File

@@ -1,4 +1,8 @@
import pickle
import torch
import torch.distributed as dist
from multiprocess.synchronize import Event
from multiprocess.shared_memory import SharedMemory
from nanovllm.config import Config
from nanovllm.engine.sequence import Sequence
@@ -11,12 +15,17 @@ from nanovllm.utils.loader import load_model
class ModelRunner:
def __init__(self, config: Config):
def __init__(self, config: Config, rank: int, event: Event | list[Event]):
self.config = config
hf_config = config.hf_config
self.block_size = config.kvcache_block_size
self.enforce_eager = config.enforce_eager
self.world_size = config.tensor_parallel_size
self.rank = rank
self.event = event
dist.init_process_group("nccl", "tcp://localhost:2333", world_size=self.world_size, rank=rank)
torch.cuda.set_device(rank)
default_dtype = torch.get_default_dtype()
torch.set_default_dtype(hf_config.torch_dtype)
torch.set_default_device("cuda")
@@ -29,14 +38,66 @@ class ModelRunner:
torch.set_default_device("cpu")
torch.set_default_dtype(default_dtype)
if self.world_size > 1:
if rank == 0:
self.shm = SharedMemory(name="nanovllm", create=True, size=2**20)
dist.barrier()
else:
dist.barrier()
self.shm = SharedMemory(name="nanovllm")
self.loop()
def exit(self):
if self.world_size > 1:
self.shm.close()
if self.rank == 0:
self.shm.unlink()
# dist.destroy_process_group()
def loop(self):
while True:
method_name, args = self.read_shm()
method = getattr(self, method_name, None)
assert callable(method)
method(*args)
if method_name == "exit":
break
def read_shm(self):
assert self.world_size > 1 and self.rank
self.event.wait()
n = int.from_bytes(self.shm.buf[0:4], "little")
method_name, *args = pickle.loads(self.shm.buf[4:n+4])
self.event.clear()
return method_name, args
def write_shm(self, method_name, *args):
assert self.world_size > 1 and not self.rank
data = pickle.dumps([method_name, *args])
n = len(data)
assert n + 4 <= self.shm.size
self.shm.buf[0:4] = n.to_bytes(4, "little")
self.shm.buf[4:n+4] = data
for event in self.event:
event.set()
def call(self, method_name, *args):
assert self.rank == 0
if self.world_size > 1:
self.write_shm(method_name, *args)
method = getattr(self, method_name, None)
assert callable(method)
return method(*args)
def allocate_kv_cache(self, gpu_memory_utilization):
config = self.config
hf_config = config.hf_config
total, used, _ = get_gpu_memory()
free = total * gpu_memory_utilization - used
block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * hf_config.num_key_value_heads * hf_config.head_dim * hf_config.torch_dtype.itemsize
num_kv_heads = hf_config.num_key_value_heads // dist.get_world_size()
block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * hf_config.head_dim * hf_config.torch_dtype.itemsize
config.num_kvcache_blocks = int(free) // block_bytes
self.kv_cache = torch.zeros(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, hf_config.num_key_value_heads, hf_config.head_dim)
self.kv_cache = torch.zeros(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, hf_config.head_dim)
layer_id = 0
for module in self.model.modules():
if hasattr(module, "k_cache") and hasattr(module, "v_cache"):
@@ -148,7 +209,7 @@ class ModelRunner:
input_ids, positions = self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs)
temperatures = self.prepare_sample(seqs)
logits = self.run_model(input_ids, positions, is_prefill)
token_ids = self.sampler(logits, temperatures).tolist()
token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None
reset_context()
return token_ids

View File

@@ -14,8 +14,6 @@ class Scheduler:
self.block_manager = BlockManager(config.num_kvcache_blocks, config.kvcache_block_size)
self.waiting: deque[Sequence] = deque()
self.running: deque[Sequence] = deque()
self.num_finished = 0
self.num_tokens = 0
def is_finished(self):
return not self.waiting and not self.running
@@ -67,11 +65,9 @@ class Scheduler:
self.waiting.appendleft(seq)
def postprocess(self, seqs: list[Sequence], token_ids: list[int]) -> list[bool]:
self.num_tokens += len(token_ids)
for seq, token_id in zip(seqs, token_ids):
seq.append_token(token_id)
if (not seq.ignore_eos and token_id == self.eos) or seq.num_completion_tokens == seq.max_tokens:
seq.status = SequenceStatus.FINISHED
self.block_manager.deallocate(seq)
self.running.remove(seq)
self.num_finished += 1

View File

@@ -75,7 +75,7 @@ class Sequence:
self.num_tokens += 1
def __getstate__(self):
state = super().__getstate__()
state = vars(self).copy()
if self.num_completion_tokens:
state.pop("token_ids")
return state