@@ -109,28 +109,27 @@ def _all_parameters(
109
109
110
110
111
111
def _save_param_ensemble_array_to_disk (
112
- ensemble : Ensemble ,
112
+ source_ensemble : Ensemble ,
113
+ target_ensemble : Ensemble ,
113
114
param_ensemble_array : npt .NDArray [np .float64 ],
114
115
param_group : str ,
115
116
iens_active_index : npt .NDArray [np .int_ ],
116
117
) -> None :
117
- config_node = ensemble .experiment .parameter_configuration [param_group ]
118
- for i , realization in enumerate (iens_active_index ):
119
- config_node .save_parameters (
120
- ensemble , param_group , realization , param_ensemble_array [:, i ]
121
- )
122
-
123
-
124
- def _load_param_ensemble_array_scalar (
125
- ensemble : Ensemble ,
126
- scalar_params : list [str ],
127
- iens_active_index : npt .NDArray [np .int_ ],
128
- ) -> npt .NDArray [np .float64 ]:
129
- config_node = ensemble .experiment .parameter_configuration ["SCALAR_PARAMETERS" ]
130
- assert isinstance (config_node , ScalarParameters )
131
- return config_node .load_parameters_scalar (
132
- ensemble , scalar_params , iens_active_index
133
- )
118
+ config_node = target_ensemble .experiment .parameter_configuration [param_group ]
119
+ if isinstance (config_node , ScalarParameters ):
120
+ for i , realization in enumerate (iens_active_index ):
121
+ config_node .save_updated_parameters_and_copy_remaining (
122
+ source_ensemble ,
123
+ target_ensemble ,
124
+ param_group ,
125
+ realization ,
126
+ param_ensemble_array [:, i ],
127
+ )
128
+ else :
129
+ for i , realization in enumerate (iens_active_index ):
130
+ config_node .save_parameters (
131
+ target_ensemble , param_group , realization , param_ensemble_array [:, i ]
132
+ )
134
133
135
134
136
135
def _load_param_ensemble_array (
@@ -140,7 +139,7 @@ def _load_param_ensemble_array(
140
139
) -> npt .NDArray [np .float64 ]:
141
140
config_node = ensemble .experiment .parameter_configuration [param_group ]
142
141
if isinstance (config_node , ScalarParameters ):
143
- raise ValueError ( "Config node is scalar!" )
142
+ return config_node . load_parameters_to_update ( ensemble , iens_active_index )
144
143
dataset = config_node .load_parameters (ensemble , param_group , iens_active_index )
145
144
assert isinstance (dataset , np .ndarray ), "dataset is not an numpy array"
146
145
return dataset
@@ -556,14 +555,9 @@ def correlation_callback(
556
555
cross_correlations_accumulator .append (cross_correlations_of_batch )
557
556
558
557
for param_group in parameters :
559
- if isinstance (param_group , list ):
560
- param_ensemble_array = _load_param_ensemble_array_scalar (
561
- source_ensemble , iens_active_index
562
- )
563
- else :
564
- param_ensemble_array = _load_param_ensemble_array (
565
- source_ensemble , param_group , iens_active_index
566
- )
558
+ param_ensemble_array = _load_param_ensemble_array (
559
+ source_ensemble , param_group , iens_active_index
560
+ )
567
561
if module .localization :
568
562
config_node = source_ensemble .experiment .parameter_configuration [
569
563
param_group
@@ -629,7 +623,11 @@ def correlation_callback(
629
623
start = time .time ()
630
624
631
625
_save_param_ensemble_array_to_disk (
632
- target_ensemble , param_ensemble_array , param_group , iens_active_index
626
+ source_ensemble ,
627
+ target_ensemble ,
628
+ param_ensemble_array ,
629
+ param_group ,
630
+ iens_active_index ,
633
631
)
634
632
logger .info (
635
633
f"Storing data for { param_group } completed in { (time .time () - start ) / 60 } minutes"
0 commit comments