12
12
use rand:: Rng ;
13
13
use crate :: { ziggurat_tables, Distribution , Open01 } ;
14
14
use crate :: utils:: ziggurat;
15
+ use num_traits:: Float ;
15
16
16
17
/// Samples floating-point numbers according to the normal distribution
17
18
/// `N(0, 1)` (a.k.a. a standard normal, or Gaussian). This is equivalent to
@@ -102,9 +103,9 @@ impl Distribution<f64> for StandardNormal {
102
103
///
103
104
/// [`StandardNormal`]: crate::StandardNormal
104
105
#[ derive( Clone , Copy , Debug ) ]
105
- pub struct Normal {
106
- mean : f64 ,
107
- std_dev : f64 ,
106
+ pub struct Normal < N > {
107
+ mean : N ,
108
+ std_dev : N ,
108
109
}
109
110
110
111
/// Error type returned from `Normal::new` and `LogNormal::new`.
@@ -114,12 +115,14 @@ pub enum Error {
114
115
StdDevTooSmall ,
115
116
}
116
117
117
- impl Normal {
118
+ impl < N : Float > Normal < N >
119
+ where StandardNormal : Distribution < N >
120
+ {
118
121
/// Construct a new `Normal` distribution with the given mean and
119
122
/// standard deviation.
120
123
#[ inline]
121
- pub fn new ( mean : f64 , std_dev : f64 ) -> Result < Normal , Error > {
122
- if !( std_dev >= 0.0 ) {
124
+ pub fn new ( mean : N , std_dev : N ) -> Result < Normal < N > , Error > {
125
+ if !( std_dev >= N :: zero ( ) ) {
123
126
return Err ( Error :: StdDevTooSmall ) ;
124
127
}
125
128
Ok ( Normal {
@@ -128,9 +131,12 @@ impl Normal {
128
131
} )
129
132
}
130
133
}
131
- impl Distribution < f64 > for Normal {
132
- fn sample < R : Rng + ?Sized > ( & self , rng : & mut R ) -> f64 {
133
- let n: f64 = rng. sample ( StandardNormal ) ;
134
+
135
+ impl < N : Float > Distribution < N > for Normal < N >
136
+ where StandardNormal : Distribution < N >
137
+ {
138
+ fn sample < R : Rng + ?Sized > ( & self , rng : & mut R ) -> N {
139
+ let n: N = rng. sample ( StandardNormal ) ;
134
140
self . mean + self . std_dev * n
135
141
}
136
142
}
@@ -152,23 +158,28 @@ impl Distribution<f64> for Normal {
152
158
/// println!("{} is from an ln N(2, 9) distribution", v)
153
159
/// ```
154
160
#[ derive( Clone , Copy , Debug ) ]
155
- pub struct LogNormal {
156
- norm : Normal
161
+ pub struct LogNormal < N > {
162
+ norm : Normal < N >
157
163
}
158
164
159
- impl LogNormal {
165
+ impl < N : Float > LogNormal < N >
166
+ where StandardNormal : Distribution < N >
167
+ {
160
168
/// Construct a new `LogNormal` distribution with the given mean
161
169
/// and standard deviation of the logarithm of the distribution.
162
170
#[ inline]
163
- pub fn new ( mean : f64 , std_dev : f64 ) -> Result < LogNormal , Error > {
164
- if !( std_dev >= 0.0 ) {
171
+ pub fn new ( mean : N , std_dev : N ) -> Result < LogNormal < N > , Error > {
172
+ if !( std_dev >= N :: zero ( ) ) {
165
173
return Err ( Error :: StdDevTooSmall ) ;
166
174
}
167
175
Ok ( LogNormal { norm : Normal :: new ( mean, std_dev) . unwrap ( ) } )
168
176
}
169
177
}
170
- impl Distribution < f64 > for LogNormal {
171
- fn sample < R : Rng + ?Sized > ( & self , rng : & mut R ) -> f64 {
178
+
179
+ impl < N : Float > Distribution < N > for LogNormal < N >
180
+ where StandardNormal : Distribution < N >
181
+ {
182
+ fn sample < R : Rng + ?Sized > ( & self , rng : & mut R ) -> N {
172
183
self . norm . sample ( rng) . exp ( )
173
184
}
174
185
}
0 commit comments