diff --git a/bench_offload.py b/bench_offload.py index 1fe90fc..ed1b953 100644 --- a/bench_offload.py +++ b/bench_offload.py @@ -38,8 +38,8 @@ def main(): llm = LLM( path, enforce_eager=False, - max_model_len=128 * 1024, - max_num_batched_tokens=128 * 1024, + max_model_len=256 * 1024, + max_num_batched_tokens=256 * 1024, enable_cpu_offload=True, num_gpu_blocks=120, num_prefetch_blocks=4, @@ -54,12 +54,12 @@ def main(): # bench_prefill(llm, num_seqs=1, input_len=1024) # bench_prefill(llm, num_seqs=1, input_len=2048) # bench_prefill(llm, num_seqs=1, input_len=4096) - bench_prefill(llm, num_seqs=1, input_len=16 * 1024) + bench_prefill(llm, num_seqs=1, input_len=128 * 1024) print("=" * 60) print("Decode Benchmark (CPU Offload)") print("=" * 60) - bench_decode(llm, num_seqs=1, input_len=16 * 1024, max_output_len=128) + bench_decode(llm, num_seqs=1, input_len=128 * 1024, max_output_len=128) # bench_decode(llm, num_seqs=1, input_len=2048, max_output_len=128) diff --git a/nanovllm/kvcache/offload_engine.py b/nanovllm/kvcache/offload_engine.py index 8990222..a7f8ba7 100644 --- a/nanovllm/kvcache/offload_engine.py +++ b/nanovllm/kvcache/offload_engine.py @@ -152,6 +152,14 @@ class OffloadEngine: self.ring_slot_all_layers_ready = [torch.cuda.Event() for _ in range(self.num_ring_slots)] self.ring_slot_all_layers_offload_done = [torch.cuda.Event() for _ in range(self.num_ring_slots)] + # ========== Per-slot Per-layer compute_done events for async pipeline ========== + # ring_slot_compute_done[slot_idx][layer_id] = CUDA Event for compute completion + # This is used to ensure we don't overwrite data before it's been read by attention + self.ring_slot_compute_done = [ + [torch.cuda.Event() for _ in range(num_layers)] + for _ in range(self.num_ring_slots) + ] + # ========== Event tracking for async transfers ========== self.pending_events: Dict[Tuple[int, int], torch.cuda.Event] = {} @@ -622,11 +630,26 @@ class OffloadEngine: # ----- Per-slot Per-layer loading methods ----- + def record_slot_compute_done(self, slot_idx: int, layer_id: int) -> None: + """ + Record that computation using this slot's data is done. + + This event is used by load_to_slot_layer to ensure we don't overwrite + data before it's been read by attention computation. + + Args: + slot_idx: GPU slot index that was just used for computation + layer_id: Layer index + """ + self.ring_slot_compute_done[slot_idx][layer_id].record() + def load_to_slot_layer(self, slot_idx: int, layer_id: int, cpu_block_id: int) -> None: """ Async load a single CPU block to a ring buffer slot for one layer. This is the core building block for ring buffer pipelining. + Before starting the transfer, waits for any previous compute on this slot + to complete (using compute_done event). Args: slot_idx: Target GPU slot index @@ -636,6 +659,10 @@ class OffloadEngine: logger.debug(f"Ring load: layer={layer_id}, CPU[{cpu_block_id}] -> GPU slot[{slot_idx}]") with torch.cuda.stream(self.transfer_stream_main): + # Wait for previous compute on this slot to complete before overwriting + # This prevents data race: transfer must not start until attention finishes reading + self.transfer_stream_main.wait_event(self.ring_slot_compute_done[slot_idx][layer_id]) + self.k_cache_gpu[layer_id, slot_idx].copy_( self.k_cache_cpu[layer_id, cpu_block_id], non_blocking=True ) diff --git a/nanovllm/layers/attention.py b/nanovllm/layers/attention.py index dfde652..b99b679 100644 --- a/nanovllm/layers/attention.py +++ b/nanovllm/layers/attention.py @@ -209,13 +209,26 @@ class Attention(nn.Module): offload_engine, ): """ - Ring buffer synchronous loading for previous chunks. + Ring buffer async pipeline loading with double buffering. - For correctness, we use synchronous loading: - - Load one block at a time - - Wait for transfer, compute attention, then load next + Uses compute_done events to ensure safe buffer reuse: + - Before loading to slot X, wait for previous compute on slot X to finish + - Before computing on slot X, wait for load to slot X to finish - This ensures no data races between transfer and computation. + Timeline with 2 slots (A, B): + ┌──────────────┐ + │ Load B0→A │ + └──────────────┘ + ┌──────────────┐ ┌──────────────┐ + │ Load B1→B │ │ Load B2→A │ ... + └──────────────┘ └──────────────┘ + ↘ ↘ + ┌──────────────┐ ┌──────────────┐ + │ Compute(A) │ │ Compute(B) │ ... + └──────────────┘ └──────────────┘ + + The load_to_slot_layer internally waits for compute_done[slot] before + starting the transfer, ensuring no data race. """ from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs @@ -224,29 +237,62 @@ class Attention(nn.Module): return None, None pipeline_depth = len(load_slots) + if pipeline_depth == 0: + return None, None + o_acc, lse_acc = None, None - # Process blocks one by one (synchronous) + if pipeline_depth == 1: + # Only 1 slot available, cannot pipeline - use synchronous mode + slot = load_slots[0] + for block_idx in range(num_blocks): + offload_engine.load_to_slot_layer(slot, self.layer_id, cpu_block_table[block_idx]) + offload_engine.wait_slot_layer(slot, self.layer_id) + prev_k, prev_v = offload_engine.get_kv_for_slot(slot, self.layer_id) + prev_o, prev_lse = flash_attn_with_lse( + q_batched, prev_k, prev_v, + softmax_scale=self.scale, + causal=False, + ) + # Record compute done so next load can safely reuse this slot + offload_engine.record_slot_compute_done(slot, self.layer_id) + if o_acc is None: + o_acc, lse_acc = prev_o, prev_lse + else: + o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse) + return o_acc, lse_acc + + # Double buffering with 2 slots + slot_A = load_slots[0] + slot_B = load_slots[1] + + # Pre-load first block to slot_A (async) + offload_engine.load_to_slot_layer(slot_A, self.layer_id, cpu_block_table[0]) + for block_idx in range(num_blocks): - # Determine which slot to use (cycle through load_slots) - slot_idx = load_slots[block_idx % pipeline_depth] - cpu_block_id = cpu_block_table[block_idx] + # Alternate between slot_A and slot_B + current_slot = slot_A if block_idx % 2 == 0 else slot_B + next_slot = slot_B if block_idx % 2 == 0 else slot_A - # Load block to slot (async) - offload_engine.load_to_slot_layer(slot_idx, self.layer_id, cpu_block_id) + # Wait for current slot's transfer to complete + offload_engine.wait_slot_layer(current_slot, self.layer_id) - # Wait for transfer to complete - offload_engine.wait_slot_layer(slot_idx, self.layer_id) - - # Get KV from slot and compute attention - prev_k, prev_v = offload_engine.get_kv_for_slot(slot_idx, self.layer_id) + # Start async load of next block to the OTHER slot + # load_to_slot_layer internally waits for next_slot's compute_done + if block_idx + 1 < num_blocks: + offload_engine.load_to_slot_layer(next_slot, self.layer_id, cpu_block_table[block_idx + 1]) + # Compute attention on current slot's data + prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot, self.layer_id) prev_o, prev_lse = flash_attn_with_lse( q_batched, prev_k, prev_v, softmax_scale=self.scale, causal=False, ) + # Record compute done - this allows the next round to safely load into this slot + offload_engine.record_slot_compute_done(current_slot, self.layer_id) + # Merge with accumulated if o_acc is None: o_acc, lse_acc = prev_o, prev_lse