Merge branch 'zijie/add-llama-1': Add multi-model support
- Add model registry system for dynamic model loading - Implement LlamaForCausalLM with Llama3 RoPE scaling - Register Qwen3ForCausalLM and Qwen2ForCausalLM - Update ModelRunner to use get_model_class() for dynamic model selection Tested: needle 32k test PASSED Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from functools import lru_cache
|
||||
import math
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
@@ -48,7 +48,102 @@ class RotaryEmbedding(nn.Module):
|
||||
return query, key
|
||||
|
||||
|
||||
@lru_cache(1)
|
||||
class Llama3RotaryEmbedding(nn.Module):
|
||||
"""
|
||||
Llama 3 RoPE with special frequency scaling.
|
||||
|
||||
Llama 3 uses a piecewise frequency adjustment:
|
||||
- High frequencies (short wavelengths): unchanged
|
||||
- Low frequencies (long wavelengths): scaled down by factor
|
||||
- Medium frequencies: smoothly interpolated
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: float,
|
||||
factor: float,
|
||||
low_freq_factor: float,
|
||||
high_freq_factor: float,
|
||||
original_max_position_embeddings: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.head_size = head_size
|
||||
assert rotary_dim == head_size
|
||||
|
||||
# Compute base inv_freq
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, rotary_dim, 2, dtype=torch.float) / rotary_dim))
|
||||
|
||||
# Apply Llama3 scaling
|
||||
inv_freq = self._compute_llama3_inv_freq(
|
||||
inv_freq,
|
||||
factor,
|
||||
low_freq_factor,
|
||||
high_freq_factor,
|
||||
original_max_position_embeddings,
|
||||
)
|
||||
|
||||
# Build cos/sin cache
|
||||
t = torch.arange(max_position_embeddings, dtype=torch.float)
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
cache = torch.cat((cos, sin), dim=-1).unsqueeze_(1)
|
||||
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||
|
||||
def _compute_llama3_inv_freq(
|
||||
self,
|
||||
inv_freq: torch.Tensor,
|
||||
factor: float,
|
||||
low_freq_factor: float,
|
||||
high_freq_factor: float,
|
||||
original_max_position_embeddings: int,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Apply Llama3 frequency scaling.
|
||||
|
||||
- wavelength > low_freq_wavelen: scale down by factor (long range, needs interpolation)
|
||||
- wavelength < high_freq_wavelen: keep unchanged (short range, high fidelity)
|
||||
- in between: smooth interpolation
|
||||
"""
|
||||
old_context_len = original_max_position_embeddings
|
||||
|
||||
low_freq_wavelen = old_context_len / low_freq_factor
|
||||
high_freq_wavelen = old_context_len / high_freq_factor
|
||||
|
||||
wavelen = 2 * math.pi / inv_freq
|
||||
|
||||
# Low frequency: scale down by factor
|
||||
inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq)
|
||||
|
||||
# Medium frequency: smooth interpolation
|
||||
smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
|
||||
smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama + smooth_factor * inv_freq
|
||||
is_medium_freq = (wavelen >= high_freq_wavelen) & (wavelen <= low_freq_wavelen)
|
||||
inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
|
||||
|
||||
return inv_freq_llama
|
||||
|
||||
@torch.compile
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
cos_sin = self.cos_sin_cache[positions]
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
query = apply_rotary_emb(query, cos, sin)
|
||||
key = apply_rotary_emb(key, cos, sin)
|
||||
return query, key
|
||||
|
||||
|
||||
# Cache for RoPE instances (keyed by hashable parameters)
|
||||
_rope_cache: dict[tuple, nn.Module] = {}
|
||||
|
||||
|
||||
def get_rope(
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
@@ -56,6 +151,42 @@ def get_rope(
|
||||
base: float,
|
||||
rope_scaling: dict | None = None,
|
||||
):
|
||||
assert rope_scaling is None
|
||||
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base)
|
||||
return rotary_emb
|
||||
# Create hashable cache key
|
||||
if rope_scaling is None:
|
||||
cache_key = (head_size, rotary_dim, max_position, base, None)
|
||||
else:
|
||||
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default"))
|
||||
if rope_type == "llama3":
|
||||
cache_key = (
|
||||
head_size, rotary_dim, max_position, base, "llama3",
|
||||
rope_scaling["factor"],
|
||||
rope_scaling["low_freq_factor"],
|
||||
rope_scaling["high_freq_factor"],
|
||||
rope_scaling["original_max_position_embeddings"],
|
||||
)
|
||||
else:
|
||||
cache_key = (head_size, rotary_dim, max_position, base, rope_type)
|
||||
|
||||
if cache_key in _rope_cache:
|
||||
return _rope_cache[cache_key]
|
||||
|
||||
if rope_scaling is None:
|
||||
rope = RotaryEmbedding(head_size, rotary_dim, max_position, base)
|
||||
else:
|
||||
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default"))
|
||||
if rope_type == "llama3":
|
||||
rope = Llama3RotaryEmbedding(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
base,
|
||||
factor=rope_scaling["factor"],
|
||||
low_freq_factor=rope_scaling["low_freq_factor"],
|
||||
high_freq_factor=rope_scaling["high_freq_factor"],
|
||||
original_max_position_embeddings=rope_scaling["original_max_position_embeddings"],
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported rope_type: {rope_type}")
|
||||
|
||||
_rope_cache[cache_key] = rope
|
||||
return rope
|
||||
|
||||
Reference in New Issue
Block a user