From 8e634dda3dc34c47fbf37fbb561705152c99793f Mon Sep 17 00:00:00 2001 From: Collin Capano Date: Fri, 27 May 2022 16:31:31 -0400 Subject: [PATCH 1/6] add ability to load plugin models --- pycbc/inference/models/__init__.py | 46 ++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/pycbc/inference/models/__init__.py b/pycbc/inference/models/__init__.py index 0f47348fa95..05c8c9b0387 100644 --- a/pycbc/inference/models/__init__.py +++ b/pycbc/inference/models/__init__.py @@ -20,6 +20,8 @@ """ +from .base import BaseModel +from .base_data import BaseDataModel from .analytic import (TestEggbox, TestNormal, TestRosenbrock, TestVolcano, TestPrior, TestPosterior) from .gaussian_noise import GaussianNoise @@ -197,3 +199,47 @@ def read_from_config(cp, **kwargs): Relative, HierarchicalModel, )} + + +# +# ============================================================================= +# +# Plugin utilities +# +# ============================================================================= +# +def add_custom_model(model, force=False): + """Makes a custom model available to PyCBC. + + The provided model will be added to the dictionary of models that PyCBC + knows about, using the model's ``name`` attribute. If the ``name`` is the + same as a model that already exists in PyCBC, a :py:exc:`RuntimeError` will + be raised unless the ``force`` option is set to ``True``. + + Parameters + ---------- + model : pycbc.inference.models.base.BaseModel + The model to use. The model should be a sub-class of + :py:class:`BaseModel ` to ensure + it has the correct API for use within ``pycbc_inference``. + force : bool, optional + Add the model even if its ``name`` attribute is the same as a model + that is already in :py:data:`pycbc.inference.models.models`. Otherwise, + a :py:exc:`RuntimeError` will be raised. Default is ``False``. + """ + if model.name in models and not force: + raise RuntimeError("Cannot load plugin model {}; the name is already " + "in use.".format(model.name)) + models[model.name] = model + + +def retrieve_model_plugins(): + """Retrieves and processes external model plugins. + """ + import pkg_resources + # Check for fd waveforms + for plugin in pkg_resources.iter_entry_points('pycbc.inference.models'): + add_custom_model(plugin.resolve()) + + +retrieve_model_plugins() From 4e8499d9fefa37df912b3e16eb9bfa4a12d9d318 Mon Sep 17 00:00:00 2001 From: Collin Capano Date: Sat, 28 May 2022 09:31:57 -0400 Subject: [PATCH 2/6] try to avoid circular imports --- pycbc/inference/__init__.py | 4 ++ pycbc/inference/models/__init__.py | 46 ----------------------- pycbc/inference/plugin.py | 60 ++++++++++++++++++++++++++++++ 3 files changed, 64 insertions(+), 46 deletions(-) create mode 100644 pycbc/inference/plugin.py diff --git a/pycbc/inference/__init__.py b/pycbc/inference/__init__.py index a52dc2e94f7..9b8f0be602c 100644 --- a/pycbc/inference/__init__.py +++ b/pycbc/inference/__init__.py @@ -2,3 +2,7 @@ # pylint: disable=unused-import from . import (models, sampler, io) from . import (burn_in, entropy, gelman_rubin, geweke, option_utils) +from .plugin import retrieve_model_plugins as _retrieve_model_plugins + + +_retrieve_model_plugins(models.models) diff --git a/pycbc/inference/models/__init__.py b/pycbc/inference/models/__init__.py index 05c8c9b0387..0f47348fa95 100644 --- a/pycbc/inference/models/__init__.py +++ b/pycbc/inference/models/__init__.py @@ -20,8 +20,6 @@ """ -from .base import BaseModel -from .base_data import BaseDataModel from .analytic import (TestEggbox, TestNormal, TestRosenbrock, TestVolcano, TestPrior, TestPosterior) from .gaussian_noise import GaussianNoise @@ -199,47 +197,3 @@ def read_from_config(cp, **kwargs): Relative, HierarchicalModel, )} - - -# -# ============================================================================= -# -# Plugin utilities -# -# ============================================================================= -# -def add_custom_model(model, force=False): - """Makes a custom model available to PyCBC. - - The provided model will be added to the dictionary of models that PyCBC - knows about, using the model's ``name`` attribute. If the ``name`` is the - same as a model that already exists in PyCBC, a :py:exc:`RuntimeError` will - be raised unless the ``force`` option is set to ``True``. - - Parameters - ---------- - model : pycbc.inference.models.base.BaseModel - The model to use. The model should be a sub-class of - :py:class:`BaseModel ` to ensure - it has the correct API for use within ``pycbc_inference``. - force : bool, optional - Add the model even if its ``name`` attribute is the same as a model - that is already in :py:data:`pycbc.inference.models.models`. Otherwise, - a :py:exc:`RuntimeError` will be raised. Default is ``False``. - """ - if model.name in models and not force: - raise RuntimeError("Cannot load plugin model {}; the name is already " - "in use.".format(model.name)) - models[model.name] = model - - -def retrieve_model_plugins(): - """Retrieves and processes external model plugins. - """ - import pkg_resources - # Check for fd waveforms - for plugin in pkg_resources.iter_entry_points('pycbc.inference.models'): - add_custom_model(plugin.resolve()) - - -retrieve_model_plugins() diff --git a/pycbc/inference/plugin.py b/pycbc/inference/plugin.py new file mode 100644 index 00000000000..37da51d3434 --- /dev/null +++ b/pycbc/inference/plugin.py @@ -0,0 +1,60 @@ +# Copyright (C) 2022 Collin Capano +# This program is free software; you can redistribute it and/or modify it +# under the terms of the GNU General Public License as published by the +# Free Software Foundation; either version 3 of the License, or (at your +# option) any later version. +# +# This program is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General +# Public License for more details. +# +# You should have received a copy of the GNU General Public License along +# with this program; if not, write to the Free Software Foundation, Inc., +# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. + + +# +# ============================================================================= +# +# Preamble +# +# ============================================================================= +# + +"""Utilities for plugin model discovery.""" + + +def retrieve_model_plugins(model_dict): + """Retrieves and processes external model plugins. + """ + import pkg_resources + # Check for fd waveforms + for plugin in pkg_resources.iter_entry_points('pycbc.inference.models'): + add_custom_model(plugin.resolve(), model_dict) + + +def add_custom_model(model, models, force=False): + """Makes a custom model available to PyCBC. + + The provided model will be added to the dictionary of models that PyCBC + knows about, using the model's ``name`` attribute. If the ``name`` is the + same as a model that already exists in PyCBC, a :py:exc:`RuntimeError` will + be raised unless the ``force`` option is set to ``True``. + + Parameters + ---------- + model : pycbc.inference.models.base.BaseModel + The model to use. The model should be a sub-class of + :py:class:`BaseModel ` to ensure + it has the correct API for use within ``pycbc_inference``. + force : bool, optional + Add the model even if its ``name`` attribute is the same as a model + that is already in :py:data:`pycbc.inference.models.models`. Otherwise, + a :py:exc:`RuntimeError` will be raised. Default is ``False``. + """ + #from pycbc.inference.models import models + if model.name in models and not force: + raise RuntimeError("Cannot load plugin model {}; the name is already " + "in use.".format(model.name)) + models[model.name] = model From 361ca43daac9e4991e4bb256e83c0c97ec46ef67 Mon Sep 17 00:00:00 2001 From: Collin Capano Date: Sat, 28 May 2022 23:11:52 -0400 Subject: [PATCH 3/6] fix circular import by using functions to retrieve models --- pycbc/inference/__init__.py | 5 -- pycbc/inference/models/__init__.py | 79 +++++++++++++++++++++++++++++- pycbc/inference/plugin.py | 60 ----------------------- 3 files changed, 77 insertions(+), 67 deletions(-) delete mode 100644 pycbc/inference/plugin.py diff --git a/pycbc/inference/__init__.py b/pycbc/inference/__init__.py index 9b8f0be602c..55c3d58ec84 100644 --- a/pycbc/inference/__init__.py +++ b/pycbc/inference/__init__.py @@ -1,8 +1,3 @@ - # pylint: disable=unused-import from . import (models, sampler, io) from . import (burn_in, entropy, gelman_rubin, geweke, option_utils) -from .plugin import retrieve_model_plugins as _retrieve_model_plugins - - -_retrieve_model_plugins(models.models) diff --git a/pycbc/inference/models/__init__.py b/pycbc/inference/models/__init__.py index 0f47348fa95..3e36ccd0d8d 100644 --- a/pycbc/inference/models/__init__.py +++ b/pycbc/inference/models/__init__.py @@ -20,6 +20,9 @@ """ +from pkg_resources import iter_entry_points as _iter_entry_points +from .base import BaseModel +from .base_data import BaseDataModel from .analytic import (TestEggbox, TestNormal, TestRosenbrock, TestVolcano, TestPrior, TestPosterior) from .gaussian_noise import GaussianNoise @@ -176,10 +179,10 @@ def read_from_config(cp, **kwargs): """ # use the name to get the distribution name = cp.get("model", "name") - return models[name].from_config(cp, **kwargs) + return get_model(name).from_config(cp, **kwargs) -models = {_cls.name: _cls for _cls in ( +_models = {_cls.name: _cls for _cls in ( TestEggbox, TestNormal, TestRosenbrock, @@ -197,3 +200,75 @@ def read_from_config(cp, **kwargs): Relative, HierarchicalModel, )} + + + +def register_model(model, force=False): + """Makes a custom model available to PyCBC. + + The provided model will be added to the dictionary of models that PyCBC + knows about, using the model's ``name`` attribute. If the ``name`` is the + same as a model that already exists in PyCBC, a :py:exc:`RuntimeError` will + be raised unless the ``force`` option is set to ``True``. + + Parameters + ---------- + model : pycbc.inference.models.base.BaseModel + The model to use. The model should be a sub-class of + :py:class:`BaseModel ` to ensure + it has the correct API for use within ``pycbc_inference``. + force : bool, optional + Add the model even if its ``name`` attribute is the same as a model + that is already in :py:data:`pycbc.inference.models.models`. Otherwise, + a :py:exc:`RuntimeError` will be raised. Default is ``False``. + """ + if model.name in _models and not force: + raise RuntimeError("Cannot add model {}; the name is already in use." + .format(model.name)) + _models[model.name] = model + + +class _ModelManager: + """Retrieve the dictionary of available models. + + The first time this is called, any plugin models that are available will be + added to the dictionary before returning. + + Returns + ------- + dict : + Dictionary of model names -> models. + """ + def __init__(self): + self.retrieve_plugins = True + + def __call__(self): + if self.retrieve_plugins: + for plugin in _iter_entry_points('pycbc.inference.models'): + register_model(plugin.resolve()) + self.retrieve_plugins = False + return _models + + +get_models = _ModelManager() + + +def get_model(model_name): + """Retrieve the given model. + + Parameters + ---------- + model_name : str + The name of the model to get. + + Returns + ------- + model : + The requested model. + """ + return get_models()[model_name] + + +def available_models(): + """List the currently available models.""" + return list(get_models().keys()) diff --git a/pycbc/inference/plugin.py b/pycbc/inference/plugin.py deleted file mode 100644 index 37da51d3434..00000000000 --- a/pycbc/inference/plugin.py +++ /dev/null @@ -1,60 +0,0 @@ -# Copyright (C) 2022 Collin Capano -# This program is free software; you can redistribute it and/or modify it -# under the terms of the GNU General Public License as published by the -# Free Software Foundation; either version 3 of the License, or (at your -# option) any later version. -# -# This program is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General -# Public License for more details. -# -# You should have received a copy of the GNU General Public License along -# with this program; if not, write to the Free Software Foundation, Inc., -# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. - - -# -# ============================================================================= -# -# Preamble -# -# ============================================================================= -# - -"""Utilities for plugin model discovery.""" - - -def retrieve_model_plugins(model_dict): - """Retrieves and processes external model plugins. - """ - import pkg_resources - # Check for fd waveforms - for plugin in pkg_resources.iter_entry_points('pycbc.inference.models'): - add_custom_model(plugin.resolve(), model_dict) - - -def add_custom_model(model, models, force=False): - """Makes a custom model available to PyCBC. - - The provided model will be added to the dictionary of models that PyCBC - knows about, using the model's ``name`` attribute. If the ``name`` is the - same as a model that already exists in PyCBC, a :py:exc:`RuntimeError` will - be raised unless the ``force`` option is set to ``True``. - - Parameters - ---------- - model : pycbc.inference.models.base.BaseModel - The model to use. The model should be a sub-class of - :py:class:`BaseModel ` to ensure - it has the correct API for use within ``pycbc_inference``. - force : bool, optional - Add the model even if its ``name`` attribute is the same as a model - that is already in :py:data:`pycbc.inference.models.models`. Otherwise, - a :py:exc:`RuntimeError` will be raised. Default is ``False``. - """ - #from pycbc.inference.models import models - if model.name in models and not force: - raise RuntimeError("Cannot load plugin model {}; the name is already " - "in use.".format(model.name)) - models[model.name] = model From ca962e0485cd35946060846ae0868d496375ec67 Mon Sep 17 00:00:00 2001 From: Collin Capano Date: Sat, 28 May 2022 23:14:31 -0400 Subject: [PATCH 4/6] use get_models for building the models table --- docs/_include/models-table.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/_include/models-table.py b/docs/_include/models-table.py index 9df816d6dcd..863b87ed73c 100644 --- a/docs/_include/models-table.py +++ b/docs/_include/models-table.py @@ -19,10 +19,10 @@ """Prints an RST table of available models from the inference.models module. """ -from pycbc.inference.models import models +from pycbc.inference.models import get_models from _dict_to_rst import (rst_dict_table, format_class) -tbl = rst_dict_table(models, key_format='``\'{0}\'``'.format, +tbl = rst_dict_table(get_models(), key_format='``\'{0}\'``'.format, header=('Name', 'Class'), val_format=format_class) From 24734fd1ed8ea7e5c3538d83d6c976029a773424 Mon Sep 17 00:00:00 2001 From: Collin Capano Date: Wed, 1 Jun 2022 20:50:57 -0400 Subject: [PATCH 5/6] duck type models dict --- pycbc/inference/models/__init__.py | 151 +++++++++++++++++++++-------- 1 file changed, 112 insertions(+), 39 deletions(-) diff --git a/pycbc/inference/models/__init__.py b/pycbc/inference/models/__init__.py index 3e36ccd0d8d..54ae5ec3756 100644 --- a/pycbc/inference/models/__init__.py +++ b/pycbc/inference/models/__init__.py @@ -20,7 +20,9 @@ """ +import logging from pkg_resources import iter_entry_points as _iter_entry_points +from collections import UserDict as _UserDict from .base import BaseModel from .base_data import BaseDataModel from .analytic import (TestEggbox, TestNormal, TestRosenbrock, TestVolcano, @@ -202,55 +204,109 @@ def read_from_config(cp, **kwargs): )} - -def register_model(model, force=False): - """Makes a custom model available to PyCBC. - - The provided model will be added to the dictionary of models that PyCBC - knows about, using the model's ``name`` attribute. If the ``name`` is the - same as a model that already exists in PyCBC, a :py:exc:`RuntimeError` will - be raised unless the ``force`` option is set to ``True``. - - Parameters - ---------- - model : pycbc.inference.models.base.BaseModel - The model to use. The model should be a sub-class of - :py:class:`BaseModel ` to ensure - it has the correct API for use within ``pycbc_inference``. - force : bool, optional - Add the model even if its ``name`` attribute is the same as a model - that is already in :py:data:`pycbc.inference.models.models`. Otherwise, - a :py:exc:`RuntimeError` will be raised. Default is ``False``. - """ - if model.name in _models and not force: - raise RuntimeError("Cannot add model {}; the name is already in use." - .format(model.name)) - _models[model.name] = model - - -class _ModelManager: - """Retrieve the dictionary of available models. +class _ModelManager(dict): + """Sub-classes dictionary to manage the collection of available models. The first time this is called, any plugin models that are available will be added to the dictionary before returning. - - Returns - ------- - dict : - Dictionary of model names -> models. """ - def __init__(self): + def __init__(self, *args, **kwargs): self.retrieve_plugins = True + super().__init__(*args, **kwargs) + + def add_model(self, model): + """Adds a model to the dictionary. - def __call__(self): + If the given model has the same name as a model already in the + dictionary, the original model will be overridden. A warning will be + printed in that case. + """ + if super().__contains__(model.name): + logging.warn("Custom model %s will override a model of the same " + "name. If you don't want this, change the model's " + "name attribute and restart.", model.name) + self[model.name] = model + + def add_plugins(self): + """Adds any plugin models that are available. + + This will only add the plugins if ``self.retrieve_plugins = True``. + After this runs, ``self.retrieve_plugins`` is set to ``False``, so that + subsequent calls to this will no re-add models. + """ if self.retrieve_plugins: for plugin in _iter_entry_points('pycbc.inference.models'): - register_model(plugin.resolve()) + self.add_model(plugin.resolve()) self.retrieve_plugins = False - return _models - -get_models = _ModelManager() + def __len__(self): + self.add_plugins() + super().__len__() + + def __contains__(self, key): + self.add_plugins() + return super().__contains__(key) + + def get(self, *args): + self.add_plugins() + return super().get(*args) + + def popitem(self): + self.add_plugins() + return super().popitem() + + def pop(self, *args): + try: + return super().pop(*args) + except KeyError: + self.add_plugins() + return super().pop(*args) + + def keys(self): + self.add_plugins() + return super().keys() + + def values(self): + self.add_plugins() + return super().values() + + def items(self): + self.add_plugins() + return super().items() + + def __iter__(self): + self.add_plugins() + return super().__iter__() + + def __repr__(self): + self.add_plugins() + return super().__repr__() + + def __getitem__(self, item): + try: + return super().__getitem__(item) + except KeyError: + self.add_plugins() + return super().__getitem__(item) + + def __delitem__(self, *args, **kwargs): + try: + super().__delitem__(*args, **kwargs) + except KeyError: + self.add_plugins() + super().__delitem__(*args, **kwargs) + + +models = _ModelManager(_models) + + +def get_models(): + """Returns the dictionary of current models. + + Ensures that plugins are added to the dictionary first. + """ + models.add_plugins() + return models def get_model(model_name): @@ -272,3 +328,20 @@ def get_model(model_name): def available_models(): """List the currently available models.""" return list(get_models().keys()) + + +def register_model(model): + """Makes a custom model available to PyCBC. + + The provided model will be added to the dictionary of models that PyCBC + knows about, using the model's ``name`` attribute. If the ``name`` is the + same as a model that already exists in PyCBC, a warning will be printed. + + Parameters + ---------- + model : pycbc.inference.models.base.BaseModel + The model to use. The model should be a sub-class of + :py:class:`BaseModel ` to ensure + it has the correct API for use within ``pycbc_inference``. + """ + get_models().add_model(model) From cdd98c6c295e557a30356cd4873991a2d6a06bff Mon Sep 17 00:00:00 2001 From: Collin Capano Date: Thu, 2 Jun 2022 08:54:08 -0400 Subject: [PATCH 6/6] code climate --- pycbc/inference/models/__init__.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/pycbc/inference/models/__init__.py b/pycbc/inference/models/__init__.py index 54ae5ec3756..b12f3879c16 100644 --- a/pycbc/inference/models/__init__.py +++ b/pycbc/inference/models/__init__.py @@ -22,7 +22,6 @@ import logging from pkg_resources import iter_entry_points as _iter_entry_points -from collections import UserDict as _UserDict from .base import BaseModel from .base_data import BaseDataModel from .analytic import (TestEggbox, TestNormal, TestRosenbrock, TestVolcano, @@ -222,9 +221,9 @@ def add_model(self, model): printed in that case. """ if super().__contains__(model.name): - logging.warn("Custom model %s will override a model of the same " - "name. If you don't want this, change the model's " - "name attribute and restart.", model.name) + logging.warning("Custom model %s will override a model of the " + "same name. If you don't want this, change the " + "model's name attribute and restart.", model.name) self[model.name] = model def add_plugins(self):