Merge pull request #11 from GeeeekExplorer/tp_dev

This commit is contained in:
Xingkai Yu
2025-06-15 10:37:21 +08:00
committed by GitHub
16 changed files with 135 additions and 69 deletions

View File

@@ -1,2 +1,2 @@
from nanovllm.llm import LLM from nanovllm.llm import LLM
from nanovllm.sampling_params import SamplingParams from nanovllm.sampling_params import SamplingParams

View File

@@ -1,14 +1,16 @@
import os
from dataclasses import dataclass from dataclasses import dataclass
from transformers import AutoConfig from transformers import AutoConfig
@dataclass @dataclass
class Config: class Config:
model: str = '' model: str
max_num_batched_tokens: int = 32768 max_num_batched_tokens: int = 32768
max_num_seqs: int = 512 max_num_seqs: int = 512
max_model_len: int = 4096 max_model_len: int = 4096
gpu_memory_utilization: float = 0.9 gpu_memory_utilization: float = 0.9
tensor_parallel_size: int = 1
enforce_eager: bool = False enforce_eager: bool = False
hf_config: AutoConfig | None = None hf_config: AutoConfig | None = None
eos: int = -1 eos: int = -1
@@ -16,5 +18,8 @@ class Config:
num_kvcache_blocks: int = -1 num_kvcache_blocks: int = -1
def __post_init__(self): def __post_init__(self):
assert self.model assert os.path.isdir(self.model)
assert self.kvcache_block_size % 256 == 0 assert self.kvcache_block_size % 256 == 0
assert 1 <= self.tensor_parallel_size <= 8
self.hf_config = AutoConfig.from_pretrained(self.model)
self.max_model_len = min(self.max_model_len, self.hf_config.max_position_embeddings)

View File

@@ -107,7 +107,7 @@ class BlockManager:
block_table.append(block_id) block_table.append(block_id)
elif len(seq) % self.block_size == 0: elif len(seq) % self.block_size == 0:
assert last_block.hash == -1 assert last_block.hash == -1
token_ids = seq.last_block() token_ids = seq.block(seq.num_blocks-1)
prefix = self.blocks[block_table[-2]].hash if len(block_table) > 1 else -1 prefix = self.blocks[block_table[-2]].hash if len(block_table) > 1 else -1
h = compute_hash(token_ids, prefix) h = compute_hash(token_ids, prefix)
last_block.update(h, token_ids) last_block.update(h, token_ids)

View File

@@ -1,6 +1,9 @@
import atexit
from dataclasses import fields
from time import perf_counter from time import perf_counter
from tqdm.auto import tqdm from tqdm.auto import tqdm
from transformers import AutoConfig, AutoTokenizer from transformers import AutoConfig, AutoTokenizer
import torch.multiprocessing as mp
from nanovllm.config import Config from nanovllm.config import Config
from nanovllm.sampling_params import SamplingParams from nanovllm.sampling_params import SamplingParams
@@ -12,17 +15,27 @@ from nanovllm.engine.model_runner import ModelRunner
class LLMEngine: class LLMEngine:
def __init__(self, model, **kwargs): def __init__(self, model, **kwargs):
config = Config(model) config_fileds = {field.name for field in fields(Config)}
for k, v in kwargs.items(): config_kwargs = {k: v for k, v in kwargs.items() if k in config_fileds}
if hasattr(config, k): config = Config(model, **config_kwargs)
setattr(config, k, v) self.ps = []
Sequence.block_size = config.kvcache_block_size self.events = []
config.hf_config = AutoConfig.from_pretrained(config.model) for i in range(1, config.tensor_parallel_size):
config.max_model_len = min(config.max_model_len, config.hf_config.max_position_embeddings) event = mp.Event()
process = mp.Process(target=ModelRunner, args=(config, i, event))
process.start()
self.ps.append(process)
self.events.append(event)
self.model_runner = ModelRunner(config, 0, self.events)
self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True) self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True)
config.eos = self.tokenizer.eos_token_id config.eos = self.tokenizer.eos_token_id
self.model_runner = ModelRunner(config)
self.scheduler = Scheduler(config) self.scheduler = Scheduler(config)
atexit.register(self.exit)
def exit(self):
self.model_runner.call("exit")
for p in self.ps:
p.join()
def add_request(self, prompt: str | list[int], sampling_params: SamplingParams): def add_request(self, prompt: str | list[int], sampling_params: SamplingParams):
if isinstance(prompt, str): if isinstance(prompt, str):
@@ -32,7 +45,7 @@ class LLMEngine:
def step(self): def step(self):
seqs, is_prefill = self.scheduler.schedule() seqs, is_prefill = self.scheduler.schedule()
token_ids = self.model_runner.run(seqs, is_prefill) token_ids = self.model_runner.call("run", seqs, is_prefill)
self.scheduler.postprocess(seqs, token_ids) self.scheduler.postprocess(seqs, token_ids)
outputs = [(seq.seq_id, seq.completion_token_ids) for seq in seqs if seq.is_finished] outputs = [(seq.seq_id, seq.completion_token_ids) for seq in seqs if seq.is_finished]
num_tokens = sum(len(seq) for seq in seqs) if is_prefill else -len(seqs) num_tokens = sum(len(seq) for seq in seqs) if is_prefill else -len(seqs)
@@ -79,4 +92,4 @@ class LLMEngine:
outputs = [{"text": self.tokenizer.decode(token_ids), "token_ids": token_ids} for token_ids in outputs] outputs = [{"text": self.tokenizer.decode(token_ids), "token_ids": token_ids} for token_ids in outputs]
if use_tqdm: if use_tqdm:
pbar.close() pbar.close()
return outputs return outputs

View File

@@ -1,4 +1,8 @@
import pickle
import torch import torch
import torch.distributed as dist
from multiprocess.synchronize import Event
from multiprocess.shared_memory import SharedMemory
from nanovllm.config import Config from nanovllm.config import Config
from nanovllm.engine.sequence import Sequence from nanovllm.engine.sequence import Sequence
@@ -11,12 +15,17 @@ from nanovllm.utils.loader import load_model
class ModelRunner: class ModelRunner:
def __init__(self, config: Config): def __init__(self, config: Config, rank: int, event: Event | list[Event]):
self.config = config self.config = config
hf_config = config.hf_config hf_config = config.hf_config
self.block_size = config.kvcache_block_size self.block_size = config.kvcache_block_size
self.enforce_eager = config.enforce_eager self.enforce_eager = config.enforce_eager
self.world_size = config.tensor_parallel_size
self.rank = rank
self.event = event
dist.init_process_group("nccl", "tcp://localhost:2333", world_size=self.world_size, rank=rank)
torch.cuda.set_device(rank)
default_dtype = torch.get_default_dtype() default_dtype = torch.get_default_dtype()
torch.set_default_dtype(hf_config.torch_dtype) torch.set_default_dtype(hf_config.torch_dtype)
torch.set_default_device("cuda") torch.set_default_device("cuda")
@@ -29,14 +38,63 @@ class ModelRunner:
torch.set_default_device("cpu") torch.set_default_device("cpu")
torch.set_default_dtype(default_dtype) torch.set_default_dtype(default_dtype)
if self.world_size > 1:
if rank == 0:
self.shm = SharedMemory(name="nanovllm", create=True, size=2**20)
dist.barrier()
else:
dist.barrier()
self.shm = SharedMemory(name="nanovllm")
self.loop()
def exit(self):
if self.world_size > 1:
self.shm.close()
if self.rank == 0:
self.shm.unlink()
# dist.destroy_process_group()
def loop(self):
while True:
method_name, args = self.read_shm()
self.call(method_name, *args)
if method_name == "exit":
break
def read_shm(self):
assert self.world_size > 1 and self.rank
self.event.wait()
n = int.from_bytes(self.shm.buf[0:4], "little")
method_name, *args = pickle.loads(self.shm.buf[4:n+4])
self.event.clear()
return method_name, args
def write_shm(self, method_name, *args):
assert self.world_size > 1 and not self.rank
data = pickle.dumps([method_name, *args])
n = len(data)
assert n + 4 <= self.shm.size
self.shm.buf[0:4] = n.to_bytes(4, "little")
self.shm.buf[4:n+4] = data
for event in self.event:
event.set()
def call(self, method_name, *args):
if self.world_size > 1 and self.rank == 0:
self.write_shm(method_name, *args)
method = getattr(self, method_name, None)
assert callable(method)
return method(*args)
def allocate_kv_cache(self, gpu_memory_utilization): def allocate_kv_cache(self, gpu_memory_utilization):
config = self.config config = self.config
hf_config = config.hf_config hf_config = config.hf_config
total, used, _ = get_gpu_memory() total, used, _ = get_gpu_memory()
free = total * gpu_memory_utilization - used free = total * gpu_memory_utilization - used
block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * hf_config.num_key_value_heads * hf_config.head_dim * hf_config.torch_dtype.itemsize num_kv_heads = hf_config.num_key_value_heads // dist.get_world_size()
block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * hf_config.head_dim * hf_config.torch_dtype.itemsize
config.num_kvcache_blocks = int(free) // block_bytes config.num_kvcache_blocks = int(free) // block_bytes
self.kv_cache = torch.zeros(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, hf_config.num_key_value_heads, hf_config.head_dim) self.kv_cache = torch.zeros(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, hf_config.head_dim)
layer_id = 0 layer_id = 0
for module in self.model.modules(): for module in self.model.modules():
if hasattr(module, "k_cache") and hasattr(module, "v_cache"): if hasattr(module, "k_cache") and hasattr(module, "v_cache"):
@@ -66,7 +124,7 @@ class ModelRunner:
for seq in seqs: for seq in seqs:
seqlen = len(seq) seqlen = len(seq)
input_ids.extend(seq[seq.num_cached_tokens:]) input_ids.extend(seq[seq.num_cached_tokens:])
positions.extend(list(range(seq.num_cached_tokens, len(seq)))) positions.extend(list(range(seq.num_cached_tokens, seqlen)))
seqlen_q = seqlen - seq.num_cached_tokens seqlen_q = seqlen - seq.num_cached_tokens
seqlen_k = seqlen seqlen_k = seqlen
cu_seqlens_q.append(cu_seqlens_q[-1] + seqlen_q) cu_seqlens_q.append(cu_seqlens_q[-1] + seqlen_q)
@@ -78,7 +136,7 @@ class ModelRunner:
if i != seq.num_blocks - 1: if i != seq.num_blocks - 1:
end = start + self.block_size end = start + self.block_size
else: else:
end = start + len(seq.last_block()) end = start + seq.last_block_num_tokens
slot_mapping.extend(list(range(start, end))) slot_mapping.extend(list(range(start, end)))
assert len(input_ids) == len(slot_mapping) assert len(input_ids) == len(slot_mapping)
assert len(input_ids) == cu_seqlens_q[-1] assert len(input_ids) == cu_seqlens_q[-1]
@@ -102,7 +160,7 @@ class ModelRunner:
input_ids.append(seq.last_token) input_ids.append(seq.last_token)
positions.append(len(seq)) positions.append(len(seq))
context_lens.append(len(seq)) context_lens.append(len(seq))
slot_mapping.append(seq.block_table[-1] * self.block_size + len(seq.last_block()) - 1) slot_mapping.append(seq.block_table[-1] * self.block_size + seq.last_block_num_tokens - 1)
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
@@ -148,7 +206,7 @@ class ModelRunner:
input_ids, positions = self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs) input_ids, positions = self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs)
temperatures = self.prepare_sample(seqs) temperatures = self.prepare_sample(seqs)
logits = self.run_model(input_ids, positions, is_prefill) logits = self.run_model(input_ids, positions, is_prefill)
token_ids = self.sampler(logits, temperatures).tolist() token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None
reset_context() reset_context()
return token_ids return token_ids

View File

@@ -14,8 +14,6 @@ class Scheduler:
self.block_manager = BlockManager(config.num_kvcache_blocks, config.kvcache_block_size) self.block_manager = BlockManager(config.num_kvcache_blocks, config.kvcache_block_size)
self.waiting: deque[Sequence] = deque() self.waiting: deque[Sequence] = deque()
self.running: deque[Sequence] = deque() self.running: deque[Sequence] = deque()
self.num_finished = 0
self.num_tokens = 0
def is_finished(self): def is_finished(self):
return not self.waiting and not self.running return not self.waiting and not self.running
@@ -67,11 +65,9 @@ class Scheduler:
self.waiting.appendleft(seq) self.waiting.appendleft(seq)
def postprocess(self, seqs: list[Sequence], token_ids: list[int]) -> list[bool]: def postprocess(self, seqs: list[Sequence], token_ids: list[int]) -> list[bool]:
self.num_tokens += len(token_ids)
for seq, token_id in zip(seqs, token_ids): for seq, token_id in zip(seqs, token_ids):
seq.append_token(token_id) seq.append_token(token_id)
if (not seq.ignore_eos and token_id == self.eos) or seq.num_completion_tokens == seq.max_tokens: if (not seq.ignore_eos and token_id == self.eos) or seq.num_completion_tokens == seq.max_tokens:
seq.status = SequenceStatus.FINISHED seq.status = SequenceStatus.FINISHED
self.block_manager.deallocate(seq) self.block_manager.deallocate(seq)
self.running.remove(seq) self.running.remove(seq)
self.num_finished += 1

View File

@@ -19,15 +19,17 @@ class Sequence:
self.seq_id = next(Sequence.counter) self.seq_id = next(Sequence.counter)
self.status = SequenceStatus.WAITING self.status = SequenceStatus.WAITING
self.token_ids = copy(token_ids) self.token_ids = copy(token_ids)
self.last_token = token_ids[-1]
self.num_tokens = len(self.token_ids)
self.num_prompt_tokens = len(token_ids) self.num_prompt_tokens = len(token_ids)
self._num_cached_tokens = 0 self.num_cached_tokens = 0
self.block_table = [] self.block_table = []
self.temperature = sampling_params.temperature self.temperature = sampling_params.temperature
self.max_tokens = sampling_params.max_tokens self.max_tokens = sampling_params.max_tokens
self.ignore_eos = sampling_params.ignore_eos self.ignore_eos = sampling_params.ignore_eos
def __len__(self): def __len__(self):
return len(self.token_ids) return self.num_tokens
def __lt__(self, other): def __lt__(self, other):
return self.seq_id < other.seq_id return self.seq_id < other.seq_id
@@ -41,7 +43,7 @@ class Sequence:
@property @property
def num_completion_tokens(self): def num_completion_tokens(self):
return len(self.token_ids) - self.num_prompt_tokens return self.num_tokens - self.num_prompt_tokens
@property @property
def prompt_token_ids(self): def prompt_token_ids(self):
@@ -51,33 +53,29 @@ class Sequence:
def completion_token_ids(self): def completion_token_ids(self):
return self.token_ids[self.num_prompt_tokens:] return self.token_ids[self.num_prompt_tokens:]
@property
def num_cached_tokens(self):
return self._num_cached_tokens
@num_cached_tokens.setter
def num_cached_tokens(self, num_cached_tokens):
assert num_cached_tokens % self.block_size == 0
self._num_cached_tokens = num_cached_tokens
@property @property
def num_cached_blocks(self): def num_cached_blocks(self):
return self.num_cached_tokens // self.block_size return self.num_cached_tokens // self.block_size
@property @property
def num_blocks(self): def num_blocks(self):
return (len(self.token_ids) + self.block_size - 1) // self.block_size return (self.num_tokens + self.block_size - 1) // self.block_size
@property @property
def last_token(self): def last_block_num_tokens(self):
return self.token_ids[-1] return self.num_tokens - (self.num_blocks - 1) * self.block_size
def block(self, i): def block(self, i):
assert 0 <= i < self.num_blocks
return self.token_ids[i*self.block_size: (i+1)*self.block_size] return self.token_ids[i*self.block_size: (i+1)*self.block_size]
def last_block(self):
n = self.num_blocks
return self.token_ids[(n-1)*self.block_size:]
def append_token(self, token_id: int): def append_token(self, token_id: int):
self.token_ids.append(token_id) self.token_ids.append(token_id)
self.last_token = token_id
self.num_tokens += 1
def __getstate__(self):
state = vars(self).copy()
if self.num_completion_tokens:
state.pop("token_ids")
return state

View File

@@ -11,4 +11,4 @@ class SiluAndMul(nn.Module):
@torch.compile @torch.compile
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
x, y = x.chunk(2, -1) x, y = x.chunk(2, -1)
return F.silu(x) * y return y.mul_(F.silu(x))

View File

@@ -14,8 +14,8 @@ class VocabParallelEmbedding(nn.Module):
embedding_dim: int, embedding_dim: int,
): ):
super().__init__() super().__init__()
self.tp_rank = 0 # get_tensor_model_parallel_rank() self.tp_rank = dist.get_rank()
self.tp_size = 1 # get_tensor_model_parallel_world_size() self.tp_size = dist.get_world_size()
assert num_embeddings % self.tp_size == 0 assert num_embeddings % self.tp_size == 0
self.num_embeddings = num_embeddings self.num_embeddings = num_embeddings
self.num_embeddings_per_partition = self.num_embeddings // self.tp_size self.num_embeddings_per_partition = self.num_embeddings // self.tp_size
@@ -39,7 +39,7 @@ class VocabParallelEmbedding(nn.Module):
x = mask * (x - self.vocab_start_idx) x = mask * (x - self.vocab_start_idx)
y = F.embedding(x, self.weight) y = F.embedding(x, self.weight)
if self.tp_size > 1: if self.tp_size > 1:
y = mask * y y = mask.unsqueeze(1) * y
dist.all_reduce(y) dist.all_reduce(y)
return y return y
@@ -65,8 +65,8 @@ class ParallelLMHead(VocabParallelEmbedding):
last_indices = context.cu_seqlens_q[1:] - 1 last_indices = context.cu_seqlens_q[1:] - 1
x = x[last_indices].contiguous() x = x[last_indices].contiguous()
logits = F.linear(x, self.weight, self.bias) logits = F.linear(x, self.weight, self.bias)
# if self.tp_size > 1: if self.tp_size > 1:
# all_logits = [torch.empty_like(logits) for _ in range(self.tp_size)] all_logits = [torch.empty_like(logits) for _ in range(self.tp_size)] if self.tp_rank == 0 else None
# dist.gather(logits, all_logits, 0) dist.gather(logits, all_logits, 0)
# logits = torch.cat(all_logits, -1) logits = torch.cat(all_logits, -1) if self.tp_rank == 0 else None
return logits if self.tp_rank == 0 else None return logits

View File

@@ -21,8 +21,8 @@ class LinearBase(nn.Module):
self.input_size = input_size self.input_size = input_size
self.output_size = output_size self.output_size = output_size
self.tp_dim = tp_dim self.tp_dim = tp_dim
self.tp_rank = 0 # get_tensor_model_parallel_rank() self.tp_rank = dist.get_rank()
self.tp_size = 1 # get_tensor_model_parallel_world_size() self.tp_size = dist.get_world_size()
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError
@@ -65,7 +65,6 @@ class ColumnParallelLinear(LinearBase):
self.input_size_per_partition = input_size self.input_size_per_partition = input_size
self.output_size_per_partition = divide(output_size, self.tp_size) self.output_size_per_partition = divide(output_size, self.tp_size)
self.output_partition_sizes = [self.output_size_per_partition] self.output_partition_sizes = [self.output_size_per_partition]
# If QKV or MergedColumn, use output size of each partition.
if hasattr(self, "output_sizes"): if hasattr(self, "output_sizes"):
self.output_partition_sizes = [ self.output_partition_sizes = [
divide(output_size, self.tp_size) divide(output_size, self.tp_size)
@@ -101,8 +100,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
bias: bool = False, bias: bool = False,
): ):
self.output_sizes = output_sizes self.output_sizes = output_sizes
tp_size = 1 # get_tensor_model_parallel_world_size()
assert all(output_size % tp_size == 0 for output_size in output_sizes)
super().__init__(input_size, sum(output_sizes), bias=bias) super().__init__(input_size, sum(output_sizes), bias=bias)
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: int): def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: int):
@@ -110,7 +107,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size
shard_size = self.output_sizes[loaded_shard_id] // self.tp_size shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size) param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size)
# loaded_weight = loaded_weight.narrow(self.tp_dim, self.tp_rank * shard_size, shard_size) loaded_weight = loaded_weight.chunk(self.tp_size, self.tp_dim)[self.tp_rank]
assert param_data.size() == loaded_weight.size() assert param_data.size() == loaded_weight.size()
param_data.copy_(loaded_weight) param_data.copy_(loaded_weight)
@@ -131,8 +128,7 @@ class QKVParallelLinear(ColumnParallelLinear):
if total_num_kv_heads is None: if total_num_kv_heads is None:
total_num_kv_heads = total_num_heads total_num_kv_heads = total_num_heads
self.total_num_kv_heads = total_num_kv_heads self.total_num_kv_heads = total_num_kv_heads
# Divide the weight matrix along the last dimension. tp_size = dist.get_world_size()
tp_size = 1 # get_tensor_model_parallel_world_size()
self.num_heads = divide(self.total_num_heads, tp_size) self.num_heads = divide(self.total_num_heads, tp_size)
self.num_kv_heads = divide(self.total_num_kv_heads, tp_size) self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
input_size = self.hidden_size input_size = self.hidden_size
@@ -158,7 +154,7 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_size = self.num_kv_heads * self.head_size shard_size = self.num_kv_heads * self.head_size
shard_offset = self.num_heads * self.head_size + self.num_kv_heads * self.head_size shard_offset = self.num_heads * self.head_size + self.num_kv_heads * self.head_size
param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size) param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size)
# loaded_weight = loaded_weight.narrow(self.tp_dim, self.tp_rank * shard_size, shard_size) loaded_weight = loaded_weight.chunk(self.tp_size, self.tp_dim)[self.tp_rank]
assert param_data.size() == loaded_weight.size() assert param_data.size() == loaded_weight.size()
param_data.copy_(loaded_weight) param_data.copy_(loaded_weight)

View File

@@ -70,4 +70,4 @@ def get_rope(
): ):
assert rope_scaling is None assert rope_scaling is None
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base) rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base)
return rotary_emb return rotary_emb

View File

@@ -2,4 +2,4 @@ from nanovllm.engine.llm_engine import LLMEngine
class LLM(LLMEngine): class LLM(LLMEngine):
pass pass

View File

@@ -1,5 +1,6 @@
import torch import torch
from torch import nn from torch import nn
import torch.distributed as dist
from transformers import Qwen3Config from transformers import Qwen3Config
from nanovllm.layers.activation import SiluAndMul from nanovllm.layers.activation import SiluAndMul
@@ -26,7 +27,7 @@ class Qwen3Attention(nn.Module):
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
tp_size = 1 # get_tensor_model_parallel_world_size() tp_size = dist.get_world_size()
self.total_num_heads = num_heads self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0 assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size self.num_heads = self.total_num_heads // tp_size

View File

@@ -25,4 +25,4 @@ def set_context(is_prefill, cu_seqlens_q=None, cu_seqlens_k=None, max_seqlen_q=0
def reset_context(): def reset_context():
global _CONTEXT global _CONTEXT
_CONTEXT = Context() _CONTEXT = Context()

View File

@@ -10,7 +10,6 @@ def default_weight_loader(param: nn.Parameter, loaded_weight: torch.Tensor):
def load_model(model: nn.Module, path: str): def load_model(model: nn.Module, path: str):
assert os.path.isdir(path)
packed_modules_mapping = getattr(model, "packed_modules_mapping", {}) packed_modules_mapping = getattr(model, "packed_modules_mapping", {})
for file in glob(os.path.join(path, "*.safetensors")): for file in glob(os.path.join(path, "*.safetensors")):
with safe_open(file, "pt", "cpu") as f: with safe_open(file, "pt", "cpu") as f:

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "nano-vllm" name = "nano-vllm"
version = "0.1.0" version = "0.2.0"
authors = [{ name = "Xingkai Yu" }] authors = [{ name = "Xingkai Yu" }]
license = "MIT" license = "MIT"
license-files = ["LICENSE"] license-files = ["LICENSE"]