We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent d74ba9c commit 15f6f64Copy full SHA for 15f6f64
aeppl/mixture.py
@@ -337,9 +337,15 @@ def switch_ifelse_mixture_replace(fgraph, node):
337
old_mixture_rv.dtype,
338
old_mixture_rv.broadcastable,
339
)
340
- new_node = mix_op.make_node(
341
- *([NoneConst, as_nontensor_scalar(node.inputs[0])] + mixture_rvs)
342
- )
+
+ if node.inputs[0].ndim == 0:
+ # as_nontensor_scalar to allow graphs to be identical to mixture sub-graphs
343
+ # created using at.stack and Subtensor indexing
344
+ new_node = mix_op.make_node(
345
+ *([NoneConst, as_nontensor_scalar(node.inputs[0])] + mixture_rvs)
346
+ )
347
+ else:
348
+ new_node = mix_op.make_node(*([NoneConst, node.inputs[0]] + mixture_rvs))
349
350
new_mixture_rv = new_node.default_output()
351
0 commit comments