Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
287 changes: 279 additions & 8 deletions examples/rl.ipynb

Large diffs are not rendered by default.

64 changes: 48 additions & 16 deletions pyem/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,11 +367,18 @@ def recover(self, true_params: np.ndarray, pr_inputs: List[str], simulate_func:
self._out = recovery_model._out
return recovery_dict

def plot_recovery(self, recovery_dict: dict, show_line: bool = True,
figsize: tuple = (10, 4), show: bool = True) -> plt.Figure:
def plot_recovery(
self,
recovery_dict: dict,
show_line: bool = True,
figsize: tuple | None = None,
show: bool = True
) -> plt.Figure:
"""
Plot parameter recovery as scatter plots of simulated vs estimated parameters.

Creates 3 columns with as many rows as needed, with compact spacing and
subplot sizes that scale with the grid.

Args:
recovery_dict: Output from recover() method, containing:
- 'true_params' (array-like, shape [n_sims, n_params])
Expand All @@ -387,37 +394,62 @@ def plot_recovery(self, recovery_dict: dict, show_line: bool = True,
estimated_params = recovery_dict['estimated_params']
nparams = true_params.shape[1]

# Create 1 x nparams layout (keep squeeze=False to always get 2D array, then ravel)
fig, axes = plt.subplots(1, nparams, figsize=figsize, squeeze=False)
axes = axes.ravel()
# Grid: 3 columns, compute rows
ncols = 3
nrows = int(np.ceil(nparams / ncols))

# Figure size: scale per-subplot to avoid tiny axes.
# Aim for 5x5 inches per subplot (square-ish data area works well here).
per_ax_w, per_ax_h = 3.5, 3.5
fig_w = per_ax_w * ncols
fig_h = per_ax_h * nrows
if figsize is None:
figsize = (fig_w, fig_h)

fig, axes = plt.subplots(
nrows, ncols,
figsize=figsize,
constrained_layout=True, # let Matplotlib handle spacing
squeeze=False
)

# In case self.param_names is longer than nparams
# Fine-tune constrained_layout paddings (reduces big gutters)
# w_pad/h_pad: padding around the figure edges; wspace/hspace: padding between subplots
fig.set_constrained_layout_pads()

axes = axes.ravel()
names = list(self.param_names)[:nparams]

for i, param_name in enumerate(names):
ax = axes[i]

# Use the shared plotting helper
plotting.plot_scatter(
true_params[:, i], f'True {param_name}',
estimated_params[:, i], f'Estimated {param_name}',
ax=ax,
show_line=show_line,
equal_limits=True,
s=75,
equal_limits=True, # still equalize limits (handled w/ box aspect below)
s=100, # slightly smaller markers to reduce overlap
alpha=0.6,
colorname='royalblue',
annotate=True,
)
# Title & tick/label sizing tuned so they don't collide with data
ax.tick_params(labelsize=12)
ax.xaxis.label.set_size(12)
ax.yaxis.label.set_size(12)

# Title
ax.set_title(f'{param_name}')
# Keep plots square without blowing up gutters
# (avoid ax.set_aspect('equal', adjustable='box') here)
try:
ax.set_box_aspect(1) # Matplotlib >=3.4
except Exception:
pass

# Hide any unused axes (just in case)
# Remove unused axes completely so they don't consume layout space
for j in range(nparams, len(axes)):
axes[j].set_visible(False)
axes[j].remove()

plt.tight_layout()
if show:
plt.show()

return fig
Loading