diff --git a/execution_engine/omop/cohort/population_intervention_pair.py b/execution_engine/omop/cohort/population_intervention_pair.py index ab7100dd..c072b883 100644 --- a/execution_engine/omop/cohort/population_intervention_pair.py +++ b/execution_engine/omop/cohort/population_intervention_pair.py @@ -99,14 +99,44 @@ def url(self) -> str: """ return self._url + @classmethod + def filter_symbols(cls, node: logic.Expr, filter_: logic.Expr) -> logic.Expr: + """ + Filter (=AND-combine) all symbols by the applied filter function + + Used to filter all intervention criteria (symbols) by the population output in order to exclude + all intervention events outside the population intervals, which may otherwise interfere with corrected + determination of temporal combination, i.e. the presence of an intervention event during some time window. + """ + + if isinstance(node, logic.Symbol): + return logic.And(node, filter_, category=CohortCategory.INTERVENTION) + + if hasattr(node, "args") and isinstance(node.args, tuple): + converted_args = [cls.filter_symbols(a, filter_) for a in node.args] + + if any(a is not b for a, b in zip(node.args, converted_args)): + node.args = tuple(converted_args) + + return node + def execution_graph(self) -> ExecutionGraph: """ Get the execution graph for the population/intervention pair. """ + p = ExecutionGraph.combination_to_expression(self._population) + i = ExecutionGraph.combination_to_expression(self._intervention) + + # filter all intervention criteria by the output of the population - this is performed to filter out + # intervention events that outside of the population intervals (i.e. the time windows during which + # patients are part of the population) as otherwise events outside of the population time may be picked up + # by Temporal criteria that determine the presence of some event or condition during a specific time window. + i = self.filter_symbols(i, filter_=p) + pi = logic.LeftDependentToggle( - ExecutionGraph.combination_to_expression(self._population), - ExecutionGraph.combination_to_expression(self._intervention), + p, + i, category=CohortCategory.POPULATION_INTERVENTION, ) pi_graph = ExecutionGraph.from_expression(pi, self._base_criterion) diff --git a/execution_engine/omop/cohort/recommendation.py b/execution_engine/omop/cohort/recommendation.py index 120fd13e..81be0997 100644 --- a/execution_engine/omop/cohort/recommendation.py +++ b/execution_engine/omop/cohort/recommendation.py @@ -116,7 +116,6 @@ def execution_graph(self) -> ExecutionGraph: """ p_nodes = [] - i_nodes = [] pi_nodes = [] pi_graphs = [] @@ -124,16 +123,12 @@ def execution_graph(self) -> ExecutionGraph: pi_graph = pi_pair.execution_graph() p_nodes.append(pi_graph.sink_node(CohortCategory.POPULATION)) - i_nodes.append(pi_graph.sink_node(CohortCategory.INTERVENTION)) pi_nodes.append(pi_graph.sink_node(CohortCategory.POPULATION_INTERVENTION)) pi_graphs.append(pi_graph) p_combination_node = logic.NoDataPreservingOr( *p_nodes, category=CohortCategory.POPULATION ) - i_combination_node = logic.NoDataPreservingOr( - *i_nodes, category=CohortCategory.INTERVENTION - ) pi_combination_node = logic.NoDataPreservingAnd( *pi_nodes, category=CohortCategory.POPULATION_INTERVENTION ) @@ -143,9 +138,6 @@ def execution_graph(self) -> ExecutionGraph: common_graph.add_node( p_combination_node, store_result=True, category=CohortCategory.POPULATION ) - common_graph.add_node( - i_combination_node, store_result=True, category=CohortCategory.INTERVENTION - ) common_graph.add_node( pi_combination_node, @@ -154,7 +146,6 @@ def execution_graph(self) -> ExecutionGraph: ) common_graph.add_edges_from((src, p_combination_node) for src in p_nodes) - common_graph.add_edges_from((src, i_combination_node) for src in i_nodes) common_graph.add_edges_from((src, pi_combination_node) for src in pi_nodes) return common_graph diff --git a/tests/recommendation/test_recommendation_base.py b/tests/recommendation/test_recommendation_base.py index 68749e33..8758e127 100644 --- a/tests/recommendation/test_recommendation_base.py +++ b/tests/recommendation/test_recommendation_base.py @@ -709,6 +709,14 @@ def assemble_daily_recommendation_evaluation( df, group["intervention"] ) + # filter intervention by population + if df[f"i_{group_name}"].dtype == bool: + df[f"i_{group_name}"] &= df[f"p_{group_name}"] + else: + df[f"i_{group_name}"] = ( + df[f"p_{group_name}"] == IntervalType.POSITIVE + ) & (df[f"i_{group_name}"] == IntervalType.POSITIVE) + # expressions like "Eq(a+b+c, 1)" (at least one criterion) yield boolean columns and must # be converted to IntervalType if df[(f"p_{group_name}", "")].dtype == bool: @@ -984,7 +992,11 @@ def process_result(df_result): df_result_p_i = omopdb.query( get_query( partial_day_coverage, - category=[CohortCategory.BASE, CohortCategory.POPULATION], + category=[ + CohortCategory.BASE, + CohortCategory.POPULATION, + CohortCategory.INTERVENTION, + ], ) ) @@ -993,7 +1005,6 @@ def process_result(df_result): get_query( full_day_coverage, category=[ - CohortCategory.INTERVENTION, CohortCategory.POPULATION_INTERVENTION, ], ) diff --git a/tests/recommendation/test_recommendation_base_v2.py b/tests/recommendation/test_recommendation_base_v2.py index 2e6a6504..5bc1a584 100644 --- a/tests/recommendation/test_recommendation_base_v2.py +++ b/tests/recommendation/test_recommendation_base_v2.py @@ -485,6 +485,14 @@ def assemble_daily_recommendation_evaluation( df[f"p_{group_name}"] = evaluate_expression(group["population"], df) df[f"i_{group_name}"] = evaluate_expression(group["intervention"], df) + # filter intervention by population + if df[f"i_{group_name}"].dtype == bool: + df[f"i_{group_name}"] &= df[f"p_{group_name}"] + else: + df[f"i_{group_name}"] = ( + df[f"p_{group_name}"] == IntervalType.POSITIVE + ) & (df[f"i_{group_name}"] == IntervalType.POSITIVE) + # expressions like "Eq(a+b+c, 1)" (at least one criterion) yield boolean columns and must # be converted to IntervalType if df[f"p_{group_name}"].dtype == bool: diff --git a/tests/recommendation/utils/result_comparator.py b/tests/recommendation/utils/result_comparator.py index e9bbb2cb..01c05593 100644 --- a/tests/recommendation/utils/result_comparator.py +++ b/tests/recommendation/utils/result_comparator.py @@ -42,16 +42,14 @@ def plan_names(self) -> list[str]: return [col[2:] for col in self.df.columns if col.startswith("i_")] def plan_name_column_names(self) -> list[str]: - cols = [ - "_".join(i) for i in itertools.product(["p", "i", "p_i"], self.plan_names) - ] - return cols + ["p", "i", "p_i"] + cols = ["_".join(i) for i in itertools.product(["p", "p_i"], self.plan_names)] + return cols + ["p", "p_i"] def derive_database_result(self, df: pd.DataFrame) -> "ResultComparator": df = df.copy() - df.loc[ - :, [c for c in self.plan_name_column_names() if c not in df.columns] - ] = False + df.loc[:, [c for c in self.plan_name_column_names() if c not in df.columns]] = ( + False + ) return ResultComparator(name="db", df=df)