This commit is contained in:
GeeeekExplorer
2025-06-21 17:04:53 +08:00
parent ad4e95fbdc
commit cde3fc22c2
9 changed files with 42 additions and 100 deletions

View File

@@ -64,12 +64,6 @@ class ColumnParallelLinear(LinearBase):
super().__init__(input_size, output_size, 0)
self.input_size_per_partition = input_size
self.output_size_per_partition = divide(output_size, self.tp_size)
self.output_partition_sizes = [self.output_size_per_partition]
if hasattr(self, "output_sizes"):
self.output_partition_sizes = [
divide(output_size, self.tp_size)
for output_size in self.output_sizes
]
self.weight = nn.Parameter(torch.empty(self.output_size_per_partition, self.input_size))
self.weight.weight_loader = self.weight_loader
@@ -122,23 +116,14 @@ class QKVParallelLinear(ColumnParallelLinear):
total_num_kv_heads: int | None = None,
bias: bool = False,
):
self.hidden_size = hidden_size
self.head_size = head_size
self.total_num_heads = total_num_heads
if total_num_kv_heads is None:
total_num_kv_heads = total_num_heads
self.total_num_kv_heads = total_num_kv_heads
self.total_num_kv_heads = total_num_kv_heads or total_num_heads
tp_size = dist.get_world_size()
self.num_heads = divide(self.total_num_heads, tp_size)
self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
input_size = self.hidden_size
output_size = (self.num_heads + 2 * self.num_kv_heads) * tp_size * self.head_size
self.output_sizes = [
self.num_heads * self.head_size * tp_size, # q_proj
self.num_kv_heads * self.head_size * tp_size, # k_proj
self.num_kv_heads * self.head_size * tp_size, # v_proj
]
input_size = hidden_size
output_size = (self.total_num_heads + 2 * self.total_num_kv_heads) * self.head_size
super().__init__(input_size, output_size, bias)
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: str):
@@ -170,7 +155,6 @@ class RowParallelLinear(LinearBase):
super().__init__(input_size, output_size, 1)
self.input_size_per_partition = divide(input_size, self.tp_size)
self.output_size_per_partition = output_size
self.output_partition_sizes = [output_size]
self.weight = nn.Parameter(torch.empty(self.output_size, self.input_size_per_partition))
self.weight.weight_loader = self.weight_loader