116 lines
4.6 KiB
Python
116 lines
4.6 KiB
Python
import torch
|
|
from torch import nn
|
|
|
|
from flash_attn.flash_attn_interface import flash_attn_varlen_func, flash_attn_with_kvcache
|
|
from nanovllm.utils.context import get_context
|
|
|
|
|
|
def store_kvcache(
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
k_cache: torch.Tensor,
|
|
v_cache: torch.Tensor,
|
|
slot_mapping: torch.Tensor,
|
|
):
|
|
"""
|
|
Store key/value tensors into KV cache using slot mapping.
|
|
|
|
This is a pure PyTorch implementation replacing the previous Triton kernel.
|
|
Uses index_copy_ for efficient in-place scatter operation.
|
|
|
|
Args:
|
|
key: [N, num_kv_heads, head_dim]
|
|
value: [N, num_kv_heads, head_dim]
|
|
k_cache: [num_blocks, block_size, num_kv_heads, head_dim] or similar
|
|
v_cache: same shape as k_cache
|
|
slot_mapping: [N] with values as flat indices, -1 means skip
|
|
"""
|
|
is_capturing = torch.cuda.is_current_stream_capturing()
|
|
|
|
if is_capturing:
|
|
# During CUDA graph capture, assume all slots are valid.
|
|
# CUDA graphs don't support data-dependent operations like boolean indexing.
|
|
# This is safe because decode (captured) always has valid slots.
|
|
valid_slots = slot_mapping
|
|
valid_keys = key
|
|
valid_values = value
|
|
else:
|
|
# Normal execution: filter out invalid slots (slot == -1)
|
|
valid_mask = slot_mapping >= 0
|
|
if not valid_mask.any():
|
|
return
|
|
valid_slots = slot_mapping[valid_mask]
|
|
valid_keys = key[valid_mask] # [M, num_kv_heads, head_dim]
|
|
valid_values = value[valid_mask]
|
|
|
|
# Flatten cache and KV for scatter operation
|
|
# Cache is viewed as [total_slots, D] where D = num_kv_heads * head_dim
|
|
N, num_kv_heads, head_dim = key.shape
|
|
D = num_kv_heads * head_dim
|
|
total_slots = k_cache.numel() // D
|
|
|
|
k_cache_flat = k_cache.view(total_slots, D)
|
|
v_cache_flat = v_cache.view(total_slots, D)
|
|
valid_keys_flat = valid_keys.reshape(-1, D)
|
|
valid_values_flat = valid_values.reshape(-1, D)
|
|
|
|
# In-place scatter using index_copy_
|
|
k_cache_flat.index_copy_(0, valid_slots.long(), valid_keys_flat)
|
|
v_cache_flat.index_copy_(0, valid_slots.long(), valid_values_flat)
|
|
|
|
|
|
class Attention(nn.Module):
|
|
"""
|
|
Attention layer for GPU-only mode.
|
|
|
|
For CPU offload mode, attention is computed directly in model_runner's
|
|
run_layerwise_offload_prefill/decode methods using FlashAttention.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
num_heads,
|
|
head_dim,
|
|
scale,
|
|
num_kv_heads,
|
|
):
|
|
super().__init__()
|
|
self.num_heads = num_heads
|
|
self.head_dim = head_dim
|
|
self.scale = scale
|
|
self.num_kv_heads = num_kv_heads
|
|
self.k_cache = self.v_cache = torch.tensor([])
|
|
# Layer ID set by model_runner after model creation
|
|
self.layer_id: int = -1
|
|
|
|
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
|
|
context = get_context()
|
|
k_cache, v_cache = self.k_cache, self.v_cache
|
|
|
|
# Store KV to cache (for GPU-only mode)
|
|
if k_cache.numel() and v_cache.numel():
|
|
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
|
|
|
|
if context.is_prefill:
|
|
if context.block_tables is not None: # prefix cache
|
|
k, v = k_cache, v_cache
|
|
o = flash_attn_varlen_func(q, k, v,
|
|
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
|
|
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
|
|
softmax_scale=self.scale, causal=True, block_table=context.block_tables)
|
|
elif context.attention_policy is not None:
|
|
# Attention via policy (GPU-only) - delegate to policy
|
|
o = context.attention_policy.compute_prefill(
|
|
q, k, v, self.layer_id, softmax_scale=self.scale
|
|
)
|
|
else:
|
|
o = flash_attn_varlen_func(q, k, v,
|
|
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
|
|
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
|
|
softmax_scale=self.scale, causal=True, block_table=context.block_tables)
|
|
else: # decode
|
|
o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache,
|
|
cache_seqlens=context.context_lens, block_table=context.block_tables,
|
|
softmax_scale=self.scale, causal=True)
|
|
return o
|