Skip to content
367 changes: 367 additions & 0 deletions examples/rhat_monitor.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,367 @@
// clang++ -std=c++20 -O3 -pthread rhat_monitor.cpp -o rhat_monitor
// ./rhat_monitor

#include <atomic>
#include <chrono>
#include <cmath>
#include <fstream>
#include <functional>
#include <iostream>
#include <latch>
#include <numeric>
#include <random>
#include <stop_token>
#include <string>
#include <thread>
#include <vector>

#define VERY_INLINE [[gnu::always_inline]] inline
#ifdef __APPLE__
#include <pthread.h>
VERY_INLINE void interactive_qos() {
pthread_set_qos_class_self_np(QOS_CLASS_USER_INTERACTIVE, 0); // best
}
VERY_INLINE void initiated_qos() {
pthread_set_qos_class_self_np(QOS_CLASS_USER_INITIATED, 0); // next best
}
#else
VERY_INLINE void interactive_qos() {}
VERY_INLINE void initiated_qos() {}
#endif

double sum(const std::vector<double>& xs) noexcept {
return std::transform_reduce(xs.begin(), xs.end(), 0.0, std::plus<>{},
std::identity());
}

std::size_t sum(const std::vector<std::size_t>& xs) noexcept {
return std::transform_reduce(xs.begin(), xs.end(), 0, std::plus<>{},
std::identity());
}

double mean(const std::vector<double>& xs) noexcept {
return sum(xs) / xs.size();
}

double variance(const std::vector<double>& xs) noexcept {
std::size_t N = xs.size();
if (N < 2) {
return std::numeric_limits<double>::quiet_NaN();
}
double mean_xs = mean(xs);
double sum = std::transform_reduce(xs.begin(), xs.end(), 0.0, std::plus<>{},
[mean_xs](double x) {
double diff = x - mean_xs;
return diff * diff;
});
return sum / (N - 1);
}

struct ChainStats {
double sample_mean;
double sample_var;
std::size_t count;
};

class AtomicChainStats {
public:
AtomicChainStats() noexcept {
ChainStats init{std::numeric_limits<double>::quiet_NaN(),
std::numeric_limits<double>::quiet_NaN(), 0u};
data_.store(init, std::memory_order_relaxed);
}

void store(const ChainStats& p) noexcept {
data_.store(p, std::memory_order_release);
}

ChainStats load() const noexcept {
return data_.load(std::memory_order_acquire);
}

private:
std::atomic<ChainStats> data_;
};

class ChainRecord {
public:
ChainRecord(std::size_t chain_id, std::size_t D, std::size_t Nmax)
: chain_id_(chain_id), D_(D), Nmax_(Nmax) {
theta_.reserve(Nmax * D);
logp_.reserve(Nmax);
}

std::vector<double>& draws() noexcept { return theta_; }

std::size_t dims() const noexcept { return D_; }

std::size_t num_draws() const noexcept { return logp_.size(); }

inline double logp(std::size_t n) const { return logp_[n]; }

inline double operator()(std::size_t n, std::size_t d) const {
return theta_[n * D_ + d];
}

void write_csv(std::ostream& out) const {
auto dim = dims();
for (std::size_t n = 0; n < num_draws(); ++n) {
out << chain_id_ << ',' << n << ',' << logp_[n];
for (std::size_t d = 0; d < dim; ++d) {
out << ',' << operator()(n, d);
}
out << '\n';
}
}

void append_logp(double logp) { logp_.push_back(logp); }

private:
std::size_t chain_id_;
std::size_t D_;
std::size_t Nmax_;
std::vector<double> theta_;
std::vector<double> logp_;
};

static void write_csv_header(std::ofstream& out, std::size_t dim) {
out << "chain,iteration,log_density";
for (std::size_t d = 0; d < dim; ++d) {
out << ",theta[" << d << "]";
}
out << '\n';
}

static void write_csv(const std::string& path, std::size_t dim,
const std::vector<ChainRecord>& chain_records) {
std::ofstream out(path, std::ios::binary); // binary for Windows consistency
if (!out) {
throw std::runtime_error("could not open file: " + path);
}
write_csv_header(out, dim);
for (const auto& chain_record : chain_records) {
chain_record.write_csv(out);
}
}

class WelfordAccumulator {
public:
WelfordAccumulator() : n_(0), mean_(0.0), M2_(0.0) {}

void observe(double x) {
++n_;
const double delta = x - mean_;
mean_ += delta / static_cast<double>(n_);
const double delta2 = x - mean_;
M2_ += delta * delta2;
}

std::size_t count() const { return n_; }

double mean() const { return mean_; }

double sample_variance() const {
return n_ > 1 ? (M2_ / static_cast<double>(n_ - 1))
: std::numeric_limits<double>::quiet_NaN();
}

ChainStats sample_stats() { return {mean(), sample_variance(), count()}; }

void reset() {
n_ = 0;
mean_ = 0.0;
M2_ = 0.0;
}

private:
std::size_t n_;
double mean_;
double M2_;
};

template <class Sampler>
class ChainWorker {
public:
ChainWorker(std::size_t chain_id, std::size_t draws_per_chain,
Sampler& sampler, AtomicChainStats& acs, std::latch& start_gate)
: chain_id_(chain_id),
draws_per_chain_(draws_per_chain),
sampler_(sampler),
chain_record_(chain_id, sampler.dim(), draws_per_chain),
acs_(acs),
start_gate_(start_gate) {}

void operator()(std::stop_token st) {
interactive_qos();
start_gate_.get().arrive_and_wait();
for (std::size_t iter = 0; iter < draws_per_chain_; ++iter) {
if ((iter + 1) % 64 == 0) {
std::this_thread::yield();
}
double logp = sampler_.get().sample(chain_record_.draws());
chain_record_.append_logp(logp);
logp_stats_.observe(logp);
acs_.store(logp_stats_.sample_stats());
if (st.stop_requested()) {
break;
}
}
}

ChainRecord&& take_chain_record() { return std::move(chain_record_); }

private:
std::size_t chain_id_;
std::size_t draws_per_chain_;
std::reference_wrapper<Sampler> sampler_;
ChainRecord chain_record_;
WelfordAccumulator logp_stats_;
AtomicChainStats& acs_;
std::reference_wrapper<std::latch> start_gate_;
};

void debug_print(double variance_of_means, double mean_of_variances,
double num_draws, double r_hat,
const std::vector<std::size_t>& counts) {
auto M = counts.size();
std::cout << "RHAT: " << r_hat << " NUM_DRAWS: " << num_draws
<< " COUNTS: ";
for (std::size_t m = 0; m < M; ++m) {
if (m > 0) {
std::cout << ", ";
}
std::cout << counts[m];
}
std::cout << std::endl;
}

static void controller_loop(std::vector<AtomicChainStats>& chain_statses,
std::vector<std::jthread>& workers,
double rhat_threshold, std::latch& start_gate,
std::size_t max_draws_per_chain,
std::stop_source& stopper) {
interactive_qos();
start_gate.wait();
const std::size_t M = chain_statses.size();
std::vector<double> chain_means(M, std::numeric_limits<double>::quiet_NaN());
std::vector<double> chain_variances(M,
std::numeric_limits<double>::quiet_NaN());
std::vector<std::size_t> counts(M, 0);
while (true) {
for (std::size_t m = 0; m < M; ++m) {
ChainStats u = chain_statses[m].load();
chain_means[m] = u.sample_mean;
chain_variances[m] = u.sample_var;
counts[m] = u.count;
}
double variance_of_means = variance(chain_means);
double mean_of_variances = mean(chain_variances);
double r_hat = std::sqrt(1 + variance_of_means / mean_of_variances);
std::size_t num_draws = sum(counts);

debug_print(variance_of_means, mean_of_variances, num_draws, r_hat, counts);

if (r_hat <= rhat_threshold || num_draws == M * max_draws_per_chain) {
stopper.request_stop();
break;
}

std::this_thread::sleep_for(std::chrono::microseconds{16});
}
}

// Sampler requires: { double sample(vector<double>& draw); size_t dim(); }
template <typename Sampler>
std::vector<ChainRecord> sample(std::vector<Sampler>& samplers,
double rhat_threshold,
std::size_t max_draws_per_chain) {
std::size_t M = samplers.size();
std::vector<AtomicChainStats> chain_statses(M);
std::latch start_gate(M);
std::vector<ChainWorker<Sampler>> workers;
workers.reserve(M);
for (std::size_t m = 0; m < M; ++m) {
workers.emplace_back(m, max_draws_per_chain, samplers[m], chain_statses[m],
start_gate);
}
{
std::stop_source stopper;
std::vector<std::jthread> threads;
threads.reserve(M);
for (std::size_t m = 0; m < M; ++m) {
threads.emplace_back(std::ref(workers[m]), stopper.get_token());
}
controller_loop(chain_statses, threads, rhat_threshold, start_gate,
max_draws_per_chain, stopper);
}

std::vector<ChainRecord> chain_records;
chain_records.reserve(M);
for (auto& worker : workers) {
chain_records.emplace_back(worker.take_chain_record());
}
return chain_records;
}

// ****************** EXAMPLE USAGE AFTER HERE ************************

class StandardNormalSampler {
public:
explicit StandardNormalSampler(unsigned int seed, std::size_t dim)
: dim_(dim), engine_(seed), normal_dist_(0, 1) {}

double sample(std::vector<double>& draw) noexcept {
double lp = 0;
for (std::size_t i = 0; i < dim_; ++i) {
double x = normal_dist_(engine_);
draw.push_back(x);
lp += -0.5 * x * x; // unnomalized
}
return lp;
}

std::size_t dim() const noexcept { return dim_; }

private:
std::size_t dim_;
std::mt19937_64 engine_;
std::normal_distribution<double> normal_dist_;
};

int main() {
const std::string csv_path = "samples.csv";
const std::size_t D = 4;
std::size_t M = 16;
const std::size_t N = 1000;
double rhat_threshold = 1.001;

std::random_device rd;
std::vector<StandardNormalSampler> samplers;
samplers.reserve(M);
for (std::size_t m = 0; m < M; ++m) {
auto seed = rd();
samplers.emplace_back(seed, D);
}

std::vector<ChainRecord> chain_records = sample(samplers, rhat_threshold, N);
std::size_t rows = 0;
for (std::size_t m = 0; m < chain_records.size(); ++m) {
const auto& chain_record = chain_records[m];
std::size_t N_m = chain_record.num_draws();
std::vector<double> lps(N_m);
for (std::size_t n = 0; n < N_m; ++n) {
lps[n] = chain_record.logp(n);
}
rows += N_m;
std::cout << "Chain " << m << " count=" << N_m
<< " Final: mean(logp)=" << mean(lps)
<< " var(logp) [sample]=" << variance(lps) << '\n';
}
std::cout << "Number of draws: " << rows << '\n';

// UNCOMMENT TO DUMP CSV
// write_csv(csv_path, D, chain_records);
// std::cout << "Wrote draws to " << csv_path << '\n';

return 0;
}