Skip to content

Commit 3059a0a

Browse files
authored
Add check for alphas (#21)
1 parent e3f72a2 commit 3059a0a

File tree

2 files changed

+20
-0
lines changed

2 files changed

+20
-0
lines changed

dageo/data_assimilation.py

+5
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
1414
# License for the specific language governing permissions and limitations under
1515
# the License.
16+
import warnings
1617

1718
import numpy as np
1819

@@ -91,6 +92,10 @@ def esmda(model_prior, forward, data_obs, sigma, alphas=4, data_prior=None,
9192
alphas = np.zeros(alphas) + alphas
9293
else:
9394
alphas = np.asarray(alphas)
95+
if abs(np.sum(1/alphas)-1) > 0.01:
96+
warnings.warn(
97+
f"The sum of 1/alphas should be 1; provided: {np.sum(1/alphas)}."
98+
)
9499

95100
# Copy prior as start of post (output)
96101
model_post = model_prior.copy()

tests/test_data_assimilation.py

+15
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import pytest
12
import numpy as np
23
from numpy.testing import assert_allclose
34

@@ -89,6 +90,20 @@ def cbp(x):
8990
assert_allclose(43420257, np.max(p), atol=1)
9091
assert_allclose(0.0, x[np.argmax(p)], atol=1e-8)
9192

93+
# Warning
94+
with pytest.warns(UserWarning, match='provided: 1.25'):
95+
lm_post3 = dageo.esmda(
96+
model_prior=mprior,
97+
forward=lin_fwd,
98+
data_obs=l_dobs,
99+
sigma=obs_std,
100+
alphas=[4, 4, 4, 4, 4],
101+
localization_matrix=np.array([[0.5]]),
102+
callback_post=cbp,
103+
return_post_data=False,
104+
random=3333,
105+
)
106+
92107

93108
def test_all_dir():
94109
assert set(data_assimilation.__all__) == set(dir(data_assimilation))

0 commit comments

Comments
 (0)