From 2d504656726644304bc6e797e7b55ced47459395 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Mon, 28 Jul 2025 20:54:02 +0800 Subject: [PATCH 1/3] Refactor frequency assertion checks to use jnp.allclose for improved performance --- src/jimgw/core/single_event/data.py | 8 +++++--- src/jimgw/core/single_event/detector.py | 4 ++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/jimgw/core/single_event/data.py b/src/jimgw/core/single_event/data.py index 998fa35f3..f1322b93d 100644 --- a/src/jimgw/core/single_event/data.py +++ b/src/jimgw/core/single_event/data.py @@ -325,9 +325,11 @@ def from_fd( # This ensures the newly constructed Data in FD fully # represents the input FD data. d_new, f_new = data.frequency_slice(frequencies[0], frequencies[-1]) - assert all(jnp.equal(d_new, fd)), "Data do not match after slicing" - assert all( - jnp.equal(f_new, frequencies) + assert jnp.allclose( + d_new, fd, rtol=1e-10, atol=1e-15 + ), "Data do not match after slicing" + assert jnp.allclose( + f_new, frequencies, rtol=1e-10, atol=1e-15 ), "Frequencies do not match after slicing" return data diff --git a/src/jimgw/core/single_event/detector.py b/src/jimgw/core/single_event/detector.py index 3eb9d93cb..a29bdacfb 100644 --- a/src/jimgw/core/single_event/detector.py +++ b/src/jimgw/core/single_event/detector.py @@ -142,8 +142,8 @@ def set_frequency_bounds( data, freqs_1 = self.data.frequency_slice(*self.frequency_bounds) psd, freqs_2 = self.psd.frequency_slice(*self.frequency_bounds) - assert all( - freqs_1 == freqs_2 + assert jnp.allclose( + freqs_1, freqs_2, rtol=1e-10, atol=1e-15 ), f"The {self.name} data and PSD must have same frequencies" self._sliced_frequencies = freqs_1 From 222aa6e5d7f2acdbc5bbc0b603ff19738a948380 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Mon, 28 Jul 2025 22:05:09 +0800 Subject: [PATCH 2/3] Replace jnp.allclose with jnp.array_equal for frequency assertions in Data and Detector classes --- src/jimgw/core/single_event/data.py | 8 +++----- src/jimgw/core/single_event/detector.py | 4 ++-- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/src/jimgw/core/single_event/data.py b/src/jimgw/core/single_event/data.py index f1322b93d..d5a2ae5f9 100644 --- a/src/jimgw/core/single_event/data.py +++ b/src/jimgw/core/single_event/data.py @@ -325,11 +325,9 @@ def from_fd( # This ensures the newly constructed Data in FD fully # represents the input FD data. d_new, f_new = data.frequency_slice(frequencies[0], frequencies[-1]) - assert jnp.allclose( - d_new, fd, rtol=1e-10, atol=1e-15 - ), "Data do not match after slicing" - assert jnp.allclose( - f_new, frequencies, rtol=1e-10, atol=1e-15 + assert jnp.array_equal(d_new, fd), "Data do not match after slicing" + assert jnp.array_equal( + f_new, frequencies ), "Frequencies do not match after slicing" return data diff --git a/src/jimgw/core/single_event/detector.py b/src/jimgw/core/single_event/detector.py index a29bdacfb..df52023f6 100644 --- a/src/jimgw/core/single_event/detector.py +++ b/src/jimgw/core/single_event/detector.py @@ -142,8 +142,8 @@ def set_frequency_bounds( data, freqs_1 = self.data.frequency_slice(*self.frequency_bounds) psd, freqs_2 = self.psd.frequency_slice(*self.frequency_bounds) - assert jnp.allclose( - freqs_1, freqs_2, rtol=1e-10, atol=1e-15 + assert jnp.array_equal( + freqs_1, freqs_2 ), f"The {self.name} data and PSD must have same frequencies" self._sliced_frequencies = freqs_1 From 459ff23aad8cdb8007cb719456caee7664c968da Mon Sep 17 00:00:00 2001 From: Samson Leong <55839002+SSL32081@users.noreply.github.com> Date: Mon, 28 Jul 2025 22:55:28 +0800 Subject: [PATCH 3/3] Replace allclose with arrayequal for frequencies check --- src/jimgw/core/single_event/data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/jimgw/core/single_event/data.py b/src/jimgw/core/single_event/data.py index d5a2ae5f9..48570c00c 100644 --- a/src/jimgw/core/single_event/data.py +++ b/src/jimgw/core/single_event/data.py @@ -315,7 +315,7 @@ def from_fd( delta_t = 1 / (2 * fnyq) data_td_full = jnp.fft.irfft(data_fd_full) / delta_t # check frequencies - assert jnp.allclose( + assert jnp.array_equal( f, jnp.fft.rfftfreq(len(data_td_full), delta_t) ), "Generated frequencies do not match the input frequencies" # create a Data object