Files
nano-vllm/nanovllm/debug/breakpoints.py
2026-01-03 22:36:40 +08:00

40 lines
1.2 KiB
Python

"""Breakpoint types and data structures for alignment debugging."""
from dataclasses import dataclass
from enum import Enum, auto
from typing import Optional
import torch
class BreakpointType(Enum):
"""Types of breakpoints in the model forward pass."""
EMBEDDING = auto() # After embed_tokens
LAYER_OUTPUT = auto() # After each decoder layer
FINAL_NORM = auto() # After final RMSNorm
LM_HEAD = auto() # After lm_head (logits)
@dataclass
class Breakpoint:
"""A captured breakpoint with tensor data."""
bp_type: BreakpointType
layer_idx: Optional[int] # None for EMBEDDING, FINAL_NORM, LM_HEAD
tensor: torch.Tensor
name: str
def normalize_shape(self) -> torch.Tensor:
"""
Normalize tensor shape for comparison.
nanovllm uses [num_tokens, hidden_size] while torch uses
[batch, seq_len, hidden_size]. This adds a batch dimension
to 2D tensors for comparison.
"""
if self.tensor.dim() == 2:
return self.tensor.unsqueeze(0)
return self.tensor
def __repr__(self) -> str:
shape_str = "x".join(str(d) for d in self.tensor.shape)
return f"Breakpoint({self.name}, shape={shape_str}, dtype={self.tensor.dtype})"