Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions .github/workflows/pull_request.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,15 @@ jobs:
- os: windows-latest
python-version: "3.9"
dependency-set: minimum
- os: ubuntu-latest
python-version: "3.11"
dependency-set: direct-install
- os: macos-latest
python-version: "3.11"
dependency-set: direct-install
- os: windows-latest
python-version: "3.11"
dependency-set: direct-install
- os: ubuntu-latest
python-version: "3.13"
dependency-set: maximum
Expand Down Expand Up @@ -73,14 +82,16 @@ jobs:

- name: Install dependencies
run: |
uv pip install --system --no-deps .
# onnx is required for onnx export tests
# we don't install all dev dependencies here for speed
uv pip install --system -r requirements.txt
if [[ "${{ matrix.dependency-set }}" == "direct-install" ]]; then
uv pip install --system .
else
uv pip install --system --no-deps .
uv pip install --system -r requirements.txt
fi
uv pip install --system pytest psutil
# onnx is not supported on python 3.13 yet https://github.com/onnx/onnx/issues/6339
if [[ "${{ matrix.python-version }}" != "3.13" ]]; then
uv pip install --system onnx
uv pip install --system onnx onnxruntime
fi
shell: bash

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ dev = [
# Test
"pytest",
"onnx", # required for onnx export tests
"onnxruntime",
"psutil", # required for testing internal memory tool on windows
# Docs
"mkdocs",
Expand Down
64 changes: 56 additions & 8 deletions src/tabpfn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,26 @@
InferenceEngineCachePreprocessing,
InferenceEngineOnDemand,
)
from tabpfn.model.loading import load_model_criterion_config
from tabpfn.utils import infer_fp16_inference_mode
from tabpfn.model.loading import (
load_model_criterion_config,
)
from tabpfn.utils import (
infer_fp16_inference_mode,
)

if TYPE_CHECKING:
import numpy as np
import pandas as pd

from tabpfn.misc.compile_to_onnx import ONNXModelWrapper
from tabpfn.model.bar_distribution import FullSupportBarDistribution
from tabpfn.model.config import InferenceConfig
from tabpfn.model.transformer import PerFeatureTransformer


@overload
def initialize_tabpfn_model(
model_path: str | Path | Literal["auto"],
model_path: Path,
which: Literal["regressor"],
fit_mode: Literal["low_memory", "fit_preprocessors", "fit_with_cache"],
static_seed: int,
Expand All @@ -50,15 +55,15 @@ def initialize_tabpfn_model(

@overload
def initialize_tabpfn_model(
model_path: str | Path | Literal["auto"],
model_path: Path,
which: Literal["classifier"],
fit_mode: Literal["low_memory", "fit_preprocessors", "fit_with_cache"],
static_seed: int,
) -> tuple[PerFeatureTransformer, InferenceConfig, None]: ...


def initialize_tabpfn_model(
model_path: str | Path | Literal["auto"],
model_path: Path,
which: Literal["classifier", "regressor"],
fit_mode: Literal["low_memory", "fit_preprocessors", "fit_with_cache"],
static_seed: int,
Expand All @@ -79,8 +84,6 @@ def initialize_tabpfn_model(
"""
# Handle auto model_path
download = True
if isinstance(model_path, str) and model_path == "auto":
model_path = None # type: ignore

# Load model with potential caching
if which == "classifier":
Expand Down Expand Up @@ -112,6 +115,46 @@ def initialize_tabpfn_model(
return model, config_, bar_distribution


def load_onnx_model(
model_path: Path,
device: torch.device,
) -> ONNXModelWrapper:
"""Load a TabPFN model in ONNX format.

Args:
model_path: Path to the ONNX model file.
which: Which TabPFN model to load.
version: The version of the model.
device: The device to run the model on.

Returns:
The loaded ONNX model wrapped in a PyTorch-compatible interface.

Raises:
ImportError: If onnxruntime is not installed.
FileNotFoundError: If the model file doesn't exist.
"""
try:
from tabpfn.misc.compile_to_onnx import ONNXModelWrapper
except ImportError as err:
raise ImportError(
"onnxruntime is required to load ONNX models. "
"Install it with: pip install onnxruntime-gpu"
"or pip install onnxruntime",
) from err

if not model_path.exists():
raise FileNotFoundError(
f"ONNX model not found at: {model_path}, "
"please compile the model by running "
"`from tabpfn.misc.compile_to_onnx import compile_onnx_models; "
"compile_onnx_models()`"
"or change `model_path`.",
)

return ONNXModelWrapper(model_path, device)


def determine_precision(
inference_precision: torch.dtype | Literal["autocast", "auto"],
device_: torch.device,
Expand Down Expand Up @@ -158,7 +201,7 @@ def create_inference_engine( # noqa: PLR0913
*,
X_train: np.ndarray,
y_train: np.ndarray,
model: PerFeatureTransformer,
model: PerFeatureTransformer | ONNXModelWrapper,
ensemble_configs: Any,
cat_ix: list[int],
fit_mode: Literal["low_memory", "fit_preprocessors", "fit_with_cache"],
Expand All @@ -169,6 +212,7 @@ def create_inference_engine( # noqa: PLR0913
forced_inference_dtype_: torch.dtype | None,
memory_saving_mode: bool | Literal["auto"] | float | int,
use_autocast_: bool,
use_onnx: bool = False,
) -> InferenceEngine:
"""Creates the appropriate TabPFN inference engine based on `fit_mode`.

Expand All @@ -191,6 +235,7 @@ def create_inference_engine( # noqa: PLR0913
forced_inference_dtype_: If not None, the forced dtype for inference.
memory_saving_mode: GPU/CPU memory saving settings.
use_autocast_: Whether we use torch.autocast for inference.
use_onnx: Whether to use ONNX runtime for model inference.
"""
engine: (
InferenceEngineOnDemand
Expand All @@ -209,6 +254,7 @@ def create_inference_engine( # noqa: PLR0913
dtype_byte_size=byte_size,
force_inference_dtype=forced_inference_dtype_,
save_peak_mem=memory_saving_mode,
use_onnx=use_onnx,
)
elif fit_mode == "fit_preprocessors":
engine = InferenceEngineCachePreprocessing.prepare(
Expand All @@ -222,6 +268,7 @@ def create_inference_engine( # noqa: PLR0913
dtype_byte_size=byte_size,
force_inference_dtype=forced_inference_dtype_,
save_peak_mem=memory_saving_mode,
use_onnx=use_onnx,
)
elif fit_mode == "fit_with_cache":
engine = InferenceEngineCacheKV.prepare(
Expand All @@ -237,6 +284,7 @@ def create_inference_engine( # noqa: PLR0913
force_inference_dtype=forced_inference_dtype_,
save_peak_mem=memory_saving_mode,
autocast=use_autocast_,
use_onnx=use_onnx,
)
else:
raise ValueError(f"Invalid fit_mode: {fit_mode}")
Expand Down
49 changes: 41 additions & 8 deletions src/tabpfn/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
create_inference_engine,
determine_precision,
initialize_tabpfn_model,
load_onnx_model,
)
from tabpfn.config import ModelInterfaceConfig
from tabpfn.constants import (
Expand All @@ -43,6 +44,7 @@
XType,
YType,
)
from tabpfn.model.loading import resolve_model_path
from tabpfn.preprocessing import (
ClassifierEnsembleConfig,
EnsembleConfig,
Expand All @@ -68,7 +70,9 @@
from torch.types import _dtype

from tabpfn.inference import InferenceEngine
from tabpfn.misc.compile_to_onnx import ONNXModelWrapper
from tabpfn.model.config import InferenceConfig
from tabpfn.model.transformer import PerFeatureTransformer

try:
from sklearn.base import Tags
Expand Down Expand Up @@ -131,6 +135,9 @@ class TabPFNClassifier(ClassifierMixin, BaseEstimator):
preprocessor_: ColumnTransformer
"""The column transformer used to preprocess the input data to be numeric."""

model_: PerFeatureTransformer | ONNXModelWrapper
"""The loaded model used for inference."""

def __init__( # noqa: PLR0913
self,
*,
Expand All @@ -152,6 +159,7 @@ def __init__( # noqa: PLR0913
random_state: int | np.random.RandomState | np.random.Generator | None = 0,
n_jobs: int = -1,
inference_config: dict | ModelInterfaceConfig | None = None,
use_onnx: bool = False,
) -> None:
"""A TabPFN interface for classification.

Expand Down Expand Up @@ -341,6 +349,9 @@ def __init__( # noqa: PLR0913
- If `dict`, the key-value pairs are used to update the default
`ModelInterfaceConfig`. Raises an error if an unknown key is passed.
- If `ModelInterfaceConfig`, the object is used as the configuration.

use_onnx:
Whether to use an ONNX compiled model.
"""
super().__init__()
self.n_estimators = n_estimators
Expand All @@ -363,6 +374,7 @@ def __init__( # noqa: PLR0913
self.random_state = random_state
self.n_jobs = n_jobs
self.inference_config = inference_config
self.use_onnx = use_onnx

# TODO: We can remove this from scikit-learn lower bound of 1.6
def _more_tags(self) -> dict[str, Any]:
Expand All @@ -387,20 +399,40 @@ def fit(self, X: XType, y: YType) -> Self:
"""
static_seed, rng = infer_random_state(self.random_state)

# Load the model and config
self.model_, self.config_, _ = initialize_tabpfn_model(
model_path=self.model_path,
which="classifier",
fit_mode=self.fit_mode,
static_seed=static_seed,
)

# Determine device and precision
self.device_ = infer_device_and_type(self.device)
(self.use_autocast_, self.forced_inference_dtype_, byte_size) = (
determine_precision(self.inference_precision, self.device_)
)

model_path, _, _ = resolve_model_path(
self.model_path,
which="classifier",
version="v2",
use_onnx=self.use_onnx,
)
# Load the model and config
if self.use_onnx:
# if the model was already loaded with the same config
# use the same ONNX session
if hasattr(self, "model_") and (model_path, self.device_) == (
self.model_.model_path,
self.model_.device,
):
print("Using same ONNX session as last fit call") # noqa: T201
else:
self.model_ = load_onnx_model(
model_path,
device=self.device_,
)
else:
self.model_, self.config_, _ = initialize_tabpfn_model(
model_path=model_path,
which="classifier",
fit_mode=self.fit_mode,
static_seed=static_seed,
)

# Build the interface_config
self.interface_config_ = ModelInterfaceConfig.from_user_input(
inference_config=self.inference_config,
Expand Down Expand Up @@ -515,6 +547,7 @@ def fit(self, X: XType, y: YType) -> Self:
forced_inference_dtype_=self.forced_inference_dtype_,
memory_saving_mode=self.memory_saving_mode,
use_autocast_=self.use_autocast_,
use_onnx=self.use_onnx,
)

return self
Expand Down
Loading
Loading