Skip to content
Closed
Changes from 35 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
6ea2f64
refactor: remove evaluate_original method and update its usage in Het…
thomasckng May 23, 2025
b019e1e
fix: correct super() call in HeterodynedTransientLikelihoodFD class
thomasckng May 23, 2025
e909fd2
fix: update method signature and call in HeterodynedTransientLikeliho…
thomasckng May 23, 2025
bfcfbe9
fix: update method call in HeterodynedTransientLikelihoodFD class to …
thomasckng May 23, 2025
41e8825
fix: update method definition in HeterodynedTransientLikelihoodFD cla…
thomasckng May 23, 2025
bd60d33
fix the shape of frequency array in heterodyned
CharmaineWONG2 May 23, 2025
2bc05bc
Use slightly faster broadcasting
SSL32081 May 24, 2025
fce5b95
Change to use super call
SSL32081 May 24, 2025
3ad9098
Update array_equal to array_equiv assertion
SSL32081 May 26, 2025
91c3a1f
Minor: formatting
SSL32081 May 26, 2025
c12faac
Avoid editing the parameter dictionary in-place
SSL32081 May 29, 2025
775dc63
Make data optional in likelihood evaluate
SSL32081 May 29, 2025
80cf3b7
Add the None default value
SSL32081 May 29, 2025
1516549
Add comment regarding trapezoid integration
SSL32081 May 29, 2025
86c1dab
Fix heterodyne likelihood bug
SSL32081 May 29, 2025
7f6e02a
Remove explicit optional type
SSL32081 May 29, 2025
670cdc0
Fix formatting in HeterodynedTransientLikelihoodFD class
thomasckng Jun 4, 2025
8e6d348
Refactor initial sample generation and integrate into Jim and likelih…
thomasckng Jun 4, 2025
2af50cd
Fix formatting in sample_initial_condition method and clean up import…
thomasckng Jun 4, 2025
f874a44
Remove redundant import of generate_initial_samples in likelihood.py
thomasckng Jun 4, 2025
f996e88
Merge branch 'jim-dev' into fix-maxL
thomasckng Jun 4, 2025
be07cb2
Fix pre-commit bugs
SSL32081 Jun 4, 2025
7ddda36
Fix pre-commit bugs again
SSL32081 Jun 4, 2025
899b7b2
Use the frequecies attribute
SSL32081 Jun 4, 2025
474c1b8
Formatting
thomasckng Jun 5, 2025
61d9f43
Refactor type hints for sample_transforms and likelihood_transforms t…
thomasckng Jun 6, 2025
162e066
Refactor evaluation call in HeterodynedTransientLikelihoodFD to use s…
thomasckng Jun 27, 2025
458c6e3
Update import for welch function to use jax.scipy instead of scipy
thomasckng Jul 3, 2025
89faeb4
Merge branch 'jim-dev' into fix-maxL
thomasckng Jul 3, 2025
2880631
Merge utils.py improvements from add-2-step-optimisation
thomasckng Jul 4, 2025
08b3ae9
Formatting
thomasckng Jul 4, 2025
ddf1e5b
Merge branch 'jim-dev' into fix-maxL
thomasckng Jul 12, 2025
7783c88
Merge branch 'jim-dev' into fix-maxL
thomasckng Jul 22, 2025
24ce79c
Merge branch 'jim-dev' into fix-maxL
thomasckng Jul 22, 2025
d04ecdf
Minor change
thomasckng Jul 22, 2025
859d97f
Optimize frequency binning with vectorized operations in HeterodynedT…
thomasckng Jul 25, 2025
d88638f
Formatting
thomasckng Jul 25, 2025
ba2140d
Revert evaluation call in HeterodynedTransientLikelihoodFD to use sup…
thomasckng Jul 25, 2025
5567c99
Remove unused coefficient arrays in compute_coefficients method of He…
thomasckng Jul 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 26 additions & 28 deletions src/jimgw/core/single_event/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,17 +438,8 @@ def __init__(
if reference_waveform is None:
reference_waveform = waveform

# Get the original frequency grid
frequency_original = self.frequencies
# Get the grid of the relative binning scheme (contains the final endpoint)
# and the center points
freq_grid, self.freq_grid_center = self.make_binning_scheme(
jnp.array(frequency_original), n_bins
)
self.freq_grid_low = freq_grid[:-1]

if ref_params:
self.ref_params = ref_params
self.ref_params = ref_params.copy()
logging.info(f"Reference parameters provided, which are {self.ref_params}")
elif prior:
logging.info("No reference parameters are provided, finding it...")
Expand Down Expand Up @@ -484,12 +475,21 @@ def __init__(
self.B0_array = {}
self.B1_array = {}

# Get the original frequency grid
frequency_original = self.frequencies
# Get the grid of the relative binning scheme (contains the final endpoint)
# and the center points
freq_grid, self.freq_grid_center = self.make_binning_scheme(
jnp.array(frequency_original), n_bins
)
self.freq_grid_low = freq_grid[:-1]

Comment thread
thomasckng marked this conversation as resolved.
h_sky = reference_waveform(frequency_original, self.ref_params)

# Get frequency masks to be applied, for both original
# and heterodyne frequency grid
h_amp = jnp.sum(
jnp.array([jnp.abs(h_sky[key]) for key in h_sky.keys()]), axis=0
jnp.array([jnp.abs(h_sky[pol]) for pol in h_sky.keys()]), axis=0
)
f_valid = frequency_original[jnp.where(h_amp > 0)[0]]
f_max = jnp.max(f_valid)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tend to think this block of frequency mask may not be needed (up to Line 299), as it is trying to keep segments where the waveform is non-zero, but given that we are only generating the waveform segments as requested, I believe the whole segment will be accepted in almost all scenarios, other than perhaps at the point where frequency = 0.0.

Copy link
Copy Markdown
Collaborator

@SSL32081 SSL32081 Jul 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the latest version, it goes from l494 to l507, which refers to this block of code:

f_valid = frequency_original[jnp.where(h_amp > 0)[0]]
...
self.freq_grid_low = self.freq_grid_low[mask_heterodyne_low]
self.freq_grid_center = self.freq_grid_center[mask_heterodyne_center]

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kazewong @tsunhopang Any suggestions?

Expand Down Expand Up @@ -532,7 +532,6 @@ def __init__(
freq_grid,
self.freq_grid_center,
)

self.A0_array[detector.name] = A0[mask_heterodyne_center]
self.A1_array[detector.name] = A1[mask_heterodyne_center]
self.B0_array[detector.name] = B0[mask_heterodyne_center]
Expand All @@ -556,7 +555,7 @@ def _likelihood(self, params: dict[str, Float], data: dict) -> Float:
frequencies_low, waveform_sky_low, params
)
waveform_center = detector.fd_response(
frequencies_low, waveform_sky_center, params
frequencies_center, waveform_sky_center, params
Comment thread
thomasckng marked this conversation as resolved.
)

r0 = waveform_center / self.waveform_center_ref[detector.name]
Expand Down Expand Up @@ -634,9 +633,8 @@ def make_binning_scheme(
The bin centers.
"""
phase_diff_array = self.max_phase_diff(freqs, freqs[0], freqs[-1], chi=chi) # type: ignore
bin_f = interp1d(phase_diff_array, freqs)
phase_diff = jnp.linspace(phase_diff_array[0], phase_diff_array[-1], n_bins + 1)
f_bins = bin_f(phase_diff)
f_bins = interp1d(phase_diff_array, freqs)(phase_diff)
f_bins_center = (f_bins[:-1] + f_bins[1:]) / 2
return jnp.array(f_bins), jnp.array(f_bins_center)

Expand All @@ -653,15 +651,15 @@ def compute_coefficients(data, h_ref, psd, freqs, f_bins, f_bins_center):
for i in range(len(f_bins) - 1):
f_index = jnp.where((freqs >= f_bins[i]) & (freqs < f_bins[i + 1]))[0]
freq_shift = freqs[f_index] - f_bins_center[i]
A0_array.append(4 * jnp.sum(data_prod[f_index]) * df)
A1_array.append(4 * jnp.sum(data_prod[f_index] * freq_shift) * df)
B0_array.append(4 * jnp.sum(self_prod[f_index]) * df)
B1_array.append(4 * jnp.sum(self_prod[f_index] * freq_shift) * df)

A0_array = jnp.array(A0_array)
A1_array = jnp.array(A1_array)
B0_array = jnp.array(B0_array)
B1_array = jnp.array(B1_array)
A0_array.append(jnp.sum(data_prod[f_index]))
A1_array.append(jnp.sum(data_prod[f_index] * freq_shift))
B0_array.append(jnp.sum(self_prod[f_index]))
B1_array.append(jnp.sum(self_prod[f_index] * freq_shift))

A0_array = 4 * df * jnp.array(A0_array)
A1_array = 4 * df * jnp.array(A1_array)
B0_array = 4 * df * jnp.array(B0_array)
B1_array = 4 * df * jnp.array(B1_array)
return A0_array, A1_array, B0_array, B1_array

def maximize_likelihood(
Expand All @@ -676,15 +674,15 @@ def maximize_likelihood(
for transform in sample_transforms:
parameter_names = transform.propagate_name(parameter_names)

super_obj = super(HeterodynedTransientLikelihoodFD, self)
Comment thread
thomasckng marked this conversation as resolved.
Outdated

def y(x: Float[Array, " n_dims"], data: dict) -> Float:
named_params = dict(zip(parameter_names, x))
for transform in reversed(sample_transforms):
named_params = transform.backward(named_params)
for transform in likelihood_transforms:
named_params = transform.forward(named_params)
return -super(HeterodynedTransientLikelihoodFD, self).evaluate(
named_params, data
)
return -super_obj.evaluate(named_params, data)

print("Starting the optimizer")

Expand Down Expand Up @@ -720,7 +718,7 @@ def y(x: Float[Array, " n_dims"], data: dict) -> Float:
non_finite_index[:common_length]
].set(guess[:common_length])

rng_key, best_fit, log_prob = optimizer.optimize(
_, best_fit, log_prob = optimizer.optimize(
jax.random.PRNGKey(12094), y, initial_position, {}
)

Expand Down
Loading