Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 53 additions & 25 deletions src/tabpfn/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal
from typing_extensions import Self

import pandas as pd
import numpy as np
import numpy.typing as npt
import torch
from sklearn import config_context
from sklearn.base import BaseEstimator, ClassifierMixin, check_is_fitted
Expand Down Expand Up @@ -75,10 +76,31 @@
from tabpfn.architectures.interface import ArchitectureConfig
from tabpfn.config import ModelInterfaceConfig

try:
from sklearn.base import Tags
except ImportError:
Tags = Any
from typing import Any
Tags = Any
# classifier.py


# (other imports that were already there)




def preprocess_input(X: pd.DataFrame) -> pd.DataFrame:
"""
Safely handle missing values in input DataFrame.
- Numeric columns: fill NaN with 0
- Categorical/text columns: fill NaN with 'missing'
"""
X = X.copy()
for col in X.columns:
if X[col].dtype == "object": # categorical/text
X[col] = X[col].fillna("missing").astype(str)
else: # numeric
X[col] = X[col].fillna(0).astype(float)
return X




class TabPFNClassifier(ClassifierMixin, BaseEstimator):
Expand All @@ -102,7 +124,7 @@ class TabPFNClassifier(ClassifierMixin, BaseEstimator):
the first is used for inference.
"""

feature_names_in_: npt.NDArray[Any]
feature_names_in_: npt.NDArray[np.generic]
"""The feature names of the input data.

May not be set if the input data does not have feature names,
Expand Down Expand Up @@ -633,12 +655,15 @@ def fit_from_preprocessed(
@config_context(transform_output="default") # type: ignore
@track_model_call(model_method="fit", param_names=["X", "y"])
def fit(self, X: XType, y: YType) -> Self:
"""Fit the model.
"""
Fit the model with preprocessing of missing values.

Args:
X: The input data.
y: The target variable.
X: The input data.
y: The target variable.
"""
X = preprocess_input(X)

if self.fit_mode == "batched":
logging.warning(
"The model was in 'batched' mode, likely after finetuning. "
Expand All @@ -653,29 +678,30 @@ def fit(self, X: XType, y: YType) -> Self:
else: # already fitted and prompt_tuning mode: no cat. features
_, rng = infer_random_state(self.random_state)
_, _, byte_size = determine_precision(
self.inference_precision, self.devices_
)
self.inference_precision, self.devices_
)

# Create the inference engine
self.executor_ = create_inference_engine(
X_train=X,
y_train=y,
model=self.model_,
ensemble_configs=ensemble_configs,
cat_ix=self.inferred_categorical_indices_,
fit_mode=self.fit_mode,
devices_=self.devices_,
rng=rng,
n_jobs=self.n_jobs,
byte_size=byte_size,
forced_inference_dtype_=self.forced_inference_dtype_,
memory_saving_mode=self.memory_saving_mode,
use_autocast_=self.use_autocast_,
inference_mode=not self.differentiable_input,
X_train=X,
y_train=y,
model=self.model_,
ensemble_configs=ensemble_configs,
cat_ix=self.inferred_categorical_indices_,
fit_mode=self.fit_mode,
devices_=self.devices_,
rng=rng,
n_jobs=self.n_jobs,
byte_size=byte_size,
forced_inference_dtype_=self.forced_inference_dtype_,
memory_saving_mode=self.memory_saving_mode,
use_autocast_=self.use_autocast_,
inference_mode=not self.differentiable_input,
)

return self


def _raw_predict(self, X: XType, *, return_logits: bool) -> torch.Tensor:
"""Internal method to run prediction.

Expand All @@ -694,6 +720,7 @@ def _raw_predict(self, X: XType, *, return_logits: bool) -> torch.Tensor:
depending on `return_logits`.
"""
check_is_fitted(self)
X = preprocess_input(X)

if not self.differentiable_input:
X = validate_X_predict(X, self)
Expand Down Expand Up @@ -978,3 +1005,4 @@ def load_from_fit_state(
f"Attempting to load a '{est.__class__.__name__}' as '{cls.__name__}'"
)
return est

39 changes: 39 additions & 0 deletions tests/test_na_handling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""Test to verify that TabPFNClassifier handles missing values gracefully."""

from __future__ import annotations

import pandas as pd
from sklearn.model_selection import train_test_split

from tabpfn import TabPFNClassifier


def test_classifier_handles_na_values() -> None:
"""Ensure that TabPFNClassifier can train and predict
with NA values in input data.
"""
data = {
"feature1": ["a", "b", pd.NA, "d"],
"feature2": [1, 2, 3, 4],
"target": [0, 1, 0, 1],
}
df = pd.DataFrame(data)
X = df[["feature1", "feature2"]]
y = df["target"]

# Train-test split
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.25, random_state=42
)

clf = TabPFNClassifier(device="cpu")

# Train and predict
clf.fit(X_train, y_train)
predictions = clf.predict(X_test)

# Assert predictions length matches
assert len(predictions) == len(y_test)

# Predictions should be valid labels only
assert set(predictions).issubset(set(y.unique()))