Skip to content

Commit 79dafa3

Browse files
committed
allow to use onnx model inside sklearn interface
1 parent a537f40 commit 79dafa3

File tree

9 files changed

+388
-14
lines changed

9 files changed

+388
-14
lines changed

src/tabpfn/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from importlib.metadata import version
22

33
from tabpfn.classifier import TabPFNClassifier
4-
from tabpfn.debug_versions import display_debug_info
4+
from tabpfn.misc.debug_versions import display_debug_info
55
from tabpfn.regressor import TabPFNRegressor
66

77
try:

src/tabpfn/base.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
if TYPE_CHECKING:
3434
import numpy as np
3535

36+
from tabpfn.misc.onnx_wrapper import ONNXModelWrapper
3637
from tabpfn.model.bar_distribution import FullSupportBarDistribution
3738
from tabpfn.model.config import InferenceConfig
3839
from tabpfn.model.transformer import PerFeatureTransformer
@@ -111,6 +112,36 @@ def initialize_tabpfn_model(
111112
return model, config_, bar_distribution
112113

113114

115+
def load_onnx_model(
116+
model_path: str | Path,
117+
) -> ONNXModelWrapper:
118+
"""Load a TabPFN model in ONNX format.
119+
120+
Args:
121+
model_path: Path to the ONNX model file.
122+
123+
Returns:
124+
The loaded ONNX model wrapped in a PyTorch-compatible interface.
125+
126+
Raises:
127+
ImportError: If onnxruntime is not installed.
128+
FileNotFoundError: If the model file doesn't exist.
129+
"""
130+
try:
131+
from tabpfn.misc.onnx_wrapper import ONNXModelWrapper
132+
except ImportError as err:
133+
raise ImportError(
134+
"onnxruntime is required to load ONNX models. "
135+
"Install it with: pip install onnxruntime",
136+
) from err
137+
138+
model_path = Path(model_path)
139+
if not model_path.exists():
140+
raise FileNotFoundError(f"ONNX model not found at: {model_path}")
141+
142+
return ONNXModelWrapper(str(model_path))
143+
144+
114145
def determine_precision(
115146
inference_precision: torch.dtype | Literal["autocast", "auto"],
116147
device_: torch.device,
@@ -168,6 +199,7 @@ def create_inference_engine( # noqa: PLR0913
168199
forced_inference_dtype_: torch.dtype | None,
169200
memory_saving_mode: bool | Literal["auto"] | float | int,
170201
use_autocast_: bool,
202+
use_onnx: bool = False,
171203
) -> InferenceEngine:
172204
"""Creates the appropriate TabPFN inference engine based on `fit_mode`.
173205
@@ -190,6 +222,7 @@ def create_inference_engine( # noqa: PLR0913
190222
forced_inference_dtype_: If not None, the forced dtype for inference.
191223
memory_saving_mode: GPU/CPU memory saving settings.
192224
use_autocast_: Whether we use torch.autocast for inference.
225+
use_onnx: Whether to use ONNX runtime for model inference.
193226
"""
194227
engine: (
195228
InferenceEngineOnDemand
@@ -208,6 +241,7 @@ def create_inference_engine( # noqa: PLR0913
208241
dtype_byte_size=byte_size,
209242
force_inference_dtype=forced_inference_dtype_,
210243
save_peak_mem=memory_saving_mode,
244+
use_onnx=use_onnx,
211245
)
212246
elif fit_mode == "fit_preprocessors":
213247
engine = InferenceEngineCachePreprocessing.prepare(
@@ -221,6 +255,7 @@ def create_inference_engine( # noqa: PLR0913
221255
dtype_byte_size=byte_size,
222256
force_inference_dtype=forced_inference_dtype_,
223257
save_peak_mem=memory_saving_mode,
258+
use_onnx=use_onnx,
224259
)
225260
elif fit_mode == "fit_with_cache":
226261
engine = InferenceEngineCacheKV.prepare(
@@ -236,6 +271,7 @@ def create_inference_engine( # noqa: PLR0913
236271
force_inference_dtype=forced_inference_dtype_,
237272
save_peak_mem=memory_saving_mode,
238273
autocast=use_autocast_,
274+
use_onnx=use_onnx,
239275
)
240276
else:
241277
raise ValueError(f"Invalid fit_mode: {fit_mode}")

src/tabpfn/classifier.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
create_inference_engine,
3333
determine_precision,
3434
initialize_tabpfn_model,
35+
load_onnx_model,
3536
)
3637
from tabpfn.config import ModelInterfaceConfig
3738
from tabpfn.constants import (
@@ -148,6 +149,7 @@ def __init__( # noqa: PLR0913
148149
random_state: int | np.random.RandomState | np.random.Generator | None = 0,
149150
n_jobs: int = -1,
150151
inference_config: dict | ModelInterfaceConfig | None = None,
152+
use_onnx: bool = False,
151153
) -> None:
152154
"""A TabPFN interface for classification.
153155
@@ -337,6 +339,9 @@ def __init__( # noqa: PLR0913
337339
- If `dict`, the key-value pairs are used to update the default
338340
`ModelInterfaceConfig`. Raises an error if an unknown key is passed.
339341
- If `ModelInterfaceConfig`, the object is used as the configuration.
342+
343+
use_onnx:
344+
Whether to use an ONNX compiled model.
340345
"""
341346
super().__init__()
342347
self.n_estimators = n_estimators
@@ -359,6 +364,7 @@ def __init__( # noqa: PLR0913
359364
self.random_state = random_state
360365
self.n_jobs = n_jobs
361366
self.inference_config = inference_config
367+
self.use_onnx = use_onnx
362368

363369
# TODO: We can remove this from scikit-learn lower bound of 1.6
364370
def _more_tags(self) -> dict[str, Any]:
@@ -383,12 +389,15 @@ def fit(self, X: XType, y: YType) -> Self:
383389
static_seed, rng = infer_random_state(self.random_state)
384390

385391
# Load the model and config
386-
self.model_, self.config_, _ = initialize_tabpfn_model(
387-
model_path=self.model_path,
388-
which="classifier",
389-
fit_mode=self.fit_mode,
390-
static_seed=static_seed,
391-
)
392+
if self.use_onnx:
393+
self.model_ = load_onnx_model("model_classifier.onnx")
394+
else:
395+
self.model_, self.config_, _ = initialize_tabpfn_model(
396+
model_path=self.model_path,
397+
which="classifier",
398+
fit_mode=self.fit_mode,
399+
static_seed=static_seed,
400+
)
392401

393402
# Determine device and precision
394403
self.device_ = infer_device_and_type(self.device)
@@ -500,6 +509,7 @@ def fit(self, X: XType, y: YType) -> Self:
500509
forced_inference_dtype_=self.forced_inference_dtype_,
501510
memory_saving_mode=self.memory_saving_mode,
502511
use_autocast_=self.use_autocast_,
512+
use_onnx=self.use_onnx,
503513
)
504514

505515
return self

src/tabpfn/inference.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,9 +219,10 @@ class InferenceEngineCachePreprocessing(InferenceEngine):
219219
preprocessors: Sequence[SequentialFeatureTransformer]
220220
model: PerFeatureTransformer
221221
force_inference_dtype: torch.dtype | None
222+
use_onnx: bool = False
222223

223224
@classmethod
224-
def prepare(
225+
def prepare( # noqa: PLR0913
225226
cls,
226227
X_train: np.ndarray,
227228
y_train: np.ndarray,
@@ -234,6 +235,7 @@ def prepare(
234235
dtype_byte_size: int,
235236
force_inference_dtype: torch.dtype | None,
236237
save_peak_mem: bool | Literal["auto"] | float | int,
238+
use_onnx: bool = False,
237239
) -> InferenceEngineCachePreprocessing:
238240
"""Prepare the inference engine.
239241
@@ -248,6 +250,7 @@ def prepare(
248250
dtype_byte_size: The byte size of the dtype.
249251
force_inference_dtype: The dtype to force inference to.
250252
save_peak_mem: Whether to save peak memory usage.
253+
use_onnx: Whether to use ONNX for inference.
251254
252255
Returns:
253256
The prepared inference engine.
@@ -272,6 +275,7 @@ def prepare(
272275
dtype_byte_size=dtype_byte_size,
273276
force_inference_dtype=force_inference_dtype,
274277
save_peak_mem=save_peak_mem,
278+
use_onnx=use_onnx,
275279
)
276280

277281
@override
@@ -315,6 +319,7 @@ def iter_outputs(
315319
device=device,
316320
dtype_byte_size=self.dtype_byte_size,
317321
safety_factor=1.2, # TODO(Arjun): make customizable
322+
use_onnx=self.use_onnx,
318323
)
319324

320325
style = None

src/tabpfn/misc/__init__.py

Whitespace-only changes.
File renamed without changes.

0 commit comments

Comments
 (0)