From 6ef2a4f630ab162a7855dce500b86adff2a4465c Mon Sep 17 00:00:00 2001 From: GeeeekExplorer <2651904866@qq.com> Date: Sun, 31 Aug 2025 22:55:34 +0800 Subject: [PATCH] compile random sampling --- nanovllm/layers/sampler.py | 11 +++++------ nanovllm/sampling_params.py | 3 +++ 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/nanovllm/layers/sampler.py b/nanovllm/layers/sampler.py index a5e7ddc..b101018 100644 --- a/nanovllm/layers/sampler.py +++ b/nanovllm/layers/sampler.py @@ -7,10 +7,9 @@ class Sampler(nn.Module): def __init__(self): super().__init__() + @torch.compile def forward(self, logits: torch.Tensor, temperatures: torch.Tensor): - logits = logits.float() - greedy_tokens = logits.argmax(dim=-1) - logits.div_(temperatures.unsqueeze(dim=1)) - probs = torch.softmax(logits, dim=-1, dtype=torch.float) - sample_tokens = probs.div_(torch.empty_like(probs).exponential_(1) + 1e-10).argmax(dim=-1) - return torch.where(temperatures == 0, greedy_tokens, sample_tokens) + logits = logits.float().div_(temperatures.unsqueeze(dim=1)) + probs = torch.softmax(logits, dim=-1) + sample_tokens = probs.div_(torch.empty_like(probs).exponential_(1).clamp_min_(1e-10)).argmax(dim=-1) + return sample_tokens diff --git a/nanovllm/sampling_params.py b/nanovllm/sampling_params.py index 67d60bb..c9f872f 100644 --- a/nanovllm/sampling_params.py +++ b/nanovllm/sampling_params.py @@ -6,3 +6,6 @@ class SamplingParams: temperature: float = 1.0 max_tokens: int = 64 ignore_eos: bool = False + + def __post_init__(self): + assert self.temperature > 1e-10, "greedy sampling is not permitted"