120 lines
3.9 KiB
Python
120 lines
3.9 KiB
Python
"""Torch reference model adapter for breakpoint alignment."""
|
|
|
|
from typing import Generator, Set, Optional
|
|
import torch
|
|
|
|
from ..breakpoints import Breakpoint, BreakpointType
|
|
from .base import SteppableModel
|
|
|
|
|
|
class TorchSteppable(SteppableModel):
|
|
"""
|
|
Steppable adapter for the torch reference Qwen3 implementation.
|
|
|
|
Wraps tests/modeling_qwen3.py Qwen3ForCausalLM model.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model,
|
|
enabled_breakpoints: Optional[Set[BreakpointType]] = None,
|
|
):
|
|
"""
|
|
Args:
|
|
model: Qwen3ForCausalLM from tests/modeling_qwen3.py
|
|
enabled_breakpoints: Set of breakpoint types to yield at
|
|
"""
|
|
super().__init__(enabled_breakpoints)
|
|
self.model = model
|
|
self.model.eval()
|
|
|
|
@property
|
|
def num_layers(self) -> int:
|
|
return len(self.model.model.layers)
|
|
|
|
def step(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: Optional[torch.Tensor] = None,
|
|
is_prefill: bool = True,
|
|
) -> Generator[Breakpoint, None, torch.Tensor]:
|
|
"""
|
|
Generator that manually steps through the torch model.
|
|
|
|
The torch model uses [batch, seq_len, hidden_size] shapes.
|
|
"""
|
|
# Ensure batch dimension
|
|
if input_ids.dim() == 1:
|
|
input_ids = input_ids.unsqueeze(0)
|
|
|
|
batch_size, seq_len = input_ids.shape
|
|
device = input_ids.device
|
|
|
|
# Generate position IDs if not provided
|
|
if positions is None:
|
|
positions = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1)
|
|
elif positions.dim() == 1:
|
|
positions = positions.unsqueeze(0)
|
|
|
|
with torch.no_grad():
|
|
# === EMBEDDING ===
|
|
hidden_states = self.model.model.embed_tokens(input_ids)
|
|
|
|
if self.is_enabled(BreakpointType.EMBEDDING):
|
|
yield Breakpoint(
|
|
bp_type=BreakpointType.EMBEDDING,
|
|
layer_idx=None,
|
|
tensor=hidden_states.detach().clone(),
|
|
name="Embedding",
|
|
)
|
|
|
|
# Create causal attention mask
|
|
causal_mask = torch.triu(
|
|
torch.full((seq_len, seq_len), float("-inf"), device=device),
|
|
diagonal=1,
|
|
)
|
|
attention_mask = causal_mask.unsqueeze(0).unsqueeze(0)
|
|
|
|
# === DECODER LAYERS ===
|
|
for layer_idx, layer in enumerate(self.model.model.layers):
|
|
hidden_states, _, _ = layer(
|
|
hidden_states=hidden_states,
|
|
position_ids=positions,
|
|
attention_mask=attention_mask,
|
|
past_key_value=None,
|
|
use_cache=False,
|
|
output_qkv=False,
|
|
)
|
|
|
|
if self.is_enabled(BreakpointType.LAYER_OUTPUT):
|
|
yield Breakpoint(
|
|
bp_type=BreakpointType.LAYER_OUTPUT,
|
|
layer_idx=layer_idx,
|
|
tensor=hidden_states.detach().clone(),
|
|
name=f"Layer {layer_idx}",
|
|
)
|
|
|
|
# === FINAL NORM ===
|
|
hidden_states = self.model.model.norm(hidden_states)
|
|
|
|
if self.is_enabled(BreakpointType.FINAL_NORM):
|
|
yield Breakpoint(
|
|
bp_type=BreakpointType.FINAL_NORM,
|
|
layer_idx=None,
|
|
tensor=hidden_states.detach().clone(),
|
|
name="Final Norm",
|
|
)
|
|
|
|
# === LM HEAD ===
|
|
logits = self.model.lm_head(hidden_states)
|
|
|
|
if self.is_enabled(BreakpointType.LM_HEAD):
|
|
yield Breakpoint(
|
|
bp_type=BreakpointType.LM_HEAD,
|
|
layer_idx=None,
|
|
tensor=logits.detach().clone(),
|
|
name="LM Head",
|
|
)
|
|
|
|
return logits
|