better
This commit is contained in:
10
README.md
10
README.md
@@ -6,7 +6,7 @@ A lightweight vLLM implementation built from scratch.
|
|||||||
|
|
||||||
* 🚀 **Fast offline inference** - Comparable inference speeds to vLLM
|
* 🚀 **Fast offline inference** - Comparable inference speeds to vLLM
|
||||||
* 📖 **Readable codebase** - Clean implementation in ~ 1,200 lines of Python code
|
* 📖 **Readable codebase** - Clean implementation in ~ 1,200 lines of Python code
|
||||||
* ⚡ **Optimization Suite** - Prefix caching, Torch compilation, CUDA graph, etc.
|
* ⚡ **Optimization Suite** - Prefix caching, Tensor Parallelism, Torch compilation, CUDA graph, etc.
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
@@ -17,6 +17,14 @@ pip install git+https://github.com/GeeeekExplorer/nano-vllm.git
|
|||||||
## Quick Start
|
## Quick Start
|
||||||
|
|
||||||
See `example.py` for usage. The API mirrors vLLM's interface with minor differences in the `LLM.generate` method.
|
See `example.py` for usage. The API mirrors vLLM's interface with minor differences in the `LLM.generate` method.
|
||||||
|
```python
|
||||||
|
from nanovllm import LLM, SamplingParams
|
||||||
|
llm = LLM("/YOUR/MODEL/PATH", enforce_eager=True, tensor_parallel_size=1)
|
||||||
|
sampling_params = SamplingParams(temperature=0.6, max_tokens=256)
|
||||||
|
prompts = ["Hello, Nano-vLLM."]
|
||||||
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
|
outputs[0]["text"]
|
||||||
|
```
|
||||||
|
|
||||||
## Benchmark
|
## Benchmark
|
||||||
|
|
||||||
|
|||||||
@@ -31,13 +31,11 @@ class Block:
|
|||||||
self.hash = -1
|
self.hash = -1
|
||||||
self.token_ids = []
|
self.token_ids = []
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f"{(self.block_id, self.ref_count, self.hash)}"
|
|
||||||
|
|
||||||
|
|
||||||
class BlockManager:
|
class BlockManager:
|
||||||
|
|
||||||
def __init__(self, num_blocks: int, block_size: int):
|
def __init__(self, num_blocks: int, block_size: int):
|
||||||
|
assert num_blocks > 0
|
||||||
self.block_size = block_size
|
self.block_size = block_size
|
||||||
self.blocks: list[Block] = [Block(i) for i in range(num_blocks)]
|
self.blocks: list[Block] = [Block(i) for i in range(num_blocks)]
|
||||||
self.hash_to_block_id: dict[int, int] = dict()
|
self.hash_to_block_id: dict[int, int] = dict()
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import atexit
|
|||||||
from dataclasses import fields
|
from dataclasses import fields
|
||||||
from time import perf_counter
|
from time import perf_counter
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
from transformers import AutoConfig, AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
from nanovllm.config import Config
|
from nanovllm.config import Config
|
||||||
@@ -62,11 +62,7 @@ class LLMEngine:
|
|||||||
use_tqdm: bool = True,
|
use_tqdm: bool = True,
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
if use_tqdm:
|
if use_tqdm:
|
||||||
pbar = tqdm(
|
pbar = tqdm(total=len(prompts), desc="Generating", dynamic_ncols=True)
|
||||||
total=len(prompts),
|
|
||||||
desc="Generating",
|
|
||||||
dynamic_ncols=True,
|
|
||||||
)
|
|
||||||
if not isinstance(sampling_params, list):
|
if not isinstance(sampling_params, list):
|
||||||
sampling_params = [sampling_params] * len(prompts)
|
sampling_params = [sampling_params] * len(prompts)
|
||||||
for prompt, sp in zip(prompts, sampling_params):
|
for prompt, sp in zip(prompts, sampling_params):
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ class ModelRunner:
|
|||||||
dist.barrier()
|
dist.barrier()
|
||||||
if self.rank == 0:
|
if self.rank == 0:
|
||||||
self.shm.unlink()
|
self.shm.unlink()
|
||||||
# dist.destroy_process_group()
|
dist.destroy_process_group()
|
||||||
|
|
||||||
def loop(self):
|
def loop(self):
|
||||||
while True:
|
while True:
|
||||||
@@ -92,7 +92,7 @@ class ModelRunner:
|
|||||||
hf_config = config.hf_config
|
hf_config = config.hf_config
|
||||||
total, used, _ = get_gpu_memory()
|
total, used, _ = get_gpu_memory()
|
||||||
free = total * gpu_memory_utilization - used
|
free = total * gpu_memory_utilization - used
|
||||||
num_kv_heads = hf_config.num_key_value_heads // dist.get_world_size()
|
num_kv_heads = hf_config.num_key_value_heads // self.world_size
|
||||||
block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * hf_config.head_dim * hf_config.torch_dtype.itemsize
|
block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * hf_config.head_dim * hf_config.torch_dtype.itemsize
|
||||||
config.num_kvcache_blocks = int(free) // block_bytes
|
config.num_kvcache_blocks = int(free) // block_bytes
|
||||||
self.kv_cache = torch.zeros(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, hf_config.head_dim)
|
self.kv_cache = torch.zeros(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, hf_config.head_dim)
|
||||||
@@ -120,7 +120,6 @@ class ModelRunner:
|
|||||||
max_seqlen_q = 0
|
max_seqlen_q = 0
|
||||||
max_seqlen_k = 0
|
max_seqlen_k = 0
|
||||||
slot_mapping = []
|
slot_mapping = []
|
||||||
context_lens = None
|
|
||||||
block_tables = None
|
block_tables = None
|
||||||
for seq in seqs:
|
for seq in seqs:
|
||||||
seqlen = len(seq)
|
seqlen = len(seq)
|
||||||
@@ -142,14 +141,13 @@ class ModelRunner:
|
|||||||
assert len(input_ids) == len(slot_mapping)
|
assert len(input_ids) == len(slot_mapping)
|
||||||
assert len(input_ids) == cu_seqlens_q[-1]
|
assert len(input_ids) == cu_seqlens_q[-1]
|
||||||
if cu_seqlens_k[-1] > cu_seqlens_q[-1]: # prefix cache
|
if cu_seqlens_k[-1] > cu_seqlens_q[-1]: # prefix cache
|
||||||
context_lens = torch.tensor([len(seq) for seq in seqs], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
|
||||||
block_tables = self.prepare_block_tables(seqs)
|
block_tables = self.prepare_block_tables(seqs)
|
||||||
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
||||||
positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
||||||
cu_seqlens_q = torch.tensor(cu_seqlens_q, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
cu_seqlens_q = torch.tensor(cu_seqlens_q, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||||
cu_seqlens_k = torch.tensor(cu_seqlens_k, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
cu_seqlens_k = torch.tensor(cu_seqlens_k, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||||
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||||
set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, context_lens, block_tables)
|
set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, None, block_tables)
|
||||||
return input_ids, positions
|
return input_ids, positions
|
||||||
|
|
||||||
def prepare_decode(self, seqs: list[Sequence]):
|
def prepare_decode(self, seqs: list[Sequence]):
|
||||||
@@ -205,7 +203,7 @@ class ModelRunner:
|
|||||||
|
|
||||||
def run(self, seqs: list[Sequence], is_prefill: bool) -> list[int]:
|
def run(self, seqs: list[Sequence], is_prefill: bool) -> list[int]:
|
||||||
input_ids, positions = self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs)
|
input_ids, positions = self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs)
|
||||||
temperatures = self.prepare_sample(seqs)
|
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
|
||||||
logits = self.run_model(input_ids, positions, is_prefill)
|
logits = self.run_model(input_ids, positions, is_prefill)
|
||||||
token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None
|
token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None
|
||||||
reset_context()
|
reset_context()
|
||||||
|
|||||||
@@ -31,9 +31,6 @@ class Sequence:
|
|||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.num_tokens
|
return self.num_tokens
|
||||||
|
|
||||||
def __lt__(self, other):
|
|
||||||
return self.seq_id < other.seq_id
|
|
||||||
|
|
||||||
def __getitem__(self, key):
|
def __getitem__(self, key):
|
||||||
return self.token_ids[key]
|
return self.token_ids[key]
|
||||||
|
|
||||||
@@ -75,7 +72,14 @@ class Sequence:
|
|||||||
self.num_tokens += 1
|
self.num_tokens += 1
|
||||||
|
|
||||||
def __getstate__(self):
|
def __getstate__(self):
|
||||||
state = vars(self).copy()
|
state = {
|
||||||
if self.num_completion_tokens:
|
"num_tokens": self.num_tokens,
|
||||||
state.pop("token_ids")
|
"num_prompt_tokens": self.num_prompt_tokens,
|
||||||
|
"num_cached_tokens": self.num_cached_tokens,
|
||||||
|
"block_table": self.block_table,
|
||||||
|
}
|
||||||
|
if self.num_completion_tokens == 0:
|
||||||
|
state["token_ids"] = self.token_ids
|
||||||
|
else:
|
||||||
|
state["last_token"] = self.last_token
|
||||||
return state
|
return state
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ _CONTEXT = Context()
|
|||||||
def get_context():
|
def get_context():
|
||||||
return _CONTEXT
|
return _CONTEXT
|
||||||
|
|
||||||
def set_context(is_prefill, cu_seqlens_q=None, cu_seqlens_k=None, max_seqlen_q=0, max_seqlen_k=0, slot_mapping=None, context_lens=None, block_tables=None, ):
|
def set_context(is_prefill, cu_seqlens_q=None, cu_seqlens_k=None, max_seqlen_q=0, max_seqlen_k=0, slot_mapping=None, context_lens=None, block_tables=None):
|
||||||
global _CONTEXT
|
global _CONTEXT
|
||||||
_CONTEXT = Context(is_prefill, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, context_lens, block_tables)
|
_CONTEXT = Context(is_prefill, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, context_lens, block_tables)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user