From b98e1ca3050caaab91482f184ae0d9f6a3f31aac Mon Sep 17 00:00:00 2001 From: GeeeekExplorer <2651904866@qq.com> Date: Tue, 10 Jun 2025 08:52:58 +0800 Subject: [PATCH] fix --- bench.py | 4 ++-- example.py | 2 +- nanovllm/config.py | 2 +- nanovllm/engine/llm_engine.py | 24 +++++++++++++++++++----- nanovllm/engine/model_runner.py | 6 +++--- nanovllm/engine/scheduler.py | 2 +- nanovllm/engine/sequence.py | 5 +---- nanovllm/layers/attention.py | 5 ++--- nanovllm/layers/sampler.py | 10 +++++----- nanovllm/models/qwen3.py | 5 ++++- 10 files changed, 39 insertions(+), 26 deletions(-) diff --git a/bench.py b/bench.py index 967df23..5789754 100644 --- a/bench.py +++ b/bench.py @@ -16,5 +16,5 @@ sampling_params = SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=ma t = time.time() completions = llm.generate(prompt_token_ids, sampling_params) -troughput = batch_size * max_tokens / (time.time() - t) -print(f"Throughput: {troughput: .2f}") +throughput = batch_size * max_tokens / (time.time() - t) +print(f"Throughput: {throughput: .2f}") diff --git a/example.py b/example.py index 7b7cf06..3dc65f8 100644 --- a/example.py +++ b/example.py @@ -24,6 +24,6 @@ prompts = [ completions = llm.generate(prompts, sampling_params) for p, c in zip(prompts, completions): - print("\n\n") + print("\n") print(f"Prompt: {p}") print(f"Completion: {c}") diff --git a/nanovllm/config.py b/nanovllm/config.py index e669d50..4c33837 100644 --- a/nanovllm/config.py +++ b/nanovllm/config.py @@ -5,7 +5,7 @@ from transformers import AutoConfig @dataclass class Config: model: str = '' - max_num_batched_tokens: int = 16384 + max_num_batched_tokens: int = 32768 max_num_seqs: int = 512 max_model_len: int = 4096 gpu_memory_utilization: float = 0.95 diff --git a/nanovllm/engine/llm_engine.py b/nanovllm/engine/llm_engine.py index e22ce7f..6c48af9 100644 --- a/nanovllm/engine/llm_engine.py +++ b/nanovllm/engine/llm_engine.py @@ -1,4 +1,5 @@ from collections import defaultdict +from time import perf_counter from tqdm.auto import tqdm from transformers import AutoConfig, AutoTokenizer @@ -33,7 +34,7 @@ class LLMEngine: seqs, is_prefill = self.scheduler.schedule() token_ids = self.model_runner.run(seqs, is_prefill) finished = self.scheduler.postprocess(seqs, token_ids) - return [(seq.seq_id, token_id, finish) for seq, token_id, finish in zip(seqs, token_ids, finished)] + return [(seq.seq_id, token_id, finish) for seq, token_id, finish in zip(seqs, token_ids, finished)], sum(len(seq) for seq in seqs) if is_prefill else len(seqs) def is_finished(self): return self.scheduler.is_finished() @@ -45,19 +46,32 @@ class LLMEngine: use_tqdm: bool = True, ) -> list[str]: if use_tqdm: - pbar = tqdm(total=len(prompts), - desc="Processed prompts", + pbar = tqdm( + total=len(prompts), + desc="Generating", + dynamic_ncols=True, ) if not isinstance(SamplingParams, list): sampling_params = [sampling_params] * len(prompts) for prompt, sp in zip(prompts, sampling_params): self.add_request(prompt, sp) outputs = defaultdict(list) + prefill_throughput = decode_throughput = 0. while not self.is_finished(): - output = self.step() + t = perf_counter() + output, num_tokens = self.step() + if use_tqdm: + if num_tokens > len(output): + prefill_throughput = num_tokens / (perf_counter() - t) + else: + decode_throughput = num_tokens / (perf_counter() - t) + pbar.set_postfix({ + "Prefill": f"{int(prefill_throughput)}tok/s", + "Decode": f"{int(decode_throughput)}tok/s", + }) for seq_id, token_id, finish in output: outputs[seq_id].append(token_id) - if use_tqdm and finish: + if finish and use_tqdm: pbar.update(1) outputs = [outputs[seq_id] for seq_id in sorted(outputs)] outputs = [self.tokenizer.decode(token_ids) for token_ids in outputs] diff --git a/nanovllm/engine/model_runner.py b/nanovllm/engine/model_runner.py index 4724674..3d26e48 100644 --- a/nanovllm/engine/model_runner.py +++ b/nanovllm/engine/model_runner.py @@ -24,7 +24,7 @@ class ModelRunner: self.sampler = Sampler() self.allocate_kv_cache(config.gpu_memory_utilization) if not self.enforce_eager: - self.capture_model() + self.capture_cudagraph() torch.set_default_device("cpu") torch.set_default_dtype(default_dtype) @@ -101,7 +101,7 @@ class ModelRunner: input_ids.append(seq.last_token) positions.append(len(seq)) context_lens.append(len(seq)) - slot_mapping.append(seq.block_table[-1] * self.block_size + len(seq.last_block())) + slot_mapping.append(seq.block_table[-1] * self.block_size + len(seq.last_block()) - 1) input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) @@ -152,7 +152,7 @@ class ModelRunner: return token_ids @torch.inference_mode() - def capture_model(self): + def capture_cudagraph(self): get_rng_state = torch.cuda.get_rng_state set_rng_state = torch.cuda.set_rng_state rng_state = torch.cuda.get_rng_state() diff --git a/nanovllm/engine/scheduler.py b/nanovllm/engine/scheduler.py index 1e5e684..28cc298 100644 --- a/nanovllm/engine/scheduler.py +++ b/nanovllm/engine/scheduler.py @@ -73,7 +73,7 @@ class Scheduler: finished = [] for seq, token_id in zip(seqs, token_ids): seq.append_token(token_id) - if token_id == self.eos or seq.num_completion_tokens == seq.max_tokens: + if (not seq.ignore_eos and token_id == self.eos) or seq.num_completion_tokens == seq.max_tokens: seq.status = SequenceStatus.FINISHED self.block_manager.deallocate(seq) self.running.remove(seq) diff --git a/nanovllm/engine/sequence.py b/nanovllm/engine/sequence.py index 5d4f792..717434a 100644 --- a/nanovllm/engine/sequence.py +++ b/nanovllm/engine/sequence.py @@ -64,10 +64,7 @@ class Sequence: def last_block(self, block_size=256): n = self.num_blocks - t = len(self) + block_size - self.num_blocks * block_size - x = self.token_ids[(n-1)*block_size:] - assert len(x) == t - return x + return self.token_ids[(n-1)*block_size:] def append_token(self, token_id: int): self.token_ids.append(token_id) diff --git a/nanovllm/layers/attention.py b/nanovllm/layers/attention.py index 6e58865..3f52a3b 100644 --- a/nanovllm/layers/attention.py +++ b/nanovllm/layers/attention.py @@ -4,7 +4,6 @@ import triton import triton.language as tl from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache -# from nanovllm.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache from nanovllm.utils.context import get_context @@ -64,8 +63,8 @@ class Attention(nn.Module): context = get_context() k_cache = self.k_cache v_cache = self.v_cache + store_kvcache(k, v, k_cache, v_cache, context.slot_mapping) if context.is_prefill: - store_kvcache(k, v, k_cache, v_cache, context.slot_mapping) if context.block_tables is None: # normal prefill cu_seqlens_k = context.cu_seqlens_k seqused_k = None @@ -79,7 +78,7 @@ class Attention(nn.Module): seqused_k=seqused_k, softmax_scale=self.scale, causal=True, block_table=context.block_tables) else: # decode - o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache, k.unsqueeze(1), v.unsqueeze(1), + o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache, cache_seqlens=context.context_lens, block_table=context.block_tables, softmax_scale=self.scale, causal=True) o = o.view(-1, self.num_heads * self.head_dim) diff --git a/nanovllm/layers/sampler.py b/nanovllm/layers/sampler.py index 12d8888..88e59ee 100644 --- a/nanovllm/layers/sampler.py +++ b/nanovllm/layers/sampler.py @@ -7,11 +7,11 @@ class Sampler(nn.Module): def __init__(self): super().__init__() - def forward(self, logits: torch.Tensor, temperatures: torch.Tensor | None = None): + def forward(self, logits: torch.Tensor, temperatures: torch.Tensor): logits = logits.to(torch.float) - if temperatures is not None: - logits.div_(temperatures.unsqueeze(dim=1)) + greedy_tokens = logits.argmax(dim=-1) + logits.div_(temperatures.unsqueeze(dim=1)) probs = torch.softmax(logits, dim=-1, dtype=torch.float) # logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) - sampled_tokens = probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1) - return sampled_tokens + sample_tokens = probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1) + return torch.where(temperatures == 0, greedy_tokens, sample_tokens) diff --git a/nanovllm/models/qwen3.py b/nanovllm/models/qwen3.py index 5fb463f..1d3ff41 100755 --- a/nanovllm/models/qwen3.py +++ b/nanovllm/models/qwen3.py @@ -212,7 +212,8 @@ class Qwen3ForCausalLM(nn.Module): super().__init__() self.model = Qwen3Model(config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) - if config.tie_word_embeddings: + self.tie_word_embeddings = config.tie_word_embeddings + if self.tie_word_embeddings: self.lm_head.weight.data = self.model.embed_tokens.weight.data def forward( @@ -236,6 +237,8 @@ class Qwen3ForCausalLM(nn.Module): default_weight_loader = lambda param, loaded_weight: param.data.copy_(loaded_weight) with safe_open(os.path.join(path, "model.safetensors"), "pt", "cpu") as f: for n, p in self.named_parameters(): + if self.tie_word_embeddings and "lm_head" in n: + continue for x in self.packed_modules_mapping: if x in n: weight_loader = getattr(p, "weight_loader", default_weight_loader)