From 658520b788ed9789ab269852d83092f1d9e90d1d Mon Sep 17 00:00:00 2001 From: GeeeekExplorer <2651904866@qq.com> Date: Fri, 27 Jun 2025 01:51:57 +0800 Subject: [PATCH 1/3] warmup and allocate --- nanovllm/config.py | 3 ++- nanovllm/engine/model_runner.py | 25 ++++++++++++++++++++----- nanovllm/engine/sequence.py | 2 +- nanovllm/layers/attention.py | 3 ++- 4 files changed, 25 insertions(+), 8 deletions(-) diff --git a/nanovllm/config.py b/nanovllm/config.py index 6c4e7f9..959ffb3 100644 --- a/nanovllm/config.py +++ b/nanovllm/config.py @@ -6,7 +6,7 @@ from transformers import AutoConfig @dataclass class Config: model: str - max_num_batched_tokens: int = 32768 + max_num_batched_tokens: int = 16384 max_num_seqs: int = 512 max_model_len: int = 4096 gpu_memory_utilization: float = 0.9 @@ -23,3 +23,4 @@ class Config: assert 1 <= self.tensor_parallel_size <= 8 self.hf_config = AutoConfig.from_pretrained(self.model) self.max_model_len = min(self.max_model_len, self.hf_config.max_position_embeddings) + assert self.max_num_batched_tokens >= self.max_model_len diff --git a/nanovllm/engine/model_runner.py b/nanovllm/engine/model_runner.py index 823f0af..699d36c 100644 --- a/nanovllm/engine/model_runner.py +++ b/nanovllm/engine/model_runner.py @@ -31,7 +31,8 @@ class ModelRunner: self.model = Qwen3ForCausalLM(hf_config) load_model(self.model, config.model) self.sampler = Sampler() - self.allocate_kv_cache(config.gpu_memory_utilization) + peak = self.warmup_model() + self.allocate_kv_cache(config.gpu_memory_utilization, peak) if not self.enforce_eager: self.capture_cudagraph() torch.set_default_device("cpu") @@ -46,6 +47,18 @@ class ModelRunner: self.shm = SharedMemory(name="nanovllm") self.loop() + def warmup_model(self): + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + before = torch.cuda.memory_stats().get("allocated_bytes.all.peak", 0) + max_num_batched_tokens, max_model_len = self.config.max_num_batched_tokens, self.config.max_model_len + num_seqs = min(max_num_batched_tokens // max_model_len, self.config.max_num_seqs) + seqs = [Sequence([0] * max_model_len) for _ in range(num_seqs)] + self.run(seqs, True) + torch.cuda.empty_cache() + after = torch.cuda.memory_stats().get("allocated_bytes.all.peak", 0) + return after - before + def exit(self): if self.world_size > 1: self.shm.close() @@ -89,14 +102,15 @@ class ModelRunner: assert callable(method) return method(*args) - def allocate_kv_cache(self, gpu_memory_utilization): + def allocate_kv_cache(self, gpu_memory_utilization, peak): config = self.config hf_config = config.hf_config free, total = torch.cuda.mem_get_info() used = total - free num_kv_heads = hf_config.num_key_value_heads // self.world_size block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * hf_config.head_dim * hf_config.torch_dtype.itemsize - config.num_kvcache_blocks = int(total * gpu_memory_utilization - used) // block_bytes + config.num_kvcache_blocks = int(total * gpu_memory_utilization - used - peak) // block_bytes + print(f"{config.num_kvcache_blocks=}") self.kv_cache = torch.zeros(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, hf_config.head_dim) layer_id = 0 for module in self.model.modules(): @@ -133,6 +147,8 @@ class ModelRunner: cu_seqlens_k.append(cu_seqlens_k[-1] + seqlen_k) max_seqlen_q = max(seqlen_q, max_seqlen_q) max_seqlen_k = max(seqlen_k, max_seqlen_k) + if not seq.block_table: + continue for i in range(seq.num_cached_blocks, seq.num_blocks): start = seq.block_table[i] * self.block_size if i != seq.num_blocks - 1: @@ -140,7 +156,6 @@ class ModelRunner: else: end = start + seq.last_block_num_tokens slot_mapping.extend(list(range(start, end))) - assert len(input_ids) == len(slot_mapping) if cu_seqlens_k[-1] > cu_seqlens_q[-1]: # prefix cache block_tables = self.prepare_block_tables(seqs) input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) @@ -177,7 +192,7 @@ class ModelRunner: return temperatures @torch.inference_mode() - def run_model(self, input_ids: torch.Tensor, positions: torch.Tensor, is_prefill): + def run_model(self, input_ids: torch.Tensor, positions: torch.Tensor, is_prefill: bool): if is_prefill or self.enforce_eager or input_ids.size(0) > 512: return self.model.compute_logits(self.model(input_ids, positions)) else: diff --git a/nanovllm/engine/sequence.py b/nanovllm/engine/sequence.py index 9f25fe6..49d9ee6 100644 --- a/nanovllm/engine/sequence.py +++ b/nanovllm/engine/sequence.py @@ -15,7 +15,7 @@ class Sequence: block_size = 256 counter = count() - def __init__(self, token_ids: list[int], sampling_params: SamplingParams): + def __init__(self, token_ids: list[int], sampling_params = SamplingParams()): self.seq_id = next(Sequence.counter) self.status = SequenceStatus.WAITING self.token_ids = copy(token_ids) diff --git a/nanovllm/layers/attention.py b/nanovllm/layers/attention.py index bb5344e..5620b13 100644 --- a/nanovllm/layers/attention.py +++ b/nanovllm/layers/attention.py @@ -63,7 +63,8 @@ class Attention(nn.Module): context = get_context() k_cache = self.k_cache v_cache = self.v_cache - store_kvcache(k, v, k_cache, v_cache, context.slot_mapping) + if k_cache.numel() and v_cache.numel(): + store_kvcache(k, v, k_cache, v_cache, context.slot_mapping) if context.is_prefill: if context.block_tables is not None: # prefix cache k, v = k_cache, v_cache From 1caeec8dfa9cd7a4997d294ca2293f19f278e463 Mon Sep 17 00:00:00 2001 From: GeeeekExplorer <2651904866@qq.com> Date: Fri, 27 Jun 2025 18:50:56 +0800 Subject: [PATCH 2/3] same as vllm --- README.md | 6 +++--- nanovllm/engine/model_runner.py | 34 ++++++++++++++++----------------- nanovllm/layers/attention.py | 3 +-- 3 files changed, 20 insertions(+), 23 deletions(-) diff --git a/README.md b/README.md index e6907be..f56a733 100644 --- a/README.md +++ b/README.md @@ -14,9 +14,9 @@ A lightweight vLLM implementation built from scratch. pip install git+https://github.com/GeeeekExplorer/nano-vllm.git ``` -## Manual download +## Manual Download -If you’d rather fetch the model weights yourself, you can use: +If you prefer to download the model weights manually, use the following command: ```bash huggingface-cli download --resume-download Qwen/Qwen3-0.6B \ --local-dir ~/huggingface/Qwen3-0.6B/ \ @@ -25,7 +25,7 @@ huggingface-cli download --resume-download Qwen/Qwen3-0.6B \ ## Quick Start -See `example.py` for usage. The API mirrors vLLM's interface with minor differences in the `LLM.generate` method. +See `example.py` for usage. The API mirrors vLLM's interface with minor differences in the `LLM.generate` method: ```python from nanovllm import LLM, SamplingParams llm = LLM("/YOUR/MODEL/PATH", enforce_eager=True, tensor_parallel_size=1) diff --git a/nanovllm/engine/model_runner.py b/nanovllm/engine/model_runner.py index 699d36c..033d160 100644 --- a/nanovllm/engine/model_runner.py +++ b/nanovllm/engine/model_runner.py @@ -22,7 +22,7 @@ class ModelRunner: self.world_size = config.tensor_parallel_size self.rank = rank self.event = event - + dist.init_process_group("nccl", "tcp://localhost:2333", world_size=self.world_size, rank=rank) torch.cuda.set_device(rank) default_dtype = torch.get_default_dtype() @@ -31,8 +31,8 @@ class ModelRunner: self.model = Qwen3ForCausalLM(hf_config) load_model(self.model, config.model) self.sampler = Sampler() - peak = self.warmup_model() - self.allocate_kv_cache(config.gpu_memory_utilization, peak) + self.warmup_model() + self.allocate_kv_cache() if not self.enforce_eager: self.capture_cudagraph() torch.set_default_device("cpu") @@ -47,18 +47,6 @@ class ModelRunner: self.shm = SharedMemory(name="nanovllm") self.loop() - def warmup_model(self): - torch.cuda.empty_cache() - torch.cuda.reset_peak_memory_stats() - before = torch.cuda.memory_stats().get("allocated_bytes.all.peak", 0) - max_num_batched_tokens, max_model_len = self.config.max_num_batched_tokens, self.config.max_model_len - num_seqs = min(max_num_batched_tokens // max_model_len, self.config.max_num_seqs) - seqs = [Sequence([0] * max_model_len) for _ in range(num_seqs)] - self.run(seqs, True) - torch.cuda.empty_cache() - after = torch.cuda.memory_stats().get("allocated_bytes.all.peak", 0) - return after - before - def exit(self): if self.world_size > 1: self.shm.close() @@ -102,15 +90,25 @@ class ModelRunner: assert callable(method) return method(*args) - def allocate_kv_cache(self, gpu_memory_utilization, peak): + def warmup_model(self): + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + max_num_batched_tokens, max_model_len = self.config.max_num_batched_tokens, self.config.max_model_len + num_seqs = min(max_num_batched_tokens // max_model_len, self.config.max_num_seqs) + seqs = [Sequence([0] * max_model_len) for _ in range(num_seqs)] + self.run(seqs, True) + torch.cuda.empty_cache() + + def allocate_kv_cache(self): config = self.config hf_config = config.hf_config free, total = torch.cuda.mem_get_info() used = total - free + peak = torch.cuda.memory_stats()["allocated_bytes.all.peak"] + current = torch.cuda.memory_stats()["allocated_bytes.all.current"] num_kv_heads = hf_config.num_key_value_heads // self.world_size block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * hf_config.head_dim * hf_config.torch_dtype.itemsize - config.num_kvcache_blocks = int(total * gpu_memory_utilization - used - peak) // block_bytes - print(f"{config.num_kvcache_blocks=}") + config.num_kvcache_blocks = int(total * config.gpu_memory_utilization - used - peak + current) // block_bytes self.kv_cache = torch.zeros(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, hf_config.head_dim) layer_id = 0 for module in self.model.modules(): diff --git a/nanovllm/layers/attention.py b/nanovllm/layers/attention.py index 5620b13..d036641 100644 --- a/nanovllm/layers/attention.py +++ b/nanovllm/layers/attention.py @@ -61,8 +61,7 @@ class Attention(nn.Module): k = k.view(-1, self.num_kv_heads, self.head_dim) v = v.view(-1, self.num_kv_heads, self.head_dim) context = get_context() - k_cache = self.k_cache - v_cache = self.v_cache + k_cache, v_cache = self.k_cache, self.v_cache if k_cache.numel() and v_cache.numel(): store_kvcache(k, v, k_cache, v_cache, context.slot_mapping) if context.is_prefill: From cb0b3dec3f416455f47cf3391c72f75167592b8b Mon Sep 17 00:00:00 2001 From: GeeeekExplorer <2651904866@qq.com> Date: Fri, 27 Jun 2025 22:50:33 +0800 Subject: [PATCH 3/3] remove rng state --- nanovllm/engine/model_runner.py | 17 ++--------------- 1 file changed, 2 insertions(+), 15 deletions(-) diff --git a/nanovllm/engine/model_runner.py b/nanovllm/engine/model_runner.py index 033d160..d48a0eb 100644 --- a/nanovllm/engine/model_runner.py +++ b/nanovllm/engine/model_runner.py @@ -77,7 +77,6 @@ class ModelRunner: assert self.world_size > 1 and not self.rank data = pickle.dumps([method_name, *args]) n = len(data) - assert n + 4 <= self.shm.size self.shm.buf[0:4] = n.to_bytes(4, "little") self.shm.buf[4:n+4] = data for event in self.event: @@ -87,7 +86,6 @@ class ModelRunner: if self.world_size > 1 and self.rank == 0: self.write_shm(method_name, *args) method = getattr(self, method_name, None) - assert callable(method) return method(*args) def warmup_model(self): @@ -109,6 +107,7 @@ class ModelRunner: num_kv_heads = hf_config.num_key_value_heads // self.world_size block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * hf_config.head_dim * hf_config.torch_dtype.itemsize config.num_kvcache_blocks = int(total * config.gpu_memory_utilization - used - peak + current) // block_bytes + assert config.num_kvcache_blocks > 0 self.kv_cache = torch.zeros(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, hf_config.head_dim) layer_id = 0 for module in self.model.modules(): @@ -119,10 +118,7 @@ class ModelRunner: def prepare_block_tables(self, seqs: list[Sequence]): max_len = max(len(seq.block_table) for seq in seqs) - block_tables = [ - seq.block_table + [-1] * (max_len - len(seq.block_table)) - for seq in seqs - ] + block_tables = [seq.block_table + [-1] * (max_len - len(seq.block_table)) for seq in seqs] block_tables = torch.tensor(block_tables, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) return block_tables @@ -219,12 +215,6 @@ class ModelRunner: @torch.inference_mode() def capture_cudagraph(self): - get_rng_state = torch.cuda.get_rng_state - set_rng_state = torch.cuda.set_rng_state - rng_state = torch.cuda.get_rng_state() - torch.cuda.get_rng_state = lambda: rng_state - torch.cuda.set_rng_state = lambda _: None - config = self.config hf_config = config.hf_config max_bs = min(self.config.max_num_seqs, 512) @@ -259,6 +249,3 @@ class ModelRunner: block_tables=block_tables, outputs=outputs, ) - - torch.cuda.get_rng_state = get_rng_state - torch.cuda.set_rng_state = set_rng_state