From 77ec74afd58391366de72109c14dc91b89c0eaba Mon Sep 17 00:00:00 2001 From: Willy Fitra Hendria Date: Thu, 1 Feb 2024 23:04:32 +0900 Subject: [PATCH 1/2] skip validation stage when val_split_mode is set to none Signed-off-by: Willy Fitra Hendria --- src/anomalib/utils/config.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/src/anomalib/utils/config.py b/src/anomalib/utils/config.py index 43b610ed72..d0024212f1 100644 --- a/src/anomalib/utils/config.py +++ b/src/anomalib/utils/config.py @@ -19,6 +19,8 @@ from jsonargparse import Path as JSONArgparsePath from omegaconf import DictConfig, ListConfig, OmegaConf +from anomalib.data.utils import ValSplitMode + logger = logging.getLogger(__name__) @@ -122,6 +124,7 @@ def update_config(config: DictConfig | ListConfig | Namespace) -> DictConfig | L config.results_dir.path = str(project_path) config = _update_nncf_config(config) + config = _update_val_config(config) # write the original config for eventual debug (modified config at the end of the function) (project_path / "config_original.yaml").write_text(to_yaml(config_original)) @@ -214,6 +217,21 @@ def _update_nncf_config(config: DictConfig | ListConfig) -> DictConfig | ListCon return config +def _update_val_config(config: DictConfig | ListConfig) -> DictConfig | ListConfig: + """Skip validation if `val_split_mode` is set to 'none'. + + Args: + config (DictConfig | ListConfig): Configurable parameters of the current run. + + Returns: + DictConfig | ListConfig: Updated configurable parameters in DictConfig object. + """ + if config.data.init_args.val_split_mode == ValSplitMode.NONE and config.trainer.limit_val_batches != 0.0: + logger.warning("Running without validation set. Setting trainer.limit_val_batches to 0.") + config.trainer.limit_val_batches = 0.0 + return config + + def _show_warnings(config: DictConfig | ListConfig | Namespace) -> None: """Show warnings if any based on the configuration settings. From a9de159f63c96e6b90800512a6e3b55cc7d9e246 Mon Sep 17 00:00:00 2001 From: Willy Fitra Hendria Date: Thu, 1 Feb 2024 23:10:35 +0900 Subject: [PATCH 2/2] skip test stage when test_split_mode is set to none Signed-off-by: Willy Fitra Hendria --- src/anomalib/engine/engine.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/anomalib/engine/engine.py b/src/anomalib/engine/engine.py index 79aba4a8c8..9a70bed174 100644 --- a/src/anomalib/engine/engine.py +++ b/src/anomalib/engine/engine.py @@ -23,6 +23,7 @@ from anomalib.callbacks.thresholding import _ThresholdCallback from anomalib.callbacks.visualizer import _VisualizationCallback from anomalib.data import AnomalibDataModule, AnomalibDataset, PredictDataset +from anomalib.data.utils import TestSplitMode from anomalib.deploy.export import ExportType, export_to_onnx, export_to_openvino, export_to_torch from anomalib.metrics.threshold import BaseThreshold from anomalib.models import AnomalyModule @@ -636,7 +637,7 @@ def train( test_dataloaders: EVAL_DATALOADERS | None = None, datamodule: AnomalibDataModule | None = None, ckpt_path: str | None = None, - ) -> _EVALUATE_OUTPUT: + ) -> _EVALUATE_OUTPUT | None: """Fits the model and then calls test on it. Args: @@ -653,6 +654,9 @@ def train( ckpt_path (str | None, optional): Checkpoint path. If provided, the model will be loaded from this path. Defaults to None. + Returns: + _EVALUATE_OUTPUT | None: A List of dictionaries containing the test results. 1 dict per dataloader. + CLI Usage: 1. you can pick a model, and you can run through the MVTec dataset. ```python @@ -674,7 +678,12 @@ def train( self.trainer.validate(model, val_dataloaders, None, verbose=False, datamodule=datamodule) else: self.trainer.fit(model, train_dataloaders, val_dataloaders, datamodule, ckpt_path) - self.trainer.test(model, test_dataloaders, ckpt_path=ckpt_path, datamodule=datamodule) + + if datamodule is not None and datamodule.test_split_mode == TestSplitMode.NONE: + logger.info(f"The test_split_mode is set to '{TestSplitMode.NONE}'. Skipping test stage.") + logger.warning(f"Found {len(datamodule.test_data)} images in the test set.") + return None + return self.trainer.test(model, test_dataloaders, ckpt_path=ckpt_path, datamodule=datamodule) def export( self,