refactor
This commit is contained in:
@@ -169,7 +169,7 @@ class ModelRunner:
|
||||
context_lens = torch.zeros(max_bs, dtype=torch.int32)
|
||||
block_tables = torch.zeros(max_bs, max_num_blocks, dtype=torch.int32)
|
||||
outputs = torch.zeros(max_bs, hf_config.hidden_size)
|
||||
self.graph_bs = [1, 2, 4, 8, 16] + list(range(32, max_bs + 1, 32))
|
||||
self.graph_bs = [1, 2, 4, 8, 16] + list(range(16, max_bs + 1, 16))
|
||||
self.graphs = {}
|
||||
self.graph_pool = None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user