From 03a8c033cbb954c8e65c6f5efe0db4d84fda0993 Mon Sep 17 00:00:00 2001 From: Zijie Tian Date: Sat, 10 Jan 2026 21:03:45 +0800 Subject: [PATCH] [claudesquad] update from 'add-llama-1' on 10 Jan 26 21:03 CST --- .claude/rules/gpu-testing.md | 88 +++++++++++++ findings.md | 160 +++++++++++++++++++++++ nanovllm/engine/model_runner.py | 5 +- nanovllm/layers/rotary_embedding.py | 141 +++++++++++++++++++- nanovllm/models/__init__.py | 9 ++ nanovllm/models/llama.py | 194 ++++++++++++++++++++++++++++ nanovllm/models/qwen3.py | 2 + nanovllm/models/registry.py | 46 +++++++ progress.md | 76 +++++++++++ task_plan.md | 144 +++++++++++++++++++++ 10 files changed, 858 insertions(+), 7 deletions(-) create mode 100644 .claude/rules/gpu-testing.md create mode 100644 findings.md create mode 100644 nanovllm/models/__init__.py create mode 100644 nanovllm/models/llama.py create mode 100644 nanovllm/models/registry.py create mode 100644 progress.md create mode 100644 task_plan.md diff --git a/.claude/rules/gpu-testing.md b/.claude/rules/gpu-testing.md new file mode 100644 index 0000000..5c0e9e5 --- /dev/null +++ b/.claude/rules/gpu-testing.md @@ -0,0 +1,88 @@ +# GPU Testing Rules + +## GPU Type Detection + +Before running any GPU test/benchmark, detect the GPU type and apply appropriate settings: + +```bash +nvidia-smi --query-gpu=name --format=csv,noheader | head -1 +``` + +### Testing Mode by GPU Type + +| GPU Type | Test Mode | Reason | +|----------|-----------|--------| +| **RTX 3090** | `--enable-offload` ONLY | Limited VRAM (24GB), must use CPU offload | +| **A100** | Both modes OK | Large VRAM (40/80GB), can test with or without offload | +| **RTX 4090** | `--enable-offload` ONLY | Limited VRAM (24GB) | +| **Other** | Ask user | Unknown VRAM capacity | + +### Example Commands + +**For 3090:** +```bash +# MUST use offload +CUDA_VISIBLE_DEVICES=X python tests/test_needle.py --model ~/models/Llama-3.1-8B-Instruct --enable-offload +``` + +**For A100:** +```bash +# Can test without offload +CUDA_VISIBLE_DEVICES=X python tests/test_needle.py --model ~/models/Llama-3.1-8B-Instruct + +# Or with offload +CUDA_VISIBLE_DEVICES=X python tests/test_needle.py --model ~/models/Llama-3.1-8B-Instruct --enable-offload +``` + +--- + +## GPU Card Assignment (CRITICAL) + +### Multi-Instance Environment + +This project runs with multiple Claude instances on different worktrees, each needing a dedicated GPU. + +### MANDATORY RULE + +**Before executing ANY GPU command:** + +1. **Check if user specified GPU**: Look for user message like "use GPU 0" or "CUDA_VISIBLE_DEVICES=1" + +2. **If user did NOT specify GPU**: + - **STOP and ASK**: "Which GPU should I use? (e.g., 0, 1, 2, ...)" + - **DO NOT assume or guess** the GPU number + - **DO NOT proceed** until user confirms + +3. **Always prefix GPU commands with `CUDA_VISIBLE_DEVICES=X`**: + ```bash + CUDA_VISIBLE_DEVICES=0 python script.py # Use GPU 0 + CUDA_VISIBLE_DEVICES=1 python script.py # Use GPU 1 + ``` + +### Example Workflow + +**Correct:** +``` +User: "Run the needle test" +Claude: "Which GPU should I use for this test?" +User: "Use GPU 2" +Claude: Runs `CUDA_VISIBLE_DEVICES=2 python tests/test_needle.py ...` +``` + +**Wrong:** +``` +User: "Run the needle test" +Claude: Runs `python tests/test_needle.py ...` # NO! Missing GPU specification! +``` + +--- + +## Combined Checklist + +Before running any GPU test: + +- [ ] User specified GPU number? If not, ASK. +- [ ] Detected GPU type? (3090 → offload only, A100 → flexible) +- [ ] GPU mutex check passed? (see commands.md) +- [ ] Command prefixed with `CUDA_VISIBLE_DEVICES=X`? +- [ ] Local package installed? (`pip install -e . --prefix=./.local --no-deps`) diff --git a/findings.md b/findings.md new file mode 100644 index 0000000..bb77faa --- /dev/null +++ b/findings.md @@ -0,0 +1,160 @@ +# Findings: Multi-Model Support Analysis + +## Current Architecture Analysis + +### Model Loading Flow +``` +LLM(model_path) + → LLMEngine.__init__() + → Config.__post_init__() + → hf_config = AutoConfig.from_pretrained(model) + → ModelRunner.__init__() + → model = Qwen3ForCausalLM(hf_config) ← HARDCODED + → load_model(model, config.model) +``` + +### Key Files +| File | Purpose | +|------|---------| +| `nanovllm/engine/model_runner.py` | 模型加载和运行 | +| `nanovllm/models/qwen3.py` | Qwen3 模型定义 | +| `nanovllm/utils/loader.py` | safetensors 权重加载 | +| `nanovllm/layers/rotary_embedding.py` | RoPE 实现 | + +--- + +## Llama 3.1 Config Analysis + +```json +{ + "architectures": ["LlamaForCausalLM"], + "model_type": "llama", + "attention_bias": false, + "mlp_bias": false, + "head_dim": 128, + "hidden_size": 4096, + "intermediate_size": 14336, + "num_attention_heads": 32, + "num_hidden_layers": 32, + "num_key_value_heads": 8, + "hidden_act": "silu", + "rms_norm_eps": 1e-05, + "rope_theta": 500000.0, + "rope_scaling": { + "factor": 8.0, + "high_freq_factor": 4.0, + "low_freq_factor": 1.0, + "original_max_position_embeddings": 8192, + "rope_type": "llama3" + }, + "max_position_embeddings": 131072, + "tie_word_embeddings": false, + "vocab_size": 128256 +} +``` + +### Llama 3 RoPE Scaling +Llama 3 使用特殊的 RoPE scaling 策略 (`rope_type: "llama3"`): +- 低频分量保持不变(对应短距离依赖) +- 高频分量线性插值(对应长距离依赖) +- 参数: `factor`, `low_freq_factor`, `high_freq_factor`, `original_max_position_embeddings` + +参考实现 (transformers): +```python +def _compute_llama3_parameters(config, device, inv_freq): + factor = config.factor + low_freq_factor = config.low_freq_factor + high_freq_factor = config.high_freq_factor + old_context_len = config.original_max_position_embeddings + + low_freq_wavelen = old_context_len / low_freq_factor + high_freq_wavelen = old_context_len / high_freq_factor + + wavelen = 2 * math.pi / inv_freq + inv_freq_llama = torch.where( + wavelen > low_freq_wavelen, + inv_freq / factor, + inv_freq + ) + smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor) + smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama + smooth_factor * inv_freq + is_medium_freq = (wavelen >= high_freq_wavelen) & (wavelen <= low_freq_wavelen) + inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama) + return inv_freq_llama +``` + +--- + +## Weight Mapping Analysis + +### Qwen3 packed_modules_mapping +```python +packed_modules_mapping = { + "q_proj": ("qkv_proj", "q"), + "k_proj": ("qkv_proj", "k"), + "v_proj": ("qkv_proj", "v"), + "gate_proj": ("gate_up_proj", 0), + "up_proj": ("gate_up_proj", 1), +} +``` + +### Llama Weight Names (from safetensors) +预期 Llama 权重命名与 Qwen3 类似: +- `model.layers.{i}.self_attn.q_proj.weight` +- `model.layers.{i}.self_attn.k_proj.weight` +- `model.layers.{i}.self_attn.v_proj.weight` +- `model.layers.{i}.self_attn.o_proj.weight` +- `model.layers.{i}.mlp.gate_proj.weight` +- `model.layers.{i}.mlp.up_proj.weight` +- `model.layers.{i}.mlp.down_proj.weight` +- `model.layers.{i}.input_layernorm.weight` +- `model.layers.{i}.post_attention_layernorm.weight` + +**结论**: Llama 的 `packed_modules_mapping` 与 Qwen3 相同,可以复用。 + +--- + +## Shared Components (Can Reuse) + +| Component | File | Notes | +|-----------|------|-------| +| `RMSNorm` | `layers/layernorm.py` | 通用 | +| `SiluAndMul` | `layers/activation.py` | 通用 | +| `Attention` | `layers/attention.py` | FlashAttention wrapper | +| `QKVParallelLinear` | `layers/linear.py` | 支持 bias=False | +| `RowParallelLinear` | `layers/linear.py` | 通用 | +| `MergedColumnParallelLinear` | `layers/linear.py` | 通用 | +| `VocabParallelEmbedding` | `layers/embed_head.py` | 通用 | +| `ParallelLMHead` | `layers/embed_head.py` | 通用 | +| `load_model` | `utils/loader.py` | 通用 | + +--- + +## Llama vs Qwen3 Implementation Diff + +### Attention +| Feature | Qwen3Attention | LlamaAttention | +|---------|----------------|----------------| +| QKV bias | 可配置 (attention_bias) | 始终 False | +| q_norm | 有 (when bias=False) | 无 | +| k_norm | 有 (when bias=False) | 无 | +| RoPE | Standard | Llama3 scaled | + +### MLP +| Feature | Qwen3MLP | LlamaMLP | +|---------|----------|----------| +| gate/up bias | False | False | +| down bias | False | False | +| hidden_act | silu | silu | + +**结论**: Llama MLP 与 Qwen3 MLP 几乎相同,可以直接复用或简化。 + +--- + +## Risk Assessment + +| Risk | Impact | Mitigation | +|------|--------|------------| +| RoPE 实现错误 | 高 - 导致错误输出 | 参考 transformers 实现,单元测试 | +| 权重映射错误 | 高 - 模型无法加载 | 检查 safetensors 键名 | +| 注册表循环导入 | 中 - 启动失败 | 延迟导入 | diff --git a/nanovllm/engine/model_runner.py b/nanovllm/engine/model_runner.py index d3db28f..308355a 100644 --- a/nanovllm/engine/model_runner.py +++ b/nanovllm/engine/model_runner.py @@ -6,7 +6,7 @@ from multiprocessing.shared_memory import SharedMemory from nanovllm.config import Config from nanovllm.engine.sequence import Sequence -from nanovllm.models.qwen3 import Qwen3ForCausalLM +from nanovllm.models import get_model_class from nanovllm.layers.sampler import GreedySampler from nanovllm.utils.context import set_context, get_context, reset_context from nanovllm.utils.loader import load_model @@ -32,7 +32,8 @@ class ModelRunner: default_dtype = torch.get_default_dtype() torch.set_default_dtype(hf_config.torch_dtype) torch.set_default_device("cuda") - self.model = Qwen3ForCausalLM(hf_config) + model_class = get_model_class(hf_config) + self.model = model_class(hf_config) load_model(self.model, config.model) self.sampler = GreedySampler() diff --git a/nanovllm/layers/rotary_embedding.py b/nanovllm/layers/rotary_embedding.py index 998d116..4eba274 100644 --- a/nanovllm/layers/rotary_embedding.py +++ b/nanovllm/layers/rotary_embedding.py @@ -1,4 +1,4 @@ -from functools import lru_cache +import math import torch from torch import nn @@ -48,7 +48,102 @@ class RotaryEmbedding(nn.Module): return query, key -@lru_cache(1) +class Llama3RotaryEmbedding(nn.Module): + """ + Llama 3 RoPE with special frequency scaling. + + Llama 3 uses a piecewise frequency adjustment: + - High frequencies (short wavelengths): unchanged + - Low frequencies (long wavelengths): scaled down by factor + - Medium frequencies: smoothly interpolated + """ + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: float, + factor: float, + low_freq_factor: float, + high_freq_factor: float, + original_max_position_embeddings: int, + ) -> None: + super().__init__() + self.head_size = head_size + assert rotary_dim == head_size + + # Compute base inv_freq + inv_freq = 1.0 / (base ** (torch.arange(0, rotary_dim, 2, dtype=torch.float) / rotary_dim)) + + # Apply Llama3 scaling + inv_freq = self._compute_llama3_inv_freq( + inv_freq, + factor, + low_freq_factor, + high_freq_factor, + original_max_position_embeddings, + ) + + # Build cos/sin cache + t = torch.arange(max_position_embeddings, dtype=torch.float) + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + cache = torch.cat((cos, sin), dim=-1).unsqueeze_(1) + self.register_buffer("cos_sin_cache", cache, persistent=False) + + def _compute_llama3_inv_freq( + self, + inv_freq: torch.Tensor, + factor: float, + low_freq_factor: float, + high_freq_factor: float, + original_max_position_embeddings: int, + ) -> torch.Tensor: + """ + Apply Llama3 frequency scaling. + + - wavelength > low_freq_wavelen: scale down by factor (long range, needs interpolation) + - wavelength < high_freq_wavelen: keep unchanged (short range, high fidelity) + - in between: smooth interpolation + """ + old_context_len = original_max_position_embeddings + + low_freq_wavelen = old_context_len / low_freq_factor + high_freq_wavelen = old_context_len / high_freq_factor + + wavelen = 2 * math.pi / inv_freq + + # Low frequency: scale down by factor + inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq) + + # Medium frequency: smooth interpolation + smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor) + smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama + smooth_factor * inv_freq + is_medium_freq = (wavelen >= high_freq_wavelen) & (wavelen <= low_freq_wavelen) + inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama) + + return inv_freq_llama + + @torch.compile + def forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + cos_sin = self.cos_sin_cache[positions] + cos, sin = cos_sin.chunk(2, dim=-1) + query = apply_rotary_emb(query, cos, sin) + key = apply_rotary_emb(key, cos, sin) + return query, key + + +# Cache for RoPE instances (keyed by hashable parameters) +_rope_cache: dict[tuple, nn.Module] = {} + + def get_rope( head_size: int, rotary_dim: int, @@ -56,6 +151,42 @@ def get_rope( base: float, rope_scaling: dict | None = None, ): - assert rope_scaling is None - rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base) - return rotary_emb + # Create hashable cache key + if rope_scaling is None: + cache_key = (head_size, rotary_dim, max_position, base, None) + else: + rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default")) + if rope_type == "llama3": + cache_key = ( + head_size, rotary_dim, max_position, base, "llama3", + rope_scaling["factor"], + rope_scaling["low_freq_factor"], + rope_scaling["high_freq_factor"], + rope_scaling["original_max_position_embeddings"], + ) + else: + cache_key = (head_size, rotary_dim, max_position, base, rope_type) + + if cache_key in _rope_cache: + return _rope_cache[cache_key] + + if rope_scaling is None: + rope = RotaryEmbedding(head_size, rotary_dim, max_position, base) + else: + rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default")) + if rope_type == "llama3": + rope = Llama3RotaryEmbedding( + head_size, + rotary_dim, + max_position, + base, + factor=rope_scaling["factor"], + low_freq_factor=rope_scaling["low_freq_factor"], + high_freq_factor=rope_scaling["high_freq_factor"], + original_max_position_embeddings=rope_scaling["original_max_position_embeddings"], + ) + else: + raise ValueError(f"Unsupported rope_type: {rope_type}") + + _rope_cache[cache_key] = rope + return rope diff --git a/nanovllm/models/__init__.py b/nanovllm/models/__init__.py new file mode 100644 index 0000000..28d41b2 --- /dev/null +++ b/nanovllm/models/__init__.py @@ -0,0 +1,9 @@ +"""Model registry and model implementations.""" + +from nanovllm.models.registry import register_model, get_model_class, MODEL_REGISTRY + +# Import models to trigger registration +from nanovllm.models import qwen3 +from nanovllm.models import llama + +__all__ = ["register_model", "get_model_class", "MODEL_REGISTRY"] diff --git a/nanovllm/models/llama.py b/nanovllm/models/llama.py new file mode 100644 index 0000000..14c4fb4 --- /dev/null +++ b/nanovllm/models/llama.py @@ -0,0 +1,194 @@ +import torch +from torch import nn +import torch.distributed as dist + +from nanovllm.layers.activation import SiluAndMul +from nanovllm.layers.attention import Attention +from nanovllm.layers.layernorm import RMSNorm +from nanovllm.layers.linear import QKVParallelLinear, MergedColumnParallelLinear, RowParallelLinear +from nanovllm.layers.rotary_embedding import get_rope +from nanovllm.layers.embed_head import VocabParallelEmbedding, ParallelLMHead +from nanovllm.models.registry import register_model + + +class LlamaAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 4096 * 32, + head_dim: int | None = None, + rope_theta: float = 10000, + rope_scaling: dict | None = None, + ) -> None: + super().__init__() + tp_size = dist.get_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + assert self.total_num_kv_heads % tp_size == 0 + self.num_kv_heads = self.total_num_kv_heads // tp_size + self.head_dim = head_dim or hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim ** -0.5 + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, # Llama has no attention bias + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + ) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position, + base=rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + self.num_kv_heads, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q = q.view(-1, self.num_heads, self.head_dim) + k = k.view(-1, self.num_kv_heads, self.head_dim) + v = v.view(-1, self.num_kv_heads, self.head_dim) + # Llama has no q_norm/k_norm + q, k = self.rotary_emb(positions, q, k) + o = self.attn(q, k, v) + output = self.o_proj(o.flatten(1, -1)) + return output + + +class LlamaMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=False, + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + ) + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x = self.down_proj(x) + return x + + +class LlamaDecoderLayer(nn.Module): + + def __init__(self, config) -> None: + super().__init__() + self.self_attn = LlamaAttention( + hidden_size=config.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + max_position=config.max_position_embeddings, + head_dim=getattr(config, 'head_dim', None), + rope_theta=getattr(config, "rope_theta", 10000), + rope_scaling=getattr(config, "rope_scaling", None), + ) + self.mlp = LlamaMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + ) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: torch.Tensor | None, + ) -> tuple[torch.Tensor, torch.Tensor]: + if residual is None: + hidden_states, residual = self.input_layernorm(hidden_states), hidden_states + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + hidden_states = self.self_attn(positions, hidden_states) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class LlamaModel(nn.Module): + + def __init__(self, config) -> None: + super().__init__() + self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size) + self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + residual = None + for layer in self.layers: + hidden_states, residual = layer(positions, hidden_states, residual) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +@register_model("LlamaForCausalLM") +class LlamaForCausalLM(nn.Module): + packed_modules_mapping = { + "q_proj": ("qkv_proj", "q"), + "k_proj": ("qkv_proj", "k"), + "v_proj": ("qkv_proj", "v"), + "gate_proj": ("gate_up_proj", 0), + "up_proj": ("gate_up_proj", 1), + } + + def __init__(self, config) -> None: + super().__init__() + self.model = LlamaModel(config) + self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) + if getattr(config, 'tie_word_embeddings', False): + self.lm_head.weight.data = self.model.embed_tokens.weight.data + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + ) -> torch.Tensor: + return self.model(input_ids, positions) + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + return self.lm_head(hidden_states) diff --git a/nanovllm/models/qwen3.py b/nanovllm/models/qwen3.py index 6298d8b..b4e8413 100755 --- a/nanovllm/models/qwen3.py +++ b/nanovllm/models/qwen3.py @@ -9,6 +9,7 @@ from nanovllm.layers.layernorm import RMSNorm from nanovllm.layers.linear import QKVParallelLinear, MergedColumnParallelLinear, RowParallelLinear from nanovllm.layers.rotary_embedding import get_rope from nanovllm.layers.embed_head import VocabParallelEmbedding, ParallelLMHead +from nanovllm.models.registry import register_model class Qwen3Attention(nn.Module): @@ -186,6 +187,7 @@ class Qwen3Model(nn.Module): return hidden_states +@register_model("Qwen3ForCausalLM", "Qwen2ForCausalLM") class Qwen3ForCausalLM(nn.Module): packed_modules_mapping = { "q_proj": ("qkv_proj", "q"), diff --git a/nanovllm/models/registry.py b/nanovllm/models/registry.py new file mode 100644 index 0000000..84120f3 --- /dev/null +++ b/nanovllm/models/registry.py @@ -0,0 +1,46 @@ +"""Model registry for dynamic model loading.""" + +from typing import Type +from torch import nn + +# Global registry mapping architecture names to model classes +MODEL_REGISTRY: dict[str, Type[nn.Module]] = {} + + +def register_model(*architectures: str): + """ + Decorator to register a model class for given architecture names. + + Usage: + @register_model("LlamaForCausalLM") + class LlamaForCausalLM(nn.Module): + ... + """ + def decorator(cls: Type[nn.Module]) -> Type[nn.Module]: + for arch in architectures: + MODEL_REGISTRY[arch] = cls + return cls + return decorator + + +def get_model_class(hf_config) -> Type[nn.Module]: + """ + Get model class based on HuggingFace config. + + Args: + hf_config: HuggingFace model config with 'architectures' field + + Returns: + Model class for the given architecture + + Raises: + ValueError: If architecture is not supported + """ + architectures = getattr(hf_config, "architectures", []) + for arch in architectures: + if arch in MODEL_REGISTRY: + return MODEL_REGISTRY[arch] + raise ValueError( + f"Unsupported architecture: {architectures}. " + f"Supported: {list(MODEL_REGISTRY.keys())}" + ) diff --git a/progress.md b/progress.md new file mode 100644 index 0000000..11a1daa --- /dev/null +++ b/progress.md @@ -0,0 +1,76 @@ +# Progress Log: Multi-Model Support + +## Session: 2026-01-10 + +### Initial Analysis Complete + +**Time**: Session start + +**Actions:** +1. Read `nanovllm/engine/model_runner.py` - 确认硬编码位置 (line 35) +2. Read `nanovllm/models/qwen3.py` - 理解 Qwen3 模型结构 +3. Read `nanovllm/utils/loader.py` - 理解权重加载机制 +4. Read `nanovllm/layers/rotary_embedding.py` - 发现 RoPE scaling 限制 +5. Read `/home/zijie/models/Llama-3.1-8B-Instruct/config.json` - 理解 Llama 配置 + +**Key Findings:** +- 模型加载在 `model_runner.py:35` 硬编码为 Qwen3 +- RoPE 目前不支持 scaling (`assert rope_scaling is None`) +- Llama 3.1 需要 "llama3" 类型的 RoPE scaling +- Llama 无 q_norm/k_norm,无 attention bias + +**Created:** +- `task_plan.md` - 6 阶段实施计划 +- `findings.md` - 技术分析和发现 + +--- + +### Phase Status + +| Phase | Status | Notes | +|-------|--------|-------| +| 1. Model Registry | **COMPLETED** | `registry.py`, `__init__.py` | +| 2. Llama3 RoPE | **COMPLETED** | `rotary_embedding.py` | +| 3. Llama Model | **COMPLETED** | `llama.py` | +| 4. ModelRunner | **COMPLETED** | Dynamic loading | +| 5. Qwen3 Register | **COMPLETED** | `@register_model` decorator | +| 6. Testing | **COMPLETED** | Both Llama & Qwen3 pass | + +--- + +## Test Results + +### Llama 3.1-8B-Instruct (32K needle, GPU 0, offload) +``` +Input: 32768 tokens +Expected: 7492 +Output: 7492 +Status: PASSED +Prefill: 1644 tok/s +``` + +### Qwen3-4B (8K needle, GPU 1, offload) - Regression Test +``` +Input: 8192 tokens +Expected: 7492 +Output: 7492 +Status: PASSED +Prefill: 3295 tok/s +``` + +--- + +## Files Modified This Session + +| File | Action | Description | +|------|--------|-------------| +| `nanovllm/models/registry.py` | created | Model registry with `@register_model` decorator | +| `nanovllm/models/__init__.py` | created | Export registry functions, import models | +| `nanovllm/models/llama.py` | created | Llama model implementation | +| `nanovllm/models/qwen3.py` | modified | Added `@register_model` decorator | +| `nanovllm/layers/rotary_embedding.py` | modified | Added Llama3 RoPE scaling | +| `nanovllm/engine/model_runner.py` | modified | Dynamic model loading via registry | +| `.claude/rules/gpu-testing.md` | created | GPU testing rules | +| `task_plan.md` | created | Implementation plan | +| `findings.md` | created | Technical findings | +| `progress.md` | created | Progress tracking | diff --git a/task_plan.md b/task_plan.md new file mode 100644 index 0000000..87626ef --- /dev/null +++ b/task_plan.md @@ -0,0 +1,144 @@ +# Task Plan: Multi-Model Support for nanovllm + +## Goal +扩展 nanovllm 框架以支持多种模型(当前只支持 Qwen3),特别是添加 Llama-3.1-8B-Instruct 支持,并建立可扩展的模型添加范式。 + +## Current State Analysis + +### 硬编码问题位置 +- `nanovllm/engine/model_runner.py:35`: 直接实例化 `Qwen3ForCausalLM(hf_config)` +- `nanovllm/engine/model_runner.py:9`: 硬编码导入 `from nanovllm.models.qwen3 import Qwen3ForCausalLM` + +### Qwen3 vs Llama 3.1 架构差异 + +| Feature | Qwen3 | Llama 3.1 | +|---------|-------|-----------| +| Config Class | Qwen3Config | LlamaConfig | +| attention_bias | True (可配置) | False | +| q_norm/k_norm | 有 (when bias=False) | 无 | +| mlp_bias | N/A | False | +| RoPE Scaling | None (目前) | llama3 类型 | +| RoPE theta | 1000000 | 500000 | +| hidden_act | silu | silu | +| tie_word_embeddings | True | False | + +### 关键限制 +- `rotary_embedding.py:59`: `assert rope_scaling is None` - 不支持 RoPE scaling + +--- + +## Phases + +### Phase 1: Create Model Registry Pattern [pending] +**Files to modify:** +- `nanovllm/models/__init__.py` (new) +- `nanovllm/models/registry.py` (new) + +**Tasks:** +1. 创建模型注册表机制 +2. 定义模型注册装饰器 `@register_model` +3. 实现 `get_model_class(hf_config)` 函数,根据 `architectures` 字段自动选择模型 + +**Design:** +```python +MODEL_REGISTRY: dict[str, type] = {} + +def register_model(*architectures): + """Decorator to register a model class for given architecture names.""" + def decorator(cls): + for arch in architectures: + MODEL_REGISTRY[arch] = cls + return cls + return decorator + +def get_model_class(hf_config) -> type: + """Get model class based on HF config architectures.""" + for arch in hf_config.architectures: + if arch in MODEL_REGISTRY: + return MODEL_REGISTRY[arch] + raise ValueError(f"Unsupported architecture: {hf_config.architectures}") +``` + +### Phase 2: Add Llama3 RoPE Scaling Support [pending] +**Files to modify:** +- `nanovllm/layers/rotary_embedding.py` + +**Tasks:** +1. 实现 `Llama3RotaryEmbedding` 类,支持 llama3 rope_type +2. 修改 `get_rope()` 函数,根据 rope_scaling 类型选择实现 +3. 保持向后兼容(rope_scaling=None 使用原实现) + +**Llama3 RoPE Scaling Formula:** +```python +# From transformers: +# low_freq_factor, high_freq_factor, original_max_position_embeddings +# Adjust frequencies based on wavelength thresholds +``` + +### Phase 3: Implement Llama Model [pending] +**Files to create:** +- `nanovllm/models/llama.py` + +**Tasks:** +1. 创建 `LlamaAttention` 类(无 q_norm/k_norm,无 QKV bias) +2. 创建 `LlamaMLP` 类(与 Qwen3MLP 类似,无 bias) +3. 创建 `LlamaDecoderLayer` 类 +4. 创建 `LlamaModel` 和 `LlamaForCausalLM` 类 +5. 添加 `packed_modules_mapping` 以支持权重加载 +6. 使用 `@register_model("LlamaForCausalLM")` 注册 + +### Phase 4: Modify ModelRunner for Dynamic Loading [pending] +**Files to modify:** +- `nanovllm/engine/model_runner.py` + +**Tasks:** +1. 移除硬编码 `from nanovllm.models.qwen3 import Qwen3ForCausalLM` +2. 导入 `from nanovllm.models import get_model_class` +3. 替换 `self.model = Qwen3ForCausalLM(hf_config)` 为: + ```python + model_class = get_model_class(hf_config) + self.model = model_class(hf_config) + ``` + +### Phase 5: Register Qwen3 Model [pending] +**Files to modify:** +- `nanovllm/models/qwen3.py` + +**Tasks:** +1. 导入 `from nanovllm.models.registry import register_model` +2. 添加 `@register_model("Qwen3ForCausalLM", "Qwen2ForCausalLM")` 装饰器 + +### Phase 6: Test with Llama-3.1-8B-Instruct [pending] +**Files:** +- `tests/test_needle.py` (existing, use for validation) + +**Tasks:** +1. 运行 needle 测试: `python tests/test_needle.py --model ~/models/Llama-3.1-8B-Instruct` +2. 验证模型加载正确 +3. 验证推理输出正确 + +--- + +## Errors Encountered +| Error | Attempt | Resolution | +|-------|---------|------------| +| (none yet) | | | + +--- + +## Success Criteria +- [x] 分析完成:理解当前架构和需要的改动 +- [ ] Phase 1: 模型注册表实现 +- [ ] Phase 2: Llama3 RoPE scaling 支持 +- [ ] Phase 3: Llama 模型实现 +- [ ] Phase 4: ModelRunner 动态加载 +- [ ] Phase 5: Qwen3 模型注册 +- [ ] Phase 6: Llama needle 测试通过 + +--- + +## Notes +- 保持现有 Qwen3 功能不变 +- 遵循现有代码风格 +- 复用现有 layers 组件(Linear, RMSNorm, Embedding 等) +- 只添加必要的代码,不过度工程化