This commit is contained in:
GeeeekExplorer
2025-06-10 08:52:58 +08:00
parent a5a4909e6a
commit b98e1ca305
10 changed files with 39 additions and 26 deletions

View File

@@ -4,7 +4,6 @@ import triton
import triton.language as tl
from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
# from nanovllm.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
from nanovllm.utils.context import get_context
@@ -64,8 +63,8 @@ class Attention(nn.Module):
context = get_context()
k_cache = self.k_cache
v_cache = self.v_cache
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
if context.is_prefill:
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
if context.block_tables is None: # normal prefill
cu_seqlens_k = context.cu_seqlens_k
seqused_k = None
@@ -79,7 +78,7 @@ class Attention(nn.Module):
seqused_k=seqused_k, softmax_scale=self.scale,
causal=True, block_table=context.block_tables)
else: # decode
o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache, k.unsqueeze(1), v.unsqueeze(1),
o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache,
cache_seqlens=context.context_lens, block_table=context.block_tables,
softmax_scale=self.scale, causal=True)
o = o.view(-1, self.num_heads * self.head_dim)

View File

@@ -7,11 +7,11 @@ class Sampler(nn.Module):
def __init__(self):
super().__init__()
def forward(self, logits: torch.Tensor, temperatures: torch.Tensor | None = None):
def forward(self, logits: torch.Tensor, temperatures: torch.Tensor):
logits = logits.to(torch.float)
if temperatures is not None:
logits.div_(temperatures.unsqueeze(dim=1))
greedy_tokens = logits.argmax(dim=-1)
logits.div_(temperatures.unsqueeze(dim=1))
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
# logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
sampled_tokens = probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1)
return sampled_tokens
sample_tokens = probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1)
return torch.where(temperatures == 0, greedy_tokens, sample_tokens)