[refactor] Implement real chunked prefill mechenism.
This commit is contained in:
@@ -566,50 +566,151 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
cpu_blocks += 1
|
||||
return cpu_blocks > 0 and len(seq.block_table) > self.num_gpu_slots
|
||||
|
||||
def load_all_kv_for_layer(
|
||||
self,
|
||||
seq: Sequence,
|
||||
layer_id: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Load ALL KV for a sequence from both GPU and CPU for a layer.
|
||||
# ========== Chunked Decode Support ==========
|
||||
|
||||
Used during chunked decode to compute full attention.
|
||||
def get_decode_chunk_info(self, seq: Sequence) -> Tuple[List[int], List[int], int]:
|
||||
"""
|
||||
Get information for chunked decode.
|
||||
|
||||
Returns:
|
||||
(k, v) tensors with shape [1, total_tokens, kv_heads, head_dim]
|
||||
(cpu_block_ids, cpu_logical_ids, num_chunks)
|
||||
- cpu_block_ids: List of CPU block IDs in sequence order
|
||||
- cpu_logical_ids: Corresponding logical block IDs
|
||||
- num_chunks: Number of chunks needed
|
||||
"""
|
||||
k_chunks = []
|
||||
v_chunks = []
|
||||
cpu_block_ids = []
|
||||
cpu_logical_ids = []
|
||||
|
||||
for logical_id in seq.block_table:
|
||||
block = self.logical_blocks[logical_id]
|
||||
if block.location == BlockLocation.CPU:
|
||||
cpu_block_ids.append(block.cpu_block_id)
|
||||
cpu_logical_ids.append(logical_id)
|
||||
|
||||
# Each chunk uses available GPU slots minus 1 (reserved for write block)
|
||||
usable_slots = self.num_gpu_slots - 1
|
||||
num_chunks = (len(cpu_block_ids) + usable_slots - 1) // usable_slots if usable_slots > 0 else 0
|
||||
|
||||
return cpu_block_ids, cpu_logical_ids, num_chunks
|
||||
|
||||
def load_decode_chunk(
|
||||
self,
|
||||
seq: Sequence,
|
||||
cpu_block_ids: List[int],
|
||||
cpu_logical_ids: List[int],
|
||||
chunk_idx: int,
|
||||
) -> List[int]:
|
||||
"""
|
||||
Load one chunk of CPU blocks to GPU for chunked decode.
|
||||
|
||||
Similar to chunked prefill: uses GPU slots to hold a batch of blocks.
|
||||
|
||||
Args:
|
||||
seq: Sequence being decoded
|
||||
cpu_block_ids: All CPU block IDs for this sequence
|
||||
cpu_logical_ids: Corresponding logical block IDs
|
||||
chunk_idx: Which chunk to load (0-indexed)
|
||||
|
||||
Returns:
|
||||
List of GPU slot IDs where the chunk was loaded
|
||||
"""
|
||||
chunk_size = self.num_gpu_slots
|
||||
start = chunk_idx * chunk_size
|
||||
end = min(start + chunk_size, len(cpu_block_ids))
|
||||
|
||||
chunk_cpu_ids = cpu_block_ids[start:end]
|
||||
chunk_logical_ids = cpu_logical_ids[start:end]
|
||||
|
||||
# Use GPU slots 0, 1, 2, ... for this chunk
|
||||
gpu_slots = list(range(len(chunk_cpu_ids)))
|
||||
|
||||
# Load all layers at once using offload_engine
|
||||
self.offload_engine.load_cpu_blocks_to_gpu_slots_all_layers(
|
||||
chunk_cpu_ids, gpu_slots
|
||||
)
|
||||
|
||||
return gpu_slots
|
||||
|
||||
def get_gpu_blocks_for_decode(self, seq: Sequence) -> Tuple[List[int], List[int]]:
|
||||
"""
|
||||
Get blocks currently on GPU for this sequence.
|
||||
|
||||
Returns:
|
||||
(gpu_slots, logical_ids) - GPU slot IDs and corresponding logical block IDs
|
||||
"""
|
||||
gpu_slots = []
|
||||
logical_ids = []
|
||||
|
||||
for logical_id in seq.block_table:
|
||||
block = self.logical_blocks[logical_id]
|
||||
if block.location == BlockLocation.GPU:
|
||||
# Get from GPU cache
|
||||
k, v = self.offload_engine.get_layer_cache(layer_id)
|
||||
# k, v shape: [num_gpu_blocks, block_size, kv_heads, head_dim]
|
||||
k_block = k[block.gpu_slot] # [block_size, kv_heads, head_dim]
|
||||
v_block = v[block.gpu_slot]
|
||||
k_chunks.append(k_block)
|
||||
v_chunks.append(v_block)
|
||||
gpu_slots.append(block.gpu_slot)
|
||||
logical_ids.append(logical_id)
|
||||
|
||||
elif block.location == BlockLocation.CPU:
|
||||
# Get from CPU cache
|
||||
k_block, v_block = self.offload_engine.get_cpu_block(layer_id, block.cpu_block_id)
|
||||
# Already [block_size, kv_heads, head_dim]
|
||||
k_chunks.append(k_block.to("cuda", non_blocking=True))
|
||||
v_chunks.append(v_block.to("cuda", non_blocking=True))
|
||||
return gpu_slots, logical_ids
|
||||
|
||||
# Concatenate all chunks
|
||||
k_all = torch.cat(k_chunks, dim=0) # [total_tokens, kv_heads, head_dim]
|
||||
v_all = torch.cat(v_chunks, dim=0)
|
||||
def get_kv_for_gpu_slots(
|
||||
self,
|
||||
layer_id: int,
|
||||
gpu_slots: List[int],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Get KV tensors for specific GPU slots.
|
||||
|
||||
# Add batch dimension
|
||||
k_all = k_all.unsqueeze(0) # [1, total_tokens, kv_heads, head_dim]
|
||||
v_all = v_all.unsqueeze(0)
|
||||
Args:
|
||||
layer_id: Layer index
|
||||
gpu_slots: List of GPU slot IDs
|
||||
|
||||
return k_all, v_all
|
||||
Returns:
|
||||
(k, v) tensors with shape [1, num_tokens, kv_heads, head_dim]
|
||||
"""
|
||||
k_cache, v_cache = self.offload_engine.get_layer_cache(layer_id)
|
||||
# k_cache, v_cache shape: [num_gpu_blocks, block_size, kv_heads, head_dim]
|
||||
|
||||
k_chunks = [k_cache[slot] for slot in gpu_slots]
|
||||
v_chunks = [v_cache[slot] for slot in gpu_slots]
|
||||
|
||||
# Concatenate and add batch dimension
|
||||
k = torch.cat(k_chunks, dim=0).unsqueeze(0) # [1, tokens, heads, dim]
|
||||
v = torch.cat(v_chunks, dim=0).unsqueeze(0)
|
||||
|
||||
return k, v
|
||||
|
||||
def ensure_last_block_on_gpu(self, seq: Sequence) -> int:
|
||||
"""
|
||||
Ensure the last block is on GPU for writing new KV.
|
||||
|
||||
Uses a RESERVED slot (last slot) to avoid conflicts with chunked decode
|
||||
which uses slots 0, 1, 2, ... for loading CPU blocks.
|
||||
|
||||
Returns:
|
||||
GPU slot ID for the last block
|
||||
"""
|
||||
last_logical_id = seq.block_table[-1]
|
||||
block = self.logical_blocks[last_logical_id]
|
||||
|
||||
if block.location == BlockLocation.GPU:
|
||||
return block.gpu_slot
|
||||
|
||||
# Use last slot as reserved slot for write block
|
||||
# This avoids conflicts with chunked decode which uses slots 0, 1, 2...
|
||||
reserved_slot = self.num_gpu_slots - 1
|
||||
|
||||
# Load this block to GPU for all layers
|
||||
self.offload_engine.load_cpu_blocks_to_gpu_slots_all_layers(
|
||||
[block.cpu_block_id], [reserved_slot]
|
||||
)
|
||||
|
||||
# Update block state
|
||||
self.free_cpu_blocks.append(block.cpu_block_id)
|
||||
del self.cpu_block_to_logical[block.cpu_block_id]
|
||||
|
||||
self.gpu_slot_to_logical[reserved_slot] = last_logical_id
|
||||
block.location = BlockLocation.GPU
|
||||
block.gpu_slot = reserved_slot
|
||||
block.cpu_block_id = -1
|
||||
|
||||
return reserved_slot
|
||||
|
||||
def get_gpu_block_tables(
|
||||
self,
|
||||
|
||||
@@ -308,6 +308,112 @@ class OffloadEngine:
|
||||
events.append(event)
|
||||
return events
|
||||
|
||||
# ========== Chunked Decode: Load CPU blocks to GPU slots ==========
|
||||
|
||||
def load_cpu_blocks_to_gpu_slots(
|
||||
self,
|
||||
layer_id: int,
|
||||
cpu_block_ids: List[int],
|
||||
gpu_slot_ids: List[int],
|
||||
) -> None:
|
||||
"""
|
||||
Load CPU blocks to specific GPU slots for chunked decode.
|
||||
|
||||
Uses the main GPU KV cache slots, not a separate temp buffer.
|
||||
This is the same mechanism as chunked prefill uses.
|
||||
|
||||
Args:
|
||||
layer_id: Layer index
|
||||
cpu_block_ids: List of CPU block IDs to load
|
||||
gpu_slot_ids: List of GPU slot IDs to load into (must be same length)
|
||||
"""
|
||||
assert len(cpu_block_ids) == len(gpu_slot_ids)
|
||||
|
||||
stream = self._get_next_stream()
|
||||
|
||||
with torch.cuda.stream(stream):
|
||||
for cpu_block_id, gpu_slot in zip(cpu_block_ids, gpu_slot_ids):
|
||||
# Copy from pinned CPU memory to GPU KV cache slot
|
||||
self.k_cache_gpu[layer_id, gpu_slot].copy_(
|
||||
self.k_cache_cpu[layer_id, cpu_block_id],
|
||||
non_blocking=True
|
||||
)
|
||||
self.v_cache_gpu[layer_id, gpu_slot].copy_(
|
||||
self.v_cache_cpu[layer_id, cpu_block_id],
|
||||
non_blocking=True
|
||||
)
|
||||
|
||||
# Wait for transfer to complete
|
||||
stream.synchronize()
|
||||
|
||||
def load_cpu_blocks_to_gpu_slots_async(
|
||||
self,
|
||||
layer_id: int,
|
||||
cpu_block_ids: List[int],
|
||||
gpu_slot_ids: List[int],
|
||||
) -> torch.cuda.Event:
|
||||
"""
|
||||
Async version: Load CPU blocks to GPU slots.
|
||||
|
||||
Args:
|
||||
layer_id: Layer index
|
||||
cpu_block_ids: List of CPU block IDs to load
|
||||
gpu_slot_ids: List of GPU slot IDs to load into
|
||||
|
||||
Returns:
|
||||
CUDA event to wait on
|
||||
"""
|
||||
assert len(cpu_block_ids) == len(gpu_slot_ids)
|
||||
|
||||
stream = self._get_next_stream()
|
||||
event = torch.cuda.Event()
|
||||
|
||||
with torch.cuda.stream(stream):
|
||||
for cpu_block_id, gpu_slot in zip(cpu_block_ids, gpu_slot_ids):
|
||||
self.k_cache_gpu[layer_id, gpu_slot].copy_(
|
||||
self.k_cache_cpu[layer_id, cpu_block_id],
|
||||
non_blocking=True
|
||||
)
|
||||
self.v_cache_gpu[layer_id, gpu_slot].copy_(
|
||||
self.v_cache_cpu[layer_id, cpu_block_id],
|
||||
non_blocking=True
|
||||
)
|
||||
event.record()
|
||||
|
||||
return event
|
||||
|
||||
def load_cpu_blocks_to_gpu_slots_all_layers(
|
||||
self,
|
||||
cpu_block_ids: List[int],
|
||||
gpu_slot_ids: List[int],
|
||||
) -> None:
|
||||
"""
|
||||
Load CPU blocks to GPU slots for ALL layers at once.
|
||||
|
||||
More efficient than per-layer loading when we know the mapping upfront.
|
||||
|
||||
Args:
|
||||
cpu_block_ids: List of CPU block IDs to load
|
||||
gpu_slot_ids: List of GPU slot IDs to load into
|
||||
"""
|
||||
assert len(cpu_block_ids) == len(gpu_slot_ids)
|
||||
|
||||
stream = self._get_next_stream()
|
||||
|
||||
with torch.cuda.stream(stream):
|
||||
for cpu_block_id, gpu_slot in zip(cpu_block_ids, gpu_slot_ids):
|
||||
# Copy all layers at once
|
||||
self.k_cache_gpu[:, gpu_slot].copy_(
|
||||
self.k_cache_cpu[:, cpu_block_id],
|
||||
non_blocking=True
|
||||
)
|
||||
self.v_cache_gpu[:, gpu_slot].copy_(
|
||||
self.v_cache_cpu[:, cpu_block_id],
|
||||
non_blocking=True
|
||||
)
|
||||
|
||||
stream.synchronize()
|
||||
|
||||
# ========== Synchronization methods ==========
|
||||
|
||||
def wait_for_block(self, layer_id: int, gpu_block_id: int) -> None:
|
||||
|
||||
Reference in New Issue
Block a user