Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 11 additions & 13 deletions examples/bayes.ipynb

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions examples/glm.ipynb

Large diffs are not rendered by default.

304 changes: 294 additions & 10 deletions examples/rl.ipynb

Large diffs are not rendered by default.

66 changes: 49 additions & 17 deletions pyem/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def fit(
convergence_custom: str | None = None,
convergence_crit: float = 1e-3,
convergence_precision: int = 6,
njobs: int = -1,
njobs: int = -2,
optim_method: str = "BFGS",
optim_options: dict | None = None,
max_restarts: int = 2,
Expand Down Expand Up @@ -367,11 +367,18 @@ def recover(self, true_params: np.ndarray, pr_inputs: List[str], simulate_func:
self._out = recovery_model._out
return recovery_dict

def plot_recovery(self, recovery_dict: dict, show_line: bool = True,
figsize: tuple = (10, 4), show: bool = True) -> plt.Figure:
def plot_recovery(
self,
recovery_dict: dict,
show_line: bool = True,
figsize: tuple | None = None,
show: bool = True
) -> plt.Figure:
"""
Plot parameter recovery as scatter plots of simulated vs estimated parameters.

Creates 3 columns with as many rows as needed, with compact spacing and
subplot sizes that scale with the grid.

Args:
recovery_dict: Output from recover() method, containing:
- 'true_params' (array-like, shape [n_sims, n_params])
Expand All @@ -387,37 +394,62 @@ def plot_recovery(self, recovery_dict: dict, show_line: bool = True,
estimated_params = recovery_dict['estimated_params']
nparams = true_params.shape[1]

# Create 1 x nparams layout (keep squeeze=False to always get 2D array, then ravel)
fig, axes = plt.subplots(1, nparams, figsize=figsize, squeeze=False)
axes = axes.ravel()
# Grid: 3 columns, compute rows
ncols = 3
nrows = int(np.ceil(nparams / ncols))

# Figure size: scale per-subplot to avoid tiny axes.
# Aim for 5x5 inches per subplot (square-ish data area works well here).
per_ax_w, per_ax_h = 3.5, 3.5
fig_w = per_ax_w * ncols
fig_h = per_ax_h * nrows
if figsize is None:
figsize = (fig_w, fig_h)

fig, axes = plt.subplots(
nrows, ncols,
figsize=figsize,
constrained_layout=True, # let Matplotlib handle spacing
squeeze=False
)

# In case self.param_names is longer than nparams
# Fine-tune constrained_layout paddings (reduces big gutters)
# w_pad/h_pad: padding around the figure edges; wspace/hspace: padding between subplots
fig.get_layout_engine().set() #h_pad=X, w_pad=Y, hspace=Z, wspace=W

axes = axes.ravel()
names = list(self.param_names)[:nparams]

for i, param_name in enumerate(names):
ax = axes[i]

# Use the shared plotting helper
plotting.plot_scatter(
true_params[:, i], f'True {param_name}',
estimated_params[:, i], f'Estimated {param_name}',
ax=ax,
show_line=show_line,
equal_limits=True,
s=75,
equal_limits=True, # still equalize limits (handled w/ box aspect below)
s=100, # slightly smaller markers to reduce overlap
alpha=0.6,
colorname='royalblue',
annotate=True,
)
# Title & tick/label sizing tuned so they don't collide with data
ax.tick_params(labelsize=12)
ax.xaxis.label.set_size(12)
ax.yaxis.label.set_size(12)

# Title
ax.set_title(f'{param_name}')
# Keep plots square without blowing up gutters
# (avoid ax.set_aspect('equal', adjustable='box') here)
try:
ax.set_box_aspect(1) # Matplotlib >=3.4
except Exception:
pass

# Hide any unused axes (just in case)
# Remove unused axes completely so they don't consume layout space
for j in range(nparams, len(axes)):
axes[j].set_visible(False)
axes[j].remove()

plt.tight_layout()
if show:
plt.show()

return fig
2 changes: 1 addition & 1 deletion pyem/core/em.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class EMConfig:
convergence_custom: Literal["relative_npl","running_average", None] = None
convergence_crit: float = 1e-3
convergence_precision: int = 6
njobs: int = -1
njobs: int = -2
optim: OptimConfig = field(default_factory=OptimConfig)
seed: int | None = None
max_subject_retries: int = 0 # additional retries if optimizer fails badly
Expand Down
5 changes: 2 additions & 3 deletions pyem/models/bayes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from __future__ import annotations
import numpy as np
from ..utils.math import norm2alpha, calc_fval

Expand All @@ -15,7 +14,7 @@ def _generate_fishp(lambda1: float, n_fish: int) -> np.ndarray:
fishp = np.eye(n_fish) * m + (1 - np.eye(n_fish)) * s
return fishp

def simulate(params: np.ndarray, nblocks: int = 10, ntrials: int = 15,
def bayes_sim(params: np.ndarray, nblocks: int = 10, ntrials: int = 15,
n_fish: int = 3) -> dict:
"""Simulate the fish task described in the repository documentation.

Expand Down Expand Up @@ -68,7 +67,7 @@ def simulate(params: np.ndarray, nblocks: int = 10, ntrials: int = 15,
"ponds": ponds,
}

def fit(params, choices, observations, prior=None, output: str = 'npl'):
def bayes_fit(params, choices, observations, prior=None, output: str = 'npl'):
"""Likelihood for the fish task.

Parameters are supplied in Gaussian space and transformed to ``lambda1``
Expand Down
Loading