From 30462fe89aeda71d110afd4a0ff5230087c82fc0 Mon Sep 17 00:00:00 2001 From: Zijie Tian Date: Wed, 31 Dec 2025 23:35:25 +0800 Subject: [PATCH] [WIP] Before fix needle. --- nanovllm/engine/model_runner.py | 31 +-- nanovllm/kvcache/offload_engine.py | 406 ++++++++++++----------------- nanovllm/layers/attention.py | 41 +-- tests/test_debug_verification.py | 5 +- tests/test_offload_correctness.py | 19 +- 5 files changed, 212 insertions(+), 290 deletions(-) diff --git a/nanovllm/engine/model_runner.py b/nanovllm/engine/model_runner.py index 3e281e6..6a9b2bc 100644 --- a/nanovllm/engine/model_runner.py +++ b/nanovllm/engine/model_runner.py @@ -489,24 +489,15 @@ class ModelRunner: logical_id = seq.block_table[block_idx] self.kvcache_manager.prefilled_blocks.add(logical_id) - # Offload this chunk's ring buffer slot to CPU (async) + # NOTE: Per-layer offloading is now done in attention.forward + # Each layer offloads its KV to CPU immediately after computing attention. + # We just need to wait for the last offload to complete before reusing the slot. if block_idx < len(cpu_block_ids): - cpu_block_id = cpu_block_ids[block_idx] - - # Call sparse policy hook before offload (to capture metadata) - sparse_policy = self.kvcache_manager.sparse_policy - if sparse_policy is not None: - num_tokens = chunk_end - chunk_start - for layer_id in range(offload_engine.num_layers): - k_cache = offload_engine.k_cache_gpu[layer_id, write_slot, :num_tokens] - sparse_policy.on_block_offloaded( - cpu_block_id=cpu_block_id, - layer_id=layer_id, - k_cache=k_cache, - num_valid_tokens=num_tokens, - ) - - offload_engine.offload_slot_to_cpu(write_slot, cpu_block_id) + # TODO: Sparse policy hook needs update for new GPU cache architecture + # The GPU cache no longer has layer dimension, so we can't access + # k_cache_gpu[layer_id, write_slot]. Sparse policy should be called + # in attention.forward after per-layer offload. + pass # Wait for offload to complete before next chunk # (slot will be reused after N chunks) @@ -628,7 +619,11 @@ class ModelRunner: if pos_in_block == self.block_size - 1: last_cpu_block = self.kvcache_manager.get_last_cpu_block(seq) if last_cpu_block >= 0: - offload_engine.offload_decode_slot(last_cpu_block) + # TODO: In new GPU cache architecture (no layer dimension), + # decode offload should be done per-layer in attention.forward. + # For now, offload all layers sequentially. + for layer_id in range(offload_engine.num_layers): + offload_engine.offload_decode_slot_layer(layer_id, last_cpu_block) offload_engine.wait_all_offload_done() # Reset decode start position for next block self.kvcache_manager.reset_decode_start_pos(seq) diff --git a/nanovllm/kvcache/offload_engine.py b/nanovllm/kvcache/offload_engine.py index c89e16b..688f8be 100644 --- a/nanovllm/kvcache/offload_engine.py +++ b/nanovllm/kvcache/offload_engine.py @@ -67,14 +67,19 @@ class OffloadEngine: self.block_numel = block_size * self.kv_dim # ========== sgDMA pitch parameters for strided transfers ========== + # CPU cache: [num_layers, num_cpu_blocks, block_size, kv_heads, head_dim] + # GPU cache: [num_gpu_blocks, block_size, kv_heads, head_dim] (no layer dim) + # For CPU-to-GPU transfer (H2D): copy single layer, single block at a time + # For all-layer CPU operations (D2H offload to all layers): use sgDMA self.dtype_size = dtype.itemsize + # CPU pitch: stride between layers in CPU cache (for all-layer operations) self.cpu_pitch = num_cpu_blocks * self.block_numel * self.dtype_size - self.gpu_pitch = num_gpu_blocks * self.block_numel * self.dtype_size - self.width = self.block_numel * self.dtype_size - self.height = num_layers + # GPU has no layer dimension, so single block transfer is contiguous + self.gpu_block_bytes = self.block_numel * self.dtype_size + self.height = num_layers # For CPU all-layer operations - logger.info(f"sgDMA parameters: cpu_pitch={self.cpu_pitch}, gpu_pitch={self.gpu_pitch}, " - f"width={self.width}, height={self.height}") + logger.info(f"sgDMA parameters: cpu_pitch={self.cpu_pitch}, " + f"gpu_block_bytes={self.gpu_block_bytes}, height={self.height}") # ========== Unified Ring Buffer configuration ========== # Constraint checks @@ -100,14 +105,16 @@ class OffloadEngine: logger.info(f" Decode: slot[0] as decode_slot, slots[1..{num_gpu_blocks-1}] for loading") # ========== Fixed-address GPU KV cache ========== - # Shape: [num_layers, num_gpu_blocks, block_size, kv_heads, head_dim] - # Use zeros initialization to avoid uninitialized memory issues + # Shape: [num_gpu_blocks, block_size, kv_heads, head_dim] + # NOTE: No num_layers dimension! GPU slots are shared across layers. + # Each layer reuses the same slots (layers execute sequentially). + # This saves 28x GPU memory compared to per-layer allocation. self.k_cache_gpu = torch.zeros( - num_layers, num_gpu_blocks, block_size, num_kv_heads, head_dim, + num_gpu_blocks, block_size, num_kv_heads, head_dim, dtype=dtype, device="cuda" ) self.v_cache_gpu = torch.zeros( - num_layers, num_gpu_blocks, block_size, num_kv_heads, head_dim, + num_gpu_blocks, block_size, num_kv_heads, head_dim, dtype=dtype, device="cuda" ) @@ -159,35 +166,23 @@ class OffloadEngine: # Decode offload event self.decode_offload_done = torch.cuda.Event() - # ========== Per-slot Per-layer events for ring buffer ========== - # ring_slot_ready[slot_idx][layer_id] = CUDA Event for H2D completion - # ring_slot_offload_done[slot_idx][layer_id] = CUDA Event for D2H completion - self.ring_slot_ready = [ - [torch.cuda.Event() for _ in range(num_layers)] - for _ in range(self.num_ring_slots) - ] - self.ring_slot_offload_done = [ - [torch.cuda.Event() for _ in range(num_layers)] - for _ in range(self.num_ring_slots) - ] + # ========== Per-slot events for ring buffer ========== + # Since GPU cache has no layer dimension and layers execute sequentially, + # we only need per-slot events (not per-slot per-layer). + # ring_slot_ready[slot_idx] = CUDA Event for H2D completion + # ring_slot_offload_done[slot_idx] = CUDA Event for D2H completion + self.ring_slot_ready = [torch.cuda.Event() for _ in range(self.num_ring_slots)] + self.ring_slot_offload_done = [torch.cuda.Event() for _ in range(self.num_ring_slots)] - # Per-slot events for all-layer operations (used in some legacy paths) - self.ring_slot_all_layers_ready = [torch.cuda.Event() for _ in range(self.num_ring_slots)] - self.ring_slot_all_layers_offload_done = [torch.cuda.Event() for _ in range(self.num_ring_slots)] - - # ========== Per-slot Per-layer compute_done events for async pipeline ========== - # ring_slot_compute_done[slot_idx][layer_id] = CUDA Event for compute completion - # This is used to ensure we don't overwrite data before it's been read by attention - self.ring_slot_compute_done = [ - [torch.cuda.Event() for _ in range(num_layers)] - for _ in range(self.num_ring_slots) - ] + # ========== Per-slot compute_done events for async pipeline ========== + # ring_slot_compute_done[slot_idx] = CUDA Event for compute completion + # This ensures we don't overwrite data before it's been read by attention + self.ring_slot_compute_done = [torch.cuda.Event() for _ in range(self.num_ring_slots)] # Initialize all compute_done events (record them once) # This prevents undefined behavior on first load_to_slot_layer call for slot_idx in range(self.num_ring_slots): - for layer_id in range(num_layers): - self.ring_slot_compute_done[slot_idx][layer_id].record() + self.ring_slot_compute_done[slot_idx].record() torch.cuda.synchronize() # Ensure all events are recorded # ========== Event tracking for async transfers ========== @@ -204,23 +199,24 @@ class OffloadEngine: return stream # ========== CUDA Graph compatible methods ========== + # NOTE: These methods need to be updated for the new GPU cache architecture. + # GPU cache no longer has layer dimension, so gathered copy semantics change. + # For now, these are kept for reference but should not be used without updating. def gathered_h2d_layer(self, layer_id: int) -> None: """ Execute gathered H2D copy for a single layer. - This method is CUDA Graph compatible - can be captured into a graph. - Before calling, update_gather_indices() must be called to set up - which CPU blocks to copy to which GPU slots. - - Args: - layer_id: Layer index to transfer + WARNING: This method needs updating for new GPU cache architecture. + GPU cache no longer has layer dimension. """ + # GPU cache has no layer dimension - use flat indexing + # Source is CPU[layer_id], dest is GPU (shared across layers) gathered_copy_kv( k_src=self.k_cache_cpu[layer_id], v_src=self.v_cache_cpu[layer_id], - k_dst=self.k_cache_gpu[layer_id], - v_dst=self.v_cache_gpu[layer_id], + k_dst=self.k_cache_gpu, # No layer indexing + v_dst=self.v_cache_gpu, # No layer indexing indices=self.gather_indices_gpu[layer_id], ) @@ -228,7 +224,8 @@ class OffloadEngine: """ Execute gathered H2D copy for all layers. - CUDA Graph compatible - can be captured into a single graph. + WARNING: In new architecture, GPU slots are shared across layers. + This method would overwrite slots multiple times. Not recommended. """ for layer_id in range(self.num_layers): self.gathered_h2d_layer(layer_id) @@ -297,10 +294,10 @@ class OffloadEngine: """ Async prefetch a single block from CPU to GPU. - For use in prefill phase where CUDA graphs are not used. + GPU cache has no layer dimension - layer_id is for CPU cache indexing. Args: - layer_id: Layer index + layer_id: Layer index (for CPU cache) cpu_block_id: Source block in CPU cache gpu_block_id: Destination slot in GPU cache @@ -313,13 +310,12 @@ class OffloadEngine: logger.debug(f"H2D prefetch: layer={layer_id}, CPU[{cpu_block_id}] -> GPU[{gpu_block_id}]") with torch.cuda.stream(stream): - # K cache - self.k_cache_gpu[layer_id, gpu_block_id].copy_( + # GPU: no layer dimension, CPU: has layer dimension + self.k_cache_gpu[gpu_block_id].copy_( self.k_cache_cpu[layer_id, cpu_block_id], non_blocking=True ) - # V cache - self.v_cache_gpu[layer_id, gpu_block_id].copy_( + self.v_cache_gpu[gpu_block_id].copy_( self.v_cache_cpu[layer_id, cpu_block_id], non_blocking=True ) @@ -356,8 +352,10 @@ class OffloadEngine: """ Async offload a block from GPU to CPU. + GPU cache has no layer dimension - layer_id is for CPU cache indexing. + Args: - layer_id: Layer index + layer_id: Layer index (for CPU cache) gpu_block_id: Source slot in GPU cache cpu_block_id: Destination block in CPU cache @@ -373,14 +371,13 @@ class OffloadEngine: # Wait for any compute using this block stream.wait_stream(self.compute_stream) - # K cache + # GPU: no layer dimension, CPU: has layer dimension self.k_cache_cpu[layer_id, cpu_block_id].copy_( - self.k_cache_gpu[layer_id, gpu_block_id], + self.k_cache_gpu[gpu_block_id], non_blocking=True ) - # V cache self.v_cache_cpu[layer_id, cpu_block_id].copy_( - self.v_cache_gpu[layer_id, gpu_block_id], + self.v_cache_gpu[gpu_block_id], non_blocking=True ) event.record() @@ -417,11 +414,10 @@ class OffloadEngine: """ Load CPU blocks to specific GPU slots for chunked decode. - Uses the main GPU KV cache slots, not a separate temp buffer. - This is the same mechanism as chunked prefill uses. + GPU cache has no layer dimension - layer_id is for CPU cache indexing. Args: - layer_id: Layer index + layer_id: Layer index (for CPU cache) cpu_block_ids: List of CPU block IDs to load gpu_slot_ids: List of GPU slot IDs to load into (must be same length) """ @@ -434,12 +430,12 @@ class OffloadEngine: with torch.cuda.stream(stream): for cpu_block_id, gpu_slot in zip(cpu_block_ids, gpu_slot_ids): - # Copy from pinned CPU memory to GPU KV cache slot - self.k_cache_gpu[layer_id, gpu_slot].copy_( + # GPU: no layer dimension, CPU: has layer dimension + self.k_cache_gpu[gpu_slot].copy_( self.k_cache_cpu[layer_id, cpu_block_id], non_blocking=True ) - self.v_cache_gpu[layer_id, gpu_slot].copy_( + self.v_cache_gpu[gpu_slot].copy_( self.v_cache_cpu[layer_id, cpu_block_id], non_blocking=True ) @@ -456,8 +452,10 @@ class OffloadEngine: """ Async version: Load CPU blocks to GPU slots. + GPU cache has no layer dimension - layer_id is for CPU cache indexing. + Args: - layer_id: Layer index + layer_id: Layer index (for CPU cache) cpu_block_ids: List of CPU block IDs to load gpu_slot_ids: List of GPU slot IDs to load into @@ -474,11 +472,12 @@ class OffloadEngine: with torch.cuda.stream(stream): for cpu_block_id, gpu_slot in zip(cpu_block_ids, gpu_slot_ids): - self.k_cache_gpu[layer_id, gpu_slot].copy_( + # GPU: no layer dimension, CPU: has layer dimension + self.k_cache_gpu[gpu_slot].copy_( self.k_cache_cpu[layer_id, cpu_block_id], non_blocking=True ) - self.v_cache_gpu[layer_id, gpu_slot].copy_( + self.v_cache_gpu[gpu_slot].copy_( self.v_cache_cpu[layer_id, cpu_block_id], non_blocking=True ) @@ -486,44 +485,8 @@ class OffloadEngine: return event - def load_cpu_blocks_to_gpu_slots_all_layers( - self, - cpu_block_ids: List[int], - gpu_slot_ids: List[int], - ) -> None: - """ - Load CPU blocks to GPU slots for ALL layers at once. - - More efficient than per-layer loading when we know the mapping upfront. - - Args: - cpu_block_ids: List of CPU block IDs to load - gpu_slot_ids: List of GPU slot IDs to load into - """ - assert len(cpu_block_ids) == len(gpu_slot_ids) - - if cpu_block_ids: - logger.debug(f"H2D all layers: CPU{cpu_block_ids} -> GPU{gpu_slot_ids}") - - stream = self._get_next_stream() - - with torch.cuda.stream(stream): - for cpu_block_id, gpu_slot in zip(cpu_block_ids, gpu_slot_ids): - # Copy all layers at once using sgDMA - memcpy_2d_async( - self.k_cache_gpu[:, gpu_slot], - self.k_cache_cpu[:, cpu_block_id], - self.gpu_pitch, self.cpu_pitch, self.width, self.height, - "h2d", stream=stream - ) - memcpy_2d_async( - self.v_cache_gpu[:, gpu_slot], - self.v_cache_cpu[:, cpu_block_id], - self.gpu_pitch, self.cpu_pitch, self.width, self.height, - "h2d", stream=stream - ) - - stream.synchronize() + # NOTE: load_cpu_blocks_to_gpu_slots_all_layers removed - GPU cache no longer has + # layer dimension. Each GPU slot holds data for ONE layer at a time. # ========== Synchronization methods ========== @@ -548,21 +511,27 @@ class OffloadEngine: def get_layer_cache(self, layer_id: int) -> Tuple[Tensor, Tensor]: """ - Get GPU K/V cache tensors for a specific layer. + Get GPU K/V cache tensors for attention layer. + + NOTE: GPU cache has no layer dimension - all layers share the same slots. + The layer_id parameter is kept for API compatibility but not used. Returns: - (k_cache, v_cache) tensors for the layer + (k_cache, v_cache) tensors Shape: [num_gpu_blocks, block_size, kv_heads, head_dim] """ - return self.k_cache_gpu[layer_id], self.v_cache_gpu[layer_id] + # GPU cache is shared across all layers (no layer dimension) + return self.k_cache_gpu, self.v_cache_gpu def get_all_gpu_cache(self) -> Tuple[Tensor, Tensor]: """ Get full GPU K/V cache tensors. + NOTE: GPU cache has no layer dimension in the new architecture. + Returns: (k_cache, v_cache) tensors - Shape: [num_layers, num_gpu_blocks, block_size, kv_heads, head_dim] + Shape: [num_gpu_blocks, block_size, kv_heads, head_dim] """ return self.k_cache_gpu, self.v_cache_gpu @@ -668,7 +637,7 @@ class OffloadEngine: # ----- Per-slot Per-layer loading methods ----- - def record_slot_compute_done(self, slot_idx: int, layer_id: int) -> None: + def record_slot_compute_done(self, slot_idx: int) -> None: """ Record that computation using this slot's data is done. @@ -677,22 +646,23 @@ class OffloadEngine: Args: slot_idx: GPU slot index that was just used for computation - layer_id: Layer index """ - self.ring_slot_compute_done[slot_idx][layer_id].record() + self.ring_slot_compute_done[slot_idx].record() def load_to_slot_layer(self, slot_idx: int, layer_id: int, cpu_block_id: int) -> None: """ Async load a single CPU block to a ring buffer slot for one layer. This is the core building block for ring buffer pipelining. + GPU cache has no layer dimension - slots are shared across all layers. + CPU cache still has layer dimension for persistent storage. + Before starting the transfer, waits for: 1. Any previous compute on this slot to complete - 2. Any pending offload of this slot to complete Args: slot_idx: Target GPU slot index - layer_id: Layer index to load + layer_id: Layer index to load (for CPU cache indexing) cpu_block_id: Source CPU block ID """ logger.debug(f"Ring load: layer={layer_id}, CPU[{cpu_block_id}] -> GPU slot[{slot_idx}]") @@ -704,150 +674,105 @@ class OffloadEngine: with torch.cuda.stream(stream): # Wait for previous compute on this slot to complete before overwriting # This prevents data race: transfer must not start until attention finishes reading - stream.wait_event(self.ring_slot_compute_done[slot_idx][layer_id]) + stream.wait_event(self.ring_slot_compute_done[slot_idx]) # Also wait for any pending offload of this slot to complete # This prevents race: load must not write GPU slot while offload is reading from it - stream.wait_event(self.ring_slot_all_layers_offload_done[slot_idx]) + stream.wait_event(self.ring_slot_offload_done[slot_idx]) - self.k_cache_gpu[layer_id, slot_idx].copy_( + # GPU: no layer dimension, CPU: has layer dimension + self.k_cache_gpu[slot_idx].copy_( self.k_cache_cpu[layer_id, cpu_block_id], non_blocking=True ) - self.v_cache_gpu[layer_id, slot_idx].copy_( + self.v_cache_gpu[slot_idx].copy_( self.v_cache_cpu[layer_id, cpu_block_id], non_blocking=True ) - self.ring_slot_ready[slot_idx][layer_id].record(stream) + self.ring_slot_ready[slot_idx].record(stream) torch.cuda.nvtx.range_pop() - def wait_slot_layer(self, slot_idx: int, layer_id: int) -> None: + def wait_slot_layer(self, slot_idx: int) -> None: """ - Wait for a slot's loading to complete for a specific layer. + Wait for a slot's loading to complete. Args: slot_idx: GPU slot index to wait for - layer_id: Layer index to wait for """ - self.compute_stream.wait_event(self.ring_slot_ready[slot_idx][layer_id]) + self.compute_stream.wait_event(self.ring_slot_ready[slot_idx]) - def load_to_slot_all_layers(self, slot_idx: int, cpu_block_id: int) -> None: - """ - Async load a CPU block to a ring buffer slot for ALL layers. - - Args: - slot_idx: Target GPU slot index - cpu_block_id: Source CPU block ID - """ - logger.debug(f"Ring load all layers: CPU[{cpu_block_id}] -> GPU slot[{slot_idx}]") - - with torch.cuda.stream(self.transfer_stream_main): - memcpy_2d_async( - self.k_cache_gpu[:, slot_idx], - self.k_cache_cpu[:, cpu_block_id], - self.gpu_pitch, self.cpu_pitch, self.width, self.height, - "h2d", stream=self.transfer_stream_main - ) - memcpy_2d_async( - self.v_cache_gpu[:, slot_idx], - self.v_cache_cpu[:, cpu_block_id], - self.gpu_pitch, self.cpu_pitch, self.width, self.height, - "h2d", stream=self.transfer_stream_main - ) - self.ring_slot_all_layers_ready[slot_idx].record(self.transfer_stream_main) - - def wait_slot_all_layers(self, slot_idx: int) -> None: - """Wait for a slot's loading to complete for ALL layers.""" - self.compute_stream.wait_event(self.ring_slot_all_layers_ready[slot_idx]) + # NOTE: load_to_slot_all_layers removed - GPU cache no longer has layer dimension. + # Each GPU slot holds data for ONE layer at a time. Layers execute sequentially, + # reusing the same GPU slots. # ----- Slot offload methods ----- - def offload_slot_to_cpu(self, slot_idx: int, cpu_block_id: int) -> None: + # NOTE: offload_slot_to_cpu (all-layers) removed - GPU cache no longer has layer dimension. + # Use offload_slot_layer_to_cpu for per-layer offloading. + + def wait_slot_offload(self, slot_idx: int) -> None: + """Wait for slot offload to complete.""" + self.compute_stream.wait_event(self.ring_slot_offload_done[slot_idx]) + + def offload_slot_layer_to_cpu(self, slot_idx: int, layer_id: int, cpu_block_id: int) -> None: """ - Async offload a ring buffer slot to CPU (all layers). + Async offload a ring buffer slot to CPU for one layer. + + GPU cache has no layer dimension, so we copy from GPU slot to the + specific layer in CPU cache. Args: slot_idx: Source GPU slot index + layer_id: Target layer in CPU cache cpu_block_id: Target CPU block ID """ - logger.debug(f"Ring offload: GPU slot[{slot_idx}] -> CPU[{cpu_block_id}]") + logger.debug(f"Ring offload: GPU slot[{slot_idx}] -> CPU[layer={layer_id}, block={cpu_block_id}]") - torch.cuda.nvtx.range_push(f"D2H: Slot[{slot_idx}]->CPU[{cpu_block_id}]") + torch.cuda.nvtx.range_push(f"D2H: Slot[{slot_idx}]->CPU[L{layer_id},B{cpu_block_id}]") with torch.cuda.stream(self.transfer_stream_main): # Wait for both compute_stream and default stream # - compute_stream: for flash attention operations # - default_stream: for store_kvcache which runs on default stream self.transfer_stream_main.wait_stream(self.compute_stream) self.transfer_stream_main.wait_stream(torch.cuda.default_stream()) - memcpy_2d_async( - self.k_cache_cpu[:, cpu_block_id], - self.k_cache_gpu[:, slot_idx], - self.cpu_pitch, self.gpu_pitch, self.width, self.height, - "d2h", stream=self.transfer_stream_main - ) - memcpy_2d_async( - self.v_cache_cpu[:, cpu_block_id], - self.v_cache_gpu[:, slot_idx], - self.cpu_pitch, self.gpu_pitch, self.width, self.height, - "d2h", stream=self.transfer_stream_main - ) - self.ring_slot_all_layers_offload_done[slot_idx].record(self.transfer_stream_main) - torch.cuda.nvtx.range_pop() - def wait_slot_offload(self, slot_idx: int) -> None: - """Wait for slot offload to complete.""" - self.compute_stream.wait_event(self.ring_slot_all_layers_offload_done[slot_idx]) - - def offload_slot_layer_to_cpu(self, slot_idx: int, layer_id: int, cpu_block_id: int) -> None: - """ - Async offload a ring buffer slot to CPU for one layer. - - Args: - slot_idx: Source GPU slot index - layer_id: Layer index to offload - cpu_block_id: Target CPU block ID - """ - with torch.cuda.stream(self.transfer_stream_main): - # Wait for both compute_stream and default stream - self.transfer_stream_main.wait_stream(self.compute_stream) - self.transfer_stream_main.wait_stream(torch.cuda.default_stream()) + # GPU: no layer dimension, CPU: has layer dimension self.k_cache_cpu[layer_id, cpu_block_id].copy_( - self.k_cache_gpu[layer_id, slot_idx], non_blocking=True + self.k_cache_gpu[slot_idx], non_blocking=True ) self.v_cache_cpu[layer_id, cpu_block_id].copy_( - self.v_cache_gpu[layer_id, slot_idx], non_blocking=True + self.v_cache_gpu[slot_idx], non_blocking=True ) - self.ring_slot_offload_done[slot_idx][layer_id].record(self.transfer_stream_main) - - def wait_slot_offload_layer(self, slot_idx: int, layer_id: int) -> None: - """Wait for slot offload to complete for a specific layer.""" - self.compute_stream.wait_event(self.ring_slot_offload_done[slot_idx][layer_id]) + self.ring_slot_offload_done[slot_idx].record(self.transfer_stream_main) + torch.cuda.nvtx.range_pop() # ----- KV access methods for ring buffer ----- - def get_kv_for_slot(self, slot_idx: int, layer_id: int) -> Tuple[Tensor, Tensor]: + def get_kv_for_slot(self, slot_idx: int) -> Tuple[Tensor, Tensor]: """ Get KV for a single ring buffer slot. + GPU cache has no layer dimension - slots contain data for whatever + layer was most recently loaded. + Args: slot_idx: GPU slot index - layer_id: Layer ID Returns: (k_cache, v_cache), shape: [1, block_size, kv_heads, head_dim] """ - k = self.k_cache_gpu[layer_id, slot_idx].unsqueeze(0) # [1, block_size, heads, dim] - v = self.v_cache_gpu[layer_id, slot_idx].unsqueeze(0) + k = self.k_cache_gpu[slot_idx].unsqueeze(0) # [1, block_size, heads, dim] + v = self.v_cache_gpu[slot_idx].unsqueeze(0) return k, v def get_kv_for_slots( self, - layer_id: int, slot_indices: List[int], ) -> Tuple[Tensor, Tensor]: """ Get KV for multiple ring buffer slots. + GPU cache has no layer dimension - returns data from specified slots. + Args: - layer_id: Layer ID slot_indices: List of GPU slot indices Returns: @@ -855,92 +780,86 @@ class OffloadEngine: """ if not slot_indices: return None, None - k = self.k_cache_gpu[layer_id, slot_indices] - v = self.v_cache_gpu[layer_id, slot_indices] + k = self.k_cache_gpu[slot_indices] + v = self.v_cache_gpu[slot_indices] k = k.reshape(1, -1, self.num_kv_heads, self.head_dim) v = v.reshape(1, -1, self.num_kv_heads, self.head_dim) return k, v # ----- Decode slot methods (kept for decode phase) ----- + # NOTE: For decode with CPU offload, the flow is per-layer: + # 1. Each layer stores to decode_slot (same GPU memory, reused) + # 2. Each layer offloads its data to CPU[layer_id, block_id] + # 3. Each layer loads prev blocks from CPU[layer_id] when needed - def offload_decode_slot(self, cpu_block_id: int) -> None: + def offload_decode_slot_layer(self, layer_id: int, cpu_block_id: int) -> None: """ - Offload KV from decode slot (slot[0]) to CPU. + Offload KV from decode slot (slot[0]) to CPU for one layer. Args: + layer_id: Layer ID cpu_block_id: Target CPU block ID """ - logger.debug(f"Decode offload: GPU slot[{self.decode_slot}] -> CPU[{cpu_block_id}]") - - with torch.cuda.stream(self.transfer_stream_main): - self.transfer_stream_main.wait_stream(self.compute_stream) - memcpy_2d_async( - self.k_cache_cpu[:, cpu_block_id], - self.k_cache_gpu[:, self.decode_slot], - self.cpu_pitch, self.gpu_pitch, self.width, self.height, - "d2h", stream=self.transfer_stream_main - ) - memcpy_2d_async( - self.v_cache_cpu[:, cpu_block_id], - self.v_cache_gpu[:, self.decode_slot], - self.cpu_pitch, self.gpu_pitch, self.width, self.height, - "d2h", stream=self.transfer_stream_main - ) - self.decode_offload_done.record(self.transfer_stream_main) + # Reuse the existing per-layer offload method + self.offload_slot_layer_to_cpu(self.decode_slot, layer_id, cpu_block_id) def wait_decode_offload(self) -> None: """Wait for decode slot offload to complete.""" - self.compute_stream.wait_event(self.decode_offload_done) + self.wait_slot_offload(self.decode_slot) def get_kv_for_decode_slot( self, - layer_id: int, pos_in_block: int, ) -> Tuple[Tensor, Tensor]: """ Get KV at specified position in decode slot. + GPU cache has no layer dimension - decode slot contains data for + whatever layer was most recently stored. + Args: - layer_id: Layer ID pos_in_block: Token position within block (0 to block_size-1) Returns: (k_cache, v_cache), shape: [1, 1, kv_heads, head_dim] """ - k = self.k_cache_gpu[layer_id, self.decode_slot, pos_in_block:pos_in_block+1] - v = self.v_cache_gpu[layer_id, self.decode_slot, pos_in_block:pos_in_block+1] + k = self.k_cache_gpu[self.decode_slot, pos_in_block:pos_in_block+1] + v = self.v_cache_gpu[self.decode_slot, pos_in_block:pos_in_block+1] k = k.unsqueeze(0) v = v.unsqueeze(0) return k, v def get_kv_for_decode_slot_accumulated( self, - layer_id: int, num_tokens: int, ) -> Tuple[Tensor, Tensor]: """ Get accumulated KV in decode slot (positions 0 to num_tokens-1). + GPU cache has no layer dimension - decode slot contains data for + whatever layer was most recently stored. + Args: - layer_id: Layer ID num_tokens: Number of accumulated tokens (1 to block_size) Returns: (k_cache, v_cache), shape: [1, num_tokens, kv_heads, head_dim] """ - k = self.k_cache_gpu[layer_id, self.decode_slot, :num_tokens] - v = self.v_cache_gpu[layer_id, self.decode_slot, :num_tokens] + k = self.k_cache_gpu[self.decode_slot, :num_tokens] + v = self.v_cache_gpu[self.decode_slot, :num_tokens] k = k.unsqueeze(0) v = v.unsqueeze(0) return k, v # ----- Legacy compatibility methods (for decode double-buffering) ----- + # NOTE: GPU cache has no layer dimension. Layer ID is used for CPU cache indexing only. def load_to_compute_layer(self, layer_id: int, cpu_block_ids: List[int]) -> None: """ Legacy: Load CPU blocks to decode_load_slots for decode double-buffering. Uses first half of decode_load_slots as 'compute' region. + GPU cache has no layer dimension - layer_id is for CPU cache indexing. """ if not cpu_block_ids: return @@ -953,26 +872,27 @@ class OffloadEngine: for i in range(num_to_load): cpu_id = cpu_block_ids[i] gpu_slot = slots[i] - self.k_cache_gpu[layer_id, gpu_slot].copy_( + # GPU: no layer dimension, CPU: has layer dimension + self.k_cache_gpu[gpu_slot].copy_( self.k_cache_cpu[layer_id, cpu_id], non_blocking=True ) - self.v_cache_gpu[layer_id, gpu_slot].copy_( + self.v_cache_gpu[gpu_slot].copy_( self.v_cache_cpu[layer_id, cpu_id], non_blocking=True ) if num_to_load > 0: - self.ring_slot_ready[slots[0]][layer_id].record(self.transfer_stream_main) + self.ring_slot_ready[slots[0]].record(self.transfer_stream_main) - def wait_compute_layer(self, layer_id: int) -> None: + def wait_compute_layer(self) -> None: """Legacy: Wait for 'compute' region loading.""" - half = max(1, len(self.decode_load_slots) // 2) if self.decode_load_slots: - self.wait_slot_layer(self.decode_load_slots[0], layer_id) + self.wait_slot_layer(self.decode_load_slots[0]) def load_to_prefetch_layer(self, layer_id: int, cpu_block_ids: List[int]) -> None: """ Legacy: Load CPU blocks to decode_load_slots for decode double-buffering. Uses second half of decode_load_slots as 'prefetch' region. + GPU cache has no layer dimension - layer_id is for CPU cache indexing. """ if not cpu_block_ids: return @@ -987,37 +907,36 @@ class OffloadEngine: for i in range(num_to_load): cpu_id = cpu_block_ids[i] gpu_slot = slots[i] - self.k_cache_gpu[layer_id, gpu_slot].copy_( + # GPU: no layer dimension, CPU: has layer dimension + self.k_cache_gpu[gpu_slot].copy_( self.k_cache_cpu[layer_id, cpu_id], non_blocking=True ) - self.v_cache_gpu[layer_id, gpu_slot].copy_( + self.v_cache_gpu[gpu_slot].copy_( self.v_cache_cpu[layer_id, cpu_id], non_blocking=True ) if num_to_load > 0: - self.ring_slot_ready[slots[0]][layer_id].record(self.transfer_stream_main) + self.ring_slot_ready[slots[0]].record(self.transfer_stream_main) - def wait_prefetch_layer(self, layer_id: int) -> None: + def wait_prefetch_layer(self) -> None: """Legacy: Wait for 'prefetch' region loading.""" half = max(1, len(self.decode_load_slots) // 2) slots = self.decode_load_slots[half:] if slots: - self.wait_slot_layer(slots[0], layer_id) + self.wait_slot_layer(slots[0]) elif self.decode_load_slots: - self.wait_slot_layer(self.decode_load_slots[0], layer_id) + self.wait_slot_layer(self.decode_load_slots[0]) def get_kv_for_compute( self, - layer_id: int, num_blocks: int, ) -> Tuple[Tensor, Tensor]: """Legacy: Get KV from 'compute' region (first half of decode_load_slots).""" half = max(1, len(self.decode_load_slots) // 2) slots = self.decode_load_slots[:half][:num_blocks] - return self.get_kv_for_slots(layer_id, slots) + return self.get_kv_for_slots(slots) def get_kv_for_prefetch( self, - layer_id: int, num_blocks: int, ) -> Tuple[Tensor, Tensor]: """Legacy: Get KV from 'prefetch' region (second half of decode_load_slots).""" @@ -1026,7 +945,7 @@ class OffloadEngine: if not slots: slots = self.decode_load_slots slots = slots[:num_blocks] - return self.get_kv_for_slots(layer_id, slots) + return self.get_kv_for_slots(slots) # ========== Debug Hook Interface ========== # @@ -1082,12 +1001,15 @@ class OffloadEngine: Call all registered debug hooks with loaded tensor (internal use). Called by attention.py after wait_slot_layer completes. + GPU cache has no layer dimension - slot contains data for the layer + that was just loaded. """ if not self._debug_mode or not self._debug_hooks: return - k = self.k_cache_gpu[layer_id, slot_idx] - v = self.v_cache_gpu[layer_id, slot_idx] + # GPU cache has no layer dimension + k = self.k_cache_gpu[slot_idx] + v = self.v_cache_gpu[slot_idx] for hook in self._debug_hooks: try: diff --git a/nanovllm/layers/attention.py b/nanovllm/layers/attention.py index 4171ad8..3cc170a 100644 --- a/nanovllm/layers/attention.py +++ b/nanovllm/layers/attention.py @@ -201,6 +201,18 @@ class Attention(nn.Module): torch.cuda.nvtx.range_pop() # ChunkedPrefill + # Per-layer offload: In new GPU cache architecture (no layer dimension), + # each layer must offload its KV to CPU before next layer overwrites the GPU slot. + if kvcache_manager is not None and hasattr(kvcache_manager, 'offload_engine'): + offload_engine = kvcache_manager.offload_engine + write_slot = offload_engine.get_write_slot_for_prefill(current_chunk_idx) + seq = context.chunked_seq if hasattr(context, 'chunked_seq') else None + if seq is not None: + cpu_block_ids, _ = kvcache_manager.get_all_cpu_blocks(seq) + if current_chunk_idx < len(cpu_block_ids): + cpu_block_id = cpu_block_ids[current_chunk_idx] + offload_engine.offload_slot_layer_to_cpu(write_slot, self.layer_id, cpu_block_id) + # Remove batch dimension: [1, total_tokens, heads, dim] -> [total_tokens, heads, dim] return final_o.squeeze(0) @@ -219,11 +231,11 @@ class Attention(nn.Module): for block_idx, cpu_block_id in enumerate(cpu_block_table): # Load to slot 0 (single slot) offload_engine.load_to_slot_layer(0, self.layer_id, cpu_block_id) - offload_engine.wait_slot_layer(0, self.layer_id) + offload_engine.wait_slot_layer(0) # IMPORTANT: Must use compute_stream to match wait_slot_layer with torch.cuda.stream(compute_stream): - prev_k, prev_v = offload_engine.get_kv_for_slot(0, self.layer_id) + prev_k, prev_v = offload_engine.get_kv_for_slot(0) prev_o, prev_lse = flash_attn_with_lse( q_batched, prev_k, prev_v, @@ -289,21 +301,21 @@ class Attention(nn.Module): for block_idx in range(num_blocks): cpu_block_id = cpu_block_table[block_idx] offload_engine.load_to_slot_layer(slot, self.layer_id, cpu_block_id) - offload_engine.wait_slot_layer(slot, self.layer_id) + offload_engine.wait_slot_layer(slot) with torch.cuda.stream(compute_stream): # Debug: call hooks on compute_stream (synchronized with transfer) if offload_engine.debug_mode: offload_engine._call_debug_hooks(slot, self.layer_id, cpu_block_id) - prev_k, prev_v = offload_engine.get_kv_for_slot(slot, self.layer_id) + prev_k, prev_v = offload_engine.get_kv_for_slot(slot) prev_o, prev_lse = flash_attn_with_lse( q_batched, prev_k, prev_v, softmax_scale=self.scale, causal=False, ) # Record compute done so next load can safely reuse this slot - offload_engine.record_slot_compute_done(slot, self.layer_id) + offload_engine.record_slot_compute_done(slot) if o_acc is None: o_acc, lse_acc = prev_o, prev_lse else: @@ -332,7 +344,7 @@ class Attention(nn.Module): cpu_block_id = cpu_block_table[block_idx] # Wait for current slot's transfer to complete (on compute_stream) - offload_engine.wait_slot_layer(current_slot, self.layer_id) + offload_engine.wait_slot_layer(current_slot) # Compute attention on current slot's data # IMPORTANT: Use dedicated compute_stream to avoid implicit sync with default stream @@ -342,7 +354,7 @@ class Attention(nn.Module): offload_engine._call_debug_hooks(current_slot, self.layer_id, cpu_block_id) torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} PrevBlock{block_idx}") - prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot, self.layer_id) + prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot) prev_o, prev_lse = flash_attn_with_lse( q_batched, prev_k, prev_v, softmax_scale=self.scale, @@ -351,7 +363,7 @@ class Attention(nn.Module): torch.cuda.nvtx.range_pop() # Record compute done - this allows the next transfer to safely overwrite this slot - offload_engine.record_slot_compute_done(current_slot, self.layer_id) + offload_engine.record_slot_compute_done(current_slot) # Immediately start loading the NEXT block into this slot (if more blocks remain) # Key insight: reuse current_slot immediately after compute is done! @@ -464,13 +476,9 @@ class Attention(nn.Module): with torch.cuda.stream(compute_stream): # Get KV from current buffer FIRST, before prefetching overwrites it if use_compute: - k_chunk, v_chunk = offload_engine.get_kv_for_compute( - self.layer_id, num_blocks_in_chunk - ) + k_chunk, v_chunk = offload_engine.get_kv_for_compute(num_blocks_in_chunk) else: - k_chunk, v_chunk = offload_engine.get_kv_for_prefetch( - self.layer_id, num_blocks_in_chunk - ) + k_chunk, v_chunk = offload_engine.get_kv_for_prefetch(num_blocks_in_chunk) # Compute attention for this chunk o_chunk, lse_chunk = flash_attn_with_lse( @@ -512,8 +520,9 @@ class Attention(nn.Module): with torch.cuda.stream(compute_stream): if num_accumulated > 0: - decode_k = offload_engine.k_cache_gpu[self.layer_id, offload_engine.decode_slot, start_pos:pos_in_block+1] - decode_v = offload_engine.v_cache_gpu[self.layer_id, offload_engine.decode_slot, start_pos:pos_in_block+1] + # GPU cache has no layer dimension + decode_k = offload_engine.k_cache_gpu[offload_engine.decode_slot, start_pos:pos_in_block+1] + decode_v = offload_engine.v_cache_gpu[offload_engine.decode_slot, start_pos:pos_in_block+1] decode_k = decode_k.unsqueeze(0) decode_v = decode_v.unsqueeze(0) diff --git a/tests/test_debug_verification.py b/tests/test_debug_verification.py index 4258a8c..35d5179 100644 --- a/tests/test_debug_verification.py +++ b/tests/test_debug_verification.py @@ -29,10 +29,7 @@ def debug_load_hook(slot_idx: int, layer_id: int, cpu_block_id: int, k: Tensor, """Record loaded tensor values for layer 0.""" if layer_id != 0: return - - if layer_id == 0: - __import__('pdb').set_trace() - + load_log.append({ "chunk_idx": current_chunk[0], "cpu_block_id": cpu_block_id, diff --git a/tests/test_offload_correctness.py b/tests/test_offload_correctness.py index 8a5a54b..f8fb61d 100644 --- a/tests/test_offload_correctness.py +++ b/tests/test_offload_correctness.py @@ -20,7 +20,6 @@ import torch from random import randint, seed from nanovllm import LLM, SamplingParams from nanovllm.utils.context import get_context -from nanovllm.kvcache.debug_utils import dump_block_state # ============================================================ @@ -97,9 +96,9 @@ def make_verified_load_to_slot_layer(original_func, offload_engine): # cpu_block_id == chunk_idx in our sequential test expected_k, expected_v = get_expected_pattern(cpu_block_id) - # Read GPU slot data - gpu_k = offload_engine.k_cache_gpu[layer_id, slot_idx] - gpu_v = offload_engine.v_cache_gpu[layer_id, slot_idx] + # Read GPU slot data (GPU cache has no layer dimension) + gpu_k = offload_engine.k_cache_gpu[slot_idx] + gpu_v = offload_engine.v_cache_gpu[slot_idx] actual_k = gpu_k.float().mean().item() actual_v = gpu_v.float().mean().item() @@ -306,9 +305,9 @@ def make_gpu_write_verification_post_hook(layer_id: int): # Get expected pattern for current chunk expected_k, expected_v = get_expected_pattern(chunk_idx) - # Verify write_slot contains current chunk's data - gpu_k = oe.k_cache_gpu[layer_id, write_slot] - gpu_v = oe.v_cache_gpu[layer_id, write_slot] + # Verify write_slot contains current chunk's data (GPU cache has no layer dimension) + gpu_k = oe.k_cache_gpu[write_slot] + gpu_v = oe.v_cache_gpu[write_slot] actual_k_mean = gpu_k.float().mean().item() actual_v_mean = gpu_v.float().mean().item() @@ -419,9 +418,9 @@ def make_post_chunk_verification_hook(layer_id: int): expected_k, expected_v = get_expected_pattern(chunk_idx) - # Check GPU ring buffer - gpu_k = oe.k_cache_gpu[layer_id, ring_slot] - gpu_v = oe.v_cache_gpu[layer_id, ring_slot] + # Check GPU ring buffer (GPU cache has no layer dimension) + gpu_k = oe.k_cache_gpu[ring_slot] + gpu_v = oe.v_cache_gpu[ring_slot] k_ok, k_err = check_pattern(gpu_k, expected_k, f"GPU K slot {ring_slot}") v_ok, v_err = check_pattern(gpu_v, expected_v, f"GPU V slot {ring_slot}")