|
15 | 15 | from hypothesis import strategies as st
|
16 | 16 | from pydantic import RootModel, TypeAdapter
|
17 | 17 |
|
18 |
| -from ert.config import ConfigValidationError, ErtConfig, HookRuntime |
| 18 | +from ert import ErtScript, ErtScriptWorkflow |
| 19 | +from ert.config import ConfigValidationError, ErtConfig, ESSettings, HookRuntime |
19 | 20 | from ert.config.ert_config import _split_string_into_sections, create_forward_model_json
|
20 | 21 | from ert.config.parsing import ConfigKeys, ConfigWarning
|
21 | 22 | from ert.config.parsing.context_values import (
|
|
29 | 30 | from ert.config.parsing.queue_system import QueueSystem
|
30 | 31 | from ert.plugins import ErtPluginManager
|
31 | 32 | from ert.shared import ert_share_path
|
| 33 | +from ert.storage import LocalEnsemble, Storage |
32 | 34 |
|
33 | 35 | from .config_dict_generator import config_generators
|
34 | 36 |
|
@@ -1991,3 +1993,168 @@ def run(self, *args):
|
1991 | 1993 |
|
1992 | 1994 | assert ert_config.substitutions["<FOO>"] == "ertconfig_foo"
|
1993 | 1995 | assert ert_config.substitutions["<FOO2>"] == "ertconfig_foo2"
|
| 1996 | + |
| 1997 | + |
| 1998 | +def test_ert_script_hook_pre_experiment_but_asks_for_storage(): |
| 1999 | + workflow_file_path = os.path.join(os.getcwd(), "workflow") |
| 2000 | + with open(workflow_file_path, mode="w", encoding="utf-8") as fh: |
| 2001 | + fh.write("TEST_SCRIPT") |
| 2002 | + |
| 2003 | + with open("config.ert", mode="w", encoding="utf-8") as fh: |
| 2004 | + fh.write( |
| 2005 | + dedent( |
| 2006 | + f""" |
| 2007 | + NUM_REALIZATIONS 1 |
| 2008 | +
|
| 2009 | + LOAD_WORKFLOW {workflow_file_path} workflow_alias |
| 2010 | + HOOK_WORKFLOW workflow_alias PRE_EXPERIMENT |
| 2011 | + """ |
| 2012 | + ) |
| 2013 | + ) |
| 2014 | + |
| 2015 | + class SomeScript(ErtScript): |
| 2016 | + def run(self, storage: Storage): |
| 2017 | + pass |
| 2018 | + |
| 2019 | + wfjob = ErtScriptWorkflow( |
| 2020 | + name="TEST_SCRIPT", |
| 2021 | + ertscript_class=SomeScript, |
| 2022 | + ) |
| 2023 | + ErtConfig.PREINSTALLED_WORKFLOWS = {"TEST_SCRIPT": wfjob} |
| 2024 | + |
| 2025 | + with pytest.raises( |
| 2026 | + ConfigValidationError, |
| 2027 | + match=r"Workflow job TEST_SCRIPT.*" |
| 2028 | + r"expected fixtures.*storage.*", |
| 2029 | + ): |
| 2030 | + ErtConfig.from_file("config.ert") |
| 2031 | + |
| 2032 | + |
| 2033 | +def test_ert_script_hook_pre_experiment_but_asks_for_ensemble(): |
| 2034 | + workflow_file_path = os.path.join(os.getcwd(), "workflow") |
| 2035 | + with open(workflow_file_path, mode="w", encoding="utf-8") as fh: |
| 2036 | + fh.write("TEST_SCRIPT") |
| 2037 | + |
| 2038 | + with open("config.ert", mode="w", encoding="utf-8") as fh: |
| 2039 | + fh.write( |
| 2040 | + dedent( |
| 2041 | + f""" |
| 2042 | + NUM_REALIZATIONS 1 |
| 2043 | +
|
| 2044 | + LOAD_WORKFLOW {workflow_file_path} workflow_alias |
| 2045 | + HOOK_WORKFLOW workflow_alias PRE_EXPERIMENT |
| 2046 | + """ |
| 2047 | + ) |
| 2048 | + ) |
| 2049 | + |
| 2050 | + class SomeScript(ErtScript): |
| 2051 | + def run(self, ensemble: LocalEnsemble): |
| 2052 | + pass |
| 2053 | + |
| 2054 | + wfjob = ErtScriptWorkflow( |
| 2055 | + name="TEST_SCRIPT", |
| 2056 | + ertscript_class=SomeScript, |
| 2057 | + ) |
| 2058 | + ErtConfig.PREINSTALLED_WORKFLOWS = {"TEST_SCRIPT": wfjob} |
| 2059 | + |
| 2060 | + with pytest.raises( |
| 2061 | + ConfigValidationError, |
| 2062 | + match=r"Workflow job TEST_SCRIPT.*" |
| 2063 | + r"expected fixtures.*ensemble.*", |
| 2064 | + ): |
| 2065 | + ErtConfig.from_file("config.ert") |
| 2066 | + |
| 2067 | + |
| 2068 | +def test_ert_script_hook_pre_experiment_but_asks_for_random_seed(): |
| 2069 | + workflow_file_path = os.path.join(os.getcwd(), "workflow") |
| 2070 | + with open(workflow_file_path, mode="w", encoding="utf-8") as fh: |
| 2071 | + fh.write("TEST_SCRIPT") |
| 2072 | + |
| 2073 | + with open("config.ert", mode="w", encoding="utf-8") as fh: |
| 2074 | + fh.write( |
| 2075 | + dedent( |
| 2076 | + f""" |
| 2077 | + NUM_REALIZATIONS 1 |
| 2078 | +
|
| 2079 | + LOAD_WORKFLOW {workflow_file_path} workflow_alias |
| 2080 | + HOOK_WORKFLOW workflow_alias PRE_EXPERIMENT |
| 2081 | + """ |
| 2082 | + ) |
| 2083 | + ) |
| 2084 | + |
| 2085 | + class SomeScript(ErtScript): |
| 2086 | + def run(self, random_seed: int): |
| 2087 | + pass |
| 2088 | + |
| 2089 | + wfjob = ErtScriptWorkflow( |
| 2090 | + name="TEST_SCRIPT", |
| 2091 | + ertscript_class=SomeScript, |
| 2092 | + ) |
| 2093 | + ErtConfig.PREINSTALLED_WORKFLOWS = {"TEST_SCRIPT": wfjob} |
| 2094 | + |
| 2095 | + ErtConfig.from_file("config.ert") |
| 2096 | + |
| 2097 | + |
| 2098 | +def test_ert_script_hook_pre_experiment_essettings_fails(): |
| 2099 | + workflow_file_path = os.path.join(os.getcwd(), "workflow") |
| 2100 | + with open(workflow_file_path, mode="w", encoding="utf-8") as fh: |
| 2101 | + fh.write("TEST_SCRIPT") |
| 2102 | + |
| 2103 | + with open("config.ert", mode="w", encoding="utf-8") as fh: |
| 2104 | + fh.write( |
| 2105 | + dedent( |
| 2106 | + f""" |
| 2107 | + NUM_REALIZATIONS 1 |
| 2108 | +
|
| 2109 | + LOAD_WORKFLOW {workflow_file_path} workflow_alias |
| 2110 | + HOOK_WORKFLOW workflow_alias PRE_EXPERIMENT |
| 2111 | + """ |
| 2112 | + ) |
| 2113 | + ) |
| 2114 | + |
| 2115 | + class SomeScript(ErtScript): |
| 2116 | + def run(self, es_settings: ESSettings): |
| 2117 | + pass |
| 2118 | + |
| 2119 | + wfjob = ErtScriptWorkflow( |
| 2120 | + name="TEST_SCRIPT", |
| 2121 | + ertscript_class=SomeScript, |
| 2122 | + ) |
| 2123 | + ErtConfig.PREINSTALLED_WORKFLOWS = {"TEST_SCRIPT": wfjob} |
| 2124 | + |
| 2125 | + with pytest.raises( |
| 2126 | + ConfigValidationError, |
| 2127 | + match=r".*It would work in these runtimes: PRE_UPDATE, POST_UPDATE, PRE_FIRST_UPDATE.*", |
| 2128 | + ): |
| 2129 | + ErtConfig.from_file("config.ert") |
| 2130 | + |
| 2131 | + |
| 2132 | +def test_ert_script_hook_valid_essettings_succeed(): |
| 2133 | + workflow_file_path = os.path.join(os.getcwd(), "workflow") |
| 2134 | + with open(workflow_file_path, mode="w", encoding="utf-8") as fh: |
| 2135 | + fh.write("TEST_SCRIPT") |
| 2136 | + |
| 2137 | + with open("config.ert", mode="w", encoding="utf-8") as fh: |
| 2138 | + fh.write( |
| 2139 | + dedent( |
| 2140 | + f""" |
| 2141 | + NUM_REALIZATIONS 1 |
| 2142 | +
|
| 2143 | + LOAD_WORKFLOW {workflow_file_path} workflow_alias |
| 2144 | + HOOK_WORKFLOW workflow_alias PRE_UPDATE |
| 2145 | + HOOK_WORKFLOW workflow_alias POST_UPDATE |
| 2146 | + HOOK_WORKFLOW workflow_alias PRE_FIRST_UPDATE |
| 2147 | + """ |
| 2148 | + ) |
| 2149 | + ) |
| 2150 | + |
| 2151 | + class SomeScript(ErtScript): |
| 2152 | + def run(self, es_settings: ESSettings): |
| 2153 | + pass |
| 2154 | + |
| 2155 | + wfjob = ErtScriptWorkflow( |
| 2156 | + name="TEST_SCRIPT", |
| 2157 | + ertscript_class=SomeScript, |
| 2158 | + ) |
| 2159 | + ErtConfig.PREINSTALLED_WORKFLOWS = {"TEST_SCRIPT": wfjob} |
| 2160 | + ErtConfig.from_file("config.ert") |
0 commit comments