From c4fb9414270a35efa73f797b1a224cbc6bbe2366 Mon Sep 17 00:00:00 2001 From: Patrick Nicodemus Date: Wed, 6 Sep 2023 14:42:25 -0400 Subject: [PATCH] Changed def of ssqr_diff' to not modify inputs. Added two tests. --- src/owl/core/owl_ndarray_maths_stub.c | 16 ++++++++-------- test/unit_dense_ndarray.ml | 24 +++++++++++++++++++++++- 2 files changed, 31 insertions(+), 9 deletions(-) diff --git a/src/owl/core/owl_ndarray_maths_stub.c b/src/owl/core/owl_ndarray_maths_stub.c index b17e13d08..c2842f517 100644 --- a/src/owl/core/owl_ndarray_maths_stub.c +++ b/src/owl/core/owl_ndarray_maths_stub.c @@ -4992,34 +4992,34 @@ // ssqr_diff #define FUN11 float32_ssqr_diff -#define INIT float r = 0. +#define INIT float r = 0. ; float diff #define NUMBER float #define NUMBER1 float -#define ACCFN(A,X,Y) X -= Y; X *= X; A += X +#define ACCFN(A,X,Y) diff=X-Y; diff*=diff; A+=diff #define COPYNUM(A) (caml_copy_double(A)) #include OWL_NDARRAY_MATHS_FOLD #define FUN11 float64_ssqr_diff -#define INIT double r = 0. +#define INIT double r = 0. ; double diff #define NUMBER double #define NUMBER1 double -#define ACCFN(A,X,Y) X -= Y; X *= X; A += X +#define ACCFN(A,X,Y) diff=X-Y; diff*=diff; A+=diff #define COPYNUM(A) (caml_copy_double(A)) #include OWL_NDARRAY_MATHS_FOLD #define FUN11 complex32_ssqr_diff -#define INIT complex_float r = { 0.0, 0.0 } +#define INIT complex_float r = { 0.0, 0.0 }; complex_float diff #define NUMBER complex_float #define NUMBER1 complex_float -#define ACCFN(A,X,Y) X.r -= Y.r; X.i -= Y.i; A.r += (X.r - X.i) * (X.r + X.i); A.i += 2 * A.r * A.i +#define ACCFN(A,X,Y) diff.r = X.r - Y.r; diff.i = X.i - Y.i; A.r += (diff.r - diff.i) * (diff.r + diff.i); A.i += 2 * A.r * A.i #define COPYNUM(A) (cp_two_doubles(A.r, A.i)) #include OWL_NDARRAY_MATHS_FOLD #define FUN11 complex64_ssqr_diff -#define INIT complex_double r = { 0.0, 0.0 } +#define INIT complex_double r = { 0.0, 0.0 }; complex_double diff #define NUMBER complex_double #define NUMBER1 complex_double -#define ACCFN(A,X,Y) X.r -= Y.r; X.i -= Y.i; A.r += (X.r - X.i) * (X.r + X.i); A.i += 2 * A.r * A.i +#define ACCFN(A,X,Y) diff.r = X.r - Y.r; diff.i = X.i - Y.i; A.r += (diff.r - diff.i) * (diff.r + diff.i); A.i += 2 * A.r * A.i #define COPYNUM(A) (cp_two_doubles(A.r, A.i)) #include OWL_NDARRAY_MATHS_FOLD diff --git a/test/unit_dense_ndarray.ml b/test/unit_dense_ndarray.ml index 364fe38fe..3a1c53e5f 100644 --- a/test/unit_dense_ndarray.ml +++ b/test/unit_dense_ndarray.ml @@ -132,7 +132,24 @@ module To_test = struct let sum_reduce () = M.sum_reduce ~axis:[| 0; 2 |] x4 = M.of_array Float64 [| 8.; 8.; 8. |] [| 1; 3; 1 |] - + let ssqr_diff32 () = + let a = M.of_array Float32 [| 3.; 4.; 5.; |] [| 1; 3 |] in + let a' = M.copy a in + let b = M.of_array Float32 [| 1.; 2.; 3.; |] [| 1; 3 |] in + let b' = M.copy b in + let ssqrdiff = M.ssqr_diff' a b in + ssqrdiff = 12. && a = a' && b = b' + + let ssqr_diff64 () = + let a = M.of_array Float64 [| 3.; 4.; 5.; |] [| 1; 3 |] in + let a' = M.copy a in + let b = M.of_array Float64 [| 1.; 2.; 3.; |] [| 1; 3 |] in + let b' = M.copy b in + let ssqrdiff = M.ssqr_diff' a b in + ssqrdiff = 12. && a = a' && b = b' + + + let min' () = M.min' x0 = 0. let max' () = M.max' x0 = 3. @@ -530,6 +547,10 @@ let sort1 () = Alcotest.(check bool) "sort1" true (To_test.sort1 ()) let sum_reduce () = Alcotest.(check bool) "sum_reduce" true (To_test.sum_reduce ()) +let ssqr_diff32 () = Alcotest.(check bool) "ssqr_diff32" true (To_test.ssqr_diff32 ()) + +let ssqr_diff64 () = Alcotest.(check bool) "ssqr_diff64" true (To_test.ssqr_diff64 ()) + let min' () = Alcotest.(check bool) "min'" true (To_test.min' ()) let max' () = Alcotest.(check bool) "max'" true (To_test.max' ()) @@ -674,6 +695,7 @@ let test_set = ; "mul", `Slow, mul; "add_scalar", `Slow, add_scalar; "mul_scalar", `Slow, mul_scalar ; "abs", `Slow, abs; "neg", `Slow, neg; "sum'", `Slow, sum'; "median'", `Slow, median' ; "median", `Slow, median; "sort1", `Slow, sort1; "sum_reduce", `Slow, sum_reduce + ; "ssqr_diff32", `Slow, ssqr_diff32 ; "ssqr_diff64", `Slow, ssqr_diff64 ; "min'", `Slow, min'; "max'", `Slow, max'; "minmax_i", `Slow, minmax_i ; "init_nd", `Slow, init_nd; "is_zero", `Slow, is_zero ; "is_positive", `Slow, is_positive; "is_negative", `Slow, is_negative