diff --git a/execution_engine/omop/criterion/point_in_time.py b/execution_engine/omop/criterion/point_in_time.py index a44fb267..658d7d70 100644 --- a/execution_engine/omop/criterion/point_in_time.py +++ b/execution_engine/omop/criterion/point_in_time.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, Dict, cast from sqlalchemy import CTE, ColumnElement, Select, select @@ -21,10 +21,19 @@ class PointInTimeCriterion(ConceptCriterion): def __init__( self, + forward_fill: bool = True, *args: Any, **kwargs: Any, ): super().__init__(*args, **kwargs) + self._forward_fill = forward_fill + + @property + def forward_fill(self) -> bool: + """ + Return true is process_data should forward_fill the temporal intervals in the observation window. + """ + return self._forward_fill def _sql_interval_type_column(self, query: Select | CTE) -> ColumnElement: """ @@ -83,6 +92,22 @@ def _create_query(self) -> Select: return query + def dict(self) -> dict[str, Any]: + """ + Get a JSON representation of the criterion. + """ + from_super = super().dict() + return from_super | {"forward_fill": self._forward_fill} + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "PointInTimeCriterion": + """ + Create a criterion from a JSON representation. + """ + object = cast("PointInTimeCriterion", super().from_dict(data)) + object._forward_fill = data.get("forward_fill", True) # Backward compat + return object + def process_data( self, data: PersonIntervals, @@ -92,7 +117,8 @@ def process_data( """ Process the result of the SQL query. - Forward fill all intervals and in insert NO_DATA intervals for missing time in observation_window. + If configured via the forward_fill attribute, forward fill all intervals. + Insert NO_DATA intervals for missing time in observation_window. :param data: The result of the SQL query. :param base_data: The base data or None if this is the base criterion. @@ -105,7 +131,8 @@ def process_data( # because they are valid not only at the time of the measurement but also for a certain time after the # measurement possibly, one would need to define something like a "validity duration" for each # measurement value (or rather each measurement in each recommendation) - data = process.forward_fill(data, observation_window) + if self._forward_fill: + data = process.forward_fill(data, observation_window) no_data_intervals = process.complementary_intervals( data, diff --git a/tests/execution_engine/omop/criterion/combination/test_temporal_combination.py b/tests/execution_engine/omop/criterion/combination/test_temporal_combination.py index 8c326636..ed0750dd 100644 --- a/tests/execution_engine/omop/criterion/combination/test_temporal_combination.py +++ b/tests/execution_engine/omop/criterion/combination/test_temporal_combination.py @@ -11,14 +11,17 @@ ) from execution_engine.omop.criterion.condition_occurrence import ConditionOccurrence from execution_engine.omop.criterion.drug_exposure import DrugExposure +from execution_engine.omop.criterion.measurement import Measurement from execution_engine.omop.criterion.procedure_occurrence import ProcedureOccurrence from execution_engine.task.process import get_processing_module from execution_engine.util.types import Dosage, TimeRange from execution_engine.util.value import ValueNumber from tests._fixtures.concept import ( concept_artificial_respiration, + concept_body_weight, concept_covid19, concept_heparin_ingredient, + concept_unit_kg, concept_unit_mg, ) from tests._testdata import concepts @@ -26,6 +29,7 @@ from tests.functions import ( create_condition, create_drug_exposure, + create_measurement, create_procedure, create_visit, ) @@ -267,6 +271,23 @@ def test_add_all(self): concept=concept_artificial_respiration, ) +bodyweight_measurement_without_forward_fill = Measurement( + exclude=False, + category=CohortCategory.POPULATION, + concept=concept_body_weight, + value=ValueNumber.parse("<=110", unit=concept_unit_kg), + static=False, + forward_fill=False, +) + +bodyweight_measurement_with_forward_fill = Measurement( + exclude=False, + category=CohortCategory.POPULATION, + concept=concept_body_weight, + value=ValueNumber.parse("<=110", unit=concept_unit_kg), + static=False, +) + class TestCriterionCombinationDatabase(TestCriterion): """ @@ -284,12 +305,22 @@ def criteria(self, db_session): c1.id = 1 c2.id = 2 c3.id = 3 + bodyweight_measurement_without_forward_fill.id = 4 + bodyweight_measurement_with_forward_fill.id = 5 self.register_criterion(c1, db_session) self.register_criterion(c2, db_session) self.register_criterion(c3, db_session) - - return [c1, c2, c3] + self.register_criterion(bodyweight_measurement_without_forward_fill, db_session) + self.register_criterion(bodyweight_measurement_with_forward_fill, db_session) + + return [ + c1, + c2, + c3, + bodyweight_measurement_without_forward_fill, + bodyweight_measurement_with_forward_fill, + ] def run_criteria_test( self, @@ -1226,3 +1257,143 @@ def test_overlapping_combination_on_database( observation_window, persons, ) + + +class TestCriterionPointInTime(TestCriterionCombinationDatabase): + """ + Test class for testing the behavior of PointInTimeCriterion + classes. + + More precisely, the test ensures that point-in-time events like + measurements interact correctly with PointInTimeCriteria and + TemporalIndicatorCombinations. A particular failure mode of this + combination has been that single point-in-time event could lead to + POSITIVE result interval on subsequent days due to forward_fill + logic in PointInTimeCriterion. + """ + + @pytest.fixture + def observation_window(self) -> TimeRange: + return TimeRange( + start="2023-02-28 13:55:00Z", + end="2023-03-03 18:00:00Z", + name="observation", + ) + + def patient_events(self, db_session, visit_occurrence): + e1 = create_measurement( + vo=visit_occurrence, + measurement_concept_id=concept_body_weight.concept_id, + measurement_datetime=pendulum.parse("2023-03-01 09:00:00+01:00"), + value_as_number=100, + unit_concept_id=concept_unit_kg.concept_id, + ) + db_session.add_all([e1]) + db_session.commit() + + @pytest.mark.parametrize( + "combination,expected", + [ + ( + TemporalIndicatorCombination.MorningShift( + bodyweight_measurement_without_forward_fill, + category=CohortCategory.POPULATION, + ), + { + 1: { + ( + pendulum.parse("2023-03-01 06:00:00+01:00"), + pendulum.parse("2023-03-01 13:59:59+01:00"), + ), + } + }, + ), + ( + TemporalIndicatorCombination.AfternoonShift( + bodyweight_measurement_without_forward_fill, + category=CohortCategory.POPULATION, + ), + {1: set()}, + ), + ( + TemporalIndicatorCombination.MorningShift( + bodyweight_measurement_with_forward_fill, + category=CohortCategory.POPULATION, + ), + { + 1: { + ( + pendulum.parse("2023-03-01 06:00:00+01:00"), + pendulum.parse("2023-03-01 13:59:59+01:00"), + ), + ( + pendulum.parse("2023-03-02 06:00:00+01:00"), + pendulum.parse("2023-03-02 13:59:59+01:00"), + ), + ( + pendulum.parse("2023-03-03 06:00:00+01:00"), + pendulum.parse("2023-03-03 13:59:59+01:00"), + ), + } + }, + ), + ( + TemporalIndicatorCombination.AfternoonShift( + bodyweight_measurement_with_forward_fill, + category=CohortCategory.POPULATION, + ), + { + 1: { + ( + pendulum.parse("2023-03-01 14:00:00+01:00"), + pendulum.parse("2023-03-01 21:59:59+01:00"), + ), + ( + pendulum.parse("2023-03-02 14:00:00+01:00"), + pendulum.parse("2023-03-02 21:59:59+01:00"), + ), + ( + pendulum.parse("2023-03-03 14:00:00+01:00"), + pendulum.parse("2023-03-03 15:00:00+00:00"), + ), + } + }, + ), + ], + ) + def test_point_in_time_criterion_on_database( + self, + person, + db_session, + base_criterion, + combination, + expected, + observation_window, + criteria, + ): + persons = [person[0]] # only one person + vos = [ + create_visit( + person_id=person.person_id, + visit_start_datetime=observation_window.start + + datetime.timedelta(hours=3), + visit_end_datetime=observation_window.end - datetime.timedelta(hours=3), + visit_concept_id=concepts.INTENSIVE_CARE, + ) + for person in persons + ] + + self.patient_events(db_session, vos[0]) + + db_session.add_all(vos) + db_session.commit() + + self.run_criteria_test( + combination, + expected, + db_session, + criteria, + base_criterion, + observation_window, + persons, + )