Skip to content

Commit 15f6f64

Browse files
Allow vector-valued indices for switch/ifelse mixture sub-graphs
1 parent d74ba9c commit 15f6f64

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

aeppl/mixture.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -337,9 +337,15 @@ def switch_ifelse_mixture_replace(fgraph, node):
337337
old_mixture_rv.dtype,
338338
old_mixture_rv.broadcastable,
339339
)
340-
new_node = mix_op.make_node(
341-
*([NoneConst, as_nontensor_scalar(node.inputs[0])] + mixture_rvs)
342-
)
340+
341+
if node.inputs[0].ndim == 0:
342+
# as_nontensor_scalar to allow graphs to be identical to mixture sub-graphs
343+
# created using at.stack and Subtensor indexing
344+
new_node = mix_op.make_node(
345+
*([NoneConst, as_nontensor_scalar(node.inputs[0])] + mixture_rvs)
346+
)
347+
else:
348+
new_node = mix_op.make_node(*([NoneConst, node.inputs[0]] + mixture_rvs))
343349

344350
new_mixture_rv = new_node.default_output()
345351

0 commit comments

Comments
 (0)