remove rng state
This commit is contained in:
@@ -77,7 +77,6 @@ class ModelRunner:
|
||||
assert self.world_size > 1 and not self.rank
|
||||
data = pickle.dumps([method_name, *args])
|
||||
n = len(data)
|
||||
assert n + 4 <= self.shm.size
|
||||
self.shm.buf[0:4] = n.to_bytes(4, "little")
|
||||
self.shm.buf[4:n+4] = data
|
||||
for event in self.event:
|
||||
@@ -87,7 +86,6 @@ class ModelRunner:
|
||||
if self.world_size > 1 and self.rank == 0:
|
||||
self.write_shm(method_name, *args)
|
||||
method = getattr(self, method_name, None)
|
||||
assert callable(method)
|
||||
return method(*args)
|
||||
|
||||
def warmup_model(self):
|
||||
@@ -109,6 +107,7 @@ class ModelRunner:
|
||||
num_kv_heads = hf_config.num_key_value_heads // self.world_size
|
||||
block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * hf_config.head_dim * hf_config.torch_dtype.itemsize
|
||||
config.num_kvcache_blocks = int(total * config.gpu_memory_utilization - used - peak + current) // block_bytes
|
||||
assert config.num_kvcache_blocks > 0
|
||||
self.kv_cache = torch.zeros(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, hf_config.head_dim)
|
||||
layer_id = 0
|
||||
for module in self.model.modules():
|
||||
@@ -119,10 +118,7 @@ class ModelRunner:
|
||||
|
||||
def prepare_block_tables(self, seqs: list[Sequence]):
|
||||
max_len = max(len(seq.block_table) for seq in seqs)
|
||||
block_tables = [
|
||||
seq.block_table + [-1] * (max_len - len(seq.block_table))
|
||||
for seq in seqs
|
||||
]
|
||||
block_tables = [seq.block_table + [-1] * (max_len - len(seq.block_table)) for seq in seqs]
|
||||
block_tables = torch.tensor(block_tables, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||
return block_tables
|
||||
|
||||
@@ -219,12 +215,6 @@ class ModelRunner:
|
||||
|
||||
@torch.inference_mode()
|
||||
def capture_cudagraph(self):
|
||||
get_rng_state = torch.cuda.get_rng_state
|
||||
set_rng_state = torch.cuda.set_rng_state
|
||||
rng_state = torch.cuda.get_rng_state()
|
||||
torch.cuda.get_rng_state = lambda: rng_state
|
||||
torch.cuda.set_rng_state = lambda _: None
|
||||
|
||||
config = self.config
|
||||
hf_config = config.hf_config
|
||||
max_bs = min(self.config.max_num_seqs, 512)
|
||||
@@ -259,6 +249,3 @@ class ModelRunner:
|
||||
block_tables=block_tables,
|
||||
outputs=outputs,
|
||||
)
|
||||
|
||||
torch.cuda.get_rng_state = get_rng_state
|
||||
torch.cuda.set_rng_state = set_rng_state
|
||||
|
||||
Reference in New Issue
Block a user