simplify
This commit is contained in:
@@ -36,7 +36,7 @@ class Qwen3Attention(nn.Module):
|
||||
self.head_dim = head_dim or hidden_size // self.total_num_heads
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.scaling = self.head_dim ** -0.5
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size,
|
||||
@@ -73,15 +73,12 @@ class Qwen3Attention(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
qkv = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q_by_head = q.view(-1, self.num_heads, self.head_dim)
|
||||
q_by_head = self.q_norm(q_by_head)
|
||||
q = q_by_head.view(q.shape)
|
||||
k_by_head = k.view(-1, self.num_kv_heads, self.head_dim)
|
||||
k_by_head = self.k_norm(k_by_head)
|
||||
k = k_by_head.view(k.shape)
|
||||
q = self.q_norm(q.view(-1, self.num_heads, self.head_dim))
|
||||
k = self.k_norm(k.view(-1, self.num_kv_heads, self.head_dim))
|
||||
v = v.view(-1, self.num_kv_heads, self.head_dim)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
o = self.attn(q, k, v)
|
||||
output = self.o_proj(o)
|
||||
output = self.o_proj(o.flatten(1, -1))
|
||||
return output
|
||||
|
||||
|
||||
@@ -147,8 +144,7 @@ class Qwen3DecoderLayer(nn.Module):
|
||||
residual: torch.Tensor | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
hidden_states, residual = self.input_layernorm(hidden_states), hidden_states
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
hidden_states = self.self_attn(positions, hidden_states)
|
||||
@@ -205,12 +201,10 @@ class Qwen3ForCausalLM(nn.Module):
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions)
|
||||
return hidden_states
|
||||
return self.model(input_ids, positions)
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
logits = self.lm_head(hidden_states)
|
||||
return logits
|
||||
return self.lm_head(hidden_states)
|
||||
|
||||
Reference in New Issue
Block a user