init commit

This commit is contained in:
GeeeekExplorer
2025-06-10 00:23:23 +08:00
commit a5a4909e6a
26 changed files with 1677 additions and 0 deletions

195
.gitignore vendored Normal file
View File

@@ -0,0 +1,195 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# UV
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
#uv.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
#poetry.toml
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
.pdm.toml
.pdm-python
.pdm-build/
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
# Abstra
# Abstra is an AI-powered process automation framework.
# Ignore directories containing user credentials, local state, and settings.
# Learn more at https://abstra.io/docs
.abstra/
# Visual Studio Code
# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
# and can be added to the global gitignore or merged into this file. However, if you prefer,
# you could uncomment the following to ignore the entire vscode folder
# .vscode/
# Ruff stuff:
.ruff_cache/
# PyPI configuration file
.pypirc
# Cursor
# Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
# exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
# refer to https://docs.cursor.com/context/ignore-files
.cursorignore
.cursorindexingignore

21
LICENSE Normal file
View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2025 Xingkai Yu
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

1
README.md Normal file
View File

@@ -0,0 +1 @@
# Nano-VLLM

20
bench.py Normal file
View File

@@ -0,0 +1,20 @@
import os
import time
import torch
from nanovllm import LLM, SamplingParams
batch_size = 256
seq_len = 1024
max_tokens = 512
path = os.path.expanduser("~/huggingface/Qwen3-0.6B/")
llm = LLM(path, enforce_eager=False)
prompt_token_ids = torch.randint(0, 10240, (batch_size, seq_len)).tolist()
sampling_params = SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=max_tokens)
t = time.time()
completions = llm.generate(prompt_token_ids, sampling_params)
troughput = batch_size * max_tokens / (time.time() - t)
print(f"Throughput: {troughput: .2f}")

29
example.py Normal file
View File

@@ -0,0 +1,29 @@
import os
from nanovllm import LLM, SamplingParams
from transformers import AutoTokenizer
path = os.path.expanduser("~/huggingface/Qwen3-0.6B/")
tokenizer = AutoTokenizer.from_pretrained(path)
llm = LLM(path, enforce_eager=True)
sampling_params = SamplingParams(temperature=0.6, max_tokens=256)
prompts = [
"自我介绍一下吧!",
"列出100内所有素数",
]
prompts = [
tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
tokenize=False,
add_generation_prompt=True,
enable_thinking=True
)
for prompt in prompts
]
completions = llm.generate(prompts, sampling_params)
for p, c in zip(prompts, completions):
print("\n\n")
print(f"Prompt: {p}")
print(f"Completion: {c}")

2
nanovllm/__init__.py Normal file
View File

@@ -0,0 +1,2 @@
from nanovllm.llm import LLM
from nanovllm.sampling_params import SamplingParams

20
nanovllm/config.py Normal file
View File

@@ -0,0 +1,20 @@
from dataclasses import dataclass
from transformers import AutoConfig
@dataclass
class Config:
model: str = ''
max_num_batched_tokens: int = 16384
max_num_seqs: int = 512
max_model_len: int = 4096
gpu_memory_utilization: float = 0.95
enforce_eager: bool = False
hf_config: AutoConfig | None = None
eos: int = -1
kvcache_block_size: int = 256
num_kvcache_blocks: int = -1
def __post_init__(self):
assert self.model
assert self.kvcache_block_size == 256

View File

@@ -0,0 +1,118 @@
from collections import deque
import xxhash
import numpy as np
from nanovllm.engine.sequence import Sequence
def compute_hash(token_ids: list[int], prefix: int = -1):
h = xxhash.xxh64()
if prefix != -1:
h.update(prefix.to_bytes(8))
h.update(np.array(token_ids).tobytes())
return h.intdigest()
class Block:
def __init__(self, block_id):
self.block_id = block_id
self.ref_count = 0
self.hash = -1
self.token_ids = []
def update(self, hash: int, token_ids: list[int]):
assert hash != -1
assert len(token_ids) == 256
self.hash = hash
self.token_ids = token_ids
def reset(self):
self.ref_count = 1
self.hash = -1
self.token_ids = []
def __repr__(self):
return f"{(self.block_id, self.ref_count, self.hash)}"
class BlockManager:
def __init__(self, num_blocks: int, block_size: int = 256):
assert block_size == 256
self.block_size = block_size
self.blocks: list[Block] = [Block(i) for i in range(num_blocks)]
self.hash_to_block_id: dict[int, int] = dict()
self.free_block_ids: deque[int] = deque(range(num_blocks))
self.used_block_ids: set[int] = set()
def _allocate_block(self, block_id: int):
block = self.blocks[block_id]
assert block.ref_count == 0
block.reset()
self.free_block_ids.remove(block_id)
self.used_block_ids.add(block_id)
return self.blocks[block_id]
def _deallocate_block(self, block_id: int):
assert self.blocks[block_id].ref_count == 0
self.used_block_ids.remove(block_id)
self.free_block_ids.append(block_id)
def can_allocate(self, seq: Sequence):
return seq.num_blocks <= len(self.free_block_ids)
def allocate(self, seq: Sequence):
assert not seq.block_table
h = -1
cache_miss = False
for i in range(seq.num_blocks):
token_ids = seq.block(i, self.block_size)
h = compute_hash(token_ids, h) if len(token_ids) == self.block_size else -1
block_id = self.hash_to_block_id.get(h, -1)
if block_id == -1 or self.blocks[block_id].token_ids != token_ids:
cache_miss = True
if cache_miss:
block_id = self.free_block_ids[0]
block = self._allocate_block(block_id)
else:
seq.num_cached_tokens += self.block_size
if block_id in self.used_block_ids:
block = self.blocks[block_id]
block.ref_count += 1
else:
block = self._allocate_block(block_id)
if h != -1:
block.update(h, token_ids)
self.hash_to_block_id[h] = block_id
seq.block_table.append(block_id)
def deallocate(self, seq: Sequence):
for block_id in seq.block_table:
block = self.blocks[block_id]
block.ref_count -= 1
if block.ref_count == 0:
self._deallocate_block(block_id)
seq.num_cached_tokens = 0
seq.block_table.clear()
def can_append(self):
return len(self.free_block_ids) >= 1
def may_append(self, seq: Sequence):
block_table = seq.block_table
last_block = self.blocks[block_table[-1]]
if len(seq) % self.block_size == 1:
assert last_block.hash != -1
block_id = self.free_block_ids[0]
self._allocate_block(block_id)
block_table.append(block_id)
elif len(seq) % self.block_size == 0:
assert last_block.hash == -1
token_ids = seq.last_block(self.block_size)
prefix = self.blocks[block_table[-2]].hash if len(block_table) > 1 else -1
h = compute_hash(token_ids, prefix)
last_block.update(h, token_ids)
self.hash_to_block_id[h] = last_block.block_id
else:
assert last_block.hash == -1

View File

@@ -0,0 +1,66 @@
from collections import defaultdict
from tqdm.auto import tqdm
from transformers import AutoConfig, AutoTokenizer
from nanovllm.config import Config
from nanovllm.sampling_params import SamplingParams
from nanovllm.engine.sequence import Sequence
from nanovllm.engine.scheduler import Scheduler
from nanovllm.engine.model_runner import ModelRunner
class LLMEngine:
def __init__(self, model, **kwargs):
config = Config(model)
for k, v in kwargs.items():
if hasattr(config, k):
setattr(config, k, v)
config.hf_config = AutoConfig.from_pretrained(config.model)
config.max_model_len = min(config.max_model_len, config.hf_config.max_position_embeddings)
self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True)
config.eos = self.tokenizer.eos_token_id
self.model_runner = ModelRunner(config)
self.scheduler = Scheduler(config)
def add_request(self, prompt: str | list[int], sampling_params: SamplingParams):
if isinstance(prompt, str):
prompt = self.tokenizer.encode(prompt)
seq = Sequence(prompt, sampling_params)
self.scheduler.add(seq)
def step(self):
seqs, is_prefill = self.scheduler.schedule()
token_ids = self.model_runner.run(seqs, is_prefill)
finished = self.scheduler.postprocess(seqs, token_ids)
return [(seq.seq_id, token_id, finish) for seq, token_id, finish in zip(seqs, token_ids, finished)]
def is_finished(self):
return self.scheduler.is_finished()
def generate(
self,
prompts: list[str] | list[list[int]],
sampling_params: SamplingParams | list[SamplingParams],
use_tqdm: bool = True,
) -> list[str]:
if use_tqdm:
pbar = tqdm(total=len(prompts),
desc="Processed prompts",
)
if not isinstance(SamplingParams, list):
sampling_params = [sampling_params] * len(prompts)
for prompt, sp in zip(prompts, sampling_params):
self.add_request(prompt, sp)
outputs = defaultdict(list)
while not self.is_finished():
output = self.step()
for seq_id, token_id, finish in output:
outputs[seq_id].append(token_id)
if use_tqdm and finish:
pbar.update(1)
outputs = [outputs[seq_id] for seq_id in sorted(outputs)]
outputs = [self.tokenizer.decode(token_ids) for token_ids in outputs]
if use_tqdm:
pbar.close()
return outputs

View File

@@ -0,0 +1,198 @@
import torch
from nanovllm.config import Config
from nanovllm.engine.sequence import Sequence
from nanovllm.utils.context import set_context, get_context, reset_context
from nanovllm.utils.memory import get_gpu_memory
from nanovllm.models.qwen3 import Qwen3ForCausalLM
from nanovllm.layers.sampler import Sampler
class ModelRunner:
def __init__(self, config: Config):
self.config = config
hf_config = config.hf_config
self.block_size = config.kvcache_block_size
self.enforce_eager = config.enforce_eager
default_dtype = torch.get_default_dtype()
torch.set_default_dtype(hf_config.torch_dtype)
torch.set_default_device("cuda")
self.model = Qwen3ForCausalLM(hf_config)
self.model.load_weights(config.model)
self.sampler = Sampler()
self.allocate_kv_cache(config.gpu_memory_utilization)
if not self.enforce_eager:
self.capture_model()
torch.set_default_device("cpu")
torch.set_default_dtype(default_dtype)
def allocate_kv_cache(self, gpu_memory_utilization):
config = self.config
hf_config = config.hf_config
total, used, _ = get_gpu_memory()
free = total * gpu_memory_utilization - used
block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * hf_config.num_key_value_heads * hf_config.head_dim * hf_config.torch_dtype.itemsize
config.num_kvcache_blocks = int(free * 1e6) // block_bytes
self.kv_cache = torch.zeros(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, hf_config.num_key_value_heads, hf_config.head_dim)
layer_id = 0
for module in self.model.modules():
if hasattr(module, "k_cache") and hasattr(module, "v_cache"):
module.k_cache = self.kv_cache[0, layer_id]
module.v_cache = self.kv_cache[1, layer_id]
layer_id += 1
def preare_block_tables(self, seqs: list[Sequence]):
max_len = max(len(seq.block_table) for seq in seqs)
block_tables = [
seq.block_table + [-1] * (max_len - len(seq.block_table))
for seq in seqs
]
block_tables = torch.tensor(block_tables, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
return block_tables
def prepare_prefill(self, seqs: list[Sequence]):
input_ids = []
positions = []
cu_seqlens_q = [0]
cu_seqlens_k = [0]
max_seqlen_q = 0
max_seqlen_k = 0
slot_mapping = []
context_lens = None
block_tables = None
for seq in seqs:
seqlen = len(seq)
input_ids.extend(seq[seq.num_cached_tokens:])
positions.extend(list(range(seq.num_cached_tokens, len(seq))))
seqlen_q = seqlen - seq.num_cached_tokens
seqlen_k = seqlen
cu_seqlens_q.append(cu_seqlens_q[-1] + seqlen_q)
cu_seqlens_k.append(cu_seqlens_k[-1] + seqlen_k)
max_seqlen_q = max(seqlen_q, max_seqlen_q)
max_seqlen_k = max(seqlen_k, max_seqlen_k)
for i in range(seq.num_cached_blocks, seq.num_blocks):
start = seq.block_table[i] * self.block_size
if i != seq.num_blocks - 1:
end = start + self.block_size
else:
end = start + len(seq.last_block())
slot_mapping.extend(list(range(start, end)))
assert len(input_ids) == len(slot_mapping)
assert len(input_ids) == cu_seqlens_q[-1]
if cu_seqlens_k[-1] > cu_seqlens_q[-1]: # prefix cache
context_lens = torch.tensor([len(seq) for seq in seqs], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
block_tables = self.preare_block_tables(seqs)
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
cu_seqlens_q = torch.tensor(cu_seqlens_q, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
cu_seqlens_k = torch.tensor(cu_seqlens_k, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, context_lens, block_tables)
return input_ids, positions
def prepare_decode(self, seqs: list[Sequence]):
input_ids = []
positions = []
slot_mapping = []
context_lens = []
for seq in seqs:
input_ids.append(seq.last_token)
positions.append(len(seq))
context_lens.append(len(seq))
slot_mapping.append(seq.block_table[-1] * self.block_size + len(seq.last_block()))
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
context_lens = torch.tensor(context_lens, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
block_tables = self.preare_block_tables(seqs)
set_context(False, slot_mapping=slot_mapping, context_lens=context_lens, block_tables=block_tables)
return input_ids, positions
def prepare_sample(self, seqs: list[Sequence]):
temperatures = []
for seq in seqs:
temperatures.append(seq.temperature)
temperatures = torch.tensor(temperatures, dtype=torch.float32, pin_memory=True).cuda(non_blocking=True)
return temperatures
@torch.inference_mode()
def run_model(self, input_ids: torch.Tensor, positions: torch.Tensor, is_prefill):
if is_prefill or self.enforce_eager or input_ids.size(0) > 256:
return self.model.compute_logits(self.model(input_ids, positions))
else:
bs = input_ids.size(0)
context = get_context()
self.reset_graph_vars()
graph = self.graphs[next(x for x in self.graph_bs if x >= bs)]
graph_vars = self.graph_vars
graph_vars["input_ids"][:bs] = input_ids
graph_vars["positions"][:bs] = positions
graph_vars["slot_mapping"][:bs] = context.slot_mapping
graph_vars["context_lens"][:bs] = context.context_lens
graph_vars["block_tables"][:bs, :context.block_tables.size(1)] = context.block_tables
graph.replay()
return self.model.compute_logits(graph_vars["outputs"][:bs])
def reset_graph_vars(self):
graph_vars = self.graph_vars
graph_vars["input_ids"].zero_()
graph_vars["positions"].zero_()
graph_vars["slot_mapping"].zero_()
graph_vars["context_lens"].zero_()
graph_vars["block_tables"].zero_()
def run(self, seqs: list[Sequence], is_prefill: bool) -> list[int]:
input_ids, positions = self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs)
temperatures = self.prepare_sample(seqs)
logits = self.run_model(input_ids, positions, is_prefill)
token_ids = self.sampler(logits, temperatures).tolist()
reset_context()
return token_ids
@torch.inference_mode()
def capture_model(self):
get_rng_state = torch.cuda.get_rng_state
set_rng_state = torch.cuda.set_rng_state
rng_state = torch.cuda.get_rng_state()
torch.cuda.get_rng_state = lambda: rng_state
torch.cuda.set_rng_state = lambda _: None
config = self.config
hf_config = config.hf_config
max_bs = min(self.config.max_num_seqs, 256)
max_num_blocks = (config.max_model_len + self.block_size - 1) // self.block_size
input_ids = torch.zeros(max_bs, dtype=torch.int64)
positions = torch.zeros(max_bs, dtype=torch.int64)
slot_mapping = torch.zeros(max_bs, dtype=torch.int32)
context_lens = torch.zeros(max_bs, dtype=torch.int32)
block_tables = torch.zeros(max_bs, max_num_blocks, dtype=torch.int32)
outputs = torch.zeros(max_bs, hf_config.hidden_size)
self.graph_bs = [1, 2, 4, 8, 16] + list(range(32, max_bs + 1, 32))
self.graphs = {}
self.graph_pool = None
for bs in reversed(self.graph_bs):
graph = torch.cuda.CUDAGraph()
set_context(False, slot_mapping=slot_mapping[:bs], context_lens=context_lens[:bs], block_tables=block_tables[:bs])
outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # warmup
with torch.cuda.graph(graph, self.graph_pool):
outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # capture
if self.graph_pool is None:
self.graph_pool = graph.pool()
self.graphs[bs] = graph
torch.cuda.synchronize()
reset_context()
self.graph_vars = dict(
input_ids=input_ids,
positions=positions,
slot_mapping=slot_mapping,
context_lens=context_lens,
block_tables=block_tables,
outputs=outputs,
)
torch.cuda.get_rng_state = get_rng_state
torch.cuda.set_rng_state = set_rng_state

View File

@@ -0,0 +1,84 @@
from collections import deque
from nanovllm.config import Config
from nanovllm.engine.sequence import Sequence, SequenceStatus
from nanovllm.engine.block_manager import BlockManager
class Scheduler:
def __init__(self, config: Config):
self.max_num_seqs = config.max_num_seqs
self.max_num_batched_tokens = config.max_num_batched_tokens
self.eos = config.eos
self.block_manager = BlockManager(config.num_kvcache_blocks, config.kvcache_block_size)
self.waiting: deque[Sequence] = deque()
self.running: deque[Sequence] = deque()
self.num_finished = 0
self.num_tokens = 0
def is_finished(self):
return not self.waiting and not self.running
def add(self, seq: Sequence):
self.waiting.append(seq)
def schedule(self) -> tuple[list[Sequence], SequenceStatus]:
# prefill
scheduled_seqs = []
num_seqs = 0
num_batched_tokens = 0
while self.waiting and num_seqs < self.max_num_seqs:
seq = self.waiting[0]
if num_batched_tokens + len(seq) > self.max_num_batched_tokens or not self.block_manager.can_allocate(seq):
break
num_seqs += 1
self.block_manager.allocate(seq)
num_batched_tokens += len(seq) - seq.num_cached_tokens
seq.status = SequenceStatus.RUNNING
self.waiting.popleft()
self.running.append(seq)
scheduled_seqs.append(seq)
if scheduled_seqs:
return scheduled_seqs, True
# decode
# self.running = deque(sorted(self.running))
while self.running and num_seqs < self.max_num_seqs:
seq = self.running.popleft()
while not self.block_manager.can_append():
if self.running:
self.preempt(self.running.pop())
else:
self.preempt(seq)
break
else:
num_seqs += 1
self.block_manager.may_append(seq)
scheduled_seqs.append(seq)
running = deque(scheduled_seqs)
running.extend(self.running)
self.running = running
if scheduled_seqs:
return scheduled_seqs, False
def preempt(self, seq: Sequence):
seq.status = SequenceStatus.WAITING
self.block_manager.deallocate(seq)
self.waiting.appendleft(seq)
return True
def postprocess(self, seqs: list[Sequence], token_ids: list[int]) -> list[bool]:
self.num_tokens += len(token_ids)
finished = []
for seq, token_id in zip(seqs, token_ids):
seq.append_token(token_id)
if token_id == self.eos or seq.num_completion_tokens == seq.max_tokens:
seq.status = SequenceStatus.FINISHED
self.block_manager.deallocate(seq)
self.running.remove(seq)
self.num_finished += 1
finished.append(True)
else:
finished.append(False)
return finished

View File

@@ -0,0 +1,73 @@
from copy import copy
from enum import Enum, auto
from itertools import count
from nanovllm.sampling_params import SamplingParams
class SequenceStatus(Enum):
WAITING = auto()
RUNNING = auto()
FINISHED = auto()
class Sequence:
counter = count()
def __init__(self, token_ids: list[int], sampling_params: SamplingParams):
self.seq_id = next(Sequence.counter)
self.status = SequenceStatus.WAITING
self.token_ids = copy(token_ids)
self.num_prompt_tokens = len(token_ids)
self._num_cached_tokens = 0
self.block_table = []
self.temperature = sampling_params.temperature
self.max_tokens = sampling_params.max_tokens
self.ignore_eos = sampling_params.ignore_eos
def __len__(self):
return len(self.token_ids)
def __lt__(self, other):
return self.seq_id < other.seq_id
def __getitem__(self, key):
return self.token_ids[key]
@property
def num_completion_tokens(self):
return len(self.token_ids) - self.num_prompt_tokens
@property
def num_cached_tokens(self):
return self._num_cached_tokens
@num_cached_tokens.setter
def num_cached_tokens(self, num_cached_tokens):
assert num_cached_tokens % 256 == 0
self._num_cached_tokens = num_cached_tokens
@property
def num_cached_blocks(self):
return self.num_cached_tokens // 256
@property
def num_blocks(self):
return (len(self.token_ids) + 255) // 256
@property
def last_token(self):
return self.token_ids[-1]
def block(self, i, block_size=256):
return self.token_ids[i*block_size: (i+1)*block_size]
def last_block(self, block_size=256):
n = self.num_blocks
t = len(self) + block_size - self.num_blocks * block_size
x = self.token_ids[(n-1)*block_size:]
assert len(x) == t
return x
def append_token(self, token_id: int):
self.token_ids.append(token_id)

14
nanovllm/layers/activation.py Executable file
View File

@@ -0,0 +1,14 @@
import torch
from torch import nn
import torch.nn.functional as F
class SiluAndMul(nn.Module):
def __init__(self):
super().__init__()
@torch.compile
def forward(self, x: torch.Tensor) -> torch.Tensor:
x, y = x.chunk(2, -1)
return F.silu(x) * y

View File

@@ -0,0 +1,86 @@
import torch
from torch import nn
import triton
import triton.language as tl
from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
# from nanovllm.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
from nanovllm.utils.context import get_context
@triton.jit
def store_kvcache_kernel(
key_ptr,
key_stride,
value_ptr,
value_stride,
k_cache_ptr,
v_cache_ptr,
slot_mapping_ptr,
D: tl.constexpr,
):
idx = tl.program_id(0)
key_offsets = idx * key_stride + tl.arange(0, D)
value_offsets = idx * value_stride + tl.arange(0, D)
key = tl.load(key_ptr + key_offsets)
value = tl.load(value_ptr + value_offsets)
slot = tl.load(slot_mapping_ptr + idx)
cache_offsets = slot * D + tl.arange(0, D)
tl.store(k_cache_ptr + cache_offsets, key)
tl.store(v_cache_ptr + cache_offsets, value)
def store_kvcache(key: torch.Tensor, value: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, slot_mapping: torch.Tensor):
N, num_heads, head_dim = key.shape
D = num_heads * head_dim
assert key.stride(-1) == 1 and value.stride(-1) == 1
assert key.stride(1) == head_dim and value.stride(1) == head_dim
assert k_cache.stride(1) == D and v_cache.stride(1) == D
assert slot_mapping.numel() == N
store_kvcache_kernel[(N,)](key, key.stride(0), value, value.stride(0), k_cache, v_cache, slot_mapping, D)
class Attention(nn.Module):
def __init__(
self,
num_heads,
head_dim,
scale,
num_kv_heads,
):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
self.scale = scale
self.num_kv_heads = num_kv_heads
self.k_cache = self.v_cache = torch.tensor([])
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
o: torch.Tensor
q = q.view(-1, self.num_heads, self.head_dim)
k = k.view(-1, self.num_kv_heads, self.head_dim)
v = v.view(-1, self.num_kv_heads, self.head_dim)
context = get_context()
k_cache = self.k_cache
v_cache = self.v_cache
if context.is_prefill:
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
if context.block_tables is None: # normal prefill
cu_seqlens_k = context.cu_seqlens_k
seqused_k = None
else: # prefix cache
cu_seqlens_k = None
seqused_k = context.context_lens
k, v = k_cache, v_cache
o = flash_attn_varlen_func(q, k, v,
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=cu_seqlens_k,
seqused_k=seqused_k, softmax_scale=self.scale,
causal=True, block_table=context.block_tables)
else: # decode
o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache, k.unsqueeze(1), v.unsqueeze(1),
cache_seqlens=context.context_lens, block_table=context.block_tables,
softmax_scale=self.scale, causal=True)
o = o.view(-1, self.num_heads * self.head_dim)
return o

View File

@@ -0,0 +1,72 @@
import torch
from torch import nn
import torch.nn.functional as F
import torch.distributed as dist
from nanovllm.utils.context import get_context
class VocabParallelEmbedding(nn.Module):
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
):
super().__init__()
self.tp_rank = 0 # get_tensor_model_parallel_rank()
self.tp_size = 1 # get_tensor_model_parallel_world_size()
assert num_embeddings % self.tp_size == 0
self.num_embeddings = num_embeddings
self.num_embeddings_per_partition = self.num_embeddings // self.tp_size
self.vocab_start_idx = self.num_embeddings_per_partition * self.tp_rank
self.vocab_end_idx = self.vocab_start_idx + self.num_embeddings_per_partition
self.embedding_dim = embedding_dim
self.weight = nn.Parameter(torch.empty(self.num_embeddings_per_partition, embedding_dim))
self.weight.weight_loader = self.weight_loader
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
param_data = param.data
shard_size = param_data.size(0)
start_idx = self.tp_rank * shard_size
loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
assert param_data.size() == loaded_weight.size()
param_data.copy_(loaded_weight)
def forward(self, x: torch.Tensor):
if self.tp_size > 1:
mask = (x >= self.vocab_start_idx) & (x < self.vocab_end_idx)
x = mask * (x - self.vocab_start_idx)
y = F.embedding(x, self.weight)
if self.tp_size > 1:
y = mask * y
dist.all_reduce(y)
return y
class ParallelLMHead(VocabParallelEmbedding):
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
bias: bool = False,
):
super().__init__(num_embeddings, embedding_dim)
if bias:
self.bias = nn.Parameter(torch.empty(self.num_embeddings_per_partition))
self.bias.weight_loader = self.weight_loader
else:
self.register_parameter("bias", None)
def forward(self, x: torch.Tensor):
context = get_context()
if context.is_prefill:
last_indices = context.cu_seqlens_q[1:] - 1
x = x[last_indices].contiguous()
logits = F.linear(x, self.weight, self.bias)
# if self.tp_size > 1:
# all_logits = [torch.empty_like(logits) for _ in range(self.tp_size)]
# dist.gather(logits, all_logits, 0)
# logits = torch.cat(all_logits, -1)
return logits if self.tp_rank == 0 else None

51
nanovllm/layers/layernorm.py Executable file
View File

@@ -0,0 +1,51 @@
import torch
from torch import nn
class RMSNorm(nn.Module):
def __init__(
self,
hidden_size: int,
eps: float = 1e-6,
) -> None:
super().__init__()
self.hidden_size = hidden_size
self.eps = eps
self.weight = nn.Parameter(torch.ones(hidden_size))
@torch.compile
def rms_forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
orig_dtype = x.dtype
x = x.to(torch.float32)
var = x.pow(2).mean(dim=-1, keepdim=True)
x.mul_(torch.rsqrt(var + self.eps))
x = x.to(orig_dtype).mul_(self.weight)
return x
@torch.compile
def add_rms_forward(
self,
x: torch.Tensor,
residual: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
orig_dtype = x.dtype
x = x.to(torch.float32).add_(residual.to(torch.float32))
residual = x.to(orig_dtype)
var = x.pow(2).mean(dim=-1, keepdim=True)
x.mul_(torch.rsqrt(var + self.eps))
x = x.to(orig_dtype).mul_(self.weight)
return x, residual
def forward(
self,
x: torch.Tensor,
residual: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if residual is None:
return self.rms_forward(x)
else:
return self.add_rms_forward(x, residual)

199
nanovllm/layers/linear.py Executable file
View File

@@ -0,0 +1,199 @@
import torch
from torch import nn
import torch.nn.functional as F
import torch.distributed as dist
def divide(numerator, denominator):
assert numerator % denominator == 0
return numerator // denominator
class LinearBase(nn.Module):
def __init__(
self,
input_size: int,
output_size: int,
tp_dim: int | None = None,
):
super().__init__()
self.input_size = input_size
self.output_size = output_size
self.tp_dim = tp_dim
self.tp_rank = 0 # get_tensor_model_parallel_rank()
self.tp_size = 1 # get_tensor_model_parallel_world_size()
def forward(self, x: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
class ReplicatedLinear(LinearBase):
def __init__(
self,
input_size: int,
output_size: int,
bias: bool = False,
):
super().__init__(input_size, output_size)
self.weight = nn.Parameter(torch.empty(self.output_size, self.input_size))
self.weight.weight_loader = self.weight_loader
if bias:
self.bias = nn.Parameter(torch.empty(self.output_size))
self.bias.weight_loader = self.weight_loader
else:
self.register_parameter("bias", None)
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
assert param.size() == loaded_weight.size()
param.data.copy_(loaded_weight)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.linear(x, self.weight, self.bias)
class ColumnParallelLinear(LinearBase):
def __init__(
self,
input_size: int,
output_size: int,
bias: bool = False,
):
super().__init__(input_size, output_size, 0)
self.input_size_per_partition = input_size
self.output_size_per_partition = divide(output_size, self.tp_size)
self.output_partition_sizes = [self.output_size_per_partition]
# If QKV or MergedColumn, use output size of each partition.
if hasattr(self, "output_sizes"):
self.output_partition_sizes = [
divide(output_size, self.tp_size)
for output_size in self.output_sizes
]
self.weight = nn.Parameter(torch.empty(self.output_size_per_partition, self.input_size))
self.weight.weight_loader = self.weight_loader
if bias:
self.bias = nn.Parameter(torch.empty(self.output_size_per_partition))
self.bias.weight_loader = self.weight_loader
else:
self.register_parameter("bias", None)
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
param_data = param.data
shard_size = param_data.size(self.tp_dim)
start_idx = self.tp_rank * shard_size
loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size)
assert param_data.size() == loaded_weight.size()
param_data.copy_(loaded_weight)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.linear(x, self.weight, self.bias)
class MergedColumnParallelLinear(ColumnParallelLinear):
def __init__(
self,
input_size: int,
output_sizes: list[int],
bias: bool = False,
):
self.output_sizes = output_sizes
tp_size = 1 # get_tensor_model_parallel_world_size()
assert all(output_size % tp_size == 0 for output_size in output_sizes)
super().__init__(input_size, sum(output_sizes), bias=bias)
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: int | None = None):
param_data = param.data
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size
shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size)
# loaded_weight = loaded_weight.narrow(self.tp_dim, self.tp_rank * shard_size, shard_size)
assert param_data.size() == loaded_weight.size()
param_data.copy_(loaded_weight)
class QKVParallelLinear(ColumnParallelLinear):
def __init__(
self,
hidden_size: int,
head_size: int,
total_num_heads: int,
total_num_kv_heads: int | None = None,
bias: bool = False,
):
self.hidden_size = hidden_size
self.head_size = head_size
self.total_num_heads = total_num_heads
if total_num_kv_heads is None:
total_num_kv_heads = total_num_heads
self.total_num_kv_heads = total_num_kv_heads
# Divide the weight matrix along the last dimension.
tp_size = 1 # get_tensor_model_parallel_world_size()
self.num_heads = divide(self.total_num_heads, tp_size)
self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
input_size = self.hidden_size
output_size = (self.num_heads + 2 * self.num_kv_heads) * tp_size * self.head_size
self.output_sizes = [
self.num_heads * self.head_size * tp_size, # q_proj
self.num_kv_heads * self.head_size * tp_size, # k_proj
self.num_kv_heads * self.head_size * tp_size, # v_proj
]
super().__init__(input_size, output_size, bias)
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: str | None = None):
param_data = param.data
assert loaded_shard_id in ["q", "k", "v"]
if loaded_shard_id == "q":
shard_size = self.num_heads * self.head_size
shard_offset = 0
elif loaded_shard_id == "k":
shard_size = self.num_kv_heads * self.head_size
shard_offset = self.num_heads * self.head_size
else:
shard_size = self.num_kv_heads * self.head_size
shard_offset = self.num_heads * self.head_size + self.num_kv_heads * self.head_size
param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size)
# loaded_weight = loaded_weight.narrow(self.tp_dim, self.tp_rank * shard_size, shard_size)
assert param_data.size() == loaded_weight.size()
param_data.copy_(loaded_weight)
class RowParallelLinear(LinearBase):
def __init__(
self,
input_size: int,
output_size: int,
bias: bool = False,
):
super().__init__(input_size, output_size, 1)
self.input_size_per_partition = divide(input_size, self.tp_size)
self.output_size_per_partition = output_size
self.output_partition_sizes = [output_size]
self.weight = nn.Parameter(torch.empty(self.output_size, self.input_size_per_partition))
self.weight.weight_loader = self.weight_loader
if bias:
self.bias = nn.Parameter(torch.empty(self.output_size))
self.bias.weight_loader = self.weight_loader
else:
self.register_parameter("bias", None)
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
param_data = param.data
shard_size = param_data.size(self.tp_dim)
start_idx = self.tp_rank * shard_size
loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size)
assert param_data.size() == loaded_weight.size()
param_data.copy_(loaded_weight)
def forward(self, x: torch.Tensor) -> torch.Tensor:
y = F.linear(x, self.weight, self.bias if self.tp_rank == 0 else None)
if self.tp_size > 1:
dist.all_reduce(y)
return y

View File

@@ -0,0 +1,73 @@
from functools import lru_cache
import torch
from torch import nn
def apply_rotary_emb(
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> torch.Tensor:
cos = cos.unsqueeze(-2)
sin = sin.unsqueeze(-2)
x1, x2 = torch.chunk(x.to(torch.float32), 2, dim=-1)
y1 = x1 * cos - x2 * sin
y2 = x2 * cos + x1 * sin
return torch.cat((y1, y2), dim=-1).to(x.dtype)
class RotaryEmbedding(nn.Module):
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: float,
) -> None:
super().__init__()
self.head_size = head_size
self.rotary_dim = rotary_dim
assert rotary_dim == head_size
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (base**(torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))
t = torch.arange(self.max_position_embeddings, dtype=torch.float)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
self.register_buffer("cos_sin_cache", cache, persistent=False)
@torch.compile
def forward(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
positions = positions.flatten()
num_tokens = positions.shape[0]
cos_sin = self.cos_sin_cache[positions]
cos, sin = cos_sin.chunk(2, dim=-1)
query_shape = query.shape
query = query.view(num_tokens, -1, self.head_size)
query = apply_rotary_emb(query, cos, sin).view(query_shape)
key_shape = key.shape
key = key.view(num_tokens, -1, self.head_size)
key = apply_rotary_emb(key, cos, sin).view(key_shape)
return query, key
@lru_cache(1)
def get_rope(
head_size: int,
rotary_dim: int,
max_position: int,
base: float,
rope_scaling: dict | None = None,
):
assert rope_scaling is None
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base)
return rotary_emb

View File

@@ -0,0 +1,17 @@
import torch
from torch import nn
class Sampler(nn.Module):
def __init__(self):
super().__init__()
def forward(self, logits: torch.Tensor, temperatures: torch.Tensor | None = None):
logits = logits.to(torch.float)
if temperatures is not None:
logits.div_(temperatures.unsqueeze(dim=1))
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
# logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
sampled_tokens = probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1)
return sampled_tokens

5
nanovllm/llm.py Normal file
View File

@@ -0,0 +1,5 @@
from nanovllm.engine.llm_engine import LLMEngine
class LLM(LLMEngine):
pass

247
nanovllm/models/qwen3.py Executable file
View File

@@ -0,0 +1,247 @@
import torch
from torch import nn
from transformers import Qwen3Config
from nanovllm.layers.activation import SiluAndMul
from nanovllm.layers.attention import Attention
from nanovllm.layers.layernorm import RMSNorm
from nanovllm.layers.linear import QKVParallelLinear, MergedColumnParallelLinear, RowParallelLinear
from nanovllm.layers.rotary_embedding import get_rope
from nanovllm.layers.embed_head import VocabParallelEmbedding, ParallelLMHead
class Qwen3Attention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
max_position: int = 4096 * 32,
head_dim: int | None = None,
rms_norm_eps: float = 1e-06,
qkv_bias: bool = False,
rope_theta: float = 10000,
rope_scaling: tuple | None = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = 1 # get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
assert self.total_num_kv_heads % tp_size == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = head_dim or hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=qkv_bias,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position,
base=self.rope_theta,
rope_scaling=rope_scaling,
)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads)
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
qkv = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q_by_head = q.view(-1, self.num_heads, self.head_dim)
q_by_head = self.q_norm(q_by_head)
q = q_by_head.view(q.shape)
k_by_head = k.view(-1, self.num_kv_heads, self.head_dim)
k_by_head = self.k_norm(k_by_head)
k = k_by_head.view(k.shape)
q, k = self.rotary_emb(positions, q, k)
o = self.attn(q, k, v)
output = self.o_proj(o)
return output
class Qwen3MLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2,
bias=False,
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
)
assert hidden_act == "silu"
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x = self.down_proj(x)
return x
class Qwen3DecoderLayer(nn.Module):
def __init__(
self,
config: Qwen3Config,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = Qwen3Attention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
max_position=config.max_position_embeddings,
rms_norm_eps=config.rms_norm_eps,
qkv_bias=getattr(config, 'attention_bias', False),
head_dim=getattr(config, 'head_dim', None),
rope_theta=getattr(config, "rope_theta", 1000000),
rope_scaling=getattr(config, "rope_scaling", None),
)
self.mlp = Qwen3MLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor]:
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
)
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
class Qwen3Model(nn.Module):
def __init__(
self,
config: Qwen3Config,
):
super().__init__()
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList([Qwen3DecoderLayer(config) for _ in range(config.num_hidden_layers)])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
residual = None
for layer in self.layers:
hidden_states, residual = layer(
positions,
hidden_states,
residual,
)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class Qwen3ForCausalLM(nn.Module):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
def __init__(
self,
config: Qwen3Config
):
super().__init__()
self.model = Qwen3Model(config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
if config.tie_word_embeddings:
self.lm_head.weight.data = self.model.embed_tokens.weight.data
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor:
logits = self.lm_head(hidden_states)
return logits
def load_weights(self, path: str):
import os
from safetensors import safe_open
default_weight_loader = lambda param, loaded_weight: param.data.copy_(loaded_weight)
with safe_open(os.path.join(path, "model.safetensors"), "pt", "cpu") as f:
for n, p in self.named_parameters():
for x in self.packed_modules_mapping:
if x in n:
weight_loader = getattr(p, "weight_loader", default_weight_loader)
for i, y in enumerate(self.packed_modules_mapping[x]):
weight_loader(p, f.get_tensor(n.replace(x, y)), y[0] if x == "qkv_proj" else i)
break
else:
weight_loader = getattr(p, "weight_loader", default_weight_loader)
weight_loader(p, f.get_tensor(n))

View File

@@ -0,0 +1,8 @@
from dataclasses import dataclass
@dataclass
class SamplingParams:
temperature: float = 1.0
max_tokens: int = 64
ignore_eos: bool = False

28
nanovllm/utils/context.py Normal file
View File

@@ -0,0 +1,28 @@
from contextlib import contextmanager
from dataclasses import dataclass
import torch
@dataclass
class Context:
is_prefill: bool = False
cu_seqlens_q: torch.Tensor | None = None
cu_seqlens_k: torch.Tensor | None = None
max_seqlen_q: int = 0
max_seqlen_k: int = 0
slot_mapping: torch.Tensor | None = None
context_lens: torch.Tensor | None = None
block_tables: torch.Tensor | None = None
_CONTEXT = Context()
def get_context():
return _CONTEXT
def set_context(is_prefill, cu_seqlens_q=None, cu_seqlens_k=None, max_seqlen_q=0, max_seqlen_k=0, slot_mapping=None, context_lens=None, block_tables=None, ):
global _CONTEXT
_CONTEXT = Context(is_prefill, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, context_lens, block_tables)
def reset_context():
global _CONTEXT
_CONTEXT = Context()

14
nanovllm/utils/memory.py Normal file
View File

@@ -0,0 +1,14 @@
import os
import subprocess
import torch
def get_gpu_memory(device_id: int = 0):
torch.cuda.synchronize()
result = subprocess.check_output(
['nvidia-smi', '-i', str(device_id), '--query-gpu=memory.total,memory.used,memory.free', '--format=csv,nounits,noheader'],
encoding='utf-8'
)
total_memory, used_memory, free_memory = [int(x) for x in result.strip().split(', ')]
return total_memory, used_memory, free_memory

31
nanovllm/utils/timer.py Normal file
View File

@@ -0,0 +1,31 @@
from contextlib import contextmanager
from collections import defaultdict
import torch
class CUDATimer:
def __init__(self):
self.events = defaultdict(list)
@contextmanager
def record(self, name, enabled=True):
if not enabled:
yield
else:
start, end = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
self.events[name].append((start, end))
start.record()
yield
end.record()
def log(self):
torch.cuda.synchronize()
ret = []
for name, events in self.events.items():
total = 0
count = len(self.events)
for start, end in events:
total += start.elapsed_time(end)
ret.append(f"{name} {total:.2f}ms/{count}times")
return ", ".join(ret)

5
requirements.txt Normal file
View File

@@ -0,0 +1,5 @@
torch
triton
transformers
cmake
ninja