[feat] Added debug tools.

This commit is contained in:
Zijie Tian
2026-01-03 22:36:40 +08:00
parent 9b52d25866
commit 00ed17c640
12 changed files with 1118 additions and 1 deletions

View File

@@ -0,0 +1,49 @@
"""
Breakpoint debugging tools for aligning nanovllm with reference implementations.
This module provides a generator-based breakpoint aligner that enables step-by-step
comparison between nanovllm and torch reference model outputs.
Example:
>>> from nanovllm.debug import BreakpointAligner, TorchSteppable, NanovllmSteppable
>>> from tests.modeling_qwen3 import Qwen3ForCausalLM
>>>
>>> # Load models
>>> torch_model = Qwen3ForCausalLM.from_pretrained(model_path, dtype=torch.float16)
>>> nanovllm_model = ... # Your nanovllm model
>>>
>>> # Create adapters
>>> ref = TorchSteppable(torch_model)
>>> test = NanovllmSteppable(nanovllm_model)
>>>
>>> # Run alignment
>>> aligner = BreakpointAligner(ref, test)
>>> result = aligner.align(input_ids)
>>> print(result)
"""
from .breakpoints import BreakpointType, Breakpoint
from .comparator import TensorComparator, ComparisonResult
from .aligner import BreakpointAligner, AlignmentResult
from .adapters import SteppableModel, TorchSteppable, NanovllmSteppable
from .utils import setup_prefill_context, setup_decode_context, cleanup_context
__all__ = [
# Core classes
"BreakpointAligner",
"AlignmentResult",
# Breakpoints
"BreakpointType",
"Breakpoint",
# Comparator
"TensorComparator",
"ComparisonResult",
# Adapters
"SteppableModel",
"TorchSteppable",
"NanovllmSteppable",
# Utils
"setup_prefill_context",
"setup_decode_context",
"cleanup_context",
]

View File

@@ -0,0 +1,11 @@
"""Model adapters for breakpoint alignment."""
from .base import SteppableModel
from .torch_adapter import TorchSteppable
from .nanovllm_adapter import NanovllmSteppable
__all__ = [
"SteppableModel",
"TorchSteppable",
"NanovllmSteppable",
]

View File

@@ -0,0 +1,59 @@
"""Base class for steppable model adapters."""
from abc import ABC, abstractmethod
from typing import Generator, Set, Optional
import torch
from ..breakpoints import Breakpoint, BreakpointType
class SteppableModel(ABC):
"""
Abstract base class for models that can yield at breakpoints.
Subclasses implement the step() method as a generator that yields
Breakpoint objects at each enabled breakpoint during forward pass.
"""
def __init__(self, enabled_breakpoints: Optional[Set[BreakpointType]] = None):
"""
Args:
enabled_breakpoints: Set of breakpoint types to yield at.
If None, yields at all breakpoints.
"""
self.enabled_breakpoints = enabled_breakpoints
def is_enabled(self, bp_type: BreakpointType) -> bool:
"""Check if a breakpoint type is enabled."""
if self.enabled_breakpoints is None:
return True
return bp_type in self.enabled_breakpoints
@abstractmethod
def step(
self,
input_ids: torch.Tensor,
positions: Optional[torch.Tensor] = None,
is_prefill: bool = True,
) -> Generator[Breakpoint, None, torch.Tensor]:
"""
Generator that yields Breakpoint objects at enabled breakpoints.
Args:
input_ids: Input token IDs
positions: Position IDs (optional, auto-generated if None)
is_prefill: True for prefill phase, False for decode
Yields:
Breakpoint objects at each enabled checkpoint
Returns:
Final output tensor (logits)
"""
pass
@property
@abstractmethod
def num_layers(self) -> int:
"""Return the number of decoder layers."""
pass

View File

@@ -0,0 +1,229 @@
"""Nanovllm model adapter for breakpoint alignment."""
from typing import Generator, Set, Optional, Dict, Any, List
import torch
from nanovllm.utils.context import set_context, reset_context
from ..breakpoints import Breakpoint, BreakpointType
from .base import SteppableModel
class NanovllmSteppable(SteppableModel):
"""
Steppable adapter for nanovllm Qwen3 implementation.
Uses PyTorch hooks to capture intermediate values during forward pass,
then yields them as breakpoints after execution completes.
Key challenges handled:
1. Shape difference: nanovllm uses [num_tokens, hidden] vs [batch, seq, hidden]
2. Context-based attention: must call set_context() before forward
3. Fused operations: decoder layer returns (hidden_states, residual) tuple
"""
def __init__(
self,
model,
enabled_breakpoints: Optional[Set[BreakpointType]] = None,
):
"""
Args:
model: Qwen3ForCausalLM from nanovllm
enabled_breakpoints: Set of breakpoint types to yield at
"""
super().__init__(enabled_breakpoints)
self.model = model
self.model.eval()
self._hooks: List[Any] = []
self._captured: Dict[str, torch.Tensor] = {}
@property
def num_layers(self) -> int:
return len(self.model.model.layers)
def _register_hooks(self):
"""Register forward hooks on all relevant modules."""
self._hooks = []
self._captured = {}
# Hook for embedding output
def embed_hook(module, input, output):
self._captured["embed"] = output.detach().clone()
self._hooks.append(
self.model.model.embed_tokens.register_forward_hook(embed_hook)
)
# Hooks for each decoder layer
for layer_idx in range(self.num_layers):
layer = self.model.model.layers[layer_idx]
def make_layer_hook(idx):
def hook(module, input, output):
# Decoder layer returns (hidden_states, residual)
hidden_states = output[0] if isinstance(output, tuple) else output
self._captured[f"layer_{idx}"] = hidden_states.detach().clone()
return hook
self._hooks.append(
layer.register_forward_hook(make_layer_hook(layer_idx))
)
# Hook for final norm
def final_norm_hook(module, input, output):
# Final norm returns (hidden_states, _) for fused add
hidden_states = output[0] if isinstance(output, tuple) else output
self._captured["final_norm"] = hidden_states.detach().clone()
self._hooks.append(
self.model.model.norm.register_forward_hook(final_norm_hook)
)
# Hook for lm_head
def lm_head_hook(module, input, output):
self._captured["lm_head"] = output.detach().clone()
self._hooks.append(
self.model.lm_head.register_forward_hook(lm_head_hook)
)
def _remove_hooks(self):
"""Remove all registered hooks."""
for hook in self._hooks:
hook.remove()
self._hooks = []
def _setup_context(self, seq_len: int, device: torch.device, is_prefill: bool):
"""
Set up nanovllm context for forward pass.
For alignment testing, we use simple context without real KV cache.
"""
if is_prefill:
# Prefill: process all tokens at once
cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=device)
# Use -1 for slot_mapping to skip KV cache writes
slot_mapping = torch.full((seq_len,), -1, dtype=torch.int32, device=device)
set_context(
is_prefill=True,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=seq_len,
max_seqlen_k=seq_len,
slot_mapping=slot_mapping,
is_chunked_prefill=False,
)
else:
# Decode: single token generation
# For decode, we need context_lens and block_tables
# For alignment testing without real KV cache, we use minimal setup
context_lens = torch.tensor([seq_len - 1], dtype=torch.int32, device=device)
# Single token slot
slot_mapping = torch.tensor([-1], dtype=torch.int32, device=device)
# Empty block tables (no KV cache)
block_tables = torch.zeros((1, 1), dtype=torch.int32, device=device)
set_context(
is_prefill=False,
slot_mapping=slot_mapping,
context_lens=context_lens,
block_tables=block_tables,
)
def _normalize_tensor(self, tensor: torch.Tensor) -> torch.Tensor:
"""
Normalize nanovllm tensor shape to [batch, seq_len, ...].
nanovllm uses [num_tokens, ...] format without batch dimension.
We add batch dimension for comparison with torch model.
"""
if tensor.dim() == 2: # [num_tokens, hidden_size]
return tensor.unsqueeze(0)
return tensor
def step(
self,
input_ids: torch.Tensor,
positions: Optional[torch.Tensor] = None,
is_prefill: bool = True,
) -> Generator[Breakpoint, None, torch.Tensor]:
"""
Execute nanovllm forward pass with hooks to capture breakpoints.
Unlike the torch adapter which manually steps through each component,
we run the full forward pass and collect captured values afterward.
"""
# Ensure 1D for nanovllm (it expects [num_tokens])
if input_ids.dim() == 2:
input_ids = input_ids.squeeze(0)
seq_len = input_ids.numel()
device = input_ids.device
# Generate position IDs if not provided
if positions is None:
positions = torch.arange(seq_len, device=device)
elif positions.dim() == 2:
positions = positions.squeeze(0)
# Register hooks
self._register_hooks()
try:
# Setup context for attention
self._setup_context(seq_len, device, is_prefill)
# Run forward pass (hooks capture everything)
with torch.no_grad():
hidden_states = self.model(input_ids, positions)
logits = self.model.compute_logits(hidden_states)
reset_context()
# Yield breakpoints in order from captured data
# EMBEDDING
if self.is_enabled(BreakpointType.EMBEDDING) and "embed" in self._captured:
yield Breakpoint(
bp_type=BreakpointType.EMBEDDING,
layer_idx=None,
tensor=self._normalize_tensor(self._captured["embed"]),
name="Embedding",
)
# LAYER_OUTPUT for each layer
if self.is_enabled(BreakpointType.LAYER_OUTPUT):
for layer_idx in range(self.num_layers):
key = f"layer_{layer_idx}"
if key in self._captured:
yield Breakpoint(
bp_type=BreakpointType.LAYER_OUTPUT,
layer_idx=layer_idx,
tensor=self._normalize_tensor(self._captured[key]),
name=f"Layer {layer_idx}",
)
# FINAL_NORM
if self.is_enabled(BreakpointType.FINAL_NORM) and "final_norm" in self._captured:
yield Breakpoint(
bp_type=BreakpointType.FINAL_NORM,
layer_idx=None,
tensor=self._normalize_tensor(self._captured["final_norm"]),
name="Final Norm",
)
# LM_HEAD
if self.is_enabled(BreakpointType.LM_HEAD) and "lm_head" in self._captured:
yield Breakpoint(
bp_type=BreakpointType.LM_HEAD,
layer_idx=None,
tensor=self._normalize_tensor(self._captured["lm_head"]),
name="LM Head",
)
return logits
finally:
self._remove_hooks()
self._captured = {}

View File

@@ -0,0 +1,119 @@
"""Torch reference model adapter for breakpoint alignment."""
from typing import Generator, Set, Optional
import torch
from ..breakpoints import Breakpoint, BreakpointType
from .base import SteppableModel
class TorchSteppable(SteppableModel):
"""
Steppable adapter for the torch reference Qwen3 implementation.
Wraps tests/modeling_qwen3.py Qwen3ForCausalLM model.
"""
def __init__(
self,
model,
enabled_breakpoints: Optional[Set[BreakpointType]] = None,
):
"""
Args:
model: Qwen3ForCausalLM from tests/modeling_qwen3.py
enabled_breakpoints: Set of breakpoint types to yield at
"""
super().__init__(enabled_breakpoints)
self.model = model
self.model.eval()
@property
def num_layers(self) -> int:
return len(self.model.model.layers)
def step(
self,
input_ids: torch.Tensor,
positions: Optional[torch.Tensor] = None,
is_prefill: bool = True,
) -> Generator[Breakpoint, None, torch.Tensor]:
"""
Generator that manually steps through the torch model.
The torch model uses [batch, seq_len, hidden_size] shapes.
"""
# Ensure batch dimension
if input_ids.dim() == 1:
input_ids = input_ids.unsqueeze(0)
batch_size, seq_len = input_ids.shape
device = input_ids.device
# Generate position IDs if not provided
if positions is None:
positions = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1)
elif positions.dim() == 1:
positions = positions.unsqueeze(0)
with torch.no_grad():
# === EMBEDDING ===
hidden_states = self.model.model.embed_tokens(input_ids)
if self.is_enabled(BreakpointType.EMBEDDING):
yield Breakpoint(
bp_type=BreakpointType.EMBEDDING,
layer_idx=None,
tensor=hidden_states.detach().clone(),
name="Embedding",
)
# Create causal attention mask
causal_mask = torch.triu(
torch.full((seq_len, seq_len), float("-inf"), device=device),
diagonal=1,
)
attention_mask = causal_mask.unsqueeze(0).unsqueeze(0)
# === DECODER LAYERS ===
for layer_idx, layer in enumerate(self.model.model.layers):
hidden_states, _, _ = layer(
hidden_states=hidden_states,
position_ids=positions,
attention_mask=attention_mask,
past_key_value=None,
use_cache=False,
output_qkv=False,
)
if self.is_enabled(BreakpointType.LAYER_OUTPUT):
yield Breakpoint(
bp_type=BreakpointType.LAYER_OUTPUT,
layer_idx=layer_idx,
tensor=hidden_states.detach().clone(),
name=f"Layer {layer_idx}",
)
# === FINAL NORM ===
hidden_states = self.model.model.norm(hidden_states)
if self.is_enabled(BreakpointType.FINAL_NORM):
yield Breakpoint(
bp_type=BreakpointType.FINAL_NORM,
layer_idx=None,
tensor=hidden_states.detach().clone(),
name="Final Norm",
)
# === LM HEAD ===
logits = self.model.lm_head(hidden_states)
if self.is_enabled(BreakpointType.LM_HEAD):
yield Breakpoint(
bp_type=BreakpointType.LM_HEAD,
layer_idx=None,
tensor=logits.detach().clone(),
name="LM Head",
)
return logits

211
nanovllm/debug/aligner.py Normal file
View File

@@ -0,0 +1,211 @@
"""Breakpoint aligner for comparing model outputs."""
from dataclasses import dataclass, field
from typing import List, Tuple, Optional
import torch
from .breakpoints import Breakpoint
from .comparator import TensorComparator, ComparisonResult
from .adapters.base import SteppableModel
@dataclass
class AlignmentResult:
"""Result of an alignment test."""
passed: bool
all_comparisons: List[Tuple[Breakpoint, Breakpoint, ComparisonResult]] = field(default_factory=list)
failed_at: Optional[Breakpoint] = None
message: str = ""
def __repr__(self) -> str:
passed_count = sum(1 for _, _, c in self.all_comparisons if c.passed)
total = len(self.all_comparisons)
status = "PASSED" if self.passed else "FAILED"
return f"AlignmentResult({status}, {passed_count}/{total} breakpoints passed)"
class BreakpointAligner:
"""
Orchestrates alternating execution of reference and test models,
comparing outputs at each breakpoint.
Example:
>>> from nanovllm.debug import BreakpointAligner, TorchSteppable, NanovllmSteppable
>>> ref = TorchSteppable(torch_model)
>>> test = NanovllmSteppable(nanovllm_model)
>>> aligner = BreakpointAligner(ref, test)
>>> result = aligner.align(input_ids)
>>> print(result)
"""
def __init__(
self,
ref_model: SteppableModel,
test_model: SteppableModel,
comparator: Optional[TensorComparator] = None,
stop_on_error: bool = True,
verbose: bool = True,
):
"""
Args:
ref_model: Reference (torch) steppable model
test_model: Test (nanovllm) steppable model
comparator: Tensor comparator instance (uses default if None)
stop_on_error: If True, stop at first mismatch
verbose: If True, print comparison results
"""
self.ref_model = ref_model
self.test_model = test_model
self.comparator = comparator or TensorComparator()
self.stop_on_error = stop_on_error
self.verbose = verbose
def align(
self,
input_ids: torch.Tensor,
positions: Optional[torch.Tensor] = None,
is_prefill: bool = True,
) -> AlignmentResult:
"""
Run both models with same input, comparing at each breakpoint.
Args:
input_ids: Input token IDs
positions: Position IDs (optional)
is_prefill: True for prefill phase, False for decode
Returns:
AlignmentResult with pass/fail status and details
"""
all_comparisons: List[Tuple[Breakpoint, Breakpoint, ComparisonResult]] = []
if self.verbose:
phase = "prefill" if is_prefill else "decode"
print(f"\n{'='*60}")
print(f"Alignment Test ({phase})")
print(f"{'='*60}")
# Start both generators
ref_gen = self.ref_model.step(input_ids, positions, is_prefill)
test_gen = self.test_model.step(input_ids, positions, is_prefill)
try:
while True:
# Get next breakpoint from reference
try:
ref_bp = next(ref_gen)
except StopIteration:
break
# Get corresponding breakpoint from test
try:
test_bp = next(test_gen)
except StopIteration:
if self.verbose:
print(f"Test model ended early at {ref_bp.name}")
return AlignmentResult(
passed=False,
all_comparisons=all_comparisons,
failed_at=ref_bp,
message=f"Test model ended early at {ref_bp.name}",
)
# Verify breakpoints match
if ref_bp.bp_type != test_bp.bp_type:
msg = f"Breakpoint type mismatch: {ref_bp.bp_type} vs {test_bp.bp_type}"
if self.verbose:
print(msg)
return AlignmentResult(
passed=False,
all_comparisons=all_comparisons,
failed_at=ref_bp,
message=msg,
)
if ref_bp.layer_idx != test_bp.layer_idx:
msg = f"Layer index mismatch: {ref_bp.layer_idx} vs {test_bp.layer_idx}"
if self.verbose:
print(msg)
return AlignmentResult(
passed=False,
all_comparisons=all_comparisons,
failed_at=ref_bp,
message=msg,
)
# Normalize shapes for comparison
ref_t = ref_bp.normalize_shape()
test_t = test_bp.normalize_shape()
# Handle shape mismatches
if ref_t.shape != test_t.shape:
if self.verbose:
print(f"[{ref_bp.name}] Shape mismatch: ref={ref_t.shape} vs test={test_t.shape}")
# Try to reshape if element count matches
if ref_t.numel() == test_t.numel():
test_t = test_t.view(ref_t.shape)
else:
msg = f"Shape mismatch at {ref_bp.name}: {ref_t.shape} vs {test_t.shape}"
return AlignmentResult(
passed=False,
all_comparisons=all_comparisons,
failed_at=ref_bp,
message=msg,
)
# Compare tensors
result = self.comparator.compare(ref_t, test_t, ref_bp.name)
all_comparisons.append((ref_bp, test_bp, result))
if self.verbose:
status = "\u2713" if result.passed else "\u2717"
print(f"{status} [{ref_bp.name}] cos={result.cosine_similarity:.6f}, max_diff={result.max_abs_diff:.2e}")
if not result.passed and self.stop_on_error:
if self.verbose:
print(f"\nStopped at {ref_bp.name} (stop_on_error=True)")
print(result.message)
return AlignmentResult(
passed=False,
all_comparisons=all_comparisons,
failed_at=ref_bp,
message=f"Alignment failed at {ref_bp.name}",
)
# Check for extra test breakpoints
try:
extra_bp = next(test_gen)
msg = f"Test model has extra breakpoints starting at {extra_bp.name}"
if self.verbose:
print(msg)
return AlignmentResult(
passed=False,
all_comparisons=all_comparisons,
message=msg,
)
except StopIteration:
pass
except Exception as e:
msg = f"Exception during alignment: {str(e)}"
if self.verbose:
print(msg)
raise
# Summary
all_passed = all(comp[2].passed for comp in all_comparisons)
passed_count = sum(1 for _, _, c in all_comparisons if c.passed)
total = len(all_comparisons)
if self.verbose:
print(f"{'='*60}")
status = "PASSED" if all_passed else "FAILED"
print(f"Result: {status} ({passed_count}/{total} breakpoints)")
print(f"{'='*60}\n")
return AlignmentResult(
passed=all_passed,
all_comparisons=all_comparisons,
message="All breakpoints aligned" if all_passed else "Some breakpoints failed",
)

View File

@@ -0,0 +1,39 @@
"""Breakpoint types and data structures for alignment debugging."""
from dataclasses import dataclass
from enum import Enum, auto
from typing import Optional
import torch
class BreakpointType(Enum):
"""Types of breakpoints in the model forward pass."""
EMBEDDING = auto() # After embed_tokens
LAYER_OUTPUT = auto() # After each decoder layer
FINAL_NORM = auto() # After final RMSNorm
LM_HEAD = auto() # After lm_head (logits)
@dataclass
class Breakpoint:
"""A captured breakpoint with tensor data."""
bp_type: BreakpointType
layer_idx: Optional[int] # None for EMBEDDING, FINAL_NORM, LM_HEAD
tensor: torch.Tensor
name: str
def normalize_shape(self) -> torch.Tensor:
"""
Normalize tensor shape for comparison.
nanovllm uses [num_tokens, hidden_size] while torch uses
[batch, seq_len, hidden_size]. This adds a batch dimension
to 2D tensors for comparison.
"""
if self.tensor.dim() == 2:
return self.tensor.unsqueeze(0)
return self.tensor
def __repr__(self) -> str:
shape_str = "x".join(str(d) for d in self.tensor.shape)
return f"Breakpoint({self.name}, shape={shape_str}, dtype={self.tensor.dtype})"

View File

@@ -0,0 +1,94 @@
"""Tensor comparison utilities for alignment debugging."""
from dataclasses import dataclass
import torch
import torch.nn.functional as F
@dataclass
class ComparisonResult:
"""Result of comparing two tensors."""
passed: bool
cosine_similarity: float
max_abs_diff: float
mean_abs_diff: float
message: str
def __repr__(self) -> str:
status = "\u2713" if self.passed else "\u2717"
return f"{status} cos={self.cosine_similarity:.6f}, max_diff={self.max_abs_diff:.2e}"
class TensorComparator:
"""Compares tensors using cosine similarity and absolute differences."""
def __init__(
self,
cosine_threshold: float = 0.999,
max_diff_threshold: float = 0.1,
mean_diff_threshold: float = 0.01,
):
"""
Args:
cosine_threshold: Minimum cosine similarity to pass (0-1)
max_diff_threshold: Maximum allowed absolute difference
mean_diff_threshold: Maximum allowed mean absolute difference
"""
self.cosine_threshold = cosine_threshold
self.max_diff_threshold = max_diff_threshold
self.mean_diff_threshold = mean_diff_threshold
def compare(
self,
ref: torch.Tensor,
test: torch.Tensor,
name: str = "",
) -> ComparisonResult:
"""
Compare two tensors and return detailed result.
Args:
ref: Reference tensor
test: Test tensor
name: Name for the comparison (used in message)
Returns:
ComparisonResult with pass/fail status and metrics
"""
# Convert to float32 for comparison
ref_f = ref.float().flatten()
test_f = test.float().flatten()
# Cosine similarity
cos_sim = F.cosine_similarity(
ref_f.unsqueeze(0),
test_f.unsqueeze(0)
).item()
# Absolute differences
diff = (ref.float() - test.float()).abs()
max_diff = diff.max().item()
mean_diff = diff.mean().item()
# Check thresholds
passed = (
cos_sim >= self.cosine_threshold and
max_diff <= self.max_diff_threshold and
mean_diff <= self.mean_diff_threshold
)
status = "PASS" if passed else "FAIL"
message = (
f"[{name}] {status}\n"
f" Cosine Similarity: {cos_sim:.6f} (threshold: {self.cosine_threshold})\n"
f" Max Abs Diff: {max_diff:.6f} (threshold: {self.max_diff_threshold})\n"
f" Mean Abs Diff: {mean_diff:.6f} (threshold: {self.mean_diff_threshold})"
)
return ComparisonResult(
passed=passed,
cosine_similarity=cos_sim,
max_abs_diff=max_diff,
mean_abs_diff=mean_diff,
message=message,
)

51
nanovllm/debug/utils.py Normal file
View File

@@ -0,0 +1,51 @@
"""Utility functions for breakpoint alignment debugging."""
import torch
from nanovllm.utils.context import set_context, reset_context
def setup_prefill_context(seq_len: int, device: torch.device):
"""
Set up nanovllm context for prefill alignment testing.
Args:
seq_len: Sequence length
device: Target device
"""
cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=device)
slot_mapping = torch.full((seq_len,), -1, dtype=torch.int32, device=device)
set_context(
is_prefill=True,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=seq_len,
max_seqlen_k=seq_len,
slot_mapping=slot_mapping,
is_chunked_prefill=False,
)
def setup_decode_context(context_len: int, device: torch.device):
"""
Set up nanovllm context for decode alignment testing.
Args:
context_len: Context length (number of previous tokens)
device: Target device
"""
context_lens = torch.tensor([context_len], dtype=torch.int32, device=device)
slot_mapping = torch.tensor([-1], dtype=torch.int32, device=device)
block_tables = torch.zeros((1, 1), dtype=torch.int32, device=device)
set_context(
is_prefill=False,
slot_mapping=slot_mapping,
context_lens=context_lens,
block_tables=block_tables,
)
def cleanup_context():
"""Reset nanovllm context after alignment testing."""
reset_context()

View File

@@ -0,0 +1,134 @@
"""
Test NanovllmSteppable: Print activation statistics at each layer.
Usage:
python tests/test_nanovllm_steppable.py
"""
import os
os.environ["NANOVLLM_LOG_LEVEL"] = "WARNING"
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent))
import torch
from transformers import AutoTokenizer
from nanovllm import LLM
from nanovllm.debug.adapters.nanovllm_adapter import NanovllmSteppable
from utils import generate_needle_prompt, check_needle_answer
# ============================================================
# Config
# ============================================================
MODEL_PATH = os.path.expanduser("~/models/Qwen3-0.6B/")
INPUT_LEN = 32768 # Longer context to test offload
MAX_NEW_TOKENS = 20
DTYPE = torch.float16
ENABLE_CPU_OFFLOAD = True # Test offload mode
# ============================================================
# Load Model
# ============================================================
print(f"Loading nanovllm model (cpu_offload={ENABLE_CPU_OFFLOAD})...")
llm = LLM(
MODEL_PATH,
enforce_eager=True, # Required for hooks to work
max_model_len=40960,
max_num_batched_tokens=40960,
enable_cpu_offload=ENABLE_CPU_OFFLOAD,
dtype="float16",
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
# Get the underlying model for steppable
model = llm.model_runner.model
# ============================================================
# Prepare Input (using needle-in-haystack prompt)
# ============================================================
prompt, expected_answer = generate_needle_prompt(
tokenizer,
target_length=INPUT_LEN,
needle_position=0.5,
needle_value="7492",
use_chat_template=False,
verbose=True,
)
input_ids = tokenizer.encode(prompt, return_tensors="pt").to("cuda")
print(f"Input shape: {input_ids.shape}")
print(f"Expected answer: {expected_answer}\n")
# ============================================================
# Create Steppable Model (reused for prefill + decode)
# ============================================================
steppable = NanovllmSteppable(model)
# ============================================================
# Prefill Phase: Print activation stats
# ============================================================
print("=" * 85)
print("PREFILL PHASE")
print("=" * 85)
print(f"{'Layer':<15} {'Shape':<25} {'Mean':>10} {'Std':>10} {'Min':>10} {'Max':>10}")
print("-" * 85)
current_ids = input_ids.clone()
logits = None
for bp in steppable.step(current_ids, is_prefill=True):
t = bp.tensor.float()
shape_str = str(list(t.shape))
print(f"{bp.name:<15} {shape_str:<25} {t.mean():>10.4f} {t.std():>10.4f} {t.min():>10.4f} {t.max():>10.4f}")
if bp.name == "LM Head":
logits = bp.tensor
# Get first token from prefill
next_token_id = logits[0, -1].argmax().item()
next_token = tokenizer.decode(next_token_id)
current_ids = torch.cat([current_ids, torch.tensor([[next_token_id]], device=current_ids.device)], dim=1)
generated_tokens = [next_token]
# ============================================================
# Decode Phase: Only print generated tokens
# ============================================================
print("\n" + "=" * 85)
print("DECODE PHASE")
print("=" * 85)
print(f"Step 1: {next_token!r}")
for step in range(2, MAX_NEW_TOKENS + 1):
# Forward pass with full sequence (reuse same steppable)
# Note: nanovllm without KV cache needs full sequence for each decode
for bp in steppable.step(current_ids, is_prefill=True):
if bp.name == "LM Head":
logits = bp.tensor
# Get next token (greedy)
next_token_id = logits[0, -1].argmax().item()
next_token = tokenizer.decode(next_token_id)
generated_tokens.append(next_token)
print(f"Step {step:2}: {next_token!r}")
# Stop if EOS
if next_token_id == tokenizer.eos_token_id:
print(" (EOS)")
break
# Append to sequence
current_ids = torch.cat([current_ids, torch.tensor([[next_token_id]], device=current_ids.device)], dim=1)
# ============================================================
# Result
# ============================================================
print("\n" + "=" * 85)
print("RESULT")
print("=" * 85)
generated_text = "".join(generated_tokens)
print(f"Generated: {generated_text!r}")
print(f"Expected: {expected_answer}")
print(f"Answer: {'CORRECT!' if check_needle_answer(generated_text, expected_answer) else 'INCORRECT'}")
print("\ntest_nanovllm_steppable: PASSED")

View File

@@ -8,7 +8,7 @@ sequences longer than ~200 tokens. Use --no-offload for correctness testing.
"""
import os
os.environ["NANOVLLM_LOG_LEVEL"] = "DEBUG"
os.environ["NANOVLLM_LOG_LEVEL"] = "INFO"
import argparse
from nanovllm import LLM, SamplingParams

View File

@@ -0,0 +1,121 @@
"""
Test TorchSteppable: Print activation statistics at each layer.
Usage:
python tests/test_torch_steppable.py
"""
import os
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent))
import torch
from transformers import AutoTokenizer
from modeling_qwen3 import Qwen3ForCausalLM
from nanovllm.debug.adapters.torch_adapter import TorchSteppable
from utils import generate_needle_prompt, check_needle_answer
# ============================================================
# Config
# ============================================================
MODEL_PATH = os.path.expanduser("~/models/Qwen3-0.6B/")
INPUT_LEN = 512
MAX_NEW_TOKENS = 20
DTYPE = torch.float16
# ============================================================
# Load Model
# ============================================================
print("Loading model...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = Qwen3ForCausalLM.from_pretrained(MODEL_PATH, dtype=DTYPE)
model = model.to("cuda").eval()
# ============================================================
# Prepare Input (using needle-in-haystack prompt)
# ============================================================
prompt, expected_answer = generate_needle_prompt(
tokenizer,
target_length=INPUT_LEN,
needle_position=0.5,
needle_value="7492",
use_chat_template=False,
verbose=True,
)
input_ids = tokenizer.encode(prompt, return_tensors="pt").to("cuda")
print(f"Input shape: {input_ids.shape}")
print(f"Expected answer: {expected_answer}\n")
# ============================================================
# Create Steppable Model (reused for prefill + decode)
# ============================================================
steppable = TorchSteppable(model)
# ============================================================
# Prefill Phase: Print activation stats
# ============================================================
print("=" * 85)
print("PREFILL PHASE")
print("=" * 85)
print(f"{'Layer':<15} {'Shape':<25} {'Mean':>10} {'Std':>10} {'Min':>10} {'Max':>10}")
print("-" * 85)
current_ids = input_ids.clone()
logits = None
for bp in steppable.step(current_ids):
t = bp.tensor.float()
shape_str = str(list(t.shape))
print(f"{bp.name:<15} {shape_str:<25} {t.mean():>10.4f} {t.std():>10.4f} {t.min():>10.4f} {t.max():>10.4f}")
if bp.name == "LM Head":
logits = bp.tensor
# Get first token from prefill
next_token_id = logits[0, -1].argmax().item()
next_token = tokenizer.decode(next_token_id)
current_ids = torch.cat([current_ids, torch.tensor([[next_token_id]], device=current_ids.device)], dim=1)
generated_tokens = [next_token]
# ============================================================
# Decode Phase: Only print generated tokens
# ============================================================
print("\n" + "=" * 85)
print("DECODE PHASE")
print("=" * 85)
print(f"Step 1: {next_token!r}")
for step in range(2, MAX_NEW_TOKENS + 1):
# Forward pass (reuse same steppable)
for bp in steppable.step(current_ids):
if bp.name == "LM Head":
logits = bp.tensor
# Get next token (greedy)
next_token_id = logits[0, -1].argmax().item()
next_token = tokenizer.decode(next_token_id)
generated_tokens.append(next_token)
print(f"Step {step:2}: {next_token!r}")
# Stop if EOS
if next_token_id == tokenizer.eos_token_id:
print(" (EOS)")
break
# Append to sequence
current_ids = torch.cat([current_ids, torch.tensor([[next_token_id]], device=current_ids.device)], dim=1)
# ============================================================
# Result
# ============================================================
print("\n" + "=" * 85)
print("RESULT")
print("=" * 85)
generated_text = "".join(generated_tokens)
print(f"Generated: {generated_text!r}")
print(f"Expected: {expected_answer}")
print(f"Answer: {'CORRECT!' if check_needle_answer(generated_text, expected_answer) else 'INCORRECT'}")
print("\ntest_torch_steppable: PASSED")