This commit is contained in:
GeeeekExplorer
2025-08-31 19:44:57 +08:00
parent 6a6d217de7
commit df99418f7d
11 changed files with 47 additions and 96 deletions

View File

@@ -19,11 +19,12 @@ def store_kvcache_kernel(
D: tl.constexpr,
):
idx = tl.program_id(0)
slot = tl.load(slot_mapping_ptr + idx)
if slot == -1: return
key_offsets = idx * key_stride + tl.arange(0, D)
value_offsets = idx * value_stride + tl.arange(0, D)
key = tl.load(key_ptr + key_offsets)
value = tl.load(value_ptr + value_offsets)
slot = tl.load(slot_mapping_ptr + idx)
cache_offsets = slot * D + tl.arange(0, D)
tl.store(k_cache_ptr + cache_offsets, key)
tl.store(v_cache_ptr + cache_offsets, value)
@@ -56,10 +57,6 @@ class Attention(nn.Module):
self.k_cache = self.v_cache = torch.tensor([])
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
o: torch.Tensor
q = q.view(-1, self.num_heads, self.head_dim)
k = k.view(-1, self.num_kv_heads, self.head_dim)
v = v.view(-1, self.num_kv_heads, self.head_dim)
context = get_context()
k_cache, v_cache = self.k_cache, self.v_cache
if k_cache.numel() and v_cache.numel():
@@ -75,5 +72,4 @@ class Attention(nn.Module):
o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache,
cache_seqlens=context.context_lens, block_table=context.block_tables,
softmax_scale=self.scale, causal=True)
o = o.view(-1, self.num_heads * self.head_dim)
return o