diff --git a/nmma/core/base.py b/nmma/core/base.py index 942806dc..c151a11a 100644 --- a/nmma/core/base.py +++ b/nmma/core/base.py @@ -1,6 +1,5 @@ import inspect -import io -import contextlib +import os import h5py from ast import literal_eval import numpy as np @@ -15,6 +14,7 @@ from .utils import input_obj_to_str, read_bestfit_from_posterior from .constants import set_cosmology from .conversion import cosmology_to_distance +from .parsing import single_messenger_analysis_parsing, nmma_base_parsing def initialisation_args_from_signature_and_namespace(_callable, namespace, prefixes = []): prefixes.append('') @@ -98,12 +98,15 @@ def final_diagnostics(self, bestfit_params, args, result=None): The figure object containing the plot """ - pass + try: + return self.sub_model.final_diagnostics(bestfit_params, args, result) + except AttributeError: + pass def post_process_bestfit(self, args, result=None): bestfit_params = read_bestfit_from_posterior(args) bestfit_params = self.parameter_conversion(bestfit_params) - self.final_diagnostics(bestfit_params, args, result) + return self.final_diagnostics(bestfit_params, args, result) def check_parameter_equivalencies(self, parameter_names): """Check for equivalent parameters and terminate if found""" @@ -272,28 +275,29 @@ def check_priors_and_likelihood_for_nmma(priors, likelihood): constraints = {k: priors.pop(k) for k in priors.copy().keys() if isinstance(priors[k], Constraint)} likelihood.constraints.update(constraints) + test_draw = priors.sample(1) test_conversion = priors.conversion_function(test_draw) if len(set(test_conversion.keys()) ) != len(test_conversion.keys()): priors.conversion_function = priors.default_conversion_function likelihood.conv_functions.append(likelihood.priors.conversion_function) + # add final conversions likelihood.setup_parameter_conversion() return priors, likelihood def bilby_sampling(likelihood, priors, args, injection_parameters=None, rank=0): - priors, likelihood = check_priors_and_likelihood_for_nmma(priors, likelihood) - + if isinstance(args, dict): + def_args = nmma_base_parsing(single_messenger_analysis_parsing) + def_args.__dict__.update(args) + args = def_args # fetch the additional sampler kwargs - if isinstance(args.sampler_kwargs, str): - sampler_kwargs = literal_eval(args.sampler_kwargs) - else: - sampler_kwargs = args.sampler_kwargs + sampler_kwargs = getattr(args, "sampler_kwargs", {}) print("Running with the following additional sampler_kwargs:") print(sampler_kwargs) # check if it is running with reactive sampler - nlive = None if args.reactive_sampling else args.nlive + nlive = None if getattr(args, 'reactive_sampling', False) else args.nlive if nlive is None and args.sampler != "ultranest": raise ValueError("reactive sampling is only available for ultranest, " "please set nlive or use ultranest sampler") @@ -307,7 +311,7 @@ def bilby_sampling(likelihood, priors, args, injection_parameters=None, rank=0): sampler_kwargs["niter"] = 1 elif args.sampler == "dynesty": sampler_kwargs["maxiter"] = 1 - + result = run_sampler( likelihood, priors, @@ -327,6 +331,7 @@ def bilby_sampling(likelihood, priors, args, injection_parameters=None, rank=0): return try: + result.posterior = likelihood.posterior_conversion(result.posterior) result.save_to_file() result.save_posterior_samples() except FileMovedError: @@ -359,49 +364,50 @@ def bilby_sampling(likelihood, priors, args, injection_parameters=None, rank=0): if args.bestfit or args.plot: likelihood.post_process_bestfit(args, result) + return result def multi_analysis_loop(args, analysis_setup): - + USE_MPI = False # check if it is running under mpi try: from mpi4py import MPI rank = MPI.COMM_WORLD.Get_rank() + USE_MPI = True except ImportError: rank = 0 - if rank != 0: - # Create buffers - stdout_buffer = io.StringIO() - stderr_buffer = io.StringIO() - - # Redirect python output into buffers - redirect_out = contextlib.redirect_stdout(stdout_buffer) - redirect_err = contextlib.redirect_stderr(stderr_buffer) - else: - redirect_out = contextlib.nullcontext() - redirect_err = contextlib.nullcontext() + if rank != 0 and not getattr(args, 'verbose', False): + devnull = os.open(os.devnull, os.O_WRONLY) + os.dup2(devnull, 1) + os.dup2(devnull, 2) - with redirect_out, redirect_err: - if getattr(args, 'multi', None): - sub_runs = [] - if len(args.multi) == 1: - arg, vals = list(args.multi.items())[0] - for i, val in enumerate(vals): - run_args = deepcopy(args) - setattr(run_args, arg, val) - setattr(run_args, 'label', f"{args.label}_{i}") - sub_runs.append(run_args) - else: - for run_name, changes in args.multi.items(): - run_args = deepcopy(args) - setattr(run_args, 'label', f"{args.label}_{run_name}") - for key, value in changes.items(): - if key not in args: - raise KeyError(f"{key} not a known argument... please remove") - setattr(run_args, key, value) - sub_runs.append(run_args) + if getattr(args, 'multi', None): + sub_runs = [] + if len(args.multi) == 1: + arg, vals = list(args.multi.items())[0] + for i, val in enumerate(vals): + run_args = deepcopy(args) + setattr(run_args, arg, val) + setattr(run_args, 'label', f"{args.label}_{i}") + sub_runs.append(run_args) + else: + for run_name, changes in args.multi.items(): + run_args = deepcopy(args) + setattr(run_args, 'label', f"{args.label}_{run_name}") + for key, value in changes.items(): + if key not in args: + raise KeyError(f"{key} not a known argument... please remove") + setattr(run_args, key, value) + sub_runs.append(run_args) + else: + sub_runs = [args] + for run_args in sub_runs: + priors, likelihood, injection_parameters = analysis_setup(run_args) + priors, likelihood = check_priors_and_likelihood_for_nmma(priors, likelihood) + if USE_MPI and run_args.sampler =='dynesty': + from .mpi_setup import pbilby_sampling + run_function = pbilby_sampling else: - sub_runs = [args] - for run_args in sub_runs: - priors, likelihood, injection_parameters = analysis_setup(run_args) - bilby_sampling(likelihood, priors, run_args, injection_parameters, rank) \ No newline at end of file + run_function = bilby_sampling + out = run_function(likelihood, priors, run_args, injection_parameters, rank) + return out \ No newline at end of file diff --git a/nmma/core/constants.py b/nmma/core/constants.py index ae8682d3..26249f3b 100644 --- a/nmma/core/constants.py +++ b/nmma/core/constants.py @@ -8,6 +8,7 @@ # helpers mc2 = const.M_sun*const.c**2 G_per_c2 = const.G/const.c**2 +seconds_a_day = 24*3600 @@ -34,6 +35,11 @@ msun_to_ergs = mc2.cgs.value MeV_per_fm3_to_Msun_per_km3 = 1e54/((mc2).to(u.MeV).value) # 1 MeV/fm**3 is 8.9653E-7 Msun/km**3 +## pulsar timing constants +msun_s = (const.M_sun * const.G / const.c**3).value # geometrised Msun in seconds +msun_mus = msun_s * 1e6 # Msun in microseconds, used for pulsar timing +einstein_factor = (const.G*const.M_sun/const.c**3).value**(2/3) + default_cosmology = cosmology.Planck18 def set_cosmology(cosmology_input=None): """Set the cosmology for the NMMA package. diff --git a/nmma/core/conversion.py b/nmma/core/conversion.py index 6279a99a..b63271c2 100644 --- a/nmma/core/conversion.py +++ b/nmma/core/conversion.py @@ -4,7 +4,7 @@ from scipy.integrate import simpson from astropy import units from astropy import cosmology as cosmo -from .constants import geom_msun_km, msun_to_ergs, get_cosmology, set_cosmology +from .constants import geom_msun_km, msun_to_ergs, msun_s, get_cosmology, set_cosmology from bilby.gw.conversion import ( component_masses_to_chirp_mass, @@ -191,32 +191,50 @@ def convert_mtot_mni(params): params["mrp_c"] = (params["xmix"]*(params["mtot"]-params["mni"])-params["mrp"]) return params +############################## pulsar timing conversions #################################### +def binary_mass_function(m_obs, m_comp, sin_i): + return (m_comp * sin_i)**3 / (m_obs + m_comp)**2 + +def shapiro_delay(m_comp, sin_i): + "see https://arxiv.org/pdf/1007.0933.pdf" + shapiro_range = msun_s*1.e6 * m_comp # in microseconds + orthometric_ratio = sin_i/(1+np.sqrt(1-sin_i**2)) + return shapiro_range * orthometric_ratio**3 + +def einstein_delay_orbital_factor(orbital_period, eccentricity): + "see, e.g., 10.1007/978-3-662-62110-3_1, p.12 " + return msun_s**(2/3) * eccentricity * (orbital_period /2/np.pi)**(1/3) +def simplified_einstein_delay(m_psr, m_comp, einstein_factor): + "see, e.g., 10.1007/978-3-662-62110-3_1, p.12 " + return einstein_factor *m_comp * (m_psr + 2*m_comp) / (m_psr + m_comp)**(4/3) + +def einstein_delay(m_psr, m_comp, orbital_period, eccentricity): + "see, e.g., 10.1007/978-3-662-62110-3_1, p.12 " + einstein_delay_factor = einstein_delay_orbital_factor(orbital_period, eccentricity) + return simplified_einstein_delay(m_psr, m_comp, einstein_delay_factor) + +def mass_parameters_to_sini(total_mass, mass_function, m_comp): + "Invert the binary mass function to get sin(i) for a given total mass and mass function" + return np.cbrt(mass_function * total_mass**2)/m_comp + ############################## EOS-related conversions #################################### -def EOS2Parameters(radii, masses, lambdas, m1_source, m2_source): - ### FIXME: Under what circumstance would these not simply be mass_val[-1], radius_val[-1]? +def EOS_to_ns_parameters(radii, masses, lambdas): TOV_mass = masses.max(axis=-1) TOV_radius = radii[np.argmax(masses)] + R_14, R_16 = np.interp(x=[1.4, 1.6], xp=masses, fp=radii, left=0, right=0) + + return TOV_mass, TOV_radius, R_14, R_16 +def EOS_to_system_parameters(radii, masses, lambdas, m1_source, m2_source): (log_lambda_1, log_lambda_2) = np.interp(x=[m1_source, m2_source], xp= masses, fp=np.log(lambdas), left=-np.inf, right=-np.inf) lambda_1 = np.exp(log_lambda_1) lambda_2 = np.exp(log_lambda_2) - try: - (radius_1, radius_2, R_14, R_16) = np.interp( - x=[m1_source, m2_source, 1.4, 1.6], - xp=masses, fp= radii, left =0, right=0) + (radius_1, radius_2) = np.interp( x=[m1_source, m2_source], + xp=masses, fp= radii, left =0, right=0) - return TOV_mass, TOV_radius, lambda_1, lambda_2, radius_1, radius_2, R_14, R_16 - ## radius interpolation will raise an error if dealing with multiple sources at once - # In that case we return all values as corresponding arrays - except ValueError: - ref = np.ones_like(lambda_1) - (radius_1, radius_2, R_14, R_16) = np.interp( - x=[m1_source, m2_source, 1.4*ref, 1.6*ref], - xp=masses, fp= radii, left =0, right=0) - - return ref*TOV_mass, ref*TOV_radius, lambda_1, lambda_2, radius_1, radius_2, R_14, R_16 + return lambda_1, lambda_2, radius_1, radius_2 def radii_from_qur(parameters): mass_1_source = parameters["mass_1_source"] @@ -638,7 +656,7 @@ def chiBH_fitting( return chi_BH - def bns_parameter_conversion(self, converted_parameters): + def bns_ejecta_conversion(self, converted_parameters): # prevent the output message flooded by these warning messages old = np.seterr() @@ -674,26 +692,48 @@ def bns_parameter_conversion(self, converted_parameters): # total eject mass total_ejeta_mass = 10**log10_mej_dyn + 10**log10_mej_wind + np.seterr(**old) + return log10_mej_dyn, log10_mej_wind, np.log10(total_ejeta_mass), log10_mdisk_fit + + def grb_energy_conversion(self, converted_parameters, log10_mdisk_fit): + # GRB afterglow energy - log10_Ejet = converted_parameters.get("log10_E0", ( - np.log10(converted_parameters.get("ratio_epsilon", 2e-4)) - + np.log10(1.0 - converted_parameters["ratio_zeta"]) - + log10_mdisk_fit + np.log10(msun_to_ergs) ) - ) + log10_Ejet = np.log10(converted_parameters.get("ratio_epsilon", 2e-4)) + log10_Ejet += np.log10(1.0 - converted_parameters["ratio_zeta"]) + log10_Ejet += log10_mdisk_fit + np.log10(msun_to_ergs) - thetaCore = converted_parameters.get("thetaCore", 0.105) + thetaCore = converted_parameters.get("thetaCore", 0.105) ## default about 6 degree, see arxiv:2210.05695 + + if not any(key in converted_parameters for key in ["thetaWing", "alphaWing", "b"]): + return log10_Ejet - np.log10(np.sin(thetaCore/2)**2) + + if "alphaWing" in converted_parameters: + alphaWing = converted_parameters['alphaWing'] + else: + alphaWing = converted_parameters["thetaWing"] / converted_parameters["thetaCore"] + if "b" in converted_parameters: # power law jet - alphaWing = converted_parameters.get("alphaWing", converted_parameters["thetaWing"] / converted_parameters["thetaCore"]) - log10_E0 = np.log10(powerlaw_jet_energy_to_central_isotropic_energy_equivalent(10**log10_Ejet, thetaCore, alphaWing, converted_parameters["b"])) - elif "b" not in converted_parameters and ("thetaWing" in converted_parameters or "alphaWing" in converted_parameters): # gaussian jet - alphaWing = converted_parameters.get("alphaWing", converted_parameters["thetaWing"] / converted_parameters["thetaCore"]) - log10_E0 = np.log10(gaussian_jet_energy_to_central_isotropic_energy_equivalent(10**log10_Ejet, thetaCore, alphaWing)) + jet_func = powerlaw_jet_energy_to_central_isotropic_energy_equivalent + data = np.column_stack((10**log10_Ejet, thetaCore, alphaWing, converted_parameters["b"])) + + else: + jet_func = gaussian_jet_energy_to_central_isotropic_energy_equivalent + data = np.column_stack((10**log10_Ejet, thetaCore, alphaWing)) + + return np.log10([jet_func(*row) for row in data]) + + + + def bns_parameter_conversion(self, parameters): + log10_mej_dyn, log10_mej_wind, log10_mej_total, log10_mdisk_fit = self.bns_ejecta_conversion(parameters) + + if "log10_E0" in parameters: + log10_E0 = parameters["log10_E0"] else: - log10_E0 = log10_Ejet - np.log10(np.sin(thetaCore/2)**2) + log10_E0 = self.grb_energy_conversion(parameters, log10_mdisk_fit) - np.seterr(**old) - converted_ejecta = np.stack((log10_mej_dyn, log10_mej_wind, np.log10(total_ejeta_mass), log10_E0 )) + converted_ejecta = (log10_mej_dyn, log10_mej_wind, log10_mej_total, log10_E0) return np.where(np.isfinite(converted_ejecta), converted_ejecta, -np.inf) @@ -753,6 +793,9 @@ def from_dict(cls, instruction_dict): if 'em' in instruction_dict: conversions.append(instruction_dict['em']) + + if 'custom' in instruction_dict: + conversions.append(instruction_dict['custom']) return cls(*conversions) @@ -807,8 +850,14 @@ def identity_conversion(self, parameters): 'log10_E0' : r'$\log_{10}(E_0{\rm [erg]})$', 'ratio_zeta' : r'$\zeta$', 'alpha' : r'$\alpha$', - 'KNtheta' : r'$\theta_{KN}$', - 'KNphi' : r'$\phi_{KN}$', + 'KNtheta' : r'$\theta_{KN} [^\circ]$', + 'KNphi' : r'$\phi_{KN} [^\circ]$', + # Bu parameters ## + 'vej_dyn' : r'$v_{\rm{dyn}}{\rm [c]}$', + 'vej_wind' : r'$v_{\rm{wind}}{\rm [c]}$', + 'Ye_dyn' : r'$Y_{e,{\rm{dyn}}}$', + 'kappa_Ye' : r'$\kappa_{\rm{Y_e}}$', + 'kappa_v' : r'$\kappa_{v}$', ## GRB parameters ## 'log10_E0' : r'$\log_{10}(E_0{\rm [erg]})$', 'ratio_epsilon' : r'$\epsilon$', @@ -826,11 +875,11 @@ def identity_conversion(self, parameters): 'mni_c' : r'$M_{\rm{Ni}}/M_{\rm{tot}}$', 'mrp_c' : r'$M_{\rm{rp,c}}{\rm [M_{\odot}]}$', ### EOS parameters ### - 'L_sym' : r'$L_{\rm{sym}}{\rm [MeV]}$', - 'K_sym' : r'$K_{\rm{sym}}{\rm [MeV]}$', - 'K_sat' : r'$K_{\rm{sat}}{\rm [MeV]}$', - '3n_sat' : r'$c_{3n_{\rm{sat}}{\rm [c]}$', - '5n_sat' : r'$c_{5n_{\rm{sat}}{\rm [c]}$', + 'L_sym' : r'$L_{\rm sym}{\rm [MeV]}$', + 'K_sym' : r'$K_{\rm sym}{\rm [MeV]}$', + 'K_sat' : r'$K_{\rm sat}{\rm [MeV]}$', + '3n_sat' : r'$c^2_{3n_{\rm sat}}{\rm [c^2]}$', + '5n_sat' : r'$c^2_{5n_{\rm sat}}{\rm [c^2]}$', 'TOV_mass' : r'$M_{\rm{TOV}}{\rm [M_{\odot}]}$', 'R_14' : r'$R_{1.4}{\rm[km]}$', 'lambda_tilde' : r'$\tilde{\Lambda}$', diff --git a/nmma/joint/analysis_run.py b/nmma/core/mpi_setup.py similarity index 81% rename from nmma/joint/analysis_run.py rename to nmma/core/mpi_setup.py index 9f57d954..fa954472 100644 --- a/nmma/joint/analysis_run.py +++ b/nmma/core/mpi_setup.py @@ -2,9 +2,9 @@ import sys import traceback from io import BufferedWriter -from glob import glob from copy import deepcopy import pickle +import signal from functools import wraps from time import time from datetime import timedelta @@ -12,16 +12,16 @@ import numpy as np from pandas import DataFrame from numpy.random import Generator, PCG64, SeedSequence +from schwimmbad import MPIPool, MultiPool -from bilby.core.prior import PriorDict from bilby.core.sampler import base_sampler as bs, dynesty3_utils as dy_utils from bilby.core.sampler.dynesty import dynesty_stats_plot import dynesty from dynesty.plotting import traceplot, runplot -from ..core.conversion import label_mapping -from ..core.utils import rejection_sample, read_bestfit_from_posterior, logger -from .joint_likelihood import MultiMessengerLikelihood +from .conversion import label_mapping +from .utils import rejection_sample, read_bestfit_from_posterior, logger +from .parsing import process_sampler_kwargs def time_storage(func): @@ -42,46 +42,27 @@ class Worker(bs.NestedSampler): Parameters: data_dump: a pickle-file containing all relevant data to create priors and likelihoods. """ - def __init__(self, data_dump, - outdir, label, plot = False, - skip_import_verification = True, + def __init__( + self, args, prior, likelihood, + injection_parameters, plot = False, + skip_import_verification = True, ): - ## Load the data dump - if not data_dump.endswith("_dump.pickle"): - test_out = os.path.join(os.getcwd(), data_dump) - test_dump = glob(f"{test_out}/data/*_dump.pickle") - data_dump = test_dump[0] - with open(data_dump, "rb") as file: - data_dump = pickle.load(file) - - ## Set properties from the data dump - self.data_dump = data_dump - args = data_dump["args"] args.plot = plot self.args = args + self.outdir = args.outdir + self.label = args.label - # If the run dir has not been specified, get it from the args - if outdir is None: - outdir = self.args.outdir - os.makedirs(outdir, exist_ok=True) + os.makedirs(args.outdir, exist_ok=True) - # If the label has not been specified, get it from the args - if label is None: - label = self.args.label - priors = PriorDict.from_json(data_dump["prior_file"]) - - ## Set up the likelihood - likelihood = MultiMessengerLikelihood.setup_from_args( - data_dump, priors, self.args, logger) - super().__init__( - likelihood, priors, outdir, label, - injection_parameters= data_dump.get("injection_parameters", None), + likelihood, prior, self.outdir, self.label, + injection_parameters = injection_parameters, skip_import_verification = skip_import_verification, - plot=self.args.plot, + plot= plot, soft_init=True, + use_ratio = True, ) @@ -150,7 +131,8 @@ class Dynesty(Worker): def __init__( self, - data_dump, outdir, label, + args, prior, likelihood, + injection_parameters = None, maxmcmc=5000, naccept=60, nact=2, @@ -158,13 +140,11 @@ def __init__( sampler_kwargs={}, sampler_init_kwargs={}, plot= False, - ): - super().__init__(data_dump, outdir, label, plot, - skip_import_verification = False) - # for handler in logger.handlers: - # if isinstance(handler, logging.StreamHandler): - # handler.stream = sys.stdout - + meta_data = {}, + ): + super().__init__(args, prior, likelihood, injection_parameters, + plot, skip_import_verification = False) + self.resume_file = f"{self.outdir}/{self.label}_checkpoint_resume.pickle" self.samples_file= f'{self.outdir}/{self.label}_samples.parquet' @@ -177,9 +157,9 @@ def __init__( # dynesty3 sampler kwargs self.dlogz = sampler_kwargs['dlogz'] self.sampler_kwargs = sampler_kwargs - self._init_sampler_kwargs(sampler_init_kwargs, nact, naccept, maxmcmc) self.nlive = sampler_init_kwargs['nlive'] + self.meta_data = meta_data def _init_sampler_kwargs(self, kwargs, nact, naccept, maxmcmc): @@ -503,19 +483,21 @@ def plot_current_state(self): plt.close("all") def storable_metadata(self): - meta_data = self.data_dump.copy() - waveform_generator = meta_data.pop("waveform_generator", None) - if waveform_generator is not None: - meta_data["waveform_generator"] = waveform_generator.__repr__() - ifo_list = meta_data.pop("ifo_list", None) - if ifo_list is not None: - meta_data["ifo_list"] = [ifo.__repr__() for ifo in ifo_list] - - meta_data["args"] = vars(self.args) # convert Namespace to dict for storing + meta_data = self.meta_data + meta_data["args"] = vars(self.args).copy() # convert Namespace to dict for storing meta_data["likelihood"] = self.likelihood.meta_data meta_data["sampler_kwargs"] = self.init_sampler_kwargs meta_data["run_sampler_kwargs"] = self.sampler_kwargs + meta_data = self.floatify_dict(meta_data) return meta_data + + def floatify_dict(self, d): + for k, v in d.items(): + if isinstance(v, dict): + d[k] = self.floatify_dict(v) + elif isinstance(v, np.floating): + d[k] = float(v) + return d def format_result( self, @@ -604,3 +586,114 @@ def format_result( logger.warning(f"Failed to create diagnostic plots: {e} \n{traceback.format_exc()}") logger.info("Finished formatting result.") return result + + + +def pbilby_sampling( + likelihood, prior, args, + injection_parameters, rank, + pool_type = 'mpi', + meta_data = {}, + **kwargs +): + + default_kwargs = dict( + sampler_kwargs={}, + sampling_seed=42, + plot=True, + # + maxmcmc=5000, + naccept=60, + nact=2, + check_point_delta_t=1800, + n_check_point=2000, + max_its=1e10, + max_run_time=1e10, + checkpoint_plot=False, + # + rejection_sample_posterior=True, + result_format="hdf5", + ) + + # priority: kwargs > args > defaults + use_kwargs = {} + for key in default_kwargs.keys(): + if key in kwargs: + use_kwargs[key] = kwargs[key] + elif hasattr(args, key): + use_kwargs[key] = getattr(args, key) + else: + use_kwargs[key] = default_kwargs[key] + + use_kwargs |= kwargs # in case there are additional kwargs not in default_kwargs + + # Initialise a worker. this needs a global scope to allow + # persistence of states beyond the pool's scope. + # Otherwise emulators retrace on each evaluation. + global worker + if rank == 0: + sampler_init_kwargs, run_kwargs = process_sampler_kwargs( + kwargs.pop('sampler_kwargs', {}), use_kwargs) + + worker = Dynesty( + args, prior, likelihood, + injection_parameters, + maxmcmc=use_kwargs['maxmcmc'], + nact=use_kwargs['nact'], + naccept=use_kwargs['naccept'], + sampling_seed=use_kwargs['sampling_seed'], + sampler_kwargs = run_kwargs, + sampler_init_kwargs=sampler_init_kwargs, + plot=use_kwargs['plot'], + meta_data=meta_data, + ) + + else: + worker = Worker(args, prior, likelihood, + injection_parameters, plot = use_kwargs['plot']) + + ## graceful handling of preemptive shutdowns + def handle_sigterm(signum, frame): + try: + worker.checkpointing(False, + 'Received termination signal. Checkpointing and exiting gracefully.') + sys.exit() + except Exception: + pass + + signal.signal(signal.SIGTERM, handle_sigterm) + signal.signal(signal.SIGINT , handle_sigterm) + signal.signal(signal.SIGUSR1, handle_sigterm) + + POOL = MPIPool if pool_type == 'mpi' else MultiPool + with POOL() as pool: + result = None + if pool.is_master(): + worker.start_sampler( + pool, + pooled_log_likelihood, + pooled_prior_transform, + pooled_initial_point_from_prior) + + results = worker.run_sampler( + check_point_delta_t=use_kwargs['check_point_delta_t'], + n_check_point=use_kwargs['n_check_point'], + max_its=use_kwargs['max_its'], + max_run_time=use_kwargs['max_run_time'], + checkpoint_plot=use_kwargs['checkpoint_plot'] + ) + result = worker.format_result( + results, use_kwargs['result_format'], + use_kwargs['rejection_sample_posterior']) + return result + + +# Worker functions. These are read in the global scope by each worker +def pooled_initial_point_from_prior(args): + return worker.get_initial_point_from_prior(args) + +def pooled_log_likelihood(v_array): + return worker.log_likelihood(v_array) + +def pooled_prior_transform(u_array): + return worker.prior_transform(u_array) diff --git a/nmma/core/parsing.py b/nmma/core/parsing.py index d0002162..7e431576 100644 --- a/nmma/core/parsing.py +++ b/nmma/core/parsing.py @@ -108,33 +108,72 @@ def base_analysis_parsing(parser): parser.add_argument("--cosmology", help="Name of the cosmology to be used, see astropy.cosmology for available cosmologies (implicit default: Planck18)") parser.add_argument("--sampling-seed","--seed", type=int, default=42, help="Sampling seed (default: 42)") + parser.add_argument("--sampler-kwargs", default="{}", type = yaml_parse, + help="Additional keyword arguments to pass to the sampler as a dictionary" ) + parser.add_argument("--skip-sampling", action='store_true', + help="If analysis has already run, skip bilby sampling and compute results from checkpoint files. Combine with --plot to make plots from these files.") + parser.add_argument("--dlogz", default=0.1, type=float, + help="Stopping criteria: remaining evidence, (default=0.1)" ) + parser.add_argument("--soft-init", action='store_true', + help="To start the sampler softly (without any checking, default: False)") + parser.add_argument("--cpus", type=int, default=1, + help="Number of cores to be used, only needed for dynesty (default: 1)") + parser.add_argument("-n","--nlive", "--n-live", type=int, default=2048, help="Number of live points (default: 2048)") return parser +def dynesty_parsing(parser): + dynesty_group = parser.add_argument_group(title="Dynesty Settings") + dynesty_group.add_argument("--n-check-point", default=1000, type=int, + help="Steps to take before attempting checkpoint") + dynesty_group.add_argument("--max-its", default=10**10, type=int, + help="Maximum number of iterations to sample for (default=1.e10)") + dynesty_group.add_argument("--max-run-time", default=1.0e10, type=float, + help="Maximum time to run for (default=1.e10 s)") + dynesty_group.add_argument("--rejection-sample-posterior", action='store_false', help=( + "Whether to generate the posterior samples by rejection sampling the " + "nested samples or resampling with replacement" ) ) + dynesty_group.add_argument( "--walks", default=100, type=int, + help="Minimum number of walks, defaults to 100" ) + dynesty_group.add_argument( "--proposals", action="append", + help="The jump proposals to use, the options are 'diff' and 'volumetric'" ) + dynesty_group.add_argument("--maxmcmc", default=5000, type=int, + help="Maximum number of walks, defaults to 5000" ) + dynesty_group.add_argument( "--nact", default=2, type=int, + help="Number of autocorrelation times to take, defaults to 2") + dynesty_group.add_argument("--naccept", default=60, type=int, + help="The average number of accepted steps per MCMC chain, defaults to 60") + dynesty_group.add_argument("--min-eff", default=10, type=float, + help="The minimum efficiency at which to switch from uniform sampling.") + + + dynesty_group.add_argument("--facc", default=0.5, type=float, + help="See dynesty.NestedSampler") + dynesty_group.add_argument("--enlarge", default=1.5, type=float, + help="See dynesty.NestedSampler") + return parser + + def single_messenger_analysis_parsing(parser): parser = base_analysis_parsing(parser) + parser = dynesty_parsing(parser) parser.add_argument("--config", help="Name of the configuration file containing parameter values.") parser.add_argument("-o", "--outdir", default="outdir", help="Path to the output directory") parser.add_argument("--label", default ="nmma_transient", help="Label for the run") parser.add_argument("--plot", action='store_true', help="create characteristic plot") + parser.add_argument("--plot-kwargs", type=yaml_parse, default={}, help="Additional keyword arguments to pass to the plotting routine as a dictionary" ) parser.add_argument("--verbose", action='store_true', help="print out log likelihoods" ) parser.add_argument("--prior-file","--prior", help="Path to the prior file") parser.add_argument("--sampler", default="pymultinest", help="Sampler to be used (default: pymultinest)") - parser.add_argument("--sampler-kwargs", default="{}", - help="Additional kwargs (e.g. {'evidence_tolerance':0.5}) for bilby.run_sampler, put a double quotation marks around the dictionary") - parser.add_argument("--soft-init", action='store_true', - help="To start the sampler softly (without any checking, default: False)") - parser.add_argument("--cpus", type=int, default=1, - help="Number of cores to be used, only needed for dynesty (default: 1)") - parser.add_argument("-n","--nlive", "--n-live", type=int, default=2048, help="Number of live points (default: 2048)") parser.add_argument("--reactive-sampling", action='store_true', help="To use reactive sampling in ultranest (default: False)") - parser.add_argument("--skip-sampling", action='store_true', - help="If analysis has already run, skip bilby sampling and compute results from checkpoint files. Combine with --plot to make plots from these files.") - + parser.add_argument("--bestfit", "--best-fit", action='store_true', + help="Save the best fit parameters to JSON") + return parser + def base_injection_parsing(parser): parser.add_argument("-f", "--injection-file","--filename", default="injection", help="Path to the file with injection parameters, default: 'injection'.") @@ -229,6 +268,27 @@ def slurm_analysis_parser(parser): return parser +def process_sampler_kwargs(sampler_kwargs, kwargs): + # Set defaults here to avoid inconsistent values + default_kwargs = dict(dlogz=0.1, save_bounds=False, + min_eff=10, sample="acceptance-walk", nlive=1000, bound="live", + walks=100, facc=0.5, enlarge=1.5) + + run_sampler_kwargs = {key: kwargs.get(key, default_kwargs[key]) + for key in ['dlogz', 'save_bounds']} + + + def_init_kwargs = {key: kwargs.get(key, default_kwargs[key]) + for key in ['min_eff','sample', 'nlive', 'bound', 'walks', 'facc', 'enlarge']} + + sampler_init_kwargs = def_init_kwargs | sampler_kwargs + sampler_init_kwargs['first_update'] = dict(min_eff=sampler_init_kwargs.pop('min_eff'), + min_ncall= 2 * sampler_init_kwargs['nlive']) + + return sampler_init_kwargs, run_sampler_kwargs + + + ############# UTILS ############# def process_multi_condition_string(multi_condition_string): # Supported operators diff --git a/nmma/core/plotting_utils.py b/nmma/core/plotting_utils.py index 7d480b63..e6953760 100644 --- a/nmma/core/plotting_utils.py +++ b/nmma/core/plotting_utils.py @@ -2,8 +2,11 @@ from bilby.core.prior import PriorDict, DeltaFunction import numpy as np import matplotlib +from matplotlib.colors import LinearSegmentedColormap import os -matplotlib.use("Agg") +# matplotlib.use("Agg") +from matplotlib import pyplot as plt +import itertools def fig_setup(): fig_width_pt = 750.0 # Get this from LaTeX using \showthe\columnwidth @@ -18,7 +21,7 @@ def fig_setup(): "legend.fontsize": 18, "xtick.labelsize": 18, "ytick.labelsize": 18, - "font.family": "Times New Roman", + "font.family": "serif", "figure.figsize": fig_size, "mathtext.fontset": "stix", } @@ -45,7 +48,7 @@ def fig_setup(): matplotlib.rcParams.update(params) matplotlib.rcParams['text.usetex'] = (os.environ.get("CI") != 'true') - return [ + color_array = [ "#22ADFC", # blue "#F42969", # red "#F4A429", # orange @@ -58,7 +61,7 @@ def fig_setup(): "#8B4513", # brown "#FF6347", # tomato ] - + return itertools.cycle(color_array) def plotting_parameters_from_priors(priors, keys=None): """ @@ -83,4 +86,39 @@ def plotting_parameters_from_priors(priors, keys=None): if keys is None: keys = priors.keys() - return {k: v.latex_label for k, v in priors.items() if k in keys and not isinstance(v, DeltaFunction)} \ No newline at end of file + return {k: v.latex_label for k, v in priors.items() if k in keys and not isinstance(v, DeltaFunction)} + +def setup_multi_axes(num_axes, sharex=False, sharey=False, ncols=None): + "Set up a multi-panel figure with the specified number of axes, essentially stolen from corner.py" + + if ncols is None: + ncols = np.min([5, np.ceil(np.sqrt(num_axes)).astype(int)]) + nrows = np.ceil(num_axes / ncols).astype(int) + + + factor = 2.0 # size of one side of one panel + lbdim = 0.5 * factor # size of left/bottom margin + trdim = 0.2 * factor # size of top/right margin + whspace = 0.05 # w/hspace size + whspace = trdim + rowdim = lbdim + factor * ncols + factor * (ncols - 1.0) * whspace + trdim + coldim = lbdim + factor * nrows + factor * (nrows - 1.0) * whspace + trdim + + + + # Create a new figure if one wasn't provided. + fig, axes = plt.subplots(nrows, ncols, figsize=(rowdim, coldim), + sharex=sharex, sharey=sharey, constrained_layout=True) + + return fig, axes.flatten() + + +def fading_cmap(color): + cmap = LinearSegmentedColormap.from_list("custom_cmap", ["white",color], gamma = 2) + + cdict = cmap._segmentdata.copy() + vals = cdict['alpha'][:, 0] + alpha = np.linspace(0, 1, len(vals)) + cdict['alpha'] = np.column_stack([vals, alpha, alpha]) + + return LinearSegmentedColormap(cmap.name, cdict, cmap.N, cmap._gamma) \ No newline at end of file diff --git a/nmma/core/utils.py b/nmma/core/utils.py index cb866f38..88456ffb 100644 --- a/nmma/core/utils.py +++ b/nmma/core/utils.py @@ -8,7 +8,6 @@ from bilby.core.result import read_in_result from astropy import time -from matplotlib.colors import LinearSegmentedColormap from pathlib import Path import yaml @@ -69,7 +68,8 @@ def read_trigger_time(parameters=None, args=None, out_format = 'mjd'): format = 'gps' trigger_time= time.Time(args.trigger_time, format=format) trigger_time # this fails if not a valid time - + elif args is not None: + args.trigger_time = trigger_time.mjd if out_format == 'mjd' else trigger_time.gps if out_format == 'mjd': return trigger_time.mjd @@ -151,19 +151,20 @@ def set_filename(basename, args, identifier=''): return f"{base}{identifier}{ext}" -def read_bestfit_from_posterior(args, mode = 'max_likelihood'): +def read_bestfit_from_posterior(args, mode = 'max_likelihood', return_posterior=False): posterior_samples = get_posteriors(args) if mode == 'max_likelihood': bestfit = posterior_samples.loc[posterior_samples.log_likelihood.idxmax()] elif mode == 'max_posterior': - bestfit = posterior_samples.loc[(posterior_samples.log_likelihood*posterior_samples.log_prior).idxmax()] + bestfit = posterior_samples.loc[(posterior_samples.log_likelihood + posterior_samples.log_prior).idxmax()] else: raise ValueError(f"Mode {mode} not recognized. Use 'max_likelihood' or 'max_posterior'.") bestfit_params = bestfit.to_dict() bestfit_idx = bestfit.name print(f"Best fit parameters: {str(bestfit_params)}\nBest fit index: {bestfit_idx}") bestfit_params["best_fit_index"] = int(bestfit_idx) - return bestfit_params + + return (bestfit_params, posterior_samples) if return_posterior else bestfit_params def read_bestfit_from_json(bestfit_file_json, cols, verbose=False): df = pd.read_json(bestfit_file_json, typ="series") @@ -190,6 +191,7 @@ def sig_lims(values, quantiles=None, sig_unc=2): fmt = f".{ord_error}f" return f"${{{q_mean:{fmt}}}}_{{-{low_err:{fmt}}}}^{{+{high_err:{fmt}}}}$" else: + q_mean, low_err, high_err =np.around([q_mean, low_err, high_err], ord_error) return f"${{{int(q_mean)}}}_{{-{int(low_err)}}}^{{+{int(high_err)}}}$" @@ -222,12 +224,15 @@ def input_obj_to_str(input_obj, ref_name= None): else: raise TypeError("Input object could not be identified.") -def fading_cmap(color): - cmap = LinearSegmentedColormap.from_list("custom_cmap", ["white",color], gamma = 2) - - cdict = cmap._segmentdata.copy() - vals = cdict['alpha'][:, 0] - alpha = np.linspace(0, 1, len(vals)) - cdict['alpha'] = np.column_stack([vals, alpha, alpha]) - - return LinearSegmentedColormap(cmap.name, cdict, cmap.N, cmap._gamma) \ No newline at end of file +def nan_level(data, level, weights=None): + nans, clean_data = np.isnan(data), data[~np.isnan(data)] + if weights is not None: + weights = np.array(weights)[~nans] + weights = weights / np.sum(weights) + nan_share = np.sum((nans)) / len(data) + if nan_share > level: + return [np.nan, np.nan] + rest_level = level - nan_share + low = np.quantile(clean_data, (1-rest_level)/2, weights=weights, method='inverted_cdf') + up = np.quantile(clean_data, 1-(1-rest_level)/2, weights=weights, method='inverted_cdf') + return [low, up] diff --git a/nmma/em/analysis.py b/nmma/em/analysis.py index 8778e4d5..3a2862ee 100644 --- a/nmma/em/analysis.py +++ b/nmma/em/analysis.py @@ -52,7 +52,7 @@ def check_detections(data, remove_nondetections=False): def set_analysis_filters(filters, data): if filters is None: - filters = list(data.keys()) + return list(data.keys()) filters_to_analyze = [filt for filt in data.keys() if filt in filters] print(f"Running with filters {filters_to_analyze}") @@ -94,26 +94,25 @@ def bolometric_setup(args): return priors, likelihood, injection_parameters def analysis_setup(args): - + filters = utils.set_filters(args) - detection_limit = utils.create_detection_limit(args, filters) - try: # load observational data data = io.load_em_observations(args, format='observations') trigger_time = read_trigger_time(None,args) injection_parameters = None except ValueError: + detection_limit = utils.create_detection_limit(args, filters) # try to work with injection data instead data, injection_parameters = data_from_injection(args, filters, detection_limit) trigger_time = injection_parameters.get('trigger_time',0) except FileNotFoundError: # If the injection file is not found, raise an error raise FileNotFoundError("Injection file not found.") - data = utils.cut_data_to_time_range(data, args, trigger_time) data = check_detections(data, args.remove_nondetections) filters_to_analyze = set_analysis_filters(filters, data) + detection_limit = utils.create_detection_limit(args, filters_to_analyze) # initialize light curve model print("Creating light curve model for inference") @@ -252,7 +251,14 @@ def nnanalysis(args): print('saved posterior plot') def main(args=None): + if isinstance(args, dict): + non_default = args.copy() + args = [] + else: + non_default = {} args = parsing_and_logging(multi_wavelength_analysis_parser, args) + args.__dict__.update(non_default) + if args.sampler == 'neuralnet': nnanalysis(args) else: diff --git a/nmma/em/em_likelihood.py b/nmma/em/em_likelihood.py index e9a56cf1..841cb891 100644 --- a/nmma/em/em_likelihood.py +++ b/nmma/em/em_likelihood.py @@ -21,7 +21,7 @@ def setup_em_kwargs(priors, data_dump, args, logger=None): light_curve_model = model.create_light_curve_model_from_args(lc_model, args, filters) trigger_time = read_trigger_time(None, args) light_curve_data = utils.setup_filtered_lc_data(light_curve_data, trigger_time) - light_curve_data = utils.check_model_time_consistency(light_curve_data, light_curve_model, priors) + light_curve_data = utils.check_model_time_consistency(light_curve_data, light_curve_model, priors, args.injection) sys_handler = systematics.FilterSystematicsHandler(filters, data_dump['systematics_dict'], error_budget=args.em_error_budget, light_curve_times=light_curve_data[0]) @@ -119,7 +119,7 @@ def final_diagnostics(self, bestfit_params, args, result=None): The figure object containing the plot """ - self.sub_model.final_diagnostics(bestfit_params, args, result) + return self.sub_model.final_diagnostics(bestfit_params, args, result) def posterior_conversion(self, posterior_samples): if 'log10_mej_dyn' in posterior_samples and 'log10_mej_wind' in posterior_samples: @@ -259,7 +259,7 @@ def final_diagnostics(self, bestfit_params, args, result=None): if result is None: save_path = f'{args.outdir}/{args.label}_bol_lightcurve.png' save_path = f'{result.outdir}/{result.label}_bol_lightcurve.png' - bolometric_lc_plot(self, obs_times, obs_lc, save_path = save_path) + return bolometric_lc_plot(self, obs_times, obs_lc, save_path = save_path) class MultiFilterTransient(BasicEMTransient): @@ -351,4 +351,4 @@ def band_log_likelihood(self, expected_mags, obs_error): return minus_chisquare_total + gaussprob_total def final_diagnostics(self, bestfit_params, args, result=None): - lch_bestfit(self, bestfit_params, args, result) \ No newline at end of file + return lch_bestfit(self, bestfit_params, args, result) \ No newline at end of file diff --git a/nmma/em/em_parsing.py b/nmma/em/em_parsing.py index cef22c68..cb722e0b 100644 --- a/nmma/em/em_parsing.py +++ b/nmma/em/em_parsing.py @@ -38,7 +38,6 @@ def basic_em_only_parsing(parser): parser.add_argument("--verbose", action='store_true', help="print out log likelihoods" ) return parser - def basic_em_only_analysis_parsing(parser): parser = single_messenger_analysis_parsing(parser) @@ -47,8 +46,6 @@ def basic_em_only_analysis_parsing(parser): parser.add_argument("--light-curve-data", "--data", help="Path to data in [time filter magnitude error] format, time format will be inferred, but can be explicitly adjusted with --time-format. If not given, will try to generate data from the injection file.") parser.add_argument("--time-format", help="Time format of the light curve data, e.g. isot, mjd, see https://docs.astropy.org/en/stable/time/#time-format") - parser.add_argument("--bestfit", "--best-fit", action='store_true', - help="Save the best fit parameters and magnitudes to JSON") return parser @@ -73,6 +70,8 @@ def em_model_parsing(parser): help="Name of the model-type to be used, can be a comma-seperated list for joint lightcurve models" ) em_model_parser.add_argument("--em-model", "--kilonova-model","--model", type=yaml_parse, nargs="*", help="Name of the transient model to be used") + em_model_parser.add_argument("--em-model-kwargs", "--model-parameters", type=yaml_parse, + help="Additional keyword arguments for the transient model, given like a python-dict ") em_model_parser.add_argument("--interpolation-type", "--gptype", default="keras", help="Interpolation library to be used for EM "\ "transient model. Default: keras, further options: tensorflow, sklearn_gp, api_gp" ) @@ -233,7 +232,8 @@ def modified_em_prior_parsing(parser): mod_em_prior_parser.add_argument("--em-error-budget", "--kilonova-error", help="Additional statistical error (mag) to be introduced in each filter," \ " can be passed as list or dict. Will only be used if em_syserr is not given in prior") - mod_em_prior_parser.add_argument("--systematics-file", help="Path to systematics configuration file") + mod_em_prior_parser.add_argument("--systematics-file", type=yaml_parse, + help="Path to systematics configuration file") return parser def em_analysis_parsing(parser): diff --git a/nmma/em/io.py b/nmma/em/io.py index b1846b73..83bee8df 100644 --- a/nmma/em/io.py +++ b/nmma/em/io.py @@ -34,6 +34,8 @@ def load_em_observations(filename, args=None, format='observations'): Returns: - data (dict): Dictionary containing the lightcurve data from the file. The keys are generally 'time' and each of the filters in the file as well as their accompanying error values. """ + if isinstance(filename, dict): + return filename # assume it is already in the correct format if isinstance(filename, argparse.Namespace): args = filename filename = args.light_curve_data @@ -78,7 +80,9 @@ def read_lc_from_csv(filename, args, format): data = {} for line in lines: - lineSplit = line.split(" ") + if line.startswith("#") or line.startswith("time") or line.startswith("mjd"): + continue + lineSplit = line.split(None) lineSplit = list(filter(None, lineSplit)) try: mjd = Time(lineSplit[0]).mjd diff --git a/nmma/em/lightcurve_generation.py b/nmma/em/lightcurve_generation.py index 11e7cc2f..e83b8540 100644 --- a/nmma/em/lightcurve_generation.py +++ b/nmma/em/lightcurve_generation.py @@ -11,6 +11,7 @@ from scipy.interpolate import CubicSpline from . import utils +from ..core.utils import read_trigger_time try: import afterglowpy @@ -36,12 +37,15 @@ def inner(func): ################################################################# +def dummy_add(nu): + return 0.0 + def bb_flux_from_inv_temp(nu, inv_temp, R_photo, dist_squared = abs_mag_dist_factor): exponent = np.clip(h * nu * inv_temp / kb, None, 700) # to avoid overflow in exp bb_factor = 2.* h/ c_cgs**2 return bb_factor * nu**3 /np.expm1(exponent) * R_photo * R_photo / dist_squared -def mag_dict_for_blackbody(filters, inv_temp, R_photo, nu_host, add = lambda x: 0.): +def mag_dict_for_blackbody(filters, inv_temp, R_photo, nu_host, add = dummy_add): mag = {} # nu_host = nu_obs * (1 + redshift) for idx, filt in enumerate(filters): @@ -57,10 +61,11 @@ def mag_dict_for_blackbody(filters, inv_temp, R_photo, nu_host, add = lambda x: ################################################################# ######################### LC MODELS ############################# ################################################################# - ## Arnett model convenience functions def arnett_lc_get_int_A_non_vec(x, y): - r = quad(lambda z: 2 * z * np.exp(-2 * z * y + z**2), 0, x) + def arnett_func(z): + return 2 * z * np.exp(-2 * z * y + z**2) + r = quad(arnett_func, 0, x) return r[0] @@ -68,7 +73,9 @@ def arnett_lc_get_int_A_non_vec(x, y): def arnett_lc_get_int_B_non_vec(x, y, s): - r = quad(lambda z: 2 * z * np.exp(-2 * z * y + 2 * z * s + z**2), 0, x) + def arnett_func(z): + return 2 * z * np.exp(-2 * z * y + 2 * z * s + z**2) + r = quad(arnett_func, 0, x) return r[0] @@ -235,13 +242,14 @@ def flux_density_on_E0_array(default_time, obs_frequencies, param_dict): time_scale = np.log10(default_time / t_end) log10_E0[mask] = log10_Eend + energy_exponential * time_scale[mask] E0 = 10 ** log10_E0 - vec_func = np.vectorize( - lambda i: fluxDensity( + def helper(i): + return fluxDensity( default_time[i], obs_frequencies, E0=E0[i], **param_dict - ), + ) + vec_func = np.vectorize(helper, otypes=[np.ndarray] ) mJys = vec_func(np.arange(len(default_time))) @@ -288,20 +296,18 @@ def host_lc(sample_times, parameters, filters, host_mag): ## supernova model def sn_lc(sample_times_stretched, sn_model, filters, lambdas): mag = {} - - for filt, lambda_A in zip(filters, lambdas): - # convert back to AA - lambda_AA = 1e10 * lambda_A - if lambda_AA < sn_model.minwave() or lambda_AA > sn_model.maxwave(): - mag[filt] = np.full_like(sample_times_stretched,np.inf) - else: - try: - flux = sn_model.flux(sample_times_stretched, [lambda_AA])[:, 0] - # see https://en.wikipedia.org/wiki/AB_magnitude - flux_jy = 3.34e4 * np.power(lambda_AA, 2.0) * flux - mag[filt] = utils.flux_to_ABmag(flux_jy, unit='Jy') - except Exception: - return {} + for filt, lambda_ in zip(filters, lambdas): + try: + mag[filt] = sn_model.bandmag(filt, 'ab', sample_times_stretched) + except ValueError: + lambda_AA = 1e10 * lambda_ + if lambda_AA < sn_model.minwave() or lambda_AA > sn_model.maxwave(): + mag[filt] = np.full_like(sample_times_stretched, np.inf) + continue + #NOTE: workaround for potential bug in sncosmo: buffer error if lambdaa as float + flux_AA = sn_model.flux(sample_times_stretched, [lambda_AA]).flatten() + # see https://en.wikipedia.org/wiki/AB_magnitude + mag[filt] = utils.flux_to_ABmag(flux_AA*3.34e4 * lambda_AA**2, unit='Jy') return mag ## shock-cooling lightcurve @@ -818,7 +824,9 @@ def create_light_curve_data( injection_parameters = light_curve_model.parameter_conversion(injection_parameters) filters = utils.set_filters(args) - trigger_time = injection_parameters.get("trigger_time", 0.) + trigger_time = read_trigger_time(injection_parameters, args) + if trigger_time is None: + trigger_time = 0. if rng is None: rng = np.random.default_rng(args.generation_seed) if getattr(args, 'absolute', False): diff --git a/nmma/em/lightcurve_handling.py b/nmma/em/lightcurve_handling.py index ee077acd..40d92872 100644 --- a/nmma/em/lightcurve_handling.py +++ b/nmma/em/lightcurve_handling.py @@ -63,12 +63,12 @@ def post_process_bestfit(transient, bestfit_params, args, result=None): print(f"Saved bestfit parameters and magnitudes to {bestfit_file}") if args.plot: - filters_to_plot = [ - filt for filt in transient.observed_filters - if not np.isnan(transient.light_curves[filt]).all() - ] - plot_error = {filt: model_error[filt] for filt in filters_to_plot} - mags_to_plot = {filt: best_mags[filt] for filt in filters_to_plot} + plot_kwargs = getattr(args, "plot_kwargs", {}) + if (plot_filters := plot_kwargs.pop("filters", None)) is None: + plot_filters = {filt: filt for filt in transient.observed_filters if not np.isnan(transient.light_curves[filt]).all()} + + plot_error = {k: model_error[k] for k in plot_filters.keys()} + mags_to_plot = {k: best_mags[k] for k in plot_filters.keys()} mags_to_plot["time"] = observable_times if isinstance(lc_model, model.CombinedLightCurveModelContainer): @@ -85,7 +85,7 @@ def post_process_bestfit(transient, bestfit_params, args, result=None): } plot_errors = [] plot_mags = [] - for filt in filters_to_plot: + for filt in plot_filters.keys(): try: plot_mags.append(utils.get_filtered_mag(mag_all[i], filt)) except KeyError: @@ -98,11 +98,15 @@ def post_process_bestfit(transient, bestfit_params, args, result=None): sub_model_plot_props[sub_model.model]['plot_errors'] = plot_errors else: sub_model_plot_props = None - - basic_em_analysis_plot( - transient, mags_to_plot, plot_error, chi2_dict, mismatches, - sub_model_plot_props, xlim = args.xlim, ylim = args.ylim, - save_path = os.path.join(args.outdir, f"{args.label}_lightcurves.png") + save_path = plot_kwargs.pop("save_path", + os.path.join(args.outdir, + f"{args.label}_bestfit_lightcurves.png")) + return basic_em_analysis_plot( + transient, plot_filters, mags_to_plot, plot_error, + chi2_dict, mismatches, sub_model_plot_props, + xlim = args.xlim, ylim = args.ylim, + save_path = save_path, + fig = getattr(args, "fig", None), **plot_kwargs ) diff --git a/nmma/em/model.py b/nmma/em/model.py index d5242b48..c334656c 100644 --- a/nmma/em/model.py +++ b/nmma/em/model.py @@ -351,7 +351,7 @@ def combine_detector_data(self, model_lc, observable_times): # apparent_magnitude = utils.autocomplete_data( # observable_times, observable_times[use_mask], apparent_magnitude[use_mask]) else: #no meaningful inter-/extrapolation possible - apparent_magnitude = np.full_like(self.model_times, np.inf) + apparent_magnitude = np.full_like(observable_times, np.inf) lc_data[filt] = apparent_magnitude return (observable_times, lc_data) @@ -448,6 +448,8 @@ class SimpleBolometricLightCurveModel(LightCurveModelContainer): ---------- model: string, optional Name of the model. Can be either "Arnett" (default) or "Arnett_modified" + em_model_kwargs: optional + Additional keyword arguments, not used in Arnett SN model. Returns ------- @@ -455,7 +457,7 @@ class SimpleBolometricLightCurveModel(LightCurveModelContainer): A light curve model object to evaluate the light curve from a set of parameters """ - def __init__(self, model="Arnett", sample_times=None): + def __init__(self, model="Arnett", sample_times=None, **em_model_kwargs): super().__init__(model, sample_times=sample_times) if model == "Arnett": self.lc_func = lc_gen.arnett_lc @@ -504,7 +506,8 @@ class SVDLightCurveModel(LightCurveModelContainer): List of filters to create model for. Defaults to all available filters. local_only: bool, optional If True, only local models will be used. - + em_model_kwargs: optional + Additional keyword arguments to be passed to the underlying emulator. Returns ------- LightCurveModel: `nmma.em.model.SVDLightCurveModel` @@ -523,6 +526,7 @@ def __init__( filters=None, sample_times = None, local_only=False, + **em_model_kwargs ): # Some models have underscores. Keep those, but drop '_tf' if it exists model_name_components = model.split("_") @@ -546,7 +550,7 @@ def __init__( ##FIXME Does this make sense for api_gp, too? filters = self.get_model_data(core_model_name, filters) try: - svd_mag_model = joblib.load(modelfile) + svd_mag_model = joblib.load(modelfile, **em_model_kwargs) # temporary fix, moving towards permament setting of sncosmo filter names self.svd_mag_model = {k.replace("_", ":"): v for k, v in svd_mag_model.items()} self.svd_lbol_model= None # FIXME: this is not yet implemented @@ -667,11 +671,13 @@ class FiestaKilonovaModel(FiestaModel): Parameters ---------- model: str, optional - Name of the model. Default is "Bu2025_MLP". + Name of the model. Default is "Bu2026_MLP". filters: list of str, optional List of filters to create model for. Defaults to all available filters. surrogate_dir: str, optional path to the directory containing the surrogate models. + em_model_kwargs: optional + Additional keyword arguments to be passed to the underlying fiesta model. Returns ------- @@ -679,19 +685,28 @@ class FiestaKilonovaModel(FiestaModel): A light curve model object to evaluate the light curve from a set of parameters. """ - def __init__(self, model="Bu2025_MLP", filters=None, surrogate_dir=None, **kwargs): + def __init__(self, model="Bu2026_MLP", filters=None, surrogate_dir=None, **em_model_kwargs): if model.endswith("_lc"): from fiesta.inference.lightcurve_model import BullaLightcurveModel as BullaSurrogate else: from fiesta.inference.lightcurve_model import BullaFlux as BullaSurrogate - fiesta_kwargs= dict( name=model, filters=filters, directory=surrogate_dir,) - try: - fiesta_model = BullaSurrogate(**fiesta_kwargs) - except OSError: - fiesta_kwargs['directory'] = f'{surrogate_dir}/KN/{model}/model' - fiesta_model = BullaSurrogate(**fiesta_kwargs) + fiesta_kwargs= dict( name=model, filters=filters, directory=surrogate_dir) + fiesta_model = BullaSurrogate(**fiesta_kwargs) - super().__init__(fiesta_model, filters, sample_times=kwargs.get('sample_times', None)) + super().__init__(fiesta_model, filters, sample_times=em_model_kwargs.get('sample_times', None)) + + def parameter_conversion(self, parameters): + if "kappa_Ye" in parameters: + if "Ye_wind" in parameters: + parameters["Ye_dyn"] = parameters["kappa_Ye"] * parameters["Ye_wind"] + else: + parameters["Ye_wind"] = parameters["Ye_dyn"] / parameters["kappa_Ye"] + if "kappa_v" in parameters: + if "v_ej_wind" in parameters: + parameters["v_ej_dyn"] = parameters["kappa_v"] * parameters["v_ej_wind"] + else: + parameters["v_ej_wind"] = parameters["v_ej_dyn"] / parameters["kappa_v"] + return super().parameter_conversion(parameters) class GRBMixin: def __init__(self, *args, resolution=12, **kwargs): @@ -741,6 +756,8 @@ class FiestaGRBModel(GRBMixin,FiestaModel): List of filters to create model for. Defaults to all available filters. surrogate_dir: str, optional path to the directory containing the surrogate models. + em_model_kwargs: optional + Additional keyword arguments to be passed to the underlying Fiesta GRB model. Returns ------- @@ -748,16 +765,12 @@ class FiestaGRBModel(GRBMixin,FiestaModel): A light curve model object to evaluate the light curve from a set of parameters. """ - def __init__(self, model="afgpy_gaussian_CVAE", filters=None, surrogate_dir=None, **kwargs): + def __init__(self, model="afgpy_gaussian_CVAE", filters=None, surrogate_dir=None, **em_model_kwargs): from fiesta.inference.lightcurve_model import AfterglowFlux - fiesta_kwargs= dict( name=model, filters=filters, directory=surrogate_dir,) - try: - fiesta_model = AfterglowFlux(**fiesta_kwargs) - except OSError: - fiesta_kwargs['directory'] = f'{surrogate_dir}/GRB/{model}/model' - fiesta_model = AfterglowFlux(**fiesta_kwargs) - - super().__init__(fiesta_model, filters, sample_times=kwargs.get('sample_times', None)) + fiesta_kwargs= dict( name=model, filters=filters, directory=surrogate_dir) + fiesta_model = AfterglowFlux(**fiesta_kwargs) + + super().__init__(fiesta_model, filters, sample_times=em_model_kwargs.get('sample_times', None)) class GRBLightCurveModel(GRBMixin, LightCurveModelContainer): @@ -778,6 +791,8 @@ class GRBLightCurveModel(GRBMixin, LightCurveModelContainer): Type of jet for the GRB model. Default is 0. filters: list of str, optional List of filters to create model for. Defaults to all available filters. + em_model_kwargs: optional + Additional keyword arguments to be passed to the underlying afterglowpy model. Returns ------- @@ -793,10 +808,11 @@ def __init__( jet_type=0, filters=None, sample_times=None, + **em_model_kwargs ): super().__init__(model, filters, model_parameters, sample_times, resolution=resolution) self.jet_type = jet_type - self.default_parameters = {"xi_N": 1.0, "d_L": 3.086e19, "jetType": jet_type, "specType": 0} # d_L=10pc in cm + self.default_parameters = {"xi_N": 1.0, "d_L": 3.086e19, "jetType": jet_type, "specType": 0, **em_model_kwargs} # d_L=10pc in cm self.def_keys = self.default_parameters.keys() #keys we typically sample in log space, but need to convert to linear space self.log_sampling_keys = ["E0", "n0", "epsilon_e", "epsilon_B"] @@ -874,6 +890,8 @@ class HostGalaxyLightCurveModel(LightCurveModelContainer): Magnitude of the host galaxy. Default is 23.9. model_parameters: list, optional List of alternative model parameters. If not specified, default will be used. + em_model_kwargs: optional + Additional keyword arguments, not used, but provided for consistency. Returns ------- @@ -888,7 +906,8 @@ def __init__( sample_times=None, # host_mag is the magnitude of the host galaxy in the filters host_mag=23.9, # value for case of arxiv:2303.12849 - model_parameters=None + model_parameters=None, + **em_model_kwargs ): super().__init__(model, filters, model_parameters, sample_times=sample_times) if isinstance(host_mag, (float, int)): @@ -915,6 +934,8 @@ class SupernovaLightCurveModel(LightCurveModelContainer): List of filters to create model for. Defaults to all available filters. model_parameters: list, optional List of alternative model parameters. If not specified, default will be used. + em_model_kwargs: optional + Additional keyword arguments to be passed to the underlying sncosmo model. Returns ------- @@ -927,15 +948,25 @@ def __init__( model="nugent-hyper", filters=None, sample_times=None, - model_parameters=None + model_parameters=None, + **em_model_kwargs ): - self.sn_model = sncosmo.Model(source=model) + if isinstance(model, str): + self.sn_model = sncosmo.Model(source=model, **em_model_kwargs) + else: + self.sn_model = model if model_parameters is not None: print("Warning: model_parameters are ignored for SupernovaLightCurveModel, using sncosmo defaults.") model_parameters = self.sn_model.param_names if sample_times is None: sample_times = np.linspace(self.sn_model.mintime(), self.sn_model.maxtime(), 200) + + if (sample_times < 0).any(): + # NOTE: We assume this means the sncosmo model is relative to peak time. + sample_times += sample_times[0] + print(f"Warning: Some supernova models are relative to the peak, some relative to the explosion time, " + "but nmma always expects times relative to the explosion time. Adjust your t0 prior accordingly." ) super().__init__(model, filters, model_parameters, sample_times) def em_parameter_setup(self, parameters): @@ -943,22 +974,37 @@ def em_parameter_setup(self, parameters): self.sn_model.set(**lc_pars) def combine_lc_params(self, parameters): - # FIXME: This should probably be removed, use sncosmo parameters instead + self.stretch = parameters.get('supernova_mag_stretch', 1.) + parameters["t0"] = parameters.get("t0", 0.) parameters["z"] = self.redshift return {p: parameters.get(p, self.sn_model.get(p)) for p in self.model_parameters} - def generate_lightcurve(self, sample_times, parameters): - self.em_parameter_setup(parameters) - + def gen_detector_lc(self, parameters = None, sample_times=None): + """Generate a light curve for given parameters as observable in detector frame. + Parameters + ---------- + parameters: dict + Parameters of the Supernova model. + sample_times: times at which to explore the light curve. If None, uses the default times for the model.""" - mag = lc_gen.sn_lc(sample_times / self.stretch, self.sn_model, + if sample_times is None: + sample_times = self.model_times + + # convert the parameters to the fiesta model parameters + self.em_parameter_setup(parameters) + mag = lc_gen.sn_lc(sample_times / self.stretch/(1 + self.redshift), self.sn_model, self.default_filts, self.lambdas) - - return {filt: filt_mag for filt, filt_mag in mag.items()} - + + # apply the extinction correction + ext_mag = self.get_extinction_mags() + obs_mags = self.apply_extinction_correction(mag, ext_mag, self.default_filts) + + # we are in observer frame, but still need to add the timeshift + return (sample_times + self.timeshift, obs_mags) + class ShockCoolingLightCurveModel(LightCurveModelContainer): def __init__( @@ -974,6 +1020,10 @@ def __init__( Name of the model. Default is "Piro2021". filters: list of str, optional List of filters to create model for. Defaults to all available filters. + model_parameters: list, optional + List of alternative model parameters. If not specified, default will be used. + em_model_kwargs: optional + Additional keyword arguments, not used, but provided for consistency. Returns ------- @@ -1020,6 +1070,8 @@ class SimpleKilonovaLightCurveModel(LightCurveModelContainer): Name of the model. Default is "Me2017". filters: list of str, optional List of filters to create model for. Defaults to all available filters. + em_model_kwargs: optional + Additional keyword arguments, not used, but provided for consistency. Returns ------- @@ -1028,7 +1080,7 @@ class SimpleKilonovaLightCurveModel(LightCurveModelContainer): from a set of parameters. """ def __init__( - self, model="Me2017", filters=None, sample_times=None + self, model="Me2017", filters=None, sample_times=None, **em_model_kwargs ): super().__init__(model, filters, sample_times=sample_times) lc_dict={ @@ -1311,10 +1363,12 @@ def single_model_from_args(model_class, model_name, args, # update explicit args model_args = default_model_args | dict(filters=filters, - sample_times=utils.setup_sample_times(args), + sample_times=utils.setup_sample_times(args) ) - if model_name is not None: + if model_name: model_args["model"] = model_name.strip() + if args.em_model_kwargs: + model_args|= args.em_model_kwargs return model_class(**model_args) def create_light_curve_model_from_args( diff --git a/nmma/em/plotting_utils.py b/nmma/em/plotting_utils.py index 797e2708..e194f293 100644 --- a/nmma/em/plotting_utils.py +++ b/nmma/em/plotting_utils.py @@ -1,120 +1,203 @@ import matplotlib.pyplot as plt import matplotlib +from matplotlib.ticker import MaxNLocator + from mpl_toolkits.axes_grid1 import make_axes_locatable import numpy as np import os -matplotlib.use("agg") params = { "backend": "pdf", "figure.figsize": [18, 25], } matplotlib.rcParams.update(params) matplotlib.rcParams['text.usetex'] = (os.environ.get("CI") != 'true') +matplotlib.rcParams['figure.labelsize'] = 20 +# matplotlib.rcParams['axes.titlesize'] = 'large' +if os.environ.get("CI") == 'true': + matplotlib.use("agg") +from nmma.core.plotting_utils import fig_setup +nmma_colors = fig_setup() ############################################## ################# MAIN PLOTS ################# ############################################## def basic_em_analysis_plot( - transient, mags_to_plot, error_dict, chi2_dict, mismatches, - sub_model_plot_props, xlim, ylim, save_path, - ncol = 2): + transient, plot_filters, mags_to_plot, error_dict, chi2_dict, + mismatches, sub_model_plot_props, xlim, ylim, save_path, + ncol = 2, fig = None, shared_data = True, **kwargs): + + + if fig: + return add_em_analysis_plot(fig, transient, mags_to_plot, error_dict, mismatches, sub_model_plot_props, xlim, ylim, save_path, shared_data, **kwargs) time = mags_to_plot.pop("time") - filter_names = list(mags_to_plot.keys()) + filter_names = list(plot_filters.keys()) + n_filters = len(filter_names) + data_colors = plt.cm.plasma(np.linspace(0, 1, len(filter_names)))[::-1] + color = kwargs.pop('color', next(nmma_colors)) + fig, axes = analysis_plot_geometry(filter_names, ncol=ncol) + fig.supylabel("AB magnitude", rotation=90) + fig.supxlabel("Time [days]") if xlim is None: xlim = get_time_limits_from_obs_data(transient) else: xlim = check_limit(xlim) - + if ylim is None: ylim = get_mag_limits_from_obs_data(transient, filter_names) else: shared_ylim = check_limit(ylim) ylim = {filt: shared_ylim for filt in filter_names} - fig, axes = analysis_plot_geometry(filter_names, ncol=ncol) - colors = plt.cm.Spectral(np.linspace(0, 1, len(filter_names)))[::-1] + for cnt, filt in enumerate(filter_names): - # summary plot row, col = divmod(cnt, ncol) ax_sum = axes[row, col] + ax_sum.set_ylim(ylim[filt]) + ax_sum.set_xlim(xlim) + if xlim[0] > 0: + ax_sum.set_xscale('log') + ax_sum.set_ylabel(plot_filters[filt]) + # adding the ax for the Delta divider = make_axes_locatable(ax_sum) ax_delta = divider.append_axes('bottom', - size='30%', - sharex=ax_sum) + size='40%', + sharex=ax_sum, + pad=0.1) + ax_delta.set_ylabel(r"$\Delta$ mag") - # configuring ax_sum - ax_sum.set_ylabel("AB magnitude", rotation=90) - ax_delta.set_ylabel(r"$\Delta (\sigma)$") - if cnt == len(filter_names)-1: - ax_delta.set_xlabel("Time [days]") + # ax_delta.set_yscale('log') + # ax_delta.yaxis.set_major_locator(MaxNLocator(min_n_ticks=2)) + + if cnt not in [n_filters - i-1 for i in range(ncol)]: + ax_sum.set_xticklabels([]) else: - ax_delta.set_xticklabels([]) + plt.setp(ax_sum.get_xticklabels(), visible=False) + # only show x labels on the lowest delta plots + + # configuring ax_sum + + if cnt not in [n_filters - i-1 for i in range(ncol)]: # only show x labels on the last two plots + ax_sum.set_xticklabels([]) # plot the observations - ax_sum, det_times = plot_observations(ax_sum, transient, colors[cnt], filter=filt) + ax_sum, det_times = plot_observations(ax_sum, transient, data_colors[cnt], filter=filt) if det_times.size>0: # plot the mismatch between the model and the data - diff_per_data, sigma_per_data = mismatches[filt] + diff_per_data, _ = mismatches[filt] ax_delta.axhline(0, linestyle='--', color='k') - ax_delta.scatter(det_times, diff_per_data, color=colors[cnt]) - - ax_sum.set_title(f'{filt}: ' + fr'$\chi^2 / d.o.f. = {round(chi2_dict[filt], 2)}$') + ax_delta.scatter(det_times, diff_per_data, color=color) + ax_sum.plot([], [], label=fr'$\chi^2$ / d.o.f. = {round(chi2_dict[filt], 2)}', color=color) + ax_sum.legend(loc='best', frameon=False, handlelength=0, handletextpad=0) - else: - ax_sum.set_title(f'{filt}') # plot the best-fit lc with errors - mag_plot = mags_to_plot[filt] - error_budget = error_dict[filt] - label = 'combined' if sub_model_plot_props is not None else "" - - ax_sum.plot(time, mag_plot, - color='coral', linewidth=3, linestyle="--") - ax_sum.fill_between(time, - mag_plot + error_budget, - mag_plot - error_budget, - facecolor='coral', - alpha=0.2, - label=label, - ) + plot_bestfit_with_errors(ax_sum, time, mags_to_plot[filt], error_dict[filt], sub_model_plot_props, cnt, color) - if sub_model_plot_props is not None: - ## plot additional lcs for each sub_model - for model_name, prop_dict in sub_model_plot_props.items(): - mag_plot = prop_dict['plot_mags'][cnt] - mag_err = prop_dict['plot_errors'][cnt] - plot_times = prop_dict['plot_times'] - ax_sum.plot(plot_times, mag_plot, - color='coral', linewidth=3, linestyle="--") - ax_sum.fill_between(plot_times, - mag_plot + mag_err, - mag_plot - mag_err, - facecolor=prop_dict['color'], - alpha=0.2, - label=model_name, - ) + + fig.tight_layout() + if save_path: + fig.savefig(save_path, bbox_inches='tight') + return fig + +def add_em_analysis_plot(fig, + transient, mags_to_plot, error_dict, mismatches, + sub_model_plot_props, xlim, ylim, save_path, shared_data = True, **kwargs): + + time = mags_to_plot.pop("time") + filter_names = list(mags_to_plot.keys()) + n_axes = int(np.ceil(len(fig.axes)/2)) + + data_colors = plt.cm.plasma(np.linspace(0, 1, len(filter_names)))[::-1] + color = kwargs.pop('color', next(nmma_colors)) + + if not shared_data: + if xlim is None: + add_xlim = get_time_limits_from_obs_data(transient) + old_xlim = fig.axes[0].get_xlim() + xlim = (min(add_xlim[0], old_xlim[0]), max(add_xlim[1], old_xlim[1])) + if ylim is None: + add_ylim = get_mag_limits_from_obs_data(transient, filter_names) + ylim = {} + for i, filt in enumerate(filter_names): + old_ylim = fig.axes[i].get_ylim() + ylim[filt] = (max(add_ylim[filt][0], old_ylim[0]), min(add_ylim[filt][1], old_ylim[1])) + + for cnt, filt in enumerate(filter_names): - ax_sum.set_ylim(ylim[filt]) - ax_sum.set_xlim(xlim) - ax_delta.set_xlim(xlim) - if xlim[0] > 0: - ax_sum.set_xscale('log') + # summary plot + ax_sum = fig.axes[cnt] + ax_delta = fig.axes[cnt+n_axes] - plt.tight_layout() - plt.savefig(save_path, bbox_inches='tight') - plt.close() + # plot the observations + ax_sum, det_times = plot_observations(ax_sum, transient, data_colors[cnt], filter=filt) + if not shared_data: + ax_sum.set_xlim(xlim) + ax_sum.set_ylim(ylim[filt]) + ax_delta.set_xlim(xlim) + + if det_times.size>0: + # plot the mismatch between the model and the data + diff_per_data, _ = mismatches[filt] + delta_ylim = ax_delta.get_ylim() + ylim = (min(delta_ylim[0], 0.9*min(diff_per_data)), + max(delta_ylim[1], 1.1*max(diff_per_data))) + ax_delta.set_ylim(ylim) + ax_delta.axhline(0, linestyle='--', color='k') + ax_delta.scatter(det_times, diff_per_data, color=color) + + ax_sum.get_legend().remove() + # ax_sum.plot([], [], label=round(chi2_dict[filt], 2), color=color) + # ax_sum.legend(loc='best', frameon=False, handlelength=0, handletextpad=0, labelcolor='linecolor') + + plot_bestfit_with_errors(ax_sum, time, mags_to_plot[filt], error_dict[filt], sub_model_plot_props, cnt, color) + + + fig.tight_layout() + if save_path: + fig.savefig(save_path, bbox_inches='tight') + return fig + +def plot_bestfit_with_errors(ax_sum, time, mag_plot, error_budget, + sub_model_plot_props, cnt, color): + + label = 'combined' if sub_model_plot_props is not None else "" + ax_sum.plot(time, mag_plot, + color=color, linewidth=3, linestyle="--") + ax_sum.fill_between(time, + mag_plot + error_budget, + mag_plot - error_budget, + facecolor=color, + alpha=0.2, + label=label, + ) + + if sub_model_plot_props is not None: + ## plot additional lcs for each sub_model + for model_name, prop_dict in sub_model_plot_props.items(): + mag_plot = prop_dict['plot_mags'][cnt] + mag_err = prop_dict['plot_errors'][cnt] + plot_times = prop_dict['plot_times'] + ax_sum.plot(plot_times, mag_plot, + color=color, linewidth=3, linestyle="--") + ax_sum.fill_between(plot_times, + mag_plot + mag_err, + mag_plot - mag_err, + facecolor=prop_dict['color'], + alpha=0.2, + label=model_name, + ) + def bolometric_lc_plot(transient, time, lc, save_path, color = "coral"): matplotlib.rcParams.update( - {'font.size': 12, - # 'font.family': 'Times New Roman' + {'font.size': 12, 'font.family': 'serif' } ) fig, ax = plt.subplots(1, 1) @@ -181,7 +264,7 @@ def spec_plot_func(fig, ax, XX, YY, plot_data, label): def chi2_hists_from_dict(chi2_dict, outpath): matplotlib.rcParams.update( - {"font.size": 16, "font.family": "Times New Roman"} + {"font.size": 16, "font.family": "Serif"} ) for filt, chi2_array in chi2_dict.items(): plt.figure() @@ -292,7 +375,7 @@ def return_hist(x): hist, _ = np.histogram(x, bins=bins) return hist - hist = np.apply_along_axis(lambda x: return_hist(x), -1, plot_data.T) + hist = np.apply_along_axis(return_hist, -1, plot_data.T) bins = (bins[1:] + bins[:-1]) / 2.0 X, Y = np.meshgrid(sample_times, bins) @@ -353,27 +436,6 @@ def check_limit(lim): assert len(lim) == 2, f"{lim} is no valid plot-limit." return lim -def get_time_limits_from_obs_data(transient): - """ - A function that goes through the lc data and finds the time range that encompasses all data points. - """ - - xmin = np.min([t_arr.min() for t_arr in transient.light_curve_times.values()]) - xmax = np.max([t_arr.max() for t_arr in transient.light_curve_times.values()]) - - return (0.8*xmin, 1.2*xmax) - -def get_mag_limits_from_obs_data(transient, filter_names): - """ - A function that goes through the lc data and finds the magnitude range for each filter. - """ - - ylim = {} - for filt in filter_names: - ylim[filt] = (1.2*transient.light_curves[filt].max(), 0.8*transient.light_curves[filt].min()) - - return ylim - def plot_observations(ax, transient, color="k",**kwargs): obs_times, obs_lc, obs_unc = transient.light_curve_times, transient.light_curves, transient.light_curve_uncertainties if 'filter' in kwargs: @@ -401,7 +463,6 @@ def analysis_plot_geometry(filters_to_plot, ncol=2): wpanel = 3. nrow = int(np.ceil(len(filters_to_plot) / ncol)) - fig, axes = plt.subplots(nrow, ncol) figsize = (1.5 * (lspace + wpanel * ncol + wspace * (ncol - 1) + trspace), 1.5 * (bspace + hpanel * nrow + hspace * (nrow - 1) + trspace)) @@ -409,25 +470,27 @@ def analysis_plot_geometry(filters_to_plot, ncol=2): fig, axes = plt.subplots(nrow, ncol, figsize=figsize, squeeze=False) fig.subplots_adjust(left=lspace / figsize[0], bottom=bspace / figsize[1], - right=1. - trspace / figsize[0], - top=1. - trspace / figsize[1], + right = 1. - trspace / figsize[0], + top = 1. - trspace / figsize[1], wspace=wspace / wpanel, hspace=hspace / hpanel) - if len(filters_to_plot) % 2: - axes[-1, -1].axis('off') + if len(filters_to_plot) % ncol: + for i in range(len(filters_to_plot) % ncol, ncol): + axes[-1, i-ncol].axis('off') return fig, axes + def get_time_limits_from_obs_data(transient): """ A function that goes through the lc data and finds the time range that encompasses all data points. """ - - xmin = np.min([t_arr.min() for t_arr in transient.light_curve_times.values()]) - xmax = np.max([t_arr.max() for t_arr in transient.light_curve_times.values()]) - return (0.8*xmin, 1.2*xmax) + xmin = np.min([t_arr.min() for t_arr in transient.light_curve_times.values()]) + xmax = np.max([t_arr.max() for t_arr in transient.light_curve_times.values()]) + + return (0.9*xmin, 1.1*xmax) def get_mag_limits_from_obs_data(transient, filter_names): """ @@ -436,7 +499,8 @@ def get_mag_limits_from_obs_data(transient, filter_names): ylim = {} for filt in filter_names: - ylim[filt] = (1.2*transient.light_curves[filt].max(), 0.8*transient.light_curves[filt].min()) - + min_mag = transient.light_curves[filt].min() + max_mag = transient.light_curves[filt].max() + ylim[filt] = (min(1.05*max_mag, 1+ max_mag), max(min_mag-1, 0.95*min_mag)) return ylim diff --git a/nmma/em/prior.py b/nmma/em/prior.py index f054641a..7e88ab30 100644 --- a/nmma/em/prior.py +++ b/nmma/em/prior.py @@ -229,7 +229,7 @@ def create_prior_from_args(args, systematics_handler): lc_model : nmma.em.model.LightCurveModelContainer Light curve model object to compute light curves """ - priors = PriorDict(args.prior_file) + priors = PriorDict(args.prior_file) if getattr(args, 'prior_file', None) else PriorDict(args.prior) priors = adjust_hubble_prior(priors, args) priors = extinction_prior(priors, args) diff --git a/nmma/em/utils.py b/nmma/em/utils.py index 72cb62ee..e4fca364 100644 --- a/nmma/em/utils.py +++ b/nmma/em/utils.py @@ -77,7 +77,7 @@ def setup_sample_times(args): return np.arange(tmin, tmax + args.em_tstep, args.em_tstep) # otherwise, create the array based on selected scale - if 'lin' in args.em_timescale: + if 'lin' in args.em_timescale or tmin<=0.: return np.linspace(tmin, tmax, args.em_nsteps) elif any(scale in args.em_timescale for scale in ['log', 'geo']): return np.geomspace(tmin, tmax, args.em_nsteps) @@ -103,6 +103,7 @@ def set_filters(args): em_detectors = args.em_detectors.split(",") else: em_detectors = getattr(args, "em_detectors", []).copy() + em_detectors = [det.strip().lower() for det in em_detectors] filters = [] if 'ztf' in em_detectors: em_detectors.remove('ztf') @@ -110,7 +111,7 @@ def set_filters(args): if 'lsst' in em_detectors: em_detectors.remove('lsst') filters.extend( ["lsstg", "lsstr", "lssti", "lsstz", "lssty"] ) - elif hasattr(args, "rubin_ToO_type"): + elif getattr(args, "rubin_ToO_type"): if args.rubin_ToO_type == 'platinum': filters.extend( ["ps1::g","ps1::r","ps1::i","ps1::z","ps1::y"] ) elif args.rubin_ToO_type == 'gold': @@ -129,7 +130,6 @@ def set_filters(args): if em_detectors: raise NotImplementedError(f"{em_detectors} not implemented yet.") ## to be extended - return filters @@ -198,7 +198,6 @@ def set_filter_associated_dict(quantity, filters, default_limit = np.inf): def cut_data_to_time_range(data, args, trigger_time, tmin = 0, tmax = np.inf): - tmin = getattr(args, "data_tmin", tmin) tmax = getattr(args, "data_tmax", tmax) @@ -251,7 +250,6 @@ def setup_filtered_lc_data(light_curve_data, trigger_time): def check_model_time_consistency(light_curve_data, light_curve_model, priors, injection = None): (lc_times, lc_mags, lc_uncertainties, trigger_time) = light_curve_data - data_tmin = np.min([lc_times[key].min() for key in lc_times.keys()]) data_tmax = np.max([lc_times[key].max() for key in lc_times.keys()]) diff --git a/nmma/eos/eos_likelihood.py b/nmma/eos/eos_likelihood.py index 289d4745..1db3cab2 100644 --- a/nmma/eos/eos_likelihood.py +++ b/nmma/eos/eos_likelihood.py @@ -4,15 +4,19 @@ import shutil import json from ast import literal_eval +import matplotlib from tqdm.contrib.concurrent import process_map import numpy as np from scipy.special import logsumexp -from scipy.stats import norm, gaussian_kde +from scipy.stats import norm +from scipy.ndimage import gaussian_filter +# from scipy.spatial import ConvexHull from matplotlib import pyplot as plt from bilby.core.prior import WeightedCategorical, PriorDict from .eos_processing import EoSConverter from ..core.base import NMMALikelihood -from ..core.utils import fading_cmap +from ..core.utils import nan_level +from ..core.plotting_utils import fading_cmap def setup_tabulated_eos_priors(args, priors, logger=None): if logger: @@ -60,27 +64,69 @@ def setup_submodel_conversion(self): self.conv_functions.append(self.sub_model.parameter_conversion) - def final_diagnostics(self, bestfit_params, args, result=None): + def final_diagnostics(self, bestfit_params, args, result=None, fig = None): + matplotlib.rcParams.update({'font.size': 16, 'font.family': 'serif'}) + bestfit_params =self.parameter_conversion(bestfit_params) - - color_densities = [] radii, masses, lambdas = self.sub_model.eos_converter.macro_parameters.values() - x_lim = (np.min(radii)-0.3, np.max(radii)+0.3) - y_lim = (masses[0], masses[-1]+0.1) - fig, ax = plt.subplots(figsize=(10, 6)) - ax.set_xlim(x_lim) - ax.set_ylim(y_lim) - for constraint in self.sub_model.constraints: - color = ax._get_lines.get_next_color() - ax = constraint.plot(ax=ax, resolution=100, color=color) - ax.plot(radii, masses, label='Best fit EOS', zorder=10) - - ax.set_xlabel(r'Radius [km]') - ax.set_ylabel(r'Mass [M$_\odot$]') - ax.legend() - fig.savefig(os.path.join(args.outdir, f"{args.label}_mr_curve.png")) - plt.show() + if fig is None: + x_lim = (min(np.min(radii)-0.3, 9), max(np.max(radii)+0.3, 15)) + y_lim = (masses[0], masses[-1]+0.1) + fig, ax = plt.subplots(figsize=(10, 10)) + ax.set_xlim(x_lim) + ax.set_ylim(y_lim) + ax.set_xlabel(r'Radius [km]') + ax.set_ylabel(r'Mass [M$_\odot$]') + + for constraint in self.sub_model.constraints: + color = ax._get_lines.get_next_color() + ax = constraint.plot(ax=ax, color=color) + else: + ax = fig.axes[0] + xlow, xhigh = ax.get_xlim() + ylow, yhigh = ax.get_ylim() + ax.set_xlim(min(np.min(radii)-0.3, xlow), max(np.max(radii)+0.3, xhigh)) + ax.set_ylim(min(masses[0], ylow), max(masses[-1]+0.1, yhigh)) + + labels = [line.get_label() for line in fig.legends[0].legend_handles] + for constraint in self.sub_model.constraints: + if constraint.name in labels: + continue + color = ax._get_lines.get_next_color() + ax = constraint.plot(ax=ax, color=color) + fig.legends.clear() ## remove old legend to avoid duplicates + + + line =ax.plot(radii, masses, label=f'{args.label}',linewidth=3, zorder=10) + + if result is not None: + cmap = fading_cmap(line[0].get_color()) + posterior = self.parameter_conversion(result.posterior) + # posterior['log_post'] = posterior.log_likelihood + posterior.log_prior + # post_weights = np.exp(posterior.log_post) + post_weights = None + eos_data = self.sub_model.eos_converter.macro_conversion(posterior) + max_masses = np.max([eos[1][-1] for eos in eos_data]) + mass_range = np.linspace(1.0, max_masses, 151) + show_radii = np.empty((len(mass_range), len(eos_data))) + + for i, eos in enumerate(eos_data): + show_radii[:,i] = np.interp(mass_range, eos[1], eos[0], right = np.nan) + + for level in [0.5, 0.9]: + bounds = np.array([nan_level(radii, level, post_weights) for radii in show_radii]) + ax.fill_betweenx(mass_range, bounds[:, 0], bounds[:, 1], color=cmap(1-0.5*level), zorder=1) + + if result.injection_parameters is not None: + inj_eos = self.sub_model.eos_converter.macro_conversion(result.injection_parameters) + ax.plot(inj_eos[0], inj_eos[1], label='Injection', color='black', linestyle='dashed', linewidth=3, zorder=10) + + fig.legend(ncols=2, loc='upper center', + bbox_to_anchor=(0.5, 0.00), handlelength=2) + fig.tight_layout() + fig.savefig(os.path.join(args.outdir, f"{args.label}_mr_curve.png"),bbox_inches='tight') + return fig def compose_eos_constraints(args, constraint_kinds=['lower_mtov', 'upper_mtov', 'mass_radius']): @@ -192,7 +238,7 @@ def initialise_from_dict(self, constraint_dict): elif constraint_kind == 'mass_radius': for label, constraint in sub_constraints.items(): constraint_list.append(MassRadiusConstraint( - file_path=constraint.get('posterior', None), + file_path=constraint.get('posterior', constraint.get('file_path', None)), name=label, arxiv_ref=constraint.get('arxiv', None) )) @@ -200,7 +246,7 @@ def initialise_from_dict(self, constraint_dict): raise ValueError('Unknown type of EoS Constraint. Must be "lower_mtov", \ "upper_mtov", "mass-radius" or "micro\ ') - return constraint_list + return constraint_list def parameter_conversion(self, parameters): return self.eos_converter.parameter_conversion(parameters) @@ -300,6 +346,7 @@ def __init__(self, measured_mass, measure_error, name=None, arxiv_ref=None, logn self.error = measure_error self.repr_add = f'of {measured_mass}+-{measure_error} M_sun' self.lognorm_method = lognorm_method + self.linestyle = '--' def __repr__(self): @@ -317,19 +364,20 @@ def log_likelihood(self, parameters, local_parameters=None): tov_mass = local_parameters['masses'][-1] return self.lognorm_method(tov_mass, loc=self.mass, scale=self.error) - def plot(self, ax, resolution = 100, **kwargs): + def plot(self, ax, **kwargs): """Plot the mass constraint on the given figure.""" x_lim = ax.get_xlim() - y_lim = ax.get_ylim() - M_grid = np.linspace(*y_lim, num=resolution) - show_x, show_y = np.meshgrid(np.linspace(*x_lim, resolution), M_grid) - line = ax.hlines(self.mass, *x_lim, linestyle='--', linewidth=1.5, zorder=3, label=self.name, **kwargs) - cmap = fading_cmap(line.get_color()) - shade_profile = norm.pdf(M_grid, loc=self.mass, scale=self.error) - shading_matrix = np.repeat(shade_profile[:, np.newaxis], resolution, axis=1) - ax.contourf(show_x, show_y, shading_matrix, levels=50, cmap=cmap) + dummy_line = ax.plot([], [], label=self.name, linestyle=self.linestyle, linewidth=2.5, **kwargs) + line = ax.hlines(self.mass, *x_lim, linestyle=self.linestyle, linewidth=2.5, zorder=3, **kwargs) + cmap = fading_cmap(dummy_line[0].get_color()) + levels = [0.95, 0.68] + for i, level in enumerate(levels): + ax.fill_between(x_lim, self.mass - (i+1)*self.error, self.mass + (i+1)*self.error, color=cmap(0.8*level), zorder=2-i) + # ax.hlines(self.mass + (i+1)*self.error, *x_lim, color=cmap(0.9*level), linewidth=1.5, zorder=1) + # ax.hlines(self.mass - (i+1)*self.error, *x_lim, color=cmap(0.9*level), linewidth=1.5, zorder=1) return ax + class LowerMTOVConstraint(MassConstraint): '''Constraint that an EOS supports at least a certain TOV mass(within Gaussian uncertainty)''' def __init__(self, measured_mass, measure_error, name=None, arxiv_ref=None): @@ -364,6 +412,7 @@ def __init__(self, measured_mass, measure_error, name=None, arxiv_ref=None): Identifier of a relevant source """ super().__init__(measured_mass, measure_error, name, arxiv_ref, lognorm_method=norm.logsf) + self.linestyle = ':' class MassRadiusConstraint(EoSConstraint): @@ -389,22 +438,9 @@ def __init__(self, mass_array=None, radius_array=None, weights = None, file_path mass_array, radius_array, weights = self.read_data(file_path) elif mass_array is None or radius_array is None: raise ValueError('Must provide data for masses and radii as arrays or file from which to load') - - if len(radius_array) > 10000: - ratio = len(radius_array) // 10000 - else: - ratio = 1 - - radius = radius_array[::ratio] - mass = mass_array[::ratio] - if weights is not None: - weights = weights[::ratio] - - self.KDE = gaussian_kde((radius, mass), weights=weights) - self.test_masses= np.linspace(start=1., stop=2.5, num=150 ) # 1 to 2.5 Msun - self.rng = np.random.default_rng() - - + self.set_grid(mass_array, radius_array, weights) + self.test_masses = np.linspace(1.2, 2.5, 151) + def read_data(self, file_path): """Read mass-radius data from a file.""" data = np.loadtxt(file_path, unpack=True) @@ -443,12 +479,45 @@ def read_data(self, file_path): return masses, radius, weights - def log_likelihood(self, parameters, local_parameters): + def set_grid(self, masses, radii, weights, mass_step = 0.01, radius_step = 0.03): + """Set up a grid upon which to build a histogram of mass-radius data to approximate the pdf. + Note that when using multiple mass-radius measurements, all measurements should use the same stepsizes! + Parameters + ---------- + masses: np.array + Array with mass posterior of M-R measurement + radii: np.array + Array with radius posterior of M-R measurement, must be specified along an equal-length mass_array + weights: np.array, optional + Array with weights of the M-R samples, must be specified along an equal-length mass_array + mass_step: float + step size for mass grid in solar masses, default is 0.01 Msun + radius_step: float + step size for radius grid in km, default is 0.02 km (20 m) + """ + mass_bins = self.set_bins(masses, mass_step) + rad_bins = self.set_bins(radii, radius_step) + if 3*len(mass_bins)*len(rad_bins) > len(masses): + print("Warning: The histogram might be to sparsely populated to get meaningful results.") + + histogram, self.rad_edges, self.mass_edges = np.histogram2d(radii, masses, bins=[rad_bins, mass_bins], weights=weights, density=True) + drad = self.rad_edges[1] - self.rad_edges[0] + dmass = self.mass_edges[1] - self.mass_edges[0] + + self.histogram = gaussian_filter(histogram*dmass*drad, sigma=3) + + def set_bins(self, array, step_size, sensitivity=0.001): + low, high = np.quantile(array, [sensitivity, 1.- sensitivity]) + bins = np.arange(0.95*low, 1.05*high, step_size, dtype=np.float64) + return bins + + def log_likelihood(self, parameters, local_parameters): try: tov_mass = parameters.get('TOV_mass', local_parameters['masses'][-1]) return self.single_logl(tov_mass, local_parameters['masses'], local_parameters['radii']) except (ValueError, IndexError): + self.single_logl(tov_mass, local_parameters['masses'], local_parameters['radii']) return [ self.single_logl(masses[-1], masses, local_parameters['radii'][i]) for i, masses in enumerate(local_parameters['masses']) @@ -458,30 +527,31 @@ def single_logl(self, tov_mass, masses, radii): ## interpolate radii along equally spaced mass grid up to MTov test_mass_range=self.test_masses[self.test_masses&to - except ValueError: # filename - with open(to, 'wb') as to_file: - os.dup2(to_file.fileno(), stdout_fd) # $ exec > to - try: - yield stdout # allow code to be run with the redirected stdout - finally: - # restore stdout to its previous value - #NOTE: dup2 makes stdout_fd inheritable unconditionally - stdout.flush() - os.dup2(copied.fileno(), stdout_fd) # $ exec >&copied - - def baryonic_mass(gravitational_mass, EOS, eos_path_macro, eos_path_micro): @@ -76,8 +43,7 @@ def TOVeq(y, x): x = np.arange(dr, r+dr, dr) y0 = [p0, m0] - with stdout_redirected(): - p_solv, m_solv = scipy.integrate.odeint(TOVeq, y0=y0, t = x).T + p_solv, m_solv = scipy.integrate.odeint(TOVeq, y0=y0, t = x).T n_solv = np.interp(p_solv, P, N) @@ -145,7 +111,6 @@ def __init__(self, prior, posterior_samples, Neos, eos_path_macro, eos_path_micr log10_mdisk = self.posterior_samples.log10_mdisk.to_numpy() self.KDE = scipy.stats.gaussian_kde((chirp_mass, eta_star, EOS, log10_mdisk, log10_mej_dyn)) - super().__init__(**kwargs) def Prior(self, x): @@ -171,7 +136,9 @@ def LogLikelihood(self, x): mdisk = 10**log10_mdisk mej_dyn = 10**log10_mej_dyn - m_rem_b = baryonic_mass(mass_1, EOS, self.eos_path_macro, self.eos_path_micro) + baryonic_mass(mass_2, EOS, self.eos_path_macro, self.eos_path_micro) - mdisk - mej_dyn #calculate the baryonic remnant mass + b1 = baryonic_mass(mass_1, EOS, self.eos_path_macro, self.eos_path_micro) + b2 = baryonic_mass(mass_2, EOS, self.eos_path_macro, self.eos_path_micro) + m_rem_b = b1 + b2 - mdisk - mej_dyn #calculate the baryonic remnant mass if self.use_M_max: m_threshold = baryonic_Kepler_mass(mTOV, R_14, ratio_R, delta) #if the Kepler limit is the threshold, use the quasiuniversal relation @@ -223,7 +190,6 @@ def maximum_mass_resampling(args): class PostmergerInference(PostmergerInferenceMixIn, Solver): pass - solution = PostmergerInference(prior, posterior_samples, Neos, args.eos_path_macro, args.eos_path_micro, args.use_M_Kepler, **pymulti_kwargs) diff --git a/nmma/post_processing/plotting_routines.py b/nmma/post_processing/plotting_routines.py index a500f28b..f64a3758 100644 --- a/nmma/post_processing/plotting_routines.py +++ b/nmma/post_processing/plotting_routines.py @@ -3,58 +3,30 @@ import numpy as np import pandas as pd import matplotlib +from matplotlib.ticker import MaxNLocator import seaborn from matplotlib import gridspec, pyplot as plt from ast import literal_eval -import itertools from ..core.conversion import chirp_mass_and_eta_to_component_masses, tidal_deformabilities_and_mass_ratio_to_eff_tidal_deformabilities, label_mapping from ..core import utils, parsing from ..core import plotting_utils as corepu from .parser import corner_plot_parser -color_array = corepu.fig_setup() -nmma_colors = itertools.cycle(color_array) +nmma_colors = corepu.fig_setup() - -def plot_multi_corner(args, key_selection=None): - - plot_kwargs = literal_eval(args.kwargs) - quantiles = [0.16, 0.5, 0.84] - fig = None - labels = [lab for lab in args.label_name] if args.label_name is not None else [f for f in args.posterior_files] - for i, f in enumerate(args.posterior_files): - plot_keys, plot_labels = corepu.plotting_parameters_from_priors(args.prior, keys=key_selection).items() - if args.injection_json is not None: - truths = utils.read_injection_file(args.injection_json) - truths = truths.iloc[args.injection_num].to_dict() - truths = np.array([truths[k] for k in plot_keys]) - if args.verbose: - print("\nLoaded Injection:") - print(f"Truths from injection: {truths}") - elif args.bestfit_params is not None: - truths = utils.read_bestfit_from_json(args.bestfit_json, plot_keys, args.verbose) - else: - truths = None - - fig = setup_corner_plot(f, [], label =labels[i], truths = truths, fig=fig, - quantiles=quantiles, plot_keys=plot_keys, default_labels = plot_labels, **plot_kwargs) - - filename, ext = os.path.splitext(args.output) - if not ext: - filename = os.path.join(os.getcwd(), f"{filename}.png") - fig.savefig(filename, bbox_inches="tight", dpi=300) - print("\nSaved corner plot:", filename) - - -def setup_corner_plot(posterior_samples,limits = None, plot_keys = None, fig = None, - injection=None, post_dir = None, default_labels={}, **plot_kwargs): +def setup_plot_quantities(posterior_samples, limits, plot_keys, injection, post_dir = None, default_labels={}, **plot_kwargs): + matplotlib.rcParams.update({'font.size': 16, 'font.family': 'serif'}) #load samples posterior_samples = utils.get_posteriors(posterior_samples, post_dir) + best_fit = posterior_samples.iloc[posterior_samples['log_likelihood'].idxmax()] if plot_keys is None: plot_keys = posterior_samples.columns.tolist() # show all we can + for key in ['log_likelihood', 'log_prior']: + if key in plot_keys: + plot_keys.remove(key) # but not the likelihood itself if limits is None: - limits = [(np.inf, -np.inf) for key in plot_keys] # will adjust more permissively later + limits = [(np.inf, -np.inf) for _ in plot_keys] # will adjust more permissively later # find what to actually plot plot_samples, plot_labels, titles = [], [], [] for i, k in enumerate(plot_keys): @@ -78,15 +50,112 @@ def setup_corner_plot(posterior_samples,limits = None, plot_keys = None, fig = N plot_samples.append(np.linspace(100*cur_max, 1001*cur_max, posterior_samples.shape[0])) plot_labels.append('') titles.append('') + plot_samples = np.column_stack(plot_samples) - if injection is not None: truths = [injection.get(key, None) for key in plot_keys] else: truths = None - # limits = ((np.amin(posterior_samples[k]), np.amax(posterior_samples[k])) for k in plot_keys) + + + return plot_samples, plot_labels, titles, limits, truths, best_fit + +def plot_histograms_only(posterior_samples,limits = None, plot_keys = None, fig = None, + injection=None, post_dir = None, default_labels={}, best_fit = False, ncols=None, loc_labels='left',title_kwargs ={}, **plot_kwargs): + plot_samples, plot_labels, titles, limits, truths, best_fit = setup_plot_quantities( + posterior_samples, limits, plot_keys, injection, post_dir, default_labels, **plot_kwargs) + + if fig is None: + fig, axes = corepu.setup_multi_axes(len(plot_keys), ncols=ncols) + else: + axes = fig.get_axes() + label = plot_kwargs.pop('label', None) + color = plot_kwargs.pop('color', next(nmma_colors)) + for i, ax in enumerate(axes): + if i >= len(plot_keys): + ax.axis('off') # Hide any extra subplots + continue + ax.hist(plot_samples[:, i], bins=50, density=True, color=color, alpha=0.7,histtype = 'step', **plot_kwargs) + if loc_labels == 'left': + ax.yaxis.set_label_position("left") + ax.set_ylabel(plot_labels[i], fontsize=16) + elif loc_labels == 'top': + ax.xaxis.set_label_position("top") + ax.set_xlabel(plot_labels[i], fontsize=16) + # ax.xaxis.set_label_coords(0.5, -0.3) + ax = set_title(ax, titles[i], plot_labels[i], color, **title_kwargs) + + if truths is not None and truths[i] is not None: + ax.axvline(truths[i], color='tab:orange', linestyle='--') + + if isinstance(best_fit, (dict, pd.Series)) and plot_keys[i] in best_fit: + ax.axvline(best_fit[plot_keys[i]], color=color, linestyle='-.') + + ax.xaxis.set_major_locator(MaxNLocator(nbins=4, min_n_ticks=3, prune='both')) + [l.set_rotation(45) for l in ax.get_xticklabels()] + [l.set_rotation(45) for l in ax.get_xticklabels(minor=True)] + ax.autoscale(enable=True, axis='x', tight=True) + ax.set_yticks([]) + ax.set_yticklabels([]) + + + # allow joint legend + if label: + fig.legends.clear() + fig.axes[0].plot([], [], label = label, color= color, **plot_kwargs) + fig.legend(ncols=2, loc='upper center', bbox_to_anchor=(0.5, 0.), handlelength=2) + + return fig, limits + + + + + +def plot_multi_corner(args, key_selection=None): + + plot_kwargs = literal_eval(args.kwargs) + quantiles = [0.16, 0.5, 0.84] + fig = None + labels = [lab for lab in args.label_name] if args.label_name is not None else [f for f in args.posterior_files] + for i, f in enumerate(args.posterior_files): + plot_keys, plot_labels = corepu.plotting_parameters_from_priors(args.prior, keys=key_selection).items() + if args.injection_json is not None: + truths = utils.read_injection_file(args.injection_json) + truths = truths.iloc[args.injection_num].to_dict() + truths = np.array([truths[k] for k in plot_keys]) + if args.verbose: + print("\nLoaded Injection:") + print(f"Truths from injection: {truths}") + elif args.bestfit_params is not None: + truths = utils.read_bestfit_from_json(args.bestfit_json, plot_keys, args.verbose) + else: + truths = None + + fig = setup_corner_plot(f, [], label =labels[i], truths = truths, fig=fig, + quantiles=quantiles, plot_keys=plot_keys, default_labels = plot_labels, **plot_kwargs) + + filename, ext = os.path.splitext(args.output) + if not ext: + filename = os.path.join(os.getcwd(), f"{filename}.png") + fig.savefig(filename, bbox_inches="tight", dpi=300) + print("\nSaved corner plot:", filename) + + +def setup_corner_plot(posterior_samples,limits = None, plot_keys = None, fig = None, + injection=None, post_dir = None, default_labels={}, **plot_kwargs): + + plot_samples, plot_labels, titles, limits, truths, _ = setup_plot_quantities( + posterior_samples, limits, plot_keys, injection, post_dir, default_labels, **plot_kwargs) + color = plot_kwargs.pop('color', next(nmma_colors)) - fig = corner_plot(plot_samples, plot_labels, limits, fig=fig, truths= truths, color = color, titles=titles, show_titles = False, **plot_kwargs) + fig = corner_plot(plot_samples, plot_labels, limits, fig=fig, truths= truths, color = color, show_titles = False, **plot_kwargs) + + + # adjust titles + axes = fig.get_axes() + for i, title in enumerate(titles): + ax = axes[i*len(plot_labels) + i] + ax = set_title(ax, title, plot_labels[i], color) # allow joint legend if 'label' in plot_kwargs: @@ -98,11 +167,11 @@ def setup_corner_plot(posterior_samples,limits = None, plot_keys = None, fig = N def corner_plot(plot_samples, labels, limits, fig = None, save=False, **kwargs): - matplotlib.rcParams.update({'font.size': 16, 'font.family': 'Times New Roman'}) + matplotlib.rcParams.update({'font.size': 16, 'font.family': 'Serif'}) matplotlib.rcParams['text.usetex'] = (os.environ.get("CI") != 'true') default_kwargs = dict(bins=50, smooth=1.3, label_kwargs=dict(fontsize=16), show_titles=True, - title_kwargs=dict(fontsize=16), color = color_array[0], #color='#0072C1', - truth_color='tab:orange', quantiles=[0.05, 0.5, 0.95], + title_kwargs=dict(fontsize=16), color = next(nmma_colors), #color='#0072C1', + truth_color='tab:orange', quantiles=[0.16, 0.5, 0.84], levels=(0.10, 0.32, 0.68, 0.95), median_line=True, title_fmt=".2f", plot_density=False, plot_datapoints=False, fill_contours=True, max_n_ticks=4, hist_kwargs={'density': True}) @@ -114,6 +183,27 @@ def corner_plot(plot_samples, labels, limits, fig = None, save=False, **kwargs): return fig +def set_title(ax, title, label, color, **title_kwargs): + old_text = ax._left_title.get_text() + if old_text == '': # meaning no title set, so we can set our own + text = f'{label}={title}' + ax.set_title(text, loc = 'left', color=color, **title_kwargs) + + elif ax._right_title.get_text() == '': # meaning only left title set yet + old_text_parts = old_text.split('=') + ax._left_title.set_text(old_text_parts[-1]) # remove old title + ax.set_title(title, loc='right', color=color, **title_kwargs) + + elif ax.get_title() == '': # meaning only left and right title set yet + ax.set_title(title, loc='center', color=color, pad=10, **title_kwargs) + else: + # we already have a title in all three locations, so we rather remove all + print("Warning: More than three titles set for this axis, so all titles were removed to avoid confusion. ") + print(f"All intended title information was {label}: {old_text} (first, left), {ax._right_title.get_text()} (second, right), {ax.get_title()} (third, center), {title} (new).") + ax._left_title.set_text('') + ax._right_title.set_text('') + ax.set_title('', loc='center', **title_kwargs) + return ax def resampling_corner_plot(posterior_samples, solution, outdir, withNSBH): diff --git a/requirements.txt b/requirements.txt index 8b34e285..1ec16d56 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ future -bilby>=2.7.1 +bilby>=2.8 bilby_pipe>=1.7.0 schwimmbad colorcet @@ -18,6 +18,4 @@ toml ligo.skymap healpy scikit-learn -tensorflow>=2.19; platform_system == "Darwin" -tensorflow>=2.19; platform_system == "Windows" tensorflow-cpu>=2.19; platform_system == "Linux"