init commit

This commit is contained in:
GeeeekExplorer
2025-06-10 00:23:23 +08:00
commit a5a4909e6a
26 changed files with 1677 additions and 0 deletions

View File

@@ -0,0 +1,17 @@
import torch
from torch import nn
class Sampler(nn.Module):
def __init__(self):
super().__init__()
def forward(self, logits: torch.Tensor, temperatures: torch.Tensor | None = None):
logits = logits.to(torch.float)
if temperatures is not None:
logits.div_(temperatures.unsqueeze(dim=1))
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
# logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
sampled_tokens = probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1)
return sampled_tokens