Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/ripplegw/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"IMRPhenomXAS",
"IMRPhenomXAS_NRTidalv3",
"IMRPhenomXPHM",
"IMRPhenomHM",
Comment thread
thomasckng marked this conversation as resolved.
"SineGaussian",
"waveform_preset",
]
Comment thread
thomasckng marked this conversation as resolved.
126 changes: 126 additions & 0 deletions src/ripplegw/waveforms/IMRPhenomHM.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import jax
import jax.numpy as jnp
from ..constants import MSUN, MTSUN, MRSUN, MPC
from jaxtyping import Array
from .spherical_harmonics import (
compute_sminus2_l2,
compute_sminus2_l3,
compute_sminus2_l4,
)


# Some pre-XPHM ripple code
from .IMRPhenomXPHM import (
XLALSimIMRPhenomHMGethlmModes,
)
Comment thread
thomasckng marked this conversation as resolved.


def gen_IMRPhenomHM(
frequency_array,
mass_1,
mass_2,
chi1,
chi2,
distance, # in Mpc
inclination,
phi0,
reference_frequency,
):
"""Generate IMRPhenomHM plus and cross polarizations."""

m1_SI = mass_1 * MSUN
m2_SI = mass_2 * MSUN
Mtot = mass_1 + mass_2

# Overall amplitude prefactor from LAL's XLALSimPhenomUtilsFDamp0:
# amp0 = Mtot * MRSUN * Mtot * MTSUN / distance
# where Mtot is in solar masses and distance is in meters
dist_m = distance * MPC # distance in meters
amp0 = Mtot * MRSUN * Mtot * MTSUN / dist_m

extra_params = {
"ModeArray": jnp.array(
[[2, 1], [2, 2], [3, 2], [3, 3], [4, 4]], dtype=jnp.int32
)
}

hlm = XLALSimIMRPhenomHMGethlmModes(
frequency_array,
m1_SI,
m2_SI,
0,
0,
chi1,
0,
0,
chi2,
phi0,
frequency_array[1] - frequency_array[0],
reference_frequency,
extra_params,
)

ells = extra_params["ModeArray"][:, 0]
minus1l = jnp.where(ells % 2 != 0, -1, 1)
mode_projections = jax.vmap(
get_phenomHMFD_mode_projection,
in_axes=(None, 0, 0, 0),
)(
inclination,
minus1l,
extra_params["ModeArray"][:, 0],
extra_params["ModeArray"][:, 1],
)

# Reshape to (n_modes, 2, 1) and (n_modes, 1, f_sampling) so they broadcast to (n_modes, 2, f_sampling)
projected = mode_projections[:, :, None] * hlm[:, None, :] * amp0
hp, hc = jnp.sum(projected, axis=0)

return hp, hc


def get_phenomHMFD_mode_projection(
theta: float,
minus1l: int | Array,
ell: int | Array,
m: int | Array,
) -> Array:
"""
Helper function to compute mode-by-mode plus- and cross-polarisation prefactors
"""

Y = jax.lax.switch(
ell - 2,
[
lambda: compute_sminus2_l2(theta, m),
lambda: compute_sminus2_l3(theta, m),
lambda: compute_sminus2_l4(theta, m),
],
)

def sym_branch():
# Equatorial symmetry: add in -m mode
Ymstar = jax.lax.switch(
ell - 2,
[
lambda: compute_sminus2_l2(theta, -m),
lambda: compute_sminus2_l3(theta, -m),
lambda: compute_sminus2_l4(theta, -m),
],
)
Ymstar = jnp.conj(Ymstar)
factorp = 0.5 * (Y + minus1l * Ymstar)
factorc = -1j * 0.5 * (Y - minus1l * Ymstar)
return jnp.array([factorp, factorc])

def asym_branch(): # NOTE This is for hypothetical m=0 modes, not currently implemented. Structure is there in case we ever want to use it
# Not adding in the -m mode
factorp = Y
factorc = -1j * factorp
return jnp.array([factorp, factorc])

return jax.lax.select(
m == 0,
asym_branch(),
sym_branch(),
)
7 changes: 4 additions & 3 deletions tests/cross_validation/test_lal_mismatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@
"TaylorF2": 1e-14,
"IMRPhenomPv2": 1e-4, # see note above
"IMRPhenomXPHM": 1e-6,
"IMRPhenomHM": 1e-6,
}
DEFAULT_MISMATCH_THRESHOLD = 1e-5 # fallback for unknown waveforms

Expand Down Expand Up @@ -317,8 +318,8 @@ def freq_params(request):
T = float(request.config.getoption("--T"))
return {
"f_l": 20.0,
"f_u": 1024.0,
"f_sampling": 2048.0,
"f_u": 2048.0,
"f_sampling": 4096.0,
"T": T,
"f_ref": 20.0,
}
Comment thread
thomasckng marked this conversation as resolved.
Outdated
Expand Down Expand Up @@ -353,6 +354,7 @@ def psd_data():
pytest.param("TaylorF2", DEFAULT_BOUNDS, id="TaylorF2"),
pytest.param("IMRPhenomPv2", BBH_BOUNDS, id="IMRPhenomPv2"),
pytest.param("IMRPhenomXPHM", BBH_BOUNDS, id="IMRPhenomXPHM"),
pytest.param("IMRPhenomHM", BBH_BOUNDS, id="IMRPhenomHM"),
],
)
def test_waveform_mismatch(
Expand Down Expand Up @@ -472,7 +474,6 @@ def _compute_lal(i_theta):
with ThreadPoolExecutor(max_workers=n_workers) as pool:
# map() preserves input order, so result[i] matches theta_batch[i]
lal_results = list(pool.map(_compute_lal, enumerate(theta_batch)))

lal_hp_list = []
lal_hc_list = []
theta_ripple_list = []
Expand Down
68 changes: 67 additions & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

LAL_AVAILABLE = True
except ImportError:
LAL_AVAILABLE = False
import traceback, sys; sys.stdout.write("IMPORT ERROR:\n"); traceback.print_exc(file=sys.stdout); LAL_AVAILABLE = False
Comment thread
thomasckng marked this conversation as resolved.
Outdated

Comment thread
thomasckng marked this conversation as resolved.

def check_lal_available():
Expand Down Expand Up @@ -76,6 +76,7 @@ def check_is_tidal(waveform_name: str) -> bool:
"IMRPhenomXAS",
"IMRPhenomPv2",
"IMRPhenomXPHM",
"IMRPhenomHM",
"SineGaussian",
]

Expand Down Expand Up @@ -198,6 +199,28 @@ def waveform(theta):
f_ref,
)
return hp, hc

elif waveform_name == "IMRPhenomHM":
from ripplegw.waveforms.IMRPhenomHM import gen_IMRPhenomHM as waveform_generator
from ripplegw.conversions import Mc_eta_to_ms

@jax.jit
def waveform(theta):
# theta = [Mc, eta, s1z, s2z, dist_mpc, tc, phic, inclination]
# consistent with the precessing-waveform convention used by this test suite
m1, m2 = Mc_eta_to_ms(jnp.array([theta[0], theta[1]]))
hp, hc = waveform_generator(
fs,
m1,
m2,
theta[2],
theta[3],
theta[4], # distance in Mpc
theta[7], # inclination
theta[6], # phi0
f_ref,
)
return hp, hc
Comment thread
thomasckng marked this conversation as resolved.

elif waveform_name == "SineGaussian":
from ripplegw.waveforms.SineGaussian import gen_SineGaussian_hphc
Expand Down Expand Up @@ -307,6 +330,48 @@ def _call_xphm(lalparams):
# NNLO angles. The caller detects the exception and excludes the sample
# from the mismatch assertion and histogram.
hp, hc = _call_xphm(_make_xphm_params(222))
elif waveform_name == "IMRPhenomHM":
# XPHM requires SimIMRPhenomXPHM directly with MSA prescription params.
# SimInspiralChooseFDWaveform cannot set the PhenomXPrecVersion flag needed
# to guarantee the MSA prescription that the ripple implementation uses.
# theta = [m1, m2, s1z, s2z, dist_mpc, tc, phic, inclination]
Comment thread
thomasckng marked this conversation as resolved.
Outdated
m1_kg = theta[0] * lal.MSUN_SI
m2_kg = theta[1] * lal.MSUN_SI
s1z = theta[2]
s2z = theta[3]
distance = theta[4] * 1e6 * lal.PC_SI
phi_ref = theta[6]
inclination = theta[7]

def _call_hm():
lalparams = lal.CreateDict()
ModeArray = lalsim.SimInspiralCreateModeArray()
for el, em in [(2, 1), (2, 2), (3, 2), (3, 3), (4, 4)]:
lalsim.SimInspiralModeArrayActivateMode(ModeArray, el, em)
lalsim.SimInspiralWaveformParamsInsertModeArray(lalparams, ModeArray)
return lalsim.SimInspiralChooseFDWaveform(
m1_kg,
m2_kg,
0.0,
0.0,
s1z,
0.0,
0.0,
s2z,
distance,
inclination,
phi_ref,
0,
0,
0,
df,
f_l,
f_u,
f_ref,
lalparams,
approximant,
)
hp, hc = _call_hm()
elif is_precessing:
# Precessing waveform: theta = [m1, m2, s1x, s1y, s1z, s2x, s2y, s2z, dist, tc, phic, inc]
m1_kg = theta[0] * lal.MSUN_SI
Expand Down Expand Up @@ -567,3 +632,4 @@ def generate_random_params(
)

return theta
print('LAL_AVAILABLE is', LAL_AVAILABLE)
Comment thread
thomasckng marked this conversation as resolved.
Outdated
Loading