Skip to content
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 111 additions & 0 deletions benchmarks/causal_discovery.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import numpy as np
from pgmpy.base import DAG
from pgmpy.estimators import PC, GES
from pgmpy.metrics import SHD
from pgmpy.factors.continuous import LinearGaussianCPD
from pgmpy.models import LinearGaussianBayesianNetwork as LGBN

"""
Benchmarking Structural Hamming Distance (SHD) for Causal Discovery Algorithms: PC and GES

Algorithm Definitions:
----------------------
- PC (Peter-Clark) Algorithm:
A constraint-based algorithm that starts with a complete undirected graph and removes edges
based on conditional independence tests. It then orients edges using separation sets and
rules like the collider rule:
X → Z ← Y if X ⟂⟂ Y | Z

- GES (Greedy Equivalence Search) Algorithm:
A score-based algorithm that performs greedy forward and backward search in the space
of equivalence classes of DAGs to maximize a scoring criterion such as BIC.

Scoring function (Bayesian Information Criterion - BIC):
Score(G : D) = log P(D | G) - λ * |G|

Metric:
-------
- SHD (Structural Hamming Distance):
Measures the number of edge insertions, deletions, or reversals required to convert
one DAG into another.
"""

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Vanshitaaa20 , Add def of GES and PC in doc string form in the script.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

def generate_random_dag(num_nodes: int, edge_prob: float = 0.3, seed: int = 0) -> DAG:
dag = DAG.get_random(n_nodes=num_nodes, edge_prob=edge_prob, seed=seed)
for i in range(num_nodes):
dag.add_node(f"X_{i}")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to add the nodes here? Doesn't the get_random method already give the DAG on the specified number of nodes?

return dag

# Benchmark parameters
num_trials = 10
shd_pc_list = []
shd_ges_list = []

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Vanshitaaa20 , add the algo equations like the benchmarking script in the "doc string" form

# Run trials
for trial in range(num_trials):
np.random.seed(trial)
print(f"\nTrial {trial + 1}/{num_trials}")

true_dag = generate_random_dag(num_nodes=5, edge_prob=0.3, seed=trial)

lgbn = LGBN(true_dag.edges())
lgbn.add_nodes_from(true_dag.nodes())
for node in true_dag.nodes():
parents = list(lgbn.get_parents(node))
beta = [0.0] + list(np.random.uniform(0.5, 1.5, size=len(parents)))
cpd = LinearGaussianCPD(variable=node, beta=beta, std=1, evidence=parents)
lgbn.add_cpds(cpd)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LinearGaussianBayesianNetwork has a get_random method that should give a full randomly generated model.


data = lgbn.simulate(n=1000)

# PC Estimation
try:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to do a try-except. Better to let it fail; it will help in detecting bugs.

learned_dag_pc = PC(data).estimate(
ci_test="pearsonr",
variant="stable",
return_type="dag",
)
except Exception as e:
print(" PC estimation failed:", e)
continue

# GES Estimation
try:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.

ges_out = GES(data).estimate(scoring_method="bic-g")
learned_dag_ges = (
ges_out["model"]
if isinstance(ges_out, dict) and "model" in ges_out
else (ges_out[0] if isinstance(ges_out, tuple) else ges_out)
)
except Exception as e:
print(" GES estimation failed:", e)
continue

# Ensure node alignment
all_nodes = sorted(set(true_dag.nodes()).union(
set(learned_dag_pc.nodes())).union(set(learned_dag_ges.nodes())))
true_dag.add_nodes_from(all_nodes)
learned_dag_pc.add_nodes_from(all_nodes)
learned_dag_ges.add_nodes_from(all_nodes)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this required? Both the true model and the learned dag would have the same nodes/variables as in the dataset. Right?


# Compute SHD using built-in method
try:
shd_pc = SHD(true_dag, learned_dag_pc)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Try-except should only be used when we know in what situation the code will throw an error, and that is the expected behavior. Not required here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Vanshitaaa20 , Like @ankurankan said, remove try-catch exceptions , let the test cases fail, so that we could identify the potential causes and make changes on it.

shd_ges = SHD(true_dag, learned_dag_ges)
except Exception as e:
print(" SHD computation failed:", e)
print(" true_dag edges:", true_dag.edges())
print(" learned_dag_pc edges:", learned_dag_pc.edges())
print(" learned_dag_ges edges:", learned_dag_ges.edges())
continue

shd_pc_list.append(shd_pc)
shd_ges_list.append(shd_ges)

print(" SHD (PC):", shd_pc)
print(" SHD (GES):", shd_ges)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a custom csv import instead of print , which gets stored in causalbench/results folder. Perform this changes first . @Vanshitaaa20

# Final Results
print(f"\nAverage SHD over {len(shd_pc_list)} successful trials:")
print(f" PC: {np.mean(shd_pc_list):.2f} ± {np.std(shd_pc_list):.2f}")
print(f" GES: {np.mean(shd_ges_list):.2f} ± {np.std(shd_ges_list):.2f}")