|
2 | 2 | from datetime import datetime, timedelta
|
3 | 3 | from textwrap import dedent
|
4 | 4 |
|
| 5 | +import numpy as np |
5 | 6 | import pytest
|
| 7 | +from pandas import ExcelWriter |
6 | 8 | from pandas.core.frame import DataFrame
|
7 | 9 | from resdata.summary import Summary
|
8 | 10 |
|
9 | 11 | from ert.config import ErtConfig
|
10 |
| -from ert.enkf_main import sample_prior |
| 12 | +from ert.config.design_matrix import DESIGN_MATRIX_GROUP, DesignMatrix |
| 13 | +from ert.enkf_main import sample_prior, save_design_matrix_to_ensemble |
11 | 14 | from ert.libres_facade import LibresFacade
|
12 | 15 | from ert.storage import open_storage
|
13 | 16 |
|
@@ -241,3 +244,53 @@ def test_load_gen_kw_not_sorted(storage, tmpdir, snapshot):
|
241 | 244 |
|
242 | 245 | data = ensemble.load_all_gen_kw_data()
|
243 | 246 | snapshot.assert_match(data.round(12).to_csv(), "gen_kw_unsorted")
|
| 247 | + |
| 248 | + |
| 249 | +@pytest.mark.parametrize( |
| 250 | + "reals, expect_error", |
| 251 | + [ |
| 252 | + pytest.param( |
| 253 | + list(range(10)), |
| 254 | + False, |
| 255 | + id="correct_active_realizations", |
| 256 | + ), |
| 257 | + pytest.param([10, 11], True, id="incorrect_active_realizations"), |
| 258 | + ], |
| 259 | +) |
| 260 | +def test_save_parameters_to_storage_from_design_dataframe( |
| 261 | + tmp_path, reals, expect_error |
| 262 | +): |
| 263 | + design_path = tmp_path / "design_matrix.xlsx" |
| 264 | + ensemble_size = 10 |
| 265 | + a_values = np.random.default_rng().uniform(-5, 5, 10) |
| 266 | + b_values = np.random.default_rng().uniform(-5, 5, 10) |
| 267 | + c_values = np.random.default_rng().uniform(-5, 5, 10) |
| 268 | + design_matrix_df = DataFrame({"a": a_values, "b": b_values, "c": c_values}) |
| 269 | + with ExcelWriter(design_path) as xl_write: |
| 270 | + design_matrix_df.to_excel(xl_write, index=False, sheet_name="DesignSheet01") |
| 271 | + DataFrame().to_excel( |
| 272 | + xl_write, index=False, sheet_name="DefaultValues", header=False |
| 273 | + ) |
| 274 | + design_matrix = DesignMatrix(design_path, "DesignSheet01", "DefaultValues") |
| 275 | + design_matrix.read_design_matrix() |
| 276 | + with open_storage(tmp_path / "storage", mode="w") as storage: |
| 277 | + experiment_id = storage.create_experiment( |
| 278 | + parameters=[design_matrix.parameter_configuration[DESIGN_MATRIX_GROUP]] |
| 279 | + ) |
| 280 | + ensemble = storage.create_ensemble( |
| 281 | + experiment_id, name="default", ensemble_size=ensemble_size |
| 282 | + ) |
| 283 | + if expect_error: |
| 284 | + with pytest.raises(KeyError): |
| 285 | + save_design_matrix_to_ensemble( |
| 286 | + design_matrix.design_matrix_df, ensemble, reals |
| 287 | + ) |
| 288 | + else: |
| 289 | + save_design_matrix_to_ensemble( |
| 290 | + design_matrix.design_matrix_df, ensemble, reals |
| 291 | + ) |
| 292 | + params = ensemble.load_parameters(DESIGN_MATRIX_GROUP)["values"] |
| 293 | + all(params.names.values == ["a", "b", "c"]) |
| 294 | + np.testing.assert_array_almost_equal(params[:, 0], a_values) |
| 295 | + np.testing.assert_array_almost_equal(params[:, 1], b_values) |
| 296 | + np.testing.assert_array_almost_equal(params[:, 2], c_values) |
0 commit comments