Files
nano-vllm/nanovllm/debug/adapters/torch_adapter.py
2026-01-03 22:36:40 +08:00

120 lines
3.9 KiB
Python

"""Torch reference model adapter for breakpoint alignment."""
from typing import Generator, Set, Optional
import torch
from ..breakpoints import Breakpoint, BreakpointType
from .base import SteppableModel
class TorchSteppable(SteppableModel):
"""
Steppable adapter for the torch reference Qwen3 implementation.
Wraps tests/modeling_qwen3.py Qwen3ForCausalLM model.
"""
def __init__(
self,
model,
enabled_breakpoints: Optional[Set[BreakpointType]] = None,
):
"""
Args:
model: Qwen3ForCausalLM from tests/modeling_qwen3.py
enabled_breakpoints: Set of breakpoint types to yield at
"""
super().__init__(enabled_breakpoints)
self.model = model
self.model.eval()
@property
def num_layers(self) -> int:
return len(self.model.model.layers)
def step(
self,
input_ids: torch.Tensor,
positions: Optional[torch.Tensor] = None,
is_prefill: bool = True,
) -> Generator[Breakpoint, None, torch.Tensor]:
"""
Generator that manually steps through the torch model.
The torch model uses [batch, seq_len, hidden_size] shapes.
"""
# Ensure batch dimension
if input_ids.dim() == 1:
input_ids = input_ids.unsqueeze(0)
batch_size, seq_len = input_ids.shape
device = input_ids.device
# Generate position IDs if not provided
if positions is None:
positions = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1)
elif positions.dim() == 1:
positions = positions.unsqueeze(0)
with torch.no_grad():
# === EMBEDDING ===
hidden_states = self.model.model.embed_tokens(input_ids)
if self.is_enabled(BreakpointType.EMBEDDING):
yield Breakpoint(
bp_type=BreakpointType.EMBEDDING,
layer_idx=None,
tensor=hidden_states.detach().clone(),
name="Embedding",
)
# Create causal attention mask
causal_mask = torch.triu(
torch.full((seq_len, seq_len), float("-inf"), device=device),
diagonal=1,
)
attention_mask = causal_mask.unsqueeze(0).unsqueeze(0)
# === DECODER LAYERS ===
for layer_idx, layer in enumerate(self.model.model.layers):
hidden_states, _, _ = layer(
hidden_states=hidden_states,
position_ids=positions,
attention_mask=attention_mask,
past_key_value=None,
use_cache=False,
output_qkv=False,
)
if self.is_enabled(BreakpointType.LAYER_OUTPUT):
yield Breakpoint(
bp_type=BreakpointType.LAYER_OUTPUT,
layer_idx=layer_idx,
tensor=hidden_states.detach().clone(),
name=f"Layer {layer_idx}",
)
# === FINAL NORM ===
hidden_states = self.model.model.norm(hidden_states)
if self.is_enabled(BreakpointType.FINAL_NORM):
yield Breakpoint(
bp_type=BreakpointType.FINAL_NORM,
layer_idx=None,
tensor=hidden_states.detach().clone(),
name="Final Norm",
)
# === LM HEAD ===
logits = self.model.lm_head(hidden_states)
if self.is_enabled(BreakpointType.LM_HEAD):
yield Breakpoint(
bp_type=BreakpointType.LM_HEAD,
layer_idx=None,
tensor=logits.detach().clone(),
name="LM Head",
)
return logits