@@ -249,25 +249,6 @@ def test_hetero_mixture_binomial(p_val, size):
249
249
(),
250
250
0 ,
251
251
),
252
- (
253
- (
254
- np .array (0 , dtype = aesara .config .floatX ),
255
- np .array (1 , dtype = aesara .config .floatX ),
256
- ),
257
- (
258
- np .array (0.5 , dtype = aesara .config .floatX ),
259
- np .array (0.5 , dtype = aesara .config .floatX ),
260
- ),
261
- (
262
- np .array (100 , dtype = aesara .config .floatX ),
263
- np .array (1 , dtype = aesara .config .floatX ),
264
- ),
265
- np .array ([0.1 , 0.5 , 0.4 ], dtype = aesara .config .floatX ),
266
- (),
267
- (),
268
- (),
269
- 0 ,
270
- ),
271
252
(
272
253
(
273
254
np .array (0 , dtype = aesara .config .floatX ),
@@ -713,14 +694,118 @@ def test_mixture_with_DiracDelta():
713
694
assert m_vv in logp_res
714
695
715
696
716
- @pytest .mark .parametrize ("op" , [at .switch , ifelse ])
717
- def test_switch_ifelse_mixture (op ):
697
+ @pytest .mark .parametrize (
698
+ "op, X_args, Y_args, p_val, comp_size, idx_size" ,
699
+ [
700
+ [op ] + list (test_args )
701
+ for op in [at .switch , ifelse ]
702
+ for test_args in [
703
+ (
704
+ (
705
+ np .array (- 10 , dtype = aesara .config .floatX ),
706
+ np .array (0.1 , dtype = aesara .config .floatX ),
707
+ ),
708
+ (
709
+ np .array (10 , dtype = aesara .config .floatX ),
710
+ np .array (0.1 , dtype = aesara .config .floatX ),
711
+ ),
712
+ np .array (0.5 , dtype = aesara .config .floatX ),
713
+ (),
714
+ (),
715
+ ),
716
+ (
717
+ (
718
+ np .array (- 10 , dtype = aesara .config .floatX ),
719
+ np .array (0.1 , dtype = aesara .config .floatX ),
720
+ ),
721
+ (
722
+ np .array (10 , dtype = aesara .config .floatX ),
723
+ np .array (0.1 , dtype = aesara .config .floatX ),
724
+ ),
725
+ np .array (0.5 , dtype = aesara .config .floatX ),
726
+ (),
727
+ (6 ,),
728
+ ),
729
+ (
730
+ (
731
+ np .array ([10 , 20 ], dtype = aesara .config .floatX ),
732
+ np .array (0.1 , dtype = aesara .config .floatX ),
733
+ ),
734
+ (
735
+ np .array ([- 10 , - 20 ], dtype = aesara .config .floatX ),
736
+ np .array (0.1 , dtype = aesara .config .floatX ),
737
+ ),
738
+ np .array ([0.9 , 0.1 ], dtype = aesara .config .floatX ),
739
+ (2 ,),
740
+ (2 ,),
741
+ ),
742
+ (
743
+ (
744
+ np .array ([10 , 20 ], dtype = aesara .config .floatX ),
745
+ np .array (0.1 , dtype = aesara .config .floatX ),
746
+ ),
747
+ (
748
+ np .array ([- 10 , - 20 ], dtype = aesara .config .floatX ),
749
+ np .array (0.1 , dtype = aesara .config .floatX ),
750
+ ),
751
+ np .array ([0.9 , 0.1 ], dtype = aesara .config .floatX ),
752
+ None ,
753
+ None ,
754
+ ),
755
+ (
756
+ (
757
+ np .array (- 10 , dtype = aesara .config .floatX ),
758
+ np .array (0.1 , dtype = aesara .config .floatX ),
759
+ ),
760
+ (
761
+ np .array (10 , dtype = aesara .config .floatX ),
762
+ np .array (0.1 , dtype = aesara .config .floatX ),
763
+ ),
764
+ np .array (0.5 , dtype = aesara .config .floatX ),
765
+ (2 , 3 ),
766
+ (2 , 3 ),
767
+ ),
768
+ (
769
+ (
770
+ np .array (10 , dtype = aesara .config .floatX ),
771
+ np .array (0.1 , dtype = aesara .config .floatX ),
772
+ ),
773
+ (
774
+ np .array (- 10 , dtype = aesara .config .floatX ),
775
+ np .array (0.1 , dtype = aesara .config .floatX ),
776
+ ),
777
+ np .array (0.5 , dtype = aesara .config .floatX ),
778
+ (2 , 3 ),
779
+ (),
780
+ ),
781
+ (
782
+ (
783
+ np .array (10 , dtype = aesara .config .floatX ),
784
+ np .array (0.1 , dtype = aesara .config .floatX ),
785
+ ),
786
+ (
787
+ np .array (- 10 , dtype = aesara .config .floatX ),
788
+ np .array (0.1 , dtype = aesara .config .floatX ),
789
+ ),
790
+ np .array (0.5 , dtype = aesara .config .floatX ),
791
+ (3 ,),
792
+ (3 ,),
793
+ ),
794
+ ]
795
+ if not ((test_args [- 1 ] is None or len (test_args [- 1 ]) > 0 ) and op == ifelse )
796
+ ],
797
+ )
798
+ def test_switch_ifelse_mixture (op , X_args , Y_args , p_val , comp_size , idx_size ):
799
+ """
800
+ The argument size is both the input to srng.normal and the expected
801
+ size of the mixture RV Z1_rv
802
+ """
718
803
srng = at .random .RandomStream (29833 )
719
804
720
- X_rv = srng .normal (- 10.0 , 0.1 , name = "X" )
721
- Y_rv = srng .normal (10.0 , 0.1 , name = "Y" )
805
+ X_rv = srng .normal (* X_args , size = comp_size , name = "X" )
806
+ Y_rv = srng .normal (* Y_args , size = comp_size , name = "Y" )
722
807
723
- I_rv = srng .bernoulli (0.5 , name = "I" )
808
+ I_rv = srng .bernoulli (p_val , size = idx_size , name = "I" )
724
809
i_vv = I_rv .clone ()
725
810
i_vv .name = "i"
726
811
@@ -752,9 +837,3 @@ def test_switch_ifelse_mixture(op):
752
837
753
838
z1_logp = joint_logprob ({Z1_rv : z_vv , I_rv : i_vv })
754
839
z2_logp = joint_logprob ({Z2_rv : z_vv , I_rv : i_vv })
755
-
756
- # below should follow immediately from the equal_computations assertion above
757
- assert equal_computations ([z1_logp ], [z2_logp ])
758
-
759
- np .testing .assert_almost_equal (0.69049938 , z1_logp .eval ({z_vv : - 10 , i_vv : 0 }))
760
- np .testing .assert_almost_equal (0.69049938 , z2_logp .eval ({z_vv : - 10 , i_vv : 0 }))
0 commit comments