diff --git a/src/ert/config/ert_config.py b/src/ert/config/ert_config.py index 345b202d5f9..d0c36b4fbb6 100644 --- a/src/ert/config/ert_config.py +++ b/src/ert/config/ert_config.py @@ -23,7 +23,7 @@ from pydantic import field_validator from pydantic.dataclasses import dataclass, rebuild_dataclass -from ert.plugins import ErtPluginManager +from ert.plugins import ErtPluginManager, fixtures_per_hook from ert.plugins.workflow_config import ErtScriptWorkflow from ert.substitutions import Substitutions @@ -395,10 +395,48 @@ def workflows_from_dict( ) continue + wf = workflows[hook_name] + available_fixtures = fixtures_per_hook[mode] + for job, _ in wf.cmd_list: + if job.ert_script is None: + continue + + ert_script_instance = job.ert_script() + requested_fixtures = ert_script_instance.requested_fixtures + + # Look for requested fixtures that are not available for the given + # mode + missing_fixtures = requested_fixtures - available_fixtures + + if missing_fixtures: + ok_modes = [ + m + for m in HookRuntime + if not requested_fixtures - fixtures_per_hook[m] + ] + + message_start = ( + f"Workflow job {job.name} .run function expected " + f"fixtures: {missing_fixtures}, which are not available " + f"in the fixtures for the runtime {mode}: {available_fixtures}. " + ) + message_end = ( + f"It would work in these runtimes: {', '.join(map(str, ok_modes))}" + if len(ok_modes) > 0 + else "This fixture is not available in any of the runtimes." + ) + + errors.append( + ErrorInfo(message=message_start + message_end).set_context( + hook_name + ) + ) + hooked_workflows[mode].append(workflows[hook_name]) if errors: raise ConfigValidationError.from_collected(errors) + return workflow_jobs, workflows, hooked_workflows diff --git a/src/ert/gui/tools/plugins/plugin.py b/src/ert/gui/tools/plugins/plugin.py index da66e20d5e1..29624b31010 100644 --- a/src/ert/gui/tools/plugins/plugin.py +++ b/src/ert/gui/tools/plugins/plugin.py @@ -3,7 +3,7 @@ import inspect from typing import TYPE_CHECKING, Any -from ert.plugins import ErtPlugin, WorkflowFixtures +from ert.plugins import ErtPlugin, HookedWorkflowFixtures if TYPE_CHECKING: from PyQt6.QtWidgets import QWidget @@ -34,7 +34,7 @@ def getName(self) -> str: def getDescription(self) -> str: return self.__description - def getArguments(self, fixtures: WorkflowFixtures) -> list[Any]: + def getArguments(self, fixtures: HookedWorkflowFixtures) -> list[Any]: """ Returns a list of arguments. Either from GUI or from arbitrary code. If the user for example cancels in the GUI a CancelPluginException is raised. diff --git a/src/ert/gui/tools/plugins/plugin_runner.py b/src/ert/gui/tools/plugins/plugin_runner.py index 39a21acec26..f26cfcc1e55 100644 --- a/src/ert/gui/tools/plugins/plugin_runner.py +++ b/src/ert/gui/tools/plugins/plugin_runner.py @@ -6,7 +6,7 @@ from _ert.threading import ErtThread from ert.config import ErtConfig -from ert.plugins import CancelPluginException, WorkflowFixtures +from ert.plugins import CancelPluginException, HookedWorkflowFixtures from ert.runpaths import Runpaths from ert.workflow_runner import WorkflowJobRunner @@ -84,7 +84,7 @@ def run(self) -> None: print("Plugin cancelled before execution!") def __runWorkflowJob( - self, arguments: list[Any] | None, fixtures: WorkflowFixtures + self, arguments: list[Any] | None, fixtures: HookedWorkflowFixtures ) -> None: self.__result = self._runner.run(arguments, fixtures=fixtures) diff --git a/src/ert/plugins/__init__.py b/src/ert/plugins/__init__.py index bbf603bf6f0..8ce25276b97 100644 --- a/src/ert/plugins/__init__.py +++ b/src/ert/plugins/__init__.py @@ -13,7 +13,12 @@ ) from .plugin_response import PluginMetadata, PluginResponse from .workflow_config import ErtScriptWorkflow, WorkflowConfigs -from .workflow_fixtures import WorkflowFixtures +from .workflow_fixtures import ( + HookedWorkflowFixtures, + WorkflowFixtures, + all_hooked_workflow_fixtures, + fixtures_per_hook, +) P = ParamSpec("P") @@ -59,8 +64,11 @@ def inner(*args: P.args, **kwargs: P.kwargs) -> Any: "ErtScript", "ErtScriptWorkflow", "ExternalErtScript", + "HookedWorkflowFixtures", "JobDoc", "WorkflowConfigs", "WorkflowFixtures", + "all_hooked_workflow_fixtures", + "fixtures_per_hook", "plugin", ] diff --git a/src/ert/plugins/ert_script.py b/src/ert/plugins/ert_script.py index e7bd06b70b1..8c59b673f85 100644 --- a/src/ert/plugins/ert_script.py +++ b/src/ert/plugins/ert_script.py @@ -9,11 +9,15 @@ from abc import abstractmethod from collections.abc import Callable from types import MappingProxyType, ModuleType -from typing import TYPE_CHECKING, Any, TypeAlias +from typing import TYPE_CHECKING, Any, TypeAlias, cast from typing_extensions import deprecated -from .workflow_fixtures import WorkflowFixtures +from .workflow_fixtures import ( + HookedWorkflowFixtures, + WorkflowFixtures, + all_hooked_workflow_fixtures, +) if TYPE_CHECKING: from ert.config import ErtConfig @@ -107,13 +111,28 @@ def cleanup(self) -> None: Override to perform cleanup after a run. """ + @property + def requested_fixtures(self) -> set[str]: + return { + k + for k, v in inspect.signature(self.run).parameters.items() + if k in all_hooked_workflow_fixtures + } + def initializeAndRun( self, argument_types: list[type[Any]], argument_values: list[str], - fixtures: WorkflowFixtures | None = None, + fixtures: WorkflowFixtures | HookedWorkflowFixtures | None = None, ) -> Any: - fixtures = {} if fixtures is None else fixtures + fixtures_without_hook = {**(fixtures or {})} + + if "hook" in fixtures_without_hook: + fixtures_without_hook.pop("hook") + + complete_fixtures: WorkflowFixtures = cast( + WorkflowFixtures, fixtures_without_hook + ) arguments = [] for index, arg_value in enumerate(argument_values): arg_type = argument_types[index] if index < len(argument_types) else str @@ -122,14 +141,15 @@ def initializeAndRun( arguments.append(arg_type(arg_value)) else: arguments.append(None) - fixtures["workflow_args"] = arguments + + complete_fixtures["workflow_args"] = arguments try: func_args = inspect.signature(self.run).parameters - # If the user has specified *args, we skip injecting fixtures, and just + # If the user has specified *args, we skip injecting complete_fixtures, and just # pass the user configured arguments if not any(p.kind == p.VAR_POSITIONAL for p in func_args.values()): try: - arguments = self.insert_fixtures(func_args, fixtures) + arguments = self.insert_fixtures(func_args, complete_fixtures) except ValueError as e: # This is here for backwards compatibility, the user does not have *argv # but positional arguments. Can not be mixed with using fixtures. diff --git a/src/ert/plugins/workflow_fixtures.py b/src/ert/plugins/workflow_fixtures.py index f73f5b6ce6d..58597ed9d6b 100644 --- a/src/ert/plugins/workflow_fixtures.py +++ b/src/ert/plugins/workflow_fixtures.py @@ -1,9 +1,12 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any +import functools +import typing +from typing import TYPE_CHECKING, NotRequired, TypedDict from PyQt6.QtWidgets import QWidget -from typing_extensions import TypedDict + +from ert.config.parsing.hook_runtime import HookRuntime if TYPE_CHECKING: from ert.config import ESSettings, UpdateSettings @@ -11,13 +14,91 @@ from ert.storage import Ensemble, Storage -class WorkflowFixtures(TypedDict, total=False): - ensemble: Ensemble | None +class PreExperimentFixtures(TypedDict): + random_seed: int + hook: HookRuntime = HookRuntime.PRE_EXPERIMENT + + +class PostExperimentFixtures(TypedDict): + hook: HookRuntime = HookRuntime.POST_EXPERIMENT + random_seed: int storage: Storage - random_seed: int | None + ensemble: Ensemble + + +class PreSimulationFixtures(TypedDict): + hook: HookRuntime = HookRuntime.PRE_SIMULATION + random_seed: int + reports_dir: str + run_paths: Runpaths + storage: Storage + ensemble: Ensemble + + +class PostSimulationFixtures(PreSimulationFixtures): + hook: HookRuntime = HookRuntime.POST_SIMULATION + + +class PreFirstUpdateFixtures(TypedDict): + hook: HookRuntime = HookRuntime.PRE_FIRST_UPDATE + random_seed: int reports_dir: str - observation_settings: UpdateSettings - es_settings: ESSettings run_paths: Runpaths - workflow_args: list[Any] - parent: QWidget | None + storage: Storage + ensemble: Ensemble + es_settings: ESSettings + observation_settings: UpdateSettings + + +class PreUpdateFixtures(PreFirstUpdateFixtures): + hook: HookRuntime = HookRuntime.PRE_UPDATE + + +class PostUpdateFixtures(PreFirstUpdateFixtures): + hook: HookRuntime = HookRuntime.POST_UPDATE + + +class WorkflowFixtures(TypedDict, total=False): + workflow_args: NotRequired[list[typing.Any]] + parent: NotRequired[QWidget] + random_seed: NotRequired[int] + reports_dir: NotRequired[str] + run_paths: NotRequired[Runpaths] + storage: NotRequired[Storage] + ensemble: NotRequired[Ensemble] + es_settings: NotRequired[ESSettings] + observation_settings: NotRequired[UpdateSettings] + + +# Union Type Definition +HookedWorkflowFixtures = ( + PreExperimentFixtures + | PostExperimentFixtures + | PreSimulationFixtures + | PostSimulationFixtures + | PreFirstUpdateFixtures + | PreUpdateFixtures + | PostUpdateFixtures +) + + +def __all_workflow_fixtures() -> set[str]: + fixtures_per_runtime = ( + __get_available_fixtures_for_hook(hook) for hook in HookRuntime + ) + + return functools.reduce(lambda a, b: a | b, fixtures_per_runtime) + + +def __get_available_fixtures_for_hook(hook: HookRuntime) -> set[str]: + cls = next( + cls for cls in typing.get_args(HookedWorkflowFixtures) if cls.hook == hook + ) + + return set(cls.__annotations__.keys()) - {"hook"} + + +all_hooked_workflow_fixtures = __all_workflow_fixtures() +fixtures_per_hook = { + hook: __get_available_fixtures_for_hook(hook) for hook in HookRuntime +} diff --git a/src/ert/run_models/base_run_model.py b/src/ert/run_models/base_run_model.py index f7fa510ad18..b11805fdfe5 100644 --- a/src/ert/run_models/base_run_model.py +++ b/src/ert/run_models/base_run_model.py @@ -55,7 +55,7 @@ REALIZATION_STATE_FINISHED, ) from ert.mode_definitions import MODULE_MODE -from ert.plugins import WorkflowFixtures +from ert.plugins import HookedWorkflowFixtures from ert.runpaths import Runpaths from ert.storage import Ensemble, Storage from ert.substitutions import Substitutions @@ -737,7 +737,7 @@ def validate_successful_realizations_count(self) -> None: def run_workflows( self, runtime: HookRuntime, - fixtures: WorkflowFixtures, + fixtures: HookedWorkflowFixtures, ) -> None: for workflow in self._hooked_workflows[runtime]: WorkflowRunner(workflow=workflow, fixtures=fixtures).run_blocking() @@ -883,7 +883,7 @@ def update( ) ) - workflow_fixtures: WorkflowFixtures = { + workflow_fixtures: HookedWorkflowFixtures = { "storage": self._storage, "ensemble": prior, "observation_settings": self._update_settings, diff --git a/src/ert/workflow_runner.py b/src/ert/workflow_runner.py index c466f89378b..c6f69d39281 100644 --- a/src/ert/workflow_runner.py +++ b/src/ert/workflow_runner.py @@ -6,7 +6,7 @@ from typing import Any, Self from ert.config import Workflow, WorkflowJob -from ert.plugins import ErtScript, ExternalErtScript, WorkflowFixtures +from ert.plugins import ErtScript, ExternalErtScript, HookedWorkflowFixtures class WorkflowJobRunner: @@ -19,7 +19,7 @@ def __init__(self, workflow_job: WorkflowJob): def run( self, arguments: list[Any] | None = None, - fixtures: WorkflowFixtures | None = None, + fixtures: HookedWorkflowFixtures | None = None, ) -> Any: if arguments is None: arguments = [] @@ -101,7 +101,7 @@ class WorkflowRunner: def __init__( self, workflow: Workflow, - fixtures: WorkflowFixtures, + fixtures: HookedWorkflowFixtures, ) -> None: self.__workflow = workflow self.fixtures = fixtures diff --git a/tests/ert/unit_tests/config/test_ert_config.py b/tests/ert/unit_tests/config/test_ert_config.py index feac2ec18ac..5a3d57b6184 100644 --- a/tests/ert/unit_tests/config/test_ert_config.py +++ b/tests/ert/unit_tests/config/test_ert_config.py @@ -16,7 +16,8 @@ from hypothesis import strategies as st from pydantic import RootModel, TypeAdapter -from ert.config import ConfigValidationError, ErtConfig, HookRuntime +from ert import ErtScript, ErtScriptWorkflow +from ert.config import ConfigValidationError, ErtConfig, ESSettings, HookRuntime from ert.config.ert_config import _split_string_into_sections, create_forward_model_json from ert.config.forward_model_step import ForwardModelStep from ert.config.parsing import ConfigKeys, ConfigWarning @@ -31,6 +32,7 @@ from ert.config.parsing.queue_system import QueueSystem from ert.plugins import ErtPluginManager from ert.shared import ert_share_path +from ert.storage import LocalEnsemble, Storage from tests.ert.ui_tests.cli.analysis.test_design_matrix import _create_design_matrix from .config_dict_generator import config_generators @@ -2177,6 +2179,131 @@ def test_run_template_raises_configvalidationerror_with_more_than_two_arguments( ) +@pytest.fixture +def setup_workflow_file(): + workflow_file_path = os.path.join(os.getcwd(), "workflow") + with open(workflow_file_path, mode="w", encoding="utf-8") as fh: + fh.write("TEST_SCRIPT") + + with open("config.ert", mode="w", encoding="utf-8") as fh: + fh.write( + dedent( + f""" + NUM_REALIZATIONS 1 + + LOAD_WORKFLOW {workflow_file_path} workflow_alias + HOOK_WORKFLOW workflow_alias PRE_EXPERIMENT + """ + ) + ) + + +@pytest.mark.usefixtures("use_tmpdir", "setup_workflow_file") +def test_ert_script_hook_pre_experiment_but_asks_for_storage(): + class SomeScript(ErtScript): + def run(self, storage: Storage): + pass + + wfjob = ErtScriptWorkflow( + name="TEST_SCRIPT", + ertscript_class=SomeScript, + ) + ErtConfig.PREINSTALLED_WORKFLOWS = {"TEST_SCRIPT": wfjob} + + with pytest.raises( + ConfigValidationError, + match=r"Workflow job TEST_SCRIPT.*" + r"expected fixtures.*storage.*", + ): + ErtConfig.from_file("config.ert") + + +@pytest.mark.usefixtures("use_tmpdir", "setup_workflow_file") +def test_ert_script_hook_pre_experiment_but_asks_for_ensemble(): + class SomeScript(ErtScript): + def run(self, ensemble: LocalEnsemble): + pass + + wfjob = ErtScriptWorkflow( + name="TEST_SCRIPT", + ertscript_class=SomeScript, + ) + ErtConfig.PREINSTALLED_WORKFLOWS = {"TEST_SCRIPT": wfjob} + + with pytest.raises( + ConfigValidationError, + match=r"Workflow job TEST_SCRIPT.*" + r"expected fixtures.*ensemble.*", + ): + ErtConfig.from_file("config.ert") + + +@pytest.mark.usefixtures("use_tmpdir", "setup_workflow_file") +def test_ert_script_hook_pre_experiment_but_asks_for_random_seed(): + class SomeScript(ErtScript): + def run(self, random_seed: int): + pass + + wfjob = ErtScriptWorkflow( + name="TEST_SCRIPT", + ertscript_class=SomeScript, + ) + ErtConfig.PREINSTALLED_WORKFLOWS = {"TEST_SCRIPT": wfjob} + + ErtConfig.from_file("config.ert") + + +@pytest.mark.usefixtures("use_tmpdir", "setup_workflow_file") +def test_ert_script_hook_pre_experiment_essettings_fails(): + class SomeScript(ErtScript): + def run(self, es_settings: ESSettings): + pass + + wfjob = ErtScriptWorkflow( + name="TEST_SCRIPT", + ertscript_class=SomeScript, + ) + ErtConfig.PREINSTALLED_WORKFLOWS = {"TEST_SCRIPT": wfjob} + + with pytest.raises( + ConfigValidationError, + match=r".*It would work in these runtimes: PRE_UPDATE, POST_UPDATE, PRE_FIRST_UPDATE.*", + ): + ErtConfig.from_file("config.ert") + + +@pytest.mark.usefixtures("use_tmpdir") +def test_ert_script_hook_valid_essettings_succeed(): + workflow_file_path = os.path.join(os.getcwd(), "workflow") + with open(workflow_file_path, mode="w", encoding="utf-8") as fh: + fh.write("TEST_SCRIPT") + + with open("config.ert", mode="w", encoding="utf-8") as fh: + fh.write( + dedent( + f""" + NUM_REALIZATIONS 1 + + LOAD_WORKFLOW {workflow_file_path} workflow_alias + HOOK_WORKFLOW workflow_alias PRE_UPDATE + HOOK_WORKFLOW workflow_alias POST_UPDATE + HOOK_WORKFLOW workflow_alias PRE_FIRST_UPDATE + """ + ) + ) + + class SomeScript(ErtScript): + def run(self, es_settings: ESSettings): + pass + + wfjob = ErtScriptWorkflow( + name="TEST_SCRIPT", + ertscript_class=SomeScript, + ) + ErtConfig.PREINSTALLED_WORKFLOWS = {"TEST_SCRIPT": wfjob} + ErtConfig.from_file("config.ert") + + def test_queue_options_are_joined_after_option_name(): assert ( ErtConfig.from_file_contents(