compile random sampling

This commit is contained in:
GeeeekExplorer
2025-08-31 22:55:34 +08:00
parent df99418f7d
commit 6ef2a4f630
2 changed files with 8 additions and 6 deletions

View File

@@ -7,10 +7,9 @@ class Sampler(nn.Module):
def __init__(self):
super().__init__()
@torch.compile
def forward(self, logits: torch.Tensor, temperatures: torch.Tensor):
logits = logits.float()
greedy_tokens = logits.argmax(dim=-1)
logits.div_(temperatures.unsqueeze(dim=1))
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
sample_tokens = probs.div_(torch.empty_like(probs).exponential_(1) + 1e-10).argmax(dim=-1)
return torch.where(temperatures == 0, greedy_tokens, sample_tokens)
logits = logits.float().div_(temperatures.unsqueeze(dim=1))
probs = torch.softmax(logits, dim=-1)
sample_tokens = probs.div_(torch.empty_like(probs).exponential_(1).clamp_min_(1e-10)).argmax(dim=-1)
return sample_tokens