faster pickle
This commit is contained in:
@@ -183,9 +183,11 @@ class ModelRunner:
|
||||
else:
|
||||
bs = input_ids.size(0)
|
||||
context = get_context()
|
||||
self.reset_graph_vars()
|
||||
graph = self.graphs[next(x for x in self.graph_bs if x >= bs)]
|
||||
graph_vars = self.graph_vars
|
||||
for k, v in graph_vars.items():
|
||||
if k != "outputs":
|
||||
v.zero_()
|
||||
graph_vars["input_ids"][:bs] = input_ids
|
||||
graph_vars["positions"][:bs] = positions
|
||||
graph_vars["slot_mapping"][:bs] = context.slot_mapping
|
||||
@@ -194,14 +196,6 @@ class ModelRunner:
|
||||
graph.replay()
|
||||
return self.model.compute_logits(graph_vars["outputs"][:bs])
|
||||
|
||||
def reset_graph_vars(self):
|
||||
graph_vars = self.graph_vars
|
||||
graph_vars["input_ids"].zero_()
|
||||
graph_vars["positions"].zero_()
|
||||
graph_vars["slot_mapping"].zero_()
|
||||
graph_vars["context_lens"].zero_()
|
||||
graph_vars["block_tables"].zero_()
|
||||
|
||||
def run(self, seqs: list[Sequence], is_prefill: bool) -> list[int]:
|
||||
input_ids, positions = self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs)
|
||||
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
|
||||
|
||||
Reference in New Issue
Block a user