From 32e215519f816f31eb25dbea7873a8284115bef4 Mon Sep 17 00:00:00 2001 From: "Yngve S. Kristiansen" Date: Mon, 17 Mar 2025 16:30:40 +0100 Subject: [PATCH 1/5] Validate workflow job args against hook runtime fixtures --- src/ert/config/ert_config.py | 39 ++++ src/ert/config/parsing/hook_runtime.py | 51 +++++ src/ert/plugins/ert_script.py | 8 + .../ert/unit_tests/config/test_ert_config.py | 174 +++++++++++++++++- 4 files changed, 271 insertions(+), 1 deletion(-) diff --git a/src/ert/config/ert_config.py b/src/ert/config/ert_config.py index 345b202d5f9..fa430bcfb0b 100644 --- a/src/ert/config/ert_config.py +++ b/src/ert/config/ert_config.py @@ -60,6 +60,7 @@ from .parsing import ( parse as parse_config, ) +from .parsing.hook_runtime import fixtures_per_runtime from .parsing.observations_parser import ( GenObsValues, HistoryValues, @@ -395,10 +396,48 @@ def workflows_from_dict( ) continue + wf = workflows[hook_name] + available_fixtures = fixtures_per_runtime[mode] + for job, _ in wf.cmd_list: + if job.ert_script is None: + continue + + ert_script_instance = job.ert_script() + requested_fixtures = ert_script_instance.requested_fixtures + + # Look for requested fixtures that are not available for the given + # mode + missing_fixtures = requested_fixtures - available_fixtures + + if missing_fixtures: + ok_modes = [ + m + for m in HookRuntime + if not requested_fixtures - fixtures_per_runtime[m] + ] + + message_start = ( + f"Workflow job {job.name} .run function expected " + f"fixtures: {missing_fixtures}, which are not available " + f"in the fixtures for the runtime {mode}: {available_fixtures}. " + ) + message_end = ( + f"It would work in these runtimes: {', '.join(map(str, ok_modes))}" + if len(ok_modes) > 0 + else "This fixture is not available in any of the runtimes." + ) + + errors.append( + ErrorInfo(message=message_start + message_end).set_context( + hook_name + ) + ) + hooked_workflows[mode].append(workflows[hook_name]) if errors: raise ConfigValidationError.from_collected(errors) + return workflow_jobs, workflows, hooked_workflows diff --git a/src/ert/config/parsing/hook_runtime.py b/src/ert/config/parsing/hook_runtime.py index 70c9cc740db..32b63b25ba5 100644 --- a/src/ert/config/parsing/hook_runtime.py +++ b/src/ert/config/parsing/hook_runtime.py @@ -9,3 +9,54 @@ class HookRuntime(StrEnum): PRE_FIRST_UPDATE = "PRE_FIRST_UPDATE" PRE_EXPERIMENT = "PRE_EXPERIMENT" POST_EXPERIMENT = "POST_EXPERIMENT" + + +fixtures_per_runtime = { + HookRuntime.PRE_EXPERIMENT: {"random_seed"}, + HookRuntime.PRE_SIMULATION: { + "storage", + "ensemble", + "reports_dir", + "random_seed", + "run_paths", + }, + HookRuntime.POST_SIMULATION: { + "storage", + "ensemble", + "reports_dir", + "random_seed", + "run_paths", + }, + HookRuntime.PRE_FIRST_UPDATE: { + "storage", + "ensemble", + "reports_dir", + "random_seed", + "es_settings", + "observation_settings", + "run_paths", + }, + HookRuntime.PRE_UPDATE: { + "storage", + "ensemble", + "reports_dir", + "random_seed", + "es_settings", + "observation_settings", + "run_paths", + }, + HookRuntime.POST_UPDATE: { + "storage", + "ensemble", + "reports_dir", + "random_seed", + "es_settings", + "observation_settings", + "run_paths", + }, + HookRuntime.POST_EXPERIMENT: { + "random_seed", + "storage", + "ensemble", + }, +} diff --git a/src/ert/plugins/ert_script.py b/src/ert/plugins/ert_script.py index e7bd06b70b1..d819ebf1208 100644 --- a/src/ert/plugins/ert_script.py +++ b/src/ert/plugins/ert_script.py @@ -107,6 +107,14 @@ def cleanup(self) -> None: Override to perform cleanup after a run. """ + @property + def requested_fixtures(self) -> set[str]: + return { + k + for k, v in inspect.signature(self.run).parameters.items() + if k in WorkflowFixtures.__annotations__ and k != "workflow_args" + } + def initializeAndRun( self, argument_types: list[type[Any]], diff --git a/tests/ert/unit_tests/config/test_ert_config.py b/tests/ert/unit_tests/config/test_ert_config.py index feac2ec18ac..2b899bf18f2 100644 --- a/tests/ert/unit_tests/config/test_ert_config.py +++ b/tests/ert/unit_tests/config/test_ert_config.py @@ -16,7 +16,8 @@ from hypothesis import strategies as st from pydantic import RootModel, TypeAdapter -from ert.config import ConfigValidationError, ErtConfig, HookRuntime +from ert import ErtScript, ErtScriptWorkflow +from ert.config import ConfigValidationError, ErtConfig, ESSettings, HookRuntime from ert.config.ert_config import _split_string_into_sections, create_forward_model_json from ert.config.forward_model_step import ForwardModelStep from ert.config.parsing import ConfigKeys, ConfigWarning @@ -31,6 +32,7 @@ from ert.config.parsing.queue_system import QueueSystem from ert.plugins import ErtPluginManager from ert.shared import ert_share_path +from ert.storage import LocalEnsemble, Storage from tests.ert.ui_tests.cli.analysis.test_design_matrix import _create_design_matrix from .config_dict_generator import config_generators @@ -2177,6 +2179,176 @@ def test_run_template_raises_configvalidationerror_with_more_than_two_arguments( ) +@pytest.mark.usefixtures("use_tmpdir") +def test_ert_script_hook_pre_experiment_but_asks_for_storage(): + workflow_file_path = os.path.join(os.getcwd(), "workflow") + with open(workflow_file_path, mode="w", encoding="utf-8") as fh: + fh.write("TEST_SCRIPT") + + with open("config.ert", mode="w", encoding="utf-8") as fh: + fh.write( + dedent( + f""" + NUM_REALIZATIONS 1 + + LOAD_WORKFLOW {workflow_file_path} workflow_alias + HOOK_WORKFLOW workflow_alias PRE_EXPERIMENT + """ + ) + ) + + class SomeScript(ErtScript): + def run(self, storage: Storage): + pass + + wfjob = ErtScriptWorkflow( + name="TEST_SCRIPT", + ertscript_class=SomeScript, + ) + ErtConfig.PREINSTALLED_WORKFLOWS = {"TEST_SCRIPT": wfjob} + + with pytest.raises( + ConfigValidationError, + match=r"Workflow job TEST_SCRIPT.*" + r"expected fixtures.*storage.*", + ): + ErtConfig.from_file("config.ert") + + +@pytest.mark.usefixtures("use_tmpdir") +def test_ert_script_hook_pre_experiment_but_asks_for_ensemble(): + workflow_file_path = os.path.join(os.getcwd(), "workflow") + with open(workflow_file_path, mode="w", encoding="utf-8") as fh: + fh.write("TEST_SCRIPT") + + with open("config.ert", mode="w", encoding="utf-8") as fh: + fh.write( + dedent( + f""" + NUM_REALIZATIONS 1 + + LOAD_WORKFLOW {workflow_file_path} workflow_alias + HOOK_WORKFLOW workflow_alias PRE_EXPERIMENT + """ + ) + ) + + class SomeScript(ErtScript): + def run(self, ensemble: LocalEnsemble): + pass + + wfjob = ErtScriptWorkflow( + name="TEST_SCRIPT", + ertscript_class=SomeScript, + ) + ErtConfig.PREINSTALLED_WORKFLOWS = {"TEST_SCRIPT": wfjob} + + with pytest.raises( + ConfigValidationError, + match=r"Workflow job TEST_SCRIPT.*" + r"expected fixtures.*ensemble.*", + ): + ErtConfig.from_file("config.ert") + + +@pytest.mark.usefixtures("use_tmpdir") +def test_ert_script_hook_pre_experiment_but_asks_for_random_seed(): + workflow_file_path = os.path.join(os.getcwd(), "workflow") + with open(workflow_file_path, mode="w", encoding="utf-8") as fh: + fh.write("TEST_SCRIPT") + + with open("config.ert", mode="w", encoding="utf-8") as fh: + fh.write( + dedent( + f""" + NUM_REALIZATIONS 1 + + LOAD_WORKFLOW {workflow_file_path} workflow_alias + HOOK_WORKFLOW workflow_alias PRE_EXPERIMENT + """ + ) + ) + + class SomeScript(ErtScript): + def run(self, random_seed: int): + pass + + wfjob = ErtScriptWorkflow( + name="TEST_SCRIPT", + ertscript_class=SomeScript, + ) + ErtConfig.PREINSTALLED_WORKFLOWS = {"TEST_SCRIPT": wfjob} + + ErtConfig.from_file("config.ert") + + +@pytest.mark.usefixtures("use_tmpdir") +def test_ert_script_hook_pre_experiment_essettings_fails(): + workflow_file_path = os.path.join(os.getcwd(), "workflow") + with open(workflow_file_path, mode="w", encoding="utf-8") as fh: + fh.write("TEST_SCRIPT") + + with open("config.ert", mode="w", encoding="utf-8") as fh: + fh.write( + dedent( + f""" + NUM_REALIZATIONS 1 + + LOAD_WORKFLOW {workflow_file_path} workflow_alias + HOOK_WORKFLOW workflow_alias PRE_EXPERIMENT + """ + ) + ) + + class SomeScript(ErtScript): + def run(self, es_settings: ESSettings): + pass + + wfjob = ErtScriptWorkflow( + name="TEST_SCRIPT", + ertscript_class=SomeScript, + ) + ErtConfig.PREINSTALLED_WORKFLOWS = {"TEST_SCRIPT": wfjob} + + with pytest.raises( + ConfigValidationError, + match=r".*It would work in these runtimes: PRE_UPDATE, POST_UPDATE, PRE_FIRST_UPDATE.*", + ): + ErtConfig.from_file("config.ert") + + +@pytest.mark.usefixtures("use_tmpdir") +def test_ert_script_hook_valid_essettings_succeed(): + workflow_file_path = os.path.join(os.getcwd(), "workflow") + with open(workflow_file_path, mode="w", encoding="utf-8") as fh: + fh.write("TEST_SCRIPT") + + with open("config.ert", mode="w", encoding="utf-8") as fh: + fh.write( + dedent( + f""" + NUM_REALIZATIONS 1 + + LOAD_WORKFLOW {workflow_file_path} workflow_alias + HOOK_WORKFLOW workflow_alias PRE_UPDATE + HOOK_WORKFLOW workflow_alias POST_UPDATE + HOOK_WORKFLOW workflow_alias PRE_FIRST_UPDATE + """ + ) + ) + + class SomeScript(ErtScript): + def run(self, es_settings: ESSettings): + pass + + wfjob = ErtScriptWorkflow( + name="TEST_SCRIPT", + ertscript_class=SomeScript, + ) + ErtConfig.PREINSTALLED_WORKFLOWS = {"TEST_SCRIPT": wfjob} + ErtConfig.from_file("config.ert") + + def test_queue_options_are_joined_after_option_name(): assert ( ErtConfig.from_file_contents( From a93b04b34a478827863b6374fcb85c21cb8b9050 Mon Sep 17 00:00:00 2001 From: "Yngve S. Kristiansen" Date: Tue, 18 Mar 2025 13:57:20 +0100 Subject: [PATCH 2/5] fixup! Validate workflow job args against hook runtime fixtures address review --- .../ert/unit_tests/config/test_ert_config.py | 69 ++++--------------- 1 file changed, 12 insertions(+), 57 deletions(-) diff --git a/tests/ert/unit_tests/config/test_ert_config.py b/tests/ert/unit_tests/config/test_ert_config.py index 2b899bf18f2..5a3d57b6184 100644 --- a/tests/ert/unit_tests/config/test_ert_config.py +++ b/tests/ert/unit_tests/config/test_ert_config.py @@ -2179,8 +2179,8 @@ def test_run_template_raises_configvalidationerror_with_more_than_two_arguments( ) -@pytest.mark.usefixtures("use_tmpdir") -def test_ert_script_hook_pre_experiment_but_asks_for_storage(): +@pytest.fixture +def setup_workflow_file(): workflow_file_path = os.path.join(os.getcwd(), "workflow") with open(workflow_file_path, mode="w", encoding="utf-8") as fh: fh.write("TEST_SCRIPT") @@ -2189,14 +2189,17 @@ def test_ert_script_hook_pre_experiment_but_asks_for_storage(): fh.write( dedent( f""" - NUM_REALIZATIONS 1 + NUM_REALIZATIONS 1 - LOAD_WORKFLOW {workflow_file_path} workflow_alias - HOOK_WORKFLOW workflow_alias PRE_EXPERIMENT - """ + LOAD_WORKFLOW {workflow_file_path} workflow_alias + HOOK_WORKFLOW workflow_alias PRE_EXPERIMENT + """ ) ) + +@pytest.mark.usefixtures("use_tmpdir", "setup_workflow_file") +def test_ert_script_hook_pre_experiment_but_asks_for_storage(): class SomeScript(ErtScript): def run(self, storage: Storage): pass @@ -2215,24 +2218,8 @@ def run(self, storage: Storage): ErtConfig.from_file("config.ert") -@pytest.mark.usefixtures("use_tmpdir") +@pytest.mark.usefixtures("use_tmpdir", "setup_workflow_file") def test_ert_script_hook_pre_experiment_but_asks_for_ensemble(): - workflow_file_path = os.path.join(os.getcwd(), "workflow") - with open(workflow_file_path, mode="w", encoding="utf-8") as fh: - fh.write("TEST_SCRIPT") - - with open("config.ert", mode="w", encoding="utf-8") as fh: - fh.write( - dedent( - f""" - NUM_REALIZATIONS 1 - - LOAD_WORKFLOW {workflow_file_path} workflow_alias - HOOK_WORKFLOW workflow_alias PRE_EXPERIMENT - """ - ) - ) - class SomeScript(ErtScript): def run(self, ensemble: LocalEnsemble): pass @@ -2251,24 +2238,8 @@ def run(self, ensemble: LocalEnsemble): ErtConfig.from_file("config.ert") -@pytest.mark.usefixtures("use_tmpdir") +@pytest.mark.usefixtures("use_tmpdir", "setup_workflow_file") def test_ert_script_hook_pre_experiment_but_asks_for_random_seed(): - workflow_file_path = os.path.join(os.getcwd(), "workflow") - with open(workflow_file_path, mode="w", encoding="utf-8") as fh: - fh.write("TEST_SCRIPT") - - with open("config.ert", mode="w", encoding="utf-8") as fh: - fh.write( - dedent( - f""" - NUM_REALIZATIONS 1 - - LOAD_WORKFLOW {workflow_file_path} workflow_alias - HOOK_WORKFLOW workflow_alias PRE_EXPERIMENT - """ - ) - ) - class SomeScript(ErtScript): def run(self, random_seed: int): pass @@ -2282,24 +2253,8 @@ def run(self, random_seed: int): ErtConfig.from_file("config.ert") -@pytest.mark.usefixtures("use_tmpdir") +@pytest.mark.usefixtures("use_tmpdir", "setup_workflow_file") def test_ert_script_hook_pre_experiment_essettings_fails(): - workflow_file_path = os.path.join(os.getcwd(), "workflow") - with open(workflow_file_path, mode="w", encoding="utf-8") as fh: - fh.write("TEST_SCRIPT") - - with open("config.ert", mode="w", encoding="utf-8") as fh: - fh.write( - dedent( - f""" - NUM_REALIZATIONS 1 - - LOAD_WORKFLOW {workflow_file_path} workflow_alias - HOOK_WORKFLOW workflow_alias PRE_EXPERIMENT - """ - ) - ) - class SomeScript(ErtScript): def run(self, es_settings: ESSettings): pass From 33f99854ca049eb5d2f837d4b34d56132b19025d Mon Sep 17 00:00:00 2001 From: "Yngve S. Kristiansen" Date: Fri, 21 Mar 2025 08:50:28 +0100 Subject: [PATCH 3/5] tmp --- src/ert/config/parsing/hook_runtime.py | 43 +++++++++++++++++++++++++- 1 file changed, 42 insertions(+), 1 deletion(-) diff --git a/src/ert/config/parsing/hook_runtime.py b/src/ert/config/parsing/hook_runtime.py index 32b63b25ba5..3351f96d846 100644 --- a/src/ert/config/parsing/hook_runtime.py +++ b/src/ert/config/parsing/hook_runtime.py @@ -1,4 +1,10 @@ +from dataclasses import dataclass from enum import StrEnum +from typing import Literal + +from ert.config import ESSettings, UpdateSettings +from ert.runpaths import Runpaths +from ert.storage import Ensemble, Storage class HookRuntime(StrEnum): @@ -10,9 +16,44 @@ class HookRuntime(StrEnum): PRE_EXPERIMENT = "PRE_EXPERIMENT" POST_EXPERIMENT = "POST_EXPERIMENT" +@dataclass +class PreExperimentFixtures: + random_seed: int + hook: Literal["pre_experiment"] = HookRuntime.PRE_EXPERIMENT + + +class PostExperimentFixtures(PreExperimentFixtures): + storage: Storage + ensemble: Ensemble + hook: HookRuntime = HookRuntime.POST_EXPERIMENT + +class PreSimulationFixtures(PostExperimentFixtures): + reports_dir: str + run_paths: Runpaths + hook: HookRuntime = HookRuntime.PRE_SIMULATION + +class PostSimulationFixtures(PreSimulationFixtures): + hook: HookRuntime = HookRuntime.POST_SIMULATION + + +class PreFirstUpdateFixtures(PreSimulationFixtures): + es_settings: ESSettings + observation_settings: UpdateSettings + hook: HookRuntime = HookRuntime.PRE_FIRST_UPDATE + + +class PreUpdateFixtures(PreFirstUpdateFixtures): + hook: HookRuntime = HookRuntime.PRE_UPDATE + + +class PostUpdateFixtures(PreFirstUpdateFixtures): + hook: HookRuntime = HookRuntime.POST_UPDATE + + +WorkflowFixtures = PreExperimentFixtures | PostExperimentFixtures | PreSimulationFixtures | PostSimulationFixtures | PreFirstUpdateFixtures | PreUpdateFixtures | PostUpdateFixtures fixtures_per_runtime = { - HookRuntime.PRE_EXPERIMENT: {"random_seed"}, + HookRuntime.PRE_EXPERIMENT: PreExperimentFixtures.__, HookRuntime.PRE_SIMULATION: { "storage", "ensemble", From 054a3568d0ff6e8a1fcb00b76bcf341d40cbb5c2 Mon Sep 17 00:00:00 2001 From: "Yngve S. Kristiansen" Date: Fri, 21 Mar 2025 11:45:11 +0100 Subject: [PATCH 4/5] tmp --- src/ert/config/ert_config.py | 6 +- src/ert/config/parsing/hook_runtime.py | 92 -------------------------- src/ert/plugins/workflow_fixtures.py | 90 ++++++++++++++++++++++--- 3 files changed, 82 insertions(+), 106 deletions(-) diff --git a/src/ert/config/ert_config.py b/src/ert/config/ert_config.py index fa430bcfb0b..d901f841780 100644 --- a/src/ert/config/ert_config.py +++ b/src/ert/config/ert_config.py @@ -27,6 +27,7 @@ from ert.plugins.workflow_config import ErtScriptWorkflow from ert.substitutions import Substitutions +from ..plugins.workflow_fixtures import get_available_fixtures_for_runtime from ._design_matrix_validator import DesignMatrixValidator from .analysis_config import AnalysisConfig from .ensemble_config import EnsembleConfig @@ -60,7 +61,6 @@ from .parsing import ( parse as parse_config, ) -from .parsing.hook_runtime import fixtures_per_runtime from .parsing.observations_parser import ( GenObsValues, HistoryValues, @@ -397,7 +397,7 @@ def workflows_from_dict( continue wf = workflows[hook_name] - available_fixtures = fixtures_per_runtime[mode] + available_fixtures = get_available_fixtures_for_runtime(hook_name) for job, _ in wf.cmd_list: if job.ert_script is None: continue @@ -413,7 +413,7 @@ def workflows_from_dict( ok_modes = [ m for m in HookRuntime - if not requested_fixtures - fixtures_per_runtime[m] + if not requested_fixtures - get_available_fixtures_for_runtime(m) ] message_start = ( diff --git a/src/ert/config/parsing/hook_runtime.py b/src/ert/config/parsing/hook_runtime.py index 3351f96d846..70c9cc740db 100644 --- a/src/ert/config/parsing/hook_runtime.py +++ b/src/ert/config/parsing/hook_runtime.py @@ -1,10 +1,4 @@ -from dataclasses import dataclass from enum import StrEnum -from typing import Literal - -from ert.config import ESSettings, UpdateSettings -from ert.runpaths import Runpaths -from ert.storage import Ensemble, Storage class HookRuntime(StrEnum): @@ -15,89 +9,3 @@ class HookRuntime(StrEnum): PRE_FIRST_UPDATE = "PRE_FIRST_UPDATE" PRE_EXPERIMENT = "PRE_EXPERIMENT" POST_EXPERIMENT = "POST_EXPERIMENT" - -@dataclass -class PreExperimentFixtures: - random_seed: int - hook: Literal["pre_experiment"] = HookRuntime.PRE_EXPERIMENT - - -class PostExperimentFixtures(PreExperimentFixtures): - storage: Storage - ensemble: Ensemble - hook: HookRuntime = HookRuntime.POST_EXPERIMENT - -class PreSimulationFixtures(PostExperimentFixtures): - reports_dir: str - run_paths: Runpaths - hook: HookRuntime = HookRuntime.PRE_SIMULATION - -class PostSimulationFixtures(PreSimulationFixtures): - hook: HookRuntime = HookRuntime.POST_SIMULATION - - -class PreFirstUpdateFixtures(PreSimulationFixtures): - es_settings: ESSettings - observation_settings: UpdateSettings - hook: HookRuntime = HookRuntime.PRE_FIRST_UPDATE - - -class PreUpdateFixtures(PreFirstUpdateFixtures): - hook: HookRuntime = HookRuntime.PRE_UPDATE - - -class PostUpdateFixtures(PreFirstUpdateFixtures): - hook: HookRuntime = HookRuntime.POST_UPDATE - - -WorkflowFixtures = PreExperimentFixtures | PostExperimentFixtures | PreSimulationFixtures | PostSimulationFixtures | PreFirstUpdateFixtures | PreUpdateFixtures | PostUpdateFixtures - -fixtures_per_runtime = { - HookRuntime.PRE_EXPERIMENT: PreExperimentFixtures.__, - HookRuntime.PRE_SIMULATION: { - "storage", - "ensemble", - "reports_dir", - "random_seed", - "run_paths", - }, - HookRuntime.POST_SIMULATION: { - "storage", - "ensemble", - "reports_dir", - "random_seed", - "run_paths", - }, - HookRuntime.PRE_FIRST_UPDATE: { - "storage", - "ensemble", - "reports_dir", - "random_seed", - "es_settings", - "observation_settings", - "run_paths", - }, - HookRuntime.PRE_UPDATE: { - "storage", - "ensemble", - "reports_dir", - "random_seed", - "es_settings", - "observation_settings", - "run_paths", - }, - HookRuntime.POST_UPDATE: { - "storage", - "ensemble", - "reports_dir", - "random_seed", - "es_settings", - "observation_settings", - "run_paths", - }, - HookRuntime.POST_EXPERIMENT: { - "random_seed", - "storage", - "ensemble", - }, -} diff --git a/src/ert/plugins/workflow_fixtures.py b/src/ert/plugins/workflow_fixtures.py index f73f5b6ce6d..656a4e81882 100644 --- a/src/ert/plugins/workflow_fixtures.py +++ b/src/ert/plugins/workflow_fixtures.py @@ -1,9 +1,8 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any - -from PyQt6.QtWidgets import QWidget -from typing_extensions import TypedDict +import typing +from typing import TYPE_CHECKING, TypedDict +from ert.config.parsing.hook_runtime import HookRuntime if TYPE_CHECKING: from ert.config import ESSettings, UpdateSettings @@ -11,13 +10,82 @@ from ert.storage import Ensemble, Storage -class WorkflowFixtures(TypedDict, total=False): - ensemble: Ensemble | None +class _BaseWorkflowFixtures(TypedDict): + random_seed: int + + +class _UpdateWorkflowFixtures(TypedDict): + es_settings: ESSettings + observation_settings: UpdateSettings + + +class _StorageWorkflowFixtures(TypedDict): storage: Storage - random_seed: int | None + ensemble: Ensemble + + +class PreExperimentFixtures(_BaseWorkflowFixtures): + random_seed: int + hook: HookRuntime.PRE_EXPERIMENT + + +class PostExperimentFixtures(_BaseWorkflowFixtures, _StorageWorkflowFixtures): + hook: HookRuntime.POST_EXPERIMENT + + +class PreSimulationFixtures(_BaseWorkflowFixtures, _StorageWorkflowFixtures): + hook: HookRuntime.PRE_SIMULATION reports_dir: str - observation_settings: UpdateSettings - es_settings: ESSettings run_paths: Runpaths - workflow_args: list[Any] - parent: QWidget | None + + +class PostSimulationFixtures(PreSimulationFixtures): + hook: HookRuntime.POST_SIMULATION + + +class PreFirstUpdateFixtures( + _BaseWorkflowFixtures, _UpdateWorkflowFixtures, _StorageWorkflowFixtures +): + hook: HookRuntime.PRE_FIRST_UPDATE + reports_dir: str + run_paths: Runpaths + + +class PreUpdateFixtures(PreFirstUpdateFixtures): + hook: HookRuntime.PRE_UPDATE + + +class PostUpdateFixtures(PreFirstUpdateFixtures): + hook: HookRuntime.POST_UPDATE + + +# Union Type Definition +WorkflowFixtures = ( + PreExperimentFixtures + | PostExperimentFixtures + | PreSimulationFixtures + | PostSimulationFixtures + | PreFirstUpdateFixtures + | PreUpdateFixtures + | PostUpdateFixtures +) + + +def matches_hook(cls, hook: HookRuntime): + return cls.__annotations__.get("hook") == hook + + +def available_fixtures(cls): + return set(cls.__annotations__.keys()) - {"hook"} + + +def get_available_fixtures_for_runtime(hook: HookRuntime) -> set[str]: + cls = next( + cls + for cls in typing.get_args(WorkflowFixtures) + if matches_hook(cls, hook) + ) + + ref = cls + + return available_fixtures(cls) From 366ffe1ad95e66dabc1853b2a835b09c9d3add88 Mon Sep 17 00:00:00 2001 From: "Yngve S. Kristiansen" Date: Fri, 21 Mar 2025 14:57:30 +0100 Subject: [PATCH 5/5] Address review --- src/ert/config/ert_config.py | 7 +- src/ert/gui/tools/plugins/plugin.py | 4 +- src/ert/gui/tools/plugins/plugin_runner.py | 4 +- src/ert/plugins/__init__.py | 10 ++- src/ert/plugins/ert_script.py | 28 +++++-- src/ert/plugins/workflow_fixtures.py | 87 +++++++++++++--------- src/ert/run_models/base_run_model.py | 6 +- src/ert/workflow_runner.py | 6 +- 8 files changed, 92 insertions(+), 60 deletions(-) diff --git a/src/ert/config/ert_config.py b/src/ert/config/ert_config.py index d901f841780..d0c36b4fbb6 100644 --- a/src/ert/config/ert_config.py +++ b/src/ert/config/ert_config.py @@ -23,11 +23,10 @@ from pydantic import field_validator from pydantic.dataclasses import dataclass, rebuild_dataclass -from ert.plugins import ErtPluginManager +from ert.plugins import ErtPluginManager, fixtures_per_hook from ert.plugins.workflow_config import ErtScriptWorkflow from ert.substitutions import Substitutions -from ..plugins.workflow_fixtures import get_available_fixtures_for_runtime from ._design_matrix_validator import DesignMatrixValidator from .analysis_config import AnalysisConfig from .ensemble_config import EnsembleConfig @@ -397,7 +396,7 @@ def workflows_from_dict( continue wf = workflows[hook_name] - available_fixtures = get_available_fixtures_for_runtime(hook_name) + available_fixtures = fixtures_per_hook[mode] for job, _ in wf.cmd_list: if job.ert_script is None: continue @@ -413,7 +412,7 @@ def workflows_from_dict( ok_modes = [ m for m in HookRuntime - if not requested_fixtures - get_available_fixtures_for_runtime(m) + if not requested_fixtures - fixtures_per_hook[m] ] message_start = ( diff --git a/src/ert/gui/tools/plugins/plugin.py b/src/ert/gui/tools/plugins/plugin.py index da66e20d5e1..29624b31010 100644 --- a/src/ert/gui/tools/plugins/plugin.py +++ b/src/ert/gui/tools/plugins/plugin.py @@ -3,7 +3,7 @@ import inspect from typing import TYPE_CHECKING, Any -from ert.plugins import ErtPlugin, WorkflowFixtures +from ert.plugins import ErtPlugin, HookedWorkflowFixtures if TYPE_CHECKING: from PyQt6.QtWidgets import QWidget @@ -34,7 +34,7 @@ def getName(self) -> str: def getDescription(self) -> str: return self.__description - def getArguments(self, fixtures: WorkflowFixtures) -> list[Any]: + def getArguments(self, fixtures: HookedWorkflowFixtures) -> list[Any]: """ Returns a list of arguments. Either from GUI or from arbitrary code. If the user for example cancels in the GUI a CancelPluginException is raised. diff --git a/src/ert/gui/tools/plugins/plugin_runner.py b/src/ert/gui/tools/plugins/plugin_runner.py index 39a21acec26..f26cfcc1e55 100644 --- a/src/ert/gui/tools/plugins/plugin_runner.py +++ b/src/ert/gui/tools/plugins/plugin_runner.py @@ -6,7 +6,7 @@ from _ert.threading import ErtThread from ert.config import ErtConfig -from ert.plugins import CancelPluginException, WorkflowFixtures +from ert.plugins import CancelPluginException, HookedWorkflowFixtures from ert.runpaths import Runpaths from ert.workflow_runner import WorkflowJobRunner @@ -84,7 +84,7 @@ def run(self) -> None: print("Plugin cancelled before execution!") def __runWorkflowJob( - self, arguments: list[Any] | None, fixtures: WorkflowFixtures + self, arguments: list[Any] | None, fixtures: HookedWorkflowFixtures ) -> None: self.__result = self._runner.run(arguments, fixtures=fixtures) diff --git a/src/ert/plugins/__init__.py b/src/ert/plugins/__init__.py index bbf603bf6f0..8ce25276b97 100644 --- a/src/ert/plugins/__init__.py +++ b/src/ert/plugins/__init__.py @@ -13,7 +13,12 @@ ) from .plugin_response import PluginMetadata, PluginResponse from .workflow_config import ErtScriptWorkflow, WorkflowConfigs -from .workflow_fixtures import WorkflowFixtures +from .workflow_fixtures import ( + HookedWorkflowFixtures, + WorkflowFixtures, + all_hooked_workflow_fixtures, + fixtures_per_hook, +) P = ParamSpec("P") @@ -59,8 +64,11 @@ def inner(*args: P.args, **kwargs: P.kwargs) -> Any: "ErtScript", "ErtScriptWorkflow", "ExternalErtScript", + "HookedWorkflowFixtures", "JobDoc", "WorkflowConfigs", "WorkflowFixtures", + "all_hooked_workflow_fixtures", + "fixtures_per_hook", "plugin", ] diff --git a/src/ert/plugins/ert_script.py b/src/ert/plugins/ert_script.py index d819ebf1208..8c59b673f85 100644 --- a/src/ert/plugins/ert_script.py +++ b/src/ert/plugins/ert_script.py @@ -9,11 +9,15 @@ from abc import abstractmethod from collections.abc import Callable from types import MappingProxyType, ModuleType -from typing import TYPE_CHECKING, Any, TypeAlias +from typing import TYPE_CHECKING, Any, TypeAlias, cast from typing_extensions import deprecated -from .workflow_fixtures import WorkflowFixtures +from .workflow_fixtures import ( + HookedWorkflowFixtures, + WorkflowFixtures, + all_hooked_workflow_fixtures, +) if TYPE_CHECKING: from ert.config import ErtConfig @@ -112,16 +116,23 @@ def requested_fixtures(self) -> set[str]: return { k for k, v in inspect.signature(self.run).parameters.items() - if k in WorkflowFixtures.__annotations__ and k != "workflow_args" + if k in all_hooked_workflow_fixtures } def initializeAndRun( self, argument_types: list[type[Any]], argument_values: list[str], - fixtures: WorkflowFixtures | None = None, + fixtures: WorkflowFixtures | HookedWorkflowFixtures | None = None, ) -> Any: - fixtures = {} if fixtures is None else fixtures + fixtures_without_hook = {**(fixtures or {})} + + if "hook" in fixtures_without_hook: + fixtures_without_hook.pop("hook") + + complete_fixtures: WorkflowFixtures = cast( + WorkflowFixtures, fixtures_without_hook + ) arguments = [] for index, arg_value in enumerate(argument_values): arg_type = argument_types[index] if index < len(argument_types) else str @@ -130,14 +141,15 @@ def initializeAndRun( arguments.append(arg_type(arg_value)) else: arguments.append(None) - fixtures["workflow_args"] = arguments + + complete_fixtures["workflow_args"] = arguments try: func_args = inspect.signature(self.run).parameters - # If the user has specified *args, we skip injecting fixtures, and just + # If the user has specified *args, we skip injecting complete_fixtures, and just # pass the user configured arguments if not any(p.kind == p.VAR_POSITIONAL for p in func_args.values()): try: - arguments = self.insert_fixtures(func_args, fixtures) + arguments = self.insert_fixtures(func_args, complete_fixtures) except ValueError as e: # This is here for backwards compatibility, the user does not have *argv # but positional arguments. Can not be mixed with using fixtures. diff --git a/src/ert/plugins/workflow_fixtures.py b/src/ert/plugins/workflow_fixtures.py index 656a4e81882..58597ed9d6b 100644 --- a/src/ert/plugins/workflow_fixtures.py +++ b/src/ert/plugins/workflow_fixtures.py @@ -1,7 +1,11 @@ from __future__ import annotations +import functools import typing -from typing import TYPE_CHECKING, TypedDict +from typing import TYPE_CHECKING, NotRequired, TypedDict + +from PyQt6.QtWidgets import QWidget + from ert.config.parsing.hook_runtime import HookRuntime if TYPE_CHECKING: @@ -10,57 +14,64 @@ from ert.storage import Ensemble, Storage -class _BaseWorkflowFixtures(TypedDict): +class PreExperimentFixtures(TypedDict): random_seed: int + hook: HookRuntime = HookRuntime.PRE_EXPERIMENT -class _UpdateWorkflowFixtures(TypedDict): - es_settings: ESSettings - observation_settings: UpdateSettings - - -class _StorageWorkflowFixtures(TypedDict): +class PostExperimentFixtures(TypedDict): + hook: HookRuntime = HookRuntime.POST_EXPERIMENT + random_seed: int storage: Storage ensemble: Ensemble -class PreExperimentFixtures(_BaseWorkflowFixtures): +class PreSimulationFixtures(TypedDict): + hook: HookRuntime = HookRuntime.PRE_SIMULATION random_seed: int - hook: HookRuntime.PRE_EXPERIMENT - - -class PostExperimentFixtures(_BaseWorkflowFixtures, _StorageWorkflowFixtures): - hook: HookRuntime.POST_EXPERIMENT - - -class PreSimulationFixtures(_BaseWorkflowFixtures, _StorageWorkflowFixtures): - hook: HookRuntime.PRE_SIMULATION reports_dir: str run_paths: Runpaths + storage: Storage + ensemble: Ensemble class PostSimulationFixtures(PreSimulationFixtures): - hook: HookRuntime.POST_SIMULATION + hook: HookRuntime = HookRuntime.POST_SIMULATION -class PreFirstUpdateFixtures( - _BaseWorkflowFixtures, _UpdateWorkflowFixtures, _StorageWorkflowFixtures -): - hook: HookRuntime.PRE_FIRST_UPDATE +class PreFirstUpdateFixtures(TypedDict): + hook: HookRuntime = HookRuntime.PRE_FIRST_UPDATE + random_seed: int reports_dir: str run_paths: Runpaths + storage: Storage + ensemble: Ensemble + es_settings: ESSettings + observation_settings: UpdateSettings class PreUpdateFixtures(PreFirstUpdateFixtures): - hook: HookRuntime.PRE_UPDATE + hook: HookRuntime = HookRuntime.PRE_UPDATE class PostUpdateFixtures(PreFirstUpdateFixtures): - hook: HookRuntime.POST_UPDATE + hook: HookRuntime = HookRuntime.POST_UPDATE + + +class WorkflowFixtures(TypedDict, total=False): + workflow_args: NotRequired[list[typing.Any]] + parent: NotRequired[QWidget] + random_seed: NotRequired[int] + reports_dir: NotRequired[str] + run_paths: NotRequired[Runpaths] + storage: NotRequired[Storage] + ensemble: NotRequired[Ensemble] + es_settings: NotRequired[ESSettings] + observation_settings: NotRequired[UpdateSettings] # Union Type Definition -WorkflowFixtures = ( +HookedWorkflowFixtures = ( PreExperimentFixtures | PostExperimentFixtures | PreSimulationFixtures @@ -71,21 +82,23 @@ class PostUpdateFixtures(PreFirstUpdateFixtures): ) -def matches_hook(cls, hook: HookRuntime): - return cls.__annotations__.get("hook") == hook - +def __all_workflow_fixtures() -> set[str]: + fixtures_per_runtime = ( + __get_available_fixtures_for_hook(hook) for hook in HookRuntime + ) -def available_fixtures(cls): - return set(cls.__annotations__.keys()) - {"hook"} + return functools.reduce(lambda a, b: a | b, fixtures_per_runtime) -def get_available_fixtures_for_runtime(hook: HookRuntime) -> set[str]: +def __get_available_fixtures_for_hook(hook: HookRuntime) -> set[str]: cls = next( - cls - for cls in typing.get_args(WorkflowFixtures) - if matches_hook(cls, hook) + cls for cls in typing.get_args(HookedWorkflowFixtures) if cls.hook == hook ) - ref = cls + return set(cls.__annotations__.keys()) - {"hook"} + - return available_fixtures(cls) +all_hooked_workflow_fixtures = __all_workflow_fixtures() +fixtures_per_hook = { + hook: __get_available_fixtures_for_hook(hook) for hook in HookRuntime +} diff --git a/src/ert/run_models/base_run_model.py b/src/ert/run_models/base_run_model.py index f7fa510ad18..b11805fdfe5 100644 --- a/src/ert/run_models/base_run_model.py +++ b/src/ert/run_models/base_run_model.py @@ -55,7 +55,7 @@ REALIZATION_STATE_FINISHED, ) from ert.mode_definitions import MODULE_MODE -from ert.plugins import WorkflowFixtures +from ert.plugins import HookedWorkflowFixtures from ert.runpaths import Runpaths from ert.storage import Ensemble, Storage from ert.substitutions import Substitutions @@ -737,7 +737,7 @@ def validate_successful_realizations_count(self) -> None: def run_workflows( self, runtime: HookRuntime, - fixtures: WorkflowFixtures, + fixtures: HookedWorkflowFixtures, ) -> None: for workflow in self._hooked_workflows[runtime]: WorkflowRunner(workflow=workflow, fixtures=fixtures).run_blocking() @@ -883,7 +883,7 @@ def update( ) ) - workflow_fixtures: WorkflowFixtures = { + workflow_fixtures: HookedWorkflowFixtures = { "storage": self._storage, "ensemble": prior, "observation_settings": self._update_settings, diff --git a/src/ert/workflow_runner.py b/src/ert/workflow_runner.py index c466f89378b..c6f69d39281 100644 --- a/src/ert/workflow_runner.py +++ b/src/ert/workflow_runner.py @@ -6,7 +6,7 @@ from typing import Any, Self from ert.config import Workflow, WorkflowJob -from ert.plugins import ErtScript, ExternalErtScript, WorkflowFixtures +from ert.plugins import ErtScript, ExternalErtScript, HookedWorkflowFixtures class WorkflowJobRunner: @@ -19,7 +19,7 @@ def __init__(self, workflow_job: WorkflowJob): def run( self, arguments: list[Any] | None = None, - fixtures: WorkflowFixtures | None = None, + fixtures: HookedWorkflowFixtures | None = None, ) -> Any: if arguments is None: arguments = [] @@ -101,7 +101,7 @@ class WorkflowRunner: def __init__( self, workflow: Workflow, - fixtures: WorkflowFixtures, + fixtures: HookedWorkflowFixtures, ) -> None: self.__workflow = workflow self.fixtures = fixtures