This commit is contained in:
GeeeekExplorer
2025-06-21 17:04:53 +08:00
parent ad4e95fbdc
commit cde3fc22c2
9 changed files with 42 additions and 100 deletions

View File

@@ -1,5 +1,4 @@
from functools import lru_cache
import torch
from torch import nn
@@ -28,12 +27,9 @@ class RotaryEmbedding(nn.Module):
) -> None:
super().__init__()
self.head_size = head_size
self.rotary_dim = rotary_dim
assert rotary_dim == head_size
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (base**(torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))
t = torch.arange(self.max_position_embeddings, dtype=torch.float)
inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2, dtype=torch.float) / rotary_dim))
t = torch.arange(max_position_embeddings, dtype=torch.float)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos()
sin = freqs.sin()
@@ -47,8 +43,7 @@ class RotaryEmbedding(nn.Module):
query: torch.Tensor,
key: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
positions = positions.flatten()
num_tokens = positions.shape[0]
num_tokens = positions.size(0)
cos_sin = self.cos_sin_cache[positions]
cos, sin = cos_sin.chunk(2, dim=-1)
query_shape = query.shape