Skip to content
152 changes: 59 additions & 93 deletions examples/rhat_monitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,18 @@
#include <thread>
#include <vector>

#define VERY_INLINE [[gnu::always_inline]] inline
#ifdef __APPLE__
#include <pthread.h>
[[gnu::always_inline]] inline void interactive_qos() {
VERY_INLINE void interactive_qos() {
pthread_set_qos_class_self_np(QOS_CLASS_USER_INTERACTIVE, 0); // best
}
[[gnu::always_inline]] inline void initiated_qos() {
VERY_INLINE void initiated_qos() {
pthread_set_qos_class_self_np(QOS_CLASS_USER_INITIATED, 0); // next best
}
#else
[[gnu::always_inline]] inline void interactive_qos() {}
[[gnu::always_inline]] inline void initiated_qos() {}
VERY_INLINE void interactive_qos() {}
VERY_INLINE void initiated_qos() {}
#endif

double sum(const std::vector<double>& xs) noexcept {
Expand Down Expand Up @@ -57,9 +58,29 @@ double variance(const std::vector<double>& xs) noexcept {
}

struct ChainStats {
std::size_t count;
double sample_mean;
double sample_var;
std::size_t count;
};

class AtomicChainStats {
public:
AtomicChainStats() noexcept {
ChainStats init{std::numeric_limits<double>::quiet_NaN(),
std::numeric_limits<double>::quiet_NaN(), 0u};
data_.store(init, std::memory_order_relaxed);
}

void store(const ChainStats& p) noexcept {
data_.store(p, std::memory_order_release);
}

ChainStats load() const noexcept {
return data_.load(std::memory_order_acquire);
}

private:
std::atomic<ChainStats> data_;
};

class ChainRecord {
Expand Down Expand Up @@ -103,57 +124,6 @@ class ChainRecord {
std::vector<double> logp_;
};

// placeholder ring buffer; see https://rigtorp.se/ringbuffer/
template <class T, std::size_t Capacity>
class alignas(std::hardware_destructive_interference_size) RingBuffer {
public:
explicit RingBuffer() : data_(Capacity) {}
RingBuffer(const RingBuffer&) = delete;
RingBuffer& operator=(const RingBuffer&) = delete;
RingBuffer(RingBuffer&&) = delete;
RingBuffer& operator=(RingBuffer&&) = delete;

template <class... Args>
bool try_emplace(Args&&... args) noexcept {
auto write_idx = write_idx_.load(std::memory_order_relaxed);
auto next = write_idx + 1;
if (next == Capacity) {
next = 0;
}
if (next == read_idx_.load(std::memory_order_acquire)) {
return false;
}
data_[write_idx] = T(std::forward<Args>(args)...);
write_idx_.store(next, std::memory_order_release);
return true;
}

bool pop(T& out) noexcept {
auto read_idx = read_idx_.load(std::memory_order_relaxed);
if (read_idx == write_idx_.load(std::memory_order_acquire)) {
return false;
}
out = std::move(data_[read_idx]);
auto next = read_idx + 1;
if (next == Capacity) {
next = 0;
}
read_idx_.store(next, std::memory_order_release);
return true;
}

private:
std::vector<T> data_;
alignas(std::hardware_destructive_interference_size)
std::atomic<std::size_t> read_idx_{0};
alignas(std::hardware_destructive_interference_size)
std::atomic<std::size_t> write_idx_{0};
};

constexpr std::size_t RING_CAPACITY = 64;

using Queue = RingBuffer<ChainStats, RING_CAPACITY>;

static void write_csv_header(std::ofstream& out, std::size_t dim) {
out << "chain,iteration,log_density";
for (std::size_t d = 0; d < dim; ++d) {
Expand Down Expand Up @@ -195,7 +165,7 @@ class WelfordAccumulator {
: std::numeric_limits<double>::quiet_NaN();
}

ChainStats sample_stats() { return {count(), mean(), sample_variance()}; }
ChainStats sample_stats() { return {mean(), sample_variance(), count()}; }

void reset() {
n_ = 0;
Expand All @@ -212,34 +182,30 @@ class WelfordAccumulator {
template <class Sampler>
class ChainWorker {
public:
ChainWorker(std::size_t chain_id, std::size_t draws_per_chain, Sampler& sampler,
Queue& q, std::latch& start_gate)
ChainWorker(std::size_t chain_id, std::size_t draws_per_chain,
Sampler& sampler, AtomicChainStats& acs, std::latch& start_gate)
: chain_id_(chain_id),
draws_per_chain_(draws_per_chain),
sampler_(sampler),
chain_record_(chain_id, sampler.dim(), draws_per_chain),
q_(q),
acs_(acs),
start_gate_(start_gate) {}

void operator()(std::stop_token st) {
initiated_qos();
interactive_qos();
start_gate_.get().arrive_and_wait();
for (std::size_t iter = 0; iter < draws_per_chain_; ++iter) {
if ((iter + 1) % 16 == 0) {
if ((iter + 1) % 64 == 0) {
std::this_thread::yield();
}
double logp = sampler_.get().sample(chain_record_.draws());
chain_record_.append_logp(logp);
logp_stats_.observe(logp);
// ignore return of emplace
q_.get().try_emplace(logp_stats_.sample_stats());
acs_.store(logp_stats_.sample_stats());
if (st.stop_requested()) {
break;
}
}
// make sure final update sticks
while (!st.stop_requested()
&& !q_.get().try_emplace(logp_stats_.sample_stats()));
}

ChainRecord&& take_chain_record() { return std::move(chain_record_); }
Expand All @@ -250,17 +216,16 @@ class ChainWorker {
std::reference_wrapper<Sampler> sampler_;
ChainRecord chain_record_;
WelfordAccumulator logp_stats_;
std::reference_wrapper<Queue> q_;
AtomicChainStats& acs_;
std::reference_wrapper<std::latch> start_gate_;
};

void debug_print(double variance_of_means, double mean_of_variances,
double num_draws, double r_hat,
const std::vector<std::size_t>& counts) {
auto M = counts.size();
std::cout << "RHAT: " << r_hat
<< " NUM_DRAWS: " << num_draws
<< " COUNTS: ";
std::cout << "RHAT: " << r_hat << " NUM_DRAWS: " << num_draws
<< " COUNTS: ";
for (std::size_t m = 0; m < M; ++m) {
if (m > 0) {
std::cout << ", ";
Expand All @@ -270,26 +235,24 @@ void debug_print(double variance_of_means, double mean_of_variances,
std::cout << std::endl;
}

static void controller_loop(std::vector<Queue>& queues,
static void controller_loop(std::vector<AtomicChainStats>& chain_statses,
std::vector<std::jthread>& workers,
double rhat_threshold, std::latch& start_gate,
std::size_t max_draws_per_chain,
std::stop_source& stopper) {
std::stop_source& stopper) {
interactive_qos();
start_gate.wait();
const std::size_t M = queues.size();
const std::size_t M = chain_statses.size();
std::vector<double> chain_means(M, std::numeric_limits<double>::quiet_NaN());
std::vector<double> chain_variances(M,
std::numeric_limits<double>::quiet_NaN());
std::vector<std::size_t> counts(M, 0);
while (true) {
for (std::size_t m = 0; m < M; ++m) {
ChainStats u;
while (queues[m].pop(u)) {
chain_means[m] = u.sample_mean;
chain_variances[m] = u.sample_var;
counts[m] = u.count;
}
ChainStats u = chain_statses[m].load();
chain_means[m] = u.sample_mean;
chain_variances[m] = u.sample_var;
counts[m] = u.count;
}
double variance_of_means = variance(chain_means);
double mean_of_variances = mean(chain_variances);
Expand All @@ -302,38 +265,40 @@ static void controller_loop(std::vector<Queue>& queues,
stopper.request_stop();
break;
}

std::this_thread::sleep_for(std::chrono::microseconds{16});
}
}

// Sampler requires: { double sample(vector<double>& draw); size_t dim(); }
template <typename Sampler>
std::vector<ChainRecord> sample(std::vector<Sampler>& samplers,
double rhat_threshold,
std::size_t max_draws_per_chain) {
double rhat_threshold,
std::size_t max_draws_per_chain) {
std::size_t M = samplers.size();
std::vector<Queue> queues(M);
std::vector<AtomicChainStats> chain_statses(M);
std::latch start_gate(M);
std::vector<ChainWorker<Sampler>> workers;
workers.reserve(M);
for (std::size_t m = 0; m < M; ++m) {
workers.emplace_back(m, max_draws_per_chain, samplers[m], queues[m],
start_gate);
workers.emplace_back(m, max_draws_per_chain, samplers[m], chain_statses[m],
start_gate);
}
{
{
std::stop_source stopper;
std::vector<std::jthread> threads;
threads.reserve(M);
for (std::size_t m = 0; m < M; ++m) {
threads.emplace_back(std::ref(workers[m]), stopper.get_token());
}
controller_loop(queues, threads, rhat_threshold, start_gate,
max_draws_per_chain, stopper);
controller_loop(chain_statses, threads, rhat_threshold, start_gate,
max_draws_per_chain, stopper);
}

std::vector<ChainRecord> chain_records;
chain_records.reserve(M);
for (auto& worker : workers) {
chain_records.emplace_back(worker.take_chain_record());
chain_records.emplace_back(worker.take_chain_record());
}
return chain_records;
}
Expand Down Expand Up @@ -365,7 +330,7 @@ class StandardNormalSampler {

int main() {
const std::string csv_path = "samples.csv";
const std::size_t D = 16;
const std::size_t D = 4;
std::size_t M = 16;
const std::size_t N = 1000;
double rhat_threshold = 1.001;
Expand Down Expand Up @@ -394,8 +359,9 @@ int main() {
}
std::cout << "Number of draws: " << rows << '\n';

write_csv(csv_path, D, chain_records);
std::cout << "Wrote draws to " << csv_path << '\n';
// UNCOMMENT TO DUMP CSV
// write_csv(csv_path, D, chain_records);
// std::cout << "Wrote draws to " << csv_path << '\n';

return 0;
}