Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: change warning into infolog #584

Merged
merged 2 commits into from
Dec 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 7 additions & 17 deletions mapie/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import logging
import re
from typing import Any, Optional, Tuple

Expand Down Expand Up @@ -228,22 +229,24 @@ def test_valid_verbose(verbose: Any) -> None:
check_verbose(verbose)


def test_initial_low_high_pred() -> None:
def test_initial_low_high_pred(caplog) -> None:
"""Test lower/upper predictions of the quantiles regression crossing"""
y_preds = np.array([[4, 5, 2], [4, 4, 4], [2, 3, 4]])
with pytest.warns(UserWarning, match=r"WARNING: The predictions are*"):
with caplog.at_level(logging.INFO):
check_lower_upper_bounds(y_preds[0], y_preds[1], y_preds[2])
assert "The predictions are ill-sorted" in caplog.text


def test_final_low_high_pred() -> None:
def test_final_low_high_pred(caplog) -> None:
"""Test lower/upper predictions crossing"""
y_preds = np.array(
[[4, 3, 2], [3, 3, 3], [2, 3, 4]]
)
y_pred_low = np.array([4, 7, 2])
y_pred_up = np.array([3, 3, 3])
with pytest.warns(UserWarning, match=r"WARNING: The predictions are*"):
with caplog.at_level(logging.INFO):
check_lower_upper_bounds(y_pred_low, y_pred_up, y_preds[2])
assert "The predictions are ill-sorted" in caplog.text


def test_ensemble_in_predict() -> None:
Expand Down Expand Up @@ -331,19 +334,6 @@ def test_quantile_prefit_non_iterable(estimator: Any) -> None:
mapie_reg.fit([1, 2, 3], [4, 5, 6])


# def test_calib_set_no_Xy_but_sample_weight() -> None:
# """Test warning message if sample weight provided but no X y in calib."""
# X = np.array([4, 5, 6])
# y = np.array([4, 3, 2])
# sample_weight = np.array([4, 4, 4])
# sample_weight_calib = np.array([4, 3, 4])
# with pytest.warns(UserWarning, match=r"WARNING: sample weight*"):
# check_calib_set(
# X=X, y=y, sample_weight=sample_weight,
# sample_weight_calib=sample_weight_calib
# )


@pytest.mark.parametrize("strategy", ["quantile", "uniform", "array split"])
def test_binning_group_strategies(strategy: str) -> None:
"""Test that different strategies have the correct outputs."""
Expand Down
41 changes: 6 additions & 35 deletions mapie/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import warnings
from inspect import signature
from typing import Any, Iterable, Optional, Tuple, Union, cast
Expand Down Expand Up @@ -573,39 +574,6 @@ def check_lower_upper_bounds(
y_pred_up: NDArray,
y_preds: NDArray
) -> None:
"""
Check if lower or upper bounds and prediction are consistent.

Parameters
----------
y_pred_low: NDArray of shape (n_samples,)
Lower bound prediction.

y_pred_up: NDArray of shape (n_samples,)
Upper bound prediction.

y_preds: NDArray of shape (n_samples,)
Prediction.

Raises
------
Warning
If any of the predictions are ill-sorted.

Examples
--------
>>> import warnings
>>> warnings.filterwarnings("error")
>>> import numpy as np
>>> from mapie.utils import check_lower_upper_bounds
>>> y_preds = np.array([[4, 3, 2], [4, 4, 4], [2, 3, 4]])
>>> try:
... check_lower_upper_bounds(y_preds[0], y_preds[1], y_preds[2])
... except Exception as exception:
... print(exception)
...
WARNING: The predictions are ill-sorted.
"""
y_pred_low = column_or_1d(y_pred_low)
y_pred_up = column_or_1d(y_pred_up)
y_preds = column_or_1d(y_preds)
Expand All @@ -617,9 +585,12 @@ def check_lower_upper_bounds(
)

if any_inversion:
warnings.warn(
"WARNING: The predictions are ill-sorted."
initial_logger_level = logging.root.level
logging.basicConfig(level=logging.INFO)
logging.info(
"The predictions are ill-sorted."
)
logging.basicConfig(level=initial_logger_level)


def check_defined_variables_predict_cqr(
Expand Down
Loading