[feat] Need to optimized with async prefetch.

This commit is contained in:
Zijie Tian
2025-12-15 06:58:40 +08:00
parent 1081ab51ea
commit b8b6478506
9 changed files with 556 additions and 404 deletions

View File

@@ -44,74 +44,101 @@ Nano-vLLM is a lightweight vLLM implementation (~1,200 lines) for fast offline L
When `enable_cpu_offload=True`, KV cache is stored on CPU with a small GPU buffer for computation. This enables long-context inference with limited GPU memory. When `enable_cpu_offload=True`, KV cache is stored on CPU with a small GPU buffer for computation. This enables long-context inference with limited GPU memory.
### Three-Region GPU Buffer Design ### Unified Ring Buffer Design
``` ```
GPU Slots: [0] [1, 2, 3] [4, 5] GPU Slots: [0] [1] [2] [3] [4] ...
↑ ↑ ↑ ←────────────────────────────→
decode compute prefetch All slots as ring buffer
(1 slot) (N slots) (M slots)
- Decode slot: New token's KV written here during decode Prefill: ALL slots cycle as ring buffer [slot = chunk_idx % N]
- Compute region: Load CPU blocks for current chunk computation Decode: slot[0] = decode_slot, slots[1:] = load slots for previous chunks
- Prefetch region: Async load next chunk while computing current
``` ```
**File**: `nanovllm/kvcache/offload_engine.py` **File**: `nanovllm/kvcache/offload_engine.py`
Key attributes: Key attributes:
- `num_ring_slots`: Total GPU slots (= num_gpu_blocks)
- `ring_slots`: List of all GPU slot indices [0, 1, 2, ...]
- `decode_slot = 0`: Fixed slot for decode KV writes - `decode_slot = 0`: Fixed slot for decode KV writes
- `compute_slots`: List of GPU slots for compute region - `decode_load_slots`: Slots[1:] for loading previous chunks during decode
- `prefetch_slots`: List of GPU slots for prefetch region
- `k_cache_gpu/v_cache_gpu`: Shape `[num_layers, num_gpu_blocks, block_size, kv_heads, head_dim]` - `k_cache_gpu/v_cache_gpu`: Shape `[num_layers, num_gpu_blocks, block_size, kv_heads, head_dim]`
- `k_cache_cpu/v_cache_cpu`: Shape `[num_layers, num_cpu_blocks, block_size, kv_heads, head_dim]` (pinned memory) - `k_cache_cpu/v_cache_cpu`: Shape `[num_layers, num_cpu_blocks, block_size, kv_heads, head_dim]` (pinned memory)
### Per-Layer Loading (Critical Design) Key methods:
**Problem solved**: Original design had layer 0 load ALL layers' KV at once. When layer 0 processed chunk 1, it overwrote chunk 0's data before layer 1+ could read it.
**Solution**: Each layer independently loads only its own KV data:
```python ```python
# Per-layer methods in OffloadEngine # Prefill: get write slot and load slots
load_to_compute_layer(layer_id, cpu_block_ids) # Load single layer to compute region get_write_slot_for_prefill(chunk_idx) # Returns chunk_idx % num_ring_slots
wait_compute_layer(layer_id) # Wait for layer's transfer get_load_slots_for_prefill(write_slot_idx) # Returns all slots except write_slot
load_to_prefetch_layer(layer_id, cpu_block_ids) # Load single layer to prefetch region
wait_prefetch_layer(layer_id) # Wait for layer's prefetch # Decode: get load slots (excludes decode_slot)
get_load_slots_for_decode() # Returns slots[1:]
# Per-slot per-layer operations
load_to_slot_layer(slot_idx, layer_id, cpu_block_id) # Async load single block
wait_slot_layer(slot_idx, layer_id) # Wait for layer's transfer
offload_slot_to_cpu(slot_idx, cpu_block_id) # Async offload to CPU
``` ```
### Chunked Prefill Flow ### Per-Slot Per-Layer Events (Critical Design)
Each slot has per-layer CUDA events for fine-grained synchronization:
- `ring_slot_ready[slot_idx][layer_id]`: H2D transfer completion
- `ring_slot_offload_done[slot_idx][layer_id]`: D2H transfer completion
This enables:
1. Overlapped H2D transfer with attention computation
2. Each layer independently waits for its own data
3. Pipeline depth = N-1 for prefill (N slots, 1 for writing)
### Chunked Prefill Flow (Ring Buffer Pipeline)
**File**: `nanovllm/layers/attention.py` - `_chunked_prefill_attention()` **File**: `nanovllm/layers/attention.py` - `_chunked_prefill_attention()`
``` ```
For each prefill chunk: For prefill chunk K:
1. Current chunk's KV is written to GPU (compute region slots) 1. Current chunk's KV written to ring_slot[K % N]
2. Load previous chunks' KV from CPU to prefetch region 2. Load previous chunks from CPU using N-1 available slots (pipeline)
3. Compute attention against previous KV (no causal mask) 3. Compute attention against previous KV (no causal mask)
4. Compute attention against current KV (causal mask) 4. Compute attention against current KV (causal mask)
5. Merge results using online softmax (LSE) 5. Merge results using online softmax (LSE)
6. Offload current chunk's KV to CPU 6. Offload current slot to CPU
Pipeline Timeline (with 4 slots, processing chunk 3):
write_slot = 3, load_slots = [0, 1, 2]
┌─────────────┐ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐
│Load B0→S0 │ │Load B1→S1 │ │Load B2→S2 │ │ (wait) │
└─────────────┘ └─────────────┘ └─────────────┘ └─────────────┘
↘ ↘ ↘
┌─────────────┐ ┌─────────────┐ ┌─────────────┐
│ Attn(B0) │ │ Attn(B1) │ │ Attn(B2) │
└─────────────┘ └─────────────┘ └─────────────┘
``` ```
**Important**: Prefill uses ONLY prefetch region to avoid conflict with current chunk's KV being written to compute region. **Key**: Write slot cycles through ALL slots, load slots = all except write slot.
### Chunked Decode Flow (Double Buffering) ### Chunked Decode Flow (Double Buffering)
**File**: `nanovllm/layers/attention.py` - `_chunked_decode_attention()` **File**: `nanovllm/layers/attention.py` - `_chunked_decode_attention()`
Decode uses legacy double-buffering with `decode_load_slots`:
- First half of decode_load_slots: 'compute' buffer
- Second half: 'prefetch' buffer
``` ```
Timeline (async double buffering): Timeline:
┌─────────────┐ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐
Load: │C0 → Compute │ │C1 → Prefetch│ │C2 → Compute Load: │C0 → buf0 │ │C1 → buf1 │ │C2 → buf0
└─────────────┘ └─────────────┘ └─────────────┘ └─────────────┘ └─────────────┘ └─────────────┘
↘ ↘ ↘ ↘ ↘ ↘
Compute: [C0] [C1] [C2] Compute: [C0] [C1] [C2]
1. Pre-load first chunk to compute region 1. Pre-load first chunk to compute buffer
2. Wait for current buffer, trigger async prefetch of next chunk to OTHER buffer 2. Wait for current buffer, trigger async prefetch to OTHER buffer
3. Compute attention, merge results 3. Compute attention, merge results
4. Swap buffers, repeat 4. Swap buffers, repeat
5. Finally attend to decode slot (new token's KV) 5. Finally attend to decode_slot (new token's KV)
``` ```
### HybridKVCacheManager ### HybridKVCacheManager
@@ -120,7 +147,7 @@ Compute: [C0] [C1] [C2]
Manages both GPU and CPU blocks: Manages both GPU and CPU blocks:
- `allocate()`: Allocate GPU block first, fallback to CPU - `allocate()`: Allocate GPU block first, fallback to CPU
- `allocate_cpu_only()`: Force CPU allocation (for chunked offload mode) - `allocate_cpu_only()`: Force CPU allocation (for ring buffer mode)
- `get_all_cpu_blocks(seq)`: Get all CPU block IDs for a sequence - `get_all_cpu_blocks(seq)`: Get all CPU block IDs for a sequence
- `get_prefilled_cpu_blocks(seq)`: Get CPU blocks from previous chunks - `get_prefilled_cpu_blocks(seq)`: Get CPU blocks from previous chunks
- `get_write_slot_for_chunked_offload(seq)`: Get GPU slot for writing new KV (returns decode_slot) - `get_write_slot_for_chunked_offload(seq)`: Get GPU slot for writing new KV (returns decode_slot)
@@ -136,9 +163,7 @@ def merge_attention_outputs(o1, lse1, o2, lse2):
# Uses LSE to correctly weight and combine partial attention outputs # Uses LSE to correctly weight and combine partial attention outputs
``` ```
### Ring Buffer Design (Future Optimization) ### Pipeline Depth
Current double-buffering limits pipeline depth. Planned improvement: - **Prefill**: Pipeline depth = N-1 (where N = num_gpu_blocks)
- Unified ring buffer using all GPU slots (except decode) - **Decode**: Pipeline depth = (N-1)/2 (double buffering within decode_load_slots)
- Per-slot per-layer CUDA events for fine-grained sync
- Deeper pipeline: prefetch N-1 blocks ahead (vs 1 chunk)

View File

@@ -41,8 +41,8 @@ def main():
max_model_len=128 * 1024, max_model_len=128 * 1024,
max_num_batched_tokens=128 * 1024, max_num_batched_tokens=128 * 1024,
enable_cpu_offload=True, enable_cpu_offload=True,
num_gpu_blocks=6, num_gpu_blocks=120,
num_prefetch_blocks=2, num_prefetch_blocks=4,
) )
# Warmup # Warmup
@@ -54,12 +54,12 @@ def main():
# bench_prefill(llm, num_seqs=1, input_len=1024) # bench_prefill(llm, num_seqs=1, input_len=1024)
# bench_prefill(llm, num_seqs=1, input_len=2048) # bench_prefill(llm, num_seqs=1, input_len=2048)
# bench_prefill(llm, num_seqs=1, input_len=4096) # bench_prefill(llm, num_seqs=1, input_len=4096)
bench_prefill(llm, num_seqs=1, input_len=64 * 1024) bench_prefill(llm, num_seqs=1, input_len=16 * 1024)
print("=" * 60) print("=" * 60)
print("Decode Benchmark (CPU Offload)") print("Decode Benchmark (CPU Offload)")
print("=" * 60) print("=" * 60)
bench_decode(llm, num_seqs=1, input_len=64 * 1024, max_output_len=128) bench_decode(llm, num_seqs=1, input_len=16 * 1024, max_output_len=128)
# bench_decode(llm, num_seqs=1, input_len=2048, max_output_len=128) # bench_decode(llm, num_seqs=1, input_len=2048, max_output_len=128)

View File

@@ -14,7 +14,7 @@ class Config:
enforce_eager: bool = False enforce_eager: bool = False
hf_config: AutoConfig | None = None hf_config: AutoConfig | None = None
eos: int = -1 eos: int = -1
kvcache_block_size: int = 256 kvcache_block_size: int = 4096
num_kvcache_blocks: int = -1 num_kvcache_blocks: int = -1
# CPU Offload configuration # CPU Offload configuration

View File

@@ -630,29 +630,31 @@ class ModelRunner:
def run_chunked_offload_prefill(self, seqs: list[Sequence]) -> list[int]: def run_chunked_offload_prefill(self, seqs: list[Sequence]) -> list[int]:
""" """
Run prefill with three-region GPU buffer (CPU is primary storage). Run prefill with unified ring buffer (CPU is primary storage).
Flow: Flow:
1. All blocks are allocated to CPU (primary storage) 1. All blocks are allocated to CPU (primary storage)
2. Process tokens in chunks using Compute region GPU buffer 2. Each chunk writes KV to ring buffer slot[chunk_idx % N]
3. After each chunk, offload from Compute region to CPU 3. After each chunk, offload from ring buffer slot to CPU
4. Prefetch region is used to load previous KV (if any) 4. All N-1 other slots are used to load previous chunks for attention
""" """
import sys import sys
assert len(seqs) == 1, "Three-region prefill only supports single sequence" assert len(seqs) == 1, "Ring buffer prefill only supports single sequence"
seq = seqs[0] seq = seqs[0]
offload_engine = self.kvcache_manager.offload_engine offload_engine = self.kvcache_manager.offload_engine
compute_size = offload_engine.num_compute_blocks # Each chunk uses 1 ring buffer slot = 1 block
tokens_per_chunk = compute_size * self.block_size tokens_per_chunk = self.block_size
total_tokens = len(seq) total_tokens = len(seq)
print(f"[Three-region Prefill] Starting: {total_tokens} tokens, " num_chunks = (total_tokens + tokens_per_chunk - 1) // tokens_per_chunk
f"compute_size={compute_size} blocks, chunk={tokens_per_chunk} tokens", print(f"[Ring Buffer Prefill] Starting: {total_tokens} tokens, "
f"ring_slots={offload_engine.num_ring_slots}, chunk={tokens_per_chunk} tokens, "
f"total_chunks={num_chunks}",
file=sys.stderr) file=sys.stderr)
chunk_num = 0 chunk_idx = 0
logits = None logits = None
processed_tokens = 0 processed_tokens = 0
@@ -660,27 +662,22 @@ class ModelRunner:
cpu_block_ids, logical_ids = self.kvcache_manager.get_all_cpu_blocks(seq) cpu_block_ids, logical_ids = self.kvcache_manager.get_all_cpu_blocks(seq)
while processed_tokens < total_tokens: while processed_tokens < total_tokens:
chunk_num += 1
chunk_start = processed_tokens chunk_start = processed_tokens
chunk_end = min(processed_tokens + tokens_per_chunk, total_tokens) chunk_end = min(processed_tokens + tokens_per_chunk, total_tokens)
chunk_tokens = chunk_end - chunk_start
# Calculate which CPU blocks this chunk covers # Get ring buffer slot for this chunk
start_block_idx = chunk_start // self.block_size write_slot = offload_engine.get_write_slot_for_prefill(chunk_idx)
end_block_idx = (chunk_end + self.block_size - 1) // self.block_size
num_blocks = end_block_idx - start_block_idx
print(f"[Three-region Prefill] Chunk {chunk_num}: tokens {chunk_start}-{chunk_end}, " # CPU block index for this chunk
f"blocks {start_block_idx}-{end_block_idx-1}, " block_idx = chunk_idx
f"compute_slots={offload_engine.compute_slots[:num_blocks]}",
print(f"[Ring Buffer Prefill] Chunk {chunk_idx}: tokens {chunk_start}-{chunk_end}, "
f"write_slot={write_slot}",
file=sys.stderr) file=sys.stderr)
# Get GPU slots for this chunk (using Compute region)
gpu_slots = offload_engine.compute_slots[:num_blocks]
# Prepare inputs # Prepare inputs
input_ids, positions = self._prepare_chunked_offload_chunk( input_ids, positions = self._prepare_chunked_offload_chunk(
seq, chunk_start, chunk_end, gpu_slots, start_block_idx seq, chunk_start, chunk_end, write_slot, block_idx, chunk_idx
) )
if input_ids.numel() == 0: if input_ids.numel() == 0:
@@ -690,24 +687,27 @@ class ModelRunner:
logits = self.run_model(input_ids, positions, is_prefill=True) logits = self.run_model(input_ids, positions, is_prefill=True)
reset_context() reset_context()
# Mark blocks as prefilled # Mark block as prefilled
for i in range(start_block_idx, min(end_block_idx, len(seq.block_table))): if block_idx < len(seq.block_table):
logical_id = seq.block_table[i] logical_id = seq.block_table[block_idx]
self.kvcache_manager.prefilled_blocks.add(logical_id) self.kvcache_manager.prefilled_blocks.add(logical_id)
# Offload this chunk from Compute region to CPU (async) # Offload this chunk's ring buffer slot to CPU (async)
chunk_cpu_blocks = cpu_block_ids[start_block_idx:end_block_idx] if block_idx < len(cpu_block_ids):
offload_engine.offload_compute_to_cpu(chunk_cpu_blocks) cpu_block_id = cpu_block_ids[block_idx]
offload_engine.offload_slot_to_cpu(write_slot, cpu_block_id)
# Wait for offload to complete before next chunk # Wait for offload to complete before next chunk
offload_engine.wait_all_offload_done() # (slot will be reused after N chunks)
offload_engine.wait_slot_offload(write_slot)
processed_tokens = chunk_end processed_tokens = chunk_end
chunk_idx += 1
# Wait for all offloads to complete # Wait for all offloads to complete
offload_engine.wait_all_offload_done() offload_engine.wait_all_offload_done()
print(f"[Three-region Prefill] Complete: {chunk_num} chunks", file=sys.stderr) print(f"[Ring Buffer Prefill] Complete: {chunk_idx} chunks", file=sys.stderr)
# Sample from last logits # Sample from last logits
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
@@ -723,34 +723,24 @@ class ModelRunner:
seq: Sequence, seq: Sequence,
chunk_start: int, chunk_start: int,
chunk_end: int, chunk_end: int,
gpu_slots: list[int], write_slot: int,
start_block_idx: int, block_idx: int,
chunk_idx: int,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
"""Prepare inputs for a chunked offload prefill chunk.""" """Prepare inputs for a chunked offload prefill chunk (ring buffer design)."""
# Input tokens for this chunk # Input tokens for this chunk
input_ids = seq[chunk_start:chunk_end] input_ids = seq[chunk_start:chunk_end]
positions = list(range(chunk_start, chunk_end)) positions = list(range(chunk_start, chunk_end))
# Create slot mapping pointing to GPU slots # Create slot mapping pointing to the single write_slot
slot_mapping = [] slot_mapping = []
num_tokens = chunk_end - chunk_start for pos in range(chunk_start, chunk_end):
token_idx = 0
for i, gpu_slot in enumerate(gpu_slots):
block_idx = start_block_idx + i
block_start = block_idx * self.block_size
block_end = min(block_start + self.block_size, len(seq))
# How many tokens in this block for this chunk
overlap_start = max(chunk_start, block_start)
overlap_end = min(chunk_end, block_end)
for pos in range(overlap_start, overlap_end):
pos_in_block = pos % self.block_size pos_in_block = pos % self.block_size
slot = gpu_slot * self.block_size + pos_in_block slot = write_slot * self.block_size + pos_in_block
slot_mapping.append(slot) slot_mapping.append(slot)
# Convert to tensors # Convert to tensors
num_tokens = chunk_end - chunk_start
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
@@ -770,21 +760,23 @@ class ModelRunner:
is_chunked_prefill=True, is_chunked_prefill=True,
kvcache_manager=self.kvcache_manager, kvcache_manager=self.kvcache_manager,
chunked_seq=seq, chunked_seq=seq,
current_chunk_idx=chunk_idx, # Pass chunk index for ring buffer pipeline
) )
return input_ids, positions return input_ids, positions
def run_chunked_offload_decode(self, seqs: list[Sequence]) -> list[int]: def run_chunked_offload_decode(self, seqs: list[Sequence]) -> list[int]:
""" """
Run decode with three-region GPU buffer. Run decode with ring buffer (CPU is primary storage).
All KV is on CPU. Uses Decode region to write new KV, Compute/Prefetch region to load KV chunks. All KV is on CPU. Uses decode_slot (slot[0]) to write new KV.
New token's KV is written to Decode region (slot 0) then offloaded to CPU only when block is full. Other slots (slots[1:]) are used to load previous KV chunks via pipeline.
New token's KV is written to decode_slot then offloaded to CPU only when block is full.
Key: Decode region is never overwritten by Compute/Prefetch, dedicated to writing new KV. Key: decode_slot is dedicated to writing new KV, never used for loading.
Optimization: Batch offloads - only offload when block is full, attend to all accumulated tokens. Optimization: Batch offloads - only offload when block is full, attend to all accumulated tokens.
""" """
assert len(seqs) == 1, "Three-region decode only supports single sequence" assert len(seqs) == 1, "Ring buffer decode only supports single sequence"
seq = seqs[0] seq = seqs[0]
offload_engine = self.kvcache_manager.offload_engine offload_engine = self.kvcache_manager.offload_engine

View File

@@ -12,7 +12,7 @@ class SequenceStatus(Enum):
class Sequence: class Sequence:
block_size = 256 block_size = 4096
counter = count() counter = count()
def __init__(self, token_ids: list[int], sampling_params = SamplingParams()): def __init__(self, token_ids: list[int], sampling_params = SamplingParams()):

View File

@@ -95,16 +95,16 @@ class HybridKVCacheManager(KVCacheManager):
num_cpu_blocks: Number of CPU pool blocks (overflow or primary storage) num_cpu_blocks: Number of CPU pool blocks (overflow or primary storage)
block_size: Tokens per block block_size: Tokens per block
policy: Eviction policy (default: LRU) policy: Eviction policy (default: LRU)
cpu_primary: If True, use CPU as primary storage with three-region GPU buffer. cpu_primary: If True, use CPU as primary storage with ring buffer GPU design.
If False, use GPU as primary with CPU as overflow (legacy mode). If False, use GPU as primary with CPU as overflow (legacy mode).
num_prefetch_blocks: Number of prefetch blocks for three-region GPU buffer design num_prefetch_blocks: Number of blocks for ring buffer pipeline (deprecated, ring_slots = num_gpu_slots)
""" """
self._block_size = block_size self._block_size = block_size
self.num_gpu_slots = num_gpu_slots self.num_gpu_slots = num_gpu_slots
self.num_cpu_blocks = num_cpu_blocks self.num_cpu_blocks = num_cpu_blocks
self.total_blocks = num_gpu_slots + num_cpu_blocks self.total_blocks = num_gpu_slots + num_cpu_blocks
self.cpu_primary = cpu_primary # Three-region mode flag self.cpu_primary = cpu_primary # Ring buffer mode flag
self.num_prefetch_blocks = num_prefetch_blocks # Three-region design parameter self.num_prefetch_blocks = num_prefetch_blocks # Ring buffer design parameter (deprecated)
# Eviction policy # Eviction policy
self.policy = policy or LRUPolicy() self.policy = policy or LRUPolicy()
@@ -341,7 +341,7 @@ class HybridKVCacheManager(KVCacheManager):
""" """
assert not seq.block_table, "Sequence already has blocks" assert not seq.block_table, "Sequence already has blocks"
# Three-region mode: all blocks are allocated to CPU # Ring buffer mode: all blocks are allocated to CPU
if self.cpu_primary: if self.cpu_primary:
return self.allocate_cpu_only(seq) return self.allocate_cpu_only(seq)
@@ -471,7 +471,7 @@ class HybridKVCacheManager(KVCacheManager):
block.token_ids = [] block.token_ids = []
if self.cpu_primary: if self.cpu_primary:
# Three-region mode: new block allocated to CPU # Ring buffer mode: new block allocated to CPU
if not self.free_cpu_blocks: if not self.free_cpu_blocks:
raise RuntimeError("No free CPU blocks for decode") raise RuntimeError("No free CPU blocks for decode")
cpu_block_id = self.free_cpu_blocks.popleft() cpu_block_id = self.free_cpu_blocks.popleft()
@@ -1025,14 +1025,14 @@ class HybridKVCacheManager(KVCacheManager):
break break
return pos return pos
# ========== Three-region double buffering support ========== # ========== Ring Buffer CPU-primary support ==========
def allocate_cpu_only(self, seq: Sequence) -> None: def allocate_cpu_only(self, seq: Sequence) -> None:
""" """
Allocate CPU blocks for sequence (for three-region mode). Allocate CPU blocks for sequence (for ring buffer mode).
Unlike allocate(), here all blocks are allocated to CPU, Unlike allocate(), here all blocks are allocated to CPU,
GPU is only used as working buffer. GPU is only used as ring buffer for computation.
Args: Args:
seq: Sequence to allocate seq: Sequence to allocate
@@ -1092,10 +1092,10 @@ class HybridKVCacheManager(KVCacheManager):
cpu_blocks.append(block.cpu_block_id) cpu_blocks.append(block.cpu_block_id)
else: else:
# If block is on GPU, it should have a corresponding CPU block # If block is on GPU, it should have a corresponding CPU block
# In three-region mode, all data ultimately resides on CPU # In ring buffer mode, all data ultimately resides on CPU
raise RuntimeError( raise RuntimeError(
f"Block {logical_id} not on CPU (location={block.location}). " f"Block {logical_id} not on CPU (location={block.location}). "
f"In three-region mode, all blocks should be on CPU." f"In ring buffer mode, all blocks should be on CPU."
) )
return cpu_blocks return cpu_blocks
@@ -1171,8 +1171,8 @@ class HybridKVCacheManager(KVCacheManager):
""" """
Get GPU slot for writing new KV during chunked offload decode. Get GPU slot for writing new KV during chunked offload decode.
In three-region design, always use Decode region (slot 0) to write new KV. In ring buffer design, always use decode_slot (slot[0]) to write new KV.
This avoids conflicts with Compute/Prefetch region loading operations. This avoids conflicts with loading operations which use slots[1:].
Args: Args:
seq: Sequence seq: Sequence

View File

@@ -65,34 +65,30 @@ class OffloadEngine:
self.kv_dim = num_kv_heads * head_dim self.kv_dim = num_kv_heads * head_dim
self.block_numel = block_size * self.kv_dim self.block_numel = block_size * self.kv_dim
# ========== Three-region GPU Buffer configuration ========== # ========== Unified Ring Buffer configuration ==========
# Constraint checks # Constraint checks
assert num_gpu_blocks >= 3, \ assert num_gpu_blocks >= 2, \
f"Need at least 3 GPU blocks: 1 decode + 1 compute + 1 prefetch, got {num_gpu_blocks}" f"Need at least 2 GPU blocks for ring buffer, got {num_gpu_blocks}"
assert num_prefetch_blocks >= 1, \
f"Need at least 1 prefetch block, got {num_prefetch_blocks}"
assert num_gpu_blocks >= 1 + 1 + num_prefetch_blocks, \
f"Insufficient GPU blocks: need 1(decode) + 1(compute) + {num_prefetch_blocks}(prefetch), got {num_gpu_blocks}"
# Three-region configuration # Unified Ring Buffer: all slots cycle for prefill
# Decode region: [0] - Fixed 1 block for writing new KV # Prefill: use ALL slots as ring buffer (slot[chunk_idx % N])
# Decode: slot[0] as decode_slot, slots[1:] for loading previous chunks
self.num_ring_slots = num_gpu_blocks
self.ring_slots = list(range(num_gpu_blocks))
# Decode phase uses slot[0] for writing new token's KV
self.decode_slot = 0 self.decode_slot = 0
# Decode phase uses slots[1:] for loading previous chunks from CPU
self.decode_load_slots = list(range(1, num_gpu_blocks))
self.num_decode_load_slots = len(self.decode_load_slots)
# Compute region: [1, ..., num_gpu_blocks - num_prefetch_blocks - 1] # Keep num_prefetch_blocks for compatibility (used as chunk size for loading)
compute_start = 1
compute_end = num_gpu_blocks - num_prefetch_blocks
self.compute_slots = list(range(compute_start, compute_end))
self.num_compute_blocks = len(self.compute_slots)
# Prefetch region: [num_gpu_blocks - num_prefetch_blocks, ..., num_gpu_blocks - 1]
prefetch_start = compute_end
self.prefetch_slots = list(range(prefetch_start, num_gpu_blocks))
self.num_prefetch_blocks = num_prefetch_blocks self.num_prefetch_blocks = num_prefetch_blocks
self.num_gpu_slots = num_gpu_blocks # alias self.num_gpu_slots = num_gpu_blocks # alias
logger.info(f"Three-region GPU Buffer: decode_slot={self.decode_slot}, " logger.info(f"Unified Ring Buffer: {self.num_ring_slots} slots total")
f"compute_slots={self.compute_slots}, prefetch_slots={self.prefetch_slots}") logger.info(f" Prefill: all slots as ring buffer [0..{num_gpu_blocks-1}]")
logger.info(f" Decode: slot[0] as decode_slot, slots[1..{num_gpu_blocks-1}] for loading")
# ========== Fixed-address GPU KV cache ========== # ========== Fixed-address GPU KV cache ==========
# Shape: [num_layers, num_gpu_blocks, block_size, kv_heads, head_dim] # Shape: [num_layers, num_gpu_blocks, block_size, kv_heads, head_dim]
@@ -134,18 +130,27 @@ class OffloadEngine:
self.compute_stream = torch.cuda.current_stream() self.compute_stream = torch.cuda.current_stream()
self._stream_idx = 0 self._stream_idx = 0
# ========== Three-region dedicated stream and events ========== # ========== Ring Buffer dedicated stream and events ==========
self.transfer_stream_main = torch.cuda.Stream() # Main transfer stream self.transfer_stream_main = torch.cuda.Stream() # Main transfer stream
# Sync events - three-region loading completion # Decode offload event
self.compute_ready = torch.cuda.Event()
self.prefetch_ready = torch.cuda.Event()
self.decode_offload_done = torch.cuda.Event() self.decode_offload_done = torch.cuda.Event()
# ========== Per-layer events for chunked attention ========== # ========== Per-slot Per-layer events for ring buffer ==========
# Each layer has its own event for synchronization # ring_slot_ready[slot_idx][layer_id] = CUDA Event for H2D completion
self.compute_ready_per_layer = [torch.cuda.Event() for _ in range(num_layers)] # ring_slot_offload_done[slot_idx][layer_id] = CUDA Event for D2H completion
self.prefetch_ready_per_layer = [torch.cuda.Event() for _ in range(num_layers)] self.ring_slot_ready = [
[torch.cuda.Event() for _ in range(num_layers)]
for _ in range(self.num_ring_slots)
]
self.ring_slot_offload_done = [
[torch.cuda.Event() for _ in range(num_layers)]
for _ in range(self.num_ring_slots)
]
# Per-slot events for all-layer operations (used in some legacy paths)
self.ring_slot_all_layers_ready = [torch.cuda.Event() for _ in range(self.num_ring_slots)]
self.ring_slot_all_layers_offload_done = [torch.cuda.Event() for _ in range(self.num_ring_slots)]
# ========== Event tracking for async transfers ========== # ========== Event tracking for async transfers ==========
self.pending_events: Dict[Tuple[int, int], torch.cuda.Event] = {} self.pending_events: Dict[Tuple[int, int], torch.cuda.Event] = {}
@@ -560,7 +565,7 @@ class OffloadEngine:
f" kv_heads={self.num_kv_heads},\n" f" kv_heads={self.num_kv_heads},\n"
f" head_dim={self.head_dim},\n" f" head_dim={self.head_dim},\n"
f" dtype={self.dtype},\n" f" dtype={self.dtype},\n"
f" three_regions: decode_slot={self.decode_slot}, compute={self.compute_slots}, prefetch={self.prefetch_slots},\n" f" ring_buffer: {self.num_ring_slots} slots, decode_slot={self.decode_slot}, decode_load_slots={self.decode_load_slots},\n"
f" gpu_memory={self.gpu_memory_bytes() / 1024**2:.1f}MB,\n" f" gpu_memory={self.gpu_memory_bytes() / 1024**2:.1f}MB,\n"
f" cpu_memory={self.cpu_memory_bytes() / 1024**2:.1f}MB\n" f" cpu_memory={self.cpu_memory_bytes() / 1024**2:.1f}MB\n"
f")" f")"
@@ -570,174 +575,207 @@ class OffloadEngine:
"""Wait for all offload operations to complete.""" """Wait for all offload operations to complete."""
self.transfer_stream_main.synchronize() self.transfer_stream_main.synchronize()
# ========== Unified Ring Buffer methods ==========
# ----- Prefill: Ring Buffer slot management -----
def get_write_slot_for_prefill(self, chunk_idx: int) -> int:
"""
Get ring buffer slot for writing prefill chunk.
For prefill, ALL slots are used as ring buffer, cycling through.
Args:
chunk_idx: Current chunk index (0, 1, 2, ...)
Returns:
GPU slot index for writing
"""
return chunk_idx % self.num_ring_slots
def get_load_slots_for_prefill(self, write_slot_idx: int) -> List[int]:
"""
Get available slots for loading previous chunks during prefill.
Excludes the current write slot to avoid conflict.
Args:
write_slot_idx: Current write slot index
Returns:
List of slot indices available for loading (N-1 slots)
"""
return [i for i in range(self.num_ring_slots) if i != write_slot_idx]
# ----- Decode: slot management -----
def get_load_slots_for_decode(self) -> List[int]:
"""
Get slots available for loading during decode.
Excludes decode_slot (slot[0]) since it's used for writing new token's KV.
Returns:
List of slot indices for loading (slots[1:])
"""
return self.decode_load_slots
# ----- Per-slot Per-layer loading methods -----
def load_to_slot_layer(self, slot_idx: int, layer_id: int, cpu_block_id: int) -> None:
"""
Async load a single CPU block to a ring buffer slot for one layer.
This is the core building block for ring buffer pipelining.
Args:
slot_idx: Target GPU slot index
layer_id: Layer index to load
cpu_block_id: Source CPU block ID
"""
logger.debug(f"Ring load: layer={layer_id}, CPU[{cpu_block_id}] -> GPU slot[{slot_idx}]")
with torch.cuda.stream(self.transfer_stream_main):
self.k_cache_gpu[layer_id, slot_idx].copy_(
self.k_cache_cpu[layer_id, cpu_block_id], non_blocking=True
)
self.v_cache_gpu[layer_id, slot_idx].copy_(
self.v_cache_cpu[layer_id, cpu_block_id], non_blocking=True
)
self.ring_slot_ready[slot_idx][layer_id].record(self.transfer_stream_main)
def wait_slot_layer(self, slot_idx: int, layer_id: int) -> None:
"""
Wait for a slot's loading to complete for a specific layer.
Args:
slot_idx: GPU slot index to wait for
layer_id: Layer index to wait for
"""
self.compute_stream.wait_event(self.ring_slot_ready[slot_idx][layer_id])
def load_to_slot_all_layers(self, slot_idx: int, cpu_block_id: int) -> None:
"""
Async load a CPU block to a ring buffer slot for ALL layers.
Args:
slot_idx: Target GPU slot index
cpu_block_id: Source CPU block ID
"""
logger.debug(f"Ring load all layers: CPU[{cpu_block_id}] -> GPU slot[{slot_idx}]")
with torch.cuda.stream(self.transfer_stream_main):
self.k_cache_gpu[:, slot_idx].copy_(
self.k_cache_cpu[:, cpu_block_id], non_blocking=True
)
self.v_cache_gpu[:, slot_idx].copy_(
self.v_cache_cpu[:, cpu_block_id], non_blocking=True
)
self.ring_slot_all_layers_ready[slot_idx].record(self.transfer_stream_main)
def wait_slot_all_layers(self, slot_idx: int) -> None:
"""Wait for a slot's loading to complete for ALL layers."""
self.compute_stream.wait_event(self.ring_slot_all_layers_ready[slot_idx])
# ----- Slot offload methods -----
def offload_slot_to_cpu(self, slot_idx: int, cpu_block_id: int) -> None:
"""
Async offload a ring buffer slot to CPU (all layers).
Args:
slot_idx: Source GPU slot index
cpu_block_id: Target CPU block ID
"""
logger.debug(f"Ring offload: GPU slot[{slot_idx}] -> CPU[{cpu_block_id}]")
with torch.cuda.stream(self.transfer_stream_main):
self.transfer_stream_main.wait_stream(self.compute_stream)
self.k_cache_cpu[:, cpu_block_id].copy_(
self.k_cache_gpu[:, slot_idx], non_blocking=True
)
self.v_cache_cpu[:, cpu_block_id].copy_(
self.v_cache_gpu[:, slot_idx], non_blocking=True
)
self.ring_slot_all_layers_offload_done[slot_idx].record(self.transfer_stream_main)
def wait_slot_offload(self, slot_idx: int) -> None:
"""Wait for slot offload to complete."""
self.compute_stream.wait_event(self.ring_slot_all_layers_offload_done[slot_idx])
def offload_slot_layer_to_cpu(self, slot_idx: int, layer_id: int, cpu_block_id: int) -> None:
"""
Async offload a ring buffer slot to CPU for one layer.
Args:
slot_idx: Source GPU slot index
layer_id: Layer index to offload
cpu_block_id: Target CPU block ID
"""
with torch.cuda.stream(self.transfer_stream_main):
self.transfer_stream_main.wait_stream(self.compute_stream)
self.k_cache_cpu[layer_id, cpu_block_id].copy_(
self.k_cache_gpu[layer_id, slot_idx], non_blocking=True
)
self.v_cache_cpu[layer_id, cpu_block_id].copy_(
self.v_cache_gpu[layer_id, slot_idx], non_blocking=True
)
self.ring_slot_offload_done[slot_idx][layer_id].record(self.transfer_stream_main)
def wait_slot_offload_layer(self, slot_idx: int, layer_id: int) -> None:
"""Wait for slot offload to complete for a specific layer."""
self.compute_stream.wait_event(self.ring_slot_offload_done[slot_idx][layer_id])
# ----- KV access methods for ring buffer -----
def get_kv_for_slot(self, slot_idx: int, layer_id: int) -> Tuple[Tensor, Tensor]:
"""
Get KV for a single ring buffer slot.
Args:
slot_idx: GPU slot index
layer_id: Layer ID
Returns:
(k_cache, v_cache), shape: [1, block_size, kv_heads, head_dim]
"""
k = self.k_cache_gpu[layer_id, slot_idx].unsqueeze(0) # [1, block_size, heads, dim]
v = self.v_cache_gpu[layer_id, slot_idx].unsqueeze(0)
return k, v
def get_kv_for_slots( def get_kv_for_slots(
self, self,
layer_id: int, layer_id: int,
gpu_slots: List[int], slot_indices: List[int],
) -> Tuple[Tensor, Tensor]: ) -> Tuple[Tensor, Tensor]:
""" """
Get KV for specified GPU slots. Get KV for multiple ring buffer slots.
Args: Args:
layer_id: Layer ID layer_id: Layer ID
gpu_slots: List of GPU slot IDs slot_indices: List of GPU slot indices
Returns: Returns:
(k_cache, v_cache), shape: [1, len(slots) * block_size, kv_heads, head_dim] (k_cache, v_cache), shape: [1, len(slots) * block_size, kv_heads, head_dim]
""" """
if not gpu_slots: if not slot_indices:
return None, None return None, None
k = self.k_cache_gpu[layer_id, gpu_slots] k = self.k_cache_gpu[layer_id, slot_indices]
v = self.v_cache_gpu[layer_id, gpu_slots] v = self.v_cache_gpu[layer_id, slot_indices]
k = k.reshape(1, -1, self.num_kv_heads, self.head_dim) k = k.reshape(1, -1, self.num_kv_heads, self.head_dim)
v = v.reshape(1, -1, self.num_kv_heads, self.head_dim) v = v.reshape(1, -1, self.num_kv_heads, self.head_dim)
return k, v return k, v
# ========== Three-region GPU Buffer methods ========== # ----- Decode slot methods (kept for decode phase) -----
def load_to_compute(self, cpu_block_ids: List[int]) -> None:
"""
Async load CPU blocks to Compute region.
Args:
cpu_block_ids: List of CPU block IDs to load
"""
if not cpu_block_ids:
self.compute_ready.record(self.transfer_stream_main)
return
num_to_load = min(len(cpu_block_ids), len(self.compute_slots))
logger.debug(f"Compute load: CPU{cpu_block_ids[:num_to_load]} -> GPU compute slots {self.compute_slots[:num_to_load]}")
with torch.cuda.stream(self.transfer_stream_main):
for i in range(num_to_load):
cpu_id = cpu_block_ids[i]
gpu_slot = self.compute_slots[i]
# Copy all layers together
self.k_cache_gpu[:, gpu_slot].copy_(
self.k_cache_cpu[:, cpu_id], non_blocking=True
)
self.v_cache_gpu[:, gpu_slot].copy_(
self.v_cache_cpu[:, cpu_id], non_blocking=True
)
self.compute_ready.record(self.transfer_stream_main)
def load_to_prefetch(self, cpu_block_ids: List[int]) -> None:
"""
Async load CPU blocks to Prefetch region.
Args:
cpu_block_ids: List of CPU block IDs to load
"""
if not cpu_block_ids:
self.prefetch_ready.record(self.transfer_stream_main)
return
num_to_load = min(len(cpu_block_ids), len(self.prefetch_slots))
logger.debug(f"Prefetch load: CPU{cpu_block_ids[:num_to_load]} -> GPU prefetch slots {self.prefetch_slots[:num_to_load]}")
with torch.cuda.stream(self.transfer_stream_main):
for i in range(num_to_load):
cpu_id = cpu_block_ids[i]
gpu_slot = self.prefetch_slots[i]
self.k_cache_gpu[:, gpu_slot].copy_(
self.k_cache_cpu[:, cpu_id], non_blocking=True
)
self.v_cache_gpu[:, gpu_slot].copy_(
self.v_cache_cpu[:, cpu_id], non_blocking=True
)
self.prefetch_ready.record(self.transfer_stream_main)
def wait_compute(self) -> None:
"""Wait for Compute region loading to complete."""
self.compute_stream.wait_event(self.compute_ready)
def load_to_compute_layer(self, layer_id: int, cpu_block_ids: List[int]) -> None:
"""
Load CPU blocks to Compute region for a single layer only.
This is used for per-layer chunked attention where each layer
independently loads its KV data.
Args:
layer_id: Layer index to load
cpu_block_ids: List of CPU block IDs to load
"""
if not cpu_block_ids:
self.compute_ready_per_layer[layer_id].record(self.transfer_stream_main)
return
num_to_load = min(len(cpu_block_ids), len(self.compute_slots))
logger.debug(f"Compute load (layer {layer_id}): CPU{cpu_block_ids[:num_to_load]} -> GPU compute slots {self.compute_slots[:num_to_load]}")
with torch.cuda.stream(self.transfer_stream_main):
for i in range(num_to_load):
cpu_id = cpu_block_ids[i]
gpu_slot = self.compute_slots[i]
# Copy only this layer (not all layers)
self.k_cache_gpu[layer_id, gpu_slot].copy_(
self.k_cache_cpu[layer_id, cpu_id], non_blocking=True
)
self.v_cache_gpu[layer_id, gpu_slot].copy_(
self.v_cache_cpu[layer_id, cpu_id], non_blocking=True
)
self.compute_ready_per_layer[layer_id].record(self.transfer_stream_main)
def wait_compute_layer(self, layer_id: int) -> None:
"""Wait for specific layer's Compute region loading to complete."""
self.compute_stream.wait_event(self.compute_ready_per_layer[layer_id])
def wait_prefetch(self) -> None:
"""Wait for Prefetch region loading to complete."""
self.compute_stream.wait_event(self.prefetch_ready)
def load_to_prefetch_layer(self, layer_id: int, cpu_block_ids: List[int]) -> None:
"""
Load CPU blocks to Prefetch region for a single layer only.
This is used for per-layer chunked attention where each layer
independently loads its KV data.
Args:
layer_id: Layer index to load
cpu_block_ids: List of CPU block IDs to load
"""
if not cpu_block_ids:
self.prefetch_ready_per_layer[layer_id].record(self.transfer_stream_main)
return
num_to_load = min(len(cpu_block_ids), len(self.prefetch_slots))
logger.debug(f"Prefetch load (layer {layer_id}): CPU{cpu_block_ids[:num_to_load]} -> GPU prefetch slots {self.prefetch_slots[:num_to_load]}")
with torch.cuda.stream(self.transfer_stream_main):
for i in range(num_to_load):
cpu_id = cpu_block_ids[i]
gpu_slot = self.prefetch_slots[i]
# Copy only this layer (not all layers)
self.k_cache_gpu[layer_id, gpu_slot].copy_(
self.k_cache_cpu[layer_id, cpu_id], non_blocking=True
)
self.v_cache_gpu[layer_id, gpu_slot].copy_(
self.v_cache_cpu[layer_id, cpu_id], non_blocking=True
)
self.prefetch_ready_per_layer[layer_id].record(self.transfer_stream_main)
def wait_prefetch_layer(self, layer_id: int) -> None:
"""Wait for specific layer's Prefetch region loading to complete."""
self.compute_stream.wait_event(self.prefetch_ready_per_layer[layer_id])
def swap_compute_prefetch(self) -> None:
"""Swap roles of Compute region and Prefetch region."""
self.compute_slots, self.prefetch_slots = self.prefetch_slots, self.compute_slots
def offload_decode_slot(self, cpu_block_id: int) -> None: def offload_decode_slot(self, cpu_block_id: int) -> None:
""" """
Offload KV from Decode region to CPU. Offload KV from decode slot (slot[0]) to CPU.
Args: Args:
cpu_block_id: Target CPU block ID cpu_block_id: Target CPU block ID
""" """
logger.debug(f"Decode offload: GPU[{self.decode_slot}] -> CPU[{cpu_block_id}]") logger.debug(f"Decode offload: GPU slot[{self.decode_slot}] -> CPU[{cpu_block_id}]")
with torch.cuda.stream(self.transfer_stream_main): with torch.cuda.stream(self.transfer_stream_main):
self.transfer_stream_main.wait_stream(self.compute_stream) self.transfer_stream_main.wait_stream(self.compute_stream)
@@ -750,61 +788,16 @@ class OffloadEngine:
self.decode_offload_done.record(self.transfer_stream_main) self.decode_offload_done.record(self.transfer_stream_main)
def wait_decode_offload(self) -> None: def wait_decode_offload(self) -> None:
"""Wait for Decode region offload to complete.""" """Wait for decode slot offload to complete."""
self.compute_stream.wait_event(self.decode_offload_done) self.compute_stream.wait_event(self.decode_offload_done)
def get_kv_for_compute(
self,
layer_id: int,
num_blocks: int,
) -> Tuple[Tensor, Tensor]:
"""
Get KV for specified number of blocks in Compute region.
Args:
layer_id: Layer ID
num_blocks: Number of blocks needed
Returns:
(k_cache, v_cache), shape: [1, num_blocks * block_size, kv_heads, head_dim]
"""
slots = self.compute_slots[:num_blocks]
k = self.k_cache_gpu[layer_id, slots] # [num_blocks, block_size, heads, dim]
v = self.v_cache_gpu[layer_id, slots]
# Reshape: [num_blocks, block_size, heads, dim] -> [1, num_blocks*block_size, heads, dim]
k = k.reshape(1, -1, self.num_kv_heads, self.head_dim)
v = v.reshape(1, -1, self.num_kv_heads, self.head_dim)
return k, v
def get_kv_for_prefetch(
self,
layer_id: int,
num_blocks: int,
) -> Tuple[Tensor, Tensor]:
"""
Get KV for specified number of blocks in Prefetch region.
Args:
layer_id: Layer ID
num_blocks: Number of blocks needed
Returns:
(k_cache, v_cache), shape: [1, num_blocks * block_size, kv_heads, head_dim]
"""
slots = self.prefetch_slots[:num_blocks]
k = self.k_cache_gpu[layer_id, slots]
v = self.v_cache_gpu[layer_id, slots]
k = k.reshape(1, -1, self.num_kv_heads, self.head_dim)
v = v.reshape(1, -1, self.num_kv_heads, self.head_dim)
return k, v
def get_kv_for_decode_slot( def get_kv_for_decode_slot(
self, self,
layer_id: int, layer_id: int,
pos_in_block: int, pos_in_block: int,
) -> Tuple[Tensor, Tensor]: ) -> Tuple[Tensor, Tensor]:
""" """
Get KV at specified position in Decode region (for new token during decode). Get KV at specified position in decode slot.
Args: Args:
layer_id: Layer ID layer_id: Layer ID
@@ -813,9 +806,9 @@ class OffloadEngine:
Returns: Returns:
(k_cache, v_cache), shape: [1, 1, kv_heads, head_dim] (k_cache, v_cache), shape: [1, 1, kv_heads, head_dim]
""" """
k = self.k_cache_gpu[layer_id, self.decode_slot, pos_in_block:pos_in_block+1] # [1, heads, dim] k = self.k_cache_gpu[layer_id, self.decode_slot, pos_in_block:pos_in_block+1]
v = self.v_cache_gpu[layer_id, self.decode_slot, pos_in_block:pos_in_block+1] v = self.v_cache_gpu[layer_id, self.decode_slot, pos_in_block:pos_in_block+1]
k = k.unsqueeze(0) # [1, 1, heads, dim] k = k.unsqueeze(0)
v = v.unsqueeze(0) v = v.unsqueeze(0)
return k, v return k, v
@@ -825,10 +818,7 @@ class OffloadEngine:
num_tokens: int, num_tokens: int,
) -> Tuple[Tensor, Tensor]: ) -> Tuple[Tensor, Tensor]:
""" """
Get accumulated KV in Decode region (all tokens from position 0 to num_tokens-1). Get accumulated KV in decode slot (positions 0 to num_tokens-1).
Used when batching decode offloads - attend to all accumulated tokens,
not just the current one.
Args: Args:
layer_id: Layer ID layer_id: Layer ID
@@ -837,35 +827,102 @@ class OffloadEngine:
Returns: Returns:
(k_cache, v_cache), shape: [1, num_tokens, kv_heads, head_dim] (k_cache, v_cache), shape: [1, num_tokens, kv_heads, head_dim]
""" """
k = self.k_cache_gpu[layer_id, self.decode_slot, :num_tokens] # [num_tokens, heads, dim] k = self.k_cache_gpu[layer_id, self.decode_slot, :num_tokens]
v = self.v_cache_gpu[layer_id, self.decode_slot, :num_tokens] v = self.v_cache_gpu[layer_id, self.decode_slot, :num_tokens]
k = k.unsqueeze(0) # [1, num_tokens, heads, dim] k = k.unsqueeze(0)
v = v.unsqueeze(0) v = v.unsqueeze(0)
return k, v return k, v
def offload_compute_to_cpu(self, cpu_block_ids: List[int]) -> None: # ----- Legacy compatibility methods (for decode double-buffering) -----
"""
Offload KV from Compute region to CPU.
Args: def load_to_compute_layer(self, layer_id: int, cpu_block_ids: List[int]) -> None:
cpu_block_ids: Target CPU block IDs list """
Legacy: Load CPU blocks to decode_load_slots for decode double-buffering.
Uses first half of decode_load_slots as 'compute' region.
""" """
if not cpu_block_ids: if not cpu_block_ids:
return return
num_to_offload = min(len(cpu_block_ids), len(self.compute_slots)) half = max(1, len(self.decode_load_slots) // 2)
logger.debug(f"Compute offload: GPU {self.compute_slots[:num_to_offload]} -> CPU{cpu_block_ids[:num_to_offload]}") slots = self.decode_load_slots[:half]
num_to_load = min(len(cpu_block_ids), len(slots))
with torch.cuda.stream(self.transfer_stream_main): with torch.cuda.stream(self.transfer_stream_main):
# Wait for compute to complete for i in range(num_to_load):
self.transfer_stream_main.wait_stream(self.compute_stream)
for i in range(num_to_offload):
gpu_slot = self.compute_slots[i]
cpu_id = cpu_block_ids[i] cpu_id = cpu_block_ids[i]
self.k_cache_cpu[:, cpu_id].copy_( gpu_slot = slots[i]
self.k_cache_gpu[:, gpu_slot], non_blocking=True self.k_cache_gpu[layer_id, gpu_slot].copy_(
self.k_cache_cpu[layer_id, cpu_id], non_blocking=True
) )
self.v_cache_cpu[:, cpu_id].copy_( self.v_cache_gpu[layer_id, gpu_slot].copy_(
self.v_cache_gpu[:, gpu_slot], non_blocking=True self.v_cache_cpu[layer_id, cpu_id], non_blocking=True
) )
if num_to_load > 0:
self.ring_slot_ready[slots[0]][layer_id].record(self.transfer_stream_main)
def wait_compute_layer(self, layer_id: int) -> None:
"""Legacy: Wait for 'compute' region loading."""
half = max(1, len(self.decode_load_slots) // 2)
if self.decode_load_slots:
self.wait_slot_layer(self.decode_load_slots[0], layer_id)
def load_to_prefetch_layer(self, layer_id: int, cpu_block_ids: List[int]) -> None:
"""
Legacy: Load CPU blocks to decode_load_slots for decode double-buffering.
Uses second half of decode_load_slots as 'prefetch' region.
"""
if not cpu_block_ids:
return
half = max(1, len(self.decode_load_slots) // 2)
slots = self.decode_load_slots[half:]
if not slots:
slots = self.decode_load_slots # Fallback if only 1-2 slots
num_to_load = min(len(cpu_block_ids), len(slots))
with torch.cuda.stream(self.transfer_stream_main):
for i in range(num_to_load):
cpu_id = cpu_block_ids[i]
gpu_slot = slots[i]
self.k_cache_gpu[layer_id, gpu_slot].copy_(
self.k_cache_cpu[layer_id, cpu_id], non_blocking=True
)
self.v_cache_gpu[layer_id, gpu_slot].copy_(
self.v_cache_cpu[layer_id, cpu_id], non_blocking=True
)
if num_to_load > 0:
self.ring_slot_ready[slots[0]][layer_id].record(self.transfer_stream_main)
def wait_prefetch_layer(self, layer_id: int) -> None:
"""Legacy: Wait for 'prefetch' region loading."""
half = max(1, len(self.decode_load_slots) // 2)
slots = self.decode_load_slots[half:]
if slots:
self.wait_slot_layer(slots[0], layer_id)
elif self.decode_load_slots:
self.wait_slot_layer(self.decode_load_slots[0], layer_id)
def get_kv_for_compute(
self,
layer_id: int,
num_blocks: int,
) -> Tuple[Tensor, Tensor]:
"""Legacy: Get KV from 'compute' region (first half of decode_load_slots)."""
half = max(1, len(self.decode_load_slots) // 2)
slots = self.decode_load_slots[:half][:num_blocks]
return self.get_kv_for_slots(layer_id, slots)
def get_kv_for_prefetch(
self,
layer_id: int,
num_blocks: int,
) -> Tuple[Tensor, Tensor]:
"""Legacy: Get KV from 'prefetch' region (second half of decode_load_slots)."""
half = max(1, len(self.decode_load_slots) // 2)
slots = self.decode_load_slots[half:]
if not slots:
slots = self.decode_load_slots
slots = slots[:num_blocks]
return self.get_kv_for_slots(layer_id, slots)

View File

@@ -100,16 +100,19 @@ class Attention(nn.Module):
context, context,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Compute attention with three-region GPU buffer for chunked prefill. Compute attention with unified ring buffer for chunked prefill.
For chunked prefill: Ring buffer design:
1. Load previous KV from CPU using Compute/Prefetch region (if any previous chunks) - Current chunk's KV is written to ring_slot[chunk_idx % N]
2. Compute attention against previous KV chunks (no causal mask) - Previous chunks' KV are loaded from CPU using N-1 available slots
3. Compute attention against current chunk's KV (causal) - Pipeline: pre-fill slots, then process with overlapped load/compute
4. Merge all results using online softmax
Three-region design guarantees: current chunk's KV is in Compute region, previous KV is loaded For each layer:
from CPU to Prefetch region, so write and load regions never overlap. 1. Current chunk's KV is in k_batched, v_batched (just written by model)
2. Load previous chunks from CPU using available slots (pipeline)
3. Compute attention against previous KV (no causal mask)
4. Compute attention against current KV (causal)
5. Merge all results using online softmax
""" """
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
@@ -122,50 +125,32 @@ class Attention(nn.Module):
o_acc = None o_acc = None
lse_acc = None lse_acc = None
# Load previous KV from CPU using Compute/Prefetch region
kvcache_manager = context.kvcache_manager kvcache_manager = context.kvcache_manager
seq = context.chunked_seq if hasattr(context, 'chunked_seq') else None seq = context.chunked_seq if hasattr(context, 'chunked_seq') else None
current_chunk_idx = context.current_chunk_idx
if kvcache_manager is not None and seq is not None and self.layer_id >= 0: if kvcache_manager is not None and seq is not None and self.layer_id >= 0:
# Get prefilled CPU blocks (blocks already written in previous chunks) # Get prefilled CPU blocks (blocks from previous chunks)
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq) cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
if cpu_block_table: if cpu_block_table:
offload_engine = kvcache_manager.offload_engine offload_engine = kvcache_manager.offload_engine
# For prefill: ONLY use Prefetch region to avoid conflict with
# current chunk's KV being written to Compute region slots
# Use synchronous per-layer loading (async would conflict with writes)
chunk_size = offload_engine.num_prefetch_blocks
num_chunks = (len(cpu_block_table) + chunk_size - 1) // chunk_size
for chunk_idx in range(num_chunks): # Get write slot for current chunk and available load slots
start = chunk_idx * chunk_size write_slot = offload_engine.get_write_slot_for_prefill(current_chunk_idx)
end = min(start + chunk_size, len(cpu_block_table)) load_slots = offload_engine.get_load_slots_for_prefill(write_slot)
num_blocks_in_chunk = end - start pipeline_depth = len(load_slots)
chunk_ids = cpu_block_table[start:end]
# Load to Prefetch region (per-layer, sync) if pipeline_depth == 0:
offload_engine.load_to_prefetch_layer(self.layer_id, chunk_ids) # Only 1 slot total, cannot pipeline - use sync loading
offload_engine.wait_prefetch_layer(self.layer_id) o_acc, lse_acc = self._sync_load_previous_chunks(
q_batched, cpu_block_table, offload_engine
prev_k, prev_v = offload_engine.get_kv_for_prefetch(
self.layer_id, num_blocks_in_chunk
) )
# Compute attention against this chunk (no causal mask)
prev_o, prev_lse = flash_attn_with_lse(
q_batched,
prev_k,
prev_v,
softmax_scale=self.scale,
causal=False,
)
# Merge with accumulated
if o_acc is None:
o_acc, lse_acc = prev_o, prev_lse
else: else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse) # Use ring buffer pipeline
o_acc, lse_acc = self._ring_buffer_pipeline_load(
q_batched, cpu_block_table, load_slots, offload_engine
)
# Compute attention against current chunk's KV (with causal mask) # Compute attention against current chunk's KV (with causal mask)
current_o, current_lse = flash_attn_with_lse( current_o, current_lse = flash_attn_with_lse(
@@ -185,6 +170,91 @@ class Attention(nn.Module):
# Remove batch dimension: [1, total_tokens, heads, dim] -> [total_tokens, heads, dim] # Remove batch dimension: [1, total_tokens, heads, dim] -> [total_tokens, heads, dim]
return final_o.squeeze(0) return final_o.squeeze(0)
def _sync_load_previous_chunks(
self,
q_batched: torch.Tensor,
cpu_block_table: list,
offload_engine,
):
"""Synchronous loading fallback when pipeline_depth=0."""
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
o_acc, lse_acc = None, None
for block_idx, cpu_block_id in enumerate(cpu_block_table):
# Load to slot 0 (single slot)
offload_engine.load_to_slot_layer(0, self.layer_id, cpu_block_id)
offload_engine.wait_slot_layer(0, self.layer_id)
prev_k, prev_v = offload_engine.get_kv_for_slot(0, self.layer_id)
prev_o, prev_lse = flash_attn_with_lse(
q_batched, prev_k, prev_v,
softmax_scale=self.scale,
causal=False,
)
if o_acc is None:
o_acc, lse_acc = prev_o, prev_lse
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
return o_acc, lse_acc
def _ring_buffer_pipeline_load(
self,
q_batched: torch.Tensor,
cpu_block_table: list,
load_slots: list,
offload_engine,
):
"""
Ring buffer synchronous loading for previous chunks.
For correctness, we use synchronous loading:
- Load one block at a time
- Wait for transfer, compute attention, then load next
This ensures no data races between transfer and computation.
"""
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
num_blocks = len(cpu_block_table)
if num_blocks == 0:
return None, None
pipeline_depth = len(load_slots)
o_acc, lse_acc = None, None
# Process blocks one by one (synchronous)
for block_idx in range(num_blocks):
# Determine which slot to use (cycle through load_slots)
slot_idx = load_slots[block_idx % pipeline_depth]
cpu_block_id = cpu_block_table[block_idx]
# Load block to slot (async)
offload_engine.load_to_slot_layer(slot_idx, self.layer_id, cpu_block_id)
# Wait for transfer to complete
offload_engine.wait_slot_layer(slot_idx, self.layer_id)
# Get KV from slot and compute attention
prev_k, prev_v = offload_engine.get_kv_for_slot(slot_idx, self.layer_id)
prev_o, prev_lse = flash_attn_with_lse(
q_batched, prev_k, prev_v,
softmax_scale=self.scale,
causal=False,
)
# Merge with accumulated
if o_acc is None:
o_acc, lse_acc = prev_o, prev_lse
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
return o_acc, lse_acc
def _chunked_decode_attention( def _chunked_decode_attention(
self, self,
q: torch.Tensor, q: torch.Tensor,
@@ -193,20 +263,24 @@ class Attention(nn.Module):
context, context,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Compute decode attention with async double-buffering using Compute and Prefetch regions. Compute decode attention with double-buffering using decode_load_slots.
Decode uses:
- decode_slot (slot[0]): writes new token's KV
- decode_load_slots (slots[1:]): load previous chunks from CPU
Pipeline design: Pipeline design:
- Compute region: holds current chunk being computed - First half of decode_load_slots: 'compute' buffer
- Prefetch region: async loads next chunk while current is computing - Second half: 'prefetch' buffer
- After computation, swap roles of the two regions - Double-buffer between them for async overlap
Timeline: Timeline:
┌─────────────┐ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐
│Load C0→Comp │ │Load C1→Pref │ │Load C2→Comp │ ... │Load C0→buf0 │ │Load C1→buf1 │ │Load C2→buf0 │ ...
└─────────────┘ └─────────────┘ └─────────────┘ └─────────────┘ └─────────────┘ └─────────────┘
↘ ↘ ↘ ↘ ↘ ↘
┌─────────────┐ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐
Compute C0 │ │ Compute C1 │ │ Compute C2 Attn(C0) │ │ Attn(C1) │ │ Attn(C2)
└─────────────┘ └─────────────┘ └─────────────┘ └─────────────┘ └─────────────┘ └─────────────┘
""" """
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs

View File

@@ -32,6 +32,8 @@ class Context:
# Starting position within block where decode tokens began (for accumulated token tracking) # Starting position within block where decode tokens began (for accumulated token tracking)
# Used when batching decode offloads - we need to attend to all accumulated tokens # Used when batching decode offloads - we need to attend to all accumulated tokens
decode_start_pos_in_block: int = 0 decode_start_pos_in_block: int = 0
# Current chunk index for ring buffer pipeline (prefill only)
current_chunk_idx: int = 0
_CONTEXT = Context() _CONTEXT = Context()
@@ -57,6 +59,7 @@ def set_context(
chunked_seq=None, chunked_seq=None,
decode_pos_in_block=0, decode_pos_in_block=0,
decode_start_pos_in_block=0, decode_start_pos_in_block=0,
current_chunk_idx=0,
): ):
global _CONTEXT global _CONTEXT
_CONTEXT = Context( _CONTEXT = Context(
@@ -75,6 +78,7 @@ def set_context(
chunked_seq=chunked_seq, chunked_seq=chunked_seq,
decode_pos_in_block=decode_pos_in_block, decode_pos_in_block=decode_pos_in_block,
decode_start_pos_in_block=decode_start_pos_in_block, decode_start_pos_in_block=decode_start_pos_in_block,
current_chunk_idx=current_chunk_idx,
) )