Skip to content

Commit 0f1b1ff

Browse files
authored
Merge pull request #793 from dhardy/distr
Generic distributions: use custom trait
2 parents b664e64 + b330c21 commit 0f1b1ff

14 files changed

+288
-166
lines changed

rand_distr/Cargo.toml

-1
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,3 @@ appveyor = { repository = "rust-random/rand" }
2020

2121
[dependencies]
2222
rand = { path = "..", version = ">=0.5, <=0.7" }
23-
num-traits = "0.2"

rand_distr/src/cauchy.rs

+17-13
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
//! The Cauchy distribution.
1111
1212
use rand::Rng;
13-
use crate::Distribution;
14-
use std::f64::consts::PI;
13+
use crate::{Distribution, Standard};
14+
use crate::utils::Float;
1515

1616
/// The Cauchy distribution `Cauchy(median, scale)`.
1717
///
@@ -28,9 +28,9 @@ use std::f64::consts::PI;
2828
/// println!("{} is from a Cauchy(2, 5) distribution", v);
2929
/// ```
3030
#[derive(Clone, Copy, Debug)]
31-
pub struct Cauchy {
32-
median: f64,
33-
scale: f64
31+
pub struct Cauchy<N> {
32+
median: N,
33+
scale: N,
3434
}
3535

3636
/// Error type returned from `Cauchy::new`.
@@ -40,11 +40,13 @@ pub enum Error {
4040
ScaleTooSmall,
4141
}
4242

43-
impl Cauchy {
43+
impl<N: Float> Cauchy<N>
44+
where Standard: Distribution<N>
45+
{
4446
/// Construct a new `Cauchy` with the given shape parameters
4547
/// `median` the peak location and `scale` the scale factor.
46-
pub fn new(median: f64, scale: f64) -> Result<Cauchy, Error> {
47-
if !(scale > 0.0) {
48+
pub fn new(median: N, scale: N) -> Result<Cauchy<N>, Error> {
49+
if !(scale > N::from(0.0)) {
4850
return Err(Error::ScaleTooSmall);
4951
}
5052
Ok(Cauchy {
@@ -54,13 +56,15 @@ impl Cauchy {
5456
}
5557
}
5658

57-
impl Distribution<f64> for Cauchy {
58-
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
59+
impl<N: Float> Distribution<N> for Cauchy<N>
60+
where Standard: Distribution<N>
61+
{
62+
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> N {
5963
// sample from [0, 1)
60-
let x = rng.gen::<f64>();
64+
let x = Standard.sample(rng);
6165
// get standard cauchy random number
6266
// note that π/2 is not exactly representable, even if x=0.5 the result is finite
63-
let comp_dev = (PI * x).tan();
67+
let comp_dev = (N::pi() * x).tan();
6468
// shift and scale according to parameters
6569
let result = self.median + self.scale * comp_dev;
6670
result
@@ -99,7 +103,7 @@ mod test {
99103
fn test_cauchy_mean() {
100104
let cauchy = Cauchy::new(10.0, 5.0).unwrap();
101105
let mut rng = crate::test::rng(123);
102-
let mut sum = 0.0;
106+
let mut sum = 0.0f64;
103107
for _ in 0..1000 {
104108
sum += cauchy.sample(&mut rng);
105109
}

rand_distr/src/dirichlet.rs

+19-15
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
//! The dirichlet distribution.
1111
1212
use rand::Rng;
13-
use crate::Distribution;
14-
use crate::gamma::Gamma;
13+
use crate::{Distribution, Gamma, StandardNormal, Exp1, Open01};
14+
use crate::utils::Float;
1515

1616
/// The dirichelet distribution `Dirichlet(alpha)`.
1717
///
@@ -30,9 +30,9 @@ use crate::gamma::Gamma;
3030
/// println!("{:?} is from a Dirichlet([1.0, 2.0, 3.0]) distribution", samples);
3131
/// ```
3232
#[derive(Clone, Debug)]
33-
pub struct Dirichlet {
33+
pub struct Dirichlet<N> {
3434
/// Concentration parameters (alpha)
35-
alpha: Vec<f64>,
35+
alpha: Vec<N>,
3636
}
3737

3838
/// Error type returned from `Dirchlet::new`.
@@ -46,18 +46,20 @@ pub enum Error {
4646
SizeTooSmall,
4747
}
4848

49-
impl Dirichlet {
49+
impl<N: Float> Dirichlet<N>
50+
where StandardNormal: Distribution<N>, Exp1: Distribution<N>, Open01: Distribution<N>
51+
{
5052
/// Construct a new `Dirichlet` with the given alpha parameter `alpha`.
5153
///
5254
/// Requires `alpha.len() >= 2`.
5355
#[inline]
54-
pub fn new<V: Into<Vec<f64>>>(alpha: V) -> Result<Dirichlet, Error> {
56+
pub fn new<V: Into<Vec<N>>>(alpha: V) -> Result<Dirichlet<N>, Error> {
5557
let a = alpha.into();
5658
if a.len() < 2 {
5759
return Err(Error::AlphaTooShort);
5860
}
5961
for i in 0..a.len() {
60-
if !(a[i] > 0.0) {
62+
if !(a[i] > N::from(0.0)) {
6163
return Err(Error::AlphaTooSmall);
6264
}
6365
}
@@ -69,8 +71,8 @@ impl Dirichlet {
6971
///
7072
/// Requires `size >= 2`.
7173
#[inline]
72-
pub fn new_with_size(alpha: f64, size: usize) -> Result<Dirichlet, Error> {
73-
if !(alpha > 0.0) {
74+
pub fn new_with_size(alpha: N, size: usize) -> Result<Dirichlet<N>, Error> {
75+
if !(alpha > N::from(0.0)) {
7476
return Err(Error::AlphaTooSmall);
7577
}
7678
if size < 2 {
@@ -82,18 +84,20 @@ impl Dirichlet {
8284
}
8385
}
8486

85-
impl Distribution<Vec<f64>> for Dirichlet {
86-
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Vec<f64> {
87+
impl<N: Float> Distribution<Vec<N>> for Dirichlet<N>
88+
where StandardNormal: Distribution<N>, Exp1: Distribution<N>, Open01: Distribution<N>
89+
{
90+
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Vec<N> {
8791
let n = self.alpha.len();
88-
let mut samples = vec![0.0f64; n];
89-
let mut sum = 0.0f64;
92+
let mut samples = vec![N::from(0.0); n];
93+
let mut sum = N::from(0.0);
9094

9195
for i in 0..n {
92-
let g = Gamma::new(self.alpha[i], 1.0).unwrap();
96+
let g = Gamma::new(self.alpha[i], N::from(1.0)).unwrap();
9397
samples[i] = g.sample(rng);
9498
sum += samples[i];
9599
}
96-
let invacc = 1.0 / sum;
100+
let invacc = N::from(1.0) / sum;
97101
for i in 0..n {
98102
samples[i] *= invacc;
99103
}

rand_distr/src/exponential.rs

+3-4
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,7 @@
1111
1212
use rand::Rng;
1313
use crate::{ziggurat_tables, Distribution};
14-
use crate::utils::ziggurat;
15-
use num_traits::Float;
14+
use crate::utils::{ziggurat, Float};
1615

1716
/// Samples floating-point numbers according to the exponential distribution,
1817
/// with rate parameter `λ = 1`. This is equivalent to `Exp::new(1.0)` or
@@ -105,10 +104,10 @@ where Exp1: Distribution<N>
105104
/// `lambda`.
106105
#[inline]
107106
pub fn new(lambda: N) -> Result<Exp<N>, Error> {
108-
if !(lambda > N::zero()) {
107+
if !(lambda > N::from(0.0)) {
109108
return Err(Error::LambdaTooSmall);
110109
}
111-
Ok(Exp { lambda_inverse: N::one() / lambda })
110+
Ok(Exp { lambda_inverse: N::from(1.0) / lambda })
112111
}
113112
}
114113

0 commit comments

Comments
 (0)