import torch from torch import nn import triton import triton.language as tl from flash_attn.flash_attn_interface import flash_attn_varlen_func, flash_attn_with_kvcache from nanovllm.utils.context import get_context @triton.jit def store_kvcache_kernel( key_ptr, key_stride, value_ptr, value_stride, k_cache_ptr, v_cache_ptr, slot_mapping_ptr, D: tl.constexpr, ): idx = tl.program_id(0) slot = tl.load(slot_mapping_ptr + idx) if slot == -1: return key_offsets = idx * key_stride + tl.arange(0, D) value_offsets = idx * value_stride + tl.arange(0, D) key = tl.load(key_ptr + key_offsets) value = tl.load(value_ptr + value_offsets) cache_offsets = slot * D + tl.arange(0, D) tl.store(k_cache_ptr + cache_offsets, key) tl.store(v_cache_ptr + cache_offsets, value) def store_kvcache(key: torch.Tensor, value: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, slot_mapping: torch.Tensor): N, num_heads, head_dim = key.shape D = num_heads * head_dim assert key.stride(-1) == 1 and value.stride(-1) == 1 assert key.stride(1) == head_dim and value.stride(1) == head_dim assert k_cache.stride(1) == D and v_cache.stride(1) == D assert slot_mapping.numel() == N store_kvcache_kernel[(N,)](key, key.stride(0), value, value.stride(0), k_cache, v_cache, slot_mapping, D) class Attention(nn.Module): def __init__( self, num_heads, head_dim, scale, num_kv_heads, ): super().__init__() self.num_heads = num_heads self.head_dim = head_dim self.scale = scale self.num_kv_heads = num_kv_heads self.k_cache = self.v_cache = torch.tensor([]) # Layer ID set by model_runner after model creation self.layer_id: int = -1 def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): context = get_context() k_cache, v_cache = self.k_cache, self.v_cache if k_cache.numel() and v_cache.numel(): store_kvcache(k, v, k_cache, v_cache, context.slot_mapping) if context.is_prefill: if context.is_chunked_prefill: # Chunked prefill: merge attention from previous KV o = self._chunked_prefill_attention(q, k, v, context) elif context.block_tables is not None: # prefix cache k, v = k_cache, v_cache o = flash_attn_varlen_func(q, k, v, max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q, max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k, softmax_scale=self.scale, causal=True, block_table=context.block_tables) else: o = flash_attn_varlen_func(q, k, v, max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q, max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k, softmax_scale=self.scale, causal=True, block_table=context.block_tables) else: # decode if context.is_chunked_prefill: # Chunked decode: need to load all KV from CPU+GPU o = self._chunked_decode_attention(q, k, v, context) else: o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache, cache_seqlens=context.context_lens, block_table=context.block_tables, softmax_scale=self.scale, causal=True) return o def _chunked_prefill_attention( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, context, ) -> torch.Tensor: """ Compute attention with chunked KV from CPU cache. For chunked prefill: 1. Load previous KV from CPU for this layer 2. Compute attention against previous KV (no causal mask) 3. Compute attention against current chunk's KV (causal) 4. Merge results using online softmax """ from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs # q, k, v shape: [total_tokens, num_heads, head_dim] total_tokens = q.shape[0] # Reshape for flash attention: [batch, seq, heads, dim] q_batched = q.unsqueeze(0) # [1, total_tokens, heads, dim] k_batched = k.unsqueeze(0) v_batched = v.unsqueeze(0) accumulated_o = None accumulated_lse = None # Load previous KV from CPU for this layer if context.offload_engine is not None and self.layer_id >= 0: # Get the kvcache_manager from context kvcache_manager = context.offload_engine # For each sequence in the chunk, load previous KV # Currently assuming single sequence if hasattr(context, 'chunked_seq') and context.chunked_seq is not None: prev_k, prev_v = kvcache_manager.load_prev_kv_for_layer( context.chunked_seq, self.layer_id, ) if prev_k is not None and prev_v is not None: # Compute attention against previous KV (no causal mask) prev_o, prev_lse = flash_attn_with_lse( q_batched, prev_k, prev_v, softmax_scale=self.scale, causal=False, # No causal mask for previous context ) accumulated_o = prev_o accumulated_lse = prev_lse # Compute attention against current chunk's KV (with causal mask) current_o, current_lse = flash_attn_with_lse( q_batched, k_batched, v_batched, softmax_scale=self.scale, causal=True, # Causal mask for current chunk ) # Merge with accumulated if accumulated_o is None: final_o = current_o else: final_o, _ = merge_attention_outputs( accumulated_o, accumulated_lse, current_o, current_lse, ) # Remove batch dimension: [1, total_tokens, heads, dim] -> [total_tokens, heads, dim] return final_o.squeeze(0) def _chunked_decode_attention( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, context, ) -> torch.Tensor: """ Compute decode attention with KV spread across CPU and GPU. Uses chunked attention similar to chunked prefill: 1. Process blocks on GPU first (if any) 2. Load CPU blocks in chunks to GPU slots (per-layer) 3. Compute attention for each chunk, merge with online softmax """ from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs # q shape: [batch_size, num_heads, head_dim] (single decode token per sequence) # Need: [batch, seqlen, heads, dim] q_batched = q.unsqueeze(1) # [batch, 1, heads, dim] kvcache_manager = context.offload_engine seq = context.chunked_seq o_acc = None lse_acc = None # Step 1: Process blocks already on GPU (if any) gpu_slots, _ = kvcache_manager.get_gpu_blocks_for_decode(seq) if gpu_slots: k_gpu, v_gpu = kvcache_manager.get_kv_for_gpu_slots(self.layer_id, gpu_slots) o_gpu, lse_gpu = flash_attn_with_lse( q_batched, k_gpu, v_gpu, softmax_scale=self.scale, causal=False, ) o_acc, lse_acc = o_gpu, lse_gpu # Step 2: Process CPU blocks in chunks # Get chunk info from kvcache_manager cpu_block_ids, cpu_logical_ids, num_chunks = kvcache_manager.get_decode_chunk_info(seq) if num_chunks > 0: # Use num_gpu_slots - 1 to avoid the reserved slot (used for write block) chunk_size = kvcache_manager.num_gpu_slots - 1 for chunk_idx in range(num_chunks): start = chunk_idx * chunk_size end = min(start + chunk_size, len(cpu_block_ids)) chunk_cpu_ids = cpu_block_ids[start:end] # Load this chunk to GPU slots 0, 1, 2, ... for THIS LAYER # (slot num_gpu_slots-1 is reserved for write block) gpu_slots_for_chunk = list(range(len(chunk_cpu_ids))) kvcache_manager.offload_engine.load_cpu_blocks_to_gpu_slots( self.layer_id, chunk_cpu_ids, gpu_slots_for_chunk, ) # Get KV for this chunk k_chunk, v_chunk = kvcache_manager.get_kv_for_gpu_slots( self.layer_id, gpu_slots_for_chunk ) # Compute attention for this chunk o_chunk, lse_chunk = flash_attn_with_lse( q_batched, k_chunk, v_chunk, softmax_scale=self.scale, causal=False, ) # Merge with accumulated if o_acc is None: o_acc, lse_acc = o_chunk, lse_chunk else: o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o_chunk, lse_chunk) if o_acc is None: raise RuntimeError("Chunked decode attention failed: no KV available") # Output shape: [batch, 1, heads, dim] (same as normal decode) return o_acc