Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit b157fcc

Browse files
committedMar 18, 2025·
Validate workflow job args against hook runtime fixtures
1 parent 92834bf commit b157fcc

File tree

4 files changed

+271
-1
lines changed

4 files changed

+271
-1
lines changed
 

‎src/ert/config/ert_config.py

+39
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
from .parsing import (
6161
parse as parse_config,
6262
)
63+
from .parsing.hook_runtime import fixtures_per_runtime
6364
from .parsing.observations_parser import (
6465
GenObsValues,
6566
HistoryValues,
@@ -395,10 +396,48 @@ def workflows_from_dict(
395396
)
396397
continue
397398

399+
wf = workflows[hook_name]
400+
available_fixtures = fixtures_per_runtime[mode]
401+
for job, _ in wf.cmd_list:
402+
if job.ert_script is None:
403+
continue
404+
405+
ert_script_instance = job.ert_script()
406+
requested_fixtures = ert_script_instance.requested_fixtures
407+
408+
# Look for requested fixtures that are not available for the given
409+
# mode
410+
missing_fixtures = requested_fixtures - available_fixtures
411+
412+
if missing_fixtures:
413+
ok_modes = [
414+
m
415+
for m in HookRuntime
416+
if not requested_fixtures - fixtures_per_runtime[m]
417+
]
418+
419+
message_start = (
420+
f"Workflow job {job.name} .run function expected "
421+
f"fixtures: {missing_fixtures}, which are not available "
422+
f"in the fixtures for the runtime {mode}: {available_fixtures}. "
423+
)
424+
message_end = (
425+
f"It would work in these runtimes: {', '.join(map(str, ok_modes))}"
426+
if len(ok_modes) > 0
427+
else "This fixture is not available in any of the runtimes."
428+
)
429+
430+
errors.append(
431+
ErrorInfo(message=message_start + message_end).set_context(
432+
hook_name
433+
)
434+
)
435+
398436
hooked_workflows[mode].append(workflows[hook_name])
399437

400438
if errors:
401439
raise ConfigValidationError.from_collected(errors)
440+
402441
return workflow_jobs, workflows, hooked_workflows
403442

404443

‎src/ert/config/parsing/hook_runtime.py

+51
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,54 @@ class HookRuntime(StrEnum):
99
PRE_FIRST_UPDATE = "PRE_FIRST_UPDATE"
1010
PRE_EXPERIMENT = "PRE_EXPERIMENT"
1111
POST_EXPERIMENT = "POST_EXPERIMENT"
12+
13+
14+
fixtures_per_runtime = {
15+
HookRuntime.PRE_EXPERIMENT: {"random_seed"},
16+
HookRuntime.PRE_SIMULATION: {
17+
"storage",
18+
"ensemble",
19+
"reports_dir",
20+
"random_seed",
21+
"run_paths",
22+
},
23+
HookRuntime.POST_SIMULATION: {
24+
"storage",
25+
"ensemble",
26+
"reports_dir",
27+
"random_seed",
28+
"run_paths",
29+
},
30+
HookRuntime.PRE_FIRST_UPDATE: {
31+
"storage",
32+
"ensemble",
33+
"reports_dir",
34+
"random_seed",
35+
"es_settings",
36+
"observation_settings",
37+
"run_paths",
38+
},
39+
HookRuntime.PRE_UPDATE: {
40+
"storage",
41+
"ensemble",
42+
"reports_dir",
43+
"random_seed",
44+
"es_settings",
45+
"observation_settings",
46+
"run_paths",
47+
},
48+
HookRuntime.POST_UPDATE: {
49+
"storage",
50+
"ensemble",
51+
"reports_dir",
52+
"random_seed",
53+
"es_settings",
54+
"observation_settings",
55+
"run_paths",
56+
},
57+
HookRuntime.POST_EXPERIMENT: {
58+
"random_seed",
59+
"storage",
60+
"ensemble",
61+
},
62+
}

‎src/ert/plugins/ert_script.py

+8
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,14 @@ def cleanup(self) -> None:
107107
Override to perform cleanup after a run.
108108
"""
109109

110+
@property
111+
def requested_fixtures(self) -> set[str]:
112+
return {
113+
k
114+
for k, v in inspect.signature(self.run).parameters.items()
115+
if k in WorkflowFixtures.__annotations__ and k != "workflow_args"
116+
}
117+
110118
def initializeAndRun(
111119
self,
112120
argument_types: list[type[Any]],

‎tests/ert/unit_tests/config/test_ert_config.py

+173-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
from hypothesis import strategies as st
1616
from pydantic import RootModel, TypeAdapter
1717

18-
from ert.config import ConfigValidationError, ErtConfig, HookRuntime
18+
from ert import ErtScript, ErtScriptWorkflow
19+
from ert.config import ConfigValidationError, ErtConfig, ESSettings, HookRuntime
1920
from ert.config.ert_config import _split_string_into_sections, create_forward_model_json
2021
from ert.config.parsing import ConfigKeys, ConfigWarning
2122
from ert.config.parsing.context_values import (
@@ -29,6 +30,7 @@
2930
from ert.config.parsing.queue_system import QueueSystem
3031
from ert.plugins import ErtPluginManager
3132
from ert.shared import ert_share_path
33+
from ert.storage import LocalEnsemble, Storage
3234

3335
from .config_dict_generator import config_generators
3436

@@ -1991,3 +1993,173 @@ def run(self, *args):
19911993

19921994
assert ert_config.substitutions["<FOO>"] == "ertconfig_foo"
19931995
assert ert_config.substitutions["<FOO2>"] == "ertconfig_foo2"
1996+
1997+
1998+
@pytest.mark.usefixtures("use_tmpdir")
1999+
def test_ert_script_hook_pre_experiment_but_asks_for_storage():
2000+
workflow_file_path = os.path.join(os.getcwd(), "workflow")
2001+
with open(workflow_file_path, mode="w", encoding="utf-8") as fh:
2002+
fh.write("TEST_SCRIPT")
2003+
2004+
with open("config.ert", mode="w", encoding="utf-8") as fh:
2005+
fh.write(
2006+
dedent(
2007+
f"""
2008+
NUM_REALIZATIONS 1
2009+
2010+
LOAD_WORKFLOW {workflow_file_path} workflow_alias
2011+
HOOK_WORKFLOW workflow_alias PRE_EXPERIMENT
2012+
"""
2013+
)
2014+
)
2015+
2016+
class SomeScript(ErtScript):
2017+
def run(self, storage: Storage):
2018+
pass
2019+
2020+
wfjob = ErtScriptWorkflow(
2021+
name="TEST_SCRIPT",
2022+
ertscript_class=SomeScript,
2023+
)
2024+
ErtConfig.PREINSTALLED_WORKFLOWS = {"TEST_SCRIPT": wfjob}
2025+
2026+
with pytest.raises(
2027+
ConfigValidationError,
2028+
match=r"Workflow job TEST_SCRIPT.*"
2029+
r"expected fixtures.*storage.*",
2030+
):
2031+
ErtConfig.from_file("config.ert")
2032+
2033+
2034+
@pytest.mark.usefixtures("use_tmpdir")
2035+
def test_ert_script_hook_pre_experiment_but_asks_for_ensemble():
2036+
workflow_file_path = os.path.join(os.getcwd(), "workflow")
2037+
with open(workflow_file_path, mode="w", encoding="utf-8") as fh:
2038+
fh.write("TEST_SCRIPT")
2039+
2040+
with open("config.ert", mode="w", encoding="utf-8") as fh:
2041+
fh.write(
2042+
dedent(
2043+
f"""
2044+
NUM_REALIZATIONS 1
2045+
2046+
LOAD_WORKFLOW {workflow_file_path} workflow_alias
2047+
HOOK_WORKFLOW workflow_alias PRE_EXPERIMENT
2048+
"""
2049+
)
2050+
)
2051+
2052+
class SomeScript(ErtScript):
2053+
def run(self, ensemble: LocalEnsemble):
2054+
pass
2055+
2056+
wfjob = ErtScriptWorkflow(
2057+
name="TEST_SCRIPT",
2058+
ertscript_class=SomeScript,
2059+
)
2060+
ErtConfig.PREINSTALLED_WORKFLOWS = {"TEST_SCRIPT": wfjob}
2061+
2062+
with pytest.raises(
2063+
ConfigValidationError,
2064+
match=r"Workflow job TEST_SCRIPT.*"
2065+
r"expected fixtures.*ensemble.*",
2066+
):
2067+
ErtConfig.from_file("config.ert")
2068+
2069+
2070+
@pytest.mark.usefixtures("use_tmpdir")
2071+
def test_ert_script_hook_pre_experiment_but_asks_for_random_seed():
2072+
workflow_file_path = os.path.join(os.getcwd(), "workflow")
2073+
with open(workflow_file_path, mode="w", encoding="utf-8") as fh:
2074+
fh.write("TEST_SCRIPT")
2075+
2076+
with open("config.ert", mode="w", encoding="utf-8") as fh:
2077+
fh.write(
2078+
dedent(
2079+
f"""
2080+
NUM_REALIZATIONS 1
2081+
2082+
LOAD_WORKFLOW {workflow_file_path} workflow_alias
2083+
HOOK_WORKFLOW workflow_alias PRE_EXPERIMENT
2084+
"""
2085+
)
2086+
)
2087+
2088+
class SomeScript(ErtScript):
2089+
def run(self, random_seed: int):
2090+
pass
2091+
2092+
wfjob = ErtScriptWorkflow(
2093+
name="TEST_SCRIPT",
2094+
ertscript_class=SomeScript,
2095+
)
2096+
ErtConfig.PREINSTALLED_WORKFLOWS = {"TEST_SCRIPT": wfjob}
2097+
2098+
ErtConfig.from_file("config.ert")
2099+
2100+
2101+
@pytest.mark.usefixtures("use_tmpdir")
2102+
def test_ert_script_hook_pre_experiment_essettings_fails():
2103+
workflow_file_path = os.path.join(os.getcwd(), "workflow")
2104+
with open(workflow_file_path, mode="w", encoding="utf-8") as fh:
2105+
fh.write("TEST_SCRIPT")
2106+
2107+
with open("config.ert", mode="w", encoding="utf-8") as fh:
2108+
fh.write(
2109+
dedent(
2110+
f"""
2111+
NUM_REALIZATIONS 1
2112+
2113+
LOAD_WORKFLOW {workflow_file_path} workflow_alias
2114+
HOOK_WORKFLOW workflow_alias PRE_EXPERIMENT
2115+
"""
2116+
)
2117+
)
2118+
2119+
class SomeScript(ErtScript):
2120+
def run(self, es_settings: ESSettings):
2121+
pass
2122+
2123+
wfjob = ErtScriptWorkflow(
2124+
name="TEST_SCRIPT",
2125+
ertscript_class=SomeScript,
2126+
)
2127+
ErtConfig.PREINSTALLED_WORKFLOWS = {"TEST_SCRIPT": wfjob}
2128+
2129+
with pytest.raises(
2130+
ConfigValidationError,
2131+
match=r".*It would work in these runtimes: PRE_UPDATE, POST_UPDATE, PRE_FIRST_UPDATE.*",
2132+
):
2133+
ErtConfig.from_file("config.ert")
2134+
2135+
2136+
@pytest.mark.usefixtures("use_tmpdir")
2137+
def test_ert_script_hook_valid_essettings_succeed():
2138+
workflow_file_path = os.path.join(os.getcwd(), "workflow")
2139+
with open(workflow_file_path, mode="w", encoding="utf-8") as fh:
2140+
fh.write("TEST_SCRIPT")
2141+
2142+
with open("config.ert", mode="w", encoding="utf-8") as fh:
2143+
fh.write(
2144+
dedent(
2145+
f"""
2146+
NUM_REALIZATIONS 1
2147+
2148+
LOAD_WORKFLOW {workflow_file_path} workflow_alias
2149+
HOOK_WORKFLOW workflow_alias PRE_UPDATE
2150+
HOOK_WORKFLOW workflow_alias POST_UPDATE
2151+
HOOK_WORKFLOW workflow_alias PRE_FIRST_UPDATE
2152+
"""
2153+
)
2154+
)
2155+
2156+
class SomeScript(ErtScript):
2157+
def run(self, es_settings: ESSettings):
2158+
pass
2159+
2160+
wfjob = ErtScriptWorkflow(
2161+
name="TEST_SCRIPT",
2162+
ertscript_class=SomeScript,
2163+
)
2164+
ErtConfig.PREINSTALLED_WORKFLOWS = {"TEST_SCRIPT": wfjob}
2165+
ErtConfig.from_file("config.ert")

0 commit comments

Comments
 (0)
Please sign in to comment.