Skip to content

Commit bfa9e03

Browse files
AdamHillierlgeiger
authored andcommitted
Components as decorators (#84)
* Make components and tasks decorators rather than classes. * Expose CLI. * Bug fix. * Fix another bug. * Make experiment loss optional. * Fix yet another bug. * Bug fix once more. * Another bug fix. * Update defaults. * WIP * Tidy-up. * Export `configure` correctly. * Update examples. * Make linting happy. * Update zookeeper/core/component.py Co-Authored-By: Lukas Geiger <[email protected]> * Remove erroneous prints. * Review suggestion. * Move `colorama` import to top.
1 parent 3a75bc6 commit bfa9e03

21 files changed

+1159
-1032
lines changed

examples/larq_experiment.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,20 @@
77

88
import larq as lq
99
import tensorflow as tf
10-
import tensorflow_datasets as tfds
1110

12-
from zookeeper import Dataset, Experiment, Model, Preprocessing, TFDSDataset
13-
from zookeeper.cli import add_task_to_cli, cli
11+
from zookeeper import cli, component, task
12+
from zookeeper.tf import Dataset, Experiment, Model, Preprocessing, TFDSDataset
1413

1514

15+
@component
1616
class Cifar10(TFDSDataset):
1717
name = "cifar10"
1818
# CIFAR-10 has only train and test, so validate on test.
19-
train_split = tfds.Split.TRAIN
20-
validation_split = tfds.Split.TEST
19+
train_split = "train"
20+
validation_split = "test"
2121

2222

23+
@component
2324
class PadCropAndFlip(Preprocessing):
2425
pad_size: int
2526
output_size: int
@@ -44,6 +45,7 @@ def output(self, data):
4445
return data["label"]
4546

4647

48+
@component
4749
class BinaryNet(Model):
4850
dataset: Dataset
4951
preprocessing: Preprocessing
@@ -107,7 +109,7 @@ def build(self, input_shape):
107109
)
108110

109111

110-
@add_task_to_cli
112+
@task
111113
class BinaryNetCifar10(Experiment):
112114
dataset = Cifar10()
113115
preprocessing = PadCropAndFlip(pad_size=40, output_size=32)

zookeeper/__init__.py

+2-15
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,3 @@
1-
from zookeeper.component import Component
2-
from zookeeper.dataset import Dataset, TFDSDataset
3-
from zookeeper.experiment import Experiment
4-
from zookeeper.model import Model
5-
from zookeeper.preprocessing import Preprocessing
6-
from zookeeper.task import Task
1+
from zookeeper.core import cli, component, configure, task
72

8-
__all__ = [
9-
"Component",
10-
"Dataset",
11-
"Experiment",
12-
"Model",
13-
"Preprocessing",
14-
"Task",
15-
"TFDSDataset",
16-
]
3+
__all__ = ["cli", "component", "configure", "task"]

zookeeper/cli_test.py

-63
This file was deleted.

0 commit comments

Comments
 (0)