230 lines
8.2 KiB
Python
230 lines
8.2 KiB
Python
"""Nanovllm model adapter for breakpoint alignment."""
|
|
|
|
from typing import Generator, Set, Optional, Dict, Any, List
|
|
import torch
|
|
|
|
from nanovllm.utils.context import set_context, reset_context
|
|
from ..breakpoints import Breakpoint, BreakpointType
|
|
from .base import SteppableModel
|
|
|
|
|
|
class NanovllmSteppable(SteppableModel):
|
|
"""
|
|
Steppable adapter for nanovllm Qwen3 implementation.
|
|
|
|
Uses PyTorch hooks to capture intermediate values during forward pass,
|
|
then yields them as breakpoints after execution completes.
|
|
|
|
Key challenges handled:
|
|
1. Shape difference: nanovllm uses [num_tokens, hidden] vs [batch, seq, hidden]
|
|
2. Context-based attention: must call set_context() before forward
|
|
3. Fused operations: decoder layer returns (hidden_states, residual) tuple
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model,
|
|
enabled_breakpoints: Optional[Set[BreakpointType]] = None,
|
|
):
|
|
"""
|
|
Args:
|
|
model: Qwen3ForCausalLM from nanovllm
|
|
enabled_breakpoints: Set of breakpoint types to yield at
|
|
"""
|
|
super().__init__(enabled_breakpoints)
|
|
self.model = model
|
|
self.model.eval()
|
|
self._hooks: List[Any] = []
|
|
self._captured: Dict[str, torch.Tensor] = {}
|
|
|
|
@property
|
|
def num_layers(self) -> int:
|
|
return len(self.model.model.layers)
|
|
|
|
def _register_hooks(self):
|
|
"""Register forward hooks on all relevant modules."""
|
|
self._hooks = []
|
|
self._captured = {}
|
|
|
|
# Hook for embedding output
|
|
def embed_hook(module, input, output):
|
|
self._captured["embed"] = output.detach().clone()
|
|
|
|
self._hooks.append(
|
|
self.model.model.embed_tokens.register_forward_hook(embed_hook)
|
|
)
|
|
|
|
# Hooks for each decoder layer
|
|
for layer_idx in range(self.num_layers):
|
|
layer = self.model.model.layers[layer_idx]
|
|
|
|
def make_layer_hook(idx):
|
|
def hook(module, input, output):
|
|
# Decoder layer returns (hidden_states, residual)
|
|
hidden_states = output[0] if isinstance(output, tuple) else output
|
|
self._captured[f"layer_{idx}"] = hidden_states.detach().clone()
|
|
return hook
|
|
|
|
self._hooks.append(
|
|
layer.register_forward_hook(make_layer_hook(layer_idx))
|
|
)
|
|
|
|
# Hook for final norm
|
|
def final_norm_hook(module, input, output):
|
|
# Final norm returns (hidden_states, _) for fused add
|
|
hidden_states = output[0] if isinstance(output, tuple) else output
|
|
self._captured["final_norm"] = hidden_states.detach().clone()
|
|
|
|
self._hooks.append(
|
|
self.model.model.norm.register_forward_hook(final_norm_hook)
|
|
)
|
|
|
|
# Hook for lm_head
|
|
def lm_head_hook(module, input, output):
|
|
self._captured["lm_head"] = output.detach().clone()
|
|
|
|
self._hooks.append(
|
|
self.model.lm_head.register_forward_hook(lm_head_hook)
|
|
)
|
|
|
|
def _remove_hooks(self):
|
|
"""Remove all registered hooks."""
|
|
for hook in self._hooks:
|
|
hook.remove()
|
|
self._hooks = []
|
|
|
|
def _setup_context(self, seq_len: int, device: torch.device, is_prefill: bool):
|
|
"""
|
|
Set up nanovllm context for forward pass.
|
|
|
|
For alignment testing, we use simple context without real KV cache.
|
|
"""
|
|
if is_prefill:
|
|
# Prefill: process all tokens at once
|
|
cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=device)
|
|
# Use -1 for slot_mapping to skip KV cache writes
|
|
slot_mapping = torch.full((seq_len,), -1, dtype=torch.int32, device=device)
|
|
|
|
set_context(
|
|
is_prefill=True,
|
|
cu_seqlens_q=cu_seqlens,
|
|
cu_seqlens_k=cu_seqlens,
|
|
max_seqlen_q=seq_len,
|
|
max_seqlen_k=seq_len,
|
|
slot_mapping=slot_mapping,
|
|
is_chunked_prefill=False,
|
|
)
|
|
else:
|
|
# Decode: single token generation
|
|
# For decode, we need context_lens and block_tables
|
|
# For alignment testing without real KV cache, we use minimal setup
|
|
context_lens = torch.tensor([seq_len - 1], dtype=torch.int32, device=device)
|
|
# Single token slot
|
|
slot_mapping = torch.tensor([-1], dtype=torch.int32, device=device)
|
|
# Empty block tables (no KV cache)
|
|
block_tables = torch.zeros((1, 1), dtype=torch.int32, device=device)
|
|
|
|
set_context(
|
|
is_prefill=False,
|
|
slot_mapping=slot_mapping,
|
|
context_lens=context_lens,
|
|
block_tables=block_tables,
|
|
)
|
|
|
|
def _normalize_tensor(self, tensor: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Normalize nanovllm tensor shape to [batch, seq_len, ...].
|
|
|
|
nanovllm uses [num_tokens, ...] format without batch dimension.
|
|
We add batch dimension for comparison with torch model.
|
|
"""
|
|
if tensor.dim() == 2: # [num_tokens, hidden_size]
|
|
return tensor.unsqueeze(0)
|
|
return tensor
|
|
|
|
def step(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: Optional[torch.Tensor] = None,
|
|
is_prefill: bool = True,
|
|
) -> Generator[Breakpoint, None, torch.Tensor]:
|
|
"""
|
|
Execute nanovllm forward pass with hooks to capture breakpoints.
|
|
|
|
Unlike the torch adapter which manually steps through each component,
|
|
we run the full forward pass and collect captured values afterward.
|
|
"""
|
|
# Ensure 1D for nanovllm (it expects [num_tokens])
|
|
if input_ids.dim() == 2:
|
|
input_ids = input_ids.squeeze(0)
|
|
|
|
seq_len = input_ids.numel()
|
|
device = input_ids.device
|
|
|
|
# Generate position IDs if not provided
|
|
if positions is None:
|
|
positions = torch.arange(seq_len, device=device)
|
|
elif positions.dim() == 2:
|
|
positions = positions.squeeze(0)
|
|
|
|
# Register hooks
|
|
self._register_hooks()
|
|
|
|
try:
|
|
# Setup context for attention
|
|
self._setup_context(seq_len, device, is_prefill)
|
|
|
|
# Run forward pass (hooks capture everything)
|
|
with torch.no_grad():
|
|
hidden_states = self.model(input_ids, positions)
|
|
logits = self.model.compute_logits(hidden_states)
|
|
|
|
reset_context()
|
|
|
|
# Yield breakpoints in order from captured data
|
|
|
|
# EMBEDDING
|
|
if self.is_enabled(BreakpointType.EMBEDDING) and "embed" in self._captured:
|
|
yield Breakpoint(
|
|
bp_type=BreakpointType.EMBEDDING,
|
|
layer_idx=None,
|
|
tensor=self._normalize_tensor(self._captured["embed"]),
|
|
name="Embedding",
|
|
)
|
|
|
|
# LAYER_OUTPUT for each layer
|
|
if self.is_enabled(BreakpointType.LAYER_OUTPUT):
|
|
for layer_idx in range(self.num_layers):
|
|
key = f"layer_{layer_idx}"
|
|
if key in self._captured:
|
|
yield Breakpoint(
|
|
bp_type=BreakpointType.LAYER_OUTPUT,
|
|
layer_idx=layer_idx,
|
|
tensor=self._normalize_tensor(self._captured[key]),
|
|
name=f"Layer {layer_idx}",
|
|
)
|
|
|
|
# FINAL_NORM
|
|
if self.is_enabled(BreakpointType.FINAL_NORM) and "final_norm" in self._captured:
|
|
yield Breakpoint(
|
|
bp_type=BreakpointType.FINAL_NORM,
|
|
layer_idx=None,
|
|
tensor=self._normalize_tensor(self._captured["final_norm"]),
|
|
name="Final Norm",
|
|
)
|
|
|
|
# LM_HEAD
|
|
if self.is_enabled(BreakpointType.LM_HEAD) and "lm_head" in self._captured:
|
|
yield Breakpoint(
|
|
bp_type=BreakpointType.LM_HEAD,
|
|
layer_idx=None,
|
|
tensor=self._normalize_tensor(self._captured["lm_head"]),
|
|
name="LM Head",
|
|
)
|
|
|
|
return logits
|
|
|
|
finally:
|
|
self._remove_hooks()
|
|
self._captured = {}
|