Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pycbc/inference/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,7 @@
# pylint: disable=unused-import
from . import (models, sampler, io)
from . import (burn_in, entropy, gelman_rubin, geweke, option_utils)
from .plugin import retrieve_model_plugins as _retrieve_model_plugins


_retrieve_model_plugins(models.models, autoinit=True)
67 changes: 67 additions & 0 deletions pycbc/inference/plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright (C) 2022 Collin Capano
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.


#
# =============================================================================
#
# Preamble
#
# =============================================================================
#

"""Utilities for plugin model discovery."""

_autoinit = False
def retrieve_model_plugins(model_dict, autoinit=False):
"""Retrieves and processes external model plugins.
"""
# Avoid circular auto initialization
global _autoinit
if autoinit and _autoinit:
return
elif autoinit:
_autoinit = True

import pkg_resources
# Check for fd waveforms
for plugin in pkg_resources.iter_entry_points('pycbc.inference.models'):
add_custom_model(plugin.resolve(), model_dict)


def add_custom_model(model, models, force=False):
"""Makes a custom model available to PyCBC.

The provided model will be added to the dictionary of models that PyCBC
knows about, using the model's ``name`` attribute. If the ``name`` is the
same as a model that already exists in PyCBC, a :py:exc:`RuntimeError` will
be raised unless the ``force`` option is set to ``True``.

Parameters
----------
model : pycbc.inference.models.base.BaseModel
The model to use. The model should be a sub-class of
:py:class:`BaseModel <pycbc.inference.models.base.BaseModel>` to ensure
it has the correct API for use within ``pycbc_inference``.
force : bool, optional
Add the model even if its ``name`` attribute is the same as a model
that is already in :py:data:`pycbc.inference.models.models`. Otherwise,
a :py:exc:`RuntimeError` will be raised. Default is ``False``.
"""
#from pycbc.inference.models import models
if model.name in models and not force:
raise RuntimeError("Cannot load plugin model {}; the name is already "
"in use.".format(model.name))
models[model.name] = model