From 26b65c6dd2ae1d49e3acb27a8662e3edf3e6a473 Mon Sep 17 00:00:00 2001 From: Hen42rik Date: Tue, 10 Mar 2026 14:15:09 +0100 Subject: [PATCH 01/13] bug fix in constraint building, simplified MR evaluation, updated plotting --- nmma/eos/eos_likelihood.py | 164 +++++++++++++++++++++++-------------- 1 file changed, 101 insertions(+), 63 deletions(-) diff --git a/nmma/eos/eos_likelihood.py b/nmma/eos/eos_likelihood.py index 289d4745..e1dc6455 100644 --- a/nmma/eos/eos_likelihood.py +++ b/nmma/eos/eos_likelihood.py @@ -4,10 +4,12 @@ import shutil import json from ast import literal_eval +import matplotlib from tqdm.contrib.concurrent import process_map import numpy as np from scipy.special import logsumexp -from scipy.stats import norm, gaussian_kde +from scipy.stats import norm +from scipy.ndimage import gaussian_filter from matplotlib import pyplot as plt from bilby.core.prior import WeightedCategorical, PriorDict from .eos_processing import EoSConverter @@ -60,27 +62,40 @@ def setup_submodel_conversion(self): self.conv_functions.append(self.sub_model.parameter_conversion) - def final_diagnostics(self, bestfit_params, args, result=None): + def final_diagnostics(self, bestfit_params, args, result=None, fig = None): + matplotlib.rcParams.update({'font.size': 16, 'font.family': 'serif'}) + # matplotlib.rcParams['font.serif'] = ['Times New Roman'] + plt.rcParams['font.serif'] bestfit_params =self.parameter_conversion(bestfit_params) - color_densities = [] radii, masses, lambdas = self.sub_model.eos_converter.macro_parameters.values() - x_lim = (np.min(radii)-0.3, np.max(radii)+0.3) - y_lim = (masses[0], masses[-1]+0.1) - fig, ax = plt.subplots(figsize=(10, 6)) - ax.set_xlim(x_lim) - ax.set_ylim(y_lim) - for constraint in self.sub_model.constraints: - color = ax._get_lines.get_next_color() - ax = constraint.plot(ax=ax, resolution=100, color=color) - ax.plot(radii, masses, label='Best fit EOS', zorder=10) - - ax.set_xlabel(r'Radius [km]') - ax.set_ylabel(r'Mass [M$_\odot$]') - ax.legend() - fig.savefig(os.path.join(args.outdir, f"{args.label}_mr_curve.png")) - - plt.show() + if fig is None: + x_lim = (min(np.min(radii)-0.3, 9), max(np.max(radii)+0.3, 15)) + y_lim = (masses[0], masses[-1]+0.1) + fig, ax = plt.subplots(figsize=(10, 10)) + ax.set_xlim(x_lim) + ax.set_ylim(y_lim) + ax.set_xlabel(r'Radius [km]') + ax.set_ylabel(r'Mass [M$_\odot$]') + + for constraint in self.sub_model.constraints: + color = ax._get_lines.get_next_color() + ax = constraint.plot(ax=ax, color=color) + else: + ax = fig.axes[0] + labels = [line.get_label() for line in fig.legends[0].legend_handles] + for constraint in self.sub_model.constraints: + if constraint.name in labels: + continue + color = ax._get_lines.get_next_color() + ax = constraint.plot(ax=ax, color=color) + fig.legends.clear() ## remove old legend to avoid duplicates + ax.plot(radii, masses, label=f'{args.label} Best fit EoS', zorder=10) + + fig.legend(ncols=2, loc='upper center', + bbox_to_anchor=(0.5, 0.00), handlelength=2) + fig.tight_layout() + fig.savefig(os.path.join(args.outdir, f"{args.label}_mr_curve.png"),bbox_inches='tight') + return fig def compose_eos_constraints(args, constraint_kinds=['lower_mtov', 'upper_mtov', 'mass_radius']): @@ -200,7 +215,7 @@ def initialise_from_dict(self, constraint_dict): raise ValueError('Unknown type of EoS Constraint. Must be "lower_mtov", \ "upper_mtov", "mass-radius" or "micro\ ') - return constraint_list + return constraint_list def parameter_conversion(self, parameters): return self.eos_converter.parameter_conversion(parameters) @@ -300,6 +315,7 @@ def __init__(self, measured_mass, measure_error, name=None, arxiv_ref=None, logn self.error = measure_error self.repr_add = f'of {measured_mass}+-{measure_error} M_sun' self.lognorm_method = lognorm_method + self.linestyle = '--' def __repr__(self): @@ -320,16 +336,16 @@ def log_likelihood(self, parameters, local_parameters=None): def plot(self, ax, resolution = 100, **kwargs): """Plot the mass constraint on the given figure.""" x_lim = ax.get_xlim() - y_lim = ax.get_ylim() - M_grid = np.linspace(*y_lim, num=resolution) - show_x, show_y = np.meshgrid(np.linspace(*x_lim, resolution), M_grid) - line = ax.hlines(self.mass, *x_lim, linestyle='--', linewidth=1.5, zorder=3, label=self.name, **kwargs) + dummy_line = ax.plot([], [], label=self.name, linestyle=self.linestyle, linewidth=2.5, **kwargs) + line = ax.hlines(self.mass, *x_lim, linestyle=self.linestyle, linewidth=2.5, zorder=3, **kwargs) cmap = fading_cmap(line.get_color()) - shade_profile = norm.pdf(M_grid, loc=self.mass, scale=self.error) - shading_matrix = np.repeat(shade_profile[:, np.newaxis], resolution, axis=1) - ax.contourf(show_x, show_y, shading_matrix, levels=50, cmap=cmap) + levels = [0.95, 0.68] + for i, level in enumerate(levels): + ax.hlines(self.mass + (i+1)*self.error, *x_lim, color=cmap(0.9*level), linewidth=1.5, zorder=1) + ax.hlines(self.mass - (i+1)*self.error, *x_lim, color=cmap(0.9*level), linewidth=1.5, zorder=1) return ax + class LowerMTOVConstraint(MassConstraint): '''Constraint that an EOS supports at least a certain TOV mass(within Gaussian uncertainty)''' def __init__(self, measured_mass, measure_error, name=None, arxiv_ref=None): @@ -364,6 +380,7 @@ def __init__(self, measured_mass, measure_error, name=None, arxiv_ref=None): Identifier of a relevant source """ super().__init__(measured_mass, measure_error, name, arxiv_ref, lognorm_method=norm.logsf) + self.linestyle = ':' class MassRadiusConstraint(EoSConstraint): @@ -389,22 +406,9 @@ def __init__(self, mass_array=None, radius_array=None, weights = None, file_path mass_array, radius_array, weights = self.read_data(file_path) elif mass_array is None or radius_array is None: raise ValueError('Must provide data for masses and radii as arrays or file from which to load') - - if len(radius_array) > 10000: - ratio = len(radius_array) // 10000 - else: - ratio = 1 - - radius = radius_array[::ratio] - mass = mass_array[::ratio] - if weights is not None: - weights = weights[::ratio] - - self.KDE = gaussian_kde((radius, mass), weights=weights) - self.test_masses= np.linspace(start=1., stop=2.5, num=150 ) # 1 to 2.5 Msun - self.rng = np.random.default_rng() - - + self.set_grid(mass_array, radius_array, weights) + self.test_masses = np.linspace(1.2, 2.5, 151) + def read_data(self, file_path): """Read mass-radius data from a file.""" data = np.loadtxt(file_path, unpack=True) @@ -443,12 +447,45 @@ def read_data(self, file_path): return masses, radius, weights - def log_likelihood(self, parameters, local_parameters): + def set_grid(self, masses, radii, weights, mass_step = 0.01, radius_step = 0.03): + """Set up a grid upon which to build a histogram of mass-radius data to approximate the pdf. + Note that when using multiple mass-radius measurements, all measurements should use the same stepsizes! + Parameters + ---------- + masses: np.array + Array with mass posterior of M-R measurement + radii: np.array + Array with radius posterior of M-R measurement, must be specified along an equal-length mass_array + weights: np.array, optional + Array with weights of the M-R samples, must be specified along an equal-length mass_array + mass_step: float + step size for mass grid in solar masses, default is 0.01 Msun + radius_step: float + step size for radius grid in km, default is 0.02 km (20 m) + """ + mass_bins = self.set_bins(masses, mass_step) + rad_bins = self.set_bins(radii, radius_step) + if 3*len(mass_bins)*len(rad_bins) > len(masses): + print("Warning: The histogram might be to sparsely populated to get meaningful results.") + + histogram, self.rad_edges, self.mass_edges = np.histogram2d(radii, masses, bins=[rad_bins, mass_bins], weights=weights, density=True) + drad = self.rad_edges[1] - self.rad_edges[0] + dmass = self.mass_edges[1] - self.mass_edges[0] + + self.histogram = gaussian_filter(histogram*dmass*drad, sigma=3) + + def set_bins(self, array, step_size, sensitivity=0.001): + low, high = np.quantile(array, [sensitivity, 1.- sensitivity]) + bins = np.arange(0.95*low, 1.05*high, step_size, dtype=np.float64) + return bins + + def log_likelihood(self, parameters, local_parameters): try: tov_mass = parameters.get('TOV_mass', local_parameters['masses'][-1]) return self.single_logl(tov_mass, local_parameters['masses'], local_parameters['radii']) except (ValueError, IndexError): + self.single_logl(tov_mass, local_parameters['masses'], local_parameters['radii']) return [ self.single_logl(masses[-1], masses, local_parameters['radii'][i]) for i, masses in enumerate(local_parameters['masses']) @@ -458,30 +495,31 @@ def single_logl(self, tov_mass, masses, radii): ## interpolate radii along equally spaced mass grid up to MTov test_mass_range=self.test_masses[self.test_masses Date: Tue, 10 Mar 2026 14:17:12 +0100 Subject: [PATCH 02/13] split eos conversion, added pulsar conversions --- nmma/core/conversion.py | 66 ++++++++++++++++++----------- nmma/eos/eos_processing.py | 86 ++++++++++++++++++++++---------------- 2 files changed, 91 insertions(+), 61 deletions(-) diff --git a/nmma/core/conversion.py b/nmma/core/conversion.py index 6279a99a..d2ba1f56 100644 --- a/nmma/core/conversion.py +++ b/nmma/core/conversion.py @@ -4,7 +4,7 @@ from scipy.integrate import simpson from astropy import units from astropy import cosmology as cosmo -from .constants import geom_msun_km, msun_to_ergs, get_cosmology, set_cosmology +from .constants import geom_msun_km, msun_to_ergs, msun_s, get_cosmology, set_cosmology from bilby.gw.conversion import ( component_masses_to_chirp_mass, @@ -191,32 +191,50 @@ def convert_mtot_mni(params): params["mrp_c"] = (params["xmix"]*(params["mtot"]-params["mni"])-params["mrp"]) return params +############################## pulsar timing conversions #################################### +def binary_mass_function(m_obs, m_comp, sin_i): + return (m_comp * sin_i)**3 / (m_obs + m_comp)**2 + +def shapiro_delay(m_comp, sin_i): + "see https://arxiv.org/pdf/1007.0933.pdf" + shapiro_range = msun_s*1.e6 * m_comp # in microseconds + orthometric_ratio = sin_i/(1+np.sqrt(1-sin_i**2)) + return shapiro_range * orthometric_ratio**3 + +def einstein_delay_orbital_factor(orbital_period, eccentricity): + "see, e.g., 10.1007/978-3-662-62110-3_1, p.12 " + return msun_s**(2/3) * eccentricity * (orbital_period /2/np.pi)**(1/3) +def simplified_einstein_delay(m_psr, m_comp, einstein_factor): + "see, e.g., 10.1007/978-3-662-62110-3_1, p.12 " + return einstein_factor *m_comp * (m_psr + 2*m_comp) / (m_psr + m_comp)**(4/3) + +def einstein_delay(m_psr, m_comp, orbital_period, eccentricity): + "see, e.g., 10.1007/978-3-662-62110-3_1, p.12 " + einstein_delay_factor = einstein_delay_orbital_factor(orbital_period, eccentricity) + return simplified_einstein_delay(m_psr, m_comp, einstein_delay_factor) + +def mass_parameters_to_sini(total_mass, mass_function, m_comp): + "Invert the binary mass function to get sin(i) for a given total mass and mass function" + return np.cbrt(mass_function * total_mass**2)/m_comp + ############################## EOS-related conversions #################################### -def EOS2Parameters(radii, masses, lambdas, m1_source, m2_source): - ### FIXME: Under what circumstance would these not simply be mass_val[-1], radius_val[-1]? +def EOS_to_ns_parameters(radii, masses, lambdas): TOV_mass = masses.max(axis=-1) TOV_radius = radii[np.argmax(masses)] + R_14, R_16 = np.interp(x=[1.4, 1.6], xp=masses, fp=radii, left=0, right=0) + return TOV_mass, TOV_radius, R_14, R_16 + +def EOS_to_system_parameters(radii, masses, lambdas, m1_source, m2_source): (log_lambda_1, log_lambda_2) = np.interp(x=[m1_source, m2_source], xp= masses, fp=np.log(lambdas), left=-np.inf, right=-np.inf) lambda_1 = np.exp(log_lambda_1) lambda_2 = np.exp(log_lambda_2) - try: - (radius_1, radius_2, R_14, R_16) = np.interp( - x=[m1_source, m2_source, 1.4, 1.6], - xp=masses, fp= radii, left =0, right=0) - - return TOV_mass, TOV_radius, lambda_1, lambda_2, radius_1, radius_2, R_14, R_16 - ## radius interpolation will raise an error if dealing with multiple sources at once - # In that case we return all values as corresponding arrays - except ValueError: - ref = np.ones_like(lambda_1) - (radius_1, radius_2, R_14, R_16) = np.interp( - x=[m1_source, m2_source, 1.4*ref, 1.6*ref], - xp=masses, fp= radii, left =0, right=0) + (radius_1, radius_2) = np.interp( x=[m1_source, m2_source], + xp=masses, fp= radii, left =0, right=0) - return ref*TOV_mass, ref*TOV_radius, lambda_1, lambda_2, radius_1, radius_2, R_14, R_16 + return lambda_1, lambda_2, radius_1, radius_2 def radii_from_qur(parameters): mass_1_source = parameters["mass_1_source"] @@ -807,8 +825,8 @@ def identity_conversion(self, parameters): 'log10_E0' : r'$\log_{10}(E_0{\rm [erg]})$', 'ratio_zeta' : r'$\zeta$', 'alpha' : r'$\alpha$', - 'KNtheta' : r'$\theta_{KN}$', - 'KNphi' : r'$\phi_{KN}$', + 'KNtheta' : r'$\theta_{KN} [^\circ]$', + 'KNphi' : r'$\phi_{KN} [^\circ]$', ## GRB parameters ## 'log10_E0' : r'$\log_{10}(E_0{\rm [erg]})$', 'ratio_epsilon' : r'$\epsilon$', @@ -826,11 +844,11 @@ def identity_conversion(self, parameters): 'mni_c' : r'$M_{\rm{Ni}}/M_{\rm{tot}}$', 'mrp_c' : r'$M_{\rm{rp,c}}{\rm [M_{\odot}]}$', ### EOS parameters ### - 'L_sym' : r'$L_{\rm{sym}}{\rm [MeV]}$', - 'K_sym' : r'$K_{\rm{sym}}{\rm [MeV]}$', - 'K_sat' : r'$K_{\rm{sat}}{\rm [MeV]}$', - '3n_sat' : r'$c_{3n_{\rm{sat}}{\rm [c]}$', - '5n_sat' : r'$c_{5n_{\rm{sat}}{\rm [c]}$', + 'L_sym' : r'$L_{\rm sym}{\rm [MeV]}$', + 'K_sym' : r'$K_{\rm sym}{\rm [MeV]}$', + 'K_sat' : r'$K_{\rm sat}{\rm [MeV]}$', + '3n_sat' : r'$c^2_{3n_{\rm sat}}{\rm [c^2]}$', + '5n_sat' : r'$c^2_{5n_{\rm sat}}{\rm [c^2]}$', 'TOV_mass' : r'$M_{\rm{TOV}}{\rm [M_{\odot}]}$', 'R_14' : r'$R_{1.4}{\rm[km]}$', 'lambda_tilde' : r'$\tilde{\Lambda}$', diff --git a/nmma/eos/eos_processing.py b/nmma/eos/eos_processing.py index c42555d8..0907fc77 100644 --- a/nmma/eos/eos_processing.py +++ b/nmma/eos/eos_processing.py @@ -6,18 +6,22 @@ import joblib from ast import literal_eval import keras as k -from ..core.conversion import radii_from_qur, EOS2Parameters +from ..core.conversion import radii_from_qur, EOS_to_ns_parameters, EOS_to_system_parameters def setup_eos_generator(args): - try: - with open(args.emulator_metadata, 'r') as f: - meta_dict = json.load(f) - except TypeError: - meta_dict = args.emulator_metadata - except FileNotFoundError: - meta_dict = literal_eval(args.emulator_metadata) - - eos_model_type = args.micro_eos_model.lower() + if isinstance(args, dict): + meta_dict = args + eos_model_type = meta_dict['micro_eos_model'].lower() + else: + try: + with open(args.emulator_metadata, 'r') as f: + meta_dict = json.load(f) + except TypeError: + meta_dict = args.emulator_metadata + except FileNotFoundError: + meta_dict = literal_eval(args.emulator_metadata) + + eos_model_type = args.micro_eos_model.lower() if eos_model_type == 'nep': return NEPEoSGenerator(meta_dict) @@ -230,10 +234,11 @@ class LEC13EoSGenerator(LECEoSGenerator): class EoSConverter: def __init__(self, args, method=None): - if getattr(args, 'eos_file', None) or getattr(args, 'eos_data', None): - method = "tabulated" - elif getattr(args, 'emulator_metadata', None): - method = "emulated" + if method is None: + if getattr(args, 'eos_file', None) or getattr(args, 'eos_data', None): + method = "tabulated" + elif getattr(args, 'emulator_metadata', None): + method = "emulated" self.parameter_conversion = self.full_eos_conversion # Case 1: eos is generated from emulator on the fly @@ -295,55 +300,62 @@ def eos_from_ram(self, converted_parameters): def single_eos_from_ram(self, _): return self.eos_data + def full_eos_conversion(self, parameters): + parameters =self.compute_macro_parameters(parameters) + return self.system_props_from_eos(parameters) + def compute_macro_parameters(self, parameters): + eos_macro_keys = ["TOV_mass", "TOV_radius", "R_14", "R_16"] eos_data = self.macro_conversion(parameters) + if len(eos_data) ==1: radii, masses, lambdas = eos_data[0] + for key, val in zip(eos_macro_keys, + EOS_to_ns_parameters(radii, masses, lambdas) + ): + parameters[key] = val else: radii, masses, lambdas = map(list, zip(*eos_data)) + TOV_mass_list, TOV_radius_list, R_14_list, R_16_list = [], [], [], [] + for rad, mass, lam in zip(radii, masses, lambdas): + TOV_mass, TOV_radius, R_14, R_16 = EOS_to_ns_parameters(rad, mass, lam) + TOV_mass_list.append(TOV_mass) + TOV_radius_list.append(TOV_radius) + R_14_list.append(R_14) + R_16_list.append(R_16) + for key, _list in zip(eos_macro_keys, [ + TOV_mass_list, TOV_radius_list, R_14_list, R_16_list + ]): + parameters[key] = np.array(_list) self.macro_parameters = {'radii': radii, 'masses': masses, 'lambdas': lambdas} return parameters - - def full_eos_conversion(self, parameters): - self.compute_macro_parameters(parameters) - return self.macro_props_from_eos(parameters) - - def macro_props_from_eos(self, converted_parameters): - eos_keys = ["TOV_mass", "TOV_radius", "lambda_1", "lambda_2", - "radius_1", "radius_2", "R_14", "R_16"] + def system_props_from_eos(self, converted_parameters): + system_keys = ["lambda_1", "lambda_2", "radius_1", "radius_2"] m1_source = converted_parameters["mass_1_source"] m2_source = converted_parameters["mass_2_source"] radii, masses, lambdas = self.macro_parameters.values() if isinstance(radii, np.ndarray): # single eos case - for key, val_array in zip(eos_keys, - EOS2Parameters(radii, masses, lambdas, m1_source, m2_source) + for key, val_array in zip(system_keys, + EOS_to_system_parameters(radii, masses, lambdas, m1_source, m2_source) ): converted_parameters[key] = val_array else: - ### assuming TOV mass and radius are the last entries of the respective arrays - TOV_mass_list, TOV_radius_list, R_14_list, R_16_list = [], [], [], [] lambda_1_list, lambda_2_list, radius_1_list, radius_2_list = [], [], [], [] for i, rad in enumerate(radii): - (TOV_mass, TOV_radius, lambda_1, lambda_2, radius_1, - radius_2, R_14, R_16 - ) = EOS2Parameters(rad, masses[i], lambdas[i], m1_source[i], m2_source[i] ) + (lambda_1, lambda_2, radius_1, radius_2 ) = EOS_to_system_parameters( + rad, masses[i], lambdas[i], m1_source[i], m2_source[i] ) - TOV_radius_list.append(TOV_radius) - TOV_mass_list.append(TOV_mass) lambda_1_list.append(lambda_1) lambda_2_list.append(lambda_2) radius_1_list.append(radius_1) radius_2_list.append(radius_2) - R_14_list.append(R_14) - R_16_list.append(R_16) - - for key, _list in zip(eos_keys, [ - TOV_mass_list, TOV_radius_list, lambda_1_list, - lambda_2_list, radius_1_list, radius_2_list, R_14_list, R_16_list + + for key, _list in zip(system_keys, [ + lambda_1_list, lambda_2_list, radius_1_list, radius_2_list ]): converted_parameters[key] = np.array(_list) From 5ad28f6a362b44685cb40ebcdf39adfff714d4a1 Mon Sep 17 00:00:00 2001 From: Hen42rik Date: Tue, 10 Mar 2026 14:22:36 +0100 Subject: [PATCH 03/13] minor convenience fixes --- nmma/core/constants.py | 6 ++++++ nmma/core/parsing.py | 4 +++- nmma/core/utils.py | 1 + nmma/em/io.py | 6 +++++- nmma/em/prior.py | 2 +- nmma/em/utils.py | 4 +--- 6 files changed, 17 insertions(+), 6 deletions(-) diff --git a/nmma/core/constants.py b/nmma/core/constants.py index ae8682d3..26249f3b 100644 --- a/nmma/core/constants.py +++ b/nmma/core/constants.py @@ -8,6 +8,7 @@ # helpers mc2 = const.M_sun*const.c**2 G_per_c2 = const.G/const.c**2 +seconds_a_day = 24*3600 @@ -34,6 +35,11 @@ msun_to_ergs = mc2.cgs.value MeV_per_fm3_to_Msun_per_km3 = 1e54/((mc2).to(u.MeV).value) # 1 MeV/fm**3 is 8.9653E-7 Msun/km**3 +## pulsar timing constants +msun_s = (const.M_sun * const.G / const.c**3).value # geometrised Msun in seconds +msun_mus = msun_s * 1e6 # Msun in microseconds, used for pulsar timing +einstein_factor = (const.G*const.M_sun/const.c**3).value**(2/3) + default_cosmology = cosmology.Planck18 def set_cosmology(cosmology_input=None): """Set the cosmology for the NMMA package. diff --git a/nmma/core/parsing.py b/nmma/core/parsing.py index d0002162..484f58f4 100644 --- a/nmma/core/parsing.py +++ b/nmma/core/parsing.py @@ -132,7 +132,9 @@ def single_messenger_analysis_parsing(parser): help="To use reactive sampling in ultranest (default: False)") parser.add_argument("--skip-sampling", action='store_true', help="If analysis has already run, skip bilby sampling and compute results from checkpoint files. Combine with --plot to make plots from these files.") - + parser.add_argument("--bestfit", "--best-fit", action='store_true', + help="Save the best fit parameters to JSON") + return parser def base_injection_parsing(parser): diff --git a/nmma/core/utils.py b/nmma/core/utils.py index cb866f38..be7c5a85 100644 --- a/nmma/core/utils.py +++ b/nmma/core/utils.py @@ -190,6 +190,7 @@ def sig_lims(values, quantiles=None, sig_unc=2): fmt = f".{ord_error}f" return f"${{{q_mean:{fmt}}}}_{{-{low_err:{fmt}}}}^{{+{high_err:{fmt}}}}$" else: + q_mean, low_err, high_err =np.around([q_mean, low_err, high_err], ord_error) return f"${{{int(q_mean)}}}_{{-{int(low_err)}}}^{{+{int(high_err)}}}$" diff --git a/nmma/em/io.py b/nmma/em/io.py index b1846b73..83bee8df 100644 --- a/nmma/em/io.py +++ b/nmma/em/io.py @@ -34,6 +34,8 @@ def load_em_observations(filename, args=None, format='observations'): Returns: - data (dict): Dictionary containing the lightcurve data from the file. The keys are generally 'time' and each of the filters in the file as well as their accompanying error values. """ + if isinstance(filename, dict): + return filename # assume it is already in the correct format if isinstance(filename, argparse.Namespace): args = filename filename = args.light_curve_data @@ -78,7 +80,9 @@ def read_lc_from_csv(filename, args, format): data = {} for line in lines: - lineSplit = line.split(" ") + if line.startswith("#") or line.startswith("time") or line.startswith("mjd"): + continue + lineSplit = line.split(None) lineSplit = list(filter(None, lineSplit)) try: mjd = Time(lineSplit[0]).mjd diff --git a/nmma/em/prior.py b/nmma/em/prior.py index f054641a..7e88ab30 100644 --- a/nmma/em/prior.py +++ b/nmma/em/prior.py @@ -229,7 +229,7 @@ def create_prior_from_args(args, systematics_handler): lc_model : nmma.em.model.LightCurveModelContainer Light curve model object to compute light curves """ - priors = PriorDict(args.prior_file) + priors = PriorDict(args.prior_file) if getattr(args, 'prior_file', None) else PriorDict(args.prior) priors = adjust_hubble_prior(priors, args) priors = extinction_prior(priors, args) diff --git a/nmma/em/utils.py b/nmma/em/utils.py index 72cb62ee..dfdaa0e9 100644 --- a/nmma/em/utils.py +++ b/nmma/em/utils.py @@ -77,7 +77,7 @@ def setup_sample_times(args): return np.arange(tmin, tmax + args.em_tstep, args.em_tstep) # otherwise, create the array based on selected scale - if 'lin' in args.em_timescale: + if 'lin' in args.em_timescale or tmin<=0.: return np.linspace(tmin, tmax, args.em_nsteps) elif any(scale in args.em_timescale for scale in ['log', 'geo']): return np.geomspace(tmin, tmax, args.em_nsteps) @@ -198,7 +198,6 @@ def set_filter_associated_dict(quantity, filters, default_limit = np.inf): def cut_data_to_time_range(data, args, trigger_time, tmin = 0, tmax = np.inf): - tmin = getattr(args, "data_tmin", tmin) tmax = getattr(args, "data_tmax", tmax) @@ -251,7 +250,6 @@ def setup_filtered_lc_data(light_curve_data, trigger_time): def check_model_time_consistency(light_curve_data, light_curve_model, priors, injection = None): (lc_times, lc_mags, lc_uncertainties, trigger_time) = light_curve_data - data_tmin = np.min([lc_times[key].min() for key in lc_times.keys()]) data_tmax = np.max([lc_times[key].max() for key in lc_times.keys()]) From 07b237b812d67f5eec01e2dc3aea24fce32e8e11 Mon Sep 17 00:00:00 2001 From: Hen42rik Date: Tue, 10 Mar 2026 14:25:33 +0100 Subject: [PATCH 04/13] ENH: improved sncosmo connection --- nmma/core/base.py | 28 +++++--- nmma/em/analysis.py | 9 ++- nmma/em/em_parsing.py | 5 +- nmma/em/lightcurve_generation.py | 22 +++--- nmma/em/model.py | 115 +++++++++++++++++++++---------- 5 files changed, 114 insertions(+), 65 deletions(-) diff --git a/nmma/core/base.py b/nmma/core/base.py index 942806dc..7eef70cc 100644 --- a/nmma/core/base.py +++ b/nmma/core/base.py @@ -2,6 +2,7 @@ import io import contextlib import h5py +from argparse import Namespace from ast import literal_eval import numpy as np import pandas as pd @@ -15,6 +16,7 @@ from .utils import input_obj_to_str, read_bestfit_from_posterior from .constants import set_cosmology from .conversion import cosmology_to_distance +from .parsing import single_messenger_analysis_parsing, nmma_base_parsing def initialisation_args_from_signature_and_namespace(_callable, namespace, prefixes = []): prefixes.append('') @@ -98,12 +100,15 @@ def final_diagnostics(self, bestfit_params, args, result=None): The figure object containing the plot """ - pass + try: + self.sub_model.final_diagnostics(bestfit_params, args, result) + except AttributeError: + pass def post_process_bestfit(self, args, result=None): bestfit_params = read_bestfit_from_posterior(args) bestfit_params = self.parameter_conversion(bestfit_params) - self.final_diagnostics(bestfit_params, args, result) + return self.final_diagnostics(bestfit_params, args, result) def check_parameter_equivalencies(self, parameter_names): """Check for equivalent parameters and terminate if found""" @@ -272,28 +277,32 @@ def check_priors_and_likelihood_for_nmma(priors, likelihood): constraints = {k: priors.pop(k) for k in priors.copy().keys() if isinstance(priors[k], Constraint)} likelihood.constraints.update(constraints) + test_draw = priors.sample(1) test_conversion = priors.conversion_function(test_draw) if len(set(test_conversion.keys()) ) != len(test_conversion.keys()): priors.conversion_function = priors.default_conversion_function likelihood.conv_functions.append(likelihood.priors.conversion_function) + # add final conversions likelihood.setup_parameter_conversion() return priors, likelihood def bilby_sampling(likelihood, priors, args, injection_parameters=None, rank=0): priors, likelihood = check_priors_and_likelihood_for_nmma(priors, likelihood) - + if isinstance(args, dict): + def_args = nmma_base_parsing(single_messenger_analysis_parsing) + def_args.__dict__.update(args) + args = def_args # fetch the additional sampler kwargs - if isinstance(args.sampler_kwargs, str): - sampler_kwargs = literal_eval(args.sampler_kwargs) - else: - sampler_kwargs = args.sampler_kwargs + sampler_kwargs = getattr(args, "sampler_kwargs", {}) + if isinstance(sampler_kwargs, str): + sampler_kwargs = literal_eval(sampler_kwargs) print("Running with the following additional sampler_kwargs:") print(sampler_kwargs) # check if it is running with reactive sampler - nlive = None if args.reactive_sampling else args.nlive + nlive = None if getattr(args, 'reactive_sampling', False) else args.nlive if nlive is None and args.sampler != "ultranest": raise ValueError("reactive sampling is only available for ultranest, " "please set nlive or use ultranest sampler") @@ -358,7 +367,8 @@ def bilby_sampling(likelihood, priors, args, injection_parameters=None, rank=0): result.plot_corner(injection_parameters, priors) if args.bestfit or args.plot: - likelihood.post_process_bestfit(args, result) + result.posterior = likelihood.posterior_conversion(result.posterior) + return likelihood.post_process_bestfit(args, result) def multi_analysis_loop(args, analysis_setup): diff --git a/nmma/em/analysis.py b/nmma/em/analysis.py index 8778e4d5..1f03f3f7 100644 --- a/nmma/em/analysis.py +++ b/nmma/em/analysis.py @@ -97,7 +97,6 @@ def analysis_setup(args): filters = utils.set_filters(args) detection_limit = utils.create_detection_limit(args, filters) - try: # load observational data data = io.load_em_observations(args, format='observations') @@ -110,7 +109,6 @@ def analysis_setup(args): except FileNotFoundError: # If the injection file is not found, raise an error raise FileNotFoundError("Injection file not found.") - data = utils.cut_data_to_time_range(data, args, trigger_time) data = check_detections(data, args.remove_nondetections) filters_to_analyze = set_analysis_filters(filters, data) @@ -252,7 +250,14 @@ def nnanalysis(args): print('saved posterior plot') def main(args=None): + if isinstance(args, dict): + non_default = args.copy() + args = [] + else: + non_default = {} args = parsing_and_logging(multi_wavelength_analysis_parser, args) + args.__dict__.update(non_default) + if args.sampler == 'neuralnet': nnanalysis(args) else: diff --git a/nmma/em/em_parsing.py b/nmma/em/em_parsing.py index cef22c68..12d1e3f8 100644 --- a/nmma/em/em_parsing.py +++ b/nmma/em/em_parsing.py @@ -38,7 +38,6 @@ def basic_em_only_parsing(parser): parser.add_argument("--verbose", action='store_true', help="print out log likelihoods" ) return parser - def basic_em_only_analysis_parsing(parser): parser = single_messenger_analysis_parsing(parser) @@ -47,8 +46,6 @@ def basic_em_only_analysis_parsing(parser): parser.add_argument("--light-curve-data", "--data", help="Path to data in [time filter magnitude error] format, time format will be inferred, but can be explicitly adjusted with --time-format. If not given, will try to generate data from the injection file.") parser.add_argument("--time-format", help="Time format of the light curve data, e.g. isot, mjd, see https://docs.astropy.org/en/stable/time/#time-format") - parser.add_argument("--bestfit", "--best-fit", action='store_true', - help="Save the best fit parameters and magnitudes to JSON") return parser @@ -73,6 +70,8 @@ def em_model_parsing(parser): help="Name of the model-type to be used, can be a comma-seperated list for joint lightcurve models" ) em_model_parser.add_argument("--em-model", "--kilonova-model","--model", type=yaml_parse, nargs="*", help="Name of the transient model to be used") + em_model_parser.add_argument("--em-model-kwargs", "--model-parameters", type=yaml_parse, + help="Additional keyword arguments for the transient model, given like a python-dict ") em_model_parser.add_argument("--interpolation-type", "--gptype", default="keras", help="Interpolation library to be used for EM "\ "transient model. Default: keras, further options: tensorflow, sklearn_gp, api_gp" ) diff --git a/nmma/em/lightcurve_generation.py b/nmma/em/lightcurve_generation.py index 11e7cc2f..a5e04387 100644 --- a/nmma/em/lightcurve_generation.py +++ b/nmma/em/lightcurve_generation.py @@ -288,20 +288,14 @@ def host_lc(sample_times, parameters, filters, host_mag): ## supernova model def sn_lc(sample_times_stretched, sn_model, filters, lambdas): mag = {} - - for filt, lambda_A in zip(filters, lambdas): - # convert back to AA - lambda_AA = 1e10 * lambda_A - if lambda_AA < sn_model.minwave() or lambda_AA > sn_model.maxwave(): - mag[filt] = np.full_like(sample_times_stretched,np.inf) - else: - try: - flux = sn_model.flux(sample_times_stretched, [lambda_AA])[:, 0] - # see https://en.wikipedia.org/wiki/AB_magnitude - flux_jy = 3.34e4 * np.power(lambda_AA, 2.0) * flux - mag[filt] = utils.flux_to_ABmag(flux_jy, unit='Jy') - except Exception: - return {} + for filt, lambda_ in zip(filters, lambdas): + try: + mag[filt] = sn_model.bandmag(filt, 'ab', sample_times_stretched) + except ValueError: + lambda_AA = 1e10 * lambda_ + flux_AA = sn_model.flux(sample_times_stretched, lambda_AA) + # see https://en.wikipedia.org/wiki/AB_magnitude + mag[filt] = utils.flux_to_ABmag(flux_AA*3.34e4 * lambda_AA**2, unit='Jy') return mag ## shock-cooling lightcurve diff --git a/nmma/em/model.py b/nmma/em/model.py index d5242b48..fe708e12 100644 --- a/nmma/em/model.py +++ b/nmma/em/model.py @@ -351,7 +351,7 @@ def combine_detector_data(self, model_lc, observable_times): # apparent_magnitude = utils.autocomplete_data( # observable_times, observable_times[use_mask], apparent_magnitude[use_mask]) else: #no meaningful inter-/extrapolation possible - apparent_magnitude = np.full_like(self.model_times, np.inf) + apparent_magnitude = np.full_like(observable_times, np.inf) lc_data[filt] = apparent_magnitude return (observable_times, lc_data) @@ -448,6 +448,8 @@ class SimpleBolometricLightCurveModel(LightCurveModelContainer): ---------- model: string, optional Name of the model. Can be either "Arnett" (default) or "Arnett_modified" + em_model_kwargs: optional + Additional keyword arguments, not used in Arnett SN model. Returns ------- @@ -455,7 +457,7 @@ class SimpleBolometricLightCurveModel(LightCurveModelContainer): A light curve model object to evaluate the light curve from a set of parameters """ - def __init__(self, model="Arnett", sample_times=None): + def __init__(self, model="Arnett", sample_times=None, **em_model_kwargs): super().__init__(model, sample_times=sample_times) if model == "Arnett": self.lc_func = lc_gen.arnett_lc @@ -504,7 +506,8 @@ class SVDLightCurveModel(LightCurveModelContainer): List of filters to create model for. Defaults to all available filters. local_only: bool, optional If True, only local models will be used. - + em_model_kwargs: optional + Additional keyword arguments to be passed to the underlying emulator. Returns ------- LightCurveModel: `nmma.em.model.SVDLightCurveModel` @@ -523,6 +526,7 @@ def __init__( filters=None, sample_times = None, local_only=False, + **em_model_kwargs ): # Some models have underscores. Keep those, but drop '_tf' if it exists model_name_components = model.split("_") @@ -546,7 +550,7 @@ def __init__( ##FIXME Does this make sense for api_gp, too? filters = self.get_model_data(core_model_name, filters) try: - svd_mag_model = joblib.load(modelfile) + svd_mag_model = joblib.load(modelfile, **em_model_kwargs) # temporary fix, moving towards permament setting of sncosmo filter names self.svd_mag_model = {k.replace("_", ":"): v for k, v in svd_mag_model.items()} self.svd_lbol_model= None # FIXME: this is not yet implemented @@ -667,11 +671,13 @@ class FiestaKilonovaModel(FiestaModel): Parameters ---------- model: str, optional - Name of the model. Default is "Bu2025_MLP". + Name of the model. Default is "Bu2026_MLP". filters: list of str, optional List of filters to create model for. Defaults to all available filters. surrogate_dir: str, optional path to the directory containing the surrogate models. + em_model_kwargs: optional + Additional keyword arguments to be passed to the underlying fiesta model. Returns ------- @@ -679,19 +685,15 @@ class FiestaKilonovaModel(FiestaModel): A light curve model object to evaluate the light curve from a set of parameters. """ - def __init__(self, model="Bu2025_MLP", filters=None, surrogate_dir=None, **kwargs): + def __init__(self, model="Bu2026_MLP", filters=None, surrogate_dir=None, **em_model_kwargs): if model.endswith("_lc"): from fiesta.inference.lightcurve_model import BullaLightcurveModel as BullaSurrogate else: from fiesta.inference.lightcurve_model import BullaFlux as BullaSurrogate - fiesta_kwargs= dict( name=model, filters=filters, directory=surrogate_dir,) - try: - fiesta_model = BullaSurrogate(**fiesta_kwargs) - except OSError: - fiesta_kwargs['directory'] = f'{surrogate_dir}/KN/{model}/model' - fiesta_model = BullaSurrogate(**fiesta_kwargs) + fiesta_kwargs= dict( name=model, filters=filters, directory=surrogate_dir,**em_model_kwargs) + fiesta_model = BullaSurrogate(**fiesta_kwargs) - super().__init__(fiesta_model, filters, sample_times=kwargs.get('sample_times', None)) + super().__init__(fiesta_model, filters, sample_times=em_model_kwargs.get('sample_times', None)) class GRBMixin: def __init__(self, *args, resolution=12, **kwargs): @@ -741,6 +743,8 @@ class FiestaGRBModel(GRBMixin,FiestaModel): List of filters to create model for. Defaults to all available filters. surrogate_dir: str, optional path to the directory containing the surrogate models. + em_model_kwargs: optional + Additional keyword arguments to be passed to the underlying Fiesta GRB model. Returns ------- @@ -748,16 +752,12 @@ class FiestaGRBModel(GRBMixin,FiestaModel): A light curve model object to evaluate the light curve from a set of parameters. """ - def __init__(self, model="afgpy_gaussian_CVAE", filters=None, surrogate_dir=None, **kwargs): + def __init__(self, model="afgpy_gaussian_CVAE", filters=None, surrogate_dir=None, **em_model_kwargs): from fiesta.inference.lightcurve_model import AfterglowFlux - fiesta_kwargs= dict( name=model, filters=filters, directory=surrogate_dir,) - try: - fiesta_model = AfterglowFlux(**fiesta_kwargs) - except OSError: - fiesta_kwargs['directory'] = f'{surrogate_dir}/GRB/{model}/model' - fiesta_model = AfterglowFlux(**fiesta_kwargs) - - super().__init__(fiesta_model, filters, sample_times=kwargs.get('sample_times', None)) + fiesta_kwargs= dict( name=model, filters=filters, directory=surrogate_dir, **em_model_kwargs) + fiesta_model = AfterglowFlux(**fiesta_kwargs) + + super().__init__(fiesta_model, filters, sample_times=em_model_kwargs.get('sample_times', None)) class GRBLightCurveModel(GRBMixin, LightCurveModelContainer): @@ -778,6 +778,8 @@ class GRBLightCurveModel(GRBMixin, LightCurveModelContainer): Type of jet for the GRB model. Default is 0. filters: list of str, optional List of filters to create model for. Defaults to all available filters. + em_model_kwargs: optional + Additional keyword arguments to be passed to the underlying afterglowpy model. Returns ------- @@ -793,10 +795,11 @@ def __init__( jet_type=0, filters=None, sample_times=None, + **em_model_kwargs ): super().__init__(model, filters, model_parameters, sample_times, resolution=resolution) self.jet_type = jet_type - self.default_parameters = {"xi_N": 1.0, "d_L": 3.086e19, "jetType": jet_type, "specType": 0} # d_L=10pc in cm + self.default_parameters = {"xi_N": 1.0, "d_L": 3.086e19, "jetType": jet_type, "specType": 0, **em_model_kwargs} # d_L=10pc in cm self.def_keys = self.default_parameters.keys() #keys we typically sample in log space, but need to convert to linear space self.log_sampling_keys = ["E0", "n0", "epsilon_e", "epsilon_B"] @@ -874,6 +877,8 @@ class HostGalaxyLightCurveModel(LightCurveModelContainer): Magnitude of the host galaxy. Default is 23.9. model_parameters: list, optional List of alternative model parameters. If not specified, default will be used. + em_model_kwargs: optional + Additional keyword arguments, not used, but provided for consistency. Returns ------- @@ -888,7 +893,8 @@ def __init__( sample_times=None, # host_mag is the magnitude of the host galaxy in the filters host_mag=23.9, # value for case of arxiv:2303.12849 - model_parameters=None + model_parameters=None, + **em_model_kwargs ): super().__init__(model, filters, model_parameters, sample_times=sample_times) if isinstance(host_mag, (float, int)): @@ -915,6 +921,8 @@ class SupernovaLightCurveModel(LightCurveModelContainer): List of filters to create model for. Defaults to all available filters. model_parameters: list, optional List of alternative model parameters. If not specified, default will be used. + em_model_kwargs: optional + Additional keyword arguments to be passed to the underlying sncosmo model. Returns ------- @@ -927,15 +935,25 @@ def __init__( model="nugent-hyper", filters=None, sample_times=None, - model_parameters=None + model_parameters=None, + **em_model_kwargs ): - self.sn_model = sncosmo.Model(source=model) + if isinstance(model, str): + self.sn_model = sncosmo.Model(source=model, **em_model_kwargs) + else: + self.sn_model = model if model_parameters is not None: print("Warning: model_parameters are ignored for SupernovaLightCurveModel, using sncosmo defaults.") model_parameters = self.sn_model.param_names if sample_times is None: sample_times = np.linspace(self.sn_model.mintime(), self.sn_model.maxtime(), 200) + + if (sample_times < 0).any(): + # NOTE: We assume this means the sncosmo model is relative to peak time. + sample_times += sample_times[0] + print(f"Warning: Some supernova models are relative to the peak, some relative to the explosion time, " + "but nmma always expects times relative to the explosion time. Adjust your t0 prior accordingly." ) super().__init__(model, filters, model_parameters, sample_times) def em_parameter_setup(self, parameters): @@ -943,22 +961,37 @@ def em_parameter_setup(self, parameters): self.sn_model.set(**lc_pars) def combine_lc_params(self, parameters): - # FIXME: This should probably be removed, use sncosmo parameters instead + self.stretch = parameters.get('supernova_mag_stretch', 1.) + parameters["t0"] = parameters.get("t0", 0.) parameters["z"] = self.redshift return {p: parameters.get(p, self.sn_model.get(p)) for p in self.model_parameters} - def generate_lightcurve(self, sample_times, parameters): - self.em_parameter_setup(parameters) - + def gen_detector_lc(self, parameters = None, sample_times=None): + """Generate a light curve for given parameters as observable in detector frame. + Parameters + ---------- + parameters: dict + Parameters of the Supernova model. + sample_times: times at which to explore the light curve. If None, uses the default times for the model.""" - mag = lc_gen.sn_lc(sample_times / self.stretch, self.sn_model, + if sample_times is None: + sample_times = self.model_times + + # convert the parameters to the fiesta model parameters + self.em_parameter_setup(parameters) + mag = lc_gen.sn_lc(sample_times / self.stretch/(1 + self.redshift), self.sn_model, self.default_filts, self.lambdas) - - return {filt: filt_mag for filt, filt_mag in mag.items()} - + + # apply the extinction correction + ext_mag = self.get_extinction_mags() + obs_mags = self.apply_extinction_correction(mag, ext_mag, self.default_filts) + + # we are in observer frame, but still need to add the timeshift + return (sample_times + self.timeshift, obs_mags) + class ShockCoolingLightCurveModel(LightCurveModelContainer): def __init__( @@ -974,6 +1007,10 @@ def __init__( Name of the model. Default is "Piro2021". filters: list of str, optional List of filters to create model for. Defaults to all available filters. + model_parameters: list, optional + List of alternative model parameters. If not specified, default will be used. + em_model_kwargs: optional + Additional keyword arguments, not used, but provided for consistency. Returns ------- @@ -1020,6 +1057,8 @@ class SimpleKilonovaLightCurveModel(LightCurveModelContainer): Name of the model. Default is "Me2017". filters: list of str, optional List of filters to create model for. Defaults to all available filters. + em_model_kwargs: optional + Additional keyword arguments, not used, but provided for consistency. Returns ------- @@ -1028,7 +1067,7 @@ class SimpleKilonovaLightCurveModel(LightCurveModelContainer): from a set of parameters. """ def __init__( - self, model="Me2017", filters=None, sample_times=None + self, model="Me2017", filters=None, sample_times=None, **em_model_kwargs ): super().__init__(model, filters, sample_times=sample_times) lc_dict={ @@ -1311,10 +1350,12 @@ def single_model_from_args(model_class, model_name, args, # update explicit args model_args = default_model_args | dict(filters=filters, - sample_times=utils.setup_sample_times(args), + sample_times=utils.setup_sample_times(args) ) - if model_name is not None: + if model_name: model_args["model"] = model_name.strip() + if args.em_model_kwargs: + model_args|= args.em_model_kwargs return model_class(**model_args) def create_light_curve_model_from_args( From ca3c8f0a87e4f6e63a535cf2095e49312622bf72 Mon Sep 17 00:00:00 2001 From: Hen42rik Date: Tue, 7 Apr 2026 15:11:24 +0200 Subject: [PATCH 05/13] make dynesty-mpi available in single-messenger analysis --- nmma/core/base.py | 82 ++++---- .../analysis_run.py => core/mpi_setup.py} | 187 +++++++++++++----- nmma/core/parsing.py | 61 +++++- nmma/joint/generation.py | 4 +- nmma/joint/injection_handling.py | 4 +- nmma/joint/main.py | 157 ++++++--------- nmma/joint/multi_parsing.py | 69 +------ 7 files changed, 297 insertions(+), 267 deletions(-) rename nmma/{joint/analysis_run.py => core/mpi_setup.py} (83%) diff --git a/nmma/core/base.py b/nmma/core/base.py index 7eef70cc..2c0ca1ec 100644 --- a/nmma/core/base.py +++ b/nmma/core/base.py @@ -1,8 +1,6 @@ import inspect -import io -import contextlib +import os import h5py -from argparse import Namespace from ast import literal_eval import numpy as np import pandas as pd @@ -101,7 +99,7 @@ def final_diagnostics(self, bestfit_params, args, result=None): """ try: - self.sub_model.final_diagnostics(bestfit_params, args, result) + return self.sub_model.final_diagnostics(bestfit_params, args, result) except AttributeError: pass @@ -289,15 +287,12 @@ def check_priors_and_likelihood_for_nmma(priors, likelihood): return priors, likelihood def bilby_sampling(likelihood, priors, args, injection_parameters=None, rank=0): - priors, likelihood = check_priors_and_likelihood_for_nmma(priors, likelihood) if isinstance(args, dict): def_args = nmma_base_parsing(single_messenger_analysis_parsing) def_args.__dict__.update(args) args = def_args # fetch the additional sampler kwargs sampler_kwargs = getattr(args, "sampler_kwargs", {}) - if isinstance(sampler_kwargs, str): - sampler_kwargs = literal_eval(sampler_kwargs) print("Running with the following additional sampler_kwargs:") print(sampler_kwargs) @@ -368,50 +363,51 @@ def bilby_sampling(likelihood, priors, args, injection_parameters=None, rank=0): if args.bestfit or args.plot: result.posterior = likelihood.posterior_conversion(result.posterior) - return likelihood.post_process_bestfit(args, result) + likelihood.post_process_bestfit(args, result) + return result def multi_analysis_loop(args, analysis_setup): - + USE_MPI = False # check if it is running under mpi try: from mpi4py import MPI rank = MPI.COMM_WORLD.Get_rank() + USE_MPI = True except ImportError: rank = 0 - if rank != 0: - # Create buffers - stdout_buffer = io.StringIO() - stderr_buffer = io.StringIO() - - # Redirect python output into buffers - redirect_out = contextlib.redirect_stdout(stdout_buffer) - redirect_err = contextlib.redirect_stderr(stderr_buffer) - else: - redirect_out = contextlib.nullcontext() - redirect_err = contextlib.nullcontext() + # if rank != 0: + # devnull = os.open(os.devnull, os.O_WRONLY) + # os.dup2(devnull, 1) + # os.dup2(devnull, 2) - with redirect_out, redirect_err: - if getattr(args, 'multi', None): - sub_runs = [] - if len(args.multi) == 1: - arg, vals = list(args.multi.items())[0] - for i, val in enumerate(vals): - run_args = deepcopy(args) - setattr(run_args, arg, val) - setattr(run_args, 'label', f"{args.label}_{i}") - sub_runs.append(run_args) - else: - for run_name, changes in args.multi.items(): - run_args = deepcopy(args) - setattr(run_args, 'label', f"{args.label}_{run_name}") - for key, value in changes.items(): - if key not in args: - raise KeyError(f"{key} not a known argument... please remove") - setattr(run_args, key, value) - sub_runs.append(run_args) + if getattr(args, 'multi', None): + sub_runs = [] + if len(args.multi) == 1: + arg, vals = list(args.multi.items())[0] + for i, val in enumerate(vals): + run_args = deepcopy(args) + setattr(run_args, arg, val) + setattr(run_args, 'label', f"{args.label}_{i}") + sub_runs.append(run_args) + else: + for run_name, changes in args.multi.items(): + run_args = deepcopy(args) + setattr(run_args, 'label', f"{args.label}_{run_name}") + for key, value in changes.items(): + if key not in args: + raise KeyError(f"{key} not a known argument... please remove") + setattr(run_args, key, value) + sub_runs.append(run_args) + else: + sub_runs = [args] + for run_args in sub_runs: + priors, likelihood, injection_parameters = analysis_setup(run_args) + priors, likelihood = check_priors_and_likelihood_for_nmma(priors, likelihood) + if USE_MPI and run_args.sampler =='dynesty': + from .mpi_setup import pbilby_sampling + run_function = pbilby_sampling else: - sub_runs = [args] - for run_args in sub_runs: - priors, likelihood, injection_parameters = analysis_setup(run_args) - bilby_sampling(likelihood, priors, run_args, injection_parameters, rank) \ No newline at end of file + run_function = bilby_sampling + out = run_function(likelihood, priors, run_args, injection_parameters, rank) + return out \ No newline at end of file diff --git a/nmma/joint/analysis_run.py b/nmma/core/mpi_setup.py similarity index 83% rename from nmma/joint/analysis_run.py rename to nmma/core/mpi_setup.py index 9f57d954..ec2dc380 100644 --- a/nmma/joint/analysis_run.py +++ b/nmma/core/mpi_setup.py @@ -2,9 +2,9 @@ import sys import traceback from io import BufferedWriter -from glob import glob from copy import deepcopy import pickle +import signal from functools import wraps from time import time from datetime import timedelta @@ -12,16 +12,16 @@ import numpy as np from pandas import DataFrame from numpy.random import Generator, PCG64, SeedSequence +from schwimmbad import MPIPool, MultiPool -from bilby.core.prior import PriorDict from bilby.core.sampler import base_sampler as bs, dynesty3_utils as dy_utils from bilby.core.sampler.dynesty import dynesty_stats_plot import dynesty from dynesty.plotting import traceplot, runplot -from ..core.conversion import label_mapping -from ..core.utils import rejection_sample, read_bestfit_from_posterior, logger -from .joint_likelihood import MultiMessengerLikelihood +from .conversion import label_mapping +from .utils import rejection_sample, read_bestfit_from_posterior, logger +from .parsing import process_sampler_kwargs def time_storage(func): @@ -42,45 +42,25 @@ class Worker(bs.NestedSampler): Parameters: data_dump: a pickle-file containing all relevant data to create priors and likelihoods. """ - def __init__(self, data_dump, - outdir, label, plot = False, - skip_import_verification = True, + def __init__( + self, args, prior, likelihood, + injection_parameters, plot = False, + skip_import_verification = True, ): - ## Load the data dump - if not data_dump.endswith("_dump.pickle"): - test_out = os.path.join(os.getcwd(), data_dump) - test_dump = glob(f"{test_out}/data/*_dump.pickle") - data_dump = test_dump[0] - with open(data_dump, "rb") as file: - data_dump = pickle.load(file) - - ## Set properties from the data dump - self.data_dump = data_dump - args = data_dump["args"] args.plot = plot self.args = args + self.outdir = args.outdir + self.label = args.label - # If the run dir has not been specified, get it from the args - if outdir is None: - outdir = self.args.outdir - os.makedirs(outdir, exist_ok=True) + os.makedirs(args.outdir, exist_ok=True) - # If the label has not been specified, get it from the args - if label is None: - label = self.args.label - priors = PriorDict.from_json(data_dump["prior_file"]) - - ## Set up the likelihood - likelihood = MultiMessengerLikelihood.setup_from_args( - data_dump, priors, self.args, logger) - super().__init__( - likelihood, priors, outdir, label, - injection_parameters= data_dump.get("injection_parameters", None), + likelihood, prior, self.outdir, self.label, + injection_parameters, skip_import_verification = skip_import_verification, - plot=self.args.plot, + plot= plot, soft_init=True, ) @@ -150,7 +130,8 @@ class Dynesty(Worker): def __init__( self, - data_dump, outdir, label, + args, prior, likelihood, + injection_parameters = None, maxmcmc=5000, naccept=60, nact=2, @@ -158,13 +139,12 @@ def __init__( sampler_kwargs={}, sampler_init_kwargs={}, plot= False, - ): - super().__init__(data_dump, outdir, label, plot, - skip_import_verification = False) - # for handler in logger.handlers: - # if isinstance(handler, logging.StreamHandler): - # handler.stream = sys.stdout - + meta_data = {}, + ): + + super().__init__(args, prior, likelihood, injection_parameters, + plot, skip_import_verification = False) + self.resume_file = f"{self.outdir}/{self.label}_checkpoint_resume.pickle" self.samples_file= f'{self.outdir}/{self.label}_samples.parquet' @@ -177,9 +157,9 @@ def __init__( # dynesty3 sampler kwargs self.dlogz = sampler_kwargs['dlogz'] self.sampler_kwargs = sampler_kwargs - self._init_sampler_kwargs(sampler_init_kwargs, nact, naccept, maxmcmc) self.nlive = sampler_init_kwargs['nlive'] + self.meta_data = meta_data def _init_sampler_kwargs(self, kwargs, nact, naccept, maxmcmc): @@ -503,13 +483,7 @@ def plot_current_state(self): plt.close("all") def storable_metadata(self): - meta_data = self.data_dump.copy() - waveform_generator = meta_data.pop("waveform_generator", None) - if waveform_generator is not None: - meta_data["waveform_generator"] = waveform_generator.__repr__() - ifo_list = meta_data.pop("ifo_list", None) - if ifo_list is not None: - meta_data["ifo_list"] = [ifo.__repr__() for ifo in ifo_list] + meta_data = self.meta_data meta_data["args"] = vars(self.args) # convert Namespace to dict for storing meta_data["likelihood"] = self.likelihood.meta_data @@ -604,3 +578,114 @@ def format_result( logger.warning(f"Failed to create diagnostic plots: {e} \n{traceback.format_exc()}") logger.info("Finished formatting result.") return result + + + +def pbilby_sampling( + likelihood, prior, args, + injection_parameters, rank, + pool_type = 'mpi', + meta_data = {}, + **kwargs +): + + default_kwargs = dict( + sampler_kwargs={}, + sampling_seed=42, + plot=True, + # + maxmcmc=5000, + naccept=60, + nact=2, + check_point_delta_t=1800, + n_check_point=2000, + max_its=1e10, + max_run_time=1e10, + checkpoint_plot=False, + # + rejection_sample_posterior=True, + result_format="hdf5", + ) + + # priority: kwargs > args > defaults + use_kwargs = {} + for key in default_kwargs.keys(): + if key in kwargs: + use_kwargs[key] = kwargs[key] + elif hasattr(args, key): + use_kwargs[key] = getattr(args, key) + else: + use_kwargs[key] = default_kwargs[key] + + use_kwargs |= kwargs # in case there are additional kwargs not in default_kwargs + + # Initialise a worker. this needs a global scope to allow + # persistence of states beyond the pool's scope. + # Otherwise emulators retrace on each evaluation. + global worker + if rank == 0: + sampler_init_kwargs, run_kwargs = process_sampler_kwargs( + kwargs.pop('sampler_kwargs', {}), use_kwargs) + + worker = Dynesty( + args, prior, likelihood, + injection_parameters, + maxmcmc=use_kwargs['maxmcmc'], + nact=use_kwargs['nact'], + naccept=use_kwargs['naccept'], + sampling_seed=use_kwargs['sampling_seed'], + sampler_kwargs = run_kwargs, + sampler_init_kwargs=sampler_init_kwargs, + plot=use_kwargs['plot'], + meta_data=meta_data, + ) + + else: + worker = Worker(args, prior, likelihood, + injection_parameters, plot = use_kwargs['plot']) + + ## graceful handling of preemptive shutdowns + def handle_sigterm(signum, frame): + try: + worker.checkpointing(False, + 'Received termination signal. Checkpointing and exiting gracefully.') + sys.exit() + except Exception: + pass + + signal.signal(signal.SIGTERM, handle_sigterm) + signal.signal(signal.SIGINT , handle_sigterm) + signal.signal(signal.SIGUSR1, handle_sigterm) + + POOL = MPIPool if pool_type == 'mpi' else MultiPool + with POOL() as pool: + result = None + if pool.is_master(): + worker.start_sampler( + pool, + pooled_log_likelihood, + pooled_prior_transform, + pooled_initial_point_from_prior) + + results = worker.run_sampler( + check_point_delta_t=use_kwargs['check_point_delta_t'], + n_check_point=use_kwargs['n_check_point'], + max_its=use_kwargs['max_its'], + max_run_time=use_kwargs['max_run_time'], + checkpoint_plot=use_kwargs['checkpoint_plot'] + ) + result = worker.format_result( + results, use_kwargs['result_format'], + use_kwargs['rejection_sample_posterior']) + return result + + +# Worker functions. These are read in the global scope by each worker +def pooled_initial_point_from_prior(args): + return worker.get_initial_point_from_prior(args) + +def pooled_log_likelihood(v_array): + return worker.log_likelihood(v_array) + +def pooled_prior_transform(u_array): + return worker.prior_transform(u_array) diff --git a/nmma/core/parsing.py b/nmma/core/parsing.py index 484f58f4..467ee971 100644 --- a/nmma/core/parsing.py +++ b/nmma/core/parsing.py @@ -108,10 +108,45 @@ def base_analysis_parsing(parser): parser.add_argument("--cosmology", help="Name of the cosmology to be used, see astropy.cosmology for available cosmologies (implicit default: Planck18)") parser.add_argument("--sampling-seed","--seed", type=int, default=42, help="Sampling seed (default: 42)") + parser.add_argument("--sampler-kwargs", default="{}", type = yaml_parse, + help="Additional keyword arguments to pass to the sampler as a dictionary" ) return parser +def dynesty_parsing(parser): + dynesty_group = parser.add_argument_group(title="Dynesty Settings") + dynesty_group.add_argument("--n-check-point", default=1000, type=int, + help="Steps to take before attempting checkpoint") + dynesty_group.add_argument("--max-its", default=10**10, type=int, + help="Maximum number of iterations to sample for (default=1.e10)") + dynesty_group.add_argument("--max-run-time", default=1.0e10, type=float, + help="Maximum time to run for (default=1.e10 s)") + dynesty_group.add_argument("--rejection-sample-posterior", action='store_false', help=( + "Whether to generate the posterior samples by rejection sampling the " + "nested samples or resampling with replacement" ) ) + dynesty_group.add_argument( "--walks", default=100, type=int, + help="Minimum number of walks, defaults to 100" ) + dynesty_group.add_argument( "--proposals", action="append", + help="The jump proposals to use, the options are 'diff' and 'volumetric'" ) + dynesty_group.add_argument("--maxmcmc", default=5000, type=int, + help="Maximum number of walks, defaults to 5000" ) + dynesty_group.add_argument( "--nact", default=2, type=int, + help="Number of autocorrelation times to take, defaults to 2") + dynesty_group.add_argument("--naccept", default=60, type=int, + help="The average number of accepted steps per MCMC chain, defaults to 60") + dynesty_group.add_argument("--min-eff", default=10, type=float, + help="The minimum efficiency at which to switch from uniform sampling.") + + + dynesty_group.add_argument("--facc", default=0.5, type=float, + help="See dynesty.NestedSampler") + dynesty_group.add_argument("--enlarge", default=1.5, type=float, + help="See dynesty.NestedSampler") + return parser + + def single_messenger_analysis_parsing(parser): parser = base_analysis_parsing(parser) + parser = dynesty_parsing(parser) parser.add_argument("--config", help="Name of the configuration file containing parameter values.") parser.add_argument("-o", "--outdir", default="outdir", help="Path to the output directory") @@ -121,8 +156,8 @@ def single_messenger_analysis_parsing(parser): parser.add_argument("--prior-file","--prior", help="Path to the prior file") parser.add_argument("--sampler", default="pymultinest", help="Sampler to be used (default: pymultinest)") - parser.add_argument("--sampler-kwargs", default="{}", - help="Additional kwargs (e.g. {'evidence_tolerance':0.5}) for bilby.run_sampler, put a double quotation marks around the dictionary") + parser.add_argument("--dlogz", default=0.1, type=float, + help="Stopping criteria: remaining evidence, (default=0.1)" ) parser.add_argument("--soft-init", action='store_true', help="To start the sampler softly (without any checking, default: False)") parser.add_argument("--cpus", type=int, default=1, @@ -137,6 +172,7 @@ def single_messenger_analysis_parsing(parser): return parser + def base_injection_parsing(parser): parser.add_argument("-f", "--injection-file","--filename", default="injection", help="Path to the file with injection parameters, default: 'injection'.") @@ -231,6 +267,27 @@ def slurm_analysis_parser(parser): return parser +def process_sampler_kwargs(sampler_kwargs, kwargs): + # Set defaults here to avoid inconsistent values + default_kwargs = dict(dlogz=0.1, save_bounds=False, + min_eff=10, sample="acceptance-walk", nlive=1000, bound="live", + walks=100, facc=0.5, enlarge=1.5) + + run_sampler_kwargs = {key: kwargs.get(key, default_kwargs[key]) + for key in ['dlogz', 'save_bounds']} + + + def_init_kwargs = {key: kwargs.get(key, default_kwargs[key]) + for key in ['min_eff','sample', 'nlive', 'bound', 'walks', 'facc', 'enlarge']} + + sampler_init_kwargs = def_init_kwargs | sampler_kwargs + sampler_init_kwargs['first_update'] = dict(min_eff=sampler_init_kwargs.pop('min_eff'), + min_ncall= 2 * sampler_init_kwargs['nlive']) + + return sampler_init_kwargs, run_sampler_kwargs + + + ############# UTILS ############# def process_multi_condition_string(multi_condition_string): # Supported operators diff --git a/nmma/joint/generation.py b/nmma/joint/generation.py index 270d63bb..55152047 100644 --- a/nmma/joint/generation.py +++ b/nmma/joint/generation.py @@ -209,9 +209,9 @@ def __init__(self, args, unknown_args, logger=None): self.adjust_priors_and_data(args, logger) #test-build likelihood - lhood = MultiMessengerLikelihood.setup_from_args( + self.lhood = MultiMessengerLikelihood.setup_from_args( self.data_dump, self._priors, self.args, logger) - lhood.log_likelihood(self._priors.sample()) + self.lhood.log_likelihood(self._priors.sample()) self.save_data_dump() diff --git a/nmma/joint/injection_handling.py b/nmma/joint/injection_handling.py index 1fb58ceb..9e3f9fcc 100644 --- a/nmma/joint/injection_handling.py +++ b/nmma/joint/injection_handling.py @@ -108,7 +108,9 @@ def setup_post_processing(self, args): postprocess_methods.append(self.prepare_lightcurves) if not postprocess_methods: - postprocess_methods.append(lambda df: df) # No-op if no postprocessing is needed + def dummy_postprocess(df): + return df + postprocess_methods.append(dummy_postprocess) # No-op if no postprocessing is needed self.postprocessing = postprocess_methods def generate_injection_file(self): diff --git a/nmma/joint/main.py b/nmma/joint/main.py index c64fc181..5239b6fc 100644 --- a/nmma/joint/main.py +++ b/nmma/joint/main.py @@ -1,127 +1,80 @@ -""" -Module to run parallel bilby using MPI -""" - import os os.environ["OMP_NUM_THREADS"] = "1" os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" -import signal -import io -import contextlib +from glob import glob +import pickle + try: from mpi4py import MPI rank = MPI.COMM_WORLD.Get_rank() except ImportError: rank = 0 -if rank != 0: - # Create buffers - stdout_buffer = io.StringIO() - stderr_buffer = io.StringIO() - - # Redirect python output into buffers - redirect_out = contextlib.redirect_stdout(stdout_buffer) - redirect_err = contextlib.redirect_stderr(stderr_buffer) -else: - redirect_out = contextlib.nullcontext() - redirect_err = contextlib.nullcontext() - - -with redirect_out, redirect_err: - from schwimmbad import MPIPool, MultiPool - from .multi_parsing import create_nmma_analysis_parser, parse_analysis_args, process_sampler_kwargs - from .analysis_run import Dynesty, Worker +if rank != 0: + devnull = os.open(os.devnull, os.O_WRONLY) + os.dup2(devnull, 1) + os.dup2(devnull, 2) + +from bilby.core.prior import PriorDict +from ..core.mpi_setup import pbilby_sampling +from .multi_parsing import create_nmma_analysis_parser, parse_analysis_args +from .joint_likelihood import MultiMessengerLikelihood +from ..core.utils import logger def analysis_runner( data_dump, outdir=None, label=None, - maxmcmc=5000, - naccept=60, - nact=2, - init_sampler_kwargs={}, - run_sampler_kwargs={}, - sampling_seed=42, - plot=True, - # - check_point_delta_t=1800, - n_check_point=2000, - max_its=1e10, - max_run_time=1e10, - checkpoint_plot=False, - # - rejection_sample_posterior=True, - result_format="hdf5", - pool_type ='mpi', + plot=False, **kwargs, ): """ API for running the analysis from Python instead of the command line. It takes all the same options as the CLI, specified as keyword arguments. """ - # Initialise a worker. this needs a global scope to allow - # persistence of states beyond the pool's scope. - # Otherwise emulators retrace on each evaluation. - global worker - with redirect_out, redirect_err: - if rank == 0: - init_sampler_kwargs, run_sampler_kwargs = process_sampler_kwargs( - init_sampler_kwargs, run_sampler_kwargs, kwargs) - - worker = Dynesty( - data_dump, outdir, label, - maxmcmc=maxmcmc, - nact=nact, - naccept=naccept, - sampling_seed=sampling_seed, - sampler_kwargs = run_sampler_kwargs, - sampler_init_kwargs=init_sampler_kwargs, - plot=plot, - ) - - else: - worker = Worker(data_dump, outdir, label) - - ## graceful handling of preemptive shutdowns - def handle_sigterm(signum, frame): - try: - worker.checkpointing(False, - 'Received termination signal. Checkpointing and exiting gracefully.') - except Exception: - pass - - signal.signal(signal.SIGTERM, handle_sigterm) - signal.signal(signal.SIGINT , handle_sigterm) - signal.signal(signal.SIGUSR1, handle_sigterm) - POOL = MPIPool if pool_type == 'mpi' else MultiPool - with POOL() as pool: - result = None - if pool.is_master(): - worker.start_sampler( - pool, - pooled_log_likelihood, - pooled_prior_transform, - pooled_initial_point_from_prior) - - results = worker.run_sampler( - check_point_delta_t, n_check_point, max_its, - max_run_time, checkpoint_plot) - result = worker.format_result(results, result_format, - rejection_sample_posterior) - return result - - -# Worker functions. These are read in the global scope by each worker -def pooled_initial_point_from_prior(args): - return worker.get_initial_point_from_prior(args) - -def pooled_log_likelihood(v_array): - return worker.log_likelihood(v_array) + ## Load the data dump + if not data_dump.endswith("_dump.pickle"): + test_out = os.path.join(os.getcwd(), data_dump) + test_dump = glob(f"{test_out}/data/*_dump.pickle") + data_dump = test_dump[0] + with open(data_dump, "rb") as file: + data_dump = pickle.load(file) + + ## Set properties from the data dump + args = data_dump["args"] + args.plot = plot + + # If the run dir has not been specified, get it from the args + if outdir: + args.outdir = outdir + + # If the label has not been specified, get it from the args + if label: + args.label = label + + priors = PriorDict.from_json(data_dump["prior_file"]) + + ## Set up the likelihood + likelihood = MultiMessengerLikelihood.setup_from_args( + data_dump, priors, args, logger) + + ## adjust meta data to storable format + meta_data = data_dump.copy() + waveform_generator = meta_data.pop("waveform_generator", None) + if waveform_generator is not None: + meta_data["waveform_generator"] = waveform_generator.__repr__() + ifo_list = meta_data.pop("ifo_list", None) + if ifo_list is not None: + meta_data["ifo_list"] = [ifo.__repr__() for ifo in ifo_list] -def pooled_prior_transform(u_array): - return worker.prior_transform(u_array) + + pbilby_sampling( + likelihood, priors, args, + data_dump.get("injection_parameters", None), rank, + plot=plot, meta_data=meta_data, + **kwargs) def nmma_analysis(): """ @@ -136,4 +89,4 @@ def nmma_analysis(): # Run the analysis analysis_runner(**vars(input_args)) - + diff --git a/nmma/joint/multi_parsing.py b/nmma/joint/multi_parsing.py index c57827b9..d27dc74d 100644 --- a/nmma/joint/multi_parsing.py +++ b/nmma/joint/multi_parsing.py @@ -3,7 +3,7 @@ import bilby from bilby_pipe import parser as bp_parser -from nmma.core.parsing import base_analysis_parsing, check_for_config, nonestr +from nmma.core.parsing import base_analysis_parsing, dynesty_parsing, check_for_config, nonestr from nmma.joint.joint_parsing import joint_likelihood_parsing from nmma.em.em_parsing import em_analysis_parsing from nmma.eos.eos_parsing import eos_parsing, tabulated_eos_parsing @@ -11,7 +11,6 @@ from .. import __version__ # noqa: E402 -from numpy import inf logger = bilby.core.utils.logger def _create_base_nmma_parser(sampler="dynesty", parents=[]): @@ -23,7 +22,7 @@ def _create_base_nmma_parser(sampler="dynesty", parents=[]): ) if sampler in ["all", "dynesty"]: - base_parser = sampler_parsing(base_parser) + base_parser = multi_sampler_parsing(base_parser) base_parser = dynesty_parsing(base_parser) base_parser = em_settings_parsing(base_parser) @@ -46,56 +45,20 @@ def em_settings_parsing(parser): return parser -def sampler_parsing(parser): +def multi_sampler_parsing(parser): sampler_group = parser.add_argument_group(title = "Setting for the Sampler") - sampler_group.add_argument("--init-sampler-kwargs", default="{}", - help="Additional keyword arguments to pass to the sampler as a dictionary" ) sampler_group.add_argument("--sampler", choices=["dynesty"], default="dynesty", help="The parallelised sampler to use, defaults to dynesty") sampler_group.add_argument( "-n", "--nlive", default=1000, type=int, help="Number of live points" ) sampler_group.add_argument("--dlogz", default=0.1, type=float, help="Stopping criteria: remaining evidence, (default=0.1)" ) - sampler_group.add_argument("--n-effective", default=inf, type=float, - help="Stopping criteria: effective number of samples, (default=inf)" ) sampler_group.add_argument("--bound","--dynesty-bound", default="live", help="Dynesty bounding method (default=live)" ) sampler_group.add_argument( "--sample", "--dynesty-sample", default="acceptance-walk", help="sampling method (default=acceptance-walk).") return parser -def dynesty_parsing(parser): - dynesty_group = parser.add_argument_group(title="Dynesty Settings") - dynesty_group.add_argument("--n-check-point", default=1000, type=int, - help="Steps to take before attempting checkpoint") - dynesty_group.add_argument("--max-its", default=10**10, type=int, - help="Maximum number of iterations to sample for (default=1.e10)") - dynesty_group.add_argument("--max-run-time", default=1.0e10, type=float, - help="Maximum time to run for (default=1.e10 s)") - dynesty_group.add_argument("--rejection-sample-posterior", action='store_false', help=( - "Whether to generate the posterior samples by rejection sampling the " - "nested samples or resampling with replacement" ) ) - dynesty_group.add_argument( "--walks", default=100, type=int, - help="Minimum number of walks, defaults to 100" ) - dynesty_group.add_argument( "--proposals", action="append", - help="The jump proposals to use, the options are 'diff' and 'volumetric'" ) - dynesty_group.add_argument("--maxmcmc", default=5000, type=int, - help="Maximum number of walks, defaults to 5000" ) - dynesty_group.add_argument( "--nact", default=2, type=int, - help="Number of autocorrelation times to take, defaults to 2") - dynesty_group.add_argument("--naccept", default=60, type=int, - help="The average number of accepted steps per MCMC chain, defaults to 60") - dynesty_group.add_argument("--min-eff", default=10, type=float, - help="The minimum efficiency at which to switch from uniform sampling.") - - - dynesty_group.add_argument("--facc", default=0.5, type=float, - help="See dynesty.NestedSampler") - dynesty_group.add_argument("--enlarge", default=1.5, type=float, - help="See dynesty.NestedSampler") - return parser - - def add_misc_settings(parser): misc_group = parser.add_argument_group(title="Misc. Settings") misc_group.add_argument("-c", "--clean", action='store_true', help="Run clean: ignore any resume files") @@ -189,8 +152,6 @@ def _create_reduced_bilby_pipe_parser(): return bilby_pipe_parser - - def create_nmma_generation_parser(): """Parser for nmma_generation""" bilby_pipe_parser = _create_reduced_bilby_pipe_parser() @@ -230,30 +191,6 @@ def parse_generation_args(cli_args=[""]): args = generation_parser.parse_args(args=cli_args) return args, generation_parser -def process_sampler_kwargs(init_sampler_kwargs, run_sampler_kwargs, kwargs): - # Set defaults here to avoid inconsistent values between main.py and MainRun - default_kwargs = dict(dlogz=0.1, save_bounds=False, - min_eff=10, sample="acceptance-walk", nlive=1000, bound="live", - walks=100, facc=0.5, enlarge=1.5) - - def_run_kwargs = {key: kwargs.get(key, default_kwargs[key]) - for key in ['dlogz', 'save_bounds']} - if isinstance(run_sampler_kwargs, str): - run_sampler_kwargs = literal_eval(run_sampler_kwargs) - run_sampler_kwargs = def_run_kwargs | run_sampler_kwargs - - - def_init_kwargs = {key: kwargs.get(key, default_kwargs[key]) - for key in ['min_eff','sample', 'nlive', 'bound', 'walks', 'facc', 'enlarge']} - if isinstance(init_sampler_kwargs, str): - init_sampler_kwargs = literal_eval(init_sampler_kwargs) - init_sampler_kwargs = def_init_kwargs | init_sampler_kwargs - init_sampler_kwargs['first_update'] = dict(min_eff=init_sampler_kwargs.pop('min_eff'), - min_ncall= 2 * init_sampler_kwargs['nlive']) - - return init_sampler_kwargs, run_sampler_kwargs - - def create_nmma_analysis_parser(sampler="dynesty"): """Parser for nmma_analysis""" parser = _create_base_nmma_parser(sampler=sampler) From b456dc0f13e3c2dc2b87d85c9dda111353c19fff Mon Sep 17 00:00:00 2001 From: Hen42rik Date: Tue, 7 Apr 2026 15:31:05 +0200 Subject: [PATCH 06/13] various bug fixes --- nmma/core/utils.py | 7 +-- nmma/em/analysis.py | 7 +-- nmma/em/em_likelihood.py | 6 +-- nmma/em/em_parsing.py | 3 +- nmma/em/lightcurve_generation.py | 27 +++++++---- nmma/em/model.py | 17 ++++++- nmma/eos/eos_parsing.py | 2 +- nmma/post_processing/hubble_estimates.py | 2 +- .../maximum_mass_constraint.py | 45 +++---------------- 9 files changed, 56 insertions(+), 60 deletions(-) diff --git a/nmma/core/utils.py b/nmma/core/utils.py index be7c5a85..c5f3c913 100644 --- a/nmma/core/utils.py +++ b/nmma/core/utils.py @@ -151,19 +151,20 @@ def set_filename(basename, args, identifier=''): return f"{base}{identifier}{ext}" -def read_bestfit_from_posterior(args, mode = 'max_likelihood'): +def read_bestfit_from_posterior(args, mode = 'max_likelihood', return_posterior=False): posterior_samples = get_posteriors(args) if mode == 'max_likelihood': bestfit = posterior_samples.loc[posterior_samples.log_likelihood.idxmax()] elif mode == 'max_posterior': - bestfit = posterior_samples.loc[(posterior_samples.log_likelihood*posterior_samples.log_prior).idxmax()] + bestfit = posterior_samples.loc[(posterior_samples.log_likelihood + posterior_samples.log_prior).idxmax()] else: raise ValueError(f"Mode {mode} not recognized. Use 'max_likelihood' or 'max_posterior'.") bestfit_params = bestfit.to_dict() bestfit_idx = bestfit.name print(f"Best fit parameters: {str(bestfit_params)}\nBest fit index: {bestfit_idx}") bestfit_params["best_fit_index"] = int(bestfit_idx) - return bestfit_params + + return (bestfit_params, posterior_samples) if return_posterior else bestfit_params def read_bestfit_from_json(bestfit_file_json, cols, verbose=False): df = pd.read_json(bestfit_file_json, typ="series") diff --git a/nmma/em/analysis.py b/nmma/em/analysis.py index 1f03f3f7..3a2862ee 100644 --- a/nmma/em/analysis.py +++ b/nmma/em/analysis.py @@ -52,7 +52,7 @@ def check_detections(data, remove_nondetections=False): def set_analysis_filters(filters, data): if filters is None: - filters = list(data.keys()) + return list(data.keys()) filters_to_analyze = [filt for filt in data.keys() if filt in filters] print(f"Running with filters {filters_to_analyze}") @@ -94,15 +94,15 @@ def bolometric_setup(args): return priors, likelihood, injection_parameters def analysis_setup(args): - + filters = utils.set_filters(args) - detection_limit = utils.create_detection_limit(args, filters) try: # load observational data data = io.load_em_observations(args, format='observations') trigger_time = read_trigger_time(None,args) injection_parameters = None except ValueError: + detection_limit = utils.create_detection_limit(args, filters) # try to work with injection data instead data, injection_parameters = data_from_injection(args, filters, detection_limit) trigger_time = injection_parameters.get('trigger_time',0) @@ -112,6 +112,7 @@ def analysis_setup(args): data = utils.cut_data_to_time_range(data, args, trigger_time) data = check_detections(data, args.remove_nondetections) filters_to_analyze = set_analysis_filters(filters, data) + detection_limit = utils.create_detection_limit(args, filters_to_analyze) # initialize light curve model print("Creating light curve model for inference") diff --git a/nmma/em/em_likelihood.py b/nmma/em/em_likelihood.py index e9a56cf1..a0e4c1f9 100644 --- a/nmma/em/em_likelihood.py +++ b/nmma/em/em_likelihood.py @@ -119,7 +119,7 @@ def final_diagnostics(self, bestfit_params, args, result=None): The figure object containing the plot """ - self.sub_model.final_diagnostics(bestfit_params, args, result) + return self.sub_model.final_diagnostics(bestfit_params, args, result) def posterior_conversion(self, posterior_samples): if 'log10_mej_dyn' in posterior_samples and 'log10_mej_wind' in posterior_samples: @@ -259,7 +259,7 @@ def final_diagnostics(self, bestfit_params, args, result=None): if result is None: save_path = f'{args.outdir}/{args.label}_bol_lightcurve.png' save_path = f'{result.outdir}/{result.label}_bol_lightcurve.png' - bolometric_lc_plot(self, obs_times, obs_lc, save_path = save_path) + return bolometric_lc_plot(self, obs_times, obs_lc, save_path = save_path) class MultiFilterTransient(BasicEMTransient): @@ -351,4 +351,4 @@ def band_log_likelihood(self, expected_mags, obs_error): return minus_chisquare_total + gaussprob_total def final_diagnostics(self, bestfit_params, args, result=None): - lch_bestfit(self, bestfit_params, args, result) \ No newline at end of file + return lch_bestfit(self, bestfit_params, args, result) \ No newline at end of file diff --git a/nmma/em/em_parsing.py b/nmma/em/em_parsing.py index 12d1e3f8..cb722e0b 100644 --- a/nmma/em/em_parsing.py +++ b/nmma/em/em_parsing.py @@ -232,7 +232,8 @@ def modified_em_prior_parsing(parser): mod_em_prior_parser.add_argument("--em-error-budget", "--kilonova-error", help="Additional statistical error (mag) to be introduced in each filter," \ " can be passed as list or dict. Will only be used if em_syserr is not given in prior") - mod_em_prior_parser.add_argument("--systematics-file", help="Path to systematics configuration file") + mod_em_prior_parser.add_argument("--systematics-file", type=yaml_parse, + help="Path to systematics configuration file") return parser def em_analysis_parsing(parser): diff --git a/nmma/em/lightcurve_generation.py b/nmma/em/lightcurve_generation.py index a5e04387..bd6983e9 100644 --- a/nmma/em/lightcurve_generation.py +++ b/nmma/em/lightcurve_generation.py @@ -36,12 +36,15 @@ def inner(func): ################################################################# +def dummy_add(nu): + return 0.0 + def bb_flux_from_inv_temp(nu, inv_temp, R_photo, dist_squared = abs_mag_dist_factor): exponent = np.clip(h * nu * inv_temp / kb, None, 700) # to avoid overflow in exp bb_factor = 2.* h/ c_cgs**2 return bb_factor * nu**3 /np.expm1(exponent) * R_photo * R_photo / dist_squared -def mag_dict_for_blackbody(filters, inv_temp, R_photo, nu_host, add = lambda x: 0.): +def mag_dict_for_blackbody(filters, inv_temp, R_photo, nu_host, add = dummy_add): mag = {} # nu_host = nu_obs * (1 + redshift) for idx, filt in enumerate(filters): @@ -57,10 +60,11 @@ def mag_dict_for_blackbody(filters, inv_temp, R_photo, nu_host, add = lambda x: ################################################################# ######################### LC MODELS ############################# ################################################################# - ## Arnett model convenience functions def arnett_lc_get_int_A_non_vec(x, y): - r = quad(lambda z: 2 * z * np.exp(-2 * z * y + z**2), 0, x) + def arnett_func(z): + return 2 * z * np.exp(-2 * z * y + z**2) + r = quad(arnett_func, 0, x) return r[0] @@ -68,7 +72,9 @@ def arnett_lc_get_int_A_non_vec(x, y): def arnett_lc_get_int_B_non_vec(x, y, s): - r = quad(lambda z: 2 * z * np.exp(-2 * z * y + 2 * z * s + z**2), 0, x) + def arnett_func(z): + return 2 * z * np.exp(-2 * z * y + 2 * z * s + z**2) + r = quad(arnett_func, 0, x) return r[0] @@ -235,13 +241,14 @@ def flux_density_on_E0_array(default_time, obs_frequencies, param_dict): time_scale = np.log10(default_time / t_end) log10_E0[mask] = log10_Eend + energy_exponential * time_scale[mask] E0 = 10 ** log10_E0 - vec_func = np.vectorize( - lambda i: fluxDensity( + def helper(i): + return fluxDensity( default_time[i], obs_frequencies, E0=E0[i], **param_dict - ), + ) + vec_func = np.vectorize(helper, otypes=[np.ndarray] ) mJys = vec_func(np.arange(len(default_time))) @@ -293,7 +300,11 @@ def sn_lc(sample_times_stretched, sn_model, filters, lambdas): mag[filt] = sn_model.bandmag(filt, 'ab', sample_times_stretched) except ValueError: lambda_AA = 1e10 * lambda_ - flux_AA = sn_model.flux(sample_times_stretched, lambda_AA) + if lambda_AA < sn_model.minwave() or lambda_AA > sn_model.maxwave(): + mag[filt] = np.full_like(sample_times_stretched, np.inf) + continue + #NOTE: workaround for potential bug in sncosmo: buffer error if lambdaa as float + flux_AA = sn_model.flux(sample_times_stretched, [lambda_AA]).flatten() # see https://en.wikipedia.org/wiki/AB_magnitude mag[filt] = utils.flux_to_ABmag(flux_AA*3.34e4 * lambda_AA**2, unit='Jy') return mag diff --git a/nmma/em/model.py b/nmma/em/model.py index fe708e12..c334656c 100644 --- a/nmma/em/model.py +++ b/nmma/em/model.py @@ -690,11 +690,24 @@ def __init__(self, model="Bu2026_MLP", filters=None, surrogate_dir=None, **em_mo from fiesta.inference.lightcurve_model import BullaLightcurveModel as BullaSurrogate else: from fiesta.inference.lightcurve_model import BullaFlux as BullaSurrogate - fiesta_kwargs= dict( name=model, filters=filters, directory=surrogate_dir,**em_model_kwargs) + fiesta_kwargs= dict( name=model, filters=filters, directory=surrogate_dir) fiesta_model = BullaSurrogate(**fiesta_kwargs) super().__init__(fiesta_model, filters, sample_times=em_model_kwargs.get('sample_times', None)) + def parameter_conversion(self, parameters): + if "kappa_Ye" in parameters: + if "Ye_wind" in parameters: + parameters["Ye_dyn"] = parameters["kappa_Ye"] * parameters["Ye_wind"] + else: + parameters["Ye_wind"] = parameters["Ye_dyn"] / parameters["kappa_Ye"] + if "kappa_v" in parameters: + if "v_ej_wind" in parameters: + parameters["v_ej_dyn"] = parameters["kappa_v"] * parameters["v_ej_wind"] + else: + parameters["v_ej_wind"] = parameters["v_ej_dyn"] / parameters["kappa_v"] + return super().parameter_conversion(parameters) + class GRBMixin: def __init__(self, *args, resolution=12, **kwargs): self.resolution = resolution @@ -754,7 +767,7 @@ class FiestaGRBModel(GRBMixin,FiestaModel): """ def __init__(self, model="afgpy_gaussian_CVAE", filters=None, surrogate_dir=None, **em_model_kwargs): from fiesta.inference.lightcurve_model import AfterglowFlux - fiesta_kwargs= dict( name=model, filters=filters, directory=surrogate_dir, **em_model_kwargs) + fiesta_kwargs= dict( name=model, filters=filters, directory=surrogate_dir) fiesta_model = AfterglowFlux(**fiesta_kwargs) super().__init__(fiesta_model, filters, sample_times=em_model_kwargs.get('sample_times', None)) diff --git a/nmma/eos/eos_parsing.py b/nmma/eos/eos_parsing.py index 8a2caf45..cee07a4e 100644 --- a/nmma/eos/eos_parsing.py +++ b/nmma/eos/eos_parsing.py @@ -56,7 +56,7 @@ def eos_parsing(parser): help= "dict with additional mass-radius constraints to consider, using style: {'name':{'file_path':path_to_R_M_posterior,[, 'arxiv':'arxiv_id']},...}") eos_input_parser.add( "--mass-radius-name", nargs ="*", help= "list of identifiers for further mass-radius-posteriors to consider") - eos_input_parser.add("--mass-radius-posterior", "--mass-radius-file-path",nargs ="*", + eos_input_parser.add("--mass-radius-file-path", "--mass-radius-posterior", nargs ="*", help= "list of files with additional radius-mass posteriors to consider") eos_input_parser.add( "--mass-radius-arxiv", nargs ="*", help= "list of arxiv-ids for additional R-M posteriors to consider") diff --git a/nmma/post_processing/hubble_estimates.py b/nmma/post_processing/hubble_estimates.py index b1652d74..7b279a98 100644 --- a/nmma/post_processing/hubble_estimates.py +++ b/nmma/post_processing/hubble_estimates.py @@ -21,7 +21,7 @@ def generate_logprob(probs, H0sample, index): for idx, i in enumerate(index): logprob_combined+= probs[i].logpdf(H0sample) if idx!=0: - logprob_combined+= + 3 * np.log(H0sample) + logprob_combined+=3 * np.log(H0sample) logprob_combined-= scipy.special.logsumexp(logprob_combined) log_prob_list.append(logprob_combined) diff --git a/nmma/post_processing/maximum_mass_constraint.py b/nmma/post_processing/maximum_mass_constraint.py index 5b7a7fb4..21c12b7f 100644 --- a/nmma/post_processing/maximum_mass_constraint.py +++ b/nmma/post_processing/maximum_mass_constraint.py @@ -19,39 +19,6 @@ from ..core.constants import MeV_per_fm3_to_Msun_per_km3, geom_msun_km, particle_mass -def fileno(file_or_fd): - fd = getattr(file_or_fd, 'fileno', lambda: file_or_fd)() - if not isinstance(fd, int): - raise ValueError("Expected a file (`.fileno()`) or a file descriptor") - return fd - -@contextlib.contextmanager -def stdout_redirected(to=os.devnull, stdout=None): - """ - https://stackoverflow.com/a/22434262/190597 (J.F. Sebastian) - """ - if stdout is None: - stdout = sys.stdout - - stdout_fd = fileno(stdout) - # copy stdout_fd before it is overwritten - #NOTE: `copied` is inheritable on Windows when duplicating a standard stream - with os.fdopen(os.dup(stdout_fd), 'wb') as copied: - stdout.flush() # flush library buffers that dup2 knows nothing about - try: - os.dup2(fileno(to), stdout_fd) # $ exec >&to - except ValueError: # filename - with open(to, 'wb') as to_file: - os.dup2(to_file.fileno(), stdout_fd) # $ exec > to - try: - yield stdout # allow code to be run with the redirected stdout - finally: - # restore stdout to its previous value - #NOTE: dup2 makes stdout_fd inheritable unconditionally - stdout.flush() - os.dup2(copied.fileno(), stdout_fd) # $ exec >&copied - - def baryonic_mass(gravitational_mass, EOS, eos_path_macro, eos_path_micro): @@ -76,8 +43,7 @@ def TOVeq(y, x): x = np.arange(dr, r+dr, dr) y0 = [p0, m0] - with stdout_redirected(): - p_solv, m_solv = scipy.integrate.odeint(TOVeq, y0=y0, t = x).T + p_solv, m_solv = scipy.integrate.odeint(TOVeq, y0=y0, t = x).T n_solv = np.interp(p_solv, P, N) @@ -145,7 +111,6 @@ def __init__(self, prior, posterior_samples, Neos, eos_path_macro, eos_path_micr log10_mdisk = self.posterior_samples.log10_mdisk.to_numpy() self.KDE = scipy.stats.gaussian_kde((chirp_mass, eta_star, EOS, log10_mdisk, log10_mej_dyn)) - super().__init__(**kwargs) def Prior(self, x): @@ -171,7 +136,12 @@ def LogLikelihood(self, x): mdisk = 10**log10_mdisk mej_dyn = 10**log10_mej_dyn - m_rem_b = baryonic_mass(mass_1, EOS, self.eos_path_macro, self.eos_path_micro) + baryonic_mass(mass_2, EOS, self.eos_path_macro, self.eos_path_micro) - mdisk - mej_dyn #calculate the baryonic remnant mass + try: + b1 = baryonic_mass(mass_1, EOS, self.eos_path_macro, self.eos_path_micro) + b2 = baryonic_mass(mass_2, EOS, self.eos_path_macro, self.eos_path_micro) + m_rem_b = b1 + b2 - mdisk - mej_dyn #calculate the baryonic remnant mass + except: + breakpoint() if self.use_M_max: m_threshold = baryonic_Kepler_mass(mTOV, R_14, ratio_R, delta) #if the Kepler limit is the threshold, use the quasiuniversal relation @@ -223,7 +193,6 @@ def maximum_mass_resampling(args): class PostmergerInference(PostmergerInferenceMixIn, Solver): pass - solution = PostmergerInference(prior, posterior_samples, Neos, args.eos_path_macro, args.eos_path_micro, args.use_M_Kepler, **pymulti_kwargs) From 07bdfa350975aeb570787cc42bea2cdee2391bcf Mon Sep 17 00:00:00 2001 From: Hen42rik Date: Tue, 7 Apr 2026 16:07:08 +0200 Subject: [PATCH 07/13] adjust GRB conversions --- nmma/core/conversion.py | 54 +++++++++++++++++++++++++++++------------ 1 file changed, 39 insertions(+), 15 deletions(-) diff --git a/nmma/core/conversion.py b/nmma/core/conversion.py index d2ba1f56..79cd09f0 100644 --- a/nmma/core/conversion.py +++ b/nmma/core/conversion.py @@ -656,7 +656,7 @@ def chiBH_fitting( return chi_BH - def bns_parameter_conversion(self, converted_parameters): + def bns_ejecta_conversion(self, converted_parameters): # prevent the output message flooded by these warning messages old = np.seterr() @@ -692,26 +692,48 @@ def bns_parameter_conversion(self, converted_parameters): # total eject mass total_ejeta_mass = 10**log10_mej_dyn + 10**log10_mej_wind + np.seterr(**old) + return log10_mej_dyn, log10_mej_wind, np.log10(total_ejeta_mass), log10_mdisk_fit + + def grb_energy_conversion(self, converted_parameters, log10_mdisk_fit): + # GRB afterglow energy - log10_Ejet = converted_parameters.get("log10_E0", ( - np.log10(converted_parameters.get("ratio_epsilon", 2e-4)) - + np.log10(1.0 - converted_parameters["ratio_zeta"]) - + log10_mdisk_fit + np.log10(msun_to_ergs) ) - ) + log10_Ejet = np.log10(converted_parameters.get("ratio_epsilon", 2e-4)) + log10_Ejet += np.log10(1.0 - converted_parameters["ratio_zeta"]) + log10_Ejet += log10_mdisk_fit + np.log10(msun_to_ergs) - thetaCore = converted_parameters.get("thetaCore", 0.105) + thetaCore = converted_parameters.get("thetaCore", 0.105) ## default about 6 degree, see arxiv:2210.05695 + + if not any(key in converted_parameters for key in ["thetaWing", "alphaWing", "b"]): + return log10_Ejet - np.log10(np.sin(thetaCore/2)**2) + + if "alphaWing" in converted_parameters: + alphaWing = converted_parameters['alphaWing'] + else: + alphaWing = converted_parameters["thetaWing"] / converted_parameters["thetaCore"] + if "b" in converted_parameters: # power law jet - alphaWing = converted_parameters.get("alphaWing", converted_parameters["thetaWing"] / converted_parameters["thetaCore"]) - log10_E0 = np.log10(powerlaw_jet_energy_to_central_isotropic_energy_equivalent(10**log10_Ejet, thetaCore, alphaWing, converted_parameters["b"])) - elif "b" not in converted_parameters and ("thetaWing" in converted_parameters or "alphaWing" in converted_parameters): # gaussian jet - alphaWing = converted_parameters.get("alphaWing", converted_parameters["thetaWing"] / converted_parameters["thetaCore"]) - log10_E0 = np.log10(gaussian_jet_energy_to_central_isotropic_energy_equivalent(10**log10_Ejet, thetaCore, alphaWing)) + jet_func = powerlaw_jet_energy_to_central_isotropic_energy_equivalent + data = np.column_stack((10**log10_Ejet, thetaCore, alphaWing, converted_parameters["b"])) + else: - log10_E0 = log10_Ejet - np.log10(np.sin(thetaCore/2)**2) + jet_func = gaussian_jet_energy_to_central_isotropic_energy_equivalent + data = np.column_stack((10**log10_Ejet, thetaCore, alphaWing)) + + return np.log10([jet_func(*row) for row in data]) + - np.seterr(**old) - converted_ejecta = np.stack((log10_mej_dyn, log10_mej_wind, np.log10(total_ejeta_mass), log10_E0 )) + + def bns_parameter_conversion(self, parameters): + log10_mej_dyn, log10_mej_wind, log10_mej_total, log10_mdisk_fit = self.bns_ejecta_conversion(parameters) + + if "log10_E0" in parameters: + log10_E0 = parameters["log10_E0"] + else: + log10_E0 = self.grb_energy_conversion(parameters, log10_mdisk_fit) + + converted_ejecta = (log10_mej_dyn, log10_mej_wind, log10_mej_total, log10_E0) return np.where(np.isfinite(converted_ejecta), converted_ejecta, -np.inf) @@ -827,6 +849,8 @@ def identity_conversion(self, parameters): 'alpha' : r'$\alpha$', 'KNtheta' : r'$\theta_{KN} [^\circ]$', 'KNphi' : r'$\phi_{KN} [^\circ]$', + 'kappa_Ye' : r'$\kappa_{\rm{Y_e}}$', + 'kappa_v' : r'$\kappa_{v}$', ## GRB parameters ## 'log10_E0' : r'$\log_{10}(E_0{\rm [erg]})$', 'ratio_epsilon' : r'$\epsilon$', From af3babc7eab48259c83d40165af12371a39f6864 Mon Sep 17 00:00:00 2001 From: Hen42rik Date: Mon, 4 May 2026 13:43:52 +0200 Subject: [PATCH 08/13] ENH: restructure MPI sampling to include other samplers --- nmma/core/base.py | 12 ++++++------ nmma/core/mpi_setup.py | 17 +++++++++++++---- nmma/core/parsing.py | 19 ++++++++++--------- nmma/joint/main.py | 17 +++++++++++------ nmma/joint/multi_parsing.py | 7 ++----- 5 files changed, 42 insertions(+), 30 deletions(-) diff --git a/nmma/core/base.py b/nmma/core/base.py index 2c0ca1ec..c151a11a 100644 --- a/nmma/core/base.py +++ b/nmma/core/base.py @@ -311,7 +311,7 @@ def bilby_sampling(likelihood, priors, args, injection_parameters=None, rank=0): sampler_kwargs["niter"] = 1 elif args.sampler == "dynesty": sampler_kwargs["maxiter"] = 1 - + result = run_sampler( likelihood, priors, @@ -331,6 +331,7 @@ def bilby_sampling(likelihood, priors, args, injection_parameters=None, rank=0): return try: + result.posterior = likelihood.posterior_conversion(result.posterior) result.save_to_file() result.save_posterior_samples() except FileMovedError: @@ -362,7 +363,6 @@ def bilby_sampling(likelihood, priors, args, injection_parameters=None, rank=0): result.plot_corner(injection_parameters, priors) if args.bestfit or args.plot: - result.posterior = likelihood.posterior_conversion(result.posterior) likelihood.post_process_bestfit(args, result) return result @@ -376,10 +376,10 @@ def multi_analysis_loop(args, analysis_setup): except ImportError: rank = 0 - # if rank != 0: - # devnull = os.open(os.devnull, os.O_WRONLY) - # os.dup2(devnull, 1) - # os.dup2(devnull, 2) + if rank != 0 and not getattr(args, 'verbose', False): + devnull = os.open(os.devnull, os.O_WRONLY) + os.dup2(devnull, 1) + os.dup2(devnull, 2) if getattr(args, 'multi', None): sub_runs = [] diff --git a/nmma/core/mpi_setup.py b/nmma/core/mpi_setup.py index ec2dc380..6517af32 100644 --- a/nmma/core/mpi_setup.py +++ b/nmma/core/mpi_setup.py @@ -58,10 +58,11 @@ def __init__( super().__init__( likelihood, prior, self.outdir, self.label, - injection_parameters, + injection_parameters = injection_parameters, skip_import_verification = skip_import_verification, plot= plot, soft_init=True, + use_ratio = True, ) @@ -141,7 +142,7 @@ def __init__( plot= False, meta_data = {}, ): - + breakpoint() super().__init__(args, prior, likelihood, injection_parameters, plot, skip_import_verification = False) @@ -484,12 +485,20 @@ def plot_current_state(self): def storable_metadata(self): meta_data = self.meta_data - - meta_data["args"] = vars(self.args) # convert Namespace to dict for storing + meta_data["args"] = vars(self.args).copy() # convert Namespace to dict for storing meta_data["likelihood"] = self.likelihood.meta_data meta_data["sampler_kwargs"] = self.init_sampler_kwargs meta_data["run_sampler_kwargs"] = self.sampler_kwargs + meta_data = self.floatify_dict(meta_data) return meta_data + + def floatify_dict(self, d): + for k, v in d.items(): + if isinstance(v, dict): + d[k] = self.floatify_dict(v) + elif isinstance(v, np.floating): + d[k] = float(v) + return d def format_result( self, diff --git a/nmma/core/parsing.py b/nmma/core/parsing.py index 467ee971..7e431576 100644 --- a/nmma/core/parsing.py +++ b/nmma/core/parsing.py @@ -110,6 +110,15 @@ def base_analysis_parsing(parser): help="Sampling seed (default: 42)") parser.add_argument("--sampler-kwargs", default="{}", type = yaml_parse, help="Additional keyword arguments to pass to the sampler as a dictionary" ) + parser.add_argument("--skip-sampling", action='store_true', + help="If analysis has already run, skip bilby sampling and compute results from checkpoint files. Combine with --plot to make plots from these files.") + parser.add_argument("--dlogz", default=0.1, type=float, + help="Stopping criteria: remaining evidence, (default=0.1)" ) + parser.add_argument("--soft-init", action='store_true', + help="To start the sampler softly (without any checking, default: False)") + parser.add_argument("--cpus", type=int, default=1, + help="Number of cores to be used, only needed for dynesty (default: 1)") + parser.add_argument("-n","--nlive", "--n-live", type=int, default=2048, help="Number of live points (default: 2048)") return parser def dynesty_parsing(parser): @@ -152,21 +161,13 @@ def single_messenger_analysis_parsing(parser): parser.add_argument("-o", "--outdir", default="outdir", help="Path to the output directory") parser.add_argument("--label", default ="nmma_transient", help="Label for the run") parser.add_argument("--plot", action='store_true', help="create characteristic plot") + parser.add_argument("--plot-kwargs", type=yaml_parse, default={}, help="Additional keyword arguments to pass to the plotting routine as a dictionary" ) parser.add_argument("--verbose", action='store_true', help="print out log likelihoods" ) parser.add_argument("--prior-file","--prior", help="Path to the prior file") parser.add_argument("--sampler", default="pymultinest", help="Sampler to be used (default: pymultinest)") - parser.add_argument("--dlogz", default=0.1, type=float, - help="Stopping criteria: remaining evidence, (default=0.1)" ) - parser.add_argument("--soft-init", action='store_true', - help="To start the sampler softly (without any checking, default: False)") - parser.add_argument("--cpus", type=int, default=1, - help="Number of cores to be used, only needed for dynesty (default: 1)") - parser.add_argument("-n","--nlive", "--n-live", type=int, default=2048, help="Number of live points (default: 2048)") parser.add_argument("--reactive-sampling", action='store_true', help="To use reactive sampling in ultranest (default: False)") - parser.add_argument("--skip-sampling", action='store_true', - help="If analysis has already run, skip bilby sampling and compute results from checkpoint files. Combine with --plot to make plots from these files.") parser.add_argument("--bestfit", "--best-fit", action='store_true', help="Save the best fit parameters to JSON") diff --git a/nmma/joint/main.py b/nmma/joint/main.py index 5239b6fc..d2ac7d53 100644 --- a/nmma/joint/main.py +++ b/nmma/joint/main.py @@ -18,6 +18,7 @@ from bilby.core.prior import PriorDict from ..core.mpi_setup import pbilby_sampling +from ..core.base import bilby_sampling from .multi_parsing import create_nmma_analysis_parser, parse_analysis_args from .joint_likelihood import MultiMessengerLikelihood from ..core.utils import logger @@ -69,12 +70,16 @@ def analysis_runner( if ifo_list is not None: meta_data["ifo_list"] = [ifo.__repr__() for ifo in ifo_list] - - pbilby_sampling( - likelihood, priors, args, - data_dump.get("injection_parameters", None), rank, - plot=plot, meta_data=meta_data, - **kwargs) + if args.sampler == "dynesty": + logger.info("Using dynesty sampler") + return pbilby_sampling( + likelihood, priors, args, + data_dump.get("injection_parameters", None), rank, + plot=plot, meta_data=meta_data, **kwargs) + else: + return bilby_sampling( + likelihood, priors, args, + data_dump.get("injection_parameters", None), rank) def nmma_analysis(): """ diff --git a/nmma/joint/multi_parsing.py b/nmma/joint/multi_parsing.py index d27dc74d..46de5308 100644 --- a/nmma/joint/multi_parsing.py +++ b/nmma/joint/multi_parsing.py @@ -21,8 +21,8 @@ def _create_base_nmma_parser(sampler="dynesty", parents=[]): version=f"%(prog)s={__version__}\nbilby={bilby.__version__}", ) + base_parser = multi_sampler_parsing(base_parser) if sampler in ["all", "dynesty"]: - base_parser = multi_sampler_parsing(base_parser) base_parser = dynesty_parsing(base_parser) base_parser = em_settings_parsing(base_parser) @@ -48,11 +48,8 @@ def em_settings_parsing(parser): def multi_sampler_parsing(parser): sampler_group = parser.add_argument_group(title = "Setting for the Sampler") - sampler_group.add_argument("--sampler", choices=["dynesty"], default="dynesty", + sampler_group.add_argument("--sampler", default="dynesty", help="The parallelised sampler to use, defaults to dynesty") - sampler_group.add_argument( "-n", "--nlive", default=1000, type=int, help="Number of live points" ) - sampler_group.add_argument("--dlogz", default=0.1, type=float, - help="Stopping criteria: remaining evidence, (default=0.1)" ) sampler_group.add_argument("--bound","--dynesty-bound", default="live", help="Dynesty bounding method (default=live)" ) sampler_group.add_argument( "--sample", "--dynesty-sample", default="acceptance-walk", From d9d2035cdc753d880603e444d4084010bd0ac256 Mon Sep 17 00:00:00 2001 From: Hen42rik Date: Mon, 4 May 2026 13:45:04 +0200 Subject: [PATCH 09/13] ENH: Simplify injection handling in joint analysis --- nmma/core/conversion.py | 7 +++ nmma/gw/gw_inputs.py | 36 ++++++++++++ nmma/joint/generation.py | 100 +++++++++++++++++++++------------ nmma/joint/joint_likelihood.py | 5 +- 4 files changed, 109 insertions(+), 39 deletions(-) create mode 100644 nmma/gw/gw_inputs.py diff --git a/nmma/core/conversion.py b/nmma/core/conversion.py index 79cd09f0..b63271c2 100644 --- a/nmma/core/conversion.py +++ b/nmma/core/conversion.py @@ -793,6 +793,9 @@ def from_dict(cls, instruction_dict): if 'em' in instruction_dict: conversions.append(instruction_dict['em']) + + if 'custom' in instruction_dict: + conversions.append(instruction_dict['custom']) return cls(*conversions) @@ -849,6 +852,10 @@ def identity_conversion(self, parameters): 'alpha' : r'$\alpha$', 'KNtheta' : r'$\theta_{KN} [^\circ]$', 'KNphi' : r'$\phi_{KN} [^\circ]$', + # Bu parameters ## + 'vej_dyn' : r'$v_{\rm{dyn}}{\rm [c]}$', + 'vej_wind' : r'$v_{\rm{wind}}{\rm [c]}$', + 'Ye_dyn' : r'$Y_{e,{\rm{dyn}}}$', 'kappa_Ye' : r'$\kappa_{\rm{Y_e}}$', 'kappa_v' : r'$\kappa_{v}$', ## GRB parameters ## diff --git a/nmma/gw/gw_inputs.py b/nmma/gw/gw_inputs.py new file mode 100644 index 00000000..5346433e --- /dev/null +++ b/nmma/gw/gw_inputs.py @@ -0,0 +1,36 @@ +# FIXME: This is a hacky subclass to adapt the bilby_pipe data generation to +# our needs. We should at some point get rid of bilby pipe + +from bilby_pipe.data_generation import DataGenerationInput as DataInput +class NMMAGravitationalWaveInput(DataInput): + def __init__(self, args, unknown_args): + args.calibration_correction_type = 'data' + super().__init__(args, unknown_args) + self.interferometers.plot_data(outdir=self.data_directory, label=self.label) + if self.likelihood_type == "ROQGravitationalWaveTransient": + self.save_roq_weights() + + args.gw_likelihood_type = self.likelihood_type + 'Quick wrapper to fix some issues with the bilby_pipe data generation' + + @DataInput.interferometers.setter + def interferometers(self, interferometers): + self._detectors = [ifo.name for ifo in interferometers] + self.minimum_frequency_dict = self.reset_frequency_dict( + self.minimum_frequency_dict) + self.maximum_frequency_dict = self.reset_frequency_dict( + self.maximum_frequency_dict) + DataInput.interferometers.fset(self, interferometers) + + def reset_frequency_dict(self, frequency_dict): + out_dict = {} + for det in self.detectors: + if det in frequency_dict: + out_dict[det] = frequency_dict[det] + # eg. ET1, ET2, ET3 were given by ET + elif det[:-1] in frequency_dict: + out_dict[det] = frequency_dict[det[:-1]] + else: + raise ValueError( + f"Detector {det} not found in frequency dict {frequency_dict}") + return out_dict \ No newline at end of file diff --git a/nmma/joint/generation.py b/nmma/joint/generation.py index 55152047..81768754 100644 --- a/nmma/joint/generation.py +++ b/nmma/joint/generation.py @@ -18,6 +18,11 @@ from .multi_parsing import parse_generation_args +from ..core.constants import set_cosmology +from ..core.conversion import KilonovaEjectaFitting +from ..core.base import adjust_priors_for_nmma, adjust_hubble_prior +from ..core.utils import read_trigger_time +from ..gw.gw_inputs import NMMAGravitationalWaveInput from ..em.prior import extinction_prior from ..em.io import load_em_observations from ..em.model import create_injection_model @@ -26,9 +31,6 @@ from ..em import utils as em_utils from ..eos.eos_likelihood import (compose_eos_constraints, EoSConverter, JointEoSConstraint, setup_tabulated_eos_priors) -from ..core.constants import set_cosmology -from ..core.base import adjust_priors_for_nmma, adjust_hubble_prior -from ..core.utils import read_trigger_time from .joint_likelihood import MultiMessengerLikelihood import matplotlib @@ -186,14 +188,16 @@ def __init__(self, args, unknown_args, logger=None): self.data_dump_file = f"{self.data_directory}/{self.label}_data_dump.pickle" self.data_set = False - self.injection = args.injection self.injection_numbers = args.injection_numbers self.injection_file = args.injection_file self.injection_dict = args.injection_dict - if self.injection: + if self.injection_file or self.injection_dict: + self.injection = True + args.injection = True self.injection_parameters = self.injection_df.iloc[self.idx].to_dict() else: - self.injection_parameters=None + args.injection = False + self.injection_parameters = None self.trigger_time = read_trigger_time(self.injection_parameters, args, 'gps') args.trigger_time = self.trigger_time @@ -204,10 +208,10 @@ def __init__(self, args, unknown_args, logger=None): **get_version_info(), command_line_args=args.__dict__, unknown_command_line_args=self.unknown_args, - injection_parameters= self.injection_parameters, + injection_parameters=self.injection_parameters, ) self.adjust_priors_and_data(args, logger) - + #test-build likelihood self.lhood = MultiMessengerLikelihood.setup_from_args( self.data_dump, self._priors, self.args, logger) @@ -218,19 +222,21 @@ def __init__(self, args, unknown_args, logger=None): def adjust_priors_and_data(self, args, logger): messengers, analysis_modifiers = [], [] data_dump = dict(injection_parameters = self.injection_parameters) + if self.injection_parameters: + converted_injection = self.injection_parameters.copy() + # GW SETUP if args.detectors: messengers.append("gw") # get a BBHPriorDict only if GW parameters are present priors = super()._get_priors() - self.gw_inputs= bilby_pipe.data_generation.DataGenerationInput(args, self.unknown_args) - #### FIXME resetting likelihood type is an unpleasant bilby_pipe remnant - self.gw_inputs.interferometers.plot_data(outdir=self.data_directory, label=self.label) - args.gw_likelihood_type = self.gw_inputs.likelihood_type - if args.gw_likelihood_type == "ROQGravitationalWaveTransient": - self.gw_inputs.save_roq_weights() - data_dump |= dict(waveform_generator=self.gw_inputs.waveform_generator, - ifo_list=self.gw_inputs.interferometers) + if self.injection_parameters: + converted_injection = bilby.gw.conversion.convert_to_lal_binary_neutron_star_parameters( + converted_injection)[0] + else: + self.gw_inputs= NMMAGravitationalWaveInput(args, self.unknown_args) + data_dump |= dict(waveform_generator=self.gw_inputs.waveform_generator, + ifo_list=self.gw_inputs.interferometers) else: priors = self._get_priors() @@ -239,15 +245,56 @@ def adjust_priors_and_data(self, args, logger): if args.Hubble or any(['hubble' in key.lower() for key in priors.keys()]): analysis_modifiers.append("Hubble") + # EOS Setup + if args.emulator_metadata: + messengers.append("eos") + logger.info("Setting up EOS constraints") + data_dump |= dict(eos_constraint_dict= compose_eos_constraints(args)) + eos_converter = EoSConverter(args, 'emulated') + + elif args.eos_data: + analysis_modifiers.append('tabulated_eos') + eos_constraint_dict = compose_eos_constraints(args) + if eos_constraint_dict: + eos_converter = EoSConverter(args, 'tabulated') + constraint = JointEoSConstraint(eos_constraint_dict, eos_converter=eos_converter) + args.eos_weight, args.eos_data, args.Neos = constraint.tabulate_weighted_eos( + args.Neos, args.outdir, args.eos_weight) + priors = setup_tabulated_eos_priors(args, priors, logger) + + if self.injection_parameters: + try: + converted_injection = eos_converter(converted_injection) + converted_injection = KilonovaEjectaFitting()(converted_injection) + except Exception as e: + logger.warning("eos and ejecta fitting failed for injection parameters. Continuing without conversion.") + logger.warning(f"Error was {e}") + pass + finally: + # correct injection only once lambdas are properly set + # some routines return np.float32 which raises errors downstream in results + # processing, so we convert to float here. Should be handled more elegantly + args.injection_dict = {k: float(v) for k, v in converted_injection.items()} + self.gw_inputs= NMMAGravitationalWaveInput(args, self.unknown_args) + data_dump |= dict(waveform_generator=self.gw_inputs.waveform_generator, + ifo_list=self.gw_inputs.interferometers) # EM SETUP if args.em_model or args.em_transient_class: messengers.append('em') if self.injection_parameters: injection_model = create_injection_model(args) + converted_injection = injection_model.parameter_conversion( + converted_injection) + for param in injection_model.model_parameters: + if param not in self.injection_parameters: + try: + self.injection_parameters[param] = converted_injection[param] + except KeyError: + raise KeyError(f"Required parameter {param} could not be derived from conversion.") light_curve_data = create_light_curve_data( - self.injection_parameters, args, injection_model - ) + self.injection_parameters, args, injection_model + ) else: light_curve_data = load_em_observations(args) @@ -262,23 +309,6 @@ def adjust_priors_and_data(self, args, logger): priors = extinction_prior(priors, args) data_dump |= dict(light_curve_data=light_curve_data, filters = filters, systematics_dict = sys_handler.systematics_dict) - - - # EOS Setup - if args.emulator_metadata: - messengers.append("eos") - logger.info("Setting up EOS constraints") - data_dump |= dict(eos_constraint_dict= compose_eos_constraints(args)) - - elif args.eos_data: - analysis_modifiers.append('tabulated_eos') - eos_constraint_dict = compose_eos_constraints(args) - if eos_constraint_dict: - eos_converter = EoSConverter(args, 'tabulated') - constraint = JointEoSConstraint(eos_constraint_dict, eos_converter=eos_converter) - args.eos_weight, args.eos_data, args.Neos = constraint.tabulate_weighted_eos( - args.Neos, args.outdir, args.eos_weight) - priors = setup_tabulated_eos_priors(args, priors, logger) self.args = args self.messengers = messengers diff --git a/nmma/joint/joint_likelihood.py b/nmma/joint/joint_likelihood.py index 1b5c6e49..bff3e7a8 100644 --- a/nmma/joint/joint_likelihood.py +++ b/nmma/joint/joint_likelihood.py @@ -113,6 +113,7 @@ def setup_from_args(cls, data_dump, priors, args, logger=None): logger.info("Setting up GW likelihood") gw_kwargs = setup_gw_kwargs(data_dump, args, logger) messenger_lhoods.append(GravitationalWaveTransientLikelihood(priors, **gw_kwargs)) + priors.convert_floats_to_delta_functions() conversion_instructions['gw'] = True # placeholder if "eos" in messengers: @@ -152,10 +153,6 @@ def setup_from_args(cls, data_dump, priors, args, logger=None): if "log10_mej_wind" in priors or "log10_mej_dyn" in priors or args.ejecta_conversion: conversion_instructions['ejecta'] = True - # if "spec" in messengers: # FUTURE - # spec_kwargs = setup_spectroscopy_kwargs(data_dump, args, ...) - # messenger_lhoods.append(SpectroscopicLikelihood(priors, **spec_kwargs)) - if len(messenger_lhoods) == 0: raise ValueError("No messenger likelihoods were set up.") elif len(messenger_lhoods) == 1: From 57e6e7962b19606c8894884c8c1008dc059a07ea Mon Sep 17 00:00:00 2001 From: Hen42rik Date: Mon, 4 May 2026 13:46:30 +0200 Subject: [PATCH 10/13] ENH: more versatile plotting of inference results --- nmma/core/plotting_utils.py | 48 +++- nmma/core/utils.py | 25 ++- nmma/em/lightcurve_handling.py | 28 ++- nmma/em/plotting_utils.py | 256 ++++++++++++++-------- nmma/eos/eos_likelihood.py | 50 ++++- nmma/post_processing/plotting_routines.py | 176 +++++++++++---- 6 files changed, 407 insertions(+), 176 deletions(-) diff --git a/nmma/core/plotting_utils.py b/nmma/core/plotting_utils.py index 7d480b63..e6953760 100644 --- a/nmma/core/plotting_utils.py +++ b/nmma/core/plotting_utils.py @@ -2,8 +2,11 @@ from bilby.core.prior import PriorDict, DeltaFunction import numpy as np import matplotlib +from matplotlib.colors import LinearSegmentedColormap import os -matplotlib.use("Agg") +# matplotlib.use("Agg") +from matplotlib import pyplot as plt +import itertools def fig_setup(): fig_width_pt = 750.0 # Get this from LaTeX using \showthe\columnwidth @@ -18,7 +21,7 @@ def fig_setup(): "legend.fontsize": 18, "xtick.labelsize": 18, "ytick.labelsize": 18, - "font.family": "Times New Roman", + "font.family": "serif", "figure.figsize": fig_size, "mathtext.fontset": "stix", } @@ -45,7 +48,7 @@ def fig_setup(): matplotlib.rcParams.update(params) matplotlib.rcParams['text.usetex'] = (os.environ.get("CI") != 'true') - return [ + color_array = [ "#22ADFC", # blue "#F42969", # red "#F4A429", # orange @@ -58,7 +61,7 @@ def fig_setup(): "#8B4513", # brown "#FF6347", # tomato ] - + return itertools.cycle(color_array) def plotting_parameters_from_priors(priors, keys=None): """ @@ -83,4 +86,39 @@ def plotting_parameters_from_priors(priors, keys=None): if keys is None: keys = priors.keys() - return {k: v.latex_label for k, v in priors.items() if k in keys and not isinstance(v, DeltaFunction)} \ No newline at end of file + return {k: v.latex_label for k, v in priors.items() if k in keys and not isinstance(v, DeltaFunction)} + +def setup_multi_axes(num_axes, sharex=False, sharey=False, ncols=None): + "Set up a multi-panel figure with the specified number of axes, essentially stolen from corner.py" + + if ncols is None: + ncols = np.min([5, np.ceil(np.sqrt(num_axes)).astype(int)]) + nrows = np.ceil(num_axes / ncols).astype(int) + + + factor = 2.0 # size of one side of one panel + lbdim = 0.5 * factor # size of left/bottom margin + trdim = 0.2 * factor # size of top/right margin + whspace = 0.05 # w/hspace size + whspace = trdim + rowdim = lbdim + factor * ncols + factor * (ncols - 1.0) * whspace + trdim + coldim = lbdim + factor * nrows + factor * (nrows - 1.0) * whspace + trdim + + + + # Create a new figure if one wasn't provided. + fig, axes = plt.subplots(nrows, ncols, figsize=(rowdim, coldim), + sharex=sharex, sharey=sharey, constrained_layout=True) + + return fig, axes.flatten() + + +def fading_cmap(color): + cmap = LinearSegmentedColormap.from_list("custom_cmap", ["white",color], gamma = 2) + + cdict = cmap._segmentdata.copy() + vals = cdict['alpha'][:, 0] + alpha = np.linspace(0, 1, len(vals)) + cdict['alpha'] = np.column_stack([vals, alpha, alpha]) + + return LinearSegmentedColormap(cmap.name, cdict, cmap.N, cmap._gamma) \ No newline at end of file diff --git a/nmma/core/utils.py b/nmma/core/utils.py index c5f3c913..88456ffb 100644 --- a/nmma/core/utils.py +++ b/nmma/core/utils.py @@ -8,7 +8,6 @@ from bilby.core.result import read_in_result from astropy import time -from matplotlib.colors import LinearSegmentedColormap from pathlib import Path import yaml @@ -69,7 +68,8 @@ def read_trigger_time(parameters=None, args=None, out_format = 'mjd'): format = 'gps' trigger_time= time.Time(args.trigger_time, format=format) trigger_time # this fails if not a valid time - + elif args is not None: + args.trigger_time = trigger_time.mjd if out_format == 'mjd' else trigger_time.gps if out_format == 'mjd': return trigger_time.mjd @@ -224,12 +224,15 @@ def input_obj_to_str(input_obj, ref_name= None): else: raise TypeError("Input object could not be identified.") -def fading_cmap(color): - cmap = LinearSegmentedColormap.from_list("custom_cmap", ["white",color], gamma = 2) - - cdict = cmap._segmentdata.copy() - vals = cdict['alpha'][:, 0] - alpha = np.linspace(0, 1, len(vals)) - cdict['alpha'] = np.column_stack([vals, alpha, alpha]) - - return LinearSegmentedColormap(cmap.name, cdict, cmap.N, cmap._gamma) \ No newline at end of file +def nan_level(data, level, weights=None): + nans, clean_data = np.isnan(data), data[~np.isnan(data)] + if weights is not None: + weights = np.array(weights)[~nans] + weights = weights / np.sum(weights) + nan_share = np.sum((nans)) / len(data) + if nan_share > level: + return [np.nan, np.nan] + rest_level = level - nan_share + low = np.quantile(clean_data, (1-rest_level)/2, weights=weights, method='inverted_cdf') + up = np.quantile(clean_data, 1-(1-rest_level)/2, weights=weights, method='inverted_cdf') + return [low, up] diff --git a/nmma/em/lightcurve_handling.py b/nmma/em/lightcurve_handling.py index ee077acd..40d92872 100644 --- a/nmma/em/lightcurve_handling.py +++ b/nmma/em/lightcurve_handling.py @@ -63,12 +63,12 @@ def post_process_bestfit(transient, bestfit_params, args, result=None): print(f"Saved bestfit parameters and magnitudes to {bestfit_file}") if args.plot: - filters_to_plot = [ - filt for filt in transient.observed_filters - if not np.isnan(transient.light_curves[filt]).all() - ] - plot_error = {filt: model_error[filt] for filt in filters_to_plot} - mags_to_plot = {filt: best_mags[filt] for filt in filters_to_plot} + plot_kwargs = getattr(args, "plot_kwargs", {}) + if (plot_filters := plot_kwargs.pop("filters", None)) is None: + plot_filters = {filt: filt for filt in transient.observed_filters if not np.isnan(transient.light_curves[filt]).all()} + + plot_error = {k: model_error[k] for k in plot_filters.keys()} + mags_to_plot = {k: best_mags[k] for k in plot_filters.keys()} mags_to_plot["time"] = observable_times if isinstance(lc_model, model.CombinedLightCurveModelContainer): @@ -85,7 +85,7 @@ def post_process_bestfit(transient, bestfit_params, args, result=None): } plot_errors = [] plot_mags = [] - for filt in filters_to_plot: + for filt in plot_filters.keys(): try: plot_mags.append(utils.get_filtered_mag(mag_all[i], filt)) except KeyError: @@ -98,11 +98,15 @@ def post_process_bestfit(transient, bestfit_params, args, result=None): sub_model_plot_props[sub_model.model]['plot_errors'] = plot_errors else: sub_model_plot_props = None - - basic_em_analysis_plot( - transient, mags_to_plot, plot_error, chi2_dict, mismatches, - sub_model_plot_props, xlim = args.xlim, ylim = args.ylim, - save_path = os.path.join(args.outdir, f"{args.label}_lightcurves.png") + save_path = plot_kwargs.pop("save_path", + os.path.join(args.outdir, + f"{args.label}_bestfit_lightcurves.png")) + return basic_em_analysis_plot( + transient, plot_filters, mags_to_plot, plot_error, + chi2_dict, mismatches, sub_model_plot_props, + xlim = args.xlim, ylim = args.ylim, + save_path = save_path, + fig = getattr(args, "fig", None), **plot_kwargs ) diff --git a/nmma/em/plotting_utils.py b/nmma/em/plotting_utils.py index 797e2708..e194f293 100644 --- a/nmma/em/plotting_utils.py +++ b/nmma/em/plotting_utils.py @@ -1,120 +1,203 @@ import matplotlib.pyplot as plt import matplotlib +from matplotlib.ticker import MaxNLocator + from mpl_toolkits.axes_grid1 import make_axes_locatable import numpy as np import os -matplotlib.use("agg") params = { "backend": "pdf", "figure.figsize": [18, 25], } matplotlib.rcParams.update(params) matplotlib.rcParams['text.usetex'] = (os.environ.get("CI") != 'true') +matplotlib.rcParams['figure.labelsize'] = 20 +# matplotlib.rcParams['axes.titlesize'] = 'large' +if os.environ.get("CI") == 'true': + matplotlib.use("agg") +from nmma.core.plotting_utils import fig_setup +nmma_colors = fig_setup() ############################################## ################# MAIN PLOTS ################# ############################################## def basic_em_analysis_plot( - transient, mags_to_plot, error_dict, chi2_dict, mismatches, - sub_model_plot_props, xlim, ylim, save_path, - ncol = 2): + transient, plot_filters, mags_to_plot, error_dict, chi2_dict, + mismatches, sub_model_plot_props, xlim, ylim, save_path, + ncol = 2, fig = None, shared_data = True, **kwargs): + + + if fig: + return add_em_analysis_plot(fig, transient, mags_to_plot, error_dict, mismatches, sub_model_plot_props, xlim, ylim, save_path, shared_data, **kwargs) time = mags_to_plot.pop("time") - filter_names = list(mags_to_plot.keys()) + filter_names = list(plot_filters.keys()) + n_filters = len(filter_names) + data_colors = plt.cm.plasma(np.linspace(0, 1, len(filter_names)))[::-1] + color = kwargs.pop('color', next(nmma_colors)) + fig, axes = analysis_plot_geometry(filter_names, ncol=ncol) + fig.supylabel("AB magnitude", rotation=90) + fig.supxlabel("Time [days]") if xlim is None: xlim = get_time_limits_from_obs_data(transient) else: xlim = check_limit(xlim) - + if ylim is None: ylim = get_mag_limits_from_obs_data(transient, filter_names) else: shared_ylim = check_limit(ylim) ylim = {filt: shared_ylim for filt in filter_names} - fig, axes = analysis_plot_geometry(filter_names, ncol=ncol) - colors = plt.cm.Spectral(np.linspace(0, 1, len(filter_names)))[::-1] + for cnt, filt in enumerate(filter_names): - # summary plot row, col = divmod(cnt, ncol) ax_sum = axes[row, col] + ax_sum.set_ylim(ylim[filt]) + ax_sum.set_xlim(xlim) + if xlim[0] > 0: + ax_sum.set_xscale('log') + ax_sum.set_ylabel(plot_filters[filt]) + # adding the ax for the Delta divider = make_axes_locatable(ax_sum) ax_delta = divider.append_axes('bottom', - size='30%', - sharex=ax_sum) + size='40%', + sharex=ax_sum, + pad=0.1) + ax_delta.set_ylabel(r"$\Delta$ mag") - # configuring ax_sum - ax_sum.set_ylabel("AB magnitude", rotation=90) - ax_delta.set_ylabel(r"$\Delta (\sigma)$") - if cnt == len(filter_names)-1: - ax_delta.set_xlabel("Time [days]") + # ax_delta.set_yscale('log') + # ax_delta.yaxis.set_major_locator(MaxNLocator(min_n_ticks=2)) + + if cnt not in [n_filters - i-1 for i in range(ncol)]: + ax_sum.set_xticklabels([]) else: - ax_delta.set_xticklabels([]) + plt.setp(ax_sum.get_xticklabels(), visible=False) + # only show x labels on the lowest delta plots + + # configuring ax_sum + + if cnt not in [n_filters - i-1 for i in range(ncol)]: # only show x labels on the last two plots + ax_sum.set_xticklabels([]) # plot the observations - ax_sum, det_times = plot_observations(ax_sum, transient, colors[cnt], filter=filt) + ax_sum, det_times = plot_observations(ax_sum, transient, data_colors[cnt], filter=filt) if det_times.size>0: # plot the mismatch between the model and the data - diff_per_data, sigma_per_data = mismatches[filt] + diff_per_data, _ = mismatches[filt] ax_delta.axhline(0, linestyle='--', color='k') - ax_delta.scatter(det_times, diff_per_data, color=colors[cnt]) - - ax_sum.set_title(f'{filt}: ' + fr'$\chi^2 / d.o.f. = {round(chi2_dict[filt], 2)}$') + ax_delta.scatter(det_times, diff_per_data, color=color) + ax_sum.plot([], [], label=fr'$\chi^2$ / d.o.f. = {round(chi2_dict[filt], 2)}', color=color) + ax_sum.legend(loc='best', frameon=False, handlelength=0, handletextpad=0) - else: - ax_sum.set_title(f'{filt}') # plot the best-fit lc with errors - mag_plot = mags_to_plot[filt] - error_budget = error_dict[filt] - label = 'combined' if sub_model_plot_props is not None else "" - - ax_sum.plot(time, mag_plot, - color='coral', linewidth=3, linestyle="--") - ax_sum.fill_between(time, - mag_plot + error_budget, - mag_plot - error_budget, - facecolor='coral', - alpha=0.2, - label=label, - ) + plot_bestfit_with_errors(ax_sum, time, mags_to_plot[filt], error_dict[filt], sub_model_plot_props, cnt, color) - if sub_model_plot_props is not None: - ## plot additional lcs for each sub_model - for model_name, prop_dict in sub_model_plot_props.items(): - mag_plot = prop_dict['plot_mags'][cnt] - mag_err = prop_dict['plot_errors'][cnt] - plot_times = prop_dict['plot_times'] - ax_sum.plot(plot_times, mag_plot, - color='coral', linewidth=3, linestyle="--") - ax_sum.fill_between(plot_times, - mag_plot + mag_err, - mag_plot - mag_err, - facecolor=prop_dict['color'], - alpha=0.2, - label=model_name, - ) + + fig.tight_layout() + if save_path: + fig.savefig(save_path, bbox_inches='tight') + return fig + +def add_em_analysis_plot(fig, + transient, mags_to_plot, error_dict, mismatches, + sub_model_plot_props, xlim, ylim, save_path, shared_data = True, **kwargs): + + time = mags_to_plot.pop("time") + filter_names = list(mags_to_plot.keys()) + n_axes = int(np.ceil(len(fig.axes)/2)) + + data_colors = plt.cm.plasma(np.linspace(0, 1, len(filter_names)))[::-1] + color = kwargs.pop('color', next(nmma_colors)) + + if not shared_data: + if xlim is None: + add_xlim = get_time_limits_from_obs_data(transient) + old_xlim = fig.axes[0].get_xlim() + xlim = (min(add_xlim[0], old_xlim[0]), max(add_xlim[1], old_xlim[1])) + if ylim is None: + add_ylim = get_mag_limits_from_obs_data(transient, filter_names) + ylim = {} + for i, filt in enumerate(filter_names): + old_ylim = fig.axes[i].get_ylim() + ylim[filt] = (max(add_ylim[filt][0], old_ylim[0]), min(add_ylim[filt][1], old_ylim[1])) + + for cnt, filt in enumerate(filter_names): - ax_sum.set_ylim(ylim[filt]) - ax_sum.set_xlim(xlim) - ax_delta.set_xlim(xlim) - if xlim[0] > 0: - ax_sum.set_xscale('log') + # summary plot + ax_sum = fig.axes[cnt] + ax_delta = fig.axes[cnt+n_axes] - plt.tight_layout() - plt.savefig(save_path, bbox_inches='tight') - plt.close() + # plot the observations + ax_sum, det_times = plot_observations(ax_sum, transient, data_colors[cnt], filter=filt) + if not shared_data: + ax_sum.set_xlim(xlim) + ax_sum.set_ylim(ylim[filt]) + ax_delta.set_xlim(xlim) + + if det_times.size>0: + # plot the mismatch between the model and the data + diff_per_data, _ = mismatches[filt] + delta_ylim = ax_delta.get_ylim() + ylim = (min(delta_ylim[0], 0.9*min(diff_per_data)), + max(delta_ylim[1], 1.1*max(diff_per_data))) + ax_delta.set_ylim(ylim) + ax_delta.axhline(0, linestyle='--', color='k') + ax_delta.scatter(det_times, diff_per_data, color=color) + + ax_sum.get_legend().remove() + # ax_sum.plot([], [], label=round(chi2_dict[filt], 2), color=color) + # ax_sum.legend(loc='best', frameon=False, handlelength=0, handletextpad=0, labelcolor='linecolor') + + plot_bestfit_with_errors(ax_sum, time, mags_to_plot[filt], error_dict[filt], sub_model_plot_props, cnt, color) + + + fig.tight_layout() + if save_path: + fig.savefig(save_path, bbox_inches='tight') + return fig + +def plot_bestfit_with_errors(ax_sum, time, mag_plot, error_budget, + sub_model_plot_props, cnt, color): + + label = 'combined' if sub_model_plot_props is not None else "" + ax_sum.plot(time, mag_plot, + color=color, linewidth=3, linestyle="--") + ax_sum.fill_between(time, + mag_plot + error_budget, + mag_plot - error_budget, + facecolor=color, + alpha=0.2, + label=label, + ) + + if sub_model_plot_props is not None: + ## plot additional lcs for each sub_model + for model_name, prop_dict in sub_model_plot_props.items(): + mag_plot = prop_dict['plot_mags'][cnt] + mag_err = prop_dict['plot_errors'][cnt] + plot_times = prop_dict['plot_times'] + ax_sum.plot(plot_times, mag_plot, + color=color, linewidth=3, linestyle="--") + ax_sum.fill_between(plot_times, + mag_plot + mag_err, + mag_plot - mag_err, + facecolor=prop_dict['color'], + alpha=0.2, + label=model_name, + ) + def bolometric_lc_plot(transient, time, lc, save_path, color = "coral"): matplotlib.rcParams.update( - {'font.size': 12, - # 'font.family': 'Times New Roman' + {'font.size': 12, 'font.family': 'serif' } ) fig, ax = plt.subplots(1, 1) @@ -181,7 +264,7 @@ def spec_plot_func(fig, ax, XX, YY, plot_data, label): def chi2_hists_from_dict(chi2_dict, outpath): matplotlib.rcParams.update( - {"font.size": 16, "font.family": "Times New Roman"} + {"font.size": 16, "font.family": "Serif"} ) for filt, chi2_array in chi2_dict.items(): plt.figure() @@ -292,7 +375,7 @@ def return_hist(x): hist, _ = np.histogram(x, bins=bins) return hist - hist = np.apply_along_axis(lambda x: return_hist(x), -1, plot_data.T) + hist = np.apply_along_axis(return_hist, -1, plot_data.T) bins = (bins[1:] + bins[:-1]) / 2.0 X, Y = np.meshgrid(sample_times, bins) @@ -353,27 +436,6 @@ def check_limit(lim): assert len(lim) == 2, f"{lim} is no valid plot-limit." return lim -def get_time_limits_from_obs_data(transient): - """ - A function that goes through the lc data and finds the time range that encompasses all data points. - """ - - xmin = np.min([t_arr.min() for t_arr in transient.light_curve_times.values()]) - xmax = np.max([t_arr.max() for t_arr in transient.light_curve_times.values()]) - - return (0.8*xmin, 1.2*xmax) - -def get_mag_limits_from_obs_data(transient, filter_names): - """ - A function that goes through the lc data and finds the magnitude range for each filter. - """ - - ylim = {} - for filt in filter_names: - ylim[filt] = (1.2*transient.light_curves[filt].max(), 0.8*transient.light_curves[filt].min()) - - return ylim - def plot_observations(ax, transient, color="k",**kwargs): obs_times, obs_lc, obs_unc = transient.light_curve_times, transient.light_curves, transient.light_curve_uncertainties if 'filter' in kwargs: @@ -401,7 +463,6 @@ def analysis_plot_geometry(filters_to_plot, ncol=2): wpanel = 3. nrow = int(np.ceil(len(filters_to_plot) / ncol)) - fig, axes = plt.subplots(nrow, ncol) figsize = (1.5 * (lspace + wpanel * ncol + wspace * (ncol - 1) + trspace), 1.5 * (bspace + hpanel * nrow + hspace * (nrow - 1) + trspace)) @@ -409,25 +470,27 @@ def analysis_plot_geometry(filters_to_plot, ncol=2): fig, axes = plt.subplots(nrow, ncol, figsize=figsize, squeeze=False) fig.subplots_adjust(left=lspace / figsize[0], bottom=bspace / figsize[1], - right=1. - trspace / figsize[0], - top=1. - trspace / figsize[1], + right = 1. - trspace / figsize[0], + top = 1. - trspace / figsize[1], wspace=wspace / wpanel, hspace=hspace / hpanel) - if len(filters_to_plot) % 2: - axes[-1, -1].axis('off') + if len(filters_to_plot) % ncol: + for i in range(len(filters_to_plot) % ncol, ncol): + axes[-1, i-ncol].axis('off') return fig, axes + def get_time_limits_from_obs_data(transient): """ A function that goes through the lc data and finds the time range that encompasses all data points. """ - - xmin = np.min([t_arr.min() for t_arr in transient.light_curve_times.values()]) - xmax = np.max([t_arr.max() for t_arr in transient.light_curve_times.values()]) - return (0.8*xmin, 1.2*xmax) + xmin = np.min([t_arr.min() for t_arr in transient.light_curve_times.values()]) + xmax = np.max([t_arr.max() for t_arr in transient.light_curve_times.values()]) + + return (0.9*xmin, 1.1*xmax) def get_mag_limits_from_obs_data(transient, filter_names): """ @@ -436,7 +499,8 @@ def get_mag_limits_from_obs_data(transient, filter_names): ylim = {} for filt in filter_names: - ylim[filt] = (1.2*transient.light_curves[filt].max(), 0.8*transient.light_curves[filt].min()) - + min_mag = transient.light_curves[filt].min() + max_mag = transient.light_curves[filt].max() + ylim[filt] = (min(1.05*max_mag, 1+ max_mag), max(min_mag-1, 0.95*min_mag)) return ylim diff --git a/nmma/eos/eos_likelihood.py b/nmma/eos/eos_likelihood.py index e1dc6455..1db3cab2 100644 --- a/nmma/eos/eos_likelihood.py +++ b/nmma/eos/eos_likelihood.py @@ -10,11 +10,13 @@ from scipy.special import logsumexp from scipy.stats import norm from scipy.ndimage import gaussian_filter +# from scipy.spatial import ConvexHull from matplotlib import pyplot as plt from bilby.core.prior import WeightedCategorical, PriorDict from .eos_processing import EoSConverter from ..core.base import NMMALikelihood -from ..core.utils import fading_cmap +from ..core.utils import nan_level +from ..core.plotting_utils import fading_cmap def setup_tabulated_eos_priors(args, priors, logger=None): if logger: @@ -64,10 +66,10 @@ def setup_submodel_conversion(self): def final_diagnostics(self, bestfit_params, args, result=None, fig = None): matplotlib.rcParams.update({'font.size': 16, 'font.family': 'serif'}) - # matplotlib.rcParams['font.serif'] = ['Times New Roman'] + plt.rcParams['font.serif'] + bestfit_params =self.parameter_conversion(bestfit_params) - radii, masses, lambdas = self.sub_model.eos_converter.macro_parameters.values() + if fig is None: x_lim = (min(np.min(radii)-0.3, 9), max(np.max(radii)+0.3, 15)) y_lim = (masses[0], masses[-1]+0.1) @@ -82,6 +84,11 @@ def final_diagnostics(self, bestfit_params, args, result=None, fig = None): ax = constraint.plot(ax=ax, color=color) else: ax = fig.axes[0] + xlow, xhigh = ax.get_xlim() + ylow, yhigh = ax.get_ylim() + ax.set_xlim(min(np.min(radii)-0.3, xlow), max(np.max(radii)+0.3, xhigh)) + ax.set_ylim(min(masses[0], ylow), max(masses[-1]+0.1, yhigh)) + labels = [line.get_label() for line in fig.legends[0].legend_handles] for constraint in self.sub_model.constraints: if constraint.name in labels: @@ -89,7 +96,31 @@ def final_diagnostics(self, bestfit_params, args, result=None, fig = None): color = ax._get_lines.get_next_color() ax = constraint.plot(ax=ax, color=color) fig.legends.clear() ## remove old legend to avoid duplicates - ax.plot(radii, masses, label=f'{args.label} Best fit EoS', zorder=10) + + + line =ax.plot(radii, masses, label=f'{args.label}',linewidth=3, zorder=10) + + if result is not None: + cmap = fading_cmap(line[0].get_color()) + posterior = self.parameter_conversion(result.posterior) + # posterior['log_post'] = posterior.log_likelihood + posterior.log_prior + # post_weights = np.exp(posterior.log_post) + post_weights = None + eos_data = self.sub_model.eos_converter.macro_conversion(posterior) + max_masses = np.max([eos[1][-1] for eos in eos_data]) + mass_range = np.linspace(1.0, max_masses, 151) + show_radii = np.empty((len(mass_range), len(eos_data))) + + for i, eos in enumerate(eos_data): + show_radii[:,i] = np.interp(mass_range, eos[1], eos[0], right = np.nan) + + for level in [0.5, 0.9]: + bounds = np.array([nan_level(radii, level, post_weights) for radii in show_radii]) + ax.fill_betweenx(mass_range, bounds[:, 0], bounds[:, 1], color=cmap(1-0.5*level), zorder=1) + + if result.injection_parameters is not None: + inj_eos = self.sub_model.eos_converter.macro_conversion(result.injection_parameters) + ax.plot(inj_eos[0], inj_eos[1], label='Injection', color='black', linestyle='dashed', linewidth=3, zorder=10) fig.legend(ncols=2, loc='upper center', bbox_to_anchor=(0.5, 0.00), handlelength=2) @@ -207,7 +238,7 @@ def initialise_from_dict(self, constraint_dict): elif constraint_kind == 'mass_radius': for label, constraint in sub_constraints.items(): constraint_list.append(MassRadiusConstraint( - file_path=constraint.get('posterior', None), + file_path=constraint.get('posterior', constraint.get('file_path', None)), name=label, arxiv_ref=constraint.get('arxiv', None) )) @@ -333,16 +364,17 @@ def log_likelihood(self, parameters, local_parameters=None): tov_mass = local_parameters['masses'][-1] return self.lognorm_method(tov_mass, loc=self.mass, scale=self.error) - def plot(self, ax, resolution = 100, **kwargs): + def plot(self, ax, **kwargs): """Plot the mass constraint on the given figure.""" x_lim = ax.get_xlim() dummy_line = ax.plot([], [], label=self.name, linestyle=self.linestyle, linewidth=2.5, **kwargs) line = ax.hlines(self.mass, *x_lim, linestyle=self.linestyle, linewidth=2.5, zorder=3, **kwargs) - cmap = fading_cmap(line.get_color()) + cmap = fading_cmap(dummy_line[0].get_color()) levels = [0.95, 0.68] for i, level in enumerate(levels): - ax.hlines(self.mass + (i+1)*self.error, *x_lim, color=cmap(0.9*level), linewidth=1.5, zorder=1) - ax.hlines(self.mass - (i+1)*self.error, *x_lim, color=cmap(0.9*level), linewidth=1.5, zorder=1) + ax.fill_between(x_lim, self.mass - (i+1)*self.error, self.mass + (i+1)*self.error, color=cmap(0.8*level), zorder=2-i) + # ax.hlines(self.mass + (i+1)*self.error, *x_lim, color=cmap(0.9*level), linewidth=1.5, zorder=1) + # ax.hlines(self.mass - (i+1)*self.error, *x_lim, color=cmap(0.9*level), linewidth=1.5, zorder=1) return ax diff --git a/nmma/post_processing/plotting_routines.py b/nmma/post_processing/plotting_routines.py index a500f28b..f64a3758 100644 --- a/nmma/post_processing/plotting_routines.py +++ b/nmma/post_processing/plotting_routines.py @@ -3,58 +3,30 @@ import numpy as np import pandas as pd import matplotlib +from matplotlib.ticker import MaxNLocator import seaborn from matplotlib import gridspec, pyplot as plt from ast import literal_eval -import itertools from ..core.conversion import chirp_mass_and_eta_to_component_masses, tidal_deformabilities_and_mass_ratio_to_eff_tidal_deformabilities, label_mapping from ..core import utils, parsing from ..core import plotting_utils as corepu from .parser import corner_plot_parser -color_array = corepu.fig_setup() -nmma_colors = itertools.cycle(color_array) +nmma_colors = corepu.fig_setup() - -def plot_multi_corner(args, key_selection=None): - - plot_kwargs = literal_eval(args.kwargs) - quantiles = [0.16, 0.5, 0.84] - fig = None - labels = [lab for lab in args.label_name] if args.label_name is not None else [f for f in args.posterior_files] - for i, f in enumerate(args.posterior_files): - plot_keys, plot_labels = corepu.plotting_parameters_from_priors(args.prior, keys=key_selection).items() - if args.injection_json is not None: - truths = utils.read_injection_file(args.injection_json) - truths = truths.iloc[args.injection_num].to_dict() - truths = np.array([truths[k] for k in plot_keys]) - if args.verbose: - print("\nLoaded Injection:") - print(f"Truths from injection: {truths}") - elif args.bestfit_params is not None: - truths = utils.read_bestfit_from_json(args.bestfit_json, plot_keys, args.verbose) - else: - truths = None - - fig = setup_corner_plot(f, [], label =labels[i], truths = truths, fig=fig, - quantiles=quantiles, plot_keys=plot_keys, default_labels = plot_labels, **plot_kwargs) - - filename, ext = os.path.splitext(args.output) - if not ext: - filename = os.path.join(os.getcwd(), f"{filename}.png") - fig.savefig(filename, bbox_inches="tight", dpi=300) - print("\nSaved corner plot:", filename) - - -def setup_corner_plot(posterior_samples,limits = None, plot_keys = None, fig = None, - injection=None, post_dir = None, default_labels={}, **plot_kwargs): +def setup_plot_quantities(posterior_samples, limits, plot_keys, injection, post_dir = None, default_labels={}, **plot_kwargs): + matplotlib.rcParams.update({'font.size': 16, 'font.family': 'serif'}) #load samples posterior_samples = utils.get_posteriors(posterior_samples, post_dir) + best_fit = posterior_samples.iloc[posterior_samples['log_likelihood'].idxmax()] if plot_keys is None: plot_keys = posterior_samples.columns.tolist() # show all we can + for key in ['log_likelihood', 'log_prior']: + if key in plot_keys: + plot_keys.remove(key) # but not the likelihood itself if limits is None: - limits = [(np.inf, -np.inf) for key in plot_keys] # will adjust more permissively later + limits = [(np.inf, -np.inf) for _ in plot_keys] # will adjust more permissively later # find what to actually plot plot_samples, plot_labels, titles = [], [], [] for i, k in enumerate(plot_keys): @@ -78,15 +50,112 @@ def setup_corner_plot(posterior_samples,limits = None, plot_keys = None, fig = N plot_samples.append(np.linspace(100*cur_max, 1001*cur_max, posterior_samples.shape[0])) plot_labels.append('') titles.append('') + plot_samples = np.column_stack(plot_samples) - if injection is not None: truths = [injection.get(key, None) for key in plot_keys] else: truths = None - # limits = ((np.amin(posterior_samples[k]), np.amax(posterior_samples[k])) for k in plot_keys) + + + return plot_samples, plot_labels, titles, limits, truths, best_fit + +def plot_histograms_only(posterior_samples,limits = None, plot_keys = None, fig = None, + injection=None, post_dir = None, default_labels={}, best_fit = False, ncols=None, loc_labels='left',title_kwargs ={}, **plot_kwargs): + plot_samples, plot_labels, titles, limits, truths, best_fit = setup_plot_quantities( + posterior_samples, limits, plot_keys, injection, post_dir, default_labels, **plot_kwargs) + + if fig is None: + fig, axes = corepu.setup_multi_axes(len(plot_keys), ncols=ncols) + else: + axes = fig.get_axes() + label = plot_kwargs.pop('label', None) + color = plot_kwargs.pop('color', next(nmma_colors)) + for i, ax in enumerate(axes): + if i >= len(plot_keys): + ax.axis('off') # Hide any extra subplots + continue + ax.hist(plot_samples[:, i], bins=50, density=True, color=color, alpha=0.7,histtype = 'step', **plot_kwargs) + if loc_labels == 'left': + ax.yaxis.set_label_position("left") + ax.set_ylabel(plot_labels[i], fontsize=16) + elif loc_labels == 'top': + ax.xaxis.set_label_position("top") + ax.set_xlabel(plot_labels[i], fontsize=16) + # ax.xaxis.set_label_coords(0.5, -0.3) + ax = set_title(ax, titles[i], plot_labels[i], color, **title_kwargs) + + if truths is not None and truths[i] is not None: + ax.axvline(truths[i], color='tab:orange', linestyle='--') + + if isinstance(best_fit, (dict, pd.Series)) and plot_keys[i] in best_fit: + ax.axvline(best_fit[plot_keys[i]], color=color, linestyle='-.') + + ax.xaxis.set_major_locator(MaxNLocator(nbins=4, min_n_ticks=3, prune='both')) + [l.set_rotation(45) for l in ax.get_xticklabels()] + [l.set_rotation(45) for l in ax.get_xticklabels(minor=True)] + ax.autoscale(enable=True, axis='x', tight=True) + ax.set_yticks([]) + ax.set_yticklabels([]) + + + # allow joint legend + if label: + fig.legends.clear() + fig.axes[0].plot([], [], label = label, color= color, **plot_kwargs) + fig.legend(ncols=2, loc='upper center', bbox_to_anchor=(0.5, 0.), handlelength=2) + + return fig, limits + + + + + +def plot_multi_corner(args, key_selection=None): + + plot_kwargs = literal_eval(args.kwargs) + quantiles = [0.16, 0.5, 0.84] + fig = None + labels = [lab for lab in args.label_name] if args.label_name is not None else [f for f in args.posterior_files] + for i, f in enumerate(args.posterior_files): + plot_keys, plot_labels = corepu.plotting_parameters_from_priors(args.prior, keys=key_selection).items() + if args.injection_json is not None: + truths = utils.read_injection_file(args.injection_json) + truths = truths.iloc[args.injection_num].to_dict() + truths = np.array([truths[k] for k in plot_keys]) + if args.verbose: + print("\nLoaded Injection:") + print(f"Truths from injection: {truths}") + elif args.bestfit_params is not None: + truths = utils.read_bestfit_from_json(args.bestfit_json, plot_keys, args.verbose) + else: + truths = None + + fig = setup_corner_plot(f, [], label =labels[i], truths = truths, fig=fig, + quantiles=quantiles, plot_keys=plot_keys, default_labels = plot_labels, **plot_kwargs) + + filename, ext = os.path.splitext(args.output) + if not ext: + filename = os.path.join(os.getcwd(), f"{filename}.png") + fig.savefig(filename, bbox_inches="tight", dpi=300) + print("\nSaved corner plot:", filename) + + +def setup_corner_plot(posterior_samples,limits = None, plot_keys = None, fig = None, + injection=None, post_dir = None, default_labels={}, **plot_kwargs): + + plot_samples, plot_labels, titles, limits, truths, _ = setup_plot_quantities( + posterior_samples, limits, plot_keys, injection, post_dir, default_labels, **plot_kwargs) + color = plot_kwargs.pop('color', next(nmma_colors)) - fig = corner_plot(plot_samples, plot_labels, limits, fig=fig, truths= truths, color = color, titles=titles, show_titles = False, **plot_kwargs) + fig = corner_plot(plot_samples, plot_labels, limits, fig=fig, truths= truths, color = color, show_titles = False, **plot_kwargs) + + + # adjust titles + axes = fig.get_axes() + for i, title in enumerate(titles): + ax = axes[i*len(plot_labels) + i] + ax = set_title(ax, title, plot_labels[i], color) # allow joint legend if 'label' in plot_kwargs: @@ -98,11 +167,11 @@ def setup_corner_plot(posterior_samples,limits = None, plot_keys = None, fig = N def corner_plot(plot_samples, labels, limits, fig = None, save=False, **kwargs): - matplotlib.rcParams.update({'font.size': 16, 'font.family': 'Times New Roman'}) + matplotlib.rcParams.update({'font.size': 16, 'font.family': 'Serif'}) matplotlib.rcParams['text.usetex'] = (os.environ.get("CI") != 'true') default_kwargs = dict(bins=50, smooth=1.3, label_kwargs=dict(fontsize=16), show_titles=True, - title_kwargs=dict(fontsize=16), color = color_array[0], #color='#0072C1', - truth_color='tab:orange', quantiles=[0.05, 0.5, 0.95], + title_kwargs=dict(fontsize=16), color = next(nmma_colors), #color='#0072C1', + truth_color='tab:orange', quantiles=[0.16, 0.5, 0.84], levels=(0.10, 0.32, 0.68, 0.95), median_line=True, title_fmt=".2f", plot_density=False, plot_datapoints=False, fill_contours=True, max_n_ticks=4, hist_kwargs={'density': True}) @@ -114,6 +183,27 @@ def corner_plot(plot_samples, labels, limits, fig = None, save=False, **kwargs): return fig +def set_title(ax, title, label, color, **title_kwargs): + old_text = ax._left_title.get_text() + if old_text == '': # meaning no title set, so we can set our own + text = f'{label}={title}' + ax.set_title(text, loc = 'left', color=color, **title_kwargs) + + elif ax._right_title.get_text() == '': # meaning only left title set yet + old_text_parts = old_text.split('=') + ax._left_title.set_text(old_text_parts[-1]) # remove old title + ax.set_title(title, loc='right', color=color, **title_kwargs) + + elif ax.get_title() == '': # meaning only left and right title set yet + ax.set_title(title, loc='center', color=color, pad=10, **title_kwargs) + else: + # we already have a title in all three locations, so we rather remove all + print("Warning: More than three titles set for this axis, so all titles were removed to avoid confusion. ") + print(f"All intended title information was {label}: {old_text} (first, left), {ax._right_title.get_text()} (second, right), {ax.get_title()} (third, center), {title} (new).") + ax._left_title.set_text('') + ax._right_title.set_text('') + ax.set_title('', loc='center', **title_kwargs) + return ax def resampling_corner_plot(posterior_samples, solution, outdir, withNSBH): From 17ae5b68a6a53792d6f696b1f9317ae342c9d23b Mon Sep 17 00:00:00 2001 From: Hen42rik Date: Mon, 4 May 2026 13:48:46 +0200 Subject: [PATCH 11/13] BUG: adjust proper handling of timeshift and trigger time in injection setups --- nmma/em/em_likelihood.py | 2 +- nmma/em/lightcurve_generation.py | 5 ++++- nmma/em/utils.py | 4 ++-- nmma/population/pop_likelihood.py | 35 ++++++++++++------------------- requirements.txt | 4 +--- 5 files changed, 21 insertions(+), 29 deletions(-) diff --git a/nmma/em/em_likelihood.py b/nmma/em/em_likelihood.py index a0e4c1f9..841cb891 100644 --- a/nmma/em/em_likelihood.py +++ b/nmma/em/em_likelihood.py @@ -21,7 +21,7 @@ def setup_em_kwargs(priors, data_dump, args, logger=None): light_curve_model = model.create_light_curve_model_from_args(lc_model, args, filters) trigger_time = read_trigger_time(None, args) light_curve_data = utils.setup_filtered_lc_data(light_curve_data, trigger_time) - light_curve_data = utils.check_model_time_consistency(light_curve_data, light_curve_model, priors) + light_curve_data = utils.check_model_time_consistency(light_curve_data, light_curve_model, priors, args.injection) sys_handler = systematics.FilterSystematicsHandler(filters, data_dump['systematics_dict'], error_budget=args.em_error_budget, light_curve_times=light_curve_data[0]) diff --git a/nmma/em/lightcurve_generation.py b/nmma/em/lightcurve_generation.py index bd6983e9..e83b8540 100644 --- a/nmma/em/lightcurve_generation.py +++ b/nmma/em/lightcurve_generation.py @@ -11,6 +11,7 @@ from scipy.interpolate import CubicSpline from . import utils +from ..core.utils import read_trigger_time try: import afterglowpy @@ -823,7 +824,9 @@ def create_light_curve_data( injection_parameters = light_curve_model.parameter_conversion(injection_parameters) filters = utils.set_filters(args) - trigger_time = injection_parameters.get("trigger_time", 0.) + trigger_time = read_trigger_time(injection_parameters, args) + if trigger_time is None: + trigger_time = 0. if rng is None: rng = np.random.default_rng(args.generation_seed) if getattr(args, 'absolute', False): diff --git a/nmma/em/utils.py b/nmma/em/utils.py index dfdaa0e9..e4fca364 100644 --- a/nmma/em/utils.py +++ b/nmma/em/utils.py @@ -103,6 +103,7 @@ def set_filters(args): em_detectors = args.em_detectors.split(",") else: em_detectors = getattr(args, "em_detectors", []).copy() + em_detectors = [det.strip().lower() for det in em_detectors] filters = [] if 'ztf' in em_detectors: em_detectors.remove('ztf') @@ -110,7 +111,7 @@ def set_filters(args): if 'lsst' in em_detectors: em_detectors.remove('lsst') filters.extend( ["lsstg", "lsstr", "lssti", "lsstz", "lssty"] ) - elif hasattr(args, "rubin_ToO_type"): + elif getattr(args, "rubin_ToO_type"): if args.rubin_ToO_type == 'platinum': filters.extend( ["ps1::g","ps1::r","ps1::i","ps1::z","ps1::y"] ) elif args.rubin_ToO_type == 'gold': @@ -129,7 +130,6 @@ def set_filters(args): if em_detectors: raise NotImplementedError(f"{em_detectors} not implemented yet.") ## to be extended - return filters diff --git a/nmma/population/pop_likelihood.py b/nmma/population/pop_likelihood.py index d94ed6a5..833584a4 100644 --- a/nmma/population/pop_likelihood.py +++ b/nmma/population/pop_likelihood.py @@ -1,22 +1,15 @@ import numpy as np from scipy.stats import uniform, truncnorm -# class NeutronStarPopulationLikelihood(NMMALikelihood): -# """ -# Base class for neutron star population likelihoods. -# This class is intended to be subclassed for specific population models. -# """ - -# def __init__(self, pop_model): -# super().__init__(pop_model) class NeutronStarPopulation: - #based on https://doi.org/10.3847/2041-8213/ac2f3e """ - Object to compute the likelihood of a binary to align with - a given population model from Landry & Read.""" - def __init__(self, model_name): - self.beta = 0.0 + Object to compute the likelihood of a binary to align with + a given population model from Landry & Read. + (https://doi.org/10.3847/2041-8213/ac2f3e) + """ + def __init__(self, model_name, beta=0.0): + self.beta = beta if model_name.lower() == 'flat': m_min, m_max = 1.1, 2.0 self.distribution = uniform(loc=m_min, scale=m_max) @@ -24,14 +17,12 @@ def __init__(self, model_name): m_min, m_max = 1.1, 2.1 loc = 1.5 scale = 1.0 - trunc_low, trunc_high = (m_min - loc) / scale, (m_max - loc) / scale - self.distribution = truncnorm(trunc_low, trunc_high, loc=loc, scale=scale) - - + trunc_low = (m_min - loc) / scale + trunc_high = (m_max - loc) / scale + self.distribution = truncnorm(trunc_low, trunc_high, + loc=loc, scale=scale) def log_likelihood(self, parameters): - # - return ( self.distribution.logpdf(parameters['mass_1_source']) - + self.distribution.logpdf(parameters['mass_2_source']) - + np.log(parameters['mass_ratio']**self.beta) - ) + return (self.distribution.logpdf(parameters['mass_1_source']) + + self.distribution.logpdf(parameters['mass_2_source']) + + np.log(parameters['mass_ratio']**self.beta)) diff --git a/requirements.txt b/requirements.txt index 8b34e285..1ec16d56 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ future -bilby>=2.7.1 +bilby>=2.8 bilby_pipe>=1.7.0 schwimmbad colorcet @@ -18,6 +18,4 @@ toml ligo.skymap healpy scikit-learn -tensorflow>=2.19; platform_system == "Darwin" -tensorflow>=2.19; platform_system == "Windows" tensorflow-cpu>=2.19; platform_system == "Linux" From e45a0d424e24bb311925c7b0aeb11d897ca49bab Mon Sep 17 00:00:00 2001 From: Hen42rik Date: Mon, 4 May 2026 14:13:51 +0200 Subject: [PATCH 12/13] bug: remove breakpoints --- nmma/core/mpi_setup.py | 1 - nmma/post_processing/maximum_mass_constraint.py | 9 +++------ 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/nmma/core/mpi_setup.py b/nmma/core/mpi_setup.py index 6517af32..fa954472 100644 --- a/nmma/core/mpi_setup.py +++ b/nmma/core/mpi_setup.py @@ -142,7 +142,6 @@ def __init__( plot= False, meta_data = {}, ): - breakpoint() super().__init__(args, prior, likelihood, injection_parameters, plot, skip_import_verification = False) diff --git a/nmma/post_processing/maximum_mass_constraint.py b/nmma/post_processing/maximum_mass_constraint.py index 21c12b7f..8f2e12e4 100644 --- a/nmma/post_processing/maximum_mass_constraint.py +++ b/nmma/post_processing/maximum_mass_constraint.py @@ -136,12 +136,9 @@ def LogLikelihood(self, x): mdisk = 10**log10_mdisk mej_dyn = 10**log10_mej_dyn - try: - b1 = baryonic_mass(mass_1, EOS, self.eos_path_macro, self.eos_path_micro) - b2 = baryonic_mass(mass_2, EOS, self.eos_path_macro, self.eos_path_micro) - m_rem_b = b1 + b2 - mdisk - mej_dyn #calculate the baryonic remnant mass - except: - breakpoint() + b1 = baryonic_mass(mass_1, EOS, self.eos_path_macro, self.eos_path_micro) + b2 = baryonic_mass(mass_2, EOS, self.eos_path_macro, self.eos_path_micro) + m_rem_b = b1 + b2 - mdisk - mej_dyn #calculate the baryonic remnant mass if self.use_M_max: m_threshold = baryonic_Kepler_mass(mTOV, R_14, ratio_R, delta) #if the Kepler limit is the threshold, use the quasiuniversal relation From ff8bf7cc0e47443da31b263ad6aa73aa899a9023 Mon Sep 17 00:00:00 2001 From: Hen42rik Date: Mon, 4 May 2026 14:29:28 +0200 Subject: [PATCH 13/13] bugfix:enforced gw_input removed --- nmma/joint/generation.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/nmma/joint/generation.py b/nmma/joint/generation.py index 81768754..97f26cdb 100644 --- a/nmma/joint/generation.py +++ b/nmma/joint/generation.py @@ -183,7 +183,7 @@ def __init__(self, args, unknown_args, logger=None): self.plot_injection = args.plot_injection - self.sampler = "dynesty" + self.sampler = args.sampler self.sampling_seed = args.sampling_seed self.data_dump_file = f"{self.data_directory}/{self.label}_data_dump.pickle" @@ -272,11 +272,13 @@ def adjust_priors_and_data(self, args, logger): pass finally: # correct injection only once lambdas are properly set - # some routines return np.float32 which raises errors downstream in results - # processing, so we convert to float here. Should be handled more elegantly + # some routines return np.float32 which raises errors + # downstream in results processing, so we convert to float + # here. Should be handled more elegantly args.injection_dict = {k: float(v) for k, v in converted_injection.items()} - self.gw_inputs= NMMAGravitationalWaveInput(args, self.unknown_args) - data_dump |= dict(waveform_generator=self.gw_inputs.waveform_generator, + if 'gw' in messengers: + self.gw_inputs= NMMAGravitationalWaveInput(args, self.unknown_args) + data_dump |= dict(waveform_generator=self.gw_inputs.waveform_generator, ifo_list=self.gw_inputs.interferometers) # EM SETUP