fix
This commit is contained in:
@@ -24,7 +24,7 @@ class ModelRunner:
|
||||
self.sampler = Sampler()
|
||||
self.allocate_kv_cache(config.gpu_memory_utilization)
|
||||
if not self.enforce_eager:
|
||||
self.capture_model()
|
||||
self.capture_cudagraph()
|
||||
torch.set_default_device("cpu")
|
||||
torch.set_default_dtype(default_dtype)
|
||||
|
||||
@@ -101,7 +101,7 @@ class ModelRunner:
|
||||
input_ids.append(seq.last_token)
|
||||
positions.append(len(seq))
|
||||
context_lens.append(len(seq))
|
||||
slot_mapping.append(seq.block_table[-1] * self.block_size + len(seq.last_block()))
|
||||
slot_mapping.append(seq.block_table[-1] * self.block_size + len(seq.last_block()) - 1)
|
||||
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
||||
positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
||||
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||
@@ -152,7 +152,7 @@ class ModelRunner:
|
||||
return token_ids
|
||||
|
||||
@torch.inference_mode()
|
||||
def capture_model(self):
|
||||
def capture_cudagraph(self):
|
||||
get_rng_state = torch.cuda.get_rng_state
|
||||
set_rng_state = torch.cuda.set_rng_state
|
||||
rng_state = torch.cuda.get_rng_state()
|
||||
|
||||
Reference in New Issue
Block a user