[feat] Added bench_offload.py and GreedySampler.
This commit is contained in:
@@ -13,3 +13,13 @@ class Sampler(nn.Module):
|
||||
probs = torch.softmax(logits, dim=-1)
|
||||
sample_tokens = probs.div_(torch.empty_like(probs).exponential_(1).clamp_min_(1e-10)).argmax(dim=-1)
|
||||
return sample_tokens
|
||||
|
||||
|
||||
class GreedySampler(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@torch.compile
|
||||
def forward(self, logits: torch.Tensor, temperatures: torch.Tensor = None):
|
||||
return logits.argmax(dim=-1)
|
||||
|
||||
Reference in New Issue
Block a user