Skip to content

Commit b8cfb56

Browse files
hcho3jameslambtrivialfis
authored
[backport] Compatibility fixes for scikit-learn 1.6 (dmlc#11021, dmlc#11162) (dmlc#11205)
* Adapt to scikit-learn 1.6 estimator tag changes (dmlc#11021) * More sklearn tag support. (dmlc#11162) * [CI] Unpin scikit-learn * Remove test_doc_link() test --------- Co-authored-by: James Lamb <[email protected]> Co-authored-by: Jiaming Yuan <[email protected]>
1 parent 30a7fd5 commit b8cfb56

File tree

14 files changed

+197
-33
lines changed

14 files changed

+197
-33
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,11 +139,13 @@ credentials.csv
139139
.bloop
140140

141141
# python tests
142+
*.bin
142143
demo/**/*.txt
143144
*.dmatrix
144145
.hypothesis
145146
__MACOSX/
146147
model*.json
148+
/tests/python/models/models/
147149

148150
# R tests
149151
*.htm

python-package/pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ extension-pkg-whitelist = ["numpy"]
6262
disable = [
6363
"attribute-defined-outside-init",
6464
"import-outside-toplevel",
65+
"too-few-public-methods",
66+
"too-many-ancestors",
6567
"too-many-nested-blocks",
6668
"unexpected-special-method-signature",
6769
"unsubscriptable-object",

python-package/xgboost/compat.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,32 +45,43 @@ def lazy_isinstance(instance: Any, module: str, name: str) -> bool:
4545

4646
# sklearn
4747
try:
48+
from sklearn import __version__ as _sklearn_version
4849
from sklearn.base import BaseEstimator as XGBModelBase
4950
from sklearn.base import ClassifierMixin as XGBClassifierBase
5051
from sklearn.base import RegressorMixin as XGBRegressorBase
51-
from sklearn.preprocessing import LabelEncoder
5252

5353
try:
54-
from sklearn.model_selection import KFold as XGBKFold
5554
from sklearn.model_selection import StratifiedKFold as XGBStratifiedKFold
5655
except ImportError:
57-
from sklearn.cross_validation import KFold as XGBKFold
5856
from sklearn.cross_validation import StratifiedKFold as XGBStratifiedKFold
5957

58+
# sklearn.utils Tags types can be imported unconditionally once
59+
# xgboost's minimum scikit-learn version is 1.6 or higher
60+
try:
61+
from sklearn.utils import Tags as _sklearn_Tags
62+
except ImportError:
63+
_sklearn_Tags = object
64+
6065
SKLEARN_INSTALLED = True
6166

6267
except ImportError:
6368
SKLEARN_INSTALLED = False
6469

6570
# used for compatibility without sklearn
66-
XGBModelBase = object
67-
XGBClassifierBase = object
68-
XGBRegressorBase = object
69-
LabelEncoder = object
71+
class XGBModelBase: # type: ignore[no-redef]
72+
"""Dummy class for sklearn.base.BaseEstimator."""
73+
74+
class XGBClassifierBase: # type: ignore[no-redef]
75+
"""Dummy class for sklearn.base.ClassifierMixin."""
76+
77+
class XGBRegressorBase: # type: ignore[no-redef]
78+
"""Dummy class for sklearn.base.RegressorMixin."""
7079

71-
XGBKFold = None
7280
XGBStratifiedKFold = None
7381

82+
_sklearn_Tags = object
83+
_sklearn_version = object
84+
7485

7586
_logger = logging.getLogger(__name__)
7687

python-package/xgboost/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,7 @@ def c_array(
410410
def from_array_interface(interface: dict) -> NumpyOrCupy:
411411
"""Convert array interface to numpy or cupy array"""
412412

413-
class Array: # pylint: disable=too-few-public-methods
413+
class Array:
414414
"""Wrapper type for communicating with numpy and cupy."""
415415

416416
_interface: Optional[dict] = None

python-package/xgboost/dask/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
# pylint: disable=too-many-arguments, too-many-locals
22
# pylint: disable=missing-class-docstring, invalid-name
33
# pylint: disable=too-many-lines
4-
# pylint: disable=too-few-public-methods
5-
# pylint: disable=import-error
64
"""
75
Dask extensions for distributed training
86
----------------------------------------

python-package/xgboost/sklearn.py

Lines changed: 85 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,14 @@
2929

3030
# Do not use class names on scikit-learn directly. Re-define the classes on
3131
# .compat to guarantee the behavior without scikit-learn
32-
from .compat import SKLEARN_INSTALLED, XGBClassifierBase, XGBModelBase, XGBRegressorBase
32+
from .compat import (
33+
SKLEARN_INSTALLED,
34+
XGBClassifierBase,
35+
XGBModelBase,
36+
XGBRegressorBase,
37+
_sklearn_Tags,
38+
_sklearn_version,
39+
)
3340
from .config import config_context
3441
from .core import (
3542
Booster,
@@ -45,7 +52,7 @@
4552
from .training import train
4653

4754

48-
class XGBRankerMixIn: # pylint: disable=too-few-public-methods
55+
class XGBRankerMixIn:
4956
"""MixIn for ranking, defines the _estimator_type usually defined in scikit-learn
5057
base classes.
5158
@@ -69,7 +76,7 @@ def _can_use_qdm(tree_method: Optional[str]) -> bool:
6976
return tree_method in ("hist", "gpu_hist", None, "auto")
7077

7178

72-
class _SklObjWProto(Protocol): # pylint: disable=too-few-public-methods
79+
class _SklObjWProto(Protocol):
7380
def __call__(
7481
self,
7582
y_true: ArrayLike,
@@ -782,11 +789,52 @@ def __init__(
782789

783790
def _more_tags(self) -> Dict[str, bool]:
784791
"""Tags used for scikit-learn data validation."""
785-
tags = {"allow_nan": True, "no_validation": True}
792+
tags = {"allow_nan": True, "no_validation": True, "sparse": True}
786793
if hasattr(self, "kwargs") and self.kwargs.get("updater") == "shotgun":
787794
tags["non_deterministic"] = True
795+
796+
tags["categorical"] = self.enable_categorical
797+
return tags
798+
799+
@staticmethod
800+
def _update_sklearn_tags_from_dict(
801+
*,
802+
tags: _sklearn_Tags,
803+
tags_dict: Dict[str, bool],
804+
) -> _sklearn_Tags:
805+
"""Update ``sklearn.utils.Tags`` inherited from ``scikit-learn`` base classes.
806+
807+
``scikit-learn`` 1.6 introduced a dataclass-based interface for estimator tags.
808+
ref: https://github.com/scikit-learn/scikit-learn/pull/29677
809+
810+
This method handles updating that instance based on the values in
811+
``self._more_tags()``.
812+
813+
"""
814+
tags.non_deterministic = tags_dict.get("non_deterministic", False)
815+
tags.no_validation = tags_dict["no_validation"]
816+
tags.input_tags.allow_nan = tags_dict["allow_nan"]
817+
tags.input_tags.sparse = tags_dict["sparse"]
818+
tags.input_tags.categorical = tags_dict["categorical"]
788819
return tags
789820

821+
def __sklearn_tags__(self) -> _sklearn_Tags:
822+
# XGBModelBase.__sklearn_tags__() cannot be called unconditionally,
823+
# because that method isn't defined for scikit-learn<1.6
824+
if not hasattr(XGBModelBase, "__sklearn_tags__"):
825+
err_msg = (
826+
"__sklearn_tags__() should not be called when using scikit-learn<1.6. "
827+
f"Detected version: {_sklearn_version}"
828+
)
829+
raise AttributeError(err_msg)
830+
831+
# take whatever tags are provided by BaseEstimator, then modify
832+
# them with XGBoost-specific values
833+
return self._update_sklearn_tags_from_dict(
834+
tags=super().__sklearn_tags__(), # pylint: disable=no-member
835+
tags_dict=self._more_tags(),
836+
)
837+
790838
def __sklearn_is_fitted__(self) -> bool:
791839
return hasattr(self, "_Booster")
792840

@@ -841,13 +889,27 @@ def get_params(self, deep: bool = True) -> Dict[str, Any]:
841889
"""Get parameters."""
842890
# Based on: https://stackoverflow.com/questions/59248211
843891
# The basic flow in `get_params` is:
844-
# 0. Return parameters in subclass first, by using inspect.
845-
# 1. Return parameters in `XGBModel` (the base class).
892+
# 0. Return parameters in subclass (self.__class__) first, by using inspect.
893+
# 1. Return parameters in all parent classes (especially `XGBModel`).
846894
# 2. Return whatever in `**kwargs`.
847895
# 3. Merge them.
896+
#
897+
# This needs to accommodate being called recursively in the following
898+
# inheritance graphs (and similar for classification and ranking):
899+
#
900+
# XGBRFRegressor -> XGBRegressor -> XGBModel -> BaseEstimator
901+
# XGBRegressor -> XGBModel -> BaseEstimator
902+
# XGBModel -> BaseEstimator
903+
#
848904
params = super().get_params(deep)
849905
cp = copy.copy(self)
850-
cp.__class__ = cp.__class__.__bases__[0]
906+
# If the immediate parent defines get_params(), use that.
907+
if callable(getattr(cp.__class__.__bases__[0], "get_params", None)):
908+
cp.__class__ = cp.__class__.__bases__[0]
909+
# Otherwise, skip it and assume the next class will have it.
910+
# This is here primarily for cases where the first class in MRO is a scikit-learn mixin.
911+
else:
912+
cp.__class__ = cp.__class__.__bases__[1]
851913
params.update(cp.__class__.get_params(cp, deep))
852914
# if kwargs is a dict, update params accordingly
853915
if hasattr(self, "kwargs") and isinstance(self.kwargs, dict):
@@ -1431,7 +1493,7 @@ def _cls_predict_proba(n_classes: int, prediction: PredtT, vstack: Callable) ->
14311493
Number of boosting rounds.
14321494
""",
14331495
)
1434-
class XGBClassifier(XGBModel, XGBClassifierBase):
1496+
class XGBClassifier(XGBClassifierBase, XGBModel):
14351497
# pylint: disable=missing-docstring,invalid-name,too-many-instance-attributes
14361498
@_deprecate_positional_args
14371499
def __init__(
@@ -1447,6 +1509,12 @@ def _more_tags(self) -> Dict[str, bool]:
14471509
tags["multilabel"] = True
14481510
return tags
14491511

1512+
def __sklearn_tags__(self) -> _sklearn_Tags:
1513+
tags = super().__sklearn_tags__()
1514+
tags_dict = self._more_tags()
1515+
tags.classifier_tags.multi_label = tags_dict["multilabel"]
1516+
return tags
1517+
14501518
@_deprecate_positional_args
14511519
def fit(
14521520
self,
@@ -1717,7 +1785,7 @@ def fit(
17171785
"Implementation of the scikit-learn API for XGBoost regression.",
17181786
["estimators", "model", "objective"],
17191787
)
1720-
class XGBRegressor(XGBModel, XGBRegressorBase):
1788+
class XGBRegressor(XGBRegressorBase, XGBModel):
17211789
# pylint: disable=missing-docstring
17221790
@_deprecate_positional_args
17231791
def __init__(
@@ -1731,6 +1799,13 @@ def _more_tags(self) -> Dict[str, bool]:
17311799
tags["multioutput_only"] = False
17321800
return tags
17331801

1802+
def __sklearn_tags__(self) -> _sklearn_Tags:
1803+
tags = super().__sklearn_tags__()
1804+
tags_dict = self._more_tags()
1805+
tags.target_tags.multi_output = tags_dict["multioutput"]
1806+
tags.target_tags.single_output = not tags_dict["multioutput_only"]
1807+
return tags
1808+
17341809

17351810
@xgboost_model_doc(
17361811
"scikit-learn API for XGBoost random forest regression.",
@@ -1858,7 +1933,7 @@ def _get_qid(
18581933
`qid` can be a special column of input `X` instead of a separated parameter, see
18591934
:py:meth:`fit` for more info.""",
18601935
)
1861-
class XGBRanker(XGBModel, XGBRankerMixIn):
1936+
class XGBRanker(XGBRankerMixIn, XGBModel):
18621937
# pylint: disable=missing-docstring,too-many-arguments,invalid-name
18631938
@_deprecate_positional_args
18641939
def __init__(self, *, objective: str = "rank:ndcg", **kwargs: Any):

python-package/xgboost/spark/core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
import base64
44

5-
# pylint: disable=fixme, too-many-ancestors, protected-access, no-member, invalid-name
6-
# pylint: disable=too-few-public-methods, too-many-lines, too-many-branches
5+
# pylint: disable=fixme, protected-access, no-member, invalid-name
6+
# pylint: disable=too-many-lines, too-many-branches
77
import json
88
import logging
99
import os

python-package/xgboost/spark/estimator.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""Xgboost pyspark integration submodule for estimator API."""
22

3-
# pylint: disable=too-many-ancestors
4-
# pylint: disable=fixme, too-many-ancestors, protected-access, no-member, invalid-name
3+
# pylint: disable=fixme, protected-access, no-member, invalid-name
54
# pylint: disable=unused-argument, too-many-locals
65

76
import warnings

python-package/xgboost/spark/params.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from typing import Dict
44

5-
# pylint: disable=too-few-public-methods
65
from pyspark.ml.param import TypeConverters
76
from pyspark.ml.param.shared import Param, Params
87

python-package/xgboost/spark/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def _get_default_params_from_func(
4343
return filtered_params_dict
4444

4545

46-
class CommunicatorContext(CCtx): # pylint: disable=too-few-public-methods
46+
class CommunicatorContext(CCtx):
4747
"""Context with PySpark specific task ID."""
4848

4949
def __init__(self, context: BarrierTaskContext, **args: Any) -> None:

python-package/xgboost/testing/data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -564,7 +564,7 @@ def is_binary(self) -> bool:
564564
return self.max_rel == 1
565565

566566

567-
class PBM: # pylint: disable=too-few-public-methods
567+
class PBM:
568568
"""Simulate click data with position bias model. There are other models available in
569569
`ULTRA <https://github.com/ULTR-Community/ULTRA.git>`_ like the cascading model.
570570

tests/ci_build/Dockerfile.gpu

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,7 @@ RUN \
2727
"nccl>=${NCCL_SHORT_VER}" \
2828
dask \
2929
dask-cuda=$RAPIDS_VERSION_ARG* dask-cudf=$RAPIDS_VERSION_ARG* cupy \
30-
numpy pytest pytest-timeout scipy \
31-
"scikit-learn<=1.5.2" \
30+
numpy pytest pytest-timeout scipy scikit-learn \
3231
pandas matplotlib wheel python-kubernetes urllib3 graphviz "hypothesis<=6.112" \
3332
"pyspark>=3.4.0" cloudpickle cuda-python && \
3433
mamba clean --all && \

tests/ci_build/conda_env/win64_test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ dependencies:
66
- numpy
77
- scipy
88
- matplotlib
9-
- scikit-learn<=1.5.2
9+
- scikit-learn
1010
- pandas
1111
- pytest
1212
- boto3

0 commit comments

Comments
 (0)