[refactor] Implement real chunked prefill mechenism.

This commit is contained in:
Zijie Tian
2025-12-10 18:34:01 +08:00
parent 0b6f19242d
commit 87055cc5ce
4 changed files with 313 additions and 85 deletions

View File

@@ -566,50 +566,151 @@ class HybridKVCacheManager(KVCacheManager):
cpu_blocks += 1
return cpu_blocks > 0 and len(seq.block_table) > self.num_gpu_slots
def load_all_kv_for_layer(
self,
seq: Sequence,
layer_id: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Load ALL KV for a sequence from both GPU and CPU for a layer.
# ========== Chunked Decode Support ==========
Used during chunked decode to compute full attention.
def get_decode_chunk_info(self, seq: Sequence) -> Tuple[List[int], List[int], int]:
"""
Get information for chunked decode.
Returns:
(k, v) tensors with shape [1, total_tokens, kv_heads, head_dim]
(cpu_block_ids, cpu_logical_ids, num_chunks)
- cpu_block_ids: List of CPU block IDs in sequence order
- cpu_logical_ids: Corresponding logical block IDs
- num_chunks: Number of chunks needed
"""
k_chunks = []
v_chunks = []
cpu_block_ids = []
cpu_logical_ids = []
for logical_id in seq.block_table:
block = self.logical_blocks[logical_id]
if block.location == BlockLocation.CPU:
cpu_block_ids.append(block.cpu_block_id)
cpu_logical_ids.append(logical_id)
# Each chunk uses available GPU slots minus 1 (reserved for write block)
usable_slots = self.num_gpu_slots - 1
num_chunks = (len(cpu_block_ids) + usable_slots - 1) // usable_slots if usable_slots > 0 else 0
return cpu_block_ids, cpu_logical_ids, num_chunks
def load_decode_chunk(
self,
seq: Sequence,
cpu_block_ids: List[int],
cpu_logical_ids: List[int],
chunk_idx: int,
) -> List[int]:
"""
Load one chunk of CPU blocks to GPU for chunked decode.
Similar to chunked prefill: uses GPU slots to hold a batch of blocks.
Args:
seq: Sequence being decoded
cpu_block_ids: All CPU block IDs for this sequence
cpu_logical_ids: Corresponding logical block IDs
chunk_idx: Which chunk to load (0-indexed)
Returns:
List of GPU slot IDs where the chunk was loaded
"""
chunk_size = self.num_gpu_slots
start = chunk_idx * chunk_size
end = min(start + chunk_size, len(cpu_block_ids))
chunk_cpu_ids = cpu_block_ids[start:end]
chunk_logical_ids = cpu_logical_ids[start:end]
# Use GPU slots 0, 1, 2, ... for this chunk
gpu_slots = list(range(len(chunk_cpu_ids)))
# Load all layers at once using offload_engine
self.offload_engine.load_cpu_blocks_to_gpu_slots_all_layers(
chunk_cpu_ids, gpu_slots
)
return gpu_slots
def get_gpu_blocks_for_decode(self, seq: Sequence) -> Tuple[List[int], List[int]]:
"""
Get blocks currently on GPU for this sequence.
Returns:
(gpu_slots, logical_ids) - GPU slot IDs and corresponding logical block IDs
"""
gpu_slots = []
logical_ids = []
for logical_id in seq.block_table:
block = self.logical_blocks[logical_id]
if block.location == BlockLocation.GPU:
# Get from GPU cache
k, v = self.offload_engine.get_layer_cache(layer_id)
# k, v shape: [num_gpu_blocks, block_size, kv_heads, head_dim]
k_block = k[block.gpu_slot] # [block_size, kv_heads, head_dim]
v_block = v[block.gpu_slot]
k_chunks.append(k_block)
v_chunks.append(v_block)
gpu_slots.append(block.gpu_slot)
logical_ids.append(logical_id)
elif block.location == BlockLocation.CPU:
# Get from CPU cache
k_block, v_block = self.offload_engine.get_cpu_block(layer_id, block.cpu_block_id)
# Already [block_size, kv_heads, head_dim]
k_chunks.append(k_block.to("cuda", non_blocking=True))
v_chunks.append(v_block.to("cuda", non_blocking=True))
return gpu_slots, logical_ids
# Concatenate all chunks
k_all = torch.cat(k_chunks, dim=0) # [total_tokens, kv_heads, head_dim]
v_all = torch.cat(v_chunks, dim=0)
def get_kv_for_gpu_slots(
self,
layer_id: int,
gpu_slots: List[int],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Get KV tensors for specific GPU slots.
# Add batch dimension
k_all = k_all.unsqueeze(0) # [1, total_tokens, kv_heads, head_dim]
v_all = v_all.unsqueeze(0)
Args:
layer_id: Layer index
gpu_slots: List of GPU slot IDs
return k_all, v_all
Returns:
(k, v) tensors with shape [1, num_tokens, kv_heads, head_dim]
"""
k_cache, v_cache = self.offload_engine.get_layer_cache(layer_id)
# k_cache, v_cache shape: [num_gpu_blocks, block_size, kv_heads, head_dim]
k_chunks = [k_cache[slot] for slot in gpu_slots]
v_chunks = [v_cache[slot] for slot in gpu_slots]
# Concatenate and add batch dimension
k = torch.cat(k_chunks, dim=0).unsqueeze(0) # [1, tokens, heads, dim]
v = torch.cat(v_chunks, dim=0).unsqueeze(0)
return k, v
def ensure_last_block_on_gpu(self, seq: Sequence) -> int:
"""
Ensure the last block is on GPU for writing new KV.
Uses a RESERVED slot (last slot) to avoid conflicts with chunked decode
which uses slots 0, 1, 2, ... for loading CPU blocks.
Returns:
GPU slot ID for the last block
"""
last_logical_id = seq.block_table[-1]
block = self.logical_blocks[last_logical_id]
if block.location == BlockLocation.GPU:
return block.gpu_slot
# Use last slot as reserved slot for write block
# This avoids conflicts with chunked decode which uses slots 0, 1, 2...
reserved_slot = self.num_gpu_slots - 1
# Load this block to GPU for all layers
self.offload_engine.load_cpu_blocks_to_gpu_slots_all_layers(
[block.cpu_block_id], [reserved_slot]
)
# Update block state
self.free_cpu_blocks.append(block.cpu_block_id)
del self.cpu_block_to_logical[block.cpu_block_id]
self.gpu_slot_to_logical[reserved_slot] = last_logical_id
block.location = BlockLocation.GPU
block.gpu_slot = reserved_slot
block.cpu_block_id = -1
return reserved_slot
def get_gpu_block_tables(
self,