From 9b8165af5ac40a328b7dcaa08d51fb552408e221 Mon Sep 17 00:00:00 2001 From: Zijie Tian Date: Fri, 12 Dec 2025 01:35:30 +0800 Subject: [PATCH] [fix] Fixed kvcache offload problem. --- nanovllm/kvcache/offload_engine.py | 75 ++++++++++++++++++++++++++++++ nanovllm/layers/attention.py | 49 ++++++------------- nanovllm/utils/context.py | 8 ++++ 3 files changed, 96 insertions(+), 36 deletions(-) diff --git a/nanovllm/kvcache/offload_engine.py b/nanovllm/kvcache/offload_engine.py index 582325e..0450e88 100644 --- a/nanovllm/kvcache/offload_engine.py +++ b/nanovllm/kvcache/offload_engine.py @@ -155,6 +155,11 @@ class OffloadEngine: self.ping_offload_done = torch.cuda.Event() self.pong_offload_done = torch.cuda.Event() + # ========== Per-layer events for chunked attention ========== + # Each layer has its own event for synchronization + self.compute_ready_per_layer = [torch.cuda.Event() for _ in range(num_layers)] + self.prefetch_ready_per_layer = [torch.cuda.Event() for _ in range(num_layers)] + # ========== Event tracking for async transfers ========== self.pending_events: Dict[Tuple[int, int], torch.cuda.Event] = {} @@ -836,10 +841,80 @@ class OffloadEngine: """Wait for Compute region loading to complete.""" self.compute_stream.wait_event(self.compute_ready) + def load_to_compute_layer(self, layer_id: int, cpu_block_ids: List[int]) -> None: + """ + Load CPU blocks to Compute region for a single layer only. + + This is used for per-layer chunked attention where each layer + independently loads its KV data. + + Args: + layer_id: Layer index to load + cpu_block_ids: List of CPU block IDs to load + """ + if not cpu_block_ids: + self.compute_ready_per_layer[layer_id].record(self.transfer_stream_main) + return + + num_to_load = min(len(cpu_block_ids), len(self.compute_slots)) + logger.debug(f"Compute load (layer {layer_id}): CPU{cpu_block_ids[:num_to_load]} -> GPU compute slots {self.compute_slots[:num_to_load]}") + + with torch.cuda.stream(self.transfer_stream_main): + for i in range(num_to_load): + cpu_id = cpu_block_ids[i] + gpu_slot = self.compute_slots[i] + # Copy only this layer (not all layers) + self.k_cache_gpu[layer_id, gpu_slot].copy_( + self.k_cache_cpu[layer_id, cpu_id], non_blocking=True + ) + self.v_cache_gpu[layer_id, gpu_slot].copy_( + self.v_cache_cpu[layer_id, cpu_id], non_blocking=True + ) + self.compute_ready_per_layer[layer_id].record(self.transfer_stream_main) + + def wait_compute_layer(self, layer_id: int) -> None: + """Wait for specific layer's Compute region loading to complete.""" + self.compute_stream.wait_event(self.compute_ready_per_layer[layer_id]) + def wait_prefetch(self) -> None: """Wait for Prefetch region loading to complete.""" self.compute_stream.wait_event(self.prefetch_ready) + def load_to_prefetch_layer(self, layer_id: int, cpu_block_ids: List[int]) -> None: + """ + Load CPU blocks to Prefetch region for a single layer only. + + This is used for per-layer chunked attention where each layer + independently loads its KV data. + + Args: + layer_id: Layer index to load + cpu_block_ids: List of CPU block IDs to load + """ + if not cpu_block_ids: + self.prefetch_ready_per_layer[layer_id].record(self.transfer_stream_main) + return + + num_to_load = min(len(cpu_block_ids), len(self.prefetch_slots)) + logger.debug(f"Prefetch load (layer {layer_id}): CPU{cpu_block_ids[:num_to_load]} -> GPU prefetch slots {self.prefetch_slots[:num_to_load]}") + + with torch.cuda.stream(self.transfer_stream_main): + for i in range(num_to_load): + cpu_id = cpu_block_ids[i] + gpu_slot = self.prefetch_slots[i] + # Copy only this layer (not all layers) + self.k_cache_gpu[layer_id, gpu_slot].copy_( + self.k_cache_cpu[layer_id, cpu_id], non_blocking=True + ) + self.v_cache_gpu[layer_id, gpu_slot].copy_( + self.v_cache_cpu[layer_id, cpu_id], non_blocking=True + ) + self.prefetch_ready_per_layer[layer_id].record(self.transfer_stream_main) + + def wait_prefetch_layer(self, layer_id: int) -> None: + """Wait for specific layer's Prefetch region loading to complete.""" + self.compute_stream.wait_event(self.prefetch_ready_per_layer[layer_id]) + def swap_compute_prefetch(self) -> None: """Swap roles of Compute region and Prefetch region.""" self.compute_slots, self.prefetch_slots = self.prefetch_slots, self.compute_slots diff --git a/nanovllm/layers/attention.py b/nanovllm/layers/attention.py index ab518f8..9ba1b3d 100644 --- a/nanovllm/layers/attention.py +++ b/nanovllm/layers/attention.py @@ -136,36 +136,20 @@ class Attention(nn.Module): # Use Prefetch region to load previous KV (won't conflict with current Compute region) prefetch_size = offload_engine.num_prefetch_blocks num_chunks = (len(cpu_block_table) + prefetch_size - 1) // prefetch_size - use_compute = True # Alternate between Compute region and Prefetch region - - # First load previous KV to Prefetch region - # Only layer 0 triggers the load (loads ALL layers at once) - first_chunk_end = min(prefetch_size, len(cpu_block_table)) - first_chunk_ids = cpu_block_table[:first_chunk_end] - if self.layer_id == 0: - offload_engine.load_to_prefetch(first_chunk_ids) for chunk_idx in range(num_chunks): start = chunk_idx * prefetch_size end = min(start + prefetch_size, len(cpu_block_table)) num_blocks_in_chunk = end - start + chunk_ids = cpu_block_table[start:end] - # Prefetch next chunk to other buffer (if exists) - # Only layer 0 triggers the load - if chunk_idx + 1 < num_chunks and self.layer_id == 0: - next_start = end - next_end = min(next_start + prefetch_size, len(cpu_block_table)) - next_chunk_ids = cpu_block_table[next_start:next_end] - if use_compute: - # Currently in Prefetch region, next load to Compute region (if space available) - # Note: Compute region already has current chunk's KV written, cannot overwrite - # So here we use simple sync strategy: wait for current to complete before loading - pass # Simplified version: no double buffering, only use Prefetch region - else: - offload_engine.load_to_prefetch(next_chunk_ids) + # Load this chunk to Prefetch region (per-layer loading) + # Each layer loads only its own KV, avoiding the bug where layer 0 + # loads all layers and overwrites data before other layers can read it + offload_engine.load_to_prefetch_layer(self.layer_id, chunk_ids) - # Wait for Prefetch region and get KV - offload_engine.wait_prefetch() + # Wait for this layer's Prefetch region and get KV + offload_engine.wait_prefetch_layer(self.layer_id) prev_k, prev_v = offload_engine.get_kv_for_prefetch( self.layer_id, num_blocks_in_chunk ) @@ -185,13 +169,6 @@ class Attention(nn.Module): else: o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse) - # Load next chunk to Prefetch region (if exists) - if chunk_idx + 1 < num_chunks and self.layer_id == 0: - next_start = end - next_end = min(next_start + prefetch_size, len(cpu_block_table)) - next_chunk_ids = cpu_block_table[next_start:next_end] - offload_engine.load_to_prefetch(next_chunk_ids) - # Compute attention against current chunk's KV (with causal mask) current_o, current_lse = flash_attn_with_lse( q_batched, @@ -262,13 +239,13 @@ class Attention(nn.Module): num_blocks_in_chunk = end - start chunk_ids = cpu_block_table[start:end] - # Load this chunk to Compute region - # Only layer 0 triggers the load (loads ALL layers at once) - if self.layer_id == 0: - offload_engine.load_to_compute(chunk_ids) + # Load this chunk to Compute region (per-layer loading) + # Each layer loads only its own KV, avoiding the bug where layer 0 + # loads all layers and overwrites data before other layers can read it + offload_engine.load_to_compute_layer(self.layer_id, chunk_ids) - # Wait for Compute region to be ready and get KV - offload_engine.wait_compute() + # Wait for this layer's Compute region to be ready and get KV + offload_engine.wait_compute_layer(self.layer_id) k_chunk, v_chunk = offload_engine.get_kv_for_compute( self.layer_id, num_blocks_in_chunk ) diff --git a/nanovllm/utils/context.py b/nanovllm/utils/context.py index b32b573..5addf8b 100644 --- a/nanovllm/utils/context.py +++ b/nanovllm/utils/context.py @@ -33,6 +33,14 @@ class Context: # Used when batching decode offloads - we need to attend to all accumulated tokens decode_start_pos_in_block: int = 0 + # ========== Per-layer chunked attention state ========== + # Whether chunked decode/prefill is currently active (for hooks to check) + chunked_decode_active: bool = False + # CPU block IDs for the current chunk being processed + chunked_decode_chunk_ids: List[int] = field(default_factory=list) + # Current chunk index being processed + chunked_decode_current_chunk: int = 0 + _CONTEXT = Context()