This commit is contained in:
GeeeekExplorer
2025-06-17 23:15:02 +08:00
parent 7e42fa6f63
commit bc0ad5a116
6 changed files with 27 additions and 23 deletions

View File

@@ -6,7 +6,7 @@ A lightweight vLLM implementation built from scratch.
* 🚀 **Fast offline inference** - Comparable inference speeds to vLLM
* 📖 **Readable codebase** - Clean implementation in ~ 1,200 lines of Python code
***Optimization Suite** - Prefix caching, Torch compilation, CUDA graph, etc.
***Optimization Suite** - Prefix caching, Tensor Parallelism, Torch compilation, CUDA graph, etc.
## Installation
@@ -17,6 +17,14 @@ pip install git+https://github.com/GeeeekExplorer/nano-vllm.git
## Quick Start
See `example.py` for usage. The API mirrors vLLM's interface with minor differences in the `LLM.generate` method.
```python
from nanovllm import LLM, SamplingParams
llm = LLM("/YOUR/MODEL/PATH", enforce_eager=True, tensor_parallel_size=1)
sampling_params = SamplingParams(temperature=0.6, max_tokens=256)
prompts = ["Hello, Nano-vLLM."]
outputs = llm.generate(prompts, sampling_params)
outputs[0]["text"]
```
## Benchmark

View File

@@ -31,13 +31,11 @@ class Block:
self.hash = -1
self.token_ids = []
def __repr__(self):
return f"{(self.block_id, self.ref_count, self.hash)}"
class BlockManager:
def __init__(self, num_blocks: int, block_size: int):
assert num_blocks > 0
self.block_size = block_size
self.blocks: list[Block] = [Block(i) for i in range(num_blocks)]
self.hash_to_block_id: dict[int, int] = dict()

View File

@@ -2,7 +2,7 @@ import atexit
from dataclasses import fields
from time import perf_counter
from tqdm.auto import tqdm
from transformers import AutoConfig, AutoTokenizer
from transformers import AutoTokenizer
import torch.multiprocessing as mp
from nanovllm.config import Config
@@ -62,11 +62,7 @@ class LLMEngine:
use_tqdm: bool = True,
) -> list[str]:
if use_tqdm:
pbar = tqdm(
total=len(prompts),
desc="Generating",
dynamic_ncols=True,
)
pbar = tqdm(total=len(prompts), desc="Generating", dynamic_ncols=True)
if not isinstance(sampling_params, list):
sampling_params = [sampling_params] * len(prompts)
for prompt, sp in zip(prompts, sampling_params):

View File

@@ -53,7 +53,7 @@ class ModelRunner:
dist.barrier()
if self.rank == 0:
self.shm.unlink()
# dist.destroy_process_group()
dist.destroy_process_group()
def loop(self):
while True:
@@ -92,7 +92,7 @@ class ModelRunner:
hf_config = config.hf_config
total, used, _ = get_gpu_memory()
free = total * gpu_memory_utilization - used
num_kv_heads = hf_config.num_key_value_heads // dist.get_world_size()
num_kv_heads = hf_config.num_key_value_heads // self.world_size
block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * hf_config.head_dim * hf_config.torch_dtype.itemsize
config.num_kvcache_blocks = int(free) // block_bytes
self.kv_cache = torch.zeros(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, hf_config.head_dim)
@@ -120,7 +120,6 @@ class ModelRunner:
max_seqlen_q = 0
max_seqlen_k = 0
slot_mapping = []
context_lens = None
block_tables = None
for seq in seqs:
seqlen = len(seq)
@@ -142,14 +141,13 @@ class ModelRunner:
assert len(input_ids) == len(slot_mapping)
assert len(input_ids) == cu_seqlens_q[-1]
if cu_seqlens_k[-1] > cu_seqlens_q[-1]: # prefix cache
context_lens = torch.tensor([len(seq) for seq in seqs], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
block_tables = self.prepare_block_tables(seqs)
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
cu_seqlens_q = torch.tensor(cu_seqlens_q, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
cu_seqlens_k = torch.tensor(cu_seqlens_k, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, context_lens, block_tables)
set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, None, block_tables)
return input_ids, positions
def prepare_decode(self, seqs: list[Sequence]):
@@ -205,7 +203,7 @@ class ModelRunner:
def run(self, seqs: list[Sequence], is_prefill: bool) -> list[int]:
input_ids, positions = self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs)
temperatures = self.prepare_sample(seqs)
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
logits = self.run_model(input_ids, positions, is_prefill)
token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None
reset_context()

View File

@@ -31,9 +31,6 @@ class Sequence:
def __len__(self):
return self.num_tokens
def __lt__(self, other):
return self.seq_id < other.seq_id
def __getitem__(self, key):
return self.token_ids[key]
@@ -75,7 +72,14 @@ class Sequence:
self.num_tokens += 1
def __getstate__(self):
state = vars(self).copy()
if self.num_completion_tokens:
state.pop("token_ids")
state = {
"num_tokens": self.num_tokens,
"num_prompt_tokens": self.num_prompt_tokens,
"num_cached_tokens": self.num_cached_tokens,
"block_table": self.block_table,
}
if self.num_completion_tokens == 0:
state["token_ids"] = self.token_ids
else:
state["last_token"] = self.last_token
return state

View File

@@ -19,7 +19,7 @@ _CONTEXT = Context()
def get_context():
return _CONTEXT
def set_context(is_prefill, cu_seqlens_q=None, cu_seqlens_k=None, max_seqlen_q=0, max_seqlen_k=0, slot_mapping=None, context_lens=None, block_tables=None, ):
def set_context(is_prefill, cu_seqlens_q=None, cu_seqlens_k=None, max_seqlen_q=0, max_seqlen_k=0, slot_mapping=None, context_lens=None, block_tables=None):
global _CONTEXT
_CONTEXT = Context(is_prefill, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, context_lens, block_tables)