support fast pickle

This commit is contained in:
GeeeekExplorer
2025-06-14 13:36:57 +08:00
parent 4a8aa090a7
commit b6136383c9
3 changed files with 21 additions and 23 deletions

View File

@@ -66,7 +66,7 @@ class ModelRunner:
for seq in seqs:
seqlen = len(seq)
input_ids.extend(seq[seq.num_cached_tokens:])
positions.extend(list(range(seq.num_cached_tokens, len(seq))))
positions.extend(list(range(seq.num_cached_tokens, seqlen)))
seqlen_q = seqlen - seq.num_cached_tokens
seqlen_k = seqlen
cu_seqlens_q.append(cu_seqlens_q[-1] + seqlen_q)
@@ -78,7 +78,7 @@ class ModelRunner:
if i != seq.num_blocks - 1:
end = start + self.block_size
else:
end = start + len(seq.last_block())
end = start + seq.last_block_num_tokens
slot_mapping.extend(list(range(start, end)))
assert len(input_ids) == len(slot_mapping)
assert len(input_ids) == cu_seqlens_q[-1]
@@ -102,7 +102,7 @@ class ModelRunner:
input_ids.append(seq.last_token)
positions.append(len(seq))
context_lens.append(len(seq))
slot_mapping.append(seq.block_table[-1] * self.block_size + len(seq.last_block()) - 1)
slot_mapping.append(seq.block_table[-1] * self.block_size + seq.last_block_num_tokens - 1)
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)