[feat] Added debug hook to offload_engine.py.

This commit is contained in:
Zijie Tian
2025-12-31 19:44:39 +08:00
parent 7af721c12c
commit 484d0de9f9
5 changed files with 383 additions and 10 deletions

View File

@@ -69,15 +69,19 @@ class HybridKVCacheManager(KVCacheManager):
Architecture (CPU-primary mode):
- CPU pool: Primary storage for all KV cache (num_cpu_blocks)
- GPU buffer: Ring buffer for computation (num_gpu_slots)
- Logical blocks: What sequences reference (num_gpu_slots + num_cpu_blocks)
- GPU buffer: Ring buffer for computation only (num_gpu_slots)
- Logical blocks: What sequences reference (num_cpu_blocks)
Design:
- All KV cache is stored on CPU as primary storage
- GPU is used as a ring buffer for computation only
- GPU is used as a ring buffer for computation only (no persistent data)
- During prefill: KV is written to GPU ring slot, then offloaded to CPU
- During decode: Previous KV is loaded from CPU to GPU for attention
- Ring buffer enables pipelined H2D transfers overlapped with computation
Note:
- Logical blocks map 1:1 with CPU blocks (total_blocks = num_cpu_blocks)
- GPU slots are transient compute buffers, not tracked in logical blocks
"""
def __init__(
@@ -102,20 +106,22 @@ class HybridKVCacheManager(KVCacheManager):
self._block_size = block_size
self.num_gpu_slots = num_gpu_slots
self.num_cpu_blocks = num_cpu_blocks
self.total_blocks = num_gpu_slots + num_cpu_blocks
# In CPU-primary mode, logical blocks map 1:1 with CPU blocks
# GPU slots are transient compute buffers, not tracked as logical blocks
self.total_blocks = num_cpu_blocks
# Eviction policy
self.policy = policy or LRUPolicy()
# Logical blocks (what sequences reference)
# Logical blocks (what sequences reference) - one per CPU block
self.logical_blocks: List[LogicalBlock] = [
LogicalBlock(i) for i in range(self.total_blocks)
]
self.free_logical_ids: deque[int] = deque(range(self.total_blocks))
# GPU slot management (slots are fixed, mapping is variable)
# GPU slot management (kept for potential future use, but not used in CPU-primary mode)
self.free_gpu_slots: deque[int] = deque(range(num_gpu_slots))
self.gpu_slot_to_logical: Dict[int, int] = {} # gpu_slot -> logical_id
self.gpu_slot_to_logical: Dict[int, int] = {} # gpu_slot -> logical_id (unused in CPU-primary mode)
# CPU block management
self.free_cpu_blocks: deque[int] = deque(range(num_cpu_blocks))
@@ -212,7 +218,9 @@ class HybridKVCacheManager(KVCacheManager):
block.ref_count -= 1
if block.ref_count == 0:
# Free physical block
# Free physical block based on location
# Note: In CPU-primary mode, blocks are always on CPU.
# GPU branch kept for potential future hybrid mode support.
if block.location == BlockLocation.GPU:
self.free_gpu_slots.append(block.gpu_slot)
del self.gpu_slot_to_logical[block.gpu_slot]

View File

@@ -193,6 +193,10 @@ class OffloadEngine:
# ========== Event tracking for async transfers ==========
self.pending_events: Dict[Tuple[int, int], torch.cuda.Event] = {}
# ========== Debug hook mode ==========
self._debug_mode = False
self._debug_hooks: List = [] # External hooks for debug events
def _get_next_stream(self) -> torch.cuda.Stream:
"""Round-robin stream selection for parallel transfers."""
stream = self.transfer_streams[self._stream_idx]
@@ -1022,4 +1026,71 @@ class OffloadEngine:
if not slots:
slots = self.decode_load_slots
slots = slots[:num_blocks]
return self.get_kv_for_slots(layer_id, slots)
return self.get_kv_for_slots(layer_id, slots)
# ========== Debug Hook Interface ==========
#
# Minimal generic hook system for debugging.
# Framework only provides hook registration and tensor access.
# All verification logic is external.
def enable_debug_mode(self) -> None:
"""Enable debug mode."""
self._debug_mode = True
logger.info("OffloadEngine debug mode ENABLED")
def disable_debug_mode(self) -> None:
"""Disable debug mode and clear all hooks."""
self._debug_mode = False
self._debug_hooks.clear()
logger.info("OffloadEngine debug mode DISABLED")
@property
def debug_mode(self) -> bool:
"""Check if debug mode is enabled."""
return self._debug_mode
def register_debug_hook(self, hook_fn) -> None:
"""
Register a debug hook.
The hook is called after H2D load completes (after wait_slot_layer),
receiving the loaded tensor for inspection.
Args:
hook_fn: Callable with signature:
(slot_idx: int, layer_id: int, cpu_block_id: int, k: Tensor, v: Tensor) -> None
- k, v: GPU tensor views for the loaded slot
Example:
def my_hook(slot_idx, layer_id, cpu_block_id, k, v):
if layer_id == 0:
k_val = k.float().mean().item()
print(f"Loaded block {cpu_block_id}, K mean = {k_val}")
offload_engine.register_debug_hook(my_hook)
"""
self._debug_hooks.append(hook_fn)
def remove_debug_hook(self, hook_fn) -> None:
"""Remove a registered debug hook."""
if hook_fn in self._debug_hooks:
self._debug_hooks.remove(hook_fn)
def _call_debug_hooks(self, slot_idx: int, layer_id: int, cpu_block_id: int) -> None:
"""
Call all registered debug hooks with loaded tensor (internal use).
Called by attention.py after wait_slot_layer completes.
"""
if not self._debug_mode or not self._debug_hooks:
return
k = self.k_cache_gpu[layer_id, slot_idx]
v = self.v_cache_gpu[layer_id, slot_idx]
for hook in self._debug_hooks:
try:
hook(slot_idx, layer_id, cpu_block_id, k, v)
except Exception as e:
logger.warning(f"Debug hook error: {e}")