Skip to content

Commit 19829a4

Browse files
committed
impl Distribution<f32> for Gamma distribution
1 parent 8e3ed11 commit 19829a4

File tree

1 file changed

+60
-45
lines changed

1 file changed

+60
-45
lines changed

rand_distr/src/gamma.rs

Lines changed: 60 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ use self::ChiSquaredRepr::*;
1414

1515
use rand::Rng;
1616
use crate::normal::StandardNormal;
17-
use crate::{Distribution, Exp, Open01};
17+
use crate::{Distribution, Exp1, Exp, Open01};
18+
use num_traits::Float;
1819

1920
/// The Gamma distribution `Gamma(shape, scale)` distribution.
2021
///
@@ -47,8 +48,8 @@ use crate::{Distribution, Exp, Open01};
4748
/// (September 2000), 363-372.
4849
/// DOI:[10.1145/358407.358414](https://doi.acm.org/10.1145/358407.358414)
4950
#[derive(Clone, Copy, Debug)]
50-
pub struct Gamma {
51-
repr: GammaRepr,
51+
pub struct Gamma<N> {
52+
repr: GammaRepr<N>,
5253
}
5354

5455
/// Error type returned from `Gamma::new`.
@@ -63,10 +64,10 @@ pub enum Error {
6364
}
6465

6566
#[derive(Clone, Copy, Debug)]
66-
enum GammaRepr {
67-
Large(GammaLargeShape),
68-
One(Exp<f64>),
69-
Small(GammaSmallShape)
67+
enum GammaRepr<N> {
68+
Large(GammaLargeShape<N>),
69+
One(Exp<N>),
70+
Small(GammaSmallShape<N>)
7071
}
7172

7273
// These two helpers could be made public, but saving the
@@ -84,37 +85,39 @@ enum GammaRepr {
8485
/// See `Gamma` for sampling from a Gamma distribution with general
8586
/// shape parameters.
8687
#[derive(Clone, Copy, Debug)]
87-
struct GammaSmallShape {
88-
inv_shape: f64,
89-
large_shape: GammaLargeShape
88+
struct GammaSmallShape<N> {
89+
inv_shape: N,
90+
large_shape: GammaLargeShape<N>
9091
}
9192

9293
/// Gamma distribution where the shape parameter is larger than 1.
9394
///
9495
/// See `Gamma` for sampling from a Gamma distribution with general
9596
/// shape parameters.
9697
#[derive(Clone, Copy, Debug)]
97-
struct GammaLargeShape {
98-
scale: f64,
99-
c: f64,
100-
d: f64
98+
struct GammaLargeShape<N> {
99+
scale: N,
100+
c: N,
101+
d: N
101102
}
102103

103-
impl Gamma {
104+
impl<N: Float> Gamma<N>
105+
where StandardNormal: Distribution<N>, Exp1: Distribution<N>, Open01: Distribution<N>
106+
{
104107
/// Construct an object representing the `Gamma(shape, scale)`
105108
/// distribution.
106109
#[inline]
107-
pub fn new(shape: f64, scale: f64) -> Result<Gamma, Error> {
108-
if !(shape > 0.0) {
110+
pub fn new(shape: N, scale: N) -> Result<Gamma<N>, Error> {
111+
if !(shape > N::zero()) {
109112
return Err(Error::ShapeTooSmall);
110113
}
111-
if !(scale > 0.0) {
114+
if !(scale > N::zero()) {
112115
return Err(Error::ScaleTooSmall);
113116
}
114117

115-
let repr = if shape == 1.0 {
116-
One(Exp::new(1.0 / scale).map_err(|_| Error::ScaleTooLarge)?)
117-
} else if shape < 1.0 {
118+
let repr = if shape == N::one() {
119+
One(Exp::new(N::one() / scale).map_err(|_| Error::ScaleTooLarge)?)
120+
} else if shape < N::one() {
118121
Small(GammaSmallShape::new_raw(shape, scale))
119122
} else {
120123
Large(GammaLargeShape::new_raw(shape, scale))
@@ -123,57 +126,69 @@ impl Gamma {
123126
}
124127
}
125128

126-
impl GammaSmallShape {
127-
fn new_raw(shape: f64, scale: f64) -> GammaSmallShape {
129+
impl<N: Float> GammaSmallShape<N>
130+
where StandardNormal: Distribution<N>, Open01: Distribution<N>
131+
{
132+
fn new_raw(shape: N, scale: N) -> GammaSmallShape<N> {
128133
GammaSmallShape {
129-
inv_shape: 1. / shape,
130-
large_shape: GammaLargeShape::new_raw(shape + 1.0, scale)
134+
inv_shape: N::one() / shape,
135+
large_shape: GammaLargeShape::new_raw(shape + N::one(), scale)
131136
}
132137
}
133138
}
134139

135-
impl GammaLargeShape {
136-
fn new_raw(shape: f64, scale: f64) -> GammaLargeShape {
137-
let d = shape - 1. / 3.;
140+
impl<N: Float> GammaLargeShape<N>
141+
where StandardNormal: Distribution<N>, Open01: Distribution<N>
142+
{
143+
fn new_raw(shape: N, scale: N) -> GammaLargeShape<N> {
144+
let d = shape - N::from(1. / 3.).unwrap();
138145
GammaLargeShape {
139146
scale,
140-
c: 1. / (9. * d).sqrt(),
147+
c: N::one() / (N::from(9.).unwrap() * d).sqrt(),
141148
d
142149
}
143150
}
144151
}
145152

146-
impl Distribution<f64> for Gamma {
147-
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
153+
impl<N: Float> Distribution<N> for Gamma<N>
154+
where StandardNormal: Distribution<N>, Exp1: Distribution<N>, Open01: Distribution<N>
155+
{
156+
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> N {
148157
match self.repr {
149158
Small(ref g) => g.sample(rng),
150159
One(ref g) => g.sample(rng),
151160
Large(ref g) => g.sample(rng),
152161
}
153162
}
154163
}
155-
impl Distribution<f64> for GammaSmallShape {
156-
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
157-
let u: f64 = rng.sample(Open01);
164+
impl<N: Float> Distribution<N> for GammaSmallShape<N>
165+
where StandardNormal: Distribution<N>, Open01: Distribution<N>
166+
{
167+
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> N {
168+
let u: N = rng.sample(Open01);
158169

159170
self.large_shape.sample(rng) * u.powf(self.inv_shape)
160171
}
161172
}
162-
impl Distribution<f64> for GammaLargeShape {
163-
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
173+
impl<N: Float> Distribution<N> for GammaLargeShape<N>
174+
where StandardNormal: Distribution<N>, Open01: Distribution<N>
175+
{
176+
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> N {
177+
// Marsaglia & Tsang method, 2000
164178
loop {
165-
let x: f64 = rng.sample(StandardNormal);
166-
let v_cbrt = 1.0 + self.c * x;
167-
if v_cbrt <= 0.0 { // a^3 <= 0 iff a <= 0
179+
let x: N = rng.sample(StandardNormal);
180+
let v_cbrt = N::one() + self.c * x;
181+
if v_cbrt <= N::zero() { // a^3 <= 0 iff a <= 0
168182
continue
169183
}
170184

171185
let v = v_cbrt * v_cbrt * v_cbrt;
172-
let u: f64 = rng.sample(Open01);
186+
let u: N = rng.sample(Open01);
173187

174188
let x_sqr = x * x;
175-
if u < 1.0 - 0.0331 * x_sqr * x_sqr ||
176-
u.ln() < 0.5 * x_sqr + self.d * (1.0 - v + v.ln()) {
189+
if u < N::one() - N::from(0.0331).unwrap() * x_sqr * x_sqr ||
190+
u.ln() < N::from(0.5).unwrap() * x_sqr + self.d * (N::one() - v + v.ln())
191+
{
177192
return self.d * v * self.scale
178193
}
179194
}
@@ -215,7 +230,7 @@ enum ChiSquaredRepr {
215230
// e.g. when alpha = 1/2 as it would be for this case, so special-
216231
// casing and using the definition of N(0,1)^2 is faster.
217232
DoFExactlyOne,
218-
DoFAnythingElse(Gamma),
233+
DoFAnythingElse(Gamma<f64>),
219234
}
220235

221236
impl ChiSquared {
@@ -350,8 +365,8 @@ impl Distribution<f64> for StudentT {
350365
/// ```
351366
#[derive(Clone, Copy, Debug)]
352367
pub struct Beta {
353-
gamma_a: Gamma,
354-
gamma_b: Gamma,
368+
gamma_a: Gamma<f64>,
369+
gamma_b: Gamma<f64>,
355370
}
356371

357372
/// Error type returned from `Beta::new`.

0 commit comments

Comments
 (0)