@@ -252,25 +252,6 @@ def test_hetero_mixture_binomial(p_val, size):
252
252
(),
253
253
0 ,
254
254
),
255
- (
256
- (
257
- np .array (0 , dtype = aesara .config .floatX ),
258
- np .array (1 , dtype = aesara .config .floatX ),
259
- ),
260
- (
261
- np .array (0.5 , dtype = aesara .config .floatX ),
262
- np .array (0.5 , dtype = aesara .config .floatX ),
263
- ),
264
- (
265
- np .array (100 , dtype = aesara .config .floatX ),
266
- np .array (1 , dtype = aesara .config .floatX ),
267
- ),
268
- np .array ([0.1 , 0.5 , 0.4 ], dtype = aesara .config .floatX ),
269
- (),
270
- (),
271
- (),
272
- 0 ,
273
- ),
274
255
(
275
256
(
276
257
np .array (0 , dtype = aesara .config .floatX ),
@@ -716,14 +697,118 @@ def test_mixture_with_DiracDelta():
716
697
assert m_vv in logp_res
717
698
718
699
719
- @pytest .mark .parametrize ("op" , [at .switch , ifelse ])
720
- def test_switch_ifelse_mixture (op ):
700
+ @pytest .mark .parametrize (
701
+ "op, X_args, Y_args, p_val, comp_size, idx_size" ,
702
+ [
703
+ [op ] + list (test_args )
704
+ for op in [at .stack , ifelse ]
705
+ for test_args in [
706
+ (
707
+ (
708
+ np .array (- 10 , dtype = aesara .config .floatX ),
709
+ np .array (0.1 , dtype = aesara .config .floatX ),
710
+ ),
711
+ (
712
+ np .array (10 , dtype = aesara .config .floatX ),
713
+ np .array (0.1 , dtype = aesara .config .floatX ),
714
+ ),
715
+ np .array (0.5 , dtype = aesara .config .floatX ),
716
+ (),
717
+ (),
718
+ ),
719
+ (
720
+ (
721
+ np .array (- 10 , dtype = aesara .config .floatX ),
722
+ np .array (0.1 , dtype = aesara .config .floatX ),
723
+ ),
724
+ (
725
+ np .array (10 , dtype = aesara .config .floatX ),
726
+ np .array (0.1 , dtype = aesara .config .floatX ),
727
+ ),
728
+ np .array (0.5 , dtype = aesara .config .floatX ),
729
+ (),
730
+ (6 ,),
731
+ ),
732
+ (
733
+ (
734
+ np .array ([10 , 20 ], dtype = aesara .config .floatX ),
735
+ np .array (0.1 , dtype = aesara .config .floatX ),
736
+ ),
737
+ (
738
+ np .array ([- 10 , - 20 ], dtype = aesara .config .floatX ),
739
+ np .array (0.1 , dtype = aesara .config .floatX ),
740
+ ),
741
+ np .array ([0.9 , 0.1 ], dtype = aesara .config .floatX ),
742
+ (2 ,),
743
+ (2 ,),
744
+ ),
745
+ (
746
+ (
747
+ np .array ([10 , 20 ], dtype = aesara .config .floatX ),
748
+ np .array (0.1 , dtype = aesara .config .floatX ),
749
+ ),
750
+ (
751
+ np .array ([- 10 , - 20 ], dtype = aesara .config .floatX ),
752
+ np .array (0.1 , dtype = aesara .config .floatX ),
753
+ ),
754
+ np .array ([0.9 , 0.1 ], dtype = aesara .config .floatX ),
755
+ None ,
756
+ None ,
757
+ ),
758
+ (
759
+ (
760
+ np .array (- 10 , dtype = aesara .config .floatX ),
761
+ np .array (0.1 , dtype = aesara .config .floatX ),
762
+ ),
763
+ (
764
+ np .array (10 , dtype = aesara .config .floatX ),
765
+ np .array (0.1 , dtype = aesara .config .floatX ),
766
+ ),
767
+ np .array (0.5 , dtype = aesara .config .floatX ),
768
+ (2 , 3 ),
769
+ (2 , 3 ),
770
+ ),
771
+ (
772
+ (
773
+ np .array (10 , dtype = aesara .config .floatX ),
774
+ np .array (0.1 , dtype = aesara .config .floatX ),
775
+ ),
776
+ (
777
+ np .array (- 10 , dtype = aesara .config .floatX ),
778
+ np .array (0.1 , dtype = aesara .config .floatX ),
779
+ ),
780
+ np .array (0.5 , dtype = aesara .config .floatX ),
781
+ (2 , 3 ),
782
+ (),
783
+ ),
784
+ (
785
+ (
786
+ np .array (10 , dtype = aesara .config .floatX ),
787
+ np .array (0.1 , dtype = aesara .config .floatX ),
788
+ ),
789
+ (
790
+ np .array (- 10 , dtype = aesara .config .floatX ),
791
+ np .array (0.1 , dtype = aesara .config .floatX ),
792
+ ),
793
+ np .array (0.5 , dtype = aesara .config .floatX ),
794
+ (3 ,),
795
+ (3 ,),
796
+ ),
797
+ ]
798
+ if not ((test_args [- 1 ] is None or len (test_args [- 1 ]) > 0 ) and op == ifelse )
799
+ ],
800
+ )
801
+ def test_switch_ifelse_mixture (op , X_args , Y_args , p_val , comp_size , idx_size ):
802
+ """
803
+ The argument size is both the input to srng.normal and the expected
804
+ size of the mixture RV Z1_rv
805
+ """
721
806
srng = at .random .RandomStream (29833 )
722
807
723
- X_rv = srng .normal (- 10.0 , 0.1 , name = "X" )
724
- Y_rv = srng .normal (10.0 , 0.1 , name = "Y" )
808
+ X_rv = srng .normal (* X_args , size = comp_size , name = "X" )
809
+ Y_rv = srng .normal (* Y_args , size = comp_size , name = "Y" )
725
810
726
- I_rv = srng .bernoulli (0.5 , name = "I" )
811
+ I_rv = srng .bernoulli (p_val , size = idx_size , name = "I" )
727
812
i_vv = I_rv .clone ()
728
813
i_vv .name = "i"
729
814
0 commit comments