diff --git a/src/jimgw/core/single_event/likelihood.py b/src/jimgw/core/single_event/likelihood.py index 2b4a7bc06..566baf856 100644 --- a/src/jimgw/core/single_event/likelihood.py +++ b/src/jimgw/core/single_event/likelihood.py @@ -438,17 +438,8 @@ def __init__( if reference_waveform is None: reference_waveform = waveform - # Get the original frequency grid - frequency_original = self.frequencies - # Get the grid of the relative binning scheme (contains the final endpoint) - # and the center points - freq_grid, self.freq_grid_center = self.make_binning_scheme( - jnp.array(frequency_original), n_bins - ) - self.freq_grid_low = freq_grid[:-1] - if ref_params: - self.ref_params = ref_params + self.ref_params = ref_params.copy() logging.info(f"Reference parameters provided, which are {self.ref_params}") elif prior: logging.info("No reference parameters are provided, finding it...") @@ -484,12 +475,21 @@ def __init__( self.B0_array = {} self.B1_array = {} + # Get the original frequency grid + frequency_original = self.frequencies + # Get the grid of the relative binning scheme (contains the final endpoint) + # and the center points + freq_grid, self.freq_grid_center = self.make_binning_scheme( + jnp.array(frequency_original), n_bins + ) + self.freq_grid_low = freq_grid[:-1] + h_sky = reference_waveform(frequency_original, self.ref_params) # Get frequency masks to be applied, for both original # and heterodyne frequency grid h_amp = jnp.sum( - jnp.array([jnp.abs(h_sky[key]) for key in h_sky.keys()]), axis=0 + jnp.array([jnp.abs(h_sky[pol]) for pol in h_sky.keys()]), axis=0 ) f_valid = frequency_original[jnp.where(h_amp > 0)[0]] f_max = jnp.max(f_valid) @@ -532,7 +532,6 @@ def __init__( freq_grid, self.freq_grid_center, ) - self.A0_array[detector.name] = A0[mask_heterodyne_center] self.A1_array[detector.name] = A1[mask_heterodyne_center] self.B0_array[detector.name] = B0[mask_heterodyne_center] @@ -556,7 +555,7 @@ def _likelihood(self, params: dict[str, Float], data: dict) -> Float: frequencies_low, waveform_sky_low, params ) waveform_center = detector.fd_response( - frequencies_low, waveform_sky_center, params + frequencies_center, waveform_sky_center, params ) r0 = waveform_center / self.waveform_center_ref[detector.name] @@ -634,34 +633,52 @@ def make_binning_scheme( The bin centers. """ phase_diff_array = self.max_phase_diff(freqs, freqs[0], freqs[-1], chi=chi) # type: ignore - bin_f = interp1d(phase_diff_array, freqs) phase_diff = jnp.linspace(phase_diff_array[0], phase_diff_array[-1], n_bins + 1) - f_bins = bin_f(phase_diff) + f_bins = interp1d(phase_diff_array, freqs)(phase_diff) f_bins_center = (f_bins[:-1] + f_bins[1:]) / 2 return jnp.array(f_bins), jnp.array(f_bins_center) @staticmethod def compute_coefficients(data, h_ref, psd, freqs, f_bins, f_bins_center): - A0_array = [] - A1_array = [] - B0_array = [] - B1_array = [] - df = freqs[1] - freqs[0] data_prod = jnp.array(data * h_ref.conj()) / psd self_prod = jnp.array(h_ref * h_ref.conj()) / psd - for i in range(len(f_bins) - 1): - f_index = jnp.where((freqs >= f_bins[i]) & (freqs < f_bins[i + 1]))[0] - freq_shift = freqs[f_index] - f_bins_center[i] - A0_array.append(4 * jnp.sum(data_prod[f_index]) * df) - A1_array.append(4 * jnp.sum(data_prod[f_index] * freq_shift) * df) - B0_array.append(4 * jnp.sum(self_prod[f_index]) * df) - B1_array.append(4 * jnp.sum(self_prod[f_index] * freq_shift) * df) - - A0_array = jnp.array(A0_array) - A1_array = jnp.array(A1_array) - B0_array = jnp.array(B0_array) - B1_array = jnp.array(B1_array) + + # Vectorized binning using broadcasting + freq_bins_left = f_bins[:-1] # Shape: (len(f_bins)-1,) + freq_bins_right = f_bins[1:] # Shape: (len(f_bins)-1,) + + # Broadcast for vectorized comparison + freqs_broadcast = freqs[None, :] # Shape: (1, n_freqs) + left_bounds = freq_bins_left[:, None] # Shape: (len(f_bins)-1, 1) + right_bounds = freq_bins_right[:, None] # Shape: (len(f_bins)-1, 1) + + # Create mask matrix: True where frequency belongs to bin + mask = (freqs_broadcast >= left_bounds) & ( + freqs_broadcast < right_bounds + ) # Shape: (len(f_bins)-1, n_freqs) + + # Vectorized computation of frequency shifts + f_bins_center_broadcast = f_bins_center[:, None] # Shape: (len(f_bins)-1, 1) + freq_shift_matrix = ( + freqs_broadcast - f_bins_center_broadcast + ) * mask # Shape: (len(f_bins)-1, n_freqs) + + # Vectorized computation of coefficients + # For each bin, sum over the frequency dimension + A0_array = ( + 4 * jnp.sum(data_prod[None, :] * mask, axis=1) * df + ) # Shape: (len(f_bins)-1,) + A1_array = ( + 4 * jnp.sum(data_prod[None, :] * freq_shift_matrix, axis=1) * df + ) # Shape: (len(f_bins)-1,) + B0_array = ( + 4 * jnp.sum(self_prod[None, :] * mask, axis=1) * df + ) # Shape: (len(f_bins)-1,) + B1_array = ( + 4 * jnp.sum(self_prod[None, :] * freq_shift_matrix, axis=1) * df + ) # Shape: (len(f_bins)-1,) + return A0_array, A1_array, B0_array, B1_array def maximize_likelihood( @@ -720,7 +737,7 @@ def y(x: Float[Array, " n_dims"], data: dict) -> Float: non_finite_index[:common_length] ].set(guess[:common_length]) - rng_key, best_fit, log_prob = optimizer.optimize( + _, best_fit, log_prob = optimizer.optimize( jax.random.PRNGKey(12094), y, initial_position, {} )