|
15 | 15 |
|
16 | 16 | """Access registered datasets."""
|
17 | 17 |
|
| 18 | +from __future__ import annotations |
| 19 | + |
18 | 20 | import abc
|
19 | 21 | import collections
|
20 | 22 | from collections.abc import Iterator
|
|
24 | 26 | import inspect
|
25 | 27 | import os.path
|
26 | 28 | import time
|
27 |
| -from typing import ClassVar, Type |
| 29 | +from typing import ClassVar, Protocol, Type |
28 | 30 |
|
29 | 31 | from absl import logging
|
30 | 32 | from etils import epath
|
|
37 | 39 | from tensorflow_datasets.core.utils import resource_utils
|
38 | 40 |
|
39 | 41 | # Internal registry containing <str registered_name, DatasetBuilder subclass>
|
40 |
| -_DATASET_REGISTRY = {} |
| 42 | +_DATASET_REGISTRY: dict[str, Type[RegisteredDataset]] = {} |
41 | 43 |
|
42 | 44 | # Internal registry containing:
|
43 | 45 | # <str snake_cased_name, abstract DatasetBuilder subclass>
|
44 |
| -_ABSTRACT_DATASET_REGISTRY = {} |
| 46 | +_ABSTRACT_DATASET_REGISTRY: dict[str, Type[RegisteredDataset]] = {} |
45 | 47 |
|
46 | 48 | # Keep track of dict[str (module name), list[DatasetBuilder]]
|
47 | 49 | # This is directly accessed by `tfds.community.builder_cls_from_module` when
|
@@ -289,6 +291,75 @@ def __init_subclass__(cls, skip_registration=False, **kwargs): # pylint: disabl
|
289 | 291 | _DATASET_REGISTRY[cls.name] = cls
|
290 | 292 |
|
291 | 293 |
|
| 294 | +class DatasetBuilderProvider(Protocol): |
| 295 | + |
| 296 | + def has_dataset(self, name: str) -> bool: |
| 297 | + ... |
| 298 | + |
| 299 | + def get_builder_cls(self, name: str) -> Type[RegisteredDataset]: |
| 300 | + ... |
| 301 | + |
| 302 | + |
| 303 | +class SourceDirDatasetBuilderProvider(DatasetBuilderProvider): |
| 304 | + """Provider of dataset builders that are defined in the given source code folder.""" |
| 305 | + |
| 306 | + def __init__(self, datasets_dir: str): |
| 307 | + self._datasets_dir = datasets_dir |
| 308 | + self._registry: dict[str, Type[RegisteredDataset]] = {} |
| 309 | + |
| 310 | + @functools.cached_property |
| 311 | + def dataset_packages(self) -> dict[str, tuple[epath.Path, str]]: |
| 312 | + """Returns existing datasets. |
| 313 | +
|
| 314 | + Returns: |
| 315 | + {ds_name: (pkg_path, builder_module)}. |
| 316 | + For example: {'mnist': ('/lib/tensorflow_datasets/datasets/mnist', |
| 317 | + 'tensorflow_datasets.datasets.mnist.builder')} |
| 318 | + """ |
| 319 | + datasets = {} |
| 320 | + datasets_dir_path = resource_utils.tfds_path(self._datasets_dir) |
| 321 | + if not datasets_dir_path.exists(): |
| 322 | + return datasets |
| 323 | + ds_dir_pkg = '.'.join( |
| 324 | + ['tensorflow_datasets'] + self._datasets_dir.split(os.path.sep) |
| 325 | + ) |
| 326 | + for child in datasets_dir_path.iterdir(): |
| 327 | + # Except for a few exceptions, all children of datasets/ directory are |
| 328 | + # packages of datasets, no needs to check child is a directory and |
| 329 | + # contains a `builder.py` module. |
| 330 | + exceptions = [ |
| 331 | + '__init__.py', |
| 332 | + ] |
| 333 | + if child.name not in exceptions: |
| 334 | + pkg_path = epath.Path(datasets_dir_path) / child.name |
| 335 | + builder_module = ( |
| 336 | + f'{ds_dir_pkg}.{child.name}.{child.name}{_BUILDER_MODULE_SUFFIX}' |
| 337 | + ) |
| 338 | + datasets[child.name] = (pkg_path, builder_module) |
| 339 | + return datasets |
| 340 | + |
| 341 | + def has_dataset(self, name: str) -> bool: |
| 342 | + return name in self.dataset_packages |
| 343 | + |
| 344 | + def get_builder_cls(self, name: str) -> Type[RegisteredDataset]: |
| 345 | + if name in self._registry: |
| 346 | + return self._registry[name] |
| 347 | + pkg_dir_path, builder_module = self.dataset_packages[name] |
| 348 | + cls = importlib.import_module(builder_module).Builder |
| 349 | + cls.pkg_dir_path = pkg_dir_path |
| 350 | + self._registry[name] = cls |
| 351 | + return cls |
| 352 | + |
| 353 | + |
| 354 | +_DATASET_PROVIDER_REGISTRY: list[DatasetBuilderProvider] = [ |
| 355 | + SourceDirDatasetBuilderProvider(constants.DATASETS_TFDS_SRC_DIR) |
| 356 | +] |
| 357 | + |
| 358 | + |
| 359 | +def add_dataset_builder_provider(provider: DatasetBuilderProvider) -> None: |
| 360 | + _DATASET_PROVIDER_REGISTRY.append(provider) |
| 361 | + |
| 362 | + |
292 | 363 | def _is_builder_available(builder_cls: Type[RegisteredDataset]) -> bool:
|
293 | 364 | """Returns `True` is the builder is available."""
|
294 | 365 | return visibility.DatasetType.TFDS_PUBLIC.is_available()
|
@@ -346,14 +417,9 @@ def _get_existing_dataset_packages(
|
346 | 417 |
|
347 | 418 | def imported_builder_cls(name: str) -> Type[RegisteredDataset]:
|
348 | 419 | """Returns the Registered dataset class."""
|
349 |
| - existing_ds_pkgs = _get_existing_dataset_packages( |
350 |
| - constants.DATASETS_TFDS_SRC_DIR |
351 |
| - ) |
352 |
| - if name in existing_ds_pkgs: |
353 |
| - pkg_dir_path, builder_module = existing_ds_pkgs[name] |
354 |
| - cls = importlib.import_module(builder_module).Builder |
355 |
| - cls.pkg_dir_path = pkg_dir_path |
356 |
| - return cls |
| 420 | + for dataset_builder_provider in _DATASET_PROVIDER_REGISTRY: |
| 421 | + if dataset_builder_provider.has_dataset(name): |
| 422 | + return dataset_builder_provider.get_builder_cls(name) |
357 | 423 |
|
358 | 424 | if name in _ABSTRACT_DATASET_REGISTRY:
|
359 | 425 | # Will raise TypeError: Can't instantiate abstract class X with abstract
|
|
0 commit comments