""" Custom Qwen3 implementation using only torch and transformers. This file provides a clean reference implementation for understanding the model computation graph. Computation Graph: ================== Input: token_ids [batch, seq_len] │ ▼ ┌─────────────┐ │ Embedding │ embed_tokens: [vocab_size, hidden_size] └─────────────┘ │ ▼ hidden_states [batch, seq_len, hidden_size] │ ▼ ┌─────────────────────────────────────────────────────────┐ │ Decoder Layer (x N) │ │ ┌───────────────────────────────────────────────────┐ │ │ │ Self Attention Block │ │ │ │ │ │ │ │ input_layernorm (RMSNorm) │ │ │ │ │ │ │ │ │ ▼ │ │ │ │ ┌─────────────────────────────────────────────┐ │ │ │ │ │ Qwen3Attention │ │ │ │ │ │ Q = q_proj(x) → q_norm → reshape │ │ │ │ │ │ K = k_proj(x) → k_norm → reshape │ │ │ │ │ │ V = v_proj(x) → reshape │ │ │ │ │ │ │ │ │ │ │ │ │ ▼ │ │ │ │ │ │ Q, K = apply_rotary_pos_emb(Q, K, cos, sin)│ │ │ │ │ │ │ │ │ │ │ │ │ ▼ │ │ │ │ │ │ attn_output = attention(Q, K, V) │ │ │ │ │ │ │ │ │ │ │ │ │ ▼ │ │ │ │ │ │ output = o_proj(attn_output) │ │ │ │ │ └─────────────────────────────────────────────┘ │ │ │ │ │ │ │ │ │ ▼ │ │ │ │ hidden_states = residual + attn_output │ │ │ └───────────────────────────────────────────────────┘ │ │ │ │ │ ▼ │ │ ┌───────────────────────────────────────────────────┐ │ │ │ MLP Block │ │ │ │ │ │ │ │ post_attention_layernorm (RMSNorm) │ │ │ │ │ │ │ │ │ ▼ │ │ │ │ ┌─────────────────────────────────────────────┐ │ │ │ │ │ Qwen3MLP │ │ │ │ │ │ gate = gate_proj(x) │ │ │ │ │ │ up = up_proj(x) │ │ │ │ │ │ output = down_proj(silu(gate) * up) │ │ │ │ │ └─────────────────────────────────────────────┘ │ │ │ │ │ │ │ │ │ ▼ │ │ │ │ hidden_states = residual + mlp_output │ │ │ └───────────────────────────────────────────────────┘ │ └─────────────────────────────────────────────────────────┘ │ ▼ ┌─────────────┐ │ norm │ final RMSNorm └─────────────┘ │ ▼ ┌─────────────┐ │ lm_head │ [hidden_size, vocab_size] └─────────────┘ │ ▼ logits [batch, seq_len, vocab_size] """ import math from typing import Optional, Tuple, List import torch import torch.nn as nn import torch.nn.functional as F class Qwen3RMSNorm(nn.Module): """RMSNorm implementation.""" def __init__(self, hidden_size: int, eps: float = 1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.eps = eps def forward(self, x: torch.Tensor) -> torch.Tensor: input_dtype = x.dtype x = x.float() variance = x.pow(2).mean(-1, keepdim=True) x = x * torch.rsqrt(variance + self.eps) return self.weight * x.to(input_dtype) class Qwen3RotaryEmbedding(nn.Module): """Rotary Position Embedding (RoPE).""" def __init__(self, dim: int, max_position_embeddings: int = 32768, base: float = 10000.0): super().__init__() self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base # Compute inverse frequencies inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) @torch.no_grad() def forward(self, x: torch.Tensor, position_ids: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: x: Input tensor [batch, seq_len, num_heads, head_dim] or similar position_ids: Position indices [batch, seq_len] Returns: cos, sin: [batch, seq_len, head_dim] """ # inv_freq: [dim/2] # position_ids: [batch, seq_len] inv_freq_expanded = self.inv_freq[None, :, None].float() # [1, dim/2, 1] position_ids_expanded = position_ids[:, None, :].float() # [batch, 1, seq_len] # freqs: [batch, dim/2, seq_len] freqs = inv_freq_expanded @ position_ids_expanded # freqs: [batch, seq_len, dim/2] freqs = freqs.transpose(1, 2) # Duplicate for full head_dim: [batch, seq_len, dim] emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos().to(x.dtype) sin = emb.sin().to(x.dtype) return cos, sin def rotate_half(x: torch.Tensor) -> torch.Tensor: """Rotate half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb( q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Apply rotary position embeddings to Q and K. Args: q: [batch, num_heads, seq_len, head_dim] k: [batch, num_kv_heads, seq_len, head_dim] cos: [batch, seq_len, head_dim] sin: [batch, seq_len, head_dim] Returns: q_embed, k_embed with same shapes as inputs """ # Unsqueeze for broadcasting: [batch, 1, seq_len, head_dim] cos = cos.unsqueeze(1) sin = sin.unsqueeze(1) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed class Qwen3Attention(nn.Module): """ Qwen3 Multi-Head Attention with Grouped Query Attention (GQA) support. Data Flow: --------- hidden_states [batch, seq_len, hidden_size] │ ├──► q_proj ──► q_norm ──► reshape ──► Q [batch, num_heads, seq_len, head_dim] ├──► k_proj ──► k_norm ──► reshape ──► K [batch, num_kv_heads, seq_len, head_dim] └──► v_proj ──► reshape ──► V [batch, num_kv_heads, seq_len, head_dim] │ ▼ apply_rotary_pos_emb(Q, K) │ ▼ attention(Q, K, V) ──► attn_output [batch, num_heads, seq_len, head_dim] │ ▼ reshape ──► o_proj ──► output [batch, seq_len, hidden_size] """ def __init__( self, hidden_size: int, num_attention_heads: int, num_key_value_heads: int, head_dim: int, max_position_embeddings: int = 32768, rope_theta: float = 10000.0, attention_bias: bool = False, rms_norm_eps: float = 1e-6, layer_idx: int = 0, ): super().__init__() self.hidden_size = hidden_size self.num_heads = num_attention_heads self.num_kv_heads = num_key_value_heads self.head_dim = head_dim self.num_kv_groups = num_attention_heads // num_key_value_heads self.layer_idx = layer_idx # Scaling factor self.scaling = head_dim ** -0.5 # QKV projections self.q_proj = nn.Linear(hidden_size, num_attention_heads * head_dim, bias=attention_bias) self.k_proj = nn.Linear(hidden_size, num_key_value_heads * head_dim, bias=attention_bias) self.v_proj = nn.Linear(hidden_size, num_key_value_heads * head_dim, bias=attention_bias) self.o_proj = nn.Linear(num_attention_heads * head_dim, hidden_size, bias=attention_bias) # QK normalization (Qwen3 specific) self.q_norm = Qwen3RMSNorm(head_dim, eps=rms_norm_eps) self.k_norm = Qwen3RMSNorm(head_dim, eps=rms_norm_eps) # Rotary embeddings self.rotary_emb = Qwen3RotaryEmbedding( head_dim, max_position_embeddings=max_position_embeddings, base=rope_theta, ) def forward( self, hidden_states: torch.Tensor, position_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, use_cache: bool = False, output_qkv: bool = False, ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]], Optional[dict]]: """ Args: hidden_states: [batch, seq_len, hidden_size] position_ids: [batch, seq_len] attention_mask: [batch, 1, seq_len, kv_seq_len] (causal mask) past_key_value: (k_cache, v_cache) from previous steps use_cache: Whether to return updated cache output_qkv: Whether to output Q, K, V tensors for debugging Returns: output: [batch, seq_len, hidden_size] past_key_value: Updated cache (if use_cache=True) qkv_dict: {"q": Q, "k": K, "v": V} (if output_qkv=True) """ batch_size, seq_len, _ = hidden_states.shape # === QKV Projections === q = self.q_proj(hidden_states) # [batch, seq_len, num_heads * head_dim] k = self.k_proj(hidden_states) # [batch, seq_len, num_kv_heads * head_dim] v = self.v_proj(hidden_states) # [batch, seq_len, num_kv_heads * head_dim] # Reshape to [batch, seq_len, num_heads, head_dim] q = q.view(batch_size, seq_len, self.num_heads, self.head_dim) k = k.view(batch_size, seq_len, self.num_kv_heads, self.head_dim) v = v.view(batch_size, seq_len, self.num_kv_heads, self.head_dim) # === QK Normalization (Qwen3 specific) === q = self.q_norm(q) k = self.k_norm(k) # Transpose to [batch, num_heads, seq_len, head_dim] q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) # === Rotary Position Embeddings === cos, sin = self.rotary_emb(v, position_ids) q, k = apply_rotary_pos_emb(q, k, cos, sin) # === KV Cache Update === if past_key_value is not None: k_cache, v_cache = past_key_value k = torch.cat([k_cache, k], dim=2) v = torch.cat([v_cache, v], dim=2) new_past_key_value = (k, v) if use_cache else None # === Grouped Query Attention (expand KV heads if needed) === if self.num_kv_groups > 1: # Repeat KV for each query group k = k.repeat_interleave(self.num_kv_groups, dim=1) v = v.repeat_interleave(self.num_kv_groups, dim=1) # === Attention Computation (using SDPA for memory efficiency) === # Use PyTorch's scaled_dot_product_attention which can use FlashAttention backend # is_causal only works when q_len == kv_len (prefill), not during decode q_len, kv_len = q.shape[2], k.shape[2] is_causal = (q_len == kv_len) and (q_len > 1) attn_output = F.scaled_dot_product_attention( q, k, v, attn_mask=None, dropout_p=0.0, is_causal=is_causal, scale=self.scaling, ) # [batch, num_heads, seq_len, head_dim] # === Output Projection === # Transpose back and reshape attn_output = attn_output.transpose(1, 2).contiguous() # [batch, seq_len, num_heads, head_dim] attn_output = attn_output.view(batch_size, seq_len, -1) # [batch, seq_len, hidden_size] output = self.o_proj(attn_output) # Optional QKV output for debugging qkv_dict = None if output_qkv: qkv_dict = { "q": q, # [batch, num_heads, seq_len, head_dim] (post-RoPE) "k": k, # [batch, num_heads, kv_seq_len, head_dim] (post-RoPE, expanded) "v": v, # [batch, num_heads, kv_seq_len, head_dim] (expanded) } return output, new_past_key_value, qkv_dict class Qwen3MLP(nn.Module): """ Qwen3 MLP with SwiGLU activation. Data Flow: --------- hidden_states [batch, seq_len, hidden_size] │ ├──► gate_proj ──► gate [batch, seq_len, intermediate_size] │ └──► up_proj ──► up [batch, seq_len, intermediate_size] │ ▼ silu(gate) * up │ ▼ down_proj ──► output [batch, seq_len, hidden_size] """ def __init__(self, hidden_size: int, intermediate_size: int, bias: bool = False): super().__init__() self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=bias) self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=bias) self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=bias) def forward(self, x: torch.Tensor) -> torch.Tensor: gate = self.gate_proj(x) up = self.up_proj(x) return self.down_proj(F.silu(gate) * up) class Qwen3DecoderLayer(nn.Module): """Single Qwen3 Decoder Layer.""" def __init__( self, hidden_size: int, intermediate_size: int, num_attention_heads: int, num_key_value_heads: int, head_dim: int, max_position_embeddings: int = 32768, rope_theta: float = 10000.0, rms_norm_eps: float = 1e-6, attention_bias: bool = False, mlp_bias: bool = False, layer_idx: int = 0, ): super().__init__() self.layer_idx = layer_idx # Pre-attention LayerNorm self.input_layernorm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps) # Self-attention self.self_attn = Qwen3Attention( hidden_size=hidden_size, num_attention_heads=num_attention_heads, num_key_value_heads=num_key_value_heads, head_dim=head_dim, max_position_embeddings=max_position_embeddings, rope_theta=rope_theta, attention_bias=attention_bias, rms_norm_eps=rms_norm_eps, layer_idx=layer_idx, ) # Post-attention LayerNorm self.post_attention_layernorm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps) # MLP self.mlp = Qwen3MLP(hidden_size, intermediate_size, bias=mlp_bias) def forward( self, hidden_states: torch.Tensor, position_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, use_cache: bool = False, output_qkv: bool = False, ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]], Optional[dict]]: """ Args: hidden_states: [batch, seq_len, hidden_size] position_ids: [batch, seq_len] attention_mask: Causal attention mask past_key_value: KV cache for this layer use_cache: Whether to return updated cache output_qkv: Whether to output Q, K, V for debugging Returns: hidden_states: [batch, seq_len, hidden_size] past_key_value: Updated cache qkv_dict: QKV tensors (if output_qkv=True) """ # === Self Attention Block === residual = hidden_states hidden_states = self.input_layernorm(hidden_states) attn_output, new_past_key_value, qkv_dict = self.self_attn( hidden_states=hidden_states, position_ids=position_ids, attention_mask=attention_mask, past_key_value=past_key_value, use_cache=use_cache, output_qkv=output_qkv, ) hidden_states = residual + attn_output # === MLP Block === residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states return hidden_states, new_past_key_value, qkv_dict class Qwen3Model(nn.Module): """Qwen3 Transformer Model (without LM head).""" def __init__( self, vocab_size: int, hidden_size: int, intermediate_size: int, num_hidden_layers: int, num_attention_heads: int, num_key_value_heads: int, head_dim: int, max_position_embeddings: int = 32768, rope_theta: float = 10000.0, rms_norm_eps: float = 1e-6, attention_bias: bool = False, mlp_bias: bool = False, ): super().__init__() self.vocab_size = vocab_size self.num_hidden_layers = num_hidden_layers # Token embeddings self.embed_tokens = nn.Embedding(vocab_size, hidden_size) # Decoder layers self.layers = nn.ModuleList([ Qwen3DecoderLayer( hidden_size=hidden_size, intermediate_size=intermediate_size, num_attention_heads=num_attention_heads, num_key_value_heads=num_key_value_heads, head_dim=head_dim, max_position_embeddings=max_position_embeddings, rope_theta=rope_theta, rms_norm_eps=rms_norm_eps, attention_bias=attention_bias, mlp_bias=mlp_bias, layer_idx=i, ) for i in range(num_hidden_layers) ]) # Final LayerNorm self.norm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps) def forward( self, input_ids: torch.Tensor, position_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, use_cache: bool = False, output_qkv_layers: Optional[List[int]] = None, ) -> Tuple[torch.Tensor, Optional[List], Optional[dict]]: """ Args: input_ids: [batch, seq_len] position_ids: [batch, seq_len] attention_mask: [batch, seq_len] or pre-computed 4D mask past_key_values: List of (k, v) tuples for each layer use_cache: Whether to return new cache output_qkv_layers: List of layer indices to output QKV for Returns: hidden_states: [batch, seq_len, hidden_size] new_past_key_values: Updated cache qkv_outputs: {layer_idx: qkv_dict} """ batch_size, seq_len = input_ids.shape # Embedding hidden_states = self.embed_tokens(input_ids) # Position IDs if position_ids is None: past_len = past_key_values[0][0].shape[2] if past_key_values else 0 position_ids = torch.arange(past_len, past_len + seq_len, device=input_ids.device) position_ids = position_ids.unsqueeze(0).expand(batch_size, -1) # Attention mask (create causal mask if not provided) if attention_mask is None or attention_mask.dim() == 2: kv_seq_len = seq_len + (past_key_values[0][0].shape[2] if past_key_values else 0) causal_mask = torch.triu( torch.full((seq_len, kv_seq_len), float("-inf"), device=input_ids.device), diagonal=kv_seq_len - seq_len + 1, ) attention_mask = causal_mask.unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, kv_seq_len] # Initialize cache list new_past_key_values = [] if use_cache else None qkv_outputs = {} if output_qkv_layers else None # Decoder layers for i, layer in enumerate(self.layers): past_kv = past_key_values[i] if past_key_values else None output_qkv = output_qkv_layers is not None and i in output_qkv_layers hidden_states, new_kv, qkv_dict = layer( hidden_states=hidden_states, position_ids=position_ids, attention_mask=attention_mask, past_key_value=past_kv, use_cache=use_cache, output_qkv=output_qkv, ) if use_cache: new_past_key_values.append(new_kv) if qkv_dict is not None: qkv_outputs[i] = qkv_dict # Final norm hidden_states = self.norm(hidden_states) return hidden_states, new_past_key_values, qkv_outputs class Qwen3ForCausalLM(nn.Module): """Qwen3 Model with Language Modeling head.""" def __init__( self, vocab_size: int, hidden_size: int, intermediate_size: int, num_hidden_layers: int, num_attention_heads: int, num_key_value_heads: int, head_dim: int, max_position_embeddings: int = 32768, rope_theta: float = 10000.0, rms_norm_eps: float = 1e-6, attention_bias: bool = False, mlp_bias: bool = False, tie_word_embeddings: bool = True, ): super().__init__() self.vocab_size = vocab_size self.tie_word_embeddings = tie_word_embeddings # Transformer model self.model = Qwen3Model( vocab_size=vocab_size, hidden_size=hidden_size, intermediate_size=intermediate_size, num_hidden_layers=num_hidden_layers, num_attention_heads=num_attention_heads, num_key_value_heads=num_key_value_heads, head_dim=head_dim, max_position_embeddings=max_position_embeddings, rope_theta=rope_theta, rms_norm_eps=rms_norm_eps, attention_bias=attention_bias, mlp_bias=mlp_bias, ) # LM head self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False) def forward( self, input_ids: torch.Tensor, position_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, use_cache: bool = False, output_qkv_layers: Optional[List[int]] = None, ) -> Tuple[torch.Tensor, Optional[List], Optional[dict]]: """ Args: input_ids: [batch, seq_len] ... (same as Qwen3Model) Returns: logits: [batch, seq_len, vocab_size] past_key_values: Updated KV cache qkv_outputs: QKV tensors for specified layers """ hidden_states, new_past_key_values, qkv_outputs = self.model( input_ids=input_ids, position_ids=position_ids, attention_mask=attention_mask, past_key_values=past_key_values, use_cache=use_cache, output_qkv_layers=output_qkv_layers, ) logits = self.lm_head(hidden_states) return logits, new_past_key_values, qkv_outputs @classmethod def from_pretrained(cls, model_path: str, dtype: torch.dtype = torch.float16) -> "Qwen3ForCausalLM": """ Load weights from a pretrained Qwen3 model. Args: model_path: Path to model directory containing config.json and model weights dtype: Data type for model weights Returns: Initialized Qwen3ForCausalLM model """ import json import os from safetensors.torch import load_file # Load config config_path = os.path.join(model_path, "config.json") with open(config_path) as f: config = json.load(f) # Create model model = cls( vocab_size=config["vocab_size"], hidden_size=config["hidden_size"], intermediate_size=config["intermediate_size"], num_hidden_layers=config["num_hidden_layers"], num_attention_heads=config["num_attention_heads"], num_key_value_heads=config.get("num_key_value_heads", config["num_attention_heads"]), head_dim=config.get("head_dim", config["hidden_size"] // config["num_attention_heads"]), max_position_embeddings=config.get("max_position_embeddings", 32768), rope_theta=config.get("rope_theta", 10000.0), rms_norm_eps=config.get("rms_norm_eps", 1e-6), attention_bias=config.get("attention_bias", False), mlp_bias=config.get("mlp_bias", False), tie_word_embeddings=config.get("tie_word_embeddings", True), ) # Load weights weight_files = sorted([ f for f in os.listdir(model_path) if f.endswith(".safetensors") ]) state_dict = {} for wf in weight_files: state_dict.update(load_file(os.path.join(model_path, wf))) # Load into model model.load_state_dict(state_dict, strict=False) # Tie lm_head weights to embed_tokens if configured if model.tie_word_embeddings: model.lm_head.weight = model.model.embed_tokens.weight model = model.to(dtype) return model @torch.no_grad() def generate( self, input_ids: torch.Tensor, max_new_tokens: int = 32, temperature: float = 1.0, do_sample: bool = True, pad_token_id: Optional[int] = None, eos_token_id: Optional[int] = None, ) -> torch.Tensor: """Simple autoregressive generation.""" device = input_ids.device batch_size, seq_len = input_ids.shape past_key_values = None generated = input_ids.clone() for _ in range(max_new_tokens): if past_key_values is None: current_input = generated else: current_input = generated[:, -1:] logits, past_key_values, _ = self( input_ids=current_input, past_key_values=past_key_values, use_cache=True, ) next_token_logits = logits[:, -1, :] if temperature > 0 and do_sample: next_token_logits = next_token_logits / temperature probs = torch.softmax(next_token_logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) else: next_token = next_token_logits.argmax(dim=-1, keepdim=True) generated = torch.cat([generated, next_token], dim=1) if eos_token_id is not None and (next_token == eos_token_id).all(): break return generated def print_computation_graph(): """Print the computation graph for reference.""" print(__doc__) if __name__ == "__main__": print_computation_graph()