[feat] Added debug tools.
This commit is contained in:
49
nanovllm/debug/__init__.py
Normal file
49
nanovllm/debug/__init__.py
Normal file
@@ -0,0 +1,49 @@
|
||||
"""
|
||||
Breakpoint debugging tools for aligning nanovllm with reference implementations.
|
||||
|
||||
This module provides a generator-based breakpoint aligner that enables step-by-step
|
||||
comparison between nanovllm and torch reference model outputs.
|
||||
|
||||
Example:
|
||||
>>> from nanovllm.debug import BreakpointAligner, TorchSteppable, NanovllmSteppable
|
||||
>>> from tests.modeling_qwen3 import Qwen3ForCausalLM
|
||||
>>>
|
||||
>>> # Load models
|
||||
>>> torch_model = Qwen3ForCausalLM.from_pretrained(model_path, dtype=torch.float16)
|
||||
>>> nanovllm_model = ... # Your nanovllm model
|
||||
>>>
|
||||
>>> # Create adapters
|
||||
>>> ref = TorchSteppable(torch_model)
|
||||
>>> test = NanovllmSteppable(nanovllm_model)
|
||||
>>>
|
||||
>>> # Run alignment
|
||||
>>> aligner = BreakpointAligner(ref, test)
|
||||
>>> result = aligner.align(input_ids)
|
||||
>>> print(result)
|
||||
"""
|
||||
|
||||
from .breakpoints import BreakpointType, Breakpoint
|
||||
from .comparator import TensorComparator, ComparisonResult
|
||||
from .aligner import BreakpointAligner, AlignmentResult
|
||||
from .adapters import SteppableModel, TorchSteppable, NanovllmSteppable
|
||||
from .utils import setup_prefill_context, setup_decode_context, cleanup_context
|
||||
|
||||
__all__ = [
|
||||
# Core classes
|
||||
"BreakpointAligner",
|
||||
"AlignmentResult",
|
||||
# Breakpoints
|
||||
"BreakpointType",
|
||||
"Breakpoint",
|
||||
# Comparator
|
||||
"TensorComparator",
|
||||
"ComparisonResult",
|
||||
# Adapters
|
||||
"SteppableModel",
|
||||
"TorchSteppable",
|
||||
"NanovllmSteppable",
|
||||
# Utils
|
||||
"setup_prefill_context",
|
||||
"setup_decode_context",
|
||||
"cleanup_context",
|
||||
]
|
||||
11
nanovllm/debug/adapters/__init__.py
Normal file
11
nanovllm/debug/adapters/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""Model adapters for breakpoint alignment."""
|
||||
|
||||
from .base import SteppableModel
|
||||
from .torch_adapter import TorchSteppable
|
||||
from .nanovllm_adapter import NanovllmSteppable
|
||||
|
||||
__all__ = [
|
||||
"SteppableModel",
|
||||
"TorchSteppable",
|
||||
"NanovllmSteppable",
|
||||
]
|
||||
59
nanovllm/debug/adapters/base.py
Normal file
59
nanovllm/debug/adapters/base.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""Base class for steppable model adapters."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Generator, Set, Optional
|
||||
import torch
|
||||
|
||||
from ..breakpoints import Breakpoint, BreakpointType
|
||||
|
||||
|
||||
class SteppableModel(ABC):
|
||||
"""
|
||||
Abstract base class for models that can yield at breakpoints.
|
||||
|
||||
Subclasses implement the step() method as a generator that yields
|
||||
Breakpoint objects at each enabled breakpoint during forward pass.
|
||||
"""
|
||||
|
||||
def __init__(self, enabled_breakpoints: Optional[Set[BreakpointType]] = None):
|
||||
"""
|
||||
Args:
|
||||
enabled_breakpoints: Set of breakpoint types to yield at.
|
||||
If None, yields at all breakpoints.
|
||||
"""
|
||||
self.enabled_breakpoints = enabled_breakpoints
|
||||
|
||||
def is_enabled(self, bp_type: BreakpointType) -> bool:
|
||||
"""Check if a breakpoint type is enabled."""
|
||||
if self.enabled_breakpoints is None:
|
||||
return True
|
||||
return bp_type in self.enabled_breakpoints
|
||||
|
||||
@abstractmethod
|
||||
def step(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: Optional[torch.Tensor] = None,
|
||||
is_prefill: bool = True,
|
||||
) -> Generator[Breakpoint, None, torch.Tensor]:
|
||||
"""
|
||||
Generator that yields Breakpoint objects at enabled breakpoints.
|
||||
|
||||
Args:
|
||||
input_ids: Input token IDs
|
||||
positions: Position IDs (optional, auto-generated if None)
|
||||
is_prefill: True for prefill phase, False for decode
|
||||
|
||||
Yields:
|
||||
Breakpoint objects at each enabled checkpoint
|
||||
|
||||
Returns:
|
||||
Final output tensor (logits)
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def num_layers(self) -> int:
|
||||
"""Return the number of decoder layers."""
|
||||
pass
|
||||
229
nanovllm/debug/adapters/nanovllm_adapter.py
Normal file
229
nanovllm/debug/adapters/nanovllm_adapter.py
Normal file
@@ -0,0 +1,229 @@
|
||||
"""Nanovllm model adapter for breakpoint alignment."""
|
||||
|
||||
from typing import Generator, Set, Optional, Dict, Any, List
|
||||
import torch
|
||||
|
||||
from nanovllm.utils.context import set_context, reset_context
|
||||
from ..breakpoints import Breakpoint, BreakpointType
|
||||
from .base import SteppableModel
|
||||
|
||||
|
||||
class NanovllmSteppable(SteppableModel):
|
||||
"""
|
||||
Steppable adapter for nanovllm Qwen3 implementation.
|
||||
|
||||
Uses PyTorch hooks to capture intermediate values during forward pass,
|
||||
then yields them as breakpoints after execution completes.
|
||||
|
||||
Key challenges handled:
|
||||
1. Shape difference: nanovllm uses [num_tokens, hidden] vs [batch, seq, hidden]
|
||||
2. Context-based attention: must call set_context() before forward
|
||||
3. Fused operations: decoder layer returns (hidden_states, residual) tuple
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
enabled_breakpoints: Optional[Set[BreakpointType]] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
model: Qwen3ForCausalLM from nanovllm
|
||||
enabled_breakpoints: Set of breakpoint types to yield at
|
||||
"""
|
||||
super().__init__(enabled_breakpoints)
|
||||
self.model = model
|
||||
self.model.eval()
|
||||
self._hooks: List[Any] = []
|
||||
self._captured: Dict[str, torch.Tensor] = {}
|
||||
|
||||
@property
|
||||
def num_layers(self) -> int:
|
||||
return len(self.model.model.layers)
|
||||
|
||||
def _register_hooks(self):
|
||||
"""Register forward hooks on all relevant modules."""
|
||||
self._hooks = []
|
||||
self._captured = {}
|
||||
|
||||
# Hook for embedding output
|
||||
def embed_hook(module, input, output):
|
||||
self._captured["embed"] = output.detach().clone()
|
||||
|
||||
self._hooks.append(
|
||||
self.model.model.embed_tokens.register_forward_hook(embed_hook)
|
||||
)
|
||||
|
||||
# Hooks for each decoder layer
|
||||
for layer_idx in range(self.num_layers):
|
||||
layer = self.model.model.layers[layer_idx]
|
||||
|
||||
def make_layer_hook(idx):
|
||||
def hook(module, input, output):
|
||||
# Decoder layer returns (hidden_states, residual)
|
||||
hidden_states = output[0] if isinstance(output, tuple) else output
|
||||
self._captured[f"layer_{idx}"] = hidden_states.detach().clone()
|
||||
return hook
|
||||
|
||||
self._hooks.append(
|
||||
layer.register_forward_hook(make_layer_hook(layer_idx))
|
||||
)
|
||||
|
||||
# Hook for final norm
|
||||
def final_norm_hook(module, input, output):
|
||||
# Final norm returns (hidden_states, _) for fused add
|
||||
hidden_states = output[0] if isinstance(output, tuple) else output
|
||||
self._captured["final_norm"] = hidden_states.detach().clone()
|
||||
|
||||
self._hooks.append(
|
||||
self.model.model.norm.register_forward_hook(final_norm_hook)
|
||||
)
|
||||
|
||||
# Hook for lm_head
|
||||
def lm_head_hook(module, input, output):
|
||||
self._captured["lm_head"] = output.detach().clone()
|
||||
|
||||
self._hooks.append(
|
||||
self.model.lm_head.register_forward_hook(lm_head_hook)
|
||||
)
|
||||
|
||||
def _remove_hooks(self):
|
||||
"""Remove all registered hooks."""
|
||||
for hook in self._hooks:
|
||||
hook.remove()
|
||||
self._hooks = []
|
||||
|
||||
def _setup_context(self, seq_len: int, device: torch.device, is_prefill: bool):
|
||||
"""
|
||||
Set up nanovllm context for forward pass.
|
||||
|
||||
For alignment testing, we use simple context without real KV cache.
|
||||
"""
|
||||
if is_prefill:
|
||||
# Prefill: process all tokens at once
|
||||
cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=device)
|
||||
# Use -1 for slot_mapping to skip KV cache writes
|
||||
slot_mapping = torch.full((seq_len,), -1, dtype=torch.int32, device=device)
|
||||
|
||||
set_context(
|
||||
is_prefill=True,
|
||||
cu_seqlens_q=cu_seqlens,
|
||||
cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=seq_len,
|
||||
max_seqlen_k=seq_len,
|
||||
slot_mapping=slot_mapping,
|
||||
is_chunked_prefill=False,
|
||||
)
|
||||
else:
|
||||
# Decode: single token generation
|
||||
# For decode, we need context_lens and block_tables
|
||||
# For alignment testing without real KV cache, we use minimal setup
|
||||
context_lens = torch.tensor([seq_len - 1], dtype=torch.int32, device=device)
|
||||
# Single token slot
|
||||
slot_mapping = torch.tensor([-1], dtype=torch.int32, device=device)
|
||||
# Empty block tables (no KV cache)
|
||||
block_tables = torch.zeros((1, 1), dtype=torch.int32, device=device)
|
||||
|
||||
set_context(
|
||||
is_prefill=False,
|
||||
slot_mapping=slot_mapping,
|
||||
context_lens=context_lens,
|
||||
block_tables=block_tables,
|
||||
)
|
||||
|
||||
def _normalize_tensor(self, tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Normalize nanovllm tensor shape to [batch, seq_len, ...].
|
||||
|
||||
nanovllm uses [num_tokens, ...] format without batch dimension.
|
||||
We add batch dimension for comparison with torch model.
|
||||
"""
|
||||
if tensor.dim() == 2: # [num_tokens, hidden_size]
|
||||
return tensor.unsqueeze(0)
|
||||
return tensor
|
||||
|
||||
def step(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: Optional[torch.Tensor] = None,
|
||||
is_prefill: bool = True,
|
||||
) -> Generator[Breakpoint, None, torch.Tensor]:
|
||||
"""
|
||||
Execute nanovllm forward pass with hooks to capture breakpoints.
|
||||
|
||||
Unlike the torch adapter which manually steps through each component,
|
||||
we run the full forward pass and collect captured values afterward.
|
||||
"""
|
||||
# Ensure 1D for nanovllm (it expects [num_tokens])
|
||||
if input_ids.dim() == 2:
|
||||
input_ids = input_ids.squeeze(0)
|
||||
|
||||
seq_len = input_ids.numel()
|
||||
device = input_ids.device
|
||||
|
||||
# Generate position IDs if not provided
|
||||
if positions is None:
|
||||
positions = torch.arange(seq_len, device=device)
|
||||
elif positions.dim() == 2:
|
||||
positions = positions.squeeze(0)
|
||||
|
||||
# Register hooks
|
||||
self._register_hooks()
|
||||
|
||||
try:
|
||||
# Setup context for attention
|
||||
self._setup_context(seq_len, device, is_prefill)
|
||||
|
||||
# Run forward pass (hooks capture everything)
|
||||
with torch.no_grad():
|
||||
hidden_states = self.model(input_ids, positions)
|
||||
logits = self.model.compute_logits(hidden_states)
|
||||
|
||||
reset_context()
|
||||
|
||||
# Yield breakpoints in order from captured data
|
||||
|
||||
# EMBEDDING
|
||||
if self.is_enabled(BreakpointType.EMBEDDING) and "embed" in self._captured:
|
||||
yield Breakpoint(
|
||||
bp_type=BreakpointType.EMBEDDING,
|
||||
layer_idx=None,
|
||||
tensor=self._normalize_tensor(self._captured["embed"]),
|
||||
name="Embedding",
|
||||
)
|
||||
|
||||
# LAYER_OUTPUT for each layer
|
||||
if self.is_enabled(BreakpointType.LAYER_OUTPUT):
|
||||
for layer_idx in range(self.num_layers):
|
||||
key = f"layer_{layer_idx}"
|
||||
if key in self._captured:
|
||||
yield Breakpoint(
|
||||
bp_type=BreakpointType.LAYER_OUTPUT,
|
||||
layer_idx=layer_idx,
|
||||
tensor=self._normalize_tensor(self._captured[key]),
|
||||
name=f"Layer {layer_idx}",
|
||||
)
|
||||
|
||||
# FINAL_NORM
|
||||
if self.is_enabled(BreakpointType.FINAL_NORM) and "final_norm" in self._captured:
|
||||
yield Breakpoint(
|
||||
bp_type=BreakpointType.FINAL_NORM,
|
||||
layer_idx=None,
|
||||
tensor=self._normalize_tensor(self._captured["final_norm"]),
|
||||
name="Final Norm",
|
||||
)
|
||||
|
||||
# LM_HEAD
|
||||
if self.is_enabled(BreakpointType.LM_HEAD) and "lm_head" in self._captured:
|
||||
yield Breakpoint(
|
||||
bp_type=BreakpointType.LM_HEAD,
|
||||
layer_idx=None,
|
||||
tensor=self._normalize_tensor(self._captured["lm_head"]),
|
||||
name="LM Head",
|
||||
)
|
||||
|
||||
return logits
|
||||
|
||||
finally:
|
||||
self._remove_hooks()
|
||||
self._captured = {}
|
||||
119
nanovllm/debug/adapters/torch_adapter.py
Normal file
119
nanovllm/debug/adapters/torch_adapter.py
Normal file
@@ -0,0 +1,119 @@
|
||||
"""Torch reference model adapter for breakpoint alignment."""
|
||||
|
||||
from typing import Generator, Set, Optional
|
||||
import torch
|
||||
|
||||
from ..breakpoints import Breakpoint, BreakpointType
|
||||
from .base import SteppableModel
|
||||
|
||||
|
||||
class TorchSteppable(SteppableModel):
|
||||
"""
|
||||
Steppable adapter for the torch reference Qwen3 implementation.
|
||||
|
||||
Wraps tests/modeling_qwen3.py Qwen3ForCausalLM model.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
enabled_breakpoints: Optional[Set[BreakpointType]] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
model: Qwen3ForCausalLM from tests/modeling_qwen3.py
|
||||
enabled_breakpoints: Set of breakpoint types to yield at
|
||||
"""
|
||||
super().__init__(enabled_breakpoints)
|
||||
self.model = model
|
||||
self.model.eval()
|
||||
|
||||
@property
|
||||
def num_layers(self) -> int:
|
||||
return len(self.model.model.layers)
|
||||
|
||||
def step(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: Optional[torch.Tensor] = None,
|
||||
is_prefill: bool = True,
|
||||
) -> Generator[Breakpoint, None, torch.Tensor]:
|
||||
"""
|
||||
Generator that manually steps through the torch model.
|
||||
|
||||
The torch model uses [batch, seq_len, hidden_size] shapes.
|
||||
"""
|
||||
# Ensure batch dimension
|
||||
if input_ids.dim() == 1:
|
||||
input_ids = input_ids.unsqueeze(0)
|
||||
|
||||
batch_size, seq_len = input_ids.shape
|
||||
device = input_ids.device
|
||||
|
||||
# Generate position IDs if not provided
|
||||
if positions is None:
|
||||
positions = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1)
|
||||
elif positions.dim() == 1:
|
||||
positions = positions.unsqueeze(0)
|
||||
|
||||
with torch.no_grad():
|
||||
# === EMBEDDING ===
|
||||
hidden_states = self.model.model.embed_tokens(input_ids)
|
||||
|
||||
if self.is_enabled(BreakpointType.EMBEDDING):
|
||||
yield Breakpoint(
|
||||
bp_type=BreakpointType.EMBEDDING,
|
||||
layer_idx=None,
|
||||
tensor=hidden_states.detach().clone(),
|
||||
name="Embedding",
|
||||
)
|
||||
|
||||
# Create causal attention mask
|
||||
causal_mask = torch.triu(
|
||||
torch.full((seq_len, seq_len), float("-inf"), device=device),
|
||||
diagonal=1,
|
||||
)
|
||||
attention_mask = causal_mask.unsqueeze(0).unsqueeze(0)
|
||||
|
||||
# === DECODER LAYERS ===
|
||||
for layer_idx, layer in enumerate(self.model.model.layers):
|
||||
hidden_states, _, _ = layer(
|
||||
hidden_states=hidden_states,
|
||||
position_ids=positions,
|
||||
attention_mask=attention_mask,
|
||||
past_key_value=None,
|
||||
use_cache=False,
|
||||
output_qkv=False,
|
||||
)
|
||||
|
||||
if self.is_enabled(BreakpointType.LAYER_OUTPUT):
|
||||
yield Breakpoint(
|
||||
bp_type=BreakpointType.LAYER_OUTPUT,
|
||||
layer_idx=layer_idx,
|
||||
tensor=hidden_states.detach().clone(),
|
||||
name=f"Layer {layer_idx}",
|
||||
)
|
||||
|
||||
# === FINAL NORM ===
|
||||
hidden_states = self.model.model.norm(hidden_states)
|
||||
|
||||
if self.is_enabled(BreakpointType.FINAL_NORM):
|
||||
yield Breakpoint(
|
||||
bp_type=BreakpointType.FINAL_NORM,
|
||||
layer_idx=None,
|
||||
tensor=hidden_states.detach().clone(),
|
||||
name="Final Norm",
|
||||
)
|
||||
|
||||
# === LM HEAD ===
|
||||
logits = self.model.lm_head(hidden_states)
|
||||
|
||||
if self.is_enabled(BreakpointType.LM_HEAD):
|
||||
yield Breakpoint(
|
||||
bp_type=BreakpointType.LM_HEAD,
|
||||
layer_idx=None,
|
||||
tensor=logits.detach().clone(),
|
||||
name="LM Head",
|
||||
)
|
||||
|
||||
return logits
|
||||
211
nanovllm/debug/aligner.py
Normal file
211
nanovllm/debug/aligner.py
Normal file
@@ -0,0 +1,211 @@
|
||||
"""Breakpoint aligner for comparing model outputs."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Tuple, Optional
|
||||
import torch
|
||||
|
||||
from .breakpoints import Breakpoint
|
||||
from .comparator import TensorComparator, ComparisonResult
|
||||
from .adapters.base import SteppableModel
|
||||
|
||||
|
||||
@dataclass
|
||||
class AlignmentResult:
|
||||
"""Result of an alignment test."""
|
||||
passed: bool
|
||||
all_comparisons: List[Tuple[Breakpoint, Breakpoint, ComparisonResult]] = field(default_factory=list)
|
||||
failed_at: Optional[Breakpoint] = None
|
||||
message: str = ""
|
||||
|
||||
def __repr__(self) -> str:
|
||||
passed_count = sum(1 for _, _, c in self.all_comparisons if c.passed)
|
||||
total = len(self.all_comparisons)
|
||||
status = "PASSED" if self.passed else "FAILED"
|
||||
return f"AlignmentResult({status}, {passed_count}/{total} breakpoints passed)"
|
||||
|
||||
|
||||
class BreakpointAligner:
|
||||
"""
|
||||
Orchestrates alternating execution of reference and test models,
|
||||
comparing outputs at each breakpoint.
|
||||
|
||||
Example:
|
||||
>>> from nanovllm.debug import BreakpointAligner, TorchSteppable, NanovllmSteppable
|
||||
>>> ref = TorchSteppable(torch_model)
|
||||
>>> test = NanovllmSteppable(nanovllm_model)
|
||||
>>> aligner = BreakpointAligner(ref, test)
|
||||
>>> result = aligner.align(input_ids)
|
||||
>>> print(result)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ref_model: SteppableModel,
|
||||
test_model: SteppableModel,
|
||||
comparator: Optional[TensorComparator] = None,
|
||||
stop_on_error: bool = True,
|
||||
verbose: bool = True,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
ref_model: Reference (torch) steppable model
|
||||
test_model: Test (nanovllm) steppable model
|
||||
comparator: Tensor comparator instance (uses default if None)
|
||||
stop_on_error: If True, stop at first mismatch
|
||||
verbose: If True, print comparison results
|
||||
"""
|
||||
self.ref_model = ref_model
|
||||
self.test_model = test_model
|
||||
self.comparator = comparator or TensorComparator()
|
||||
self.stop_on_error = stop_on_error
|
||||
self.verbose = verbose
|
||||
|
||||
def align(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: Optional[torch.Tensor] = None,
|
||||
is_prefill: bool = True,
|
||||
) -> AlignmentResult:
|
||||
"""
|
||||
Run both models with same input, comparing at each breakpoint.
|
||||
|
||||
Args:
|
||||
input_ids: Input token IDs
|
||||
positions: Position IDs (optional)
|
||||
is_prefill: True for prefill phase, False for decode
|
||||
|
||||
Returns:
|
||||
AlignmentResult with pass/fail status and details
|
||||
"""
|
||||
all_comparisons: List[Tuple[Breakpoint, Breakpoint, ComparisonResult]] = []
|
||||
|
||||
if self.verbose:
|
||||
phase = "prefill" if is_prefill else "decode"
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Alignment Test ({phase})")
|
||||
print(f"{'='*60}")
|
||||
|
||||
# Start both generators
|
||||
ref_gen = self.ref_model.step(input_ids, positions, is_prefill)
|
||||
test_gen = self.test_model.step(input_ids, positions, is_prefill)
|
||||
|
||||
try:
|
||||
while True:
|
||||
# Get next breakpoint from reference
|
||||
try:
|
||||
ref_bp = next(ref_gen)
|
||||
except StopIteration:
|
||||
break
|
||||
|
||||
# Get corresponding breakpoint from test
|
||||
try:
|
||||
test_bp = next(test_gen)
|
||||
except StopIteration:
|
||||
if self.verbose:
|
||||
print(f"Test model ended early at {ref_bp.name}")
|
||||
return AlignmentResult(
|
||||
passed=False,
|
||||
all_comparisons=all_comparisons,
|
||||
failed_at=ref_bp,
|
||||
message=f"Test model ended early at {ref_bp.name}",
|
||||
)
|
||||
|
||||
# Verify breakpoints match
|
||||
if ref_bp.bp_type != test_bp.bp_type:
|
||||
msg = f"Breakpoint type mismatch: {ref_bp.bp_type} vs {test_bp.bp_type}"
|
||||
if self.verbose:
|
||||
print(msg)
|
||||
return AlignmentResult(
|
||||
passed=False,
|
||||
all_comparisons=all_comparisons,
|
||||
failed_at=ref_bp,
|
||||
message=msg,
|
||||
)
|
||||
|
||||
if ref_bp.layer_idx != test_bp.layer_idx:
|
||||
msg = f"Layer index mismatch: {ref_bp.layer_idx} vs {test_bp.layer_idx}"
|
||||
if self.verbose:
|
||||
print(msg)
|
||||
return AlignmentResult(
|
||||
passed=False,
|
||||
all_comparisons=all_comparisons,
|
||||
failed_at=ref_bp,
|
||||
message=msg,
|
||||
)
|
||||
|
||||
# Normalize shapes for comparison
|
||||
ref_t = ref_bp.normalize_shape()
|
||||
test_t = test_bp.normalize_shape()
|
||||
|
||||
# Handle shape mismatches
|
||||
if ref_t.shape != test_t.shape:
|
||||
if self.verbose:
|
||||
print(f"[{ref_bp.name}] Shape mismatch: ref={ref_t.shape} vs test={test_t.shape}")
|
||||
|
||||
# Try to reshape if element count matches
|
||||
if ref_t.numel() == test_t.numel():
|
||||
test_t = test_t.view(ref_t.shape)
|
||||
else:
|
||||
msg = f"Shape mismatch at {ref_bp.name}: {ref_t.shape} vs {test_t.shape}"
|
||||
return AlignmentResult(
|
||||
passed=False,
|
||||
all_comparisons=all_comparisons,
|
||||
failed_at=ref_bp,
|
||||
message=msg,
|
||||
)
|
||||
|
||||
# Compare tensors
|
||||
result = self.comparator.compare(ref_t, test_t, ref_bp.name)
|
||||
all_comparisons.append((ref_bp, test_bp, result))
|
||||
|
||||
if self.verbose:
|
||||
status = "\u2713" if result.passed else "\u2717"
|
||||
print(f"{status} [{ref_bp.name}] cos={result.cosine_similarity:.6f}, max_diff={result.max_abs_diff:.2e}")
|
||||
|
||||
if not result.passed and self.stop_on_error:
|
||||
if self.verbose:
|
||||
print(f"\nStopped at {ref_bp.name} (stop_on_error=True)")
|
||||
print(result.message)
|
||||
return AlignmentResult(
|
||||
passed=False,
|
||||
all_comparisons=all_comparisons,
|
||||
failed_at=ref_bp,
|
||||
message=f"Alignment failed at {ref_bp.name}",
|
||||
)
|
||||
|
||||
# Check for extra test breakpoints
|
||||
try:
|
||||
extra_bp = next(test_gen)
|
||||
msg = f"Test model has extra breakpoints starting at {extra_bp.name}"
|
||||
if self.verbose:
|
||||
print(msg)
|
||||
return AlignmentResult(
|
||||
passed=False,
|
||||
all_comparisons=all_comparisons,
|
||||
message=msg,
|
||||
)
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
msg = f"Exception during alignment: {str(e)}"
|
||||
if self.verbose:
|
||||
print(msg)
|
||||
raise
|
||||
|
||||
# Summary
|
||||
all_passed = all(comp[2].passed for comp in all_comparisons)
|
||||
passed_count = sum(1 for _, _, c in all_comparisons if c.passed)
|
||||
total = len(all_comparisons)
|
||||
|
||||
if self.verbose:
|
||||
print(f"{'='*60}")
|
||||
status = "PASSED" if all_passed else "FAILED"
|
||||
print(f"Result: {status} ({passed_count}/{total} breakpoints)")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
return AlignmentResult(
|
||||
passed=all_passed,
|
||||
all_comparisons=all_comparisons,
|
||||
message="All breakpoints aligned" if all_passed else "Some breakpoints failed",
|
||||
)
|
||||
39
nanovllm/debug/breakpoints.py
Normal file
39
nanovllm/debug/breakpoints.py
Normal file
@@ -0,0 +1,39 @@
|
||||
"""Breakpoint types and data structures for alignment debugging."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from typing import Optional
|
||||
import torch
|
||||
|
||||
|
||||
class BreakpointType(Enum):
|
||||
"""Types of breakpoints in the model forward pass."""
|
||||
EMBEDDING = auto() # After embed_tokens
|
||||
LAYER_OUTPUT = auto() # After each decoder layer
|
||||
FINAL_NORM = auto() # After final RMSNorm
|
||||
LM_HEAD = auto() # After lm_head (logits)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Breakpoint:
|
||||
"""A captured breakpoint with tensor data."""
|
||||
bp_type: BreakpointType
|
||||
layer_idx: Optional[int] # None for EMBEDDING, FINAL_NORM, LM_HEAD
|
||||
tensor: torch.Tensor
|
||||
name: str
|
||||
|
||||
def normalize_shape(self) -> torch.Tensor:
|
||||
"""
|
||||
Normalize tensor shape for comparison.
|
||||
|
||||
nanovllm uses [num_tokens, hidden_size] while torch uses
|
||||
[batch, seq_len, hidden_size]. This adds a batch dimension
|
||||
to 2D tensors for comparison.
|
||||
"""
|
||||
if self.tensor.dim() == 2:
|
||||
return self.tensor.unsqueeze(0)
|
||||
return self.tensor
|
||||
|
||||
def __repr__(self) -> str:
|
||||
shape_str = "x".join(str(d) for d in self.tensor.shape)
|
||||
return f"Breakpoint({self.name}, shape={shape_str}, dtype={self.tensor.dtype})"
|
||||
94
nanovllm/debug/comparator.py
Normal file
94
nanovllm/debug/comparator.py
Normal file
@@ -0,0 +1,94 @@
|
||||
"""Tensor comparison utilities for alignment debugging."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
@dataclass
|
||||
class ComparisonResult:
|
||||
"""Result of comparing two tensors."""
|
||||
passed: bool
|
||||
cosine_similarity: float
|
||||
max_abs_diff: float
|
||||
mean_abs_diff: float
|
||||
message: str
|
||||
|
||||
def __repr__(self) -> str:
|
||||
status = "\u2713" if self.passed else "\u2717"
|
||||
return f"{status} cos={self.cosine_similarity:.6f}, max_diff={self.max_abs_diff:.2e}"
|
||||
|
||||
|
||||
class TensorComparator:
|
||||
"""Compares tensors using cosine similarity and absolute differences."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cosine_threshold: float = 0.999,
|
||||
max_diff_threshold: float = 0.1,
|
||||
mean_diff_threshold: float = 0.01,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
cosine_threshold: Minimum cosine similarity to pass (0-1)
|
||||
max_diff_threshold: Maximum allowed absolute difference
|
||||
mean_diff_threshold: Maximum allowed mean absolute difference
|
||||
"""
|
||||
self.cosine_threshold = cosine_threshold
|
||||
self.max_diff_threshold = max_diff_threshold
|
||||
self.mean_diff_threshold = mean_diff_threshold
|
||||
|
||||
def compare(
|
||||
self,
|
||||
ref: torch.Tensor,
|
||||
test: torch.Tensor,
|
||||
name: str = "",
|
||||
) -> ComparisonResult:
|
||||
"""
|
||||
Compare two tensors and return detailed result.
|
||||
|
||||
Args:
|
||||
ref: Reference tensor
|
||||
test: Test tensor
|
||||
name: Name for the comparison (used in message)
|
||||
|
||||
Returns:
|
||||
ComparisonResult with pass/fail status and metrics
|
||||
"""
|
||||
# Convert to float32 for comparison
|
||||
ref_f = ref.float().flatten()
|
||||
test_f = test.float().flatten()
|
||||
|
||||
# Cosine similarity
|
||||
cos_sim = F.cosine_similarity(
|
||||
ref_f.unsqueeze(0),
|
||||
test_f.unsqueeze(0)
|
||||
).item()
|
||||
|
||||
# Absolute differences
|
||||
diff = (ref.float() - test.float()).abs()
|
||||
max_diff = diff.max().item()
|
||||
mean_diff = diff.mean().item()
|
||||
|
||||
# Check thresholds
|
||||
passed = (
|
||||
cos_sim >= self.cosine_threshold and
|
||||
max_diff <= self.max_diff_threshold and
|
||||
mean_diff <= self.mean_diff_threshold
|
||||
)
|
||||
|
||||
status = "PASS" if passed else "FAIL"
|
||||
message = (
|
||||
f"[{name}] {status}\n"
|
||||
f" Cosine Similarity: {cos_sim:.6f} (threshold: {self.cosine_threshold})\n"
|
||||
f" Max Abs Diff: {max_diff:.6f} (threshold: {self.max_diff_threshold})\n"
|
||||
f" Mean Abs Diff: {mean_diff:.6f} (threshold: {self.mean_diff_threshold})"
|
||||
)
|
||||
|
||||
return ComparisonResult(
|
||||
passed=passed,
|
||||
cosine_similarity=cos_sim,
|
||||
max_abs_diff=max_diff,
|
||||
mean_abs_diff=mean_diff,
|
||||
message=message,
|
||||
)
|
||||
51
nanovllm/debug/utils.py
Normal file
51
nanovllm/debug/utils.py
Normal file
@@ -0,0 +1,51 @@
|
||||
"""Utility functions for breakpoint alignment debugging."""
|
||||
|
||||
import torch
|
||||
from nanovllm.utils.context import set_context, reset_context
|
||||
|
||||
|
||||
def setup_prefill_context(seq_len: int, device: torch.device):
|
||||
"""
|
||||
Set up nanovllm context for prefill alignment testing.
|
||||
|
||||
Args:
|
||||
seq_len: Sequence length
|
||||
device: Target device
|
||||
"""
|
||||
cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=device)
|
||||
slot_mapping = torch.full((seq_len,), -1, dtype=torch.int32, device=device)
|
||||
|
||||
set_context(
|
||||
is_prefill=True,
|
||||
cu_seqlens_q=cu_seqlens,
|
||||
cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=seq_len,
|
||||
max_seqlen_k=seq_len,
|
||||
slot_mapping=slot_mapping,
|
||||
is_chunked_prefill=False,
|
||||
)
|
||||
|
||||
|
||||
def setup_decode_context(context_len: int, device: torch.device):
|
||||
"""
|
||||
Set up nanovllm context for decode alignment testing.
|
||||
|
||||
Args:
|
||||
context_len: Context length (number of previous tokens)
|
||||
device: Target device
|
||||
"""
|
||||
context_lens = torch.tensor([context_len], dtype=torch.int32, device=device)
|
||||
slot_mapping = torch.tensor([-1], dtype=torch.int32, device=device)
|
||||
block_tables = torch.zeros((1, 1), dtype=torch.int32, device=device)
|
||||
|
||||
set_context(
|
||||
is_prefill=False,
|
||||
slot_mapping=slot_mapping,
|
||||
context_lens=context_lens,
|
||||
block_tables=block_tables,
|
||||
)
|
||||
|
||||
|
||||
def cleanup_context():
|
||||
"""Reset nanovllm context after alignment testing."""
|
||||
reset_context()
|
||||
134
tests/test_nanovllm_steppable.py
Normal file
134
tests/test_nanovllm_steppable.py
Normal file
@@ -0,0 +1,134 @@
|
||||
"""
|
||||
Test NanovllmSteppable: Print activation statistics at each layer.
|
||||
|
||||
Usage:
|
||||
python tests/test_nanovllm_steppable.py
|
||||
"""
|
||||
|
||||
import os
|
||||
os.environ["NANOVLLM_LOG_LEVEL"] = "WARNING"
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
from nanovllm import LLM
|
||||
from nanovllm.debug.adapters.nanovllm_adapter import NanovllmSteppable
|
||||
from utils import generate_needle_prompt, check_needle_answer
|
||||
|
||||
# ============================================================
|
||||
# Config
|
||||
# ============================================================
|
||||
MODEL_PATH = os.path.expanduser("~/models/Qwen3-0.6B/")
|
||||
INPUT_LEN = 32768 # Longer context to test offload
|
||||
MAX_NEW_TOKENS = 20
|
||||
DTYPE = torch.float16
|
||||
ENABLE_CPU_OFFLOAD = True # Test offload mode
|
||||
|
||||
# ============================================================
|
||||
# Load Model
|
||||
# ============================================================
|
||||
print(f"Loading nanovllm model (cpu_offload={ENABLE_CPU_OFFLOAD})...")
|
||||
llm = LLM(
|
||||
MODEL_PATH,
|
||||
enforce_eager=True, # Required for hooks to work
|
||||
max_model_len=40960,
|
||||
max_num_batched_tokens=40960,
|
||||
enable_cpu_offload=ENABLE_CPU_OFFLOAD,
|
||||
dtype="float16",
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
|
||||
|
||||
# Get the underlying model for steppable
|
||||
model = llm.model_runner.model
|
||||
|
||||
# ============================================================
|
||||
# Prepare Input (using needle-in-haystack prompt)
|
||||
# ============================================================
|
||||
prompt, expected_answer = generate_needle_prompt(
|
||||
tokenizer,
|
||||
target_length=INPUT_LEN,
|
||||
needle_position=0.5,
|
||||
needle_value="7492",
|
||||
use_chat_template=False,
|
||||
verbose=True,
|
||||
)
|
||||
input_ids = tokenizer.encode(prompt, return_tensors="pt").to("cuda")
|
||||
print(f"Input shape: {input_ids.shape}")
|
||||
print(f"Expected answer: {expected_answer}\n")
|
||||
|
||||
# ============================================================
|
||||
# Create Steppable Model (reused for prefill + decode)
|
||||
# ============================================================
|
||||
steppable = NanovllmSteppable(model)
|
||||
|
||||
# ============================================================
|
||||
# Prefill Phase: Print activation stats
|
||||
# ============================================================
|
||||
print("=" * 85)
|
||||
print("PREFILL PHASE")
|
||||
print("=" * 85)
|
||||
print(f"{'Layer':<15} {'Shape':<25} {'Mean':>10} {'Std':>10} {'Min':>10} {'Max':>10}")
|
||||
print("-" * 85)
|
||||
|
||||
current_ids = input_ids.clone()
|
||||
logits = None
|
||||
|
||||
for bp in steppable.step(current_ids, is_prefill=True):
|
||||
t = bp.tensor.float()
|
||||
shape_str = str(list(t.shape))
|
||||
print(f"{bp.name:<15} {shape_str:<25} {t.mean():>10.4f} {t.std():>10.4f} {t.min():>10.4f} {t.max():>10.4f}")
|
||||
if bp.name == "LM Head":
|
||||
logits = bp.tensor
|
||||
|
||||
# Get first token from prefill
|
||||
next_token_id = logits[0, -1].argmax().item()
|
||||
next_token = tokenizer.decode(next_token_id)
|
||||
current_ids = torch.cat([current_ids, torch.tensor([[next_token_id]], device=current_ids.device)], dim=1)
|
||||
generated_tokens = [next_token]
|
||||
|
||||
# ============================================================
|
||||
# Decode Phase: Only print generated tokens
|
||||
# ============================================================
|
||||
print("\n" + "=" * 85)
|
||||
print("DECODE PHASE")
|
||||
print("=" * 85)
|
||||
print(f"Step 1: {next_token!r}")
|
||||
|
||||
for step in range(2, MAX_NEW_TOKENS + 1):
|
||||
# Forward pass with full sequence (reuse same steppable)
|
||||
# Note: nanovllm without KV cache needs full sequence for each decode
|
||||
for bp in steppable.step(current_ids, is_prefill=True):
|
||||
if bp.name == "LM Head":
|
||||
logits = bp.tensor
|
||||
|
||||
# Get next token (greedy)
|
||||
next_token_id = logits[0, -1].argmax().item()
|
||||
next_token = tokenizer.decode(next_token_id)
|
||||
generated_tokens.append(next_token)
|
||||
|
||||
print(f"Step {step:2}: {next_token!r}")
|
||||
|
||||
# Stop if EOS
|
||||
if next_token_id == tokenizer.eos_token_id:
|
||||
print(" (EOS)")
|
||||
break
|
||||
|
||||
# Append to sequence
|
||||
current_ids = torch.cat([current_ids, torch.tensor([[next_token_id]], device=current_ids.device)], dim=1)
|
||||
|
||||
# ============================================================
|
||||
# Result
|
||||
# ============================================================
|
||||
print("\n" + "=" * 85)
|
||||
print("RESULT")
|
||||
print("=" * 85)
|
||||
generated_text = "".join(generated_tokens)
|
||||
print(f"Generated: {generated_text!r}")
|
||||
print(f"Expected: {expected_answer}")
|
||||
print(f"Answer: {'CORRECT!' if check_needle_answer(generated_text, expected_answer) else 'INCORRECT'}")
|
||||
|
||||
print("\ntest_nanovllm_steppable: PASSED")
|
||||
@@ -8,7 +8,7 @@ sequences longer than ~200 tokens. Use --no-offload for correctness testing.
|
||||
"""
|
||||
|
||||
import os
|
||||
os.environ["NANOVLLM_LOG_LEVEL"] = "DEBUG"
|
||||
os.environ["NANOVLLM_LOG_LEVEL"] = "INFO"
|
||||
|
||||
import argparse
|
||||
from nanovllm import LLM, SamplingParams
|
||||
|
||||
121
tests/test_torch_steppable.py
Normal file
121
tests/test_torch_steppable.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""
|
||||
Test TorchSteppable: Print activation statistics at each layer.
|
||||
|
||||
Usage:
|
||||
python tests/test_torch_steppable.py
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
from modeling_qwen3 import Qwen3ForCausalLM
|
||||
from nanovllm.debug.adapters.torch_adapter import TorchSteppable
|
||||
from utils import generate_needle_prompt, check_needle_answer
|
||||
|
||||
# ============================================================
|
||||
# Config
|
||||
# ============================================================
|
||||
MODEL_PATH = os.path.expanduser("~/models/Qwen3-0.6B/")
|
||||
INPUT_LEN = 512
|
||||
MAX_NEW_TOKENS = 20
|
||||
DTYPE = torch.float16
|
||||
|
||||
# ============================================================
|
||||
# Load Model
|
||||
# ============================================================
|
||||
print("Loading model...")
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
|
||||
model = Qwen3ForCausalLM.from_pretrained(MODEL_PATH, dtype=DTYPE)
|
||||
model = model.to("cuda").eval()
|
||||
|
||||
# ============================================================
|
||||
# Prepare Input (using needle-in-haystack prompt)
|
||||
# ============================================================
|
||||
prompt, expected_answer = generate_needle_prompt(
|
||||
tokenizer,
|
||||
target_length=INPUT_LEN,
|
||||
needle_position=0.5,
|
||||
needle_value="7492",
|
||||
use_chat_template=False,
|
||||
verbose=True,
|
||||
)
|
||||
input_ids = tokenizer.encode(prompt, return_tensors="pt").to("cuda")
|
||||
print(f"Input shape: {input_ids.shape}")
|
||||
print(f"Expected answer: {expected_answer}\n")
|
||||
|
||||
# ============================================================
|
||||
# Create Steppable Model (reused for prefill + decode)
|
||||
# ============================================================
|
||||
steppable = TorchSteppable(model)
|
||||
|
||||
# ============================================================
|
||||
# Prefill Phase: Print activation stats
|
||||
# ============================================================
|
||||
print("=" * 85)
|
||||
print("PREFILL PHASE")
|
||||
print("=" * 85)
|
||||
print(f"{'Layer':<15} {'Shape':<25} {'Mean':>10} {'Std':>10} {'Min':>10} {'Max':>10}")
|
||||
print("-" * 85)
|
||||
|
||||
current_ids = input_ids.clone()
|
||||
logits = None
|
||||
|
||||
for bp in steppable.step(current_ids):
|
||||
t = bp.tensor.float()
|
||||
shape_str = str(list(t.shape))
|
||||
print(f"{bp.name:<15} {shape_str:<25} {t.mean():>10.4f} {t.std():>10.4f} {t.min():>10.4f} {t.max():>10.4f}")
|
||||
if bp.name == "LM Head":
|
||||
logits = bp.tensor
|
||||
|
||||
# Get first token from prefill
|
||||
next_token_id = logits[0, -1].argmax().item()
|
||||
next_token = tokenizer.decode(next_token_id)
|
||||
current_ids = torch.cat([current_ids, torch.tensor([[next_token_id]], device=current_ids.device)], dim=1)
|
||||
generated_tokens = [next_token]
|
||||
|
||||
# ============================================================
|
||||
# Decode Phase: Only print generated tokens
|
||||
# ============================================================
|
||||
print("\n" + "=" * 85)
|
||||
print("DECODE PHASE")
|
||||
print("=" * 85)
|
||||
print(f"Step 1: {next_token!r}")
|
||||
|
||||
for step in range(2, MAX_NEW_TOKENS + 1):
|
||||
# Forward pass (reuse same steppable)
|
||||
for bp in steppable.step(current_ids):
|
||||
if bp.name == "LM Head":
|
||||
logits = bp.tensor
|
||||
|
||||
# Get next token (greedy)
|
||||
next_token_id = logits[0, -1].argmax().item()
|
||||
next_token = tokenizer.decode(next_token_id)
|
||||
generated_tokens.append(next_token)
|
||||
|
||||
print(f"Step {step:2}: {next_token!r}")
|
||||
|
||||
# Stop if EOS
|
||||
if next_token_id == tokenizer.eos_token_id:
|
||||
print(" (EOS)")
|
||||
break
|
||||
|
||||
# Append to sequence
|
||||
current_ids = torch.cat([current_ids, torch.tensor([[next_token_id]], device=current_ids.device)], dim=1)
|
||||
|
||||
# ============================================================
|
||||
# Result
|
||||
# ============================================================
|
||||
print("\n" + "=" * 85)
|
||||
print("RESULT")
|
||||
print("=" * 85)
|
||||
generated_text = "".join(generated_tokens)
|
||||
print(f"Generated: {generated_text!r}")
|
||||
print(f"Expected: {expected_answer}")
|
||||
print(f"Answer: {'CORRECT!' if check_needle_answer(generated_text, expected_answer) else 'INCORRECT'}")
|
||||
|
||||
print("\ntest_torch_steppable: PASSED")
|
||||
Reference in New Issue
Block a user