Skip to content

Commit

Permalink
fix: various registry and entry points related fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
percevalw committed Sep 1, 2024
1 parent 3beeb6c commit f78f6bb
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 13 deletions.
1 change: 1 addition & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
- Support using config files with scripts without a dedicated section header
- Disable configparser interpolation (% symbol)
- Better support for escaped strings in config files
- Various registry-related fixes

### Added

Expand Down
46 changes: 33 additions & 13 deletions confit/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@
to_legacy_error,
)

try:
import importlib.metadata as importlib_metadata
except ImportError:
import importlib_metadata

PYDANTIC_V1 = pydantic.VERSION.split(".")[0] == "1"


Expand Down Expand Up @@ -309,6 +314,7 @@ def wrapper_function(*args: Any, **kwargs: Any) -> Any:
"__annotations__",
"__defaults__",
"__kwdefaults__",
"__signature__",
),
)
def wrapper_function(*args: Any, **kwargs: Any) -> Any:
Expand Down Expand Up @@ -427,7 +433,6 @@ def invoke(func, params):
return resolved

def wrap_and_register(fn: catalogue.InFunc) -> catalogue.InFunc:

if save_params is not None:
_check_signature_for_save_params(
fn if not isinstance(fn, type) else fn.__init__
Expand Down Expand Up @@ -463,7 +468,18 @@ def deprecated_fn(*args, **kwargs):
else:
return wrap_and_register

def get(self, name: str) -> catalogue.InFunc:
def get_entry_points(self):
"""Get registered entry points from other packages for this namespace.
RETURNS (Dict[str, Any]): Entry points, keyed by name.
"""
entrypoints = importlib_metadata.entry_points()
if hasattr(entrypoints, "select"):
return entrypoints.select(group=self.entry_point_namespace)
else: # dict
return entrypoints.get(self.entry_point_namespace, [])

def get(self, name: str):
"""
Get the registered function for a given name.
Expand All @@ -480,17 +496,21 @@ def get(self, name: str) -> catalogue.InFunc:
-------
catalogue.InFunc
"""
if self.entry_points:
from_entry_point = self.get_entry_point(name)
if from_entry_point:
return from_entry_point
namespace = list(self.namespace) + [name]
if not catalogue.check_exists(*namespace):
raise catalogue.RegistryError(
f"Can't find '{name}' in registry {' -> '.join(self.namespace)}. "
f"Available names: {', '.join(sorted(self.get_available())) or 'none'}"
)
return catalogue._get(namespace)
path = list(self.namespace) + [name]
try:
return catalogue._get(path)
except catalogue.RegistryError:
if self.entry_points:
from_entry_point = self.get_entry_point(name)
if from_entry_point:
return from_entry_point
if not catalogue.check_exists(*path):
raise catalogue.RegistryError(
f"Can't find '{name}' in registry {' -> '.join(self.namespace)}. "
f"Available names: "
f"{', '.join(sorted(self.get_available())) or 'none'}"
)
return catalogue._get(path)

def get_available(self) -> Sequence[str]:
"""Get all functions for a given namespace.
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ fixable = ["E", "F", "W", "I"]

[tool.ruff.isort]
known-first-party = ["confit"]
known-third-party = ["build"]

[tool.mypy]
plugins = "pydantic.mypy"
Expand Down

0 comments on commit f78f6bb

Please sign in to comment.