Skip to content
Closed
Changes from 6 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
6ea2f64
refactor: remove evaluate_original method and update its usage in Het…
thomasckng May 23, 2025
b019e1e
fix: correct super() call in HeterodynedTransientLikelihoodFD class
thomasckng May 23, 2025
e909fd2
fix: update method signature and call in HeterodynedTransientLikeliho…
thomasckng May 23, 2025
bfcfbe9
fix: update method call in HeterodynedTransientLikelihoodFD class to …
thomasckng May 23, 2025
41e8825
fix: update method definition in HeterodynedTransientLikelihoodFD cla…
thomasckng May 23, 2025
bd60d33
fix the shape of frequency array in heterodyned
CharmaineWONG2 May 23, 2025
2bc05bc
Use slightly faster broadcasting
SSL32081 May 24, 2025
fce5b95
Change to use super call
SSL32081 May 24, 2025
3ad9098
Update array_equal to array_equiv assertion
SSL32081 May 26, 2025
91c3a1f
Minor: formatting
SSL32081 May 26, 2025
c12faac
Avoid editing the parameter dictionary in-place
SSL32081 May 29, 2025
775dc63
Make data optional in likelihood evaluate
SSL32081 May 29, 2025
80cf3b7
Add the None default value
SSL32081 May 29, 2025
1516549
Add comment regarding trapezoid integration
SSL32081 May 29, 2025
86c1dab
Fix heterodyne likelihood bug
SSL32081 May 29, 2025
7f6e02a
Remove explicit optional type
SSL32081 May 29, 2025
670cdc0
Fix formatting in HeterodynedTransientLikelihoodFD class
thomasckng Jun 4, 2025
8e6d348
Refactor initial sample generation and integrate into Jim and likelih…
thomasckng Jun 4, 2025
2af50cd
Fix formatting in sample_initial_condition method and clean up import…
thomasckng Jun 4, 2025
f874a44
Remove redundant import of generate_initial_samples in likelihood.py
thomasckng Jun 4, 2025
f996e88
Merge branch 'jim-dev' into fix-maxL
thomasckng Jun 4, 2025
be07cb2
Fix pre-commit bugs
SSL32081 Jun 4, 2025
7ddda36
Fix pre-commit bugs again
SSL32081 Jun 4, 2025
899b7b2
Use the frequecies attribute
SSL32081 Jun 4, 2025
474c1b8
Formatting
thomasckng Jun 5, 2025
61d9f43
Refactor type hints for sample_transforms and likelihood_transforms t…
thomasckng Jun 6, 2025
162e066
Refactor evaluation call in HeterodynedTransientLikelihoodFD to use s…
thomasckng Jun 27, 2025
458c6e3
Update import for welch function to use jax.scipy instead of scipy
thomasckng Jul 3, 2025
89faeb4
Merge branch 'jim-dev' into fix-maxL
thomasckng Jul 3, 2025
2880631
Merge utils.py improvements from add-2-step-optimisation
thomasckng Jul 4, 2025
08b3ae9
Formatting
thomasckng Jul 4, 2025
ddf1e5b
Merge branch 'jim-dev' into fix-maxL
thomasckng Jul 12, 2025
7783c88
Merge branch 'jim-dev' into fix-maxL
thomasckng Jul 22, 2025
24ce79c
Merge branch 'jim-dev' into fix-maxL
thomasckng Jul 22, 2025
d04ecdf
Minor change
thomasckng Jul 22, 2025
859d97f
Optimize frequency binning with vectorized operations in HeterodynedT…
thomasckng Jul 25, 2025
d88638f
Formatting
thomasckng Jul 25, 2025
ba2140d
Revert evaluation call in HeterodynedTransientLikelihoodFD to use sup…
thomasckng Jul 25, 2025
5567c99
Remove unused coefficient arrays in compute_coefficients method of He…
thomasckng Jul 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 3 additions & 26 deletions src/jimgw/core/single_event/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,29 +357,6 @@ def evaluate(self, params: dict[str, Float], data: dict) -> Float:
)
return log_likelihood

def evaluate_original(
self, params: dict[str, Float], data: dict
) -> (
Float
): # TODO: Test whether we need to pass data in or with class changes is fine.
"""
Evaluate the likelihood for a given set of parameters.
"""
params["trigger_time"] = self.trigger_time
params["gmst"] = self.gmst
# adjust the params due to different marginalzation scheme
params = self.param_func(params)
# adjust the params due to fixing parameters
params = self.fixing_func(params)
# evaluate the waveform as usual
waveform_sky = self.waveform(self.frequencies, params)
return self.likelihood_function(
params,
waveform_sky,
self.detectors, # type: ignore
**self.kwargs,
)

@staticmethod
def max_phase_diff(
f: Float[Array, " n_freq"],
Expand Down Expand Up @@ -407,7 +384,7 @@ def max_phase_diff(
Maximum phase difference between the frequencies in the array.
"""
gamma = jnp.arange(-5, 6) / 3.0
f_2D = jnp.broadcast_to(f, (f.size, gamma.size))
f_2D = jnp.broadcast_to(f.reshape(f.size, 1), (f.size, gamma.size))
Comment thread
SSL32081 marked this conversation as resolved.
Outdated
f_star = jnp.where(gamma >= 0, f_high, f_low)
return (
2
Expand Down Expand Up @@ -488,7 +465,7 @@ def y(x: Float[Array, " n_dims"], data: dict) -> Float:
named_params = transform.backward(named_params)
for transform in likelihood_transforms:
named_params = transform.forward(named_params)
return -self.evaluate_original(named_params, data)
return -TransientLikelihoodFD.evaluate(self, named_params, data)

print("Starting the optimizer")

Expand Down Expand Up @@ -524,7 +501,7 @@ def y(x: Float[Array, " n_dims"], data: dict) -> Float:
non_finite_index[:common_length]
].set(guess[:common_length])

rng_key, best_fit, log_prob = optimizer.optimize(
_, best_fit, log_prob = optimizer.optimize(
jax.random.PRNGKey(12094), y, initial_position, {}
)

Expand Down
Loading