diff --git a/nanovllm/config.py b/nanovllm/config.py index e59a5eb..3087d3c 100644 --- a/nanovllm/config.py +++ b/nanovllm/config.py @@ -1,6 +1,7 @@ import os from dataclasses import dataclass from transformers import AutoConfig +import torch @dataclass @@ -16,6 +17,7 @@ class Config: eos: int = -1 kvcache_block_size: int = 4096 num_kvcache_blocks: int = -1 + dtype: str | None = None # "float16", "bfloat16", or None (use model default) # CPU Offload configuration enable_cpu_offload: bool = False @@ -41,3 +43,17 @@ class Config: self.hf_config = AutoConfig.from_pretrained(self.model) self.max_model_len = min(self.max_model_len, self.hf_config.max_position_embeddings) assert self.max_num_batched_tokens >= self.max_model_len + + # Override torch_dtype if user specified + if self.dtype is not None: + dtype_map = { + "float16": torch.float16, + "fp16": torch.float16, + "bfloat16": torch.bfloat16, + "bf16": torch.bfloat16, + "float32": torch.float32, + "fp32": torch.float32, + } + if self.dtype not in dtype_map: + raise ValueError(f"Invalid dtype: {self.dtype}. Choose from: {list(dtype_map.keys())}") + self.hf_config.torch_dtype = dtype_map[self.dtype] diff --git a/nanovllm/kvcache/hybrid_manager.py b/nanovllm/kvcache/hybrid_manager.py index 937c626..eddd270 100644 --- a/nanovllm/kvcache/hybrid_manager.py +++ b/nanovllm/kvcache/hybrid_manager.py @@ -69,15 +69,19 @@ class HybridKVCacheManager(KVCacheManager): Architecture (CPU-primary mode): - CPU pool: Primary storage for all KV cache (num_cpu_blocks) - - GPU buffer: Ring buffer for computation (num_gpu_slots) - - Logical blocks: What sequences reference (num_gpu_slots + num_cpu_blocks) + - GPU buffer: Ring buffer for computation only (num_gpu_slots) + - Logical blocks: What sequences reference (num_cpu_blocks) Design: - All KV cache is stored on CPU as primary storage - - GPU is used as a ring buffer for computation only + - GPU is used as a ring buffer for computation only (no persistent data) - During prefill: KV is written to GPU ring slot, then offloaded to CPU - During decode: Previous KV is loaded from CPU to GPU for attention - Ring buffer enables pipelined H2D transfers overlapped with computation + + Note: + - Logical blocks map 1:1 with CPU blocks (total_blocks = num_cpu_blocks) + - GPU slots are transient compute buffers, not tracked in logical blocks """ def __init__( @@ -102,20 +106,22 @@ class HybridKVCacheManager(KVCacheManager): self._block_size = block_size self.num_gpu_slots = num_gpu_slots self.num_cpu_blocks = num_cpu_blocks - self.total_blocks = num_gpu_slots + num_cpu_blocks + # In CPU-primary mode, logical blocks map 1:1 with CPU blocks + # GPU slots are transient compute buffers, not tracked as logical blocks + self.total_blocks = num_cpu_blocks # Eviction policy self.policy = policy or LRUPolicy() - # Logical blocks (what sequences reference) + # Logical blocks (what sequences reference) - one per CPU block self.logical_blocks: List[LogicalBlock] = [ LogicalBlock(i) for i in range(self.total_blocks) ] self.free_logical_ids: deque[int] = deque(range(self.total_blocks)) - # GPU slot management (slots are fixed, mapping is variable) + # GPU slot management (kept for potential future use, but not used in CPU-primary mode) self.free_gpu_slots: deque[int] = deque(range(num_gpu_slots)) - self.gpu_slot_to_logical: Dict[int, int] = {} # gpu_slot -> logical_id + self.gpu_slot_to_logical: Dict[int, int] = {} # gpu_slot -> logical_id (unused in CPU-primary mode) # CPU block management self.free_cpu_blocks: deque[int] = deque(range(num_cpu_blocks)) @@ -212,7 +218,9 @@ class HybridKVCacheManager(KVCacheManager): block.ref_count -= 1 if block.ref_count == 0: - # Free physical block + # Free physical block based on location + # Note: In CPU-primary mode, blocks are always on CPU. + # GPU branch kept for potential future hybrid mode support. if block.location == BlockLocation.GPU: self.free_gpu_slots.append(block.gpu_slot) del self.gpu_slot_to_logical[block.gpu_slot] diff --git a/nanovllm/kvcache/offload_engine.py b/nanovllm/kvcache/offload_engine.py index bf8cc16..9f2d4f8 100644 --- a/nanovllm/kvcache/offload_engine.py +++ b/nanovllm/kvcache/offload_engine.py @@ -193,6 +193,10 @@ class OffloadEngine: # ========== Event tracking for async transfers ========== self.pending_events: Dict[Tuple[int, int], torch.cuda.Event] = {} + # ========== Debug hook mode ========== + self._debug_mode = False + self._debug_hooks: List = [] # External hooks for debug events + def _get_next_stream(self) -> torch.cuda.Stream: """Round-robin stream selection for parallel transfers.""" stream = self.transfer_streams[self._stream_idx] @@ -1022,4 +1026,71 @@ class OffloadEngine: if not slots: slots = self.decode_load_slots slots = slots[:num_blocks] - return self.get_kv_for_slots(layer_id, slots) \ No newline at end of file + return self.get_kv_for_slots(layer_id, slots) + + # ========== Debug Hook Interface ========== + # + # Minimal generic hook system for debugging. + # Framework only provides hook registration and tensor access. + # All verification logic is external. + + def enable_debug_mode(self) -> None: + """Enable debug mode.""" + self._debug_mode = True + logger.info("OffloadEngine debug mode ENABLED") + + def disable_debug_mode(self) -> None: + """Disable debug mode and clear all hooks.""" + self._debug_mode = False + self._debug_hooks.clear() + logger.info("OffloadEngine debug mode DISABLED") + + @property + def debug_mode(self) -> bool: + """Check if debug mode is enabled.""" + return self._debug_mode + + def register_debug_hook(self, hook_fn) -> None: + """ + Register a debug hook. + + The hook is called after H2D load completes (after wait_slot_layer), + receiving the loaded tensor for inspection. + + Args: + hook_fn: Callable with signature: + (slot_idx: int, layer_id: int, cpu_block_id: int, k: Tensor, v: Tensor) -> None + - k, v: GPU tensor views for the loaded slot + + Example: + def my_hook(slot_idx, layer_id, cpu_block_id, k, v): + if layer_id == 0: + k_val = k.float().mean().item() + print(f"Loaded block {cpu_block_id}, K mean = {k_val}") + + offload_engine.register_debug_hook(my_hook) + """ + self._debug_hooks.append(hook_fn) + + def remove_debug_hook(self, hook_fn) -> None: + """Remove a registered debug hook.""" + if hook_fn in self._debug_hooks: + self._debug_hooks.remove(hook_fn) + + def _call_debug_hooks(self, slot_idx: int, layer_id: int, cpu_block_id: int) -> None: + """ + Call all registered debug hooks with loaded tensor (internal use). + + Called by attention.py after wait_slot_layer completes. + """ + if not self._debug_mode or not self._debug_hooks: + return + + k = self.k_cache_gpu[layer_id, slot_idx] + v = self.v_cache_gpu[layer_id, slot_idx] + + for hook in self._debug_hooks: + try: + hook(slot_idx, layer_id, cpu_block_id, k, v) + except Exception as e: + logger.warning(f"Debug hook error: {e}") \ No newline at end of file diff --git a/nanovllm/layers/attention.py b/nanovllm/layers/attention.py index 2caac7e..4171ad8 100644 --- a/nanovllm/layers/attention.py +++ b/nanovllm/layers/attention.py @@ -287,9 +287,15 @@ class Attention(nn.Module): slot = load_slots[0] compute_stream = offload_engine.compute_stream for block_idx in range(num_blocks): - offload_engine.load_to_slot_layer(slot, self.layer_id, cpu_block_table[block_idx]) + cpu_block_id = cpu_block_table[block_idx] + offload_engine.load_to_slot_layer(slot, self.layer_id, cpu_block_id) offload_engine.wait_slot_layer(slot, self.layer_id) + with torch.cuda.stream(compute_stream): + # Debug: call hooks on compute_stream (synchronized with transfer) + if offload_engine.debug_mode: + offload_engine._call_debug_hooks(slot, self.layer_id, cpu_block_id) + prev_k, prev_v = offload_engine.get_kv_for_slot(slot, self.layer_id) prev_o, prev_lse = flash_attn_with_lse( q_batched, prev_k, prev_v, @@ -323,6 +329,7 @@ class Attention(nn.Module): # Cycle through slots: slot[block_idx % num_slots] current_slot = load_slots[block_idx % num_slots] + cpu_block_id = cpu_block_table[block_idx] # Wait for current slot's transfer to complete (on compute_stream) offload_engine.wait_slot_layer(current_slot, self.layer_id) @@ -330,6 +337,10 @@ class Attention(nn.Module): # Compute attention on current slot's data # IMPORTANT: Use dedicated compute_stream to avoid implicit sync with default stream with torch.cuda.stream(compute_stream): + # Debug: call hooks on compute_stream (synchronized with transfer) + if offload_engine.debug_mode: + offload_engine._call_debug_hooks(current_slot, self.layer_id, cpu_block_id) + torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} PrevBlock{block_idx}") prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot, self.layer_id) prev_o, prev_lse = flash_attn_with_lse( diff --git a/tests/test_debug_verification.py b/tests/test_debug_verification.py new file mode 100644 index 0000000..bf1d0cc --- /dev/null +++ b/tests/test_debug_verification.py @@ -0,0 +1,267 @@ +""" +Test script for verifying KV cache offload correctness using debug hooks. + +Strategy: +1. Inject distinctive K/V values (K=chunk_idx+1, V=-(chunk_idx+1)) +2. Register debug hook to receive loaded tensor +3. Hook reads tensor values to verify correct block was loaded +4. No verification logic in framework - all external + +This tests the framework's normal async execution path. +""" + +import os +os.environ["NANOVLLM_LOG_LEVEL"] = "INFO" + +from random import randint, seed +from typing import Dict, List, Tuple +import torch +from torch import Tensor +from nanovllm import LLM, SamplingParams +from nanovllm.utils.context import get_context + + +# ============================================================ +# Configuration +# ============================================================ + +MODEL_PATH = os.path.expanduser("~/models/Qwen3-0.6B/") +MAX_MODEL_LEN = 32 * 1024 +NUM_GPU_BLOCKS = 4 +INPUT_LEN = 32 * 1024 +BLOCK_SIZE = 1024 + + +# ============================================================ +# External state (managed by test, not framework) +# ============================================================ + +# Record all load operations: list of {cpu_block_id, k_value, v_value, ...} +load_log: List[Dict] = [] + +# Track current chunk for grouping loads +current_chunk: List[int] = [0] # mutable container + + +# ============================================================ +# Debug hook - receives loaded tensor directly +# ============================================================ + +def debug_load_hook(slot_idx: int, layer_id: int, cpu_block_id: int, k: Tensor, v: Tensor) -> None: + """ + Debug hook called after each H2D load. + Reads tensor values to verify which block was actually loaded. + """ + # Only record layer 0 for efficiency + if layer_id != 0: + return + + # Read tensor values (the distinctive pattern we injected) + k_val = k.float().mean().item() + v_val = v.float().mean().item() + + load_log.append({ + "chunk_idx": current_chunk[0], + "slot_idx": slot_idx, + "cpu_block_id": cpu_block_id, + "k_value": k_val, + "v_value": v_val, + }) + + +# ============================================================ +# Pattern injection hook - injects distinctive values into K/V +# ============================================================ + +def make_pattern_injection_hook(layer_id): + """Inject distinctive patterns: K = chunk_idx + 1, V = -(chunk_idx + 1)""" + def hook(module, inputs): + ctx = get_context() + if not ctx.is_prefill: + return inputs + + if layer_id != 0: + return inputs + + chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0 + current_chunk[0] = chunk_idx # Update for debug_load_hook + + if len(inputs) >= 3: + q, k, v = inputs[0], inputs[1], inputs[2] + k_pattern = float(chunk_idx + 1) + v_pattern = float(-(chunk_idx + 1)) + k_new = torch.full_like(k, k_pattern) + v_new = torch.full_like(v, v_pattern) + return (q, k_new, v_new) + inputs[3:] + + return inputs + return hook + + +# ============================================================ +# Verification functions (all external, not in framework) +# ============================================================ + +def verify_load_order() -> Tuple[int, int, List[Dict]]: + """Verify blocks were loaded in correct order by checking K values.""" + # Group loads by chunk + chunk_loads: Dict[int, List[Tuple[int, float]]] = {} + for log in load_log: + chunk = log["chunk_idx"] + if chunk not in chunk_loads: + chunk_loads[chunk] = [] + chunk_loads[chunk].append((log["cpu_block_id"], log["k_value"])) + + correct = 0 + incorrect = 0 + errors = [] + + for chunk in sorted(chunk_loads.keys()): + loads = chunk_loads[chunk] + # Expected: blocks [0, 1, ..., chunk-1] with K values [1, 2, ..., chunk] + expected_blocks = list(range(chunk)) + actual_blocks = [block_id for block_id, _ in loads] + + # Also verify K values match expected pattern + k_values = [k_val for _, k_val in loads] + expected_k_values = [float(b + 1) for b in expected_blocks] + + blocks_ok = actual_blocks == expected_blocks + # Check K values with tolerance + k_ok = all(abs(a - e) < 1e-2 for a, e in zip(k_values, expected_k_values)) if len(k_values) == len(expected_k_values) else False + + if blocks_ok and k_ok: + correct += 1 + else: + incorrect += 1 + errors.append({ + "chunk_idx": chunk, + "expected_blocks": expected_blocks, + "actual_blocks": actual_blocks, + "expected_k": expected_k_values, + "actual_k": k_values, + }) + + return correct, incorrect, errors + + +def print_verification_summary(): + """Print verification results.""" + correct, incorrect, errors = verify_load_order() + + # Group for display + chunk_loads: Dict[int, List[int]] = {} + for log in load_log: + chunk = log["chunk_idx"] + if chunk not in chunk_loads: + chunk_loads[chunk] = [] + chunk_loads[chunk].append(log["cpu_block_id"]) + + print(f"\n{'='*60}") + print("Debug Verification Summary") + print(f"{'='*60}") + + print(f"\n1. Load Operations:") + print(f" Total H2D loads recorded: {len(load_log)}") + print(f" Chunks with correct order: {correct}") + print(f" Chunks with incorrect order: {incorrect}") + + if incorrect > 0: + print(f"\n Errors:") + for err in errors[:5]: + print(f" Chunk {err['chunk_idx']}:") + print(f" Expected blocks: {err['expected_blocks']}") + print(f" Actual blocks: {err['actual_blocks']}") + print(f" K values: {[f'{v:.1f}' for v in err['actual_k']]}") + + print(f"\n2. Load Order Sample (first 5 and last 2 chunks):") + sorted_chunks = sorted(chunk_loads.keys()) + display_chunks = sorted_chunks[:5] + sorted_chunks[-2:] if len(sorted_chunks) > 7 else sorted_chunks + for chunk in display_chunks: + blocks = chunk_loads[chunk] + expected = list(range(chunk)) + status = "OK" if blocks == expected else "WRONG" + print(f" Chunk {chunk}: {blocks} [{status}]") + + print(f"\n{'='*60}") + + +# ============================================================ +# Main Test Script +# ============================================================ + +print("Initializing LLM with CPU offload...") +llm = LLM( + MODEL_PATH, + enforce_eager=True, + max_model_len=MAX_MODEL_LEN, + max_num_batched_tokens=MAX_MODEL_LEN, + enable_cpu_offload=True, + kvcache_block_size=BLOCK_SIZE, + num_gpu_blocks=NUM_GPU_BLOCKS, + dtype="float16", +) + +# Get offload engine and enable debug mode +kvcache_manager = llm.model_runner.kvcache_manager +offload_engine = kvcache_manager.offload_engine +offload_engine.enable_debug_mode() + +# Register our debug hook +offload_engine.register_debug_hook(debug_load_hook) +print("Debug mode enabled with custom hook") + +# Register pattern injection hooks +hooks = [] +model = llm.model_runner.model +for layer_idx, decoder_layer in enumerate(model.model.layers): + attn_module = decoder_layer.self_attn.attn + pre_hook = attn_module.register_forward_pre_hook(make_pattern_injection_hook(layer_idx)) + hooks.append(pre_hook) +print(f"Registered {len(hooks)} pattern injection hooks") + +# Generate input +seed(42) +prompt_token_ids = [[randint(0, 10000) for _ in range(INPUT_LEN)]] +num_chunks = INPUT_LEN // BLOCK_SIZE +print(f"\nInput: {INPUT_LEN} tokens, {num_chunks} chunks expected") +print(f"GPU blocks: {NUM_GPU_BLOCKS}, Block size: {BLOCK_SIZE}") + +# Run prefill +print("\n" + "=" * 60) +print("Starting Prefill...") +print("=" * 60) + +sampling_params = SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=1) +outputs = llm.generate(prompt_token_ids, sampling_params, use_tqdm=False) + +# Remove hooks +for hook in hooks: + hook.remove() +offload_engine.remove_debug_hook(debug_load_hook) + +# Verify and print +print("\n" + "=" * 60) +print("Post-Execution Verification") +print("=" * 60) +print_verification_summary() + +# Final verdict +correct, incorrect, _ = verify_load_order() +expected_loads = num_chunks * (num_chunks - 1) // 2 +actual_loads = len(load_log) + +print(f"\nResults:") +print(f" Total loads: {actual_loads} (expected: {expected_loads})") +print(f" Order verification: {correct} correct, {incorrect} incorrect") + +print("\n" + "=" * 60) +all_passed = incorrect == 0 and actual_loads == expected_loads + +if all_passed: + print("test_debug_verification: PASSED") +else: + print("test_debug_verification: FAILED") +print("=" * 60) + +offload_engine.disable_debug_mode()