diff --git a/src/anomalib/cli/cli.py b/src/anomalib/cli/cli.py
index 9de4a1d9c1..6638305194 100644
--- a/src/anomalib/cli/cli.py
+++ b/src/anomalib/cli/cli.py
@@ -16,6 +16,7 @@
 from rich import traceback
 
 from anomalib import TaskType, __version__
+from anomalib.cli.pipelines import PIPELINE_REGISTRY, pipeline_subcommands, run_pipeline
 from anomalib.cli.utils.help_formatter import CustomHelpFormatter, get_short_docstring
 from anomalib.cli.utils.openvino import add_openvino_export_arguments
 from anomalib.loggers import configure_logger
@@ -132,6 +133,13 @@ def add_subcommands(self, **kwargs) -> None:
             # add arguments to subcommand
             getattr(self, f"add_{subcommand}_arguments")(sub_parser)
 
+        # Add pipeline subcommands
+        if PIPELINE_REGISTRY is not None:
+            for subcommand, value in pipeline_subcommands().items():
+                sub_parser = PIPELINE_REGISTRY[subcommand].get_parser()
+                self.subcommand_parsers[subcommand] = sub_parser
+                parser_subcommands.add_subcommand(subcommand, sub_parser, help=value["description"])
+
     def add_arguments_to_parser(self, parser: ArgumentParser) -> None:
         """Extend trainer's arguments to add engine arguments.
 
@@ -353,6 +361,8 @@ def _run_subcommand(self) -> None:
             fn = getattr(self.engine, self.subcommand)
             fn_kwargs = self._prepare_subcommand_kwargs(self.subcommand)
             fn(**fn_kwargs)
+        elif PIPELINE_REGISTRY is not None and self.subcommand in pipeline_subcommands():
+            run_pipeline(self.config)
         else:
             self.config_init = self.parser.instantiate_classes(self.config)
             getattr(self, f"{self.subcommand}")()
diff --git a/src/anomalib/cli/pipelines.py b/src/anomalib/cli/pipelines.py
new file mode 100644
index 0000000000..4ec13cfc72
--- /dev/null
+++ b/src/anomalib/cli/pipelines.py
@@ -0,0 +1,41 @@
+"""Subcommand for pipelines."""
+
+# Copyright (C) 2024 Intel Corporation
+# SPDX-License-Identifier: Apache-2.0
+
+
+import logging
+
+from jsonargparse import Namespace
+
+from anomalib.cli.utils.help_formatter import get_short_docstring
+from anomalib.utils.exceptions import try_import
+
+logger = logging.getLogger(__name__)
+
+if try_import("anomalib.pipelines"):
+    from anomalib.pipelines import Benchmark
+    from anomalib.pipelines.components.base import Pipeline
+
+    PIPELINE_REGISTRY: dict[str, type[Pipeline]] | None = {"benchmark": Benchmark}
+else:
+    PIPELINE_REGISTRY = None
+
+
+def pipeline_subcommands() -> dict[str, dict[str, str]]:
+    """Return subcommands for pipelines."""
+    if PIPELINE_REGISTRY is not None:
+        return {name: {"description": get_short_docstring(pipeline)} for name, pipeline in PIPELINE_REGISTRY.items()}
+    return {}
+
+
+def run_pipeline(args: Namespace) -> None:
+    """Run pipeline."""
+    logger.warning("This feature is experimental. It may change or be removed in the future.")
+    if PIPELINE_REGISTRY is not None:
+        subcommand = args.subcommand
+        config = args[subcommand]
+        PIPELINE_REGISTRY[subcommand]().run(config)
+    else:
+        msg = "Pipeline is not available"
+        raise ValueError(msg)
diff --git a/src/anomalib/cli/utils/help_formatter.py b/src/anomalib/cli/utils/help_formatter.py
index ea4ef825b6..db5b1a1bf6 100644
--- a/src/anomalib/cli/utils/help_formatter.py
+++ b/src/anomalib/cli/utils/help_formatter.py
@@ -6,7 +6,6 @@
 import argparse
 import re
 import sys
-from typing import TypeVar
 
 import docstring_parser
 from jsonargparse import DefaultHelpFormatter
@@ -36,11 +35,11 @@
     print("To use other subcommand using `anomalib install`")
 
 
-def get_short_docstring(component: TypeVar) -> str:
+def get_short_docstring(component: type) -> str:
     """Get the short description from the docstring.
 
     Args:
-        component (TypeVar): The component to get the docstring from
+        component (type): The component to get the docstring from
 
     Returns:
         str: The short description
diff --git a/src/anomalib/data/__init__.py b/src/anomalib/data/__init__.py
index 85a4fd1589..1f937ede52 100644
--- a/src/anomalib/data/__init__.py
+++ b/src/anomalib/data/__init__.py
@@ -3,7 +3,6 @@
 # Copyright (C) 2022-2024 Intel Corporation
 # SPDX-License-Identifier: Apache-2.0
 
-
 import importlib
 import logging
 from enum import Enum
@@ -29,20 +28,35 @@
 )
 
 
-def get_datamodule(config: DictConfig | ListConfig) -> AnomalibDataModule:
+class UnknownDatamoduleError(ModuleNotFoundError):
+    ...
+
+
+def get_datamodule(config: DictConfig | ListConfig | dict) -> AnomalibDataModule:
     """Get Anomaly Datamodule.
 
     Args:
-        config (DictConfig | ListConfig): Configuration of the anomaly model.
+        config (DictConfig | ListConfig | dict): Configuration of the anomaly model.
 
     Returns:
         PyTorch Lightning DataModule
     """
     logger.info("Loading the datamodule")
 
-    module = importlib.import_module(".".join(config.data.class_path.split(".")[:-1]))
-    dataclass = getattr(module, config.data.class_path.split(".")[-1])
-    init_args = {**config.data.get("init_args", {})}  # get dict
+    if isinstance(config, dict):
+        config = DictConfig(config)
+
+    try:
+        _config = config.data if "data" in config else config
+        if len(_config.class_path.split(".")) > 1:
+            module = importlib.import_module(".".join(_config.class_path.split(".")[:-1]))
+        else:
+            module = importlib.import_module("anomalib.data")
+    except ModuleNotFoundError as exception:
+        logger.exception(f"ModuleNotFoundError: {_config.class_path}")
+        raise UnknownDatamoduleError from exception
+    dataclass = getattr(module, _config.class_path.split(".")[-1])
+    init_args = {**_config.get("init_args", {})}  # get dict
     if "image_size" in init_args:
         init_args["image_size"] = to_tuple(init_args["image_size"])
 
diff --git a/src/anomalib/metrics/f1_score.py b/src/anomalib/metrics/f1_score.py
index 0477e8306d..f666542d32 100644
--- a/src/anomalib/metrics/f1_score.py
+++ b/src/anomalib/metrics/f1_score.py
@@ -6,7 +6,6 @@
 # Copyright (C) 2024 Intel Corporation
 # SPDX-License-Identifier: Apache-2.0
 
-
 import logging
 from typing import Any, Literal
 
diff --git a/src/anomalib/pipelines/__init__.py b/src/anomalib/pipelines/__init__.py
new file mode 100644
index 0000000000..0ca537d4de
--- /dev/null
+++ b/src/anomalib/pipelines/__init__.py
@@ -0,0 +1,8 @@
+"""Pipelines for end-to-end usecases."""
+
+# Copyright (C) 2024 Intel Corporation
+# SPDX-License-Identifier: Apache-2.0
+
+from .benchmark import Benchmark
+
+__all__ = ["Benchmark"]
diff --git a/src/anomalib/pipelines/benchmark/__init__.py b/src/anomalib/pipelines/benchmark/__init__.py
new file mode 100644
index 0000000000..bfb34aded2
--- /dev/null
+++ b/src/anomalib/pipelines/benchmark/__init__.py
@@ -0,0 +1,8 @@
+"""Benchmarking."""
+
+# Copyright (C) 2024 Intel Corporation
+# SPDX-License-Identifier: Apache-2.0
+
+from .pipeline import Benchmark
+
+__all__ = ["Benchmark"]
diff --git a/src/anomalib/pipelines/benchmark/generator.py b/src/anomalib/pipelines/benchmark/generator.py
new file mode 100644
index 0000000000..40bd446a89
--- /dev/null
+++ b/src/anomalib/pipelines/benchmark/generator.py
@@ -0,0 +1,41 @@
+"""Benchmark job generator."""
+
+# Copyright (C) 2024 Intel Corporation
+# SPDX-License-Identifier: Apache-2.0
+
+from collections.abc import Generator
+
+from anomalib.data import get_datamodule
+from anomalib.models import get_model
+from anomalib.pipelines.components import JobGenerator
+from anomalib.pipelines.components.utils import get_iterator_from_grid_dict
+from anomalib.utils.logging import hide_output
+
+from .job import BenchmarkJob
+
+
+class BenchmarkJobGenerator(JobGenerator):
+    """Generate BenchmarkJob.
+
+    Args:
+        accelerator (str): The accelerator to use.
+    """
+
+    def __init__(self, accelerator: str) -> None:
+        self.accelerator = accelerator
+
+    @property
+    def job_class(self) -> type:
+        """Return the job class."""
+        return BenchmarkJob
+
+    @hide_output
+    def generate_jobs(self, args: dict) -> Generator[BenchmarkJob, None, None]:
+        """Return iterator based on the arguments."""
+        for _container in get_iterator_from_grid_dict(args):
+            yield BenchmarkJob(
+                accelerator=self.accelerator,
+                seed=_container["seed"],
+                model=get_model(_container["model"]),
+                datamodule=get_datamodule(_container["data"]),
+            )
diff --git a/src/anomalib/pipelines/benchmark/job.py b/src/anomalib/pipelines/benchmark/job.py
new file mode 100644
index 0000000000..56b2e69d1a
--- /dev/null
+++ b/src/anomalib/pipelines/benchmark/job.py
@@ -0,0 +1,108 @@
+"""Benchmarking job."""
+
+# Copyright (C) 2024 Intel Corporation
+# SPDX-License-Identifier: Apache-2.0
+
+import logging
+from datetime import datetime
+from pathlib import Path
+from tempfile import TemporaryDirectory
+from typing import Any
+
+import pandas as pd
+from lightning import seed_everything
+from rich.console import Console
+from rich.table import Table
+
+from anomalib.data import AnomalibDataModule
+from anomalib.engine import Engine
+from anomalib.models import AnomalyModule
+from anomalib.pipelines.components import Job
+from anomalib.utils.logging import hide_output
+
+logger = logging.getLogger(__name__)
+
+
+class BenchmarkJob(Job):
+    """Benchmarking job.
+
+    Args:
+        accelerator (str): The accelerator to use.
+        model (AnomalyModule): The model to use.
+        datamodule (AnomalibDataModule): The data module to use.
+        seed (int): The seed to use.
+    """
+
+    name = "benchmark"
+
+    def __init__(self, accelerator: str, model: AnomalyModule, datamodule: AnomalibDataModule, seed: int) -> None:
+        super().__init__()
+        self.accelerator = accelerator
+        self.model = model
+        self.datamodule = datamodule
+        self.seed = seed
+
+    @hide_output
+    def run(
+        self,
+        task_id: int | None = None,
+    ) -> dict[str, Any]:
+        """Run the benchmark."""
+        devices: str | list[int] = "auto"
+        if task_id is not None:
+            devices = [task_id]
+            logger.info(f"Running job {self.model.__class__.__name__} with device {task_id}")
+        with TemporaryDirectory() as temp_dir:
+            seed_everything(self.seed)
+            engine = Engine(
+                accelerator=self.accelerator,
+                devices=devices,
+                default_root_dir=temp_dir,
+            )
+            engine.fit(self.model, self.datamodule)
+            test_results = engine.test(self.model, self.datamodule)
+        # TODO(ashwinvaidya17): Restore throughput
+        # https://github.com/openvinotoolkit/anomalib/issues/2054
+        output = {
+            "seed": self.seed,
+            "accelerator": self.accelerator,
+            "model": self.model.__class__.__name__,
+            "data": self.datamodule.__class__.__name__,
+            "category": self.datamodule.category,
+            **test_results[0],
+        }
+        logger.info(f"Completed with result {output}")
+        return output
+
+    @staticmethod
+    def collect(results: list[dict[str, Any]]) -> pd.DataFrame:
+        """Gather the results returned from run."""
+        output: dict[str, Any] = {}
+        for key in results[0]:
+            output[key] = []
+        for result in results:
+            for key, value in result.items():
+                output[key].append(value)
+        return pd.DataFrame(output)
+
+    @staticmethod
+    def save(result: pd.DataFrame) -> None:
+        """Save the result to a csv file."""
+        BenchmarkJob._print_tabular_results(result)
+        file_path = Path("runs") / BenchmarkJob.name / datetime.now().strftime("%Y-%m-%d-%H:%M:%S") / "results.csv"
+        file_path.parent.mkdir(parents=True, exist_ok=True)
+        result.to_csv(file_path, index=False)
+        logger.info(f"Saved results to {file_path}")
+
+    @staticmethod
+    def _print_tabular_results(gathered_result: pd.DataFrame) -> None:
+        """Print the tabular results."""
+        if gathered_result is not None:
+            console = Console()
+            table = Table(title=f"{BenchmarkJob.name} Results", show_header=True, header_style="bold magenta")
+            _results = gathered_result.to_dict("list")
+            for column in _results:
+                table.add_column(column)
+            for row in zip(*_results.values(), strict=False):
+                table.add_row(*[str(value) for value in row])
+            console.print(table)
diff --git a/src/anomalib/pipelines/benchmark/pipeline.py b/src/anomalib/pipelines/benchmark/pipeline.py
new file mode 100644
index 0000000000..b4410f8094
--- /dev/null
+++ b/src/anomalib/pipelines/benchmark/pipeline.py
@@ -0,0 +1,29 @@
+"""Benchmarking."""
+
+# Copyright (C) 2024 Intel Corporation
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+from anomalib.pipelines.components.base import Pipeline, Runner
+from anomalib.pipelines.components.runners import ParallelRunner, SerialRunner
+
+from .generator import BenchmarkJobGenerator
+
+
+class Benchmark(Pipeline):
+    """Benchmarking pipeline."""
+
+    def _setup_runners(self, args: dict) -> list[Runner]:
+        """Setup the runners for the pipeline."""
+        accelerators = args["accelerator"] if isinstance(args["accelerator"], list) else [args["accelerator"]]
+        runners: list[Runner] = []
+        for accelerator in accelerators:
+            if accelerator == "cpu":
+                runners.append(SerialRunner(BenchmarkJobGenerator("cpu")))
+            elif accelerator == "cuda":
+                runners.append(ParallelRunner(BenchmarkJobGenerator("cuda"), n_jobs=torch.cuda.device_count()))
+            else:
+                msg = f"Unsupported accelerator: {accelerator}"
+                raise ValueError(msg)
+        return runners
diff --git a/src/anomalib/pipelines/components/__init__.py b/src/anomalib/pipelines/components/__init__.py
new file mode 100644
index 0000000000..1350937639
--- /dev/null
+++ b/src/anomalib/pipelines/components/__init__.py
@@ -0,0 +1,13 @@
+"""Utilities for the pipeline modules."""
+
+# Copyright (C) 2024 Intel Corporation
+# SPDX-License-Identifier: Apache-2.0
+
+from .base import Job, JobGenerator, Pipeline, Runner
+
+__all__ = [
+    "Job",
+    "JobGenerator",
+    "Pipeline",
+    "Runner",
+]
diff --git a/src/anomalib/pipelines/components/base/__init__.py b/src/anomalib/pipelines/components/base/__init__.py
new file mode 100644
index 0000000000..90682e9cd0
--- /dev/null
+++ b/src/anomalib/pipelines/components/base/__init__.py
@@ -0,0 +1,10 @@
+"""Base classes for pipelines."""
+
+# Copyright (C) 2024 Intel Corporation
+# SPDX-License-Identifier: Apache-2.0
+
+from .job import Job, JobGenerator
+from .pipeline import Pipeline
+from .runner import Runner
+
+__all__ = ["Job", "JobGenerator", "Runner", "Pipeline"]
diff --git a/src/anomalib/pipelines/components/base/job.py b/src/anomalib/pipelines/components/base/job.py
new file mode 100644
index 0000000000..422a83efa3
--- /dev/null
+++ b/src/anomalib/pipelines/components/base/job.py
@@ -0,0 +1,71 @@
+"""Job from which all the jobs inherit from."""
+
+# Copyright (C) 2024 Intel Corporation
+# SPDX-License-Identifier: Apache-2.0
+
+from abc import ABC, abstractmethod
+from collections.abc import Generator
+
+from anomalib.pipelines.types import GATHERED_RESULTS, RUN_RESULTS
+
+
+class Job(ABC):
+    """A job is an atomic unit of work that can be run in parallel with other jobs."""
+
+    name: str
+
+    @abstractmethod
+    def run(self, task_id: int | None = None) -> RUN_RESULTS:
+        """A job is a single unit of work that can be run in parallel with other jobs.
+
+        ``task_id`` is optional and is only passed when the job is run in parallel.
+        """
+
+    @staticmethod
+    @abstractmethod
+    def collect(results: list[RUN_RESULTS]) -> GATHERED_RESULTS:
+        """Gather the results returned from run.
+
+        This can be used to combine the results from multiple runs or to save/process individual job results.
+
+        Args:
+            results (list): List of results returned from run.
+
+        Returns:
+            (GATHERED_RESULTS): Collated results.
+        """
+
+    @staticmethod
+    @abstractmethod
+    def save(results: GATHERED_RESULTS) -> None:
+        """Save the gathered results.
+
+        This can be used to save the results in a file or a database.
+
+        Args:
+            results: The gathered result returned from gather_results.
+        """
+
+
+class JobGenerator(ABC):
+    """Generate Job.
+
+    The runners accept a generator that generates the jobs. The task of this class is to parse the config and return an
+    iterator of specific jobs.
+    """
+
+    def __call__(self, args: dict | None = None) -> Generator[Job, None, None]:
+        """Calls the ``generate_jobs`` method."""
+        return self.generate_jobs(args)
+
+    @abstractmethod
+    def generate_jobs(self, args: dict | None = None) -> Generator[Job, None, None]:
+        """Return an iterator based on the arguments.
+
+        This can be used to generate the configurations that will be passed to run.
+        """
+
+    @property
+    @abstractmethod
+    def job_class(self) -> type[Job]:
+        """Return the job class that will be generated."""
diff --git a/src/anomalib/pipelines/components/base/pipeline.py b/src/anomalib/pipelines/components/base/pipeline.py
new file mode 100644
index 0000000000..8b4a6a4742
--- /dev/null
+++ b/src/anomalib/pipelines/components/base/pipeline.py
@@ -0,0 +1,77 @@
+"""Base class for pipeline."""
+
+# Copyright (C) 2024 Intel Corporation
+# SPDX-License-Identifier: Apache-2.0
+
+import logging
+from abc import ABC, abstractmethod
+from pathlib import Path
+
+import yaml
+from jsonargparse import ArgumentParser, Namespace
+from rich import print, traceback
+
+from anomalib.utils.logging import redirect_logs
+
+from .runner import Runner
+
+traceback.install()
+
+log_file = "runs/pipeline.log"
+logger = logging.getLogger(__name__)
+
+
+class Pipeline(ABC):
+    """Base class for pipeline."""
+
+    def _get_args(self, args: Namespace) -> dict:
+        """Get pipeline arguments by parsing the config file.
+
+        Args:
+            args (Namespace): Arguments to run the pipeline. These are the args returned by ArgumentParser.
+
+        Returns:
+            dict: Pipeline arguments.
+        """
+        if args is None:
+            logger.warning("No arguments provided, parsing arguments from command line.")
+            parser = self.get_parser()
+            args = parser.parse_args()
+
+        with Path(args.config).open() as file:
+            return yaml.safe_load(file)
+
+    @abstractmethod
+    def _setup_runners(self, args: dict) -> list[Runner]:
+        """Setup the runners for the pipeline."""
+
+    def run(self, args: Namespace | None = None) -> None:
+        """Run the pipeline.
+
+        Args:
+            args (Namespace): Arguments to run the pipeline. These are the args returned by ArgumentParser.
+        """
+        args = self._get_args(args)
+        runners = self._setup_runners(args)
+        redirect_logs(log_file)
+
+        for runner in runners:
+            try:
+                _args = args.get(runner.generator.job_class.name, None)
+                runner.run(_args)
+            except Exception:  # noqa: PERF203 catch all exception and allow try-catch in loop
+                logger.exception("An error occurred when running the runner.")
+                print(
+                    f"There were some errors when running [red]{runner.generator.job_class.name}[/red] with"
+                    f" [green]{runner.__class__.__name__}[/green]."
+                    f" Please check [magenta]{log_file}[/magenta] for more details.",
+                )
+
+    @staticmethod
+    def get_parser(parser: ArgumentParser | None = None) -> ArgumentParser:
+        """Create a new parser if none is provided."""
+        if parser is None:
+            parser = ArgumentParser()
+            parser.add_argument("--config", type=str | Path, help="Configuration file path.", required=True)
+
+        return parser
diff --git a/src/anomalib/pipelines/components/base/runner.py b/src/anomalib/pipelines/components/base/runner.py
new file mode 100644
index 0000000000..f3577d2801
--- /dev/null
+++ b/src/anomalib/pipelines/components/base/runner.py
@@ -0,0 +1,19 @@
+"""Base runner."""
+
+# Copyright (C) 2024 Intel Corporation
+# SPDX-License-Identifier: Apache-2.0
+
+from abc import ABC, abstractmethod
+
+from .job import JobGenerator
+
+
+class Runner(ABC):
+    """Base runner."""
+
+    def __init__(self, generator: JobGenerator) -> None:
+        self.generator = generator
+
+    @abstractmethod
+    def run(self, args: dict) -> None:
+        """Run the pipeline."""
diff --git a/src/anomalib/pipelines/components/runners/__init__.py b/src/anomalib/pipelines/components/runners/__init__.py
new file mode 100644
index 0000000000..27ef21046f
--- /dev/null
+++ b/src/anomalib/pipelines/components/runners/__init__.py
@@ -0,0 +1,9 @@
+"""Executor for running a single job."""
+
+# Copyright (C) 2024 Intel Corporation
+# SPDX-License-Identifier: Apache-2.0
+
+from .parallel import ParallelRunner
+from .serial import SerialRunner
+
+__all__ = ["SerialRunner", "ParallelRunner"]
diff --git a/src/anomalib/pipelines/components/runners/parallel.py b/src/anomalib/pipelines/components/runners/parallel.py
new file mode 100644
index 0000000000..c0da242af8
--- /dev/null
+++ b/src/anomalib/pipelines/components/runners/parallel.py
@@ -0,0 +1,98 @@
+"""Process pool executor."""
+
+# Copyright (C) 2024 Intel Corporation
+# SPDX-License-Identifier: Apache-2.0
+
+import logging
+import multiprocessing
+from concurrent.futures import ProcessPoolExecutor
+from typing import TYPE_CHECKING
+
+from rich import print
+from rich.progress import Progress, TaskID
+
+from anomalib.pipelines.components.base import JobGenerator, Runner
+
+if TYPE_CHECKING:
+    from concurrent.futures import Future
+
+logger = logging.getLogger(__name__)
+
+
+class ParallelExecutionError(Exception):
+    """Pool execution error should be raised when one or more jobs fail in the pool."""
+
+
+class ParallelRunner(Runner):
+    """Run the job in parallel using a process pool.
+
+    It creates a pool of processes and submits the jobs to the pool.
+    This is useful when you have fixed resources that you want to re-use.
+    Once a process is done, it is replaced with a new job.
+
+    Args:
+        generator (JobGenerator): The generator that generates the jobs.
+        n_jobs (int): The number of jobs to run in parallel.
+
+    Example:
+        Creating a pool with the size of the number of available GPUs and submitting jobs to the pool.
+        >>> ParallelRunner(generator, n_jobs=torch.cuda.device_count())
+        Each time a job is submitted to the pool, an additional parameter `task_id` will be passed to `job.run` method.
+        The job can then use this `task_id` to assign a particular device to train on.
+        >>> def run(self, arg1: int, arg2: nn.Module, task_id: int) -> None:
+        >>>     device = torch.device(f"cuda:{task_id}")
+        >>>     model = arg2.to(device)
+        >>>     ...
+
+    """
+
+    def __init__(self, generator: JobGenerator, n_jobs: int) -> None:
+        super().__init__(generator)
+        self.n_jobs = n_jobs
+        self.processes: dict[int, Future | None] = {}
+        self.progress = Progress()
+        self.task_id: TaskID
+        self.results: list[dict] = []
+        self.failures = False
+
+    def run(self, args: dict) -> None:
+        """Run the job in parallel."""
+        self.task_id = self.progress.add_task(self.generator.job_class.name, total=None)
+        self.progress.start()
+        self.processes = {i: None for i in range(self.n_jobs)}
+
+        with ProcessPoolExecutor(max_workers=self.n_jobs, mp_context=multiprocessing.get_context("spawn")) as executor:
+            for job in self.generator.generate_jobs(args):
+                while None not in self.processes.values():
+                    self._await_cleanup_processes()
+                # get free index
+                index = next(i for i, p in self.processes.items() if p is None)
+                self.processes[index] = executor.submit(job.run, task_id=index)
+            self._await_cleanup_processes(blocking=True)
+
+        self.progress.update(self.task_id, completed=1, total=1)
+        self.progress.stop()
+        gathered_result = self.generator.job_class.collect(self.results)
+        self.generator.job_class.save(gathered_result)
+        if self.failures:
+            msg = f"[bold red]There were some errors with job {self.generator.job_class.name}[/bold red]"
+            print(msg)
+            logger.error(msg)
+            raise ParallelExecutionError(msg)
+        logger.info(f"Job {self.generator.job_class.name} completed successfully.")
+
+    def _await_cleanup_processes(self, blocking: bool = False) -> None:
+        """Wait for any one process to finish.
+
+        Args:
+            blocking (bool): If True, wait for all processes to finish.
+        """
+        for index, process in self.processes.items():
+            if process is not None and ((process.done() and not blocking) or blocking):
+                try:
+                    self.results.append(process.result())
+                except Exception:
+                    logger.exception("An exception occurred while getting the process result.")
+                    self.failures = True
+                self.processes[index] = None
+                self.progress.update(self.task_id, advance=1)
diff --git a/src/anomalib/pipelines/components/runners/serial.py b/src/anomalib/pipelines/components/runners/serial.py
new file mode 100644
index 0000000000..3633485168
--- /dev/null
+++ b/src/anomalib/pipelines/components/runners/serial.py
@@ -0,0 +1,44 @@
+"""Executor for running a job serially."""
+
+# Copyright (C) 2024 Intel Corporation
+# SPDX-License-Identifier: Apache-2.0
+
+import logging
+
+from rich import print
+from rich.progress import track
+
+from anomalib.pipelines.components.base import JobGenerator, Runner
+
+logger = logging.getLogger(__name__)
+
+
+class SerialExecutionError(Exception):
+    """Error when running a job serially."""
+
+
+class SerialRunner(Runner):
+    """Serial executor for running a single job at a time."""
+
+    def __init__(self, generator: JobGenerator) -> None:
+        super().__init__(generator)
+
+    def run(self, args: dict) -> None:
+        """Run the job."""
+        results = []
+        failures = False
+        logger.info(f"Running job {self.generator.job_class.name}")
+        for job in track(self.generator(args), description=self.generator.job_class.name):
+            try:
+                results.append(job.run())
+            except Exception:  # noqa: PERF203
+                failures = True
+                logger.exception("Error running job.")
+        gathered_result = self.generator.job_class.collect(results)
+        self.generator.job_class.save(gathered_result)
+        if failures:
+            msg = f"[bold red]There were some errors with job {self.generator.job_class.name}[/bold red]"
+            print(msg)
+            logger.error(msg)
+            raise SerialExecutionError(msg)
+        logger.info(f"Job {self.generator.job_class.name} completed successfully.")
diff --git a/src/anomalib/pipelines/components/utils/__init__.py b/src/anomalib/pipelines/components/utils/__init__.py
new file mode 100644
index 0000000000..230edc6891
--- /dev/null
+++ b/src/anomalib/pipelines/components/utils/__init__.py
@@ -0,0 +1,8 @@
+"""Utils."""
+
+# Copyright (C) 2024 Intel Corporation
+# SPDX-License-Identifier: Apache-2.0
+
+from .grid_search import get_iterator_from_grid_dict
+
+__all__ = ["get_iterator_from_grid_dict"]
diff --git a/src/anomalib/pipelines/components/utils/grid_search.py b/src/anomalib/pipelines/components/utils/grid_search.py
new file mode 100644
index 0000000000..04e481ca6a
--- /dev/null
+++ b/src/anomalib/pipelines/components/utils/grid_search.py
@@ -0,0 +1,55 @@
+"""Utils for benchmarking."""
+
+# Copyright (C) 2024 Intel Corporation
+# SPDX-License-Identifier: Apache-2.0
+
+from collections.abc import Generator
+from itertools import product
+from typing import Any
+
+from anomalib.utils.config import (
+    convert_valuesview_to_tuple,
+    flatten_dict,
+    to_nested_dict,
+)
+
+
+def get_iterator_from_grid_dict(container: dict) -> Generator[dict, Any, None]:
+    """Yields an iterator based on the grid search arguments.
+
+    Args:
+        container (dict): Container with grid search arguments.
+
+    Example:
+        >>> container = {
+                "seed": 42,
+                "data": {
+                    "root": ...,
+                    "category": {
+                        "grid": ["bottle", "carpet"],
+                        ...
+                    }
+                }
+            }
+        >>> get_iterator_from_grid_search(container)
+        {
+                "seed": 42,
+                "data": {
+                    "root": ...,
+                    "category": "bottle"
+                        ...
+                    }
+                }
+        }
+
+    Yields:
+        Generator[dict, Any, None]: Iterator based on the grid search arguments.
+    """
+    _container = flatten_dict(container)
+    grid_dict = {key: value for key, value in _container.items() if "grid" in key}
+    _container = {key: value for key, value in _container.items() if key not in grid_dict}
+    combinations = list(product(*convert_valuesview_to_tuple(grid_dict.values())))
+    for combination in combinations:
+        for key, value in zip(grid_dict.keys(), combination, strict=True):
+            _container[key.removesuffix(".grid")] = value
+        yield to_nested_dict(_container)
diff --git a/src/anomalib/pipelines/types.py b/src/anomalib/pipelines/types.py
new file mode 100644
index 0000000000..d542d93378
--- /dev/null
+++ b/src/anomalib/pipelines/types.py
@@ -0,0 +1,9 @@
+"""Types."""
+
+# Copyright (C) 2024 Intel Corporation
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any
+
+RUN_RESULTS = Any
+GATHERED_RESULTS = Any
diff --git a/src/anomalib/utils/config.py b/src/anomalib/utils/config.py
index 113522819e..f41617f355 100644
--- a/src/anomalib/utils/config.py
+++ b/src/anomalib/utils/config.py
@@ -3,9 +3,8 @@
 # Copyright (C) 2022-2024 Intel Corporation
 # SPDX-License-Identifier: Apache-2.0
 
-
 import logging
-from collections.abc import Sequence
+from collections.abc import Iterable, Sequence, ValuesView
 from pathlib import Path
 from typing import Any, cast
 
@@ -29,6 +28,34 @@ def _convert_nested_path_to_str(config: Any) -> Any:  # noqa: ANN401
     return config
 
 
+def to_nested_dict(config: dict) -> dict:
+    """Convert the flattened dictionary to nested dictionary.
+
+    Examples:
+        >>> config = {
+                "dataset.category": "bottle",
+                "dataset.image_size": 224,
+                "model_name": "padim",
+            }
+        >>> to_nested_dict(config)
+        {
+            "dataset": {
+                "category": "bottle",
+                "image_size": 224,
+            },
+            "model_name": "padim",
+        }
+    """
+    out: dict[str, Any] = {}
+    for key, value in config.items():
+        keys = key.split(".")
+        _dict = out
+        for k in keys[:-1]:
+            _dict = _dict.setdefault(k, {})
+        _dict[keys[-1]] = value
+    return out
+
+
 def to_yaml(config: Namespace | ListConfig | DictConfig) -> str:
     """Convert the config to a yaml string.
 
@@ -78,6 +105,112 @@ def to_tuple(input_size: int | ListConfig) -> tuple[int, int]:
     return ret_val
 
 
+def convert_valuesview_to_tuple(values: ValuesView) -> list[tuple]:
+    """Convert a ValuesView object to a list of tuples.
+
+    This is useful to get list of possible values for each parameter in the config and a tuple for values that are
+    are to be patched. Ideally this is useful when used with product.
+
+    Example:
+        >>> params = DictConfig({
+                "dataset.category": [
+                    "bottle",
+                    "cable",
+                ],
+                "dataset.image_size": 224,
+                "model_name": ["padim"],
+            })
+        >>> convert_to_tuple(params.values())
+        [('bottle', 'cable'), (224,), ('padim',)]
+        >>> list(itertools.product(*convert_to_tuple(params.values())))
+        [('bottle', 224, 'padim'), ('cable', 224, 'padim')]
+
+    Args:
+        values: ValuesView: ValuesView object to be converted to a list of tuples.
+
+    Returns:
+        list[Tuple]: List of tuples.
+    """
+    return_list = []
+    for value in values:
+        if isinstance(value, Iterable) and not isinstance(value, str):
+            return_list.append(tuple(value))
+        else:
+            return_list.append((value,))
+    return return_list
+
+
+def flatten_dict(config: dict, prefix: str = "") -> dict:
+    """Flatten the dictionary.
+
+    Examples:
+        >>> config = {
+                "dataset": {
+                    "category": "bottle",
+                    "image_size": 224,
+                },
+                "model_name": "padim",
+            }
+        >>> flatten_dict(config)
+        {
+            "dataset.category": "bottle",
+            "dataset.image_size": 224,
+            "model_name": "padim",
+        }
+    """
+    out = {}
+    for key, value in config.items():
+        if isinstance(value, dict):
+            out.update(flatten_dict(value, f"{prefix}{key}."))
+        else:
+            out[f"{prefix}{key}"] = value
+    return out
+
+
+def namespace_from_dict(container: dict) -> Namespace:
+    """Convert dictionary to Namespace recursively.
+
+    Examples:
+        >>> container = {
+                "dataset": {
+                    "category": "bottle",
+                    "image_size": 224,
+                },
+                "model_name": "padim",
+            }
+        >>> namespace_from_dict(container)
+        Namespace(dataset=Namespace(category='bottle', image_size=224), model_name='padim')
+    """
+    output = Namespace()
+    for k, v in container.items():
+        if isinstance(v, dict):
+            setattr(output, k, namespace_from_dict(v))
+        else:
+            setattr(output, k, v)
+    return output
+
+
+def dict_from_namespace(container: Namespace) -> dict:
+    """Convert Namespace to dictionary recursively.
+
+    Examples:
+        >>> from jsonargparse import Namespace
+        >>> ns = Namespace()
+        >>> ns.a = 1
+        >>> ns.b = Namespace()
+        >>> ns.b.c = 2
+        >>> dict_from_namespace(ns)
+        {'a': 1, 'b': {'c': 2}}
+    """
+    output = {}
+    for k, v in container.__dict__.items():
+        if isinstance(v, Namespace):
+            output[k] = dict_from_namespace(v)
+        else:
+            output[k] = v
+    return output
+
+
 def update_config(config: DictConfig | ListConfig | Namespace) -> DictConfig | ListConfig | Namespace:
     """Update config.
 
diff --git a/src/anomalib/utils/logging.py b/src/anomalib/utils/logging.py
new file mode 100644
index 0000000000..722b4d87b7
--- /dev/null
+++ b/src/anomalib/utils/logging.py
@@ -0,0 +1,85 @@
+"""Logging Utility functions."""
+
+# Copyright (C) 2024 Intel Corporation
+# SPDX-License-Identifier: Apache-2.0
+
+import functools
+import io
+import logging
+import sys
+from collections.abc import Callable
+from pathlib import Path
+from typing import Any
+
+
+class LoggerRedirectError(Exception):
+    """Exception occurred when executing function with outputs redirected to logger."""
+
+
+def hide_output(func: Callable[..., Any]) -> Callable[..., Any]:
+    """Hide output of the function.
+
+    Args:
+        func (function): Hides output from all streams of this function.
+
+    Example:
+        >>> @hide_output
+        >>> def my_function():
+        >>>     print("This will not be printed")
+        >>> my_function()
+
+        >>> @hide_output
+        >>> def my_function():
+        >>>     1/0
+        >>> my_function()
+        Traceback (most recent call last):
+        File "<stdin>", line 1, in <module>
+        File "<stdin>", line 2, in my_fun
+        ZeroDivisionError: division by zero
+
+    Raises:
+        Exception: In case the execution of function fails, it raises an exception.
+
+    Returns:
+        object of the called function
+    """
+
+    @functools.wraps(func)
+    def wrapper(*args: Any, **kwargs: Any) -> Any:  # noqa: ANN401
+        """Wrapper function."""
+        # redirect stdout and stderr to logger
+        stdout = sys.stdout
+        stderr = sys.stderr
+        sys.stdout = io.StringIO()
+        sys.stderr = io.StringIO()
+        try:
+            value = func(*args, **kwargs)
+        except Exception as exception:  # noqa: BLE001
+            msg = f"Error occurred while executing {func.__name__}"
+            raise LoggerRedirectError(msg) from exception
+        finally:
+            sys.stdout = stdout
+            sys.stderr = stderr
+        return value
+
+    return wrapper
+
+
+def redirect_logs(log_file: str) -> None:
+    """Add file handler to logger.
+
+    It also removes all other handlers from the loggers.
+
+    Note: This feature does not work well with multiprocessing and won't redirect logs from child processes.
+    """
+    Path(log_file).parent.mkdir(exist_ok=True, parents=True)
+    logger_file_handler = logging.FileHandler(log_file)
+    root_logger = logging.getLogger()
+    root_logger.setLevel(logging.DEBUG)
+    format_string = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
+    logging.basicConfig(format=format_string, level=logging.DEBUG, handlers=[logger_file_handler])
+    logging.captureWarnings(capture=True)
+    # remove other handlers from all loggers
+    loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict]
+    for _logger in loggers:
+        _logger.handlers = [logger_file_handler]
diff --git a/tests/integration/pipelines/__init__.py b/tests/integration/pipelines/__init__.py
new file mode 100644
index 0000000000..a1eb07e53e
--- /dev/null
+++ b/tests/integration/pipelines/__init__.py
@@ -0,0 +1,4 @@
+"""Pipeline tests."""
+
+# Copyright (C) 2024 Intel Corporation
+# SPDX-License-Identifier: Apache-2.0
diff --git a/tests/integration/pipelines/pipeline.yaml b/tests/integration/pipelines/pipeline.yaml
new file mode 100644
index 0000000000..114b125944
--- /dev/null
+++ b/tests/integration/pipelines/pipeline.yaml
@@ -0,0 +1,17 @@
+# mock.patch does not work with multiprocessing("spawn")
+# hence cuda is not part of the tests
+accelerator: cpu
+benchmark:
+  seed:
+    grid: [42, 51]
+  model:
+    class_path:
+      grid: [Padim, Patchcore]
+  data:
+    class_path: MVTec
+    init_args:
+      category:
+        grid:
+          - bottle
+          - capsule
+      image_size: [256, 256]
diff --git a/tests/integration/pipelines/test_benchmark.py b/tests/integration/pipelines/test_benchmark.py
new file mode 100644
index 0000000000..bf49de1764
--- /dev/null
+++ b/tests/integration/pipelines/test_benchmark.py
@@ -0,0 +1,21 @@
+"""Test benchmarking pipeline."""
+
+# Copyright (C) 2024 Intel Corporation
+# SPDX-License-Identifier: Apache-2.0
+
+from unittest.mock import MagicMock, patch
+
+from anomalib.pipelines import Benchmark
+
+
+@patch("anomalib.pipelines.benchmark.job.Engine", return_value=MagicMock(test=MagicMock(return_value=[{"test": 1}])))
+@patch("anomalib.pipelines.benchmark.generator.get_model", return_value=MagicMock())
+@patch("anomalib.pipelines.benchmark.generator.get_datamodule", return_value=MagicMock(category="dummy"))
+def test_benchmark_pipeline(engine: MagicMock, model: MagicMock, datamodule: MagicMock) -> None:  # noqa: ARG001 | this is needed for patching
+    """Test benchmarking pipeline."""
+    with patch("anomalib.pipelines.benchmark.job.BenchmarkJob.save", return_value=MagicMock()) as save_method:
+        benchmark = Benchmark()
+        benchmark_parser = benchmark.get_parser()
+        args = benchmark_parser.parse_args(["--config", "tests/integration/pipelines/pipeline.yaml"])
+        benchmark.run(args)
+        assert len(save_method.call_args.args[0]) == 8
diff --git a/tools/experimental/README.md b/tools/experimental/README.md
new file mode 100644
index 0000000000..7753314e58
--- /dev/null
+++ b/tools/experimental/README.md
@@ -0,0 +1,4 @@
+# Anomalib Experimental
+
+> [!WARNING]
+> These are experimental utilities that are under development. These might change frequently or might even be dropped.
diff --git a/tools/experimental/__init__.py b/tools/experimental/__init__.py
new file mode 100644
index 0000000000..0c99698557
--- /dev/null
+++ b/tools/experimental/__init__.py
@@ -0,0 +1,4 @@
+"""Independent entrypoint for runners."""
+
+# Copyright (C) 2024 Intel Corporation
+# SPDX-License-Identifier: Apache-2.0
diff --git a/tools/experimental/benchmarking/README.md b/tools/experimental/benchmarking/README.md
new file mode 100644
index 0000000000..f869bbcbf0
--- /dev/null
+++ b/tools/experimental/benchmarking/README.md
@@ -0,0 +1,7 @@
+# Benchmarking Entrypoint
+
+## Usage
+
+```bash
+python tools/experimental/benchmarking/benchmark.py --config tools/experimental/benchmarking/sample.yaml
+```
diff --git a/tools/experimental/benchmarking/__init__.py b/tools/experimental/benchmarking/__init__.py
new file mode 100644
index 0000000000..4e128bf782
--- /dev/null
+++ b/tools/experimental/benchmarking/__init__.py
@@ -0,0 +1,4 @@
+"""Benchmarking entrypoint."""
+
+# Copyright (C) 2024 Intel Corporation
+# SPDX-License-Identifier: Apache-2.0
diff --git a/tools/experimental/benchmarking/benchmark.py b/tools/experimental/benchmarking/benchmark.py
new file mode 100644
index 0000000000..1953bfc0a5
--- /dev/null
+++ b/tools/experimental/benchmarking/benchmark.py
@@ -0,0 +1,14 @@
+"""Run benchmarking."""
+
+# Copyright (C) 2024 Intel Corporation
+# SPDX-License-Identifier: Apache-2.0
+
+import logging
+
+from anomalib.pipelines.benchmark import Benchmark
+
+logger = logging.getLogger(__name__)
+
+if __name__ == "__main__":
+    logger.warning("This feature is experimental. It may change or be removed in the future.")
+    Benchmark().run()
diff --git a/tools/experimental/benchmarking/sample.yaml b/tools/experimental/benchmarking/sample.yaml
new file mode 100644
index 0000000000..6c0810c000
--- /dev/null
+++ b/tools/experimental/benchmarking/sample.yaml
@@ -0,0 +1,17 @@
+# sample script to show grid search for two categories
+accelerator:
+  - cuda
+  - cpu
+benchmark:
+  seed: 42
+  model:
+    class_path:
+      grid: [Padim, Patchcore]
+  data:
+    class_path: MVTec
+    init_args:
+      category:
+        grid:
+          - bottle
+          - capsule
+      image_size: [256, 256]