Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validate workflow job args against hook runtime fixtures #10341

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 39 additions & 1 deletion src/ert/config/ert_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from pydantic import field_validator
from pydantic.dataclasses import dataclass, rebuild_dataclass

from ert.plugins import ErtPluginManager
from ert.plugins import ErtPluginManager, fixtures_per_hook
from ert.plugins.workflow_config import ErtScriptWorkflow
from ert.substitutions import Substitutions

Expand Down Expand Up @@ -395,10 +395,48 @@ def workflows_from_dict(
)
continue

wf = workflows[hook_name]
available_fixtures = fixtures_per_hook[mode]
for job, _ in wf.cmd_list:
if job.ert_script is None:
continue

ert_script_instance = job.ert_script()
requested_fixtures = ert_script_instance.requested_fixtures

# Look for requested fixtures that are not available for the given
# mode
missing_fixtures = requested_fixtures - available_fixtures

if missing_fixtures:
ok_modes = [
m
for m in HookRuntime
if not requested_fixtures - fixtures_per_hook[m]
]

message_start = (
f"Workflow job {job.name} .run function expected "
f"fixtures: {missing_fixtures}, which are not available "
f"in the fixtures for the runtime {mode}: {available_fixtures}. "
)
message_end = (
f"It would work in these runtimes: {', '.join(map(str, ok_modes))}"
if len(ok_modes) > 0
else "This fixture is not available in any of the runtimes."
)

errors.append(
ErrorInfo(message=message_start + message_end).set_context(
hook_name
)
)

hooked_workflows[mode].append(workflows[hook_name])

if errors:
raise ConfigValidationError.from_collected(errors)

return workflow_jobs, workflows, hooked_workflows


Expand Down
4 changes: 2 additions & 2 deletions src/ert/gui/tools/plugins/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import inspect
from typing import TYPE_CHECKING, Any

from ert.plugins import ErtPlugin, WorkflowFixtures
from ert.plugins import ErtPlugin, HookedWorkflowFixtures

if TYPE_CHECKING:
from PyQt6.QtWidgets import QWidget
Expand Down Expand Up @@ -34,7 +34,7 @@ def getName(self) -> str:
def getDescription(self) -> str:
return self.__description

def getArguments(self, fixtures: WorkflowFixtures) -> list[Any]:
def getArguments(self, fixtures: HookedWorkflowFixtures) -> list[Any]:
"""
Returns a list of arguments. Either from GUI or from arbitrary code.
If the user for example cancels in the GUI a CancelPluginException is raised.
Expand Down
4 changes: 2 additions & 2 deletions src/ert/gui/tools/plugins/plugin_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from _ert.threading import ErtThread
from ert.config import ErtConfig
from ert.plugins import CancelPluginException, WorkflowFixtures
from ert.plugins import CancelPluginException, HookedWorkflowFixtures
from ert.runpaths import Runpaths
from ert.workflow_runner import WorkflowJobRunner

Expand Down Expand Up @@ -84,7 +84,7 @@ def run(self) -> None:
print("Plugin cancelled before execution!")

def __runWorkflowJob(
self, arguments: list[Any] | None, fixtures: WorkflowFixtures
self, arguments: list[Any] | None, fixtures: HookedWorkflowFixtures
) -> None:
self.__result = self._runner.run(arguments, fixtures=fixtures)

Expand Down
10 changes: 9 additions & 1 deletion src/ert/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@
)
from .plugin_response import PluginMetadata, PluginResponse
from .workflow_config import ErtScriptWorkflow, WorkflowConfigs
from .workflow_fixtures import WorkflowFixtures
from .workflow_fixtures import (
HookedWorkflowFixtures,
WorkflowFixtures,
all_hooked_workflow_fixtures,
fixtures_per_hook,
)

P = ParamSpec("P")

Expand Down Expand Up @@ -59,8 +64,11 @@ def inner(*args: P.args, **kwargs: P.kwargs) -> Any:
"ErtScript",
"ErtScriptWorkflow",
"ExternalErtScript",
"HookedWorkflowFixtures",
"JobDoc",
"WorkflowConfigs",
"WorkflowFixtures",
"all_hooked_workflow_fixtures",
"fixtures_per_hook",
"plugin",
]
34 changes: 27 additions & 7 deletions src/ert/plugins/ert_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,15 @@
from abc import abstractmethod
from collections.abc import Callable
from types import MappingProxyType, ModuleType
from typing import TYPE_CHECKING, Any, TypeAlias
from typing import TYPE_CHECKING, Any, TypeAlias, cast

from typing_extensions import deprecated

from .workflow_fixtures import WorkflowFixtures
from .workflow_fixtures import (
HookedWorkflowFixtures,
WorkflowFixtures,
all_hooked_workflow_fixtures,
)

if TYPE_CHECKING:
from ert.config import ErtConfig
Expand Down Expand Up @@ -107,13 +111,28 @@ def cleanup(self) -> None:
Override to perform cleanup after a run.
"""

@property
def requested_fixtures(self) -> set[str]:
return {
k
for k, v in inspect.signature(self.run).parameters.items()
if k in all_hooked_workflow_fixtures
}

def initializeAndRun(
self,
argument_types: list[type[Any]],
argument_values: list[str],
fixtures: WorkflowFixtures | None = None,
fixtures: WorkflowFixtures | HookedWorkflowFixtures | None = None,
) -> Any:
fixtures = {} if fixtures is None else fixtures
fixtures_without_hook = {**(fixtures or {})}

if "hook" in fixtures_without_hook:
fixtures_without_hook.pop("hook")

complete_fixtures: WorkflowFixtures = cast(
WorkflowFixtures, fixtures_without_hook
)
arguments = []
for index, arg_value in enumerate(argument_values):
arg_type = argument_types[index] if index < len(argument_types) else str
Expand All @@ -122,14 +141,15 @@ def initializeAndRun(
arguments.append(arg_type(arg_value))
else:
arguments.append(None)
fixtures["workflow_args"] = arguments

complete_fixtures["workflow_args"] = arguments
try:
func_args = inspect.signature(self.run).parameters
# If the user has specified *args, we skip injecting fixtures, and just
# If the user has specified *args, we skip injecting complete_fixtures, and just
# pass the user configured arguments
if not any(p.kind == p.VAR_POSITIONAL for p in func_args.values()):
try:
arguments = self.insert_fixtures(func_args, fixtures)
arguments = self.insert_fixtures(func_args, complete_fixtures)
except ValueError as e:
# This is here for backwards compatibility, the user does not have *argv
# but positional arguments. Can not be mixed with using fixtures.
Expand Down
99 changes: 90 additions & 9 deletions src/ert/plugins/workflow_fixtures.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,104 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any
import functools
import typing
from typing import TYPE_CHECKING, NotRequired, TypedDict

from PyQt6.QtWidgets import QWidget
from typing_extensions import TypedDict

from ert.config.parsing.hook_runtime import HookRuntime

if TYPE_CHECKING:
from ert.config import ESSettings, UpdateSettings
from ert.runpaths import Runpaths
from ert.storage import Ensemble, Storage


class WorkflowFixtures(TypedDict, total=False):
ensemble: Ensemble | None
class PreExperimentFixtures(TypedDict):
random_seed: int
hook: HookRuntime = HookRuntime.PRE_EXPERIMENT

Check failure on line 19 in src/ert/plugins/workflow_fixtures.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

Right hand side values are not supported in TypedDict


class PostExperimentFixtures(TypedDict):
hook: HookRuntime = HookRuntime.POST_EXPERIMENT

Check failure on line 23 in src/ert/plugins/workflow_fixtures.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

Right hand side values are not supported in TypedDict
random_seed: int
storage: Storage
random_seed: int | None
ensemble: Ensemble


class PreSimulationFixtures(TypedDict):
hook: HookRuntime = HookRuntime.PRE_SIMULATION

Check failure on line 30 in src/ert/plugins/workflow_fixtures.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

Right hand side values are not supported in TypedDict
random_seed: int
reports_dir: str
run_paths: Runpaths
storage: Storage
ensemble: Ensemble


class PostSimulationFixtures(PreSimulationFixtures):
hook: HookRuntime = HookRuntime.POST_SIMULATION

Check failure on line 39 in src/ert/plugins/workflow_fixtures.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

Overwriting TypedDict field "hook" while extending

Check failure on line 39 in src/ert/plugins/workflow_fixtures.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

Right hand side values are not supported in TypedDict


class PreFirstUpdateFixtures(TypedDict):
hook: HookRuntime = HookRuntime.PRE_FIRST_UPDATE

Check failure on line 43 in src/ert/plugins/workflow_fixtures.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

Right hand side values are not supported in TypedDict
random_seed: int
reports_dir: str
observation_settings: UpdateSettings
es_settings: ESSettings
run_paths: Runpaths
workflow_args: list[Any]
parent: QWidget | None
storage: Storage
ensemble: Ensemble
es_settings: ESSettings
observation_settings: UpdateSettings


class PreUpdateFixtures(PreFirstUpdateFixtures):
hook: HookRuntime = HookRuntime.PRE_UPDATE

Check failure on line 54 in src/ert/plugins/workflow_fixtures.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

Overwriting TypedDict field "hook" while extending

Check failure on line 54 in src/ert/plugins/workflow_fixtures.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

Right hand side values are not supported in TypedDict


class PostUpdateFixtures(PreFirstUpdateFixtures):
hook: HookRuntime = HookRuntime.POST_UPDATE

Check failure on line 58 in src/ert/plugins/workflow_fixtures.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

Overwriting TypedDict field "hook" while extending

Check failure on line 58 in src/ert/plugins/workflow_fixtures.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

Right hand side values are not supported in TypedDict


class WorkflowFixtures(TypedDict, total=False):
workflow_args: NotRequired[list[typing.Any]]
parent: NotRequired[QWidget]
random_seed: NotRequired[int]
reports_dir: NotRequired[str]
run_paths: NotRequired[Runpaths]
storage: NotRequired[Storage]
ensemble: NotRequired[Ensemble]
es_settings: NotRequired[ESSettings]
observation_settings: NotRequired[UpdateSettings]


# Union Type Definition
HookedWorkflowFixtures = (
PreExperimentFixtures
| PostExperimentFixtures
| PreSimulationFixtures
| PostSimulationFixtures
| PreFirstUpdateFixtures
| PreUpdateFixtures
| PostUpdateFixtures
)


def __all_workflow_fixtures() -> set[str]:
fixtures_per_runtime = (
__get_available_fixtures_for_hook(hook) for hook in HookRuntime
)

return functools.reduce(lambda a, b: a | b, fixtures_per_runtime)


def __get_available_fixtures_for_hook(hook: HookRuntime) -> set[str]:
cls = next(
cls for cls in typing.get_args(HookedWorkflowFixtures) if cls.hook == hook
)

return set(cls.__annotations__.keys()) - {"hook"}


all_hooked_workflow_fixtures = __all_workflow_fixtures()
fixtures_per_hook = {
hook: __get_available_fixtures_for_hook(hook) for hook in HookRuntime
}
6 changes: 3 additions & 3 deletions src/ert/run_models/base_run_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
REALIZATION_STATE_FINISHED,
)
from ert.mode_definitions import MODULE_MODE
from ert.plugins import WorkflowFixtures
from ert.plugins import HookedWorkflowFixtures
from ert.runpaths import Runpaths
from ert.storage import Ensemble, Storage
from ert.substitutions import Substitutions
Expand Down Expand Up @@ -737,7 +737,7 @@ def validate_successful_realizations_count(self) -> None:
def run_workflows(
self,
runtime: HookRuntime,
fixtures: WorkflowFixtures,
fixtures: HookedWorkflowFixtures,
) -> None:
for workflow in self._hooked_workflows[runtime]:
WorkflowRunner(workflow=workflow, fixtures=fixtures).run_blocking()
Expand Down Expand Up @@ -883,7 +883,7 @@ def update(
)
)

workflow_fixtures: WorkflowFixtures = {
workflow_fixtures: HookedWorkflowFixtures = {
"storage": self._storage,
"ensemble": prior,
"observation_settings": self._update_settings,
Expand Down
6 changes: 3 additions & 3 deletions src/ert/workflow_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Any, Self

from ert.config import Workflow, WorkflowJob
from ert.plugins import ErtScript, ExternalErtScript, WorkflowFixtures
from ert.plugins import ErtScript, ExternalErtScript, HookedWorkflowFixtures


class WorkflowJobRunner:
Expand All @@ -19,7 +19,7 @@ def __init__(self, workflow_job: WorkflowJob):
def run(
self,
arguments: list[Any] | None = None,
fixtures: WorkflowFixtures | None = None,
fixtures: HookedWorkflowFixtures | None = None,
) -> Any:
if arguments is None:
arguments = []
Expand Down Expand Up @@ -101,7 +101,7 @@ class WorkflowRunner:
def __init__(
self,
workflow: Workflow,
fixtures: WorkflowFixtures,
fixtures: HookedWorkflowFixtures,
) -> None:
self.__workflow = workflow
self.fixtures = fixtures
Expand Down
Loading