Skip to content

Commit 59bc4d2

Browse files
committed
impl Distribution<N> for Normal and LogNormal
Adds dependency on num-traits
1 parent 09dd014 commit 59bc4d2

File tree

2 files changed

+28
-16
lines changed

2 files changed

+28
-16
lines changed

rand_distr/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,4 @@ appveyor = { repository = "rust-random/rand" }
2020

2121
[dependencies]
2222
rand = { path = "..", version = ">=0.5, <=0.7" }
23+
num-traits = "0.2"

rand_distr/src/normal.rs

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
use rand::Rng;
1313
use crate::{ziggurat_tables, Distribution, Open01};
1414
use crate::utils::ziggurat;
15+
use num_traits::Float;
1516

1617
/// Samples floating-point numbers according to the normal distribution
1718
/// `N(0, 1)` (a.k.a. a standard normal, or Gaussian). This is equivalent to
@@ -102,9 +103,9 @@ impl Distribution<f64> for StandardNormal {
102103
///
103104
/// [`StandardNormal`]: crate::StandardNormal
104105
#[derive(Clone, Copy, Debug)]
105-
pub struct Normal {
106-
mean: f64,
107-
std_dev: f64,
106+
pub struct Normal<N> {
107+
mean: N,
108+
std_dev: N,
108109
}
109110

110111
/// Error type returned from `Normal::new` and `LogNormal::new`.
@@ -114,12 +115,14 @@ pub enum Error {
114115
StdDevTooSmall,
115116
}
116117

117-
impl Normal {
118+
impl<N: Float> Normal<N>
119+
where StandardNormal: Distribution<N>
120+
{
118121
/// Construct a new `Normal` distribution with the given mean and
119122
/// standard deviation.
120123
#[inline]
121-
pub fn new(mean: f64, std_dev: f64) -> Result<Normal, Error> {
122-
if !(std_dev >= 0.0) {
124+
pub fn new(mean: N, std_dev: N) -> Result<Normal<N>, Error> {
125+
if !(std_dev >= N::zero()) {
123126
return Err(Error::StdDevTooSmall);
124127
}
125128
Ok(Normal {
@@ -128,9 +131,12 @@ impl Normal {
128131
})
129132
}
130133
}
131-
impl Distribution<f64> for Normal {
132-
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
133-
let n: f64 = rng.sample(StandardNormal);
134+
135+
impl<N: Float> Distribution<N> for Normal<N>
136+
where StandardNormal: Distribution<N>
137+
{
138+
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> N {
139+
let n: N = rng.sample(StandardNormal);
134140
self.mean + self.std_dev * n
135141
}
136142
}
@@ -152,23 +158,28 @@ impl Distribution<f64> for Normal {
152158
/// println!("{} is from an ln N(2, 9) distribution", v)
153159
/// ```
154160
#[derive(Clone, Copy, Debug)]
155-
pub struct LogNormal {
156-
norm: Normal
161+
pub struct LogNormal<N> {
162+
norm: Normal<N>
157163
}
158164

159-
impl LogNormal {
165+
impl<N: Float> LogNormal<N>
166+
where StandardNormal: Distribution<N>
167+
{
160168
/// Construct a new `LogNormal` distribution with the given mean
161169
/// and standard deviation of the logarithm of the distribution.
162170
#[inline]
163-
pub fn new(mean: f64, std_dev: f64) -> Result<LogNormal, Error> {
164-
if !(std_dev >= 0.0) {
171+
pub fn new(mean: N, std_dev: N) -> Result<LogNormal<N>, Error> {
172+
if !(std_dev >= N::zero()) {
165173
return Err(Error::StdDevTooSmall);
166174
}
167175
Ok(LogNormal { norm: Normal::new(mean, std_dev).unwrap() })
168176
}
169177
}
170-
impl Distribution<f64> for LogNormal {
171-
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
178+
179+
impl<N: Float> Distribution<N> for LogNormal<N>
180+
where StandardNormal: Distribution<N>
181+
{
182+
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> N {
172183
self.norm.sample(rng).exp()
173184
}
174185
}

0 commit comments

Comments
 (0)