248 lines
8.2 KiB
Python
Executable File
248 lines
8.2 KiB
Python
Executable File
import torch
|
|
from torch import nn
|
|
from transformers import Qwen3Config
|
|
|
|
from nanovllm.layers.activation import SiluAndMul
|
|
from nanovllm.layers.attention import Attention
|
|
from nanovllm.layers.layernorm import RMSNorm
|
|
from nanovllm.layers.linear import QKVParallelLinear, MergedColumnParallelLinear, RowParallelLinear
|
|
from nanovllm.layers.rotary_embedding import get_rope
|
|
from nanovllm.layers.embed_head import VocabParallelEmbedding, ParallelLMHead
|
|
|
|
|
|
class Qwen3Attention(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
hidden_size: int,
|
|
num_heads: int,
|
|
num_kv_heads: int,
|
|
max_position: int = 4096 * 32,
|
|
head_dim: int | None = None,
|
|
rms_norm_eps: float = 1e-06,
|
|
qkv_bias: bool = False,
|
|
rope_theta: float = 10000,
|
|
rope_scaling: tuple | None = None,
|
|
) -> None:
|
|
super().__init__()
|
|
self.hidden_size = hidden_size
|
|
tp_size = 1 # get_tensor_model_parallel_world_size()
|
|
self.total_num_heads = num_heads
|
|
assert self.total_num_heads % tp_size == 0
|
|
self.num_heads = self.total_num_heads // tp_size
|
|
self.total_num_kv_heads = num_kv_heads
|
|
assert self.total_num_kv_heads % tp_size == 0
|
|
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
|
self.head_dim = head_dim or hidden_size // self.total_num_heads
|
|
self.q_size = self.num_heads * self.head_dim
|
|
self.kv_size = self.num_kv_heads * self.head_dim
|
|
self.scaling = self.head_dim**-0.5
|
|
self.rope_theta = rope_theta
|
|
|
|
self.qkv_proj = QKVParallelLinear(
|
|
hidden_size,
|
|
self.head_dim,
|
|
self.total_num_heads,
|
|
self.total_num_kv_heads,
|
|
bias=qkv_bias,
|
|
)
|
|
self.o_proj = RowParallelLinear(
|
|
self.total_num_heads * self.head_dim,
|
|
hidden_size,
|
|
bias=False,
|
|
)
|
|
|
|
self.rotary_emb = get_rope(
|
|
self.head_dim,
|
|
rotary_dim=self.head_dim,
|
|
max_position=max_position,
|
|
base=self.rope_theta,
|
|
rope_scaling=rope_scaling,
|
|
)
|
|
self.attn = Attention(self.num_heads,
|
|
self.head_dim,
|
|
self.scaling,
|
|
num_kv_heads=self.num_kv_heads)
|
|
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
|
|
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
|
|
|
|
def forward(
|
|
self,
|
|
positions: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
qkv = self.qkv_proj(hidden_states)
|
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
|
q_by_head = q.view(-1, self.num_heads, self.head_dim)
|
|
q_by_head = self.q_norm(q_by_head)
|
|
q = q_by_head.view(q.shape)
|
|
k_by_head = k.view(-1, self.num_kv_heads, self.head_dim)
|
|
k_by_head = self.k_norm(k_by_head)
|
|
k = k_by_head.view(k.shape)
|
|
q, k = self.rotary_emb(positions, q, k)
|
|
o = self.attn(q, k, v)
|
|
output = self.o_proj(o)
|
|
return output
|
|
|
|
|
|
class Qwen3MLP(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
hidden_size: int,
|
|
intermediate_size: int,
|
|
hidden_act: str,
|
|
) -> None:
|
|
super().__init__()
|
|
self.gate_up_proj = MergedColumnParallelLinear(
|
|
hidden_size,
|
|
[intermediate_size] * 2,
|
|
bias=False,
|
|
)
|
|
self.down_proj = RowParallelLinear(
|
|
intermediate_size,
|
|
hidden_size,
|
|
bias=False,
|
|
)
|
|
assert hidden_act == "silu"
|
|
self.act_fn = SiluAndMul()
|
|
|
|
def forward(self, x):
|
|
gate_up = self.gate_up_proj(x)
|
|
x = self.act_fn(gate_up)
|
|
x = self.down_proj(x)
|
|
return x
|
|
|
|
|
|
class Qwen3DecoderLayer(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
config: Qwen3Config,
|
|
) -> None:
|
|
super().__init__()
|
|
self.hidden_size = config.hidden_size
|
|
self.self_attn = Qwen3Attention(
|
|
hidden_size=self.hidden_size,
|
|
num_heads=config.num_attention_heads,
|
|
num_kv_heads=config.num_key_value_heads,
|
|
max_position=config.max_position_embeddings,
|
|
rms_norm_eps=config.rms_norm_eps,
|
|
qkv_bias=getattr(config, 'attention_bias', False),
|
|
head_dim=getattr(config, 'head_dim', None),
|
|
rope_theta=getattr(config, "rope_theta", 1000000),
|
|
rope_scaling=getattr(config, "rope_scaling", None),
|
|
)
|
|
self.mlp = Qwen3MLP(
|
|
hidden_size=config.hidden_size,
|
|
intermediate_size=config.intermediate_size,
|
|
hidden_act=config.hidden_act,
|
|
)
|
|
self.input_layernorm = RMSNorm(config.hidden_size,
|
|
eps=config.rms_norm_eps)
|
|
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
|
eps=config.rms_norm_eps)
|
|
|
|
def forward(
|
|
self,
|
|
positions: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
residual: torch.Tensor | None,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
if residual is None:
|
|
residual = hidden_states
|
|
hidden_states = self.input_layernorm(hidden_states)
|
|
else:
|
|
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
|
hidden_states = self.self_attn(
|
|
positions=positions,
|
|
hidden_states=hidden_states,
|
|
)
|
|
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
|
hidden_states = self.mlp(hidden_states)
|
|
return hidden_states, residual
|
|
|
|
|
|
class Qwen3Model(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
config: Qwen3Config,
|
|
):
|
|
super().__init__()
|
|
self.vocab_size = config.vocab_size
|
|
self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size)
|
|
self.layers = nn.ModuleList([Qwen3DecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
hidden_states = self.embed_tokens(input_ids)
|
|
residual = None
|
|
for layer in self.layers:
|
|
hidden_states, residual = layer(
|
|
positions,
|
|
hidden_states,
|
|
residual,
|
|
)
|
|
hidden_states, _ = self.norm(hidden_states, residual)
|
|
return hidden_states
|
|
|
|
|
|
class Qwen3ForCausalLM(nn.Module):
|
|
packed_modules_mapping = {
|
|
"qkv_proj": [
|
|
"q_proj",
|
|
"k_proj",
|
|
"v_proj",
|
|
],
|
|
"gate_up_proj": [
|
|
"gate_proj",
|
|
"up_proj",
|
|
],
|
|
}
|
|
|
|
def __init__(
|
|
self,
|
|
config: Qwen3Config
|
|
):
|
|
super().__init__()
|
|
self.model = Qwen3Model(config)
|
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
|
if config.tie_word_embeddings:
|
|
self.lm_head.weight.data = self.model.embed_tokens.weight.data
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
hidden_states = self.model(input_ids, positions)
|
|
return hidden_states
|
|
|
|
def compute_logits(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
logits = self.lm_head(hidden_states)
|
|
return logits
|
|
|
|
def load_weights(self, path: str):
|
|
import os
|
|
from safetensors import safe_open
|
|
default_weight_loader = lambda param, loaded_weight: param.data.copy_(loaded_weight)
|
|
with safe_open(os.path.join(path, "model.safetensors"), "pt", "cpu") as f:
|
|
for n, p in self.named_parameters():
|
|
for x in self.packed_modules_mapping:
|
|
if x in n:
|
|
weight_loader = getattr(p, "weight_loader", default_weight_loader)
|
|
for i, y in enumerate(self.packed_modules_mapping[x]):
|
|
weight_loader(p, f.get_tensor(n.replace(x, y)), y[0] if x == "qkv_proj" else i)
|
|
break
|
|
else:
|
|
weight_loader = getattr(p, "weight_loader", default_weight_loader)
|
|
weight_loader(p, f.get_tensor(n))
|