[refactor] Translate into english, void Chinese due to claude.
This commit is contained in:
@@ -79,8 +79,12 @@ class Qwen3Attention(nn.Module):
|
||||
k = k.view(-1, self.num_kv_heads, self.head_dim)
|
||||
v = v.view(-1, self.num_kv_heads, self.head_dim)
|
||||
if not self.qkv_bias:
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
# Reshape to 2D before RMSNorm to avoid torch.compile recompilation
|
||||
# q: [num_tokens, num_heads, head_dim] -> [num_tokens * num_heads, head_dim]
|
||||
# After norm, reshape back to 3D
|
||||
num_tokens = q.shape[0]
|
||||
q = self.q_norm(q.reshape(-1, self.head_dim)).view(num_tokens, self.num_heads, self.head_dim)
|
||||
k = self.k_norm(k.reshape(-1, self.head_dim)).view(num_tokens, self.num_kv_heads, self.head_dim)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
o = self.attn(q, k, v)
|
||||
output = self.o_proj(o.flatten(1, -1))
|
||||
|
||||
Reference in New Issue
Block a user