better
This commit is contained in:
@@ -69,4 +69,4 @@ class ParallelLMHead(VocabParallelEmbedding):
|
||||
all_logits = [torch.empty_like(logits) for _ in range(self.tp_size)] if self.tp_rank == 0 else None
|
||||
dist.gather(logits, all_logits, 0)
|
||||
logits = torch.cat(all_logits, -1) if self.tp_rank == 0 else None
|
||||
return logits
|
||||
return logits
|
||||
|
||||
Reference in New Issue
Block a user