From 29e102720b3bbcb53f789805bd8949a005f9ab61 Mon Sep 17 00:00:00 2001 From: Zijie Tian Date: Wed, 28 Jan 2026 13:23:53 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20fix:=20support=20multiple=20EOS?= =?UTF-8?q?=20tokens=20for=20GLM-4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit GLM-4 uses multiple EOS tokens [151329, 151336, 151338] where 151336 (<|user|>) should also stop generation. Previously only the first EOS from tokenizer was used, causing generation to always hit max_tokens. Changes: - config.py: Change eos type to int | list[int] - llm_engine.py: Read eos_token_id from hf_config (contains full list) - scheduler.py: Use set for efficient multi-EOS lookup Co-Authored-By: Claude Opus 4.5 --- nanovllm/config.py | 2 +- nanovllm/engine/llm_engine.py | 8 +++++++- nanovllm/engine/scheduler.py | 6 ++++-- 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/nanovllm/config.py b/nanovllm/config.py index c36cb58..2c75300 100644 --- a/nanovllm/config.py +++ b/nanovllm/config.py @@ -22,7 +22,7 @@ class Config: tensor_parallel_size: int = 1 enforce_eager: bool = False hf_config: AutoConfig | None = None - eos: int = -1 + eos: int | list[int] = -1 # Single EOS token or list of EOS tokens (e.g., GLM-4) kvcache_block_size: int = 1024 num_kvcache_blocks: int = -1 dtype: str | None = None # "float16", "bfloat16", or None (use model default) diff --git a/nanovllm/engine/llm_engine.py b/nanovllm/engine/llm_engine.py index 055241d..bd2d280 100644 --- a/nanovllm/engine/llm_engine.py +++ b/nanovllm/engine/llm_engine.py @@ -31,7 +31,13 @@ class LLMEngine: self.events.append(event) self.model_runner = ModelRunner(config, 0, self.events) self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True, trust_remote_code=True) - config.eos = self.tokenizer.eos_token_id + # Get EOS token(s) from config (may be int or list, e.g., GLM-4 uses list) + # Prefer hf_config.eos_token_id which contains full list, fallback to tokenizer + eos_from_config = getattr(config.hf_config, 'eos_token_id', None) + if eos_from_config is not None: + config.eos = eos_from_config + else: + config.eos = self.tokenizer.eos_token_id # Set Sequence.block_size to match the KV cache block size Sequence.block_size = config.kvcache_block_size self.scheduler = Scheduler(config, self.model_runner.kvcache_manager) diff --git a/nanovllm/engine/scheduler.py b/nanovllm/engine/scheduler.py index 994ddd6..9944de3 100644 --- a/nanovllm/engine/scheduler.py +++ b/nanovllm/engine/scheduler.py @@ -15,7 +15,9 @@ class Scheduler: def __init__(self, config: Config, kvcache_manager: "KVCacheManager"): self.max_num_seqs = config.max_num_seqs self.max_num_batched_tokens = config.max_num_batched_tokens - self.eos = config.eos + # Convert EOS to set for efficient lookup (supports single int or list) + eos = config.eos + self.eos_set = set(eos) if isinstance(eos, list) else {eos} self.kvcache_manager = kvcache_manager self.waiting: deque[Sequence] = deque() self.running: deque[Sequence] = deque() @@ -94,7 +96,7 @@ class Scheduler: def postprocess(self, seqs: list[Sequence], token_ids: list[int]) -> list[bool]: for seq, token_id in zip(seqs, token_ids): seq.append_token(token_id) - if (not seq.ignore_eos and token_id == self.eos) or seq.num_completion_tokens == seq.max_tokens: + if (not seq.ignore_eos and token_id in self.eos_set) or seq.num_completion_tokens == seq.max_tokens: seq.status = SequenceStatus.FINISHED self.kvcache_manager.deallocate(seq) self.running.remove(seq)