This commit is contained in:
GeeeekExplorer
2025-08-31 19:44:57 +08:00
parent 6a6d217de7
commit df99418f7d
11 changed files with 47 additions and 96 deletions

View File

@@ -8,11 +8,9 @@ class Sampler(nn.Module):
super().__init__()
def forward(self, logits: torch.Tensor, temperatures: torch.Tensor):
logits = logits.to(torch.float)
logits = logits.float()
greedy_tokens = logits.argmax(dim=-1)
logits.div_(temperatures.unsqueeze(dim=1))
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
# logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
epsilon = 1e-10
sample_tokens = probs.div_(torch.empty_like(probs).exponential_(1) + epsilon).argmax(dim=-1)
sample_tokens = probs.div_(torch.empty_like(probs).exponential_(1) + 1e-10).argmax(dim=-1)
return torch.where(temperatures == 0, greedy_tokens, sample_tokens)