diff --git a/example/workbench.py b/example/workbench.py index fb225fc57..fe294ce88 100644 --- a/example/workbench.py +++ b/example/workbench.py @@ -15,7 +15,7 @@ UniformSpherePrior, ) from jimgw.core.single_event.detector import get_H1, get_L1, get_V1 -from jimgw.core.single_event.likelihood import TransientLikelihoodFD +from jimgw.core.single_event.likelihood import BaseTransientLikelihoodFD from jimgw.core.single_event.data import Data from jimgw.core.single_event.waveform import RippleIMRPhenomPv2 from jimgw.core.transforms import BoundToUnbound @@ -131,7 +131,9 @@ gps_time=gps, ifos=ifos, dL_min=dL_prior.xmin, dL_max=dL_prior.xmax ), GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(gps_time=gps, ifo=ifos[0]), - GeocentricArrivalTimeToDetectorArrivalTimeTransform(tc_min=t_c_prior.xmin, tc_max=t_c_prior.xmax, gps_time=gps, ifo=ifos[0]), + GeocentricArrivalTimeToDetectorArrivalTimeTransform( + tc_min=t_c_prior.xmin, tc_max=t_c_prior.xmax, gps_time=gps, ifo=ifos[0] + ), SkyFrameToDetectorFrameSkyPositionTransform(gps_time=gps, ifos=ifos), BoundToUnbound( name_mapping=(["M_c"], ["M_c_unbounded"]), @@ -207,13 +209,12 @@ ] -likelihood = TransientLikelihoodFD( +likelihood = BaseTransientLikelihoodFD( ifos, waveform=waveform, trigger_time=gps, f_min=fmin, f_max=fmax, - # marginalization="time", ) jim = Jim( diff --git a/jim_dagster/InjectionRecovery/assets.py b/jim_dagster/InjectionRecovery/assets.py index becfe7107..60e8d825c 100644 --- a/jim_dagster/InjectionRecovery/assets.py +++ b/jim_dagster/InjectionRecovery/assets.py @@ -19,6 +19,7 @@ # Sample a fiducial population + @dg.asset( group_name="prerun", key_prefix="InjectionRecovery", @@ -32,10 +33,12 @@ def sample_population(): path_prefix="./data/", ) + # TODO: Add diagnostics regarding the sampled population. # Create asset group for run and configuration + @dg.asset( group_name="prerun", description="Configuration file for the run.", @@ -76,10 +79,17 @@ def config_file(): run.local_data_prefix = f"./data/runs/{idx}/strains/" run.serialize(f"./data/runs/{idx}/config.yaml") + @dg.multi_asset( specs=[ - dg.AssetSpec(key=["InjectionRecovery", "strain"], deps=[["InjectionRecovery", "config_file"]]), - dg.AssetSpec(key=["InjectionRecovery", "psd"], deps=[["InjectionRecovery", "config_file"]]), + dg.AssetSpec( + key=["InjectionRecovery", "strain"], + deps=[["InjectionRecovery", "config_file"]], + ), + dg.AssetSpec( + key=["InjectionRecovery", "psd"], + deps=[["InjectionRecovery", "config_file"]], + ), ], group_name="prerun", ) @@ -122,19 +132,53 @@ def raw_data(): detector.data.to_file(f"./data/runs/{idx}/strains/{ifo}_data") detector.psd.to_file(f"./data/runs/{idx}/strains/{ifo}_psd") + @dg.multi_asset( specs=[ - dg.AssetSpec(key=["InjectionRecovery", "training_chains"], deps=[["InjectionRecovery", "raw_data"]]), - dg.AssetSpec(key=["InjectionRecovery", "training_log_prob"], deps=[["InjectionRecovery", "raw_data"]]), - dg.AssetSpec(key=["InjectionRecovery", "training_local_acceptance"], deps=[["InjectionRecovery", "raw_data"]]), - dg.AssetSpec(key=["InjectionRecovery", "training_global_acceptance"], deps=[["InjectionRecovery", "raw_data"]]), - dg.AssetSpec(key=["InjectionRecovery", "training_loss"], deps=[["InjectionRecovery", "raw_data"]]), - dg.AssetSpec(key=["InjectionRecovery", "production_chains"], deps=[["InjectionRecovery", "raw_data"]]), - dg.AssetSpec(key=["InjectionRecovery", "production_log_prob"], deps=[["InjectionRecovery", "raw_data"]]), - dg.AssetSpec(key=["InjectionRecovery", "production_local_acceptance"], deps=[["InjectionRecovery", "raw_data"]]), - dg.AssetSpec(key=["InjectionRecovery", "production_global_acceptance"], deps=[["InjectionRecovery", "raw_data"]]), - dg.AssetSpec(key=["InjectionRecovery", "auxiliary_nf_samples"], deps=[["InjectionRecovery", "raw_data"]]), - dg.AssetSpec(key=["InjectionRecovery", "auxiliary_prior_samples"], deps=[["InjectionRecovery", "raw_data"]]), + dg.AssetSpec( + key=["InjectionRecovery", "training_chains"], + deps=[["InjectionRecovery", "raw_data"]], + ), + dg.AssetSpec( + key=["InjectionRecovery", "training_log_prob"], + deps=[["InjectionRecovery", "raw_data"]], + ), + dg.AssetSpec( + key=["InjectionRecovery", "training_local_acceptance"], + deps=[["InjectionRecovery", "raw_data"]], + ), + dg.AssetSpec( + key=["InjectionRecovery", "training_global_acceptance"], + deps=[["InjectionRecovery", "raw_data"]], + ), + dg.AssetSpec( + key=["InjectionRecovery", "training_loss"], + deps=[["InjectionRecovery", "raw_data"]], + ), + dg.AssetSpec( + key=["InjectionRecovery", "production_chains"], + deps=[["InjectionRecovery", "raw_data"]], + ), + dg.AssetSpec( + key=["InjectionRecovery", "production_log_prob"], + deps=[["InjectionRecovery", "raw_data"]], + ), + dg.AssetSpec( + key=["InjectionRecovery", "production_local_acceptance"], + deps=[["InjectionRecovery", "raw_data"]], + ), + dg.AssetSpec( + key=["InjectionRecovery", "production_global_acceptance"], + deps=[["InjectionRecovery", "raw_data"]], + ), + dg.AssetSpec( + key=["InjectionRecovery", "auxiliary_nf_samples"], + deps=[["InjectionRecovery", "raw_data"]], + ), + dg.AssetSpec( + key=["InjectionRecovery", "auxiliary_prior_samples"], + deps=[["InjectionRecovery", "raw_data"]], + ), ], group_name="run", ) @@ -145,64 +189,140 @@ def run(): """ pass + # Create asset group for diagnostics -@dg.asset(group_name="diagnostics", deps=[["InjectionRecovery", "training_loss"]], key_prefix="InjectionRecovery") + +@dg.asset( + group_name="diagnostics", + deps=[["InjectionRecovery", "training_loss"]], + key_prefix="InjectionRecovery", +) def loss_plot(): pass -@dg.asset(group_name="diagnostics", deps=[["InjectionRecovery", "training_chains"]], key_prefix="InjectionRecovery") + +@dg.asset( + group_name="diagnostics", + deps=[["InjectionRecovery", "training_chains"]], + key_prefix="InjectionRecovery", +) def training_chains_corner_plot(): pass -@dg.asset(group_name="diagnostics", deps=[["InjectionRecovery", "training_chains"]], key_prefix="InjectionRecovery") + +@dg.asset( + group_name="diagnostics", + deps=[["InjectionRecovery", "training_chains"]], + key_prefix="InjectionRecovery", +) def training_chains_trace_plot(): pass -@dg.asset(group_name="diagnostics", deps=[["InjectionRecovery", "training_chains"]], key_prefix="InjectionRecovery") + +@dg.asset( + group_name="diagnostics", + deps=[["InjectionRecovery", "training_chains"]], + key_prefix="InjectionRecovery", +) def training_chains_rhat_plot(): pass -@dg.asset(group_name="diagnostics", deps=[["InjectionRecovery", "training_log_prob"]], key_prefix="InjectionRecovery") + +@dg.asset( + group_name="diagnostics", + deps=[["InjectionRecovery", "training_log_prob"]], + key_prefix="InjectionRecovery", +) def training_log_prob_distribution(): pass -@dg.asset(group_name="diagnostics", deps=[["InjectionRecovery", "training_log_prob"]], key_prefix="InjectionRecovery") + +@dg.asset( + group_name="diagnostics", + deps=[["InjectionRecovery", "training_log_prob"]], + key_prefix="InjectionRecovery", +) def training_log_prob_evolution(): pass -@dg.asset(group_name="diagnostics", deps=[["InjectionRecovery", "training_local_acceptance"]], key_prefix="InjectionRecovery") + +@dg.asset( + group_name="diagnostics", + deps=[["InjectionRecovery", "training_local_acceptance"]], + key_prefix="InjectionRecovery", +) def training_local_acceptance_plot(): pass -@dg.asset(group_name="diagnostics", deps=[["InjectionRecovery", "training_global_acceptance"]], key_prefix="InjectionRecovery") + +@dg.asset( + group_name="diagnostics", + deps=[["InjectionRecovery", "training_global_acceptance"]], + key_prefix="InjectionRecovery", +) def training_global_acceptance_plot(): pass -@dg.asset(group_name="diagnostics", deps=[["InjectionRecovery", "production_chains"]], key_prefix="InjectionRecovery") + +@dg.asset( + group_name="diagnostics", + deps=[["InjectionRecovery", "production_chains"]], + key_prefix="InjectionRecovery", +) def production_chains_corner_plot(): pass -@dg.asset(group_name="diagnostics", deps=[["InjectionRecovery", "production_chains"]], key_prefix="InjectionRecovery") + +@dg.asset( + group_name="diagnostics", + deps=[["InjectionRecovery", "production_chains"]], + key_prefix="InjectionRecovery", +) def production_chains_trace_plot(): pass -@dg.asset(group_name="diagnostics", deps=[["InjectionRecovery", "production_chains"]], key_prefix="InjectionRecovery") + +@dg.asset( + group_name="diagnostics", + deps=[["InjectionRecovery", "production_chains"]], + key_prefix="InjectionRecovery", +) def production_chains_rhat_plot(): pass -@dg.asset(group_name="diagnostics", deps=[["InjectionRecovery", "production_log_prob"]], key_prefix="InjectionRecovery") + +@dg.asset( + group_name="diagnostics", + deps=[["InjectionRecovery", "production_log_prob"]], + key_prefix="InjectionRecovery", +) def production_log_prob_distribution(): pass -@dg.asset(group_name="diagnostics", deps=[["InjectionRecovery", "production_log_prob"]], key_prefix="InjectionRecovery") + +@dg.asset( + group_name="diagnostics", + deps=[["InjectionRecovery", "production_log_prob"]], + key_prefix="InjectionRecovery", +) def production_log_prob_evolution(): pass -@dg.asset(group_name="diagnostics", deps=[["InjectionRecovery", "production_local_acceptance"]], key_prefix="InjectionRecovery") + +@dg.asset( + group_name="diagnostics", + deps=[["InjectionRecovery", "production_local_acceptance"]], + key_prefix="InjectionRecovery", +) def production_local_acceptance_plot(): pass -@dg.asset(group_name="diagnostics", deps=[["InjectionRecovery", "production_global_acceptance"]], key_prefix="InjectionRecovery") + +@dg.asset( + group_name="diagnostics", + deps=[["InjectionRecovery", "production_global_acceptance"]], + key_prefix="InjectionRecovery", +) def production_global_acceptance_plot(): pass diff --git a/jim_dagster/RealDataCatalog/assets.py b/jim_dagster/RealDataCatalog/assets.py index 7e2bc40a4..d35c1029c 100644 --- a/jim_dagster/RealDataCatalog/assets.py +++ b/jim_dagster/RealDataCatalog/assets.py @@ -12,20 +12,21 @@ event_partitions_def = DynamicPartitionsDefinition(name="event_name") + @dg.asset( key_prefix="RealDataCatalog", group_name="prerun", description="Fetch all confident events and their gps time", ) def event_list(context: AssetExecutionContext): - catalogs = ['GWTC-1-confident', 'GWTC-2.1-confident', 'GWTC-3-confident'] + catalogs = ["GWTC-1-confident", "GWTC-2.1-confident", "GWTC-3-confident"] result = [] event_names = [] for catalog in catalogs: - event_list = gwosc.api.fetch_catalog_json(catalog)['events'] + event_list = gwosc.api.fetch_catalog_json(catalog)["events"] for event in event_list.values(): - name = event['commonName'] - gps_time = event['GPS'] + name = event["commonName"] + gps_time = event["GPS"] result.append((name, gps_time)) event_names.append(name) os.makedirs("data", exist_ok=True) @@ -39,7 +40,7 @@ def event_list(context: AssetExecutionContext): # We should be able to partition this asset and run it in parallel for each event. @dg.multi_asset( specs=[ - dg.AssetSpec(["RealDataCatalog","strain"], deps=[event_list]), + dg.AssetSpec(["RealDataCatalog", "strain"], deps=[event_list]), dg.AssetSpec(["RealDataCatalog", "psd"], deps=[event_list]), ], group_name="prerun", @@ -61,7 +62,9 @@ def raw_data(context: AssetExecutionContext): data = Data.from_gwosc(ifo, start, end) data.to_file(os.path.join(event_dir, f"{ifo}_data")) # TODO: Perhaps we should make sure the PSD estimation window are the same accross all IFOs? - psd_data = Data.from_gwosc(ifo, start - 4098, end -2) # This needs to be changed at some point + psd_data = Data.from_gwosc( + ifo, start - 4098, end - 2 + ) # This needs to be changed at some point if np.isnan(psd_data.td).any(): psd_data = Data.from_gwosc(ifo, start + 2, end + 4098) if np.isnan(psd_data.td).any(): @@ -91,6 +94,7 @@ def raw_data_plot(context: AssetExecutionContext): Plot the raw strain data for each IFO for the event. """ import matplotlib.pyplot as plt + event_name = context.partition_key event_dir = os.path.join("data", event_name, "raw") plots_dir = os.path.join("data", event_name, "plots") @@ -101,7 +105,7 @@ def raw_data_plot(context: AssetExecutionContext): data_file = os.path.join(event_dir, f"{ifo}_data.npz") if os.path.exists(data_file): data = np.load(data_file) - t = data["epoch"] + np.arange(data["td"].shape[0]) * data['dt'] + t = data["epoch"] + np.arange(data["td"].shape[0]) * data["dt"] td = data["td"] if t is not None and td is not None: plt.figure() @@ -115,6 +119,7 @@ def raw_data_plot(context: AssetExecutionContext): plot_paths.append(plot_path) return plot_paths + @dg.asset( group_name="diagnostics", deps=[["RealDataCatalog", "psd"]], @@ -126,6 +131,7 @@ def psd_plot(context: AssetExecutionContext): Plot the PSD for each IFO for the event. """ import matplotlib.pyplot as plt + event_name = context.partition_key event_dir = os.path.join("data", event_name, "raw") plots_dir = os.path.join("data", event_name, "plots") @@ -172,7 +178,9 @@ def config_file(context: AssetExecutionContext): if os.path.exists(data_file) and os.path.exists(psd_file): available_ifos.append(ifo) if available_ifos == []: - raise RuntimeError(f"No IFOs with both data and PSD found for event {event_name}") + raise RuntimeError( + f"No IFOs with both data and PSD found for event {event_name}" + ) run = IMRPhenomPv2StandardCBCRunDefinition( n_chains=500, n_local_steps=100, @@ -219,15 +227,34 @@ def config_file(context: AssetExecutionContext): run.local_data_prefix = os.path.join(run_dir, "raw/") run.serialize(os.path.join(run_dir, "config.yaml")) + @dg.multi_asset( specs=[ - dg.AssetSpec(key=["RealDataCatalog", "training_loss"], deps=[raw_data, config_file]), - dg.AssetSpec(key=["RealDataCatalog", "production_chains"], deps=[raw_data, config_file]), - dg.AssetSpec(key=["RealDataCatalog", "production_log_prob"], deps=[raw_data, config_file]), - dg.AssetSpec(key=["RealDataCatalog", "production_local_acceptance"], deps=[raw_data, config_file]), - dg.AssetSpec(key=["RealDataCatalog", "production_global_acceptance"], deps=[raw_data, config_file]), - dg.AssetSpec(key=["RealDataCatalog", "auxiliary_nf_samples"], deps=[raw_data, config_file]), - dg.AssetSpec(key=["RealDataCatalog", "auxiliary_prior_samples"], deps=[raw_data, config_file]), + dg.AssetSpec( + key=["RealDataCatalog", "training_loss"], deps=[raw_data, config_file] + ), + dg.AssetSpec( + key=["RealDataCatalog", "production_chains"], deps=[raw_data, config_file] + ), + dg.AssetSpec( + key=["RealDataCatalog", "production_log_prob"], deps=[raw_data, config_file] + ), + dg.AssetSpec( + key=["RealDataCatalog", "production_local_acceptance"], + deps=[raw_data, config_file], + ), + dg.AssetSpec( + key=["RealDataCatalog", "production_global_acceptance"], + deps=[raw_data, config_file], + ), + dg.AssetSpec( + key=["RealDataCatalog", "auxiliary_nf_samples"], + deps=[raw_data, config_file], + ), + dg.AssetSpec( + key=["RealDataCatalog", "auxiliary_prior_samples"], + deps=[raw_data, config_file], + ), ], group_name="run", partitions_def=event_partitions_def, @@ -251,6 +278,7 @@ def loss_plot(context: AssetExecutionContext): Generate and save a loss plot from the training_loss asset. """ import matplotlib.pyplot as plt + event_name = context.partition_key run_dir = os.path.join("data", event_name) plots_dir = os.path.join(run_dir, "plots") @@ -272,6 +300,7 @@ def loss_plot(context: AssetExecutionContext): plt.close() return plot_path + @dg.asset( group_name="diagnostics", deps=[["RealDataCatalog", "production_chains"]], @@ -284,6 +313,7 @@ def production_chains_corner_plot(context: AssetExecutionContext): """ import matplotlib.pyplot as plt import corner + event_name = context.partition_key run_dir = os.path.join("data", event_name) plots_dir = os.path.join(run_dir, "plots") @@ -294,7 +324,22 @@ def production_chains_corner_plot(context: AssetExecutionContext): results = np.load(results_path, allow_pickle=True) chains = results["chains"].item() # keys = np.sort(list(chains.keys())) - keys = ['M_c', 'q', 's1_mag', 's1_theta', 's1_phi', 's2_mag', 's2_theta', 's2_phi', 'iota', 'd_L', 'phase_c', 'psi', 'ra', 'dec'] + keys = [ + "M_c", + "q", + "s1_mag", + "s1_theta", + "s1_phi", + "s2_mag", + "s2_theta", + "s2_phi", + "iota", + "d_L", + "phase_c", + "psi", + "ra", + "dec", + ] samples = np.array([chains[key] for key in keys]).T fig = corner.corner(samples[::10], labels=keys) plot_path = os.path.join(plots_dir, "production_chains_corner.png") @@ -302,6 +347,7 @@ def production_chains_corner_plot(context: AssetExecutionContext): plt.close(fig) return plot_path + @dg.asset( group_name="diagnostics", deps=[["RealDataCatalog", "auxiliary_nf_samples"]], @@ -314,6 +360,7 @@ def nf_samples_corner_plot(context: AssetExecutionContext): """ import matplotlib.pyplot as plt import corner + event_name = context.partition_key run_dir = os.path.join("data", event_name) plots_dir = os.path.join(run_dir, "plots") @@ -324,7 +371,22 @@ def nf_samples_corner_plot(context: AssetExecutionContext): results = np.load(results_path, allow_pickle=True) nf_samples = results["nf_samples"].item() # keys = np.sort(list(nf_samples.keys())) - keys = ['M_c', 'q', 's1_mag', 's1_theta', 's1_phi', 's2_mag', 's2_theta', 's2_phi', 'iota', 'd_L', 'phase_c', 'psi', 'ra', 'dec'] + keys = [ + "M_c", + "q", + "s1_mag", + "s1_theta", + "s1_phi", + "s2_mag", + "s2_theta", + "s2_phi", + "iota", + "d_L", + "phase_c", + "psi", + "ra", + "dec", + ] nf_samples = np.array([nf_samples[key] for key in keys]).T fig = corner.corner(nf_samples, labels=keys) # Thinning for better visualization plot_path = os.path.join(plots_dir, "nf_samples_corner.png") @@ -332,6 +394,7 @@ def nf_samples_corner_plot(context: AssetExecutionContext): plt.close(fig) return plot_path + @dg.asset( group_name="diagnostics", deps=[["RealDataCatalog", "auxiliary_prior_samples"]], @@ -344,6 +407,7 @@ def prior_samples_corner_plot(context: AssetExecutionContext): """ import matplotlib.pyplot as plt import corner + event_name = context.partition_key run_dir = os.path.join("data", event_name) plots_dir = os.path.join(run_dir, "plots") @@ -354,7 +418,22 @@ def prior_samples_corner_plot(context: AssetExecutionContext): results = np.load(results_path, allow_pickle=True) prior_samples = results["prior_samples"].item() # keys = np.sort(list(prior_samples.keys())) - keys = ['M_c', 'q', 's1_mag', 's1_theta', 's1_phi', 's2_mag', 's2_theta', 's2_phi', 'iota', 'd_L', 'phase_c', 'psi', 'ra', 'dec'] + keys = [ + "M_c", + "q", + "s1_mag", + "s1_theta", + "s1_phi", + "s2_mag", + "s2_theta", + "s2_phi", + "iota", + "d_L", + "phase_c", + "psi", + "ra", + "dec", + ] prior_samples = np.array([prior_samples[key] for key in keys]).T fig = corner.corner(prior_samples, labels=keys) # Thinning for better visualization plot_path = os.path.join(plots_dir, "prior_samples_corner.png") @@ -362,6 +441,7 @@ def prior_samples_corner_plot(context: AssetExecutionContext): plt.close(fig) return plot_path + @dg.asset( group_name="diagnostics", deps=[["RealDataCatalog", "production_chains"]], @@ -373,6 +453,7 @@ def production_chains_trace_plot(context: AssetExecutionContext): Generate and save a trace plot for the production chains. """ import matplotlib.pyplot as plt + event_name = context.partition_key run_dir = os.path.join("data", event_name) plots_dir = os.path.join(run_dir, "plots") @@ -382,7 +463,22 @@ def production_chains_trace_plot(context: AssetExecutionContext): raise FileNotFoundError(f"Results file not found: {results_path}") results = np.load(results_path, allow_pickle=True) chains = results["chains"].item() - keys = ['M_c', 'q', 's1_mag', 's1_theta', 's1_phi', 's2_mag', 's2_theta', 's2_phi', 'iota', 'd_L', 'phase_c', 'psi', 'ra', 'dec'] + keys = [ + "M_c", + "q", + "s1_mag", + "s1_theta", + "s1_phi", + "s2_mag", + "s2_theta", + "s2_phi", + "iota", + "d_L", + "phase_c", + "psi", + "ra", + "dec", + ] n_params = len(keys) samples = [chains[key] for key in keys] @@ -400,6 +496,7 @@ def production_chains_trace_plot(context: AssetExecutionContext): plt.close() return plot_path + @dg.asset( group_name="diagnostics", deps=[["RealDataCatalog", "production_log_prob"]], @@ -411,6 +508,7 @@ def production_log_prob_distribution(context: AssetExecutionContext): Generate and save a histogram of the production log probability. """ import matplotlib.pyplot as plt + event_name = context.partition_key run_dir = os.path.join("data", event_name) plots_dir = os.path.join(run_dir, "plots") @@ -432,6 +530,7 @@ def production_log_prob_distribution(context: AssetExecutionContext): plt.close() return plot_path + @dg.asset( group_name="diagnostics", deps=[["RealDataCatalog", "production_log_prob"]], @@ -443,6 +542,7 @@ def production_log_prob_evolution(context: AssetExecutionContext): Generate and save a plot of the evolution of the production log probability. """ import matplotlib.pyplot as plt + event_name = context.partition_key run_dir = os.path.join("data", event_name) plots_dir = os.path.join(run_dir, "plots") @@ -464,6 +564,7 @@ def production_log_prob_evolution(context: AssetExecutionContext): plt.close() return plot_path + @dg.asset( group_name="diagnostics", deps=[["RealDataCatalog", "production_local_acceptance"]], @@ -475,6 +576,7 @@ def production_local_acceptance_plot(context: AssetExecutionContext): Generate and save a plot of the local acceptance rate. """ import matplotlib.pyplot as plt + event_name = context.partition_key run_dir = os.path.join("data", event_name) plots_dir = os.path.join(run_dir, "plots") @@ -496,6 +598,7 @@ def production_local_acceptance_plot(context: AssetExecutionContext): plt.close() return plot_path + @dg.asset( group_name="diagnostics", deps=[["RealDataCatalog", "production_global_acceptance"]], @@ -507,6 +610,7 @@ def production_global_acceptance_plot(context: AssetExecutionContext): Generate and save a plot of the global acceptance rate. """ import matplotlib.pyplot as plt + event_name = context.partition_key run_dir = os.path.join("data", event_name) plots_dir = os.path.join(run_dir, "plots") diff --git a/src/jimgw/core/single_event/detector.py b/src/jimgw/core/single_event/detector.py index 92d11958c..3eb9d93cb 100644 --- a/src/jimgw/core/single_event/detector.py +++ b/src/jimgw/core/single_event/detector.py @@ -631,16 +631,13 @@ def inject_signal( self.set_frequency_bounds() masked_signal = projected_strain[self.frequency_mask] - df = self.sliced_frequencies[1] - self.sliced_frequencies[0] + df = self.sliced_frequencies[1] - self.sliced_frequencies[0] _optimal_snr_sq = inner_product( masked_signal, masked_signal, self.sliced_psd, df ) optimal_snr = _optimal_snr_sq**0.5 match_filtered_snr = complex_inner_product( - masked_signal, - self.sliced_fd_data, - self.sliced_psd, - df + masked_signal, self.sliced_fd_data, self.sliced_psd, df ) match_filtered_snr /= optimal_snr diff --git a/src/jimgw/core/single_event/likelihood.py b/src/jimgw/core/single_event/likelihood.py index d97466175..2b4a7bc06 100644 --- a/src/jimgw/core/single_event/likelihood.py +++ b/src/jimgw/core/single_event/likelihood.py @@ -2,7 +2,7 @@ import jax.numpy as jnp from flowMC.strategy.optimization import AdamOptimization from jax.scipy.special import logsumexp -from jaxtyping import Array, Float, Complex +from jaxtyping import Array, Float from typing import Optional from scipy.interpolate import interp1d from jimgw.core.utils import log_i0 @@ -17,15 +17,45 @@ ) import logging from typing import Sequence +from abc import abstractmethod class SingleEventLikelihood(LikelihoodBase): detectors: Sequence[Detector] waveform: Waveform + fixed_parameters: dict[str, Float] = {} - def __init__(self, detectors: Sequence[Detector], waveform: Waveform) -> None: + @property + def duration(self) -> Float: + return self.detectors[0].data.duration + + @property + def detector_names(self): + """The interferometers for the likelihood.""" + return [detector.name for detector in self.detectors] + + def __init__( + self, + detectors: Sequence[Detector], + waveform: Waveform, + fixed_parameters: Optional[dict[str, Float]] = None, + ) -> None: self.detectors = detectors self.waveform = waveform + self.fixed_parameters = fixed_parameters if fixed_parameters is not None else {} + + def evaluate(self, params: dict[str, Float], data: dict) -> Float: + """Evaluate the likelihood for a given set of parameters. + + This is a template method that calls the core likelihood evaluation method + """ + params.update(self.fixed_parameters) + return self._likelihood(params, data) + + @abstractmethod + def _likelihood(self, params: dict[str, Float], data: dict) -> Float: + """Core likelihood evaluation method to be implemented by subclasses.""" + raise NotImplementedError("Subclasses must implement this method.") class ZeroLikelihood(LikelihoodBase): @@ -33,23 +63,54 @@ def __init__(self): pass def evaluate(self, params: dict[str, Float], data: dict) -> Float: + """Evaluate the likelihood, which is always zero.""" return 0.0 -class TransientLikelihoodFD(SingleEventLikelihood): +class BaseTransientLikelihoodFD(SingleEventLikelihood): + """Base class for frequency-domain transient gravitational wave likelihood. + + This class provides the basic likelihood evaluation for gravitational wave transient events + in the frequency domain, using matched filtering across multiple detectors. + + Attributes: + frequencies (Float[Array]): The frequency array used for likelihood evaluation. + trigger_time (Float): The GPS time of the event trigger. + gmst (Float): Greenwich Mean Sidereal Time computed from the trigger time. + + Args: + detectors (Sequence[Detector]): List of detector objects containing data and metadata. + waveform (Waveform): Waveform model to evaluate. + f_min (Float, optional): Minimum frequency for likelihood evaluation. Defaults to 0. + f_max (Float, optional): Maximum frequency for likelihood evaluation. Defaults to infinity. + trigger_time (Float, optional): GPS time of the event trigger. Defaults to 0. + + Example: + >>> likelihood = BaseTransientLikelihoodFD(detectors, waveform, f_min=20, f_max=1024, trigger_time=1234567890) + >>> logL = likelihood.evaluate(params, data) + """ + def __init__( self, detectors: Sequence[Detector], waveform: Waveform, + fixed_parameters: Optional[dict[str, Float]] = None, f_min: Float = 0, f_max: Float = float("inf"), trigger_time: Float = 0, - **kwargs, ) -> None: - # NOTE: having 'kwargs' here makes it very difficult to diagnose - # errors and keep track of what's going on, would be better to list - # explicitly what the arguments are accepted + """Initializes the BaseTransientLikelihoodFD class. + Sets up the frequency bounds for the detectors and computes the Greenwich Mean Sidereal Time. + + Args: + detectors (Sequence[Detector]): List of detector objects. + waveform (Waveform): Waveform model. + f_min (Float, optional): Minimum frequency. Defaults to 0. + f_max (Float, optional): Maximum frequency. Defaults to infinity. + trigger_time (Float, optional): Event trigger time. Defaults to 0. + """ + super().__init__(detectors, waveform, fixed_parameters) # Set the frequency bounds for the detectors _frequencies = [] for detector in detectors: @@ -59,100 +120,273 @@ def __init__( assert jnp.all( jnp.array(_frequencies)[:-1] == jnp.array(_frequencies)[1:] ), "The frequency arrays are not all the same." - - self.detectors = detectors self.frequencies = _frequencies[0] - self.duration = self.detectors[0].data.duration - self.waveform = waveform self.trigger_time = trigger_time self.gmst = compute_gmst(self.trigger_time) - self.kwargs = kwargs - if "marginalization" in self.kwargs: - marginalization = self.kwargs["marginalization"] - assert marginalization in [ - "phase", - "phase-time", - "time", - ], "Only support time, phase and phase+time marginalzation" - self.marginalization = marginalization - if self.marginalization == "phase-time": - self.param_func = lambda x: {**x, "phase_c": 0.0, "t_c": 0.0} - self.likelihood_function = phase_time_marginalized_likelihood - logging.info("Marginalizing over phase and time") - elif self.marginalization == "time": - self.param_func = lambda x: {**x, "t_c": 0.0} - self.likelihood_function = time_marginalized_likelihood - logging.info("Marginalizing over time") - elif self.marginalization == "phase": - self.param_func = lambda x: {**x, "phase_c": 0.0} - self.likelihood_function = phase_marginalized_likelihood - logging.info("Marginalizing over phase") - if "time" in self.marginalization: - fs = self.detectors[0].data.sampling_frequency - duration = self.detectors[0].data.duration - self.kwargs["tc_array"] = jnp.fft.fftfreq( - int(duration * fs / 2), 1.0 / duration - ) - self.kwargs["pad_low"] = jnp.zeros(int(self.frequencies[0] * duration)) - if jnp.isclose(self.frequencies[-1], fs / 2.0 - 1.0 / duration): - self.kwargs["pad_high"] = jnp.array([]) - else: - self.kwargs["pad_high"] = jnp.zeros( - int( - (fs / 2.0 - 1.0 / duration - self.frequencies[-1]) - * duration - ) - ) - print() - else: - self.param_func = lambda x: x - self.likelihood_function = original_likelihood - self.marginalization = "" - - # the fixing_parameters is expected to be a dictionary - # with key as parameter name and value is the fixed value - # e.g. {'M_c': 1.1975, 't_c': 0} - if "fixing_parameters" in self.kwargs: - fixing_parameters = self.kwargs["fixing_parameters"] - print(f"Parameters are fixed {fixing_parameters}") - # check for conflict with the marginalization - assert not ( - "t_c" in fixing_parameters and "time" in self.marginalization - ), "Cannot have t_c fixed while having the marginalization of t_c turned on" - assert not ( - "phase_c" in fixing_parameters and "phase" in self.marginalization - ), "Cannot have phase_c fixed while having the marginalization of phase_c turned on" - # if the same key exists in both dictionary, - # the later one will overwrite the former one - self.fixing_func = lambda x: {**x, **fixing_parameters} + + def evaluate(self, params: dict[str, Float], data: dict) -> Float: + """Evaluate the log-likelihood for a given set of parameters. + + Computes the log-likelihood by matched filtering the model waveform against the data + for each detector, using the frequency-domain inner product. + + Args: + params (dict[str, Float]): Dictionary of model parameters. + data (dict): Dictionary containing data (not used in this implementation). + + Returns: + Float: The log-likelihood value. + """ + params.update(self.fixed_parameters) + params["trigger_time"] = self.trigger_time + params["gmst"] = self.gmst + log_likelihood = self._likelihood(params, data) + return log_likelihood + + def _likelihood(self, params: dict[str, Float], data: dict) -> Float: + """Core likelihood evaluation method for frequency-domain transient events.""" + waveform_sky = self.waveform(self.frequencies, params) + log_likelihood = 0.0 + df = ( + self.detectors[0].sliced_frequencies[1] + - self.detectors[0].sliced_frequencies[0] + ) + for ifo in self.detectors: + freqs, ifo_data, psd = ( + ifo.sliced_frequencies, + ifo.sliced_fd_data, + ifo.sliced_psd, + ) + h_dec = ifo.fd_response(freqs, waveform_sky, params) + match_filter_SNR = inner_product(h_dec, ifo_data, psd, df) + optimal_SNR = inner_product(h_dec, h_dec, psd, df) + log_likelihood += match_filter_SNR - optimal_SNR / 2 + return log_likelihood + + +class TimeMarginalizedLikelihoodFD(BaseTransientLikelihoodFD): + """Frequency-domain likelihood class with analytic marginalization over coalescence time. + + This class implements a likelihood function for gravitational wave transient events, + marginalized over the coalescence time parameter (`t_c`). The marginalization is performed + using a fast Fourier transform (FFT) over the frequency domain inner product between the + model and the data. The likelihood is computed for a set of detectors and a waveform model. + + Attributes: + tc_range (tuple[Float, Float]): The range of coalescence times to marginalize over. + tc_array (Float[Array, "duration*f_sample/2"]): Array of time shifts corresponding to FFT bins. + pad_low (Float[Array, "n_pad_low"]): Zero-padding array for frequencies below the minimum frequency. + pad_high (Float[Array, "n_pad_high"]): Zero-padding array for frequencies above the maximum frequency. + + Args: + detectors (Sequence[Detector]): List of detector objects containing data and metadata. + waveform (Waveform): Waveform model to evaluate. + f_min (Float, optional): Minimum frequency for likelihood evaluation. Defaults to 0. + f_max (Float, optional): Maximum frequency for likelihood evaluation. Defaults to infinity. + trigger_time (Float, optional): GPS time of the event trigger. Defaults to 0. + tc_range (tuple[Float, Float], optional): Range of coalescence times to marginalize over. Defaults to (-0.12, 0.12). + + Example: + >>> likelihood = TimeMarginalizedLikelihoodFD(detectors, waveform, f_min=20, f_max=1024, trigger_time=1234567890) + >>> logL = likelihood.evaluate(params, data) + """ + + tc_range: tuple[Float, Float] + tc_array: Float[Array, " duration*f_sample/2"] + pad_low: Float[Array, " n_pad_low"] + pad_high: Float[Array, " n_pad_high"] + + def __init__( + self, + detectors: Sequence[Detector], + waveform: Waveform, + fixed_parameters: Optional[dict[str, Float]] = None, + f_min: Float = 0, + f_max: Float = float("inf"), + trigger_time: Float = 0, + tc_range: tuple[Float, Float] = (-0.12, 0.12), + ) -> None: + """Initializes the TimeMarginalizedLikelihoodFD class. + + Sets up the frequency bounds, coalescence time range, FFT time array, and zero-padding + arrays for the likelihood calculation. + + Args: + detectors (Sequence[Detector]): List of detector objects. + waveform (Waveform): Waveform model. + f_min (Float, optional): Minimum frequency. Defaults to 0. + f_max (Float, optional): Maximum frequency. Defaults to infinity. + trigger_time (Float, optional): Event trigger time. Defaults to 0. + tc_range (tuple[Float, Float], optional): Marginalization range for coalescence time. Defaults to (-0.12, 0.12). + """ + super().__init__( + detectors, waveform, fixed_parameters, f_min, f_max, trigger_time + ) + assert ( + "t_c" not in self.fixed_parameters + ), "Cannot have t_c fixed while marginalizing over t_c" + self.tc_range = tc_range + fs = self.detectors[0].data.sampling_frequency + duration = self.detectors[0].data.duration + self.tc_array = jnp.fft.fftfreq(int(duration * fs / 2), 1.0 / duration) + self.pad_low = jnp.zeros(int(self.frequencies[0] * duration)) + if jnp.isclose(self.frequencies[-1], fs / 2.0 - 1.0 / duration): + self.pad_high = jnp.array([]) else: - self.fixing_func = lambda x: x + self.pad_high = jnp.zeros( + int((fs / 2.0 - 1.0 / duration - self.frequencies[-1]) * duration) + ) - @property - def detector_names(self): - """The interferometers for the likelihood.""" - return [detector.name for detector in self.detectors] + def evaluate(self, params: dict[str, Float], data: dict) -> Float: + params.update(self.fixed_parameters) + params["trigger_time"] = self.trigger_time + params["gmst"] = self.gmst + params["t_c"] = 0.0 # Fixing t_c to 0 for time marginalization + log_likelihood = self._likelihood(params, data) + return log_likelihood + + def _likelihood(self, params: dict[str, Float], data: dict) -> Float: + """Evaluate the time-marginalized likelihood for a given set of parameters. + Computes the log-likelihood marginalized over coalescence time by: + - Calculating the frequency-domain inner product between the model and data for each detector. + - Padding the inner product array to cover the full frequency range. + - Applying FFT to obtain the likelihood as a function of coalescence time. + - Restricting the FFT output to the specified `tc_range`. + - Marginalizing using logsumexp over the allowed coalescence times. + Args: + params (dict[str, Float]): Dictionary of model parameters. + data (dict): Dictionary containing data (not used in this implementation). + Returns: + Float: The marginalized log-likelihood value. + """ + + log_likelihood = 0.0 + complex_h_inner_d = jnp.zeros_like(self.detectors[0].sliced_frequencies) + df = ( + self.detectors[0].sliced_frequencies[1] + - self.detectors[0].sliced_frequencies[0] + ) + waveform_sky = self.waveform(self.frequencies, params) + for ifo in self.detectors: + freqs, ifo_data, psd = ( + ifo.sliced_frequencies, + ifo.sliced_fd_data, + ifo.sliced_psd, + ) + h_dec = ifo.fd_response(freqs, waveform_sky, params) + # using instead of + complex_h_inner_d += 4 * h_dec * jnp.conj(ifo_data) / psd * df + optimal_SNR = inner_product(h_dec, h_dec, psd, df) + log_likelihood += -optimal_SNR / 2 + + # Padding the complex_h_inner_d to cover the full frequency range + complex_h_inner_d_positive_f = jnp.concatenate( + (self.pad_low, complex_h_inner_d, self.pad_high) + ) + + # FFT to obtain exp(-i2πf t_c) as a function of t_c + fft_h_inner_d = jnp.fft.fft(complex_h_inner_d_positive_f, norm="backward") + + # Restrict FFT output to the allowed tc_range, set others to -inf + fft_h_inner_d = jnp.where( + (self.tc_array > self.tc_range[0]) & (self.tc_array < self.tc_range[1]), + fft_h_inner_d.real, + jnp.zeros_like(fft_h_inner_d.real) - jnp.inf, + ) + + # Marginalize over t_c using logsumexp + log_likelihood += logsumexp(fft_h_inner_d) - jnp.log(len(self.tc_array)) + return log_likelihood + + +class PhaseMarginalizedLikelihoodFD(BaseTransientLikelihoodFD): + """This has not been tested by a human yet.""" def evaluate(self, params: dict[str, Float], data: dict) -> Float: - # TODO: Test whether we need to pass data in or with class changes is fine. - """Evaluate the likelihood for a given set of parameters.""" + params.update(self.fixed_parameters) + params["phase_c"] = 0.0 # Fixing phase_c to 0 for phase marginalization params["trigger_time"] = self.trigger_time params["gmst"] = self.gmst - # adjust the params due to different marginalzation scheme - params = self.param_func(params) - # adjust the params due to fixing parameters - params = self.fixing_func(params) - # evaluate the waveform as usual + log_likelihood = self._likelihood(params, data) + return log_likelihood + + def _likelihood(self, params: dict[str, Float], data: dict) -> Float: + log_likelihood = 0.0 + complex_d_inner_h = 0.0 + 0.0j + waveform_sky = self.waveform(self.frequencies, params) - return self.likelihood_function( - params, - waveform_sky, - self.detectors, # type: ignore - **self.kwargs, + df = ( + self.detectors[0].sliced_frequencies[1] + - self.detectors[0].sliced_frequencies[0] ) + for ifo in self.detectors: + freqs, ifo_data, psd = ( + ifo.sliced_frequencies, + ifo.sliced_fd_data, + ifo.sliced_psd, + ) + h_dec = ifo.fd_response(freqs, waveform_sky, params) + complex_d_inner_h += complex_inner_product(h_dec, ifo_data, psd, df) + optimal_SNR = inner_product(h_dec, h_dec, psd, df) + log_likelihood += -optimal_SNR / 2 + + log_likelihood += log_i0(jnp.absolute(complex_d_inner_h)) + return log_likelihood + +class PhaseTimeMarginalizedLikelihoodFD(TimeMarginalizedLikelihoodFD): + """This has not been tested by a human yet.""" -class HeterodynedTransientLikelihoodFD(TransientLikelihoodFD): + def evaluate(self, params: dict[str, Float], data: dict) -> Float: + params.update(self.fixed_parameters) + params["trigger_time"] = self.trigger_time + params["gmst"] = self.gmst + params["t_c"] = 0.0 # Fix t_c for marginalization + params["phase_c"] = 0.0 + return self._likelihood(params, data) + + def _likelihood(self, params: dict[str, Float], data: dict) -> Float: + # Refactored: use self.detectors, self.frequencies, self.tc_array, self.pad_low, self.pad_high, self.tc_range + log_likelihood = 0.0 + complex_h_inner_d = 0.0 + 0.0j + + df = ( + self.detectors[0].sliced_frequencies[1] + - self.detectors[0].sliced_frequencies[0] + ) + waveform_sky = self.waveform(self.frequencies, params) + for ifo in self.detectors: + freqs, ifo_data, psd = ( + ifo.sliced_frequencies, + ifo.sliced_fd_data, + ifo.sliced_psd, + ) + h_dec = ifo.fd_response(freqs, waveform_sky, params) + complex_h_inner_d += complex_inner_product(h_dec, ifo_data, psd, df) + optimal_SNR = inner_product(h_dec, h_dec, psd, df) + log_likelihood += -optimal_SNR / 2 + + # Pad the complex_h_inner_d to cover the full frequency range + complex_h_inner_d_positive_f = jnp.concatenate( + (self.pad_low, complex_h_inner_d, self.pad_high) + ) + + # FFT to obtain exp(-i2πf t_c) as a function of t_c + fft_h_inner_d = jnp.fft.fft(complex_h_inner_d_positive_f, norm="backward") + + # Restrict FFT output to the allowed tc_range, set others to -inf + log_i0_abs_fft = jnp.where( + (self.tc_array > self.tc_range[0]) & (self.tc_array < self.tc_range[1]), + log_i0(jnp.absolute(fft_h_inner_d)), + jnp.zeros_like(fft_h_inner_d.real) - jnp.inf, + ) + + # Marginalize over t_c using logsumexp + log_likelihood += logsumexp(log_i0_abs_fft) - jnp.log(len(self.tc_array)) + return log_likelihood + + +class HeterodynedTransientLikelihoodFD(BaseTransientLikelihoodFD): n_bins: int # Number of bins to use for the likelihood ref_params: dict # Reference parameters for the likelihood freq_grid_low: Array # Heterodyned frequency grid @@ -180,10 +414,11 @@ def __init__( self, detectors: Sequence[Detector], waveform: Waveform, + fixed_parameters: Optional[dict[str, Float]] = None, f_min: Float = 0, f_max: Float = float("inf"), - n_bins: int = 100, trigger_time: float = 0, + n_bins: int = 100, popsize: int = 100, n_steps: int = 2000, ref_params: dict = {}, @@ -191,9 +426,11 @@ def __init__( prior: Optional[Prior] = None, sample_transforms: list[BijectiveTransform] = [], likelihood_transforms: list[NtoMTransform] = [], - **kwargs, - ) -> None: - super().__init__(detectors, waveform, f_min, f_max, trigger_time) + ): + + super().__init__( + detectors, waveform, fixed_parameters, f_min, f_max, trigger_time + ) logging.info("Initializing heterodyned likelihood..") @@ -201,45 +438,6 @@ def __init__( if reference_waveform is None: reference_waveform = waveform - self.kwargs = kwargs - if "marginalization" in self.kwargs: - marginalization = self.kwargs["marginalization"] - assert marginalization in [ - "phase", - ], "Heterodyned likelihood only support phase marginalzation" - self.marginalization = marginalization - if self.marginalization == "phase": - self.param_func = lambda x: {**x, "phase_c": 0.0} - self.likelihood_function = phase_marginalized_likelihood - self.rb_likelihood_function = ( - phase_marginalized_relative_binning_likelihood - ) - logging.info("Marginalizing over phase") - else: - self.param_func = lambda x: x - self.likelihood_function = original_likelihood - self.rb_likelihood_function = original_relative_binning_likelihood - self.marginalization = "" - - # the fixing_parameters is expected to be a dictionary - # with key as parameter name and value is the fixed value - # e.g. {'M_c': 1.1975, 't_c': 0} - if "fixing_parameters" in self.kwargs: - fixing_parameters = self.kwargs["fixing_parameters"] - logging.info(f"Parameters are fixed {fixing_parameters}") - # check for conflict with the marginalization - assert not ( - "t_c" in fixing_parameters and "time" in self.marginalization - ), "Cannot have t_c fixed while marginalizing over t_c" - assert not ( - "phase_c" in fixing_parameters and "phase" in self.marginalization - ), "Cannot have phase_c fixed while marginalizing over phase_c" - # if the same key exists in both dictionary, - # the later one will overwrite the former one - self.fixing_func = lambda x: {**x, **fixing_parameters} - else: - self.fixing_func = lambda x: x - # Get the original frequency grid frequency_original = self.frequencies # Get the grid of the relative binning scheme (contains the final endpoint) @@ -278,10 +476,6 @@ def __init__( self.ref_params["trigger_time"] = self.trigger_time self.ref_params["gmst"] = self.gmst - # adjust the params due to different marginalzation scheme - self.ref_params = self.param_func(self.ref_params) - # adjust the params due to fixing parameters - self.ref_params = self.fixing_func(self.ref_params) self.waveform_low_ref = {} self.waveform_center_ref = {} @@ -345,60 +539,45 @@ def __init__( self.B1_array[detector.name] = B1[mask_heterodyne_center] def evaluate(self, params: dict[str, Float], data: dict) -> Float: - frequencies_low = self.freq_grid_low - frequencies_center = self.freq_grid_center params["trigger_time"] = self.trigger_time params["gmst"] = self.gmst - # adjust the params due to different marginalzation scheme - params = self.param_func(params) - # adjust the params due to fixing parameters - params = self.fixing_func(params) + params.update(self.fixed_parameters) # evaluate the waveforms as usual + return self._likelihood(params, data) + + def _likelihood(self, params: dict[str, Float], data: dict) -> Float: + frequencies_low = self.freq_grid_low + frequencies_center = self.freq_grid_center + log_likelihood = 0.0 waveform_sky_low = self.waveform(frequencies_low, params) waveform_sky_center = self.waveform(frequencies_center, params) - log_likelihood = self.rb_likelihood_function( - params, - self.A0_array, - self.A1_array, - self.B0_array, - self.B1_array, - waveform_sky_low, - waveform_sky_center, - self.waveform_low_ref, - self.waveform_center_ref, - self.detectors, - frequencies_low, - frequencies_center, - **self.kwargs, - ) - return log_likelihood + for detector in self.detectors: + waveform_low = detector.fd_response( + frequencies_low, waveform_sky_low, params + ) + waveform_center = detector.fd_response( + frequencies_low, waveform_sky_center, params + ) - def evaluate_original( - self, params: dict[str, Float], data: dict - ) -> ( - Float - ): # TODO: Test whether we need to pass data in or with class changes is fine. - """ - Evaluate the likelihood for a given set of parameters. - """ - params["trigger_time"] = self.trigger_time - params["gmst"] = self.gmst - # adjust the params due to different marginalzation scheme - params = self.param_func(params) - # adjust the params due to fixing parameters - params = self.fixing_func(params) - # evaluate the waveform as usual - waveform_sky = self.waveform(self.frequencies, params) - return self.likelihood_function( - params, - waveform_sky, - self.detectors, # type: ignore - **self.kwargs, - ) + r0 = waveform_center / self.waveform_center_ref[detector.name] + r1 = (waveform_low / self.waveform_low_ref[detector.name] - r0) / ( + frequencies_low - frequencies_center + ) + match_filter_SNR = jnp.sum( + self.A0_array[detector.name] * r0.conj() + + self.A1_array[detector.name] * r1.conj() + ) + optimal_SNR = jnp.sum( + self.B0_array[detector.name] * jnp.abs(r0) ** 2 + + 2 * self.B1_array[detector.name] * (r0 * r1.conj()).real + ) + log_likelihood += (match_filter_SNR - optimal_SNR / 2).real + + return log_likelihood @staticmethod def max_phase_diff( - f: Float[Array, " n_freq"], + freqs: Float[Array, " n_freq"], f_low: float, f_high: float, chi: float = 1.0, @@ -406,9 +585,11 @@ def max_phase_diff( """ Compute the maximum phase difference between the frequencies in the array. + See Eq.(7) in arXiv:2302.05333. + Parameters ---------- - f: Float[Array, "n_dims"] + freqs: Float[Array, "n_freq"] Array of frequencies to be binned. f_low: float Lower frequency bound. @@ -419,18 +600,15 @@ def max_phase_diff( Returns ------- - Float[Array, "n_dims"] + Float[Array, "n_freq"] Maximum phase difference between the frequencies in the array. """ gamma = jnp.arange(-5, 6) / 3.0 - f_2D = jnp.broadcast_to(f, (f.size, gamma.size)) + # Promotes freqs to 2D with shape (n_freq, 10) for later f/f_star + freq_2D = jax.lax.broadcast_in_dim(freqs, (freqs.size, gamma.size), [0]) f_star = jnp.where(gamma >= 0, f_high, f_low) - return ( - 2 - * jnp.pi - * chi - * jnp.sum((f_2D / f_star) ** gamma * jnp.sign(gamma), axis=1) - ) + summand = (freq_2D / f_star) ** gamma * jnp.sign(gamma) + return 2 * jnp.pi * chi * jnp.sum(summand, axis=1) def make_binning_scheme( self, freqs: Float[Array, " n_freq"], n_bins: int, chi: float = 1 @@ -504,7 +682,9 @@ def y(x: Float[Array, " n_dims"], data: dict) -> Float: named_params = transform.backward(named_params) for transform in likelihood_transforms: named_params = transform.forward(named_params) - return -self.evaluate_original(named_params, data) + return -super(HeterodynedTransientLikelihoodFD, self).evaluate( + named_params, data + ) print("Starting the optimizer") @@ -552,240 +732,55 @@ def y(x: Float[Array, " n_dims"], data: dict) -> Float: return named_params -likelihood_presets = { - "TransientLikelihoodFD": TransientLikelihoodFD, - "HeterodynedTransientLikelihoodFD": HeterodynedTransientLikelihoodFD, -} +class HeterodynedPhaseMarginalizedLikelihoodFD(HeterodynedTransientLikelihoodFD): + def evaluate(self, params: dict[str, Float], data: dict) -> Float: + params.update(self.fixed_parameters) + params["phase_c"] = 0.0 + params["trigger_time"] = self.trigger_time + params["gmst"] = self.gmst + log_likelihood = self._likelihood(params, data) + return log_likelihood -def original_likelihood( - params: dict[str, Float], - h_sky: dict[str, Complex[Array, " n_dim"]], - detectors: list[Detector], - **kwargs, -) -> Float: - log_likelihood = 0.0 - df = detectors[0].sliced_frequencies[1] - detectors[0].sliced_frequencies[0] - for ifo in detectors: - freqs, data, psd = ifo.sliced_frequencies, ifo.sliced_fd_data, ifo.sliced_psd - h_dec = ifo.fd_response(freqs, h_sky, params) - match_filter_SNR = inner_product(h_dec, data, psd, df) - optimal_SNR = inner_product(h_dec, h_dec, psd, df) - log_likelihood += match_filter_SNR - optimal_SNR / 2 - - return log_likelihood - - -def phase_marginalized_likelihood( - params: dict[str, Float], - h_sky: dict[str, Complex[Array, " n_dim"]], - detectors: list[Detector], - **kwargs, -) -> Float: - log_likelihood = 0.0 - complex_d_inner_h = 0.0 + 0.0j - df = detectors[0].sliced_frequencies[1] - detectors[0].sliced_frequencies[0] - for ifo in detectors: - freqs, data, psd = ifo.sliced_frequencies, ifo.sliced_fd_data, ifo.sliced_psd - h_dec = ifo.fd_response(freqs, h_sky, params) - complex_d_inner_h += complex_inner_product(h_dec, data, psd, df) - optimal_SNR = inner_product(h_dec, h_dec, psd, df) - log_likelihood += -optimal_SNR / 2 - - log_likelihood += log_i0(jnp.absolute(complex_d_inner_h)) - return log_likelihood - - -def _get_tc_array(duration: Float, sampling_rate: Float): - return jnp.fft.fftfreq(int(duration * sampling_rate / 2), 1 / duration) - - -def _get_frequencies_pads(detector: Detector, fs: Float) -> tuple[Float, Float]: - f_low, f_high = detector.frequency_bounds - duration = detector.data.duration - delta_f = 1 / duration - - pad_low = jnp.zeros(int(f_low * duration)) - - f_Nyquist_diff = fs / 2.0 - delta_f - f_high - if jnp.isclose(f_Nyquist_diff, 0): - pad_high = jnp.array([]) - else: - pad_high = jnp.zeros(int(f_Nyquist_diff * duration)) - return pad_low, pad_high - - -def time_marginalized_likelihood( - params: dict[str, Float], - h_sky: dict[str, Complex[Array, " n_dim"]], - detectors: list[Detector], - **kwargs, -) -> Float: - log_likelihood = 0.0 - complex_h_inner_d = jnp.zeros_like(detectors[0].sliced_frequencies) - df = detectors[0].sliced_frequencies[1] - detectors[0].sliced_frequencies[0] - for ifo in detectors: - freqs, data, psd = ifo.sliced_frequencies, ifo.sliced_fd_data, ifo.sliced_psd - h_dec = ifo.fd_response(freqs, h_sky, params) - # using instead of - complex_h_inner_d += 4 * h_dec * jnp.conj(data) / psd * df - optimal_SNR = inner_product(h_dec, h_dec, psd, df) - log_likelihood += -optimal_SNR / 2 - - # fetch the tc range tc_array, lower padding and higher padding - tc_range = [-0.12, 0.12] # TODO: This is hard coded right now, need to update. - tc_array = kwargs["tc_array"] - pad_low = kwargs["pad_low"] - pad_high = kwargs["pad_high"] - - # padding the complex_h_inner_d - # this array is the hd*/S for f in [0, fs / 2 - df] - complex_h_inner_d_positive_f = jnp.concatenate( - (pad_low, complex_h_inner_d, pad_high) - ) - - # make use of the fft - # which then return the exp(-i2pift_c) - # w.r.t. the tc_array - fft_h_inner_d = jnp.fft.fft(complex_h_inner_d_positive_f, norm="backward") - - # set the values to -inf when it is outside the tc range - # so that they will disappear after the logsumexp - fft_h_inner_d = jnp.where( - (tc_array > tc_range[0]) & (tc_array < tc_range[1]), - fft_h_inner_d.real, - jnp.zeros_like(fft_h_inner_d.real) - jnp.inf, - ) - - # using the logsumexp to marginalize over the tc prior range - log_likelihood += logsumexp(fft_h_inner_d) - jnp.log(len(tc_array)) - return log_likelihood - - -def phase_time_marginalized_likelihood( - params: dict[str, Float], - h_sky: dict[str, Complex[Array, " n_dim"]], - detectors: list[Detector], - **kwargs, -) -> Float: - log_likelihood = 0.0 - complex_h_inner_d = 0.0 + 0.0j - df = detectors[0].sliced_frequencies[1] - detectors[0].sliced_frequencies[0] - for ifo in detectors: - freqs, data, psd = ifo.sliced_frequencies, ifo.sliced_fd_data, ifo.sliced_psd - h_dec = ifo.fd_response(freqs, h_sky, params) - # using instead of - complex_h_inner_d += complex_inner_product(data, h_dec, psd, df) - optimal_SNR = inner_product(h_dec, h_dec, psd, df) - log_likelihood += -optimal_SNR / 2 - duration = detectors[0].data.duration - - # fetch the tc range tc_array, lower padding and higher padding - tc_range = kwargs["tc_range"] - fs = kwargs["sampling_rate"] - tc_array = _get_tc_array(duration, fs) - pad_low, pad_high = _get_frequencies_pads(detectors[0], fs=fs) - - # padding the complex_h_inner_d - # this array is the hd*/S for f in [0, fs / 2 - df] - complex_h_inner_d_positive_f = jnp.concatenate( - (pad_low, complex_h_inner_d, pad_high) - ) - - # make use of the fft - # which then return the exp(-i2pift_c) - # w.r.t. the tc_array - fft_h_inner_d = jnp.fft.fft(complex_h_inner_d_positive_f, norm="backward") - - # set the values to -inf when it is outside the tc range - # so that they will disappear after the logsumexp - log_i0_abs_fft = jnp.where( - (tc_array > tc_range[0]) & (tc_array < tc_range[1]), - log_i0(jnp.absolute(fft_h_inner_d)), - jnp.zeros_like(fft_h_inner_d.real) - jnp.inf, - ) - - # using the logsumexp to marginalize over the tc prior range - log_likelihood += logsumexp(log_i0_abs_fft) - jnp.log(len(tc_array)) - return log_likelihood - - -def original_relative_binning_likelihood( - params, - A0_array, - A1_array, - B0_array, - B1_array, - waveform_sky_low, - waveform_sky_center, - waveform_low_ref, - waveform_center_ref, - detectors, - frequencies_low, - frequencies_center, - **kwargs, -): - log_likelihood = 0.0 - - for detector in detectors: - waveform_low = detector.fd_response(frequencies_low, waveform_sky_low, params) - waveform_center = detector.fd_response( - frequencies_low, waveform_sky_center, params - ) + def _likelihood(self, params: dict[str, Float], data: dict) -> Float: + frequencies_low = self.freq_grid_low + frequencies_center = self.freq_grid_center + waveform_sky_low = self.waveform(frequencies_low, params) + waveform_sky_center = self.waveform(frequencies_center, params) + log_likelihood = 0.0 + complex_d_inner_h = 0.0 - r0 = waveform_center / waveform_center_ref[detector.name] - r1 = (waveform_low / waveform_low_ref[detector.name] - r0) / ( - frequencies_low - frequencies_center - ) - match_filter_SNR = jnp.sum( - A0_array[detector.name] * r0.conj() + A1_array[detector.name] * r1.conj() - ) - optimal_SNR = jnp.sum( - B0_array[detector.name] * jnp.abs(r0) ** 2 - + 2 * B1_array[detector.name] * (r0 * r1.conj()).real - ) - log_likelihood += (match_filter_SNR - optimal_SNR / 2).real - - return log_likelihood - - -def phase_marginalized_relative_binning_likelihood( - params, - A0_array, - A1_array, - B0_array, - B1_array, - waveform_sky_low, - waveform_sky_center, - waveform_low_ref, - waveform_center_ref, - detectors, - frequencies_low, - frequencies_center, - **kwargs, -): - log_likelihood = 0.0 - complex_d_inner_h = 0.0 - - for detector in detectors: - waveform_low = detector.fd_response(frequencies_low, waveform_sky_low, params) - waveform_center = detector.fd_response( - frequencies_center, waveform_sky_center, params - ) + for detector in self.detectors: + waveform_low = detector.fd_response( + frequencies_low, waveform_sky_low, params + ) + waveform_center = detector.fd_response( + frequencies_center, waveform_sky_center, params + ) + r0 = waveform_center / self.waveform_center_ref[detector.name] + r1 = (waveform_low / self.waveform_low_ref[detector.name] - r0) / ( + frequencies_low - frequencies_center + ) + complex_d_inner_h += jnp.sum( + self.A0_array[detector.name] * r0.conj() + + self.A1_array[detector.name] * r1.conj() + ) + optimal_SNR = jnp.sum( + self.B0_array[detector.name] * jnp.abs(r0) ** 2 + + 2 * self.B1_array[detector.name] * (r0 * r1.conj()).real + ) + log_likelihood += -optimal_SNR.real / 2 - r0 = waveform_center / waveform_center_ref[detector.name] - r1 = (waveform_low / waveform_low_ref[detector.name] - r0) / ( - frequencies_low - frequencies_center - ) - complex_d_inner_h += jnp.sum( - A0_array[detector.name] * r0.conj() + A1_array[detector.name] * r1.conj() - ) - optimal_SNR = jnp.sum( - B0_array[detector.name] * jnp.abs(r0) ** 2 - + 2 * B1_array[detector.name] * (r0 * r1.conj()).real - ) - log_likelihood += -optimal_SNR.real / 2 + log_likelihood += log_i0(jnp.absolute(complex_d_inner_h)) + + return log_likelihood - log_likelihood += log_i0(jnp.absolute(complex_d_inner_h)) - return log_likelihood +likelihood_presets = { + "BaseTransientLikelihoodFD": BaseTransientLikelihoodFD, + "TimeMarginalizedLikelihoodFD": TimeMarginalizedLikelihoodFD, + "PhaseMarginalizedLikelihoodFD": PhaseMarginalizedLikelihoodFD, + "PhaseTimeMarginalizedLikelihoodFD": PhaseTimeMarginalizedLikelihoodFD, + "HeterodynedTransientLikelihoodFD": HeterodynedTransientLikelihoodFD, + "PhaseMarginalizedHeterodynedLikelihoodFD": HeterodynedPhaseMarginalizedLikelihoodFD, +} diff --git a/src/jimgw/core/single_event/utils.py b/src/jimgw/core/single_event/utils.py index 2b7be90de..6975afbc7 100644 --- a/src/jimgw/core/single_event/utils.py +++ b/src/jimgw/core/single_event/utils.py @@ -1,6 +1,5 @@ import jax.numpy as jnp from jaxtyping import Array, Float, Complex -from typing import Optional from jimgw.core.constants import MTSUN from jimgw.core.utils import safe_arctan2, carte_to_spherical_angles diff --git a/src/jimgw/run/library/IMRPhenomPv2_standard_cbc.py b/src/jimgw/run/library/IMRPhenomPv2_standard_cbc.py index fdf2fef55..81498e405 100644 --- a/src/jimgw/run/library/IMRPhenomPv2_standard_cbc.py +++ b/src/jimgw/run/library/IMRPhenomPv2_standard_cbc.py @@ -13,7 +13,7 @@ from jimgw.core.single_event.data import Data, PowerSpectrum from jimgw.core.single_event.detector import get_detector_preset -from jimgw.core.single_event.likelihood import TransientLikelihoodFD, ZeroLikelihood +from jimgw.core.single_event.likelihood import BaseTransientLikelihoodFD, ZeroLikelihood from jimgw.core.single_event.waveform import RippleIMRPhenomPv2 from jimgw.core.transforms import BoundToUnbound, BijectiveTransform, NtoMTransform from jimgw.core.single_event.transforms import ( @@ -22,6 +22,7 @@ MassRatioToSymmetricMassRatioTransform, DistanceToSNRWeightedDistanceTransform, GeocentricArrivalPhaseToDetectorArrivalPhaseTransform, + GeocentricArrivalTimeToDetectorArrivalTimeTransform, ) from typing import Optional, Sequence, Self @@ -54,7 +55,7 @@ def __init__( max_s2: float, iota_range: tuple[float, float], dL_range: tuple[float, float], - # t_c_range: tuple[float, float], + t_c_range: tuple[float, float], phase_c_range: tuple[float, float], psi_range: tuple[float, float], ra_range: tuple[float, float], @@ -69,7 +70,7 @@ def __init__( self.max_s2 = max_s2 self.iota_range = iota_range self.dL_range = dL_range - # self.t_c_range = t_c_range + self.t_c_range = t_c_range self.phase_c_range = phase_c_range self.psi_range = psi_range self.ra_range = ra_range @@ -85,7 +86,7 @@ def initialize_jim_objects(self): def initialize_likelihood( self, local_data_prefix: Optional[str] = None - ) -> TransientLikelihoodFD: + ) -> BaseTransientLikelihoodFD: logging.info("Initializing likelihood...") gps = self.gps @@ -123,13 +124,12 @@ def initialize_likelihood( waveform = RippleIMRPhenomPv2(f_ref=self.f_ref) - likelihood = TransientLikelihoodFD( + likelihood = BaseTransientLikelihoodFD( detectors=self.ifos, waveform=waveform, trigger_time=gps, f_min=self.f_min, f_max=self.f_max, - marginalization="time", ) return likelihood @@ -152,9 +152,9 @@ def initialize_prior(self) -> CombinePrior: 2.0, parameter_names=["d_L"], ) - # t_c_prior = UniformPrior( - # self.t_c_range[0], self.t_c_range[1], parameter_names=["t_c"] - # ) + t_c_prior = UniformPrior( + self.t_c_range[0], self.t_c_range[1], parameter_names=["t_c"] + ) phase_c_prior = UniformPrior( self.phase_c_range[0], self.phase_c_range[1], parameter_names=["phase_c"] ) @@ -173,7 +173,7 @@ def initialize_prior(self) -> CombinePrior: s2_prior, iota_prior, dL_prior, - # t_c_prior, + t_c_prior, phase_c_prior, psi_prior, ra_prior, @@ -202,12 +202,12 @@ def initialize_sample_transforms(self) -> Sequence[BijectiveTransform]: GeocentricArrivalPhaseToDetectorArrivalPhaseTransform( gps_time=self.gps, ifo=self.ifos[0] ), - # GeocentricArrivalTimeToDetectorArrivalTimeTransform( - # tc_min=self.t_c_range[0], - # tc_max=self.t_c_range[1], - # gps_time=self.gps, - # ifo=self.ifos[0], - # ), + GeocentricArrivalTimeToDetectorArrivalTimeTransform( + tc_min=self.t_c_range[0], + tc_max=self.t_c_range[1], + gps_time=self.gps, + ifo=self.ifos[0], + ), SkyFrameToDetectorFrameSkyPositionTransform( gps_time=self.gps, ifos=self.ifos ), @@ -289,7 +289,7 @@ def serialize(self, path: str = "./") -> dict: "max_s2": self.max_s2, "iota_range": list(self.iota_range), "dL_range": list(self.dL_range), - # "t_c_range": list(self.t_c_range), + "t_c_range": list(self.t_c_range), "phase_c_range": list(self.phase_c_range), "psi_range": list(self.psi_range), "ra_range": list(self.ra_range), @@ -322,7 +322,7 @@ def deserialize(cls, path: str) -> Self: max_s2=run_dict["max_s2"], iota_range=tuple(run_dict["iota_range"]), dL_range=tuple(run_dict["dL_range"]), - # t_c_range=tuple(run_dict["t_c_range"]), + t_c_range=tuple(run_dict["t_c_range"]), phase_c_range=tuple(run_dict["phase_c_range"]), psi_range=tuple(run_dict["psi_range"]), ra_range=tuple(run_dict["ra_range"]), @@ -354,7 +354,7 @@ def __init__(self): max_s2=0.99, iota_range=(0.0, jnp.pi), dL_range=(1.0, 10000.0), - # t_c_range=(-0.05, 0.05), + t_c_range=(-0.05, 0.05), phase_c_range=(0.0, 2 * jnp.pi), psi_range=(0.0, jnp.pi), ra_range=(0.0, 2 * jnp.pi), diff --git a/test/unit/test_likelhood.py b/test/unit/test_likelhood.py new file mode 100644 index 000000000..0bb2566e8 --- /dev/null +++ b/test/unit/test_likelhood.py @@ -0,0 +1,145 @@ +import pytest +import numpy as np +from jimgw.core.single_event.likelihood import ( + SingleEventLikelihood, + ZeroLikelihood, + BaseTransientLikelihoodFD, + TimeMarginalizedLikelihoodFD, + PhaseMarginalizedLikelihoodFD, + PhaseTimeMarginalizedLikelihoodFD, + HeterodynedTransientLikelihoodFD, + HeterodynedPhaseMarginalizedLikelihoodFD, +) +from jimgw.core.single_event.detector import get_H1, get_L1 +from jimgw.core.single_event.waveform import RippleIMRPhenomD +from jimgw.core.single_event.data import Data + + +@pytest.fixture +def detectors_and_waveform(): + gps = 1126259462.4 + start = gps - 2 + end = gps + 2 + psd_start = gps - 2048 + psd_end = gps + 2048 + fmin = 20.0 + fmax = 1024.0 + ifos = [get_H1(), get_L1()] + for ifo in ifos: + data = Data.from_gwosc(ifo.name, start, end) + ifo.set_data(data) + psd_data = Data.from_gwosc(ifo.name, psd_start, psd_end) + psd_fftlength = data.duration * data.sampling_frequency + ifo.set_psd(psd_data.to_psd(nperseg=psd_fftlength)) + waveform = RippleIMRPhenomD(f_ref=20.0) + return ifos, waveform, fmin, fmax, gps + + +def example_params(gmst): + return { + "M_c": 30.0, + "eta": 0.249, + "s1_z": 0.0, + "s2_z": 0.0, + "d_L": 400.0, + "phase_c": 0.0, + "t_c": 0.0, + "iota": 0.0, + "ra": 1.375, + "dec": -1.2108, + "gmst": gmst, + "psi": 0.0, + } + + +class TestZeroLikelihood: + def test_initialization_and_evaluation(self, detectors_and_waveform): + ifos, waveform, fmin, fmax, gps = detectors_and_waveform + likelihood = ZeroLikelihood() + assert isinstance(likelihood, ZeroLikelihood) + params = example_params(gps) + result = likelihood.evaluate(params, {}) + assert result == 0.0 + + +class TestBaseTransientLikelihoodFD: + def test_initialization(self, detectors_and_waveform): + ifos, waveform, fmin, fmax, gps = detectors_and_waveform + likelihood = BaseTransientLikelihoodFD( + detectors=ifos, waveform=waveform, f_min=fmin, f_max=fmax, trigger_time=gps + ) + assert isinstance(likelihood, BaseTransientLikelihoodFD) + assert np.allclose(likelihood.frequencies, [20.0, (20.0 + 1024.0) / 2, 1024.0]) + assert likelihood.trigger_time == 1126259462.4 + assert hasattr(likelihood, "gmst") + + def test_evaluation(self, detectors_and_waveform): + ifos, waveform, fmin, fmax, gps = detectors_and_waveform + likelihood = BaseTransientLikelihoodFD( + detectors=ifos, waveform=waveform, f_min=fmin, f_max=fmax, trigger_time=gps + ) + params = example_params(likelihood.gmst) + log_likelihood = likelihood.evaluate(params, {}) + assert np.isfinite(log_likelihood), "Log likelihood should be finite" + + +class TestTimeMarginalizedLikelihoodFD: + def test_initialization_and_evaluation(self, detectors_and_waveform): + ifos, waveform, fmin, fmax, gps = detectors_and_waveform + likelihood = TimeMarginalizedLikelihoodFD( + detectors=ifos, waveform=waveform, f_min=fmin, f_max=fmax, trigger_time=gps, tc_range=(-0.15, 0.15) + ) + assert isinstance(likelihood, TimeMarginalizedLikelihoodFD) + params = example_params(likelihood.gmst) + result = likelihood.evaluate(params, {}) + assert np.isfinite(result) + + +class TestPhaseMarginalizedLikelihoodFD: + def test_initialization_and_evaluation(self, detectors_and_waveform): + ifos, waveform, fmin, fmax, gps = detectors_and_waveform + likelihood = PhaseMarginalizedLikelihoodFD( + detectors=ifos, waveform=waveform, f_min=fmin, f_max=fmax, trigger_time=gps + ) + assert isinstance(likelihood, PhaseMarginalizedLikelihoodFD) + params = example_params(likelihood.gmst) + result = likelihood.evaluate(params, {}) + assert np.isfinite(result) + + +class TestPhaseTimeMarginalizedLikelihoodFD: + def test_initialization_and_evaluation(self, detectors_and_waveform): + ifos, waveform, fmin, fmax, gps = detectors_and_waveform + likelihood = PhaseTimeMarginalizedLikelihoodFD( + detectors=ifos, waveform=waveform, f_min=fmin, f_max=fmax, trigger_time=gps, tc_range=(-0.15, 0.15) + ) + assert isinstance(likelihood, PhaseTimeMarginalizedLikelihoodFD) + params = example_params(likelihood.gmst) + result = likelihood.evaluate(params, {}) + assert np.isfinite(result) + + +class TestHeterodynedTransientLikelihoodFD: + def test_initialization_and_evaluation(self, detectors_and_waveform): + ifos, waveform, fmin, fmax, gps = detectors_and_waveform + likelihood = HeterodynedTransientLikelihoodFD( + detectors=ifos, waveform=waveform, f_min=fmin, f_max=fmax, trigger_time=gps, ref_params=example_params(gps) + ) + assert isinstance(likelihood, HeterodynedTransientLikelihoodFD) + params = example_params(likelihood.gmst) + result = likelihood.evaluate(params, {}) + assert np.isfinite(result) + + +class TestHeterodynedPhaseMarginalizedLikelihoodFD: + def test_initialization_and_likelihood(self, detectors_and_waveform): + ifos, waveform, fmin, fmax, gps = detectors_and_waveform + likelihood = HeterodynedPhaseMarginalizedLikelihoodFD( + detectors=ifos, waveform=waveform, f_min=fmin, f_max=fmax, trigger_time=gps, ref_params=example_params(gps) + ) + assert isinstance(likelihood, HeterodynedPhaseMarginalizedLikelihoodFD) + params = example_params(likelihood.gmst) + result = likelihood.evaluate(params, {}) + assert np.isfinite(result) + +# Need to add tests for running the heterodyned likelihood with different parameters