[refactor] Refactor current gpu and cpu block allocation strategy.
This commit is contained in:
@@ -81,20 +81,24 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
num_cpu_blocks: int,
|
||||
block_size: int,
|
||||
policy: Optional[EvictionPolicy] = None,
|
||||
cpu_primary: bool = True,
|
||||
):
|
||||
"""
|
||||
Initialize hybrid manager.
|
||||
|
||||
Args:
|
||||
num_gpu_slots: Number of GPU buffer slots (working set)
|
||||
num_cpu_blocks: Number of CPU pool blocks (overflow)
|
||||
num_cpu_blocks: Number of CPU pool blocks (overflow or primary storage)
|
||||
block_size: Tokens per block
|
||||
policy: Eviction policy (default: LRU)
|
||||
cpu_primary: If True, use CPU as primary storage with Ping-Pong GPU buffer.
|
||||
If False, use GPU as primary with CPU as overflow (legacy mode).
|
||||
"""
|
||||
self._block_size = block_size
|
||||
self.num_gpu_slots = num_gpu_slots
|
||||
self.num_cpu_blocks = num_cpu_blocks
|
||||
self.total_blocks = num_gpu_slots + num_cpu_blocks
|
||||
self.cpu_primary = cpu_primary # Ping-Pong mode flag
|
||||
|
||||
# Eviction policy
|
||||
self.policy = policy or LRUPolicy()
|
||||
@@ -321,12 +325,16 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
"""
|
||||
Allocate logical blocks for prefill.
|
||||
|
||||
New blocks are allocated on GPU when possible. If GPU is full and all
|
||||
GPU blocks belong to this sequence (can't evict), remaining blocks
|
||||
are allocated to CPU for chunked prefill.
|
||||
In cpu_primary mode (Ping-Pong): All blocks are allocated to CPU.
|
||||
In legacy mode: Blocks are allocated to GPU when possible, overflow to CPU.
|
||||
"""
|
||||
assert not seq.block_table, "Sequence already has blocks"
|
||||
|
||||
# Ping-Pong模式:所有blocks都分配到CPU
|
||||
if self.cpu_primary:
|
||||
return self.allocate_cpu_only(seq)
|
||||
|
||||
# Legacy模式:GPU为主,CPU为overflow
|
||||
h = -1
|
||||
cache_miss = False
|
||||
|
||||
@@ -451,13 +459,22 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
block.hash = -1
|
||||
block.token_ids = []
|
||||
|
||||
# New decode blocks go to GPU
|
||||
gpu_slot = self._allocate_gpu_slot()
|
||||
block.location = BlockLocation.GPU
|
||||
block.gpu_slot = gpu_slot
|
||||
|
||||
self.gpu_slot_to_logical[gpu_slot] = logical_id
|
||||
self.policy.on_block_allocated(gpu_slot, self.current_step)
|
||||
if self.cpu_primary:
|
||||
# Ping-Pong模式:新block分配到CPU
|
||||
if not self.free_cpu_blocks:
|
||||
raise RuntimeError("No free CPU blocks for decode")
|
||||
cpu_block_id = self.free_cpu_blocks.popleft()
|
||||
block.location = BlockLocation.CPU
|
||||
block.cpu_block_id = cpu_block_id
|
||||
block.gpu_slot = -1
|
||||
self.cpu_block_to_logical[cpu_block_id] = logical_id
|
||||
else:
|
||||
# Legacy模式:新block分配到GPU
|
||||
gpu_slot = self._allocate_gpu_slot()
|
||||
block.location = BlockLocation.GPU
|
||||
block.gpu_slot = gpu_slot
|
||||
self.gpu_slot_to_logical[gpu_slot] = logical_id
|
||||
self.policy.on_block_allocated(gpu_slot, self.current_step)
|
||||
|
||||
block_table.append(logical_id)
|
||||
|
||||
@@ -993,6 +1010,158 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
break
|
||||
return pos
|
||||
|
||||
# ========== Ping-Pong 双缓冲支持 ==========
|
||||
|
||||
def allocate_cpu_only(self, seq: Sequence) -> None:
|
||||
"""
|
||||
为序列分配 CPU blocks(用于 Ping-Pong 模式)。
|
||||
|
||||
与 allocate() 不同,这里所有 blocks 都分配到 CPU,
|
||||
GPU 只用作工作缓冲区。
|
||||
|
||||
Args:
|
||||
seq: 要分配的序列
|
||||
"""
|
||||
assert not seq.block_table, "Sequence already has blocks"
|
||||
|
||||
for i in range(seq.num_blocks):
|
||||
# 分配 CPU block
|
||||
if not self.free_cpu_blocks:
|
||||
raise RuntimeError(
|
||||
f"No free CPU blocks. Need {seq.num_blocks}, "
|
||||
f"available: {len(self.free_cpu_blocks)}"
|
||||
)
|
||||
|
||||
cpu_block_id = self.free_cpu_blocks.popleft()
|
||||
|
||||
# 分配逻辑块
|
||||
logical_id = self.free_logical_ids.popleft()
|
||||
block = self.logical_blocks[logical_id]
|
||||
block.ref_count = 1
|
||||
block.location = BlockLocation.CPU
|
||||
block.cpu_block_id = cpu_block_id
|
||||
block.gpu_slot = -1
|
||||
|
||||
self.cpu_block_to_logical[cpu_block_id] = logical_id
|
||||
seq.block_table.append(logical_id)
|
||||
|
||||
def get_cpu_block_table(self, seq: Sequence) -> List[int]:
|
||||
"""
|
||||
获取序列的 CPU block ID 列表。
|
||||
|
||||
Args:
|
||||
seq: 序列
|
||||
|
||||
Returns:
|
||||
CPU block IDs 列表,按序列顺序
|
||||
"""
|
||||
cpu_blocks = []
|
||||
for logical_id in seq.block_table:
|
||||
block = self.logical_blocks[logical_id]
|
||||
if block.location == BlockLocation.CPU:
|
||||
cpu_blocks.append(block.cpu_block_id)
|
||||
else:
|
||||
# 如果 block 在 GPU 上,它应该有一个对应的 CPU block
|
||||
# 在 Ping-Pong 模式下,所有数据最终都在 CPU 上
|
||||
raise RuntimeError(
|
||||
f"Block {logical_id} not on CPU (location={block.location}). "
|
||||
f"In Ping-Pong mode, all blocks should be on CPU."
|
||||
)
|
||||
return cpu_blocks
|
||||
|
||||
def get_all_cpu_blocks(self, seq: Sequence) -> Tuple[List[int], List[int]]:
|
||||
"""
|
||||
获取序列的所有 CPU blocks 及其逻辑 ID。
|
||||
|
||||
Args:
|
||||
seq: 序列
|
||||
|
||||
Returns:
|
||||
(cpu_block_ids, logical_ids)
|
||||
"""
|
||||
cpu_block_ids = []
|
||||
logical_ids = []
|
||||
for logical_id in seq.block_table:
|
||||
block = self.logical_blocks[logical_id]
|
||||
if block.location == BlockLocation.CPU:
|
||||
cpu_block_ids.append(block.cpu_block_id)
|
||||
logical_ids.append(logical_id)
|
||||
return cpu_block_ids, logical_ids
|
||||
|
||||
def allocate_next_cpu_block(self, seq: Sequence) -> int:
|
||||
"""
|
||||
为序列分配下一个 CPU block(用于 decode 时新 token)。
|
||||
|
||||
Args:
|
||||
seq: 序列
|
||||
|
||||
Returns:
|
||||
新分配的 CPU block ID
|
||||
"""
|
||||
if not self.free_cpu_blocks:
|
||||
raise RuntimeError("No free CPU blocks")
|
||||
|
||||
cpu_block_id = self.free_cpu_blocks.popleft()
|
||||
logical_id = self.free_logical_ids.popleft()
|
||||
|
||||
block = self.logical_blocks[logical_id]
|
||||
block.ref_count = 1
|
||||
block.location = BlockLocation.CPU
|
||||
block.cpu_block_id = cpu_block_id
|
||||
block.gpu_slot = -1
|
||||
|
||||
self.cpu_block_to_logical[cpu_block_id] = logical_id
|
||||
seq.block_table.append(logical_id)
|
||||
|
||||
return cpu_block_id
|
||||
|
||||
def get_last_cpu_block(self, seq: Sequence) -> int:
|
||||
"""
|
||||
获取序列最后一个 block 的 CPU block ID。
|
||||
|
||||
如果最后一个 block 不在 CPU 上,返回 -1。
|
||||
|
||||
Args:
|
||||
seq: 序列
|
||||
|
||||
Returns:
|
||||
CPU block ID,如果不在 CPU 上则返回 -1
|
||||
"""
|
||||
if not seq.block_table:
|
||||
return -1
|
||||
|
||||
last_logical_id = seq.block_table[-1]
|
||||
block = self.logical_blocks[last_logical_id]
|
||||
|
||||
if block.location == BlockLocation.CPU:
|
||||
return block.cpu_block_id
|
||||
return -1
|
||||
|
||||
def get_write_slot_for_pingpong(self, seq: Sequence) -> int:
|
||||
"""
|
||||
获取 Ping-Pong decode 时新 KV 写入的 GPU slot。
|
||||
|
||||
策略:使用序列所需 chunks 数决定最后用的是 Ping 还是 Pong buffer,
|
||||
然后使用该 buffer 的最后一个 slot。
|
||||
|
||||
Args:
|
||||
seq: 序列
|
||||
|
||||
Returns:
|
||||
GPU slot ID
|
||||
"""
|
||||
cpu_blocks, _ = self.get_all_cpu_blocks(seq)
|
||||
ping_size = self.offload_engine.ping_size
|
||||
num_chunks = (len(cpu_blocks) + ping_size - 1) // ping_size if cpu_blocks else 0
|
||||
|
||||
# 最后一个 chunk 用的是哪个 buffer
|
||||
if num_chunks % 2 == 1 or num_chunks == 0:
|
||||
# 奇数个 chunk(或0个),最后用的是 ping
|
||||
return self.offload_engine.ping_slots[-1]
|
||||
else:
|
||||
# 偶数个 chunk,最后用的是 pong
|
||||
return self.offload_engine.pong_slots[-1]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"HybridKVCacheManager(\n"
|
||||
|
||||
@@ -64,6 +64,14 @@ class OffloadEngine:
|
||||
self.kv_dim = num_kv_heads * head_dim
|
||||
self.block_numel = block_size * self.kv_dim
|
||||
|
||||
# ========== Ping-Pong 双缓冲配置 ==========
|
||||
assert num_gpu_blocks >= 2, "Ping-Pong需要至少2个GPU blocks"
|
||||
self.ping_size = num_gpu_blocks // 2
|
||||
self.pong_size = num_gpu_blocks - self.ping_size
|
||||
self.ping_slots = list(range(self.ping_size)) # [0, 1, 2, ...]
|
||||
self.pong_slots = list(range(self.ping_size, num_gpu_blocks)) # [ping_size, ping_size+1, ...]
|
||||
self.num_gpu_slots = num_gpu_blocks # alias
|
||||
|
||||
# ========== Fixed-address GPU KV cache ==========
|
||||
# Shape: [num_layers, num_gpu_blocks, block_size, kv_heads, head_dim]
|
||||
self.k_cache_gpu = torch.empty(
|
||||
@@ -103,6 +111,17 @@ class OffloadEngine:
|
||||
self.compute_stream = torch.cuda.current_stream()
|
||||
self._stream_idx = 0
|
||||
|
||||
# ========== Ping-Pong 专用 stream 和事件 ==========
|
||||
self.pingpong_stream = torch.cuda.Stream() # 专用于Ping-Pong传输
|
||||
|
||||
# 同步事件 - 加载完成
|
||||
self.ping_ready = torch.cuda.Event()
|
||||
self.pong_ready = torch.cuda.Event()
|
||||
|
||||
# 同步事件 - offload完成
|
||||
self.ping_offload_done = torch.cuda.Event()
|
||||
self.pong_offload_done = torch.cuda.Event()
|
||||
|
||||
# ========== Event tracking for async transfers ==========
|
||||
self.pending_events: Dict[Tuple[int, int], torch.cuda.Event] = {}
|
||||
|
||||
@@ -516,7 +535,211 @@ class OffloadEngine:
|
||||
f" kv_heads={self.num_kv_heads},\n"
|
||||
f" head_dim={self.head_dim},\n"
|
||||
f" dtype={self.dtype},\n"
|
||||
f" ping_size={self.ping_size}, pong_size={self.pong_size},\n"
|
||||
f" gpu_memory={self.gpu_memory_bytes() / 1024**2:.1f}MB,\n"
|
||||
f" cpu_memory={self.cpu_memory_bytes() / 1024**2:.1f}MB\n"
|
||||
f")"
|
||||
)
|
||||
)
|
||||
|
||||
# ========== Ping-Pong 双缓冲方法 ==========
|
||||
|
||||
def load_to_ping(self, cpu_block_ids: List[int]) -> None:
|
||||
"""
|
||||
异步加载CPU blocks到Ping buffer。
|
||||
|
||||
Args:
|
||||
cpu_block_ids: 要加载的CPU block IDs列表
|
||||
"""
|
||||
if not cpu_block_ids:
|
||||
self.ping_ready.record(self.pingpong_stream)
|
||||
return
|
||||
|
||||
num_to_load = min(len(cpu_block_ids), self.ping_size)
|
||||
logger.debug(f"Ping load: CPU{cpu_block_ids[:num_to_load]} -> GPU ping slots {self.ping_slots[:num_to_load]}")
|
||||
|
||||
with torch.cuda.stream(self.pingpong_stream):
|
||||
for i in range(num_to_load):
|
||||
cpu_id = cpu_block_ids[i]
|
||||
gpu_slot = self.ping_slots[i]
|
||||
# 所有层一起复制
|
||||
self.k_cache_gpu[:, gpu_slot].copy_(
|
||||
self.k_cache_cpu[:, cpu_id], non_blocking=True
|
||||
)
|
||||
self.v_cache_gpu[:, gpu_slot].copy_(
|
||||
self.v_cache_cpu[:, cpu_id], non_blocking=True
|
||||
)
|
||||
self.ping_ready.record(self.pingpong_stream)
|
||||
|
||||
def load_to_pong(self, cpu_block_ids: List[int]) -> None:
|
||||
"""
|
||||
异步加载CPU blocks到Pong buffer。
|
||||
|
||||
Args:
|
||||
cpu_block_ids: 要加载的CPU block IDs列表
|
||||
"""
|
||||
if not cpu_block_ids:
|
||||
self.pong_ready.record(self.pingpong_stream)
|
||||
return
|
||||
|
||||
num_to_load = min(len(cpu_block_ids), self.pong_size)
|
||||
logger.debug(f"Pong load: CPU{cpu_block_ids[:num_to_load]} -> GPU pong slots {self.pong_slots[:num_to_load]}")
|
||||
|
||||
with torch.cuda.stream(self.pingpong_stream):
|
||||
for i in range(num_to_load):
|
||||
cpu_id = cpu_block_ids[i]
|
||||
gpu_slot = self.pong_slots[i]
|
||||
self.k_cache_gpu[:, gpu_slot].copy_(
|
||||
self.k_cache_cpu[:, cpu_id], non_blocking=True
|
||||
)
|
||||
self.v_cache_gpu[:, gpu_slot].copy_(
|
||||
self.v_cache_cpu[:, cpu_id], non_blocking=True
|
||||
)
|
||||
self.pong_ready.record(self.pingpong_stream)
|
||||
|
||||
def wait_ping(self) -> None:
|
||||
"""等待Ping buffer加载完成。"""
|
||||
self.compute_stream.wait_event(self.ping_ready)
|
||||
|
||||
def wait_pong(self) -> None:
|
||||
"""等待Pong buffer加载完成。"""
|
||||
self.compute_stream.wait_event(self.pong_ready)
|
||||
|
||||
def offload_buffer_to_cpu(
|
||||
self,
|
||||
buffer: str,
|
||||
cpu_block_ids: List[int],
|
||||
) -> None:
|
||||
"""
|
||||
异步将buffer中的KV offload到CPU。
|
||||
|
||||
Args:
|
||||
buffer: "ping" 或 "pong"
|
||||
cpu_block_ids: 目标CPU block IDs列表
|
||||
"""
|
||||
slots = self.ping_slots if buffer == "ping" else self.pong_slots
|
||||
event = self.ping_offload_done if buffer == "ping" else self.pong_offload_done
|
||||
|
||||
if not cpu_block_ids:
|
||||
event.record(self.pingpong_stream)
|
||||
return
|
||||
|
||||
num_to_offload = min(len(cpu_block_ids), len(slots))
|
||||
logger.debug(f"{buffer.capitalize()} offload: GPU {slots[:num_to_offload]} -> CPU{cpu_block_ids[:num_to_offload]}")
|
||||
|
||||
with torch.cuda.stream(self.pingpong_stream):
|
||||
# 等待计算完成
|
||||
self.pingpong_stream.wait_stream(self.compute_stream)
|
||||
|
||||
for i in range(num_to_offload):
|
||||
gpu_slot = slots[i]
|
||||
cpu_id = cpu_block_ids[i]
|
||||
self.k_cache_cpu[:, cpu_id].copy_(
|
||||
self.k_cache_gpu[:, gpu_slot], non_blocking=True
|
||||
)
|
||||
self.v_cache_cpu[:, cpu_id].copy_(
|
||||
self.v_cache_gpu[:, gpu_slot], non_blocking=True
|
||||
)
|
||||
event.record(self.pingpong_stream)
|
||||
|
||||
def offload_slot_to_cpu(
|
||||
self,
|
||||
gpu_slot: int,
|
||||
cpu_block_id: int,
|
||||
) -> None:
|
||||
"""
|
||||
异步将单个GPU slot的KV offload到CPU。
|
||||
|
||||
Args:
|
||||
gpu_slot: GPU slot ID
|
||||
cpu_block_id: 目标CPU block ID
|
||||
"""
|
||||
logger.debug(f"Slot offload: GPU[{gpu_slot}] -> CPU[{cpu_block_id}]")
|
||||
|
||||
with torch.cuda.stream(self.pingpong_stream):
|
||||
self.pingpong_stream.wait_stream(self.compute_stream)
|
||||
self.k_cache_cpu[:, cpu_block_id].copy_(
|
||||
self.k_cache_gpu[:, gpu_slot], non_blocking=True
|
||||
)
|
||||
self.v_cache_cpu[:, cpu_block_id].copy_(
|
||||
self.v_cache_gpu[:, gpu_slot], non_blocking=True
|
||||
)
|
||||
|
||||
def wait_ping_offload_done(self) -> None:
|
||||
"""等待Ping buffer offload完成。"""
|
||||
self.compute_stream.wait_event(self.ping_offload_done)
|
||||
|
||||
def wait_pong_offload_done(self) -> None:
|
||||
"""等待Pong buffer offload完成。"""
|
||||
self.compute_stream.wait_event(self.pong_offload_done)
|
||||
|
||||
def wait_all_offload_done(self) -> None:
|
||||
"""等待所有offload完成。"""
|
||||
self.pingpong_stream.synchronize()
|
||||
|
||||
def get_kv_for_ping_slots(
|
||||
self,
|
||||
layer_id: int,
|
||||
num_slots: int,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
获取Ping buffer中指定数量slots的KV。
|
||||
|
||||
Args:
|
||||
layer_id: 层ID
|
||||
num_slots: 需要的slot数量
|
||||
|
||||
Returns:
|
||||
(k_cache, v_cache),shape: [1, num_slots * block_size, kv_heads, head_dim]
|
||||
"""
|
||||
slots = self.ping_slots[:num_slots]
|
||||
k = self.k_cache_gpu[layer_id, slots] # [num_slots, block_size, heads, dim]
|
||||
v = self.v_cache_gpu[layer_id, slots]
|
||||
# Reshape: [num_slots, block_size, heads, dim] -> [1, num_slots*block_size, heads, dim]
|
||||
k = k.reshape(1, -1, self.num_kv_heads, self.head_dim)
|
||||
v = v.reshape(1, -1, self.num_kv_heads, self.head_dim)
|
||||
return k, v
|
||||
|
||||
def get_kv_for_pong_slots(
|
||||
self,
|
||||
layer_id: int,
|
||||
num_slots: int,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
获取Pong buffer中指定数量slots的KV。
|
||||
|
||||
Args:
|
||||
layer_id: 层ID
|
||||
num_slots: 需要的slot数量
|
||||
|
||||
Returns:
|
||||
(k_cache, v_cache),shape: [1, num_slots * block_size, kv_heads, head_dim]
|
||||
"""
|
||||
slots = self.pong_slots[:num_slots]
|
||||
k = self.k_cache_gpu[layer_id, slots]
|
||||
v = self.v_cache_gpu[layer_id, slots]
|
||||
k = k.reshape(1, -1, self.num_kv_heads, self.head_dim)
|
||||
v = v.reshape(1, -1, self.num_kv_heads, self.head_dim)
|
||||
return k, v
|
||||
|
||||
def get_kv_for_slots(
|
||||
self,
|
||||
layer_id: int,
|
||||
gpu_slots: List[int],
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
获取指定GPU slots的KV。
|
||||
|
||||
Args:
|
||||
layer_id: 层ID
|
||||
gpu_slots: GPU slot IDs列表
|
||||
|
||||
Returns:
|
||||
(k_cache, v_cache),shape: [1, len(slots) * block_size, kv_heads, head_dim]
|
||||
"""
|
||||
if not gpu_slots:
|
||||
return None, None
|
||||
k = self.k_cache_gpu[layer_id, gpu_slots]
|
||||
v = self.v_cache_gpu[layer_id, gpu_slots]
|
||||
k = k.reshape(1, -1, self.num_kv_heads, self.head_dim)
|
||||
v = v.reshape(1, -1, self.num_kv_heads, self.head_dim)
|
||||
return k, v
|
||||
Reference in New Issue
Block a user