Merge branch 'zijie/add-llama-1': Add multi-model support

- Add model registry system for dynamic model loading
- Implement LlamaForCausalLM with Llama3 RoPE scaling
- Register Qwen3ForCausalLM and Qwen2ForCausalLM
- Update ModelRunner to use get_model_class() for dynamic model selection

Tested: needle 32k test PASSED

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Zijie Tian
2026-01-10 21:20:53 +08:00
10 changed files with 947 additions and 7 deletions

View File

@@ -6,7 +6,7 @@ from multiprocessing.shared_memory import SharedMemory
from nanovllm.config import Config, SparsePolicyType
from nanovllm.engine.sequence import Sequence
from nanovllm.models.qwen3 import Qwen3ForCausalLM
from nanovllm.models import get_model_class
from nanovllm.layers.sampler import GreedySampler
from nanovllm.utils.context import set_context, get_context, reset_context
from nanovllm.utils.loader import load_model
@@ -32,7 +32,8 @@ class ModelRunner:
default_dtype = torch.get_default_dtype()
torch.set_default_dtype(hf_config.torch_dtype)
torch.set_default_device("cuda")
self.model = Qwen3ForCausalLM(hf_config)
model_class = get_model_class(hf_config)
self.model = model_class(hf_config)
load_model(self.model, config.model)
self.sampler = GreedySampler()

View File

@@ -1,4 +1,4 @@
from functools import lru_cache
import math
import torch
from torch import nn
@@ -48,7 +48,102 @@ class RotaryEmbedding(nn.Module):
return query, key
@lru_cache(1)
class Llama3RotaryEmbedding(nn.Module):
"""
Llama 3 RoPE with special frequency scaling.
Llama 3 uses a piecewise frequency adjustment:
- High frequencies (short wavelengths): unchanged
- Low frequencies (long wavelengths): scaled down by factor
- Medium frequencies: smoothly interpolated
"""
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: float,
factor: float,
low_freq_factor: float,
high_freq_factor: float,
original_max_position_embeddings: int,
) -> None:
super().__init__()
self.head_size = head_size
assert rotary_dim == head_size
# Compute base inv_freq
inv_freq = 1.0 / (base ** (torch.arange(0, rotary_dim, 2, dtype=torch.float) / rotary_dim))
# Apply Llama3 scaling
inv_freq = self._compute_llama3_inv_freq(
inv_freq,
factor,
low_freq_factor,
high_freq_factor,
original_max_position_embeddings,
)
# Build cos/sin cache
t = torch.arange(max_position_embeddings, dtype=torch.float)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1).unsqueeze_(1)
self.register_buffer("cos_sin_cache", cache, persistent=False)
def _compute_llama3_inv_freq(
self,
inv_freq: torch.Tensor,
factor: float,
low_freq_factor: float,
high_freq_factor: float,
original_max_position_embeddings: int,
) -> torch.Tensor:
"""
Apply Llama3 frequency scaling.
- wavelength > low_freq_wavelen: scale down by factor (long range, needs interpolation)
- wavelength < high_freq_wavelen: keep unchanged (short range, high fidelity)
- in between: smooth interpolation
"""
old_context_len = original_max_position_embeddings
low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
wavelen = 2 * math.pi / inv_freq
# Low frequency: scale down by factor
inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq)
# Medium frequency: smooth interpolation
smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama + smooth_factor * inv_freq
is_medium_freq = (wavelen >= high_freq_wavelen) & (wavelen <= low_freq_wavelen)
inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
return inv_freq_llama
@torch.compile
def forward(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
cos_sin = self.cos_sin_cache[positions]
cos, sin = cos_sin.chunk(2, dim=-1)
query = apply_rotary_emb(query, cos, sin)
key = apply_rotary_emb(key, cos, sin)
return query, key
# Cache for RoPE instances (keyed by hashable parameters)
_rope_cache: dict[tuple, nn.Module] = {}
def get_rope(
head_size: int,
rotary_dim: int,
@@ -56,6 +151,42 @@ def get_rope(
base: float,
rope_scaling: dict | None = None,
):
assert rope_scaling is None
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base)
return rotary_emb
# Create hashable cache key
if rope_scaling is None:
cache_key = (head_size, rotary_dim, max_position, base, None)
else:
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default"))
if rope_type == "llama3":
cache_key = (
head_size, rotary_dim, max_position, base, "llama3",
rope_scaling["factor"],
rope_scaling["low_freq_factor"],
rope_scaling["high_freq_factor"],
rope_scaling["original_max_position_embeddings"],
)
else:
cache_key = (head_size, rotary_dim, max_position, base, rope_type)
if cache_key in _rope_cache:
return _rope_cache[cache_key]
if rope_scaling is None:
rope = RotaryEmbedding(head_size, rotary_dim, max_position, base)
else:
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default"))
if rope_type == "llama3":
rope = Llama3RotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
factor=rope_scaling["factor"],
low_freq_factor=rope_scaling["low_freq_factor"],
high_freq_factor=rope_scaling["high_freq_factor"],
original_max_position_embeddings=rope_scaling["original_max_position_embeddings"],
)
else:
raise ValueError(f"Unsupported rope_type: {rope_type}")
_rope_cache[cache_key] = rope
return rope

View File

@@ -0,0 +1,9 @@
"""Model registry and model implementations."""
from nanovllm.models.registry import register_model, get_model_class, MODEL_REGISTRY
# Import models to trigger registration
from nanovllm.models import qwen3
from nanovllm.models import llama
__all__ = ["register_model", "get_model_class", "MODEL_REGISTRY"]

194
nanovllm/models/llama.py Normal file
View File

@@ -0,0 +1,194 @@
import torch
from torch import nn
import torch.distributed as dist
from nanovllm.layers.activation import SiluAndMul
from nanovllm.layers.attention import Attention
from nanovllm.layers.layernorm import RMSNorm
from nanovllm.layers.linear import QKVParallelLinear, MergedColumnParallelLinear, RowParallelLinear
from nanovllm.layers.rotary_embedding import get_rope
from nanovllm.layers.embed_head import VocabParallelEmbedding, ParallelLMHead
from nanovllm.models.registry import register_model
class LlamaAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
max_position: int = 4096 * 32,
head_dim: int | None = None,
rope_theta: float = 10000,
rope_scaling: dict | None = None,
) -> None:
super().__init__()
tp_size = dist.get_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
assert self.total_num_kv_heads % tp_size == 0
self.num_kv_heads = self.total_num_kv_heads // tp_size
self.head_dim = head_dim or hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim ** -0.5
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=False, # Llama has no attention bias
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position,
base=rope_theta,
rope_scaling=rope_scaling,
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
self.num_kv_heads,
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
qkv = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q = q.view(-1, self.num_heads, self.head_dim)
k = k.view(-1, self.num_kv_heads, self.head_dim)
v = v.view(-1, self.num_kv_heads, self.head_dim)
# Llama has no q_norm/k_norm
q, k = self.rotary_emb(positions, q, k)
o = self.attn(q, k, v)
output = self.o_proj(o.flatten(1, -1))
return output
class LlamaMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2,
bias=False,
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
)
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x = self.down_proj(x)
return x
class LlamaDecoderLayer(nn.Module):
def __init__(self, config) -> None:
super().__init__()
self.self_attn = LlamaAttention(
hidden_size=config.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
max_position=config.max_position_embeddings,
head_dim=getattr(config, 'head_dim', None),
rope_theta=getattr(config, "rope_theta", 10000),
rope_scaling=getattr(config, "rope_scaling", None),
)
self.mlp = LlamaMLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor]:
if residual is None:
hidden_states, residual = self.input_layernorm(hidden_states), hidden_states
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = self.self_attn(positions, hidden_states)
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
class LlamaModel(nn.Module):
def __init__(self, config) -> None:
super().__init__()
self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
residual = None
for layer in self.layers:
hidden_states, residual = layer(positions, hidden_states, residual)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
@register_model("LlamaForCausalLM")
class LlamaForCausalLM(nn.Module):
packed_modules_mapping = {
"q_proj": ("qkv_proj", "q"),
"k_proj": ("qkv_proj", "k"),
"v_proj": ("qkv_proj", "v"),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
def __init__(self, config) -> None:
super().__init__()
self.model = LlamaModel(config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
if getattr(config, 'tie_word_embeddings', False):
self.lm_head.weight.data = self.model.embed_tokens.weight.data
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
) -> torch.Tensor:
return self.model(input_ids, positions)
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor:
return self.lm_head(hidden_states)

View File

@@ -9,6 +9,7 @@ from nanovllm.layers.layernorm import RMSNorm
from nanovllm.layers.linear import QKVParallelLinear, MergedColumnParallelLinear, RowParallelLinear
from nanovllm.layers.rotary_embedding import get_rope
from nanovllm.layers.embed_head import VocabParallelEmbedding, ParallelLMHead
from nanovllm.models.registry import register_model
class Qwen3Attention(nn.Module):
@@ -186,6 +187,7 @@ class Qwen3Model(nn.Module):
return hidden_states
@register_model("Qwen3ForCausalLM", "Qwen2ForCausalLM")
class Qwen3ForCausalLM(nn.Module):
packed_modules_mapping = {
"q_proj": ("qkv_proj", "q"),

View File

@@ -0,0 +1,46 @@
"""Model registry for dynamic model loading."""
from typing import Type
from torch import nn
# Global registry mapping architecture names to model classes
MODEL_REGISTRY: dict[str, Type[nn.Module]] = {}
def register_model(*architectures: str):
"""
Decorator to register a model class for given architecture names.
Usage:
@register_model("LlamaForCausalLM")
class LlamaForCausalLM(nn.Module):
...
"""
def decorator(cls: Type[nn.Module]) -> Type[nn.Module]:
for arch in architectures:
MODEL_REGISTRY[arch] = cls
return cls
return decorator
def get_model_class(hf_config) -> Type[nn.Module]:
"""
Get model class based on HuggingFace config.
Args:
hf_config: HuggingFace model config with 'architectures' field
Returns:
Model class for the given architecture
Raises:
ValueError: If architecture is not supported
"""
architectures = getattr(hf_config, "architectures", [])
for arch in architectures:
if arch in MODEL_REGISTRY:
return MODEL_REGISTRY[arch]
raise ValueError(
f"Unsupported architecture: {architectures}. "
f"Supported: {list(MODEL_REGISTRY.keys())}"
)