Skip to content

Commit f5f7a6a

Browse files
Merge pull request #116 from stefanradev93/Development
Update README.md with forum
2 parents ec61454 + b098a42 commit f5f7a6a

File tree

3 files changed

+19
-13
lines changed

3 files changed

+19
-13
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@ For starters, check out some of our walk-through notebooks:
1919
7. [Model comparison for cognitive models](examples/Model_Comparison_MPT.ipynb)
2020
8. [Hierarchical model comparison for cognitive models](examples/Hierarchical_Model_Comparison_MPT.ipynb)
2121

22-
## Project Documentation
22+
## Documentation \& Help
2323

24-
The project documentation is available at <https://bayesflow.org>
24+
The project documentation is available at <https://bayesflow.org>. Please use the [BayesFlow Forums](https://discuss.bayesflow.org/) for any BayesFlow-related questions and discussions, and [GitHub Issues](https://github.com/stefanradev93/BayesFlow/issues) for bug reports and feature requests.
2525

2626
## Installation
2727

bayesflow/diagnostics.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def plot_recovery(
5353
n_row=None,
5454
xlabel="Ground truth",
5555
ylabel="Estimated",
56-
**kwargs
56+
**kwargs,
5757
):
5858
"""Creates and plots publication-ready recovery plot with true vs. point estimate + uncertainty.
5959
The point estimate can be controlled with the ``point_agg`` argument, and the uncertainty estimate
@@ -110,7 +110,7 @@ def plot_recovery(
110110
**kwargs : optional
111111
Additional keyword arguments passed to ax.errorbar or ax.scatter.
112112
Example: `rasterized=True` to reduce PDF file size with many dots
113-
113+
114114
Returns
115115
-------
116116
f : plt.Figure - the figure instance for optional saving
@@ -240,7 +240,7 @@ def plot_z_score_contraction(
240240
tick_fontsize=12,
241241
color="#8f2727",
242242
n_col=None,
243-
n_row=None
243+
n_row=None,
244244
):
245245
"""Implements a graphical check for global model sensitivity by plotting the posterior
246246
z-score over the posterior contraction for each set of posterior samples in ``post_samples``
@@ -567,7 +567,7 @@ def plot_sbc_histograms(
567567
tick_fontsize=12,
568568
hist_color="#a34f4f",
569569
n_row=None,
570-
n_col=None
570+
n_col=None,
571571
):
572572
"""Creates and plots publication-ready histograms of rank statistics for simulation-based calibration
573573
(SBC) checks according to [1].
@@ -910,7 +910,7 @@ def plot_losses(
910910
for i, ax in enumerate(looper):
911911
# Plot train curve
912912
ax.plot(train_step_index, train_losses.iloc[:, i], color=train_color, lw=lw_train, alpha=0.9, label="Training")
913-
if moving_average:
913+
if moving_average and train_losses.columns[i] == "Loss":
914914
moving_average_window = int(train_losses.shape[0] * ma_window_fraction)
915915
smoothed_loss = train_losses.iloc[:, i].rolling(window=moving_average_window).mean()
916916
ax.plot(train_step_index, smoothed_loss, color="grey", lw=lw_train, label="Training (Moving Average)")
@@ -929,7 +929,7 @@ def plot_losses(
929929
)
930930
# Schmuck
931931
ax.set_xlabel("Training step #", fontsize=label_fontsize)
932-
ax.set_ylabel("Loss value", fontsize=label_fontsize)
932+
ax.set_ylabel("Value", fontsize=label_fontsize)
933933
sns.despine(ax=ax)
934934
ax.grid(alpha=grid_alpha)
935935
ax.set_title(train_losses.columns[i], fontsize=title_fontsize)
@@ -1061,7 +1061,7 @@ def plot_calibration_curves(
10611061
fig_size=None,
10621062
color="#8f2727",
10631063
n_row=None,
1064-
n_col=None
1064+
n_col=None,
10651065
):
10661066
"""Plots the calibration curves, the ECEs and the marginal histograms of predicted posterior model probabilities
10671067
for a model comparison problem. The marginal histograms inform about the fraction of predictions in each bin.
@@ -1114,7 +1114,6 @@ def plot_calibration_curves(
11141114
elif n_row is not None and n_col is None:
11151115
n_col = int(np.ceil(num_models / n_row))
11161116

1117-
11181117
# Compute calibration
11191118
cal_errs, probs_true, probs_pred = expected_calibration_error(true_models, pred_models, num_bins)
11201119

@@ -1273,7 +1272,13 @@ def plot_confusion_matrix(
12731272
for i in range(cm.shape[0]):
12741273
for j in range(cm.shape[1]):
12751274
ax.text(
1276-
j, i, format(cm[i, j], fmt), fontsize=value_fontsize, ha="center", va="center", color="white" if cm[i, j] > thresh else "black"
1275+
j,
1276+
i,
1277+
format(cm[i, j], fmt),
1278+
fontsize=value_fontsize,
1279+
ha="center",
1280+
va="center",
1281+
color="white" if cm[i, j] > thresh else "black",
12771282
)
12781283
if title:
12791284
ax.set_title("Confusion Matrix", fontsize=title_fontsize)

bayesflow/summary_networks.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def __init__(
140140
# Construct final attention layer, which will perform cross-attention
141141
# between the outputs ot the self-attention layers and the dynamic template
142142
if bidirectional:
143-
final_input_dim = template_dim*2
143+
final_input_dim = template_dim * 2
144144
else:
145145
final_input_dim = template_dim
146146
self.output_attention = MultiHeadAttentionBlock(
@@ -184,7 +184,8 @@ def call(self, x, **kwargs):
184184

185185
class SetTransformer(tf.keras.Model):
186186
"""Implements the set transformer architecture from [1] which ultimately represents
187-
a learnable permutation-invariant function.
187+
a learnable permutation-invariant function. Designed to naturally model interactions in
188+
the input set, which may be hard to capture with the simpler ``DeepSet`` architecture.
188189
189190
[1] Lee, J., Lee, Y., Kim, J., Kosiorek, A., Choi, S., & Teh, Y. W. (2019).
190191
Set transformer: A framework for attention-based permutation-invariant neural networks.

0 commit comments

Comments
 (0)