Skip to content

Commit 3f5ae48

Browse files
Add tests with non-null size for ifelse and switch mixtures
1 parent 15f6f64 commit 3f5ae48

File tree

2 files changed

+110
-31
lines changed

2 files changed

+110
-31
lines changed

aeppl/mixture.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,7 @@ def switch_ifelse_mixture_replace(fgraph, node):
345345
*([NoneConst, as_nontensor_scalar(node.inputs[0])] + mixture_rvs)
346346
)
347347
else:
348-
new_node = mix_op.make_node(*([NoneConst, node.inputs[0]] + mixture_rvs))
348+
new_node = mix_op.make_node(*([at.constant(0), node.inputs[0]] + mixture_rvs))
349349

350350
new_mixture_rv = new_node.default_output()
351351

tests/test_mixture.py

Lines changed: 109 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -249,25 +249,6 @@ def test_hetero_mixture_binomial(p_val, size):
249249
(),
250250
0,
251251
),
252-
(
253-
(
254-
np.array(0, dtype=aesara.config.floatX),
255-
np.array(1, dtype=aesara.config.floatX),
256-
),
257-
(
258-
np.array(0.5, dtype=aesara.config.floatX),
259-
np.array(0.5, dtype=aesara.config.floatX),
260-
),
261-
(
262-
np.array(100, dtype=aesara.config.floatX),
263-
np.array(1, dtype=aesara.config.floatX),
264-
),
265-
np.array([0.1, 0.5, 0.4], dtype=aesara.config.floatX),
266-
(),
267-
(),
268-
(),
269-
0,
270-
),
271252
(
272253
(
273254
np.array(0, dtype=aesara.config.floatX),
@@ -713,14 +694,118 @@ def test_mixture_with_DiracDelta():
713694
assert m_vv in logp_res
714695

715696

716-
@pytest.mark.parametrize("op", [at.switch, ifelse])
717-
def test_switch_ifelse_mixture(op):
697+
@pytest.mark.parametrize(
698+
"op, X_args, Y_args, p_val, comp_size, idx_size",
699+
[
700+
[op] + list(test_args)
701+
for op in [at.switch, ifelse]
702+
for test_args in [
703+
(
704+
(
705+
np.array(-10, dtype=aesara.config.floatX),
706+
np.array(0.1, dtype=aesara.config.floatX),
707+
),
708+
(
709+
np.array(10, dtype=aesara.config.floatX),
710+
np.array(0.1, dtype=aesara.config.floatX),
711+
),
712+
np.array(0.5, dtype=aesara.config.floatX),
713+
(),
714+
(),
715+
),
716+
(
717+
(
718+
np.array(-10, dtype=aesara.config.floatX),
719+
np.array(0.1, dtype=aesara.config.floatX),
720+
),
721+
(
722+
np.array(10, dtype=aesara.config.floatX),
723+
np.array(0.1, dtype=aesara.config.floatX),
724+
),
725+
np.array(0.5, dtype=aesara.config.floatX),
726+
(),
727+
(6,),
728+
),
729+
(
730+
(
731+
np.array([10, 20], dtype=aesara.config.floatX),
732+
np.array(0.1, dtype=aesara.config.floatX),
733+
),
734+
(
735+
np.array([-10, -20], dtype=aesara.config.floatX),
736+
np.array(0.1, dtype=aesara.config.floatX),
737+
),
738+
np.array([0.9, 0.1], dtype=aesara.config.floatX),
739+
(2,),
740+
(2,),
741+
),
742+
(
743+
(
744+
np.array([10, 20], dtype=aesara.config.floatX),
745+
np.array(0.1, dtype=aesara.config.floatX),
746+
),
747+
(
748+
np.array([-10, -20], dtype=aesara.config.floatX),
749+
np.array(0.1, dtype=aesara.config.floatX),
750+
),
751+
np.array([0.9, 0.1], dtype=aesara.config.floatX),
752+
None,
753+
None,
754+
),
755+
(
756+
(
757+
np.array(-10, dtype=aesara.config.floatX),
758+
np.array(0.1, dtype=aesara.config.floatX),
759+
),
760+
(
761+
np.array(10, dtype=aesara.config.floatX),
762+
np.array(0.1, dtype=aesara.config.floatX),
763+
),
764+
np.array(0.5, dtype=aesara.config.floatX),
765+
(2, 3),
766+
(2, 3),
767+
),
768+
(
769+
(
770+
np.array(10, dtype=aesara.config.floatX),
771+
np.array(0.1, dtype=aesara.config.floatX),
772+
),
773+
(
774+
np.array(-10, dtype=aesara.config.floatX),
775+
np.array(0.1, dtype=aesara.config.floatX),
776+
),
777+
np.array(0.5, dtype=aesara.config.floatX),
778+
(2, 3),
779+
(),
780+
),
781+
(
782+
(
783+
np.array(10, dtype=aesara.config.floatX),
784+
np.array(0.1, dtype=aesara.config.floatX),
785+
),
786+
(
787+
np.array(-10, dtype=aesara.config.floatX),
788+
np.array(0.1, dtype=aesara.config.floatX),
789+
),
790+
np.array(0.5, dtype=aesara.config.floatX),
791+
(3,),
792+
(3,),
793+
),
794+
]
795+
if not ((test_args[-1] is None or len(test_args[-1]) > 0) and op == ifelse)
796+
],
797+
)
798+
def test_switch_ifelse_mixture(op, X_args, Y_args, p_val, comp_size, idx_size):
799+
"""
800+
The argument size is both the input to srng.normal and the expected
801+
size of the mixture RV Z1_rv
802+
"""
718803
srng = at.random.RandomStream(29833)
719804

720-
X_rv = srng.normal(-10.0, 0.1, name="X")
721-
Y_rv = srng.normal(10.0, 0.1, name="Y")
805+
X_rv = srng.normal(*X_args, size=comp_size, name="X")
806+
Y_rv = srng.normal(*Y_args, size=comp_size, name="Y")
722807

723-
I_rv = srng.bernoulli(0.5, name="I")
808+
I_rv = srng.bernoulli(p_val, size=idx_size, name="I")
724809
i_vv = I_rv.clone()
725810
i_vv.name = "i"
726811

@@ -752,9 +837,3 @@ def test_switch_ifelse_mixture(op):
752837

753838
z1_logp = joint_logprob({Z1_rv: z_vv, I_rv: i_vv})
754839
z2_logp = joint_logprob({Z2_rv: z_vv, I_rv: i_vv})
755-
756-
# below should follow immediately from the equal_computations assertion above
757-
assert equal_computations([z1_logp], [z2_logp])
758-
759-
np.testing.assert_almost_equal(0.69049938, z1_logp.eval({z_vv: -10, i_vv: 0}))
760-
np.testing.assert_almost_equal(0.69049938, z2_logp.eval({z_vv: -10, i_vv: 0}))

0 commit comments

Comments
 (0)