fix
This commit is contained in:
@@ -4,7 +4,6 @@ import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
|
||||
# from nanovllm.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
|
||||
from nanovllm.utils.context import get_context
|
||||
|
||||
|
||||
@@ -64,8 +63,8 @@ class Attention(nn.Module):
|
||||
context = get_context()
|
||||
k_cache = self.k_cache
|
||||
v_cache = self.v_cache
|
||||
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
|
||||
if context.is_prefill:
|
||||
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
|
||||
if context.block_tables is None: # normal prefill
|
||||
cu_seqlens_k = context.cu_seqlens_k
|
||||
seqused_k = None
|
||||
@@ -79,7 +78,7 @@ class Attention(nn.Module):
|
||||
seqused_k=seqused_k, softmax_scale=self.scale,
|
||||
causal=True, block_table=context.block_tables)
|
||||
else: # decode
|
||||
o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache, k.unsqueeze(1), v.unsqueeze(1),
|
||||
o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache,
|
||||
cache_seqlens=context.context_lens, block_table=context.block_tables,
|
||||
softmax_scale=self.scale, causal=True)
|
||||
o = o.view(-1, self.num_heads * self.head_dim)
|
||||
|
||||
@@ -7,11 +7,11 @@ class Sampler(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, logits: torch.Tensor, temperatures: torch.Tensor | None = None):
|
||||
def forward(self, logits: torch.Tensor, temperatures: torch.Tensor):
|
||||
logits = logits.to(torch.float)
|
||||
if temperatures is not None:
|
||||
logits.div_(temperatures.unsqueeze(dim=1))
|
||||
greedy_tokens = logits.argmax(dim=-1)
|
||||
logits.div_(temperatures.unsqueeze(dim=1))
|
||||
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
||||
# logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
|
||||
sampled_tokens = probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1)
|
||||
return sampled_tokens
|
||||
sample_tokens = probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1)
|
||||
return torch.where(temperatures == 0, greedy_tokens, sample_tokens)
|
||||
|
||||
Reference in New Issue
Block a user