Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 178 additions & 0 deletions tobac/tests/test_utils_interest_field.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
import numpy as np
import pandas as pd
import xarray as xr
import pytest

from tobac.utils.features_to_field import features_to_interest_field


def make_template():
time = np.arange(5)
x = np.arange(50)
y = np.arange(60)

data = xr.DataArray(
np.zeros((len(time), len(x), len(y))),
dims=("time", "x", "y"),
coords={"time": time, "x": x, "y": y},
name="interest",
)
return data


def test_empty_features_returns_zero_field():
template = make_template()

features = pd.DataFrame(columns=["frame", "hdim_1", "hdim_2", "threshold_value"])

out = features_to_interest_field(
features,
template,
position_mode="hdim",
amp_from="threshold_value",
)

assert out.shape == template.shape
assert float(out.max()) == 0.0
assert float(out.min()) == 0.0


def test_amplitude_scaling_from_threshold():
template = make_template()

features = pd.DataFrame(
[
{
"frame": 0,
"hdim_1": 20.0,
"hdim_2": 30.0,
"threshold_value": 0.3,
}
]
)

out = features_to_interest_field(
features,
template,
position_mode="hdim",
amp_from="threshold_value",
amp_factor=2.0,
sigma=3.0,
size_from=None,
mode="max",
)

assert np.isclose(float(out.max()), 0.6, atol=1e-6)


def test_blob_center_is_at_feature_position():
template = make_template()

features = pd.DataFrame(
[
{
"frame": 2,
"hdim_1": 10.0,
"hdim_2": 15.0,
"threshold_value": 1.0,
}
]
)

out = features_to_interest_field(
features,
template,
position_mode="hdim",
amp_from="threshold_value",
amp_factor=1.0,
sigma=2.0,
size_from=None,
mode="max",
)

val = out.isel(time=2, x=10, y=15).item()
assert np.isclose(val, 1.0, atol=1e-6)


def test_add_vs_max_overlap():
template = make_template()

features = pd.DataFrame(
[
{"frame": 0, "hdim_1": 25.0, "hdim_2": 30.0, "threshold_value": 1.0},
{"frame": 0, "hdim_1": 25.0, "hdim_2": 30.0, "threshold_value": 1.0},
]
)

out_add = features_to_interest_field(
features,
template,
position_mode="hdim",
amp_from="threshold_value",
amp_factor=1.0,
sigma=2.0,
size_from=None,
mode="add",
)

out_max = features_to_interest_field(
features,
template,
position_mode="hdim",
amp_from="threshold_value",
amp_factor=1.0,
sigma=2.0,
size_from=None,
mode="max",
)

assert float(out_add.max()) > float(out_max.max())
assert np.isclose(float(out_max.max()), 1.0, atol=1e-6)


def test_sigma_from_area_changes_blob_width():
template = make_template()

features = pd.DataFrame(
[
{
"frame": 0,
"hdim_1": 25.0,
"hdim_2": 30.0,
"threshold_value": 1.0,
"area": 100.0,
}
]
)

out = features_to_interest_field(
features,
template,
position_mode="hdim",
amp_from="threshold_value",
amp_factor=1.0,
sigma=1.0,
size_from="area",
mode="max",
)

center = out.isel(time=0, x=25, y=30).item()
off = out.isel(time=0, x=25, y=35).item()

assert off > 0.0
assert off < center


def test_invalid_position_mode_raises():
template = make_template()

features = pd.DataFrame(
[{"frame": 0, "hdim_1": 10.0, "hdim_2": 10.0, "threshold_value": 1.0}]
)

with pytest.raises(ValueError):
features_to_interest_field(
features,
template,
position_mode="invalid",
)
111 changes: 111 additions & 0 deletions tobac/utils/features_to_field.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
"""Function for converting feature labels back into an artificial interest field."""

import xarray as xr
import numpy as np


def features_to_interest_field(
features,
template: xr.DataArray,
*,
position_mode="hdim", # "hdim" or "xy"
position_cols=None, # override
time_key="frame", # or "time"
blob="gaussian", # "gaussian" or "tophat"
mode="max", # "max" or "add"
amp_from="threshold_value", # column name or constant float
amp_factor=2.0, # amp = amp_factor * amp_from (if column)
size_from="area", # column name; if missing uses sigma
sigma=5.0, # default sigma (grid points for hdim)
min_sigma=1.0,
):
if not isinstance(template, xr.DataArray):
raise TypeError("template must be an xarray.DataArray")
if template.ndim < 3:
raise ValueError("template must be at least 3D (time + 2D space)")

tdim = template.dims[0]
d1, d2 = template.dims[-2], template.dims[-1]
out = xr.zeros_like(template, dtype=float)

n1, n2 = template.sizes[d1], template.sizes[d2]

# choose position columns
if position_cols is None:
if position_mode == "hdim":
position_cols = ("hdim_1", "hdim_2")
elif position_mode == "xy":
position_cols = ("x", "y")
else:
raise ValueError("position_mode must be 'hdim' or 'xy'")

p1_col, p2_col = position_cols

# coordinate grids
if position_mode == "hdim":
C1, C2 = np.meshgrid(np.arange(n1), np.arange(n2), indexing="ij")

# sigma is in grid points
def rr2(p1, p2):
return (C1 - p1) ** 2 + (C2 - p2) ** 2

else: # "xy" physical coords
c1 = template[d1].values
c2 = template[d2].values
C1, C2 = np.meshgrid(c1, c2, indexing="ij")

# sigma is in same units as coords
def rr2(p1, p2):
return (C1 - p1) ** 2 + (C2 - p2) ** 2

def amplitude(row):
if isinstance(amp_from, (int, float)):
return float(amp_from)
return amp_factor * float(row[amp_from])

def sigma_for_row(row):
if (
size_from is not None
and size_from in row.index
and np.isfinite(row[size_from])
):
area = float(row[size_from])
if area > 0:
r = np.sqrt(area / np.pi)
s = max(min_sigma, r / 2.0)
return s
return float(sigma)

for _, row in features.iterrows():
# time selection
if time_key == "frame":
tidx = int(row["frame"])
selector = {tdim: out[tdim].values[tidx]}
elif time_key == "time":
selector = {tdim: np.datetime64(row["time"])}
else:
raise ValueError("time_key must be 'frame' or 'time'")

p1 = float(row[p1_col])
p2 = float(row[p2_col])
amp = amplitude(row)
sig = sigma_for_row(row)

r2 = rr2(p1, p2)

if blob == "gaussian":
blob2d = amp * np.exp(-0.5 * r2 / (sig**2))
elif blob == "tophat":
blob2d = amp * (r2 <= (sig**2))
else:
raise ValueError("blob must be 'gaussian' or 'tophat'")

current = out.sel(selector).values
if mode == "add":
out.loc[selector] = current + blob2d
elif mode == "max":
out.loc[selector] = np.maximum(current, blob2d)
else:
raise ValueError("mode must be 'add' or 'max'")

return out
Loading