diff --git a/README.md b/README.md index bb66181b..8e7f5091 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,7 @@ ripple is a JAX-based package for differentiable gravitational-wave waveform gen - IMRPhenomXAS_NRTidalv3 - IMRPhenomPv2 - IMRPhenomXPHM (MSA) +- IMRPhenomHM For a quick introduction, see the [Quick Start guide](https://ripplegw.readthedocs.io/en/stable/quickstart/). diff --git a/docs/index.md b/docs/index.md index 2d936906..347dfb83 100644 --- a/docs/index.md +++ b/docs/index.md @@ -15,6 +15,7 @@ ripple is a JAX-based package for differentiable gravitational-wave waveform gen - IMRPhenomXAS_NRTidalv3 - IMRPhenomPv2 - IMRPhenomXPHM (MSA) +- IMRPhenomHM !!! warning ripple has not yet reached v1.0.0 and the API may change. Use at your own risk. Consider pinning to a specific version if you need API stability. diff --git a/src/ripplegw/__init__.py b/src/ripplegw/__init__.py index 1b773093..80ca6efb 100644 --- a/src/ripplegw/__init__.py +++ b/src/ripplegw/__init__.py @@ -6,6 +6,7 @@ IMRPhenomXAS, IMRPhenomXAS_NRTidalv3, IMRPhenomXPHM, + IMRPhenomHM, SineGaussian, waveform_preset, ) @@ -18,6 +19,7 @@ "IMRPhenomXAS", "IMRPhenomXAS_NRTidalv3", "IMRPhenomXPHM", + "IMRPhenomHM", "SineGaussian", "waveform_preset", ] diff --git a/src/ripplegw/benchmarks/timings/timing.py b/src/ripplegw/benchmarks/timings/timing.py index 52f58386..616e48f1 100644 --- a/src/ripplegw/benchmarks/timings/timing.py +++ b/src/ripplegw/benchmarks/timings/timing.py @@ -387,6 +387,7 @@ def main(): "TaylorF2", "IMRPhenomD_NRTidalv2", "IMRPhenomXAS_NRTidalv3", + "IMRPhenomHM", ], help="Waveform approximant to time", ) diff --git a/src/ripplegw/interfaces.py b/src/ripplegw/interfaces.py index b1ee8bb1..877970d5 100644 --- a/src/ripplegw/interfaces.py +++ b/src/ripplegw/interfaces.py @@ -9,6 +9,7 @@ from .waveforms.IMRPhenomD_NRTidalv2 import gen_IMRPhenomD_NRTidalv2_hphc from .waveforms.IMRPhenomXAS import gen_IMRPhenomXAS_hphc from .waveforms.IMRPhenomXAS_NRTidalv3 import gen_IMRPhenomXAS_NRTidalv3_hphc +from .waveforms.IMRPhenomHM import gen_IMRPhenomHM from .waveforms.SineGaussian import gen_SineGaussian_hphc from .waveforms.IMRPhenomXPHM import generate_xphm from .conversions import Mc_eta_to_ms @@ -328,6 +329,36 @@ def __repr__(self): return f"IMRPhenomXPHM(f_ref={self.f_ref})" +class IMRPhenomHM(Waveform): + f_ref: float + + def __init__(self, f_ref: float = 20.0): + self.f_ref = f_ref + + def __call__( + self, frequency: Float[Array, " n_freq"], params: dict[str, Float] + ) -> dict[str, Float[Array, " n_freq"]]: + output = {} + m1, m2 = Mc_eta_to_ms(jnp.array([params["M_c"], params["eta"]])) + hp, hc = gen_IMRPhenomHM( + frequency, + m1, + m2, + params["s1_z"], + params["s2_z"], + params["d_L"], + params["iota"], + params["phase_c"], + self.f_ref, + ) + output["p"] = hp + output["c"] = hc + return output + + def __repr__(self): + return f"IMRPhenomHM(f_ref={self.f_ref})" + + class SineGaussian(Waveform): def __init__(self): pass @@ -370,5 +401,6 @@ def __repr__(self): "IMRPhenomXAS": IMRPhenomXAS, "IMRPhenomXAS_NRTidalv3": IMRPhenomXAS_NRTidalv3, "IMRPhenomXPHM": IMRPhenomXPHM, + "IMRPhenomHM": IMRPhenomHM, "SineGaussian": SineGaussian, } diff --git a/src/ripplegw/waveforms/IMRPhenomHM.py b/src/ripplegw/waveforms/IMRPhenomHM.py new file mode 100644 index 00000000..45ecb4dc --- /dev/null +++ b/src/ripplegw/waveforms/IMRPhenomHM.py @@ -0,0 +1,126 @@ +import jax +import jax.numpy as jnp +from ..constants import MSUN, MTSUN, MRSUN, MPC +from jaxtyping import Array +from .spherical_harmonics import ( + compute_sminus2_l2, + compute_sminus2_l3, + compute_sminus2_l4, +) + + +# Some pre-XPHM ripple code +from .IMRPhenomXPHM import ( + XLALSimIMRPhenomHMGethlmModes, +) + + +def gen_IMRPhenomHM( + frequency_array, + mass_1, + mass_2, + chi1, + chi2, + distance, # in Mpc + inclination, + phi0, + reference_frequency, +): + """Generate IMRPhenomHM plus and cross polarizations.""" + + m1_SI = mass_1 * MSUN + m2_SI = mass_2 * MSUN + Mtot = mass_1 + mass_2 + + # Overall amplitude prefactor from LAL's XLALSimPhenomUtilsFDamp0: + # amp0 = Mtot * MRSUN * Mtot * MTSUN / distance + # where Mtot is in solar masses and distance is in meters + dist_m = distance * MPC # distance in meters + amp0 = Mtot * MRSUN * Mtot * MTSUN / dist_m + + extra_params = { + "ModeArray": jnp.array( + [[2, 1], [2, 2], [3, 2], [3, 3], [4, 4]], dtype=jnp.int32 + ) + } + + hlm = XLALSimIMRPhenomHMGethlmModes( + frequency_array, + m1_SI, + m2_SI, + 0, + 0, + chi1, + 0, + 0, + chi2, + phi0, + frequency_array[1] - frequency_array[0], + reference_frequency, + extra_params, + ) + + ells = extra_params["ModeArray"][:, 0] + minus1l = jnp.where(ells % 2 != 0, -1, 1) + mode_projections = jax.vmap( + get_phenomHMFD_mode_projection, + in_axes=(None, 0, 0, 0), + )( + inclination, + minus1l, + extra_params["ModeArray"][:, 0], + extra_params["ModeArray"][:, 1], + ) + + # Reshape to (n_modes, 2, 1) and (n_modes, 1, f_sampling) so they broadcast to (n_modes, 2, f_sampling) + projected = mode_projections[:, :, None] * hlm[:, None, :] * amp0 + hp, hc = jnp.sum(projected, axis=0) + + return hp, hc + + +def get_phenomHMFD_mode_projection( + theta: float, + minus1l: int | Array, + ell: int | Array, + m: int | Array, +) -> Array: + """ + Helper function to compute mode-by-mode plus- and cross-polarisation prefactors + """ + + Y = jax.lax.switch( + ell - 2, + [ + lambda: compute_sminus2_l2(theta, m), + lambda: compute_sminus2_l3(theta, m), + lambda: compute_sminus2_l4(theta, m), + ], + ) + + def sym_branch(): + # Equatorial symmetry: add in -m mode + Ymstar = jax.lax.switch( + ell - 2, + [ + lambda: compute_sminus2_l2(theta, -m), + lambda: compute_sminus2_l3(theta, -m), + lambda: compute_sminus2_l4(theta, -m), + ], + ) + Ymstar = jnp.conj(Ymstar) + factorp = 0.5 * (Y + minus1l * Ymstar) + factorc = -1j * 0.5 * (Y - minus1l * Ymstar) + return jnp.array([factorp, factorc]) + + def asym_branch(): # NOTE This is for hypothetical m=0 modes, not currently implemented. Structure is there in case we ever want to use it + # Not adding in the -m mode + factorp = Y + factorc = -1j * factorp + return jnp.array([factorp, factorc]) + + return jax.lax.select( + m == 0, + asym_branch(), + sym_branch(), + ) diff --git a/tests/cross_validation/test_lal_mismatch.py b/tests/cross_validation/test_lal_mismatch.py index 9fcf5a59..a7bc1d21 100644 --- a/tests/cross_validation/test_lal_mismatch.py +++ b/tests/cross_validation/test_lal_mismatch.py @@ -90,6 +90,7 @@ "TaylorF2": 1e-14, "IMRPhenomPv2": 1e-4, # see note above "IMRPhenomXPHM": 1e-6, + "IMRPhenomHM": 1e-6, } DEFAULT_MISMATCH_THRESHOLD = 1e-5 # fallback for unknown waveforms @@ -364,6 +365,7 @@ def psd_data(): pytest.param("TaylorF2", DEFAULT_BOUNDS, id="TaylorF2"), pytest.param("IMRPhenomPv2", BBH_BOUNDS, id="IMRPhenomPv2"), pytest.param("IMRPhenomXPHM", BBH_BOUNDS, id="IMRPhenomXPHM"), + pytest.param("IMRPhenomHM", BBH_BOUNDS, id="IMRPhenomHM"), ], ) def test_waveform_mismatch( @@ -483,7 +485,6 @@ def _compute_lal(i_theta): with ThreadPoolExecutor(max_workers=n_workers) as pool: # map() preserves input order, so result[i] matches theta_batch[i] lal_results = list(pool.map(_compute_lal, enumerate(theta_batch))) - lal_hp_list = [] lal_hc_list = [] theta_ripple_list = [] diff --git a/tests/integration/test_waveforms.py b/tests/integration/test_waveforms.py index 6b5fcaf1..6ad09c84 100644 --- a/tests/integration/test_waveforms.py +++ b/tests/integration/test_waveforms.py @@ -677,6 +677,7 @@ def test_all_keys_present(self): "IMRPhenomXAS", "IMRPhenomXAS_NRTidalv3", "IMRPhenomXPHM", + "IMRPhenomHM", "SineGaussian", } assert expected == set(waveform_preset.keys()) diff --git a/tests/utils.py b/tests/utils.py index 55e21816..fc101976 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -76,6 +76,7 @@ def check_is_tidal(waveform_name: str) -> bool: "IMRPhenomXAS", "IMRPhenomPv2", "IMRPhenomXPHM", + "IMRPhenomHM", "SineGaussian", ] @@ -199,6 +200,27 @@ def waveform(theta): ) return hp, hc + elif waveform_name == "IMRPhenomHM": + from ripplegw.waveforms.IMRPhenomHM import gen_IMRPhenomHM + from ripplegw.conversions import Mc_eta_to_ms + + @jax.jit + def waveform(theta): + # theta = [Mc, eta, s1z, s2z, dist_mpc, tc, phic, inclination] + m1, m2 = Mc_eta_to_ms(jnp.array([theta[0], theta[1]])) + hp, hc = gen_IMRPhenomHM( + fs, + m1, + m2, + theta[2], + theta[3], + theta[4], # distance in Mpc + theta[7], # inclination + theta[6], # phi0 + f_ref, + ) + return hp, hc + elif waveform_name == "SineGaussian": from ripplegw.waveforms.SineGaussian import gen_SineGaussian_hphc @@ -307,6 +329,47 @@ def _call_xphm(lalparams): # NNLO angles. The caller detects the exception and excludes the sample # from the mismatch assertion and histogram. hp, hc = _call_xphm(_make_xphm_params(222)) + + elif waveform_name == "IMRPhenomHM": + m1_kg = theta[0] * lal.MSUN_SI + m2_kg = theta[1] * lal.MSUN_SI + s1z = theta[2] + s2z = theta[3] + distance = theta[4] * 1e6 * lal.PC_SI + phi_ref = theta[6] + inclination = theta[7] + + def _call_hm(): + lalparams = lal.CreateDict() + ModeArray = lalsim.SimInspiralCreateModeArray() + for el, em in [(2, 1), (2, 2), (3, 2), (3, 3), (4, 4)]: + lalsim.SimInspiralModeArrayActivateMode(ModeArray, el, em) + lalsim.SimInspiralWaveformParamsInsertModeArray(lalparams, ModeArray) + return lalsim.SimInspiralChooseFDWaveform( + m1_kg, + m2_kg, + 0.0, + 0.0, + s1z, + 0.0, + 0.0, + s2z, + distance, + inclination, + phi_ref, + 0, + 0, + 0, + df, + f_l, + f_u, + f_ref, + lalparams, + approximant, + ) + + hp, hc = _call_hm() + elif is_precessing: # Precessing waveform: theta = [m1, m2, s1x, s1y, s1z, s2x, s2y, s2z, dist, tc, phic, inc] m1_kg = theta[0] * lal.MSUN_SI @@ -339,6 +402,7 @@ def _call_xphm(lalparams): None, approximant, ) + else: # Non-precessing waveform: theta = [m1, m2, s1z, s2z, (l1, l2), dist, tc, phic, inc] if is_tidal: diff --git a/timings/submit_condor.sh b/timings/submit_condor.sh index bc8bdf7c..d2000bbd 100755 --- a/timings/submit_condor.sh +++ b/timings/submit_condor.sh @@ -14,7 +14,7 @@ N_RUNS="50" mkdir -p "${OUTDIR}" PRECISIONS=("float32" "float64") -MODELS=("TaylorF2" "IMRPhenomD" "IMRPhenomXAS" "IMRPhenomPv2" "IMRPhenomXPHM" "IMRPhenomD_NRTidalv2" "IMRPhenomXAS_NRTidalv3") +MODELS=("TaylorF2" "IMRPhenomD" "IMRPhenomXAS" "IMRPhenomPv2" "IMRPhenomXPHM" "IMRPhenomD_NRTidalv2" "IMRPhenomXAS_NRTidalv3" "IMRPhenomHM") TIMING_SUB="${OUTDIR}/timing.sub" POSTPROCESS_SUB="${OUTDIR}/postprocess.sub" diff --git a/timings/submit_slurm.sh b/timings/submit_slurm.sh index 97116a49..e7654773 100755 --- a/timings/submit_slurm.sh +++ b/timings/submit_slurm.sh @@ -12,7 +12,7 @@ N_WAVEFORMS="10000" N_RUNS="50" PRECISIONS=("float32" "float64") -MODELS=("TaylorF2" "IMRPhenomD" "IMRPhenomXAS" "IMRPhenomPv2" "IMRPhenomXPHM" "IMRPhenomD_NRTidalv2" "IMRPhenomXAS_NRTidalv3") +MODELS=("TaylorF2" "IMRPhenomD" "IMRPhenomXAS" "IMRPhenomPv2" "IMRPhenomXPHM" "IMRPhenomD_NRTidalv2" "IMRPhenomXAS_NRTidalv3" "IMRPhenomHM") mkdir -p "${SCRIPT_DIR}/outdir"