[feat] Added debug hook to offload_engine.py.
This commit is contained in:
@@ -193,6 +193,10 @@ class OffloadEngine:
|
||||
# ========== Event tracking for async transfers ==========
|
||||
self.pending_events: Dict[Tuple[int, int], torch.cuda.Event] = {}
|
||||
|
||||
# ========== Debug hook mode ==========
|
||||
self._debug_mode = False
|
||||
self._debug_hooks: List = [] # External hooks for debug events
|
||||
|
||||
def _get_next_stream(self) -> torch.cuda.Stream:
|
||||
"""Round-robin stream selection for parallel transfers."""
|
||||
stream = self.transfer_streams[self._stream_idx]
|
||||
@@ -1022,4 +1026,71 @@ class OffloadEngine:
|
||||
if not slots:
|
||||
slots = self.decode_load_slots
|
||||
slots = slots[:num_blocks]
|
||||
return self.get_kv_for_slots(layer_id, slots)
|
||||
return self.get_kv_for_slots(layer_id, slots)
|
||||
|
||||
# ========== Debug Hook Interface ==========
|
||||
#
|
||||
# Minimal generic hook system for debugging.
|
||||
# Framework only provides hook registration and tensor access.
|
||||
# All verification logic is external.
|
||||
|
||||
def enable_debug_mode(self) -> None:
|
||||
"""Enable debug mode."""
|
||||
self._debug_mode = True
|
||||
logger.info("OffloadEngine debug mode ENABLED")
|
||||
|
||||
def disable_debug_mode(self) -> None:
|
||||
"""Disable debug mode and clear all hooks."""
|
||||
self._debug_mode = False
|
||||
self._debug_hooks.clear()
|
||||
logger.info("OffloadEngine debug mode DISABLED")
|
||||
|
||||
@property
|
||||
def debug_mode(self) -> bool:
|
||||
"""Check if debug mode is enabled."""
|
||||
return self._debug_mode
|
||||
|
||||
def register_debug_hook(self, hook_fn) -> None:
|
||||
"""
|
||||
Register a debug hook.
|
||||
|
||||
The hook is called after H2D load completes (after wait_slot_layer),
|
||||
receiving the loaded tensor for inspection.
|
||||
|
||||
Args:
|
||||
hook_fn: Callable with signature:
|
||||
(slot_idx: int, layer_id: int, cpu_block_id: int, k: Tensor, v: Tensor) -> None
|
||||
- k, v: GPU tensor views for the loaded slot
|
||||
|
||||
Example:
|
||||
def my_hook(slot_idx, layer_id, cpu_block_id, k, v):
|
||||
if layer_id == 0:
|
||||
k_val = k.float().mean().item()
|
||||
print(f"Loaded block {cpu_block_id}, K mean = {k_val}")
|
||||
|
||||
offload_engine.register_debug_hook(my_hook)
|
||||
"""
|
||||
self._debug_hooks.append(hook_fn)
|
||||
|
||||
def remove_debug_hook(self, hook_fn) -> None:
|
||||
"""Remove a registered debug hook."""
|
||||
if hook_fn in self._debug_hooks:
|
||||
self._debug_hooks.remove(hook_fn)
|
||||
|
||||
def _call_debug_hooks(self, slot_idx: int, layer_id: int, cpu_block_id: int) -> None:
|
||||
"""
|
||||
Call all registered debug hooks with loaded tensor (internal use).
|
||||
|
||||
Called by attention.py after wait_slot_layer completes.
|
||||
"""
|
||||
if not self._debug_mode or not self._debug_hooks:
|
||||
return
|
||||
|
||||
k = self.k_cache_gpu[layer_id, slot_idx]
|
||||
v = self.v_cache_gpu[layer_id, slot_idx]
|
||||
|
||||
for hook in self._debug_hooks:
|
||||
try:
|
||||
hook(slot_idx, layer_id, cpu_block_id, k, v)
|
||||
except Exception as e:
|
||||
logger.warning(f"Debug hook error: {e}")
|
||||
Reference in New Issue
Block a user