diff --git a/execution_engine/omop/cohort/recommendation.py b/execution_engine/omop/cohort/recommendation.py index 695805b6..f132c6d2 100644 --- a/execution_engine/omop/cohort/recommendation.py +++ b/execution_engine/omop/cohort/recommendation.py @@ -210,10 +210,10 @@ def reset_state(self) -> None: Sets all _id attributes to None in the recommendation and all its population/intervention pairs and criteria. """ - self._id = None + self.reset_id() for pi_pair in self.population_intervention_pairs(): - pi_pair._id = None + pi_pair.reset_id() for criterion in self.atoms(): - criterion._id = None + criterion.reset_id() diff --git a/execution_engine/omop/db/celida/tables.py b/execution_engine/omop/db/celida/tables.py index b434d038..444bf5e7 100644 --- a/execution_engine/omop/db/celida/tables.py +++ b/execution_engine/omop/db/celida/tables.py @@ -146,7 +146,7 @@ class ResultInterval(Base): # noqa: D101 ) result_id: Mapped[int] = mapped_column( - Integer, primary_key=True, index=True, autoincrement=True + BigInteger, primary_key=True, index=True, autoincrement=True ) run_id = mapped_column( ForeignKey(f"{SCHEMA_NAME}.execution_run.run_id"), @@ -169,9 +169,7 @@ class ResultInterval(Base): # noqa: D101 interval_start: Mapped[datetime] interval_end: Mapped[datetime] interval_type = mapped_column(IntervalTypeEnum) - interval_ratio: Mapped[float] = mapped_column( - nullable=True - ) + interval_ratio: Mapped[float] = mapped_column(nullable=True) execution_run: Mapped["ExecutionRun"] = relationship( primaryjoin="ResultInterval.run_id == ExecutionRun.run_id", ) diff --git a/execution_engine/task/task.py b/execution_engine/task/task.py index 78d541c4..85d079bd 100644 --- a/execution_engine/task/task.py +++ b/execution_engine/task/task.py @@ -51,6 +51,7 @@ def default_interval_union_with_count( interval_type, interval_count = IntervalType.NEGATIVE, 0 else: interval_type, interval_count = interval.type, cast(int, interval.count) + if ( ( interval_type is IntervalType.POSITIVE @@ -199,14 +200,15 @@ def receives_only_count_inputs(self) -> bool: Indicates whether this tasks only receives inputs from expression that perform counting and thus return IntervalWithCount. """ - # all arguments are count types - if all(isinstance(parent, COUNT_TYPES) for parent in self.expr.args): - return True - - # all arguments are logic.BinaryNonCommutativeOperator, and all of their "right" children are count types + # all arguments are either count types or logic.BinaryNonCommutativeOperator + # with their "right" child being a count type, or they have a custom counting function (count_intervals()) if all( - isinstance(parent, logic.BinaryNonCommutativeOperator) - and isinstance(parent.right, COUNT_TYPES) + ( + isinstance(parent, logic.BinaryNonCommutativeOperator) + and isinstance(parent.right, COUNT_TYPES) + ) + or (isinstance(parent, COUNT_TYPES)) + or (hasattr(parent, "count_intervals")) for parent in self.expr.args ): return True diff --git a/execution_engine/util/serializable.py b/execution_engine/util/serializable.py index 768fe13f..6844aaa6 100644 --- a/execution_engine/util/serializable.py +++ b/execution_engine/util/serializable.py @@ -170,7 +170,14 @@ def set_id(self, value: int, overwrite: bool = False) -> None: """ if self._id is not None and not overwrite: raise ValueError("Database ID has already been set!") - self._id = value + object.__setattr__(self, "_id", value) + + def reset_id(self) -> None: + """ + Resets the database ID. + """ + # Circumvents the immutable __setattr__ + object.__setattr__(self, "_id", None) @property def id(self) -> int: