diff --git a/src/tabpfn/classifier.py b/src/tabpfn/classifier.py index 34adc8c0e..af872d8e7 100644 --- a/src/tabpfn/classifier.py +++ b/src/tabpfn/classifier.py @@ -24,8 +24,9 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Literal from typing_extensions import Self - +import pandas as pd import numpy as np +import numpy.typing as npt import torch from sklearn import config_context from sklearn.base import BaseEstimator, ClassifierMixin, check_is_fitted @@ -75,10 +76,31 @@ from tabpfn.architectures.interface import ArchitectureConfig from tabpfn.config import ModelInterfaceConfig - try: - from sklearn.base import Tags - except ImportError: - Tags = Any + from typing import Any + Tags = Any +# classifier.py + + +# (other imports that were already there) + + + + +def preprocess_input(X: pd.DataFrame) -> pd.DataFrame: + """ + Safely handle missing values in input DataFrame. + - Numeric columns: fill NaN with 0 + - Categorical/text columns: fill NaN with 'missing' + """ + X = X.copy() + for col in X.columns: + if X[col].dtype == "object": # categorical/text + X[col] = X[col].fillna("missing").astype(str) + else: # numeric + X[col] = X[col].fillna(0).astype(float) + return X + + class TabPFNClassifier(ClassifierMixin, BaseEstimator): @@ -102,7 +124,7 @@ class TabPFNClassifier(ClassifierMixin, BaseEstimator): the first is used for inference. """ - feature_names_in_: npt.NDArray[Any] + feature_names_in_: npt.NDArray[np.generic] """The feature names of the input data. May not be set if the input data does not have feature names, @@ -633,12 +655,15 @@ def fit_from_preprocessed( @config_context(transform_output="default") # type: ignore @track_model_call(model_method="fit", param_names=["X", "y"]) def fit(self, X: XType, y: YType) -> Self: - """Fit the model. + """ + Fit the model with preprocessing of missing values. Args: - X: The input data. - y: The target variable. + X: The input data. + y: The target variable. """ + X = preprocess_input(X) + if self.fit_mode == "batched": logging.warning( "The model was in 'batched' mode, likely after finetuning. " @@ -653,29 +678,30 @@ def fit(self, X: XType, y: YType) -> Self: else: # already fitted and prompt_tuning mode: no cat. features _, rng = infer_random_state(self.random_state) _, _, byte_size = determine_precision( - self.inference_precision, self.devices_ - ) + self.inference_precision, self.devices_ + ) # Create the inference engine self.executor_ = create_inference_engine( - X_train=X, - y_train=y, - model=self.model_, - ensemble_configs=ensemble_configs, - cat_ix=self.inferred_categorical_indices_, - fit_mode=self.fit_mode, - devices_=self.devices_, - rng=rng, - n_jobs=self.n_jobs, - byte_size=byte_size, - forced_inference_dtype_=self.forced_inference_dtype_, - memory_saving_mode=self.memory_saving_mode, - use_autocast_=self.use_autocast_, - inference_mode=not self.differentiable_input, + X_train=X, + y_train=y, + model=self.model_, + ensemble_configs=ensemble_configs, + cat_ix=self.inferred_categorical_indices_, + fit_mode=self.fit_mode, + devices_=self.devices_, + rng=rng, + n_jobs=self.n_jobs, + byte_size=byte_size, + forced_inference_dtype_=self.forced_inference_dtype_, + memory_saving_mode=self.memory_saving_mode, + use_autocast_=self.use_autocast_, + inference_mode=not self.differentiable_input, ) return self + def _raw_predict(self, X: XType, *, return_logits: bool) -> torch.Tensor: """Internal method to run prediction. @@ -694,6 +720,7 @@ def _raw_predict(self, X: XType, *, return_logits: bool) -> torch.Tensor: depending on `return_logits`. """ check_is_fitted(self) + X = preprocess_input(X) if not self.differentiable_input: X = validate_X_predict(X, self) @@ -978,3 +1005,4 @@ def load_from_fit_state( f"Attempting to load a '{est.__class__.__name__}' as '{cls.__name__}'" ) return est + diff --git a/tests/test_na_handling.py b/tests/test_na_handling.py new file mode 100644 index 000000000..03ed114f1 --- /dev/null +++ b/tests/test_na_handling.py @@ -0,0 +1,39 @@ +"""Test to verify that TabPFNClassifier handles missing values gracefully.""" + +from __future__ import annotations + +import pandas as pd +from sklearn.model_selection import train_test_split + +from tabpfn import TabPFNClassifier + + +def test_classifier_handles_na_values() -> None: + """Ensure that TabPFNClassifier can train and predict + with NA values in input data. + """ + data = { + "feature1": ["a", "b", pd.NA, "d"], + "feature2": [1, 2, 3, 4], + "target": [0, 1, 0, 1], + } + df = pd.DataFrame(data) + X = df[["feature1", "feature2"]] + y = df["target"] + + # Train-test split + X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=0.25, random_state=42 + ) + + clf = TabPFNClassifier(device="cpu") + + # Train and predict + clf.fit(X_train, y_train) + predictions = clf.predict(X_test) + + # Assert predictions length matches + assert len(predictions) == len(y_test) + + # Predictions should be valid labels only + assert set(predictions).issubset(set(y.unique()))