[feat] Added bench_offload.py and GreedySampler.

This commit is contained in:
Zijie Tian
2025-12-12 00:24:08 +08:00
parent 0bd7ba7536
commit 60d24f7c12
4 changed files with 37 additions and 8 deletions

View File

@@ -13,3 +13,13 @@ class Sampler(nn.Module):
probs = torch.softmax(logits, dim=-1)
sample_tokens = probs.div_(torch.empty_like(probs).exponential_(1).clamp_min_(1e-10)).argmax(dim=-1)
return sample_tokens
class GreedySampler(nn.Module):
def __init__(self):
super().__init__()
@torch.compile
def forward(self, logits: torch.Tensor, temperatures: torch.Tensor = None):
return logits.argmax(dim=-1)