forked from leejet/stable-diffusion.cpp
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrng_mt19937.hpp
More file actions
147 lines (130 loc) · 5.41 KB
/
rng_mt19937.hpp
File metadata and controls
147 lines (130 loc) · 5.41 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
#ifndef __RNG_MT19937_HPP__
#define __RNG_MT19937_HPP__
#include <cmath>
#include <vector>
#include "rng.hpp"
// RNG imitiating torch cpu randn on CPU.
// Port from pytorch, original license: https://github.com/pytorch/pytorch/blob/d01a7b0241ed1c4cded7e7ca097249feb343f072/LICENSE
// Ref: https://github.com/pytorch/pytorch/blob/d01a7b0241ed1c4cded7e7ca097249feb343f072/aten/src/ATen/core/TransformationHelper.h, for uniform_real
// Ref: https://github.com/pytorch/pytorch/blob/d01a7b0241ed1c4cded7e7ca097249feb343f072/aten/src/ATen/native/cpu/DistributionTemplates.h, for normal_kernel/normal_fill/normal_fill_16
// Ref: https://github.com/pytorch/pytorch/blob/d01a7b0241ed1c4cded7e7ca097249feb343f072/aten/src/ATen/core/MT19937RNGEngine.h, for mt19937_engine
// Ref: https://github.com/pytorch/pytorch/blob/d01a7b0241ed1c4cded7e7ca097249feb343f072/aten/src/ATen/core/DistributionsHelper.h, for uniform_real_distribution/normal_distribution
class MT19937RNG : public RNG {
static const int N = 624;
static const int M = 397;
static const uint32_t MATRIX_A = 0x9908b0dfU;
static const uint32_t UMASK = 0x80000000U;
static const uint32_t LMASK = 0x7fffffffU;
struct State {
uint64_t seed_;
int left_;
bool seeded_;
uint32_t next_;
std::array<uint32_t, N> state_;
bool has_next_gauss = false;
double next_gauss = 0.0f;
};
State s;
uint32_t mix_bits(uint32_t u, uint32_t v) { return (u & UMASK) | (v & LMASK); }
uint32_t twist(uint32_t u, uint32_t v) { return (mix_bits(u, v) >> 1) ^ ((v & 1) ? MATRIX_A : 0); }
void next_state() {
uint32_t* p = s.state_.data();
s.left_ = N;
s.next_ = 0;
for (int j = N - M + 1; --j; p++)
p[0] = p[M] ^ twist(p[0], p[1]);
for (int j = M; --j; p++)
p[0] = p[M - N] ^ twist(p[0], p[1]);
p[0] = p[M - N] ^ twist(p[0], s.state_[0]);
}
uint32_t rand_uint32() {
if (--s.left_ == 0)
next_state();
uint32_t y = s.state_[s.next_++];
y ^= (y >> 11);
y ^= (y << 7) & 0x9d2c5680U;
y ^= (y << 15) & 0xefc60000U;
y ^= (y >> 18);
return y;
}
uint64_t rand_uint64() {
uint64_t high = (uint64_t)rand_uint32();
uint64_t low = (uint64_t)rand_uint32();
return (high << 32) | low;
}
template <typename T, typename V>
T uniform_real(V val, T from, T to) {
constexpr auto MASK = static_cast<V>((static_cast<uint64_t>(1) << std::numeric_limits<T>::digits) - 1);
constexpr auto DIVISOR = static_cast<T>(1) / (static_cast<uint64_t>(1) << std::numeric_limits<T>::digits);
T x = (val & MASK) * DIVISOR;
return (x * (to - from) + from);
}
double normal_double_value(double mean, double std) {
if (s.has_next_gauss) {
s.has_next_gauss = false;
return s.next_gauss;
}
double u1 = uniform_real(rand_uint64(), 0., 1.); // double
double u2 = uniform_real(rand_uint64(), 0., 1.); // double
double r = std::sqrt(-2.0 * std::log1p(-u2));
double theta = 2.0 * 3.14159265358979323846 * u1;
double value = r * std::cos(theta) * std + mean;
s.next_gauss = r * std::sin(theta) * std + mean;
s.has_next_gauss = true;
return value;
}
void normal_fill_16(float* data, float mean, float std) {
for (int j = 0; j < 8; ++j) {
float u1 = 1.0f - data[j];
float u2 = data[j + 8];
float r = std::sqrt(-2.0f * std::log(u1));
float theta = 2.0f * 3.14159265358979323846 * u2;
data[j] = r * std::cos(theta) * std + mean;
data[j + 8] = r * std::sin(theta) * std + mean;
}
}
void randn(float* data, int64_t size, float mean = 0.0f, float std = 1.0f) {
if (size >= 16) {
for (int64_t i = 0; i < size; i++) {
data[i] = uniform_real(rand_uint32(), 0.f, 1.f);
}
for (int64_t i = 0; i < size - 15; i += 16) {
normal_fill_16(data + i, mean, std);
}
if (size % 16 != 0) {
// Recompute the last 16 values.
data = data + size - 16;
for (int64_t i = 0; i < 16; i++) {
data[i] = uniform_real(rand_uint32(), 0.f, 1.f);
}
normal_fill_16(data, mean, std);
}
} else {
// Strange handling, hard to understand, but keeping it consistent with PyTorch.
for (int64_t i = 0; i < size; i++) {
data[i] = (float)normal_double_value(mean, std);
}
}
}
public:
MT19937RNG(uint64_t seed = 0) { manual_seed(seed); }
void manual_seed(uint64_t seed) override {
s.seed_ = seed;
s.seeded_ = true;
s.state_[0] = (uint32_t)(seed & 0xffffffffU);
for (int j = 1; j < N; j++) {
uint32_t prev = s.state_[j - 1];
s.state_[j] = 1812433253U * (prev ^ (prev >> 30)) + j;
}
s.left_ = 1;
s.next_ = 0;
s.has_next_gauss = false;
}
std::vector<float> randn(uint32_t n) override {
std::vector<float> out;
out.resize(n);
randn((float*)out.data(), out.size());
return out;
}
};
#endif // __RNG_MT19937_HPP__