✨ feat: add GLM-4-9B-Chat-1M model support
Add support for GLM-4 model architecture with the following changes: - Add glm4.py with ChatGLMForCausalLM, GLM4Model, GLM4Attention, GLM4MLP - Add GLM4RotaryEmbedding with interleaved partial rotation (rotary_dim = head_dim // 2) - Add apply_rotary_emb_interleaved function for GLM-4 style RoPE - Add GLM-4 weight name conversion and loading in loader.py - Add GLM-4 chat template conversion in test_ruler.py - Add trust_remote_code=True for GLM-4 config loading Key GLM-4 specific adaptations: - QKV bias enabled (add_qkv_bias: true) - RoPE with rope_ratio scaling (base = 10000 * rope_ratio) - Interleaved RoPE (pairs adjacent elements, not first/second half) - Partial rotation (only half of head_dim is rotated) - Uses multi_query_group_num instead of num_key_value_heads - Uses kv_channels instead of head_dim - Uses ffn_hidden_size instead of intermediate_size Tested with RULER niah_single_1 (5 samples): 100% accuracy Both GPU-only and CPU offload modes verified Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -8,12 +8,43 @@ def apply_rotary_emb(
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Non-interleaved RoPE (used by Llama, Qwen, etc.)"""
|
||||
x1, x2 = torch.chunk(x.float(), 2, dim=-1)
|
||||
y1 = x1 * cos - x2 * sin
|
||||
y2 = x2 * cos + x1 * sin
|
||||
return torch.cat((y1, y2), dim=-1).to(x.dtype)
|
||||
|
||||
|
||||
def apply_rotary_emb_interleaved(
|
||||
x: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Interleaved RoPE (used by GLM-4, etc.)
|
||||
|
||||
Args:
|
||||
x: [seq_len, num_heads, head_dim]
|
||||
cos: [seq_len, 1, head_dim // 2]
|
||||
sin: [seq_len, 1, head_dim // 2]
|
||||
|
||||
x is reshaped to [seq_len, num_heads, head_dim // 2, 2] where:
|
||||
- x[..., 0] are even positions
|
||||
- x[..., 1] are odd positions
|
||||
"""
|
||||
rot_dim = x.shape[-1]
|
||||
# x_shaped: [seq_len, num_heads, rot_dim // 2, 2]
|
||||
x_shaped = x.float().reshape(*x.shape[:-1], rot_dim // 2, 2)
|
||||
# x_0, x_1: [seq_len, num_heads, rot_dim // 2]
|
||||
x_0 = x_shaped[..., 0]
|
||||
x_1 = x_shaped[..., 1]
|
||||
# cos/sin: [seq_len, 1, rot_dim // 2] - broadcasts to num_heads
|
||||
x_out = torch.stack([
|
||||
x_0 * cos - x_1 * sin,
|
||||
x_1 * cos + x_0 * sin,
|
||||
], dim=-1)
|
||||
return x_out.flatten(-2).to(x.dtype)
|
||||
|
||||
|
||||
class RotaryEmbedding(nn.Module):
|
||||
|
||||
def __init__(
|
||||
@@ -140,6 +171,76 @@ class Llama3RotaryEmbedding(nn.Module):
|
||||
return query, key
|
||||
|
||||
|
||||
class GLM4RotaryEmbedding(nn.Module):
|
||||
"""
|
||||
GLM-4 RoPE with interleaved rotation and partial rotation.
|
||||
|
||||
GLM-4 uses:
|
||||
- Interleaved rotation (pairs adjacent elements, not first/second half)
|
||||
- rope_ratio to scale base: base = 10000 * rope_ratio
|
||||
- Partial rotation: only rotates first rotary_dim elements, rest pass through
|
||||
- rotary_dim = head_dim // 2 (only half of head_dim is rotated)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: float,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.head_size = head_size
|
||||
self.rotary_dim = rotary_dim # GLM-4: rotary_dim = head_dim // 2
|
||||
# inv_freq shape: [rotary_dim // 2]
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, rotary_dim, 2, dtype=torch.float) / rotary_dim))
|
||||
t = torch.arange(max_position_embeddings, dtype=torch.float)
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq) # [max_pos, rotary_dim // 2]
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
# cache shape [max_pos, 1, rotary_dim // 2, 2]
|
||||
cache = torch.stack((cos, sin), dim=-1).unsqueeze_(1)
|
||||
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||
|
||||
@torch.compile
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Apply RoPE to query and key.
|
||||
|
||||
Args:
|
||||
positions: [seq_len]
|
||||
query: [seq_len, num_heads, head_dim]
|
||||
key: [seq_len, num_kv_heads, head_dim]
|
||||
|
||||
Returns:
|
||||
Rotated query and key with same shapes as input.
|
||||
"""
|
||||
cache = self.cos_sin_cache[positions] # [seq_len, 1, rotary_dim // 2, 2]
|
||||
cos = cache[..., 0] # [seq_len, 1, rotary_dim // 2]
|
||||
sin = cache[..., 1] # [seq_len, 1, rotary_dim // 2]
|
||||
|
||||
# Split into rotated and pass-through parts
|
||||
q_rot = query[..., :self.rotary_dim]
|
||||
q_pass = query[..., self.rotary_dim:]
|
||||
k_rot = key[..., :self.rotary_dim]
|
||||
k_pass = key[..., self.rotary_dim:]
|
||||
|
||||
# Apply interleaved RoPE to rotated part
|
||||
q_rot = apply_rotary_emb_interleaved(q_rot, cos, sin)
|
||||
k_rot = apply_rotary_emb_interleaved(k_rot, cos, sin)
|
||||
|
||||
# Concatenate rotated and pass-through parts
|
||||
query = torch.cat([q_rot, q_pass], dim=-1)
|
||||
key = torch.cat([k_rot, k_pass], dim=-1)
|
||||
|
||||
return query, key
|
||||
|
||||
|
||||
# Cache for RoPE instances (keyed by hashable parameters)
|
||||
_rope_cache: dict[tuple, nn.Module] = {}
|
||||
|
||||
@@ -150,10 +251,11 @@ def get_rope(
|
||||
max_position: int,
|
||||
base: float,
|
||||
rope_scaling: dict | None = None,
|
||||
is_interleaved: bool = False,
|
||||
):
|
||||
# Create hashable cache key
|
||||
if rope_scaling is None:
|
||||
cache_key = (head_size, rotary_dim, max_position, base, None)
|
||||
cache_key = (head_size, rotary_dim, max_position, base, None, is_interleaved)
|
||||
else:
|
||||
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default"))
|
||||
if rope_type == "llama3":
|
||||
@@ -163,15 +265,19 @@ def get_rope(
|
||||
rope_scaling["low_freq_factor"],
|
||||
rope_scaling["high_freq_factor"],
|
||||
rope_scaling["original_max_position_embeddings"],
|
||||
is_interleaved,
|
||||
)
|
||||
else:
|
||||
cache_key = (head_size, rotary_dim, max_position, base, rope_type)
|
||||
cache_key = (head_size, rotary_dim, max_position, base, rope_type, is_interleaved)
|
||||
|
||||
if cache_key in _rope_cache:
|
||||
return _rope_cache[cache_key]
|
||||
|
||||
if rope_scaling is None:
|
||||
rope = RotaryEmbedding(head_size, rotary_dim, max_position, base)
|
||||
if is_interleaved:
|
||||
rope = GLM4RotaryEmbedding(head_size, rotary_dim, max_position, base)
|
||||
else:
|
||||
rope = RotaryEmbedding(head_size, rotary_dim, max_position, base)
|
||||
else:
|
||||
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default"))
|
||||
if rope_type == "llama3":
|
||||
|
||||
Reference in New Issue
Block a user