Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/_include/models-table.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
"""Prints an RST table of available models from the inference.models
module.
"""
from pycbc.inference.models import models
from pycbc.inference.models import get_models
from _dict_to_rst import (rst_dict_table, format_class)

tbl = rst_dict_table(models, key_format='``\'{0}\'``'.format,
tbl = rst_dict_table(get_models(), key_format='``\'{0}\'``'.format,
header=('Name', 'Class'),
val_format=format_class)

Expand Down
1 change: 0 additions & 1 deletion pycbc/inference/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

# pylint: disable=unused-import
from . import (models, sampler, io)
from . import (burn_in, entropy, gelman_rubin, geweke, option_utils)
151 changes: 149 additions & 2 deletions pycbc/inference/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
"""


import logging
from pkg_resources import iter_entry_points as _iter_entry_points
from .base import BaseModel
from .base_data import BaseDataModel
from .analytic import (TestEggbox, TestNormal, TestRosenbrock, TestVolcano,
TestPrior, TestPosterior)
from .gaussian_noise import GaussianNoise
Expand Down Expand Up @@ -176,10 +180,10 @@ def read_from_config(cp, **kwargs):
"""
# use the name to get the distribution
name = cp.get("model", "name")
return models[name].from_config(cp, **kwargs)
return get_model(name).from_config(cp, **kwargs)


models = {_cls.name: _cls for _cls in (
_models = {_cls.name: _cls for _cls in (
TestEggbox,
TestNormal,
TestRosenbrock,
Expand All @@ -197,3 +201,146 @@ def read_from_config(cp, **kwargs):
Relative,
HierarchicalModel,
)}


class _ModelManager(dict):
"""Sub-classes dictionary to manage the collection of available models.

The first time this is called, any plugin models that are available will be
added to the dictionary before returning.
"""
def __init__(self, *args, **kwargs):
self.retrieve_plugins = True
super().__init__(*args, **kwargs)

def add_model(self, model):
"""Adds a model to the dictionary.

If the given model has the same name as a model already in the
dictionary, the original model will be overridden. A warning will be
printed in that case.
"""
if super().__contains__(model.name):
logging.warning("Custom model %s will override a model of the "
"same name. If you don't want this, change the "
"model's name attribute and restart.", model.name)
self[model.name] = model

def add_plugins(self):
"""Adds any plugin models that are available.

This will only add the plugins if ``self.retrieve_plugins = True``.
After this runs, ``self.retrieve_plugins`` is set to ``False``, so that
subsequent calls to this will no re-add models.
"""
if self.retrieve_plugins:
for plugin in _iter_entry_points('pycbc.inference.models'):
self.add_model(plugin.resolve())
self.retrieve_plugins = False

def __len__(self):
self.add_plugins()
super().__len__()

def __contains__(self, key):
self.add_plugins()
return super().__contains__(key)

def get(self, *args):
self.add_plugins()
return super().get(*args)

def popitem(self):
self.add_plugins()
return super().popitem()

def pop(self, *args):
try:
return super().pop(*args)
except KeyError:
self.add_plugins()
return super().pop(*args)

def keys(self):
self.add_plugins()
return super().keys()

def values(self):
self.add_plugins()
return super().values()

def items(self):
self.add_plugins()
return super().items()

def __iter__(self):
self.add_plugins()
return super().__iter__()

def __repr__(self):
self.add_plugins()
return super().__repr__()

def __getitem__(self, item):
try:
return super().__getitem__(item)
except KeyError:
self.add_plugins()
return super().__getitem__(item)

def __delitem__(self, *args, **kwargs):
try:
super().__delitem__(*args, **kwargs)
except KeyError:
self.add_plugins()
super().__delitem__(*args, **kwargs)


models = _ModelManager(_models)


def get_models():
"""Returns the dictionary of current models.

Ensures that plugins are added to the dictionary first.
"""
models.add_plugins()
return models


def get_model(model_name):
"""Retrieve the given model.

Parameters
----------
model_name : str
The name of the model to get.

Returns
-------
model :
The requested model.
"""
return get_models()[model_name]


def available_models():
"""List the currently available models."""
return list(get_models().keys())


def register_model(model):
"""Makes a custom model available to PyCBC.

The provided model will be added to the dictionary of models that PyCBC
knows about, using the model's ``name`` attribute. If the ``name`` is the
same as a model that already exists in PyCBC, a warning will be printed.

Parameters
----------
model : pycbc.inference.models.base.BaseModel
The model to use. The model should be a sub-class of
:py:class:`BaseModel <pycbc.inference.models.base.BaseModel>` to ensure
it has the correct API for use within ``pycbc_inference``.
"""
get_models().add_model(model)