Skip to content

Conversation

bob-carpenter
Copy link
Collaborator

I coded up something to run a Stan program until the ESS for E[theta], E[theta^2], E[theta^4] are all above a given threshold, where theta includes all the parameters and the log density lp__. It just iterates blocks of draws, and keeps a running sum of ESS values and a running average. It then dumps the results in scientific notation in JSON to 10 decimal places, which is more than we should need for our test cases. Ideally, we'd dump to within a number of standard deviations from the mean, but I think this should be OK unless we have examples with small very small standard deviations relative to their means.

I largely used ChatGPT to code this, so more than happy to get feedback on the Python.

Here's the output for a 10-dimensional normal target with scales 1 to 10.

build (benchmarks)$ time python3 ../examples/reference-moments.py

STAN PROGRAM: stan_file = '../examples/diag_scale_target.stan'
    DATA FILE: data_file = '../examples/diag_scale_target.json'
    OUTPUT FILE: out_file = '../examples/diag_scale_target_out.json'
         min_ess_target = 1000000.0
         block_size = 10000
         max_blocks = 10000
         seed = 643889

    0.  min(ESS) = 0.00e+00
   10.  min(ESS) = 4.02e+04
   20.  min(ESS) = 7.88e+04
   30.  min(ESS) = 1.17e+05
   40.  min(ESS) = 1.57e+05
   50.  min(ESS) = 1.97e+05
   60.  min(ESS) = 2.36e+05
   70.  min(ESS) = 2.75e+05
   80.  min(ESS) = 3.14e+05
   90.  min(ESS) = 3.53e+05
  100.  min(ESS) = 3.92e+05
  110.  min(ESS) = 4.31e+05
  120.  min(ESS) = 4.72e+05
  130.  min(ESS) = 5.11e+05
  140.  min(ESS) = 5.50e+05
  150.  min(ESS) = 5.90e+05
  160.  min(ESS) = 6.30e+05
  170.  min(ESS) = 6.69e+05
  180.  min(ESS) = 7.09e+05
  190.  min(ESS) = 7.48e+05
  200.  min(ESS) = 7.87e+05
  210.  min(ESS) = 8.26e+05
  220.  min(ESS) = 8.64e+05
  230.  min(ESS) = 9.03e+05
  240.  min(ESS) = 9.44e+05
  250.  min(ESS) = 9.83e+05
  255.  min(ESS) = 1.00e+06

***** ACHIEVED MINIMUM ESS TARGET *****

X                   E[X]         E[X^2]         E[X^4]       ESS[X]     ESS[X^2]     ESS[X^4]
theta[0]     9.26343e-05    1.00122e+00    3.00575e+00  2.59284e+06  1.11440e+06  1.11440e+06
theta[1]    -9.47696e-05    4.00412e+00    4.80899e+01  2.88753e+06  1.03804e+06  1.03804e+06
theta[2]     1.08066e-03    9.01607e+00    2.44025e+02  3.90151e+06  1.00352e+06  1.00352e+06
theta[3]    -1.46590e-03    1.60158e+01    7.70857e+02  3.79204e+06  1.07977e+06  1.07977e+06
theta[4]    -1.12083e-03    2.50006e+01    1.87645e+03  3.39769e+06  1.23626e+06  1.23626e+06
theta[5]     1.49413e-03    3.59929e+01    3.89271e+03  2.75145e+06  1.23811e+06  1.23811e+06
theta[6]     3.63230e-03    4.91345e+01    7.24702e+03  2.34881e+06  1.19165e+06  1.19165e+06
theta[7]    -4.64606e-03    6.39767e+01    1.22822e+04  2.01754e+06  1.21023e+06  1.21023e+06
theta[8]     5.71366e-03    8.11992e+01    1.97758e+04  1.70365e+06  1.23581e+06  1.23581e+06
theta[9]     3.51204e-04    9.99416e+01    3.00887e+04  1.45531e+06  1.24614e+06  1.24614e+06
lp          -5.00456e+00    3.00560e+01    1.68729e+03  1.08288e+06  1.08288e+06  1.08288e+06
python3 ../examples/reference-moments.py  109.36s user 14.48s system 104% cpu 1:58.91 total

It produces this JSON output:

{
  "vars": [
    9.2634264000e-05,
    -9.4769614000e-05,
    1.0806610000e-03,
    -1.4658990000e-03,
    -1.1208311000e-03,
    1.4941326000e-03,
    3.6323028000e-03,
    -4.6460588000e-03,
    5.7136589000e-03,
    3.5120391000e-04,
    -5.0045560000e+00
  ],
  "vars_sq": [
    1.0012219000e+00,
    4.0041225000e+00,
    9.0160672000e+00,
    1.6015831000e+01,
    2.5000648000e+01,
    3.5992883000e+01,
    4.9134498000e+01,
    6.3976726000e+01,
    8.1199190000e+01,
    9.9941619000e+01,
    3.0055958000e+01
  ],
  "vars_fourth": [
    3.0057456000e+00,
    4.8089912000e+01,
    2.4402527000e+02,
    7.7085734000e+02,
    1.8764548000e+03,
    3.8927089000e+03,
    7.2470225000e+03,
    1.2282180000e+04,
    1.9775846000e+04,
    3.0088670000e+04,
    1.6872879000e+03
  ],
  "ess": [
    2.5928430710e+06,
    2.8875322166e+06,
    3.9015117136e+06,
    3.7920353054e+06,
    3.3976893429e+06,
    2.7514505746e+06,
    2.3488147882e+06,
    2.0175434890e+06,
    1.7036521274e+06,
    1.4553086826e+06,
    1.0828808333e+06
  ],
  "ess_sq": [
    1.1144011606e+06,
    1.0380386513e+06,
    1.0035197614e+06,
    1.0797689318e+06,
    1.2362565887e+06,
    1.2381111553e+06,
    1.1916465762e+06,
    1.2102278759e+06,
    1.2358123834e+06,
    1.2461369728e+06,
    1.0828808333e+06
  ],
  "ess_fourth": [
    1.1144011606e+06,
    1.0380386513e+06,
    1.0035197614e+06,
    1.0797689318e+06,
    1.2362565887e+06,
    1.2381111553e+06,
    1.1916465762e+06,
    1.2102278759e+06,
    1.2358123834e+06,
    1.2461369728e+06,
    1.0828808333e+06
  ]
}

@WardBrian
Copy link
Collaborator

I'm happy to take a look at this, but do you think it would best belong in a different repository?

@bob-carpenter
Copy link
Collaborator Author

bob-carpenter commented Aug 5, 2025 via email

@bob-carpenter bob-carpenter changed the title Reference moment estimation Benchmarks Branch [DO NOT MERGE] Aug 5, 2025
@bob-carpenter
Copy link
Collaborator Author

Also, not sure why this would cause a failure as I didn't plug in any new builds or new tests.

@WardBrian
Copy link
Collaborator

It looks like you added the stationarity example to cmake, but it’s missing a #include it needs

@bob-carpenter
Copy link
Collaborator Author

Thanks---I hadn't looked into the cause, but I just added the include, and hopefully that will fix it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @bob-carpenter -- you might want to check out #26, which has a target very similar to this with a few more features already

std::string sample_csv_file_numbered = prefix + "-walnuts-draws-" + std::to_string(trial) + ".csv";
test_adaptive_walnuts(stan_model, sample_csv_file_numbered, trial_seed, iter_warmup, iter_sampling);
}
std::quick_exit(0); // crashes without this---not stan_model dtor, prob dlclose_deleter
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, does the preexisting example-stan program crash at exit as well? This is probably isn't relevant for the purposes of evaluation, but could be important for later python bindings etc

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if this could be related to roualdes/bridgestan#111

Also, when stan is compiled using zig cc it also segfaults on exit, and since zig adds some flags to add traps to avoid undefined behavior, maybe that is also related?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, seems likely. We should probably just leak the handle for now, and try to figure out roualdes/bridgestan#111. My guess is we need a bs_finalize() function that calls tbb::finalize() and maybe also some math library specific stuff.

Copy link
Collaborator

@WardBrian WardBrian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few things that stood out to me, mostly focusing on the files mentioned in "III. Gradients until within error bound"


print("Sampling")
print(f"{0:5d}. min(ESS) = {0:.1e}")
for b in range(1, max_blocks + 1):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that this as currently coded isn't much different from one big run in terms of disk usage -- cmdstanpy cleans up it's temporary directory at exit not when the fit object is gc'd.

To fix, you could use the TemporaryDirectory context manager and the output_dir argument to sample

Comment on lines +47 to +51
da = xr.DataArray(a[np.newaxis, :, :], dims=("chain", "draw", "var"))
ds = az.ess(da, method="bulk")
data_var = next(iter(ds.data_vars))
vec = ds[data_var].values
return np.asarray(np.squeeze(vec), dtype=np.float64)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is pretty hard to read/verify it's doing the right thing without manually running it a few times. Assuming you're confident in it that's probably fine, but just wanted to flag it

DynamicStanModel stan_model(model_so_file_c_str, data_json_file_c_str, seed);
for (int trial = 0; trial < trials; ++trial) {
std::cout << "trial = " << trial << std::endl;
unsigned int trial_seed = seed + static_cast<unsigned int>(17 * (trial + 1));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unsure it matters, but curious why you're doing seed + 17 * trial here but seed + trial in eval-nuts.py?

for (Integer n = 0; n < iter_sampling; ++n) {
auto draw = sampler();
model.constrain_draw(draw, draws.col(n));
lp_grads[static_cast<std::size_t>(n)] = logp_grad_calls;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious why these static casts are necessary, and why you wouldn't just declare the loop as for(size_t n ... instead

* `examples/stan-warmup.py`


## III. Gradients until within error bound
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would recommend collecting these steps into one bash or python script at some point to make running an experiment a bit less manual

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants