This commit is contained in:
GeeeekExplorer
2025-06-14 00:36:32 +08:00
parent 9b59dae751
commit 4a8aa090a7
5 changed files with 20 additions and 8 deletions

View File

@@ -5,8 +5,8 @@ A lightweight vLLM implementation built from scratch.
## Key Features ## Key Features
* 🚀 **Fast offline inference** - Comparable inference speeds to vLLM * 🚀 **Fast offline inference** - Comparable inference speeds to vLLM
* 📖 **Readable codebase** - Clean implementation under 1,200 lines of Python code * 📖 **Readable codebase** - Clean implementation in ~ 1,200 lines of Python code
***Optimization Suite** - Prefix caching, Torch compilation, CUDA graph, etc ***Optimization Suite** - Prefix caching, Torch compilation, CUDA graph, etc.
## Installation ## Installation

View File

@@ -4,7 +4,7 @@ from transformers import AutoConfig, AutoTokenizer
from nanovllm.config import Config from nanovllm.config import Config
from nanovllm.sampling_params import SamplingParams from nanovllm.sampling_params import SamplingParams
from nanovllm.engine.sequence import Sequence, SequenceStatus from nanovllm.engine.sequence import Sequence
from nanovllm.engine.scheduler import Scheduler from nanovllm.engine.scheduler import Scheduler
from nanovllm.engine.model_runner import ModelRunner from nanovllm.engine.model_runner import ModelRunner
@@ -34,7 +34,7 @@ class LLMEngine:
seqs, is_prefill = self.scheduler.schedule() seqs, is_prefill = self.scheduler.schedule()
token_ids = self.model_runner.run(seqs, is_prefill) token_ids = self.model_runner.run(seqs, is_prefill)
self.scheduler.postprocess(seqs, token_ids) self.scheduler.postprocess(seqs, token_ids)
outputs = [(seq.seq_id, seq[seq.num_prompt_tokens:]) for seq in seqs if seq.status == SequenceStatus.FINISHED] outputs = [(seq.seq_id, seq.completion_token_ids) for seq in seqs if seq.is_finished]
num_tokens = sum(len(seq) for seq in seqs) if is_prefill else -len(seqs) num_tokens = sum(len(seq) for seq in seqs) if is_prefill else -len(seqs)
return outputs, num_tokens return outputs, num_tokens

View File

@@ -44,7 +44,7 @@ class ModelRunner:
module.v_cache = self.kv_cache[1, layer_id] module.v_cache = self.kv_cache[1, layer_id]
layer_id += 1 layer_id += 1
def preare_block_tables(self, seqs: list[Sequence]): def prepare_block_tables(self, seqs: list[Sequence]):
max_len = max(len(seq.block_table) for seq in seqs) max_len = max(len(seq.block_table) for seq in seqs)
block_tables = [ block_tables = [
seq.block_table + [-1] * (max_len - len(seq.block_table)) seq.block_table + [-1] * (max_len - len(seq.block_table))
@@ -84,7 +84,7 @@ class ModelRunner:
assert len(input_ids) == cu_seqlens_q[-1] assert len(input_ids) == cu_seqlens_q[-1]
if cu_seqlens_k[-1] > cu_seqlens_q[-1]: # prefix cache if cu_seqlens_k[-1] > cu_seqlens_q[-1]: # prefix cache
context_lens = torch.tensor([len(seq) for seq in seqs], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) context_lens = torch.tensor([len(seq) for seq in seqs], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
block_tables = self.preare_block_tables(seqs) block_tables = self.prepare_block_tables(seqs)
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
cu_seqlens_q = torch.tensor(cu_seqlens_q, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) cu_seqlens_q = torch.tensor(cu_seqlens_q, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
@@ -107,7 +107,7 @@ class ModelRunner:
positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
context_lens = torch.tensor(context_lens, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) context_lens = torch.tensor(context_lens, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
block_tables = self.preare_block_tables(seqs) block_tables = self.prepare_block_tables(seqs)
set_context(False, slot_mapping=slot_mapping, context_lens=context_lens, block_tables=block_tables) set_context(False, slot_mapping=slot_mapping, context_lens=context_lens, block_tables=block_tables)
return input_ids, positions return input_ids, positions

View File

@@ -35,10 +35,22 @@ class Sequence:
def __getitem__(self, key): def __getitem__(self, key):
return self.token_ids[key] return self.token_ids[key]
@property
def is_finished(self):
return self.status == SequenceStatus.FINISHED
@property @property
def num_completion_tokens(self): def num_completion_tokens(self):
return len(self.token_ids) - self.num_prompt_tokens return len(self.token_ids) - self.num_prompt_tokens
@property
def prompt_token_ids(self):
return self.token_ids[:self.num_prompt_tokens]
@property
def completion_token_ids(self):
return self.token_ids[self.num_prompt_tokens:]
@property @property
def num_cached_tokens(self): def num_cached_tokens(self):
return self._num_cached_tokens return self._num_cached_tokens

View File

@@ -9,7 +9,7 @@ authors = [{ name = "Xingkai Yu" }]
license = "MIT" license = "MIT"
license-files = ["LICENSE"] license-files = ["LICENSE"]
readme = "README.md" readme = "README.md"
description = "a mimic VLLM implementation from scratch" description = "a lightweight vLLM implementation built from scratch"
requires-python = ">=3.10,<3.13" requires-python = ">=3.10,<3.13"
dependencies = [ dependencies = [
"torch>=2.4.0", "torch>=2.4.0",