Files
nano-vllm/nanovllm/engine/scheduler.py
Zijie Tian 29e102720b 🐛 fix: support multiple EOS tokens for GLM-4
GLM-4 uses multiple EOS tokens [151329, 151336, 151338] where 151336
(<|user|>) should also stop generation. Previously only the first EOS
from tokenizer was used, causing generation to always hit max_tokens.

Changes:
- config.py: Change eos type to int | list[int]
- llm_engine.py: Read eos_token_id from hf_config (contains full list)
- scheduler.py: Use set for efficient multi-EOS lookup

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-28 13:23:53 +08:00

103 lines
4.2 KiB
Python

from collections import deque
from time import perf_counter_ns
from typing import TYPE_CHECKING
from nanovllm.config import Config
from nanovllm.engine.sequence import Sequence, SequenceStatus
from nanovllm.utils.observer import InferenceObserver
if TYPE_CHECKING:
from nanovllm.kvcache import KVCacheManager
class Scheduler:
def __init__(self, config: Config, kvcache_manager: "KVCacheManager"):
self.max_num_seqs = config.max_num_seqs
self.max_num_batched_tokens = config.max_num_batched_tokens
# Convert EOS to set for efficient lookup (supports single int or list)
eos = config.eos
self.eos_set = set(eos) if isinstance(eos, list) else {eos}
self.kvcache_manager = kvcache_manager
self.waiting: deque[Sequence] = deque()
self.running: deque[Sequence] = deque()
def is_finished(self):
return not self.waiting and not self.running
def add(self, seq: Sequence):
self.waiting.append(seq)
def schedule(self) -> tuple[list[Sequence], bool]:
# prefill
scheduled_seqs = []
num_seqs = 0
num_batched_tokens = 0
while self.waiting and num_seqs < self.max_num_seqs:
if InferenceObserver.ttft_start == 0:
InferenceObserver.ttft_start = perf_counter_ns()
seq = self.waiting[0]
# Check if sequence is too large
if not self.running and num_seqs == 0:
# First sequence, give clear error if it can't be scheduled
if len(seq) > self.max_num_batched_tokens:
raise RuntimeError(
f"Sequence too long: {len(seq)} tokens exceeds "
f"max_num_batched_tokens={self.max_num_batched_tokens}. "
f"Increase max_num_batched_tokens (set equal to max_model_len for long sequences)."
)
if not self.kvcache_manager.can_allocate(seq):
blocks_needed = seq.num_blocks
blocks_available = self.kvcache_manager.num_free_blocks
raise RuntimeError(
f"Cannot allocate KV cache for sequence: "
f"need {blocks_needed} blocks ({len(seq)} tokens), "
f"but only {blocks_available} blocks available. "
f"Increase max_model_len to allocate more blocks."
)
if num_batched_tokens + len(seq) > self.max_num_batched_tokens:
break
if not self.kvcache_manager.can_allocate(seq):
break
num_seqs += 1
self.kvcache_manager.allocate(seq)
num_batched_tokens += len(seq) - seq.num_cached_tokens
seq.status = SequenceStatus.RUNNING
self.waiting.popleft()
self.running.append(seq)
scheduled_seqs.append(seq)
if scheduled_seqs:
return scheduled_seqs, True
# decode
while self.running and num_seqs < self.max_num_seqs:
seq = self.running.popleft()
while not self.kvcache_manager.can_append(seq):
if self.running:
self.preempt(self.running.pop())
else:
self.preempt(seq)
break
else:
num_seqs += 1
self.kvcache_manager.may_append(seq)
scheduled_seqs.append(seq)
assert scheduled_seqs, "No sequences scheduled - this should not happen"
self.running.extendleft(reversed(scheduled_seqs))
return scheduled_seqs, False
def preempt(self, seq: Sequence):
seq.status = SequenceStatus.WAITING
self.kvcache_manager.deallocate(seq)
self.waiting.appendleft(seq)
def postprocess(self, seqs: list[Sequence], token_ids: list[int]) -> list[bool]:
for seq, token_id in zip(seqs, token_ids):
seq.append_token(token_id)
if (not seq.ignore_eos and token_id in self.eos_set) or seq.num_completion_tokens == seq.max_tokens:
seq.status = SequenceStatus.FINISHED
self.kvcache_manager.deallocate(seq)
self.running.remove(seq)