Skip to content

Commit fe696a9

Browse files
committed
Add pointwise error computation to method and add to forecasting unit test
1 parent dcdab7a commit fe696a9

File tree

3 files changed

+35
-3
lines changed

3 files changed

+35
-3
lines changed

deepsensor/eval/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .metrics import *

deepsensor/eval/metrics.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import xarray as xr
2+
from deepsensor.model.pred import Prediction
3+
4+
5+
def compute_errors(pred: Prediction, target: xr.Dataset) -> xr.Dataset:
6+
"""
7+
Compute errors between predictions and targets.
8+
9+
Args:
10+
pred: Prediction object.
11+
target: Target data.
12+
13+
Returns:
14+
xr.Dataset: Dataset of pointwise differences between predictions and targets
15+
at the same valid time in the predictions. Note, the difference is positive
16+
when the prediction is greater than the target.
17+
"""
18+
errors = {}
19+
for var_ID, pred_var in pred.items():
20+
target_var = target[var_ID]
21+
error = pred_var["mean"] - target_var.sel(time=pred_var.time)
22+
error.name = f"{var_ID}"
23+
errors[var_ID] = error
24+
return xr.Dataset(errors)

tests/test_model.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from deepsensor.data.loader import TaskLoader
1919
from deepsensor.model.convnp import ConvNP
2020
from deepsensor.train.train import Trainer
21+
from deepsensor.eval.metrics import compute_errors
2122

2223
from tests.utils import gen_random_data_xr, gen_random_data_pandas
2324

@@ -686,9 +687,15 @@ def test_forecasting_model_predict_return_valid_times(self):
686687

687688
if isinstance(pred_var, xr.Dataset):
688689
# Check we can compute errors using the valid time coord ('time')
689-
errors = pred_var["mean"] - self.da.sel(time=pred_var.time)
690-
assert errors.dims == ("lead_time", "init_time", "x1", "x2")
691-
assert errors.shape == pred_var["mean"].shape
690+
errors = compute_errors(pred, self.da.to_dataset())
691+
for var_ID in errors.keys():
692+
assert tuple(errors[var_ID].dims) == (
693+
"lead_time",
694+
"init_time",
695+
"x1",
696+
"x2",
697+
)
698+
assert errors[var_ID].shape == pred[var_ID]["mean"].shape
692699
elif isinstance(pred_var, pd.DataFrame):
693700
# Makes coordinate checking easier by avoiding repeat values
694701
pred_var = pred_var.to_xarray().isel(x1=0, x2=0)

0 commit comments

Comments
 (0)