Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ICD revalidation #196

Merged
merged 32 commits into from
Jan 30, 2025
Merged
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
ed14995
add function to extract ic matches
metinbicer Dec 17, 2024
e68ec5d
calculate icd absolute and relative errors
metinbicer Dec 17, 2024
4e9c2f2
implement icd emul pipeline
metinbicer Dec 17, 2024
0348e6a
add icd scorer functions
metinbicer Dec 17, 2024
344941d
import scorer functions to include in __all__
metinbicer Dec 17, 2024
613fbd4
icd val gen
metinbicer Dec 17, 2024
47974d0
initial version of icd analysis
metinbicer Dec 17, 2024
0d10e8d
update old algo original name
metinbicer Dec 18, 2024
9bb7895
implement performance on the pipeline level (runtime_s)
metinbicer Jan 20, 2025
e5fc3e6
minor fix for df empty check
metinbicer Jan 20, 2025
9b46147
refactor error function to use it in the final aggregator
metinbicer Jan 21, 2025
38dde1c
rename error prefix to tp
metinbicer Jan 23, 2025
890d85c
move functions from gs/eval to utils/eval
metinbicer Jan 23, 2025
34b4eea
update reval analysis
metinbicer Jan 24, 2025
3084f9e
update documentation
metinbicer Jan 24, 2025
db990cb
fix formatting
metinbicer Jan 24, 2025
7f3d4bb
fix comment
metinbicer Jan 24, 2025
2c95483
update error column names
metinbicer Jan 24, 2025
c3ede48
update branch name
metinbicer Jan 24, 2025
cbf5ce3
refactor matched icd error function name to clarify the use of true p…
metinbicer Jan 24, 2025
23667c4
present performance comparison table (old vs new)
metinbicer Jan 24, 2025
94678ac
Some minor renamings and wordings
AKuederle Jan 27, 2025
39702fa
switch pipeline runtime calc to wrapper
metinbicer Jan 28, 2025
b039f84
add docstring to functions moved to utils/eval
metinbicer Jan 28, 2025
290e2b4
use ref data to calc mean step time in ic timing error
metinbicer Jan 28, 2025
35a94ca
minor rename and comments
metinbicer Jan 28, 2025
50d76fe
minor fix for linting
metinbicer Jan 28, 2025
61575b4
Removed uneeded functionality from combine_detected_and_reference_met…
AKuederle Jan 29, 2025
5d19aaa
Some minor doc updates
AKuederle Jan 29, 2025
71b5c92
Add proper type annots for pipeline paras
AKuederle Jan 29, 2025
80a2058
Lint fix
AKuederle Jan 30, 2025
d9493fe
Switched back to main
AKuederle Jan 30, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
@@ -213,6 +213,7 @@ def substitute(matchobj) -> str:
"../examples/data_transform",
"../examples/dev_guides",
"../revalidation/gait_sequences",
"../revalidation/initial_contacts",
"../revalidation/stride_length",
]
),
32 changes: 32 additions & 0 deletions docs/modules/initial_contacts.rst
Original file line number Diff line number Diff line change
@@ -18,6 +18,16 @@ Algorithms
IcdIonescu
IcdHKLeeImproved

Pipelines
+++++++++
.. currentmodule:: mobgap.initial_contacts

.. autosummary::
:toctree: generated/initial_contacts
:template: class.rst

pipeline.IcdEmulationPipeline

Utils
+++++
.. currentmodule:: mobgap.initial_contacts
@@ -37,7 +47,29 @@ Evaluation
:template: func.rst

calculate_matched_icd_performance_metrics
calculate_true_positive_icd_error
categorize_ic_list
get_matching_ics

Evaluation Scores
+++++++++++++++++
These scores are expected to be used in combination with :class:`~mobgap.utils.evaluation.Evaluation` and
:class:`~mobgap.utils.evaluation.EvaluationCV` or directly with :func:`~tpcp.validation.cross_validation` and
:func:`~tpcp.validation.validation`.

.. currentmodule:: mobgap.initial_contacts.evaluation

.. autosummary::
:toctree: generated/initial_contacts

icd_score

.. autosummary::
:toctree: generated/initial_contacts
:template: func.rst

icd_per_datapoint_score
icd_final_agg

Base Classes
++++++++++++
48 changes: 5 additions & 43 deletions mobgap/gait_sequences/evaluation.py
Original file line number Diff line number Diff line change
@@ -19,8 +19,10 @@
)
from mobgap.utils.evaluation import (
accuracy_score,
combine_detected_and_reference_metrics,
count_samples_in_intervals,
count_samples_in_match_intervals,
extract_tp_matches,
npv_score,
precision_recall_f1_score,
specificity_score,
@@ -721,10 +723,10 @@ def get_matching_intervals(

tp_matches = matches.query("match_type == 'tp'")

detected_matches = _extract_tp_matches(metrics_detected, tp_matches["gs_id_detected"])
reference_matches = _extract_tp_matches(metrics_reference, tp_matches["gs_id_reference"])
detected_matches = extract_tp_matches(metrics_detected, tp_matches["gs_id_detected"])
reference_matches = extract_tp_matches(metrics_reference, tp_matches["gs_id_reference"])

combined_matches = _combine_detected_and_reference_metrics(
combined_matches = combine_detected_and_reference_metrics(
detected_matches, reference_matches, tp_matches=tp_matches
)

@@ -748,46 +750,6 @@ def _check_gs_level_matches_sanity(matches: pd.DataFrame) -> pd.DataFrame:
return matches


def _extract_tp_matches(metrics: pd.DataFrame, match_indices: pd.Series) -> pd.DataFrame:
try:
matches = metrics.loc[match_indices]
except KeyError as e:
raise ValueError(
"The indices from the provided `matches` DataFrame do not fit to the metrics DataFrames. "
"Please ensure that the `matches` DataFrame is calculated based on the same data "
"as the `metrics` DataFrames and thus refers to valid indices."
) from e
return matches


def _combine_detected_and_reference_metrics(
detected: pd.DataFrame, reference: pd.DataFrame, tp_matches: Union[pd.DataFrame, None] = None
) -> pd.DataFrame:
# if wb_id in index, add it as a column to preserve it in the combined DataFrame
if "wb_id" in detected.index.names and "wb_id" in reference.index.names:
detected.insert(0, "wb_id", detected.index.get_level_values("wb_id"))
reference.insert(0, "wb_id", reference.index.get_level_values("wb_id"))

common_columns = list(set(reference.columns).intersection(detected.columns))
if len(common_columns) == 0:
raise ValueError("No common columns found in `metrics_detected` and `metrics_reference`.")

detected = detected[common_columns]
reference = reference[common_columns]

if tp_matches is not None:
detected.index = tp_matches.index
reference.index = tp_matches.index

matches = detected.merge(reference, left_index=True, right_index=True, suffixes=("_det", "_ref"))

# construct MultiIndex columns
matches.columns = pd.MultiIndex.from_product([["detected", "reference"], common_columns])
# make 'metrics' level the uppermost level and sort columns accordingly for readability
matches = matches.swaplevel(axis=1).sort_index(axis=1, level=0)
return matches


__all__ = [
"categorize_intervals_per_sample",
"categorize_intervals",
253 changes: 253 additions & 0 deletions mobgap/initial_contacts/_evaluation_scorer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
import warnings

import pandas as pd
from tpcp.validate import Scorer, no_agg

from mobgap.data.base import BaseGaitDatasetWithReference
from mobgap.initial_contacts.pipeline import IcdEmulationPipeline


def icd_per_datapoint_score(pipeline: IcdEmulationPipeline, datapoint: BaseGaitDatasetWithReference) -> dict:
"""Evaluate the performance of an ICD algorithm on a single datapoint.

.. warning:: This function is not meant to be called directly, but as a scoring function in a
:class:`tpcp.validate.Scorer`.
If you are writing custom scoring functions, you can use this function as a template or wrap it in a new
function.

This function is used to evaluate the performance of an ICD algorithm on a single datapoint.
It calculates the performance metrics based on the detected initial contacts and the reference initial contacts.

The following performance metrics are calculated:

- all outputs of :func:`~mobgap.initial_contacts.evaluation.calculate_matched_icd_performance_metrics`
(will be averaged over all datapoints)
- ``matches``: The matched initial contacts calculated by
:func:`~mobgap.initial_contacts.evaluation.categorize_ic_list` (return as ``no_agg``)
- ``detected``: The detected initial contacts (return as ``no_agg``)
- ``reference``: The reference initial contacts (return as ``no_agg``)
- ``sampling_rate_hz``: The sampling rate of the data (return as ``no_agg``)

Parameters
----------
pipeline
An instance of ICD emulation pipeline that wraps the algorithm that should be evaluated.
datapoint
The datapoint to be evaluated.

Returns
-------
dict
A dictionary containing the performance metrics.
Note, that some results are wrapped in a ``no_agg`` object or other aggregators.
The results of this function are not expected to be parsed manually, but rather the function is expected to be
used in the context of the :func:`~tpcp.validate.validate`/:func:`~tpcp.validate.cross_validate` functions or
similar as scorer.
This functions will aggregate the results and provide a summary of the performance metrics.

"""
from mobgap.initial_contacts.evaluation import (
calculate_matched_icd_performance_metrics,
calculate_true_positive_icd_error,
categorize_ic_list,
get_matching_ics,
)
from mobgap.utils.conversions import as_samples
from mobgap.utils.df_operations import create_multi_groupby

with warnings.catch_warnings():
# We know that these errors might happen, and they are usually not relevant for the evaluation
warnings.filterwarnings("ignore", message="Zero division", category=UserWarning)
warnings.filterwarnings("ignore", message="multiple ICs", category=UserWarning)

# Run the algorithm on the datapoint
pipeline.safe_run(datapoint)
detected_ic_list = pipeline.ic_list_
reference_ic_list = datapoint.reference_parameters_.ic_list
sampling_rate_hz = datapoint.sampling_rate_hz

# tolerance around the reference ic (this is a centered window - half window in both directions)
tolerance_s = 0.5
tolerance_samples = as_samples(tolerance_s, sampling_rate_hz)

# match types
matches_per_wb = create_multi_groupby(detected_ic_list, reference_ic_list, groupby="wb_id").apply(
lambda df1, df2: categorize_ic_list(
ic_list_detected=df1,
ic_list_reference=df2,
tolerance_samples=tolerance_samples,
multiindex_warning=False,
)
)
# check if matches_per_wb has the required columns
if matches_per_wb.empty == 1:
# then it is an empty dataframe without required columns
matches_per_wb = pd.DataFrame(
{
"ic_id_detected": [],
"ic_id_reference": [],
"match_type": [],
"wb_id": [],
}
).set_index(["wb_id"])

# calculate run time on pipeline level
runtime_s = pipeline.perf_["runtime_s"]

# match initial contacts, get true positives
tp_ics = get_matching_ics(
metrics_detected=detected_ic_list,
metrics_reference=reference_ic_list,
matches=matches_per_wb,
)

# Calculate the performance metrics
performance_metrics = {
**calculate_matched_icd_performance_metrics(
matches_per_wb,
),
**calculate_true_positive_icd_error(
reference_ic_list,
tp_ics,
sampling_rate_hz,
),
"matches": no_agg(matches_per_wb),
"detected": no_agg(detected_ic_list),
"reference": no_agg(reference_ic_list),
"tp_ics": no_agg(tp_ics),
"sampling_rate_hz": no_agg(sampling_rate_hz),
"runtime_s": runtime_s,
}

return performance_metrics


def icd_final_agg(
agg_results: dict[str, float],
single_results: dict[str, list],
pipeline: IcdEmulationPipeline, # noqa: ARG001
dataset: BaseGaitDatasetWithReference,
) -> tuple[dict[str, any], dict[str, list[any]]]:
"""Aggregate the performance metrics of an ICD algorithm over multiple datapoints.

.. warning:: This function is not meant to be called directly, but as ``final_aggregator`` in a
:class:`tpcp.validate.Scorer`.
If you are writing custom scoring functions, you can use this function as a template or wrap it in a new
function.

This function aggregates the performance metrics as follows:

- All raw outputs (``detected``, ``reference``, ``sampling_rate_hz``) are concatenated to a single
dataframe, to make it easier to work with and are returned as part of the single results.
- We recalculate all performance metrics from
:func:`~mobgap.initial_contacts.evaluation.calculate_matched_icd_performance_metrics` on the combined data.
The results are prefixed with ``combined__``.
Compared to the per-datapoint results (which are calculated, as errors per recording -> average over all
recordings), these metrics are calculated as combining all ICDs from all recordings and then calculating the
performance metrics.
Effectively, this means, that in the `per_datapoint` version, each recording is weighted equally, while in the
`combined` version, each IC is weighted equally.

Parameters
----------
agg_results
The aggregated results from all datapoints (see :class:`~tpcp.validate.Scorer`).
single_results
The per-datapoint results (see :class:`~tpcp.validate.Scorer`).
pipeline
The pipeline that was passed to the scorer.
This is ignored in this function, but might be useful in custom final aggregators.
dataset
The dataset that was passed to the scorer.

Returns
-------
final_agg_results
The final aggregated results.
final_single_results
The per-datapoint results, that are not aggregated.

"""
from mobgap.initial_contacts.evaluation import (
calculate_matched_icd_performance_metrics,
calculate_true_positive_icd_error,
)

data_labels = [d.group_label for d in dataset]
data_label_names = data_labels[0]._fields
# We combine each to a combined dataframe
matches = single_results.pop("matches")
matches = pd.concat(matches, keys=data_labels, names=[*data_label_names, *matches[0].index.names])
detected = single_results.pop("detected")
detected = pd.concat(detected, keys=data_labels, names=[*data_label_names, *detected[0].index.names])
reference = single_results.pop("reference")
reference = pd.concat(reference, keys=data_labels, names=[*data_label_names, *reference[0].index.names])
tp_ics = single_results.pop("tp_ics")
tp_ics = pd.concat(tp_ics, keys=data_labels, names=[*data_label_names, *tp_ics[0].index.names])

aggregated_single_results = {
"raw__detected": detected,
"raw__reference": reference,
}

sampling_rate_hz = single_results.pop("sampling_rate_hz")
if set(sampling_rate_hz) != {sampling_rate_hz[0]}:
raise ValueError(
"Sampling rate is not the same for all datapoints in the dataset. "
"This not supported by this scorer. "
"Provide a custom scorer that can handle this case."
)

combined_matched = {
f"combined__{k}": v
for k, v in {
**calculate_matched_icd_performance_metrics(matches),
**calculate_true_positive_icd_error(reference, tp_ics, sampling_rate_hz[0]),
}.items()
}

# Note, that we pass the "aggregated_single_results" out via the single results and not the aggregated results
# The reason is that the aggregated results are expected to be a single value per metric, while the single results
# can be anything.
return {**agg_results, **combined_matched}, {**single_results, **aggregated_single_results}


#: :data:: icd_score
#: Scorer class instance for ICD algorithms.
icd_score = Scorer(icd_per_datapoint_score, final_aggregator=icd_final_agg)
icd_score.__doc__ = """Scorer for ICD algorithms.

This is a pre-configured :class:`~tpcp.validate.Scorer` object using the :func:`icd_per_datapoint_score` function as
per-datapoint scorer and the :func:`icd_final_agg` function as final aggregator.
For more information about Scorer, head to the tpcp documentation (:class:`~tpcp.validate.Scorer`).
For usage information in the context of mobgap, have a look at the :ref:`evaluation example <icd_evaluation>` for ICD.

The following metrics are calculated:

Raw metrics (part of the single results):

- ``single__raw__detected``: The detected initial contacts as a single dataframe with the datapoint labels as index.
- ``single__raw__reference``: The reference initial contacts as a single dataframe with the datapoint labels as index.

Metrics per datapoint (single results):
*These values are all provided as a list of values, one per datapoint.*

- All outputs of :func:`~mobgap.initial_contacts.evaluation.calculate_matched_icd_performance_metrics` and
:func:`~mobgap.initial_contacts.evaluation.calculate_true_positive_icd_error` averaged per
datapoint. These are stored as ``single__{metric_name}``
- ``single__runtime_s``: The runtime of the algorithm in seconds. If multiple WBs were processed, is the runtime it
took to process all WBs.

Aggregated metrics (aggregated results):

- All single outputs averaged over all datapoints. These are stored as ``agg__{metric_name}``.
- All metrics from :func:`~mobgap.initial_contacts.evaluation.calculate_matched_icd_performance_metrics` and
:func:`~mobgap.initial_contacts.evaluation.calculate_true_positive_icd_error` recalculated on all detected ICs across
all datapoints. These are stored as ``combined__{metric_name}``.
Compared to the per-datapoint results (which are calculated, as errors per recording -> average over all
recordings), these metrics are calculated as combining all ICDs from all recordings and then calculating the
performance metrics.
Effectively, this means, that in the `per_datapoint` version, each recording is weighted equally, while in the
`combined` version, each IC is weighted equally.

"""
172 changes: 171 additions & 1 deletion mobgap/initial_contacts/evaluation.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,23 @@
"""Class to evaluate initial contact detection algorithms."""

import warnings
from collections.abc import Hashable
from typing import Literal, Union

import numpy as np
import pandas as pd
from scipy.spatial import KDTree

from mobgap.utils.evaluation import precision_recall_f1_score
from mobgap.initial_contacts._evaluation_scorer import (
icd_final_agg,
icd_per_datapoint_score,
icd_score,
)
from mobgap.utils.evaluation import (
combine_detected_and_reference_metrics,
extract_tp_matches,
precision_recall_f1_score,
)


def calculate_matched_icd_performance_metrics(
@@ -72,6 +82,72 @@ def calculate_matched_icd_performance_metrics(
return icd_metrics


def calculate_true_positive_icd_error(
ic_list_reference: pd.DataFrame,
match_ics: pd.DataFrame,
sampling_rate_hz: float,
groupby: Union[Hashable, tuple[Hashable, ...]] = "wb_id",
) -> dict[str, Union[float, int]]:
"""
Calculate error metrics for initial contact detection results.
This function assumes that you already classified the detected initial contacts as true positive (tp), false
positive (fp), or false negative (fn) matches using the
:func:`~mobgap.initial_contacts.evaluation.categorize_ic_list` function.
The dataframe returned by categorize function can then be used as input to this function.
The following metrics are calculated for each true positive initial contact:
- `tp_absolute_timing_error_s`: Absolute time difference (in seconds) between the detected and reference initial
contact.
- `tp_relative_timing_error`: All absolute errors, within a walking bout, divided by the average step duration
estimated by the INDIP.
In case no ICs are detected, the error metrics will be 0.
Note, that this will introduce a bias when comparing these values, because algorithms that don't find any ICs will
have a lower error than algorithms that find ICs but with a higher error.
The value should always be considered together with the number of correctly detected ICs.
Parameters
----------
ic_list_reference: pd.DataFrame
The dataframe of reference initial contacts.
match_ics: pd.DataFrame
Initial contact true positives as output by :func:`~mobgap.initial_contacts.evaluation.get_matching_ics`.
sampling_rate_hz: float
Sampling rate of the data.
groupby
A valid pandas groupby argument to group the initial contacts by to calculate the average step duration.
Returns
-------
error_metrics: dict
"""
# calculate absolute error in seconds
tp_absolute_timing_error_s = abs(match_ics["ic"]["detected"] - match_ics["ic"]["reference"]) / sampling_rate_hz

# relative error (estimated by dividing all absolute errors, within a walking bout, by the average step duration
# estimated by the reference system)
mean_ref_step_time_s = (
ic_list_reference.groupby(groupby)["ic"].diff().dropna().groupby(groupby).mean() / sampling_rate_hz
)

tp_relative_timing_error = tp_absolute_timing_error_s / mean_ref_step_time_s

# return mean after dropping nans, unless empty, return 0
error_metrics = {
"tp_absolute_timing_error_s": tp_absolute_timing_error_s.dropna().mean()
if not tp_absolute_timing_error_s.dropna().empty
else 0,
"tp_relative_timing_error": tp_relative_timing_error.dropna().mean()
if not tp_relative_timing_error.dropna().empty
else 0,
}

return error_metrics


def categorize_ic_list(
*,
ic_list_detected: pd.DataFrame,
@@ -224,6 +300,88 @@ def categorize_ic_list(
return matches


def get_matching_ics(
*, metrics_detected: pd.DataFrame, metrics_reference: pd.DataFrame, matches: pd.DataFrame
) -> pd.DataFrame:
"""
Extract the detected and reference initial contacts that are considered as matches sequence-by-sequence (tps).
The metrics of the detected and reference initial contacts are extracted and returned in a DataFrame
for further comparison.
Parameters
----------
metrics_detected
Each row corresponds to a detected initial contact interval as output from the ICD algorithms.
The columns contain the metrics estimated for each respective initial contact based on these detected intervals.
The columns present in both `metrics_detected` and `metrics_reference` are regarded for the matching,
while the other columns are discarded.
metrics_reference
Each row corresponds to a reference initial contact interval as retrieved from the reference system.
The columns contain the metrics estimated for each respective initial contact based on these reference intervals.
The columns present in both `metrics_detected` and `metrics_reference` are regarded for the matching,
while the other columns are discarded.
matches
A DataFrame containing the matched initial contacts
as output by :func:`~mobgap.initial_contacts.evaluation.calculate_matched_icd_performance_metrics`.
Must have been calculated based on the same interval data as `metrics_detected` and `metrics_reference`.
Expected to have the columns `ic_id_detected`, `ic_id_reference`, and `match_type`.
Returns
-------
matches: pd.DataFrame
The detected initial contaccts that are considered as matches assigned to the reference sequences
they are matching with.
As index, the unique identifier for each matched initial contact assigned in the `matches` DataFrame is used.
The columns are two-level MultiIndex columns, consisting of a `metrics` and an `origin` level.
As first column level, all columns present in both `metrics_detected` and `metrics_reference` are included.
The second column level indicates the origin of the respective value, either `detected` or `reference` for
metrics that were estimated based on the detected or reference initial contacts, respectively.
Examples
--------
>>> from mobgap.initial_contacts.evaluation import (
... categorize_ic_list,
... get_matching_ics,
... )
>>> ic_detected = pd.DataFrame([11, 23, 30, 50], columns=["ic"]).rename_axis(
... "ic_id"
... )
>>> ic_reference = pd.DataFrame([10, 20, 32, 40], columns=["ic"]).rename_axis(
... "ic_id"
... )
>>> matches = categorize_ic_list(
... ic_list_detected=ic_detected,
... ic_list_reference=ic_reference,
... tolerance_samples=2,
... )
>>> match_ics = get_matching_ics(
... metrics_detected=ic_detected,
... metrics_reference=ic_reference,
... matches=matches,
... )
>>> match_ics
ic
detected reference
id
0 11 10
1 30 32
"""
matches = _check_matches_sanity(matches)

tp_matches = matches.query("match_type == 'tp'")

detected_matches = extract_tp_matches(metrics_detected, tp_matches["ic_id_detected"])
reference_matches = extract_tp_matches(metrics_reference, tp_matches["ic_id_reference"])

combined_matches = combine_detected_and_reference_metrics(
detected_matches, reference_matches, tp_matches=tp_matches
)

return combined_matches


def _match_label_lists(
list_left: np.ndarray, list_right: np.ndarray, tolerance_samples: Union[int, float] = 0
) -> tuple[np.ndarray, np.ndarray]:
@@ -337,3 +495,15 @@ def _sanitize_index(ic_list: pd.DataFrame, list_type: Literal["detected", "refer
if not ic_list.index.is_unique:
raise ValueError(f"The index of `ic_list_{list_type}` must be unique!")
return ic_list, is_multindex


__all__ = [
"calculate_matched_icd_performance_metrics",
"calculate_true_positive_icd_error",
"categorize_ic_list",
"_match_label_lists",
"icd_per_datapoint_score",
"icd_final_agg",
"icd_score",
"get_matching_ics",
]
127 changes: 127 additions & 0 deletions mobgap/initial_contacts/pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
"""Helpful Pipelines to wrap the ICD algorithms for optimization and evaluation."""

import warnings

import pandas as pd
from tpcp import OptimizableParameter, OptimizablePipeline
from typing_extensions import Self

from mobgap._utils_internal.misc import MeasureTimeResults, timed_action_method
from mobgap.data.base import BaseGaitDatasetWithReference
from mobgap.initial_contacts.base import BaseIcDetector, base_icd_docfiller
from mobgap.pipeline import GsIterator
from mobgap.utils.conversions import to_body_frame


def _conditionally_to_bf(data: pd.DataFrame, convert: bool) -> pd.DataFrame:
if convert:
return to_body_frame(data)
return data


@base_icd_docfiller
class IcdEmulationPipeline(OptimizablePipeline[BaseGaitDatasetWithReference]):
"""Run an ICD algorithm in isolation on a Gait Dataset.
This wraps any ICD algorithm and allows to apply it to a single datapoint of a Gait Dataset or optimize it
based on a whole dataset.
This pipeline can be used in combination with the ``tpcp.validate`` and ``tpcp.optimize`` modules to evaluate or
improve the performance of an ICD algorithm.
Parameters
----------
algo
The ICD algorithm that should be run/evaluated.
convert_to_body_frame
If True, the data will be converted to the body frame before running the algorithm.
This is the default, as most algorithm expect the data in the body frame.
If your data is explicitly not aligned and your algorithm supports sensor frame/unaligned input you might want
to set this to False.
Attributes
----------
%(ic_list_)s
algo_
The ICD algo instance with all results after running the algorithm.
This can be helpful for debugging or further analysis.
Notes
-----
All emulation pipelines pass available metadata of the dataset to the algorithm.
This includes the recording metadata (``recording_metadata``) and the participant metadata
(``participant_metadata``), which are passed as keyword arguments to the ``detect`` method of the algorithm.
In addition, we pass the group label of the datapoint as ``dp_group`` to the algorithm.
This is usually not required by algorithms (because this would mean that the algorithm changes behaviour based on
the exact recording provided).
However, it can be helpful when working with "dummy" algorithms, that simply return some fixed pre-defined results
or to be used as cache key, when the algorithm has internal caching mechanisms.
For the `self_optimize` method, we pass the same metadata to the algorithm, but each value is actually a list of
values, one for each datapoint in the dataset.
"""

algo: OptimizableParameter[BaseIcDetector]
convert_to_body_frame: bool

per_wb_algo_: dict[str, BaseIcDetector]
ic_list_: pd.DataFrame
perf_: MeasureTimeResults

def __init__(self, algo: BaseIcDetector, *, convert_to_body_frame: bool = True) -> None:
self.algo = algo
self.convert_to_body_frame = convert_to_body_frame

@timed_action_method
def run(self, datapoint: BaseGaitDatasetWithReference) -> Self:
"""Run the pipeline on a single data point.
This extracts the imu_data (``data_ss``) and the sampling rate (``sampling_rate_hz``) from the datapoint and
uses the ``detect`` method of the ICD algorithm to detect the gait sequences.
Parameters
----------
datapoint
A single datapoint of a Gait Dataset with reference information.
Returns
-------
self
The pipeline instance with the detected initial contacts stored in the ``ic_list_`` attribute.
"""
imu_data = _conditionally_to_bf(datapoint.data_ss, self.convert_to_body_frame)
sampling_rate_hz = datapoint.sampling_rate_hz

kwargs = {
"sampling_rate_hz": sampling_rate_hz,
**datapoint.recording_metadata,
**datapoint.participant_metadata,
"dp_group": datapoint.group_label,
}

ref_paras = datapoint.reference_parameters_

if len(ref_paras.wb_list) == 0:
warnings.warn(
f"No walking bouts found in the reference data. {kwargs['dp_group']}", RuntimeWarning, stacklevel=1
)
self.per_wb_algo_ = {}
self.ic_list_ = pd.DataFrame({"wb_id": [], "step_id": [], "ic": []}).astype(
{"wb_id": float, "step_id": int, "ic": int}
)
self.ic_list_ = self.ic_list_.set_index(["wb_id", "step_id"])
return self
wb_iterator = GsIterator()
result_algo_list = {}
for (wb, data), r in wb_iterator.iterate(imu_data, ref_paras.wb_list):
algo = self.algo.clone().detect(data, **kwargs, current_gs=wb)
result_algo_list[wb.id] = algo
r.ic_list = algo.ic_list_
self.per_wb_algo_ = result_algo_list
self.ic_list_ = wb_iterator.results_.ic_list
return self


__all__ = ["IcdEmulationPipeline"]
115 changes: 115 additions & 0 deletions mobgap/utils/evaluation.py
Original file line number Diff line number Diff line change
@@ -448,6 +448,119 @@ def count_samples_in_intervals(intervals: pd.DataFrame) -> int:
return int((intervals["end"] - intervals["start"] + 1).sum())


def extract_tp_matches(metrics: pd.DataFrame, match_indices: pd.Series) -> pd.DataFrame:
"""
Extract true positive (TP) matches from the metrics DataFrame based on provided match indices.
This function is used in :func:`~mobgap.gait_sequences.evaluation.get_matching_intervals` and
:func:`~mobgap.initial_contacts.evaluation.get_matching_ics` to filter out the TP matches from
a larger DataFrame of evaluation metrics (e.g., detected/reference gait sequences or initial
contacts). TP matches are calculated by :func:`~mobgap.gait_sequences.evaluation.categorize_intervals`
and :func:`~mobgap.initial_contacts.evaluation.categorize_ic_list` for gait sequence and initial
contact detection, respectively.
Parameters
----------
metrics: pd.DataFrame
The DataFrame containing metrics from which to extract the matches.
match_indices: pd.Series
A Series of indices indicating the rows in `metrics` to be extracted.
Returns
-------
matches: pd.DataFrame
A DataFrame containing the rows of `metrics` corresponding to the indices in `match_indices`.
Raises
------
ValueError
If any of the indices in `match_indices` are not found in the `metrics` DataFrame,
indicating that there is a mismatch between the provided indices and the data.
Examples
--------
>>> metrics = pd.DataFrame({
>>> 'step_id': [0, 1, 2, 3],
>>> 'metric': [597, 660, 720, 797]
>>> }).set_index("step_id")
>>> match_indices = pd.Series([0, 2])
>>> extract_tp_matches(metrics, match_indices)
metric
step_id
0 597
2 720
"""
try:
matches = metrics.loc[match_indices]
except KeyError as e:
raise ValueError(
"The indices from the provided `matches` DataFrame do not fit to the metrics DataFrames. "
"Please ensure that the `matches` DataFrame is calculated based on the same data "
"as the `metrics` DataFrames and thus refers to valid indices."
) from e
return matches


def combine_detected_and_reference_metrics(
detected: pd.DataFrame, reference: pd.DataFrame, tp_matches: Union[pd.DataFrame, None] = None
) -> pd.DataFrame:
"""
Combine metrics from detected and reference DataFrames using a set of true positive matches to reindex.
This function is used in :func:`~mobgap.gait_sequences.evaluation.get_matching_intervals` and
:func:`~mobgap.initial_contacts.evaluation.get_matching_ics` to merge two DataFrames (`detected`
and `reference`) based on their common columns. The dataframes are obtained by
:func:`~mobgap.utils.evaluation.extract_tp_matches`. Optionally, if a set of true positive matches
(`tp_matches`) is provided, it will reindex both DataFrames before merging.
The result is a combined DataFrame where detected and reference metrics are placed side-by-side
for each matching index with multi-level columns.
Parameters
----------
detected : pd.DataFrame
The DataFrame containing the detected metrics (gait sequences or initial contacts).
reference : pd.DataFrame
The DataFrame containing the reference metrics.
tp_matches : Union[pd.DataFrame, None], optional
A DataFrame containing true positive matches, by default None. If provided, the indices
of the `detected` and `reference` DataFrames will be reindexed to match the true positive
matches before combining. If None, no reindexing is done.
Returns
-------
pd.DataFrame
A combined DataFrame with the detected and reference metrics side-by-side. The columns
are multi-indexed with levels corresponding to `detected` and `reference` for easier
comparison of corresponding metrics.
Raises
------
ValueError
If no common columns are found between `detected` and `reference`.
Notes
-----
- The merged DataFrame has multi-level columns, where the top level indicates the source
(`detected` or `reference`) and the lower level corresponds to the metric name.
- We add a new column `orig_index` to both DataFrames to keep track of the original index
"""
common_columns = list(set(reference.columns).intersection(detected.columns))
if len(common_columns) == 0:
raise ValueError("No common columns found in `metrics_detected` and `metrics_reference`.")

detected = detected[common_columns].assign(orig_index=lambda df_: df_.index.to_list()).reset_index(drop=True)
reference = reference[common_columns].assign(orig_index=lambda df_: df_.index.to_list()).reset_index(drop=True)

combined = pd.concat({"detected": detected, "reference": reference}, axis=1)
# make 'metrics' level the uppermost level and sort columns accordingly for readability
combined = combined.swaplevel(axis=1).sort_index(axis=1, level=0)
if tp_matches is not None:
combined.index = tp_matches.index

return combined


def _estimate_number_tn_samples(matches_df: pd.DataFrame, n_overall_samples: Union[int, None], tn_warning: bool) -> int:
tn = count_samples_in_match_intervals(matches_df, "tn")
if tn > 0 and n_overall_samples is not None:
@@ -516,4 +629,6 @@ def _input_is_icd_matches_df(matches_df: pd.DataFrame) -> bool:
"Evaluation",
"EvaluationCV",
"save_evaluation_results",
"extract_tp_matches",
"combine_detected_and_reference_metrics",
]
102 changes: 51 additions & 51 deletions poetry.lock
2 changes: 2 additions & 0 deletions revalidation/initial_contacts/README.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Initial Contacts
----------------
254 changes: 254 additions & 0 deletions revalidation/initial_contacts/_01_icd_analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,254 @@
"""
.. _icd_val_results:
Performance of the initial contact algorithms on the TVS dataset
================================================================
.. warning:: On this page you will find preliminary results for a standardized revalidation of the pipeline and all
of its algorithm.
The current state, **TECHNICAL EXPERIMENTATION**.
Don't use these results or make any assumptions based on them.
We will update this page incrementally and provide further information, as soon as the state of any of the validation
steps changes.
The following provides an analysis and comparison of the icd performance on the TVS dataset (lab and free-living).
We look into the actual performance of the algorithms compared to the reference data and compare these results with
the performance of the original matlab algorithm.
.. note:: If you are interested in how these results are calculated, head over to the
:ref:`processing page <icd_val_gen>`.
We focus on the `single_results` (aka the performance per trail) and will aggregate it over multiple levels.
"""

# %%
# Below are the list of algorithms that we will compare.
# Note, that we use the prefix "new" to refer to the reimplemented python algorithms and "orig" to refer to the
# original matlab algorithms.

# Note also that the IcdIonescu algorithm is the reimplementation of the Ani_McCamley algorithm in the original
# matlab algorithms.
# The other two algorithms (IcdShinImproved and IcdHKLeeImproved) are actually cadence algorithms.
# As they can also be used to detect initial contacts, we present their results as well.
# However, you should check the dedicated cadence analysis for a more detailed comparison of these algorithms.
algorithms = {
"IcdIonescu": ("IcdIonescu", "new"),
"IcdShinImproved": ("IcdShinImproved", "new"),
"IcdHKLeeImproved": ("IcdHKLeeImproved", "new"),
}
# We only load the matlab algorithms that we reimplemented
algorithms.update(
{
"matlab_Ani_McCamley": ("IcdIonescu", "orig"),
}
)

# %%
# The code below loads the data and prepares it for the analysis.
# By default, the data will be downloaded from an online repository (and cached locally).
# If you want to use a local copy of the data, you can set the `MOBGAP_VALIDATION_DATA_PATH` environment variable.
# and the MOBGAP_VALIDATION_USE_LOCA_DATA to `1`.
#
# The file download will print a couple log information, which can usually be ignored.
# You can also change the `version` parameter to load a different version of the data.
from pathlib import Path

import pandas as pd
from mobgap.data.validation_results import ValidationResultLoader
from mobgap.utils.misc import get_env_var

local_data_path = (
Path(get_env_var("MOBGAP_VALIDATION_DATA_PATH")) / "results"
if int(get_env_var("MOBGAP_VALIDATION_USE_LOCAL_DATA", 0))
else None
)
loader = ValidationResultLoader(
"icd", result_path=local_data_path, version="main"
)

free_living_index_cols = [
"cohort",
"participant_id",
"time_measure",
"recording",
"recording_name",
"recording_name_pretty",
]

results = {
v: loader.load_single_results(k, "free_living")
for k, v in algorithms.items()
}
results = pd.concat(results, names=["algo", "version", *free_living_index_cols])
results_long = results.reset_index().assign(
algo_with_version=lambda df: df["algo"] + " (" + df["version"] + ")",
_combined="combined",
)
cohort_order = ["HA", "CHF", "COPD", "MS", "PD", "PFF"]
# %%
# Performance metrics
# -------------------
# For each participant, performance metrics were calculated by classifying the detected initial contacts as TP, FP or
# FN matches. Based on these values, recall (sensitivity), precision (positive predictive value), F1 score were
# calculated.
# On top of that, absolute error for each true positive initial contact was calculated as the temporal difference
# between detected and reference values. Relative error was calculated by dividing all absolute errors, within a walking
# bout, by the average step duration estimated from the reference system.
# From these, we calculate the mean and confidence interval for both systems, the bias and limits of agreement (LoA)
# between the algorithm output and the reference data, and the ICC.
#
# Below the functions that calculate these metrics are defined.
from functools import partial

from mobgap.pipeline.evaluation import CustomErrorAggregations as A
from mobgap.utils.df_operations import (
CustomOperation,
apply_aggregations,
apply_transformations,
)
from mobgap.utils.tables import FormatTransformer as F

custom_aggs = [
CustomOperation(
identifier=None,
function=A.n_datapoints,
column_name=[("n_datapoints", "all")],
),
("recall", ["mean", A.conf_intervals]),
("precision", ["mean", A.conf_intervals]),
("f1_score", ["mean", A.conf_intervals]),
("tp_absolute_timing_error_s", ["mean", A.loa]),
("tp_relative_timing_error", ["mean", A.loa]),
]

format_transforms = [
CustomOperation(
identifier=None,
function=lambda df_: df_[("n_datapoints", "all")].astype(int),
column_name=("General", "n_datapoints"),
),
*(
CustomOperation(
identifier=None,
function=partial(
F.value_with_range,
value_col=("mean", c),
range_col=("conf_intervals", c),
),
column_name=("ICD", c),
)
for c in [
"recall",
"precision",
"f1_score",
]
),
*(
CustomOperation(
identifier=None,
function=partial(
F.value_with_range,
value_col=("mean", c),
range_col=("loa", c),
),
column_name=("IC Timing", c),
)
for c in [
"tp_absolute_timing_error_s",
"tp_relative_timing_error",
]
),
]

final_names = {
"n_datapoints": "# recordings",
"recall": "Recall",
"precision": "Precision",
"f1_score": "F1 Score",
"tp_absolute_timing_error_s": "Abs. Error [s]",
"tp_relative_timing_error": "Bias and LoA",
}


def format_results(df: pd.DataFrame) -> pd.DataFrame:
return (
df.pipe(apply_transformations, format_transforms)
.rename(columns=final_names)
.loc[:, pd.IndexSlice[:, list(final_names.values())]]
)


# %%
# Free-Living Comparison
# ----------------------
# We focus the comparison on the free-living data, as this is the most relevant considering our final use-case.
# In the free-living data, there is one 2.5 hour recording per participant.
# This means, each datapoint in the plots below and in the summary statistics represents one participant.
#
# All results across all cohorts
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
import matplotlib.pyplot as plt
import seaborn as sns

hue_order = ["orig", "new"]

fig, ax = plt.subplots()
sns.boxplot(
data=results_long,
x="algo",
y="f1_score",
hue="version",
hue_order=hue_order,
ax=ax,
)
fig.show()

perf_metrics_all = (
results.groupby(["algo", "version"])
.apply(apply_aggregations, custom_aggs)
.pipe(format_results)
)
perf_metrics_all

# %%
# Per Cohort
# ~~~~~~~~~~
# While this provides a good overview, it does not fully reflect how these algorithms perform on the different cohorts.
fig, ax = plt.subplots()
sns.boxplot(
data=results_long, x="cohort", y="f1_score", hue="algo_with_version", ax=ax
)
fig.show()

perf_metrics_per_cohort = (
results.groupby(["cohort", "algo", "version"])
.apply(apply_aggregations, custom_aggs)
.pipe(format_results)
.loc[cohort_order]
)
perf_metrics_per_cohort


# %%
# Only relevant algorithms
# ~~~~~~~~~~~~~~~~~~~~~~~~
# Finally, we present comparison of the old and new implementations of IcdIonescu. IcdShinImproved and
# IcdHKLeeImproved are excluded because they are cadence algorithms and we don't calculate ICs with these algos in the
# old Matlab implementation.
Ionescu_results = results_long.query("algo == 'IcdIonescu'")
fig, ax = plt.subplots()
sns.boxplot(
data=Ionescu_results,
x="cohort",
y="f1_score",
hue="algo_with_version",
ax=ax,
)
fig.show()

final_perf_metrics = perf_metrics_per_cohort.query(
"algo == 'IcdIonescu'"
).reset_index(level="algo", drop=True)

final_perf_metrics
320 changes: 320 additions & 0 deletions revalidation/initial_contacts/_02_icd_result_generation_no_exc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,320 @@
"""
.. _icd_val_gen:
Revalidation of the initial contact detection algorithms
========================================================
.. note:: This is the code to create the results! If you are interested in viewing the results, please check the
:ref:`results report <icd_val_results>`.
This script reproduces the validation results on TVS dataset for the initial contact detection algorithms.
It load results from the old matlab algorithms and runs the new algorithms on the same data.
Performance metrics are calculated on a per-trial/per-recording basis and aggregated (mean for most metrics)
over the whole dataset.
The raw detected initial contacts and all performance metrics are saved to disk.
.. warning:: Before you modify and re-run this script, read through our guide on :ref:`revalidation`.
In case you are planning to update the official results (either after a code change, or because an algorithm was
added), contact one of the core maintainers.
They can assist with the process.
"""

# %%
# Setting up "Dummy Algorithms" to load the old results
# -----------------------------------------------------
# Instead of just loading the results of the old matlab algorithms, we create dummy algorithms that respond with the
# precomputed results per trial.
# This way, we can be sure that the exact same structure is used for the evaluation of all algorithms.
#
# Note, that this is not the most efficient way to do this, as we need to open the file repeatedly and also reload the
# data from the matlab files, even though the dummy algorithm does not need it.
from pathlib import Path
from typing import Any, Literal, Optional

import pandas as pd
from mobgap.initial_contacts.base import BaseIcDetector
from mobgap.pipeline import Region
from tpcp.caching import hybrid_cache
from typing_extensions import Self, Unpack


def load_old_icd_results(result_file_path: Path) -> pd.DataFrame:
assert result_file_path.exists(), result_file_path
data = pd.read_csv(result_file_path).astype({"participant_id": str})
data = data.set_index(data.columns[:-4].to_list())
return data


class DummyIcdAlgo(BaseIcDetector):
"""A dummy algorithm that responds with the precomputed results of the old pipeline.
This makes it convenient to compare the results of the old pipeline with the new pipeline, as we can simply use
the same code to evaluate both.
However, this also makes things a lot slower compared to just loading all the results, as we need to open the
file repeatedly and also reload the data from the matlab files, even though the dummy algorithm does not need it.
Parameters
----------
old_algo_name
Name of the algorithm for which we want to load the results.
This determines the name of the file to load.
base_result_folder
Base folder where the results are stored.
"""

def __init__(self, old_algo_name: str, base_result_folder: Path) -> None:
self.old_algo_name = old_algo_name
self.base_result_folder = base_result_folder

def detect(
self,
data: pd.DataFrame,
*,
sampling_rate_hz: float,
measurement_condition: Optional[
Literal["free_living", "laboratory"]
] = None,
dp_group: Optional[tuple[str, ...]] = None,
current_gs: Region = None,
**_: Unpack[dict[str, Any]],
) -> Self:
""" "Run" the algorithm."""
assert (
measurement_condition is not None
), "measurement_condition must be provided"
assert dp_group is not None, "dp_group must be provided"

cached_load_old_icd_results = hybrid_cache(lru_cache_maxsize=1)(
load_old_icd_results
)

all_results = cached_load_old_icd_results(
self.base_result_folder
/ measurement_condition
/ f"{self.old_algo_name}.csv"
)

unique_label = dp_group[:-2]
gs_start = current_gs.start
try:
ic_results = (
all_results.loc[unique_label]
.query(
"start == @gs_start | start == @gs_start + 1 | start == @gs_start - 1"
)
.copy()
)
ic = ic_results["ic_list_rel_to_wb"].apply(pd.eval).iloc[0]
ic_list = pd.DataFrame(
{"ic": ic},
)
ic_list.index.names = ["step_id"]
except:
# returns an empty dataframe for all exceptions (missing algo result for the data or missing data)
ic_list = pd.DataFrame(columns=["ic"])

self.ic_list_ = ic_list
return self


# %%
# Setting up the algorithms
# -------------------------
# We use the :class:`~mobgap.initial_contacts.pipeline.IcdEmulationPipeline` to run the algorithms.
# We create an instance of this pipeline for each algorithm we want to evaluate and store them in a dictionary.
# The key is used to identify the algorithm in the results and used as folder name to store the results.
#
# .. note:: Set up your environment variables to point to the correct paths.
# The easiest way to do this is to create a `.env` file in the root of the repository with the following content.
# You need the paths to the root folder of the TVS dataset `MOBGAP_TVS_DATASET_PATH` and the path where revalidation
# results should be stored `MOBGAP_VALIDATION_DATA_PATH`.
# The path to the cache directory `MOBGAP_CACHE_DIR_PATH` is optional, when you don't want to store the memory cache
# in the default location.
from mobgap.initial_contacts.pipeline import IcdEmulationPipeline
from mobgap.utils.misc import get_env_var

matlab_algo_result_path = (
Path(get_env_var("MOBGAP_VALIDATION_DATA_PATH")) / "_extracted_results/icd"
)

pipelines = {}
for matlab_algo_name in [
"Ani_McCamley",
]:
pipelines[f"matlab_{matlab_algo_name}"] = IcdEmulationPipeline(
DummyIcdAlgo(
matlab_algo_name, base_result_folder=matlab_algo_result_path
)
)

# %%
# The reimplemented algorithms:
from mobgap.initial_contacts import (
IcdHKLeeImproved,
IcdIonescu,
IcdShinImproved,
)

pipelines["IcdIonescu"] = IcdEmulationPipeline(IcdIonescu())
pipelines["IcdShinImproved"] = IcdEmulationPipeline(IcdShinImproved())
pipelines["IcdHKLeeImproved"] = IcdEmulationPipeline(IcdHKLeeImproved())


# %%
# Setting up the dataset
# ----------------------
# We run the comparison on the Lab and the Free-Living part of the TVS dataset.
# We use the :class:`~mobgap.data.TVSFreeLivingDataset` and the :class:`~mobgap.data.TVSLabDataset` to load the data.
# Note, that we use Memory caching to speed up the loading of the data.
# We also skip the recordings where the reference data is missing.
# In both cases, we compare against the INDIP reference system as done in the original validation as well.
#
# In the evaluation, each row of the dataset is treated as a separate recording.
# Results are calculated per recording.
# Aggregated results are calculated over the whole dataset, without considering the content of the individual
# recordings.
# Depending on how you want to interpret the results, you might not want to use the aggregated results, but rather
# perform custom aggregations over the provided "single_results".
from joblib import Memory
from mobgap import PACKAGE_ROOT
from mobgap.data import TVSFreeLivingDataset, TVSLabDataset

cache_dir = Path(
get_env_var("MOBGAP_CACHE_DIR_PATH", PACKAGE_ROOT.parent / ".cache")
)

datasets_free_living = TVSFreeLivingDataset(
get_env_var("MOBGAP_TVS_DATASET_PATH"),
reference_system="INDIP",
memory=Memory(cache_dir),
missing_reference_error_type="skip",
)
datasets_laboratory = TVSLabDataset(
get_env_var("MOBGAP_TVS_DATASET_PATH"),
reference_system="INDIP",
memory=Memory(cache_dir),
missing_reference_error_type="skip",
)


# %%
# Running the evaluation
# ----------------------
# We multiprocess the evaluation on the level of algorithms using joblib.
# Each algorithm pipeline is run using its own instance of the :class:`~mobgap.evaluation.Evaluation` class.
#
# The evaluation object iterates over the entire dataset, runs the algorithm on each recording and calculates the
# score using the :func:`~mobgap.initial_contacts._evaluation_scorer.icd_score` function.
import matplotlib.pyplot as plt
import seaborn as sns
from joblib import Parallel, delayed
from mobgap.initial_contacts.evaluation import icd_score
from mobgap.utils.evaluation import Evaluation

n_jobs = int(get_env_var("MOBGAP_N_JOBS", 3))
results_base_path = (
Path(get_env_var("MOBGAP_VALIDATION_DATA_PATH")) / "results/icd"
)


def run_evaluation(name, pipeline, ds):
eval_pipe = Evaluation(
ds,
scoring=icd_score,
).run(pipeline)
return name, eval_pipe


def eval_debug_plot(
results: dict[str, Evaluation[IcdEmulationPipeline]],
) -> None:
results_df = (
pd.concat({k: v.get_single_results_as_df() for k, v in results.items()})
.reset_index()
.rename(columns={"level_0": "algo_name"})
)

metrics = [
"precision",
"recall",
"f1_score",
"tp_absolute_timing_error_s",
"tp_relative_timing_error",
]
fig, axes = plt.subplots(2, 3, figsize=(18, 6))

for ax, metric in zip(axes.flatten(), metrics):
sns.boxplot(
data=results_df,
x="cohort",
y=metric,
hue="algo_name",
ax=ax,
showmeans=True,
)
ax.set_title(metric)

plt.tight_layout()
plt.show()


# %%
# Free-Living
# ~~~~~~~~~~~
# Let's start with the Free-Living part of the dataset.

with Parallel(n_jobs=n_jobs) as parallel:
results_free_living: dict[str, Evaluation[IcdEmulationPipeline]] = dict(
parallel(
delayed(run_evaluation)(name, pipeline, datasets_free_living)
for name, pipeline in pipelines.items()
)
)

# %%
# We create a quick plot for debugging.
# This is not meant to be a comprehensive analysis, but rather a quick check to see if the results are as expected.
eval_debug_plot(results_free_living)

# %%
# Then we save the results to disk.
from mobgap.utils.evaluation import save_evaluation_results

for k, v in results_free_living.items():
save_evaluation_results(
k,
v,
condition="free_living",
base_path=results_base_path,
raw_result_filter=["detected"],
)


# %%
# Laboratory
# ~~~~~~~~~~
# Now, we repeat the evaluation for the Laboratory part of the dataset.
with Parallel(n_jobs=n_jobs) as parallel:
results_laboratory: dict[str, Evaluation[IcdEmulationPipeline]] = dict(
parallel(
delayed(run_evaluation)(name, pipeline, datasets_laboratory)
for name, pipeline in pipelines.items()
)
)

# %%
# We create a quick plot for debugging.
eval_debug_plot(results_laboratory)

# %%
# Then we save the results to disk.
for k, v in results_laboratory.items():
save_evaluation_results(
k,
v,
condition="laboratory",
base_path=results_base_path,
raw_result_filter=["detected"],
)
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -61,6 +61,14 @@
"name":"n_turns_reference",
"type":"integer"
},
{
"name":"orig_index_detected",
"type":"string"
},
{
"name":"orig_index_reference",
"type":"string"
},
{
"name":"start_detected",
"type":"integer"
@@ -92,14 +100,6 @@
{
"name":"walking_speed_mps_reference",
"type":"number"
},
{
"name":"wb_id_detected",
"type":"integer"
},
{
"name":"wb_id_reference",
"type":"integer"
}
],
"primaryKey":[
@@ -124,16 +124,26 @@
"n_steps_reference":8,
"n_turns_detected":1,
"n_turns_reference":0,
"orig_index_detected":[
"T1",
12345,
"2023-01-01",
0
],
"orig_index_reference":[
"T1",
12345,
"2023-01-01",
0
],
"start_detected":0,
"start_reference":0,
"stride_duration_s_detected":2.3848070177,
"stride_duration_s_reference":1.58675,
"stride_length_m_detected":2.9073425098,
"stride_length_m_reference":2.51885,
"walking_speed_mps_detected":2.0616591597,
"walking_speed_mps_reference":1.16079,
"wb_id_detected":0,
"wb_id_reference":0
"walking_speed_mps_reference":1.16079
},
{
"index":1,
@@ -151,16 +161,26 @@
"n_steps_reference":12,
"n_turns_detected":3,
"n_turns_reference":2,
"orig_index_detected":[
"T1",
12345,
"2023-01-01",
2
],
"orig_index_reference":[
"T1",
12345,
"2023-01-01",
2
],
"start_detected":20,
"start_reference":20,
"stride_duration_s_detected":3.472869224,
"stride_duration_s_reference":2.56092,
"stride_length_m_detected":2.572194542,
"stride_length_m_reference":1.66305,
"walking_speed_mps_detected":2.2559309823,
"walking_speed_mps_reference":1.60044,
"wb_id_detected":2,
"wb_id_reference":2
"walking_speed_mps_reference":1.60044
},
{
"index":2,
@@ -178,16 +198,26 @@
"n_steps_reference":8,
"n_turns_detected":0,
"n_turns_reference":0,
"orig_index_detected":[
"T1",
12345,
"2023-01-01",
4
],
"orig_index_reference":[
"T1",
12345,
"2023-01-01",
4
],
"start_detected":40,
"start_reference":40,
"stride_duration_s_detected":3.3131197829,
"stride_duration_s_reference":2.35295,
"stride_length_m_detected":2.3382232273,
"stride_length_m_reference":1.95969,
"walking_speed_mps_detected":2.5481554328,
"walking_speed_mps_reference":2.3323,
"wb_id_detected":4,
"wb_id_reference":4
"walking_speed_mps_reference":2.3323
},
{
"index":3,
@@ -205,16 +235,26 @@
"n_steps_reference":17,
"n_turns_detected":1,
"n_turns_reference":1,
"orig_index_detected":[
"T1",
12345,
"2023-01-01",
6
],
"orig_index_reference":[
"T1",
12345,
"2023-01-01",
6
],
"start_detected":60,
"start_reference":60,
"stride_duration_s_detected":2.6274958317,
"stride_duration_s_reference":2.00582,
"stride_length_m_detected":3.30913068,
"stride_length_m_reference":2.41782,
"walking_speed_mps_detected":2.7831202734,
"walking_speed_mps_reference":2.09538,
"wb_id_detected":6,
"wb_id_reference":6
"walking_speed_mps_reference":2.09538
}
]
}
Original file line number Diff line number Diff line change
@@ -22,84 +22,84 @@
"type":"integer"
},
{
"name":"c_a_d_e_n_c_e___s_p_m___d_e_t_e_c_t_e_d",
"name":"cadence_spm_detected",
"type":"number"
},
{
"name":"c_a_d_e_n_c_e___s_p_m___r_e_f_e_r_e_n_c_e",
"name":"cadence_spm_reference",
"type":"number"
},
{
"name":"d_u_r_a_t_i_o_n___s___d_e_t_e_c_t_e_d",
"name":"duration_s_detected",
"type":"number"
},
{
"name":"d_u_r_a_t_i_o_n___s___r_e_f_e_r_e_n_c_e",
"name":"duration_s_reference",
"type":"number"
},
{
"name":"e_n_d___d_e_t_e_c_t_e_d",
"name":"end_detected",
"type":"integer"
},
{
"name":"e_n_d___r_e_f_e_r_e_n_c_e",
"name":"end_reference",
"type":"integer"
},
{
"name":"n___s_t_e_p_s___d_e_t_e_c_t_e_d",
"name":"n_steps_detected",
"type":"integer"
},
{
"name":"n___s_t_e_p_s___r_e_f_e_r_e_n_c_e",
"name":"n_steps_reference",
"type":"integer"
},
{
"name":"n___t_u_r_n_s___d_e_t_e_c_t_e_d",
"name":"n_turns_detected",
"type":"integer"
},
{
"name":"n___t_u_r_n_s___r_e_f_e_r_e_n_c_e",
"name":"n_turns_reference",
"type":"integer"
},
{
"name":"s_t_a_r_t___d_e_t_e_c_t_e_d",
"type":"integer"
"name":"orig_index_detected",
"type":"string"
},
{
"name":"s_t_a_r_t___r_e_f_e_r_e_n_c_e",
"type":"integer"
"name":"orig_index_reference",
"type":"string"
},
{
"name":"s_t_r_i_d_e___d_u_r_a_t_i_o_n___s___d_e_t_e_c_t_e_d",
"type":"number"
"name":"start_detected",
"type":"integer"
},
{
"name":"s_t_r_i_d_e___d_u_r_a_t_i_o_n___s___r_e_f_e_r_e_n_c_e",
"type":"number"
"name":"start_reference",
"type":"integer"
},
{
"name":"s_t_r_i_d_e___l_e_n_g_t_h___m___d_e_t_e_c_t_e_d",
"name":"stride_duration_s_detected",
"type":"number"
},
{
"name":"s_t_r_i_d_e___l_e_n_g_t_h___m___r_e_f_e_r_e_n_c_e",
"name":"stride_duration_s_reference",
"type":"number"
},
{
"name":"w_a_l_k_i_n_g___s_p_e_e_d___m_p_s___d_e_t_e_c_t_e_d",
"name":"stride_length_m_detected",
"type":"number"
},
{
"name":"w_a_l_k_i_n_g___s_p_e_e_d___m_p_s___r_e_f_e_r_e_n_c_e",
"name":"stride_length_m_reference",
"type":"number"
},
{
"name":"w_b___i_d___d_e_t_e_c_t_e_d",
"type":"integer"
"name":"walking_speed_mps_detected",
"type":"number"
},
{
"name":"w_b___i_d___r_e_f_e_r_e_n_c_e",
"type":"integer"
"name":"walking_speed_mps_reference",
"type":"number"
}
],
"primaryKey":[
@@ -114,107 +114,147 @@
"participant_id":12345,
"measurement_date":"2023-01-01",
"match_id":0,
"c_a_d_e_n_c_e___s_p_m___d_e_t_e_c_t_e_d":100.2324317085,
"c_a_d_e_n_c_e___s_p_m___r_e_f_e_r_e_n_c_e":99.82188,
"d_u_r_a_t_i_o_n___s___d_e_t_e_c_t_e_d":5.1303145487,
"d_u_r_a_t_i_o_n___s___r_e_f_e_r_e_n_c_e":4.74702,
"e_n_d___d_e_t_e_c_t_e_d":5,
"e_n_d___r_e_f_e_r_e_n_c_e":4,
"n___s_t_e_p_s___d_e_t_e_c_t_e_d":8,
"n___s_t_e_p_s___r_e_f_e_r_e_n_c_e":8,
"n___t_u_r_n_s___d_e_t_e_c_t_e_d":1,
"n___t_u_r_n_s___r_e_f_e_r_e_n_c_e":0,
"s_t_a_r_t___d_e_t_e_c_t_e_d":0,
"s_t_a_r_t___r_e_f_e_r_e_n_c_e":0,
"s_t_r_i_d_e___d_u_r_a_t_i_o_n___s___d_e_t_e_c_t_e_d":2.3848070177,
"s_t_r_i_d_e___d_u_r_a_t_i_o_n___s___r_e_f_e_r_e_n_c_e":1.58675,
"s_t_r_i_d_e___l_e_n_g_t_h___m___d_e_t_e_c_t_e_d":2.9073425098,
"s_t_r_i_d_e___l_e_n_g_t_h___m___r_e_f_e_r_e_n_c_e":2.51885,
"w_a_l_k_i_n_g___s_p_e_e_d___m_p_s___d_e_t_e_c_t_e_d":2.0616591597,
"w_a_l_k_i_n_g___s_p_e_e_d___m_p_s___r_e_f_e_r_e_n_c_e":1.16079,
"w_b___i_d___d_e_t_e_c_t_e_d":0,
"w_b___i_d___r_e_f_e_r_e_n_c_e":0
"cadence_spm_detected":100.2324317085,
"cadence_spm_reference":99.82188,
"duration_s_detected":5.1303145487,
"duration_s_reference":4.74702,
"end_detected":5,
"end_reference":4,
"n_steps_detected":8,
"n_steps_reference":8,
"n_turns_detected":1,
"n_turns_reference":0,
"orig_index_detected":[
"T1",
12345,
"2023-01-01",
0
],
"orig_index_reference":[
"T1",
12345,
"2023-01-01",
0
],
"start_detected":0,
"start_reference":0,
"stride_duration_s_detected":2.3848070177,
"stride_duration_s_reference":1.58675,
"stride_length_m_detected":2.9073425098,
"stride_length_m_reference":2.51885,
"walking_speed_mps_detected":2.0616591597,
"walking_speed_mps_reference":1.16079
},
{
"index":1,
"visit_type":"T1",
"participant_id":12345,
"measurement_date":"2023-01-01",
"match_id":2,
"c_a_d_e_n_c_e___s_p_m___d_e_t_e_c_t_e_d":87.4843290301,
"c_a_d_e_n_c_e___s_p_m___r_e_f_e_r_e_n_c_e":86.53527,
"d_u_r_a_t_i_o_n___s___d_e_t_e_c_t_e_d":9.1405757466,
"d_u_r_a_t_i_o_n___s___r_e_f_e_r_e_n_c_e":8.52727,
"e_n_d___d_e_t_e_c_t_e_d":25,
"e_n_d___r_e_f_e_r_e_n_c_e":24,
"n___s_t_e_p_s___d_e_t_e_c_t_e_d":12,
"n___s_t_e_p_s___r_e_f_e_r_e_n_c_e":12,
"n___t_u_r_n_s___d_e_t_e_c_t_e_d":3,
"n___t_u_r_n_s___r_e_f_e_r_e_n_c_e":2,
"s_t_a_r_t___d_e_t_e_c_t_e_d":20,
"s_t_a_r_t___r_e_f_e_r_e_n_c_e":20,
"s_t_r_i_d_e___d_u_r_a_t_i_o_n___s___d_e_t_e_c_t_e_d":3.472869224,
"s_t_r_i_d_e___d_u_r_a_t_i_o_n___s___r_e_f_e_r_e_n_c_e":2.56092,
"s_t_r_i_d_e___l_e_n_g_t_h___m___d_e_t_e_c_t_e_d":2.572194542,
"s_t_r_i_d_e___l_e_n_g_t_h___m___r_e_f_e_r_e_n_c_e":1.66305,
"w_a_l_k_i_n_g___s_p_e_e_d___m_p_s___d_e_t_e_c_t_e_d":2.2559309823,
"w_a_l_k_i_n_g___s_p_e_e_d___m_p_s___r_e_f_e_r_e_n_c_e":1.60044,
"w_b___i_d___d_e_t_e_c_t_e_d":2,
"w_b___i_d___r_e_f_e_r_e_n_c_e":2
"cadence_spm_detected":87.4843290301,
"cadence_spm_reference":86.53527,
"duration_s_detected":9.1405757466,
"duration_s_reference":8.52727,
"end_detected":25,
"end_reference":24,
"n_steps_detected":12,
"n_steps_reference":12,
"n_turns_detected":3,
"n_turns_reference":2,
"orig_index_detected":[
"T1",
12345,
"2023-01-01",
2
],
"orig_index_reference":[
"T1",
12345,
"2023-01-01",
2
],
"start_detected":20,
"start_reference":20,
"stride_duration_s_detected":3.472869224,
"stride_duration_s_reference":2.56092,
"stride_length_m_detected":2.572194542,
"stride_length_m_reference":1.66305,
"walking_speed_mps_detected":2.2559309823,
"walking_speed_mps_reference":1.60044
},
{
"index":2,
"visit_type":"T1",
"participant_id":12345,
"measurement_date":"2023-01-01",
"match_id":4,
"c_a_d_e_n_c_e___s_p_m___d_e_t_e_c_t_e_d":93.9889414049,
"c_a_d_e_n_c_e___s_p_m___r_e_f_e_r_e_n_c_e":93.69895,
"d_u_r_a_t_i_o_n___s___d_e_t_e_c_t_e_d":6.2172281615,
"d_u_r_a_t_i_o_n___s___r_e_f_e_r_e_n_c_e":6.09907,
"e_n_d___d_e_t_e_c_t_e_d":45,
"e_n_d___r_e_f_e_r_e_n_c_e":44,
"n___s_t_e_p_s___d_e_t_e_c_t_e_d":9,
"n___s_t_e_p_s___r_e_f_e_r_e_n_c_e":8,
"n___t_u_r_n_s___d_e_t_e_c_t_e_d":0,
"n___t_u_r_n_s___r_e_f_e_r_e_n_c_e":0,
"s_t_a_r_t___d_e_t_e_c_t_e_d":40,
"s_t_a_r_t___r_e_f_e_r_e_n_c_e":40,
"s_t_r_i_d_e___d_u_r_a_t_i_o_n___s___d_e_t_e_c_t_e_d":3.3131197829,
"s_t_r_i_d_e___d_u_r_a_t_i_o_n___s___r_e_f_e_r_e_n_c_e":2.35295,
"s_t_r_i_d_e___l_e_n_g_t_h___m___d_e_t_e_c_t_e_d":2.3382232273,
"s_t_r_i_d_e___l_e_n_g_t_h___m___r_e_f_e_r_e_n_c_e":1.95969,
"w_a_l_k_i_n_g___s_p_e_e_d___m_p_s___d_e_t_e_c_t_e_d":2.5481554328,
"w_a_l_k_i_n_g___s_p_e_e_d___m_p_s___r_e_f_e_r_e_n_c_e":2.3323,
"w_b___i_d___d_e_t_e_c_t_e_d":4,
"w_b___i_d___r_e_f_e_r_e_n_c_e":4
"cadence_spm_detected":93.9889414049,
"cadence_spm_reference":93.69895,
"duration_s_detected":6.2172281615,
"duration_s_reference":6.09907,
"end_detected":45,
"end_reference":44,
"n_steps_detected":9,
"n_steps_reference":8,
"n_turns_detected":0,
"n_turns_reference":0,
"orig_index_detected":[
"T1",
12345,
"2023-01-01",
4
],
"orig_index_reference":[
"T1",
12345,
"2023-01-01",
4
],
"start_detected":40,
"start_reference":40,
"stride_duration_s_detected":3.3131197829,
"stride_duration_s_reference":2.35295,
"stride_length_m_detected":2.3382232273,
"stride_length_m_reference":1.95969,
"walking_speed_mps_detected":2.5481554328,
"walking_speed_mps_reference":2.3323
},
{
"index":3,
"visit_type":"T1",
"participant_id":12345,
"measurement_date":"2023-01-01",
"match_id":6,
"c_a_d_e_n_c_e___s_p_m___d_e_t_e_c_t_e_d":87.4288802029,
"c_a_d_e_n_c_e___s_p_m___r_e_f_e_r_e_n_c_e":87.09312,
"d_u_r_a_t_i_o_n___s___d_e_t_e_c_t_e_d":14.0346487047,
"d_u_r_a_t_i_o_n___s___r_e_f_e_r_e_n_c_e":13.5208,
"e_n_d___d_e_t_e_c_t_e_d":65,
"e_n_d___r_e_f_e_r_e_n_c_e":64,
"n___s_t_e_p_s___d_e_t_e_c_t_e_d":17,
"n___s_t_e_p_s___r_e_f_e_r_e_n_c_e":17,
"n___t_u_r_n_s___d_e_t_e_c_t_e_d":1,
"n___t_u_r_n_s___r_e_f_e_r_e_n_c_e":1,
"s_t_a_r_t___d_e_t_e_c_t_e_d":60,
"s_t_a_r_t___r_e_f_e_r_e_n_c_e":60,
"s_t_r_i_d_e___d_u_r_a_t_i_o_n___s___d_e_t_e_c_t_e_d":2.6274958317,
"s_t_r_i_d_e___d_u_r_a_t_i_o_n___s___r_e_f_e_r_e_n_c_e":2.00582,
"s_t_r_i_d_e___l_e_n_g_t_h___m___d_e_t_e_c_t_e_d":3.30913068,
"s_t_r_i_d_e___l_e_n_g_t_h___m___r_e_f_e_r_e_n_c_e":2.41782,
"w_a_l_k_i_n_g___s_p_e_e_d___m_p_s___d_e_t_e_c_t_e_d":2.7831202734,
"w_a_l_k_i_n_g___s_p_e_e_d___m_p_s___r_e_f_e_r_e_n_c_e":2.09538,
"w_b___i_d___d_e_t_e_c_t_e_d":6,
"w_b___i_d___r_e_f_e_r_e_n_c_e":6
"cadence_spm_detected":87.4288802029,
"cadence_spm_reference":87.09312,
"duration_s_detected":14.0346487047,
"duration_s_reference":13.5208,
"end_detected":65,
"end_reference":64,
"n_steps_detected":17,
"n_steps_reference":17,
"n_turns_detected":1,
"n_turns_reference":1,
"orig_index_detected":[
"T1",
12345,
"2023-01-01",
6
],
"orig_index_reference":[
"T1",
12345,
"2023-01-01",
6
],
"start_detected":60,
"start_reference":60,
"stride_duration_s_detected":2.6274958317,
"stride_duration_s_reference":2.00582,
"stride_length_m_detected":3.30913068,
"stride_length_m_reference":2.41782,
"walking_speed_mps_detected":2.7831202734,
"walking_speed_mps_reference":2.09538
}
]
}
Original file line number Diff line number Diff line change
@@ -57,6 +57,14 @@
"name":"n_turns_reference",
"type":"integer"
},
{
"name":"orig_index_detected",
"type":"string"
},
{
"name":"orig_index_reference",
"type":"string"
},
{
"name":"start_detected",
"type":"integer"
@@ -89,14 +97,6 @@
"name":"walking_speed_mps_reference",
"type":"number"
},
{
"name":"wb_id_detected",
"type":"integer"
},
{
"name":"wb_id_reference",
"type":"integer"
},
{
"name":"cadence_spm_error",
"type":"number"
@@ -234,6 +234,18 @@
"n_steps_reference":8,
"n_turns_detected":1,
"n_turns_reference":0,
"orig_index_detected":[
"T1",
12345,
"2023-01-01",
0
],
"orig_index_reference":[
"T1",
12345,
"2023-01-01",
0
],
"start_detected":0,
"start_reference":0,
"stride_duration_s_detected":2.3848070177,
@@ -242,8 +254,6 @@
"stride_length_m_reference":2.51885,
"walking_speed_mps_detected":2.0616591597,
"walking_speed_mps_reference":1.16079,
"wb_id_detected":0,
"wb_id_reference":0,
"cadence_spm_error":0.4105517085,
"cadence_spm_rel_error":0.0041128429,
"cadence_spm_abs_error":0.4105517085,
@@ -288,6 +298,18 @@
"n_steps_reference":12,
"n_turns_detected":3,
"n_turns_reference":2,
"orig_index_detected":[
"T1",
12345,
"2023-01-01",
2
],
"orig_index_reference":[
"T1",
12345,
"2023-01-01",
2
],
"start_detected":20,
"start_reference":20,
"stride_duration_s_detected":3.472869224,
@@ -296,8 +318,6 @@
"stride_length_m_reference":1.66305,
"walking_speed_mps_detected":2.2559309823,
"walking_speed_mps_reference":1.60044,
"wb_id_detected":2,
"wb_id_reference":2,
"cadence_spm_error":0.9490590301,
"cadence_spm_rel_error":0.0109673088,
"cadence_spm_abs_error":0.9490590301,
@@ -342,6 +362,18 @@
"n_steps_reference":8,
"n_turns_detected":0,
"n_turns_reference":0,
"orig_index_detected":[
"T1",
12345,
"2023-01-01",
4
],
"orig_index_reference":[
"T1",
12345,
"2023-01-01",
4
],
"start_detected":40,
"start_reference":40,
"stride_duration_s_detected":3.3131197829,
@@ -350,8 +382,6 @@
"stride_length_m_reference":1.95969,
"walking_speed_mps_detected":2.5481554328,
"walking_speed_mps_reference":2.3323,
"wb_id_detected":4,
"wb_id_reference":4,
"cadence_spm_error":0.2899914049,
"cadence_spm_rel_error":0.0030949269,
"cadence_spm_abs_error":0.2899914049,
@@ -396,6 +426,18 @@
"n_steps_reference":17,
"n_turns_detected":1,
"n_turns_reference":1,
"orig_index_detected":[
"T1",
12345,
"2023-01-01",
6
],
"orig_index_reference":[
"T1",
12345,
"2023-01-01",
6
],
"start_detected":60,
"start_reference":60,
"stride_duration_s_detected":2.6274958317,
@@ -404,8 +446,6 @@
"stride_length_m_reference":2.41782,
"walking_speed_mps_detected":2.7831202734,
"walking_speed_mps_reference":2.09538,
"wb_id_detected":6,
"wb_id_reference":6,
"cadence_spm_error":0.3357602029,
"cadence_spm_rel_error":0.0038551863,
"cadence_spm_abs_error":0.3357602029,
5 changes: 1 addition & 4 deletions tests/test_examples/test_pipeline_examples.py
Original file line number Diff line number Diff line change
@@ -75,13 +75,10 @@ def test_dmo_evaluation_on_wb_level(snapshot):
# flatten multiindex columns as they are not supported by snapshot
wb_matches.columns = ["_".join(pair) for pair in wb_matches.columns]
snapshot.assert_match(wb_matches.reset_index(), "det_ref_daily")
snapshot.assert_match(wb_matches.reset_index(), "wb_matches")

snapshot.assert_match(wb_tp_fp_fn, "wb_tp_fp_fn")

# flatten multiindex columns as they are not supported by snapshot
wb_matches.columns = ["_".join(pair) for pair in wb_matches.columns]
snapshot.assert_match(wb_matches.reset_index(), "wb_matches")

# flatten multiindex columns as they are not supported by snapshot
wb_errors.columns = ["_".join(pair) for pair in wb_errors.columns]
snapshot.assert_match(wb_errors, "wb_errors")

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
{
"schema":{
"fields":[
{
"name":"index",
"type":"integer"
},
{
"name":"metric_a_detected",
"type":"integer"
},
{
"name":"metric_a_reference",
"type":"integer"
},
{
"name":"metric_b_detected",
"type":"integer"
},
{
"name":"metric_b_reference",
"type":"integer"
},
{
"name":"metric_c_detected",
"type":"integer"
},
{
"name":"metric_c_reference",
"type":"integer"
},
{
"name":"orig_index_detected",
"type":"string"
},
{
"name":"orig_index_reference",
"type":"string"
}
],
"primaryKey":[
"index"
],
"pandas_version":"1.4.0"
},
"data":[
{
"index":0,
"metric_a_detected":1,
"metric_a_reference":4,
"metric_b_detected":2,
"metric_b_reference":5,
"metric_c_detected":3,
"metric_c_reference":6,
"orig_index_detected":[
"a",
1
],
"orig_index_reference":[
"a",
2
]
}
]
}
18 changes: 11 additions & 7 deletions tests/test_gait_sequences/test_gsd_evaluation.py
Original file line number Diff line number Diff line change
@@ -627,17 +627,18 @@ def test_get_matching_gs_empty_df(self, dmo_df, matches_df):
)

def test_get_matching_gs_no_matches(self, dmo_df, matches_df):
matches_df["match_type"] = "fp" * len(matches_df)
matches_df["match_type"] = "fp"
combined = get_matching_intervals(
metrics_detected=dmo_df,
metrics_reference=dmo_df,
matches=pd.DataFrame(columns=matches_df.columns),
)
assert combined.empty
assert_array_equal(
list(combined.columns[:-2].to_numpy()), list(product(dmo_df.columns, ["detected", "reference"]))
)
assert_array_equal(list(combined.columns[-2:].to_numpy()), list(product(["wb_id"], ["detected", "reference"])))
cols = np.array(combined.columns.to_list())
cols.sort()
expected_cols = np.array(list(product([*dmo_df.columns, "orig_index"], ["detected", "reference"])))
expected_cols.sort()
assert_array_equal(cols, expected_cols)

def test_get_matching_gs_invalid_matches(self, dmo_df, matches_df):
with pytest.raises(TypeError):
@@ -673,6 +674,9 @@ def test_get_matching_gs(self, snapshot, dmo_df, matches_df):
assert_array_equal(combined.index, matches_df.query("match_type == 'tp'").index)
assert_array_equal(
list(combined.columns.to_numpy()),
list(product(dmo_df.columns, ["detected", "reference"])) + [("wb_id", "detected"), ("wb_id", "reference")],
list(product(dmo_df.columns, ["detected", "reference"]))
+ [("orig_index", "detected"), ("orig_index", "reference")],
)
snapshot.assert_match(combined.to_numpy()[0], "combined")
# We combine the columns for the snapshot test
combined.columns = ["_".join(col) for col in combined.columns]
snapshot.assert_match(combined, "combined")