This commit is contained in:
GeeeekExplorer
2025-06-10 08:52:58 +08:00
parent a5a4909e6a
commit b98e1ca305
10 changed files with 39 additions and 26 deletions

View File

@@ -7,11 +7,11 @@ class Sampler(nn.Module):
def __init__(self):
super().__init__()
def forward(self, logits: torch.Tensor, temperatures: torch.Tensor | None = None):
def forward(self, logits: torch.Tensor, temperatures: torch.Tensor):
logits = logits.to(torch.float)
if temperatures is not None:
logits.div_(temperatures.unsqueeze(dim=1))
greedy_tokens = logits.argmax(dim=-1)
logits.div_(temperatures.unsqueeze(dim=1))
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
# logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
sampled_tokens = probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1)
return sampled_tokens
sample_tokens = probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1)
return torch.where(temperatures == 0, greedy_tokens, sample_tokens)