diff --git a/nanovllm/debug/__init__.py b/nanovllm/debug/__init__.py new file mode 100644 index 0000000..923c30e --- /dev/null +++ b/nanovllm/debug/__init__.py @@ -0,0 +1,49 @@ +""" +Breakpoint debugging tools for aligning nanovllm with reference implementations. + +This module provides a generator-based breakpoint aligner that enables step-by-step +comparison between nanovllm and torch reference model outputs. + +Example: + >>> from nanovllm.debug import BreakpointAligner, TorchSteppable, NanovllmSteppable + >>> from tests.modeling_qwen3 import Qwen3ForCausalLM + >>> + >>> # Load models + >>> torch_model = Qwen3ForCausalLM.from_pretrained(model_path, dtype=torch.float16) + >>> nanovllm_model = ... # Your nanovllm model + >>> + >>> # Create adapters + >>> ref = TorchSteppable(torch_model) + >>> test = NanovllmSteppable(nanovllm_model) + >>> + >>> # Run alignment + >>> aligner = BreakpointAligner(ref, test) + >>> result = aligner.align(input_ids) + >>> print(result) +""" + +from .breakpoints import BreakpointType, Breakpoint +from .comparator import TensorComparator, ComparisonResult +from .aligner import BreakpointAligner, AlignmentResult +from .adapters import SteppableModel, TorchSteppable, NanovllmSteppable +from .utils import setup_prefill_context, setup_decode_context, cleanup_context + +__all__ = [ + # Core classes + "BreakpointAligner", + "AlignmentResult", + # Breakpoints + "BreakpointType", + "Breakpoint", + # Comparator + "TensorComparator", + "ComparisonResult", + # Adapters + "SteppableModel", + "TorchSteppable", + "NanovllmSteppable", + # Utils + "setup_prefill_context", + "setup_decode_context", + "cleanup_context", +] diff --git a/nanovllm/debug/adapters/__init__.py b/nanovllm/debug/adapters/__init__.py new file mode 100644 index 0000000..771e74f --- /dev/null +++ b/nanovllm/debug/adapters/__init__.py @@ -0,0 +1,11 @@ +"""Model adapters for breakpoint alignment.""" + +from .base import SteppableModel +from .torch_adapter import TorchSteppable +from .nanovllm_adapter import NanovllmSteppable + +__all__ = [ + "SteppableModel", + "TorchSteppable", + "NanovllmSteppable", +] diff --git a/nanovllm/debug/adapters/base.py b/nanovllm/debug/adapters/base.py new file mode 100644 index 0000000..8e5c0b7 --- /dev/null +++ b/nanovllm/debug/adapters/base.py @@ -0,0 +1,59 @@ +"""Base class for steppable model adapters.""" + +from abc import ABC, abstractmethod +from typing import Generator, Set, Optional +import torch + +from ..breakpoints import Breakpoint, BreakpointType + + +class SteppableModel(ABC): + """ + Abstract base class for models that can yield at breakpoints. + + Subclasses implement the step() method as a generator that yields + Breakpoint objects at each enabled breakpoint during forward pass. + """ + + def __init__(self, enabled_breakpoints: Optional[Set[BreakpointType]] = None): + """ + Args: + enabled_breakpoints: Set of breakpoint types to yield at. + If None, yields at all breakpoints. + """ + self.enabled_breakpoints = enabled_breakpoints + + def is_enabled(self, bp_type: BreakpointType) -> bool: + """Check if a breakpoint type is enabled.""" + if self.enabled_breakpoints is None: + return True + return bp_type in self.enabled_breakpoints + + @abstractmethod + def step( + self, + input_ids: torch.Tensor, + positions: Optional[torch.Tensor] = None, + is_prefill: bool = True, + ) -> Generator[Breakpoint, None, torch.Tensor]: + """ + Generator that yields Breakpoint objects at enabled breakpoints. + + Args: + input_ids: Input token IDs + positions: Position IDs (optional, auto-generated if None) + is_prefill: True for prefill phase, False for decode + + Yields: + Breakpoint objects at each enabled checkpoint + + Returns: + Final output tensor (logits) + """ + pass + + @property + @abstractmethod + def num_layers(self) -> int: + """Return the number of decoder layers.""" + pass diff --git a/nanovllm/debug/adapters/nanovllm_adapter.py b/nanovllm/debug/adapters/nanovllm_adapter.py new file mode 100644 index 0000000..8ab38be --- /dev/null +++ b/nanovllm/debug/adapters/nanovllm_adapter.py @@ -0,0 +1,229 @@ +"""Nanovllm model adapter for breakpoint alignment.""" + +from typing import Generator, Set, Optional, Dict, Any, List +import torch + +from nanovllm.utils.context import set_context, reset_context +from ..breakpoints import Breakpoint, BreakpointType +from .base import SteppableModel + + +class NanovllmSteppable(SteppableModel): + """ + Steppable adapter for nanovllm Qwen3 implementation. + + Uses PyTorch hooks to capture intermediate values during forward pass, + then yields them as breakpoints after execution completes. + + Key challenges handled: + 1. Shape difference: nanovllm uses [num_tokens, hidden] vs [batch, seq, hidden] + 2. Context-based attention: must call set_context() before forward + 3. Fused operations: decoder layer returns (hidden_states, residual) tuple + """ + + def __init__( + self, + model, + enabled_breakpoints: Optional[Set[BreakpointType]] = None, + ): + """ + Args: + model: Qwen3ForCausalLM from nanovllm + enabled_breakpoints: Set of breakpoint types to yield at + """ + super().__init__(enabled_breakpoints) + self.model = model + self.model.eval() + self._hooks: List[Any] = [] + self._captured: Dict[str, torch.Tensor] = {} + + @property + def num_layers(self) -> int: + return len(self.model.model.layers) + + def _register_hooks(self): + """Register forward hooks on all relevant modules.""" + self._hooks = [] + self._captured = {} + + # Hook for embedding output + def embed_hook(module, input, output): + self._captured["embed"] = output.detach().clone() + + self._hooks.append( + self.model.model.embed_tokens.register_forward_hook(embed_hook) + ) + + # Hooks for each decoder layer + for layer_idx in range(self.num_layers): + layer = self.model.model.layers[layer_idx] + + def make_layer_hook(idx): + def hook(module, input, output): + # Decoder layer returns (hidden_states, residual) + hidden_states = output[0] if isinstance(output, tuple) else output + self._captured[f"layer_{idx}"] = hidden_states.detach().clone() + return hook + + self._hooks.append( + layer.register_forward_hook(make_layer_hook(layer_idx)) + ) + + # Hook for final norm + def final_norm_hook(module, input, output): + # Final norm returns (hidden_states, _) for fused add + hidden_states = output[0] if isinstance(output, tuple) else output + self._captured["final_norm"] = hidden_states.detach().clone() + + self._hooks.append( + self.model.model.norm.register_forward_hook(final_norm_hook) + ) + + # Hook for lm_head + def lm_head_hook(module, input, output): + self._captured["lm_head"] = output.detach().clone() + + self._hooks.append( + self.model.lm_head.register_forward_hook(lm_head_hook) + ) + + def _remove_hooks(self): + """Remove all registered hooks.""" + for hook in self._hooks: + hook.remove() + self._hooks = [] + + def _setup_context(self, seq_len: int, device: torch.device, is_prefill: bool): + """ + Set up nanovllm context for forward pass. + + For alignment testing, we use simple context without real KV cache. + """ + if is_prefill: + # Prefill: process all tokens at once + cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=device) + # Use -1 for slot_mapping to skip KV cache writes + slot_mapping = torch.full((seq_len,), -1, dtype=torch.int32, device=device) + + set_context( + is_prefill=True, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=seq_len, + max_seqlen_k=seq_len, + slot_mapping=slot_mapping, + is_chunked_prefill=False, + ) + else: + # Decode: single token generation + # For decode, we need context_lens and block_tables + # For alignment testing without real KV cache, we use minimal setup + context_lens = torch.tensor([seq_len - 1], dtype=torch.int32, device=device) + # Single token slot + slot_mapping = torch.tensor([-1], dtype=torch.int32, device=device) + # Empty block tables (no KV cache) + block_tables = torch.zeros((1, 1), dtype=torch.int32, device=device) + + set_context( + is_prefill=False, + slot_mapping=slot_mapping, + context_lens=context_lens, + block_tables=block_tables, + ) + + def _normalize_tensor(self, tensor: torch.Tensor) -> torch.Tensor: + """ + Normalize nanovllm tensor shape to [batch, seq_len, ...]. + + nanovllm uses [num_tokens, ...] format without batch dimension. + We add batch dimension for comparison with torch model. + """ + if tensor.dim() == 2: # [num_tokens, hidden_size] + return tensor.unsqueeze(0) + return tensor + + def step( + self, + input_ids: torch.Tensor, + positions: Optional[torch.Tensor] = None, + is_prefill: bool = True, + ) -> Generator[Breakpoint, None, torch.Tensor]: + """ + Execute nanovllm forward pass with hooks to capture breakpoints. + + Unlike the torch adapter which manually steps through each component, + we run the full forward pass and collect captured values afterward. + """ + # Ensure 1D for nanovllm (it expects [num_tokens]) + if input_ids.dim() == 2: + input_ids = input_ids.squeeze(0) + + seq_len = input_ids.numel() + device = input_ids.device + + # Generate position IDs if not provided + if positions is None: + positions = torch.arange(seq_len, device=device) + elif positions.dim() == 2: + positions = positions.squeeze(0) + + # Register hooks + self._register_hooks() + + try: + # Setup context for attention + self._setup_context(seq_len, device, is_prefill) + + # Run forward pass (hooks capture everything) + with torch.no_grad(): + hidden_states = self.model(input_ids, positions) + logits = self.model.compute_logits(hidden_states) + + reset_context() + + # Yield breakpoints in order from captured data + + # EMBEDDING + if self.is_enabled(BreakpointType.EMBEDDING) and "embed" in self._captured: + yield Breakpoint( + bp_type=BreakpointType.EMBEDDING, + layer_idx=None, + tensor=self._normalize_tensor(self._captured["embed"]), + name="Embedding", + ) + + # LAYER_OUTPUT for each layer + if self.is_enabled(BreakpointType.LAYER_OUTPUT): + for layer_idx in range(self.num_layers): + key = f"layer_{layer_idx}" + if key in self._captured: + yield Breakpoint( + bp_type=BreakpointType.LAYER_OUTPUT, + layer_idx=layer_idx, + tensor=self._normalize_tensor(self._captured[key]), + name=f"Layer {layer_idx}", + ) + + # FINAL_NORM + if self.is_enabled(BreakpointType.FINAL_NORM) and "final_norm" in self._captured: + yield Breakpoint( + bp_type=BreakpointType.FINAL_NORM, + layer_idx=None, + tensor=self._normalize_tensor(self._captured["final_norm"]), + name="Final Norm", + ) + + # LM_HEAD + if self.is_enabled(BreakpointType.LM_HEAD) and "lm_head" in self._captured: + yield Breakpoint( + bp_type=BreakpointType.LM_HEAD, + layer_idx=None, + tensor=self._normalize_tensor(self._captured["lm_head"]), + name="LM Head", + ) + + return logits + + finally: + self._remove_hooks() + self._captured = {} diff --git a/nanovllm/debug/adapters/torch_adapter.py b/nanovllm/debug/adapters/torch_adapter.py new file mode 100644 index 0000000..6e54858 --- /dev/null +++ b/nanovllm/debug/adapters/torch_adapter.py @@ -0,0 +1,119 @@ +"""Torch reference model adapter for breakpoint alignment.""" + +from typing import Generator, Set, Optional +import torch + +from ..breakpoints import Breakpoint, BreakpointType +from .base import SteppableModel + + +class TorchSteppable(SteppableModel): + """ + Steppable adapter for the torch reference Qwen3 implementation. + + Wraps tests/modeling_qwen3.py Qwen3ForCausalLM model. + """ + + def __init__( + self, + model, + enabled_breakpoints: Optional[Set[BreakpointType]] = None, + ): + """ + Args: + model: Qwen3ForCausalLM from tests/modeling_qwen3.py + enabled_breakpoints: Set of breakpoint types to yield at + """ + super().__init__(enabled_breakpoints) + self.model = model + self.model.eval() + + @property + def num_layers(self) -> int: + return len(self.model.model.layers) + + def step( + self, + input_ids: torch.Tensor, + positions: Optional[torch.Tensor] = None, + is_prefill: bool = True, + ) -> Generator[Breakpoint, None, torch.Tensor]: + """ + Generator that manually steps through the torch model. + + The torch model uses [batch, seq_len, hidden_size] shapes. + """ + # Ensure batch dimension + if input_ids.dim() == 1: + input_ids = input_ids.unsqueeze(0) + + batch_size, seq_len = input_ids.shape + device = input_ids.device + + # Generate position IDs if not provided + if positions is None: + positions = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1) + elif positions.dim() == 1: + positions = positions.unsqueeze(0) + + with torch.no_grad(): + # === EMBEDDING === + hidden_states = self.model.model.embed_tokens(input_ids) + + if self.is_enabled(BreakpointType.EMBEDDING): + yield Breakpoint( + bp_type=BreakpointType.EMBEDDING, + layer_idx=None, + tensor=hidden_states.detach().clone(), + name="Embedding", + ) + + # Create causal attention mask + causal_mask = torch.triu( + torch.full((seq_len, seq_len), float("-inf"), device=device), + diagonal=1, + ) + attention_mask = causal_mask.unsqueeze(0).unsqueeze(0) + + # === DECODER LAYERS === + for layer_idx, layer in enumerate(self.model.model.layers): + hidden_states, _, _ = layer( + hidden_states=hidden_states, + position_ids=positions, + attention_mask=attention_mask, + past_key_value=None, + use_cache=False, + output_qkv=False, + ) + + if self.is_enabled(BreakpointType.LAYER_OUTPUT): + yield Breakpoint( + bp_type=BreakpointType.LAYER_OUTPUT, + layer_idx=layer_idx, + tensor=hidden_states.detach().clone(), + name=f"Layer {layer_idx}", + ) + + # === FINAL NORM === + hidden_states = self.model.model.norm(hidden_states) + + if self.is_enabled(BreakpointType.FINAL_NORM): + yield Breakpoint( + bp_type=BreakpointType.FINAL_NORM, + layer_idx=None, + tensor=hidden_states.detach().clone(), + name="Final Norm", + ) + + # === LM HEAD === + logits = self.model.lm_head(hidden_states) + + if self.is_enabled(BreakpointType.LM_HEAD): + yield Breakpoint( + bp_type=BreakpointType.LM_HEAD, + layer_idx=None, + tensor=logits.detach().clone(), + name="LM Head", + ) + + return logits diff --git a/nanovllm/debug/aligner.py b/nanovllm/debug/aligner.py new file mode 100644 index 0000000..7d3d76a --- /dev/null +++ b/nanovllm/debug/aligner.py @@ -0,0 +1,211 @@ +"""Breakpoint aligner for comparing model outputs.""" + +from dataclasses import dataclass, field +from typing import List, Tuple, Optional +import torch + +from .breakpoints import Breakpoint +from .comparator import TensorComparator, ComparisonResult +from .adapters.base import SteppableModel + + +@dataclass +class AlignmentResult: + """Result of an alignment test.""" + passed: bool + all_comparisons: List[Tuple[Breakpoint, Breakpoint, ComparisonResult]] = field(default_factory=list) + failed_at: Optional[Breakpoint] = None + message: str = "" + + def __repr__(self) -> str: + passed_count = sum(1 for _, _, c in self.all_comparisons if c.passed) + total = len(self.all_comparisons) + status = "PASSED" if self.passed else "FAILED" + return f"AlignmentResult({status}, {passed_count}/{total} breakpoints passed)" + + +class BreakpointAligner: + """ + Orchestrates alternating execution of reference and test models, + comparing outputs at each breakpoint. + + Example: + >>> from nanovllm.debug import BreakpointAligner, TorchSteppable, NanovllmSteppable + >>> ref = TorchSteppable(torch_model) + >>> test = NanovllmSteppable(nanovllm_model) + >>> aligner = BreakpointAligner(ref, test) + >>> result = aligner.align(input_ids) + >>> print(result) + """ + + def __init__( + self, + ref_model: SteppableModel, + test_model: SteppableModel, + comparator: Optional[TensorComparator] = None, + stop_on_error: bool = True, + verbose: bool = True, + ): + """ + Args: + ref_model: Reference (torch) steppable model + test_model: Test (nanovllm) steppable model + comparator: Tensor comparator instance (uses default if None) + stop_on_error: If True, stop at first mismatch + verbose: If True, print comparison results + """ + self.ref_model = ref_model + self.test_model = test_model + self.comparator = comparator or TensorComparator() + self.stop_on_error = stop_on_error + self.verbose = verbose + + def align( + self, + input_ids: torch.Tensor, + positions: Optional[torch.Tensor] = None, + is_prefill: bool = True, + ) -> AlignmentResult: + """ + Run both models with same input, comparing at each breakpoint. + + Args: + input_ids: Input token IDs + positions: Position IDs (optional) + is_prefill: True for prefill phase, False for decode + + Returns: + AlignmentResult with pass/fail status and details + """ + all_comparisons: List[Tuple[Breakpoint, Breakpoint, ComparisonResult]] = [] + + if self.verbose: + phase = "prefill" if is_prefill else "decode" + print(f"\n{'='*60}") + print(f"Alignment Test ({phase})") + print(f"{'='*60}") + + # Start both generators + ref_gen = self.ref_model.step(input_ids, positions, is_prefill) + test_gen = self.test_model.step(input_ids, positions, is_prefill) + + try: + while True: + # Get next breakpoint from reference + try: + ref_bp = next(ref_gen) + except StopIteration: + break + + # Get corresponding breakpoint from test + try: + test_bp = next(test_gen) + except StopIteration: + if self.verbose: + print(f"Test model ended early at {ref_bp.name}") + return AlignmentResult( + passed=False, + all_comparisons=all_comparisons, + failed_at=ref_bp, + message=f"Test model ended early at {ref_bp.name}", + ) + + # Verify breakpoints match + if ref_bp.bp_type != test_bp.bp_type: + msg = f"Breakpoint type mismatch: {ref_bp.bp_type} vs {test_bp.bp_type}" + if self.verbose: + print(msg) + return AlignmentResult( + passed=False, + all_comparisons=all_comparisons, + failed_at=ref_bp, + message=msg, + ) + + if ref_bp.layer_idx != test_bp.layer_idx: + msg = f"Layer index mismatch: {ref_bp.layer_idx} vs {test_bp.layer_idx}" + if self.verbose: + print(msg) + return AlignmentResult( + passed=False, + all_comparisons=all_comparisons, + failed_at=ref_bp, + message=msg, + ) + + # Normalize shapes for comparison + ref_t = ref_bp.normalize_shape() + test_t = test_bp.normalize_shape() + + # Handle shape mismatches + if ref_t.shape != test_t.shape: + if self.verbose: + print(f"[{ref_bp.name}] Shape mismatch: ref={ref_t.shape} vs test={test_t.shape}") + + # Try to reshape if element count matches + if ref_t.numel() == test_t.numel(): + test_t = test_t.view(ref_t.shape) + else: + msg = f"Shape mismatch at {ref_bp.name}: {ref_t.shape} vs {test_t.shape}" + return AlignmentResult( + passed=False, + all_comparisons=all_comparisons, + failed_at=ref_bp, + message=msg, + ) + + # Compare tensors + result = self.comparator.compare(ref_t, test_t, ref_bp.name) + all_comparisons.append((ref_bp, test_bp, result)) + + if self.verbose: + status = "\u2713" if result.passed else "\u2717" + print(f"{status} [{ref_bp.name}] cos={result.cosine_similarity:.6f}, max_diff={result.max_abs_diff:.2e}") + + if not result.passed and self.stop_on_error: + if self.verbose: + print(f"\nStopped at {ref_bp.name} (stop_on_error=True)") + print(result.message) + return AlignmentResult( + passed=False, + all_comparisons=all_comparisons, + failed_at=ref_bp, + message=f"Alignment failed at {ref_bp.name}", + ) + + # Check for extra test breakpoints + try: + extra_bp = next(test_gen) + msg = f"Test model has extra breakpoints starting at {extra_bp.name}" + if self.verbose: + print(msg) + return AlignmentResult( + passed=False, + all_comparisons=all_comparisons, + message=msg, + ) + except StopIteration: + pass + + except Exception as e: + msg = f"Exception during alignment: {str(e)}" + if self.verbose: + print(msg) + raise + + # Summary + all_passed = all(comp[2].passed for comp in all_comparisons) + passed_count = sum(1 for _, _, c in all_comparisons if c.passed) + total = len(all_comparisons) + + if self.verbose: + print(f"{'='*60}") + status = "PASSED" if all_passed else "FAILED" + print(f"Result: {status} ({passed_count}/{total} breakpoints)") + print(f"{'='*60}\n") + + return AlignmentResult( + passed=all_passed, + all_comparisons=all_comparisons, + message="All breakpoints aligned" if all_passed else "Some breakpoints failed", + ) diff --git a/nanovllm/debug/breakpoints.py b/nanovllm/debug/breakpoints.py new file mode 100644 index 0000000..1cece8f --- /dev/null +++ b/nanovllm/debug/breakpoints.py @@ -0,0 +1,39 @@ +"""Breakpoint types and data structures for alignment debugging.""" + +from dataclasses import dataclass +from enum import Enum, auto +from typing import Optional +import torch + + +class BreakpointType(Enum): + """Types of breakpoints in the model forward pass.""" + EMBEDDING = auto() # After embed_tokens + LAYER_OUTPUT = auto() # After each decoder layer + FINAL_NORM = auto() # After final RMSNorm + LM_HEAD = auto() # After lm_head (logits) + + +@dataclass +class Breakpoint: + """A captured breakpoint with tensor data.""" + bp_type: BreakpointType + layer_idx: Optional[int] # None for EMBEDDING, FINAL_NORM, LM_HEAD + tensor: torch.Tensor + name: str + + def normalize_shape(self) -> torch.Tensor: + """ + Normalize tensor shape for comparison. + + nanovllm uses [num_tokens, hidden_size] while torch uses + [batch, seq_len, hidden_size]. This adds a batch dimension + to 2D tensors for comparison. + """ + if self.tensor.dim() == 2: + return self.tensor.unsqueeze(0) + return self.tensor + + def __repr__(self) -> str: + shape_str = "x".join(str(d) for d in self.tensor.shape) + return f"Breakpoint({self.name}, shape={shape_str}, dtype={self.tensor.dtype})" diff --git a/nanovllm/debug/comparator.py b/nanovllm/debug/comparator.py new file mode 100644 index 0000000..4f3eaa6 --- /dev/null +++ b/nanovllm/debug/comparator.py @@ -0,0 +1,94 @@ +"""Tensor comparison utilities for alignment debugging.""" + +from dataclasses import dataclass +import torch +import torch.nn.functional as F + + +@dataclass +class ComparisonResult: + """Result of comparing two tensors.""" + passed: bool + cosine_similarity: float + max_abs_diff: float + mean_abs_diff: float + message: str + + def __repr__(self) -> str: + status = "\u2713" if self.passed else "\u2717" + return f"{status} cos={self.cosine_similarity:.6f}, max_diff={self.max_abs_diff:.2e}" + + +class TensorComparator: + """Compares tensors using cosine similarity and absolute differences.""" + + def __init__( + self, + cosine_threshold: float = 0.999, + max_diff_threshold: float = 0.1, + mean_diff_threshold: float = 0.01, + ): + """ + Args: + cosine_threshold: Minimum cosine similarity to pass (0-1) + max_diff_threshold: Maximum allowed absolute difference + mean_diff_threshold: Maximum allowed mean absolute difference + """ + self.cosine_threshold = cosine_threshold + self.max_diff_threshold = max_diff_threshold + self.mean_diff_threshold = mean_diff_threshold + + def compare( + self, + ref: torch.Tensor, + test: torch.Tensor, + name: str = "", + ) -> ComparisonResult: + """ + Compare two tensors and return detailed result. + + Args: + ref: Reference tensor + test: Test tensor + name: Name for the comparison (used in message) + + Returns: + ComparisonResult with pass/fail status and metrics + """ + # Convert to float32 for comparison + ref_f = ref.float().flatten() + test_f = test.float().flatten() + + # Cosine similarity + cos_sim = F.cosine_similarity( + ref_f.unsqueeze(0), + test_f.unsqueeze(0) + ).item() + + # Absolute differences + diff = (ref.float() - test.float()).abs() + max_diff = diff.max().item() + mean_diff = diff.mean().item() + + # Check thresholds + passed = ( + cos_sim >= self.cosine_threshold and + max_diff <= self.max_diff_threshold and + mean_diff <= self.mean_diff_threshold + ) + + status = "PASS" if passed else "FAIL" + message = ( + f"[{name}] {status}\n" + f" Cosine Similarity: {cos_sim:.6f} (threshold: {self.cosine_threshold})\n" + f" Max Abs Diff: {max_diff:.6f} (threshold: {self.max_diff_threshold})\n" + f" Mean Abs Diff: {mean_diff:.6f} (threshold: {self.mean_diff_threshold})" + ) + + return ComparisonResult( + passed=passed, + cosine_similarity=cos_sim, + max_abs_diff=max_diff, + mean_abs_diff=mean_diff, + message=message, + ) diff --git a/nanovllm/debug/utils.py b/nanovllm/debug/utils.py new file mode 100644 index 0000000..87d771c --- /dev/null +++ b/nanovllm/debug/utils.py @@ -0,0 +1,51 @@ +"""Utility functions for breakpoint alignment debugging.""" + +import torch +from nanovllm.utils.context import set_context, reset_context + + +def setup_prefill_context(seq_len: int, device: torch.device): + """ + Set up nanovllm context for prefill alignment testing. + + Args: + seq_len: Sequence length + device: Target device + """ + cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=device) + slot_mapping = torch.full((seq_len,), -1, dtype=torch.int32, device=device) + + set_context( + is_prefill=True, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=seq_len, + max_seqlen_k=seq_len, + slot_mapping=slot_mapping, + is_chunked_prefill=False, + ) + + +def setup_decode_context(context_len: int, device: torch.device): + """ + Set up nanovllm context for decode alignment testing. + + Args: + context_len: Context length (number of previous tokens) + device: Target device + """ + context_lens = torch.tensor([context_len], dtype=torch.int32, device=device) + slot_mapping = torch.tensor([-1], dtype=torch.int32, device=device) + block_tables = torch.zeros((1, 1), dtype=torch.int32, device=device) + + set_context( + is_prefill=False, + slot_mapping=slot_mapping, + context_lens=context_lens, + block_tables=block_tables, + ) + + +def cleanup_context(): + """Reset nanovllm context after alignment testing.""" + reset_context() diff --git a/tests/test_nanovllm_steppable.py b/tests/test_nanovllm_steppable.py new file mode 100644 index 0000000..2381ef3 --- /dev/null +++ b/tests/test_nanovllm_steppable.py @@ -0,0 +1,134 @@ +""" +Test NanovllmSteppable: Print activation statistics at each layer. + +Usage: + python tests/test_nanovllm_steppable.py +""" + +import os +os.environ["NANOVLLM_LOG_LEVEL"] = "WARNING" + +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent)) + +import torch +from transformers import AutoTokenizer +from nanovllm import LLM +from nanovllm.debug.adapters.nanovllm_adapter import NanovllmSteppable +from utils import generate_needle_prompt, check_needle_answer + +# ============================================================ +# Config +# ============================================================ +MODEL_PATH = os.path.expanduser("~/models/Qwen3-0.6B/") +INPUT_LEN = 32768 # Longer context to test offload +MAX_NEW_TOKENS = 20 +DTYPE = torch.float16 +ENABLE_CPU_OFFLOAD = True # Test offload mode + +# ============================================================ +# Load Model +# ============================================================ +print(f"Loading nanovllm model (cpu_offload={ENABLE_CPU_OFFLOAD})...") +llm = LLM( + MODEL_PATH, + enforce_eager=True, # Required for hooks to work + max_model_len=40960, + max_num_batched_tokens=40960, + enable_cpu_offload=ENABLE_CPU_OFFLOAD, + dtype="float16", +) +tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) + +# Get the underlying model for steppable +model = llm.model_runner.model + +# ============================================================ +# Prepare Input (using needle-in-haystack prompt) +# ============================================================ +prompt, expected_answer = generate_needle_prompt( + tokenizer, + target_length=INPUT_LEN, + needle_position=0.5, + needle_value="7492", + use_chat_template=False, + verbose=True, +) +input_ids = tokenizer.encode(prompt, return_tensors="pt").to("cuda") +print(f"Input shape: {input_ids.shape}") +print(f"Expected answer: {expected_answer}\n") + +# ============================================================ +# Create Steppable Model (reused for prefill + decode) +# ============================================================ +steppable = NanovllmSteppable(model) + +# ============================================================ +# Prefill Phase: Print activation stats +# ============================================================ +print("=" * 85) +print("PREFILL PHASE") +print("=" * 85) +print(f"{'Layer':<15} {'Shape':<25} {'Mean':>10} {'Std':>10} {'Min':>10} {'Max':>10}") +print("-" * 85) + +current_ids = input_ids.clone() +logits = None + +for bp in steppable.step(current_ids, is_prefill=True): + t = bp.tensor.float() + shape_str = str(list(t.shape)) + print(f"{bp.name:<15} {shape_str:<25} {t.mean():>10.4f} {t.std():>10.4f} {t.min():>10.4f} {t.max():>10.4f}") + if bp.name == "LM Head": + logits = bp.tensor + +# Get first token from prefill +next_token_id = logits[0, -1].argmax().item() +next_token = tokenizer.decode(next_token_id) +current_ids = torch.cat([current_ids, torch.tensor([[next_token_id]], device=current_ids.device)], dim=1) +generated_tokens = [next_token] + +# ============================================================ +# Decode Phase: Only print generated tokens +# ============================================================ +print("\n" + "=" * 85) +print("DECODE PHASE") +print("=" * 85) +print(f"Step 1: {next_token!r}") + +for step in range(2, MAX_NEW_TOKENS + 1): + # Forward pass with full sequence (reuse same steppable) + # Note: nanovllm without KV cache needs full sequence for each decode + for bp in steppable.step(current_ids, is_prefill=True): + if bp.name == "LM Head": + logits = bp.tensor + + # Get next token (greedy) + next_token_id = logits[0, -1].argmax().item() + next_token = tokenizer.decode(next_token_id) + generated_tokens.append(next_token) + + print(f"Step {step:2}: {next_token!r}") + + # Stop if EOS + if next_token_id == tokenizer.eos_token_id: + print(" (EOS)") + break + + # Append to sequence + current_ids = torch.cat([current_ids, torch.tensor([[next_token_id]], device=current_ids.device)], dim=1) + +# ============================================================ +# Result +# ============================================================ +print("\n" + "=" * 85) +print("RESULT") +print("=" * 85) +generated_text = "".join(generated_tokens) +print(f"Generated: {generated_text!r}") +print(f"Expected: {expected_answer}") +print(f"Answer: {'CORRECT!' if check_needle_answer(generated_text, expected_answer) else 'INCORRECT'}") + +print("\ntest_nanovllm_steppable: PASSED") diff --git a/tests/test_needle.py b/tests/test_needle.py index 006a3d9..4f8661a 100644 --- a/tests/test_needle.py +++ b/tests/test_needle.py @@ -8,7 +8,7 @@ sequences longer than ~200 tokens. Use --no-offload for correctness testing. """ import os -os.environ["NANOVLLM_LOG_LEVEL"] = "DEBUG" +os.environ["NANOVLLM_LOG_LEVEL"] = "INFO" import argparse from nanovllm import LLM, SamplingParams diff --git a/tests/test_torch_steppable.py b/tests/test_torch_steppable.py new file mode 100644 index 0000000..63c38a7 --- /dev/null +++ b/tests/test_torch_steppable.py @@ -0,0 +1,121 @@ +""" +Test TorchSteppable: Print activation statistics at each layer. + +Usage: + python tests/test_torch_steppable.py +""" + +import os +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent)) + +import torch +from transformers import AutoTokenizer +from modeling_qwen3 import Qwen3ForCausalLM +from nanovllm.debug.adapters.torch_adapter import TorchSteppable +from utils import generate_needle_prompt, check_needle_answer + +# ============================================================ +# Config +# ============================================================ +MODEL_PATH = os.path.expanduser("~/models/Qwen3-0.6B/") +INPUT_LEN = 512 +MAX_NEW_TOKENS = 20 +DTYPE = torch.float16 + +# ============================================================ +# Load Model +# ============================================================ +print("Loading model...") +tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) +model = Qwen3ForCausalLM.from_pretrained(MODEL_PATH, dtype=DTYPE) +model = model.to("cuda").eval() + +# ============================================================ +# Prepare Input (using needle-in-haystack prompt) +# ============================================================ +prompt, expected_answer = generate_needle_prompt( + tokenizer, + target_length=INPUT_LEN, + needle_position=0.5, + needle_value="7492", + use_chat_template=False, + verbose=True, +) +input_ids = tokenizer.encode(prompt, return_tensors="pt").to("cuda") +print(f"Input shape: {input_ids.shape}") +print(f"Expected answer: {expected_answer}\n") + +# ============================================================ +# Create Steppable Model (reused for prefill + decode) +# ============================================================ +steppable = TorchSteppable(model) + +# ============================================================ +# Prefill Phase: Print activation stats +# ============================================================ +print("=" * 85) +print("PREFILL PHASE") +print("=" * 85) +print(f"{'Layer':<15} {'Shape':<25} {'Mean':>10} {'Std':>10} {'Min':>10} {'Max':>10}") +print("-" * 85) + +current_ids = input_ids.clone() +logits = None + +for bp in steppable.step(current_ids): + t = bp.tensor.float() + shape_str = str(list(t.shape)) + print(f"{bp.name:<15} {shape_str:<25} {t.mean():>10.4f} {t.std():>10.4f} {t.min():>10.4f} {t.max():>10.4f}") + if bp.name == "LM Head": + logits = bp.tensor + +# Get first token from prefill +next_token_id = logits[0, -1].argmax().item() +next_token = tokenizer.decode(next_token_id) +current_ids = torch.cat([current_ids, torch.tensor([[next_token_id]], device=current_ids.device)], dim=1) +generated_tokens = [next_token] + +# ============================================================ +# Decode Phase: Only print generated tokens +# ============================================================ +print("\n" + "=" * 85) +print("DECODE PHASE") +print("=" * 85) +print(f"Step 1: {next_token!r}") + +for step in range(2, MAX_NEW_TOKENS + 1): + # Forward pass (reuse same steppable) + for bp in steppable.step(current_ids): + if bp.name == "LM Head": + logits = bp.tensor + + # Get next token (greedy) + next_token_id = logits[0, -1].argmax().item() + next_token = tokenizer.decode(next_token_id) + generated_tokens.append(next_token) + + print(f"Step {step:2}: {next_token!r}") + + # Stop if EOS + if next_token_id == tokenizer.eos_token_id: + print(" (EOS)") + break + + # Append to sequence + current_ids = torch.cat([current_ids, torch.tensor([[next_token_id]], device=current_ids.device)], dim=1) + +# ============================================================ +# Result +# ============================================================ +print("\n" + "=" * 85) +print("RESULT") +print("=" * 85) +generated_text = "".join(generated_tokens) +print(f"Generated: {generated_text!r}") +print(f"Expected: {expected_answer}") +print(f"Answer: {'CORRECT!' if check_needle_answer(generated_text, expected_answer) else 'INCORRECT'}") + +print("\ntest_torch_steppable: PASSED")