Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions python-package/xgboost/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1730,6 +1730,7 @@ def fit(
) -> "XGBClassifier":
# pylint: disable = attribute-defined-outside-init,too-many-statements
with config_context(verbosity=self.verbosity):
encoded_classes = None
# We keep the n_classes_ as a simple member instead of loading it from
# booster in a Python property. This way we can have efficient and
# thread-safe prediction.
Expand All @@ -1744,22 +1745,28 @@ def fit(
elif _is_cupy_alike(y):
cp = import_cupy()

classes = cp.unique(y)
classes, encoded_classes = cp.unique(y, return_inverse=True)
self.n_classes_ = len(classes)
expected_classes = cp.array(self.classes_)
else:
classes = np.unique(np.asarray(y))
classes, encoded_classes = np.unique(np.asarray(y), return_inverse=True)
self.n_classes_ = len(classes)
expected_classes = self.classes_
if (
classes.shape != expected_classes.shape
or not (classes == expected_classes).all()
):
if classes.shape != expected_classes.shape:
raise ValueError(
f"Invalid classes inferred from unique values of `y`. "
f"Expected: {expected_classes}, got {classes}"
)

self.classes_labels: ArrayLike | None = None
# if classes are not label encoded as [0, 1, ..., n_classes - 1], then
# we need to transform them.
if not (classes == expected_classes).all():
self.classes_labels = classes
if encoded_classes is None:
encoded_classes = np.unique(np.asarray(y), return_inverse=True)[1]
y = encoded_classes

params = self.get_xgb_params()

if callable(self.objective):
Expand Down Expand Up @@ -1861,6 +1868,10 @@ def predict(
column_indexes = np.repeat(0, class_probs.shape[0])
column_indexes[class_probs > 0.5] = 1

# map back to original class labels
if self.classes_labels is not None:
return self.classes_labels[column_indexes]

return column_indexes

def predict_proba(
Expand Down