import logging import torch from torch import nn import triton import triton.language as tl from flash_attn.flash_attn_interface import flash_attn_varlen_func, flash_attn_with_kvcache from nanovllm.utils.context import get_context logger = logging.getLogger(__name__) @triton.jit def store_kvcache_kernel( key_ptr, key_stride, value_ptr, value_stride, k_cache_ptr, v_cache_ptr, slot_mapping_ptr, D: tl.constexpr, ): idx = tl.program_id(0) slot = tl.load(slot_mapping_ptr + idx) if slot == -1: return key_offsets = idx * key_stride + tl.arange(0, D) value_offsets = idx * value_stride + tl.arange(0, D) key = tl.load(key_ptr + key_offsets) value = tl.load(value_ptr + value_offsets) cache_offsets = slot * D + tl.arange(0, D) tl.store(k_cache_ptr + cache_offsets, key) tl.store(v_cache_ptr + cache_offsets, value) def store_kvcache(key: torch.Tensor, value: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, slot_mapping: torch.Tensor): N, num_heads, head_dim = key.shape D = num_heads * head_dim assert key.stride(-1) == 1 and value.stride(-1) == 1 assert key.stride(1) == head_dim and value.stride(1) == head_dim assert k_cache.stride(1) == D and v_cache.stride(1) == D assert slot_mapping.numel() == N store_kvcache_kernel[(N,)](key, key.stride(0), value, value.stride(0), k_cache, v_cache, slot_mapping, D) class Attention(nn.Module): def __init__( self, num_heads, head_dim, scale, num_kv_heads, ): super().__init__() self.num_heads = num_heads self.head_dim = head_dim self.scale = scale self.num_kv_heads = num_kv_heads self.k_cache = self.v_cache = torch.tensor([]) # Layer ID set by model_runner after model creation self.layer_id: int = -1 def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): context = get_context() k_cache, v_cache = self.k_cache, self.v_cache if k_cache.numel() and v_cache.numel(): store_kvcache(k, v, k_cache, v_cache, context.slot_mapping) if context.is_prefill: if context.is_chunked_prefill: # Chunked prefill: merge attention from previous KV o = self._chunked_prefill_attention(q, k, v, context) elif context.block_tables is not None: # prefix cache k, v = k_cache, v_cache o = flash_attn_varlen_func(q, k, v, max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q, max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k, softmax_scale=self.scale, causal=True, block_table=context.block_tables) else: o = flash_attn_varlen_func(q, k, v, max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q, max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k, softmax_scale=self.scale, causal=True, block_table=context.block_tables) else: # decode if context.is_chunked_prefill: # Chunked decode: need to load all KV from CPU+GPU o = self._chunked_decode_attention(q, k, v, context) else: o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache, cache_seqlens=context.context_lens, block_table=context.block_tables, softmax_scale=self.scale, causal=True) return o def _chunked_prefill_attention( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, context, ) -> torch.Tensor: """ Compute attention with unified ring buffer for chunked prefill. Ring buffer design: - Current chunk's KV is written to ring_slot[chunk_idx % N] - Previous chunks' KV are loaded from CPU using N-1 available slots - Pipeline: pre-fill slots, then process with overlapped load/compute For each layer: 1. Current chunk's KV is in k_batched, v_batched (just written by model) 2. Load previous chunks from CPU using available slots (pipeline) 3. Compute attention against previous KV (no causal mask) 4. Compute attention against current KV (causal) 5. Merge all results using online softmax """ from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs # q, k, v shape: [total_tokens, num_heads, head_dim] # Reshape for flash attention: [batch, seq, heads, dim] q_batched = q.unsqueeze(0) # [1, total_tokens, heads, dim] k_batched = k.unsqueeze(0) v_batched = v.unsqueeze(0) o_acc = None lse_acc = None kvcache_manager = context.kvcache_manager seq = context.chunked_seq if hasattr(context, 'chunked_seq') else None current_chunk_idx = context.current_chunk_idx if kvcache_manager is not None and seq is not None and self.layer_id >= 0: # Get prefilled CPU blocks (blocks from previous chunks) cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq) if cpu_block_table: offload_engine = kvcache_manager.offload_engine # Get write slot for current chunk and available load slots write_slot = offload_engine.get_write_slot_for_prefill(current_chunk_idx) load_slots = offload_engine.get_load_slots_for_prefill(write_slot) pipeline_depth = len(load_slots) if pipeline_depth == 0: # Only 1 slot total, cannot pipeline - use sync loading o_acc, lse_acc = self._sync_load_previous_chunks( q_batched, cpu_block_table, offload_engine ) else: # Use ring buffer pipeline o_acc, lse_acc = self._ring_buffer_pipeline_load( q_batched, cpu_block_table, load_slots, offload_engine ) # Compute attention against current chunk's KV (with causal mask) current_o, current_lse = flash_attn_with_lse( q_batched, k_batched, v_batched, softmax_scale=self.scale, causal=True, ) # Merge with accumulated if o_acc is None: final_o = current_o else: final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse) # Remove batch dimension: [1, total_tokens, heads, dim] -> [total_tokens, heads, dim] return final_o.squeeze(0) def _sync_load_previous_chunks( self, q_batched: torch.Tensor, cpu_block_table: list, offload_engine, ): """Synchronous loading fallback when pipeline_depth=0.""" from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs o_acc, lse_acc = None, None for block_idx, cpu_block_id in enumerate(cpu_block_table): # Load to slot 0 (single slot) offload_engine.load_to_slot_layer(0, self.layer_id, cpu_block_id) offload_engine.wait_slot_layer(0, self.layer_id) prev_k, prev_v = offload_engine.get_kv_for_slot(0, self.layer_id) prev_o, prev_lse = flash_attn_with_lse( q_batched, prev_k, prev_v, softmax_scale=self.scale, causal=False, ) if o_acc is None: o_acc, lse_acc = prev_o, prev_lse else: o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse) return o_acc, lse_acc def _ring_buffer_pipeline_load( self, q_batched: torch.Tensor, cpu_block_table: list, load_slots: list, offload_engine, ): """ Ring buffer async pipeline loading with double buffering. Uses compute_done events to ensure safe buffer reuse: - Before loading to slot X, wait for previous compute on slot X to finish - Before computing on slot X, wait for load to slot X to finish Timeline with 2 slots (A, B): ┌──────────────┐ │ Load B0→A │ └──────────────┘ ┌──────────────┐ ┌──────────────┐ │ Load B1→B │ │ Load B2→A │ ... └──────────────┘ └──────────────┘ ↘ ↘ ┌──────────────┐ ┌──────────────┐ │ Compute(A) │ │ Compute(B) │ ... └──────────────┘ └──────────────┘ The load_to_slot_layer internally waits for compute_done[slot] before starting the transfer, ensuring no data race. """ from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs num_blocks = len(cpu_block_table) if num_blocks == 0: return None, None pipeline_depth = len(load_slots) if pipeline_depth == 0: return None, None o_acc, lse_acc = None, None if pipeline_depth == 1: # Only 1 slot available, cannot pipeline - use synchronous mode slot = load_slots[0] for block_idx in range(num_blocks): offload_engine.load_to_slot_layer(slot, self.layer_id, cpu_block_table[block_idx]) offload_engine.wait_slot_layer(slot, self.layer_id) prev_k, prev_v = offload_engine.get_kv_for_slot(slot, self.layer_id) prev_o, prev_lse = flash_attn_with_lse( q_batched, prev_k, prev_v, softmax_scale=self.scale, causal=False, ) # Record compute done so next load can safely reuse this slot offload_engine.record_slot_compute_done(slot, self.layer_id) if o_acc is None: o_acc, lse_acc = prev_o, prev_lse else: o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse) return o_acc, lse_acc # Double buffering with 2 slots slot_A = load_slots[0] slot_B = load_slots[1] # Pre-load first block to slot_A (async) offload_engine.load_to_slot_layer(slot_A, self.layer_id, cpu_block_table[0]) for block_idx in range(num_blocks): # Alternate between slot_A and slot_B current_slot = slot_A if block_idx % 2 == 0 else slot_B next_slot = slot_B if block_idx % 2 == 0 else slot_A # Wait for current slot's transfer to complete offload_engine.wait_slot_layer(current_slot, self.layer_id) # Start async load of next block to the OTHER slot # load_to_slot_layer internally waits for next_slot's compute_done if block_idx + 1 < num_blocks: offload_engine.load_to_slot_layer(next_slot, self.layer_id, cpu_block_table[block_idx + 1]) # Compute attention on current slot's data prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot, self.layer_id) prev_o, prev_lse = flash_attn_with_lse( q_batched, prev_k, prev_v, softmax_scale=self.scale, causal=False, ) # Record compute done - this allows the next round to safely load into this slot offload_engine.record_slot_compute_done(current_slot, self.layer_id) # Merge with accumulated if o_acc is None: o_acc, lse_acc = prev_o, prev_lse else: o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse) return o_acc, lse_acc def _chunked_decode_attention( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, context, ) -> torch.Tensor: """ Compute decode attention with double-buffering using decode_load_slots. Decode uses: - decode_slot (slot[0]): writes new token's KV - decode_load_slots (slots[1:]): load previous chunks from CPU Pipeline design: - First half of decode_load_slots: 'compute' buffer - Second half: 'prefetch' buffer - Double-buffer between them for async overlap Timeline: ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │Load C0→buf0 │ │Load C1→buf1 │ │Load C2→buf0 │ ... └─────────────┘ └─────────────┘ └─────────────┘ ↘ ↘ ↘ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │ Attn(C0) │ │ Attn(C1) │ │ Attn(C2) │ └─────────────┘ └─────────────┘ └─────────────┘ """ from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs # q shape: [batch_size, num_heads, head_dim] (single decode token per sequence) q_batched = q.unsqueeze(1) # [batch, 1, heads, dim] kvcache_manager = context.kvcache_manager seq = context.chunked_seq # Get all CPU blocks for this sequence cpu_block_table, _ = kvcache_manager.get_all_cpu_blocks(seq) if self.layer_id == 0: logger.debug(f"Decode attention: cpu_block_table={cpu_block_table}, seq.block_table={list(seq.block_table)}") if not cpu_block_table: raise RuntimeError("Chunked decode attention failed: no CPU blocks available") offload_engine = kvcache_manager.offload_engine # Use prefetch_size as chunk size for double buffering # This ensures both Compute and Prefetch regions can hold a full chunk chunk_size = offload_engine.num_prefetch_blocks num_chunks = (len(cpu_block_table) + chunk_size - 1) // chunk_size o_acc = None lse_acc = None # Double buffering state: True = use Compute region, False = use Prefetch region use_compute = True # Pre-load first chunk to Compute region (async) first_chunk_ids = cpu_block_table[:min(chunk_size, len(cpu_block_table))] offload_engine.load_to_compute_layer(self.layer_id, first_chunk_ids) for chunk_idx in range(num_chunks): start = chunk_idx * chunk_size end = min(start + chunk_size, len(cpu_block_table)) num_blocks_in_chunk = end - start # Wait for current buffer to be ready if use_compute: offload_engine.wait_compute_layer(self.layer_id) else: offload_engine.wait_prefetch_layer(self.layer_id) # Trigger async prefetch of next chunk to the OTHER buffer # This overlaps transfer with current chunk's computation if chunk_idx + 1 < num_chunks: next_start = end next_end = min(next_start + chunk_size, len(cpu_block_table)) next_chunk_ids = cpu_block_table[next_start:next_end] if use_compute: # Current in Compute, prefetch next to Prefetch region offload_engine.load_to_prefetch_layer(self.layer_id, next_chunk_ids) else: # Current in Prefetch, prefetch next to Compute region offload_engine.load_to_compute_layer(self.layer_id, next_chunk_ids) # Get KV from current buffer if use_compute: k_chunk, v_chunk = offload_engine.get_kv_for_compute( self.layer_id, num_blocks_in_chunk ) else: k_chunk, v_chunk = offload_engine.get_kv_for_prefetch( self.layer_id, num_blocks_in_chunk ) # Compute attention for this chunk o_chunk, lse_chunk = flash_attn_with_lse( q_batched, k_chunk, v_chunk, softmax_scale=self.scale, causal=False, ) # Merge with accumulated if o_acc is None: o_acc, lse_acc = o_chunk, lse_chunk else: o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o_chunk, lse_chunk) # Swap buffers for next iteration use_compute = not use_compute # Now attend to Decode region (contains accumulated decode tokens) pos_in_block = context.decode_pos_in_block start_pos = context.decode_start_pos_in_block num_accumulated = pos_in_block - start_pos + 1 if num_accumulated > 0: decode_k = offload_engine.k_cache_gpu[self.layer_id, offload_engine.decode_slot, start_pos:pos_in_block+1] decode_v = offload_engine.v_cache_gpu[self.layer_id, offload_engine.decode_slot, start_pos:pos_in_block+1] decode_k = decode_k.unsqueeze(0) decode_v = decode_v.unsqueeze(0) decode_o, decode_lse = flash_attn_with_lse( q_batched, decode_k, decode_v, softmax_scale=self.scale, causal=False, ) if o_acc is None: o_acc = decode_o else: o_acc, _ = merge_attention_outputs(o_acc, lse_acc, decode_o, decode_lse) if o_acc is None: raise RuntimeError("Chunked decode attention failed: no KV available") return o_acc