[feat] Added debug hook to offload_engine.py.
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from transformers import AutoConfig
|
||||
import torch
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -16,6 +17,7 @@ class Config:
|
||||
eos: int = -1
|
||||
kvcache_block_size: int = 4096
|
||||
num_kvcache_blocks: int = -1
|
||||
dtype: str | None = None # "float16", "bfloat16", or None (use model default)
|
||||
|
||||
# CPU Offload configuration
|
||||
enable_cpu_offload: bool = False
|
||||
@@ -41,3 +43,17 @@ class Config:
|
||||
self.hf_config = AutoConfig.from_pretrained(self.model)
|
||||
self.max_model_len = min(self.max_model_len, self.hf_config.max_position_embeddings)
|
||||
assert self.max_num_batched_tokens >= self.max_model_len
|
||||
|
||||
# Override torch_dtype if user specified
|
||||
if self.dtype is not None:
|
||||
dtype_map = {
|
||||
"float16": torch.float16,
|
||||
"fp16": torch.float16,
|
||||
"bfloat16": torch.bfloat16,
|
||||
"bf16": torch.bfloat16,
|
||||
"float32": torch.float32,
|
||||
"fp32": torch.float32,
|
||||
}
|
||||
if self.dtype not in dtype_map:
|
||||
raise ValueError(f"Invalid dtype: {self.dtype}. Choose from: {list(dtype_map.keys())}")
|
||||
self.hf_config.torch_dtype = dtype_map[self.dtype]
|
||||
|
||||
@@ -69,15 +69,19 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
|
||||
Architecture (CPU-primary mode):
|
||||
- CPU pool: Primary storage for all KV cache (num_cpu_blocks)
|
||||
- GPU buffer: Ring buffer for computation (num_gpu_slots)
|
||||
- Logical blocks: What sequences reference (num_gpu_slots + num_cpu_blocks)
|
||||
- GPU buffer: Ring buffer for computation only (num_gpu_slots)
|
||||
- Logical blocks: What sequences reference (num_cpu_blocks)
|
||||
|
||||
Design:
|
||||
- All KV cache is stored on CPU as primary storage
|
||||
- GPU is used as a ring buffer for computation only
|
||||
- GPU is used as a ring buffer for computation only (no persistent data)
|
||||
- During prefill: KV is written to GPU ring slot, then offloaded to CPU
|
||||
- During decode: Previous KV is loaded from CPU to GPU for attention
|
||||
- Ring buffer enables pipelined H2D transfers overlapped with computation
|
||||
|
||||
Note:
|
||||
- Logical blocks map 1:1 with CPU blocks (total_blocks = num_cpu_blocks)
|
||||
- GPU slots are transient compute buffers, not tracked in logical blocks
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -102,20 +106,22 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
self._block_size = block_size
|
||||
self.num_gpu_slots = num_gpu_slots
|
||||
self.num_cpu_blocks = num_cpu_blocks
|
||||
self.total_blocks = num_gpu_slots + num_cpu_blocks
|
||||
# In CPU-primary mode, logical blocks map 1:1 with CPU blocks
|
||||
# GPU slots are transient compute buffers, not tracked as logical blocks
|
||||
self.total_blocks = num_cpu_blocks
|
||||
|
||||
# Eviction policy
|
||||
self.policy = policy or LRUPolicy()
|
||||
|
||||
# Logical blocks (what sequences reference)
|
||||
# Logical blocks (what sequences reference) - one per CPU block
|
||||
self.logical_blocks: List[LogicalBlock] = [
|
||||
LogicalBlock(i) for i in range(self.total_blocks)
|
||||
]
|
||||
self.free_logical_ids: deque[int] = deque(range(self.total_blocks))
|
||||
|
||||
# GPU slot management (slots are fixed, mapping is variable)
|
||||
# GPU slot management (kept for potential future use, but not used in CPU-primary mode)
|
||||
self.free_gpu_slots: deque[int] = deque(range(num_gpu_slots))
|
||||
self.gpu_slot_to_logical: Dict[int, int] = {} # gpu_slot -> logical_id
|
||||
self.gpu_slot_to_logical: Dict[int, int] = {} # gpu_slot -> logical_id (unused in CPU-primary mode)
|
||||
|
||||
# CPU block management
|
||||
self.free_cpu_blocks: deque[int] = deque(range(num_cpu_blocks))
|
||||
@@ -212,7 +218,9 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
block.ref_count -= 1
|
||||
|
||||
if block.ref_count == 0:
|
||||
# Free physical block
|
||||
# Free physical block based on location
|
||||
# Note: In CPU-primary mode, blocks are always on CPU.
|
||||
# GPU branch kept for potential future hybrid mode support.
|
||||
if block.location == BlockLocation.GPU:
|
||||
self.free_gpu_slots.append(block.gpu_slot)
|
||||
del self.gpu_slot_to_logical[block.gpu_slot]
|
||||
|
||||
@@ -193,6 +193,10 @@ class OffloadEngine:
|
||||
# ========== Event tracking for async transfers ==========
|
||||
self.pending_events: Dict[Tuple[int, int], torch.cuda.Event] = {}
|
||||
|
||||
# ========== Debug hook mode ==========
|
||||
self._debug_mode = False
|
||||
self._debug_hooks: List = [] # External hooks for debug events
|
||||
|
||||
def _get_next_stream(self) -> torch.cuda.Stream:
|
||||
"""Round-robin stream selection for parallel transfers."""
|
||||
stream = self.transfer_streams[self._stream_idx]
|
||||
@@ -1022,4 +1026,71 @@ class OffloadEngine:
|
||||
if not slots:
|
||||
slots = self.decode_load_slots
|
||||
slots = slots[:num_blocks]
|
||||
return self.get_kv_for_slots(layer_id, slots)
|
||||
return self.get_kv_for_slots(layer_id, slots)
|
||||
|
||||
# ========== Debug Hook Interface ==========
|
||||
#
|
||||
# Minimal generic hook system for debugging.
|
||||
# Framework only provides hook registration and tensor access.
|
||||
# All verification logic is external.
|
||||
|
||||
def enable_debug_mode(self) -> None:
|
||||
"""Enable debug mode."""
|
||||
self._debug_mode = True
|
||||
logger.info("OffloadEngine debug mode ENABLED")
|
||||
|
||||
def disable_debug_mode(self) -> None:
|
||||
"""Disable debug mode and clear all hooks."""
|
||||
self._debug_mode = False
|
||||
self._debug_hooks.clear()
|
||||
logger.info("OffloadEngine debug mode DISABLED")
|
||||
|
||||
@property
|
||||
def debug_mode(self) -> bool:
|
||||
"""Check if debug mode is enabled."""
|
||||
return self._debug_mode
|
||||
|
||||
def register_debug_hook(self, hook_fn) -> None:
|
||||
"""
|
||||
Register a debug hook.
|
||||
|
||||
The hook is called after H2D load completes (after wait_slot_layer),
|
||||
receiving the loaded tensor for inspection.
|
||||
|
||||
Args:
|
||||
hook_fn: Callable with signature:
|
||||
(slot_idx: int, layer_id: int, cpu_block_id: int, k: Tensor, v: Tensor) -> None
|
||||
- k, v: GPU tensor views for the loaded slot
|
||||
|
||||
Example:
|
||||
def my_hook(slot_idx, layer_id, cpu_block_id, k, v):
|
||||
if layer_id == 0:
|
||||
k_val = k.float().mean().item()
|
||||
print(f"Loaded block {cpu_block_id}, K mean = {k_val}")
|
||||
|
||||
offload_engine.register_debug_hook(my_hook)
|
||||
"""
|
||||
self._debug_hooks.append(hook_fn)
|
||||
|
||||
def remove_debug_hook(self, hook_fn) -> None:
|
||||
"""Remove a registered debug hook."""
|
||||
if hook_fn in self._debug_hooks:
|
||||
self._debug_hooks.remove(hook_fn)
|
||||
|
||||
def _call_debug_hooks(self, slot_idx: int, layer_id: int, cpu_block_id: int) -> None:
|
||||
"""
|
||||
Call all registered debug hooks with loaded tensor (internal use).
|
||||
|
||||
Called by attention.py after wait_slot_layer completes.
|
||||
"""
|
||||
if not self._debug_mode or not self._debug_hooks:
|
||||
return
|
||||
|
||||
k = self.k_cache_gpu[layer_id, slot_idx]
|
||||
v = self.v_cache_gpu[layer_id, slot_idx]
|
||||
|
||||
for hook in self._debug_hooks:
|
||||
try:
|
||||
hook(slot_idx, layer_id, cpu_block_id, k, v)
|
||||
except Exception as e:
|
||||
logger.warning(f"Debug hook error: {e}")
|
||||
@@ -287,9 +287,15 @@ class Attention(nn.Module):
|
||||
slot = load_slots[0]
|
||||
compute_stream = offload_engine.compute_stream
|
||||
for block_idx in range(num_blocks):
|
||||
offload_engine.load_to_slot_layer(slot, self.layer_id, cpu_block_table[block_idx])
|
||||
cpu_block_id = cpu_block_table[block_idx]
|
||||
offload_engine.load_to_slot_layer(slot, self.layer_id, cpu_block_id)
|
||||
offload_engine.wait_slot_layer(slot, self.layer_id)
|
||||
|
||||
with torch.cuda.stream(compute_stream):
|
||||
# Debug: call hooks on compute_stream (synchronized with transfer)
|
||||
if offload_engine.debug_mode:
|
||||
offload_engine._call_debug_hooks(slot, self.layer_id, cpu_block_id)
|
||||
|
||||
prev_k, prev_v = offload_engine.get_kv_for_slot(slot, self.layer_id)
|
||||
prev_o, prev_lse = flash_attn_with_lse(
|
||||
q_batched, prev_k, prev_v,
|
||||
@@ -323,6 +329,7 @@ class Attention(nn.Module):
|
||||
|
||||
# Cycle through slots: slot[block_idx % num_slots]
|
||||
current_slot = load_slots[block_idx % num_slots]
|
||||
cpu_block_id = cpu_block_table[block_idx]
|
||||
|
||||
# Wait for current slot's transfer to complete (on compute_stream)
|
||||
offload_engine.wait_slot_layer(current_slot, self.layer_id)
|
||||
@@ -330,6 +337,10 @@ class Attention(nn.Module):
|
||||
# Compute attention on current slot's data
|
||||
# IMPORTANT: Use dedicated compute_stream to avoid implicit sync with default stream
|
||||
with torch.cuda.stream(compute_stream):
|
||||
# Debug: call hooks on compute_stream (synchronized with transfer)
|
||||
if offload_engine.debug_mode:
|
||||
offload_engine._call_debug_hooks(current_slot, self.layer_id, cpu_block_id)
|
||||
|
||||
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} PrevBlock{block_idx}")
|
||||
prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot, self.layer_id)
|
||||
prev_o, prev_lse = flash_attn_with_lse(
|
||||
|
||||
Reference in New Issue
Block a user