Files
nano-vllm/nanovllm/layers/sampler.py
2025-06-24 02:02:33 +08:00

19 lines
669 B
Python

import torch
from torch import nn
class Sampler(nn.Module):
def __init__(self):
super().__init__()
def forward(self, logits: torch.Tensor, temperatures: torch.Tensor):
logits = logits.to(torch.float)
greedy_tokens = logits.argmax(dim=-1)
logits.div_(temperatures.unsqueeze(dim=1))
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
# logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
epsilon = 1e-10
sample_tokens = probs.div_(torch.empty_like(probs).exponential_(1) + epsilon).argmax(dim=-1)
return torch.where(temperatures == 0, greedy_tokens, sample_tokens)