import torch from torch import nn import triton import triton.language as tl from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache # from nanovllm.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache from nanovllm.utils.context import get_context @triton.jit def store_kvcache_kernel( key_ptr, key_stride, value_ptr, value_stride, k_cache_ptr, v_cache_ptr, slot_mapping_ptr, D: tl.constexpr, ): idx = tl.program_id(0) key_offsets = idx * key_stride + tl.arange(0, D) value_offsets = idx * value_stride + tl.arange(0, D) key = tl.load(key_ptr + key_offsets) value = tl.load(value_ptr + value_offsets) slot = tl.load(slot_mapping_ptr + idx) cache_offsets = slot * D + tl.arange(0, D) tl.store(k_cache_ptr + cache_offsets, key) tl.store(v_cache_ptr + cache_offsets, value) def store_kvcache(key: torch.Tensor, value: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, slot_mapping: torch.Tensor): N, num_heads, head_dim = key.shape D = num_heads * head_dim assert key.stride(-1) == 1 and value.stride(-1) == 1 assert key.stride(1) == head_dim and value.stride(1) == head_dim assert k_cache.stride(1) == D and v_cache.stride(1) == D assert slot_mapping.numel() == N store_kvcache_kernel[(N,)](key, key.stride(0), value, value.stride(0), k_cache, v_cache, slot_mapping, D) class Attention(nn.Module): def __init__( self, num_heads, head_dim, scale, num_kv_heads, ): super().__init__() self.num_heads = num_heads self.head_dim = head_dim self.scale = scale self.num_kv_heads = num_kv_heads self.k_cache = self.v_cache = torch.tensor([]) def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): o: torch.Tensor q = q.view(-1, self.num_heads, self.head_dim) k = k.view(-1, self.num_kv_heads, self.head_dim) v = v.view(-1, self.num_kv_heads, self.head_dim) context = get_context() k_cache = self.k_cache v_cache = self.v_cache if context.is_prefill: store_kvcache(k, v, k_cache, v_cache, context.slot_mapping) if context.block_tables is None: # normal prefill cu_seqlens_k = context.cu_seqlens_k seqused_k = None else: # prefix cache cu_seqlens_k = None seqused_k = context.context_lens k, v = k_cache, v_cache o = flash_attn_varlen_func(q, k, v, max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q, max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=cu_seqlens_k, seqused_k=seqused_k, softmax_scale=self.scale, causal=True, block_table=context.block_tables) else: # decode o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache, k.unsqueeze(1), v.unsqueeze(1), cache_seqlens=context.context_lens, block_table=context.block_tables, softmax_scale=self.scale, causal=True) o = o.view(-1, self.num_heads * self.head_dim) return o