Skip to content

Commit f7a22ca

Browse files
committed
added tests
1 parent 1366838 commit f7a22ca

8 files changed

Lines changed: 1089 additions & 0 deletions

File tree

pixi.lock

Lines changed: 432 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pixi.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ pkoffee = { path = "." }
1414
ruff = ">=0.14.11,<0.15"
1515
numpy = ">=2.4.1,<3"
1616
matplotlib = ">=3.10.8,<4"
17+
pytest = ">=9.0.2,<10"
18+
pytest-cov = ">=7.0.0,<8"
1719

1820
[package]
1921
name = "pkoffee"

tests/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Unit tests initialization."""

tests/test_data.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
"""Unit tests for data loading and preprocessing."""
2+
3+
from pathlib import Path
4+
5+
import numpy as np
6+
import pandas as pd
7+
import pytest
8+
9+
from pkoffee.data import (
10+
ColumnTypeError,
11+
CSVReadError,
12+
MissingColumnsError,
13+
RequiredColumn,
14+
curate,
15+
data_dtype,
16+
extract_arrays,
17+
load_csv,
18+
validate,
19+
)
20+
21+
22+
def test_validate() -> None:
23+
"""Test validate with valide DataFrame."""
24+
assert validate(pd.DataFrame({"cups": [0], "productivity": [1.2]})) is None
25+
26+
27+
def test_validate_wrong_type() -> None:
28+
"""Test validate with incorrect required column type."""
29+
with pytest.raises(ColumnTypeError):
30+
validate(pd.DataFrame({"cups": [0], "productivity": ["a"]}))
31+
32+
33+
def test_validate_missing_column() -> None:
34+
"""Test validate with missing required column in DataFrame."""
35+
with pytest.raises(MissingColumnsError):
36+
validate(pd.DataFrame({f"{RequiredColumn.CUPS}": [1], "notproductivity": [1.2]}))
37+
38+
39+
@pytest.mark.parametrize(
40+
("data", "expected"),
41+
[
42+
(
43+
pd.DataFrame({"cups": [1, np.nan, 2], "productivity": [1.2, 2.1, np.nan]}),
44+
pd.DataFrame({"cups": [1.0], "productivity": [1.2]}),
45+
),
46+
(
47+
pd.DataFrame({"cups": [np.nan, np.nan, np.nan], "productivity": [1.2, 2.1, 3.4]}),
48+
pd.DataFrame({"cups": [], "productivity": []}),
49+
),
50+
(
51+
pd.DataFrame({"cups": [1, 1, 4], "productivity": [1.2, 2.1, 0.5]}),
52+
pd.DataFrame({"cups": [1, 1, 4], "productivity": [1.2, 2.1, 0.5]}),
53+
),
54+
],
55+
)
56+
def test_currate(data: pd.DataFrame, expected: pd.DataFrame) -> None:
57+
"""Test curate with different DataFrames containing nans."""
58+
assert curate(data).equals(expected)
59+
60+
61+
def test_load_csv_valid_file(tmp_path: Path) -> None:
62+
"""Test loading valid CSV."""
63+
data_file = tmp_path / "valid.csv"
64+
cups = np.array([1, 2, 3], dtype=int)
65+
prod = np.array([2.3, 1.2, 4.8], dtype=data_dtype)
66+
np.savetxt(
67+
data_file,
68+
np.stack([cups, prod], axis=1),
69+
fmt=["%d", "%10.4f"],
70+
delimiter=",",
71+
header=f"{RequiredColumn.CUPS},{RequiredColumn.PRODUCTIVITY}",
72+
comments="",
73+
)
74+
75+
data = load_csv(data_file)
76+
assert RequiredColumn.CUPS in data.columns
77+
assert RequiredColumn.PRODUCTIVITY in data.columns
78+
assert np.isclose(data[RequiredColumn.CUPS].to_numpy(), cups).all()
79+
assert np.isclose(data[RequiredColumn.PRODUCTIVITY].to_numpy(), prod).all()
80+
assert data.dtypes[RequiredColumn.CUPS] == np.int64
81+
assert data.dtypes[RequiredColumn.PRODUCTIVITY] == np.float64
82+
83+
84+
def test_load_csv_missing_file() -> None:
85+
"""Test that FileNotFoundError is raised for missing files."""
86+
with pytest.raises(FileNotFoundError):
87+
load_csv(Path("nonexistent_file.csv"))
88+
89+
90+
def test_load_csv_missing_columns(tmp_path: Path) -> None:
91+
"""Test MissingColumnsError is raised for missing required columns."""
92+
wrong_col_file = tmp_path / "missing_columns.csv"
93+
data = np.stack([[1], [2.3]], axis=1)
94+
np.savetxt(
95+
wrong_col_file,
96+
data,
97+
fmt=["%d", "%10.4f"],
98+
delimiter=",",
99+
header=f"{RequiredColumn.CUPS},wrong_column",
100+
comments="",
101+
)
102+
with pytest.raises(MissingColumnsError, match="Missing required columns"):
103+
load_csv(wrong_col_file)
104+
105+
106+
def test_load_data_with_nan_values(tmp_path: Path) -> None:
107+
"""Test that rows with NaN values are dropped."""
108+
data_file = tmp_path / "valid_with_nan.csv"
109+
with data_file.open("w") as fh:
110+
fh.write(f"{RequiredColumn.CUPS},{RequiredColumn.PRODUCTIVITY}\n")
111+
fh.write("1,10.5\n")
112+
fh.write("2,\n") # Missing productivity
113+
fh.write("3,18.2\n")
114+
115+
data = load_csv(data_file)
116+
expected = pd.DataFrame({RequiredColumn.CUPS: [1, 3], RequiredColumn.PRODUCTIVITY: [10.5, 18.2]})
117+
assert data.equals(expected)
118+
119+
120+
def test_load_data_with_extra_values(tmp_path: Path) -> None:
121+
"""Test that rows with NaN values are dropped."""
122+
data_file = tmp_path / "valid_with_nan.csv"
123+
with data_file.open("w") as fh:
124+
fh.write(f"{RequiredColumn.CUPS},{RequiredColumn.PRODUCTIVITY}\n")
125+
fh.write("1,2.1\n")
126+
# try to read the file while it is open for write
127+
with pytest.raises(CSVReadError):
128+
load_csv(data_file)
129+
130+
131+
def test_extract_arrays() -> None:
132+
"""Test extracting numpy arrays from DataFrame."""
133+
cups_ref = np.array([1, 2, 3], dtype=int)
134+
productivity_ref = np.array([10.5, 15.3, 18.2], dtype=np.float64)
135+
data = pd.DataFrame({RequiredColumn.CUPS: cups_ref, RequiredColumn.PRODUCTIVITY: productivity_ref})
136+
137+
cups, productivity = extract_arrays(data)
138+
139+
assert np.allclose(cups_ref, cups)
140+
assert np.allclose(productivity_ref, productivity)

tests/test_fit_model.py

Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
"""Unit tests for model fitting."""
2+
3+
import numpy as np
4+
import pytest
5+
6+
from pkoffee.data import AnyShapeDataDtypeArray, neg_inf, pos_inf
7+
from pkoffee.data import data_dtype as dt
8+
from pkoffee.fit_model import (
9+
FunctionIdNotFoundInMappingError,
10+
FunctionNotFoundInMappingError,
11+
Model,
12+
ModelParsingError,
13+
fit_model,
14+
)
15+
from pkoffee.parametric_function import ParametersBounds
16+
17+
18+
class Linear:
19+
"""Linear function."""
20+
21+
def __call__(self, x: AnyShapeDataDtypeArray, a: dt, b: dt) -> AnyShapeDataDtypeArray:
22+
"""Evaluate y = a * x + b."""
23+
return a * x + b # pyright: ignore[reportReturnType] return dtype is data_dtype alright.
24+
25+
@staticmethod
26+
def param_guess() -> dict[str, dt]:
27+
"""Guess parameters for tests: a: 2.0, b: 1.0."""
28+
return {"a": dt(2.0), "b": dt(1.0)}
29+
30+
@staticmethod
31+
def param_bounds() -> ParametersBounds:
32+
"""Parameter bounds in [-inf, +inf]."""
33+
return ParametersBounds(min={"a": neg_inf, "b": neg_inf}, max={"a": pos_inf, "b": pos_inf})
34+
35+
36+
class LinearSQRT:
37+
"""Passthrough model with square root of coefficient parameter."""
38+
39+
def __call__(self, x: AnyShapeDataDtypeArray, a: dt) -> AnyShapeDataDtypeArray:
40+
"""Evaluate y = sqrt(a) * x."""
41+
return np.sqrt(a) * x
42+
43+
@staticmethod
44+
def param_guess() -> dict[str, dt]:
45+
"""Guess parameter: a: 1.0."""
46+
return {"a": dt(1.0)}
47+
48+
@staticmethod
49+
def param_bounds() -> ParametersBounds:
50+
"""Parameter bounds: "a" should be positive."""
51+
return ParametersBounds(min={"a": dt(5e-7)}, max={"a": pos_inf})
52+
53+
54+
class PassThrough:
55+
"""PassThrough empty model for testing."""
56+
57+
def __call__(self, x: AnyShapeDataDtypeArray) -> AnyShapeDataDtypeArray:
58+
"""Passthrough."""
59+
return x
60+
61+
@staticmethod
62+
def param_guess() -> dict[str, dt]:
63+
"""No parameter to guess."""
64+
return {}
65+
66+
@staticmethod
67+
def param_bounds() -> ParametersBounds:
68+
"""No parameter bounds."""
69+
return ParametersBounds(min={}, max={})
70+
71+
72+
def test_model_result_predict() -> None:
73+
"""Test ModelResult prediction method."""
74+
model_result = Model(
75+
name="Linear",
76+
function=Linear(),
77+
params=Linear.param_guess(),
78+
bounds=Linear.param_bounds(),
79+
r_squared=dt(0.95),
80+
)
81+
82+
x_new = np.array([0.0, 1.0, 2.0], dtype=dt)
83+
predictions = model_result.predict(x_new)
84+
85+
expected = np.array([1.0, 3.0, 5.0])
86+
np.testing.assert_allclose(predictions, expected)
87+
88+
89+
def test_fit_model_success() -> None:
90+
"""Test successful model fitting."""
91+
# Create simple linear data
92+
x = np.linspace(0, 10, 50)
93+
rng = np.random.default_rng(1337)
94+
y = 2.0 * x + 1.0 + rng.normal(0, 0.1, 50)
95+
96+
lin_model = Model(
97+
name="Test Linear",
98+
function=Linear(),
99+
params={"a": dt(1.0), "b": dt(0.0)},
100+
bounds=ParametersBounds(min=dict.fromkeys(["a", "b"], neg_inf), max=dict.fromkeys(["a", "b"], pos_inf)),
101+
)
102+
103+
result, _ = fit_model(x, y, lin_model)
104+
105+
assert result is not None
106+
assert result.name == "Test Linear"
107+
assert result.r_squared > 0.9 # Should fit well # noqa: PLR2004
108+
assert len(result.params) == 2 # noqa: PLR2004
109+
assert result.params["a"] == pytest.approx(2.0, abs=0.2)
110+
111+
112+
def test_fit_model_failure() -> None:
113+
"""Test that fit_model raises error on failure."""
114+
x = np.array([1.0, 2.0, 3.0])
115+
y = np.array([1.0, 2.0, 3.0])
116+
117+
# Create a config that will fail
118+
config = Model(
119+
name="Bad Model",
120+
function=LinearSQRT(),
121+
params={"a": dt(-10.0)},
122+
bounds=ParametersBounds(min={"a": dt(-100.0)}, max={"a": dt(-1.0)}),
123+
)
124+
125+
# Verify value error is raised, ignore RuntimeWarning from sqrt(0)
126+
with pytest.raises(ValueError), pytest.warns(RuntimeWarning): # noqa: PT011 error raised by scipy
127+
_ = fit_model(x, y, config)
128+
129+
130+
def test_model_result_repr() -> None:
131+
"""Test ModelResult string representation."""
132+
model = Model(
133+
name="TestModel",
134+
function=PassThrough(),
135+
params={},
136+
r_squared=dt(0.8765),
137+
bounds=ParametersBounds({}, {}),
138+
)
139+
140+
repr_str = repr(model)
141+
142+
assert repr_str == f"ModelFit(name='{model.name}', R²={model.r_squared:.3f})"
143+
144+
145+
def test_model_sort() -> None:
146+
"""Test Model.sort."""
147+
r_squared_min = dt(0.1)
148+
r_squared_max = dt(0.3)
149+
model_list = [
150+
Model(
151+
name="a",
152+
function=PassThrough(),
153+
params={},
154+
r_squared=r_squared_max,
155+
bounds=ParametersBounds({}, {}),
156+
),
157+
Model(
158+
name="b",
159+
function=PassThrough(),
160+
params={},
161+
r_squared=r_squared_min,
162+
bounds=ParametersBounds({}, {}),
163+
),
164+
]
165+
Model.sort(model_list)
166+
assert model_list[0].r_squared == r_squared_max
167+
assert model_list[1].r_squared == r_squared_min
168+
169+
170+
def test_model_to_dict() -> None:
171+
"""Test model conversion to dictionary."""
172+
linear_model = Model(
173+
name="test_linear",
174+
function=Linear(),
175+
params={"a": dt(1.0), "b": dt(0.5)},
176+
bounds=ParametersBounds(min={"a": dt(-1.0), "b": dt(-5.0)}, max={"a": dt(1.0), "b": dt(5.0)}),
177+
r_squared=dt(0.2),
178+
)
179+
assert linear_model.to_dict({Linear: "Linear"}) == {
180+
"name": "test_linear",
181+
"function": "Linear",
182+
"params": {"a": 1.0, "b": 0.5},
183+
"bounds": {"min": {"a": -1.0, "b": -5.0}, "max": {"a": 1.0, "b": 5.0}},
184+
"r_squared": 0.20000000298023224,
185+
}
186+
187+
188+
def test_model_to_dict_missing_mapping() -> None:
189+
"""Test Error raising when function mapping is missing during dict conversion."""
190+
model = Model(
191+
name="Passthrough",
192+
function=PassThrough(),
193+
params={},
194+
bounds=ParametersBounds(min={}, max={}),
195+
r_squared=neg_inf,
196+
)
197+
with pytest.raises(FunctionNotFoundInMappingError):
198+
model.to_dict({})
199+
200+
201+
def test_model_from_dict() -> None:
202+
"""Test model creation from dictionary."""
203+
d = {
204+
"name": "test_linear",
205+
"function": "Linear",
206+
"params": {"a": 1.0, "b": 0.5},
207+
"bounds": {"min": {"a": -1.0, "b": -5.0}, "max": {"a": 1.0, "b": 5.0}},
208+
"r_squared": 0.20000000298023224,
209+
}
210+
linear_model = Model.from_dict(d, {"Linear": Linear})
211+
assert linear_model.name == d["name"]
212+
assert isinstance(linear_model.function, Linear)
213+
assert linear_model.params == {p: dt(v) for p, v in d["params"].items()}
214+
assert linear_model.bounds == ParametersBounds(
215+
min={p: dt(v) for p, v in d["bounds"]["min"].items()}, max={p: dt(v) for p, v in d["bounds"]["max"].items()}
216+
)
217+
assert linear_model.r_squared == d["r_squared"]
218+
219+
220+
def test_model_from_dict_missing_mapping() -> None:
221+
"""Test model creation from dictionary error in case of missing mapping."""
222+
d = {
223+
"name": "passthrough",
224+
"function": "PassThrough",
225+
"params": {},
226+
"bounds": {"min": {}, "max": {}},
227+
"r_squared": neg_inf,
228+
}
229+
with pytest.raises(FunctionIdNotFoundInMappingError):
230+
Model.from_dict(d, {})
231+
232+
233+
def test_model_from_dict_bad_dict() -> None:
234+
"""Test model creation error when the dictionary doesn't have the required content."""
235+
with pytest.raises(ModelParsingError):
236+
Model.from_dict({}, {})

0 commit comments

Comments
 (0)