Skip to content

Commit

Permalink
Adding tests, fixing minor issues.
Browse files Browse the repository at this point in the history
- Fixed an issue where the DCT/DST parameters weren't passed in the right order
- Fixed an issue where the norm factor of the DCT wasn't computed correctly (incorrect delta)
- Generated a unit_fft file that tests various parameters of FFT. Values are generated using scipy.fft module for the FFT, DCT and DST functions
- Changed pocketfft::detail:: namespace usage to just pocketfft:: in the C++ code for more readability
  • Loading branch information
gabyfle committed Nov 12, 2024
1 parent aea4703 commit 0dbafb4
Show file tree
Hide file tree
Showing 5 changed files with 1,803 additions and 17 deletions.
8 changes: 4 additions & 4 deletions src/owl/fftpack/owl_fft_generic.mli
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ val dct
-> ('a, 'b) t
-> ('a, 'b) t
(** [dct ?axis ?ttype ?norm ?ortho ?nthreads x] performs 1-dimensional Discrete Cosine Transform (DCT) on a real input.
[ttype] is the DCT type to use for this transform. Default value is [II].
[ttype] is the DCT type to use for this transform. Default value is [Two].
[norm] is the normalization option. By default, [norm] is set to [Backward].
[ortho] constrols whether or not we should use the orthogonalized variant of the DCT. *)

Expand All @@ -98,7 +98,7 @@ val idct
-> ('a, 'b) t
-> ('a, 'b) t
(** [idct ?axis ?ttype ?norm ?ortho ?nthreads x] performs inverse 1-dimensional Discrete Cosine Transform (DCT) on a real input.
[ttype] is the DCT type to use for this transform. Default value is [II].
[ttype] is the DCT type to use for this transform. Default value is [Two].
[norm] is the normalization option. By default, [norm] is set to [Forward].
[ortho] constrols whether or not we should use the orthogonalized variant of the DCT. *)

Expand All @@ -111,7 +111,7 @@ val dst
-> ('a, 'b) t
-> ('a, 'b) t
(** [dst ?axis ?ttype ?norm ?ortho ?nthreads x] performs 1-dimensional Discrete Sine Transform (DCT) on a real input.
[ttype] is the DCT type to use for this transform. Default value is [II].
[ttype] is the DCT type to use for this transform. Default value is [Two].
[norm] is the normalization option. By default, [norm] is set to [Backward].
[ortho] constrols whether or not we should use the orthogonalized variant of the DST. *)

Expand All @@ -124,6 +124,6 @@ val idst
-> ('a, 'b) t
-> ('a, 'b) t
(** [idst ?axis ?ttype ?norm ?ortho ?nthreads x] performs inverse 1-dimensional Discrete Sine Transform (DST) on a real input.
[ttype] is the DST type to use for this transform. Default value is [II].
[ttype] is the DST type to use for this transform. Default value is [Two].
[norm] is the normalization option. By default, [norm] is set to [Forward].
[ortho] constrols whether or not we should use the orthogonalized variant of the DST. *)
12 changes: 6 additions & 6 deletions src/owl/fftpack/owl_fftpack.ml
Original file line number Diff line number Diff line change
Expand Up @@ -181,10 +181,10 @@ let _owl_dctb
-> int
-> unit
=
fun ityp x y ttype axis norm ortho nthreads ->
fun ityp x y axis ttype norm ortho nthreads ->
match ityp with
| Float32 -> owl_float32_dct x y (inverse_map ttype) axis norm ortho nthreads
| Float64 -> owl_float64_dct x y (inverse_map ttype) axis norm ortho nthreads
| Float32 -> owl_float32_dct x y axis (inverse_map ttype) norm ortho nthreads
| Float64 -> owl_float64_dct x y axis (inverse_map ttype) norm ortho nthreads
| _ -> failwith "_owl_dctb: unsupported operation"


Expand Down Expand Up @@ -241,8 +241,8 @@ let _owl_dstb
-> int
-> unit
=
fun ityp x y ttype axis norm ortho nthreads ->
fun ityp x y axis ttype norm ortho nthreads ->
match ityp with
| Float32 -> owl_float32_dst x y (inverse_map ttype) axis norm ortho nthreads
| Float64 -> owl_float64_dst x y (inverse_map ttype) axis norm ortho nthreads
| Float32 -> owl_float32_dst x y axis (inverse_map ttype) norm ortho nthreads
| Float64 -> owl_float64_dst x y axis (inverse_map ttype) norm ortho nthreads
| _ -> failwith "_owl_dstb: unsupported operation"
14 changes: 7 additions & 7 deletions src/owl/fftpack/owl_fftpack_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ extern "C"
Treal norm_factor = compute_norm_factor<Treal>(dims, axes, norm);
try
{
pocketfft::detail::c2c(dims, stride_in, stride_out, axes, forward,
X_data, Y_data, norm_factor, nthreads);
pocketfft::c2c(dims, stride_in, stride_out, axes, forward,
X_data, Y_data, norm_factor, nthreads);
}
catch (const std::exception &e)
{
Expand Down Expand Up @@ -294,12 +294,12 @@ extern "C"

shape_t axes{static_cast<size_t>(d)};
{
Treal norm_factor = (type == 1) ? compute_norm_factor<Treal>(dims, axes, norm, 2, 1)
Treal norm_factor = (type == 1) ? compute_norm_factor<Treal>(dims, axes, norm, 2, -1)
: compute_norm_factor<Treal>(dims, axes, norm, 2);
try
{
pocketfft::detail::dct(dims, stride_in, stride_out, axes, type,
X_data, Y_data, norm_factor, ortho, nthreads);
pocketfft::dct(dims, stride_in, stride_out, axes, type,
X_data, Y_data, norm_factor, ortho, nthreads);
}
catch (const std::exception &e)
{
Expand Down Expand Up @@ -364,8 +364,8 @@ extern "C"
: compute_norm_factor<Treal>(dims, axes, norm, 2);
try
{
pocketfft::detail::dst(dims, stride_in, stride_out, axes, type,
X_data, Y_data, norm_factor, ortho, nthreads);
pocketfft::dst(dims, stride_in, stride_out, axes, type,
X_data, Y_data, norm_factor, ortho, nthreads);
}
catch (const std::exception &e)
{
Expand Down
1 change: 1 addition & 0 deletions test/test_runner.ml
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,5 @@ let () =
; "base: complex", Unit_base_complex.test_set
; "base: ndarray core", Unit_base_ndarray_core.test_set
; "base: linalg", Unit_linalg_solver.test_set; "base: signal", Unit_signal.test_set
; "fft", Unit_fft.test_set
; ("algodiff matrix", Unit_algodiff_matrix.[ Reverse.test; Forward.test ]) ]
Loading

0 comments on commit 0dbafb4

Please sign in to comment.