Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ ripple is a JAX-based package for differentiable gravitational-wave waveform gen
- IMRPhenomXAS_NRTidalv3
- IMRPhenomPv2
- IMRPhenomXPHM (MSA)
- IMRPhenomHM

For a quick introduction, see the [Quick Start guide](https://ripplegw.readthedocs.io/en/stable/quickstart/).

Expand Down
1 change: 1 addition & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ ripple is a JAX-based package for differentiable gravitational-wave waveform gen
- IMRPhenomXAS_NRTidalv3
- IMRPhenomPv2
- IMRPhenomXPHM (MSA)
- IMRPhenomHM

!!! warning
ripple has not yet reached v1.0.0 and the API may change. Use at your own risk. Consider pinning to a specific version if you need API stability.
Expand Down
2 changes: 2 additions & 0 deletions src/ripplegw/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
IMRPhenomXAS,
IMRPhenomXAS_NRTidalv3,
IMRPhenomXPHM,
IMRPhenomHM,
SineGaussian,
waveform_preset,
)
Expand All @@ -18,6 +19,7 @@
"IMRPhenomXAS",
"IMRPhenomXAS_NRTidalv3",
"IMRPhenomXPHM",
"IMRPhenomHM",
Comment thread
thomasckng marked this conversation as resolved.
"SineGaussian",
"waveform_preset",
]
Comment thread
thomasckng marked this conversation as resolved.
1 change: 1 addition & 0 deletions src/ripplegw/benchmarks/timings/timing.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,7 @@ def main():
"TaylorF2",
"IMRPhenomD_NRTidalv2",
"IMRPhenomXAS_NRTidalv3",
"IMRPhenomHM",
],
help="Waveform approximant to time",
)
Expand Down
32 changes: 32 additions & 0 deletions src/ripplegw/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .waveforms.IMRPhenomD_NRTidalv2 import gen_IMRPhenomD_NRTidalv2_hphc
from .waveforms.IMRPhenomXAS import gen_IMRPhenomXAS_hphc
from .waveforms.IMRPhenomXAS_NRTidalv3 import gen_IMRPhenomXAS_NRTidalv3_hphc
from .waveforms.IMRPhenomHM import gen_IMRPhenomHM
from .waveforms.SineGaussian import gen_SineGaussian_hphc
from .waveforms.IMRPhenomXPHM import generate_xphm
from .conversions import Mc_eta_to_ms
Expand Down Expand Up @@ -328,6 +329,36 @@ def __repr__(self):
return f"IMRPhenomXPHM(f_ref={self.f_ref})"


class IMRPhenomHM(Waveform):
f_ref: float

def __init__(self, f_ref: float = 20.0):
self.f_ref = f_ref

def __call__(
self, frequency: Float[Array, " n_freq"], params: dict[str, Float]
) -> dict[str, Float[Array, " n_freq"]]:
output = {}
m1, m2 = Mc_eta_to_ms(jnp.array([params["M_c"], params["eta"]]))
hp, hc = gen_IMRPhenomHM(
frequency,
m1,
m2,
params["s1_z"],
params["s2_z"],
params["d_L"],
params["iota"],
params["phase_c"],
self.f_ref,
)
output["p"] = hp
output["c"] = hc
return output

def __repr__(self):
return f"IMRPhenomHM(f_ref={self.f_ref})"


class SineGaussian(Waveform):
def __init__(self):
pass
Expand Down Expand Up @@ -370,5 +401,6 @@ def __repr__(self):
"IMRPhenomXAS": IMRPhenomXAS,
"IMRPhenomXAS_NRTidalv3": IMRPhenomXAS_NRTidalv3,
"IMRPhenomXPHM": IMRPhenomXPHM,
"IMRPhenomHM": IMRPhenomHM,
"SineGaussian": SineGaussian,
}
126 changes: 126 additions & 0 deletions src/ripplegw/waveforms/IMRPhenomHM.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import jax
import jax.numpy as jnp
from ..constants import MSUN, MTSUN, MRSUN, MPC
from jaxtyping import Array
from .spherical_harmonics import (
compute_sminus2_l2,
compute_sminus2_l3,
compute_sminus2_l4,
)


# Some pre-XPHM ripple code
from .IMRPhenomXPHM import (
XLALSimIMRPhenomHMGethlmModes,
)
Comment thread
thomasckng marked this conversation as resolved.


def gen_IMRPhenomHM(
frequency_array,
mass_1,
mass_2,
chi1,
chi2,
distance, # in Mpc
inclination,
phi0,
reference_frequency,
):
"""Generate IMRPhenomHM plus and cross polarizations."""

m1_SI = mass_1 * MSUN
m2_SI = mass_2 * MSUN
Mtot = mass_1 + mass_2

# Overall amplitude prefactor from LAL's XLALSimPhenomUtilsFDamp0:
# amp0 = Mtot * MRSUN * Mtot * MTSUN / distance
# where Mtot is in solar masses and distance is in meters
dist_m = distance * MPC # distance in meters
amp0 = Mtot * MRSUN * Mtot * MTSUN / dist_m

extra_params = {
"ModeArray": jnp.array(
[[2, 1], [2, 2], [3, 2], [3, 3], [4, 4]], dtype=jnp.int32
)
}

hlm = XLALSimIMRPhenomHMGethlmModes(
frequency_array,
m1_SI,
m2_SI,
0,
0,
chi1,
0,
0,
chi2,
phi0,
frequency_array[1] - frequency_array[0],
reference_frequency,
extra_params,
)

ells = extra_params["ModeArray"][:, 0]
minus1l = jnp.where(ells % 2 != 0, -1, 1)
mode_projections = jax.vmap(
get_phenomHMFD_mode_projection,
in_axes=(None, 0, 0, 0),
)(
inclination,
minus1l,
extra_params["ModeArray"][:, 0],
extra_params["ModeArray"][:, 1],
)

# Reshape to (n_modes, 2, 1) and (n_modes, 1, f_sampling) so they broadcast to (n_modes, 2, f_sampling)
projected = mode_projections[:, :, None] * hlm[:, None, :] * amp0
hp, hc = jnp.sum(projected, axis=0)

return hp, hc


def get_phenomHMFD_mode_projection(
theta: float,
minus1l: int | Array,
ell: int | Array,
m: int | Array,
) -> Array:
"""
Helper function to compute mode-by-mode plus- and cross-polarisation prefactors
"""

Y = jax.lax.switch(
ell - 2,
[
lambda: compute_sminus2_l2(theta, m),
lambda: compute_sminus2_l3(theta, m),
lambda: compute_sminus2_l4(theta, m),
],
)

def sym_branch():
# Equatorial symmetry: add in -m mode
Ymstar = jax.lax.switch(
ell - 2,
[
lambda: compute_sminus2_l2(theta, -m),
lambda: compute_sminus2_l3(theta, -m),
lambda: compute_sminus2_l4(theta, -m),
],
)
Ymstar = jnp.conj(Ymstar)
factorp = 0.5 * (Y + minus1l * Ymstar)
factorc = -1j * 0.5 * (Y - minus1l * Ymstar)
return jnp.array([factorp, factorc])

def asym_branch(): # NOTE This is for hypothetical m=0 modes, not currently implemented. Structure is there in case we ever want to use it
# Not adding in the -m mode
factorp = Y
factorc = -1j * factorp
return jnp.array([factorp, factorc])

return jax.lax.select(
m == 0,
asym_branch(),
sym_branch(),
)
3 changes: 2 additions & 1 deletion tests/cross_validation/test_lal_mismatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@
"TaylorF2": 1e-14,
"IMRPhenomPv2": 1e-4, # see note above
"IMRPhenomXPHM": 1e-6,
"IMRPhenomHM": 1e-6,
}
DEFAULT_MISMATCH_THRESHOLD = 1e-5 # fallback for unknown waveforms

Expand Down Expand Up @@ -364,6 +365,7 @@ def psd_data():
pytest.param("TaylorF2", DEFAULT_BOUNDS, id="TaylorF2"),
pytest.param("IMRPhenomPv2", BBH_BOUNDS, id="IMRPhenomPv2"),
pytest.param("IMRPhenomXPHM", BBH_BOUNDS, id="IMRPhenomXPHM"),
pytest.param("IMRPhenomHM", BBH_BOUNDS, id="IMRPhenomHM"),
],
)
def test_waveform_mismatch(
Expand Down Expand Up @@ -483,7 +485,6 @@ def _compute_lal(i_theta):
with ThreadPoolExecutor(max_workers=n_workers) as pool:
# map() preserves input order, so result[i] matches theta_batch[i]
lal_results = list(pool.map(_compute_lal, enumerate(theta_batch)))

lal_hp_list = []
lal_hc_list = []
theta_ripple_list = []
Expand Down
1 change: 1 addition & 0 deletions tests/integration/test_waveforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,6 +677,7 @@ def test_all_keys_present(self):
"IMRPhenomXAS",
"IMRPhenomXAS_NRTidalv3",
"IMRPhenomXPHM",
"IMRPhenomHM",
"SineGaussian",
}
assert expected == set(waveform_preset.keys())
Expand Down
64 changes: 64 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def check_is_tidal(waveform_name: str) -> bool:
"IMRPhenomXAS",
"IMRPhenomPv2",
"IMRPhenomXPHM",
"IMRPhenomHM",
"SineGaussian",
]

Expand Down Expand Up @@ -199,6 +200,27 @@ def waveform(theta):
)
return hp, hc

elif waveform_name == "IMRPhenomHM":
from ripplegw.waveforms.IMRPhenomHM import gen_IMRPhenomHM
from ripplegw.conversions import Mc_eta_to_ms

@jax.jit
def waveform(theta):
# theta = [Mc, eta, s1z, s2z, dist_mpc, tc, phic, inclination]
m1, m2 = Mc_eta_to_ms(jnp.array([theta[0], theta[1]]))
hp, hc = gen_IMRPhenomHM(
fs,
m1,
m2,
theta[2],
theta[3],
theta[4], # distance in Mpc
theta[7], # inclination
theta[6], # phi0
f_ref,
)
return hp, hc
Comment thread
thomasckng marked this conversation as resolved.

elif waveform_name == "SineGaussian":
from ripplegw.waveforms.SineGaussian import gen_SineGaussian_hphc

Expand Down Expand Up @@ -307,6 +329,47 @@ def _call_xphm(lalparams):
# NNLO angles. The caller detects the exception and excludes the sample
# from the mismatch assertion and histogram.
hp, hc = _call_xphm(_make_xphm_params(222))

elif waveform_name == "IMRPhenomHM":
m1_kg = theta[0] * lal.MSUN_SI
m2_kg = theta[1] * lal.MSUN_SI
s1z = theta[2]
s2z = theta[3]
distance = theta[4] * 1e6 * lal.PC_SI
phi_ref = theta[6]
inclination = theta[7]

def _call_hm():
lalparams = lal.CreateDict()
ModeArray = lalsim.SimInspiralCreateModeArray()
for el, em in [(2, 1), (2, 2), (3, 2), (3, 3), (4, 4)]:
lalsim.SimInspiralModeArrayActivateMode(ModeArray, el, em)
lalsim.SimInspiralWaveformParamsInsertModeArray(lalparams, ModeArray)
return lalsim.SimInspiralChooseFDWaveform(
m1_kg,
m2_kg,
0.0,
0.0,
s1z,
0.0,
0.0,
s2z,
distance,
inclination,
phi_ref,
0,
0,
0,
df,
f_l,
f_u,
f_ref,
lalparams,
approximant,
)

hp, hc = _call_hm()

elif is_precessing:
# Precessing waveform: theta = [m1, m2, s1x, s1y, s1z, s2x, s2y, s2z, dist, tc, phic, inc]
m1_kg = theta[0] * lal.MSUN_SI
Expand Down Expand Up @@ -339,6 +402,7 @@ def _call_xphm(lalparams):
None,
approximant,
)

else:
# Non-precessing waveform: theta = [m1, m2, s1z, s2z, (l1, l2), dist, tc, phic, inc]
if is_tidal:
Expand Down
2 changes: 1 addition & 1 deletion timings/submit_condor.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ N_RUNS="50"
mkdir -p "${OUTDIR}"

PRECISIONS=("float32" "float64")
MODELS=("TaylorF2" "IMRPhenomD" "IMRPhenomXAS" "IMRPhenomPv2" "IMRPhenomXPHM" "IMRPhenomD_NRTidalv2" "IMRPhenomXAS_NRTidalv3")
MODELS=("TaylorF2" "IMRPhenomD" "IMRPhenomXAS" "IMRPhenomPv2" "IMRPhenomXPHM" "IMRPhenomD_NRTidalv2" "IMRPhenomXAS_NRTidalv3" "IMRPhenomHM")

TIMING_SUB="${OUTDIR}/timing.sub"
POSTPROCESS_SUB="${OUTDIR}/postprocess.sub"
Expand Down
2 changes: 1 addition & 1 deletion timings/submit_slurm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ N_WAVEFORMS="10000"
N_RUNS="50"

PRECISIONS=("float32" "float64")
MODELS=("TaylorF2" "IMRPhenomD" "IMRPhenomXAS" "IMRPhenomPv2" "IMRPhenomXPHM" "IMRPhenomD_NRTidalv2" "IMRPhenomXAS_NRTidalv3")
MODELS=("TaylorF2" "IMRPhenomD" "IMRPhenomXAS" "IMRPhenomPv2" "IMRPhenomXPHM" "IMRPhenomD_NRTidalv2" "IMRPhenomXAS_NRTidalv3" "IMRPhenomHM")

mkdir -p "${SCRIPT_DIR}/outdir"

Expand Down