Fix: Division-by-Zero Risk and Typo

This commit is contained in:
xiaohajiayou
2025-06-24 02:02:33 +08:00
parent 03cfc13bb3
commit 054aec852d
2 changed files with 4 additions and 3 deletions

View File

@@ -13,5 +13,6 @@ class Sampler(nn.Module):
logits.div_(temperatures.unsqueeze(dim=1))
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
# logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
sample_tokens = probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1)
epsilon = 1e-10
sample_tokens = probs.div_(torch.empty_like(probs).exponential_(1) + epsilon).argmax(dim=-1)
return torch.where(temperatures == 0, greedy_tokens, sample_tokens)