diff --git a/kedro/runner/parallel_runner.py b/kedro/runner/parallel_runner.py index 4f20295285..4dba5fc774 100644 --- a/kedro/runner/parallel_runner.py +++ b/kedro/runner/parallel_runner.py @@ -4,11 +4,7 @@ from __future__ import annotations -import os -import sys -from collections import Counter -from concurrent.futures import FIRST_COMPLETED, ProcessPoolExecutor, wait -from itertools import chain +from concurrent.futures import Executor, ProcessPoolExecutor from multiprocessing.managers import BaseProxy, SyncManager from multiprocessing.reduction import ForkingPickler from pickle import PicklingError @@ -21,7 +17,6 @@ SharedMemoryDataset, ) from kedro.runner.runner import AbstractRunner -from kedro.runner.task import Task if TYPE_CHECKING: from collections.abc import Iterable @@ -31,9 +26,6 @@ from kedro.pipeline import Pipeline from kedro.pipeline.node import Node -# see https://github.com/python/cpython/blob/master/Lib/concurrent/futures/process.py#L114 -_MAX_WINDOWS_WORKERS = 61 - class ParallelRunnerManager(SyncManager): """``ParallelRunnerManager`` is used to create shared ``MemoryDataset`` @@ -83,16 +75,7 @@ def __init__( self._manager = ParallelRunnerManager() self._manager.start() - # This code comes from the concurrent.futures library - # https://github.com/python/cpython/blob/master/Lib/concurrent/futures/process.py#L588 - if max_workers is None: - # NOTE: `os.cpu_count` might return None in some weird cases. - # https://github.com/python/cpython/blob/3.7/Modules/posixmodule.c#L11431 - max_workers = os.cpu_count() or 1 - if sys.platform == "win32": - max_workers = min(_MAX_WINDOWS_WORKERS, max_workers) - - self._max_workers = max_workers + self._max_workers = self._validate_max_workers(max_workers) def __del__(self) -> None: self._manager.shutdown() @@ -189,14 +172,17 @@ def _get_required_workers_count(self, pipeline: Pipeline) -> int: return min(required_processes, self._max_workers) + def _get_executor(self, max_workers: int) -> Executor: + return ProcessPoolExecutor(max_workers=max_workers) + def _run( self, pipeline: Pipeline, catalog: CatalogProtocol, - hook_manager: PluginManager, + hook_manager: PluginManager | None = None, session_id: str | None = None, ) -> None: - """The abstract interface for running pipelines. + """The method implementing parallel pipeline running. Args: pipeline: The ``Pipeline`` to run. @@ -218,50 +204,8 @@ def _run( "for potential performance gains. https://docs.kedro.org/en/stable/nodes_and_pipelines/run_a_pipeline.html#load-and-save-asynchronously" ) - nodes = pipeline.nodes - self._validate_catalog(catalog, pipeline) - self._validate_nodes(nodes) - self._set_manager_datasets(catalog, pipeline) - load_counts = Counter(chain.from_iterable(n.inputs for n in nodes)) - node_dependencies = pipeline.node_dependencies - todo_nodes = set(node_dependencies.keys()) - done_nodes: set[Node] = set() - futures = set() - done = None - max_workers = self._get_required_workers_count(pipeline) - - with ProcessPoolExecutor(max_workers=max_workers) as pool: - while True: - ready = {n for n in todo_nodes if node_dependencies[n] <= done_nodes} - todo_nodes -= ready - for node in ready: - task = Task( - node=node, - catalog=catalog, - is_async=self._is_async, - session_id=session_id, - parallel=True, - ) - futures.add(pool.submit(task)) - if not futures: - if todo_nodes: - debug_data = { - "todo_nodes": todo_nodes, - "done_nodes": done_nodes, - "ready_nodes": ready, - "done_futures": done, - } - debug_data_str = "\n".join( - f"{k} = {v}" for k, v in debug_data.items() - ) - raise RuntimeError( - f"Unable to schedule new tasks although some nodes " - f"have not been run:\n{debug_data_str}" - ) - break # pragma: no cover - done, futures = wait(futures, return_when=FIRST_COMPLETED) - for future in done: - node = future.result() - done_nodes.add(node) - - self._release_datasets(node, catalog, load_counts, pipeline) + super()._run( + pipeline=pipeline, + catalog=catalog, + session_id=session_id, + ) diff --git a/kedro/runner/runner.py b/kedro/runner/runner.py index c3f31f40a4..f882e37249 100644 --- a/kedro/runner/runner.py +++ b/kedro/runner/runner.py @@ -6,17 +6,26 @@ import inspect import logging +import os +import sys import warnings from abc import ABC, abstractmethod -from collections import deque +from collections import Counter, deque +from concurrent.futures import FIRST_COMPLETED, Executor, ProcessPoolExecutor, wait +from itertools import chain from typing import TYPE_CHECKING, Any +from pluggy import PluginManager + from kedro import KedroDeprecationWarning from kedro.framework.hooks.manager import _NullPluginManager from kedro.io import CatalogProtocol, MemoryDataset, SharedMemoryDataset from kedro.pipeline import Pipeline from kedro.runner.task import Task +# see https://github.com/python/cpython/blob/master/Lib/concurrent/futures/process.py#L114 +_MAX_WINDOWS_WORKERS = 61 + if TYPE_CHECKING: from collections.abc import Collection, Iterable @@ -166,25 +175,95 @@ def run_only_missing( return self.run(to_rerun, catalog, hook_manager) + @abstractmethod # pragma: no cover + def _get_executor(self, max_workers: int) -> Executor: + """Abstract method to provide the correct executor (e.g., ThreadPoolExecutor or ProcessPoolExecutor).""" + pass + @abstractmethod # pragma: no cover def _run( self, pipeline: Pipeline, catalog: CatalogProtocol, - hook_manager: PluginManager, + hook_manager: PluginManager | None = None, session_id: str | None = None, ) -> None: """The abstract interface for running pipelines, assuming that the - inputs have already been checked and normalized by run(). + inputs have already been checked and normalized by run(). + This contains the Common pipeline execution logic using an executor. Args: pipeline: The ``Pipeline`` to run. catalog: An implemented instance of ``CatalogProtocol`` from which to fetch data. hook_manager: The ``PluginManager`` to activate hooks. session_id: The id of the session. - """ - pass + + nodes = pipeline.nodes + + self._validate_catalog(catalog, pipeline) + self._validate_nodes(nodes) + self._set_manager_datasets(catalog, pipeline) + + load_counts = Counter(chain.from_iterable(n.inputs for n in pipeline.nodes)) + node_dependencies = pipeline.node_dependencies + todo_nodes = set(node_dependencies.keys()) + done_nodes: set[Node] = set() + futures = set() + done = None + max_workers = self._get_required_workers_count(pipeline) + + with self._get_executor(max_workers) as pool: + while True: + ready = {n for n in todo_nodes if node_dependencies[n] <= done_nodes} + todo_nodes -= ready + for node in ready: + task = Task( + node=node, + catalog=catalog, + hook_manager=hook_manager, + is_async=self._is_async, + session_id=session_id, + ) + if isinstance(pool, ProcessPoolExecutor): + task.parallel = True + futures.add(pool.submit(task)) + if not futures: + if todo_nodes: + self._raise_runtime_error(todo_nodes, done_nodes, ready, done) + break + done, futures = wait(futures, return_when=FIRST_COMPLETED) + for future in done: + try: + node = future.result() + except Exception: + self._suggest_resume_scenario(pipeline, done_nodes, catalog) + raise + done_nodes.add(node) + self._logger.info("Completed node: %s", node.name) + self._logger.info( + "Completed %d out of %d tasks", len(done_nodes), len(nodes) + ) + self._release_datasets(node, catalog, load_counts, pipeline) + + @staticmethod + def _raise_runtime_error( + todo_nodes: set[Node], + done_nodes: set[Node], + ready: set[Node], + done: set[Node] | None, + ) -> None: + debug_data = { + "todo_nodes": todo_nodes, + "done_nodes": done_nodes, + "ready_nodes": ready, + "done_futures": done, + } + debug_data_str = "\n".join(f"{k} = {v}" for k, v in debug_data.items()) + raise RuntimeError( + f"Unable to schedule new tasks although some nodes " + f"have not been run:\n{debug_data_str}" + ) def _suggest_resume_scenario( self, @@ -240,6 +319,47 @@ def _release_datasets( if load_counts[dataset] < 1 and dataset not in pipeline.outputs(): catalog.release(dataset) + def _validate_catalog(self, catalog: CatalogProtocol, pipeline: Pipeline) -> None: + # Add catalog validation logic here if needed + pass + + def _validate_nodes(self, node: Iterable[Node]) -> None: + # Add node validation logic here if needed + pass + + def _set_manager_datasets( + self, catalog: CatalogProtocol, pipeline: Pipeline + ) -> None: + # Set up any necessary manager datasets here + pass + + def _get_required_workers_count(self, pipeline: Pipeline) -> int: + return 1 + + @classmethod + def _validate_max_workers(cls, max_workers: int | None) -> int: + """ + Validates and returns the number of workers. Sets to os.cpu_count() or 1 if max_workers is None, + and limits max_workers to 61 on Windows. + + Args: + max_workers: Desired number of workers. If None, defaults to os.cpu_count() or 1. + + Returns: + A valid number of workers to use. + + Raises: + ValueError: If max_workers is set and is not positive. + """ + if max_workers is None: + max_workers = os.cpu_count() or 1 + if sys.platform == "win32": + max_workers = min(_MAX_WINDOWS_WORKERS, max_workers) + elif max_workers <= 0: + raise ValueError("max_workers should be positive") + + return max_workers + def _find_nodes_to_resume_from( pipeline: Pipeline, unfinished_nodes: Collection[Node], catalog: CatalogProtocol diff --git a/kedro/runner/sequential_runner.py b/kedro/runner/sequential_runner.py index 508f95234f..8e0fc92377 100644 --- a/kedro/runner/sequential_runner.py +++ b/kedro/runner/sequential_runner.py @@ -5,12 +5,13 @@ from __future__ import annotations -from collections import Counter -from itertools import chain +from concurrent.futures import ( + Executor, + ThreadPoolExecutor, +) from typing import TYPE_CHECKING, Any from kedro.runner.runner import AbstractRunner -from kedro.runner.task import Task if TYPE_CHECKING: from pluggy import PluginManager @@ -46,11 +47,16 @@ def __init__( is_async=is_async, extra_dataset_patterns=self._extra_dataset_patterns ) + def _get_executor(self, max_workers: int) -> Executor: + return ThreadPoolExecutor( + max_workers=1 + ) # Single-threaded for sequential execution + def _run( self, pipeline: Pipeline, catalog: CatalogProtocol, - hook_manager: PluginManager, + hook_manager: PluginManager | None = None, session_id: str | None = None, ) -> None: """The method implementing sequential pipeline running. @@ -69,27 +75,9 @@ def _run( "Using synchronous mode for loading and saving data. Use the --async flag " "for potential performance gains. https://docs.kedro.org/en/stable/nodes_and_pipelines/run_a_pipeline.html#load-and-save-asynchronously" ) - nodes = pipeline.nodes - done_nodes = set() - - load_counts = Counter(chain.from_iterable(n.inputs for n in nodes)) - - for exec_index, node in enumerate(nodes): - try: - Task( - node=node, - catalog=catalog, - hook_manager=hook_manager, - is_async=self._is_async, - session_id=session_id, - ).execute() - done_nodes.add(node) - except Exception: - self._suggest_resume_scenario(pipeline, done_nodes, catalog) - raise - - self._release_datasets(node, catalog, load_counts, pipeline) - - self._logger.info( - "Completed %d out of %d tasks", len(done_nodes), len(nodes) - ) + super()._run( + pipeline=pipeline, + catalog=catalog, + hook_manager=hook_manager, + session_id=session_id, + ) diff --git a/kedro/runner/thread_runner.py b/kedro/runner/thread_runner.py index 19cfaafdbd..b0194165b7 100644 --- a/kedro/runner/thread_runner.py +++ b/kedro/runner/thread_runner.py @@ -6,12 +6,9 @@ from __future__ import annotations import warnings -from collections import Counter -from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait -from itertools import chain +from concurrent.futures import Executor, ThreadPoolExecutor from typing import TYPE_CHECKING, Any -from kedro.runner import Task from kedro.runner.runner import AbstractRunner if TYPE_CHECKING: @@ -19,7 +16,6 @@ from kedro.io import CatalogProtocol from kedro.pipeline import Pipeline - from kedro.pipeline.node import Node class ThreadRunner(AbstractRunner): @@ -63,10 +59,7 @@ def __init__( is_async=False, extra_dataset_patterns=self._extra_dataset_patterns ) - if max_workers is not None and max_workers <= 0: - raise ValueError("max_workers should be positive") - - self._max_workers = max_workers + self._max_workers = self._validate_max_workers(max_workers) def _get_required_workers_count(self, pipeline: Pipeline) -> int: """ @@ -85,14 +78,17 @@ def _get_required_workers_count(self, pipeline: Pipeline) -> int: else required_threads ) + def _get_executor(self, max_workers: int) -> Executor: + return ThreadPoolExecutor(max_workers=max_workers) + def _run( self, pipeline: Pipeline, catalog: CatalogProtocol, - hook_manager: PluginManager, + hook_manager: PluginManager | None = None, session_id: str | None = None, ) -> None: - """The abstract interface for running pipelines. + """The method implementing threaded pipeline running. Args: pipeline: The ``Pipeline`` to run. @@ -104,42 +100,9 @@ def _run( Exception: in case of any downstream node failure. """ - nodes = pipeline.nodes - load_counts = Counter(chain.from_iterable(n.inputs for n in nodes)) - node_dependencies = pipeline.node_dependencies - todo_nodes = set(node_dependencies.keys()) - done_nodes: set[Node] = set() - futures = set() - done = None - max_workers = self._get_required_workers_count(pipeline) - - with ThreadPoolExecutor(max_workers=max_workers) as pool: - while True: - ready = {n for n in todo_nodes if node_dependencies[n] <= done_nodes} - todo_nodes -= ready - for node in ready: - task = Task( - node=node, - catalog=catalog, - hook_manager=hook_manager, - is_async=self._is_async, - session_id=session_id, - ) - futures.add(pool.submit(task)) - if not futures: - assert not todo_nodes, (todo_nodes, done_nodes, ready, done) # noqa: S101 - break - done, futures = wait(futures, return_when=FIRST_COMPLETED) - for future in done: - try: - node = future.result() - except Exception: - self._suggest_resume_scenario(pipeline, done_nodes, catalog) - raise - done_nodes.add(node) - self._logger.info("Completed node: %s", node.name) - self._logger.info( - "Completed %d out of %d tasks", len(done_nodes), len(nodes) - ) - - self._release_datasets(node, catalog, load_counts, pipeline) + super()._run( + pipeline=pipeline, + catalog=catalog, + hook_manager=hook_manager, + session_id=session_id, + ) diff --git a/tests/runner/test_parallel_runner.py b/tests/runner/test_parallel_runner.py index 0f989048f1..049b27f200 100644 --- a/tests/runner/test_parallel_runner.py +++ b/tests/runner/test_parallel_runner.py @@ -17,9 +17,9 @@ from kedro.pipeline.modular_pipeline import pipeline as modular_pipeline from kedro.runner import ParallelRunner from kedro.runner.parallel_runner import ( - _MAX_WINDOWS_WORKERS, ParallelRunnerManager, ) +from kedro.runner.runner import _MAX_WINDOWS_WORKERS from tests.runner.conftest import ( exception_fn, identity, diff --git a/tests/runner/test_sequential_runner.py b/tests/runner/test_sequential_runner.py index 229518ecd4..e2424fbec6 100644 --- a/tests/runner/test_sequential_runner.py +++ b/tests/runner/test_sequential_runner.py @@ -271,9 +271,9 @@ class TestSuggestResumeScenario: [ (["node1_A", "node1_B"], r"No nodes ran."), (["node2"], r"(node1_A,node1_B|node1_B,node1_A)"), - (["node3_A"], r"(node3_A,node3_B|node3_B,node3_A)"), - (["node4_A"], r"(node3_A,node3_B|node3_B,node3_A)"), - (["node3_A", "node4_A"], r"(node3_A,node3_B|node3_B,node3_A)"), + (["node3_A"], r"(node3_A,node3_B|node3_B,node3_A|node3_A)"), + (["node4_A"], r"(node3_A,node3_B|node3_B,node3_A|node3_A)"), + (["node3_A", "node4_A"], r"(node3_A,node3_B|node3_B,node3_A|node3_A)"), (["node2", "node4_A"], r"(node1_A,node1_B|node1_B,node1_A)"), ], ) @@ -304,9 +304,9 @@ def test_suggest_resume_scenario( [ (["node1_A", "node1_B"], r"No nodes ran."), (["node2"], r'"node1_A,node1_B"'), - (["node3_A"], r'"node3_A,node3_B"'), - (["node4_A"], r'"node3_A,node3_B"'), - (["node3_A", "node4_A"], r'"node3_A,node3_B"'), + (["node3_A"], r"(node3_A,node3_B|node3_A)"), + (["node4_A"], r"(node3_A,node3_B|node3_A)"), + (["node3_A", "node4_A"], r"(node3_A,node3_B|node3_A)"), (["node2", "node4_A"], r'"node1_A,node1_B"'), ], ) diff --git a/tests/runner/test_thread_runner.py b/tests/runner/test_thread_runner.py index 8374071cab..8da85624e0 100644 --- a/tests/runner/test_thread_runner.py +++ b/tests/runner/test_thread_runner.py @@ -1,5 +1,6 @@ from __future__ import annotations +import re from concurrent.futures import ThreadPoolExecutor from typing import Any @@ -249,3 +250,76 @@ def test_release_transcoded(self): # we want to see both datasets being released assert list(log) == [("release", "save"), ("load", "load"), ("release", "load")] + + +class TestSuggestResumeScenario: + @pytest.mark.parametrize( + "failing_node_names,expected_pattern", + [ + (["node1_A", "node1_B"], r"No nodes ran."), + (["node2"], r"(node1_A,node1_B|node1_B,node1_A)"), + (["node3_A"], r"(node3_A,node3_B|node3_B,node3_A|node3_A)"), + (["node4_A"], r"(node3_A,node3_B|node3_B,node3_A|node3_A)"), + (["node3_A", "node4_A"], r"(node3_A,node3_B|node3_B,node3_A|node3_A)"), + (["node2", "node4_A"], r"(node1_A,node1_B|node1_B,node1_A)"), + ], + ) + def test_suggest_resume_scenario( + self, + caplog, + two_branches_crossed_pipeline, + persistent_dataset_catalog, + failing_node_names, + expected_pattern, + ): + nodes = {n.name: n for n in two_branches_crossed_pipeline.nodes} + for name in failing_node_names: + two_branches_crossed_pipeline -= modular_pipeline([nodes[name]]) + two_branches_crossed_pipeline += modular_pipeline( + [nodes[name]._copy(func=exception_fn)] + ) + with pytest.raises(Exception): + ThreadRunner().run( + two_branches_crossed_pipeline, + persistent_dataset_catalog, + hook_manager=_create_hook_manager(), + ) + assert re.search(expected_pattern, caplog.text) + + @pytest.mark.parametrize( + "failing_node_names,expected_pattern", + [ + (["node1_A", "node1_B"], r"No nodes ran."), + (["node2"], r'"node1_A,node1_B"'), + (["node3_A"], r"(node3_A,node3_B|node3_A)"), + (["node4_A"], r"(node3_A,node3_B|node3_A)"), + (["node3_A", "node4_A"], r"(node3_A,node3_B|node3_A)"), + (["node2", "node4_A"], r'"node1_A,node1_B"'), + ], + ) + def test_stricter_suggest_resume_scenario( + self, + caplog, + two_branches_crossed_pipeline_variable_inputs, + persistent_dataset_catalog, + failing_node_names, + expected_pattern, + ): + """ + Stricter version of previous test. + Covers pipelines where inputs are shared across nodes. + """ + test_pipeline = two_branches_crossed_pipeline_variable_inputs + + nodes = {n.name: n for n in test_pipeline.nodes} + for name in failing_node_names: + test_pipeline -= modular_pipeline([nodes[name]]) + test_pipeline += modular_pipeline([nodes[name]._copy(func=exception_fn)]) + + with pytest.raises(Exception, match="test exception"): + ThreadRunner().run( + test_pipeline, + persistent_dataset_catalog, + hook_manager=_create_hook_manager(), + ) + assert re.search(expected_pattern, caplog.text)