From b8b6478506dfc8fbecf7ea5bb7573c16b146b1d0 Mon Sep 17 00:00:00 2001 From: Zijie Tian Date: Mon, 15 Dec 2025 06:58:40 +0800 Subject: [PATCH] [feat] Need to optimized with async prefetch. --- CLAUDE.md | 99 ++++-- bench_offload.py | 8 +- nanovllm/config.py | 2 +- nanovllm/engine/model_runner.py | 104 +++--- nanovllm/engine/sequence.py | 2 +- nanovllm/kvcache/hybrid_manager.py | 26 +- nanovllm/kvcache/offload_engine.py | 549 ++++++++++++++++------------- nanovllm/layers/attention.py | 166 ++++++--- nanovllm/utils/context.py | 4 + 9 files changed, 556 insertions(+), 404 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index 84d8d22..6a117ca 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -44,74 +44,101 @@ Nano-vLLM is a lightweight vLLM implementation (~1,200 lines) for fast offline L When `enable_cpu_offload=True`, KV cache is stored on CPU with a small GPU buffer for computation. This enables long-context inference with limited GPU memory. -### Three-Region GPU Buffer Design +### Unified Ring Buffer Design ``` -GPU Slots: [0] [1, 2, 3] [4, 5] - ↑ ↑ ↑ - decode compute prefetch - (1 slot) (N slots) (M slots) +GPU Slots: [0] [1] [2] [3] [4] ... + ←────────────────────────────→ + All slots as ring buffer -- Decode slot: New token's KV written here during decode -- Compute region: Load CPU blocks for current chunk computation -- Prefetch region: Async load next chunk while computing current +Prefill: ALL slots cycle as ring buffer [slot = chunk_idx % N] +Decode: slot[0] = decode_slot, slots[1:] = load slots for previous chunks ``` **File**: `nanovllm/kvcache/offload_engine.py` Key attributes: +- `num_ring_slots`: Total GPU slots (= num_gpu_blocks) +- `ring_slots`: List of all GPU slot indices [0, 1, 2, ...] - `decode_slot = 0`: Fixed slot for decode KV writes -- `compute_slots`: List of GPU slots for compute region -- `prefetch_slots`: List of GPU slots for prefetch region +- `decode_load_slots`: Slots[1:] for loading previous chunks during decode - `k_cache_gpu/v_cache_gpu`: Shape `[num_layers, num_gpu_blocks, block_size, kv_heads, head_dim]` - `k_cache_cpu/v_cache_cpu`: Shape `[num_layers, num_cpu_blocks, block_size, kv_heads, head_dim]` (pinned memory) -### Per-Layer Loading (Critical Design) - -**Problem solved**: Original design had layer 0 load ALL layers' KV at once. When layer 0 processed chunk 1, it overwrote chunk 0's data before layer 1+ could read it. - -**Solution**: Each layer independently loads only its own KV data: +Key methods: ```python -# Per-layer methods in OffloadEngine -load_to_compute_layer(layer_id, cpu_block_ids) # Load single layer to compute region -wait_compute_layer(layer_id) # Wait for layer's transfer -load_to_prefetch_layer(layer_id, cpu_block_ids) # Load single layer to prefetch region -wait_prefetch_layer(layer_id) # Wait for layer's prefetch +# Prefill: get write slot and load slots +get_write_slot_for_prefill(chunk_idx) # Returns chunk_idx % num_ring_slots +get_load_slots_for_prefill(write_slot_idx) # Returns all slots except write_slot + +# Decode: get load slots (excludes decode_slot) +get_load_slots_for_decode() # Returns slots[1:] + +# Per-slot per-layer operations +load_to_slot_layer(slot_idx, layer_id, cpu_block_id) # Async load single block +wait_slot_layer(slot_idx, layer_id) # Wait for layer's transfer +offload_slot_to_cpu(slot_idx, cpu_block_id) # Async offload to CPU ``` -### Chunked Prefill Flow +### Per-Slot Per-Layer Events (Critical Design) + +Each slot has per-layer CUDA events for fine-grained synchronization: +- `ring_slot_ready[slot_idx][layer_id]`: H2D transfer completion +- `ring_slot_offload_done[slot_idx][layer_id]`: D2H transfer completion + +This enables: +1. Overlapped H2D transfer with attention computation +2. Each layer independently waits for its own data +3. Pipeline depth = N-1 for prefill (N slots, 1 for writing) + +### Chunked Prefill Flow (Ring Buffer Pipeline) **File**: `nanovllm/layers/attention.py` - `_chunked_prefill_attention()` ``` -For each prefill chunk: -1. Current chunk's KV is written to GPU (compute region slots) -2. Load previous chunks' KV from CPU to prefetch region +For prefill chunk K: +1. Current chunk's KV written to ring_slot[K % N] +2. Load previous chunks from CPU using N-1 available slots (pipeline) 3. Compute attention against previous KV (no causal mask) 4. Compute attention against current KV (causal mask) 5. Merge results using online softmax (LSE) -6. Offload current chunk's KV to CPU +6. Offload current slot to CPU + +Pipeline Timeline (with 4 slots, processing chunk 3): +write_slot = 3, load_slots = [0, 1, 2] + +┌─────────────┐ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ +│Load B0→S0 │ │Load B1→S1 │ │Load B2→S2 │ │ (wait) │ +└─────────────┘ └─────────────┘ └─────────────┘ └─────────────┘ + ↘ ↘ ↘ + ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ + │ Attn(B0) │ │ Attn(B1) │ │ Attn(B2) │ + └─────────────┘ └─────────────┘ └─────────────┘ ``` -**Important**: Prefill uses ONLY prefetch region to avoid conflict with current chunk's KV being written to compute region. +**Key**: Write slot cycles through ALL slots, load slots = all except write slot. ### Chunked Decode Flow (Double Buffering) **File**: `nanovllm/layers/attention.py` - `_chunked_decode_attention()` +Decode uses legacy double-buffering with `decode_load_slots`: +- First half of decode_load_slots: 'compute' buffer +- Second half: 'prefetch' buffer + ``` -Timeline (async double buffering): +Timeline: ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ -Load: │C0 → Compute │ │C1 → Prefetch│ │C2 → Compute │ +Load: │C0 → buf0 │ │C1 → buf1 │ │C2 → buf0 │ └─────────────┘ └─────────────┘ └─────────────┘ ↘ ↘ ↘ Compute: [C0] [C1] [C2] -1. Pre-load first chunk to compute region -2. Wait for current buffer, trigger async prefetch of next chunk to OTHER buffer +1. Pre-load first chunk to compute buffer +2. Wait for current buffer, trigger async prefetch to OTHER buffer 3. Compute attention, merge results 4. Swap buffers, repeat -5. Finally attend to decode slot (new token's KV) +5. Finally attend to decode_slot (new token's KV) ``` ### HybridKVCacheManager @@ -120,7 +147,7 @@ Compute: [C0] [C1] [C2] Manages both GPU and CPU blocks: - `allocate()`: Allocate GPU block first, fallback to CPU -- `allocate_cpu_only()`: Force CPU allocation (for chunked offload mode) +- `allocate_cpu_only()`: Force CPU allocation (for ring buffer mode) - `get_all_cpu_blocks(seq)`: Get all CPU block IDs for a sequence - `get_prefilled_cpu_blocks(seq)`: Get CPU blocks from previous chunks - `get_write_slot_for_chunked_offload(seq)`: Get GPU slot for writing new KV (returns decode_slot) @@ -136,9 +163,7 @@ def merge_attention_outputs(o1, lse1, o2, lse2): # Uses LSE to correctly weight and combine partial attention outputs ``` -### Ring Buffer Design (Future Optimization) +### Pipeline Depth -Current double-buffering limits pipeline depth. Planned improvement: -- Unified ring buffer using all GPU slots (except decode) -- Per-slot per-layer CUDA events for fine-grained sync -- Deeper pipeline: prefetch N-1 blocks ahead (vs 1 chunk) +- **Prefill**: Pipeline depth = N-1 (where N = num_gpu_blocks) +- **Decode**: Pipeline depth = (N-1)/2 (double buffering within decode_load_slots) diff --git a/bench_offload.py b/bench_offload.py index e4a1771..1fe90fc 100644 --- a/bench_offload.py +++ b/bench_offload.py @@ -41,8 +41,8 @@ def main(): max_model_len=128 * 1024, max_num_batched_tokens=128 * 1024, enable_cpu_offload=True, - num_gpu_blocks=6, - num_prefetch_blocks=2, + num_gpu_blocks=120, + num_prefetch_blocks=4, ) # Warmup @@ -54,12 +54,12 @@ def main(): # bench_prefill(llm, num_seqs=1, input_len=1024) # bench_prefill(llm, num_seqs=1, input_len=2048) # bench_prefill(llm, num_seqs=1, input_len=4096) - bench_prefill(llm, num_seqs=1, input_len=64 * 1024) + bench_prefill(llm, num_seqs=1, input_len=16 * 1024) print("=" * 60) print("Decode Benchmark (CPU Offload)") print("=" * 60) - bench_decode(llm, num_seqs=1, input_len=64 * 1024, max_output_len=128) + bench_decode(llm, num_seqs=1, input_len=16 * 1024, max_output_len=128) # bench_decode(llm, num_seqs=1, input_len=2048, max_output_len=128) diff --git a/nanovllm/config.py b/nanovllm/config.py index 13a8e29..5b9f98d 100644 --- a/nanovllm/config.py +++ b/nanovllm/config.py @@ -14,7 +14,7 @@ class Config: enforce_eager: bool = False hf_config: AutoConfig | None = None eos: int = -1 - kvcache_block_size: int = 256 + kvcache_block_size: int = 4096 num_kvcache_blocks: int = -1 # CPU Offload configuration diff --git a/nanovllm/engine/model_runner.py b/nanovllm/engine/model_runner.py index 6fbdb78..ad60f61 100644 --- a/nanovllm/engine/model_runner.py +++ b/nanovllm/engine/model_runner.py @@ -630,29 +630,31 @@ class ModelRunner: def run_chunked_offload_prefill(self, seqs: list[Sequence]) -> list[int]: """ - Run prefill with three-region GPU buffer (CPU is primary storage). + Run prefill with unified ring buffer (CPU is primary storage). Flow: 1. All blocks are allocated to CPU (primary storage) - 2. Process tokens in chunks using Compute region GPU buffer - 3. After each chunk, offload from Compute region to CPU - 4. Prefetch region is used to load previous KV (if any) + 2. Each chunk writes KV to ring buffer slot[chunk_idx % N] + 3. After each chunk, offload from ring buffer slot to CPU + 4. All N-1 other slots are used to load previous chunks for attention """ import sys - assert len(seqs) == 1, "Three-region prefill only supports single sequence" + assert len(seqs) == 1, "Ring buffer prefill only supports single sequence" seq = seqs[0] offload_engine = self.kvcache_manager.offload_engine - compute_size = offload_engine.num_compute_blocks - tokens_per_chunk = compute_size * self.block_size + # Each chunk uses 1 ring buffer slot = 1 block + tokens_per_chunk = self.block_size total_tokens = len(seq) - print(f"[Three-region Prefill] Starting: {total_tokens} tokens, " - f"compute_size={compute_size} blocks, chunk={tokens_per_chunk} tokens", + num_chunks = (total_tokens + tokens_per_chunk - 1) // tokens_per_chunk + print(f"[Ring Buffer Prefill] Starting: {total_tokens} tokens, " + f"ring_slots={offload_engine.num_ring_slots}, chunk={tokens_per_chunk} tokens, " + f"total_chunks={num_chunks}", file=sys.stderr) - chunk_num = 0 + chunk_idx = 0 logits = None processed_tokens = 0 @@ -660,27 +662,22 @@ class ModelRunner: cpu_block_ids, logical_ids = self.kvcache_manager.get_all_cpu_blocks(seq) while processed_tokens < total_tokens: - chunk_num += 1 chunk_start = processed_tokens chunk_end = min(processed_tokens + tokens_per_chunk, total_tokens) - chunk_tokens = chunk_end - chunk_start - # Calculate which CPU blocks this chunk covers - start_block_idx = chunk_start // self.block_size - end_block_idx = (chunk_end + self.block_size - 1) // self.block_size - num_blocks = end_block_idx - start_block_idx + # Get ring buffer slot for this chunk + write_slot = offload_engine.get_write_slot_for_prefill(chunk_idx) - print(f"[Three-region Prefill] Chunk {chunk_num}: tokens {chunk_start}-{chunk_end}, " - f"blocks {start_block_idx}-{end_block_idx-1}, " - f"compute_slots={offload_engine.compute_slots[:num_blocks]}", + # CPU block index for this chunk + block_idx = chunk_idx + + print(f"[Ring Buffer Prefill] Chunk {chunk_idx}: tokens {chunk_start}-{chunk_end}, " + f"write_slot={write_slot}", file=sys.stderr) - # Get GPU slots for this chunk (using Compute region) - gpu_slots = offload_engine.compute_slots[:num_blocks] - # Prepare inputs input_ids, positions = self._prepare_chunked_offload_chunk( - seq, chunk_start, chunk_end, gpu_slots, start_block_idx + seq, chunk_start, chunk_end, write_slot, block_idx, chunk_idx ) if input_ids.numel() == 0: @@ -690,24 +687,27 @@ class ModelRunner: logits = self.run_model(input_ids, positions, is_prefill=True) reset_context() - # Mark blocks as prefilled - for i in range(start_block_idx, min(end_block_idx, len(seq.block_table))): - logical_id = seq.block_table[i] + # Mark block as prefilled + if block_idx < len(seq.block_table): + logical_id = seq.block_table[block_idx] self.kvcache_manager.prefilled_blocks.add(logical_id) - # Offload this chunk from Compute region to CPU (async) - chunk_cpu_blocks = cpu_block_ids[start_block_idx:end_block_idx] - offload_engine.offload_compute_to_cpu(chunk_cpu_blocks) + # Offload this chunk's ring buffer slot to CPU (async) + if block_idx < len(cpu_block_ids): + cpu_block_id = cpu_block_ids[block_idx] + offload_engine.offload_slot_to_cpu(write_slot, cpu_block_id) # Wait for offload to complete before next chunk - offload_engine.wait_all_offload_done() + # (slot will be reused after N chunks) + offload_engine.wait_slot_offload(write_slot) processed_tokens = chunk_end + chunk_idx += 1 # Wait for all offloads to complete offload_engine.wait_all_offload_done() - print(f"[Three-region Prefill] Complete: {chunk_num} chunks", file=sys.stderr) + print(f"[Ring Buffer Prefill] Complete: {chunk_idx} chunks", file=sys.stderr) # Sample from last logits temperatures = self.prepare_sample(seqs) if self.rank == 0 else None @@ -723,34 +723,24 @@ class ModelRunner: seq: Sequence, chunk_start: int, chunk_end: int, - gpu_slots: list[int], - start_block_idx: int, + write_slot: int, + block_idx: int, + chunk_idx: int, ) -> tuple[torch.Tensor, torch.Tensor]: - """Prepare inputs for a chunked offload prefill chunk.""" + """Prepare inputs for a chunked offload prefill chunk (ring buffer design).""" # Input tokens for this chunk input_ids = seq[chunk_start:chunk_end] positions = list(range(chunk_start, chunk_end)) - # Create slot mapping pointing to GPU slots + # Create slot mapping pointing to the single write_slot slot_mapping = [] - num_tokens = chunk_end - chunk_start - - token_idx = 0 - for i, gpu_slot in enumerate(gpu_slots): - block_idx = start_block_idx + i - block_start = block_idx * self.block_size - block_end = min(block_start + self.block_size, len(seq)) - - # How many tokens in this block for this chunk - overlap_start = max(chunk_start, block_start) - overlap_end = min(chunk_end, block_end) - - for pos in range(overlap_start, overlap_end): - pos_in_block = pos % self.block_size - slot = gpu_slot * self.block_size + pos_in_block - slot_mapping.append(slot) + for pos in range(chunk_start, chunk_end): + pos_in_block = pos % self.block_size + slot = write_slot * self.block_size + pos_in_block + slot_mapping.append(slot) # Convert to tensors + num_tokens = chunk_end - chunk_start input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) @@ -770,21 +760,23 @@ class ModelRunner: is_chunked_prefill=True, kvcache_manager=self.kvcache_manager, chunked_seq=seq, + current_chunk_idx=chunk_idx, # Pass chunk index for ring buffer pipeline ) return input_ids, positions def run_chunked_offload_decode(self, seqs: list[Sequence]) -> list[int]: """ - Run decode with three-region GPU buffer. + Run decode with ring buffer (CPU is primary storage). - All KV is on CPU. Uses Decode region to write new KV, Compute/Prefetch region to load KV chunks. - New token's KV is written to Decode region (slot 0) then offloaded to CPU only when block is full. + All KV is on CPU. Uses decode_slot (slot[0]) to write new KV. + Other slots (slots[1:]) are used to load previous KV chunks via pipeline. + New token's KV is written to decode_slot then offloaded to CPU only when block is full. - Key: Decode region is never overwritten by Compute/Prefetch, dedicated to writing new KV. + Key: decode_slot is dedicated to writing new KV, never used for loading. Optimization: Batch offloads - only offload when block is full, attend to all accumulated tokens. """ - assert len(seqs) == 1, "Three-region decode only supports single sequence" + assert len(seqs) == 1, "Ring buffer decode only supports single sequence" seq = seqs[0] offload_engine = self.kvcache_manager.offload_engine diff --git a/nanovllm/engine/sequence.py b/nanovllm/engine/sequence.py index 49d9ee6..59a17c1 100644 --- a/nanovllm/engine/sequence.py +++ b/nanovllm/engine/sequence.py @@ -12,7 +12,7 @@ class SequenceStatus(Enum): class Sequence: - block_size = 256 + block_size = 4096 counter = count() def __init__(self, token_ids: list[int], sampling_params = SamplingParams()): diff --git a/nanovllm/kvcache/hybrid_manager.py b/nanovllm/kvcache/hybrid_manager.py index baa8450..c915acd 100644 --- a/nanovllm/kvcache/hybrid_manager.py +++ b/nanovllm/kvcache/hybrid_manager.py @@ -95,16 +95,16 @@ class HybridKVCacheManager(KVCacheManager): num_cpu_blocks: Number of CPU pool blocks (overflow or primary storage) block_size: Tokens per block policy: Eviction policy (default: LRU) - cpu_primary: If True, use CPU as primary storage with three-region GPU buffer. + cpu_primary: If True, use CPU as primary storage with ring buffer GPU design. If False, use GPU as primary with CPU as overflow (legacy mode). - num_prefetch_blocks: Number of prefetch blocks for three-region GPU buffer design + num_prefetch_blocks: Number of blocks for ring buffer pipeline (deprecated, ring_slots = num_gpu_slots) """ self._block_size = block_size self.num_gpu_slots = num_gpu_slots self.num_cpu_blocks = num_cpu_blocks self.total_blocks = num_gpu_slots + num_cpu_blocks - self.cpu_primary = cpu_primary # Three-region mode flag - self.num_prefetch_blocks = num_prefetch_blocks # Three-region design parameter + self.cpu_primary = cpu_primary # Ring buffer mode flag + self.num_prefetch_blocks = num_prefetch_blocks # Ring buffer design parameter (deprecated) # Eviction policy self.policy = policy or LRUPolicy() @@ -341,7 +341,7 @@ class HybridKVCacheManager(KVCacheManager): """ assert not seq.block_table, "Sequence already has blocks" - # Three-region mode: all blocks are allocated to CPU + # Ring buffer mode: all blocks are allocated to CPU if self.cpu_primary: return self.allocate_cpu_only(seq) @@ -471,7 +471,7 @@ class HybridKVCacheManager(KVCacheManager): block.token_ids = [] if self.cpu_primary: - # Three-region mode: new block allocated to CPU + # Ring buffer mode: new block allocated to CPU if not self.free_cpu_blocks: raise RuntimeError("No free CPU blocks for decode") cpu_block_id = self.free_cpu_blocks.popleft() @@ -1025,14 +1025,14 @@ class HybridKVCacheManager(KVCacheManager): break return pos - # ========== Three-region double buffering support ========== + # ========== Ring Buffer CPU-primary support ========== def allocate_cpu_only(self, seq: Sequence) -> None: """ - Allocate CPU blocks for sequence (for three-region mode). + Allocate CPU blocks for sequence (for ring buffer mode). Unlike allocate(), here all blocks are allocated to CPU, - GPU is only used as working buffer. + GPU is only used as ring buffer for computation. Args: seq: Sequence to allocate @@ -1092,10 +1092,10 @@ class HybridKVCacheManager(KVCacheManager): cpu_blocks.append(block.cpu_block_id) else: # If block is on GPU, it should have a corresponding CPU block - # In three-region mode, all data ultimately resides on CPU + # In ring buffer mode, all data ultimately resides on CPU raise RuntimeError( f"Block {logical_id} not on CPU (location={block.location}). " - f"In three-region mode, all blocks should be on CPU." + f"In ring buffer mode, all blocks should be on CPU." ) return cpu_blocks @@ -1171,8 +1171,8 @@ class HybridKVCacheManager(KVCacheManager): """ Get GPU slot for writing new KV during chunked offload decode. - In three-region design, always use Decode region (slot 0) to write new KV. - This avoids conflicts with Compute/Prefetch region loading operations. + In ring buffer design, always use decode_slot (slot[0]) to write new KV. + This avoids conflicts with loading operations which use slots[1:]. Args: seq: Sequence diff --git a/nanovllm/kvcache/offload_engine.py b/nanovllm/kvcache/offload_engine.py index 616f395..8990222 100644 --- a/nanovllm/kvcache/offload_engine.py +++ b/nanovllm/kvcache/offload_engine.py @@ -65,34 +65,30 @@ class OffloadEngine: self.kv_dim = num_kv_heads * head_dim self.block_numel = block_size * self.kv_dim - # ========== Three-region GPU Buffer configuration ========== + # ========== Unified Ring Buffer configuration ========== # Constraint checks - assert num_gpu_blocks >= 3, \ - f"Need at least 3 GPU blocks: 1 decode + 1 compute + 1 prefetch, got {num_gpu_blocks}" - assert num_prefetch_blocks >= 1, \ - f"Need at least 1 prefetch block, got {num_prefetch_blocks}" - assert num_gpu_blocks >= 1 + 1 + num_prefetch_blocks, \ - f"Insufficient GPU blocks: need 1(decode) + 1(compute) + {num_prefetch_blocks}(prefetch), got {num_gpu_blocks}" + assert num_gpu_blocks >= 2, \ + f"Need at least 2 GPU blocks for ring buffer, got {num_gpu_blocks}" - # Three-region configuration - # Decode region: [0] - Fixed 1 block for writing new KV + # Unified Ring Buffer: all slots cycle for prefill + # Prefill: use ALL slots as ring buffer (slot[chunk_idx % N]) + # Decode: slot[0] as decode_slot, slots[1:] for loading previous chunks + self.num_ring_slots = num_gpu_blocks + self.ring_slots = list(range(num_gpu_blocks)) + + # Decode phase uses slot[0] for writing new token's KV self.decode_slot = 0 + # Decode phase uses slots[1:] for loading previous chunks from CPU + self.decode_load_slots = list(range(1, num_gpu_blocks)) + self.num_decode_load_slots = len(self.decode_load_slots) - # Compute region: [1, ..., num_gpu_blocks - num_prefetch_blocks - 1] - compute_start = 1 - compute_end = num_gpu_blocks - num_prefetch_blocks - self.compute_slots = list(range(compute_start, compute_end)) - self.num_compute_blocks = len(self.compute_slots) - - # Prefetch region: [num_gpu_blocks - num_prefetch_blocks, ..., num_gpu_blocks - 1] - prefetch_start = compute_end - self.prefetch_slots = list(range(prefetch_start, num_gpu_blocks)) + # Keep num_prefetch_blocks for compatibility (used as chunk size for loading) self.num_prefetch_blocks = num_prefetch_blocks - self.num_gpu_slots = num_gpu_blocks # alias - logger.info(f"Three-region GPU Buffer: decode_slot={self.decode_slot}, " - f"compute_slots={self.compute_slots}, prefetch_slots={self.prefetch_slots}") + logger.info(f"Unified Ring Buffer: {self.num_ring_slots} slots total") + logger.info(f" Prefill: all slots as ring buffer [0..{num_gpu_blocks-1}]") + logger.info(f" Decode: slot[0] as decode_slot, slots[1..{num_gpu_blocks-1}] for loading") # ========== Fixed-address GPU KV cache ========== # Shape: [num_layers, num_gpu_blocks, block_size, kv_heads, head_dim] @@ -134,18 +130,27 @@ class OffloadEngine: self.compute_stream = torch.cuda.current_stream() self._stream_idx = 0 - # ========== Three-region dedicated stream and events ========== + # ========== Ring Buffer dedicated stream and events ========== self.transfer_stream_main = torch.cuda.Stream() # Main transfer stream - # Sync events - three-region loading completion - self.compute_ready = torch.cuda.Event() - self.prefetch_ready = torch.cuda.Event() + # Decode offload event self.decode_offload_done = torch.cuda.Event() - # ========== Per-layer events for chunked attention ========== - # Each layer has its own event for synchronization - self.compute_ready_per_layer = [torch.cuda.Event() for _ in range(num_layers)] - self.prefetch_ready_per_layer = [torch.cuda.Event() for _ in range(num_layers)] + # ========== Per-slot Per-layer events for ring buffer ========== + # ring_slot_ready[slot_idx][layer_id] = CUDA Event for H2D completion + # ring_slot_offload_done[slot_idx][layer_id] = CUDA Event for D2H completion + self.ring_slot_ready = [ + [torch.cuda.Event() for _ in range(num_layers)] + for _ in range(self.num_ring_slots) + ] + self.ring_slot_offload_done = [ + [torch.cuda.Event() for _ in range(num_layers)] + for _ in range(self.num_ring_slots) + ] + + # Per-slot events for all-layer operations (used in some legacy paths) + self.ring_slot_all_layers_ready = [torch.cuda.Event() for _ in range(self.num_ring_slots)] + self.ring_slot_all_layers_offload_done = [torch.cuda.Event() for _ in range(self.num_ring_slots)] # ========== Event tracking for async transfers ========== self.pending_events: Dict[Tuple[int, int], torch.cuda.Event] = {} @@ -560,7 +565,7 @@ class OffloadEngine: f" kv_heads={self.num_kv_heads},\n" f" head_dim={self.head_dim},\n" f" dtype={self.dtype},\n" - f" three_regions: decode_slot={self.decode_slot}, compute={self.compute_slots}, prefetch={self.prefetch_slots},\n" + f" ring_buffer: {self.num_ring_slots} slots, decode_slot={self.decode_slot}, decode_load_slots={self.decode_load_slots},\n" f" gpu_memory={self.gpu_memory_bytes() / 1024**2:.1f}MB,\n" f" cpu_memory={self.cpu_memory_bytes() / 1024**2:.1f}MB\n" f")" @@ -570,174 +575,207 @@ class OffloadEngine: """Wait for all offload operations to complete.""" self.transfer_stream_main.synchronize() + # ========== Unified Ring Buffer methods ========== + + # ----- Prefill: Ring Buffer slot management ----- + + def get_write_slot_for_prefill(self, chunk_idx: int) -> int: + """ + Get ring buffer slot for writing prefill chunk. + + For prefill, ALL slots are used as ring buffer, cycling through. + + Args: + chunk_idx: Current chunk index (0, 1, 2, ...) + + Returns: + GPU slot index for writing + """ + return chunk_idx % self.num_ring_slots + + def get_load_slots_for_prefill(self, write_slot_idx: int) -> List[int]: + """ + Get available slots for loading previous chunks during prefill. + + Excludes the current write slot to avoid conflict. + + Args: + write_slot_idx: Current write slot index + + Returns: + List of slot indices available for loading (N-1 slots) + """ + return [i for i in range(self.num_ring_slots) if i != write_slot_idx] + + # ----- Decode: slot management ----- + + def get_load_slots_for_decode(self) -> List[int]: + """ + Get slots available for loading during decode. + + Excludes decode_slot (slot[0]) since it's used for writing new token's KV. + + Returns: + List of slot indices for loading (slots[1:]) + """ + return self.decode_load_slots + + # ----- Per-slot Per-layer loading methods ----- + + def load_to_slot_layer(self, slot_idx: int, layer_id: int, cpu_block_id: int) -> None: + """ + Async load a single CPU block to a ring buffer slot for one layer. + + This is the core building block for ring buffer pipelining. + + Args: + slot_idx: Target GPU slot index + layer_id: Layer index to load + cpu_block_id: Source CPU block ID + """ + logger.debug(f"Ring load: layer={layer_id}, CPU[{cpu_block_id}] -> GPU slot[{slot_idx}]") + + with torch.cuda.stream(self.transfer_stream_main): + self.k_cache_gpu[layer_id, slot_idx].copy_( + self.k_cache_cpu[layer_id, cpu_block_id], non_blocking=True + ) + self.v_cache_gpu[layer_id, slot_idx].copy_( + self.v_cache_cpu[layer_id, cpu_block_id], non_blocking=True + ) + self.ring_slot_ready[slot_idx][layer_id].record(self.transfer_stream_main) + + def wait_slot_layer(self, slot_idx: int, layer_id: int) -> None: + """ + Wait for a slot's loading to complete for a specific layer. + + Args: + slot_idx: GPU slot index to wait for + layer_id: Layer index to wait for + """ + self.compute_stream.wait_event(self.ring_slot_ready[slot_idx][layer_id]) + + def load_to_slot_all_layers(self, slot_idx: int, cpu_block_id: int) -> None: + """ + Async load a CPU block to a ring buffer slot for ALL layers. + + Args: + slot_idx: Target GPU slot index + cpu_block_id: Source CPU block ID + """ + logger.debug(f"Ring load all layers: CPU[{cpu_block_id}] -> GPU slot[{slot_idx}]") + + with torch.cuda.stream(self.transfer_stream_main): + self.k_cache_gpu[:, slot_idx].copy_( + self.k_cache_cpu[:, cpu_block_id], non_blocking=True + ) + self.v_cache_gpu[:, slot_idx].copy_( + self.v_cache_cpu[:, cpu_block_id], non_blocking=True + ) + self.ring_slot_all_layers_ready[slot_idx].record(self.transfer_stream_main) + + def wait_slot_all_layers(self, slot_idx: int) -> None: + """Wait for a slot's loading to complete for ALL layers.""" + self.compute_stream.wait_event(self.ring_slot_all_layers_ready[slot_idx]) + + # ----- Slot offload methods ----- + + def offload_slot_to_cpu(self, slot_idx: int, cpu_block_id: int) -> None: + """ + Async offload a ring buffer slot to CPU (all layers). + + Args: + slot_idx: Source GPU slot index + cpu_block_id: Target CPU block ID + """ + logger.debug(f"Ring offload: GPU slot[{slot_idx}] -> CPU[{cpu_block_id}]") + + with torch.cuda.stream(self.transfer_stream_main): + self.transfer_stream_main.wait_stream(self.compute_stream) + self.k_cache_cpu[:, cpu_block_id].copy_( + self.k_cache_gpu[:, slot_idx], non_blocking=True + ) + self.v_cache_cpu[:, cpu_block_id].copy_( + self.v_cache_gpu[:, slot_idx], non_blocking=True + ) + self.ring_slot_all_layers_offload_done[slot_idx].record(self.transfer_stream_main) + + def wait_slot_offload(self, slot_idx: int) -> None: + """Wait for slot offload to complete.""" + self.compute_stream.wait_event(self.ring_slot_all_layers_offload_done[slot_idx]) + + def offload_slot_layer_to_cpu(self, slot_idx: int, layer_id: int, cpu_block_id: int) -> None: + """ + Async offload a ring buffer slot to CPU for one layer. + + Args: + slot_idx: Source GPU slot index + layer_id: Layer index to offload + cpu_block_id: Target CPU block ID + """ + with torch.cuda.stream(self.transfer_stream_main): + self.transfer_stream_main.wait_stream(self.compute_stream) + self.k_cache_cpu[layer_id, cpu_block_id].copy_( + self.k_cache_gpu[layer_id, slot_idx], non_blocking=True + ) + self.v_cache_cpu[layer_id, cpu_block_id].copy_( + self.v_cache_gpu[layer_id, slot_idx], non_blocking=True + ) + self.ring_slot_offload_done[slot_idx][layer_id].record(self.transfer_stream_main) + + def wait_slot_offload_layer(self, slot_idx: int, layer_id: int) -> None: + """Wait for slot offload to complete for a specific layer.""" + self.compute_stream.wait_event(self.ring_slot_offload_done[slot_idx][layer_id]) + + # ----- KV access methods for ring buffer ----- + + def get_kv_for_slot(self, slot_idx: int, layer_id: int) -> Tuple[Tensor, Tensor]: + """ + Get KV for a single ring buffer slot. + + Args: + slot_idx: GPU slot index + layer_id: Layer ID + + Returns: + (k_cache, v_cache), shape: [1, block_size, kv_heads, head_dim] + """ + k = self.k_cache_gpu[layer_id, slot_idx].unsqueeze(0) # [1, block_size, heads, dim] + v = self.v_cache_gpu[layer_id, slot_idx].unsqueeze(0) + return k, v + def get_kv_for_slots( self, layer_id: int, - gpu_slots: List[int], + slot_indices: List[int], ) -> Tuple[Tensor, Tensor]: """ - Get KV for specified GPU slots. + Get KV for multiple ring buffer slots. Args: layer_id: Layer ID - gpu_slots: List of GPU slot IDs + slot_indices: List of GPU slot indices Returns: (k_cache, v_cache), shape: [1, len(slots) * block_size, kv_heads, head_dim] """ - if not gpu_slots: + if not slot_indices: return None, None - k = self.k_cache_gpu[layer_id, gpu_slots] - v = self.v_cache_gpu[layer_id, gpu_slots] + k = self.k_cache_gpu[layer_id, slot_indices] + v = self.v_cache_gpu[layer_id, slot_indices] k = k.reshape(1, -1, self.num_kv_heads, self.head_dim) v = v.reshape(1, -1, self.num_kv_heads, self.head_dim) return k, v - # ========== Three-region GPU Buffer methods ========== - - def load_to_compute(self, cpu_block_ids: List[int]) -> None: - """ - Async load CPU blocks to Compute region. - - Args: - cpu_block_ids: List of CPU block IDs to load - """ - if not cpu_block_ids: - self.compute_ready.record(self.transfer_stream_main) - return - - num_to_load = min(len(cpu_block_ids), len(self.compute_slots)) - logger.debug(f"Compute load: CPU{cpu_block_ids[:num_to_load]} -> GPU compute slots {self.compute_slots[:num_to_load]}") - - with torch.cuda.stream(self.transfer_stream_main): - for i in range(num_to_load): - cpu_id = cpu_block_ids[i] - gpu_slot = self.compute_slots[i] - # Copy all layers together - self.k_cache_gpu[:, gpu_slot].copy_( - self.k_cache_cpu[:, cpu_id], non_blocking=True - ) - self.v_cache_gpu[:, gpu_slot].copy_( - self.v_cache_cpu[:, cpu_id], non_blocking=True - ) - self.compute_ready.record(self.transfer_stream_main) - - def load_to_prefetch(self, cpu_block_ids: List[int]) -> None: - """ - Async load CPU blocks to Prefetch region. - - Args: - cpu_block_ids: List of CPU block IDs to load - """ - if not cpu_block_ids: - self.prefetch_ready.record(self.transfer_stream_main) - return - - num_to_load = min(len(cpu_block_ids), len(self.prefetch_slots)) - logger.debug(f"Prefetch load: CPU{cpu_block_ids[:num_to_load]} -> GPU prefetch slots {self.prefetch_slots[:num_to_load]}") - - with torch.cuda.stream(self.transfer_stream_main): - for i in range(num_to_load): - cpu_id = cpu_block_ids[i] - gpu_slot = self.prefetch_slots[i] - self.k_cache_gpu[:, gpu_slot].copy_( - self.k_cache_cpu[:, cpu_id], non_blocking=True - ) - self.v_cache_gpu[:, gpu_slot].copy_( - self.v_cache_cpu[:, cpu_id], non_blocking=True - ) - self.prefetch_ready.record(self.transfer_stream_main) - - def wait_compute(self) -> None: - """Wait for Compute region loading to complete.""" - self.compute_stream.wait_event(self.compute_ready) - - def load_to_compute_layer(self, layer_id: int, cpu_block_ids: List[int]) -> None: - """ - Load CPU blocks to Compute region for a single layer only. - - This is used for per-layer chunked attention where each layer - independently loads its KV data. - - Args: - layer_id: Layer index to load - cpu_block_ids: List of CPU block IDs to load - """ - if not cpu_block_ids: - self.compute_ready_per_layer[layer_id].record(self.transfer_stream_main) - return - - num_to_load = min(len(cpu_block_ids), len(self.compute_slots)) - logger.debug(f"Compute load (layer {layer_id}): CPU{cpu_block_ids[:num_to_load]} -> GPU compute slots {self.compute_slots[:num_to_load]}") - - with torch.cuda.stream(self.transfer_stream_main): - for i in range(num_to_load): - cpu_id = cpu_block_ids[i] - gpu_slot = self.compute_slots[i] - # Copy only this layer (not all layers) - self.k_cache_gpu[layer_id, gpu_slot].copy_( - self.k_cache_cpu[layer_id, cpu_id], non_blocking=True - ) - self.v_cache_gpu[layer_id, gpu_slot].copy_( - self.v_cache_cpu[layer_id, cpu_id], non_blocking=True - ) - self.compute_ready_per_layer[layer_id].record(self.transfer_stream_main) - - def wait_compute_layer(self, layer_id: int) -> None: - """Wait for specific layer's Compute region loading to complete.""" - self.compute_stream.wait_event(self.compute_ready_per_layer[layer_id]) - - def wait_prefetch(self) -> None: - """Wait for Prefetch region loading to complete.""" - self.compute_stream.wait_event(self.prefetch_ready) - - def load_to_prefetch_layer(self, layer_id: int, cpu_block_ids: List[int]) -> None: - """ - Load CPU blocks to Prefetch region for a single layer only. - - This is used for per-layer chunked attention where each layer - independently loads its KV data. - - Args: - layer_id: Layer index to load - cpu_block_ids: List of CPU block IDs to load - """ - if not cpu_block_ids: - self.prefetch_ready_per_layer[layer_id].record(self.transfer_stream_main) - return - - num_to_load = min(len(cpu_block_ids), len(self.prefetch_slots)) - logger.debug(f"Prefetch load (layer {layer_id}): CPU{cpu_block_ids[:num_to_load]} -> GPU prefetch slots {self.prefetch_slots[:num_to_load]}") - - with torch.cuda.stream(self.transfer_stream_main): - for i in range(num_to_load): - cpu_id = cpu_block_ids[i] - gpu_slot = self.prefetch_slots[i] - # Copy only this layer (not all layers) - self.k_cache_gpu[layer_id, gpu_slot].copy_( - self.k_cache_cpu[layer_id, cpu_id], non_blocking=True - ) - self.v_cache_gpu[layer_id, gpu_slot].copy_( - self.v_cache_cpu[layer_id, cpu_id], non_blocking=True - ) - self.prefetch_ready_per_layer[layer_id].record(self.transfer_stream_main) - - def wait_prefetch_layer(self, layer_id: int) -> None: - """Wait for specific layer's Prefetch region loading to complete.""" - self.compute_stream.wait_event(self.prefetch_ready_per_layer[layer_id]) - - def swap_compute_prefetch(self) -> None: - """Swap roles of Compute region and Prefetch region.""" - self.compute_slots, self.prefetch_slots = self.prefetch_slots, self.compute_slots + # ----- Decode slot methods (kept for decode phase) ----- def offload_decode_slot(self, cpu_block_id: int) -> None: """ - Offload KV from Decode region to CPU. + Offload KV from decode slot (slot[0]) to CPU. Args: cpu_block_id: Target CPU block ID """ - logger.debug(f"Decode offload: GPU[{self.decode_slot}] -> CPU[{cpu_block_id}]") + logger.debug(f"Decode offload: GPU slot[{self.decode_slot}] -> CPU[{cpu_block_id}]") with torch.cuda.stream(self.transfer_stream_main): self.transfer_stream_main.wait_stream(self.compute_stream) @@ -750,61 +788,16 @@ class OffloadEngine: self.decode_offload_done.record(self.transfer_stream_main) def wait_decode_offload(self) -> None: - """Wait for Decode region offload to complete.""" + """Wait for decode slot offload to complete.""" self.compute_stream.wait_event(self.decode_offload_done) - def get_kv_for_compute( - self, - layer_id: int, - num_blocks: int, - ) -> Tuple[Tensor, Tensor]: - """ - Get KV for specified number of blocks in Compute region. - - Args: - layer_id: Layer ID - num_blocks: Number of blocks needed - - Returns: - (k_cache, v_cache), shape: [1, num_blocks * block_size, kv_heads, head_dim] - """ - slots = self.compute_slots[:num_blocks] - k = self.k_cache_gpu[layer_id, slots] # [num_blocks, block_size, heads, dim] - v = self.v_cache_gpu[layer_id, slots] - # Reshape: [num_blocks, block_size, heads, dim] -> [1, num_blocks*block_size, heads, dim] - k = k.reshape(1, -1, self.num_kv_heads, self.head_dim) - v = v.reshape(1, -1, self.num_kv_heads, self.head_dim) - return k, v - - def get_kv_for_prefetch( - self, - layer_id: int, - num_blocks: int, - ) -> Tuple[Tensor, Tensor]: - """ - Get KV for specified number of blocks in Prefetch region. - - Args: - layer_id: Layer ID - num_blocks: Number of blocks needed - - Returns: - (k_cache, v_cache), shape: [1, num_blocks * block_size, kv_heads, head_dim] - """ - slots = self.prefetch_slots[:num_blocks] - k = self.k_cache_gpu[layer_id, slots] - v = self.v_cache_gpu[layer_id, slots] - k = k.reshape(1, -1, self.num_kv_heads, self.head_dim) - v = v.reshape(1, -1, self.num_kv_heads, self.head_dim) - return k, v - def get_kv_for_decode_slot( self, layer_id: int, pos_in_block: int, ) -> Tuple[Tensor, Tensor]: """ - Get KV at specified position in Decode region (for new token during decode). + Get KV at specified position in decode slot. Args: layer_id: Layer ID @@ -813,9 +806,9 @@ class OffloadEngine: Returns: (k_cache, v_cache), shape: [1, 1, kv_heads, head_dim] """ - k = self.k_cache_gpu[layer_id, self.decode_slot, pos_in_block:pos_in_block+1] # [1, heads, dim] + k = self.k_cache_gpu[layer_id, self.decode_slot, pos_in_block:pos_in_block+1] v = self.v_cache_gpu[layer_id, self.decode_slot, pos_in_block:pos_in_block+1] - k = k.unsqueeze(0) # [1, 1, heads, dim] + k = k.unsqueeze(0) v = v.unsqueeze(0) return k, v @@ -825,10 +818,7 @@ class OffloadEngine: num_tokens: int, ) -> Tuple[Tensor, Tensor]: """ - Get accumulated KV in Decode region (all tokens from position 0 to num_tokens-1). - - Used when batching decode offloads - attend to all accumulated tokens, - not just the current one. + Get accumulated KV in decode slot (positions 0 to num_tokens-1). Args: layer_id: Layer ID @@ -837,35 +827,102 @@ class OffloadEngine: Returns: (k_cache, v_cache), shape: [1, num_tokens, kv_heads, head_dim] """ - k = self.k_cache_gpu[layer_id, self.decode_slot, :num_tokens] # [num_tokens, heads, dim] + k = self.k_cache_gpu[layer_id, self.decode_slot, :num_tokens] v = self.v_cache_gpu[layer_id, self.decode_slot, :num_tokens] - k = k.unsqueeze(0) # [1, num_tokens, heads, dim] + k = k.unsqueeze(0) v = v.unsqueeze(0) return k, v - def offload_compute_to_cpu(self, cpu_block_ids: List[int]) -> None: - """ - Offload KV from Compute region to CPU. + # ----- Legacy compatibility methods (for decode double-buffering) ----- - Args: - cpu_block_ids: Target CPU block IDs list + def load_to_compute_layer(self, layer_id: int, cpu_block_ids: List[int]) -> None: + """ + Legacy: Load CPU blocks to decode_load_slots for decode double-buffering. + + Uses first half of decode_load_slots as 'compute' region. """ if not cpu_block_ids: return - num_to_offload = min(len(cpu_block_ids), len(self.compute_slots)) - logger.debug(f"Compute offload: GPU {self.compute_slots[:num_to_offload]} -> CPU{cpu_block_ids[:num_to_offload]}") + half = max(1, len(self.decode_load_slots) // 2) + slots = self.decode_load_slots[:half] + num_to_load = min(len(cpu_block_ids), len(slots)) with torch.cuda.stream(self.transfer_stream_main): - # Wait for compute to complete - self.transfer_stream_main.wait_stream(self.compute_stream) - - for i in range(num_to_offload): - gpu_slot = self.compute_slots[i] + for i in range(num_to_load): cpu_id = cpu_block_ids[i] - self.k_cache_cpu[:, cpu_id].copy_( - self.k_cache_gpu[:, gpu_slot], non_blocking=True + gpu_slot = slots[i] + self.k_cache_gpu[layer_id, gpu_slot].copy_( + self.k_cache_cpu[layer_id, cpu_id], non_blocking=True ) - self.v_cache_cpu[:, cpu_id].copy_( - self.v_cache_gpu[:, gpu_slot], non_blocking=True - ) \ No newline at end of file + self.v_cache_gpu[layer_id, gpu_slot].copy_( + self.v_cache_cpu[layer_id, cpu_id], non_blocking=True + ) + if num_to_load > 0: + self.ring_slot_ready[slots[0]][layer_id].record(self.transfer_stream_main) + + def wait_compute_layer(self, layer_id: int) -> None: + """Legacy: Wait for 'compute' region loading.""" + half = max(1, len(self.decode_load_slots) // 2) + if self.decode_load_slots: + self.wait_slot_layer(self.decode_load_slots[0], layer_id) + + def load_to_prefetch_layer(self, layer_id: int, cpu_block_ids: List[int]) -> None: + """ + Legacy: Load CPU blocks to decode_load_slots for decode double-buffering. + + Uses second half of decode_load_slots as 'prefetch' region. + """ + if not cpu_block_ids: + return + + half = max(1, len(self.decode_load_slots) // 2) + slots = self.decode_load_slots[half:] + if not slots: + slots = self.decode_load_slots # Fallback if only 1-2 slots + num_to_load = min(len(cpu_block_ids), len(slots)) + + with torch.cuda.stream(self.transfer_stream_main): + for i in range(num_to_load): + cpu_id = cpu_block_ids[i] + gpu_slot = slots[i] + self.k_cache_gpu[layer_id, gpu_slot].copy_( + self.k_cache_cpu[layer_id, cpu_id], non_blocking=True + ) + self.v_cache_gpu[layer_id, gpu_slot].copy_( + self.v_cache_cpu[layer_id, cpu_id], non_blocking=True + ) + if num_to_load > 0: + self.ring_slot_ready[slots[0]][layer_id].record(self.transfer_stream_main) + + def wait_prefetch_layer(self, layer_id: int) -> None: + """Legacy: Wait for 'prefetch' region loading.""" + half = max(1, len(self.decode_load_slots) // 2) + slots = self.decode_load_slots[half:] + if slots: + self.wait_slot_layer(slots[0], layer_id) + elif self.decode_load_slots: + self.wait_slot_layer(self.decode_load_slots[0], layer_id) + + def get_kv_for_compute( + self, + layer_id: int, + num_blocks: int, + ) -> Tuple[Tensor, Tensor]: + """Legacy: Get KV from 'compute' region (first half of decode_load_slots).""" + half = max(1, len(self.decode_load_slots) // 2) + slots = self.decode_load_slots[:half][:num_blocks] + return self.get_kv_for_slots(layer_id, slots) + + def get_kv_for_prefetch( + self, + layer_id: int, + num_blocks: int, + ) -> Tuple[Tensor, Tensor]: + """Legacy: Get KV from 'prefetch' region (second half of decode_load_slots).""" + half = max(1, len(self.decode_load_slots) // 2) + slots = self.decode_load_slots[half:] + if not slots: + slots = self.decode_load_slots + slots = slots[:num_blocks] + return self.get_kv_for_slots(layer_id, slots) \ No newline at end of file diff --git a/nanovllm/layers/attention.py b/nanovllm/layers/attention.py index d50456d..dfde652 100644 --- a/nanovllm/layers/attention.py +++ b/nanovllm/layers/attention.py @@ -100,16 +100,19 @@ class Attention(nn.Module): context, ) -> torch.Tensor: """ - Compute attention with three-region GPU buffer for chunked prefill. + Compute attention with unified ring buffer for chunked prefill. - For chunked prefill: - 1. Load previous KV from CPU using Compute/Prefetch region (if any previous chunks) - 2. Compute attention against previous KV chunks (no causal mask) - 3. Compute attention against current chunk's KV (causal) - 4. Merge all results using online softmax + Ring buffer design: + - Current chunk's KV is written to ring_slot[chunk_idx % N] + - Previous chunks' KV are loaded from CPU using N-1 available slots + - Pipeline: pre-fill slots, then process with overlapped load/compute - Three-region design guarantees: current chunk's KV is in Compute region, previous KV is loaded - from CPU to Prefetch region, so write and load regions never overlap. + For each layer: + 1. Current chunk's KV is in k_batched, v_batched (just written by model) + 2. Load previous chunks from CPU using available slots (pipeline) + 3. Compute attention against previous KV (no causal mask) + 4. Compute attention against current KV (causal) + 5. Merge all results using online softmax """ from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs @@ -122,51 +125,33 @@ class Attention(nn.Module): o_acc = None lse_acc = None - # Load previous KV from CPU using Compute/Prefetch region kvcache_manager = context.kvcache_manager seq = context.chunked_seq if hasattr(context, 'chunked_seq') else None + current_chunk_idx = context.current_chunk_idx if kvcache_manager is not None and seq is not None and self.layer_id >= 0: - # Get prefilled CPU blocks (blocks already written in previous chunks) + # Get prefilled CPU blocks (blocks from previous chunks) cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq) if cpu_block_table: offload_engine = kvcache_manager.offload_engine - # For prefill: ONLY use Prefetch region to avoid conflict with - # current chunk's KV being written to Compute region slots - # Use synchronous per-layer loading (async would conflict with writes) - chunk_size = offload_engine.num_prefetch_blocks - num_chunks = (len(cpu_block_table) + chunk_size - 1) // chunk_size - for chunk_idx in range(num_chunks): - start = chunk_idx * chunk_size - end = min(start + chunk_size, len(cpu_block_table)) - num_blocks_in_chunk = end - start - chunk_ids = cpu_block_table[start:end] + # Get write slot for current chunk and available load slots + write_slot = offload_engine.get_write_slot_for_prefill(current_chunk_idx) + load_slots = offload_engine.get_load_slots_for_prefill(write_slot) + pipeline_depth = len(load_slots) - # Load to Prefetch region (per-layer, sync) - offload_engine.load_to_prefetch_layer(self.layer_id, chunk_ids) - offload_engine.wait_prefetch_layer(self.layer_id) - - prev_k, prev_v = offload_engine.get_kv_for_prefetch( - self.layer_id, num_blocks_in_chunk + if pipeline_depth == 0: + # Only 1 slot total, cannot pipeline - use sync loading + o_acc, lse_acc = self._sync_load_previous_chunks( + q_batched, cpu_block_table, offload_engine ) - - # Compute attention against this chunk (no causal mask) - prev_o, prev_lse = flash_attn_with_lse( - q_batched, - prev_k, - prev_v, - softmax_scale=self.scale, - causal=False, + else: + # Use ring buffer pipeline + o_acc, lse_acc = self._ring_buffer_pipeline_load( + q_batched, cpu_block_table, load_slots, offload_engine ) - # Merge with accumulated - if o_acc is None: - o_acc, lse_acc = prev_o, prev_lse - else: - o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse) - # Compute attention against current chunk's KV (with causal mask) current_o, current_lse = flash_attn_with_lse( q_batched, @@ -185,6 +170,91 @@ class Attention(nn.Module): # Remove batch dimension: [1, total_tokens, heads, dim] -> [total_tokens, heads, dim] return final_o.squeeze(0) + def _sync_load_previous_chunks( + self, + q_batched: torch.Tensor, + cpu_block_table: list, + offload_engine, + ): + """Synchronous loading fallback when pipeline_depth=0.""" + from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs + + o_acc, lse_acc = None, None + + for block_idx, cpu_block_id in enumerate(cpu_block_table): + # Load to slot 0 (single slot) + offload_engine.load_to_slot_layer(0, self.layer_id, cpu_block_id) + offload_engine.wait_slot_layer(0, self.layer_id) + + prev_k, prev_v = offload_engine.get_kv_for_slot(0, self.layer_id) + + prev_o, prev_lse = flash_attn_with_lse( + q_batched, prev_k, prev_v, + softmax_scale=self.scale, + causal=False, + ) + + if o_acc is None: + o_acc, lse_acc = prev_o, prev_lse + else: + o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse) + + return o_acc, lse_acc + + def _ring_buffer_pipeline_load( + self, + q_batched: torch.Tensor, + cpu_block_table: list, + load_slots: list, + offload_engine, + ): + """ + Ring buffer synchronous loading for previous chunks. + + For correctness, we use synchronous loading: + - Load one block at a time + - Wait for transfer, compute attention, then load next + + This ensures no data races between transfer and computation. + """ + from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs + + num_blocks = len(cpu_block_table) + if num_blocks == 0: + return None, None + + pipeline_depth = len(load_slots) + o_acc, lse_acc = None, None + + # Process blocks one by one (synchronous) + for block_idx in range(num_blocks): + # Determine which slot to use (cycle through load_slots) + slot_idx = load_slots[block_idx % pipeline_depth] + cpu_block_id = cpu_block_table[block_idx] + + # Load block to slot (async) + offload_engine.load_to_slot_layer(slot_idx, self.layer_id, cpu_block_id) + + # Wait for transfer to complete + offload_engine.wait_slot_layer(slot_idx, self.layer_id) + + # Get KV from slot and compute attention + prev_k, prev_v = offload_engine.get_kv_for_slot(slot_idx, self.layer_id) + + prev_o, prev_lse = flash_attn_with_lse( + q_batched, prev_k, prev_v, + softmax_scale=self.scale, + causal=False, + ) + + # Merge with accumulated + if o_acc is None: + o_acc, lse_acc = prev_o, prev_lse + else: + o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse) + + return o_acc, lse_acc + def _chunked_decode_attention( self, q: torch.Tensor, @@ -193,20 +263,24 @@ class Attention(nn.Module): context, ) -> torch.Tensor: """ - Compute decode attention with async double-buffering using Compute and Prefetch regions. + Compute decode attention with double-buffering using decode_load_slots. + + Decode uses: + - decode_slot (slot[0]): writes new token's KV + - decode_load_slots (slots[1:]): load previous chunks from CPU Pipeline design: - - Compute region: holds current chunk being computed - - Prefetch region: async loads next chunk while current is computing - - After computation, swap roles of the two regions + - First half of decode_load_slots: 'compute' buffer + - Second half: 'prefetch' buffer + - Double-buffer between them for async overlap Timeline: ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ - │Load C0→Comp │ │Load C1→Pref │ │Load C2→Comp │ ... + │Load C0→buf0 │ │Load C1→buf1 │ │Load C2→buf0 │ ... └─────────────┘ └─────────────┘ └─────────────┘ ↘ ↘ ↘ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ - │ Compute C0 │ │ Compute C1 │ │ Compute C2 │ + │ Attn(C0) │ │ Attn(C1) │ │ Attn(C2) │ └─────────────┘ └─────────────┘ └─────────────┘ """ from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs diff --git a/nanovllm/utils/context.py b/nanovllm/utils/context.py index 86163a9..23a1483 100644 --- a/nanovllm/utils/context.py +++ b/nanovllm/utils/context.py @@ -32,6 +32,8 @@ class Context: # Starting position within block where decode tokens began (for accumulated token tracking) # Used when batching decode offloads - we need to attend to all accumulated tokens decode_start_pos_in_block: int = 0 + # Current chunk index for ring buffer pipeline (prefill only) + current_chunk_idx: int = 0 _CONTEXT = Context() @@ -57,6 +59,7 @@ def set_context( chunked_seq=None, decode_pos_in_block=0, decode_start_pos_in_block=0, + current_chunk_idx=0, ): global _CONTEXT _CONTEXT = Context( @@ -75,6 +78,7 @@ def set_context( chunked_seq=chunked_seq, decode_pos_in_block=decode_pos_in_block, decode_start_pos_in_block=decode_start_pos_in_block, + current_chunk_idx=current_chunk_idx, )