faster pickle
This commit is contained in:
@@ -183,9 +183,11 @@ class ModelRunner:
|
||||
else:
|
||||
bs = input_ids.size(0)
|
||||
context = get_context()
|
||||
self.reset_graph_vars()
|
||||
graph = self.graphs[next(x for x in self.graph_bs if x >= bs)]
|
||||
graph_vars = self.graph_vars
|
||||
for k, v in graph_vars.items():
|
||||
if k != "outputs":
|
||||
v.zero_()
|
||||
graph_vars["input_ids"][:bs] = input_ids
|
||||
graph_vars["positions"][:bs] = positions
|
||||
graph_vars["slot_mapping"][:bs] = context.slot_mapping
|
||||
@@ -194,14 +196,6 @@ class ModelRunner:
|
||||
graph.replay()
|
||||
return self.model.compute_logits(graph_vars["outputs"][:bs])
|
||||
|
||||
def reset_graph_vars(self):
|
||||
graph_vars = self.graph_vars
|
||||
graph_vars["input_ids"].zero_()
|
||||
graph_vars["positions"].zero_()
|
||||
graph_vars["slot_mapping"].zero_()
|
||||
graph_vars["context_lens"].zero_()
|
||||
graph_vars["block_tables"].zero_()
|
||||
|
||||
def run(self, seqs: list[Sequence], is_prefill: bool) -> list[int]:
|
||||
input_ids, positions = self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs)
|
||||
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
|
||||
|
||||
@@ -72,14 +72,12 @@ class Sequence:
|
||||
self.num_tokens += 1
|
||||
|
||||
def __getstate__(self):
|
||||
state = {
|
||||
"num_tokens": self.num_tokens,
|
||||
"num_prompt_tokens": self.num_prompt_tokens,
|
||||
"num_cached_tokens": self.num_cached_tokens,
|
||||
"block_table": self.block_table,
|
||||
}
|
||||
return (self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.block_table,
|
||||
self.token_ids if self.num_completion_tokens == 0 else self.last_token)
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.block_table = state[:-1]
|
||||
if self.num_completion_tokens == 0:
|
||||
state["token_ids"] = self.token_ids
|
||||
self.token_ids = state[-1]
|
||||
else:
|
||||
state["last_token"] = self.last_token
|
||||
return state
|
||||
self.last_token = state[-1]
|
||||
|
||||
@@ -16,7 +16,6 @@ dependencies = [
|
||||
"triton>=3.0.0",
|
||||
"transformers>=4.51.0",
|
||||
"flash-attn",
|
||||
"nvidia-ml-py",
|
||||
"xxhash",
|
||||
]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user