same as vllm

This commit is contained in:
GeeeekExplorer
2025-06-27 18:50:56 +08:00
parent 658520b788
commit 1caeec8dfa
3 changed files with 20 additions and 23 deletions

View File

@@ -14,9 +14,9 @@ A lightweight vLLM implementation built from scratch.
pip install git+https://github.com/GeeeekExplorer/nano-vllm.git pip install git+https://github.com/GeeeekExplorer/nano-vllm.git
``` ```
## Manual download ## Manual Download
If youd rather fetch the model weights yourself, you can use: If you prefer to download the model weights manually, use the following command:
```bash ```bash
huggingface-cli download --resume-download Qwen/Qwen3-0.6B \ huggingface-cli download --resume-download Qwen/Qwen3-0.6B \
--local-dir ~/huggingface/Qwen3-0.6B/ \ --local-dir ~/huggingface/Qwen3-0.6B/ \
@@ -25,7 +25,7 @@ huggingface-cli download --resume-download Qwen/Qwen3-0.6B \
## Quick Start ## Quick Start
See `example.py` for usage. The API mirrors vLLM's interface with minor differences in the `LLM.generate` method. See `example.py` for usage. The API mirrors vLLM's interface with minor differences in the `LLM.generate` method:
```python ```python
from nanovllm import LLM, SamplingParams from nanovllm import LLM, SamplingParams
llm = LLM("/YOUR/MODEL/PATH", enforce_eager=True, tensor_parallel_size=1) llm = LLM("/YOUR/MODEL/PATH", enforce_eager=True, tensor_parallel_size=1)

View File

@@ -22,7 +22,7 @@ class ModelRunner:
self.world_size = config.tensor_parallel_size self.world_size = config.tensor_parallel_size
self.rank = rank self.rank = rank
self.event = event self.event = event
dist.init_process_group("nccl", "tcp://localhost:2333", world_size=self.world_size, rank=rank) dist.init_process_group("nccl", "tcp://localhost:2333", world_size=self.world_size, rank=rank)
torch.cuda.set_device(rank) torch.cuda.set_device(rank)
default_dtype = torch.get_default_dtype() default_dtype = torch.get_default_dtype()
@@ -31,8 +31,8 @@ class ModelRunner:
self.model = Qwen3ForCausalLM(hf_config) self.model = Qwen3ForCausalLM(hf_config)
load_model(self.model, config.model) load_model(self.model, config.model)
self.sampler = Sampler() self.sampler = Sampler()
peak = self.warmup_model() self.warmup_model()
self.allocate_kv_cache(config.gpu_memory_utilization, peak) self.allocate_kv_cache()
if not self.enforce_eager: if not self.enforce_eager:
self.capture_cudagraph() self.capture_cudagraph()
torch.set_default_device("cpu") torch.set_default_device("cpu")
@@ -47,18 +47,6 @@ class ModelRunner:
self.shm = SharedMemory(name="nanovllm") self.shm = SharedMemory(name="nanovllm")
self.loop() self.loop()
def warmup_model(self):
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
before = torch.cuda.memory_stats().get("allocated_bytes.all.peak", 0)
max_num_batched_tokens, max_model_len = self.config.max_num_batched_tokens, self.config.max_model_len
num_seqs = min(max_num_batched_tokens // max_model_len, self.config.max_num_seqs)
seqs = [Sequence([0] * max_model_len) for _ in range(num_seqs)]
self.run(seqs, True)
torch.cuda.empty_cache()
after = torch.cuda.memory_stats().get("allocated_bytes.all.peak", 0)
return after - before
def exit(self): def exit(self):
if self.world_size > 1: if self.world_size > 1:
self.shm.close() self.shm.close()
@@ -102,15 +90,25 @@ class ModelRunner:
assert callable(method) assert callable(method)
return method(*args) return method(*args)
def allocate_kv_cache(self, gpu_memory_utilization, peak): def warmup_model(self):
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
max_num_batched_tokens, max_model_len = self.config.max_num_batched_tokens, self.config.max_model_len
num_seqs = min(max_num_batched_tokens // max_model_len, self.config.max_num_seqs)
seqs = [Sequence([0] * max_model_len) for _ in range(num_seqs)]
self.run(seqs, True)
torch.cuda.empty_cache()
def allocate_kv_cache(self):
config = self.config config = self.config
hf_config = config.hf_config hf_config = config.hf_config
free, total = torch.cuda.mem_get_info() free, total = torch.cuda.mem_get_info()
used = total - free used = total - free
peak = torch.cuda.memory_stats()["allocated_bytes.all.peak"]
current = torch.cuda.memory_stats()["allocated_bytes.all.current"]
num_kv_heads = hf_config.num_key_value_heads // self.world_size num_kv_heads = hf_config.num_key_value_heads // self.world_size
block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * hf_config.head_dim * hf_config.torch_dtype.itemsize block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * hf_config.head_dim * hf_config.torch_dtype.itemsize
config.num_kvcache_blocks = int(total * gpu_memory_utilization - used - peak) // block_bytes config.num_kvcache_blocks = int(total * config.gpu_memory_utilization - used - peak + current) // block_bytes
print(f"{config.num_kvcache_blocks=}")
self.kv_cache = torch.zeros(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, hf_config.head_dim) self.kv_cache = torch.zeros(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, hf_config.head_dim)
layer_id = 0 layer_id = 0
for module in self.model.modules(): for module in self.model.modules():

View File

@@ -61,8 +61,7 @@ class Attention(nn.Module):
k = k.view(-1, self.num_kv_heads, self.head_dim) k = k.view(-1, self.num_kv_heads, self.head_dim)
v = v.view(-1, self.num_kv_heads, self.head_dim) v = v.view(-1, self.num_kv_heads, self.head_dim)
context = get_context() context = get_context()
k_cache = self.k_cache k_cache, v_cache = self.k_cache, self.v_cache
v_cache = self.v_cache
if k_cache.numel() and v_cache.numel(): if k_cache.numel() and v_cache.numel():
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping) store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
if context.is_prefill: if context.is_prefill: