support tensor parallel
This commit is contained in:
@@ -14,8 +14,8 @@ class VocabParallelEmbedding(nn.Module):
|
||||
embedding_dim: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.tp_rank = 0 # get_tensor_model_parallel_rank()
|
||||
self.tp_size = 1 # get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = dist.get_rank()
|
||||
self.tp_size = dist.get_world_size()
|
||||
assert num_embeddings % self.tp_size == 0
|
||||
self.num_embeddings = num_embeddings
|
||||
self.num_embeddings_per_partition = self.num_embeddings // self.tp_size
|
||||
@@ -39,7 +39,7 @@ class VocabParallelEmbedding(nn.Module):
|
||||
x = mask * (x - self.vocab_start_idx)
|
||||
y = F.embedding(x, self.weight)
|
||||
if self.tp_size > 1:
|
||||
y = mask * y
|
||||
y = mask.unsqueeze(1) * y
|
||||
dist.all_reduce(y)
|
||||
return y
|
||||
|
||||
@@ -65,8 +65,8 @@ class ParallelLMHead(VocabParallelEmbedding):
|
||||
last_indices = context.cu_seqlens_q[1:] - 1
|
||||
x = x[last_indices].contiguous()
|
||||
logits = F.linear(x, self.weight, self.bias)
|
||||
# if self.tp_size > 1:
|
||||
# all_logits = [torch.empty_like(logits) for _ in range(self.tp_size)]
|
||||
# dist.gather(logits, all_logits, 0)
|
||||
# logits = torch.cat(all_logits, -1)
|
||||
return logits if self.tp_rank == 0 else None
|
||||
if self.tp_size > 1:
|
||||
all_logits = [torch.empty_like(logits) for _ in range(self.tp_size)] if self.tp_rank == 0 else None
|
||||
dist.gather(logits, all_logits, 0)
|
||||
logits = torch.cat(all_logits, -1) if self.tp_rank == 0 else None
|
||||
return logits
|
||||
Reference in New Issue
Block a user