remove rng state
This commit is contained in:
@@ -77,7 +77,6 @@ class ModelRunner:
|
|||||||
assert self.world_size > 1 and not self.rank
|
assert self.world_size > 1 and not self.rank
|
||||||
data = pickle.dumps([method_name, *args])
|
data = pickle.dumps([method_name, *args])
|
||||||
n = len(data)
|
n = len(data)
|
||||||
assert n + 4 <= self.shm.size
|
|
||||||
self.shm.buf[0:4] = n.to_bytes(4, "little")
|
self.shm.buf[0:4] = n.to_bytes(4, "little")
|
||||||
self.shm.buf[4:n+4] = data
|
self.shm.buf[4:n+4] = data
|
||||||
for event in self.event:
|
for event in self.event:
|
||||||
@@ -87,7 +86,6 @@ class ModelRunner:
|
|||||||
if self.world_size > 1 and self.rank == 0:
|
if self.world_size > 1 and self.rank == 0:
|
||||||
self.write_shm(method_name, *args)
|
self.write_shm(method_name, *args)
|
||||||
method = getattr(self, method_name, None)
|
method = getattr(self, method_name, None)
|
||||||
assert callable(method)
|
|
||||||
return method(*args)
|
return method(*args)
|
||||||
|
|
||||||
def warmup_model(self):
|
def warmup_model(self):
|
||||||
@@ -109,6 +107,7 @@ class ModelRunner:
|
|||||||
num_kv_heads = hf_config.num_key_value_heads // self.world_size
|
num_kv_heads = hf_config.num_key_value_heads // self.world_size
|
||||||
block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * hf_config.head_dim * hf_config.torch_dtype.itemsize
|
block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * hf_config.head_dim * hf_config.torch_dtype.itemsize
|
||||||
config.num_kvcache_blocks = int(total * config.gpu_memory_utilization - used - peak + current) // block_bytes
|
config.num_kvcache_blocks = int(total * config.gpu_memory_utilization - used - peak + current) // block_bytes
|
||||||
|
assert config.num_kvcache_blocks > 0
|
||||||
self.kv_cache = torch.zeros(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, hf_config.head_dim)
|
self.kv_cache = torch.zeros(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, hf_config.head_dim)
|
||||||
layer_id = 0
|
layer_id = 0
|
||||||
for module in self.model.modules():
|
for module in self.model.modules():
|
||||||
@@ -119,10 +118,7 @@ class ModelRunner:
|
|||||||
|
|
||||||
def prepare_block_tables(self, seqs: list[Sequence]):
|
def prepare_block_tables(self, seqs: list[Sequence]):
|
||||||
max_len = max(len(seq.block_table) for seq in seqs)
|
max_len = max(len(seq.block_table) for seq in seqs)
|
||||||
block_tables = [
|
block_tables = [seq.block_table + [-1] * (max_len - len(seq.block_table)) for seq in seqs]
|
||||||
seq.block_table + [-1] * (max_len - len(seq.block_table))
|
|
||||||
for seq in seqs
|
|
||||||
]
|
|
||||||
block_tables = torch.tensor(block_tables, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
block_tables = torch.tensor(block_tables, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||||
return block_tables
|
return block_tables
|
||||||
|
|
||||||
@@ -219,12 +215,6 @@ class ModelRunner:
|
|||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def capture_cudagraph(self):
|
def capture_cudagraph(self):
|
||||||
get_rng_state = torch.cuda.get_rng_state
|
|
||||||
set_rng_state = torch.cuda.set_rng_state
|
|
||||||
rng_state = torch.cuda.get_rng_state()
|
|
||||||
torch.cuda.get_rng_state = lambda: rng_state
|
|
||||||
torch.cuda.set_rng_state = lambda _: None
|
|
||||||
|
|
||||||
config = self.config
|
config = self.config
|
||||||
hf_config = config.hf_config
|
hf_config = config.hf_config
|
||||||
max_bs = min(self.config.max_num_seqs, 512)
|
max_bs = min(self.config.max_num_seqs, 512)
|
||||||
@@ -259,6 +249,3 @@ class ModelRunner:
|
|||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
outputs=outputs,
|
outputs=outputs,
|
||||||
)
|
)
|
||||||
|
|
||||||
torch.cuda.get_rng_state = get_rng_state
|
|
||||||
torch.cuda.set_rng_state = set_rng_state
|
|
||||||
|
|||||||
Reference in New Issue
Block a user