@@ -788,6 +788,8 @@ def plot_posterior_2d(
788
788
def plot_losses (
789
789
train_losses ,
790
790
val_losses = None ,
791
+ moving_average = False ,
792
+ ma_window_fraction = 0.01 ,
791
793
fig_size = None ,
792
794
train_color = "#8f2727" ,
793
795
val_color = "black" ,
@@ -803,31 +805,35 @@ def plot_losses(
803
805
Parameters
804
806
----------
805
807
806
- train_losses : pd.DataFrame
808
+ train_losses : pd.DataFrame
807
809
The (plottable) history as returned by a train_[...] method of a ``Trainer`` instance.
808
810
Alternatively, you can just pass a data frame of validation losses instead of train losses,
809
811
if you only want to plot the validation loss.
810
- val_losses : pd.DataFrame or None, optional, default: None
812
+ val_losses : pd.DataFrame or None, optional, default: None
811
813
The (plottable) validation history as returned by a train_[...] method of a ``Trainer`` instance.
812
814
If left ``None``, only train losses are plotted. Should have the same number of columns
813
815
as ``train_losses``.
814
- fig_size : tuple or None, optional, default: None
816
+ moving_average : bool, optional, default: False
817
+ A flag for adding a moving average line of the train_losses.
818
+ ma_window_fraction : int, optional, default: 0.01
819
+ Window size for the moving average as a fraction of total training steps.
820
+ fig_size : tuple or None, optional, default: None
815
821
The figure size passed to the ``matplotlib`` constructor. Inferred if ``None``
816
- train_color : str, optional, default: '#8f2727'
822
+ train_color : str, optional, default: '#8f2727'
817
823
The color for the train loss trajectory
818
- val_color : str, optional, default: black
824
+ val_color : str, optional, default: black
819
825
The color for the optional validation loss trajectory
820
- lw_train : int, optional, default: 2
826
+ lw_train : int, optional, default: 2
821
827
The linewidth for the training loss curve
822
- lw_val : int, optional, default: 3
828
+ lw_val : int, optional, default: 3
823
829
The linewidth for the validation loss curve
824
- grid_alpha : float, optional, default 0.5
830
+ grid_alpha : float, optional, default 0.5
825
831
The opacity factor for the background gridlines
826
- legend_fontsize : int, optional, default: 14
832
+ legend_fontsize : int, optional, default: 14
827
833
The font size of the legend text
828
- label_fontsize : int, optional, default: 14
834
+ label_fontsize : int, optional, default: 14
829
835
The font size of the y-label text
830
- title_fontsize : int, optional, default: 16
836
+ title_fontsize : int, optional, default: 16
831
837
The font size of the title text
832
838
833
839
Returns
@@ -864,6 +870,11 @@ def plot_losses(
864
870
for i , ax in enumerate (looper ):
865
871
# Plot train curve
866
872
ax .plot (train_step_index , train_losses .iloc [:, i ], color = train_color , lw = lw_train , alpha = 0.9 , label = "Training" )
873
+ if moving_average :
874
+ moving_average_window = int (train_losses .shape [0 ] * ma_window_fraction )
875
+ smoothed_loss = train_losses .iloc [:, i ].rolling (window = moving_average_window ).mean ()
876
+ ax .plot (train_step_index , smoothed_loss , color = "grey" , lw = lw_train , label = "Training (Moving Average)" )
877
+
867
878
# Plot optional val curve
868
879
if val_losses is not None :
869
880
if i < val_losses .shape [1 ]:
@@ -1172,10 +1183,10 @@ def plot_confusion_matrix(
1172
1183
ax .set (xticks = np .arange (cm .shape [1 ]), yticks = np .arange (cm .shape [0 ]))
1173
1184
ax .set_xticklabels (model_names , fontsize = tick_fontsize )
1174
1185
if xtick_rotation :
1175
- plt .xticks (rotation = xtick_rotation , ha = "right" )
1186
+ plt .xticks (rotation = xtick_rotation , ha = "right" )
1176
1187
ax .set_yticklabels (model_names , fontsize = tick_fontsize )
1177
1188
if ytick_rotation :
1178
- plt .yticks (rotation = ytick_rotation )
1189
+ plt .yticks (rotation = ytick_rotation )
1179
1190
ax .set_xlabel ("Predicted model" , fontsize = tick_fontsize )
1180
1191
ax .set_ylabel ("True model" , fontsize = tick_fontsize )
1181
1192
@@ -1192,16 +1203,18 @@ def plot_confusion_matrix(
1192
1203
return fig
1193
1204
1194
1205
1195
- def plot_mmd_hypothesis_test (mmd_null ,
1196
- mmd_observed = None ,
1197
- alpha_level = 0.05 ,
1198
- null_color = (0.16407 , 0.020171 , 0.577478 ),
1199
- observed_color = "red" ,
1200
- alpha_color = "orange" ,
1201
- truncate_vlines_at_kde = False ,
1202
- xmin = None ,
1203
- xmax = None ,
1204
- bw_factor = 1.5 ):
1206
+ def plot_mmd_hypothesis_test (
1207
+ mmd_null ,
1208
+ mmd_observed = None ,
1209
+ alpha_level = 0.05 ,
1210
+ null_color = (0.16407 , 0.020171 , 0.577478 ),
1211
+ observed_color = "red" ,
1212
+ alpha_color = "orange" ,
1213
+ truncate_vlines_at_kde = False ,
1214
+ xmin = None ,
1215
+ xmax = None ,
1216
+ bw_factor = 1.5 ,
1217
+ ):
1205
1218
"""
1206
1219
1207
1220
Parameters
@@ -1242,25 +1255,31 @@ def draw_vline_to_kde(x, kde_object, color, label=None, **kwargs):
1242
1255
def fill_area_under_kde (kde_object , x_start , x_end = None , ** kwargs ):
1243
1256
kde_x , kde_y = kde_object .lines [0 ].get_data ()
1244
1257
if x_end is not None :
1245
- plt .fill_between (kde_x , kde_y , where = (kde_x >= x_start ) & (kde_x <= x_end ),
1246
- interpolate = True , ** kwargs )
1258
+ plt .fill_between (kde_x , kde_y , where = (kde_x >= x_start ) & (kde_x <= x_end ), interpolate = True , ** kwargs )
1247
1259
else :
1248
- plt .fill_between (kde_x , kde_y , where = (kde_x >= x_start ),
1249
- interpolate = True , ** kwargs )
1260
+ plt .fill_between (kde_x , kde_y , where = (kde_x >= x_start ), interpolate = True , ** kwargs )
1250
1261
1251
1262
f = plt .figure (figsize = (8 , 4 ))
1252
1263
1253
1264
kde = sns .kdeplot (mmd_null , fill = False , linewidth = 0 , bw_adjust = bw_factor )
1254
- sns .kdeplot (mmd_null , fill = True , alpha = .12 , color = null_color , bw_adjust = bw_factor )
1265
+ sns .kdeplot (mmd_null , fill = True , alpha = 0 .12 , color = null_color , bw_adjust = bw_factor )
1255
1266
1256
1267
if truncate_vlines_at_kde :
1257
1268
draw_vline_to_kde (x = mmd_observed , kde_object = kde , color = observed_color , label = r"Observed data" )
1258
1269
else :
1259
- plt .vlines (x = mmd_observed , ymin = 0 , ymax = plt .gca ().get_ylim ()[1 ], color = observed_color , linewidth = 3 ,
1260
- label = r"Observed data" )
1270
+ plt .vlines (
1271
+ x = mmd_observed ,
1272
+ ymin = 0 ,
1273
+ ymax = plt .gca ().get_ylim ()[1 ],
1274
+ color = observed_color ,
1275
+ linewidth = 3 ,
1276
+ label = r"Observed data" ,
1277
+ )
1261
1278
1262
1279
mmd_critical = np .quantile (mmd_null , 1 - alpha_level )
1263
- fill_area_under_kde (kde , mmd_critical , color = alpha_color , alpha = 0.5 , label = fr"{ int (alpha_level * 100 )} % rejection area" )
1280
+ fill_area_under_kde (
1281
+ kde , mmd_critical , color = alpha_color , alpha = 0.5 , label = rf"{ int (alpha_level * 100 )} % rejection area"
1282
+ )
1264
1283
1265
1284
if truncate_vlines_at_kde :
1266
1285
draw_vline_to_kde (x = mmd_critical , kde_object = kde , color = alpha_color )
@@ -1273,7 +1292,7 @@ def fill_area_under_kde(kde_object, x_start, x_end=None, **kwargs):
1273
1292
plt .ylabel ("" )
1274
1293
plt .yticks ([])
1275
1294
plt .xlim (xmin , xmax )
1276
- plt .tick_params (axis = ' both' , which = ' major' , labelsize = 16 )
1295
+ plt .tick_params (axis = " both" , which = " major" , labelsize = 16 )
1277
1296
1278
1297
plt .legend (fontsize = 20 )
1279
1298
sns .despine ()
0 commit comments