✨ feat: add GLM-4-9B-Chat-1M model support
Add support for GLM-4 model architecture with the following changes: - Add glm4.py with ChatGLMForCausalLM, GLM4Model, GLM4Attention, GLM4MLP - Add GLM4RotaryEmbedding with interleaved partial rotation (rotary_dim = head_dim // 2) - Add apply_rotary_emb_interleaved function for GLM-4 style RoPE - Add GLM-4 weight name conversion and loading in loader.py - Add GLM-4 chat template conversion in test_ruler.py - Add trust_remote_code=True for GLM-4 config loading Key GLM-4 specific adaptations: - QKV bias enabled (add_qkv_bias: true) - RoPE with rope_ratio scaling (base = 10000 * rope_ratio) - Interleaved RoPE (pairs adjacent elements, not first/second half) - Partial rotation (only half of head_dim is rotated) - Uses multi_query_group_num instead of num_key_value_heads - Uses kv_channels instead of head_dim - Uses ffn_hidden_size instead of intermediate_size Tested with RULER niah_single_1 (5 samples): 100% accuracy Both GPU-only and CPU offload modes verified Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -57,8 +57,11 @@ class Config:
|
||||
assert os.path.isdir(self.model)
|
||||
assert self.kvcache_block_size % 256 == 0
|
||||
assert 1 <= self.tensor_parallel_size <= 8
|
||||
self.hf_config = AutoConfig.from_pretrained(self.model)
|
||||
self.max_model_len = min(self.max_model_len, self.hf_config.max_position_embeddings)
|
||||
self.hf_config = AutoConfig.from_pretrained(self.model, trust_remote_code=True)
|
||||
# Get max position embeddings (GLM-4 uses seq_length instead of max_position_embeddings)
|
||||
max_pos = getattr(self.hf_config, 'max_position_embeddings',
|
||||
getattr(self.hf_config, 'seq_length', 4096))
|
||||
self.max_model_len = min(self.max_model_len, max_pos)
|
||||
assert self.max_num_batched_tokens >= self.max_model_len
|
||||
|
||||
# Override torch_dtype if user specified
|
||||
|
||||
@@ -30,7 +30,7 @@ class LLMEngine:
|
||||
self.ps.append(process)
|
||||
self.events.append(event)
|
||||
self.model_runner = ModelRunner(config, 0, self.events)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True, trust_remote_code=True)
|
||||
config.eos = self.tokenizer.eos_token_id
|
||||
# Set Sequence.block_size to match the KV cache block size
|
||||
Sequence.block_size = config.kvcache_block_size
|
||||
|
||||
@@ -30,6 +30,18 @@ def _find_free_port() -> int:
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
def get_num_kv_heads(hf_config) -> int:
|
||||
"""Get number of KV heads from config (handles GLM-4's multi_query_group_num)."""
|
||||
return getattr(hf_config, 'num_key_value_heads',
|
||||
getattr(hf_config, 'multi_query_group_num', hf_config.num_attention_heads))
|
||||
|
||||
|
||||
def get_head_dim(hf_config) -> int:
|
||||
"""Get head dimension from config (handles GLM-4's kv_channels)."""
|
||||
return getattr(hf_config, "head_dim",
|
||||
getattr(hf_config, "kv_channels", hf_config.hidden_size // hf_config.num_attention_heads))
|
||||
|
||||
|
||||
class ModelRunner:
|
||||
|
||||
def __init__(self, config: Config, rank: int, event: Event | list[Event]):
|
||||
@@ -144,8 +156,8 @@ class ModelRunner:
|
||||
used = total - free
|
||||
peak = torch.cuda.memory_stats()["allocated_bytes.all.peak"]
|
||||
current = torch.cuda.memory_stats()["allocated_bytes.all.current"]
|
||||
num_kv_heads = hf_config.num_key_value_heads // self.world_size
|
||||
head_dim = getattr(hf_config, "head_dim", hf_config.hidden_size // hf_config.num_attention_heads)
|
||||
num_kv_heads = get_num_kv_heads(hf_config) // self.world_size
|
||||
head_dim = get_head_dim(hf_config)
|
||||
block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * head_dim * hf_config.torch_dtype.itemsize
|
||||
|
||||
# Calculate max GPU blocks based on available memory
|
||||
@@ -787,8 +799,8 @@ class ModelRunner:
|
||||
- LastGraph: o_proj → post_norm → mlp → final_norm
|
||||
"""
|
||||
hf_config = self.config.hf_config
|
||||
num_kv_heads = hf_config.num_key_value_heads // self.world_size
|
||||
head_dim = getattr(hf_config, "head_dim", hf_config.hidden_size // hf_config.num_attention_heads)
|
||||
num_kv_heads = get_num_kv_heads(hf_config) // self.world_size
|
||||
head_dim = get_head_dim(hf_config)
|
||||
|
||||
# Create Decode Graph Manager (seq_len=1)
|
||||
self.decode_graph_manager = OffloadGraphManager(
|
||||
|
||||
@@ -8,12 +8,43 @@ def apply_rotary_emb(
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Non-interleaved RoPE (used by Llama, Qwen, etc.)"""
|
||||
x1, x2 = torch.chunk(x.float(), 2, dim=-1)
|
||||
y1 = x1 * cos - x2 * sin
|
||||
y2 = x2 * cos + x1 * sin
|
||||
return torch.cat((y1, y2), dim=-1).to(x.dtype)
|
||||
|
||||
|
||||
def apply_rotary_emb_interleaved(
|
||||
x: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Interleaved RoPE (used by GLM-4, etc.)
|
||||
|
||||
Args:
|
||||
x: [seq_len, num_heads, head_dim]
|
||||
cos: [seq_len, 1, head_dim // 2]
|
||||
sin: [seq_len, 1, head_dim // 2]
|
||||
|
||||
x is reshaped to [seq_len, num_heads, head_dim // 2, 2] where:
|
||||
- x[..., 0] are even positions
|
||||
- x[..., 1] are odd positions
|
||||
"""
|
||||
rot_dim = x.shape[-1]
|
||||
# x_shaped: [seq_len, num_heads, rot_dim // 2, 2]
|
||||
x_shaped = x.float().reshape(*x.shape[:-1], rot_dim // 2, 2)
|
||||
# x_0, x_1: [seq_len, num_heads, rot_dim // 2]
|
||||
x_0 = x_shaped[..., 0]
|
||||
x_1 = x_shaped[..., 1]
|
||||
# cos/sin: [seq_len, 1, rot_dim // 2] - broadcasts to num_heads
|
||||
x_out = torch.stack([
|
||||
x_0 * cos - x_1 * sin,
|
||||
x_1 * cos + x_0 * sin,
|
||||
], dim=-1)
|
||||
return x_out.flatten(-2).to(x.dtype)
|
||||
|
||||
|
||||
class RotaryEmbedding(nn.Module):
|
||||
|
||||
def __init__(
|
||||
@@ -140,6 +171,76 @@ class Llama3RotaryEmbedding(nn.Module):
|
||||
return query, key
|
||||
|
||||
|
||||
class GLM4RotaryEmbedding(nn.Module):
|
||||
"""
|
||||
GLM-4 RoPE with interleaved rotation and partial rotation.
|
||||
|
||||
GLM-4 uses:
|
||||
- Interleaved rotation (pairs adjacent elements, not first/second half)
|
||||
- rope_ratio to scale base: base = 10000 * rope_ratio
|
||||
- Partial rotation: only rotates first rotary_dim elements, rest pass through
|
||||
- rotary_dim = head_dim // 2 (only half of head_dim is rotated)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: float,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.head_size = head_size
|
||||
self.rotary_dim = rotary_dim # GLM-4: rotary_dim = head_dim // 2
|
||||
# inv_freq shape: [rotary_dim // 2]
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, rotary_dim, 2, dtype=torch.float) / rotary_dim))
|
||||
t = torch.arange(max_position_embeddings, dtype=torch.float)
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq) # [max_pos, rotary_dim // 2]
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
# cache shape [max_pos, 1, rotary_dim // 2, 2]
|
||||
cache = torch.stack((cos, sin), dim=-1).unsqueeze_(1)
|
||||
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||
|
||||
@torch.compile
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Apply RoPE to query and key.
|
||||
|
||||
Args:
|
||||
positions: [seq_len]
|
||||
query: [seq_len, num_heads, head_dim]
|
||||
key: [seq_len, num_kv_heads, head_dim]
|
||||
|
||||
Returns:
|
||||
Rotated query and key with same shapes as input.
|
||||
"""
|
||||
cache = self.cos_sin_cache[positions] # [seq_len, 1, rotary_dim // 2, 2]
|
||||
cos = cache[..., 0] # [seq_len, 1, rotary_dim // 2]
|
||||
sin = cache[..., 1] # [seq_len, 1, rotary_dim // 2]
|
||||
|
||||
# Split into rotated and pass-through parts
|
||||
q_rot = query[..., :self.rotary_dim]
|
||||
q_pass = query[..., self.rotary_dim:]
|
||||
k_rot = key[..., :self.rotary_dim]
|
||||
k_pass = key[..., self.rotary_dim:]
|
||||
|
||||
# Apply interleaved RoPE to rotated part
|
||||
q_rot = apply_rotary_emb_interleaved(q_rot, cos, sin)
|
||||
k_rot = apply_rotary_emb_interleaved(k_rot, cos, sin)
|
||||
|
||||
# Concatenate rotated and pass-through parts
|
||||
query = torch.cat([q_rot, q_pass], dim=-1)
|
||||
key = torch.cat([k_rot, k_pass], dim=-1)
|
||||
|
||||
return query, key
|
||||
|
||||
|
||||
# Cache for RoPE instances (keyed by hashable parameters)
|
||||
_rope_cache: dict[tuple, nn.Module] = {}
|
||||
|
||||
@@ -150,10 +251,11 @@ def get_rope(
|
||||
max_position: int,
|
||||
base: float,
|
||||
rope_scaling: dict | None = None,
|
||||
is_interleaved: bool = False,
|
||||
):
|
||||
# Create hashable cache key
|
||||
if rope_scaling is None:
|
||||
cache_key = (head_size, rotary_dim, max_position, base, None)
|
||||
cache_key = (head_size, rotary_dim, max_position, base, None, is_interleaved)
|
||||
else:
|
||||
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default"))
|
||||
if rope_type == "llama3":
|
||||
@@ -163,15 +265,19 @@ def get_rope(
|
||||
rope_scaling["low_freq_factor"],
|
||||
rope_scaling["high_freq_factor"],
|
||||
rope_scaling["original_max_position_embeddings"],
|
||||
is_interleaved,
|
||||
)
|
||||
else:
|
||||
cache_key = (head_size, rotary_dim, max_position, base, rope_type)
|
||||
cache_key = (head_size, rotary_dim, max_position, base, rope_type, is_interleaved)
|
||||
|
||||
if cache_key in _rope_cache:
|
||||
return _rope_cache[cache_key]
|
||||
|
||||
if rope_scaling is None:
|
||||
rope = RotaryEmbedding(head_size, rotary_dim, max_position, base)
|
||||
if is_interleaved:
|
||||
rope = GLM4RotaryEmbedding(head_size, rotary_dim, max_position, base)
|
||||
else:
|
||||
rope = RotaryEmbedding(head_size, rotary_dim, max_position, base)
|
||||
else:
|
||||
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default"))
|
||||
if rope_type == "llama3":
|
||||
|
||||
@@ -5,5 +5,6 @@ from nanovllm.models.registry import register_model, get_model_class, MODEL_REGI
|
||||
# Import models to trigger registration
|
||||
from nanovllm.models import qwen3
|
||||
from nanovllm.models import llama
|
||||
from nanovllm.models import glm4
|
||||
|
||||
__all__ = ["register_model", "get_model_class", "MODEL_REGISTRY"]
|
||||
|
||||
235
nanovllm/models/glm4.py
Normal file
235
nanovllm/models/glm4.py
Normal file
@@ -0,0 +1,235 @@
|
||||
"""GLM-4 model implementation for nano-vllm."""
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.distributed as dist
|
||||
|
||||
from nanovllm.layers.activation import SiluAndMul
|
||||
from nanovllm.layers.attention import Attention
|
||||
from nanovllm.layers.layernorm import RMSNorm
|
||||
from nanovllm.layers.linear import QKVParallelLinear, MergedColumnParallelLinear, RowParallelLinear
|
||||
from nanovllm.layers.rotary_embedding import get_rope
|
||||
from nanovllm.layers.embed_head import VocabParallelEmbedding, ParallelLMHead
|
||||
from nanovllm.models.registry import register_model
|
||||
|
||||
|
||||
class GLM4Attention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
max_position: int = 1048576,
|
||||
head_dim: int = 128,
|
||||
rope_theta: float = 10000,
|
||||
rope_scaling: dict | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
tp_size = dist.get_world_size()
|
||||
self.total_num_heads = num_heads
|
||||
assert self.total_num_heads % tp_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_size
|
||||
self.total_num_kv_heads = num_kv_heads
|
||||
assert self.total_num_kv_heads % tp_size == 0
|
||||
self.num_kv_heads = self.total_num_kv_heads // tp_size
|
||||
self.head_dim = head_dim
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim ** -0.5
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size,
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=True, # GLM-4 has QKV bias
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False, # GLM-4 has no output bias
|
||||
)
|
||||
# GLM-4 only rotates half of head_dim
|
||||
rotary_dim = self.head_dim // 2
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=rotary_dim,
|
||||
max_position=max_position,
|
||||
base=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
is_interleaved=True, # GLM-4 uses interleaved RoPE
|
||||
)
|
||||
self.attn = Attention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
self.num_kv_heads,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
qkv = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q = q.view(-1, self.num_heads, self.head_dim)
|
||||
k = k.view(-1, self.num_kv_heads, self.head_dim)
|
||||
v = v.view(-1, self.num_kv_heads, self.head_dim)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
o = self.attn(q, k, v)
|
||||
output = self.o_proj(o.flatten(1, -1))
|
||||
return output
|
||||
|
||||
|
||||
class GLM4MLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size,
|
||||
[intermediate_size] * 2,
|
||||
bias=False, # GLM-4 has no MLP bias
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
)
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
def forward(self, x):
|
||||
gate_up = self.gate_up_proj(x)
|
||||
x = self.act_fn(gate_up)
|
||||
x = self.down_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class GLM4DecoderLayer(nn.Module):
|
||||
|
||||
def __init__(self, config) -> None:
|
||||
super().__init__()
|
||||
# GLM-4 config field mapping
|
||||
hidden_size = config.hidden_size
|
||||
num_heads = config.num_attention_heads
|
||||
num_kv_heads = getattr(config, 'multi_query_group_num', num_heads)
|
||||
head_dim = getattr(config, 'kv_channels', hidden_size // num_heads)
|
||||
max_position = getattr(config, 'seq_length', 1048576)
|
||||
rope_ratio = getattr(config, 'rope_ratio', 1)
|
||||
rope_theta = 10000 * rope_ratio # GLM-4 uses rope_ratio to scale base
|
||||
intermediate_size = getattr(config, 'ffn_hidden_size', getattr(config, 'intermediate_size', None))
|
||||
rms_norm_eps = getattr(config, 'layernorm_epsilon', 1e-5)
|
||||
|
||||
self.self_attn = GLM4Attention(
|
||||
hidden_size=hidden_size,
|
||||
num_heads=num_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
max_position=max_position,
|
||||
head_dim=head_dim,
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=getattr(config, "rope_scaling", None),
|
||||
)
|
||||
self.mlp = GLM4MLP(
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
)
|
||||
self.input_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: torch.Tensor | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if residual is None:
|
||||
hidden_states, residual = self.input_layernorm(hidden_states), hidden_states
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
hidden_states = self.self_attn(positions, hidden_states)
|
||||
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
class GLM4Model(nn.Module):
|
||||
|
||||
def __init__(self, config) -> None:
|
||||
super().__init__()
|
||||
vocab_size = getattr(config, 'padded_vocab_size', config.vocab_size)
|
||||
num_layers = getattr(config, 'num_layers', config.num_hidden_layers)
|
||||
rms_norm_eps = getattr(config, 'layernorm_epsilon', 1e-5)
|
||||
|
||||
self.embed_tokens = VocabParallelEmbedding(vocab_size, config.hidden_size)
|
||||
self.layers = nn.ModuleList([GLM4DecoderLayer(config) for _ in range(num_layers)])
|
||||
self.norm = RMSNorm(config.hidden_size, eps=rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
residual = None
|
||||
for layer in self.layers:
|
||||
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@register_model("ChatGLMModel", "ChatGLMForConditionalGeneration")
|
||||
class ChatGLMForCausalLM(nn.Module):
|
||||
"""
|
||||
GLM-4 model for causal language modeling.
|
||||
|
||||
Weight mapping from HuggingFace to nanovllm:
|
||||
- transformer.embedding.word_embeddings → model.embed_tokens
|
||||
- transformer.encoder.layers.X.input_layernorm → model.layers.X.input_layernorm
|
||||
- transformer.encoder.layers.X.self_attention.query_key_value → model.layers.X.self_attn.qkv_proj (split q/k/v)
|
||||
- transformer.encoder.layers.X.self_attention.dense → model.layers.X.self_attn.o_proj
|
||||
- transformer.encoder.layers.X.post_attention_layernorm → model.layers.X.post_attention_layernorm
|
||||
- transformer.encoder.layers.X.mlp.dense_h_to_4h → model.layers.X.mlp.gate_up_proj (split gate/up)
|
||||
- transformer.encoder.layers.X.mlp.dense_4h_to_h → model.layers.X.mlp.down_proj
|
||||
- transformer.encoder.final_layernorm → model.norm
|
||||
- transformer.output_layer → lm_head
|
||||
"""
|
||||
packed_modules_mapping = {
|
||||
# QKV is merged in GLM-4 as query_key_value
|
||||
"query_key_value": ("qkv_proj", None), # Special handling needed
|
||||
# MLP gate and up are merged as dense_h_to_4h
|
||||
"dense_h_to_4h": ("gate_up_proj", None), # Special handling needed
|
||||
}
|
||||
|
||||
# Weight name mapping for loader
|
||||
hf_to_nanovllm_mapping = {
|
||||
"transformer.embedding.word_embeddings": "model.embed_tokens",
|
||||
"transformer.encoder.final_layernorm": "model.norm",
|
||||
"transformer.output_layer": "lm_head",
|
||||
}
|
||||
|
||||
def __init__(self, config) -> None:
|
||||
super().__init__()
|
||||
vocab_size = getattr(config, 'padded_vocab_size', config.vocab_size)
|
||||
self.config = config
|
||||
self.model = GLM4Model(config)
|
||||
self.lm_head = ParallelLMHead(vocab_size, config.hidden_size)
|
||||
# GLM-4 does not tie embeddings
|
||||
# if getattr(config, 'tie_word_embeddings', False):
|
||||
# self.lm_head.weight.data = self.model.embed_tokens.weight.data
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return self.model(input_ids, positions)
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return self.lm_head(hidden_states)
|
||||
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
import re
|
||||
from glob import glob
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -9,20 +10,146 @@ def default_weight_loader(param: nn.Parameter, loaded_weight: torch.Tensor):
|
||||
param.data.copy_(loaded_weight)
|
||||
|
||||
|
||||
# GLM-4 weight name mappings
|
||||
GLM4_NAME_MAPPING = {
|
||||
"transformer.embedding.word_embeddings": "model.embed_tokens",
|
||||
"transformer.encoder.final_layernorm": "model.norm",
|
||||
"transformer.output_layer": "lm_head",
|
||||
}
|
||||
|
||||
GLM4_LAYER_MAPPING = {
|
||||
"self_attention.query_key_value": "self_attn.qkv_proj",
|
||||
"self_attention.dense": "self_attn.o_proj",
|
||||
"mlp.dense_h_to_4h": "mlp.gate_up_proj",
|
||||
"mlp.dense_4h_to_h": "mlp.down_proj",
|
||||
}
|
||||
|
||||
|
||||
def convert_glm4_weight_name(weight_name: str) -> tuple[str, str | None]:
|
||||
"""
|
||||
Convert GLM-4 weight name to nanovllm format.
|
||||
|
||||
Returns:
|
||||
tuple: (converted_name, shard_id) where shard_id is used for packed modules
|
||||
Returns (None, None) for weights that should be skipped
|
||||
"""
|
||||
# Skip rotary embedding weights (we use our own RoPE implementation)
|
||||
if "rotary_pos_emb" in weight_name:
|
||||
return None, None
|
||||
|
||||
# Check direct mappings first
|
||||
for glm_name, nano_name in GLM4_NAME_MAPPING.items():
|
||||
if weight_name.startswith(glm_name):
|
||||
return weight_name.replace(glm_name, nano_name), None
|
||||
|
||||
# Handle layer weights: transformer.encoder.layers.X.xxx
|
||||
layer_match = re.match(r"transformer\.encoder\.layers\.(\d+)\.(.+)", weight_name)
|
||||
if layer_match:
|
||||
layer_idx = layer_match.group(1)
|
||||
remainder = layer_match.group(2)
|
||||
|
||||
# Handle packed modules (QKV and gate_up)
|
||||
for glm_subname, nano_subname in GLM4_LAYER_MAPPING.items():
|
||||
if remainder.startswith(glm_subname):
|
||||
suffix = remainder[len(glm_subname):] # .weight or .bias
|
||||
new_name = f"model.layers.{layer_idx}.{nano_subname}{suffix}"
|
||||
|
||||
# Determine shard_id for packed modules
|
||||
if "qkv_proj" in nano_subname:
|
||||
return new_name, "qkv" # Special marker for GLM4 QKV
|
||||
elif "gate_up_proj" in nano_subname:
|
||||
return new_name, "gate_up" # Special marker for GLM4 gate_up
|
||||
else:
|
||||
return new_name, None
|
||||
|
||||
# Handle non-packed layer weights (layernorms)
|
||||
new_name = f"model.layers.{layer_idx}.{remainder}"
|
||||
return new_name, None
|
||||
|
||||
# No mapping found, return original
|
||||
return weight_name, None
|
||||
|
||||
|
||||
def load_glm4_qkv(param: nn.Parameter, loaded_weight: torch.Tensor, config):
|
||||
"""Load GLM-4 merged QKV weights by splitting into q, k, v."""
|
||||
num_heads = config.num_attention_heads
|
||||
num_kv_heads = getattr(config, 'multi_query_group_num', num_heads)
|
||||
head_dim = getattr(config, 'kv_channels', config.hidden_size // num_heads)
|
||||
|
||||
q_size = num_heads * head_dim
|
||||
kv_size = num_kv_heads * head_dim
|
||||
|
||||
# Split QKV: [q_size + kv_size + kv_size, hidden_size]
|
||||
q, k, v = loaded_weight.split([q_size, kv_size, kv_size], dim=0)
|
||||
|
||||
# Load each part using the weight_loader
|
||||
weight_loader = getattr(param, "weight_loader")
|
||||
weight_loader(param, q, "q")
|
||||
weight_loader(param, k, "k")
|
||||
weight_loader(param, v, "v")
|
||||
|
||||
|
||||
def load_glm4_gate_up(param: nn.Parameter, loaded_weight: torch.Tensor, config):
|
||||
"""Load GLM-4 merged gate_up weights by splitting into gate, up."""
|
||||
ffn_hidden_size = getattr(config, 'ffn_hidden_size', getattr(config, 'intermediate_size', None))
|
||||
|
||||
# Split gate_up: [ffn_hidden_size * 2, hidden_size]
|
||||
gate, up = loaded_weight.split([ffn_hidden_size, ffn_hidden_size], dim=0)
|
||||
|
||||
# Load each part using the weight_loader
|
||||
weight_loader = getattr(param, "weight_loader")
|
||||
weight_loader(param, gate, 0) # gate_proj is shard 0
|
||||
weight_loader(param, up, 1) # up_proj is shard 1
|
||||
|
||||
|
||||
def is_glm4_model(model: nn.Module) -> bool:
|
||||
"""Check if the model is a GLM-4 model."""
|
||||
return model.__class__.__name__ in ("ChatGLMForCausalLM",)
|
||||
|
||||
|
||||
def load_model(model: nn.Module, path: str):
|
||||
packed_modules_mapping = getattr(model, "packed_modules_mapping", {})
|
||||
is_glm4 = is_glm4_model(model)
|
||||
config = getattr(model, "config", None)
|
||||
|
||||
for file in glob(os.path.join(path, "*.safetensors")):
|
||||
with safe_open(file, "pt", "cpu") as f:
|
||||
for weight_name in f.keys():
|
||||
loaded_weight = f.get_tensor(weight_name)
|
||||
|
||||
# GLM-4 specific handling
|
||||
if is_glm4:
|
||||
param_name, shard_id = convert_glm4_weight_name(weight_name)
|
||||
|
||||
# Skip weights that don't need to be loaded
|
||||
if param_name is None:
|
||||
continue
|
||||
|
||||
if shard_id == "qkv":
|
||||
param = model.get_parameter(param_name)
|
||||
load_glm4_qkv(param, loaded_weight, config)
|
||||
continue
|
||||
elif shard_id == "gate_up":
|
||||
param = model.get_parameter(param_name)
|
||||
load_glm4_gate_up(param, loaded_weight, config)
|
||||
continue
|
||||
else:
|
||||
# Regular weight, use converted name
|
||||
param = model.get_parameter(param_name)
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
continue
|
||||
|
||||
# Original loading logic for other models
|
||||
for k in packed_modules_mapping:
|
||||
if k in weight_name:
|
||||
v, shard_id = packed_modules_mapping[k]
|
||||
param_name = weight_name.replace(k, v)
|
||||
param = model.get_parameter(param_name)
|
||||
weight_loader = getattr(param, "weight_loader")
|
||||
weight_loader(param, f.get_tensor(weight_name), shard_id)
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
param = model.get_parameter(weight_name)
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, f.get_tensor(weight_name))
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
Reference in New Issue
Block a user