diff --git a/README.md b/README.md index 3483f5c..7f93114 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ A lightweight vLLM implementation built from scratch. * 🚀 **Fast offline inference** - Comparable inference speeds to vLLM * 📖 **Readable codebase** - Clean implementation in ~ 1,200 lines of Python code -* ⚡ **Optimization Suite** - Prefix caching, Torch compilation, CUDA graph, etc. +* ⚡ **Optimization Suite** - Prefix caching, Tensor Parallelism, Torch compilation, CUDA graph, etc. ## Installation @@ -17,6 +17,14 @@ pip install git+https://github.com/GeeeekExplorer/nano-vllm.git ## Quick Start See `example.py` for usage. The API mirrors vLLM's interface with minor differences in the `LLM.generate` method. +```python +from nanovllm import LLM, SamplingParams +llm = LLM("/YOUR/MODEL/PATH", enforce_eager=True, tensor_parallel_size=1) +sampling_params = SamplingParams(temperature=0.6, max_tokens=256) +prompts = ["Hello, Nano-vLLM."] +outputs = llm.generate(prompts, sampling_params) +outputs[0]["text"] +``` ## Benchmark diff --git a/nanovllm/engine/block_manager.py b/nanovllm/engine/block_manager.py index 9327e95..02bc682 100644 --- a/nanovllm/engine/block_manager.py +++ b/nanovllm/engine/block_manager.py @@ -31,13 +31,11 @@ class Block: self.hash = -1 self.token_ids = [] - def __repr__(self): - return f"{(self.block_id, self.ref_count, self.hash)}" - class BlockManager: def __init__(self, num_blocks: int, block_size: int): + assert num_blocks > 0 self.block_size = block_size self.blocks: list[Block] = [Block(i) for i in range(num_blocks)] self.hash_to_block_id: dict[int, int] = dict() diff --git a/nanovllm/engine/llm_engine.py b/nanovllm/engine/llm_engine.py index 028f170..7d73d42 100644 --- a/nanovllm/engine/llm_engine.py +++ b/nanovllm/engine/llm_engine.py @@ -2,7 +2,7 @@ import atexit from dataclasses import fields from time import perf_counter from tqdm.auto import tqdm -from transformers import AutoConfig, AutoTokenizer +from transformers import AutoTokenizer import torch.multiprocessing as mp from nanovllm.config import Config @@ -62,11 +62,7 @@ class LLMEngine: use_tqdm: bool = True, ) -> list[str]: if use_tqdm: - pbar = tqdm( - total=len(prompts), - desc="Generating", - dynamic_ncols=True, - ) + pbar = tqdm(total=len(prompts), desc="Generating", dynamic_ncols=True) if not isinstance(sampling_params, list): sampling_params = [sampling_params] * len(prompts) for prompt, sp in zip(prompts, sampling_params): diff --git a/nanovllm/engine/model_runner.py b/nanovllm/engine/model_runner.py index e5958ec..0c7b2ef 100644 --- a/nanovllm/engine/model_runner.py +++ b/nanovllm/engine/model_runner.py @@ -53,7 +53,7 @@ class ModelRunner: dist.barrier() if self.rank == 0: self.shm.unlink() - # dist.destroy_process_group() + dist.destroy_process_group() def loop(self): while True: @@ -92,7 +92,7 @@ class ModelRunner: hf_config = config.hf_config total, used, _ = get_gpu_memory() free = total * gpu_memory_utilization - used - num_kv_heads = hf_config.num_key_value_heads // dist.get_world_size() + num_kv_heads = hf_config.num_key_value_heads // self.world_size block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * hf_config.head_dim * hf_config.torch_dtype.itemsize config.num_kvcache_blocks = int(free) // block_bytes self.kv_cache = torch.zeros(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, hf_config.head_dim) @@ -120,7 +120,6 @@ class ModelRunner: max_seqlen_q = 0 max_seqlen_k = 0 slot_mapping = [] - context_lens = None block_tables = None for seq in seqs: seqlen = len(seq) @@ -142,14 +141,13 @@ class ModelRunner: assert len(input_ids) == len(slot_mapping) assert len(input_ids) == cu_seqlens_q[-1] if cu_seqlens_k[-1] > cu_seqlens_q[-1]: # prefix cache - context_lens = torch.tensor([len(seq) for seq in seqs], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) block_tables = self.prepare_block_tables(seqs) input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) cu_seqlens_q = torch.tensor(cu_seqlens_q, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) cu_seqlens_k = torch.tensor(cu_seqlens_k, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) - set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, context_lens, block_tables) + set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, None, block_tables) return input_ids, positions def prepare_decode(self, seqs: list[Sequence]): @@ -205,7 +203,7 @@ class ModelRunner: def run(self, seqs: list[Sequence], is_prefill: bool) -> list[int]: input_ids, positions = self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs) - temperatures = self.prepare_sample(seqs) + temperatures = self.prepare_sample(seqs) if self.rank == 0 else None logits = self.run_model(input_ids, positions, is_prefill) token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None reset_context() diff --git a/nanovllm/engine/sequence.py b/nanovllm/engine/sequence.py index 00866b4..094b77e 100644 --- a/nanovllm/engine/sequence.py +++ b/nanovllm/engine/sequence.py @@ -31,9 +31,6 @@ class Sequence: def __len__(self): return self.num_tokens - def __lt__(self, other): - return self.seq_id < other.seq_id - def __getitem__(self, key): return self.token_ids[key] @@ -75,7 +72,14 @@ class Sequence: self.num_tokens += 1 def __getstate__(self): - state = vars(self).copy() - if self.num_completion_tokens: - state.pop("token_ids") + state = { + "num_tokens": self.num_tokens, + "num_prompt_tokens": self.num_prompt_tokens, + "num_cached_tokens": self.num_cached_tokens, + "block_table": self.block_table, + } + if self.num_completion_tokens == 0: + state["token_ids"] = self.token_ids + else: + state["last_token"] = self.last_token return state diff --git a/nanovllm/utils/context.py b/nanovllm/utils/context.py index 65bbd18..a9540b7 100644 --- a/nanovllm/utils/context.py +++ b/nanovllm/utils/context.py @@ -19,7 +19,7 @@ _CONTEXT = Context() def get_context(): return _CONTEXT -def set_context(is_prefill, cu_seqlens_q=None, cu_seqlens_k=None, max_seqlen_q=0, max_seqlen_k=0, slot_mapping=None, context_lens=None, block_tables=None, ): +def set_context(is_prefill, cu_seqlens_q=None, cu_seqlens_k=None, max_seqlen_q=0, max_seqlen_k=0, slot_mapping=None, context_lens=None, block_tables=None): global _CONTEXT _CONTEXT = Context(is_prefill, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, context_lens, block_tables)