From bfa9e03f056d9963bffad4fb92f7fbae5fe5119c Mon Sep 17 00:00:00 2001 From: Adam Hillier <7688302+AdamHillier@users.noreply.github.com> Date: Thu, 5 Dec 2019 14:34:39 +0000 Subject: [PATCH] Components as decorators (#84) * Make components and tasks decorators rather than classes. * Expose CLI. * Bug fix. * Fix another bug. * Make experiment loss optional. * Fix yet another bug. * Bug fix once more. * Another bug fix. * Update defaults. * WIP * Tidy-up. * Export `configure` correctly. * Update examples. * Make linting happy. * Update zookeeper/core/component.py Co-Authored-By: Lukas Geiger * Remove erroneous prints. * Review suggestion. * Move `colorama` import to top. --- examples/larq_experiment.py | 14 +- zookeeper/__init__.py | 17 +- zookeeper/cli_test.py | 63 ---- zookeeper/component.py | 426 ---------------------- zookeeper/component_test.py | 292 --------------- zookeeper/core/__init__.py | 5 + zookeeper/{ => core}/cli.py | 27 +- zookeeper/core/cli_test.py | 55 +++ zookeeper/core/component.py | 543 ++++++++++++++++++++++++++++ zookeeper/core/component_test.py | 351 ++++++++++++++++++ zookeeper/core/task.py | 56 +++ zookeeper/core/task_test.py | 45 +++ zookeeper/core/utils.py | 71 ++++ zookeeper/task.py | 27 -- zookeeper/task_test.py | 36 -- zookeeper/tf/__init__.py | 13 + zookeeper/{ => tf}/dataset.py | 37 +- zookeeper/{ => tf}/experiment.py | 11 +- zookeeper/{ => tf}/model.py | 4 +- zookeeper/{ => tf}/preprocessing.py | 4 +- zookeeper/utils.py | 94 ----- 21 files changed, 1159 insertions(+), 1032 deletions(-) delete mode 100644 zookeeper/cli_test.py delete mode 100644 zookeeper/component.py delete mode 100644 zookeeper/component_test.py create mode 100644 zookeeper/core/__init__.py rename zookeeper/{ => core}/cli.py (54%) create mode 100644 zookeeper/core/cli_test.py create mode 100644 zookeeper/core/component.py create mode 100644 zookeeper/core/component_test.py create mode 100644 zookeeper/core/task.py create mode 100644 zookeeper/core/task_test.py create mode 100644 zookeeper/core/utils.py delete mode 100644 zookeeper/task.py delete mode 100644 zookeeper/task_test.py create mode 100644 zookeeper/tf/__init__.py rename zookeeper/{ => tf}/dataset.py (79%) rename zookeeper/{ => tf}/experiment.py (73%) rename zookeeper/{ => tf}/model.py (82%) rename zookeeper/{ => tf}/preprocessing.py (96%) delete mode 100644 zookeeper/utils.py diff --git a/examples/larq_experiment.py b/examples/larq_experiment.py index e58fd61..260b45b 100644 --- a/examples/larq_experiment.py +++ b/examples/larq_experiment.py @@ -7,19 +7,20 @@ import larq as lq import tensorflow as tf -import tensorflow_datasets as tfds -from zookeeper import Dataset, Experiment, Model, Preprocessing, TFDSDataset -from zookeeper.cli import add_task_to_cli, cli +from zookeeper import cli, component, task +from zookeeper.tf import Dataset, Experiment, Model, Preprocessing, TFDSDataset +@component class Cifar10(TFDSDataset): name = "cifar10" # CIFAR-10 has only train and test, so validate on test. - train_split = tfds.Split.TRAIN - validation_split = tfds.Split.TEST + train_split = "train" + validation_split = "test" +@component class PadCropAndFlip(Preprocessing): pad_size: int output_size: int @@ -44,6 +45,7 @@ def output(self, data): return data["label"] +@component class BinaryNet(Model): dataset: Dataset preprocessing: Preprocessing @@ -107,7 +109,7 @@ def build(self, input_shape): ) -@add_task_to_cli +@task class BinaryNetCifar10(Experiment): dataset = Cifar10() preprocessing = PadCropAndFlip(pad_size=40, output_size=32) diff --git a/zookeeper/__init__.py b/zookeeper/__init__.py index 46e3d6e..628c4b7 100644 --- a/zookeeper/__init__.py +++ b/zookeeper/__init__.py @@ -1,16 +1,3 @@ -from zookeeper.component import Component -from zookeeper.dataset import Dataset, TFDSDataset -from zookeeper.experiment import Experiment -from zookeeper.model import Model -from zookeeper.preprocessing import Preprocessing -from zookeeper.task import Task +from zookeeper.core import cli, component, configure, task -__all__ = [ - "Component", - "Dataset", - "Experiment", - "Model", - "Preprocessing", - "Task", - "TFDSDataset", -] +__all__ = ["cli", "component", "configure", "task"] diff --git a/zookeeper/cli_test.py b/zookeeper/cli_test.py deleted file mode 100644 index ac5a59f..0000000 --- a/zookeeper/cli_test.py +++ /dev/null @@ -1,63 +0,0 @@ -import pytest -from click.testing import CliRunner - -from zookeeper.cli import add_task_to_cli, cli -from zookeeper.task import Task - - -@pytest.fixture -def test_task(): - # We have to define `TestTask` inside a pytest fixture so that it gets - # reinstantiated for every test. - - @add_task_to_cli - class TestTask(Task): - a: int - b: str = "foo" - - def run(self): - print(self.a, self.b) - - return None - - -runner = CliRunner(mix_stderr=False) - - -def test_pass_param_values(test_task): - # We should be able to pass parameter values through the CLI. - result = runner.invoke(cli, ["test_task", "a=5"]) - assert result.exit_code == 0 - assert result.output == "5 foo\n" - - -def test_param_key_valid_characters(test_task): - # We should be able to pass keys with underscores and full stops and - # capitals. It's okay here that the param with name `x.y_z.A` doesn't - # actually exist. - result = runner.invoke(cli, ["test_task", "a=5", "x.y_z.A=1.0"]) - assert result.exit_code == 0 - - -def test_param_key_invalid_characters(test_task): - # Keys with invalid characters such as '-' or '@' should not be accepted. - result = runner.invoke(cli, ["test_task", "a=5", "x-y=1.0"]) - assert result.exit_code == 2 - result = runner.invoke(cli, ["test_task", "a=5", "x@y=1.0"]) - assert result.exit_code == 2 - - -def test_override_param_values(test_task): - # We should be able to override existing parameter values through the CLI. - result = runner.invoke(cli, ["test_task", "a=5", "b=bar"]) - assert result.exit_code == 0 - assert result.output == "5 bar\n" - - -def test_override_param_complex_string(test_task): - # We should be able to pass complex strings, including paths. - result = runner.invoke( - cli, ["test_task", "a=5", "b=https://some-path/foo/bar@somewhere"] - ) - assert result.exit_code == 0 - assert result.output == "5 https://some-path/foo/bar@somewhere\n" diff --git a/zookeeper/component.py b/zookeeper/component.py deleted file mode 100644 index 2218e8b..0000000 --- a/zookeeper/component.py +++ /dev/null @@ -1,426 +0,0 @@ -import inspect -from typing import Any, Dict, Optional - -from prompt_toolkit import print_formatted_text -from typeguard import check_type - -from zookeeper.utils import ( - convert_to_snake_case, - get_concrete_subclasses, - prompt_for_component, - promt_for_param_value, - type_name_str, -) - -try: # pragma: no cover - from colorama import Fore - - BLUE, YELLOW, RESET = Fore.BLUE, Fore.YELLOW, Fore.RESET -except ImportError: # pragma: no cover - BLUE = YELLOW = RESET = "" - -# Indent for nesting in the string representation -INDENT = " " * 4 - - -def defined_on_self_or_ancestor(self, name): - """ - Test if the annotation `name` exists on `self` or a component ancestor of - `self` with a defined value. If so, return the instance on which `name` is - defined. Otherwise, return `None`. - """ - - # Using `hasattr` is not safe, as it is implemented with `getattr` wrapped - # in a try-catch (Python is terrible), so we need to check `dir(self)`. - if name in self.__component_annotations__ and name in dir(self): - return self - if self.__component_parent__: - return defined_on_self_or_ancestor(self.__component_parent__, name) - return None - - -def str_key_val(key, value, color=True, single_line=False): - if isinstance(value, Component): - if single_line: - value = repr(value) - else: - value = f"\n{INDENT}".join(str(value).split("\n")) - elif callable(value): - value = "" - elif type(value) == str: - value = f'"{value}"' - space = "" if single_line else " " - return ( - f"{BLUE}{key}{RESET}{space}={space}{YELLOW}{value}{RESET}" - if color - else f"{key}{space}={space}{value}" - ) - - -class Component: - """ - A generic, modular component class designed to be easily configurable. - - Components can have configurable parameters, which can be either generic - Python objects or nested sub-components. These are declared with class-level - Python type annotations, in the same way that elements of - [dataclasses](https://docs.python.org/3/library/dataclasses.html) are - declared. After instantiation, components are 'configured' with a - configuration dictionary; this process automatically injects the correct - parameters into the component and all subcomponents. Component parameters - can have defaults set, either in the class definition or passed via - `__init__`, but configuration values passed to `configure` will always take - precedence over these values. - - If a nested sub-component child declares a parameter with the same name as a - parameter in one of its ancestors, it will receive the same configured value - as the parent does. Howevever, configuration is scoped: if the parameter on - the child, or on a _closer anscestor_, is configured with a different value, - then that value will override the one from the original parent. - - Configuration can be interactive. In this case, the method will prompt for - missing parameters via the CLI. - - The following example illustrates the configuration mechanism with scoped - configuration: - - ``` - class A(Component): - x: int - z: float - - def __call__(self): - return str(self.x) + "_" + str(self.z) - - class B(Component): - a: A - y: str = "foo" - - def __call__(self): - return self.y + " / " + self.a() - - class C(Component): - b: B - x: int - z: float = 3.14 - - def __call__(self): - return str(self.x) + "_" + str(self.z) + " / " + self.b() - - - c = C() - c.configure({ - "x": 5, # (1) - "b.x": 10, # (2) - "b.a.x": 15, # (3) - - "b.y": "foo", # (4) - - "b.z": 2.71 # (5) - }) - print(c) - - >> C( - b = B( - a = A( - x = 15, # (3) overrides (2) overrides (1) - z = 2.71 # Inherits from parent: (5) - ), - y = "foo" # (4) overrides the default - ), - x = 5, # Only (1) applies - z = 3.14 # The default is taken - ) - ``` - """ - - # The name of the component. - __component_name__ = None - - # If this instance is nested in another component, a reference to that - # parent instance. - __component_parent__ = None - - # All annotations which apply to the class, including those inherited from - # superclasses. This is populated in `__init__`. - __component_annotations__ = {} - - # Whether or not the component has been configured. - __component_configured__ = False - - def __init__(self, **kwargs): - """ - `kwargs` may only contain argument names corresponding to component - annotations. The passed values will be set on the instance. - """ - - self.__component_name__ = self.__class__.__name__ - - # Populate `self.__component_annotations__` with all annotations set on - # this class and all superclasses. We have to go through the MRO chain - # and collect them in reverse order so that they are correctly - # overriden. - annotations = {} - for base_class in reversed(inspect.getmro(self.__class__)): - annotations.update(getattr(base_class, "__annotations__", {})) - annotations.update(getattr(self, "__annotations__", {})) - self.__component_annotations__ = annotations - - for k, v in kwargs.items(): - if k in self.__component_annotations__: - setattr(self, k, v) - else: - raise ValueError( - f"Argument '{k}' passed to `__init__` does not correspond to " - f"any annotation of '{type_name_str(self.__class__)}'." - ) - - def __init_subclass__(cls, *args, **kwargs): - # Ensure that `__init__` does not accept any positional arguments. - for i, (name, param) in enumerate( - inspect.signature(cls.__init__).parameters.items() - ): - if ( - i > 0 - and param.default == inspect.Parameter.empty - and param.kind - not in [inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD] - ): - raise ValueError( - "The `__init__` method of a `Component` sub-class must not accept " - "any positional arguments, as the component configuration process " - "requires component classes to be instantiable without arguments. " - f"Please set a default value for the parameter '{name}' of " - f"`{type_name_str(cls)}.__init__`." - ) - - def __getattr__(self, name): - # This is only called if the attribute doesn't exist on the instance - # (i.e. on `self`, on the class, or on any superclass). When this - # happens, if `name` is a declared annotation which is also declared on - # some ancestor with a defined value for `name`, return that value. - if name in self.__component_annotations__: - ancestor = defined_on_self_or_ancestor(self, name) - if ancestor is not None: - return getattr(ancestor, name) - raise AttributeError( - f"Component '{self.__component_name__}' does not have any attribute " - f"'{name}'." - ) - - def __setattr__(self, name, value): - # Type-check annotated values. - if name in self.__component_annotations__: - annotated_type = self.__component_annotations__[name] - try: - check_type(name, value, annotated_type) - # Because boolean `True` and `False` are coercible to ints and - # floats, `typeguard.check_type` doesn't throw if we e.g. pass - # `True` to a value expecting a float. This would, however, - # likely be a user error, so explicitly check for this. - if annotated_type in [float, int] and isinstance(value, bool): - raise TypeError - except TypeError: - raise TypeError( - f"Attempting to set parameter '{self.__component_name__}.{name}' " - f"which has annotated type '{type_name_str(annotated_type)}' with " - f"value '{value}'." - ) from None - super().__setattr__(name, value) - - def __str__(self): - params = f",\n{INDENT}".join( - [str_key_val(k, getattr(self, k)) for k in self.__component_annotations__] - ) - return f"{self.__class__.__name__}(\n{INDENT}{params}\n)" - - def __repr__(self): - params = ", ".join( - [ - str_key_val(k, getattr(self, k), color=False, single_line=True) - for k in self.__component_annotations__ - ] - ) - return f"{self.__class__.__name__}({params})" - - def configure( - self, - conf: Dict[str, Any], - name: Optional[str] = None, - parent: Optional["Component"] = None, - interactive: bool = False, - ): - """ - Configure the component instance with parameters from the `conf` dict. - - Configuration passed through `conf` takes precedence over and will - overwrite any values already set on the instance - either class defaults - or those passed via `__init__`. - """ - - # Configuration can only happen once. - if self.__component_configured__: - raise ValueError( - f"Component '{self.__component_name__}' has already been configured." - ) - - if name is not None: - self.__component_name__ = name - self.__component_parent__ = parent - - # Divide the annotations into those which are and those which are not - # nested components. We will process the non-component parameters first, - # because nested components may depend on parameter (non-component) - # values set in the parent. - non_component_annotations = [] - component_annotations = [] - - for k, v in self.__component_annotations__.items(): - # We have to be careful because `v` can be a `typing.Type` subclass - # e.g. `typing.List[float]`. - # - # In Python 3.7+, this will cause `issubclass(v, Component)` to - # throw, but `inspect.isclass(v)` will be `False`. - # - # In Python 3.6, `inspect.isclass(v)` will be `True`, but - # fortunately `issubclass(v, Component)` won't throw. - if inspect.isclass(v) and issubclass(v, Component): - component_annotations.append((k, v)) - else: - non_component_annotations.append((k, v)) - - # Process non-component annotations. - for k, v in non_component_annotations: - param_name = f"{self.__component_name__}.{k}" - param_type_name = v.__name__ if inspect.isclass(v) else v - - # The value from the `conf` dict takes priority. - if k in conf: - param_value = conf[k] - setattr(self, k, param_value) - - # If there's no config value but a value is already set on the - # instance (or a parent), no action needs to be taken. `__getattr__` - # is overriden to give the instance access to such values. - elif defined_on_self_or_ancestor(self, k) is not None: - pass - - # If we are running interactively, prompt for the missing value. Add - # it to the configuration so that it gets passed to any children. - elif interactive: - param_value = promt_for_param_value(param_name, v) - setattr(self, k, param_value) - conf[k] = param_value - - # If we're not running interactively and there's no value anywhere, - # raise an error. - else: - raise ValueError( - "No configuration value found for annotated parameter " - f"'{param_name}' of type '{param_type_name}'." - ) - - # Process nested component annotations. - for k, v in component_annotations: - param_name = f"{self.__component_name__}.{k}" - param_type_name = type_name_str(v) - concrete_subclasses = get_concrete_subclasses(v) - - # The value from the `conf` dict takes priority. - if k in conf: - instance = conf[k] - # The value might have been parsed from command-line arguments, - # in which case we expect a string naming the class. - if isinstance(instance, str): - for component_cls in concrete_subclasses: - if ( - instance == component_cls.__name__ - or instance == component_cls.__qualname__ - or instance == convert_to_snake_case(component_cls.__name__) - ): - instance = component_cls() - conf[k] = instance - break - setattr(self, k, instance) - - # If there's no config value but a value is already set on the - # instance (or a parent), no action needs to be taken. `__getattr__` - # is overriden to give the instance access to such values. - elif defined_on_self_or_ancestor(self, k) is not None: - pass - - # If there is no concrete subclass of `v`, raise an error. - elif len(concrete_subclasses) == 0: - raise ValueError( - "There is no defined, non-abstract class that can be instantiated " - f"to satisfy the annotated parameter '{param_name}' of type " - f"'{param_type_name}'." - ) - - # If there is only one concrete subclass of `v`, instantiate an - # instance of that class. - elif len(concrete_subclasses) == 1: - component_cls = list(concrete_subclasses)[0] - print_formatted_text( - f"'{type_name_str(component_cls)}' is the only concrete component " - "class that satisfies the type of the annotated parameter " - f"'{param_name}'. Using an instance of this class by default." - ) - # This is safe because we ban overriding `__init__`. - instance = component_cls() - setattr(self, k, instance) - - # If we are running interactively and there is more than one - # concrete subclass of `v`, prompt for the concrete subclass to - # instantiate. Add the instance to the configuation so that is can - # get passed to any children. - elif interactive: - component_cls = prompt_for_component(param_name, v) - # The is safe because we ban overriding `__init__`. - instance = component_cls() - setattr(self, k, instance) - - # If we're not running interactively and there is more than one - # concrete subclass of `v`, raise an error. - else: - raise ValueError( - f"Annotated parameter '{param_name}' of type '{param_type_name}' " - f"has no configured value. Please configure '{param_name}' with one " - f"of the following concrete subclasses of '{param_type_name}':\n " - + "\n ".join(list(type_name_str(c) for c in concrete_subclasses)) - ) - - # Recursively configure the nested sub-components. - for k, v in component_annotations: - param_name = f"{self.__component_name__}.{k}" - sub_component = getattr(self, k) - if not sub_component.__component_configured__: - # Configure the sub-component. The configuration we use consists - # of all non-scoped keys and any keys scoped to `k`, where the - # keys scoped to `k` override the non-scoped keys. - non_scoped_conf = {a: b for a, b in conf.items() if "." not in a} - k_scoped_conf = { - a[len(f"{k}.") :]: b - for a, b in conf.items() - if a.startswith(f"{k}.") - } - nested_conf = {**non_scoped_conf, **k_scoped_conf} - sub_component.configure( - nested_conf, name=param_name, parent=self, interactive=interactive - ) - - self.validate_configuration() - - self.__component_configured__ = True - - def validate_configuration(self): - """ - Called automatically at the end of `configure`. Subclasses should - override this method to provide fine-grained parameter validation. - Invalid configuration should be flagged by raising an error with a - descriptive error message. - """ - - # Checking for missing values is done in `configure`. Type-checking is - # done in `__setattr__`. - pass diff --git a/zookeeper/component_test.py b/zookeeper/component_test.py deleted file mode 100644 index 8c4051e..0000000 --- a/zookeeper/component_test.py +++ /dev/null @@ -1,292 +0,0 @@ -import re -from abc import ABC, abstractmethod -from typing import List -from unittest.mock import patch - -import pytest -from click import unstyle - -from zookeeper.component import Component - - -# Specify this as a fixture because we want the class to be re-defined each -# time. This is because we want class-level default attributes to be re-defined -# and re-initialised. -@pytest.fixture -def Parent(): - class AbstractGrandChild(Component, ABC): - a: int - b: str - c: List[float] - - @abstractmethod - def __call__(self): - pass - - class GrandChild1(AbstractGrandChild): - def __call__(self): - return f"grand_child_1_{self.b}_{self.c}" - - class GrandChild2(AbstractGrandChild): - def __call__(self): - return f"grand_child_2_{self.a}" - - class Child(Component): - b: str = "bar" - - grand_child: AbstractGrandChild = GrandChild1() - - def __call__(self): - return f"child_{self.b}_{self.grand_child()}" - - class Parent(Component): - a: int - b: str - - child: Child - - def __call__(self): - return f"root_{self.a}_{self.child()}" - - return Parent - - -def test_positional_args_init_error(): - # Defining a subclass which overrides `__init__` with positional arguments - # should raise a ValueError. - - with pytest.raises( - ValueError, - match=r"^The `__init__` method of a `Component` sub-class must not accept", - ): - - class C(Component): - def __init__(self, c, **kwargs): - super().__init__(**kwargs) - self.c = c - - -def test_init(Parent): - # Kwargs passed to `__init__` corresponding to annotated parameter should - # set those values, overriding the defaults. - - p = Parent(a=5) - assert p.a == 5 - - # Kwargs not corresponding to an annotated parameter should cause a - # ValueError to be raised. - with pytest.raises( - ValueError, - match=r"Argument 'd' passed to `__init__` does not correspond to any annotation", - ): - p = Parent(d="baz") - - -def test_configure_non_interactive_missing_param(Parent): - # Hydrating without configuring a value for an attribute without a default - # should raise a ValueError. - - p = Parent() - p_conf = {"a": 5} - with pytest.raises( - ValueError, match=r"^No configuration value found for annotated parameter" - ): - p.configure(p_conf, name="parent") - - -def test_configure_override_values(Parent): - # A configured instance should have its attributes corrected overriden. - - p = Parent() - p_conf = {"a": 10, "b": "foo", "c": [1.5, -1.2]} - - p.configure(p_conf, name="parent") - - # `a` should be correctly overrriden, `b` should take its default value. - assert p.a == 10 - assert p.b == "foo" - - -def test_configure_scoped_override(Parent): - # Configuration values should be correctly scoped. - - p = Parent() - p_conf = { - "a": 10, - "child.a": 15, - "b": "foo", - "child.grand_child.b": "baz", - "c": [1.5, -1.2], - "child.c": [-17.2], - "child.grand_child.c": [0, 4.2], - } - - p.configure(p_conf, name="parent") - - # The parent `p` should have the value `a` = 10. Even though a config value - # is declared for its scope, `p.child` should have no `a` value set, as it - # doesn't declare `a` as a dependency. Despite this, `p.child.grand_child` - # should get the value `a` = 15, as it lives inside the configuration scope - # of its parent, `p.child`. - assert p.a == 10 - assert not hasattr(p.child, "a") - assert p.child.grand_child.a == 15 - - # `b` is declared as a dependency at all three levels. The `baz` value - # should be scoped only to the grandchild, so `foo` will apply to both - # higher levels. - assert p.b == "foo" - assert p.child.b == "foo" - assert p.child.grand_child.b == "baz" - - # `c` is declared as a dependency only in the grandchild. The more specific - # scopes override the more general. - assert p.child.grand_child.c == [0, 4.2] - - -def test_configure_one_possible_component(): - # If there's only a single defined, non-abstract class that satisfies a - # declared sub-component depency of a component, then we expect `configure` - # to instantiate an instance of this class by default without prompting, but - # also warn that it has done so. - class A(Component): - def __call__(self): - return "hello world" - - class Parent(Component): - a: A - - def __call__(self): - return self.a.__call__() - - p = Parent() - - with patch("zookeeper.component.print_formatted_text") as print_formatted_text: - p.configure({}) - - print_formatted_text.assert_called_once() - assert len(print_formatted_text.call_args[0]) == 1 - assert re.search( - r" is the only concrete component class that satisfies the type of the " - "annotated parameter 'Parent.a'. Using an instance of this class by default.$", - print_formatted_text.call_args[0][0], - ) - - -def test_configure_interactive_prompt_for_missing_value(Parent): - # Configure with all configuration values specified apart from `c`. When - # running in interactive mode, we expect to be prompted to input this value. - - p = Parent() - p_conf = {"a": 10, "b": "foo"} - - c_value = [3.14, 2.7] - - with patch("zookeeper.utils.prompt", return_value=str(c_value)) as prompt: - p.configure(p_conf, name="parent", interactive=True) - - assert p.child.grand_child.c == c_value - prompt.assert_called_once() - - -def test_configure_interactive_prompt_for_subcomponent_choice(): - # Configure a parent with an unspecified child subcomponent. In interactive - # mode, we expect to be prompted to choose from the list of defined, - # concrete sub-components. - - class A(Component): - a: int = 5 - - def __call__(self): - return self.a - - class B(A): - def __call__(self): - return super().__call__() ** 3 - - class B2(B): - def __call__(self): - return super().__call__() + 1 - - class C(A): - def __call__(self): - return super().__call__() * 2 - - class Parent(Component): - child: A - - def __call__(self): - return self.child.__call__() - - p = Parent() - p_conf = {} - - # The prompt lists the concrete subclasses (alphabetically) and asks for an - # an integer input corresponding to an index in this list. The response '3' - # therefore selects `B2`. - with patch("zookeeper.utils.prompt", return_value=str(3)) as prompt: - p.configure(p_conf, interactive=True) - - assert isinstance(p.child, B2) - assert p() == 126 - prompt.assert_called_once() - - -def test_str_and_repr(Parent): - # `__str__` and `__repr__` should give formatted strings that represent - # nested components nicely. - - p = Parent() - p_conf = {"a": 10, "b": "foo", "c": [1.5, -1.2]} - - p.configure(p_conf, name="parent") - - assert ( - unstyle(repr(p)) - == """Parent(a=10, b="foo", child=Child(b="foo", grand_child=GrandChild1(a=10, b="foo", c=[1.5, -1.2])))""" - ) - assert ( - unstyle(str(p)) - == """Parent( - a = 10, - b = "foo", - child = Child( - b = "foo", - grand_child = GrandChild1( - a = 10, - b = "foo", - c = [1.5, -1.2] - ) - ) -)""" - ) - - -def test_type_check(): - class A(Component): - a: int = 0 - b: float = 1.5 - c: str = "foo" - - a = A() - - # Attempting to set an int parameter with a float. - with pytest.raises( - TypeError, - match=r"^Attempting to set parameter 'A.a' which has annotated type 'int' with value '4.5'.$", - ): - a.configure({"a": 4.5}) - - # Attempting to set a float parameter with a bool. - with pytest.raises( - TypeError, - match=r"^Attempting to set parameter 'A.b' which has annotated type 'float' with value 'True'.$", - ): - a.configure({"b": True}) - - # Attempting to set a string parameter with an int. - with pytest.raises( - TypeError, - match=r"^Attempting to set parameter 'A.c' which has annotated type 'str' with value '8'.$", - ): - a.configure({"c": 8}) diff --git a/zookeeper/core/__init__.py b/zookeeper/core/__init__.py new file mode 100644 index 0000000..ea0c24d --- /dev/null +++ b/zookeeper/core/__init__.py @@ -0,0 +1,5 @@ +from zookeeper.core.cli import cli +from zookeeper.core.component import component, configure +from zookeeper.core.task import task + +__all__ = ["component", "configure", "cli", "task"] diff --git a/zookeeper/cli.py b/zookeeper/core/cli.py similarity index 54% rename from zookeeper/cli.py rename to zookeeper/core/cli.py index 8cd3b35..707b19c 100644 --- a/zookeeper/cli.py +++ b/zookeeper/core/cli.py @@ -1,10 +1,8 @@ import re -from inspect import isclass import click -from zookeeper.task import Task -from zookeeper.utils import convert_to_snake_case, parse_value_from_string +from zookeeper.core.utils import parse_value_from_string @click.group() @@ -39,26 +37,3 @@ def convert(self, str_value, param, ctx): ) return key, value - - -def add_task_to_cli(task_cls: type): - """A decorator which adds a CLI command to run the Task.""" - - if not isclass(task_cls) or not issubclass(task_cls, Task): - raise ValueError( - "The decorator `add_task_to_cli` can only be applied to `zookeeper.Task` " - "subclasses." - ) - - task_name = convert_to_snake_case(task_cls.__name__) - - @cli.command(task_name) - @click.argument("config", type=ConfigParam(), nargs=-1) - @click.option("-i", "--interactive", is_flag=True, default=False) - def command(config, interactive): - config = {k: v for k, v in config} - task_instance = task_cls() - task_instance.configure(config, interactive=interactive) - task_instance.run() - - return task_cls diff --git a/zookeeper/core/cli_test.py b/zookeeper/core/cli_test.py new file mode 100644 index 0000000..e429a71 --- /dev/null +++ b/zookeeper/core/cli_test.py @@ -0,0 +1,55 @@ +from click import testing + +from zookeeper.core.cli import cli +from zookeeper.core.task import task + + +@task +class Task: + a: int + b: str = "foo" + + def run(self): + print(self.a, self.b) + + +runner = testing.CliRunner(mix_stderr=False) + + +def test_pass_param_values(): + # We should be able to pass parameter values through the CLI. + result = runner.invoke(cli, ["task", "a=5"]) + assert result.exit_code == 0 + assert result.output == "5 foo\n" + + +def test_param_key_valid_characters(): + # We should be able to pass keys with underscores and full stops and + # capitals. It's okay here that the param with name `x.y_z.A` doesn't + # actually exist. + result = runner.invoke(cli, ["task", "a=5", "x.y_z.A=1.0"]) + assert result.exit_code == 0 + + +def test_param_key_invalid_characters(): + # Keys with invalid characters such as '-' or '@' should not be accepted. + result = runner.invoke(cli, ["task", "a=5", "x-y=1.0"]) + assert result.exit_code == 2 + result = runner.invoke(cli, ["task", "a=5", "x@y=1.0"]) + assert result.exit_code == 2 + + +def test_override_param_values(): + # We should be able to override existing parameter values through the CLI. + result = runner.invoke(cli, ["task", "a=5", "b=bar"]) + assert result.exit_code == 0 + assert result.output == "5 bar\n" + + +def test_override_param_complex_string(): + # We should be able to pass complex strings, including paths. + result = runner.invoke( + cli, ["task", "a=5", "b=https://some-path/foo/bar@somewhere"] + ) + assert result.exit_code == 0 + assert result.output == "5 https://some-path/foo/bar@somewhere\n" diff --git a/zookeeper/core/component.py b/zookeeper/core/component.py new file mode 100644 index 0000000..820090f --- /dev/null +++ b/zookeeper/core/component.py @@ -0,0 +1,543 @@ +""" +Components are generic, modular classes designed to be easily configurable. + +Components have configurable fields, which can contain either generic Python +objects or nested sub-components. These are declared with class-level Python +type annotations, in the same way that fields of +[dataclasses](https://docs.python.org/3/library/dataclasses.html) are declared. +After instantiation, components are 'configured' with a configuration +dictionary; this process automatically injects the correct field values into the +component and all sub-components. Component fields can have defaults set, either +in the class definition or passed via `__init__`, but field values passed to +`configure` will always take precedence over these default values. + +If a nested sub-component declares a field with the same name as a field in one +of its ancestors, it will receive the same configured field value as the parent +does. Howevever, configuration is scoped: if the field on the child, or on a +_closer anscestor_, is configured with a different value, then that value will +override the one from the original parent. + +Configuration can be interactive. In this case, the method will prompt for +missing fields via the CLI. + +The following example illustrates the configuration mechanism with and +configuration scoping: + +``` +@component +class A: + x: int + z: float + +@component +class B: + a: A + y: str = "foo" + +@component +class C: + b: B + x: int + z: float = 3.14 + +c = C() +configure( + c, + { + "x": 5, # (1) + "b.x": 10, # (2) + "b.a.x": 15, # (3) + + "b.y": "foo", # (4) + + "b.z": 2.71 # (5) + } +) +print(c) + +>> C( + b = B( + a = A( + x = 15, # (3) overrides (2) overrides (1) + z = 2.71 # Inherits from parent: (5) + ), + y = "foo" # (4) overrides the default + ), + x = 5, # Only (1) applies + z = 3.14 # The default is taken + ) +``` +""" + +import inspect +from typing import Any, Dict, Optional + +from prompt_toolkit import print_formatted_text +from typeguard import check_type + +from zookeeper.core.utils import ( + convert_to_snake_case, + prompt_for_component_subclass, + prompt_for_value, + type_name_str, +) + +try: # pragma: no cover + from colorama import Fore + + YELLOW, GREEN, RED, RESET = Fore.YELLOW, Fore.GREEN, Fore.RED, Fore.RESET +except ImportError: # pragma: no cover + YELLOW = GREEN = RED = RESET = "" + + +def is_component_class(cls): + try: + return "__component_name__" in cls.__dict__ + except AttributeError: + return False + + +def generate_subclasses(cls): + """Recursively find subclasses of `cls`.""" + + if not inspect.isclass(cls): + return + yield cls + for s in cls.__subclasses__(): + yield from generate_subclasses(s) + + +def generate_component_subclasses(cls): + """Find component subclasses of `cls`.""" + + for subclass in generate_subclasses(cls): + if is_component_class(subclass) and not inspect.isabstract(subclass): + yield subclass + + +##################### +# Component fields. # +##################### + + +class Field: + def __init__(self, annotated_type): + self.annotated_type = annotated_type + + +class EmptyField(Field): + pass + + +class InheritedField(Field): + def __init__(self, annotated_type, is_overriden=False): + super().__init__(annotated_type) + self.is_overriden = is_overriden + + +class NonEmptyField(Field): + def __init__(self, annotated_type, is_overriden): + super().__init__(annotated_type) + self.is_overriden = is_overriden + + +class EmptyFieldError(AttributeError): + def __init__(self, component, field_name): + message = ( + f"The component `{component.__component_name__}` has no default or " + f"configured value for field `{field_name}`. Please configure the " + "component to provide a value." + ) + super().__init__(message) + + +# Constants which are used internally during component configuration. They are +# used as placeholders to indicate to a nested sub-component that an ancestor +# component has a value for a given field name. +OVERRIDEN_CONF_VALUE = object() +NON_OVERRIDEN_CONF_VALUE = object() + + +def set_field_value(instance, name, value): + assert not instance.__component_configured__ + assert name in instance.__component_fields__ + field = instance.__component_fields__[name] + + if value == OVERRIDEN_CONF_VALUE: + instance.__component_fields__[name] = InheritedField( + annotated_type=field.annotated_type, is_overriden=True + ) + elif value == NON_OVERRIDEN_CONF_VALUE: + if isinstance(field, EmptyField): + instance.__component_fields__[name] = InheritedField( + annotated_type=field.annotated_type, is_overriden=False + ) + else: + object.__setattr__(instance, name, value) + instance.__component_fields__[name] = NonEmptyField( + annotated_type=field.annotated_type, is_overriden=False + ) + + +#################################### +# Component class method wrappers. # +#################################### + + +def init_wrapper(init_fn): + # Components need to be instantiable without arguments, so check that + # `init_fn` does not accept any positional arguments without default values. + for i, (name, param) in enumerate(inspect.signature(init_fn).parameters.items()): + if ( + i > 0 + and param.default == inspect.Parameter.empty + and param.kind + not in [inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD] + ): + raise ValueError( + "The `__init__` method of a component must not accept any " + "positional arguments, as the component configuration process " + "requires component classes to be instantiable without arguments." + ) + + def __component_init__(instance, **kwargs): + # Fake 'super' call. + if init_fn != object.__init__: + init_fn(instance, **kwargs) + + instance.__component_fields__ = {} + + # Populate `__component_fields__` with all annotations set on this class + # and all superclasses. We have to go through the MRO chain and collect + # them in reverse order so that they are correctly overriden. + annotations = {} + for base_class in reversed(inspect.getmro(instance.__class__)): + annotations.update(getattr(base_class, "__annotations__", {})) + instance.__component_fields__ = {} + for name, annotated_type in annotations.items(): + if name in object.__dir__(instance): # type: ignore + instance.__component_fields__[name] = NonEmptyField( + annotated_type, is_overriden=False + ) + else: + instance.__component_fields__[name] = EmptyField(annotated_type) + + if init_fn == object.__init__: + for name, value in kwargs.items(): + if name in instance.__component_fields__: + set_field_value(instance, name, value) + else: + raise ValueError( + f"Argument '{name}' does not correspond to any annotated field " + f"of '{type_name_str(instance.__class__)}'." + ) + + return __component_init__ + + +def dir_wrapper(dir_fn): + def __component_dir__(instance): + return set(dir_fn(instance)) | set(instance.__component_fields__.keys()) + + return __component_dir__ + + +def getattribute_wrapper(getattr_fn): + def __component_getattr__(instance, name): + component_fields = object.__getattribute__(instance, "__component_fields__") + if name in component_fields: + field = component_fields[name] + if isinstance(field, EmptyField): + raise EmptyFieldError(instance, name) + if isinstance(field, NonEmptyField): + return object.__getattribute__(instance, name) + if isinstance(field, InheritedField): + return getattr(instance.__component_parent__, name) + else: + return getattr_fn(instance, name) + raise AttributeError + + return __component_getattr__ + + +def setattr_wrapper(setattr_fn): + def __component_setattr__(instance, name, value): + if name in instance.__component_fields__: + raise ValueError( + "Setting component field values directly is prohibited. Use Zookeeper " + "component configuration to set field values." + ) + else: + return setattr_fn(instance, name, value) + + return __component_setattr__ + + +def delattr_wrapper(delattr_fn): + def __component_delattr__(instance, name): + if name in instance.__component_fields__: + raise ValueError("Deleting component fields is prohibited.") + return delattr_fn(instance, name) + + return __component_delattr__ + + +################################## +# Pretty string representations. # +################################## + + +# Indent for nesting in the string representation +INDENT = " " * 4 + + +def str_key_val(key, value, color=True, single_line=False): + if is_component_class(value.__class__): + if single_line: + value = repr(value) + else: + value = f"\n{INDENT}".join(str(value).split("\n")) + elif callable(value): + value = "" + elif type(value) == str: + value = f'"{value}"' + space = "" if single_line else " " + return ( + f"{YELLOW}{key}{RESET}{space}={space}{YELLOW}{value}{RESET}" + if color + else f"{key}{space}={space}{value}" + ) + + +def __component_repr__(instance): + fields = ", ".join( + [ + str_key_val( + field_name, getattr(instance, field_name), color=False, single_line=True + ) + for field_name in sorted(instance.__component_fields__) + ] + ) + return f"{instance.__class__.__name__}({fields})" + + +def __component_str__(instance): + fields = f",\n{INDENT}".join( + [ + str_key_val(field_name, getattr(instance, field_name)) + for field_name in sorted(instance.__component_fields__) + ] + ) + return f"{instance.__class__.__name__}(\n{INDENT}{fields}\n)" + + +####################### +# Exported functions. # +####################### + + +def component(cls): + """A decorater which turns a class into a Zookeeper component.""" + + if not inspect.isclass(cls): + raise ValueError("Only classes can be decorated with @component.") + + if inspect.isabstract(cls): + raise ValueError("Abstract classes cannot be decorated with @component.") + + if is_component_class(cls): + raise ValueError( + f"The class {cls} is already a component; the @component decorator cannot " + "be applied again." + ) + + cls.__component_name__ = convert_to_snake_case(cls.__name__) + cls.__component_parent__ = None + cls.__component_configured__ = False + cls.__component_fields__ = {} + + # Override `__getattribute__`, `__setattr__`, and `__delattr__` to correctly + # manage getting, setting, and deleting component fields. + cls.__getattribute__ = getattribute_wrapper(cls.__getattribute__) # type: ignore + cls.__setattr__ = setattr_wrapper(cls.__setattr__) + cls.__delattr__ = delattr_wrapper(cls.__delattr__) + + # Override `__dir__` so that field names are included. + cls.__dir__ = dir_wrapper(cls.__dir__) + + # Override `__init__` to perform component initialisation and (potentially) + # set key-word args as field values. + cls.__init__ = init_wrapper(cls.__init__) + + # Components should have nice `__str__` and `__repr__` methods. + cls.__str__ = __component_str__ + cls.__repr__ = __component_repr__ + + return cls + + +def configure( + instance, + conf: Dict[str, Any], + name: Optional[str] = None, + interactive: bool = False, +): + """ + Configure the component instance with parameters from the `conf` dict. + + Configuration passed through `conf` takes precedence over and will + overwrite any values already set on the instance - either class defaults + or those set in `__init__`. + """ + + # Configuration can only happen once. + if instance.__component_configured__: + raise ValueError( + f"Component '{instance.__component_name__}' has already been configured." + ) + + if name is not None: + instance.__component_name__ = name + + # Set the correct value for each field. + for field_name, field in instance.__component_fields__.items(): + full_name = f"{instance.__component_name__}.{field_name}" + field_type_name = ( + field.annotated_type.__name__ + if inspect.isclass(field.annotated_type) + else str(field.annotated_type) + ) + component_subclasses = list(generate_component_subclasses(field.annotated_type)) + + if field_name in conf: + field_value = conf[field_name] + # The configuration value could be a string specifying a component + # class to instantiate. + if len(component_subclasses) > 0 and isinstance(field_value, str): + for subclass in component_subclasses: + if ( + field_value == subclass.__name__ + or field_value == subclass.__qualname__ + or convert_to_snake_case(field_value) + == convert_to_snake_case(subclass.__name__) + ): + field_value = subclass() + break + + set_field_value(instance, field_name, field_value) + + # If this is a 'raw' value, add a placeholder to `conf` so that it + # gets picked up correctly in sub-components. + if ( + field_value != OVERRIDEN_CONF_VALUE + and field_value != NON_OVERRIDEN_CONF_VALUE + ): + conf[field_name] = OVERRIDEN_CONF_VALUE + + # If there's no config value but a value is already set on the instance, + # we only need to add a placeholder to `conf` to make sure that the + # value will be accessible from sub-components. `hasattr` isn't safe so + # we have to check membership directly. + elif field_name in object.__dir__(instance): # type: ignore + conf[field_name] = NON_OVERRIDEN_CONF_VALUE + + # If there is only one concrete component subclass of the annotated + # type, we assume the user must intend to use that subclass, and so + # instantiate and use an instance automatically. + elif len(component_subclasses) == 1: + component_cls = list(component_subclasses)[0] + print_formatted_text( + f"'{type_name_str(component_cls)}' is the only concrete component " + f"class that satisfies the type of the annotated field '{full_name}'. " + "Using an instance of this class by default." + ) + # This is safe because we don't allow `__init__` to have any + # positional arguments without defaults. + field_value = component_cls() + + set_field_value(instance, field_name, field_value) + + # Add a placeholder to `conf` to so that this value can be accessed + # by sub-components. + conf[field_name] = NON_OVERRIDEN_CONF_VALUE + + # If we are running interactively, prompt for a value. + elif interactive: + if len(component_subclasses) > 0: + component_cls = prompt_for_component_subclass( + full_name, component_subclasses + ) + # This is safe because we don't allow `__init__` to have any + # positional arguments without defaults. + field_value = component_cls() + else: + field_value = prompt_for_value(full_name, field.annotated_type) + + set_field_value(instance, field_name, field_value) + + # Add a placeholder to `conf` so that this value can be accessed by + # sub-components. + conf[field_name] = OVERRIDEN_CONF_VALUE + + # Otherwise, raise an appropriate error. + else: + if len(component_subclasses) > 0: + raise ValueError( + f"Annotated field '{full_name}' of type '{field_type_name}' " + f"has no configured value. Please configure '{full_name}' with " + f"one of the following component subclasses of '{field_type_name}':" + + "\n ".join( + [""] + list(type_name_str(c) for c in component_subclasses) + ) + ) + raise ValueError( + "No configuration value found for annotated field " + f"'{full_name}' of type '{field_type_name}'." + ) + + # Recursively configure any sub-components. + for field_name, field_type in instance.__component_fields__.items(): + field_value = getattr(instance, field_name) + full_name = f"{instance.__component_name__}.{field_name}" + + if ( + is_component_class(field_value.__class__) + and not field_value.__component_configured__ + ): + # Set the component parent so that inherited fields function + # correctly. + field_value.__component_parent__ = instance + + # Configure the nested sub-component. The configuration we use + # consists of all non-scoped keys and any keys scoped to + # `field_name`, where the keys scoped to `field_name` override the + # non-scoped keys. + non_scoped_conf = {a: b for a, b in conf.items() if "." not in a} + field_name_scoped_conf = { + a[len(f"{field_name}.") :]: b + for a, b in conf.items() + if a.startswith(f"{field_name}.") + } + nested_conf = {**non_scoped_conf, **field_name_scoped_conf} + configure(field_value, nested_conf, name=full_name, interactive=interactive) + + # Type check all fields. + for field_name, field in instance.__component_fields__.items(): + assert field_name in instance.__component_fields__ + field_value = getattr(instance, field_name) + try: + check_type(field_name, field_value, field.annotated_type) + # Because boolean `True` and `False` are coercible to ints and + # floats, `typeguard.check_type` doesn't throw if we e.g. pass + # `True` to a value expecting a float. This would, however, likely + # be a user error, so explicitly check for this. + if field.annotated_type in [float, int] and isinstance(field_value, bool): + raise TypeError + except TypeError: + raise TypeError( + f"Attempting to set field '{instance.__component_name__}.{field_name}' " + f"which has annotated type '{type_name_str(field.annotated_type)}' " + f"with value '{field_value}'." + ) from None + + instance.__component_configured__ = True diff --git a/zookeeper/core/component_test.py b/zookeeper/core/component_test.py new file mode 100644 index 0000000..f6d0ebb --- /dev/null +++ b/zookeeper/core/component_test.py @@ -0,0 +1,351 @@ +import abc +from typing import List +from unittest.mock import patch + +import pytest +from click import unstyle + +from zookeeper.core.component import component, configure + + +@pytest.fixture +def ExampleComponentClass(): + @component + class A: + a: int + b: str = "foo" + + return A + + +def test_non_class_decorate_error(): + """ + An error should be raised when attempting to decorate a non-class object. + """ + + with pytest.raises( + ValueError, match="Only classes can be decorated with @component.", + ): + + @component + def fn(): + pass + + +def test_abstract_class_decorate_error(): + """ + An error should be raised when attempting to decorate an abstract class. + """ + + with pytest.raises( + ValueError, match="Abstract classes cannot be decorated with @component.", + ): + + @component + class A(abc.ABC): + @abc.abstractmethod + def foo(self): + pass + + +def test_positional_args_init_decorate_error(): + """ + An error should be raised when attempting to decorate a class with an + `__init__` methods that takes positional arguments. + """ + + with pytest.raises( + ValueError, + match=r"^The `__init__` method of a component must not accept any positional arguments", + ): + + @component + class A: + def __init__(self, a, b=5): + self.a = a + self.b = b + + +def test_existing_init(): + """ + If the decorated class has an `__init__` method, that method should be + called on instantiation. + """ + + @component + class A: + def __init__(self, foo="bar"): + self.foo = foo + + assert A().foo == "bar" + + +def test_no_init(ExampleComponentClass): + """ + If the decorated class does not have an `__init__` method, the decorated + class should define an `__init__` which accepts kwargs to set field values, + and raises appropriate arguments when other values are passed. + """ + + x = ExampleComponentClass(a=2) + assert x.a == 2 + assert x.b == "foo" + + x = ExampleComponentClass(a=0, b="bar") + assert x.a == 0 + assert x.b == "bar" + + with pytest.raises( + TypeError, + match=r"__component_init__\(\) takes 1 positional argument but 2 were given", + ): + ExampleComponentClass("foobar") + + with pytest.raises( + ValueError, + match=r"^Argument 'some_other_field_name' does not correspond to any annotated field of", + ): + ExampleComponentClass(some_other_field_name=0) + + +def test_configure_override_field_values(ExampleComponentClass): + """Component fields should be overriden correctly.""" + + x = ExampleComponentClass() + configure(x, {"a": 0, "b": "bar"}) + assert x.a == 0 + assert x.b == "bar" + + +def test_configure_scoped_override_field_values(): + """Field overriding should respect component scope.""" + + @component + class Child: + a: int + b: str + c: List[float] + + @component + class Parent: + b: str = "bar" + child: Child = Child() + + @component + class GrandParent: + a: int + b: str + parent: Parent = Parent() + + grand_parent = GrandParent() + + configure( + grand_parent, + { + "a": 10, + "parent.a": 15, + "b": "foo", + "parent.child.b": "baz", + "c": [1.5, -1.2], + "parent.c": [-17.2], + "parent.child.c": [0, 4.2], + }, + ) + + # The grand-parent `grand_parent` should have the value `a` = 10. Even + # though a config value is declared for its scope, `grand_parent.child` + # should have no `a` value set, as it doesn't declare `a` as a field. + # Despite this, `grand_parent.parent.child` should get the value `a` = 15, + # as it lives inside the configuration scope of its parent, + # `grand_parent.parent`. + assert grand_parent.a == 10 + assert not hasattr(grand_parent.parent, "a") + assert grand_parent.parent.child.a == 15 + + # `b` is declared as a field at all three levels. The 'baz' value should be + # scoped only to the child, so 'foo' will apply to both the parent and + # grand-parent. + assert grand_parent.b == "foo" + assert grand_parent.parent.b == "foo" + assert grand_parent.parent.child.b == "baz" + + # `c` is declared as a field only in the child. The more specific scopes + # override the more general. + assert grand_parent.parent.child.c == [0, 4.2] + + +def test_configure_automatically_instantiate_subcomponent(): + """ + If there is only a single component subclass of a field type, an instance of + the class should be automatically instantiated during configuration. + """ + + class AbstractChild: + pass + + @component + class Child1(AbstractChild): + pass + + @component + class Parent: + child: AbstractChild + + # There is only a single defined component subclass of `AbstractChild`, + # `Child1`, so we should be able to configure an instance of `Parent` and + # have an instance automatically instantiated in the process. + + p = Parent() + configure(p, {}) + assert isinstance(p.child, Child1) + + @component + class Child2(AbstractChild): + pass + + # Now there is another defined component subclass of `AbstractChild`, + # so configuration will now fail (as we cannot choose between the two). + + p = Parent() + with pytest.raises( + ValueError, + match="Annotated field 'parent.child' of type 'AbstractChild' has no configured value. Please configure 'parent.child' with one of the following component subclasses", + ): + configure(p, {}) + + +def test_configure_non_interactive_missing_field_value(ExampleComponentClass): + """ + When not configuring interactively, an error should be raised if a field has + neither a default nor a configured value. + """ + + with pytest.raises( + ValueError, + match=r"^No configuration value found for annotated field 'FAKE_NAME.a' of type 'int'.", + ): + configure(ExampleComponentClass(), {"b": "bar"}, name="FAKE_NAME") + + +def test_configure_interactive_prompt_missing_field_value(ExampleComponentClass): + """ + When configuring interactively, fields without default or configured values + should prompt for value input through the CLI. + """ + + x = ExampleComponentClass() + a_value = 42 + + with patch("zookeeper.core.utils.prompt", return_value=str(a_value)) as prompt: + configure(x, {"b": "bar"}, name="FAKE_NAME", interactive=True) + + assert x.a == a_value + assert x.b == "bar" + prompt.assert_called_once() + + +def test_configure_interactive_prompt_for_subcomponent_choice(): + """ + When configuring interactively, sub-component fields without default or + configured values should prompt for a choice of subcomponents to instantiate + through the CLI. + """ + + class AbstractChild: + pass + + @component + class Child1(AbstractChild): + pass + + @component + class Child2(AbstractChild): + pass + + class Child3_Abstract(AbstractChild): + pass + + @component + class Child3A(Child3_Abstract): + pass + + @component + class Child3B(Child3_Abstract): + pass + + @component + class Parent: + child: AbstractChild + + # The prompt lists the concrete component subclasses (alphabetically) and + # asks for an an integer input corresponding to an index in this list. + + # We expect the list to therefore be as follows (`AbstractChild` and + # `Child3_Abstract` are excluded because although they live in the subclass + # hierarchy, neither is a component): + expected_class_choices = [Child1, Child2, Child3A, Child3B] + + for i, expected_choice in enumerate(expected_class_choices): + p = Parent() + + with patch("zookeeper.core.utils.prompt", return_value=str(i + 1)) as prompt: + configure(p, {}, interactive=True) + + assert isinstance(p.child, expected_choice) + prompt.assert_called_once() + + +def test_str_and_repr(): + """ + `__str__` and `__repr__` should give formatted strings that represent nested + components nicely. + """ + + @component + class Child: + a: int + b: str + c: List[float] + + @component + class Parent: + b: str = "bar" + child: Child = Child() + + p = Parent() + + configure(p, {"a": 10, "b": "foo", "c": [1.5, -1.2]}, name="parent") + + assert ( + unstyle(repr(p)) + == """Parent(b="foo", child=Child(a=10, b="foo", c=[1.5, -1.2]))""" + ) + assert ( + unstyle(str(p)) + == """Parent( + b = "foo", + child = Child( + a = 10, + b = "foo", + c = [1.5, -1.2] + ) +)""" + ) + + +def test_type_check(ExampleComponentClass): + """During configuration we should type-check all field values.""" + + # Attempting to set an int field with a float. + with pytest.raises( + TypeError, + match=r"^Attempting to set field 'x.a' which has annotated type 'int' with value '4.5'.$", + ): + configure(ExampleComponentClass(), {"a": 4.5}, name="x") + + # Attempting to set a str field with a bool. + with pytest.raises( + TypeError, + match=r"^Attempting to set field 'x.b' which has annotated type 'str' with value 'True'.$", + ): + configure(ExampleComponentClass(), {"a": 3, "b": True}, name="x") diff --git a/zookeeper/core/task.py b/zookeeper/core/task.py new file mode 100644 index 0000000..6f33c00 --- /dev/null +++ b/zookeeper/core/task.py @@ -0,0 +1,56 @@ +import inspect + +import click + +from zookeeper.core.cli import ConfigParam, cli +from zookeeper.core.component import component, configure +from zookeeper.core.utils import convert_to_snake_case + + +def task(cls): + """ + A decorator which turns a class into a Zookeeper task, which is a Zookeeper + method with an argument-less `run` method. + + Tasks are runnable through the CLI. Upon execution, the task is instantiated + and all component fields are configured using configuration passed as CLI + arguments of the form `field_name=field_value`, and then the `run` method is + called. + """ + + cls = component(cls) + + if not (hasattr(cls, "run") and callable(cls.run)): + raise ValueError("Classes decorated with @task must define a `run` method.") + + # Enforce argument-less `run` + + call_args = inspect.signature(cls.run).parameters + if len(call_args) > 1 or len(call_args) == 1 and "self" not in call_args: + raise ValueError( + "A task class must define a `run` method taking no arguments except " + f"`self`, which runs the task, but `{cls.__name__}.run` accepts arguments " + f"{call_args}." + ) + + # Register a CLI command to run the task. + + task_name = convert_to_snake_case(cls.__name__) + if task_name in cli.commands: + raise ValueError( + f"Task naming conflict. Task with name '{task_name}' already registered. " + "Note that the task name is the name of the class that the @task decorator " + "is applied to, normalised to 'snake case', e.g. 'FooBarTask' -> " + "'foo_bar_task'." + ) + + @cli.command(task_name) + @click.argument("config", type=ConfigParam(), nargs=-1) + @click.option("-i", "--interactive", is_flag=True, default=False) + def command(config, interactive): + config = {k: v for k, v in config} + task_instance = cls() + configure(task_instance, config, interactive=interactive) + task_instance.run() + + return cls diff --git a/zookeeper/core/task_test.py b/zookeeper/core/task_test.py new file mode 100644 index 0000000..b34cbae --- /dev/null +++ b/zookeeper/core/task_test.py @@ -0,0 +1,45 @@ +import pytest + +from zookeeper.core.task import task + + +def test_with_argumentless_run(): + """Tasks with argument-less `run` should not cause errors.""" + + @task + class T1: + def run(self): + pass + + @task + class T2: + @classmethod + def run(cls): + pass + + @task + class T3: + @staticmethod + def run(): + pass + + +def test_no_run_error(): + """Tasks without `run` should cause an error.""" + + with pytest.raises( + ValueError, match="Classes decorated with @task must define a `run` method.", + ): + + @task + class T: + pass + + +def test_run_with_args_error(): + """ + Defining a subclass which has a `run` that takes any arguments should raise + a ValueError. + """ + + pass diff --git a/zookeeper/core/utils.py b/zookeeper/core/utils.py new file mode 100644 index 0000000..0a6b1aa --- /dev/null +++ b/zookeeper/core/utils.py @@ -0,0 +1,71 @@ +import re +from ast import literal_eval +from typing import Sequence, Type + +from prompt_toolkit import print_formatted_text, prompt + + +def type_name_str(type) -> str: + try: + if hasattr(type, "__qualname__"): + return str(type.__qualname__) + if hasattr(type, "__name__"): + return str(type.__name__) + return str(type) + except Exception: + return "" + + +def convert_to_snake_case(name): + s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name) + return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() + + +def parse_value_from_string(string: str): + try: + value = literal_eval(string) + except (ValueError, SyntaxError): + # Parse as string if above raises a ValueError or SyntaxError. + value = str(string) + except Exception: + raise ValueError(f"Could not parse '{string}'.") + return value + + +def prompt_for_value(field_name: str, field_type): + """Promt the user to input a value for the parameter `field_name`.""" + + print_formatted_text( + f"No value found for field '{field_name}' of type '{field_type}'. " + "Please enter a value for this parameter:" + ) + response = prompt("> ") + while response == "": + print_formatted_text(f"No input received, please enter a value:") + response = prompt("> ") + return parse_value_from_string(response) + + +def prompt_for_component_subclass(component_name: str, classes: Sequence[Type]) -> Type: + """Prompt the user to choose a compnent subclass from `classes`.""" + + print_formatted_text(f"No instance found for nested component '{component_name}'.") + choices = {cls.__qualname__: cls for cls in classes} + names = sorted(list(choices.keys())) + print_formatted_text( + f"Please choose from one of the following component subclasses to instantiate:\n" + + "\n".join([f"{i + 1}) {o}" for i, o in enumerate(names)]) + ) + response = prompt("> ") + while True: + try: + response = int(response) - 1 + except ValueError: + response = -1 + if 0 <= response < len(names): + break + print_formatted_text( + f"Invalid input. Please enter a number between 1 and {len(names)}:" + ) + response = prompt("> ") + return choices[names[response]] diff --git a/zookeeper/task.py b/zookeeper/task.py deleted file mode 100644 index ea66d57..0000000 --- a/zookeeper/task.py +++ /dev/null @@ -1,27 +0,0 @@ -from abc import ABC, abstractmethod -from inspect import signature - -from zookeeper.component import Component - - -class Task(Component, ABC): - """ - A 'Task' component that performs a task on `run`. - """ - - def __init_subclass__(cls, *args, **kwargs): - super().__init_subclass__(*args, **kwargs) - # Enforce argument-less `run` - if hasattr(cls, "run"): - call_args = signature(cls.run).parameters - if len(call_args) == 0 or len(call_args) == 1 and "self" in call_args: - return - raise ValueError( - "A `Task` subclass must define a `run` method taking no positional " - f"arguments which runs the task, but {cls.__name__}.run accepts " - f"positional arguments {call_args}." - ) - - @abstractmethod - def run(self): - pass diff --git a/zookeeper/task_test.py b/zookeeper/task_test.py deleted file mode 100644 index 400d446..0000000 --- a/zookeeper/task_test.py +++ /dev/null @@ -1,36 +0,0 @@ -import pytest - -from zookeeper.task import Task - - -def test_override_call_with_args_error(): - - # Defining a subclass which overrides `run` with positional arguments - # should raise a ValueError. - with pytest.raises( - ValueError, - match=r"^A `Task` subclass must define a `run` method taking no positional", - ): - - class J(Task): - def run(self, a, b): - pass - - # Overriding `run` without positional arguments should not raise an - # error. - - class J1(Task): - def run(self): - pass - - # The same should be true if `run` is a static method or class method. - - class J2(Task): - @classmethod - def run(cls): - pass - - class J3(Task): - @staticmethod - def run(): - pass diff --git a/zookeeper/tf/__init__.py b/zookeeper/tf/__init__.py new file mode 100644 index 0000000..32dea3f --- /dev/null +++ b/zookeeper/tf/__init__.py @@ -0,0 +1,13 @@ +from zookeeper.tf.dataset import Dataset, MultiTFDSDataset, TFDSDataset +from zookeeper.tf.experiment import Experiment +from zookeeper.tf.model import Model +from zookeeper.tf.preprocessing import Preprocessing + +__all__ = [ + "Dataset", + "Experiment", + "Model", + "MultiTFDSDataset", + "Preprocessing", + "TFDSDataset", +] diff --git a/zookeeper/dataset.py b/zookeeper/tf/dataset.py similarity index 79% rename from zookeeper/dataset.py rename to zookeeper/tf/dataset.py index af0b3ea..d39eb52 100644 --- a/zookeeper/dataset.py +++ b/zookeeper/tf/dataset.py @@ -4,10 +4,8 @@ import tensorflow as tf import tensorflow_datasets as tfds -from zookeeper.component import Component - -class Dataset(Component, ABC): +class Dataset(ABC): """ An abstract base class to encapsulate a dataset. Concrete sub-classes must implement the `train` method, and optionally the `validation` method. @@ -30,7 +28,7 @@ def validation(self, decoders=None) -> Tuple[tf.data.Dataset, int]: """ raise ValueError( - f"Dataset '{self.__component_name__}' is not configured with validation " + f"Dataset '{self.__class__.__name__}' is not configured with validation " "data." ) @@ -65,37 +63,6 @@ class TFDSDataset(Dataset): train_split: str validation_split: Optional[str] = None - def validate_configuration(self): - super().validate_configuration() - - # Check that the name corresponds to a valid TensorFlow dataset. - builder_names = tfds.list_builders() - if self.name.split(":")[0].split("/")[0] not in builder_names: - raise ValueError( - f"'{self.__component_name__}.name' has invalid value '{self.name}'. " - "Valid dataset names:\n " + ",\n ".join(builder_names) - ) - - # Check that the `train_split` is valid. - if self.train_split is None or any( - s not in self.splits for s in base_splits(self.train_split) - ): - raise ValueError( - f"'{self.__component_name__}.train_split' has invalid value " - f"'{self.train_split}'. Valid values:\n " - + ",\n ".join(self.splits.keys()) - ) - - # Check that `validation_split` is valid (`None` is allowed). - if self.validation_split is not None and any( - s not in self.splits for s in base_splits(self.train_split) - ): - raise ValueError( - f"'{self.__component_name__}.train_split' has invalid value " - f"'{self.train_split}'. Valid values:\n " - + ",\n ".join([None] + self.splits.keys()) - ) - @property def info(self): if not hasattr(self, "_info"): diff --git a/zookeeper/experiment.py b/zookeeper/tf/experiment.py similarity index 73% rename from zookeeper/experiment.py rename to zookeeper/tf/experiment.py index f5710c2..7e9dfb4 100644 --- a/zookeeper/experiment.py +++ b/zookeeper/tf/experiment.py @@ -2,13 +2,12 @@ from tensorflow import keras -from zookeeper.dataset import Dataset -from zookeeper.model import Model -from zookeeper.preprocessing import Preprocessing -from zookeeper.task import Task +from zookeeper.tf.dataset import Dataset +from zookeeper.tf.model import Model +from zookeeper.tf.preprocessing import Preprocessing -class Experiment(Task): +class Experiment: """ A wrapper around a Keras experiment. Subclasses must implement their training loop in `run`. @@ -23,7 +22,7 @@ class Experiment(Task): epochs: int batch_size: int metrics: List[Union[keras.metrics.Metric, Callable, str]] = [] - loss: Union[keras.losses.Loss, str] + loss: Optional[Union[keras.losses.Loss, str]] optimizer: Union[keras.optimizers.Optimizer, str] learning_rate_schedule: Optional[Callable] = None callbacks: List[Union[keras.callbacks.Callback, Callable]] = [] diff --git a/zookeeper/model.py b/zookeeper/tf/model.py similarity index 82% rename from zookeeper/model.py rename to zookeeper/tf/model.py index 031ae1b..abe3c98 100644 --- a/zookeeper/model.py +++ b/zookeeper/tf/model.py @@ -3,10 +3,8 @@ from tensorflow import keras -from zookeeper.component import Component - -class Model(Component, ABC): +class Model(ABC): """ A wrapper around a Keras model. Subclasses must implement `build` to build and return a Keras model. diff --git a/zookeeper/preprocessing.py b/zookeeper/tf/preprocessing.py similarity index 96% rename from zookeeper/preprocessing.py rename to zookeeper/tf/preprocessing.py index b6a62bf..470d5fc 100644 --- a/zookeeper/preprocessing.py +++ b/zookeeper/tf/preprocessing.py @@ -4,8 +4,6 @@ import tensorflow as tf -from zookeeper.component import Component - def pass_training_kwarg(function, training=False): if "training" in signature(function).parameters: @@ -13,7 +11,7 @@ def pass_training_kwarg(function, training=False): return function -class Preprocessing(Component): +class Preprocessing: """A wrapper around `tf.data` preprocessing.""" def input(self, data, training) -> tf.Tensor: diff --git a/zookeeper/utils.py b/zookeeper/utils.py deleted file mode 100644 index 3325ae3..0000000 --- a/zookeeper/utils.py +++ /dev/null @@ -1,94 +0,0 @@ -import re -from ast import literal_eval -from inspect import isabstract -from typing import Set - -from prompt_toolkit import print_formatted_text, prompt - - -def type_name_str(type) -> str: - try: - if hasattr(type, "__qualname__"): - return str(type.__qualname__) - if hasattr(type, "__name__"): - return str(type.__name__) - return str(type) - except Exception: - return "" - - -def convert_to_snake_case(name): - s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name) - return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() - - -def get_concrete_subclasses(cls) -> Set[type]: - """Return a set of all non-abstract classes which inherit from `cls`.""" - subclasses = set([cls] if not isabstract(cls) else []) - for s in cls.__subclasses__(): - if not isabstract(s): - subclasses.add(s) - subclasses.update(get_concrete_subclasses(s)) - return subclasses - - -def parse_value_from_string(string: str): - try: - value = literal_eval(string) - except (ValueError, SyntaxError): - # Parse as string if above raises a ValueError or SyntaxError. - value = str(string) - except Exception: - raise ValueError(f"Could not parse '{string}'.") - return value - - -def promt_for_param_value(param_name: str, param_type): - """Promt the user to input a value for the parameter `param_name`.""" - print_formatted_text( - f"No value found for parameter '{param_name}' of type '{param_type}'. " - "Please enter a value for this parameter:" - ) - response = prompt("> ") - while response == "": - print_formatted_text(f"No input received, please enter a value:") - response = prompt("> ") - return parse_value_from_string(response) - - -def prompt_for_component(component_name: str, component_cls: type) -> type: - print_formatted_text( - f"No instance found for nested component '{component_name}' of type " - f"'{component_cls.__qualname__}'." - ) - - component_options = { - cls.__qualname__: cls for cls in get_concrete_subclasses(component_cls) - } - - if len(component_options) == 0: - raise ValueError( - f"'{component_cls}' has no defined concrete subclass implementation." - ) - - component_names = sorted(list(component_options.keys())) - - print_formatted_text( - f"Please choose from one of the following concrete subclasses of " - f"'{component_cls.__qualname__}' to instantiate:\n" - + "\n".join([f"{i + 1}) {o}" for i, o in enumerate(component_names)]) - ) - response = prompt("> ") - while True: - try: - response = int(response) - 1 - except ValueError: - response = -1 - if 0 <= response < len(component_names): - break - print_formatted_text( - f"Invalid input. Please enter a number between 1 and {len(component_names)}:" - ) - response = prompt("> ") - - return component_options[component_names[response]]