-
Notifications
You must be signed in to change notification settings - Fork 110
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix bug where site config was not propagated to Everest config #9719
base: main
Are you sure you want to change the base?
Changes from 3 commits
f5b6de1
6337e95
584a340
7bd65e4
1017e5d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
from collections.abc import Iterator | ||
from contextlib import contextmanager | ||
from contextvars import ContextVar | ||
from typing import Any | ||
|
||
from pydantic import BaseModel | ||
|
||
init_context_var = ContextVar("_init_context_var", default=None) | ||
|
||
|
||
@contextmanager | ||
def init_context(value: dict[str, Any]) -> Iterator[None]: | ||
token = init_context_var.set(value) # type: ignore | ||
try: | ||
yield | ||
finally: | ||
init_context_var.reset(token) | ||
|
||
|
||
class BaseModelWithContextSupport(BaseModel): | ||
def __init__(__pydantic_self__, **data: Any) -> None: | ||
__pydantic_self__.__pydantic_validator__.validate_python( | ||
data, | ||
self_instance=__pydantic_self__, | ||
context=init_context_var.get(), | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,14 +5,16 @@ | |
import re | ||
import shutil | ||
from abc import abstractmethod | ||
from dataclasses import asdict, field, fields | ||
from typing import Annotated, Any, Literal, no_type_check | ||
|
||
import pydantic | ||
from pydantic import Field, field_validator | ||
from pydantic.dataclasses import dataclass | ||
from pydantic_core.core_schema import ValidationInfo | ||
|
||
from ._get_num_cpu import get_num_cpu_from_data_file | ||
from .parsing import ( | ||
BaseModelWithContextSupport, | ||
ConfigDict, | ||
ConfigKeys, | ||
ConfigValidationError, | ||
|
@@ -37,20 +39,30 @@ def activate_script() -> str: | |
return "" | ||
|
||
|
||
@pydantic.dataclasses.dataclass( | ||
config={ | ||
"extra": "forbid", | ||
"validate_assignment": True, | ||
"use_enum_values": True, | ||
"validate_default": True, | ||
} | ||
) | ||
class QueueOptions: | ||
class QueueOptions( | ||
BaseModelWithContextSupport, | ||
validate_assignment=True, | ||
extra="forbid", | ||
use_enum_values=True, | ||
validate_default=True, | ||
): | ||
name: QueueSystem | ||
max_running: pydantic.NonNegativeInt = 0 | ||
submit_sleep: pydantic.NonNegativeFloat = 0.0 | ||
project_code: str | None = None | ||
activate_script: str = field(default_factory=activate_script) | ||
activate_script: str | None = Field(default=None, validate_default=True) | ||
|
||
@field_validator("activate_script", mode="before") | ||
@classmethod | ||
def inject_site_config_script(cls, v: str, info: ValidationInfo) -> str: | ||
# User value gets highest priority | ||
if isinstance(v, str): | ||
return v | ||
# Use from plugin system if user has not specified | ||
plugin_script = None | ||
if info.context: | ||
plugin_script = info.context.get(info.field_name) | ||
return plugin_script or activate_script() # Return default value | ||
|
||
@staticmethod | ||
def create_queue_options( | ||
|
@@ -78,12 +90,12 @@ def create_queue_options( | |
return None | ||
|
||
def add_global_queue_options(self, config_dict: ConfigDict) -> None: | ||
for generic_option in fields(QueueOptions): | ||
for name, generic_option in QueueOptions.model_fields.items(): | ||
if ( | ||
generic_value := config_dict.get(generic_option.name.upper(), None) # type: ignore | ||
) and self.__dict__[generic_option.name] == generic_option.default: | ||
generic_value := config_dict.get(name.upper(), None) # type: ignore | ||
) and self.__dict__[name] == generic_option.default: | ||
try: | ||
setattr(self, generic_option.name, generic_value) | ||
setattr(self, name, generic_value) | ||
except pydantic.ValidationError as exception: | ||
for error in exception.errors(): | ||
_throw_error_or_warning( | ||
|
@@ -98,7 +110,6 @@ def driver_options(self) -> dict[str, Any]: | |
"""Translate the queue options to the key-value API provided by each driver""" | ||
|
||
|
||
@pydantic.dataclasses.dataclass | ||
class LocalQueueOptions(QueueOptions): | ||
name: Literal[QueueSystem.LOCAL] = QueueSystem.LOCAL | ||
|
||
|
@@ -107,7 +118,6 @@ def driver_options(self) -> dict[str, Any]: | |
return {} | ||
|
||
|
||
@pydantic.dataclasses.dataclass | ||
class LsfQueueOptions(QueueOptions): | ||
name: Literal[QueueSystem.LSF] = QueueSystem.LSF | ||
bhist_cmd: NonEmptyString | None = None | ||
|
@@ -120,17 +130,13 @@ class LsfQueueOptions(QueueOptions): | |
|
||
@property | ||
def driver_options(self) -> dict[str, Any]: | ||
driver_dict = asdict(self) | ||
driver_dict.pop("name") | ||
driver_dict = self.model_dump(exclude={"name", "submit_sleep", "max_running"}) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Very nice ❤️ |
||
driver_dict["exclude_hosts"] = driver_dict.pop("exclude_host") | ||
driver_dict["queue_name"] = driver_dict.pop("lsf_queue") | ||
driver_dict["resource_requirement"] = driver_dict.pop("lsf_resource") | ||
driver_dict.pop("submit_sleep") | ||
driver_dict.pop("max_running") | ||
return driver_dict | ||
|
||
|
||
@pydantic.dataclasses.dataclass | ||
class TorqueQueueOptions(QueueOptions): | ||
name: Literal[QueueSystem.TORQUE] = QueueSystem.TORQUE | ||
qsub_cmd: NonEmptyString | None = None | ||
|
@@ -143,15 +149,19 @@ class TorqueQueueOptions(QueueOptions): | |
|
||
@property | ||
def driver_options(self) -> dict[str, Any]: | ||
driver_dict = asdict(self) | ||
driver_dict.pop("name") | ||
driver_dict = self.model_dump( | ||
exclude={ | ||
"name", | ||
"max_running", | ||
"submit_sleep", | ||
"qstat_options", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Were these: |
||
"queue_query_timeout", | ||
} | ||
) | ||
driver_dict["queue_name"] = driver_dict.pop("queue") | ||
driver_dict.pop("max_running") | ||
driver_dict.pop("submit_sleep") | ||
return driver_dict | ||
|
||
|
||
@pydantic.dataclasses.dataclass | ||
class SlurmQueueOptions(QueueOptions): | ||
name: Literal[QueueSystem.SLURM] = QueueSystem.SLURM | ||
sbatch: NonEmptyString = "sbatch" | ||
|
@@ -167,8 +177,7 @@ class SlurmQueueOptions(QueueOptions): | |
|
||
@property | ||
def driver_options(self) -> dict[str, Any]: | ||
driver_dict = asdict(self) | ||
driver_dict.pop("name") | ||
driver_dict = self.model_dump(exclude={"name", "max_running", "submit_sleep"}) | ||
driver_dict["sbatch_cmd"] = driver_dict.pop("sbatch") | ||
driver_dict["scancel_cmd"] = driver_dict.pop("scancel") | ||
driver_dict["scontrol_cmd"] = driver_dict.pop("scontrol") | ||
|
@@ -177,8 +186,6 @@ def driver_options(self) -> dict[str, Any]: | |
driver_dict["exclude_hosts"] = driver_dict.pop("exclude_host") | ||
driver_dict["include_hosts"] = driver_dict.pop("include_host") | ||
driver_dict["queue_name"] = driver_dict.pop("partition") | ||
driver_dict.pop("max_running") | ||
driver_dict.pop("submit_sleep") | ||
return driver_dict | ||
|
||
|
||
|
@@ -203,12 +210,12 @@ def validate(self, mem_str_format: str | None) -> bool: | |
) | ||
|
||
valid_options: dict[str, list[str]] = { | ||
QueueSystem.LOCAL: [field.name.upper() for field in fields(LocalQueueOptions)], | ||
QueueSystem.LSF: [field.name.upper() for field in fields(LsfQueueOptions)], | ||
QueueSystem.SLURM: [field.name.upper() for field in fields(SlurmQueueOptions)], | ||
QueueSystem.TORQUE: [field.name.upper() for field in fields(TorqueQueueOptions)], | ||
QueueSystem.LOCAL: [field.upper() for field in LocalQueueOptions.model_fields], | ||
QueueSystem.LSF: [field.upper() for field in LsfQueueOptions.model_fields], | ||
QueueSystem.SLURM: [field.upper() for field in SlurmQueueOptions.model_fields], | ||
QueueSystem.TORQUE: [field.upper() for field in TorqueQueueOptions.model_fields], | ||
QueueSystemWithGeneric.GENERIC: [ | ||
field.name.upper() for field in fields(QueueOptions) | ||
field.upper() for field in QueueOptions.model_fields | ||
], | ||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,9 +25,13 @@ | |
field_validator, | ||
model_validator, | ||
) | ||
from pydantic_core.core_schema import ValidationInfo | ||
from ruamel.yaml import YAML, YAMLError | ||
|
||
from ert.config import ErtConfig | ||
from ert.config import ErtConfig, QueueConfig | ||
from ert.config.parsing import BaseModelWithContextSupport | ||
from ert.config.parsing.base_model_context import init_context | ||
from ert.plugins import ErtPluginManager | ||
from everest.config.control_variable_config import ControlVariableGuessListConfig | ||
from everest.config.install_template_config import InstallTemplateConfig | ||
from everest.config.server_config import ServerConfig | ||
|
@@ -134,7 +138,7 @@ class HasName(Protocol): | |
name: str | ||
|
||
|
||
class EverestConfig(BaseModelWithPropertySupport): # type: ignore | ||
class EverestConfig(BaseModelWithPropertySupport, BaseModelWithContextSupport): # type: ignore | ||
controls: Annotated[list[ControlConfig], AfterValidator(unique_items)] = Field( | ||
description="""Defines a list of controls. | ||
Controls should have unique names each control defines | ||
|
@@ -267,7 +271,7 @@ def validate_queue_system(self) -> Self: # pylint: disable=E0213 | |
return self | ||
|
||
@model_validator(mode="after") | ||
def validate_forward_model_job_name_installed(self) -> Self: # pylint: disable=E0213 | ||
def validate_forward_model_job_name_installed(self, info: ValidationInfo) -> Self: # pylint: disable=E0213 | ||
install_jobs = self.install_jobs | ||
forward_model_jobs = self.forward_model | ||
if install_jobs is None: | ||
|
@@ -807,6 +811,18 @@ def load_file(config_file: str) -> "EverestConfig": | |
|
||
raise exp from error | ||
|
||
@classmethod | ||
def with_plugins(cls, config_dict): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am a bit confused as to why mypy is not complaining about the missing type for the argument |
||
context = {} | ||
activate_script = ErtPluginManager().activate_script() | ||
site_config = ErtConfig.read_site_config() | ||
if site_config: | ||
context["queue_system"] = QueueConfig.from_dict(site_config).queue_options | ||
if activate_script: | ||
context["activate_script"] = ErtPluginManager().activate_script() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you can just have context["activate_script"] = activate_script in place of context["activate_script"] = ErtPluginManager().activate_script() There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a need to re-invoke |
||
with init_context(context): | ||
return cls(**config_dict) | ||
|
||
@staticmethod | ||
def load_file_with_argparser( | ||
config_path, parser: ArgumentParser | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,21 +1,21 @@ | ||
from typing import Any | ||
|
||
from pydantic import ( | ||
BaseModel, | ||
Field, | ||
NonNegativeInt, | ||
PositiveInt, | ||
field_validator, | ||
model_validator, | ||
) | ||
from pydantic_core.core_schema import ValidationInfo | ||
|
||
from ert.config.parsing import BaseModelWithContextSupport | ||
from ert.config.queue_config import ( | ||
LocalQueueOptions, | ||
LsfQueueOptions, | ||
SlurmQueueOptions, | ||
TorqueQueueOptions, | ||
) | ||
from ert.plugins import ErtPluginManager | ||
|
||
simulator_example = {"queue_system": {"name": "local", "max_running": 3}} | ||
|
||
|
@@ -29,11 +29,11 @@ def check_removed_config(queue_system): | |
} | ||
if isinstance(queue_system, str) and queue_system in queue_systems: | ||
raise ValueError( | ||
f"Queue system configuration has changed, valid options for {queue_system} are: {list(queue_systems[queue_system].__dataclass_fields__.keys())}" | ||
f"Queue system configuration has changed, valid options for {queue_system} are: {list(queue_systems[queue_system].model_fields.keys())}" | ||
) | ||
|
||
|
||
class SimulatorConfig(BaseModel, extra="forbid"): # type: ignore | ||
class SimulatorConfig(BaseModelWithContextSupport, extra="forbid"): # type: ignore | ||
cores_per_node: PositiveInt | None = Field( | ||
default=None, | ||
description="""defines the number of CPUs when running | ||
|
@@ -94,13 +94,12 @@ class SimulatorConfig(BaseModel, extra="forbid"): # type: ignore | |
|
||
@field_validator("queue_system", mode="before") | ||
@classmethod | ||
def default_local_queue(cls, v): | ||
def default_local_queue(cls, v, info: ValidationInfo): | ||
if v is None: | ||
return LocalQueueOptions(max_running=8) | ||
if "activate_script" not in v and ( | ||
active_script := ErtPluginManager().activate_script() | ||
): | ||
v["activate_script"] = active_script | ||
options = None | ||
if info.context: | ||
options = info.context.get(info.field_name) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Won't |
||
return options or LocalQueueOptions(max_running=8) | ||
return v | ||
|
||
@model_validator(mode="before") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a test that runs this line? EDIT: I see running
test_detached
triggers it, but do any of the ERT tests run it? Is this mainly meant for things we run via Everest?