forked from equinor/ert
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor cli interface, make all modes available
- Loading branch information
Showing
11 changed files
with
578 additions
and
283 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,157 +1,182 @@ | ||
#!/usr/bin/env python | ||
import logging | ||
import os | ||
import sys | ||
from argparse import ArgumentTypeError | ||
|
||
from res.enkf import EnKFMain, ResConfig, ESUpdate, ErtRunContext | ||
from res.enkf.enums import RealizationStateEnum, HookRuntime | ||
from ecl.util.util import BoolVector | ||
import res | ||
from res.util import ResLog | ||
|
||
import logging | ||
|
||
|
||
def setup_fs(ert, target="default"): | ||
fs_manager = ert.getEnkfFsManager() | ||
|
||
src_fs = fs_manager.getCurrentFileSystem() | ||
tar_fs = fs_manager.getFileSystem(target) | ||
|
||
return src_fs, tar_fs | ||
|
||
|
||
def resconfig(config_file): | ||
return ResConfig(user_config_file=config_file) | ||
|
||
|
||
def _test_run(ert): | ||
source_fs, _ = setup_fs(ert) | ||
|
||
model_config = ert.getModelConfig() | ||
subst_list = ert.getDataKW() | ||
|
||
single_mask = BoolVector(default_value=False) | ||
single_mask[0] = True | ||
|
||
run_context = ErtRunContext.ensemble_experiment( | ||
sim_fs=source_fs, | ||
mask=single_mask, | ||
path_fmt=model_config.getRunpathFormat(), | ||
jobname_fmt=model_config.getJobnameFormat(), | ||
subst_list=subst_list, | ||
itr=0) | ||
|
||
sim_runner = ert.getEnkfSimulationRunner() | ||
_run_ensemble_experiment(ert, run_context, sim_runner) | ||
|
||
|
||
def _experiment_run(ert): | ||
source_fs, _ = setup_fs(ert) | ||
|
||
model_config = ert.getModelConfig() | ||
subst_list = ert.getDataKW() | ||
|
||
mask = BoolVector(default_value=True, initial_size=ert.getEnsembleSize()) | ||
|
||
run_context = ErtRunContext.ensemble_experiment( | ||
sim_fs=source_fs, | ||
mask=mask, | ||
path_fmt=model_config.getRunpathFormat(), | ||
jobname_fmt=model_config.getJobnameFormat(), | ||
subst_list=subst_list, | ||
itr=0) | ||
|
||
sim_runner = ert.getEnkfSimulationRunner() | ||
_run_ensemble_experiment(ert, run_context, sim_runner) | ||
|
||
|
||
def _ensemble_smoother_run(ert, target_case): | ||
source_fs, target_fs = setup_fs(ert, target_case) | ||
|
||
model_config = ert.getModelConfig() | ||
subst_list = ert.getDataKW() | ||
|
||
mask = BoolVector(default_value=True, initial_size=ert.getEnsembleSize()) | ||
|
||
prior_context = ErtRunContext.ensemble_smoother( | ||
sim_fs=source_fs, | ||
target_fs=target_fs, | ||
mask=mask, | ||
path_fmt=model_config.getRunpathFormat(), | ||
jobname_fmt=model_config.getJobnameFormat(), | ||
subst_list=subst_list, | ||
itr=0) | ||
|
||
sim_runner = ert.getEnkfSimulationRunner() | ||
_run_ensemble_experiment(ert, prior_context, sim_runner) | ||
sim_runner.runWorkflows( HookRuntime.PRE_UPDATE ) | ||
|
||
es_update = ESUpdate(ert) | ||
success = es_update.smootherUpdate(prior_context) | ||
if not success: | ||
raise AssertionError("Analysis of simulation failed!") | ||
|
||
sim_runner.runWorkflows( HookRuntime.POST_UPDATE ) | ||
|
||
ert.getEnkfFsManager().switchFileSystem(prior_context.get_target_fs()) | ||
|
||
sim_fs = prior_context.get_target_fs( ) | ||
state = RealizationStateEnum.STATE_HAS_DATA | RealizationStateEnum.STATE_INITIALIZED | ||
mask = sim_fs.getStateMap().createMask(state) | ||
|
||
rerun_context = ErtRunContext.ensemble_smoother( | ||
sim_fs=sim_fs, | ||
target_fs=None, | ||
mask=mask, | ||
path_fmt=model_config.getRunpathFormat(), | ||
jobname_fmt=model_config.getJobnameFormat(), | ||
subst_list=subst_list, | ||
itr=1) | ||
|
||
_run_ensemble_experiment(ert, rerun_context, sim_runner) | ||
|
||
|
||
def _run_ensemble_experiment(ert, run_context, sim_runner): | ||
sim_runner.createRunPath(run_context) | ||
sim_runner.runWorkflows(HookRuntime.PRE_SIMULATION) | ||
|
||
job_queue = ert.get_queue_config().create_job_queue() | ||
num_successful_realizations = sim_runner.runEnsembleExperiment(job_queue, run_context) | ||
_assert_minium_realizations_success(ert, num_successful_realizations) | ||
|
||
print("{} of the realizations were successful".format(num_successful_realizations)) | ||
sim_runner.runWorkflows( HookRuntime.POST_SIMULATION ) | ||
|
||
|
||
def _assert_minium_realizations_success(ert, num_successful_realizations): | ||
if num_successful_realizations == 0: | ||
raise AssertionError("Simulation failed! All realizations failed!") | ||
elif not ert.analysisConfig().haveEnoughRealisations(num_successful_realizations, ert.getEnsembleSize()): | ||
raise AssertionError("Too many simulations have failed! You can add/adjust MIN_REALIZATIONS to allow failures in your simulations.\n\n" | ||
"Check ERT log file '%s' or simulation folder for details." % ResLog.getFilename()) | ||
|
||
|
||
def main(): | ||
# The ert_cli script should be called from ert.in. The arguments are parsed and verified in ert.in | ||
if len(sys.argv) < 3: | ||
raise AssertionError("Required arguments are missing, the config-file, " | ||
"mode and target case must be provided") | ||
|
||
config_file, mode, target_case = sys.argv[1:] | ||
|
||
config = resconfig(config_file) | ||
ert = EnKFMain(config) | ||
|
||
if not ert._real_enkf_main().have_observations() and mode == 'ensemble_smoother': | ||
logging.error("No observations loaded. Unable to perform model update.") | ||
return | ||
|
||
if mode == "test_run": | ||
_test_run(ert) | ||
if mode == "ensemble_experiment": | ||
_experiment_run(ert) | ||
if mode == "ensemble_smoother": | ||
_ensemble_smoother_run(ert, target_case) | ||
|
||
from res.enkf import EnKFMain, ErtRunContext, ESUpdate, ResConfig | ||
|
||
from ert_gui import ERT | ||
from ert_gui.ertwidgets.models import ertmodel | ||
from ert_gui.ide.keywords.definitions import (NumberListStringArgument, | ||
RangeStringArgument) | ||
from ert_gui.simulation.models.ensemble_experiment import EnsembleExperiment | ||
from ert_gui.simulation.models.ensemble_smoother import EnsembleSmoother | ||
from ert_gui.simulation.models.iterated_ensemble_smoother import \ | ||
IteratedEnsembleSmoother | ||
from ert_gui.simulation.models.multiple_data_assimilation import \ | ||
MultipleDataAssimilation | ||
from ert_gui.simulation.models.single_test_run import SingleTestRun | ||
|
||
|
||
def run_cli(args): | ||
|
||
res_config = ResConfig(args.config) | ||
os.chdir(res_config.config_path) | ||
ert = EnKFMain(res_config, strict=True, verbose=args.verbose) | ||
notifier = ErtCliNotifier(ert, args.config) | ||
ERT.adapt(notifier) | ||
|
||
# Setup model | ||
if args.mode == 'test_run': | ||
model, argument = _setup_single_test_run() | ||
elif args.mode == 'ensemble_experiment': | ||
model, argument = _setup_ensemble_experiment(args) | ||
elif args.mode == 'ensemble_smoother': | ||
model, argument = _setup_ensemble_smoother(args) | ||
elif args.mode == 'es_mda': | ||
model, argument = _setup_multiple_data_assimilation(args) | ||
elif args.mode == 'iterated_ensemble_smoother': | ||
model, argument = _setup_iterated_ensemble_smoother(args) | ||
else: | ||
raise NotImplementedError( | ||
"Run type not supported {}".format(args.mode)) | ||
|
||
model.runSimulations(argument) | ||
|
||
|
||
def _setup_single_test_run(): | ||
model = SingleTestRun() | ||
simulations_argument = { | ||
"active_realizations": BoolVector(default_value=True, initial_size=1), | ||
} | ||
return model, simulations_argument | ||
|
||
|
||
def _setup_ensemble_experiment(args): | ||
|
||
model = EnsembleExperiment() | ||
simulations_argument = { | ||
"active_realizations": _realizations(args), | ||
} | ||
return model, simulations_argument | ||
|
||
|
||
def _setup_ensemble_smoother(args): | ||
model = EnsembleSmoother() | ||
|
||
simulations_argument = { | ||
"active_realizations": _realizations(args), | ||
"target_case": _target_case_name(args, format_mode=False) | ||
} | ||
return model, simulations_argument | ||
|
||
|
||
def _setup_multiple_data_assimilation(args): | ||
model = MultipleDataAssimilation() | ||
iterable = True | ||
active_name = ERT.ert.analysisConfig().activeModuleName() | ||
modules = ertmodel.getAnalysisModuleNames(iterable=iterable) | ||
simulations_argument = { | ||
"active_realizations": _realizations(args), | ||
"target_case": _target_case_name(args, format_mode=True), | ||
"analysis_module": _get_analysis_module_name(active_name, modules, iterable=False), | ||
"weights": args.weights | ||
} | ||
return model, simulations_argument | ||
|
||
|
||
def _setup_iterated_ensemble_smoother(args): | ||
if args.iterations is not None: | ||
ertmodel.setNumberOfIterations(args.iterations) | ||
|
||
model = IteratedEnsembleSmoother() | ||
iterable = True | ||
active_name = ERT.ert.analysisConfig().activeModuleName() | ||
modules = ertmodel.getAnalysisModuleNames(iterable=iterable) | ||
simulations_argument = { | ||
"active_realizations": _realizations(args), | ||
"target_case": _target_case_name(args, format_mode=True), | ||
"analysis_module": _get_analysis_module_name(active_name, modules, iterable=iterable) | ||
} | ||
return model, simulations_argument | ||
|
||
|
||
def _get_analysis_module_name(active_name, modules, iterable): | ||
|
||
if active_name in modules: | ||
return active_name | ||
elif "STD_ENKF" in modules and not iterable: | ||
return "STD_ENKF" | ||
elif "RML_ENKF" in modules and iterable: | ||
return "RML_ENKF" | ||
elif len(modules) > 0: | ||
return modules[0] | ||
|
||
return None | ||
|
||
|
||
def _realizations(args): | ||
ensemble_size = ERT.ert.getEnsembleSize() | ||
mask = BoolVector(default_value=False, initial_size=ensemble_size) | ||
if args.realizations is None: | ||
default = "0-{}".format(ensemble_size - 1) | ||
mask.updateActiveMask(default) | ||
return mask | ||
|
||
validator = RangeStringArgument(ensemble_size) | ||
validated = validator.validate(args.realizations) | ||
if validated.failed(): | ||
raise ArgumentTypeError( | ||
"Defined realizations is not within range of ensemble size: {}".format(args.realizations)) | ||
mask.updateActiveMask(args.realizations) | ||
return mask | ||
|
||
|
||
def _target_case_name(args, format_mode=False): | ||
""" @rtype: str """ | ||
if args.target_case is not None: | ||
return args.target_case | ||
|
||
if not format_mode: | ||
case_name = ertmodel.getCurrentCaseName() | ||
return "{}_smoother_update".format(case_name) | ||
|
||
if __name__ == "__main__": | ||
main() | ||
aic = ERT.ert.analysisConfig().getAnalysisIterConfig() | ||
if aic.caseFormatSet(): | ||
return aic.getCaseFormat() | ||
|
||
case_name = ertmodel.getCurrentCaseName() | ||
return "{}_%d".format(case_name) | ||
|
||
|
||
class ErtCliNotifier(): | ||
|
||
def __init__(self, ert, config_file): | ||
self._ert = ert | ||
self._config_file = config_file | ||
|
||
@property | ||
def ert(self): | ||
""" @rtype: EnKFMain """ | ||
if self._ert is None: | ||
raise ValueError("Ert is undefined.") | ||
return self._ert | ||
|
||
@property | ||
def config_file(self): | ||
""" @rtype: str """ | ||
if self._ert is None: | ||
raise ValueError("Ert is undefined.") | ||
return self._config_file | ||
|
||
@property | ||
def ertChanged(self): | ||
pass | ||
|
||
def emitErtChange(self): | ||
pass | ||
|
||
def reloadERT(self, config_file): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
class ErtAdapter(): | ||
|
||
def adapt(self, implementation): | ||
self._implementation = implementation | ||
|
||
@property | ||
def ertChanged(self): | ||
return self._implementation.ertChanged | ||
|
||
@property | ||
def ert(self): | ||
return self._implementation.ert | ||
|
||
@property | ||
def config_file(self): | ||
return self._implementation.config_file | ||
|
||
def emitErtChange(self): | ||
self._implementation.emitErtChange() | ||
|
||
def reloadERT(self, config_file): | ||
self._implementation.reloadERT(config_file) | ||
|
||
ERT = ErtAdapter() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.