[refactor] Translate into english, void Chinese due to claude.

This commit is contained in:
Zijie Tian
2025-12-11 00:30:24 +08:00
parent e85c2b4776
commit babfa17354
9 changed files with 297 additions and 187 deletions

View File

@@ -65,44 +65,44 @@ class OffloadEngine:
self.kv_dim = num_kv_heads * head_dim
self.block_numel = block_size * self.kv_dim
# ========== 三区域 GPU Buffer 配置 ==========
# 约束检查
# ========== Three-region GPU Buffer configuration ==========
# Constraint checks
assert num_gpu_blocks >= 3, \
f"至少需要3个GPU blocks: 1 decode + 1 compute + 1 prefetch, got {num_gpu_blocks}"
f"Need at least 3 GPU blocks: 1 decode + 1 compute + 1 prefetch, got {num_gpu_blocks}"
assert num_prefetch_blocks >= 1, \
f"至少需要1个prefetch block, got {num_prefetch_blocks}"
f"Need at least 1 prefetch block, got {num_prefetch_blocks}"
assert num_gpu_blocks >= 1 + 1 + num_prefetch_blocks, \
f"GPU blocks不足: 需要 1(decode) + 1(compute) + {num_prefetch_blocks}(prefetch), got {num_gpu_blocks}"
f"Insufficient GPU blocks: need 1(decode) + 1(compute) + {num_prefetch_blocks}(prefetch), got {num_gpu_blocks}"
# 三区域配置
# Decode: [0] - 固定1个block用于写入新KV
# Three-region configuration
# Decode region: [0] - Fixed 1 block for writing new KV
self.decode_slot = 0
# Compute: [1, ..., num_gpu_blocks - num_prefetch_blocks - 1]
# Compute region: [1, ..., num_gpu_blocks - num_prefetch_blocks - 1]
compute_start = 1
compute_end = num_gpu_blocks - num_prefetch_blocks
self.compute_slots = list(range(compute_start, compute_end))
self.num_compute_blocks = len(self.compute_slots)
# Prefetch: [num_gpu_blocks - num_prefetch_blocks, ..., num_gpu_blocks - 1]
# Prefetch region: [num_gpu_blocks - num_prefetch_blocks, ..., num_gpu_blocks - 1]
prefetch_start = compute_end
self.prefetch_slots = list(range(prefetch_start, num_gpu_blocks))
self.num_prefetch_blocks = num_prefetch_blocks
self.num_gpu_slots = num_gpu_blocks # alias
# 保留旧的ping/pong属性以兼容后续会移除
# Keep old ping/pong attributes for compatibility (will be removed later)
self.ping_size = self.num_compute_blocks
self.pong_size = self.num_prefetch_blocks
self.ping_slots = self.compute_slots.copy()
self.pong_slots = self.prefetch_slots.copy()
logger.info(f"三区域 GPU Buffer: decode_slot={self.decode_slot}, "
logger.info(f"Three-region GPU Buffer: decode_slot={self.decode_slot}, "
f"compute_slots={self.compute_slots}, prefetch_slots={self.prefetch_slots}")
# ========== Fixed-address GPU KV cache ==========
# Shape: [num_layers, num_gpu_blocks, block_size, kv_heads, head_dim]
# 使用 zeros 初始化以避免未初始化内存问题
# Use zeros initialization to avoid uninitialized memory issues
self.k_cache_gpu = torch.zeros(
num_layers, num_gpu_blocks, block_size, num_kv_heads, head_dim,
dtype=dtype, device="cuda"
@@ -140,15 +140,15 @@ class OffloadEngine:
self.compute_stream = torch.cuda.current_stream()
self._stream_idx = 0
# ========== 三区域专用 stream 和事件 ==========
self.transfer_stream_main = torch.cuda.Stream() # 主传输stream
# ========== Three-region dedicated stream and events ==========
self.transfer_stream_main = torch.cuda.Stream() # Main transfer stream
# 同步事件 - 三区域加载完成
# Sync events - three-region loading completion
self.compute_ready = torch.cuda.Event()
self.prefetch_ready = torch.cuda.Event()
self.decode_offload_done = torch.cuda.Event()
# 保留旧的ping/pong事件以兼容后续会移除
# Keep old ping/pong events for compatibility (will be removed later)
self.pingpong_stream = self.transfer_stream_main
self.ping_ready = self.compute_ready
self.pong_ready = self.prefetch_ready
@@ -568,20 +568,20 @@ class OffloadEngine:
f" kv_heads={self.num_kv_heads},\n"
f" head_dim={self.head_dim},\n"
f" dtype={self.dtype},\n"
f" 三区域: decode_slot={self.decode_slot}, compute={self.compute_slots}, prefetch={self.prefetch_slots},\n"
f" three_regions: decode_slot={self.decode_slot}, compute={self.compute_slots}, prefetch={self.prefetch_slots},\n"
f" gpu_memory={self.gpu_memory_bytes() / 1024**2:.1f}MB,\n"
f" cpu_memory={self.cpu_memory_bytes() / 1024**2:.1f}MB\n"
f")"
)
# ========== Ping-Pong 双缓冲方法 ==========
# ========== Ping-Pong double buffering methods ==========
def load_to_ping(self, cpu_block_ids: List[int]) -> None:
"""
异步加载CPU blocksPing buffer
Async load CPU blocks to Ping buffer.
Args:
cpu_block_ids: 要加载的CPU block IDs列表
cpu_block_ids: List of CPU block IDs to load
"""
if not cpu_block_ids:
self.ping_ready.record(self.pingpong_stream)
@@ -594,7 +594,7 @@ class OffloadEngine:
for i in range(num_to_load):
cpu_id = cpu_block_ids[i]
gpu_slot = self.ping_slots[i]
# 所有层一起复制
# Copy all layers together
self.k_cache_gpu[:, gpu_slot].copy_(
self.k_cache_cpu[:, cpu_id], non_blocking=True
)
@@ -605,10 +605,10 @@ class OffloadEngine:
def load_to_pong(self, cpu_block_ids: List[int]) -> None:
"""
异步加载CPU blocksPong buffer
Async load CPU blocks to Pong buffer.
Args:
cpu_block_ids: 要加载的CPU block IDs列表
cpu_block_ids: List of CPU block IDs to load
"""
if not cpu_block_ids:
self.pong_ready.record(self.pingpong_stream)
@@ -630,11 +630,11 @@ class OffloadEngine:
self.pong_ready.record(self.pingpong_stream)
def wait_ping(self) -> None:
"""等待Ping buffer加载完成。"""
"""Wait for Ping buffer loading to complete."""
self.compute_stream.wait_event(self.ping_ready)
def wait_pong(self) -> None:
"""等待Pong buffer加载完成。"""
"""Wait for Pong buffer loading to complete."""
self.compute_stream.wait_event(self.pong_ready)
def offload_buffer_to_cpu(
@@ -643,11 +643,11 @@ class OffloadEngine:
cpu_block_ids: List[int],
) -> None:
"""
异步将buffer中的KV offload到CPU
Async offload KV from buffer to CPU.
Args:
buffer: "ping" "pong"
cpu_block_ids: 目标CPU block IDs列表
buffer: "ping" or "pong"
cpu_block_ids: Target CPU block IDs list
"""
slots = self.ping_slots if buffer == "ping" else self.pong_slots
event = self.ping_offload_done if buffer == "ping" else self.pong_offload_done
@@ -660,7 +660,7 @@ class OffloadEngine:
logger.debug(f"{buffer.capitalize()} offload: GPU {slots[:num_to_offload]} -> CPU{cpu_block_ids[:num_to_offload]}")
with torch.cuda.stream(self.pingpong_stream):
# 等待计算完成
# Wait for compute to complete
self.pingpong_stream.wait_stream(self.compute_stream)
for i in range(num_to_offload):
@@ -680,11 +680,11 @@ class OffloadEngine:
cpu_block_id: int,
) -> None:
"""
异步将单个GPU slot的KV offload到CPU
Async offload a single GPU slot's KV to CPU.
Args:
gpu_slot: GPU slot ID
cpu_block_id: 目标CPU block ID
cpu_block_id: Target CPU block ID
"""
logger.debug(f"Slot offload: GPU[{gpu_slot}] -> CPU[{cpu_block_id}]")
@@ -698,15 +698,15 @@ class OffloadEngine:
)
def wait_ping_offload_done(self) -> None:
"""等待Ping buffer offload完成。"""
"""Wait for Ping buffer offload to complete."""
self.compute_stream.wait_event(self.ping_offload_done)
def wait_pong_offload_done(self) -> None:
"""等待Pong buffer offload完成。"""
"""Wait for Pong buffer offload to complete."""
self.compute_stream.wait_event(self.pong_offload_done)
def wait_all_offload_done(self) -> None:
"""等待所有offload完成。"""
"""Wait for all offload operations to complete."""
self.pingpong_stream.synchronize()
def get_kv_for_ping_slots(
@@ -715,14 +715,14 @@ class OffloadEngine:
num_slots: int,
) -> Tuple[Tensor, Tensor]:
"""
获取Ping buffer中指定数量slots的KV。
Get KV for specified number of slots in Ping buffer.
Args:
layer_id: ID
num_slots: 需要的slot数量
layer_id: Layer ID
num_slots: Number of slots needed
Returns:
(k_cache, v_cache)shape: [1, num_slots * block_size, kv_heads, head_dim]
(k_cache, v_cache), shape: [1, num_slots * block_size, kv_heads, head_dim]
"""
slots = self.ping_slots[:num_slots]
k = self.k_cache_gpu[layer_id, slots] # [num_slots, block_size, heads, dim]
@@ -738,14 +738,14 @@ class OffloadEngine:
num_slots: int,
) -> Tuple[Tensor, Tensor]:
"""
获取Pong buffer中指定数量slots的KV。
Get KV for specified number of slots in Pong buffer.
Args:
layer_id: ID
num_slots: 需要的slot数量
layer_id: Layer ID
num_slots: Number of slots needed
Returns:
(k_cache, v_cache)shape: [1, num_slots * block_size, kv_heads, head_dim]
(k_cache, v_cache), shape: [1, num_slots * block_size, kv_heads, head_dim]
"""
slots = self.pong_slots[:num_slots]
k = self.k_cache_gpu[layer_id, slots]
@@ -760,14 +760,14 @@ class OffloadEngine:
gpu_slots: List[int],
) -> Tuple[Tensor, Tensor]:
"""
获取指定GPU slots的KV。
Get KV for specified GPU slots.
Args:
layer_id: ID
gpu_slots: GPU slot IDs列表
layer_id: Layer ID
gpu_slots: List of GPU slot IDs
Returns:
(k_cache, v_cache)shape: [1, len(slots) * block_size, kv_heads, head_dim]
(k_cache, v_cache), shape: [1, len(slots) * block_size, kv_heads, head_dim]
"""
if not gpu_slots:
return None, None
@@ -777,14 +777,14 @@ class OffloadEngine:
v = v.reshape(1, -1, self.num_kv_heads, self.head_dim)
return k, v
# ========== 三区域 GPU Buffer 方法 ==========
# ========== Three-region GPU Buffer methods ==========
def load_to_compute(self, cpu_block_ids: List[int]) -> None:
"""
异步加载CPU blocksCompute区。
Async load CPU blocks to Compute region.
Args:
cpu_block_ids: 要加载的CPU block IDs列表
cpu_block_ids: List of CPU block IDs to load
"""
if not cpu_block_ids:
self.compute_ready.record(self.transfer_stream_main)
@@ -797,7 +797,7 @@ class OffloadEngine:
for i in range(num_to_load):
cpu_id = cpu_block_ids[i]
gpu_slot = self.compute_slots[i]
# 所有层一起复制
# Copy all layers together
self.k_cache_gpu[:, gpu_slot].copy_(
self.k_cache_cpu[:, cpu_id], non_blocking=True
)
@@ -808,10 +808,10 @@ class OffloadEngine:
def load_to_prefetch(self, cpu_block_ids: List[int]) -> None:
"""
异步加载CPU blocksPrefetch区。
Async load CPU blocks to Prefetch region.
Args:
cpu_block_ids: 要加载的CPU block IDs列表
cpu_block_ids: List of CPU block IDs to load
"""
if not cpu_block_ids:
self.prefetch_ready.record(self.transfer_stream_main)
@@ -833,25 +833,25 @@ class OffloadEngine:
self.prefetch_ready.record(self.transfer_stream_main)
def wait_compute(self) -> None:
"""等待Compute区加载完成。"""
"""Wait for Compute region loading to complete."""
self.compute_stream.wait_event(self.compute_ready)
def wait_prefetch(self) -> None:
"""等待Prefetch区加载完成。"""
"""Wait for Prefetch region loading to complete."""
self.compute_stream.wait_event(self.prefetch_ready)
def swap_compute_prefetch(self) -> None:
"""交换Compute区和Prefetch区的角色。"""
"""Swap roles of Compute region and Prefetch region."""
self.compute_slots, self.prefetch_slots = self.prefetch_slots, self.compute_slots
# 同时更新旧的ping/pong slots以保持兼容
# Also update old ping/pong slots for compatibility
self.ping_slots, self.pong_slots = self.pong_slots, self.ping_slots
def offload_decode_slot(self, cpu_block_id: int) -> None:
"""
将Decode区的KV offload到CPU
Offload KV from Decode region to CPU.
Args:
cpu_block_id: 目标CPU block ID
cpu_block_id: Target CPU block ID
"""
logger.debug(f"Decode offload: GPU[{self.decode_slot}] -> CPU[{cpu_block_id}]")
@@ -866,7 +866,7 @@ class OffloadEngine:
self.decode_offload_done.record(self.transfer_stream_main)
def wait_decode_offload(self) -> None:
"""等待Decode区offload完成。"""
"""Wait for Decode region offload to complete."""
self.compute_stream.wait_event(self.decode_offload_done)
def get_kv_for_compute(
@@ -875,14 +875,14 @@ class OffloadEngine:
num_blocks: int,
) -> Tuple[Tensor, Tensor]:
"""
获取Compute区中指定数量blocks的KV。
Get KV for specified number of blocks in Compute region.
Args:
layer_id: ID
num_blocks: 需要的block数量
layer_id: Layer ID
num_blocks: Number of blocks needed
Returns:
(k_cache, v_cache)shape: [1, num_blocks * block_size, kv_heads, head_dim]
(k_cache, v_cache), shape: [1, num_blocks * block_size, kv_heads, head_dim]
"""
slots = self.compute_slots[:num_blocks]
k = self.k_cache_gpu[layer_id, slots] # [num_blocks, block_size, heads, dim]
@@ -898,14 +898,14 @@ class OffloadEngine:
num_blocks: int,
) -> Tuple[Tensor, Tensor]:
"""
获取Prefetch区中指定数量blocks的KV。
Get KV for specified number of blocks in Prefetch region.
Args:
layer_id: ID
num_blocks: 需要的block数量
layer_id: Layer ID
num_blocks: Number of blocks needed
Returns:
(k_cache, v_cache)shape: [1, num_blocks * block_size, kv_heads, head_dim]
(k_cache, v_cache), shape: [1, num_blocks * block_size, kv_heads, head_dim]
"""
slots = self.prefetch_slots[:num_blocks]
k = self.k_cache_gpu[layer_id, slots]
@@ -920,14 +920,14 @@ class OffloadEngine:
pos_in_block: int,
) -> Tuple[Tensor, Tensor]:
"""
获取Decode区指定位置的KV用于decode时的新token
Get KV at specified position in Decode region (for new token during decode).
Args:
layer_id: ID
pos_in_block: tokenblock内的位置 (0 to block_size-1)
layer_id: Layer ID
pos_in_block: Token position within block (0 to block_size-1)
Returns:
(k_cache, v_cache)shape: [1, 1, kv_heads, head_dim]
(k_cache, v_cache), shape: [1, 1, kv_heads, head_dim]
"""
k = self.k_cache_gpu[layer_id, self.decode_slot, pos_in_block:pos_in_block+1] # [1, heads, dim]
v = self.v_cache_gpu[layer_id, self.decode_slot, pos_in_block:pos_in_block+1]
@@ -935,12 +935,36 @@ class OffloadEngine:
v = v.unsqueeze(0)
return k, v
def offload_compute_to_cpu(self, cpu_block_ids: List[int]) -> None:
def get_kv_for_decode_slot_accumulated(
self,
layer_id: int,
num_tokens: int,
) -> Tuple[Tensor, Tensor]:
"""
将Compute区的KV offload到CPU。
Get accumulated KV in Decode region (all tokens from position 0 to num_tokens-1).
Used when batching decode offloads - attend to all accumulated tokens,
not just the current one.
Args:
cpu_block_ids: 目标CPU block IDs列表
layer_id: Layer ID
num_tokens: Number of accumulated tokens (1 to block_size)
Returns:
(k_cache, v_cache), shape: [1, num_tokens, kv_heads, head_dim]
"""
k = self.k_cache_gpu[layer_id, self.decode_slot, :num_tokens] # [num_tokens, heads, dim]
v = self.v_cache_gpu[layer_id, self.decode_slot, :num_tokens]
k = k.unsqueeze(0) # [1, num_tokens, heads, dim]
v = v.unsqueeze(0)
return k, v
def offload_compute_to_cpu(self, cpu_block_ids: List[int]) -> None:
"""
Offload KV from Compute region to CPU.
Args:
cpu_block_ids: Target CPU block IDs list
"""
if not cpu_block_ids:
return
@@ -949,7 +973,7 @@ class OffloadEngine:
logger.debug(f"Compute offload: GPU {self.compute_slots[:num_to_offload]} -> CPU{cpu_block_ids[:num_to_offload]}")
with torch.cuda.stream(self.transfer_stream_main):
# 等待计算完成
# Wait for compute to complete
self.transfer_stream_main.wait_stream(self.compute_stream)
for i in range(num_to_offload):