Files
nano-vllm/nanovllm/debug/adapters/nanovllm_adapter.py
2026-01-03 22:36:40 +08:00

230 lines
8.2 KiB
Python

"""Nanovllm model adapter for breakpoint alignment."""
from typing import Generator, Set, Optional, Dict, Any, List
import torch
from nanovllm.utils.context import set_context, reset_context
from ..breakpoints import Breakpoint, BreakpointType
from .base import SteppableModel
class NanovllmSteppable(SteppableModel):
"""
Steppable adapter for nanovllm Qwen3 implementation.
Uses PyTorch hooks to capture intermediate values during forward pass,
then yields them as breakpoints after execution completes.
Key challenges handled:
1. Shape difference: nanovllm uses [num_tokens, hidden] vs [batch, seq, hidden]
2. Context-based attention: must call set_context() before forward
3. Fused operations: decoder layer returns (hidden_states, residual) tuple
"""
def __init__(
self,
model,
enabled_breakpoints: Optional[Set[BreakpointType]] = None,
):
"""
Args:
model: Qwen3ForCausalLM from nanovllm
enabled_breakpoints: Set of breakpoint types to yield at
"""
super().__init__(enabled_breakpoints)
self.model = model
self.model.eval()
self._hooks: List[Any] = []
self._captured: Dict[str, torch.Tensor] = {}
@property
def num_layers(self) -> int:
return len(self.model.model.layers)
def _register_hooks(self):
"""Register forward hooks on all relevant modules."""
self._hooks = []
self._captured = {}
# Hook for embedding output
def embed_hook(module, input, output):
self._captured["embed"] = output.detach().clone()
self._hooks.append(
self.model.model.embed_tokens.register_forward_hook(embed_hook)
)
# Hooks for each decoder layer
for layer_idx in range(self.num_layers):
layer = self.model.model.layers[layer_idx]
def make_layer_hook(idx):
def hook(module, input, output):
# Decoder layer returns (hidden_states, residual)
hidden_states = output[0] if isinstance(output, tuple) else output
self._captured[f"layer_{idx}"] = hidden_states.detach().clone()
return hook
self._hooks.append(
layer.register_forward_hook(make_layer_hook(layer_idx))
)
# Hook for final norm
def final_norm_hook(module, input, output):
# Final norm returns (hidden_states, _) for fused add
hidden_states = output[0] if isinstance(output, tuple) else output
self._captured["final_norm"] = hidden_states.detach().clone()
self._hooks.append(
self.model.model.norm.register_forward_hook(final_norm_hook)
)
# Hook for lm_head
def lm_head_hook(module, input, output):
self._captured["lm_head"] = output.detach().clone()
self._hooks.append(
self.model.lm_head.register_forward_hook(lm_head_hook)
)
def _remove_hooks(self):
"""Remove all registered hooks."""
for hook in self._hooks:
hook.remove()
self._hooks = []
def _setup_context(self, seq_len: int, device: torch.device, is_prefill: bool):
"""
Set up nanovllm context for forward pass.
For alignment testing, we use simple context without real KV cache.
"""
if is_prefill:
# Prefill: process all tokens at once
cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=device)
# Use -1 for slot_mapping to skip KV cache writes
slot_mapping = torch.full((seq_len,), -1, dtype=torch.int32, device=device)
set_context(
is_prefill=True,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=seq_len,
max_seqlen_k=seq_len,
slot_mapping=slot_mapping,
is_chunked_prefill=False,
)
else:
# Decode: single token generation
# For decode, we need context_lens and block_tables
# For alignment testing without real KV cache, we use minimal setup
context_lens = torch.tensor([seq_len - 1], dtype=torch.int32, device=device)
# Single token slot
slot_mapping = torch.tensor([-1], dtype=torch.int32, device=device)
# Empty block tables (no KV cache)
block_tables = torch.zeros((1, 1), dtype=torch.int32, device=device)
set_context(
is_prefill=False,
slot_mapping=slot_mapping,
context_lens=context_lens,
block_tables=block_tables,
)
def _normalize_tensor(self, tensor: torch.Tensor) -> torch.Tensor:
"""
Normalize nanovllm tensor shape to [batch, seq_len, ...].
nanovllm uses [num_tokens, ...] format without batch dimension.
We add batch dimension for comparison with torch model.
"""
if tensor.dim() == 2: # [num_tokens, hidden_size]
return tensor.unsqueeze(0)
return tensor
def step(
self,
input_ids: torch.Tensor,
positions: Optional[torch.Tensor] = None,
is_prefill: bool = True,
) -> Generator[Breakpoint, None, torch.Tensor]:
"""
Execute nanovllm forward pass with hooks to capture breakpoints.
Unlike the torch adapter which manually steps through each component,
we run the full forward pass and collect captured values afterward.
"""
# Ensure 1D for nanovllm (it expects [num_tokens])
if input_ids.dim() == 2:
input_ids = input_ids.squeeze(0)
seq_len = input_ids.numel()
device = input_ids.device
# Generate position IDs if not provided
if positions is None:
positions = torch.arange(seq_len, device=device)
elif positions.dim() == 2:
positions = positions.squeeze(0)
# Register hooks
self._register_hooks()
try:
# Setup context for attention
self._setup_context(seq_len, device, is_prefill)
# Run forward pass (hooks capture everything)
with torch.no_grad():
hidden_states = self.model(input_ids, positions)
logits = self.model.compute_logits(hidden_states)
reset_context()
# Yield breakpoints in order from captured data
# EMBEDDING
if self.is_enabled(BreakpointType.EMBEDDING) and "embed" in self._captured:
yield Breakpoint(
bp_type=BreakpointType.EMBEDDING,
layer_idx=None,
tensor=self._normalize_tensor(self._captured["embed"]),
name="Embedding",
)
# LAYER_OUTPUT for each layer
if self.is_enabled(BreakpointType.LAYER_OUTPUT):
for layer_idx in range(self.num_layers):
key = f"layer_{layer_idx}"
if key in self._captured:
yield Breakpoint(
bp_type=BreakpointType.LAYER_OUTPUT,
layer_idx=layer_idx,
tensor=self._normalize_tensor(self._captured[key]),
name=f"Layer {layer_idx}",
)
# FINAL_NORM
if self.is_enabled(BreakpointType.FINAL_NORM) and "final_norm" in self._captured:
yield Breakpoint(
bp_type=BreakpointType.FINAL_NORM,
layer_idx=None,
tensor=self._normalize_tensor(self._captured["final_norm"]),
name="Final Norm",
)
# LM_HEAD
if self.is_enabled(BreakpointType.LM_HEAD) and "lm_head" in self._captured:
yield Breakpoint(
bp_type=BreakpointType.LM_HEAD,
layer_idx=None,
tensor=self._normalize_tensor(self._captured["lm_head"]),
name="LM Head",
)
return logits
finally:
self._remove_hooks()
self._captured = {}