A collection of different projects and ideas that use differentiable programming to explore self organising systems with emergent pattern formation. Primarily for my PhD research.
The general idea is to build auto-differentiable complex systems that can be efficiently trained to yield specified (via data) spatio-temporal patterning.
To run this code, either:
pip install -r requirements_<DEVICE>.txt
conda env create -f env_<DEVICE>.yml
Where <DEVICE>
is either cpu
or gpu
.
There are 3 main branches of work here:
- NCA (neural cellular automata)
- PDE (partial differential equations)
- ABM (agent based modelling)
There is considerable overlap between these, so they all inherit from base classes in Common. Everything that builds into a model is a subclass of equinox.Module
, so that propagation of gradients works correctly.
Common.model.abstract_model.py
contains theAbstractModel
class, a subclass ofequinox.Module
- This contains a few extra utility methods like saving to / loading from files
Common.model.spatial_operators.py
contains theOps
class, a subclass ofequinox.Module
- This performs finite difference approximations of any 2D vector calculus operations (divergence, gradient, laplacians etc.)
Common.model.boundary.py
contains themodel_boundary
class- This enforces complex boundary conditions on model states during training and evalutation
Common.trainer.abstract_data_augmenter_tree.py
contains aDataAugmenterAbstract
class- Subclasses of this handle any data augmentation during training of models.
- Data should be a PyTree (list) of arrays, which allows for simultaneously training to different sized patterns
Common.trainer.abstract_data_augmenter_array.py
contains aDataAugmenterAbstract
class- Same as above, but only accepts data as one array. Does multi-gpu data parallelism through JAX sharding
Common.trainer.abstract_tensorboard_log.py
contains aTrain_log
class- Subclasses of this save various aspects of model parameters/state during and after training
Common.trainer.loss.py
contains various custom loss functionsCommon.utils.py
contains various helper functions for loading and processing specific datasets
Everything is subclassed from Common. Important details are that:
- Everything in
NCA.model.
is anAbstractModel
subclass that also uses theOps
class NCA.trainer.NCA_trainer.py
includes theNCA_Trainer
class- This uses Optax to fit the NCA models to data
NCA.trainer.data_augmenter_*
Include variousDataAugmenter
subclasses, each for training to a different taskNCA.trainer.tensorboard_log.py
subclassesTrain_log
to visualise model parameters during trainingNCA.NCA_visualiser.py
produces nice plots summarizing model parameters of anNCA
model
Everything is subclassed from Common. Important details are that:
PDE.model.solver.semidiscrete_solver.py
contains thePDE_solver
class, a subclass ofAbstractModel
- This uses
diffrax
to perform a fully auto-differentiable numerical ODE solve on a spatially discretised PDE (i.e. a system of ODEs) - Needs to be initialised with the RHS of the PDE, an
equinox.Module
with call signatureF: t,X,args -> X
- This uses
PDE.model.reaction_diffusion_advection.update.py
includes an auto-differentiable multi-species reaction diffusion advection equation, parameterised by neural networksPDE.model.reaction_diffusion_chemotaxis.update.py
includes an auto-differentiable multi-cell multi-signal reaction diffusion chemotaxis equation, parameterised by neural networksPDE.model.fixed_models.update_*
contains four nice example PDEs that perform pattern formationPDE.trainer.PDE_trainer.py
includes thePDE_Trainer
that uses Optax to fit PDE paramaters such that the solutions of the PDE approximate a given time seriesPDE/trainer/optimiser.py
includes customoptax.GradientTransformation()
that keeps diffusion coefficients non-negative