Skip to content

Commit e161baa

Browse files
committed
Validate workflow job args against hook runtime fixtures
1 parent 92834bf commit e161baa

File tree

4 files changed

+266
-1
lines changed

4 files changed

+266
-1
lines changed

src/ert/config/ert_config.py

+39
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
from .parsing import (
6161
parse as parse_config,
6262
)
63+
from .parsing.hook_runtime import fixtures_per_runtime
6364
from .parsing.observations_parser import (
6465
GenObsValues,
6566
HistoryValues,
@@ -395,10 +396,48 @@ def workflows_from_dict(
395396
)
396397
continue
397398

399+
wf = workflows[hook_name]
400+
available_fixtures = fixtures_per_runtime[mode]
401+
for job, _ in wf.cmd_list:
402+
if job.ert_script is None:
403+
continue
404+
405+
ert_script_instance = job.ert_script()
406+
requested_fixtures = ert_script_instance.requested_fixtures
407+
408+
# Look for requested fixtures that are not available for the given
409+
# mode
410+
missing_fixtures = set(requested_fixtures) - available_fixtures
411+
412+
if missing_fixtures:
413+
ok_modes = [
414+
m
415+
for m in HookRuntime
416+
if not set(requested_fixtures) - fixtures_per_runtime[m]
417+
]
418+
419+
message_start = (
420+
f"Workflow job {job.name} .run function expected "
421+
f"fixtures: {missing_fixtures}, which are not available "
422+
f"in the fixtures for the runtime {mode}: {available_fixtures}. "
423+
)
424+
message_end = (
425+
f"It would work in these runtimes: {', '.join(map(str, ok_modes))}"
426+
if len(ok_modes) > 0
427+
else "This fixture is not available in any of the runtimes."
428+
)
429+
430+
errors.append(
431+
ErrorInfo(message=message_start + message_end).set_context(
432+
hook_name
433+
)
434+
)
435+
398436
hooked_workflows[mode].append(workflows[hook_name])
399437

400438
if errors:
401439
raise ConfigValidationError.from_collected(errors)
440+
402441
return workflow_jobs, workflows, hooked_workflows
403442

404443

src/ert/config/parsing/hook_runtime.py

+51
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,54 @@ class HookRuntime(StrEnum):
99
PRE_FIRST_UPDATE = "PRE_FIRST_UPDATE"
1010
PRE_EXPERIMENT = "PRE_EXPERIMENT"
1111
POST_EXPERIMENT = "POST_EXPERIMENT"
12+
13+
14+
fixtures_per_runtime = {
15+
HookRuntime.PRE_EXPERIMENT: {"random_seed"},
16+
HookRuntime.PRE_SIMULATION: {
17+
"storage",
18+
"ensemble",
19+
"reports_dir",
20+
"random_seed",
21+
"run_paths",
22+
},
23+
HookRuntime.POST_SIMULATION: {
24+
"storage",
25+
"ensemble",
26+
"reports_dir",
27+
"random_seed",
28+
"run_paths",
29+
},
30+
HookRuntime.PRE_FIRST_UPDATE: {
31+
"storage",
32+
"ensemble",
33+
"reports_dir",
34+
"random_seed",
35+
"es_settings",
36+
"observation_settings",
37+
"run_paths",
38+
},
39+
HookRuntime.PRE_UPDATE: {
40+
"storage",
41+
"ensemble",
42+
"reports_dir",
43+
"random_seed",
44+
"es_settings",
45+
"observation_settings",
46+
"run_paths",
47+
},
48+
HookRuntime.POST_UPDATE: {
49+
"storage",
50+
"ensemble",
51+
"reports_dir",
52+
"random_seed",
53+
"es_settings",
54+
"observation_settings",
55+
"run_paths",
56+
},
57+
HookRuntime.POST_EXPERIMENT: {
58+
"random_seed",
59+
"storage",
60+
"ensemble",
61+
},
62+
}

src/ert/plugins/ert_script.py

+8
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,14 @@ def cleanup(self) -> None:
107107
Override to perform cleanup after a run.
108108
"""
109109

110+
@property
111+
def requested_fixtures(self):
112+
return {
113+
k: v
114+
for k, v in inspect.signature(self.run).parameters.items()
115+
if k in WorkflowFixtures.__annotations__
116+
}
117+
110118
def initializeAndRun(
111119
self,
112120
argument_types: list[type[Any]],

tests/ert/unit_tests/config/test_ert_config.py

+168-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
from hypothesis import strategies as st
1616
from pydantic import RootModel, TypeAdapter
1717

18-
from ert.config import ConfigValidationError, ErtConfig, HookRuntime
18+
from ert import ErtScript, ErtScriptWorkflow
19+
from ert.config import ConfigValidationError, ErtConfig, ESSettings, HookRuntime
1920
from ert.config.ert_config import _split_string_into_sections, create_forward_model_json
2021
from ert.config.parsing import ConfigKeys, ConfigWarning
2122
from ert.config.parsing.context_values import (
@@ -29,6 +30,7 @@
2930
from ert.config.parsing.queue_system import QueueSystem
3031
from ert.plugins import ErtPluginManager
3132
from ert.shared import ert_share_path
33+
from ert.storage import LocalEnsemble, Storage
3234

3335
from .config_dict_generator import config_generators
3436

@@ -1991,3 +1993,168 @@ def run(self, *args):
19911993

19921994
assert ert_config.substitutions["<FOO>"] == "ertconfig_foo"
19931995
assert ert_config.substitutions["<FOO2>"] == "ertconfig_foo2"
1996+
1997+
1998+
def test_ert_script_hook_pre_experiment_but_asks_for_storage():
1999+
workflow_file_path = os.path.join(os.getcwd(), "workflow")
2000+
with open(workflow_file_path, mode="w", encoding="utf-8") as fh:
2001+
fh.write("TEST_SCRIPT")
2002+
2003+
with open("config.ert", mode="w", encoding="utf-8") as fh:
2004+
fh.write(
2005+
dedent(
2006+
f"""
2007+
NUM_REALIZATIONS 1
2008+
2009+
LOAD_WORKFLOW {workflow_file_path} workflow_alias
2010+
HOOK_WORKFLOW workflow_alias PRE_EXPERIMENT
2011+
"""
2012+
)
2013+
)
2014+
2015+
class SomeScript(ErtScript):
2016+
def run(self, storage: Storage):
2017+
pass
2018+
2019+
wfjob = ErtScriptWorkflow(
2020+
name="TEST_SCRIPT",
2021+
ertscript_class=SomeScript,
2022+
)
2023+
ErtConfig.PREINSTALLED_WORKFLOWS = {"TEST_SCRIPT": wfjob}
2024+
2025+
with pytest.raises(
2026+
ConfigValidationError,
2027+
match=r"Workflow job TEST_SCRIPT.*"
2028+
r"expected fixtures.*storage.*",
2029+
):
2030+
ErtConfig.from_file("config.ert")
2031+
2032+
2033+
def test_ert_script_hook_pre_experiment_but_asks_for_ensemble():
2034+
workflow_file_path = os.path.join(os.getcwd(), "workflow")
2035+
with open(workflow_file_path, mode="w", encoding="utf-8") as fh:
2036+
fh.write("TEST_SCRIPT")
2037+
2038+
with open("config.ert", mode="w", encoding="utf-8") as fh:
2039+
fh.write(
2040+
dedent(
2041+
f"""
2042+
NUM_REALIZATIONS 1
2043+
2044+
LOAD_WORKFLOW {workflow_file_path} workflow_alias
2045+
HOOK_WORKFLOW workflow_alias PRE_EXPERIMENT
2046+
"""
2047+
)
2048+
)
2049+
2050+
class SomeScript(ErtScript):
2051+
def run(self, ensemble: LocalEnsemble):
2052+
pass
2053+
2054+
wfjob = ErtScriptWorkflow(
2055+
name="TEST_SCRIPT",
2056+
ertscript_class=SomeScript,
2057+
)
2058+
ErtConfig.PREINSTALLED_WORKFLOWS = {"TEST_SCRIPT": wfjob}
2059+
2060+
with pytest.raises(
2061+
ConfigValidationError,
2062+
match=r"Workflow job TEST_SCRIPT.*"
2063+
r"expected fixtures.*ensemble.*",
2064+
):
2065+
ErtConfig.from_file("config.ert")
2066+
2067+
2068+
def test_ert_script_hook_pre_experiment_but_asks_for_random_seed():
2069+
workflow_file_path = os.path.join(os.getcwd(), "workflow")
2070+
with open(workflow_file_path, mode="w", encoding="utf-8") as fh:
2071+
fh.write("TEST_SCRIPT")
2072+
2073+
with open("config.ert", mode="w", encoding="utf-8") as fh:
2074+
fh.write(
2075+
dedent(
2076+
f"""
2077+
NUM_REALIZATIONS 1
2078+
2079+
LOAD_WORKFLOW {workflow_file_path} workflow_alias
2080+
HOOK_WORKFLOW workflow_alias PRE_EXPERIMENT
2081+
"""
2082+
)
2083+
)
2084+
2085+
class SomeScript(ErtScript):
2086+
def run(self, random_seed: int):
2087+
pass
2088+
2089+
wfjob = ErtScriptWorkflow(
2090+
name="TEST_SCRIPT",
2091+
ertscript_class=SomeScript,
2092+
)
2093+
ErtConfig.PREINSTALLED_WORKFLOWS = {"TEST_SCRIPT": wfjob}
2094+
2095+
ErtConfig.from_file("config.ert")
2096+
2097+
2098+
def test_ert_script_hook_pre_experiment_essettings_fails():
2099+
workflow_file_path = os.path.join(os.getcwd(), "workflow")
2100+
with open(workflow_file_path, mode="w", encoding="utf-8") as fh:
2101+
fh.write("TEST_SCRIPT")
2102+
2103+
with open("config.ert", mode="w", encoding="utf-8") as fh:
2104+
fh.write(
2105+
dedent(
2106+
f"""
2107+
NUM_REALIZATIONS 1
2108+
2109+
LOAD_WORKFLOW {workflow_file_path} workflow_alias
2110+
HOOK_WORKFLOW workflow_alias PRE_EXPERIMENT
2111+
"""
2112+
)
2113+
)
2114+
2115+
class SomeScript(ErtScript):
2116+
def run(self, es_settings: ESSettings):
2117+
pass
2118+
2119+
wfjob = ErtScriptWorkflow(
2120+
name="TEST_SCRIPT",
2121+
ertscript_class=SomeScript,
2122+
)
2123+
ErtConfig.PREINSTALLED_WORKFLOWS = {"TEST_SCRIPT": wfjob}
2124+
2125+
with pytest.raises(
2126+
ConfigValidationError,
2127+
match=r".*It would work in these runtimes: PRE_UPDATE, POST_UPDATE, PRE_FIRST_UPDATE.*",
2128+
):
2129+
ErtConfig.from_file("config.ert")
2130+
2131+
2132+
def test_ert_script_hook_valid_essettings_succeed():
2133+
workflow_file_path = os.path.join(os.getcwd(), "workflow")
2134+
with open(workflow_file_path, mode="w", encoding="utf-8") as fh:
2135+
fh.write("TEST_SCRIPT")
2136+
2137+
with open("config.ert", mode="w", encoding="utf-8") as fh:
2138+
fh.write(
2139+
dedent(
2140+
f"""
2141+
NUM_REALIZATIONS 1
2142+
2143+
LOAD_WORKFLOW {workflow_file_path} workflow_alias
2144+
HOOK_WORKFLOW workflow_alias PRE_UPDATE
2145+
HOOK_WORKFLOW workflow_alias POST_UPDATE
2146+
HOOK_WORKFLOW workflow_alias PRE_FIRST_UPDATE
2147+
"""
2148+
)
2149+
)
2150+
2151+
class SomeScript(ErtScript):
2152+
def run(self, es_settings: ESSettings):
2153+
pass
2154+
2155+
wfjob = ErtScriptWorkflow(
2156+
name="TEST_SCRIPT",
2157+
ertscript_class=SomeScript,
2158+
)
2159+
ErtConfig.PREINSTALLED_WORKFLOWS = {"TEST_SCRIPT": wfjob}
2160+
ErtConfig.from_file("config.ert")

0 commit comments

Comments
 (0)