feat: add GLM-4-9B-Chat-1M model support

Add support for GLM-4 model architecture with the following changes:

- Add glm4.py with ChatGLMForCausalLM, GLM4Model, GLM4Attention, GLM4MLP
- Add GLM4RotaryEmbedding with interleaved partial rotation (rotary_dim = head_dim // 2)
- Add apply_rotary_emb_interleaved function for GLM-4 style RoPE
- Add GLM-4 weight name conversion and loading in loader.py
- Add GLM-4 chat template conversion in test_ruler.py
- Add trust_remote_code=True for GLM-4 config loading

Key GLM-4 specific adaptations:
- QKV bias enabled (add_qkv_bias: true)
- RoPE with rope_ratio scaling (base = 10000 * rope_ratio)
- Interleaved RoPE (pairs adjacent elements, not first/second half)
- Partial rotation (only half of head_dim is rotated)
- Uses multi_query_group_num instead of num_key_value_heads
- Uses kv_channels instead of head_dim
- Uses ffn_hidden_size instead of intermediate_size

Tested with RULER niah_single_1 (5 samples): 100% accuracy
Both GPU-only and CPU offload modes verified

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Zijie Tian
2026-01-28 13:15:57 +08:00
parent 8d19e61446
commit 726e4b58cf
8 changed files with 557 additions and 12 deletions

View File

@@ -30,6 +30,18 @@ def _find_free_port() -> int:
return s.getsockname()[1]
def get_num_kv_heads(hf_config) -> int:
"""Get number of KV heads from config (handles GLM-4's multi_query_group_num)."""
return getattr(hf_config, 'num_key_value_heads',
getattr(hf_config, 'multi_query_group_num', hf_config.num_attention_heads))
def get_head_dim(hf_config) -> int:
"""Get head dimension from config (handles GLM-4's kv_channels)."""
return getattr(hf_config, "head_dim",
getattr(hf_config, "kv_channels", hf_config.hidden_size // hf_config.num_attention_heads))
class ModelRunner:
def __init__(self, config: Config, rank: int, event: Event | list[Event]):
@@ -144,8 +156,8 @@ class ModelRunner:
used = total - free
peak = torch.cuda.memory_stats()["allocated_bytes.all.peak"]
current = torch.cuda.memory_stats()["allocated_bytes.all.current"]
num_kv_heads = hf_config.num_key_value_heads // self.world_size
head_dim = getattr(hf_config, "head_dim", hf_config.hidden_size // hf_config.num_attention_heads)
num_kv_heads = get_num_kv_heads(hf_config) // self.world_size
head_dim = get_head_dim(hf_config)
block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * head_dim * hf_config.torch_dtype.itemsize
# Calculate max GPU blocks based on available memory
@@ -787,8 +799,8 @@ class ModelRunner:
- LastGraph: o_proj → post_norm → mlp → final_norm
"""
hf_config = self.config.hf_config
num_kv_heads = hf_config.num_key_value_heads // self.world_size
head_dim = getattr(hf_config, "head_dim", hf_config.hidden_size // hf_config.num_attention_heads)
num_kv_heads = get_num_kv_heads(hf_config) // self.world_size
head_dim = get_head_dim(hf_config)
# Create Decode Graph Manager (seq_len=1)
self.decode_graph_manager = OffloadGraphManager(