-
Notifications
You must be signed in to change notification settings - Fork 24
Block design #24
Block design #24
Changes from 35 commits
72a2bd9
6d6f4d9
7b2d811
4ee38bb
aea2444
d859057
b84e7de
3e1fbc0
aec74ff
8a215de
52d26f6
94ffa99
a19116f
e48ec42
0ebe2a0
1d654dd
6d51079
0d17bb6
fba82b6
e7db4dd
af01c4e
a94c7eb
d06a26d
b7fcb59
8b03d98
cb29984
b49ffa1
c7a49a6
80b426f
e2bc1dc
79278fd
fed9702
030caad
ed9e13b
f78b81d
13701af
0953d9b
e1a5c23
fb48faf
b4088ce
64f971a
41c9144
6f5631a
41009e9
7fb7f67
63d4ea6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,144 @@ | ||
| # Standard | ||
| from abc import ABC, abstractmethod | ||
| from typing import Any, Dict, Iterable, List, Optional, Union | ||
|
|
||
| # Third Party | ||
| from datasets import Dataset | ||
| import pandas as pd | ||
|
|
||
| DATASET_ROW_TYPE = Union[Dict, pd.Series] | ||
| DATASET_TYPE = Union[Iterable[DATASET_ROW_TYPE], pd.DataFrame, Dataset] | ||
|
|
||
|
|
||
| class BaseBlock(ABC): | ||
| """Base Class for all Blocks""" | ||
|
|
||
| def __init__( | ||
| self, | ||
| name: str = None, | ||
| arg_fields: List[str] = None, | ||
| kwarg_fields: List[str] = None, | ||
| result_field: str = None, | ||
| ) -> None: | ||
|
|
||
| if not (isinstance(arg_fields, list) or arg_fields is None): | ||
|
||
| raise TypeError(f"arg_fields must be of type 'list'") | ||
|
||
| if not (isinstance(kwarg_fields, list) or kwarg_fields is None): | ||
| raise TypeError(f"kwarg_fields must be of type 'list'") | ||
| if not (isinstance(result_field, str) or result_field is None): | ||
| raise TypeError(f"result_field must be of type 'str'") | ||
|
|
||
| self._name = name | ||
|
|
||
| self._arg_fields = arg_fields | ||
| self._kwarg_fields = kwarg_fields | ||
| self._result_field = result_field | ||
|
|
||
| @property | ||
| def name(self): | ||
| return self._name | ||
|
|
||
| @property | ||
| def arg_fields(self): | ||
| return self._arg_fields | ||
|
|
||
| @property | ||
| def kwarg_fields(self): | ||
| return self._kwarg_fields | ||
|
|
||
| @property | ||
| def result_field(self): | ||
| return self._result_field | ||
|
|
||
| def get_args_kwargs( | ||
| self, | ||
| inp: BLOCK_ROW_TYPE, | ||
| arg_fields: Optional[List[str]] = None, | ||
| kwarg_fields: Optional[List[str]] = None, | ||
| ): | ||
|
|
||
| arg_fields = arg_fields or self.arg_fields or [] | ||
| kwarg_fields = kwarg_fields or self.kwarg_fields or [] | ||
|
|
||
| if isinstance(inp, (dict, pd.DataFrame, Dataset)): | ||
| args = [inp.get(arg) for arg in arg_fields] | ||
| kwargs = {kwarg: inp.get(kwarg) for kwarg in kwarg_fields} | ||
| else: | ||
| raise TypeError(f"Unexpected input type: {type(inp)}") | ||
|
|
||
| return args, kwargs | ||
mvcrouse marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| def write_result( | ||
| self, | ||
| inp: BLOCK_ROW_TYPE, | ||
| res: Any, | ||
gabe-l-hart marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| result_field: str = None, | ||
mvcrouse marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| ): | ||
| result_field = result_field or self.result_field | ||
|
|
||
| assert result_field is not None, "Result field cannot be None!" | ||
|
|
||
| if isinstance(inp, (dict, pd.DataFrame, Dataset)): | ||
| inp[result_field] = res | ||
| else: | ||
| raise TypeError(f"Unexpected input type: {type(inp)}") | ||
mvcrouse marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| @abstractmethod | ||
| def generate( | ||
| self, | ||
| inputs: BLOCK_INPUT_TYPE, | ||
| *, | ||
| arg_fields: Optional[List[str]] = None, | ||
| kwarg_fields: Optional[List[str]] = None, | ||
| result_field: Optional[str] = None, | ||
gabe-l-hart marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| **kwargs, | ||
| ): | ||
| """The generate function is the primary interface to a Block | ||
|
|
||
| args: | ||
| inputs (BLOCK_INPUT_TYPE): A block operates over a logical iterable | ||
| of rows with named columns (see BLOCK_INPUT_TYPE) | ||
|
|
||
| kwargs: | ||
| arg_fields (Optional[List[str]]): Names of fields within the rows of | ||
| the inputs that should be extracted and passed as positional | ||
| args to the underlying implementation methods. | ||
| kwarg_fields (Optional[List[str]]): Names of fields within the rows | ||
| of the inputs that should be extracted and passed as keyword | ||
| args to the underlying implementation methods. | ||
| **kwargs: Additional keyword args that may be passed to the derived | ||
| block's generate function | ||
| """ | ||
|
|
||
|
|
||
| class BaseUtilityBlock(BaseBlock): | ||
| pass | ||
|
|
||
|
|
||
| class BaseGeneratorBlock(BaseBlock): | ||
| pass | ||
|
||
|
|
||
|
|
||
| class BaseValidatorBlock(BaseBlock): | ||
| def __init__(self, filter: bool = False, **kwargs: Any) -> None: | ||
| super().__init__(**kwargs) | ||
| self._filter_invalids = filter | ||
|
|
||
| def generate( | ||
| self, | ||
| inputs: BLOCK_INPUT_TYPE, | ||
| arg_fields: Optional[List[str]] = None, | ||
mvcrouse marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| kwarg_fields: Optional[List[str]] = None, | ||
| result_field: Optional[List[str]] = None, | ||
| ): | ||
| outputs = [] | ||
| for x in inputs: | ||
| inp_args, inp_kwargs = self.get_args_kwargs(x, arg_fields, kwarg_fields) | ||
| res = self._validate(*inp_args, **inp_kwargs) | ||
| if res or not self._filter_invalids: | ||
| self.write_result(x, res, result_field) | ||
| outputs.append(x) | ||
| return outputs | ||
|
|
||
| def _validate(self, *args: Any, **kwargs: Any) -> bool: | ||
| raise NotImplementedError | ||
mvcrouse marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
This file was deleted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Related to the question below, I think the
Dicthere could be further restricted asDict[str, _something_]. If there are any restrictions on the types for the value, that_something_could be itself a bigUnionor type def, or it could just beAny.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll go with Dict[str, Any]. I frequently pass a dictionary with the SDG object I'm building up as one of the values