[feat] Added debug hook to offload_engine.py.
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from transformers import AutoConfig
|
from transformers import AutoConfig
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -16,6 +17,7 @@ class Config:
|
|||||||
eos: int = -1
|
eos: int = -1
|
||||||
kvcache_block_size: int = 4096
|
kvcache_block_size: int = 4096
|
||||||
num_kvcache_blocks: int = -1
|
num_kvcache_blocks: int = -1
|
||||||
|
dtype: str | None = None # "float16", "bfloat16", or None (use model default)
|
||||||
|
|
||||||
# CPU Offload configuration
|
# CPU Offload configuration
|
||||||
enable_cpu_offload: bool = False
|
enable_cpu_offload: bool = False
|
||||||
@@ -41,3 +43,17 @@ class Config:
|
|||||||
self.hf_config = AutoConfig.from_pretrained(self.model)
|
self.hf_config = AutoConfig.from_pretrained(self.model)
|
||||||
self.max_model_len = min(self.max_model_len, self.hf_config.max_position_embeddings)
|
self.max_model_len = min(self.max_model_len, self.hf_config.max_position_embeddings)
|
||||||
assert self.max_num_batched_tokens >= self.max_model_len
|
assert self.max_num_batched_tokens >= self.max_model_len
|
||||||
|
|
||||||
|
# Override torch_dtype if user specified
|
||||||
|
if self.dtype is not None:
|
||||||
|
dtype_map = {
|
||||||
|
"float16": torch.float16,
|
||||||
|
"fp16": torch.float16,
|
||||||
|
"bfloat16": torch.bfloat16,
|
||||||
|
"bf16": torch.bfloat16,
|
||||||
|
"float32": torch.float32,
|
||||||
|
"fp32": torch.float32,
|
||||||
|
}
|
||||||
|
if self.dtype not in dtype_map:
|
||||||
|
raise ValueError(f"Invalid dtype: {self.dtype}. Choose from: {list(dtype_map.keys())}")
|
||||||
|
self.hf_config.torch_dtype = dtype_map[self.dtype]
|
||||||
|
|||||||
@@ -69,15 +69,19 @@ class HybridKVCacheManager(KVCacheManager):
|
|||||||
|
|
||||||
Architecture (CPU-primary mode):
|
Architecture (CPU-primary mode):
|
||||||
- CPU pool: Primary storage for all KV cache (num_cpu_blocks)
|
- CPU pool: Primary storage for all KV cache (num_cpu_blocks)
|
||||||
- GPU buffer: Ring buffer for computation (num_gpu_slots)
|
- GPU buffer: Ring buffer for computation only (num_gpu_slots)
|
||||||
- Logical blocks: What sequences reference (num_gpu_slots + num_cpu_blocks)
|
- Logical blocks: What sequences reference (num_cpu_blocks)
|
||||||
|
|
||||||
Design:
|
Design:
|
||||||
- All KV cache is stored on CPU as primary storage
|
- All KV cache is stored on CPU as primary storage
|
||||||
- GPU is used as a ring buffer for computation only
|
- GPU is used as a ring buffer for computation only (no persistent data)
|
||||||
- During prefill: KV is written to GPU ring slot, then offloaded to CPU
|
- During prefill: KV is written to GPU ring slot, then offloaded to CPU
|
||||||
- During decode: Previous KV is loaded from CPU to GPU for attention
|
- During decode: Previous KV is loaded from CPU to GPU for attention
|
||||||
- Ring buffer enables pipelined H2D transfers overlapped with computation
|
- Ring buffer enables pipelined H2D transfers overlapped with computation
|
||||||
|
|
||||||
|
Note:
|
||||||
|
- Logical blocks map 1:1 with CPU blocks (total_blocks = num_cpu_blocks)
|
||||||
|
- GPU slots are transient compute buffers, not tracked in logical blocks
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -102,20 +106,22 @@ class HybridKVCacheManager(KVCacheManager):
|
|||||||
self._block_size = block_size
|
self._block_size = block_size
|
||||||
self.num_gpu_slots = num_gpu_slots
|
self.num_gpu_slots = num_gpu_slots
|
||||||
self.num_cpu_blocks = num_cpu_blocks
|
self.num_cpu_blocks = num_cpu_blocks
|
||||||
self.total_blocks = num_gpu_slots + num_cpu_blocks
|
# In CPU-primary mode, logical blocks map 1:1 with CPU blocks
|
||||||
|
# GPU slots are transient compute buffers, not tracked as logical blocks
|
||||||
|
self.total_blocks = num_cpu_blocks
|
||||||
|
|
||||||
# Eviction policy
|
# Eviction policy
|
||||||
self.policy = policy or LRUPolicy()
|
self.policy = policy or LRUPolicy()
|
||||||
|
|
||||||
# Logical blocks (what sequences reference)
|
# Logical blocks (what sequences reference) - one per CPU block
|
||||||
self.logical_blocks: List[LogicalBlock] = [
|
self.logical_blocks: List[LogicalBlock] = [
|
||||||
LogicalBlock(i) for i in range(self.total_blocks)
|
LogicalBlock(i) for i in range(self.total_blocks)
|
||||||
]
|
]
|
||||||
self.free_logical_ids: deque[int] = deque(range(self.total_blocks))
|
self.free_logical_ids: deque[int] = deque(range(self.total_blocks))
|
||||||
|
|
||||||
# GPU slot management (slots are fixed, mapping is variable)
|
# GPU slot management (kept for potential future use, but not used in CPU-primary mode)
|
||||||
self.free_gpu_slots: deque[int] = deque(range(num_gpu_slots))
|
self.free_gpu_slots: deque[int] = deque(range(num_gpu_slots))
|
||||||
self.gpu_slot_to_logical: Dict[int, int] = {} # gpu_slot -> logical_id
|
self.gpu_slot_to_logical: Dict[int, int] = {} # gpu_slot -> logical_id (unused in CPU-primary mode)
|
||||||
|
|
||||||
# CPU block management
|
# CPU block management
|
||||||
self.free_cpu_blocks: deque[int] = deque(range(num_cpu_blocks))
|
self.free_cpu_blocks: deque[int] = deque(range(num_cpu_blocks))
|
||||||
@@ -212,7 +218,9 @@ class HybridKVCacheManager(KVCacheManager):
|
|||||||
block.ref_count -= 1
|
block.ref_count -= 1
|
||||||
|
|
||||||
if block.ref_count == 0:
|
if block.ref_count == 0:
|
||||||
# Free physical block
|
# Free physical block based on location
|
||||||
|
# Note: In CPU-primary mode, blocks are always on CPU.
|
||||||
|
# GPU branch kept for potential future hybrid mode support.
|
||||||
if block.location == BlockLocation.GPU:
|
if block.location == BlockLocation.GPU:
|
||||||
self.free_gpu_slots.append(block.gpu_slot)
|
self.free_gpu_slots.append(block.gpu_slot)
|
||||||
del self.gpu_slot_to_logical[block.gpu_slot]
|
del self.gpu_slot_to_logical[block.gpu_slot]
|
||||||
|
|||||||
@@ -193,6 +193,10 @@ class OffloadEngine:
|
|||||||
# ========== Event tracking for async transfers ==========
|
# ========== Event tracking for async transfers ==========
|
||||||
self.pending_events: Dict[Tuple[int, int], torch.cuda.Event] = {}
|
self.pending_events: Dict[Tuple[int, int], torch.cuda.Event] = {}
|
||||||
|
|
||||||
|
# ========== Debug hook mode ==========
|
||||||
|
self._debug_mode = False
|
||||||
|
self._debug_hooks: List = [] # External hooks for debug events
|
||||||
|
|
||||||
def _get_next_stream(self) -> torch.cuda.Stream:
|
def _get_next_stream(self) -> torch.cuda.Stream:
|
||||||
"""Round-robin stream selection for parallel transfers."""
|
"""Round-robin stream selection for parallel transfers."""
|
||||||
stream = self.transfer_streams[self._stream_idx]
|
stream = self.transfer_streams[self._stream_idx]
|
||||||
@@ -1022,4 +1026,71 @@ class OffloadEngine:
|
|||||||
if not slots:
|
if not slots:
|
||||||
slots = self.decode_load_slots
|
slots = self.decode_load_slots
|
||||||
slots = slots[:num_blocks]
|
slots = slots[:num_blocks]
|
||||||
return self.get_kv_for_slots(layer_id, slots)
|
return self.get_kv_for_slots(layer_id, slots)
|
||||||
|
|
||||||
|
# ========== Debug Hook Interface ==========
|
||||||
|
#
|
||||||
|
# Minimal generic hook system for debugging.
|
||||||
|
# Framework only provides hook registration and tensor access.
|
||||||
|
# All verification logic is external.
|
||||||
|
|
||||||
|
def enable_debug_mode(self) -> None:
|
||||||
|
"""Enable debug mode."""
|
||||||
|
self._debug_mode = True
|
||||||
|
logger.info("OffloadEngine debug mode ENABLED")
|
||||||
|
|
||||||
|
def disable_debug_mode(self) -> None:
|
||||||
|
"""Disable debug mode and clear all hooks."""
|
||||||
|
self._debug_mode = False
|
||||||
|
self._debug_hooks.clear()
|
||||||
|
logger.info("OffloadEngine debug mode DISABLED")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def debug_mode(self) -> bool:
|
||||||
|
"""Check if debug mode is enabled."""
|
||||||
|
return self._debug_mode
|
||||||
|
|
||||||
|
def register_debug_hook(self, hook_fn) -> None:
|
||||||
|
"""
|
||||||
|
Register a debug hook.
|
||||||
|
|
||||||
|
The hook is called after H2D load completes (after wait_slot_layer),
|
||||||
|
receiving the loaded tensor for inspection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hook_fn: Callable with signature:
|
||||||
|
(slot_idx: int, layer_id: int, cpu_block_id: int, k: Tensor, v: Tensor) -> None
|
||||||
|
- k, v: GPU tensor views for the loaded slot
|
||||||
|
|
||||||
|
Example:
|
||||||
|
def my_hook(slot_idx, layer_id, cpu_block_id, k, v):
|
||||||
|
if layer_id == 0:
|
||||||
|
k_val = k.float().mean().item()
|
||||||
|
print(f"Loaded block {cpu_block_id}, K mean = {k_val}")
|
||||||
|
|
||||||
|
offload_engine.register_debug_hook(my_hook)
|
||||||
|
"""
|
||||||
|
self._debug_hooks.append(hook_fn)
|
||||||
|
|
||||||
|
def remove_debug_hook(self, hook_fn) -> None:
|
||||||
|
"""Remove a registered debug hook."""
|
||||||
|
if hook_fn in self._debug_hooks:
|
||||||
|
self._debug_hooks.remove(hook_fn)
|
||||||
|
|
||||||
|
def _call_debug_hooks(self, slot_idx: int, layer_id: int, cpu_block_id: int) -> None:
|
||||||
|
"""
|
||||||
|
Call all registered debug hooks with loaded tensor (internal use).
|
||||||
|
|
||||||
|
Called by attention.py after wait_slot_layer completes.
|
||||||
|
"""
|
||||||
|
if not self._debug_mode or not self._debug_hooks:
|
||||||
|
return
|
||||||
|
|
||||||
|
k = self.k_cache_gpu[layer_id, slot_idx]
|
||||||
|
v = self.v_cache_gpu[layer_id, slot_idx]
|
||||||
|
|
||||||
|
for hook in self._debug_hooks:
|
||||||
|
try:
|
||||||
|
hook(slot_idx, layer_id, cpu_block_id, k, v)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Debug hook error: {e}")
|
||||||
@@ -287,9 +287,15 @@ class Attention(nn.Module):
|
|||||||
slot = load_slots[0]
|
slot = load_slots[0]
|
||||||
compute_stream = offload_engine.compute_stream
|
compute_stream = offload_engine.compute_stream
|
||||||
for block_idx in range(num_blocks):
|
for block_idx in range(num_blocks):
|
||||||
offload_engine.load_to_slot_layer(slot, self.layer_id, cpu_block_table[block_idx])
|
cpu_block_id = cpu_block_table[block_idx]
|
||||||
|
offload_engine.load_to_slot_layer(slot, self.layer_id, cpu_block_id)
|
||||||
offload_engine.wait_slot_layer(slot, self.layer_id)
|
offload_engine.wait_slot_layer(slot, self.layer_id)
|
||||||
|
|
||||||
with torch.cuda.stream(compute_stream):
|
with torch.cuda.stream(compute_stream):
|
||||||
|
# Debug: call hooks on compute_stream (synchronized with transfer)
|
||||||
|
if offload_engine.debug_mode:
|
||||||
|
offload_engine._call_debug_hooks(slot, self.layer_id, cpu_block_id)
|
||||||
|
|
||||||
prev_k, prev_v = offload_engine.get_kv_for_slot(slot, self.layer_id)
|
prev_k, prev_v = offload_engine.get_kv_for_slot(slot, self.layer_id)
|
||||||
prev_o, prev_lse = flash_attn_with_lse(
|
prev_o, prev_lse = flash_attn_with_lse(
|
||||||
q_batched, prev_k, prev_v,
|
q_batched, prev_k, prev_v,
|
||||||
@@ -323,6 +329,7 @@ class Attention(nn.Module):
|
|||||||
|
|
||||||
# Cycle through slots: slot[block_idx % num_slots]
|
# Cycle through slots: slot[block_idx % num_slots]
|
||||||
current_slot = load_slots[block_idx % num_slots]
|
current_slot = load_slots[block_idx % num_slots]
|
||||||
|
cpu_block_id = cpu_block_table[block_idx]
|
||||||
|
|
||||||
# Wait for current slot's transfer to complete (on compute_stream)
|
# Wait for current slot's transfer to complete (on compute_stream)
|
||||||
offload_engine.wait_slot_layer(current_slot, self.layer_id)
|
offload_engine.wait_slot_layer(current_slot, self.layer_id)
|
||||||
@@ -330,6 +337,10 @@ class Attention(nn.Module):
|
|||||||
# Compute attention on current slot's data
|
# Compute attention on current slot's data
|
||||||
# IMPORTANT: Use dedicated compute_stream to avoid implicit sync with default stream
|
# IMPORTANT: Use dedicated compute_stream to avoid implicit sync with default stream
|
||||||
with torch.cuda.stream(compute_stream):
|
with torch.cuda.stream(compute_stream):
|
||||||
|
# Debug: call hooks on compute_stream (synchronized with transfer)
|
||||||
|
if offload_engine.debug_mode:
|
||||||
|
offload_engine._call_debug_hooks(current_slot, self.layer_id, cpu_block_id)
|
||||||
|
|
||||||
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} PrevBlock{block_idx}")
|
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} PrevBlock{block_idx}")
|
||||||
prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot, self.layer_id)
|
prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot, self.layer_id)
|
||||||
prev_o, prev_lse = flash_attn_with_lse(
|
prev_o, prev_lse = flash_attn_with_lse(
|
||||||
|
|||||||
267
tests/test_debug_verification.py
Normal file
267
tests/test_debug_verification.py
Normal file
@@ -0,0 +1,267 @@
|
|||||||
|
"""
|
||||||
|
Test script for verifying KV cache offload correctness using debug hooks.
|
||||||
|
|
||||||
|
Strategy:
|
||||||
|
1. Inject distinctive K/V values (K=chunk_idx+1, V=-(chunk_idx+1))
|
||||||
|
2. Register debug hook to receive loaded tensor
|
||||||
|
3. Hook reads tensor values to verify correct block was loaded
|
||||||
|
4. No verification logic in framework - all external
|
||||||
|
|
||||||
|
This tests the framework's normal async execution path.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
os.environ["NANOVLLM_LOG_LEVEL"] = "INFO"
|
||||||
|
|
||||||
|
from random import randint, seed
|
||||||
|
from typing import Dict, List, Tuple
|
||||||
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
from nanovllm import LLM, SamplingParams
|
||||||
|
from nanovllm.utils.context import get_context
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Configuration
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
MODEL_PATH = os.path.expanduser("~/models/Qwen3-0.6B/")
|
||||||
|
MAX_MODEL_LEN = 32 * 1024
|
||||||
|
NUM_GPU_BLOCKS = 4
|
||||||
|
INPUT_LEN = 32 * 1024
|
||||||
|
BLOCK_SIZE = 1024
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# External state (managed by test, not framework)
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
# Record all load operations: list of {cpu_block_id, k_value, v_value, ...}
|
||||||
|
load_log: List[Dict] = []
|
||||||
|
|
||||||
|
# Track current chunk for grouping loads
|
||||||
|
current_chunk: List[int] = [0] # mutable container
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Debug hook - receives loaded tensor directly
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
def debug_load_hook(slot_idx: int, layer_id: int, cpu_block_id: int, k: Tensor, v: Tensor) -> None:
|
||||||
|
"""
|
||||||
|
Debug hook called after each H2D load.
|
||||||
|
Reads tensor values to verify which block was actually loaded.
|
||||||
|
"""
|
||||||
|
# Only record layer 0 for efficiency
|
||||||
|
if layer_id != 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Read tensor values (the distinctive pattern we injected)
|
||||||
|
k_val = k.float().mean().item()
|
||||||
|
v_val = v.float().mean().item()
|
||||||
|
|
||||||
|
load_log.append({
|
||||||
|
"chunk_idx": current_chunk[0],
|
||||||
|
"slot_idx": slot_idx,
|
||||||
|
"cpu_block_id": cpu_block_id,
|
||||||
|
"k_value": k_val,
|
||||||
|
"v_value": v_val,
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Pattern injection hook - injects distinctive values into K/V
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
def make_pattern_injection_hook(layer_id):
|
||||||
|
"""Inject distinctive patterns: K = chunk_idx + 1, V = -(chunk_idx + 1)"""
|
||||||
|
def hook(module, inputs):
|
||||||
|
ctx = get_context()
|
||||||
|
if not ctx.is_prefill:
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
if layer_id != 0:
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0
|
||||||
|
current_chunk[0] = chunk_idx # Update for debug_load_hook
|
||||||
|
|
||||||
|
if len(inputs) >= 3:
|
||||||
|
q, k, v = inputs[0], inputs[1], inputs[2]
|
||||||
|
k_pattern = float(chunk_idx + 1)
|
||||||
|
v_pattern = float(-(chunk_idx + 1))
|
||||||
|
k_new = torch.full_like(k, k_pattern)
|
||||||
|
v_new = torch.full_like(v, v_pattern)
|
||||||
|
return (q, k_new, v_new) + inputs[3:]
|
||||||
|
|
||||||
|
return inputs
|
||||||
|
return hook
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Verification functions (all external, not in framework)
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
def verify_load_order() -> Tuple[int, int, List[Dict]]:
|
||||||
|
"""Verify blocks were loaded in correct order by checking K values."""
|
||||||
|
# Group loads by chunk
|
||||||
|
chunk_loads: Dict[int, List[Tuple[int, float]]] = {}
|
||||||
|
for log in load_log:
|
||||||
|
chunk = log["chunk_idx"]
|
||||||
|
if chunk not in chunk_loads:
|
||||||
|
chunk_loads[chunk] = []
|
||||||
|
chunk_loads[chunk].append((log["cpu_block_id"], log["k_value"]))
|
||||||
|
|
||||||
|
correct = 0
|
||||||
|
incorrect = 0
|
||||||
|
errors = []
|
||||||
|
|
||||||
|
for chunk in sorted(chunk_loads.keys()):
|
||||||
|
loads = chunk_loads[chunk]
|
||||||
|
# Expected: blocks [0, 1, ..., chunk-1] with K values [1, 2, ..., chunk]
|
||||||
|
expected_blocks = list(range(chunk))
|
||||||
|
actual_blocks = [block_id for block_id, _ in loads]
|
||||||
|
|
||||||
|
# Also verify K values match expected pattern
|
||||||
|
k_values = [k_val for _, k_val in loads]
|
||||||
|
expected_k_values = [float(b + 1) for b in expected_blocks]
|
||||||
|
|
||||||
|
blocks_ok = actual_blocks == expected_blocks
|
||||||
|
# Check K values with tolerance
|
||||||
|
k_ok = all(abs(a - e) < 1e-2 for a, e in zip(k_values, expected_k_values)) if len(k_values) == len(expected_k_values) else False
|
||||||
|
|
||||||
|
if blocks_ok and k_ok:
|
||||||
|
correct += 1
|
||||||
|
else:
|
||||||
|
incorrect += 1
|
||||||
|
errors.append({
|
||||||
|
"chunk_idx": chunk,
|
||||||
|
"expected_blocks": expected_blocks,
|
||||||
|
"actual_blocks": actual_blocks,
|
||||||
|
"expected_k": expected_k_values,
|
||||||
|
"actual_k": k_values,
|
||||||
|
})
|
||||||
|
|
||||||
|
return correct, incorrect, errors
|
||||||
|
|
||||||
|
|
||||||
|
def print_verification_summary():
|
||||||
|
"""Print verification results."""
|
||||||
|
correct, incorrect, errors = verify_load_order()
|
||||||
|
|
||||||
|
# Group for display
|
||||||
|
chunk_loads: Dict[int, List[int]] = {}
|
||||||
|
for log in load_log:
|
||||||
|
chunk = log["chunk_idx"]
|
||||||
|
if chunk not in chunk_loads:
|
||||||
|
chunk_loads[chunk] = []
|
||||||
|
chunk_loads[chunk].append(log["cpu_block_id"])
|
||||||
|
|
||||||
|
print(f"\n{'='*60}")
|
||||||
|
print("Debug Verification Summary")
|
||||||
|
print(f"{'='*60}")
|
||||||
|
|
||||||
|
print(f"\n1. Load Operations:")
|
||||||
|
print(f" Total H2D loads recorded: {len(load_log)}")
|
||||||
|
print(f" Chunks with correct order: {correct}")
|
||||||
|
print(f" Chunks with incorrect order: {incorrect}")
|
||||||
|
|
||||||
|
if incorrect > 0:
|
||||||
|
print(f"\n Errors:")
|
||||||
|
for err in errors[:5]:
|
||||||
|
print(f" Chunk {err['chunk_idx']}:")
|
||||||
|
print(f" Expected blocks: {err['expected_blocks']}")
|
||||||
|
print(f" Actual blocks: {err['actual_blocks']}")
|
||||||
|
print(f" K values: {[f'{v:.1f}' for v in err['actual_k']]}")
|
||||||
|
|
||||||
|
print(f"\n2. Load Order Sample (first 5 and last 2 chunks):")
|
||||||
|
sorted_chunks = sorted(chunk_loads.keys())
|
||||||
|
display_chunks = sorted_chunks[:5] + sorted_chunks[-2:] if len(sorted_chunks) > 7 else sorted_chunks
|
||||||
|
for chunk in display_chunks:
|
||||||
|
blocks = chunk_loads[chunk]
|
||||||
|
expected = list(range(chunk))
|
||||||
|
status = "OK" if blocks == expected else "WRONG"
|
||||||
|
print(f" Chunk {chunk}: {blocks} [{status}]")
|
||||||
|
|
||||||
|
print(f"\n{'='*60}")
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Main Test Script
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
print("Initializing LLM with CPU offload...")
|
||||||
|
llm = LLM(
|
||||||
|
MODEL_PATH,
|
||||||
|
enforce_eager=True,
|
||||||
|
max_model_len=MAX_MODEL_LEN,
|
||||||
|
max_num_batched_tokens=MAX_MODEL_LEN,
|
||||||
|
enable_cpu_offload=True,
|
||||||
|
kvcache_block_size=BLOCK_SIZE,
|
||||||
|
num_gpu_blocks=NUM_GPU_BLOCKS,
|
||||||
|
dtype="float16",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get offload engine and enable debug mode
|
||||||
|
kvcache_manager = llm.model_runner.kvcache_manager
|
||||||
|
offload_engine = kvcache_manager.offload_engine
|
||||||
|
offload_engine.enable_debug_mode()
|
||||||
|
|
||||||
|
# Register our debug hook
|
||||||
|
offload_engine.register_debug_hook(debug_load_hook)
|
||||||
|
print("Debug mode enabled with custom hook")
|
||||||
|
|
||||||
|
# Register pattern injection hooks
|
||||||
|
hooks = []
|
||||||
|
model = llm.model_runner.model
|
||||||
|
for layer_idx, decoder_layer in enumerate(model.model.layers):
|
||||||
|
attn_module = decoder_layer.self_attn.attn
|
||||||
|
pre_hook = attn_module.register_forward_pre_hook(make_pattern_injection_hook(layer_idx))
|
||||||
|
hooks.append(pre_hook)
|
||||||
|
print(f"Registered {len(hooks)} pattern injection hooks")
|
||||||
|
|
||||||
|
# Generate input
|
||||||
|
seed(42)
|
||||||
|
prompt_token_ids = [[randint(0, 10000) for _ in range(INPUT_LEN)]]
|
||||||
|
num_chunks = INPUT_LEN // BLOCK_SIZE
|
||||||
|
print(f"\nInput: {INPUT_LEN} tokens, {num_chunks} chunks expected")
|
||||||
|
print(f"GPU blocks: {NUM_GPU_BLOCKS}, Block size: {BLOCK_SIZE}")
|
||||||
|
|
||||||
|
# Run prefill
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("Starting Prefill...")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
sampling_params = SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=1)
|
||||||
|
outputs = llm.generate(prompt_token_ids, sampling_params, use_tqdm=False)
|
||||||
|
|
||||||
|
# Remove hooks
|
||||||
|
for hook in hooks:
|
||||||
|
hook.remove()
|
||||||
|
offload_engine.remove_debug_hook(debug_load_hook)
|
||||||
|
|
||||||
|
# Verify and print
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("Post-Execution Verification")
|
||||||
|
print("=" * 60)
|
||||||
|
print_verification_summary()
|
||||||
|
|
||||||
|
# Final verdict
|
||||||
|
correct, incorrect, _ = verify_load_order()
|
||||||
|
expected_loads = num_chunks * (num_chunks - 1) // 2
|
||||||
|
actual_loads = len(load_log)
|
||||||
|
|
||||||
|
print(f"\nResults:")
|
||||||
|
print(f" Total loads: {actual_loads} (expected: {expected_loads})")
|
||||||
|
print(f" Order verification: {correct} correct, {incorrect} incorrect")
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
all_passed = incorrect == 0 and actual_loads == expected_loads
|
||||||
|
|
||||||
|
if all_passed:
|
||||||
|
print("test_debug_verification: PASSED")
|
||||||
|
else:
|
||||||
|
print("test_debug_verification: FAILED")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
offload_engine.disable_debug_mode()
|
||||||
Reference in New Issue
Block a user