Merge pull request #39 from xiaohajiayou/main

This commit is contained in:
Xingkai Yu
2025-06-24 22:51:58 +08:00
committed by GitHub
2 changed files with 4 additions and 3 deletions

View File

@@ -15,8 +15,8 @@ from nanovllm.engine.model_runner import ModelRunner
class LLMEngine: class LLMEngine:
def __init__(self, model, **kwargs): def __init__(self, model, **kwargs):
config_fileds = {field.name for field in fields(Config)} config_fields = {field.name for field in fields(Config)}
config_kwargs = {k: v for k, v in kwargs.items() if k in config_fileds} config_kwargs = {k: v for k, v in kwargs.items() if k in config_fields}
config = Config(model, **config_kwargs) config = Config(model, **config_kwargs)
self.ps = [] self.ps = []
self.events = [] self.events = []

View File

@@ -13,5 +13,6 @@ class Sampler(nn.Module):
logits.div_(temperatures.unsqueeze(dim=1)) logits.div_(temperatures.unsqueeze(dim=1))
probs = torch.softmax(logits, dim=-1, dtype=torch.float) probs = torch.softmax(logits, dim=-1, dtype=torch.float)
# logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) # logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
sample_tokens = probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1) epsilon = 1e-10
sample_tokens = probs.div_(torch.empty_like(probs).exponential_(1) + epsilon).argmax(dim=-1)
return torch.where(temperatures == 0, greedy_tokens, sample_tokens) return torch.where(temperatures == 0, greedy_tokens, sample_tokens)