diff --git a/tobac/tests/test_utils_interest_field.py b/tobac/tests/test_utils_interest_field.py new file mode 100644 index 00000000..8dfa49ab --- /dev/null +++ b/tobac/tests/test_utils_interest_field.py @@ -0,0 +1,178 @@ +import numpy as np +import pandas as pd +import xarray as xr +import pytest + +from tobac.utils.features_to_field import features_to_interest_field + + +def make_template(): + time = np.arange(5) + x = np.arange(50) + y = np.arange(60) + + data = xr.DataArray( + np.zeros((len(time), len(x), len(y))), + dims=("time", "x", "y"), + coords={"time": time, "x": x, "y": y}, + name="interest", + ) + return data + + +def test_empty_features_returns_zero_field(): + template = make_template() + + features = pd.DataFrame(columns=["frame", "hdim_1", "hdim_2", "threshold_value"]) + + out = features_to_interest_field( + features, + template, + position_mode="hdim", + amp_from="threshold_value", + ) + + assert out.shape == template.shape + assert float(out.max()) == 0.0 + assert float(out.min()) == 0.0 + + +def test_amplitude_scaling_from_threshold(): + template = make_template() + + features = pd.DataFrame( + [ + { + "frame": 0, + "hdim_1": 20.0, + "hdim_2": 30.0, + "threshold_value": 0.3, + } + ] + ) + + out = features_to_interest_field( + features, + template, + position_mode="hdim", + amp_from="threshold_value", + amp_factor=2.0, + sigma=3.0, + size_from=None, + mode="max", + ) + + assert np.isclose(float(out.max()), 0.6, atol=1e-6) + + +def test_blob_center_is_at_feature_position(): + template = make_template() + + features = pd.DataFrame( + [ + { + "frame": 2, + "hdim_1": 10.0, + "hdim_2": 15.0, + "threshold_value": 1.0, + } + ] + ) + + out = features_to_interest_field( + features, + template, + position_mode="hdim", + amp_from="threshold_value", + amp_factor=1.0, + sigma=2.0, + size_from=None, + mode="max", + ) + + val = out.isel(time=2, x=10, y=15).item() + assert np.isclose(val, 1.0, atol=1e-6) + + +def test_add_vs_max_overlap(): + template = make_template() + + features = pd.DataFrame( + [ + {"frame": 0, "hdim_1": 25.0, "hdim_2": 30.0, "threshold_value": 1.0}, + {"frame": 0, "hdim_1": 25.0, "hdim_2": 30.0, "threshold_value": 1.0}, + ] + ) + + out_add = features_to_interest_field( + features, + template, + position_mode="hdim", + amp_from="threshold_value", + amp_factor=1.0, + sigma=2.0, + size_from=None, + mode="add", + ) + + out_max = features_to_interest_field( + features, + template, + position_mode="hdim", + amp_from="threshold_value", + amp_factor=1.0, + sigma=2.0, + size_from=None, + mode="max", + ) + + assert float(out_add.max()) > float(out_max.max()) + assert np.isclose(float(out_max.max()), 1.0, atol=1e-6) + + +def test_sigma_from_area_changes_blob_width(): + template = make_template() + + features = pd.DataFrame( + [ + { + "frame": 0, + "hdim_1": 25.0, + "hdim_2": 30.0, + "threshold_value": 1.0, + "area": 100.0, + } + ] + ) + + out = features_to_interest_field( + features, + template, + position_mode="hdim", + amp_from="threshold_value", + amp_factor=1.0, + sigma=1.0, + size_from="area", + mode="max", + ) + + center = out.isel(time=0, x=25, y=30).item() + off = out.isel(time=0, x=25, y=35).item() + + assert off > 0.0 + assert off < center + + +def test_invalid_position_mode_raises(): + template = make_template() + + features = pd.DataFrame( + [{"frame": 0, "hdim_1": 10.0, "hdim_2": 10.0, "threshold_value": 1.0}] + ) + + with pytest.raises(ValueError): + features_to_interest_field( + features, + template, + position_mode="invalid", + ) diff --git a/tobac/utils/features_to_field.py b/tobac/utils/features_to_field.py new file mode 100644 index 00000000..99a38bb4 --- /dev/null +++ b/tobac/utils/features_to_field.py @@ -0,0 +1,111 @@ +"""Function for converting feature labels back into an artificial interest field.""" + +import xarray as xr +import numpy as np + + +def features_to_interest_field( + features, + template: xr.DataArray, + *, + position_mode="hdim", # "hdim" or "xy" + position_cols=None, # override + time_key="frame", # or "time" + blob="gaussian", # "gaussian" or "tophat" + mode="max", # "max" or "add" + amp_from="threshold_value", # column name or constant float + amp_factor=2.0, # amp = amp_factor * amp_from (if column) + size_from="area", # column name; if missing uses sigma + sigma=5.0, # default sigma (grid points for hdim) + min_sigma=1.0, +): + if not isinstance(template, xr.DataArray): + raise TypeError("template must be an xarray.DataArray") + if template.ndim < 3: + raise ValueError("template must be at least 3D (time + 2D space)") + + tdim = template.dims[0] + d1, d2 = template.dims[-2], template.dims[-1] + out = xr.zeros_like(template, dtype=float) + + n1, n2 = template.sizes[d1], template.sizes[d2] + + # choose position columns + if position_cols is None: + if position_mode == "hdim": + position_cols = ("hdim_1", "hdim_2") + elif position_mode == "xy": + position_cols = ("x", "y") + else: + raise ValueError("position_mode must be 'hdim' or 'xy'") + + p1_col, p2_col = position_cols + + # coordinate grids + if position_mode == "hdim": + C1, C2 = np.meshgrid(np.arange(n1), np.arange(n2), indexing="ij") + + # sigma is in grid points + def rr2(p1, p2): + return (C1 - p1) ** 2 + (C2 - p2) ** 2 + + else: # "xy" physical coords + c1 = template[d1].values + c2 = template[d2].values + C1, C2 = np.meshgrid(c1, c2, indexing="ij") + + # sigma is in same units as coords + def rr2(p1, p2): + return (C1 - p1) ** 2 + (C2 - p2) ** 2 + + def amplitude(row): + if isinstance(amp_from, (int, float)): + return float(amp_from) + return amp_factor * float(row[amp_from]) + + def sigma_for_row(row): + if ( + size_from is not None + and size_from in row.index + and np.isfinite(row[size_from]) + ): + area = float(row[size_from]) + if area > 0: + r = np.sqrt(area / np.pi) + s = max(min_sigma, r / 2.0) + return s + return float(sigma) + + for _, row in features.iterrows(): + # time selection + if time_key == "frame": + tidx = int(row["frame"]) + selector = {tdim: out[tdim].values[tidx]} + elif time_key == "time": + selector = {tdim: np.datetime64(row["time"])} + else: + raise ValueError("time_key must be 'frame' or 'time'") + + p1 = float(row[p1_col]) + p2 = float(row[p2_col]) + amp = amplitude(row) + sig = sigma_for_row(row) + + r2 = rr2(p1, p2) + + if blob == "gaussian": + blob2d = amp * np.exp(-0.5 * r2 / (sig**2)) + elif blob == "tophat": + blob2d = amp * (r2 <= (sig**2)) + else: + raise ValueError("blob must be 'gaussian' or 'tophat'") + + current = out.sel(selector).values + if mode == "add": + out.loc[selector] = current + blob2d + elif mode == "max": + out.loc[selector] = np.maximum(current, blob2d) + else: + raise ValueError("mode must be 'add' or 'max'") + + return out