Skip to content

Commit db2e7d5

Browse files
authored
make regression metrics 'multioutput' behavior consistent with scikit-learn (fixes dask#818) (dask#820)
* make regression metrics 'multioutput' behavior consistent with scikit-learn * add test on error messages * linting
1 parent 27d8d37 commit db2e7d5

File tree

2 files changed

+64
-12
lines changed

2 files changed

+64
-12
lines changed

dask_ml/metrics/regression.py

+12-12
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def _check_sample_weight(sample_weight: Optional[ArrayLike]):
1616
def _check_reg_targets(
1717
y_true: ArrayLike, y_pred: ArrayLike, multioutput: Optional[str]
1818
):
19-
if multioutput != "uniform_average":
19+
if multioutput is not None and multioutput != "uniform_average":
2020
raise NotImplementedError("'multioutput' must be 'uniform_average'")
2121

2222
if y_true.ndim == 1:
@@ -40,12 +40,12 @@ def mean_squared_error(
4040
_check_sample_weight(sample_weight)
4141
output_errors = ((y_pred - y_true) ** 2).mean(axis=0)
4242

43-
if isinstance(multioutput, str):
43+
if isinstance(multioutput, str) or multioutput is None:
4444
if multioutput == "raw_values":
45-
return output_errors
46-
elif multioutput == "uniform_average":
47-
# pass None as weights to np.average: uniform mean
48-
multioutput = None
45+
if compute:
46+
return output_errors.compute()
47+
else:
48+
return output_errors
4949
else:
5050
raise ValueError("Weighted 'multioutput' not supported.")
5151
result = output_errors.mean()
@@ -67,12 +67,12 @@ def mean_absolute_error(
6767
_check_sample_weight(sample_weight)
6868
output_errors = abs(y_pred - y_true).mean(axis=0)
6969

70-
if isinstance(multioutput, str):
70+
if isinstance(multioutput, str) or multioutput is None:
7171
if multioutput == "raw_values":
72-
return output_errors
73-
elif multioutput == "uniform_average":
74-
# pass None as weights to np.average: uniform mean
75-
multioutput = None
72+
if compute:
73+
return output_errors.compute()
74+
else:
75+
return output_errors
7676
else:
7777
raise ValueError("Weighted 'multioutput' not supported.")
7878
result = output_errors.mean()
@@ -153,7 +153,7 @@ def r2_score(
153153
compute: bool = True,
154154
) -> ArrayLike:
155155
_check_sample_weight(sample_weight)
156-
_, y_true, y_pred, multioutput = _check_reg_targets(y_true, y_pred, multioutput)
156+
_, y_true, y_pred, _ = _check_reg_targets(y_true, y_pred, multioutput)
157157
weight = 1.0
158158

159159
numerator = (weight * (y_true - y_pred) ** 2).sum(axis=0, dtype="f8")

tests/metrics/test_regression.py

+52
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import numbers
22

33
import dask.array as da
4+
import numpy as np
45
import pytest
56
import sklearn.metrics
7+
from dask.array.utils import assert_eq
68

79
import dask_ml.metrics
810
from dask_ml._compat import SK_024
@@ -72,3 +74,53 @@ def test_mean_squared_log_error():
7274
result = m1(a, b)
7375
expected = m2(a, b)
7476
assert abs(result - expected) < 1e-5
77+
78+
79+
@pytest.mark.parametrize("multioutput", ["uniform_average", None])
80+
def test_regression_metrics_unweighted_average_multioutput(metric_pairs, multioutput):
81+
m1, m2 = metric_pairs
82+
83+
a = da.random.uniform(size=(100,), chunks=(25,))
84+
b = da.random.uniform(size=(100,), chunks=(25,))
85+
86+
result = m1(a, b, multioutput=multioutput)
87+
expected = m2(a, b, multioutput=multioutput)
88+
assert abs(result - expected) < 1e-5
89+
90+
91+
@pytest.mark.parametrize("compute", [True, False])
92+
def test_regression_metrics_raw_values(metric_pairs, compute):
93+
m1, m2 = metric_pairs
94+
95+
if m1.__name__ == "r2_score":
96+
pytest.skip("r2_score does not support multioutput='raw_values'")
97+
98+
a = da.random.uniform(size=(100, 3), chunks=(25, 3))
99+
b = da.random.uniform(size=(100, 3), chunks=(25, 3))
100+
101+
result = m1(a, b, multioutput="raw_values", compute=compute)
102+
expected = m2(a, b, multioutput="raw_values")
103+
104+
if compute:
105+
assert isinstance(result, np.ndarray)
106+
else:
107+
assert isinstance(result, da.Array)
108+
109+
assert_eq(result, expected)
110+
assert result.shape == (3,)
111+
112+
113+
def test_regression_metrics_do_not_support_weighted_multioutput(metric_pairs):
114+
m1, _ = metric_pairs
115+
116+
a = da.random.uniform(size=(100, 3), chunks=(25, 3))
117+
b = da.random.uniform(size=(100, 3), chunks=(25, 3))
118+
weights = da.random.uniform(size=(3,))
119+
120+
if m1.__name__ == "r2_score":
121+
error_msg = "'multioutput' must be 'uniform_average'"
122+
else:
123+
error_msg = "Weighted 'multioutput' not supported."
124+
125+
with pytest.raises((NotImplementedError, ValueError), match=error_msg):
126+
_ = m1(a, b, multioutput=weights)

0 commit comments

Comments
 (0)