Skip to content

Commit 06e4f1c

Browse files
Implementing str method
Switched name from dataset to dataloader Switched name Prediction to Predict removed available keyword and instead write None if not available Switched from unknown to NA
1 parent 5bf252e commit 06e4f1c

File tree

2 files changed

+38
-38
lines changed

2 files changed

+38
-38
lines changed

src/lightning/pytorch/core/datamodule.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -255,16 +255,16 @@ def __str__(self) -> str:
255255
"""
256256

257257
class dataset_info:
258-
def __init__(self, available: str, length: str) -> None:
258+
def __init__(self, available: bool, length: str) -> None:
259259
self.available = available
260260
self.length = length
261261

262262
def retrieve_dataset_info(loader: DataLoader) -> dataset_info:
263263
"""Helper function to compute dataset information."""
264264
dataset = loader.dataset
265-
size: str = str(len(dataset)) if isinstance(dataset, Sized) else "unknown"
265+
size: str = str(len(dataset)) if isinstance(dataset, Sized) else "NA"
266266

267-
return dataset_info("yes", size)
267+
return dataset_info(True, size)
268268

269269
def loader_info(
270270
loader: Union[DataLoader, Iterable[DataLoader]],
@@ -282,7 +282,7 @@ def extract_loader_info(methods: list[tuple[str, str]]) -> dict:
282282
loader = loader_callback()
283283
info[loader_name] = loader_info(loader)
284284
except Exception:
285-
info[loader_name] = dataset_info("no", "unknown")
285+
info[loader_name] = dataset_info(False, "")
286286

287287
return info
288288

@@ -292,11 +292,11 @@ def format_loader_info(info: dict[str, Union[dataset_info, Iterable[dataset_info
292292
for loader_name, loader_info in info.items():
293293
# Single dataset
294294
if isinstance(loader_info, dataset_info):
295-
loader_info_formatted = f"available={loader_info.available}, size={loader_info.length}"
295+
loader_info_formatted = "None" if not loader_info.available else f"size={loader_info.length}"
296296
# Iterable of datasets
297297
else:
298298
loader_info_formatted = " ; ".join(
299-
f"{i}. available={loader_info_i.available}, size={loader_info_i.length}"
299+
"None" if not loader_info_i.available else f"{i}. size={loader_info_i.length}"
300300
for i, loader_info_i in enumerate(loader_info, start=1)
301301
)
302302

@@ -306,10 +306,10 @@ def format_loader_info(info: dict[str, Union[dataset_info, Iterable[dataset_info
306306

307307
# Available dataloader methods
308308
datamodule_loader_methods: list[tuple[str, str]] = [
309-
("Train dataset", "train_dataloader"),
310-
("Validation dataset", "val_dataloader"),
311-
("Test dataset", "test_dataloader"),
312-
("Prediction dataset", "predict_dataloader"),
309+
("Train dataloader", "train_dataloader"),
310+
("Validation dataloader", "val_dataloader"),
311+
("Test dataloader", "test_dataloader"),
312+
("Predict dataloader", "predict_dataloader"),
313313
]
314314

315315
# Retrieve information for each dataloader method

tests/tests_pytorch/core/test_datamodules.py

+28-28
Original file line numberDiff line numberDiff line change
@@ -522,10 +522,10 @@ def test_datamodule_string_not_available():
522522
dm = BoringDataModule()
523523

524524
expected_output = (
525-
f"{{Train dataset: available=no, size=unknown}}{os.linesep}"
526-
f"{{Validation dataset: available=no, size=unknown}}{os.linesep}"
527-
f"{{Test dataset: available=no, size=unknown}}{os.linesep}"
528-
f"{{Prediction dataset: available=no, size=unknown}}"
525+
f"{{Train dataloader: None}}{os.linesep}"
526+
f"{{Validation dataloader: None}}{os.linesep}"
527+
f"{{Test dataloader: None}}{os.linesep}"
528+
f"{{Predict dataloader: None}}"
529529
)
530530
out = str(dm)
531531

@@ -537,10 +537,10 @@ def test_datamodule_string_fit_setup():
537537
dm.setup(stage="fit")
538538

539539
expected_output = (
540-
f"{{Train dataset: available=yes, size=64}}{os.linesep}"
541-
f"{{Validation dataset: available=yes, size=64}}{os.linesep}"
542-
f"{{Test dataset: available=no, size=unknown}}{os.linesep}"
543-
f"{{Prediction dataset: available=no, size=unknown}}"
540+
f"{{Train dataloader: size=64}}{os.linesep}"
541+
f"{{Validation dataloader: size=64}}{os.linesep}"
542+
f"{{Test dataloader: None}}{os.linesep}"
543+
f"{{Predict dataloader: None}}"
544544
)
545545
output = str(dm)
546546

@@ -552,10 +552,10 @@ def test_datamodule_string_validation_setup():
552552
dm.setup(stage="validate")
553553

554554
expected_output = (
555-
f"{{Train dataset: available=no, size=unknown}}{os.linesep}"
556-
f"{{Validation dataset: available=yes, size=64}}{os.linesep}"
557-
f"{{Test dataset: available=no, size=unknown}}{os.linesep}"
558-
f"{{Prediction dataset: available=no, size=unknown}}"
555+
f"{{Train dataloader: None}}{os.linesep}"
556+
f"{{Validation dataloader: size=64}}{os.linesep}"
557+
f"{{Test dataloader: None}}{os.linesep}"
558+
f"{{Predict dataloader: None}}"
559559
)
560560
output = str(dm)
561561

@@ -567,10 +567,10 @@ def test_datamodule_string_test_setup():
567567
dm.setup(stage="test")
568568

569569
expected_output = (
570-
f"{{Train dataset: available=no, size=unknown}}{os.linesep}"
571-
f"{{Validation dataset: available=no, size=unknown}}{os.linesep}"
572-
f"{{Test dataset: available=yes, size=64}}{os.linesep}"
573-
f"{{Prediction dataset: available=no, size=unknown}}"
570+
f"{{Train dataloader: None}}{os.linesep}"
571+
f"{{Validation dataloader: None}}{os.linesep}"
572+
f"{{Test dataloader: size=64}}{os.linesep}"
573+
f"{{Predict dataloader: None}}"
574574
)
575575
output = str(dm)
576576

@@ -582,10 +582,10 @@ def test_datamodule_string_predict_setup():
582582
dm.setup(stage="predict")
583583

584584
expected_output = (
585-
f"{{Train dataset: available=no, size=unknown}}{os.linesep}"
586-
f"{{Validation dataset: available=no, size=unknown}}{os.linesep}"
587-
f"{{Test dataset: available=no, size=unknown}}{os.linesep}"
588-
f"{{Prediction dataset: available=yes, size=64}}"
585+
f"{{Train dataloader: None}}{os.linesep}"
586+
f"{{Validation dataloader: None}}{os.linesep}"
587+
f"{{Test dataloader: None}}{os.linesep}"
588+
f"{{Predict dataloader: size=64}}"
589589
)
590590
output = str(dm)
591591

@@ -597,10 +597,10 @@ def test_datamodule_string_no_len():
597597
dm.setup("fit")
598598

599599
expected_output = (
600-
f"{{Train dataset: available=yes, size=unknown}}{os.linesep}"
601-
f"{{Validation dataset: available=yes, size=unknown}}{os.linesep}"
602-
f"{{Test dataset: available=no, size=unknown}}{os.linesep}"
603-
f"{{Prediction dataset: available=no, size=unknown}}"
600+
f"{{Train dataloader: size=NA}}{os.linesep}"
601+
f"{{Validation dataloader: size=NA}}{os.linesep}"
602+
f"{{Test dataloader: None}}{os.linesep}"
603+
f"{{Predict dataloader: None}}"
604604
)
605605
output = str(dm)
606606

@@ -612,10 +612,10 @@ def test_datamodule_string_iterable():
612612
dm.setup("fit")
613613

614614
expected_output = (
615-
f"{{Train dataset: 1. available=yes, size=16 ; 2. available=yes, size=unknown}}{os.linesep}"
616-
f"{{Validation dataset: 1. available=yes, size=32 ; 2. available=yes, size=unknown}}{os.linesep}"
617-
f"{{Test dataset: available=no, size=unknown}}{os.linesep}"
618-
f"{{Prediction dataset: available=no, size=unknown}}"
615+
f"{{Train dataloader: 1. size=16 ; 2. size=NA}}{os.linesep}"
616+
f"{{Validation dataloader: 1. size=32 ; 2. size=NA}}{os.linesep}"
617+
f"{{Test dataloader: None}}{os.linesep}"
618+
f"{{Predict dataloader: None}}"
619619
)
620620
output = str(dm)
621621

0 commit comments

Comments
 (0)