Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft causal discovery GES implementation #1517

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

cetagostini
Copy link
Contributor

@cetagostini cetagostini commented Feb 22, 2025

Description

  • Introduces a new base class BayesianCausalDiscoveryBase that leverages Pydantic for parameter validation and encapsulates common functionality for Bayesian causal discovery.
  • Implements key graph utilities including DAG visualization, acyclicity checks, and conversion to CPDAG.
  • Provides a subclass BayesianGreedySearch that performs a greedy search with modular forward and backward phases.
  • Supports two scoring metrics: the default Penalized Log-Likelihood Score (PLLS) and a BIC-based score.
  • Enhances efficiency with caching of local scores and inference data, and integrates ground-truth constraints (nodes with no descendants or no ascendents) to reduce the search space.

Related Issue

  • Closes #
  • Related to #

Checklist


📚 Documentation preview 📚: https://pymc-marketing--1517.org.readthedocs.build/en/1517/

@github-actions github-actions bot added the MMM label Feb 22, 2025
Comment on lines +215 to +254
@model_validator(mode="before")
def set_defaults(cls, values):
"""
Set default values for optional parameters.

Parameters
----------
values : dict
Dictionary of input values

Returns
-------
dict
Updated values with defaults set
"""
# Ensure that no_descendants and no_ascendents are lists.
if values.get("no_descendants") is None:
values["no_descendants"] = []
if values.get("no_ascendents") is None:
values["no_ascendents"] = []
return values

def __init__(self, **data):
"""
Initialize the causal discovery model.

Parameters
----------
**data : dict
Keyword arguments matching class attributes
"""
super().__init__(**data)
# Process the data and initialize internal state.
self.data = self.data.drop(columns=["date"], errors="ignore")
self.nodes = self.data.columns.tolist()
self.graph = {node: [] for node in self.nodes}
self.no_descendants = set(self.no_descendants)
self.no_ascendents = set(self.no_ascendents)
self.local_score_cache = {}
self.inference_cache = {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need for any of this. Just use default_factory

dot.render(output_filename, format=output_format, view=view)
return dot

def visualize(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about to_graphviz to be aligned with the ModelBuilder method name

Comment on lines +408 to +420
if parents:
intercept = pm.Normal(f"intercept_{node}", mu=0, sigma=1)
betas = pm.Normal(f"beta_{node}", mu=0, sigma=1, shape=len(parents))
parent_data = self.data[parents].values
mu_node = intercept + pt.dot(parent_data, betas)
else:
intercept = pm.Normal(f"intercept_{node}", mu=0, sigma=1)
mu_node = intercept

sigma = pm.HalfNormal(f"sigma_{node}", sigma=1)
pm.Normal(
f"obs_{node}", mu=mu_node, sigma=sigma, observed=self.data[node].values
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would the goal be to vectorize this across set of DAG combinations?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants