|
2 | 2 |
|
3 | 3 | from .base_client import BaseClient |
4 | 4 | from .. import repositories, models, constants, utils |
5 | | -from ..sdk_exceptions import InvalidParametersError |
| 5 | +from ..sdk_exceptions import ResourceCreatingDataError, InvalidParametersError |
6 | 6 | from ..validation_messages import EXPERIMENT_MODEL_PATH_VALIDATION_ERROR |
7 | 7 |
|
8 | 8 |
|
@@ -72,6 +72,8 @@ def create_single_node( |
72 | 72 | if not is_preemptible: |
73 | 73 | is_preemptible = None |
74 | 74 |
|
| 75 | + datasets = self._dataset_dicts_to_instances(datasets) |
| 76 | + |
75 | 77 | experiment = models.SingleNodeExperiment( |
76 | 78 | experiment_type_id=constants.ExperimentType.SINGLE_NODE, |
77 | 79 | name=name, |
@@ -191,6 +193,8 @@ def create_multi_node( |
191 | 193 | if not is_preemptible: |
192 | 194 | is_preemptible = None |
193 | 195 |
|
| 196 | + datasets = self._dataset_dicts_to_instances(datasets) |
| 197 | + |
194 | 198 | experiment = models.MultiNodeExperiment( |
195 | 199 | name=name, |
196 | 200 | project_id=project_id, |
@@ -314,6 +318,7 @@ def create_mpi_multi_node( |
314 | 318 | if not is_preemptible: |
315 | 319 | is_preemptible = None |
316 | 320 |
|
| 321 | + datasets = self._dataset_dicts_to_instances(datasets) |
317 | 322 | experiment_type_id = constants.ExperimentType.MPI_MULTI_NODE |
318 | 323 |
|
319 | 324 | experiment = models.MpiMultiNodeExperiment( |
@@ -421,6 +426,8 @@ def run_single_node( |
421 | 426 | if not is_preemptible: |
422 | 427 | is_preemptible = None |
423 | 428 |
|
| 429 | + datasets = self._dataset_dicts_to_instances(datasets) |
| 430 | + |
424 | 431 | experiment = models.SingleNodeExperiment( |
425 | 432 | experiment_type_id=constants.ExperimentType.SINGLE_NODE, |
426 | 433 | name=name, |
@@ -538,6 +545,8 @@ def run_multi_node( |
538 | 545 | if not is_preemptible: |
539 | 546 | is_preemptible = None |
540 | 547 |
|
| 548 | + datasets = self._dataset_dicts_to_instances(datasets) |
| 549 | + |
541 | 550 | experiment = models.MultiNodeExperiment( |
542 | 551 | name=name, |
543 | 552 | project_id=project_id, |
@@ -661,6 +670,8 @@ def run_mpi_multi_node( |
661 | 670 | if not is_preemptible: |
662 | 671 | is_preemptible = None |
663 | 672 |
|
| 673 | + datasets = self._dataset_dicts_to_instances(datasets) |
| 674 | + |
664 | 675 | experiment_type_id = constants.ExperimentType.MPI_MULTI_NODE |
665 | 676 |
|
666 | 677 | experiment = models.MpiMultiNodeExperiment( |
@@ -840,3 +851,23 @@ def stream_metrics(self, experiment_id, interval="30s", built_in_metrics=None): |
840 | 851 | built_in_metrics=built_in_metrics, |
841 | 852 | ) |
842 | 853 | return metrics |
| 854 | + |
| 855 | + def _dataset_dicts_to_instances(self, datasets): |
| 856 | + if not datasets: |
| 857 | + return None |
| 858 | + |
| 859 | + if isinstance(datasets, dict): |
| 860 | + datasets = [datasets] |
| 861 | + |
| 862 | + for ds in datasets: |
| 863 | + if not ds.get("uri"): |
| 864 | + raise ResourceCreatingDataError("Error while creating experiment with dataset: " |
| 865 | + "\"uri\" key is required and it's value must be a valid S3 URI") |
| 866 | + |
| 867 | + for ds in datasets: |
| 868 | + volume_options = ds.setdefault("volume_options", {}) |
| 869 | + volume_options.setdefault("kind", ds.pop("volume_kind", None)) |
| 870 | + volume_options.setdefault("size", ds.pop("volume_size", None)) |
| 871 | + |
| 872 | + datasets = [models.Dataset(**ds) for ds in datasets] |
| 873 | + return datasets |
0 commit comments