✨ feat: add GLM-4-9B-Chat-1M model support
Add support for GLM-4 model architecture with the following changes: - Add glm4.py with ChatGLMForCausalLM, GLM4Model, GLM4Attention, GLM4MLP - Add GLM4RotaryEmbedding with interleaved partial rotation (rotary_dim = head_dim // 2) - Add apply_rotary_emb_interleaved function for GLM-4 style RoPE - Add GLM-4 weight name conversion and loading in loader.py - Add GLM-4 chat template conversion in test_ruler.py - Add trust_remote_code=True for GLM-4 config loading Key GLM-4 specific adaptations: - QKV bias enabled (add_qkv_bias: true) - RoPE with rope_ratio scaling (base = 10000 * rope_ratio) - Interleaved RoPE (pairs adjacent elements, not first/second half) - Partial rotation (only half of head_dim is rotated) - Uses multi_query_group_num instead of num_key_value_heads - Uses kv_channels instead of head_dim - Uses ffn_hidden_size instead of intermediate_size Tested with RULER niah_single_1 (5 samples): 100% accuracy Both GPU-only and CPU offload modes verified Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -48,6 +48,62 @@ from nanovllm import LLM, SamplingParams
|
||||
# ============================================================
|
||||
|
||||
DEFAULT_DATA_DIR = Path(__file__).parent / "data/ruler_64k"
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Chat Template Conversion
|
||||
# ============================================================
|
||||
|
||||
def convert_llama_to_glm4_format(prompt: str) -> str:
|
||||
"""
|
||||
Convert Llama 3 chat template format to GLM-4 format.
|
||||
|
||||
Llama 3 format:
|
||||
<|begin_of_text|><|start_header_id|>user<|end_header_id|>
|
||||
|
||||
{user_content}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
{assistant_prefix}
|
||||
|
||||
GLM-4 format:
|
||||
[gMASK]<sop><|user|>
|
||||
{user_content}<|assistant|>
|
||||
{assistant_prefix}
|
||||
"""
|
||||
# Split into user content and assistant prefix
|
||||
parts = prompt.split("<|eot_id|><|start_header_id|>assistant<|end_header_id|>")
|
||||
|
||||
# Extract user content (remove Llama header tokens)
|
||||
user_content = parts[0]
|
||||
user_content = user_content.replace("<|begin_of_text|>", "")
|
||||
user_content = user_content.replace("<|start_header_id|>user<|end_header_id|>", "")
|
||||
user_content = user_content.strip()
|
||||
|
||||
# Extract assistant prefix (if exists)
|
||||
assistant_prefix = ""
|
||||
if len(parts) > 1:
|
||||
assistant_prefix = parts[1].replace("<|eot_id|>", "").strip()
|
||||
|
||||
# Apply GLM-4 format
|
||||
glm_prompt = f"[gMASK]<sop><|user|>\n{user_content}<|assistant|>"
|
||||
if assistant_prefix:
|
||||
glm_prompt += f"\n{assistant_prefix}"
|
||||
|
||||
return glm_prompt
|
||||
|
||||
|
||||
def is_glm_model(model_path: str) -> bool:
|
||||
"""Check if the model is a GLM model based on config."""
|
||||
from transformers import AutoConfig
|
||||
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
|
||||
return getattr(config, 'model_type', '') == 'chatglm'
|
||||
|
||||
|
||||
def convert_prompt_for_model(prompt: str, model_path: str) -> str:
|
||||
"""Convert prompt format based on model type."""
|
||||
if is_glm_model(model_path):
|
||||
return convert_llama_to_glm4_format(prompt)
|
||||
return prompt # Keep original format for Llama and other models
|
||||
DEFAULT_MODEL = os.path.expanduser("~/models/Llama-3.1-8B-Instruct")
|
||||
# Note: max_model_len must be > max_input_len to leave room for output tokens
|
||||
# 64k benchmark has inputs up to 65536 tokens, so we need 65536 + 128 = 65664
|
||||
@@ -161,6 +217,7 @@ def run_task_test(
|
||||
verbose: bool = True,
|
||||
llm_factory: Optional[callable] = None,
|
||||
fresh_llm: bool = False,
|
||||
model_path: Optional[str] = None,
|
||||
) -> Dict:
|
||||
"""
|
||||
Run test for a single RULER task.
|
||||
@@ -198,6 +255,9 @@ def run_task_test(
|
||||
for sample in samples:
|
||||
idx = sample.get("index", sample["_local_idx"])
|
||||
prompt = sample["input"]
|
||||
# Convert prompt format for GLM models
|
||||
if model_path:
|
||||
prompt = convert_prompt_for_model(prompt, model_path)
|
||||
expected = sample["outputs"]
|
||||
|
||||
# Fresh LLM mode: reinitialize for each sample
|
||||
@@ -367,6 +427,7 @@ def run_ruler_benchmark(
|
||||
verbose=verbose and not json_output,
|
||||
llm_factory=create_llm,
|
||||
fresh_llm=fresh_llm,
|
||||
model_path=model_path,
|
||||
)
|
||||
task_results.append(result)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user