import torch from torch import nn import torch.nn.functional as F import torch.distributed as dist def divide(numerator, denominator): assert numerator % denominator == 0 return numerator // denominator class LinearBase(nn.Module): def __init__( self, input_size: int, output_size: int, tp_dim: int | None = None, ): super().__init__() self.input_size = input_size self.output_size = output_size self.tp_dim = tp_dim self.tp_rank = 0 # get_tensor_model_parallel_rank() self.tp_size = 1 # get_tensor_model_parallel_world_size() def forward(self, x: torch.Tensor) -> torch.Tensor: raise NotImplementedError class ReplicatedLinear(LinearBase): def __init__( self, input_size: int, output_size: int, bias: bool = False, ): super().__init__(input_size, output_size) self.weight = nn.Parameter(torch.empty(self.output_size, self.input_size)) self.weight.weight_loader = self.weight_loader if bias: self.bias = nn.Parameter(torch.empty(self.output_size)) self.bias.weight_loader = self.weight_loader else: self.register_parameter("bias", None) def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): assert param.size() == loaded_weight.size() param.data.copy_(loaded_weight) def forward(self, x: torch.Tensor) -> torch.Tensor: return F.linear(x, self.weight, self.bias) class ColumnParallelLinear(LinearBase): def __init__( self, input_size: int, output_size: int, bias: bool = False, ): super().__init__(input_size, output_size, 0) self.input_size_per_partition = input_size self.output_size_per_partition = divide(output_size, self.tp_size) self.output_partition_sizes = [self.output_size_per_partition] # If QKV or MergedColumn, use output size of each partition. if hasattr(self, "output_sizes"): self.output_partition_sizes = [ divide(output_size, self.tp_size) for output_size in self.output_sizes ] self.weight = nn.Parameter(torch.empty(self.output_size_per_partition, self.input_size)) self.weight.weight_loader = self.weight_loader if bias: self.bias = nn.Parameter(torch.empty(self.output_size_per_partition)) self.bias.weight_loader = self.weight_loader else: self.register_parameter("bias", None) def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): param_data = param.data shard_size = param_data.size(self.tp_dim) start_idx = self.tp_rank * shard_size loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size) assert param_data.size() == loaded_weight.size() param_data.copy_(loaded_weight) def forward(self, x: torch.Tensor) -> torch.Tensor: return F.linear(x, self.weight, self.bias) class MergedColumnParallelLinear(ColumnParallelLinear): def __init__( self, input_size: int, output_sizes: list[int], bias: bool = False, ): self.output_sizes = output_sizes tp_size = 1 # get_tensor_model_parallel_world_size() assert all(output_size % tp_size == 0 for output_size in output_sizes) super().__init__(input_size, sum(output_sizes), bias=bias) def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: int | None = None): param_data = param.data shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size shard_size = self.output_sizes[loaded_shard_id] // self.tp_size param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size) # loaded_weight = loaded_weight.narrow(self.tp_dim, self.tp_rank * shard_size, shard_size) assert param_data.size() == loaded_weight.size() param_data.copy_(loaded_weight) class QKVParallelLinear(ColumnParallelLinear): def __init__( self, hidden_size: int, head_size: int, total_num_heads: int, total_num_kv_heads: int | None = None, bias: bool = False, ): self.hidden_size = hidden_size self.head_size = head_size self.total_num_heads = total_num_heads if total_num_kv_heads is None: total_num_kv_heads = total_num_heads self.total_num_kv_heads = total_num_kv_heads # Divide the weight matrix along the last dimension. tp_size = 1 # get_tensor_model_parallel_world_size() self.num_heads = divide(self.total_num_heads, tp_size) self.num_kv_heads = divide(self.total_num_kv_heads, tp_size) input_size = self.hidden_size output_size = (self.num_heads + 2 * self.num_kv_heads) * tp_size * self.head_size self.output_sizes = [ self.num_heads * self.head_size * tp_size, # q_proj self.num_kv_heads * self.head_size * tp_size, # k_proj self.num_kv_heads * self.head_size * tp_size, # v_proj ] super().__init__(input_size, output_size, bias) def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: str | None = None): param_data = param.data assert loaded_shard_id in ["q", "k", "v"] if loaded_shard_id == "q": shard_size = self.num_heads * self.head_size shard_offset = 0 elif loaded_shard_id == "k": shard_size = self.num_kv_heads * self.head_size shard_offset = self.num_heads * self.head_size else: shard_size = self.num_kv_heads * self.head_size shard_offset = self.num_heads * self.head_size + self.num_kv_heads * self.head_size param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size) # loaded_weight = loaded_weight.narrow(self.tp_dim, self.tp_rank * shard_size, shard_size) assert param_data.size() == loaded_weight.size() param_data.copy_(loaded_weight) class RowParallelLinear(LinearBase): def __init__( self, input_size: int, output_size: int, bias: bool = False, ): super().__init__(input_size, output_size, 1) self.input_size_per_partition = divide(input_size, self.tp_size) self.output_size_per_partition = output_size self.output_partition_sizes = [output_size] self.weight = nn.Parameter(torch.empty(self.output_size, self.input_size_per_partition)) self.weight.weight_loader = self.weight_loader if bias: self.bias = nn.Parameter(torch.empty(self.output_size)) self.bias.weight_loader = self.weight_loader else: self.register_parameter("bias", None) def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): param_data = param.data shard_size = param_data.size(self.tp_dim) start_idx = self.tp_rank * shard_size loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size) assert param_data.size() == loaded_weight.size() param_data.copy_(loaded_weight) def forward(self, x: torch.Tensor) -> torch.Tensor: y = F.linear(x, self.weight, self.bias if self.tp_rank == 0 else None) if self.tp_size > 1: dist.all_reduce(y) return y