-
Notifications
You must be signed in to change notification settings - Fork 1
rhat monitor #38
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
rhat monitor #38
Changes from 4 commits
56ad37d
a750cb0
dfca853
0dc6410
a6ebeda
244d586
94a7917
dbca47c
1eacb2d
0662ed2
d5f175b
3e3d88b
c10ced2
974b659
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,387 @@ | ||||||
| // TO RUN | ||||||
| // clang++ -std=c++20 -O3 rhat_monitor.cpp -o rhat_monitor | ||||||
| // ./rhat_monitor | ||||||
|
|
||||||
| #include <atomic> | ||||||
| #include <chrono> | ||||||
| #include <cmath> | ||||||
| #include <fstream> | ||||||
| #include <functional> | ||||||
| #include <iostream> | ||||||
| #include <latch> | ||||||
| #include <random> | ||||||
| #include <stop_token> | ||||||
| #include <string> | ||||||
| #include <thread> | ||||||
| #include <tuple> | ||||||
| #include <vector> | ||||||
|
|
||||||
| #ifdef __APPLE__ | ||||||
| #include <pthread.h> | ||||||
| [[gnu::always_inline]] inline void interactive_qos() { | ||||||
| pthread_set_qos_class_self_np(QOS_CLASS_USER_INTERACTIVE, 0); // best | ||||||
| } | ||||||
| [[gnu::always_inline]] inline void initiated_qos() { | ||||||
| pthread_set_qos_class_self_np(QOS_CLASS_USER_INITIATED, 0); // next best | ||||||
| } | ||||||
| #else | ||||||
| [[gnu::always_inline]] inline void interactive_qos() {} | ||||||
| [[gnu::always_inline]] inline void initiated_qos() {} | ||||||
| #endif | ||||||
|
|
||||||
| double sum(const std::vector<double>& xs) noexcept { | ||||||
| return std::transform_reduce(xs.begin(), xs.end(), 0.0, std::plus<>{}, | ||||||
| std::identity()); | ||||||
| } | ||||||
|
|
||||||
| double mean(const std::vector<double>& xs) noexcept { | ||||||
| return sum(xs) / xs.size(); | ||||||
| } | ||||||
|
|
||||||
| double variance(const std::vector<double>& xs) noexcept { | ||||||
| std::size_t N = xs.size(); | ||||||
| if (N < 2) { | ||||||
| return std::numeric_limits<double>::quiet_NaN(); | ||||||
| } | ||||||
| double mean_xs = mean(xs); | ||||||
| double sum = std::transform_reduce(xs.begin(), xs.end(), 0.0, std::plus<>{}, | ||||||
| [mean_xs](double x) { | ||||||
| double diff = x - mean_xs; | ||||||
| return diff * diff; | ||||||
| }); | ||||||
| return sum / (N - 1); | ||||||
| } | ||||||
|
|
||||||
| struct SampleStats { | ||||||
| std::size_t count; | ||||||
| double sample_mean; | ||||||
| double sample_var; | ||||||
| }; | ||||||
|
|
||||||
| class Sample { | ||||||
| public: | ||||||
| Sample(std::size_t chain_id, std::size_t D, std::size_t Nmax) | ||||||
| : chain_id_(chain_id), D_(D), Nmax_(Nmax) { | ||||||
| theta_.reserve(Nmax * D); | ||||||
| logp_.reserve(Nmax); | ||||||
| } | ||||||
|
|
||||||
| std::size_t dims() const noexcept { return D_; } | ||||||
|
|
||||||
| std::size_t num_draws() const noexcept { return logp_.size(); } | ||||||
|
|
||||||
| inline double logp(std::size_t n) const { return logp_[n]; } | ||||||
|
|
||||||
| inline double operator()(std::size_t n, std::size_t d) const { | ||||||
| return theta_[n * D_ + d]; | ||||||
| } | ||||||
|
|
||||||
| void write_csv(std::ostream& out, std::size_t dim) const { | ||||||
| for (std::size_t n = 0; n < num_draws(); ++n) { | ||||||
| out << chain_id_ << ',' << n << ',' << logp_[n]; | ||||||
| for (std::size_t d = 0; d < dim; ++d) { | ||||||
| out << ',' << operator()(n, d); | ||||||
| } | ||||||
| out << '\n'; | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
| void append_draw(double logp, std::vector<double>& draw) { | ||||||
| logp_.emplace_back(logp); | ||||||
| theta_.insert(theta_.end(), draw.begin(), draw.end()); | ||||||
| } | ||||||
|
|
||||||
| private: | ||||||
| std::size_t chain_id_; | ||||||
| std::size_t D_; | ||||||
| std::size_t Nmax_; | ||||||
| std::vector<double> theta_; | ||||||
| std::vector<double> logp_; | ||||||
| }; | ||||||
|
|
||||||
| // see https://rigtorp.se/ringbuffer/ | ||||||
| template <class T, std::size_t Capacity> | ||||||
| class alignas(std::hardware_destructive_interference_size) RingBuffer { | ||||||
| public: | ||||||
| explicit RingBuffer() : data_(Capacity) {} | ||||||
|
|
||||||
| template <class... Args> | ||||||
| bool emplace(Args&&... args) noexcept { | ||||||
| auto write_idx = write_idx_.load(std::memory_order_relaxed); | ||||||
| auto next = write_idx + 1; | ||||||
| if (next == Capacity) { | ||||||
| next = 0; | ||||||
| } | ||||||
| if (next == read_idx_.load(std::memory_order_acquire)) { | ||||||
| return false; | ||||||
| } | ||||||
| data_[write_idx] = T(std::forward<Args>(args)...); | ||||||
| write_idx_.store(next, std::memory_order_release); | ||||||
| return true; | ||||||
| } | ||||||
|
|
||||||
| bool pop(T& out) noexcept { | ||||||
| auto read_idx = read_idx_.load(std::memory_order_relaxed); | ||||||
| if (read_idx == write_idx_.load(std::memory_order_acquire)) { | ||||||
| return false; | ||||||
| } | ||||||
| out = std::move(data_[read_idx]); | ||||||
| auto next = read_idx + 1; | ||||||
| if (next == Capacity) { | ||||||
| next = 0; | ||||||
| } | ||||||
| read_idx_.store(next, std::memory_order_release); | ||||||
| return true; | ||||||
| } | ||||||
|
|
||||||
| std::size_t capacity() const noexcept { return Capacity; } | ||||||
|
|
||||||
| private: | ||||||
| std::vector<T> data_; | ||||||
| alignas(std::hardware_destructive_interference_size) | ||||||
| std::atomic<std::size_t> read_idx_{0}; | ||||||
| alignas(std::hardware_destructive_interference_size) | ||||||
| std::atomic<std::size_t> write_idx_{0}; | ||||||
| }; | ||||||
bob-carpenter marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||
|
|
||||||
| constexpr std::size_t RING_CAPACITY = 64; | ||||||
|
|
||||||
| using Queue = RingBuffer<SampleStats, RING_CAPACITY>; | ||||||
|
|
||||||
| static void write_csv_header(std::ofstream& out, std::size_t dim) { | ||||||
| out << "chain,iteration,log_density"; | ||||||
| for (std::size_t d = 0; d < dim; ++d) { | ||||||
| out << ",theta[" << d << "]"; | ||||||
| } | ||||||
| out << '\n'; | ||||||
| } | ||||||
|
|
||||||
| static void write_csv(const std::string& path, std::size_t dim, | ||||||
| const std::vector<Sample>& samples) { | ||||||
| std::ofstream out(path, std::ios::binary); // binary for Windows consistency | ||||||
| if (!out) { | ||||||
| throw std::runtime_error("could not open file: " + path); | ||||||
| } | ||||||
| write_csv_header(out, dim); | ||||||
| for (const auto& sample : samples) { | ||||||
| sample.write_csv(out, dim); | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
| class WelfordAccumulator { | ||||||
| public: | ||||||
| WelfordAccumulator() : n_(0), mean_(0.0), M2_(0.0) {} | ||||||
|
|
||||||
| void push(double x) { | ||||||
| ++n_; | ||||||
| const double delta = x - mean_; | ||||||
| mean_ += delta / static_cast<double>(n_); | ||||||
| const double delta2 = x - mean_; | ||||||
| M2_ += delta * delta2; | ||||||
| } | ||||||
bob-carpenter marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||
|
|
||||||
| std::size_t count() const { return n_; } | ||||||
|
|
||||||
| double mean() const { return mean_; } | ||||||
|
|
||||||
| double sample_variance() const { | ||||||
| return n_ > 1 ? (M2_ / static_cast<double>(n_ - 1)) | ||||||
| : std::numeric_limits<double>::quiet_NaN(); | ||||||
| } | ||||||
|
|
||||||
| SampleStats sample_stats() { return {count(), mean(), sample_variance()}; } | ||||||
|
|
||||||
| void reset() { | ||||||
| n_ = 0; | ||||||
| mean_ = 0.0; | ||||||
| M2_ = 0.0; | ||||||
| } | ||||||
|
|
||||||
| private: | ||||||
| std::size_t n_; | ||||||
| double mean_; | ||||||
| double M2_; | ||||||
| }; | ||||||
|
|
||||||
| template <class Sampler> | ||||||
| class ChainTask { | ||||||
| public: | ||||||
| ChainTask(std::size_t chain_id, std::size_t draws_per_chain, Sampler& sampler, | ||||||
| Queue& q, std::latch& start_gate) | ||||||
| : chain_id_(chain_id), | ||||||
| draws_per_chain_(draws_per_chain), | ||||||
| sampler_(sampler), | ||||||
| sample_(chain_id, sampler.dim(), draws_per_chain), | ||||||
| q_(q), | ||||||
| start_gate_(start_gate) {} | ||||||
|
|
||||||
| void operator()(std::stop_token st) { | ||||||
| initiated_qos(); | ||||||
| start_gate_.arrive_and_wait(); | ||||||
| for (std::size_t iter = 0; iter < draws_per_chain_; ++iter) { | ||||||
| if ((iter + 1) % 100 == 0) { | ||||||
| std::this_thread::yield(); | ||||||
| } | ||||||
| auto [logp, theta] = sampler_(); | ||||||
| logp_stats_.push(logp); | ||||||
| sample_.append_draw(logp, theta); | ||||||
| q_.emplace(logp_stats_.sample_stats()); | ||||||
|
||||||
| q_.emplace(logp_stats_.sample_stats()); | |
| while(!q_.emplace(logp_stats_.sample_stats())) {}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I do that, the algorithm reliably hangs. I'm not busy-spinning the controller loop in the main body---it only updates the chain means, variances, and counts if it can pop them from the ring buffer. It keeps popping while it can to get to the most recent.
The individual chains busy spin after they have hit their max warmup to make sure their final update gets registered. I haven't seen the be a problem yet, but I keep thinking there may be more hidden race conditions.
I feel like an alternative might be to have another latch used to have everyone stop.
bob-carpenter marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
bob-carpenter marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
bob-carpenter marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this is the only way out of the loop, this will also want some check for if all the workers hit their max number of iterations
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there a reason to run the controller thread on its own thread rather than on the main thread here?
Uh oh!
There was an error while loading. Please reload this page.