[refactor] Refactor current gpu and cpu block allocation strategy.

This commit is contained in:
Zijie Tian
2025-12-10 21:23:31 +08:00
parent 0a247ccb1b
commit 190df5f70d
7 changed files with 906 additions and 162 deletions

View File

@@ -81,20 +81,24 @@ class HybridKVCacheManager(KVCacheManager):
num_cpu_blocks: int,
block_size: int,
policy: Optional[EvictionPolicy] = None,
cpu_primary: bool = True,
):
"""
Initialize hybrid manager.
Args:
num_gpu_slots: Number of GPU buffer slots (working set)
num_cpu_blocks: Number of CPU pool blocks (overflow)
num_cpu_blocks: Number of CPU pool blocks (overflow or primary storage)
block_size: Tokens per block
policy: Eviction policy (default: LRU)
cpu_primary: If True, use CPU as primary storage with Ping-Pong GPU buffer.
If False, use GPU as primary with CPU as overflow (legacy mode).
"""
self._block_size = block_size
self.num_gpu_slots = num_gpu_slots
self.num_cpu_blocks = num_cpu_blocks
self.total_blocks = num_gpu_slots + num_cpu_blocks
self.cpu_primary = cpu_primary # Ping-Pong mode flag
# Eviction policy
self.policy = policy or LRUPolicy()
@@ -321,12 +325,16 @@ class HybridKVCacheManager(KVCacheManager):
"""
Allocate logical blocks for prefill.
New blocks are allocated on GPU when possible. If GPU is full and all
GPU blocks belong to this sequence (can't evict), remaining blocks
are allocated to CPU for chunked prefill.
In cpu_primary mode (Ping-Pong): All blocks are allocated to CPU.
In legacy mode: Blocks are allocated to GPU when possible, overflow to CPU.
"""
assert not seq.block_table, "Sequence already has blocks"
# Ping-Pong模式所有blocks都分配到CPU
if self.cpu_primary:
return self.allocate_cpu_only(seq)
# Legacy模式GPU为主CPU为overflow
h = -1
cache_miss = False
@@ -451,13 +459,22 @@ class HybridKVCacheManager(KVCacheManager):
block.hash = -1
block.token_ids = []
# New decode blocks go to GPU
gpu_slot = self._allocate_gpu_slot()
block.location = BlockLocation.GPU
block.gpu_slot = gpu_slot
self.gpu_slot_to_logical[gpu_slot] = logical_id
self.policy.on_block_allocated(gpu_slot, self.current_step)
if self.cpu_primary:
# Ping-Pong模式新block分配到CPU
if not self.free_cpu_blocks:
raise RuntimeError("No free CPU blocks for decode")
cpu_block_id = self.free_cpu_blocks.popleft()
block.location = BlockLocation.CPU
block.cpu_block_id = cpu_block_id
block.gpu_slot = -1
self.cpu_block_to_logical[cpu_block_id] = logical_id
else:
# Legacy模式新block分配到GPU
gpu_slot = self._allocate_gpu_slot()
block.location = BlockLocation.GPU
block.gpu_slot = gpu_slot
self.gpu_slot_to_logical[gpu_slot] = logical_id
self.policy.on_block_allocated(gpu_slot, self.current_step)
block_table.append(logical_id)
@@ -993,6 +1010,158 @@ class HybridKVCacheManager(KVCacheManager):
break
return pos
# ========== Ping-Pong 双缓冲支持 ==========
def allocate_cpu_only(self, seq: Sequence) -> None:
"""
为序列分配 CPU blocks用于 Ping-Pong 模式)。
与 allocate() 不同,这里所有 blocks 都分配到 CPU
GPU 只用作工作缓冲区。
Args:
seq: 要分配的序列
"""
assert not seq.block_table, "Sequence already has blocks"
for i in range(seq.num_blocks):
# 分配 CPU block
if not self.free_cpu_blocks:
raise RuntimeError(
f"No free CPU blocks. Need {seq.num_blocks}, "
f"available: {len(self.free_cpu_blocks)}"
)
cpu_block_id = self.free_cpu_blocks.popleft()
# 分配逻辑块
logical_id = self.free_logical_ids.popleft()
block = self.logical_blocks[logical_id]
block.ref_count = 1
block.location = BlockLocation.CPU
block.cpu_block_id = cpu_block_id
block.gpu_slot = -1
self.cpu_block_to_logical[cpu_block_id] = logical_id
seq.block_table.append(logical_id)
def get_cpu_block_table(self, seq: Sequence) -> List[int]:
"""
获取序列的 CPU block ID 列表。
Args:
seq: 序列
Returns:
CPU block IDs 列表,按序列顺序
"""
cpu_blocks = []
for logical_id in seq.block_table:
block = self.logical_blocks[logical_id]
if block.location == BlockLocation.CPU:
cpu_blocks.append(block.cpu_block_id)
else:
# 如果 block 在 GPU 上,它应该有一个对应的 CPU block
# 在 Ping-Pong 模式下,所有数据最终都在 CPU 上
raise RuntimeError(
f"Block {logical_id} not on CPU (location={block.location}). "
f"In Ping-Pong mode, all blocks should be on CPU."
)
return cpu_blocks
def get_all_cpu_blocks(self, seq: Sequence) -> Tuple[List[int], List[int]]:
"""
获取序列的所有 CPU blocks 及其逻辑 ID。
Args:
seq: 序列
Returns:
(cpu_block_ids, logical_ids)
"""
cpu_block_ids = []
logical_ids = []
for logical_id in seq.block_table:
block = self.logical_blocks[logical_id]
if block.location == BlockLocation.CPU:
cpu_block_ids.append(block.cpu_block_id)
logical_ids.append(logical_id)
return cpu_block_ids, logical_ids
def allocate_next_cpu_block(self, seq: Sequence) -> int:
"""
为序列分配下一个 CPU block用于 decode 时新 token
Args:
seq: 序列
Returns:
新分配的 CPU block ID
"""
if not self.free_cpu_blocks:
raise RuntimeError("No free CPU blocks")
cpu_block_id = self.free_cpu_blocks.popleft()
logical_id = self.free_logical_ids.popleft()
block = self.logical_blocks[logical_id]
block.ref_count = 1
block.location = BlockLocation.CPU
block.cpu_block_id = cpu_block_id
block.gpu_slot = -1
self.cpu_block_to_logical[cpu_block_id] = logical_id
seq.block_table.append(logical_id)
return cpu_block_id
def get_last_cpu_block(self, seq: Sequence) -> int:
"""
获取序列最后一个 block 的 CPU block ID。
如果最后一个 block 不在 CPU 上,返回 -1。
Args:
seq: 序列
Returns:
CPU block ID如果不在 CPU 上则返回 -1
"""
if not seq.block_table:
return -1
last_logical_id = seq.block_table[-1]
block = self.logical_blocks[last_logical_id]
if block.location == BlockLocation.CPU:
return block.cpu_block_id
return -1
def get_write_slot_for_pingpong(self, seq: Sequence) -> int:
"""
获取 Ping-Pong decode 时新 KV 写入的 GPU slot。
策略:使用序列所需 chunks 数决定最后用的是 Ping 还是 Pong buffer
然后使用该 buffer 的最后一个 slot。
Args:
seq: 序列
Returns:
GPU slot ID
"""
cpu_blocks, _ = self.get_all_cpu_blocks(seq)
ping_size = self.offload_engine.ping_size
num_chunks = (len(cpu_blocks) + ping_size - 1) // ping_size if cpu_blocks else 0
# 最后一个 chunk 用的是哪个 buffer
if num_chunks % 2 == 1 or num_chunks == 0:
# 奇数个 chunk或0个最后用的是 ping
return self.offload_engine.ping_slots[-1]
else:
# 偶数个 chunk最后用的是 pong
return self.offload_engine.pong_slots[-1]
def __repr__(self) -> str:
return (
f"HybridKVCacheManager(\n"

View File

@@ -64,6 +64,14 @@ class OffloadEngine:
self.kv_dim = num_kv_heads * head_dim
self.block_numel = block_size * self.kv_dim
# ========== Ping-Pong 双缓冲配置 ==========
assert num_gpu_blocks >= 2, "Ping-Pong需要至少2个GPU blocks"
self.ping_size = num_gpu_blocks // 2
self.pong_size = num_gpu_blocks - self.ping_size
self.ping_slots = list(range(self.ping_size)) # [0, 1, 2, ...]
self.pong_slots = list(range(self.ping_size, num_gpu_blocks)) # [ping_size, ping_size+1, ...]
self.num_gpu_slots = num_gpu_blocks # alias
# ========== Fixed-address GPU KV cache ==========
# Shape: [num_layers, num_gpu_blocks, block_size, kv_heads, head_dim]
self.k_cache_gpu = torch.empty(
@@ -103,6 +111,17 @@ class OffloadEngine:
self.compute_stream = torch.cuda.current_stream()
self._stream_idx = 0
# ========== Ping-Pong 专用 stream 和事件 ==========
self.pingpong_stream = torch.cuda.Stream() # 专用于Ping-Pong传输
# 同步事件 - 加载完成
self.ping_ready = torch.cuda.Event()
self.pong_ready = torch.cuda.Event()
# 同步事件 - offload完成
self.ping_offload_done = torch.cuda.Event()
self.pong_offload_done = torch.cuda.Event()
# ========== Event tracking for async transfers ==========
self.pending_events: Dict[Tuple[int, int], torch.cuda.Event] = {}
@@ -516,7 +535,211 @@ class OffloadEngine:
f" kv_heads={self.num_kv_heads},\n"
f" head_dim={self.head_dim},\n"
f" dtype={self.dtype},\n"
f" ping_size={self.ping_size}, pong_size={self.pong_size},\n"
f" gpu_memory={self.gpu_memory_bytes() / 1024**2:.1f}MB,\n"
f" cpu_memory={self.cpu_memory_bytes() / 1024**2:.1f}MB\n"
f")"
)
)
# ========== Ping-Pong 双缓冲方法 ==========
def load_to_ping(self, cpu_block_ids: List[int]) -> None:
"""
异步加载CPU blocks到Ping buffer。
Args:
cpu_block_ids: 要加载的CPU block IDs列表
"""
if not cpu_block_ids:
self.ping_ready.record(self.pingpong_stream)
return
num_to_load = min(len(cpu_block_ids), self.ping_size)
logger.debug(f"Ping load: CPU{cpu_block_ids[:num_to_load]} -> GPU ping slots {self.ping_slots[:num_to_load]}")
with torch.cuda.stream(self.pingpong_stream):
for i in range(num_to_load):
cpu_id = cpu_block_ids[i]
gpu_slot = self.ping_slots[i]
# 所有层一起复制
self.k_cache_gpu[:, gpu_slot].copy_(
self.k_cache_cpu[:, cpu_id], non_blocking=True
)
self.v_cache_gpu[:, gpu_slot].copy_(
self.v_cache_cpu[:, cpu_id], non_blocking=True
)
self.ping_ready.record(self.pingpong_stream)
def load_to_pong(self, cpu_block_ids: List[int]) -> None:
"""
异步加载CPU blocks到Pong buffer。
Args:
cpu_block_ids: 要加载的CPU block IDs列表
"""
if not cpu_block_ids:
self.pong_ready.record(self.pingpong_stream)
return
num_to_load = min(len(cpu_block_ids), self.pong_size)
logger.debug(f"Pong load: CPU{cpu_block_ids[:num_to_load]} -> GPU pong slots {self.pong_slots[:num_to_load]}")
with torch.cuda.stream(self.pingpong_stream):
for i in range(num_to_load):
cpu_id = cpu_block_ids[i]
gpu_slot = self.pong_slots[i]
self.k_cache_gpu[:, gpu_slot].copy_(
self.k_cache_cpu[:, cpu_id], non_blocking=True
)
self.v_cache_gpu[:, gpu_slot].copy_(
self.v_cache_cpu[:, cpu_id], non_blocking=True
)
self.pong_ready.record(self.pingpong_stream)
def wait_ping(self) -> None:
"""等待Ping buffer加载完成。"""
self.compute_stream.wait_event(self.ping_ready)
def wait_pong(self) -> None:
"""等待Pong buffer加载完成。"""
self.compute_stream.wait_event(self.pong_ready)
def offload_buffer_to_cpu(
self,
buffer: str,
cpu_block_ids: List[int],
) -> None:
"""
异步将buffer中的KV offload到CPU。
Args:
buffer: "ping""pong"
cpu_block_ids: 目标CPU block IDs列表
"""
slots = self.ping_slots if buffer == "ping" else self.pong_slots
event = self.ping_offload_done if buffer == "ping" else self.pong_offload_done
if not cpu_block_ids:
event.record(self.pingpong_stream)
return
num_to_offload = min(len(cpu_block_ids), len(slots))
logger.debug(f"{buffer.capitalize()} offload: GPU {slots[:num_to_offload]} -> CPU{cpu_block_ids[:num_to_offload]}")
with torch.cuda.stream(self.pingpong_stream):
# 等待计算完成
self.pingpong_stream.wait_stream(self.compute_stream)
for i in range(num_to_offload):
gpu_slot = slots[i]
cpu_id = cpu_block_ids[i]
self.k_cache_cpu[:, cpu_id].copy_(
self.k_cache_gpu[:, gpu_slot], non_blocking=True
)
self.v_cache_cpu[:, cpu_id].copy_(
self.v_cache_gpu[:, gpu_slot], non_blocking=True
)
event.record(self.pingpong_stream)
def offload_slot_to_cpu(
self,
gpu_slot: int,
cpu_block_id: int,
) -> None:
"""
异步将单个GPU slot的KV offload到CPU。
Args:
gpu_slot: GPU slot ID
cpu_block_id: 目标CPU block ID
"""
logger.debug(f"Slot offload: GPU[{gpu_slot}] -> CPU[{cpu_block_id}]")
with torch.cuda.stream(self.pingpong_stream):
self.pingpong_stream.wait_stream(self.compute_stream)
self.k_cache_cpu[:, cpu_block_id].copy_(
self.k_cache_gpu[:, gpu_slot], non_blocking=True
)
self.v_cache_cpu[:, cpu_block_id].copy_(
self.v_cache_gpu[:, gpu_slot], non_blocking=True
)
def wait_ping_offload_done(self) -> None:
"""等待Ping buffer offload完成。"""
self.compute_stream.wait_event(self.ping_offload_done)
def wait_pong_offload_done(self) -> None:
"""等待Pong buffer offload完成。"""
self.compute_stream.wait_event(self.pong_offload_done)
def wait_all_offload_done(self) -> None:
"""等待所有offload完成。"""
self.pingpong_stream.synchronize()
def get_kv_for_ping_slots(
self,
layer_id: int,
num_slots: int,
) -> Tuple[Tensor, Tensor]:
"""
获取Ping buffer中指定数量slots的KV。
Args:
layer_id: 层ID
num_slots: 需要的slot数量
Returns:
(k_cache, v_cache)shape: [1, num_slots * block_size, kv_heads, head_dim]
"""
slots = self.ping_slots[:num_slots]
k = self.k_cache_gpu[layer_id, slots] # [num_slots, block_size, heads, dim]
v = self.v_cache_gpu[layer_id, slots]
# Reshape: [num_slots, block_size, heads, dim] -> [1, num_slots*block_size, heads, dim]
k = k.reshape(1, -1, self.num_kv_heads, self.head_dim)
v = v.reshape(1, -1, self.num_kv_heads, self.head_dim)
return k, v
def get_kv_for_pong_slots(
self,
layer_id: int,
num_slots: int,
) -> Tuple[Tensor, Tensor]:
"""
获取Pong buffer中指定数量slots的KV。
Args:
layer_id: 层ID
num_slots: 需要的slot数量
Returns:
(k_cache, v_cache)shape: [1, num_slots * block_size, kv_heads, head_dim]
"""
slots = self.pong_slots[:num_slots]
k = self.k_cache_gpu[layer_id, slots]
v = self.v_cache_gpu[layer_id, slots]
k = k.reshape(1, -1, self.num_kv_heads, self.head_dim)
v = v.reshape(1, -1, self.num_kv_heads, self.head_dim)
return k, v
def get_kv_for_slots(
self,
layer_id: int,
gpu_slots: List[int],
) -> Tuple[Tensor, Tensor]:
"""
获取指定GPU slots的KV。
Args:
layer_id: 层ID
gpu_slots: GPU slot IDs列表
Returns:
(k_cache, v_cache)shape: [1, len(slots) * block_size, kv_heads, head_dim]
"""
if not gpu_slots:
return None, None
k = self.k_cache_gpu[layer_id, gpu_slots]
v = self.v_cache_gpu[layer_id, gpu_slots]
k = k.reshape(1, -1, self.num_kv_heads, self.head_dim)
v = v.reshape(1, -1, self.num_kv_heads, self.head_dim)
return k, v