From 87055cc5cec99fd49269af17df84e67726a6d3c5 Mon Sep 17 00:00:00 2001 From: Zijie Tian Date: Wed, 10 Dec 2025 18:34:01 +0800 Subject: [PATCH] [refactor] Implement real chunked prefill mechenism. --- nanovllm/engine/model_runner.py | 36 +++---- nanovllm/kvcache/hybrid_manager.py | 163 +++++++++++++++++++++++------ nanovllm/kvcache/offload_engine.py | 106 +++++++++++++++++++ nanovllm/layers/attention.py | 93 ++++++++++------ 4 files changed, 313 insertions(+), 85 deletions(-) diff --git a/nanovllm/engine/model_runner.py b/nanovllm/engine/model_runner.py index 38a9fcc..30371ea 100644 --- a/nanovllm/engine/model_runner.py +++ b/nanovllm/engine/model_runner.py @@ -429,37 +429,26 @@ class ModelRunner: Run decode with chunked attention when sequence exceeds GPU capacity. For decode, we need attention over ALL previous tokens. With CPU offload, - we load KV chunks and compute attention incrementally. - """ - import sys + we load KV chunks and compute attention incrementally per-layer. + Flow: + 1. Ensure last block is on GPU (for writing new KV token) + 2. Run model forward - each attention layer: + a. Compute attention on GPU blocks + b. Load CPU blocks in chunks, compute + merge + 3. Sample from output + """ # Currently only supporting single sequence for chunked decode assert len(seqs) == 1, "Chunked decode only supports single sequence" seq = seqs[0] - total_blocks = len(seq.block_table) - print(f"[Chunked Decode] Sequence has {total_blocks} blocks, " - f"GPU slots: {self.kvcache_manager.num_gpu_slots}", file=sys.stderr) - # Prepare inputs input_ids = torch.tensor([seq.last_token], dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) positions = torch.tensor([len(seq) - 1], dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) - # Compute slot mapping for the new token - # Get the last block's GPU slot if it's on GPU, otherwise we need to handle it - last_logical_id = seq.block_table[-1] - last_block = self.kvcache_manager.logical_blocks[last_logical_id] - - if last_block.location.name == "GPU": - slot = last_block.gpu_slot * self.block_size + seq.last_block_num_tokens - 1 - else: - # Last block is on CPU - we need to bring it to GPU for writing the new token - # This is a special case - allocate a temporary GPU slot - # For simplicity, use a fixed slot (this might conflict, but for decode - # we only write 1 token so it should be ok) - print(f"[Chunked Decode] Warning: last block on CPU, using temp slot", file=sys.stderr) - slot = 0 # Use first slot temporarily - + # Ensure last block is on GPU for writing new KV token + last_gpu_slot = self.kvcache_manager.ensure_last_block_on_gpu(seq) + slot = last_gpu_slot * self.block_size + seq.last_block_num_tokens - 1 slot_mapping = torch.tensor([slot], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) context_len = torch.tensor([len(seq)], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) @@ -468,12 +457,13 @@ class ModelRunner: is_prefill=False, # Decode mode slot_mapping=slot_mapping, context_lens=context_len, - is_chunked_prefill=True, # Use chunked attention + is_chunked_prefill=True, # Use chunked attention path offload_engine=self.kvcache_manager, chunked_seq=seq, ) # Run model forward pass + # Each attention layer will handle chunked KV loading internally logits = self.run_model(input_ids, positions, is_prefill=False) reset_context() diff --git a/nanovllm/kvcache/hybrid_manager.py b/nanovllm/kvcache/hybrid_manager.py index 27b0269..b40c51a 100644 --- a/nanovllm/kvcache/hybrid_manager.py +++ b/nanovllm/kvcache/hybrid_manager.py @@ -566,50 +566,151 @@ class HybridKVCacheManager(KVCacheManager): cpu_blocks += 1 return cpu_blocks > 0 and len(seq.block_table) > self.num_gpu_slots - def load_all_kv_for_layer( - self, - seq: Sequence, - layer_id: int, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Load ALL KV for a sequence from both GPU and CPU for a layer. + # ========== Chunked Decode Support ========== - Used during chunked decode to compute full attention. + def get_decode_chunk_info(self, seq: Sequence) -> Tuple[List[int], List[int], int]: + """ + Get information for chunked decode. Returns: - (k, v) tensors with shape [1, total_tokens, kv_heads, head_dim] + (cpu_block_ids, cpu_logical_ids, num_chunks) + - cpu_block_ids: List of CPU block IDs in sequence order + - cpu_logical_ids: Corresponding logical block IDs + - num_chunks: Number of chunks needed """ - k_chunks = [] - v_chunks = [] + cpu_block_ids = [] + cpu_logical_ids = [] for logical_id in seq.block_table: block = self.logical_blocks[logical_id] + if block.location == BlockLocation.CPU: + cpu_block_ids.append(block.cpu_block_id) + cpu_logical_ids.append(logical_id) + # Each chunk uses available GPU slots minus 1 (reserved for write block) + usable_slots = self.num_gpu_slots - 1 + num_chunks = (len(cpu_block_ids) + usable_slots - 1) // usable_slots if usable_slots > 0 else 0 + + return cpu_block_ids, cpu_logical_ids, num_chunks + + def load_decode_chunk( + self, + seq: Sequence, + cpu_block_ids: List[int], + cpu_logical_ids: List[int], + chunk_idx: int, + ) -> List[int]: + """ + Load one chunk of CPU blocks to GPU for chunked decode. + + Similar to chunked prefill: uses GPU slots to hold a batch of blocks. + + Args: + seq: Sequence being decoded + cpu_block_ids: All CPU block IDs for this sequence + cpu_logical_ids: Corresponding logical block IDs + chunk_idx: Which chunk to load (0-indexed) + + Returns: + List of GPU slot IDs where the chunk was loaded + """ + chunk_size = self.num_gpu_slots + start = chunk_idx * chunk_size + end = min(start + chunk_size, len(cpu_block_ids)) + + chunk_cpu_ids = cpu_block_ids[start:end] + chunk_logical_ids = cpu_logical_ids[start:end] + + # Use GPU slots 0, 1, 2, ... for this chunk + gpu_slots = list(range(len(chunk_cpu_ids))) + + # Load all layers at once using offload_engine + self.offload_engine.load_cpu_blocks_to_gpu_slots_all_layers( + chunk_cpu_ids, gpu_slots + ) + + return gpu_slots + + def get_gpu_blocks_for_decode(self, seq: Sequence) -> Tuple[List[int], List[int]]: + """ + Get blocks currently on GPU for this sequence. + + Returns: + (gpu_slots, logical_ids) - GPU slot IDs and corresponding logical block IDs + """ + gpu_slots = [] + logical_ids = [] + + for logical_id in seq.block_table: + block = self.logical_blocks[logical_id] if block.location == BlockLocation.GPU: - # Get from GPU cache - k, v = self.offload_engine.get_layer_cache(layer_id) - # k, v shape: [num_gpu_blocks, block_size, kv_heads, head_dim] - k_block = k[block.gpu_slot] # [block_size, kv_heads, head_dim] - v_block = v[block.gpu_slot] - k_chunks.append(k_block) - v_chunks.append(v_block) + gpu_slots.append(block.gpu_slot) + logical_ids.append(logical_id) - elif block.location == BlockLocation.CPU: - # Get from CPU cache - k_block, v_block = self.offload_engine.get_cpu_block(layer_id, block.cpu_block_id) - # Already [block_size, kv_heads, head_dim] - k_chunks.append(k_block.to("cuda", non_blocking=True)) - v_chunks.append(v_block.to("cuda", non_blocking=True)) + return gpu_slots, logical_ids - # Concatenate all chunks - k_all = torch.cat(k_chunks, dim=0) # [total_tokens, kv_heads, head_dim] - v_all = torch.cat(v_chunks, dim=0) + def get_kv_for_gpu_slots( + self, + layer_id: int, + gpu_slots: List[int], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Get KV tensors for specific GPU slots. - # Add batch dimension - k_all = k_all.unsqueeze(0) # [1, total_tokens, kv_heads, head_dim] - v_all = v_all.unsqueeze(0) + Args: + layer_id: Layer index + gpu_slots: List of GPU slot IDs - return k_all, v_all + Returns: + (k, v) tensors with shape [1, num_tokens, kv_heads, head_dim] + """ + k_cache, v_cache = self.offload_engine.get_layer_cache(layer_id) + # k_cache, v_cache shape: [num_gpu_blocks, block_size, kv_heads, head_dim] + + k_chunks = [k_cache[slot] for slot in gpu_slots] + v_chunks = [v_cache[slot] for slot in gpu_slots] + + # Concatenate and add batch dimension + k = torch.cat(k_chunks, dim=0).unsqueeze(0) # [1, tokens, heads, dim] + v = torch.cat(v_chunks, dim=0).unsqueeze(0) + + return k, v + + def ensure_last_block_on_gpu(self, seq: Sequence) -> int: + """ + Ensure the last block is on GPU for writing new KV. + + Uses a RESERVED slot (last slot) to avoid conflicts with chunked decode + which uses slots 0, 1, 2, ... for loading CPU blocks. + + Returns: + GPU slot ID for the last block + """ + last_logical_id = seq.block_table[-1] + block = self.logical_blocks[last_logical_id] + + if block.location == BlockLocation.GPU: + return block.gpu_slot + + # Use last slot as reserved slot for write block + # This avoids conflicts with chunked decode which uses slots 0, 1, 2... + reserved_slot = self.num_gpu_slots - 1 + + # Load this block to GPU for all layers + self.offload_engine.load_cpu_blocks_to_gpu_slots_all_layers( + [block.cpu_block_id], [reserved_slot] + ) + + # Update block state + self.free_cpu_blocks.append(block.cpu_block_id) + del self.cpu_block_to_logical[block.cpu_block_id] + + self.gpu_slot_to_logical[reserved_slot] = last_logical_id + block.location = BlockLocation.GPU + block.gpu_slot = reserved_slot + block.cpu_block_id = -1 + + return reserved_slot def get_gpu_block_tables( self, diff --git a/nanovllm/kvcache/offload_engine.py b/nanovllm/kvcache/offload_engine.py index 188b74f..8f8c2d5 100644 --- a/nanovllm/kvcache/offload_engine.py +++ b/nanovllm/kvcache/offload_engine.py @@ -308,6 +308,112 @@ class OffloadEngine: events.append(event) return events + # ========== Chunked Decode: Load CPU blocks to GPU slots ========== + + def load_cpu_blocks_to_gpu_slots( + self, + layer_id: int, + cpu_block_ids: List[int], + gpu_slot_ids: List[int], + ) -> None: + """ + Load CPU blocks to specific GPU slots for chunked decode. + + Uses the main GPU KV cache slots, not a separate temp buffer. + This is the same mechanism as chunked prefill uses. + + Args: + layer_id: Layer index + cpu_block_ids: List of CPU block IDs to load + gpu_slot_ids: List of GPU slot IDs to load into (must be same length) + """ + assert len(cpu_block_ids) == len(gpu_slot_ids) + + stream = self._get_next_stream() + + with torch.cuda.stream(stream): + for cpu_block_id, gpu_slot in zip(cpu_block_ids, gpu_slot_ids): + # Copy from pinned CPU memory to GPU KV cache slot + self.k_cache_gpu[layer_id, gpu_slot].copy_( + self.k_cache_cpu[layer_id, cpu_block_id], + non_blocking=True + ) + self.v_cache_gpu[layer_id, gpu_slot].copy_( + self.v_cache_cpu[layer_id, cpu_block_id], + non_blocking=True + ) + + # Wait for transfer to complete + stream.synchronize() + + def load_cpu_blocks_to_gpu_slots_async( + self, + layer_id: int, + cpu_block_ids: List[int], + gpu_slot_ids: List[int], + ) -> torch.cuda.Event: + """ + Async version: Load CPU blocks to GPU slots. + + Args: + layer_id: Layer index + cpu_block_ids: List of CPU block IDs to load + gpu_slot_ids: List of GPU slot IDs to load into + + Returns: + CUDA event to wait on + """ + assert len(cpu_block_ids) == len(gpu_slot_ids) + + stream = self._get_next_stream() + event = torch.cuda.Event() + + with torch.cuda.stream(stream): + for cpu_block_id, gpu_slot in zip(cpu_block_ids, gpu_slot_ids): + self.k_cache_gpu[layer_id, gpu_slot].copy_( + self.k_cache_cpu[layer_id, cpu_block_id], + non_blocking=True + ) + self.v_cache_gpu[layer_id, gpu_slot].copy_( + self.v_cache_cpu[layer_id, cpu_block_id], + non_blocking=True + ) + event.record() + + return event + + def load_cpu_blocks_to_gpu_slots_all_layers( + self, + cpu_block_ids: List[int], + gpu_slot_ids: List[int], + ) -> None: + """ + Load CPU blocks to GPU slots for ALL layers at once. + + More efficient than per-layer loading when we know the mapping upfront. + + Args: + cpu_block_ids: List of CPU block IDs to load + gpu_slot_ids: List of GPU slot IDs to load into + """ + assert len(cpu_block_ids) == len(gpu_slot_ids) + + stream = self._get_next_stream() + + with torch.cuda.stream(stream): + for cpu_block_id, gpu_slot in zip(cpu_block_ids, gpu_slot_ids): + # Copy all layers at once + self.k_cache_gpu[:, gpu_slot].copy_( + self.k_cache_cpu[:, cpu_block_id], + non_blocking=True + ) + self.v_cache_gpu[:, gpu_slot].copy_( + self.v_cache_cpu[:, cpu_block_id], + non_blocking=True + ) + + stream.synchronize() + # ========== Synchronization methods ========== def wait_for_block(self, layer_id: int, gpu_block_id: int) -> None: diff --git a/nanovllm/layers/attention.py b/nanovllm/layers/attention.py index 39bdc38..1816181 100644 --- a/nanovllm/layers/attention.py +++ b/nanovllm/layers/attention.py @@ -174,45 +174,76 @@ class Attention(nn.Module): """ Compute decode attention with KV spread across CPU and GPU. - For decode with chunked KV: - 1. Load all KV for this layer from CPU+GPU - 2. Compute attention (1 query token vs all KV) - 3. Return output + Uses chunked attention similar to chunked prefill: + 1. Process blocks on GPU first (if any) + 2. Load CPU blocks in chunks to GPU slots (per-layer) + 3. Compute attention for each chunk, merge with online softmax """ - from nanovllm.kvcache.chunked_attention import flash_attn_with_lse + from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs # q shape: [batch_size, num_heads, head_dim] (single decode token per sequence) - # We need to attend to ALL previous tokens + # Need: [batch, seqlen, heads, dim] + q_batched = q.unsqueeze(1) # [batch, 1, heads, dim] - # Load all KV for this layer - if context.offload_engine is not None and self.layer_id >= 0: - kvcache_manager = context.offload_engine + kvcache_manager = context.offload_engine + seq = context.chunked_seq - if hasattr(context, 'chunked_seq') and context.chunked_seq is not None: - # Load all KV from both GPU and CPU for this layer - k_all, v_all = kvcache_manager.load_all_kv_for_layer( - context.chunked_seq, + o_acc = None + lse_acc = None + + # Step 1: Process blocks already on GPU (if any) + gpu_slots, _ = kvcache_manager.get_gpu_blocks_for_decode(seq) + if gpu_slots: + k_gpu, v_gpu = kvcache_manager.get_kv_for_gpu_slots(self.layer_id, gpu_slots) + o_gpu, lse_gpu = flash_attn_with_lse( + q_batched, k_gpu, v_gpu, + softmax_scale=self.scale, + causal=False, + ) + o_acc, lse_acc = o_gpu, lse_gpu + + # Step 2: Process CPU blocks in chunks + # Get chunk info from kvcache_manager + cpu_block_ids, cpu_logical_ids, num_chunks = kvcache_manager.get_decode_chunk_info(seq) + + if num_chunks > 0: + # Use num_gpu_slots - 1 to avoid the reserved slot (used for write block) + chunk_size = kvcache_manager.num_gpu_slots - 1 + + for chunk_idx in range(num_chunks): + start = chunk_idx * chunk_size + end = min(start + chunk_size, len(cpu_block_ids)) + chunk_cpu_ids = cpu_block_ids[start:end] + + # Load this chunk to GPU slots 0, 1, 2, ... for THIS LAYER + # (slot num_gpu_slots-1 is reserved for write block) + gpu_slots_for_chunk = list(range(len(chunk_cpu_ids))) + kvcache_manager.offload_engine.load_cpu_blocks_to_gpu_slots( self.layer_id, + chunk_cpu_ids, + gpu_slots_for_chunk, ) - if k_all is not None and v_all is not None: - # q shape: [batch_size, num_heads, head_dim] - # Need: [batch, seqlen, heads, dim] - # Insert seqlen dimension at position 1 - q_batched = q.unsqueeze(1) # [batch, 1, heads, dim] + # Get KV for this chunk + k_chunk, v_chunk = kvcache_manager.get_kv_for_gpu_slots( + self.layer_id, gpu_slots_for_chunk + ) - # k_all, v_all shape: [1, total_kv_tokens, kv_heads, head_dim] - # Compute attention (no causal mask for decode - we want all KV) - out, _ = flash_attn_with_lse( - q_batched, - k_all, - v_all, - softmax_scale=self.scale, - causal=False, # No causal mask for decode - ) + # Compute attention for this chunk + o_chunk, lse_chunk = flash_attn_with_lse( + q_batched, k_chunk, v_chunk, + softmax_scale=self.scale, + causal=False, + ) - # Output shape: [batch, 1, heads, dim] -> [batch, heads, dim] - return out.squeeze(1) + # Merge with accumulated + if o_acc is None: + o_acc, lse_acc = o_chunk, lse_chunk + else: + o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o_chunk, lse_chunk) - # Fallback: shouldn't reach here - raise RuntimeError("Chunked decode attention failed: no KV available") + if o_acc is None: + raise RuntimeError("Chunked decode attention failed: no KV available") + + # Output shape: [batch, 1, heads, dim] -> [batch, heads, dim] + return o_acc.squeeze(1)