From 1caeec8dfa9cd7a4997d294ca2293f19f278e463 Mon Sep 17 00:00:00 2001 From: GeeeekExplorer <2651904866@qq.com> Date: Fri, 27 Jun 2025 18:50:56 +0800 Subject: [PATCH] same as vllm --- README.md | 6 +++--- nanovllm/engine/model_runner.py | 34 ++++++++++++++++----------------- nanovllm/layers/attention.py | 3 +-- 3 files changed, 20 insertions(+), 23 deletions(-) diff --git a/README.md b/README.md index e6907be..f56a733 100644 --- a/README.md +++ b/README.md @@ -14,9 +14,9 @@ A lightweight vLLM implementation built from scratch. pip install git+https://github.com/GeeeekExplorer/nano-vllm.git ``` -## Manual download +## Manual Download -If you’d rather fetch the model weights yourself, you can use: +If you prefer to download the model weights manually, use the following command: ```bash huggingface-cli download --resume-download Qwen/Qwen3-0.6B \ --local-dir ~/huggingface/Qwen3-0.6B/ \ @@ -25,7 +25,7 @@ huggingface-cli download --resume-download Qwen/Qwen3-0.6B \ ## Quick Start -See `example.py` for usage. The API mirrors vLLM's interface with minor differences in the `LLM.generate` method. +See `example.py` for usage. The API mirrors vLLM's interface with minor differences in the `LLM.generate` method: ```python from nanovllm import LLM, SamplingParams llm = LLM("/YOUR/MODEL/PATH", enforce_eager=True, tensor_parallel_size=1) diff --git a/nanovllm/engine/model_runner.py b/nanovllm/engine/model_runner.py index 699d36c..033d160 100644 --- a/nanovllm/engine/model_runner.py +++ b/nanovllm/engine/model_runner.py @@ -22,7 +22,7 @@ class ModelRunner: self.world_size = config.tensor_parallel_size self.rank = rank self.event = event - + dist.init_process_group("nccl", "tcp://localhost:2333", world_size=self.world_size, rank=rank) torch.cuda.set_device(rank) default_dtype = torch.get_default_dtype() @@ -31,8 +31,8 @@ class ModelRunner: self.model = Qwen3ForCausalLM(hf_config) load_model(self.model, config.model) self.sampler = Sampler() - peak = self.warmup_model() - self.allocate_kv_cache(config.gpu_memory_utilization, peak) + self.warmup_model() + self.allocate_kv_cache() if not self.enforce_eager: self.capture_cudagraph() torch.set_default_device("cpu") @@ -47,18 +47,6 @@ class ModelRunner: self.shm = SharedMemory(name="nanovllm") self.loop() - def warmup_model(self): - torch.cuda.empty_cache() - torch.cuda.reset_peak_memory_stats() - before = torch.cuda.memory_stats().get("allocated_bytes.all.peak", 0) - max_num_batched_tokens, max_model_len = self.config.max_num_batched_tokens, self.config.max_model_len - num_seqs = min(max_num_batched_tokens // max_model_len, self.config.max_num_seqs) - seqs = [Sequence([0] * max_model_len) for _ in range(num_seqs)] - self.run(seqs, True) - torch.cuda.empty_cache() - after = torch.cuda.memory_stats().get("allocated_bytes.all.peak", 0) - return after - before - def exit(self): if self.world_size > 1: self.shm.close() @@ -102,15 +90,25 @@ class ModelRunner: assert callable(method) return method(*args) - def allocate_kv_cache(self, gpu_memory_utilization, peak): + def warmup_model(self): + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + max_num_batched_tokens, max_model_len = self.config.max_num_batched_tokens, self.config.max_model_len + num_seqs = min(max_num_batched_tokens // max_model_len, self.config.max_num_seqs) + seqs = [Sequence([0] * max_model_len) for _ in range(num_seqs)] + self.run(seqs, True) + torch.cuda.empty_cache() + + def allocate_kv_cache(self): config = self.config hf_config = config.hf_config free, total = torch.cuda.mem_get_info() used = total - free + peak = torch.cuda.memory_stats()["allocated_bytes.all.peak"] + current = torch.cuda.memory_stats()["allocated_bytes.all.current"] num_kv_heads = hf_config.num_key_value_heads // self.world_size block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * hf_config.head_dim * hf_config.torch_dtype.itemsize - config.num_kvcache_blocks = int(total * gpu_memory_utilization - used - peak) // block_bytes - print(f"{config.num_kvcache_blocks=}") + config.num_kvcache_blocks = int(total * config.gpu_memory_utilization - used - peak + current) // block_bytes self.kv_cache = torch.zeros(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, hf_config.head_dim) layer_id = 0 for module in self.model.modules(): diff --git a/nanovllm/layers/attention.py b/nanovllm/layers/attention.py index 5620b13..d036641 100644 --- a/nanovllm/layers/attention.py +++ b/nanovllm/layers/attention.py @@ -61,8 +61,7 @@ class Attention(nn.Module): k = k.view(-1, self.num_kv_heads, self.head_dim) v = v.view(-1, self.num_kv_heads, self.head_dim) context = get_context() - k_cache = self.k_cache - v_cache = self.v_cache + k_cache, v_cache = self.k_cache, self.v_cache if k_cache.numel() and v_cache.numel(): store_kvcache(k, v, k_cache, v_cache, context.slot_mapping) if context.is_prefill: