[refactor] Translate into english, void Chinese due to claude.
This commit is contained in:
@@ -95,16 +95,16 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
num_cpu_blocks: Number of CPU pool blocks (overflow or primary storage)
|
||||
block_size: Tokens per block
|
||||
policy: Eviction policy (default: LRU)
|
||||
cpu_primary: If True, use CPU as primary storage with 三区域 GPU buffer.
|
||||
cpu_primary: If True, use CPU as primary storage with three-region GPU buffer.
|
||||
If False, use GPU as primary with CPU as overflow (legacy mode).
|
||||
num_prefetch_blocks: Number of prefetch blocks for 三区域 GPU buffer design
|
||||
num_prefetch_blocks: Number of prefetch blocks for three-region GPU buffer design
|
||||
"""
|
||||
self._block_size = block_size
|
||||
self.num_gpu_slots = num_gpu_slots
|
||||
self.num_cpu_blocks = num_cpu_blocks
|
||||
self.total_blocks = num_gpu_slots + num_cpu_blocks
|
||||
self.cpu_primary = cpu_primary # 三区域 mode flag
|
||||
self.num_prefetch_blocks = num_prefetch_blocks # 三区域设计参数
|
||||
self.cpu_primary = cpu_primary # Three-region mode flag
|
||||
self.num_prefetch_blocks = num_prefetch_blocks # Three-region design parameter
|
||||
|
||||
# Eviction policy
|
||||
self.policy = policy or LRUPolicy()
|
||||
@@ -138,6 +138,10 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
# Track blocks that have been prefilled (KV written) for chunked prefill
|
||||
self.prefilled_blocks: Set[int] = set() # logical_ids
|
||||
|
||||
# Track decode starting position within block (for batched offload optimization)
|
||||
# Key: sequence id, Value: starting position where decode began in current block
|
||||
self._decode_start_pos: Dict[int, int] = {}
|
||||
|
||||
@property
|
||||
def block_size(self) -> int:
|
||||
return self._block_size
|
||||
@@ -337,11 +341,11 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
"""
|
||||
assert not seq.block_table, "Sequence already has blocks"
|
||||
|
||||
# Ping-Pong模式:所有blocks都分配到CPU
|
||||
# Three-region mode: all blocks are allocated to CPU
|
||||
if self.cpu_primary:
|
||||
return self.allocate_cpu_only(seq)
|
||||
|
||||
# Legacy模式:GPU为主,CPU为overflow
|
||||
# Legacy mode: GPU as primary, CPU as overflow
|
||||
h = -1
|
||||
cache_miss = False
|
||||
|
||||
@@ -467,7 +471,7 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
block.token_ids = []
|
||||
|
||||
if self.cpu_primary:
|
||||
# Ping-Pong模式:新block分配到CPU
|
||||
# Three-region mode: new block allocated to CPU
|
||||
if not self.free_cpu_blocks:
|
||||
raise RuntimeError("No free CPU blocks for decode")
|
||||
cpu_block_id = self.free_cpu_blocks.popleft()
|
||||
@@ -476,7 +480,7 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
block.gpu_slot = -1
|
||||
self.cpu_block_to_logical[cpu_block_id] = logical_id
|
||||
else:
|
||||
# Legacy模式:新block分配到GPU
|
||||
# Legacy mode: new block allocated to GPU
|
||||
gpu_slot = self._allocate_gpu_slot()
|
||||
block.location = BlockLocation.GPU
|
||||
block.gpu_slot = gpu_slot
|
||||
@@ -1021,22 +1025,22 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
break
|
||||
return pos
|
||||
|
||||
# ========== Ping-Pong 双缓冲支持 ==========
|
||||
# ========== Three-region double buffering support ==========
|
||||
|
||||
def allocate_cpu_only(self, seq: Sequence) -> None:
|
||||
"""
|
||||
为序列分配 CPU blocks(用于 Ping-Pong 模式)。
|
||||
Allocate CPU blocks for sequence (for three-region mode).
|
||||
|
||||
与 allocate() 不同,这里所有 blocks 都分配到 CPU,
|
||||
GPU 只用作工作缓冲区。
|
||||
Unlike allocate(), here all blocks are allocated to CPU,
|
||||
GPU is only used as working buffer.
|
||||
|
||||
Args:
|
||||
seq: 要分配的序列
|
||||
seq: Sequence to allocate
|
||||
"""
|
||||
assert not seq.block_table, "Sequence already has blocks"
|
||||
|
||||
for i in range(seq.num_blocks):
|
||||
# 分配 CPU block
|
||||
# Allocate CPU block
|
||||
if not self.free_cpu_blocks:
|
||||
raise RuntimeError(
|
||||
f"No free CPU blocks. Need {seq.num_blocks}, "
|
||||
@@ -1045,7 +1049,7 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
|
||||
cpu_block_id = self.free_cpu_blocks.popleft()
|
||||
|
||||
# 分配逻辑块
|
||||
# Allocate logical block
|
||||
logical_id = self.free_logical_ids.popleft()
|
||||
block = self.logical_blocks[logical_id]
|
||||
block.ref_count = 1
|
||||
@@ -1058,13 +1062,13 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
|
||||
def get_cpu_block_table(self, seq: Sequence) -> List[int]:
|
||||
"""
|
||||
获取序列的 CPU block ID 列表。
|
||||
Get CPU block ID list for sequence.
|
||||
|
||||
Args:
|
||||
seq: 序列
|
||||
seq: Sequence
|
||||
|
||||
Returns:
|
||||
CPU block IDs 列表,按序列顺序
|
||||
List of CPU block IDs in sequence order
|
||||
"""
|
||||
cpu_blocks = []
|
||||
for logical_id in seq.block_table:
|
||||
@@ -1072,20 +1076,20 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
if block.location == BlockLocation.CPU:
|
||||
cpu_blocks.append(block.cpu_block_id)
|
||||
else:
|
||||
# 如果 block 在 GPU 上,它应该有一个对应的 CPU block
|
||||
# 在 Ping-Pong 模式下,所有数据最终都在 CPU 上
|
||||
# If block is on GPU, it should have a corresponding CPU block
|
||||
# In three-region mode, all data ultimately resides on CPU
|
||||
raise RuntimeError(
|
||||
f"Block {logical_id} not on CPU (location={block.location}). "
|
||||
f"In Ping-Pong mode, all blocks should be on CPU."
|
||||
f"In three-region mode, all blocks should be on CPU."
|
||||
)
|
||||
return cpu_blocks
|
||||
|
||||
def get_all_cpu_blocks(self, seq: Sequence) -> Tuple[List[int], List[int]]:
|
||||
"""
|
||||
获取序列的所有 CPU blocks 及其逻辑 ID。
|
||||
Get all CPU blocks and their logical IDs for sequence.
|
||||
|
||||
Args:
|
||||
seq: 序列
|
||||
seq: Sequence
|
||||
|
||||
Returns:
|
||||
(cpu_block_ids, logical_ids)
|
||||
@@ -1101,13 +1105,13 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
|
||||
def allocate_next_cpu_block(self, seq: Sequence) -> int:
|
||||
"""
|
||||
为序列分配下一个 CPU block(用于 decode 时新 token)。
|
||||
Allocate next CPU block for sequence (for new token during decode).
|
||||
|
||||
Args:
|
||||
seq: 序列
|
||||
seq: Sequence
|
||||
|
||||
Returns:
|
||||
新分配的 CPU block ID
|
||||
Newly allocated CPU block ID
|
||||
"""
|
||||
if not self.free_cpu_blocks:
|
||||
raise RuntimeError("No free CPU blocks")
|
||||
@@ -1128,15 +1132,15 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
|
||||
def get_last_cpu_block(self, seq: Sequence) -> int:
|
||||
"""
|
||||
获取序列最后一个 block 的 CPU block ID。
|
||||
Get CPU block ID of the last block in sequence.
|
||||
|
||||
如果最后一个 block 不在 CPU 上,返回 -1。
|
||||
Returns -1 if the last block is not on CPU.
|
||||
|
||||
Args:
|
||||
seq: 序列
|
||||
seq: Sequence
|
||||
|
||||
Returns:
|
||||
CPU block ID,如果不在 CPU 上则返回 -1
|
||||
CPU block ID, or -1 if not on CPU
|
||||
"""
|
||||
if not seq.block_table:
|
||||
return -1
|
||||
@@ -1150,19 +1154,65 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
|
||||
def get_write_slot_for_pingpong(self, seq: Sequence) -> int:
|
||||
"""
|
||||
获取三区域 decode 时新 KV 写入的 GPU slot。
|
||||
Get GPU slot for writing new KV during three-region decode.
|
||||
|
||||
在三区域设计中,永远使用 Decode区 (slot 0) 写入新 KV。
|
||||
这样可以避免与 Compute/Prefetch区 的加载操作冲突。
|
||||
In three-region design, always use Decode region (slot 0) to write new KV.
|
||||
This avoids conflicts with Compute/Prefetch region loading operations.
|
||||
|
||||
Args:
|
||||
seq: 序列
|
||||
seq: Sequence
|
||||
|
||||
Returns:
|
||||
GPU slot ID (永远是 decode_slot = 0)
|
||||
GPU slot ID (always decode_slot = 0)
|
||||
"""
|
||||
return self.offload_engine.decode_slot
|
||||
|
||||
def get_decode_start_pos(self, seq: Sequence) -> int:
|
||||
"""
|
||||
Get the starting position within block where decode tokens began.
|
||||
|
||||
This is used for batched offload optimization - we need to attend to all
|
||||
accumulated tokens in decode slot, not just the current one.
|
||||
|
||||
Args:
|
||||
seq: Sequence
|
||||
|
||||
Returns:
|
||||
Starting position within block (0 to block_size-1)
|
||||
"""
|
||||
seq_id = id(seq)
|
||||
if seq_id not in self._decode_start_pos:
|
||||
# First decode step - compute starting position
|
||||
# After prefill, the last block has some tokens filled
|
||||
# Decode starts at the next position
|
||||
prefill_len = len(seq) - 1 # Current len includes the new decode token
|
||||
self._decode_start_pos[seq_id] = prefill_len % self._block_size
|
||||
return self._decode_start_pos[seq_id]
|
||||
|
||||
def reset_decode_start_pos(self, seq: Sequence) -> None:
|
||||
"""
|
||||
Reset decode start position for sequence.
|
||||
|
||||
Called when block is full and offloaded - next decode starts at position 0.
|
||||
|
||||
Args:
|
||||
seq: Sequence
|
||||
"""
|
||||
seq_id = id(seq)
|
||||
self._decode_start_pos[seq_id] = 0
|
||||
|
||||
def clear_decode_tracking(self, seq: Sequence) -> None:
|
||||
"""
|
||||
Clear decode position tracking for sequence.
|
||||
|
||||
Called when sequence is deallocated.
|
||||
|
||||
Args:
|
||||
seq: Sequence
|
||||
"""
|
||||
seq_id = id(seq)
|
||||
self._decode_start_pos.pop(seq_id, None)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"HybridKVCacheManager(\n"
|
||||
|
||||
@@ -65,44 +65,44 @@ class OffloadEngine:
|
||||
self.kv_dim = num_kv_heads * head_dim
|
||||
self.block_numel = block_size * self.kv_dim
|
||||
|
||||
# ========== 三区域 GPU Buffer 配置 ==========
|
||||
# 约束检查
|
||||
# ========== Three-region GPU Buffer configuration ==========
|
||||
# Constraint checks
|
||||
assert num_gpu_blocks >= 3, \
|
||||
f"至少需要3个GPU blocks: 1 decode + 1 compute + 1 prefetch, got {num_gpu_blocks}"
|
||||
f"Need at least 3 GPU blocks: 1 decode + 1 compute + 1 prefetch, got {num_gpu_blocks}"
|
||||
assert num_prefetch_blocks >= 1, \
|
||||
f"至少需要1个prefetch block, got {num_prefetch_blocks}"
|
||||
f"Need at least 1 prefetch block, got {num_prefetch_blocks}"
|
||||
assert num_gpu_blocks >= 1 + 1 + num_prefetch_blocks, \
|
||||
f"GPU blocks不足: 需要 1(decode) + 1(compute) + {num_prefetch_blocks}(prefetch), got {num_gpu_blocks}"
|
||||
f"Insufficient GPU blocks: need 1(decode) + 1(compute) + {num_prefetch_blocks}(prefetch), got {num_gpu_blocks}"
|
||||
|
||||
# 三区域配置
|
||||
# Decode区: [0] - 固定1个block用于写入新KV
|
||||
# Three-region configuration
|
||||
# Decode region: [0] - Fixed 1 block for writing new KV
|
||||
self.decode_slot = 0
|
||||
|
||||
# Compute区: [1, ..., num_gpu_blocks - num_prefetch_blocks - 1]
|
||||
# Compute region: [1, ..., num_gpu_blocks - num_prefetch_blocks - 1]
|
||||
compute_start = 1
|
||||
compute_end = num_gpu_blocks - num_prefetch_blocks
|
||||
self.compute_slots = list(range(compute_start, compute_end))
|
||||
self.num_compute_blocks = len(self.compute_slots)
|
||||
|
||||
# Prefetch区: [num_gpu_blocks - num_prefetch_blocks, ..., num_gpu_blocks - 1]
|
||||
# Prefetch region: [num_gpu_blocks - num_prefetch_blocks, ..., num_gpu_blocks - 1]
|
||||
prefetch_start = compute_end
|
||||
self.prefetch_slots = list(range(prefetch_start, num_gpu_blocks))
|
||||
self.num_prefetch_blocks = num_prefetch_blocks
|
||||
|
||||
self.num_gpu_slots = num_gpu_blocks # alias
|
||||
|
||||
# 保留旧的ping/pong属性以兼容(后续会移除)
|
||||
# Keep old ping/pong attributes for compatibility (will be removed later)
|
||||
self.ping_size = self.num_compute_blocks
|
||||
self.pong_size = self.num_prefetch_blocks
|
||||
self.ping_slots = self.compute_slots.copy()
|
||||
self.pong_slots = self.prefetch_slots.copy()
|
||||
|
||||
logger.info(f"三区域 GPU Buffer: decode_slot={self.decode_slot}, "
|
||||
logger.info(f"Three-region GPU Buffer: decode_slot={self.decode_slot}, "
|
||||
f"compute_slots={self.compute_slots}, prefetch_slots={self.prefetch_slots}")
|
||||
|
||||
# ========== Fixed-address GPU KV cache ==========
|
||||
# Shape: [num_layers, num_gpu_blocks, block_size, kv_heads, head_dim]
|
||||
# 使用 zeros 初始化以避免未初始化内存问题
|
||||
# Use zeros initialization to avoid uninitialized memory issues
|
||||
self.k_cache_gpu = torch.zeros(
|
||||
num_layers, num_gpu_blocks, block_size, num_kv_heads, head_dim,
|
||||
dtype=dtype, device="cuda"
|
||||
@@ -140,15 +140,15 @@ class OffloadEngine:
|
||||
self.compute_stream = torch.cuda.current_stream()
|
||||
self._stream_idx = 0
|
||||
|
||||
# ========== 三区域专用 stream 和事件 ==========
|
||||
self.transfer_stream_main = torch.cuda.Stream() # 主传输stream
|
||||
# ========== Three-region dedicated stream and events ==========
|
||||
self.transfer_stream_main = torch.cuda.Stream() # Main transfer stream
|
||||
|
||||
# 同步事件 - 三区域加载完成
|
||||
# Sync events - three-region loading completion
|
||||
self.compute_ready = torch.cuda.Event()
|
||||
self.prefetch_ready = torch.cuda.Event()
|
||||
self.decode_offload_done = torch.cuda.Event()
|
||||
|
||||
# 保留旧的ping/pong事件以兼容(后续会移除)
|
||||
# Keep old ping/pong events for compatibility (will be removed later)
|
||||
self.pingpong_stream = self.transfer_stream_main
|
||||
self.ping_ready = self.compute_ready
|
||||
self.pong_ready = self.prefetch_ready
|
||||
@@ -568,20 +568,20 @@ class OffloadEngine:
|
||||
f" kv_heads={self.num_kv_heads},\n"
|
||||
f" head_dim={self.head_dim},\n"
|
||||
f" dtype={self.dtype},\n"
|
||||
f" 三区域: decode_slot={self.decode_slot}, compute={self.compute_slots}, prefetch={self.prefetch_slots},\n"
|
||||
f" three_regions: decode_slot={self.decode_slot}, compute={self.compute_slots}, prefetch={self.prefetch_slots},\n"
|
||||
f" gpu_memory={self.gpu_memory_bytes() / 1024**2:.1f}MB,\n"
|
||||
f" cpu_memory={self.cpu_memory_bytes() / 1024**2:.1f}MB\n"
|
||||
f")"
|
||||
)
|
||||
|
||||
# ========== Ping-Pong 双缓冲方法 ==========
|
||||
# ========== Ping-Pong double buffering methods ==========
|
||||
|
||||
def load_to_ping(self, cpu_block_ids: List[int]) -> None:
|
||||
"""
|
||||
异步加载CPU blocks到Ping buffer。
|
||||
Async load CPU blocks to Ping buffer.
|
||||
|
||||
Args:
|
||||
cpu_block_ids: 要加载的CPU block IDs列表
|
||||
cpu_block_ids: List of CPU block IDs to load
|
||||
"""
|
||||
if not cpu_block_ids:
|
||||
self.ping_ready.record(self.pingpong_stream)
|
||||
@@ -594,7 +594,7 @@ class OffloadEngine:
|
||||
for i in range(num_to_load):
|
||||
cpu_id = cpu_block_ids[i]
|
||||
gpu_slot = self.ping_slots[i]
|
||||
# 所有层一起复制
|
||||
# Copy all layers together
|
||||
self.k_cache_gpu[:, gpu_slot].copy_(
|
||||
self.k_cache_cpu[:, cpu_id], non_blocking=True
|
||||
)
|
||||
@@ -605,10 +605,10 @@ class OffloadEngine:
|
||||
|
||||
def load_to_pong(self, cpu_block_ids: List[int]) -> None:
|
||||
"""
|
||||
异步加载CPU blocks到Pong buffer。
|
||||
Async load CPU blocks to Pong buffer.
|
||||
|
||||
Args:
|
||||
cpu_block_ids: 要加载的CPU block IDs列表
|
||||
cpu_block_ids: List of CPU block IDs to load
|
||||
"""
|
||||
if not cpu_block_ids:
|
||||
self.pong_ready.record(self.pingpong_stream)
|
||||
@@ -630,11 +630,11 @@ class OffloadEngine:
|
||||
self.pong_ready.record(self.pingpong_stream)
|
||||
|
||||
def wait_ping(self) -> None:
|
||||
"""等待Ping buffer加载完成。"""
|
||||
"""Wait for Ping buffer loading to complete."""
|
||||
self.compute_stream.wait_event(self.ping_ready)
|
||||
|
||||
def wait_pong(self) -> None:
|
||||
"""等待Pong buffer加载完成。"""
|
||||
"""Wait for Pong buffer loading to complete."""
|
||||
self.compute_stream.wait_event(self.pong_ready)
|
||||
|
||||
def offload_buffer_to_cpu(
|
||||
@@ -643,11 +643,11 @@ class OffloadEngine:
|
||||
cpu_block_ids: List[int],
|
||||
) -> None:
|
||||
"""
|
||||
异步将buffer中的KV offload到CPU。
|
||||
Async offload KV from buffer to CPU.
|
||||
|
||||
Args:
|
||||
buffer: "ping" 或 "pong"
|
||||
cpu_block_ids: 目标CPU block IDs列表
|
||||
buffer: "ping" or "pong"
|
||||
cpu_block_ids: Target CPU block IDs list
|
||||
"""
|
||||
slots = self.ping_slots if buffer == "ping" else self.pong_slots
|
||||
event = self.ping_offload_done if buffer == "ping" else self.pong_offload_done
|
||||
@@ -660,7 +660,7 @@ class OffloadEngine:
|
||||
logger.debug(f"{buffer.capitalize()} offload: GPU {slots[:num_to_offload]} -> CPU{cpu_block_ids[:num_to_offload]}")
|
||||
|
||||
with torch.cuda.stream(self.pingpong_stream):
|
||||
# 等待计算完成
|
||||
# Wait for compute to complete
|
||||
self.pingpong_stream.wait_stream(self.compute_stream)
|
||||
|
||||
for i in range(num_to_offload):
|
||||
@@ -680,11 +680,11 @@ class OffloadEngine:
|
||||
cpu_block_id: int,
|
||||
) -> None:
|
||||
"""
|
||||
异步将单个GPU slot的KV offload到CPU。
|
||||
Async offload a single GPU slot's KV to CPU.
|
||||
|
||||
Args:
|
||||
gpu_slot: GPU slot ID
|
||||
cpu_block_id: 目标CPU block ID
|
||||
cpu_block_id: Target CPU block ID
|
||||
"""
|
||||
logger.debug(f"Slot offload: GPU[{gpu_slot}] -> CPU[{cpu_block_id}]")
|
||||
|
||||
@@ -698,15 +698,15 @@ class OffloadEngine:
|
||||
)
|
||||
|
||||
def wait_ping_offload_done(self) -> None:
|
||||
"""等待Ping buffer offload完成。"""
|
||||
"""Wait for Ping buffer offload to complete."""
|
||||
self.compute_stream.wait_event(self.ping_offload_done)
|
||||
|
||||
def wait_pong_offload_done(self) -> None:
|
||||
"""等待Pong buffer offload完成。"""
|
||||
"""Wait for Pong buffer offload to complete."""
|
||||
self.compute_stream.wait_event(self.pong_offload_done)
|
||||
|
||||
def wait_all_offload_done(self) -> None:
|
||||
"""等待所有offload完成。"""
|
||||
"""Wait for all offload operations to complete."""
|
||||
self.pingpong_stream.synchronize()
|
||||
|
||||
def get_kv_for_ping_slots(
|
||||
@@ -715,14 +715,14 @@ class OffloadEngine:
|
||||
num_slots: int,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
获取Ping buffer中指定数量slots的KV。
|
||||
Get KV for specified number of slots in Ping buffer.
|
||||
|
||||
Args:
|
||||
layer_id: 层ID
|
||||
num_slots: 需要的slot数量
|
||||
layer_id: Layer ID
|
||||
num_slots: Number of slots needed
|
||||
|
||||
Returns:
|
||||
(k_cache, v_cache),shape: [1, num_slots * block_size, kv_heads, head_dim]
|
||||
(k_cache, v_cache), shape: [1, num_slots * block_size, kv_heads, head_dim]
|
||||
"""
|
||||
slots = self.ping_slots[:num_slots]
|
||||
k = self.k_cache_gpu[layer_id, slots] # [num_slots, block_size, heads, dim]
|
||||
@@ -738,14 +738,14 @@ class OffloadEngine:
|
||||
num_slots: int,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
获取Pong buffer中指定数量slots的KV。
|
||||
Get KV for specified number of slots in Pong buffer.
|
||||
|
||||
Args:
|
||||
layer_id: 层ID
|
||||
num_slots: 需要的slot数量
|
||||
layer_id: Layer ID
|
||||
num_slots: Number of slots needed
|
||||
|
||||
Returns:
|
||||
(k_cache, v_cache),shape: [1, num_slots * block_size, kv_heads, head_dim]
|
||||
(k_cache, v_cache), shape: [1, num_slots * block_size, kv_heads, head_dim]
|
||||
"""
|
||||
slots = self.pong_slots[:num_slots]
|
||||
k = self.k_cache_gpu[layer_id, slots]
|
||||
@@ -760,14 +760,14 @@ class OffloadEngine:
|
||||
gpu_slots: List[int],
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
获取指定GPU slots的KV。
|
||||
Get KV for specified GPU slots.
|
||||
|
||||
Args:
|
||||
layer_id: 层ID
|
||||
gpu_slots: GPU slot IDs列表
|
||||
layer_id: Layer ID
|
||||
gpu_slots: List of GPU slot IDs
|
||||
|
||||
Returns:
|
||||
(k_cache, v_cache),shape: [1, len(slots) * block_size, kv_heads, head_dim]
|
||||
(k_cache, v_cache), shape: [1, len(slots) * block_size, kv_heads, head_dim]
|
||||
"""
|
||||
if not gpu_slots:
|
||||
return None, None
|
||||
@@ -777,14 +777,14 @@ class OffloadEngine:
|
||||
v = v.reshape(1, -1, self.num_kv_heads, self.head_dim)
|
||||
return k, v
|
||||
|
||||
# ========== 三区域 GPU Buffer 方法 ==========
|
||||
# ========== Three-region GPU Buffer methods ==========
|
||||
|
||||
def load_to_compute(self, cpu_block_ids: List[int]) -> None:
|
||||
"""
|
||||
异步加载CPU blocks到Compute区。
|
||||
Async load CPU blocks to Compute region.
|
||||
|
||||
Args:
|
||||
cpu_block_ids: 要加载的CPU block IDs列表
|
||||
cpu_block_ids: List of CPU block IDs to load
|
||||
"""
|
||||
if not cpu_block_ids:
|
||||
self.compute_ready.record(self.transfer_stream_main)
|
||||
@@ -797,7 +797,7 @@ class OffloadEngine:
|
||||
for i in range(num_to_load):
|
||||
cpu_id = cpu_block_ids[i]
|
||||
gpu_slot = self.compute_slots[i]
|
||||
# 所有层一起复制
|
||||
# Copy all layers together
|
||||
self.k_cache_gpu[:, gpu_slot].copy_(
|
||||
self.k_cache_cpu[:, cpu_id], non_blocking=True
|
||||
)
|
||||
@@ -808,10 +808,10 @@ class OffloadEngine:
|
||||
|
||||
def load_to_prefetch(self, cpu_block_ids: List[int]) -> None:
|
||||
"""
|
||||
异步加载CPU blocks到Prefetch区。
|
||||
Async load CPU blocks to Prefetch region.
|
||||
|
||||
Args:
|
||||
cpu_block_ids: 要加载的CPU block IDs列表
|
||||
cpu_block_ids: List of CPU block IDs to load
|
||||
"""
|
||||
if not cpu_block_ids:
|
||||
self.prefetch_ready.record(self.transfer_stream_main)
|
||||
@@ -833,25 +833,25 @@ class OffloadEngine:
|
||||
self.prefetch_ready.record(self.transfer_stream_main)
|
||||
|
||||
def wait_compute(self) -> None:
|
||||
"""等待Compute区加载完成。"""
|
||||
"""Wait for Compute region loading to complete."""
|
||||
self.compute_stream.wait_event(self.compute_ready)
|
||||
|
||||
def wait_prefetch(self) -> None:
|
||||
"""等待Prefetch区加载完成。"""
|
||||
"""Wait for Prefetch region loading to complete."""
|
||||
self.compute_stream.wait_event(self.prefetch_ready)
|
||||
|
||||
def swap_compute_prefetch(self) -> None:
|
||||
"""交换Compute区和Prefetch区的角色。"""
|
||||
"""Swap roles of Compute region and Prefetch region."""
|
||||
self.compute_slots, self.prefetch_slots = self.prefetch_slots, self.compute_slots
|
||||
# 同时更新旧的ping/pong slots以保持兼容
|
||||
# Also update old ping/pong slots for compatibility
|
||||
self.ping_slots, self.pong_slots = self.pong_slots, self.ping_slots
|
||||
|
||||
def offload_decode_slot(self, cpu_block_id: int) -> None:
|
||||
"""
|
||||
将Decode区的KV offload到CPU。
|
||||
Offload KV from Decode region to CPU.
|
||||
|
||||
Args:
|
||||
cpu_block_id: 目标CPU block ID
|
||||
cpu_block_id: Target CPU block ID
|
||||
"""
|
||||
logger.debug(f"Decode offload: GPU[{self.decode_slot}] -> CPU[{cpu_block_id}]")
|
||||
|
||||
@@ -866,7 +866,7 @@ class OffloadEngine:
|
||||
self.decode_offload_done.record(self.transfer_stream_main)
|
||||
|
||||
def wait_decode_offload(self) -> None:
|
||||
"""等待Decode区offload完成。"""
|
||||
"""Wait for Decode region offload to complete."""
|
||||
self.compute_stream.wait_event(self.decode_offload_done)
|
||||
|
||||
def get_kv_for_compute(
|
||||
@@ -875,14 +875,14 @@ class OffloadEngine:
|
||||
num_blocks: int,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
获取Compute区中指定数量blocks的KV。
|
||||
Get KV for specified number of blocks in Compute region.
|
||||
|
||||
Args:
|
||||
layer_id: 层ID
|
||||
num_blocks: 需要的block数量
|
||||
layer_id: Layer ID
|
||||
num_blocks: Number of blocks needed
|
||||
|
||||
Returns:
|
||||
(k_cache, v_cache),shape: [1, num_blocks * block_size, kv_heads, head_dim]
|
||||
(k_cache, v_cache), shape: [1, num_blocks * block_size, kv_heads, head_dim]
|
||||
"""
|
||||
slots = self.compute_slots[:num_blocks]
|
||||
k = self.k_cache_gpu[layer_id, slots] # [num_blocks, block_size, heads, dim]
|
||||
@@ -898,14 +898,14 @@ class OffloadEngine:
|
||||
num_blocks: int,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
获取Prefetch区中指定数量blocks的KV。
|
||||
Get KV for specified number of blocks in Prefetch region.
|
||||
|
||||
Args:
|
||||
layer_id: 层ID
|
||||
num_blocks: 需要的block数量
|
||||
layer_id: Layer ID
|
||||
num_blocks: Number of blocks needed
|
||||
|
||||
Returns:
|
||||
(k_cache, v_cache),shape: [1, num_blocks * block_size, kv_heads, head_dim]
|
||||
(k_cache, v_cache), shape: [1, num_blocks * block_size, kv_heads, head_dim]
|
||||
"""
|
||||
slots = self.prefetch_slots[:num_blocks]
|
||||
k = self.k_cache_gpu[layer_id, slots]
|
||||
@@ -920,14 +920,14 @@ class OffloadEngine:
|
||||
pos_in_block: int,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
获取Decode区指定位置的KV(用于decode时的新token)。
|
||||
Get KV at specified position in Decode region (for new token during decode).
|
||||
|
||||
Args:
|
||||
layer_id: 层ID
|
||||
pos_in_block: token在block内的位置 (0 to block_size-1)
|
||||
layer_id: Layer ID
|
||||
pos_in_block: Token position within block (0 to block_size-1)
|
||||
|
||||
Returns:
|
||||
(k_cache, v_cache),shape: [1, 1, kv_heads, head_dim]
|
||||
(k_cache, v_cache), shape: [1, 1, kv_heads, head_dim]
|
||||
"""
|
||||
k = self.k_cache_gpu[layer_id, self.decode_slot, pos_in_block:pos_in_block+1] # [1, heads, dim]
|
||||
v = self.v_cache_gpu[layer_id, self.decode_slot, pos_in_block:pos_in_block+1]
|
||||
@@ -935,12 +935,36 @@ class OffloadEngine:
|
||||
v = v.unsqueeze(0)
|
||||
return k, v
|
||||
|
||||
def offload_compute_to_cpu(self, cpu_block_ids: List[int]) -> None:
|
||||
def get_kv_for_decode_slot_accumulated(
|
||||
self,
|
||||
layer_id: int,
|
||||
num_tokens: int,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
将Compute区的KV offload到CPU。
|
||||
Get accumulated KV in Decode region (all tokens from position 0 to num_tokens-1).
|
||||
|
||||
Used when batching decode offloads - attend to all accumulated tokens,
|
||||
not just the current one.
|
||||
|
||||
Args:
|
||||
cpu_block_ids: 目标CPU block IDs列表
|
||||
layer_id: Layer ID
|
||||
num_tokens: Number of accumulated tokens (1 to block_size)
|
||||
|
||||
Returns:
|
||||
(k_cache, v_cache), shape: [1, num_tokens, kv_heads, head_dim]
|
||||
"""
|
||||
k = self.k_cache_gpu[layer_id, self.decode_slot, :num_tokens] # [num_tokens, heads, dim]
|
||||
v = self.v_cache_gpu[layer_id, self.decode_slot, :num_tokens]
|
||||
k = k.unsqueeze(0) # [1, num_tokens, heads, dim]
|
||||
v = v.unsqueeze(0)
|
||||
return k, v
|
||||
|
||||
def offload_compute_to_cpu(self, cpu_block_ids: List[int]) -> None:
|
||||
"""
|
||||
Offload KV from Compute region to CPU.
|
||||
|
||||
Args:
|
||||
cpu_block_ids: Target CPU block IDs list
|
||||
"""
|
||||
if not cpu_block_ids:
|
||||
return
|
||||
@@ -949,7 +973,7 @@ class OffloadEngine:
|
||||
logger.debug(f"Compute offload: GPU {self.compute_slots[:num_to_offload]} -> CPU{cpu_block_ids[:num_to_offload]}")
|
||||
|
||||
with torch.cuda.stream(self.transfer_stream_main):
|
||||
# 等待计算完成
|
||||
# Wait for compute to complete
|
||||
self.transfer_stream_main.wait_stream(self.compute_stream)
|
||||
|
||||
for i in range(num_to_offload):
|
||||
|
||||
Reference in New Issue
Block a user