[refactor] Cleanup unused code after perf_opt merge
Removed ~460 lines of unused/redundant code from offload_engine.py: - CUDA gather methods (gathered_h2d_*, update_gather_indices) - Legacy async transfer methods (prefetch_block_async, offload_block_async) - Legacy sync/wait methods (wait_for_block, wait_all_transfers, sync_indices) - Legacy compatibility methods (load_to_compute_layer, wait_compute_layer) - Unused gather_indices tensors and memory calculations Updated class docstring to reflect current architecture. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -40,14 +40,13 @@ class OffloadEngine:
|
|||||||
High-performance CPU-GPU async transfer engine for KV cache offloading.
|
High-performance CPU-GPU async transfer engine for KV cache offloading.
|
||||||
|
|
||||||
Memory layout:
|
Memory layout:
|
||||||
- GPU cache: [num_layers, num_gpu_blocks, block_size, kv_heads, head_dim]
|
- GPU cache: [num_gpu_blocks, block_size, kv_heads, head_dim] (no layer dimension)
|
||||||
- CPU cache: [num_layers, num_cpu_blocks, block_size, kv_heads, head_dim] (pinned)
|
- CPU cache: [num_layers, num_cpu_blocks, block_size, kv_heads, head_dim] (pinned)
|
||||||
- Gather indices: [num_layers, num_gpu_blocks] (fixed address, variable content)
|
|
||||||
|
|
||||||
CUDA Graph compatibility:
|
Features:
|
||||||
- gathered_h2d_layer() can be captured into CUDA graphs
|
- Unified ring buffer for chunked prefill/decode
|
||||||
- update_gather_indices() is called outside graphs to prepare indices
|
- Per-layer prefill buffer for async offload
|
||||||
- All tensor addresses remain fixed across graph replays
|
- Cross-layer pipeline for decode with double-buffering
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -210,19 +209,6 @@ class OffloadEngine:
|
|||||||
dtype=dtype, device="cpu", pin_memory=True
|
dtype=dtype, device="cpu", pin_memory=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# ========== Fixed-address gather indices (content is variable) ==========
|
|
||||||
# gather_indices[layer][i] = CPU block id to copy to GPU slot i
|
|
||||||
# -1 means no-op (skip this slot)
|
|
||||||
self.gather_indices_cpu = torch.empty(
|
|
||||||
num_layers, num_gpu_blocks,
|
|
||||||
dtype=torch.int64, device="cpu", pin_memory=True
|
|
||||||
)
|
|
||||||
self.gather_indices_cpu.fill_(-1)
|
|
||||||
self.gather_indices_gpu = torch.full(
|
|
||||||
(num_layers, num_gpu_blocks), -1,
|
|
||||||
dtype=torch.int64, device="cuda"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Log memory allocation
|
# Log memory allocation
|
||||||
gpu_mem_mb = self.gpu_memory_bytes() / (1024 * 1024)
|
gpu_mem_mb = self.gpu_memory_bytes() / (1024 * 1024)
|
||||||
cpu_mem_mb = self.cpu_memory_bytes() / (1024 * 1024)
|
cpu_mem_mb = self.cpu_memory_bytes() / (1024 * 1024)
|
||||||
@@ -277,321 +263,6 @@ class OffloadEngine:
|
|||||||
# ========== Sparse attention policy (set at construction time) ==========
|
# ========== Sparse attention policy (set at construction time) ==========
|
||||||
self.sparse_policy = sparse_policy
|
self.sparse_policy = sparse_policy
|
||||||
|
|
||||||
def _get_next_stream(self) -> torch.cuda.Stream:
|
|
||||||
"""Round-robin stream selection for parallel transfers."""
|
|
||||||
stream = self.transfer_streams[self._stream_idx]
|
|
||||||
self._stream_idx = (self._stream_idx + 1) % len(self.transfer_streams)
|
|
||||||
return stream
|
|
||||||
|
|
||||||
# ========== CUDA Graph compatible methods ==========
|
|
||||||
# NOTE: These methods need to be updated for the new GPU cache architecture.
|
|
||||||
# GPU cache no longer has layer dimension, so gathered copy semantics change.
|
|
||||||
# For now, these are kept for reference but should not be used without updating.
|
|
||||||
|
|
||||||
def gathered_h2d_layer(self, layer_id: int) -> None:
|
|
||||||
"""
|
|
||||||
Execute gathered H2D copy for a single layer.
|
|
||||||
|
|
||||||
WARNING: This method needs updating for new GPU cache architecture.
|
|
||||||
GPU cache no longer has layer dimension.
|
|
||||||
"""
|
|
||||||
# GPU cache has no layer dimension - use flat indexing
|
|
||||||
# Source is CPU[layer_id], dest is GPU (shared across layers)
|
|
||||||
gathered_copy_kv(
|
|
||||||
k_src=self.k_cache_cpu[layer_id],
|
|
||||||
v_src=self.v_cache_cpu[layer_id],
|
|
||||||
k_dst=self.k_cache_gpu, # No layer indexing
|
|
||||||
v_dst=self.v_cache_gpu, # No layer indexing
|
|
||||||
indices=self.gather_indices_gpu[layer_id],
|
|
||||||
)
|
|
||||||
|
|
||||||
def gathered_h2d_all_layers(self) -> None:
|
|
||||||
"""
|
|
||||||
Execute gathered H2D copy for all layers.
|
|
||||||
|
|
||||||
WARNING: In new architecture, GPU slots are shared across layers.
|
|
||||||
This method would overwrite slots multiple times. Not recommended.
|
|
||||||
"""
|
|
||||||
for layer_id in range(self.num_layers):
|
|
||||||
self.gathered_h2d_layer(layer_id)
|
|
||||||
|
|
||||||
def update_gather_indices(
|
|
||||||
self,
|
|
||||||
layer_id: int,
|
|
||||||
mappings: List[Tuple[int, int]],
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Update gather indices for a layer (call OUTSIDE CUDA graph).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
layer_id: Layer index
|
|
||||||
mappings: List of (cpu_block_id, gpu_slot) tuples
|
|
||||||
Only these slots will be updated; others keep their values
|
|
||||||
"""
|
|
||||||
for cpu_block_id, gpu_slot in mappings:
|
|
||||||
self.gather_indices_cpu[layer_id, gpu_slot] = cpu_block_id
|
|
||||||
|
|
||||||
# Async copy to GPU
|
|
||||||
self.gather_indices_gpu[layer_id].copy_(
|
|
||||||
self.gather_indices_cpu[layer_id],
|
|
||||||
non_blocking=True
|
|
||||||
)
|
|
||||||
|
|
||||||
def update_gather_indices_all_layers(
|
|
||||||
self,
|
|
||||||
mappings_per_layer: List[List[Tuple[int, int]]],
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Update gather indices for all layers.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
mappings_per_layer: mappings_per_layer[layer_id] = [(cpu_block_id, gpu_slot), ...]
|
|
||||||
"""
|
|
||||||
for layer_id, mappings in enumerate(mappings_per_layer):
|
|
||||||
for cpu_block_id, gpu_slot in mappings:
|
|
||||||
self.gather_indices_cpu[layer_id, gpu_slot] = cpu_block_id
|
|
||||||
|
|
||||||
# Batch copy all layers
|
|
||||||
self.gather_indices_gpu.copy_(self.gather_indices_cpu, non_blocking=True)
|
|
||||||
|
|
||||||
def clear_gather_indices(self, layer_id: Optional[int] = None) -> None:
|
|
||||||
"""
|
|
||||||
Clear gather indices (set all to -1, meaning no-op).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
layer_id: If provided, clear only this layer; otherwise clear all
|
|
||||||
"""
|
|
||||||
if layer_id is not None:
|
|
||||||
self.gather_indices_cpu[layer_id].fill_(-1)
|
|
||||||
self.gather_indices_gpu[layer_id].fill_(-1)
|
|
||||||
else:
|
|
||||||
self.gather_indices_cpu.fill_(-1)
|
|
||||||
self.gather_indices_gpu.fill_(-1)
|
|
||||||
|
|
||||||
# ========== Async transfer methods (for prefill, outside CUDA graph) ==========
|
|
||||||
|
|
||||||
def prefetch_block_async(
|
|
||||||
self,
|
|
||||||
layer_id: int,
|
|
||||||
cpu_block_id: int,
|
|
||||||
gpu_block_id: int,
|
|
||||||
) -> torch.cuda.Event:
|
|
||||||
"""
|
|
||||||
Async prefetch a single block from CPU to GPU.
|
|
||||||
|
|
||||||
GPU cache has no layer dimension - layer_id is for CPU cache indexing.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
layer_id: Layer index (for CPU cache)
|
|
||||||
cpu_block_id: Source block in CPU cache
|
|
||||||
gpu_block_id: Destination slot in GPU cache
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
CUDA event that signals completion
|
|
||||||
"""
|
|
||||||
stream = self._get_next_stream()
|
|
||||||
event = torch.cuda.Event()
|
|
||||||
|
|
||||||
logger.debug(f"H2D prefetch: layer={layer_id}, CPU[{cpu_block_id}] -> GPU[{gpu_block_id}]")
|
|
||||||
|
|
||||||
with torch.cuda.stream(stream):
|
|
||||||
# GPU: no layer dimension, CPU: has layer dimension
|
|
||||||
self.k_cache_gpu[gpu_block_id].copy_(
|
|
||||||
self.k_cache_cpu[layer_id, cpu_block_id],
|
|
||||||
non_blocking=True
|
|
||||||
)
|
|
||||||
self.v_cache_gpu[gpu_block_id].copy_(
|
|
||||||
self.v_cache_cpu[layer_id, cpu_block_id],
|
|
||||||
non_blocking=True
|
|
||||||
)
|
|
||||||
event.record()
|
|
||||||
|
|
||||||
self.pending_events[(layer_id, gpu_block_id)] = event
|
|
||||||
return event
|
|
||||||
|
|
||||||
def prefetch_blocks_batch_async(
|
|
||||||
self,
|
|
||||||
transfers: List[Tuple[int, int, int]], # [(layer_id, cpu_block_id, gpu_block_id), ...]
|
|
||||||
) -> List[torch.cuda.Event]:
|
|
||||||
"""
|
|
||||||
Batch async prefetch multiple blocks.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
transfers: List of (layer_id, cpu_block_id, gpu_block_id) tuples
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of CUDA events for each transfer
|
|
||||||
"""
|
|
||||||
events = []
|
|
||||||
for layer_id, cpu_block_id, gpu_block_id in transfers:
|
|
||||||
event = self.prefetch_block_async(layer_id, cpu_block_id, gpu_block_id)
|
|
||||||
events.append(event)
|
|
||||||
return events
|
|
||||||
|
|
||||||
def offload_block_async(
|
|
||||||
self,
|
|
||||||
layer_id: int,
|
|
||||||
gpu_block_id: int,
|
|
||||||
cpu_block_id: int,
|
|
||||||
) -> torch.cuda.Event:
|
|
||||||
"""
|
|
||||||
Async offload a block from GPU to CPU.
|
|
||||||
|
|
||||||
GPU cache has no layer dimension - layer_id is for CPU cache indexing.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
layer_id: Layer index (for CPU cache)
|
|
||||||
gpu_block_id: Source slot in GPU cache
|
|
||||||
cpu_block_id: Destination block in CPU cache
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
CUDA event that signals completion
|
|
||||||
"""
|
|
||||||
stream = self._get_next_stream()
|
|
||||||
event = torch.cuda.Event()
|
|
||||||
|
|
||||||
logger.debug(f"D2H offload: layer={layer_id}, GPU[{gpu_block_id}] -> CPU[{cpu_block_id}]")
|
|
||||||
|
|
||||||
with torch.cuda.stream(stream):
|
|
||||||
# Wait for any compute using this block
|
|
||||||
stream.wait_stream(self.compute_stream)
|
|
||||||
|
|
||||||
# GPU: no layer dimension, CPU: has layer dimension
|
|
||||||
self.k_cache_cpu[layer_id, cpu_block_id].copy_(
|
|
||||||
self.k_cache_gpu[gpu_block_id],
|
|
||||||
non_blocking=True
|
|
||||||
)
|
|
||||||
self.v_cache_cpu[layer_id, cpu_block_id].copy_(
|
|
||||||
self.v_cache_gpu[gpu_block_id],
|
|
||||||
non_blocking=True
|
|
||||||
)
|
|
||||||
event.record()
|
|
||||||
|
|
||||||
return event
|
|
||||||
|
|
||||||
def offload_blocks_batch_async(
|
|
||||||
self,
|
|
||||||
transfers: List[Tuple[int, int, int]], # [(layer_id, gpu_block_id, cpu_block_id), ...]
|
|
||||||
) -> List[torch.cuda.Event]:
|
|
||||||
"""
|
|
||||||
Batch async offload multiple blocks.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
transfers: List of (layer_id, gpu_block_id, cpu_block_id) tuples
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of CUDA events
|
|
||||||
"""
|
|
||||||
events = []
|
|
||||||
for layer_id, gpu_block_id, cpu_block_id in transfers:
|
|
||||||
event = self.offload_block_async(layer_id, gpu_block_id, cpu_block_id)
|
|
||||||
events.append(event)
|
|
||||||
return events
|
|
||||||
|
|
||||||
# ========== Chunked Decode: Load CPU blocks to GPU slots ==========
|
|
||||||
|
|
||||||
def load_cpu_blocks_to_gpu_slots(
|
|
||||||
self,
|
|
||||||
layer_id: int,
|
|
||||||
cpu_block_ids: List[int],
|
|
||||||
gpu_slot_ids: List[int],
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Load CPU blocks to specific GPU slots for chunked decode.
|
|
||||||
|
|
||||||
GPU cache has no layer dimension - layer_id is for CPU cache indexing.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
layer_id: Layer index (for CPU cache)
|
|
||||||
cpu_block_ids: List of CPU block IDs to load
|
|
||||||
gpu_slot_ids: List of GPU slot IDs to load into (must be same length)
|
|
||||||
"""
|
|
||||||
assert len(cpu_block_ids) == len(gpu_slot_ids)
|
|
||||||
|
|
||||||
if cpu_block_ids:
|
|
||||||
logger.debug(f"H2D chunked load: layer={layer_id}, CPU{cpu_block_ids} -> GPU{gpu_slot_ids}")
|
|
||||||
|
|
||||||
stream = self._get_next_stream()
|
|
||||||
|
|
||||||
with torch.cuda.stream(stream):
|
|
||||||
for cpu_block_id, gpu_slot in zip(cpu_block_ids, gpu_slot_ids):
|
|
||||||
# GPU: no layer dimension, CPU: has layer dimension
|
|
||||||
self.k_cache_gpu[gpu_slot].copy_(
|
|
||||||
self.k_cache_cpu[layer_id, cpu_block_id],
|
|
||||||
non_blocking=True
|
|
||||||
)
|
|
||||||
self.v_cache_gpu[gpu_slot].copy_(
|
|
||||||
self.v_cache_cpu[layer_id, cpu_block_id],
|
|
||||||
non_blocking=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Wait for transfer to complete
|
|
||||||
stream.synchronize()
|
|
||||||
|
|
||||||
def load_cpu_blocks_to_gpu_slots_async(
|
|
||||||
self,
|
|
||||||
layer_id: int,
|
|
||||||
cpu_block_ids: List[int],
|
|
||||||
gpu_slot_ids: List[int],
|
|
||||||
) -> torch.cuda.Event:
|
|
||||||
"""
|
|
||||||
Async version: Load CPU blocks to GPU slots.
|
|
||||||
|
|
||||||
GPU cache has no layer dimension - layer_id is for CPU cache indexing.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
layer_id: Layer index (for CPU cache)
|
|
||||||
cpu_block_ids: List of CPU block IDs to load
|
|
||||||
gpu_slot_ids: List of GPU slot IDs to load into
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
CUDA event to wait on
|
|
||||||
"""
|
|
||||||
assert len(cpu_block_ids) == len(gpu_slot_ids)
|
|
||||||
|
|
||||||
if cpu_block_ids:
|
|
||||||
logger.debug(f"H2D chunked load async: layer={layer_id}, CPU{cpu_block_ids} -> GPU{gpu_slot_ids}")
|
|
||||||
|
|
||||||
stream = self._get_next_stream()
|
|
||||||
event = torch.cuda.Event()
|
|
||||||
|
|
||||||
with torch.cuda.stream(stream):
|
|
||||||
for cpu_block_id, gpu_slot in zip(cpu_block_ids, gpu_slot_ids):
|
|
||||||
# GPU: no layer dimension, CPU: has layer dimension
|
|
||||||
self.k_cache_gpu[gpu_slot].copy_(
|
|
||||||
self.k_cache_cpu[layer_id, cpu_block_id],
|
|
||||||
non_blocking=True
|
|
||||||
)
|
|
||||||
self.v_cache_gpu[gpu_slot].copy_(
|
|
||||||
self.v_cache_cpu[layer_id, cpu_block_id],
|
|
||||||
non_blocking=True
|
|
||||||
)
|
|
||||||
event.record()
|
|
||||||
|
|
||||||
return event
|
|
||||||
|
|
||||||
# NOTE: load_cpu_blocks_to_gpu_slots_all_layers removed - GPU cache no longer has
|
|
||||||
# layer dimension. Each GPU slot holds data for ONE layer at a time.
|
|
||||||
|
|
||||||
# ========== Synchronization methods ==========
|
|
||||||
|
|
||||||
def wait_for_block(self, layer_id: int, gpu_block_id: int) -> None:
|
|
||||||
"""Wait for a specific block's transfer to complete."""
|
|
||||||
key = (layer_id, gpu_block_id)
|
|
||||||
if key in self.pending_events:
|
|
||||||
self.pending_events[key].synchronize()
|
|
||||||
del self.pending_events[key]
|
|
||||||
|
|
||||||
def wait_all_transfers(self) -> None:
|
|
||||||
"""Wait for all pending transfers to complete."""
|
|
||||||
for stream in self.transfer_streams:
|
|
||||||
stream.synchronize()
|
|
||||||
self.pending_events.clear()
|
|
||||||
|
|
||||||
def sync_indices(self) -> None:
|
|
||||||
"""Synchronize to ensure all index updates are complete."""
|
|
||||||
torch.cuda.default_stream().synchronize()
|
|
||||||
|
|
||||||
# ========== Cache access methods ==========
|
# ========== Cache access methods ==========
|
||||||
|
|
||||||
def get_layer_cache(self, layer_id: int) -> Tuple[Tensor, Tensor]:
|
def get_layer_cache(self, layer_id: int) -> Tuple[Tensor, Tensor]:
|
||||||
@@ -605,54 +276,22 @@ class OffloadEngine:
|
|||||||
(k_cache, v_cache) tensors
|
(k_cache, v_cache) tensors
|
||||||
Shape: [num_gpu_blocks, block_size, kv_heads, head_dim]
|
Shape: [num_gpu_blocks, block_size, kv_heads, head_dim]
|
||||||
"""
|
"""
|
||||||
# GPU cache is shared across all layers (no layer dimension)
|
|
||||||
return self.k_cache_gpu, self.v_cache_gpu
|
return self.k_cache_gpu, self.v_cache_gpu
|
||||||
|
|
||||||
def get_all_gpu_cache(self) -> Tuple[Tensor, Tensor]:
|
|
||||||
"""
|
|
||||||
Get full GPU K/V cache tensors.
|
|
||||||
|
|
||||||
NOTE: GPU cache has no layer dimension in the new architecture.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
(k_cache, v_cache) tensors
|
|
||||||
Shape: [num_gpu_blocks, block_size, kv_heads, head_dim]
|
|
||||||
"""
|
|
||||||
return self.k_cache_gpu, self.v_cache_gpu
|
|
||||||
|
|
||||||
def get_cpu_block(
|
|
||||||
self,
|
|
||||||
layer_id: int,
|
|
||||||
cpu_block_id: int,
|
|
||||||
) -> Tuple[Tensor, Tensor]:
|
|
||||||
"""
|
|
||||||
Get a specific CPU block's K/V cache.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
(k_cache, v_cache) for the block
|
|
||||||
Shape: [block_size, kv_heads, head_dim]
|
|
||||||
"""
|
|
||||||
return (
|
|
||||||
self.k_cache_cpu[layer_id, cpu_block_id],
|
|
||||||
self.v_cache_cpu[layer_id, cpu_block_id],
|
|
||||||
)
|
|
||||||
|
|
||||||
# ========== Memory info ==========
|
# ========== Memory info ==========
|
||||||
|
|
||||||
def gpu_memory_bytes(self) -> int:
|
def gpu_memory_bytes(self) -> int:
|
||||||
"""Total GPU memory used by KV caches."""
|
"""Total GPU memory used by KV caches."""
|
||||||
return (
|
return (
|
||||||
self.k_cache_gpu.numel() * self.k_cache_gpu.element_size() +
|
self.k_cache_gpu.numel() * self.k_cache_gpu.element_size() +
|
||||||
self.v_cache_gpu.numel() * self.v_cache_gpu.element_size() +
|
self.v_cache_gpu.numel() * self.v_cache_gpu.element_size()
|
||||||
self.gather_indices_gpu.numel() * self.gather_indices_gpu.element_size()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def cpu_memory_bytes(self) -> int:
|
def cpu_memory_bytes(self) -> int:
|
||||||
"""Total CPU memory used by KV caches."""
|
"""Total CPU memory used by KV caches."""
|
||||||
return (
|
return (
|
||||||
self.k_cache_cpu.numel() * self.k_cache_cpu.element_size() +
|
self.k_cache_cpu.numel() * self.k_cache_cpu.element_size() +
|
||||||
self.v_cache_cpu.numel() * self.v_cache_cpu.element_size() +
|
self.v_cache_cpu.numel() * self.v_cache_cpu.element_size()
|
||||||
self.gather_indices_cpu.numel() * self.gather_indices_cpu.element_size()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
@@ -955,102 +594,6 @@ class OffloadEngine:
|
|||||||
v = v.unsqueeze(0)
|
v = v.unsqueeze(0)
|
||||||
return k, v
|
return k, v
|
||||||
|
|
||||||
# ----- Legacy compatibility methods (for decode double-buffering) -----
|
|
||||||
# NOTE: GPU cache has no layer dimension. Layer ID is used for CPU cache indexing only.
|
|
||||||
|
|
||||||
def load_to_compute_layer(self, layer_id: int, cpu_block_ids: List[int]) -> None:
|
|
||||||
"""
|
|
||||||
Legacy: Load CPU blocks to decode_load_slots for decode double-buffering.
|
|
||||||
|
|
||||||
Uses first half of decode_load_slots as 'compute' region.
|
|
||||||
GPU cache has no layer dimension - layer_id is for CPU cache indexing.
|
|
||||||
"""
|
|
||||||
if not cpu_block_ids:
|
|
||||||
return
|
|
||||||
|
|
||||||
half = max(1, len(self.decode_load_slots) // 2)
|
|
||||||
slots = self.decode_load_slots[:half]
|
|
||||||
num_to_load = min(len(cpu_block_ids), len(slots))
|
|
||||||
|
|
||||||
with torch.cuda.stream(self.transfer_stream_main):
|
|
||||||
for i in range(num_to_load):
|
|
||||||
cpu_id = cpu_block_ids[i]
|
|
||||||
gpu_slot = slots[i]
|
|
||||||
# GPU: no layer dimension, CPU: has layer dimension
|
|
||||||
self.k_cache_gpu[gpu_slot].copy_(
|
|
||||||
self.k_cache_cpu[layer_id, cpu_id], non_blocking=True
|
|
||||||
)
|
|
||||||
self.v_cache_gpu[gpu_slot].copy_(
|
|
||||||
self.v_cache_cpu[layer_id, cpu_id], non_blocking=True
|
|
||||||
)
|
|
||||||
if num_to_load > 0:
|
|
||||||
self.ring_slot_ready[slots[0]].record(self.transfer_stream_main)
|
|
||||||
|
|
||||||
def wait_compute_layer(self) -> None:
|
|
||||||
"""Legacy: Wait for 'compute' region loading."""
|
|
||||||
if self.decode_load_slots:
|
|
||||||
self.wait_slot_layer(self.decode_load_slots[0])
|
|
||||||
|
|
||||||
def load_to_prefetch_layer(self, layer_id: int, cpu_block_ids: List[int]) -> None:
|
|
||||||
"""
|
|
||||||
Legacy: Load CPU blocks to decode_load_slots for decode double-buffering.
|
|
||||||
|
|
||||||
Uses second half of decode_load_slots as 'prefetch' region.
|
|
||||||
GPU cache has no layer dimension - layer_id is for CPU cache indexing.
|
|
||||||
"""
|
|
||||||
if not cpu_block_ids:
|
|
||||||
return
|
|
||||||
|
|
||||||
half = max(1, len(self.decode_load_slots) // 2)
|
|
||||||
slots = self.decode_load_slots[half:]
|
|
||||||
if not slots:
|
|
||||||
slots = self.decode_load_slots # Fallback if only 1-2 slots
|
|
||||||
num_to_load = min(len(cpu_block_ids), len(slots))
|
|
||||||
|
|
||||||
with torch.cuda.stream(self.transfer_stream_main):
|
|
||||||
for i in range(num_to_load):
|
|
||||||
cpu_id = cpu_block_ids[i]
|
|
||||||
gpu_slot = slots[i]
|
|
||||||
# GPU: no layer dimension, CPU: has layer dimension
|
|
||||||
self.k_cache_gpu[gpu_slot].copy_(
|
|
||||||
self.k_cache_cpu[layer_id, cpu_id], non_blocking=True
|
|
||||||
)
|
|
||||||
self.v_cache_gpu[gpu_slot].copy_(
|
|
||||||
self.v_cache_cpu[layer_id, cpu_id], non_blocking=True
|
|
||||||
)
|
|
||||||
if num_to_load > 0:
|
|
||||||
self.ring_slot_ready[slots[0]].record(self.transfer_stream_main)
|
|
||||||
|
|
||||||
def wait_prefetch_layer(self) -> None:
|
|
||||||
"""Legacy: Wait for 'prefetch' region loading."""
|
|
||||||
half = max(1, len(self.decode_load_slots) // 2)
|
|
||||||
slots = self.decode_load_slots[half:]
|
|
||||||
if slots:
|
|
||||||
self.wait_slot_layer(slots[0])
|
|
||||||
elif self.decode_load_slots:
|
|
||||||
self.wait_slot_layer(self.decode_load_slots[0])
|
|
||||||
|
|
||||||
def get_kv_for_compute(
|
|
||||||
self,
|
|
||||||
num_blocks: int,
|
|
||||||
) -> Tuple[Tensor, Tensor]:
|
|
||||||
"""Legacy: Get KV from 'compute' region (first half of decode_load_slots)."""
|
|
||||||
half = max(1, len(self.decode_load_slots) // 2)
|
|
||||||
slots = self.decode_load_slots[:half][:num_blocks]
|
|
||||||
return self.get_kv_for_slots(slots)
|
|
||||||
|
|
||||||
def get_kv_for_prefetch(
|
|
||||||
self,
|
|
||||||
num_blocks: int,
|
|
||||||
) -> Tuple[Tensor, Tensor]:
|
|
||||||
"""Legacy: Get KV from 'prefetch' region (second half of decode_load_slots)."""
|
|
||||||
half = max(1, len(self.decode_load_slots) // 2)
|
|
||||||
slots = self.decode_load_slots[half:]
|
|
||||||
if not slots:
|
|
||||||
slots = self.decode_load_slots
|
|
||||||
slots = slots[:num_blocks]
|
|
||||||
return self.get_kv_for_slots(slots)
|
|
||||||
|
|
||||||
# ========== Debug Hook Interface ==========
|
# ========== Debug Hook Interface ==========
|
||||||
#
|
#
|
||||||
# Minimal generic hook system for debugging.
|
# Minimal generic hook system for debugging.
|
||||||
|
|||||||
Reference in New Issue
Block a user