Files
nano-vllm/nanovllm/layers/sampler.py
GeeeekExplorer b98e1ca305 fix
2025-06-10 21:25:54 +08:00

18 lines
631 B
Python

import torch
from torch import nn
class Sampler(nn.Module):
def __init__(self):
super().__init__()
def forward(self, logits: torch.Tensor, temperatures: torch.Tensor):
logits = logits.to(torch.float)
greedy_tokens = logits.argmax(dim=-1)
logits.div_(temperatures.unsqueeze(dim=1))
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
# logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
sample_tokens = probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1)
return torch.where(temperatures == 0, greedy_tokens, sample_tokens)