diff --git a/examples/rhat_monitor.cpp b/examples/rhat_monitor.cpp new file mode 100644 index 0000000..39af262 --- /dev/null +++ b/examples/rhat_monitor.cpp @@ -0,0 +1,367 @@ +// clang++ -std=c++20 -O3 -pthread rhat_monitor.cpp -o rhat_monitor +// ./rhat_monitor + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define VERY_INLINE [[gnu::always_inline]] inline +#ifdef __APPLE__ +#include +VERY_INLINE void interactive_qos() { + pthread_set_qos_class_self_np(QOS_CLASS_USER_INTERACTIVE, 0); // best +} +VERY_INLINE void initiated_qos() { + pthread_set_qos_class_self_np(QOS_CLASS_USER_INITIATED, 0); // next best +} +#else +VERY_INLINE void interactive_qos() {} +VERY_INLINE void initiated_qos() {} +#endif + +double sum(const std::vector& xs) noexcept { + return std::transform_reduce(xs.begin(), xs.end(), 0.0, std::plus<>{}, + std::identity()); +} + +std::size_t sum(const std::vector& xs) noexcept { + return std::transform_reduce(xs.begin(), xs.end(), 0, std::plus<>{}, + std::identity()); +} + +double mean(const std::vector& xs) noexcept { + return sum(xs) / xs.size(); +} + +double variance(const std::vector& xs) noexcept { + std::size_t N = xs.size(); + if (N < 2) { + return std::numeric_limits::quiet_NaN(); + } + double mean_xs = mean(xs); + double sum = std::transform_reduce(xs.begin(), xs.end(), 0.0, std::plus<>{}, + [mean_xs](double x) { + double diff = x - mean_xs; + return diff * diff; + }); + return sum / (N - 1); +} + +struct ChainStats { + double sample_mean; + double sample_var; + std::size_t count; +}; + +class AtomicChainStats { + public: + AtomicChainStats() noexcept { + ChainStats init{std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), 0u}; + data_.store(init, std::memory_order_relaxed); + } + + void store(const ChainStats& p) noexcept { + data_.store(p, std::memory_order_release); + } + + ChainStats load() const noexcept { + return data_.load(std::memory_order_acquire); + } + + private: + std::atomic data_; +}; + +class ChainRecord { + public: + ChainRecord(std::size_t chain_id, std::size_t D, std::size_t Nmax) + : chain_id_(chain_id), D_(D), Nmax_(Nmax) { + theta_.reserve(Nmax * D); + logp_.reserve(Nmax); + } + + std::vector& draws() noexcept { return theta_; } + + std::size_t dims() const noexcept { return D_; } + + std::size_t num_draws() const noexcept { return logp_.size(); } + + inline double logp(std::size_t n) const { return logp_[n]; } + + inline double operator()(std::size_t n, std::size_t d) const { + return theta_[n * D_ + d]; + } + + void write_csv(std::ostream& out) const { + auto dim = dims(); + for (std::size_t n = 0; n < num_draws(); ++n) { + out << chain_id_ << ',' << n << ',' << logp_[n]; + for (std::size_t d = 0; d < dim; ++d) { + out << ',' << operator()(n, d); + } + out << '\n'; + } + } + + void append_logp(double logp) { logp_.push_back(logp); } + + private: + std::size_t chain_id_; + std::size_t D_; + std::size_t Nmax_; + std::vector theta_; + std::vector logp_; +}; + +static void write_csv_header(std::ofstream& out, std::size_t dim) { + out << "chain,iteration,log_density"; + for (std::size_t d = 0; d < dim; ++d) { + out << ",theta[" << d << "]"; + } + out << '\n'; +} + +static void write_csv(const std::string& path, std::size_t dim, + const std::vector& chain_records) { + std::ofstream out(path, std::ios::binary); // binary for Windows consistency + if (!out) { + throw std::runtime_error("could not open file: " + path); + } + write_csv_header(out, dim); + for (const auto& chain_record : chain_records) { + chain_record.write_csv(out); + } +} + +class WelfordAccumulator { + public: + WelfordAccumulator() : n_(0), mean_(0.0), M2_(0.0) {} + + void observe(double x) { + ++n_; + const double delta = x - mean_; + mean_ += delta / static_cast(n_); + const double delta2 = x - mean_; + M2_ += delta * delta2; + } + + std::size_t count() const { return n_; } + + double mean() const { return mean_; } + + double sample_variance() const { + return n_ > 1 ? (M2_ / static_cast(n_ - 1)) + : std::numeric_limits::quiet_NaN(); + } + + ChainStats sample_stats() { return {mean(), sample_variance(), count()}; } + + void reset() { + n_ = 0; + mean_ = 0.0; + M2_ = 0.0; + } + + private: + std::size_t n_; + double mean_; + double M2_; +}; + +template +class ChainWorker { + public: + ChainWorker(std::size_t chain_id, std::size_t draws_per_chain, + Sampler& sampler, AtomicChainStats& acs, std::latch& start_gate) + : chain_id_(chain_id), + draws_per_chain_(draws_per_chain), + sampler_(sampler), + chain_record_(chain_id, sampler.dim(), draws_per_chain), + acs_(acs), + start_gate_(start_gate) {} + + void operator()(std::stop_token st) { + interactive_qos(); + start_gate_.get().arrive_and_wait(); + for (std::size_t iter = 0; iter < draws_per_chain_; ++iter) { + if ((iter + 1) % 64 == 0) { + std::this_thread::yield(); + } + double logp = sampler_.get().sample(chain_record_.draws()); + chain_record_.append_logp(logp); + logp_stats_.observe(logp); + acs_.store(logp_stats_.sample_stats()); + if (st.stop_requested()) { + break; + } + } + } + + ChainRecord&& take_chain_record() { return std::move(chain_record_); } + + private: + std::size_t chain_id_; + std::size_t draws_per_chain_; + std::reference_wrapper sampler_; + ChainRecord chain_record_; + WelfordAccumulator logp_stats_; + AtomicChainStats& acs_; + std::reference_wrapper start_gate_; +}; + +void debug_print(double variance_of_means, double mean_of_variances, + double num_draws, double r_hat, + const std::vector& counts) { + auto M = counts.size(); + std::cout << "RHAT: " << r_hat << " NUM_DRAWS: " << num_draws + << " COUNTS: "; + for (std::size_t m = 0; m < M; ++m) { + if (m > 0) { + std::cout << ", "; + } + std::cout << counts[m]; + } + std::cout << std::endl; +} + +static void controller_loop(std::vector& chain_statses, + std::vector& workers, + double rhat_threshold, std::latch& start_gate, + std::size_t max_draws_per_chain, + std::stop_source& stopper) { + interactive_qos(); + start_gate.wait(); + const std::size_t M = chain_statses.size(); + std::vector chain_means(M, std::numeric_limits::quiet_NaN()); + std::vector chain_variances(M, + std::numeric_limits::quiet_NaN()); + std::vector counts(M, 0); + while (true) { + for (std::size_t m = 0; m < M; ++m) { + ChainStats u = chain_statses[m].load(); + chain_means[m] = u.sample_mean; + chain_variances[m] = u.sample_var; + counts[m] = u.count; + } + double variance_of_means = variance(chain_means); + double mean_of_variances = mean(chain_variances); + double r_hat = std::sqrt(1 + variance_of_means / mean_of_variances); + std::size_t num_draws = sum(counts); + + debug_print(variance_of_means, mean_of_variances, num_draws, r_hat, counts); + + if (r_hat <= rhat_threshold || num_draws == M * max_draws_per_chain) { + stopper.request_stop(); + break; + } + + std::this_thread::sleep_for(std::chrono::microseconds{16}); + } +} + +// Sampler requires: { double sample(vector& draw); size_t dim(); } +template +std::vector sample(std::vector& samplers, + double rhat_threshold, + std::size_t max_draws_per_chain) { + std::size_t M = samplers.size(); + std::vector chain_statses(M); + std::latch start_gate(M); + std::vector> workers; + workers.reserve(M); + for (std::size_t m = 0; m < M; ++m) { + workers.emplace_back(m, max_draws_per_chain, samplers[m], chain_statses[m], + start_gate); + } + { + std::stop_source stopper; + std::vector threads; + threads.reserve(M); + for (std::size_t m = 0; m < M; ++m) { + threads.emplace_back(std::ref(workers[m]), stopper.get_token()); + } + controller_loop(chain_statses, threads, rhat_threshold, start_gate, + max_draws_per_chain, stopper); + } + + std::vector chain_records; + chain_records.reserve(M); + for (auto& worker : workers) { + chain_records.emplace_back(worker.take_chain_record()); + } + return chain_records; +} + +// ****************** EXAMPLE USAGE AFTER HERE ************************ + +class StandardNormalSampler { + public: + explicit StandardNormalSampler(unsigned int seed, std::size_t dim) + : dim_(dim), engine_(seed), normal_dist_(0, 1) {} + + double sample(std::vector& draw) noexcept { + double lp = 0; + for (std::size_t i = 0; i < dim_; ++i) { + double x = normal_dist_(engine_); + draw.push_back(x); + lp += -0.5 * x * x; // unnomalized + } + return lp; + } + + std::size_t dim() const noexcept { return dim_; } + + private: + std::size_t dim_; + std::mt19937_64 engine_; + std::normal_distribution normal_dist_; +}; + +int main() { + const std::string csv_path = "samples.csv"; + const std::size_t D = 4; + std::size_t M = 16; + const std::size_t N = 1000; + double rhat_threshold = 1.001; + + std::random_device rd; + std::vector samplers; + samplers.reserve(M); + for (std::size_t m = 0; m < M; ++m) { + auto seed = rd(); + samplers.emplace_back(seed, D); + } + + std::vector chain_records = sample(samplers, rhat_threshold, N); + std::size_t rows = 0; + for (std::size_t m = 0; m < chain_records.size(); ++m) { + const auto& chain_record = chain_records[m]; + std::size_t N_m = chain_record.num_draws(); + std::vector lps(N_m); + for (std::size_t n = 0; n < N_m; ++n) { + lps[n] = chain_record.logp(n); + } + rows += N_m; + std::cout << "Chain " << m << " count=" << N_m + << " Final: mean(logp)=" << mean(lps) + << " var(logp) [sample]=" << variance(lps) << '\n'; + } + std::cout << "Number of draws: " << rows << '\n'; + + // UNCOMMENT TO DUMP CSV + // write_csv(csv_path, D, chain_records); + // std::cout << "Wrote draws to " << csv_path << '\n'; + + return 0; +}