This commit is contained in:
GeeeekExplorer
2025-08-31 19:44:57 +08:00
parent 6a6d217de7
commit df99418f7d
11 changed files with 47 additions and 96 deletions

View File

@@ -36,7 +36,7 @@ class Qwen3Attention(nn.Module):
self.head_dim = head_dim or hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.scaling = self.head_dim ** -0.5
self.qkv_proj = QKVParallelLinear(
hidden_size,
@@ -73,15 +73,12 @@ class Qwen3Attention(nn.Module):
) -> torch.Tensor:
qkv = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q_by_head = q.view(-1, self.num_heads, self.head_dim)
q_by_head = self.q_norm(q_by_head)
q = q_by_head.view(q.shape)
k_by_head = k.view(-1, self.num_kv_heads, self.head_dim)
k_by_head = self.k_norm(k_by_head)
k = k_by_head.view(k.shape)
q = self.q_norm(q.view(-1, self.num_heads, self.head_dim))
k = self.k_norm(k.view(-1, self.num_kv_heads, self.head_dim))
v = v.view(-1, self.num_kv_heads, self.head_dim)
q, k = self.rotary_emb(positions, q, k)
o = self.attn(q, k, v)
output = self.o_proj(o)
output = self.o_proj(o.flatten(1, -1))
return output
@@ -147,8 +144,7 @@ class Qwen3DecoderLayer(nn.Module):
residual: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor]:
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states, residual = self.input_layernorm(hidden_states), hidden_states
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = self.self_attn(positions, hidden_states)
@@ -205,12 +201,10 @@ class Qwen3ForCausalLM(nn.Module):
input_ids: torch.Tensor,
positions: torch.Tensor,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions)
return hidden_states
return self.model(input_ids, positions)
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor:
logits = self.lm_head(hidden_states)
return logits
return self.lm_head(hidden_states)