[refactor] Remove legacy mode path.
This commit is contained in:
11
CLAUDE.md
11
CLAUDE.md
@@ -173,13 +173,16 @@ Compute: [C0] [C1] [C2]
|
|||||||
|
|
||||||
**File**: `nanovllm/kvcache/hybrid_manager.py`
|
**File**: `nanovllm/kvcache/hybrid_manager.py`
|
||||||
|
|
||||||
Manages both GPU and CPU blocks:
|
CPU-primary KV cache manager with GPU ring buffer design:
|
||||||
- `allocate()`: Allocate GPU block first, fallback to CPU
|
- All KV cache is stored on CPU as primary storage
|
||||||
- `allocate_cpu_only()`: Force CPU allocation (for ring buffer mode)
|
- GPU is used as a ring buffer for computation only
|
||||||
|
- Ring buffer enables pipelined H2D transfers overlapped with computation
|
||||||
|
|
||||||
|
Key methods:
|
||||||
|
- `allocate()` / `allocate_cpu_only()`: Allocate all blocks to CPU
|
||||||
- `get_all_cpu_blocks(seq)`: Get all CPU block IDs for a sequence
|
- `get_all_cpu_blocks(seq)`: Get all CPU block IDs for a sequence
|
||||||
- `get_prefilled_cpu_blocks(seq)`: Get CPU blocks from previous chunks
|
- `get_prefilled_cpu_blocks(seq)`: Get CPU blocks from previous chunks
|
||||||
- `get_write_slot_for_chunked_offload(seq)`: Get GPU slot for writing new KV (returns decode_slot)
|
- `get_write_slot_for_chunked_offload(seq)`: Get GPU slot for writing new KV (returns decode_slot)
|
||||||
- `may_offload()`: Offload GPU blocks to CPU when decode slot fills
|
|
||||||
|
|
||||||
### Online Softmax Merge
|
### Online Softmax Merge
|
||||||
|
|
||||||
|
|||||||
@@ -388,26 +388,6 @@ class ModelRunner:
|
|||||||
else:
|
else:
|
||||||
return self.run_chunked_offload_decode(seqs)
|
return self.run_chunked_offload_decode(seqs)
|
||||||
|
|
||||||
# Check if chunked prefill is needed (legacy path)
|
|
||||||
if is_prefill and hasattr(self, 'kvcache_manager'):
|
|
||||||
needs_chunked = any(
|
|
||||||
hasattr(self.kvcache_manager, 'needs_chunked_prefill') and
|
|
||||||
self.kvcache_manager.needs_chunked_prefill(seq)
|
|
||||||
for seq in seqs if seq.block_table
|
|
||||||
)
|
|
||||||
if needs_chunked:
|
|
||||||
return self.run_chunked_prefill(seqs)
|
|
||||||
|
|
||||||
# Check if chunked decode is needed (legacy path)
|
|
||||||
if not is_prefill and hasattr(self, 'kvcache_manager'):
|
|
||||||
needs_chunked = any(
|
|
||||||
hasattr(self.kvcache_manager, 'needs_chunked_decode') and
|
|
||||||
self.kvcache_manager.needs_chunked_decode(seq)
|
|
||||||
for seq in seqs if seq.block_table
|
|
||||||
)
|
|
||||||
if needs_chunked:
|
|
||||||
return self.run_chunked_decode(seqs)
|
|
||||||
|
|
||||||
input_ids, positions = self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs)
|
input_ids, positions = self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs)
|
||||||
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
|
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
|
||||||
logits = self.run_model(input_ids, positions, is_prefill)
|
logits = self.run_model(input_ids, positions, is_prefill)
|
||||||
@@ -445,194 +425,6 @@ class ModelRunner:
|
|||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def run_chunked_prefill(self, seqs: list[Sequence]) -> list[int]:
|
|
||||||
"""
|
|
||||||
Run prefill in chunks when sequences exceed GPU capacity.
|
|
||||||
|
|
||||||
For each chunk:
|
|
||||||
1. Process tokens through model forward pass
|
|
||||||
2. At each attention layer:
|
|
||||||
- Load previous KV from CPU (handled by attention layer)
|
|
||||||
- Compute attention with online softmax merging
|
|
||||||
- Store current KV to GPU cache
|
|
||||||
3. After chunk completes, offload KV to CPU
|
|
||||||
4. Load next chunk's blocks to GPU
|
|
||||||
"""
|
|
||||||
import sys
|
|
||||||
|
|
||||||
# Currently only supporting single sequence for chunked prefill
|
|
||||||
assert len(seqs) == 1, "Chunked prefill only supports single sequence"
|
|
||||||
seq = seqs[0]
|
|
||||||
|
|
||||||
total_blocks = seq.num_blocks
|
|
||||||
print(f"[Chunked Prefill] Starting: {total_blocks} total blocks, "
|
|
||||||
f"GPU slots: {self.kvcache_manager.num_gpu_slots}", file=sys.stderr)
|
|
||||||
|
|
||||||
chunk_num = 0
|
|
||||||
logits = None
|
|
||||||
|
|
||||||
while True:
|
|
||||||
# Get chunk info (which blocks are on GPU and not yet prefilled)
|
|
||||||
chunk_info = self.kvcache_manager.get_gpu_block_tables_partial(seqs)
|
|
||||||
gpu_blocks, start_block_idx, end_block_idx = chunk_info[0]
|
|
||||||
|
|
||||||
if not gpu_blocks:
|
|
||||||
# No more blocks to process
|
|
||||||
break
|
|
||||||
|
|
||||||
chunk_num += 1
|
|
||||||
chunk_tokens = (end_block_idx - start_block_idx) * self.block_size
|
|
||||||
if end_block_idx == seq.num_blocks:
|
|
||||||
# Last block may be partial
|
|
||||||
chunk_tokens = len(seq) - start_block_idx * self.block_size
|
|
||||||
|
|
||||||
print(f"[Chunked Prefill] Chunk {chunk_num}: blocks {start_block_idx}-{end_block_idx-1}, "
|
|
||||||
f"~{chunk_tokens} tokens", file=sys.stderr)
|
|
||||||
|
|
||||||
# Prepare inputs for this chunk
|
|
||||||
input_ids, positions = self._prepare_chunked_prefill(seq, gpu_blocks, start_block_idx, end_block_idx)
|
|
||||||
|
|
||||||
if input_ids.numel() == 0:
|
|
||||||
print(f"[Chunked Prefill] No input tokens, breaking", file=sys.stderr)
|
|
||||||
break
|
|
||||||
|
|
||||||
print(f"[Chunked Prefill] Running model with {input_ids.numel()} tokens...", file=sys.stderr)
|
|
||||||
|
|
||||||
# Run model forward pass
|
|
||||||
logits = self.run_model(input_ids, positions, is_prefill=True)
|
|
||||||
reset_context()
|
|
||||||
|
|
||||||
print(f"[Chunked Prefill] Model forward complete", file=sys.stderr)
|
|
||||||
|
|
||||||
# Check if this is the last chunk
|
|
||||||
# Mark current chunk as prefilled and offload to CPU
|
|
||||||
self.kvcache_manager.complete_prefill_chunk(seq)
|
|
||||||
|
|
||||||
# Check if more chunks needed
|
|
||||||
if not self.kvcache_manager.needs_chunked_prefill(seq):
|
|
||||||
print(f"[Chunked Prefill] All chunks done, sampling", file=sys.stderr)
|
|
||||||
break
|
|
||||||
|
|
||||||
print(f"[Chunked Prefill] Chunk transfer complete, loading next...", file=sys.stderr)
|
|
||||||
|
|
||||||
# Sample from the last chunk's logits
|
|
||||||
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
|
|
||||||
if logits is not None:
|
|
||||||
token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None
|
|
||||||
else:
|
|
||||||
token_ids = [0] if self.rank == 0 else None
|
|
||||||
|
|
||||||
return token_ids
|
|
||||||
|
|
||||||
def run_chunked_decode(self, seqs: list[Sequence]) -> list[int]:
|
|
||||||
"""
|
|
||||||
Run decode with chunked attention when sequence exceeds GPU capacity.
|
|
||||||
|
|
||||||
For decode, we need attention over ALL previous tokens. With CPU offload,
|
|
||||||
we load KV chunks and compute attention incrementally per-layer.
|
|
||||||
|
|
||||||
Flow:
|
|
||||||
1. Ensure last block is on GPU (for writing new KV token)
|
|
||||||
2. Run model forward - each attention layer:
|
|
||||||
a. Compute attention on GPU blocks
|
|
||||||
b. Load CPU blocks in chunks, compute + merge
|
|
||||||
3. Sample from output
|
|
||||||
"""
|
|
||||||
# Currently only supporting single sequence for chunked decode
|
|
||||||
assert len(seqs) == 1, "Chunked decode only supports single sequence"
|
|
||||||
seq = seqs[0]
|
|
||||||
|
|
||||||
# Prepare inputs
|
|
||||||
input_ids = torch.tensor([seq.last_token], dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
|
||||||
positions = torch.tensor([len(seq) - 1], dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
|
||||||
|
|
||||||
# Ensure last block is on GPU for writing new KV token
|
|
||||||
last_gpu_slot = self.kvcache_manager.ensure_last_block_on_gpu(seq)
|
|
||||||
slot = last_gpu_slot * self.block_size + seq.last_block_num_tokens - 1
|
|
||||||
slot_mapping = torch.tensor([slot], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
|
||||||
context_len = torch.tensor([len(seq)], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
|
||||||
|
|
||||||
# Set up context for chunked decode
|
|
||||||
set_context(
|
|
||||||
is_prefill=False, # Decode mode
|
|
||||||
slot_mapping=slot_mapping,
|
|
||||||
context_lens=context_len,
|
|
||||||
is_chunked_prefill=True, # Use chunked attention path
|
|
||||||
kvcache_manager=self.kvcache_manager,
|
|
||||||
chunked_seq=seq,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Run model forward pass
|
|
||||||
# Each attention layer will handle chunked KV loading internally
|
|
||||||
logits = self.run_model(input_ids, positions, is_prefill=False)
|
|
||||||
reset_context()
|
|
||||||
|
|
||||||
# Sample
|
|
||||||
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
|
|
||||||
token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None
|
|
||||||
|
|
||||||
return token_ids
|
|
||||||
|
|
||||||
def _prepare_chunked_prefill(
|
|
||||||
self,
|
|
||||||
seq: Sequence,
|
|
||||||
gpu_blocks: list[int],
|
|
||||||
start_block_idx: int,
|
|
||||||
end_block_idx: int,
|
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
"""
|
|
||||||
Prepare inputs for a single chunk in chunked prefill.
|
|
||||||
|
|
||||||
Sets up context with is_chunked_prefill=True so attention layers
|
|
||||||
know to load previous KV from CPU.
|
|
||||||
"""
|
|
||||||
# Calculate token range for this chunk
|
|
||||||
start_token = start_block_idx * self.block_size
|
|
||||||
end_token = min(end_block_idx * self.block_size, len(seq))
|
|
||||||
|
|
||||||
# Input tokens for this chunk
|
|
||||||
input_ids = seq[start_token:end_token]
|
|
||||||
positions = list(range(start_token, end_token))
|
|
||||||
|
|
||||||
# Slot mapping for storing KV cache
|
|
||||||
slot_mapping = []
|
|
||||||
for i, gpu_block_id in enumerate(gpu_blocks):
|
|
||||||
block_idx = start_block_idx + i
|
|
||||||
start = gpu_block_id * self.block_size
|
|
||||||
if block_idx != seq.num_blocks - 1:
|
|
||||||
end = start + self.block_size
|
|
||||||
else:
|
|
||||||
end = start + seq.last_block_num_tokens
|
|
||||||
slot_mapping.extend(list(range(start, end)))
|
|
||||||
|
|
||||||
# Trim slot_mapping to match actual token count
|
|
||||||
actual_tokens = end_token - start_token
|
|
||||||
slot_mapping = slot_mapping[:actual_tokens]
|
|
||||||
|
|
||||||
# Convert to tensors
|
|
||||||
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
|
||||||
positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
|
||||||
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
|
||||||
|
|
||||||
# Set up context for chunked prefill
|
|
||||||
seqlen = actual_tokens
|
|
||||||
cu_seqlens_q = torch.tensor([0, seqlen], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
|
||||||
cu_seqlens_k = torch.tensor([0, seqlen], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
|
||||||
|
|
||||||
set_context(
|
|
||||||
is_prefill=True,
|
|
||||||
cu_seqlens_q=cu_seqlens_q,
|
|
||||||
cu_seqlens_k=cu_seqlens_k,
|
|
||||||
max_seqlen_q=seqlen,
|
|
||||||
max_seqlen_k=seqlen,
|
|
||||||
slot_mapping=slot_mapping,
|
|
||||||
is_chunked_prefill=True,
|
|
||||||
kvcache_manager=self.kvcache_manager, # Pass manager for loading previous KV
|
|
||||||
chunked_seq=seq, # Pass sequence for loading previous KV
|
|
||||||
)
|
|
||||||
|
|
||||||
return input_ids, positions
|
|
||||||
|
|
||||||
def run_chunked_offload_prefill(self, seqs: list[Sequence]) -> list[int]:
|
def run_chunked_offload_prefill(self, seqs: list[Sequence]) -> list[int]:
|
||||||
"""
|
"""
|
||||||
Run prefill with unified ring buffer (CPU is primary storage).
|
Run prefill with unified ring buffer (CPU is primary storage).
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ KV Cache management module.
|
|||||||
|
|
||||||
This module provides pluggable KV cache management strategies:
|
This module provides pluggable KV cache management strategies:
|
||||||
- GPUOnlyManager: Pure GPU (default, current nano-vllm behavior)
|
- GPUOnlyManager: Pure GPU (default, current nano-vllm behavior)
|
||||||
- HybridKVCacheManager: CPU offload with CUDA Graph support
|
- HybridKVCacheManager: CPU-primary storage with GPU ring buffer for computation
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
from nanovllm.kvcache import create_kvcache_manager
|
from nanovllm.kvcache import create_kvcache_manager
|
||||||
|
|||||||
@@ -65,22 +65,19 @@ class LogicalBlock:
|
|||||||
|
|
||||||
class HybridKVCacheManager(KVCacheManager):
|
class HybridKVCacheManager(KVCacheManager):
|
||||||
"""
|
"""
|
||||||
Hybrid CPU-GPU KV cache manager with CUDA Graph support.
|
Hybrid CPU-GPU KV cache manager with ring buffer design.
|
||||||
|
|
||||||
Architecture:
|
Architecture (CPU-primary mode):
|
||||||
- GPU buffer: Fixed-size working set (num_gpu_slots)
|
- CPU pool: Primary storage for all KV cache (num_cpu_blocks)
|
||||||
- CPU pool: Overflow storage (num_cpu_blocks)
|
- GPU buffer: Ring buffer for computation (num_gpu_slots)
|
||||||
- Logical blocks: What sequences reference (num_gpu_slots + num_cpu_blocks)
|
- Logical blocks: What sequences reference (num_gpu_slots + num_cpu_blocks)
|
||||||
|
|
||||||
CUDA Graph compatibility:
|
Design:
|
||||||
- All tensor addresses fixed at init time
|
- All KV cache is stored on CPU as primary storage
|
||||||
- prepare_for_attention() updates gather_indices (outside graph)
|
- GPU is used as a ring buffer for computation only
|
||||||
- gathered_h2d_layer() executes transfer (inside graph)
|
- During prefill: KV is written to GPU ring slot, then offloaded to CPU
|
||||||
|
- During decode: Previous KV is loaded from CPU to GPU for attention
|
||||||
Strategy:
|
- Ring buffer enables pipelined H2D transfers overlapped with computation
|
||||||
1. New KV data written to GPU slots
|
|
||||||
2. Cold blocks evicted to CPU using configurable policy
|
|
||||||
3. Needed blocks prefetched back to GPU before attention
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -89,26 +86,25 @@ class HybridKVCacheManager(KVCacheManager):
|
|||||||
num_cpu_blocks: int,
|
num_cpu_blocks: int,
|
||||||
block_size: int,
|
block_size: int,
|
||||||
policy: Optional[EvictionPolicy] = None,
|
policy: Optional[EvictionPolicy] = None,
|
||||||
cpu_primary: bool = True,
|
|
||||||
num_prefetch_blocks: int = 2,
|
num_prefetch_blocks: int = 2,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize hybrid manager.
|
Initialize hybrid manager with CPU-primary ring buffer design.
|
||||||
|
|
||||||
|
All KV cache is stored on CPU as primary storage. GPU slots are used
|
||||||
|
as a ring buffer for computation only.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
num_gpu_slots: Number of GPU buffer slots (working set)
|
num_gpu_slots: Number of GPU buffer slots (ring buffer for computation)
|
||||||
num_cpu_blocks: Number of CPU pool blocks (overflow or primary storage)
|
num_cpu_blocks: Number of CPU pool blocks (primary storage)
|
||||||
block_size: Tokens per block
|
block_size: Tokens per block
|
||||||
policy: Eviction policy (default: LRU)
|
policy: Eviction policy (default: LRU, used for prefix cache management)
|
||||||
cpu_primary: If True, use CPU as primary storage with ring buffer GPU design.
|
|
||||||
If False, use GPU as primary with CPU as overflow (legacy mode).
|
|
||||||
num_prefetch_blocks: Number of blocks for ring buffer pipeline (deprecated, ring_slots = num_gpu_slots)
|
num_prefetch_blocks: Number of blocks for ring buffer pipeline (deprecated, ring_slots = num_gpu_slots)
|
||||||
"""
|
"""
|
||||||
self._block_size = block_size
|
self._block_size = block_size
|
||||||
self.num_gpu_slots = num_gpu_slots
|
self.num_gpu_slots = num_gpu_slots
|
||||||
self.num_cpu_blocks = num_cpu_blocks
|
self.num_cpu_blocks = num_cpu_blocks
|
||||||
self.total_blocks = num_gpu_slots + num_cpu_blocks
|
self.total_blocks = num_gpu_slots + num_cpu_blocks
|
||||||
self.cpu_primary = cpu_primary # Ring buffer mode flag
|
|
||||||
self.num_prefetch_blocks = num_prefetch_blocks # Ring buffer design parameter (deprecated)
|
self.num_prefetch_blocks = num_prefetch_blocks # Ring buffer design parameter (deprecated)
|
||||||
|
|
||||||
# Eviction policy
|
# Eviction policy
|
||||||
@@ -200,160 +196,6 @@ class HybridKVCacheManager(KVCacheManager):
|
|||||||
self.sparse_policy = policy
|
self.sparse_policy = policy
|
||||||
logger.info(f"Sparse attention policy set: {policy}")
|
logger.info(f"Sparse attention policy set: {policy}")
|
||||||
|
|
||||||
def _allocate_gpu_slot(self, protected_logical_ids: Optional[Set[int]] = None) -> int:
|
|
||||||
"""
|
|
||||||
Get a free GPU slot, evicting if necessary.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
protected_logical_ids: Logical block IDs that cannot be evicted
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
GPU slot ID
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: If no GPU slot is available
|
|
||||||
"""
|
|
||||||
if self.free_gpu_slots:
|
|
||||||
return self.free_gpu_slots.popleft()
|
|
||||||
|
|
||||||
# Need to evict - find victim using policy
|
|
||||||
return self._evict_to_cpu(protected_logical_ids)
|
|
||||||
|
|
||||||
def _try_allocate_gpu_slot(self, protected_logical_ids: Optional[Set[int]] = None) -> Optional[int]:
|
|
||||||
"""
|
|
||||||
Try to get a free GPU slot, evicting if necessary.
|
|
||||||
|
|
||||||
Unlike _allocate_gpu_slot(), returns None instead of raising if no eviction possible.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
protected_logical_ids: Logical block IDs that cannot be evicted
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
GPU slot ID, or None if no slot available
|
|
||||||
"""
|
|
||||||
if self.free_gpu_slots:
|
|
||||||
return self.free_gpu_slots.popleft()
|
|
||||||
|
|
||||||
# Check if we can evict
|
|
||||||
protected = protected_logical_ids or set()
|
|
||||||
for gpu_slot, logical_id in self.gpu_slot_to_logical.items():
|
|
||||||
if logical_id not in protected:
|
|
||||||
block = self.logical_blocks[logical_id]
|
|
||||||
if block.ref_count > 0:
|
|
||||||
# Found evictable block
|
|
||||||
return self._evict_to_cpu(protected_logical_ids)
|
|
||||||
|
|
||||||
# No evictable blocks
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _evict_to_cpu(self, protected_logical_ids: Optional[Set[int]] = None) -> int:
|
|
||||||
"""
|
|
||||||
Evict a GPU block to CPU to make room.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
protected_logical_ids: Logical block IDs that cannot be evicted
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The freed GPU slot ID
|
|
||||||
"""
|
|
||||||
protected = protected_logical_ids or set()
|
|
||||||
|
|
||||||
# Find candidates (blocks currently on GPU with ref_count > 0, excluding protected)
|
|
||||||
candidates: Set[int] = set()
|
|
||||||
for gpu_slot, logical_id in self.gpu_slot_to_logical.items():
|
|
||||||
if logical_id in protected:
|
|
||||||
continue # Skip protected blocks
|
|
||||||
block = self.logical_blocks[logical_id]
|
|
||||||
if block.ref_count > 0: # Only evict blocks still in use
|
|
||||||
candidates.add(gpu_slot)
|
|
||||||
|
|
||||||
if not candidates:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"No GPU slots available for eviction. "
|
|
||||||
f"GPU slots: {self.num_gpu_slots}, protected: {len(protected)}, "
|
|
||||||
f"need more GPU memory or reduce sequence length"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Use policy to select victim
|
|
||||||
victim_gpu_slot = self.policy.select_victim(candidates)
|
|
||||||
logical_id = self.gpu_slot_to_logical[victim_gpu_slot]
|
|
||||||
block = self.logical_blocks[logical_id]
|
|
||||||
|
|
||||||
# Allocate CPU block
|
|
||||||
if not self.free_cpu_blocks:
|
|
||||||
raise RuntimeError("Both GPU and CPU are full")
|
|
||||||
cpu_block_id = self.free_cpu_blocks.popleft()
|
|
||||||
|
|
||||||
# Async offload GPU -> CPU
|
|
||||||
self.offload_engine.offload_block_async(
|
|
||||||
layer_id=0, # TODO: handle per-layer offloading
|
|
||||||
gpu_block_id=victim_gpu_slot,
|
|
||||||
cpu_block_id=cpu_block_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update mappings
|
|
||||||
del self.gpu_slot_to_logical[victim_gpu_slot]
|
|
||||||
self.cpu_block_to_logical[cpu_block_id] = logical_id
|
|
||||||
|
|
||||||
block.location = BlockLocation.CPU
|
|
||||||
block.gpu_slot = -1
|
|
||||||
block.cpu_block_id = cpu_block_id
|
|
||||||
|
|
||||||
# Notify policy
|
|
||||||
self.policy.on_block_evicted(victim_gpu_slot)
|
|
||||||
|
|
||||||
return victim_gpu_slot
|
|
||||||
|
|
||||||
def _ensure_on_gpu(
|
|
||||||
self,
|
|
||||||
logical_id: int,
|
|
||||||
protected_logical_ids: Optional[Set[int]] = None,
|
|
||||||
) -> int:
|
|
||||||
"""
|
|
||||||
Ensure a logical block is on GPU.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
logical_id: Logical block ID
|
|
||||||
protected_logical_ids: Logical block IDs that cannot be evicted
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
GPU slot ID where the block is/will be
|
|
||||||
"""
|
|
||||||
block = self.logical_blocks[logical_id]
|
|
||||||
|
|
||||||
if block.location == BlockLocation.GPU:
|
|
||||||
# Already on GPU, update policy
|
|
||||||
self.policy.on_block_access(block.gpu_slot, self.current_step)
|
|
||||||
return block.gpu_slot
|
|
||||||
|
|
||||||
if block.location == BlockLocation.CPU:
|
|
||||||
# Need to prefetch from CPU
|
|
||||||
gpu_slot = self._allocate_gpu_slot(protected_logical_ids)
|
|
||||||
|
|
||||||
# Async prefetch CPU -> GPU
|
|
||||||
self.offload_engine.prefetch_block_async(
|
|
||||||
layer_id=0, # TODO: handle per-layer
|
|
||||||
cpu_block_id=block.cpu_block_id,
|
|
||||||
gpu_block_id=gpu_slot,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update mappings
|
|
||||||
self.free_cpu_blocks.append(block.cpu_block_id)
|
|
||||||
del self.cpu_block_to_logical[block.cpu_block_id]
|
|
||||||
|
|
||||||
self.gpu_slot_to_logical[gpu_slot] = logical_id
|
|
||||||
|
|
||||||
block.location = BlockLocation.GPU
|
|
||||||
block.gpu_slot = gpu_slot
|
|
||||||
block.cpu_block_id = -1
|
|
||||||
|
|
||||||
# Notify policy
|
|
||||||
self.policy.on_block_prefetched(gpu_slot, self.current_step)
|
|
||||||
|
|
||||||
return gpu_slot
|
|
||||||
|
|
||||||
raise RuntimeError(f"Block {logical_id} is in invalid state")
|
|
||||||
|
|
||||||
def can_allocate(self, seq: Sequence) -> bool:
|
def can_allocate(self, seq: Sequence) -> bool:
|
||||||
"""Check if we can allocate blocks for a new sequence."""
|
"""Check if we can allocate blocks for a new sequence."""
|
||||||
return len(self.free_logical_ids) >= seq.num_blocks
|
return len(self.free_logical_ids) >= seq.num_blocks
|
||||||
@@ -362,89 +204,10 @@ class HybridKVCacheManager(KVCacheManager):
|
|||||||
"""
|
"""
|
||||||
Allocate logical blocks for prefill.
|
Allocate logical blocks for prefill.
|
||||||
|
|
||||||
In cpu_primary mode (Chunked Offload): All blocks are allocated to CPU.
|
All blocks are allocated to CPU (primary storage).
|
||||||
In legacy mode: Blocks are allocated to GPU when possible, overflow to CPU.
|
GPU is used as ring buffer for computation only.
|
||||||
"""
|
"""
|
||||||
assert not seq.block_table, "Sequence already has blocks"
|
return self.allocate_cpu_only(seq)
|
||||||
|
|
||||||
# Ring buffer mode: all blocks are allocated to CPU
|
|
||||||
if self.cpu_primary:
|
|
||||||
return self.allocate_cpu_only(seq)
|
|
||||||
|
|
||||||
# Legacy mode: GPU as primary, CPU as overflow
|
|
||||||
h = -1
|
|
||||||
cache_miss = False
|
|
||||||
|
|
||||||
# Track blocks allocated for this sequence to protect them from eviction
|
|
||||||
allocated_for_seq: Set[int] = set()
|
|
||||||
|
|
||||||
for i in range(seq.num_blocks):
|
|
||||||
token_ids = seq.block(i)
|
|
||||||
|
|
||||||
# Hash for full blocks only
|
|
||||||
if len(token_ids) == self._block_size:
|
|
||||||
h = self.compute_hash(token_ids, h)
|
|
||||||
else:
|
|
||||||
h = -1
|
|
||||||
|
|
||||||
# Check prefix cache
|
|
||||||
cached_logical_id = self.hash_to_logical_id.get(h, -1)
|
|
||||||
if cached_logical_id != -1:
|
|
||||||
cached_block = self.logical_blocks[cached_logical_id]
|
|
||||||
if cached_block.token_ids == token_ids and cached_block.ref_count > 0:
|
|
||||||
# Cache hit
|
|
||||||
cached_block.ref_count += 1
|
|
||||||
seq.num_cached_tokens += self._block_size
|
|
||||||
seq.block_table.append(cached_logical_id)
|
|
||||||
allocated_for_seq.add(cached_logical_id)
|
|
||||||
|
|
||||||
# Ensure block is on GPU (protect already allocated blocks)
|
|
||||||
if cached_block.location == BlockLocation.CPU:
|
|
||||||
self._ensure_on_gpu(cached_logical_id, allocated_for_seq)
|
|
||||||
|
|
||||||
continue
|
|
||||||
|
|
||||||
cache_miss = True
|
|
||||||
|
|
||||||
# Allocate new logical block
|
|
||||||
logical_id = self.free_logical_ids.popleft()
|
|
||||||
block = self.logical_blocks[logical_id]
|
|
||||||
block.ref_count = 1
|
|
||||||
block.hash = h
|
|
||||||
block.token_ids = token_ids.copy() if len(token_ids) == self._block_size else []
|
|
||||||
|
|
||||||
# Try to allocate GPU slot
|
|
||||||
gpu_slot = self._try_allocate_gpu_slot(allocated_for_seq)
|
|
||||||
if gpu_slot is not None:
|
|
||||||
# Got GPU slot
|
|
||||||
block.location = BlockLocation.GPU
|
|
||||||
block.gpu_slot = gpu_slot
|
|
||||||
block.cpu_block_id = -1
|
|
||||||
self.gpu_slot_to_logical[gpu_slot] = logical_id
|
|
||||||
else:
|
|
||||||
# GPU full and can't evict (all protected) - allocate to CPU
|
|
||||||
# This block will be written via chunked prefill
|
|
||||||
if not self.free_cpu_blocks:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Both GPU and CPU are full. Need {seq.num_blocks} blocks, "
|
|
||||||
f"GPU has {self.num_gpu_slots}, CPU has {self.num_cpu_blocks}"
|
|
||||||
)
|
|
||||||
cpu_block_id = self.free_cpu_blocks.popleft()
|
|
||||||
block.location = BlockLocation.CPU
|
|
||||||
block.gpu_slot = -1
|
|
||||||
block.cpu_block_id = cpu_block_id
|
|
||||||
self.cpu_block_to_logical[cpu_block_id] = logical_id
|
|
||||||
|
|
||||||
allocated_for_seq.add(logical_id)
|
|
||||||
|
|
||||||
# Update prefix cache
|
|
||||||
if h != -1:
|
|
||||||
self.hash_to_logical_id[h] = logical_id
|
|
||||||
|
|
||||||
# Notify policy
|
|
||||||
self.policy.on_block_allocated(gpu_slot, self.current_step)
|
|
||||||
|
|
||||||
seq.block_table.append(logical_id)
|
|
||||||
|
|
||||||
def deallocate(self, seq: Sequence) -> None:
|
def deallocate(self, seq: Sequence) -> None:
|
||||||
"""Release all blocks for a sequence."""
|
"""Release all blocks for a sequence."""
|
||||||
@@ -496,22 +259,14 @@ class HybridKVCacheManager(KVCacheManager):
|
|||||||
block.hash = -1
|
block.hash = -1
|
||||||
block.token_ids = []
|
block.token_ids = []
|
||||||
|
|
||||||
if self.cpu_primary:
|
# Allocate new block to CPU (ring buffer mode)
|
||||||
# Ring buffer mode: new block allocated to CPU
|
if not self.free_cpu_blocks:
|
||||||
if not self.free_cpu_blocks:
|
raise RuntimeError("No free CPU blocks for decode")
|
||||||
raise RuntimeError("No free CPU blocks for decode")
|
cpu_block_id = self.free_cpu_blocks.popleft()
|
||||||
cpu_block_id = self.free_cpu_blocks.popleft()
|
block.location = BlockLocation.CPU
|
||||||
block.location = BlockLocation.CPU
|
block.cpu_block_id = cpu_block_id
|
||||||
block.cpu_block_id = cpu_block_id
|
block.gpu_slot = -1
|
||||||
block.gpu_slot = -1
|
self.cpu_block_to_logical[cpu_block_id] = logical_id
|
||||||
self.cpu_block_to_logical[cpu_block_id] = logical_id
|
|
||||||
else:
|
|
||||||
# Legacy mode: new block allocated to GPU
|
|
||||||
gpu_slot = self._allocate_gpu_slot()
|
|
||||||
block.location = BlockLocation.GPU
|
|
||||||
block.gpu_slot = gpu_slot
|
|
||||||
self.gpu_slot_to_logical[gpu_slot] = logical_id
|
|
||||||
self.policy.on_block_allocated(gpu_slot, self.current_step)
|
|
||||||
|
|
||||||
block_table.append(logical_id)
|
block_table.append(logical_id)
|
||||||
|
|
||||||
@@ -536,235 +291,10 @@ class HybridKVCacheManager(KVCacheManager):
|
|||||||
"""
|
"""
|
||||||
Prepare KV cache for attention computation.
|
Prepare KV cache for attention computation.
|
||||||
|
|
||||||
For prefill: async prefetch blocks from CPU to GPU.
|
In ring buffer mode, this is a no-op because chunked offload
|
||||||
For decode: update gather_indices for CUDA graph.
|
paths handle H2D transfers directly in the attention layer.
|
||||||
"""
|
"""
|
||||||
self.current_step += 1
|
pass
|
||||||
|
|
||||||
# Collect all needed logical blocks
|
|
||||||
needed_logical_ids: Set[int] = set()
|
|
||||||
for seq in seqs:
|
|
||||||
needed_logical_ids.update(seq.block_table)
|
|
||||||
|
|
||||||
if is_prefill:
|
|
||||||
# Prefill: ensure all blocks on GPU (async prefetch)
|
|
||||||
# Pass needed_logical_ids as protected to prevent evicting blocks we need
|
|
||||||
for logical_id in needed_logical_ids:
|
|
||||||
block = self.logical_blocks[logical_id]
|
|
||||||
if block.location == BlockLocation.CPU:
|
|
||||||
self._ensure_on_gpu(logical_id, needed_logical_ids)
|
|
||||||
|
|
||||||
# Wait for all prefetches to complete
|
|
||||||
self.offload_engine.wait_all_transfers()
|
|
||||||
|
|
||||||
else:
|
|
||||||
# Decode: Check if we need chunked decode
|
|
||||||
cpu_blocks_count = sum(
|
|
||||||
1 for lid in needed_logical_ids
|
|
||||||
if self.logical_blocks[lid].location == BlockLocation.CPU
|
|
||||||
)
|
|
||||||
|
|
||||||
if cpu_blocks_count > self.num_gpu_slots:
|
|
||||||
# Too many blocks on CPU - will use chunked decode
|
|
||||||
# Don't try to load all blocks now
|
|
||||||
return
|
|
||||||
|
|
||||||
# Standard decode: prepare gather_indices for CUDA graph
|
|
||||||
# Identify blocks needing transfer
|
|
||||||
self.pending_gpu_loads.clear()
|
|
||||||
mappings_per_layer: List[List[Tuple[int, int]]] = [
|
|
||||||
[] for _ in range(self.offload_engine.num_layers)
|
|
||||||
]
|
|
||||||
|
|
||||||
for logical_id in needed_logical_ids:
|
|
||||||
block = self.logical_blocks[logical_id]
|
|
||||||
if block.location == BlockLocation.CPU:
|
|
||||||
# Allocate GPU slot (protect needed blocks from eviction)
|
|
||||||
gpu_slot = self._allocate_gpu_slot(needed_logical_ids)
|
|
||||||
|
|
||||||
# Record mapping for each layer
|
|
||||||
for layer_id in range(self.offload_engine.num_layers):
|
|
||||||
mappings_per_layer[layer_id].append(
|
|
||||||
(block.cpu_block_id, gpu_slot)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update block state
|
|
||||||
self.free_cpu_blocks.append(block.cpu_block_id)
|
|
||||||
del self.cpu_block_to_logical[block.cpu_block_id]
|
|
||||||
|
|
||||||
self.gpu_slot_to_logical[gpu_slot] = logical_id
|
|
||||||
block.location = BlockLocation.GPU
|
|
||||||
block.gpu_slot = gpu_slot
|
|
||||||
block.cpu_block_id = -1
|
|
||||||
|
|
||||||
self.pending_gpu_loads.add(logical_id)
|
|
||||||
self.policy.on_block_prefetched(gpu_slot, self.current_step)
|
|
||||||
|
|
||||||
elif block.location == BlockLocation.GPU:
|
|
||||||
self.policy.on_block_access(block.gpu_slot, self.current_step)
|
|
||||||
|
|
||||||
# Update gather indices (outside graph)
|
|
||||||
self.offload_engine.update_gather_indices_all_layers(mappings_per_layer)
|
|
||||||
self.offload_engine.sync_indices()
|
|
||||||
|
|
||||||
def needs_chunked_decode(self, seq: Sequence) -> bool:
|
|
||||||
"""
|
|
||||||
Check if sequence needs chunked decode.
|
|
||||||
|
|
||||||
Returns True if there are blocks on CPU and total blocks exceed GPU capacity.
|
|
||||||
"""
|
|
||||||
cpu_blocks = 0
|
|
||||||
for logical_id in seq.block_table:
|
|
||||||
block = self.logical_blocks[logical_id]
|
|
||||||
if block.location == BlockLocation.CPU:
|
|
||||||
cpu_blocks += 1
|
|
||||||
return cpu_blocks > 0 and len(seq.block_table) > self.num_gpu_slots
|
|
||||||
|
|
||||||
# ========== Chunked Decode Support ==========
|
|
||||||
|
|
||||||
def get_decode_chunk_info(self, seq: Sequence) -> Tuple[List[int], List[int], int]:
|
|
||||||
"""
|
|
||||||
Get information for chunked decode.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
(cpu_block_ids, cpu_logical_ids, num_chunks)
|
|
||||||
- cpu_block_ids: List of CPU block IDs in sequence order
|
|
||||||
- cpu_logical_ids: Corresponding logical block IDs
|
|
||||||
- num_chunks: Number of chunks needed
|
|
||||||
"""
|
|
||||||
cpu_block_ids = []
|
|
||||||
cpu_logical_ids = []
|
|
||||||
|
|
||||||
for logical_id in seq.block_table:
|
|
||||||
block = self.logical_blocks[logical_id]
|
|
||||||
if block.location == BlockLocation.CPU:
|
|
||||||
cpu_block_ids.append(block.cpu_block_id)
|
|
||||||
cpu_logical_ids.append(logical_id)
|
|
||||||
|
|
||||||
# Each chunk uses available GPU slots minus 1 (reserved for write block)
|
|
||||||
usable_slots = self.num_gpu_slots - 1
|
|
||||||
num_chunks = (len(cpu_block_ids) + usable_slots - 1) // usable_slots if usable_slots > 0 else 0
|
|
||||||
|
|
||||||
return cpu_block_ids, cpu_logical_ids, num_chunks
|
|
||||||
|
|
||||||
def load_decode_chunk(
|
|
||||||
self,
|
|
||||||
seq: Sequence,
|
|
||||||
cpu_block_ids: List[int],
|
|
||||||
cpu_logical_ids: List[int],
|
|
||||||
chunk_idx: int,
|
|
||||||
) -> List[int]:
|
|
||||||
"""
|
|
||||||
Load one chunk of CPU blocks to GPU for chunked decode.
|
|
||||||
|
|
||||||
Similar to chunked prefill: uses GPU slots to hold a batch of blocks.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
seq: Sequence being decoded
|
|
||||||
cpu_block_ids: All CPU block IDs for this sequence
|
|
||||||
cpu_logical_ids: Corresponding logical block IDs
|
|
||||||
chunk_idx: Which chunk to load (0-indexed)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of GPU slot IDs where the chunk was loaded
|
|
||||||
"""
|
|
||||||
chunk_size = self.num_gpu_slots
|
|
||||||
start = chunk_idx * chunk_size
|
|
||||||
end = min(start + chunk_size, len(cpu_block_ids))
|
|
||||||
|
|
||||||
chunk_cpu_ids = cpu_block_ids[start:end]
|
|
||||||
chunk_logical_ids = cpu_logical_ids[start:end]
|
|
||||||
|
|
||||||
# Use GPU slots 0, 1, 2, ... for this chunk
|
|
||||||
gpu_slots = list(range(len(chunk_cpu_ids)))
|
|
||||||
|
|
||||||
# Load all layers at once using offload_engine
|
|
||||||
self.offload_engine.load_cpu_blocks_to_gpu_slots_all_layers(
|
|
||||||
chunk_cpu_ids, gpu_slots
|
|
||||||
)
|
|
||||||
|
|
||||||
return gpu_slots
|
|
||||||
|
|
||||||
def get_gpu_blocks_for_decode(self, seq: Sequence) -> Tuple[List[int], List[int]]:
|
|
||||||
"""
|
|
||||||
Get blocks currently on GPU for this sequence.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
(gpu_slots, logical_ids) - GPU slot IDs and corresponding logical block IDs
|
|
||||||
"""
|
|
||||||
gpu_slots = []
|
|
||||||
logical_ids = []
|
|
||||||
|
|
||||||
for logical_id in seq.block_table:
|
|
||||||
block = self.logical_blocks[logical_id]
|
|
||||||
if block.location == BlockLocation.GPU:
|
|
||||||
gpu_slots.append(block.gpu_slot)
|
|
||||||
logical_ids.append(logical_id)
|
|
||||||
|
|
||||||
return gpu_slots, logical_ids
|
|
||||||
|
|
||||||
def get_kv_for_gpu_slots(
|
|
||||||
self,
|
|
||||||
layer_id: int,
|
|
||||||
gpu_slots: List[int],
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
"""
|
|
||||||
Get KV tensors for specific GPU slots.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
layer_id: Layer index
|
|
||||||
gpu_slots: List of GPU slot IDs
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
(k, v) tensors with shape [1, num_tokens, kv_heads, head_dim]
|
|
||||||
"""
|
|
||||||
k_cache, v_cache = self.offload_engine.get_layer_cache(layer_id)
|
|
||||||
# k_cache, v_cache shape: [num_gpu_blocks, block_size, kv_heads, head_dim]
|
|
||||||
|
|
||||||
k_chunks = [k_cache[slot] for slot in gpu_slots]
|
|
||||||
v_chunks = [v_cache[slot] for slot in gpu_slots]
|
|
||||||
|
|
||||||
# Concatenate and add batch dimension
|
|
||||||
k = torch.cat(k_chunks, dim=0).unsqueeze(0) # [1, tokens, heads, dim]
|
|
||||||
v = torch.cat(v_chunks, dim=0).unsqueeze(0)
|
|
||||||
|
|
||||||
return k, v
|
|
||||||
|
|
||||||
def ensure_last_block_on_gpu(self, seq: Sequence) -> int:
|
|
||||||
"""
|
|
||||||
Ensure the last block is on GPU for writing new KV.
|
|
||||||
|
|
||||||
Uses a RESERVED slot (last slot) to avoid conflicts with chunked decode
|
|
||||||
which uses slots 0, 1, 2, ... for loading CPU blocks.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
GPU slot ID for the last block
|
|
||||||
"""
|
|
||||||
last_logical_id = seq.block_table[-1]
|
|
||||||
block = self.logical_blocks[last_logical_id]
|
|
||||||
|
|
||||||
if block.location == BlockLocation.GPU:
|
|
||||||
return block.gpu_slot
|
|
||||||
|
|
||||||
# Use last slot as reserved slot for write block
|
|
||||||
# This avoids conflicts with chunked decode which uses slots 0, 1, 2...
|
|
||||||
reserved_slot = self.num_gpu_slots - 1
|
|
||||||
|
|
||||||
# Load this block to GPU for all layers
|
|
||||||
self.offload_engine.load_cpu_blocks_to_gpu_slots_all_layers(
|
|
||||||
[block.cpu_block_id], [reserved_slot]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update block state
|
|
||||||
self.free_cpu_blocks.append(block.cpu_block_id)
|
|
||||||
del self.cpu_block_to_logical[block.cpu_block_id]
|
|
||||||
|
|
||||||
self.gpu_slot_to_logical[reserved_slot] = last_logical_id
|
|
||||||
block.location = BlockLocation.GPU
|
|
||||||
block.gpu_slot = reserved_slot
|
|
||||||
block.cpu_block_id = -1
|
|
||||||
|
|
||||||
return reserved_slot
|
|
||||||
|
|
||||||
def get_gpu_block_tables(
|
def get_gpu_block_tables(
|
||||||
self,
|
self,
|
||||||
@@ -773,19 +303,13 @@ class HybridKVCacheManager(KVCacheManager):
|
|||||||
"""
|
"""
|
||||||
Get GPU slot tables for sequences.
|
Get GPU slot tables for sequences.
|
||||||
|
|
||||||
Returns GPU slot IDs, which may differ from logical block IDs.
|
In ring buffer mode, all blocks are on CPU, so this raises an error
|
||||||
|
if called. Use run_chunked_offload_* methods instead.
|
||||||
"""
|
"""
|
||||||
result = []
|
raise RuntimeError(
|
||||||
for seq in seqs:
|
"get_gpu_block_tables should not be called in ring buffer mode. "
|
||||||
gpu_table = []
|
"Use run_chunked_offload_prefill/decode instead."
|
||||||
for logical_id in seq.block_table:
|
)
|
||||||
block = self.logical_blocks[logical_id]
|
|
||||||
assert block.location == BlockLocation.GPU, (
|
|
||||||
f"Block {logical_id} not on GPU (location={block.location})"
|
|
||||||
)
|
|
||||||
gpu_table.append(block.gpu_slot)
|
|
||||||
result.append(gpu_table)
|
|
||||||
return result
|
|
||||||
|
|
||||||
def post_attention_cleanup(
|
def post_attention_cleanup(
|
||||||
self,
|
self,
|
||||||
@@ -795,180 +319,12 @@ class HybridKVCacheManager(KVCacheManager):
|
|||||||
"""
|
"""
|
||||||
Cleanup after attention.
|
Cleanup after attention.
|
||||||
|
|
||||||
Clear pending loads and optionally proactive offload.
|
In ring buffer mode, this is a no-op because offload is handled
|
||||||
|
directly in the chunked prefill/decode paths.
|
||||||
"""
|
"""
|
||||||
self.pending_gpu_loads.clear()
|
pass
|
||||||
|
|
||||||
# ========== Chunked Prefill Support ==========
|
# ========== Ring Buffer CPU-primary Chunked Prefill Support ==========
|
||||||
|
|
||||||
def needs_chunked_prefill(self, seq: Sequence) -> bool:
|
|
||||||
"""
|
|
||||||
Check if sequence needs chunked prefill.
|
|
||||||
|
|
||||||
Returns True if there are unprefilled blocks that are on CPU.
|
|
||||||
This indicates we need to process in chunks because not all blocks fit on GPU.
|
|
||||||
"""
|
|
||||||
for logical_id in seq.block_table:
|
|
||||||
if logical_id not in self.prefilled_blocks:
|
|
||||||
block = self.logical_blocks[logical_id]
|
|
||||||
if block.location == BlockLocation.CPU:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def get_gpu_block_count(self, seq: Sequence) -> int:
|
|
||||||
"""Get number of blocks currently on GPU for this sequence."""
|
|
||||||
count = 0
|
|
||||||
for logical_id in seq.block_table:
|
|
||||||
block = self.logical_blocks[logical_id]
|
|
||||||
if block.location == BlockLocation.GPU:
|
|
||||||
count += 1
|
|
||||||
return count
|
|
||||||
|
|
||||||
def get_prefill_chunk_info(self, seq: Sequence) -> Tuple[int, int, List[int]]:
|
|
||||||
"""
|
|
||||||
Get information for current prefill chunk.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
(start_block_idx, end_block_idx, gpu_block_ids)
|
|
||||||
- start_block_idx: First block index in this chunk
|
|
||||||
- end_block_idx: Last block index (exclusive) in this chunk
|
|
||||||
- gpu_block_ids: GPU slot IDs for blocks in this chunk
|
|
||||||
"""
|
|
||||||
start_idx = -1
|
|
||||||
end_idx = -1
|
|
||||||
gpu_block_ids = []
|
|
||||||
|
|
||||||
for i, logical_id in enumerate(seq.block_table):
|
|
||||||
block = self.logical_blocks[logical_id]
|
|
||||||
if block.location == BlockLocation.GPU:
|
|
||||||
if start_idx == -1:
|
|
||||||
start_idx = i
|
|
||||||
end_idx = i + 1
|
|
||||||
gpu_block_ids.append(block.gpu_slot)
|
|
||||||
elif start_idx != -1:
|
|
||||||
# Found CPU block after GPU blocks - stop here
|
|
||||||
break
|
|
||||||
|
|
||||||
if start_idx == -1:
|
|
||||||
return (0, 0, [])
|
|
||||||
|
|
||||||
return (start_idx, end_idx, gpu_block_ids)
|
|
||||||
|
|
||||||
def complete_prefill_chunk(self, seq: Sequence) -> bool:
|
|
||||||
"""
|
|
||||||
Complete a prefill chunk: mark blocks as prefilled, offload to CPU, load next chunk.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if there are more chunks to process, False if done.
|
|
||||||
"""
|
|
||||||
# Find blocks currently on GPU that were just prefilled
|
|
||||||
gpu_blocks_to_offload = []
|
|
||||||
for logical_id in seq.block_table:
|
|
||||||
block = self.logical_blocks[logical_id]
|
|
||||||
if block.location == BlockLocation.GPU and logical_id not in self.prefilled_blocks:
|
|
||||||
# Mark as prefilled
|
|
||||||
self.prefilled_blocks.add(logical_id)
|
|
||||||
gpu_blocks_to_offload.append(logical_id)
|
|
||||||
|
|
||||||
# Offload prefilled GPU blocks to CPU
|
|
||||||
for logical_id in gpu_blocks_to_offload:
|
|
||||||
block = self.logical_blocks[logical_id]
|
|
||||||
if not self.free_cpu_blocks:
|
|
||||||
raise RuntimeError("No free CPU blocks for offload")
|
|
||||||
|
|
||||||
cpu_block_id = self.free_cpu_blocks.popleft()
|
|
||||||
|
|
||||||
# Async offload all layers
|
|
||||||
for layer_id in range(self.offload_engine.num_layers):
|
|
||||||
self.offload_engine.offload_block_async(
|
|
||||||
layer_id=layer_id,
|
|
||||||
gpu_block_id=block.gpu_slot,
|
|
||||||
cpu_block_id=cpu_block_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update mappings
|
|
||||||
self.free_gpu_slots.append(block.gpu_slot)
|
|
||||||
del self.gpu_slot_to_logical[block.gpu_slot]
|
|
||||||
self.cpu_block_to_logical[cpu_block_id] = logical_id
|
|
||||||
|
|
||||||
block.location = BlockLocation.CPU
|
|
||||||
block.cpu_block_id = cpu_block_id
|
|
||||||
block.gpu_slot = -1
|
|
||||||
|
|
||||||
# Wait for offload to complete
|
|
||||||
self.offload_engine.wait_all_transfers()
|
|
||||||
|
|
||||||
# Find next UNPREFILLED CPU blocks and bring them to GPU
|
|
||||||
cpu_blocks_to_load = []
|
|
||||||
for logical_id in seq.block_table:
|
|
||||||
if logical_id in self.prefilled_blocks:
|
|
||||||
continue # Skip already prefilled
|
|
||||||
block = self.logical_blocks[logical_id]
|
|
||||||
if block.location == BlockLocation.CPU:
|
|
||||||
if len(cpu_blocks_to_load) >= self.num_gpu_slots:
|
|
||||||
break # GPU is full
|
|
||||||
cpu_blocks_to_load.append(logical_id)
|
|
||||||
|
|
||||||
if not cpu_blocks_to_load:
|
|
||||||
return False # All blocks have been prefilled
|
|
||||||
|
|
||||||
# Load unprefilled CPU blocks to GPU
|
|
||||||
for logical_id in cpu_blocks_to_load:
|
|
||||||
block = self.logical_blocks[logical_id]
|
|
||||||
gpu_slot = self.free_gpu_slots.popleft()
|
|
||||||
|
|
||||||
# Note: We're NOT prefetching existing data - these blocks are being
|
|
||||||
# loaded for the first time, so we just need to assign GPU slots
|
|
||||||
# The model will write new KV cache data to these slots
|
|
||||||
|
|
||||||
# Update mappings
|
|
||||||
self.free_cpu_blocks.append(block.cpu_block_id)
|
|
||||||
del self.cpu_block_to_logical[block.cpu_block_id]
|
|
||||||
self.gpu_slot_to_logical[gpu_slot] = logical_id
|
|
||||||
|
|
||||||
block.location = BlockLocation.GPU
|
|
||||||
block.gpu_slot = gpu_slot
|
|
||||||
block.cpu_block_id = -1
|
|
||||||
|
|
||||||
return True # More chunks to process
|
|
||||||
|
|
||||||
def get_gpu_block_tables_partial(
|
|
||||||
self,
|
|
||||||
seqs: List[Sequence],
|
|
||||||
) -> List[Tuple[List[int], int, int]]:
|
|
||||||
"""
|
|
||||||
Get GPU block tables for chunked prefill.
|
|
||||||
|
|
||||||
Returns list of (gpu_block_ids, start_block_idx, end_block_idx) per sequence.
|
|
||||||
Only includes blocks that are currently on GPU AND haven't been prefilled yet.
|
|
||||||
"""
|
|
||||||
result = []
|
|
||||||
for seq in seqs:
|
|
||||||
gpu_table = []
|
|
||||||
start_idx = -1
|
|
||||||
end_idx = -1
|
|
||||||
|
|
||||||
for i, logical_id in enumerate(seq.block_table):
|
|
||||||
# Skip already prefilled blocks
|
|
||||||
if logical_id in self.prefilled_blocks:
|
|
||||||
continue
|
|
||||||
|
|
||||||
block = self.logical_blocks[logical_id]
|
|
||||||
if block.location == BlockLocation.GPU:
|
|
||||||
if start_idx == -1:
|
|
||||||
start_idx = i
|
|
||||||
end_idx = i + 1
|
|
||||||
gpu_table.append(block.gpu_slot)
|
|
||||||
elif start_idx != -1:
|
|
||||||
# Stop at first non-GPU block after GPU blocks
|
|
||||||
break
|
|
||||||
|
|
||||||
if start_idx == -1:
|
|
||||||
start_idx = 0
|
|
||||||
end_idx = 0
|
|
||||||
|
|
||||||
result.append((gpu_table, start_idx, end_idx))
|
|
||||||
return result
|
|
||||||
|
|
||||||
def get_prefilled_cpu_blocks(self, seq: Sequence) -> List[int]:
|
def get_prefilled_cpu_blocks(self, seq: Sequence) -> List[int]:
|
||||||
"""
|
"""
|
||||||
@@ -991,66 +347,6 @@ class HybridKVCacheManager(KVCacheManager):
|
|||||||
)
|
)
|
||||||
return cpu_blocks
|
return cpu_blocks
|
||||||
|
|
||||||
def load_prev_kv_for_layer(
|
|
||||||
self,
|
|
||||||
seq: Sequence,
|
|
||||||
layer_id: int,
|
|
||||||
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
|
||||||
"""
|
|
||||||
Load previous prefilled KV from CPU for a specific layer.
|
|
||||||
|
|
||||||
This concatenates KV from all previously prefilled blocks for use
|
|
||||||
during chunked prefill attention.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
seq: Sequence to load KV for
|
|
||||||
layer_id: Layer index
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
(k, v) tensors with shape [1, total_prev_tokens, kv_heads, head_dim]
|
|
||||||
or (None, None) if no previous KV exists
|
|
||||||
"""
|
|
||||||
cpu_blocks = self.get_prefilled_cpu_blocks(seq)
|
|
||||||
if not cpu_blocks:
|
|
||||||
return None, None
|
|
||||||
|
|
||||||
k_chunks = []
|
|
||||||
v_chunks = []
|
|
||||||
|
|
||||||
for cpu_block_id in cpu_blocks:
|
|
||||||
k, v = self.offload_engine.get_cpu_block(layer_id, cpu_block_id)
|
|
||||||
# k, v shape: [block_size, kv_heads, head_dim]
|
|
||||||
k_chunks.append(k)
|
|
||||||
v_chunks.append(v)
|
|
||||||
|
|
||||||
# Concatenate all chunks
|
|
||||||
k_prev = torch.cat(k_chunks, dim=0) # [total_prev_tokens, kv_heads, head_dim]
|
|
||||||
v_prev = torch.cat(v_chunks, dim=0)
|
|
||||||
|
|
||||||
# Move to GPU and add batch dimension
|
|
||||||
k_prev = k_prev.to("cuda", non_blocking=True).unsqueeze(0) # [1, tokens, heads, dim]
|
|
||||||
v_prev = v_prev.to("cuda", non_blocking=True).unsqueeze(0)
|
|
||||||
|
|
||||||
return k_prev, v_prev
|
|
||||||
|
|
||||||
def get_chunk_start_position(self, seq: Sequence) -> int:
|
|
||||||
"""
|
|
||||||
Get the starting token position for the current chunk.
|
|
||||||
|
|
||||||
This is the total number of tokens in previously prefilled blocks.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Token position offset for current chunk
|
|
||||||
"""
|
|
||||||
pos = 0
|
|
||||||
for logical_id in seq.block_table:
|
|
||||||
if logical_id in self.prefilled_blocks:
|
|
||||||
# Full block's worth of tokens
|
|
||||||
pos += self._block_size
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
return pos
|
|
||||||
|
|
||||||
# ========== Ring Buffer CPU-primary support ==========
|
# ========== Ring Buffer CPU-primary support ==========
|
||||||
|
|
||||||
def allocate_cpu_only(self, seq: Sequence) -> None:
|
def allocate_cpu_only(self, seq: Sequence) -> None:
|
||||||
|
|||||||
Reference in New Issue
Block a user