diff --git a/docs/_include/models-table.py b/docs/_include/models-table.py index 9df816d6dcd..863b87ed73c 100644 --- a/docs/_include/models-table.py +++ b/docs/_include/models-table.py @@ -19,10 +19,10 @@ """Prints an RST table of available models from the inference.models module. """ -from pycbc.inference.models import models +from pycbc.inference.models import get_models from _dict_to_rst import (rst_dict_table, format_class) -tbl = rst_dict_table(models, key_format='``\'{0}\'``'.format, +tbl = rst_dict_table(get_models(), key_format='``\'{0}\'``'.format, header=('Name', 'Class'), val_format=format_class) diff --git a/pycbc/inference/__init__.py b/pycbc/inference/__init__.py index a52dc2e94f7..55c3d58ec84 100644 --- a/pycbc/inference/__init__.py +++ b/pycbc/inference/__init__.py @@ -1,4 +1,3 @@ - # pylint: disable=unused-import from . import (models, sampler, io) from . import (burn_in, entropy, gelman_rubin, geweke, option_utils) diff --git a/pycbc/inference/models/__init__.py b/pycbc/inference/models/__init__.py index 0f47348fa95..b12f3879c16 100644 --- a/pycbc/inference/models/__init__.py +++ b/pycbc/inference/models/__init__.py @@ -20,6 +20,10 @@ """ +import logging +from pkg_resources import iter_entry_points as _iter_entry_points +from .base import BaseModel +from .base_data import BaseDataModel from .analytic import (TestEggbox, TestNormal, TestRosenbrock, TestVolcano, TestPrior, TestPosterior) from .gaussian_noise import GaussianNoise @@ -176,10 +180,10 @@ def read_from_config(cp, **kwargs): """ # use the name to get the distribution name = cp.get("model", "name") - return models[name].from_config(cp, **kwargs) + return get_model(name).from_config(cp, **kwargs) -models = {_cls.name: _cls for _cls in ( +_models = {_cls.name: _cls for _cls in ( TestEggbox, TestNormal, TestRosenbrock, @@ -197,3 +201,146 @@ def read_from_config(cp, **kwargs): Relative, HierarchicalModel, )} + + +class _ModelManager(dict): + """Sub-classes dictionary to manage the collection of available models. + + The first time this is called, any plugin models that are available will be + added to the dictionary before returning. + """ + def __init__(self, *args, **kwargs): + self.retrieve_plugins = True + super().__init__(*args, **kwargs) + + def add_model(self, model): + """Adds a model to the dictionary. + + If the given model has the same name as a model already in the + dictionary, the original model will be overridden. A warning will be + printed in that case. + """ + if super().__contains__(model.name): + logging.warning("Custom model %s will override a model of the " + "same name. If you don't want this, change the " + "model's name attribute and restart.", model.name) + self[model.name] = model + + def add_plugins(self): + """Adds any plugin models that are available. + + This will only add the plugins if ``self.retrieve_plugins = True``. + After this runs, ``self.retrieve_plugins`` is set to ``False``, so that + subsequent calls to this will no re-add models. + """ + if self.retrieve_plugins: + for plugin in _iter_entry_points('pycbc.inference.models'): + self.add_model(plugin.resolve()) + self.retrieve_plugins = False + + def __len__(self): + self.add_plugins() + super().__len__() + + def __contains__(self, key): + self.add_plugins() + return super().__contains__(key) + + def get(self, *args): + self.add_plugins() + return super().get(*args) + + def popitem(self): + self.add_plugins() + return super().popitem() + + def pop(self, *args): + try: + return super().pop(*args) + except KeyError: + self.add_plugins() + return super().pop(*args) + + def keys(self): + self.add_plugins() + return super().keys() + + def values(self): + self.add_plugins() + return super().values() + + def items(self): + self.add_plugins() + return super().items() + + def __iter__(self): + self.add_plugins() + return super().__iter__() + + def __repr__(self): + self.add_plugins() + return super().__repr__() + + def __getitem__(self, item): + try: + return super().__getitem__(item) + except KeyError: + self.add_plugins() + return super().__getitem__(item) + + def __delitem__(self, *args, **kwargs): + try: + super().__delitem__(*args, **kwargs) + except KeyError: + self.add_plugins() + super().__delitem__(*args, **kwargs) + + +models = _ModelManager(_models) + + +def get_models(): + """Returns the dictionary of current models. + + Ensures that plugins are added to the dictionary first. + """ + models.add_plugins() + return models + + +def get_model(model_name): + """Retrieve the given model. + + Parameters + ---------- + model_name : str + The name of the model to get. + + Returns + ------- + model : + The requested model. + """ + return get_models()[model_name] + + +def available_models(): + """List the currently available models.""" + return list(get_models().keys()) + + +def register_model(model): + """Makes a custom model available to PyCBC. + + The provided model will be added to the dictionary of models that PyCBC + knows about, using the model's ``name`` attribute. If the ``name`` is the + same as a model that already exists in PyCBC, a warning will be printed. + + Parameters + ---------- + model : pycbc.inference.models.base.BaseModel + The model to use. The model should be a sub-class of + :py:class:`BaseModel ` to ensure + it has the correct API for use within ``pycbc_inference``. + """ + get_models().add_model(model)