Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 7be0092

Browse files
committedFeb 18, 2025
More transf II.
1 parent 7df38fb commit 7be0092

File tree

1 file changed

+55
-6
lines changed

1 file changed

+55
-6
lines changed
 

‎src/ert/config/scalar_parameter.py

+55-6
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1+
import math
12
import warnings
23
from dataclasses import dataclass
34
from pathlib import Path
45
from typing import Literal
56

67
import numpy as np
7-
import xarray as xr
88
from scipy.stats import norm
99

1010
from .parameter_config import ParameterConfig
@@ -13,8 +13,8 @@
1313

1414

1515
@dataclass
16-
class TransUniformSettings:
17-
name: Literal["uniform"] = "uniform"
16+
class TransUnifSettings:
17+
name: Literal["unif"] = "unif"
1818
min: float = 0.0
1919
max: float = 1.0
2020

@@ -23,6 +23,20 @@ def trans(self, x: float) -> float:
2323
return y * (self.max - self.min) + self.min
2424

2525

26+
@dataclass
27+
class TransDUnifSettings:
28+
name: Literal["dunif"] = "dunif"
29+
steps: int = 1000
30+
min: float = 0.0
31+
max: float = 1.0
32+
33+
def trans(self, x: float) -> float:
34+
y = float(norm.cdf(x))
35+
return (math.floor(y * self.steps) / (self.steps - 1)) * (
36+
self.max - self.min
37+
) + self.min
38+
39+
2640
@dataclass
2741
class TransNormalSettings:
2842
name: Literal["normal"] = "normal"
@@ -33,6 +47,19 @@ def trans(self, x: float) -> float:
3347
return x * self.std + self.mean
3448

3549

50+
@dataclass
51+
class TransTruncNormalSettings:
52+
name: Literal["trunc_normal"] = "trunc_normal"
53+
mean: float = 0.0
54+
std: float = 1.0
55+
min: float = 0.0
56+
max: float = 1.0
57+
58+
def trans(self, x: float) -> float:
59+
y = x * self.std + self.mean
60+
return max(min(y, self.max), self.min) # clamp
61+
62+
3663
@dataclass
3764
class TransRawSettings:
3865
name: Literal["raw"] = "raw"
@@ -50,6 +77,25 @@ def trans(self, _: float) -> float:
5077
return self.value
5178

5279

80+
@dataclass
81+
class TransTriangularSettings:
82+
name: Literal["triangular"] = "triangular"
83+
min: float = 0.0
84+
mode: float = 0.5
85+
max: float = 1.0
86+
87+
def trans(self, x: float) -> float:
88+
inv_norm_left = (self.max - self.min) * (self.mode - self.min)
89+
inv_norm_right = (self.max - self.min) * (self.max - self.mode)
90+
ymode = (self.mode - self.min) / (self.max - self.min)
91+
y = norm.cdf(x)
92+
93+
if y < ymode:
94+
return self.min + math.sqrt(y * inv_norm_left)
95+
else:
96+
return self.max - math.sqrt((1 - y) * inv_norm_right)
97+
98+
5399
@dataclass
54100
class TransErrfSettings:
55101
name: Literal["errf"] = "errf"
@@ -105,16 +151,19 @@ class PolarsData:
105151

106152
@dataclass
107153
class ScalarParameter(ParameterConfig):
108-
name: str
154+
# name: str
109155
group: str
110156
distribution: (
111-
TransUniformSettings
157+
TransUnifSettings
158+
| TransDUnifSettings
112159
| TransRawSettings
113160
| TransConstSettings
114161
| TransNormalSettings
162+
| TransTruncNormalSettings
115163
| TransErrfSettings
116164
| TransDerrfSettings
165+
| TransTriangularSettings
117166
)
118167
active: bool
119168
input_source: Literal["design_matrix", "sampled"]
120-
dataset_file: PolarsData | xr.Dataset
169+
dataset_file: PolarsData

0 commit comments

Comments
 (0)
Please sign in to comment.