diff --git a/aeppl/mixture.py b/aeppl/mixture.py index ae08f277..a260eca4 100644 --- a/aeppl/mixture.py +++ b/aeppl/mixture.py @@ -10,7 +10,7 @@ node_rewriter, pre_greedy_node_rewriter, ) -from aesara.ifelse import ifelse +from aesara.ifelse import IfElse, ifelse from aesara.scalar.basic import Switch from aesara.tensor.basic import Join, MakeVector from aesara.tensor.elemwise import Elemwise @@ -309,6 +309,31 @@ def mixture_replace(fgraph, node): def switch_mixture_replace(fgraph, node): rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None) + if rv_map_feature is None: + return None # pragma: no cover + + old_mixture_rv = node.default_output() + + # Add an extra dimension to the indices so that the `MixtureRV` we + # construct represents a valid + # `at.stack(node.inputs[1:])[f(node.inputs[0])]`, for some function `f`, + # that's equivalent to `at.switch(*node.inputs)`. + out_shape = at.broadcast_shape( + *(tuple(v.shape) for v in node.inputs[1:]), arrays_are_shapes=True + ) + switch_indices = (node.inputs[0],) + tuple(at.arange(s) for s in out_shape) + + # Construct the proxy/intermediate mixture representation + switch_stack = at.stack(node.inputs[::-1])[switch_indices] + switch_stack.name = old_mixture_rv.name + + return mixture_replace.transform(fgraph, switch_stack.owner) + + +@node_rewriter((IfElse,)) +def ifelse_mixture_replace(fgraph, node): + rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None) + if rv_map_feature is None: return None # pragma: no cover @@ -332,14 +357,25 @@ def switch_mixture_replace(fgraph, node): new_comp_rv = new_node.outputs[out_idx] mixture_rvs.append(new_comp_rv) + """ + Unlike mixtures generated via at.stack, there is only one condition, i.e. index + for switch/ifelse-defined mixture sub-graphs. However, this condition can be + non-scalar for Switch Ops. + """ mix_op = MixtureRV( 2, old_mixture_rv.type.dtype, old_mixture_rv.type.shape, ) - new_node = mix_op.make_node( - *([NoneConst, as_nontensor_scalar(node.inputs[0])] + mixture_rvs) - ) + + if node.inputs[0].ndim == 0: + # as_nontensor_scalar to allow graphs to be identical to mixture sub-graphs + # created using at.stack and Subtensor indexing + new_node = mix_op.make_node( + *([NoneConst, as_nontensor_scalar(node.inputs[0])] + mixture_rvs) + ) + else: + new_node = mix_op.make_node(*([at.constant(0), node.inputs[0]] + mixture_rvs)) new_mixture_rv = new_node.default_output() @@ -420,7 +456,7 @@ def logprob_MixtureRV( logprob_rewrites_db.register( "mixture_replace", EquilibriumGraphRewriter( - [mixture_replace, switch_mixture_replace], + [mixture_replace, switch_mixture_replace, ifelse_mixture_replace], max_use_ratio=aesara.config.optdb__max_use_ratio, ), "basic", diff --git a/tests/test_mixture.py b/tests/test_mixture.py index 0005f25e..a6e6f9c6 100644 --- a/tests/test_mixture.py +++ b/tests/test_mixture.py @@ -4,6 +4,7 @@ import pytest import scipy.stats.distributions as sp from aesara.graph.basic import Variable, equal_computations +from aesara.ifelse import ifelse from aesara.tensor.random.basic import CategoricalRV from aesara.tensor.shape import shape_tuple from aesara.tensor.subtensor import as_index_constant @@ -232,25 +233,6 @@ def test_hetero_mixture_binomial(p_val, size): (), 0, ), - ( - ( - np.array(0, dtype=aesara.config.floatX), - np.array(1, dtype=aesara.config.floatX), - ), - ( - np.array(0.5, dtype=aesara.config.floatX), - np.array(0.5, dtype=aesara.config.floatX), - ), - ( - np.array(100, dtype=aesara.config.floatX), - np.array(1, dtype=aesara.config.floatX), - ), - np.array([0.1, 0.5, 0.4], dtype=aesara.config.floatX), - (), - (), - (), - 0, - ), ( ( np.array(0, dtype=aesara.config.floatX), @@ -682,17 +664,122 @@ def test_mixture_with_DiracDelta(): assert M_rv in logp_res -def test_switch_mixture(): +@pytest.mark.parametrize( + "op, X_args, Y_args, p_val, comp_size, idx_size", + [ + [op] + list(test_args) + for op in [at.switch, ifelse] + for test_args in [ + ( + ( + np.array(-10, dtype=aesara.config.floatX), + np.array(0.1, dtype=aesara.config.floatX), + ), + ( + np.array(10, dtype=aesara.config.floatX), + np.array(0.1, dtype=aesara.config.floatX), + ), + np.array(0.5, dtype=aesara.config.floatX), + (), + (), + ), + ( + ( + np.array(-10, dtype=aesara.config.floatX), + np.array(0.1, dtype=aesara.config.floatX), + ), + ( + np.array(10, dtype=aesara.config.floatX), + np.array(0.1, dtype=aesara.config.floatX), + ), + np.array(0.5, dtype=aesara.config.floatX), + (), + (6,), + ), + ( + ( + np.array([10, 20], dtype=aesara.config.floatX), + np.array(0.1, dtype=aesara.config.floatX), + ), + ( + np.array([-10, -20], dtype=aesara.config.floatX), + np.array(0.1, dtype=aesara.config.floatX), + ), + np.array([0.9, 0.1], dtype=aesara.config.floatX), + (2,), + (2,), + ), + ( + ( + np.array([10, 20], dtype=aesara.config.floatX), + np.array(0.1, dtype=aesara.config.floatX), + ), + ( + np.array([-10, -20], dtype=aesara.config.floatX), + np.array(0.1, dtype=aesara.config.floatX), + ), + np.array([0.9, 0.1], dtype=aesara.config.floatX), + None, + None, + ), + ( + ( + np.array(-10, dtype=aesara.config.floatX), + np.array(0.1, dtype=aesara.config.floatX), + ), + ( + np.array(10, dtype=aesara.config.floatX), + np.array(0.1, dtype=aesara.config.floatX), + ), + np.array(0.5, dtype=aesara.config.floatX), + (2, 3), + (2, 3), + ), + ( + ( + np.array(10, dtype=aesara.config.floatX), + np.array(0.1, dtype=aesara.config.floatX), + ), + ( + np.array(-10, dtype=aesara.config.floatX), + np.array(0.1, dtype=aesara.config.floatX), + ), + np.array(0.5, dtype=aesara.config.floatX), + (2, 3), + (), + ), + ( + ( + np.array(10, dtype=aesara.config.floatX), + np.array(0.1, dtype=aesara.config.floatX), + ), + ( + np.array(-10, dtype=aesara.config.floatX), + np.array(0.1, dtype=aesara.config.floatX), + ), + np.array(0.5, dtype=aesara.config.floatX), + (3,), + (3,), + ), + ] + if not ((test_args[-1] is None or len(test_args[-1]) > 0) and op == ifelse) + ], +) +def test_switch_ifelse_mixture(op, X_args, Y_args, p_val, comp_size, idx_size): + """ + The argument size is both the input to srng.normal and the expected + size of the mixture RV Z1_rv + """ srng = at.random.RandomStream(29833) - X_rv = srng.normal(-10.0, 0.1, name="X") - Y_rv = srng.normal(10.0, 0.1, name="Y") + X_rv = srng.normal(*X_args, size=comp_size, name="X") + Y_rv = srng.normal(*Y_args, size=comp_size, name="Y") - I_rv = srng.bernoulli(0.5, name="I") + I_rv = srng.bernoulli(p_val, size=idx_size, name="I") i_vv = I_rv.clone() i_vv.name = "i" - Z1_rv = at.switch(I_rv, X_rv, Y_rv) + Z1_rv = op(I_rv, X_rv, Y_rv) z_vv = Z1_rv.clone() z_vv.name = "z1"