Skip to content

Commit c4e4ce6

Browse files
authored
ICD revalidation (#196)
2 parents 1667b09 + d9493fe commit c4e4ce6

19 files changed

+1670
-256
lines changed

docs/conf.py

+1
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ def substitute(matchobj) -> str:
213213
"../examples/data_transform",
214214
"../examples/dev_guides",
215215
"../revalidation/gait_sequences",
216+
"../revalidation/initial_contacts",
216217
"../revalidation/stride_length",
217218
]
218219
),

docs/modules/initial_contacts.rst

+32
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,16 @@ Algorithms
1818
IcdIonescu
1919
IcdHKLeeImproved
2020

21+
Pipelines
22+
+++++++++
23+
.. currentmodule:: mobgap.initial_contacts
24+
25+
.. autosummary::
26+
:toctree: generated/initial_contacts
27+
:template: class.rst
28+
29+
pipeline.IcdEmulationPipeline
30+
2131
Utils
2232
+++++
2333
.. currentmodule:: mobgap.initial_contacts
@@ -37,7 +47,29 @@ Evaluation
3747
:template: func.rst
3848

3949
calculate_matched_icd_performance_metrics
50+
calculate_true_positive_icd_error
4051
categorize_ic_list
52+
get_matching_ics
53+
54+
Evaluation Scores
55+
+++++++++++++++++
56+
These scores are expected to be used in combination with :class:`~mobgap.utils.evaluation.Evaluation` and
57+
:class:`~mobgap.utils.evaluation.EvaluationCV` or directly with :func:`~tpcp.validation.cross_validation` and
58+
:func:`~tpcp.validation.validation`.
59+
60+
.. currentmodule:: mobgap.initial_contacts.evaluation
61+
62+
.. autosummary::
63+
:toctree: generated/initial_contacts
64+
65+
icd_score
66+
67+
.. autosummary::
68+
:toctree: generated/initial_contacts
69+
:template: func.rst
70+
71+
icd_per_datapoint_score
72+
icd_final_agg
4173

4274
Base Classes
4375
++++++++++++

mobgap/gait_sequences/evaluation.py

+5-43
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,10 @@
1919
)
2020
from mobgap.utils.evaluation import (
2121
accuracy_score,
22+
combine_detected_and_reference_metrics,
2223
count_samples_in_intervals,
2324
count_samples_in_match_intervals,
25+
extract_tp_matches,
2426
npv_score,
2527
precision_recall_f1_score,
2628
specificity_score,
@@ -721,10 +723,10 @@ def get_matching_intervals(
721723

722724
tp_matches = matches.query("match_type == 'tp'")
723725

724-
detected_matches = _extract_tp_matches(metrics_detected, tp_matches["gs_id_detected"])
725-
reference_matches = _extract_tp_matches(metrics_reference, tp_matches["gs_id_reference"])
726+
detected_matches = extract_tp_matches(metrics_detected, tp_matches["gs_id_detected"])
727+
reference_matches = extract_tp_matches(metrics_reference, tp_matches["gs_id_reference"])
726728

727-
combined_matches = _combine_detected_and_reference_metrics(
729+
combined_matches = combine_detected_and_reference_metrics(
728730
detected_matches, reference_matches, tp_matches=tp_matches
729731
)
730732

@@ -748,46 +750,6 @@ def _check_gs_level_matches_sanity(matches: pd.DataFrame) -> pd.DataFrame:
748750
return matches
749751

750752

751-
def _extract_tp_matches(metrics: pd.DataFrame, match_indices: pd.Series) -> pd.DataFrame:
752-
try:
753-
matches = metrics.loc[match_indices]
754-
except KeyError as e:
755-
raise ValueError(
756-
"The indices from the provided `matches` DataFrame do not fit to the metrics DataFrames. "
757-
"Please ensure that the `matches` DataFrame is calculated based on the same data "
758-
"as the `metrics` DataFrames and thus refers to valid indices."
759-
) from e
760-
return matches
761-
762-
763-
def _combine_detected_and_reference_metrics(
764-
detected: pd.DataFrame, reference: pd.DataFrame, tp_matches: Union[pd.DataFrame, None] = None
765-
) -> pd.DataFrame:
766-
# if wb_id in index, add it as a column to preserve it in the combined DataFrame
767-
if "wb_id" in detected.index.names and "wb_id" in reference.index.names:
768-
detected.insert(0, "wb_id", detected.index.get_level_values("wb_id"))
769-
reference.insert(0, "wb_id", reference.index.get_level_values("wb_id"))
770-
771-
common_columns = list(set(reference.columns).intersection(detected.columns))
772-
if len(common_columns) == 0:
773-
raise ValueError("No common columns found in `metrics_detected` and `metrics_reference`.")
774-
775-
detected = detected[common_columns]
776-
reference = reference[common_columns]
777-
778-
if tp_matches is not None:
779-
detected.index = tp_matches.index
780-
reference.index = tp_matches.index
781-
782-
matches = detected.merge(reference, left_index=True, right_index=True, suffixes=("_det", "_ref"))
783-
784-
# construct MultiIndex columns
785-
matches.columns = pd.MultiIndex.from_product([["detected", "reference"], common_columns])
786-
# make 'metrics' level the uppermost level and sort columns accordingly for readability
787-
matches = matches.swaplevel(axis=1).sort_index(axis=1, level=0)
788-
return matches
789-
790-
791753
__all__ = [
792754
"categorize_intervals_per_sample",
793755
"categorize_intervals",
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,253 @@
1+
import warnings
2+
3+
import pandas as pd
4+
from tpcp.validate import Scorer, no_agg
5+
6+
from mobgap.data.base import BaseGaitDatasetWithReference
7+
from mobgap.initial_contacts.pipeline import IcdEmulationPipeline
8+
9+
10+
def icd_per_datapoint_score(pipeline: IcdEmulationPipeline, datapoint: BaseGaitDatasetWithReference) -> dict:
11+
"""Evaluate the performance of an ICD algorithm on a single datapoint.
12+
13+
.. warning:: This function is not meant to be called directly, but as a scoring function in a
14+
:class:`tpcp.validate.Scorer`.
15+
If you are writing custom scoring functions, you can use this function as a template or wrap it in a new
16+
function.
17+
18+
This function is used to evaluate the performance of an ICD algorithm on a single datapoint.
19+
It calculates the performance metrics based on the detected initial contacts and the reference initial contacts.
20+
21+
The following performance metrics are calculated:
22+
23+
- all outputs of :func:`~mobgap.initial_contacts.evaluation.calculate_matched_icd_performance_metrics`
24+
(will be averaged over all datapoints)
25+
- ``matches``: The matched initial contacts calculated by
26+
:func:`~mobgap.initial_contacts.evaluation.categorize_ic_list` (return as ``no_agg``)
27+
- ``detected``: The detected initial contacts (return as ``no_agg``)
28+
- ``reference``: The reference initial contacts (return as ``no_agg``)
29+
- ``sampling_rate_hz``: The sampling rate of the data (return as ``no_agg``)
30+
31+
Parameters
32+
----------
33+
pipeline
34+
An instance of ICD emulation pipeline that wraps the algorithm that should be evaluated.
35+
datapoint
36+
The datapoint to be evaluated.
37+
38+
Returns
39+
-------
40+
dict
41+
A dictionary containing the performance metrics.
42+
Note, that some results are wrapped in a ``no_agg`` object or other aggregators.
43+
The results of this function are not expected to be parsed manually, but rather the function is expected to be
44+
used in the context of the :func:`~tpcp.validate.validate`/:func:`~tpcp.validate.cross_validate` functions or
45+
similar as scorer.
46+
This functions will aggregate the results and provide a summary of the performance metrics.
47+
48+
"""
49+
from mobgap.initial_contacts.evaluation import (
50+
calculate_matched_icd_performance_metrics,
51+
calculate_true_positive_icd_error,
52+
categorize_ic_list,
53+
get_matching_ics,
54+
)
55+
from mobgap.utils.conversions import as_samples
56+
from mobgap.utils.df_operations import create_multi_groupby
57+
58+
with warnings.catch_warnings():
59+
# We know that these errors might happen, and they are usually not relevant for the evaluation
60+
warnings.filterwarnings("ignore", message="Zero division", category=UserWarning)
61+
warnings.filterwarnings("ignore", message="multiple ICs", category=UserWarning)
62+
63+
# Run the algorithm on the datapoint
64+
pipeline.safe_run(datapoint)
65+
detected_ic_list = pipeline.ic_list_
66+
reference_ic_list = datapoint.reference_parameters_.ic_list
67+
sampling_rate_hz = datapoint.sampling_rate_hz
68+
69+
# tolerance around the reference ic (this is a centered window - half window in both directions)
70+
tolerance_s = 0.5
71+
tolerance_samples = as_samples(tolerance_s, sampling_rate_hz)
72+
73+
# match types
74+
matches_per_wb = create_multi_groupby(detected_ic_list, reference_ic_list, groupby="wb_id").apply(
75+
lambda df1, df2: categorize_ic_list(
76+
ic_list_detected=df1,
77+
ic_list_reference=df2,
78+
tolerance_samples=tolerance_samples,
79+
multiindex_warning=False,
80+
)
81+
)
82+
# check if matches_per_wb has the required columns
83+
if matches_per_wb.empty == 1:
84+
# then it is an empty dataframe without required columns
85+
matches_per_wb = pd.DataFrame(
86+
{
87+
"ic_id_detected": [],
88+
"ic_id_reference": [],
89+
"match_type": [],
90+
"wb_id": [],
91+
}
92+
).set_index(["wb_id"])
93+
94+
# calculate run time on pipeline level
95+
runtime_s = pipeline.perf_["runtime_s"]
96+
97+
# match initial contacts, get true positives
98+
tp_ics = get_matching_ics(
99+
metrics_detected=detected_ic_list,
100+
metrics_reference=reference_ic_list,
101+
matches=matches_per_wb,
102+
)
103+
104+
# Calculate the performance metrics
105+
performance_metrics = {
106+
**calculate_matched_icd_performance_metrics(
107+
matches_per_wb,
108+
),
109+
**calculate_true_positive_icd_error(
110+
reference_ic_list,
111+
tp_ics,
112+
sampling_rate_hz,
113+
),
114+
"matches": no_agg(matches_per_wb),
115+
"detected": no_agg(detected_ic_list),
116+
"reference": no_agg(reference_ic_list),
117+
"tp_ics": no_agg(tp_ics),
118+
"sampling_rate_hz": no_agg(sampling_rate_hz),
119+
"runtime_s": runtime_s,
120+
}
121+
122+
return performance_metrics
123+
124+
125+
def icd_final_agg(
126+
agg_results: dict[str, float],
127+
single_results: dict[str, list],
128+
pipeline: IcdEmulationPipeline, # noqa: ARG001
129+
dataset: BaseGaitDatasetWithReference,
130+
) -> tuple[dict[str, any], dict[str, list[any]]]:
131+
"""Aggregate the performance metrics of an ICD algorithm over multiple datapoints.
132+
133+
.. warning:: This function is not meant to be called directly, but as ``final_aggregator`` in a
134+
:class:`tpcp.validate.Scorer`.
135+
If you are writing custom scoring functions, you can use this function as a template or wrap it in a new
136+
function.
137+
138+
This function aggregates the performance metrics as follows:
139+
140+
- All raw outputs (``detected``, ``reference``, ``sampling_rate_hz``) are concatenated to a single
141+
dataframe, to make it easier to work with and are returned as part of the single results.
142+
- We recalculate all performance metrics from
143+
:func:`~mobgap.initial_contacts.evaluation.calculate_matched_icd_performance_metrics` on the combined data.
144+
The results are prefixed with ``combined__``.
145+
Compared to the per-datapoint results (which are calculated, as errors per recording -> average over all
146+
recordings), these metrics are calculated as combining all ICDs from all recordings and then calculating the
147+
performance metrics.
148+
Effectively, this means, that in the `per_datapoint` version, each recording is weighted equally, while in the
149+
`combined` version, each IC is weighted equally.
150+
151+
Parameters
152+
----------
153+
agg_results
154+
The aggregated results from all datapoints (see :class:`~tpcp.validate.Scorer`).
155+
single_results
156+
The per-datapoint results (see :class:`~tpcp.validate.Scorer`).
157+
pipeline
158+
The pipeline that was passed to the scorer.
159+
This is ignored in this function, but might be useful in custom final aggregators.
160+
dataset
161+
The dataset that was passed to the scorer.
162+
163+
Returns
164+
-------
165+
final_agg_results
166+
The final aggregated results.
167+
final_single_results
168+
The per-datapoint results, that are not aggregated.
169+
170+
"""
171+
from mobgap.initial_contacts.evaluation import (
172+
calculate_matched_icd_performance_metrics,
173+
calculate_true_positive_icd_error,
174+
)
175+
176+
data_labels = [d.group_label for d in dataset]
177+
data_label_names = data_labels[0]._fields
178+
# We combine each to a combined dataframe
179+
matches = single_results.pop("matches")
180+
matches = pd.concat(matches, keys=data_labels, names=[*data_label_names, *matches[0].index.names])
181+
detected = single_results.pop("detected")
182+
detected = pd.concat(detected, keys=data_labels, names=[*data_label_names, *detected[0].index.names])
183+
reference = single_results.pop("reference")
184+
reference = pd.concat(reference, keys=data_labels, names=[*data_label_names, *reference[0].index.names])
185+
tp_ics = single_results.pop("tp_ics")
186+
tp_ics = pd.concat(tp_ics, keys=data_labels, names=[*data_label_names, *tp_ics[0].index.names])
187+
188+
aggregated_single_results = {
189+
"raw__detected": detected,
190+
"raw__reference": reference,
191+
}
192+
193+
sampling_rate_hz = single_results.pop("sampling_rate_hz")
194+
if set(sampling_rate_hz) != {sampling_rate_hz[0]}:
195+
raise ValueError(
196+
"Sampling rate is not the same for all datapoints in the dataset. "
197+
"This not supported by this scorer. "
198+
"Provide a custom scorer that can handle this case."
199+
)
200+
201+
combined_matched = {
202+
f"combined__{k}": v
203+
for k, v in {
204+
**calculate_matched_icd_performance_metrics(matches),
205+
**calculate_true_positive_icd_error(reference, tp_ics, sampling_rate_hz[0]),
206+
}.items()
207+
}
208+
209+
# Note, that we pass the "aggregated_single_results" out via the single results and not the aggregated results
210+
# The reason is that the aggregated results are expected to be a single value per metric, while the single results
211+
# can be anything.
212+
return {**agg_results, **combined_matched}, {**single_results, **aggregated_single_results}
213+
214+
215+
#: :data:: icd_score
216+
#: Scorer class instance for ICD algorithms.
217+
icd_score = Scorer(icd_per_datapoint_score, final_aggregator=icd_final_agg)
218+
icd_score.__doc__ = """Scorer for ICD algorithms.
219+
220+
This is a pre-configured :class:`~tpcp.validate.Scorer` object using the :func:`icd_per_datapoint_score` function as
221+
per-datapoint scorer and the :func:`icd_final_agg` function as final aggregator.
222+
For more information about Scorer, head to the tpcp documentation (:class:`~tpcp.validate.Scorer`).
223+
For usage information in the context of mobgap, have a look at the :ref:`evaluation example <icd_evaluation>` for ICD.
224+
225+
The following metrics are calculated:
226+
227+
Raw metrics (part of the single results):
228+
229+
- ``single__raw__detected``: The detected initial contacts as a single dataframe with the datapoint labels as index.
230+
- ``single__raw__reference``: The reference initial contacts as a single dataframe with the datapoint labels as index.
231+
232+
Metrics per datapoint (single results):
233+
*These values are all provided as a list of values, one per datapoint.*
234+
235+
- All outputs of :func:`~mobgap.initial_contacts.evaluation.calculate_matched_icd_performance_metrics` and
236+
:func:`~mobgap.initial_contacts.evaluation.calculate_true_positive_icd_error` averaged per
237+
datapoint. These are stored as ``single__{metric_name}``
238+
- ``single__runtime_s``: The runtime of the algorithm in seconds. If multiple WBs were processed, is the runtime it
239+
took to process all WBs.
240+
241+
Aggregated metrics (aggregated results):
242+
243+
- All single outputs averaged over all datapoints. These are stored as ``agg__{metric_name}``.
244+
- All metrics from :func:`~mobgap.initial_contacts.evaluation.calculate_matched_icd_performance_metrics` and
245+
:func:`~mobgap.initial_contacts.evaluation.calculate_true_positive_icd_error` recalculated on all detected ICs across
246+
all datapoints. These are stored as ``combined__{metric_name}``.
247+
Compared to the per-datapoint results (which are calculated, as errors per recording -> average over all
248+
recordings), these metrics are calculated as combining all ICDs from all recordings and then calculating the
249+
performance metrics.
250+
Effectively, this means, that in the `per_datapoint` version, each recording is weighted equally, while in the
251+
`combined` version, each IC is weighted equally.
252+
253+
"""

0 commit comments

Comments
 (0)