🐛 fix: support multiple EOS tokens for GLM-4

GLM-4 uses multiple EOS tokens [151329, 151336, 151338] where 151336
(<|user|>) should also stop generation. Previously only the first EOS
from tokenizer was used, causing generation to always hit max_tokens.

Changes:
- config.py: Change eos type to int | list[int]
- llm_engine.py: Read eos_token_id from hf_config (contains full list)
- scheduler.py: Use set for efficient multi-EOS lookup

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Zijie Tian
2026-01-28 13:23:53 +08:00
parent 726e4b58cf
commit 29e102720b
3 changed files with 12 additions and 4 deletions

View File

@@ -22,7 +22,7 @@ class Config:
tensor_parallel_size: int = 1
enforce_eager: bool = False
hf_config: AutoConfig | None = None
eos: int = -1
eos: int | list[int] = -1 # Single EOS token or list of EOS tokens (e.g., GLM-4)
kvcache_block_size: int = 1024
num_kvcache_blocks: int = -1
dtype: str | None = None # "float16", "bfloat16", or None (use model default)

View File

@@ -31,7 +31,13 @@ class LLMEngine:
self.events.append(event)
self.model_runner = ModelRunner(config, 0, self.events)
self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True, trust_remote_code=True)
config.eos = self.tokenizer.eos_token_id
# Get EOS token(s) from config (may be int or list, e.g., GLM-4 uses list)
# Prefer hf_config.eos_token_id which contains full list, fallback to tokenizer
eos_from_config = getattr(config.hf_config, 'eos_token_id', None)
if eos_from_config is not None:
config.eos = eos_from_config
else:
config.eos = self.tokenizer.eos_token_id
# Set Sequence.block_size to match the KV cache block size
Sequence.block_size = config.kvcache_block_size
self.scheduler = Scheduler(config, self.model_runner.kvcache_manager)

View File

@@ -15,7 +15,9 @@ class Scheduler:
def __init__(self, config: Config, kvcache_manager: "KVCacheManager"):
self.max_num_seqs = config.max_num_seqs
self.max_num_batched_tokens = config.max_num_batched_tokens
self.eos = config.eos
# Convert EOS to set for efficient lookup (supports single int or list)
eos = config.eos
self.eos_set = set(eos) if isinstance(eos, list) else {eos}
self.kvcache_manager = kvcache_manager
self.waiting: deque[Sequence] = deque()
self.running: deque[Sequence] = deque()
@@ -94,7 +96,7 @@ class Scheduler:
def postprocess(self, seqs: list[Sequence], token_ids: list[int]) -> list[bool]:
for seq, token_id in zip(seqs, token_ids):
seq.append_token(token_id)
if (not seq.ignore_eos and token_id == self.eos) or seq.num_completion_tokens == seq.max_tokens:
if (not seq.ignore_eos and token_id in self.eos_set) or seq.num_completion_tokens == seq.max_tokens:
seq.status = SequenceStatus.FINISHED
self.kvcache_manager.deallocate(seq)
self.running.remove(seq)