From a5a4909e6a011e5443e354362a7e9c7b8b203e64 Mon Sep 17 00:00:00 2001 From: GeeeekExplorer <2651904866@qq.com> Date: Tue, 10 Jun 2025 00:23:23 +0800 Subject: [PATCH] init commit --- .gitignore | 195 ++++++++++++++++++++++ LICENSE | 21 +++ README.md | 1 + bench.py | 20 +++ example.py | 29 ++++ nanovllm/__init__.py | 2 + nanovllm/config.py | 20 +++ nanovllm/engine/block_manager.py | 118 +++++++++++++ nanovllm/engine/llm_engine.py | 66 ++++++++ nanovllm/engine/model_runner.py | 198 ++++++++++++++++++++++ nanovllm/engine/scheduler.py | 84 ++++++++++ nanovllm/engine/sequence.py | 73 ++++++++ nanovllm/layers/activation.py | 14 ++ nanovllm/layers/attention.py | 86 ++++++++++ nanovllm/layers/embed_head.py | 72 ++++++++ nanovllm/layers/layernorm.py | 51 ++++++ nanovllm/layers/linear.py | 199 ++++++++++++++++++++++ nanovllm/layers/rotary_embedding.py | 73 ++++++++ nanovllm/layers/sampler.py | 17 ++ nanovllm/llm.py | 5 + nanovllm/models/qwen3.py | 247 ++++++++++++++++++++++++++++ nanovllm/sampling_params.py | 8 + nanovllm/utils/context.py | 28 ++++ nanovllm/utils/memory.py | 14 ++ nanovllm/utils/timer.py | 31 ++++ requirements.txt | 5 + 26 files changed, 1677 insertions(+) create mode 100644 .gitignore create mode 100644 LICENSE create mode 100644 README.md create mode 100644 bench.py create mode 100644 example.py create mode 100644 nanovllm/__init__.py create mode 100644 nanovllm/config.py create mode 100644 nanovllm/engine/block_manager.py create mode 100644 nanovllm/engine/llm_engine.py create mode 100644 nanovllm/engine/model_runner.py create mode 100644 nanovllm/engine/scheduler.py create mode 100644 nanovllm/engine/sequence.py create mode 100755 nanovllm/layers/activation.py create mode 100644 nanovllm/layers/attention.py create mode 100644 nanovllm/layers/embed_head.py create mode 100755 nanovllm/layers/layernorm.py create mode 100755 nanovllm/layers/linear.py create mode 100644 nanovllm/layers/rotary_embedding.py create mode 100644 nanovllm/layers/sampler.py create mode 100644 nanovllm/llm.py create mode 100755 nanovllm/models/qwen3.py create mode 100644 nanovllm/sampling_params.py create mode 100644 nanovllm/utils/context.py create mode 100644 nanovllm/utils/memory.py create mode 100644 nanovllm/utils/timer.py create mode 100644 requirements.txt diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..00fafbd --- /dev/null +++ b/.gitignore @@ -0,0 +1,195 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# UV +# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +#uv.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock +#poetry.toml + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/latest/usage/project/#working-with-version-control +.pdm.toml +.pdm-python +.pdm-build/ + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +# Abstra +# Abstra is an AI-powered process automation framework. +# Ignore directories containing user credentials, local state, and settings. +# Learn more at https://abstra.io/docs +.abstra/ + +# Visual Studio Code +# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore +# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore +# and can be added to the global gitignore or merged into this file. However, if you prefer, +# you could uncomment the following to ignore the entire vscode folder +# .vscode/ + +# Ruff stuff: +.ruff_cache/ + +# PyPI configuration file +.pypirc + +# Cursor +# Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to +# exclude from AI features like autocomplete and code analysis. Recommended for sensitive data +# refer to https://docs.cursor.com/context/ignore-files +.cursorignore +.cursorindexingignore \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..8eb3afc --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025 Xingkai Yu + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..24fef66 --- /dev/null +++ b/README.md @@ -0,0 +1 @@ +# Nano-VLLM \ No newline at end of file diff --git a/bench.py b/bench.py new file mode 100644 index 0000000..967df23 --- /dev/null +++ b/bench.py @@ -0,0 +1,20 @@ +import os +import time +import torch +from nanovllm import LLM, SamplingParams + + +batch_size = 256 +seq_len = 1024 +max_tokens = 512 + +path = os.path.expanduser("~/huggingface/Qwen3-0.6B/") +llm = LLM(path, enforce_eager=False) + +prompt_token_ids = torch.randint(0, 10240, (batch_size, seq_len)).tolist() +sampling_params = SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=max_tokens) + +t = time.time() +completions = llm.generate(prompt_token_ids, sampling_params) +troughput = batch_size * max_tokens / (time.time() - t) +print(f"Throughput: {troughput: .2f}") diff --git a/example.py b/example.py new file mode 100644 index 0000000..7b7cf06 --- /dev/null +++ b/example.py @@ -0,0 +1,29 @@ +import os +from nanovllm import LLM, SamplingParams +from transformers import AutoTokenizer + + +path = os.path.expanduser("~/huggingface/Qwen3-0.6B/") +tokenizer = AutoTokenizer.from_pretrained(path) +llm = LLM(path, enforce_eager=True) + +sampling_params = SamplingParams(temperature=0.6, max_tokens=256) +prompts = [ + "自我介绍一下吧!", + "列出100内所有素数", +] +prompts = [ + tokenizer.apply_chat_template( + [{"role": "user", "content": prompt}], + tokenize=False, + add_generation_prompt=True, + enable_thinking=True + ) + for prompt in prompts +] +completions = llm.generate(prompts, sampling_params) + +for p, c in zip(prompts, completions): + print("\n\n") + print(f"Prompt: {p}") + print(f"Completion: {c}") diff --git a/nanovllm/__init__.py b/nanovllm/__init__.py new file mode 100644 index 0000000..e84e8cc --- /dev/null +++ b/nanovllm/__init__.py @@ -0,0 +1,2 @@ +from nanovllm.llm import LLM +from nanovllm.sampling_params import SamplingParams \ No newline at end of file diff --git a/nanovllm/config.py b/nanovllm/config.py new file mode 100644 index 0000000..e669d50 --- /dev/null +++ b/nanovllm/config.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass +from transformers import AutoConfig + + +@dataclass +class Config: + model: str = '' + max_num_batched_tokens: int = 16384 + max_num_seqs: int = 512 + max_model_len: int = 4096 + gpu_memory_utilization: float = 0.95 + enforce_eager: bool = False + hf_config: AutoConfig | None = None + eos: int = -1 + kvcache_block_size: int = 256 + num_kvcache_blocks: int = -1 + + def __post_init__(self): + assert self.model + assert self.kvcache_block_size == 256 \ No newline at end of file diff --git a/nanovllm/engine/block_manager.py b/nanovllm/engine/block_manager.py new file mode 100644 index 0000000..4739c5b --- /dev/null +++ b/nanovllm/engine/block_manager.py @@ -0,0 +1,118 @@ +from collections import deque +import xxhash +import numpy as np + +from nanovllm.engine.sequence import Sequence + + +def compute_hash(token_ids: list[int], prefix: int = -1): + h = xxhash.xxh64() + if prefix != -1: + h.update(prefix.to_bytes(8)) + h.update(np.array(token_ids).tobytes()) + return h.intdigest() + + +class Block: + + def __init__(self, block_id): + self.block_id = block_id + self.ref_count = 0 + self.hash = -1 + self.token_ids = [] + + def update(self, hash: int, token_ids: list[int]): + assert hash != -1 + assert len(token_ids) == 256 + self.hash = hash + self.token_ids = token_ids + + def reset(self): + self.ref_count = 1 + self.hash = -1 + self.token_ids = [] + + def __repr__(self): + return f"{(self.block_id, self.ref_count, self.hash)}" + + +class BlockManager: + + def __init__(self, num_blocks: int, block_size: int = 256): + assert block_size == 256 + self.block_size = block_size + self.blocks: list[Block] = [Block(i) for i in range(num_blocks)] + self.hash_to_block_id: dict[int, int] = dict() + self.free_block_ids: deque[int] = deque(range(num_blocks)) + self.used_block_ids: set[int] = set() + + def _allocate_block(self, block_id: int): + block = self.blocks[block_id] + assert block.ref_count == 0 + block.reset() + self.free_block_ids.remove(block_id) + self.used_block_ids.add(block_id) + return self.blocks[block_id] + + def _deallocate_block(self, block_id: int): + assert self.blocks[block_id].ref_count == 0 + self.used_block_ids.remove(block_id) + self.free_block_ids.append(block_id) + + def can_allocate(self, seq: Sequence): + return seq.num_blocks <= len(self.free_block_ids) + + def allocate(self, seq: Sequence): + assert not seq.block_table + h = -1 + cache_miss = False + for i in range(seq.num_blocks): + token_ids = seq.block(i, self.block_size) + h = compute_hash(token_ids, h) if len(token_ids) == self.block_size else -1 + block_id = self.hash_to_block_id.get(h, -1) + if block_id == -1 or self.blocks[block_id].token_ids != token_ids: + cache_miss = True + if cache_miss: + block_id = self.free_block_ids[0] + block = self._allocate_block(block_id) + else: + seq.num_cached_tokens += self.block_size + if block_id in self.used_block_ids: + block = self.blocks[block_id] + block.ref_count += 1 + else: + block = self._allocate_block(block_id) + if h != -1: + block.update(h, token_ids) + self.hash_to_block_id[h] = block_id + seq.block_table.append(block_id) + + def deallocate(self, seq: Sequence): + for block_id in seq.block_table: + block = self.blocks[block_id] + block.ref_count -= 1 + if block.ref_count == 0: + self._deallocate_block(block_id) + seq.num_cached_tokens = 0 + seq.block_table.clear() + + def can_append(self): + return len(self.free_block_ids) >= 1 + + def may_append(self, seq: Sequence): + block_table = seq.block_table + last_block = self.blocks[block_table[-1]] + if len(seq) % self.block_size == 1: + assert last_block.hash != -1 + block_id = self.free_block_ids[0] + self._allocate_block(block_id) + block_table.append(block_id) + elif len(seq) % self.block_size == 0: + assert last_block.hash == -1 + token_ids = seq.last_block(self.block_size) + prefix = self.blocks[block_table[-2]].hash if len(block_table) > 1 else -1 + h = compute_hash(token_ids, prefix) + last_block.update(h, token_ids) + self.hash_to_block_id[h] = last_block.block_id + else: + assert last_block.hash == -1 diff --git a/nanovllm/engine/llm_engine.py b/nanovllm/engine/llm_engine.py new file mode 100644 index 0000000..e22ce7f --- /dev/null +++ b/nanovllm/engine/llm_engine.py @@ -0,0 +1,66 @@ +from collections import defaultdict +from tqdm.auto import tqdm +from transformers import AutoConfig, AutoTokenizer + +from nanovllm.config import Config +from nanovllm.sampling_params import SamplingParams +from nanovllm.engine.sequence import Sequence +from nanovllm.engine.scheduler import Scheduler +from nanovllm.engine.model_runner import ModelRunner + + +class LLMEngine: + + def __init__(self, model, **kwargs): + config = Config(model) + for k, v in kwargs.items(): + if hasattr(config, k): + setattr(config, k, v) + config.hf_config = AutoConfig.from_pretrained(config.model) + config.max_model_len = min(config.max_model_len, config.hf_config.max_position_embeddings) + self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True) + config.eos = self.tokenizer.eos_token_id + self.model_runner = ModelRunner(config) + self.scheduler = Scheduler(config) + + def add_request(self, prompt: str | list[int], sampling_params: SamplingParams): + if isinstance(prompt, str): + prompt = self.tokenizer.encode(prompt) + seq = Sequence(prompt, sampling_params) + self.scheduler.add(seq) + + def step(self): + seqs, is_prefill = self.scheduler.schedule() + token_ids = self.model_runner.run(seqs, is_prefill) + finished = self.scheduler.postprocess(seqs, token_ids) + return [(seq.seq_id, token_id, finish) for seq, token_id, finish in zip(seqs, token_ids, finished)] + + def is_finished(self): + return self.scheduler.is_finished() + + def generate( + self, + prompts: list[str] | list[list[int]], + sampling_params: SamplingParams | list[SamplingParams], + use_tqdm: bool = True, + ) -> list[str]: + if use_tqdm: + pbar = tqdm(total=len(prompts), + desc="Processed prompts", + ) + if not isinstance(SamplingParams, list): + sampling_params = [sampling_params] * len(prompts) + for prompt, sp in zip(prompts, sampling_params): + self.add_request(prompt, sp) + outputs = defaultdict(list) + while not self.is_finished(): + output = self.step() + for seq_id, token_id, finish in output: + outputs[seq_id].append(token_id) + if use_tqdm and finish: + pbar.update(1) + outputs = [outputs[seq_id] for seq_id in sorted(outputs)] + outputs = [self.tokenizer.decode(token_ids) for token_ids in outputs] + if use_tqdm: + pbar.close() + return outputs diff --git a/nanovllm/engine/model_runner.py b/nanovllm/engine/model_runner.py new file mode 100644 index 0000000..4724674 --- /dev/null +++ b/nanovllm/engine/model_runner.py @@ -0,0 +1,198 @@ +import torch + +from nanovllm.config import Config +from nanovllm.engine.sequence import Sequence +from nanovllm.utils.context import set_context, get_context, reset_context +from nanovllm.utils.memory import get_gpu_memory +from nanovllm.models.qwen3 import Qwen3ForCausalLM +from nanovllm.layers.sampler import Sampler + + +class ModelRunner: + + def __init__(self, config: Config): + self.config = config + hf_config = config.hf_config + self.block_size = config.kvcache_block_size + self.enforce_eager = config.enforce_eager + + default_dtype = torch.get_default_dtype() + torch.set_default_dtype(hf_config.torch_dtype) + torch.set_default_device("cuda") + self.model = Qwen3ForCausalLM(hf_config) + self.model.load_weights(config.model) + self.sampler = Sampler() + self.allocate_kv_cache(config.gpu_memory_utilization) + if not self.enforce_eager: + self.capture_model() + torch.set_default_device("cpu") + torch.set_default_dtype(default_dtype) + + def allocate_kv_cache(self, gpu_memory_utilization): + config = self.config + hf_config = config.hf_config + total, used, _ = get_gpu_memory() + free = total * gpu_memory_utilization - used + block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * hf_config.num_key_value_heads * hf_config.head_dim * hf_config.torch_dtype.itemsize + config.num_kvcache_blocks = int(free * 1e6) // block_bytes + self.kv_cache = torch.zeros(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, hf_config.num_key_value_heads, hf_config.head_dim) + layer_id = 0 + for module in self.model.modules(): + if hasattr(module, "k_cache") and hasattr(module, "v_cache"): + module.k_cache = self.kv_cache[0, layer_id] + module.v_cache = self.kv_cache[1, layer_id] + layer_id += 1 + + def preare_block_tables(self, seqs: list[Sequence]): + max_len = max(len(seq.block_table) for seq in seqs) + block_tables = [ + seq.block_table + [-1] * (max_len - len(seq.block_table)) + for seq in seqs + ] + block_tables = torch.tensor(block_tables, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) + return block_tables + + def prepare_prefill(self, seqs: list[Sequence]): + input_ids = [] + positions = [] + cu_seqlens_q = [0] + cu_seqlens_k = [0] + max_seqlen_q = 0 + max_seqlen_k = 0 + slot_mapping = [] + context_lens = None + block_tables = None + for seq in seqs: + seqlen = len(seq) + input_ids.extend(seq[seq.num_cached_tokens:]) + positions.extend(list(range(seq.num_cached_tokens, len(seq)))) + seqlen_q = seqlen - seq.num_cached_tokens + seqlen_k = seqlen + cu_seqlens_q.append(cu_seqlens_q[-1] + seqlen_q) + cu_seqlens_k.append(cu_seqlens_k[-1] + seqlen_k) + max_seqlen_q = max(seqlen_q, max_seqlen_q) + max_seqlen_k = max(seqlen_k, max_seqlen_k) + for i in range(seq.num_cached_blocks, seq.num_blocks): + start = seq.block_table[i] * self.block_size + if i != seq.num_blocks - 1: + end = start + self.block_size + else: + end = start + len(seq.last_block()) + slot_mapping.extend(list(range(start, end))) + assert len(input_ids) == len(slot_mapping) + assert len(input_ids) == cu_seqlens_q[-1] + if cu_seqlens_k[-1] > cu_seqlens_q[-1]: # prefix cache + context_lens = torch.tensor([len(seq) for seq in seqs], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) + block_tables = self.preare_block_tables(seqs) + input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) + positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) + cu_seqlens_q = torch.tensor(cu_seqlens_q, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) + cu_seqlens_k = torch.tensor(cu_seqlens_k, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) + slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) + set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, context_lens, block_tables) + return input_ids, positions + + def prepare_decode(self, seqs: list[Sequence]): + input_ids = [] + positions = [] + slot_mapping = [] + context_lens = [] + for seq in seqs: + input_ids.append(seq.last_token) + positions.append(len(seq)) + context_lens.append(len(seq)) + slot_mapping.append(seq.block_table[-1] * self.block_size + len(seq.last_block())) + input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) + positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) + slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) + context_lens = torch.tensor(context_lens, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) + block_tables = self.preare_block_tables(seqs) + set_context(False, slot_mapping=slot_mapping, context_lens=context_lens, block_tables=block_tables) + return input_ids, positions + + def prepare_sample(self, seqs: list[Sequence]): + temperatures = [] + for seq in seqs: + temperatures.append(seq.temperature) + temperatures = torch.tensor(temperatures, dtype=torch.float32, pin_memory=True).cuda(non_blocking=True) + return temperatures + + @torch.inference_mode() + def run_model(self, input_ids: torch.Tensor, positions: torch.Tensor, is_prefill): + if is_prefill or self.enforce_eager or input_ids.size(0) > 256: + return self.model.compute_logits(self.model(input_ids, positions)) + else: + bs = input_ids.size(0) + context = get_context() + self.reset_graph_vars() + graph = self.graphs[next(x for x in self.graph_bs if x >= bs)] + graph_vars = self.graph_vars + graph_vars["input_ids"][:bs] = input_ids + graph_vars["positions"][:bs] = positions + graph_vars["slot_mapping"][:bs] = context.slot_mapping + graph_vars["context_lens"][:bs] = context.context_lens + graph_vars["block_tables"][:bs, :context.block_tables.size(1)] = context.block_tables + graph.replay() + return self.model.compute_logits(graph_vars["outputs"][:bs]) + + def reset_graph_vars(self): + graph_vars = self.graph_vars + graph_vars["input_ids"].zero_() + graph_vars["positions"].zero_() + graph_vars["slot_mapping"].zero_() + graph_vars["context_lens"].zero_() + graph_vars["block_tables"].zero_() + + def run(self, seqs: list[Sequence], is_prefill: bool) -> list[int]: + input_ids, positions = self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs) + temperatures = self.prepare_sample(seqs) + logits = self.run_model(input_ids, positions, is_prefill) + token_ids = self.sampler(logits, temperatures).tolist() + reset_context() + return token_ids + + @torch.inference_mode() + def capture_model(self): + get_rng_state = torch.cuda.get_rng_state + set_rng_state = torch.cuda.set_rng_state + rng_state = torch.cuda.get_rng_state() + torch.cuda.get_rng_state = lambda: rng_state + torch.cuda.set_rng_state = lambda _: None + + config = self.config + hf_config = config.hf_config + max_bs = min(self.config.max_num_seqs, 256) + max_num_blocks = (config.max_model_len + self.block_size - 1) // self.block_size + input_ids = torch.zeros(max_bs, dtype=torch.int64) + positions = torch.zeros(max_bs, dtype=torch.int64) + slot_mapping = torch.zeros(max_bs, dtype=torch.int32) + context_lens = torch.zeros(max_bs, dtype=torch.int32) + block_tables = torch.zeros(max_bs, max_num_blocks, dtype=torch.int32) + outputs = torch.zeros(max_bs, hf_config.hidden_size) + self.graph_bs = [1, 2, 4, 8, 16] + list(range(32, max_bs + 1, 32)) + self.graphs = {} + self.graph_pool = None + + for bs in reversed(self.graph_bs): + graph = torch.cuda.CUDAGraph() + set_context(False, slot_mapping=slot_mapping[:bs], context_lens=context_lens[:bs], block_tables=block_tables[:bs]) + outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # warmup + with torch.cuda.graph(graph, self.graph_pool): + outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # capture + if self.graph_pool is None: + self.graph_pool = graph.pool() + self.graphs[bs] = graph + torch.cuda.synchronize() + reset_context() + + self.graph_vars = dict( + input_ids=input_ids, + positions=positions, + slot_mapping=slot_mapping, + context_lens=context_lens, + block_tables=block_tables, + outputs=outputs, + ) + + torch.cuda.get_rng_state = get_rng_state + torch.cuda.set_rng_state = set_rng_state diff --git a/nanovllm/engine/scheduler.py b/nanovllm/engine/scheduler.py new file mode 100644 index 0000000..1e5e684 --- /dev/null +++ b/nanovllm/engine/scheduler.py @@ -0,0 +1,84 @@ +from collections import deque + +from nanovllm.config import Config +from nanovllm.engine.sequence import Sequence, SequenceStatus +from nanovllm.engine.block_manager import BlockManager + + +class Scheduler: + + def __init__(self, config: Config): + self.max_num_seqs = config.max_num_seqs + self.max_num_batched_tokens = config.max_num_batched_tokens + self.eos = config.eos + self.block_manager = BlockManager(config.num_kvcache_blocks, config.kvcache_block_size) + self.waiting: deque[Sequence] = deque() + self.running: deque[Sequence] = deque() + self.num_finished = 0 + self.num_tokens = 0 + + def is_finished(self): + return not self.waiting and not self.running + + def add(self, seq: Sequence): + self.waiting.append(seq) + + def schedule(self) -> tuple[list[Sequence], SequenceStatus]: + # prefill + scheduled_seqs = [] + num_seqs = 0 + num_batched_tokens = 0 + while self.waiting and num_seqs < self.max_num_seqs: + seq = self.waiting[0] + if num_batched_tokens + len(seq) > self.max_num_batched_tokens or not self.block_manager.can_allocate(seq): + break + num_seqs += 1 + self.block_manager.allocate(seq) + num_batched_tokens += len(seq) - seq.num_cached_tokens + seq.status = SequenceStatus.RUNNING + self.waiting.popleft() + self.running.append(seq) + scheduled_seqs.append(seq) + if scheduled_seqs: + return scheduled_seqs, True + + # decode + # self.running = deque(sorted(self.running)) + while self.running and num_seqs < self.max_num_seqs: + seq = self.running.popleft() + while not self.block_manager.can_append(): + if self.running: + self.preempt(self.running.pop()) + else: + self.preempt(seq) + break + else: + num_seqs += 1 + self.block_manager.may_append(seq) + scheduled_seqs.append(seq) + running = deque(scheduled_seqs) + running.extend(self.running) + self.running = running + if scheduled_seqs: + return scheduled_seqs, False + + def preempt(self, seq: Sequence): + seq.status = SequenceStatus.WAITING + self.block_manager.deallocate(seq) + self.waiting.appendleft(seq) + return True + + def postprocess(self, seqs: list[Sequence], token_ids: list[int]) -> list[bool]: + self.num_tokens += len(token_ids) + finished = [] + for seq, token_id in zip(seqs, token_ids): + seq.append_token(token_id) + if token_id == self.eos or seq.num_completion_tokens == seq.max_tokens: + seq.status = SequenceStatus.FINISHED + self.block_manager.deallocate(seq) + self.running.remove(seq) + self.num_finished += 1 + finished.append(True) + else: + finished.append(False) + return finished diff --git a/nanovllm/engine/sequence.py b/nanovllm/engine/sequence.py new file mode 100644 index 0000000..5d4f792 --- /dev/null +++ b/nanovllm/engine/sequence.py @@ -0,0 +1,73 @@ +from copy import copy +from enum import Enum, auto +from itertools import count + +from nanovllm.sampling_params import SamplingParams + + +class SequenceStatus(Enum): + WAITING = auto() + RUNNING = auto() + FINISHED = auto() + + +class Sequence: + counter = count() + + def __init__(self, token_ids: list[int], sampling_params: SamplingParams): + self.seq_id = next(Sequence.counter) + self.status = SequenceStatus.WAITING + self.token_ids = copy(token_ids) + self.num_prompt_tokens = len(token_ids) + self._num_cached_tokens = 0 + self.block_table = [] + self.temperature = sampling_params.temperature + self.max_tokens = sampling_params.max_tokens + self.ignore_eos = sampling_params.ignore_eos + + def __len__(self): + return len(self.token_ids) + + def __lt__(self, other): + return self.seq_id < other.seq_id + + def __getitem__(self, key): + return self.token_ids[key] + + @property + def num_completion_tokens(self): + return len(self.token_ids) - self.num_prompt_tokens + + @property + def num_cached_tokens(self): + return self._num_cached_tokens + + @num_cached_tokens.setter + def num_cached_tokens(self, num_cached_tokens): + assert num_cached_tokens % 256 == 0 + self._num_cached_tokens = num_cached_tokens + + @property + def num_cached_blocks(self): + return self.num_cached_tokens // 256 + + @property + def num_blocks(self): + return (len(self.token_ids) + 255) // 256 + + @property + def last_token(self): + return self.token_ids[-1] + + def block(self, i, block_size=256): + return self.token_ids[i*block_size: (i+1)*block_size] + + def last_block(self, block_size=256): + n = self.num_blocks + t = len(self) + block_size - self.num_blocks * block_size + x = self.token_ids[(n-1)*block_size:] + assert len(x) == t + return x + + def append_token(self, token_id: int): + self.token_ids.append(token_id) diff --git a/nanovllm/layers/activation.py b/nanovllm/layers/activation.py new file mode 100755 index 0000000..041ee20 --- /dev/null +++ b/nanovllm/layers/activation.py @@ -0,0 +1,14 @@ +import torch +from torch import nn +import torch.nn.functional as F + + +class SiluAndMul(nn.Module): + + def __init__(self): + super().__init__() + + @torch.compile + def forward(self, x: torch.Tensor) -> torch.Tensor: + x, y = x.chunk(2, -1) + return F.silu(x) * y diff --git a/nanovllm/layers/attention.py b/nanovllm/layers/attention.py new file mode 100644 index 0000000..6e58865 --- /dev/null +++ b/nanovllm/layers/attention.py @@ -0,0 +1,86 @@ +import torch +from torch import nn +import triton +import triton.language as tl + +from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache +# from nanovllm.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache +from nanovllm.utils.context import get_context + + +@triton.jit +def store_kvcache_kernel( + key_ptr, + key_stride, + value_ptr, + value_stride, + k_cache_ptr, + v_cache_ptr, + slot_mapping_ptr, + D: tl.constexpr, +): + idx = tl.program_id(0) + key_offsets = idx * key_stride + tl.arange(0, D) + value_offsets = idx * value_stride + tl.arange(0, D) + key = tl.load(key_ptr + key_offsets) + value = tl.load(value_ptr + value_offsets) + slot = tl.load(slot_mapping_ptr + idx) + cache_offsets = slot * D + tl.arange(0, D) + tl.store(k_cache_ptr + cache_offsets, key) + tl.store(v_cache_ptr + cache_offsets, value) + + +def store_kvcache(key: torch.Tensor, value: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, slot_mapping: torch.Tensor): + N, num_heads, head_dim = key.shape + D = num_heads * head_dim + assert key.stride(-1) == 1 and value.stride(-1) == 1 + assert key.stride(1) == head_dim and value.stride(1) == head_dim + assert k_cache.stride(1) == D and v_cache.stride(1) == D + assert slot_mapping.numel() == N + store_kvcache_kernel[(N,)](key, key.stride(0), value, value.stride(0), k_cache, v_cache, slot_mapping, D) + + +class Attention(nn.Module): + + def __init__( + self, + num_heads, + head_dim, + scale, + num_kv_heads, + ): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + self.scale = scale + self.num_kv_heads = num_kv_heads + self.k_cache = self.v_cache = torch.tensor([]) + + def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): + o: torch.Tensor + q = q.view(-1, self.num_heads, self.head_dim) + k = k.view(-1, self.num_kv_heads, self.head_dim) + v = v.view(-1, self.num_kv_heads, self.head_dim) + context = get_context() + k_cache = self.k_cache + v_cache = self.v_cache + if context.is_prefill: + store_kvcache(k, v, k_cache, v_cache, context.slot_mapping) + if context.block_tables is None: # normal prefill + cu_seqlens_k = context.cu_seqlens_k + seqused_k = None + else: # prefix cache + cu_seqlens_k = None + seqused_k = context.context_lens + k, v = k_cache, v_cache + o = flash_attn_varlen_func(q, k, v, + max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q, + max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=cu_seqlens_k, + seqused_k=seqused_k, softmax_scale=self.scale, + causal=True, block_table=context.block_tables) + else: # decode + o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache, k.unsqueeze(1), v.unsqueeze(1), + cache_seqlens=context.context_lens, block_table=context.block_tables, + softmax_scale=self.scale, causal=True) + o = o.view(-1, self.num_heads * self.head_dim) + return o diff --git a/nanovllm/layers/embed_head.py b/nanovllm/layers/embed_head.py new file mode 100644 index 0000000..72b6e01 --- /dev/null +++ b/nanovllm/layers/embed_head.py @@ -0,0 +1,72 @@ +import torch +from torch import nn +import torch.nn.functional as F +import torch.distributed as dist + +from nanovllm.utils.context import get_context + + +class VocabParallelEmbedding(nn.Module): + + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + ): + super().__init__() + self.tp_rank = 0 # get_tensor_model_parallel_rank() + self.tp_size = 1 # get_tensor_model_parallel_world_size() + assert num_embeddings % self.tp_size == 0 + self.num_embeddings = num_embeddings + self.num_embeddings_per_partition = self.num_embeddings // self.tp_size + self.vocab_start_idx = self.num_embeddings_per_partition * self.tp_rank + self.vocab_end_idx = self.vocab_start_idx + self.num_embeddings_per_partition + self.embedding_dim = embedding_dim + self.weight = nn.Parameter(torch.empty(self.num_embeddings_per_partition, embedding_dim)) + self.weight.weight_loader = self.weight_loader + + def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): + param_data = param.data + shard_size = param_data.size(0) + start_idx = self.tp_rank * shard_size + loaded_weight = loaded_weight.narrow(0, start_idx, shard_size) + assert param_data.size() == loaded_weight.size() + param_data.copy_(loaded_weight) + + def forward(self, x: torch.Tensor): + if self.tp_size > 1: + mask = (x >= self.vocab_start_idx) & (x < self.vocab_end_idx) + x = mask * (x - self.vocab_start_idx) + y = F.embedding(x, self.weight) + if self.tp_size > 1: + y = mask * y + dist.all_reduce(y) + return y + + +class ParallelLMHead(VocabParallelEmbedding): + + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + bias: bool = False, + ): + super().__init__(num_embeddings, embedding_dim) + if bias: + self.bias = nn.Parameter(torch.empty(self.num_embeddings_per_partition)) + self.bias.weight_loader = self.weight_loader + else: + self.register_parameter("bias", None) + + def forward(self, x: torch.Tensor): + context = get_context() + if context.is_prefill: + last_indices = context.cu_seqlens_q[1:] - 1 + x = x[last_indices].contiguous() + logits = F.linear(x, self.weight, self.bias) + # if self.tp_size > 1: + # all_logits = [torch.empty_like(logits) for _ in range(self.tp_size)] + # dist.gather(logits, all_logits, 0) + # logits = torch.cat(all_logits, -1) + return logits if self.tp_rank == 0 else None \ No newline at end of file diff --git a/nanovllm/layers/layernorm.py b/nanovllm/layers/layernorm.py new file mode 100755 index 0000000..011638e --- /dev/null +++ b/nanovllm/layers/layernorm.py @@ -0,0 +1,51 @@ +import torch +from torch import nn + + +class RMSNorm(nn.Module): + + def __init__( + self, + hidden_size: int, + eps: float = 1e-6, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + self.eps = eps + self.weight = nn.Parameter(torch.ones(hidden_size)) + + @torch.compile + def rms_forward( + self, + x: torch.Tensor, + ) -> torch.Tensor: + orig_dtype = x.dtype + x = x.to(torch.float32) + var = x.pow(2).mean(dim=-1, keepdim=True) + x.mul_(torch.rsqrt(var + self.eps)) + x = x.to(orig_dtype).mul_(self.weight) + return x + + @torch.compile + def add_rms_forward( + self, + x: torch.Tensor, + residual: torch.Tensor, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + orig_dtype = x.dtype + x = x.to(torch.float32).add_(residual.to(torch.float32)) + residual = x.to(orig_dtype) + var = x.pow(2).mean(dim=-1, keepdim=True) + x.mul_(torch.rsqrt(var + self.eps)) + x = x.to(orig_dtype).mul_(self.weight) + return x, residual + + def forward( + self, + x: torch.Tensor, + residual: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + if residual is None: + return self.rms_forward(x) + else: + return self.add_rms_forward(x, residual) diff --git a/nanovllm/layers/linear.py b/nanovllm/layers/linear.py new file mode 100755 index 0000000..f39fa59 --- /dev/null +++ b/nanovllm/layers/linear.py @@ -0,0 +1,199 @@ +import torch +from torch import nn +import torch.nn.functional as F +import torch.distributed as dist + + +def divide(numerator, denominator): + assert numerator % denominator == 0 + return numerator // denominator + + +class LinearBase(nn.Module): + + def __init__( + self, + input_size: int, + output_size: int, + tp_dim: int | None = None, + ): + super().__init__() + self.input_size = input_size + self.output_size = output_size + self.tp_dim = tp_dim + self.tp_rank = 0 # get_tensor_model_parallel_rank() + self.tp_size = 1 # get_tensor_model_parallel_world_size() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + raise NotImplementedError + + +class ReplicatedLinear(LinearBase): + + def __init__( + self, + input_size: int, + output_size: int, + bias: bool = False, + ): + super().__init__(input_size, output_size) + self.weight = nn.Parameter(torch.empty(self.output_size, self.input_size)) + self.weight.weight_loader = self.weight_loader + if bias: + self.bias = nn.Parameter(torch.empty(self.output_size)) + self.bias.weight_loader = self.weight_loader + else: + self.register_parameter("bias", None) + + def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): + assert param.size() == loaded_weight.size() + param.data.copy_(loaded_weight) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return F.linear(x, self.weight, self.bias) + + +class ColumnParallelLinear(LinearBase): + + def __init__( + self, + input_size: int, + output_size: int, + bias: bool = False, + ): + super().__init__(input_size, output_size, 0) + self.input_size_per_partition = input_size + self.output_size_per_partition = divide(output_size, self.tp_size) + self.output_partition_sizes = [self.output_size_per_partition] + # If QKV or MergedColumn, use output size of each partition. + if hasattr(self, "output_sizes"): + self.output_partition_sizes = [ + divide(output_size, self.tp_size) + for output_size in self.output_sizes + ] + + self.weight = nn.Parameter(torch.empty(self.output_size_per_partition, self.input_size)) + self.weight.weight_loader = self.weight_loader + if bias: + self.bias = nn.Parameter(torch.empty(self.output_size_per_partition)) + self.bias.weight_loader = self.weight_loader + else: + self.register_parameter("bias", None) + + def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): + param_data = param.data + shard_size = param_data.size(self.tp_dim) + start_idx = self.tp_rank * shard_size + loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size) + assert param_data.size() == loaded_weight.size() + param_data.copy_(loaded_weight) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return F.linear(x, self.weight, self.bias) + + +class MergedColumnParallelLinear(ColumnParallelLinear): + + def __init__( + self, + input_size: int, + output_sizes: list[int], + bias: bool = False, + ): + self.output_sizes = output_sizes + tp_size = 1 # get_tensor_model_parallel_world_size() + assert all(output_size % tp_size == 0 for output_size in output_sizes) + super().__init__(input_size, sum(output_sizes), bias=bias) + + def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: int | None = None): + param_data = param.data + shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size + shard_size = self.output_sizes[loaded_shard_id] // self.tp_size + param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size) + # loaded_weight = loaded_weight.narrow(self.tp_dim, self.tp_rank * shard_size, shard_size) + assert param_data.size() == loaded_weight.size() + param_data.copy_(loaded_weight) + + +class QKVParallelLinear(ColumnParallelLinear): + + def __init__( + self, + hidden_size: int, + head_size: int, + total_num_heads: int, + total_num_kv_heads: int | None = None, + bias: bool = False, + ): + self.hidden_size = hidden_size + self.head_size = head_size + self.total_num_heads = total_num_heads + if total_num_kv_heads is None: + total_num_kv_heads = total_num_heads + self.total_num_kv_heads = total_num_kv_heads + # Divide the weight matrix along the last dimension. + tp_size = 1 # get_tensor_model_parallel_world_size() + self.num_heads = divide(self.total_num_heads, tp_size) + self.num_kv_heads = divide(self.total_num_kv_heads, tp_size) + input_size = self.hidden_size + output_size = (self.num_heads + 2 * self.num_kv_heads) * tp_size * self.head_size + self.output_sizes = [ + self.num_heads * self.head_size * tp_size, # q_proj + self.num_kv_heads * self.head_size * tp_size, # k_proj + self.num_kv_heads * self.head_size * tp_size, # v_proj + ] + + super().__init__(input_size, output_size, bias) + + def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: str | None = None): + param_data = param.data + assert loaded_shard_id in ["q", "k", "v"] + if loaded_shard_id == "q": + shard_size = self.num_heads * self.head_size + shard_offset = 0 + elif loaded_shard_id == "k": + shard_size = self.num_kv_heads * self.head_size + shard_offset = self.num_heads * self.head_size + else: + shard_size = self.num_kv_heads * self.head_size + shard_offset = self.num_heads * self.head_size + self.num_kv_heads * self.head_size + param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size) + # loaded_weight = loaded_weight.narrow(self.tp_dim, self.tp_rank * shard_size, shard_size) + assert param_data.size() == loaded_weight.size() + param_data.copy_(loaded_weight) + + +class RowParallelLinear(LinearBase): + + def __init__( + self, + input_size: int, + output_size: int, + bias: bool = False, + ): + super().__init__(input_size, output_size, 1) + self.input_size_per_partition = divide(input_size, self.tp_size) + self.output_size_per_partition = output_size + self.output_partition_sizes = [output_size] + + self.weight = nn.Parameter(torch.empty(self.output_size, self.input_size_per_partition)) + self.weight.weight_loader = self.weight_loader + if bias: + self.bias = nn.Parameter(torch.empty(self.output_size)) + self.bias.weight_loader = self.weight_loader + else: + self.register_parameter("bias", None) + + def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): + param_data = param.data + shard_size = param_data.size(self.tp_dim) + start_idx = self.tp_rank * shard_size + loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size) + assert param_data.size() == loaded_weight.size() + param_data.copy_(loaded_weight) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + y = F.linear(x, self.weight, self.bias if self.tp_rank == 0 else None) + if self.tp_size > 1: + dist.all_reduce(y) + return y diff --git a/nanovllm/layers/rotary_embedding.py b/nanovllm/layers/rotary_embedding.py new file mode 100644 index 0000000..26ca5f9 --- /dev/null +++ b/nanovllm/layers/rotary_embedding.py @@ -0,0 +1,73 @@ +from functools import lru_cache + +import torch +from torch import nn + + +def apply_rotary_emb( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, +) -> torch.Tensor: + cos = cos.unsqueeze(-2) + sin = sin.unsqueeze(-2) + x1, x2 = torch.chunk(x.to(torch.float32), 2, dim=-1) + y1 = x1 * cos - x2 * sin + y2 = x2 * cos + x1 * sin + return torch.cat((y1, y2), dim=-1).to(x.dtype) + + +class RotaryEmbedding(nn.Module): + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: float, + ) -> None: + super().__init__() + self.head_size = head_size + self.rotary_dim = rotary_dim + assert rotary_dim == head_size + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (base**(torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim)) + t = torch.arange(self.max_position_embeddings, dtype=torch.float) + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + cache = torch.cat((cos, sin), dim=-1) + self.register_buffer("cos_sin_cache", cache, persistent=False) + + @torch.compile + def forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + positions = positions.flatten() + num_tokens = positions.shape[0] + cos_sin = self.cos_sin_cache[positions] + cos, sin = cos_sin.chunk(2, dim=-1) + query_shape = query.shape + query = query.view(num_tokens, -1, self.head_size) + query = apply_rotary_emb(query, cos, sin).view(query_shape) + key_shape = key.shape + key = key.view(num_tokens, -1, self.head_size) + key = apply_rotary_emb(key, cos, sin).view(key_shape) + return query, key + + +@lru_cache(1) +def get_rope( + head_size: int, + rotary_dim: int, + max_position: int, + base: float, + rope_scaling: dict | None = None, +): + assert rope_scaling is None + rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base) + return rotary_emb \ No newline at end of file diff --git a/nanovllm/layers/sampler.py b/nanovllm/layers/sampler.py new file mode 100644 index 0000000..12d8888 --- /dev/null +++ b/nanovllm/layers/sampler.py @@ -0,0 +1,17 @@ +import torch +from torch import nn + + +class Sampler(nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, logits: torch.Tensor, temperatures: torch.Tensor | None = None): + logits = logits.to(torch.float) + if temperatures is not None: + logits.div_(temperatures.unsqueeze(dim=1)) + probs = torch.softmax(logits, dim=-1, dtype=torch.float) + # logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) + sampled_tokens = probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1) + return sampled_tokens diff --git a/nanovllm/llm.py b/nanovllm/llm.py new file mode 100644 index 0000000..1c3efe9 --- /dev/null +++ b/nanovllm/llm.py @@ -0,0 +1,5 @@ +from nanovllm.engine.llm_engine import LLMEngine + + +class LLM(LLMEngine): + pass \ No newline at end of file diff --git a/nanovllm/models/qwen3.py b/nanovllm/models/qwen3.py new file mode 100755 index 0000000..5fb463f --- /dev/null +++ b/nanovllm/models/qwen3.py @@ -0,0 +1,247 @@ +import torch +from torch import nn +from transformers import Qwen3Config + +from nanovllm.layers.activation import SiluAndMul +from nanovllm.layers.attention import Attention +from nanovllm.layers.layernorm import RMSNorm +from nanovllm.layers.linear import QKVParallelLinear, MergedColumnParallelLinear, RowParallelLinear +from nanovllm.layers.rotary_embedding import get_rope +from nanovllm.layers.embed_head import VocabParallelEmbedding, ParallelLMHead + + +class Qwen3Attention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 4096 * 32, + head_dim: int | None = None, + rms_norm_eps: float = 1e-06, + qkv_bias: bool = False, + rope_theta: float = 10000, + rope_scaling: tuple | None = None, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = 1 # get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + assert self.total_num_kv_heads % tp_size == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = head_dim or hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=qkv_bias, + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position, + base=self.rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads) + self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) + self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q_by_head = q.view(-1, self.num_heads, self.head_dim) + q_by_head = self.q_norm(q_by_head) + q = q_by_head.view(q.shape) + k_by_head = k.view(-1, self.num_kv_heads, self.head_dim) + k_by_head = self.k_norm(k_by_head) + k = k_by_head.view(k.shape) + q, k = self.rotary_emb(positions, q, k) + o = self.attn(q, k, v) + output = self.o_proj(o) + return output + + +class Qwen3MLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=False, + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + ) + assert hidden_act == "silu" + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x = self.down_proj(x) + return x + + +class Qwen3DecoderLayer(nn.Module): + + def __init__( + self, + config: Qwen3Config, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = Qwen3Attention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + max_position=config.max_position_embeddings, + rms_norm_eps=config.rms_norm_eps, + qkv_bias=getattr(config, 'attention_bias', False), + head_dim=getattr(config, 'head_dim', None), + rope_theta=getattr(config, "rope_theta", 1000000), + rope_scaling=getattr(config, "rope_scaling", None), + ) + self.mlp = Qwen3MLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: torch.Tensor | None, + ) -> tuple[torch.Tensor, torch.Tensor]: + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class Qwen3Model(nn.Module): + + def __init__( + self, + config: Qwen3Config, + ): + super().__init__() + self.vocab_size = config.vocab_size + self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size) + self.layers = nn.ModuleList([Qwen3DecoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + residual = None + for layer in self.layers: + hidden_states, residual = layer( + positions, + hidden_states, + residual, + ) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class Qwen3ForCausalLM(nn.Module): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + def __init__( + self, + config: Qwen3Config + ): + super().__init__() + self.model = Qwen3Model(config) + self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) + if config.tie_word_embeddings: + self.lm_head.weight.data = self.model.embed_tokens.weight.data + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + logits = self.lm_head(hidden_states) + return logits + + def load_weights(self, path: str): + import os + from safetensors import safe_open + default_weight_loader = lambda param, loaded_weight: param.data.copy_(loaded_weight) + with safe_open(os.path.join(path, "model.safetensors"), "pt", "cpu") as f: + for n, p in self.named_parameters(): + for x in self.packed_modules_mapping: + if x in n: + weight_loader = getattr(p, "weight_loader", default_weight_loader) + for i, y in enumerate(self.packed_modules_mapping[x]): + weight_loader(p, f.get_tensor(n.replace(x, y)), y[0] if x == "qkv_proj" else i) + break + else: + weight_loader = getattr(p, "weight_loader", default_weight_loader) + weight_loader(p, f.get_tensor(n)) diff --git a/nanovllm/sampling_params.py b/nanovllm/sampling_params.py new file mode 100644 index 0000000..67d60bb --- /dev/null +++ b/nanovllm/sampling_params.py @@ -0,0 +1,8 @@ +from dataclasses import dataclass + + +@dataclass +class SamplingParams: + temperature: float = 1.0 + max_tokens: int = 64 + ignore_eos: bool = False diff --git a/nanovllm/utils/context.py b/nanovllm/utils/context.py new file mode 100644 index 0000000..d4cfd3d --- /dev/null +++ b/nanovllm/utils/context.py @@ -0,0 +1,28 @@ +from contextlib import contextmanager +from dataclasses import dataclass +import torch + + +@dataclass +class Context: + is_prefill: bool = False + cu_seqlens_q: torch.Tensor | None = None + cu_seqlens_k: torch.Tensor | None = None + max_seqlen_q: int = 0 + max_seqlen_k: int = 0 + slot_mapping: torch.Tensor | None = None + context_lens: torch.Tensor | None = None + block_tables: torch.Tensor | None = None + +_CONTEXT = Context() + +def get_context(): + return _CONTEXT + +def set_context(is_prefill, cu_seqlens_q=None, cu_seqlens_k=None, max_seqlen_q=0, max_seqlen_k=0, slot_mapping=None, context_lens=None, block_tables=None, ): + global _CONTEXT + _CONTEXT = Context(is_prefill, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, context_lens, block_tables) + +def reset_context(): + global _CONTEXT + _CONTEXT = Context() \ No newline at end of file diff --git a/nanovllm/utils/memory.py b/nanovllm/utils/memory.py new file mode 100644 index 0000000..4d87b31 --- /dev/null +++ b/nanovllm/utils/memory.py @@ -0,0 +1,14 @@ +import os +import subprocess +import torch + + +def get_gpu_memory(device_id: int = 0): + torch.cuda.synchronize() + result = subprocess.check_output( + ['nvidia-smi', '-i', str(device_id), '--query-gpu=memory.total,memory.used,memory.free', '--format=csv,nounits,noheader'], + encoding='utf-8' + ) + total_memory, used_memory, free_memory = [int(x) for x in result.strip().split(', ')] + return total_memory, used_memory, free_memory + \ No newline at end of file diff --git a/nanovllm/utils/timer.py b/nanovllm/utils/timer.py new file mode 100644 index 0000000..0e8b1bc --- /dev/null +++ b/nanovllm/utils/timer.py @@ -0,0 +1,31 @@ +from contextlib import contextmanager +from collections import defaultdict +import torch + + +class CUDATimer: + + def __init__(self): + self.events = defaultdict(list) + + @contextmanager + def record(self, name, enabled=True): + if not enabled: + yield + else: + start, end = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) + self.events[name].append((start, end)) + start.record() + yield + end.record() + + def log(self): + torch.cuda.synchronize() + ret = [] + for name, events in self.events.items(): + total = 0 + count = len(self.events) + for start, end in events: + total += start.elapsed_time(end) + ret.append(f"{name} {total:.2f}ms/{count}times") + return ", ".join(ret) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..b1f4fd9 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,5 @@ +torch +triton +transformers +cmake +ninja \ No newline at end of file