This commit is contained in:
GeeeekExplorer
2025-08-31 19:44:57 +08:00
parent 6a6d217de7
commit df99418f7d
11 changed files with 47 additions and 96 deletions

View File

@@ -15,14 +15,20 @@ class LinearBase(nn.Module):
self,
input_size: int,
output_size: int,
bias: bool = False,
tp_dim: int | None = None,
):
super().__init__()
self.input_size = input_size
self.output_size = output_size
self.tp_dim = tp_dim
self.tp_rank = dist.get_rank()
self.tp_size = dist.get_world_size()
self.weight = nn.Parameter(torch.empty(output_size, input_size))
self.weight.weight_loader = self.weight_loader
if bias:
self.bias = nn.Parameter(torch.empty(output_size))
self.bias.weight_loader = self.weight_loader
else:
self.register_parameter("bias", None)
def forward(self, x: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
@@ -36,14 +42,7 @@ class ReplicatedLinear(LinearBase):
output_size: int,
bias: bool = False,
):
super().__init__(input_size, output_size)
self.weight = nn.Parameter(torch.empty(self.output_size, self.input_size))
self.weight.weight_loader = self.weight_loader
if bias:
self.bias = nn.Parameter(torch.empty(self.output_size))
self.bias.weight_loader = self.weight_loader
else:
self.register_parameter("bias", None)
super().__init__(input_size, output_size, bias)
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
param.data.copy_(loaded_weight)
@@ -60,17 +59,8 @@ class ColumnParallelLinear(LinearBase):
output_size: int,
bias: bool = False,
):
super().__init__(input_size, output_size, 0)
self.input_size_per_partition = input_size
self.output_size_per_partition = divide(output_size, self.tp_size)
self.weight = nn.Parameter(torch.empty(self.output_size_per_partition, self.input_size))
self.weight.weight_loader = self.weight_loader
if bias:
self.bias = nn.Parameter(torch.empty(self.output_size_per_partition))
self.bias.weight_loader = self.weight_loader
else:
self.register_parameter("bias", None)
tp_size = dist.get_world_size()
super().__init__(input_size, divide(output_size, tp_size), bias, 0)
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
param_data = param.data
@@ -92,7 +82,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
bias: bool = False,
):
self.output_sizes = output_sizes
super().__init__(input_size, sum(output_sizes), bias=bias)
super().__init__(input_size, sum(output_sizes), bias)
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: int):
param_data = param.data
@@ -113,15 +103,13 @@ class QKVParallelLinear(ColumnParallelLinear):
total_num_kv_heads: int | None = None,
bias: bool = False,
):
self.head_size = head_size
self.total_num_heads = total_num_heads
self.total_num_kv_heads = total_num_kv_heads or total_num_heads
tp_size = dist.get_world_size()
self.num_heads = divide(self.total_num_heads, tp_size)
self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
input_size = hidden_size
output_size = (self.total_num_heads + 2 * self.total_num_kv_heads) * self.head_size
super().__init__(input_size, output_size, bias)
total_num_kv_heads = total_num_kv_heads or total_num_heads
self.head_size = head_size
self.num_heads = divide(total_num_heads, tp_size)
self.num_kv_heads = divide(total_num_kv_heads, tp_size)
output_size = (total_num_heads + 2 * total_num_kv_heads) * self.head_size
super().__init__(hidden_size, output_size, bias)
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: str):
param_data = param.data
@@ -148,17 +136,8 @@ class RowParallelLinear(LinearBase):
output_size: int,
bias: bool = False,
):
super().__init__(input_size, output_size, 1)
self.input_size_per_partition = divide(input_size, self.tp_size)
self.output_size_per_partition = output_size
self.weight = nn.Parameter(torch.empty(self.output_size, self.input_size_per_partition))
self.weight.weight_loader = self.weight_loader
if bias:
self.bias = nn.Parameter(torch.empty(self.output_size))
self.bias.weight_loader = self.weight_loader
else:
self.register_parameter("bias", None)
tp_size = dist.get_world_size()
super().__init__(divide(input_size, tp_size), output_size, bias, 1)
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
param_data = param.data