Skip to content

Commit d0c9f1e

Browse files
Merge pull request #89 from elseml/Development
Add optional moving average line of training losses
2 parents ba1d4d4 + fe4f7ba commit d0c9f1e

File tree

1 file changed

+51
-32
lines changed

1 file changed

+51
-32
lines changed

bayesflow/diagnostics.py

Lines changed: 51 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -788,6 +788,8 @@ def plot_posterior_2d(
788788
def plot_losses(
789789
train_losses,
790790
val_losses=None,
791+
moving_average=False,
792+
ma_window_fraction=0.01,
791793
fig_size=None,
792794
train_color="#8f2727",
793795
val_color="black",
@@ -803,31 +805,35 @@ def plot_losses(
803805
Parameters
804806
----------
805807
806-
train_losses : pd.DataFrame
808+
train_losses : pd.DataFrame
807809
The (plottable) history as returned by a train_[...] method of a ``Trainer`` instance.
808810
Alternatively, you can just pass a data frame of validation losses instead of train losses,
809811
if you only want to plot the validation loss.
810-
val_losses : pd.DataFrame or None, optional, default: None
812+
val_losses : pd.DataFrame or None, optional, default: None
811813
The (plottable) validation history as returned by a train_[...] method of a ``Trainer`` instance.
812814
If left ``None``, only train losses are plotted. Should have the same number of columns
813815
as ``train_losses``.
814-
fig_size : tuple or None, optional, default: None
816+
moving_average : bool, optional, default: False
817+
A flag for adding a moving average line of the train_losses.
818+
ma_window_fraction : int, optional, default: 0.01
819+
Window size for the moving average as a fraction of total training steps.
820+
fig_size : tuple or None, optional, default: None
815821
The figure size passed to the ``matplotlib`` constructor. Inferred if ``None``
816-
train_color : str, optional, default: '#8f2727'
822+
train_color : str, optional, default: '#8f2727'
817823
The color for the train loss trajectory
818-
val_color : str, optional, default: black
824+
val_color : str, optional, default: black
819825
The color for the optional validation loss trajectory
820-
lw_train : int, optional, default: 2
826+
lw_train : int, optional, default: 2
821827
The linewidth for the training loss curve
822-
lw_val : int, optional, default: 3
828+
lw_val : int, optional, default: 3
823829
The linewidth for the validation loss curve
824-
grid_alpha : float, optional, default 0.5
830+
grid_alpha : float, optional, default 0.5
825831
The opacity factor for the background gridlines
826-
legend_fontsize : int, optional, default: 14
832+
legend_fontsize : int, optional, default: 14
827833
The font size of the legend text
828-
label_fontsize : int, optional, default: 14
834+
label_fontsize : int, optional, default: 14
829835
The font size of the y-label text
830-
title_fontsize : int, optional, default: 16
836+
title_fontsize : int, optional, default: 16
831837
The font size of the title text
832838
833839
Returns
@@ -864,6 +870,11 @@ def plot_losses(
864870
for i, ax in enumerate(looper):
865871
# Plot train curve
866872
ax.plot(train_step_index, train_losses.iloc[:, i], color=train_color, lw=lw_train, alpha=0.9, label="Training")
873+
if moving_average:
874+
moving_average_window = int(train_losses.shape[0] * ma_window_fraction)
875+
smoothed_loss = train_losses.iloc[:, i].rolling(window=moving_average_window).mean()
876+
ax.plot(train_step_index, smoothed_loss, color="grey", lw=lw_train, label="Training (Moving Average)")
877+
867878
# Plot optional val curve
868879
if val_losses is not None:
869880
if i < val_losses.shape[1]:
@@ -1172,10 +1183,10 @@ def plot_confusion_matrix(
11721183
ax.set(xticks=np.arange(cm.shape[1]), yticks=np.arange(cm.shape[0]))
11731184
ax.set_xticklabels(model_names, fontsize=tick_fontsize)
11741185
if xtick_rotation:
1175-
plt.xticks(rotation=xtick_rotation, ha="right")
1186+
plt.xticks(rotation=xtick_rotation, ha="right")
11761187
ax.set_yticklabels(model_names, fontsize=tick_fontsize)
11771188
if ytick_rotation:
1178-
plt.yticks(rotation=ytick_rotation)
1189+
plt.yticks(rotation=ytick_rotation)
11791190
ax.set_xlabel("Predicted model", fontsize=tick_fontsize)
11801191
ax.set_ylabel("True model", fontsize=tick_fontsize)
11811192

@@ -1192,16 +1203,18 @@ def plot_confusion_matrix(
11921203
return fig
11931204

11941205

1195-
def plot_mmd_hypothesis_test(mmd_null,
1196-
mmd_observed=None,
1197-
alpha_level=0.05,
1198-
null_color=(0.16407, 0.020171, 0.577478),
1199-
observed_color="red",
1200-
alpha_color="orange",
1201-
truncate_vlines_at_kde=False,
1202-
xmin=None,
1203-
xmax=None,
1204-
bw_factor=1.5):
1206+
def plot_mmd_hypothesis_test(
1207+
mmd_null,
1208+
mmd_observed=None,
1209+
alpha_level=0.05,
1210+
null_color=(0.16407, 0.020171, 0.577478),
1211+
observed_color="red",
1212+
alpha_color="orange",
1213+
truncate_vlines_at_kde=False,
1214+
xmin=None,
1215+
xmax=None,
1216+
bw_factor=1.5,
1217+
):
12051218
"""
12061219
12071220
Parameters
@@ -1242,25 +1255,31 @@ def draw_vline_to_kde(x, kde_object, color, label=None, **kwargs):
12421255
def fill_area_under_kde(kde_object, x_start, x_end=None, **kwargs):
12431256
kde_x, kde_y = kde_object.lines[0].get_data()
12441257
if x_end is not None:
1245-
plt.fill_between(kde_x, kde_y, where=(kde_x >= x_start) & (kde_x <= x_end),
1246-
interpolate=True, **kwargs)
1258+
plt.fill_between(kde_x, kde_y, where=(kde_x >= x_start) & (kde_x <= x_end), interpolate=True, **kwargs)
12471259
else:
1248-
plt.fill_between(kde_x, kde_y, where=(kde_x >= x_start),
1249-
interpolate=True, **kwargs)
1260+
plt.fill_between(kde_x, kde_y, where=(kde_x >= x_start), interpolate=True, **kwargs)
12501261

12511262
f = plt.figure(figsize=(8, 4))
12521263

12531264
kde = sns.kdeplot(mmd_null, fill=False, linewidth=0, bw_adjust=bw_factor)
1254-
sns.kdeplot(mmd_null, fill=True, alpha=.12, color=null_color, bw_adjust=bw_factor)
1265+
sns.kdeplot(mmd_null, fill=True, alpha=0.12, color=null_color, bw_adjust=bw_factor)
12551266

12561267
if truncate_vlines_at_kde:
12571268
draw_vline_to_kde(x=mmd_observed, kde_object=kde, color=observed_color, label=r"Observed data")
12581269
else:
1259-
plt.vlines(x=mmd_observed, ymin=0, ymax=plt.gca().get_ylim()[1], color=observed_color, linewidth=3,
1260-
label=r"Observed data")
1270+
plt.vlines(
1271+
x=mmd_observed,
1272+
ymin=0,
1273+
ymax=plt.gca().get_ylim()[1],
1274+
color=observed_color,
1275+
linewidth=3,
1276+
label=r"Observed data",
1277+
)
12611278

12621279
mmd_critical = np.quantile(mmd_null, 1 - alpha_level)
1263-
fill_area_under_kde(kde, mmd_critical, color=alpha_color, alpha=0.5, label=fr"{int(alpha_level*100)}% rejection area")
1280+
fill_area_under_kde(
1281+
kde, mmd_critical, color=alpha_color, alpha=0.5, label=rf"{int(alpha_level*100)}% rejection area"
1282+
)
12641283

12651284
if truncate_vlines_at_kde:
12661285
draw_vline_to_kde(x=mmd_critical, kde_object=kde, color=alpha_color)
@@ -1273,7 +1292,7 @@ def fill_area_under_kde(kde_object, x_start, x_end=None, **kwargs):
12731292
plt.ylabel("")
12741293
plt.yticks([])
12751294
plt.xlim(xmin, xmax)
1276-
plt.tick_params(axis='both', which='major', labelsize=16)
1295+
plt.tick_params(axis="both", which="major", labelsize=16)
12771296

12781297
plt.legend(fontsize=20)
12791298
sns.despine()

0 commit comments

Comments
 (0)