[refactor] Refactor offload code to multi-chunk.
This commit is contained in:
@@ -91,12 +91,6 @@ class OffloadEngine:
|
||||
|
||||
self.num_gpu_slots = num_gpu_blocks # alias
|
||||
|
||||
# Keep old ping/pong attributes for compatibility (will be removed later)
|
||||
self.ping_size = self.num_compute_blocks
|
||||
self.pong_size = self.num_prefetch_blocks
|
||||
self.ping_slots = self.compute_slots.copy()
|
||||
self.pong_slots = self.prefetch_slots.copy()
|
||||
|
||||
logger.info(f"Three-region GPU Buffer: decode_slot={self.decode_slot}, "
|
||||
f"compute_slots={self.compute_slots}, prefetch_slots={self.prefetch_slots}")
|
||||
|
||||
@@ -148,13 +142,6 @@ class OffloadEngine:
|
||||
self.prefetch_ready = torch.cuda.Event()
|
||||
self.decode_offload_done = torch.cuda.Event()
|
||||
|
||||
# Keep old ping/pong events for compatibility (will be removed later)
|
||||
self.pingpong_stream = self.transfer_stream_main
|
||||
self.ping_ready = self.compute_ready
|
||||
self.pong_ready = self.prefetch_ready
|
||||
self.ping_offload_done = torch.cuda.Event()
|
||||
self.pong_offload_done = torch.cuda.Event()
|
||||
|
||||
# ========== Per-layer events for chunked attention ==========
|
||||
# Each layer has its own event for synchronization
|
||||
self.compute_ready_per_layer = [torch.cuda.Event() for _ in range(num_layers)]
|
||||
@@ -579,185 +566,9 @@ class OffloadEngine:
|
||||
f")"
|
||||
)
|
||||
|
||||
# ========== Ping-Pong double buffering methods ==========
|
||||
|
||||
def load_to_ping(self, cpu_block_ids: List[int]) -> None:
|
||||
"""
|
||||
Async load CPU blocks to Ping buffer.
|
||||
|
||||
Args:
|
||||
cpu_block_ids: List of CPU block IDs to load
|
||||
"""
|
||||
if not cpu_block_ids:
|
||||
self.ping_ready.record(self.pingpong_stream)
|
||||
return
|
||||
|
||||
num_to_load = min(len(cpu_block_ids), self.ping_size)
|
||||
logger.debug(f"Ping load: CPU{cpu_block_ids[:num_to_load]} -> GPU ping slots {self.ping_slots[:num_to_load]}")
|
||||
|
||||
with torch.cuda.stream(self.pingpong_stream):
|
||||
for i in range(num_to_load):
|
||||
cpu_id = cpu_block_ids[i]
|
||||
gpu_slot = self.ping_slots[i]
|
||||
# Copy all layers together
|
||||
self.k_cache_gpu[:, gpu_slot].copy_(
|
||||
self.k_cache_cpu[:, cpu_id], non_blocking=True
|
||||
)
|
||||
self.v_cache_gpu[:, gpu_slot].copy_(
|
||||
self.v_cache_cpu[:, cpu_id], non_blocking=True
|
||||
)
|
||||
self.ping_ready.record(self.pingpong_stream)
|
||||
|
||||
def load_to_pong(self, cpu_block_ids: List[int]) -> None:
|
||||
"""
|
||||
Async load CPU blocks to Pong buffer.
|
||||
|
||||
Args:
|
||||
cpu_block_ids: List of CPU block IDs to load
|
||||
"""
|
||||
if not cpu_block_ids:
|
||||
self.pong_ready.record(self.pingpong_stream)
|
||||
return
|
||||
|
||||
num_to_load = min(len(cpu_block_ids), self.pong_size)
|
||||
logger.debug(f"Pong load: CPU{cpu_block_ids[:num_to_load]} -> GPU pong slots {self.pong_slots[:num_to_load]}")
|
||||
|
||||
with torch.cuda.stream(self.pingpong_stream):
|
||||
for i in range(num_to_load):
|
||||
cpu_id = cpu_block_ids[i]
|
||||
gpu_slot = self.pong_slots[i]
|
||||
self.k_cache_gpu[:, gpu_slot].copy_(
|
||||
self.k_cache_cpu[:, cpu_id], non_blocking=True
|
||||
)
|
||||
self.v_cache_gpu[:, gpu_slot].copy_(
|
||||
self.v_cache_cpu[:, cpu_id], non_blocking=True
|
||||
)
|
||||
self.pong_ready.record(self.pingpong_stream)
|
||||
|
||||
def wait_ping(self) -> None:
|
||||
"""Wait for Ping buffer loading to complete."""
|
||||
self.compute_stream.wait_event(self.ping_ready)
|
||||
|
||||
def wait_pong(self) -> None:
|
||||
"""Wait for Pong buffer loading to complete."""
|
||||
self.compute_stream.wait_event(self.pong_ready)
|
||||
|
||||
def offload_buffer_to_cpu(
|
||||
self,
|
||||
buffer: str,
|
||||
cpu_block_ids: List[int],
|
||||
) -> None:
|
||||
"""
|
||||
Async offload KV from buffer to CPU.
|
||||
|
||||
Args:
|
||||
buffer: "ping" or "pong"
|
||||
cpu_block_ids: Target CPU block IDs list
|
||||
"""
|
||||
slots = self.ping_slots if buffer == "ping" else self.pong_slots
|
||||
event = self.ping_offload_done if buffer == "ping" else self.pong_offload_done
|
||||
|
||||
if not cpu_block_ids:
|
||||
event.record(self.pingpong_stream)
|
||||
return
|
||||
|
||||
num_to_offload = min(len(cpu_block_ids), len(slots))
|
||||
logger.debug(f"{buffer.capitalize()} offload: GPU {slots[:num_to_offload]} -> CPU{cpu_block_ids[:num_to_offload]}")
|
||||
|
||||
with torch.cuda.stream(self.pingpong_stream):
|
||||
# Wait for compute to complete
|
||||
self.pingpong_stream.wait_stream(self.compute_stream)
|
||||
|
||||
for i in range(num_to_offload):
|
||||
gpu_slot = slots[i]
|
||||
cpu_id = cpu_block_ids[i]
|
||||
self.k_cache_cpu[:, cpu_id].copy_(
|
||||
self.k_cache_gpu[:, gpu_slot], non_blocking=True
|
||||
)
|
||||
self.v_cache_cpu[:, cpu_id].copy_(
|
||||
self.v_cache_gpu[:, gpu_slot], non_blocking=True
|
||||
)
|
||||
event.record(self.pingpong_stream)
|
||||
|
||||
def offload_slot_to_cpu(
|
||||
self,
|
||||
gpu_slot: int,
|
||||
cpu_block_id: int,
|
||||
) -> None:
|
||||
"""
|
||||
Async offload a single GPU slot's KV to CPU.
|
||||
|
||||
Args:
|
||||
gpu_slot: GPU slot ID
|
||||
cpu_block_id: Target CPU block ID
|
||||
"""
|
||||
logger.debug(f"Slot offload: GPU[{gpu_slot}] -> CPU[{cpu_block_id}]")
|
||||
|
||||
with torch.cuda.stream(self.pingpong_stream):
|
||||
self.pingpong_stream.wait_stream(self.compute_stream)
|
||||
self.k_cache_cpu[:, cpu_block_id].copy_(
|
||||
self.k_cache_gpu[:, gpu_slot], non_blocking=True
|
||||
)
|
||||
self.v_cache_cpu[:, cpu_block_id].copy_(
|
||||
self.v_cache_gpu[:, gpu_slot], non_blocking=True
|
||||
)
|
||||
|
||||
def wait_ping_offload_done(self) -> None:
|
||||
"""Wait for Ping buffer offload to complete."""
|
||||
self.compute_stream.wait_event(self.ping_offload_done)
|
||||
|
||||
def wait_pong_offload_done(self) -> None:
|
||||
"""Wait for Pong buffer offload to complete."""
|
||||
self.compute_stream.wait_event(self.pong_offload_done)
|
||||
|
||||
def wait_all_offload_done(self) -> None:
|
||||
"""Wait for all offload operations to complete."""
|
||||
self.pingpong_stream.synchronize()
|
||||
|
||||
def get_kv_for_ping_slots(
|
||||
self,
|
||||
layer_id: int,
|
||||
num_slots: int,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Get KV for specified number of slots in Ping buffer.
|
||||
|
||||
Args:
|
||||
layer_id: Layer ID
|
||||
num_slots: Number of slots needed
|
||||
|
||||
Returns:
|
||||
(k_cache, v_cache), shape: [1, num_slots * block_size, kv_heads, head_dim]
|
||||
"""
|
||||
slots = self.ping_slots[:num_slots]
|
||||
k = self.k_cache_gpu[layer_id, slots] # [num_slots, block_size, heads, dim]
|
||||
v = self.v_cache_gpu[layer_id, slots]
|
||||
# Reshape: [num_slots, block_size, heads, dim] -> [1, num_slots*block_size, heads, dim]
|
||||
k = k.reshape(1, -1, self.num_kv_heads, self.head_dim)
|
||||
v = v.reshape(1, -1, self.num_kv_heads, self.head_dim)
|
||||
return k, v
|
||||
|
||||
def get_kv_for_pong_slots(
|
||||
self,
|
||||
layer_id: int,
|
||||
num_slots: int,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Get KV for specified number of slots in Pong buffer.
|
||||
|
||||
Args:
|
||||
layer_id: Layer ID
|
||||
num_slots: Number of slots needed
|
||||
|
||||
Returns:
|
||||
(k_cache, v_cache), shape: [1, num_slots * block_size, kv_heads, head_dim]
|
||||
"""
|
||||
slots = self.pong_slots[:num_slots]
|
||||
k = self.k_cache_gpu[layer_id, slots]
|
||||
v = self.v_cache_gpu[layer_id, slots]
|
||||
k = k.reshape(1, -1, self.num_kv_heads, self.head_dim)
|
||||
v = v.reshape(1, -1, self.num_kv_heads, self.head_dim)
|
||||
return k, v
|
||||
self.transfer_stream_main.synchronize()
|
||||
|
||||
def get_kv_for_slots(
|
||||
self,
|
||||
@@ -918,8 +729,6 @@ class OffloadEngine:
|
||||
def swap_compute_prefetch(self) -> None:
|
||||
"""Swap roles of Compute region and Prefetch region."""
|
||||
self.compute_slots, self.prefetch_slots = self.prefetch_slots, self.compute_slots
|
||||
# Also update old ping/pong slots for compatibility
|
||||
self.ping_slots, self.pong_slots = self.pong_slots, self.ping_slots
|
||||
|
||||
def offload_decode_slot(self, cpu_block_id: int) -> None:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user