support qwen2
This commit is contained in:
@@ -37,6 +37,7 @@ class Qwen3Attention(nn.Module):
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim ** -0.5
|
||||
self.qkv_bias = qkv_bias
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size,
|
||||
@@ -63,8 +64,9 @@ class Qwen3Attention(nn.Module):
|
||||
self.scaling,
|
||||
self.num_kv_heads,
|
||||
)
|
||||
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
|
||||
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
|
||||
if not self.qkv_bias:
|
||||
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
|
||||
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -73,9 +75,12 @@ class Qwen3Attention(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
qkv = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q = self.q_norm(q.view(-1, self.num_heads, self.head_dim))
|
||||
k = self.k_norm(k.view(-1, self.num_kv_heads, self.head_dim))
|
||||
q = q.view(-1, self.num_heads, self.head_dim)
|
||||
k = k.view(-1, self.num_kv_heads, self.head_dim)
|
||||
v = v.view(-1, self.num_kv_heads, self.head_dim)
|
||||
if not self.qkv_bias:
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
o = self.attn(q, k, v)
|
||||
output = self.o_proj(o.flatten(1, -1))
|
||||
@@ -124,7 +129,7 @@ class Qwen3DecoderLayer(nn.Module):
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
max_position=config.max_position_embeddings,
|
||||
rms_norm_eps=config.rms_norm_eps,
|
||||
qkv_bias=getattr(config, 'attention_bias', False),
|
||||
qkv_bias=getattr(config, 'attention_bias', True),
|
||||
head_dim=getattr(config, 'head_dim', None),
|
||||
rope_theta=getattr(config, "rope_theta", 1000000),
|
||||
rope_scaling=getattr(config, "rope_scaling", None),
|
||||
|
||||
Reference in New Issue
Block a user