Skip to content

Commit a4368fe

Browse files
tomvdwThe TensorFlow Datasets Authors
authored and
The TensorFlow Datasets Authors
committed
Add a mechanism to register dataset builder providers.
This will allow to register dataset builders that are defined in different source code folders. It also allows registering datasets that are defined elsewhere, e.g., on disk or in a database. PiperOrigin-RevId: 669244349
1 parent 4786ed5 commit a4368fe

File tree

3 files changed

+93
-20
lines changed

3 files changed

+93
-20
lines changed

tensorflow_datasets/core/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333
from tensorflow_datasets.core.lazy_imports_lib import lazy_imports
3434
from tensorflow_datasets.core.load import DatasetCollectionLoader
3535
from tensorflow_datasets.core.naming import ShardedFileTemplate
36+
from tensorflow_datasets.core.registered import add_dataset_builder_provider
37+
from tensorflow_datasets.core.registered import DatasetBuilderProvider
3638
from tensorflow_datasets.core.registered import DatasetNotFoundError
3739
from tensorflow_datasets.core.sequential_writer import SequentialWriter
3840
from tensorflow_datasets.core.split_builder import SplitGeneratorLegacy as SplitGenerator
@@ -61,12 +63,14 @@ def benchmark(*args, **kwargs):
6163

6264
__all__ = [
6365
"add_data_dir",
66+
"add_dataset_builder_provider",
6467
"as_path",
6568
"BenchmarkResult",
6669
"BeamBasedBuilder",
6770
"BeamMetadataDict",
6871
"BuilderConfig",
6972
"DatasetBuilder",
73+
"DatasetBuilderProvider",
7074
"DatasetCollectionLoader",
7175
"DatasetInfo",
7276
"DatasetIdentity",

tensorflow_datasets/core/registered.py

Lines changed: 77 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515

1616
"""Access registered datasets."""
1717

18+
from __future__ import annotations
19+
1820
import abc
1921
import collections
2022
from collections.abc import Iterator
@@ -24,7 +26,7 @@
2426
import inspect
2527
import os.path
2628
import time
27-
from typing import ClassVar, Type
29+
from typing import ClassVar, Protocol, Type
2830

2931
from absl import logging
3032
from etils import epath
@@ -37,11 +39,11 @@
3739
from tensorflow_datasets.core.utils import resource_utils
3840

3941
# Internal registry containing <str registered_name, DatasetBuilder subclass>
40-
_DATASET_REGISTRY = {}
42+
_DATASET_REGISTRY: dict[str, Type[RegisteredDataset]] = {}
4143

4244
# Internal registry containing:
4345
# <str snake_cased_name, abstract DatasetBuilder subclass>
44-
_ABSTRACT_DATASET_REGISTRY = {}
46+
_ABSTRACT_DATASET_REGISTRY: dict[str, Type[RegisteredDataset]] = {}
4547

4648
# Keep track of dict[str (module name), list[DatasetBuilder]]
4749
# This is directly accessed by `tfds.community.builder_cls_from_module` when
@@ -289,6 +291,75 @@ def __init_subclass__(cls, skip_registration=False, **kwargs): # pylint: disabl
289291
_DATASET_REGISTRY[cls.name] = cls
290292

291293

294+
class DatasetBuilderProvider(Protocol):
295+
296+
def has_dataset(self, name: str) -> bool:
297+
...
298+
299+
def get_builder_cls(self, name: str) -> Type[RegisteredDataset]:
300+
...
301+
302+
303+
class SourceDirDatasetBuilderProvider(DatasetBuilderProvider):
304+
"""Provider of dataset builders that are defined in the given source code folder."""
305+
306+
def __init__(self, datasets_dir: str):
307+
self._datasets_dir = datasets_dir
308+
self._registry: dict[str, Type[RegisteredDataset]] = {}
309+
310+
@functools.cached_property
311+
def dataset_packages(self) -> dict[str, tuple[epath.Path, str]]:
312+
"""Returns existing datasets.
313+
314+
Returns:
315+
{ds_name: (pkg_path, builder_module)}.
316+
For example: {'mnist': ('/lib/tensorflow_datasets/datasets/mnist',
317+
'tensorflow_datasets.datasets.mnist.builder')}
318+
"""
319+
datasets = {}
320+
datasets_dir_path = resource_utils.tfds_path(self._datasets_dir)
321+
if not datasets_dir_path.exists():
322+
return datasets
323+
ds_dir_pkg = '.'.join(
324+
['tensorflow_datasets'] + self._datasets_dir.split(os.path.sep)
325+
)
326+
for child in datasets_dir_path.iterdir():
327+
# Except for a few exceptions, all children of datasets/ directory are
328+
# packages of datasets, no needs to check child is a directory and
329+
# contains a `builder.py` module.
330+
exceptions = [
331+
'__init__.py',
332+
]
333+
if child.name not in exceptions:
334+
pkg_path = epath.Path(datasets_dir_path) / child.name
335+
builder_module = (
336+
f'{ds_dir_pkg}.{child.name}.{child.name}{_BUILDER_MODULE_SUFFIX}'
337+
)
338+
datasets[child.name] = (pkg_path, builder_module)
339+
return datasets
340+
341+
def has_dataset(self, name: str) -> bool:
342+
return name in self.dataset_packages
343+
344+
def get_builder_cls(self, name: str) -> Type[RegisteredDataset]:
345+
if name in self._registry:
346+
return self._registry[name]
347+
pkg_dir_path, builder_module = self.dataset_packages[name]
348+
cls = importlib.import_module(builder_module).Builder
349+
cls.pkg_dir_path = pkg_dir_path
350+
self._registry[name] = cls
351+
return cls
352+
353+
354+
_DATASET_PROVIDER_REGISTRY: list[DatasetBuilderProvider] = [
355+
SourceDirDatasetBuilderProvider(constants.DATASETS_TFDS_SRC_DIR)
356+
]
357+
358+
359+
def add_dataset_builder_provider(provider: DatasetBuilderProvider) -> None:
360+
_DATASET_PROVIDER_REGISTRY.append(provider)
361+
362+
292363
def _is_builder_available(builder_cls: Type[RegisteredDataset]) -> bool:
293364
"""Returns `True` is the builder is available."""
294365
return visibility.DatasetType.TFDS_PUBLIC.is_available()
@@ -346,14 +417,9 @@ def _get_existing_dataset_packages(
346417

347418
def imported_builder_cls(name: str) -> Type[RegisteredDataset]:
348419
"""Returns the Registered dataset class."""
349-
existing_ds_pkgs = _get_existing_dataset_packages(
350-
constants.DATASETS_TFDS_SRC_DIR
351-
)
352-
if name in existing_ds_pkgs:
353-
pkg_dir_path, builder_module = existing_ds_pkgs[name]
354-
cls = importlib.import_module(builder_module).Builder
355-
cls.pkg_dir_path = pkg_dir_path
356-
return cls
420+
for dataset_builder_provider in _DATASET_PROVIDER_REGISTRY:
421+
if dataset_builder_provider.has_dataset(name):
422+
return dataset_builder_provider.get_builder_cls(name)
357423

358424
if name in _ABSTRACT_DATASET_REGISTRY:
359425
# Will raise TypeError: Can't instantiate abstract class X with abstract

tensorflow_datasets/core/registered_test.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from unittest import mock
2121
import pytest
2222
from tensorflow_datasets import testing
23-
from tensorflow_datasets.core import constants
2423
from tensorflow_datasets.core import load
2524
from tensorflow_datasets.core import registered
2625
from tensorflow_datasets.core import splits
@@ -326,14 +325,16 @@ def test_name_inferred_from_pkg_level3():
326325
assert ds_builder.name == "dummy_ds_2"
327326

328327

329-
class ConfigBasedBuildersTest(testing.TestCase):
328+
class SourceDirDatasetBuilderProviderTest(testing.TestCase):
330329

331-
def test__get_existing_dataset_packages(self):
332-
ds_packages = registered._get_existing_dataset_packages(
330+
def test_provider(self):
331+
provider = registered.SourceDirDatasetBuilderProvider(
333332
"testing/dummy_config_based_datasets"
334333
)
335-
self.assertEqual(set(ds_packages.keys()), {"dummy_ds_1", "dummy_ds_2"})
336-
pkg_path, builder_module = ds_packages["dummy_ds_1"]
334+
self.assertEqual(
335+
set(provider.dataset_packages), {"dummy_ds_1", "dummy_ds_2"}
336+
)
337+
pkg_path, builder_module = provider.dataset_packages["dummy_ds_1"]
337338
self.assertEndsWith(
338339
str(pkg_path),
339340
"tensorflow_datasets/testing/dummy_config_based_datasets/dummy_ds_1",
@@ -343,10 +344,12 @@ def test__get_existing_dataset_packages(self):
343344
"tensorflow_datasets.testing.dummy_config_based_datasets.dummy_ds_1.dummy_ds_1_dataset_builder",
344345
)
345346

346-
@mock.patch.object(
347-
constants, "DATASETS_TFDS_SRC_DIR", "testing/dummy_config_based_datasets"
348-
)
349347
def test_imported_builder_cls(self):
348+
registered.add_dataset_builder_provider(
349+
registered.SourceDirDatasetBuilderProvider(
350+
"testing/dummy_config_based_datasets"
351+
)
352+
)
350353
builder = registered.imported_builder_cls("dummy_ds_1")
351354
self.assertEqual(builder.name, "dummy_ds_1")
352355

0 commit comments

Comments
 (0)