faster pickle

This commit is contained in:
GeeeekExplorer
2025-06-23 00:51:52 +08:00
parent 8162578b60
commit 03cfc13bb3
3 changed files with 10 additions and 19 deletions

View File

@@ -183,9 +183,11 @@ class ModelRunner:
else:
bs = input_ids.size(0)
context = get_context()
self.reset_graph_vars()
graph = self.graphs[next(x for x in self.graph_bs if x >= bs)]
graph_vars = self.graph_vars
for k, v in graph_vars.items():
if k != "outputs":
v.zero_()
graph_vars["input_ids"][:bs] = input_ids
graph_vars["positions"][:bs] = positions
graph_vars["slot_mapping"][:bs] = context.slot_mapping
@@ -194,14 +196,6 @@ class ModelRunner:
graph.replay()
return self.model.compute_logits(graph_vars["outputs"][:bs])
def reset_graph_vars(self):
graph_vars = self.graph_vars
graph_vars["input_ids"].zero_()
graph_vars["positions"].zero_()
graph_vars["slot_mapping"].zero_()
graph_vars["context_lens"].zero_()
graph_vars["block_tables"].zero_()
def run(self, seqs: list[Sequence], is_prefill: bool) -> list[int]:
input_ids, positions = self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs)
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None

View File

@@ -72,14 +72,12 @@ class Sequence:
self.num_tokens += 1
def __getstate__(self):
state = {
"num_tokens": self.num_tokens,
"num_prompt_tokens": self.num_prompt_tokens,
"num_cached_tokens": self.num_cached_tokens,
"block_table": self.block_table,
}
return (self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.block_table,
self.token_ids if self.num_completion_tokens == 0 else self.last_token)
def __setstate__(self, state):
self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.block_table = state[:-1]
if self.num_completion_tokens == 0:
state["token_ids"] = self.token_ids
self.token_ids = state[-1]
else:
state["last_token"] = self.last_token
return state
self.last_token = state[-1]

View File

@@ -16,7 +16,6 @@ dependencies = [
"triton>=3.0.0",
"transformers>=4.51.0",
"flash-attn",
"nvidia-ml-py",
"xxhash",
]