17 lines
559 B
Python
17 lines
559 B
Python
import torch
|
|
from torch import nn
|
|
|
|
|
|
class Sampler(nn.Module):
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, logits: torch.Tensor, temperatures: torch.Tensor):
|
|
logits = logits.float()
|
|
greedy_tokens = logits.argmax(dim=-1)
|
|
logits.div_(temperatures.unsqueeze(dim=1))
|
|
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
|
sample_tokens = probs.div_(torch.empty_like(probs).exponential_(1) + 1e-10).argmax(dim=-1)
|
|
return torch.where(temperatures == 0, greedy_tokens, sample_tokens)
|