Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 32 additions & 2 deletions execution_engine/omop/cohort/population_intervention_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,44 @@ def url(self) -> str:
"""
return self._url

@classmethod
def filter_symbols(cls, node: logic.Expr, filter_: logic.Expr) -> logic.Expr:
"""
Filter (=AND-combine) all symbols by the applied filter function

Used to filter all intervention criteria (symbols) by the population output in order to exclude
all intervention events outside the population intervals, which may otherwise interfere with corrected
determination of temporal combination, i.e. the presence of an intervention event during some time window.
"""

if isinstance(node, logic.Symbol):
return logic.And(node, filter_, category=CohortCategory.INTERVENTION)

if hasattr(node, "args") and isinstance(node.args, tuple):
converted_args = [cls.filter_symbols(a, filter_) for a in node.args]

if any(a is not b for a, b in zip(node.args, converted_args)):
node.args = tuple(converted_args)

return node

def execution_graph(self) -> ExecutionGraph:
"""
Get the execution graph for the population/intervention pair.
"""

p = ExecutionGraph.combination_to_expression(self._population)
i = ExecutionGraph.combination_to_expression(self._intervention)

# filter all intervention criteria by the output of the population - this is performed to filter out
# intervention events that outside of the population intervals (i.e. the time windows during which
# patients are part of the population) as otherwise events outside of the population time may be picked up
# by Temporal criteria that determine the presence of some event or condition during a specific time window.
i = self.filter_symbols(i, filter_=p)

pi = logic.LeftDependentToggle(
ExecutionGraph.combination_to_expression(self._population),
ExecutionGraph.combination_to_expression(self._intervention),
p,
i,
category=CohortCategory.POPULATION_INTERVENTION,
)
pi_graph = ExecutionGraph.from_expression(pi, self._base_criterion)
Expand Down
9 changes: 0 additions & 9 deletions execution_engine/omop/cohort/recommendation.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,24 +116,19 @@ def execution_graph(self) -> ExecutionGraph:
"""

p_nodes = []
i_nodes = []
pi_nodes = []
pi_graphs = []

for pi_pair in self._pi_pairs:
pi_graph = pi_pair.execution_graph()

p_nodes.append(pi_graph.sink_node(CohortCategory.POPULATION))
i_nodes.append(pi_graph.sink_node(CohortCategory.INTERVENTION))
pi_nodes.append(pi_graph.sink_node(CohortCategory.POPULATION_INTERVENTION))
pi_graphs.append(pi_graph)

p_combination_node = logic.NoDataPreservingOr(
*p_nodes, category=CohortCategory.POPULATION
)
i_combination_node = logic.NoDataPreservingOr(
*i_nodes, category=CohortCategory.INTERVENTION
)
pi_combination_node = logic.NoDataPreservingAnd(
*pi_nodes, category=CohortCategory.POPULATION_INTERVENTION
)
Expand All @@ -143,9 +138,6 @@ def execution_graph(self) -> ExecutionGraph:
common_graph.add_node(
p_combination_node, store_result=True, category=CohortCategory.POPULATION
)
common_graph.add_node(
i_combination_node, store_result=True, category=CohortCategory.INTERVENTION
)

common_graph.add_node(
pi_combination_node,
Expand All @@ -154,7 +146,6 @@ def execution_graph(self) -> ExecutionGraph:
)

common_graph.add_edges_from((src, p_combination_node) for src in p_nodes)
common_graph.add_edges_from((src, i_combination_node) for src in i_nodes)
common_graph.add_edges_from((src, pi_combination_node) for src in pi_nodes)

return common_graph
Expand Down
15 changes: 13 additions & 2 deletions tests/recommendation/test_recommendation_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,14 @@ def assemble_daily_recommendation_evaluation(
df, group["intervention"]
)

# filter intervention by population
if df[f"i_{group_name}"].dtype == bool:
df[f"i_{group_name}"] &= df[f"p_{group_name}"]
else:
df[f"i_{group_name}"] = (
df[f"p_{group_name}"] == IntervalType.POSITIVE
) & (df[f"i_{group_name}"] == IntervalType.POSITIVE)

# expressions like "Eq(a+b+c, 1)" (at least one criterion) yield boolean columns and must
# be converted to IntervalType
if df[(f"p_{group_name}", "")].dtype == bool:
Expand Down Expand Up @@ -984,7 +992,11 @@ def process_result(df_result):
df_result_p_i = omopdb.query(
get_query(
partial_day_coverage,
category=[CohortCategory.BASE, CohortCategory.POPULATION],
category=[
CohortCategory.BASE,
CohortCategory.POPULATION,
CohortCategory.INTERVENTION,
],
)
)

Expand All @@ -993,7 +1005,6 @@ def process_result(df_result):
get_query(
full_day_coverage,
category=[
CohortCategory.INTERVENTION,
CohortCategory.POPULATION_INTERVENTION,
],
)
Expand Down
8 changes: 8 additions & 0 deletions tests/recommendation/test_recommendation_base_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,14 @@ def assemble_daily_recommendation_evaluation(
df[f"p_{group_name}"] = evaluate_expression(group["population"], df)
df[f"i_{group_name}"] = evaluate_expression(group["intervention"], df)

# filter intervention by population
if df[f"i_{group_name}"].dtype == bool:
df[f"i_{group_name}"] &= df[f"p_{group_name}"]
else:
df[f"i_{group_name}"] = (
df[f"p_{group_name}"] == IntervalType.POSITIVE
) & (df[f"i_{group_name}"] == IntervalType.POSITIVE)

# expressions like "Eq(a+b+c, 1)" (at least one criterion) yield boolean columns and must
# be converted to IntervalType
if df[f"p_{group_name}"].dtype == bool:
Expand Down
12 changes: 5 additions & 7 deletions tests/recommendation/utils/result_comparator.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,14 @@ def plan_names(self) -> list[str]:
return [col[2:] for col in self.df.columns if col.startswith("i_")]

def plan_name_column_names(self) -> list[str]:
cols = [
"_".join(i) for i in itertools.product(["p", "i", "p_i"], self.plan_names)
]
return cols + ["p", "i", "p_i"]
cols = ["_".join(i) for i in itertools.product(["p", "p_i"], self.plan_names)]
return cols + ["p", "p_i"]

def derive_database_result(self, df: pd.DataFrame) -> "ResultComparator":
df = df.copy()
df.loc[
:, [c for c in self.plan_name_column_names() if c not in df.columns]
] = False
df.loc[:, [c for c in self.plan_name_column_names() if c not in df.columns]] = (
False
)

return ResultComparator(name="db", df=df)

Expand Down
Loading