This README covers a running example of training an AVICI model in causal discovery for a custom data-generating process. The components we provide here can be used as a starting point for a new project. As an illustrative example, we implement AVICI trained on SCMs with sinusoidal functions and random tree graphs.
This folder contains the following three files:
This file contains all custom classes that make up the generative model
of our domain.
All functions should be written using standard numpy and not jax.numpy. This is both faster and avoids conflicting
resource usage of the CPU workers that continually update the training data buffers
and the hardware accelerators used by jax for the actual network training.
The provided func.py implements two example classes for sampling random trees and SCMs with sinusoidal functions,
respectively, which are not implemented already in avici.synthetic.
Each custom data-generating process must subclass one of the following two abstract base classes
and implement the __call__ function with the correct signatures:
Subclasses of GraphModel implement functionality for sampling training graphs and
can be used to define (part of) the causal graph distribution p(G).
Each child class has to implement __call__ acceping two arguments:
- rng (np.random.Generator) – numpy pseudorandom number generator
- n_vars (int) – number of nodes in the graph
Returns:
- ndarray – binary adjacency matrix of shape
[n_vars, n_vars]
Example:
import numpy as onp
from avici.synthetic import GraphModel
class DummyGraph(GraphModel):
def __call__(self, rng, n_vars):
return onp.zeros((n_vars, n_vars))Subclasses of MechanismModel implement functionality for sampling observational
and interventional data given a causal graph. These classes can be used to define (part of)
the data-generating distribution p(D | G).
Each child class has to implement __call__ acceping four arguments:
- rng (np.random.Generator) – numpy pseudorandom number generator
- g (ndarray) – binary adjacency matrix of shape
[n_vars, n_vars]as generated by aGraphModelsubclass - n_observations_obs (int) – number of observational data points to be sampled
- n_observations_int (int) – number of interventional data points to be sampled
Returns:
- avici.synthetic.Data – namedtuple containing
x_obs,x_intand booleanis_count_data. The data matricesx_obsandx_intmust have shapes[n_observations_obs, n_vars, 2]and[n_observations_int, n_vars, 2], respectively. The first value in the last axis (i.e.x_int[..., 0]) contains the values and the second axis (i.e.x_int[..., 1]) contains either 0 or 1, indicating which nodes were intervened upon in which observations.
Accordingly,x_obs[..., 1]has only zeros as it always contains observational data.
is_count_datais used to determine how the data is standardized. (Default for all real-valued data should beFalse, which implies the usual z-standardization.)
Example:
import numpy as onp
from avici.synthetic import MechanismModel, Data
class DummyMechanism(MechanismModel):
def __call__(self, rng, g, n_observations_obs, n_observations_int):
n_vars = g.shape[-1]
return Data(
x_obs=onp.zeros((n_observations_obs, n_vars, 2)),
x_int=onp.zeros((n_observations_int, n_vars, 2)),
is_count_data=False,
)Both GraphModel and MechanismModel subclasses can be initialized with and
store an arbitrary number of arguments for later use inside __call__,
like function parameters or other sampling functions.
For GraphModel, this is also where additional details on the interventions ought to be specified,
e.g., how many nodes are intervened upon and in what fashion.
This YAML file is the configuration file that defines the distribution over datasets our structure learning model is trained on. The file has to be structured in the following way:
---
train_n_vars: [5, 10]
test_n_vars: [20]
test_n_datasets: 10
additional_modules:
- "./func.py"
data:
- n_observations_obs: 300
n_observations_int: 100
graph:
- __class__: ErdosRenyi
edges_per_var: [ 1.0, 2.0, 3.0 ]
mechanism:
- __class__: LinearAdditive
param:
- __class__: SignedUniform
low: 1.0
high: 3.0
bias: ...
noise: ...
noise_scale: ...
n_interv_vars: ...
interv_dist: ...
- ...The top-level keywords specify the following:
- train_n_vars – list of integers specifying the numbers of variables in the causal graphs and datasets during training
- test_n_vars – list of integers specifying the numbers of variables used for validation
- test_n_datasets – number of validation datasets
- additional_modules – list of paths (relative or absolute) defining additional data-generating processes
(e.g., our
func.pyfile) - data – nested combination of dicts and lists specifying the full data-generating distribution
The data entry specifies the distribution over training datasets. During training, we continually generate fresh data for data buffers of the different numbers of variables according to this distribution. The configuration of the data field maintains the following invariants:
-
If any (nested) part of the data tree is a list, one configuration of it is selected uniformly at random in each new sample. For example, in the above configuration, all graphs are Erdos-Renyi, in which the expected number of edges per node is either 1, 2, or 3, selected randomly for each new dataset. Internally, the nested dict of lists is expanded into a single list of all possible combinations of dicts, so be careful not to specify too many combinations (>1000).
-
Each (list) element in the top level of _data needs to specify:
graph,mechanism,n_observations_obs, andn_observations_int(satisfying theavici.synthetic.SyntheticSpecsignature). The integersn_observations_obsandn_observations_intspecify the number of data points generated for each dataset. At training time, these observations are subbatched further depending on the optimization parameters. -
Each (list) element in the top level of data.graph needs to define a
GraphModelsubclass, and each (list) element of data.mechanism aMechanismModelsubclass. The class name is specified via the__class__key. All other arguments the class expects at initialization time (via__init__) are specified alongside. Please refer to the signature ofavici.synthetic.LinearAdditiveto verify this in the above example.The class arguments may be (lists of) classes themselves, defined recursively in the same way. For example,
avici.synthetic.Distributionsubclasses specify how the weights and noise of the linear function SCMLinearAdditiveis sampled. Likewise,avici.synthetic.NoiseModelsubclasses specify the noise scale in the SCM. -
All classes not available inside
avici.syntheticneed to be defined in other files and specified via their path in the additional_modules field. When specified this way, they can be used in the configuration exactly like all other members already provided inavici.synthetic.
The easiest way of understanding how domain.yaml is configured is to look at a few examples.
The following configurations define the training distributions of the models trained
in Lorch et al., (2022), whose checkpoints are
available for download via avici.load_pretrained:
linear.yaml, rff.yaml, gene.yaml.
These config files directly correspond to the Tables given in Appendix A of the paper.
Currently, we provide the following data-generating processes in avici.synthetic:
-
GraphModelsubclasses:ErdosRenyiScaleFreeScaleFreeTransposeWattsStrogatzSBMGRGYeastEcoli
-
MechanismModelsubclasses:LinearAdditiveRFFAdditiveGRNSergio
-
Distributionsubclasses:GaussianLaplaceCauchyUniformSignedUniformRandIntBeta
-
NoiseModelsubclasses:SimpleNoiseHeteroscedasticRFFNoise
These classes can be used in a domain.yaml configuration out-of-the-box and without further specifications.
This is the main training script.
Our provided script automatically performs multi-device training. Hence, if you run this
script on a machine with mulitple GPUs, all accelerators will be used directly using jax.pmap
and corresponding functions.
Given our above domain configuration,
we can train a first (small) model to check the script
by changing directory to example-custom/ and running
python train.py --config "./domain.yaml"where --config specifies an (absolute or relative) path to our YAML domain configuration.
The above call uses --smoke_test true by default, which sets the network and training
parameters to small dummy values for testing.
For further information about the other command line arguments,
run python train.py --help.
To train a large model with the same hyperparameters as Lorch et al., (2022), run
python train.py --config "./domain.yaml" --smoke_test falseEach different n_vars of the training data distribution requires
a seperate jax.jit compilation. Therefore, it is normal that the first
few steps of training take relatively long.
After each n_vars has been seen, update steps are fast (~0.5sec/step for the
full model on Quadro RTX 6000 GPUs).
The script automatically generates checkpoints, which can be used
both for continuing training and for downstream predictions.
By default, the checkpoints are stored in ./checkpoints/.
To re-start training with a checkpoint,
simply re-run train.py with the same checkpoint directory
and the code will automatically detect the most recent checkpoint.
Analogous to the pretrained checkpoints we provide for automatic download,
the checkpoints created during training with this script can be
loaded using the avici.load_pretrained function:
import avici
model = avici.load_pretrained(checkpoint_dir="path/to/checkpoint", expects_counts=False)