diff --git a/nanovllm/layers/sampler.py b/nanovllm/layers/sampler.py index a5e7ddc..b101018 100644 --- a/nanovllm/layers/sampler.py +++ b/nanovllm/layers/sampler.py @@ -7,10 +7,9 @@ class Sampler(nn.Module): def __init__(self): super().__init__() + @torch.compile def forward(self, logits: torch.Tensor, temperatures: torch.Tensor): - logits = logits.float() - greedy_tokens = logits.argmax(dim=-1) - logits.div_(temperatures.unsqueeze(dim=1)) - probs = torch.softmax(logits, dim=-1, dtype=torch.float) - sample_tokens = probs.div_(torch.empty_like(probs).exponential_(1) + 1e-10).argmax(dim=-1) - return torch.where(temperatures == 0, greedy_tokens, sample_tokens) + logits = logits.float().div_(temperatures.unsqueeze(dim=1)) + probs = torch.softmax(logits, dim=-1) + sample_tokens = probs.div_(torch.empty_like(probs).exponential_(1).clamp_min_(1e-10)).argmax(dim=-1) + return sample_tokens diff --git a/nanovllm/sampling_params.py b/nanovllm/sampling_params.py index 67d60bb..c9f872f 100644 --- a/nanovllm/sampling_params.py +++ b/nanovllm/sampling_params.py @@ -6,3 +6,6 @@ class SamplingParams: temperature: float = 1.0 max_tokens: int = 64 ignore_eos: bool = False + + def __post_init__(self): + assert self.temperature > 1e-10, "greedy sampling is not permitted"