@@ -188,7 +188,7 @@ def forward(self, targets, condition, **kwargs):
188
188
condition_shape = tf .shape (condition )
189
189
190
190
# Needs to be concatinable with condition
191
- if tf . rank ( condition ) == 2 :
191
+ if len ( condition_shape ) == 2 :
192
192
shape_scale = (condition_shape [0 ], 1 )
193
193
else :
194
194
shape_scale = (condition_shape [0 ], condition_shape [1 ], 1 )
@@ -201,7 +201,7 @@ def forward(self, targets, condition, **kwargs):
201
201
noise_scale = tf .zeros (shape = shape_scale ) + self .soft_low
202
202
203
203
# Perturb data with noise (will broadcast to all dimensions)
204
- if len (shape_scale ) == 2 and tf . rank ( targets ) == 3 :
204
+ if len (shape_scale ) == 2 and len ( target_shape ) == 3 :
205
205
targets += tf .expand_dims (noise_scale , axis = 1 ) * tf .random .normal (shape = target_shape )
206
206
else :
207
207
targets += noise_scale * tf .random .normal (shape = target_shape )
@@ -228,7 +228,7 @@ def inverse(self, z, condition, **kwargs):
228
228
if self .soft_flow and condition is not None :
229
229
# Needs to be concatinable with condition
230
230
shape_scale = (
231
- (condition .shape [0 ], 1 ) if tf . rank (condition ) == 2 else (condition .shape [0 ], condition .shape [1 ], 1 )
231
+ (condition .shape [0 ], 1 ) if len (condition . shape ) == 2 else (condition .shape [0 ], condition .shape [1 ], 1 )
232
232
)
233
233
noise_scale = tf .zeros (shape = shape_scale ) + 2.0 * self .soft_low
234
234
0 commit comments