From 03cfc13bb3a2a46ce0097e4c8d9561af998bdcad Mon Sep 17 00:00:00 2001 From: GeeeekExplorer <2651904866@qq.com> Date: Mon, 23 Jun 2025 00:51:52 +0800 Subject: [PATCH] faster pickle --- nanovllm/engine/model_runner.py | 12 +++--------- nanovllm/engine/sequence.py | 16 +++++++--------- pyproject.toml | 1 - 3 files changed, 10 insertions(+), 19 deletions(-) diff --git a/nanovllm/engine/model_runner.py b/nanovllm/engine/model_runner.py index 3f5636a..823f0af 100644 --- a/nanovllm/engine/model_runner.py +++ b/nanovllm/engine/model_runner.py @@ -183,9 +183,11 @@ class ModelRunner: else: bs = input_ids.size(0) context = get_context() - self.reset_graph_vars() graph = self.graphs[next(x for x in self.graph_bs if x >= bs)] graph_vars = self.graph_vars + for k, v in graph_vars.items(): + if k != "outputs": + v.zero_() graph_vars["input_ids"][:bs] = input_ids graph_vars["positions"][:bs] = positions graph_vars["slot_mapping"][:bs] = context.slot_mapping @@ -194,14 +196,6 @@ class ModelRunner: graph.replay() return self.model.compute_logits(graph_vars["outputs"][:bs]) - def reset_graph_vars(self): - graph_vars = self.graph_vars - graph_vars["input_ids"].zero_() - graph_vars["positions"].zero_() - graph_vars["slot_mapping"].zero_() - graph_vars["context_lens"].zero_() - graph_vars["block_tables"].zero_() - def run(self, seqs: list[Sequence], is_prefill: bool) -> list[int]: input_ids, positions = self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs) temperatures = self.prepare_sample(seqs) if self.rank == 0 else None diff --git a/nanovllm/engine/sequence.py b/nanovllm/engine/sequence.py index 094b77e..9f25fe6 100644 --- a/nanovllm/engine/sequence.py +++ b/nanovllm/engine/sequence.py @@ -72,14 +72,12 @@ class Sequence: self.num_tokens += 1 def __getstate__(self): - state = { - "num_tokens": self.num_tokens, - "num_prompt_tokens": self.num_prompt_tokens, - "num_cached_tokens": self.num_cached_tokens, - "block_table": self.block_table, - } + return (self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.block_table, + self.token_ids if self.num_completion_tokens == 0 else self.last_token) + + def __setstate__(self, state): + self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.block_table = state[:-1] if self.num_completion_tokens == 0: - state["token_ids"] = self.token_ids + self.token_ids = state[-1] else: - state["last_token"] = self.last_token - return state + self.last_token = state[-1] diff --git a/pyproject.toml b/pyproject.toml index bf3ecda..ffe1094 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,6 @@ dependencies = [ "triton>=3.0.0", "transformers>=4.51.0", "flash-attn", - "nvidia-ml-py", "xxhash", ]