From 658520b788ed9789ab269852d83092f1d9e90d1d Mon Sep 17 00:00:00 2001 From: GeeeekExplorer <2651904866@qq.com> Date: Fri, 27 Jun 2025 01:51:57 +0800 Subject: [PATCH] warmup and allocate --- nanovllm/config.py | 3 ++- nanovllm/engine/model_runner.py | 25 ++++++++++++++++++++----- nanovllm/engine/sequence.py | 2 +- nanovllm/layers/attention.py | 3 ++- 4 files changed, 25 insertions(+), 8 deletions(-) diff --git a/nanovllm/config.py b/nanovllm/config.py index 6c4e7f9..959ffb3 100644 --- a/nanovllm/config.py +++ b/nanovllm/config.py @@ -6,7 +6,7 @@ from transformers import AutoConfig @dataclass class Config: model: str - max_num_batched_tokens: int = 32768 + max_num_batched_tokens: int = 16384 max_num_seqs: int = 512 max_model_len: int = 4096 gpu_memory_utilization: float = 0.9 @@ -23,3 +23,4 @@ class Config: assert 1 <= self.tensor_parallel_size <= 8 self.hf_config = AutoConfig.from_pretrained(self.model) self.max_model_len = min(self.max_model_len, self.hf_config.max_position_embeddings) + assert self.max_num_batched_tokens >= self.max_model_len diff --git a/nanovllm/engine/model_runner.py b/nanovllm/engine/model_runner.py index 823f0af..699d36c 100644 --- a/nanovllm/engine/model_runner.py +++ b/nanovllm/engine/model_runner.py @@ -31,7 +31,8 @@ class ModelRunner: self.model = Qwen3ForCausalLM(hf_config) load_model(self.model, config.model) self.sampler = Sampler() - self.allocate_kv_cache(config.gpu_memory_utilization) + peak = self.warmup_model() + self.allocate_kv_cache(config.gpu_memory_utilization, peak) if not self.enforce_eager: self.capture_cudagraph() torch.set_default_device("cpu") @@ -46,6 +47,18 @@ class ModelRunner: self.shm = SharedMemory(name="nanovllm") self.loop() + def warmup_model(self): + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + before = torch.cuda.memory_stats().get("allocated_bytes.all.peak", 0) + max_num_batched_tokens, max_model_len = self.config.max_num_batched_tokens, self.config.max_model_len + num_seqs = min(max_num_batched_tokens // max_model_len, self.config.max_num_seqs) + seqs = [Sequence([0] * max_model_len) for _ in range(num_seqs)] + self.run(seqs, True) + torch.cuda.empty_cache() + after = torch.cuda.memory_stats().get("allocated_bytes.all.peak", 0) + return after - before + def exit(self): if self.world_size > 1: self.shm.close() @@ -89,14 +102,15 @@ class ModelRunner: assert callable(method) return method(*args) - def allocate_kv_cache(self, gpu_memory_utilization): + def allocate_kv_cache(self, gpu_memory_utilization, peak): config = self.config hf_config = config.hf_config free, total = torch.cuda.mem_get_info() used = total - free num_kv_heads = hf_config.num_key_value_heads // self.world_size block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * hf_config.head_dim * hf_config.torch_dtype.itemsize - config.num_kvcache_blocks = int(total * gpu_memory_utilization - used) // block_bytes + config.num_kvcache_blocks = int(total * gpu_memory_utilization - used - peak) // block_bytes + print(f"{config.num_kvcache_blocks=}") self.kv_cache = torch.zeros(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, hf_config.head_dim) layer_id = 0 for module in self.model.modules(): @@ -133,6 +147,8 @@ class ModelRunner: cu_seqlens_k.append(cu_seqlens_k[-1] + seqlen_k) max_seqlen_q = max(seqlen_q, max_seqlen_q) max_seqlen_k = max(seqlen_k, max_seqlen_k) + if not seq.block_table: + continue for i in range(seq.num_cached_blocks, seq.num_blocks): start = seq.block_table[i] * self.block_size if i != seq.num_blocks - 1: @@ -140,7 +156,6 @@ class ModelRunner: else: end = start + seq.last_block_num_tokens slot_mapping.extend(list(range(start, end))) - assert len(input_ids) == len(slot_mapping) if cu_seqlens_k[-1] > cu_seqlens_q[-1]: # prefix cache block_tables = self.prepare_block_tables(seqs) input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) @@ -177,7 +192,7 @@ class ModelRunner: return temperatures @torch.inference_mode() - def run_model(self, input_ids: torch.Tensor, positions: torch.Tensor, is_prefill): + def run_model(self, input_ids: torch.Tensor, positions: torch.Tensor, is_prefill: bool): if is_prefill or self.enforce_eager or input_ids.size(0) > 512: return self.model.compute_logits(self.model(input_ids, positions)) else: diff --git a/nanovllm/engine/sequence.py b/nanovllm/engine/sequence.py index 9f25fe6..49d9ee6 100644 --- a/nanovllm/engine/sequence.py +++ b/nanovllm/engine/sequence.py @@ -15,7 +15,7 @@ class Sequence: block_size = 256 counter = count() - def __init__(self, token_ids: list[int], sampling_params: SamplingParams): + def __init__(self, token_ids: list[int], sampling_params = SamplingParams()): self.seq_id = next(Sequence.counter) self.status = SequenceStatus.WAITING self.token_ids = copy(token_ids) diff --git a/nanovllm/layers/attention.py b/nanovllm/layers/attention.py index bb5344e..5620b13 100644 --- a/nanovllm/layers/attention.py +++ b/nanovllm/layers/attention.py @@ -63,7 +63,8 @@ class Attention(nn.Module): context = get_context() k_cache = self.k_cache v_cache = self.v_cache - store_kvcache(k, v, k_cache, v_cache, context.slot_mapping) + if k_cache.numel() and v_cache.numel(): + store_kvcache(k, v, k_cache, v_cache, context.slot_mapping) if context.is_prefill: if context.block_tables is not None: # prefix cache k, v = k_cache, v_cache