init commit
This commit is contained in:
195
.gitignore
vendored
Normal file
195
.gitignore
vendored
Normal file
@@ -0,0 +1,195 @@
|
|||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py,cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
cover/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
.pybuilder/
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
# For a library or package, you might want to ignore these files since the code is
|
||||||
|
# intended to run in multiple environments; otherwise, check them in:
|
||||||
|
# .python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
|
# install all needed dependencies.
|
||||||
|
#Pipfile.lock
|
||||||
|
|
||||||
|
# UV
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
||||||
|
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||||
|
# commonly ignored for libraries.
|
||||||
|
#uv.lock
|
||||||
|
|
||||||
|
# poetry
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||||
|
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||||
|
# commonly ignored for libraries.
|
||||||
|
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||||
|
#poetry.lock
|
||||||
|
#poetry.toml
|
||||||
|
|
||||||
|
# pdm
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||||
|
#pdm.lock
|
||||||
|
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||||
|
# in version control.
|
||||||
|
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
||||||
|
.pdm.toml
|
||||||
|
.pdm-python
|
||||||
|
.pdm-build/
|
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
||||||
|
|
||||||
|
# pytype static type analyzer
|
||||||
|
.pytype/
|
||||||
|
|
||||||
|
# Cython debug symbols
|
||||||
|
cython_debug/
|
||||||
|
|
||||||
|
# PyCharm
|
||||||
|
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||||
|
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||||
|
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||||
|
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||||
|
#.idea/
|
||||||
|
|
||||||
|
# Abstra
|
||||||
|
# Abstra is an AI-powered process automation framework.
|
||||||
|
# Ignore directories containing user credentials, local state, and settings.
|
||||||
|
# Learn more at https://abstra.io/docs
|
||||||
|
.abstra/
|
||||||
|
|
||||||
|
# Visual Studio Code
|
||||||
|
# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
|
||||||
|
# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
|
||||||
|
# and can be added to the global gitignore or merged into this file. However, if you prefer,
|
||||||
|
# you could uncomment the following to ignore the entire vscode folder
|
||||||
|
# .vscode/
|
||||||
|
|
||||||
|
# Ruff stuff:
|
||||||
|
.ruff_cache/
|
||||||
|
|
||||||
|
# PyPI configuration file
|
||||||
|
.pypirc
|
||||||
|
|
||||||
|
# Cursor
|
||||||
|
# Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
|
||||||
|
# exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
|
||||||
|
# refer to https://docs.cursor.com/context/ignore-files
|
||||||
|
.cursorignore
|
||||||
|
.cursorindexingignore
|
||||||
21
LICENSE
Normal file
21
LICENSE
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2025 Xingkai Yu
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
||||||
20
bench.py
Normal file
20
bench.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
import os
|
||||||
|
import time
|
||||||
|
import torch
|
||||||
|
from nanovllm import LLM, SamplingParams
|
||||||
|
|
||||||
|
|
||||||
|
batch_size = 256
|
||||||
|
seq_len = 1024
|
||||||
|
max_tokens = 512
|
||||||
|
|
||||||
|
path = os.path.expanduser("~/huggingface/Qwen3-0.6B/")
|
||||||
|
llm = LLM(path, enforce_eager=False)
|
||||||
|
|
||||||
|
prompt_token_ids = torch.randint(0, 10240, (batch_size, seq_len)).tolist()
|
||||||
|
sampling_params = SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=max_tokens)
|
||||||
|
|
||||||
|
t = time.time()
|
||||||
|
completions = llm.generate(prompt_token_ids, sampling_params)
|
||||||
|
troughput = batch_size * max_tokens / (time.time() - t)
|
||||||
|
print(f"Throughput: {troughput: .2f}")
|
||||||
29
example.py
Normal file
29
example.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
import os
|
||||||
|
from nanovllm import LLM, SamplingParams
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
path = os.path.expanduser("~/huggingface/Qwen3-0.6B/")
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(path)
|
||||||
|
llm = LLM(path, enforce_eager=True)
|
||||||
|
|
||||||
|
sampling_params = SamplingParams(temperature=0.6, max_tokens=256)
|
||||||
|
prompts = [
|
||||||
|
"自我介绍一下吧!",
|
||||||
|
"列出100内所有素数",
|
||||||
|
]
|
||||||
|
prompts = [
|
||||||
|
tokenizer.apply_chat_template(
|
||||||
|
[{"role": "user", "content": prompt}],
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
enable_thinking=True
|
||||||
|
)
|
||||||
|
for prompt in prompts
|
||||||
|
]
|
||||||
|
completions = llm.generate(prompts, sampling_params)
|
||||||
|
|
||||||
|
for p, c in zip(prompts, completions):
|
||||||
|
print("\n\n")
|
||||||
|
print(f"Prompt: {p}")
|
||||||
|
print(f"Completion: {c}")
|
||||||
2
nanovllm/__init__.py
Normal file
2
nanovllm/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
from nanovllm.llm import LLM
|
||||||
|
from nanovllm.sampling_params import SamplingParams
|
||||||
20
nanovllm/config.py
Normal file
20
nanovllm/config.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from transformers import AutoConfig
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Config:
|
||||||
|
model: str = ''
|
||||||
|
max_num_batched_tokens: int = 16384
|
||||||
|
max_num_seqs: int = 512
|
||||||
|
max_model_len: int = 4096
|
||||||
|
gpu_memory_utilization: float = 0.95
|
||||||
|
enforce_eager: bool = False
|
||||||
|
hf_config: AutoConfig | None = None
|
||||||
|
eos: int = -1
|
||||||
|
kvcache_block_size: int = 256
|
||||||
|
num_kvcache_blocks: int = -1
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
assert self.model
|
||||||
|
assert self.kvcache_block_size == 256
|
||||||
118
nanovllm/engine/block_manager.py
Normal file
118
nanovllm/engine/block_manager.py
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
from collections import deque
|
||||||
|
import xxhash
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from nanovllm.engine.sequence import Sequence
|
||||||
|
|
||||||
|
|
||||||
|
def compute_hash(token_ids: list[int], prefix: int = -1):
|
||||||
|
h = xxhash.xxh64()
|
||||||
|
if prefix != -1:
|
||||||
|
h.update(prefix.to_bytes(8))
|
||||||
|
h.update(np.array(token_ids).tobytes())
|
||||||
|
return h.intdigest()
|
||||||
|
|
||||||
|
|
||||||
|
class Block:
|
||||||
|
|
||||||
|
def __init__(self, block_id):
|
||||||
|
self.block_id = block_id
|
||||||
|
self.ref_count = 0
|
||||||
|
self.hash = -1
|
||||||
|
self.token_ids = []
|
||||||
|
|
||||||
|
def update(self, hash: int, token_ids: list[int]):
|
||||||
|
assert hash != -1
|
||||||
|
assert len(token_ids) == 256
|
||||||
|
self.hash = hash
|
||||||
|
self.token_ids = token_ids
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.ref_count = 1
|
||||||
|
self.hash = -1
|
||||||
|
self.token_ids = []
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"{(self.block_id, self.ref_count, self.hash)}"
|
||||||
|
|
||||||
|
|
||||||
|
class BlockManager:
|
||||||
|
|
||||||
|
def __init__(self, num_blocks: int, block_size: int = 256):
|
||||||
|
assert block_size == 256
|
||||||
|
self.block_size = block_size
|
||||||
|
self.blocks: list[Block] = [Block(i) for i in range(num_blocks)]
|
||||||
|
self.hash_to_block_id: dict[int, int] = dict()
|
||||||
|
self.free_block_ids: deque[int] = deque(range(num_blocks))
|
||||||
|
self.used_block_ids: set[int] = set()
|
||||||
|
|
||||||
|
def _allocate_block(self, block_id: int):
|
||||||
|
block = self.blocks[block_id]
|
||||||
|
assert block.ref_count == 0
|
||||||
|
block.reset()
|
||||||
|
self.free_block_ids.remove(block_id)
|
||||||
|
self.used_block_ids.add(block_id)
|
||||||
|
return self.blocks[block_id]
|
||||||
|
|
||||||
|
def _deallocate_block(self, block_id: int):
|
||||||
|
assert self.blocks[block_id].ref_count == 0
|
||||||
|
self.used_block_ids.remove(block_id)
|
||||||
|
self.free_block_ids.append(block_id)
|
||||||
|
|
||||||
|
def can_allocate(self, seq: Sequence):
|
||||||
|
return seq.num_blocks <= len(self.free_block_ids)
|
||||||
|
|
||||||
|
def allocate(self, seq: Sequence):
|
||||||
|
assert not seq.block_table
|
||||||
|
h = -1
|
||||||
|
cache_miss = False
|
||||||
|
for i in range(seq.num_blocks):
|
||||||
|
token_ids = seq.block(i, self.block_size)
|
||||||
|
h = compute_hash(token_ids, h) if len(token_ids) == self.block_size else -1
|
||||||
|
block_id = self.hash_to_block_id.get(h, -1)
|
||||||
|
if block_id == -1 or self.blocks[block_id].token_ids != token_ids:
|
||||||
|
cache_miss = True
|
||||||
|
if cache_miss:
|
||||||
|
block_id = self.free_block_ids[0]
|
||||||
|
block = self._allocate_block(block_id)
|
||||||
|
else:
|
||||||
|
seq.num_cached_tokens += self.block_size
|
||||||
|
if block_id in self.used_block_ids:
|
||||||
|
block = self.blocks[block_id]
|
||||||
|
block.ref_count += 1
|
||||||
|
else:
|
||||||
|
block = self._allocate_block(block_id)
|
||||||
|
if h != -1:
|
||||||
|
block.update(h, token_ids)
|
||||||
|
self.hash_to_block_id[h] = block_id
|
||||||
|
seq.block_table.append(block_id)
|
||||||
|
|
||||||
|
def deallocate(self, seq: Sequence):
|
||||||
|
for block_id in seq.block_table:
|
||||||
|
block = self.blocks[block_id]
|
||||||
|
block.ref_count -= 1
|
||||||
|
if block.ref_count == 0:
|
||||||
|
self._deallocate_block(block_id)
|
||||||
|
seq.num_cached_tokens = 0
|
||||||
|
seq.block_table.clear()
|
||||||
|
|
||||||
|
def can_append(self):
|
||||||
|
return len(self.free_block_ids) >= 1
|
||||||
|
|
||||||
|
def may_append(self, seq: Sequence):
|
||||||
|
block_table = seq.block_table
|
||||||
|
last_block = self.blocks[block_table[-1]]
|
||||||
|
if len(seq) % self.block_size == 1:
|
||||||
|
assert last_block.hash != -1
|
||||||
|
block_id = self.free_block_ids[0]
|
||||||
|
self._allocate_block(block_id)
|
||||||
|
block_table.append(block_id)
|
||||||
|
elif len(seq) % self.block_size == 0:
|
||||||
|
assert last_block.hash == -1
|
||||||
|
token_ids = seq.last_block(self.block_size)
|
||||||
|
prefix = self.blocks[block_table[-2]].hash if len(block_table) > 1 else -1
|
||||||
|
h = compute_hash(token_ids, prefix)
|
||||||
|
last_block.update(h, token_ids)
|
||||||
|
self.hash_to_block_id[h] = last_block.block_id
|
||||||
|
else:
|
||||||
|
assert last_block.hash == -1
|
||||||
66
nanovllm/engine/llm_engine.py
Normal file
66
nanovllm/engine/llm_engine.py
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
from collections import defaultdict
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
from transformers import AutoConfig, AutoTokenizer
|
||||||
|
|
||||||
|
from nanovllm.config import Config
|
||||||
|
from nanovllm.sampling_params import SamplingParams
|
||||||
|
from nanovllm.engine.sequence import Sequence
|
||||||
|
from nanovllm.engine.scheduler import Scheduler
|
||||||
|
from nanovllm.engine.model_runner import ModelRunner
|
||||||
|
|
||||||
|
|
||||||
|
class LLMEngine:
|
||||||
|
|
||||||
|
def __init__(self, model, **kwargs):
|
||||||
|
config = Config(model)
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
if hasattr(config, k):
|
||||||
|
setattr(config, k, v)
|
||||||
|
config.hf_config = AutoConfig.from_pretrained(config.model)
|
||||||
|
config.max_model_len = min(config.max_model_len, config.hf_config.max_position_embeddings)
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True)
|
||||||
|
config.eos = self.tokenizer.eos_token_id
|
||||||
|
self.model_runner = ModelRunner(config)
|
||||||
|
self.scheduler = Scheduler(config)
|
||||||
|
|
||||||
|
def add_request(self, prompt: str | list[int], sampling_params: SamplingParams):
|
||||||
|
if isinstance(prompt, str):
|
||||||
|
prompt = self.tokenizer.encode(prompt)
|
||||||
|
seq = Sequence(prompt, sampling_params)
|
||||||
|
self.scheduler.add(seq)
|
||||||
|
|
||||||
|
def step(self):
|
||||||
|
seqs, is_prefill = self.scheduler.schedule()
|
||||||
|
token_ids = self.model_runner.run(seqs, is_prefill)
|
||||||
|
finished = self.scheduler.postprocess(seqs, token_ids)
|
||||||
|
return [(seq.seq_id, token_id, finish) for seq, token_id, finish in zip(seqs, token_ids, finished)]
|
||||||
|
|
||||||
|
def is_finished(self):
|
||||||
|
return self.scheduler.is_finished()
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
prompts: list[str] | list[list[int]],
|
||||||
|
sampling_params: SamplingParams | list[SamplingParams],
|
||||||
|
use_tqdm: bool = True,
|
||||||
|
) -> list[str]:
|
||||||
|
if use_tqdm:
|
||||||
|
pbar = tqdm(total=len(prompts),
|
||||||
|
desc="Processed prompts",
|
||||||
|
)
|
||||||
|
if not isinstance(SamplingParams, list):
|
||||||
|
sampling_params = [sampling_params] * len(prompts)
|
||||||
|
for prompt, sp in zip(prompts, sampling_params):
|
||||||
|
self.add_request(prompt, sp)
|
||||||
|
outputs = defaultdict(list)
|
||||||
|
while not self.is_finished():
|
||||||
|
output = self.step()
|
||||||
|
for seq_id, token_id, finish in output:
|
||||||
|
outputs[seq_id].append(token_id)
|
||||||
|
if use_tqdm and finish:
|
||||||
|
pbar.update(1)
|
||||||
|
outputs = [outputs[seq_id] for seq_id in sorted(outputs)]
|
||||||
|
outputs = [self.tokenizer.decode(token_ids) for token_ids in outputs]
|
||||||
|
if use_tqdm:
|
||||||
|
pbar.close()
|
||||||
|
return outputs
|
||||||
198
nanovllm/engine/model_runner.py
Normal file
198
nanovllm/engine/model_runner.py
Normal file
@@ -0,0 +1,198 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from nanovllm.config import Config
|
||||||
|
from nanovllm.engine.sequence import Sequence
|
||||||
|
from nanovllm.utils.context import set_context, get_context, reset_context
|
||||||
|
from nanovllm.utils.memory import get_gpu_memory
|
||||||
|
from nanovllm.models.qwen3 import Qwen3ForCausalLM
|
||||||
|
from nanovllm.layers.sampler import Sampler
|
||||||
|
|
||||||
|
|
||||||
|
class ModelRunner:
|
||||||
|
|
||||||
|
def __init__(self, config: Config):
|
||||||
|
self.config = config
|
||||||
|
hf_config = config.hf_config
|
||||||
|
self.block_size = config.kvcache_block_size
|
||||||
|
self.enforce_eager = config.enforce_eager
|
||||||
|
|
||||||
|
default_dtype = torch.get_default_dtype()
|
||||||
|
torch.set_default_dtype(hf_config.torch_dtype)
|
||||||
|
torch.set_default_device("cuda")
|
||||||
|
self.model = Qwen3ForCausalLM(hf_config)
|
||||||
|
self.model.load_weights(config.model)
|
||||||
|
self.sampler = Sampler()
|
||||||
|
self.allocate_kv_cache(config.gpu_memory_utilization)
|
||||||
|
if not self.enforce_eager:
|
||||||
|
self.capture_model()
|
||||||
|
torch.set_default_device("cpu")
|
||||||
|
torch.set_default_dtype(default_dtype)
|
||||||
|
|
||||||
|
def allocate_kv_cache(self, gpu_memory_utilization):
|
||||||
|
config = self.config
|
||||||
|
hf_config = config.hf_config
|
||||||
|
total, used, _ = get_gpu_memory()
|
||||||
|
free = total * gpu_memory_utilization - used
|
||||||
|
block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * hf_config.num_key_value_heads * hf_config.head_dim * hf_config.torch_dtype.itemsize
|
||||||
|
config.num_kvcache_blocks = int(free * 1e6) // block_bytes
|
||||||
|
self.kv_cache = torch.zeros(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, hf_config.num_key_value_heads, hf_config.head_dim)
|
||||||
|
layer_id = 0
|
||||||
|
for module in self.model.modules():
|
||||||
|
if hasattr(module, "k_cache") and hasattr(module, "v_cache"):
|
||||||
|
module.k_cache = self.kv_cache[0, layer_id]
|
||||||
|
module.v_cache = self.kv_cache[1, layer_id]
|
||||||
|
layer_id += 1
|
||||||
|
|
||||||
|
def preare_block_tables(self, seqs: list[Sequence]):
|
||||||
|
max_len = max(len(seq.block_table) for seq in seqs)
|
||||||
|
block_tables = [
|
||||||
|
seq.block_table + [-1] * (max_len - len(seq.block_table))
|
||||||
|
for seq in seqs
|
||||||
|
]
|
||||||
|
block_tables = torch.tensor(block_tables, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||||
|
return block_tables
|
||||||
|
|
||||||
|
def prepare_prefill(self, seqs: list[Sequence]):
|
||||||
|
input_ids = []
|
||||||
|
positions = []
|
||||||
|
cu_seqlens_q = [0]
|
||||||
|
cu_seqlens_k = [0]
|
||||||
|
max_seqlen_q = 0
|
||||||
|
max_seqlen_k = 0
|
||||||
|
slot_mapping = []
|
||||||
|
context_lens = None
|
||||||
|
block_tables = None
|
||||||
|
for seq in seqs:
|
||||||
|
seqlen = len(seq)
|
||||||
|
input_ids.extend(seq[seq.num_cached_tokens:])
|
||||||
|
positions.extend(list(range(seq.num_cached_tokens, len(seq))))
|
||||||
|
seqlen_q = seqlen - seq.num_cached_tokens
|
||||||
|
seqlen_k = seqlen
|
||||||
|
cu_seqlens_q.append(cu_seqlens_q[-1] + seqlen_q)
|
||||||
|
cu_seqlens_k.append(cu_seqlens_k[-1] + seqlen_k)
|
||||||
|
max_seqlen_q = max(seqlen_q, max_seqlen_q)
|
||||||
|
max_seqlen_k = max(seqlen_k, max_seqlen_k)
|
||||||
|
for i in range(seq.num_cached_blocks, seq.num_blocks):
|
||||||
|
start = seq.block_table[i] * self.block_size
|
||||||
|
if i != seq.num_blocks - 1:
|
||||||
|
end = start + self.block_size
|
||||||
|
else:
|
||||||
|
end = start + len(seq.last_block())
|
||||||
|
slot_mapping.extend(list(range(start, end)))
|
||||||
|
assert len(input_ids) == len(slot_mapping)
|
||||||
|
assert len(input_ids) == cu_seqlens_q[-1]
|
||||||
|
if cu_seqlens_k[-1] > cu_seqlens_q[-1]: # prefix cache
|
||||||
|
context_lens = torch.tensor([len(seq) for seq in seqs], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||||
|
block_tables = self.preare_block_tables(seqs)
|
||||||
|
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
||||||
|
positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
||||||
|
cu_seqlens_q = torch.tensor(cu_seqlens_q, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||||
|
cu_seqlens_k = torch.tensor(cu_seqlens_k, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||||
|
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||||
|
set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, context_lens, block_tables)
|
||||||
|
return input_ids, positions
|
||||||
|
|
||||||
|
def prepare_decode(self, seqs: list[Sequence]):
|
||||||
|
input_ids = []
|
||||||
|
positions = []
|
||||||
|
slot_mapping = []
|
||||||
|
context_lens = []
|
||||||
|
for seq in seqs:
|
||||||
|
input_ids.append(seq.last_token)
|
||||||
|
positions.append(len(seq))
|
||||||
|
context_lens.append(len(seq))
|
||||||
|
slot_mapping.append(seq.block_table[-1] * self.block_size + len(seq.last_block()))
|
||||||
|
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
||||||
|
positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
||||||
|
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||||
|
context_lens = torch.tensor(context_lens, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||||
|
block_tables = self.preare_block_tables(seqs)
|
||||||
|
set_context(False, slot_mapping=slot_mapping, context_lens=context_lens, block_tables=block_tables)
|
||||||
|
return input_ids, positions
|
||||||
|
|
||||||
|
def prepare_sample(self, seqs: list[Sequence]):
|
||||||
|
temperatures = []
|
||||||
|
for seq in seqs:
|
||||||
|
temperatures.append(seq.temperature)
|
||||||
|
temperatures = torch.tensor(temperatures, dtype=torch.float32, pin_memory=True).cuda(non_blocking=True)
|
||||||
|
return temperatures
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def run_model(self, input_ids: torch.Tensor, positions: torch.Tensor, is_prefill):
|
||||||
|
if is_prefill or self.enforce_eager or input_ids.size(0) > 256:
|
||||||
|
return self.model.compute_logits(self.model(input_ids, positions))
|
||||||
|
else:
|
||||||
|
bs = input_ids.size(0)
|
||||||
|
context = get_context()
|
||||||
|
self.reset_graph_vars()
|
||||||
|
graph = self.graphs[next(x for x in self.graph_bs if x >= bs)]
|
||||||
|
graph_vars = self.graph_vars
|
||||||
|
graph_vars["input_ids"][:bs] = input_ids
|
||||||
|
graph_vars["positions"][:bs] = positions
|
||||||
|
graph_vars["slot_mapping"][:bs] = context.slot_mapping
|
||||||
|
graph_vars["context_lens"][:bs] = context.context_lens
|
||||||
|
graph_vars["block_tables"][:bs, :context.block_tables.size(1)] = context.block_tables
|
||||||
|
graph.replay()
|
||||||
|
return self.model.compute_logits(graph_vars["outputs"][:bs])
|
||||||
|
|
||||||
|
def reset_graph_vars(self):
|
||||||
|
graph_vars = self.graph_vars
|
||||||
|
graph_vars["input_ids"].zero_()
|
||||||
|
graph_vars["positions"].zero_()
|
||||||
|
graph_vars["slot_mapping"].zero_()
|
||||||
|
graph_vars["context_lens"].zero_()
|
||||||
|
graph_vars["block_tables"].zero_()
|
||||||
|
|
||||||
|
def run(self, seqs: list[Sequence], is_prefill: bool) -> list[int]:
|
||||||
|
input_ids, positions = self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs)
|
||||||
|
temperatures = self.prepare_sample(seqs)
|
||||||
|
logits = self.run_model(input_ids, positions, is_prefill)
|
||||||
|
token_ids = self.sampler(logits, temperatures).tolist()
|
||||||
|
reset_context()
|
||||||
|
return token_ids
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def capture_model(self):
|
||||||
|
get_rng_state = torch.cuda.get_rng_state
|
||||||
|
set_rng_state = torch.cuda.set_rng_state
|
||||||
|
rng_state = torch.cuda.get_rng_state()
|
||||||
|
torch.cuda.get_rng_state = lambda: rng_state
|
||||||
|
torch.cuda.set_rng_state = lambda _: None
|
||||||
|
|
||||||
|
config = self.config
|
||||||
|
hf_config = config.hf_config
|
||||||
|
max_bs = min(self.config.max_num_seqs, 256)
|
||||||
|
max_num_blocks = (config.max_model_len + self.block_size - 1) // self.block_size
|
||||||
|
input_ids = torch.zeros(max_bs, dtype=torch.int64)
|
||||||
|
positions = torch.zeros(max_bs, dtype=torch.int64)
|
||||||
|
slot_mapping = torch.zeros(max_bs, dtype=torch.int32)
|
||||||
|
context_lens = torch.zeros(max_bs, dtype=torch.int32)
|
||||||
|
block_tables = torch.zeros(max_bs, max_num_blocks, dtype=torch.int32)
|
||||||
|
outputs = torch.zeros(max_bs, hf_config.hidden_size)
|
||||||
|
self.graph_bs = [1, 2, 4, 8, 16] + list(range(32, max_bs + 1, 32))
|
||||||
|
self.graphs = {}
|
||||||
|
self.graph_pool = None
|
||||||
|
|
||||||
|
for bs in reversed(self.graph_bs):
|
||||||
|
graph = torch.cuda.CUDAGraph()
|
||||||
|
set_context(False, slot_mapping=slot_mapping[:bs], context_lens=context_lens[:bs], block_tables=block_tables[:bs])
|
||||||
|
outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # warmup
|
||||||
|
with torch.cuda.graph(graph, self.graph_pool):
|
||||||
|
outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # capture
|
||||||
|
if self.graph_pool is None:
|
||||||
|
self.graph_pool = graph.pool()
|
||||||
|
self.graphs[bs] = graph
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
reset_context()
|
||||||
|
|
||||||
|
self.graph_vars = dict(
|
||||||
|
input_ids=input_ids,
|
||||||
|
positions=positions,
|
||||||
|
slot_mapping=slot_mapping,
|
||||||
|
context_lens=context_lens,
|
||||||
|
block_tables=block_tables,
|
||||||
|
outputs=outputs,
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.cuda.get_rng_state = get_rng_state
|
||||||
|
torch.cuda.set_rng_state = set_rng_state
|
||||||
84
nanovllm/engine/scheduler.py
Normal file
84
nanovllm/engine/scheduler.py
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
from collections import deque
|
||||||
|
|
||||||
|
from nanovllm.config import Config
|
||||||
|
from nanovllm.engine.sequence import Sequence, SequenceStatus
|
||||||
|
from nanovllm.engine.block_manager import BlockManager
|
||||||
|
|
||||||
|
|
||||||
|
class Scheduler:
|
||||||
|
|
||||||
|
def __init__(self, config: Config):
|
||||||
|
self.max_num_seqs = config.max_num_seqs
|
||||||
|
self.max_num_batched_tokens = config.max_num_batched_tokens
|
||||||
|
self.eos = config.eos
|
||||||
|
self.block_manager = BlockManager(config.num_kvcache_blocks, config.kvcache_block_size)
|
||||||
|
self.waiting: deque[Sequence] = deque()
|
||||||
|
self.running: deque[Sequence] = deque()
|
||||||
|
self.num_finished = 0
|
||||||
|
self.num_tokens = 0
|
||||||
|
|
||||||
|
def is_finished(self):
|
||||||
|
return not self.waiting and not self.running
|
||||||
|
|
||||||
|
def add(self, seq: Sequence):
|
||||||
|
self.waiting.append(seq)
|
||||||
|
|
||||||
|
def schedule(self) -> tuple[list[Sequence], SequenceStatus]:
|
||||||
|
# prefill
|
||||||
|
scheduled_seqs = []
|
||||||
|
num_seqs = 0
|
||||||
|
num_batched_tokens = 0
|
||||||
|
while self.waiting and num_seqs < self.max_num_seqs:
|
||||||
|
seq = self.waiting[0]
|
||||||
|
if num_batched_tokens + len(seq) > self.max_num_batched_tokens or not self.block_manager.can_allocate(seq):
|
||||||
|
break
|
||||||
|
num_seqs += 1
|
||||||
|
self.block_manager.allocate(seq)
|
||||||
|
num_batched_tokens += len(seq) - seq.num_cached_tokens
|
||||||
|
seq.status = SequenceStatus.RUNNING
|
||||||
|
self.waiting.popleft()
|
||||||
|
self.running.append(seq)
|
||||||
|
scheduled_seqs.append(seq)
|
||||||
|
if scheduled_seqs:
|
||||||
|
return scheduled_seqs, True
|
||||||
|
|
||||||
|
# decode
|
||||||
|
# self.running = deque(sorted(self.running))
|
||||||
|
while self.running and num_seqs < self.max_num_seqs:
|
||||||
|
seq = self.running.popleft()
|
||||||
|
while not self.block_manager.can_append():
|
||||||
|
if self.running:
|
||||||
|
self.preempt(self.running.pop())
|
||||||
|
else:
|
||||||
|
self.preempt(seq)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
num_seqs += 1
|
||||||
|
self.block_manager.may_append(seq)
|
||||||
|
scheduled_seqs.append(seq)
|
||||||
|
running = deque(scheduled_seqs)
|
||||||
|
running.extend(self.running)
|
||||||
|
self.running = running
|
||||||
|
if scheduled_seqs:
|
||||||
|
return scheduled_seqs, False
|
||||||
|
|
||||||
|
def preempt(self, seq: Sequence):
|
||||||
|
seq.status = SequenceStatus.WAITING
|
||||||
|
self.block_manager.deallocate(seq)
|
||||||
|
self.waiting.appendleft(seq)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def postprocess(self, seqs: list[Sequence], token_ids: list[int]) -> list[bool]:
|
||||||
|
self.num_tokens += len(token_ids)
|
||||||
|
finished = []
|
||||||
|
for seq, token_id in zip(seqs, token_ids):
|
||||||
|
seq.append_token(token_id)
|
||||||
|
if token_id == self.eos or seq.num_completion_tokens == seq.max_tokens:
|
||||||
|
seq.status = SequenceStatus.FINISHED
|
||||||
|
self.block_manager.deallocate(seq)
|
||||||
|
self.running.remove(seq)
|
||||||
|
self.num_finished += 1
|
||||||
|
finished.append(True)
|
||||||
|
else:
|
||||||
|
finished.append(False)
|
||||||
|
return finished
|
||||||
73
nanovllm/engine/sequence.py
Normal file
73
nanovllm/engine/sequence.py
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
from copy import copy
|
||||||
|
from enum import Enum, auto
|
||||||
|
from itertools import count
|
||||||
|
|
||||||
|
from nanovllm.sampling_params import SamplingParams
|
||||||
|
|
||||||
|
|
||||||
|
class SequenceStatus(Enum):
|
||||||
|
WAITING = auto()
|
||||||
|
RUNNING = auto()
|
||||||
|
FINISHED = auto()
|
||||||
|
|
||||||
|
|
||||||
|
class Sequence:
|
||||||
|
counter = count()
|
||||||
|
|
||||||
|
def __init__(self, token_ids: list[int], sampling_params: SamplingParams):
|
||||||
|
self.seq_id = next(Sequence.counter)
|
||||||
|
self.status = SequenceStatus.WAITING
|
||||||
|
self.token_ids = copy(token_ids)
|
||||||
|
self.num_prompt_tokens = len(token_ids)
|
||||||
|
self._num_cached_tokens = 0
|
||||||
|
self.block_table = []
|
||||||
|
self.temperature = sampling_params.temperature
|
||||||
|
self.max_tokens = sampling_params.max_tokens
|
||||||
|
self.ignore_eos = sampling_params.ignore_eos
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.token_ids)
|
||||||
|
|
||||||
|
def __lt__(self, other):
|
||||||
|
return self.seq_id < other.seq_id
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
return self.token_ids[key]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_completion_tokens(self):
|
||||||
|
return len(self.token_ids) - self.num_prompt_tokens
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_cached_tokens(self):
|
||||||
|
return self._num_cached_tokens
|
||||||
|
|
||||||
|
@num_cached_tokens.setter
|
||||||
|
def num_cached_tokens(self, num_cached_tokens):
|
||||||
|
assert num_cached_tokens % 256 == 0
|
||||||
|
self._num_cached_tokens = num_cached_tokens
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_cached_blocks(self):
|
||||||
|
return self.num_cached_tokens // 256
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_blocks(self):
|
||||||
|
return (len(self.token_ids) + 255) // 256
|
||||||
|
|
||||||
|
@property
|
||||||
|
def last_token(self):
|
||||||
|
return self.token_ids[-1]
|
||||||
|
|
||||||
|
def block(self, i, block_size=256):
|
||||||
|
return self.token_ids[i*block_size: (i+1)*block_size]
|
||||||
|
|
||||||
|
def last_block(self, block_size=256):
|
||||||
|
n = self.num_blocks
|
||||||
|
t = len(self) + block_size - self.num_blocks * block_size
|
||||||
|
x = self.token_ids[(n-1)*block_size:]
|
||||||
|
assert len(x) == t
|
||||||
|
return x
|
||||||
|
|
||||||
|
def append_token(self, token_id: int):
|
||||||
|
self.token_ids.append(token_id)
|
||||||
14
nanovllm/layers/activation.py
Executable file
14
nanovllm/layers/activation.py
Executable file
@@ -0,0 +1,14 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class SiluAndMul(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@torch.compile
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
x, y = x.chunk(2, -1)
|
||||||
|
return F.silu(x) * y
|
||||||
86
nanovllm/layers/attention.py
Normal file
86
nanovllm/layers/attention.py
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
|
||||||
|
# from nanovllm.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
|
||||||
|
from nanovllm.utils.context import get_context
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def store_kvcache_kernel(
|
||||||
|
key_ptr,
|
||||||
|
key_stride,
|
||||||
|
value_ptr,
|
||||||
|
value_stride,
|
||||||
|
k_cache_ptr,
|
||||||
|
v_cache_ptr,
|
||||||
|
slot_mapping_ptr,
|
||||||
|
D: tl.constexpr,
|
||||||
|
):
|
||||||
|
idx = tl.program_id(0)
|
||||||
|
key_offsets = idx * key_stride + tl.arange(0, D)
|
||||||
|
value_offsets = idx * value_stride + tl.arange(0, D)
|
||||||
|
key = tl.load(key_ptr + key_offsets)
|
||||||
|
value = tl.load(value_ptr + value_offsets)
|
||||||
|
slot = tl.load(slot_mapping_ptr + idx)
|
||||||
|
cache_offsets = slot * D + tl.arange(0, D)
|
||||||
|
tl.store(k_cache_ptr + cache_offsets, key)
|
||||||
|
tl.store(v_cache_ptr + cache_offsets, value)
|
||||||
|
|
||||||
|
|
||||||
|
def store_kvcache(key: torch.Tensor, value: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, slot_mapping: torch.Tensor):
|
||||||
|
N, num_heads, head_dim = key.shape
|
||||||
|
D = num_heads * head_dim
|
||||||
|
assert key.stride(-1) == 1 and value.stride(-1) == 1
|
||||||
|
assert key.stride(1) == head_dim and value.stride(1) == head_dim
|
||||||
|
assert k_cache.stride(1) == D and v_cache.stride(1) == D
|
||||||
|
assert slot_mapping.numel() == N
|
||||||
|
store_kvcache_kernel[(N,)](key, key.stride(0), value, value.stride(0), k_cache, v_cache, slot_mapping, D)
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_heads,
|
||||||
|
head_dim,
|
||||||
|
scale,
|
||||||
|
num_kv_heads,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = head_dim
|
||||||
|
self.scale = scale
|
||||||
|
self.num_kv_heads = num_kv_heads
|
||||||
|
self.k_cache = self.v_cache = torch.tensor([])
|
||||||
|
|
||||||
|
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
|
||||||
|
o: torch.Tensor
|
||||||
|
q = q.view(-1, self.num_heads, self.head_dim)
|
||||||
|
k = k.view(-1, self.num_kv_heads, self.head_dim)
|
||||||
|
v = v.view(-1, self.num_kv_heads, self.head_dim)
|
||||||
|
context = get_context()
|
||||||
|
k_cache = self.k_cache
|
||||||
|
v_cache = self.v_cache
|
||||||
|
if context.is_prefill:
|
||||||
|
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
|
||||||
|
if context.block_tables is None: # normal prefill
|
||||||
|
cu_seqlens_k = context.cu_seqlens_k
|
||||||
|
seqused_k = None
|
||||||
|
else: # prefix cache
|
||||||
|
cu_seqlens_k = None
|
||||||
|
seqused_k = context.context_lens
|
||||||
|
k, v = k_cache, v_cache
|
||||||
|
o = flash_attn_varlen_func(q, k, v,
|
||||||
|
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
|
||||||
|
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=cu_seqlens_k,
|
||||||
|
seqused_k=seqused_k, softmax_scale=self.scale,
|
||||||
|
causal=True, block_table=context.block_tables)
|
||||||
|
else: # decode
|
||||||
|
o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache, k.unsqueeze(1), v.unsqueeze(1),
|
||||||
|
cache_seqlens=context.context_lens, block_table=context.block_tables,
|
||||||
|
softmax_scale=self.scale, causal=True)
|
||||||
|
o = o.view(-1, self.num_heads * self.head_dim)
|
||||||
|
return o
|
||||||
72
nanovllm/layers/embed_head.py
Normal file
72
nanovllm/layers/embed_head.py
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
from nanovllm.utils.context import get_context
|
||||||
|
|
||||||
|
|
||||||
|
class VocabParallelEmbedding(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_embeddings: int,
|
||||||
|
embedding_dim: int,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.tp_rank = 0 # get_tensor_model_parallel_rank()
|
||||||
|
self.tp_size = 1 # get_tensor_model_parallel_world_size()
|
||||||
|
assert num_embeddings % self.tp_size == 0
|
||||||
|
self.num_embeddings = num_embeddings
|
||||||
|
self.num_embeddings_per_partition = self.num_embeddings // self.tp_size
|
||||||
|
self.vocab_start_idx = self.num_embeddings_per_partition * self.tp_rank
|
||||||
|
self.vocab_end_idx = self.vocab_start_idx + self.num_embeddings_per_partition
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
self.weight = nn.Parameter(torch.empty(self.num_embeddings_per_partition, embedding_dim))
|
||||||
|
self.weight.weight_loader = self.weight_loader
|
||||||
|
|
||||||
|
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
|
||||||
|
param_data = param.data
|
||||||
|
shard_size = param_data.size(0)
|
||||||
|
start_idx = self.tp_rank * shard_size
|
||||||
|
loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
|
||||||
|
assert param_data.size() == loaded_weight.size()
|
||||||
|
param_data.copy_(loaded_weight)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
if self.tp_size > 1:
|
||||||
|
mask = (x >= self.vocab_start_idx) & (x < self.vocab_end_idx)
|
||||||
|
x = mask * (x - self.vocab_start_idx)
|
||||||
|
y = F.embedding(x, self.weight)
|
||||||
|
if self.tp_size > 1:
|
||||||
|
y = mask * y
|
||||||
|
dist.all_reduce(y)
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
class ParallelLMHead(VocabParallelEmbedding):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_embeddings: int,
|
||||||
|
embedding_dim: int,
|
||||||
|
bias: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__(num_embeddings, embedding_dim)
|
||||||
|
if bias:
|
||||||
|
self.bias = nn.Parameter(torch.empty(self.num_embeddings_per_partition))
|
||||||
|
self.bias.weight_loader = self.weight_loader
|
||||||
|
else:
|
||||||
|
self.register_parameter("bias", None)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
context = get_context()
|
||||||
|
if context.is_prefill:
|
||||||
|
last_indices = context.cu_seqlens_q[1:] - 1
|
||||||
|
x = x[last_indices].contiguous()
|
||||||
|
logits = F.linear(x, self.weight, self.bias)
|
||||||
|
# if self.tp_size > 1:
|
||||||
|
# all_logits = [torch.empty_like(logits) for _ in range(self.tp_size)]
|
||||||
|
# dist.gather(logits, all_logits, 0)
|
||||||
|
# logits = torch.cat(all_logits, -1)
|
||||||
|
return logits if self.tp_rank == 0 else None
|
||||||
51
nanovllm/layers/layernorm.py
Executable file
51
nanovllm/layers/layernorm.py
Executable file
@@ -0,0 +1,51 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
class RMSNorm(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
eps: float = 1e-6,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.eps = eps
|
||||||
|
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||||
|
|
||||||
|
@torch.compile
|
||||||
|
def rms_forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
orig_dtype = x.dtype
|
||||||
|
x = x.to(torch.float32)
|
||||||
|
var = x.pow(2).mean(dim=-1, keepdim=True)
|
||||||
|
x.mul_(torch.rsqrt(var + self.eps))
|
||||||
|
x = x.to(orig_dtype).mul_(self.weight)
|
||||||
|
return x
|
||||||
|
|
||||||
|
@torch.compile
|
||||||
|
def add_rms_forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
residual: torch.Tensor,
|
||||||
|
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
orig_dtype = x.dtype
|
||||||
|
x = x.to(torch.float32).add_(residual.to(torch.float32))
|
||||||
|
residual = x.to(orig_dtype)
|
||||||
|
var = x.pow(2).mean(dim=-1, keepdim=True)
|
||||||
|
x.mul_(torch.rsqrt(var + self.eps))
|
||||||
|
x = x.to(orig_dtype).mul_(self.weight)
|
||||||
|
return x, residual
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
residual: torch.Tensor | None = None,
|
||||||
|
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
if residual is None:
|
||||||
|
return self.rms_forward(x)
|
||||||
|
else:
|
||||||
|
return self.add_rms_forward(x, residual)
|
||||||
199
nanovllm/layers/linear.py
Executable file
199
nanovllm/layers/linear.py
Executable file
@@ -0,0 +1,199 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
|
||||||
|
def divide(numerator, denominator):
|
||||||
|
assert numerator % denominator == 0
|
||||||
|
return numerator // denominator
|
||||||
|
|
||||||
|
|
||||||
|
class LinearBase(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_size: int,
|
||||||
|
output_size: int,
|
||||||
|
tp_dim: int | None = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.input_size = input_size
|
||||||
|
self.output_size = output_size
|
||||||
|
self.tp_dim = tp_dim
|
||||||
|
self.tp_rank = 0 # get_tensor_model_parallel_rank()
|
||||||
|
self.tp_size = 1 # get_tensor_model_parallel_world_size()
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class ReplicatedLinear(LinearBase):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_size: int,
|
||||||
|
output_size: int,
|
||||||
|
bias: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__(input_size, output_size)
|
||||||
|
self.weight = nn.Parameter(torch.empty(self.output_size, self.input_size))
|
||||||
|
self.weight.weight_loader = self.weight_loader
|
||||||
|
if bias:
|
||||||
|
self.bias = nn.Parameter(torch.empty(self.output_size))
|
||||||
|
self.bias.weight_loader = self.weight_loader
|
||||||
|
else:
|
||||||
|
self.register_parameter("bias", None)
|
||||||
|
|
||||||
|
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
|
||||||
|
assert param.size() == loaded_weight.size()
|
||||||
|
param.data.copy_(loaded_weight)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return F.linear(x, self.weight, self.bias)
|
||||||
|
|
||||||
|
|
||||||
|
class ColumnParallelLinear(LinearBase):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_size: int,
|
||||||
|
output_size: int,
|
||||||
|
bias: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__(input_size, output_size, 0)
|
||||||
|
self.input_size_per_partition = input_size
|
||||||
|
self.output_size_per_partition = divide(output_size, self.tp_size)
|
||||||
|
self.output_partition_sizes = [self.output_size_per_partition]
|
||||||
|
# If QKV or MergedColumn, use output size of each partition.
|
||||||
|
if hasattr(self, "output_sizes"):
|
||||||
|
self.output_partition_sizes = [
|
||||||
|
divide(output_size, self.tp_size)
|
||||||
|
for output_size in self.output_sizes
|
||||||
|
]
|
||||||
|
|
||||||
|
self.weight = nn.Parameter(torch.empty(self.output_size_per_partition, self.input_size))
|
||||||
|
self.weight.weight_loader = self.weight_loader
|
||||||
|
if bias:
|
||||||
|
self.bias = nn.Parameter(torch.empty(self.output_size_per_partition))
|
||||||
|
self.bias.weight_loader = self.weight_loader
|
||||||
|
else:
|
||||||
|
self.register_parameter("bias", None)
|
||||||
|
|
||||||
|
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
|
||||||
|
param_data = param.data
|
||||||
|
shard_size = param_data.size(self.tp_dim)
|
||||||
|
start_idx = self.tp_rank * shard_size
|
||||||
|
loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size)
|
||||||
|
assert param_data.size() == loaded_weight.size()
|
||||||
|
param_data.copy_(loaded_weight)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return F.linear(x, self.weight, self.bias)
|
||||||
|
|
||||||
|
|
||||||
|
class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_size: int,
|
||||||
|
output_sizes: list[int],
|
||||||
|
bias: bool = False,
|
||||||
|
):
|
||||||
|
self.output_sizes = output_sizes
|
||||||
|
tp_size = 1 # get_tensor_model_parallel_world_size()
|
||||||
|
assert all(output_size % tp_size == 0 for output_size in output_sizes)
|
||||||
|
super().__init__(input_size, sum(output_sizes), bias=bias)
|
||||||
|
|
||||||
|
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: int | None = None):
|
||||||
|
param_data = param.data
|
||||||
|
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size
|
||||||
|
shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
|
||||||
|
param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size)
|
||||||
|
# loaded_weight = loaded_weight.narrow(self.tp_dim, self.tp_rank * shard_size, shard_size)
|
||||||
|
assert param_data.size() == loaded_weight.size()
|
||||||
|
param_data.copy_(loaded_weight)
|
||||||
|
|
||||||
|
|
||||||
|
class QKVParallelLinear(ColumnParallelLinear):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
head_size: int,
|
||||||
|
total_num_heads: int,
|
||||||
|
total_num_kv_heads: int | None = None,
|
||||||
|
bias: bool = False,
|
||||||
|
):
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.head_size = head_size
|
||||||
|
self.total_num_heads = total_num_heads
|
||||||
|
if total_num_kv_heads is None:
|
||||||
|
total_num_kv_heads = total_num_heads
|
||||||
|
self.total_num_kv_heads = total_num_kv_heads
|
||||||
|
# Divide the weight matrix along the last dimension.
|
||||||
|
tp_size = 1 # get_tensor_model_parallel_world_size()
|
||||||
|
self.num_heads = divide(self.total_num_heads, tp_size)
|
||||||
|
self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
|
||||||
|
input_size = self.hidden_size
|
||||||
|
output_size = (self.num_heads + 2 * self.num_kv_heads) * tp_size * self.head_size
|
||||||
|
self.output_sizes = [
|
||||||
|
self.num_heads * self.head_size * tp_size, # q_proj
|
||||||
|
self.num_kv_heads * self.head_size * tp_size, # k_proj
|
||||||
|
self.num_kv_heads * self.head_size * tp_size, # v_proj
|
||||||
|
]
|
||||||
|
|
||||||
|
super().__init__(input_size, output_size, bias)
|
||||||
|
|
||||||
|
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: str | None = None):
|
||||||
|
param_data = param.data
|
||||||
|
assert loaded_shard_id in ["q", "k", "v"]
|
||||||
|
if loaded_shard_id == "q":
|
||||||
|
shard_size = self.num_heads * self.head_size
|
||||||
|
shard_offset = 0
|
||||||
|
elif loaded_shard_id == "k":
|
||||||
|
shard_size = self.num_kv_heads * self.head_size
|
||||||
|
shard_offset = self.num_heads * self.head_size
|
||||||
|
else:
|
||||||
|
shard_size = self.num_kv_heads * self.head_size
|
||||||
|
shard_offset = self.num_heads * self.head_size + self.num_kv_heads * self.head_size
|
||||||
|
param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size)
|
||||||
|
# loaded_weight = loaded_weight.narrow(self.tp_dim, self.tp_rank * shard_size, shard_size)
|
||||||
|
assert param_data.size() == loaded_weight.size()
|
||||||
|
param_data.copy_(loaded_weight)
|
||||||
|
|
||||||
|
|
||||||
|
class RowParallelLinear(LinearBase):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_size: int,
|
||||||
|
output_size: int,
|
||||||
|
bias: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__(input_size, output_size, 1)
|
||||||
|
self.input_size_per_partition = divide(input_size, self.tp_size)
|
||||||
|
self.output_size_per_partition = output_size
|
||||||
|
self.output_partition_sizes = [output_size]
|
||||||
|
|
||||||
|
self.weight = nn.Parameter(torch.empty(self.output_size, self.input_size_per_partition))
|
||||||
|
self.weight.weight_loader = self.weight_loader
|
||||||
|
if bias:
|
||||||
|
self.bias = nn.Parameter(torch.empty(self.output_size))
|
||||||
|
self.bias.weight_loader = self.weight_loader
|
||||||
|
else:
|
||||||
|
self.register_parameter("bias", None)
|
||||||
|
|
||||||
|
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
|
||||||
|
param_data = param.data
|
||||||
|
shard_size = param_data.size(self.tp_dim)
|
||||||
|
start_idx = self.tp_rank * shard_size
|
||||||
|
loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size)
|
||||||
|
assert param_data.size() == loaded_weight.size()
|
||||||
|
param_data.copy_(loaded_weight)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
y = F.linear(x, self.weight, self.bias if self.tp_rank == 0 else None)
|
||||||
|
if self.tp_size > 1:
|
||||||
|
dist.all_reduce(y)
|
||||||
|
return y
|
||||||
73
nanovllm/layers/rotary_embedding.py
Normal file
73
nanovllm/layers/rotary_embedding.py
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rotary_emb(
|
||||||
|
x: torch.Tensor,
|
||||||
|
cos: torch.Tensor,
|
||||||
|
sin: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
cos = cos.unsqueeze(-2)
|
||||||
|
sin = sin.unsqueeze(-2)
|
||||||
|
x1, x2 = torch.chunk(x.to(torch.float32), 2, dim=-1)
|
||||||
|
y1 = x1 * cos - x2 * sin
|
||||||
|
y2 = x2 * cos + x1 * sin
|
||||||
|
return torch.cat((y1, y2), dim=-1).to(x.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class RotaryEmbedding(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
head_size: int,
|
||||||
|
rotary_dim: int,
|
||||||
|
max_position_embeddings: int,
|
||||||
|
base: float,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.head_size = head_size
|
||||||
|
self.rotary_dim = rotary_dim
|
||||||
|
assert rotary_dim == head_size
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.base = base
|
||||||
|
inv_freq = 1.0 / (base**(torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))
|
||||||
|
t = torch.arange(self.max_position_embeddings, dtype=torch.float)
|
||||||
|
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||||
|
cos = freqs.cos()
|
||||||
|
sin = freqs.sin()
|
||||||
|
cache = torch.cat((cos, sin), dim=-1)
|
||||||
|
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||||
|
|
||||||
|
@torch.compile
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
positions = positions.flatten()
|
||||||
|
num_tokens = positions.shape[0]
|
||||||
|
cos_sin = self.cos_sin_cache[positions]
|
||||||
|
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||||
|
query_shape = query.shape
|
||||||
|
query = query.view(num_tokens, -1, self.head_size)
|
||||||
|
query = apply_rotary_emb(query, cos, sin).view(query_shape)
|
||||||
|
key_shape = key.shape
|
||||||
|
key = key.view(num_tokens, -1, self.head_size)
|
||||||
|
key = apply_rotary_emb(key, cos, sin).view(key_shape)
|
||||||
|
return query, key
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(1)
|
||||||
|
def get_rope(
|
||||||
|
head_size: int,
|
||||||
|
rotary_dim: int,
|
||||||
|
max_position: int,
|
||||||
|
base: float,
|
||||||
|
rope_scaling: dict | None = None,
|
||||||
|
):
|
||||||
|
assert rope_scaling is None
|
||||||
|
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base)
|
||||||
|
return rotary_emb
|
||||||
17
nanovllm/layers/sampler.py
Normal file
17
nanovllm/layers/sampler.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
class Sampler(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, logits: torch.Tensor, temperatures: torch.Tensor | None = None):
|
||||||
|
logits = logits.to(torch.float)
|
||||||
|
if temperatures is not None:
|
||||||
|
logits.div_(temperatures.unsqueeze(dim=1))
|
||||||
|
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
||||||
|
# logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
|
||||||
|
sampled_tokens = probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1)
|
||||||
|
return sampled_tokens
|
||||||
5
nanovllm/llm.py
Normal file
5
nanovllm/llm.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
from nanovllm.engine.llm_engine import LLMEngine
|
||||||
|
|
||||||
|
|
||||||
|
class LLM(LLMEngine):
|
||||||
|
pass
|
||||||
247
nanovllm/models/qwen3.py
Executable file
247
nanovllm/models/qwen3.py
Executable file
@@ -0,0 +1,247 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from transformers import Qwen3Config
|
||||||
|
|
||||||
|
from nanovllm.layers.activation import SiluAndMul
|
||||||
|
from nanovllm.layers.attention import Attention
|
||||||
|
from nanovllm.layers.layernorm import RMSNorm
|
||||||
|
from nanovllm.layers.linear import QKVParallelLinear, MergedColumnParallelLinear, RowParallelLinear
|
||||||
|
from nanovllm.layers.rotary_embedding import get_rope
|
||||||
|
from nanovllm.layers.embed_head import VocabParallelEmbedding, ParallelLMHead
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen3Attention(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
num_heads: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
max_position: int = 4096 * 32,
|
||||||
|
head_dim: int | None = None,
|
||||||
|
rms_norm_eps: float = 1e-06,
|
||||||
|
qkv_bias: bool = False,
|
||||||
|
rope_theta: float = 10000,
|
||||||
|
rope_scaling: tuple | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
tp_size = 1 # get_tensor_model_parallel_world_size()
|
||||||
|
self.total_num_heads = num_heads
|
||||||
|
assert self.total_num_heads % tp_size == 0
|
||||||
|
self.num_heads = self.total_num_heads // tp_size
|
||||||
|
self.total_num_kv_heads = num_kv_heads
|
||||||
|
assert self.total_num_kv_heads % tp_size == 0
|
||||||
|
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||||
|
self.head_dim = head_dim or hidden_size // self.total_num_heads
|
||||||
|
self.q_size = self.num_heads * self.head_dim
|
||||||
|
self.kv_size = self.num_kv_heads * self.head_dim
|
||||||
|
self.scaling = self.head_dim**-0.5
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
|
||||||
|
self.qkv_proj = QKVParallelLinear(
|
||||||
|
hidden_size,
|
||||||
|
self.head_dim,
|
||||||
|
self.total_num_heads,
|
||||||
|
self.total_num_kv_heads,
|
||||||
|
bias=qkv_bias,
|
||||||
|
)
|
||||||
|
self.o_proj = RowParallelLinear(
|
||||||
|
self.total_num_heads * self.head_dim,
|
||||||
|
hidden_size,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.rotary_emb = get_rope(
|
||||||
|
self.head_dim,
|
||||||
|
rotary_dim=self.head_dim,
|
||||||
|
max_position=max_position,
|
||||||
|
base=self.rope_theta,
|
||||||
|
rope_scaling=rope_scaling,
|
||||||
|
)
|
||||||
|
self.attn = Attention(self.num_heads,
|
||||||
|
self.head_dim,
|
||||||
|
self.scaling,
|
||||||
|
num_kv_heads=self.num_kv_heads)
|
||||||
|
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
|
||||||
|
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
qkv = self.qkv_proj(hidden_states)
|
||||||
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
|
q_by_head = q.view(-1, self.num_heads, self.head_dim)
|
||||||
|
q_by_head = self.q_norm(q_by_head)
|
||||||
|
q = q_by_head.view(q.shape)
|
||||||
|
k_by_head = k.view(-1, self.num_kv_heads, self.head_dim)
|
||||||
|
k_by_head = self.k_norm(k_by_head)
|
||||||
|
k = k_by_head.view(k.shape)
|
||||||
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
|
o = self.attn(q, k, v)
|
||||||
|
output = self.o_proj(o)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen3MLP(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
intermediate_size: int,
|
||||||
|
hidden_act: str,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.gate_up_proj = MergedColumnParallelLinear(
|
||||||
|
hidden_size,
|
||||||
|
[intermediate_size] * 2,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.down_proj = RowParallelLinear(
|
||||||
|
intermediate_size,
|
||||||
|
hidden_size,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
assert hidden_act == "silu"
|
||||||
|
self.act_fn = SiluAndMul()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
gate_up = self.gate_up_proj(x)
|
||||||
|
x = self.act_fn(gate_up)
|
||||||
|
x = self.down_proj(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen3DecoderLayer(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: Qwen3Config,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.self_attn = Qwen3Attention(
|
||||||
|
hidden_size=self.hidden_size,
|
||||||
|
num_heads=config.num_attention_heads,
|
||||||
|
num_kv_heads=config.num_key_value_heads,
|
||||||
|
max_position=config.max_position_embeddings,
|
||||||
|
rms_norm_eps=config.rms_norm_eps,
|
||||||
|
qkv_bias=getattr(config, 'attention_bias', False),
|
||||||
|
head_dim=getattr(config, 'head_dim', None),
|
||||||
|
rope_theta=getattr(config, "rope_theta", 1000000),
|
||||||
|
rope_scaling=getattr(config, "rope_scaling", None),
|
||||||
|
)
|
||||||
|
self.mlp = Qwen3MLP(
|
||||||
|
hidden_size=config.hidden_size,
|
||||||
|
intermediate_size=config.intermediate_size,
|
||||||
|
hidden_act=config.hidden_act,
|
||||||
|
)
|
||||||
|
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||||
|
eps=config.rms_norm_eps)
|
||||||
|
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||||
|
eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
residual: torch.Tensor | None,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
if residual is None:
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
else:
|
||||||
|
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||||
|
hidden_states = self.self_attn(
|
||||||
|
positions=positions,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
)
|
||||||
|
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
return hidden_states, residual
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen3Model(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: Qwen3Config,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.vocab_size = config.vocab_size
|
||||||
|
self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size)
|
||||||
|
self.layers = nn.ModuleList([Qwen3DecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||||||
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
residual = None
|
||||||
|
for layer in self.layers:
|
||||||
|
hidden_states, residual = layer(
|
||||||
|
positions,
|
||||||
|
hidden_states,
|
||||||
|
residual,
|
||||||
|
)
|
||||||
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen3ForCausalLM(nn.Module):
|
||||||
|
packed_modules_mapping = {
|
||||||
|
"qkv_proj": [
|
||||||
|
"q_proj",
|
||||||
|
"k_proj",
|
||||||
|
"v_proj",
|
||||||
|
],
|
||||||
|
"gate_up_proj": [
|
||||||
|
"gate_proj",
|
||||||
|
"up_proj",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: Qwen3Config
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.model = Qwen3Model(config)
|
||||||
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||||
|
if config.tie_word_embeddings:
|
||||||
|
self.lm_head.weight.data = self.model.embed_tokens.weight.data
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
hidden_states = self.model(input_ids, positions)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def compute_logits(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
logits = self.lm_head(hidden_states)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def load_weights(self, path: str):
|
||||||
|
import os
|
||||||
|
from safetensors import safe_open
|
||||||
|
default_weight_loader = lambda param, loaded_weight: param.data.copy_(loaded_weight)
|
||||||
|
with safe_open(os.path.join(path, "model.safetensors"), "pt", "cpu") as f:
|
||||||
|
for n, p in self.named_parameters():
|
||||||
|
for x in self.packed_modules_mapping:
|
||||||
|
if x in n:
|
||||||
|
weight_loader = getattr(p, "weight_loader", default_weight_loader)
|
||||||
|
for i, y in enumerate(self.packed_modules_mapping[x]):
|
||||||
|
weight_loader(p, f.get_tensor(n.replace(x, y)), y[0] if x == "qkv_proj" else i)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
weight_loader = getattr(p, "weight_loader", default_weight_loader)
|
||||||
|
weight_loader(p, f.get_tensor(n))
|
||||||
8
nanovllm/sampling_params.py
Normal file
8
nanovllm/sampling_params.py
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SamplingParams:
|
||||||
|
temperature: float = 1.0
|
||||||
|
max_tokens: int = 64
|
||||||
|
ignore_eos: bool = False
|
||||||
28
nanovllm/utils/context.py
Normal file
28
nanovllm/utils/context.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
from contextlib import contextmanager
|
||||||
|
from dataclasses import dataclass
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Context:
|
||||||
|
is_prefill: bool = False
|
||||||
|
cu_seqlens_q: torch.Tensor | None = None
|
||||||
|
cu_seqlens_k: torch.Tensor | None = None
|
||||||
|
max_seqlen_q: int = 0
|
||||||
|
max_seqlen_k: int = 0
|
||||||
|
slot_mapping: torch.Tensor | None = None
|
||||||
|
context_lens: torch.Tensor | None = None
|
||||||
|
block_tables: torch.Tensor | None = None
|
||||||
|
|
||||||
|
_CONTEXT = Context()
|
||||||
|
|
||||||
|
def get_context():
|
||||||
|
return _CONTEXT
|
||||||
|
|
||||||
|
def set_context(is_prefill, cu_seqlens_q=None, cu_seqlens_k=None, max_seqlen_q=0, max_seqlen_k=0, slot_mapping=None, context_lens=None, block_tables=None, ):
|
||||||
|
global _CONTEXT
|
||||||
|
_CONTEXT = Context(is_prefill, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, context_lens, block_tables)
|
||||||
|
|
||||||
|
def reset_context():
|
||||||
|
global _CONTEXT
|
||||||
|
_CONTEXT = Context()
|
||||||
14
nanovllm/utils/memory.py
Normal file
14
nanovllm/utils/memory.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def get_gpu_memory(device_id: int = 0):
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
result = subprocess.check_output(
|
||||||
|
['nvidia-smi', '-i', str(device_id), '--query-gpu=memory.total,memory.used,memory.free', '--format=csv,nounits,noheader'],
|
||||||
|
encoding='utf-8'
|
||||||
|
)
|
||||||
|
total_memory, used_memory, free_memory = [int(x) for x in result.strip().split(', ')]
|
||||||
|
return total_memory, used_memory, free_memory
|
||||||
|
|
||||||
31
nanovllm/utils/timer.py
Normal file
31
nanovllm/utils/timer.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
from contextlib import contextmanager
|
||||||
|
from collections import defaultdict
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class CUDATimer:
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.events = defaultdict(list)
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def record(self, name, enabled=True):
|
||||||
|
if not enabled:
|
||||||
|
yield
|
||||||
|
else:
|
||||||
|
start, end = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
|
||||||
|
self.events[name].append((start, end))
|
||||||
|
start.record()
|
||||||
|
yield
|
||||||
|
end.record()
|
||||||
|
|
||||||
|
def log(self):
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
ret = []
|
||||||
|
for name, events in self.events.items():
|
||||||
|
total = 0
|
||||||
|
count = len(self.events)
|
||||||
|
for start, end in events:
|
||||||
|
total += start.elapsed_time(end)
|
||||||
|
ret.append(f"{name} {total:.2f}ms/{count}times")
|
||||||
|
return ", ".join(ret)
|
||||||
5
requirements.txt
Normal file
5
requirements.txt
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
torch
|
||||||
|
triton
|
||||||
|
transformers
|
||||||
|
cmake
|
||||||
|
ninja
|
||||||
Reference in New Issue
Block a user