diff --git a/src/ert/config/parsing/__init__.py b/src/ert/config/parsing/__init__.py index 6ac0114fe67..36f4e1fcd91 100644 --- a/src/ert/config/parsing/__init__.py +++ b/src/ert/config/parsing/__init__.py @@ -1,4 +1,5 @@ from .analysis_mode import AnalysisMode +from .base_model_context import BaseModelWithContextSupport from .config_dict import ConfigDict from .config_errors import ConfigValidationError, ConfigWarning from .config_keywords import ConfigKeys @@ -20,6 +21,7 @@ __all__ = [ "AnalysisMode", + "BaseModelWithContextSupport", "ConfigDict", "ConfigKeys", "ConfigValidationError", diff --git a/src/ert/config/parsing/base_model_context.py b/src/ert/config/parsing/base_model_context.py new file mode 100644 index 00000000000..29bdf17d4cb --- /dev/null +++ b/src/ert/config/parsing/base_model_context.py @@ -0,0 +1,26 @@ +from collections.abc import Iterator +from contextlib import contextmanager +from contextvars import ContextVar +from typing import Any + +from pydantic import BaseModel + +init_context_var = ContextVar("_init_context_var", default=None) + + +@contextmanager +def init_context(value: dict[str, Any]) -> Iterator[None]: + token = init_context_var.set(value) # type: ignore + try: + yield + finally: + init_context_var.reset(token) + + +class BaseModelWithContextSupport(BaseModel): + def __init__(__pydantic_self__, **data: Any) -> None: + __pydantic_self__.__pydantic_validator__.validate_python( + data, + self_instance=__pydantic_self__, + context=init_context_var.get(), + ) diff --git a/src/ert/config/queue_config.py b/src/ert/config/queue_config.py index c0b58482324..d8e04734968 100644 --- a/src/ert/config/queue_config.py +++ b/src/ert/config/queue_config.py @@ -5,14 +5,16 @@ import re import shutil from abc import abstractmethod -from dataclasses import asdict, field, fields from typing import Annotated, Any, Literal, no_type_check import pydantic +from pydantic import Field, field_validator from pydantic.dataclasses import dataclass +from pydantic_core.core_schema import ValidationInfo from ._get_num_cpu import get_num_cpu_from_data_file from .parsing import ( + BaseModelWithContextSupport, ConfigDict, ConfigKeys, ConfigValidationError, @@ -37,20 +39,30 @@ def activate_script() -> str: return "" -@pydantic.dataclasses.dataclass( - config={ - "extra": "forbid", - "validate_assignment": True, - "use_enum_values": True, - "validate_default": True, - } -) -class QueueOptions: +class QueueOptions( + BaseModelWithContextSupport, + validate_assignment=True, + extra="forbid", + use_enum_values=True, + validate_default=True, +): name: QueueSystem max_running: pydantic.NonNegativeInt = 0 submit_sleep: pydantic.NonNegativeFloat = 0.0 project_code: str | None = None - activate_script: str = field(default_factory=activate_script) + activate_script: str | None = Field(default=None, validate_default=True) + + @field_validator("activate_script", mode="before") + @classmethod + def inject_site_config_script(cls, v: str, info: ValidationInfo) -> str: + # User value gets highest priority + if isinstance(v, str): + return v + # Use from plugin system if user has not specified + plugin_script = None + if info.context: + plugin_script = info.context.get(info.field_name) + return plugin_script or activate_script() # Return default value @staticmethod def create_queue_options( @@ -78,12 +90,12 @@ def create_queue_options( return None def add_global_queue_options(self, config_dict: ConfigDict) -> None: - for generic_option in fields(QueueOptions): + for name, generic_option in QueueOptions.model_fields.items(): if ( - generic_value := config_dict.get(generic_option.name.upper(), None) # type: ignore - ) and self.__dict__[generic_option.name] == generic_option.default: + generic_value := config_dict.get(name.upper(), None) # type: ignore + ) and self.__dict__[name] == generic_option.default: try: - setattr(self, generic_option.name, generic_value) + setattr(self, name, generic_value) except pydantic.ValidationError as exception: for error in exception.errors(): _throw_error_or_warning( @@ -98,7 +110,6 @@ def driver_options(self) -> dict[str, Any]: """Translate the queue options to the key-value API provided by each driver""" -@pydantic.dataclasses.dataclass class LocalQueueOptions(QueueOptions): name: Literal[QueueSystem.LOCAL] = QueueSystem.LOCAL @@ -107,7 +118,6 @@ def driver_options(self) -> dict[str, Any]: return {} -@pydantic.dataclasses.dataclass class LsfQueueOptions(QueueOptions): name: Literal[QueueSystem.LSF] = QueueSystem.LSF bhist_cmd: NonEmptyString | None = None @@ -120,17 +130,13 @@ class LsfQueueOptions(QueueOptions): @property def driver_options(self) -> dict[str, Any]: - driver_dict = asdict(self) - driver_dict.pop("name") + driver_dict = self.model_dump(exclude={"name", "submit_sleep", "max_running"}) driver_dict["exclude_hosts"] = driver_dict.pop("exclude_host") driver_dict["queue_name"] = driver_dict.pop("lsf_queue") driver_dict["resource_requirement"] = driver_dict.pop("lsf_resource") - driver_dict.pop("submit_sleep") - driver_dict.pop("max_running") return driver_dict -@pydantic.dataclasses.dataclass class TorqueQueueOptions(QueueOptions): name: Literal[QueueSystem.TORQUE] = QueueSystem.TORQUE qsub_cmd: NonEmptyString | None = None @@ -143,15 +149,19 @@ class TorqueQueueOptions(QueueOptions): @property def driver_options(self) -> dict[str, Any]: - driver_dict = asdict(self) - driver_dict.pop("name") + driver_dict = self.model_dump( + exclude={ + "name", + "max_running", + "submit_sleep", + "qstat_options", + "queue_query_timeout", + } + ) driver_dict["queue_name"] = driver_dict.pop("queue") - driver_dict.pop("max_running") - driver_dict.pop("submit_sleep") return driver_dict -@pydantic.dataclasses.dataclass class SlurmQueueOptions(QueueOptions): name: Literal[QueueSystem.SLURM] = QueueSystem.SLURM sbatch: NonEmptyString = "sbatch" @@ -167,8 +177,7 @@ class SlurmQueueOptions(QueueOptions): @property def driver_options(self) -> dict[str, Any]: - driver_dict = asdict(self) - driver_dict.pop("name") + driver_dict = self.model_dump(exclude={"name", "max_running", "submit_sleep"}) driver_dict["sbatch_cmd"] = driver_dict.pop("sbatch") driver_dict["scancel_cmd"] = driver_dict.pop("scancel") driver_dict["scontrol_cmd"] = driver_dict.pop("scontrol") @@ -177,8 +186,6 @@ def driver_options(self) -> dict[str, Any]: driver_dict["exclude_hosts"] = driver_dict.pop("exclude_host") driver_dict["include_hosts"] = driver_dict.pop("include_host") driver_dict["queue_name"] = driver_dict.pop("partition") - driver_dict.pop("max_running") - driver_dict.pop("submit_sleep") return driver_dict @@ -203,12 +210,12 @@ def validate(self, mem_str_format: str | None) -> bool: ) valid_options: dict[str, list[str]] = { - QueueSystem.LOCAL: [field.name.upper() for field in fields(LocalQueueOptions)], - QueueSystem.LSF: [field.name.upper() for field in fields(LsfQueueOptions)], - QueueSystem.SLURM: [field.name.upper() for field in fields(SlurmQueueOptions)], - QueueSystem.TORQUE: [field.name.upper() for field in fields(TorqueQueueOptions)], + QueueSystem.LOCAL: [field.upper() for field in LocalQueueOptions.model_fields], + QueueSystem.LSF: [field.upper() for field in LsfQueueOptions.model_fields], + QueueSystem.SLURM: [field.upper() for field in SlurmQueueOptions.model_fields], + QueueSystem.TORQUE: [field.upper() for field in TorqueQueueOptions.model_fields], QueueSystemWithGeneric.GENERIC: [ - field.name.upper() for field in fields(QueueOptions) + field.upper() for field in QueueOptions.model_fields ], } diff --git a/src/ert/gui/simulation/experiment_panel.py b/src/ert/gui/simulation/experiment_panel.py index 7be4a058423..17ea6399d96 100644 --- a/src/ert/gui/simulation/experiment_panel.py +++ b/src/ert/gui/simulation/experiment_panel.py @@ -3,7 +3,6 @@ import os import platform from collections import OrderedDict -from dataclasses import fields from datetime import datetime from pathlib import Path from queue import SimpleQueue @@ -379,10 +378,10 @@ def populate_clipboard_debug_info(self) -> None: if isinstance(self.get_current_experiment_type(), SingleTestRun): queue_opts = LocalQueueOptions(max_running=1) - for field in fields(queue_opts): - field_value = getattr(queue_opts, field.name) + for name in queue_opts.model_fields: + field_value = getattr(queue_opts, name) if field_value is not None: - kv[field.name.replace("_", " ").capitalize()] = str(field_value) + kv[name.replace("_", " ").capitalize()] = str(field_value) kv["**Status**"] = "" kv["Trace ID"] = get_trace_id() diff --git a/src/everest/config/everest_config.py b/src/everest/config/everest_config.py index d9dc43c9657..00b657f1aa5 100644 --- a/src/everest/config/everest_config.py +++ b/src/everest/config/everest_config.py @@ -25,9 +25,13 @@ field_validator, model_validator, ) +from pydantic_core.core_schema import ValidationInfo from ruamel.yaml import YAML, YAMLError -from ert.config import ErtConfig +from ert.config import ErtConfig, QueueConfig +from ert.config.parsing import BaseModelWithContextSupport +from ert.config.parsing.base_model_context import init_context +from ert.plugins import ErtPluginManager from everest.config.control_variable_config import ControlVariableGuessListConfig from everest.config.install_template_config import InstallTemplateConfig from everest.config.server_config import ServerConfig @@ -69,18 +73,6 @@ from pydantic_core import ErrorDetails -def _dummy_ert_config(): - site_config = ErtConfig.read_site_config() - dummy_config = {"NUM_REALIZATIONS": 1, "ENSPATH": "."} - dummy_config.update(site_config) - return ErtConfig.with_plugins().from_dict(config_dict=dummy_config) - - -def get_system_installed_jobs(): - """Returns list of all system installed job names""" - return list(_dummy_ert_config().installed_forward_model_steps.keys()) - - # Makes property.setter work # Based on https://github.com/pydantic/pydantic/issues/1577#issuecomment-790506164 # We should use computed_property instead of this, when upgrading to pydantic 2. @@ -134,7 +126,7 @@ class HasName(Protocol): name: str -class EverestConfig(BaseModelWithPropertySupport): # type: ignore +class EverestConfig(BaseModelWithPropertySupport, BaseModelWithContextSupport): # type: ignore controls: Annotated[list[ControlConfig], AfterValidator(unique_items)] = Field( description="""Defines a list of controls. Controls should have unique names each control defines @@ -192,7 +184,7 @@ class EverestConfig(BaseModelWithPropertySupport): # type: ignore default=None, description="A list of output constraints with unique names." ) install_jobs: list[InstallJobConfig] | None = Field( - default=None, description="A list of jobs to install" + default=None, description="A list of jobs to install", validate_default=True ) install_workflow_jobs: list[InstallJobConfig] | None = Field( default=None, description="A list of workflow jobs to install" @@ -267,7 +259,7 @@ def validate_queue_system(self) -> Self: # pylint: disable=E0213 return self @model_validator(mode="after") - def validate_forward_model_job_name_installed(self) -> Self: # pylint: disable=E0213 + def validate_forward_model_job_name_installed(self, info: ValidationInfo) -> Self: # pylint: disable=E0213 install_jobs = self.install_jobs forward_model_jobs = self.forward_model if install_jobs is None: @@ -276,7 +268,8 @@ def validate_forward_model_job_name_installed(self) -> Self: # pylint: disable= return self installed_jobs_name = [job.name for job in install_jobs] installed_jobs_name += list(script_names) # default jobs - installed_jobs_name += get_system_installed_jobs() # system jobs + if info.context: # Add plugin jobs + installed_jobs_name += info.context.get("install_jobs", {}).keys() errors = [] for fm_job in forward_model_jobs: @@ -761,7 +754,7 @@ def with_defaults(cls, **kwargs): "model": {"realizations": [0]}, } - return cls.model_validate({**defaults, **kwargs}) + return cls.with_plugins({**defaults, **kwargs}) @staticmethod def lint_config_dict(config: dict) -> list["ErrorDetails"]: @@ -778,8 +771,8 @@ def lint_config_dict_with_raise(config: dict): # more understandable EverestConfig.model_validate(config) - @staticmethod - def load_file(config_file: str) -> "EverestConfig": + @classmethod + def load_file(cls, config_file: str): config_path = os.path.realpath(config_file) if not os.path.isfile(config_path): @@ -787,7 +780,7 @@ def load_file(config_file: str) -> "EverestConfig": config_dict = yaml_file_to_substituted_config_dict(config_path) try: - return EverestConfig.model_validate(config_dict) + return cls.with_plugins(config_dict) except ValidationError as error: exp = EverestValidationError() file_content = [] @@ -807,6 +800,23 @@ def load_file(config_file: str) -> "EverestConfig": raise exp from error + @classmethod + def with_plugins(cls, config_dict): + site_config = ErtConfig.read_site_config() + ert_config: ErtConfig = ErtConfig.with_plugins().from_dict( + config_dict=site_config + ) + context = { + "install_jobs": ert_config.installed_forward_model_steps, + } + activate_script = ErtPluginManager().activate_script() + if site_config: + context["queue_system"] = QueueConfig.from_dict(site_config).queue_options + if activate_script: + context["activate_script"] = ErtPluginManager().activate_script() + with init_context(context): + return cls(**config_dict) + @staticmethod def load_file_with_argparser( config_path, parser: ArgumentParser diff --git a/src/everest/config/server_config.py b/src/everest/config/server_config.py index f4f7bd9b27a..de4c9691100 100644 --- a/src/everest/config/server_config.py +++ b/src/everest/config/server_config.py @@ -2,7 +2,7 @@ import os from typing import Any -from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator +from pydantic import BaseModel, ConfigDict, Field, model_validator from ert.config.queue_config import ( LocalQueueOptions, @@ -10,7 +10,6 @@ SlurmQueueOptions, TorqueQueueOptions, ) -from ert.plugins import ErtPluginManager from ..strings import ( CERTIFICATE_DIR, @@ -38,15 +37,6 @@ class ServerConfig(BaseModel): # type: ignore extra="forbid", ) - @field_validator("queue_system", mode="before") - @classmethod - def default_local_queue(cls, v): - if v is None: - return v - elif "activate_script" not in v and ErtPluginManager().activate_script(): - v["activate_script"] = ErtPluginManager().activate_script() - return v - @model_validator(mode="before") @classmethod def check_old_config(cls, data: Any) -> Any: diff --git a/src/everest/config/simulator_config.py b/src/everest/config/simulator_config.py index 7797dbd46a9..776a1f3ff9c 100644 --- a/src/everest/config/simulator_config.py +++ b/src/everest/config/simulator_config.py @@ -1,21 +1,21 @@ from typing import Any from pydantic import ( - BaseModel, Field, NonNegativeInt, PositiveInt, field_validator, model_validator, ) +from pydantic_core.core_schema import ValidationInfo +from ert.config.parsing import BaseModelWithContextSupport from ert.config.queue_config import ( LocalQueueOptions, LsfQueueOptions, SlurmQueueOptions, TorqueQueueOptions, ) -from ert.plugins import ErtPluginManager simulator_example = {"queue_system": {"name": "local", "max_running": 3}} @@ -29,11 +29,11 @@ def check_removed_config(queue_system): } if isinstance(queue_system, str) and queue_system in queue_systems: raise ValueError( - f"Queue system configuration has changed, valid options for {queue_system} are: {list(queue_systems[queue_system].__dataclass_fields__.keys())}" + f"Queue system configuration has changed, valid options for {queue_system} are: {list(queue_systems[queue_system].model_fields.keys())}" ) -class SimulatorConfig(BaseModel, extra="forbid"): # type: ignore +class SimulatorConfig(BaseModelWithContextSupport, extra="forbid"): # type: ignore cores_per_node: PositiveInt | None = Field( default=None, description="""defines the number of CPUs when running @@ -94,13 +94,12 @@ class SimulatorConfig(BaseModel, extra="forbid"): # type: ignore @field_validator("queue_system", mode="before") @classmethod - def default_local_queue(cls, v): + def default_local_queue(cls, v, info: ValidationInfo): if v is None: - return LocalQueueOptions(max_running=8) - if "activate_script" not in v and ( - active_script := ErtPluginManager().activate_script() - ): - v["activate_script"] = active_script + options = None + if info.context: + options = info.context.get(info.field_name) + return options or LocalQueueOptions(max_running=8) return v @model_validator(mode="before") diff --git a/tests/ert/unit_tests/config/config_dict_generator.py b/tests/ert/unit_tests/config/config_dict_generator.py index 82b85e9a4fa..ae2064cb399 100644 --- a/tests/ert/unit_tests/config/config_dict_generator.py +++ b/tests/ert/unit_tests/config/config_dict_generator.py @@ -4,9 +4,9 @@ import os.path import stat from collections import defaultdict -from dataclasses import dataclass, fields +from dataclasses import dataclass from pathlib import Path -from typing import Any, Literal +from typing import Any, Literal, get_args, get_origin from warnings import filterwarnings import hypothesis.strategies as st @@ -127,37 +127,52 @@ def memory_with_unit_lsf(draw): def valid_queue_options(queue_system: str): return [ - field.name.upper() - for field in fields( - queue_systems_and_options[QueueSystemWithGeneric(queue_system)] - ) - if field.name != "name" + name.upper() + for name in queue_systems_and_options[ + QueueSystemWithGeneric(queue_system) + ].model_fields + if name != "name" ] +def has_base_type( + field_type, base_type: type[int] | bool | type[str] | type[float] +) -> bool: + if field_type is base_type: + return True + origin = get_origin(field_type) + if origin: + args = get_args(field_type) + if any(arg is base_type for arg in args): + return True + return any(has_base_type(arg, base_type) for arg in args) + return False + + queue_options_by_type: dict[str, dict[str, list[str]]] = defaultdict(dict) for system, options in queue_systems_and_options.items(): queue_options_by_type["string"][system.name] = [ - field.name.upper() - for field in fields(options) - if ("String" in field.type or "str" in field.type) - and "memory" not in field.name + name.upper() + for name, field in options.model_fields.items() + if has_base_type(field.annotation, str) and "memory" not in name ] queue_options_by_type["bool"][system.name] = [ - field.name.upper() for field in fields(options) if field.type == "bool" + name.upper() + for name, field in options.model_fields.items() + if has_base_type(field.annotation, bool) ] queue_options_by_type["posint"][system.name] = [ - field.name.upper() - for field in fields(options) - if "PositiveInt" in field.type or "NonNegativeInt" in field.type + name.upper() + for name, field in options.model_fields.items() + if has_base_type(field.annotation, int) ] queue_options_by_type["posfloat"][system.name] = [ - field.name.upper() - for field in fields(options) - if "NonNegativeFloat" in field.type or "PositiveFloat" in field.type + name.upper() + for name, field in options.model_fields.items() + if has_base_type(field.annotation, float) ] queue_options_by_type["memory"][system.name] = [ - field.name.upper() for field in fields(options) if "memory" in field.name + name.upper() for name in options.model_fields if "memory" in name ] diff --git a/tests/ert/unit_tests/config/test_queue_config.py b/tests/ert/unit_tests/config/test_queue_config.py index b138c8a15e5..a908bf9d01b 100644 --- a/tests/ert/unit_tests/config/test_queue_config.py +++ b/tests/ert/unit_tests/config/test_queue_config.py @@ -15,7 +15,6 @@ from ert.config.queue_config import ( LocalQueueOptions, LsfQueueOptions, - QueueOptions, SlurmQueueOptions, TorqueQueueOptions, ) @@ -422,7 +421,7 @@ def test_default_activate_script_generation(expected, monkeypatch, venv): def test_conda_activate_script_generation(expected, monkeypatch, env): monkeypatch.setenv("VIRTUAL_ENV", "") monkeypatch.setenv("CONDA_ENV", env) - options = QueueOptions(name="local") + options = LocalQueueOptions(name="local") assert options.activate_script == expected @@ -433,7 +432,7 @@ def test_conda_activate_script_generation(expected, monkeypatch, env): def test_multiple_activate_script_generation(expected, monkeypatch, env): monkeypatch.setenv("VIRTUAL_ENV", env) monkeypatch.setenv("CONDA_ENV", env) - options = QueueOptions(name="local") + options = LocalQueueOptions(name="local") assert options.activate_script == expected diff --git a/tests/everest/test_detached.py b/tests/everest/test_detached.py index 064c881301d..523010aafd2 100644 --- a/tests/everest/test_detached.py +++ b/tests/everest/test_detached.py @@ -5,9 +5,10 @@ import pytest import requests +import yaml import everest -from ert.config import ErtConfig +from ert.config import QueueSystem from ert.config.queue_config import ( LocalQueueOptions, LsfQueueOptions, @@ -20,7 +21,6 @@ from everest.config.server_config import ServerConfig from everest.config.simulator_config import SimulatorConfig from everest.detached import ( - _EVERSERVER_JOB_PATH, PROXY, ServerStatus, everserver_status, @@ -31,12 +31,6 @@ wait_for_server, wait_for_server_to_stop, ) -from everest.strings import ( - DEFAULT_OUTPUT_DIR, - DETACHED_NODE_DIR, - EVEREST_SERVER_CONFIG, - SIMULATION_DIR, -) from everest.util import makedirs_if_needed @@ -150,43 +144,11 @@ def test_wait_for_server(server_is_running_mock, caplog, monkeypatch): assert not caplog.messages -def _get_reference_config(): - everest_config = EverestConfig.load_file("config_minimal.yml") - reference_config = ErtConfig.read_site_config() - cwd = os.getcwd() - reference_config.update( - { - "INSTALL_JOB": [(EVEREST_SERVER_CONFIG, _EVERSERVER_JOB_PATH)], - "QUEUE_SYSTEM": "LOCAL", - "JOBNAME": EVEREST_SERVER_CONFIG, - "MAX_SUBMIT": 1, - "NUM_REALIZATIONS": 1, - "RUNPATH": os.path.join( - cwd, - DEFAULT_OUTPUT_DIR, - DETACHED_NODE_DIR, - SIMULATION_DIR, - ), - "FORWARD_MODEL": [ - [ - EVEREST_SERVER_CONFIG, - "--config-file", - os.path.join(cwd, "config_minimal.yml"), - ], - ], - "ENSPATH": os.path.join( - cwd, DEFAULT_OUTPUT_DIR, DETACHED_NODE_DIR, EVEREST_SERVER_CONFIG - ), - "RUNPATH_FILE": os.path.join( - cwd, DEFAULT_OUTPUT_DIR, DETACHED_NODE_DIR, ".res_runpath_list" - ), - } - ) - return everest_config, reference_config - - -def test_detached_mode_config_base(copy_math_func_test_data_to_tmp): - everest_config, _ = _get_reference_config() +def test_detached_mode_config_base(min_config, monkeypatch, tmp_path): + monkeypatch.chdir(tmp_path) + with open("config.yml", "w", encoding="utf-8") as fout: + yaml.dump(min_config, fout) + everest_config = EverestConfig.load_file("config.yml") assert everest_config.simulator.queue_system == LocalQueueOptions(max_running=8) @@ -285,7 +247,7 @@ def test_generate_queue_options_use_simulator_values( queue_options, expected_result, monkeypatch ): monkeypatch.setattr( - everest.config.server_config.ErtPluginManager, + everest.config.everest_config.ErtPluginManager, "activate_script", MagicMock(return_value=activate_script()), ) @@ -295,6 +257,62 @@ def test_generate_queue_options_use_simulator_values( assert config.server.queue_system == expected_result +@pytest.mark.parametrize("use_plugin", (True, False)) +@pytest.mark.parametrize( + "queue_options", + [ + {"name": "slurm", "activate_script": "From user"}, + {"name": "slurm"}, + ], +) +def test_queue_options_site_config(queue_options, use_plugin, monkeypatch, min_config): + plugin_result = "From plugin" + if "activate_script" in queue_options: + expected_result = queue_options["activate_script"] + elif use_plugin: + expected_result = plugin_result + else: + expected_result = activate_script() + + if use_plugin: + monkeypatch.setattr( + everest.config.everest_config.ErtPluginManager, + "activate_script", + MagicMock(return_value=plugin_result), + ) + config = EverestConfig.with_plugins( + {"simulator": {"queue_system": queue_options}} | min_config + ) + assert config.simulator.queue_system.activate_script == expected_result + + +@pytest.mark.parametrize("use_plugin", (True, False)) +@pytest.mark.parametrize( + "queue_options", + [ + {"queue_system": {"name": "slurm"}}, + {}, + ], +) +def test_simulator_queue_system_site_config( + queue_options, use_plugin, monkeypatch, min_config +): + if queue_options: + expected_result = SlurmQueueOptions # User specified + elif use_plugin: + expected_result = LsfQueueOptions # Mock site config + else: + expected_result = LocalQueueOptions # Default value + if use_plugin: + monkeypatch.setattr( + everest.config.everest_config.ErtConfig, + "read_site_config", + MagicMock(return_value={"QUEUE_SYSTEM": QueueSystem.LSF}), + ) + config = EverestConfig.with_plugins({"simulator": queue_options} | min_config) + assert isinstance(config.simulator.queue_system, expected_result) + + @pytest.mark.timeout(5) # Simulation might not finish @pytest.mark.integration_test @pytest.mark.xdist_group(name="starts_everest") diff --git a/tests/everest/test_egg_simulation.py b/tests/everest/test_egg_simulation.py index d20de3a3f9d..64c007dd233 100644 --- a/tests/everest/test_egg_simulation.py +++ b/tests/everest/test_egg_simulation.py @@ -594,7 +594,6 @@ def test_opm_fail_default_summary_keys(copy_egg_test_data_to_tmp): config = EverestConfig.load_file(CONFIG_FILE) # The Everest config file will fail to load as an Eclipse data file config.model.data_file = os.path.realpath(CONFIG_FILE) - assert len(EverestConfig.lint_config_dict(config.to_dict())) == 0 ert_config = _everest_to_ert_config_dict(config) diff --git a/tests/everest/test_res_initialization.py b/tests/everest/test_res_initialization.py index 27fc2e5db81..e08c60c4e68 100644 --- a/tests/everest/test_res_initialization.py +++ b/tests/everest/test_res_initialization.py @@ -369,3 +369,29 @@ def test_user_config_jobs_precedence(tmp_path, monkeypatch): .executable == "echo" ) + + +def test_that_queue_settings_are_taken_from_site_config( + min_config, monkeypatch, tmp_path +): + monkeypatch.chdir(tmp_path) + assert "simulator" not in min_config # Double check + Path("site-config").write_text( + dedent(""" + QUEUE_SYSTEM LSF + QUEUE_OPTION LSF LSF_RESOURCE my_resource + QUEUE_OPTION LSF LSF_QUEUE my_queue + """), + encoding="utf-8", + ) + with open("config.yml", "w", encoding="utf-8") as f: + yaml.dump(min_config, f) + monkeypatch.setenv("ERT_SITE_CONFIG", "site-config") + config = EverestConfig.load_file("config.yml") + assert config.simulator.queue_system == LsfQueueOptions( + lsf_queue="my_queue", lsf_resource="my_resource" + ) + ert_config = everest_to_ert_config(config) + assert ert_config.queue_config.queue_options == LsfQueueOptions( + lsf_queue="my_queue", lsf_resource="my_resource" + ) diff --git a/tests/everest/test_util.py b/tests/everest/test_util.py index 60e3275811b..a97b0e5670f 100644 --- a/tests/everest/test_util.py +++ b/tests/everest/test_util.py @@ -7,7 +7,6 @@ from everest import util from everest.bin.utils import report_on_previous_run from everest.config import EverestConfig, ServerConfig -from everest.config.everest_config import get_system_installed_jobs from everest.detached import ServerStatus from everest.strings import SERVER_STATUS from tests.everest.utils import ( @@ -133,13 +132,6 @@ def test_get_everserver_status_path(copy_math_func_test_data_to_tmp): assert path == expected_path -def test_get_system_installed_job_names(): - job_names = get_system_installed_jobs() - assert job_names is not None - assert isinstance(job_names, list) - assert len(job_names) > 0 - - @patch( "everest.bin.utils.everserver_status", return_value={"status": ServerStatus.failed, "message": "mock error"},