This commit is contained in:
GeeeekExplorer
2025-06-15 10:31:48 +08:00
parent c1fd4ea3c2
commit fc778a4da9
10 changed files with 19 additions and 22 deletions

View File

@@ -69,4 +69,4 @@ class ParallelLMHead(VocabParallelEmbedding):
all_logits = [torch.empty_like(logits) for _ in range(self.tp_size)] if self.tp_rank == 0 else None
dist.gather(logits, all_logits, 0)
logits = torch.cat(all_logits, -1) if self.tp_rank == 0 else None
return logits
return logits