simplify
This commit is contained in:
@@ -21,7 +21,6 @@ class VocabParallelEmbedding(nn.Module):
|
||||
self.num_embeddings_per_partition = self.num_embeddings // self.tp_size
|
||||
self.vocab_start_idx = self.num_embeddings_per_partition * self.tp_rank
|
||||
self.vocab_end_idx = self.vocab_start_idx + self.num_embeddings_per_partition
|
||||
self.embedding_dim = embedding_dim
|
||||
self.weight = nn.Parameter(torch.empty(self.num_embeddings_per_partition, embedding_dim))
|
||||
self.weight.weight_loader = self.weight_loader
|
||||
|
||||
|
||||
@@ -64,12 +64,6 @@ class ColumnParallelLinear(LinearBase):
|
||||
super().__init__(input_size, output_size, 0)
|
||||
self.input_size_per_partition = input_size
|
||||
self.output_size_per_partition = divide(output_size, self.tp_size)
|
||||
self.output_partition_sizes = [self.output_size_per_partition]
|
||||
if hasattr(self, "output_sizes"):
|
||||
self.output_partition_sizes = [
|
||||
divide(output_size, self.tp_size)
|
||||
for output_size in self.output_sizes
|
||||
]
|
||||
|
||||
self.weight = nn.Parameter(torch.empty(self.output_size_per_partition, self.input_size))
|
||||
self.weight.weight_loader = self.weight_loader
|
||||
@@ -122,23 +116,14 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
total_num_kv_heads: int | None = None,
|
||||
bias: bool = False,
|
||||
):
|
||||
self.hidden_size = hidden_size
|
||||
self.head_size = head_size
|
||||
self.total_num_heads = total_num_heads
|
||||
if total_num_kv_heads is None:
|
||||
total_num_kv_heads = total_num_heads
|
||||
self.total_num_kv_heads = total_num_kv_heads
|
||||
self.total_num_kv_heads = total_num_kv_heads or total_num_heads
|
||||
tp_size = dist.get_world_size()
|
||||
self.num_heads = divide(self.total_num_heads, tp_size)
|
||||
self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
|
||||
input_size = self.hidden_size
|
||||
output_size = (self.num_heads + 2 * self.num_kv_heads) * tp_size * self.head_size
|
||||
self.output_sizes = [
|
||||
self.num_heads * self.head_size * tp_size, # q_proj
|
||||
self.num_kv_heads * self.head_size * tp_size, # k_proj
|
||||
self.num_kv_heads * self.head_size * tp_size, # v_proj
|
||||
]
|
||||
|
||||
input_size = hidden_size
|
||||
output_size = (self.total_num_heads + 2 * self.total_num_kv_heads) * self.head_size
|
||||
super().__init__(input_size, output_size, bias)
|
||||
|
||||
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: str):
|
||||
@@ -170,7 +155,6 @@ class RowParallelLinear(LinearBase):
|
||||
super().__init__(input_size, output_size, 1)
|
||||
self.input_size_per_partition = divide(input_size, self.tp_size)
|
||||
self.output_size_per_partition = output_size
|
||||
self.output_partition_sizes = [output_size]
|
||||
|
||||
self.weight = nn.Parameter(torch.empty(self.output_size, self.input_size_per_partition))
|
||||
self.weight.weight_loader = self.weight_loader
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from functools import lru_cache
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
@@ -28,12 +27,9 @@ class RotaryEmbedding(nn.Module):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.head_size = head_size
|
||||
self.rotary_dim = rotary_dim
|
||||
assert rotary_dim == head_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
inv_freq = 1.0 / (base**(torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))
|
||||
t = torch.arange(self.max_position_embeddings, dtype=torch.float)
|
||||
inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2, dtype=torch.float) / rotary_dim))
|
||||
t = torch.arange(max_position_embeddings, dtype=torch.float)
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
@@ -47,8 +43,7 @@ class RotaryEmbedding(nn.Module):
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
positions = positions.flatten()
|
||||
num_tokens = positions.shape[0]
|
||||
num_tokens = positions.size(0)
|
||||
cos_sin = self.cos_sin_cache[positions]
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
query_shape = query.shape
|
||||
|
||||
Reference in New Issue
Block a user