Files
nano-vllm/nanovllm/debug/aligner.py
2026-01-03 22:36:40 +08:00

212 lines
7.8 KiB
Python

"""Breakpoint aligner for comparing model outputs."""
from dataclasses import dataclass, field
from typing import List, Tuple, Optional
import torch
from .breakpoints import Breakpoint
from .comparator import TensorComparator, ComparisonResult
from .adapters.base import SteppableModel
@dataclass
class AlignmentResult:
"""Result of an alignment test."""
passed: bool
all_comparisons: List[Tuple[Breakpoint, Breakpoint, ComparisonResult]] = field(default_factory=list)
failed_at: Optional[Breakpoint] = None
message: str = ""
def __repr__(self) -> str:
passed_count = sum(1 for _, _, c in self.all_comparisons if c.passed)
total = len(self.all_comparisons)
status = "PASSED" if self.passed else "FAILED"
return f"AlignmentResult({status}, {passed_count}/{total} breakpoints passed)"
class BreakpointAligner:
"""
Orchestrates alternating execution of reference and test models,
comparing outputs at each breakpoint.
Example:
>>> from nanovllm.debug import BreakpointAligner, TorchSteppable, NanovllmSteppable
>>> ref = TorchSteppable(torch_model)
>>> test = NanovllmSteppable(nanovllm_model)
>>> aligner = BreakpointAligner(ref, test)
>>> result = aligner.align(input_ids)
>>> print(result)
"""
def __init__(
self,
ref_model: SteppableModel,
test_model: SteppableModel,
comparator: Optional[TensorComparator] = None,
stop_on_error: bool = True,
verbose: bool = True,
):
"""
Args:
ref_model: Reference (torch) steppable model
test_model: Test (nanovllm) steppable model
comparator: Tensor comparator instance (uses default if None)
stop_on_error: If True, stop at first mismatch
verbose: If True, print comparison results
"""
self.ref_model = ref_model
self.test_model = test_model
self.comparator = comparator or TensorComparator()
self.stop_on_error = stop_on_error
self.verbose = verbose
def align(
self,
input_ids: torch.Tensor,
positions: Optional[torch.Tensor] = None,
is_prefill: bool = True,
) -> AlignmentResult:
"""
Run both models with same input, comparing at each breakpoint.
Args:
input_ids: Input token IDs
positions: Position IDs (optional)
is_prefill: True for prefill phase, False for decode
Returns:
AlignmentResult with pass/fail status and details
"""
all_comparisons: List[Tuple[Breakpoint, Breakpoint, ComparisonResult]] = []
if self.verbose:
phase = "prefill" if is_prefill else "decode"
print(f"\n{'='*60}")
print(f"Alignment Test ({phase})")
print(f"{'='*60}")
# Start both generators
ref_gen = self.ref_model.step(input_ids, positions, is_prefill)
test_gen = self.test_model.step(input_ids, positions, is_prefill)
try:
while True:
# Get next breakpoint from reference
try:
ref_bp = next(ref_gen)
except StopIteration:
break
# Get corresponding breakpoint from test
try:
test_bp = next(test_gen)
except StopIteration:
if self.verbose:
print(f"Test model ended early at {ref_bp.name}")
return AlignmentResult(
passed=False,
all_comparisons=all_comparisons,
failed_at=ref_bp,
message=f"Test model ended early at {ref_bp.name}",
)
# Verify breakpoints match
if ref_bp.bp_type != test_bp.bp_type:
msg = f"Breakpoint type mismatch: {ref_bp.bp_type} vs {test_bp.bp_type}"
if self.verbose:
print(msg)
return AlignmentResult(
passed=False,
all_comparisons=all_comparisons,
failed_at=ref_bp,
message=msg,
)
if ref_bp.layer_idx != test_bp.layer_idx:
msg = f"Layer index mismatch: {ref_bp.layer_idx} vs {test_bp.layer_idx}"
if self.verbose:
print(msg)
return AlignmentResult(
passed=False,
all_comparisons=all_comparisons,
failed_at=ref_bp,
message=msg,
)
# Normalize shapes for comparison
ref_t = ref_bp.normalize_shape()
test_t = test_bp.normalize_shape()
# Handle shape mismatches
if ref_t.shape != test_t.shape:
if self.verbose:
print(f"[{ref_bp.name}] Shape mismatch: ref={ref_t.shape} vs test={test_t.shape}")
# Try to reshape if element count matches
if ref_t.numel() == test_t.numel():
test_t = test_t.view(ref_t.shape)
else:
msg = f"Shape mismatch at {ref_bp.name}: {ref_t.shape} vs {test_t.shape}"
return AlignmentResult(
passed=False,
all_comparisons=all_comparisons,
failed_at=ref_bp,
message=msg,
)
# Compare tensors
result = self.comparator.compare(ref_t, test_t, ref_bp.name)
all_comparisons.append((ref_bp, test_bp, result))
if self.verbose:
status = "\u2713" if result.passed else "\u2717"
print(f"{status} [{ref_bp.name}] cos={result.cosine_similarity:.6f}, max_diff={result.max_abs_diff:.2e}")
if not result.passed and self.stop_on_error:
if self.verbose:
print(f"\nStopped at {ref_bp.name} (stop_on_error=True)")
print(result.message)
return AlignmentResult(
passed=False,
all_comparisons=all_comparisons,
failed_at=ref_bp,
message=f"Alignment failed at {ref_bp.name}",
)
# Check for extra test breakpoints
try:
extra_bp = next(test_gen)
msg = f"Test model has extra breakpoints starting at {extra_bp.name}"
if self.verbose:
print(msg)
return AlignmentResult(
passed=False,
all_comparisons=all_comparisons,
message=msg,
)
except StopIteration:
pass
except Exception as e:
msg = f"Exception during alignment: {str(e)}"
if self.verbose:
print(msg)
raise
# Summary
all_passed = all(comp[2].passed for comp in all_comparisons)
passed_count = sum(1 for _, _, c in all_comparisons if c.passed)
total = len(all_comparisons)
if self.verbose:
print(f"{'='*60}")
status = "PASSED" if all_passed else "FAILED"
print(f"Result: {status} ({passed_count}/{total} breakpoints)")
print(f"{'='*60}\n")
return AlignmentResult(
passed=all_passed,
all_comparisons=all_comparisons,
message="All breakpoints aligned" if all_passed else "Some breakpoints failed",
)