[feat] Added debug tools.
This commit is contained in:
39
nanovllm/debug/breakpoints.py
Normal file
39
nanovllm/debug/breakpoints.py
Normal file
@@ -0,0 +1,39 @@
|
||||
"""Breakpoint types and data structures for alignment debugging."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from typing import Optional
|
||||
import torch
|
||||
|
||||
|
||||
class BreakpointType(Enum):
|
||||
"""Types of breakpoints in the model forward pass."""
|
||||
EMBEDDING = auto() # After embed_tokens
|
||||
LAYER_OUTPUT = auto() # After each decoder layer
|
||||
FINAL_NORM = auto() # After final RMSNorm
|
||||
LM_HEAD = auto() # After lm_head (logits)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Breakpoint:
|
||||
"""A captured breakpoint with tensor data."""
|
||||
bp_type: BreakpointType
|
||||
layer_idx: Optional[int] # None for EMBEDDING, FINAL_NORM, LM_HEAD
|
||||
tensor: torch.Tensor
|
||||
name: str
|
||||
|
||||
def normalize_shape(self) -> torch.Tensor:
|
||||
"""
|
||||
Normalize tensor shape for comparison.
|
||||
|
||||
nanovllm uses [num_tokens, hidden_size] while torch uses
|
||||
[batch, seq_len, hidden_size]. This adds a batch dimension
|
||||
to 2D tensors for comparison.
|
||||
"""
|
||||
if self.tensor.dim() == 2:
|
||||
return self.tensor.unsqueeze(0)
|
||||
return self.tensor
|
||||
|
||||
def __repr__(self) -> str:
|
||||
shape_str = "x".join(str(d) for d in self.tensor.shape)
|
||||
return f"Breakpoint({self.name}, shape={shape_str}, dtype={self.tensor.dtype})"
|
||||
Reference in New Issue
Block a user