From b6136383c90ab236387718bfc5b9a93a2c21bd69 Mon Sep 17 00:00:00 2001 From: GeeeekExplorer <2651904866@qq.com> Date: Sat, 14 Jun 2025 13:36:57 +0800 Subject: [PATCH] support fast pickle --- nanovllm/engine/block_manager.py | 2 +- nanovllm/engine/model_runner.py | 6 +++--- nanovllm/engine/sequence.py | 36 +++++++++++++++----------------- 3 files changed, 21 insertions(+), 23 deletions(-) diff --git a/nanovllm/engine/block_manager.py b/nanovllm/engine/block_manager.py index f916fc4..9327e95 100644 --- a/nanovllm/engine/block_manager.py +++ b/nanovllm/engine/block_manager.py @@ -107,7 +107,7 @@ class BlockManager: block_table.append(block_id) elif len(seq) % self.block_size == 0: assert last_block.hash == -1 - token_ids = seq.last_block() + token_ids = seq.block(seq.num_blocks-1) prefix = self.blocks[block_table[-2]].hash if len(block_table) > 1 else -1 h = compute_hash(token_ids, prefix) last_block.update(h, token_ids) diff --git a/nanovllm/engine/model_runner.py b/nanovllm/engine/model_runner.py index b39fabb..0978e32 100644 --- a/nanovllm/engine/model_runner.py +++ b/nanovllm/engine/model_runner.py @@ -66,7 +66,7 @@ class ModelRunner: for seq in seqs: seqlen = len(seq) input_ids.extend(seq[seq.num_cached_tokens:]) - positions.extend(list(range(seq.num_cached_tokens, len(seq)))) + positions.extend(list(range(seq.num_cached_tokens, seqlen))) seqlen_q = seqlen - seq.num_cached_tokens seqlen_k = seqlen cu_seqlens_q.append(cu_seqlens_q[-1] + seqlen_q) @@ -78,7 +78,7 @@ class ModelRunner: if i != seq.num_blocks - 1: end = start + self.block_size else: - end = start + len(seq.last_block()) + end = start + seq.last_block_num_tokens slot_mapping.extend(list(range(start, end))) assert len(input_ids) == len(slot_mapping) assert len(input_ids) == cu_seqlens_q[-1] @@ -102,7 +102,7 @@ class ModelRunner: input_ids.append(seq.last_token) positions.append(len(seq)) context_lens.append(len(seq)) - slot_mapping.append(seq.block_table[-1] * self.block_size + len(seq.last_block()) - 1) + slot_mapping.append(seq.block_table[-1] * self.block_size + seq.last_block_num_tokens - 1) input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) diff --git a/nanovllm/engine/sequence.py b/nanovllm/engine/sequence.py index 51c0438..2216a86 100644 --- a/nanovllm/engine/sequence.py +++ b/nanovllm/engine/sequence.py @@ -19,15 +19,17 @@ class Sequence: self.seq_id = next(Sequence.counter) self.status = SequenceStatus.WAITING self.token_ids = copy(token_ids) + self.last_token = token_ids[-1] + self.num_tokens = len(self.token_ids) self.num_prompt_tokens = len(token_ids) - self._num_cached_tokens = 0 + self.num_cached_tokens = 0 self.block_table = [] self.temperature = sampling_params.temperature self.max_tokens = sampling_params.max_tokens self.ignore_eos = sampling_params.ignore_eos def __len__(self): - return len(self.token_ids) + return self.num_tokens def __lt__(self, other): return self.seq_id < other.seq_id @@ -41,7 +43,7 @@ class Sequence: @property def num_completion_tokens(self): - return len(self.token_ids) - self.num_prompt_tokens + return self.num_tokens - self.num_prompt_tokens @property def prompt_token_ids(self): @@ -51,33 +53,29 @@ class Sequence: def completion_token_ids(self): return self.token_ids[self.num_prompt_tokens:] - @property - def num_cached_tokens(self): - return self._num_cached_tokens - - @num_cached_tokens.setter - def num_cached_tokens(self, num_cached_tokens): - assert num_cached_tokens % self.block_size == 0 - self._num_cached_tokens = num_cached_tokens - @property def num_cached_blocks(self): return self.num_cached_tokens // self.block_size @property def num_blocks(self): - return (len(self.token_ids) + self.block_size - 1) // self.block_size + return (self.num_tokens + self.block_size - 1) // self.block_size @property - def last_token(self): - return self.token_ids[-1] + def last_block_num_tokens(self): + return self.num_tokens - (self.num_blocks - 1) * self.block_size def block(self, i): + assert 0 <= i < self.num_blocks return self.token_ids[i*self.block_size: (i+1)*self.block_size] - def last_block(self): - n = self.num_blocks - return self.token_ids[(n-1)*self.block_size:] - def append_token(self, token_id: int): self.token_ids.append(token_id) + self.last_token = token_id + self.num_tokens += 1 + + def __getstate__(self): + state = super().__getstate__() + if self.num_completion_tokens: + state.pop("token_ids") + return state