init commit
This commit is contained in:
195
.gitignore
vendored
Normal file
195
.gitignore
vendored
Normal file
@@ -0,0 +1,195 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# UV
|
||||
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
#uv.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
#poetry.toml
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
||||
.pdm.toml
|
||||
.pdm-python
|
||||
.pdm-build/
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
|
||||
# Abstra
|
||||
# Abstra is an AI-powered process automation framework.
|
||||
# Ignore directories containing user credentials, local state, and settings.
|
||||
# Learn more at https://abstra.io/docs
|
||||
.abstra/
|
||||
|
||||
# Visual Studio Code
|
||||
# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
|
||||
# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. However, if you prefer,
|
||||
# you could uncomment the following to ignore the entire vscode folder
|
||||
# .vscode/
|
||||
|
||||
# Ruff stuff:
|
||||
.ruff_cache/
|
||||
|
||||
# PyPI configuration file
|
||||
.pypirc
|
||||
|
||||
# Cursor
|
||||
# Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
|
||||
# exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
|
||||
# refer to https://docs.cursor.com/context/ignore-files
|
||||
.cursorignore
|
||||
.cursorindexingignore
|
||||
21
LICENSE
Normal file
21
LICENSE
Normal file
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2025 Xingkai Yu
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
20
bench.py
Normal file
20
bench.py
Normal file
@@ -0,0 +1,20 @@
|
||||
import os
|
||||
import time
|
||||
import torch
|
||||
from nanovllm import LLM, SamplingParams
|
||||
|
||||
|
||||
batch_size = 256
|
||||
seq_len = 1024
|
||||
max_tokens = 512
|
||||
|
||||
path = os.path.expanduser("~/huggingface/Qwen3-0.6B/")
|
||||
llm = LLM(path, enforce_eager=False)
|
||||
|
||||
prompt_token_ids = torch.randint(0, 10240, (batch_size, seq_len)).tolist()
|
||||
sampling_params = SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=max_tokens)
|
||||
|
||||
t = time.time()
|
||||
completions = llm.generate(prompt_token_ids, sampling_params)
|
||||
troughput = batch_size * max_tokens / (time.time() - t)
|
||||
print(f"Throughput: {troughput: .2f}")
|
||||
29
example.py
Normal file
29
example.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import os
|
||||
from nanovllm import LLM, SamplingParams
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
|
||||
path = os.path.expanduser("~/huggingface/Qwen3-0.6B/")
|
||||
tokenizer = AutoTokenizer.from_pretrained(path)
|
||||
llm = LLM(path, enforce_eager=True)
|
||||
|
||||
sampling_params = SamplingParams(temperature=0.6, max_tokens=256)
|
||||
prompts = [
|
||||
"自我介绍一下吧!",
|
||||
"列出100内所有素数",
|
||||
]
|
||||
prompts = [
|
||||
tokenizer.apply_chat_template(
|
||||
[{"role": "user", "content": prompt}],
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=True
|
||||
)
|
||||
for prompt in prompts
|
||||
]
|
||||
completions = llm.generate(prompts, sampling_params)
|
||||
|
||||
for p, c in zip(prompts, completions):
|
||||
print("\n\n")
|
||||
print(f"Prompt: {p}")
|
||||
print(f"Completion: {c}")
|
||||
2
nanovllm/__init__.py
Normal file
2
nanovllm/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from nanovllm.llm import LLM
|
||||
from nanovllm.sampling_params import SamplingParams
|
||||
20
nanovllm/config.py
Normal file
20
nanovllm/config.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from dataclasses import dataclass
|
||||
from transformers import AutoConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
model: str = ''
|
||||
max_num_batched_tokens: int = 16384
|
||||
max_num_seqs: int = 512
|
||||
max_model_len: int = 4096
|
||||
gpu_memory_utilization: float = 0.95
|
||||
enforce_eager: bool = False
|
||||
hf_config: AutoConfig | None = None
|
||||
eos: int = -1
|
||||
kvcache_block_size: int = 256
|
||||
num_kvcache_blocks: int = -1
|
||||
|
||||
def __post_init__(self):
|
||||
assert self.model
|
||||
assert self.kvcache_block_size == 256
|
||||
118
nanovllm/engine/block_manager.py
Normal file
118
nanovllm/engine/block_manager.py
Normal file
@@ -0,0 +1,118 @@
|
||||
from collections import deque
|
||||
import xxhash
|
||||
import numpy as np
|
||||
|
||||
from nanovllm.engine.sequence import Sequence
|
||||
|
||||
|
||||
def compute_hash(token_ids: list[int], prefix: int = -1):
|
||||
h = xxhash.xxh64()
|
||||
if prefix != -1:
|
||||
h.update(prefix.to_bytes(8))
|
||||
h.update(np.array(token_ids).tobytes())
|
||||
return h.intdigest()
|
||||
|
||||
|
||||
class Block:
|
||||
|
||||
def __init__(self, block_id):
|
||||
self.block_id = block_id
|
||||
self.ref_count = 0
|
||||
self.hash = -1
|
||||
self.token_ids = []
|
||||
|
||||
def update(self, hash: int, token_ids: list[int]):
|
||||
assert hash != -1
|
||||
assert len(token_ids) == 256
|
||||
self.hash = hash
|
||||
self.token_ids = token_ids
|
||||
|
||||
def reset(self):
|
||||
self.ref_count = 1
|
||||
self.hash = -1
|
||||
self.token_ids = []
|
||||
|
||||
def __repr__(self):
|
||||
return f"{(self.block_id, self.ref_count, self.hash)}"
|
||||
|
||||
|
||||
class BlockManager:
|
||||
|
||||
def __init__(self, num_blocks: int, block_size: int = 256):
|
||||
assert block_size == 256
|
||||
self.block_size = block_size
|
||||
self.blocks: list[Block] = [Block(i) for i in range(num_blocks)]
|
||||
self.hash_to_block_id: dict[int, int] = dict()
|
||||
self.free_block_ids: deque[int] = deque(range(num_blocks))
|
||||
self.used_block_ids: set[int] = set()
|
||||
|
||||
def _allocate_block(self, block_id: int):
|
||||
block = self.blocks[block_id]
|
||||
assert block.ref_count == 0
|
||||
block.reset()
|
||||
self.free_block_ids.remove(block_id)
|
||||
self.used_block_ids.add(block_id)
|
||||
return self.blocks[block_id]
|
||||
|
||||
def _deallocate_block(self, block_id: int):
|
||||
assert self.blocks[block_id].ref_count == 0
|
||||
self.used_block_ids.remove(block_id)
|
||||
self.free_block_ids.append(block_id)
|
||||
|
||||
def can_allocate(self, seq: Sequence):
|
||||
return seq.num_blocks <= len(self.free_block_ids)
|
||||
|
||||
def allocate(self, seq: Sequence):
|
||||
assert not seq.block_table
|
||||
h = -1
|
||||
cache_miss = False
|
||||
for i in range(seq.num_blocks):
|
||||
token_ids = seq.block(i, self.block_size)
|
||||
h = compute_hash(token_ids, h) if len(token_ids) == self.block_size else -1
|
||||
block_id = self.hash_to_block_id.get(h, -1)
|
||||
if block_id == -1 or self.blocks[block_id].token_ids != token_ids:
|
||||
cache_miss = True
|
||||
if cache_miss:
|
||||
block_id = self.free_block_ids[0]
|
||||
block = self._allocate_block(block_id)
|
||||
else:
|
||||
seq.num_cached_tokens += self.block_size
|
||||
if block_id in self.used_block_ids:
|
||||
block = self.blocks[block_id]
|
||||
block.ref_count += 1
|
||||
else:
|
||||
block = self._allocate_block(block_id)
|
||||
if h != -1:
|
||||
block.update(h, token_ids)
|
||||
self.hash_to_block_id[h] = block_id
|
||||
seq.block_table.append(block_id)
|
||||
|
||||
def deallocate(self, seq: Sequence):
|
||||
for block_id in seq.block_table:
|
||||
block = self.blocks[block_id]
|
||||
block.ref_count -= 1
|
||||
if block.ref_count == 0:
|
||||
self._deallocate_block(block_id)
|
||||
seq.num_cached_tokens = 0
|
||||
seq.block_table.clear()
|
||||
|
||||
def can_append(self):
|
||||
return len(self.free_block_ids) >= 1
|
||||
|
||||
def may_append(self, seq: Sequence):
|
||||
block_table = seq.block_table
|
||||
last_block = self.blocks[block_table[-1]]
|
||||
if len(seq) % self.block_size == 1:
|
||||
assert last_block.hash != -1
|
||||
block_id = self.free_block_ids[0]
|
||||
self._allocate_block(block_id)
|
||||
block_table.append(block_id)
|
||||
elif len(seq) % self.block_size == 0:
|
||||
assert last_block.hash == -1
|
||||
token_ids = seq.last_block(self.block_size)
|
||||
prefix = self.blocks[block_table[-2]].hash if len(block_table) > 1 else -1
|
||||
h = compute_hash(token_ids, prefix)
|
||||
last_block.update(h, token_ids)
|
||||
self.hash_to_block_id[h] = last_block.block_id
|
||||
else:
|
||||
assert last_block.hash == -1
|
||||
66
nanovllm/engine/llm_engine.py
Normal file
66
nanovllm/engine/llm_engine.py
Normal file
@@ -0,0 +1,66 @@
|
||||
from collections import defaultdict
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import AutoConfig, AutoTokenizer
|
||||
|
||||
from nanovllm.config import Config
|
||||
from nanovllm.sampling_params import SamplingParams
|
||||
from nanovllm.engine.sequence import Sequence
|
||||
from nanovllm.engine.scheduler import Scheduler
|
||||
from nanovllm.engine.model_runner import ModelRunner
|
||||
|
||||
|
||||
class LLMEngine:
|
||||
|
||||
def __init__(self, model, **kwargs):
|
||||
config = Config(model)
|
||||
for k, v in kwargs.items():
|
||||
if hasattr(config, k):
|
||||
setattr(config, k, v)
|
||||
config.hf_config = AutoConfig.from_pretrained(config.model)
|
||||
config.max_model_len = min(config.max_model_len, config.hf_config.max_position_embeddings)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True)
|
||||
config.eos = self.tokenizer.eos_token_id
|
||||
self.model_runner = ModelRunner(config)
|
||||
self.scheduler = Scheduler(config)
|
||||
|
||||
def add_request(self, prompt: str | list[int], sampling_params: SamplingParams):
|
||||
if isinstance(prompt, str):
|
||||
prompt = self.tokenizer.encode(prompt)
|
||||
seq = Sequence(prompt, sampling_params)
|
||||
self.scheduler.add(seq)
|
||||
|
||||
def step(self):
|
||||
seqs, is_prefill = self.scheduler.schedule()
|
||||
token_ids = self.model_runner.run(seqs, is_prefill)
|
||||
finished = self.scheduler.postprocess(seqs, token_ids)
|
||||
return [(seq.seq_id, token_id, finish) for seq, token_id, finish in zip(seqs, token_ids, finished)]
|
||||
|
||||
def is_finished(self):
|
||||
return self.scheduler.is_finished()
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompts: list[str] | list[list[int]],
|
||||
sampling_params: SamplingParams | list[SamplingParams],
|
||||
use_tqdm: bool = True,
|
||||
) -> list[str]:
|
||||
if use_tqdm:
|
||||
pbar = tqdm(total=len(prompts),
|
||||
desc="Processed prompts",
|
||||
)
|
||||
if not isinstance(SamplingParams, list):
|
||||
sampling_params = [sampling_params] * len(prompts)
|
||||
for prompt, sp in zip(prompts, sampling_params):
|
||||
self.add_request(prompt, sp)
|
||||
outputs = defaultdict(list)
|
||||
while not self.is_finished():
|
||||
output = self.step()
|
||||
for seq_id, token_id, finish in output:
|
||||
outputs[seq_id].append(token_id)
|
||||
if use_tqdm and finish:
|
||||
pbar.update(1)
|
||||
outputs = [outputs[seq_id] for seq_id in sorted(outputs)]
|
||||
outputs = [self.tokenizer.decode(token_ids) for token_ids in outputs]
|
||||
if use_tqdm:
|
||||
pbar.close()
|
||||
return outputs
|
||||
198
nanovllm/engine/model_runner.py
Normal file
198
nanovllm/engine/model_runner.py
Normal file
@@ -0,0 +1,198 @@
|
||||
import torch
|
||||
|
||||
from nanovllm.config import Config
|
||||
from nanovllm.engine.sequence import Sequence
|
||||
from nanovllm.utils.context import set_context, get_context, reset_context
|
||||
from nanovllm.utils.memory import get_gpu_memory
|
||||
from nanovllm.models.qwen3 import Qwen3ForCausalLM
|
||||
from nanovllm.layers.sampler import Sampler
|
||||
|
||||
|
||||
class ModelRunner:
|
||||
|
||||
def __init__(self, config: Config):
|
||||
self.config = config
|
||||
hf_config = config.hf_config
|
||||
self.block_size = config.kvcache_block_size
|
||||
self.enforce_eager = config.enforce_eager
|
||||
|
||||
default_dtype = torch.get_default_dtype()
|
||||
torch.set_default_dtype(hf_config.torch_dtype)
|
||||
torch.set_default_device("cuda")
|
||||
self.model = Qwen3ForCausalLM(hf_config)
|
||||
self.model.load_weights(config.model)
|
||||
self.sampler = Sampler()
|
||||
self.allocate_kv_cache(config.gpu_memory_utilization)
|
||||
if not self.enforce_eager:
|
||||
self.capture_model()
|
||||
torch.set_default_device("cpu")
|
||||
torch.set_default_dtype(default_dtype)
|
||||
|
||||
def allocate_kv_cache(self, gpu_memory_utilization):
|
||||
config = self.config
|
||||
hf_config = config.hf_config
|
||||
total, used, _ = get_gpu_memory()
|
||||
free = total * gpu_memory_utilization - used
|
||||
block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * hf_config.num_key_value_heads * hf_config.head_dim * hf_config.torch_dtype.itemsize
|
||||
config.num_kvcache_blocks = int(free * 1e6) // block_bytes
|
||||
self.kv_cache = torch.zeros(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, hf_config.num_key_value_heads, hf_config.head_dim)
|
||||
layer_id = 0
|
||||
for module in self.model.modules():
|
||||
if hasattr(module, "k_cache") and hasattr(module, "v_cache"):
|
||||
module.k_cache = self.kv_cache[0, layer_id]
|
||||
module.v_cache = self.kv_cache[1, layer_id]
|
||||
layer_id += 1
|
||||
|
||||
def preare_block_tables(self, seqs: list[Sequence]):
|
||||
max_len = max(len(seq.block_table) for seq in seqs)
|
||||
block_tables = [
|
||||
seq.block_table + [-1] * (max_len - len(seq.block_table))
|
||||
for seq in seqs
|
||||
]
|
||||
block_tables = torch.tensor(block_tables, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||
return block_tables
|
||||
|
||||
def prepare_prefill(self, seqs: list[Sequence]):
|
||||
input_ids = []
|
||||
positions = []
|
||||
cu_seqlens_q = [0]
|
||||
cu_seqlens_k = [0]
|
||||
max_seqlen_q = 0
|
||||
max_seqlen_k = 0
|
||||
slot_mapping = []
|
||||
context_lens = None
|
||||
block_tables = None
|
||||
for seq in seqs:
|
||||
seqlen = len(seq)
|
||||
input_ids.extend(seq[seq.num_cached_tokens:])
|
||||
positions.extend(list(range(seq.num_cached_tokens, len(seq))))
|
||||
seqlen_q = seqlen - seq.num_cached_tokens
|
||||
seqlen_k = seqlen
|
||||
cu_seqlens_q.append(cu_seqlens_q[-1] + seqlen_q)
|
||||
cu_seqlens_k.append(cu_seqlens_k[-1] + seqlen_k)
|
||||
max_seqlen_q = max(seqlen_q, max_seqlen_q)
|
||||
max_seqlen_k = max(seqlen_k, max_seqlen_k)
|
||||
for i in range(seq.num_cached_blocks, seq.num_blocks):
|
||||
start = seq.block_table[i] * self.block_size
|
||||
if i != seq.num_blocks - 1:
|
||||
end = start + self.block_size
|
||||
else:
|
||||
end = start + len(seq.last_block())
|
||||
slot_mapping.extend(list(range(start, end)))
|
||||
assert len(input_ids) == len(slot_mapping)
|
||||
assert len(input_ids) == cu_seqlens_q[-1]
|
||||
if cu_seqlens_k[-1] > cu_seqlens_q[-1]: # prefix cache
|
||||
context_lens = torch.tensor([len(seq) for seq in seqs], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||
block_tables = self.preare_block_tables(seqs)
|
||||
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
||||
positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
||||
cu_seqlens_q = torch.tensor(cu_seqlens_q, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||
cu_seqlens_k = torch.tensor(cu_seqlens_k, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||
set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, context_lens, block_tables)
|
||||
return input_ids, positions
|
||||
|
||||
def prepare_decode(self, seqs: list[Sequence]):
|
||||
input_ids = []
|
||||
positions = []
|
||||
slot_mapping = []
|
||||
context_lens = []
|
||||
for seq in seqs:
|
||||
input_ids.append(seq.last_token)
|
||||
positions.append(len(seq))
|
||||
context_lens.append(len(seq))
|
||||
slot_mapping.append(seq.block_table[-1] * self.block_size + len(seq.last_block()))
|
||||
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
||||
positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
||||
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||
context_lens = torch.tensor(context_lens, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||
block_tables = self.preare_block_tables(seqs)
|
||||
set_context(False, slot_mapping=slot_mapping, context_lens=context_lens, block_tables=block_tables)
|
||||
return input_ids, positions
|
||||
|
||||
def prepare_sample(self, seqs: list[Sequence]):
|
||||
temperatures = []
|
||||
for seq in seqs:
|
||||
temperatures.append(seq.temperature)
|
||||
temperatures = torch.tensor(temperatures, dtype=torch.float32, pin_memory=True).cuda(non_blocking=True)
|
||||
return temperatures
|
||||
|
||||
@torch.inference_mode()
|
||||
def run_model(self, input_ids: torch.Tensor, positions: torch.Tensor, is_prefill):
|
||||
if is_prefill or self.enforce_eager or input_ids.size(0) > 256:
|
||||
return self.model.compute_logits(self.model(input_ids, positions))
|
||||
else:
|
||||
bs = input_ids.size(0)
|
||||
context = get_context()
|
||||
self.reset_graph_vars()
|
||||
graph = self.graphs[next(x for x in self.graph_bs if x >= bs)]
|
||||
graph_vars = self.graph_vars
|
||||
graph_vars["input_ids"][:bs] = input_ids
|
||||
graph_vars["positions"][:bs] = positions
|
||||
graph_vars["slot_mapping"][:bs] = context.slot_mapping
|
||||
graph_vars["context_lens"][:bs] = context.context_lens
|
||||
graph_vars["block_tables"][:bs, :context.block_tables.size(1)] = context.block_tables
|
||||
graph.replay()
|
||||
return self.model.compute_logits(graph_vars["outputs"][:bs])
|
||||
|
||||
def reset_graph_vars(self):
|
||||
graph_vars = self.graph_vars
|
||||
graph_vars["input_ids"].zero_()
|
||||
graph_vars["positions"].zero_()
|
||||
graph_vars["slot_mapping"].zero_()
|
||||
graph_vars["context_lens"].zero_()
|
||||
graph_vars["block_tables"].zero_()
|
||||
|
||||
def run(self, seqs: list[Sequence], is_prefill: bool) -> list[int]:
|
||||
input_ids, positions = self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs)
|
||||
temperatures = self.prepare_sample(seqs)
|
||||
logits = self.run_model(input_ids, positions, is_prefill)
|
||||
token_ids = self.sampler(logits, temperatures).tolist()
|
||||
reset_context()
|
||||
return token_ids
|
||||
|
||||
@torch.inference_mode()
|
||||
def capture_model(self):
|
||||
get_rng_state = torch.cuda.get_rng_state
|
||||
set_rng_state = torch.cuda.set_rng_state
|
||||
rng_state = torch.cuda.get_rng_state()
|
||||
torch.cuda.get_rng_state = lambda: rng_state
|
||||
torch.cuda.set_rng_state = lambda _: None
|
||||
|
||||
config = self.config
|
||||
hf_config = config.hf_config
|
||||
max_bs = min(self.config.max_num_seqs, 256)
|
||||
max_num_blocks = (config.max_model_len + self.block_size - 1) // self.block_size
|
||||
input_ids = torch.zeros(max_bs, dtype=torch.int64)
|
||||
positions = torch.zeros(max_bs, dtype=torch.int64)
|
||||
slot_mapping = torch.zeros(max_bs, dtype=torch.int32)
|
||||
context_lens = torch.zeros(max_bs, dtype=torch.int32)
|
||||
block_tables = torch.zeros(max_bs, max_num_blocks, dtype=torch.int32)
|
||||
outputs = torch.zeros(max_bs, hf_config.hidden_size)
|
||||
self.graph_bs = [1, 2, 4, 8, 16] + list(range(32, max_bs + 1, 32))
|
||||
self.graphs = {}
|
||||
self.graph_pool = None
|
||||
|
||||
for bs in reversed(self.graph_bs):
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
set_context(False, slot_mapping=slot_mapping[:bs], context_lens=context_lens[:bs], block_tables=block_tables[:bs])
|
||||
outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # warmup
|
||||
with torch.cuda.graph(graph, self.graph_pool):
|
||||
outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # capture
|
||||
if self.graph_pool is None:
|
||||
self.graph_pool = graph.pool()
|
||||
self.graphs[bs] = graph
|
||||
torch.cuda.synchronize()
|
||||
reset_context()
|
||||
|
||||
self.graph_vars = dict(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
slot_mapping=slot_mapping,
|
||||
context_lens=context_lens,
|
||||
block_tables=block_tables,
|
||||
outputs=outputs,
|
||||
)
|
||||
|
||||
torch.cuda.get_rng_state = get_rng_state
|
||||
torch.cuda.set_rng_state = set_rng_state
|
||||
84
nanovllm/engine/scheduler.py
Normal file
84
nanovllm/engine/scheduler.py
Normal file
@@ -0,0 +1,84 @@
|
||||
from collections import deque
|
||||
|
||||
from nanovllm.config import Config
|
||||
from nanovllm.engine.sequence import Sequence, SequenceStatus
|
||||
from nanovllm.engine.block_manager import BlockManager
|
||||
|
||||
|
||||
class Scheduler:
|
||||
|
||||
def __init__(self, config: Config):
|
||||
self.max_num_seqs = config.max_num_seqs
|
||||
self.max_num_batched_tokens = config.max_num_batched_tokens
|
||||
self.eos = config.eos
|
||||
self.block_manager = BlockManager(config.num_kvcache_blocks, config.kvcache_block_size)
|
||||
self.waiting: deque[Sequence] = deque()
|
||||
self.running: deque[Sequence] = deque()
|
||||
self.num_finished = 0
|
||||
self.num_tokens = 0
|
||||
|
||||
def is_finished(self):
|
||||
return not self.waiting and not self.running
|
||||
|
||||
def add(self, seq: Sequence):
|
||||
self.waiting.append(seq)
|
||||
|
||||
def schedule(self) -> tuple[list[Sequence], SequenceStatus]:
|
||||
# prefill
|
||||
scheduled_seqs = []
|
||||
num_seqs = 0
|
||||
num_batched_tokens = 0
|
||||
while self.waiting and num_seqs < self.max_num_seqs:
|
||||
seq = self.waiting[0]
|
||||
if num_batched_tokens + len(seq) > self.max_num_batched_tokens or not self.block_manager.can_allocate(seq):
|
||||
break
|
||||
num_seqs += 1
|
||||
self.block_manager.allocate(seq)
|
||||
num_batched_tokens += len(seq) - seq.num_cached_tokens
|
||||
seq.status = SequenceStatus.RUNNING
|
||||
self.waiting.popleft()
|
||||
self.running.append(seq)
|
||||
scheduled_seqs.append(seq)
|
||||
if scheduled_seqs:
|
||||
return scheduled_seqs, True
|
||||
|
||||
# decode
|
||||
# self.running = deque(sorted(self.running))
|
||||
while self.running and num_seqs < self.max_num_seqs:
|
||||
seq = self.running.popleft()
|
||||
while not self.block_manager.can_append():
|
||||
if self.running:
|
||||
self.preempt(self.running.pop())
|
||||
else:
|
||||
self.preempt(seq)
|
||||
break
|
||||
else:
|
||||
num_seqs += 1
|
||||
self.block_manager.may_append(seq)
|
||||
scheduled_seqs.append(seq)
|
||||
running = deque(scheduled_seqs)
|
||||
running.extend(self.running)
|
||||
self.running = running
|
||||
if scheduled_seqs:
|
||||
return scheduled_seqs, False
|
||||
|
||||
def preempt(self, seq: Sequence):
|
||||
seq.status = SequenceStatus.WAITING
|
||||
self.block_manager.deallocate(seq)
|
||||
self.waiting.appendleft(seq)
|
||||
return True
|
||||
|
||||
def postprocess(self, seqs: list[Sequence], token_ids: list[int]) -> list[bool]:
|
||||
self.num_tokens += len(token_ids)
|
||||
finished = []
|
||||
for seq, token_id in zip(seqs, token_ids):
|
||||
seq.append_token(token_id)
|
||||
if token_id == self.eos or seq.num_completion_tokens == seq.max_tokens:
|
||||
seq.status = SequenceStatus.FINISHED
|
||||
self.block_manager.deallocate(seq)
|
||||
self.running.remove(seq)
|
||||
self.num_finished += 1
|
||||
finished.append(True)
|
||||
else:
|
||||
finished.append(False)
|
||||
return finished
|
||||
73
nanovllm/engine/sequence.py
Normal file
73
nanovllm/engine/sequence.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from copy import copy
|
||||
from enum import Enum, auto
|
||||
from itertools import count
|
||||
|
||||
from nanovllm.sampling_params import SamplingParams
|
||||
|
||||
|
||||
class SequenceStatus(Enum):
|
||||
WAITING = auto()
|
||||
RUNNING = auto()
|
||||
FINISHED = auto()
|
||||
|
||||
|
||||
class Sequence:
|
||||
counter = count()
|
||||
|
||||
def __init__(self, token_ids: list[int], sampling_params: SamplingParams):
|
||||
self.seq_id = next(Sequence.counter)
|
||||
self.status = SequenceStatus.WAITING
|
||||
self.token_ids = copy(token_ids)
|
||||
self.num_prompt_tokens = len(token_ids)
|
||||
self._num_cached_tokens = 0
|
||||
self.block_table = []
|
||||
self.temperature = sampling_params.temperature
|
||||
self.max_tokens = sampling_params.max_tokens
|
||||
self.ignore_eos = sampling_params.ignore_eos
|
||||
|
||||
def __len__(self):
|
||||
return len(self.token_ids)
|
||||
|
||||
def __lt__(self, other):
|
||||
return self.seq_id < other.seq_id
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.token_ids[key]
|
||||
|
||||
@property
|
||||
def num_completion_tokens(self):
|
||||
return len(self.token_ids) - self.num_prompt_tokens
|
||||
|
||||
@property
|
||||
def num_cached_tokens(self):
|
||||
return self._num_cached_tokens
|
||||
|
||||
@num_cached_tokens.setter
|
||||
def num_cached_tokens(self, num_cached_tokens):
|
||||
assert num_cached_tokens % 256 == 0
|
||||
self._num_cached_tokens = num_cached_tokens
|
||||
|
||||
@property
|
||||
def num_cached_blocks(self):
|
||||
return self.num_cached_tokens // 256
|
||||
|
||||
@property
|
||||
def num_blocks(self):
|
||||
return (len(self.token_ids) + 255) // 256
|
||||
|
||||
@property
|
||||
def last_token(self):
|
||||
return self.token_ids[-1]
|
||||
|
||||
def block(self, i, block_size=256):
|
||||
return self.token_ids[i*block_size: (i+1)*block_size]
|
||||
|
||||
def last_block(self, block_size=256):
|
||||
n = self.num_blocks
|
||||
t = len(self) + block_size - self.num_blocks * block_size
|
||||
x = self.token_ids[(n-1)*block_size:]
|
||||
assert len(x) == t
|
||||
return x
|
||||
|
||||
def append_token(self, token_id: int):
|
||||
self.token_ids.append(token_id)
|
||||
14
nanovllm/layers/activation.py
Executable file
14
nanovllm/layers/activation.py
Executable file
@@ -0,0 +1,14 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class SiluAndMul(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@torch.compile
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x, y = x.chunk(2, -1)
|
||||
return F.silu(x) * y
|
||||
86
nanovllm/layers/attention.py
Normal file
86
nanovllm/layers/attention.py
Normal file
@@ -0,0 +1,86 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
|
||||
# from nanovllm.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
|
||||
from nanovllm.utils.context import get_context
|
||||
|
||||
|
||||
@triton.jit
|
||||
def store_kvcache_kernel(
|
||||
key_ptr,
|
||||
key_stride,
|
||||
value_ptr,
|
||||
value_stride,
|
||||
k_cache_ptr,
|
||||
v_cache_ptr,
|
||||
slot_mapping_ptr,
|
||||
D: tl.constexpr,
|
||||
):
|
||||
idx = tl.program_id(0)
|
||||
key_offsets = idx * key_stride + tl.arange(0, D)
|
||||
value_offsets = idx * value_stride + tl.arange(0, D)
|
||||
key = tl.load(key_ptr + key_offsets)
|
||||
value = tl.load(value_ptr + value_offsets)
|
||||
slot = tl.load(slot_mapping_ptr + idx)
|
||||
cache_offsets = slot * D + tl.arange(0, D)
|
||||
tl.store(k_cache_ptr + cache_offsets, key)
|
||||
tl.store(v_cache_ptr + cache_offsets, value)
|
||||
|
||||
|
||||
def store_kvcache(key: torch.Tensor, value: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, slot_mapping: torch.Tensor):
|
||||
N, num_heads, head_dim = key.shape
|
||||
D = num_heads * head_dim
|
||||
assert key.stride(-1) == 1 and value.stride(-1) == 1
|
||||
assert key.stride(1) == head_dim and value.stride(1) == head_dim
|
||||
assert k_cache.stride(1) == D and v_cache.stride(1) == D
|
||||
assert slot_mapping.numel() == N
|
||||
store_kvcache_kernel[(N,)](key, key.stride(0), value, value.stride(0), k_cache, v_cache, slot_mapping, D)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads,
|
||||
head_dim,
|
||||
scale,
|
||||
num_kv_heads,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
self.scale = scale
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.k_cache = self.v_cache = torch.tensor([])
|
||||
|
||||
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
|
||||
o: torch.Tensor
|
||||
q = q.view(-1, self.num_heads, self.head_dim)
|
||||
k = k.view(-1, self.num_kv_heads, self.head_dim)
|
||||
v = v.view(-1, self.num_kv_heads, self.head_dim)
|
||||
context = get_context()
|
||||
k_cache = self.k_cache
|
||||
v_cache = self.v_cache
|
||||
if context.is_prefill:
|
||||
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
|
||||
if context.block_tables is None: # normal prefill
|
||||
cu_seqlens_k = context.cu_seqlens_k
|
||||
seqused_k = None
|
||||
else: # prefix cache
|
||||
cu_seqlens_k = None
|
||||
seqused_k = context.context_lens
|
||||
k, v = k_cache, v_cache
|
||||
o = flash_attn_varlen_func(q, k, v,
|
||||
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
|
||||
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=cu_seqlens_k,
|
||||
seqused_k=seqused_k, softmax_scale=self.scale,
|
||||
causal=True, block_table=context.block_tables)
|
||||
else: # decode
|
||||
o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache, k.unsqueeze(1), v.unsqueeze(1),
|
||||
cache_seqlens=context.context_lens, block_table=context.block_tables,
|
||||
softmax_scale=self.scale, causal=True)
|
||||
o = o.view(-1, self.num_heads * self.head_dim)
|
||||
return o
|
||||
72
nanovllm/layers/embed_head.py
Normal file
72
nanovllm/layers/embed_head.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
import torch.distributed as dist
|
||||
|
||||
from nanovllm.utils.context import get_context
|
||||
|
||||
|
||||
class VocabParallelEmbedding(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.tp_rank = 0 # get_tensor_model_parallel_rank()
|
||||
self.tp_size = 1 # get_tensor_model_parallel_world_size()
|
||||
assert num_embeddings % self.tp_size == 0
|
||||
self.num_embeddings = num_embeddings
|
||||
self.num_embeddings_per_partition = self.num_embeddings // self.tp_size
|
||||
self.vocab_start_idx = self.num_embeddings_per_partition * self.tp_rank
|
||||
self.vocab_end_idx = self.vocab_start_idx + self.num_embeddings_per_partition
|
||||
self.embedding_dim = embedding_dim
|
||||
self.weight = nn.Parameter(torch.empty(self.num_embeddings_per_partition, embedding_dim))
|
||||
self.weight.weight_loader = self.weight_loader
|
||||
|
||||
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
|
||||
param_data = param.data
|
||||
shard_size = param_data.size(0)
|
||||
start_idx = self.tp_rank * shard_size
|
||||
loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
|
||||
assert param_data.size() == loaded_weight.size()
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
if self.tp_size > 1:
|
||||
mask = (x >= self.vocab_start_idx) & (x < self.vocab_end_idx)
|
||||
x = mask * (x - self.vocab_start_idx)
|
||||
y = F.embedding(x, self.weight)
|
||||
if self.tp_size > 1:
|
||||
y = mask * y
|
||||
dist.all_reduce(y)
|
||||
return y
|
||||
|
||||
|
||||
class ParallelLMHead(VocabParallelEmbedding):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
bias: bool = False,
|
||||
):
|
||||
super().__init__(num_embeddings, embedding_dim)
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.empty(self.num_embeddings_per_partition))
|
||||
self.bias.weight_loader = self.weight_loader
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
context = get_context()
|
||||
if context.is_prefill:
|
||||
last_indices = context.cu_seqlens_q[1:] - 1
|
||||
x = x[last_indices].contiguous()
|
||||
logits = F.linear(x, self.weight, self.bias)
|
||||
# if self.tp_size > 1:
|
||||
# all_logits = [torch.empty_like(logits) for _ in range(self.tp_size)]
|
||||
# dist.gather(logits, all_logits, 0)
|
||||
# logits = torch.cat(all_logits, -1)
|
||||
return logits if self.tp_rank == 0 else None
|
||||
51
nanovllm/layers/layernorm.py
Executable file
51
nanovllm/layers/layernorm.py
Executable file
@@ -0,0 +1,51 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
eps: float = 1e-6,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
|
||||
@torch.compile
|
||||
def rms_forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
orig_dtype = x.dtype
|
||||
x = x.to(torch.float32)
|
||||
var = x.pow(2).mean(dim=-1, keepdim=True)
|
||||
x.mul_(torch.rsqrt(var + self.eps))
|
||||
x = x.to(orig_dtype).mul_(self.weight)
|
||||
return x
|
||||
|
||||
@torch.compile
|
||||
def add_rms_forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
orig_dtype = x.dtype
|
||||
x = x.to(torch.float32).add_(residual.to(torch.float32))
|
||||
residual = x.to(orig_dtype)
|
||||
var = x.pow(2).mean(dim=-1, keepdim=True)
|
||||
x.mul_(torch.rsqrt(var + self.eps))
|
||||
x = x.to(orig_dtype).mul_(self.weight)
|
||||
return x, residual
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
if residual is None:
|
||||
return self.rms_forward(x)
|
||||
else:
|
||||
return self.add_rms_forward(x, residual)
|
||||
199
nanovllm/layers/linear.py
Executable file
199
nanovllm/layers/linear.py
Executable file
@@ -0,0 +1,199 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
def divide(numerator, denominator):
|
||||
assert numerator % denominator == 0
|
||||
return numerator // denominator
|
||||
|
||||
|
||||
class LinearBase(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
tp_dim: int | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.input_size = input_size
|
||||
self.output_size = output_size
|
||||
self.tp_dim = tp_dim
|
||||
self.tp_rank = 0 # get_tensor_model_parallel_rank()
|
||||
self.tp_size = 1 # get_tensor_model_parallel_world_size()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ReplicatedLinear(LinearBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
bias: bool = False,
|
||||
):
|
||||
super().__init__(input_size, output_size)
|
||||
self.weight = nn.Parameter(torch.empty(self.output_size, self.input_size))
|
||||
self.weight.weight_loader = self.weight_loader
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.empty(self.output_size))
|
||||
self.bias.weight_loader = self.weight_loader
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
|
||||
assert param.size() == loaded_weight.size()
|
||||
param.data.copy_(loaded_weight)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return F.linear(x, self.weight, self.bias)
|
||||
|
||||
|
||||
class ColumnParallelLinear(LinearBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
bias: bool = False,
|
||||
):
|
||||
super().__init__(input_size, output_size, 0)
|
||||
self.input_size_per_partition = input_size
|
||||
self.output_size_per_partition = divide(output_size, self.tp_size)
|
||||
self.output_partition_sizes = [self.output_size_per_partition]
|
||||
# If QKV or MergedColumn, use output size of each partition.
|
||||
if hasattr(self, "output_sizes"):
|
||||
self.output_partition_sizes = [
|
||||
divide(output_size, self.tp_size)
|
||||
for output_size in self.output_sizes
|
||||
]
|
||||
|
||||
self.weight = nn.Parameter(torch.empty(self.output_size_per_partition, self.input_size))
|
||||
self.weight.weight_loader = self.weight_loader
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.empty(self.output_size_per_partition))
|
||||
self.bias.weight_loader = self.weight_loader
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
|
||||
param_data = param.data
|
||||
shard_size = param_data.size(self.tp_dim)
|
||||
start_idx = self.tp_rank * shard_size
|
||||
loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size)
|
||||
assert param_data.size() == loaded_weight.size()
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return F.linear(x, self.weight, self.bias)
|
||||
|
||||
|
||||
class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_sizes: list[int],
|
||||
bias: bool = False,
|
||||
):
|
||||
self.output_sizes = output_sizes
|
||||
tp_size = 1 # get_tensor_model_parallel_world_size()
|
||||
assert all(output_size % tp_size == 0 for output_size in output_sizes)
|
||||
super().__init__(input_size, sum(output_sizes), bias=bias)
|
||||
|
||||
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: int | None = None):
|
||||
param_data = param.data
|
||||
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size
|
||||
shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
|
||||
param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size)
|
||||
# loaded_weight = loaded_weight.narrow(self.tp_dim, self.tp_rank * shard_size, shard_size)
|
||||
assert param_data.size() == loaded_weight.size()
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
|
||||
class QKVParallelLinear(ColumnParallelLinear):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
head_size: int,
|
||||
total_num_heads: int,
|
||||
total_num_kv_heads: int | None = None,
|
||||
bias: bool = False,
|
||||
):
|
||||
self.hidden_size = hidden_size
|
||||
self.head_size = head_size
|
||||
self.total_num_heads = total_num_heads
|
||||
if total_num_kv_heads is None:
|
||||
total_num_kv_heads = total_num_heads
|
||||
self.total_num_kv_heads = total_num_kv_heads
|
||||
# Divide the weight matrix along the last dimension.
|
||||
tp_size = 1 # get_tensor_model_parallel_world_size()
|
||||
self.num_heads = divide(self.total_num_heads, tp_size)
|
||||
self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
|
||||
input_size = self.hidden_size
|
||||
output_size = (self.num_heads + 2 * self.num_kv_heads) * tp_size * self.head_size
|
||||
self.output_sizes = [
|
||||
self.num_heads * self.head_size * tp_size, # q_proj
|
||||
self.num_kv_heads * self.head_size * tp_size, # k_proj
|
||||
self.num_kv_heads * self.head_size * tp_size, # v_proj
|
||||
]
|
||||
|
||||
super().__init__(input_size, output_size, bias)
|
||||
|
||||
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: str | None = None):
|
||||
param_data = param.data
|
||||
assert loaded_shard_id in ["q", "k", "v"]
|
||||
if loaded_shard_id == "q":
|
||||
shard_size = self.num_heads * self.head_size
|
||||
shard_offset = 0
|
||||
elif loaded_shard_id == "k":
|
||||
shard_size = self.num_kv_heads * self.head_size
|
||||
shard_offset = self.num_heads * self.head_size
|
||||
else:
|
||||
shard_size = self.num_kv_heads * self.head_size
|
||||
shard_offset = self.num_heads * self.head_size + self.num_kv_heads * self.head_size
|
||||
param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size)
|
||||
# loaded_weight = loaded_weight.narrow(self.tp_dim, self.tp_rank * shard_size, shard_size)
|
||||
assert param_data.size() == loaded_weight.size()
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
|
||||
class RowParallelLinear(LinearBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
bias: bool = False,
|
||||
):
|
||||
super().__init__(input_size, output_size, 1)
|
||||
self.input_size_per_partition = divide(input_size, self.tp_size)
|
||||
self.output_size_per_partition = output_size
|
||||
self.output_partition_sizes = [output_size]
|
||||
|
||||
self.weight = nn.Parameter(torch.empty(self.output_size, self.input_size_per_partition))
|
||||
self.weight.weight_loader = self.weight_loader
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.empty(self.output_size))
|
||||
self.bias.weight_loader = self.weight_loader
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
|
||||
param_data = param.data
|
||||
shard_size = param_data.size(self.tp_dim)
|
||||
start_idx = self.tp_rank * shard_size
|
||||
loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size)
|
||||
assert param_data.size() == loaded_weight.size()
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
y = F.linear(x, self.weight, self.bias if self.tp_rank == 0 else None)
|
||||
if self.tp_size > 1:
|
||||
dist.all_reduce(y)
|
||||
return y
|
||||
73
nanovllm/layers/rotary_embedding.py
Normal file
73
nanovllm/layers/rotary_embedding.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from functools import lru_cache
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
def apply_rotary_emb(
|
||||
x: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
cos = cos.unsqueeze(-2)
|
||||
sin = sin.unsqueeze(-2)
|
||||
x1, x2 = torch.chunk(x.to(torch.float32), 2, dim=-1)
|
||||
y1 = x1 * cos - x2 * sin
|
||||
y2 = x2 * cos + x1 * sin
|
||||
return torch.cat((y1, y2), dim=-1).to(x.dtype)
|
||||
|
||||
|
||||
class RotaryEmbedding(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: float,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.head_size = head_size
|
||||
self.rotary_dim = rotary_dim
|
||||
assert rotary_dim == head_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
inv_freq = 1.0 / (base**(torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))
|
||||
t = torch.arange(self.max_position_embeddings, dtype=torch.float)
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
cache = torch.cat((cos, sin), dim=-1)
|
||||
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||
|
||||
@torch.compile
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
positions = positions.flatten()
|
||||
num_tokens = positions.shape[0]
|
||||
cos_sin = self.cos_sin_cache[positions]
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
query_shape = query.shape
|
||||
query = query.view(num_tokens, -1, self.head_size)
|
||||
query = apply_rotary_emb(query, cos, sin).view(query_shape)
|
||||
key_shape = key.shape
|
||||
key = key.view(num_tokens, -1, self.head_size)
|
||||
key = apply_rotary_emb(key, cos, sin).view(key_shape)
|
||||
return query, key
|
||||
|
||||
|
||||
@lru_cache(1)
|
||||
def get_rope(
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position: int,
|
||||
base: float,
|
||||
rope_scaling: dict | None = None,
|
||||
):
|
||||
assert rope_scaling is None
|
||||
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base)
|
||||
return rotary_emb
|
||||
17
nanovllm/layers/sampler.py
Normal file
17
nanovllm/layers/sampler.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class Sampler(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, logits: torch.Tensor, temperatures: torch.Tensor | None = None):
|
||||
logits = logits.to(torch.float)
|
||||
if temperatures is not None:
|
||||
logits.div_(temperatures.unsqueeze(dim=1))
|
||||
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
||||
# logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
|
||||
sampled_tokens = probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1)
|
||||
return sampled_tokens
|
||||
5
nanovllm/llm.py
Normal file
5
nanovllm/llm.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from nanovllm.engine.llm_engine import LLMEngine
|
||||
|
||||
|
||||
class LLM(LLMEngine):
|
||||
pass
|
||||
247
nanovllm/models/qwen3.py
Executable file
247
nanovllm/models/qwen3.py
Executable file
@@ -0,0 +1,247 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import Qwen3Config
|
||||
|
||||
from nanovllm.layers.activation import SiluAndMul
|
||||
from nanovllm.layers.attention import Attention
|
||||
from nanovllm.layers.layernorm import RMSNorm
|
||||
from nanovllm.layers.linear import QKVParallelLinear, MergedColumnParallelLinear, RowParallelLinear
|
||||
from nanovllm.layers.rotary_embedding import get_rope
|
||||
from nanovllm.layers.embed_head import VocabParallelEmbedding, ParallelLMHead
|
||||
|
||||
|
||||
class Qwen3Attention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
max_position: int = 4096 * 32,
|
||||
head_dim: int | None = None,
|
||||
rms_norm_eps: float = 1e-06,
|
||||
qkv_bias: bool = False,
|
||||
rope_theta: float = 10000,
|
||||
rope_scaling: tuple | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
tp_size = 1 # get_tensor_model_parallel_world_size()
|
||||
self.total_num_heads = num_heads
|
||||
assert self.total_num_heads % tp_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_size
|
||||
self.total_num_kv_heads = num_kv_heads
|
||||
assert self.total_num_kv_heads % tp_size == 0
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||
self.head_dim = head_dim or hidden_size // self.total_num_heads
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.rope_theta = rope_theta
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size,
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=qkv_bias,
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position,
|
||||
base=self.rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
)
|
||||
self.attn = Attention(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads)
|
||||
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
|
||||
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
qkv = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q_by_head = q.view(-1, self.num_heads, self.head_dim)
|
||||
q_by_head = self.q_norm(q_by_head)
|
||||
q = q_by_head.view(q.shape)
|
||||
k_by_head = k.view(-1, self.num_kv_heads, self.head_dim)
|
||||
k_by_head = self.k_norm(k_by_head)
|
||||
k = k_by_head.view(k.shape)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
o = self.attn(q, k, v)
|
||||
output = self.o_proj(o)
|
||||
return output
|
||||
|
||||
|
||||
class Qwen3MLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size,
|
||||
[intermediate_size] * 2,
|
||||
bias=False,
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
)
|
||||
assert hidden_act == "silu"
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
def forward(self, x):
|
||||
gate_up = self.gate_up_proj(x)
|
||||
x = self.act_fn(gate_up)
|
||||
x = self.down_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class Qwen3DecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Qwen3Config,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.self_attn = Qwen3Attention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
max_position=config.max_position_embeddings,
|
||||
rms_norm_eps=config.rms_norm_eps,
|
||||
qkv_bias=getattr(config, 'attention_bias', False),
|
||||
head_dim=getattr(config, 'head_dim', None),
|
||||
rope_theta=getattr(config, "rope_theta", 1000000),
|
||||
rope_scaling=getattr(config, "rope_scaling", None),
|
||||
)
|
||||
self.mlp = Qwen3MLP(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: torch.Tensor | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
)
|
||||
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
class Qwen3Model(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Qwen3Config,
|
||||
):
|
||||
super().__init__()
|
||||
self.vocab_size = config.vocab_size
|
||||
self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size)
|
||||
self.layers = nn.ModuleList([Qwen3DecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
residual = None
|
||||
for layer in self.layers:
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
residual,
|
||||
)
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Qwen3ForCausalLM(nn.Module):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Qwen3Config
|
||||
):
|
||||
super().__init__()
|
||||
self.model = Qwen3Model(config)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
if config.tie_word_embeddings:
|
||||
self.lm_head.weight.data = self.model.embed_tokens.weight.data
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
logits = self.lm_head(hidden_states)
|
||||
return logits
|
||||
|
||||
def load_weights(self, path: str):
|
||||
import os
|
||||
from safetensors import safe_open
|
||||
default_weight_loader = lambda param, loaded_weight: param.data.copy_(loaded_weight)
|
||||
with safe_open(os.path.join(path, "model.safetensors"), "pt", "cpu") as f:
|
||||
for n, p in self.named_parameters():
|
||||
for x in self.packed_modules_mapping:
|
||||
if x in n:
|
||||
weight_loader = getattr(p, "weight_loader", default_weight_loader)
|
||||
for i, y in enumerate(self.packed_modules_mapping[x]):
|
||||
weight_loader(p, f.get_tensor(n.replace(x, y)), y[0] if x == "qkv_proj" else i)
|
||||
break
|
||||
else:
|
||||
weight_loader = getattr(p, "weight_loader", default_weight_loader)
|
||||
weight_loader(p, f.get_tensor(n))
|
||||
8
nanovllm/sampling_params.py
Normal file
8
nanovllm/sampling_params.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class SamplingParams:
|
||||
temperature: float = 1.0
|
||||
max_tokens: int = 64
|
||||
ignore_eos: bool = False
|
||||
28
nanovllm/utils/context.py
Normal file
28
nanovllm/utils/context.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
import torch
|
||||
|
||||
|
||||
@dataclass
|
||||
class Context:
|
||||
is_prefill: bool = False
|
||||
cu_seqlens_q: torch.Tensor | None = None
|
||||
cu_seqlens_k: torch.Tensor | None = None
|
||||
max_seqlen_q: int = 0
|
||||
max_seqlen_k: int = 0
|
||||
slot_mapping: torch.Tensor | None = None
|
||||
context_lens: torch.Tensor | None = None
|
||||
block_tables: torch.Tensor | None = None
|
||||
|
||||
_CONTEXT = Context()
|
||||
|
||||
def get_context():
|
||||
return _CONTEXT
|
||||
|
||||
def set_context(is_prefill, cu_seqlens_q=None, cu_seqlens_k=None, max_seqlen_q=0, max_seqlen_k=0, slot_mapping=None, context_lens=None, block_tables=None, ):
|
||||
global _CONTEXT
|
||||
_CONTEXT = Context(is_prefill, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, context_lens, block_tables)
|
||||
|
||||
def reset_context():
|
||||
global _CONTEXT
|
||||
_CONTEXT = Context()
|
||||
14
nanovllm/utils/memory.py
Normal file
14
nanovllm/utils/memory.py
Normal file
@@ -0,0 +1,14 @@
|
||||
import os
|
||||
import subprocess
|
||||
import torch
|
||||
|
||||
|
||||
def get_gpu_memory(device_id: int = 0):
|
||||
torch.cuda.synchronize()
|
||||
result = subprocess.check_output(
|
||||
['nvidia-smi', '-i', str(device_id), '--query-gpu=memory.total,memory.used,memory.free', '--format=csv,nounits,noheader'],
|
||||
encoding='utf-8'
|
||||
)
|
||||
total_memory, used_memory, free_memory = [int(x) for x in result.strip().split(', ')]
|
||||
return total_memory, used_memory, free_memory
|
||||
|
||||
31
nanovllm/utils/timer.py
Normal file
31
nanovllm/utils/timer.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from contextlib import contextmanager
|
||||
from collections import defaultdict
|
||||
import torch
|
||||
|
||||
|
||||
class CUDATimer:
|
||||
|
||||
def __init__(self):
|
||||
self.events = defaultdict(list)
|
||||
|
||||
@contextmanager
|
||||
def record(self, name, enabled=True):
|
||||
if not enabled:
|
||||
yield
|
||||
else:
|
||||
start, end = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
|
||||
self.events[name].append((start, end))
|
||||
start.record()
|
||||
yield
|
||||
end.record()
|
||||
|
||||
def log(self):
|
||||
torch.cuda.synchronize()
|
||||
ret = []
|
||||
for name, events in self.events.items():
|
||||
total = 0
|
||||
count = len(self.events)
|
||||
for start, end in events:
|
||||
total += start.elapsed_time(end)
|
||||
ret.append(f"{name} {total:.2f}ms/{count}times")
|
||||
return ", ".join(ret)
|
||||
5
requirements.txt
Normal file
5
requirements.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
torch
|
||||
triton
|
||||
transformers
|
||||
cmake
|
||||
ninja
|
||||
Reference in New Issue
Block a user