diff --git a/CLAUDE.md b/CLAUDE.md index 935c01c..e236d36 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -173,13 +173,16 @@ Compute: [C0] [C1] [C2] **File**: `nanovllm/kvcache/hybrid_manager.py` -Manages both GPU and CPU blocks: -- `allocate()`: Allocate GPU block first, fallback to CPU -- `allocate_cpu_only()`: Force CPU allocation (for ring buffer mode) +CPU-primary KV cache manager with GPU ring buffer design: +- All KV cache is stored on CPU as primary storage +- GPU is used as a ring buffer for computation only +- Ring buffer enables pipelined H2D transfers overlapped with computation + +Key methods: +- `allocate()` / `allocate_cpu_only()`: Allocate all blocks to CPU - `get_all_cpu_blocks(seq)`: Get all CPU block IDs for a sequence - `get_prefilled_cpu_blocks(seq)`: Get CPU blocks from previous chunks - `get_write_slot_for_chunked_offload(seq)`: Get GPU slot for writing new KV (returns decode_slot) -- `may_offload()`: Offload GPU blocks to CPU when decode slot fills ### Online Softmax Merge diff --git a/nanovllm/engine/model_runner.py b/nanovllm/engine/model_runner.py index e39e6c5..212036a 100644 --- a/nanovllm/engine/model_runner.py +++ b/nanovllm/engine/model_runner.py @@ -388,26 +388,6 @@ class ModelRunner: else: return self.run_chunked_offload_decode(seqs) - # Check if chunked prefill is needed (legacy path) - if is_prefill and hasattr(self, 'kvcache_manager'): - needs_chunked = any( - hasattr(self.kvcache_manager, 'needs_chunked_prefill') and - self.kvcache_manager.needs_chunked_prefill(seq) - for seq in seqs if seq.block_table - ) - if needs_chunked: - return self.run_chunked_prefill(seqs) - - # Check if chunked decode is needed (legacy path) - if not is_prefill and hasattr(self, 'kvcache_manager'): - needs_chunked = any( - hasattr(self.kvcache_manager, 'needs_chunked_decode') and - self.kvcache_manager.needs_chunked_decode(seq) - for seq in seqs if seq.block_table - ) - if needs_chunked: - return self.run_chunked_decode(seqs) - input_ids, positions = self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs) temperatures = self.prepare_sample(seqs) if self.rank == 0 else None logits = self.run_model(input_ids, positions, is_prefill) @@ -445,194 +425,6 @@ class ModelRunner: return False - def run_chunked_prefill(self, seqs: list[Sequence]) -> list[int]: - """ - Run prefill in chunks when sequences exceed GPU capacity. - - For each chunk: - 1. Process tokens through model forward pass - 2. At each attention layer: - - Load previous KV from CPU (handled by attention layer) - - Compute attention with online softmax merging - - Store current KV to GPU cache - 3. After chunk completes, offload KV to CPU - 4. Load next chunk's blocks to GPU - """ - import sys - - # Currently only supporting single sequence for chunked prefill - assert len(seqs) == 1, "Chunked prefill only supports single sequence" - seq = seqs[0] - - total_blocks = seq.num_blocks - print(f"[Chunked Prefill] Starting: {total_blocks} total blocks, " - f"GPU slots: {self.kvcache_manager.num_gpu_slots}", file=sys.stderr) - - chunk_num = 0 - logits = None - - while True: - # Get chunk info (which blocks are on GPU and not yet prefilled) - chunk_info = self.kvcache_manager.get_gpu_block_tables_partial(seqs) - gpu_blocks, start_block_idx, end_block_idx = chunk_info[0] - - if not gpu_blocks: - # No more blocks to process - break - - chunk_num += 1 - chunk_tokens = (end_block_idx - start_block_idx) * self.block_size - if end_block_idx == seq.num_blocks: - # Last block may be partial - chunk_tokens = len(seq) - start_block_idx * self.block_size - - print(f"[Chunked Prefill] Chunk {chunk_num}: blocks {start_block_idx}-{end_block_idx-1}, " - f"~{chunk_tokens} tokens", file=sys.stderr) - - # Prepare inputs for this chunk - input_ids, positions = self._prepare_chunked_prefill(seq, gpu_blocks, start_block_idx, end_block_idx) - - if input_ids.numel() == 0: - print(f"[Chunked Prefill] No input tokens, breaking", file=sys.stderr) - break - - print(f"[Chunked Prefill] Running model with {input_ids.numel()} tokens...", file=sys.stderr) - - # Run model forward pass - logits = self.run_model(input_ids, positions, is_prefill=True) - reset_context() - - print(f"[Chunked Prefill] Model forward complete", file=sys.stderr) - - # Check if this is the last chunk - # Mark current chunk as prefilled and offload to CPU - self.kvcache_manager.complete_prefill_chunk(seq) - - # Check if more chunks needed - if not self.kvcache_manager.needs_chunked_prefill(seq): - print(f"[Chunked Prefill] All chunks done, sampling", file=sys.stderr) - break - - print(f"[Chunked Prefill] Chunk transfer complete, loading next...", file=sys.stderr) - - # Sample from the last chunk's logits - temperatures = self.prepare_sample(seqs) if self.rank == 0 else None - if logits is not None: - token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None - else: - token_ids = [0] if self.rank == 0 else None - - return token_ids - - def run_chunked_decode(self, seqs: list[Sequence]) -> list[int]: - """ - Run decode with chunked attention when sequence exceeds GPU capacity. - - For decode, we need attention over ALL previous tokens. With CPU offload, - we load KV chunks and compute attention incrementally per-layer. - - Flow: - 1. Ensure last block is on GPU (for writing new KV token) - 2. Run model forward - each attention layer: - a. Compute attention on GPU blocks - b. Load CPU blocks in chunks, compute + merge - 3. Sample from output - """ - # Currently only supporting single sequence for chunked decode - assert len(seqs) == 1, "Chunked decode only supports single sequence" - seq = seqs[0] - - # Prepare inputs - input_ids = torch.tensor([seq.last_token], dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) - positions = torch.tensor([len(seq) - 1], dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) - - # Ensure last block is on GPU for writing new KV token - last_gpu_slot = self.kvcache_manager.ensure_last_block_on_gpu(seq) - slot = last_gpu_slot * self.block_size + seq.last_block_num_tokens - 1 - slot_mapping = torch.tensor([slot], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) - context_len = torch.tensor([len(seq)], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) - - # Set up context for chunked decode - set_context( - is_prefill=False, # Decode mode - slot_mapping=slot_mapping, - context_lens=context_len, - is_chunked_prefill=True, # Use chunked attention path - kvcache_manager=self.kvcache_manager, - chunked_seq=seq, - ) - - # Run model forward pass - # Each attention layer will handle chunked KV loading internally - logits = self.run_model(input_ids, positions, is_prefill=False) - reset_context() - - # Sample - temperatures = self.prepare_sample(seqs) if self.rank == 0 else None - token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None - - return token_ids - - def _prepare_chunked_prefill( - self, - seq: Sequence, - gpu_blocks: list[int], - start_block_idx: int, - end_block_idx: int, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Prepare inputs for a single chunk in chunked prefill. - - Sets up context with is_chunked_prefill=True so attention layers - know to load previous KV from CPU. - """ - # Calculate token range for this chunk - start_token = start_block_idx * self.block_size - end_token = min(end_block_idx * self.block_size, len(seq)) - - # Input tokens for this chunk - input_ids = seq[start_token:end_token] - positions = list(range(start_token, end_token)) - - # Slot mapping for storing KV cache - slot_mapping = [] - for i, gpu_block_id in enumerate(gpu_blocks): - block_idx = start_block_idx + i - start = gpu_block_id * self.block_size - if block_idx != seq.num_blocks - 1: - end = start + self.block_size - else: - end = start + seq.last_block_num_tokens - slot_mapping.extend(list(range(start, end))) - - # Trim slot_mapping to match actual token count - actual_tokens = end_token - start_token - slot_mapping = slot_mapping[:actual_tokens] - - # Convert to tensors - input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) - positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) - slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) - - # Set up context for chunked prefill - seqlen = actual_tokens - cu_seqlens_q = torch.tensor([0, seqlen], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) - cu_seqlens_k = torch.tensor([0, seqlen], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) - - set_context( - is_prefill=True, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=seqlen, - max_seqlen_k=seqlen, - slot_mapping=slot_mapping, - is_chunked_prefill=True, - kvcache_manager=self.kvcache_manager, # Pass manager for loading previous KV - chunked_seq=seq, # Pass sequence for loading previous KV - ) - - return input_ids, positions - def run_chunked_offload_prefill(self, seqs: list[Sequence]) -> list[int]: """ Run prefill with unified ring buffer (CPU is primary storage). diff --git a/nanovllm/kvcache/__init__.py b/nanovllm/kvcache/__init__.py index ef34a81..e8eb7f9 100644 --- a/nanovllm/kvcache/__init__.py +++ b/nanovllm/kvcache/__init__.py @@ -3,7 +3,7 @@ KV Cache management module. This module provides pluggable KV cache management strategies: - GPUOnlyManager: Pure GPU (default, current nano-vllm behavior) -- HybridKVCacheManager: CPU offload with CUDA Graph support +- HybridKVCacheManager: CPU-primary storage with GPU ring buffer for computation Usage: from nanovllm.kvcache import create_kvcache_manager diff --git a/nanovllm/kvcache/hybrid_manager.py b/nanovllm/kvcache/hybrid_manager.py index f4dde34..5be9e94 100644 --- a/nanovllm/kvcache/hybrid_manager.py +++ b/nanovllm/kvcache/hybrid_manager.py @@ -65,22 +65,19 @@ class LogicalBlock: class HybridKVCacheManager(KVCacheManager): """ - Hybrid CPU-GPU KV cache manager with CUDA Graph support. + Hybrid CPU-GPU KV cache manager with ring buffer design. - Architecture: - - GPU buffer: Fixed-size working set (num_gpu_slots) - - CPU pool: Overflow storage (num_cpu_blocks) + Architecture (CPU-primary mode): + - CPU pool: Primary storage for all KV cache (num_cpu_blocks) + - GPU buffer: Ring buffer for computation (num_gpu_slots) - Logical blocks: What sequences reference (num_gpu_slots + num_cpu_blocks) - CUDA Graph compatibility: - - All tensor addresses fixed at init time - - prepare_for_attention() updates gather_indices (outside graph) - - gathered_h2d_layer() executes transfer (inside graph) - - Strategy: - 1. New KV data written to GPU slots - 2. Cold blocks evicted to CPU using configurable policy - 3. Needed blocks prefetched back to GPU before attention + Design: + - All KV cache is stored on CPU as primary storage + - GPU is used as a ring buffer for computation only + - During prefill: KV is written to GPU ring slot, then offloaded to CPU + - During decode: Previous KV is loaded from CPU to GPU for attention + - Ring buffer enables pipelined H2D transfers overlapped with computation """ def __init__( @@ -89,26 +86,25 @@ class HybridKVCacheManager(KVCacheManager): num_cpu_blocks: int, block_size: int, policy: Optional[EvictionPolicy] = None, - cpu_primary: bool = True, num_prefetch_blocks: int = 2, ): """ - Initialize hybrid manager. + Initialize hybrid manager with CPU-primary ring buffer design. + + All KV cache is stored on CPU as primary storage. GPU slots are used + as a ring buffer for computation only. Args: - num_gpu_slots: Number of GPU buffer slots (working set) - num_cpu_blocks: Number of CPU pool blocks (overflow or primary storage) + num_gpu_slots: Number of GPU buffer slots (ring buffer for computation) + num_cpu_blocks: Number of CPU pool blocks (primary storage) block_size: Tokens per block - policy: Eviction policy (default: LRU) - cpu_primary: If True, use CPU as primary storage with ring buffer GPU design. - If False, use GPU as primary with CPU as overflow (legacy mode). + policy: Eviction policy (default: LRU, used for prefix cache management) num_prefetch_blocks: Number of blocks for ring buffer pipeline (deprecated, ring_slots = num_gpu_slots) """ self._block_size = block_size self.num_gpu_slots = num_gpu_slots self.num_cpu_blocks = num_cpu_blocks self.total_blocks = num_gpu_slots + num_cpu_blocks - self.cpu_primary = cpu_primary # Ring buffer mode flag self.num_prefetch_blocks = num_prefetch_blocks # Ring buffer design parameter (deprecated) # Eviction policy @@ -200,160 +196,6 @@ class HybridKVCacheManager(KVCacheManager): self.sparse_policy = policy logger.info(f"Sparse attention policy set: {policy}") - def _allocate_gpu_slot(self, protected_logical_ids: Optional[Set[int]] = None) -> int: - """ - Get a free GPU slot, evicting if necessary. - - Args: - protected_logical_ids: Logical block IDs that cannot be evicted - - Returns: - GPU slot ID - - Raises: - RuntimeError: If no GPU slot is available - """ - if self.free_gpu_slots: - return self.free_gpu_slots.popleft() - - # Need to evict - find victim using policy - return self._evict_to_cpu(protected_logical_ids) - - def _try_allocate_gpu_slot(self, protected_logical_ids: Optional[Set[int]] = None) -> Optional[int]: - """ - Try to get a free GPU slot, evicting if necessary. - - Unlike _allocate_gpu_slot(), returns None instead of raising if no eviction possible. - - Args: - protected_logical_ids: Logical block IDs that cannot be evicted - - Returns: - GPU slot ID, or None if no slot available - """ - if self.free_gpu_slots: - return self.free_gpu_slots.popleft() - - # Check if we can evict - protected = protected_logical_ids or set() - for gpu_slot, logical_id in self.gpu_slot_to_logical.items(): - if logical_id not in protected: - block = self.logical_blocks[logical_id] - if block.ref_count > 0: - # Found evictable block - return self._evict_to_cpu(protected_logical_ids) - - # No evictable blocks - return None - - def _evict_to_cpu(self, protected_logical_ids: Optional[Set[int]] = None) -> int: - """ - Evict a GPU block to CPU to make room. - - Args: - protected_logical_ids: Logical block IDs that cannot be evicted - - Returns: - The freed GPU slot ID - """ - protected = protected_logical_ids or set() - - # Find candidates (blocks currently on GPU with ref_count > 0, excluding protected) - candidates: Set[int] = set() - for gpu_slot, logical_id in self.gpu_slot_to_logical.items(): - if logical_id in protected: - continue # Skip protected blocks - block = self.logical_blocks[logical_id] - if block.ref_count > 0: # Only evict blocks still in use - candidates.add(gpu_slot) - - if not candidates: - raise RuntimeError( - f"No GPU slots available for eviction. " - f"GPU slots: {self.num_gpu_slots}, protected: {len(protected)}, " - f"need more GPU memory or reduce sequence length" - ) - - # Use policy to select victim - victim_gpu_slot = self.policy.select_victim(candidates) - logical_id = self.gpu_slot_to_logical[victim_gpu_slot] - block = self.logical_blocks[logical_id] - - # Allocate CPU block - if not self.free_cpu_blocks: - raise RuntimeError("Both GPU and CPU are full") - cpu_block_id = self.free_cpu_blocks.popleft() - - # Async offload GPU -> CPU - self.offload_engine.offload_block_async( - layer_id=0, # TODO: handle per-layer offloading - gpu_block_id=victim_gpu_slot, - cpu_block_id=cpu_block_id, - ) - - # Update mappings - del self.gpu_slot_to_logical[victim_gpu_slot] - self.cpu_block_to_logical[cpu_block_id] = logical_id - - block.location = BlockLocation.CPU - block.gpu_slot = -1 - block.cpu_block_id = cpu_block_id - - # Notify policy - self.policy.on_block_evicted(victim_gpu_slot) - - return victim_gpu_slot - - def _ensure_on_gpu( - self, - logical_id: int, - protected_logical_ids: Optional[Set[int]] = None, - ) -> int: - """ - Ensure a logical block is on GPU. - - Args: - logical_id: Logical block ID - protected_logical_ids: Logical block IDs that cannot be evicted - - Returns: - GPU slot ID where the block is/will be - """ - block = self.logical_blocks[logical_id] - - if block.location == BlockLocation.GPU: - # Already on GPU, update policy - self.policy.on_block_access(block.gpu_slot, self.current_step) - return block.gpu_slot - - if block.location == BlockLocation.CPU: - # Need to prefetch from CPU - gpu_slot = self._allocate_gpu_slot(protected_logical_ids) - - # Async prefetch CPU -> GPU - self.offload_engine.prefetch_block_async( - layer_id=0, # TODO: handle per-layer - cpu_block_id=block.cpu_block_id, - gpu_block_id=gpu_slot, - ) - - # Update mappings - self.free_cpu_blocks.append(block.cpu_block_id) - del self.cpu_block_to_logical[block.cpu_block_id] - - self.gpu_slot_to_logical[gpu_slot] = logical_id - - block.location = BlockLocation.GPU - block.gpu_slot = gpu_slot - block.cpu_block_id = -1 - - # Notify policy - self.policy.on_block_prefetched(gpu_slot, self.current_step) - - return gpu_slot - - raise RuntimeError(f"Block {logical_id} is in invalid state") - def can_allocate(self, seq: Sequence) -> bool: """Check if we can allocate blocks for a new sequence.""" return len(self.free_logical_ids) >= seq.num_blocks @@ -362,89 +204,10 @@ class HybridKVCacheManager(KVCacheManager): """ Allocate logical blocks for prefill. - In cpu_primary mode (Chunked Offload): All blocks are allocated to CPU. - In legacy mode: Blocks are allocated to GPU when possible, overflow to CPU. + All blocks are allocated to CPU (primary storage). + GPU is used as ring buffer for computation only. """ - assert not seq.block_table, "Sequence already has blocks" - - # Ring buffer mode: all blocks are allocated to CPU - if self.cpu_primary: - return self.allocate_cpu_only(seq) - - # Legacy mode: GPU as primary, CPU as overflow - h = -1 - cache_miss = False - - # Track blocks allocated for this sequence to protect them from eviction - allocated_for_seq: Set[int] = set() - - for i in range(seq.num_blocks): - token_ids = seq.block(i) - - # Hash for full blocks only - if len(token_ids) == self._block_size: - h = self.compute_hash(token_ids, h) - else: - h = -1 - - # Check prefix cache - cached_logical_id = self.hash_to_logical_id.get(h, -1) - if cached_logical_id != -1: - cached_block = self.logical_blocks[cached_logical_id] - if cached_block.token_ids == token_ids and cached_block.ref_count > 0: - # Cache hit - cached_block.ref_count += 1 - seq.num_cached_tokens += self._block_size - seq.block_table.append(cached_logical_id) - allocated_for_seq.add(cached_logical_id) - - # Ensure block is on GPU (protect already allocated blocks) - if cached_block.location == BlockLocation.CPU: - self._ensure_on_gpu(cached_logical_id, allocated_for_seq) - - continue - - cache_miss = True - - # Allocate new logical block - logical_id = self.free_logical_ids.popleft() - block = self.logical_blocks[logical_id] - block.ref_count = 1 - block.hash = h - block.token_ids = token_ids.copy() if len(token_ids) == self._block_size else [] - - # Try to allocate GPU slot - gpu_slot = self._try_allocate_gpu_slot(allocated_for_seq) - if gpu_slot is not None: - # Got GPU slot - block.location = BlockLocation.GPU - block.gpu_slot = gpu_slot - block.cpu_block_id = -1 - self.gpu_slot_to_logical[gpu_slot] = logical_id - else: - # GPU full and can't evict (all protected) - allocate to CPU - # This block will be written via chunked prefill - if not self.free_cpu_blocks: - raise RuntimeError( - f"Both GPU and CPU are full. Need {seq.num_blocks} blocks, " - f"GPU has {self.num_gpu_slots}, CPU has {self.num_cpu_blocks}" - ) - cpu_block_id = self.free_cpu_blocks.popleft() - block.location = BlockLocation.CPU - block.gpu_slot = -1 - block.cpu_block_id = cpu_block_id - self.cpu_block_to_logical[cpu_block_id] = logical_id - - allocated_for_seq.add(logical_id) - - # Update prefix cache - if h != -1: - self.hash_to_logical_id[h] = logical_id - - # Notify policy - self.policy.on_block_allocated(gpu_slot, self.current_step) - - seq.block_table.append(logical_id) + return self.allocate_cpu_only(seq) def deallocate(self, seq: Sequence) -> None: """Release all blocks for a sequence.""" @@ -496,22 +259,14 @@ class HybridKVCacheManager(KVCacheManager): block.hash = -1 block.token_ids = [] - if self.cpu_primary: - # Ring buffer mode: new block allocated to CPU - if not self.free_cpu_blocks: - raise RuntimeError("No free CPU blocks for decode") - cpu_block_id = self.free_cpu_blocks.popleft() - block.location = BlockLocation.CPU - block.cpu_block_id = cpu_block_id - block.gpu_slot = -1 - self.cpu_block_to_logical[cpu_block_id] = logical_id - else: - # Legacy mode: new block allocated to GPU - gpu_slot = self._allocate_gpu_slot() - block.location = BlockLocation.GPU - block.gpu_slot = gpu_slot - self.gpu_slot_to_logical[gpu_slot] = logical_id - self.policy.on_block_allocated(gpu_slot, self.current_step) + # Allocate new block to CPU (ring buffer mode) + if not self.free_cpu_blocks: + raise RuntimeError("No free CPU blocks for decode") + cpu_block_id = self.free_cpu_blocks.popleft() + block.location = BlockLocation.CPU + block.cpu_block_id = cpu_block_id + block.gpu_slot = -1 + self.cpu_block_to_logical[cpu_block_id] = logical_id block_table.append(logical_id) @@ -536,235 +291,10 @@ class HybridKVCacheManager(KVCacheManager): """ Prepare KV cache for attention computation. - For prefill: async prefetch blocks from CPU to GPU. - For decode: update gather_indices for CUDA graph. + In ring buffer mode, this is a no-op because chunked offload + paths handle H2D transfers directly in the attention layer. """ - self.current_step += 1 - - # Collect all needed logical blocks - needed_logical_ids: Set[int] = set() - for seq in seqs: - needed_logical_ids.update(seq.block_table) - - if is_prefill: - # Prefill: ensure all blocks on GPU (async prefetch) - # Pass needed_logical_ids as protected to prevent evicting blocks we need - for logical_id in needed_logical_ids: - block = self.logical_blocks[logical_id] - if block.location == BlockLocation.CPU: - self._ensure_on_gpu(logical_id, needed_logical_ids) - - # Wait for all prefetches to complete - self.offload_engine.wait_all_transfers() - - else: - # Decode: Check if we need chunked decode - cpu_blocks_count = sum( - 1 for lid in needed_logical_ids - if self.logical_blocks[lid].location == BlockLocation.CPU - ) - - if cpu_blocks_count > self.num_gpu_slots: - # Too many blocks on CPU - will use chunked decode - # Don't try to load all blocks now - return - - # Standard decode: prepare gather_indices for CUDA graph - # Identify blocks needing transfer - self.pending_gpu_loads.clear() - mappings_per_layer: List[List[Tuple[int, int]]] = [ - [] for _ in range(self.offload_engine.num_layers) - ] - - for logical_id in needed_logical_ids: - block = self.logical_blocks[logical_id] - if block.location == BlockLocation.CPU: - # Allocate GPU slot (protect needed blocks from eviction) - gpu_slot = self._allocate_gpu_slot(needed_logical_ids) - - # Record mapping for each layer - for layer_id in range(self.offload_engine.num_layers): - mappings_per_layer[layer_id].append( - (block.cpu_block_id, gpu_slot) - ) - - # Update block state - self.free_cpu_blocks.append(block.cpu_block_id) - del self.cpu_block_to_logical[block.cpu_block_id] - - self.gpu_slot_to_logical[gpu_slot] = logical_id - block.location = BlockLocation.GPU - block.gpu_slot = gpu_slot - block.cpu_block_id = -1 - - self.pending_gpu_loads.add(logical_id) - self.policy.on_block_prefetched(gpu_slot, self.current_step) - - elif block.location == BlockLocation.GPU: - self.policy.on_block_access(block.gpu_slot, self.current_step) - - # Update gather indices (outside graph) - self.offload_engine.update_gather_indices_all_layers(mappings_per_layer) - self.offload_engine.sync_indices() - - def needs_chunked_decode(self, seq: Sequence) -> bool: - """ - Check if sequence needs chunked decode. - - Returns True if there are blocks on CPU and total blocks exceed GPU capacity. - """ - cpu_blocks = 0 - for logical_id in seq.block_table: - block = self.logical_blocks[logical_id] - if block.location == BlockLocation.CPU: - cpu_blocks += 1 - return cpu_blocks > 0 and len(seq.block_table) > self.num_gpu_slots - - # ========== Chunked Decode Support ========== - - def get_decode_chunk_info(self, seq: Sequence) -> Tuple[List[int], List[int], int]: - """ - Get information for chunked decode. - - Returns: - (cpu_block_ids, cpu_logical_ids, num_chunks) - - cpu_block_ids: List of CPU block IDs in sequence order - - cpu_logical_ids: Corresponding logical block IDs - - num_chunks: Number of chunks needed - """ - cpu_block_ids = [] - cpu_logical_ids = [] - - for logical_id in seq.block_table: - block = self.logical_blocks[logical_id] - if block.location == BlockLocation.CPU: - cpu_block_ids.append(block.cpu_block_id) - cpu_logical_ids.append(logical_id) - - # Each chunk uses available GPU slots minus 1 (reserved for write block) - usable_slots = self.num_gpu_slots - 1 - num_chunks = (len(cpu_block_ids) + usable_slots - 1) // usable_slots if usable_slots > 0 else 0 - - return cpu_block_ids, cpu_logical_ids, num_chunks - - def load_decode_chunk( - self, - seq: Sequence, - cpu_block_ids: List[int], - cpu_logical_ids: List[int], - chunk_idx: int, - ) -> List[int]: - """ - Load one chunk of CPU blocks to GPU for chunked decode. - - Similar to chunked prefill: uses GPU slots to hold a batch of blocks. - - Args: - seq: Sequence being decoded - cpu_block_ids: All CPU block IDs for this sequence - cpu_logical_ids: Corresponding logical block IDs - chunk_idx: Which chunk to load (0-indexed) - - Returns: - List of GPU slot IDs where the chunk was loaded - """ - chunk_size = self.num_gpu_slots - start = chunk_idx * chunk_size - end = min(start + chunk_size, len(cpu_block_ids)) - - chunk_cpu_ids = cpu_block_ids[start:end] - chunk_logical_ids = cpu_logical_ids[start:end] - - # Use GPU slots 0, 1, 2, ... for this chunk - gpu_slots = list(range(len(chunk_cpu_ids))) - - # Load all layers at once using offload_engine - self.offload_engine.load_cpu_blocks_to_gpu_slots_all_layers( - chunk_cpu_ids, gpu_slots - ) - - return gpu_slots - - def get_gpu_blocks_for_decode(self, seq: Sequence) -> Tuple[List[int], List[int]]: - """ - Get blocks currently on GPU for this sequence. - - Returns: - (gpu_slots, logical_ids) - GPU slot IDs and corresponding logical block IDs - """ - gpu_slots = [] - logical_ids = [] - - for logical_id in seq.block_table: - block = self.logical_blocks[logical_id] - if block.location == BlockLocation.GPU: - gpu_slots.append(block.gpu_slot) - logical_ids.append(logical_id) - - return gpu_slots, logical_ids - - def get_kv_for_gpu_slots( - self, - layer_id: int, - gpu_slots: List[int], - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Get KV tensors for specific GPU slots. - - Args: - layer_id: Layer index - gpu_slots: List of GPU slot IDs - - Returns: - (k, v) tensors with shape [1, num_tokens, kv_heads, head_dim] - """ - k_cache, v_cache = self.offload_engine.get_layer_cache(layer_id) - # k_cache, v_cache shape: [num_gpu_blocks, block_size, kv_heads, head_dim] - - k_chunks = [k_cache[slot] for slot in gpu_slots] - v_chunks = [v_cache[slot] for slot in gpu_slots] - - # Concatenate and add batch dimension - k = torch.cat(k_chunks, dim=0).unsqueeze(0) # [1, tokens, heads, dim] - v = torch.cat(v_chunks, dim=0).unsqueeze(0) - - return k, v - - def ensure_last_block_on_gpu(self, seq: Sequence) -> int: - """ - Ensure the last block is on GPU for writing new KV. - - Uses a RESERVED slot (last slot) to avoid conflicts with chunked decode - which uses slots 0, 1, 2, ... for loading CPU blocks. - - Returns: - GPU slot ID for the last block - """ - last_logical_id = seq.block_table[-1] - block = self.logical_blocks[last_logical_id] - - if block.location == BlockLocation.GPU: - return block.gpu_slot - - # Use last slot as reserved slot for write block - # This avoids conflicts with chunked decode which uses slots 0, 1, 2... - reserved_slot = self.num_gpu_slots - 1 - - # Load this block to GPU for all layers - self.offload_engine.load_cpu_blocks_to_gpu_slots_all_layers( - [block.cpu_block_id], [reserved_slot] - ) - - # Update block state - self.free_cpu_blocks.append(block.cpu_block_id) - del self.cpu_block_to_logical[block.cpu_block_id] - - self.gpu_slot_to_logical[reserved_slot] = last_logical_id - block.location = BlockLocation.GPU - block.gpu_slot = reserved_slot - block.cpu_block_id = -1 - - return reserved_slot + pass def get_gpu_block_tables( self, @@ -773,19 +303,13 @@ class HybridKVCacheManager(KVCacheManager): """ Get GPU slot tables for sequences. - Returns GPU slot IDs, which may differ from logical block IDs. + In ring buffer mode, all blocks are on CPU, so this raises an error + if called. Use run_chunked_offload_* methods instead. """ - result = [] - for seq in seqs: - gpu_table = [] - for logical_id in seq.block_table: - block = self.logical_blocks[logical_id] - assert block.location == BlockLocation.GPU, ( - f"Block {logical_id} not on GPU (location={block.location})" - ) - gpu_table.append(block.gpu_slot) - result.append(gpu_table) - return result + raise RuntimeError( + "get_gpu_block_tables should not be called in ring buffer mode. " + "Use run_chunked_offload_prefill/decode instead." + ) def post_attention_cleanup( self, @@ -795,180 +319,12 @@ class HybridKVCacheManager(KVCacheManager): """ Cleanup after attention. - Clear pending loads and optionally proactive offload. + In ring buffer mode, this is a no-op because offload is handled + directly in the chunked prefill/decode paths. """ - self.pending_gpu_loads.clear() + pass - # ========== Chunked Prefill Support ========== - - def needs_chunked_prefill(self, seq: Sequence) -> bool: - """ - Check if sequence needs chunked prefill. - - Returns True if there are unprefilled blocks that are on CPU. - This indicates we need to process in chunks because not all blocks fit on GPU. - """ - for logical_id in seq.block_table: - if logical_id not in self.prefilled_blocks: - block = self.logical_blocks[logical_id] - if block.location == BlockLocation.CPU: - return True - return False - - def get_gpu_block_count(self, seq: Sequence) -> int: - """Get number of blocks currently on GPU for this sequence.""" - count = 0 - for logical_id in seq.block_table: - block = self.logical_blocks[logical_id] - if block.location == BlockLocation.GPU: - count += 1 - return count - - def get_prefill_chunk_info(self, seq: Sequence) -> Tuple[int, int, List[int]]: - """ - Get information for current prefill chunk. - - Returns: - (start_block_idx, end_block_idx, gpu_block_ids) - - start_block_idx: First block index in this chunk - - end_block_idx: Last block index (exclusive) in this chunk - - gpu_block_ids: GPU slot IDs for blocks in this chunk - """ - start_idx = -1 - end_idx = -1 - gpu_block_ids = [] - - for i, logical_id in enumerate(seq.block_table): - block = self.logical_blocks[logical_id] - if block.location == BlockLocation.GPU: - if start_idx == -1: - start_idx = i - end_idx = i + 1 - gpu_block_ids.append(block.gpu_slot) - elif start_idx != -1: - # Found CPU block after GPU blocks - stop here - break - - if start_idx == -1: - return (0, 0, []) - - return (start_idx, end_idx, gpu_block_ids) - - def complete_prefill_chunk(self, seq: Sequence) -> bool: - """ - Complete a prefill chunk: mark blocks as prefilled, offload to CPU, load next chunk. - - Returns: - True if there are more chunks to process, False if done. - """ - # Find blocks currently on GPU that were just prefilled - gpu_blocks_to_offload = [] - for logical_id in seq.block_table: - block = self.logical_blocks[logical_id] - if block.location == BlockLocation.GPU and logical_id not in self.prefilled_blocks: - # Mark as prefilled - self.prefilled_blocks.add(logical_id) - gpu_blocks_to_offload.append(logical_id) - - # Offload prefilled GPU blocks to CPU - for logical_id in gpu_blocks_to_offload: - block = self.logical_blocks[logical_id] - if not self.free_cpu_blocks: - raise RuntimeError("No free CPU blocks for offload") - - cpu_block_id = self.free_cpu_blocks.popleft() - - # Async offload all layers - for layer_id in range(self.offload_engine.num_layers): - self.offload_engine.offload_block_async( - layer_id=layer_id, - gpu_block_id=block.gpu_slot, - cpu_block_id=cpu_block_id, - ) - - # Update mappings - self.free_gpu_slots.append(block.gpu_slot) - del self.gpu_slot_to_logical[block.gpu_slot] - self.cpu_block_to_logical[cpu_block_id] = logical_id - - block.location = BlockLocation.CPU - block.cpu_block_id = cpu_block_id - block.gpu_slot = -1 - - # Wait for offload to complete - self.offload_engine.wait_all_transfers() - - # Find next UNPREFILLED CPU blocks and bring them to GPU - cpu_blocks_to_load = [] - for logical_id in seq.block_table: - if logical_id in self.prefilled_blocks: - continue # Skip already prefilled - block = self.logical_blocks[logical_id] - if block.location == BlockLocation.CPU: - if len(cpu_blocks_to_load) >= self.num_gpu_slots: - break # GPU is full - cpu_blocks_to_load.append(logical_id) - - if not cpu_blocks_to_load: - return False # All blocks have been prefilled - - # Load unprefilled CPU blocks to GPU - for logical_id in cpu_blocks_to_load: - block = self.logical_blocks[logical_id] - gpu_slot = self.free_gpu_slots.popleft() - - # Note: We're NOT prefetching existing data - these blocks are being - # loaded for the first time, so we just need to assign GPU slots - # The model will write new KV cache data to these slots - - # Update mappings - self.free_cpu_blocks.append(block.cpu_block_id) - del self.cpu_block_to_logical[block.cpu_block_id] - self.gpu_slot_to_logical[gpu_slot] = logical_id - - block.location = BlockLocation.GPU - block.gpu_slot = gpu_slot - block.cpu_block_id = -1 - - return True # More chunks to process - - def get_gpu_block_tables_partial( - self, - seqs: List[Sequence], - ) -> List[Tuple[List[int], int, int]]: - """ - Get GPU block tables for chunked prefill. - - Returns list of (gpu_block_ids, start_block_idx, end_block_idx) per sequence. - Only includes blocks that are currently on GPU AND haven't been prefilled yet. - """ - result = [] - for seq in seqs: - gpu_table = [] - start_idx = -1 - end_idx = -1 - - for i, logical_id in enumerate(seq.block_table): - # Skip already prefilled blocks - if logical_id in self.prefilled_blocks: - continue - - block = self.logical_blocks[logical_id] - if block.location == BlockLocation.GPU: - if start_idx == -1: - start_idx = i - end_idx = i + 1 - gpu_table.append(block.gpu_slot) - elif start_idx != -1: - # Stop at first non-GPU block after GPU blocks - break - - if start_idx == -1: - start_idx = 0 - end_idx = 0 - - result.append((gpu_table, start_idx, end_idx)) - return result + # ========== Ring Buffer CPU-primary Chunked Prefill Support ========== def get_prefilled_cpu_blocks(self, seq: Sequence) -> List[int]: """ @@ -991,66 +347,6 @@ class HybridKVCacheManager(KVCacheManager): ) return cpu_blocks - def load_prev_kv_for_layer( - self, - seq: Sequence, - layer_id: int, - ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: - """ - Load previous prefilled KV from CPU for a specific layer. - - This concatenates KV from all previously prefilled blocks for use - during chunked prefill attention. - - Args: - seq: Sequence to load KV for - layer_id: Layer index - - Returns: - (k, v) tensors with shape [1, total_prev_tokens, kv_heads, head_dim] - or (None, None) if no previous KV exists - """ - cpu_blocks = self.get_prefilled_cpu_blocks(seq) - if not cpu_blocks: - return None, None - - k_chunks = [] - v_chunks = [] - - for cpu_block_id in cpu_blocks: - k, v = self.offload_engine.get_cpu_block(layer_id, cpu_block_id) - # k, v shape: [block_size, kv_heads, head_dim] - k_chunks.append(k) - v_chunks.append(v) - - # Concatenate all chunks - k_prev = torch.cat(k_chunks, dim=0) # [total_prev_tokens, kv_heads, head_dim] - v_prev = torch.cat(v_chunks, dim=0) - - # Move to GPU and add batch dimension - k_prev = k_prev.to("cuda", non_blocking=True).unsqueeze(0) # [1, tokens, heads, dim] - v_prev = v_prev.to("cuda", non_blocking=True).unsqueeze(0) - - return k_prev, v_prev - - def get_chunk_start_position(self, seq: Sequence) -> int: - """ - Get the starting token position for the current chunk. - - This is the total number of tokens in previously prefilled blocks. - - Returns: - Token position offset for current chunk - """ - pos = 0 - for logical_id in seq.block_table: - if logical_id in self.prefilled_blocks: - # Full block's worth of tokens - pos += self._block_size - else: - break - return pos - # ========== Ring Buffer CPU-primary support ========== def allocate_cpu_only(self, seq: Sequence) -> None: