Files
nano-vllm/nanovllm/debug/adapters/base.py
2026-01-03 22:36:40 +08:00

60 lines
1.7 KiB
Python

"""Base class for steppable model adapters."""
from abc import ABC, abstractmethod
from typing import Generator, Set, Optional
import torch
from ..breakpoints import Breakpoint, BreakpointType
class SteppableModel(ABC):
"""
Abstract base class for models that can yield at breakpoints.
Subclasses implement the step() method as a generator that yields
Breakpoint objects at each enabled breakpoint during forward pass.
"""
def __init__(self, enabled_breakpoints: Optional[Set[BreakpointType]] = None):
"""
Args:
enabled_breakpoints: Set of breakpoint types to yield at.
If None, yields at all breakpoints.
"""
self.enabled_breakpoints = enabled_breakpoints
def is_enabled(self, bp_type: BreakpointType) -> bool:
"""Check if a breakpoint type is enabled."""
if self.enabled_breakpoints is None:
return True
return bp_type in self.enabled_breakpoints
@abstractmethod
def step(
self,
input_ids: torch.Tensor,
positions: Optional[torch.Tensor] = None,
is_prefill: bool = True,
) -> Generator[Breakpoint, None, torch.Tensor]:
"""
Generator that yields Breakpoint objects at enabled breakpoints.
Args:
input_ids: Input token IDs
positions: Position IDs (optional, auto-generated if None)
is_prefill: True for prefill phase, False for decode
Yields:
Breakpoint objects at each enabled checkpoint
Returns:
Final output tensor (logits)
"""
pass
@property
@abstractmethod
def num_layers(self) -> int:
"""Return the number of decoder layers."""
pass