|
| 1 | +"""Unit tests for model fitting.""" |
| 2 | + |
| 3 | +import numpy as np |
| 4 | +import pytest |
| 5 | + |
| 6 | +from pkoffee.data import AnyShapeDataDtypeArray, neg_inf, pos_inf |
| 7 | +from pkoffee.data import data_dtype as dt |
| 8 | +from pkoffee.fit_model import ( |
| 9 | + FunctionIdNotFoundInMappingError, |
| 10 | + FunctionNotFoundInMappingError, |
| 11 | + Model, |
| 12 | + ModelParsingError, |
| 13 | + fit_model, |
| 14 | +) |
| 15 | +from pkoffee.parametric_function import ParametersBounds |
| 16 | + |
| 17 | + |
| 18 | +class Linear: |
| 19 | + """Linear function.""" |
| 20 | + |
| 21 | + def __call__(self, x: AnyShapeDataDtypeArray, a: dt, b: dt) -> AnyShapeDataDtypeArray: |
| 22 | + """Evaluate y = a * x + b.""" |
| 23 | + return a * x + b # pyright: ignore[reportReturnType] return dtype is data_dtype alright. |
| 24 | + |
| 25 | + @staticmethod |
| 26 | + def param_guess() -> dict[str, dt]: |
| 27 | + """Guess parameters for tests: a: 2.0, b: 1.0.""" |
| 28 | + return {"a": dt(2.0), "b": dt(1.0)} |
| 29 | + |
| 30 | + @staticmethod |
| 31 | + def param_bounds() -> ParametersBounds: |
| 32 | + """Parameter bounds in [-inf, +inf].""" |
| 33 | + return ParametersBounds(min={"a": neg_inf, "b": neg_inf}, max={"a": pos_inf, "b": pos_inf}) |
| 34 | + |
| 35 | + |
| 36 | +class LinearSQRT: |
| 37 | + """Passthrough model with square root of coefficient parameter.""" |
| 38 | + |
| 39 | + def __call__(self, x: AnyShapeDataDtypeArray, a: dt) -> AnyShapeDataDtypeArray: |
| 40 | + """Evaluate y = sqrt(a) * x.""" |
| 41 | + return np.sqrt(a) * x |
| 42 | + |
| 43 | + @staticmethod |
| 44 | + def param_guess() -> dict[str, dt]: |
| 45 | + """Guess parameter: a: 1.0.""" |
| 46 | + return {"a": dt(1.0)} |
| 47 | + |
| 48 | + @staticmethod |
| 49 | + def param_bounds() -> ParametersBounds: |
| 50 | + """Parameter bounds: "a" should be positive.""" |
| 51 | + return ParametersBounds(min={"a": dt(5e-7)}, max={"a": pos_inf}) |
| 52 | + |
| 53 | + |
| 54 | +class PassThrough: |
| 55 | + """PassThrough empty model for testing.""" |
| 56 | + |
| 57 | + def __call__(self, x: AnyShapeDataDtypeArray) -> AnyShapeDataDtypeArray: |
| 58 | + """Passthrough.""" |
| 59 | + return x |
| 60 | + |
| 61 | + @staticmethod |
| 62 | + def param_guess() -> dict[str, dt]: |
| 63 | + """No parameter to guess.""" |
| 64 | + return {} |
| 65 | + |
| 66 | + @staticmethod |
| 67 | + def param_bounds() -> ParametersBounds: |
| 68 | + """No parameter bounds.""" |
| 69 | + return ParametersBounds(min={}, max={}) |
| 70 | + |
| 71 | + |
| 72 | +def test_model_result_predict() -> None: |
| 73 | + """Test ModelResult prediction method.""" |
| 74 | + model_result = Model( |
| 75 | + name="Linear", |
| 76 | + function=Linear(), |
| 77 | + params=Linear.param_guess(), |
| 78 | + bounds=Linear.param_bounds(), |
| 79 | + r_squared=dt(0.95), |
| 80 | + ) |
| 81 | + |
| 82 | + x_new = np.array([0.0, 1.0, 2.0], dtype=dt) |
| 83 | + predictions = model_result.predict(x_new) |
| 84 | + |
| 85 | + expected = np.array([1.0, 3.0, 5.0]) |
| 86 | + np.testing.assert_allclose(predictions, expected) |
| 87 | + |
| 88 | + |
| 89 | +def test_fit_model_success() -> None: |
| 90 | + """Test successful model fitting.""" |
| 91 | + # Create simple linear data |
| 92 | + x = np.linspace(0, 10, 50) |
| 93 | + rng = np.random.default_rng(1337) |
| 94 | + y = 2.0 * x + 1.0 + rng.normal(0, 0.1, 50) |
| 95 | + |
| 96 | + lin_model = Model( |
| 97 | + name="Test Linear", |
| 98 | + function=Linear(), |
| 99 | + params={"a": dt(1.0), "b": dt(0.0)}, |
| 100 | + bounds=ParametersBounds(min=dict.fromkeys(["a", "b"], neg_inf), max=dict.fromkeys(["a", "b"], pos_inf)), |
| 101 | + ) |
| 102 | + |
| 103 | + result, _ = fit_model(x, y, lin_model) |
| 104 | + |
| 105 | + assert result is not None |
| 106 | + assert result.name == "Test Linear" |
| 107 | + assert result.r_squared > 0.9 # Should fit well # noqa: PLR2004 |
| 108 | + assert len(result.params) == 2 # noqa: PLR2004 |
| 109 | + assert result.params["a"] == pytest.approx(2.0, abs=0.2) |
| 110 | + |
| 111 | + |
| 112 | +def test_fit_model_failure() -> None: |
| 113 | + """Test that fit_model raises error on failure.""" |
| 114 | + x = np.array([1.0, 2.0, 3.0]) |
| 115 | + y = np.array([1.0, 2.0, 3.0]) |
| 116 | + |
| 117 | + # Create a config that will fail |
| 118 | + config = Model( |
| 119 | + name="Bad Model", |
| 120 | + function=LinearSQRT(), |
| 121 | + params={"a": dt(-10.0)}, |
| 122 | + bounds=ParametersBounds(min={"a": dt(-100.0)}, max={"a": dt(-1.0)}), |
| 123 | + ) |
| 124 | + |
| 125 | + # Verify value error is raised, ignore RuntimeWarning from sqrt(0) |
| 126 | + with pytest.raises(ValueError), pytest.warns(RuntimeWarning): # noqa: PT011 error raised by scipy |
| 127 | + _ = fit_model(x, y, config) |
| 128 | + |
| 129 | + |
| 130 | +def test_model_result_repr() -> None: |
| 131 | + """Test ModelResult string representation.""" |
| 132 | + model = Model( |
| 133 | + name="TestModel", |
| 134 | + function=PassThrough(), |
| 135 | + params={}, |
| 136 | + r_squared=dt(0.8765), |
| 137 | + bounds=ParametersBounds({}, {}), |
| 138 | + ) |
| 139 | + |
| 140 | + repr_str = repr(model) |
| 141 | + |
| 142 | + assert repr_str == f"ModelFit(name='{model.name}', R²={model.r_squared:.3f})" |
| 143 | + |
| 144 | + |
| 145 | +def test_model_sort() -> None: |
| 146 | + """Test Model.sort.""" |
| 147 | + r_squared_min = dt(0.1) |
| 148 | + r_squared_max = dt(0.3) |
| 149 | + model_list = [ |
| 150 | + Model( |
| 151 | + name="a", |
| 152 | + function=PassThrough(), |
| 153 | + params={}, |
| 154 | + r_squared=r_squared_max, |
| 155 | + bounds=ParametersBounds({}, {}), |
| 156 | + ), |
| 157 | + Model( |
| 158 | + name="b", |
| 159 | + function=PassThrough(), |
| 160 | + params={}, |
| 161 | + r_squared=r_squared_min, |
| 162 | + bounds=ParametersBounds({}, {}), |
| 163 | + ), |
| 164 | + ] |
| 165 | + Model.sort(model_list) |
| 166 | + assert model_list[0].r_squared == r_squared_max |
| 167 | + assert model_list[1].r_squared == r_squared_min |
| 168 | + |
| 169 | + |
| 170 | +def test_model_to_dict() -> None: |
| 171 | + """Test model conversion to dictionary.""" |
| 172 | + linear_model = Model( |
| 173 | + name="test_linear", |
| 174 | + function=Linear(), |
| 175 | + params={"a": dt(1.0), "b": dt(0.5)}, |
| 176 | + bounds=ParametersBounds(min={"a": dt(-1.0), "b": dt(-5.0)}, max={"a": dt(1.0), "b": dt(5.0)}), |
| 177 | + r_squared=dt(0.2), |
| 178 | + ) |
| 179 | + assert linear_model.to_dict({Linear: "Linear"}) == { |
| 180 | + "name": "test_linear", |
| 181 | + "function": "Linear", |
| 182 | + "params": {"a": 1.0, "b": 0.5}, |
| 183 | + "bounds": {"min": {"a": -1.0, "b": -5.0}, "max": {"a": 1.0, "b": 5.0}}, |
| 184 | + "r_squared": 0.20000000298023224, |
| 185 | + } |
| 186 | + |
| 187 | + |
| 188 | +def test_model_to_dict_missing_mapping() -> None: |
| 189 | + """Test Error raising when function mapping is missing during dict conversion.""" |
| 190 | + model = Model( |
| 191 | + name="Passthrough", |
| 192 | + function=PassThrough(), |
| 193 | + params={}, |
| 194 | + bounds=ParametersBounds(min={}, max={}), |
| 195 | + r_squared=neg_inf, |
| 196 | + ) |
| 197 | + with pytest.raises(FunctionNotFoundInMappingError): |
| 198 | + model.to_dict({}) |
| 199 | + |
| 200 | + |
| 201 | +def test_model_from_dict() -> None: |
| 202 | + """Test model creation from dictionary.""" |
| 203 | + d = { |
| 204 | + "name": "test_linear", |
| 205 | + "function": "Linear", |
| 206 | + "params": {"a": 1.0, "b": 0.5}, |
| 207 | + "bounds": {"min": {"a": -1.0, "b": -5.0}, "max": {"a": 1.0, "b": 5.0}}, |
| 208 | + "r_squared": 0.20000000298023224, |
| 209 | + } |
| 210 | + linear_model = Model.from_dict(d, {"Linear": Linear}) |
| 211 | + assert linear_model.name == d["name"] |
| 212 | + assert isinstance(linear_model.function, Linear) |
| 213 | + assert linear_model.params == {p: dt(v) for p, v in d["params"].items()} |
| 214 | + assert linear_model.bounds == ParametersBounds( |
| 215 | + min={p: dt(v) for p, v in d["bounds"]["min"].items()}, max={p: dt(v) for p, v in d["bounds"]["max"].items()} |
| 216 | + ) |
| 217 | + assert linear_model.r_squared == d["r_squared"] |
| 218 | + |
| 219 | + |
| 220 | +def test_model_from_dict_missing_mapping() -> None: |
| 221 | + """Test model creation from dictionary error in case of missing mapping.""" |
| 222 | + d = { |
| 223 | + "name": "passthrough", |
| 224 | + "function": "PassThrough", |
| 225 | + "params": {}, |
| 226 | + "bounds": {"min": {}, "max": {}}, |
| 227 | + "r_squared": neg_inf, |
| 228 | + } |
| 229 | + with pytest.raises(FunctionIdNotFoundInMappingError): |
| 230 | + Model.from_dict(d, {}) |
| 231 | + |
| 232 | + |
| 233 | +def test_model_from_dict_bad_dict() -> None: |
| 234 | + """Test model creation error when the dictionary doesn't have the required content.""" |
| 235 | + with pytest.raises(ModelParsingError): |
| 236 | + Model.from_dict({}, {}) |
0 commit comments