[feat] Need to optimized with async prefetch.
This commit is contained in:
99
CLAUDE.md
99
CLAUDE.md
@@ -44,74 +44,101 @@ Nano-vLLM is a lightweight vLLM implementation (~1,200 lines) for fast offline L
|
||||
|
||||
When `enable_cpu_offload=True`, KV cache is stored on CPU with a small GPU buffer for computation. This enables long-context inference with limited GPU memory.
|
||||
|
||||
### Three-Region GPU Buffer Design
|
||||
### Unified Ring Buffer Design
|
||||
|
||||
```
|
||||
GPU Slots: [0] [1, 2, 3] [4, 5]
|
||||
↑ ↑ ↑
|
||||
decode compute prefetch
|
||||
(1 slot) (N slots) (M slots)
|
||||
GPU Slots: [0] [1] [2] [3] [4] ...
|
||||
←────────────────────────────→
|
||||
All slots as ring buffer
|
||||
|
||||
- Decode slot: New token's KV written here during decode
|
||||
- Compute region: Load CPU blocks for current chunk computation
|
||||
- Prefetch region: Async load next chunk while computing current
|
||||
Prefill: ALL slots cycle as ring buffer [slot = chunk_idx % N]
|
||||
Decode: slot[0] = decode_slot, slots[1:] = load slots for previous chunks
|
||||
```
|
||||
|
||||
**File**: `nanovllm/kvcache/offload_engine.py`
|
||||
|
||||
Key attributes:
|
||||
- `num_ring_slots`: Total GPU slots (= num_gpu_blocks)
|
||||
- `ring_slots`: List of all GPU slot indices [0, 1, 2, ...]
|
||||
- `decode_slot = 0`: Fixed slot for decode KV writes
|
||||
- `compute_slots`: List of GPU slots for compute region
|
||||
- `prefetch_slots`: List of GPU slots for prefetch region
|
||||
- `decode_load_slots`: Slots[1:] for loading previous chunks during decode
|
||||
- `k_cache_gpu/v_cache_gpu`: Shape `[num_layers, num_gpu_blocks, block_size, kv_heads, head_dim]`
|
||||
- `k_cache_cpu/v_cache_cpu`: Shape `[num_layers, num_cpu_blocks, block_size, kv_heads, head_dim]` (pinned memory)
|
||||
|
||||
### Per-Layer Loading (Critical Design)
|
||||
|
||||
**Problem solved**: Original design had layer 0 load ALL layers' KV at once. When layer 0 processed chunk 1, it overwrote chunk 0's data before layer 1+ could read it.
|
||||
|
||||
**Solution**: Each layer independently loads only its own KV data:
|
||||
Key methods:
|
||||
```python
|
||||
# Per-layer methods in OffloadEngine
|
||||
load_to_compute_layer(layer_id, cpu_block_ids) # Load single layer to compute region
|
||||
wait_compute_layer(layer_id) # Wait for layer's transfer
|
||||
load_to_prefetch_layer(layer_id, cpu_block_ids) # Load single layer to prefetch region
|
||||
wait_prefetch_layer(layer_id) # Wait for layer's prefetch
|
||||
# Prefill: get write slot and load slots
|
||||
get_write_slot_for_prefill(chunk_idx) # Returns chunk_idx % num_ring_slots
|
||||
get_load_slots_for_prefill(write_slot_idx) # Returns all slots except write_slot
|
||||
|
||||
# Decode: get load slots (excludes decode_slot)
|
||||
get_load_slots_for_decode() # Returns slots[1:]
|
||||
|
||||
# Per-slot per-layer operations
|
||||
load_to_slot_layer(slot_idx, layer_id, cpu_block_id) # Async load single block
|
||||
wait_slot_layer(slot_idx, layer_id) # Wait for layer's transfer
|
||||
offload_slot_to_cpu(slot_idx, cpu_block_id) # Async offload to CPU
|
||||
```
|
||||
|
||||
### Chunked Prefill Flow
|
||||
### Per-Slot Per-Layer Events (Critical Design)
|
||||
|
||||
Each slot has per-layer CUDA events for fine-grained synchronization:
|
||||
- `ring_slot_ready[slot_idx][layer_id]`: H2D transfer completion
|
||||
- `ring_slot_offload_done[slot_idx][layer_id]`: D2H transfer completion
|
||||
|
||||
This enables:
|
||||
1. Overlapped H2D transfer with attention computation
|
||||
2. Each layer independently waits for its own data
|
||||
3. Pipeline depth = N-1 for prefill (N slots, 1 for writing)
|
||||
|
||||
### Chunked Prefill Flow (Ring Buffer Pipeline)
|
||||
|
||||
**File**: `nanovllm/layers/attention.py` - `_chunked_prefill_attention()`
|
||||
|
||||
```
|
||||
For each prefill chunk:
|
||||
1. Current chunk's KV is written to GPU (compute region slots)
|
||||
2. Load previous chunks' KV from CPU to prefetch region
|
||||
For prefill chunk K:
|
||||
1. Current chunk's KV written to ring_slot[K % N]
|
||||
2. Load previous chunks from CPU using N-1 available slots (pipeline)
|
||||
3. Compute attention against previous KV (no causal mask)
|
||||
4. Compute attention against current KV (causal mask)
|
||||
5. Merge results using online softmax (LSE)
|
||||
6. Offload current chunk's KV to CPU
|
||||
6. Offload current slot to CPU
|
||||
|
||||
Pipeline Timeline (with 4 slots, processing chunk 3):
|
||||
write_slot = 3, load_slots = [0, 1, 2]
|
||||
|
||||
┌─────────────┐ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐
|
||||
│Load B0→S0 │ │Load B1→S1 │ │Load B2→S2 │ │ (wait) │
|
||||
└─────────────┘ └─────────────┘ └─────────────┘ └─────────────┘
|
||||
↘ ↘ ↘
|
||||
┌─────────────┐ ┌─────────────┐ ┌─────────────┐
|
||||
│ Attn(B0) │ │ Attn(B1) │ │ Attn(B2) │
|
||||
└─────────────┘ └─────────────┘ └─────────────┘
|
||||
```
|
||||
|
||||
**Important**: Prefill uses ONLY prefetch region to avoid conflict with current chunk's KV being written to compute region.
|
||||
**Key**: Write slot cycles through ALL slots, load slots = all except write slot.
|
||||
|
||||
### Chunked Decode Flow (Double Buffering)
|
||||
|
||||
**File**: `nanovllm/layers/attention.py` - `_chunked_decode_attention()`
|
||||
|
||||
Decode uses legacy double-buffering with `decode_load_slots`:
|
||||
- First half of decode_load_slots: 'compute' buffer
|
||||
- Second half: 'prefetch' buffer
|
||||
|
||||
```
|
||||
Timeline (async double buffering):
|
||||
Timeline:
|
||||
┌─────────────┐ ┌─────────────┐ ┌─────────────┐
|
||||
Load: │C0 → Compute │ │C1 → Prefetch│ │C2 → Compute │
|
||||
Load: │C0 → buf0 │ │C1 → buf1 │ │C2 → buf0 │
|
||||
└─────────────┘ └─────────────┘ └─────────────┘
|
||||
↘ ↘ ↘
|
||||
Compute: [C0] [C1] [C2]
|
||||
|
||||
1. Pre-load first chunk to compute region
|
||||
2. Wait for current buffer, trigger async prefetch of next chunk to OTHER buffer
|
||||
1. Pre-load first chunk to compute buffer
|
||||
2. Wait for current buffer, trigger async prefetch to OTHER buffer
|
||||
3. Compute attention, merge results
|
||||
4. Swap buffers, repeat
|
||||
5. Finally attend to decode slot (new token's KV)
|
||||
5. Finally attend to decode_slot (new token's KV)
|
||||
```
|
||||
|
||||
### HybridKVCacheManager
|
||||
@@ -120,7 +147,7 @@ Compute: [C0] [C1] [C2]
|
||||
|
||||
Manages both GPU and CPU blocks:
|
||||
- `allocate()`: Allocate GPU block first, fallback to CPU
|
||||
- `allocate_cpu_only()`: Force CPU allocation (for chunked offload mode)
|
||||
- `allocate_cpu_only()`: Force CPU allocation (for ring buffer mode)
|
||||
- `get_all_cpu_blocks(seq)`: Get all CPU block IDs for a sequence
|
||||
- `get_prefilled_cpu_blocks(seq)`: Get CPU blocks from previous chunks
|
||||
- `get_write_slot_for_chunked_offload(seq)`: Get GPU slot for writing new KV (returns decode_slot)
|
||||
@@ -136,9 +163,7 @@ def merge_attention_outputs(o1, lse1, o2, lse2):
|
||||
# Uses LSE to correctly weight and combine partial attention outputs
|
||||
```
|
||||
|
||||
### Ring Buffer Design (Future Optimization)
|
||||
### Pipeline Depth
|
||||
|
||||
Current double-buffering limits pipeline depth. Planned improvement:
|
||||
- Unified ring buffer using all GPU slots (except decode)
|
||||
- Per-slot per-layer CUDA events for fine-grained sync
|
||||
- Deeper pipeline: prefetch N-1 blocks ahead (vs 1 chunk)
|
||||
- **Prefill**: Pipeline depth = N-1 (where N = num_gpu_blocks)
|
||||
- **Decode**: Pipeline depth = (N-1)/2 (double buffering within decode_load_slots)
|
||||
|
||||
@@ -41,8 +41,8 @@ def main():
|
||||
max_model_len=128 * 1024,
|
||||
max_num_batched_tokens=128 * 1024,
|
||||
enable_cpu_offload=True,
|
||||
num_gpu_blocks=6,
|
||||
num_prefetch_blocks=2,
|
||||
num_gpu_blocks=120,
|
||||
num_prefetch_blocks=4,
|
||||
)
|
||||
|
||||
# Warmup
|
||||
@@ -54,12 +54,12 @@ def main():
|
||||
# bench_prefill(llm, num_seqs=1, input_len=1024)
|
||||
# bench_prefill(llm, num_seqs=1, input_len=2048)
|
||||
# bench_prefill(llm, num_seqs=1, input_len=4096)
|
||||
bench_prefill(llm, num_seqs=1, input_len=64 * 1024)
|
||||
bench_prefill(llm, num_seqs=1, input_len=16 * 1024)
|
||||
|
||||
print("=" * 60)
|
||||
print("Decode Benchmark (CPU Offload)")
|
||||
print("=" * 60)
|
||||
bench_decode(llm, num_seqs=1, input_len=64 * 1024, max_output_len=128)
|
||||
bench_decode(llm, num_seqs=1, input_len=16 * 1024, max_output_len=128)
|
||||
# bench_decode(llm, num_seqs=1, input_len=2048, max_output_len=128)
|
||||
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ class Config:
|
||||
enforce_eager: bool = False
|
||||
hf_config: AutoConfig | None = None
|
||||
eos: int = -1
|
||||
kvcache_block_size: int = 256
|
||||
kvcache_block_size: int = 4096
|
||||
num_kvcache_blocks: int = -1
|
||||
|
||||
# CPU Offload configuration
|
||||
|
||||
@@ -630,29 +630,31 @@ class ModelRunner:
|
||||
|
||||
def run_chunked_offload_prefill(self, seqs: list[Sequence]) -> list[int]:
|
||||
"""
|
||||
Run prefill with three-region GPU buffer (CPU is primary storage).
|
||||
Run prefill with unified ring buffer (CPU is primary storage).
|
||||
|
||||
Flow:
|
||||
1. All blocks are allocated to CPU (primary storage)
|
||||
2. Process tokens in chunks using Compute region GPU buffer
|
||||
3. After each chunk, offload from Compute region to CPU
|
||||
4. Prefetch region is used to load previous KV (if any)
|
||||
2. Each chunk writes KV to ring buffer slot[chunk_idx % N]
|
||||
3. After each chunk, offload from ring buffer slot to CPU
|
||||
4. All N-1 other slots are used to load previous chunks for attention
|
||||
"""
|
||||
import sys
|
||||
|
||||
assert len(seqs) == 1, "Three-region prefill only supports single sequence"
|
||||
assert len(seqs) == 1, "Ring buffer prefill only supports single sequence"
|
||||
seq = seqs[0]
|
||||
|
||||
offload_engine = self.kvcache_manager.offload_engine
|
||||
compute_size = offload_engine.num_compute_blocks
|
||||
tokens_per_chunk = compute_size * self.block_size
|
||||
# Each chunk uses 1 ring buffer slot = 1 block
|
||||
tokens_per_chunk = self.block_size
|
||||
|
||||
total_tokens = len(seq)
|
||||
print(f"[Three-region Prefill] Starting: {total_tokens} tokens, "
|
||||
f"compute_size={compute_size} blocks, chunk={tokens_per_chunk} tokens",
|
||||
num_chunks = (total_tokens + tokens_per_chunk - 1) // tokens_per_chunk
|
||||
print(f"[Ring Buffer Prefill] Starting: {total_tokens} tokens, "
|
||||
f"ring_slots={offload_engine.num_ring_slots}, chunk={tokens_per_chunk} tokens, "
|
||||
f"total_chunks={num_chunks}",
|
||||
file=sys.stderr)
|
||||
|
||||
chunk_num = 0
|
||||
chunk_idx = 0
|
||||
logits = None
|
||||
processed_tokens = 0
|
||||
|
||||
@@ -660,27 +662,22 @@ class ModelRunner:
|
||||
cpu_block_ids, logical_ids = self.kvcache_manager.get_all_cpu_blocks(seq)
|
||||
|
||||
while processed_tokens < total_tokens:
|
||||
chunk_num += 1
|
||||
chunk_start = processed_tokens
|
||||
chunk_end = min(processed_tokens + tokens_per_chunk, total_tokens)
|
||||
chunk_tokens = chunk_end - chunk_start
|
||||
|
||||
# Calculate which CPU blocks this chunk covers
|
||||
start_block_idx = chunk_start // self.block_size
|
||||
end_block_idx = (chunk_end + self.block_size - 1) // self.block_size
|
||||
num_blocks = end_block_idx - start_block_idx
|
||||
# Get ring buffer slot for this chunk
|
||||
write_slot = offload_engine.get_write_slot_for_prefill(chunk_idx)
|
||||
|
||||
print(f"[Three-region Prefill] Chunk {chunk_num}: tokens {chunk_start}-{chunk_end}, "
|
||||
f"blocks {start_block_idx}-{end_block_idx-1}, "
|
||||
f"compute_slots={offload_engine.compute_slots[:num_blocks]}",
|
||||
# CPU block index for this chunk
|
||||
block_idx = chunk_idx
|
||||
|
||||
print(f"[Ring Buffer Prefill] Chunk {chunk_idx}: tokens {chunk_start}-{chunk_end}, "
|
||||
f"write_slot={write_slot}",
|
||||
file=sys.stderr)
|
||||
|
||||
# Get GPU slots for this chunk (using Compute region)
|
||||
gpu_slots = offload_engine.compute_slots[:num_blocks]
|
||||
|
||||
# Prepare inputs
|
||||
input_ids, positions = self._prepare_chunked_offload_chunk(
|
||||
seq, chunk_start, chunk_end, gpu_slots, start_block_idx
|
||||
seq, chunk_start, chunk_end, write_slot, block_idx, chunk_idx
|
||||
)
|
||||
|
||||
if input_ids.numel() == 0:
|
||||
@@ -690,24 +687,27 @@ class ModelRunner:
|
||||
logits = self.run_model(input_ids, positions, is_prefill=True)
|
||||
reset_context()
|
||||
|
||||
# Mark blocks as prefilled
|
||||
for i in range(start_block_idx, min(end_block_idx, len(seq.block_table))):
|
||||
logical_id = seq.block_table[i]
|
||||
# Mark block as prefilled
|
||||
if block_idx < len(seq.block_table):
|
||||
logical_id = seq.block_table[block_idx]
|
||||
self.kvcache_manager.prefilled_blocks.add(logical_id)
|
||||
|
||||
# Offload this chunk from Compute region to CPU (async)
|
||||
chunk_cpu_blocks = cpu_block_ids[start_block_idx:end_block_idx]
|
||||
offload_engine.offload_compute_to_cpu(chunk_cpu_blocks)
|
||||
# Offload this chunk's ring buffer slot to CPU (async)
|
||||
if block_idx < len(cpu_block_ids):
|
||||
cpu_block_id = cpu_block_ids[block_idx]
|
||||
offload_engine.offload_slot_to_cpu(write_slot, cpu_block_id)
|
||||
|
||||
# Wait for offload to complete before next chunk
|
||||
offload_engine.wait_all_offload_done()
|
||||
# (slot will be reused after N chunks)
|
||||
offload_engine.wait_slot_offload(write_slot)
|
||||
|
||||
processed_tokens = chunk_end
|
||||
chunk_idx += 1
|
||||
|
||||
# Wait for all offloads to complete
|
||||
offload_engine.wait_all_offload_done()
|
||||
|
||||
print(f"[Three-region Prefill] Complete: {chunk_num} chunks", file=sys.stderr)
|
||||
print(f"[Ring Buffer Prefill] Complete: {chunk_idx} chunks", file=sys.stderr)
|
||||
|
||||
# Sample from last logits
|
||||
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
|
||||
@@ -723,34 +723,24 @@ class ModelRunner:
|
||||
seq: Sequence,
|
||||
chunk_start: int,
|
||||
chunk_end: int,
|
||||
gpu_slots: list[int],
|
||||
start_block_idx: int,
|
||||
write_slot: int,
|
||||
block_idx: int,
|
||||
chunk_idx: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Prepare inputs for a chunked offload prefill chunk."""
|
||||
"""Prepare inputs for a chunked offload prefill chunk (ring buffer design)."""
|
||||
# Input tokens for this chunk
|
||||
input_ids = seq[chunk_start:chunk_end]
|
||||
positions = list(range(chunk_start, chunk_end))
|
||||
|
||||
# Create slot mapping pointing to GPU slots
|
||||
# Create slot mapping pointing to the single write_slot
|
||||
slot_mapping = []
|
||||
num_tokens = chunk_end - chunk_start
|
||||
|
||||
token_idx = 0
|
||||
for i, gpu_slot in enumerate(gpu_slots):
|
||||
block_idx = start_block_idx + i
|
||||
block_start = block_idx * self.block_size
|
||||
block_end = min(block_start + self.block_size, len(seq))
|
||||
|
||||
# How many tokens in this block for this chunk
|
||||
overlap_start = max(chunk_start, block_start)
|
||||
overlap_end = min(chunk_end, block_end)
|
||||
|
||||
for pos in range(overlap_start, overlap_end):
|
||||
for pos in range(chunk_start, chunk_end):
|
||||
pos_in_block = pos % self.block_size
|
||||
slot = gpu_slot * self.block_size + pos_in_block
|
||||
slot = write_slot * self.block_size + pos_in_block
|
||||
slot_mapping.append(slot)
|
||||
|
||||
# Convert to tensors
|
||||
num_tokens = chunk_end - chunk_start
|
||||
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
||||
positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
||||
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||
@@ -770,21 +760,23 @@ class ModelRunner:
|
||||
is_chunked_prefill=True,
|
||||
kvcache_manager=self.kvcache_manager,
|
||||
chunked_seq=seq,
|
||||
current_chunk_idx=chunk_idx, # Pass chunk index for ring buffer pipeline
|
||||
)
|
||||
|
||||
return input_ids, positions
|
||||
|
||||
def run_chunked_offload_decode(self, seqs: list[Sequence]) -> list[int]:
|
||||
"""
|
||||
Run decode with three-region GPU buffer.
|
||||
Run decode with ring buffer (CPU is primary storage).
|
||||
|
||||
All KV is on CPU. Uses Decode region to write new KV, Compute/Prefetch region to load KV chunks.
|
||||
New token's KV is written to Decode region (slot 0) then offloaded to CPU only when block is full.
|
||||
All KV is on CPU. Uses decode_slot (slot[0]) to write new KV.
|
||||
Other slots (slots[1:]) are used to load previous KV chunks via pipeline.
|
||||
New token's KV is written to decode_slot then offloaded to CPU only when block is full.
|
||||
|
||||
Key: Decode region is never overwritten by Compute/Prefetch, dedicated to writing new KV.
|
||||
Key: decode_slot is dedicated to writing new KV, never used for loading.
|
||||
Optimization: Batch offloads - only offload when block is full, attend to all accumulated tokens.
|
||||
"""
|
||||
assert len(seqs) == 1, "Three-region decode only supports single sequence"
|
||||
assert len(seqs) == 1, "Ring buffer decode only supports single sequence"
|
||||
seq = seqs[0]
|
||||
|
||||
offload_engine = self.kvcache_manager.offload_engine
|
||||
|
||||
@@ -12,7 +12,7 @@ class SequenceStatus(Enum):
|
||||
|
||||
|
||||
class Sequence:
|
||||
block_size = 256
|
||||
block_size = 4096
|
||||
counter = count()
|
||||
|
||||
def __init__(self, token_ids: list[int], sampling_params = SamplingParams()):
|
||||
|
||||
@@ -95,16 +95,16 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
num_cpu_blocks: Number of CPU pool blocks (overflow or primary storage)
|
||||
block_size: Tokens per block
|
||||
policy: Eviction policy (default: LRU)
|
||||
cpu_primary: If True, use CPU as primary storage with three-region GPU buffer.
|
||||
cpu_primary: If True, use CPU as primary storage with ring buffer GPU design.
|
||||
If False, use GPU as primary with CPU as overflow (legacy mode).
|
||||
num_prefetch_blocks: Number of prefetch blocks for three-region GPU buffer design
|
||||
num_prefetch_blocks: Number of blocks for ring buffer pipeline (deprecated, ring_slots = num_gpu_slots)
|
||||
"""
|
||||
self._block_size = block_size
|
||||
self.num_gpu_slots = num_gpu_slots
|
||||
self.num_cpu_blocks = num_cpu_blocks
|
||||
self.total_blocks = num_gpu_slots + num_cpu_blocks
|
||||
self.cpu_primary = cpu_primary # Three-region mode flag
|
||||
self.num_prefetch_blocks = num_prefetch_blocks # Three-region design parameter
|
||||
self.cpu_primary = cpu_primary # Ring buffer mode flag
|
||||
self.num_prefetch_blocks = num_prefetch_blocks # Ring buffer design parameter (deprecated)
|
||||
|
||||
# Eviction policy
|
||||
self.policy = policy or LRUPolicy()
|
||||
@@ -341,7 +341,7 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
"""
|
||||
assert not seq.block_table, "Sequence already has blocks"
|
||||
|
||||
# Three-region mode: all blocks are allocated to CPU
|
||||
# Ring buffer mode: all blocks are allocated to CPU
|
||||
if self.cpu_primary:
|
||||
return self.allocate_cpu_only(seq)
|
||||
|
||||
@@ -471,7 +471,7 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
block.token_ids = []
|
||||
|
||||
if self.cpu_primary:
|
||||
# Three-region mode: new block allocated to CPU
|
||||
# Ring buffer mode: new block allocated to CPU
|
||||
if not self.free_cpu_blocks:
|
||||
raise RuntimeError("No free CPU blocks for decode")
|
||||
cpu_block_id = self.free_cpu_blocks.popleft()
|
||||
@@ -1025,14 +1025,14 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
break
|
||||
return pos
|
||||
|
||||
# ========== Three-region double buffering support ==========
|
||||
# ========== Ring Buffer CPU-primary support ==========
|
||||
|
||||
def allocate_cpu_only(self, seq: Sequence) -> None:
|
||||
"""
|
||||
Allocate CPU blocks for sequence (for three-region mode).
|
||||
Allocate CPU blocks for sequence (for ring buffer mode).
|
||||
|
||||
Unlike allocate(), here all blocks are allocated to CPU,
|
||||
GPU is only used as working buffer.
|
||||
GPU is only used as ring buffer for computation.
|
||||
|
||||
Args:
|
||||
seq: Sequence to allocate
|
||||
@@ -1092,10 +1092,10 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
cpu_blocks.append(block.cpu_block_id)
|
||||
else:
|
||||
# If block is on GPU, it should have a corresponding CPU block
|
||||
# In three-region mode, all data ultimately resides on CPU
|
||||
# In ring buffer mode, all data ultimately resides on CPU
|
||||
raise RuntimeError(
|
||||
f"Block {logical_id} not on CPU (location={block.location}). "
|
||||
f"In three-region mode, all blocks should be on CPU."
|
||||
f"In ring buffer mode, all blocks should be on CPU."
|
||||
)
|
||||
return cpu_blocks
|
||||
|
||||
@@ -1171,8 +1171,8 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
"""
|
||||
Get GPU slot for writing new KV during chunked offload decode.
|
||||
|
||||
In three-region design, always use Decode region (slot 0) to write new KV.
|
||||
This avoids conflicts with Compute/Prefetch region loading operations.
|
||||
In ring buffer design, always use decode_slot (slot[0]) to write new KV.
|
||||
This avoids conflicts with loading operations which use slots[1:].
|
||||
|
||||
Args:
|
||||
seq: Sequence
|
||||
|
||||
@@ -65,34 +65,30 @@ class OffloadEngine:
|
||||
self.kv_dim = num_kv_heads * head_dim
|
||||
self.block_numel = block_size * self.kv_dim
|
||||
|
||||
# ========== Three-region GPU Buffer configuration ==========
|
||||
# ========== Unified Ring Buffer configuration ==========
|
||||
# Constraint checks
|
||||
assert num_gpu_blocks >= 3, \
|
||||
f"Need at least 3 GPU blocks: 1 decode + 1 compute + 1 prefetch, got {num_gpu_blocks}"
|
||||
assert num_prefetch_blocks >= 1, \
|
||||
f"Need at least 1 prefetch block, got {num_prefetch_blocks}"
|
||||
assert num_gpu_blocks >= 1 + 1 + num_prefetch_blocks, \
|
||||
f"Insufficient GPU blocks: need 1(decode) + 1(compute) + {num_prefetch_blocks}(prefetch), got {num_gpu_blocks}"
|
||||
assert num_gpu_blocks >= 2, \
|
||||
f"Need at least 2 GPU blocks for ring buffer, got {num_gpu_blocks}"
|
||||
|
||||
# Three-region configuration
|
||||
# Decode region: [0] - Fixed 1 block for writing new KV
|
||||
# Unified Ring Buffer: all slots cycle for prefill
|
||||
# Prefill: use ALL slots as ring buffer (slot[chunk_idx % N])
|
||||
# Decode: slot[0] as decode_slot, slots[1:] for loading previous chunks
|
||||
self.num_ring_slots = num_gpu_blocks
|
||||
self.ring_slots = list(range(num_gpu_blocks))
|
||||
|
||||
# Decode phase uses slot[0] for writing new token's KV
|
||||
self.decode_slot = 0
|
||||
# Decode phase uses slots[1:] for loading previous chunks from CPU
|
||||
self.decode_load_slots = list(range(1, num_gpu_blocks))
|
||||
self.num_decode_load_slots = len(self.decode_load_slots)
|
||||
|
||||
# Compute region: [1, ..., num_gpu_blocks - num_prefetch_blocks - 1]
|
||||
compute_start = 1
|
||||
compute_end = num_gpu_blocks - num_prefetch_blocks
|
||||
self.compute_slots = list(range(compute_start, compute_end))
|
||||
self.num_compute_blocks = len(self.compute_slots)
|
||||
|
||||
# Prefetch region: [num_gpu_blocks - num_prefetch_blocks, ..., num_gpu_blocks - 1]
|
||||
prefetch_start = compute_end
|
||||
self.prefetch_slots = list(range(prefetch_start, num_gpu_blocks))
|
||||
# Keep num_prefetch_blocks for compatibility (used as chunk size for loading)
|
||||
self.num_prefetch_blocks = num_prefetch_blocks
|
||||
|
||||
self.num_gpu_slots = num_gpu_blocks # alias
|
||||
|
||||
logger.info(f"Three-region GPU Buffer: decode_slot={self.decode_slot}, "
|
||||
f"compute_slots={self.compute_slots}, prefetch_slots={self.prefetch_slots}")
|
||||
logger.info(f"Unified Ring Buffer: {self.num_ring_slots} slots total")
|
||||
logger.info(f" Prefill: all slots as ring buffer [0..{num_gpu_blocks-1}]")
|
||||
logger.info(f" Decode: slot[0] as decode_slot, slots[1..{num_gpu_blocks-1}] for loading")
|
||||
|
||||
# ========== Fixed-address GPU KV cache ==========
|
||||
# Shape: [num_layers, num_gpu_blocks, block_size, kv_heads, head_dim]
|
||||
@@ -134,18 +130,27 @@ class OffloadEngine:
|
||||
self.compute_stream = torch.cuda.current_stream()
|
||||
self._stream_idx = 0
|
||||
|
||||
# ========== Three-region dedicated stream and events ==========
|
||||
# ========== Ring Buffer dedicated stream and events ==========
|
||||
self.transfer_stream_main = torch.cuda.Stream() # Main transfer stream
|
||||
|
||||
# Sync events - three-region loading completion
|
||||
self.compute_ready = torch.cuda.Event()
|
||||
self.prefetch_ready = torch.cuda.Event()
|
||||
# Decode offload event
|
||||
self.decode_offload_done = torch.cuda.Event()
|
||||
|
||||
# ========== Per-layer events for chunked attention ==========
|
||||
# Each layer has its own event for synchronization
|
||||
self.compute_ready_per_layer = [torch.cuda.Event() for _ in range(num_layers)]
|
||||
self.prefetch_ready_per_layer = [torch.cuda.Event() for _ in range(num_layers)]
|
||||
# ========== Per-slot Per-layer events for ring buffer ==========
|
||||
# ring_slot_ready[slot_idx][layer_id] = CUDA Event for H2D completion
|
||||
# ring_slot_offload_done[slot_idx][layer_id] = CUDA Event for D2H completion
|
||||
self.ring_slot_ready = [
|
||||
[torch.cuda.Event() for _ in range(num_layers)]
|
||||
for _ in range(self.num_ring_slots)
|
||||
]
|
||||
self.ring_slot_offload_done = [
|
||||
[torch.cuda.Event() for _ in range(num_layers)]
|
||||
for _ in range(self.num_ring_slots)
|
||||
]
|
||||
|
||||
# Per-slot events for all-layer operations (used in some legacy paths)
|
||||
self.ring_slot_all_layers_ready = [torch.cuda.Event() for _ in range(self.num_ring_slots)]
|
||||
self.ring_slot_all_layers_offload_done = [torch.cuda.Event() for _ in range(self.num_ring_slots)]
|
||||
|
||||
# ========== Event tracking for async transfers ==========
|
||||
self.pending_events: Dict[Tuple[int, int], torch.cuda.Event] = {}
|
||||
@@ -560,7 +565,7 @@ class OffloadEngine:
|
||||
f" kv_heads={self.num_kv_heads},\n"
|
||||
f" head_dim={self.head_dim},\n"
|
||||
f" dtype={self.dtype},\n"
|
||||
f" three_regions: decode_slot={self.decode_slot}, compute={self.compute_slots}, prefetch={self.prefetch_slots},\n"
|
||||
f" ring_buffer: {self.num_ring_slots} slots, decode_slot={self.decode_slot}, decode_load_slots={self.decode_load_slots},\n"
|
||||
f" gpu_memory={self.gpu_memory_bytes() / 1024**2:.1f}MB,\n"
|
||||
f" cpu_memory={self.cpu_memory_bytes() / 1024**2:.1f}MB\n"
|
||||
f")"
|
||||
@@ -570,174 +575,207 @@ class OffloadEngine:
|
||||
"""Wait for all offload operations to complete."""
|
||||
self.transfer_stream_main.synchronize()
|
||||
|
||||
# ========== Unified Ring Buffer methods ==========
|
||||
|
||||
# ----- Prefill: Ring Buffer slot management -----
|
||||
|
||||
def get_write_slot_for_prefill(self, chunk_idx: int) -> int:
|
||||
"""
|
||||
Get ring buffer slot for writing prefill chunk.
|
||||
|
||||
For prefill, ALL slots are used as ring buffer, cycling through.
|
||||
|
||||
Args:
|
||||
chunk_idx: Current chunk index (0, 1, 2, ...)
|
||||
|
||||
Returns:
|
||||
GPU slot index for writing
|
||||
"""
|
||||
return chunk_idx % self.num_ring_slots
|
||||
|
||||
def get_load_slots_for_prefill(self, write_slot_idx: int) -> List[int]:
|
||||
"""
|
||||
Get available slots for loading previous chunks during prefill.
|
||||
|
||||
Excludes the current write slot to avoid conflict.
|
||||
|
||||
Args:
|
||||
write_slot_idx: Current write slot index
|
||||
|
||||
Returns:
|
||||
List of slot indices available for loading (N-1 slots)
|
||||
"""
|
||||
return [i for i in range(self.num_ring_slots) if i != write_slot_idx]
|
||||
|
||||
# ----- Decode: slot management -----
|
||||
|
||||
def get_load_slots_for_decode(self) -> List[int]:
|
||||
"""
|
||||
Get slots available for loading during decode.
|
||||
|
||||
Excludes decode_slot (slot[0]) since it's used for writing new token's KV.
|
||||
|
||||
Returns:
|
||||
List of slot indices for loading (slots[1:])
|
||||
"""
|
||||
return self.decode_load_slots
|
||||
|
||||
# ----- Per-slot Per-layer loading methods -----
|
||||
|
||||
def load_to_slot_layer(self, slot_idx: int, layer_id: int, cpu_block_id: int) -> None:
|
||||
"""
|
||||
Async load a single CPU block to a ring buffer slot for one layer.
|
||||
|
||||
This is the core building block for ring buffer pipelining.
|
||||
|
||||
Args:
|
||||
slot_idx: Target GPU slot index
|
||||
layer_id: Layer index to load
|
||||
cpu_block_id: Source CPU block ID
|
||||
"""
|
||||
logger.debug(f"Ring load: layer={layer_id}, CPU[{cpu_block_id}] -> GPU slot[{slot_idx}]")
|
||||
|
||||
with torch.cuda.stream(self.transfer_stream_main):
|
||||
self.k_cache_gpu[layer_id, slot_idx].copy_(
|
||||
self.k_cache_cpu[layer_id, cpu_block_id], non_blocking=True
|
||||
)
|
||||
self.v_cache_gpu[layer_id, slot_idx].copy_(
|
||||
self.v_cache_cpu[layer_id, cpu_block_id], non_blocking=True
|
||||
)
|
||||
self.ring_slot_ready[slot_idx][layer_id].record(self.transfer_stream_main)
|
||||
|
||||
def wait_slot_layer(self, slot_idx: int, layer_id: int) -> None:
|
||||
"""
|
||||
Wait for a slot's loading to complete for a specific layer.
|
||||
|
||||
Args:
|
||||
slot_idx: GPU slot index to wait for
|
||||
layer_id: Layer index to wait for
|
||||
"""
|
||||
self.compute_stream.wait_event(self.ring_slot_ready[slot_idx][layer_id])
|
||||
|
||||
def load_to_slot_all_layers(self, slot_idx: int, cpu_block_id: int) -> None:
|
||||
"""
|
||||
Async load a CPU block to a ring buffer slot for ALL layers.
|
||||
|
||||
Args:
|
||||
slot_idx: Target GPU slot index
|
||||
cpu_block_id: Source CPU block ID
|
||||
"""
|
||||
logger.debug(f"Ring load all layers: CPU[{cpu_block_id}] -> GPU slot[{slot_idx}]")
|
||||
|
||||
with torch.cuda.stream(self.transfer_stream_main):
|
||||
self.k_cache_gpu[:, slot_idx].copy_(
|
||||
self.k_cache_cpu[:, cpu_block_id], non_blocking=True
|
||||
)
|
||||
self.v_cache_gpu[:, slot_idx].copy_(
|
||||
self.v_cache_cpu[:, cpu_block_id], non_blocking=True
|
||||
)
|
||||
self.ring_slot_all_layers_ready[slot_idx].record(self.transfer_stream_main)
|
||||
|
||||
def wait_slot_all_layers(self, slot_idx: int) -> None:
|
||||
"""Wait for a slot's loading to complete for ALL layers."""
|
||||
self.compute_stream.wait_event(self.ring_slot_all_layers_ready[slot_idx])
|
||||
|
||||
# ----- Slot offload methods -----
|
||||
|
||||
def offload_slot_to_cpu(self, slot_idx: int, cpu_block_id: int) -> None:
|
||||
"""
|
||||
Async offload a ring buffer slot to CPU (all layers).
|
||||
|
||||
Args:
|
||||
slot_idx: Source GPU slot index
|
||||
cpu_block_id: Target CPU block ID
|
||||
"""
|
||||
logger.debug(f"Ring offload: GPU slot[{slot_idx}] -> CPU[{cpu_block_id}]")
|
||||
|
||||
with torch.cuda.stream(self.transfer_stream_main):
|
||||
self.transfer_stream_main.wait_stream(self.compute_stream)
|
||||
self.k_cache_cpu[:, cpu_block_id].copy_(
|
||||
self.k_cache_gpu[:, slot_idx], non_blocking=True
|
||||
)
|
||||
self.v_cache_cpu[:, cpu_block_id].copy_(
|
||||
self.v_cache_gpu[:, slot_idx], non_blocking=True
|
||||
)
|
||||
self.ring_slot_all_layers_offload_done[slot_idx].record(self.transfer_stream_main)
|
||||
|
||||
def wait_slot_offload(self, slot_idx: int) -> None:
|
||||
"""Wait for slot offload to complete."""
|
||||
self.compute_stream.wait_event(self.ring_slot_all_layers_offload_done[slot_idx])
|
||||
|
||||
def offload_slot_layer_to_cpu(self, slot_idx: int, layer_id: int, cpu_block_id: int) -> None:
|
||||
"""
|
||||
Async offload a ring buffer slot to CPU for one layer.
|
||||
|
||||
Args:
|
||||
slot_idx: Source GPU slot index
|
||||
layer_id: Layer index to offload
|
||||
cpu_block_id: Target CPU block ID
|
||||
"""
|
||||
with torch.cuda.stream(self.transfer_stream_main):
|
||||
self.transfer_stream_main.wait_stream(self.compute_stream)
|
||||
self.k_cache_cpu[layer_id, cpu_block_id].copy_(
|
||||
self.k_cache_gpu[layer_id, slot_idx], non_blocking=True
|
||||
)
|
||||
self.v_cache_cpu[layer_id, cpu_block_id].copy_(
|
||||
self.v_cache_gpu[layer_id, slot_idx], non_blocking=True
|
||||
)
|
||||
self.ring_slot_offload_done[slot_idx][layer_id].record(self.transfer_stream_main)
|
||||
|
||||
def wait_slot_offload_layer(self, slot_idx: int, layer_id: int) -> None:
|
||||
"""Wait for slot offload to complete for a specific layer."""
|
||||
self.compute_stream.wait_event(self.ring_slot_offload_done[slot_idx][layer_id])
|
||||
|
||||
# ----- KV access methods for ring buffer -----
|
||||
|
||||
def get_kv_for_slot(self, slot_idx: int, layer_id: int) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Get KV for a single ring buffer slot.
|
||||
|
||||
Args:
|
||||
slot_idx: GPU slot index
|
||||
layer_id: Layer ID
|
||||
|
||||
Returns:
|
||||
(k_cache, v_cache), shape: [1, block_size, kv_heads, head_dim]
|
||||
"""
|
||||
k = self.k_cache_gpu[layer_id, slot_idx].unsqueeze(0) # [1, block_size, heads, dim]
|
||||
v = self.v_cache_gpu[layer_id, slot_idx].unsqueeze(0)
|
||||
return k, v
|
||||
|
||||
def get_kv_for_slots(
|
||||
self,
|
||||
layer_id: int,
|
||||
gpu_slots: List[int],
|
||||
slot_indices: List[int],
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Get KV for specified GPU slots.
|
||||
Get KV for multiple ring buffer slots.
|
||||
|
||||
Args:
|
||||
layer_id: Layer ID
|
||||
gpu_slots: List of GPU slot IDs
|
||||
slot_indices: List of GPU slot indices
|
||||
|
||||
Returns:
|
||||
(k_cache, v_cache), shape: [1, len(slots) * block_size, kv_heads, head_dim]
|
||||
"""
|
||||
if not gpu_slots:
|
||||
if not slot_indices:
|
||||
return None, None
|
||||
k = self.k_cache_gpu[layer_id, gpu_slots]
|
||||
v = self.v_cache_gpu[layer_id, gpu_slots]
|
||||
k = self.k_cache_gpu[layer_id, slot_indices]
|
||||
v = self.v_cache_gpu[layer_id, slot_indices]
|
||||
k = k.reshape(1, -1, self.num_kv_heads, self.head_dim)
|
||||
v = v.reshape(1, -1, self.num_kv_heads, self.head_dim)
|
||||
return k, v
|
||||
|
||||
# ========== Three-region GPU Buffer methods ==========
|
||||
|
||||
def load_to_compute(self, cpu_block_ids: List[int]) -> None:
|
||||
"""
|
||||
Async load CPU blocks to Compute region.
|
||||
|
||||
Args:
|
||||
cpu_block_ids: List of CPU block IDs to load
|
||||
"""
|
||||
if not cpu_block_ids:
|
||||
self.compute_ready.record(self.transfer_stream_main)
|
||||
return
|
||||
|
||||
num_to_load = min(len(cpu_block_ids), len(self.compute_slots))
|
||||
logger.debug(f"Compute load: CPU{cpu_block_ids[:num_to_load]} -> GPU compute slots {self.compute_slots[:num_to_load]}")
|
||||
|
||||
with torch.cuda.stream(self.transfer_stream_main):
|
||||
for i in range(num_to_load):
|
||||
cpu_id = cpu_block_ids[i]
|
||||
gpu_slot = self.compute_slots[i]
|
||||
# Copy all layers together
|
||||
self.k_cache_gpu[:, gpu_slot].copy_(
|
||||
self.k_cache_cpu[:, cpu_id], non_blocking=True
|
||||
)
|
||||
self.v_cache_gpu[:, gpu_slot].copy_(
|
||||
self.v_cache_cpu[:, cpu_id], non_blocking=True
|
||||
)
|
||||
self.compute_ready.record(self.transfer_stream_main)
|
||||
|
||||
def load_to_prefetch(self, cpu_block_ids: List[int]) -> None:
|
||||
"""
|
||||
Async load CPU blocks to Prefetch region.
|
||||
|
||||
Args:
|
||||
cpu_block_ids: List of CPU block IDs to load
|
||||
"""
|
||||
if not cpu_block_ids:
|
||||
self.prefetch_ready.record(self.transfer_stream_main)
|
||||
return
|
||||
|
||||
num_to_load = min(len(cpu_block_ids), len(self.prefetch_slots))
|
||||
logger.debug(f"Prefetch load: CPU{cpu_block_ids[:num_to_load]} -> GPU prefetch slots {self.prefetch_slots[:num_to_load]}")
|
||||
|
||||
with torch.cuda.stream(self.transfer_stream_main):
|
||||
for i in range(num_to_load):
|
||||
cpu_id = cpu_block_ids[i]
|
||||
gpu_slot = self.prefetch_slots[i]
|
||||
self.k_cache_gpu[:, gpu_slot].copy_(
|
||||
self.k_cache_cpu[:, cpu_id], non_blocking=True
|
||||
)
|
||||
self.v_cache_gpu[:, gpu_slot].copy_(
|
||||
self.v_cache_cpu[:, cpu_id], non_blocking=True
|
||||
)
|
||||
self.prefetch_ready.record(self.transfer_stream_main)
|
||||
|
||||
def wait_compute(self) -> None:
|
||||
"""Wait for Compute region loading to complete."""
|
||||
self.compute_stream.wait_event(self.compute_ready)
|
||||
|
||||
def load_to_compute_layer(self, layer_id: int, cpu_block_ids: List[int]) -> None:
|
||||
"""
|
||||
Load CPU blocks to Compute region for a single layer only.
|
||||
|
||||
This is used for per-layer chunked attention where each layer
|
||||
independently loads its KV data.
|
||||
|
||||
Args:
|
||||
layer_id: Layer index to load
|
||||
cpu_block_ids: List of CPU block IDs to load
|
||||
"""
|
||||
if not cpu_block_ids:
|
||||
self.compute_ready_per_layer[layer_id].record(self.transfer_stream_main)
|
||||
return
|
||||
|
||||
num_to_load = min(len(cpu_block_ids), len(self.compute_slots))
|
||||
logger.debug(f"Compute load (layer {layer_id}): CPU{cpu_block_ids[:num_to_load]} -> GPU compute slots {self.compute_slots[:num_to_load]}")
|
||||
|
||||
with torch.cuda.stream(self.transfer_stream_main):
|
||||
for i in range(num_to_load):
|
||||
cpu_id = cpu_block_ids[i]
|
||||
gpu_slot = self.compute_slots[i]
|
||||
# Copy only this layer (not all layers)
|
||||
self.k_cache_gpu[layer_id, gpu_slot].copy_(
|
||||
self.k_cache_cpu[layer_id, cpu_id], non_blocking=True
|
||||
)
|
||||
self.v_cache_gpu[layer_id, gpu_slot].copy_(
|
||||
self.v_cache_cpu[layer_id, cpu_id], non_blocking=True
|
||||
)
|
||||
self.compute_ready_per_layer[layer_id].record(self.transfer_stream_main)
|
||||
|
||||
def wait_compute_layer(self, layer_id: int) -> None:
|
||||
"""Wait for specific layer's Compute region loading to complete."""
|
||||
self.compute_stream.wait_event(self.compute_ready_per_layer[layer_id])
|
||||
|
||||
def wait_prefetch(self) -> None:
|
||||
"""Wait for Prefetch region loading to complete."""
|
||||
self.compute_stream.wait_event(self.prefetch_ready)
|
||||
|
||||
def load_to_prefetch_layer(self, layer_id: int, cpu_block_ids: List[int]) -> None:
|
||||
"""
|
||||
Load CPU blocks to Prefetch region for a single layer only.
|
||||
|
||||
This is used for per-layer chunked attention where each layer
|
||||
independently loads its KV data.
|
||||
|
||||
Args:
|
||||
layer_id: Layer index to load
|
||||
cpu_block_ids: List of CPU block IDs to load
|
||||
"""
|
||||
if not cpu_block_ids:
|
||||
self.prefetch_ready_per_layer[layer_id].record(self.transfer_stream_main)
|
||||
return
|
||||
|
||||
num_to_load = min(len(cpu_block_ids), len(self.prefetch_slots))
|
||||
logger.debug(f"Prefetch load (layer {layer_id}): CPU{cpu_block_ids[:num_to_load]} -> GPU prefetch slots {self.prefetch_slots[:num_to_load]}")
|
||||
|
||||
with torch.cuda.stream(self.transfer_stream_main):
|
||||
for i in range(num_to_load):
|
||||
cpu_id = cpu_block_ids[i]
|
||||
gpu_slot = self.prefetch_slots[i]
|
||||
# Copy only this layer (not all layers)
|
||||
self.k_cache_gpu[layer_id, gpu_slot].copy_(
|
||||
self.k_cache_cpu[layer_id, cpu_id], non_blocking=True
|
||||
)
|
||||
self.v_cache_gpu[layer_id, gpu_slot].copy_(
|
||||
self.v_cache_cpu[layer_id, cpu_id], non_blocking=True
|
||||
)
|
||||
self.prefetch_ready_per_layer[layer_id].record(self.transfer_stream_main)
|
||||
|
||||
def wait_prefetch_layer(self, layer_id: int) -> None:
|
||||
"""Wait for specific layer's Prefetch region loading to complete."""
|
||||
self.compute_stream.wait_event(self.prefetch_ready_per_layer[layer_id])
|
||||
|
||||
def swap_compute_prefetch(self) -> None:
|
||||
"""Swap roles of Compute region and Prefetch region."""
|
||||
self.compute_slots, self.prefetch_slots = self.prefetch_slots, self.compute_slots
|
||||
# ----- Decode slot methods (kept for decode phase) -----
|
||||
|
||||
def offload_decode_slot(self, cpu_block_id: int) -> None:
|
||||
"""
|
||||
Offload KV from Decode region to CPU.
|
||||
Offload KV from decode slot (slot[0]) to CPU.
|
||||
|
||||
Args:
|
||||
cpu_block_id: Target CPU block ID
|
||||
"""
|
||||
logger.debug(f"Decode offload: GPU[{self.decode_slot}] -> CPU[{cpu_block_id}]")
|
||||
logger.debug(f"Decode offload: GPU slot[{self.decode_slot}] -> CPU[{cpu_block_id}]")
|
||||
|
||||
with torch.cuda.stream(self.transfer_stream_main):
|
||||
self.transfer_stream_main.wait_stream(self.compute_stream)
|
||||
@@ -750,61 +788,16 @@ class OffloadEngine:
|
||||
self.decode_offload_done.record(self.transfer_stream_main)
|
||||
|
||||
def wait_decode_offload(self) -> None:
|
||||
"""Wait for Decode region offload to complete."""
|
||||
"""Wait for decode slot offload to complete."""
|
||||
self.compute_stream.wait_event(self.decode_offload_done)
|
||||
|
||||
def get_kv_for_compute(
|
||||
self,
|
||||
layer_id: int,
|
||||
num_blocks: int,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Get KV for specified number of blocks in Compute region.
|
||||
|
||||
Args:
|
||||
layer_id: Layer ID
|
||||
num_blocks: Number of blocks needed
|
||||
|
||||
Returns:
|
||||
(k_cache, v_cache), shape: [1, num_blocks * block_size, kv_heads, head_dim]
|
||||
"""
|
||||
slots = self.compute_slots[:num_blocks]
|
||||
k = self.k_cache_gpu[layer_id, slots] # [num_blocks, block_size, heads, dim]
|
||||
v = self.v_cache_gpu[layer_id, slots]
|
||||
# Reshape: [num_blocks, block_size, heads, dim] -> [1, num_blocks*block_size, heads, dim]
|
||||
k = k.reshape(1, -1, self.num_kv_heads, self.head_dim)
|
||||
v = v.reshape(1, -1, self.num_kv_heads, self.head_dim)
|
||||
return k, v
|
||||
|
||||
def get_kv_for_prefetch(
|
||||
self,
|
||||
layer_id: int,
|
||||
num_blocks: int,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Get KV for specified number of blocks in Prefetch region.
|
||||
|
||||
Args:
|
||||
layer_id: Layer ID
|
||||
num_blocks: Number of blocks needed
|
||||
|
||||
Returns:
|
||||
(k_cache, v_cache), shape: [1, num_blocks * block_size, kv_heads, head_dim]
|
||||
"""
|
||||
slots = self.prefetch_slots[:num_blocks]
|
||||
k = self.k_cache_gpu[layer_id, slots]
|
||||
v = self.v_cache_gpu[layer_id, slots]
|
||||
k = k.reshape(1, -1, self.num_kv_heads, self.head_dim)
|
||||
v = v.reshape(1, -1, self.num_kv_heads, self.head_dim)
|
||||
return k, v
|
||||
|
||||
def get_kv_for_decode_slot(
|
||||
self,
|
||||
layer_id: int,
|
||||
pos_in_block: int,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Get KV at specified position in Decode region (for new token during decode).
|
||||
Get KV at specified position in decode slot.
|
||||
|
||||
Args:
|
||||
layer_id: Layer ID
|
||||
@@ -813,9 +806,9 @@ class OffloadEngine:
|
||||
Returns:
|
||||
(k_cache, v_cache), shape: [1, 1, kv_heads, head_dim]
|
||||
"""
|
||||
k = self.k_cache_gpu[layer_id, self.decode_slot, pos_in_block:pos_in_block+1] # [1, heads, dim]
|
||||
k = self.k_cache_gpu[layer_id, self.decode_slot, pos_in_block:pos_in_block+1]
|
||||
v = self.v_cache_gpu[layer_id, self.decode_slot, pos_in_block:pos_in_block+1]
|
||||
k = k.unsqueeze(0) # [1, 1, heads, dim]
|
||||
k = k.unsqueeze(0)
|
||||
v = v.unsqueeze(0)
|
||||
return k, v
|
||||
|
||||
@@ -825,10 +818,7 @@ class OffloadEngine:
|
||||
num_tokens: int,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Get accumulated KV in Decode region (all tokens from position 0 to num_tokens-1).
|
||||
|
||||
Used when batching decode offloads - attend to all accumulated tokens,
|
||||
not just the current one.
|
||||
Get accumulated KV in decode slot (positions 0 to num_tokens-1).
|
||||
|
||||
Args:
|
||||
layer_id: Layer ID
|
||||
@@ -837,35 +827,102 @@ class OffloadEngine:
|
||||
Returns:
|
||||
(k_cache, v_cache), shape: [1, num_tokens, kv_heads, head_dim]
|
||||
"""
|
||||
k = self.k_cache_gpu[layer_id, self.decode_slot, :num_tokens] # [num_tokens, heads, dim]
|
||||
k = self.k_cache_gpu[layer_id, self.decode_slot, :num_tokens]
|
||||
v = self.v_cache_gpu[layer_id, self.decode_slot, :num_tokens]
|
||||
k = k.unsqueeze(0) # [1, num_tokens, heads, dim]
|
||||
k = k.unsqueeze(0)
|
||||
v = v.unsqueeze(0)
|
||||
return k, v
|
||||
|
||||
def offload_compute_to_cpu(self, cpu_block_ids: List[int]) -> None:
|
||||
"""
|
||||
Offload KV from Compute region to CPU.
|
||||
# ----- Legacy compatibility methods (for decode double-buffering) -----
|
||||
|
||||
Args:
|
||||
cpu_block_ids: Target CPU block IDs list
|
||||
def load_to_compute_layer(self, layer_id: int, cpu_block_ids: List[int]) -> None:
|
||||
"""
|
||||
Legacy: Load CPU blocks to decode_load_slots for decode double-buffering.
|
||||
|
||||
Uses first half of decode_load_slots as 'compute' region.
|
||||
"""
|
||||
if not cpu_block_ids:
|
||||
return
|
||||
|
||||
num_to_offload = min(len(cpu_block_ids), len(self.compute_slots))
|
||||
logger.debug(f"Compute offload: GPU {self.compute_slots[:num_to_offload]} -> CPU{cpu_block_ids[:num_to_offload]}")
|
||||
half = max(1, len(self.decode_load_slots) // 2)
|
||||
slots = self.decode_load_slots[:half]
|
||||
num_to_load = min(len(cpu_block_ids), len(slots))
|
||||
|
||||
with torch.cuda.stream(self.transfer_stream_main):
|
||||
# Wait for compute to complete
|
||||
self.transfer_stream_main.wait_stream(self.compute_stream)
|
||||
|
||||
for i in range(num_to_offload):
|
||||
gpu_slot = self.compute_slots[i]
|
||||
for i in range(num_to_load):
|
||||
cpu_id = cpu_block_ids[i]
|
||||
self.k_cache_cpu[:, cpu_id].copy_(
|
||||
self.k_cache_gpu[:, gpu_slot], non_blocking=True
|
||||
gpu_slot = slots[i]
|
||||
self.k_cache_gpu[layer_id, gpu_slot].copy_(
|
||||
self.k_cache_cpu[layer_id, cpu_id], non_blocking=True
|
||||
)
|
||||
self.v_cache_cpu[:, cpu_id].copy_(
|
||||
self.v_cache_gpu[:, gpu_slot], non_blocking=True
|
||||
self.v_cache_gpu[layer_id, gpu_slot].copy_(
|
||||
self.v_cache_cpu[layer_id, cpu_id], non_blocking=True
|
||||
)
|
||||
if num_to_load > 0:
|
||||
self.ring_slot_ready[slots[0]][layer_id].record(self.transfer_stream_main)
|
||||
|
||||
def wait_compute_layer(self, layer_id: int) -> None:
|
||||
"""Legacy: Wait for 'compute' region loading."""
|
||||
half = max(1, len(self.decode_load_slots) // 2)
|
||||
if self.decode_load_slots:
|
||||
self.wait_slot_layer(self.decode_load_slots[0], layer_id)
|
||||
|
||||
def load_to_prefetch_layer(self, layer_id: int, cpu_block_ids: List[int]) -> None:
|
||||
"""
|
||||
Legacy: Load CPU blocks to decode_load_slots for decode double-buffering.
|
||||
|
||||
Uses second half of decode_load_slots as 'prefetch' region.
|
||||
"""
|
||||
if not cpu_block_ids:
|
||||
return
|
||||
|
||||
half = max(1, len(self.decode_load_slots) // 2)
|
||||
slots = self.decode_load_slots[half:]
|
||||
if not slots:
|
||||
slots = self.decode_load_slots # Fallback if only 1-2 slots
|
||||
num_to_load = min(len(cpu_block_ids), len(slots))
|
||||
|
||||
with torch.cuda.stream(self.transfer_stream_main):
|
||||
for i in range(num_to_load):
|
||||
cpu_id = cpu_block_ids[i]
|
||||
gpu_slot = slots[i]
|
||||
self.k_cache_gpu[layer_id, gpu_slot].copy_(
|
||||
self.k_cache_cpu[layer_id, cpu_id], non_blocking=True
|
||||
)
|
||||
self.v_cache_gpu[layer_id, gpu_slot].copy_(
|
||||
self.v_cache_cpu[layer_id, cpu_id], non_blocking=True
|
||||
)
|
||||
if num_to_load > 0:
|
||||
self.ring_slot_ready[slots[0]][layer_id].record(self.transfer_stream_main)
|
||||
|
||||
def wait_prefetch_layer(self, layer_id: int) -> None:
|
||||
"""Legacy: Wait for 'prefetch' region loading."""
|
||||
half = max(1, len(self.decode_load_slots) // 2)
|
||||
slots = self.decode_load_slots[half:]
|
||||
if slots:
|
||||
self.wait_slot_layer(slots[0], layer_id)
|
||||
elif self.decode_load_slots:
|
||||
self.wait_slot_layer(self.decode_load_slots[0], layer_id)
|
||||
|
||||
def get_kv_for_compute(
|
||||
self,
|
||||
layer_id: int,
|
||||
num_blocks: int,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""Legacy: Get KV from 'compute' region (first half of decode_load_slots)."""
|
||||
half = max(1, len(self.decode_load_slots) // 2)
|
||||
slots = self.decode_load_slots[:half][:num_blocks]
|
||||
return self.get_kv_for_slots(layer_id, slots)
|
||||
|
||||
def get_kv_for_prefetch(
|
||||
self,
|
||||
layer_id: int,
|
||||
num_blocks: int,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""Legacy: Get KV from 'prefetch' region (second half of decode_load_slots)."""
|
||||
half = max(1, len(self.decode_load_slots) // 2)
|
||||
slots = self.decode_load_slots[half:]
|
||||
if not slots:
|
||||
slots = self.decode_load_slots
|
||||
slots = slots[:num_blocks]
|
||||
return self.get_kv_for_slots(layer_id, slots)
|
||||
@@ -100,16 +100,19 @@ class Attention(nn.Module):
|
||||
context,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute attention with three-region GPU buffer for chunked prefill.
|
||||
Compute attention with unified ring buffer for chunked prefill.
|
||||
|
||||
For chunked prefill:
|
||||
1. Load previous KV from CPU using Compute/Prefetch region (if any previous chunks)
|
||||
2. Compute attention against previous KV chunks (no causal mask)
|
||||
3. Compute attention against current chunk's KV (causal)
|
||||
4. Merge all results using online softmax
|
||||
Ring buffer design:
|
||||
- Current chunk's KV is written to ring_slot[chunk_idx % N]
|
||||
- Previous chunks' KV are loaded from CPU using N-1 available slots
|
||||
- Pipeline: pre-fill slots, then process with overlapped load/compute
|
||||
|
||||
Three-region design guarantees: current chunk's KV is in Compute region, previous KV is loaded
|
||||
from CPU to Prefetch region, so write and load regions never overlap.
|
||||
For each layer:
|
||||
1. Current chunk's KV is in k_batched, v_batched (just written by model)
|
||||
2. Load previous chunks from CPU using available slots (pipeline)
|
||||
3. Compute attention against previous KV (no causal mask)
|
||||
4. Compute attention against current KV (causal)
|
||||
5. Merge all results using online softmax
|
||||
"""
|
||||
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||
|
||||
@@ -122,50 +125,32 @@ class Attention(nn.Module):
|
||||
o_acc = None
|
||||
lse_acc = None
|
||||
|
||||
# Load previous KV from CPU using Compute/Prefetch region
|
||||
kvcache_manager = context.kvcache_manager
|
||||
seq = context.chunked_seq if hasattr(context, 'chunked_seq') else None
|
||||
current_chunk_idx = context.current_chunk_idx
|
||||
|
||||
if kvcache_manager is not None and seq is not None and self.layer_id >= 0:
|
||||
# Get prefilled CPU blocks (blocks already written in previous chunks)
|
||||
# Get prefilled CPU blocks (blocks from previous chunks)
|
||||
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
|
||||
|
||||
if cpu_block_table:
|
||||
offload_engine = kvcache_manager.offload_engine
|
||||
# For prefill: ONLY use Prefetch region to avoid conflict with
|
||||
# current chunk's KV being written to Compute region slots
|
||||
# Use synchronous per-layer loading (async would conflict with writes)
|
||||
chunk_size = offload_engine.num_prefetch_blocks
|
||||
num_chunks = (len(cpu_block_table) + chunk_size - 1) // chunk_size
|
||||
|
||||
for chunk_idx in range(num_chunks):
|
||||
start = chunk_idx * chunk_size
|
||||
end = min(start + chunk_size, len(cpu_block_table))
|
||||
num_blocks_in_chunk = end - start
|
||||
chunk_ids = cpu_block_table[start:end]
|
||||
# Get write slot for current chunk and available load slots
|
||||
write_slot = offload_engine.get_write_slot_for_prefill(current_chunk_idx)
|
||||
load_slots = offload_engine.get_load_slots_for_prefill(write_slot)
|
||||
pipeline_depth = len(load_slots)
|
||||
|
||||
# Load to Prefetch region (per-layer, sync)
|
||||
offload_engine.load_to_prefetch_layer(self.layer_id, chunk_ids)
|
||||
offload_engine.wait_prefetch_layer(self.layer_id)
|
||||
|
||||
prev_k, prev_v = offload_engine.get_kv_for_prefetch(
|
||||
self.layer_id, num_blocks_in_chunk
|
||||
if pipeline_depth == 0:
|
||||
# Only 1 slot total, cannot pipeline - use sync loading
|
||||
o_acc, lse_acc = self._sync_load_previous_chunks(
|
||||
q_batched, cpu_block_table, offload_engine
|
||||
)
|
||||
|
||||
# Compute attention against this chunk (no causal mask)
|
||||
prev_o, prev_lse = flash_attn_with_lse(
|
||||
q_batched,
|
||||
prev_k,
|
||||
prev_v,
|
||||
softmax_scale=self.scale,
|
||||
causal=False,
|
||||
)
|
||||
|
||||
# Merge with accumulated
|
||||
if o_acc is None:
|
||||
o_acc, lse_acc = prev_o, prev_lse
|
||||
else:
|
||||
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
|
||||
# Use ring buffer pipeline
|
||||
o_acc, lse_acc = self._ring_buffer_pipeline_load(
|
||||
q_batched, cpu_block_table, load_slots, offload_engine
|
||||
)
|
||||
|
||||
# Compute attention against current chunk's KV (with causal mask)
|
||||
current_o, current_lse = flash_attn_with_lse(
|
||||
@@ -185,6 +170,91 @@ class Attention(nn.Module):
|
||||
# Remove batch dimension: [1, total_tokens, heads, dim] -> [total_tokens, heads, dim]
|
||||
return final_o.squeeze(0)
|
||||
|
||||
def _sync_load_previous_chunks(
|
||||
self,
|
||||
q_batched: torch.Tensor,
|
||||
cpu_block_table: list,
|
||||
offload_engine,
|
||||
):
|
||||
"""Synchronous loading fallback when pipeline_depth=0."""
|
||||
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||
|
||||
o_acc, lse_acc = None, None
|
||||
|
||||
for block_idx, cpu_block_id in enumerate(cpu_block_table):
|
||||
# Load to slot 0 (single slot)
|
||||
offload_engine.load_to_slot_layer(0, self.layer_id, cpu_block_id)
|
||||
offload_engine.wait_slot_layer(0, self.layer_id)
|
||||
|
||||
prev_k, prev_v = offload_engine.get_kv_for_slot(0, self.layer_id)
|
||||
|
||||
prev_o, prev_lse = flash_attn_with_lse(
|
||||
q_batched, prev_k, prev_v,
|
||||
softmax_scale=self.scale,
|
||||
causal=False,
|
||||
)
|
||||
|
||||
if o_acc is None:
|
||||
o_acc, lse_acc = prev_o, prev_lse
|
||||
else:
|
||||
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
|
||||
|
||||
return o_acc, lse_acc
|
||||
|
||||
def _ring_buffer_pipeline_load(
|
||||
self,
|
||||
q_batched: torch.Tensor,
|
||||
cpu_block_table: list,
|
||||
load_slots: list,
|
||||
offload_engine,
|
||||
):
|
||||
"""
|
||||
Ring buffer synchronous loading for previous chunks.
|
||||
|
||||
For correctness, we use synchronous loading:
|
||||
- Load one block at a time
|
||||
- Wait for transfer, compute attention, then load next
|
||||
|
||||
This ensures no data races between transfer and computation.
|
||||
"""
|
||||
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||
|
||||
num_blocks = len(cpu_block_table)
|
||||
if num_blocks == 0:
|
||||
return None, None
|
||||
|
||||
pipeline_depth = len(load_slots)
|
||||
o_acc, lse_acc = None, None
|
||||
|
||||
# Process blocks one by one (synchronous)
|
||||
for block_idx in range(num_blocks):
|
||||
# Determine which slot to use (cycle through load_slots)
|
||||
slot_idx = load_slots[block_idx % pipeline_depth]
|
||||
cpu_block_id = cpu_block_table[block_idx]
|
||||
|
||||
# Load block to slot (async)
|
||||
offload_engine.load_to_slot_layer(slot_idx, self.layer_id, cpu_block_id)
|
||||
|
||||
# Wait for transfer to complete
|
||||
offload_engine.wait_slot_layer(slot_idx, self.layer_id)
|
||||
|
||||
# Get KV from slot and compute attention
|
||||
prev_k, prev_v = offload_engine.get_kv_for_slot(slot_idx, self.layer_id)
|
||||
|
||||
prev_o, prev_lse = flash_attn_with_lse(
|
||||
q_batched, prev_k, prev_v,
|
||||
softmax_scale=self.scale,
|
||||
causal=False,
|
||||
)
|
||||
|
||||
# Merge with accumulated
|
||||
if o_acc is None:
|
||||
o_acc, lse_acc = prev_o, prev_lse
|
||||
else:
|
||||
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
|
||||
|
||||
return o_acc, lse_acc
|
||||
|
||||
def _chunked_decode_attention(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
@@ -193,20 +263,24 @@ class Attention(nn.Module):
|
||||
context,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute decode attention with async double-buffering using Compute and Prefetch regions.
|
||||
Compute decode attention with double-buffering using decode_load_slots.
|
||||
|
||||
Decode uses:
|
||||
- decode_slot (slot[0]): writes new token's KV
|
||||
- decode_load_slots (slots[1:]): load previous chunks from CPU
|
||||
|
||||
Pipeline design:
|
||||
- Compute region: holds current chunk being computed
|
||||
- Prefetch region: async loads next chunk while current is computing
|
||||
- After computation, swap roles of the two regions
|
||||
- First half of decode_load_slots: 'compute' buffer
|
||||
- Second half: 'prefetch' buffer
|
||||
- Double-buffer between them for async overlap
|
||||
|
||||
Timeline:
|
||||
┌─────────────┐ ┌─────────────┐ ┌─────────────┐
|
||||
│Load C0→Comp │ │Load C1→Pref │ │Load C2→Comp │ ...
|
||||
│Load C0→buf0 │ │Load C1→buf1 │ │Load C2→buf0 │ ...
|
||||
└─────────────┘ └─────────────┘ └─────────────┘
|
||||
↘ ↘ ↘
|
||||
┌─────────────┐ ┌─────────────┐ ┌─────────────┐
|
||||
│ Compute C0 │ │ Compute C1 │ │ Compute C2 │
|
||||
│ Attn(C0) │ │ Attn(C1) │ │ Attn(C2) │
|
||||
└─────────────┘ └─────────────┘ └─────────────┘
|
||||
"""
|
||||
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||
|
||||
@@ -32,6 +32,8 @@ class Context:
|
||||
# Starting position within block where decode tokens began (for accumulated token tracking)
|
||||
# Used when batching decode offloads - we need to attend to all accumulated tokens
|
||||
decode_start_pos_in_block: int = 0
|
||||
# Current chunk index for ring buffer pipeline (prefill only)
|
||||
current_chunk_idx: int = 0
|
||||
|
||||
|
||||
_CONTEXT = Context()
|
||||
@@ -57,6 +59,7 @@ def set_context(
|
||||
chunked_seq=None,
|
||||
decode_pos_in_block=0,
|
||||
decode_start_pos_in_block=0,
|
||||
current_chunk_idx=0,
|
||||
):
|
||||
global _CONTEXT
|
||||
_CONTEXT = Context(
|
||||
@@ -75,6 +78,7 @@ def set_context(
|
||||
chunked_seq=chunked_seq,
|
||||
decode_pos_in_block=decode_pos_in_block,
|
||||
decode_start_pos_in_block=decode_start_pos_in_block,
|
||||
current_chunk_idx=current_chunk_idx,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user