diff --git a/nanovllm/kvcache/sparse/full_policy.py b/nanovllm/kvcache/sparse/full_policy.py index 4df7b63..a6b9c23 100644 --- a/nanovllm/kvcache/sparse/full_policy.py +++ b/nanovllm/kvcache/sparse/full_policy.py @@ -192,5 +192,256 @@ class FullAttentionPolicy(SparsePolicy): # Remove batch dimension: [1, seq_len, num_heads, head_dim] -> [seq_len, num_heads, head_dim] return final_o.squeeze(0) + def compute_chunked_decode( + self, + q: torch.Tensor, + layer_id: int, + softmax_scale: float, + offload_engine: "OffloadEngine", + kvcache_manager: "KVCacheManager", + seq: "Sequence", + ) -> torch.Tensor: + """ + Compute full attention for chunked decode. + + This method handles the complete chunked decode flow: + 1. Get prefilled CPU blocks + 2. Apply select_blocks for block filtering + 3. Load blocks via pipeline (ring buffer or cross-layer) + 4. Read accumulated decode tokens from decode buffer + 5. Merge all results + + Args: + q: Query tensor [batch_size, num_heads, head_dim] + layer_id: Current layer index + softmax_scale: Softmax scaling factor + offload_engine: OffloadEngine for loading blocks + kvcache_manager: KVCacheManager for block management + seq: Sequence object + + Returns: + Attention output [batch_size, 1, num_heads, head_dim] + """ + from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs + + # q shape: [batch_size, num_heads, head_dim] (single decode token per sequence) + q_batched = q.unsqueeze(1) # [batch, 1, heads, dim] + + # Get only PREFILLED CPU blocks (exclude the current decode block) + cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq) + if layer_id == 0: + logger.debug(f"Decode attention: cpu_block_table={cpu_block_table}, seq.block_table={list(seq.block_table)}") + if not cpu_block_table: + raise RuntimeError("Chunked decode attention failed: no prefilled CPU blocks available") + + # Calculate valid tokens in the last CPU block + # CRITICAL: Use original prefill length, not current seq length! + # CPU blocks are fixed after prefill, their content doesn't change during decode. + block_size = kvcache_manager.block_size + num_prefill_blocks = len(cpu_block_table) + total_prefill_tokens = kvcache_manager.get_prefill_len(seq) # Original prefill length + last_block_valid_tokens = total_prefill_tokens % block_size + if last_block_valid_tokens == 0 and total_prefill_tokens > 0: + last_block_valid_tokens = block_size # Last block was exactly full + + # Apply sparse policy (self) for block filtering + policy_ctx = PolicyContext( + query_chunk_idx=0, + num_query_chunks=1, + layer_id=layer_id, + query=q_batched, + is_prefill=False, + block_size=kvcache_manager.block_size, + total_kv_len=len(cpu_block_table) * kvcache_manager.block_size, + ) + cpu_block_table = self.select_blocks(cpu_block_table, offload_engine, policy_ctx) + + # Use cross-layer pipeline if active (initialized in model_runner) + if offload_engine.is_pipeline_active(): + o_acc, lse_acc = self._decode_with_layer_pipeline( + q_batched, cpu_block_table, offload_engine, + block_size, last_block_valid_tokens, layer_id, softmax_scale + ) + else: + # Fallback to original ring buffer pipeline + load_slots = offload_engine.decode_load_slots + o_acc, lse_acc = self._decode_ring_buffer_pipeline( + q_batched, cpu_block_table, load_slots, offload_engine, + block_size, last_block_valid_tokens, layer_id, softmax_scale + ) + + # Now attend to accumulated decode tokens from per-layer decode buffer + # Compute decode position information internally + seq_len = len(seq) + decode_pos_in_block = (seq_len - 1) % block_size + decode_start_pos = kvcache_manager.get_decode_start_pos(seq) + decode_start_pos_in_block = decode_start_pos % block_size + num_accumulated = decode_pos_in_block - decode_start_pos_in_block + 1 + + # Sync compute_stream with default stream before reading decode_buffer + compute_stream = offload_engine.compute_stream + compute_stream.wait_stream(torch.cuda.default_stream()) + + with torch.cuda.stream(compute_stream): + if num_accumulated > 0: + # Read from per-layer decode buffer + decode_k = offload_engine.decode_k_buffer[layer_id, decode_start_pos_in_block:decode_pos_in_block+1] + decode_v = offload_engine.decode_v_buffer[layer_id, decode_start_pos_in_block:decode_pos_in_block+1] + decode_k = decode_k.unsqueeze(0) + decode_v = decode_v.unsqueeze(0) + + decode_o, decode_lse = flash_attn_with_lse( + q_batched, decode_k, decode_v, + softmax_scale=softmax_scale, + causal=False, + ) + + if o_acc is None: + o_acc = decode_o + else: + o_acc, _ = merge_attention_outputs(o_acc, lse_acc, decode_o, decode_lse) + + if o_acc is None: + raise RuntimeError("Chunked decode attention failed: no KV available") + + # Sync back to default stream before returning + torch.cuda.default_stream().wait_stream(compute_stream) + + return o_acc + + def _decode_ring_buffer_pipeline( + self, + q_batched: torch.Tensor, + cpu_block_table: list, + load_slots: list, + offload_engine: "OffloadEngine", + block_size: int, + last_block_valid_tokens: int, + layer_id: int, + softmax_scale: float, + ): + """ + Ring buffer pipeline for decode prefill loading. + + Loads one block at a time, computes attention, and merges results. + Uses load_to_slot_layer / wait_slot_layer / get_kv_for_slot methods. + """ + from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs + + num_blocks = len(cpu_block_table) + if num_blocks == 0: + return None, None + + if not load_slots: + return None, None + + o_acc, lse_acc = None, None + num_slots = len(load_slots) + compute_stream = offload_engine.compute_stream + + # Phase 1: Pre-load up to num_slots blocks + num_preload = min(num_slots, num_blocks) + for i in range(num_preload): + offload_engine.load_to_slot_layer(load_slots[i], layer_id, cpu_block_table[i]) + + # Phase 2: Process blocks with pipeline + for block_idx in range(num_blocks): + current_slot = load_slots[block_idx % num_slots] + cpu_block_id = cpu_block_table[block_idx] + + # Wait for current slot's transfer to complete + offload_engine.wait_slot_layer(current_slot) + + with torch.cuda.stream(compute_stream): + # Get KV from slot + prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot) + + # Handle partial last block + is_last_block = (block_idx == num_blocks - 1) + if is_last_block and last_block_valid_tokens < block_size: + prev_k = prev_k[:, :last_block_valid_tokens, :, :] + prev_v = prev_v[:, :last_block_valid_tokens, :, :] + + # Compute attention + prev_o, prev_lse = flash_attn_with_lse( + q_batched, prev_k, prev_v, + softmax_scale=softmax_scale, + causal=False, + ) + + # Record compute done for slot reuse + offload_engine.record_slot_compute_done(current_slot) + + # Start loading next block (pipeline) + next_block_idx = block_idx + num_slots + if next_block_idx < num_blocks: + offload_engine.load_to_slot_layer(current_slot, layer_id, cpu_block_table[next_block_idx]) + + # Merge with accumulated + with torch.cuda.stream(compute_stream): + if o_acc is None: + o_acc, lse_acc = prev_o, prev_lse + else: + o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse) + + return o_acc, lse_acc + + def _decode_with_layer_pipeline( + self, + q_batched: torch.Tensor, + cpu_block_table: list, + offload_engine: "OffloadEngine", + block_size: int, + last_block_valid_tokens: int, + layer_id: int, + softmax_scale: float, + ): + """ + Decode using cross-layer pipeline for optimized H2D transfer. + + Uses pre-loaded layer buffers instead of loading blocks one by one. + The pipeline loads the next layer's data while the current layer + computes, achieving transfer/compute overlap. + """ + from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs + + num_blocks = len(cpu_block_table) + if num_blocks == 0: + return None, None + + compute_stream = offload_engine.compute_stream + + # Get KV from pre-loaded layer buffer (triggers next layer loading) + prev_k, prev_v = offload_engine.get_decode_layer_kv(layer_id, num_blocks) + + # prev_k, prev_v shape: [num_blocks, block_size, kv_heads, head_dim] + # Reshape to [1, num_blocks * block_size, kv_heads, head_dim] + total_tokens = num_blocks * block_size + + # Handle partial last block + if last_block_valid_tokens < block_size: + # Only use valid tokens from last block + actual_tokens = (num_blocks - 1) * block_size + last_block_valid_tokens + # Flatten and truncate + prev_k_flat = prev_k.reshape(-1, prev_k.shape[-2], prev_k.shape[-1])[:actual_tokens] + prev_v_flat = prev_v.reshape(-1, prev_v.shape[-2], prev_v.shape[-1])[:actual_tokens] + else: + prev_k_flat = prev_k.reshape(-1, prev_k.shape[-2], prev_k.shape[-1]) + prev_v_flat = prev_v.reshape(-1, prev_v.shape[-2], prev_v.shape[-1]) + + # Add batch dimension: [1, total_tokens, kv_heads, head_dim] + prev_k_batched = prev_k_flat.unsqueeze(0) + prev_v_batched = prev_v_flat.unsqueeze(0) + + # Compute attention on all prefilled blocks at once + with torch.cuda.stream(compute_stream): + o_acc, lse_acc = flash_attn_with_lse( + q_batched, prev_k_batched, prev_v_batched, + softmax_scale=softmax_scale, + causal=False, + ) + + return o_acc, lse_acc + def __repr__(self) -> str: return "FullAttentionPolicy()" diff --git a/nanovllm/kvcache/sparse/policy.py b/nanovllm/kvcache/sparse/policy.py index 7cb9dcd..db09ec1 100644 --- a/nanovllm/kvcache/sparse/policy.py +++ b/nanovllm/kvcache/sparse/policy.py @@ -233,5 +233,43 @@ class SparsePolicy(ABC): """ pass + @abstractmethod + def compute_chunked_decode( + self, + q: torch.Tensor, + layer_id: int, + softmax_scale: float, + offload_engine: "OffloadEngine", + kvcache_manager: "KVCacheManager", + seq: "Sequence", + ) -> torch.Tensor: + """ + Compute chunked decode attention (complete flow). + + This is the main entry point for decode attention computation. + It defines the complete decode flow: + 1. Get prefilled blocks from CPU + 2. Select blocks (call select_blocks) + 3. Load blocks via pipeline (ring buffer or cross-layer) + 4. Read accumulated decode tokens from decode buffer + 5. Merge all results + + The decode position information can be computed internally: + - decode_start_pos = kvcache_manager.get_decode_start_pos(seq) + - decode_pos_in_block = (len(seq) - 1) % kvcache_manager.block_size + + Args: + q: [batch_size, num_heads, head_dim] query for decode token + layer_id: transformer layer index + softmax_scale: softmax scaling factor + offload_engine: OffloadEngine for loading blocks + kvcache_manager: KVCacheManager for block management + seq: Sequence object + + Returns: + [batch_size, 1, num_heads, head_dim] final attention output + """ + pass + def __repr__(self) -> str: return f"{self.__class__.__name__}()" diff --git a/nanovllm/layers/attention.py b/nanovllm/layers/attention.py index e13456b..84af442 100644 --- a/nanovllm/layers/attention.py +++ b/nanovllm/layers/attention.py @@ -5,7 +5,6 @@ from torch import nn from flash_attn.flash_attn_interface import flash_attn_varlen_func, flash_attn_with_kvcache from nanovllm.utils.context import get_context -from nanovllm.kvcache.sparse.policy import PolicyContext logger = logging.getLogger(__name__) @@ -237,240 +236,41 @@ class Attention(nn.Module): context, ) -> torch.Tensor: """ - Compute decode attention using cross-layer pipeline. + Compute decode attention by delegating to sparse policy. - Optimization: Uses double-buffered layer cache to overlap H2D transfer - with computation across layers: - - Layer N computes while Layer N+1's data is being loaded - - Each layer only waits for its own data, not all layers' data + Simplified design: + - All computation logic is delegated to sparse_policy.compute_chunked_decode() + - This method only validates the policy and delegates - This reduces effective latency from O(num_layers * transfer_time) to - O(transfer_time + num_layers * compute_time) when transfer < compute. + The policy handles: + 1. Loading prefilled blocks from CPU via pipeline + 2. Computing attention against prefilled KV + 3. Reading accumulated decode tokens from decode buffer + 4. Merging all results """ - from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs - - # q shape: [batch_size, num_heads, head_dim] (single decode token per sequence) - q_batched = q.unsqueeze(1) # [batch, 1, heads, dim] - kvcache_manager = context.kvcache_manager seq = context.chunked_seq - - # Get only PREFILLED CPU blocks (exclude the current decode block) - cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq) - if self.layer_id == 0: - logger.debug(f"Decode attention: cpu_block_table={cpu_block_table}, seq.block_table={list(seq.block_table)}") - if not cpu_block_table: - raise RuntimeError("Chunked decode attention failed: no prefilled CPU blocks available") - - # Calculate valid tokens in the last CPU block - # CRITICAL: Use original prefill length, not current seq length! - # CPU blocks are fixed after prefill, their content doesn't change during decode. - block_size = kvcache_manager.block_size - num_prefill_blocks = len(cpu_block_table) - total_prefill_tokens = kvcache_manager.get_prefill_len(seq) # Original prefill length - last_block_valid_tokens = total_prefill_tokens % block_size - if last_block_valid_tokens == 0 and total_prefill_tokens > 0: - last_block_valid_tokens = block_size # Last block was exactly full - offload_engine = kvcache_manager.offload_engine - # Apply sparse policy if enabled (Quest does Top-K selection for decode) + # Get sparse policy - required for chunked decode sparse_policy = kvcache_manager.sparse_policy - if sparse_policy is not None: - policy_ctx = PolicyContext( - query_chunk_idx=0, - num_query_chunks=1, - layer_id=self.layer_id, - query=q_batched, - is_prefill=False, - block_size=kvcache_manager.block_size, - total_kv_len=len(cpu_block_table) * kvcache_manager.block_size, - ) - cpu_block_table = sparse_policy.select_blocks( - cpu_block_table, offload_engine, policy_ctx - ) + if sparse_policy is None: + raise RuntimeError("sparse_policy is required for chunked decode") - # Use cross-layer pipeline if active (initialized in model_runner) - if offload_engine.is_pipeline_active(): - o_acc, lse_acc = self._decode_with_layer_pipeline( - q_batched, cpu_block_table, offload_engine, - block_size, last_block_valid_tokens - ) - else: - # Fallback to original ring buffer pipeline - load_slots = offload_engine.decode_load_slots - o_acc, lse_acc = self._decode_ring_buffer_pipeline( - q_batched, cpu_block_table, load_slots, offload_engine, - block_size, last_block_valid_tokens - ) + # Check if policy supports decode phase + if not sparse_policy.supports_decode: + raise RuntimeError(f"{sparse_policy} does not support decode phase") - # Now attend to accumulated decode tokens from per-layer decode buffer - pos_in_block = context.decode_pos_in_block - start_pos = context.decode_start_pos_in_block - num_accumulated = pos_in_block - start_pos + 1 + # [DEBUG] Verify execution path + logger.debug(f"[DEBUG] Calling sparse_policy.compute_chunked_decode, " + f"policy={sparse_policy}, layer={self.layer_id}") - # Sync compute_stream with default stream before reading decode_buffer - compute_stream = offload_engine.compute_stream - compute_stream.wait_stream(torch.cuda.default_stream()) - - with torch.cuda.stream(compute_stream): - if num_accumulated > 0: - # Read from per-layer decode buffer - decode_k = offload_engine.decode_k_buffer[self.layer_id, start_pos:pos_in_block+1] - decode_v = offload_engine.decode_v_buffer[self.layer_id, start_pos:pos_in_block+1] - decode_k = decode_k.unsqueeze(0) - decode_v = decode_v.unsqueeze(0) - - decode_o, decode_lse = flash_attn_with_lse( - q_batched, decode_k, decode_v, - softmax_scale=self.scale, - causal=False, - ) - - if o_acc is None: - o_acc = decode_o - else: - o_acc, _ = merge_attention_outputs(o_acc, lse_acc, decode_o, decode_lse) - - if o_acc is None: - raise RuntimeError("Chunked decode attention failed: no KV available") - - # Sync back to default stream before returning - torch.cuda.default_stream().wait_stream(compute_stream) - - return o_acc - - def _decode_ring_buffer_pipeline( - self, - q_batched: torch.Tensor, - cpu_block_table: list, - load_slots: list, - offload_engine, - block_size: int, - last_block_valid_tokens: int, - ): - """ - Ring buffer pipeline for decode prefill loading (same mechanism as prefill). - - Loads one block at a time, computes attention, and merges results. - Uses the same load_to_slot_layer / wait_slot_layer / get_kv_for_slot - methods as prefill for proven correctness. - """ - from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs - - num_blocks = len(cpu_block_table) - if num_blocks == 0: - return None, None - - if not load_slots: - return None, None - - o_acc, lse_acc = None, None - num_slots = len(load_slots) - compute_stream = offload_engine.compute_stream - - # Phase 1: Pre-load up to num_slots blocks - num_preload = min(num_slots, num_blocks) - for i in range(num_preload): - offload_engine.load_to_slot_layer(load_slots[i], self.layer_id, cpu_block_table[i]) - - # Phase 2: Process blocks with pipeline - for block_idx in range(num_blocks): - current_slot = load_slots[block_idx % num_slots] - cpu_block_id = cpu_block_table[block_idx] - - # Wait for current slot's transfer to complete - offload_engine.wait_slot_layer(current_slot) - - with torch.cuda.stream(compute_stream): - # Get KV from slot - prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot) - - # Handle partial last block - is_last_block = (block_idx == num_blocks - 1) - if is_last_block and last_block_valid_tokens < block_size: - prev_k = prev_k[:, :last_block_valid_tokens, :, :] - prev_v = prev_v[:, :last_block_valid_tokens, :, :] - - # Compute attention - prev_o, prev_lse = flash_attn_with_lse( - q_batched, prev_k, prev_v, - softmax_scale=self.scale, - causal=False, - ) - - # Record compute done for slot reuse - offload_engine.record_slot_compute_done(current_slot) - - # Start loading next block (pipeline) - next_block_idx = block_idx + num_slots - if next_block_idx < num_blocks: - offload_engine.load_to_slot_layer(current_slot, self.layer_id, cpu_block_table[next_block_idx]) - - # Merge with accumulated - with torch.cuda.stream(compute_stream): - if o_acc is None: - o_acc, lse_acc = prev_o, prev_lse - else: - o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse) - - return o_acc, lse_acc - - def _decode_with_layer_pipeline( - self, - q_batched: torch.Tensor, - cpu_block_table: list, - offload_engine, - block_size: int, - last_block_valid_tokens: int, - ): - """ - Decode using cross-layer pipeline for optimized H2D transfer. - - This method uses pre-loaded layer buffers instead of loading - blocks one by one. The pipeline loads the next layer's data - while the current layer computes, achieving transfer/compute overlap. - - The key insight is that each layer needs the SAME blocks but from - different layers of CPU cache. By double-buffering and pipelining - across layers, we reduce total latency. - """ - from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs - - num_blocks = len(cpu_block_table) - if num_blocks == 0: - return None, None - - compute_stream = offload_engine.compute_stream - - # Get KV from pre-loaded layer buffer (triggers next layer loading) - prev_k, prev_v = offload_engine.get_decode_layer_kv(self.layer_id, num_blocks) - - # prev_k, prev_v shape: [num_blocks, block_size, kv_heads, head_dim] - # Reshape to [1, num_blocks * block_size, kv_heads, head_dim] - total_tokens = num_blocks * block_size - - # Handle partial last block - if last_block_valid_tokens < block_size: - # Only use valid tokens from last block - actual_tokens = (num_blocks - 1) * block_size + last_block_valid_tokens - # Flatten and truncate - prev_k_flat = prev_k.reshape(-1, prev_k.shape[-2], prev_k.shape[-1])[:actual_tokens] - prev_v_flat = prev_v.reshape(-1, prev_v.shape[-2], prev_v.shape[-1])[:actual_tokens] - else: - prev_k_flat = prev_k.reshape(-1, prev_k.shape[-2], prev_k.shape[-1]) - prev_v_flat = prev_v.reshape(-1, prev_v.shape[-2], prev_v.shape[-1]) - - # Add batch dimension: [1, total_tokens, kv_heads, head_dim] - prev_k_batched = prev_k_flat.unsqueeze(0) - prev_v_batched = prev_v_flat.unsqueeze(0) - - # Compute attention on all prefilled blocks at once - with torch.cuda.stream(compute_stream): - o_acc, lse_acc = flash_attn_with_lse( - q_batched, prev_k_batched, prev_v_batched, - softmax_scale=self.scale, - causal=False, - ) - - return o_acc, lse_acc + # Delegate all computation to policy (no flash_attn or merge calls here!) + return sparse_policy.compute_chunked_decode( + q, + self.layer_id, + self.scale, + offload_engine, + kvcache_manager, + seq, + )