|
1 | 1 | import numbers
|
2 | 2 |
|
3 | 3 | import dask.array as da
|
| 4 | +import numpy as np |
4 | 5 | import pytest
|
5 | 6 | import sklearn.metrics
|
| 7 | +from dask.array.utils import assert_eq |
6 | 8 |
|
7 | 9 | import dask_ml.metrics
|
8 | 10 | from dask_ml._compat import SK_024
|
@@ -72,3 +74,53 @@ def test_mean_squared_log_error():
|
72 | 74 | result = m1(a, b)
|
73 | 75 | expected = m2(a, b)
|
74 | 76 | assert abs(result - expected) < 1e-5
|
| 77 | + |
| 78 | + |
| 79 | +@pytest.mark.parametrize("multioutput", ["uniform_average", None]) |
| 80 | +def test_regression_metrics_unweighted_average_multioutput(metric_pairs, multioutput): |
| 81 | + m1, m2 = metric_pairs |
| 82 | + |
| 83 | + a = da.random.uniform(size=(100,), chunks=(25,)) |
| 84 | + b = da.random.uniform(size=(100,), chunks=(25,)) |
| 85 | + |
| 86 | + result = m1(a, b, multioutput=multioutput) |
| 87 | + expected = m2(a, b, multioutput=multioutput) |
| 88 | + assert abs(result - expected) < 1e-5 |
| 89 | + |
| 90 | + |
| 91 | +@pytest.mark.parametrize("compute", [True, False]) |
| 92 | +def test_regression_metrics_raw_values(metric_pairs, compute): |
| 93 | + m1, m2 = metric_pairs |
| 94 | + |
| 95 | + if m1.__name__ == "r2_score": |
| 96 | + pytest.skip("r2_score does not support multioutput='raw_values'") |
| 97 | + |
| 98 | + a = da.random.uniform(size=(100, 3), chunks=(25, 3)) |
| 99 | + b = da.random.uniform(size=(100, 3), chunks=(25, 3)) |
| 100 | + |
| 101 | + result = m1(a, b, multioutput="raw_values", compute=compute) |
| 102 | + expected = m2(a, b, multioutput="raw_values") |
| 103 | + |
| 104 | + if compute: |
| 105 | + assert isinstance(result, np.ndarray) |
| 106 | + else: |
| 107 | + assert isinstance(result, da.Array) |
| 108 | + |
| 109 | + assert_eq(result, expected) |
| 110 | + assert result.shape == (3,) |
| 111 | + |
| 112 | + |
| 113 | +def test_regression_metrics_do_not_support_weighted_multioutput(metric_pairs): |
| 114 | + m1, _ = metric_pairs |
| 115 | + |
| 116 | + a = da.random.uniform(size=(100, 3), chunks=(25, 3)) |
| 117 | + b = da.random.uniform(size=(100, 3), chunks=(25, 3)) |
| 118 | + weights = da.random.uniform(size=(3,)) |
| 119 | + |
| 120 | + if m1.__name__ == "r2_score": |
| 121 | + error_msg = "'multioutput' must be 'uniform_average'" |
| 122 | + else: |
| 123 | + error_msg = "Weighted 'multioutput' not supported." |
| 124 | + |
| 125 | + with pytest.raises((NotImplementedError, ValueError), match=error_msg): |
| 126 | + _ = m1(a, b, multioutput=weights) |
0 commit comments