This commit is contained in:
GeeeekExplorer
2025-08-31 19:44:57 +08:00
parent 6a6d217de7
commit df99418f7d
11 changed files with 47 additions and 96 deletions

View File

@@ -8,9 +8,7 @@ def apply_rotary_emb(
cos: torch.Tensor,
sin: torch.Tensor,
) -> torch.Tensor:
cos = cos.unsqueeze(-2)
sin = sin.unsqueeze(-2)
x1, x2 = torch.chunk(x.to(torch.float32), 2, dim=-1)
x1, x2 = torch.chunk(x.float(), 2, dim=-1)
y1 = x1 * cos - x2 * sin
y2 = x2 * cos + x1 * sin
return torch.cat((y1, y2), dim=-1).to(x.dtype)
@@ -33,7 +31,7 @@ class RotaryEmbedding(nn.Module):
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
cache = torch.cat((cos, sin), dim=-1).unsqueeze_(1)
self.register_buffer("cos_sin_cache", cache, persistent=False)
@torch.compile
@@ -43,15 +41,10 @@ class RotaryEmbedding(nn.Module):
query: torch.Tensor,
key: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
num_tokens = positions.size(0)
cos_sin = self.cos_sin_cache[positions]
cos, sin = cos_sin.chunk(2, dim=-1)
query_shape = query.shape
query = query.view(num_tokens, -1, self.head_size)
query = apply_rotary_emb(query, cos, sin).view(query_shape)
key_shape = key.shape
key = key.view(num_tokens, -1, self.head_size)
key = apply_rotary_emb(key, cos, sin).view(key_shape)
query = apply_rotary_emb(query, cos, sin)
key = apply_rotary_emb(key, cos, sin)
return query, key