40 lines
1.2 KiB
Python
40 lines
1.2 KiB
Python
"""Breakpoint types and data structures for alignment debugging."""
|
|
|
|
from dataclasses import dataclass
|
|
from enum import Enum, auto
|
|
from typing import Optional
|
|
import torch
|
|
|
|
|
|
class BreakpointType(Enum):
|
|
"""Types of breakpoints in the model forward pass."""
|
|
EMBEDDING = auto() # After embed_tokens
|
|
LAYER_OUTPUT = auto() # After each decoder layer
|
|
FINAL_NORM = auto() # After final RMSNorm
|
|
LM_HEAD = auto() # After lm_head (logits)
|
|
|
|
|
|
@dataclass
|
|
class Breakpoint:
|
|
"""A captured breakpoint with tensor data."""
|
|
bp_type: BreakpointType
|
|
layer_idx: Optional[int] # None for EMBEDDING, FINAL_NORM, LM_HEAD
|
|
tensor: torch.Tensor
|
|
name: str
|
|
|
|
def normalize_shape(self) -> torch.Tensor:
|
|
"""
|
|
Normalize tensor shape for comparison.
|
|
|
|
nanovllm uses [num_tokens, hidden_size] while torch uses
|
|
[batch, seq_len, hidden_size]. This adds a batch dimension
|
|
to 2D tensors for comparison.
|
|
"""
|
|
if self.tensor.dim() == 2:
|
|
return self.tensor.unsqueeze(0)
|
|
return self.tensor
|
|
|
|
def __repr__(self) -> str:
|
|
shape_str = "x".join(str(d) for d in self.tensor.shape)
|
|
return f"Breakpoint({self.name}, shape={shape_str}, dtype={self.tensor.dtype})"
|