init commit
This commit is contained in:
86
nanovllm/layers/attention.py
Normal file
86
nanovllm/layers/attention.py
Normal file
@@ -0,0 +1,86 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
|
||||
# from nanovllm.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
|
||||
from nanovllm.utils.context import get_context
|
||||
|
||||
|
||||
@triton.jit
|
||||
def store_kvcache_kernel(
|
||||
key_ptr,
|
||||
key_stride,
|
||||
value_ptr,
|
||||
value_stride,
|
||||
k_cache_ptr,
|
||||
v_cache_ptr,
|
||||
slot_mapping_ptr,
|
||||
D: tl.constexpr,
|
||||
):
|
||||
idx = tl.program_id(0)
|
||||
key_offsets = idx * key_stride + tl.arange(0, D)
|
||||
value_offsets = idx * value_stride + tl.arange(0, D)
|
||||
key = tl.load(key_ptr + key_offsets)
|
||||
value = tl.load(value_ptr + value_offsets)
|
||||
slot = tl.load(slot_mapping_ptr + idx)
|
||||
cache_offsets = slot * D + tl.arange(0, D)
|
||||
tl.store(k_cache_ptr + cache_offsets, key)
|
||||
tl.store(v_cache_ptr + cache_offsets, value)
|
||||
|
||||
|
||||
def store_kvcache(key: torch.Tensor, value: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, slot_mapping: torch.Tensor):
|
||||
N, num_heads, head_dim = key.shape
|
||||
D = num_heads * head_dim
|
||||
assert key.stride(-1) == 1 and value.stride(-1) == 1
|
||||
assert key.stride(1) == head_dim and value.stride(1) == head_dim
|
||||
assert k_cache.stride(1) == D and v_cache.stride(1) == D
|
||||
assert slot_mapping.numel() == N
|
||||
store_kvcache_kernel[(N,)](key, key.stride(0), value, value.stride(0), k_cache, v_cache, slot_mapping, D)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads,
|
||||
head_dim,
|
||||
scale,
|
||||
num_kv_heads,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
self.scale = scale
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.k_cache = self.v_cache = torch.tensor([])
|
||||
|
||||
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
|
||||
o: torch.Tensor
|
||||
q = q.view(-1, self.num_heads, self.head_dim)
|
||||
k = k.view(-1, self.num_kv_heads, self.head_dim)
|
||||
v = v.view(-1, self.num_kv_heads, self.head_dim)
|
||||
context = get_context()
|
||||
k_cache = self.k_cache
|
||||
v_cache = self.v_cache
|
||||
if context.is_prefill:
|
||||
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
|
||||
if context.block_tables is None: # normal prefill
|
||||
cu_seqlens_k = context.cu_seqlens_k
|
||||
seqused_k = None
|
||||
else: # prefix cache
|
||||
cu_seqlens_k = None
|
||||
seqused_k = context.context_lens
|
||||
k, v = k_cache, v_cache
|
||||
o = flash_attn_varlen_func(q, k, v,
|
||||
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
|
||||
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=cu_seqlens_k,
|
||||
seqused_k=seqused_k, softmax_scale=self.scale,
|
||||
causal=True, block_table=context.block_tables)
|
||||
else: # decode
|
||||
o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache, k.unsqueeze(1), v.unsqueeze(1),
|
||||
cache_seqlens=context.context_lens, block_table=context.block_tables,
|
||||
softmax_scale=self.scale, causal=True)
|
||||
o = o.view(-1, self.num_heads * self.head_dim)
|
||||
return o
|
||||
Reference in New Issue
Block a user