From 726e4b58cf4afce2463eb5b3e8e6f00578969902 Mon Sep 17 00:00:00 2001 From: Zijie Tian Date: Wed, 28 Jan 2026 13:15:57 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20feat:=20add=20GLM-4-9B-Chat-1M=20mo?= =?UTF-8?q?del=20support?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add support for GLM-4 model architecture with the following changes: - Add glm4.py with ChatGLMForCausalLM, GLM4Model, GLM4Attention, GLM4MLP - Add GLM4RotaryEmbedding with interleaved partial rotation (rotary_dim = head_dim // 2) - Add apply_rotary_emb_interleaved function for GLM-4 style RoPE - Add GLM-4 weight name conversion and loading in loader.py - Add GLM-4 chat template conversion in test_ruler.py - Add trust_remote_code=True for GLM-4 config loading Key GLM-4 specific adaptations: - QKV bias enabled (add_qkv_bias: true) - RoPE with rope_ratio scaling (base = 10000 * rope_ratio) - Interleaved RoPE (pairs adjacent elements, not first/second half) - Partial rotation (only half of head_dim is rotated) - Uses multi_query_group_num instead of num_key_value_heads - Uses kv_channels instead of head_dim - Uses ffn_hidden_size instead of intermediate_size Tested with RULER niah_single_1 (5 samples): 100% accuracy Both GPU-only and CPU offload modes verified Co-Authored-By: Claude Opus 4.5 --- nanovllm/config.py | 7 +- nanovllm/engine/llm_engine.py | 2 +- nanovllm/engine/model_runner.py | 20 ++- nanovllm/layers/rotary_embedding.py | 112 ++++++++++++- nanovllm/models/__init__.py | 1 + nanovllm/models/glm4.py | 235 ++++++++++++++++++++++++++++ nanovllm/utils/loader.py | 131 +++++++++++++++- tests/test_ruler.py | 61 ++++++++ 8 files changed, 557 insertions(+), 12 deletions(-) create mode 100644 nanovllm/models/glm4.py diff --git a/nanovllm/config.py b/nanovllm/config.py index faaeb34..c36cb58 100644 --- a/nanovllm/config.py +++ b/nanovllm/config.py @@ -57,8 +57,11 @@ class Config: assert os.path.isdir(self.model) assert self.kvcache_block_size % 256 == 0 assert 1 <= self.tensor_parallel_size <= 8 - self.hf_config = AutoConfig.from_pretrained(self.model) - self.max_model_len = min(self.max_model_len, self.hf_config.max_position_embeddings) + self.hf_config = AutoConfig.from_pretrained(self.model, trust_remote_code=True) + # Get max position embeddings (GLM-4 uses seq_length instead of max_position_embeddings) + max_pos = getattr(self.hf_config, 'max_position_embeddings', + getattr(self.hf_config, 'seq_length', 4096)) + self.max_model_len = min(self.max_model_len, max_pos) assert self.max_num_batched_tokens >= self.max_model_len # Override torch_dtype if user specified diff --git a/nanovllm/engine/llm_engine.py b/nanovllm/engine/llm_engine.py index ce3087b..055241d 100644 --- a/nanovllm/engine/llm_engine.py +++ b/nanovllm/engine/llm_engine.py @@ -30,7 +30,7 @@ class LLMEngine: self.ps.append(process) self.events.append(event) self.model_runner = ModelRunner(config, 0, self.events) - self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True) + self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True, trust_remote_code=True) config.eos = self.tokenizer.eos_token_id # Set Sequence.block_size to match the KV cache block size Sequence.block_size = config.kvcache_block_size diff --git a/nanovllm/engine/model_runner.py b/nanovllm/engine/model_runner.py index e76e793..72e2e77 100644 --- a/nanovllm/engine/model_runner.py +++ b/nanovllm/engine/model_runner.py @@ -30,6 +30,18 @@ def _find_free_port() -> int: return s.getsockname()[1] +def get_num_kv_heads(hf_config) -> int: + """Get number of KV heads from config (handles GLM-4's multi_query_group_num).""" + return getattr(hf_config, 'num_key_value_heads', + getattr(hf_config, 'multi_query_group_num', hf_config.num_attention_heads)) + + +def get_head_dim(hf_config) -> int: + """Get head dimension from config (handles GLM-4's kv_channels).""" + return getattr(hf_config, "head_dim", + getattr(hf_config, "kv_channels", hf_config.hidden_size // hf_config.num_attention_heads)) + + class ModelRunner: def __init__(self, config: Config, rank: int, event: Event | list[Event]): @@ -144,8 +156,8 @@ class ModelRunner: used = total - free peak = torch.cuda.memory_stats()["allocated_bytes.all.peak"] current = torch.cuda.memory_stats()["allocated_bytes.all.current"] - num_kv_heads = hf_config.num_key_value_heads // self.world_size - head_dim = getattr(hf_config, "head_dim", hf_config.hidden_size // hf_config.num_attention_heads) + num_kv_heads = get_num_kv_heads(hf_config) // self.world_size + head_dim = get_head_dim(hf_config) block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * head_dim * hf_config.torch_dtype.itemsize # Calculate max GPU blocks based on available memory @@ -787,8 +799,8 @@ class ModelRunner: - LastGraph: o_proj → post_norm → mlp → final_norm """ hf_config = self.config.hf_config - num_kv_heads = hf_config.num_key_value_heads // self.world_size - head_dim = getattr(hf_config, "head_dim", hf_config.hidden_size // hf_config.num_attention_heads) + num_kv_heads = get_num_kv_heads(hf_config) // self.world_size + head_dim = get_head_dim(hf_config) # Create Decode Graph Manager (seq_len=1) self.decode_graph_manager = OffloadGraphManager( diff --git a/nanovllm/layers/rotary_embedding.py b/nanovllm/layers/rotary_embedding.py index 4eba274..b681065 100644 --- a/nanovllm/layers/rotary_embedding.py +++ b/nanovllm/layers/rotary_embedding.py @@ -8,12 +8,43 @@ def apply_rotary_emb( cos: torch.Tensor, sin: torch.Tensor, ) -> torch.Tensor: + """Non-interleaved RoPE (used by Llama, Qwen, etc.)""" x1, x2 = torch.chunk(x.float(), 2, dim=-1) y1 = x1 * cos - x2 * sin y2 = x2 * cos + x1 * sin return torch.cat((y1, y2), dim=-1).to(x.dtype) +def apply_rotary_emb_interleaved( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, +) -> torch.Tensor: + """Interleaved RoPE (used by GLM-4, etc.) + + Args: + x: [seq_len, num_heads, head_dim] + cos: [seq_len, 1, head_dim // 2] + sin: [seq_len, 1, head_dim // 2] + + x is reshaped to [seq_len, num_heads, head_dim // 2, 2] where: + - x[..., 0] are even positions + - x[..., 1] are odd positions + """ + rot_dim = x.shape[-1] + # x_shaped: [seq_len, num_heads, rot_dim // 2, 2] + x_shaped = x.float().reshape(*x.shape[:-1], rot_dim // 2, 2) + # x_0, x_1: [seq_len, num_heads, rot_dim // 2] + x_0 = x_shaped[..., 0] + x_1 = x_shaped[..., 1] + # cos/sin: [seq_len, 1, rot_dim // 2] - broadcasts to num_heads + x_out = torch.stack([ + x_0 * cos - x_1 * sin, + x_1 * cos + x_0 * sin, + ], dim=-1) + return x_out.flatten(-2).to(x.dtype) + + class RotaryEmbedding(nn.Module): def __init__( @@ -140,6 +171,76 @@ class Llama3RotaryEmbedding(nn.Module): return query, key +class GLM4RotaryEmbedding(nn.Module): + """ + GLM-4 RoPE with interleaved rotation and partial rotation. + + GLM-4 uses: + - Interleaved rotation (pairs adjacent elements, not first/second half) + - rope_ratio to scale base: base = 10000 * rope_ratio + - Partial rotation: only rotates first rotary_dim elements, rest pass through + - rotary_dim = head_dim // 2 (only half of head_dim is rotated) + """ + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: float, + ) -> None: + super().__init__() + self.head_size = head_size + self.rotary_dim = rotary_dim # GLM-4: rotary_dim = head_dim // 2 + # inv_freq shape: [rotary_dim // 2] + inv_freq = 1.0 / (base ** (torch.arange(0, rotary_dim, 2, dtype=torch.float) / rotary_dim)) + t = torch.arange(max_position_embeddings, dtype=torch.float) + freqs = torch.einsum("i,j -> ij", t, inv_freq) # [max_pos, rotary_dim // 2] + cos = freqs.cos() + sin = freqs.sin() + # cache shape [max_pos, 1, rotary_dim // 2, 2] + cache = torch.stack((cos, sin), dim=-1).unsqueeze_(1) + self.register_buffer("cos_sin_cache", cache, persistent=False) + + @torch.compile + def forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Apply RoPE to query and key. + + Args: + positions: [seq_len] + query: [seq_len, num_heads, head_dim] + key: [seq_len, num_kv_heads, head_dim] + + Returns: + Rotated query and key with same shapes as input. + """ + cache = self.cos_sin_cache[positions] # [seq_len, 1, rotary_dim // 2, 2] + cos = cache[..., 0] # [seq_len, 1, rotary_dim // 2] + sin = cache[..., 1] # [seq_len, 1, rotary_dim // 2] + + # Split into rotated and pass-through parts + q_rot = query[..., :self.rotary_dim] + q_pass = query[..., self.rotary_dim:] + k_rot = key[..., :self.rotary_dim] + k_pass = key[..., self.rotary_dim:] + + # Apply interleaved RoPE to rotated part + q_rot = apply_rotary_emb_interleaved(q_rot, cos, sin) + k_rot = apply_rotary_emb_interleaved(k_rot, cos, sin) + + # Concatenate rotated and pass-through parts + query = torch.cat([q_rot, q_pass], dim=-1) + key = torch.cat([k_rot, k_pass], dim=-1) + + return query, key + + # Cache for RoPE instances (keyed by hashable parameters) _rope_cache: dict[tuple, nn.Module] = {} @@ -150,10 +251,11 @@ def get_rope( max_position: int, base: float, rope_scaling: dict | None = None, + is_interleaved: bool = False, ): # Create hashable cache key if rope_scaling is None: - cache_key = (head_size, rotary_dim, max_position, base, None) + cache_key = (head_size, rotary_dim, max_position, base, None, is_interleaved) else: rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default")) if rope_type == "llama3": @@ -163,15 +265,19 @@ def get_rope( rope_scaling["low_freq_factor"], rope_scaling["high_freq_factor"], rope_scaling["original_max_position_embeddings"], + is_interleaved, ) else: - cache_key = (head_size, rotary_dim, max_position, base, rope_type) + cache_key = (head_size, rotary_dim, max_position, base, rope_type, is_interleaved) if cache_key in _rope_cache: return _rope_cache[cache_key] if rope_scaling is None: - rope = RotaryEmbedding(head_size, rotary_dim, max_position, base) + if is_interleaved: + rope = GLM4RotaryEmbedding(head_size, rotary_dim, max_position, base) + else: + rope = RotaryEmbedding(head_size, rotary_dim, max_position, base) else: rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default")) if rope_type == "llama3": diff --git a/nanovllm/models/__init__.py b/nanovllm/models/__init__.py index 28d41b2..a865c1a 100644 --- a/nanovllm/models/__init__.py +++ b/nanovllm/models/__init__.py @@ -5,5 +5,6 @@ from nanovllm.models.registry import register_model, get_model_class, MODEL_REGI # Import models to trigger registration from nanovllm.models import qwen3 from nanovllm.models import llama +from nanovllm.models import glm4 __all__ = ["register_model", "get_model_class", "MODEL_REGISTRY"] diff --git a/nanovllm/models/glm4.py b/nanovllm/models/glm4.py new file mode 100644 index 0000000..6752d4e --- /dev/null +++ b/nanovllm/models/glm4.py @@ -0,0 +1,235 @@ +"""GLM-4 model implementation for nano-vllm.""" +import torch +from torch import nn +import torch.distributed as dist + +from nanovllm.layers.activation import SiluAndMul +from nanovllm.layers.attention import Attention +from nanovllm.layers.layernorm import RMSNorm +from nanovllm.layers.linear import QKVParallelLinear, MergedColumnParallelLinear, RowParallelLinear +from nanovllm.layers.rotary_embedding import get_rope +from nanovllm.layers.embed_head import VocabParallelEmbedding, ParallelLMHead +from nanovllm.models.registry import register_model + + +class GLM4Attention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 1048576, + head_dim: int = 128, + rope_theta: float = 10000, + rope_scaling: dict | None = None, + ) -> None: + super().__init__() + tp_size = dist.get_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + assert self.total_num_kv_heads % tp_size == 0 + self.num_kv_heads = self.total_num_kv_heads // tp_size + self.head_dim = head_dim + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim ** -0.5 + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=True, # GLM-4 has QKV bias + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, # GLM-4 has no output bias + ) + # GLM-4 only rotates half of head_dim + rotary_dim = self.head_dim // 2 + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=rotary_dim, + max_position=max_position, + base=rope_theta, + rope_scaling=rope_scaling, + is_interleaved=True, # GLM-4 uses interleaved RoPE + ) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + self.num_kv_heads, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q = q.view(-1, self.num_heads, self.head_dim) + k = k.view(-1, self.num_kv_heads, self.head_dim) + v = v.view(-1, self.num_kv_heads, self.head_dim) + q, k = self.rotary_emb(positions, q, k) + o = self.attn(q, k, v) + output = self.o_proj(o.flatten(1, -1)) + return output + + +class GLM4MLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=False, # GLM-4 has no MLP bias + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + ) + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x = self.down_proj(x) + return x + + +class GLM4DecoderLayer(nn.Module): + + def __init__(self, config) -> None: + super().__init__() + # GLM-4 config field mapping + hidden_size = config.hidden_size + num_heads = config.num_attention_heads + num_kv_heads = getattr(config, 'multi_query_group_num', num_heads) + head_dim = getattr(config, 'kv_channels', hidden_size // num_heads) + max_position = getattr(config, 'seq_length', 1048576) + rope_ratio = getattr(config, 'rope_ratio', 1) + rope_theta = 10000 * rope_ratio # GLM-4 uses rope_ratio to scale base + intermediate_size = getattr(config, 'ffn_hidden_size', getattr(config, 'intermediate_size', None)) + rms_norm_eps = getattr(config, 'layernorm_epsilon', 1e-5) + + self.self_attn = GLM4Attention( + hidden_size=hidden_size, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + max_position=max_position, + head_dim=head_dim, + rope_theta=rope_theta, + rope_scaling=getattr(config, "rope_scaling", None), + ) + self.mlp = GLM4MLP( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + ) + self.input_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps) + self.post_attention_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: torch.Tensor | None, + ) -> tuple[torch.Tensor, torch.Tensor]: + if residual is None: + hidden_states, residual = self.input_layernorm(hidden_states), hidden_states + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + hidden_states = self.self_attn(positions, hidden_states) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class GLM4Model(nn.Module): + + def __init__(self, config) -> None: + super().__init__() + vocab_size = getattr(config, 'padded_vocab_size', config.vocab_size) + num_layers = getattr(config, 'num_layers', config.num_hidden_layers) + rms_norm_eps = getattr(config, 'layernorm_epsilon', 1e-5) + + self.embed_tokens = VocabParallelEmbedding(vocab_size, config.hidden_size) + self.layers = nn.ModuleList([GLM4DecoderLayer(config) for _ in range(num_layers)]) + self.norm = RMSNorm(config.hidden_size, eps=rms_norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + residual = None + for layer in self.layers: + hidden_states, residual = layer(positions, hidden_states, residual) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +@register_model("ChatGLMModel", "ChatGLMForConditionalGeneration") +class ChatGLMForCausalLM(nn.Module): + """ + GLM-4 model for causal language modeling. + + Weight mapping from HuggingFace to nanovllm: + - transformer.embedding.word_embeddings → model.embed_tokens + - transformer.encoder.layers.X.input_layernorm → model.layers.X.input_layernorm + - transformer.encoder.layers.X.self_attention.query_key_value → model.layers.X.self_attn.qkv_proj (split q/k/v) + - transformer.encoder.layers.X.self_attention.dense → model.layers.X.self_attn.o_proj + - transformer.encoder.layers.X.post_attention_layernorm → model.layers.X.post_attention_layernorm + - transformer.encoder.layers.X.mlp.dense_h_to_4h → model.layers.X.mlp.gate_up_proj (split gate/up) + - transformer.encoder.layers.X.mlp.dense_4h_to_h → model.layers.X.mlp.down_proj + - transformer.encoder.final_layernorm → model.norm + - transformer.output_layer → lm_head + """ + packed_modules_mapping = { + # QKV is merged in GLM-4 as query_key_value + "query_key_value": ("qkv_proj", None), # Special handling needed + # MLP gate and up are merged as dense_h_to_4h + "dense_h_to_4h": ("gate_up_proj", None), # Special handling needed + } + + # Weight name mapping for loader + hf_to_nanovllm_mapping = { + "transformer.embedding.word_embeddings": "model.embed_tokens", + "transformer.encoder.final_layernorm": "model.norm", + "transformer.output_layer": "lm_head", + } + + def __init__(self, config) -> None: + super().__init__() + vocab_size = getattr(config, 'padded_vocab_size', config.vocab_size) + self.config = config + self.model = GLM4Model(config) + self.lm_head = ParallelLMHead(vocab_size, config.hidden_size) + # GLM-4 does not tie embeddings + # if getattr(config, 'tie_word_embeddings', False): + # self.lm_head.weight.data = self.model.embed_tokens.weight.data + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + ) -> torch.Tensor: + return self.model(input_ids, positions) + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + return self.lm_head(hidden_states) diff --git a/nanovllm/utils/loader.py b/nanovllm/utils/loader.py index 4ef8040..0485ab5 100644 --- a/nanovllm/utils/loader.py +++ b/nanovllm/utils/loader.py @@ -1,4 +1,5 @@ import os +import re from glob import glob import torch from torch import nn @@ -9,20 +10,146 @@ def default_weight_loader(param: nn.Parameter, loaded_weight: torch.Tensor): param.data.copy_(loaded_weight) +# GLM-4 weight name mappings +GLM4_NAME_MAPPING = { + "transformer.embedding.word_embeddings": "model.embed_tokens", + "transformer.encoder.final_layernorm": "model.norm", + "transformer.output_layer": "lm_head", +} + +GLM4_LAYER_MAPPING = { + "self_attention.query_key_value": "self_attn.qkv_proj", + "self_attention.dense": "self_attn.o_proj", + "mlp.dense_h_to_4h": "mlp.gate_up_proj", + "mlp.dense_4h_to_h": "mlp.down_proj", +} + + +def convert_glm4_weight_name(weight_name: str) -> tuple[str, str | None]: + """ + Convert GLM-4 weight name to nanovllm format. + + Returns: + tuple: (converted_name, shard_id) where shard_id is used for packed modules + Returns (None, None) for weights that should be skipped + """ + # Skip rotary embedding weights (we use our own RoPE implementation) + if "rotary_pos_emb" in weight_name: + return None, None + + # Check direct mappings first + for glm_name, nano_name in GLM4_NAME_MAPPING.items(): + if weight_name.startswith(glm_name): + return weight_name.replace(glm_name, nano_name), None + + # Handle layer weights: transformer.encoder.layers.X.xxx + layer_match = re.match(r"transformer\.encoder\.layers\.(\d+)\.(.+)", weight_name) + if layer_match: + layer_idx = layer_match.group(1) + remainder = layer_match.group(2) + + # Handle packed modules (QKV and gate_up) + for glm_subname, nano_subname in GLM4_LAYER_MAPPING.items(): + if remainder.startswith(glm_subname): + suffix = remainder[len(glm_subname):] # .weight or .bias + new_name = f"model.layers.{layer_idx}.{nano_subname}{suffix}" + + # Determine shard_id for packed modules + if "qkv_proj" in nano_subname: + return new_name, "qkv" # Special marker for GLM4 QKV + elif "gate_up_proj" in nano_subname: + return new_name, "gate_up" # Special marker for GLM4 gate_up + else: + return new_name, None + + # Handle non-packed layer weights (layernorms) + new_name = f"model.layers.{layer_idx}.{remainder}" + return new_name, None + + # No mapping found, return original + return weight_name, None + + +def load_glm4_qkv(param: nn.Parameter, loaded_weight: torch.Tensor, config): + """Load GLM-4 merged QKV weights by splitting into q, k, v.""" + num_heads = config.num_attention_heads + num_kv_heads = getattr(config, 'multi_query_group_num', num_heads) + head_dim = getattr(config, 'kv_channels', config.hidden_size // num_heads) + + q_size = num_heads * head_dim + kv_size = num_kv_heads * head_dim + + # Split QKV: [q_size + kv_size + kv_size, hidden_size] + q, k, v = loaded_weight.split([q_size, kv_size, kv_size], dim=0) + + # Load each part using the weight_loader + weight_loader = getattr(param, "weight_loader") + weight_loader(param, q, "q") + weight_loader(param, k, "k") + weight_loader(param, v, "v") + + +def load_glm4_gate_up(param: nn.Parameter, loaded_weight: torch.Tensor, config): + """Load GLM-4 merged gate_up weights by splitting into gate, up.""" + ffn_hidden_size = getattr(config, 'ffn_hidden_size', getattr(config, 'intermediate_size', None)) + + # Split gate_up: [ffn_hidden_size * 2, hidden_size] + gate, up = loaded_weight.split([ffn_hidden_size, ffn_hidden_size], dim=0) + + # Load each part using the weight_loader + weight_loader = getattr(param, "weight_loader") + weight_loader(param, gate, 0) # gate_proj is shard 0 + weight_loader(param, up, 1) # up_proj is shard 1 + + +def is_glm4_model(model: nn.Module) -> bool: + """Check if the model is a GLM-4 model.""" + return model.__class__.__name__ in ("ChatGLMForCausalLM",) + + def load_model(model: nn.Module, path: str): packed_modules_mapping = getattr(model, "packed_modules_mapping", {}) + is_glm4 = is_glm4_model(model) + config = getattr(model, "config", None) + for file in glob(os.path.join(path, "*.safetensors")): with safe_open(file, "pt", "cpu") as f: for weight_name in f.keys(): + loaded_weight = f.get_tensor(weight_name) + + # GLM-4 specific handling + if is_glm4: + param_name, shard_id = convert_glm4_weight_name(weight_name) + + # Skip weights that don't need to be loaded + if param_name is None: + continue + + if shard_id == "qkv": + param = model.get_parameter(param_name) + load_glm4_qkv(param, loaded_weight, config) + continue + elif shard_id == "gate_up": + param = model.get_parameter(param_name) + load_glm4_gate_up(param, loaded_weight, config) + continue + else: + # Regular weight, use converted name + param = model.get_parameter(param_name) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + continue + + # Original loading logic for other models for k in packed_modules_mapping: if k in weight_name: v, shard_id = packed_modules_mapping[k] param_name = weight_name.replace(k, v) param = model.get_parameter(param_name) weight_loader = getattr(param, "weight_loader") - weight_loader(param, f.get_tensor(weight_name), shard_id) + weight_loader(param, loaded_weight, shard_id) break else: param = model.get_parameter(weight_name) weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, f.get_tensor(weight_name)) + weight_loader(param, loaded_weight) diff --git a/tests/test_ruler.py b/tests/test_ruler.py index 5db2638..1f20bf6 100644 --- a/tests/test_ruler.py +++ b/tests/test_ruler.py @@ -48,6 +48,62 @@ from nanovllm import LLM, SamplingParams # ============================================================ DEFAULT_DATA_DIR = Path(__file__).parent / "data/ruler_64k" + + +# ============================================================ +# Chat Template Conversion +# ============================================================ + +def convert_llama_to_glm4_format(prompt: str) -> str: + """ + Convert Llama 3 chat template format to GLM-4 format. + + Llama 3 format: + <|begin_of_text|><|start_header_id|>user<|end_header_id|> + + {user_content}<|eot_id|><|start_header_id|>assistant<|end_header_id|> + + {assistant_prefix} + + GLM-4 format: + [gMASK]<|user|> + {user_content}<|assistant|> + {assistant_prefix} + """ + # Split into user content and assistant prefix + parts = prompt.split("<|eot_id|><|start_header_id|>assistant<|end_header_id|>") + + # Extract user content (remove Llama header tokens) + user_content = parts[0] + user_content = user_content.replace("<|begin_of_text|>", "") + user_content = user_content.replace("<|start_header_id|>user<|end_header_id|>", "") + user_content = user_content.strip() + + # Extract assistant prefix (if exists) + assistant_prefix = "" + if len(parts) > 1: + assistant_prefix = parts[1].replace("<|eot_id|>", "").strip() + + # Apply GLM-4 format + glm_prompt = f"[gMASK]<|user|>\n{user_content}<|assistant|>" + if assistant_prefix: + glm_prompt += f"\n{assistant_prefix}" + + return glm_prompt + + +def is_glm_model(model_path: str) -> bool: + """Check if the model is a GLM model based on config.""" + from transformers import AutoConfig + config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) + return getattr(config, 'model_type', '') == 'chatglm' + + +def convert_prompt_for_model(prompt: str, model_path: str) -> str: + """Convert prompt format based on model type.""" + if is_glm_model(model_path): + return convert_llama_to_glm4_format(prompt) + return prompt # Keep original format for Llama and other models DEFAULT_MODEL = os.path.expanduser("~/models/Llama-3.1-8B-Instruct") # Note: max_model_len must be > max_input_len to leave room for output tokens # 64k benchmark has inputs up to 65536 tokens, so we need 65536 + 128 = 65664 @@ -161,6 +217,7 @@ def run_task_test( verbose: bool = True, llm_factory: Optional[callable] = None, fresh_llm: bool = False, + model_path: Optional[str] = None, ) -> Dict: """ Run test for a single RULER task. @@ -198,6 +255,9 @@ def run_task_test( for sample in samples: idx = sample.get("index", sample["_local_idx"]) prompt = sample["input"] + # Convert prompt format for GLM models + if model_path: + prompt = convert_prompt_for_model(prompt, model_path) expected = sample["outputs"] # Fresh LLM mode: reinitialize for each sample @@ -367,6 +427,7 @@ def run_ruler_benchmark( verbose=verbose and not json_output, llm_factory=create_llm, fresh_llm=fresh_llm, + model_path=model_path, ) task_results.append(result)