Merge pull request #11 from GeeeekExplorer/tp_dev
This commit is contained in:
@@ -1,2 +1,2 @@
|
|||||||
from nanovllm.llm import LLM
|
from nanovllm.llm import LLM
|
||||||
from nanovllm.sampling_params import SamplingParams
|
from nanovllm.sampling_params import SamplingParams
|
||||||
|
|||||||
@@ -1,14 +1,16 @@
|
|||||||
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from transformers import AutoConfig
|
from transformers import AutoConfig
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Config:
|
class Config:
|
||||||
model: str = ''
|
model: str
|
||||||
max_num_batched_tokens: int = 32768
|
max_num_batched_tokens: int = 32768
|
||||||
max_num_seqs: int = 512
|
max_num_seqs: int = 512
|
||||||
max_model_len: int = 4096
|
max_model_len: int = 4096
|
||||||
gpu_memory_utilization: float = 0.9
|
gpu_memory_utilization: float = 0.9
|
||||||
|
tensor_parallel_size: int = 1
|
||||||
enforce_eager: bool = False
|
enforce_eager: bool = False
|
||||||
hf_config: AutoConfig | None = None
|
hf_config: AutoConfig | None = None
|
||||||
eos: int = -1
|
eos: int = -1
|
||||||
@@ -16,5 +18,8 @@ class Config:
|
|||||||
num_kvcache_blocks: int = -1
|
num_kvcache_blocks: int = -1
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
assert self.model
|
assert os.path.isdir(self.model)
|
||||||
assert self.kvcache_block_size % 256 == 0
|
assert self.kvcache_block_size % 256 == 0
|
||||||
|
assert 1 <= self.tensor_parallel_size <= 8
|
||||||
|
self.hf_config = AutoConfig.from_pretrained(self.model)
|
||||||
|
self.max_model_len = min(self.max_model_len, self.hf_config.max_position_embeddings)
|
||||||
|
|||||||
@@ -107,7 +107,7 @@ class BlockManager:
|
|||||||
block_table.append(block_id)
|
block_table.append(block_id)
|
||||||
elif len(seq) % self.block_size == 0:
|
elif len(seq) % self.block_size == 0:
|
||||||
assert last_block.hash == -1
|
assert last_block.hash == -1
|
||||||
token_ids = seq.last_block()
|
token_ids = seq.block(seq.num_blocks-1)
|
||||||
prefix = self.blocks[block_table[-2]].hash if len(block_table) > 1 else -1
|
prefix = self.blocks[block_table[-2]].hash if len(block_table) > 1 else -1
|
||||||
h = compute_hash(token_ids, prefix)
|
h = compute_hash(token_ids, prefix)
|
||||||
last_block.update(h, token_ids)
|
last_block.update(h, token_ids)
|
||||||
|
|||||||
@@ -1,6 +1,9 @@
|
|||||||
|
import atexit
|
||||||
|
from dataclasses import fields
|
||||||
from time import perf_counter
|
from time import perf_counter
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
from transformers import AutoConfig, AutoTokenizer
|
from transformers import AutoConfig, AutoTokenizer
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
from nanovllm.config import Config
|
from nanovllm.config import Config
|
||||||
from nanovllm.sampling_params import SamplingParams
|
from nanovllm.sampling_params import SamplingParams
|
||||||
@@ -12,17 +15,27 @@ from nanovllm.engine.model_runner import ModelRunner
|
|||||||
class LLMEngine:
|
class LLMEngine:
|
||||||
|
|
||||||
def __init__(self, model, **kwargs):
|
def __init__(self, model, **kwargs):
|
||||||
config = Config(model)
|
config_fileds = {field.name for field in fields(Config)}
|
||||||
for k, v in kwargs.items():
|
config_kwargs = {k: v for k, v in kwargs.items() if k in config_fileds}
|
||||||
if hasattr(config, k):
|
config = Config(model, **config_kwargs)
|
||||||
setattr(config, k, v)
|
self.ps = []
|
||||||
Sequence.block_size = config.kvcache_block_size
|
self.events = []
|
||||||
config.hf_config = AutoConfig.from_pretrained(config.model)
|
for i in range(1, config.tensor_parallel_size):
|
||||||
config.max_model_len = min(config.max_model_len, config.hf_config.max_position_embeddings)
|
event = mp.Event()
|
||||||
|
process = mp.Process(target=ModelRunner, args=(config, i, event))
|
||||||
|
process.start()
|
||||||
|
self.ps.append(process)
|
||||||
|
self.events.append(event)
|
||||||
|
self.model_runner = ModelRunner(config, 0, self.events)
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True)
|
self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True)
|
||||||
config.eos = self.tokenizer.eos_token_id
|
config.eos = self.tokenizer.eos_token_id
|
||||||
self.model_runner = ModelRunner(config)
|
|
||||||
self.scheduler = Scheduler(config)
|
self.scheduler = Scheduler(config)
|
||||||
|
atexit.register(self.exit)
|
||||||
|
|
||||||
|
def exit(self):
|
||||||
|
self.model_runner.call("exit")
|
||||||
|
for p in self.ps:
|
||||||
|
p.join()
|
||||||
|
|
||||||
def add_request(self, prompt: str | list[int], sampling_params: SamplingParams):
|
def add_request(self, prompt: str | list[int], sampling_params: SamplingParams):
|
||||||
if isinstance(prompt, str):
|
if isinstance(prompt, str):
|
||||||
@@ -32,7 +45,7 @@ class LLMEngine:
|
|||||||
|
|
||||||
def step(self):
|
def step(self):
|
||||||
seqs, is_prefill = self.scheduler.schedule()
|
seqs, is_prefill = self.scheduler.schedule()
|
||||||
token_ids = self.model_runner.run(seqs, is_prefill)
|
token_ids = self.model_runner.call("run", seqs, is_prefill)
|
||||||
self.scheduler.postprocess(seqs, token_ids)
|
self.scheduler.postprocess(seqs, token_ids)
|
||||||
outputs = [(seq.seq_id, seq.completion_token_ids) for seq in seqs if seq.is_finished]
|
outputs = [(seq.seq_id, seq.completion_token_ids) for seq in seqs if seq.is_finished]
|
||||||
num_tokens = sum(len(seq) for seq in seqs) if is_prefill else -len(seqs)
|
num_tokens = sum(len(seq) for seq in seqs) if is_prefill else -len(seqs)
|
||||||
@@ -79,4 +92,4 @@ class LLMEngine:
|
|||||||
outputs = [{"text": self.tokenizer.decode(token_ids), "token_ids": token_ids} for token_ids in outputs]
|
outputs = [{"text": self.tokenizer.decode(token_ids), "token_ids": token_ids} for token_ids in outputs]
|
||||||
if use_tqdm:
|
if use_tqdm:
|
||||||
pbar.close()
|
pbar.close()
|
||||||
return outputs
|
return outputs
|
||||||
|
|||||||
@@ -1,4 +1,8 @@
|
|||||||
|
import pickle
|
||||||
import torch
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from multiprocess.synchronize import Event
|
||||||
|
from multiprocess.shared_memory import SharedMemory
|
||||||
|
|
||||||
from nanovllm.config import Config
|
from nanovllm.config import Config
|
||||||
from nanovllm.engine.sequence import Sequence
|
from nanovllm.engine.sequence import Sequence
|
||||||
@@ -11,12 +15,17 @@ from nanovllm.utils.loader import load_model
|
|||||||
|
|
||||||
class ModelRunner:
|
class ModelRunner:
|
||||||
|
|
||||||
def __init__(self, config: Config):
|
def __init__(self, config: Config, rank: int, event: Event | list[Event]):
|
||||||
self.config = config
|
self.config = config
|
||||||
hf_config = config.hf_config
|
hf_config = config.hf_config
|
||||||
self.block_size = config.kvcache_block_size
|
self.block_size = config.kvcache_block_size
|
||||||
self.enforce_eager = config.enforce_eager
|
self.enforce_eager = config.enforce_eager
|
||||||
|
self.world_size = config.tensor_parallel_size
|
||||||
|
self.rank = rank
|
||||||
|
self.event = event
|
||||||
|
|
||||||
|
dist.init_process_group("nccl", "tcp://localhost:2333", world_size=self.world_size, rank=rank)
|
||||||
|
torch.cuda.set_device(rank)
|
||||||
default_dtype = torch.get_default_dtype()
|
default_dtype = torch.get_default_dtype()
|
||||||
torch.set_default_dtype(hf_config.torch_dtype)
|
torch.set_default_dtype(hf_config.torch_dtype)
|
||||||
torch.set_default_device("cuda")
|
torch.set_default_device("cuda")
|
||||||
@@ -29,14 +38,63 @@ class ModelRunner:
|
|||||||
torch.set_default_device("cpu")
|
torch.set_default_device("cpu")
|
||||||
torch.set_default_dtype(default_dtype)
|
torch.set_default_dtype(default_dtype)
|
||||||
|
|
||||||
|
if self.world_size > 1:
|
||||||
|
if rank == 0:
|
||||||
|
self.shm = SharedMemory(name="nanovllm", create=True, size=2**20)
|
||||||
|
dist.barrier()
|
||||||
|
else:
|
||||||
|
dist.barrier()
|
||||||
|
self.shm = SharedMemory(name="nanovllm")
|
||||||
|
self.loop()
|
||||||
|
|
||||||
|
def exit(self):
|
||||||
|
if self.world_size > 1:
|
||||||
|
self.shm.close()
|
||||||
|
if self.rank == 0:
|
||||||
|
self.shm.unlink()
|
||||||
|
# dist.destroy_process_group()
|
||||||
|
|
||||||
|
def loop(self):
|
||||||
|
while True:
|
||||||
|
method_name, args = self.read_shm()
|
||||||
|
self.call(method_name, *args)
|
||||||
|
if method_name == "exit":
|
||||||
|
break
|
||||||
|
|
||||||
|
def read_shm(self):
|
||||||
|
assert self.world_size > 1 and self.rank
|
||||||
|
self.event.wait()
|
||||||
|
n = int.from_bytes(self.shm.buf[0:4], "little")
|
||||||
|
method_name, *args = pickle.loads(self.shm.buf[4:n+4])
|
||||||
|
self.event.clear()
|
||||||
|
return method_name, args
|
||||||
|
|
||||||
|
def write_shm(self, method_name, *args):
|
||||||
|
assert self.world_size > 1 and not self.rank
|
||||||
|
data = pickle.dumps([method_name, *args])
|
||||||
|
n = len(data)
|
||||||
|
assert n + 4 <= self.shm.size
|
||||||
|
self.shm.buf[0:4] = n.to_bytes(4, "little")
|
||||||
|
self.shm.buf[4:n+4] = data
|
||||||
|
for event in self.event:
|
||||||
|
event.set()
|
||||||
|
|
||||||
|
def call(self, method_name, *args):
|
||||||
|
if self.world_size > 1 and self.rank == 0:
|
||||||
|
self.write_shm(method_name, *args)
|
||||||
|
method = getattr(self, method_name, None)
|
||||||
|
assert callable(method)
|
||||||
|
return method(*args)
|
||||||
|
|
||||||
def allocate_kv_cache(self, gpu_memory_utilization):
|
def allocate_kv_cache(self, gpu_memory_utilization):
|
||||||
config = self.config
|
config = self.config
|
||||||
hf_config = config.hf_config
|
hf_config = config.hf_config
|
||||||
total, used, _ = get_gpu_memory()
|
total, used, _ = get_gpu_memory()
|
||||||
free = total * gpu_memory_utilization - used
|
free = total * gpu_memory_utilization - used
|
||||||
block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * hf_config.num_key_value_heads * hf_config.head_dim * hf_config.torch_dtype.itemsize
|
num_kv_heads = hf_config.num_key_value_heads // dist.get_world_size()
|
||||||
|
block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * hf_config.head_dim * hf_config.torch_dtype.itemsize
|
||||||
config.num_kvcache_blocks = int(free) // block_bytes
|
config.num_kvcache_blocks = int(free) // block_bytes
|
||||||
self.kv_cache = torch.zeros(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, hf_config.num_key_value_heads, hf_config.head_dim)
|
self.kv_cache = torch.zeros(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, hf_config.head_dim)
|
||||||
layer_id = 0
|
layer_id = 0
|
||||||
for module in self.model.modules():
|
for module in self.model.modules():
|
||||||
if hasattr(module, "k_cache") and hasattr(module, "v_cache"):
|
if hasattr(module, "k_cache") and hasattr(module, "v_cache"):
|
||||||
@@ -66,7 +124,7 @@ class ModelRunner:
|
|||||||
for seq in seqs:
|
for seq in seqs:
|
||||||
seqlen = len(seq)
|
seqlen = len(seq)
|
||||||
input_ids.extend(seq[seq.num_cached_tokens:])
|
input_ids.extend(seq[seq.num_cached_tokens:])
|
||||||
positions.extend(list(range(seq.num_cached_tokens, len(seq))))
|
positions.extend(list(range(seq.num_cached_tokens, seqlen)))
|
||||||
seqlen_q = seqlen - seq.num_cached_tokens
|
seqlen_q = seqlen - seq.num_cached_tokens
|
||||||
seqlen_k = seqlen
|
seqlen_k = seqlen
|
||||||
cu_seqlens_q.append(cu_seqlens_q[-1] + seqlen_q)
|
cu_seqlens_q.append(cu_seqlens_q[-1] + seqlen_q)
|
||||||
@@ -78,7 +136,7 @@ class ModelRunner:
|
|||||||
if i != seq.num_blocks - 1:
|
if i != seq.num_blocks - 1:
|
||||||
end = start + self.block_size
|
end = start + self.block_size
|
||||||
else:
|
else:
|
||||||
end = start + len(seq.last_block())
|
end = start + seq.last_block_num_tokens
|
||||||
slot_mapping.extend(list(range(start, end)))
|
slot_mapping.extend(list(range(start, end)))
|
||||||
assert len(input_ids) == len(slot_mapping)
|
assert len(input_ids) == len(slot_mapping)
|
||||||
assert len(input_ids) == cu_seqlens_q[-1]
|
assert len(input_ids) == cu_seqlens_q[-1]
|
||||||
@@ -102,7 +160,7 @@ class ModelRunner:
|
|||||||
input_ids.append(seq.last_token)
|
input_ids.append(seq.last_token)
|
||||||
positions.append(len(seq))
|
positions.append(len(seq))
|
||||||
context_lens.append(len(seq))
|
context_lens.append(len(seq))
|
||||||
slot_mapping.append(seq.block_table[-1] * self.block_size + len(seq.last_block()) - 1)
|
slot_mapping.append(seq.block_table[-1] * self.block_size + seq.last_block_num_tokens - 1)
|
||||||
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
||||||
positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
||||||
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||||
@@ -148,7 +206,7 @@ class ModelRunner:
|
|||||||
input_ids, positions = self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs)
|
input_ids, positions = self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs)
|
||||||
temperatures = self.prepare_sample(seqs)
|
temperatures = self.prepare_sample(seqs)
|
||||||
logits = self.run_model(input_ids, positions, is_prefill)
|
logits = self.run_model(input_ids, positions, is_prefill)
|
||||||
token_ids = self.sampler(logits, temperatures).tolist()
|
token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None
|
||||||
reset_context()
|
reset_context()
|
||||||
return token_ids
|
return token_ids
|
||||||
|
|
||||||
|
|||||||
@@ -14,8 +14,6 @@ class Scheduler:
|
|||||||
self.block_manager = BlockManager(config.num_kvcache_blocks, config.kvcache_block_size)
|
self.block_manager = BlockManager(config.num_kvcache_blocks, config.kvcache_block_size)
|
||||||
self.waiting: deque[Sequence] = deque()
|
self.waiting: deque[Sequence] = deque()
|
||||||
self.running: deque[Sequence] = deque()
|
self.running: deque[Sequence] = deque()
|
||||||
self.num_finished = 0
|
|
||||||
self.num_tokens = 0
|
|
||||||
|
|
||||||
def is_finished(self):
|
def is_finished(self):
|
||||||
return not self.waiting and not self.running
|
return not self.waiting and not self.running
|
||||||
@@ -67,11 +65,9 @@ class Scheduler:
|
|||||||
self.waiting.appendleft(seq)
|
self.waiting.appendleft(seq)
|
||||||
|
|
||||||
def postprocess(self, seqs: list[Sequence], token_ids: list[int]) -> list[bool]:
|
def postprocess(self, seqs: list[Sequence], token_ids: list[int]) -> list[bool]:
|
||||||
self.num_tokens += len(token_ids)
|
|
||||||
for seq, token_id in zip(seqs, token_ids):
|
for seq, token_id in zip(seqs, token_ids):
|
||||||
seq.append_token(token_id)
|
seq.append_token(token_id)
|
||||||
if (not seq.ignore_eos and token_id == self.eos) or seq.num_completion_tokens == seq.max_tokens:
|
if (not seq.ignore_eos and token_id == self.eos) or seq.num_completion_tokens == seq.max_tokens:
|
||||||
seq.status = SequenceStatus.FINISHED
|
seq.status = SequenceStatus.FINISHED
|
||||||
self.block_manager.deallocate(seq)
|
self.block_manager.deallocate(seq)
|
||||||
self.running.remove(seq)
|
self.running.remove(seq)
|
||||||
self.num_finished += 1
|
|
||||||
|
|||||||
@@ -19,15 +19,17 @@ class Sequence:
|
|||||||
self.seq_id = next(Sequence.counter)
|
self.seq_id = next(Sequence.counter)
|
||||||
self.status = SequenceStatus.WAITING
|
self.status = SequenceStatus.WAITING
|
||||||
self.token_ids = copy(token_ids)
|
self.token_ids = copy(token_ids)
|
||||||
|
self.last_token = token_ids[-1]
|
||||||
|
self.num_tokens = len(self.token_ids)
|
||||||
self.num_prompt_tokens = len(token_ids)
|
self.num_prompt_tokens = len(token_ids)
|
||||||
self._num_cached_tokens = 0
|
self.num_cached_tokens = 0
|
||||||
self.block_table = []
|
self.block_table = []
|
||||||
self.temperature = sampling_params.temperature
|
self.temperature = sampling_params.temperature
|
||||||
self.max_tokens = sampling_params.max_tokens
|
self.max_tokens = sampling_params.max_tokens
|
||||||
self.ignore_eos = sampling_params.ignore_eos
|
self.ignore_eos = sampling_params.ignore_eos
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.token_ids)
|
return self.num_tokens
|
||||||
|
|
||||||
def __lt__(self, other):
|
def __lt__(self, other):
|
||||||
return self.seq_id < other.seq_id
|
return self.seq_id < other.seq_id
|
||||||
@@ -41,7 +43,7 @@ class Sequence:
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def num_completion_tokens(self):
|
def num_completion_tokens(self):
|
||||||
return len(self.token_ids) - self.num_prompt_tokens
|
return self.num_tokens - self.num_prompt_tokens
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def prompt_token_ids(self):
|
def prompt_token_ids(self):
|
||||||
@@ -51,33 +53,29 @@ class Sequence:
|
|||||||
def completion_token_ids(self):
|
def completion_token_ids(self):
|
||||||
return self.token_ids[self.num_prompt_tokens:]
|
return self.token_ids[self.num_prompt_tokens:]
|
||||||
|
|
||||||
@property
|
|
||||||
def num_cached_tokens(self):
|
|
||||||
return self._num_cached_tokens
|
|
||||||
|
|
||||||
@num_cached_tokens.setter
|
|
||||||
def num_cached_tokens(self, num_cached_tokens):
|
|
||||||
assert num_cached_tokens % self.block_size == 0
|
|
||||||
self._num_cached_tokens = num_cached_tokens
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_cached_blocks(self):
|
def num_cached_blocks(self):
|
||||||
return self.num_cached_tokens // self.block_size
|
return self.num_cached_tokens // self.block_size
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_blocks(self):
|
def num_blocks(self):
|
||||||
return (len(self.token_ids) + self.block_size - 1) // self.block_size
|
return (self.num_tokens + self.block_size - 1) // self.block_size
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def last_token(self):
|
def last_block_num_tokens(self):
|
||||||
return self.token_ids[-1]
|
return self.num_tokens - (self.num_blocks - 1) * self.block_size
|
||||||
|
|
||||||
def block(self, i):
|
def block(self, i):
|
||||||
|
assert 0 <= i < self.num_blocks
|
||||||
return self.token_ids[i*self.block_size: (i+1)*self.block_size]
|
return self.token_ids[i*self.block_size: (i+1)*self.block_size]
|
||||||
|
|
||||||
def last_block(self):
|
|
||||||
n = self.num_blocks
|
|
||||||
return self.token_ids[(n-1)*self.block_size:]
|
|
||||||
|
|
||||||
def append_token(self, token_id: int):
|
def append_token(self, token_id: int):
|
||||||
self.token_ids.append(token_id)
|
self.token_ids.append(token_id)
|
||||||
|
self.last_token = token_id
|
||||||
|
self.num_tokens += 1
|
||||||
|
|
||||||
|
def __getstate__(self):
|
||||||
|
state = vars(self).copy()
|
||||||
|
if self.num_completion_tokens:
|
||||||
|
state.pop("token_ids")
|
||||||
|
return state
|
||||||
|
|||||||
@@ -11,4 +11,4 @@ class SiluAndMul(nn.Module):
|
|||||||
@torch.compile
|
@torch.compile
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
x, y = x.chunk(2, -1)
|
x, y = x.chunk(2, -1)
|
||||||
return F.silu(x) * y
|
return y.mul_(F.silu(x))
|
||||||
|
|||||||
@@ -14,8 +14,8 @@ class VocabParallelEmbedding(nn.Module):
|
|||||||
embedding_dim: int,
|
embedding_dim: int,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.tp_rank = 0 # get_tensor_model_parallel_rank()
|
self.tp_rank = dist.get_rank()
|
||||||
self.tp_size = 1 # get_tensor_model_parallel_world_size()
|
self.tp_size = dist.get_world_size()
|
||||||
assert num_embeddings % self.tp_size == 0
|
assert num_embeddings % self.tp_size == 0
|
||||||
self.num_embeddings = num_embeddings
|
self.num_embeddings = num_embeddings
|
||||||
self.num_embeddings_per_partition = self.num_embeddings // self.tp_size
|
self.num_embeddings_per_partition = self.num_embeddings // self.tp_size
|
||||||
@@ -39,7 +39,7 @@ class VocabParallelEmbedding(nn.Module):
|
|||||||
x = mask * (x - self.vocab_start_idx)
|
x = mask * (x - self.vocab_start_idx)
|
||||||
y = F.embedding(x, self.weight)
|
y = F.embedding(x, self.weight)
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
y = mask * y
|
y = mask.unsqueeze(1) * y
|
||||||
dist.all_reduce(y)
|
dist.all_reduce(y)
|
||||||
return y
|
return y
|
||||||
|
|
||||||
@@ -65,8 +65,8 @@ class ParallelLMHead(VocabParallelEmbedding):
|
|||||||
last_indices = context.cu_seqlens_q[1:] - 1
|
last_indices = context.cu_seqlens_q[1:] - 1
|
||||||
x = x[last_indices].contiguous()
|
x = x[last_indices].contiguous()
|
||||||
logits = F.linear(x, self.weight, self.bias)
|
logits = F.linear(x, self.weight, self.bias)
|
||||||
# if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
# all_logits = [torch.empty_like(logits) for _ in range(self.tp_size)]
|
all_logits = [torch.empty_like(logits) for _ in range(self.tp_size)] if self.tp_rank == 0 else None
|
||||||
# dist.gather(logits, all_logits, 0)
|
dist.gather(logits, all_logits, 0)
|
||||||
# logits = torch.cat(all_logits, -1)
|
logits = torch.cat(all_logits, -1) if self.tp_rank == 0 else None
|
||||||
return logits if self.tp_rank == 0 else None
|
return logits
|
||||||
|
|||||||
@@ -21,8 +21,8 @@ class LinearBase(nn.Module):
|
|||||||
self.input_size = input_size
|
self.input_size = input_size
|
||||||
self.output_size = output_size
|
self.output_size = output_size
|
||||||
self.tp_dim = tp_dim
|
self.tp_dim = tp_dim
|
||||||
self.tp_rank = 0 # get_tensor_model_parallel_rank()
|
self.tp_rank = dist.get_rank()
|
||||||
self.tp_size = 1 # get_tensor_model_parallel_world_size()
|
self.tp_size = dist.get_world_size()
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@@ -65,7 +65,6 @@ class ColumnParallelLinear(LinearBase):
|
|||||||
self.input_size_per_partition = input_size
|
self.input_size_per_partition = input_size
|
||||||
self.output_size_per_partition = divide(output_size, self.tp_size)
|
self.output_size_per_partition = divide(output_size, self.tp_size)
|
||||||
self.output_partition_sizes = [self.output_size_per_partition]
|
self.output_partition_sizes = [self.output_size_per_partition]
|
||||||
# If QKV or MergedColumn, use output size of each partition.
|
|
||||||
if hasattr(self, "output_sizes"):
|
if hasattr(self, "output_sizes"):
|
||||||
self.output_partition_sizes = [
|
self.output_partition_sizes = [
|
||||||
divide(output_size, self.tp_size)
|
divide(output_size, self.tp_size)
|
||||||
@@ -101,8 +100,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|||||||
bias: bool = False,
|
bias: bool = False,
|
||||||
):
|
):
|
||||||
self.output_sizes = output_sizes
|
self.output_sizes = output_sizes
|
||||||
tp_size = 1 # get_tensor_model_parallel_world_size()
|
|
||||||
assert all(output_size % tp_size == 0 for output_size in output_sizes)
|
|
||||||
super().__init__(input_size, sum(output_sizes), bias=bias)
|
super().__init__(input_size, sum(output_sizes), bias=bias)
|
||||||
|
|
||||||
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: int):
|
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: int):
|
||||||
@@ -110,7 +107,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|||||||
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size
|
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size
|
||||||
shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
|
shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
|
||||||
param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size)
|
param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size)
|
||||||
# loaded_weight = loaded_weight.narrow(self.tp_dim, self.tp_rank * shard_size, shard_size)
|
loaded_weight = loaded_weight.chunk(self.tp_size, self.tp_dim)[self.tp_rank]
|
||||||
assert param_data.size() == loaded_weight.size()
|
assert param_data.size() == loaded_weight.size()
|
||||||
param_data.copy_(loaded_weight)
|
param_data.copy_(loaded_weight)
|
||||||
|
|
||||||
@@ -131,8 +128,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|||||||
if total_num_kv_heads is None:
|
if total_num_kv_heads is None:
|
||||||
total_num_kv_heads = total_num_heads
|
total_num_kv_heads = total_num_heads
|
||||||
self.total_num_kv_heads = total_num_kv_heads
|
self.total_num_kv_heads = total_num_kv_heads
|
||||||
# Divide the weight matrix along the last dimension.
|
tp_size = dist.get_world_size()
|
||||||
tp_size = 1 # get_tensor_model_parallel_world_size()
|
|
||||||
self.num_heads = divide(self.total_num_heads, tp_size)
|
self.num_heads = divide(self.total_num_heads, tp_size)
|
||||||
self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
|
self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
|
||||||
input_size = self.hidden_size
|
input_size = self.hidden_size
|
||||||
@@ -158,7 +154,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|||||||
shard_size = self.num_kv_heads * self.head_size
|
shard_size = self.num_kv_heads * self.head_size
|
||||||
shard_offset = self.num_heads * self.head_size + self.num_kv_heads * self.head_size
|
shard_offset = self.num_heads * self.head_size + self.num_kv_heads * self.head_size
|
||||||
param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size)
|
param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size)
|
||||||
# loaded_weight = loaded_weight.narrow(self.tp_dim, self.tp_rank * shard_size, shard_size)
|
loaded_weight = loaded_weight.chunk(self.tp_size, self.tp_dim)[self.tp_rank]
|
||||||
assert param_data.size() == loaded_weight.size()
|
assert param_data.size() == loaded_weight.size()
|
||||||
param_data.copy_(loaded_weight)
|
param_data.copy_(loaded_weight)
|
||||||
|
|
||||||
|
|||||||
@@ -70,4 +70,4 @@ def get_rope(
|
|||||||
):
|
):
|
||||||
assert rope_scaling is None
|
assert rope_scaling is None
|
||||||
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base)
|
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base)
|
||||||
return rotary_emb
|
return rotary_emb
|
||||||
|
|||||||
@@ -2,4 +2,4 @@ from nanovllm.engine.llm_engine import LLMEngine
|
|||||||
|
|
||||||
|
|
||||||
class LLM(LLMEngine):
|
class LLM(LLMEngine):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
import torch.distributed as dist
|
||||||
from transformers import Qwen3Config
|
from transformers import Qwen3Config
|
||||||
|
|
||||||
from nanovllm.layers.activation import SiluAndMul
|
from nanovllm.layers.activation import SiluAndMul
|
||||||
@@ -26,7 +27,7 @@ class Qwen3Attention(nn.Module):
|
|||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
tp_size = 1 # get_tensor_model_parallel_world_size()
|
tp_size = dist.get_world_size()
|
||||||
self.total_num_heads = num_heads
|
self.total_num_heads = num_heads
|
||||||
assert self.total_num_heads % tp_size == 0
|
assert self.total_num_heads % tp_size == 0
|
||||||
self.num_heads = self.total_num_heads // tp_size
|
self.num_heads = self.total_num_heads // tp_size
|
||||||
|
|||||||
@@ -25,4 +25,4 @@ def set_context(is_prefill, cu_seqlens_q=None, cu_seqlens_k=None, max_seqlen_q=0
|
|||||||
|
|
||||||
def reset_context():
|
def reset_context():
|
||||||
global _CONTEXT
|
global _CONTEXT
|
||||||
_CONTEXT = Context()
|
_CONTEXT = Context()
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ def default_weight_loader(param: nn.Parameter, loaded_weight: torch.Tensor):
|
|||||||
|
|
||||||
|
|
||||||
def load_model(model: nn.Module, path: str):
|
def load_model(model: nn.Module, path: str):
|
||||||
assert os.path.isdir(path)
|
|
||||||
packed_modules_mapping = getattr(model, "packed_modules_mapping", {})
|
packed_modules_mapping = getattr(model, "packed_modules_mapping", {})
|
||||||
for file in glob(os.path.join(path, "*.safetensors")):
|
for file in glob(os.path.join(path, "*.safetensors")):
|
||||||
with safe_open(file, "pt", "cpu") as f:
|
with safe_open(file, "pt", "cpu") as f:
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "nano-vllm"
|
name = "nano-vllm"
|
||||||
version = "0.1.0"
|
version = "0.2.0"
|
||||||
authors = [{ name = "Xingkai Yu" }]
|
authors = [{ name = "Xingkai Yu" }]
|
||||||
license = "MIT"
|
license = "MIT"
|
||||||
license-files = ["LICENSE"]
|
license-files = ["LICENSE"]
|
||||||
|
|||||||
Reference in New Issue
Block a user