This commit is contained in:
GeeeekExplorer
2025-06-10 08:52:58 +08:00
parent a5a4909e6a
commit b98e1ca305
10 changed files with 39 additions and 26 deletions

View File

@@ -24,7 +24,7 @@ class ModelRunner:
self.sampler = Sampler()
self.allocate_kv_cache(config.gpu_memory_utilization)
if not self.enforce_eager:
self.capture_model()
self.capture_cudagraph()
torch.set_default_device("cpu")
torch.set_default_dtype(default_dtype)
@@ -101,7 +101,7 @@ class ModelRunner:
input_ids.append(seq.last_token)
positions.append(len(seq))
context_lens.append(len(seq))
slot_mapping.append(seq.block_table[-1] * self.block_size + len(seq.last_block()))
slot_mapping.append(seq.block_table[-1] * self.block_size + len(seq.last_block()) - 1)
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
@@ -152,7 +152,7 @@ class ModelRunner:
return token_ids
@torch.inference_mode()
def capture_model(self):
def capture_cudagraph(self):
get_rng_state = torch.cuda.get_rng_state
set_rng_state = torch.cuda.set_rng_state
rng_state = torch.cuda.get_rng_state()