Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug where site config was not propagated to Everest config #9719

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/ert/config/parsing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .analysis_mode import AnalysisMode
from .base_model_context import BaseModelWithContextSupport
from .config_dict import ConfigDict
from .config_errors import ConfigValidationError, ConfigWarning
from .config_keywords import ConfigKeys
Expand All @@ -20,6 +21,7 @@

__all__ = [
"AnalysisMode",
"BaseModelWithContextSupport",
"ConfigDict",
"ConfigKeys",
"ConfigValidationError",
Expand Down
26 changes: 26 additions & 0 deletions src/ert/config/parsing/base_model_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from collections.abc import Iterator
from contextlib import contextmanager
from contextvars import ContextVar
from typing import Any

from pydantic import BaseModel

init_context_var = ContextVar("_init_context_var", default=None)


@contextmanager
def init_context(value: dict[str, Any]) -> Iterator[None]:
token = init_context_var.set(value) # type: ignore
try:
yield
finally:
init_context_var.reset(token)


class BaseModelWithContextSupport(BaseModel):
def __init__(__pydantic_self__, **data: Any) -> None:
__pydantic_self__.__pydantic_validator__.validate_python(
data,
self_instance=__pydantic_self__,
context=init_context_var.get(),
)
79 changes: 43 additions & 36 deletions src/ert/config/queue_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,16 @@
import re
import shutil
from abc import abstractmethod
from dataclasses import asdict, field, fields
from typing import Annotated, Any, Literal, no_type_check

import pydantic
from pydantic import Field, field_validator
from pydantic.dataclasses import dataclass
from pydantic_core.core_schema import ValidationInfo

from ._get_num_cpu import get_num_cpu_from_data_file
from .parsing import (
BaseModelWithContextSupport,
ConfigDict,
ConfigKeys,
ConfigValidationError,
Expand All @@ -37,20 +39,30 @@ def activate_script() -> str:
return ""


@pydantic.dataclasses.dataclass(
config={
"extra": "forbid",
"validate_assignment": True,
"use_enum_values": True,
"validate_default": True,
}
)
class QueueOptions:
class QueueOptions(
BaseModelWithContextSupport,
validate_assignment=True,
extra="forbid",
use_enum_values=True,
validate_default=True,
):
name: QueueSystem
max_running: pydantic.NonNegativeInt = 0
submit_sleep: pydantic.NonNegativeFloat = 0.0
project_code: str | None = None
activate_script: str = field(default_factory=activate_script)
activate_script: str | None = Field(default=None, validate_default=True)

@field_validator("activate_script", mode="before")
@classmethod
def inject_site_config_script(cls, v: str, info: ValidationInfo) -> str:
# User value gets highest priority
if isinstance(v, str):
return v
# Use from plugin system if user has not specified
plugin_script = None
if info.context:
plugin_script = info.context.get(info.field_name)
Copy link
Contributor

@yngve-sk yngve-sk Jan 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a test that runs this line? EDIT: I see running test_detached triggers it, but do any of the ERT tests run it? Is this mainly meant for things we run via Everest?

return plugin_script or activate_script() # Return default value

@staticmethod
def create_queue_options(
Expand Down Expand Up @@ -78,12 +90,12 @@ def create_queue_options(
return None

def add_global_queue_options(self, config_dict: ConfigDict) -> None:
for generic_option in fields(QueueOptions):
for name, generic_option in QueueOptions.model_fields.items():
if (
generic_value := config_dict.get(generic_option.name.upper(), None) # type: ignore
) and self.__dict__[generic_option.name] == generic_option.default:
generic_value := config_dict.get(name.upper(), None) # type: ignore
) and self.__dict__[name] == generic_option.default:
try:
setattr(self, generic_option.name, generic_value)
setattr(self, name, generic_value)
except pydantic.ValidationError as exception:
for error in exception.errors():
_throw_error_or_warning(
Expand All @@ -98,7 +110,6 @@ def driver_options(self) -> dict[str, Any]:
"""Translate the queue options to the key-value API provided by each driver"""


@pydantic.dataclasses.dataclass
class LocalQueueOptions(QueueOptions):
name: Literal[QueueSystem.LOCAL] = QueueSystem.LOCAL

Expand All @@ -107,7 +118,6 @@ def driver_options(self) -> dict[str, Any]:
return {}


@pydantic.dataclasses.dataclass
class LsfQueueOptions(QueueOptions):
name: Literal[QueueSystem.LSF] = QueueSystem.LSF
bhist_cmd: NonEmptyString | None = None
Expand All @@ -120,17 +130,13 @@ class LsfQueueOptions(QueueOptions):

@property
def driver_options(self) -> dict[str, Any]:
driver_dict = asdict(self)
driver_dict.pop("name")
driver_dict = self.model_dump(exclude={"name", "submit_sleep", "max_running"})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice ❤️

driver_dict["exclude_hosts"] = driver_dict.pop("exclude_host")
driver_dict["queue_name"] = driver_dict.pop("lsf_queue")
driver_dict["resource_requirement"] = driver_dict.pop("lsf_resource")
driver_dict.pop("submit_sleep")
driver_dict.pop("max_running")
return driver_dict


@pydantic.dataclasses.dataclass
class TorqueQueueOptions(QueueOptions):
name: Literal[QueueSystem.TORQUE] = QueueSystem.TORQUE
qsub_cmd: NonEmptyString | None = None
Expand All @@ -143,15 +149,19 @@ class TorqueQueueOptions(QueueOptions):

@property
def driver_options(self) -> dict[str, Any]:
driver_dict = asdict(self)
driver_dict.pop("name")
driver_dict = self.model_dump(
exclude={
"name",
"max_running",
"submit_sleep",
"qstat_options",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Were these: qstat_options and queue_query_timeout excluded before? If not, should maybe be excluded in a separate commit for clarity

"queue_query_timeout",
}
)
driver_dict["queue_name"] = driver_dict.pop("queue")
driver_dict.pop("max_running")
driver_dict.pop("submit_sleep")
return driver_dict


@pydantic.dataclasses.dataclass
class SlurmQueueOptions(QueueOptions):
name: Literal[QueueSystem.SLURM] = QueueSystem.SLURM
sbatch: NonEmptyString = "sbatch"
Expand All @@ -167,8 +177,7 @@ class SlurmQueueOptions(QueueOptions):

@property
def driver_options(self) -> dict[str, Any]:
driver_dict = asdict(self)
driver_dict.pop("name")
driver_dict = self.model_dump(exclude={"name", "max_running", "submit_sleep"})
driver_dict["sbatch_cmd"] = driver_dict.pop("sbatch")
driver_dict["scancel_cmd"] = driver_dict.pop("scancel")
driver_dict["scontrol_cmd"] = driver_dict.pop("scontrol")
Expand All @@ -177,8 +186,6 @@ def driver_options(self) -> dict[str, Any]:
driver_dict["exclude_hosts"] = driver_dict.pop("exclude_host")
driver_dict["include_hosts"] = driver_dict.pop("include_host")
driver_dict["queue_name"] = driver_dict.pop("partition")
driver_dict.pop("max_running")
driver_dict.pop("submit_sleep")
return driver_dict


Expand All @@ -203,12 +210,12 @@ def validate(self, mem_str_format: str | None) -> bool:
)

valid_options: dict[str, list[str]] = {
QueueSystem.LOCAL: [field.name.upper() for field in fields(LocalQueueOptions)],
QueueSystem.LSF: [field.name.upper() for field in fields(LsfQueueOptions)],
QueueSystem.SLURM: [field.name.upper() for field in fields(SlurmQueueOptions)],
QueueSystem.TORQUE: [field.name.upper() for field in fields(TorqueQueueOptions)],
QueueSystem.LOCAL: [field.upper() for field in LocalQueueOptions.model_fields],
QueueSystem.LSF: [field.upper() for field in LsfQueueOptions.model_fields],
QueueSystem.SLURM: [field.upper() for field in SlurmQueueOptions.model_fields],
QueueSystem.TORQUE: [field.upper() for field in TorqueQueueOptions.model_fields],
QueueSystemWithGeneric.GENERIC: [
field.name.upper() for field in fields(QueueOptions)
field.upper() for field in QueueOptions.model_fields
],
}

Expand Down
7 changes: 3 additions & 4 deletions src/ert/gui/simulation/experiment_panel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import os
import platform
from collections import OrderedDict
from dataclasses import fields
from datetime import datetime
from pathlib import Path
from queue import SimpleQueue
Expand Down Expand Up @@ -379,10 +378,10 @@ def populate_clipboard_debug_info(self) -> None:
if isinstance(self.get_current_experiment_type(), SingleTestRun):
queue_opts = LocalQueueOptions(max_running=1)

for field in fields(queue_opts):
field_value = getattr(queue_opts, field.name)
for name in queue_opts.model_fields:
field_value = getattr(queue_opts, name)
if field_value is not None:
kv[field.name.replace("_", " ").capitalize()] = str(field_value)
kv[name.replace("_", " ").capitalize()] = str(field_value)

kv["**Status**"] = ""
kv["Trace ID"] = get_trace_id()
Expand Down
22 changes: 19 additions & 3 deletions src/everest/config/everest_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,13 @@
field_validator,
model_validator,
)
from pydantic_core.core_schema import ValidationInfo
from ruamel.yaml import YAML, YAMLError

from ert.config import ErtConfig
from ert.config import ErtConfig, QueueConfig
from ert.config.parsing import BaseModelWithContextSupport
from ert.config.parsing.base_model_context import init_context
from ert.plugins import ErtPluginManager
from everest.config.control_variable_config import ControlVariableGuessListConfig
from everest.config.install_template_config import InstallTemplateConfig
from everest.config.server_config import ServerConfig
Expand Down Expand Up @@ -134,7 +138,7 @@ class HasName(Protocol):
name: str


class EverestConfig(BaseModelWithPropertySupport): # type: ignore
class EverestConfig(BaseModelWithPropertySupport, BaseModelWithContextSupport): # type: ignore
controls: Annotated[list[ControlConfig], AfterValidator(unique_items)] = Field(
description="""Defines a list of controls.
Controls should have unique names each control defines
Expand Down Expand Up @@ -267,7 +271,7 @@ def validate_queue_system(self) -> Self: # pylint: disable=E0213
return self

@model_validator(mode="after")
def validate_forward_model_job_name_installed(self) -> Self: # pylint: disable=E0213
def validate_forward_model_job_name_installed(self, info: ValidationInfo) -> Self: # pylint: disable=E0213
install_jobs = self.install_jobs
forward_model_jobs = self.forward_model
if install_jobs is None:
Expand Down Expand Up @@ -807,6 +811,18 @@ def load_file(config_file: str) -> "EverestConfig":

raise exp from error

@classmethod
def with_plugins(cls, config_dict):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a bit confused as to why mypy is not complaining about the missing type for the argument config_dict shouldn't there be a : dict[str, Any] or something similar?

context = {}
activate_script = ErtPluginManager().activate_script()
site_config = ErtConfig.read_site_config()
if site_config:
context["queue_system"] = QueueConfig.from_dict(site_config).queue_options
if activate_script:
context["activate_script"] = ErtPluginManager().activate_script()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can just have

context["activate_script"] = activate_script

in place of

context["activate_script"] = ErtPluginManager().activate_script()

Copy link
Contributor

@yngve-sk yngve-sk Jan 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a need to re-invoke ErtPluginManager().activate_script() after having stored it in a variable?

with init_context(context):
return cls(**config_dict)

@staticmethod
def load_file_with_argparser(
config_path, parser: ArgumentParser
Expand Down
12 changes: 1 addition & 11 deletions src/everest/config/server_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@
import os
from typing import Any

from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
from pydantic import BaseModel, ConfigDict, Field, model_validator

from ert.config.queue_config import (
LocalQueueOptions,
LsfQueueOptions,
SlurmQueueOptions,
TorqueQueueOptions,
)
from ert.plugins import ErtPluginManager

from ..strings import (
CERTIFICATE_DIR,
Expand Down Expand Up @@ -38,15 +37,6 @@ class ServerConfig(BaseModel): # type: ignore
extra="forbid",
)

@field_validator("queue_system", mode="before")
@classmethod
def default_local_queue(cls, v):
if v is None:
return v
elif "activate_script" not in v and ErtPluginManager().activate_script():
v["activate_script"] = ErtPluginManager().activate_script()
return v

@model_validator(mode="before")
@classmethod
def check_old_config(cls, data: Any) -> Any:
Expand Down
19 changes: 9 additions & 10 deletions src/everest/config/simulator_config.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
from typing import Any

from pydantic import (
BaseModel,
Field,
NonNegativeInt,
PositiveInt,
field_validator,
model_validator,
)
from pydantic_core.core_schema import ValidationInfo

from ert.config.parsing import BaseModelWithContextSupport
from ert.config.queue_config import (
LocalQueueOptions,
LsfQueueOptions,
SlurmQueueOptions,
TorqueQueueOptions,
)
from ert.plugins import ErtPluginManager

simulator_example = {"queue_system": {"name": "local", "max_running": 3}}

Expand All @@ -29,11 +29,11 @@ def check_removed_config(queue_system):
}
if isinstance(queue_system, str) and queue_system in queue_systems:
raise ValueError(
f"Queue system configuration has changed, valid options for {queue_system} are: {list(queue_systems[queue_system].__dataclass_fields__.keys())}"
f"Queue system configuration has changed, valid options for {queue_system} are: {list(queue_systems[queue_system].model_fields.keys())}"
)


class SimulatorConfig(BaseModel, extra="forbid"): # type: ignore
class SimulatorConfig(BaseModelWithContextSupport, extra="forbid"): # type: ignore
cores_per_node: PositiveInt | None = Field(
default=None,
description="""defines the number of CPUs when running
Expand Down Expand Up @@ -94,13 +94,12 @@ class SimulatorConfig(BaseModel, extra="forbid"): # type: ignore

@field_validator("queue_system", mode="before")
@classmethod
def default_local_queue(cls, v):
def default_local_queue(cls, v, info: ValidationInfo):
if v is None:
return LocalQueueOptions(max_running=8)
if "activate_script" not in v and (
active_script := ErtPluginManager().activate_script()
):
v["activate_script"] = active_script
options = None
if info.context:
options = info.context.get(info.field_name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Won't info.field_name always be queue_system here? Similar question for other instances of this

return options or LocalQueueOptions(max_running=8)
return v

@model_validator(mode="before")
Expand Down
Loading