Skip to content

Commit 7df38fb

Browse files
committed
Further transf
1 parent 98e6eb4 commit 7df38fb

File tree

1 file changed

+55
-4
lines changed

1 file changed

+55
-4
lines changed

src/ert/config/scalar_parameter.py

+55-4
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
import warnings
12
from dataclasses import dataclass
23
from pathlib import Path
34
from typing import Literal
45

6+
import numpy as np
57
import xarray as xr
68
from scipy.stats import norm
79

@@ -11,7 +13,7 @@
1113

1214

1315
@dataclass
14-
class TransUniformfSettings:
16+
class TransUniformSettings:
1517
name: Literal["uniform"] = "uniform"
1618
min: float = 0.0
1719
max: float = 1.0
@@ -22,7 +24,7 @@ def trans(self, x: float) -> float:
2224

2325

2426
@dataclass
25-
class TransNormalfSettings:
27+
class TransNormalSettings:
2628
name: Literal["normal"] = "normal"
2729
mean: float = 0.0
2830
std: float = 1.0
@@ -48,6 +50,53 @@ def trans(self, _: float) -> float:
4850
return self.value
4951

5052

53+
@dataclass
54+
class TransErrfSettings:
55+
name: Literal["errf"] = "errf"
56+
min: float = 0.0
57+
max: float = 1.0
58+
skew: float = 0.0
59+
width: float = 1.0
60+
61+
def trans(self, x: float) -> float:
62+
y = norm(loc=0, scale=self.width).cdf(x + self.skew)
63+
if np.isnan(y):
64+
raise ValueError(
65+
"Output is nan, likely from triplet (x, skewness, width) "
66+
"leading to low/high-probability in normal CDF."
67+
)
68+
return self.min + y * (self.max - self.min)
69+
70+
71+
@dataclass
72+
class TransDerrfSettings:
73+
name: Literal["derrf"] = "derrf"
74+
steps: int = 1000
75+
min: float = 0.0
76+
max: float = 1.0
77+
skew: float = 0.0
78+
width: float = 1.0
79+
80+
def trans(self, x: float) -> float:
81+
q_values = np.linspace(start=0, stop=1, num=self.steps)
82+
q_checks = np.linspace(start=0, stop=1, num=self.steps + 1)[1:]
83+
y = TransErrfSettings(min=0, max=1, skew=self.skew, width=self.width).trans(x)
84+
bin_index = np.digitize(y, q_checks, right=True)
85+
y_binned = q_values[bin_index]
86+
result = self.min + y_binned * (self.max - self.min)
87+
if result > self.max or result < self.min:
88+
warnings.warn(
89+
"trans_derff suffered from catastrophic loss of precision, clamping to min,max",
90+
stacklevel=1,
91+
)
92+
return np.clip(result, self.min, self.max)
93+
if np.isnan(result):
94+
raise ValueError(
95+
"trans_derrf returns nan, check that input arguments are reasonable"
96+
)
97+
return float(result)
98+
99+
51100
@dataclass
52101
class PolarsData:
53102
name: Literal["polars"]
@@ -59,10 +108,12 @@ class ScalarParameter(ParameterConfig):
59108
name: str
60109
group: str
61110
distribution: (
62-
TransUniformfSettings
111+
TransUniformSettings
63112
| TransRawSettings
64113
| TransConstSettings
65-
| TransNormalfSettings
114+
| TransNormalSettings
115+
| TransErrfSettings
116+
| TransDerrfSettings
66117
)
67118
active: bool
68119
input_source: Literal["design_matrix", "sampled"]

0 commit comments

Comments
 (0)