1
+ import warnings
1
2
from dataclasses import dataclass
2
3
from pathlib import Path
3
4
from typing import Literal
4
5
6
+ import numpy as np
5
7
import xarray as xr
6
8
from scipy .stats import norm
7
9
11
13
12
14
13
15
@dataclass
14
- class TransUniformfSettings :
16
+ class TransUniformSettings :
15
17
name : Literal ["uniform" ] = "uniform"
16
18
min : float = 0.0
17
19
max : float = 1.0
@@ -22,7 +24,7 @@ def trans(self, x: float) -> float:
22
24
23
25
24
26
@dataclass
25
- class TransNormalfSettings :
27
+ class TransNormalSettings :
26
28
name : Literal ["normal" ] = "normal"
27
29
mean : float = 0.0
28
30
std : float = 1.0
@@ -48,6 +50,53 @@ def trans(self, _: float) -> float:
48
50
return self .value
49
51
50
52
53
+ @dataclass
54
+ class TransErrfSettings :
55
+ name : Literal ["errf" ] = "errf"
56
+ min : float = 0.0
57
+ max : float = 1.0
58
+ skew : float = 0.0
59
+ width : float = 1.0
60
+
61
+ def trans (self , x : float ) -> float :
62
+ y = norm (loc = 0 , scale = self .width ).cdf (x + self .skew )
63
+ if np .isnan (y ):
64
+ raise ValueError (
65
+ "Output is nan, likely from triplet (x, skewness, width) "
66
+ "leading to low/high-probability in normal CDF."
67
+ )
68
+ return self .min + y * (self .max - self .min )
69
+
70
+
71
+ @dataclass
72
+ class TransDerrfSettings :
73
+ name : Literal ["derrf" ] = "derrf"
74
+ steps : int = 1000
75
+ min : float = 0.0
76
+ max : float = 1.0
77
+ skew : float = 0.0
78
+ width : float = 1.0
79
+
80
+ def trans (self , x : float ) -> float :
81
+ q_values = np .linspace (start = 0 , stop = 1 , num = self .steps )
82
+ q_checks = np .linspace (start = 0 , stop = 1 , num = self .steps + 1 )[1 :]
83
+ y = TransErrfSettings (min = 0 , max = 1 , skew = self .skew , width = self .width ).trans (x )
84
+ bin_index = np .digitize (y , q_checks , right = True )
85
+ y_binned = q_values [bin_index ]
86
+ result = self .min + y_binned * (self .max - self .min )
87
+ if result > self .max or result < self .min :
88
+ warnings .warn (
89
+ "trans_derff suffered from catastrophic loss of precision, clamping to min,max" ,
90
+ stacklevel = 1 ,
91
+ )
92
+ return np .clip (result , self .min , self .max )
93
+ if np .isnan (result ):
94
+ raise ValueError (
95
+ "trans_derrf returns nan, check that input arguments are reasonable"
96
+ )
97
+ return float (result )
98
+
99
+
51
100
@dataclass
52
101
class PolarsData :
53
102
name : Literal ["polars" ]
@@ -59,10 +108,12 @@ class ScalarParameter(ParameterConfig):
59
108
name : str
60
109
group : str
61
110
distribution : (
62
- TransUniformfSettings
111
+ TransUniformSettings
63
112
| TransRawSettings
64
113
| TransConstSettings
65
- | TransNormalfSettings
114
+ | TransNormalSettings
115
+ | TransErrfSettings
116
+ | TransDerrfSettings
66
117
)
67
118
active : bool
68
119
input_source : Literal ["design_matrix" , "sampled" ]
0 commit comments