From 1081ab51eac6668fefbaaa295fea9bfd9a763998 Mon Sep 17 00:00:00 2001 From: Zijie Tian Date: Mon, 15 Dec 2025 01:13:58 +0800 Subject: [PATCH] [refactor] Refactor offload code to multi-chunk. --- CLAUDE.md | 7 +- nanovllm/engine/model_runner.py | 40 +++--- nanovllm/kvcache/hybrid_manager.py | 6 +- nanovllm/kvcache/offload_engine.py | 193 +---------------------------- nanovllm/layers/attention.py | 5 +- nanovllm/utils/context.py | 14 +-- tests/test_chunked_attention.py | 4 +- 7 files changed, 36 insertions(+), 233 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index 0fbcdeb..84d8d22 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -17,6 +17,7 @@ Nano-vLLM is a lightweight vLLM implementation (~1,200 lines) for fast offline L **ModelRunner** (`nanovllm/engine/model_runner.py`): - Loads model weights, allocates KV cache, captures CUDA graphs - Rank 0 is main process; ranks 1+ run via `loop()` with shared memory events +- Chunked offload methods: `run_chunked_offload_prefill()`, `run_chunked_offload_decode()` **Scheduler** (`nanovllm/engine/scheduler.py`): - Two-phase scheduling: prefill (waiting queue) then decode (running queue) @@ -34,7 +35,8 @@ Nano-vLLM is a lightweight vLLM implementation (~1,200 lines) for fast offline L **Global Context** (`nanovllm/utils/context.py`): - Stores attention metadata via `get_context()`/`set_context()` -- Key fields: `cu_seqlens`, `slot_mapping`, `block_tables`, `chunked_seq` +- Key fields: `cu_seqlens`, `slot_mapping`, `block_tables`, `chunked_seq`, `kvcache_manager` +- `kvcache_manager`: Reference to HybridKVCacheManager for chunked attention (set when `is_chunked_prefill=True`) ## CPU Offload System @@ -118,9 +120,10 @@ Compute: [C0] [C1] [C2] Manages both GPU and CPU blocks: - `allocate()`: Allocate GPU block first, fallback to CPU -- `allocate_cpu_only()`: Force CPU allocation (for chunked prefill) +- `allocate_cpu_only()`: Force CPU allocation (for chunked offload mode) - `get_all_cpu_blocks(seq)`: Get all CPU block IDs for a sequence - `get_prefilled_cpu_blocks(seq)`: Get CPU blocks from previous chunks +- `get_write_slot_for_chunked_offload(seq)`: Get GPU slot for writing new KV (returns decode_slot) - `may_offload()`: Offload GPU blocks to CPU when decode slot fills ### Online Softmax Merge diff --git a/nanovllm/engine/model_runner.py b/nanovllm/engine/model_runner.py index caab2ed..6fbdb78 100644 --- a/nanovllm/engine/model_runner.py +++ b/nanovllm/engine/model_runner.py @@ -169,17 +169,17 @@ class ModelRunner: ) if config.enable_cpu_offload: - ping_size = config.num_gpu_kvcache_blocks // 2 - tokens_per_ping = ping_size * self.block_size + compute_size = config.num_gpu_kvcache_blocks // 2 + tokens_per_chunk = compute_size * self.block_size logger.info( - f"KV Cache allocated (Ping-Pong mode): " + f"KV Cache allocated (Chunked Offload mode): " f"GPU={config.num_gpu_kvcache_blocks} blocks ({gpu_memory_mb:.1f}MB), " f"CPU={config.num_cpu_kvcache_blocks} blocks ({cpu_memory_mb:.1f}MB), " f"Total={total_memory_mb:.1f}MB" ) logger.info( - f"Ping-Pong config: ping_size={ping_size} blocks, " - f"tokens_per_chunk={tokens_per_ping}, " + f"Chunked Offload config: compute_size={compute_size} blocks, " + f"tokens_per_chunk={tokens_per_chunk}, " f"block_size={self.block_size}" ) else: @@ -374,14 +374,14 @@ class ModelRunner: return self.model.compute_logits(graph_vars["outputs"][:bs]) def run(self, seqs: list[Sequence], is_prefill: bool) -> list[int]: - # Check if Ping-Pong mode should be used (all blocks on CPU) + # Check if Chunked Offload mode should be used (all blocks on CPU) if hasattr(self, 'kvcache_manager') and hasattr(self.kvcache_manager, 'get_all_cpu_blocks'): - use_pingpong = self._should_use_pingpong(seqs, is_prefill) - if use_pingpong: + use_chunked_offload = self._should_use_chunked_offload(seqs, is_prefill) + if use_chunked_offload: if is_prefill: - return self.run_pingpong_prefill(seqs) + return self.run_chunked_offload_prefill(seqs) else: - return self.run_pingpong_decode(seqs) + return self.run_chunked_offload_decode(seqs) # Check if chunked prefill is needed (legacy path) if is_prefill and hasattr(self, 'kvcache_manager'): @@ -410,7 +410,7 @@ class ModelRunner: reset_context() return token_ids - def _should_use_pingpong(self, seqs: list[Sequence], is_prefill: bool) -> bool: + def _should_use_chunked_offload(self, seqs: list[Sequence], is_prefill: bool) -> bool: """ Check if three-region mode should be used. @@ -553,7 +553,7 @@ class ModelRunner: slot_mapping=slot_mapping, context_lens=context_len, is_chunked_prefill=True, # Use chunked attention path - offload_engine=self.kvcache_manager, + kvcache_manager=self.kvcache_manager, chunked_seq=seq, ) @@ -622,13 +622,13 @@ class ModelRunner: max_seqlen_k=seqlen, slot_mapping=slot_mapping, is_chunked_prefill=True, - offload_engine=self.kvcache_manager, # Pass manager for loading previous KV + kvcache_manager=self.kvcache_manager, # Pass manager for loading previous KV chunked_seq=seq, # Pass sequence for loading previous KV ) return input_ids, positions - def run_pingpong_prefill(self, seqs: list[Sequence]) -> list[int]: + def run_chunked_offload_prefill(self, seqs: list[Sequence]) -> list[int]: """ Run prefill with three-region GPU buffer (CPU is primary storage). @@ -679,7 +679,7 @@ class ModelRunner: gpu_slots = offload_engine.compute_slots[:num_blocks] # Prepare inputs - input_ids, positions = self._prepare_pingpong_chunk( + input_ids, positions = self._prepare_chunked_offload_chunk( seq, chunk_start, chunk_end, gpu_slots, start_block_idx ) @@ -718,7 +718,7 @@ class ModelRunner: return token_ids - def _prepare_pingpong_chunk( + def _prepare_chunked_offload_chunk( self, seq: Sequence, chunk_start: int, @@ -726,7 +726,7 @@ class ModelRunner: gpu_slots: list[int], start_block_idx: int, ) -> tuple[torch.Tensor, torch.Tensor]: - """Prepare inputs for a Ping-Pong prefill chunk.""" + """Prepare inputs for a chunked offload prefill chunk.""" # Input tokens for this chunk input_ids = seq[chunk_start:chunk_end] positions = list(range(chunk_start, chunk_end)) @@ -768,13 +768,13 @@ class ModelRunner: max_seqlen_k=seqlen, slot_mapping=slot_mapping, is_chunked_prefill=True, - offload_engine=self.kvcache_manager, + kvcache_manager=self.kvcache_manager, chunked_seq=seq, ) return input_ids, positions - def run_pingpong_decode(self, seqs: list[Sequence]) -> list[int]: + def run_chunked_offload_decode(self, seqs: list[Sequence]) -> list[int]: """ Run decode with three-region GPU buffer. @@ -809,7 +809,7 @@ class ModelRunner: slot_mapping=slot_mapping, context_lens=context_len, is_chunked_prefill=True, # Use chunked attention path - offload_engine=self.kvcache_manager, + kvcache_manager=self.kvcache_manager, chunked_seq=seq, decode_pos_in_block=pos_in_block, decode_start_pos_in_block=decode_start_pos, diff --git a/nanovllm/kvcache/hybrid_manager.py b/nanovllm/kvcache/hybrid_manager.py index cc02f8b..baa8450 100644 --- a/nanovllm/kvcache/hybrid_manager.py +++ b/nanovllm/kvcache/hybrid_manager.py @@ -336,7 +336,7 @@ class HybridKVCacheManager(KVCacheManager): """ Allocate logical blocks for prefill. - In cpu_primary mode (Ping-Pong): All blocks are allocated to CPU. + In cpu_primary mode (Chunked Offload): All blocks are allocated to CPU. In legacy mode: Blocks are allocated to GPU when possible, overflow to CPU. """ assert not seq.block_table, "Sequence already has blocks" @@ -1167,9 +1167,9 @@ class HybridKVCacheManager(KVCacheManager): return block.cpu_block_id return -1 - def get_write_slot_for_pingpong(self, seq: Sequence) -> int: + def get_write_slot_for_chunked_offload(self, seq: Sequence) -> int: """ - Get GPU slot for writing new KV during three-region decode. + Get GPU slot for writing new KV during chunked offload decode. In three-region design, always use Decode region (slot 0) to write new KV. This avoids conflicts with Compute/Prefetch region loading operations. diff --git a/nanovllm/kvcache/offload_engine.py b/nanovllm/kvcache/offload_engine.py index 0450e88..616f395 100644 --- a/nanovllm/kvcache/offload_engine.py +++ b/nanovllm/kvcache/offload_engine.py @@ -91,12 +91,6 @@ class OffloadEngine: self.num_gpu_slots = num_gpu_blocks # alias - # Keep old ping/pong attributes for compatibility (will be removed later) - self.ping_size = self.num_compute_blocks - self.pong_size = self.num_prefetch_blocks - self.ping_slots = self.compute_slots.copy() - self.pong_slots = self.prefetch_slots.copy() - logger.info(f"Three-region GPU Buffer: decode_slot={self.decode_slot}, " f"compute_slots={self.compute_slots}, prefetch_slots={self.prefetch_slots}") @@ -148,13 +142,6 @@ class OffloadEngine: self.prefetch_ready = torch.cuda.Event() self.decode_offload_done = torch.cuda.Event() - # Keep old ping/pong events for compatibility (will be removed later) - self.pingpong_stream = self.transfer_stream_main - self.ping_ready = self.compute_ready - self.pong_ready = self.prefetch_ready - self.ping_offload_done = torch.cuda.Event() - self.pong_offload_done = torch.cuda.Event() - # ========== Per-layer events for chunked attention ========== # Each layer has its own event for synchronization self.compute_ready_per_layer = [torch.cuda.Event() for _ in range(num_layers)] @@ -579,185 +566,9 @@ class OffloadEngine: f")" ) - # ========== Ping-Pong double buffering methods ========== - - def load_to_ping(self, cpu_block_ids: List[int]) -> None: - """ - Async load CPU blocks to Ping buffer. - - Args: - cpu_block_ids: List of CPU block IDs to load - """ - if not cpu_block_ids: - self.ping_ready.record(self.pingpong_stream) - return - - num_to_load = min(len(cpu_block_ids), self.ping_size) - logger.debug(f"Ping load: CPU{cpu_block_ids[:num_to_load]} -> GPU ping slots {self.ping_slots[:num_to_load]}") - - with torch.cuda.stream(self.pingpong_stream): - for i in range(num_to_load): - cpu_id = cpu_block_ids[i] - gpu_slot = self.ping_slots[i] - # Copy all layers together - self.k_cache_gpu[:, gpu_slot].copy_( - self.k_cache_cpu[:, cpu_id], non_blocking=True - ) - self.v_cache_gpu[:, gpu_slot].copy_( - self.v_cache_cpu[:, cpu_id], non_blocking=True - ) - self.ping_ready.record(self.pingpong_stream) - - def load_to_pong(self, cpu_block_ids: List[int]) -> None: - """ - Async load CPU blocks to Pong buffer. - - Args: - cpu_block_ids: List of CPU block IDs to load - """ - if not cpu_block_ids: - self.pong_ready.record(self.pingpong_stream) - return - - num_to_load = min(len(cpu_block_ids), self.pong_size) - logger.debug(f"Pong load: CPU{cpu_block_ids[:num_to_load]} -> GPU pong slots {self.pong_slots[:num_to_load]}") - - with torch.cuda.stream(self.pingpong_stream): - for i in range(num_to_load): - cpu_id = cpu_block_ids[i] - gpu_slot = self.pong_slots[i] - self.k_cache_gpu[:, gpu_slot].copy_( - self.k_cache_cpu[:, cpu_id], non_blocking=True - ) - self.v_cache_gpu[:, gpu_slot].copy_( - self.v_cache_cpu[:, cpu_id], non_blocking=True - ) - self.pong_ready.record(self.pingpong_stream) - - def wait_ping(self) -> None: - """Wait for Ping buffer loading to complete.""" - self.compute_stream.wait_event(self.ping_ready) - - def wait_pong(self) -> None: - """Wait for Pong buffer loading to complete.""" - self.compute_stream.wait_event(self.pong_ready) - - def offload_buffer_to_cpu( - self, - buffer: str, - cpu_block_ids: List[int], - ) -> None: - """ - Async offload KV from buffer to CPU. - - Args: - buffer: "ping" or "pong" - cpu_block_ids: Target CPU block IDs list - """ - slots = self.ping_slots if buffer == "ping" else self.pong_slots - event = self.ping_offload_done if buffer == "ping" else self.pong_offload_done - - if not cpu_block_ids: - event.record(self.pingpong_stream) - return - - num_to_offload = min(len(cpu_block_ids), len(slots)) - logger.debug(f"{buffer.capitalize()} offload: GPU {slots[:num_to_offload]} -> CPU{cpu_block_ids[:num_to_offload]}") - - with torch.cuda.stream(self.pingpong_stream): - # Wait for compute to complete - self.pingpong_stream.wait_stream(self.compute_stream) - - for i in range(num_to_offload): - gpu_slot = slots[i] - cpu_id = cpu_block_ids[i] - self.k_cache_cpu[:, cpu_id].copy_( - self.k_cache_gpu[:, gpu_slot], non_blocking=True - ) - self.v_cache_cpu[:, cpu_id].copy_( - self.v_cache_gpu[:, gpu_slot], non_blocking=True - ) - event.record(self.pingpong_stream) - - def offload_slot_to_cpu( - self, - gpu_slot: int, - cpu_block_id: int, - ) -> None: - """ - Async offload a single GPU slot's KV to CPU. - - Args: - gpu_slot: GPU slot ID - cpu_block_id: Target CPU block ID - """ - logger.debug(f"Slot offload: GPU[{gpu_slot}] -> CPU[{cpu_block_id}]") - - with torch.cuda.stream(self.pingpong_stream): - self.pingpong_stream.wait_stream(self.compute_stream) - self.k_cache_cpu[:, cpu_block_id].copy_( - self.k_cache_gpu[:, gpu_slot], non_blocking=True - ) - self.v_cache_cpu[:, cpu_block_id].copy_( - self.v_cache_gpu[:, gpu_slot], non_blocking=True - ) - - def wait_ping_offload_done(self) -> None: - """Wait for Ping buffer offload to complete.""" - self.compute_stream.wait_event(self.ping_offload_done) - - def wait_pong_offload_done(self) -> None: - """Wait for Pong buffer offload to complete.""" - self.compute_stream.wait_event(self.pong_offload_done) - def wait_all_offload_done(self) -> None: """Wait for all offload operations to complete.""" - self.pingpong_stream.synchronize() - - def get_kv_for_ping_slots( - self, - layer_id: int, - num_slots: int, - ) -> Tuple[Tensor, Tensor]: - """ - Get KV for specified number of slots in Ping buffer. - - Args: - layer_id: Layer ID - num_slots: Number of slots needed - - Returns: - (k_cache, v_cache), shape: [1, num_slots * block_size, kv_heads, head_dim] - """ - slots = self.ping_slots[:num_slots] - k = self.k_cache_gpu[layer_id, slots] # [num_slots, block_size, heads, dim] - v = self.v_cache_gpu[layer_id, slots] - # Reshape: [num_slots, block_size, heads, dim] -> [1, num_slots*block_size, heads, dim] - k = k.reshape(1, -1, self.num_kv_heads, self.head_dim) - v = v.reshape(1, -1, self.num_kv_heads, self.head_dim) - return k, v - - def get_kv_for_pong_slots( - self, - layer_id: int, - num_slots: int, - ) -> Tuple[Tensor, Tensor]: - """ - Get KV for specified number of slots in Pong buffer. - - Args: - layer_id: Layer ID - num_slots: Number of slots needed - - Returns: - (k_cache, v_cache), shape: [1, num_slots * block_size, kv_heads, head_dim] - """ - slots = self.pong_slots[:num_slots] - k = self.k_cache_gpu[layer_id, slots] - v = self.v_cache_gpu[layer_id, slots] - k = k.reshape(1, -1, self.num_kv_heads, self.head_dim) - v = v.reshape(1, -1, self.num_kv_heads, self.head_dim) - return k, v + self.transfer_stream_main.synchronize() def get_kv_for_slots( self, @@ -918,8 +729,6 @@ class OffloadEngine: def swap_compute_prefetch(self) -> None: """Swap roles of Compute region and Prefetch region.""" self.compute_slots, self.prefetch_slots = self.prefetch_slots, self.compute_slots - # Also update old ping/pong slots for compatibility - self.ping_slots, self.pong_slots = self.pong_slots, self.ping_slots def offload_decode_slot(self, cpu_block_id: int) -> None: """ diff --git a/nanovllm/layers/attention.py b/nanovllm/layers/attention.py index 1df50b9..d50456d 100644 --- a/nanovllm/layers/attention.py +++ b/nanovllm/layers/attention.py @@ -123,8 +123,7 @@ class Attention(nn.Module): lse_acc = None # Load previous KV from CPU using Compute/Prefetch region - # Note: context.offload_engine is actually HybridKVCacheManager - kvcache_manager = context.offload_engine + kvcache_manager = context.kvcache_manager seq = context.chunked_seq if hasattr(context, 'chunked_seq') else None if kvcache_manager is not None and seq is not None and self.layer_id >= 0: @@ -215,7 +214,7 @@ class Attention(nn.Module): # q shape: [batch_size, num_heads, head_dim] (single decode token per sequence) q_batched = q.unsqueeze(1) # [batch, 1, heads, dim] - kvcache_manager = context.offload_engine + kvcache_manager = context.kvcache_manager seq = context.chunked_seq # Get all CPU blocks for this sequence diff --git a/nanovllm/utils/context.py b/nanovllm/utils/context.py index 5addf8b..86163a9 100644 --- a/nanovllm/utils/context.py +++ b/nanovllm/utils/context.py @@ -21,7 +21,7 @@ class Context: # Current chunk's position offset (for causal mask) chunk_offset: int = 0 # Reference to kvcache manager for loading previous KV (HybridKVCacheManager) - offload_engine: Any = None + kvcache_manager: Any = None # Current layer's previous K/V chunks (loaded from CPU) # Set by model_runner before each layer's forward prev_kv_chunks: List[Tuple[torch.Tensor, torch.Tensor]] = field(default_factory=list) @@ -33,14 +33,6 @@ class Context: # Used when batching decode offloads - we need to attend to all accumulated tokens decode_start_pos_in_block: int = 0 - # ========== Per-layer chunked attention state ========== - # Whether chunked decode/prefill is currently active (for hooks to check) - chunked_decode_active: bool = False - # CPU block IDs for the current chunk being processed - chunked_decode_chunk_ids: List[int] = field(default_factory=list) - # Current chunk index being processed - chunked_decode_current_chunk: int = 0 - _CONTEXT = Context() @@ -61,7 +53,7 @@ def set_context( is_chunked_prefill=False, prev_kv_ranges=None, chunk_offset=0, - offload_engine=None, + kvcache_manager=None, chunked_seq=None, decode_pos_in_block=0, decode_start_pos_in_block=0, @@ -79,7 +71,7 @@ def set_context( is_chunked_prefill=is_chunked_prefill, prev_kv_ranges=prev_kv_ranges or [], chunk_offset=chunk_offset, - offload_engine=offload_engine, + kvcache_manager=kvcache_manager, chunked_seq=chunked_seq, decode_pos_in_block=decode_pos_in_block, decode_start_pos_in_block=decode_start_pos_in_block, diff --git a/tests/test_chunked_attention.py b/tests/test_chunked_attention.py index b2be4ff..52ca8e6 100644 --- a/tests/test_chunked_attention.py +++ b/tests/test_chunked_attention.py @@ -71,7 +71,7 @@ def test_chunked_prefill(num_gpu_blocks=10, input_len=8192, output_len=64, num_p path = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/") print(f"=" * 60) - print(f"Chunked Prefill Test (Ping-Pong)") + print(f"Chunked Prefill Test (Chunked Offload)") print(f"=" * 60) print(f" target_input_len: ~{input_len} tokens") print(f" num_gpu_blocks: {num_gpu_blocks}") @@ -111,7 +111,7 @@ def test_chunked_decode(num_gpu_blocks=10, input_len=8192, output_len=128, num_p path = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/") print(f"=" * 60) - print(f"Chunked Decode Test (Ping-Pong)") + print(f"Chunked Decode Test (Chunked Offload)") print(f"=" * 60) print(f" target_input_len: ~{input_len} tokens") print(f" output_len: {output_len} tokens")