1
+ import math
1
2
import warnings
2
3
from dataclasses import dataclass
3
4
from pathlib import Path
4
5
from typing import Literal
5
6
6
7
import numpy as np
7
- import xarray as xr
8
8
from scipy .stats import norm
9
9
10
10
from .parameter_config import ParameterConfig
13
13
14
14
15
15
@dataclass
16
- class TransUniformSettings :
17
- name : Literal ["uniform " ] = "uniform "
16
+ class TransUnifSettings :
17
+ name : Literal ["unif " ] = "unif "
18
18
min : float = 0.0
19
19
max : float = 1.0
20
20
@@ -23,6 +23,20 @@ def trans(self, x: float) -> float:
23
23
return y * (self .max - self .min ) + self .min
24
24
25
25
26
+ @dataclass
27
+ class TransDUnifSettings :
28
+ name : Literal ["dunif" ] = "dunif"
29
+ steps : int = 1000
30
+ min : float = 0.0
31
+ max : float = 1.0
32
+
33
+ def trans (self , x : float ) -> float :
34
+ y = float (norm .cdf (x ))
35
+ return (math .floor (y * self .steps ) / (self .steps - 1 )) * (
36
+ self .max - self .min
37
+ ) + self .min
38
+
39
+
26
40
@dataclass
27
41
class TransNormalSettings :
28
42
name : Literal ["normal" ] = "normal"
@@ -33,6 +47,19 @@ def trans(self, x: float) -> float:
33
47
return x * self .std + self .mean
34
48
35
49
50
+ @dataclass
51
+ class TransTruncNormalSettings :
52
+ name : Literal ["trunc_normal" ] = "trunc_normal"
53
+ mean : float = 0.0
54
+ std : float = 1.0
55
+ min : float = 0.0
56
+ max : float = 1.0
57
+
58
+ def trans (self , x : float ) -> float :
59
+ y = x * self .std + self .mean
60
+ return max (min (y , self .max ), self .min ) # clamp
61
+
62
+
36
63
@dataclass
37
64
class TransRawSettings :
38
65
name : Literal ["raw" ] = "raw"
@@ -50,6 +77,25 @@ def trans(self, _: float) -> float:
50
77
return self .value
51
78
52
79
80
+ @dataclass
81
+ class TransTriangularSettings :
82
+ name : Literal ["triangular" ] = "triangular"
83
+ min : float = 0.0
84
+ mode : float = 0.5
85
+ max : float = 1.0
86
+
87
+ def trans (self , x : float ) -> float :
88
+ inv_norm_left = (self .max - self .min ) * (self .mode - self .min )
89
+ inv_norm_right = (self .max - self .min ) * (self .max - self .mode )
90
+ ymode = (self .mode - self .min ) / (self .max - self .min )
91
+ y = norm .cdf (x )
92
+
93
+ if y < ymode :
94
+ return self .min + math .sqrt (y * inv_norm_left )
95
+ else :
96
+ return self .max - math .sqrt ((1 - y ) * inv_norm_right )
97
+
98
+
53
99
@dataclass
54
100
class TransErrfSettings :
55
101
name : Literal ["errf" ] = "errf"
@@ -105,16 +151,19 @@ class PolarsData:
105
151
106
152
@dataclass
107
153
class ScalarParameter (ParameterConfig ):
108
- name : str
154
+ # name: str
109
155
group : str
110
156
distribution : (
111
- TransUniformSettings
157
+ TransUnifSettings
158
+ | TransDUnifSettings
112
159
| TransRawSettings
113
160
| TransConstSettings
114
161
| TransNormalSettings
162
+ | TransTruncNormalSettings
115
163
| TransErrfSettings
116
164
| TransDerrfSettings
165
+ | TransTriangularSettings
117
166
)
118
167
active : bool
119
168
input_source : Literal ["design_matrix" , "sampled" ]
120
- dataset_file : PolarsData | xr . Dataset
169
+ dataset_file : PolarsData
0 commit comments