compile random sampling
This commit is contained in:
@@ -7,10 +7,9 @@ class Sampler(nn.Module):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@torch.compile
|
||||||
def forward(self, logits: torch.Tensor, temperatures: torch.Tensor):
|
def forward(self, logits: torch.Tensor, temperatures: torch.Tensor):
|
||||||
logits = logits.float()
|
logits = logits.float().div_(temperatures.unsqueeze(dim=1))
|
||||||
greedy_tokens = logits.argmax(dim=-1)
|
probs = torch.softmax(logits, dim=-1)
|
||||||
logits.div_(temperatures.unsqueeze(dim=1))
|
sample_tokens = probs.div_(torch.empty_like(probs).exponential_(1).clamp_min_(1e-10)).argmax(dim=-1)
|
||||||
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
return sample_tokens
|
||||||
sample_tokens = probs.div_(torch.empty_like(probs).exponential_(1) + 1e-10).argmax(dim=-1)
|
|
||||||
return torch.where(temperatures == 0, greedy_tokens, sample_tokens)
|
|
||||||
|
|||||||
@@ -6,3 +6,6 @@ class SamplingParams:
|
|||||||
temperature: float = 1.0
|
temperature: float = 1.0
|
||||||
max_tokens: int = 64
|
max_tokens: int = 64
|
||||||
ignore_eos: bool = False
|
ignore_eos: bool = False
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
assert self.temperature > 1e-10, "greedy sampling is not permitted"
|
||||||
|
|||||||
Reference in New Issue
Block a user