Merge branch 'zijie/add-llama-1': Add multi-model support

- Add model registry system for dynamic model loading
- Implement LlamaForCausalLM with Llama3 RoPE scaling
- Register Qwen3ForCausalLM and Qwen2ForCausalLM
- Update ModelRunner to use get_model_class() for dynamic model selection

Tested: needle 32k test PASSED

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Zijie Tian
2026-01-10 21:20:53 +08:00
10 changed files with 947 additions and 7 deletions

View File

@@ -1,4 +1,4 @@
from functools import lru_cache
import math
import torch
from torch import nn
@@ -48,7 +48,102 @@ class RotaryEmbedding(nn.Module):
return query, key
@lru_cache(1)
class Llama3RotaryEmbedding(nn.Module):
"""
Llama 3 RoPE with special frequency scaling.
Llama 3 uses a piecewise frequency adjustment:
- High frequencies (short wavelengths): unchanged
- Low frequencies (long wavelengths): scaled down by factor
- Medium frequencies: smoothly interpolated
"""
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: float,
factor: float,
low_freq_factor: float,
high_freq_factor: float,
original_max_position_embeddings: int,
) -> None:
super().__init__()
self.head_size = head_size
assert rotary_dim == head_size
# Compute base inv_freq
inv_freq = 1.0 / (base ** (torch.arange(0, rotary_dim, 2, dtype=torch.float) / rotary_dim))
# Apply Llama3 scaling
inv_freq = self._compute_llama3_inv_freq(
inv_freq,
factor,
low_freq_factor,
high_freq_factor,
original_max_position_embeddings,
)
# Build cos/sin cache
t = torch.arange(max_position_embeddings, dtype=torch.float)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1).unsqueeze_(1)
self.register_buffer("cos_sin_cache", cache, persistent=False)
def _compute_llama3_inv_freq(
self,
inv_freq: torch.Tensor,
factor: float,
low_freq_factor: float,
high_freq_factor: float,
original_max_position_embeddings: int,
) -> torch.Tensor:
"""
Apply Llama3 frequency scaling.
- wavelength > low_freq_wavelen: scale down by factor (long range, needs interpolation)
- wavelength < high_freq_wavelen: keep unchanged (short range, high fidelity)
- in between: smooth interpolation
"""
old_context_len = original_max_position_embeddings
low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
wavelen = 2 * math.pi / inv_freq
# Low frequency: scale down by factor
inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq)
# Medium frequency: smooth interpolation
smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama + smooth_factor * inv_freq
is_medium_freq = (wavelen >= high_freq_wavelen) & (wavelen <= low_freq_wavelen)
inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
return inv_freq_llama
@torch.compile
def forward(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
cos_sin = self.cos_sin_cache[positions]
cos, sin = cos_sin.chunk(2, dim=-1)
query = apply_rotary_emb(query, cos, sin)
key = apply_rotary_emb(key, cos, sin)
return query, key
# Cache for RoPE instances (keyed by hashable parameters)
_rope_cache: dict[tuple, nn.Module] = {}
def get_rope(
head_size: int,
rotary_dim: int,
@@ -56,6 +151,42 @@ def get_rope(
base: float,
rope_scaling: dict | None = None,
):
assert rope_scaling is None
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base)
return rotary_emb
# Create hashable cache key
if rope_scaling is None:
cache_key = (head_size, rotary_dim, max_position, base, None)
else:
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default"))
if rope_type == "llama3":
cache_key = (
head_size, rotary_dim, max_position, base, "llama3",
rope_scaling["factor"],
rope_scaling["low_freq_factor"],
rope_scaling["high_freq_factor"],
rope_scaling["original_max_position_embeddings"],
)
else:
cache_key = (head_size, rotary_dim, max_position, base, rope_type)
if cache_key in _rope_cache:
return _rope_cache[cache_key]
if rope_scaling is None:
rope = RotaryEmbedding(head_size, rotary_dim, max_position, base)
else:
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default"))
if rope_type == "llama3":
rope = Llama3RotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
factor=rope_scaling["factor"],
low_freq_factor=rope_scaling["low_freq_factor"],
high_freq_factor=rope_scaling["high_freq_factor"],
original_max_position_embeddings=rope_scaling["original_max_position_embeddings"],
)
else:
raise ValueError(f"Unsupported rope_type: {rope_type}")
_rope_cache[cache_key] = rope
return rope