Skip to content

Commit b284e35

Browse files
larryshamalamabrandonwillard
authored andcommitted
Add tests with non-null size for ifelse and switch mixtures
1 parent dc39ef8 commit b284e35

File tree

2 files changed

+110
-31
lines changed

2 files changed

+110
-31
lines changed

aeppl/mixture.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,7 @@ def switch_ifelse_mixture_replace(fgraph, node):
345345
*([NoneConst, as_nontensor_scalar(node.inputs[0])] + mixture_rvs)
346346
)
347347
else:
348-
new_node = mix_op.make_node(*([NoneConst, node.inputs[0]] + mixture_rvs))
348+
new_node = mix_op.make_node(*([at.constant(0), node.inputs[0]] + mixture_rvs))
349349

350350
new_mixture_rv = new_node.default_output()
351351

tests/test_mixture.py

Lines changed: 109 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -252,25 +252,6 @@ def test_hetero_mixture_binomial(p_val, size):
252252
(),
253253
0,
254254
),
255-
(
256-
(
257-
np.array(0, dtype=aesara.config.floatX),
258-
np.array(1, dtype=aesara.config.floatX),
259-
),
260-
(
261-
np.array(0.5, dtype=aesara.config.floatX),
262-
np.array(0.5, dtype=aesara.config.floatX),
263-
),
264-
(
265-
np.array(100, dtype=aesara.config.floatX),
266-
np.array(1, dtype=aesara.config.floatX),
267-
),
268-
np.array([0.1, 0.5, 0.4], dtype=aesara.config.floatX),
269-
(),
270-
(),
271-
(),
272-
0,
273-
),
274255
(
275256
(
276257
np.array(0, dtype=aesara.config.floatX),
@@ -716,14 +697,118 @@ def test_mixture_with_DiracDelta():
716697
assert m_vv in logp_res
717698

718699

719-
@pytest.mark.parametrize("op", [at.switch, ifelse])
720-
def test_switch_ifelse_mixture(op):
700+
@pytest.mark.parametrize(
701+
"op, X_args, Y_args, p_val, comp_size, idx_size",
702+
[
703+
[op] + list(test_args)
704+
for op in [at.switch, ifelse]
705+
for test_args in [
706+
(
707+
(
708+
np.array(-10, dtype=aesara.config.floatX),
709+
np.array(0.1, dtype=aesara.config.floatX),
710+
),
711+
(
712+
np.array(10, dtype=aesara.config.floatX),
713+
np.array(0.1, dtype=aesara.config.floatX),
714+
),
715+
np.array(0.5, dtype=aesara.config.floatX),
716+
(),
717+
(),
718+
),
719+
(
720+
(
721+
np.array(-10, dtype=aesara.config.floatX),
722+
np.array(0.1, dtype=aesara.config.floatX),
723+
),
724+
(
725+
np.array(10, dtype=aesara.config.floatX),
726+
np.array(0.1, dtype=aesara.config.floatX),
727+
),
728+
np.array(0.5, dtype=aesara.config.floatX),
729+
(),
730+
(6,),
731+
),
732+
(
733+
(
734+
np.array([10, 20], dtype=aesara.config.floatX),
735+
np.array(0.1, dtype=aesara.config.floatX),
736+
),
737+
(
738+
np.array([-10, -20], dtype=aesara.config.floatX),
739+
np.array(0.1, dtype=aesara.config.floatX),
740+
),
741+
np.array([0.9, 0.1], dtype=aesara.config.floatX),
742+
(2,),
743+
(2,),
744+
),
745+
(
746+
(
747+
np.array([10, 20], dtype=aesara.config.floatX),
748+
np.array(0.1, dtype=aesara.config.floatX),
749+
),
750+
(
751+
np.array([-10, -20], dtype=aesara.config.floatX),
752+
np.array(0.1, dtype=aesara.config.floatX),
753+
),
754+
np.array([0.9, 0.1], dtype=aesara.config.floatX),
755+
None,
756+
None,
757+
),
758+
(
759+
(
760+
np.array(-10, dtype=aesara.config.floatX),
761+
np.array(0.1, dtype=aesara.config.floatX),
762+
),
763+
(
764+
np.array(10, dtype=aesara.config.floatX),
765+
np.array(0.1, dtype=aesara.config.floatX),
766+
),
767+
np.array(0.5, dtype=aesara.config.floatX),
768+
(2, 3),
769+
(2, 3),
770+
),
771+
(
772+
(
773+
np.array(10, dtype=aesara.config.floatX),
774+
np.array(0.1, dtype=aesara.config.floatX),
775+
),
776+
(
777+
np.array(-10, dtype=aesara.config.floatX),
778+
np.array(0.1, dtype=aesara.config.floatX),
779+
),
780+
np.array(0.5, dtype=aesara.config.floatX),
781+
(2, 3),
782+
(),
783+
),
784+
(
785+
(
786+
np.array(10, dtype=aesara.config.floatX),
787+
np.array(0.1, dtype=aesara.config.floatX),
788+
),
789+
(
790+
np.array(-10, dtype=aesara.config.floatX),
791+
np.array(0.1, dtype=aesara.config.floatX),
792+
),
793+
np.array(0.5, dtype=aesara.config.floatX),
794+
(3,),
795+
(3,),
796+
),
797+
]
798+
if not ((test_args[-1] is None or len(test_args[-1]) > 0) and op == ifelse)
799+
],
800+
)
801+
def test_switch_ifelse_mixture(op, X_args, Y_args, p_val, comp_size, idx_size):
802+
"""
803+
The argument size is both the input to srng.normal and the expected
804+
size of the mixture RV Z1_rv
805+
"""
721806
srng = at.random.RandomStream(29833)
722807

723-
X_rv = srng.normal(-10.0, 0.1, name="X")
724-
Y_rv = srng.normal(10.0, 0.1, name="Y")
808+
X_rv = srng.normal(*X_args, size=comp_size, name="X")
809+
Y_rv = srng.normal(*Y_args, size=comp_size, name="Y")
725810

726-
I_rv = srng.bernoulli(0.5, name="I")
811+
I_rv = srng.bernoulli(p_val, size=idx_size, name="I")
727812
i_vv = I_rv.clone()
728813
i_vv.name = "i"
729814

@@ -755,9 +840,3 @@ def test_switch_ifelse_mixture(op):
755840

756841
z1_logp = joint_logprob({Z1_rv: z_vv, I_rv: i_vv})
757842
z2_logp = joint_logprob({Z2_rv: z_vv, I_rv: i_vv})
758-
759-
# below should follow immediately from the equal_computations assertion above
760-
assert equal_computations([z1_logp], [z2_logp])
761-
762-
np.testing.assert_almost_equal(0.69049938, z1_logp.eval({z_vv: -10, i_vv: 0}))
763-
np.testing.assert_almost_equal(0.69049938, z2_logp.eval({z_vv: -10, i_vv: 0}))

0 commit comments

Comments
 (0)