|
5 | 5 | import time
|
6 | 6 | from collections.abc import Callable, Iterable, Sequence
|
7 | 7 | from fnmatch import fnmatch
|
8 |
| -from typing import ( |
9 |
| - TYPE_CHECKING, |
10 |
| - Generic, |
11 |
| - Self, |
12 |
| - TypeVar, |
13 |
| -) |
| 8 | +from typing import TYPE_CHECKING, Generic, Self, TypeVar |
14 | 9 |
|
15 | 10 | import iterative_ensemble_smoother as ies
|
16 | 11 | import numpy as np
|
|
19 | 14 | import scipy
|
20 | 15 | from iterative_ensemble_smoother.experimental import AdaptiveESMDA
|
21 | 16 |
|
22 |
| -from ert.config import ESSettings, GenKwConfig, ObservationGroups, UpdateSettings |
| 17 | +from ert.config import ( |
| 18 | + ESSettings, |
| 19 | + GenKwConfig, |
| 20 | + ObservationGroups, |
| 21 | + ScalarParameters, |
| 22 | + UpdateSettings, |
| 23 | +) |
23 | 24 |
|
24 | 25 | from . import misfit_preprocessor
|
25 | 26 | from .event import (
|
|
31 | 32 | AnalysisTimeEvent,
|
32 | 33 | DataSection,
|
33 | 34 | )
|
34 |
| -from .snapshots import ( |
35 |
| - ObservationAndResponseSnapshot, |
36 |
| - SmootherSnapshot, |
37 |
| -) |
| 35 | +from .snapshots import ObservationAndResponseSnapshot, SmootherSnapshot |
38 | 36 |
|
39 | 37 | if TYPE_CHECKING:
|
40 | 38 | import numpy.typing as npt
|
@@ -123,13 +121,29 @@ def _save_param_ensemble_array_to_disk(
|
123 | 121 | )
|
124 | 122 |
|
125 | 123 |
|
| 124 | +def _load_param_ensemble_array_scalar( |
| 125 | + ensemble: Ensemble, |
| 126 | + scalar_params: list[str], |
| 127 | + iens_active_index: npt.NDArray[np.int_], |
| 128 | +) -> npt.NDArray[np.float64]: |
| 129 | + config_node = ensemble.experiment.parameter_configuration["SCALAR_PARAMETERS"] |
| 130 | + assert isinstance(config_node, ScalarParameters) |
| 131 | + return config_node.load_parameters_scalar( |
| 132 | + ensemble, scalar_params, iens_active_index |
| 133 | + ) |
| 134 | + |
| 135 | + |
126 | 136 | def _load_param_ensemble_array(
|
127 | 137 | ensemble: Ensemble,
|
128 | 138 | param_group: str,
|
129 | 139 | iens_active_index: npt.NDArray[np.int_],
|
130 | 140 | ) -> npt.NDArray[np.float64]:
|
131 | 141 | config_node = ensemble.experiment.parameter_configuration[param_group]
|
132 |
| - return config_node.load_parameters(ensemble, param_group, iens_active_index) |
| 142 | + if isinstance(config_node, ScalarParameters): |
| 143 | + raise ValueError("Config node is scalar!") |
| 144 | + dataset = config_node.load_parameters(ensemble, param_group, iens_active_index) |
| 145 | + assert isinstance(dataset, np.ndarray), "dataset is not an numpy array" |
| 146 | + return dataset |
133 | 147 |
|
134 | 148 |
|
135 | 149 | def _expand_wildcards(
|
@@ -542,9 +556,14 @@ def correlation_callback(
|
542 | 556 | cross_correlations_accumulator.append(cross_correlations_of_batch)
|
543 | 557 |
|
544 | 558 | for param_group in parameters:
|
545 |
| - param_ensemble_array = _load_param_ensemble_array( |
546 |
| - source_ensemble, param_group, iens_active_index |
547 |
| - ) |
| 559 | + if isinstance(param_group, list): |
| 560 | + param_ensemble_array = _load_param_ensemble_array_scalar( |
| 561 | + source_ensemble, iens_active_index |
| 562 | + ) |
| 563 | + else: |
| 564 | + param_ensemble_array = _load_param_ensemble_array( |
| 565 | + source_ensemble, param_group, iens_active_index |
| 566 | + ) |
548 | 567 | if module.localization:
|
549 | 568 | config_node = source_ensemble.experiment.parameter_configuration[
|
550 | 569 | param_group
|
|
0 commit comments