Compare commits
4 Commits
8d19e61446
...
e09a2a5b10
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e09a2a5b10 | ||
|
|
a239bfb40d | ||
|
|
29e102720b | ||
|
|
726e4b58cf |
@@ -4,7 +4,7 @@ This file provides guidance to Claude Code when working with this repository.
|
||||
|
||||
## Overview
|
||||
|
||||
Nano-vLLM is a lightweight vLLM implementation (~1,200 lines) for fast offline LLM inference. Supports Qwen3 models with CPU offload for long-context inference.
|
||||
Nano-vLLM is a lightweight vLLM implementation (~1,200 lines) for fast offline LLM inference. Supports Qwen3, Llama-3, and GLM-4 models with CPU offload for long-context inference.
|
||||
|
||||
## Documentation Index
|
||||
|
||||
@@ -35,6 +35,7 @@ Nano-vLLM is a lightweight vLLM implementation (~1,200 lines) for fast offline L
|
||||
| [`docs/memory_communication_benchmark.md`](docs/memory_communication_benchmark.md) | 📊 通信量测试: Full vs XAttention 通信量对比 (32K/64K)、阶段分离统计 |
|
||||
| [`docs/estimate_block_size_performance.md`](docs/estimate_block_size_performance.md) | 🔥 PERF: estimate 阶段 block_size 性能分析,softmax_fuse_block_sum 最优点 (512-1024),当前 4096 慢 15x |
|
||||
| [`docs/long_context_models_1m.md`](docs/long_context_models_1m.md) | 📚 REF: 1M+ 上下文长度模型列表 (Qwen/GLM/InternLM/Llama/VL),≤10B 推荐模型 |
|
||||
| [`docs/new_model_integration_guide.md`](docs/new_model_integration_guide.md) | 🔧 GUIDE: 新模型整合指南 - 配置映射、RoPE变体、EOS处理、权重转换、验证清单 |
|
||||
|
||||
## Rules Index
|
||||
|
||||
|
||||
323
docs/new_model_integration_guide.md
Normal file
323
docs/new_model_integration_guide.md
Normal file
@@ -0,0 +1,323 @@
|
||||
# 新模型整合指南
|
||||
|
||||
本文档总结了将新模型(如GLM-4)整合到nanovllm的经验和常见问题。
|
||||
|
||||
## 整合流程概览
|
||||
|
||||
```
|
||||
1. 分析模型配置 (config.json)
|
||||
↓
|
||||
2. 创建模型文件 (nanovllm/models/<model>.py)
|
||||
↓
|
||||
3. 实现权重加载 (nanovllm/utils/loader.py)
|
||||
↓
|
||||
4. 处理特殊组件 (RoPE, Attention, etc.)
|
||||
↓
|
||||
5. 处理tokenizer差异 (EOS tokens, chat template)
|
||||
↓
|
||||
6. 验证输出正确性
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 1. 配置字段映射
|
||||
|
||||
不同模型使用不同的配置字段名称,需要建立映射关系:
|
||||
|
||||
| 标准字段 | GLM-4 | Qwen | Llama | 说明 |
|
||||
|----------|-------|------|-------|------|
|
||||
| `num_key_value_heads` | `multi_query_group_num` | `num_key_value_heads` | `num_key_value_heads` | KV heads数量 |
|
||||
| `head_dim` | `kv_channels` | 计算得出 | 计算得出 | 每个head的维度 |
|
||||
| `intermediate_size` | `ffn_hidden_size` | `intermediate_size` | `intermediate_size` | FFN隐藏层大小 |
|
||||
| `max_position_embeddings` | `seq_length` | `max_position_embeddings` | `max_position_embeddings` | 最大位置 |
|
||||
| `rope_theta` | `10000 * rope_ratio` | `rope_theta` | `rope_theta` | RoPE基础频率 |
|
||||
|
||||
### 代码示例
|
||||
|
||||
```python
|
||||
# 在模型 __init__ 中处理配置差异
|
||||
num_kv_heads = getattr(config, 'num_key_value_heads',
|
||||
getattr(config, 'multi_query_group_num', num_heads))
|
||||
|
||||
head_dim = getattr(config, 'head_dim',
|
||||
getattr(config, 'kv_channels', hidden_size // num_heads))
|
||||
|
||||
intermediate_size = getattr(config, 'intermediate_size',
|
||||
getattr(config, 'ffn_hidden_size', None))
|
||||
|
||||
max_position = getattr(config, 'max_position_embeddings',
|
||||
getattr(config, 'seq_length', 4096))
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 2. RoPE实现差异
|
||||
|
||||
RoPE是模型整合中**最容易出错**的部分。不同模型可能使用不同的RoPE变体:
|
||||
|
||||
### 2.1 旋转方式
|
||||
|
||||
| 类型 | 描述 | 使用模型 |
|
||||
|------|------|----------|
|
||||
| **Half rotation** | 前半和后半分别旋转 `[x0,x1,...] → [x0*cos-x_{d/2}*sin, ...]` | Llama, Qwen |
|
||||
| **Interleaved rotation** | 相邻元素配对旋转 `[x0,x1,...] → [x0*cos-x1*sin, x1*cos+x0*sin, ...]` | GLM-4 |
|
||||
|
||||
### 2.2 旋转维度
|
||||
|
||||
| 类型 | 描述 | 使用模型 |
|
||||
|------|------|----------|
|
||||
| **Full rotation** | 旋转整个head_dim | Llama, Qwen |
|
||||
| **Partial rotation** | 只旋转head_dim的一部分,其余pass-through | GLM-4 (rotary_dim = head_dim // 2) |
|
||||
|
||||
### 2.3 GLM-4 RoPE实现
|
||||
|
||||
```python
|
||||
class GLM4RotaryEmbedding(nn.Module):
|
||||
def __init__(self, head_dim, rotary_dim, ...):
|
||||
# GLM-4只旋转一半维度
|
||||
self.rotary_dim = rotary_dim # = head_dim // 2
|
||||
|
||||
def forward(self, positions, query, key):
|
||||
# 分离旋转部分和pass-through部分
|
||||
q_rot = query[..., :self.rotary_dim]
|
||||
q_pass = query[..., self.rotary_dim:]
|
||||
|
||||
# 只对旋转部分应用interleaved RoPE
|
||||
q_rot = apply_rotary_emb_interleaved(q_rot, cos, sin)
|
||||
|
||||
# 拼接回去
|
||||
return torch.cat([q_rot, q_pass], dim=-1), ...
|
||||
```
|
||||
|
||||
### 2.4 调试RoPE问题
|
||||
|
||||
**症状**:模型输出乱码或重复无意义的内容(如 "The. The. The...")
|
||||
|
||||
**调试方法**:
|
||||
```python
|
||||
# 对比HuggingFace参考实现的输出
|
||||
hf_q, hf_k = hf_model.apply_rotary_pos_emb(query, key, cos, sin)
|
||||
my_q, my_k = my_rotary_emb(positions, query, key)
|
||||
|
||||
print(f"Q max diff: {(hf_q - my_q).abs().max()}") # 应该 < 1e-5
|
||||
print(f"K max diff: {(hf_k - my_k).abs().max()}") # 应该 < 1e-5
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 3. 权重名称映射
|
||||
|
||||
不同模型的权重命名规范不同:
|
||||
|
||||
### 3.1 常见映射
|
||||
|
||||
| 组件 | Llama/Qwen | GLM-4 |
|
||||
|------|------------|-------|
|
||||
| Attention QKV | `q_proj`, `k_proj`, `v_proj` | `query_key_value` (合并) |
|
||||
| Attention Output | `o_proj` | `dense` |
|
||||
| MLP Gate | `gate_proj` | `dense_h_to_4h` (部分) |
|
||||
| MLP Up | `up_proj` | `dense_h_to_4h` (部分) |
|
||||
| MLP Down | `down_proj` | `dense_4h_to_h` |
|
||||
| LayerNorm | `input_layernorm` | `input_layernorm` |
|
||||
| Post-Attention LN | `post_attention_layernorm` | `post_attention_layernorm` |
|
||||
|
||||
### 3.2 实现权重转换
|
||||
|
||||
```python
|
||||
def convert_glm4_weights(name, param):
|
||||
"""将GLM-4权重名称转换为nanovllm格式"""
|
||||
# 处理合并的QKV权重
|
||||
if "query_key_value" in name:
|
||||
# 拆分为q, k, v
|
||||
q, k, v = param.split([q_size, kv_size, kv_size], dim=0)
|
||||
return {"q_proj": q, "k_proj": k, "v_proj": v}
|
||||
|
||||
# 处理合并的gate+up权重
|
||||
if "dense_h_to_4h" in name:
|
||||
gate, up = param.chunk(2, dim=0)
|
||||
return {"gate_proj": gate, "up_proj": up}
|
||||
|
||||
return {name: param}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 4. EOS Token处理
|
||||
|
||||
### 4.1 问题
|
||||
|
||||
某些模型使用**多个EOS tokens**:
|
||||
|
||||
| 模型 | EOS Token(s) | 说明 |
|
||||
|------|--------------|------|
|
||||
| Llama | `128001` | 单一EOS |
|
||||
| Qwen | `151643` | 单一EOS |
|
||||
| GLM-4 | `[151329, 151336, 151338]` | 多个:endoftext, user, observation |
|
||||
|
||||
**问题**:`tokenizer.eos_token_id` 只返回第一个,导致模型不会在其他EOS token处停止。
|
||||
|
||||
### 4.2 解决方案
|
||||
|
||||
```python
|
||||
# config.py - 支持多个EOS
|
||||
eos: int | list[int] = -1
|
||||
|
||||
# llm_engine.py - 从hf_config读取完整EOS列表
|
||||
eos_from_config = getattr(config.hf_config, 'eos_token_id', None)
|
||||
if eos_from_config is not None:
|
||||
config.eos = eos_from_config
|
||||
else:
|
||||
config.eos = self.tokenizer.eos_token_id
|
||||
|
||||
# scheduler.py - 使用set进行高效查找
|
||||
self.eos_set = set(eos) if isinstance(eos, list) else {eos}
|
||||
|
||||
# 检查时使用 in 而不是 ==
|
||||
if token_id in self.eos_set:
|
||||
# 停止生成
|
||||
```
|
||||
|
||||
### 4.3 调试EOS问题
|
||||
|
||||
**症状**:模型总是生成到max_tokens才停止
|
||||
|
||||
**调试方法**:
|
||||
```python
|
||||
# 检查EOS配置
|
||||
print(f"tokenizer.eos_token_id: {tokenizer.eos_token_id}")
|
||||
print(f"hf_config.eos_token_id: {config.hf_config.eos_token_id}")
|
||||
|
||||
# 检查输出中的EOS tokens
|
||||
output = llm.generate([prompt], params)
|
||||
for eos_id in [151329, 151336, 151338]:
|
||||
if eos_id in output[0]['token_ids']:
|
||||
print(f"Found EOS {eos_id} at position {output[0]['token_ids'].index(eos_id)}")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 5. Chat Template
|
||||
|
||||
不同模型使用不同的对话模板:
|
||||
|
||||
| 模型 | 模板格式 |
|
||||
|------|----------|
|
||||
| Llama-3 | `<\|begin_of_text\|><\|start_header_id\|>user<\|end_header_id\|>\n{content}<\|eot_id\|><\|start_header_id\|>assistant<\|end_header_id\|>\n` |
|
||||
| Qwen | `<\|im_start\|>user\n{content}<\|im_end\|>\n<\|im_start\|>assistant\n` |
|
||||
| GLM-4 | `[gMASK]<sop><\|user\|>\n{content}<\|assistant\|>\n` |
|
||||
|
||||
### 实现模板转换
|
||||
|
||||
```python
|
||||
def convert_to_model_prompt(prompt: str, model_type: str) -> str:
|
||||
"""将标准prompt转换为模型特定格式"""
|
||||
if model_type == "glm4":
|
||||
return f"[gMASK]<sop><|user|>\n{prompt}<|assistant|>\n"
|
||||
elif model_type == "llama3":
|
||||
return f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n"
|
||||
# ...
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 6. 验证清单
|
||||
|
||||
整合新模型后,按以下顺序验证:
|
||||
|
||||
### 6.1 权重加载验证
|
||||
|
||||
```python
|
||||
# 检查所有权重是否正确加载
|
||||
for name, param in model.named_parameters():
|
||||
if param.abs().sum() == 0:
|
||||
print(f"WARNING: {name} is all zeros!")
|
||||
```
|
||||
|
||||
### 6.2 单层输出验证
|
||||
|
||||
```python
|
||||
# 对比embedding层输出
|
||||
my_emb = my_model.embed_tokens(input_ids)
|
||||
hf_emb = hf_model.model.embed_tokens(input_ids)
|
||||
print(f"Embedding diff: {(my_emb - hf_emb).abs().max()}") # < 1e-5
|
||||
|
||||
# 对比第一层输出
|
||||
my_out = my_model.layers[0](my_emb, ...)
|
||||
hf_out = hf_model.model.layers[0](hf_emb, ...)
|
||||
print(f"Layer 0 diff: {(my_out - hf_out).abs().max()}") # < 1e-4
|
||||
```
|
||||
|
||||
### 6.3 生成质量验证
|
||||
|
||||
```python
|
||||
# 简单问答测试
|
||||
prompt = "Hello, how are you?"
|
||||
output = llm.generate([prompt], SamplingParams(max_tokens=50))
|
||||
print(output[0]['text']) # 应该是连贯的回答
|
||||
|
||||
# 检查是否正确停止
|
||||
print(f"Generated {len(output[0]['token_ids'])} tokens (max=50)")
|
||||
```
|
||||
|
||||
### 6.4 RULER基准测试
|
||||
|
||||
```bash
|
||||
# 运行1个sample快速验证
|
||||
python tests/test_ruler.py --model <path> --num-samples 1
|
||||
|
||||
# 验证通过后运行完整测试
|
||||
python tests/test_ruler.py --model <path> --num-samples 100
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 7. 常见问题速查
|
||||
|
||||
| 症状 | 可能原因 | 解决方案 |
|
||||
|------|----------|----------|
|
||||
| 输出乱码/重复 | RoPE实现错误 | 检查旋转方式(interleaved vs half)和旋转维度(full vs partial) |
|
||||
| 数值爆炸(NaN/Inf) | 权重加载错误或dtype不匹配 | 检查权重映射,确保dtype一致 |
|
||||
| 不停止生成 | EOS token处理错误 | 从hf_config读取完整EOS列表 |
|
||||
| 输出质量差 | LayerNorm或bias缺失 | 检查add_qkv_bias等配置 |
|
||||
| 位置编码错误 | max_position_embeddings读取错误 | 检查配置字段名称(seq_length等) |
|
||||
|
||||
---
|
||||
|
||||
## 8. 文件结构
|
||||
|
||||
新模型整合需要修改/创建的文件:
|
||||
|
||||
```
|
||||
nanovllm/
|
||||
├── models/
|
||||
│ └── <model>.py # 新建:模型定义
|
||||
├── layers/
|
||||
│ └── rotary_embedding.py # 修改:如需特殊RoPE
|
||||
├── utils/
|
||||
│ └── loader.py # 修改:权重加载
|
||||
├── config.py # 可能修改:新配置字段
|
||||
└── engine/
|
||||
├── llm_engine.py # 可能修改:EOS处理
|
||||
└── scheduler.py # 可能修改:EOS检查
|
||||
tests/
|
||||
└── test_ruler.py # 修改:chat template
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 附录:GLM-4整合案例
|
||||
|
||||
### 遇到的问题及解决
|
||||
|
||||
1. **配置字段差异** → 添加getattr fallback链
|
||||
2. **Interleaved RoPE** → 实现`apply_rotary_emb_interleaved`
|
||||
3. **Partial rotation (head_dim//2)** → 实现`GLM4RotaryEmbedding`
|
||||
4. **多EOS tokens** → 修改config/llm_engine/scheduler支持list
|
||||
5. **合并的QKV权重** → 在loader中拆分
|
||||
|
||||
### 关键代码位置
|
||||
|
||||
- RoPE实现: `nanovllm/layers/rotary_embedding.py:GLM4RotaryEmbedding`
|
||||
- 模型定义: `nanovllm/models/glm4.py`
|
||||
- 权重加载: `nanovllm/utils/loader.py:load_glm4_weights`
|
||||
- EOS处理: `nanovllm/engine/scheduler.py:eos_set`
|
||||
@@ -22,7 +22,7 @@ class Config:
|
||||
tensor_parallel_size: int = 1
|
||||
enforce_eager: bool = False
|
||||
hf_config: AutoConfig | None = None
|
||||
eos: int = -1
|
||||
eos: int | list[int] = -1 # Single EOS token or list of EOS tokens (e.g., GLM-4)
|
||||
kvcache_block_size: int = 1024
|
||||
num_kvcache_blocks: int = -1
|
||||
dtype: str | None = None # "float16", "bfloat16", or None (use model default)
|
||||
@@ -57,8 +57,11 @@ class Config:
|
||||
assert os.path.isdir(self.model)
|
||||
assert self.kvcache_block_size % 256 == 0
|
||||
assert 1 <= self.tensor_parallel_size <= 8
|
||||
self.hf_config = AutoConfig.from_pretrained(self.model)
|
||||
self.max_model_len = min(self.max_model_len, self.hf_config.max_position_embeddings)
|
||||
self.hf_config = AutoConfig.from_pretrained(self.model, trust_remote_code=True)
|
||||
# Get max position embeddings (GLM-4 uses seq_length instead of max_position_embeddings)
|
||||
max_pos = getattr(self.hf_config, 'max_position_embeddings',
|
||||
getattr(self.hf_config, 'seq_length', 4096))
|
||||
self.max_model_len = min(self.max_model_len, max_pos)
|
||||
assert self.max_num_batched_tokens >= self.max_model_len
|
||||
|
||||
# Override torch_dtype if user specified
|
||||
|
||||
@@ -30,8 +30,14 @@ class LLMEngine:
|
||||
self.ps.append(process)
|
||||
self.events.append(event)
|
||||
self.model_runner = ModelRunner(config, 0, self.events)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True)
|
||||
config.eos = self.tokenizer.eos_token_id
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True, trust_remote_code=True)
|
||||
# Get EOS token(s) from config (may be int or list, e.g., GLM-4 uses list)
|
||||
# Prefer hf_config.eos_token_id which contains full list, fallback to tokenizer
|
||||
eos_from_config = getattr(config.hf_config, 'eos_token_id', None)
|
||||
if eos_from_config is not None:
|
||||
config.eos = eos_from_config
|
||||
else:
|
||||
config.eos = self.tokenizer.eos_token_id
|
||||
# Set Sequence.block_size to match the KV cache block size
|
||||
Sequence.block_size = config.kvcache_block_size
|
||||
self.scheduler = Scheduler(config, self.model_runner.kvcache_manager)
|
||||
|
||||
@@ -30,6 +30,18 @@ def _find_free_port() -> int:
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
def get_num_kv_heads(hf_config) -> int:
|
||||
"""Get number of KV heads from config (handles GLM-4's multi_query_group_num)."""
|
||||
return getattr(hf_config, 'num_key_value_heads',
|
||||
getattr(hf_config, 'multi_query_group_num', hf_config.num_attention_heads))
|
||||
|
||||
|
||||
def get_head_dim(hf_config) -> int:
|
||||
"""Get head dimension from config (handles GLM-4's kv_channels)."""
|
||||
return getattr(hf_config, "head_dim",
|
||||
getattr(hf_config, "kv_channels", hf_config.hidden_size // hf_config.num_attention_heads))
|
||||
|
||||
|
||||
class ModelRunner:
|
||||
|
||||
def __init__(self, config: Config, rank: int, event: Event | list[Event]):
|
||||
@@ -144,8 +156,8 @@ class ModelRunner:
|
||||
used = total - free
|
||||
peak = torch.cuda.memory_stats()["allocated_bytes.all.peak"]
|
||||
current = torch.cuda.memory_stats()["allocated_bytes.all.current"]
|
||||
num_kv_heads = hf_config.num_key_value_heads // self.world_size
|
||||
head_dim = getattr(hf_config, "head_dim", hf_config.hidden_size // hf_config.num_attention_heads)
|
||||
num_kv_heads = get_num_kv_heads(hf_config) // self.world_size
|
||||
head_dim = get_head_dim(hf_config)
|
||||
block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * head_dim * hf_config.torch_dtype.itemsize
|
||||
|
||||
# Calculate max GPU blocks based on available memory
|
||||
@@ -787,8 +799,8 @@ class ModelRunner:
|
||||
- LastGraph: o_proj → post_norm → mlp → final_norm
|
||||
"""
|
||||
hf_config = self.config.hf_config
|
||||
num_kv_heads = hf_config.num_key_value_heads // self.world_size
|
||||
head_dim = getattr(hf_config, "head_dim", hf_config.hidden_size // hf_config.num_attention_heads)
|
||||
num_kv_heads = get_num_kv_heads(hf_config) // self.world_size
|
||||
head_dim = get_head_dim(hf_config)
|
||||
|
||||
# Create Decode Graph Manager (seq_len=1)
|
||||
self.decode_graph_manager = OffloadGraphManager(
|
||||
|
||||
@@ -15,7 +15,9 @@ class Scheduler:
|
||||
def __init__(self, config: Config, kvcache_manager: "KVCacheManager"):
|
||||
self.max_num_seqs = config.max_num_seqs
|
||||
self.max_num_batched_tokens = config.max_num_batched_tokens
|
||||
self.eos = config.eos
|
||||
# Convert EOS to set for efficient lookup (supports single int or list)
|
||||
eos = config.eos
|
||||
self.eos_set = set(eos) if isinstance(eos, list) else {eos}
|
||||
self.kvcache_manager = kvcache_manager
|
||||
self.waiting: deque[Sequence] = deque()
|
||||
self.running: deque[Sequence] = deque()
|
||||
@@ -94,7 +96,7 @@ class Scheduler:
|
||||
def postprocess(self, seqs: list[Sequence], token_ids: list[int]) -> list[bool]:
|
||||
for seq, token_id in zip(seqs, token_ids):
|
||||
seq.append_token(token_id)
|
||||
if (not seq.ignore_eos and token_id == self.eos) or seq.num_completion_tokens == seq.max_tokens:
|
||||
if (not seq.ignore_eos and token_id in self.eos_set) or seq.num_completion_tokens == seq.max_tokens:
|
||||
seq.status = SequenceStatus.FINISHED
|
||||
self.kvcache_manager.deallocate(seq)
|
||||
self.running.remove(seq)
|
||||
|
||||
@@ -8,12 +8,43 @@ def apply_rotary_emb(
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Non-interleaved RoPE (used by Llama, Qwen, etc.)"""
|
||||
x1, x2 = torch.chunk(x.float(), 2, dim=-1)
|
||||
y1 = x1 * cos - x2 * sin
|
||||
y2 = x2 * cos + x1 * sin
|
||||
return torch.cat((y1, y2), dim=-1).to(x.dtype)
|
||||
|
||||
|
||||
def apply_rotary_emb_interleaved(
|
||||
x: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Interleaved RoPE (used by GLM-4, etc.)
|
||||
|
||||
Args:
|
||||
x: [seq_len, num_heads, head_dim]
|
||||
cos: [seq_len, 1, head_dim // 2]
|
||||
sin: [seq_len, 1, head_dim // 2]
|
||||
|
||||
x is reshaped to [seq_len, num_heads, head_dim // 2, 2] where:
|
||||
- x[..., 0] are even positions
|
||||
- x[..., 1] are odd positions
|
||||
"""
|
||||
rot_dim = x.shape[-1]
|
||||
# x_shaped: [seq_len, num_heads, rot_dim // 2, 2]
|
||||
x_shaped = x.float().reshape(*x.shape[:-1], rot_dim // 2, 2)
|
||||
# x_0, x_1: [seq_len, num_heads, rot_dim // 2]
|
||||
x_0 = x_shaped[..., 0]
|
||||
x_1 = x_shaped[..., 1]
|
||||
# cos/sin: [seq_len, 1, rot_dim // 2] - broadcasts to num_heads
|
||||
x_out = torch.stack([
|
||||
x_0 * cos - x_1 * sin,
|
||||
x_1 * cos + x_0 * sin,
|
||||
], dim=-1)
|
||||
return x_out.flatten(-2).to(x.dtype)
|
||||
|
||||
|
||||
class RotaryEmbedding(nn.Module):
|
||||
|
||||
def __init__(
|
||||
@@ -140,6 +171,76 @@ class Llama3RotaryEmbedding(nn.Module):
|
||||
return query, key
|
||||
|
||||
|
||||
class GLM4RotaryEmbedding(nn.Module):
|
||||
"""
|
||||
GLM-4 RoPE with interleaved rotation and partial rotation.
|
||||
|
||||
GLM-4 uses:
|
||||
- Interleaved rotation (pairs adjacent elements, not first/second half)
|
||||
- rope_ratio to scale base: base = 10000 * rope_ratio
|
||||
- Partial rotation: only rotates first rotary_dim elements, rest pass through
|
||||
- rotary_dim = head_dim // 2 (only half of head_dim is rotated)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: float,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.head_size = head_size
|
||||
self.rotary_dim = rotary_dim # GLM-4: rotary_dim = head_dim // 2
|
||||
# inv_freq shape: [rotary_dim // 2]
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, rotary_dim, 2, dtype=torch.float) / rotary_dim))
|
||||
t = torch.arange(max_position_embeddings, dtype=torch.float)
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq) # [max_pos, rotary_dim // 2]
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
# cache shape [max_pos, 1, rotary_dim // 2, 2]
|
||||
cache = torch.stack((cos, sin), dim=-1).unsqueeze_(1)
|
||||
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||
|
||||
@torch.compile
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Apply RoPE to query and key.
|
||||
|
||||
Args:
|
||||
positions: [seq_len]
|
||||
query: [seq_len, num_heads, head_dim]
|
||||
key: [seq_len, num_kv_heads, head_dim]
|
||||
|
||||
Returns:
|
||||
Rotated query and key with same shapes as input.
|
||||
"""
|
||||
cache = self.cos_sin_cache[positions] # [seq_len, 1, rotary_dim // 2, 2]
|
||||
cos = cache[..., 0] # [seq_len, 1, rotary_dim // 2]
|
||||
sin = cache[..., 1] # [seq_len, 1, rotary_dim // 2]
|
||||
|
||||
# Split into rotated and pass-through parts
|
||||
q_rot = query[..., :self.rotary_dim]
|
||||
q_pass = query[..., self.rotary_dim:]
|
||||
k_rot = key[..., :self.rotary_dim]
|
||||
k_pass = key[..., self.rotary_dim:]
|
||||
|
||||
# Apply interleaved RoPE to rotated part
|
||||
q_rot = apply_rotary_emb_interleaved(q_rot, cos, sin)
|
||||
k_rot = apply_rotary_emb_interleaved(k_rot, cos, sin)
|
||||
|
||||
# Concatenate rotated and pass-through parts
|
||||
query = torch.cat([q_rot, q_pass], dim=-1)
|
||||
key = torch.cat([k_rot, k_pass], dim=-1)
|
||||
|
||||
return query, key
|
||||
|
||||
|
||||
# Cache for RoPE instances (keyed by hashable parameters)
|
||||
_rope_cache: dict[tuple, nn.Module] = {}
|
||||
|
||||
@@ -150,10 +251,11 @@ def get_rope(
|
||||
max_position: int,
|
||||
base: float,
|
||||
rope_scaling: dict | None = None,
|
||||
is_interleaved: bool = False,
|
||||
):
|
||||
# Create hashable cache key
|
||||
if rope_scaling is None:
|
||||
cache_key = (head_size, rotary_dim, max_position, base, None)
|
||||
cache_key = (head_size, rotary_dim, max_position, base, None, is_interleaved)
|
||||
else:
|
||||
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default"))
|
||||
if rope_type == "llama3":
|
||||
@@ -163,15 +265,19 @@ def get_rope(
|
||||
rope_scaling["low_freq_factor"],
|
||||
rope_scaling["high_freq_factor"],
|
||||
rope_scaling["original_max_position_embeddings"],
|
||||
is_interleaved,
|
||||
)
|
||||
else:
|
||||
cache_key = (head_size, rotary_dim, max_position, base, rope_type)
|
||||
cache_key = (head_size, rotary_dim, max_position, base, rope_type, is_interleaved)
|
||||
|
||||
if cache_key in _rope_cache:
|
||||
return _rope_cache[cache_key]
|
||||
|
||||
if rope_scaling is None:
|
||||
rope = RotaryEmbedding(head_size, rotary_dim, max_position, base)
|
||||
if is_interleaved:
|
||||
rope = GLM4RotaryEmbedding(head_size, rotary_dim, max_position, base)
|
||||
else:
|
||||
rope = RotaryEmbedding(head_size, rotary_dim, max_position, base)
|
||||
else:
|
||||
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default"))
|
||||
if rope_type == "llama3":
|
||||
|
||||
@@ -3,7 +3,9 @@
|
||||
from nanovllm.models.registry import register_model, get_model_class, MODEL_REGISTRY
|
||||
|
||||
# Import models to trigger registration
|
||||
from nanovllm.models import qwen2
|
||||
from nanovllm.models import qwen3
|
||||
from nanovllm.models import llama
|
||||
from nanovllm.models import glm4
|
||||
|
||||
__all__ = ["register_model", "get_model_class", "MODEL_REGISTRY"]
|
||||
|
||||
235
nanovllm/models/glm4.py
Normal file
235
nanovllm/models/glm4.py
Normal file
@@ -0,0 +1,235 @@
|
||||
"""GLM-4 model implementation for nano-vllm."""
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.distributed as dist
|
||||
|
||||
from nanovllm.layers.activation import SiluAndMul
|
||||
from nanovllm.layers.attention import Attention
|
||||
from nanovllm.layers.layernorm import RMSNorm
|
||||
from nanovllm.layers.linear import QKVParallelLinear, MergedColumnParallelLinear, RowParallelLinear
|
||||
from nanovllm.layers.rotary_embedding import get_rope
|
||||
from nanovllm.layers.embed_head import VocabParallelEmbedding, ParallelLMHead
|
||||
from nanovllm.models.registry import register_model
|
||||
|
||||
|
||||
class GLM4Attention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
max_position: int = 1048576,
|
||||
head_dim: int = 128,
|
||||
rope_theta: float = 10000,
|
||||
rope_scaling: dict | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
tp_size = dist.get_world_size()
|
||||
self.total_num_heads = num_heads
|
||||
assert self.total_num_heads % tp_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_size
|
||||
self.total_num_kv_heads = num_kv_heads
|
||||
assert self.total_num_kv_heads % tp_size == 0
|
||||
self.num_kv_heads = self.total_num_kv_heads // tp_size
|
||||
self.head_dim = head_dim
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim ** -0.5
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size,
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=True, # GLM-4 has QKV bias
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False, # GLM-4 has no output bias
|
||||
)
|
||||
# GLM-4 only rotates half of head_dim
|
||||
rotary_dim = self.head_dim // 2
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=rotary_dim,
|
||||
max_position=max_position,
|
||||
base=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
is_interleaved=True, # GLM-4 uses interleaved RoPE
|
||||
)
|
||||
self.attn = Attention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
self.num_kv_heads,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
qkv = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q = q.view(-1, self.num_heads, self.head_dim)
|
||||
k = k.view(-1, self.num_kv_heads, self.head_dim)
|
||||
v = v.view(-1, self.num_kv_heads, self.head_dim)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
o = self.attn(q, k, v)
|
||||
output = self.o_proj(o.flatten(1, -1))
|
||||
return output
|
||||
|
||||
|
||||
class GLM4MLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size,
|
||||
[intermediate_size] * 2,
|
||||
bias=False, # GLM-4 has no MLP bias
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
)
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
def forward(self, x):
|
||||
gate_up = self.gate_up_proj(x)
|
||||
x = self.act_fn(gate_up)
|
||||
x = self.down_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class GLM4DecoderLayer(nn.Module):
|
||||
|
||||
def __init__(self, config) -> None:
|
||||
super().__init__()
|
||||
# GLM-4 config field mapping
|
||||
hidden_size = config.hidden_size
|
||||
num_heads = config.num_attention_heads
|
||||
num_kv_heads = getattr(config, 'multi_query_group_num', num_heads)
|
||||
head_dim = getattr(config, 'kv_channels', hidden_size // num_heads)
|
||||
max_position = getattr(config, 'seq_length', 1048576)
|
||||
rope_ratio = getattr(config, 'rope_ratio', 1)
|
||||
rope_theta = 10000 * rope_ratio # GLM-4 uses rope_ratio to scale base
|
||||
intermediate_size = getattr(config, 'ffn_hidden_size', getattr(config, 'intermediate_size', None))
|
||||
rms_norm_eps = getattr(config, 'layernorm_epsilon', 1e-5)
|
||||
|
||||
self.self_attn = GLM4Attention(
|
||||
hidden_size=hidden_size,
|
||||
num_heads=num_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
max_position=max_position,
|
||||
head_dim=head_dim,
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=getattr(config, "rope_scaling", None),
|
||||
)
|
||||
self.mlp = GLM4MLP(
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
)
|
||||
self.input_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: torch.Tensor | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if residual is None:
|
||||
hidden_states, residual = self.input_layernorm(hidden_states), hidden_states
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
hidden_states = self.self_attn(positions, hidden_states)
|
||||
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
class GLM4Model(nn.Module):
|
||||
|
||||
def __init__(self, config) -> None:
|
||||
super().__init__()
|
||||
vocab_size = getattr(config, 'padded_vocab_size', config.vocab_size)
|
||||
num_layers = getattr(config, 'num_layers', config.num_hidden_layers)
|
||||
rms_norm_eps = getattr(config, 'layernorm_epsilon', 1e-5)
|
||||
|
||||
self.embed_tokens = VocabParallelEmbedding(vocab_size, config.hidden_size)
|
||||
self.layers = nn.ModuleList([GLM4DecoderLayer(config) for _ in range(num_layers)])
|
||||
self.norm = RMSNorm(config.hidden_size, eps=rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
residual = None
|
||||
for layer in self.layers:
|
||||
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@register_model("ChatGLMModel", "ChatGLMForConditionalGeneration")
|
||||
class ChatGLMForCausalLM(nn.Module):
|
||||
"""
|
||||
GLM-4 model for causal language modeling.
|
||||
|
||||
Weight mapping from HuggingFace to nanovllm:
|
||||
- transformer.embedding.word_embeddings → model.embed_tokens
|
||||
- transformer.encoder.layers.X.input_layernorm → model.layers.X.input_layernorm
|
||||
- transformer.encoder.layers.X.self_attention.query_key_value → model.layers.X.self_attn.qkv_proj (split q/k/v)
|
||||
- transformer.encoder.layers.X.self_attention.dense → model.layers.X.self_attn.o_proj
|
||||
- transformer.encoder.layers.X.post_attention_layernorm → model.layers.X.post_attention_layernorm
|
||||
- transformer.encoder.layers.X.mlp.dense_h_to_4h → model.layers.X.mlp.gate_up_proj (split gate/up)
|
||||
- transformer.encoder.layers.X.mlp.dense_4h_to_h → model.layers.X.mlp.down_proj
|
||||
- transformer.encoder.final_layernorm → model.norm
|
||||
- transformer.output_layer → lm_head
|
||||
"""
|
||||
packed_modules_mapping = {
|
||||
# QKV is merged in GLM-4 as query_key_value
|
||||
"query_key_value": ("qkv_proj", None), # Special handling needed
|
||||
# MLP gate and up are merged as dense_h_to_4h
|
||||
"dense_h_to_4h": ("gate_up_proj", None), # Special handling needed
|
||||
}
|
||||
|
||||
# Weight name mapping for loader
|
||||
hf_to_nanovllm_mapping = {
|
||||
"transformer.embedding.word_embeddings": "model.embed_tokens",
|
||||
"transformer.encoder.final_layernorm": "model.norm",
|
||||
"transformer.output_layer": "lm_head",
|
||||
}
|
||||
|
||||
def __init__(self, config) -> None:
|
||||
super().__init__()
|
||||
vocab_size = getattr(config, 'padded_vocab_size', config.vocab_size)
|
||||
self.config = config
|
||||
self.model = GLM4Model(config)
|
||||
self.lm_head = ParallelLMHead(vocab_size, config.hidden_size)
|
||||
# GLM-4 does not tie embeddings
|
||||
# if getattr(config, 'tie_word_embeddings', False):
|
||||
# self.lm_head.weight.data = self.model.embed_tokens.weight.data
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return self.model(input_ids, positions)
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return self.lm_head(hidden_states)
|
||||
207
nanovllm/models/qwen2.py
Normal file
207
nanovllm/models/qwen2.py
Normal file
@@ -0,0 +1,207 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.distributed as dist
|
||||
from transformers import Qwen2Config
|
||||
|
||||
from nanovllm.layers.activation import SiluAndMul
|
||||
from nanovllm.layers.attention import Attention
|
||||
from nanovllm.layers.layernorm import RMSNorm
|
||||
from nanovllm.layers.linear import QKVParallelLinear, MergedColumnParallelLinear, RowParallelLinear
|
||||
from nanovllm.layers.rotary_embedding import get_rope
|
||||
from nanovllm.layers.embed_head import VocabParallelEmbedding, ParallelLMHead
|
||||
from nanovllm.models.registry import register_model
|
||||
|
||||
|
||||
class Qwen2Attention(nn.Module):
|
||||
"""Qwen2/2.5 Attention without QK norm (unlike Qwen3)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
max_position: int = 4096 * 32,
|
||||
head_dim: int | None = None,
|
||||
rope_theta: float = 10000,
|
||||
rope_scaling: tuple | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
tp_size = dist.get_world_size()
|
||||
self.total_num_heads = num_heads
|
||||
assert self.total_num_heads % tp_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_size
|
||||
self.total_num_kv_heads = num_kv_heads
|
||||
assert self.total_num_kv_heads % tp_size == 0
|
||||
self.num_kv_heads = self.total_num_kv_heads // tp_size
|
||||
self.head_dim = head_dim or hidden_size // self.total_num_heads
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim ** -0.5
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size,
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=True, # Qwen2/2.5 always uses bias
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
)
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position,
|
||||
base=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
)
|
||||
self.attn = Attention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
self.num_kv_heads,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
qkv = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q = q.view(-1, self.num_heads, self.head_dim)
|
||||
k = k.view(-1, self.num_kv_heads, self.head_dim)
|
||||
v = v.view(-1, self.num_kv_heads, self.head_dim)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
o = self.attn(q, k, v)
|
||||
output = self.o_proj(o.flatten(1, -1))
|
||||
return output
|
||||
|
||||
|
||||
class Qwen2MLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size,
|
||||
[intermediate_size] * 2,
|
||||
bias=False,
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
)
|
||||
assert hidden_act == "silu"
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
def forward(self, x):
|
||||
gate_up = self.gate_up_proj(x)
|
||||
x = self.act_fn(gate_up)
|
||||
x = self.down_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class Qwen2DecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Qwen2Config,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.self_attn = Qwen2Attention(
|
||||
hidden_size=config.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
max_position=config.max_position_embeddings,
|
||||
head_dim=getattr(config, 'head_dim', None),
|
||||
rope_theta=getattr(config, "rope_theta", 1000000),
|
||||
rope_scaling=getattr(config, "rope_scaling", None),
|
||||
)
|
||||
self.mlp = Qwen2MLP(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: torch.Tensor | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if residual is None:
|
||||
hidden_states, residual = self.input_layernorm(hidden_states), hidden_states
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
hidden_states = self.self_attn(positions, hidden_states)
|
||||
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
class Qwen2Model(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Qwen2Config,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size)
|
||||
self.layers = nn.ModuleList([Qwen2DecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
residual = None
|
||||
for layer in self.layers:
|
||||
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@register_model("Qwen2ForCausalLM")
|
||||
class Qwen2ForCausalLM(nn.Module):
|
||||
packed_modules_mapping = {
|
||||
"q_proj": ("qkv_proj", "q"),
|
||||
"k_proj": ("qkv_proj", "k"),
|
||||
"v_proj": ("qkv_proj", "v"),
|
||||
"gate_proj": ("gate_up_proj", 0),
|
||||
"up_proj": ("gate_up_proj", 1),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Qwen2Config
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.model = Qwen2Model(config)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
if config.tie_word_embeddings:
|
||||
self.lm_head.weight.data = self.model.embed_tokens.weight.data
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return self.model(input_ids, positions)
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return self.lm_head(hidden_states)
|
||||
@@ -187,7 +187,7 @@ class Qwen3Model(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
@register_model("Qwen3ForCausalLM", "Qwen2ForCausalLM")
|
||||
@register_model("Qwen3ForCausalLM")
|
||||
class Qwen3ForCausalLM(nn.Module):
|
||||
packed_modules_mapping = {
|
||||
"q_proj": ("qkv_proj", "q"),
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
import re
|
||||
from glob import glob
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -9,20 +10,146 @@ def default_weight_loader(param: nn.Parameter, loaded_weight: torch.Tensor):
|
||||
param.data.copy_(loaded_weight)
|
||||
|
||||
|
||||
# GLM-4 weight name mappings
|
||||
GLM4_NAME_MAPPING = {
|
||||
"transformer.embedding.word_embeddings": "model.embed_tokens",
|
||||
"transformer.encoder.final_layernorm": "model.norm",
|
||||
"transformer.output_layer": "lm_head",
|
||||
}
|
||||
|
||||
GLM4_LAYER_MAPPING = {
|
||||
"self_attention.query_key_value": "self_attn.qkv_proj",
|
||||
"self_attention.dense": "self_attn.o_proj",
|
||||
"mlp.dense_h_to_4h": "mlp.gate_up_proj",
|
||||
"mlp.dense_4h_to_h": "mlp.down_proj",
|
||||
}
|
||||
|
||||
|
||||
def convert_glm4_weight_name(weight_name: str) -> tuple[str, str | None]:
|
||||
"""
|
||||
Convert GLM-4 weight name to nanovllm format.
|
||||
|
||||
Returns:
|
||||
tuple: (converted_name, shard_id) where shard_id is used for packed modules
|
||||
Returns (None, None) for weights that should be skipped
|
||||
"""
|
||||
# Skip rotary embedding weights (we use our own RoPE implementation)
|
||||
if "rotary_pos_emb" in weight_name:
|
||||
return None, None
|
||||
|
||||
# Check direct mappings first
|
||||
for glm_name, nano_name in GLM4_NAME_MAPPING.items():
|
||||
if weight_name.startswith(glm_name):
|
||||
return weight_name.replace(glm_name, nano_name), None
|
||||
|
||||
# Handle layer weights: transformer.encoder.layers.X.xxx
|
||||
layer_match = re.match(r"transformer\.encoder\.layers\.(\d+)\.(.+)", weight_name)
|
||||
if layer_match:
|
||||
layer_idx = layer_match.group(1)
|
||||
remainder = layer_match.group(2)
|
||||
|
||||
# Handle packed modules (QKV and gate_up)
|
||||
for glm_subname, nano_subname in GLM4_LAYER_MAPPING.items():
|
||||
if remainder.startswith(glm_subname):
|
||||
suffix = remainder[len(glm_subname):] # .weight or .bias
|
||||
new_name = f"model.layers.{layer_idx}.{nano_subname}{suffix}"
|
||||
|
||||
# Determine shard_id for packed modules
|
||||
if "qkv_proj" in nano_subname:
|
||||
return new_name, "qkv" # Special marker for GLM4 QKV
|
||||
elif "gate_up_proj" in nano_subname:
|
||||
return new_name, "gate_up" # Special marker for GLM4 gate_up
|
||||
else:
|
||||
return new_name, None
|
||||
|
||||
# Handle non-packed layer weights (layernorms)
|
||||
new_name = f"model.layers.{layer_idx}.{remainder}"
|
||||
return new_name, None
|
||||
|
||||
# No mapping found, return original
|
||||
return weight_name, None
|
||||
|
||||
|
||||
def load_glm4_qkv(param: nn.Parameter, loaded_weight: torch.Tensor, config):
|
||||
"""Load GLM-4 merged QKV weights by splitting into q, k, v."""
|
||||
num_heads = config.num_attention_heads
|
||||
num_kv_heads = getattr(config, 'multi_query_group_num', num_heads)
|
||||
head_dim = getattr(config, 'kv_channels', config.hidden_size // num_heads)
|
||||
|
||||
q_size = num_heads * head_dim
|
||||
kv_size = num_kv_heads * head_dim
|
||||
|
||||
# Split QKV: [q_size + kv_size + kv_size, hidden_size]
|
||||
q, k, v = loaded_weight.split([q_size, kv_size, kv_size], dim=0)
|
||||
|
||||
# Load each part using the weight_loader
|
||||
weight_loader = getattr(param, "weight_loader")
|
||||
weight_loader(param, q, "q")
|
||||
weight_loader(param, k, "k")
|
||||
weight_loader(param, v, "v")
|
||||
|
||||
|
||||
def load_glm4_gate_up(param: nn.Parameter, loaded_weight: torch.Tensor, config):
|
||||
"""Load GLM-4 merged gate_up weights by splitting into gate, up."""
|
||||
ffn_hidden_size = getattr(config, 'ffn_hidden_size', getattr(config, 'intermediate_size', None))
|
||||
|
||||
# Split gate_up: [ffn_hidden_size * 2, hidden_size]
|
||||
gate, up = loaded_weight.split([ffn_hidden_size, ffn_hidden_size], dim=0)
|
||||
|
||||
# Load each part using the weight_loader
|
||||
weight_loader = getattr(param, "weight_loader")
|
||||
weight_loader(param, gate, 0) # gate_proj is shard 0
|
||||
weight_loader(param, up, 1) # up_proj is shard 1
|
||||
|
||||
|
||||
def is_glm4_model(model: nn.Module) -> bool:
|
||||
"""Check if the model is a GLM-4 model."""
|
||||
return model.__class__.__name__ in ("ChatGLMForCausalLM",)
|
||||
|
||||
|
||||
def load_model(model: nn.Module, path: str):
|
||||
packed_modules_mapping = getattr(model, "packed_modules_mapping", {})
|
||||
is_glm4 = is_glm4_model(model)
|
||||
config = getattr(model, "config", None)
|
||||
|
||||
for file in glob(os.path.join(path, "*.safetensors")):
|
||||
with safe_open(file, "pt", "cpu") as f:
|
||||
for weight_name in f.keys():
|
||||
loaded_weight = f.get_tensor(weight_name)
|
||||
|
||||
# GLM-4 specific handling
|
||||
if is_glm4:
|
||||
param_name, shard_id = convert_glm4_weight_name(weight_name)
|
||||
|
||||
# Skip weights that don't need to be loaded
|
||||
if param_name is None:
|
||||
continue
|
||||
|
||||
if shard_id == "qkv":
|
||||
param = model.get_parameter(param_name)
|
||||
load_glm4_qkv(param, loaded_weight, config)
|
||||
continue
|
||||
elif shard_id == "gate_up":
|
||||
param = model.get_parameter(param_name)
|
||||
load_glm4_gate_up(param, loaded_weight, config)
|
||||
continue
|
||||
else:
|
||||
# Regular weight, use converted name
|
||||
param = model.get_parameter(param_name)
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
continue
|
||||
|
||||
# Original loading logic for other models
|
||||
for k in packed_modules_mapping:
|
||||
if k in weight_name:
|
||||
v, shard_id = packed_modules_mapping[k]
|
||||
param_name = weight_name.replace(k, v)
|
||||
param = model.get_parameter(param_name)
|
||||
weight_loader = getattr(param, "weight_loader")
|
||||
weight_loader(param, f.get_tensor(weight_name), shard_id)
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
param = model.get_parameter(weight_name)
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, f.get_tensor(weight_name))
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
@@ -48,6 +48,62 @@ from nanovllm import LLM, SamplingParams
|
||||
# ============================================================
|
||||
|
||||
DEFAULT_DATA_DIR = Path(__file__).parent / "data/ruler_64k"
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Chat Template Conversion
|
||||
# ============================================================
|
||||
|
||||
def convert_llama_to_glm4_format(prompt: str) -> str:
|
||||
"""
|
||||
Convert Llama 3 chat template format to GLM-4 format.
|
||||
|
||||
Llama 3 format:
|
||||
<|begin_of_text|><|start_header_id|>user<|end_header_id|>
|
||||
|
||||
{user_content}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
{assistant_prefix}
|
||||
|
||||
GLM-4 format:
|
||||
[gMASK]<sop><|user|>
|
||||
{user_content}<|assistant|>
|
||||
{assistant_prefix}
|
||||
"""
|
||||
# Split into user content and assistant prefix
|
||||
parts = prompt.split("<|eot_id|><|start_header_id|>assistant<|end_header_id|>")
|
||||
|
||||
# Extract user content (remove Llama header tokens)
|
||||
user_content = parts[0]
|
||||
user_content = user_content.replace("<|begin_of_text|>", "")
|
||||
user_content = user_content.replace("<|start_header_id|>user<|end_header_id|>", "")
|
||||
user_content = user_content.strip()
|
||||
|
||||
# Extract assistant prefix (if exists)
|
||||
assistant_prefix = ""
|
||||
if len(parts) > 1:
|
||||
assistant_prefix = parts[1].replace("<|eot_id|>", "").strip()
|
||||
|
||||
# Apply GLM-4 format
|
||||
glm_prompt = f"[gMASK]<sop><|user|>\n{user_content}<|assistant|>"
|
||||
if assistant_prefix:
|
||||
glm_prompt += f"\n{assistant_prefix}"
|
||||
|
||||
return glm_prompt
|
||||
|
||||
|
||||
def is_glm_model(model_path: str) -> bool:
|
||||
"""Check if the model is a GLM model based on config."""
|
||||
from transformers import AutoConfig
|
||||
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
|
||||
return getattr(config, 'model_type', '') == 'chatglm'
|
||||
|
||||
|
||||
def convert_prompt_for_model(prompt: str, model_path: str) -> str:
|
||||
"""Convert prompt format based on model type."""
|
||||
if is_glm_model(model_path):
|
||||
return convert_llama_to_glm4_format(prompt)
|
||||
return prompt # Keep original format for Llama and other models
|
||||
DEFAULT_MODEL = os.path.expanduser("~/models/Llama-3.1-8B-Instruct")
|
||||
# Note: max_model_len must be > max_input_len to leave room for output tokens
|
||||
# 64k benchmark has inputs up to 65536 tokens, so we need 65536 + 128 = 65664
|
||||
@@ -161,6 +217,7 @@ def run_task_test(
|
||||
verbose: bool = True,
|
||||
llm_factory: Optional[callable] = None,
|
||||
fresh_llm: bool = False,
|
||||
model_path: Optional[str] = None,
|
||||
) -> Dict:
|
||||
"""
|
||||
Run test for a single RULER task.
|
||||
@@ -198,6 +255,9 @@ def run_task_test(
|
||||
for sample in samples:
|
||||
idx = sample.get("index", sample["_local_idx"])
|
||||
prompt = sample["input"]
|
||||
# Convert prompt format for GLM models
|
||||
if model_path:
|
||||
prompt = convert_prompt_for_model(prompt, model_path)
|
||||
expected = sample["outputs"]
|
||||
|
||||
# Fresh LLM mode: reinitialize for each sample
|
||||
@@ -367,6 +427,7 @@ def run_ruler_benchmark(
|
||||
verbose=verbose and not json_output,
|
||||
llm_factory=create_llm,
|
||||
fresh_llm=fresh_llm,
|
||||
model_path=model_path,
|
||||
)
|
||||
task_results.append(result)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user