[feat] Added debug hook to offload_engine.py.

This commit is contained in:
Zijie Tian
2025-12-31 19:44:39 +08:00
parent 7af721c12c
commit 484d0de9f9
5 changed files with 383 additions and 10 deletions

View File

@@ -287,9 +287,15 @@ class Attention(nn.Module):
slot = load_slots[0]
compute_stream = offload_engine.compute_stream
for block_idx in range(num_blocks):
offload_engine.load_to_slot_layer(slot, self.layer_id, cpu_block_table[block_idx])
cpu_block_id = cpu_block_table[block_idx]
offload_engine.load_to_slot_layer(slot, self.layer_id, cpu_block_id)
offload_engine.wait_slot_layer(slot, self.layer_id)
with torch.cuda.stream(compute_stream):
# Debug: call hooks on compute_stream (synchronized with transfer)
if offload_engine.debug_mode:
offload_engine._call_debug_hooks(slot, self.layer_id, cpu_block_id)
prev_k, prev_v = offload_engine.get_kv_for_slot(slot, self.layer_id)
prev_o, prev_lse = flash_attn_with_lse(
q_batched, prev_k, prev_v,
@@ -323,6 +329,7 @@ class Attention(nn.Module):
# Cycle through slots: slot[block_idx % num_slots]
current_slot = load_slots[block_idx % num_slots]
cpu_block_id = cpu_block_table[block_idx]
# Wait for current slot's transfer to complete (on compute_stream)
offload_engine.wait_slot_layer(current_slot, self.layer_id)
@@ -330,6 +337,10 @@ class Attention(nn.Module):
# Compute attention on current slot's data
# IMPORTANT: Use dedicated compute_stream to avoid implicit sync with default stream
with torch.cuda.stream(compute_stream):
# Debug: call hooks on compute_stream (synchronized with transfer)
if offload_engine.debug_mode:
offload_engine._call_debug_hooks(current_slot, self.layer_id, cpu_block_id)
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} PrevBlock{block_idx}")
prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot, self.layer_id)
prev_o, prev_lse = flash_attn_with_lse(