diff --git a/README.md b/README.md index 4059cf1..3483f5c 100644 --- a/README.md +++ b/README.md @@ -5,8 +5,8 @@ A lightweight vLLM implementation built from scratch. ## Key Features * 🚀 **Fast offline inference** - Comparable inference speeds to vLLM -* 📖 **Readable codebase** - Clean implementation under 1,200 lines of Python code -* ⚡ **Optimization Suite** - Prefix caching, Torch compilation, CUDA graph, etc +* 📖 **Readable codebase** - Clean implementation in ~ 1,200 lines of Python code +* ⚡ **Optimization Suite** - Prefix caching, Torch compilation, CUDA graph, etc. ## Installation diff --git a/nanovllm/engine/llm_engine.py b/nanovllm/engine/llm_engine.py index eaa16f8..07a53c2 100644 --- a/nanovllm/engine/llm_engine.py +++ b/nanovllm/engine/llm_engine.py @@ -4,7 +4,7 @@ from transformers import AutoConfig, AutoTokenizer from nanovllm.config import Config from nanovllm.sampling_params import SamplingParams -from nanovllm.engine.sequence import Sequence, SequenceStatus +from nanovllm.engine.sequence import Sequence from nanovllm.engine.scheduler import Scheduler from nanovllm.engine.model_runner import ModelRunner @@ -34,7 +34,7 @@ class LLMEngine: seqs, is_prefill = self.scheduler.schedule() token_ids = self.model_runner.run(seqs, is_prefill) self.scheduler.postprocess(seqs, token_ids) - outputs = [(seq.seq_id, seq[seq.num_prompt_tokens:]) for seq in seqs if seq.status == SequenceStatus.FINISHED] + outputs = [(seq.seq_id, seq.completion_token_ids) for seq in seqs if seq.is_finished] num_tokens = sum(len(seq) for seq in seqs) if is_prefill else -len(seqs) return outputs, num_tokens diff --git a/nanovllm/engine/model_runner.py b/nanovllm/engine/model_runner.py index 623fe2d..b39fabb 100644 --- a/nanovllm/engine/model_runner.py +++ b/nanovllm/engine/model_runner.py @@ -44,7 +44,7 @@ class ModelRunner: module.v_cache = self.kv_cache[1, layer_id] layer_id += 1 - def preare_block_tables(self, seqs: list[Sequence]): + def prepare_block_tables(self, seqs: list[Sequence]): max_len = max(len(seq.block_table) for seq in seqs) block_tables = [ seq.block_table + [-1] * (max_len - len(seq.block_table)) @@ -84,7 +84,7 @@ class ModelRunner: assert len(input_ids) == cu_seqlens_q[-1] if cu_seqlens_k[-1] > cu_seqlens_q[-1]: # prefix cache context_lens = torch.tensor([len(seq) for seq in seqs], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) - block_tables = self.preare_block_tables(seqs) + block_tables = self.prepare_block_tables(seqs) input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) cu_seqlens_q = torch.tensor(cu_seqlens_q, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) @@ -107,7 +107,7 @@ class ModelRunner: positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) context_lens = torch.tensor(context_lens, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) - block_tables = self.preare_block_tables(seqs) + block_tables = self.prepare_block_tables(seqs) set_context(False, slot_mapping=slot_mapping, context_lens=context_lens, block_tables=block_tables) return input_ids, positions diff --git a/nanovllm/engine/sequence.py b/nanovllm/engine/sequence.py index d519fb7..51c0438 100644 --- a/nanovllm/engine/sequence.py +++ b/nanovllm/engine/sequence.py @@ -35,10 +35,22 @@ class Sequence: def __getitem__(self, key): return self.token_ids[key] + @property + def is_finished(self): + return self.status == SequenceStatus.FINISHED + @property def num_completion_tokens(self): return len(self.token_ids) - self.num_prompt_tokens + @property + def prompt_token_ids(self): + return self.token_ids[:self.num_prompt_tokens] + + @property + def completion_token_ids(self): + return self.token_ids[self.num_prompt_tokens:] + @property def num_cached_tokens(self): return self._num_cached_tokens diff --git a/pyproject.toml b/pyproject.toml index e5000b4..efb424b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ authors = [{ name = "Xingkai Yu" }] license = "MIT" license-files = ["LICENSE"] readme = "README.md" -description = "a mimic VLLM implementation from scratch" +description = "a lightweight vLLM implementation built from scratch" requires-python = ">=3.10,<3.13" dependencies = [ "torch>=2.4.0",