-
Notifications
You must be signed in to change notification settings - Fork 33
Bugfix and enhancement in heterodyned likelihood #222
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 35 commits
6ea2f64
b019e1e
e909fd2
bfcfbe9
41e8825
bd60d33
2bc05bc
fce5b95
3ad9098
91c3a1f
c12faac
775dc63
80cf3b7
1516549
86c1dab
7f6e02a
670cdc0
8e6d348
2af50cd
f874a44
f996e88
be07cb2
7ddda36
899b7b2
474c1b8
61d9f43
162e066
458c6e3
89faeb4
2880631
08b3ae9
ddf1e5b
7783c88
24ce79c
d04ecdf
859d97f
d88638f
ba2140d
5567c99
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -438,17 +438,8 @@ def __init__( | |
| if reference_waveform is None: | ||
| reference_waveform = waveform | ||
|
|
||
| # Get the original frequency grid | ||
| frequency_original = self.frequencies | ||
| # Get the grid of the relative binning scheme (contains the final endpoint) | ||
| # and the center points | ||
| freq_grid, self.freq_grid_center = self.make_binning_scheme( | ||
| jnp.array(frequency_original), n_bins | ||
| ) | ||
| self.freq_grid_low = freq_grid[:-1] | ||
|
|
||
| if ref_params: | ||
| self.ref_params = ref_params | ||
| self.ref_params = ref_params.copy() | ||
| logging.info(f"Reference parameters provided, which are {self.ref_params}") | ||
| elif prior: | ||
| logging.info("No reference parameters are provided, finding it...") | ||
|
|
@@ -484,12 +475,21 @@ def __init__( | |
| self.B0_array = {} | ||
| self.B1_array = {} | ||
|
|
||
| # Get the original frequency grid | ||
| frequency_original = self.frequencies | ||
| # Get the grid of the relative binning scheme (contains the final endpoint) | ||
| # and the center points | ||
| freq_grid, self.freq_grid_center = self.make_binning_scheme( | ||
| jnp.array(frequency_original), n_bins | ||
| ) | ||
| self.freq_grid_low = freq_grid[:-1] | ||
|
|
||
| h_sky = reference_waveform(frequency_original, self.ref_params) | ||
|
|
||
| # Get frequency masks to be applied, for both original | ||
| # and heterodyne frequency grid | ||
| h_amp = jnp.sum( | ||
| jnp.array([jnp.abs(h_sky[key]) for key in h_sky.keys()]), axis=0 | ||
| jnp.array([jnp.abs(h_sky[pol]) for pol in h_sky.keys()]), axis=0 | ||
| ) | ||
| f_valid = frequency_original[jnp.where(h_amp > 0)[0]] | ||
| f_max = jnp.max(f_valid) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I tend to think this block of frequency mask may not be needed (up to Line 299), as it is trying to keep segments where the waveform is non-zero, but given that we are only generating the waveform segments as requested, I believe the whole segment will be accepted in almost all scenarios, other than perhaps at the point where
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the latest version, it goes from f_valid = frequency_original[jnp.where(h_amp > 0)[0]]
...
self.freq_grid_low = self.freq_grid_low[mask_heterodyne_low]
self.freq_grid_center = self.freq_grid_center[mask_heterodyne_center]
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @kazewong @tsunhopang Any suggestions? |
||
|
|
@@ -532,7 +532,6 @@ def __init__( | |
| freq_grid, | ||
| self.freq_grid_center, | ||
| ) | ||
|
|
||
| self.A0_array[detector.name] = A0[mask_heterodyne_center] | ||
| self.A1_array[detector.name] = A1[mask_heterodyne_center] | ||
| self.B0_array[detector.name] = B0[mask_heterodyne_center] | ||
|
|
@@ -556,7 +555,7 @@ def _likelihood(self, params: dict[str, Float], data: dict) -> Float: | |
| frequencies_low, waveform_sky_low, params | ||
| ) | ||
| waveform_center = detector.fd_response( | ||
| frequencies_low, waveform_sky_center, params | ||
| frequencies_center, waveform_sky_center, params | ||
|
thomasckng marked this conversation as resolved.
|
||
| ) | ||
|
|
||
| r0 = waveform_center / self.waveform_center_ref[detector.name] | ||
|
|
@@ -634,9 +633,8 @@ def make_binning_scheme( | |
| The bin centers. | ||
| """ | ||
| phase_diff_array = self.max_phase_diff(freqs, freqs[0], freqs[-1], chi=chi) # type: ignore | ||
| bin_f = interp1d(phase_diff_array, freqs) | ||
| phase_diff = jnp.linspace(phase_diff_array[0], phase_diff_array[-1], n_bins + 1) | ||
| f_bins = bin_f(phase_diff) | ||
| f_bins = interp1d(phase_diff_array, freqs)(phase_diff) | ||
| f_bins_center = (f_bins[:-1] + f_bins[1:]) / 2 | ||
| return jnp.array(f_bins), jnp.array(f_bins_center) | ||
|
|
||
|
|
@@ -653,15 +651,15 @@ def compute_coefficients(data, h_ref, psd, freqs, f_bins, f_bins_center): | |
| for i in range(len(f_bins) - 1): | ||
| f_index = jnp.where((freqs >= f_bins[i]) & (freqs < f_bins[i + 1]))[0] | ||
| freq_shift = freqs[f_index] - f_bins_center[i] | ||
| A0_array.append(4 * jnp.sum(data_prod[f_index]) * df) | ||
| A1_array.append(4 * jnp.sum(data_prod[f_index] * freq_shift) * df) | ||
| B0_array.append(4 * jnp.sum(self_prod[f_index]) * df) | ||
| B1_array.append(4 * jnp.sum(self_prod[f_index] * freq_shift) * df) | ||
|
|
||
| A0_array = jnp.array(A0_array) | ||
| A1_array = jnp.array(A1_array) | ||
| B0_array = jnp.array(B0_array) | ||
| B1_array = jnp.array(B1_array) | ||
| A0_array.append(jnp.sum(data_prod[f_index])) | ||
| A1_array.append(jnp.sum(data_prod[f_index] * freq_shift)) | ||
| B0_array.append(jnp.sum(self_prod[f_index])) | ||
| B1_array.append(jnp.sum(self_prod[f_index] * freq_shift)) | ||
|
|
||
| A0_array = 4 * df * jnp.array(A0_array) | ||
| A1_array = 4 * df * jnp.array(A1_array) | ||
| B0_array = 4 * df * jnp.array(B0_array) | ||
| B1_array = 4 * df * jnp.array(B1_array) | ||
| return A0_array, A1_array, B0_array, B1_array | ||
|
|
||
| def maximize_likelihood( | ||
|
|
@@ -676,15 +674,15 @@ def maximize_likelihood( | |
| for transform in sample_transforms: | ||
| parameter_names = transform.propagate_name(parameter_names) | ||
|
|
||
| super_obj = super(HeterodynedTransientLikelihoodFD, self) | ||
|
thomasckng marked this conversation as resolved.
Outdated
|
||
|
|
||
| def y(x: Float[Array, " n_dims"], data: dict) -> Float: | ||
| named_params = dict(zip(parameter_names, x)) | ||
| for transform in reversed(sample_transforms): | ||
| named_params = transform.backward(named_params) | ||
| for transform in likelihood_transforms: | ||
| named_params = transform.forward(named_params) | ||
| return -super(HeterodynedTransientLikelihoodFD, self).evaluate( | ||
| named_params, data | ||
| ) | ||
| return -super_obj.evaluate(named_params, data) | ||
|
|
||
| print("Starting the optimizer") | ||
|
|
||
|
|
@@ -720,7 +718,7 @@ def y(x: Float[Array, " n_dims"], data: dict) -> Float: | |
| non_finite_index[:common_length] | ||
| ].set(guess[:common_length]) | ||
|
|
||
| rng_key, best_fit, log_prob = optimizer.optimize( | ||
| _, best_fit, log_prob = optimizer.optimize( | ||
| jax.random.PRNGKey(12094), y, initial_position, {} | ||
| ) | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.