faster pickle
This commit is contained in:
@@ -183,9 +183,11 @@ class ModelRunner:
|
|||||||
else:
|
else:
|
||||||
bs = input_ids.size(0)
|
bs = input_ids.size(0)
|
||||||
context = get_context()
|
context = get_context()
|
||||||
self.reset_graph_vars()
|
|
||||||
graph = self.graphs[next(x for x in self.graph_bs if x >= bs)]
|
graph = self.graphs[next(x for x in self.graph_bs if x >= bs)]
|
||||||
graph_vars = self.graph_vars
|
graph_vars = self.graph_vars
|
||||||
|
for k, v in graph_vars.items():
|
||||||
|
if k != "outputs":
|
||||||
|
v.zero_()
|
||||||
graph_vars["input_ids"][:bs] = input_ids
|
graph_vars["input_ids"][:bs] = input_ids
|
||||||
graph_vars["positions"][:bs] = positions
|
graph_vars["positions"][:bs] = positions
|
||||||
graph_vars["slot_mapping"][:bs] = context.slot_mapping
|
graph_vars["slot_mapping"][:bs] = context.slot_mapping
|
||||||
@@ -194,14 +196,6 @@ class ModelRunner:
|
|||||||
graph.replay()
|
graph.replay()
|
||||||
return self.model.compute_logits(graph_vars["outputs"][:bs])
|
return self.model.compute_logits(graph_vars["outputs"][:bs])
|
||||||
|
|
||||||
def reset_graph_vars(self):
|
|
||||||
graph_vars = self.graph_vars
|
|
||||||
graph_vars["input_ids"].zero_()
|
|
||||||
graph_vars["positions"].zero_()
|
|
||||||
graph_vars["slot_mapping"].zero_()
|
|
||||||
graph_vars["context_lens"].zero_()
|
|
||||||
graph_vars["block_tables"].zero_()
|
|
||||||
|
|
||||||
def run(self, seqs: list[Sequence], is_prefill: bool) -> list[int]:
|
def run(self, seqs: list[Sequence], is_prefill: bool) -> list[int]:
|
||||||
input_ids, positions = self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs)
|
input_ids, positions = self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs)
|
||||||
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
|
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
|
||||||
|
|||||||
@@ -72,14 +72,12 @@ class Sequence:
|
|||||||
self.num_tokens += 1
|
self.num_tokens += 1
|
||||||
|
|
||||||
def __getstate__(self):
|
def __getstate__(self):
|
||||||
state = {
|
return (self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.block_table,
|
||||||
"num_tokens": self.num_tokens,
|
self.token_ids if self.num_completion_tokens == 0 else self.last_token)
|
||||||
"num_prompt_tokens": self.num_prompt_tokens,
|
|
||||||
"num_cached_tokens": self.num_cached_tokens,
|
def __setstate__(self, state):
|
||||||
"block_table": self.block_table,
|
self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.block_table = state[:-1]
|
||||||
}
|
|
||||||
if self.num_completion_tokens == 0:
|
if self.num_completion_tokens == 0:
|
||||||
state["token_ids"] = self.token_ids
|
self.token_ids = state[-1]
|
||||||
else:
|
else:
|
||||||
state["last_token"] = self.last_token
|
self.last_token = state[-1]
|
||||||
return state
|
|
||||||
|
|||||||
@@ -16,7 +16,6 @@ dependencies = [
|
|||||||
"triton>=3.0.0",
|
"triton>=3.0.0",
|
||||||
"transformers>=4.51.0",
|
"transformers>=4.51.0",
|
||||||
"flash-attn",
|
"flash-attn",
|
||||||
"nvidia-ml-py",
|
|
||||||
"xxhash",
|
"xxhash",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user