Skip to content

Commit b3c562c

Browse files
committed
Fix MMD shape determination for offline autograph mode.
1 parent 7648dd6 commit b3c562c

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

bayesflow/losses.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
# SOFTWARE.
2020

2121
import tensorflow as tf
22-
import tensorflow_probability as tfp
2322

2423
from bayesflow.computational_utilities import maximum_mean_discrepancy
2524

@@ -62,7 +61,7 @@ def kl_latent_space_student(v, z, log_det_J):
6261
z : tf.Tensor of shape (batch_size, ...)
6362
The (latent transformed) target variables
6463
log_det_J : tf.Tensor of shape (batch_size, ...)
65-
The logartihm of the Jacobian determinant of the transformation.
64+
The logarithm of the Jacobian determinant of the transformation.
6665
6766
Returns
6867
-------
@@ -131,7 +130,7 @@ def mmd_summary_space(summary_outputs, z_dist=tf.random.normal, kernel="gaussian
131130
The kernel function to use for MMD computation.
132131
"""
133132

134-
z_samples = z_dist(summary_outputs.shape)
133+
z_samples = z_dist(tf.shape(summary_outputs))
135134
mmd_loss = maximum_mean_discrepancy(summary_outputs, z_samples, kernel)
136135
return mmd_loss
137136

0 commit comments

Comments
 (0)