[refactor] Translate into english, void Chinese due to claude.

This commit is contained in:
Zijie Tian
2025-12-11 00:30:24 +08:00
parent e85c2b4776
commit babfa17354
9 changed files with 297 additions and 187 deletions

View File

@@ -79,8 +79,12 @@ class Qwen3Attention(nn.Module):
k = k.view(-1, self.num_kv_heads, self.head_dim)
v = v.view(-1, self.num_kv_heads, self.head_dim)
if not self.qkv_bias:
q = self.q_norm(q)
k = self.k_norm(k)
# Reshape to 2D before RMSNorm to avoid torch.compile recompilation
# q: [num_tokens, num_heads, head_dim] -> [num_tokens * num_heads, head_dim]
# After norm, reshape back to 3D
num_tokens = q.shape[0]
q = self.q_norm(q.reshape(-1, self.head_dim)).view(num_tokens, self.num_heads, self.head_dim)
k = self.k_norm(k.reshape(-1, self.head_dim)).view(num_tokens, self.num_kv_heads, self.head_dim)
q, k = self.rotary_emb(positions, q, k)
o = self.attn(q, k, v)
output = self.o_proj(o.flatten(1, -1))