This commit is contained in:
GeeeekExplorer
2025-06-17 23:15:02 +08:00
parent 7e42fa6f63
commit bc0ad5a116
6 changed files with 27 additions and 23 deletions

View File

@@ -6,7 +6,7 @@ A lightweight vLLM implementation built from scratch.
* 🚀 **Fast offline inference** - Comparable inference speeds to vLLM * 🚀 **Fast offline inference** - Comparable inference speeds to vLLM
* 📖 **Readable codebase** - Clean implementation in ~ 1,200 lines of Python code * 📖 **Readable codebase** - Clean implementation in ~ 1,200 lines of Python code
***Optimization Suite** - Prefix caching, Torch compilation, CUDA graph, etc. ***Optimization Suite** - Prefix caching, Tensor Parallelism, Torch compilation, CUDA graph, etc.
## Installation ## Installation
@@ -17,6 +17,14 @@ pip install git+https://github.com/GeeeekExplorer/nano-vllm.git
## Quick Start ## Quick Start
See `example.py` for usage. The API mirrors vLLM's interface with minor differences in the `LLM.generate` method. See `example.py` for usage. The API mirrors vLLM's interface with minor differences in the `LLM.generate` method.
```python
from nanovllm import LLM, SamplingParams
llm = LLM("/YOUR/MODEL/PATH", enforce_eager=True, tensor_parallel_size=1)
sampling_params = SamplingParams(temperature=0.6, max_tokens=256)
prompts = ["Hello, Nano-vLLM."]
outputs = llm.generate(prompts, sampling_params)
outputs[0]["text"]
```
## Benchmark ## Benchmark

View File

@@ -31,13 +31,11 @@ class Block:
self.hash = -1 self.hash = -1
self.token_ids = [] self.token_ids = []
def __repr__(self):
return f"{(self.block_id, self.ref_count, self.hash)}"
class BlockManager: class BlockManager:
def __init__(self, num_blocks: int, block_size: int): def __init__(self, num_blocks: int, block_size: int):
assert num_blocks > 0
self.block_size = block_size self.block_size = block_size
self.blocks: list[Block] = [Block(i) for i in range(num_blocks)] self.blocks: list[Block] = [Block(i) for i in range(num_blocks)]
self.hash_to_block_id: dict[int, int] = dict() self.hash_to_block_id: dict[int, int] = dict()

View File

@@ -2,7 +2,7 @@ import atexit
from dataclasses import fields from dataclasses import fields
from time import perf_counter from time import perf_counter
from tqdm.auto import tqdm from tqdm.auto import tqdm
from transformers import AutoConfig, AutoTokenizer from transformers import AutoTokenizer
import torch.multiprocessing as mp import torch.multiprocessing as mp
from nanovllm.config import Config from nanovllm.config import Config
@@ -62,11 +62,7 @@ class LLMEngine:
use_tqdm: bool = True, use_tqdm: bool = True,
) -> list[str]: ) -> list[str]:
if use_tqdm: if use_tqdm:
pbar = tqdm( pbar = tqdm(total=len(prompts), desc="Generating", dynamic_ncols=True)
total=len(prompts),
desc="Generating",
dynamic_ncols=True,
)
if not isinstance(sampling_params, list): if not isinstance(sampling_params, list):
sampling_params = [sampling_params] * len(prompts) sampling_params = [sampling_params] * len(prompts)
for prompt, sp in zip(prompts, sampling_params): for prompt, sp in zip(prompts, sampling_params):

View File

@@ -53,7 +53,7 @@ class ModelRunner:
dist.barrier() dist.barrier()
if self.rank == 0: if self.rank == 0:
self.shm.unlink() self.shm.unlink()
# dist.destroy_process_group() dist.destroy_process_group()
def loop(self): def loop(self):
while True: while True:
@@ -92,7 +92,7 @@ class ModelRunner:
hf_config = config.hf_config hf_config = config.hf_config
total, used, _ = get_gpu_memory() total, used, _ = get_gpu_memory()
free = total * gpu_memory_utilization - used free = total * gpu_memory_utilization - used
num_kv_heads = hf_config.num_key_value_heads // dist.get_world_size() num_kv_heads = hf_config.num_key_value_heads // self.world_size
block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * hf_config.head_dim * hf_config.torch_dtype.itemsize block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * hf_config.head_dim * hf_config.torch_dtype.itemsize
config.num_kvcache_blocks = int(free) // block_bytes config.num_kvcache_blocks = int(free) // block_bytes
self.kv_cache = torch.zeros(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, hf_config.head_dim) self.kv_cache = torch.zeros(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, hf_config.head_dim)
@@ -120,7 +120,6 @@ class ModelRunner:
max_seqlen_q = 0 max_seqlen_q = 0
max_seqlen_k = 0 max_seqlen_k = 0
slot_mapping = [] slot_mapping = []
context_lens = None
block_tables = None block_tables = None
for seq in seqs: for seq in seqs:
seqlen = len(seq) seqlen = len(seq)
@@ -142,14 +141,13 @@ class ModelRunner:
assert len(input_ids) == len(slot_mapping) assert len(input_ids) == len(slot_mapping)
assert len(input_ids) == cu_seqlens_q[-1] assert len(input_ids) == cu_seqlens_q[-1]
if cu_seqlens_k[-1] > cu_seqlens_q[-1]: # prefix cache if cu_seqlens_k[-1] > cu_seqlens_q[-1]: # prefix cache
context_lens = torch.tensor([len(seq) for seq in seqs], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
block_tables = self.prepare_block_tables(seqs) block_tables = self.prepare_block_tables(seqs)
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
cu_seqlens_q = torch.tensor(cu_seqlens_q, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) cu_seqlens_q = torch.tensor(cu_seqlens_q, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
cu_seqlens_k = torch.tensor(cu_seqlens_k, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) cu_seqlens_k = torch.tensor(cu_seqlens_k, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, context_lens, block_tables) set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, None, block_tables)
return input_ids, positions return input_ids, positions
def prepare_decode(self, seqs: list[Sequence]): def prepare_decode(self, seqs: list[Sequence]):
@@ -205,7 +203,7 @@ class ModelRunner:
def run(self, seqs: list[Sequence], is_prefill: bool) -> list[int]: def run(self, seqs: list[Sequence], is_prefill: bool) -> list[int]:
input_ids, positions = self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs) input_ids, positions = self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs)
temperatures = self.prepare_sample(seqs) temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
logits = self.run_model(input_ids, positions, is_prefill) logits = self.run_model(input_ids, positions, is_prefill)
token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None
reset_context() reset_context()

View File

@@ -31,9 +31,6 @@ class Sequence:
def __len__(self): def __len__(self):
return self.num_tokens return self.num_tokens
def __lt__(self, other):
return self.seq_id < other.seq_id
def __getitem__(self, key): def __getitem__(self, key):
return self.token_ids[key] return self.token_ids[key]
@@ -75,7 +72,14 @@ class Sequence:
self.num_tokens += 1 self.num_tokens += 1
def __getstate__(self): def __getstate__(self):
state = vars(self).copy() state = {
if self.num_completion_tokens: "num_tokens": self.num_tokens,
state.pop("token_ids") "num_prompt_tokens": self.num_prompt_tokens,
"num_cached_tokens": self.num_cached_tokens,
"block_table": self.block_table,
}
if self.num_completion_tokens == 0:
state["token_ids"] = self.token_ids
else:
state["last_token"] = self.last_token
return state return state

View File

@@ -19,7 +19,7 @@ _CONTEXT = Context()
def get_context(): def get_context():
return _CONTEXT return _CONTEXT
def set_context(is_prefill, cu_seqlens_q=None, cu_seqlens_k=None, max_seqlen_q=0, max_seqlen_k=0, slot_mapping=None, context_lens=None, block_tables=None, ): def set_context(is_prefill, cu_seqlens_q=None, cu_seqlens_k=None, max_seqlen_q=0, max_seqlen_k=0, slot_mapping=None, context_lens=None, block_tables=None):
global _CONTEXT global _CONTEXT
_CONTEXT = Context(is_prefill, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, context_lens, block_tables) _CONTEXT = Context(is_prefill, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, context_lens, block_tables)