@@ -803,3 +803,106 @@ def test_slogdet_kronecker_rewrite():
803
803
atol = 1e-3 if config .floatX == "float32" else 1e-8 ,
804
804
rtol = 1e-3 if config .floatX == "float32" else 1e-8 ,
805
805
)
806
+
807
+
808
+ def test_cholesky_eye_rewrite ():
809
+ x = pt .eye (10 )
810
+ L = pt .linalg .cholesky (x )
811
+ f_rewritten = function ([], L , mode = "FAST_RUN" )
812
+ nodes = f_rewritten .maker .fgraph .apply_nodes
813
+
814
+ # Rewrite Test
815
+ assert not any (isinstance (node .op , Cholesky ) for node in nodes )
816
+
817
+ # Value Test
818
+ x_test = np .eye (10 )
819
+ L = np .linalg .cholesky (x_test )
820
+ rewritten_val = f_rewritten ()
821
+
822
+ assert_allclose (
823
+ L ,
824
+ rewritten_val ,
825
+ atol = 1e-3 if config .floatX == "float32" else 1e-8 ,
826
+ rtol = 1e-3 if config .floatX == "float32" else 1e-8 ,
827
+ )
828
+
829
+
830
+ @pytest .mark .parametrize (
831
+ "shape" ,
832
+ [(), (7 ,), (7 , 7 ), (5 , 7 , 7 )],
833
+ ids = ["scalar" , "vector" , "matrix" , "batched" ],
834
+ )
835
+ def test_cholesky_diag_from_eye_mul (shape ):
836
+ # Initializing x based on scalar/vector/matrix
837
+ x = pt .tensor ("x" , shape = shape )
838
+ y = pt .eye (7 ) * x
839
+ # Performing cholesky decomposition using pt.linalg.cholesky
840
+ z_cholesky = pt .linalg .cholesky (y )
841
+
842
+ # REWRITE TEST
843
+ f_rewritten = function ([x ], z_cholesky , mode = "FAST_RUN" )
844
+ nodes = f_rewritten .maker .fgraph .apply_nodes
845
+ assert not any (isinstance (node .op , Cholesky ) for node in nodes )
846
+
847
+ # NUMERIC VALUE TEST
848
+ if len (shape ) == 0 :
849
+ x_test = np .array (np .random .rand ()).astype (config .floatX )
850
+ elif len (shape ) == 1 :
851
+ x_test = np .random .rand (* shape ).astype (config .floatX )
852
+ else :
853
+ x_test = np .random .rand (* shape ).astype (config .floatX )
854
+ x_test_matrix = np .eye (7 ) * x_test
855
+ cholesky_val = np .linalg .cholesky (x_test_matrix )
856
+ rewritten_val = f_rewritten (x_test )
857
+
858
+ assert_allclose (
859
+ cholesky_val ,
860
+ rewritten_val ,
861
+ atol = 1e-3 if config .floatX == "float32" else 1e-8 ,
862
+ rtol = 1e-3 if config .floatX == "float32" else 1e-8 ,
863
+ )
864
+
865
+
866
+ def test_cholesky_diag_from_diag ():
867
+ x = pt .dvector ("x" )
868
+ x_diag = pt .diag (x )
869
+ x_cholesky = pt .linalg .cholesky (x_diag )
870
+
871
+ # REWRITE TEST
872
+ f_rewritten = function ([x ], x_cholesky , mode = "FAST_RUN" )
873
+ nodes = f_rewritten .maker .fgraph .apply_nodes
874
+
875
+ assert not any (isinstance (node .op , Cholesky ) for node in nodes )
876
+
877
+ # NUMERIC VALUE TEST
878
+ x_test = np .random .rand (10 )
879
+ x_test_matrix = np .eye (10 ) * x_test
880
+ cholesky_val = np .linalg .cholesky (x_test_matrix )
881
+ rewritten_cholesky = f_rewritten (x_test )
882
+
883
+ assert_allclose (
884
+ cholesky_val ,
885
+ rewritten_cholesky ,
886
+ atol = 1e-3 if config .floatX == "float32" else 1e-8 ,
887
+ rtol = 1e-3 if config .floatX == "float32" else 1e-8 ,
888
+ )
889
+
890
+
891
+ def test_rewrite_cholesky_diag_to_sqrt_diag_not_applied ():
892
+ # Case 1 : y is not a diagonal matrix because of k = -1
893
+ x = pt .tensor ("x" , shape = (7 , 7 ))
894
+ y = pt .eye (7 , k = - 1 ) * x
895
+ z_cholesky = pt .linalg .cholesky (y )
896
+
897
+ # REWRITE TEST (should not be applied)
898
+ f_rewritten = function ([x ], z_cholesky , mode = "FAST_RUN" )
899
+ nodes = f_rewritten .maker .fgraph .apply_nodes
900
+ assert any (isinstance (node .op , Cholesky ) for node in nodes )
901
+
902
+ # Case 2 : eye is degenerate
903
+ x = pt .scalar ("x" )
904
+ y = pt .eye (1 ) * x
905
+ z_cholesky = pt .linalg .cholesky (y )
906
+ f_rewritten = function ([x ], z_cholesky , mode = "FAST_RUN" )
907
+ nodes = f_rewritten .maker .fgraph .apply_nodes
908
+ assert any (isinstance (node .op , Cholesky ) for node in nodes )
0 commit comments