-
Notifications
You must be signed in to change notification settings - Fork 259
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft causal discovery GES implementation #1517
base: main
Are you sure you want to change the base?
Conversation
@model_validator(mode="before") | ||
def set_defaults(cls, values): | ||
""" | ||
Set default values for optional parameters. | ||
|
||
Parameters | ||
---------- | ||
values : dict | ||
Dictionary of input values | ||
|
||
Returns | ||
------- | ||
dict | ||
Updated values with defaults set | ||
""" | ||
# Ensure that no_descendants and no_ascendents are lists. | ||
if values.get("no_descendants") is None: | ||
values["no_descendants"] = [] | ||
if values.get("no_ascendents") is None: | ||
values["no_ascendents"] = [] | ||
return values | ||
|
||
def __init__(self, **data): | ||
""" | ||
Initialize the causal discovery model. | ||
|
||
Parameters | ||
---------- | ||
**data : dict | ||
Keyword arguments matching class attributes | ||
""" | ||
super().__init__(**data) | ||
# Process the data and initialize internal state. | ||
self.data = self.data.drop(columns=["date"], errors="ignore") | ||
self.nodes = self.data.columns.tolist() | ||
self.graph = {node: [] for node in self.nodes} | ||
self.no_descendants = set(self.no_descendants) | ||
self.no_ascendents = set(self.no_ascendents) | ||
self.local_score_cache = {} | ||
self.inference_cache = {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need for any of this. Just use default_factory
dot.render(output_filename, format=output_format, view=view) | ||
return dot | ||
|
||
def visualize( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about to_graphviz
to be aligned with the ModelBuilder
method name
if parents: | ||
intercept = pm.Normal(f"intercept_{node}", mu=0, sigma=1) | ||
betas = pm.Normal(f"beta_{node}", mu=0, sigma=1, shape=len(parents)) | ||
parent_data = self.data[parents].values | ||
mu_node = intercept + pt.dot(parent_data, betas) | ||
else: | ||
intercept = pm.Normal(f"intercept_{node}", mu=0, sigma=1) | ||
mu_node = intercept | ||
|
||
sigma = pm.HalfNormal(f"sigma_{node}", sigma=1) | ||
pm.Normal( | ||
f"obs_{node}", mu=mu_node, sigma=sigma, observed=self.data[node].values | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would the goal be to vectorize this across set of DAG combinations?
Description
BayesianCausalDiscoveryBase
that leverages Pydantic for parameter validation and encapsulates common functionality for Bayesian causal discovery.BayesianGreedySearch
that performs a greedy search with modular forward and backward phases.Related Issue
Checklist
pre-commit.ci autofix
to auto-fix.📚 Documentation preview 📚: https://pymc-marketing--1517.org.readthedocs.build/en/1517/