support tensor parallel

This commit is contained in:
cheunglei
2025-06-15 01:31:24 +08:00
parent b6136383c9
commit 53b3ef2e32
9 changed files with 102 additions and 31 deletions

View File

@@ -14,8 +14,8 @@ class VocabParallelEmbedding(nn.Module):
embedding_dim: int,
):
super().__init__()
self.tp_rank = 0 # get_tensor_model_parallel_rank()
self.tp_size = 1 # get_tensor_model_parallel_world_size()
self.tp_rank = dist.get_rank()
self.tp_size = dist.get_world_size()
assert num_embeddings % self.tp_size == 0
self.num_embeddings = num_embeddings
self.num_embeddings_per_partition = self.num_embeddings // self.tp_size
@@ -39,7 +39,7 @@ class VocabParallelEmbedding(nn.Module):
x = mask * (x - self.vocab_start_idx)
y = F.embedding(x, self.weight)
if self.tp_size > 1:
y = mask * y
y = mask.unsqueeze(1) * y
dist.all_reduce(y)
return y
@@ -65,8 +65,8 @@ class ParallelLMHead(VocabParallelEmbedding):
last_indices = context.cu_seqlens_q[1:] - 1
x = x[last_indices].contiguous()
logits = F.linear(x, self.weight, self.bias)
# if self.tp_size > 1:
# all_logits = [torch.empty_like(logits) for _ in range(self.tp_size)]
# dist.gather(logits, all_logits, 0)
# logits = torch.cat(all_logits, -1)
return logits if self.tp_rank == 0 else None
if self.tp_size > 1:
all_logits = [torch.empty_like(logits) for _ in range(self.tp_size)] if self.tp_rank == 0 else None
dist.gather(logits, all_logits, 0)
logits = torch.cat(all_logits, -1) if self.tp_rank == 0 else None
return logits

View File

@@ -21,8 +21,8 @@ class LinearBase(nn.Module):
self.input_size = input_size
self.output_size = output_size
self.tp_dim = tp_dim
self.tp_rank = 0 # get_tensor_model_parallel_rank()
self.tp_size = 1 # get_tensor_model_parallel_world_size()
self.tp_rank = dist.get_rank()
self.tp_size = dist.get_world_size()
def forward(self, x: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
@@ -65,7 +65,6 @@ class ColumnParallelLinear(LinearBase):
self.input_size_per_partition = input_size
self.output_size_per_partition = divide(output_size, self.tp_size)
self.output_partition_sizes = [self.output_size_per_partition]
# If QKV or MergedColumn, use output size of each partition.
if hasattr(self, "output_sizes"):
self.output_partition_sizes = [
divide(output_size, self.tp_size)
@@ -101,8 +100,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
bias: bool = False,
):
self.output_sizes = output_sizes
tp_size = 1 # get_tensor_model_parallel_world_size()
assert all(output_size % tp_size == 0 for output_size in output_sizes)
super().__init__(input_size, sum(output_sizes), bias=bias)
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: int):
@@ -110,7 +107,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size
shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size)
# loaded_weight = loaded_weight.narrow(self.tp_dim, self.tp_rank * shard_size, shard_size)
loaded_weight = loaded_weight.chunk(self.tp_size, self.tp_dim)[self.tp_rank]
assert param_data.size() == loaded_weight.size()
param_data.copy_(loaded_weight)
@@ -131,8 +128,7 @@ class QKVParallelLinear(ColumnParallelLinear):
if total_num_kv_heads is None:
total_num_kv_heads = total_num_heads
self.total_num_kv_heads = total_num_kv_heads
# Divide the weight matrix along the last dimension.
tp_size = 1 # get_tensor_model_parallel_world_size()
tp_size = dist.get_world_size()
self.num_heads = divide(self.total_num_heads, tp_size)
self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
input_size = self.hidden_size
@@ -158,7 +154,7 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_size = self.num_kv_heads * self.head_size
shard_offset = self.num_heads * self.head_size + self.num_kv_heads * self.head_size
param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size)
# loaded_weight = loaded_weight.narrow(self.tp_dim, self.tp_rank * shard_size, shard_size)
loaded_weight = loaded_weight.chunk(self.tp_size, self.tp_dim)[self.tp_rank]
assert param_data.size() == loaded_weight.size()
param_data.copy_(loaded_weight)