From 0cd4e044c75a6c18708e2a04c33d61dc687a8a52 Mon Sep 17 00:00:00 2001 From: Gregor Lichtner Date: Wed, 4 Dec 2024 21:58:15 +0100 Subject: [PATCH] fix: unpickling of temporal count --- execution_engine/util/cohort_logic.py | 129 +++++++++--------- .../util/test_cohort_logic.py | 34 +++++ 2 files changed, 97 insertions(+), 66 deletions(-) diff --git a/execution_engine/util/cohort_logic.py b/execution_engine/util/cohort_logic.py index 4fb8bb1d..3f0cd08a 100644 --- a/execution_engine/util/cohort_logic.py +++ b/execution_engine/util/cohort_logic.py @@ -506,28 +506,27 @@ def __new__( self.interval_type = interval_type return self - # glichtner: this should now be handled by the base class using get_instance_variables() - # def __reduce__(self) -> tuple[Callable, tuple]: - # """ - # Reduce the expression to its arguments and category. - # - # Required for pickling (e.g. when using multiprocessing). - # - # :return: Tuple of the class, arguments, and category. - # """ - # return ( - # self._recreate, - # ( - # self.args, - # { - # "category": self.category, - # "threshold": self.count_min, - # "start_time": self.start_time, - # "end_time": self.end_time, - # "interval_type": self.interval_type, - # }, - # ), - # ) + def __reduce__(self) -> tuple[Callable, tuple]: + """ + Reduce the expression to its arguments and category. + + Required for pickling (e.g. when using multiprocessing). + + :return: Tuple of the class, arguments, and category. + """ + return ( + self._recreate, + ( + self.args, + { + "category": self.category, + "threshold": self.count_min, + "start_time": self.start_time, + "end_time": self.end_time, + "interval_type": self.interval_type, + }, + ), + ) def __repr__(self) -> str: """ @@ -568,28 +567,27 @@ def __new__( self.interval_type = interval_type return self - # glichtner: this should now be handled by the base class using get_instance_variables() - # def __reduce__(self) -> tuple[Callable, tuple]: - # """ - # Reduce the expression to its arguments and category. - # - # Required for pickling (e.g. when using multiprocessing). - # - # :return: Tuple of the class, arguments, and category. - # """ - # return ( - # self._recreate, - # ( - # self.args, - # { - # "category": self.category, - # "threshold": self.count_max, - # "start_time": self.start_time, - # "end_time": self.end_time, - # "interval_type": self.interval_type, - # }, - # ), - # ) + def __reduce__(self) -> tuple[Callable, tuple]: + """ + Reduce the expression to its arguments and category. + + Required for pickling (e.g. when using multiprocessing). + + :return: Tuple of the class, arguments, and category. + """ + return ( + self._recreate, + ( + self.args, + { + "category": self.category, + "threshold": self.count_max, + "start_time": self.start_time, + "end_time": self.end_time, + "interval_type": self.interval_type, + }, + ), + ) def __repr__(self) -> str: """ @@ -631,28 +629,27 @@ def __new__( self.interval_type = interval_type return self - # glichtner: this should now be handled by the base class using get_instance_variables() - # def __reduce__(self) -> tuple[Callable, tuple]: - # """ - # Reduce the expression to its arguments and category. - # - # Required for pickling (e.g. when using multiprocessing). - # - # :return: Tuple of the class, arguments, and category. - # """ - # return ( - # self._recreate, - # ( - # self.args, - # { - # "category": self.category, - # "threshold": self.count_min, - # "start_time": self.start_time, - # "end_time": self.end_time, - # "interval_type": self.interval_type, - # }, - # ), - # ) + def __reduce__(self) -> tuple[Callable, tuple]: + """ + Reduce the expression to its arguments and category. + + Required for pickling (e.g. when using multiprocessing). + + :return: Tuple of the class, arguments, and category. + """ + return ( + self._recreate, + ( + self.args, + { + "category": self.category, + "threshold": self.count_min, + "start_time": self.start_time, + "end_time": self.end_time, + "interval_type": self.interval_type, + }, + ), + ) def __repr__(self) -> str: """ diff --git a/tests/execution_engine/util/test_cohort_logic.py b/tests/execution_engine/util/test_cohort_logic.py index c10bedf7..731bf95f 100644 --- a/tests/execution_engine/util/test_cohort_logic.py +++ b/tests/execution_engine/util/test_cohort_logic.py @@ -4,6 +4,7 @@ import pytest from execution_engine.constants import CohortCategory +from execution_engine.omop.criterion.combination.temporal import TimeIntervalType from execution_engine.util.cohort_logic import ( AllOrNone, And, @@ -19,6 +20,9 @@ Not, Or, Symbol, + TemporalExactCount, + TemporalMaxCount, + TemporalMinCount, ) from tests.mocks.criterion import MockCriterion @@ -224,6 +228,36 @@ class TestSymbolMultiprocessing: NoDataPreservingAnd(1, 2, 3, category=CohortCategory.POPULATION), NoDataPreservingOr(1, 2, 3, category=CohortCategory.POPULATION), LeftDependentToggle(left=1, right=2, category=CohortCategory.POPULATION), + TemporalMinCount( + 1, + 2, + 3, + threshold=2, + start_time=None, + end_time=None, + interval_type=TimeIntervalType.DAY, + category=CohortCategory.POPULATION, + ), + TemporalMaxCount( + 1, + 2, + 3, + threshold=3, + start_time=None, + end_time=None, + interval_type=TimeIntervalType.MORNING_SHIFT, + category=CohortCategory.POPULATION, + ), + TemporalExactCount( + 1, + 2, + 3, + threshold=2, + start_time=None, + end_time=None, + interval_type=TimeIntervalType.NIGHT_SHIFT, + category=CohortCategory.POPULATION, + ), ], ids=lambda expr: expr.__class__.__name__, )