-
Notifications
You must be signed in to change notification settings - Fork 1
rhat monitor #38
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
bob-carpenter
wants to merge
14
commits into
main
Choose a base branch
from
auto-rhat
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
rhat monitor #38
Changes from all commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
56ad37d
rhat monitor draft
bob-carpenter a750cb0
remove unused args
bob-carpenter dfca853
fixes per code review
0dc6410
pause normal sim
a6ebeda
test chain completion; sampler refs; controller main thread
bob-carpenter 244d586
reference_wrapper for ChainTask
94a7917
lp ref to return
dbca47c
shared stop token
1eacb2d
rename Welford push() to observe()
0662ed2
include updates
d5f175b
remove unused fun
bob-carpenter 3e3d88b
join workers to allow completion
c10ced2
name consistency
974b659
atomic chain state
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,367 @@ | ||
| // clang++ -std=c++20 -O3 -pthread rhat_monitor.cpp -o rhat_monitor | ||
| // ./rhat_monitor | ||
|
|
||
| #include <atomic> | ||
| #include <chrono> | ||
| #include <cmath> | ||
| #include <fstream> | ||
| #include <functional> | ||
| #include <iostream> | ||
| #include <latch> | ||
| #include <numeric> | ||
| #include <random> | ||
| #include <stop_token> | ||
| #include <string> | ||
| #include <thread> | ||
| #include <vector> | ||
|
|
||
| #define VERY_INLINE [[gnu::always_inline]] inline | ||
| #ifdef __APPLE__ | ||
| #include <pthread.h> | ||
| VERY_INLINE void interactive_qos() { | ||
| pthread_set_qos_class_self_np(QOS_CLASS_USER_INTERACTIVE, 0); // best | ||
| } | ||
| VERY_INLINE void initiated_qos() { | ||
| pthread_set_qos_class_self_np(QOS_CLASS_USER_INITIATED, 0); // next best | ||
| } | ||
| #else | ||
| VERY_INLINE void interactive_qos() {} | ||
| VERY_INLINE void initiated_qos() {} | ||
| #endif | ||
|
|
||
| double sum(const std::vector<double>& xs) noexcept { | ||
| return std::transform_reduce(xs.begin(), xs.end(), 0.0, std::plus<>{}, | ||
| std::identity()); | ||
| } | ||
|
|
||
| std::size_t sum(const std::vector<std::size_t>& xs) noexcept { | ||
| return std::transform_reduce(xs.begin(), xs.end(), 0, std::plus<>{}, | ||
| std::identity()); | ||
| } | ||
|
|
||
| double mean(const std::vector<double>& xs) noexcept { | ||
| return sum(xs) / xs.size(); | ||
| } | ||
|
|
||
| double variance(const std::vector<double>& xs) noexcept { | ||
| std::size_t N = xs.size(); | ||
| if (N < 2) { | ||
| return std::numeric_limits<double>::quiet_NaN(); | ||
| } | ||
| double mean_xs = mean(xs); | ||
| double sum = std::transform_reduce(xs.begin(), xs.end(), 0.0, std::plus<>{}, | ||
| [mean_xs](double x) { | ||
| double diff = x - mean_xs; | ||
| return diff * diff; | ||
| }); | ||
| return sum / (N - 1); | ||
| } | ||
|
|
||
| struct ChainStats { | ||
| double sample_mean; | ||
| double sample_var; | ||
| std::size_t count; | ||
| }; | ||
|
|
||
| class AtomicChainStats { | ||
| public: | ||
| AtomicChainStats() noexcept { | ||
| ChainStats init{std::numeric_limits<double>::quiet_NaN(), | ||
| std::numeric_limits<double>::quiet_NaN(), 0u}; | ||
| data_.store(init, std::memory_order_relaxed); | ||
| } | ||
|
|
||
| void store(const ChainStats& p) noexcept { | ||
| data_.store(p, std::memory_order_release); | ||
| } | ||
|
|
||
| ChainStats load() const noexcept { | ||
| return data_.load(std::memory_order_acquire); | ||
| } | ||
|
|
||
| private: | ||
| std::atomic<ChainStats> data_; | ||
| }; | ||
|
|
||
| class ChainRecord { | ||
| public: | ||
| ChainRecord(std::size_t chain_id, std::size_t D, std::size_t Nmax) | ||
| : chain_id_(chain_id), D_(D), Nmax_(Nmax) { | ||
| theta_.reserve(Nmax * D); | ||
| logp_.reserve(Nmax); | ||
| } | ||
|
|
||
| std::vector<double>& draws() noexcept { return theta_; } | ||
|
|
||
| std::size_t dims() const noexcept { return D_; } | ||
|
|
||
| std::size_t num_draws() const noexcept { return logp_.size(); } | ||
|
|
||
| inline double logp(std::size_t n) const { return logp_[n]; } | ||
|
|
||
| inline double operator()(std::size_t n, std::size_t d) const { | ||
| return theta_[n * D_ + d]; | ||
| } | ||
|
|
||
| void write_csv(std::ostream& out) const { | ||
| auto dim = dims(); | ||
| for (std::size_t n = 0; n < num_draws(); ++n) { | ||
| out << chain_id_ << ',' << n << ',' << logp_[n]; | ||
| for (std::size_t d = 0; d < dim; ++d) { | ||
| out << ',' << operator()(n, d); | ||
| } | ||
| out << '\n'; | ||
| } | ||
| } | ||
|
|
||
| void append_logp(double logp) { logp_.push_back(logp); } | ||
|
|
||
| private: | ||
| std::size_t chain_id_; | ||
| std::size_t D_; | ||
| std::size_t Nmax_; | ||
| std::vector<double> theta_; | ||
| std::vector<double> logp_; | ||
| }; | ||
|
|
||
| static void write_csv_header(std::ofstream& out, std::size_t dim) { | ||
| out << "chain,iteration,log_density"; | ||
| for (std::size_t d = 0; d < dim; ++d) { | ||
| out << ",theta[" << d << "]"; | ||
| } | ||
| out << '\n'; | ||
| } | ||
|
|
||
| static void write_csv(const std::string& path, std::size_t dim, | ||
| const std::vector<ChainRecord>& chain_records) { | ||
| std::ofstream out(path, std::ios::binary); // binary for Windows consistency | ||
| if (!out) { | ||
| throw std::runtime_error("could not open file: " + path); | ||
| } | ||
| write_csv_header(out, dim); | ||
| for (const auto& chain_record : chain_records) { | ||
| chain_record.write_csv(out); | ||
| } | ||
| } | ||
|
|
||
| class WelfordAccumulator { | ||
| public: | ||
| WelfordAccumulator() : n_(0), mean_(0.0), M2_(0.0) {} | ||
|
|
||
| void observe(double x) { | ||
| ++n_; | ||
| const double delta = x - mean_; | ||
| mean_ += delta / static_cast<double>(n_); | ||
| const double delta2 = x - mean_; | ||
| M2_ += delta * delta2; | ||
| } | ||
|
|
||
| std::size_t count() const { return n_; } | ||
|
|
||
| double mean() const { return mean_; } | ||
|
|
||
| double sample_variance() const { | ||
| return n_ > 1 ? (M2_ / static_cast<double>(n_ - 1)) | ||
| : std::numeric_limits<double>::quiet_NaN(); | ||
| } | ||
|
|
||
| ChainStats sample_stats() { return {mean(), sample_variance(), count()}; } | ||
|
|
||
| void reset() { | ||
| n_ = 0; | ||
| mean_ = 0.0; | ||
| M2_ = 0.0; | ||
| } | ||
|
|
||
| private: | ||
| std::size_t n_; | ||
| double mean_; | ||
| double M2_; | ||
| }; | ||
|
|
||
| template <class Sampler> | ||
| class ChainWorker { | ||
| public: | ||
| ChainWorker(std::size_t chain_id, std::size_t draws_per_chain, | ||
| Sampler& sampler, AtomicChainStats& acs, std::latch& start_gate) | ||
| : chain_id_(chain_id), | ||
| draws_per_chain_(draws_per_chain), | ||
| sampler_(sampler), | ||
| chain_record_(chain_id, sampler.dim(), draws_per_chain), | ||
| acs_(acs), | ||
| start_gate_(start_gate) {} | ||
|
|
||
| void operator()(std::stop_token st) { | ||
| interactive_qos(); | ||
| start_gate_.get().arrive_and_wait(); | ||
| for (std::size_t iter = 0; iter < draws_per_chain_; ++iter) { | ||
| if ((iter + 1) % 64 == 0) { | ||
| std::this_thread::yield(); | ||
| } | ||
| double logp = sampler_.get().sample(chain_record_.draws()); | ||
| chain_record_.append_logp(logp); | ||
| logp_stats_.observe(logp); | ||
| acs_.store(logp_stats_.sample_stats()); | ||
| if (st.stop_requested()) { | ||
| break; | ||
| } | ||
| } | ||
| } | ||
|
|
||
| ChainRecord&& take_chain_record() { return std::move(chain_record_); } | ||
|
|
||
| private: | ||
| std::size_t chain_id_; | ||
| std::size_t draws_per_chain_; | ||
| std::reference_wrapper<Sampler> sampler_; | ||
| ChainRecord chain_record_; | ||
| WelfordAccumulator logp_stats_; | ||
| AtomicChainStats& acs_; | ||
| std::reference_wrapper<std::latch> start_gate_; | ||
| }; | ||
|
|
||
| void debug_print(double variance_of_means, double mean_of_variances, | ||
| double num_draws, double r_hat, | ||
| const std::vector<std::size_t>& counts) { | ||
| auto M = counts.size(); | ||
| std::cout << "RHAT: " << r_hat << " NUM_DRAWS: " << num_draws | ||
| << " COUNTS: "; | ||
| for (std::size_t m = 0; m < M; ++m) { | ||
| if (m > 0) { | ||
| std::cout << ", "; | ||
| } | ||
| std::cout << counts[m]; | ||
| } | ||
| std::cout << std::endl; | ||
| } | ||
|
|
||
| static void controller_loop(std::vector<AtomicChainStats>& chain_statses, | ||
| std::vector<std::jthread>& workers, | ||
| double rhat_threshold, std::latch& start_gate, | ||
| std::size_t max_draws_per_chain, | ||
| std::stop_source& stopper) { | ||
| interactive_qos(); | ||
| start_gate.wait(); | ||
| const std::size_t M = chain_statses.size(); | ||
| std::vector<double> chain_means(M, std::numeric_limits<double>::quiet_NaN()); | ||
| std::vector<double> chain_variances(M, | ||
| std::numeric_limits<double>::quiet_NaN()); | ||
| std::vector<std::size_t> counts(M, 0); | ||
| while (true) { | ||
| for (std::size_t m = 0; m < M; ++m) { | ||
| ChainStats u = chain_statses[m].load(); | ||
| chain_means[m] = u.sample_mean; | ||
| chain_variances[m] = u.sample_var; | ||
| counts[m] = u.count; | ||
| } | ||
| double variance_of_means = variance(chain_means); | ||
| double mean_of_variances = mean(chain_variances); | ||
| double r_hat = std::sqrt(1 + variance_of_means / mean_of_variances); | ||
| std::size_t num_draws = sum(counts); | ||
|
|
||
| debug_print(variance_of_means, mean_of_variances, num_draws, r_hat, counts); | ||
|
|
||
| if (r_hat <= rhat_threshold || num_draws == M * max_draws_per_chain) { | ||
| stopper.request_stop(); | ||
| break; | ||
| } | ||
|
|
||
| std::this_thread::sleep_for(std::chrono::microseconds{16}); | ||
| } | ||
| } | ||
|
|
||
| // Sampler requires: { double sample(vector<double>& draw); size_t dim(); } | ||
| template <typename Sampler> | ||
| std::vector<ChainRecord> sample(std::vector<Sampler>& samplers, | ||
| double rhat_threshold, | ||
| std::size_t max_draws_per_chain) { | ||
| std::size_t M = samplers.size(); | ||
| std::vector<AtomicChainStats> chain_statses(M); | ||
| std::latch start_gate(M); | ||
| std::vector<ChainWorker<Sampler>> workers; | ||
| workers.reserve(M); | ||
| for (std::size_t m = 0; m < M; ++m) { | ||
| workers.emplace_back(m, max_draws_per_chain, samplers[m], chain_statses[m], | ||
| start_gate); | ||
| } | ||
| { | ||
| std::stop_source stopper; | ||
| std::vector<std::jthread> threads; | ||
| threads.reserve(M); | ||
| for (std::size_t m = 0; m < M; ++m) { | ||
| threads.emplace_back(std::ref(workers[m]), stopper.get_token()); | ||
| } | ||
| controller_loop(chain_statses, threads, rhat_threshold, start_gate, | ||
| max_draws_per_chain, stopper); | ||
| } | ||
|
|
||
| std::vector<ChainRecord> chain_records; | ||
| chain_records.reserve(M); | ||
| for (auto& worker : workers) { | ||
| chain_records.emplace_back(worker.take_chain_record()); | ||
| } | ||
| return chain_records; | ||
| } | ||
|
|
||
| // ****************** EXAMPLE USAGE AFTER HERE ************************ | ||
|
|
||
| class StandardNormalSampler { | ||
| public: | ||
| explicit StandardNormalSampler(unsigned int seed, std::size_t dim) | ||
| : dim_(dim), engine_(seed), normal_dist_(0, 1) {} | ||
|
|
||
| double sample(std::vector<double>& draw) noexcept { | ||
| double lp = 0; | ||
| for (std::size_t i = 0; i < dim_; ++i) { | ||
| double x = normal_dist_(engine_); | ||
| draw.push_back(x); | ||
| lp += -0.5 * x * x; // unnomalized | ||
| } | ||
| return lp; | ||
| } | ||
|
|
||
| std::size_t dim() const noexcept { return dim_; } | ||
|
|
||
| private: | ||
| std::size_t dim_; | ||
| std::mt19937_64 engine_; | ||
| std::normal_distribution<double> normal_dist_; | ||
| }; | ||
|
|
||
| int main() { | ||
| const std::string csv_path = "samples.csv"; | ||
| const std::size_t D = 4; | ||
| std::size_t M = 16; | ||
| const std::size_t N = 1000; | ||
| double rhat_threshold = 1.001; | ||
|
|
||
| std::random_device rd; | ||
| std::vector<StandardNormalSampler> samplers; | ||
| samplers.reserve(M); | ||
| for (std::size_t m = 0; m < M; ++m) { | ||
| auto seed = rd(); | ||
| samplers.emplace_back(seed, D); | ||
| } | ||
|
|
||
| std::vector<ChainRecord> chain_records = sample(samplers, rhat_threshold, N); | ||
| std::size_t rows = 0; | ||
| for (std::size_t m = 0; m < chain_records.size(); ++m) { | ||
| const auto& chain_record = chain_records[m]; | ||
| std::size_t N_m = chain_record.num_draws(); | ||
| std::vector<double> lps(N_m); | ||
| for (std::size_t n = 0; n < N_m; ++n) { | ||
| lps[n] = chain_record.logp(n); | ||
| } | ||
| rows += N_m; | ||
| std::cout << "Chain " << m << " count=" << N_m | ||
| << " Final: mean(logp)=" << mean(lps) | ||
| << " var(logp) [sample]=" << variance(lps) << '\n'; | ||
| } | ||
| std::cout << "Number of draws: " << rows << '\n'; | ||
|
|
||
| // UNCOMMENT TO DUMP CSV | ||
| // write_csv(csv_path, D, chain_records); | ||
| // std::cout << "Wrote draws to " << csv_path << '\n'; | ||
|
|
||
| return 0; | ||
| } | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.