better
This commit is contained in:
10
README.md
10
README.md
@@ -6,7 +6,7 @@ A lightweight vLLM implementation built from scratch.
|
||||
|
||||
* 🚀 **Fast offline inference** - Comparable inference speeds to vLLM
|
||||
* 📖 **Readable codebase** - Clean implementation in ~ 1,200 lines of Python code
|
||||
* ⚡ **Optimization Suite** - Prefix caching, Torch compilation, CUDA graph, etc.
|
||||
* ⚡ **Optimization Suite** - Prefix caching, Tensor Parallelism, Torch compilation, CUDA graph, etc.
|
||||
|
||||
## Installation
|
||||
|
||||
@@ -17,6 +17,14 @@ pip install git+https://github.com/GeeeekExplorer/nano-vllm.git
|
||||
## Quick Start
|
||||
|
||||
See `example.py` for usage. The API mirrors vLLM's interface with minor differences in the `LLM.generate` method.
|
||||
```python
|
||||
from nanovllm import LLM, SamplingParams
|
||||
llm = LLM("/YOUR/MODEL/PATH", enforce_eager=True, tensor_parallel_size=1)
|
||||
sampling_params = SamplingParams(temperature=0.6, max_tokens=256)
|
||||
prompts = ["Hello, Nano-vLLM."]
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
outputs[0]["text"]
|
||||
```
|
||||
|
||||
## Benchmark
|
||||
|
||||
|
||||
@@ -31,13 +31,11 @@ class Block:
|
||||
self.hash = -1
|
||||
self.token_ids = []
|
||||
|
||||
def __repr__(self):
|
||||
return f"{(self.block_id, self.ref_count, self.hash)}"
|
||||
|
||||
|
||||
class BlockManager:
|
||||
|
||||
def __init__(self, num_blocks: int, block_size: int):
|
||||
assert num_blocks > 0
|
||||
self.block_size = block_size
|
||||
self.blocks: list[Block] = [Block(i) for i in range(num_blocks)]
|
||||
self.hash_to_block_id: dict[int, int] = dict()
|
||||
|
||||
@@ -2,7 +2,7 @@ import atexit
|
||||
from dataclasses import fields
|
||||
from time import perf_counter
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import AutoConfig, AutoTokenizer
|
||||
from transformers import AutoTokenizer
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from nanovllm.config import Config
|
||||
@@ -62,11 +62,7 @@ class LLMEngine:
|
||||
use_tqdm: bool = True,
|
||||
) -> list[str]:
|
||||
if use_tqdm:
|
||||
pbar = tqdm(
|
||||
total=len(prompts),
|
||||
desc="Generating",
|
||||
dynamic_ncols=True,
|
||||
)
|
||||
pbar = tqdm(total=len(prompts), desc="Generating", dynamic_ncols=True)
|
||||
if not isinstance(sampling_params, list):
|
||||
sampling_params = [sampling_params] * len(prompts)
|
||||
for prompt, sp in zip(prompts, sampling_params):
|
||||
|
||||
@@ -53,7 +53,7 @@ class ModelRunner:
|
||||
dist.barrier()
|
||||
if self.rank == 0:
|
||||
self.shm.unlink()
|
||||
# dist.destroy_process_group()
|
||||
dist.destroy_process_group()
|
||||
|
||||
def loop(self):
|
||||
while True:
|
||||
@@ -92,7 +92,7 @@ class ModelRunner:
|
||||
hf_config = config.hf_config
|
||||
total, used, _ = get_gpu_memory()
|
||||
free = total * gpu_memory_utilization - used
|
||||
num_kv_heads = hf_config.num_key_value_heads // dist.get_world_size()
|
||||
num_kv_heads = hf_config.num_key_value_heads // self.world_size
|
||||
block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * hf_config.head_dim * hf_config.torch_dtype.itemsize
|
||||
config.num_kvcache_blocks = int(free) // block_bytes
|
||||
self.kv_cache = torch.zeros(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, hf_config.head_dim)
|
||||
@@ -120,7 +120,6 @@ class ModelRunner:
|
||||
max_seqlen_q = 0
|
||||
max_seqlen_k = 0
|
||||
slot_mapping = []
|
||||
context_lens = None
|
||||
block_tables = None
|
||||
for seq in seqs:
|
||||
seqlen = len(seq)
|
||||
@@ -142,14 +141,13 @@ class ModelRunner:
|
||||
assert len(input_ids) == len(slot_mapping)
|
||||
assert len(input_ids) == cu_seqlens_q[-1]
|
||||
if cu_seqlens_k[-1] > cu_seqlens_q[-1]: # prefix cache
|
||||
context_lens = torch.tensor([len(seq) for seq in seqs], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||
block_tables = self.prepare_block_tables(seqs)
|
||||
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
||||
positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
||||
cu_seqlens_q = torch.tensor(cu_seqlens_q, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||
cu_seqlens_k = torch.tensor(cu_seqlens_k, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||
set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, context_lens, block_tables)
|
||||
set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, None, block_tables)
|
||||
return input_ids, positions
|
||||
|
||||
def prepare_decode(self, seqs: list[Sequence]):
|
||||
@@ -205,7 +203,7 @@ class ModelRunner:
|
||||
|
||||
def run(self, seqs: list[Sequence], is_prefill: bool) -> list[int]:
|
||||
input_ids, positions = self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs)
|
||||
temperatures = self.prepare_sample(seqs)
|
||||
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
|
||||
logits = self.run_model(input_ids, positions, is_prefill)
|
||||
token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None
|
||||
reset_context()
|
||||
|
||||
@@ -31,9 +31,6 @@ class Sequence:
|
||||
def __len__(self):
|
||||
return self.num_tokens
|
||||
|
||||
def __lt__(self, other):
|
||||
return self.seq_id < other.seq_id
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.token_ids[key]
|
||||
|
||||
@@ -75,7 +72,14 @@ class Sequence:
|
||||
self.num_tokens += 1
|
||||
|
||||
def __getstate__(self):
|
||||
state = vars(self).copy()
|
||||
if self.num_completion_tokens:
|
||||
state.pop("token_ids")
|
||||
state = {
|
||||
"num_tokens": self.num_tokens,
|
||||
"num_prompt_tokens": self.num_prompt_tokens,
|
||||
"num_cached_tokens": self.num_cached_tokens,
|
||||
"block_table": self.block_table,
|
||||
}
|
||||
if self.num_completion_tokens == 0:
|
||||
state["token_ids"] = self.token_ids
|
||||
else:
|
||||
state["last_token"] = self.last_token
|
||||
return state
|
||||
|
||||
@@ -19,7 +19,7 @@ _CONTEXT = Context()
|
||||
def get_context():
|
||||
return _CONTEXT
|
||||
|
||||
def set_context(is_prefill, cu_seqlens_q=None, cu_seqlens_k=None, max_seqlen_q=0, max_seqlen_k=0, slot_mapping=None, context_lens=None, block_tables=None, ):
|
||||
def set_context(is_prefill, cu_seqlens_q=None, cu_seqlens_k=None, max_seqlen_q=0, max_seqlen_k=0, slot_mapping=None, context_lens=None, block_tables=None):
|
||||
global _CONTEXT
|
||||
_CONTEXT = Context(is_prefill, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, context_lens, block_tables)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user