From 6575099a069926bcb60d30c0b4d6cc4c0f673887 Mon Sep 17 00:00:00 2001 From: Zijie Tian Date: Wed, 7 Jan 2026 06:25:21 +0800 Subject: [PATCH] [refactor] Cleanup unused code after perf_opt merge MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Removed ~460 lines of unused/redundant code from offload_engine.py: - CUDA gather methods (gathered_h2d_*, update_gather_indices) - Legacy async transfer methods (prefetch_block_async, offload_block_async) - Legacy sync/wait methods (wait_for_block, wait_all_transfers, sync_indices) - Legacy compatibility methods (load_to_compute_layer, wait_compute_layer) - Unused gather_indices tensors and memory calculations Updated class docstring to reflect current architecture. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- nanovllm/kvcache/offload_engine.py | 471 +---------------------------- 1 file changed, 7 insertions(+), 464 deletions(-) diff --git a/nanovllm/kvcache/offload_engine.py b/nanovllm/kvcache/offload_engine.py index 4151692..ceeae44 100644 --- a/nanovllm/kvcache/offload_engine.py +++ b/nanovllm/kvcache/offload_engine.py @@ -40,14 +40,13 @@ class OffloadEngine: High-performance CPU-GPU async transfer engine for KV cache offloading. Memory layout: - - GPU cache: [num_layers, num_gpu_blocks, block_size, kv_heads, head_dim] + - GPU cache: [num_gpu_blocks, block_size, kv_heads, head_dim] (no layer dimension) - CPU cache: [num_layers, num_cpu_blocks, block_size, kv_heads, head_dim] (pinned) - - Gather indices: [num_layers, num_gpu_blocks] (fixed address, variable content) - CUDA Graph compatibility: - - gathered_h2d_layer() can be captured into CUDA graphs - - update_gather_indices() is called outside graphs to prepare indices - - All tensor addresses remain fixed across graph replays + Features: + - Unified ring buffer for chunked prefill/decode + - Per-layer prefill buffer for async offload + - Cross-layer pipeline for decode with double-buffering """ def __init__( @@ -210,19 +209,6 @@ class OffloadEngine: dtype=dtype, device="cpu", pin_memory=True ) - # ========== Fixed-address gather indices (content is variable) ========== - # gather_indices[layer][i] = CPU block id to copy to GPU slot i - # -1 means no-op (skip this slot) - self.gather_indices_cpu = torch.empty( - num_layers, num_gpu_blocks, - dtype=torch.int64, device="cpu", pin_memory=True - ) - self.gather_indices_cpu.fill_(-1) - self.gather_indices_gpu = torch.full( - (num_layers, num_gpu_blocks), -1, - dtype=torch.int64, device="cuda" - ) - # Log memory allocation gpu_mem_mb = self.gpu_memory_bytes() / (1024 * 1024) cpu_mem_mb = self.cpu_memory_bytes() / (1024 * 1024) @@ -277,321 +263,6 @@ class OffloadEngine: # ========== Sparse attention policy (set at construction time) ========== self.sparse_policy = sparse_policy - def _get_next_stream(self) -> torch.cuda.Stream: - """Round-robin stream selection for parallel transfers.""" - stream = self.transfer_streams[self._stream_idx] - self._stream_idx = (self._stream_idx + 1) % len(self.transfer_streams) - return stream - - # ========== CUDA Graph compatible methods ========== - # NOTE: These methods need to be updated for the new GPU cache architecture. - # GPU cache no longer has layer dimension, so gathered copy semantics change. - # For now, these are kept for reference but should not be used without updating. - - def gathered_h2d_layer(self, layer_id: int) -> None: - """ - Execute gathered H2D copy for a single layer. - - WARNING: This method needs updating for new GPU cache architecture. - GPU cache no longer has layer dimension. - """ - # GPU cache has no layer dimension - use flat indexing - # Source is CPU[layer_id], dest is GPU (shared across layers) - gathered_copy_kv( - k_src=self.k_cache_cpu[layer_id], - v_src=self.v_cache_cpu[layer_id], - k_dst=self.k_cache_gpu, # No layer indexing - v_dst=self.v_cache_gpu, # No layer indexing - indices=self.gather_indices_gpu[layer_id], - ) - - def gathered_h2d_all_layers(self) -> None: - """ - Execute gathered H2D copy for all layers. - - WARNING: In new architecture, GPU slots are shared across layers. - This method would overwrite slots multiple times. Not recommended. - """ - for layer_id in range(self.num_layers): - self.gathered_h2d_layer(layer_id) - - def update_gather_indices( - self, - layer_id: int, - mappings: List[Tuple[int, int]], - ) -> None: - """ - Update gather indices for a layer (call OUTSIDE CUDA graph). - - Args: - layer_id: Layer index - mappings: List of (cpu_block_id, gpu_slot) tuples - Only these slots will be updated; others keep their values - """ - for cpu_block_id, gpu_slot in mappings: - self.gather_indices_cpu[layer_id, gpu_slot] = cpu_block_id - - # Async copy to GPU - self.gather_indices_gpu[layer_id].copy_( - self.gather_indices_cpu[layer_id], - non_blocking=True - ) - - def update_gather_indices_all_layers( - self, - mappings_per_layer: List[List[Tuple[int, int]]], - ) -> None: - """ - Update gather indices for all layers. - - Args: - mappings_per_layer: mappings_per_layer[layer_id] = [(cpu_block_id, gpu_slot), ...] - """ - for layer_id, mappings in enumerate(mappings_per_layer): - for cpu_block_id, gpu_slot in mappings: - self.gather_indices_cpu[layer_id, gpu_slot] = cpu_block_id - - # Batch copy all layers - self.gather_indices_gpu.copy_(self.gather_indices_cpu, non_blocking=True) - - def clear_gather_indices(self, layer_id: Optional[int] = None) -> None: - """ - Clear gather indices (set all to -1, meaning no-op). - - Args: - layer_id: If provided, clear only this layer; otherwise clear all - """ - if layer_id is not None: - self.gather_indices_cpu[layer_id].fill_(-1) - self.gather_indices_gpu[layer_id].fill_(-1) - else: - self.gather_indices_cpu.fill_(-1) - self.gather_indices_gpu.fill_(-1) - - # ========== Async transfer methods (for prefill, outside CUDA graph) ========== - - def prefetch_block_async( - self, - layer_id: int, - cpu_block_id: int, - gpu_block_id: int, - ) -> torch.cuda.Event: - """ - Async prefetch a single block from CPU to GPU. - - GPU cache has no layer dimension - layer_id is for CPU cache indexing. - - Args: - layer_id: Layer index (for CPU cache) - cpu_block_id: Source block in CPU cache - gpu_block_id: Destination slot in GPU cache - - Returns: - CUDA event that signals completion - """ - stream = self._get_next_stream() - event = torch.cuda.Event() - - logger.debug(f"H2D prefetch: layer={layer_id}, CPU[{cpu_block_id}] -> GPU[{gpu_block_id}]") - - with torch.cuda.stream(stream): - # GPU: no layer dimension, CPU: has layer dimension - self.k_cache_gpu[gpu_block_id].copy_( - self.k_cache_cpu[layer_id, cpu_block_id], - non_blocking=True - ) - self.v_cache_gpu[gpu_block_id].copy_( - self.v_cache_cpu[layer_id, cpu_block_id], - non_blocking=True - ) - event.record() - - self.pending_events[(layer_id, gpu_block_id)] = event - return event - - def prefetch_blocks_batch_async( - self, - transfers: List[Tuple[int, int, int]], # [(layer_id, cpu_block_id, gpu_block_id), ...] - ) -> List[torch.cuda.Event]: - """ - Batch async prefetch multiple blocks. - - Args: - transfers: List of (layer_id, cpu_block_id, gpu_block_id) tuples - - Returns: - List of CUDA events for each transfer - """ - events = [] - for layer_id, cpu_block_id, gpu_block_id in transfers: - event = self.prefetch_block_async(layer_id, cpu_block_id, gpu_block_id) - events.append(event) - return events - - def offload_block_async( - self, - layer_id: int, - gpu_block_id: int, - cpu_block_id: int, - ) -> torch.cuda.Event: - """ - Async offload a block from GPU to CPU. - - GPU cache has no layer dimension - layer_id is for CPU cache indexing. - - Args: - layer_id: Layer index (for CPU cache) - gpu_block_id: Source slot in GPU cache - cpu_block_id: Destination block in CPU cache - - Returns: - CUDA event that signals completion - """ - stream = self._get_next_stream() - event = torch.cuda.Event() - - logger.debug(f"D2H offload: layer={layer_id}, GPU[{gpu_block_id}] -> CPU[{cpu_block_id}]") - - with torch.cuda.stream(stream): - # Wait for any compute using this block - stream.wait_stream(self.compute_stream) - - # GPU: no layer dimension, CPU: has layer dimension - self.k_cache_cpu[layer_id, cpu_block_id].copy_( - self.k_cache_gpu[gpu_block_id], - non_blocking=True - ) - self.v_cache_cpu[layer_id, cpu_block_id].copy_( - self.v_cache_gpu[gpu_block_id], - non_blocking=True - ) - event.record() - - return event - - def offload_blocks_batch_async( - self, - transfers: List[Tuple[int, int, int]], # [(layer_id, gpu_block_id, cpu_block_id), ...] - ) -> List[torch.cuda.Event]: - """ - Batch async offload multiple blocks. - - Args: - transfers: List of (layer_id, gpu_block_id, cpu_block_id) tuples - - Returns: - List of CUDA events - """ - events = [] - for layer_id, gpu_block_id, cpu_block_id in transfers: - event = self.offload_block_async(layer_id, gpu_block_id, cpu_block_id) - events.append(event) - return events - - # ========== Chunked Decode: Load CPU blocks to GPU slots ========== - - def load_cpu_blocks_to_gpu_slots( - self, - layer_id: int, - cpu_block_ids: List[int], - gpu_slot_ids: List[int], - ) -> None: - """ - Load CPU blocks to specific GPU slots for chunked decode. - - GPU cache has no layer dimension - layer_id is for CPU cache indexing. - - Args: - layer_id: Layer index (for CPU cache) - cpu_block_ids: List of CPU block IDs to load - gpu_slot_ids: List of GPU slot IDs to load into (must be same length) - """ - assert len(cpu_block_ids) == len(gpu_slot_ids) - - if cpu_block_ids: - logger.debug(f"H2D chunked load: layer={layer_id}, CPU{cpu_block_ids} -> GPU{gpu_slot_ids}") - - stream = self._get_next_stream() - - with torch.cuda.stream(stream): - for cpu_block_id, gpu_slot in zip(cpu_block_ids, gpu_slot_ids): - # GPU: no layer dimension, CPU: has layer dimension - self.k_cache_gpu[gpu_slot].copy_( - self.k_cache_cpu[layer_id, cpu_block_id], - non_blocking=True - ) - self.v_cache_gpu[gpu_slot].copy_( - self.v_cache_cpu[layer_id, cpu_block_id], - non_blocking=True - ) - - # Wait for transfer to complete - stream.synchronize() - - def load_cpu_blocks_to_gpu_slots_async( - self, - layer_id: int, - cpu_block_ids: List[int], - gpu_slot_ids: List[int], - ) -> torch.cuda.Event: - """ - Async version: Load CPU blocks to GPU slots. - - GPU cache has no layer dimension - layer_id is for CPU cache indexing. - - Args: - layer_id: Layer index (for CPU cache) - cpu_block_ids: List of CPU block IDs to load - gpu_slot_ids: List of GPU slot IDs to load into - - Returns: - CUDA event to wait on - """ - assert len(cpu_block_ids) == len(gpu_slot_ids) - - if cpu_block_ids: - logger.debug(f"H2D chunked load async: layer={layer_id}, CPU{cpu_block_ids} -> GPU{gpu_slot_ids}") - - stream = self._get_next_stream() - event = torch.cuda.Event() - - with torch.cuda.stream(stream): - for cpu_block_id, gpu_slot in zip(cpu_block_ids, gpu_slot_ids): - # GPU: no layer dimension, CPU: has layer dimension - self.k_cache_gpu[gpu_slot].copy_( - self.k_cache_cpu[layer_id, cpu_block_id], - non_blocking=True - ) - self.v_cache_gpu[gpu_slot].copy_( - self.v_cache_cpu[layer_id, cpu_block_id], - non_blocking=True - ) - event.record() - - return event - - # NOTE: load_cpu_blocks_to_gpu_slots_all_layers removed - GPU cache no longer has - # layer dimension. Each GPU slot holds data for ONE layer at a time. - - # ========== Synchronization methods ========== - - def wait_for_block(self, layer_id: int, gpu_block_id: int) -> None: - """Wait for a specific block's transfer to complete.""" - key = (layer_id, gpu_block_id) - if key in self.pending_events: - self.pending_events[key].synchronize() - del self.pending_events[key] - - def wait_all_transfers(self) -> None: - """Wait for all pending transfers to complete.""" - for stream in self.transfer_streams: - stream.synchronize() - self.pending_events.clear() - - def sync_indices(self) -> None: - """Synchronize to ensure all index updates are complete.""" - torch.cuda.default_stream().synchronize() - # ========== Cache access methods ========== def get_layer_cache(self, layer_id: int) -> Tuple[Tensor, Tensor]: @@ -605,54 +276,22 @@ class OffloadEngine: (k_cache, v_cache) tensors Shape: [num_gpu_blocks, block_size, kv_heads, head_dim] """ - # GPU cache is shared across all layers (no layer dimension) return self.k_cache_gpu, self.v_cache_gpu - def get_all_gpu_cache(self) -> Tuple[Tensor, Tensor]: - """ - Get full GPU K/V cache tensors. - - NOTE: GPU cache has no layer dimension in the new architecture. - - Returns: - (k_cache, v_cache) tensors - Shape: [num_gpu_blocks, block_size, kv_heads, head_dim] - """ - return self.k_cache_gpu, self.v_cache_gpu - - def get_cpu_block( - self, - layer_id: int, - cpu_block_id: int, - ) -> Tuple[Tensor, Tensor]: - """ - Get a specific CPU block's K/V cache. - - Returns: - (k_cache, v_cache) for the block - Shape: [block_size, kv_heads, head_dim] - """ - return ( - self.k_cache_cpu[layer_id, cpu_block_id], - self.v_cache_cpu[layer_id, cpu_block_id], - ) - # ========== Memory info ========== def gpu_memory_bytes(self) -> int: """Total GPU memory used by KV caches.""" return ( self.k_cache_gpu.numel() * self.k_cache_gpu.element_size() + - self.v_cache_gpu.numel() * self.v_cache_gpu.element_size() + - self.gather_indices_gpu.numel() * self.gather_indices_gpu.element_size() + self.v_cache_gpu.numel() * self.v_cache_gpu.element_size() ) def cpu_memory_bytes(self) -> int: """Total CPU memory used by KV caches.""" return ( self.k_cache_cpu.numel() * self.k_cache_cpu.element_size() + - self.v_cache_cpu.numel() * self.v_cache_cpu.element_size() + - self.gather_indices_cpu.numel() * self.gather_indices_cpu.element_size() + self.v_cache_cpu.numel() * self.v_cache_cpu.element_size() ) def __repr__(self) -> str: @@ -955,102 +594,6 @@ class OffloadEngine: v = v.unsqueeze(0) return k, v - # ----- Legacy compatibility methods (for decode double-buffering) ----- - # NOTE: GPU cache has no layer dimension. Layer ID is used for CPU cache indexing only. - - def load_to_compute_layer(self, layer_id: int, cpu_block_ids: List[int]) -> None: - """ - Legacy: Load CPU blocks to decode_load_slots for decode double-buffering. - - Uses first half of decode_load_slots as 'compute' region. - GPU cache has no layer dimension - layer_id is for CPU cache indexing. - """ - if not cpu_block_ids: - return - - half = max(1, len(self.decode_load_slots) // 2) - slots = self.decode_load_slots[:half] - num_to_load = min(len(cpu_block_ids), len(slots)) - - with torch.cuda.stream(self.transfer_stream_main): - for i in range(num_to_load): - cpu_id = cpu_block_ids[i] - gpu_slot = slots[i] - # GPU: no layer dimension, CPU: has layer dimension - self.k_cache_gpu[gpu_slot].copy_( - self.k_cache_cpu[layer_id, cpu_id], non_blocking=True - ) - self.v_cache_gpu[gpu_slot].copy_( - self.v_cache_cpu[layer_id, cpu_id], non_blocking=True - ) - if num_to_load > 0: - self.ring_slot_ready[slots[0]].record(self.transfer_stream_main) - - def wait_compute_layer(self) -> None: - """Legacy: Wait for 'compute' region loading.""" - if self.decode_load_slots: - self.wait_slot_layer(self.decode_load_slots[0]) - - def load_to_prefetch_layer(self, layer_id: int, cpu_block_ids: List[int]) -> None: - """ - Legacy: Load CPU blocks to decode_load_slots for decode double-buffering. - - Uses second half of decode_load_slots as 'prefetch' region. - GPU cache has no layer dimension - layer_id is for CPU cache indexing. - """ - if not cpu_block_ids: - return - - half = max(1, len(self.decode_load_slots) // 2) - slots = self.decode_load_slots[half:] - if not slots: - slots = self.decode_load_slots # Fallback if only 1-2 slots - num_to_load = min(len(cpu_block_ids), len(slots)) - - with torch.cuda.stream(self.transfer_stream_main): - for i in range(num_to_load): - cpu_id = cpu_block_ids[i] - gpu_slot = slots[i] - # GPU: no layer dimension, CPU: has layer dimension - self.k_cache_gpu[gpu_slot].copy_( - self.k_cache_cpu[layer_id, cpu_id], non_blocking=True - ) - self.v_cache_gpu[gpu_slot].copy_( - self.v_cache_cpu[layer_id, cpu_id], non_blocking=True - ) - if num_to_load > 0: - self.ring_slot_ready[slots[0]].record(self.transfer_stream_main) - - def wait_prefetch_layer(self) -> None: - """Legacy: Wait for 'prefetch' region loading.""" - half = max(1, len(self.decode_load_slots) // 2) - slots = self.decode_load_slots[half:] - if slots: - self.wait_slot_layer(slots[0]) - elif self.decode_load_slots: - self.wait_slot_layer(self.decode_load_slots[0]) - - def get_kv_for_compute( - self, - num_blocks: int, - ) -> Tuple[Tensor, Tensor]: - """Legacy: Get KV from 'compute' region (first half of decode_load_slots).""" - half = max(1, len(self.decode_load_slots) // 2) - slots = self.decode_load_slots[:half][:num_blocks] - return self.get_kv_for_slots(slots) - - def get_kv_for_prefetch( - self, - num_blocks: int, - ) -> Tuple[Tensor, Tensor]: - """Legacy: Get KV from 'prefetch' region (second half of decode_load_slots).""" - half = max(1, len(self.decode_load_slots) // 2) - slots = self.decode_load_slots[half:] - if not slots: - slots = self.decode_load_slots - slots = slots[:num_blocks] - return self.get_kv_for_slots(slots) - # ========== Debug Hook Interface ========== # # Minimal generic hook system for debugging.