Skip to content

Commit f46374d

Browse files
committed
Bugfix ranks for various shapes 3
1 parent 6e652b7 commit f46374d

File tree

3 files changed

+7
-6
lines changed

3 files changed

+7
-6
lines changed

bayesflow/coupling_networks.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -436,9 +436,10 @@ def _semantic_spline_parameters(self, parameters):
436436
"""
437437

438438
shape = tf.shape(parameters)
439-
if len(shape) == 2:
439+
rank = len(shape)
440+
if rank == 2:
440441
new_shape = (shape[0], self.dim_out, -1)
441-
elif len(shape) == 3:
442+
elif rank == 3:
442443
new_shape = (shape[0], shape[1], self.dim_out, -1)
443444
else:
444445
raise NotImplementedError("Spline flows can currently only operate on 2D and 3D inputs!")

bayesflow/helper_networks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def call(self, target, condition, **kwargs):
112112

113113
# Handle 3D case for a set-flow and repeat condition over
114114
# the second `time` or `n_observations` axis of `target``
115-
if tf.rank(target) == 3 and tf.rank(condition) == 2:
115+
if len(tf.shape(target)) == 3 and len(tf.shape(condition)) == 2:
116116
shape = tf.shape(target)
117117
condition = tf.expand_dims(condition, 1)
118118
condition = tf.tile(condition, [1, shape[1], 1])

bayesflow/inference_networks.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def forward(self, targets, condition, **kwargs):
188188
condition_shape = tf.shape(condition)
189189

190190
# Needs to be concatinable with condition
191-
if tf.rank(condition) == 2:
191+
if len(condition_shape) == 2:
192192
shape_scale = (condition_shape[0], 1)
193193
else:
194194
shape_scale = (condition_shape[0], condition_shape[1], 1)
@@ -201,7 +201,7 @@ def forward(self, targets, condition, **kwargs):
201201
noise_scale = tf.zeros(shape=shape_scale) + self.soft_low
202202

203203
# Perturb data with noise (will broadcast to all dimensions)
204-
if len(shape_scale) == 2 and tf.rank(targets) == 3:
204+
if len(shape_scale) == 2 and len(target_shape) == 3:
205205
targets += tf.expand_dims(noise_scale, axis=1) * tf.random.normal(shape=target_shape)
206206
else:
207207
targets += noise_scale * tf.random.normal(shape=target_shape)
@@ -228,7 +228,7 @@ def inverse(self, z, condition, **kwargs):
228228
if self.soft_flow and condition is not None:
229229
# Needs to be concatinable with condition
230230
shape_scale = (
231-
(condition.shape[0], 1) if tf.rank(condition) == 2 else (condition.shape[0], condition.shape[1], 1)
231+
(condition.shape[0], 1) if len(condition.shape) == 2 else (condition.shape[0], condition.shape[1], 1)
232232
)
233233
noise_scale = tf.zeros(shape=shape_scale) + 2.0 * self.soft_low
234234

0 commit comments

Comments
 (0)