faster pickle

This commit is contained in:
GeeeekExplorer
2025-06-23 00:51:52 +08:00
parent 8162578b60
commit 03cfc13bb3
3 changed files with 10 additions and 19 deletions

View File

@@ -183,9 +183,11 @@ class ModelRunner:
else:
bs = input_ids.size(0)
context = get_context()
self.reset_graph_vars()
graph = self.graphs[next(x for x in self.graph_bs if x >= bs)]
graph_vars = self.graph_vars
for k, v in graph_vars.items():
if k != "outputs":
v.zero_()
graph_vars["input_ids"][:bs] = input_ids
graph_vars["positions"][:bs] = positions
graph_vars["slot_mapping"][:bs] = context.slot_mapping
@@ -194,14 +196,6 @@ class ModelRunner:
graph.replay()
return self.model.compute_logits(graph_vars["outputs"][:bs])
def reset_graph_vars(self):
graph_vars = self.graph_vars
graph_vars["input_ids"].zero_()
graph_vars["positions"].zero_()
graph_vars["slot_mapping"].zero_()
graph_vars["context_lens"].zero_()
graph_vars["block_tables"].zero_()
def run(self, seqs: list[Sequence], is_prefill: bool) -> list[int]:
input_ids, positions = self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs)
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None