diff --git a/nanovllm/kvcache/offload_engine.py b/nanovllm/kvcache/offload_engine.py index b2ad229..e860daf 100644 --- a/nanovllm/kvcache/offload_engine.py +++ b/nanovllm/kvcache/offload_engine.py @@ -374,7 +374,9 @@ class OffloadEngine: """ self.ring_slot_compute_done[slot_idx].record() - def load_to_slot_layer(self, slot_idx: int, layer_id: int, cpu_block_id: int) -> None: + def load_to_slot_layer( + self, slot_idx: int, layer_id: int, cpu_block_id: int, chunk_idx: int = -1 + ) -> None: """ Async load a single CPU block to a ring buffer slot for one layer. @@ -389,13 +391,19 @@ class OffloadEngine: slot_idx: Target GPU slot index layer_id: Layer index to load (for CPU cache indexing) cpu_block_id: Source CPU block ID + chunk_idx: Optional chunk index for NVTX labeling (-1 means not specified) """ logger.debug(f"Ring load: layer={layer_id}, CPU[{cpu_block_id}] -> GPU slot[{slot_idx}]") # Use per-slot stream for parallel transfers across different slots stream = self.slot_transfer_streams[slot_idx] - torch.cuda.nvtx.range_push(f"H2D: L{layer_id} CPU[{cpu_block_id}]->Slot[{slot_idx}]") + # Build NVTX label with optional chunk info + if chunk_idx >= 0: + nvtx_label = f"H2D: L{layer_id} Chunk{chunk_idx} CPU[{cpu_block_id}]->Slot[{slot_idx}]" + else: + nvtx_label = f"H2D: L{layer_id} CPU[{cpu_block_id}]->Slot[{slot_idx}]" + torch.cuda.nvtx.range_push(nvtx_label) with torch.cuda.stream(stream): # Wait for previous compute on this slot to complete before overwriting # This prevents data race: transfer must not start until attention finishes reading @@ -702,6 +710,61 @@ class OffloadEngine: v = self.prefill_v_buffer[layer_id, :num_tokens].unsqueeze(0) return k, v + def write_to_prefill_buffer( + self, + layer_id: int, + k: Tensor, + v: Tensor, + chunk_idx: int = -1, + ) -> None: + """ + Write KV tensors to prefill buffer (D2D copy within GPU). + + This is called during chunked prefill to store current chunk's KV + before computing attention. + + Args: + layer_id: Layer index + k: Key tensor [num_tokens, kv_heads, head_dim] + v: Value tensor [num_tokens, kv_heads, head_dim] + chunk_idx: Current chunk index for NVTX labeling (-1 = not specified) + """ + num_tokens = k.shape[0] + + # Build NVTX label + if chunk_idx >= 0: + nvtx_label = f"D2D: L{layer_id} Chunk{chunk_idx} WritePrefillBuffer" + else: + nvtx_label = f"D2D: L{layer_id} WritePrefillBuffer" + + torch.cuda.nvtx.range_push(nvtx_label) + self.prefill_k_buffer[layer_id, :num_tokens].copy_(k) + self.prefill_v_buffer[layer_id, :num_tokens].copy_(v) + torch.cuda.nvtx.range_pop() + + def write_to_decode_buffer( + self, + layer_id: int, + pos_in_block: int, + k: Tensor, + v: Tensor, + ) -> None: + """ + Write KV tensors to decode buffer (D2D copy within GPU). + + This is called during chunked decode to store current decode token's KV. + + Args: + layer_id: Layer index + pos_in_block: Position within the current block + k: Key tensor [kv_heads, head_dim] (single token, squeezed) + v: Value tensor [kv_heads, head_dim] (single token, squeezed) + """ + torch.cuda.nvtx.range_push(f"D2D: L{layer_id} Pos{pos_in_block} WriteDecodeBuffer") + self.decode_k_buffer[layer_id, pos_in_block].copy_(k) + self.decode_v_buffer[layer_id, pos_in_block].copy_(v) + torch.cuda.nvtx.range_pop() + def offload_prefill_buffer_async( self, layer_id: int, diff --git a/nanovllm/kvcache/sparse/full_policy.py b/nanovllm/kvcache/sparse/full_policy.py index 9cfd061..1001e3b 100644 --- a/nanovllm/kvcache/sparse/full_policy.py +++ b/nanovllm/kvcache/sparse/full_policy.py @@ -139,7 +139,8 @@ class FullAttentionPolicy(SparsePolicy): slot = load_slots[0] for block_idx in range(num_blocks): cpu_block_id = cpu_block_table[block_idx] - offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id) + # cpu_block_id is the chunk index (block N = chunk N) + offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id, chunk_idx=cpu_block_id) offload_engine.wait_slot_layer(slot) with torch.cuda.stream(compute_stream): @@ -159,7 +160,8 @@ class FullAttentionPolicy(SparsePolicy): num_slots = len(load_slots) num_preload = min(num_slots, num_blocks) for i in range(num_preload): - offload_engine.load_to_slot_layer(load_slots[i], layer_id, cpu_block_table[i]) + cpu_block_id = cpu_block_table[i] + offload_engine.load_to_slot_layer(load_slots[i], layer_id, cpu_block_id, chunk_idx=cpu_block_id) for block_idx in range(num_blocks): current_slot = load_slots[block_idx % num_slots] @@ -186,7 +188,7 @@ class FullAttentionPolicy(SparsePolicy): if next_block_idx < num_blocks: next_slot = load_slots[next_block_idx % num_slots] next_cpu_block_id = cpu_block_table[next_block_idx] - offload_engine.load_to_slot_layer(next_slot, layer_id, next_cpu_block_id) + offload_engine.load_to_slot_layer(next_slot, layer_id, next_cpu_block_id, chunk_idx=next_cpu_block_id) # Step 4: Compute attention to current chunk (causal mask) with torch.cuda.stream(compute_stream): @@ -350,7 +352,8 @@ class FullAttentionPolicy(SparsePolicy): # Phase 1: Pre-load up to num_slots blocks num_preload = min(num_slots, num_blocks) for i in range(num_preload): - offload_engine.load_to_slot_layer(load_slots[i], layer_id, cpu_block_table[i]) + cpu_block_id = cpu_block_table[i] + offload_engine.load_to_slot_layer(load_slots[i], layer_id, cpu_block_id, chunk_idx=cpu_block_id) # Phase 2: Process blocks with pipeline for block_idx in range(num_blocks): @@ -383,7 +386,8 @@ class FullAttentionPolicy(SparsePolicy): # Start loading next block (pipeline) next_block_idx = block_idx + num_slots if next_block_idx < num_blocks: - offload_engine.load_to_slot_layer(current_slot, layer_id, cpu_block_table[next_block_idx]) + next_cpu_block_id = cpu_block_table[next_block_idx] + offload_engine.load_to_slot_layer(current_slot, layer_id, next_cpu_block_id, chunk_idx=next_cpu_block_id) # Merge with accumulated with torch.cuda.stream(compute_stream): diff --git a/nanovllm/kvcache/sparse/xattn_bsa.py b/nanovllm/kvcache/sparse/xattn_bsa.py index c8aebf7..70dcd1d 100644 --- a/nanovllm/kvcache/sparse/xattn_bsa.py +++ b/nanovllm/kvcache/sparse/xattn_bsa.py @@ -189,8 +189,8 @@ class XAttentionBSAPolicy(SparsePolicy): reshaped_block_size = block_size // self.stride # e.g., 1024/8 = 128 for cpu_block_id in available_blocks: - # Load K block from CPU to GPU - offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id) + # Load K block from CPU to GPU (cpu_block_id is chunk index) + offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id, chunk_idx=cpu_block_id) offload_engine.wait_slot_layer(slot) # Get KV: [1, block_size, num_kv_heads, head_dim] @@ -382,7 +382,7 @@ class XAttentionBSAPolicy(SparsePolicy): slot = load_slots[0] for block_idx in range(num_blocks): cpu_block_id = cpu_block_table[block_idx] - offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id) + offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id, chunk_idx=cpu_block_id) offload_engine.wait_slot_layer(slot) with torch.cuda.stream(compute_stream): @@ -402,7 +402,8 @@ class XAttentionBSAPolicy(SparsePolicy): num_slots = len(load_slots) num_preload = min(num_slots, num_blocks) for i in range(num_preload): - offload_engine.load_to_slot_layer(load_slots[i], layer_id, cpu_block_table[i]) + cpu_block_id = cpu_block_table[i] + offload_engine.load_to_slot_layer(load_slots[i], layer_id, cpu_block_id, chunk_idx=cpu_block_id) for block_idx in range(num_blocks): current_slot = load_slots[block_idx % num_slots] @@ -428,7 +429,7 @@ class XAttentionBSAPolicy(SparsePolicy): if next_block_idx < num_blocks: next_slot = load_slots[next_block_idx % num_slots] next_cpu_block_id = cpu_block_table[next_block_idx] - offload_engine.load_to_slot_layer(next_slot, layer_id, next_cpu_block_id) + offload_engine.load_to_slot_layer(next_slot, layer_id, next_cpu_block_id, chunk_idx=next_cpu_block_id) # Compute attention to current chunk (causal mask) with torch.cuda.stream(compute_stream): diff --git a/nanovllm/layers/attention.py b/nanovllm/layers/attention.py index 515bd10..a6422aa 100644 --- a/nanovllm/layers/attention.py +++ b/nanovllm/layers/attention.py @@ -104,27 +104,21 @@ class Attention(nn.Module): # This enables fully async offloads since each layer has its own buffer. offload_engine = context.kvcache_manager.offload_engine compute_stream = offload_engine.compute_stream + chunk_idx = context.current_chunk_idx if hasattr(context, 'current_chunk_idx') else -1 # Wait for default stream to ensure slot_mapping tensor transfer is complete compute_stream.wait_stream(torch.cuda.default_stream()) with torch.cuda.stream(compute_stream): - # Write KV to per-layer prefill buffer (contiguous write, no slot_mapping) + # Write KV to per-layer prefill buffer via offload_engine # k, v shape: [num_tokens, kv_heads, head_dim] - num_tokens = k.shape[0] - offload_engine.prefill_k_buffer[self.layer_id, :num_tokens].copy_(k) - offload_engine.prefill_v_buffer[self.layer_id, :num_tokens].copy_(v) + #! GPU 2 GPU + offload_engine.write_to_prefill_buffer(self.layer_id, k, v, chunk_idx=chunk_idx) elif is_chunked_offload: - # Chunked decode mode: use compute_stream for store_kvcache - # This ensures proper synchronization with per-layer offload - compute_stream = context.kvcache_manager.offload_engine.compute_stream - if k_cache.numel() and v_cache.numel(): - # CRITICAL: Wait for default stream to ensure slot_mapping tensor transfer is complete - # slot_mapping is created with non_blocking=True on default stream, but we use it - # on compute_stream. Without this sync, index_copy_ can get corrupted indices. - compute_stream.wait_stream(torch.cuda.default_stream()) - with torch.cuda.stream(compute_stream): - store_kvcache(k, v, k_cache, v_cache, context.slot_mapping) + # Chunked decode mode: write KV to per-layer decode buffer via offload_engine + # KV will be written to decode buffer in the decode branch below + # No store_kvcache needed - all KV management goes through offload_engine + pass else: # Normal mode: store on default stream if k_cache.numel() and v_cache.numel(): @@ -155,8 +149,7 @@ class Attention(nn.Module): offload_engine = kvcache_manager.offload_engine pos_in_block = context.decode_pos_in_block # k, v shape: [1, kv_heads, head_dim] - offload_engine.decode_k_buffer[self.layer_id, pos_in_block].copy_(k.squeeze(0)) - offload_engine.decode_v_buffer[self.layer_id, pos_in_block].copy_(v.squeeze(0)) + offload_engine.write_to_decode_buffer(self.layer_id, pos_in_block, k.squeeze(0), v.squeeze(0)) o = self._chunked_decode_attention(q, k, v, context) else: o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache,