refactor
This commit is contained in:
@@ -43,7 +43,6 @@ class Scheduler:
|
||||
return scheduled_seqs, True
|
||||
|
||||
# decode
|
||||
# self.running = deque(sorted(self.running))
|
||||
while self.running and num_seqs < self.max_num_seqs:
|
||||
seq = self.running.popleft()
|
||||
while not self.block_manager.can_append(seq):
|
||||
@@ -59,8 +58,8 @@ class Scheduler:
|
||||
running = deque(scheduled_seqs)
|
||||
running.extend(self.running)
|
||||
self.running = running
|
||||
if scheduled_seqs:
|
||||
return scheduled_seqs, False
|
||||
assert scheduled_seqs
|
||||
return scheduled_seqs, False
|
||||
|
||||
def preempt(self, seq: Sequence):
|
||||
seq.status = SequenceStatus.WAITING
|
||||
@@ -69,7 +68,6 @@ class Scheduler:
|
||||
|
||||
def postprocess(self, seqs: list[Sequence], token_ids: list[int]) -> list[bool]:
|
||||
self.num_tokens += len(token_ids)
|
||||
finished = []
|
||||
for seq, token_id in zip(seqs, token_ids):
|
||||
seq.append_token(token_id)
|
||||
if (not seq.ignore_eos and token_id == self.eos) or seq.num_completion_tokens == seq.max_tokens:
|
||||
@@ -77,7 +75,3 @@ class Scheduler:
|
||||
self.block_manager.deallocate(seq)
|
||||
self.running.remove(seq)
|
||||
self.num_finished += 1
|
||||
finished.append(True)
|
||||
else:
|
||||
finished.append(False)
|
||||
return finished
|
||||
|
||||
Reference in New Issue
Block a user