[refactor] Refactor offload code to multi-chunk.

This commit is contained in:
Zijie Tian
2025-12-15 01:13:58 +08:00
parent 5949537faf
commit 1081ab51ea
7 changed files with 36 additions and 233 deletions

View File

@@ -91,12 +91,6 @@ class OffloadEngine:
self.num_gpu_slots = num_gpu_blocks # alias
# Keep old ping/pong attributes for compatibility (will be removed later)
self.ping_size = self.num_compute_blocks
self.pong_size = self.num_prefetch_blocks
self.ping_slots = self.compute_slots.copy()
self.pong_slots = self.prefetch_slots.copy()
logger.info(f"Three-region GPU Buffer: decode_slot={self.decode_slot}, "
f"compute_slots={self.compute_slots}, prefetch_slots={self.prefetch_slots}")
@@ -148,13 +142,6 @@ class OffloadEngine:
self.prefetch_ready = torch.cuda.Event()
self.decode_offload_done = torch.cuda.Event()
# Keep old ping/pong events for compatibility (will be removed later)
self.pingpong_stream = self.transfer_stream_main
self.ping_ready = self.compute_ready
self.pong_ready = self.prefetch_ready
self.ping_offload_done = torch.cuda.Event()
self.pong_offload_done = torch.cuda.Event()
# ========== Per-layer events for chunked attention ==========
# Each layer has its own event for synchronization
self.compute_ready_per_layer = [torch.cuda.Event() for _ in range(num_layers)]
@@ -579,185 +566,9 @@ class OffloadEngine:
f")"
)
# ========== Ping-Pong double buffering methods ==========
def load_to_ping(self, cpu_block_ids: List[int]) -> None:
"""
Async load CPU blocks to Ping buffer.
Args:
cpu_block_ids: List of CPU block IDs to load
"""
if not cpu_block_ids:
self.ping_ready.record(self.pingpong_stream)
return
num_to_load = min(len(cpu_block_ids), self.ping_size)
logger.debug(f"Ping load: CPU{cpu_block_ids[:num_to_load]} -> GPU ping slots {self.ping_slots[:num_to_load]}")
with torch.cuda.stream(self.pingpong_stream):
for i in range(num_to_load):
cpu_id = cpu_block_ids[i]
gpu_slot = self.ping_slots[i]
# Copy all layers together
self.k_cache_gpu[:, gpu_slot].copy_(
self.k_cache_cpu[:, cpu_id], non_blocking=True
)
self.v_cache_gpu[:, gpu_slot].copy_(
self.v_cache_cpu[:, cpu_id], non_blocking=True
)
self.ping_ready.record(self.pingpong_stream)
def load_to_pong(self, cpu_block_ids: List[int]) -> None:
"""
Async load CPU blocks to Pong buffer.
Args:
cpu_block_ids: List of CPU block IDs to load
"""
if not cpu_block_ids:
self.pong_ready.record(self.pingpong_stream)
return
num_to_load = min(len(cpu_block_ids), self.pong_size)
logger.debug(f"Pong load: CPU{cpu_block_ids[:num_to_load]} -> GPU pong slots {self.pong_slots[:num_to_load]}")
with torch.cuda.stream(self.pingpong_stream):
for i in range(num_to_load):
cpu_id = cpu_block_ids[i]
gpu_slot = self.pong_slots[i]
self.k_cache_gpu[:, gpu_slot].copy_(
self.k_cache_cpu[:, cpu_id], non_blocking=True
)
self.v_cache_gpu[:, gpu_slot].copy_(
self.v_cache_cpu[:, cpu_id], non_blocking=True
)
self.pong_ready.record(self.pingpong_stream)
def wait_ping(self) -> None:
"""Wait for Ping buffer loading to complete."""
self.compute_stream.wait_event(self.ping_ready)
def wait_pong(self) -> None:
"""Wait for Pong buffer loading to complete."""
self.compute_stream.wait_event(self.pong_ready)
def offload_buffer_to_cpu(
self,
buffer: str,
cpu_block_ids: List[int],
) -> None:
"""
Async offload KV from buffer to CPU.
Args:
buffer: "ping" or "pong"
cpu_block_ids: Target CPU block IDs list
"""
slots = self.ping_slots if buffer == "ping" else self.pong_slots
event = self.ping_offload_done if buffer == "ping" else self.pong_offload_done
if not cpu_block_ids:
event.record(self.pingpong_stream)
return
num_to_offload = min(len(cpu_block_ids), len(slots))
logger.debug(f"{buffer.capitalize()} offload: GPU {slots[:num_to_offload]} -> CPU{cpu_block_ids[:num_to_offload]}")
with torch.cuda.stream(self.pingpong_stream):
# Wait for compute to complete
self.pingpong_stream.wait_stream(self.compute_stream)
for i in range(num_to_offload):
gpu_slot = slots[i]
cpu_id = cpu_block_ids[i]
self.k_cache_cpu[:, cpu_id].copy_(
self.k_cache_gpu[:, gpu_slot], non_blocking=True
)
self.v_cache_cpu[:, cpu_id].copy_(
self.v_cache_gpu[:, gpu_slot], non_blocking=True
)
event.record(self.pingpong_stream)
def offload_slot_to_cpu(
self,
gpu_slot: int,
cpu_block_id: int,
) -> None:
"""
Async offload a single GPU slot's KV to CPU.
Args:
gpu_slot: GPU slot ID
cpu_block_id: Target CPU block ID
"""
logger.debug(f"Slot offload: GPU[{gpu_slot}] -> CPU[{cpu_block_id}]")
with torch.cuda.stream(self.pingpong_stream):
self.pingpong_stream.wait_stream(self.compute_stream)
self.k_cache_cpu[:, cpu_block_id].copy_(
self.k_cache_gpu[:, gpu_slot], non_blocking=True
)
self.v_cache_cpu[:, cpu_block_id].copy_(
self.v_cache_gpu[:, gpu_slot], non_blocking=True
)
def wait_ping_offload_done(self) -> None:
"""Wait for Ping buffer offload to complete."""
self.compute_stream.wait_event(self.ping_offload_done)
def wait_pong_offload_done(self) -> None:
"""Wait for Pong buffer offload to complete."""
self.compute_stream.wait_event(self.pong_offload_done)
def wait_all_offload_done(self) -> None:
"""Wait for all offload operations to complete."""
self.pingpong_stream.synchronize()
def get_kv_for_ping_slots(
self,
layer_id: int,
num_slots: int,
) -> Tuple[Tensor, Tensor]:
"""
Get KV for specified number of slots in Ping buffer.
Args:
layer_id: Layer ID
num_slots: Number of slots needed
Returns:
(k_cache, v_cache), shape: [1, num_slots * block_size, kv_heads, head_dim]
"""
slots = self.ping_slots[:num_slots]
k = self.k_cache_gpu[layer_id, slots] # [num_slots, block_size, heads, dim]
v = self.v_cache_gpu[layer_id, slots]
# Reshape: [num_slots, block_size, heads, dim] -> [1, num_slots*block_size, heads, dim]
k = k.reshape(1, -1, self.num_kv_heads, self.head_dim)
v = v.reshape(1, -1, self.num_kv_heads, self.head_dim)
return k, v
def get_kv_for_pong_slots(
self,
layer_id: int,
num_slots: int,
) -> Tuple[Tensor, Tensor]:
"""
Get KV for specified number of slots in Pong buffer.
Args:
layer_id: Layer ID
num_slots: Number of slots needed
Returns:
(k_cache, v_cache), shape: [1, num_slots * block_size, kv_heads, head_dim]
"""
slots = self.pong_slots[:num_slots]
k = self.k_cache_gpu[layer_id, slots]
v = self.v_cache_gpu[layer_id, slots]
k = k.reshape(1, -1, self.num_kv_heads, self.head_dim)
v = v.reshape(1, -1, self.num_kv_heads, self.head_dim)
return k, v
self.transfer_stream_main.synchronize()
def get_kv_for_slots(
self,
@@ -918,8 +729,6 @@ class OffloadEngine:
def swap_compute_prefetch(self) -> None:
"""Swap roles of Compute region and Prefetch region."""
self.compute_slots, self.prefetch_slots = self.prefetch_slots, self.compute_slots
# Also update old ping/pong slots for compatibility
self.ping_slots, self.pong_slots = self.pong_slots, self.ping_slots
def offload_decode_slot(self, cpu_block_id: int) -> None:
"""