Skip to content

Commit 28b20db

Browse files
Add tests with non-null size for ifelse and switch mixtures
1 parent bbf9cd4 commit 28b20db

File tree

1 file changed

+109
-24
lines changed

1 file changed

+109
-24
lines changed

tests/test_mixture.py

Lines changed: 109 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -252,25 +252,6 @@ def test_hetero_mixture_binomial(p_val, size):
252252
(),
253253
0,
254254
),
255-
(
256-
(
257-
np.array(0, dtype=aesara.config.floatX),
258-
np.array(1, dtype=aesara.config.floatX),
259-
),
260-
(
261-
np.array(0.5, dtype=aesara.config.floatX),
262-
np.array(0.5, dtype=aesara.config.floatX),
263-
),
264-
(
265-
np.array(100, dtype=aesara.config.floatX),
266-
np.array(1, dtype=aesara.config.floatX),
267-
),
268-
np.array([0.1, 0.5, 0.4], dtype=aesara.config.floatX),
269-
(),
270-
(),
271-
(),
272-
0,
273-
),
274255
(
275256
(
276257
np.array(0, dtype=aesara.config.floatX),
@@ -716,14 +697,118 @@ def test_mixture_with_DiracDelta():
716697
assert m_vv in logp_res
717698

718699

719-
@pytest.mark.parametrize("op", [at.switch, ifelse])
720-
def test_switch_ifelse_mixture(op):
700+
@pytest.mark.parametrize(
701+
"op, X_args, Y_args, p_val, comp_size, idx_size",
702+
[
703+
[op] + list(test_args)
704+
for op in [at.stack, ifelse]
705+
for test_args in [
706+
(
707+
(
708+
np.array(-10, dtype=aesara.config.floatX),
709+
np.array(0.1, dtype=aesara.config.floatX),
710+
),
711+
(
712+
np.array(10, dtype=aesara.config.floatX),
713+
np.array(0.1, dtype=aesara.config.floatX),
714+
),
715+
np.array(0.5, dtype=aesara.config.floatX),
716+
(),
717+
(),
718+
),
719+
(
720+
(
721+
np.array(-10, dtype=aesara.config.floatX),
722+
np.array(0.1, dtype=aesara.config.floatX),
723+
),
724+
(
725+
np.array(10, dtype=aesara.config.floatX),
726+
np.array(0.1, dtype=aesara.config.floatX),
727+
),
728+
np.array(0.5, dtype=aesara.config.floatX),
729+
(),
730+
(6,),
731+
),
732+
(
733+
(
734+
np.array([10, 20], dtype=aesara.config.floatX),
735+
np.array(0.1, dtype=aesara.config.floatX),
736+
),
737+
(
738+
np.array([-10, -20], dtype=aesara.config.floatX),
739+
np.array(0.1, dtype=aesara.config.floatX),
740+
),
741+
np.array([0.9, 0.1], dtype=aesara.config.floatX),
742+
(2,),
743+
(2,),
744+
),
745+
(
746+
(
747+
np.array([10, 20], dtype=aesara.config.floatX),
748+
np.array(0.1, dtype=aesara.config.floatX),
749+
),
750+
(
751+
np.array([-10, -20], dtype=aesara.config.floatX),
752+
np.array(0.1, dtype=aesara.config.floatX),
753+
),
754+
np.array([0.9, 0.1], dtype=aesara.config.floatX),
755+
None,
756+
None,
757+
),
758+
(
759+
(
760+
np.array(-10, dtype=aesara.config.floatX),
761+
np.array(0.1, dtype=aesara.config.floatX),
762+
),
763+
(
764+
np.array(10, dtype=aesara.config.floatX),
765+
np.array(0.1, dtype=aesara.config.floatX),
766+
),
767+
np.array(0.5, dtype=aesara.config.floatX),
768+
(2, 3),
769+
(2, 3),
770+
),
771+
(
772+
(
773+
np.array(10, dtype=aesara.config.floatX),
774+
np.array(0.1, dtype=aesara.config.floatX),
775+
),
776+
(
777+
np.array(-10, dtype=aesara.config.floatX),
778+
np.array(0.1, dtype=aesara.config.floatX),
779+
),
780+
np.array(0.5, dtype=aesara.config.floatX),
781+
(2, 3),
782+
(),
783+
),
784+
(
785+
(
786+
np.array(10, dtype=aesara.config.floatX),
787+
np.array(0.1, dtype=aesara.config.floatX),
788+
),
789+
(
790+
np.array(-10, dtype=aesara.config.floatX),
791+
np.array(0.1, dtype=aesara.config.floatX),
792+
),
793+
np.array(0.5, dtype=aesara.config.floatX),
794+
(3,),
795+
(3,),
796+
),
797+
]
798+
if not ((test_args[-1] is None or len(test_args[-1]) > 0) and op == ifelse)
799+
],
800+
)
801+
def test_switch_ifelse_mixture(op, X_args, Y_args, p_val, comp_size, idx_size):
802+
"""
803+
The argument size is both the input to srng.normal and the expected
804+
size of the mixture RV Z1_rv
805+
"""
721806
srng = at.random.RandomStream(29833)
722807

723-
X_rv = srng.normal(-10.0, 0.1, name="X")
724-
Y_rv = srng.normal(10.0, 0.1, name="Y")
808+
X_rv = srng.normal(*X_args, size=comp_size, name="X")
809+
Y_rv = srng.normal(*Y_args, size=comp_size, name="Y")
725810

726-
I_rv = srng.bernoulli(0.5, name="I")
811+
I_rv = srng.bernoulli(p_val, size=idx_size, name="I")
727812
i_vv = I_rv.clone()
728813
i_vv.name = "i"
729814

0 commit comments

Comments
 (0)