diff --git a/execution_engine/execution_engine.py b/execution_engine/execution_engine.py index 6f505aee..bb591600 100644 --- a/execution_engine/execution_engine.py +++ b/execution_engine/execution_engine.py @@ -118,7 +118,7 @@ def execute( """ Executes a recommendation and stores the results in the result database. - :param recommendation: The Recommendation object (loaded from ExectionEngine.load_recommendation). + :param recommendation: The Recommendation object (loaded from ExecutionEngine.load_recommendation). :param start_datetime: The start of the observation window. :param end_datetime: The end of the observation window. If None, the current time is used. :return: The ID of the run. @@ -135,7 +135,12 @@ def execute( ) with self._db.begin(): - self.register_recommendation(recommendation) + # If the recommendation has been loaded from the + # database, its _id slot is not None. Otherwise, register + # the recommendation to store it into the database and + # assign an id. + if recommendation._id is None: + self.register_recommendation(recommendation) run_id = self.register_run( recommendation, start_datetime=start_datetime, end_datetime=end_datetime ) @@ -186,7 +191,14 @@ def load_recommendation_from_database( recommendation = cohort.Recommendation.from_json( rec_db.recommendation_json.decode() ) - recommendation.id = rec_db.recommendation_id + # All objects in the deserialized object graph must have + # an id. + assert recommendation._id is not None + assert recommendation._base_criterion._id is not None + for pi_pair in recommendation._pi_pairs: + assert pi_pair._id is not None + for criterion in pi_pair.flatten(): + assert criterion._id is not None return recommendation return None @@ -198,14 +210,29 @@ def _hash(obj: Serializable) -> tuple[bytes, str]: def register_recommendation(self, recommendation: cohort.Recommendation) -> None: """Registers the Recommendation in the result database.""" - + # We don't want to include any ids in the hash since ids + # "accidental" in the sense that they depend on, at least, the + # order in which recommendations are inserted into the + # database. + assert recommendation._id is None + assert recommendation._base_criterion._id is None + for pi_pair in recommendation._pi_pairs: + assert pi_pair._id is None + for criterion in pi_pair.flatten(): + assert criterion._id is None + # Get the hash but ignore the JSON representation for now + # since we will compute and insert a complete JSON + # representation later when we know all ids. + _, rec_hash = self._hash(recommendation) recommendation_table = result_db.Recommendation - - rec_json, rec_hash = self._hash(recommendation) + # We look for a recommendation with the computed hash in the + # database. If there is one, set the id of our recommendation + # to the stored id. Otherwise, store our recommendation + # (without the JSON representation) in the database and + # receive the fresh id. query = select(recommendation_table).where( recommendation_table.recommendation_hash == rec_hash ) - with self._db.begin() as con: rec_db = con.execute(query).fetchone() @@ -221,7 +248,7 @@ def register_recommendation(self, recommendation: cohort.Recommendation) -> None recommendation_version=recommendation.version, recommendation_package_version=recommendation.package_version, recommendation_hash=rec_hash, - recommendation_json=rec_json, + recommendation_json=bytes(), # updated later create_datetime=datetime.now(), ) .returning(recommendation_table.recommendation_id) @@ -229,25 +256,43 @@ def register_recommendation(self, recommendation: cohort.Recommendation) -> None result = con.execute(query) recommendation.id = result.fetchone().recommendation_id - + # Register all child objects. After that, the recommendation + # and all child objects have valid ids (either restored or + # fresh). + self.register_criterion(recommendation._base_criterion) for pi_pair in recommendation.population_intervention_pairs(): self.register_population_intervention_pair( pi_pair, recommendation_id=recommendation.id ) - for criterion in pi_pair.flatten(): self.register_criterion(criterion) + assert recommendation.id is not None + # TODO(jmoringe): mypy doesn't like this one. Not sure why. + # assert recommendation._base_criterion._id is not None + for pi_pair in recommendation._pi_pairs: + assert pi_pair._id is not None + for criterion in pi_pair.flatten(): + assert criterion._id is not None + + # Update the recommendation in the database with the final + # JSON representation and execution graph (now that + # recommendation id, criteria ids and pi pair is are known) + # TODO(jmoringe): only when necessary with self._db.begin() as con: - # update recommendation with execution graph (now that criterion & pi pair is are known) rec_graph: bytes = json.dumps( recommendation.execution_graph().to_cytoscape_dict(), sort_keys=True ).encode() + rec_json: bytes = recommendation.json() + logging.info(f"Storing recommendation {recommendation}") update_query = ( update(recommendation_table) .where(recommendation_table.recommendation_id == recommendation.id) - .values(recommendation_execution_graph=rec_graph) + .values( + recommendation_json=rec_json, + recommendation_execution_graph=rec_graph, + ) ) con.execute(update_query) @@ -261,6 +306,8 @@ def register_population_intervention_pair( :param pi_pair: The Population/Intervention Pair. :param recommendation_id: The ID of the Population/Intervention Pair. """ + # We don't want to include the id in the hash + assert pi_pair._id is None _, pi_pair_hash = self._hash(pi_pair) query = select(result_db.PopulationInterventionPair).where( result_db.PopulationInterventionPair.pi_pair_hash == pi_pair_hash @@ -314,6 +361,8 @@ def register_criterion(self, criterion: Criterion) -> None: result = con.execute(query) criterion.id = result.fetchone().criterion_id + assert criterion.id is not None + def register_run( self, recommendation: cohort.Recommendation, diff --git a/execution_engine/omop/cohort/population_intervention_pair.py b/execution_engine/omop/cohort/population_intervention_pair.py index 5e3d7343..ab7100dd 100644 --- a/execution_engine/omop/cohort/population_intervention_pair.py +++ b/execution_engine/omop/cohort/population_intervention_pair.py @@ -38,7 +38,7 @@ class PopulationInterventionPair(Serializable): (e.g. "has condition X and lab value Y >= Z"). """ - _id: int | None = None + _id: int | None _name: str _population: CriterionCombination _intervention: CriterionCombination @@ -50,7 +50,9 @@ def __init__( base_criterion: Criterion, population: CriterionCombination | None = None, intervention: CriterionCombination | None = None, + id: int | None = None, ) -> None: + self._id = id self._name = name self._url = url self._base_criterion = base_criterion @@ -231,6 +233,7 @@ def dict(self) -> dict[str, Any]: population = self._population intervention = self._intervention return { + "id": self._id, "name": self.name, "url": self.url, "base_criterion": { @@ -262,12 +265,13 @@ def from_dict(cls, data: Dict[str, Any]) -> "PopulationInterventionPair": CriterionCombination, criterion_factory(**data["intervention"]) ) object = cls( + id=data["id"], name=data["name"], url=data["url"], base_criterion=base_criterion, ) # The constructor initializes the population and intervention - # slots in a particular way but we want to use whatever we + # slots in a particular way, but we want to use whatever we # have deserialized instead. This is a bit inefficient because # we discard the values that were assigned to the two slots in # the constructor. diff --git a/execution_engine/omop/cohort/recommendation.py b/execution_engine/omop/cohort/recommendation.py index a6c7a45a..120fd13e 100644 --- a/execution_engine/omop/cohort/recommendation.py +++ b/execution_engine/omop/cohort/recommendation.py @@ -249,10 +249,14 @@ def dict(self) -> dict: """ Get the combination as a dictionary. """ + base_criterion = self._base_criterion return { "id": self._id, "population_intervention_pairs": [c.dict() for c in self._pi_pairs], - "base_criterion": self._base_criterion.dict(), + "base_criterion": { + "class_name": base_criterion.__class__.__name__, + "data": base_criterion.dict(), + }, "recommendation_name": self._name, "recommendation_title": self._title, "recommendation_url": self._url, @@ -283,5 +287,5 @@ def from_dict(cls, data: Dict[str, Any]) -> Self: version=data["recommendation_version"], description=data["recommendation_description"], package_version=data["recommendation_package_version"], - recommendation_id=data["id"] if "id" in data else None, + recommendation_id=data["id"], ) diff --git a/execution_engine/omop/criterion/abstract.py b/execution_engine/omop/criterion/abstract.py index d80f58b4..0d4cf0c4 100644 --- a/execution_engine/omop/criterion/abstract.py +++ b/execution_engine/omop/criterion/abstract.py @@ -93,8 +93,10 @@ class AbstractCriterion(Serializable, ABC): Abstract base class for Criterion and CriterionCombination. """ - def __init__(self, category: CohortCategory) -> None: - self._id = None + def __init__( + self, category: CohortCategory, id: int | None = None + ) -> None: + self._id = id assert isinstance( category, CohortCategory @@ -210,8 +212,10 @@ class Criterion(AbstractCriterion): Flag to indicate whether the filter_datetime function has been called. """ - def __init__(self, category: CohortCategory) -> None: - super().__init__(category=category) + def __init__( + self, category: CohortCategory, id: int | None = None + ) -> None: + super().__init__(category=category, id=id) def _set_omop_variables_from_domain(self, domain_id: str) -> None: """ diff --git a/execution_engine/omop/criterion/combination/combination.py b/execution_engine/omop/criterion/combination/combination.py index b671584a..79740e65 100644 --- a/execution_engine/omop/criterion/combination/combination.py +++ b/execution_engine/omop/criterion/combination/combination.py @@ -143,9 +143,13 @@ def dict(self) -> dict[str, Any]: "threshold": self._operator.threshold, "category": self._category.value, "criteria": [ - {"class_name": criterion.__class__.__name__, "data": criterion.dict()} + { + "class_name": criterion.__class__.__name__, + "data": criterion.dict(), + } for criterion in self._criteria ], + "root": self._root, } def __invert__(self) -> AbstractCriterion: @@ -192,6 +196,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "CriterionCombination": combination = cls( operator=operator, category=category, + root_combination=data["root"], ) for criterion in data["criteria"]: diff --git a/execution_engine/omop/criterion/concept.py b/execution_engine/omop/criterion/concept.py index 838474dc..2474db4d 100644 --- a/execution_engine/omop/criterion/concept.py +++ b/execution_engine/omop/criterion/concept.py @@ -43,12 +43,13 @@ def __init__( self, category: CohortCategory, concept: Concept, + id: int | None = None, value: Value | None = None, static: bool | None = None, timing: Timing | None = None, override_value_required: bool | None = None, ): - super().__init__(category=category) + super().__init__(category=category, id=id) self._set_omop_variables_from_domain(concept.domain_id) self._concept = concept @@ -135,6 +136,7 @@ def dict(self) -> dict[str, Any]: Get a JSON representation of the criterion. """ return { + "id": self._id, "category": self._category.value, "concept": self._concept.model_dump(), "value": ( @@ -157,6 +159,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "ConceptCriterion": """ return cls( + id=data["id"], category=CohortCategory(data["category"]), concept=Concept(**data["concept"]), value=( diff --git a/execution_engine/omop/criterion/drug_exposure.py b/execution_engine/omop/criterion/drug_exposure.py index e51c8a5c..762e6ffa 100644 --- a/execution_engine/omop/criterion/drug_exposure.py +++ b/execution_engine/omop/criterion/drug_exposure.py @@ -33,11 +33,12 @@ def __init__( ingredient_concept: Concept, dose: Dosage | None, route: Concept | None, + id: int | None = None, ) -> None: """ Initialize the drug administration action. """ - super().__init__(category=category) + super().__init__(category=category, id=id) self._set_omop_variables_from_domain("drug") self._ingredient_concept = ingredient_concept @@ -356,6 +357,7 @@ def dict(self) -> dict[str, Any]: Return a dictionary representation of the criterion. """ return { + "id": self._id, "category": self._category.value, "ingredient_concept": self._ingredient_concept.model_dump(), "dose": ( @@ -377,6 +379,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "DrugExposure": assert dose is None or isinstance(dose, Dosage), "Dose must be a Dosage or None" return cls( + id=data["id"], category=CohortCategory(data["category"]), ingredient_concept=Concept(**data["ingredient_concept"]), dose=dose, diff --git a/execution_engine/omop/criterion/procedure_occurrence.py b/execution_engine/omop/criterion/procedure_occurrence.py index 9355f951..58f2adb8 100644 --- a/execution_engine/omop/criterion/procedure_occurrence.py +++ b/execution_engine/omop/criterion/procedure_occurrence.py @@ -29,12 +29,14 @@ def __init__( value: ValueNumber | None = None, timing: Timing | None = None, static: bool | None = None, + id: int | None = None, ) -> None: super().__init__( category=category, concept=concept, value=value, static=static, + id=id, ) self._set_omop_variables_from_domain("procedure") @@ -158,6 +160,7 @@ def dict(self) -> dict[str, Any]: assert self._concept is not None, "Concept must be set" return { + "id": self._id, "category": self._category.value, "concept": self._concept.model_dump(), "value": ( @@ -189,6 +192,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "ProcedureOccurrence": ), "timing must be a ValueNumber" return cls( + id=data["id"], category=CohortCategory(data["category"]), concept=Concept(**data["concept"]), value=value, diff --git a/execution_engine/omop/criterion/visit_occurrence.py b/execution_engine/omop/criterion/visit_occurrence.py index d28180d6..fd9a9cfe 100644 --- a/execution_engine/omop/criterion/visit_occurrence.py +++ b/execution_engine/omop/criterion/visit_occurrence.py @@ -23,8 +23,10 @@ class ActivePatients(VisitOccurrence): Select only patients who are still hospitalized. """ - def __init__(self) -> None: + def __init__(self, id: int | None = None) -> None: + # TODO(jmoringe): why not use the constructor? super().__init__(id=id) self._category = CohortCategory.BASE + self._id = id if get_config().episode_of_care_visit_detail: self._set_omop_variables_from_domain("visit_detail") @@ -86,14 +88,14 @@ def dict(self) -> dict[str, Any]: """ Get a JSON representation of the criterion. """ - return {"class_name": self.__class__.__name__, "data": {}} + return {"id": self._id} @classmethod def from_dict(cls, data: Dict[str, Any]) -> "ActivePatients": """ Create a criterion from a JSON representation. """ - return cls() + return cls(id=data["id"]) class PatientsActiveDuringPeriod(ActivePatients): diff --git a/execution_engine/omop/serializable.py b/execution_engine/omop/serializable.py index 703c135c..a7d4f4fb 100644 --- a/execution_engine/omop/serializable.py +++ b/execution_engine/omop/serializable.py @@ -51,13 +51,10 @@ def json(self) -> bytes: s_json = self.dict() - if "id" in s_json: - del s_json["id"] - return json.dumps(s_json, sort_keys=True).encode() @classmethod - def from_json(cls, data: str) -> Self: + def from_json(cls, data: str | bytes) -> Self: """ Create a combination from a JSON string. """ diff --git a/tests/execution_engine/omop/cohort/test_cohort_recommendation.py b/tests/execution_engine/omop/cohort/test_cohort_recommendation.py new file mode 100644 index 00000000..87defac8 --- /dev/null +++ b/tests/execution_engine/omop/cohort/test_cohort_recommendation.py @@ -0,0 +1,60 @@ +import pytest + +from execution_engine.omop.cohort.recommendation import Recommendation +from execution_engine.omop.concepts import Concept +from execution_engine.omop.criterion.visit_occurrence import ActivePatients +from tests.mocks.criterion import MockCriterion + + +class TestRecommendation: + + def test_serialization(self): + # Register the mock criterion class + from execution_engine.omop.criterion import factory + + factory.register_criterion_class("MockCriterion", MockCriterion) + + original = Recommendation( + pi_pairs=[], + base_criterion=MockCriterion("c"), + name="foo", + title="bar", + url="baz", + version="1.0", + description="hi", + ) + + json = original.json() + deserialized = Recommendation.from_json(json) + assert original == deserialized + + @pytest.fixture + def concept(self): + return Concept( + concept_id=32037, + concept_name="Intensive Care", + domain_id="Visit", + vocabulary_id="Visit", + concept_class_id="Visit", + standard_concept="S", + concept_code="OMOP4822460", + invalid_reason=None, + ) + + def test_serialization_with_active_patients(self, concept): + # Test with a ActivePatients as the base criterion + # specifically because there used to be a problem with the + # serialization of that combination. + original = Recommendation( + pi_pairs=[], + base_criterion=ActivePatients(), + name="foo", + title="bar", + url="baz", + version="1.0", + description="hi", + ) + + json = original.json() + deserialized = Recommendation.from_json(json) + assert original == deserialized diff --git a/tests/execution_engine/omop/criterion/combination/test_logical_combination.py b/tests/execution_engine/omop/criterion/combination/test_logical_combination.py index 0a8ba06a..e088a603 100644 --- a/tests/execution_engine/omop/criterion/combination/test_logical_combination.py +++ b/tests/execution_engine/omop/criterion/combination/test_logical_combination.py @@ -109,6 +109,7 @@ def test_criterion_combination_dict(self, mock_criteria): {"class_name": "MockCriterion", "data": criterion.dict()} for criterion in mock_criteria ], + "root": False, } def test_criterion_combination_from_dict(self, mock_criteria): @@ -123,6 +124,7 @@ def test_criterion_combination_from_dict(self, mock_criteria): {"class_name": "MockCriterion", "data": criterion.dict()} for criterion in mock_criteria ], + "root": False, } # Register the mock criterion class diff --git a/tests/execution_engine/omop/criterion/combination/test_temporal_combination.py b/tests/execution_engine/omop/criterion/combination/test_temporal_combination.py index e4e58c49..cae1e71e 100644 --- a/tests/execution_engine/omop/criterion/combination/test_temporal_combination.py +++ b/tests/execution_engine/omop/criterion/combination/test_temporal_combination.py @@ -115,6 +115,7 @@ def test_criterion_combination_dict(self, mock_criteria): {"class_name": "MockCriterion", "data": criterion.dict()} for criterion in mock_criteria ], + "root": False, } def test_criterion_combination_from_dict(self, mock_criteria): @@ -128,6 +129,7 @@ def test_criterion_combination_from_dict(self, mock_criteria): ) combination_data = { + "id": None, "operator": "AT_LEAST", "threshold": 1, "category": "POPULATION_INTERVENTION",