diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs new file mode 100644 index 000000000..fa097d5d5 --- /dev/null +++ b/.git-blame-ignore-revs @@ -0,0 +1,2 @@ +# apply ruff formatting +3920071bac94e2c4b20bcf0ce9911a7b7656d0ac \ No newline at end of file diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index ae4e4ef6a..17f3b133c 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -18,11 +18,14 @@ jobs: pre-commit: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 - - uses: actions/setup-python@v4 + - uses: actions/checkout@v5 + - uses: actions/setup-python@v6 + with: + python-version: '3.13' + cache: 'pip' - name: Install jupyter run: | python -m pip install jupyterlab - uses: pre-commit/action@v3.0.1 - - uses: pre-commit-ci/lite-action@v1.0.2 + - uses: pre-commit-ci/lite-action@v1.1.0 if: always() diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b4ab87110..3e7338b8b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,18 +3,12 @@ repos: rev: v2.3.0 hooks: - id: check-merge-conflict # prevent committing files with merge conflicts -- repo: https://github.com/pycqa/flake8 - rev: 6.1.0 - hooks: - - id: flake8 - additional_dependencies: - - Flake8-pyproject -- repo: https://github.com/psf/black - rev: 22.3.0 - hooks: - - id: black - language_version: python3 - files: '^(bilby/bilby_mcmc/|bilby/core/sampler/|examples/)' +- repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.14.0 # Use the latest stable Ruff version + hooks: + - id: ruff-check # Runs the Ruff linter + args: [ --fix ] # Optionally enable automatic fixes + - id: ruff-format # Runs the Ruff formatter - repo: https://github.com/codespell-project/codespell rev: v2.1.0 hooks: diff --git a/CHANGELOG.md b/CHANGELOG.md index 16c762f38..1e31a6a24 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,8 @@ The original MRs are only visible on the [LIGO GitLab repository](https://git.li ## [Unreleased] +* MAINT: switch to ruff for automated formatting + ## [2.7.1] ### Fixes diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index d85e2176c..72fb17c33 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -24,14 +24,12 @@ Code of Conduct](https://www.python.org/psf/codeofconduct/). Furthermore, member ## Code style -During a code review (when you want to contribute changes to the code base), -you may be asked to change your code to fit with the bilby style. This is based -on a few python conventions and is generally maintained to ensure the code base -remains consistent and readable to new users. Here we list some typical things -to keep in mind ensuring the code review is as smooth as possible - -1. We follow the [standard python PEP8](https://www.python.org/dev/peps/pep-0008/) conventions for style. While the testing of this is largely automated (the C.I. pipeline tests check using [flake8](http://flake8.pycqa.org/en/latest/)), some more subjective things might slip the net. -2. New classes/functions/methods should have a docstring and following the [numpy docstring guide](https://numpydoc.readthedocs.io/en/latest/format.html), for example +We apply [ruff](https://docs.astral.sh/ruff/) automated formatting to the +entire codebase. This ensures that we have an objectively enforced and +consistent style across the project. In addition to the automated test below +are a few useful guidelines when writing code for Bilby. + +1. New classes/functions/methods should have a docstring and following the [numpy docstring guide](https://numpydoc.readthedocs.io/en/latest/format.html), for example ```python def my_new_function(x, y, print=False): """ A function to calculate the sum of two numbers @@ -47,6 +45,7 @@ def my_new_function(x, y, print=False): print("Message!") return x + y ``` +2. Changes to existing functions and classes that change the functionality should generally be accompanied by a change in the docstring describing the change in behaviour and version the change was introduced. 3. Avoid inline comments unless necessary. Ideally, the code should make it obvious what is going on, if not the docstring, only in subtle cases use comments 4. Name variables sensibly. Avoid using single-letter variables, it is better to name something `power_spectral_density_array` than `psda`. 5. Don't repeat yourself. If code is repeated in multiple places, wrap it up into a function. This also helps with the writing of robust unit tests (see below). @@ -54,8 +53,9 @@ def my_new_function(x, y, print=False): ## Automated code checking -In order to automate checking of the code quality, we use -[pre-commit](https://pre-commit.com/). For more details, see the documentation, +We use [pre-commit](https://pre-commit.com/) to apply the automated code checking. +To maximize smoothness of the contributing process, we recommend that you install +the pre-commit tests on your development machine. For more details, see the documentation, here we will give a quick-start guide: 1. Install and configure: ```console @@ -78,6 +78,14 @@ $ pre-commit install If you experience any issues with pre-commit, please ask for support on the usual help channels. +You can additionally run the individual checks manually, e.g., + +```console +$ pip install ruff +$ cd bilby +$ ruff format +$ ruff check --fix +``` ## Unit Testing diff --git a/bilby/__init__.py b/bilby/__init__.py index fb0abff05..5a2a46c71 100644 --- a/bilby/__init__.py +++ b/bilby/__init__.py @@ -15,39 +15,29 @@ """ - -import sys - from . import core, gw, hyper - -from .core import utils, likelihood, prior, result, sampler -from .core.sampler import run_sampler +from .core import likelihood, prior, result, sampler, utils from .core.likelihood import Likelihood from .core.result import read_in_result, read_in_result_list +from .core.sampler import run_sampler try: from ._version import version as __version__ except ModuleNotFoundError: # development mode - __version__ = 'unknown' - - -if sys.version_info < (3,): - raise ImportError( -"""You are running bilby >= 0.6.4 on Python 2 - -Bilby 0.6.4 and above are no longer compatible with Python 2, and you still -ended up with this version installed. That's unfortunate; sorry about that. -It should not have happened. Make sure you have pip >= 9.0 to avoid this kind -of issue, as well as setuptools >= 24.2: - - $ pip install pip setuptools --upgrade - -Your choices: - -- Upgrade to Python 3. - -- Install an older version of bilby: - - $ pip install 'bilby<0.6.4' - -""") + __version__ = "unknown" + +__all__ = [ + core, + gw, + hyper, + likelihood, + prior, + result, + sampler, + utils, + Likelihood, + read_in_result, + read_in_result_list, + run_sampler, + __version__, +] diff --git a/bilby/bilby_mcmc/__init__.py b/bilby/bilby_mcmc/__init__.py index cbf49adb1..44dd18975 100644 --- a/bilby/bilby_mcmc/__init__.py +++ b/bilby/bilby_mcmc/__init__.py @@ -1 +1,3 @@ from .sampler import Bilby_MCMC + +__all__ = [Bilby_MCMC] diff --git a/bilby/bilby_mcmc/chain.py b/bilby/bilby_mcmc/chain.py index 4b78a3044..2a8294ae0 100644 --- a/bilby/bilby_mcmc/chain.py +++ b/bilby/bilby_mcmc/chain.py @@ -7,7 +7,7 @@ from .utils import LOGLKEY, LOGLLATEXKEY, LOGPKEY, LOGPLATEXKEY -class Chain(object): +class Chain: def __init__( self, initial_sample, @@ -84,9 +84,7 @@ def _get_zero_chain_array(self): return np.zeros((self.block_length, self.ndim + 2), dtype=np.float64) def _extend_chain_array(self): - self._chain_array = np.concatenate( - (self._chain_array, self._get_zero_chain_array()), axis=0 - ) + self._chain_array = np.concatenate((self._chain_array, self._get_zero_chain_array()), axis=0) self._chain_array_length = len(self._chain_array) @property @@ -259,11 +257,7 @@ def tau(self): if self.position in self.max_tau_dict: # If we have the ACT at the current position, return it return self.max_tau_dict[self.position] - elif ( - self.tau_last < np.inf - and self.cached_tau_count < 50 - and self.nsamples_last > 50 - ): + elif self.tau_last < np.inf and self.cached_tau_count < 50 and self.nsamples_last > 50: # If we have a recent ACT return it self.cached_tau_count += 1 return self.tau_last @@ -312,9 +306,7 @@ def _calculate_tau_dict(self, minimum_index): # Choose minimimum index for the ACT calculation last_tau = self.tau_last if self.tau_window is not None and last_tau < np.inf: - minimum_index_for_act = max( - minimum_index, int(self.position - self.tau_window * last_tau) - ) + minimum_index_for_act = max(minimum_index, int(self.position - self.tau_window * last_tau)) else: minimum_index_for_act = minimum_index @@ -364,9 +356,7 @@ def samples(self): def plot(self, outdir=".", label="label", priors=None, all_samples=None): import matplotlib.pyplot as plt - fig, axes = plt.subplots( - nrows=self.ndim + 3, ncols=2, figsize=(8, 9 + 3 * (self.ndim)) - ) + fig, axes = plt.subplots(nrows=self.ndim + 3, ncols=2, figsize=(8, 9 + 3 * (self.ndim))) scatter_kwargs = dict( lw=0, marker="o", @@ -386,7 +376,7 @@ def plot(self, outdir=".", label="label", priors=None, all_samples=None): position_indexes = np.arange(self.position + 1) # Plot the traceplots - for (start, stop, thin, color, alpha, ms) in plot_setups: + for start, stop, thin, color, alpha, ms in plot_setups: for ax, key in zip(axes[:, 0], self.keys): xx = position_indexes[start:stop:thin] / K yy = self.get_1d_array(key)[start:stop:thin] @@ -415,16 +405,12 @@ def plot(self, outdir=".", label="label", priors=None, all_samples=None): if all_samples is not None: yy_all = all_samples[key] if np.any(np.isinf(yy_all)): - logger.warning( - f"Could not plot histogram for parameter {key} due to infinite values" - ) + logger.warning(f"Could not plot histogram for parameter {key} due to infinite values") else: ax.hist(yy_all, bins=50, alpha=0.6, density=True, color="k") yy = self.get_1d_array(key)[nburn : self.position : self.thin] if np.any(np.isinf(yy)): - logger.warning( - f"Could not plot histogram for parameter {key} due to infinite values" - ) + logger.warning(f"Could not plot histogram for parameter {key} due to infinite values") else: ax.hist(yy, bins=50, alpha=0.8, density=True) ax.set_xlabel(self._get_plot_label_by_key(key, priors)) @@ -441,16 +427,13 @@ def plot(self, outdir=".", label="label", priors=None, all_samples=None): axes[-1, 1].set_axis_off() - filename = "{}/{}_checkpoint_trace.png".format(outdir, label) + filename = f"{outdir}/{label}_checkpoint_trace.png" msg = [ r"Maximum $\tau$" + f"={self.tau:0.1f} ", r"$n_{\rm samples}=$" + f"{self.nsamples} ", ] if self.thin_by_nact != 1: - msg += [ - r"$n_{\rm samples}^{\rm eff}=$" - + f"{int(self.nsamples * self.thin_by_nact)} " - ] + msg += [r"$n_{\rm samples}^{\rm eff}=$" + f"{int(self.nsamples * self.thin_by_nact)} "] fig.suptitle( "| ".join(msg), y=1, @@ -471,7 +454,7 @@ def _get_plot_label_by_key(key, priors=None): return key -class Sample(object): +class Sample: def __init__(self, sample_dict): """A single sample diff --git a/bilby/bilby_mcmc/flows.py b/bilby/bilby_mcmc/flows.py index b08ea3a93..e85f10933 100644 --- a/bilby/bilby_mcmc/flows.py +++ b/bilby/bilby_mcmc/flows.py @@ -44,7 +44,6 @@ def __init__( batch_norm_between_layers=False, random_permutation=True, ): - if use_volume_preserving: coupling_constructor = AdditiveCouplingTransform else: @@ -66,9 +65,7 @@ def create_resnet(in_features, out_features): layers = [] for _ in range(num_layers): - transform = coupling_constructor( - mask=mask, transform_net_create_fn=create_resnet - ) + transform = coupling_constructor(mask=mask, transform_net_create_fn=create_resnet) layers.append(transform) mask *= -1 if batch_norm_between_layers: @@ -87,9 +84,7 @@ class BasicFlow(Flow): def __init__(self, features): transform = CompositeTransform( [ - MaskedAffineAutoregressiveTransform( - features=features, hidden_features=2 * features - ), + MaskedAffineAutoregressiveTransform(features=features, hidden_features=2 * features), RandomPermutation(features=features), ] ) diff --git a/bilby/bilby_mcmc/proposals.py b/bilby/bilby_mcmc/proposals.py index 09215f17e..d8dc4d243 100644 --- a/bilby/bilby_mcmc/proposals.py +++ b/bilby/bilby_mcmc/proposals.py @@ -13,7 +13,7 @@ from ..gw.source import PARAMETER_SETS -class ProposalCycle(object): +class ProposalCycle: def __init__(self, proposal_list): self.proposal_list = proposal_list self.weights = [prop.weight for prop in self.proposal_list] @@ -45,7 +45,7 @@ def __str__(self): return string -class BaseProposal(object): +class BaseProposal: _accepted = 0 _rejected = 0 __metaclass__ = ABCMeta @@ -134,7 +134,6 @@ def apply_reflective_boundary(self, key, val): return minimum + width * val_normalised_reflected def __call__(self, chain, likelihood=None, priors=None): - if getattr(self, "needs_likelihood_and_priors", False): sample, log_factor = self.propose(chain, likelihood, priors) else: @@ -203,7 +202,7 @@ class FixedGaussianProposal(BaseProposal): """ def __init__(self, priors, weight=1, subset=None, sigma=0.01): - super(FixedGaussianProposal, self).__init__(priors, weight, subset) + super().__init__(priors, weight, subset) self.sigmas = {} for key in self.parameters: if np.isinf(self.prior_width_dict[key]): @@ -235,7 +234,7 @@ def __init__( stop=1e5, target_facc=0.234, ): - super(AdaptiveGaussianProposal, self).__init__(priors, weight, subset) + super().__init__(priors, weight, subset) self.sigmas = {} for key in self.parameters: if np.isinf(self.prior_width_dict[key]): @@ -299,7 +298,7 @@ class DifferentialEvolutionProposal(BaseProposal): """ def __init__(self, priors, weight=1, subset=None, mode_hopping_frac=0.5): - super(DifferentialEvolutionProposal, self).__init__(priors, weight, subset) + super().__init__(priors, weight, subset) self.mode_hopping_frac = mode_hopping_frac def propose(self, chain): @@ -340,16 +339,14 @@ class UniformProposal(BaseProposal): """ def __init__(self, priors, weight=1, subset=None): - super(UniformProposal, self).__init__(priors, weight, subset) + super().__init__(priors, weight, subset) def propose(self, chain): sample = chain.current_sample for key in self.parameters: width = self.prior_width_dict[key] if np.isinf(width) is False: - sample[key] = random.rng.uniform( - self.prior_minimum_dict[key], self.prior_maximum_dict[key] - ) + sample[key] = random.rng.uniform(self.prior_minimum_dict[key], self.prior_maximum_dict[key]) else: # Unable to generate a uniform sample on infinite support pass @@ -376,7 +373,7 @@ class PriorProposal(BaseProposal): """ def __init__(self, priors, weight=1, subset=None): - super(PriorProposal, self).__init__(priors, weight, subset) + super().__init__(priors, weight, subset) self.priors = PriorDict({key: priors[key] for key in self.parameters}) def propose(self, chain): @@ -426,7 +423,7 @@ def __init__( fallback=AdaptiveGaussianProposal, scale_fits=1, ): - super(DensityEstimateProposal, self).__init__(priors, weight, subset) + super().__init__(priors, weight, subset) self.nsamples_for_density = nsamples_for_density self.fallback = fallback(priors, weight, subset) self.fit_multiplier = fit_multiplier * scale_fits @@ -490,9 +487,7 @@ def refit(self, chain): fail_parameters.append(key) if len(fail_parameters) > 0: - logger.debug( - f"{self.density_name} construction failed verification and is discarded" - ) + logger.debug(f"{self.density_name} construction failed verification and is discarded") self.density = current_density else: self.trained = True @@ -565,9 +560,7 @@ def _sample(self, nsamples=None): def check_dependencies(warn=True): if importlib.util.find_spec("sklearn") is None: if warn: - logger.warning( - "Unable to utilise GMMProposal as sklearn is not installed" - ) + logger.warning("Unable to utilise GMMProposal as sklearn is not installed") return False else: return True @@ -598,7 +591,7 @@ def __init__( js_factor=10, fallback=AdaptiveGaussianProposal, ): - super(NormalizingFlowProposal, self).__init__( + super().__init__( priors=priors, weight=weight, subset=subset, @@ -690,19 +683,14 @@ def train(self, chain): # Draw from the current flow self.flow.eval() - training_samples_draw = ( - self.flow.sample(self.nsamples_for_density).detach().numpy() - ) + training_samples_draw = self.flow.sample(self.nsamples_for_density).detach().numpy() self.flow.train() if np.mod(epoch, 10) == 0: - max_js_bits = self._calculate_js( - validation_samples, training_samples_draw - ) + max_js_bits = self._calculate_js(validation_samples, training_samples_draw) if max_js_bits < max_js_threshold: logger.debug( - f"Training complete after {epoch} steps, " - f"max_js_bits={max_js_bits:0.5f}<{max_js_threshold}" + f"Training complete after {epoch} steps, max_js_bits={max_js_bits:0.5f}<{max_js_threshold}" ) break @@ -740,9 +728,7 @@ def propose(self, chain): theta_prime_T = self.flow.sample(1) logp_theta_prime = self.flow.log_prob(theta_prime_T).detach().numpy()[0] - theta_T = torch.tensor( - np.atleast_2d([theta[key] for key in self.parameters]), dtype=torch.float32 - ) + theta_T = torch.tensor(np.atleast_2d([theta[key] for key in self.parameters]), dtype=torch.float32) logp_theta = self.flow.log_prob(theta_T).detach().numpy()[0] log_factor = logp_theta - logp_theta_prime @@ -756,9 +742,7 @@ def propose(self, chain): def check_dependencies(warn=True): if importlib.util.find_spec("glasflow") is None: if warn: - logger.warning( - "Unable to utilise NormalizingFlowProposal as glasflow is not installed" - ) + logger.warning("Unable to utilise NormalizingFlowProposal as glasflow is not installed") return False else: return True @@ -766,7 +750,7 @@ def check_dependencies(warn=True): class FixedJumpProposal(BaseProposal): def __init__(self, priors, jumps=1, subset=None, weight=1, scale=1e-4): - super(FixedJumpProposal, self).__init__(priors, weight, subset) + super().__init__(priors, weight, subset) self.scale = scale if isinstance(jumps, (int, float)): self.jumps = {key: jumps for key in self.parameters} @@ -808,9 +792,7 @@ def __init__( fd_eps=1e-4, adapt=False, ): - super(FisherMatrixProposal, self).__init__( - priors, weight, subset, scale_init=scale_init - ) + super().__init__(priors, weight, subset, scale_init=scale_init) self.update_interval = update_interval self.steps_since_update = update_interval self.adapt = adapt @@ -822,9 +804,7 @@ def propose(self, chain, likelihood, priors): if self.adapt: self.update_scale(chain) if self.steps_since_update >= self.update_interval: - fmp = FisherMatrixPosteriorEstimator( - likelihood, priors, parameters=self.parameters, fd_eps=self.fd_eps - ) + fmp = FisherMatrixPosteriorEstimator(likelihood, priors, parameters=self.parameters, fd_eps=self.fd_eps) parameters = {key: priors[key].peak for key in priors.fixed_keys} parameters.update(sample.dict) try: @@ -836,9 +816,7 @@ def propose(self, chain, likelihood, priors): return sample, 0 self.steps_since_update = 0 - jump = self.scale * random.rng.multivariate_normal( - self.mean, self.iFIM, check_valid="ignore" - ) + jump = self.scale * random.rng.multivariate_normal(self.mean, self.iFIM, check_valid="ignore") for key, val in zip(self.parameters, jump): sample[key] += val @@ -850,9 +828,7 @@ def propose(self, chain, likelihood, priors): class BaseGravitationalWaveTransientProposal(BaseProposal): def __init__(self, priors, weight=1): - super(BaseGravitationalWaveTransientProposal, self).__init__( - priors, weight=weight - ) + super().__init__(priors, weight=weight) if "phase" in priors: self.phase_key = "phase" elif "delta_phase" in priors: @@ -890,7 +866,7 @@ def get_delta_phase(self, phase, sample): class CorrelatedPolarisationPhaseJump(BaseGravitationalWaveTransientProposal): def __init__(self, priors, weight=1): - super(CorrelatedPolarisationPhaseJump, self).__init__(priors, weight=weight) + super().__init__(priors, weight=weight) def propose(self, chain): sample = chain.current_sample @@ -920,13 +896,11 @@ def propose(self, chain): class PhaseReversalProposal(BaseGravitationalWaveTransientProposal): def __init__(self, priors, weight=1, fuzz=True, fuzz_sigma=1e-1): - super(PhaseReversalProposal, self).__init__(priors, weight) + super().__init__(priors, weight) self.fuzz = fuzz self.fuzz_sigma = fuzz_sigma if self.phase_key is None: - raise SamplerError( - f"{type(self).__name__} initialised without a phase prior" - ) + raise SamplerError(f"{type(self).__name__} initialised without a phase prior") def propose(self, chain): sample = chain.current_sample @@ -945,9 +919,7 @@ def epsilon(self): class PolarisationReversalProposal(PhaseReversalProposal): def __init__(self, priors, weight=1, fuzz=True, fuzz_sigma=1e-3): - super(PolarisationReversalProposal, self).__init__( - priors, weight, fuzz, fuzz_sigma - ) + super().__init__(priors, weight, fuzz, fuzz_sigma) self.fuzz = fuzz def propose(self, chain): @@ -960,16 +932,12 @@ def propose(self, chain): class PhasePolarisationReversalProposal(PhaseReversalProposal): def __init__(self, priors, weight=1, fuzz=True, fuzz_sigma=1e-1): - super(PhasePolarisationReversalProposal, self).__init__( - priors, weight, fuzz, fuzz_sigma - ) + super().__init__(priors, weight, fuzz, fuzz_sigma) self.fuzz = fuzz def propose(self, chain): sample = chain.current_sample - sample[self.phase_key] = np.mod( - sample[self.phase_key] + np.pi + self.epsilon, 2 * np.pi - ) + sample[self.phase_key] = np.mod(sample[self.phase_key] + np.pi + self.epsilon, 2 * np.pi) sample["psi"] = np.mod(sample["psi"] + np.pi / 2 + self.epsilon, np.pi) log_factor = 0 return sample, log_factor @@ -989,7 +957,7 @@ class StretchProposal(BaseProposal): """ def __init__(self, priors, weight=1, subset=None, scale=2): - super(StretchProposal, self).__init__(priors, weight, subset) + super().__init__(priors, weight, subset) self.scale = scale def propose(self, chain): @@ -1018,7 +986,7 @@ class EnsembleProposal(BaseProposal): """Base EnsembleProposal class for ensemble-based swap proposals""" def __init__(self, priors, weight=1): - super(EnsembleProposal, self).__init__(priors, weight) + super().__init__(priors, weight) def __call__(self, chain, chain_complement): sample, log_factor = self.propose(chain, chain_complement) @@ -1041,17 +1009,13 @@ class EnsembleStretch(EnsembleProposal): """ def __init__(self, priors, weight=1, scale=2): - super(EnsembleStretch, self).__init__(priors, weight) + super().__init__(priors, weight) self.scale = scale def propose(self, chain, chain_complement): sample = chain.current_sample - completement = chain_complement[ - random.rng.integers(len(chain_complement)) - ].current_sample - return _stretch_move( - sample, completement, self.scale, self.ndim, self.parameters - ) + completement = chain_complement[random.rng.integers(len(chain_complement))].current_sample + return _stretch_move(sample, completement, self.scale, self.ndim, self.parameters) def get_default_ensemble_proposal_cycle(priors): @@ -1065,60 +1029,40 @@ def get_proposal_cycle(string, priors, L1steps=1, warn=True): if "gwA" in string: # Parameters for learning proposals - learning_kwargs = dict( - first_fit=1000, nsamples_for_density=10000, fit_multiplier=2 - ) + learning_kwargs = dict(first_fit=1000, nsamples_for_density=10000, fit_multiplier=2) all_but_cal = [key for key in priors if "recalib" not in key] plist = [ AdaptiveGaussianProposal(priors, weight=small_weight, subset=all_but_cal), - DifferentialEvolutionProposal( - priors, weight=small_weight, subset=all_but_cal - ), + DifferentialEvolutionProposal(priors, weight=small_weight, subset=all_but_cal), ] if GMMProposal.check_dependencies(warn=warn) is False: - raise SamplerError( - "the gwA proposal_cycle required the GMMProposal dependencies" - ) + raise SamplerError("the gwA proposal_cycle required the GMMProposal dependencies") if priors.intrinsic: intrinsic = PARAMETER_SETS["intrinsic"] plist += [ AdaptiveGaussianProposal(priors, weight=small_weight, subset=intrinsic), - DifferentialEvolutionProposal( - priors, weight=small_weight, subset=intrinsic - ), - KDEProposal( - priors, weight=small_weight, subset=intrinsic, **learning_kwargs - ), - GMMProposal( - priors, weight=small_weight, subset=intrinsic, **learning_kwargs - ), + DifferentialEvolutionProposal(priors, weight=small_weight, subset=intrinsic), + KDEProposal(priors, weight=small_weight, subset=intrinsic, **learning_kwargs), + GMMProposal(priors, weight=small_weight, subset=intrinsic, **learning_kwargs), ] if priors.extrinsic: extrinsic = PARAMETER_SETS["extrinsic"] plist += [ AdaptiveGaussianProposal(priors, weight=small_weight, subset=extrinsic), - DifferentialEvolutionProposal( - priors, weight=small_weight, subset=extrinsic - ), - KDEProposal( - priors, weight=small_weight, subset=extrinsic, **learning_kwargs - ), - GMMProposal( - priors, weight=small_weight, subset=extrinsic, **learning_kwargs - ), + DifferentialEvolutionProposal(priors, weight=small_weight, subset=extrinsic), + KDEProposal(priors, weight=small_weight, subset=extrinsic, **learning_kwargs), + GMMProposal(priors, weight=small_weight, subset=extrinsic, **learning_kwargs), ] if priors.mass: mass = PARAMETER_SETS["mass"] plist += [ DifferentialEvolutionProposal(priors, weight=small_weight, subset=mass), - GMMProposal( - priors, weight=small_weight, subset=mass, **learning_kwargs - ), + GMMProposal(priors, weight=small_weight, subset=mass, **learning_kwargs), FisherMatrixProposal( priors, weight=small_weight, @@ -1130,9 +1074,7 @@ def get_proposal_cycle(string, priors, L1steps=1, warn=True): spin = PARAMETER_SETS["spin"] plist += [ DifferentialEvolutionProposal(priors, weight=small_weight, subset=spin), - GMMProposal( - priors, weight=small_weight, subset=spin, **learning_kwargs - ), + GMMProposal(priors, weight=small_weight, subset=spin, **learning_kwargs), FisherMatrixProposal( priors, weight=big_weight, @@ -1142,9 +1084,7 @@ def get_proposal_cycle(string, priors, L1steps=1, warn=True): if priors.measured_spin: measured_spin = PARAMETER_SETS["measured_spin"] plist += [ - AdaptiveGaussianProposal( - priors, weight=small_weight, subset=measured_spin - ), + AdaptiveGaussianProposal(priors, weight=small_weight, subset=measured_spin), FisherMatrixProposal( priors, weight=small_weight, @@ -1155,17 +1095,13 @@ def get_proposal_cycle(string, priors, L1steps=1, warn=True): if priors.mass and priors.spin: primary_spin_and_q = PARAMETER_SETS["primary_spin_and_q"] plist += [ - DifferentialEvolutionProposal( - priors, weight=small_weight, subset=primary_spin_and_q - ), + DifferentialEvolutionProposal(priors, weight=small_weight, subset=primary_spin_and_q), ] if getattr(priors, "tidal", False): tidal = PARAMETER_SETS["tidal"] plist += [ - DifferentialEvolutionProposal( - priors, weight=small_weight, subset=tidal - ), + DifferentialEvolutionProposal(priors, weight=small_weight, subset=tidal), PriorProposal(priors, weight=small_weight, subset=tidal), ] if priors.phase: diff --git a/bilby/bilby_mcmc/sampler.py b/bilby/bilby_mcmc/sampler.py index 699cea265..1c701eea2 100644 --- a/bilby/bilby_mcmc/sampler.py +++ b/bilby/bilby_mcmc/sampler.py @@ -184,8 +184,7 @@ def __init__( normalize_prior=True, **kwargs, ): - - super(Bilby_MCMC, self).__init__( + super().__init__( likelihood=likelihood, priors=priors, outdir=outdir, @@ -202,12 +201,8 @@ def __init__( self.L1steps = self.kwargs["L1steps"] self.L2steps = self.kwargs["L2steps"] self.normalize_prior = normalize_prior - self.pt_inputs = ParallelTemperingInputs( - **{key: self.kwargs[key] for key in ParallelTemperingInputs._fields} - ) - self.convergence_inputs = ConvergenceInputs( - **{key: self.kwargs[key] for key in ConvergenceInputs._fields} - ) + self.pt_inputs = ParallelTemperingInputs(**{key: self.kwargs[key] for key in ParallelTemperingInputs._fields}) + self.convergence_inputs = ConvergenceInputs(**{key: self.kwargs[key] for key in ConvergenceInputs._fields}) self.proposal_cycle = self.kwargs["proposal_cycle"] self.pt_rejection_sample = self.kwargs["pt_rejection_sample"] self.evidence_method = self.kwargs["evidence_method"] @@ -218,7 +213,7 @@ def __init__( self.check_point_delta_t = self.kwargs["check_point_delta_t"] check_directory_exists_and_if_not_mkdir(self.outdir) self.resume = resume - self.resume_file = "{}/{}_resume.pickle".format(self.outdir, self.label) + self.resume_file = f"{self.outdir}/{self.label}_resume.pickle" self.verify_configuration() self.verbose = verbose @@ -305,7 +300,6 @@ def setup_chain_set(self): self.init_ptsampler() def init_ptsampler(self): - logger.info(f"Initializing BilbyPTMCMCSampler with:\n{self.get_setup_string()}") self.ptsampler = BilbyPTMCMCSampler( convergence_inputs=self.convergence_inputs, @@ -350,9 +344,7 @@ def draw(self): tp = datetime.datetime.now() ppt_frac = (tp - tp0).total_seconds() / self._time_since_last_print if ppt_frac > 0.01: - logger.warning( - f"Non-negligible print progress time (ppt_frac={ppt_frac:0.2f})" - ) + logger.warning(f"Non-negligible print progress time (ppt_frac={ppt_frac:0.2f})") self._steps_since_last_print = 0 self._time_since_last_print = 0 @@ -391,9 +383,7 @@ def read_current_state(self): If true, resume file was successfully loaded, otherwise false """ - if os.path.isfile(self.resume_file) is False or not os.path.getsize( - self.resume_file - ): + if os.path.isfile(self.resume_file) is False or not os.path.getsize(self.resume_file): return False import dill @@ -404,10 +394,7 @@ def read_current_state(self): return False self.ptsampler = ptsampler if self.ptsampler.pt_inputs != self.pt_inputs: - msg = ( - f"pt_inputs has changed: {self.ptsampler.pt_inputs} " - f"-> {self.pt_inputs}" - ) + msg = f"pt_inputs has changed: {self.ptsampler.pt_inputs} -> {self.pt_inputs}" raise ResumeError(msg) self.ptsampler.set_convergence_inputs(self.convergence_inputs) self.ptsampler.pt_rejection_sample = self.pt_rejection_sample @@ -432,11 +419,9 @@ def write_current_state(self): self.ptsampler.pool = None if dill.pickles(self.ptsampler): safe_file_dump(self.ptsampler, self.resume_file, dill) - logger.info("Written checkpoint file {}".format(self.resume_file)) + logger.info(f"Written checkpoint file {self.resume_file}") else: - logger.warning( - "Cannot write pickle resume file! Job may not resume if interrupted." - ) + logger.warning("Cannot write pickle resume file! Job may not resume if interrupted.") # Touch the file to postpone next check-point attempt Path(self.resume_file).touch(exist_ok=True) self.ptsampler.pool = _pool @@ -449,12 +434,8 @@ def print_long_progress(self): if self.ptsampler.nensemble > 1: self.print_ensemble_acceptance() if self.check_point_plot: - self.plot_progress( - self.ptsampler, self.label, self.outdir, self.priors, self.diagnostic - ) - self.ptsampler.compute_evidence( - outdir=self.outdir, label=self.label, make_plots=True - ) + self.plot_progress(self.ptsampler, self.label, self.outdir, self.priors, self.diagnostic) + self.ptsampler.compute_evidence(outdir=self.outdir, label=self.label, make_plots=True) def print_ensemble_acceptance(self): logger.info(f"Ensemble swaps = {self.ptsampler.swap_counter['ensemble']}") @@ -468,9 +449,7 @@ def print_progress(self): time = str(sampling_time).split(".")[0] # Time for last evaluation set - time_per_eval_ms = ( - 1000 * self._time_since_last_print / self._steps_since_last_print - ) + time_per_eval_ms = 1000 * self._time_since_last_print / self._steps_since_last_print # Pull out progress summary tau = self.ptsampler.tau @@ -487,12 +466,7 @@ def print_progress(self): # Estimated time til finish (ETF) if tau < np.inf: remaining_samples = self.target_nsamples - nsamples - remaining_evals = ( - remaining_samples - * self.convergence_inputs.thin_by_nact - * tau - * self.L1steps - ) + remaining_evals = remaining_samples * self.convergence_inputs.thin_by_nact * tau * self.L1steps remaining_time_s = time_per_eval_ms * 1e-3 * remaining_evals remaining_time_dt = datetime.timedelta(seconds=remaining_time_s) if remaining_samples > 0: @@ -582,7 +556,7 @@ def get_expected_outputs(cls, outdir=None, label=None): return filenames, [] -class BilbyPTMCMCSampler(object): +class BilbyPTMCMCSampler: def __init__( self, convergence_inputs, @@ -615,9 +589,7 @@ def __init__( self.swap_counter["L2-ensemble"] = int(self.L2steps / 2) + 1 self._nsamples_dict = {} - self.ensemble_proposal_cycle = proposals.get_default_ensemble_proposal_cycle( - _sampling_convenience_dump.priors - ) + self.ensemble_proposal_cycle = proposals.get_default_ensemble_proposal_cycle(_sampling_convenience_dump.priors) self.sampling_time = 0 self.ln_z_dict = dict() self.ln_z_err_dict = dict() @@ -644,7 +616,6 @@ def get_initial_betas(self): return betas def setup_sampler_dictionary(self, convergence_inputs, proposal_cycle): - betas = self.get_initial_betas() logger.info( f"Initializing BilbyPTMCMCSampler with:" @@ -772,9 +743,7 @@ def _calculate_nsamples(self): nsamples_list.append(sampler.nsamples) if self.pt_rejection_sample: for samp in self.sampler_list[1:]: - nsamples_list.append( - len(samp.rejection_sample_zero_temperature_samples()) - ) + nsamples_list.append(len(samp.rejection_sample_zero_temperature_samples())) return sum(nsamples_list) @property @@ -824,9 +793,7 @@ def step_all_chains(self): if self.adapt: self.adapt_temperatures() elif self.adapt: - logger.info( - f"Adaptation of temperature chains finished at step {self.position}" - ) + logger.info(f"Adaptation of temperature chains finished at step {self.position}") self.adapt = False self.swap_counter["L2-ensemble"] += 1 @@ -948,9 +915,7 @@ def compute_evidence(self, outdir, label, make_plots=True): ln_z, ln_z_err = self.compute_evidence_per_ensemble(method, kwargs) self.ln_z_dict[key] = ln_z self.ln_z_err_dict[key] = ln_z_err - logger.debug( - f"Log-evidence of {ln_z:0.2f}+/-{ln_z_err:0.2f} calculated using {key} method" - ) + logger.debug(f"Log-evidence of {ln_z:0.2f}+/-{ln_z_err:0.2f} calculated using {key} method") def compute_evidence_per_ensemble(self, method, kwargs): from scipy.special import logsumexp @@ -975,9 +940,7 @@ def compute_evidence_per_ensemble(self, method, kwargs): return lnZ, lnZerr - def thermodynamic_integration_evidence( - self, ptchain, outdir, label, make_plots=True - ): + def thermodynamic_integration_evidence(self, ptchain, outdir, label, make_plots=True): """Computes the evidence using thermodynamic integration We compute the evidence without the burnin samples, no thinning @@ -1043,9 +1006,7 @@ def stepping_stone_evidence(self, ptchain, outdir, label, make_plots=True): return np.nan, np.nan # Read in log likelihoods - ln_likes = np.array( - [samp.chain.get_1d_array(LOGLKEY)[min_index:max_index] for samp in ptchain] - )[:-1].T + ln_likes = np.array([samp.chain.get_1d_array(LOGLKEY)[min_index:max_index] for samp in ptchain])[:-1].T # Thin to only independent samples ln_likes = ln_likes[:: int(self.tau), :] @@ -1064,9 +1025,7 @@ def stepping_stone_evidence(self, ptchain, outdir, label, make_plots=True): try: for _ in range(repeats): idxs = [random.rng.integers(i, i + ll) for i in range(steps - ll)] - ln_z_realisations.append( - self._calculate_stepping_stone(betas, ln_likes[idxs, :])[0] - ) + ln_z_realisations.append(self._calculate_stepping_stone(betas, ln_likes[idxs, :])[0]) ln_z_err = np.std(ln_z_realisations) except ValueError: logger.info("Failed to estimate stepping stone uncertainty") @@ -1117,7 +1076,7 @@ def _create_lnZ_plots(self, betas, mean_lnlikes, outdir, label, sem_lnlikes=None ax1.set_ylabel(r"$\langle \log(\mathcal{L}) \rangle$") plt.tight_layout() - fig.savefig("{}/{}_beta_lnl.png".format(outdir, label)) + fig.savefig(f"{outdir}/{label}_beta_lnl.png") plt.close() def _create_stepping_stone_plot(self, means, outdir, label): @@ -1140,7 +1099,7 @@ def _create_stepping_stone_plot(self, means, outdir, label): ax.set_ylabel("Cumulative $\\ln Z$") plt.tight_layout() - fig.savefig("{}/{}_stepping_stone.png".format(outdir, label)) + fig.savefig(f"{outdir}/{label}_stepping_stone.png") plt.close() @property @@ -1155,7 +1114,7 @@ def rejection_sampling_count(self): return None -class BilbyMCMCSampler(object): +class BilbyMCMCSampler: def __init__( self, convergence_inputs, @@ -1179,16 +1138,12 @@ def __init__( if initial_sample_method.lower() == "prior": full_sample_dict = _sampling_convenience_dump.priors.sample() initial_sample = { - k: v - for k, v in full_sample_dict.items() - if k in _sampling_convenience_dump.priors.non_fixed_keys + k: v for k, v in full_sample_dict.items() if k in _sampling_convenience_dump.priors.non_fixed_keys } elif initial_sample_method.lower() in ["maximize", "maximise", "maximum"]: initial_sample = get_initial_maximimum_posterior_sample(self.beta) else: - raise ValueError( - f"initial sample method {initial_sample_method} not understood" - ) + raise ValueError(f"initial sample method {initial_sample_method} not understood") if initial_sample_dict is not None: initial_sample.update(initial_sample_dict) @@ -1295,11 +1250,7 @@ def step(self): with np.errstate(over="ignore"): alpha = np.exp( - log_factor - + self.beta * prop[LOGLKEY] - + prop[LOGPKEY] - - self.beta * curr[LOGLKEY] - - curr[LOGPKEY] + log_factor + self.beta * prop[LOGLKEY] + prop[LOGPKEY] - self.beta * curr[LOGLKEY] - curr[LOGPKEY] ) if random.rng.uniform(0, 1) <= alpha: @@ -1338,14 +1289,9 @@ def samples(self): def rejection_sample_zero_temperature_samples(self, print_message=False): beta = self.beta chain = self.chain - hot_samples = pd.DataFrame( - chain._chain_array[chain.minimum_index : chain.position], columns=chain.keys - ) + hot_samples = pd.DataFrame(chain._chain_array[chain.minimum_index : chain.position], columns=chain.keys) if len(hot_samples) == 0: - logger.debug( - f"Rejection sampling for Temp {self.Tindex} failed: " - "no usable hot samples" - ) + logger.debug(f"Rejection sampling for Temp {self.Tindex} failed: no usable hot samples") return hot_samples # Pull out log likelihood @@ -1353,9 +1299,7 @@ def rejection_sample_zero_temperature_samples(self, print_message=False): # Revert to true likelihood if needed if _sampling_convenience_dump.use_ratio: - zerotemp_logl += ( - _sampling_convenience_dump.likelihood.noise_log_likelihood() - ) + zerotemp_logl += _sampling_convenience_dump.likelihood.noise_log_likelihood() # Calculate normalised weights log_weights = (1 - beta) * zerotemp_logl @@ -1370,10 +1314,7 @@ def rejection_sample_zero_temperature_samples(self, print_message=False): self.rejection_sampling_count = len(samples) if print_message: - logger.info( - f"Rejection sampling Temp {self.Tindex}, beta={beta:0.2f} " - f"yielded {len(samples)} samples" - ) + logger.info(f"Rejection sampling Temp {self.Tindex}, beta={beta:0.2f} yielded {len(samples)} samples") return samples diff --git a/bilby/core/__init__.py b/bilby/core/__init__.py index 968f961d0..9e2483593 100644 --- a/bilby/core/__init__.py +++ b/bilby/core/__init__.py @@ -1 +1,3 @@ -from . import grid, likelihood, prior, result, sampler, series, utils, fisher +from . import fisher, grid, likelihood, prior, result, sampler, series, utils + +__all__ = [fisher, grid, likelihood, prior, result, sampler, series, utils] diff --git a/bilby/core/fisher.py b/bilby/core/fisher.py index e8a869ec7..385271f8c 100644 --- a/bilby/core/fisher.py +++ b/bilby/core/fisher.py @@ -6,9 +6,9 @@ from .likelihood import _safe_likelihood_call -class FisherMatrixPosteriorEstimator(object): +class FisherMatrixPosteriorEstimator: def __init__(self, likelihood, priors, parameters=None, fd_eps=1e-6, n_prior_samples=100): - """ A class to estimate posteriors using the Fisher Matrix approach + """A class to estimate posteriors using the Fisher Matrix approach Parameters ---------- @@ -34,9 +34,7 @@ def __init__(self, likelihood, priors, parameters=None, fd_eps=1e-6, n_prior_sam self.n_prior_samples = n_prior_samples self.N = len(self.parameter_names) - self.prior_samples = [ - priors.sample_subset(self.parameter_names) for _ in range(n_prior_samples) - ] + self.prior_samples = [priors.sample_subset(self.parameter_names) for _ in range(n_prior_samples)] self.prior_bounds = [(priors[key].minimum, priors[key].maximum) for key in self.parameter_names] self.prior_width_dict = {} @@ -93,13 +91,13 @@ def get_finite_difference_xx(self, sample, ii): p = self.shift_sample_x(sample, ii, 1) m = self.shift_sample_x(sample, ii, -1) - dx = .5 * (p[ii] - m[ii]) + dx = 0.5 * (p[ii] - m[ii]) loglp = self.log_likelihood(p) logl = self.log_likelihood(sample) loglm = self.log_likelihood(m) - return (loglp - 2 * logl + loglm) / dx ** 2 + return (loglp - 2 * logl + loglm) / dx**2 def get_finite_difference_xy(self, sample, ii, jj): # Sample grid @@ -108,8 +106,8 @@ def get_finite_difference_xy(self, sample, ii, jj): mp = self.shift_sample_xy(sample, ii, -1, jj, 1) mm = self.shift_sample_xy(sample, ii, -1, jj, -1) - dx = .5 * (pp[ii] - mm[ii]) - dy = .5 * (pp[jj] - mm[jj]) + dx = 0.5 * (pp[ii] - mm[ii]) + dy = 0.5 * (pp[jj] - mm[jj]) loglpp = self.log_likelihood(pp) loglpm = self.log_likelihood(pm) @@ -119,7 +117,6 @@ def get_finite_difference_xy(self, sample, ii, jj): return (loglpp - loglpm - loglmp + loglmm) / (4 * dx * dy) def shift_sample_x(self, sample, x_key, x_coef): - vx = sample[x_key] dvx = self.fd_eps * self.prior_width_dict[x_key] @@ -129,7 +126,6 @@ def shift_sample_x(self, sample, x_key, x_coef): return shift_sample def shift_sample_xy(self, sample, x_key, x_coef, y_key, y_coef): - vx = sample[x_key] vy = sample[y_key] @@ -142,7 +138,7 @@ def shift_sample_xy(self, sample, x_key, x_coef, y_key, y_coef): return shift_sample def get_maximum_likelihood_sample(self, initial_sample=None): - """ A method to attempt optimization of the maximum likelihood + """A method to attempt optimization of the maximum likelihood This uses a simple scipy optimization approach, starting from a number of draws from the prior to avoid problems with local optimization. @@ -160,7 +156,7 @@ def get_maximum_likelihood_sample(self, initial_sample=None): def neg_log_like(x, self, T=1): sample = {key: val for key, val in zip(self.parameter_names, x)} - return - 1 / T * self.log_likelihood(sample) + return -1 / T * self.log_likelihood(sample) out = minimize( neg_log_like, diff --git a/bilby/core/grid.py b/bilby/core/grid.py index 0d103d4cc..122c25d8a 100644 --- a/bilby/core/grid.py +++ b/bilby/core/grid.py @@ -5,15 +5,19 @@ from .likelihood import _safe_likelihood_call from .prior import Prior, PriorDict +from .result import FileMovedError from .utils import ( - logtrapzexp, check_directory_exists_and_if_not_mkdir, logger, - BilbyJsonEncoder, load_json, move_old_file + BilbyJsonEncoder, + check_directory_exists_and_if_not_mkdir, + load_json, + logger, + logtrapzexp, + move_old_file, ) -from .result import FileMovedError def grid_file_name(outdir, label, gzip=False): - """ Returns the standard filename used for a grid file + """Returns the standard filename used for a grid file Parameters ========== @@ -29,15 +33,15 @@ def grid_file_name(outdir, label, gzip=False): str: File name of the output file """ if gzip: - return os.path.join(outdir, '{}_grid.json.gz'.format(label)) + return os.path.join(outdir, f"{label}_grid.json.gz") else: - return os.path.join(outdir, '{}_grid.json'.format(label)) - + return os.path.join(outdir, f"{label}_grid.json") -class Grid(object): - def __init__(self, likelihood=None, priors=None, grid_size=101, - save=False, label='no_label', outdir='.', gzip=False): +class Grid: + def __init__( + self, likelihood=None, priors=None, grid_size=101, save=False, label="no_label", outdir=".", gzip=False + ): """ Parameters @@ -72,9 +76,8 @@ def __init__(self, likelihood=None, priors=None, grid_size=101, # evaluate the prior on the grid points if self.n_dims > 0: self._ln_prior = self.priors.ln_prob( - {key: self.mesh_grid[i].flatten() for i, key in - enumerate(self.parameter_names)}, axis=0).reshape( - self.mesh_grid[0].shape) + {key: self.mesh_grid[i].flatten() for i, key in enumerate(self.parameter_names)}, axis=0 + ).reshape(self.mesh_grid[0].shape) self._ln_likelihood = None # evaluate the likelihood on the grid points @@ -179,8 +182,7 @@ def _marginalize_single(self, log_array, name, non_marg_names=None): """ if name not in self.parameter_names: - raise ValueError("'{}' is not a recognised " - "parameter".format(name)) + raise ValueError(f"'{name}' is not a recognised parameter") if non_marg_names is None: non_marg_names = list(self.parameter_names) @@ -192,9 +194,7 @@ def _marginalize_single(self, log_array, name, non_marg_names=None): if len(places) > 1: dx = np.diff(places) - out = np.apply_along_axis( - logtrapzexp, axis, log_array, dx - ) + out = np.apply_along_axis(logtrapzexp, axis, log_array, dx) else: # no marginalisation required, just remove the singleton dimension z = log_array.shape @@ -233,8 +233,7 @@ def marginalize_ln_likelihood(self, parameters=None, not_parameters=None): array-like: The marginalized ln likelihood. """ - return self.marginalize(self.ln_likelihood, parameters=parameters, - not_parameters=not_parameters) + return self.marginalize(self.ln_likelihood, parameters=parameters, not_parameters=not_parameters) def marginalize_ln_posterior(self, parameters=None, not_parameters=None): """ @@ -254,8 +253,7 @@ def marginalize_ln_posterior(self, parameters=None, not_parameters=None): array-like: The marginalized ln posterior. """ - return self.marginalize(self.ln_posterior, parameters=parameters, - not_parameters=not_parameters) + return self.marginalize(self.ln_posterior, parameters=parameters, not_parameters=not_parameters) def marginalize_likelihood(self, parameters=None, not_parameters=None): """ @@ -275,8 +273,7 @@ def marginalize_likelihood(self, parameters=None, not_parameters=None): array-like: The marginalized likelihood. """ - ln_like = self.marginalize(self.ln_likelihood, parameters=parameters, - not_parameters=not_parameters) + ln_like = self.marginalize(self.ln_likelihood, parameters=parameters, not_parameters=not_parameters) # NOTE: the output will not be properly normalised return np.exp(ln_like - np.max(ln_like)) @@ -298,8 +295,7 @@ def marginalize_posterior(self, parameters=None, not_parameters=None): array-like: The marginalized posterior. """ - ln_post = self.marginalize(self.ln_posterior, parameters=parameters, - not_parameters=not_parameters) + ln_post = self.marginalize(self.ln_posterior, parameters=parameters, not_parameters=not_parameters) # NOTE: the output will not be properly normalised return np.exp(ln_post - np.max(ln_post)) @@ -310,12 +306,10 @@ def _evaluate(self): def _evaluate_recursion(self, dimension, parameters): if dimension == self.n_dims: - current_point = tuple([[int(np.where( - parameters[name] == - self.sample_points[name])[0])] for name in self.parameter_names]) - self._ln_likelihood[current_point] = _safe_likelihood_call( - self.likelihood, parameters + current_point = tuple( + [[int(np.where(parameters[name] == self.sample_points[name])[0])] for name in self.parameter_names] ) + self._ln_likelihood[current_point] = _safe_likelihood_call(self.likelihood, parameters) else: name = self.parameter_names[dimension] for ii in range(self._ln_likelihood.shape[dimension]): @@ -326,40 +320,42 @@ def _get_sample_points(self, grid_size): for ii, key in enumerate(self.parameter_names): if isinstance(self.priors[key], Prior): if isinstance(grid_size, int): - self.sample_points[key] = self.priors[key].rescale( - np.linspace(0, 1, grid_size)) + self.sample_points[key] = self.priors[key].rescale(np.linspace(0, 1, grid_size)) elif isinstance(grid_size, list): if isinstance(grid_size[ii], int): - self.sample_points[key] = self.priors[key].rescale( - np.linspace(0, 1, grid_size[ii])) + self.sample_points[key] = self.priors[key].rescale(np.linspace(0, 1, grid_size[ii])) else: self.sample_points[key] = grid_size[ii] elif isinstance(grid_size, dict): if isinstance(grid_size[key], int): - self.sample_points[key] = self.priors[key].rescale( - np.linspace(0, 1, grid_size[key])) + self.sample_points[key] = self.priors[key].rescale(np.linspace(0, 1, grid_size[key])) else: self.sample_points[key] = grid_size[key] else: raise TypeError("Unrecognized 'grid_size' type") # set the mesh of points - self.mesh_grid = np.meshgrid( - *(self.sample_points[key] for key in self.parameter_names), - indexing='ij') + self.mesh_grid = np.meshgrid(*(self.sample_points[key] for key in self.parameter_names), indexing="ij") def _get_save_data_dictionary(self): # This list defines all the parameters saved in the grid object save_attrs = [ - 'label', 'outdir', 'parameter_names', 'n_dims', 'priors', - 'sample_points', 'ln_likelihood', 'ln_evidence', - 'ln_noise_evidence'] + "label", + "outdir", + "parameter_names", + "n_dims", + "priors", + "sample_points", + "ln_likelihood", + "ln_evidence", + "ln_noise_evidence", + ] dictionary = dict() for attr in save_attrs: try: dictionary[attr] = getattr(self, attr) except ValueError as e: - logger.debug("Unable to save {}, message: {}".format(attr, e)) + logger.debug(f"Unable to save {attr}, message: {e}") pass return dictionary @@ -369,14 +365,15 @@ def _safe_outdir_creation(self, outdir=None, caller_func=None): try: check_directory_exists_and_if_not_mkdir(outdir) except PermissionError: - raise FileMovedError("Can not write in the out directory.\n" - "Did you move the here file from another system?\n" - "Try calling " + caller_func.__name__ + " with the 'outdir' " - "keyword argument, e.g. " + caller_func.__name__ + "(outdir='.')") + raise FileMovedError( + "Can not write in the out directory.\n" + "Did you move the here file from another system?\n" + "Try calling " + caller_func.__name__ + " with the 'outdir' " + "keyword argument, e.g. " + caller_func.__name__ + "(outdir='.')" + ) return outdir - def save_to_file(self, filename=None, overwrite=False, outdir=None, - gzip=False): + def save_to_file(self, filename=None, overwrite=False, outdir=None, gzip=False): """ Writes the Grid to a file. @@ -406,22 +403,22 @@ def save_to_file(self, filename=None, overwrite=False, outdir=None, try: dictionary["priors"] = dictionary["priors"]._get_json_dict() - if gzip or (os.path.splitext(filename)[-1] == '.gz'): + if gzip or (os.path.splitext(filename)[-1] == ".gz"): import gzip + # encode to a string - json_str = json.dumps(dictionary, cls=BilbyJsonEncoder).encode('utf-8') - with gzip.GzipFile(filename, 'w') as file: + json_str = json.dumps(dictionary, cls=BilbyJsonEncoder).encode("utf-8") + with gzip.GzipFile(filename, "w") as file: file.write(json_str) else: - with open(filename, 'w') as file: + with open(filename, "w") as file: json.dump(dictionary, file, indent=2, cls=BilbyJsonEncoder) except Exception as e: - logger.error("\n\n Saving the data has failed with the " - "following message:\n {} \n\n".format(e)) + logger.error(f"\n\n Saving the data has failed with the following message:\n {e} \n\n") @classmethod def read(cls, filename=None, outdir=None, label=None, gzip=False): - """ Read in a saved .json grid file + """Read in a saved .json grid file Parameters ========== @@ -454,16 +451,20 @@ def read(cls, filename=None, outdir=None, label=None, gzip=False): if os.path.isfile(filename): dictionary = load_json(filename, gzip) try: - grid = cls(likelihood=None, priors=dictionary['priors'], - grid_size=dictionary['sample_points'], - label=dictionary['label'], outdir=dictionary['outdir']) + grid = cls( + likelihood=None, + priors=dictionary["priors"], + grid_size=dictionary["sample_points"], + label=dictionary["label"], + outdir=dictionary["outdir"], + ) # set the likelihood - grid._ln_likelihood = dictionary['ln_likelihood'] - grid.ln_noise_evidence = dictionary['ln_noise_evidence'] + grid._ln_likelihood = dictionary["ln_likelihood"] + grid.ln_noise_evidence = dictionary["ln_noise_evidence"] return grid except TypeError as e: - raise IOError("Unable to load dictionary, error={}".format(e)) + raise OSError(f"Unable to load dictionary, error={e}") else: - raise IOError("No result '{}' found".format(filename)) + raise OSError(f"No result '{filename}' found") diff --git a/bilby/core/likelihood.py b/bilby/core/likelihood.py index 0ba344ec5..08f7e4145 100644 --- a/bilby/core/likelihood.py +++ b/bilby/core/likelihood.py @@ -7,7 +7,7 @@ from scipy.special import gammaln, xlogy from scipy.stats import multivariate_normal -from .utils import infer_parameters_from_function, infer_args_from_function_except_n_args, logger +from .utils import infer_args_from_function_except_n_args, infer_parameters_from_function, logger PARAMETERS_AS_STATE = os.environ.get("BILBY_ALLOW_PARAMETERS_AS_STATE", "TRUE") @@ -22,7 +22,6 @@ def set_parameters_as_state(level): def _fallback_to_parameters(obj, parameters): - if parameters is None: msg = "No parameters provided in likelihood call, falling back to values stored in {obj}" if PARAMETERS_AS_STATE == "FALSE": @@ -46,9 +45,7 @@ def _safe_likelihood_call(likelihood, parameters=None, use_ratio=False): logl = method(parameters=parameters) else: if PARAMETERS_AS_STATE == "FALSE": - raise LikelihoodParameterError( - f"Unable to call {likelihood} with {parameters} as an argument" - ) + raise LikelihoodParameterError(f"Unable to call {likelihood} with {parameters} as an argument") elif PARAMETERS_AS_STATE == "WARN": warn(f"Using parameters as state for {likelihood}", FutureWarning) likelihood.parameters.update(parameters) @@ -57,7 +54,6 @@ def _safe_likelihood_call(likelihood, parameters=None, use_ratio=False): class Likelihood: - def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) if ( @@ -140,7 +136,7 @@ def log_likelihood_ratio(self, parameters=None): @property def meta_data(self): - return getattr(self, '_meta_data', None) + return getattr(self, "_meta_data", None) @meta_data.setter def meta_data(self, meta_data): @@ -155,7 +151,7 @@ def marginalized_parameters(self): class ZeroLikelihood(Likelihood): - """ A special test-only class which already returns zero likelihood + """A special test-only class which already returns zero likelihood Parameters ========== @@ -165,7 +161,7 @@ class ZeroLikelihood(Likelihood): """ def __init__(self, likelihood): - super(ZeroLikelihood, self).__init__() + super().__init__() self.parameters = likelihood.parameters self._parent = likelihood @@ -197,7 +193,7 @@ class Analytical1DLikelihood(Likelihood): def __init__(self, x, y, func, **kwargs): parameters = infer_parameters_from_function(func) - super(Analytical1DLikelihood, self).__init__() + super().__init__() self.x = x self.y = y self._func = func @@ -205,31 +201,31 @@ def __init__(self, x, y, func, **kwargs): self.kwargs = kwargs def __repr__(self): - return self.__class__.__name__ + '(x={}, y={}, func={})'.format(self.x, self.y, self.func.__name__) + return self.__class__.__name__ + f"(x={self.x}, y={self.y}, func={self.func.__name__})" @property def func(self): - """ Make func read-only """ + """Make func read-only""" return self._func def model_parameters(self, parameters=None): - """ This sets up the function only parameters (i.e. not sigma for the GaussianLikelihood) """ + """This sets up the function only parameters (i.e. not sigma for the GaussianLikelihood)""" parameters = _fallback_to_parameters(self, parameters) return {key: parameters[key] for key in self.function_keys} @property def function_keys(self): - """ Makes function_keys read_only """ + """Makes function_keys read_only""" return self._function_keys @property def n(self): - """ The number of data points """ + """The number of data points""" return len(self.x) @property def x(self): - """ The independent variable. Setter assures that single numbers will be converted to arrays internally """ + """The independent variable. Setter assures that single numbers will be converted to arrays internally""" return self._x @x.setter @@ -240,7 +236,7 @@ def x(self, x): @property def y(self): - """ The dependent variable. Setter assures that single numbers will be converted to arrays internally """ + """The dependent variable. Setter assures that single numbers will be converted to arrays internally""" return self._y @y.setter @@ -250,7 +246,7 @@ def y(self, y): self._y = y def residual(self, parameters=None): - """ Residual of the function against the data. """ + """Residual of the function against the data.""" return self.y - self.func(self.x, **self.model_parameters(parameters=parameters), **self.kwargs) @@ -277,19 +273,17 @@ def __init__(self, x, y, func, sigma=None, **kwargs): to that for `x` and `y`. """ - super(GaussianLikelihood, self).__init__(x=x, y=y, func=func, **kwargs) + super().__init__(x=x, y=y, func=func, **kwargs) self.sigma = sigma def log_likelihood(self, parameters=None): parameters = _fallback_to_parameters(self, parameters) sigma = parameters.get("sigma", self.sigma) - log_l = np.sum(- (self.residual(parameters) / sigma)**2 / 2 - - np.log(2 * np.pi * sigma**2) / 2) + log_l = np.sum(-((self.residual(parameters) / sigma) ** 2) / 2 - np.log(2 * np.pi * sigma**2) / 2) return log_l def __repr__(self): - return self.__class__.__name__ + '(x={}, y={}, func={}, sigma={})' \ - .format(self.x, self.y, self.func.__name__, self.sigma) + return self.__class__.__name__ + f"(x={self.x}, y={self.y}, func={self.func.__name__}, sigma={self.sigma})" @property def sigma(self): @@ -302,7 +296,7 @@ def sigma(self): if PARAMETERS_AS_STATE == "FALSE": return self._sigma else: - return self.parameters.get('sigma', self._sigma) + return self.parameters.get("sigma", self._sigma) @sigma.setter def sigma(self, sigma): @@ -313,7 +307,7 @@ def sigma(self, sigma): elif len(sigma) == self.n: self._sigma = sigma else: - raise ValueError('Sigma must be either float or array-like x.') + raise ValueError("Sigma must be either float or array-like x.") class PoissonLikelihood(Analytical1DLikelihood): @@ -339,18 +333,17 @@ def __init__(self, x, y, func, **kwargs): fixed value is given). """ - super(PoissonLikelihood, self).__init__(x=x, y=y, func=func, **kwargs) + super().__init__(x=x, y=y, func=func, **kwargs) def log_likelihood(self, parameters=None): rate = self.func(self.x, **self.model_parameters(parameters=parameters), **self.kwargs) if not isinstance(rate, np.ndarray): raise ValueError( - "Poisson rate function returns wrong value type! " - "Is {} when it should be numpy.ndarray".format(type(rate))) - elif np.any(rate < 0.): - raise ValueError(("Poisson rate function returns a negative", - " value!")) - elif np.any(rate == 0.): + f"Poisson rate function returns wrong value type! Is {type(rate)} when it should be numpy.ndarray" + ) + elif np.any(rate < 0.0): + raise ValueError(("Poisson rate function returns a negative", " value!")) + elif np.any(rate == 0.0): return -np.inf else: return np.sum(-rate + self.y * np.log(rate) - gammaln(self.y + 1)) @@ -360,7 +353,7 @@ def __repr__(self): @property def y(self): - """ Property assures that y-value is a positive integer. """ + """Property assures that y-value is a positive integer.""" return self.__y @y.setter @@ -368,7 +361,7 @@ def y(self, y): if not isinstance(y, np.ndarray): y = np.array([y]) # check array is a non-negative integer array - if y.dtype.kind not in 'ui' or np.any(y < 0): + if y.dtype.kind not in "ui" or np.any(y < 0): raise ValueError("Data must be non-negative integers") self.__y = y @@ -390,11 +383,11 @@ def __init__(self, x, y, func, **kwargs): value is given). The model should return the expected mean of the exponential distribution for each data point. """ - super(ExponentialLikelihood, self).__init__(x=x, y=y, func=func, **kwargs) + super().__init__(x=x, y=y, func=func, **kwargs) def log_likelihood(self, parameters=None): mu = self.func(self.x, **self.model_parameters(parameters=parameters), **self.kwargs) - if np.any(mu < 0.): + if np.any(mu < 0.0): return -np.inf return -np.sum(np.log(mu) + (self.y / mu)) @@ -403,7 +396,7 @@ def __repr__(self): @property def y(self): - """ Property assures that y-value is positive. """ + """Property assures that y-value is positive.""" return self._y @y.setter @@ -445,7 +438,7 @@ def __init__(self, x, y, func, nu=None, sigma=1, **kwargs): Set the scale of the distribution. If not given then this defaults to 1, which specifies a standard (central) Student's t-distribution """ - super(StudentTLikelihood, self).__init__(x=x, y=y, func=func, **kwargs) + super().__init__(x=x, y=y, func=func, **kwargs) self.nu = nu self.sigma = sigma @@ -453,36 +446,36 @@ def __init__(self, x, y, func, nu=None, sigma=1, **kwargs): def log_likelihood(self, parameters=None): parameters = _fallback_to_parameters(self, parameters) nu = parameters.get("nu", self.nu) - if nu <= 0.: - raise ValueError("Number of degrees of freedom for Student's " - "t-likelihood must be positive") - - log_l =\ - np.sum(- (nu + 1) * np.log1p(self.lam * self.residual(parameters=parameters)**2 / nu) / 2 + - np.log(self.lam / (nu * np.pi)) / 2 + - gammaln((nu + 1) / 2) - gammaln(nu / 2)) + if nu <= 0.0: + raise ValueError("Number of degrees of freedom for Student's t-likelihood must be positive") + + log_l = np.sum( + -(nu + 1) * np.log1p(self.lam * self.residual(parameters=parameters) ** 2 / nu) / 2 + + np.log(self.lam / (nu * np.pi)) / 2 + + gammaln((nu + 1) / 2) + - gammaln(nu / 2) + ) return log_l def __repr__(self): - base_string = '(x={}, y={}, func={}, nu={}, sigma={})' - return self.__class__.__name__ + base_string.format( - self.x, self.y, self.func.__name__, self.nu, self.sigma) + base_string = "(x={}, y={}, func={}, nu={}, sigma={})" + return self.__class__.__name__ + base_string.format(self.x, self.y, self.func.__name__, self.nu, self.sigma) @property def lam(self): - """ Converts 'scale' to 'precision' """ - return 1. / self.sigma ** 2 + """Converts 'scale' to 'precision'""" + return 1.0 / self.sigma**2 @property def nu(self): - """ This checks if nu or sigma have been set in parameters. If so, those + """This checks if nu or sigma have been set in parameters. If so, those values will be used. Otherwise, the attribute nu is used. The logic is that if nu is not in parameters the attribute is used which was given at init (i.e. the known nu as a float).""" if PARAMETERS_AS_STATE == "FALSE": return self._nu else: - return self.parameters.get('nu', self._nu) + return self.parameters.get("nu", self._nu) @nu.setter def nu(self, nu): @@ -508,7 +501,7 @@ def __init__(self, data, n_dimensions, base="parameter_"): """ self.data = np.array(data) self._total = np.sum(self.data) - super(Multinomial, self).__init__() + super().__init__() self.n = n_dimensions self.base = base self._nll = None @@ -518,8 +511,7 @@ def log_likelihood(self, parameters=None): Since n - 1 parameters are sampled, the last parameter is 1 - the rest """ parameters = _fallback_to_parameters(self, parameters) - probs = [parameters[self.base + str(ii)] - for ii in range(self.n - 1)] + probs = [parameters[self.base + str(ii)] for ii in range(self.n - 1)] probs.append(1 - sum(probs)) return self._multinomial_ln_pdf(probs=probs) @@ -534,30 +526,29 @@ def noise_log_likelihood(self): def _multinomial_ln_pdf(self, probs): """Lifted from scipy.stats.multinomial._logpdf""" - ln_prob = gammaln(self._total + 1) + np.sum( - xlogy(self.data, probs) - gammaln(self.data + 1), axis=-1) + ln_prob = gammaln(self._total + 1) + np.sum(xlogy(self.data, probs) - gammaln(self.data + 1), axis=-1) return ln_prob class AnalyticalMultidimensionalCovariantGaussian(Likelihood): """ - A multivariate Gaussian likelihood - with known analytic solution. + A multivariate Gaussian likelihood + with known analytic solution. - Parameters - ========== - mean: array_like - Array with the mean values of distribution - cov: array_like - The ndim*ndim covariance matrix - """ + Parameters + ========== + mean: array_like + Array with the mean values of distribution + cov: array_like + The ndim*ndim covariance matrix + """ def __init__(self, mean, cov): self.cov = np.atleast_2d(cov) self.mean = np.atleast_1d(mean) self.sigma = np.sqrt(np.diag(self.cov)) self.pdf = multivariate_normal(mean=self.mean, cov=self.cov) - super(AnalyticalMultidimensionalCovariantGaussian, self).__init__() + super().__init__() @property def dim(self): @@ -565,23 +556,23 @@ def dim(self): def log_likelihood(self, parameters=None): parameters = _fallback_to_parameters(self, parameters) - x = np.array([parameters["x{0}".format(i)] for i in range(self.dim)]) + x = np.array([parameters[f"x{i}"] for i in range(self.dim)]) return self.pdf.logpdf(x) class AnalyticalMultidimensionalBimodalCovariantGaussian(Likelihood): """ - A multivariate Gaussian likelihood - with known analytic solution. + A multivariate Gaussian likelihood + with known analytic solution. - Parameters - ========== - mean_1: array_like - Array with the mean value of the first mode - mean_2: array_like - Array with the mean value of the second mode - cov: array_like - """ + Parameters + ========== + mean_1: array_like + Array with the mean value of the first mode + mean_2: array_like + Array with the mean value of the second mode + cov: array_like + """ def __init__(self, mean_1, mean_2, cov): self.cov = np.atleast_2d(cov) @@ -590,7 +581,7 @@ def __init__(self, mean_1, mean_2, cov): self.mean_2 = np.atleast_1d(mean_2) self.pdf_1 = multivariate_normal(mean=self.mean_1, cov=self.cov) self.pdf_2 = multivariate_normal(mean=self.mean_2, cov=self.cov) - super(AnalyticalMultidimensionalBimodalCovariantGaussian, self).__init__() + super().__init__() @property def dim(self): @@ -598,7 +589,7 @@ def dim(self): def log_likelihood(self, parameters=None): parameters = _fallback_to_parameters(self, parameters) - x = np.array([parameters["x{0}".format(i)] for i in range(self.dim)]) + x = np.array([parameters[f"x{i}"] for i in range(self.dim)]) return -np.log(2) + np.logaddexp(self.pdf_1.logpdf(x), self.pdf_2.logpdf(x)) @@ -619,11 +610,11 @@ def __init__(self, *likelihoods): likelihoods to be combined parsed as arguments """ self.likelihoods = likelihoods - super(JointLikelihood, self).__init__(parameters={}) + super().__init__(parameters={}) self.__sync_parameters() def __sync_parameters(self): - """ Synchronizes parameters between the likelihoods + """Synchronizes parameters between the likelihoods so that all likelihoods share a single parameter dict.""" if PARAMETERS_AS_STATE != "FALSE": for likelihood in self.likelihoods: @@ -633,7 +624,7 @@ def __sync_parameters(self): @property def likelihoods(self): - """ The list of likelihoods """ + """The list of likelihoods""" return self._likelihoods @likelihoods.setter @@ -643,29 +634,33 @@ def likelihoods(self, likelihoods): if all(isinstance(likelihood, Likelihood) for likelihood in likelihoods): self._likelihoods = list(likelihoods) else: - raise ValueError('Try setting the JointLikelihood like this\n' - 'JointLikelihood(first_likelihood, second_likelihood, ...)') + raise ValueError( + "Try setting the JointLikelihood like this\n" + "JointLikelihood(first_likelihood, second_likelihood, ...)" + ) elif isinstance(likelihoods, Likelihood): self._likelihoods = [likelihoods] else: - raise ValueError('Input likelihood is not a list of tuple. You need to set multiple likelihoods.') + raise ValueError("Input likelihood is not a list of tuple. You need to set multiple likelihoods.") def log_likelihood(self, parameters=None): - """ This is just the sum of the log likelihoods of all parts of the joint likelihood""" + """This is just the sum of the log likelihoods of all parts of the joint likelihood""" return sum([likelihood.log_likelihood(parameters=parameters) for likelihood in self.likelihoods]) def noise_log_likelihood(self): - """ This is just the sum of the noise likelihoods of all parts of the joint likelihood""" + """This is just the sum of the noise likelihoods of all parts of the joint likelihood""" return sum([likelihood.noise_log_likelihood() for likelihood in self.likelihoods]) def function_to_celerite_mean_model(func): from celerite.modeling import Model as CeleriteModel + return _function_to_gp_model(func, CeleriteModel) def function_to_george_mean_model(func): from celerite.modeling import Model as GeorgeModel + return _function_to_gp_model(func, GeorgeModel) @@ -684,29 +679,28 @@ def compute_gradient(self, *args, **kwargs): class _GPLikelihood(Likelihood): - def __init__(self, kernel, mean_model, t, y, yerr=1e-6, gp_class=None): """ - Basic Gaussian Process likelihood interface for `celerite` and `george`. - For `celerite` documentation see: https://celerite.readthedocs.io/en/stable/ - For `george` documentation see: https://george.readthedocs.io/en/latest/ - - Parameters - ========== - kernel: Union[celerite.term.Term, george.kernels.Kernel] - `celerite` or `george` kernel. See the respective package documentation about the usage. - mean_model: Union[celerite.modeling.Model, george.modeling.Model] - Mean model - t: array_like - The `times` or `x` values of the data set. - y: array_like - The `y` values of the data set. - yerr: float, int, array_like, optional - The error values on the y-values. If a single value is given, it is assumed that the value - applies for all y-values. Default is 1e-6, effectively assuming that no y-errors are present. - gp_class: type, None, optional - GPClass to use. This is determined by the child class used to instantiate the GP. Should usually - not be given by the user and is mostly used for testing + Basic Gaussian Process likelihood interface for `celerite` and `george`. + For `celerite` documentation see: https://celerite.readthedocs.io/en/stable/ + For `george` documentation see: https://george.readthedocs.io/en/latest/ + + Parameters + ========== + kernel: Union[celerite.term.Term, george.kernels.Kernel] + `celerite` or `george` kernel. See the respective package documentation about the usage. + mean_model: Union[celerite.modeling.Model, george.modeling.Model] + Mean model + t: array_like + The `times` or `x` values of the data set. + y: array_like + The `y` values of the data set. + yerr: float, int, array_like, optional + The error values on the y-values. If a single value is given, it is assumed that the value + applies for all y-values. Default is 1e-6, effectively assuming that no y-errors are present. + gp_class: type, None, optional + GPClass to use. This is determined by the child class used to instantiate the GP. Should usually + not be given by the user and is mostly used for testing """ self.kernel = kernel self.mean_model = mean_model @@ -742,28 +736,28 @@ def set_parameters(self, parameters): class CeleriteLikelihood(_GPLikelihood): - def __init__(self, kernel, mean_model, t, y, yerr=1e-6): """ - Basic Gaussian Process likelihood interface for `celerite` and `george`. - For `celerite` documentation see: https://celerite.readthedocs.io/en/stable/ - For `george` documentation see: https://george.readthedocs.io/en/latest/ - - Parameters - ========== - kernel: celerite.term.Term - `celerite` or `george` kernel. See the respective package documentation about the usage. - mean_model: celerite.modeling.Model - Mean model - t: array_like - The `times` or `x` values of the data set. - y: array_like - The `y` values of the data set. - yerr: float, int, array_like, optional - The error values on the y-values. If a single value is given, it is assumed that the value - applies for all y-values. Default is 1e-6, effectively assuming that no y-errors are present. + Basic Gaussian Process likelihood interface for `celerite` and `george`. + For `celerite` documentation see: https://celerite.readthedocs.io/en/stable/ + For `george` documentation see: https://george.readthedocs.io/en/latest/ + + Parameters + ========== + kernel: celerite.term.Term + `celerite` or `george` kernel. See the respective package documentation about the usage. + mean_model: celerite.modeling.Model + Mean model + t: array_like + The `times` or `x` values of the data set. + y: array_like + The `y` values of the data set. + yerr: float, int, array_like, optional + The error values on the y-values. If a single value is given, it is assumed that the value + applies for all y-values. Default is 1e-6, effectively assuming that no y-errors are present. """ import celerite + super().__init__(kernel=kernel, mean_model=mean_model, t=t, y=y, yerr=yerr, gp_class=celerite.GP) def log_likelihood(self, parameters=None): @@ -783,28 +777,28 @@ def log_likelihood(self, parameters=None): class GeorgeLikelihood(_GPLikelihood): - def __init__(self, kernel, mean_model, t, y, yerr=1e-6): """ - Basic Gaussian Process likelihood interface for `celerite` and `george`. - For `celerite` documentation see: https://celerite.readthedocs.io/en/stable/ - For `george` documentation see: https://george.readthedocs.io/en/latest/ - - Parameters - ========== - kernel: george.kernels.Kernel - `celerite` or `george` kernel. See the respective package documentation about the usage. - mean_model: george.modeling.Model - Mean model - t: array_like - The `times` or `x` values of the data set. - y: array_like - The `y` values of the data set. - yerr: float, int, array_like, optional - The error values on the y-values. If a single value is given, it is assumed that the value - applies for all y-values. Default is 1e-6, effectively assuming that no y-errors are present. + Basic Gaussian Process likelihood interface for `celerite` and `george`. + For `celerite` documentation see: https://celerite.readthedocs.io/en/stable/ + For `george` documentation see: https://george.readthedocs.io/en/latest/ + + Parameters + ========== + kernel: george.kernels.Kernel + `celerite` or `george` kernel. See the respective package documentation about the usage. + mean_model: george.modeling.Model + Mean model + t: array_like + The `times` or `x` values of the data set. + y: array_like + The `y` values of the data set. + yerr: float, int, array_like, optional + The error values on the y-values. If a single value is given, it is assumed that the value + applies for all y-values. Default is 1e-6, effectively assuming that no y-errors are present. """ import george + super().__init__(kernel=kernel, mean_model=mean_model, t=t, y=y, yerr=yerr, gp_class=george.GP) def log_likelihood(self, parameters=None): diff --git a/bilby/core/prior/__init__.py b/bilby/core/prior/__init__.py index fc795c3e1..f29fcd7f1 100644 --- a/bilby/core/prior/__init__.py +++ b/bilby/core/prior/__init__.py @@ -1,3 +1,5 @@ +# ruff: noqa: F403 + from .analytical import * from .base import * from .conditional import * diff --git a/bilby/core/prior/analytical.py b/bilby/core/prior/analytical.py index bc47cf680..a4aa18756 100644 --- a/bilby/core/prior/analytical.py +++ b/bilby/core/prior/analytical.py @@ -1,25 +1,24 @@ import numpy as np from scipy.special import ( - xlogy, + betainc, + betaincinv, + betaln, erf, erfinv, - log1p, - stdtrit, + gammainc, + gammaincinv, gammaln, + log1p, stdtr, - betaln, - betainc, - betaincinv, - gammaincinv, - gammainc, + stdtrit, + xlogy, ) -from .base import Prior from ..utils import logger +from .base import Prior class DeltaFunction(Prior): - def __init__(self, peak, name=None, latex_label=None, unit=None): """Dirac delta function prior, this always returns peak. @@ -35,8 +34,9 @@ def __init__(self, peak, name=None, latex_label=None, unit=None): See superclass """ - super(DeltaFunction, self).__init__(name=name, latex_label=latex_label, unit=unit, - minimum=peak, maximum=peak, check_range_nonzero=False) + super().__init__( + name=name, latex_label=latex_label, unit=unit, minimum=peak, maximum=peak, check_range_nonzero=False + ) self.peak = peak self._is_fixed = True self.least_recently_sampled = peak @@ -52,7 +52,7 @@ def rescale(self, val): ======= float: Rescaled probability, equivalent to peak """ - return self.peak * val ** 0 + return self.peak * val**0 def prob(self, val): """Return the prior probability of val @@ -66,7 +66,7 @@ def prob(self, val): Union[float, array_like]: np.inf if val = peak, 0 otherwise """ - at_peak = (val == self.peak) + at_peak = val == self.peak return np.nan_to_num(np.multiply(at_peak, np.inf)) def cdf(self, val): @@ -74,9 +74,7 @@ def cdf(self, val): class PowerLaw(Prior): - - def __init__(self, alpha, minimum, maximum, name=None, latex_label=None, - unit=None, boundary=None): + def __init__(self, alpha, minimum, maximum, name=None, latex_label=None, unit=None, boundary=None): """Power law with bounds and alpha, spectral index Parameters @@ -96,9 +94,9 @@ def __init__(self, alpha, minimum, maximum, name=None, latex_label=None, boundary: str See superclass """ - super(PowerLaw, self).__init__(name=name, latex_label=latex_label, - minimum=minimum, maximum=maximum, unit=unit, - boundary=boundary) + super().__init__( + name=name, latex_label=latex_label, minimum=minimum, maximum=maximum, unit=unit, boundary=boundary + ) self.alpha = alpha def rescale(self, val): @@ -119,8 +117,10 @@ def rescale(self, val): if self.alpha == -1: return self.minimum * np.exp(val * np.log(self.maximum / self.minimum)) else: - return (self.minimum ** (1 + self.alpha) + val * - (self.maximum ** (1 + self.alpha) - self.minimum ** (1 + self.alpha))) ** (1. / (1 + self.alpha)) + return ( + self.minimum ** (1 + self.alpha) + + val * (self.maximum ** (1 + self.alpha) - self.minimum ** (1 + self.alpha)) + ) ** (1.0 / (1 + self.alpha)) def prob(self, val): """Return the prior probability of val @@ -136,9 +136,11 @@ def prob(self, val): if self.alpha == -1: return np.nan_to_num(1 / val / np.log(self.maximum / self.minimum)) * self.is_in_prior_range(val) else: - return np.nan_to_num(val ** self.alpha * (1 + self.alpha) / - (self.maximum ** (1 + self.alpha) - - self.minimum ** (1 + self.alpha))) * self.is_in_prior_range(val) + return np.nan_to_num( + val**self.alpha + * (1 + self.alpha) + / (self.maximum ** (1 + self.alpha) - self.minimum ** (1 + self.alpha)) + ) * self.is_in_prior_range(val) def ln_prob(self, val): """Return the logarithmic prior probability of val @@ -153,25 +155,22 @@ def ln_prob(self, val): """ if self.alpha == -1: - normalising = 1. / np.log(self.maximum / self.minimum) + normalising = 1.0 / np.log(self.maximum / self.minimum) else: - normalising = (1 + self.alpha) / (self.maximum ** (1 + self.alpha) - - self.minimum ** (1 + self.alpha)) + normalising = (1 + self.alpha) / (self.maximum ** (1 + self.alpha) - self.minimum ** (1 + self.alpha)) - with np.errstate(divide='ignore', invalid='ignore'): - ln_in_range = np.log(1. * self.is_in_prior_range(val)) + with np.errstate(divide="ignore", invalid="ignore"): + ln_in_range = np.log(1.0 * self.is_in_prior_range(val)) ln_p = self.alpha * np.nan_to_num(np.log(val)) + np.log(normalising) return ln_p + ln_in_range def cdf(self, val): if self.alpha == -1: - _cdf = (np.log(val / self.minimum) / - np.log(self.maximum / self.minimum)) + _cdf = np.log(val / self.minimum) / np.log(self.maximum / self.minimum) else: - _cdf = ( - (val ** (self.alpha + 1) - self.minimum ** (self.alpha + 1)) - / (self.maximum ** (self.alpha + 1) - self.minimum ** (self.alpha + 1)) + _cdf = (val ** (self.alpha + 1) - self.minimum ** (self.alpha + 1)) / ( + self.maximum ** (self.alpha + 1) - self.minimum ** (self.alpha + 1) ) _cdf = np.minimum(_cdf, 1) _cdf = np.maximum(_cdf, 0) @@ -179,9 +178,7 @@ def cdf(self, val): class Uniform(Prior): - - def __init__(self, minimum, maximum, name=None, latex_label=None, - unit=None, boundary=None): + def __init__(self, minimum, maximum, name=None, latex_label=None, unit=None, boundary=None): """Uniform prior with bounds Parameters @@ -199,9 +196,9 @@ def __init__(self, minimum, maximum, name=None, latex_label=None, boundary: str See superclass """ - super(Uniform, self).__init__(name=name, latex_label=latex_label, - minimum=minimum, maximum=maximum, unit=unit, - boundary=boundary) + super().__init__( + name=name, latex_label=latex_label, minimum=minimum, maximum=maximum, unit=unit, boundary=boundary + ) def rescale(self, val): """ @@ -254,9 +251,7 @@ def cdf(self, val): class LogUniform(PowerLaw): - - def __init__(self, minimum, maximum, name=None, latex_label=None, - unit=None, boundary=None): + def __init__(self, minimum, maximum, name=None, latex_label=None, unit=None, boundary=None): """Log-Uniform prior with bounds Parameters @@ -274,16 +269,15 @@ def __init__(self, minimum, maximum, name=None, latex_label=None, boundary: str See superclass """ - super(LogUniform, self).__init__(name=name, latex_label=latex_label, unit=unit, - minimum=minimum, maximum=maximum, alpha=-1, boundary=boundary) + super().__init__( + name=name, latex_label=latex_label, unit=unit, minimum=minimum, maximum=maximum, alpha=-1, boundary=boundary + ) if self.minimum <= 0: - logger.warning('You specified a uniform-in-log prior with minimum={}'.format(self.minimum)) + logger.warning(f"You specified a uniform-in-log prior with minimum={self.minimum}") class SymmetricLogUniform(Prior): - - def __init__(self, minimum, maximum, name=None, latex_label=None, - unit=None, boundary=None): + def __init__(self, minimum, maximum, name=None, latex_label=None, unit=None, boundary=None): """Symmetric Log-Uniform distributions with bounds This is identical to a Log-Uniform distribution, but mirrored about @@ -306,9 +300,9 @@ def __init__(self, minimum, maximum, name=None, latex_label=None, boundary: str See superclass """ - super(SymmetricLogUniform, self).__init__(name=name, latex_label=latex_label, - minimum=minimum, maximum=maximum, unit=unit, - boundary=boundary) + super().__init__( + name=name, latex_label=latex_label, minimum=minimum, maximum=maximum, unit=unit, boundary=boundary + ) def rescale(self, val): """ @@ -333,10 +327,12 @@ def rescale(self, val): else: vals_less_than_5 = val < 0.5 rescaled = np.empty_like(val) - rescaled[vals_less_than_5] = -self.maximum * np.exp(-2 * val[vals_less_than_5] * - np.log(self.maximum / self.minimum)) - rescaled[~vals_less_than_5] = self.minimum * np.exp(np.log(self.maximum / self.minimum) * - (2 * val[~vals_less_than_5] - 1)) + rescaled[vals_less_than_5] = -self.maximum * np.exp( + -2 * val[vals_less_than_5] * np.log(self.maximum / self.minimum) + ) + rescaled[~vals_less_than_5] = self.minimum * np.exp( + np.log(self.maximum / self.minimum) * (2 * val[~vals_less_than_5] - 1) + ) return rescaled def prob(self, val): @@ -351,8 +347,7 @@ def prob(self, val): float: Prior probability of val """ val = np.abs(val) - return (np.nan_to_num(0.5 / val / np.log(self.maximum / self.minimum)) * - self.is_in_prior_range(val)) + return np.nan_to_num(0.5 / val / np.log(self.maximum / self.minimum)) * self.is_in_prior_range(val) def ln_prob(self, val): """Return the logarithmic prior probability of val @@ -366,15 +361,13 @@ def ln_prob(self, val): float: """ - return np.nan_to_num(- np.log(2 * np.abs(val)) - np.log(np.log(self.maximum / self.minimum))) + return np.nan_to_num(-np.log(2 * np.abs(val)) - np.log(np.log(self.maximum / self.minimum))) def cdf(self, val): norm = 0.5 / np.log(self.maximum / self.minimum) _cdf = ( - -norm * np.log(abs(val) / self.maximum) - * (val <= -self.minimum) * (val >= -self.maximum) - + (0.5 + norm * np.log(abs(val) / self.minimum)) - * (val >= self.minimum) * (val <= self.maximum) + -norm * np.log(abs(val) / self.maximum) * (val <= -self.minimum) * (val >= -self.maximum) + + (0.5 + norm * np.log(abs(val) / self.minimum)) * (val >= self.minimum) * (val <= self.maximum) + 0.5 * (val > -self.minimum) * (val < self.minimum) + 1 * (val > self.maximum) ) @@ -382,9 +375,7 @@ def cdf(self, val): class Cosine(Prior): - - def __init__(self, minimum=-np.pi / 2, maximum=np.pi / 2, name=None, - latex_label=None, unit=None, boundary=None): + def __init__(self, minimum=-np.pi / 2, maximum=np.pi / 2, name=None, latex_label=None, unit=None, boundary=None): """Cosine prior with bounds Parameters @@ -402,8 +393,9 @@ def __init__(self, minimum=-np.pi / 2, maximum=np.pi / 2, name=None, boundary: str See superclass """ - super(Cosine, self).__init__(minimum=minimum, maximum=maximum, name=name, - latex_label=latex_label, unit=unit, boundary=boundary) + super().__init__( + minimum=minimum, maximum=maximum, name=name, latex_label=latex_label, unit=unit, boundary=boundary + ) def rescale(self, val): """ @@ -428,19 +420,14 @@ def prob(self, val): return np.cos(val) / 2 * self.is_in_prior_range(val) def cdf(self, val): - _cdf = ( - (np.sin(val) - np.sin(self.minimum)) - / (np.sin(self.maximum) - np.sin(self.minimum)) - * (val >= self.minimum) * (val <= self.maximum) - + 1 * (val > self.maximum) - ) + _cdf = (np.sin(val) - np.sin(self.minimum)) / (np.sin(self.maximum) - np.sin(self.minimum)) * ( + val >= self.minimum + ) * (val <= self.maximum) + 1 * (val > self.maximum) return _cdf class Sine(Prior): - - def __init__(self, minimum=0, maximum=np.pi, name=None, - latex_label=None, unit=None, boundary=None): + def __init__(self, minimum=0, maximum=np.pi, name=None, latex_label=None, unit=None, boundary=None): """Sine prior with bounds Parameters @@ -458,8 +445,9 @@ def __init__(self, minimum=0, maximum=np.pi, name=None, boundary: str See superclass """ - super(Sine, self).__init__(minimum=minimum, maximum=maximum, name=name, - latex_label=latex_label, unit=unit, boundary=boundary) + super().__init__( + minimum=minimum, maximum=maximum, name=name, latex_label=latex_label, unit=unit, boundary=boundary + ) def rescale(self, val): """ @@ -484,17 +472,13 @@ def prob(self, val): return np.sin(val) / 2 * self.is_in_prior_range(val) def cdf(self, val): - _cdf = ( - (np.cos(val) - np.cos(self.minimum)) - / (np.cos(self.maximum) - np.cos(self.minimum)) - * (val >= self.minimum) * (val <= self.maximum) - + 1 * (val > self.maximum) - ) + _cdf = (np.cos(val) - np.cos(self.minimum)) / (np.cos(self.maximum) - np.cos(self.minimum)) * ( + val >= self.minimum + ) * (val <= self.maximum) + 1 * (val > self.maximum) return _cdf class Gaussian(Prior): - def __init__(self, mu, sigma, name=None, latex_label=None, unit=None, boundary=None): """Gaussian prior with mean mu and width sigma @@ -513,7 +497,7 @@ def __init__(self, mu, sigma, name=None, latex_label=None, unit=None, boundary=N boundary: str See superclass """ - super(Gaussian, self).__init__(name=name, latex_label=latex_label, unit=unit, boundary=boundary) + super().__init__(name=name, latex_label=latex_label, unit=unit, boundary=boundary) self.mu = mu self.sigma = sigma @@ -527,7 +511,7 @@ def rescale(self, val): This maps to the inverse CDF. This has been analytically solved for this case. """ - return self.mu + erfinv(2 * val - 1) * 2 ** 0.5 * self.sigma + return self.mu + erfinv(2 * val - 1) * 2**0.5 * self.sigma def prob(self, val): """Return the prior probability of val. @@ -540,7 +524,7 @@ def prob(self, val): ======= Union[float, array_like]: Prior probability of val """ - return np.exp(-(self.mu - val) ** 2 / (2 * self.sigma ** 2)) / (2 * np.pi) ** 0.5 / self.sigma + return np.exp(-((self.mu - val) ** 2) / (2 * self.sigma**2)) / (2 * np.pi) ** 0.5 / self.sigma def ln_prob(self, val): """Return the Log prior probability of val. @@ -554,20 +538,18 @@ def ln_prob(self, val): Union[float, array_like]: Prior probability of val """ - return -0.5 * ((self.mu - val) ** 2 / self.sigma ** 2 + np.log(2 * np.pi * self.sigma ** 2)) + return -0.5 * ((self.mu - val) ** 2 / self.sigma**2 + np.log(2 * np.pi * self.sigma**2)) def cdf(self, val): - return (1 - erf((self.mu - val) / 2 ** 0.5 / self.sigma)) / 2 + return (1 - erf((self.mu - val) / 2**0.5 / self.sigma)) / 2 class Normal(Gaussian): - """A synonym for the Gaussian distribution. """ + """A synonym for the Gaussian distribution.""" class TruncatedGaussian(Prior): - - def __init__(self, mu, sigma, minimum, maximum, name=None, - latex_label=None, unit=None, boundary=None): + def __init__(self, mu, sigma, minimum, maximum, name=None, latex_label=None, unit=None, boundary=None): """Truncated Gaussian prior with mean mu and width sigma https://en.wikipedia.org/wiki/Truncated_normal_distribution @@ -591,21 +573,23 @@ def __init__(self, mu, sigma, minimum, maximum, name=None, boundary: str See superclass """ - super(TruncatedGaussian, self).__init__(name=name, latex_label=latex_label, unit=unit, - minimum=minimum, maximum=maximum, boundary=boundary) + super().__init__( + name=name, latex_label=latex_label, unit=unit, minimum=minimum, maximum=maximum, boundary=boundary + ) self.mu = mu self.sigma = sigma @property def normalisation(self): - """ Calculates the proper normalisation of the truncated Gaussian + """Calculates the proper normalisation of the truncated Gaussian Returns ======= float: Proper normalisation of the truncated Gaussian """ - return (erf((self.maximum - self.mu) / 2 ** 0.5 / self.sigma) - erf( - (self.minimum - self.mu) / 2 ** 0.5 / self.sigma)) / 2 + return ( + erf((self.maximum - self.mu) / 2**0.5 / self.sigma) - erf((self.minimum - self.mu) / 2**0.5 / self.sigma) + ) / 2 def rescale(self, val): """ @@ -613,8 +597,12 @@ def rescale(self, val): This maps to the inverse CDF. This has been analytically solved for this case. """ - return erfinv(2 * val * self.normalisation + erf( - (self.minimum - self.mu) / 2 ** 0.5 / self.sigma)) * 2 ** 0.5 * self.sigma + self.mu + return ( + erfinv(2 * val * self.normalisation + erf((self.minimum - self.mu) / 2**0.5 / self.sigma)) + * 2**0.5 + * self.sigma + + self.mu + ) def prob(self, val): """Return the prior probability of val. @@ -627,17 +615,18 @@ def prob(self, val): ======= float: Prior probability of val """ - return np.exp(-(self.mu - val) ** 2 / (2 * self.sigma ** 2)) / (2 * np.pi) ** 0.5 \ - / self.sigma / self.normalisation * self.is_in_prior_range(val) + return ( + np.exp(-((self.mu - val) ** 2) / (2 * self.sigma**2)) + / (2 * np.pi) ** 0.5 + / self.sigma + / self.normalisation + * self.is_in_prior_range(val) + ) def cdf(self, val): _cdf = ( - ( - erf((val - self.mu) / 2 ** 0.5 / self.sigma) - - erf((self.minimum - self.mu) / 2 ** 0.5 / self.sigma) - ) / 2 / self.normalisation * (val >= self.minimum) * (val <= self.maximum) - + 1 * (val > self.maximum) - ) + erf((val - self.mu) / 2**0.5 / self.sigma) - erf((self.minimum - self.mu) / 2**0.5 / self.sigma) + ) / 2 / self.normalisation * (val >= self.minimum) * (val <= self.maximum) + 1 * (val > self.maximum) return _cdf @@ -662,9 +651,16 @@ def __init__(self, sigma, name=None, latex_label=None, unit=None, boundary=None) boundary: str See superclass """ - super(HalfGaussian, self).__init__(mu=0., sigma=sigma, minimum=0., maximum=np.inf, - name=name, latex_label=latex_label, - unit=unit, boundary=boundary) + super().__init__( + mu=0.0, + sigma=sigma, + minimum=0.0, + maximum=np.inf, + name=name, + latex_label=latex_label, + unit=unit, + boundary=boundary, + ) class HalfNormal(HalfGaussian): @@ -692,10 +688,9 @@ def __init__(self, mu, sigma, name=None, latex_label=None, unit=None, boundary=N boundary: str See superclass """ - super(LogNormal, self).__init__(name=name, minimum=0., latex_label=latex_label, - unit=unit, boundary=boundary) + super().__init__(name=name, minimum=0.0, latex_label=latex_label, unit=unit, boundary=boundary) - if sigma <= 0.: + if sigma <= 0.0: raise ValueError("For the LogGaussian prior the standard deviation must be positive") self.mu = mu @@ -707,7 +702,7 @@ def rescale(self, val): This maps to the inverse CDF. This has been analytically solved for this case. """ - return np.exp(self.mu + np.sqrt(2 * self.sigma ** 2) * erfinv(2 * val - 1)) + return np.exp(self.mu + np.sqrt(2 * self.sigma**2) * erfinv(2 * val - 1)) def prob(self, val): """Returns the prior probability of val. @@ -722,15 +717,20 @@ def prob(self, val): """ if isinstance(val, (float, int)): if val <= self.minimum: - _prob = 0. + _prob = 0.0 else: - _prob = np.exp(-(np.log(val) - self.mu) ** 2 / self.sigma ** 2 / 2)\ - / np.sqrt(2 * np.pi) / val / self.sigma + _prob = ( + np.exp(-((np.log(val) - self.mu) ** 2) / self.sigma**2 / 2) / np.sqrt(2 * np.pi) / val / self.sigma + ) else: _prob = np.zeros(val.size) - idx = (val > self.minimum) - _prob[idx] = np.exp(-(np.log(val[idx]) - self.mu) ** 2 / self.sigma ** 2 / 2)\ - / np.sqrt(2 * np.pi) / val[idx] / self.sigma + idx = val > self.minimum + _prob[idx] = ( + np.exp(-((np.log(val[idx]) - self.mu) ** 2) / self.sigma**2 / 2) + / np.sqrt(2 * np.pi) + / val[idx] + / self.sigma + ) return _prob def ln_prob(self, val): @@ -748,25 +748,28 @@ def ln_prob(self, val): if val <= self.minimum: _ln_prob = -np.inf else: - _ln_prob = -(np.log(val) - self.mu) ** 2 / self.sigma ** 2 / 2\ - - np.log(np.sqrt(2 * np.pi) * val * self.sigma) + _ln_prob = -((np.log(val) - self.mu) ** 2) / self.sigma**2 / 2 - np.log( + np.sqrt(2 * np.pi) * val * self.sigma + ) else: _ln_prob = -np.inf * np.ones(val.size) - idx = (val > self.minimum) - _ln_prob[idx] = -(np.log(val[idx]) - self.mu) ** 2\ - / self.sigma ** 2 / 2 - np.log(np.sqrt(2 * np.pi) * val[idx] * self.sigma) + idx = val > self.minimum + _ln_prob[idx] = -((np.log(val[idx]) - self.mu) ** 2) / self.sigma**2 / 2 - np.log( + np.sqrt(2 * np.pi) * val[idx] * self.sigma + ) return _ln_prob def cdf(self, val): if isinstance(val, (float, int)): if val <= self.minimum: - _cdf = 0. + _cdf = 0.0 else: _cdf = 0.5 + erf((np.log(val) - self.mu) / self.sigma / np.sqrt(2)) / 2 else: _cdf = np.zeros(val.size) - _cdf[val > self.minimum] = 0.5 + erf(( - np.log(val[val > self.minimum]) - self.mu) / self.sigma / np.sqrt(2)) / 2 + _cdf[val > self.minimum] = ( + 0.5 + erf((np.log(val[val > self.minimum]) - self.mu) / self.sigma / np.sqrt(2)) / 2 + ) return _cdf @@ -791,8 +794,7 @@ def __init__(self, mu, name=None, latex_label=None, unit=None, boundary=None): boundary: str See superclass """ - super(Exponential, self).__init__(name=name, minimum=0., latex_label=latex_label, - unit=unit, boundary=boundary) + super().__init__(name=name, minimum=0.0, latex_label=latex_label, unit=unit, boundary=boundary) self.mu = mu def rescale(self, val): @@ -816,7 +818,7 @@ def prob(self, val): """ if isinstance(val, (float, int)): if val < self.minimum: - _prob = 0. + _prob = 0.0 else: _prob = np.exp(-val / self.mu) / self.mu else: @@ -848,18 +850,17 @@ def ln_prob(self, val): def cdf(self, val): if isinstance(val, (float, int)): if val < self.minimum: - _cdf = 0. + _cdf = 0.0 else: - _cdf = 1. - np.exp(-val / self.mu) + _cdf = 1.0 - np.exp(-val / self.mu) else: _cdf = np.zeros(val.size) - _cdf[val >= self.minimum] = 1. - np.exp(-val[val >= self.minimum] / self.mu) + _cdf[val >= self.minimum] = 1.0 - np.exp(-val[val >= self.minimum] / self.mu) return _cdf class StudentT(Prior): - def __init__(self, df, mu=0., scale=1., name=None, latex_label=None, - unit=None, boundary=None): + def __init__(self, df, mu=0.0, scale=1.0, name=None, latex_label=None, unit=None, boundary=None): """Student's t-distribution prior with number of degrees of freedom df, mean mu and scale @@ -882,9 +883,9 @@ def __init__(self, df, mu=0., scale=1., name=None, latex_label=None, boundary: str See superclass """ - super(StudentT, self).__init__(name=name, latex_label=latex_label, unit=unit, boundary=boundary) + super().__init__(name=name, latex_label=latex_label, unit=unit, boundary=boundary) - if df <= 0. or scale <= 0.: + if df <= 0.0 or scale <= 0.0: raise ValueError("For the StudentT prior the number of degrees of freedom and scale must be positive") self.df = df @@ -934,17 +935,19 @@ def ln_prob(self, val): ======= Union[float, array_like]: Prior probability of val """ - return gammaln(0.5 * (self.df + 1)) - gammaln(0.5 * self.df)\ - - np.log(np.sqrt(np.pi * self.df) * self.scale) - (self.df + 1) / 2 *\ - np.log(1 + ((val - self.mu) / self.scale) ** 2 / self.df) + return ( + gammaln(0.5 * (self.df + 1)) + - gammaln(0.5 * self.df) + - np.log(np.sqrt(np.pi * self.df) * self.scale) + - (self.df + 1) / 2 * np.log(1 + ((val - self.mu) / self.scale) ** 2 / self.df) + ) def cdf(self, val): return stdtr(self.df, (val - self.mu) / self.scale) class Beta(Prior): - def __init__(self, alpha, beta, minimum=0, maximum=1, name=None, - latex_label=None, unit=None, boundary=None): + def __init__(self, alpha, beta, minimum=0, maximum=1, name=None, latex_label=None, unit=None, boundary=None): """Beta distribution https://en.wikipedia.org/wiki/Beta_distribution @@ -971,10 +974,11 @@ def __init__(self, alpha, beta, minimum=0, maximum=1, name=None, boundary: str See superclass """ - super(Beta, self).__init__(minimum=minimum, maximum=maximum, name=name, - latex_label=latex_label, unit=unit, boundary=boundary) + super().__init__( + minimum=minimum, maximum=maximum, name=name, latex_label=latex_label, unit=unit, boundary=boundary + ) - if alpha <= 0. or beta <= 0.: + if alpha <= 0.0 or beta <= 0.0: raise ValueError("alpha and beta must both be positive values") self.alpha = alpha @@ -1012,8 +1016,12 @@ def ln_prob(self, val): ======= Union[float, array_like]: Prior probability of val """ - _ln_prob = xlogy(self.alpha - 1, val - self.minimum) + xlogy(self.beta - 1, self.maximum - val)\ - - betaln(self.alpha, self.beta) - xlogy(self.alpha + self.beta - 1, self.maximum - self.minimum) + _ln_prob = ( + xlogy(self.alpha - 1, val - self.minimum) + + xlogy(self.beta - 1, self.maximum - val) + - betaln(self.alpha, self.beta) + - xlogy(self.alpha + self.beta - 1, self.maximum - self.minimum) + ) # deal with the fact that if alpha or beta are < 1 you get infinities at 0 and 1 if isinstance(val, (float, int)): @@ -1029,19 +1037,15 @@ def ln_prob(self, val): def cdf(self, val): if isinstance(val, (float, int)): if val > self.maximum: - return 1. + return 1.0 elif val < self.minimum: - return 0. + return 0.0 else: - return betainc( - self.alpha, self.beta, - (val - self.minimum) / (self.maximum - self.minimum) - ) + return betainc(self.alpha, self.beta, (val - self.minimum) / (self.maximum - self.minimum)) else: - _cdf = np.nan_to_num(betainc(self.alpha, self.beta, - (val - self.minimum) / (self.maximum - self.minimum))) - _cdf[val < self.minimum] = 0. - _cdf[val > self.maximum] = 1. + _cdf = np.nan_to_num(betainc(self.alpha, self.beta, (val - self.minimum) / (self.maximum - self.minimum))) + _cdf[val < self.minimum] = 0.0 + _cdf[val > self.maximum] = 1.0 return _cdf @@ -1066,9 +1070,9 @@ def __init__(self, mu, scale, name=None, latex_label=None, unit=None, boundary=N boundary: str See superclass """ - super(Logistic, self).__init__(name=name, latex_label=latex_label, unit=unit, boundary=boundary) + super().__init__(name=name, latex_label=latex_label, unit=unit, boundary=boundary) - if scale <= 0.: + if scale <= 0.0: raise ValueError("For the Logistic prior the scale must be positive") self.mu = mu @@ -1086,12 +1090,13 @@ def rescale(self, val): elif val == 1: rescaled = np.inf else: - rescaled = self.mu + self.scale * np.log(val / (1. - val)) + rescaled = self.mu + self.scale * np.log(val / (1.0 - val)) else: rescaled = np.inf * np.ones(val.size) rescaled[val == 0] = -np.inf - rescaled[(val > 0) & (val < 1)] = self.mu + self.scale\ - * np.log(val[(val > 0) & (val < 1)] / (1. - val[(val > 0) & (val < 1)])) + rescaled[(val > 0) & (val < 1)] = self.mu + self.scale * np.log( + val[(val > 0) & (val < 1)] / (1.0 - val[(val > 0) & (val < 1)]) + ) return rescaled def prob(self, val): @@ -1118,11 +1123,14 @@ def ln_prob(self, val): ======= Union[float, array_like]: Prior probability of val """ - return -(val - self.mu) / self.scale -\ - 2. * np.log(1. + np.exp(-(val - self.mu) / self.scale)) - np.log(self.scale) + return ( + -(val - self.mu) / self.scale + - 2.0 * np.log(1.0 + np.exp(-(val - self.mu) / self.scale)) + - np.log(self.scale) + ) def cdf(self, val): - return 1. / (1. + np.exp(-(val - self.mu) / self.scale)) + return 1.0 / (1.0 + np.exp(-(val - self.mu) / self.scale)) class Cauchy(Prior): @@ -1146,9 +1154,9 @@ def __init__(self, alpha, beta, name=None, latex_label=None, unit=None, boundary boundary: str See superclass """ - super(Cauchy, self).__init__(name=name, latex_label=latex_label, unit=unit, boundary=boundary) + super().__init__(name=name, latex_label=latex_label, unit=unit, boundary=boundary) - if beta <= 0.: + if beta <= 0.0: raise ValueError("For the Cauchy prior the scale must be positive") self.alpha = alpha @@ -1182,7 +1190,7 @@ def prob(self, val): ======= Union[float, array_like]: Prior probability of val """ - return 1. / self.beta / np.pi / (1. + ((val - self.alpha) / self.beta) ** 2) + return 1.0 / self.beta / np.pi / (1.0 + ((val - self.alpha) / self.beta) ** 2) def ln_prob(self, val): """Return the log prior probability of val. @@ -1195,7 +1203,7 @@ def ln_prob(self, val): ======= Union[float, array_like]: Log prior probability of val """ - return - np.log(self.beta * np.pi) - np.log(1. + ((val - self.alpha) / self.beta) ** 2) + return -np.log(self.beta * np.pi) - np.log(1.0 + ((val - self.alpha) / self.beta) ** 2) def cdf(self, val): return 0.5 + np.arctan((val - self.alpha) / self.beta) / np.pi @@ -1206,7 +1214,7 @@ class Lorentzian(Cauchy): class Gamma(Prior): - def __init__(self, k, theta=1., name=None, latex_label=None, unit=None, boundary=None): + def __init__(self, k, theta=1.0, name=None, latex_label=None, unit=None, boundary=None): """Gamma distribution https://en.wikipedia.org/wiki/Gamma_distribution @@ -1226,8 +1234,7 @@ def __init__(self, k, theta=1., name=None, latex_label=None, unit=None, boundary boundary: str See superclass """ - super(Gamma, self).__init__(name=name, minimum=0., latex_label=latex_label, - unit=unit, boundary=boundary) + super().__init__(name=name, minimum=0.0, latex_label=latex_label, unit=unit, boundary=boundary) if k <= 0 or theta <= 0: raise ValueError("For the Gamma prior the shape and scale must be positive") @@ -1274,15 +1281,16 @@ def ln_prob(self, val): _ln_prob = xlogy(self.k - 1, val) - val / self.theta - xlogy(self.k, self.theta) - gammaln(self.k) else: _ln_prob = -np.inf * np.ones(val.size) - idx = (val >= self.minimum) - _ln_prob[idx] = xlogy(self.k - 1, val[idx]) - val[idx] / self.theta\ - - xlogy(self.k, self.theta) - gammaln(self.k) + idx = val >= self.minimum + _ln_prob[idx] = ( + xlogy(self.k - 1, val[idx]) - val[idx] / self.theta - xlogy(self.k, self.theta) - gammaln(self.k) + ) return _ln_prob def cdf(self, val): if isinstance(val, (float, int)): if val < self.minimum: - _cdf = 0. + _cdf = 0.0 else: _cdf = gammainc(self.k, val / self.theta) else: @@ -1314,8 +1322,7 @@ def __init__(self, nu, name=None, latex_label=None, unit=None, boundary=None): if nu <= 0 or not isinstance(nu, int): raise ValueError("For the ChiSquared prior the number of degrees of freedom must be a positive integer") - super(ChiSquared, self).__init__(name=name, k=nu / 2., theta=2., - latex_label=latex_label, unit=unit, boundary=boundary) + super().__init__(name=name, k=nu / 2.0, theta=2.0, latex_label=latex_label, unit=unit, boundary=boundary) @property def nu(self): @@ -1323,12 +1330,11 @@ def nu(self): @nu.setter def nu(self, nu): - self.k = nu / 2. + self.k = nu / 2.0 class FermiDirac(Prior): - def __init__(self, sigma, mu=None, r=None, name=None, latex_label=None, - unit=None): + def __init__(self, sigma, mu=None, r=None, name=None, latex_label=None, unit=None): """A Fermi-Dirac type prior, with a fixed lower boundary at zero (see, e.g. Section 2.3.5 of [1]_). The probability distribution is defined by Equation 22 of [1]_. @@ -1356,13 +1362,12 @@ def __init__(self, sigma, mu=None, r=None, name=None, latex_label=None, .. [1] M. Pitkin, M. Isi, J. Veitch & G. Woan, `arXiv:1705.08978v1 `_, 2017. """ - super(FermiDirac, self).__init__(name=name, latex_label=latex_label, unit=unit, minimum=0.) + super().__init__(name=name, latex_label=latex_label, unit=unit, minimum=0.0) self.sigma = sigma if mu is None and r is None: - raise ValueError("For the Fermi-Dirac prior either a 'mu' value or 'r' " - "value must be given.") + raise ValueError("For the Fermi-Dirac prior either a 'mu' value or 'r' value must be given.") if r is None and mu is not None: self.mu = mu @@ -1371,9 +1376,8 @@ def __init__(self, sigma, mu=None, r=None, name=None, latex_label=None, self.r = r self.mu = self.sigma * self.r - if self.r <= 0. or self.sigma <= 0.: - raise ValueError("For the Fermi-Dirac prior the values of sigma and r " - "must be positive.") + if self.r <= 0.0 or self.sigma <= 0.0: + raise ValueError("For the Fermi-Dirac prior the values of sigma and r must be positive.") self.expr = np.exp(self.r) @@ -1394,7 +1398,7 @@ def rescale(self, val): .. [1] M. Pitkin, M. Isi, J. Veitch & G. Woan, `arXiv:1705.08978v1 `_, 2017. """ - inv = -1 / self.expr + (1 + self.expr)**-val + (1 + self.expr)**-val / self.expr + inv = -1 / self.expr + (1 + self.expr) ** -val + (1 + self.expr) ** -val / self.expr return -self.sigma * np.log(np.maximum(inv, 0)) def prob(self, val): @@ -1409,7 +1413,7 @@ def prob(self, val): float: Prior probability of val """ return ( - (np.exp((val - self.mu) / self.sigma) + 1)**-1 + (np.exp((val - self.mu) / self.sigma) + 1) ** -1 / (self.sigma * np.log1p(self.expr)) * (val >= self.minimum) ) @@ -1448,10 +1452,7 @@ def cdf(self, val): .. [1] M. Pitkin, M. Isi, J. Veitch & G. Woan, `arXiv:1705.08978v1 `_, 2017. """ - result = ( - (np.logaddexp(0, -self.r) - np.logaddexp(-val / self.sigma, -self.r)) - / np.logaddexp(0, self.r) - ) + result = (np.logaddexp(0, -self.r) - np.logaddexp(-val / self.sigma, -self.r)) / np.logaddexp(0, self.r) return np.clip(result, 0, 1) @@ -1485,15 +1486,13 @@ def __init__( nvalues = len(values) values = np.array(values) if values.shape != (nvalues,): - raise ValueError( - f"Shape of argument 'values' must be 1d array-like but has shape {values.shape}" - ) + raise ValueError(f"Shape of argument 'values' must be 1d array-like but has shape {values.shape}") minimum = np.min(values) # Small delta added to help with MCMC walking maximum = np.max(values) * (1 + 1e-15) - super(WeightedDiscreteValues, self).__init__( - name=name, latex_label=latex_label, minimum=minimum, - maximum=maximum, unit=unit, boundary=boundary) + super().__init__( + name=name, latex_label=latex_label, minimum=minimum, maximum=maximum, unit=unit, boundary=boundary + ) self.nvalues = nvalues sorter = np.argsort(values) self._values_array = values[sorter] @@ -1502,11 +1501,7 @@ def __init__( # python buildins self.values = self._values_array.tolist() - weights = ( - np.array(weights) / np.sum(weights) - if weights is not None - else np.ones(self.nvalues) / self.nvalues - ) + weights = np.array(weights) / np.sum(weights) if weights is not None else np.ones(self.nvalues) / self.nvalues # check for consistent shape of input if weights.shape != (self.nvalues,): raise ValueError( @@ -1586,16 +1581,13 @@ def ln_prob(self, val): """ index = np.searchsorted(self._values_array, val) index = np.clip(index, 0, self.nvalues - 1) - lnp = np.where( - self._values_array[index] == val, self._lnweights_array[index], -np.inf - ) + lnp = np.where(self._values_array[index] == val, self._lnweights_array[index], -np.inf) # turn 0d numpy array to scalar return lnp[()] class DiscreteValues(WeightedDiscreteValues): - def __init__(self, values, name=None, latex_label=None, - unit=None, boundary="periodic"): + def __init__(self, values, name=None, latex_label=None, unit=None, boundary="periodic"): """An equal-weighted discrete-valued prior Parameters @@ -1610,7 +1602,7 @@ def __init__(self, values, name=None, latex_label=None, See superclass """ weights = np.ones_like(values) - super(DiscreteValues, self).__init__( + super().__init__( values=values, weights=weights, name=name, @@ -1649,7 +1641,7 @@ def __init__( """ self.ncategories = ncategories values = np.arange(0, ncategories) - super(WeightedCategorical, self).__init__( + super().__init__( values=values, weights=weights, name=name, @@ -1660,9 +1652,7 @@ def __init__( class Categorical(DiscreteValues): - def __init__( - self, ncategories, name=None, latex_label=None, unit=None, boundary="periodic" - ): + def __init__(self, ncategories, name=None, latex_label=None, unit=None, boundary="periodic"): """An equal-weighted Categorical prior Parameters ========== @@ -1678,7 +1668,7 @@ def __init__( """ self.ncategories = ncategories values = np.arange(0, ncategories) - super(Categorical, self).__init__( + super().__init__( values=values, name=name, latex_label=latex_label, @@ -1697,8 +1687,9 @@ class Triangular(Prior): where the mode has the highest pdf value. """ + def __init__(self, mode, minimum, maximum, name=None, latex_label=None, unit=None): - super(Triangular, self).__init__( + super().__init__( name=name, latex_label=latex_label, unit=unit, @@ -1706,9 +1697,7 @@ def __init__(self, mode, minimum, maximum, name=None, latex_label=None, unit=Non maximum=maximum, ) self.mode = mode - self.fractional_mode = (self.mode - self.minimum) / ( - self.maximum - self.minimum - ) + self.fractional_mode = (self.mode - self.minimum) / (self.maximum - self.minimum) self.scale = self.maximum - self.minimum self.rescaled_minimum = self.minimum - (self.minimum == self.mode) * self.scale self.rescaled_maximum = self.maximum + (self.maximum == self.mode) * self.scale @@ -1731,9 +1720,9 @@ def rescale(self, val): """ below_mode = (val * self.scale * (self.mode - self.minimum)) ** 0.5 above_mode = ((1 - val) * self.scale * (self.maximum - self.mode)) ** 0.5 - return (self.minimum + below_mode) * (val < self.fractional_mode) + ( - self.maximum - above_mode - ) * (val >= self.fractional_mode) + return (self.minimum + below_mode) * (val < self.fractional_mode) + (self.maximum - above_mode) * ( + val >= self.fractional_mode + ) def prob(self, val): """ @@ -1775,17 +1764,7 @@ def cdf(self, val): float: prior cumulative probability at val """ - return ( - (val > self.mode) - + (val > self.minimum) - * (val <= self.maximum) - / (self.scale) - * ( - (val > self.mode) - * (self.rescaled_maximum - val) ** 2.0 - / (self.mode - self.rescaled_maximum) - + (val <= self.mode) - * (val - self.rescaled_minimum) ** 2.0 - / (self.mode - self.rescaled_minimum) - ) + return (val > self.mode) + (val > self.minimum) * (val <= self.maximum) / (self.scale) * ( + (val > self.mode) * (self.rescaled_maximum - val) ** 2.0 / (self.mode - self.rescaled_maximum) + + (val <= self.mode) * (val - self.rescaled_minimum) ** 2.0 / (self.mode - self.rescaled_minimum) ) diff --git a/bilby/core/prior/base.py b/bilby/core/prior/base.py index 5ca28de28..06b8bf2d0 100644 --- a/bilby/core/prior/base.py +++ b/bilby/core/prior/base.py @@ -1,27 +1,37 @@ -from importlib import import_module import json import os import re +from importlib import import_module import numpy as np import scipy.stats from ..utils import ( - infer_args_from_method, BilbyJsonEncoder, decode_bilby_json, - logger, get_dict_with_properties, + infer_args_from_method, + logger, +) +from ..utils import ( WrappedInterp1d as interp1d, ) -class Prior(object): +class Prior: _default_latex_labels = {} - def __init__(self, name=None, latex_label=None, unit=None, minimum=-np.inf, - maximum=np.inf, check_range_nonzero=True, boundary=None): - """ Implements a Prior object + def __init__( + self, + name=None, + latex_label=None, + unit=None, + minimum=-np.inf, + maximum=np.inf, + check_range_nonzero=True, + boundary=None, + ): + """Implements a Prior object Parameters ========== @@ -42,11 +52,7 @@ def __init__(self, name=None, latex_label=None, unit=None, minimum=-np.inf, Currently implemented in cpnest, dynesty and pymultinest. """ if check_range_nonzero and maximum <= minimum: - raise ValueError( - "maximum {} <= minimum {} for {} prior on {}".format( - maximum, minimum, type(self).__name__, name - ) - ) + raise ValueError(f"maximum {maximum} <= minimum {minimum} for {type(self).__name__} prior on {name}") self.name = name self.latex_label = latex_label self.unit = unit @@ -109,7 +115,7 @@ def __eq__(self, other): if isinstance(this_dict[key], np.ndarray): if not np.array_equal(this_dict[key], other_dict[key]): return False - elif isinstance(this_dict[key], type(scipy.stats.beta(1., 1.))): + elif isinstance(this_dict[key], type(scipy.stats.beta(1.0, 1.0))): continue else: if not this_dict[key] == other_dict[key]: @@ -167,17 +173,15 @@ def prob(self, val): return np.nan def cdf(self, val): - """ Generic method to calculate CDF, can be overwritten in subclass """ + """Generic method to calculate CDF, can be overwritten in subclass""" from scipy.integrate import cumulative_trapezoid + if np.any(np.isinf([self.minimum, self.maximum])): - raise ValueError( - "Unable to use the generic CDF calculation for priors with" - "infinite support") + raise ValueError("Unable to use the generic CDF calculation for priors withinfinite support") x = np.linspace(self.minimum, self.maximum, 1000) pdf = self.prob(x) cdf = cumulative_trapezoid(pdf, x, initial=0) - interp = interp1d(x, cdf, assume_sorted=True, bounds_error=False, - fill_value=(0, 1)) + interp = interp1d(x, cdf, assume_sorted=True, bounds_error=False, fill_value=(0, 1)) output = interp(val) if isinstance(val, (int, float)): output = float(output) @@ -195,7 +199,7 @@ def ln_prob(self, val): np.nan """ - with np.errstate(divide='ignore'): + with np.errstate(divide="ignore"): return np.log(self.prob(val)) def is_in_prior_range(self, val): @@ -226,7 +230,7 @@ def __repr__(self): prior_name = self.__class__.__name__ prior_module = self.__class__.__module__ instantiation_dict = self.get_instantiation_dict() - args = ', '.join([f'{key}={repr(instantiation_dict[key])}' for key in instantiation_dict]) + args = ", ".join([f"{key}={repr(instantiation_dict[key])}" for key in instantiation_dict]) if "bilby.core.prior" in prior_module: return f"{prior_name}({args})" else: @@ -276,9 +280,9 @@ def unit(self, unit): @property def latex_label_with_unit(self): - """ If a unit is specified, returns a string of the latex label and unit """ + """If a unit is specified, returns a string of the latex label and unit""" if self.unit is not None: - return "{} [{}]".format(self.latex_label, self.unit) + return f"{self.latex_label} [{self.unit}]" else: return self.latex_label @@ -313,8 +317,8 @@ def boundary(self): @boundary.setter def boundary(self, boundary): - if boundary not in ['periodic', 'reflective', None]: - raise ValueError('{} is not a valid setting for prior boundaries'.format(boundary)) + if boundary not in ["periodic", "reflective", None]: + raise ValueError(f"{boundary} is not a valid setting for prior boundaries") self._boundary = boundary @property @@ -341,19 +345,18 @@ def from_repr(cls, string): def _from_repr(cls, string): subclass_args = infer_args_from_method(cls.__init__) - string = string.replace(' ', '') + string = string.replace(" ", "") kwargs = cls._split_repr(string) for key in kwargs: val = kwargs[key] if key not in subclass_args and not hasattr(cls, "reference_params"): - raise AttributeError('Unknown argument {} for class {}'.format( - key, cls.__name__)) + raise AttributeError(f"Unknown argument {key} for class {cls.__name__}") else: kwargs[key] = cls._parse_argument_string(val) if key in ["condition_func", "conversion_function"] and isinstance(kwargs[key], str): if "." in kwargs[key]: - module = '.'.join(kwargs[key].split('.')[:-1]) - name = kwargs[key].split('.')[-1] + module = ".".join(kwargs[key].split(".")[:-1]) + name = kwargs[key].split(".")[-1] else: module = __name__ name = kwargs[key] @@ -363,30 +366,29 @@ def _from_repr(cls, string): @classmethod def _split_repr(cls, string): subclass_args = infer_args_from_method(cls.__init__) - args = string.split(',') + args = string.split(",") remove = list() for ii, key in enumerate(args): - for paren_pair in ['()', '{}', '[]']: + for paren_pair in ["()", "{}", "[]"]: if paren_pair[0] in key: jj = ii while paren_pair[1] not in args[jj]: jj += 1 - args[ii] = ','.join([args[ii], args[jj]]).strip() + args[ii] = ",".join([args[ii], args[jj]]).strip() remove.append(jj) remove.reverse() for ii in remove: del args[ii] kwargs = dict() for ii, arg in enumerate(args): - if '=' not in arg: - logger.debug( - 'Reading priors with non-keyword arguments is dangerous!') + if "=" not in arg: + logger.debug("Reading priors with non-keyword arguments is dangerous!") key = subclass_args[ii] val = arg else: - split_arg = arg.split('=') + split_arg = arg.split("=") key = split_arg[0] - val = '='.join(split_arg[1:]) + val = "=".join(split_arg[1:]) kwargs[key] = val return kwargs @@ -431,20 +433,20 @@ def _parse_argument_string(cls, val): TypeError: If val cannot be parsed as described above. """ - if val == 'None': + if val == "None": val = None - elif re.sub(r'\'.*\'', '', val) in ['r', 'u']: + elif re.sub(r"\'.*\'", "", val) in ["r", "u"]: val = val[2:-1] elif val.startswith("'") and val.endswith("'"): val = val.strip("'") - elif '(' in val and not val.startswith(("[", "{")): - other_cls = val.split('(')[0] - vals = '('.join(val.split('(')[1:])[:-1] + elif "(" in val and not val.startswith(("[", "{")): + other_cls = val.split("(")[0] + vals = "(".join(val.split("(")[1:])[:-1] if "." in other_cls: - module = '.'.join(other_cls.split('.')[:-1]) - other_cls = other_cls.split('.')[-1] + module = ".".join(other_cls.split(".")[:-1]) + other_cls = other_cls.split(".")[-1] else: - module = __name__.replace('.' + os.path.basename(__file__).replace('.py', ''), '') + module = __name__.replace("." + os.path.basename(__file__).replace(".py", ""), "") other_cls = getattr(import_module(module), other_cls) val = other_cls.from_repr(vals) else: @@ -452,25 +454,19 @@ def _parse_argument_string(cls, val): val = eval(val, dict(), dict(np=np, inf=np.inf, pi=np.pi)) except NameError: if "." in val: - module = '.'.join(val.split('.')[:-1]) - func = val.split('.')[-1] + module = ".".join(val.split(".")[:-1]) + func = val.split(".")[-1] new_val = getattr(import_module(module), func, val) if val == new_val: - raise TypeError( - "Cannot evaluate prior, " - f"failed to parse argument {val}" - ) + raise TypeError(f"Cannot evaluate prior, failed to parse argument {val}") else: val = new_val return val class Constraint(Prior): - - def __init__(self, minimum, maximum, name=None, latex_label=None, - unit=None): - super(Constraint, self).__init__(minimum=minimum, maximum=maximum, name=name, - latex_label=latex_label, unit=unit) + def __init__(self, minimum, maximum, name=None, latex_label=None, unit=None): + super().__init__(minimum=minimum, maximum=maximum, name=name, latex_label=latex_label, unit=unit) self._is_fixed = True def prob(self, val): @@ -478,4 +474,4 @@ def prob(self, val): class PriorException(Exception): - """ General base class for all prior exceptions """ + """General base class for all prior exceptions""" diff --git a/bilby/core/prior/conditional.py b/bilby/core/prior/conditional.py index 7c2a739e2..0c251e56e 100644 --- a/bilby/core/prior/conditional.py +++ b/bilby/core/prior/conditional.py @@ -1,15 +1,32 @@ +from ..utils import infer_args_from_method, infer_parameters_from_function +from .analytical import ( + Beta, + Cauchy, + ChiSquared, + Cosine, + DeltaFunction, + Exponential, + FermiDirac, + Gamma, + Gaussian, + HalfGaussian, + Logistic, + LogNormal, + LogUniform, + PowerLaw, + Sine, + StudentT, + SymmetricLogUniform, + TruncatedGaussian, + Uniform, +) from .base import Prior, PriorException from .interpolated import Interped -from .analytical import DeltaFunction, PowerLaw, Uniform, LogUniform, \ - SymmetricLogUniform, Cosine, Sine, Gaussian, TruncatedGaussian, HalfGaussian, \ - LogNormal, Exponential, StudentT, Beta, Logistic, Cauchy, Gamma, ChiSquared, FermiDirac -from ..utils import infer_args_from_method, infer_parameters_from_function def conditional_prior_factory(prior_class): class ConditionalPrior(prior_class): - def __init__(self, condition_func, name=None, latex_label=None, unit=None, - boundary=None, **reference_params): + def __init__(self, condition_func, name=None, latex_label=None, unit=None, boundary=None, **reference_params): """ Parameters @@ -28,10 +45,7 @@ def __init__(self, condition_func, name=None, latex_label=None, unit=None, .. code-block:: python def condition_func(reference_params, y): - return dict( - minimum=reference_params['minimum'] + y, - maximum=reference_params['maximum'] + y - ) + return dict(minimum=reference_params["minimum"] + y, maximum=reference_params["maximum"] + y) name: str, optional See superclass @@ -46,18 +60,16 @@ def condition_func(reference_params, y): This differs on the `prior_class`, for example for the Gaussian prior this is `mu` and `sigma`. """ - if 'boundary' in infer_args_from_method(super(ConditionalPrior, self).__init__): - super(ConditionalPrior, self).__init__(name=name, latex_label=latex_label, - unit=unit, boundary=boundary, **reference_params) + if "boundary" in infer_args_from_method(super().__init__): + super().__init__(name=name, latex_label=latex_label, unit=unit, boundary=boundary, **reference_params) else: - super(ConditionalPrior, self).__init__(name=name, latex_label=latex_label, - unit=unit, **reference_params) + super().__init__(name=name, latex_label=latex_label, unit=unit, **reference_params) self._required_variables = None self.condition_func = condition_func self._reference_params = reference_params - self.__class__.__name__ = 'Conditional{}'.format(prior_class.__name__) - self.__class__.__qualname__ = 'Conditional{}'.format(prior_class.__qualname__) + self.__class__.__name__ = f"Conditional{prior_class.__name__}" + self.__class__.__qualname__ = f"Conditional{prior_class.__qualname__}" def sample(self, size=None, **required_variables): """Draw a sample from the prior @@ -93,7 +105,7 @@ def rescale(self, val, **required_variables): """ self.update_conditions(**required_variables) - return super(ConditionalPrior, self).rescale(val) + return super().rescale(val) def prob(self, val, **required_variables): """Return the prior probability of val. @@ -111,7 +123,7 @@ def prob(self, val, **required_variables): float: Prior probability of val """ self.update_conditions(**required_variables) - return super(ConditionalPrior, self).prob(val) + return super().prob(val) def ln_prob(self, val, **required_variables): """Return the natural log prior probability of val. @@ -129,7 +141,7 @@ def ln_prob(self, val, **required_variables): float: Natural log prior probability of val """ self.update_conditions(**required_variables) - return super(ConditionalPrior, self).ln_prob(val) + return super().ln_prob(val) def cdf(self, val, **required_variables): """Return the cdf of val. @@ -147,7 +159,7 @@ def cdf(self, val, **required_variables): float: CDF of val """ self.update_conditions(**required_variables) - return super(ConditionalPrior, self).cdf(val) + return super().cdf(val) def update_conditions(self, **required_variables): """ @@ -171,9 +183,10 @@ class depending on the required variables it depends on. elif len(required_variables) == 0: return else: - raise IllegalRequiredVariablesException("Expected kwargs for {}. Got kwargs for {} instead." - .format(self.required_variables, - list(required_variables.keys()))) + raise IllegalRequiredVariablesException( + f"Expected kwargs for {self.required_variables}. " + f"Got kwargs for {list(required_variables.keys())} instead." + ) @property def reference_params(self): @@ -198,11 +211,11 @@ def condition_func(self, condition_func): @property def required_variables(self): - """ The required variables to pass into the condition function. """ + """The required variables to pass into the condition function.""" return self._required_variables def get_instantiation_dict(self): - instantiation_dict = super(ConditionalPrior, self).get_instantiation_dict() + instantiation_dict = super().get_instantiation_dict() for key, value in self.reference_params.items(): instantiation_dict[key] = value return instantiation_dict @@ -227,13 +240,11 @@ def __repr__(self): """ prior_name = self.__class__.__name__ instantiation_dict = self.get_instantiation_dict() - instantiation_dict["condition_func"] = ".".join([ - instantiation_dict["condition_func"].__module__, - instantiation_dict["condition_func"].__name__ - ]) - args = ', '.join(['{}={}'.format(key, repr(instantiation_dict[key])) - for key in instantiation_dict]) - return "{}({})".format(prior_name, args) + instantiation_dict["condition_func"] = ".".join( + [instantiation_dict["condition_func"].__module__, instantiation_dict["condition_func"].__name__] + ) + args = ", ".join([f"{key}={repr(instantiation_dict[key])}" for key in instantiation_dict]) + return f"{prior_name}({args})" return ConditionalPrior @@ -360,24 +371,23 @@ class DirichletElement(ConditionalBeta): def __init__(self, order, n_dimensions, label): """ """ - super(DirichletElement, self).__init__( - minimum=0, maximum=1, alpha=1, beta=n_dimensions - order - 1, + super().__init__( + minimum=0, + maximum=1, + alpha=1, + beta=n_dimensions - order - 1, name=label + str(order), - condition_func=self.dirichlet_condition + condition_func=self.dirichlet_condition, ) self.label = label self.n_dimensions = n_dimensions self.order = order - self._required_variables = [ - label + str(ii) for ii in range(order) - ] - self.__class__.__name__ = 'DirichletElement' - self.__class__.__qualname__ = 'DirichletElement' + self._required_variables = [label + str(ii) for ii in range(order)] + self.__class__.__name__ = "DirichletElement" + self.__class__.__qualname__ = "DirichletElement" def dirichlet_condition(self, reference_parms, **kwargs): - remaining = 1 - sum( - [kwargs[self.label + str(ii)] for ii in range(self.order)] - ) + remaining = 1 - sum([kwargs[self.label + str(ii)] for ii in range(self.order)]) return dict(minimum=reference_parms["minimum"], maximum=remaining) def __repr__(self): @@ -388,8 +398,8 @@ def get_instantiation_dict(self): class ConditionalPriorException(PriorException): - """ General base class for all conditional prior exceptions """ + """General base class for all conditional prior exceptions""" class IllegalRequiredVariablesException(ConditionalPriorException): - """ Exception class for exceptions relating to handling the required variables. """ + """Exception class for exceptions relating to handling the required variables.""" diff --git a/bilby/core/prior/dict.py b/bilby/core/prior/dict.py index 6d244610b..6d37ba212 100644 --- a/bilby/core/prior/dict.py +++ b/bilby/core/prior/dict.py @@ -2,20 +2,19 @@ import os import re from importlib import import_module -from io import open as ioopen from warnings import warn import numpy as np -from .analytical import DeltaFunction -from .base import Prior, Constraint -from .joint import JointPrior, BaseJointPriorDist from ..utils import ( - logger, - check_directory_exists_and_if_not_mkdir, BilbyJsonEncoder, + check_directory_exists_and_if_not_mkdir, decode_bilby_json, + logger, ) +from .analytical import DeltaFunction +from .base import Constraint, Prior +from .joint import BaseJointPriorDist, JointPrior class PriorDict(dict): @@ -32,14 +31,11 @@ def __init__(self, dictionary=None, filename=None, conversion_function=None): Function to convert between sampled parameters and constraints. Default is no conversion. """ - super(PriorDict, self).__init__() + super().__init__() if isinstance(dictionary, dict): self.from_dictionary(dictionary) elif type(dictionary) is str: - logger.debug( - 'Argument "dictionary" is a string.' - + " Assuming it is intended as a file name." - ) + logger.debug('Argument "dictionary" is a string.' + " Assuming it is intended as a file name.") self.from_file(dictionary) elif type(filename) is str: self.from_file(filename) @@ -90,25 +86,21 @@ def to_file(self, outdir, label): """ check_directory_exists_and_if_not_mkdir(outdir) - prior_file = os.path.join(outdir, "{}.prior".format(label)) - logger.debug("Writing priors to {}".format(prior_file)) + prior_file = os.path.join(outdir, f"{label}.prior") + logger.debug(f"Writing priors to {prior_file}") joint_dists = [] with open(prior_file, "w") as outfile: for key in self.keys(): if JointPrior in self[key].__class__.__mro__: - distname = "_".join(self[key].dist.names) + "_{}".format( - self[key].dist.distname - ) + distname = "_".join(self[key].dist.names) + f"_{self[key].dist.distname}" if distname not in joint_dists: joint_dists.append(distname) - outfile.write("{} = {}\n".format(distname, self[key].dist)) + outfile.write(f"{distname} = {self[key].dist}\n") diststr = repr(self[key].dist) priorstr = repr(self[key]) - outfile.write( - "{} = {}\n".format(key, priorstr.replace(diststr, distname)) - ) + outfile.write(f"{key} = {priorstr.replace(diststr, distname)}\n") else: - outfile.write("{} = {}\n".format(key, self[key])) + outfile.write(f"{key} = {self[key]}\n") def _get_json_dict(self): self.convert_floats_to_delta_functions() @@ -120,8 +112,8 @@ def _get_json_dict(self): def to_json(self, outdir, label): check_directory_exists_and_if_not_mkdir(outdir) - prior_file = os.path.join(outdir, "{}_prior.json".format(label)) - logger.debug("Writing priors to {}".format(prior_file)) + prior_file = os.path.join(outdir, f"{label}_prior.json") + logger.debug(f"Writing priors to {prior_file}") with open(prior_file, "w") as outfile: json.dump(self._get_json_dict(), outfile, cls=BilbyJsonEncoder, indent=2) @@ -147,7 +139,7 @@ def from_file(self, filename): comments = ["#", "\n"] prior = dict() - with ioopen(filename, "r", encoding="unicode_escape") as f: + with open(filename, encoding="unicode_escape") as f: for line in f: if line[0] in comments: continue @@ -161,15 +153,9 @@ def from_file(self, filename): @classmethod def _get_from_json_dict(cls, prior_dict): try: - class_ = getattr( - import_module(prior_dict["__module__"]), prior_dict["__name__"] - ) + class_ = getattr(import_module(prior_dict["__module__"]), prior_dict["__name__"]) except ImportError: - logger.debug( - "Cannot import prior module {}.{}".format( - prior_dict["__module__"], prior_dict["__name__"] - ) - ) + logger.debug("Cannot import prior module {}.{}".format(prior_dict["__module__"], prior_dict["__name__"])) class_ = cls except KeyError: logger.debug("Cannot find module name to load") @@ -189,7 +175,7 @@ def from_json(cls, filename): filename: str Name of the file to be read in """ - with open(filename, "r") as ff: + with open(filename) as ff: obj = json.load(ff, object_hook=decode_bilby_json) # make sure priors containing JointDists are properly handled and point @@ -219,7 +205,7 @@ def from_dictionary(self, dictionary): args = "(".join(val.split("(")[1:])[:-1] try: dictionary[key] = DeltaFunction(peak=float(cls)) - logger.debug("{} converted to DeltaFunction prior".format(key)) + logger.debug(f"{key} converted to DeltaFunction prior") continue except ValueError: pass @@ -227,23 +213,17 @@ def from_dictionary(self, dictionary): module = ".".join(cls.split(".")[:-1]) cls = cls.split(".")[-1] else: - module = __name__.replace( - "." + os.path.basename(__file__).replace(".py", ""), "" - ) + module = __name__.replace("." + os.path.basename(__file__).replace(".py", ""), "") try: cls = getattr(import_module(module), cls, cls) except ModuleNotFoundError: - logger.error( - "Cannot import prior class {} for entry: {}={}".format( - cls, key, val - ) - ) + logger.error(f"Cannot import prior class {cls} for entry: {key}={val}") raise if key.lower() in ["conversion_function", "condition_func"]: setattr(self, key, cls) elif isinstance(cls, str): if "(" in val: - raise TypeError("Unable to parse prior class {}".format(cls)) + raise TypeError(f"Unable to parse prior class {cls}") else: continue elif issubclass(cls, BaseJointPriorDist): @@ -254,16 +234,12 @@ def from_dictionary(self, dictionary): jpkwargs = { item[0].strip(): cls._parse_argument_string(item[1]) for item in cls._split_repr( - ", ".join( - [arg for arg in args.split(",") if "dist=" not in arg] - ) + ", ".join([arg for arg in args.split(",") if "dist=" not in arg]) ).items() } keymatch = re.match(r"dist=(?P\S+),", args) if keymatch is None: - raise ValueError( - "'dist' argument for JointPrior is not specified" - ) + raise ValueError("'dist' argument for JointPrior is not specified") if keymatch["distkey"] not in jpdkwargs: raise ValueError( @@ -276,10 +252,7 @@ def from_dictionary(self, dictionary): try: dictionary[key] = cls.from_repr(args) except TypeError as e: - raise TypeError( - "Unable to parse prior, bad entry: {} " - "= {}. Error message {}".format(key, val, e) - ) + raise TypeError(f"Unable to parse prior, bad entry: {key} = {val}. Error message {e}") elif isinstance(val, dict): try: _class = getattr( @@ -293,16 +266,10 @@ def from_dictionary(self, dictionary): val.get("__module__", "none"), val.get("__name__", "none") ) ) - logger.warning( - "Cannot convert {} into a prior object. " - "Leaving as dictionary.".format(key) - ) + logger.warning(f"Cannot convert {key} into a prior object. Leaving as dictionary.") continue else: - raise TypeError( - "Unable to parse prior, bad entry: {} " - "= {} of type {}".format(key, val, type(val)) - ) + raise TypeError(f"Unable to parse prior, bad entry: {key} = {val} of type {type(val)}") self.update(dictionary) def convert_floats_to_delta_functions(self): @@ -312,11 +279,9 @@ def convert_floats_to_delta_functions(self): continue elif isinstance(self[key], float) or isinstance(self[key], int): self[key] = DeltaFunction(self[key]) - logger.debug("{} converted to delta function prior.".format(key)) + logger.debug(f"{key} converted to delta function prior.") else: - logger.debug( - "{} cannot be converted to delta function prior.".format(key) - ) + logger.debug(f"{key} cannot be converted to delta function prior.") def fill_priors(self, likelihood=None, default_priors_file=None): """ @@ -402,7 +367,7 @@ def sample_subset(self, keys=iter([]), size=None): elif isinstance(self[key], Prior): samples[key] = self[key].sample(size=size) else: - logger.debug("{} not a known prior.".format(key)) + logger.debug(f"{key} not a known prior.") return samples @property @@ -415,9 +380,7 @@ def non_fixed_keys(self): @property def fixed_keys(self): - return [ - k for k, p in self.items() if (p.is_fixed and k not in self.constraint_keys) - ] + return [k for k, p in self.items() if (p.is_fixed and k not in self.constraint_keys)] @property def constraint_keys(self): @@ -473,29 +436,20 @@ def check_efficiency(n_tested, n_valid): samples = self.sample_subset(keys=keys, size=needed) keep = np.array(self.evaluate_constraints(samples), dtype=bool) for key in keys: - all_samples[key] = np.hstack( - [all_samples[key], samples[key][keep].flatten()] - ) + all_samples[key] = np.hstack([all_samples[key], samples[key][keep].flatten()]) n_tested_samples += needed n_valid_samples += np.sum(keep) check_efficiency(n_tested_samples, n_valid_samples) - all_samples = { - key: np.reshape(all_samples[key][:needed], size) for key in keys - } + all_samples = {key: np.reshape(all_samples[key][:needed], size) for key in keys} return all_samples - def normalize_constraint_factor( - self, keys, min_accept=10000, sampling_chunk=50000, nrepeats=10 - ): + def normalize_constraint_factor(self, keys, min_accept=10000, sampling_chunk=50000, nrepeats=10): if len(self.constraint_keys) == 0: return 1 elif keys in self._cached_normalizations.keys(): return self._cached_normalizations[keys] else: - factor_estimates = [ - self._estimate_normalization(keys, min_accept, sampling_chunk) - for _ in range(nrepeats) - ] + factor_estimates = [self._estimate_normalization(keys, min_accept, sampling_chunk) for _ in range(nrepeats)] factor = np.mean(factor_estimates) if np.std(factor_estimates) > 0: decimals = int(-np.floor(np.log10(3 * np.std(factor_estimates)))) @@ -577,8 +531,7 @@ def ln_prob(self, sample, axis=None, normalized=True): """ ln_prob = np.sum([self[key].ln_prob(sample[key]) for key in sample], axis=axis) - return self.check_ln_prob(sample, ln_prob, - normalized=normalized) + return self.check_ln_prob(sample, ln_prob, normalized=normalized) def check_ln_prob(self, sample, ln_prob, normalized=True): if normalized: @@ -614,9 +567,7 @@ def cdf(self, sample): dict, pandas.DataFrame: Dictionary containing the CDF values """ - return sample.__class__( - {key: self[key].cdf(sample) for key, sample in sample.items()} - ) + return sample.__class__({key: self[key].cdf(sample) for key, sample in sample.items()}) def rescale(self, keys, theta): """Rescale samples from unit cube to prior @@ -632,9 +583,7 @@ def rescale(self, keys, theta): ======= list: List of floats containing the rescaled sample """ - return list( - [self[key].rescale(sample) for key, sample in zip(keys, theta)] - ) + return list([self[key].rescale(sample) for key, sample in zip(keys, theta)]) def test_redundancy(self, key, disable_logging=False): """Empty redundancy test, should be overwritten in subclasses""" @@ -655,9 +604,7 @@ def test_has_redundant_keys(self): temp = self.copy() del temp[key] if temp.test_redundancy(key, disable_logging=True): - logger.warning( - f"{key} is a redundant key in this {self.__class__.__name__}." - ) + logger.warning(f"{key} is a redundant key in this {self.__class__.__name__}.") redundant = True return redundant @@ -689,7 +636,7 @@ def __init__(self, dictionary=None, filename=None, conversion_function=None): self._rescale_keys = [] self._rescale_indexes = [] self._least_recently_rescaled_keys = [] - super(ConditionalPriorDict, self).__init__( + super().__init__( dictionary=dictionary, filename=filename, conversion_function=conversion_function, @@ -709,12 +656,8 @@ def _resolve_conditions(self): 4. We set the `self._resolved` flag to True if all conditional priors were added in the right order """ - self._unconditional_keys = [ - key for key in self.keys() if not hasattr(self[key], "condition_func") - ] - conditional_keys_unsorted = [ - key for key in self.keys() if hasattr(self[key], "condition_func") - ] + self._unconditional_keys = [key for key in self.keys() if not hasattr(self[key], "condition_func")] + conditional_keys_unsorted = [key for key in self.keys() if hasattr(self[key], "condition_func")] self._conditional_keys = [] for _ in range(len(self)): for key in conditional_keys_unsorted[:]: @@ -736,38 +679,28 @@ def _check_conditions_resolved(self, key, sampled_keys): def sample_subset(self, keys=iter([]), size=None): self.convert_floats_to_delta_functions() - add_delta_keys = [ - key - for key in self.keys() - if key not in keys and isinstance(self[key], DeltaFunction) - ] + add_delta_keys = [key for key in self.keys() if key not in keys and isinstance(self[key], DeltaFunction)] use_keys = add_delta_keys + list(keys) subset_dict = ConditionalPriorDict({key: self[key] for key in use_keys}) if not subset_dict._resolved: - raise IllegalConditionsException( - "The current set of priors contains unresolvable conditions." - ) + raise IllegalConditionsException("The current set of priors contains unresolvable conditions.") samples = dict() for key in subset_dict.sorted_keys: if key not in keys or isinstance(self[key], Constraint): continue if isinstance(self[key], Prior): try: - samples[key] = subset_dict[key].sample( - size=size, **subset_dict.get_required_variables(key) - ) + samples[key] = subset_dict[key].sample(size=size, **subset_dict.get_required_variables(key)) except ValueError: # Some prior classes can not handle an array of conditional parameters (e.g. alpha for PowerLaw) # If that is the case, we sample each sample individually. required_variables = subset_dict.get_required_variables(key) samples[key] = np.zeros(size) for i in range(size): - rvars = { - key: value[i] for key, value in required_variables.items() - } + rvars = {key: value[i] for key, value in required_variables.items()} samples[key][i] = subset_dict[key].sample(**rvars) else: - logger.debug("{} not a known prior.".format(key)) + logger.debug(f"{key} not a known prior.") return samples def get_required_variables(self, key): @@ -782,10 +715,7 @@ def get_required_variables(self, key): ======= dict: key/value pairs of the required variables """ - return { - k: self[k].least_recently_sampled - for k in getattr(self[key], "required_variables", []) - } + return {k: self[k].least_recently_sampled for k in getattr(self[key], "required_variables", [])} def prob(self, sample, **kwargs): """ @@ -803,10 +733,7 @@ def prob(self, sample, **kwargs): """ self._prepare_evaluation(*zip(*sample.items())) - res = [ - self[key].prob(sample[key], **self.get_required_variables(key)) - for key in sample - ] + res = [self[key].prob(sample[key], **self.get_required_variables(key)) for key in sample] prob = np.prod(res, **kwargs) return self.check_prob(sample, prob) @@ -829,20 +756,13 @@ def ln_prob(self, sample, axis=None, normalized=True): """ self._prepare_evaluation(*zip(*sample.items())) - res = [ - self[key].ln_prob(sample[key], **self.get_required_variables(key)) - for key in sample - ] + res = [self[key].ln_prob(sample[key], **self.get_required_variables(key)) for key in sample] ln_prob = np.sum(res, axis=axis) - return self.check_ln_prob(sample, ln_prob, - normalized=normalized) + return self.check_ln_prob(sample, ln_prob, normalized=normalized) def cdf(self, sample): self._prepare_evaluation(*zip(*sample.items())) - res = { - key: self[key].cdf(sample[key], **self.get_required_variables(key)) - for key in sample - } + res = {key: self[key].cdf(sample[key], **self.get_required_variables(key)) for key in sample} return sample.__class__(res) def rescale(self, keys, theta): @@ -865,12 +785,8 @@ def rescale(self, keys, theta): self._update_rescale_keys(keys) result = dict() joint = dict() - for key, index in zip( - self.sorted_keys_without_fixed_parameters, self._rescale_indexes - ): - result[key] = self[key].rescale( - theta[index], **self.get_required_variables(key) - ) + for key, index in zip(self.sorted_keys_without_fixed_parameters, self._rescale_indexes): + result[key] = self[key].rescale(theta[index], **self.get_required_variables(key)) self[key].least_recently_sampled = result[key] if isinstance(self[key], JointPrior) and self[key].dist.distname not in joint: joint[self[key].dist.distname] = [key] @@ -904,10 +820,7 @@ def safe_flatten(value): def _update_rescale_keys(self, keys): if not keys == self._least_recently_rescaled_keys: - self._rescale_indexes = [ - keys.index(element) - for element in self.sorted_keys_without_fixed_parameters - ] + self._rescale_indexes = [keys.index(element) for element in self.sorted_keys_without_fixed_parameters] self._least_recently_rescaled_keys = keys def _prepare_evaluation(self, keys, theta): @@ -917,9 +830,7 @@ def _prepare_evaluation(self, keys, theta): def _check_resolved(self): if not self._resolved: - raise IllegalConditionsException( - "The current set of priors contains unresolveable conditions." - ) + raise IllegalConditionsException("The current set of priors contains unresolveable conditions.") @property def conditional_keys(self): @@ -935,18 +846,14 @@ def sorted_keys(self): @property def sorted_keys_without_fixed_parameters(self): - return [ - key - for key in self.sorted_keys - if not isinstance(self[key], (DeltaFunction, Constraint)) - ] + return [key for key in self.sorted_keys if not isinstance(self[key], (DeltaFunction, Constraint))] def __setitem__(self, key, value): - super(ConditionalPriorDict, self).__setitem__(key, value) + super().__setitem__(key, value) self._resolve_conditions() def __delitem__(self, key): - super(ConditionalPriorDict, self).__delitem__(key) + super().__delitem__(key) self._resolve_conditions() @@ -956,11 +863,9 @@ def __init__(self, n_dim=None, label="dirichlet_"): self.n_dim = n_dim self.label = label - super(DirichletPriorDict, self).__init__(dictionary=dict()) + super().__init__(dictionary=dict()) for ii in range(n_dim - 1): - self[label + "{}".format(ii)] = DirichletElement( - order=ii, n_dimensions=n_dim, label=label - ) + self[label + f"{ii}"] = DirichletElement(order=ii, n_dimensions=n_dim, label=label) def copy(self, **kwargs): return self.__class__(n_dim=self.n_dim, label=self.label) @@ -977,15 +882,9 @@ def _get_json_dict(self): @classmethod def _get_from_json_dict(cls, prior_dict): try: - cls == getattr( - import_module(prior_dict["__module__"]), prior_dict["__name__"] - ) + cls == getattr(import_module(prior_dict["__module__"]), prior_dict["__name__"]) except ImportError: - logger.debug( - "Cannot import prior module {}.{}".format( - prior_dict["__module__"], prior_dict["__name__"] - ) - ) + logger.debug("Cannot import prior module {}.{}".format(prior_dict["__module__"], prior_dict["__name__"])) except KeyError: logger.debug("Cannot find module name to load") for key in ["__module__", "__name__", "__prior_dict__"]: @@ -1024,7 +923,7 @@ def create_default_prior(name, default_priors_file=None): if name in default_priors.keys(): prior = default_priors[name] else: - logger.debug("No default prior found for variable {}.".format(name)) + logger.debug(f"No default prior found for variable {name}.") prior = None return prior diff --git a/bilby/core/prior/interpolated.py b/bilby/core/prior/interpolated.py index 5fbf8f9c1..d6b3f2640 100644 --- a/bilby/core/prior/interpolated.py +++ b/bilby/core/prior/interpolated.py @@ -1,14 +1,13 @@ import numpy as np from scipy.integrate import trapezoid +from ..utils import WrappedInterp1d as interp1d +from ..utils import logger from .base import Prior -from ..utils import logger, WrappedInterp1d as interp1d class Interped(Prior): - - def __init__(self, xx, yy, minimum=np.nan, maximum=np.nan, name=None, - latex_label=None, unit=None, boundary=None): + def __init__(self, xx, yy, minimum=np.nan, maximum=np.nan, name=None, latex_label=None, unit=None, boundary=None): """Creates an interpolated prior function from arrays of xx and yy=p(xx) Parameters @@ -53,8 +52,9 @@ def __init__(self, xx, yy, minimum=np.nan, maximum=np.nan, name=None, self.__all_interpolated = interp1d(x=xx, y=yy, bounds_error=False, fill_value=0) minimum = float(np.nanmax(np.array((min(xx), minimum)))) maximum = float(np.nanmin(np.array((max(xx), maximum)))) - super(Interped, self).__init__(name=name, latex_label=latex_label, unit=unit, - minimum=minimum, maximum=maximum, boundary=boundary) + super().__init__( + name=name, latex_label=latex_label, unit=unit, minimum=minimum, maximum=maximum, boundary=boundary + ) self._update_instance() def __eq__(self, other): @@ -106,9 +106,9 @@ def minimum(self): @minimum.setter def minimum(self, minimum): if minimum < self.min_limit: - raise ValueError('Minimum cannot be set below {}.'.format(round(self.min_limit, 2))) + raise ValueError(f"Minimum cannot be set below {round(self.min_limit, 2)}.") self._minimum = minimum - if '_maximum' in self.__dict__ and self._maximum < np.inf: + if "_maximum" in self.__dict__ and self._maximum < np.inf: self._update_instance() @property @@ -129,9 +129,9 @@ def maximum(self): @maximum.setter def maximum(self, maximum): if maximum > self.max_limit: - raise ValueError('Maximum cannot be set above {}.'.format(round(self.max_limit, 2))) + raise ValueError(f"Maximum cannot be set above {round(self.max_limit, 2)}.") self._maximum = maximum - if '_minimum' in self.__dict__ and self._minimum < np.inf: + if "_minimum" in self.__dict__ and self._minimum < np.inf: self._update_instance() @property @@ -160,8 +160,9 @@ def _update_instance(self): def _initialize_attributes(self): from scipy.integrate import cumulative_trapezoid + if trapezoid(self._yy, self.xx) != 1: - logger.debug('Supplied PDF for {} is not normalised, normalising.'.format(self.name)) + logger.debug(f"Supplied PDF for {self.name} is not normalised, normalising.") self._yy /= trapezoid(self._yy, self.xx) self.YY = cumulative_trapezoid(self._yy, self.xx, initial=0) # Need last element of cumulative distribution to be exactly one. @@ -172,9 +173,7 @@ def _initialize_attributes(self): class FromFile(Interped): - - def __init__(self, file_name, minimum=None, maximum=None, name=None, - latex_label=None, unit=None, boundary=None): + def __init__(self, file_name, minimum=None, maximum=None, name=None, latex_label=None, unit=None, boundary=None): """Creates an interpolated prior function from arrays of xx and yy=p(xx) extracted from a file Parameters @@ -198,10 +197,17 @@ def __init__(self, file_name, minimum=None, maximum=None, name=None, try: self.file_name = file_name xx, yy = np.genfromtxt(self.file_name).T - super(FromFile, self).__init__(xx=xx, yy=yy, minimum=minimum, - maximum=maximum, name=name, latex_label=latex_label, - unit=unit, boundary=boundary) - except IOError: - logger.warning("Can't load {}.".format(self.file_name)) + super().__init__( + xx=xx, + yy=yy, + minimum=minimum, + maximum=maximum, + name=name, + latex_label=latex_label, + unit=unit, + boundary=boundary, + ) + except OSError: + logger.warning(f"Can't load {self.file_name}.") logger.warning("Format should be:") logger.warning(r"x\tp(x)") diff --git a/bilby/core/prior/joint.py b/bilby/core/prior/joint.py index 43c8913e3..1ff8e1ad9 100644 --- a/bilby/core/prior/joint.py +++ b/bilby/core/prior/joint.py @@ -4,12 +4,11 @@ import scipy.stats from scipy.special import erfinv +from ..utils import get_dict_with_properties, infer_args_from_method, logger, random from .base import Prior, PriorException -from ..utils import logger, infer_args_from_method, get_dict_with_properties -from ..utils import random -class BaseJointPriorDist(object): +class BaseJointPriorDist: def __init__(self, names, bounds=None): """ A class defining JointPriorDist that will be overwritten with child @@ -43,9 +42,7 @@ def __init__(self, names, bounds=None): for bound in bounds: if isinstance(bounds, (list, tuple, np.ndarray)): if len(bound) != 2: - raise ValueError( - "Bounds must contain an upper and lower value." - ) + raise ValueError("Bounds must contain an upper and lower value.") else: if bound[1] <= bound[0]: raise ValueError("Bounds are not properly set") @@ -132,13 +129,8 @@ def __repr__(self): """ dist_name = self.__class__.__name__ instantiation_dict = self.get_instantiation_dict() - args = ", ".join( - [ - "{}={}".format(key, repr(instantiation_dict[key])) - for key in instantiation_dict - ] - ) - return "{}({})".format(dist_name, args) + args = ", ".join([f"{key}={repr(instantiation_dict[key])}" for key in instantiation_dict]) + return f"{dist_name}({args})" @classmethod def from_repr(cls, string): @@ -154,9 +146,7 @@ def _from_repr(cls, string): for key in kwargs: val = kwargs[key] if key not in subclass_args: - raise AttributeError( - "Unknown argument {} for class {}".format(key, cls.__name__) - ) + raise AttributeError(f"Unknown argument {key} for class {cls.__name__}") else: kwargs[key.strip()] = Prior._parse_argument_string(val) @@ -405,7 +395,7 @@ def __init__( A list of bounds on each parameter. The defaults are for bounds at +/- infinity. """ - super(MultivariateGaussianDist, self).__init__(names=names, bounds=bounds) + super().__init__(names=names, bounds=bounds) for name in self.names: bound = self.bounds[name] if bound[0] != -np.inf or bound[1] != np.inf: @@ -420,7 +410,7 @@ def __init__( self.covs = [] self.corrcoefs = [] self.sigmas = [] - self.logprodsigmas = [] # log of product of sigmas, needed for "standard" multivariate normal + self.logprodsigmas = [] # log of product of sigmas, needed for "standard" multivariate normal self.weights = [] self.eigvalues = [] self.eigvectors = [] @@ -456,9 +446,7 @@ def __init__( if len(np.shape(corrcoefs)) == 2: corrcoefs = [np.array(corrcoefs)] elif len(np.shape(corrcoefs)) != 3: - raise TypeError( - "List of correlation coefficients the wrong shape" - ) + raise TypeError("List of correlation coefficients the wrong shape") elif not isinstance(corrcoefs, list): raise TypeError("Must pass a list of correlation coefficients") if weights is not None: @@ -507,10 +495,7 @@ def add_mode(self, mus=None, sigmas=None, corrcoef=None, cov=None, weight=1.0): if len(self.covs[-1].shape) != 2: raise ValueError("Covariance matrix must be a 2d array") - if ( - self.covs[-1].shape[0] != self.covs[-1].shape[1] - or self.covs[-1].shape[0] != self.num_vars - ): + if self.covs[-1].shape[0] != self.covs[-1].shape[1] or self.covs[-1].shape[0] != self.num_vars: raise ValueError("Covariance shape is inconsistent") # check matrix is symmetric @@ -527,17 +512,13 @@ def add_mode(self, mus=None, sigmas=None, corrcoef=None, cov=None, weight=1.0): self.corrcoefs.append(np.asarray(corrcoef)) if len(self.corrcoefs[-1].shape) != 2: - raise ValueError( - "Correlation coefficient matrix must be a 2d array." - ) + raise ValueError("Correlation coefficient matrix must be a 2d array.") if ( self.corrcoefs[-1].shape[0] != self.corrcoefs[-1].shape[1] or self.corrcoefs[-1].shape[0] != self.num_vars ): - raise ValueError( - "Correlation coefficient matrix shape is inconsistent" - ) + raise ValueError("Correlation coefficient matrix shape is inconsistent") # check matrix is symmetric if not np.allclose(self.corrcoefs[-1], self.corrcoefs[-1].T): @@ -553,10 +534,7 @@ def add_mode(self, mus=None, sigmas=None, corrcoef=None, cov=None, weight=1.0): raise TypeError("'sigmas' must be a list") if len(self.sigmas[-1]) != self.num_vars: - raise ValueError( - "Number of standard deviations must be the " - "same as the number of parameters." - ) + raise ValueError("Number of standard deviations must be the same as the number of parameters.") # convert correlation coefficients to covariance matrix D = self.sigmas[-1] * np.identity(self.corrcoefs[-1].shape[0]) @@ -576,15 +554,11 @@ def add_mode(self, mus=None, sigmas=None, corrcoef=None, cov=None, weight=1.0): self.eigvalues.append(evals) self.eigvectors.append(evecs) except Exception as e: - raise RuntimeError( - "Problem getting eigenvalues and vectors: {}".format(e) - ) + raise RuntimeError(f"Problem getting eigenvalues and vectors: {e}") # check eigenvalues are positive if np.any(self.eigvalues[-1] <= 0.0): - raise ValueError( - "Correlation coefficient matrix is not positive definite" - ) + raise ValueError("Correlation coefficient matrix is not positive definite") self.sqeigvalues.append(np.sqrt(self.eigvalues[-1])) # set the weights @@ -607,9 +581,7 @@ def add_mode(self, mus=None, sigmas=None, corrcoef=None, cov=None, weight=1.0): # - this modifies the multivariate normal PDF as follows: # multivariate_normal(mean=mus, cov=cov).logpdf(x) # = multivariate_normal(mean=0, cov=corrcoefs).logpdf((x - mus)/sigmas) - logprodsigmas - self.mvn.append( - scipy.stats.multivariate_normal(mean=np.zeros(self.num_vars), cov=self.corrcoefs[-1]) - ) + self.mvn.append(scipy.stats.multivariate_normal(mean=np.zeros(self.num_vars), cov=self.corrcoefs[-1])) def _rescale(self, samp, **kwargs): try: @@ -623,7 +595,7 @@ def _rescale(self, samp, **kwargs): else: mode = np.argwhere(self.cumweights - random.rng.uniform(0, 1) > 0)[0][0] - samp = erfinv(2.0 * samp - 1) * 2.0 ** 0.5 + samp = erfinv(2.0 * samp - 1) * 2.0**0.5 # rotate and scale to the multivariate normal shape samp = self.mus[mode] + self.sigmas[mode] * np.einsum( @@ -645,10 +617,7 @@ def _sample(self, size, **kwargs): mode = np.argwhere(self.cumweights - random.rng.uniform(0, 1) > 0)[0][0] else: # pick modes - mode = [ - np.argwhere(self.cumweights - r > 0)[0][0] - for r in random.rng.uniform(0, 1, size) - ] + mode = [np.argwhere(self.cumweights - r > 0)[0][0] for r in random.rng.uniform(0, 1, size)] samps = np.zeros((size, len(self))) for i in range(size): @@ -697,9 +666,7 @@ def __eq__(self, other): if len(self.__dict__[key]) != len(other.__dict__[key]): return False for thismvn, othermvn in zip(self.__dict__[key], other.__dict__[key]): - if not isinstance( - thismvn, scipy.stats._multivariate.multivariate_normal_frozen - ) or not isinstance( + if not isinstance(thismvn, scipy.stats._multivariate.multivariate_normal_frozen) or not isinstance( othermvn, scipy.stats._multivariate.multivariate_normal_frozen ): return False @@ -742,17 +709,13 @@ def __init__(self, dist, name=None, latex_label=None, unit=None): See superclass """ if not isinstance(dist, BaseJointPriorDist): - raise TypeError( - "Must supply a JointPriorDist object instance to be shared by all joint params" - ) + raise TypeError("Must supply a JointPriorDist object instance to be shared by all joint params") if name not in dist.names: - raise ValueError( - "'{}' is not a parameter in the JointPriorDist".format(name) - ) + raise ValueError(f"'{name}' is not a parameter in the JointPriorDist") self.dist = dist - super(JointPrior, self).__init__( + super().__init__( name=name, latex_label=latex_label, unit=unit, @@ -822,9 +785,7 @@ def sample(self, size=1, **kwargs): if self.name in self.dist.sampled_parameters: logger.warning( - "You have already drawn a sample from parameter " - "'{}'. The same sample will be " - "returned".format(self.name) + f"You have already drawn a sample from parameter '{self.name}'. The same sample will be returned" ) if len(self.dist.current_sample) == 0: @@ -865,22 +826,12 @@ def ln_prob(self, val): # check for the same number of values for each parameter for i in range(len(self.dist) - 1): - if isinstance(values[i], (list, np.ndarray)) or isinstance( - values[i + 1], (list, np.ndarray) - ): - if isinstance(values[i], (list, np.ndarray)) and isinstance( - values[i + 1], (list, np.ndarray) - ): + if isinstance(values[i], (list, np.ndarray)) or isinstance(values[i + 1], (list, np.ndarray)): + if isinstance(values[i], (list, np.ndarray)) and isinstance(values[i + 1], (list, np.ndarray)): if len(values[i]) != len(values[i + 1]): - raise ValueError( - "Each parameter must have the same " - "number of requested values." - ) + raise ValueError("Each parameter must have the same number of requested values.") else: - raise ValueError( - "Each parameter must have the same " - "number of requested values." - ) + raise ValueError("Each parameter must have the same number of requested values.") lnp = self.dist.ln_prob(np.asarray(values).T) @@ -896,7 +847,7 @@ def ln_prob(self, val): # check value has a length len(val) except Exception as e: - raise TypeError("Invalid type for ln_prob: {}".format(e)) + raise TypeError(f"Invalid type for ln_prob: {e}") if len(val) == 1: return 0.0 @@ -923,12 +874,8 @@ def prob(self, val): class MultivariateGaussian(JointPrior): def __init__(self, dist, name=None, latex_label=None, unit=None): if not isinstance(dist, MultivariateGaussianDist): - raise JointPriorDistError( - "dist object must be instance of MultivariateGaussianDist" - ) - super(MultivariateGaussian, self).__init__( - dist=dist, name=name, latex_label=latex_label, unit=unit - ) + raise JointPriorDistError("dist object must be instance of MultivariateGaussianDist") + super().__init__(dist=dist, name=name, latex_label=latex_label, unit=unit) class MultivariateNormal(MultivariateGaussian): diff --git a/bilby/core/prior/slabspike.py b/bilby/core/prior/slabspike.py index 6910be608..3dbb910e0 100644 --- a/bilby/core/prior/slabspike.py +++ b/bilby/core/prior/slabspike.py @@ -1,12 +1,12 @@ from numbers import Number + import numpy as np -from .base import Prior from ..utils import logger +from .base import Prior class SlabSpikePrior(Prior): - def __init__(self, slab, spike_location=None, spike_height=0): """'Slab-and-spike' prior, see e.g. https://arxiv.org/abs/1812.07259 This prior is composed of a `slab`, i.e. any common prior distribution, @@ -31,15 +31,21 @@ def __init__(self, slab, spike_location=None, spike_height=0): """ self.slab = slab - super().__init__(name=self.slab.name, latex_label=self.slab.latex_label, unit=self.slab.unit, - minimum=self.slab.minimum, maximum=self.slab.maximum, - check_range_nonzero=self.slab.check_range_nonzero, boundary=self.slab.boundary) + super().__init__( + name=self.slab.name, + latex_label=self.slab.latex_label, + unit=self.slab.unit, + minimum=self.slab.minimum, + maximum=self.slab.maximum, + check_range_nonzero=self.slab.check_range_nonzero, + boundary=self.slab.boundary, + ) self.spike_location = spike_location self.spike_height = spike_height try: self.inverse_cdf_below_spike = self._find_inverse_cdf_fraction_before_spike() except Exception as e: - logger.warning("Disregard the following warning when running tests:\n {}".format(e)) + logger.warning(f"Disregard the following warning when running tests:\n {e}") @property def spike_location(self): @@ -50,7 +56,7 @@ def spike_location(self, spike_loc): if spike_loc is None: spike_loc = self.minimum if not self.minimum <= spike_loc <= self.maximum: - raise ValueError("Spike location {} not within prior domain ".format(spike_loc)) + raise ValueError(f"Spike location {spike_loc} not within prior domain ") self._spike_loc = spike_loc @property @@ -62,11 +68,11 @@ def spike_height(self, spike_height): if 0 <= spike_height <= 1: self._spike_height = spike_height else: - raise ValueError("Spike height must be between 0 and 1, but is {}".format(spike_height)) + raise ValueError(f"Spike height must be between 0 and 1, but is {spike_height}") @property def slab_fraction(self): - """ Relative prior weight of the slab. """ + """Relative prior weight of the slab.""" return 1 - self.spike_height def _find_inverse_cdf_fraction_before_spike(self): @@ -90,8 +96,8 @@ def rescale(self, val): lower_indices = val < self.inverse_cdf_below_spike intermediate_indices = np.logical_and( - self.inverse_cdf_below_spike <= val, - val <= (self.inverse_cdf_below_spike + self.spike_height)) + self.inverse_cdf_below_spike <= val, val <= (self.inverse_cdf_below_spike + self.spike_height) + ) higher_indices = val > (self.inverse_cdf_below_spike + self.spike_height) res = np.zeros(len(val)) @@ -102,8 +108,10 @@ def rescale(self, val): try: res = res[0] except (KeyError, TypeError): - logger.warning("Based on inputs, a number should be output\ - but this could not be accessed from what was computed") + logger.warning( + "Based on inputs, a number should be output\ + but this could not be accessed from what was computed" + ) return res def _contracted_rescale(self, val): @@ -142,8 +150,10 @@ def prob(self, val): try: res = res[0] except (KeyError, TypeError): - logger.warning("Based on inputs, a number should be output\ - but this could not be accessed from what was computed") + logger.warning( + "Based on inputs, a number should be output\ + but this could not be accessed from what was computed" + ) return res def ln_prob(self, val): @@ -166,12 +176,14 @@ def ln_prob(self, val): try: res = res[0] except (KeyError, TypeError): - logger.warning("Based on inputs, a number should be output\ - but this could not be accessed from what was computed") + logger.warning( + "Based on inputs, a number should be output\ + but this could not be accessed from what was computed" + ) return res def cdf(self, val): - """ Return the CDF of the prior. + """Return the CDF of the prior. This calls to the slab CDF and adds a discrete step at the spike location. diff --git a/bilby/core/result.py b/bilby/core/result.py index c4d12ff2a..f1f61ba4c 100644 --- a/bilby/core/result.py +++ b/bilby/core/result.py @@ -1,41 +1,46 @@ import datetime import inspect import json +import multiprocessing import os from collections import namedtuple from copy import copy +from functools import partial from importlib import import_module from itertools import product -import multiprocessing -from functools import partial + import numpy as np import pandas as pd import scipy.stats from . import utils from .likelihood import _safe_likelihood_call +from .prior import ConditionalDeltaFunction, DeltaFunction, Prior, PriorDict from .utils import ( - logger, infer_parameters_from_function, + BilbyJsonEncoder, check_directory_exists_and_if_not_mkdir, - latex_plot_format, safe_save_figure, - BilbyJsonEncoder, load_json, - move_old_file, get_version_information, - decode_bilby_json, docstring, - recursively_save_dict_contents_to_group, - recursively_load_dict_contents_from_group, + decode_bilby_json, + docstring, + get_version_information, + infer_parameters_from_function, + latex_plot_format, + load_json, + logger, + move_old_file, + random, recursively_decode_bilby_json, + recursively_load_dict_contents_from_group, + recursively_save_dict_contents_to_group, safe_file_dump, - random, + safe_save_figure, string_to_boolean, ) -from .prior import Prior, PriorDict, DeltaFunction, ConditionalDeltaFunction - EXTENSIONS = ["json", "hdf5", "h5", "pickle", "pkl"] -def result_file_name(outdir, label, extension='json', gzip=False): - """ Returns the standard filename used for a result file +def result_file_name(outdir, label, extension="json", gzip=False): + """Returns the standard filename used for a result file Parameters ========== @@ -52,19 +57,19 @@ def result_file_name(outdir, label, extension='json', gzip=False): ======= str: File name of the output file """ - if extension == 'pickle': - extension = 'pkl' - if extension in ['json', 'hdf5', 'pkl']: - if extension == 'json' and gzip: - return os.path.join(outdir, '{}_result.{}.gz'.format(label, extension)) + if extension == "pickle": + extension = "pkl" + if extension in ["json", "hdf5", "pkl"]: + if extension == "json" and gzip: + return os.path.join(outdir, f"{label}_result.{extension}.gz") else: - return os.path.join(outdir, '{}_result.{}'.format(label, extension)) + return os.path.join(outdir, f"{label}_result.{extension}") else: - raise ValueError("Extension type {} not understood".format(extension)) + raise ValueError(f"Extension type {extension} not understood") def _determine_file_name(filename, outdir, label, extension, gzip): - """ Helper method to determine the filename """ + """Helper method to determine the filename""" if filename is not None: if isinstance(filename, os.PathLike): # convert PathLike object to string @@ -79,7 +84,7 @@ def _determine_file_name(filename, outdir, label, extension, gzip): def read_in_result(filename=None, outdir=None, label=None, extension=None, gzip=False, result_class=None): - """ Reads in a stored bilby result object + """Reads in a stored bilby result object Parameters ========== @@ -97,9 +102,7 @@ def read_in_result(filename=None, outdir=None, label=None, extension=None, gzip= but objects which inherit from this class can be given providing additional methods. """ - filename = _determine_file_name( - filename, outdir, label, extension or "json", gzip - ) + filename = _determine_file_name(filename, outdir, label, extension or "json", gzip) if result_class is None: result_class = Result @@ -110,24 +113,22 @@ def read_in_result(filename=None, outdir=None, label=None, extension=None, gzip= if extension is None: ext = os.path.splitext(filename)[1][1:] extension = ext if ext else None - if extension == 'gz': # gzipped file - extension = os.path.splitext(os.path.splitext(filename)[0])[1].lstrip('.') + if extension == "gz": # gzipped file + extension = os.path.splitext(os.path.splitext(filename)[0])[1].lstrip(".") if extension is None: raise ValueError("No filetype extension provided and could not be inferred from filename") extension = extension.lower() read_functions = { - 'json': result_class.from_json, - 'hdf5': result_class.from_hdf5, - 'h5': result_class.from_hdf5, - 'pkl': result_class.from_pickle, - 'pickle': result_class.from_pickle, + "json": result_class.from_json, + "hdf5": result_class.from_hdf5, + "h5": result_class.from_hdf5, + "pkl": result_class.from_pickle, + "pickle": result_class.from_pickle, } if extension not in read_functions: - raise ValueError( - f"Filetype {extension} not understood, known types are {list(read_functions.keys())}" - ) + raise ValueError(f"Filetype {extension} not understood, known types are {list(read_functions.keys())}") func = read_functions[extension] @@ -135,8 +136,8 @@ def read_in_result(filename=None, outdir=None, label=None, extension=None, gzip= # Catch all other exceptions and raise a FileLoadError try: result = func(filename=filename) - except IOError as e: - raise IOError( + except OSError as e: + raise OSError( f"Failed to read in file {filename} using " f"`{result_class.__name__}.{func.__name__}` " f"(extension={extension}). " @@ -153,7 +154,7 @@ def read_in_result(filename=None, outdir=None, label=None, extension=None, gzip= def read_in_result_list(filename_list, invalid="warning"): - """ Read in a set of results + """Read in a set of results Parameters ========== @@ -169,15 +170,9 @@ def read_in_result_list(filename_list, invalid="warning"): """ results_list = [] for filename in filename_list: - if ( - not os.path.exists(filename) - and os.path.exists(f"{os.path.splitext(filename)[0]}.pkl") - ): + if not os.path.exists(filename) and os.path.exists(f"{os.path.splitext(filename)[0]}.pkl"): pickle_path = f"{os.path.splitext(filename)[0]}.pkl" - logger.warning( - f"Result file {filename} doesn't exist but {pickle_path} does. " - f"Using {pickle_path}." - ) + logger.warning(f"Result file {filename} doesn't exist but {pickle_path} does. Using {pickle_path}.") filename = pickle_path try: results_list.append(read_in_result(filename=filename)) @@ -191,9 +186,16 @@ def read_in_result_list(filename_list, invalid="warning"): def get_weights_for_reweighting( - result, new_likelihood=None, new_prior=None, old_likelihood=None, - old_prior=None, resume_file=None, n_checkpoint=5000, npool=1): - """ Calculate the weights for reweight() + result, + new_likelihood=None, + new_prior=None, + old_likelihood=None, + old_prior=None, + resume_file=None, + n_checkpoint=5000, + npool=1, +): + """Calculate the weights for reweight() See bilby.core.result.reweight() for help with the inputs @@ -226,33 +228,35 @@ def get_weights_for_reweighting( starting_index = 0 if (resume_file is not None) and os.path.exists(resume_file): - old_log_likelihood_array, old_log_prior_array, new_log_likelihood_array, new_log_prior_array = \ - np.genfromtxt(resume_file) + old_log_likelihood_array, old_log_prior_array, new_log_likelihood_array, new_log_prior_array = np.genfromtxt( + resume_file + ) starting_index = np.argmin(np.abs(old_log_likelihood_array)) - logger.info(f'Checkpoint resuming from {starting_index}.') + logger.info(f"Checkpoint resuming from {starting_index}.") elif resume_file is not None: basedir = os.path.split(resume_file)[0] check_directory_exists_and_if_not_mkdir(basedir) - dict_samples = [{key: sample[key] for key in result.posterior} - for _, sample in result.posterior.iterrows()] + dict_samples = [{key: sample[key] for key in result.posterior} for _, sample in result.posterior.iterrows()] n = len(dict_samples) - starting_index # Helper function to compute likelihoods in parallel def eval_pool(this_logl): with multiprocessing.Pool(processes=npool) as pool: chunksize = max(100, n // (2 * npool)) - return list(tqdm( - pool.imap(partial(_safe_likelihood_call, this_logl), - dict_samples[starting_index:], chunksize=chunksize), - desc='Computing likelihoods', - total=n) + return list( + tqdm( + pool.imap( + partial(_safe_likelihood_call, this_logl), dict_samples[starting_index:], chunksize=chunksize + ), + desc="Computing likelihoods", + total=n, + ) ) if old_likelihood is None: - old_log_likelihood_array[starting_index:] = \ - result.posterior["log_likelihood"][starting_index:].to_numpy() + old_log_likelihood_array[starting_index:] = result.posterior["log_likelihood"][starting_index:].to_numpy() else: old_log_likelihood_array[starting_index:] = eval_pool(old_likelihood) @@ -263,10 +267,9 @@ def eval_pool(this_logl): new_log_likelihood_array[starting_index:] = eval_pool(new_likelihood) # Compute priors - for ii, sample in enumerate(tqdm(dict_samples[starting_index:], - desc='Computing priors', - total=n), - start=starting_index): + for ii, sample in enumerate( + tqdm(dict_samples[starting_index:], desc="Computing priors", total=n), start=starting_index + ): # prior calculation needs to not have prior or likelihood keys ln_prior = sample.pop("log_prior", np.nan) if "log_likelihood" in sample: @@ -285,19 +288,19 @@ def eval_pool(this_logl): if (ii % (n_checkpoint) == 0) and (resume_file is not None): checkpointed_index = np.argmin(np.abs(old_log_likelihood_array)) - logger.info(f'Checkpointing with {checkpointed_index} samples') + logger.info(f"Checkpointing with {checkpointed_index} samples") np.savetxt( resume_file, - [old_log_likelihood_array, old_log_prior_array, new_log_likelihood_array, new_log_prior_array]) + [old_log_likelihood_array, old_log_prior_array, new_log_likelihood_array, new_log_prior_array], + ) - ln_weights = ( - new_log_likelihood_array + new_log_prior_array - old_log_likelihood_array - old_log_prior_array) + ln_weights = new_log_likelihood_array + new_log_prior_array - old_log_likelihood_array - old_log_prior_array return ln_weights, new_log_likelihood_array, new_log_prior_array, old_log_likelihood_array, old_log_prior_array def rejection_sample(posterior, weights): - """ Perform rejection sampling on a posterior using weights + """Perform rejection sampling on a posterior using weights Parameters ========== @@ -316,11 +319,21 @@ def rejection_sample(posterior, weights): return posterior[keep] -def reweight(result, label=None, new_likelihood=None, new_prior=None, - old_likelihood=None, old_prior=None, conversion_function=None, npool=1, - verbose_output=False, resume_file=None, n_checkpoint=5000, - use_nested_samples=False): - """ Reweight a result to a new likelihood/prior using rejection sampling +def reweight( + result, + label=None, + new_likelihood=None, + new_prior=None, + old_likelihood=None, + old_prior=None, + conversion_function=None, + npool=1, + verbose_output=False, + resume_file=None, + n_checkpoint=5000, + use_nested_samples=False, +): + """Reweight a result to a new likelihood/prior using rejection sampling Parameters ========== @@ -376,13 +389,20 @@ def reweight(result, label=None, new_likelihood=None, new_prior=None, result.posterior = result.nested_samples nposterior = len(result.posterior) - logger.info("Reweighting posterior with {} samples".format(nposterior)) + logger.info(f"Reweighting posterior with {nposterior} samples") - ln_weights, new_log_likelihood_array, new_log_prior_array, old_log_likelihood_array, old_log_prior_array =\ + ln_weights, new_log_likelihood_array, new_log_prior_array, old_log_likelihood_array, old_log_prior_array = ( get_weights_for_reweighting( - result, new_likelihood=new_likelihood, new_prior=new_prior, - old_likelihood=old_likelihood, old_prior=old_prior, - resume_file=resume_file, n_checkpoint=n_checkpoint, npool=npool) + result, + new_likelihood=new_likelihood, + new_prior=new_prior, + old_likelihood=old_likelihood, + old_prior=old_prior, + resume_file=resume_file, + n_checkpoint=n_checkpoint, + npool=npool, + ) + ) if use_nested_samples: ln_weights += np.log(result.posterior["weights"]) @@ -395,7 +415,7 @@ def reweight(result, label=None, new_likelihood=None, new_prior=None, result.posterior = rejection_sample(result.posterior, weights=weights) result.posterior = result.posterior.reset_index(drop=True) - logger.info("Rejection sampling resulted in {} samples".format(len(result.posterior))) + logger.info(f"Rejection sampling resulted in {len(result.posterior)} samples") result.meta_data["reweighted_using_rejection_sampling"] = True if use_nested_samples: @@ -421,28 +441,52 @@ def reweight(result, label=None, new_likelihood=None, new_prior=None, result.label += "_reweighted" if verbose_output: - return result, weights, new_log_likelihood_array, \ - new_log_prior_array, old_log_likelihood_array, old_log_prior_array + return ( + result, + weights, + new_log_likelihood_array, + new_log_prior_array, + old_log_likelihood_array, + old_log_prior_array, + ) else: return result -class Result(object): - def __init__(self, label='no_label', outdir='.', sampler=None, - search_parameter_keys=None, fixed_parameter_keys=None, - constraint_parameter_keys=None, priors=None, - sampler_kwargs=None, injection_parameters=None, - meta_data=None, posterior=None, samples=None, - nested_samples=None, log_evidence=np.nan, - log_evidence_err=np.nan, information_gain=np.nan, - log_noise_evidence=np.nan, log_bayes_factor=np.nan, - log_likelihood_evaluations=None, - log_prior_evaluations=None, sampling_time=None, nburn=None, - num_likelihood_evaluations=None, walkers=None, - max_autocorrelation_time=None, use_ratio=None, - parameter_labels=None, parameter_labels_with_unit=None, - version=None): - """ A class to store the results of the sampling run +class Result: + def __init__( + self, + label="no_label", + outdir=".", + sampler=None, + search_parameter_keys=None, + fixed_parameter_keys=None, + constraint_parameter_keys=None, + priors=None, + sampler_kwargs=None, + injection_parameters=None, + meta_data=None, + posterior=None, + samples=None, + nested_samples=None, + log_evidence=np.nan, + log_evidence_err=np.nan, + information_gain=np.nan, + log_noise_evidence=np.nan, + log_bayes_factor=np.nan, + log_likelihood_evaluations=None, + log_prior_evaluations=None, + sampling_time=None, + nburn=None, + num_likelihood_evaluations=None, + walkers=None, + max_autocorrelation_time=None, + use_ratio=None, + parameter_labels=None, + parameter_labels_with_unit=None, + version=None, + ): + """A class to store the results of the sampling run Parameters ========== @@ -570,8 +614,9 @@ def __init__(self, label='no_label', outdir='.', sampler=None, @staticmethod @docstring(_load_doctstring.format(format="pickle")) def from_pickle(filename=None, outdir=None, label=None): - filename = _determine_file_name(filename, outdir, label, 'hdf5', False) + filename = _determine_file_name(filename, outdir, label, "hdf5", False) import dill + with open(filename, "rb") as ff: return dill.load(ff) @@ -579,19 +624,16 @@ def from_pickle(filename=None, outdir=None, label=None): @docstring(_load_doctstring.format(format="hdf5")) def from_hdf5(cls, filename=None, outdir=None, label=None): import h5py - filename = _determine_file_name(filename, outdir, label, 'hdf5', False) + + filename = _determine_file_name(filename, outdir, label, "hdf5", False) with h5py.File(filename, "r") as ff: - data = recursively_load_dict_contents_from_group(ff, '/') + data = recursively_load_dict_contents_from_group(ff, "/") data["posterior"] = pd.DataFrame(data["posterior"]) - data["priors"] = PriorDict._get_from_json_dict( - json.loads(data["priors"], object_hook=decode_bilby_json) - ) + data["priors"] = PriorDict._get_from_json_dict(json.loads(data["priors"], object_hook=decode_bilby_json)) try: - cls = getattr(import_module(data['__module__']), data['__name__']) + cls = getattr(import_module(data["__module__"]), data["__name__"]) except ImportError: - logger.debug( - "Module {}.{} not found".format(data["__module__"], data["__name__"]) - ) + logger.debug("Module {}.{} not found".format(data["__module__"], data["__name__"])) except KeyError: logger.debug("No class specified, using base Result.") for key in ["__module__", "__name__"]: @@ -604,39 +646,37 @@ def from_hdf5(cls, filename=None, outdir=None, label=None): def from_json(cls, filename=None, outdir=None, label=None, gzip=False): from json.decoder import JSONDecodeError - filename = _determine_file_name(filename, outdir, label, 'json', gzip) + filename = _determine_file_name(filename, outdir, label, "json", gzip) if os.path.isfile(filename): try: dictionary = load_json(filename, gzip) except JSONDecodeError as e: - raise IOError( - "JSON failed to decode {} with message {}".format(filename, e) - ) + raise OSError(f"JSON failed to decode {filename} with message {e}") try: return cls(**dictionary) except TypeError as e: - raise IOError("Unable to load dictionary, error={}".format(e)) + raise OSError(f"Unable to load dictionary, error={e}") else: - raise IOError("No result '{}' found".format(filename)) + raise OSError(f"No result '{filename}' found") def __str__(self): - """Print a summary """ - if getattr(self, 'posterior', None) is not None: - if getattr(self, 'log_noise_evidence', None) is not None: - return ("nsamples: {:d}\n" - "ln_noise_evidence: {:6.3f}\n" - "ln_evidence: {:6.3f} +/- {:6.3f}\n" - "ln_bayes_factor: {:6.3f} +/- {:6.3f}\n" - .format(len(self.posterior), self.log_noise_evidence, self.log_evidence, - self.log_evidence_err, self.log_bayes_factor, - self.log_evidence_err)) + """Print a summary""" + if getattr(self, "posterior", None) is not None: + if getattr(self, "log_noise_evidence", None) is not None: + return ( + f"nsamples: {len(self.posterior):d}\n" + f"ln_noise_evidence: {self.log_noise_evidence:6.3f}\n" + f"ln_evidence: {self.log_evidence:6.3f} +/- {self.log_evidence_err:6.3f}\n" + f"ln_bayes_factor: {self.log_bayes_factor:6.3f} +/- {self.log_evidence_err:6.3f}\n" + ) else: - return ("nsamples: {:d}\n" - "ln_evidence: {:6.3f} +/- {:6.3f}\n" - .format(len(self.posterior), self.log_evidence, self.log_evidence_err)) + return ( + f"nsamples: {len(self.posterior):d}\n" + f"ln_evidence: {self.log_evidence:6.3f} +/- {self.log_evidence_err:6.3f}\n" + ) else: - return '' + return "" @property def meta_data(self): @@ -654,7 +694,7 @@ def priors(self): if self._priors is not None: return self._priors else: - raise ValueError('Result object has no priors') + raise ValueError("Result object has no priors") @priors.setter def priors(self, priors): @@ -664,12 +704,11 @@ def priors(self, priors): else: self._priors = PriorDict(priors) if self.parameter_labels is None: - self.parameter_labels = [self.priors[k].latex_label for k in - self.search_parameter_keys] + self.parameter_labels = [self.priors[k].latex_label for k in self.search_parameter_keys] if self.parameter_labels_with_unit is None: self.parameter_labels_with_unit = [ - self.priors[k].latex_label_with_unit for k in - self.search_parameter_keys] + self.priors[k].latex_label_with_unit for k in self.search_parameter_keys + ] elif priors is None: self._priors = priors self.parameter_labels = self.search_parameter_keys @@ -679,7 +718,7 @@ def priors(self, priors): @property def samples(self): - """ An array of samples """ + """An array of samples""" if self._samples is not None: return self._samples else: @@ -691,7 +730,7 @@ def samples(self, samples): @property def num_likelihood_evaluations(self): - """ number of likelihood evaluations """ + """number of likelihood evaluations""" if self._num_likelihood_evaluations is not None: return self._num_likelihood_evaluations else: @@ -703,7 +742,7 @@ def num_likelihood_evaluations(self, num_likelihood_evaluations): @property def nested_samples(self): - """" An array of unweighted samples """ + """ " An array of unweighted samples""" if self._nested_samples is not None: return self._nested_samples else: @@ -715,7 +754,7 @@ def nested_samples(self, nested_samples): @property def walkers(self): - """" An array of the ensemble walkers """ + """ " An array of the ensemble walkers""" if self._walkers is not None: return self._walkers else: @@ -727,7 +766,7 @@ def walkers(self, walkers): @property def nburn(self): - """" An array of the ensemble walkers """ + """ " An array of the ensemble walkers""" if self._nburn is not None: return self._nburn else: @@ -739,7 +778,7 @@ def nburn(self, nburn): @property def posterior(self): - """ A pandas data frame of the posterior """ + """A pandas data frame of the posterior""" if self._posterior is not None: return self._posterior else: @@ -772,33 +811,52 @@ def version(self): @version.setter def version(self, version): if version is None: - self._version = 'bilby={}'.format(utils.get_version_information()) + self._version = f"bilby={utils.get_version_information()}" else: self._version = version def _get_save_data_dictionary(self): # This list defines all the parameters saved in the result object save_attrs = [ - 'label', 'outdir', 'sampler', 'log_evidence', 'log_evidence_err', - 'log_noise_evidence', 'log_bayes_factor', 'priors', 'posterior', - 'injection_parameters', 'meta_data', 'search_parameter_keys', - 'fixed_parameter_keys', 'constraint_parameter_keys', - 'sampling_time', 'sampler_kwargs', 'use_ratio', 'information_gain', - 'log_likelihood_evaluations', 'log_prior_evaluations', - 'num_likelihood_evaluations', 'samples', 'nested_samples', - 'walkers', 'nburn', 'parameter_labels', 'parameter_labels_with_unit', - 'version'] + "label", + "outdir", + "sampler", + "log_evidence", + "log_evidence_err", + "log_noise_evidence", + "log_bayes_factor", + "priors", + "posterior", + "injection_parameters", + "meta_data", + "search_parameter_keys", + "fixed_parameter_keys", + "constraint_parameter_keys", + "sampling_time", + "sampler_kwargs", + "use_ratio", + "information_gain", + "log_likelihood_evaluations", + "log_prior_evaluations", + "num_likelihood_evaluations", + "samples", + "nested_samples", + "walkers", + "nburn", + "parameter_labels", + "parameter_labels_with_unit", + "version", + ] dictionary = dict() for attr in save_attrs: try: dictionary[attr] = getattr(self, attr) except ValueError as e: - logger.debug("Unable to save {}, message: {}".format(attr, e)) + logger.debug(f"Unable to save {attr}, message: {e}") pass return dictionary - def save_to_file(self, filename=None, overwrite=False, outdir=None, - extension=None, gzip=False): + def save_to_file(self, filename=None, overwrite=False, outdir=None, extension=None, gzip=False): """ Writes the Result to a file. @@ -836,20 +894,16 @@ def save_to_file(self, filename=None, overwrite=False, outdir=None, if ext in EXTENSIONS: if extension is None: logger.debug( - f"Inferred extension '{ext}' from filename '{filename}'. " - "Using this extension for saving." + f"Inferred extension '{ext}' from filename '{filename}'. Using this extension for saving." ) extension = ext elif ext != extension: - message = ( - f"The specified extension '{ext}' " - f"does not match the provided extension '{extension}'. " - ) + message = f"The specified extension '{ext}' does not match the provided extension '{extension}'. " logger.warning(message) if extension is None: logger.info("No extension given, defaulting to JSON.") - extension = 'json' + extension = "json" if extension is True: message = ( @@ -858,7 +912,7 @@ def save_to_file(self, filename=None, overwrite=False, outdir=None, "This behaviour is deprecated and will be removed. " ) logger.warning(message) - extension = 'json' + extension = "json" outdir = _outdir if outdir is None else outdir outdir = self._safe_outdir_creation(outdir, self.save_to_file) @@ -873,31 +927,33 @@ def save_to_file(self, filename=None, overwrite=False, outdir=None, dictionary = self._get_save_data_dictionary() # Convert callable sampler_kwargs to strings - if dictionary.get('sampler_kwargs', None) is not None: - for key in dictionary['sampler_kwargs']: - if hasattr(dictionary['sampler_kwargs'][key], '__call__'): - dictionary['sampler_kwargs'][key] = str(dictionary['sampler_kwargs']) + if dictionary.get("sampler_kwargs", None) is not None: + for key in dictionary["sampler_kwargs"]: + if hasattr(dictionary["sampler_kwargs"][key], "__call__"): + dictionary["sampler_kwargs"][key] = str(dictionary["sampler_kwargs"]) try: # convert priors to JSON dictionary for both JSON and hdf5 files - if extension == 'json': + if extension == "json": dictionary["priors"] = dictionary["priors"]._get_json_dict() if gzip: import gzip + # encode to a string - json_str = json.dumps(dictionary, cls=BilbyJsonEncoder).encode('utf-8') - with gzip.GzipFile(output_path, 'w') as file: + json_str = json.dumps(dictionary, cls=BilbyJsonEncoder).encode("utf-8") + with gzip.GzipFile(output_path, "w") as file: file.write(json_str) else: - with open(output_path, 'w') as file: + with open(output_path, "w") as file: json.dump(dictionary, file, indent=2, cls=BilbyJsonEncoder) - elif extension in ['hdf5', 'h5']: + elif extension in ["hdf5", "h5"]: import h5py + dictionary["__module__"] = self.__module__ dictionary["__name__"] = self.__class__.__name__ - with h5py.File(output_path, 'w') as h5file: - recursively_save_dict_contents_to_group(h5file, '/', dictionary) - elif extension in ['pkl', 'pickle']: + with h5py.File(output_path, "w") as h5file: + recursively_save_dict_contents_to_group(h5file, "/", dictionary) + elif extension in ["pkl", "pickle"]: safe_file_dump(self, output_path, "dill") else: raise ValueError(f"Extension type {extension} not understood") @@ -910,7 +966,7 @@ def save_to_file(self, filename=None, overwrite=False, outdir=None, ) def save_posterior_samples(self, filename=None, outdir=None, label=None): - """ Saves posterior samples to a file + """Saves posterior samples to a file Generates a .dat file containing the posterior samples and auxiliary data saved in the posterior. Note, strings in the posterior are @@ -930,7 +986,7 @@ def save_posterior_samples(self, filename=None, outdir=None, label=None): if label is None: label = self.label outdir = self._safe_outdir_creation(outdir, self.save_posterior_samples) - filename = '{}/{}_posterior_samples.dat'.format(outdir, label) + filename = f"{outdir}/{label}_posterior_samples.dat" else: outdir = os.path.dirname(filename) self._safe_outdir_creation(outdir, self.save_posterior_samples) @@ -945,11 +1001,11 @@ def save_posterior_samples(self, filename=None, outdir=None, label=None): df.loc[:, key + "_abs"] = np.abs(complex_term) df.loc[:, key + "_angle"] = np.angle(complex_term) - logger.info("Writing samples file to {}".format(filename)) - df.to_csv(filename, index=False, header=True, sep=' ') + logger.info(f"Writing samples file to {filename}") + df.to_csv(filename, index=False, header=True, sep=" ") def get_latex_labels_from_parameter_keys(self, keys): - """ Returns a list of latex_labels corresponding to the given keys + """Returns a list of latex_labels corresponding to the given keys Parameters ========== @@ -970,9 +1026,7 @@ def get_latex_labels_from_parameter_keys(self, keys): label = key else: label = None - logger.debug( - 'key {} not a parameter label or latex label'.format(key) - ) + logger.debug(f"key {key} not a parameter label or latex label") if label is None: label = key.replace("_", " ") latex_labels.append(label) @@ -980,26 +1034,25 @@ def get_latex_labels_from_parameter_keys(self, keys): @property def covariance_matrix(self): - """ The covariance matrix of the samples the posterior """ + """The covariance matrix of the samples the posterior""" samples = self.posterior[self.search_parameter_keys].values return np.cov(samples.T) @property def posterior_volume(self): - """ The posterior volume """ + """The posterior volume""" if self.covariance_matrix.ndim == 0: return np.sqrt(self.covariance_matrix) else: - return 1 / np.sqrt(np.abs(np.linalg.det( - 1 / self.covariance_matrix))) + return 1 / np.sqrt(np.abs(np.linalg.det(1 / self.covariance_matrix))) @staticmethod def prior_volume(priors): - """ The prior volume, given a set of priors """ + """The prior volume, given a set of priors""" return np.prod([priors[k].maximum - priors[k].minimum for k in priors]) def occam_factor(self, priors): - """ The Occam factor, + """The Occam factor, See Chapter 28, `Mackay "Information Theory, Inference, and Learning Algorithms" `_ Cambridge @@ -1010,7 +1063,7 @@ def occam_factor(self, priors): @property def bayesian_model_dimensionality(self): - """ Characterises how many parameters are effectively constraint by the data + """Characterises how many parameters are effectively constraint by the data See @@ -1018,12 +1071,10 @@ def bayesian_model_dimensionality(self): ======= float: The model dimensionality """ - return 2 * (np.mean(self.posterior['log_likelihood']**2) - - np.mean(self.posterior['log_likelihood'])**2) + return 2 * (np.mean(self.posterior["log_likelihood"] ** 2) - np.mean(self.posterior["log_likelihood"]) ** 2) - def get_one_dimensional_median_and_error_bar(self, key, fmt='.2f', - quantiles=(0.16, 0.84)): - """ Calculate the median and error bar for a given key + def get_one_dimensional_median_and_error_bar(self, key, fmt=".2f", quantiles=(0.16, 0.84)): + """Calculate the median and error bar for a given key Parameters ========== @@ -1041,7 +1092,7 @@ def get_one_dimensional_median_and_error_bar(self, key, fmt='.2f', An object with attributes, median, lower, upper and string """ - summary = namedtuple('summary', ['median', 'lower', 'upper', 'string']) + summary = namedtuple("summary", ["median", "lower", "upper", "string"]) if len(quantiles) != 2: raise ValueError("quantiles must be of length 2") @@ -1052,18 +1103,28 @@ def get_one_dimensional_median_and_error_bar(self, key, fmt='.2f', summary.plus = quants[2] - summary.median summary.minus = summary.median - quants[0] - fmt = "{{0:{0}}}".format(fmt).format + fmt = f"{{0:{fmt}}}".format string_template = r"${{{0}}}_{{-{1}}}^{{+{2}}}$" - summary.string = string_template.format( - fmt(summary.median), fmt(summary.minus), fmt(summary.plus)) + summary.string = string_template.format(fmt(summary.median), fmt(summary.minus), fmt(summary.plus)) return summary @latex_plot_format - def plot_single_density(self, key, prior=None, cumulative=False, - title=None, truth=None, save=True, - file_base_name=None, bins=50, label_fontsize=16, - title_fontsize=16, quantiles=(0.16, 0.84), dpi=300): - """ Plot a 1D marginal density, either probability or cumulative. + def plot_single_density( + self, + key, + prior=None, + cumulative=False, + title=None, + truth=None, + save=True, + file_base_name=None, + bins=50, + label_fontsize=16, + title_fontsize=16, + quantiles=(0.16, 0.84), + dpi=300, + ): + """Plot a 1D marginal density, either probability or cumulative. Parameters ========== @@ -1104,52 +1165,58 @@ def plot_single_density(self, key, prior=None, cumulative=False, A matplotlib figure object """ import matplotlib.pyplot as plt - logger.info('Plotting {} marginal distribution'.format(key)) + + logger.info(f"Plotting {key} marginal distribution") label = self.get_latex_labels_from_parameter_keys([key])[0] label = sanity_check_labels([label])[0] fig, ax = plt.subplots() try: - ax.hist(self.posterior[key].values, bins=bins, density=True, - histtype='step', cumulative=cumulative) + ax.hist(self.posterior[key].values, bins=bins, density=True, histtype="step", cumulative=cumulative) except ValueError as e: - logger.info( - 'Failed to generate 1d plot for {}, error message: {}' - .format(key, e)) + logger.info(f"Failed to generate 1d plot for {key}, error message: {e}") return ax.set_xlabel(label, fontsize=label_fontsize) if truth is not None: - ax.axvline(truth, ls='-', color='orange') + ax.axvline(truth, ls="-", color="orange") - summary = self.get_one_dimensional_median_and_error_bar( - key, quantiles=quantiles) - ax.axvline(summary.median - summary.minus, ls='--', color='C0') - ax.axvline(summary.median + summary.plus, ls='--', color='C0') + summary = self.get_one_dimensional_median_and_error_bar(key, quantiles=quantiles) + ax.axvline(summary.median - summary.minus, ls="--", color="C0") + ax.axvline(summary.median + summary.plus, ls="--", color="C0") if title: ax.set_title(summary.string, fontsize=title_fontsize) if isinstance(prior, Prior): theta = np.linspace(ax.get_xlim()[0], ax.get_xlim()[1], 300) if cumulative is False: - ax.plot(theta, prior.prob(theta), color='C2') + ax.plot(theta, prior.prob(theta), color="C2") else: - ax.plot(theta, prior.cdf(theta), color='C2') + ax.plot(theta, prior.cdf(theta), color="C2") if save: fig.tight_layout() if cumulative: - file_name = file_base_name + key + '_cdf' + file_name = file_base_name + key + "_cdf" else: - file_name = file_base_name + key + '_pdf' + file_name = file_base_name + key + "_pdf" safe_save_figure(fig=fig, filename=file_name, dpi=dpi) plt.close(fig) else: return fig - def plot_marginals(self, parameters=None, priors=None, titles=True, - file_base_name=None, bins=50, label_fontsize=16, - title_fontsize=16, quantiles=(0.16, 0.84), dpi=300, - outdir=None): - """ Plot 1D marginal distributions + def plot_marginals( + self, + parameters=None, + priors=None, + titles=True, + file_base_name=None, + bins=50, + label_fontsize=16, + title_fontsize=16, + quantiles=(0.16, 0.84), + dpi=300, + outdir=None, + ): + """Plot 1D marginal distributions Parameters ========== @@ -1201,17 +1268,17 @@ def plot_marginals(self, parameters=None, priors=None, titles=True, if file_base_name is None: outdir = self._safe_outdir_creation(outdir, self.plot_marginals) - file_base_name = '{}/{}_1d/'.format(outdir, self.label) + file_base_name = f"{outdir}/{self.label}_1d/" check_directory_exists_and_if_not_mkdir(file_base_name) if priors is True: - priors = getattr(self, 'priors', dict()) + priors = getattr(self, "priors", dict()) elif isinstance(priors, dict): pass elif priors in [False, None]: priors = dict() else: - raise ValueError('Input priors={} not understood'.format(priors)) + raise ValueError(f"Input priors={priors} not understood") for i, key in enumerate(plot_parameter_keys): if not isinstance(self.posterior[key].values[0], float): @@ -1220,15 +1287,23 @@ def plot_marginals(self, parameters=None, priors=None, titles=True, truth = truths.get(key, None) for cumulative in [False, True]: self.plot_single_density( - key, prior=prior, cumulative=cumulative, title=titles, - truth=truth, save=True, file_base_name=file_base_name, - bins=bins, label_fontsize=label_fontsize, dpi=dpi, - title_fontsize=title_fontsize, quantiles=quantiles) + key, + prior=prior, + cumulative=cumulative, + title=titles, + truth=truth, + save=True, + file_base_name=file_base_name, + bins=bins, + label_fontsize=label_fontsize, + dpi=dpi, + title_fontsize=title_fontsize, + quantiles=quantiles, + ) @latex_plot_format - def plot_corner(self, parameters=None, priors=None, titles=True, save=True, - filename=None, dpi=300, **kwargs): - """ Plot a corner-plot + def plot_corner(self, parameters=None, priors=None, titles=True, save=True, filename=None, dpi=300, **kwargs): + """Plot a corner-plot Parameters ========== @@ -1282,16 +1357,22 @@ def plot_corner(self, parameters=None, priors=None, titles=True, save=True, return defaults_kwargs = dict( - bins=50, smooth=0.9, - title_kwargs=dict(fontsize=16), color='#0072C1', - truth_color='tab:orange', quantiles=[0.16, 0.84], - levels=(1 - np.exp(-0.5), 1 - np.exp(-2), 1 - np.exp(-9 / 2.)), - plot_density=False, plot_datapoints=True, fill_contours=True, - max_n_ticks=3) + bins=50, + smooth=0.9, + title_kwargs=dict(fontsize=16), + color="#0072C1", + truth_color="tab:orange", + quantiles=[0.16, 0.84], + levels=(1 - np.exp(-0.5), 1 - np.exp(-2), 1 - np.exp(-9 / 2.0)), + plot_density=False, + plot_datapoints=True, + fill_contours=True, + max_n_ticks=3, + ) - if 'lionize' in kwargs and kwargs['lionize'] is True: - defaults_kwargs['truth_color'] = 'tab:blue' - defaults_kwargs['color'] = '#FF8C00' + if "lionize" in kwargs and kwargs["lionize"] is True: + defaults_kwargs["truth_color"] = "tab:blue" + defaults_kwargs["color"] = "#FF8C00" label_kwargs_defaults = dict(fontsize=16) hist_kwargs_defaults = dict(density=True) @@ -1309,59 +1390,52 @@ def plot_corner(self, parameters=None, priors=None, titles=True, save=True, kwargs["hist_kwargs"] = hist_kwargs_defaults # Handle if truths was passed in - if 'truth' in kwargs: - kwargs['truths'] = kwargs.pop('truth') + if "truth" in kwargs: + kwargs["truths"] = kwargs.pop("truth") if "truths" in kwargs: - truths = kwargs.get('truths') + truths = kwargs.get("truths") if isinstance(parameters, list) and isinstance(truths, list): if len(parameters) != len(truths): - raise ValueError( - "Length of parameters and truths don't match") + raise ValueError("Length of parameters and truths don't match") elif isinstance(truths, dict) and parameters is None: - parameters = kwargs.pop('truths') + parameters = kwargs.pop("truths") elif isinstance(truths, bool): pass elif truths is None: kwargs["truths"] = False else: - raise ValueError( - "Combination of parameters and truths not understood") + raise ValueError("Combination of parameters and truths not understood") # If injection parameters where stored, use these as parameter values # but do not overwrite input parameters (or truths) - cond1 = getattr(self, 'injection_parameters', None) is not None + cond1 = getattr(self, "injection_parameters", None) is not None cond2 = parameters is None cond3 = bool(kwargs.get("truths", True)) if cond1 and cond2 and cond3: - parameters = { - key: self.injection_parameters.get(key, np.nan) - for key in self.search_parameter_keys - } + parameters = {key: self.injection_parameters.get(key, np.nan) for key in self.search_parameter_keys} # If parameters is a dictionary, use the keys to determine which # parameters to plot and the values as truths. if isinstance(parameters, dict): plot_parameter_keys = list(parameters.keys()) - kwargs['truths'] = list(parameters.values()) + kwargs["truths"] = list(parameters.values()) elif parameters is None: plot_parameter_keys = self.search_parameter_keys else: plot_parameter_keys = list(parameters) # Get latex formatted strings for the plot labels - kwargs['labels'] = kwargs.get( - 'labels', self.get_latex_labels_from_parameter_keys( - plot_parameter_keys)) + kwargs["labels"] = kwargs.get("labels", self.get_latex_labels_from_parameter_keys(plot_parameter_keys)) kwargs["labels"] = sanity_check_labels(kwargs["labels"]) # Unless already set, set the range to include all samples # This prevents ValueErrors being raised for parameters with no range - kwargs['range'] = kwargs.get('range', [1] * len(plot_parameter_keys)) + kwargs["range"] = kwargs.get("range", [1] * len(plot_parameter_keys)) # Remove truths if it is a bool - if isinstance(kwargs.get('truths'), bool): - kwargs.pop('truths') + if isinstance(kwargs.get("truths"), bool): + kwargs.pop("truths") # Create the data array to plot and pass everything to corner xs = self.posterior[plot_parameter_keys].values @@ -1369,40 +1443,40 @@ def plot_corner(self, parameters=None, priors=None, titles=True, save=True, fig = corner.corner(xs, **kwargs) else: ax = kwargs.get("ax", plt.subplot()) - ax.hist(xs, bins=kwargs["bins"], color=kwargs["color"], - histtype="step", **kwargs["hist_kwargs"]) + ax.hist(xs, bins=kwargs["bins"], color=kwargs["color"], histtype="step", **kwargs["hist_kwargs"]) ax.set_xlabel(kwargs["labels"][0]) fig = plt.gcf() axes = fig.get_axes() # Add the titles - if titles and kwargs.get('quantiles', None) is not None: + if titles and kwargs.get("quantiles", None) is not None: for i, par in enumerate(plot_parameter_keys): ax = axes[i + i * len(plot_parameter_keys)] - if ax.title.get_text() == '': - ax.set_title(self.get_one_dimensional_median_and_error_bar( - par, quantiles=kwargs['quantiles']).string, - **kwargs['title_kwargs']) + if ax.title.get_text() == "": + ax.set_title( + self.get_one_dimensional_median_and_error_bar(par, quantiles=kwargs["quantiles"]).string, + **kwargs["title_kwargs"], + ) # Add priors to the 1D plots if priors is True: - priors = getattr(self, 'priors', False) + priors = getattr(self, "priors", False) if isinstance(priors, dict): for i, par in enumerate(plot_parameter_keys): ax = axes[i + i * len(plot_parameter_keys)] theta = np.linspace(ax.get_xlim()[0], ax.get_xlim()[1], 300) - ax.plot(theta, priors[par].prob(theta), color='C2') + ax.plot(theta, priors[par].prob(theta), color="C2") elif priors in [False, None]: pass else: - raise ValueError('Input priors={} not understood'.format(priors)) + raise ValueError(f"Input priors={priors} not understood") if save: if filename is None: - outdir = self._safe_outdir_creation(kwargs.get('outdir'), self.plot_corner) - filename = '{}/{}_corner.png'.format(outdir, self.label) - logger.debug('Saving corner plot to {}'.format(filename)) + outdir = self._safe_outdir_creation(kwargs.get("outdir"), self.plot_corner) + filename = f"{outdir}/{self.label}_corner.png" + logger.debug(f"Saving corner plot to {filename}") safe_save_figure(fig=fig, filename=filename, dpi=dpi) plt.close(fig) @@ -1410,9 +1484,10 @@ def plot_corner(self, parameters=None, priors=None, titles=True, save=True, @latex_plot_format def plot_walkers(self, **kwargs): - """ Method to plot the trace of the walkers in an ensemble MCMC plot """ + """Method to plot the trace of the walkers in an ensemble MCMC plot""" import matplotlib.pyplot as plt - if hasattr(self, 'walkers') is False: + + if hasattr(self, "walkers") is False: logger.warning("Cannot plot_walkers as no walkers are saved") return @@ -1425,28 +1500,39 @@ def plot_walkers(self, **kwargs): walkers = self.walkers[:, :, :] parameter_labels = sanity_check_labels(self.parameter_labels) for i, ax in enumerate(axes): - ax.plot(idxs[:self.nburn + 1], walkers[:, :self.nburn + 1, i].T, - lw=0.1, color='r') + ax.plot(idxs[: self.nburn + 1], walkers[:, : self.nburn + 1, i].T, lw=0.1, color="r") ax.set_ylabel(parameter_labels[i]) for i, ax in enumerate(axes): - ax.plot(idxs[self.nburn:], walkers[:, self.nburn:, i].T, lw=0.1, - color='k') + ax.plot(idxs[self.nburn :], walkers[:, self.nburn :, i].T, lw=0.1, color="k") ax.set_ylabel(parameter_labels[i]) fig.tight_layout() - outdir = self._safe_outdir_creation(kwargs.get('outdir'), self.plot_walkers) - filename = '{}/{}_walkers.png'.format(outdir, self.label) - logger.debug('Saving walkers plot to {}'.format('filename')) + outdir = self._safe_outdir_creation(kwargs.get("outdir"), self.plot_walkers) + filename = f"{outdir}/{self.label}_walkers.png" + logger.debug("Saving walkers plot to {}".format("filename")) safe_save_figure(fig=fig, filename=filename) plt.close(fig) @latex_plot_format - def plot_with_data(self, model, x, y, ndraws=1000, npoints=1000, - xlabel=None, ylabel=None, data_label='data', - data_fmt='o', draws_label=None, filename=None, - maxl_label='max likelihood', dpi=300, outdir=None): - """ Generate a figure showing the data and fits to the data + def plot_with_data( + self, + model, + x, + y, + ndraws=1000, + npoints=1000, + xlabel=None, + ylabel=None, + data_label="data", + data_fmt="o", + draws_label=None, + filename=None, + maxl_label="max likelihood", + dpi=300, + outdir=None, + ): + """Generate a figure showing the data and fits to the data Parameters ========== @@ -1484,20 +1570,17 @@ def plot_with_data(self, model, x, y, ndraws=1000, npoints=1000, xsmooth = np.linspace(np.min(x), np.max(x), npoints) fig, ax = plt.subplots() - logger.info('Plotting {} draws'.format(ndraws)) + logger.info(f"Plotting {ndraws} draws") for _ in range(ndraws): - s = model_posterior.sample().to_dict('records')[0] - ax.plot(xsmooth, model(xsmooth, **s), alpha=0.25, lw=0.1, color='r', - label=draws_label) + s = model_posterior.sample().to_dict("records")[0] + ax.plot(xsmooth, model(xsmooth, **s), alpha=0.25, lw=0.1, color="r", label=draws_label) try: if all(~np.isnan(self.posterior.log_likelihood)): - logger.info('Plotting maximum likelihood') + logger.info("Plotting maximum likelihood") s = model_posterior.iloc[self.posterior.log_likelihood.idxmax()] - ax.plot(xsmooth, model(xsmooth, **s), lw=1, color='k', - label=maxl_label) + ax.plot(xsmooth, model(xsmooth, **s), lw=1, color="k", label=maxl_label) except (AttributeError, TypeError): - logger.debug( - "No log likelihood values stored, unable to plot max") + logger.debug("No log likelihood values stored, unable to plot max") ax.plot(x, y, data_fmt, markersize=2, label=data_label) @@ -1513,7 +1596,7 @@ def plot_with_data(self, model, x, y, ndraws=1000, npoints=1000, fig.tight_layout() if filename is None: outdir = self._safe_outdir_creation(outdir, self.plot_with_data) - filename = '{}/{}_plot_with_data'.format(outdir, self.label) + filename = f"{outdir}/{self.label}_plot_with_data" safe_save_figure(fig=fig, filename=filename, dpi=dpi) plt.close(fig) @@ -1522,15 +1605,13 @@ def _add_prior_fixed_values_to_posterior(posterior, priors): if priors is None: return posterior for key in priors: - if isinstance(priors[key], DeltaFunction) and \ - not isinstance(priors[key], ConditionalDeltaFunction): + if isinstance(priors[key], DeltaFunction) and not isinstance(priors[key], ConditionalDeltaFunction): posterior[key] = priors[key].peak elif isinstance(priors[key], float): posterior[key] = priors[key] return posterior - def samples_to_posterior(self, likelihood=None, priors=None, - conversion_function=None, npool=1): + def samples_to_posterior(self, likelihood=None, priors=None, conversion_function=None, npool=1): """ Convert array of samples to posterior (a Pandas data frame) @@ -1547,17 +1628,13 @@ def samples_to_posterior(self, likelihood=None, priors=None, should take the data_frame, likelihood and prior as arguments. """ - data_frame = pd.DataFrame( - self.samples, columns=self.search_parameter_keys) - data_frame = self._add_prior_fixed_values_to_posterior( - data_frame, priors) - data_frame['log_likelihood'] = getattr( - self, 'log_likelihood_evaluations', np.nan) + data_frame = pd.DataFrame(self.samples, columns=self.search_parameter_keys) + data_frame = self._add_prior_fixed_values_to_posterior(data_frame, priors) + data_frame["log_likelihood"] = getattr(self, "log_likelihood_evaluations", np.nan) if self.log_prior_evaluations is None and priors is not None: - data_frame['log_prior'] = priors.ln_prob( - dict(data_frame[self.search_parameter_keys]), axis=0) + data_frame["log_prior"] = priors.ln_prob(dict(data_frame[self.search_parameter_keys]), axis=0) else: - data_frame['log_prior'] = self.log_prior_evaluations + data_frame["log_prior"] = self.log_prior_evaluations if conversion_function is not None: if "npool" in inspect.signature(conversion_function).parameters: @@ -1581,8 +1658,7 @@ def calculate_prior_values(self, priors): if isinstance(priors[key], DeltaFunction): continue else: - self.prior_values[key]\ - = priors[key].prob(self.posterior[key].values) + self.prior_values[key] = priors[key].prob(self.posterior[key].values) def get_all_injection_credible_levels(self, keys=None, weights=None): """ @@ -1606,13 +1682,12 @@ def get_all_injection_credible_levels(self, keys=None, weights=None): if keys is None: keys = self.search_parameter_keys if self.injection_parameters is None: - raise TypeError( - "Result object has no 'injection_parameters'. " - "Cannot compute credible levels." - ) - credible_levels = {key: self.get_injection_credible_level(key, weights=weights) - for key in keys - if isinstance(self.injection_parameters.get(key, None), float)} + raise TypeError("Result object has no 'injection_parameters'. Cannot compute credible levels.") + credible_levels = { + key: self.get_injection_credible_level(key, weights=weights) + for key in keys + if isinstance(self.injection_parameters.get(key, None), float) + } return credible_levels def get_injection_credible_level(self, parameter, weights=None): @@ -1635,26 +1710,21 @@ def get_injection_credible_level(self, parameter, weights=None): float: credible level """ if self.injection_parameters is None: - raise ( - TypeError, - "Result object has no 'injection_parameters'. " - "Cannot copmute credible levels." - ) + raise (TypeError, "Result object has no 'injection_parameters'. Cannot copmute credible levels.") if weights is None: weights = np.ones(len(self.posterior)) - if parameter in self.posterior and\ - parameter in self.injection_parameters: - credible_level =\ - sum(np.array(self.posterior[parameter].values < - self.injection_parameters[parameter]) * weights) / (sum(weights)) + if parameter in self.posterior and parameter in self.injection_parameters: + credible_level = sum( + np.array(self.posterior[parameter].values < self.injection_parameters[parameter]) * weights + ) / (sum(weights)) return credible_level else: return np.nan def _check_attribute_match_to_other_object(self, name, other_object): - """ Check attribute name exists in other_object and is the same + """Check attribute name exists in other_object and is the same Parameters ========== @@ -1670,7 +1740,7 @@ def _check_attribute_match_to_other_object(self, name, other_object): """ a = getattr(self, name, False) b = getattr(other_object, name, False) - logger.debug('Checking {} value: {}=={}'.format(name, a, b)) + logger.debug(f"Checking {name} value: {a}=={b}") if (a is not False) and (b is not False): type_a = type(a) type_b = type(b) @@ -1686,19 +1756,18 @@ def _check_attribute_match_to_other_object(self, name, other_object): @property def kde(self): - """ Kernel density estimate built from the stored posterior + """Kernel density estimate built from the stored posterior Uses `scipy.stats.gaussian_kde` to generate the kernel density """ if self._kde: return self._kde else: - self._kde = scipy.stats.gaussian_kde( - self.posterior[self.search_parameter_keys].values.T) + self._kde = scipy.stats.gaussian_kde(self.posterior[self.search_parameter_keys].values.T) return self._kde def posterior_probability(self, sample): - """ Calculate the posterior probability for a new sample + """Calculate the posterior probability for a new sample This queries a Kernel Density Estimate of the posterior to calculate the posterior probability density for the new sample. @@ -1718,8 +1787,7 @@ def posterior_probability(self, sample): """ if isinstance(sample, dict): sample = [sample] - ordered_sample = [[s[key] for key in self.search_parameter_keys] - for s in sample] + ordered_sample = [[s[key] for key in self.search_parameter_keys] for s in sample] return self.kde(ordered_sample) def _safe_outdir_creation(self, outdir=None, caller_func=None): @@ -1728,42 +1796,46 @@ def _safe_outdir_creation(self, outdir=None, caller_func=None): try: utils.check_directory_exists_and_if_not_mkdir(outdir) except PermissionError: - raise FileMovedError("Can not write in the out directory.\n" - "Did you move the here file from another system?\n" - "Try calling " + caller_func.__name__ + " with the 'outdir' " - "keyword argument, e.g. " + caller_func.__name__ + "(outdir='.')") + raise FileMovedError( + "Can not write in the out directory.\n" + "Did you move the here file from another system?\n" + "Try calling " + caller_func.__name__ + " with the 'outdir' " + "keyword argument, e.g. " + caller_func.__name__ + "(outdir='.')" + ) return outdir def get_weights_by_new_prior(self, old_prior, new_prior, prior_names=None): - """ Calculate a list of sample weights based on the ratio of new to old priors + """Calculate a list of sample weights based on the ratio of new to old priors - Parameters - ========== - old_prior: PriorDict, - The prior used in the generation of the original samples. + Parameters + ========== + old_prior: PriorDict, + The prior used in the generation of the original samples. - new_prior: PriorDict, - The prior to use to reweight the samples. + new_prior: PriorDict, + The prior to use to reweight the samples. - prior_names: list - A list of the priors to include in the ratio during reweighting. + prior_names: list + A list of the priors to include in the ratio during reweighting. - Returns - ======= - weights: array-like, - A list of sample weights. + Returns + ======= + weights: array-like, + A list of sample weights. - """ + """ weights = [] # Shared priors - these will form a ratio if prior_names is not None: - shared_parameters = {key: self.posterior[key] for key in new_prior if - key in old_prior and key in prior_names} + shared_parameters = { + key: self.posterior[key] for key in new_prior if key in old_prior and key in prior_names + } else: shared_parameters = {key: self.posterior[key] for key in new_prior if key in old_prior} - parameters = [{key: self.posterior[key][i] for key in shared_parameters.keys()} - for i in range(len(self.posterior))] + parameters = [ + {key: self.posterior[key][i] for key in shared_parameters.keys()} for i in range(len(self.posterior)) + ] for i in range(len(self.posterior)): weight = 1 @@ -1777,45 +1849,39 @@ def get_weights_by_new_prior(self, old_prior, new_prior, prior_names=None): return weights def to_arviz(self, prior=None): - """ Convert the Result object to an ArviZ InferenceData object. - - Parameters - ========== - prior: int - If a positive integer is given then that number of prior - samples will be drawn and stored in the ArviZ InferenceData - object. - - Returns - ======= - azdata: InferenceData - The ArviZ InferenceData object. - - Raises - ====== - RuntimeError: If ArviZ is not installed. + """Convert the Result object to an ArviZ InferenceData object. + + Parameters + ========== + prior: int + If a positive integer is given then that number of prior + samples will be drawn and stored in the ArviZ InferenceData + object. + + Returns + ======= + azdata: InferenceData + The ArviZ InferenceData object. + + Raises + ====== + RuntimeError: If ArviZ is not installed. """ try: import arviz as az except ImportError: - raise ResultError( - "ArviZ is not installed, so cannot convert to InferenceData." - ) + raise ResultError("ArviZ is not installed, so cannot convert to InferenceData.") posdict = {} for key in self.posterior: posdict[key] = self.posterior[key].values if "log_likelihood" in posdict: - loglikedict = { - "log_likelihood": posdict.pop("log_likelihood") - } + loglikedict = {"log_likelihood": posdict.pop("log_likelihood")} else: if self.log_likelihood_evaluations is not None: - loglikedict = { - "log_likelihood": self.log_likelihood_evaluations - } + loglikedict = {"log_likelihood": self.log_likelihood_evaluations} else: loglikedict = None @@ -1823,8 +1889,7 @@ def to_arviz(self, prior=None): if prior is not None: if self.priors is None: logger.warning( - "No priors are in the Result object, so prior samples " - "will not be included in the output." + "No priors are in the Result object, so prior samples will not be included in the output." ) else: priorsamples = self.priors.sample(size=prior) @@ -1837,8 +1902,8 @@ def to_arviz(self, prior=None): # add attributes version = { - "inference_library": "bilby: {}".format(self.sampler), - "inference_library_version": get_version_information() + "inference_library": f"bilby: {self.sampler}", + "inference_library_version": get_version_information(), } azdata.posterior.attrs.update(version) @@ -1851,9 +1916,8 @@ def to_arviz(self, prior=None): class ResultList(list): - def __init__(self, results=None, consistency_level="warning"): - """ A class to store a list of :class:`bilby.core.result.Result` objects + """A class to store a list of :class:`bilby.core.result.Result` objects from equivalent runs on the same data. This provides methods for outputting combined results. @@ -1868,7 +1932,7 @@ def __init__(self, results=None, consistency_level="warning"): nothing. """ - super(ResultList, self).__init__() + super().__init__() self.consistency_level = consistency_level for result in results: self.append(result) @@ -1885,9 +1949,9 @@ def append(self, result): """ if isinstance(result, Result): - super(ResultList, self).append(result) + super().append(result) elif isinstance(result, (str, os.PathLike)): - super(ResultList, self).append(read_in_result(result)) + super().append(read_in_result(result)) else: raise TypeError("Could not append a non-Result type") @@ -1923,7 +1987,7 @@ def combine(self, shuffle=False, consistency_level="error"): result = copy(self[0]) if result.label is not None: - result.label += '_combined' + result.label += "_combined" self.check_consistent_sampler() self.check_consistent_data() @@ -1970,17 +2034,18 @@ def _combine_nested_sampled_runs(self, result): The result object with the combined evidences. """ from scipy.special import logsumexp + self.check_nested_samples() # Combine evidences log_evidences = np.array([res.log_evidence for res in self]) - result.log_evidence = logsumexp(log_evidences, b=1. / len(self)) + result.log_evidence = logsumexp(log_evidences, b=1.0 / len(self)) result.log_bayes_factor = result.log_evidence - result.log_noise_evidence # Propagate uncertainty in combined evidence log_errs = [res.log_evidence_err for res in self if np.isfinite(res.log_evidence_err)] if len(log_errs) > 0: - result.log_evidence_err = 0.5 * logsumexp(2 * np.array(log_errs), b=1. / len(self)) + result.log_evidence_err = 0.5 * logsumexp(2 * np.array(log_errs), b=1.0 / len(self)) else: result.log_evidence_err = np.nan @@ -1988,7 +2053,7 @@ def _combine_nested_sampled_runs(self, result): result_weights = np.exp(log_evidences - np.max(log_evidences)) posteriors = list() for res, frac in zip(self, result_weights): - selected_samples = (random.rng.uniform(size=len(res.posterior)) < frac) + selected_samples = random.rng.uniform(size=len(res.posterior)) < frac posteriors.append(res.posterior[selected_samples]) # remove original nested_samples @@ -2020,13 +2085,13 @@ def _combine_mcmc_sampled_runs(self, result): # Combine evidences log_evidences = np.array([res.log_evidence for res in self]) - result.log_evidence = logsumexp(log_evidences, b=1. / len(self)) + result.log_evidence = logsumexp(log_evidences, b=1.0 / len(self)) result.log_bayes_factor = result.log_evidence - result.log_noise_evidence # Propagate uncertainty in combined evidence log_errs = [res.log_evidence_err for res in self if np.isfinite(res.log_evidence_err)] if len(log_errs) > 0: - result.log_evidence_err = 0.5 * logsumexp(2 * np.array(log_errs), b=1. / len(self)) + result.log_evidence_err = 0.5 * logsumexp(2 * np.array(log_errs), b=1.0 / len(self)) else: result.log_evidence_err = np.nan @@ -2084,10 +2149,19 @@ def check_consistent_sampler(self): @latex_plot_format -def plot_multiple(results, filename=None, labels=None, colours=None, - save=True, evidences=False, corner_labels=None, linestyles=None, - fig=None, **kwargs): - """ Generate a corner plot overlaying two sets of results +def plot_multiple( + results, + filename=None, + labels=None, + colours=None, + save=True, + evidences=False, + corner_labels=None, + linestyles=None, + fig=None, + **kwargs, +): + """Generate a corner plot overlaying two sets of results Parameters ========== @@ -2127,33 +2201,33 @@ def plot_multiple(results, filename=None, labels=None, colours=None, A matplotlib figure instance """ - import matplotlib.pyplot as plt import matplotlib.lines as mpllines + import matplotlib.pyplot as plt - kwargs['show_titles'] = False - kwargs['truths'] = None + kwargs["show_titles"] = False + kwargs["truths"] = None if corner_labels is not None: - kwargs['labels'] = corner_labels + kwargs["labels"] = corner_labels fig = results[0].plot_corner(fig=fig, save=False, **kwargs) - default_filename = '{}/{}'.format(results[0].outdir, 'combined') + default_filename = "{}/{}".format(results[0].outdir, "combined") lines = [] default_labels = [] for i, result in enumerate(results): if colours: c = colours[i] else: - c = 'C{}'.format(i) + c = f"C{i}" if linestyles is not None: linestyle = linestyles[i] else: - linestyle = 'solid' - hist_kwargs = kwargs.get('hist_kwargs', dict()) - hist_kwargs['color'] = c + linestyle = "solid" + hist_kwargs = kwargs.get("hist_kwargs", dict()) + hist_kwargs["color"] = c hist_kwargs["linestyle"] = linestyle kwargs["hist_kwargs"] = hist_kwargs fig = result.plot_corner(fig=fig, save=False, color=c, contour_kwargs={"linestyles": linestyle}, **kwargs) - default_filename += '_{}'.format(result.label) + default_filename += f"_{result.label}" lines.append(mpllines.Line2D([0], [0], color=c, linestyle=linestyle)) default_labels.append(result.label) @@ -2169,9 +2243,9 @@ def plot_multiple(results, filename=None, labels=None, colours=None, if evidences: if np.isnan(results[0].log_bayes_factor): - template = r'{label} $\mathrm{{ln}}(Z)={lnz:1.3g}$' + template = r"{label} $\mathrm{{ln}}(Z)={lnz:1.3g}$" else: - template = r'{label} $\mathrm{{ln}}(B)={lnbf:1.3g}$' + template = r"{label} $\mathrm{{ln}}(B)={lnbf:1.3g}$" labels = [ template.format( label=label, @@ -2194,10 +2268,19 @@ def plot_multiple(results, filename=None, labels=None, colours=None, @latex_plot_format -def make_pp_plot(results, filename=None, save=True, confidence_interval=[0.68, 0.95, 0.997], - lines=None, legend_fontsize='x-small', keys=None, title=True, - confidence_interval_alpha=0.1, weight_list=None, - **kwargs): +def make_pp_plot( + results, + filename=None, + save=True, + confidence_interval=[0.68, 0.95, 0.997], + lines=None, + legend_fontsize="x-small", + keys=None, + title=True, + confidence_interval_alpha=0.1, + weight_list=None, + **kwargs, +): """ Make a P-P plot for a set of runs with injected signals. @@ -2243,15 +2326,13 @@ def make_pp_plot(results, filename=None, save=True, confidence_interval=[0.68, 0 credible_levels = list() for i, result in enumerate(results): - credible_levels.append( - result.get_all_injection_credible_levels(keys, weights=weight_list[i]) - ) + credible_levels.append(result.get_all_injection_credible_levels(keys, weights=weight_list[i])) credible_levels = pd.DataFrame(credible_levels) if lines is None: - colors = ["C{}".format(i) for i in range(8)] + colors = [f"C{i}" for i in range(8)] linestyles = ["-", "--", ":"] - lines = ["{}{}".format(a, b) for a, b in product(linestyles, colors)] + lines = [f"{a}{b}" for a, b in product(linestyles, colors)] if len(lines) < len(credible_levels.keys()): raise ValueError("Larger number of parameters than unique linestyles") @@ -2265,45 +2346,41 @@ def make_pp_plot(results, filename=None, save=True, confidence_interval=[0.68, 0 if isinstance(confidence_interval_alpha, float): confidence_interval_alpha = [confidence_interval_alpha] * len(confidence_interval) elif len(confidence_interval_alpha) != len(confidence_interval): - raise ValueError( - "confidence_interval_alpha must have the same length as confidence_interval") + raise ValueError("confidence_interval_alpha must have the same length as confidence_interval") for ci, alpha in zip(confidence_interval, confidence_interval_alpha): - edge_of_bound = (1. - ci) / 2. + edge_of_bound = (1.0 - ci) / 2.0 lower = scipy.stats.binom.ppf(1 - edge_of_bound, N, x_values) / N upper = scipy.stats.binom.ppf(edge_of_bound, N, x_values) / N # The binomial point percent function doesn't always return 0 @ 0, # so set those bounds explicitly to be sure lower[0] = 0 upper[0] = 0 - ax.fill_between(x_values, lower, upper, alpha=alpha, color='k') + ax.fill_between(x_values, lower, upper, alpha=alpha, color="k") pvalues = [] logger.info("Key: KS-test p-value") for ii, key in enumerate(credible_levels): - pp = np.array([sum(credible_levels[key].values < xx) / - len(credible_levels) for xx in x_values]) - pvalue = scipy.stats.kstest(credible_levels[key], 'uniform').pvalue + pp = np.array([sum(credible_levels[key].values < xx) / len(credible_levels) for xx in x_values]) + pvalue = scipy.stats.kstest(credible_levels[key], "uniform").pvalue pvalues.append(pvalue) - logger.info("{}: {}".format(key, pvalue)) + logger.info(f"{key}: {pvalue}") try: name = results[0].priors[key].latex_label except (AttributeError, KeyError): name = key - label = "{} ({:2.3f})".format(name, pvalue) + label = f"{name} ({pvalue:2.3f})" plt.plot(x_values, pp, lines[ii], label=label, **kwargs) - Pvals = namedtuple('pvals', ['combined_pvalue', 'pvalues', 'names']) - pvals = Pvals(combined_pvalue=scipy.stats.combine_pvalues(pvalues)[1], - pvalues=pvalues, - names=list(credible_levels.keys())) - logger.info( - "Combined p-value: {}".format(pvals.combined_pvalue)) + Pvals = namedtuple("pvals", ["combined_pvalue", "pvalues", "names"]) + pvals = Pvals( + combined_pvalue=scipy.stats.combine_pvalues(pvalues)[1], pvalues=pvalues, names=list(credible_levels.keys()) + ) + logger.info(f"Combined p-value: {pvals.combined_pvalue}") if title: - ax.set_title("N={}, p-value={:2.4f}".format( - len(results), pvals.combined_pvalue)) + ax.set_title(f"N={len(results)}, p-value={pvals.combined_pvalue:2.4f}") ax.set_xlabel("C.I.") ax.set_ylabel("Fraction of events in C.I.") ax.legend(handlelength=2, labelspacing=0.25, fontsize=legend_fontsize) @@ -2312,14 +2389,14 @@ def make_pp_plot(results, filename=None, save=True, confidence_interval=[0.68, 0 fig.tight_layout() if save: if filename is None: - filename = 'outdir/pp.png' + filename = "outdir/pp.png" safe_save_figure(fig=fig, filename=filename, dpi=500) return fig, pvals def sanity_check_labels(labels): - """ Check labels for plotting to remove matplotlib errors """ + """Check labels for plotting to remove matplotlib errors""" for ii, lab in enumerate(labels): if "_" in lab and "$" not in lab: lab = lab.replace("_", "-") @@ -2328,16 +2405,16 @@ def sanity_check_labels(labels): class ResultError(Exception): - """ Base exception for all Result related errors """ + """Base exception for all Result related errors""" class ResultListError(ResultError): - """ For Errors occurring during combining results. """ + """For Errors occurring during combining results.""" class FileMovedError(ResultError): - """ Exceptions that occur when files have been moved """ + """Exceptions that occur when files have been moved""" class FileLoadError(ResultError): - """ Exceptions that occur when files cannot be loaded """ + """Exceptions that occur when files cannot be loaded""" diff --git a/bilby/core/sampler/__init__.py b/bilby/core/sampler/__init__.py index 086e69388..6328d85a6 100644 --- a/bilby/core/sampler/__init__.py +++ b/bilby/core/sampler/__init__.py @@ -14,6 +14,16 @@ from . import proposal from .base_sampler import Sampler, SamplingMarginalisedParameterError +__all__ = [ + proposal, + Sampler, + "ImplementedSamplers", + "IMPLEMENTED_SAMPLERS", + "get_implemented_samplers", + "get_sampler_class", + "run_sampler", +] + class ImplementedSamplers: """Dictionary-like object that contains implemented samplers. @@ -67,10 +77,7 @@ def __getitem__(self, key): elif f"bilby.{key}" in self._samplers: return self._samplers[f"bilby.{key}"] else: - raise ValueError( - f"Sampler {key} is not implemented! " - f"Available samplers are: {list(self.keys())}" - ) + raise ValueError(f"Sampler {key} is not implemented! Available samplers are: {list(self.keys())}") def __contains__(self, value): return value in self.valid_keys() @@ -130,10 +137,7 @@ def get_sampler_class(sampler): print(sampler_class.__doc__) else: if sampler == "None": - print( - "For help with a specific sampler, call sampler-help with " - "the name of the sampler" - ) + print("For help with a specific sampler, call sampler-help with the name of the sampler") else: print(f"Requested sampler {sampler} not implemented") print(f"Available samplers = {get_implemented_samplers()}") @@ -235,14 +239,12 @@ def run_sampler( if command_line_args.clean: kwargs["resume"] = False - from . import IMPLEMENTED_SAMPLERS - if priors is None: priors = dict() _check_marginalized_parameters_not_sampled(likelihood, priors) - if type(priors) == dict: + if type(priors) is dict: priors = PriorDict(priors) elif isinstance(priors, PriorDict): pass @@ -298,8 +300,7 @@ def run_sampler( ) else: raise ValueError( - "Provided sampler should be a Sampler object or name of a known " - f"sampler: {get_implemented_samplers()}." + f"Provided sampler should be a Sampler object or name of a known sampler: {get_implemented_samplers()}." ) if sampler.cached_result: @@ -333,9 +334,7 @@ def run_sampler( result.log_bayes_factor = result.log_evidence - result.log_noise_evidence if None not in [result.injection_parameters, conversion_function]: - result.injection_parameters = conversion_function( - result.injection_parameters - ) + result.injection_parameters = conversion_function(result.injection_parameters) # Initial save of the sampler in case of failure in samples_to_posterior if save: diff --git a/bilby/core/sampler/base_sampler.py b/bilby/core/sampler/base_sampler.py index 012517c83..b90de4aa3 100644 --- a/bilby/core/sampler/base_sampler.py +++ b/bilby/core/sampler/base_sampler.py @@ -108,7 +108,7 @@ def wrapped(self, *args, **kwargs): return wrapped -class Sampler(object): +class Sampler: """A sampler object to aid in setting up an inference run Parameters @@ -333,9 +333,7 @@ def _verify_external_sampler(self): try: __import__(external_sampler_name) except (ImportError, SystemExit): - raise SamplerNotInstalledError( - f"Sampler {external_sampler_name} is not installed on this system" - ) + raise SamplerNotInstalledError(f"Sampler {external_sampler_name} is not installed on this system") def _verify_kwargs_against_default_kwargs(self): """ @@ -360,10 +358,7 @@ def _initialise_parameters(self): the respective parameter is fixed. """ for key in self.priors: - if ( - isinstance(self.priors[key], Prior) - and self.priors[key].is_fixed is False - ): + if isinstance(self.priors[key], Prior) and self.priors[key].is_fixed is False: self._search_parameter_keys.append(key) elif isinstance(self.priors[key], Constraint): self._constraint_parameter_keys.append(key) @@ -378,9 +373,7 @@ def _log_information_about_priors_and_likelihood(self): for key in self._fixed_parameter_keys: logger.info(f"{key}={self.priors[key].peak}") logger.info(f"Analysis likelihood class: {self.likelihood.__class__}") - logger.info( - f"Analysis likelihood noise evidence: {self.likelihood.noise_log_likelihood()}" - ) + logger.info(f"Analysis likelihood noise evidence: {self.likelihood.noise_log_likelihood()}") def _initialise_result(self, result_class): """ @@ -425,23 +418,16 @@ def _verify_parameters(self): """ if self.priors.test_has_redundant_keys(): - raise IllegalSamplingSetError( - "Your sampling set contains redundant parameters." - ) + raise IllegalSamplingSetError("Your sampling set contains redundant parameters.") - theta = self.priors.sample_subset_constrained_as_array( - self.search_parameter_keys, size=1 - )[:, 0] + theta = self.priors.sample_subset_constrained_as_array(self.search_parameter_keys, size=1)[:, 0] try: self.log_likelihood(theta) except TypeError as e: params = deepcopy(self.parameters) - params.update( - {key: val for key, val in zip(self.search_parameter_keys, theta)} - ) + params.update({key: val for key, val in zip(self.search_parameter_keys, theta)}) raise TypeError( - f"Likelihood evaluation failed with message: \n'{e}'\n" - f"Have you specified all the parameters:\n{params}" + f"Likelihood evaluation failed with message: \n'{e}'\nHave you specified all the parameters:\n{params}" ) def _time_likelihood(self, n_evaluations=100): @@ -460,9 +446,7 @@ def _time_likelihood(self, n_evaluations=100): t1 = datetime.datetime.now() for _ in range(n_evaluations): - theta = self.priors.sample_subset_constrained_as_array( - self._search_parameter_keys, size=1 - )[:, 0] + theta = self.priors.sample_subset_constrained_as_array(self._search_parameter_keys, size=1)[:, 0] self.log_likelihood(theta) total_time = (datetime.datetime.now() - t1).total_seconds() log_likelihood_eval_time = total_time / n_evaluations @@ -471,9 +455,7 @@ def _time_likelihood(self, n_evaluations=100): log_likelihood_eval_time = np.nan logger.info("Unable to measure single likelihood time") else: - logger.info( - f"Single likelihood evaluation took {log_likelihood_eval_time:.3e} s" - ) + logger.info(f"Single likelihood evaluation took {log_likelihood_eval_time:.3e} s") return log_likelihood_eval_time def _verify_use_ratio(self): @@ -484,9 +466,7 @@ def _verify_use_ratio(self): try: self.priors.sample_subset(self.search_parameter_keys) except (KeyError, AttributeError): - logger.error( - f"Cannot sample from priors with keys: {self.search_parameter_keys}." - ) + logger.error(f"Cannot sample from priors with keys: {self.search_parameter_keys}.") raise if self.use_ratio is False: logger.debug("use_ratio set to False") @@ -495,15 +475,10 @@ def _verify_use_ratio(self): parameters = deepcopy(self.parameters) parameters.update(self.priors.sample()) - ratio_is_nan = np.isnan( - _safe_likelihood_call(self.likelihood, parameters, use_ratio=True) - ) + ratio_is_nan = np.isnan(_safe_likelihood_call(self.likelihood, parameters, use_ratio=True)) if self.use_ratio is True and ratio_is_nan: - logger.warning( - "You have requested to use the loglikelihood_ratio, but it " - " returns a NaN" - ) + logger.warning("You have requested to use the loglikelihood_ratio, but it returns a NaN") elif self.use_ratio is None and not ratio_is_nan: logger.debug("use_ratio not spec. but gives valid answer, setting True") self.use_ratio = True @@ -560,9 +535,7 @@ def log_likelihood(self, theta): params = deepcopy(self.parameters) params.update({key: t for key, t in zip(self._search_parameter_keys, theta)}) - return _safe_likelihood_call( - self.likelihood, parameters=params, use_ratio=self.use_ratio - ) + return _safe_likelihood_call(self.likelihood, parameters=params, use_ratio=self.use_ratio) def get_random_draw_from_prior(self): """Get a random draw from the prior distribution @@ -636,9 +609,7 @@ def check_draw(self, theta, warning=True): """ log_p = self.log_prior(theta) log_l = self.log_likelihood(theta) - return self._check_bad_value( - val=log_p, warning=warning, theta=theta, label="prior" - ) and self._check_bad_value( + return self._check_bad_value(val=log_p, warning=warning, theta=theta, label="prior") and self._check_bad_value( val=log_l, warning=warning, theta=theta, label="likelihood" ) @@ -674,10 +645,8 @@ def _check_cached_result(self, result_class=None): return try: - self.cached_result = read_in_result( - outdir=self.outdir, label=self.label, result_class=result_class - ) - except IOError: + self.cached_result = read_in_result(outdir=self.outdir, label=self.label, result_class=result_class) + except OSError: self.cached_result = None if command_line_args.use_cached: @@ -689,10 +658,7 @@ def _check_cached_result(self, result_class=None): check_keys = ["search_parameter_keys", "fixed_parameter_keys"] use_cache = True for key in check_keys: - if ( - self.cached_result._check_attribute_match_to_other_object(key, self) - is False - ): + if self.cached_result._check_attribute_match_to_other_object(key, self) is False: logger.debug(f"Cached value {key} is unmatched") use_cache = False try: @@ -717,9 +683,7 @@ def _log_summary_for_sampler(self): kwargs_print[k] = f"array_like, shape={array_repr.shape}" elif isinstance(kwargs_print[k], DataFrame): kwargs_print[k] = f"DataFrame, shape={kwargs_print[k].shape}" - logger.info( - f"Using sampler {self.__class__.__name__} with kwargs {kwargs_print}" - ) + logger.info(f"Using sampler {self.__class__.__name__} with kwargs {kwargs_print}") def calc_likelihood_count(self): if self.likelihood_benchmark: @@ -736,13 +700,9 @@ def npool(self): def _log_interruption(self, signum=None): if signum == 14: - logger.info( - f"Run interrupted by alarm signal {signum}: checkpoint and exit on {self.exit_code}" - ) + logger.info(f"Run interrupted by alarm signal {signum}: checkpoint and exit on {self.exit_code}") else: - logger.info( - f"Run interrupted by signal {signum}: checkpoint and exit on {self.exit_code}" - ) + logger.info(f"Run interrupted by signal {signum}: checkpoint and exit on {self.exit_code}") def write_current_state_and_exit(self, signum=None, frame=None): """ @@ -848,9 +808,7 @@ class NestedSampler(Sampler): walks_equiv_kwargs = ["walks", "steps", "nmcmc"] @staticmethod - def reorder_loglikelihoods( - unsorted_loglikelihoods, unsorted_samples, sorted_samples - ): + def reorder_loglikelihoods(unsorted_loglikelihoods, unsorted_samples, sorted_samples): """Reorders the stored log-likelihood after they have been reweighted This creates a sorting index by matching the reweights `result.samples` @@ -879,8 +837,7 @@ def reorder_loglikelihoods( idx = np.where(np.all(sorted_samples[ii] == unsorted_samples, axis=1))[0] if len(idx) > 1: logger.warning( - "Multiple likelihood matches found between sorted and " - "unsorted samples. Taking the first match." + "Multiple likelihood matches found between sorted and unsorted samples. Taking the first match." ) idxs.append(idx[0]) return unsorted_loglikelihoods[idxs] @@ -899,9 +856,7 @@ def log_likelihood(self, theta): ======= float: log_likelihood """ - if self.priors.evaluate_constraints( - {key: theta[ii] for ii, key in enumerate(self.search_parameter_keys)} - ): + if self.priors.evaluate_constraints({key: theta[ii] for ii, key in enumerate(self.search_parameter_keys)}): return Sampler.log_likelihood(self, theta) else: return np.nan_to_num(-np.inf) @@ -917,14 +872,9 @@ def print_nburn_logging_info(self): if type(self.nburn) in [float, int]: logger.info(f"Discarding {self.nburn} steps for burn-in") elif self.result.max_autocorrelation_time is None: - logger.info( - f"Autocorrelation time not calculated, discarding " - f"{self.nburn} steps for burn-in" - ) + logger.info(f"Autocorrelation time not calculated, discarding {self.nburn} steps for burn-in") else: - logger.info( - f"Discarding {self.nburn} steps for burn-in, estimated from autocorr" - ) + logger.info(f"Discarding {self.nburn} steps for burn-in, estimated from autocorr") def calculate_autocorrelation(self, samples, c=3): """Uses the `emcee.autocorr` module to estimate the autocorrelation @@ -940,9 +890,7 @@ def calculate_autocorrelation(self, samples, c=3): import emcee try: - self.result.max_autocorrelation_time = int( - np.max(emcee.autocorr.integrated_time(samples, c=c)) - ) + self.result.max_autocorrelation_time = int(np.max(emcee.autocorr.integrated_time(samples, c=c))) logger.info(f"Max autocorr time = {self.result.max_autocorrelation_time}") except emcee.autocorr.AutocorrError as e: self.result.max_autocorrelation_time = None @@ -962,7 +910,7 @@ class _TemporaryFileSamplerMixin: short_name = "" def __init__(self, temporary_directory, **kwargs): - super(_TemporaryFileSamplerMixin, self).__init__(**kwargs) + super().__init__(**kwargs) try: from mpi4py import MPI @@ -971,17 +919,14 @@ def __init__(self, temporary_directory, **kwargs): using_mpi = False if using_mpi and temporary_directory: - logger.info( - "Temporary directory incompatible with MPI, " - "will run in original directory" - ) + logger.info("Temporary directory incompatible with MPI, will run in original directory") self.use_temporary_directory = temporary_directory and not using_mpi self._outputfiles_basename = None self._temporary_outputfiles_basename = None def _check_and_load_sampling_time_file(self): if os.path.exists(self.time_file_path): - with open(self.time_file_path, "r") as time_file: + with open(self.time_file_path) as time_file: self.total_sampling_time = float(time_file.readline()) else: self.total_sampling_time = 0 @@ -1024,9 +969,7 @@ def temporary_outputfiles_basename(self, temporary_outputfiles_basename): temporary_outputfiles_basename += "/" self._temporary_outputfiles_basename = temporary_outputfiles_basename if os.path.exists(self.outputfiles_basename): - shutil.copytree( - self.outputfiles_basename, self.temporary_outputfiles_basename - ) + shutil.copytree(self.outputfiles_basename, self.temporary_outputfiles_basename) def write_current_state(self): self._calculate_and_save_sampling_time() @@ -1047,9 +990,7 @@ def _copy_temporary_directory_contents_to_proper_path(self): Copy the temporary back to the proper path. Do not delete the temporary directory. """ - logger.info( - f"Overwriting {self.outputfiles_basename} with {self.temporary_outputfiles_basename}" - ) + logger.info(f"Overwriting {self.outputfiles_basename} with {self.temporary_outputfiles_basename}") outputfiles_basename_stripped = self.outputfiles_basename.rstrip("/") shutil.copytree( self.temporary_outputfiles_basename, diff --git a/bilby/core/sampler/cpnest.py b/bilby/core/sampler/cpnest.py index e777ebc67..86fc3e041 100644 --- a/bilby/core/sampler/cpnest.py +++ b/bilby/core/sampler/cpnest.py @@ -97,10 +97,7 @@ def log_prior(x, **kwargs): return self.log_prior(theta) def _update_bounds(self): - self.bounds = [ - [self.priors[key].minimum, self.priors[key].maximum] - for key in self.names - ] + self.bounds = [[self.priors[key].minimum, self.priors[key].maximum] for key in self.names] def new_point(self): """Draw a point from the prior""" @@ -137,14 +134,10 @@ def new_point(self): out.plot() self.calc_likelihood_count() - self.result.samples = structured_to_unstructured( - out.posterior_samples[self.search_parameter_keys] - ) + self.result.samples = structured_to_unstructured(out.posterior_samples[self.search_parameter_keys]) self.result.log_likelihood_evaluations = out.posterior_samples["logL"] self.result.nested_samples = DataFrame(out.get_nested_samples(filename="")) - self.result.nested_samples.rename( - columns=dict(logL="log_likelihood"), inplace=True - ) + self.result.nested_samples.rename(columns=dict(logL="log_likelihood"), inplace=True) _, log_weights = compute_weights( np.array(self.result.nested_samples.log_likelihood), np.array(out.NS.state.nlive), @@ -184,14 +177,10 @@ def _resolve_proposal_functions(self): if self.kwargs["proposals"] is None: return if isinstance(self.kwargs["proposals"], JumpProposalCycle): - self.kwargs["proposals"] = dict( - mhs=self.kwargs["proposals"], hmc=self.kwargs["proposals"] - ) + self.kwargs["proposals"] = dict(mhs=self.kwargs["proposals"], hmc=self.kwargs["proposals"]) for key, proposal in self.kwargs["proposals"].items(): if isinstance(proposal, JumpProposalCycle): - self.kwargs["proposals"][key] = cpnest_proposal_cycle_factory( - proposal - ) + self.kwargs["proposals"][key] = cpnest_proposal_cycle_factory(proposal) elif isinstance(proposal, ProposalCycle): pass else: @@ -233,11 +222,9 @@ class CPNestProposalCycle(cpnest.proposal.ProposalCycle): def __init__(self): self.jump_proposals = copy.deepcopy(jump_proposals) for i, prop in enumerate(self.jump_proposals.proposal_functions): - self.jump_proposals.proposal_functions[i] = cpnest_proposal_factory( - prop - ) + self.jump_proposals.proposal_functions[i] = cpnest_proposal_factory(prop) self.jump_proposals.update_cycle() - super(CPNestProposalCycle, self).__init__( + super().__init__( proposals=self.jump_proposals.proposal_functions, weights=self.jump_proposals.weights, cyclelength=self.jump_proposals.cycle_length, diff --git a/bilby/core/sampler/dnest4.py b/bilby/core/sampler/dnest4.py index 0ec3a97dc..7c97229f3 100644 --- a/bilby/core/sampler/dnest4.py +++ b/bilby/core/sampler/dnest4.py @@ -8,7 +8,7 @@ from .base_sampler import NestedSampler, _TemporaryFileSamplerMixin, signal_wrapper -class _DNest4Model(object): +class _DNest4Model: msg = ( "The DNest4 sampler interface in bilby is deprecated and will" " be removed in future release. Please use the `dnest4-bilby`" @@ -16,9 +16,7 @@ class _DNest4Model(object): ) warnings.warn(msg, FutureWarning) - def __init__( - self, log_likelihood_func, from_prior_func, widths, centers, highs, lows - ): + def __init__(self, log_likelihood_func, from_prior_func, widths, centers, highs, lows): """Initialize the DNest4 model. Args: log_likelihood_func: function @@ -64,14 +62,11 @@ def perturb(self, coords): @staticmethod def wrap(x, minimum, maximum): if maximum <= minimum: - raise ValueError( - f"maximum {maximum} <= minimum {minimum}, when trying to wrap coordinates" - ) + raise ValueError(f"maximum {maximum} <= minimum {minimum}, when trying to wrap coordinates") return (x - minimum) % (maximum - minimum) + minimum class DNest4(_TemporaryFileSamplerMixin, NestedSampler): - """ Bilby wrapper of DNest4 @@ -139,7 +134,7 @@ def __init__( temporary_directory=True, **kwargs, ): - super(DNest4, self).__init__( + super().__init__( likelihood=likelihood, priors=priors, outdir=outdir, @@ -199,9 +194,7 @@ def _set_backend(self): import dnest4 if self._backend == "csv": - return dnest4.backends.CSVBackend( - f"{self.outdir}/dnest4{self.label}/", sep=" " - ) + return dnest4.backends.CSVBackend(f"{self.outdir}/dnest4{self.label}/", sep=" ") else: return dnest4.backends.MemoryBackend() @@ -222,9 +215,7 @@ def run_sampler(self): self.start_time = time.time() self.sampler = dnest4.DNest4Sampler(self._dnest4_model, backend=backend) - out = self.sampler.sample( - self.max_num_levels, num_particles=self.num_particles, **self.dnest4_kwargs - ) + out = self.sampler.sample(self.max_num_levels, num_particles=self.num_particles, **self.dnest4_kwargs) for i, sample in enumerate(out): if self._verbose and ((i + 1) % 100 == 0): @@ -257,4 +248,4 @@ def _translate_kwargs(self, kwargs): def _verify_kwargs_against_default_kwargs(self): self.outputfiles_basename = self.kwargs.pop("outputfiles_basename", None) - super(DNest4, self)._verify_kwargs_against_default_kwargs() + super()._verify_kwargs_against_default_kwargs() diff --git a/bilby/core/sampler/dynamic_dynesty.py b/bilby/core/sampler/dynamic_dynesty.py index 752930e29..e7d43bbc3 100644 --- a/bilby/core/sampler/dynamic_dynesty.py +++ b/bilby/core/sampler/dynamic_dynesty.py @@ -51,7 +51,7 @@ def _remove_live(self): pass def read_saved_state(self, continuing=False): - resume = super(DynamicDynesty, self).read_saved_state(continuing=continuing) + resume = super().read_saved_state(continuing=continuing) if not resume: return resume else: diff --git a/bilby/core/sampler/dynesty.py b/bilby/core/sampler/dynesty.py index 2771df874..c0fcaaea4 100644 --- a/bilby/core/sampler/dynesty.py +++ b/bilby/core/sampler/dynesty.py @@ -32,9 +32,7 @@ def _prior_transform_wrapper(theta): """Wrapper to the prior transformation. Needed for multiprocessing.""" from .base_sampler import _sampling_convenience_dump - return _sampling_convenience_dump.priors.rescale( - _sampling_convenience_dump.search_parameter_keys, theta - ) + return _sampling_convenience_dump.priors.rescale(_sampling_convenience_dump.search_parameter_keys, theta) def _log_likelihood_wrapper(theta): @@ -42,20 +40,10 @@ def _log_likelihood_wrapper(theta): from .base_sampler import _sampling_convenience_dump if _sampling_convenience_dump.priors.evaluate_constraints( - { - key: theta[ii] - for ii, key in enumerate(_sampling_convenience_dump.search_parameter_keys) - } + {key: theta[ii] for ii, key in enumerate(_sampling_convenience_dump.search_parameter_keys)} ): params = deepcopy(_sampling_convenience_dump.parameters) - params.update( - { - key: t - for key, t in zip( - _sampling_convenience_dump.search_parameter_keys, theta - ) - } - ) + params.update({key: t for key, t in zip(_sampling_convenience_dump.search_parameter_keys, theta)}) return _safe_likelihood_call( _sampling_convenience_dump.likelihood, params, @@ -165,11 +153,7 @@ class Dynesty(NestedSampler): @property def _dynesty_init_kwargs(self): params = inspect.signature(self.sampler_init).parameters - kwargs = { - key: param.default - for key, param in params.items() - if param.default != param.empty - } + kwargs = {key: param.default for key, param in params.items() if param.default != param.empty} kwargs["sample"] = "act-walk" kwargs["bound"] = "live" kwargs["update_interval"] = 600 @@ -179,11 +163,7 @@ def _dynesty_init_kwargs(self): @property def _dynesty_sampler_kwargs(self): params = inspect.signature(self.sampler_class.run_nested).parameters - kwargs = { - key: param.default - for key, param in params.items() - if param.default != param.empty - } + kwargs = {key: param.default for key, param in params.items() if param.default != param.empty} kwargs["save_bounds"] = False if "dlogz" in kwargs: kwargs["dlogz"] = 0.1 @@ -235,7 +215,7 @@ def __init__( self.proposals = proposals self.print_method = print_method self._translate_kwargs(kwargs) - super(Dynesty, self).__init__( + super().__init__( likelihood=likelihood, priors=priors, outdir=outdir, @@ -260,9 +240,7 @@ def __init__( self.n_check_point = ( 10 if np.isnan(self._log_likelihood_eval_time) - else max( - int(check_point_delta_t / self._log_likelihood_eval_time / 10), 10 - ) + else max(int(check_point_delta_t / self._log_likelihood_eval_time / 10), 10) ) self.check_point_delta_t = check_point_delta_t logger.info(f"Checkpoint every check_point_delta_t = {check_point_delta_t}s") @@ -294,9 +272,7 @@ def sampler_init_kwargs(self): if kwargs["sample"] == "act-walk": internal_kwargs["nact"] = self.nact - internal_sampler = dynesty_utils.ACTTrackingEnsembleWalk( - **internal_kwargs - ) + internal_sampler = dynesty_utils.ACTTrackingEnsembleWalk(**internal_kwargs) bound = "none" logger.info( f"Using the bilby-implemented ensemble rwalk sampling tracking the " @@ -315,9 +291,7 @@ def sampler_init_kwargs(self): ) elif kwargs["sample"] == "rwalk": internal_kwargs["nact"] = self.nact - internal_sampler = dynesty_utils.AcceptanceTrackingRWalk( - **internal_kwargs - ) + internal_sampler = dynesty_utils.AcceptanceTrackingRWalk(**internal_kwargs) bound = "none" logger.info( f"Using the bilby-implemented ensemble rwalk sampling method with ACT " @@ -360,9 +334,7 @@ def _translate_kwargs(self, kwargs): if "rstate" not in kwargs: kwargs["rstate"] = np.random.default_rng(seed) else: - logger.warning( - "Kwargs contain both 'rstate' and 'seed', ignoring 'seed'." - ) + logger.warning("Kwargs contain both 'rstate' and 'seed', ignoring 'seed'.") def _verify_kwargs_against_default_kwargs(self): if not self.kwargs["walks"]: @@ -371,9 +343,7 @@ def _verify_kwargs_against_default_kwargs(self): self.kwargs["print_func"] = self._print_func if "interval" in self.print_method: self._last_print_time = datetime.datetime.now() - self._print_interval = datetime.timedelta( - seconds=float(self.print_method.split("-")[1]) - ) + self._print_interval = datetime.timedelta(seconds=float(self.print_method.split("-")[1])) Sampler._verify_kwargs_against_default_kwargs(self) @classmethod @@ -540,8 +510,7 @@ def _set_sampling_method(self): "live-multi", ]: logger.info( - "Live-point based bound method requested with dynesty sample " - f"'{sample}', overwriting to 'multi'" + f"Live-point based bound method requested with dynesty sample '{sample}', overwriting to 'multi'" ) self.kwargs["bound"] = "multi" elif bound == "live": @@ -551,25 +520,15 @@ def _set_sampling_method(self): elif bound == "live-multi": from .dynesty_utils import MultiEllipsoidLivePointSampler - dynesty.dynamicsampler._SAMPLERS[ - "live-multi" - ] = MultiEllipsoidLivePointSampler + dynesty.dynamicsampler._SAMPLERS["live-multi"] = MultiEllipsoidLivePointSampler elif sample == "acceptance-walk": - raise DynestySetupError( - "bound must be set to live or live-multi for sample=acceptance-walk" - ) + raise DynestySetupError("bound must be set to live or live-multi for sample=acceptance-walk") elif self.proposals is None: - logger.warning( - "No proposals specified using dynesty sampling, defaulting " - "to 'volumetric'." - ) + logger.warning("No proposals specified using dynesty sampling, defaulting to 'volumetric'.") self.proposals = ["volumetric"] _SamplingContainer.proposals = self.proposals elif "diff" in self.proposals: - raise DynestySetupError( - "bound must be set to live or live-multi to use differential " - "evolution proposals" - ) + raise DynestySetupError("bound must be set to live or live-multi to use differential evolution proposals") if sample == "rwalk": logger.info( @@ -624,9 +583,7 @@ def run_sampler(self): logger.info("Resume file successfully loaded.") else: if self.kwargs["live_points"] is None: - self.kwargs["live_points"] = self.get_initial_points_from_prior( - self.nlive - ) + self.kwargs["live_points"] = self.get_initial_points_from_prior(self.nlive) self.kwargs["live_points"] = (*self.kwargs["live_points"], None) self.sampler = self.sampler_init( loglikelihood=_log_likelihood_wrapper, @@ -672,17 +629,13 @@ def _setup_pool(self): every process. To make sure we get every process, run the kwarg setting more times than we have processes. """ - super(Dynesty, self)._setup_pool() + super()._setup_pool() if self.new_dynesty_api: return if self.pool is not None: - args = ( - [(self.nact, self.maxmcmc, self.proposals, self.naccept)] - * self.npool - * 10 - ) + args = [(self.nact, self.maxmcmc, self.proposals, self.naccept)] * self.npool * 10 self.pool.map(_set_sampling_kwargs, args) def _generate_result(self, out): @@ -711,9 +664,7 @@ def _generate_result(self, out): keep = weights > random.rng.uniform(0, max(weights), len(weights)) self.result.samples = out.samples[keep] self.result.log_likelihood_evaluations = out.logl[keep] - logger.info( - f"Rejection sampling nested samples to obtain {sum(keep)} posterior samples" - ) + logger.info(f"Rejection sampling nested samples to obtain {sum(keep)} posterior samples") else: self.result.samples = dynesty.utils.resample_equal(out.samples, weights) self.result.log_likelihood_evaluations = self.reorder_loglikelihoods( @@ -785,9 +736,7 @@ def _run_external_sampler_with_checkpointing(self): if os.path.isfile(self.resume_file): last_checkpoint_s = time.time() - os.path.getmtime(self.resume_file) else: - last_checkpoint_s = ( - datetime.datetime.now() - self.start_time - ).total_seconds() + last_checkpoint_s = (datetime.datetime.now() - self.start_time).total_seconds() if last_checkpoint_s > self.check_point_delta_t: self.write_current_state() self._add_live() @@ -907,7 +856,7 @@ def read_saved_state(self, continuing=False): def write_current_state_and_exit(self, signum=None, frame=None): if self.pbar is not None: self.pbar = self.pbar.close() - super(Dynesty, self).write_current_state_and_exit(signum=signum, frame=frame) + super().write_current_state_and_exit(signum=signum, frame=frame) def write_current_state(self): """ @@ -949,10 +898,7 @@ def write_current_state(self): safe_file_dump((self.sampler, versions, metadata), self.resume_file, dill) logger.info(f"Written checkpoint file {self.resume_file}") else: - logger.warning( - "Cannot write pickle resume file! " - "Job will not resume if interrupted." - ) + logger.warning("Cannot write pickle resume file! Job will not resume if interrupted.") self.sampler.pool = self.pool if self.sampler.pool is not None: if self.new_dynesty_api: @@ -1015,8 +961,7 @@ def plot_current_state(self): logger.warning("Failed to create dynesty state plot at checkpoint") except Exception as e: logger.warning( - f"Unexpected error {e} in dynesty plotting. " - "Please report at github.com/bilby-dev/bilby/issues" + f"Unexpected error {e} in dynesty plotting. Please report at github.com/bilby-dev/bilby/issues" ) finally: plt.close("all") @@ -1041,16 +986,13 @@ def plot_current_state(self): logger.warning("Failed to create dynesty unit state plot at checkpoint") except Exception as e: logger.warning( - f"Unexpected error {e} in dynesty plotting. " - "Please report at github.com/bilby-dev/bilby/issues" + f"Unexpected error {e} in dynesty plotting. Please report at github.com/bilby-dev/bilby/issues" ) finally: plt.close("all") try: filename = f"{self.outdir}/{self.label}_checkpoint_run.png" - fig, _ = dyplot.runplot( - self.sampler.results, logplot=False, use_math_text=False - ) + fig, _ = dyplot.runplot(self.sampler.results, logplot=False, use_math_text=False) fig.tight_layout() plt.savefig(filename) except ( @@ -1063,8 +1005,7 @@ def plot_current_state(self): logger.warning("Failed to create dynesty run plot at checkpoint") except Exception as e: logger.warning( - f"Unexpected error {e} in dynesty plotting. " - "Please report at github.com/bilby-dev/bilby/issues" + f"Unexpected error {e} in dynesty plotting. Please report at github.com/bilby-dev/bilby/issues" ) finally: plt.close("all") @@ -1080,8 +1021,7 @@ def plot_current_state(self): logger.debug("Cannot create Dynesty stats plot with dynamic sampler.") except Exception as e: logger.warning( - f"Unexpected error {e} in dynesty plotting. " - "Please report at github.com/bilby-dev/bilby/issues" + f"Unexpected error {e} in dynesty plotting. Please report at github.com/bilby-dev/bilby/issues" ) finally: plt.close("all") @@ -1112,9 +1052,7 @@ def _run_test(self): self.pbar = self.pbar.close() print("") N = 100 - self.result.samples = pd.DataFrame(self.priors.sample(N))[ - self.search_parameter_keys - ].values + self.result.samples = pd.DataFrame(self.priors.sample(N))[self.search_parameter_keys].values self.result.nested_samples = self.result.samples self.result.log_likelihood_evaluations = np.ones(N) self.result.log_evidence = 1 @@ -1213,9 +1151,7 @@ def dynesty_stats_plot(sampler): axs[-1].legend() axs[-1].set_yscale("log") else: - axs[-2].plot( - np.arange(0, len(lifetimes) - nlive), lifetimes[:-nlive], color="grey" - ) + axs[-2].plot(np.arange(0, len(lifetimes) - nlive), lifetimes[:-nlive], color="grey") axs[-2].plot( np.arange(len(lifetimes) - nlive, len(lifetimes)), lifetimes[-nlive:], diff --git a/bilby/core/sampler/dynesty3_utils.py b/bilby/core/sampler/dynesty3_utils.py index 26e668f58..863d91245 100644 --- a/bilby/core/sampler/dynesty3_utils.py +++ b/bilby/core/sampler/dynesty3_utils.py @@ -190,18 +190,14 @@ def sample(args): evaluation_history = list() for prop in proposals: - u_prop = proposal_funcs[prop]( - u=current_u, **common_kwargs, **proposal_kwargs[prop] - ) + u_prop = proposal_funcs[prop](u=current_u, **common_kwargs, **proposal_kwargs[prop]) u_prop = apply_boundaries_(u_prop=u_prop, **boundary_kwargs) if u_prop is None: continue v_prop = args.prior_transform(u_prop) logl_prop = args.loglikelihood(v_prop) - evaluation_history.append( - SamplerHistoryItem(u=v_prop, v=u_prop, logl=logl_prop) - ) + evaluation_history.append(SamplerHistoryItem(u=v_prop, v=u_prop, logl=logl_prop)) ncall += 1 if logl_prop > args.loglstar: @@ -381,9 +377,7 @@ def build_cache(args): # Initialize internal variables current_v = args.prior_transform(np.array(current_u)) logl = args.loglikelihood(np.array(current_v)) - evaluation_history.append( - SamplerHistoryItem(u=current_u, v=current_v, logl=logl) - ) + evaluation_history.append(SamplerHistoryItem(u=current_u, v=current_v, logl=logl)) accept = 0 reject = 0 nfail = 0 @@ -400,17 +394,13 @@ def build_cache(args): iteration += 1 prop = proposals[iteration % len(proposals)] - u_prop = proposal_funcs[prop]( - u=current_u, **common_kwargs, **proposal_kwargs[prop] - ) + u_prop = proposal_funcs[prop](u=current_u, **common_kwargs, **proposal_kwargs[prop]) u_prop = apply_boundaries_(u_prop=u_prop, **boundary_kwargs) success = False if u_prop is not None: v_prop = args.prior_transform(np.array(u_prop)) logl_prop = args.loglikelihood(np.array(v_prop)) - evaluation_history.append( - SamplerHistoryItem(u=v_prop, v=u_prop, logl=logl_prop) - ) + evaluation_history.append(SamplerHistoryItem(u=v_prop, v=u_prop, logl=logl_prop)) ncall += 1 if logl_prop > args.loglstar: success = True @@ -465,20 +455,14 @@ def build_cache(args): cache = ACTTrackingEnsembleWalk._cache if accept == 0: - logger.warning( - "Unable to find a new point using walk: returning a random point" - ) + logger.warning("Unable to find a new point using walk: returning a random point") u = common_kwargs["rstate"].uniform(size=len(current_u)) v = args.prior_transform(u) logl = args.loglikelihood(v) - evaluation_history = [ - SamplerHistoryItem(u=current_u, v=current_v, logl=logl) - ] + evaluation_history = [SamplerHistoryItem(u=current_u, v=current_v, logl=logl)] cache.append((u, v, logl, ncall, blob, evaluation_history)) elif not np.isfinite(act): - logger.warning( - "Unable to find a new point using walk: try increasing maxmcmc" - ) + logger.warning("Unable to find a new point using walk: try increasing maxmcmc") cache.append((current_u, current_v, logl, ncall, blob, evaluation_history)) elif (thin == -1) or (len(u_list) <= thin): cache.append((current_u, current_v, logl, ncall, blob, evaluation_history)) @@ -486,18 +470,13 @@ def build_cache(args): u_list = u_list[thin::thin] v_list = v_list[thin::thin] logl_list = logl_list[thin::thin] - evaluation_history_list = ( - evaluation_history[thin * ii : thin * (ii + 1)] - for ii in range(len(u_list)) - ) + evaluation_history_list = (evaluation_history[thin * ii : thin * (ii + 1)] for ii in range(len(u_list))) n_found = len(u_list) accept = max(accept // n_found, 1) reject //= n_found nfail //= n_found ncall_list = [ncall // n_found] * n_found - blob_list = [ - dict(accept=accept, reject=reject, fail=nfail, act=act) - ] * n_found + blob_list = [dict(accept=accept, reject=reject, fail=nfail, act=act)] * n_found cache.extend( zip( u_list, @@ -597,9 +576,7 @@ def sample(args): iteration += 1 prop = proposals[iteration % len(proposals)] - u_prop = proposal_funcs[prop]( - current_u, **common_kwargs, **proposal_kwargs[prop] - ) + u_prop = proposal_funcs[prop](current_u, **common_kwargs, **proposal_kwargs[prop]) u_prop = apply_boundaries_(u_prop, **boundary_kwargs) if u_prop is None: @@ -609,9 +586,7 @@ def sample(args): # Check proposed point. v_prop = args.prior_transform(np.array(u_prop)) logl_prop = args.loglikelihood(np.array(v_prop)) - evaluation_history.append( - SamplerHistoryItem(u=v_prop, v=u_prop, logl=logl_prop) - ) + evaluation_history.append(SamplerHistoryItem(u=v_prop, v=u_prop, logl=logl_prop)) if logl_prop > args.loglstar: current_u = u_prop current_v = v_prop @@ -639,9 +614,7 @@ def sample(args): break if not (np.isfinite(act) and accept > 0): - logger.debug( - "Unable to find a new point using walk: returning a random point" - ) + logger.debug("Unable to find a new point using walk: returning a random point") current_u = rstate.uniform(size=len(current_u)) current_v = args.prior_transform(current_u) logl = args.loglikelihood(current_v) @@ -935,6 +908,4 @@ def apply_boundaries_(u_prop, periodic, reflective): return u_prop -proposal_funcs = dict( - diff=propose_differential_evolution, volumetric=propose_volumetric -) +proposal_funcs = dict(diff=propose_differential_evolution, volumetric=propose_volumetric) diff --git a/bilby/core/sampler/dynesty_utils.py b/bilby/core/sampler/dynesty_utils.py index 22e725aaf..b3e3f7be1 100644 --- a/bilby/core/sampler/dynesty_utils.py +++ b/bilby/core/sampler/dynesty_utils.py @@ -89,9 +89,7 @@ class MultiEllipsoidLivePointSampler(MultiEllipsoidSampler): def update_user(self, blob, update=True): LivePointSampler.update_user(self, blob=blob, update=update) - super(MultiEllipsoidLivePointSampler, self).update_rwalk( - blob=blob, update=update - ) + super().update_rwalk(blob=blob, update=update) update_rwalk = update_user @@ -102,7 +100,7 @@ def propose_live(self, *args): """ self.kwargs["nlive"] = self.nlive self.kwargs["live"] = self.live_u - return super(MultiEllipsoidLivePointSampler, self).propose_live(*args) + return super().propose_live(*args) class FixedRWalk: @@ -127,9 +125,7 @@ def __call__(self, args): accepted = list() for prop in proposals: - u_prop = proposal_funcs[prop]( - u=current_u, **common_kwargs, **proposal_kwargs[prop] - ) + u_prop = proposal_funcs[prop](u=current_u, **common_kwargs, **proposal_kwargs[prop]) u_prop = apply_boundaries_(u_prop=u_prop, **boundary_kwargs) if u_prop is None: accepted.append(0) @@ -244,9 +240,7 @@ def build_cache(self): iteration += 1 prop = proposals[iteration % len(proposals)] - u_prop = proposal_funcs[prop]( - u=current_u, **common_kwargs, **proposal_kwargs[prop] - ) + u_prop = proposal_funcs[prop](u=current_u, **common_kwargs, **proposal_kwargs[prop]) u_prop = apply_boundaries_(u_prop=u_prop, **boundary_kwargs) success = False if u_prop is not None: @@ -304,17 +298,13 @@ def build_cache(self): thin = self.thin * iact if accept == 0: - logger.warning( - "Unable to find a new point using walk: returning a random point" - ) + logger.warning("Unable to find a new point using walk: returning a random point") u = common_kwargs["rstate"].uniform(size=len(current_u)) v = args.prior_transform(u) logl = args.loglikelihood(v) self._cache.append((u, v, logl, ncall, blob)) elif not np.isfinite(act): - logger.warning( - "Unable to find a new point using walk: try increasing maxmcmc" - ) + logger.warning("Unable to find a new point using walk: try increasing maxmcmc") self._cache.append((current_u, current_v, logl, ncall, blob)) elif (self.thin == -1) or (len(u_list) <= thin): self._cache.append((current_u, current_v, logl, ncall, blob)) @@ -327,9 +317,7 @@ def build_cache(self): reject //= n_found nfail //= n_found ncall_list = [ncall // n_found] * n_found - blob_list = [ - dict(accept=accept, reject=reject, fail=nfail, scale=args.scale) - ] * n_found + blob_list = [dict(accept=accept, reject=reject, fail=nfail, scale=args.scale)] * n_found self._cache.extend(zip(u_list, v_list, logl_list, ncall_list, blob_list)) logger.debug( f"act: {self.act:.2f}, max failures: {most_failures}, thin: {thin}, " @@ -442,9 +430,7 @@ def __call__(self, args): break if not (np.isfinite(act) and accept > 0): - logger.debug( - "Unable to find a new point using walk: returning a random point" - ) + logger.debug("Unable to find a new point using walk: returning a random point") u = rstate.uniform(size=len(u)) v = args.prior_transform(u) logl = args.loglikelihood(v) diff --git a/bilby/core/sampler/emcee.py b/bilby/core/sampler/emcee.py index 2c12ee354..d09a6090e 100644 --- a/bilby/core/sampler/emcee.py +++ b/bilby/core/sampler/emcee.py @@ -80,7 +80,7 @@ def __init__( **kwargs, ): self._check_version() - super(Emcee, self).__init__( + super().__init__( likelihood=likelihood, priors=priors, outdir=outdir, @@ -139,10 +139,7 @@ def sampler_function_kwargs(self): if self.prerelease: if function_kwargs["mh_proposal"] is not None: - logger.warning( - "The 'mh_proposal' option is no longer used " - "in emcee > 2.2.1, and will be ignored." - ) + logger.warning("The 'mh_proposal' option is no longer used in emcee > 2.2.1, and will be ignored.") del function_kwargs["mh_proposal"] for key in updatekeys: @@ -155,11 +152,7 @@ def sampler_function_kwargs(self): @property def sampler_init_kwargs(self): - init_kwargs = { - key: value - for key, value in self.kwargs.items() - if key not in self.sampler_function_kwargs - } + init_kwargs = {key: value for key, value in self.kwargs.items() if key not in self.sampler_function_kwargs} init_kwargs["lnpostfn"] = _evaluator.call_emcee init_kwargs["dim"] = self.ndim @@ -192,10 +185,7 @@ def nburn(self): def nburn(self, nburn): if isinstance(nburn, (float, int)): if nburn > self.kwargs["iterations"] - 1: - raise ValueError( - "Number of burn-in samples must be smaller " - "than the total number of iterations" - ) + raise ValueError("Number of burn-in samples must be smaller than the total number of iterations") self.__nburn = nburn @@ -249,20 +239,14 @@ def checkpoint_info(self): to write the chain data to disk """ - out_dir = os.path.join( - self.outdir, f"{self.__class__.__name__.lower()}_{self.label}" - ) + out_dir = os.path.join(self.outdir, f"{self.__class__.__name__.lower()}_{self.label}") check_directory_exists_and_if_not_mkdir(out_dir) chain_file = os.path.join(out_dir, "chain.dat") sampler_file = os.path.join(out_dir, "sampler.pickle") - chain_template = ( - "{:d}" + "\t{:.9e}" * (len(self.search_parameter_keys) + 2) + "\n" - ) + chain_template = "{:d}" + "\t{:.9e}" * (len(self.search_parameter_keys) + 2) + "\n" - CheckpointInfo = namedtuple( - "CheckpointInfo", ["sampler_file", "chain_file", "chain_template"] - ) + CheckpointInfo = namedtuple("CheckpointInfo", ["sampler_file", "chain_file", "chain_template"]) checkpoint_info = CheckpointInfo( sampler_file=sampler_file, @@ -284,9 +268,7 @@ def write_current_state(self): Overwrites the stored sampler chain with one that is truncated to only the completed steps """ - logger.info( - f"Checkpointing sampler to file {self.checkpoint_info.sampler_file}" - ) + logger.info(f"Checkpointing sampler to file {self.checkpoint_info.sampler_file}") self.sampler._chain = self.sampler_chain _pool = self.sampler.pool self.sampler.pool = None @@ -317,9 +299,7 @@ def sampler(self): ): import dill - logger.info( - f"Resuming run from checkpoint file {self.checkpoint_info.sampler_file}" - ) + logger.info(f"Resuming run from checkpoint file {self.checkpoint_info.sampler_file}") with open(self.checkpoint_info.sampler_file, "rb") as f: self._sampler = dill.load(f) self._sampler.pool = self.pool @@ -337,8 +317,7 @@ def write_chains_to_file(self, sample): else: points = np.hstack([sample[0], np.array(sample[3])]) data_to_write = "\n".join( - self.checkpoint_info.chain_template.format(ii, *point) - for ii, point in enumerate(points) + self.checkpoint_info.chain_template.format(ii, *point) for ii, point in enumerate(points) ) with open(temp_chain_file, "w") as ff: ff.write(data_to_write) @@ -359,9 +338,7 @@ def _previous_iterations(self): return 0 def _draw_pos0_from_prior(self): - return np.array( - [self.get_random_draw_from_prior() for _ in range(self.nwalkers)] - ) + return np.array([self.get_random_draw_from_prior() for _ in range(self.nwalkers)]) @property def _pos0_shape(self): @@ -418,9 +395,7 @@ def run_sampler(self): self._generate_result() - self.result.samples = self.sampler.chain[:, self.nburn :, :].reshape( - (-1, self.ndim) - ) + self.result.samples = self.sampler.chain[:, self.nburn :, :].reshape((-1, self.ndim)) self.result.walkers = self.sampler.chain return self.result diff --git a/bilby/core/sampler/fake_sampler.py b/bilby/core/sampler/fake_sampler.py index c1d285605..c394cad40 100644 --- a/bilby/core/sampler/fake_sampler.py +++ b/bilby/core/sampler/fake_sampler.py @@ -20,9 +20,7 @@ class FakeSampler(Sampler): sampler_name = "fake_sampler" - default_kwargs = dict( - verbose=True, logl_args=None, logl_kwargs=None, print_progress=True - ) + default_kwargs = dict(verbose=True, logl_args=None, logl_kwargs=None, print_progress=True) def __init__( self, @@ -36,9 +34,9 @@ def __init__( injection_parameters=None, meta_data=None, result_class=None, - **kwargs + **kwargs, ): - super(FakeSampler, self).__init__( + super().__init__( likelihood=likelihood, priors=priors, outdir=outdir, @@ -49,7 +47,7 @@ def __init__( injection_parameters=None, meta_data=None, result_class=None, - **kwargs + **kwargs, ) self._read_parameter_list_from_file(sample_file) self.result.outdir = outdir diff --git a/bilby/core/sampler/kombine.py b/bilby/core/sampler/kombine.py index 467751959..87d270f2d 100644 --- a/bilby/core/sampler/kombine.py +++ b/bilby/core/sampler/kombine.py @@ -78,7 +78,7 @@ def __init__( autoburnin=False, **kwargs, ): - super(Kombine, self).__init__( + super().__init__( likelihood=likelihood, priors=priors, outdir=outdir, @@ -140,8 +140,7 @@ def sampler_init_kwargs(self): init_kwargs = { key: value for key, value in self.kwargs.items() - if key not in self.sampler_function_kwargs - and key not in self.sampler_burnin_kwargs + if key not in self.sampler_function_kwargs and key not in self.sampler_burnin_kwargs } init_kwargs.pop("burnin_verbose") init_kwargs["lnpostfn"] = _evaluator.call_emcee @@ -183,9 +182,7 @@ def run_sampler(self): self.sampler.burnin(**self.sampler_burnin_kwargs) self.kwargs["iterations"] += self._previous_iterations self.nburn = self._previous_iterations - logger.info( - f"Kombine auto-burnin complete. Removing {self.nburn} samples from chains" - ) + logger.info(f"Kombine auto-burnin complete. Removing {self.nburn} samples from chains") self._set_pos0_for_resume() from tqdm.auto import tqdm @@ -212,15 +209,13 @@ def run_sampler(self): tmp_chain = self.sampler.chain[self.nburn :, :, :].copy() self.result.samples = tmp_chain.reshape((-1, self.ndim)) - self.result.walkers = self.sampler.chain.reshape( - (self.nwalkers, self.nsteps, self.ndim) - ) + self.result.walkers = self.sampler.chain.reshape((self.nwalkers, self.nsteps, self.ndim)) return self.result def _setup_pool(self): from kombine import SerialPool - super(Kombine, self)._setup_pool() + super()._setup_pool() if self.pool is None: self.pool = SerialPool() @@ -229,4 +224,4 @@ def _close_pool(self): if isinstance(self.pool, SerialPool): self.pool = None - super(Kombine, self)._close_pool() + super()._close_pool() diff --git a/bilby/core/sampler/nessai.py b/bilby/core/sampler/nessai.py index 440e4a02f..a8067524f 100644 --- a/bilby/core/sampler/nessai.py +++ b/bilby/core/sampler/nessai.py @@ -143,10 +143,7 @@ def log_prior(x, **kwargs): return self.log_prior(theta) def _update_bounds(self): - self.bounds = { - key: [self.priors[key].minimum, self.priors[key].maximum] - for key in self.names - } + self.bounds = {key: [self.priors[key].minimum, self.priors[key].maximum] for key in self.names} def new_point(self, N=1): """Draw a point from the prior""" @@ -212,9 +209,7 @@ def update_result(self): self.result.num_likelihood_evaluations = self.fs.ns.total_likelihood_evaluations self.result.sampling_time = self.fs.ns.sampling_time - self.result.samples = live_points_to_array( - self.fs.posterior_samples, self.search_parameter_keys - ) + self.result.samples = live_points_to_array(self.fs.posterior_samples, self.search_parameter_keys) self.result.log_likelihood_evaluations = self.fs.posterior_samples["logL"] self.result.nested_samples = self.get_nested_samples() self.result.nested_samples["weights"] = self.get_posterior_weights() @@ -291,9 +286,7 @@ def _verify_kwargs_against_default_kwargs(self): self.kwargs["plot"] = self.plot if not self.kwargs["output"]: - self.kwargs["output"] = os.path.join( - self.outdir, f"{self.label}_nessai", "" - ) + self.kwargs["output"] = os.path.join(self.outdir, f"{self.label}_nessai", "") check_directory_exists_and_if_not_mkdir(self.kwargs["output"]) NestedSampler._verify_kwargs_against_default_kwargs(self) diff --git a/bilby/core/sampler/nestle.py b/bilby/core/sampler/nestle.py index 75d93bf69..3a5b2f3c4 100644 --- a/bilby/core/sampler/nestle.py +++ b/bilby/core/sampler/nestle.py @@ -77,18 +77,13 @@ def run_sampler(self): nestle.np.int = int out = nestle.sample( - loglikelihood=self.log_likelihood, - prior_transform=self.prior_transform, - ndim=self.ndim, - **self.kwargs + loglikelihood=self.log_likelihood, prior_transform=self.prior_transform, ndim=self.ndim, **self.kwargs ) print("") self.result.sampler_output = out self.result.samples = nestle.resample_equal(out.samples, out.weights) - self.result.nested_samples = DataFrame( - out.samples, columns=self.search_parameter_keys - ) + self.result.nested_samples = DataFrame(out.samples, columns=self.search_parameter_keys) self.result.nested_samples["weights"] = out.weights self.result.nested_samples["log_likelihood"] = out.logl self.result.log_likelihood_evaluations = self.reorder_loglikelihoods( diff --git a/bilby/core/sampler/polychord.py b/bilby/core/sampler/polychord.py index 7c0886720..ea77a8cb7 100644 --- a/bilby/core/sampler/polychord.py +++ b/bilby/core/sampler/polychord.py @@ -7,7 +7,6 @@ class PyPolyChord(NestedSampler): - """ Bilby wrapper of PyPolyChord https://arxiv.org/abs/1506.00171 @@ -86,9 +85,7 @@ def run_sampler(self): pc_kwargs["base_dir"] = self._sample_file_directory pc_kwargs["file_root"] = self.label pc_kwargs.pop("use_polychord_defaults") - settings = PolyChordSettings( - nDims=self.ndim, nDerived=self.ndim, **pc_kwargs - ) + settings = PolyChordSettings(nDims=self.ndim, nDerived=self.ndim, **pc_kwargs) self._verify_kwargs_against_default_kwargs() out = pypolychord.run_polychord( loglikelihood=self.log_likelihood, @@ -125,7 +122,7 @@ def _translate_kwargs(self, kwargs): def log_likelihood(self, theta): """Overrides the log_likelihood so that PolyChord understands it""" - return super(PyPolyChord, self).log_likelihood(theta), theta + return super().log_likelihood(theta), theta def _read_sample_file(self): """ @@ -138,9 +135,7 @@ def _read_sample_file(self): array_like, array_like: The log_likelihoods and the associated parameters """ - sample_file = ( - self._sample_file_directory + "/" + self.label + "_equal_weights.txt" - ) + sample_file = self._sample_file_directory + "/" + self.label + "_equal_weights.txt" samples = np.loadtxt(sample_file) log_likelihoods = -0.5 * samples[:, 1] physical_parameters = samples[:, -self.ndim :] diff --git a/bilby/core/sampler/proposal.py b/bilby/core/sampler/proposal.py index d23d19b4c..4e615354f 100644 --- a/bilby/core/sampler/proposal.py +++ b/bilby/core/sampler/proposal.py @@ -10,7 +10,7 @@ class Sample(dict): def __init__(self, dictionary=None): if dictionary is None: dictionary = dict() - super(Sample, self).__init__(dictionary) + super().__init__(dictionary) def __add__(self, other): return Sample({key: self[key] + other[key] for key in self.keys()}) @@ -35,7 +35,7 @@ def from_external_type(cls, external_sample, sampler_name): return external_sample -class JumpProposal(object): +class JumpProposal: def __init__(self, priors=None): """A generic class for jump proposals @@ -69,20 +69,13 @@ def __call__(self, sample, **kwargs): return self._apply_boundaries(sample) def _move_reflecting_keys(self, sample): - keys = [ - key for key in sample.keys() if self.priors[key].boundary == "reflective" - ] + keys = [key for key in sample.keys() if self.priors[key].boundary == "reflective"] for key in keys: - if ( - sample[key] > self.priors[key].maximum - or sample[key] < self.priors[key].minimum - ): + if sample[key] > self.priors[key].maximum or sample[key] < self.priors[key].minimum: r = self.priors[key].maximum - self.priors[key].minimum delta = (sample[key] - self.priors[key].minimum) % (2 * r) if delta > r: - sample[key] = ( - 2 * self.priors[key].maximum - self.priors[key].minimum - delta - ) + sample[key] = 2 * self.priors[key].maximum - self.priors[key].minimum - delta elif delta < r: sample[key] = self.priors[key].minimum + delta return sample @@ -90,13 +83,9 @@ def _move_reflecting_keys(self, sample): def _move_periodic_keys(self, sample): keys = [key for key in sample.keys() if self.priors[key].boundary == "periodic"] for key in keys: - if ( - sample[key] > self.priors[key].maximum - or sample[key] < self.priors[key].minimum - ): + if sample[key] > self.priors[key].maximum or sample[key] < self.priors[key].minimum: sample[key] = self.priors[key].minimum + ( - (sample[key] - self.priors[key].minimum) - % (self.priors[key].maximum - self.priors[key].minimum) + (sample[key] - self.priors[key].minimum) % (self.priors[key].maximum - self.priors[key].minimum) ) return sample @@ -106,7 +95,7 @@ def _apply_boundaries(self, sample): return sample -class JumpProposalCycle(object): +class JumpProposalCycle: def __init__(self, proposal_functions, weights, cycle_length=100): """A generic wrapper class for proposal cycles @@ -190,22 +179,18 @@ def __init__(self, step_size, priors=None): priors: See superclass """ - super(NormJump, self).__init__(priors) + super().__init__(priors) self.step_size = step_size def __call__(self, sample, **kwargs): for key in sample.keys(): sample[key] = random.rng.normal(sample[key], self.step_size) - return super(NormJump, self).__call__(sample) + return super().__call__(sample) class EnsembleWalk(JumpProposal): def __init__( - self, - random_number_generator=random.rng.uniform, - n_points=3, - priors=None, - **random_number_generator_args + self, random_number_generator=random.rng.uniform, n_points=3, priors=None, **random_number_generator_args ): """ An ensemble walk @@ -221,7 +206,7 @@ def __init__( random_number_generator_args: Additional keyword arguments for the random number generator """ - super(EnsembleWalk, self).__init__(priors) + super().__init__(priors) self.random_number_generator = random_number_generator self.n_points = n_points self.random_number_generator_args = random_number_generator_args @@ -229,15 +214,11 @@ def __init__( def __call__(self, sample, **kwargs): subset = random.rng.choice(kwargs["coordinates"], self.n_points, replace=False) for i in range(len(subset)): - subset[i] = Sample.from_external_type( - subset[i], kwargs.get("sampler_name", None) - ) + subset[i] = Sample.from_external_type(subset[i], kwargs.get("sampler_name", None)) center_of_mass = self.get_center_of_mass(subset) for x in subset: - sample += (x - center_of_mass) * self.random_number_generator( - **self.random_number_generator_args - ) - return super(EnsembleWalk, self).__call__(sample) + sample += (x - center_of_mass) * self.random_number_generator(**self.random_number_generator_args) + return super().__call__(sample) @staticmethod def get_center_of_mass(subset): @@ -254,18 +235,16 @@ def __init__(self, scale=2.0, priors=None): scale: float, optional Stretching scale. Default is 2.0. """ - super(EnsembleStretch, self).__init__(priors) + super().__init__(priors) self.scale = scale def __call__(self, sample, **kwargs): second_sample = random.rng.choice(kwargs["coordinates"]) - second_sample = Sample.from_external_type( - second_sample, kwargs.get("sampler_name", None) - ) + second_sample = Sample.from_external_type(second_sample, kwargs.get("sampler_name", None)) step = random.rng.uniform(-1, 1) * np.log(self.scale) sample = second_sample + (sample - second_sample) * np.exp(step) self.log_j = len(sample) * step - return super(EnsembleStretch, self).__call__(sample) + return super().__call__(sample) class DifferentialEvolution(JumpProposal): @@ -281,14 +260,14 @@ def __init__(self, sigma=1e-4, mu=1.0, priors=None): mu: float, optional Scale of the randomization. Default is 1.0 """ - super(DifferentialEvolution, self).__init__(priors) + super().__init__(priors) self.sigma = sigma self.mu = mu def __call__(self, sample, **kwargs): a, b = random.rng.choice(kwargs["coordinates"], 2, replace=False) sample = sample + (b - a) * random.rng.normal(self.mu, self.sigma) - return super(DifferentialEvolution, self).__call__(sample) + return super().__call__(sample) class EnsembleEigenVector(JumpProposal): @@ -301,7 +280,7 @@ def __init__(self, priors=None): priors: See superclass """ - super(EnsembleEigenVector, self).__init__(priors) + super().__init__(priors) self.eigen_values = None self.eigen_vectors = None self.covariance = None @@ -338,7 +317,7 @@ def __call__(self, sample, **kwargs): jump_size = np.sqrt(np.fabs(self.eigen_values[i])) * random.rng.normal(0, 1) for j, key in enumerate(sample.keys()): sample[key] += jump_size * self.eigen_vectors[j, i] - return super(EnsembleEigenVector, self).__call__(sample) + return super().__call__(sample) class DrawFlatPrior(JumpProposal): @@ -348,7 +327,7 @@ class DrawFlatPrior(JumpProposal): def __call__(self, sample, **kwargs): sample = _draw_from_flat_priors(sample, self.priors) - return super(DrawFlatPrior, self).__call__(sample) + return super().__call__(sample) def _draw_from_flat_priors(sample, priors): diff --git a/bilby/core/sampler/ptemcee.py b/bilby/core/sampler/ptemcee.py index d5f8d1009..7af0b7418 100644 --- a/bilby/core/sampler/ptemcee.py +++ b/bilby/core/sampler/ptemcee.py @@ -180,7 +180,7 @@ def __init__( verbose=True, **kwargs, ): - super(Ptemcee, self).__init__( + super().__init__( likelihood=likelihood, priors=priors, outdir=outdir, @@ -236,18 +236,10 @@ def __init__( self.store_walkers = store_walkers self.pos0 = pos0 - self._periodic = [ - self.priors[key].boundary == "periodic" - for key in self.search_parameter_keys - ] + self._periodic = [self.priors[key].boundary == "periodic" for key in self.search_parameter_keys] self.priors.sample() - self._minima = np.array( - [self.priors[key].minimum for key in self.search_parameter_keys] - ) - self._range = ( - np.array([self.priors[key].maximum for key in self.search_parameter_keys]) - - self._minima - ) + self._minima = np.array([self.priors[key].minimum for key in self.search_parameter_keys]) + self._range = np.array([self.priors[key].maximum for key in self.search_parameter_keys]) - self._minima self.log10beta_min = log10beta_min if self.log10beta_min is not None: @@ -281,11 +273,7 @@ def sampler_function_kwargs(self): @property def sampler_init_kwargs(self): """Kwargs passed to initialize ptemcee.Sampler()""" - return { - key: value - for key, value in self.kwargs.items() - if key not in self.sampler_function_kwargs - } + return {key: value for key, value in self.kwargs.items() if key not in self.sampler_function_kwargs} def _translate_kwargs(self, kwargs): """Translate kwargs""" @@ -306,10 +294,7 @@ def get_pos0_from_prior(self): """ logger.info("Generating pos0 samples") return np.array( - [ - [self.get_random_draw_from_prior() for _ in range(self.nwalkers)] - for _ in range(self.kwargs["ntemps"]) - ] + [[self.get_random_draw_from_prior() for _ in range(self.nwalkers)] for _ in range(self.kwargs["ntemps"])] ) def get_pos0_from_minimize(self, minimize_list=None): @@ -349,10 +334,7 @@ def neg_log_like(params): return +np.inf # Bounds used in the minimization - bounds = [ - (self.priors[key].minimum, self.priors[key].maximum) - for key in minimize_list - ] + bounds = [(self.priors[key].minimum, self.priors[key].maximum) for key in minimize_list] # Run the minimization step several times to get a range of values trials = 0 @@ -361,9 +343,7 @@ def neg_log_like(params): draw = self.priors.sample() likelihood_copy.parameters.update(draw) x0 = [draw[key] for key in minimize_list] - res = minimize( - neg_log_like, x0, bounds=bounds, method="L-BFGS-B", tol=1e-15 - ) + res = minimize(neg_log_like, x0, bounds=bounds, method="L-BFGS-B", tol=1e-15) if res.success: success.append(res.x) if trials > 100: @@ -415,11 +395,7 @@ def setup_sampler(self): # This is a very ugly hack to support numpy>=1.24 ptemcee.sampler.np.float = float - if ( - os.path.isfile(self.resume_file) - and os.path.getsize(self.resume_file) - and self.resume is True - ): + if os.path.isfile(self.resume_file) and os.path.getsize(self.resume_file) and self.resume is True: import dill logger.info(f"Resume data {self.resume_file} found") @@ -506,7 +482,7 @@ def _close_pool(self): self.sampler.pool = None if "pool" in self.result.sampler_kwargs: del self.result.sampler_kwargs["pool"] - super(Ptemcee, self)._close_pool() + super()._close_pool() @signal_wrapper def run_sampler(self): @@ -532,24 +508,16 @@ def run_sampler(self): ) if self.iteration == self.chain_array.shape[1]: - self.chain_array = np.concatenate( - (self.chain_array, self.get_zero_chain_array()), axis=1 - ) - self.log_likelihood_array = np.concatenate( - (self.log_likelihood_array, self.get_zero_array()), axis=2 - ) - self.log_posterior_array = np.concatenate( - (self.log_posterior_array, self.get_zero_array()), axis=2 - ) + self.chain_array = np.concatenate((self.chain_array, self.get_zero_chain_array()), axis=1) + self.log_likelihood_array = np.concatenate((self.log_likelihood_array, self.get_zero_array()), axis=2) + self.log_posterior_array = np.concatenate((self.log_posterior_array, self.get_zero_array()), axis=2) self.pos0 = pos0 self.chain_array[:, self.iteration, :] = pos0[0, :, :] self.log_likelihood_array[:, :, self.iteration] = log_likelihood self.log_posterior_array[:, :, self.iteration] = log_posterior - self.mean_log_posterior = np.mean( - self.log_posterior_array[:, :, : self.iteration], axis=1 - ) + self.mean_log_posterior = np.mean(self.log_posterior_array[:, :, : self.iteration], axis=1) # (nwalkers, ntemps, iterations) # so mean_log_posterior is shaped (nwalkers, iterations) @@ -567,9 +535,7 @@ def run_sampler(self): logger.debug(f"Minimum iteration = {minimum_iteration}") # Calculate the maximum discard number - discard_max = np.max( - [self.convergence_inputs.burn_in_fixed_discard, minimum_iteration] - ) + discard_max = np.max([self.convergence_inputs.burn_in_fixed_discard, minimum_iteration]) if self.iteration > discard_max + self.nwalkers: # If we have taken more than nwalkers steps after the discard @@ -618,13 +584,13 @@ def run_sampler(self): self.write_current_state(plot=self.check_point_plot) # Get 0-likelihood samples and store in the result - self.result.samples = self.chain_array[ - :, self.discard + self.nburn : self.iteration : self.thin, : - ].reshape((-1, self.ndim)) + self.result.samples = self.chain_array[:, self.discard + self.nburn : self.iteration : self.thin, :].reshape( + (-1, self.ndim) + ) loglikelihood = self.log_likelihood_array[ 0, :, self.discard + self.nburn : self.iteration : self.thin ] # nwalkers, nsteps - self.result.log_likelihood_evaluations = loglikelihood.reshape((-1)) + self.result.log_likelihood_evaluations = loglikelihood.reshape(-1) if self.store_walkers: self.result.walkers = self.sampler.chain @@ -644,9 +610,7 @@ def run_sampler(self): self.result.log_evidence = log_evidence self.result.log_evidence_err = log_evidence_err - self.result.sampling_time = datetime.timedelta( - seconds=np.sum(self.time_per_check) - ) + self.result.sampling_time = datetime.timedelta(seconds=np.sum(self.time_per_check)) self._close_pool() @@ -903,9 +867,8 @@ def check_iteration( # Calculate convergence boolean converged = gelman_rubin_statistic < ci.Q_tol and ci.nsamples < nsamples_effective logger.debug( - "Convergence: Q GRAD_WINDOW_LENGTH: - gradient_tau = get_max_gradient( - check_taus, axis=0, window_length=GRAD_WINDOW_LENGTH - ) + gradient_tau = get_max_gradient(check_taus, axis=0, window_length=GRAD_WINDOW_LENGTH) if gradient_tau < ci.gradient_tau: - logger.debug( - f"tau usable as {gradient_tau} < gradient_tau={ci.gradient_tau}" - ) + logger.debug(f"tau usable as {gradient_tau} < gradient_tau={ci.gradient_tau}") tau_usable = True else: - logger.debug( - f"tau not usable as {gradient_tau} > gradient_tau={ci.gradient_tau}" - ) + logger.debug(f"tau not usable as {gradient_tau} > gradient_tau={ci.gradient_tau}") tau_usable = False check_mean_log_posterior = mean_log_posterior[:, -nsteps_to_check:] @@ -1027,11 +984,7 @@ def get_max_gradient(x, axis=0, window_length=11, polyorder=2, smooth=False): if smooth: x = savgol_filter(x, axis=axis, window_length=window_length, polyorder=3) - return np.max( - savgol_filter( - x, axis=axis, window_length=window_length, polyorder=polyorder, deriv=1 - ) - ) + return np.max(savgol_filter(x, axis=axis, window_length=window_length, polyorder=polyorder, deriv=1)) def get_Q_convergence(samples): @@ -1112,11 +1065,7 @@ def print_progress( tswap_acceptance_str = f"{np.min(tswap_acceptance_fraction):1.2f}-{np.max(tswap_acceptance_fraction):1.2f}" ave_time_per_check = np.mean(time_per_check[-3:]) - time_left = ( - (convergence_inputs.nsamples - nsamples_effective) - * ave_time_per_check - / samples_per_check - ) + time_left = (convergence_inputs.nsamples - nsamples_effective) * ave_time_per_check / samples_per_check if time_left > 0: time_left = str(datetime.timedelta(seconds=int(time_left))) else: @@ -1133,16 +1082,9 @@ def print_progress( Q_str = f"{Q:0.2f}" - evals_per_check = ( - sampler.nwalkers * sampler.ntemps * convergence_inputs.niterations_per_check - ) + evals_per_check = sampler.nwalkers * sampler.ntemps * convergence_inputs.niterations_per_check - approximate_ncalls = ( - convergence_inputs.niterations_per_check - * iteration - * sampler.nwalkers - * sampler.ntemps - ) + approximate_ncalls = convergence_inputs.niterations_per_check * iteration * sampler.nwalkers * sampler.ntemps ncalls = f"{approximate_ncalls:1.1e}" eval_timing = f"{1000.0 * ave_time_per_check / evals_per_check:1.2f}ms/ev" @@ -1194,9 +1136,7 @@ def calculate_tau_array(samples, search_parameter_keys, ci): if ci.ignore_keys_for_tau and ci.ignore_keys_for_tau in key: continue try: - tau_array[ii, jj] = emcee.autocorr.integrated_time( - samples[ii, :, jj], c=ci.autocorr_c, tol=0 - )[0] + tau_array[ii, jj] = emcee.autocorr.integrated_time(samples[ii, :, jj], c=ci.autocorr_c, tol=0)[0] except emcee.autocorr.AutocorrError: tau_array[ii, jj] = np.inf return tau_array @@ -1229,9 +1169,7 @@ def checkpoint( # Store the samples if possible if nsamples_effective > 0: filename = f"{outdir}/{label}_samples.txt" - samples = np.array(chain_array)[ - :, discard + nburn : iteration : thin, : - ].reshape((-1, ndim)) + samples = np.array(chain_array)[:, discard + nburn : iteration : thin, :].reshape((-1, ndim)) df = pd.DataFrame(samples, columns=search_parameter_keys) df.to_csv(filename, index=False, header=True, sep=" ") @@ -1305,9 +1243,7 @@ def plot_walkers(walkers, nburn, thin, parameter_labels, outdir, label, discard= color="C0", **scatter_kwargs, ) - axh.hist( - walkers[:, discard + nburn :: thin, i].reshape((-1)), bins=50, alpha=0.8 - ) + axh.hist(walkers[:, discard + nburn :: thin, i].reshape(-1), bins=50, alpha=0.8) for i, (ax, axh) in enumerate(axes): axh.set_xlabel(parameter_labels[i]) @@ -1389,8 +1325,7 @@ def compute_evidence( if any(np.isinf(mean_lnlikes)): logger.warning( - "mean_lnlikes contains inf: recalculating without" - f" the {len(betas[np.isinf(mean_lnlikes)])} infs" + f"mean_lnlikes contains inf: recalculating without the {len(betas[np.isinf(mean_lnlikes)])} infs" ) idxs = np.isinf(mean_lnlikes) mean_lnlikes = mean_lnlikes[~idxs] @@ -1414,8 +1349,7 @@ def compute_evidence( ax2.semilogx(min_betas, evidence, "-o") ax2.set_ylabel( - r"$\int_{\beta_{min}}^{\beta=1}" - + r"\langle \log(\mathcal{L})\rangle d\beta$", + r"$\int_{\beta_{min}}^{\beta=1}" + r"\langle \log(\mathcal{L})\rangle d\beta$", size=16, ) ax2.set_xlabel(r"$\beta_{min}$") @@ -1431,7 +1365,7 @@ def do_nothing_function(): pass -class LikePriorEvaluator(object): +class LikePriorEvaluator: """ This class is copied and modified from ptemcee.LikePriorEvaluator, see https://github.com/willvousden/ptemcee for the original version @@ -1448,15 +1382,10 @@ def __init__(self): def _setup_periodic(self): priors = _sampling_convenience_dump.priors search_parameter_keys = _sampling_convenience_dump.search_parameter_keys - self._periodic = [ - priors[key].boundary == "periodic" for key in search_parameter_keys - ] + self._periodic = [priors[key].boundary == "periodic" for key in search_parameter_keys] priors.sample() self._minima = np.array([priors[key].minimum for key in search_parameter_keys]) - self._range = ( - np.array([priors[key].maximum for key in search_parameter_keys]) - - self._minima - ) + self._range = np.array([priors[key].maximum for key in search_parameter_keys]) - self._minima self.periodic_set = True def _wrap_periodic(self, array): @@ -1478,9 +1407,7 @@ def logl(self, v_array): parameters = _sampling_convenience_dump.parameters.copy() parameters.update({key: v for key, v in zip(search_parameter_keys, v_array)}) if priors.evaluate_constraints(parameters) > 0: - return _safe_likelihood_call( - likelihood, parameters, _sampling_convenience_dump.use_ratio - ) + return _safe_likelihood_call(likelihood, parameters, _sampling_convenience_dump.use_ratio) else: return np.nan_to_num(-np.inf) diff --git a/bilby/core/sampler/ptmcmc.py b/bilby/core/sampler/ptmcmc.py index f2a771cb0..1c5f4a7ea 100644 --- a/bilby/core/sampler/ptmcmc.py +++ b/bilby/core/sampler/ptmcmc.py @@ -87,8 +87,7 @@ def __init__( skip_import_verification=False, **kwargs, ): - - super(PTMCMCSampler, self).__init__( + super().__init__( likelihood=likelihood, priors=priors, outdir=outdir, @@ -113,9 +112,7 @@ def _verify_external_sampler(self): try: __import__(external_sampler_name) except (ImportError, SystemExit): - raise SamplerNotInstalledError( - f"Sampler {external_sampler_name} is not installed on this system" - ) + raise SamplerNotInstalledError(f"Sampler {external_sampler_name} is not installed on this system") def _translate_kwargs(self, kwargs): kwargs = super()._translate_kwargs(kwargs) @@ -193,9 +190,7 @@ def run_sampler(self): ) if self.custom_proposals is not None: for proposal in self.custom_proposals: - logger.info( - f"Adding {proposal} to proposals with weight {self.custom_proposals[proposal][1]}" - ) + logger.info(f"Adding {proposal} to proposals with weight {self.custom_proposals[proposal][1]}") sampler.addProposalToCycle( self.custom_proposals[proposal][0], self.custom_proposals[proposal][1], diff --git a/bilby/core/sampler/pymc.py b/bilby/core/sampler/pymc.py index 5d263a54d..4b175dbe3 100644 --- a/bilby/core/sampler/pymc.py +++ b/bilby/core/sampler/pymc.py @@ -108,7 +108,7 @@ def __init__( self.default_step_kwargs = {m.__name__.lower(): None for m in STEP_METHODS} self.default_kwargs.update(self.default_step_kwargs) - super(Pymc, self).__init__( + super().__init__( likelihood=likelihood, priors=priors, outdir=outdir, @@ -234,9 +234,7 @@ def setup_prior_mapping(self): prior_map["Cosine"] = {"internal": self._cosine_prior} prior_map["PowerLaw"] = {"internal": self._powerlaw_prior} prior_map["LogUniform"] = {"internal": self._powerlaw_prior} - prior_map["MultivariateGaussian"] = { - "internal": self._multivariate_normal_prior - } + prior_map["MultivariateGaussian"] = {"internal": self._multivariate_normal_prior} prior_map["MultivariateNormal"] = {"internal": self._multivariate_normal_prior} def _deltafunction_prior(self, key, **kwargs): @@ -270,15 +268,12 @@ def __init__(self, lower=0.0, upper=np.pi): self.upper = upper = tt.as_tensor_variable(floatX(upper)) self.norm = tt.cos(lower) - tt.cos(upper) self.mean = ( - tt.sin(upper) - + lower * tt.cos(lower) - - tt.sin(lower) - - upper * tt.cos(upper) + tt.sin(upper) + lower * tt.cos(lower) - tt.sin(lower) - upper * tt.cos(upper) ) / self.norm transform = pymc.distributions.transforms.interval(lower, upper) - super(PymcSine, self).__init__(transform=transform) + super().__init__(transform=transform) def logp(self, value): upper = self.upper @@ -289,9 +284,7 @@ def logp(self, value): value <= upper, ) - return PymcSine( - key, lower=self.priors[key].minimum, upper=self.priors[key].maximum - ) + return PymcSine(key, lower=self.priors[key].minimum, upper=self.priors[key].maximum) else: raise ValueError(f"Prior for '{key}' is not a Sine") @@ -314,15 +307,12 @@ def __init__(self, lower=-np.pi / 2.0, upper=np.pi / 2.0): self.upper = upper = tt.as_tensor_variable(floatX(upper)) self.norm = tt.sin(upper) - tt.sin(lower) self.mean = ( - upper * tt.sin(upper) - + tt.cos(upper) - - lower * tt.sin(lower) - - tt.cos(lower) + upper * tt.sin(upper) + tt.cos(upper) - lower * tt.sin(lower) - tt.cos(lower) ) / self.norm transform = pymc.distributions.transforms.interval(lower, upper) - super(PymcCosine, self).__init__(transform=transform) + super().__init__(transform=transform) def logp(self, value): upper = self.upper @@ -333,9 +323,7 @@ def logp(self, value): value <= upper, ) - return PymcCosine( - key, lower=self.priors[key].minimum, upper=self.priors[key].maximum - ) + return PymcCosine(key, lower=self.priors[key].minimum, upper=self.priors[key].maximum) else: raise ValueError(f"Prior for '{key}' is not a Cosine") @@ -348,7 +336,6 @@ def _powerlaw_prior(self, key): pymc, _, floatX = self._import_external_sampler() _, tt, _ = self._import_tensor() if isinstance(self.priors[key], PowerLaw): - # check power law is set if not hasattr(self.priors[key], "alpha"): raise AttributeError("No 'alpha' attribute set for PowerLaw prior") @@ -373,16 +360,11 @@ def __init__(self, lower, upper, alpha, testval=1): self.norm = 1.0 / (tt.log(self.upper / self.lower)) else: beta = 1.0 + self.alpha - self.norm = 1.0 / ( - beta - * (tt.pow(self.upper, beta) - tt.pow(self.lower, beta)) - ) + self.norm = 1.0 / (beta * (tt.pow(self.upper, beta) - tt.pow(self.lower, beta))) transform = pymc.distributions.transforms.interval(lower, upper) - super(PymcPowerLaw, self).__init__( - transform=transform, testval=testval - ) + super().__init__(transform=transform, testval=testval) def logp(self, value): upper = self.upper @@ -499,9 +481,7 @@ def run_sampler(self): elif isinstance(self.step_method, dict): for key in self.step_method: if key not in self._search_parameter_keys: - raise ValueError( - f"Setting a step method for an unknown parameter '{key}'" - ) + raise ValueError(f"Setting a step method for an unknown parameter '{key}'") else: # check if using a compound step (a list of step # methods for a particular parameter) @@ -511,9 +491,7 @@ def run_sampler(self): sms = [self.step_method[key]] for sm in sms: if sm.lower() not in step_methods: - raise ValueError( - f"Using invalid step method '{self.step_method[key]}'" - ) + raise ValueError(f"Using invalid step method '{self.step_method[key]}'") else: # check if using a compound step (a list of step # methods for a particular parameter) @@ -615,9 +593,7 @@ def run_sampler(self): for sms in self.step_method: curmethod = sms.lower() methodslist.append(curmethod) - args, nuts_kwargs = self._create_args_and_nuts_kwargs( - curmethod, nuts_kwargs, step_kwargs - ) + args, nuts_kwargs = self._create_args_and_nuts_kwargs(curmethod, nuts_kwargs, step_kwargs) compound.append(pymc.__dict__[step_methods[curmethod]](**args)) self.kwargs["step"] = compound else: @@ -625,12 +601,8 @@ def run_sampler(self): if self.step_method is not None: curmethod = self.step_method.lower() methodslist.append(curmethod) - args, nuts_kwargs = self._create_args_and_nuts_kwargs( - curmethod, nuts_kwargs, step_kwargs - ) - self.kwargs["step"] = pymc.__dict__[step_methods[curmethod]]( - **args - ) + args, nuts_kwargs = self._create_args_and_nuts_kwargs(curmethod, nuts_kwargs, step_kwargs) + self.kwargs["step"] = pymc.__dict__[step_methods[curmethod]](**args) # check whether only NUTS step method has been assigned if np.all([sm.lower() == "nuts" for sm in methodslist]): @@ -653,9 +625,7 @@ def run_sampler(self): posterior = trace.posterior.to_dataframe().reset_index() self.result.samples = posterior[self.search_parameter_keys] - self.result.log_likelihood_evaluations = np.sum( - trace.log_likelihood.likelihood.values, axis=-1 - ).flatten() + self.result.log_likelihood_evaluations = np.sum(trace.log_likelihood.likelihood.values, axis=-1).flatten() self.result.sampler_output = np.nan self.calculate_autocorrelation(self.result.samples) self.result.log_evidence = np.nan @@ -670,9 +640,7 @@ def _create_args_and_nuts_kwargs(self, curmethod, nuts_kwargs, step_kwargs): args = step_kwargs.get(curmethod, {}) return args, nuts_kwargs - def _create_nuts_kwargs( - self, curmethod, key, nuts_kwargs, pymc, step_kwargs, step_methods - ): + def _create_nuts_kwargs(self, curmethod, key, nuts_kwargs, pymc, step_kwargs, step_methods): if curmethod == "nuts": args, nuts_kwargs = self._get_nuts_args(nuts_kwargs, step_kwargs) else: @@ -680,9 +648,7 @@ def _create_nuts_kwargs( args = step_kwargs.get(curmethod, {}) else: args = {} - self.kwargs["step"].append( - pymc.__dict__[step_methods[curmethod]](vars=[self.pymc_priors[key]], **args) - ) + self.kwargs["step"].append(pymc.__dict__[step_methods[curmethod]](vars=[self.pymc_priors[key]], **args)) return nuts_kwargs @staticmethod @@ -721,46 +687,30 @@ def set_prior(self): try: self.pymc_priors[key] = self.priors[key].ln_prob(sampler=self) except RuntimeError: - raise RuntimeError((f"Problem setting PyMC prior for '{key}'")) + raise RuntimeError(f"Problem setting PyMC prior for '{key}'") else: # use Prior distribution name distname = self.priors[key].__class__.__name__ if distname in self.prior_map: # check if we have a predefined PyMC distribution - if ( - "pymc" in self.prior_map[distname] - and "argmap" in self.prior_map[distname] - ): + if "pymc" in self.prior_map[distname] and "argmap" in self.prior_map[distname]: # check the required arguments for the PyMC distribution pymcdistname = self.prior_map[distname]["pymc"] if pymcdistname not in pymc.__dict__: - raise ValueError( - f"Prior '{pymcdistname}' is not a known PyMC distribution." - ) + raise ValueError(f"Prior '{pymcdistname}' is not a known PyMC distribution.") - reqargs = infer_args_from_method( - pymc.__dict__[pymcdistname].dist - ) + reqargs = infer_args_from_method(pymc.__dict__[pymcdistname].dist) # set keyword arguments priorkwargs = {} - for (targ, parg) in self.prior_map[distname][ - "argmap" - ].items(): + for targ, parg in self.prior_map[distname]["argmap"].items(): if hasattr(self.priors[key], targ): if parg in reqargs: if "argtransform" in self.prior_map[distname]: - if ( - targ - in self.prior_map[distname][ - "argtransform" - ] - ): - tfunc = self.prior_map[distname][ - "argtransform" - ][targ] + if targ in self.prior_map[distname]["argtransform"]: + tfunc = self.prior_map[distname]["argtransform"][targ] else: def tfunc(x): @@ -771,29 +721,19 @@ def tfunc(x): def tfunc(x): return x - priorkwargs[parg] = tfunc( - getattr(self.priors[key], targ) - ) + priorkwargs[parg] = tfunc(getattr(self.priors[key], targ)) else: raise ValueError(f"Unknown argument {parg}") else: if parg in reqargs: priorkwargs[parg] = None - self.pymc_priors[key] = pymc.__dict__[pymcdistname]( - key, **priorkwargs - ) + self.pymc_priors[key] = pymc.__dict__[pymcdistname](key, **priorkwargs) elif "internal" in self.prior_map[distname]: - self.pymc_priors[key] = self.prior_map[distname][ - "internal" - ](key) + self.pymc_priors[key] = self.prior_map[distname]["internal"](key) else: - raise ValueError( - f"Prior '{distname}' is not a known distribution." - ) + raise ValueError(f"Prior '{distname}' is not a known distribution.") else: - raise ValueError( - f"Prior '{distname}' is not a known distribution." - ) + raise ValueError(f"Prior '{distname}' is not a known distribution.") def set_likelihood(self): """ @@ -805,7 +745,6 @@ def set_likelihood(self): _, tt, _ = self._import_tensor() class LogLike(tt.Op): - itypes = [tt.dvector] otypes = [tt.dscalar] @@ -819,18 +758,14 @@ def __init__(self, parameters, loglike, priors): if isinstance(self.priors[key], float): self.parameters[key] = self.priors[key] - self.logpgrad = LogLikeGrad( - self.parameters, self.likelihood, self.priors - ) + self.logpgrad = LogLikeGrad(self.parameters, self.likelihood, self.priors) def perform(self, node, inputs, outputs): (theta,) = inputs for i, key in enumerate(self.parameters): self.parameters[key] = theta[i] - logl = _safe_likelihood_call( - self.likelihood.log_likelihood, self.parameters - ) + logl = _safe_likelihood_call(self.likelihood.log_likelihood, self.parameters) outputs[0][0] = np.array(logl) def grad(self, inputs, g): @@ -839,7 +774,6 @@ def grad(self, inputs, g): # create Op for calculating the gradient of the log likelihood class LogLikeGrad(tt.Op): - itypes = [tt.dvector] otypes = [tt.dvector] @@ -861,14 +795,10 @@ def perform(self, node, inputs, outputs): def lnlike(values): for i, key in enumerate(self.parameters): self.parameters[key] = values[i] - return _safe_likelihood_call( - self.likelihood.log_likelihood, self.parameters - ) + return _safe_likelihood_call(self.likelihood.log_likelihood, self.parameters) # calculate gradients - grads = derivatives( - theta, lnlike, abseps=1e-5, mineps=1e-12, reltol=1e-2 - ) + grads = derivatives(theta, lnlike, abseps=1e-5, mineps=1e-12, reltol=1e-2) outputs[0][0] = grads @@ -881,9 +811,7 @@ def lnlike(values): or not hasattr(self.likelihood, "x") or not hasattr(self.likelihood, "y") ): - raise ValueError( - "Gaussian Likelihood does not have all the correct attributes!" - ) + raise ValueError("Gaussian Likelihood does not have all the correct attributes!") if "sigma" in self.pymc_priors: # if sigma is suppled use that value @@ -907,12 +835,8 @@ def lnlike(values): ) elif isinstance(self.likelihood, PoissonLikelihood): # check required attributes exist - if not hasattr(self.likelihood, "x") or not hasattr( - self.likelihood, "y" - ): - raise ValueError( - "Poisson Likelihood does not have all the correct attributes!" - ) + if not hasattr(self.likelihood, "x") or not hasattr(self.likelihood, "y"): + raise ValueError("Poisson Likelihood does not have all the correct attributes!") for key in self.pymc_priors: if key not in self.likelihood.function_keys: @@ -925,12 +849,8 @@ def lnlike(values): pymc.Poisson("likelihood", mu=model, observed=self.likelihood.y) elif isinstance(self.likelihood, ExponentialLikelihood): # check required attributes exist - if not hasattr(self.likelihood, "x") or not hasattr( - self.likelihood, "y" - ): - raise ValueError( - "Exponential Likelihood does not have all the correct attributes!" - ) + if not hasattr(self.likelihood, "x") or not hasattr(self.likelihood, "y"): + raise ValueError("Exponential Likelihood does not have all the correct attributes!") for key in self.pymc_priors: if key not in self.likelihood.function_keys: @@ -940,9 +860,7 @@ def lnlike(values): model = self.likelihood.func(self.likelihood.x, **self.pymc_priors) # set the distribution - pymc.Exponential( - "likelihood", lam=1.0 / model, observed=self.likelihood.y - ) + pymc.Exponential("likelihood", lam=1.0 / model, observed=self.likelihood.y) elif isinstance(self.likelihood, StudentTLikelihood): # check required attributes exist if ( @@ -951,9 +869,7 @@ def lnlike(values): or not hasattr(self.likelihood, "nu") or not hasattr(self.likelihood, "sigma") ): - raise ValueError( - "StudentT Likelihood does not have all the correct attributes!" - ) + raise ValueError("StudentT Likelihood does not have all the correct attributes!") if "nu" in self.pymc_priors: # if nu is suppled use that value @@ -981,24 +897,18 @@ def lnlike(values): (GravitationalWaveTransient, BasicGravitationalWaveTransient), ): # set theano Op - pass _search_parameter_keys, which only contains non-fixed variables - logl = LogLike( - self._search_parameter_keys, self.likelihood, self.pymc_priors - ) + logl = LogLike(self._search_parameter_keys, self.likelihood, self.pymc_priors) parameters = dict() for key in self._search_parameter_keys: try: parameters[key] = self.pymc_priors[key] except KeyError: - raise KeyError( - f"Unknown key '{key}' when setting GravitationalWaveTransient likelihood" - ) + raise KeyError(f"Unknown key '{key}' when setting GravitationalWaveTransient likelihood") # convert to tensor variable values = tt.as_tensor_variable(list(parameters.values())) - pymc.DensityDist( - "likelihood", lambda v: logl(v), observed={"v": values} - ) + pymc.DensityDist("likelihood", lambda v: logl(v), observed={"v": values}) else: raise ValueError("Unknown likelihood has been provided") diff --git a/bilby/core/sampler/pymultinest.py b/bilby/core/sampler/pymultinest.py index 0a9bb0aaf..f2a022d07 100644 --- a/bilby/core/sampler/pymultinest.py +++ b/bilby/core/sampler/pymultinest.py @@ -76,9 +76,9 @@ def __init__( exit_code=77, skip_import_verification=False, temporary_directory=True, - **kwargs + **kwargs, ): - super(Pymultinest, self).__init__( + super().__init__( likelihood=likelihood, priors=priors, outdir=outdir, @@ -88,7 +88,7 @@ def __init__( skip_import_verification=skip_import_verification, exit_code=exit_code, temporary_directory=temporary_directory, - **kwargs + **kwargs, ) self._apply_multinest_boundaries() @@ -145,15 +145,13 @@ def run_sampler(self): LogLikelihood=self.log_likelihood, Prior=self.prior_transform, n_dims=self.ndim, - **self.kwargs + **self.kwargs, ) self._calculate_and_save_sampling_time() self._clean_up_run_directory() - post_equal_weights = os.path.join( - self.outputfiles_basename, "post_equal_weights.dat" - ) + post_equal_weights = os.path.join(self.outputfiles_basename, "post_equal_weights.dat") post_equal_weights_data = np.loadtxt(post_equal_weights) self.result.log_likelihood_evaluations = post_equal_weights_data[:, -1] self.result.sampler_output = out @@ -187,7 +185,6 @@ def _nested_samples(self): nested_samples = pd.DataFrame( np.vstack([dead_points, live_points]).copy(), - columns=self.search_parameter_keys - + ["log_likelihood", "log_prior_volume", "mode"], + columns=self.search_parameter_keys + ["log_likelihood", "log_prior_volume", "mode"], ) return nested_samples diff --git a/bilby/core/sampler/ultranest.py b/bilby/core/sampler/ultranest.py index 756ca2c96..f2295b820 100644 --- a/bilby/core/sampler/ultranest.py +++ b/bilby/core/sampler/ultranest.py @@ -86,7 +86,7 @@ def __init__( callback_interval=10, **kwargs, ): - super(Ultranest, self).__init__( + super().__init__( likelihood=likelihood, priors=priors, outdir=outdir, @@ -136,10 +136,7 @@ def _viz_callback(self, *args, **kwargs): self._viz_callback_counter += 1 def _apply_ultranest_boundaries(self): - if ( - self.kwargs["wrapped_params"] is None - or len(self.kwargs.get("wrapped_params", [])) == 0 - ): + if self.kwargs["wrapped_params"] is None or len(self.kwargs.get("wrapped_params", [])) == 0: self.kwargs["wrapped_params"] = [] for param, value in self.priors.items(): if param in self.search_parameter_keys: @@ -154,7 +151,7 @@ def _copy_temporary_directory_contents_to_proper_path(self): Do not delete the temporary directory. """ if inspect.stack()[1].function != "_viz_callback": - super(Ultranest, self)._copy_temporary_directory_contents_to_proper_path() + super()._copy_temporary_directory_contents_to_proper_path() @property def sampler_function_kwargs(self): @@ -244,8 +241,7 @@ def run_sampler(self): sampler.stepsampler = stepsampler else: logger.warning( - "The supplied step sampler is not the correct type. " - "The default step sampling will be used instead." + "The supplied step sampler is not the correct type. The default step sampling will be used instead." ) if self.use_temporary_directory: @@ -275,22 +271,18 @@ def _generate_result(self, out): nested_samples = DataFrame(data, columns=self.search_parameter_keys) nested_samples["weights"] = weights nested_samples["log_likelihood"] = out["weighted_samples"]["logl"] - self.result.log_likelihood_evaluations = np.array( - out["weighted_samples"]["logl"] - )[mask] + self.result.log_likelihood_evaluations = np.array(out["weighted_samples"]["logl"])[mask] self.result.sampler_output = out self.result.samples = data[mask, :] self.result.nested_samples = nested_samples self.result.log_evidence = out["logz"] self.result.log_evidence_err = out["logzerr"] if self.kwargs["num_live_points"] is not None: - self.result.information_gain = ( - np.power(out["logzerr"], 2) * self.kwargs["num_live_points"] - ) + self.result.information_gain = np.power(out["logzerr"], 2) * self.kwargs["num_live_points"] self.result.outputfiles_basename = self.outputfiles_basename self.result.sampling_time = datetime.timedelta(seconds=self.total_sampling_time) def log_likelihood(self, theta): - log_l = super(Ultranest, self).log_likelihood(theta=theta) + log_l = super().log_likelihood(theta=theta) return np.nan_to_num(log_l) diff --git a/bilby/core/sampler/zeus.py b/bilby/core/sampler/zeus.py index ad6e7edb8..0925e1bf2 100644 --- a/bilby/core/sampler/zeus.py +++ b/bilby/core/sampler/zeus.py @@ -67,7 +67,7 @@ def __init__( burn_in_act=3, **kwargs, ): - super(Zeus, self).__init__( + super().__init__( likelihood=likelihood, priors=priors, outdir=outdir, @@ -84,7 +84,7 @@ def __init__( ) def _translate_kwargs(self, kwargs): - super(Zeus, self)._translate_kwargs(kwargs=kwargs) + super()._translate_kwargs(kwargs=kwargs) # check if using emcee-style arguments if "start" not in kwargs: @@ -104,11 +104,7 @@ def sampler_function_kwargs(self): @property def sampler_init_kwargs(self): - init_kwargs = { - key: value - for key, value in self.kwargs.items() - if key not in self.sampler_function_kwargs - } + init_kwargs = {key: value for key, value in self.kwargs.items() if key not in self.sampler_function_kwargs} init_kwargs["logprob_fn"] = _evaluator.call_emcee init_kwargs["ndim"] = self.ndim @@ -117,7 +113,7 @@ def sampler_init_kwargs(self): def write_current_state(self): self._sampler.distribute = map - super(Zeus, self).write_current_state() + super().write_current_state() self._sampler.distribute = getattr(self._sampler.pool, "map", map) def _initialise_sampler(self): @@ -152,9 +148,7 @@ def run_sampler(self): sampler_function_kwargs["start"] = self.pos0 # main iteration loop - for sample in self.sampler.sample( - iterations=iterations, **sampler_function_kwargs - ): + for sample in self.sampler.sample(iterations=iterations, **sampler_function_kwargs): self.write_chains_to_file(sample) self._close_pool() self.write_current_state() @@ -178,9 +172,7 @@ def _generate_result(self): f"`nburn < nsteps` ({self.result.nburn} < {self.nsteps})." " Try increasing the number of steps." ) - blobs = np.array(self.sampler.get_blobs(flat=True, discard=self.nburn)).reshape( - (-1, 2) - ) + blobs = np.array(self.sampler.get_blobs(flat=True, discard=self.nburn)).reshape((-1, 2)) log_likelihoods, log_priors = blobs.T self.result.log_likelihood_evaluations = log_likelihoods self.result.log_prior_evaluations = log_priors diff --git a/bilby/core/series.py b/bilby/core/series.py index ba1d0ffcb..0da57297d 100644 --- a/bilby/core/series.py +++ b/bilby/core/series.py @@ -1,19 +1,18 @@ from . import utils -class CoupledTimeAndFrequencySeries(object): - +class CoupledTimeAndFrequencySeries: def __init__(self, duration=None, sampling_frequency=None, start_time=0): - """ A waveform generator - - Parameters - ========== - sampling_frequency: float, optional - The sampling frequency - duration: float, optional - Time duration of data - start_time: float, optional - Starting time of the time array + """A waveform generator + + Parameters + ========== + sampling_frequency: float, optional + The sampling frequency + duration: float, optional + Time duration of data + start_time: float, optional + Starting time of the time array """ self._duration = duration self._sampling_frequency = sampling_frequency @@ -24,12 +23,14 @@ def __init__(self, duration=None, sampling_frequency=None, start_time=0): self._time_array = None def __repr__(self): - return self.__class__.__name__ + '(duration={}, sampling_frequency={}, start_time={})'\ - .format(self.duration, self.sampling_frequency, self.start_time) + return ( + self.__class__.__name__ + + f"(duration={self.duration}, sampling_frequency={self.sampling_frequency}, start_time={self.start_time})" + ) @property def frequency_array(self): - """ Frequency array for the waveforms. Automatically updates if sampling_frequency or duration are updated. + """Frequency array for the waveforms. Automatically updates if sampling_frequency or duration are updated. Returns ======= @@ -38,12 +39,13 @@ def frequency_array(self): if not self._frequency_array_updated: if self.sampling_frequency and self.duration: self._frequency_array = utils.create_frequency_series( - sampling_frequency=self.sampling_frequency, - duration=self.duration) + sampling_frequency=self.sampling_frequency, duration=self.duration + ) else: - raise ValueError('Can not calculate a frequency series without a ' - 'legitimate sampling_frequency ({}) or duration ({})' - .format(self.sampling_frequency, self.duration)) + raise ValueError( + "Can not calculate a frequency series without a " + f"legitimate sampling_frequency ({self.sampling_frequency}) or duration ({self.duration})" + ) self._frequency_array_updated = True return self._frequency_array @@ -51,13 +53,14 @@ def frequency_array(self): @frequency_array.setter def frequency_array(self, frequency_array): self._frequency_array = frequency_array - self._sampling_frequency, self._duration = \ - utils.get_sampling_frequency_and_duration_from_frequency_array(frequency_array) + self._sampling_frequency, self._duration = utils.get_sampling_frequency_and_duration_from_frequency_array( + frequency_array + ) self._frequency_array_updated = True @property def time_array(self): - """ Time array for the waveforms. Automatically updates if sampling_frequency or duration are updated. + """Time array for the waveforms. Automatically updates if sampling_frequency or duration are updated. Returns ======= @@ -67,13 +70,13 @@ def time_array(self): if not self._time_array_updated: if self.sampling_frequency and self.duration: self._time_array = utils.create_time_series( - sampling_frequency=self.sampling_frequency, - duration=self.duration, - starting_time=self.start_time) + sampling_frequency=self.sampling_frequency, duration=self.duration, starting_time=self.start_time + ) else: - raise ValueError('Can not calculate a time series without a ' - 'legitimate sampling_frequency ({}) or duration ({})' - .format(self.sampling_frequency, self.duration)) + raise ValueError( + "Can not calculate a time series without a " + f"legitimate sampling_frequency ({self.sampling_frequency}) or duration ({self.duration})" + ) self._time_array_updated = True return self._time_array @@ -81,14 +84,13 @@ def time_array(self): @time_array.setter def time_array(self, time_array): self._time_array = time_array - self._sampling_frequency, self._duration = \ - utils.get_sampling_frequency_and_duration_from_time_array(time_array) + self._sampling_frequency, self._duration = utils.get_sampling_frequency_and_duration_from_time_array(time_array) self._start_time = time_array[0] self._time_array_updated = True @property def duration(self): - """ Allows one to set the time duration and automatically updates the frequency and time array. + """Allows one to set the time duration and automatically updates the frequency and time array. Returns ======= @@ -105,7 +107,7 @@ def duration(self, duration): @property def sampling_frequency(self): - """ Allows one to set the sampling frequency and automatically updates the frequency and time array. + """Allows one to set the sampling frequency and automatically updates the frequency and time array. Returns ======= diff --git a/bilby/core/utils/__init__.py b/bilby/core/utils/__init__.py index 4b17c2b6c..03ca26521 100644 --- a/bilby/core/utils/__init__.py +++ b/bilby/core/utils/__init__.py @@ -1,6 +1,9 @@ +# ruff: noqa F403 + from . import random from .calculus import * from .cmd import * +from .cmd import set_up_command_line_arguments from .colors import * from .constants import * from .conversion import * @@ -11,6 +14,7 @@ from .introspection import * from .io import * from .log import * +from .log import setup_logger from .meta_data import * from .plotting import * from .samples import * diff --git a/bilby/core/utils/calculus.py b/bilby/core/utils/calculus.py index e10ce6111..eeac93f63 100644 --- a/bilby/core/utils/calculus.py +++ b/bilby/core/utils/calculus.py @@ -1,7 +1,7 @@ import math import numpy as np -from scipy.interpolate import interp1d, RectBivariateSpline +from scipy.interpolate import RectBivariateSpline, interp1d from scipy.special import logsumexp from .log import logger @@ -116,9 +116,7 @@ def derivatives( cureps *= epsscale if cureps < mineps or flipflop > flipflopmax: # if no convergence set flat derivative (TODO: check if there is a better thing to do instead) - logger.warning( - "Derivative calculation did not converge: setting flat derivative." - ) + logger.warning("Derivative calculation did not converge: setting flat derivative.") grads[count] = 0.0 break leps *= epsscale @@ -175,9 +173,7 @@ def logtrapzexp(lnf, dx): C = np.log(dx / 2.0) elif isinstance(dx, (list, np.ndarray)): if len(dx) != len(lnf) - 1: - raise ValueError( - "Step size array must have length one less than the function length" - ) + raise ValueError("Step size array must have length one less than the function length") lndx = np.log(dx) lnfdx1 = lnfdx1.copy() + lndx @@ -190,7 +186,6 @@ def logtrapzexp(lnf, dx): class BoundedRectBivariateSpline(RectBivariateSpline): - def __init__(self, x, y, z, bbox=[None] * 4, kx=3, ky=3, s=0, fill_value=None): self.x_min, self.x_max, self.y_min, self.y_max = bbox if self.x_min is None: @@ -224,6 +219,7 @@ class WrappedInterp1d(interp1d): A wrapper around scipy interp1d which sets equality-by-instantiation and makes sure that the output is a float if the input is a float or int. """ + def __call__(self, x): output = super().__call__(x) if isinstance(x, (float, int)): @@ -254,4 +250,4 @@ def round_up_to_power_of_two(x): float: next power of two """ - return 2**math.ceil(np.log2(x)) + return 2 ** math.ceil(np.log2(x)) diff --git a/bilby/core/utils/cmd.py b/bilby/core/utils/cmd.py index eba784bb8..364f4713f 100644 --- a/bilby/core/utils/cmd.py +++ b/bilby/core/utils/cmd.py @@ -7,7 +7,7 @@ def set_up_command_line_arguments(): - """ Sets up command line arguments that can be used to modify how scripts are run. + """Sets up command line arguments that can be used to modify how scripts are run. Returns ======= @@ -50,30 +50,37 @@ def set_up_command_line_arguments(): """ try: parser = argparse.ArgumentParser( - description="Command line interface for bilby scripts", - add_help=False, allow_abbrev=False) + description="Command line interface for bilby scripts", add_help=False, allow_abbrev=False + ) except TypeError: - parser = argparse.ArgumentParser( - description="Command line interface for bilby scripts", - add_help=False) - parser.add_argument("-v", "--verbose", action="store_true", - help=("Increase output verbosity [logging.DEBUG]." + - " Overridden by script level settings")) - parser.add_argument("-q", "--quiet", action="store_true", - help=("Decrease output verbosity [logging.WARNING]." + - " Overridden by script level settings")) - parser.add_argument("-c", "--clean", action="store_true", - help="Force clean data, never use cached data") - parser.add_argument("-u", "--use-cached", action="store_true", - help="Force cached data and do not check its validity") - parser.add_argument("--sampler-help", nargs='?', default=False, - const='None', help="Print help for given sampler") - parser.add_argument("--bilby-test-mode", action="store_true", - help=("Used for testing only: don't run full PE, but" - " just check nothing breaks")) - parser.add_argument("--bilby-zero-likelihood-mode", action="store_true", - help=("Used for testing only: don't run full PE, but" - " just check nothing breaks")) + parser = argparse.ArgumentParser(description="Command line interface for bilby scripts", add_help=False) + parser.add_argument( + "-v", + "--verbose", + action="store_true", + help=("Increase output verbosity [logging.DEBUG]." + " Overridden by script level settings"), + ) + parser.add_argument( + "-q", + "--quiet", + action="store_true", + help=("Decrease output verbosity [logging.WARNING]." + " Overridden by script level settings"), + ) + parser.add_argument("-c", "--clean", action="store_true", help="Force clean data, never use cached data") + parser.add_argument( + "-u", "--use-cached", action="store_true", help="Force cached data and do not check its validity" + ) + parser.add_argument("--sampler-help", nargs="?", default=False, const="None", help="Print help for given sampler") + parser.add_argument( + "--bilby-test-mode", + action="store_true", + help=("Used for testing only: don't run full PE, but just check nothing breaks"), + ) + parser.add_argument( + "--bilby-zero-likelihood-mode", + action="store_true", + help=("Used for testing only: don't run full PE, but just check nothing breaks"), + ) args, unknown_args = parser.parse_known_args() if args.quiet: args.log_level = logging.WARNING @@ -98,19 +105,17 @@ def run_commandline(cl, log_level=20, raise_error=True, return_output=True): """ - logger.log(log_level, 'Now executing: ' + cl) + logger.log(log_level, "Now executing: " + cl) if return_output: try: - out = subprocess.check_output( - cl, stderr=subprocess.STDOUT, shell=True, - universal_newlines=True) + out = subprocess.check_output(cl, stderr=subprocess.STDOUT, shell=True, universal_newlines=True) except subprocess.CalledProcessError as e: - logger.log(log_level, 'Execution failed: {}'.format(e.output)) + logger.log(log_level, f"Execution failed: {e.output}") if raise_error: raise else: out = 0 - os.system('\n') + os.system("\n") return out else: process = subprocess.Popen(cl, shell=True) diff --git a/bilby/core/utils/colors.py b/bilby/core/utils/colors.py index 7ce34e66c..53c604ee6 100644 --- a/bilby/core/utils/colors.py +++ b/bilby/core/utils/colors.py @@ -1,5 +1,5 @@ class tcolors: - KEY = '\033[93m' - VALUE = '\033[91m' - HIGHLIGHT = '\033[95m' - END = '\033[0m' + KEY = "\033[93m" + VALUE = "\033[91m" + HIGHLIGHT = "\033[95m" + END = "\033[0m" diff --git a/bilby/core/utils/constants.py b/bilby/core/utils/constants.py index 8dbec27da..45f3cd03d 100644 --- a/bilby/core/utils/constants.py +++ b/bilby/core/utils/constants.py @@ -1,7 +1,7 @@ # Constants: values taken from LAL 505df9dd2e69b4812f1e8eee3a6d468ba7f80674 speed_of_light = 299792458.0 # m/s -parsec = 3.085677581491367e+16 # m +parsec = 3.085677581491367e16 # m solar_mass = 1.988409870698050731911960804878414216e30 # Kg radius_of_earth = 6378136.6 # m gravitational_constant = 6.6743e-11 # m^3 kg^-1 s^-2 diff --git a/bilby/core/utils/conversion.py b/bilby/core/utils/conversion.py index c89e2a9ec..5f020dc84 100644 --- a/bilby/core/utils/conversion.py +++ b/bilby/core/utils/conversion.py @@ -2,7 +2,7 @@ def ra_dec_to_theta_phi(ra, dec, gmst): - """ Convert from RA and DEC to polar coordinates on celestial sphere + """Convert from RA and DEC to polar coordinates on celestial sphere Parameters ========== diff --git a/bilby/core/utils/counter.py b/bilby/core/utils/counter.py index 29428057e..0ed8b6800 100644 --- a/bilby/core/utils/counter.py +++ b/bilby/core/utils/counter.py @@ -1,7 +1,7 @@ import multiprocessing -class Counter(object): +class Counter: """ General class to count number of times a function is Called, returns total number of function calls @@ -11,8 +11,9 @@ class Counter(object): initalval : int, 0 number to start counting from """ + def __init__(self, initval=0): - self.val = multiprocessing.RawValue('i', initval) + self.val = multiprocessing.RawValue("i", initval) self.lock = multiprocessing.Lock() def increment(self): diff --git a/bilby/core/utils/docs.py b/bilby/core/utils/docs.py index a7e326e8b..37c6e983b 100644 --- a/bilby/core/utils/docs.py +++ b/bilby/core/utils/docs.py @@ -12,10 +12,12 @@ def docstring(docstr, sep="\n"): sep: str Separation character for appending the existing docstring. """ + def _decorator(func): if func.__doc__ is None: func.__doc__ = docstr else: func.__doc__ = sep.join([func.__doc__, docstr]) return func + return _decorator diff --git a/bilby/core/utils/entry_points.py b/bilby/core/utils/entry_points.py index 305fc5704..5aef275f6 100644 --- a/bilby/core/utils/entry_points.py +++ b/bilby/core/utils/entry_points.py @@ -1,8 +1,4 @@ -import sys -if sys.version_info < (3, 10): - from importlib_metadata import entry_points -else: - from importlib.metadata import entry_points +from importlib.metadata import entry_points def get_entry_points(group): @@ -13,6 +9,4 @@ def get_entry_points(group): group: str Entry points you wish to query """ - return { - custom.name: custom for custom in entry_points(group=group) - } + return {custom.name: custom for custom in entry_points(group=group)} diff --git a/bilby/core/utils/env.py b/bilby/core/utils/env.py index fc3ba4e1d..7df35bc43 100644 --- a/bilby/core/utils/env.py +++ b/bilby/core/utils/env.py @@ -1,16 +1,12 @@ - - def string_to_boolean(value): """Convert a string to a boolean. Supports True/False (case-insensitive), and 1/0. """ value = value.strip().lower() - if value in ['true', '1']: + if value in ["true", "1"]: return True - elif value in ['false', '0']: + elif value in ["false", "0"]: return False else: - raise ValueError( - f"Invalid value for boolean: {value}" - ) + raise ValueError(f"Invalid value for boolean: {value}") diff --git a/bilby/core/utils/introspection.py b/bilby/core/utils/introspection.py index 70073f995..2393dc375 100644 --- a/bilby/core/utils/introspection.py +++ b/bilby/core/utils/introspection.py @@ -3,7 +3,7 @@ def infer_parameters_from_function(func): - """ Infers the arguments of a function + """Infers the arguments of a function (except the first arg which is assumed to be the dep. variable). Throws out `*args` and `**kwargs` type arguments @@ -40,7 +40,7 @@ def infer_parameters_from_function(func): def infer_args_from_method(method): - """ Infers all arguments of a method except for `self` + """Infers all arguments of a method except for `self` Throws out `*args` and `**kwargs` type arguments. @@ -54,7 +54,7 @@ def infer_args_from_method(method): def infer_args_from_function_except_n_args(func, n=1): - """ Inspects a function to find its arguments, and ignoring the + """Inspects a function to find its arguments, and ignoring the first n of these, returns a list of arguments from the function's signature. @@ -104,8 +104,7 @@ def _infer_args_from_function_except_for_first_arg(func): def get_dict_with_properties(obj): - property_names = [p for p in dir(obj.__class__) - if isinstance(getattr(obj.__class__, p), property)] + property_names = [p for p in dir(obj.__class__) if isinstance(getattr(obj.__class__, p), property)] dict_with_properties = obj.__dict__.copy() for key in property_names: dict_with_properties[key] = getattr(obj, key) @@ -114,12 +113,12 @@ def get_dict_with_properties(obj): def get_function_path(func): if hasattr(func, "__module__") and hasattr(func, "__name__"): - return "{}.{}".format(func.__module__, func.__name__) + return f"{func.__module__}.{func.__name__}" else: return func -class PropertyAccessor(object): +class PropertyAccessor: """ Generic descriptor class that allows handy access of properties without long boilerplate code. The properties of Interferometer are defined as instances diff --git a/bilby/core/utils/io.py b/bilby/core/utils/io.py index 8299d6816..5c4c27f04 100644 --- a/bilby/core/utils/io.py +++ b/bilby/core/utils/io.py @@ -1,12 +1,12 @@ -from collections import UserDict, UserList import datetime import inspect import json import os import shutil +from collections import UserDict, UserList +from datetime import timedelta from importlib import import_module from pathlib import Path -from datetime import timedelta import numpy as np import pandas as pd @@ -15,7 +15,7 @@ def check_directory_exists_and_if_not_mkdir(directory): - """ Checks if the given directory exists and creates it if it does not exist + """Checks if the given directory exists and creates it if it does not exist Parameters ========== @@ -28,8 +28,8 @@ def check_directory_exists_and_if_not_mkdir(directory): class BilbyJsonEncoder(json.JSONEncoder): def default(self, obj): - from ..prior import BaseJointPriorDist, Prior, PriorDict from ...bilby_mcmc.proposals import ProposalCycle + from ..prior import BaseJointPriorDist, Prior, PriorDict if isinstance(obj, np.integer): return int(obj) @@ -49,7 +49,8 @@ def default(self, obj): if isinstance(obj, ProposalCycle): return str(obj) try: - from astropy import cosmology as cosmo, units + from astropy import cosmology as cosmo + from astropy import units if isinstance(obj, cosmo.FLRW): return encode_astropy_cosmology(obj) @@ -82,10 +83,7 @@ def default(self, obj): "__name__": obj.__name__, } if isinstance(obj, (timedelta)): - return { - "__timedelta__": True, - "__total_seconds__": obj.total_seconds() - } + return {"__timedelta__": True, "__total_seconds__": obj.total_seconds()} return obj.isoformat() return json.JSONEncoder.default(self, obj) @@ -132,9 +130,7 @@ def encode_astropy_unit(obj): return dict(__astropy_unit__=True, unit=obj.to_string()) except ImportError: - logger.debug( - "Cannot import astropy, cosmological priors may not be properly loaded." - ) + logger.debug("Cannot import astropy, cosmological priors may not be properly loaded.") return obj @@ -179,9 +175,7 @@ def decode_astropy_cosmology(dct): del dct["__cosmology__"] return cosmo.Cosmology.from_format(dct, format="mapping") except ImportError: - logger.debug( - "Cannot import astropy, cosmological priors may not be properly loaded." - ) + logger.debug("Cannot import astropy, cosmological priors may not be properly loaded.") return dct except KeyError: # Support decoding result files that used the previous encoding @@ -213,9 +207,7 @@ def decode_astropy_quantity(dct): del dct["__astropy_quantity__"] return units.Quantity(**dct) except ImportError: - logger.debug( - "Cannot import astropy, cosmological priors may not be properly loaded." - ) + logger.debug("Cannot import astropy, cosmological priors may not be properly loaded.") return dct @@ -240,9 +232,7 @@ def decode_astropy_unit(dct): del dct["__astropy_unit__"] return units.Unit(dct["unit"]) except ImportError: - logger.debug( - "Cannot import astropy, cosmological priors may not be properly loaded." - ) + logger.debug("Cannot import astropy, cosmological priors may not be properly loaded.") return dct @@ -260,10 +250,8 @@ def decode_numpy_random_generator(dct): f"Original error: {e}" ) from e # Convert the state and inc integers back to integers - dct["bit_generator_state"]["state"]["state"] = \ - int(dct["bit_generator_state"]["state"]["state"]) - dct["bit_generator_state"]["state"]["inc"] = \ - int(dct["bit_generator_state"]["state"]["inc"]) + dct["bit_generator_state"]["state"]["state"] = int(dct["bit_generator_state"]["state"]["state"]) + dct["bit_generator_state"]["state"]["inc"] = int(dct["bit_generator_state"]["state"]["inc"]) generator = np.random.Generator(bit_generator()) generator.bit_generator.state = dct["bit_generator_state"] @@ -278,7 +266,7 @@ def load_json(filename, gzip): json_str = file.read().decode("utf-8") dictionary = json.loads(json_str, object_hook=decode_bilby_json) else: - with open(filename, "r") as file: + with open(filename) as file: dictionary = json.load(file, object_hook=decode_bilby_json) return dictionary @@ -293,14 +281,13 @@ def decode_bilby_json(dct): cls = getattr(import_module(dct["__module__"]), dct["__name__"]) obj = cls(**dct["kwargs"]) except (AttributeError, ValueError) as e: - if type(e).__name__ == 'AttributeError': + if type(e).__name__ == "AttributeError": warning_message = "Unknown prior class for parameter {}, defaulting to base Prior object".format( dct["kwargs"]["name"] ) - elif type(e).__name__ == 'ValueError': + elif type(e).__name__ == "ValueError": warning_message = ( - f"Unable to load prior {cls} with arguments {dct['kwargs']}, " - "defaulting to base Prior object" + f"Unable to load prior {cls} with arguments {dct['kwargs']}, defaulting to base Prior object" ) logger.warning(warning_message) from ..prior import Prior @@ -332,9 +319,7 @@ def decode_bilby_json(dct): try: cls = getattr(import_module(dct["__module__"]), dct["__name__"], default) except ModuleNotFoundError: - logger.warning( - f"Cannot load module {dct['__module__']}, returning function name as string" - ) + logger.warning(f"Cannot load module {dct['__module__']}, returning function name as string") cls = default return cls if dct.get("__timedelta__", False): @@ -421,7 +406,8 @@ def encode_for_hdf5(key, item): from ..prior.dict import PriorDict try: - from astropy import cosmology as cosmo, units + from astropy import cosmology as cosmo + from astropy import units except ImportError: logger.debug("Cannot import astropy, cannot write cosmological priors") cosmo = None @@ -435,9 +421,9 @@ def encode_for_hdf5(key, item): item = complex(item) if isinstance(item, np.ndarray): # Numpy's wide unicode strings are not supported by hdf5 - if item.dtype.kind == 'U': - logger.debug(f'converting dtype {item.dtype} for hdf5') - item = np.array(item, dtype='S') + if item.dtype.kind == "U": + logger.debug(f"converting dtype {item.dtype} for hdf5") + item = np.array(item, dtype="S") if isinstance(item, (np.ndarray, int, float, complex, str, bytes)): output = item elif isinstance(item, np.random.Generator): @@ -462,7 +448,7 @@ def encode_for_hdf5(key, item): else: output.append(str(value).encode("utf-8")) else: - raise ValueError(f'Cannot save {key}: {type(item)} type') + raise ValueError(f"Cannot save {key}: {type(item)} type") elif isinstance(item, PriorDict): output = json.dumps(item._get_json_dict()) elif isinstance(item, pd.DataFrame): @@ -474,9 +460,7 @@ def encode_for_hdf5(key, item): elif units is not None and isinstance(item, (units.PrefixUnit, units.UnitBase, units.FunctionUnitBase)): output = encode_astropy_unit(item) elif inspect.isfunction(item) or inspect.isclass(item): - output = dict( - __module__=item.__module__, __name__=item.__name__, __class__=True - ) + output = dict(__module__=item.__module__, __name__=item.__name__, __class__=True) elif isinstance(item, dict): output = item.copy() elif isinstance(item, (UserDict, UserList)): @@ -486,7 +470,7 @@ def encode_for_hdf5(key, item): elif isinstance(item, datetime.timedelta): output = item.total_seconds() else: - raise ValueError(f'Cannot save {key}: {type(item)} type') + raise ValueError(f"Cannot save {key}: {type(item)} type") return output @@ -539,9 +523,7 @@ def recursively_load_dict_contents_from_group(h5file, path): if isinstance(item, h5py.Dataset): output[key] = decode_from_hdf5(item[()]) elif isinstance(item, h5py.Group): - output[key] = recursively_load_dict_contents_from_group( - h5file, path + key + "/" - ) + output[key] = recursively_load_dict_contents_from_group(h5file, path + key + "/") # Some items may be encoded as dictionaries, so we need to decode them # after the dictionary has been constructed. # This includes decoding astropy and bilby types @@ -573,7 +555,7 @@ def recursively_save_dict_contents_to_group(h5file, path, dic): def safe_file_dump(data, filename, module): - """ Safely dump data to a .pickle file + """Safely dump data to a .pickle file Parameters ========== @@ -593,7 +575,7 @@ def safe_file_dump(data, filename, module): def move_old_file(filename, overwrite=False): - """ Moves or removes an old file. + """Moves or removes an old file. Parameters ========== @@ -605,14 +587,12 @@ def move_old_file(filename, overwrite=False): """ if os.path.isfile(filename): if overwrite: - logger.debug("Removing existing file {}".format(filename)) + logger.debug(f"Removing existing file {filename}") os.remove(filename) else: - logger.debug( - "Renaming existing file {} to {}.old".format(filename, filename) - ) + logger.debug(f"Renaming existing file {filename} to {filename}.old") shutil.move(filename, filename + ".old") - logger.debug("Saving result to {}".format(filename)) + logger.debug(f"Saving result to {filename}") def safe_save_figure(fig, filename, **kwargs): diff --git a/bilby/core/utils/log.py b/bilby/core/utils/log.py index b6c45152b..37f52fe7d 100644 --- a/bilby/core/utils/log.py +++ b/bilby/core/utils/log.py @@ -1,15 +1,15 @@ import json import logging -from pathlib import Path import subprocess import sys from importlib import metadata +from pathlib import Path -logger = logging.getLogger('bilby') +logger = logging.getLogger("bilby") -def setup_logger(outdir='.', label=None, log_level='INFO', print_version=False): - """ Setup logging output: call at the start of the script to use +def setup_logger(outdir=".", label=None, log_level="INFO", print_version=False): + """Setup logging output: call at the start of the script to use Parameters ========== @@ -27,28 +27,28 @@ def setup_logger(outdir='.', label=None, log_level='INFO', print_version=False): try: level = getattr(logging, log_level.upper()) except AttributeError: - raise ValueError('log_level {} not understood'.format(log_level)) + raise ValueError(f"log_level {log_level} not understood") else: level = int(log_level) - logger = logging.getLogger('bilby') + logger = logging.getLogger("bilby") logger.propagate = False logger.setLevel(level) if not any([isinstance(h, logging.StreamHandler) for h in logger.handlers]): stream_handler = logging.StreamHandler() - stream_handler.setFormatter(logging.Formatter( - '%(asctime)s %(name)s %(levelname)-8s: %(message)s', datefmt='%H:%M')) + stream_handler.setFormatter( + logging.Formatter("%(asctime)s %(name)s %(levelname)-8s: %(message)s", datefmt="%H:%M") + ) stream_handler.setLevel(level) logger.addHandler(stream_handler) if not any([isinstance(h, logging.FileHandler) for h in logger.handlers]): if label: Path(outdir).mkdir(parents=True, exist_ok=True) - log_file = '{}/{}.log'.format(outdir, label) + log_file = f"{outdir}/{label}.log" file_handler = logging.FileHandler(log_file) - file_handler.setFormatter(logging.Formatter( - '%(asctime)s %(levelname)-8s: %(message)s', datefmt='%H:%M')) + file_handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)-8s: %(message)s", datefmt="%H:%M")) file_handler.setLevel(level) logger.addHandler(file_handler) @@ -58,11 +58,12 @@ def setup_logger(outdir='.', label=None, log_level='INFO', print_version=False): if print_version: version = get_version_information() - logger.info('Running bilby version: {}'.format(version)) + logger.info(f"Running bilby version: {version}") def get_version_information(): from bilby import __version__ + return __version__ @@ -106,12 +107,7 @@ def env_package_list(as_dataframe=False): conda_detected = (Path(prefix) / "conda-meta").is_dir() if conda_detected: try: - pkgs = json.loads(subprocess.check_output([ - "conda", - "list", - "--prefix", prefix, - "--json" - ])) + pkgs = json.loads(subprocess.check_output(["conda", "list", "--prefix", prefix, "--json"])) except (FileNotFoundError, subprocess.CalledProcessError): # When a conda env is in use but conda is unavailable conda_detected = False @@ -126,15 +122,23 @@ def env_package_list(as_dataframe=False): modules = loaded_modules_dict() pkgs = [{"name": x, "version": y} for x, y in modules.items()] else: - pkgs = json.loads(subprocess.check_output([ - sys.executable, - "-m", "pip", - "list", "installed", - "--format", "json", - ])) + pkgs = json.loads( + subprocess.check_output( + [ + sys.executable, + "-m", + "pip", + "list", + "installed", + "--format", + "json", + ] + ) + ) # convert to recarray for storage if as_dataframe: from pandas import DataFrame + return DataFrame(pkgs) return pkgs diff --git a/bilby/core/utils/meta_data.py b/bilby/core/utils/meta_data.py index 858ecf7a3..e5c0ccae6 100644 --- a/bilby/core/utils/meta_data.py +++ b/bilby/core/utils/meta_data.py @@ -21,10 +21,7 @@ def __init__(self, mapping=None, /, **kwargs): def __setitem__(self, key, item): if key in self: - logger.debug( - f"Overwriting meta data key {key} with value {item}. " - f"Old value was {self[key]}" - ) + logger.debug(f"Overwriting meta data key {key} with value {item}. Old value was {self[key]}") else: logger.debug(f"Setting meta data key {key} with value {item}") return super().__setitem__(key, item) @@ -37,13 +34,8 @@ def __new__(cls, *args, **kwargs): if cls._instance is None: cls._instance = super().__new__(cls) else: - logger.warning( - "GlobalMetaData has already been instantiated. " - "Returning the existing instance." - ) + logger.warning("GlobalMetaData has already been instantiated. Returning the existing instance.") return cls._instance -global_meta_data = GlobalMetaData({ - "rng": random.rng -}) +global_meta_data = GlobalMetaData({"rng": random.rng}) diff --git a/bilby/core/utils/plotting.py b/bilby/core/utils/plotting.py index 02d149447..41ad17ff7 100644 --- a/bilby/core/utils/plotting.py +++ b/bilby/core/utils/plotting.py @@ -21,6 +21,7 @@ def latex_plot_format(func): latex_plot_format-wrapped plotting function and will be set directly. """ + @functools.wraps(func) def wrapper_decorator(*args, **kwargs): import matplotlib.pyplot as plt @@ -38,9 +39,9 @@ def wrapper_decorator(*args, **kwargs): if bilby_mathdefault == 1: logger.debug("Setting mathdefault in the rcParams") - rcParams['text.latex.preamble'] = r'\providecommand{\mathdefault}[1][]{}' + rcParams["text.latex.preamble"] = r"\providecommand{\mathdefault}[1][]{}" - logger.debug("Using BILBY_STYLE={}".format(bilby_style)) + logger.debug(f"Using BILBY_STYLE={bilby_style}") if bilby_style.lower() == "none": return func(*args, **kwargs) elif os.path.isfile(bilby_style): @@ -64,9 +65,7 @@ def wrapper_decorator(*args, **kwargs): rcParams["font.family"] = _old_family return func(*args, **kwargs) else: - logger.debug( - "Environment variable BILBY_STYLE={} not used" - .format(bilby_style) - ) + logger.debug(f"Environment variable BILBY_STYLE={bilby_style} not used") return func(*args, **kwargs) + return wrapper_decorator diff --git a/bilby/core/utils/random.py b/bilby/core/utils/random.py index ccb7654c6..a8b1bb176 100644 --- a/bilby/core/utils/random.py +++ b/bilby/core/utils/random.py @@ -24,10 +24,11 @@ Do not import :code:`rng` directly from :code:`bilby.core.utils.random` since it will not be seeded correctly. """ + import sys import warnings -from numpy.random import default_rng, SeedSequence +from numpy.random import SeedSequence, default_rng def __getattr__(name): @@ -45,6 +46,7 @@ class Generator: It should not be used directly, instead use :code:`random.rng` to generate random numbers. See the documentation for more details. """ + rng = default_rng() """Random number generator. @@ -77,7 +79,7 @@ def seed(seed): # Warn if the original rng object (i.e., pre-seed) still exists elsewhere for module in sys.modules.values(): - if not module or not hasattr(module, '__dict__'): + if not module or not hasattr(module, "__dict__"): continue rng_obj = module.__dict__.get("rng") if rng_obj is _original_rng: @@ -85,7 +87,7 @@ def seed(seed): "Detected that `rng` was likely imported directly before calling `seed()`. " "This means the imported reference will not reflect the newly seeded generator. " "Use `from bilby.core.utils import random` and access `random.rng` instead.", - RuntimeWarning + RuntimeWarning, ) break diff --git a/bilby/core/utils/samples.py b/bilby/core/utils/samples.py index a075d6dcd..c3dd3850a 100644 --- a/bilby/core/utils/samples.py +++ b/bilby/core/utils/samples.py @@ -2,8 +2,8 @@ from scipy.special import logsumexp -class SamplesSummary(object): - """ Object to store a set of samples and calculate summary statistics +class SamplesSummary: + """Object to store a set of samples and calculate summary statistics Parameters ========== @@ -16,7 +16,8 @@ class SamplesSummary(object): The default confidence interval level, defaults t0 0.9 """ - def __init__(self, samples, average='median', confidence_level=.9): + + def __init__(self, samples, average="median", confidence_level=0.9): self.samples = samples self.average = average self.confidence_level = confidence_level @@ -42,18 +43,18 @@ def confidence_level(self, confidence_level): @property def average(self): - if self._average == 'mean': + if self._average == "mean": return self.mean - elif self._average == 'median': + elif self._average == "median": return self.median @average.setter def average(self, average): - allowed_averages = ['mean', 'median'] + allowed_averages = ["mean", "median"] if average in allowed_averages: self._average = average else: - raise ValueError("Average {} not in allowed averages".format(average)) + raise ValueError(f"Average {average} not in allowed averages") @property def median(self): @@ -65,37 +66,37 @@ def mean(self): @property def _lower_level(self): - """ The credible interval lower quantile value """ - return (1 - self.confidence_level) / 2. + """The credible interval lower quantile value""" + return (1 - self.confidence_level) / 2.0 @property def _upper_level(self): - """ The credible interval upper quantile value """ - return (1 + self.confidence_level) / 2. + """The credible interval upper quantile value""" + return (1 + self.confidence_level) / 2.0 @property def lower_absolute_credible_interval(self): - """ Absolute lower value of the credible interval """ + """Absolute lower value of the credible interval""" return np.quantile(self.samples, self._lower_level, axis=0) @property def upper_absolute_credible_interval(self): - """ Absolute upper value of the credible interval """ + """Absolute upper value of the credible interval""" return np.quantile(self.samples, self._upper_level, axis=0) @property def lower_relative_credible_interval(self): - """ Relative (to average) lower value of the credible interval """ + """Relative (to average) lower value of the credible interval""" return self.lower_absolute_credible_interval - self.average @property def upper_relative_credible_interval(self): - """ Relative (to average) upper value of the credible interval """ + """Relative (to average) upper value of the credible interval""" return self.upper_absolute_credible_interval - self.average def kish_log_effective_sample_size(ln_weights): - """ Calculate the Kish effective sample size from the natural-log weights + """Calculate the Kish effective sample size from the natural-log weights See https://en.wikipedia.org/wiki/Effective_sample_size for details diff --git a/bilby/core/utils/series.py b/bilby/core/utils/series.py index 63daebd6e..168acae18 100644 --- a/bilby/core/utils/series.py +++ b/bilby/core/utils/series.py @@ -25,7 +25,7 @@ def get_sampling_frequency(time_array): if np.ptp(np.diff(time_array)) > tol: raise ValueError("Your time series was not evenly sampled") else: - return np.round(1. / (time_array[1] - time_array[0]), decimals=_TOL) + return np.round(1.0 / (time_array[1] - time_array[0]), decimals=_TOL) def get_sampling_frequency_and_duration_from_time_array(time_array): @@ -83,7 +83,7 @@ def get_sampling_frequency_and_duration_from_frequency_array(frequency_array): return sampling_frequency, duration -def create_time_series(sampling_frequency, duration, starting_time=0.): +def create_time_series(sampling_frequency, duration, starting_time=0.0): """ Parameters @@ -99,13 +99,13 @@ def create_time_series(sampling_frequency, duration, starting_time=0.): """ _check_legal_sampling_frequency_and_duration(sampling_frequency, duration) number_of_samples = int(duration * sampling_frequency) - return np.linspace(start=starting_time, - stop=duration + starting_time - 1 / sampling_frequency, - num=number_of_samples) + return np.linspace( + start=starting_time, stop=duration + starting_time - 1 / sampling_frequency, num=number_of_samples + ) def create_frequency_series(sampling_frequency, duration): - """ Create a frequency series with the correct length and spacing. + """Create a frequency series with the correct length and spacing. Parameters ========== @@ -121,13 +121,11 @@ def create_frequency_series(sampling_frequency, duration): number_of_samples = int(np.round(duration * sampling_frequency)) number_of_frequencies = int(np.round(number_of_samples / 2) + 1) - return np.linspace(start=0, - stop=sampling_frequency / 2, - num=number_of_frequencies) + return np.linspace(start=0, stop=sampling_frequency / 2, num=number_of_frequencies) def _check_legal_sampling_frequency_and_duration(sampling_frequency, duration): - """ By convention, sampling_frequency and duration have to multiply to an integer + """By convention, sampling_frequency and duration have to multiply to an integer This will check if the product of both parameters multiplies reasonably close to an integer. @@ -139,19 +137,17 @@ def _check_legal_sampling_frequency_and_duration(sampling_frequency, duration): """ num = sampling_frequency * duration - if np.abs(num - np.round(num)) > 10**(-_TOL): + if np.abs(num - np.round(num)) > 10 ** (-_TOL): raise IllegalDurationAndSamplingFrequencyException( - '\nYour sampling frequency and duration must multiply to a number' - 'up to (tol = {}) decimals close to an integer number. ' - '\nBut sampling_frequency={} and duration={} multiply to {}'.format( - _TOL, sampling_frequency, duration, - sampling_frequency * duration - ) + "\nYour sampling frequency and duration must multiply to a number" + f"up to (tol = {_TOL}) decimals close to an integer number. " + f"\nBut sampling_frequency={sampling_frequency} and duration={duration} " + f"multiply to {sampling_frequency * duration}" ) def create_white_noise(sampling_frequency, duration): - """ Create white_noise which is then coloured by a given PSD + """Create white_noise which is then coloured by a given PSD Parameters ========== @@ -189,7 +185,7 @@ def create_white_noise(sampling_frequency, duration): def nfft(time_domain_strain, sampling_frequency): - """ Perform an FFT while keeping track of the frequency bins. Assumes input + """Perform an FFT while keeping track of the frequency bins. Assumes input time series is real (positive frequencies only). Parameters @@ -209,14 +205,13 @@ def nfft(time_domain_strain, sampling_frequency): frequency_domain_strain = np.fft.rfft(time_domain_strain) frequency_domain_strain /= sampling_frequency - frequency_array = np.linspace( - 0, sampling_frequency / 2, len(frequency_domain_strain)) + frequency_array = np.linspace(0, sampling_frequency / 2, len(frequency_domain_strain)) return frequency_domain_strain, frequency_array def infft(frequency_domain_strain, sampling_frequency): - """ Inverse FFT for use in conjunction with nfft. + """Inverse FFT for use in conjunction with nfft. Parameters ========== diff --git a/bilby/gw/__init__.py b/bilby/gw/__init__.py index b5115766b..3dc5e384d 100644 --- a/bilby/gw/__init__.py +++ b/bilby/gw/__init__.py @@ -1,6 +1,21 @@ -from . import (conversion, cosmology, detector, eos, likelihood, prior, - result, source, utils, waveform_generator) -from .waveform_generator import WaveformGenerator, LALCBCWaveformGenerator -from .likelihood import GravitationalWaveTransient +from . import conversion, cosmology, detector, eos, likelihood, prior, result, source, utils, waveform_generator from .detector import calibration +from .likelihood import GravitationalWaveTransient +from .waveform_generator import LALCBCWaveformGenerator, WaveformGenerator +__all__ = [ + conversion, + cosmology, + detector, + eos, + likelihood, + prior, + result, + source, + utils, + waveform_generator, + calibration, + GravitationalWaveTransient, + LALCBCWaveformGenerator, + WaveformGenerator, +] diff --git a/bilby/gw/conversion.py b/bilby/gw/conversion.py index 87fcfe78e..7ba9e7889 100644 --- a/bilby/gw/conversion.py +++ b/bilby/gw/conversion.py @@ -3,35 +3,36 @@ gravitational-wave sources. """ -import os -import sys import multiprocessing +import os import pickle +import sys import numpy as np from pandas import DataFrame, Series from scipy.stats import norm -from .utils import (lalsim_SimNeutronStarEOS4ParamSDGammaCheck, - lalsim_SimNeutronStarEOS4ParameterSpectralDecomposition, - lalsim_SimNeutronStarEOS4ParamSDViableFamilyCheck, - lalsim_SimNeutronStarEOS3PieceDynamicPolytrope, - lalsim_SimNeutronStarEOS3PieceCausalAnalytic, - lalsim_SimNeutronStarEOS3PDViableFamilyCheck, - lalsim_CreateSimNeutronStarFamily, - lalsim_SimNeutronStarEOSMaxPseudoEnthalpy, - lalsim_SimNeutronStarEOSSpeedOfSoundGeometerized, - lalsim_SimNeutronStarFamMinimumMass, - lalsim_SimNeutronStarMaximumMass, - lalsim_SimNeutronStarRadius, - lalsim_SimNeutronStarLoveNumberK2) - from ..core.likelihood import MarginalizedLikelihoodReconstructionError -from ..core.utils import logger, solar_mass, gravitational_constant, speed_of_light, command_line_args, safe_file_dump from ..core.prior import DeltaFunction -from .utils import lalsim_SimInspiralTransformPrecessingNewInitialConditions -from .eos.eos import IntegrateTOV +from ..core.utils import command_line_args, gravitational_constant, logger, safe_file_dump, solar_mass, speed_of_light from .cosmology import get_cosmology, z_at_value +from .eos.eos import IntegrateTOV +from .utils import ( + lalsim_CreateSimNeutronStarFamily, + lalsim_SimInspiralTransformPrecessingNewInitialConditions, + lalsim_SimNeutronStarEOS3PDViableFamilyCheck, + lalsim_SimNeutronStarEOS3PieceCausalAnalytic, + lalsim_SimNeutronStarEOS3PieceDynamicPolytrope, + lalsim_SimNeutronStarEOS4ParameterSpectralDecomposition, + lalsim_SimNeutronStarEOS4ParamSDGammaCheck, + lalsim_SimNeutronStarEOS4ParamSDViableFamilyCheck, + lalsim_SimNeutronStarEOSMaxPseudoEnthalpy, + lalsim_SimNeutronStarEOSSpeedOfSoundGeometerized, + lalsim_SimNeutronStarFamMinimumMass, + lalsim_SimNeutronStarLoveNumberK2, + lalsim_SimNeutronStarMaximumMass, + lalsim_SimNeutronStarRadius, +) def redshift_to_luminosity_distance(redshift, cosmology=None): @@ -46,6 +47,7 @@ def redshift_to_comoving_distance(redshift, cosmology=None): def luminosity_distance_to_redshift(distance, cosmology=None): from astropy import units + cosmology = get_cosmology(cosmology) if isinstance(distance, Series): distance = distance.values @@ -54,6 +56,7 @@ def luminosity_distance_to_redshift(distance, cosmology=None): def comoving_distance_to_redshift(distance, cosmology=None): from astropy import units + cosmology = get_cosmology(cosmology) if isinstance(distance, Series): distance = distance.values @@ -103,8 +106,7 @@ def luminosity_distance_to_comoving_distance(distance, cosmology=None): def bilby_to_lalsimulation_spins( - theta_jn, phi_jl, tilt_1, tilt_2, phi_12, a_1, a_2, mass_1, mass_2, - reference_frequency, phase + theta_jn, phi_jl, tilt_1, tilt_2, phi_12, a_1, a_2, mass_1, mass_2, reference_frequency, phase ): """ Convert from Bilby spin parameters to lalsimulation ones. @@ -152,10 +154,8 @@ def bilby_to_lalsimulation_spins( iota = theta_jn else: from numbers import Number - args = ( - theta_jn, phi_jl, tilt_1, tilt_2, phi_12, a_1, a_2, mass_1, - mass_2, reference_frequency, phase - ) + + args = (theta_jn, phi_jl, tilt_1, tilt_2, phi_12, a_1, a_2, mass_1, mass_2, reference_frequency, phase) float_inputs = all([isinstance(arg, Number) for arg in args]) if float_inputs: func = lalsim_SimInspiralTransformPrecessingNewInitialConditions @@ -207,44 +207,36 @@ def convert_to_lal_binary_black_hole_parameters(parameters): converted_parameters = parameters.copy() original_keys = list(converted_parameters.keys()) - if 'luminosity_distance' not in original_keys: - if 'redshift' in converted_parameters.keys(): - converted_parameters['luminosity_distance'] = \ - redshift_to_luminosity_distance(parameters['redshift']) - elif 'comoving_distance' in converted_parameters.keys(): - converted_parameters['luminosity_distance'] = \ - comoving_distance_to_luminosity_distance( - parameters['comoving_distance']) + if "luminosity_distance" not in original_keys: + if "redshift" in converted_parameters.keys(): + converted_parameters["luminosity_distance"] = redshift_to_luminosity_distance(parameters["redshift"]) + elif "comoving_distance" in converted_parameters.keys(): + converted_parameters["luminosity_distance"] = comoving_distance_to_luminosity_distance( + parameters["comoving_distance"] + ) for key in original_keys: - if key[-7:] == '_source': - if 'redshift' not in converted_parameters.keys(): - converted_parameters['redshift'] =\ - luminosity_distance_to_redshift( - parameters['luminosity_distance']) - converted_parameters[key[:-7]] = converted_parameters[key] * ( - 1 + converted_parameters['redshift']) + if key[-7:] == "_source": + if "redshift" not in converted_parameters.keys(): + converted_parameters["redshift"] = luminosity_distance_to_redshift(parameters["luminosity_distance"]) + converted_parameters[key[:-7]] = converted_parameters[key] * (1 + converted_parameters["redshift"]) # we do not require the component masses be added if no mass parameters are present converted_parameters = generate_component_masses(converted_parameters, require_add=False) - for idx in ['1', '2']: - key = 'chi_{}'.format(idx) + for idx in ["1", "2"]: + key = f"chi_{idx}" if key in original_keys: - if "chi_{}_in_plane".format(idx) in original_keys: - converted_parameters["a_{}".format(idx)] = ( - converted_parameters[f"chi_{idx}"] ** 2 - + converted_parameters[f"chi_{idx}_in_plane"] ** 2 + if f"chi_{idx}_in_plane" in original_keys: + converted_parameters[f"a_{idx}"] = ( + converted_parameters[f"chi_{idx}"] ** 2 + converted_parameters[f"chi_{idx}_in_plane"] ** 2 ) ** 0.5 converted_parameters[f"cos_tilt_{idx}"] = ( - converted_parameters[f"chi_{idx}"] - / converted_parameters[f"a_{idx}"] + converted_parameters[f"chi_{idx}"] / converted_parameters[f"a_{idx}"] ) - elif "a_{}".format(idx) not in original_keys: - converted_parameters['a_{}'.format(idx)] = abs( - converted_parameters[key]) - converted_parameters['cos_tilt_{}'.format(idx)] = \ - np.sign(converted_parameters[key]) + elif f"a_{idx}" not in original_keys: + converted_parameters[f"a_{idx}"] = abs(converted_parameters[key]) + converted_parameters[f"cos_tilt_{idx}"] = np.sign(converted_parameters[key]) else: with np.errstate(invalid="raise"): try: @@ -263,8 +255,8 @@ def convert_to_lal_binary_black_hole_parameters(parameters): if key not in converted_parameters: converted_parameters[key] = 0.0 - for angle in ['tilt_1', 'tilt_2', 'theta_jn']: - cos_angle = str('cos_' + angle) + for angle in ["tilt_1", "tilt_2", "theta_jn"]: + cos_angle = str("cos_" + angle) if cos_angle in converted_parameters.keys(): with np.errstate(invalid="ignore"): converted_parameters[angle] = np.arccos(converted_parameters[cos_angle]) @@ -273,11 +265,10 @@ def convert_to_lal_binary_black_hole_parameters(parameters): with np.errstate(invalid="ignore"): converted_parameters["phase"] = np.mod( converted_parameters["delta_phase"] - - np.sign(np.cos(converted_parameters["theta_jn"])) - * converted_parameters["psi"], - 2 * np.pi) - added_keys = [key for key in converted_parameters.keys() - if key not in original_keys] + - np.sign(np.cos(converted_parameters["theta_jn"])) * converted_parameters["psi"], + 2 * np.pi, + ) + added_keys = [key for key in converted_parameters.keys() if key not in original_keys] return converted_parameters, added_keys @@ -312,275 +303,310 @@ def convert_to_lal_binary_neutron_star_parameters(parameters): """ converted_parameters = parameters.copy() original_keys = list(converted_parameters.keys()) - converted_parameters, added_keys =\ - convert_to_lal_binary_black_hole_parameters(converted_parameters) - - if not any([key in converted_parameters for key in - ['lambda_1', 'lambda_2', - 'lambda_tilde', 'delta_lambda_tilde', 'lambda_symmetric', - 'eos_polytrope_gamma_0', 'eos_spectral_pca_gamma_0', 'eos_v1']]): - converted_parameters['lambda_1'] = 0 - converted_parameters['lambda_2'] = 0 - added_keys = added_keys + ['lambda_1', 'lambda_2'] + converted_parameters, added_keys = convert_to_lal_binary_black_hole_parameters(converted_parameters) + + if not any( + [ + key in converted_parameters + for key in [ + "lambda_1", + "lambda_2", + "lambda_tilde", + "delta_lambda_tilde", + "lambda_symmetric", + "eos_polytrope_gamma_0", + "eos_spectral_pca_gamma_0", + "eos_v1", + ] + ] + ): + converted_parameters["lambda_1"] = 0 + converted_parameters["lambda_2"] = 0 + added_keys = added_keys + ["lambda_1", "lambda_2"] return converted_parameters, added_keys - if 'delta_lambda_tilde' in converted_parameters.keys(): - converted_parameters['lambda_1'], converted_parameters['lambda_2'] =\ + if "delta_lambda_tilde" in converted_parameters.keys(): + converted_parameters["lambda_1"], converted_parameters["lambda_2"] = ( lambda_tilde_delta_lambda_tilde_to_lambda_1_lambda_2( - converted_parameters['lambda_tilde'], - parameters['delta_lambda_tilde'], converted_parameters['mass_1'], - converted_parameters['mass_2']) - elif 'lambda_tilde' in converted_parameters.keys(): - converted_parameters['lambda_1'], converted_parameters['lambda_2'] =\ - lambda_tilde_to_lambda_1_lambda_2( - converted_parameters['lambda_tilde'], - converted_parameters['mass_1'], converted_parameters['mass_2']) - if 'lambda_2' not in converted_parameters.keys() and 'lambda_1' in converted_parameters.keys(): - converted_parameters['lambda_2'] =\ - converted_parameters['lambda_1']\ - * converted_parameters['mass_1']**5\ - / converted_parameters['mass_2']**5 - elif 'lambda_2' in converted_parameters.keys() and converted_parameters['lambda_2'] is None: - converted_parameters['lambda_2'] =\ - converted_parameters['lambda_1']\ - * converted_parameters['mass_1']**5\ - / converted_parameters['mass_2']**5 - elif 'eos_spectral_pca_gamma_0' in converted_parameters.keys(): # FIXME: This is a clunky way to do this + converted_parameters["lambda_tilde"], + parameters["delta_lambda_tilde"], + converted_parameters["mass_1"], + converted_parameters["mass_2"], + ) + ) + elif "lambda_tilde" in converted_parameters.keys(): + converted_parameters["lambda_1"], converted_parameters["lambda_2"] = lambda_tilde_to_lambda_1_lambda_2( + converted_parameters["lambda_tilde"], converted_parameters["mass_1"], converted_parameters["mass_2"] + ) + if "lambda_2" not in converted_parameters.keys() and "lambda_1" in converted_parameters.keys(): + converted_parameters["lambda_2"] = ( + converted_parameters["lambda_1"] * converted_parameters["mass_1"] ** 5 / converted_parameters["mass_2"] ** 5 + ) + elif "lambda_2" in converted_parameters.keys() and converted_parameters["lambda_2"] is None: + converted_parameters["lambda_2"] = ( + converted_parameters["lambda_1"] * converted_parameters["mass_1"] ** 5 / converted_parameters["mass_2"] ** 5 + ) + elif "eos_spectral_pca_gamma_0" in converted_parameters.keys(): # FIXME: This is a clunky way to do this converted_parameters = generate_source_frame_parameters(converted_parameters) float_eos_params = {} max_len = 1 - eos_keys = ['eos_spectral_pca_gamma_0', - 'eos_spectral_pca_gamma_1', - 'eos_spectral_pca_gamma_2', - 'eos_spectral_pca_gamma_3', - 'mass_1_source', 'mass_2_source'] + eos_keys = [ + "eos_spectral_pca_gamma_0", + "eos_spectral_pca_gamma_1", + "eos_spectral_pca_gamma_2", + "eos_spectral_pca_gamma_3", + "mass_1_source", + "mass_2_source", + ] for key in eos_keys: try: - if (len(converted_parameters[key]) > max_len): + if len(converted_parameters[key]) > max_len: max_len = len(converted_parameters[key]) except TypeError: float_eos_params[key] = converted_parameters[key] if len(float_eos_params) == len(eos_keys): # case where all eos params are floats (pinned) g0, g1, g2, g3 = spectral_pca_to_spectral( - converted_parameters['eos_spectral_pca_gamma_0'], - converted_parameters['eos_spectral_pca_gamma_1'], - converted_parameters['eos_spectral_pca_gamma_2'], - converted_parameters['eos_spectral_pca_gamma_3']) - converted_parameters['lambda_1'], converted_parameters['lambda_2'], converted_parameters['eos_check'] = \ + converted_parameters["eos_spectral_pca_gamma_0"], + converted_parameters["eos_spectral_pca_gamma_1"], + converted_parameters["eos_spectral_pca_gamma_2"], + converted_parameters["eos_spectral_pca_gamma_3"], + ) + converted_parameters["lambda_1"], converted_parameters["lambda_2"], converted_parameters["eos_check"] = ( spectral_params_to_lambda_1_lambda_2( - g0, g1, g2, g3, converted_parameters['mass_1_source'], converted_parameters['mass_2_source']) + g0, g1, g2, g3, converted_parameters["mass_1_source"], converted_parameters["mass_2_source"] + ) + ) elif len(float_eos_params) < len(eos_keys): # case where some or none of the eos params are floats (pinned) for key in float_eos_params.keys(): converted_parameters[key] = np.ones(max_len) * converted_parameters[key] - g0pca = converted_parameters['eos_spectral_pca_gamma_0'] - g1pca = converted_parameters['eos_spectral_pca_gamma_1'] - g2pca = converted_parameters['eos_spectral_pca_gamma_2'] - g3pca = converted_parameters['eos_spectral_pca_gamma_3'] - m1s = converted_parameters['mass_1_source'] - m2s = converted_parameters['mass_2_source'] + g0pca = converted_parameters["eos_spectral_pca_gamma_0"] + g1pca = converted_parameters["eos_spectral_pca_gamma_1"] + g2pca = converted_parameters["eos_spectral_pca_gamma_2"] + g3pca = converted_parameters["eos_spectral_pca_gamma_3"] + m1s = converted_parameters["mass_1_source"] + m2s = converted_parameters["mass_2_source"] all_lambda_1 = np.empty(0) all_lambda_2 = np.empty(0) all_eos_check = np.empty(0, dtype=bool) - for (g_0pca, g_1pca, g_2pca, g_3pca, m1_s, m2_s) in zip(g0pca, g1pca, g2pca, g3pca, m1s, m2s): + for g_0pca, g_1pca, g_2pca, g_3pca, m1_s, m2_s in zip(g0pca, g1pca, g2pca, g3pca, m1s, m2s): g_0, g_1, g_2, g_3 = spectral_pca_to_spectral(g_0pca, g_1pca, g_2pca, g_3pca) - lambda_1, lambda_2, eos_check = \ - spectral_params_to_lambda_1_lambda_2(g_0, g_1, g_2, g_3, m1_s, m2_s) + lambda_1, lambda_2, eos_check = spectral_params_to_lambda_1_lambda_2(g_0, g_1, g_2, g_3, m1_s, m2_s) all_lambda_1 = np.append(all_lambda_1, lambda_1) all_lambda_2 = np.append(all_lambda_2, lambda_2) all_eos_check = np.append(all_eos_check, eos_check) - converted_parameters['lambda_1'] = all_lambda_1 - converted_parameters['lambda_2'] = all_lambda_2 - converted_parameters['eos_check'] = all_eos_check + converted_parameters["lambda_1"] = all_lambda_1 + converted_parameters["lambda_2"] = all_lambda_2 + converted_parameters["eos_check"] = all_eos_check for key in float_eos_params.keys(): converted_parameters[key] = float_eos_params[key] - elif 'eos_polytrope_gamma_0' and 'eos_polytrope_log10_pressure_1' in converted_parameters.keys(): + elif "eos_polytrope_gamma_0" and "eos_polytrope_log10_pressure_1" in converted_parameters.keys(): converted_parameters = generate_source_frame_parameters(converted_parameters) float_eos_params = {} max_len = 1 - eos_keys = ['eos_polytrope_gamma_0', - 'eos_polytrope_gamma_1', - 'eos_polytrope_gamma_2', - 'eos_polytrope_log10_pressure_1', - 'eos_polytrope_log10_pressure_2', - 'mass_1_source', 'mass_2_source'] + eos_keys = [ + "eos_polytrope_gamma_0", + "eos_polytrope_gamma_1", + "eos_polytrope_gamma_2", + "eos_polytrope_log10_pressure_1", + "eos_polytrope_log10_pressure_2", + "mass_1_source", + "mass_2_source", + ] for key in eos_keys: try: - if (len(converted_parameters[key]) > max_len): + if len(converted_parameters[key]) > max_len: max_len = len(converted_parameters[key]) except TypeError: float_eos_params[key] = converted_parameters[key] if len(float_eos_params) == len(eos_keys): # case where all eos params are floats (pinned) - converted_parameters['lambda_1'], converted_parameters['lambda_2'], converted_parameters['eos_check'] = \ + converted_parameters["lambda_1"], converted_parameters["lambda_2"], converted_parameters["eos_check"] = ( polytrope_or_causal_params_to_lambda_1_lambda_2( - converted_parameters['eos_polytrope_gamma_0'], - converted_parameters['eos_polytrope_log10_pressure_1'], - converted_parameters['eos_polytrope_gamma_1'], - converted_parameters['eos_polytrope_log10_pressure_2'], - converted_parameters['eos_polytrope_gamma_2'], - converted_parameters['mass_1_source'], - converted_parameters['mass_2_source'], - causal=0) + converted_parameters["eos_polytrope_gamma_0"], + converted_parameters["eos_polytrope_log10_pressure_1"], + converted_parameters["eos_polytrope_gamma_1"], + converted_parameters["eos_polytrope_log10_pressure_2"], + converted_parameters["eos_polytrope_gamma_2"], + converted_parameters["mass_1_source"], + converted_parameters["mass_2_source"], + causal=0, + ) + ) elif len(float_eos_params) < len(eos_keys): # case where some or none are floats (pinned) for key in float_eos_params.keys(): converted_parameters[key] = np.ones(max_len) * converted_parameters[key] - pg0 = converted_parameters['eos_polytrope_gamma_0'] - pg1 = converted_parameters['eos_polytrope_gamma_1'] - pg2 = converted_parameters['eos_polytrope_gamma_2'] - logp1 = converted_parameters['eos_polytrope_log10_pressure_1'] - logp2 = converted_parameters['eos_polytrope_log10_pressure_2'] - m1s = converted_parameters['mass_1_source'] - m2s = converted_parameters['mass_2_source'] + pg0 = converted_parameters["eos_polytrope_gamma_0"] + pg1 = converted_parameters["eos_polytrope_gamma_1"] + pg2 = converted_parameters["eos_polytrope_gamma_2"] + logp1 = converted_parameters["eos_polytrope_log10_pressure_1"] + logp2 = converted_parameters["eos_polytrope_log10_pressure_2"] + m1s = converted_parameters["mass_1_source"] + m2s = converted_parameters["mass_2_source"] all_lambda_1 = np.empty(0) all_lambda_2 = np.empty(0) all_eos_check = np.empty(0, dtype=bool) - for (pg_0, pg_1, pg_2, logp_1, logp_2, m1_s, m2_s) in zip(pg0, pg1, pg2, logp1, logp2, m1s, m2s): - lambda_1, lambda_2, eos_check = \ - polytrope_or_causal_params_to_lambda_1_lambda_2( - pg_0, logp_1, pg_1, logp_2, pg_2, m1_s, m2_s, causal=0) + for pg_0, pg_1, pg_2, logp_1, logp_2, m1_s, m2_s in zip(pg0, pg1, pg2, logp1, logp2, m1s, m2s): + lambda_1, lambda_2, eos_check = polytrope_or_causal_params_to_lambda_1_lambda_2( + pg_0, logp_1, pg_1, logp_2, pg_2, m1_s, m2_s, causal=0 + ) all_lambda_1 = np.append(all_lambda_1, lambda_1) all_lambda_2 = np.append(all_lambda_2, lambda_2) all_eos_check = np.append(all_eos_check, eos_check) - converted_parameters['lambda_1'] = all_lambda_1 - converted_parameters['lambda_2'] = all_lambda_2 - converted_parameters['eos_check'] = all_eos_check + converted_parameters["lambda_1"] = all_lambda_1 + converted_parameters["lambda_2"] = all_lambda_2 + converted_parameters["eos_check"] = all_eos_check for key in float_eos_params.keys(): converted_parameters[key] = float_eos_params[key] - elif 'eos_polytrope_gamma_0' and 'eos_polytrope_scaled_pressure_ratio' in converted_parameters.keys(): + elif "eos_polytrope_gamma_0" and "eos_polytrope_scaled_pressure_ratio" in converted_parameters.keys(): converted_parameters = generate_source_frame_parameters(converted_parameters) float_eos_params = {} max_len = 1 - eos_keys = ['eos_polytrope_gamma_0', - 'eos_polytrope_gamma_1', - 'eos_polytrope_gamma_2', - 'eos_polytrope_scaled_pressure_ratio', - 'eos_polytrope_scaled_pressure_2', - 'mass_1_source', 'mass_2_source'] + eos_keys = [ + "eos_polytrope_gamma_0", + "eos_polytrope_gamma_1", + "eos_polytrope_gamma_2", + "eos_polytrope_scaled_pressure_ratio", + "eos_polytrope_scaled_pressure_2", + "mass_1_source", + "mass_2_source", + ] for key in eos_keys: try: - if (len(converted_parameters[key]) > max_len): + if len(converted_parameters[key]) > max_len: max_len = len(converted_parameters[key]) except TypeError: float_eos_params[key] = converted_parameters[key] if len(float_eos_params) == len(eos_keys): # case where all eos params are floats (pinned) logp1, logp2 = log_pressure_reparameterization_conversion( - converted_parameters['eos_polytrope_scaled_pressure_ratio'], - converted_parameters['eos_polytrope_scaled_pressure_2']) - converted_parameters['lambda_1'], converted_parameters['lambda_2'], converted_parameters['eos_check'] = \ + converted_parameters["eos_polytrope_scaled_pressure_ratio"], + converted_parameters["eos_polytrope_scaled_pressure_2"], + ) + converted_parameters["lambda_1"], converted_parameters["lambda_2"], converted_parameters["eos_check"] = ( polytrope_or_causal_params_to_lambda_1_lambda_2( - converted_parameters['eos_polytrope_gamma_0'], + converted_parameters["eos_polytrope_gamma_0"], logp1, - converted_parameters['eos_polytrope_gamma_1'], + converted_parameters["eos_polytrope_gamma_1"], logp2, - converted_parameters['eos_polytrope_gamma_2'], - converted_parameters['mass_1_source'], - converted_parameters['mass_2_source'], - causal=0) + converted_parameters["eos_polytrope_gamma_2"], + converted_parameters["mass_1_source"], + converted_parameters["mass_2_source"], + causal=0, + ) + ) elif len(float_eos_params) < len(eos_keys): # case where some or none are floats (pinned) for key in float_eos_params.keys(): converted_parameters[key] = np.ones(max_len) * converted_parameters[key] - pg0 = converted_parameters['eos_polytrope_gamma_0'] - pg1 = converted_parameters['eos_polytrope_gamma_1'] - pg2 = converted_parameters['eos_polytrope_gamma_2'] - scaledratio = converted_parameters['eos_polytrope_scaled_pressure_ratio'] - scaled_p2 = converted_parameters['eos_polytrope_scaled_pressure_2'] - m1s = converted_parameters['mass_1_source'] - m2s = converted_parameters['mass_2_source'] + pg0 = converted_parameters["eos_polytrope_gamma_0"] + pg1 = converted_parameters["eos_polytrope_gamma_1"] + pg2 = converted_parameters["eos_polytrope_gamma_2"] + scaledratio = converted_parameters["eos_polytrope_scaled_pressure_ratio"] + scaled_p2 = converted_parameters["eos_polytrope_scaled_pressure_2"] + m1s = converted_parameters["mass_1_source"] + m2s = converted_parameters["mass_2_source"] all_lambda_1 = np.empty(0) all_lambda_2 = np.empty(0) all_eos_check = np.empty(0, dtype=bool) - for (pg_0, pg_1, pg_2, scaled_ratio, scaled_p_2, m1_s, - m2_s) in zip(pg0, pg1, pg2, scaledratio, scaled_p2, m1s, m2s): + for pg_0, pg_1, pg_2, scaled_ratio, scaled_p_2, m1_s, m2_s in zip( + pg0, pg1, pg2, scaledratio, scaled_p2, m1s, m2s + ): logp_1, logp_2 = log_pressure_reparameterization_conversion(scaled_ratio, scaled_p_2) - lambda_1, lambda_2, eos_check = \ - polytrope_or_causal_params_to_lambda_1_lambda_2( - pg_0, logp_1, pg_1, logp_2, pg_2, m1_s, m2_s, causal=0) + lambda_1, lambda_2, eos_check = polytrope_or_causal_params_to_lambda_1_lambda_2( + pg_0, logp_1, pg_1, logp_2, pg_2, m1_s, m2_s, causal=0 + ) all_lambda_1 = np.append(all_lambda_1, lambda_1) all_lambda_2 = np.append(all_lambda_2, lambda_2) all_eos_check = np.append(all_eos_check, eos_check) - converted_parameters['lambda_1'] = all_lambda_1 - converted_parameters['lambda_2'] = all_lambda_2 - converted_parameters['eos_check'] = all_eos_check + converted_parameters["lambda_1"] = all_lambda_1 + converted_parameters["lambda_2"] = all_lambda_2 + converted_parameters["eos_check"] = all_eos_check for key in float_eos_params.keys(): converted_parameters[key] = float_eos_params[key] - elif 'eos_v1' in converted_parameters.keys(): + elif "eos_v1" in converted_parameters.keys(): converted_parameters = generate_source_frame_parameters(converted_parameters) float_eos_params = {} max_len = 1 - eos_keys = ['eos_v1', - 'eos_v2', - 'eos_v3', - 'eos_log10_pressure1_cgs', - 'eos_log10_pressure2_cgs', - 'mass_1_source', 'mass_2_source'] + eos_keys = [ + "eos_v1", + "eos_v2", + "eos_v3", + "eos_log10_pressure1_cgs", + "eos_log10_pressure2_cgs", + "mass_1_source", + "mass_2_source", + ] for key in eos_keys: try: - if (len(converted_parameters[key]) > max_len): + if len(converted_parameters[key]) > max_len: max_len = len(converted_parameters[key]) except TypeError: float_eos_params[key] = converted_parameters[key] if len(float_eos_params) == len(eos_keys): # case where all eos params are floats (pinned) - converted_parameters['lambda_1'], converted_parameters['lambda_2'], converted_parameters['eos_check'] = \ + converted_parameters["lambda_1"], converted_parameters["lambda_2"], converted_parameters["eos_check"] = ( polytrope_or_causal_params_to_lambda_1_lambda_2( - converted_parameters['eos_v1'], - converted_parameters['eos_log10_pressure1_cgs'], - converted_parameters['eos_v2'], - converted_parameters['eos_log10_pressure2_cgs'], - converted_parameters['eos_v3'], - converted_parameters['mass_1_source'], - converted_parameters['mass_2_source'], - causal=1) + converted_parameters["eos_v1"], + converted_parameters["eos_log10_pressure1_cgs"], + converted_parameters["eos_v2"], + converted_parameters["eos_log10_pressure2_cgs"], + converted_parameters["eos_v3"], + converted_parameters["mass_1_source"], + converted_parameters["mass_2_source"], + causal=1, + ) + ) elif len(float_eos_params) < len(eos_keys): # case where some or none are floats (pinned) for key in float_eos_params.keys(): converted_parameters[key] = np.ones(max_len) * converted_parameters[key] - v1 = converted_parameters['eos_v1'] - v2 = converted_parameters['eos_v2'] - v3 = converted_parameters['eos_v3'] - logp1 = converted_parameters['eos_log10_pressure1_cgs'] - logp2 = converted_parameters['eos_log10_pressure2_cgs'] - m1s = converted_parameters['mass_1_source'] - m2s = converted_parameters['mass_2_source'] + v1 = converted_parameters["eos_v1"] + v2 = converted_parameters["eos_v2"] + v3 = converted_parameters["eos_v3"] + logp1 = converted_parameters["eos_log10_pressure1_cgs"] + logp2 = converted_parameters["eos_log10_pressure2_cgs"] + m1s = converted_parameters["mass_1_source"] + m2s = converted_parameters["mass_2_source"] all_lambda_1 = np.empty(0) all_lambda_2 = np.empty(0) all_eos_check = np.empty(0, dtype=bool) - for (v_1, v_2, v_3, logp_1, logp_2, m1_s, m2_s) in zip(v1, v2, v3, logp1, logp2, m1s, m2s): - lambda_1, lambda_2, eos_check = \ - polytrope_or_causal_params_to_lambda_1_lambda_2( - v_1, logp_1, v_2, logp_2, v_3, m1_s, m2_s, causal=1) + for v_1, v_2, v_3, logp_1, logp_2, m1_s, m2_s in zip(v1, v2, v3, logp1, logp2, m1s, m2s): + lambda_1, lambda_2, eos_check = polytrope_or_causal_params_to_lambda_1_lambda_2( + v_1, logp_1, v_2, logp_2, v_3, m1_s, m2_s, causal=1 + ) all_lambda_1 = np.append(all_lambda_1, lambda_1) all_lambda_2 = np.append(all_lambda_2, lambda_2) all_eos_check = np.append(all_eos_check, eos_check) - converted_parameters['lambda_1'] = all_lambda_1 - converted_parameters['lambda_2'] = all_lambda_2 - converted_parameters['eos_check'] = all_eos_check + converted_parameters["lambda_1"] = all_lambda_1 + converted_parameters["lambda_2"] = all_lambda_2 + converted_parameters["eos_check"] = all_eos_check for key in float_eos_params.keys(): converted_parameters[key] = float_eos_params[key] - elif 'lambda_symmetric' in converted_parameters.keys(): - if 'lambda_antisymmetric' in converted_parameters.keys(): - converted_parameters['lambda_1'], converted_parameters['lambda_2'] =\ + elif "lambda_symmetric" in converted_parameters.keys(): + if "lambda_antisymmetric" in converted_parameters.keys(): + converted_parameters["lambda_1"], converted_parameters["lambda_2"] = ( lambda_symmetric_lambda_antisymmetric_to_lambda_1_lambda_2( - converted_parameters['lambda_symmetric'], - converted_parameters['lambda_antisymmetric']) - elif 'mass_ratio' in converted_parameters.keys(): - if 'binary_love_uniform' in converted_parameters.keys(): - converted_parameters['lambda_1'], converted_parameters['lambda_2'] =\ + converted_parameters["lambda_symmetric"], converted_parameters["lambda_antisymmetric"] + ) + ) + elif "mass_ratio" in converted_parameters.keys(): + if "binary_love_uniform" in converted_parameters.keys(): + converted_parameters["lambda_1"], converted_parameters["lambda_2"] = ( binary_love_lambda_symmetric_to_lambda_1_lambda_2_manual_marginalisation( - converted_parameters['binary_love_uniform'], - converted_parameters['lambda_symmetric'], - converted_parameters['mass_ratio']) + converted_parameters["binary_love_uniform"], + converted_parameters["lambda_symmetric"], + converted_parameters["mass_ratio"], + ) + ) else: - converted_parameters['lambda_1'], converted_parameters['lambda_2'] =\ + converted_parameters["lambda_1"], converted_parameters["lambda_2"] = ( binary_love_lambda_symmetric_to_lambda_1_lambda_2_automatic_marginalisation( - converted_parameters['lambda_symmetric'], - converted_parameters['mass_ratio']) + converted_parameters["lambda_symmetric"], converted_parameters["mass_ratio"] + ) + ) - added_keys = [key for key in converted_parameters.keys() - if key not in original_keys] + added_keys = [key for key in converted_parameters.keys() if key not in original_keys] return converted_parameters, added_keys def log_pressure_reparameterization_conversion(scaled_pressure_ratio, scaled_pressure_2): - ''' + """ Converts the reparameterization joining pressures from (scaled_pressure_ratio,scaled_pressure_2) to (log10_pressure_1,log10_pressure_2). This reparameterization with a triangular prior (with mode = max) on scaled_pressure_2 @@ -601,8 +627,8 @@ def log_pressure_reparameterization_conversion(scaled_pressure_ratio, scaled_pre log10_pressure_1, log10_pressure_2: float joining pressures in the original parameterization - ''' - minimum_pressure = 33. + """ + minimum_pressure = 33.0 log10_pressure_1 = (scaled_pressure_ratio * scaled_pressure_2) + minimum_pressure log10_pressure_2 = minimum_pressure + scaled_pressure_2 @@ -610,7 +636,7 @@ def log_pressure_reparameterization_conversion(scaled_pressure_ratio, scaled_pre def spectral_pca_to_spectral(gamma_pca_0, gamma_pca_1, gamma_pca_2, gamma_pca_3): - ''' + """ Change of basis on parameter space from an efficient space to sample in (sample space) to the space used in spectral eos model (model space). @@ -629,27 +655,28 @@ def spectral_pca_to_spectral(gamma_pca_0, gamma_pca_1, gamma_pca_2, gamma_pca_3) converted_gamma_parameters: np.array() array of gamma_0, gamma_1, gamma_2, gamma_3 in model space - ''' + """ sampled_pca_gammas = np.array([gamma_pca_0, gamma_pca_1, gamma_pca_2, gamma_pca_3]) transformation_matrix = np.array( [ [0.43801, -0.76705, 0.45143, 0.12646], [-0.53573, 0.17169, 0.67968, 0.47070], [0.52660, 0.31255, -0.19454, 0.76626], - [-0.49379, -0.53336, -0.54444, 0.41868] + [-0.49379, -0.53336, -0.54444, 0.41868], ] ) model_space_mean = np.array([0.89421, 0.33878, -0.07894, 0.00393]) model_space_standard_deviation = np.array([0.35700, 0.25769, 0.05452, 0.00312]) - converted_gamma_parameters = \ - model_space_mean + model_space_standard_deviation * np.dot(transformation_matrix, sampled_pca_gammas) + converted_gamma_parameters = model_space_mean + model_space_standard_deviation * np.dot( + transformation_matrix, sampled_pca_gammas + ) return converted_gamma_parameters def spectral_params_to_lambda_1_lambda_2(gamma_0, gamma_1, gamma_2, gamma_3, mass_1_source, mass_2_source): - ''' + """ Converts from the 4 spectral decomposition parameters and the source masses to the tidal deformability parameters. @@ -668,7 +695,7 @@ def spectral_params_to_lambda_1_lambda_2(gamma_0, gamma_1, gamma_2, gamma_3, mas whether or not the equation of state is viable / if eos_check = False, lambdas are 0 and the sample is rejected. - ''' + """ eos_check = True if lalsim_SimNeutronStarEOS4ParamSDGammaCheck(gamma_0, gamma_1, gamma_2, gamma_3) != 0: lambda_1 = 0.0 @@ -687,7 +714,8 @@ def spectral_params_to_lambda_1_lambda_2(gamma_0, gamma_1, gamma_2, gamma_3, mas def polytrope_or_causal_params_to_lambda_1_lambda_2( - param1, log10_pressure1_cgs, param2, log10_pressure2_cgs, param3, mass_1_source, mass_2_source, causal): + param1, log10_pressure1_cgs, param2, log10_pressure2_cgs, param3, mass_1_source, mass_2_source, causal +): """ Converts parameters from sampled dynamic piecewise polytrope parameters to component tidal deformablity parameters. @@ -727,12 +755,18 @@ def polytrope_or_causal_params_to_lambda_1_lambda_2( else: if causal == 0: eos = lalsim_SimNeutronStarEOS3PieceDynamicPolytrope( - param1, log10_pressure1_cgs - 1., param2, log10_pressure2_cgs - 1., param3) + param1, log10_pressure1_cgs - 1.0, param2, log10_pressure2_cgs - 1.0, param3 + ) else: eos = lalsim_SimNeutronStarEOS3PieceCausalAnalytic( - param1, log10_pressure1_cgs - 1., param2, log10_pressure2_cgs - 1., param3) - if lalsim_SimNeutronStarEOS3PDViableFamilyCheck( - param1, log10_pressure1_cgs - 1., param2, log10_pressure2_cgs - 1., param3, causal) != 0: + param1, log10_pressure1_cgs - 1.0, param2, log10_pressure2_cgs - 1.0, param3 + ) + if ( + lalsim_SimNeutronStarEOS3PDViableFamilyCheck( + param1, log10_pressure1_cgs - 1.0, param2, log10_pressure2_cgs - 1.0, param3, causal + ) + != 0 + ): lambda_1 = 0.0 lambda_2 = 0.0 eos_check = False @@ -799,9 +833,9 @@ def lambda_from_mass_and_family(mass_i, family): """ radius = lalsim_SimNeutronStarRadius(mass_i * solar_mass, family) love_number_k2 = lalsim_SimNeutronStarLoveNumberK2(mass_i * solar_mass, family) - mass_geometrized = mass_i * solar_mass * gravitational_constant / speed_of_light ** 2. + mass_geometrized = mass_i * solar_mass * gravitational_constant / speed_of_light**2.0 compactness = mass_geometrized / radius - lambda_i = (2. / 3.) * love_number_k2 / compactness ** 5. + lambda_i = (2.0 / 3.0) * love_number_k2 / compactness**5.0 return lambda_i @@ -886,12 +920,8 @@ def chirp_mass_and_mass_ratio_to_component_masses(chirp_mass, mass_ratio): mass_2: float Mass of the lighter object """ - total_mass = chirp_mass_and_mass_ratio_to_total_mass(chirp_mass=chirp_mass, - mass_ratio=mass_ratio) - mass_1, mass_2 = ( - total_mass_and_mass_ratio_to_component_masses( - total_mass=total_mass, mass_ratio=mass_ratio) - ) + total_mass = chirp_mass_and_mass_ratio_to_total_mass(chirp_mass=chirp_mass, mass_ratio=mass_ratio) + mass_1, mass_2 = total_mass_and_mass_ratio_to_component_masses(total_mass=total_mass, mass_ratio=mass_ratio) return mass_1, mass_2 @@ -910,8 +940,8 @@ def symmetric_mass_ratio_to_mass_ratio(symmetric_mass_ratio): Mass ratio of the binary """ - temp = (1 / symmetric_mass_ratio / 2 - 1) - return temp - (temp ** 2 - 1) ** 0.5 + temp = 1 / symmetric_mass_ratio / 2 - 1 + return temp - (temp**2 - 1) ** 0.5 def chirp_mass_and_total_mass_to_symmetric_mass_ratio(chirp_mass, total_mass): @@ -958,7 +988,7 @@ def chirp_mass_and_primary_mass_to_mass_ratio(chirp_mass, mass_1): Mass ratio (mass_2/mass_1) of the binary """ a = (chirp_mass / mass_1) ** 5 - t0 = np.cbrt(9 * a + np.sqrt(3) * np.sqrt(27 * a ** 2 - 4 * a ** 3)) + t0 = np.cbrt(9 * a + np.sqrt(3) * np.sqrt(27 * a**2 - 4 * a**3)) t1 = np.cbrt(2) * 3 ** (2 / 3) t2 = np.cbrt(2 / 3) * a return t2 / t0 + t0 / t1 @@ -984,7 +1014,7 @@ def chirp_mass_and_mass_ratio_to_total_mass(chirp_mass, mass_ratio): """ with np.errstate(invalid="ignore"): - return chirp_mass * (1 + mass_ratio) ** 1.2 / mass_ratio ** 0.6 + return chirp_mass * (1 + mass_ratio) ** 1.2 / mass_ratio**0.6 def component_masses_to_chirp_mass(mass_1, mass_2): @@ -1086,10 +1116,9 @@ def mass_1_and_chirp_mass_to_mass_ratio(mass_1, chirp_mass): Mass ratio of the binary """ temp = (chirp_mass / mass_1) ** 5 - mass_ratio = (2 / 3 / (3 ** 0.5 * (27 * temp ** 2 - 4 * temp ** 3) ** 0.5 + - 9 * temp)) ** (1 / 3) * temp + \ - ((3 ** 0.5 * (27 * temp ** 2 - 4 * temp ** 3) ** 0.5 + - 9 * temp) / (2 * 3 ** 2)) ** (1 / 3) + mass_ratio = (2 / 3 / (3**0.5 * (27 * temp**2 - 4 * temp**3) ** 0.5 + 9 * temp)) ** (1 / 3) * temp + ( + (3**0.5 * (27 * temp**2 - 4 * temp**3) ** 0.5 + 9 * temp) / (2 * 3**2) + ) ** (1 / 3) return mass_ratio @@ -1143,9 +1172,14 @@ def lambda_1_lambda_2_to_lambda_tilde(lambda_1, lambda_2, mass_1, mass_2): eta = component_masses_to_symmetric_mass_ratio(mass_1, mass_2) lambda_plus = lambda_1 + lambda_2 lambda_minus = lambda_1 - lambda_2 - lambda_tilde = 8 / 13 * ( - (1 + 7 * eta - 31 * eta**2) * lambda_plus + - (1 - 4 * eta)**0.5 * (1 + 9 * eta - 11 * eta**2) * lambda_minus) + lambda_tilde = ( + 8 + / 13 + * ( + (1 + 7 * eta - 31 * eta**2) * lambda_plus + + (1 - 4 * eta) ** 0.5 * (1 + 9 * eta - 11 * eta**2) * lambda_minus + ) + ) return lambda_tilde @@ -1175,15 +1209,18 @@ def lambda_1_lambda_2_to_delta_lambda_tilde(lambda_1, lambda_2, mass_1, mass_2): eta = component_masses_to_symmetric_mass_ratio(mass_1, mass_2) lambda_plus = lambda_1 + lambda_2 lambda_minus = lambda_1 - lambda_2 - delta_lambda_tilde = 1 / 2 * ( - (1 - 4 * eta) ** 0.5 * (1 - 13272 / 1319 * eta + 8944 / 1319 * eta**2) * - lambda_plus + (1 - 15910 / 1319 * eta + 32850 / 1319 * eta ** 2 + - 3380 / 1319 * eta ** 3) * lambda_minus) + delta_lambda_tilde = ( + 1 + / 2 + * ( + (1 - 4 * eta) ** 0.5 * (1 - 13272 / 1319 * eta + 8944 / 1319 * eta**2) * lambda_plus + + (1 - 15910 / 1319 * eta + 32850 / 1319 * eta**2 + 3380 / 1319 * eta**3) * lambda_minus + ) + ) return delta_lambda_tilde -def lambda_tilde_delta_lambda_tilde_to_lambda_1_lambda_2( - lambda_tilde, delta_lambda_tilde, mass_1, mass_2): +def lambda_tilde_delta_lambda_tilde_to_lambda_1_lambda_2(lambda_tilde, delta_lambda_tilde, mass_1, mass_2): """ Convert from dominant tidal terms to individual tidal parameters. @@ -1209,28 +1246,29 @@ def lambda_tilde_delta_lambda_tilde_to_lambda_1_lambda_2( """ eta = component_masses_to_symmetric_mass_ratio(mass_1, mass_2) - coefficient_1 = (1 + 7 * eta - 31 * eta**2) - coefficient_2 = (1 - 4 * eta)**0.5 * (1 + 9 * eta - 11 * eta**2) - coefficient_3 = (1 - 4 * eta)**0.5 *\ - (1 - 13272 / 1319 * eta + 8944 / 1319 * eta**2) - coefficient_4 = (1 - 15910 / 1319 * eta + 32850 / 1319 * eta**2 + - 3380 / 1319 * eta**3) - lambda_1 =\ - (13 * lambda_tilde / 8 * (coefficient_3 - coefficient_4) - - 2 * delta_lambda_tilde * (coefficient_1 - coefficient_2))\ - / ((coefficient_1 + coefficient_2) * (coefficient_3 - coefficient_4) - - (coefficient_1 - coefficient_2) * (coefficient_3 + coefficient_4)) - lambda_2 =\ - (13 * lambda_tilde / 8 * (coefficient_3 + coefficient_4) - - 2 * delta_lambda_tilde * (coefficient_1 + coefficient_2)) \ - / ((coefficient_1 - coefficient_2) * (coefficient_3 + coefficient_4) - - (coefficient_1 + coefficient_2) * (coefficient_3 - coefficient_4)) + coefficient_1 = 1 + 7 * eta - 31 * eta**2 + coefficient_2 = (1 - 4 * eta) ** 0.5 * (1 + 9 * eta - 11 * eta**2) + coefficient_3 = (1 - 4 * eta) ** 0.5 * (1 - 13272 / 1319 * eta + 8944 / 1319 * eta**2) + coefficient_4 = 1 - 15910 / 1319 * eta + 32850 / 1319 * eta**2 + 3380 / 1319 * eta**3 + lambda_1 = ( + 13 * lambda_tilde / 8 * (coefficient_3 - coefficient_4) + - 2 * delta_lambda_tilde * (coefficient_1 - coefficient_2) + ) / ( + (coefficient_1 + coefficient_2) * (coefficient_3 - coefficient_4) + - (coefficient_1 - coefficient_2) * (coefficient_3 + coefficient_4) + ) + lambda_2 = ( + 13 * lambda_tilde / 8 * (coefficient_3 + coefficient_4) + - 2 * delta_lambda_tilde * (coefficient_1 + coefficient_2) + ) / ( + (coefficient_1 - coefficient_2) * (coefficient_3 + coefficient_4) + - (coefficient_1 + coefficient_2) * (coefficient_3 - coefficient_4) + ) return lambda_1, lambda_2 -def lambda_tilde_to_lambda_1_lambda_2( - lambda_tilde, mass_1, mass_2): +def lambda_tilde_to_lambda_1_lambda_2(lambda_tilde, mass_1, mass_2): """ Convert from dominant tidal term to individual tidal parameters assuming lambda_1 * mass_1**5 = lambda_2 * mass_2**5. @@ -1255,9 +1293,12 @@ def lambda_tilde_to_lambda_1_lambda_2( """ eta = component_masses_to_symmetric_mass_ratio(mass_1, mass_2) q = mass_2 / mass_1 - lambda_1 = 13 / 8 * lambda_tilde / ( - (1 + 7 * eta - 31 * eta**2) * (1 + q**-5) + - (1 - 4 * eta)**0.5 * (1 + 9 * eta - 11 * eta**2) * (1 - q**-5)) + lambda_1 = ( + 13 + / 8 + * lambda_tilde + / ((1 + 7 * eta - 31 * eta**2) * (1 + q**-5) + (1 - 4 * eta) ** 0.5 * (1 + 9 * eta - 11 * eta**2) * (1 - q**-5)) + ) lambda_2 = lambda_1 / q**5 return lambda_1, lambda_2 @@ -1280,7 +1321,7 @@ def lambda_1_lambda_2_to_lambda_symmetric(lambda_1, lambda_2): lambda_symmetric: float Symmetric tidal parameter. """ - lambda_symmetric = (lambda_2 + lambda_1) / 2. + lambda_symmetric = (lambda_2 + lambda_1) / 2.0 return lambda_symmetric @@ -1302,7 +1343,7 @@ def lambda_1_lambda_2_to_lambda_antisymmetric(lambda_1, lambda_2): lambda_antisymmetric: float Antisymmetric tidal parameter. """ - lambda_antisymmetric = (lambda_2 - lambda_1) / 2. + lambda_antisymmetric = (lambda_2 - lambda_1) / 2.0 return lambda_antisymmetric @@ -1376,7 +1417,6 @@ def lambda_symmetric_lambda_antisymmetric_to_lambda_1_lambda_2(lambda_symmetric, def binary_love_fit_lambda_symmetric_mass_ratio_to_lambda_antisymmetric(lambda_symmetric, mass_ratio): - """ Convert from symmetric tidal terms and mass ratio to antisymmetric tidal terms using BinaryLove relations. @@ -1403,7 +1443,7 @@ def binary_love_fit_lambda_symmetric_mass_ratio_to_lambda_antisymmetric(lambda_s lambda_antisymmetric: float Antisymmetric tidal parameter. """ - lambda_symmetric_m1o5 = np.power(lambda_symmetric, -1. / 5.) + lambda_symmetric_m1o5 = np.power(lambda_symmetric, -1.0 / 5.0) lambda_symmetric_m2o5 = lambda_symmetric_m1o5 * lambda_symmetric_m1o5 lambda_symmetric_m3o5 = lambda_symmetric_m2o5 * lambda_symmetric_m1o5 @@ -1413,8 +1453,8 @@ def binary_love_fit_lambda_symmetric_mass_ratio_to_lambda_antisymmetric(lambda_s # Eqn.2 from CHZ, incorporating the dependence on mass ratio n_polytropic = 0.743 # average polytropic index for the EoSs included in the fit - q_for_Fnofq = np.power(q, 10. / (3. - n_polytropic)) - Fnofq = (1. - q_for_Fnofq) / (1. + q_for_Fnofq) + q_for_Fnofq = np.power(q, 10.0 / (3.0 - n_polytropic)) + Fnofq = (1.0 - q_for_Fnofq) / (1.0 + q_for_Fnofq) # b_ij and c_ij coefficients are given in Table I of CHZ @@ -1435,23 +1475,34 @@ def binary_love_fit_lambda_symmetric_mass_ratio_to_lambda_antisymmetric(lambda_s # Eqn 1 from CHZ, giving the lambda_antisymmetric_fitOnly # not yet accounting for the uncertainty in the fit - numerator = 1.0 + \ - (b11 * q * lambda_symmetric_m1o5) + (b12 * q2 * lambda_symmetric_m1o5) + \ - (b21 * q * lambda_symmetric_m2o5) + (b22 * q2 * lambda_symmetric_m2o5) + \ - (b31 * q * lambda_symmetric_m3o5) + (b32 * q2 * lambda_symmetric_m3o5) + numerator = ( + 1.0 + + (b11 * q * lambda_symmetric_m1o5) + + (b12 * q2 * lambda_symmetric_m1o5) + + (b21 * q * lambda_symmetric_m2o5) + + (b22 * q2 * lambda_symmetric_m2o5) + + (b31 * q * lambda_symmetric_m3o5) + + (b32 * q2 * lambda_symmetric_m3o5) + ) - denominator = 1.0 + \ - (c11 * q * lambda_symmetric_m1o5) + (c12 * q2 * lambda_symmetric_m1o5) + \ - (c21 * q * lambda_symmetric_m2o5) + (c22 * q2 * lambda_symmetric_m2o5) + \ - (c31 * q * lambda_symmetric_m3o5) + (c32 * q2 * lambda_symmetric_m3o5) + denominator = ( + 1.0 + + (c11 * q * lambda_symmetric_m1o5) + + (c12 * q2 * lambda_symmetric_m1o5) + + (c21 * q * lambda_symmetric_m2o5) + + (c22 * q2 * lambda_symmetric_m2o5) + + (c31 * q * lambda_symmetric_m3o5) + + (c32 * q2 * lambda_symmetric_m3o5) + ) lambda_antisymmetric_fitOnly = Fnofq * lambda_symmetric * numerator / denominator return lambda_antisymmetric_fitOnly -def binary_love_lambda_symmetric_to_lambda_1_lambda_2_manual_marginalisation(binary_love_uniform, - lambda_symmetric, mass_ratio): +def binary_love_lambda_symmetric_to_lambda_1_lambda_2_manual_marginalisation( + binary_love_uniform, lambda_symmetric, mass_ratio +): """ Convert from symmetric tidal terms to lambda_1 and lambda_2 using BinaryLove relations @@ -1480,8 +1531,9 @@ def binary_love_lambda_symmetric_to_lambda_1_lambda_2_manual_marginalisation(bin lambda_2: float Tidal parameter of less massive neutron star. """ - lambda_antisymmetric_fitOnly = binary_love_fit_lambda_symmetric_mass_ratio_to_lambda_antisymmetric(lambda_symmetric, - mass_ratio) + lambda_antisymmetric_fitOnly = binary_love_fit_lambda_symmetric_mass_ratio_to_lambda_antisymmetric( + lambda_symmetric, mass_ratio + ) lambda_symmetric_sqrt = np.sqrt(lambda_symmetric) @@ -1495,7 +1547,7 @@ def binary_love_lambda_symmetric_to_lambda_1_lambda_2_manual_marginalisation(bin mu_3 = 0.5168637 mu_4 = -11.2765281 mu_5 = 14.9499544 - mu_6 = - 4.6638851 + mu_6 = -4.6638851 sigma_1 = -0.0000739 sigma_2 = 0.0103778 @@ -1509,60 +1561,58 @@ def binary_love_lambda_symmetric_to_lambda_1_lambda_2_manual_marginalisation(bin # uncertainty in the mean of the lambdaS residual fit, # using coefficients mu_1, mu_2 and mu_3 from Table II of CHZ - lambda_antisymmetric_lambda_symmetric_meanCorr = \ - (mu_1 / (lambda_symmetric * lambda_symmetric)) + \ - (mu_2 / lambda_symmetric) + mu_3 + lambda_antisymmetric_lambda_symmetric_meanCorr = ( + (mu_1 / (lambda_symmetric * lambda_symmetric)) + (mu_2 / lambda_symmetric) + mu_3 + ) # Eqn 8 from CHZ, correction on fit for lambdaA caused by # uncertainty in the standard deviation of lambdaS residual fit, # using coefficients sigma_1, sigma_2, sigma_3 and sigma_4 from Table II - lambda_antisymmetric_lambda_symmetric_stdCorr = \ - (sigma_1 * lambda_symmetric * lambda_symmetric_sqrt) + \ - (sigma_2 * lambda_symmetric) + \ - (sigma_3 * lambda_symmetric_sqrt) + sigma_4 + lambda_antisymmetric_lambda_symmetric_stdCorr = ( + (sigma_1 * lambda_symmetric * lambda_symmetric_sqrt) + + (sigma_2 * lambda_symmetric) + + (sigma_3 * lambda_symmetric_sqrt) + + sigma_4 + ) # Eqn 7, correction on fit for lambdaA caused by # uncertainty in the mean of the q residual fit, # using coefficients mu_4, mu_5 and mu_6 from Table II - lambda_antisymmetric_mass_ratio_meanCorr = \ - (mu_4 * q2) + (mu_5 * q) + mu_6 + lambda_antisymmetric_mass_ratio_meanCorr = (mu_4 * q2) + (mu_5 * q) + mu_6 # Eqn 9 from CHZ, correction on fit for lambdaA caused by # uncertainty in the standard deviation of the q residual fit, # using coefficients sigma_5, sigma_6 and sigma_7 from Table II - lambda_antisymmetric_mass_ratio_stdCorr = \ - (sigma_5 * q2) + (sigma_6 * q) + sigma_7 + lambda_antisymmetric_mass_ratio_stdCorr = (sigma_5 * q2) + (sigma_6 * q) + sigma_7 # Eqn 4 from CHZ, averaging the corrections from the # mean of the residual fits - lambda_antisymmetric_meanCorr = \ - (lambda_antisymmetric_lambda_symmetric_meanCorr + - lambda_antisymmetric_mass_ratio_meanCorr) / 2. + lambda_antisymmetric_meanCorr = ( + lambda_antisymmetric_lambda_symmetric_meanCorr + lambda_antisymmetric_mass_ratio_meanCorr + ) / 2.0 # Eqn 5 from CHZ, averaging the corrections from the # standard deviations of the residual fits - lambda_antisymmetric_stdCorr = \ - np.sqrt(np.square(lambda_antisymmetric_lambda_symmetric_stdCorr) + - np.square(lambda_antisymmetric_mass_ratio_stdCorr)) + lambda_antisymmetric_stdCorr = np.sqrt( + np.square(lambda_antisymmetric_lambda_symmetric_stdCorr) + np.square(lambda_antisymmetric_mass_ratio_stdCorr) + ) # Draw a correction on the fit from a # Gaussian distribution with width lambda_antisymmetric_stdCorr # this is done by sampling a percent point function (inverse cdf) # through a U{0,1} variable called binary_love_uniform - lambda_antisymmetric_scatter = norm.ppf(binary_love_uniform, loc=0., - scale=lambda_antisymmetric_stdCorr) + lambda_antisymmetric_scatter = norm.ppf(binary_love_uniform, loc=0.0, scale=lambda_antisymmetric_stdCorr) # Add the correction of the residual mean # and the Gaussian scatter to the lambda_antisymmetric_fitOnly value - lambda_antisymmetric = lambda_antisymmetric_fitOnly + \ - (lambda_antisymmetric_meanCorr + lambda_antisymmetric_scatter) + lambda_antisymmetric = lambda_antisymmetric_fitOnly + (lambda_antisymmetric_meanCorr + lambda_antisymmetric_scatter) lambda_1 = lambda_symmetric_lambda_antisymmetric_to_lambda_1(lambda_symmetric, lambda_antisymmetric) lambda_2 = lambda_symmetric_lambda_antisymmetric_to_lambda_2(lambda_symmetric, lambda_antisymmetric) @@ -1624,42 +1674,41 @@ def binary_love_lambda_symmetric_to_lambda_1_lambda_2_automatic_marginalisation( binary_love_uniform = random.rng.uniform(0, 1, len(lambda_symmetric)) lambda_1, lambda_2 = binary_love_lambda_symmetric_to_lambda_1_lambda_2_manual_marginalisation( - binary_love_uniform, lambda_symmetric, mass_ratio) + binary_love_uniform, lambda_symmetric, mass_ratio + ) return lambda_1, lambda_2 -def _generate_all_cbc_parameters(sample, defaults, base_conversion, - likelihood=None, priors=None, npool=1): +def _generate_all_cbc_parameters(sample, defaults, base_conversion, likelihood=None, priors=None, npool=1): """Generate all cbc parameters, helper function for BBH/BNS""" output_sample = sample.copy() waveform_defaults = defaults for key in waveform_defaults: try: - output_sample[key] = \ - likelihood.waveform_generator.waveform_arguments[key] + output_sample[key] = likelihood.waveform_generator.waveform_arguments[key] except (KeyError, AttributeError): default = waveform_defaults[key] output_sample[key] = default - logger.debug('Assuming {} = {}'.format(key, default)) + logger.debug(f"Assuming {key} = {default}") output_sample = fill_from_fixed_priors(output_sample, priors) output_sample, _ = base_conversion(output_sample) if likelihood is not None: - compute_per_detector_log_likelihoods( - samples=output_sample, likelihood=likelihood, npool=npool) + compute_per_detector_log_likelihoods(samples=output_sample, likelihood=likelihood, npool=npool) marginalized_parameters = getattr(likelihood, "_marginalized_parameters", list()) if len(marginalized_parameters) > 0: try: generate_posterior_samples_from_marginalized_likelihood( - samples=output_sample, likelihood=likelihood, npool=npool) + samples=output_sample, likelihood=likelihood, npool=npool + ) except MarginalizedLikelihoodReconstructionError as e: logger.warning( "Marginalised parameter reconstruction failed with message " - "{}. Some parameters may not have the intended " - "interpretation.".format(e) + f"{e}. Some parameters may not have the intended " + "interpretation." ) if priors is not None: misnamed_marginalizations = dict( @@ -1669,35 +1718,25 @@ def _generate_all_cbc_parameters(sample, defaults, base_conversion, ) for par in marginalized_parameters: name = misnamed_marginalizations.get(par, par) - if ( - getattr(likelihood, f'{name}_marginalization', False) - and par in likelihood.priors - ): + if getattr(likelihood, f"{name}_marginalization", False) and par in likelihood.priors: priors[par] = likelihood.priors[par] - if ( - not getattr(likelihood, "reference_frame", "sky") == "sky" - or "geocent" not in getattr(likelihood, "time_reference", "geocent") + if not getattr(likelihood, "reference_frame", "sky") == "sky" or "geocent" not in getattr( + likelihood, "time_reference", "geocent" ): try: - generate_sky_frame_parameters( - samples=output_sample, likelihood=likelihood - ) + generate_sky_frame_parameters(samples=output_sample, likelihood=likelihood) except TypeError: - logger.info( - "Failed to generate sky frame parameters for type {}" - .format(type(output_sample)) - ) + logger.info(f"Failed to generate sky frame parameters for type {type(output_sample)}") compute_snrs(output_sample, likelihood, npool=npool) - for key, func in zip(["mass", "spin", "source frame"], [ - generate_mass_parameters, generate_spin_parameters, - generate_source_frame_parameters]): + for key, func in zip( + ["mass", "spin", "source frame"], + [generate_mass_parameters, generate_spin_parameters, generate_source_frame_parameters], + ): try: output_sample = func(output_sample) except KeyError as e: - logger.info( - "Generation of {} parameters failed with message {}".format( - key, e)) + logger.info(f"Generation of {key} parameters failed with message {e}") return output_sample @@ -1723,13 +1762,15 @@ def generate_all_bbh_parameters(sample, likelihood=None, priors=None, npool=1): this function, the initial value of :code:`likelihood.parameters` are saved and reset at the end of the function. """ - waveform_defaults = { - 'reference_frequency': 50.0, 'waveform_approximant': 'IMRPhenomPv2', - 'minimum_frequency': 20.0} + waveform_defaults = {"reference_frequency": 50.0, "waveform_approximant": "IMRPhenomPv2", "minimum_frequency": 20.0} output_sample = _generate_all_cbc_parameters( - sample, defaults=waveform_defaults, + sample, + defaults=waveform_defaults, base_conversion=convert_to_lal_binary_black_hole_parameters, - likelihood=likelihood, priors=priors, npool=npool) + likelihood=likelihood, + priors=priors, + npool=npool, + ) return output_sample @@ -1754,18 +1795,19 @@ def generate_all_bns_parameters(sample, likelihood=None, priors=None, npool=1): npool: int, (default=1) If given, perform generation (where possible) using a multiprocessing pool """ - waveform_defaults = { - 'reference_frequency': 50.0, 'waveform_approximant': 'TaylorF2', - 'minimum_frequency': 20.0} + waveform_defaults = {"reference_frequency": 50.0, "waveform_approximant": "TaylorF2", "minimum_frequency": 20.0} output_sample = _generate_all_cbc_parameters( - sample, defaults=waveform_defaults, + sample, + defaults=waveform_defaults, base_conversion=convert_to_lal_binary_neutron_star_parameters, - likelihood=likelihood, priors=priors, npool=npool) + likelihood=likelihood, + priors=priors, + npool=npool, + ) try: output_sample = generate_tidal_parameters(output_sample) except KeyError as e: - logger.debug( - "Generation of tidal parameters failed with message {}".format(e)) + logger.debug(f"Generation of tidal parameters failed with message {e}") return output_sample @@ -1799,7 +1841,7 @@ def generate_specific_parameters(sample, parameters): if key in updated_sample: output_sample[key] = updated_sample[key] else: - raise KeyError("{} not in converted sample.".format(key)) + raise KeyError(f"{key} not in converted sample.") return output_sample @@ -1826,7 +1868,7 @@ def fill_from_fixed_priors(sample, priors): def generate_component_masses(sample, require_add=False, source=False): - """" + """ " Add the component masses to the dataframe/dictionary We add: mass_1, mass_2 @@ -1855,11 +1897,13 @@ def generate_component_masses(sample, require_add=False, source=False): Returns dict : the updated dictionary """ + def check_and_return_quietly(require_add, sample): if require_add: raise KeyError("Insufficient mass parameters in input dictionary") else: return sample + output_sample = sample.copy() if source: @@ -1877,59 +1921,41 @@ def check_and_return_quietly(require_add, sample): if mass_2_key in sample.keys(): return output_sample if total_mass_key in sample.keys(): - output_sample[mass_2_key] = output_sample[total_mass_key] - ( - output_sample[mass_1_key] - ) + output_sample[mass_2_key] = output_sample[total_mass_key] - (output_sample[mass_1_key]) return output_sample elif "mass_ratio" in sample.keys(): pass elif "symmetric_mass_ratio" in sample.keys(): - output_sample["mass_ratio"] = ( - symmetric_mass_ratio_to_mass_ratio( - output_sample["symmetric_mass_ratio"]) - ) + output_sample["mass_ratio"] = symmetric_mass_ratio_to_mass_ratio(output_sample["symmetric_mass_ratio"]) elif chirp_mass_key in sample.keys(): - output_sample["mass_ratio"] = ( - mass_1_and_chirp_mass_to_mass_ratio( - mass_1=output_sample[mass_1_key], - chirp_mass=output_sample[chirp_mass_key]) + output_sample["mass_ratio"] = mass_1_and_chirp_mass_to_mass_ratio( + mass_1=output_sample[mass_1_key], chirp_mass=output_sample[chirp_mass_key] ) else: return check_and_return_quietly(require_add, sample) - output_sample[mass_2_key] = ( - output_sample["mass_ratio"] * output_sample[mass_1_key] - ) + output_sample[mass_2_key] = output_sample["mass_ratio"] * output_sample[mass_1_key] return output_sample elif mass_2_key in sample.keys(): # mass_1 is not in the dict if total_mass_key in sample.keys(): - output_sample[mass_1_key] = ( - output_sample[total_mass_key] - output_sample[mass_2_key] - ) + output_sample[mass_1_key] = output_sample[total_mass_key] - output_sample[mass_2_key] return output_sample elif "mass_ratio" in sample.keys(): pass elif "symmetric_mass_ratio" in sample.keys(): - output_sample["mass_ratio"] = ( - symmetric_mass_ratio_to_mass_ratio( - output_sample["symmetric_mass_ratio"]) - ) + output_sample["mass_ratio"] = symmetric_mass_ratio_to_mass_ratio(output_sample["symmetric_mass_ratio"]) elif chirp_mass_key in sample.keys(): - output_sample["mass_ratio"] = ( - mass_2_and_chirp_mass_to_mass_ratio( - mass_2=output_sample[mass_2_key], - chirp_mass=output_sample[chirp_mass_key]) + output_sample["mass_ratio"] = mass_2_and_chirp_mass_to_mass_ratio( + mass_2=output_sample[mass_2_key], chirp_mass=output_sample[chirp_mass_key] ) else: check_and_return_quietly(require_add, sample) - output_sample[mass_1_key] = 1 / output_sample["mass_ratio"] * ( - output_sample[mass_2_key] - ) + output_sample[mass_1_key] = 1 / output_sample["mass_ratio"] * (output_sample[mass_2_key]) return output_sample @@ -1938,20 +1964,12 @@ def check_and_return_quietly(require_add, sample): if "mass_ratio" in sample.keys(): pass # We have everything we need already elif "symmetric_mass_ratio" in sample.keys(): - output_sample["mass_ratio"] = ( - symmetric_mass_ratio_to_mass_ratio( - output_sample["symmetric_mass_ratio"]) - ) + output_sample["mass_ratio"] = symmetric_mass_ratio_to_mass_ratio(output_sample["symmetric_mass_ratio"]) elif chirp_mass_key in sample.keys(): - output_sample["symmetric_mass_ratio"] = ( - chirp_mass_and_total_mass_to_symmetric_mass_ratio( - chirp_mass=output_sample[chirp_mass_key], - total_mass=output_sample[total_mass_key]) - ) - output_sample["mass_ratio"] = ( - symmetric_mass_ratio_to_mass_ratio( - output_sample["symmetric_mass_ratio"]) + output_sample["symmetric_mass_ratio"] = chirp_mass_and_total_mass_to_symmetric_mass_ratio( + chirp_mass=output_sample[chirp_mass_key], total_mass=output_sample[total_mass_key] ) + output_sample["mass_ratio"] = symmetric_mass_ratio_to_mass_ratio(output_sample["symmetric_mass_ratio"]) else: return check_and_return_quietly(require_add, sample) @@ -1959,27 +1977,19 @@ def check_and_return_quietly(require_add, sample): if "mass_ratio" in sample.keys(): pass elif "symmetric_mass_ratio" in sample.keys(): - output_sample["mass_ratio"] = ( - symmetric_mass_ratio_to_mass_ratio( - sample["symmetric_mass_ratio"]) - ) + output_sample["mass_ratio"] = symmetric_mass_ratio_to_mass_ratio(sample["symmetric_mass_ratio"]) else: return check_and_return_quietly(require_add, sample) - output_sample[total_mass_key] = ( - chirp_mass_and_mass_ratio_to_total_mass( - chirp_mass=output_sample[chirp_mass_key], - mass_ratio=output_sample["mass_ratio"]) + output_sample[total_mass_key] = chirp_mass_and_mass_ratio_to_total_mass( + chirp_mass=output_sample[chirp_mass_key], mass_ratio=output_sample["mass_ratio"] ) # We haven't matched any of the criteria - if total_mass_key not in output_sample.keys() or ( - "mass_ratio" not in output_sample.keys()): + if total_mass_key not in output_sample.keys() or ("mass_ratio" not in output_sample.keys()): return check_and_return_quietly(require_add, sample) - mass_1, mass_2 = ( - total_mass_and_mass_ratio_to_component_masses( - total_mass=output_sample[total_mass_key], - mass_ratio=output_sample["mass_ratio"]) + mass_1, mass_2 = total_mass_and_mass_ratio_to_component_masses( + total_mass=output_sample[total_mass_key], mass_ratio=output_sample["mass_ratio"] ) output_sample[mass_1_key] = mass_1 output_sample[mass_2_key] = mass_2 @@ -2017,35 +2027,31 @@ def generate_mass_parameters(sample, source=False): output_sample = intermediate_sample.copy() if source: - mass_1_key = 'mass_1_source' - mass_2_key = 'mass_2_source' - total_mass_key = 'total_mass_source' - chirp_mass_key = 'chirp_mass_source' + mass_1_key = "mass_1_source" + mass_2_key = "mass_2_source" + total_mass_key = "total_mass_source" + chirp_mass_key = "chirp_mass_source" else: - mass_1_key = 'mass_1' - mass_2_key = 'mass_2' - total_mass_key = 'total_mass' - chirp_mass_key = 'chirp_mass' + mass_1_key = "mass_1" + mass_2_key = "mass_2" + total_mass_key = "total_mass" + chirp_mass_key = "chirp_mass" if chirp_mass_key not in output_sample.keys(): - output_sample[chirp_mass_key] = ( - component_masses_to_chirp_mass(output_sample[mass_1_key], - output_sample[mass_2_key]) + output_sample[chirp_mass_key] = component_masses_to_chirp_mass( + output_sample[mass_1_key], output_sample[mass_2_key] ) if total_mass_key not in output_sample.keys(): - output_sample[total_mass_key] = ( - component_masses_to_total_mass(output_sample[mass_1_key], - output_sample[mass_2_key]) + output_sample[total_mass_key] = component_masses_to_total_mass( + output_sample[mass_1_key], output_sample[mass_2_key] ) - if 'symmetric_mass_ratio' not in output_sample.keys(): - output_sample['symmetric_mass_ratio'] = ( - component_masses_to_symmetric_mass_ratio(output_sample[mass_1_key], - output_sample[mass_2_key]) + if "symmetric_mass_ratio" not in output_sample.keys(): + output_sample["symmetric_mass_ratio"] = component_masses_to_symmetric_mass_ratio( + output_sample[mass_1_key], output_sample[mass_2_key] ) - if 'mass_ratio' not in output_sample.keys(): - output_sample['mass_ratio'] = ( - component_masses_to_mass_ratio(output_sample[mass_1_key], - output_sample[mass_2_key]) + if "mass_ratio" not in output_sample.keys(): + output_sample["mass_ratio"] = component_masses_to_mass_ratio( + output_sample[mass_1_key], output_sample[mass_2_key] ) return output_sample @@ -2072,27 +2078,24 @@ def generate_spin_parameters(sample): output_sample = generate_component_spins(output_sample) - output_sample['chi_eff'] = (output_sample['spin_1z'] + - output_sample['spin_2z'] * - output_sample['mass_ratio']) /\ - (1 + output_sample['mass_ratio']) - - output_sample['chi_1_in_plane'] = np.sqrt( - output_sample['spin_1x'] ** 2 + output_sample['spin_1y'] ** 2 - ) - output_sample['chi_2_in_plane'] = np.sqrt( - output_sample['spin_2x'] ** 2 + output_sample['spin_2y'] ** 2 + output_sample["chi_eff"] = (output_sample["spin_1z"] + output_sample["spin_2z"] * output_sample["mass_ratio"]) / ( + 1 + output_sample["mass_ratio"] ) - output_sample['chi_p'] = np.maximum( - output_sample['chi_1_in_plane'], - (4 * output_sample['mass_ratio'] + 3) / - (3 * output_sample['mass_ratio'] + 4) * output_sample['mass_ratio'] * - output_sample['chi_2_in_plane']) + output_sample["chi_1_in_plane"] = np.sqrt(output_sample["spin_1x"] ** 2 + output_sample["spin_1y"] ** 2) + output_sample["chi_2_in_plane"] = np.sqrt(output_sample["spin_2x"] ** 2 + output_sample["spin_2y"] ** 2) + + output_sample["chi_p"] = np.maximum( + output_sample["chi_1_in_plane"], + (4 * output_sample["mass_ratio"] + 3) + / (3 * output_sample["mass_ratio"] + 4) + * output_sample["mass_ratio"] + * output_sample["chi_2_in_plane"], + ) try: - output_sample['cos_tilt_1'] = np.cos(output_sample['tilt_1']) - output_sample['cos_tilt_2'] = np.cos(output_sample['tilt_2']) + output_sample["cos_tilt_1"] = np.cos(output_sample["tilt_1"]) + output_sample["cos_tilt_2"] = np.cos(output_sample["tilt_2"]) except KeyError: pass @@ -2117,38 +2120,56 @@ def generate_component_spins(sample): """ output_sample = sample.copy() - spin_conversion_parameters =\ - ['theta_jn', 'phi_jl', 'tilt_1', 'tilt_2', 'phi_12', 'a_1', 'a_2', - 'mass_1', 'mass_2', 'reference_frequency', 'phase'] + spin_conversion_parameters = [ + "theta_jn", + "phi_jl", + "tilt_1", + "tilt_2", + "phi_12", + "a_1", + "a_2", + "mass_1", + "mass_2", + "reference_frequency", + "phase", + ] if all(key in output_sample.keys() for key in spin_conversion_parameters): ( - output_sample['iota'], output_sample['spin_1x'], - output_sample['spin_1y'], output_sample['spin_1z'], - output_sample['spin_2x'], output_sample['spin_2y'], - output_sample['spin_2z'] + output_sample["iota"], + output_sample["spin_1x"], + output_sample["spin_1y"], + output_sample["spin_1z"], + output_sample["spin_2x"], + output_sample["spin_2y"], + output_sample["spin_2z"], ) = np.vectorize(bilby_to_lalsimulation_spins)( - output_sample['theta_jn'], output_sample['phi_jl'], - output_sample['tilt_1'], output_sample['tilt_2'], - output_sample['phi_12'], output_sample['a_1'], output_sample['a_2'], - output_sample['mass_1'] * solar_mass, - output_sample['mass_2'] * solar_mass, - output_sample['reference_frequency'], output_sample['phase'] + output_sample["theta_jn"], + output_sample["phi_jl"], + output_sample["tilt_1"], + output_sample["tilt_2"], + output_sample["phi_12"], + output_sample["a_1"], + output_sample["a_2"], + output_sample["mass_1"] * solar_mass, + output_sample["mass_2"] * solar_mass, + output_sample["reference_frequency"], + output_sample["phase"], ) - output_sample['phi_1'] =\ - np.fmod(2 * np.pi + np.arctan2( - output_sample['spin_1y'], output_sample['spin_1x']), 2 * np.pi) - output_sample['phi_2'] =\ - np.fmod(2 * np.pi + np.arctan2( - output_sample['spin_2y'], output_sample['spin_2x']), 2 * np.pi) - - elif 'chi_1' in output_sample and 'chi_2' in output_sample: - output_sample['spin_1x'] = 0 - output_sample['spin_1y'] = 0 - output_sample['spin_1z'] = output_sample['chi_1'] - output_sample['spin_2x'] = 0 - output_sample['spin_2y'] = 0 - output_sample['spin_2z'] = output_sample['chi_2'] + output_sample["phi_1"] = np.fmod( + 2 * np.pi + np.arctan2(output_sample["spin_1y"], output_sample["spin_1x"]), 2 * np.pi + ) + output_sample["phi_2"] = np.fmod( + 2 * np.pi + np.arctan2(output_sample["spin_2y"], output_sample["spin_2x"]), 2 * np.pi + ) + + elif "chi_1" in output_sample and "chi_2" in output_sample: + output_sample["spin_1x"] = 0 + output_sample["spin_1y"] = 0 + output_sample["spin_1z"] = output_sample["chi_1"] + output_sample["spin_2x"] = 0 + output_sample["spin_2y"] = 0 + output_sample["spin_2z"] = output_sample["chi_2"] else: logger.debug("Component spin extraction failed.") @@ -2173,14 +2194,12 @@ def generate_tidal_parameters(sample): """ output_sample = sample.copy() - output_sample['lambda_tilde'] =\ - lambda_1_lambda_2_to_lambda_tilde( - output_sample['lambda_1'], output_sample['lambda_2'], - output_sample['mass_1'], output_sample['mass_2']) - output_sample['delta_lambda_tilde'] = \ - lambda_1_lambda_2_to_delta_lambda_tilde( - output_sample['lambda_1'], output_sample['lambda_2'], - output_sample['mass_1'], output_sample['mass_2']) + output_sample["lambda_tilde"] = lambda_1_lambda_2_to_lambda_tilde( + output_sample["lambda_1"], output_sample["lambda_2"], output_sample["mass_1"], output_sample["mass_2"] + ) + output_sample["delta_lambda_tilde"] = lambda_1_lambda_2_to_delta_lambda_tilde( + output_sample["lambda_1"], output_sample["lambda_2"], output_sample["mass_1"], output_sample["mass_2"] + ) return output_sample @@ -2199,15 +2218,12 @@ def generate_source_frame_parameters(sample): """ output_sample = sample.copy() - output_sample['redshift'] =\ - luminosity_distance_to_redshift(output_sample['luminosity_distance']) - output_sample['comoving_distance'] =\ - redshift_to_comoving_distance(output_sample['redshift']) + output_sample["redshift"] = luminosity_distance_to_redshift(output_sample["luminosity_distance"]) + output_sample["comoving_distance"] = redshift_to_comoving_distance(output_sample["redshift"]) - for key in ['mass_1', 'mass_2', 'chirp_mass', 'total_mass']: + for key in ["mass_1", "mass_2", "chirp_mass", "total_mass"]: if key in output_sample: - output_sample['{}_source'.format(key)] =\ - output_sample[key] / (1 + output_sample['redshift']) + output_sample[f"{key}_source"] = output_sample[key] / (1 + output_sample["redshift"]) return output_sample @@ -2230,30 +2246,29 @@ def compute_snrs(sample, likelihood, npool=1): signal_polarizations = likelihood.waveform_generator.frequency_domain_strain(sample.copy()) for ifo in likelihood.interferometers: per_detector_snr = likelihood.calculate_snrs(signal_polarizations, ifo, parameters=sample) - sample['{}_matched_filter_snr'.format(ifo.name)] =\ - per_detector_snr.complex_matched_filter_snr - sample['{}_optimal_snr'.format(ifo.name)] = \ - per_detector_snr.optimal_snr_squared.real ** 0.5 + sample[f"{ifo.name}_matched_filter_snr"] = per_detector_snr.complex_matched_filter_snr + sample[f"{ifo.name}_optimal_snr"] = per_detector_snr.optimal_snr_squared.real**0.5 else: from tqdm.auto import tqdm - logger.info('Computing SNRs for every sample.') + + logger.info("Computing SNRs for every sample.") fill_args = [(ii, row) for ii, row in sample.iterrows()] if npool > 1: from ..core.sampler.base_sampler import _initialize_global_variables + pool = multiprocessing.Pool( processes=npool, initializer=_initialize_global_variables, initargs=(likelihood, None, None, False, dict()), ) - logger.info( - "Using a pool with size {} for nsamples={}".format(npool, len(sample)) - ) + logger.info(f"Using a pool with size {npool} for nsamples={len(sample)}") new_samples = pool.map(_compute_snrs, tqdm(fill_args, file=sys.stdout)) pool.close() pool.join() else: from ..core.sampler.base_sampler import _sampling_convenience_dump + _sampling_convenience_dump.likelihood = likelihood new_samples = [_compute_snrs(xx) for xx in tqdm(fill_args, file=sys.stdout)] @@ -2268,23 +2283,20 @@ def compute_snrs(sample, likelihood, npool=1): for k, v in snr_updates.items(): sample[k] = v else: - logger.debug('Not computing SNRs.') + logger.debug("Not computing SNRs.") def _compute_snrs(args): """A wrapper of computing the SNRs to enable multiprocessing""" from ..core.sampler.base_sampler import _sampling_convenience_dump + likelihood = _sampling_convenience_dump.likelihood ii, sample = args sample = dict(sample).copy() - signal_polarizations = likelihood.waveform_generator.frequency_domain_strain( - sample.copy() - ) + signal_polarizations = likelihood.waveform_generator.frequency_domain_strain(sample.copy()) snrs = list() for ifo in likelihood.interferometers: - snrs.append(likelihood.calculate_snrs( - signal_polarizations, ifo, return_array=False, parameters=sample - )) + snrs.append(likelihood.calculate_snrs(signal_polarizations, ifo, return_array=False, parameters=sample)) return snrs @@ -2310,7 +2322,7 @@ def compute_per_detector_log_likelihoods(samples, likelihood, npool=1, block=10) """ if likelihood is not None: if not callable(likelihood.compute_per_detector_log_likelihood): - logger.debug('Not computing per-detector log likelihoods.') + logger.debug("Not computing per-detector log likelihoods.") return samples if isinstance(samples, dict): @@ -2318,10 +2330,10 @@ def compute_per_detector_log_likelihoods(samples, likelihood, npool=1, block=10) return samples elif not isinstance(samples, DataFrame): - raise ValueError("Unable to handle input samples of type {}".format(type(samples))) + raise ValueError(f"Unable to handle input samples of type {type(samples)}") from tqdm.auto import tqdm - logger.info('Computing per-detector log likelihoods.') + logger.info("Computing per-detector log likelihoods.") # Initialize cache dict cached_samples_dict = dict() @@ -2332,17 +2344,16 @@ def compute_per_detector_log_likelihoods(samples, likelihood, npool=1, block=10) # Set up the multiprocessing if npool > 1: from ..core.sampler.base_sampler import _initialize_global_variables + pool = multiprocessing.Pool( processes=npool, initializer=_initialize_global_variables, initargs=(likelihood, None, None, False, dict()), ) - logger.info( - "Using a pool with size {} for nsamples={}" - .format(npool, len(samples)) - ) + logger.info(f"Using a pool with size {npool} for nsamples={len(samples)}") else: from ..core.sampler.base_sampler import _sampling_convenience_dump + _sampling_convenience_dump.likelihood = likelihood pool = None @@ -2356,11 +2367,9 @@ def compute_per_detector_log_likelihoods(samples, likelihood, npool=1, block=10) continue if pool is not None: - subset_samples = pool.map(_compute_per_detector_log_likelihoods, - fill_args[ii: ii + block]) + subset_samples = pool.map(_compute_per_detector_log_likelihoods, fill_args[ii : ii + block]) else: - subset_samples = [list(_compute_per_detector_log_likelihoods(xx)) - for xx in fill_args[ii: ii + block]] + subset_samples = [list(_compute_per_detector_log_likelihoods(xx)) for xx in fill_args[ii : ii + block]] cached_samples_dict[ii] = subset_samples @@ -2372,33 +2381,29 @@ def compute_per_detector_log_likelihoods(samples, likelihood, npool=1, block=10) pool.close() pool.join() - new_samples = np.concatenate( - [np.array(val) for key, val in cached_samples_dict.items() if key != "_samples"] - ) + new_samples = np.concatenate([np.array(val) for key, val in cached_samples_dict.items() if key != "_samples"]) - for ii, key in \ - enumerate([f'{ifo.name}_log_likelihood' for ifo in likelihood.interferometers]): + for ii, key in enumerate([f"{ifo.name}_log_likelihood" for ifo in likelihood.interferometers]): samples[key] = new_samples[:, ii] return samples else: - logger.debug('Not computing per-detector log likelihoods.') + logger.debug("Not computing per-detector log likelihoods.") def _compute_per_detector_log_likelihoods(args): """A wrapper of computing the per-detector log likelihoods to enable multiprocessing""" from ..core.sampler.base_sampler import _sampling_convenience_dump + likelihood = _sampling_convenience_dump.likelihood _, sample = args sample = dict(sample).copy() new_sample = likelihood.compute_per_detector_log_likelihood(sample) - return tuple((new_sample[key] for key in - [f'{ifo.name}_log_likelihood' for ifo in likelihood.interferometers])) + return tuple(new_sample[key] for key in [f"{ifo.name}_log_likelihood" for ifo in likelihood.interferometers]) -def generate_posterior_samples_from_marginalized_likelihood( - samples, likelihood, npool=1, block=10, use_cache=True): +def generate_posterior_samples_from_marginalized_likelihood(samples, likelihood, npool=1, block=10, use_cache=True): """ Reconstruct the distance posterior from a run which used a likelihood which explicitly marginalised over time/distance/phase. @@ -2434,10 +2439,10 @@ def generate_posterior_samples_from_marginalized_likelihood( if isinstance(samples, dict): return samples elif not isinstance(samples, DataFrame): - raise ValueError("Unable to handle input samples of type {}".format(type(samples))) + raise ValueError(f"Unable to handle input samples of type {type(samples)}") from tqdm.auto import tqdm - logger.info('Reconstructing marginalised parameters.') + logger.info("Reconstructing marginalised parameters.") try: cache_filename = f"{likelihood.outdir}/.{likelihood.label}_generate_posterior_cache.pickle" @@ -2456,11 +2461,9 @@ def generate_posterior_samples_from_marginalized_likelihood( # Check the samples are identical between the cache and current if (cached_samples_dict is not None) and (cached_samples_dict["_samples"].equals(samples)): # Calculate reconstruction percentage and print a log message - nsamples_converted = np.sum( - [len(val) for key, val in cached_samples_dict.items() if key != "_samples"] - ) + nsamples_converted = np.sum([len(val) for key, val in cached_samples_dict.items() if key != "_samples"]) perc = 100 * nsamples_converted / len(cached_samples_dict["_samples"]) - logger.info(f'Using cached reconstruction with {perc:0.1f}% converted.') + logger.info(f"Using cached reconstruction with {perc:0.1f}% converted.") else: logger.info("Cached samples dict out of date, ignoring") cached_samples_dict = dict(_samples=samples) @@ -2475,17 +2478,16 @@ def generate_posterior_samples_from_marginalized_likelihood( # Set up the multiprocessing if npool > 1: from ..core.sampler.base_sampler import _initialize_global_variables + pool = multiprocessing.Pool( processes=npool, initializer=_initialize_global_variables, initargs=(likelihood, None, None, False, dict()), ) - logger.info( - "Using a pool with size {} for nsamples={}" - .format(npool, len(samples)) - ) + logger.info(f"Using a pool with size {npool} for nsamples={len(samples)}") else: from ..core.sampler.base_sampler import _sampling_convenience_dump + _sampling_convenience_dump.likelihood = likelihood pool = None @@ -2500,9 +2502,9 @@ def generate_posterior_samples_from_marginalized_likelihood( continue if pool is not None: - subset_samples = pool.map(fill_sample, fill_args[ii: ii + block]) + subset_samples = pool.map(fill_sample, fill_args[ii : ii + block]) else: - subset_samples = [list(fill_sample(xx)) for xx in fill_args[ii: ii + block]] + subset_samples = [list(fill_sample(xx)) for xx in fill_args[ii : ii + block]] cached_samples_dict[ii] = subset_samples @@ -2517,9 +2519,7 @@ def generate_posterior_samples_from_marginalized_likelihood( pool.close() pool.join() - new_samples = np.concatenate( - [np.array(val) for key, val in cached_samples_dict.items() if key != "_samples"] - ) + new_samples = np.concatenate([np.array(val) for key, val in cached_samples_dict.items() if key != "_samples"]) for ii, key in enumerate(marginalized_parameters): samples[key] = new_samples[:, ii] @@ -2535,7 +2535,7 @@ def generate_sky_frame_parameters(samples, likelihood): raise ValueError from tqdm.auto import tqdm - logger.info('Generating sky frame parameters.') + logger.info("Generating sky frame parameters.") new_samples = list() for ii in tqdm(range(len(samples)), file=sys.stdout): sample = dict(samples.iloc[ii]).copy() @@ -2555,7 +2555,7 @@ def fill_sample(args): marginalized_parameters = getattr(likelihood, "_marginalized_parameters", list()) sample = dict(sample).copy() new_sample = likelihood.generate_posterior_sample_from_marginalized_likelihood(sample) - return tuple((new_sample[key] for key in marginalized_parameters)) + return tuple(new_sample[key] for key in marginalized_parameters) def identity_map_conversion(parameters): @@ -2589,26 +2589,24 @@ def identity_map_generation(sample, likelihood=None, priors=None, npool=1): output_sample = fill_from_fixed_priors(output_sample, priors) if likelihood is not None: - compute_per_detector_log_likelihoods( - samples=output_sample, likelihood=likelihood, npool=npool) + compute_per_detector_log_likelihoods(samples=output_sample, likelihood=likelihood, npool=npool) marginalized_parameters = getattr(likelihood, "_marginalized_parameters", list()) if len(marginalized_parameters) > 0: try: generate_posterior_samples_from_marginalized_likelihood( - samples=output_sample, likelihood=likelihood, npool=npool) + samples=output_sample, likelihood=likelihood, npool=npool + ) except MarginalizedLikelihoodReconstructionError as e: logger.warning( "Marginalised parameter reconstruction failed with message " - "{}. Some parameters may not have the intended " - "interpretation.".format(e) + f"{e}. Some parameters may not have the intended " + "interpretation." ) - if ("ra" in output_sample.keys() and "dec" in output_sample.keys() and "psi" in output_sample.keys()): + if "ra" in output_sample.keys() and "dec" in output_sample.keys() and "psi" in output_sample.keys(): compute_snrs(output_sample, likelihood, npool=npool) else: - logger.info( - "Skipping SNR computation since samples have insufficient sky location information" - ) + logger.info("Skipping SNR computation since samples have insufficient sky location information") return output_sample diff --git a/bilby/gw/cosmology.py b/bilby/gw/cosmology.py index 6606fdd85..42b8eaca4 100644 --- a/bilby/gw/cosmology.py +++ b/bilby/gw/cosmology.py @@ -8,7 +8,9 @@ def _set_default_cosmology(): from astropy import cosmology as cosmo + from ..core.utils.meta_data import global_meta_data + global DEFAULT_COSMOLOGY, COSMOLOGY if DEFAULT_COSMOLOGY is None: DEFAULT_COSMOLOGY = cosmo.Planck15 @@ -27,6 +29,7 @@ def get_available_cosmologies(): A tuple of strings with the names of the available cosmologies. """ from astropy.cosmology.realizations import available + return (*available, "Planck15_LAL") @@ -51,6 +54,7 @@ def get_cosmology(cosmology=None): Cosmology instance """ from astropy import cosmology as cosmo + _set_default_cosmology() if cosmology is None: cosmology = DEFAULT_COSMOLOGY @@ -65,7 +69,8 @@ def get_cosmology(cosmology=None): # Older version of LAL do not expose H0 and Omega_M try: - from lal import H0_SI as LAL_H0_SI, OMEGA_M as LAL_OMEGA_M + from lal import H0_SI as LAL_H0_SI + from lal import OMEGA_M as LAL_OMEGA_M except ImportError: LAL_H0_SI, LAL_OMEGA_M = 2.200489137532724e-18, 0.3065 @@ -73,14 +78,12 @@ def get_cosmology(cosmology=None): # consistency LAL_H0 = LAL_H0_SI * 1e3 * LAL_PC_SI * units.km / (units.Mpc * units.s) - cosmology = cosmo.FlatLambdaCDM( - H0=LAL_H0, Om0=LAL_OMEGA_M, name="Planck15_LAL" - ) + cosmology = cosmo.FlatLambdaCDM(H0=LAL_H0, Om0=LAL_OMEGA_M, name="Planck15_LAL") else: cosmology = getattr(cosmo, cosmology) elif isinstance(cosmology, dict): - if 'Ode0' in cosmology.keys(): - if 'w0' in cosmology.keys(): + if "Ode0" in cosmology.keys(): + if "w0" in cosmology.keys(): cosmology = cosmo.wCDM(**cosmology) else: cosmology = cosmo.LambdaCDM(**cosmology) @@ -105,6 +108,7 @@ def set_cosmology(cosmology=None): class. """ from ..core.utils.meta_data import global_meta_data + cosmology = get_cosmology(cosmology) global COSMOLOGY, DEFAULT_COSMOLOGY DEFAULT_COSMOLOGY = cosmology @@ -125,4 +129,5 @@ def z_at_value(func, fval, **kwargs): for detailed documentation. """ from astropy.cosmology import z_at_value + return z_at_value(func=func, fval=fval, **kwargs).value diff --git a/bilby/gw/detector/__init__.py b/bilby/gw/detector/__init__.py index 51ff98279..63fee8500 100644 --- a/bilby/gw/detector/__init__.py +++ b/bilby/gw/detector/__init__.py @@ -1,12 +1,20 @@ +# ruff: noqa: F403 + +import numpy as np + +from ...core.utils import logger +from .. import utils from ..conversion import convert_to_lal_binary_black_hole_parameters from .calibration import * from .interferometer import * from .networks import * -from .psd import * +from .networks import get_empty_interferometer +from .psd import PowerSpectralDensity from .strain_data import * + def get_safe_signal_duration(mass_1, mass_2, a_1, a_2, tilt_1, tilt_2, flow=10): - """ Calculate the safe signal duration, given the parameters + """Calculate the safe signal duration, given the parameters Parameters ========== @@ -24,15 +32,17 @@ def get_safe_signal_duration(mass_1, mass_2, a_1, a_2, tilt_1, tilt_2, flow=10): """ from lal import MSUN_SI from lalsimulation import SimInspiralChirpTimeBound + chirp_time = SimInspiralChirpTimeBound( - flow, mass_1 * MSUN_SI, mass_2 * MSUN_SI, - a_1 * np.cos(tilt_1), a_2 * np.cos(tilt_2)) - return max(2**(int(np.log2(chirp_time)) + 1), 4) + flow, mass_1 * MSUN_SI, mass_2 * MSUN_SI, a_1 * np.cos(tilt_1), a_2 * np.cos(tilt_2) + ) + return max(2 ** (int(np.log2(chirp_time)) + 1), 4) def inject_signal_into_gwpy_timeseries( - data, waveform_generator, parameters, det, power_spectral_density=None, outdir=None, label=None): - """ Inject a signal into a gwpy timeseries + data, waveform_generator, parameters, det, power_spectral_density=None, outdir=None, label=None +): + """Inject a signal into a gwpy timeseries Parameters ========== @@ -57,8 +67,8 @@ def inject_signal_into_gwpy_timeseries( A dictionary of meta data about the injection """ - from gwpy.timeseries import TimeSeries from gwpy.plot import Plot + from gwpy.timeseries import TimeSeries ifo = get_empty_interferometer(det) @@ -67,19 +77,16 @@ def inject_signal_into_gwpy_timeseries( elif power_spectral_density is not None: raise TypeError( "Input power_spectral_density should be bilby.gw.detector.psd.PowerSpectralDensity " - "object or None, received {}.".format(type(power_spectral_density)) + f"object or None, received {type(power_spectral_density)}." ) ifo.strain_data.set_from_gwpy_timeseries(data) parameters_check, _ = convert_to_lal_binary_black_hole_parameters(parameters) - parameters_check = {key: parameters_check[key] for key in - ['mass_1', 'mass_2', 'a_1', 'a_2', 'tilt_1', 'tilt_2']} + parameters_check = {key: parameters_check[key] for key in ["mass_1", "mass_2", "a_1", "a_2", "tilt_1", "tilt_2"]} safe_time = get_safe_signal_duration(**parameters_check) if data.duration.value < safe_time: - ValueError( - "Injecting a signal with safe-duration {} longer than the data {}" - .format(safe_time, data.duration.value)) + ValueError(f"Injecting a signal with safe-duration {safe_time} longer than the data {data.duration.value}") waveform_polarizations = waveform_generator.time_domain_strain(parameters) @@ -87,36 +94,29 @@ def inject_signal_into_gwpy_timeseries( for mode in waveform_polarizations.keys(): det_response = ifo.antenna_response( - parameters['ra'], parameters['dec'], parameters['geocent_time'], - parameters['psi'], mode) + parameters["ra"], parameters["dec"], parameters["geocent_time"], parameters["psi"], mode + ) signal += waveform_polarizations[mode] * det_response - time_shift = ifo.time_delay_from_geocenter( - parameters['ra'], parameters['dec'], parameters['geocent_time']) + time_shift = ifo.time_delay_from_geocenter(parameters["ra"], parameters["dec"], parameters["geocent_time"]) - dt = parameters['geocent_time'] + time_shift - data.times[0].value + dt = parameters["geocent_time"] + time_shift - data.times[0].value n_roll = dt * data.sample_rate.value n_roll = int(np.round(n_roll)) - signal_shifted = TimeSeries( - data=np.roll(signal, n_roll), times=data.times, unit=data.unit) + signal_shifted = TimeSeries(data=np.roll(signal, n_roll), times=data.times, unit=data.unit) signal_and_data = data.inject(signal_shifted) if outdir is not None and label is not None: fig = Plot(signal_shifted) - fig.savefig('{}/{}_{}_time_domain_injected_signal'.format( - outdir, ifo.name, label)) + fig.savefig(f"{outdir}/{ifo.name}_{label}_time_domain_injected_signal") meta_data = dict(name=det) - frequency_domain_signal, _ = utils.nfft( - signal_shifted.value, waveform_generator.sampling_frequency - ) - frequency_domain_data, _ = utils.nfft( - signal_and_data.value, waveform_generator.sampling_frequency - ) - meta_data['optimal_SNR'] = ( - np.sqrt(ifo.optimal_snr_squared(signal=frequency_domain_signal)).real) + frequency_domain_signal, _ = utils.nfft(signal_shifted.value, waveform_generator.sampling_frequency) + frequency_domain_data, _ = utils.nfft(signal_and_data.value, waveform_generator.sampling_frequency) + meta_data["optimal_SNR"] = np.sqrt(ifo.optimal_snr_squared(signal=frequency_domain_signal)).real from ..utils import matched_filter_snr - meta_data['matched_filter_SNR'] = matched_filter_snr( + + meta_data["matched_filter_SNR"] = matched_filter_snr( frequency_domain_signal, frequency_domain_data, ifo.power_spectral_density_array, @@ -124,20 +124,30 @@ def inject_signal_into_gwpy_timeseries( ) meta_data["parameters"] = parameters - logger.info("Injected signal in {}:".format(ifo.name)) - logger.info(" optimal SNR = {:.2f}".format(meta_data['optimal_SNR'])) - logger.info(" matched filter SNR = {:.2f}".format(meta_data['matched_filter_SNR'])) + logger.info(f"Injected signal in {ifo.name}:") + logger.info(" optimal SNR = {:.2f}".format(meta_data["optimal_SNR"])) + logger.info(" matched filter SNR = {:.2f}".format(meta_data["matched_filter_SNR"])) for key in parameters: - logger.info(' {} = {}'.format(key, parameters[key])) + logger.info(f" {key} = {parameters[key]}") return signal_and_data, meta_data def get_interferometer_with_fake_noise_and_injection( - name, injection_parameters, injection_polarizations=None, - waveform_generator=None, sampling_frequency=4096, duration=4, - start_time=None, outdir='outdir', label=None, plot=True, save=True, - zero_noise=False, raise_error=True): + name, + injection_parameters, + injection_polarizations=None, + waveform_generator=None, + sampling_frequency=4096, + duration=4, + start_time=None, + outdir="outdir", + label=None, + plot=True, + save=True, + zero_noise=False, + raise_error=True, +): """ Helper function to obtain an Interferometer instance with appropriate power spectral density and data, given an center_time. @@ -186,27 +196,27 @@ def get_interferometer_with_fake_noise_and_injection( utils.check_directory_exists_and_if_not_mkdir(outdir) if start_time is None: - start_time = injection_parameters['geocent_time'] + 2 - duration + start_time = injection_parameters["geocent_time"] + 2 - duration interferometer = get_empty_interferometer(name) interferometer.power_spectral_density = PowerSpectralDensity.from_aligo() if zero_noise: interferometer.set_strain_data_from_zero_noise( - sampling_frequency=sampling_frequency, duration=duration, - start_time=start_time) + sampling_frequency=sampling_frequency, duration=duration, start_time=start_time + ) else: interferometer.set_strain_data_from_power_spectral_density( - sampling_frequency=sampling_frequency, duration=duration, - start_time=start_time) + sampling_frequency=sampling_frequency, duration=duration, start_time=start_time + ) injection_polarizations = interferometer.inject_signal( parameters=injection_parameters, injection_polarizations=injection_polarizations, waveform_generator=waveform_generator, - raise_error=raise_error) + raise_error=raise_error, + ) - signal = interferometer.get_detector_response( - injection_polarizations, injection_parameters) + signal = interferometer.get_detector_response(injection_polarizations, injection_parameters) if plot: interferometer.plot_data(signal=signal, outdir=outdir, label=label) @@ -218,10 +228,18 @@ def get_interferometer_with_fake_noise_and_injection( def load_data_from_cache_file( - cache_file, start_time, segment_duration, psd_duration, psd_start_time, - channel_name=None, sampling_frequency=4096, roll_off=0.2, - overlap=0, outdir=None): - """ Helper routine to generate an interferometer from a cache file + cache_file, + start_time, + segment_duration, + psd_duration, + psd_start_time, + channel_name=None, + sampling_frequency=4096, + roll_off=0.2, + overlap=0, + outdir=None, +): + """Helper routine to generate an interferometer from a cache file Parameters ========== @@ -250,71 +268,78 @@ def load_data_from_cache_file( appropriate data in the cache file and a PSD. """ import lal + data_set = False psd_set = False - with open(cache_file, 'r') as ff: + with open(cache_file) as ff: lines = ff.readlines() - if len(lines)>1: - raise ValueError('This method cannot handle cache files with' - ' multiple frames. Use `load_data_by_channel_name' - ' instead.') + if len(lines) > 1: + raise ValueError( + "This method cannot handle cache files with multiple frames. Use `load_data_by_channel_name instead." + ) else: line = lines[0] cache = lal.utils.cache.CacheEntry(line) - data_in_cache = ( - (cache.segment[0].gpsSeconds < start_time) & - (cache.segment[1].gpsSeconds > start_time + segment_duration)) - psd_in_cache = ( - (cache.segment[0].gpsSeconds < psd_start_time) & - (cache.segment[1].gpsSeconds > psd_start_time + psd_duration)) - ifo = get_empty_interferometer( - "{}1".format(cache.observatory)) + data_in_cache = (cache.segment[0].gpsSeconds < start_time) & ( + cache.segment[1].gpsSeconds > start_time + segment_duration + ) + psd_in_cache = (cache.segment[0].gpsSeconds < psd_start_time) & ( + cache.segment[1].gpsSeconds > psd_start_time + psd_duration + ) + ifo = get_empty_interferometer(f"{cache.observatory}1") if not data_in_cache: - raise ValueError('The specified data segment does not exist in' - ' this frame.') + raise ValueError("The specified data segment does not exist in this frame.") if not psd_in_cache: - raise ValueError('The specified PSD data segment does not exist' - ' in this frame.') + raise ValueError("The specified PSD data segment does not exist in this frame.") if (not data_set) & data_in_cache: ifo.set_strain_data_from_frame_file( frame_file=cache.path, sampling_frequency=sampling_frequency, duration=segment_duration, start_time=start_time, - channel=channel_name, buffer_time=0) + channel=channel_name, + buffer_time=0, + ) data_set = True if (not psd_set) & psd_in_cache: - ifo.power_spectral_density = \ - PowerSpectralDensity.from_frame_file( - cache.path, - psd_start_time=psd_start_time, - psd_duration=psd_duration, - fft_length=segment_duration, - sampling_frequency=sampling_frequency, - roll_off=roll_off, - overlap=overlap, - channel=channel_name, - name=cache.observatory, - outdir=outdir, - analysis_segment_start_time=start_time) + ifo.power_spectral_density = PowerSpectralDensity.from_frame_file( + cache.path, + psd_start_time=psd_start_time, + psd_duration=psd_duration, + fft_length=segment_duration, + sampling_frequency=sampling_frequency, + roll_off=roll_off, + overlap=overlap, + channel=channel_name, + name=cache.observatory, + outdir=outdir, + analysis_segment_start_time=start_time, + ) psd_set = True if data_set and psd_set: return ifo elif not data_set: - raise ValueError('Data not loaded for {}'.format(ifo.name)) + raise ValueError(f"Data not loaded for {ifo.name}") elif not psd_set: - raise ValueError('PSD not created for {}'.format(ifo.name)) + raise ValueError(f"PSD not created for {ifo.name}") def load_data_by_channel_name( - channel_name, start_time, segment_duration, psd_duration, psd_start_time, - sampling_frequency=4096, roll_off=0.2, - overlap=0, outdir=None): - """ Helper routine to generate an interferometer from a channel name - This function creates an empty interferometer specified in the name - of the channel. It calls `ifo.set_strain_data_from_channel_name` to - set the data and PSD in the interferometer using data retrieved from + channel_name, + start_time, + segment_duration, + psd_duration, + psd_start_time, + sampling_frequency=4096, + roll_off=0.2, + overlap=0, + outdir=None, +): + """Helper routine to generate an interferometer from a channel name + This function creates an empty interferometer specified in the name + of the channel. It calls `ifo.set_strain_data_from_channel_name` to + set the data and PSD in the interferometer using data retrieved from the specified channel using gwpy.TimeSeries.get() Parameters @@ -342,27 +367,25 @@ def load_data_by_channel_name( appropriate data fetched from the specified channel and a PSD. """ try: - det = channel_name.split(':')[-2] + det = channel_name.split(":")[-2] except IndexError: raise IndexError("Channel name must be of the format `IFO:Channel`") ifo = get_empty_interferometer(det) ifo.set_strain_data_from_channel_name( + channel=channel_name, sampling_frequency=sampling_frequency, duration=segment_duration, start_time=start_time + ) + + ifo.power_spectral_density = PowerSpectralDensity.from_channel_name( channel=channel_name, + psd_start_time=psd_start_time, + psd_duration=psd_duration, + fft_length=segment_duration, sampling_frequency=sampling_frequency, - duration=segment_duration, - start_time=start_time) - - ifo.power_spectral_density = \ - PowerSpectralDensity.from_channel_name( - channel=channel_name, - psd_start_time=psd_start_time, - psd_duration=psd_duration, - fft_length=segment_duration, - sampling_frequency=sampling_frequency, - roll_off=roll_off, - overlap=overlap, - name=det, - outdir=outdir, - analysis_segment_start_time=start_time) + roll_off=roll_off, + overlap=overlap, + name=det, + outdir=outdir, + analysis_segment_start_time=start_time, + ) return ifo diff --git a/bilby/gw/detector/calibration.py b/bilby/gw/detector/calibration.py index 729b9e332..0fccde95c 100644 --- a/bilby/gw/detector/calibration.py +++ b/bilby/gw/detector/calibration.py @@ -39,6 +39,7 @@ :code:`"data"` convention. """ + import copy import os @@ -46,8 +47,8 @@ import pandas as pd from scipy.interpolate import interp1d -from ...core.utils.log import logger from ...core.prior.dict import PriorDict +from ...core.utils.log import logger from ..prior import CalibrationPriorDict @@ -64,18 +65,13 @@ def _check_calibration_correction_type(correction_type): correction_type = "data" if correction_type.lower() not in ["data", "template"]: raise ValueError( - "Calibration envelope correction should be one of 'data' or " - f"'template', found {correction_type}." + f"Calibration envelope correction should be one of 'data' or 'template', found {correction_type}." ) - logger.debug( - f"Supplied calibration correction will be applied to the {correction_type}" - ) + logger.debug(f"Supplied calibration correction will be applied to the {correction_type}") return correction_type -def read_calibration_file( - filename, frequency_array, number_of_response_curves, starting_index=0, correction_type=None -): +def read_calibration_file(filename, frequency_array, number_of_response_curves, starting_index=0, correction_type=None): r""" Function to read the hdf5 files from the calibration group containing the physical calibration response curves. @@ -112,21 +108,21 @@ def read_calibration_file( correction_type = _check_calibration_correction_type(correction_type=correction_type) logger.info(f"Reading calibration draws from {filename}") - with h5py.File(filename, 'r') as calibration_file: + with h5py.File(filename, "r") as calibration_file: try: - dr = calibration_file['deltaR'] + dr = calibration_file["deltaR"] except KeyError: raise KeyError(f"File {filename} does not contain 'deltaR' group.") # slice draws according to starting_index and number_of_response_curves - calibration_amplitude = dr['draws_amp_rel'][starting_index: starting_index + number_of_response_curves] - calibration_phase = dr['draws_phase'][starting_index: starting_index + number_of_response_curves] - calibration_frequencies = dr['freq'][:] + calibration_amplitude = dr["draws_amp_rel"][starting_index : starting_index + number_of_response_curves] + calibration_phase = dr["draws_phase"][starting_index : starting_index + number_of_response_curves] + calibration_frequencies = dr["freq"][:] # read parameter draws if present; stored under CalParams/table as a structured dataset parameter_draws = None - if 'CalParams' in calibration_file and 'table' in calibration_file['CalParams']: - table_ds = calibration_file['CalParams']['table'] + if "CalParams" in calibration_file and "table" in calibration_file["CalParams"]: + table_ds = calibration_file["CalParams"]["table"] try: rec = np.array(table_ds) # convert structured array to DataFrame @@ -138,8 +134,8 @@ def read_calibration_file( # combine into complex responses and interpolate to requested frequency array calibration_draws = calibration_amplitude * np.exp(1j * calibration_phase) calibration_draws = interp1d( - calibration_frequencies, calibration_draws, kind='cubic', - bounds_error=False, fill_value=1)(frequency_array) + calibration_frequencies, calibration_draws, kind="cubic", bounds_error=False, fill_value=1 + )(frequency_array) if correction_type == "data": calibration_draws = 1 / calibration_draws @@ -183,31 +179,30 @@ def write_calibration_file( logger.info(f"Writing calibration draws to {filename}") # overwrite/create the file - with h5py.File(filename, 'w') as calibration_file: - deltaR_group = calibration_file.create_group('deltaR') + with h5py.File(filename, "w") as calibration_file: + deltaR_group = calibration_file.create_group("deltaR") # Save output: amplitude and phase arrays amp = np.abs(calibration_draws) phase = np.angle(calibration_draws) - deltaR_group.create_dataset('draws_amp_rel', data=amp, dtype=np.float64, compression='gzip') - deltaR_group.create_dataset('draws_phase', data=phase, dtype=np.float64, compression='gzip') - deltaR_group.create_dataset('freq', data=frequency_array, dtype=np.float64) + deltaR_group.create_dataset("draws_amp_rel", data=amp, dtype=np.float64, compression="gzip") + deltaR_group.create_dataset("draws_phase", data=phase, dtype=np.float64, compression="gzip") + deltaR_group.create_dataset("freq", data=frequency_array, dtype=np.float64) # Save calibration parameter draws (DataFrame) if provided. Store as a structured # dataset under CalParams/table so it can be read back into a DataFrame. if calibration_parameter_draws is not None: - cp_group = calibration_file.create_group('CalParams') + cp_group = calibration_file.create_group("CalParams") # convert DataFrame to a numpy recarray / structured array rec = calibration_parameter_draws.to_records(index=False) # create dataset - cp_group.create_dataset('table', data=rec, compression='gzip') + cp_group.create_dataset("table", data=rec, compression="gzip") -class Recalibrate(object): +class Recalibrate: + name = "none" - name = 'none' - - def __init__(self, prefix='recalib_'): + def __init__(self, prefix="recalib_"): """ Base calibration object. This applies no transformation @@ -220,7 +215,7 @@ def __init__(self, prefix='recalib_'): self.prefix = prefix def __repr__(self): - return self.__class__.__name__ + '(prefix=\'{}\')'.format(self.prefix) + return self.__class__.__name__ + f"(prefix='{self.prefix}')" def get_calibration_factor(self, frequency_array, **params): """Apply calibration model @@ -243,16 +238,14 @@ def get_calibration_factor(self, frequency_array, **params): return np.ones_like(frequency_array) def set_calibration_parameters(self, **params): - self.params.update({key[len(self.prefix):]: params[key] for key in params - if self.prefix in key}) + self.params.update({key[len(self.prefix) :]: params[key] for key in params if self.prefix in key}) def __eq__(self, other): return self.__dict__ == other.__dict__ class CubicSpline(Recalibrate): - - name = 'cubic_spline' + name = "cubic_spline" def __init__(self, prefix, minimum_frequency, maximum_frequency, n_points): """ @@ -274,14 +267,13 @@ def __init__(self, prefix, minimum_frequency, maximum_frequency, n_points): n_points: int number of spline points """ - super(CubicSpline, self).__init__(prefix=prefix) + super().__init__(prefix=prefix) if n_points < 4: - raise ValueError('Cubic spline calibration requires at least 4 spline nodes.') + raise ValueError("Cubic spline calibration requires at least 4 spline nodes.") self.n_points = n_points self.minimum_frequency = minimum_frequency self.maximum_frequency = maximum_frequency - self._log_spline_points = np.linspace( - np.log10(minimum_frequency), np.log10(maximum_frequency), n_points) + self._log_spline_points = np.linspace(np.log10(minimum_frequency), np.log10(maximum_frequency), n_points) @property def delta_log_spline_points(self): @@ -325,8 +317,11 @@ def log_spline_points(self): return self._log_spline_points def __repr__(self): - return self.__class__.__name__ + '(prefix=\'{}\', minimum_frequency={}, maximum_frequency={}, n_points={})'\ - .format(self.prefix, self.minimum_frequency, self.maximum_frequency, self.n_points) + return ( + self.__class__.__name__ + + f"(prefix='{self.prefix}', minimum_frequency={self.minimum_frequency}, " + + f"maximum_frequency={self.maximum_frequency}, n_points={self.n_points})" + ) def _evaluate_spline(self, kind, a, b, c, d, previous_nodes): """Evaluate Eq. (1) in https://dcc.ligo.org/LIGO-T2300140""" @@ -358,9 +353,7 @@ def get_calibration_factor(self, frequency_array, **params): calibration_factor : array-like The factor to multiply the strain by. """ - log10f_per_deltalog10f = ( - np.log10(frequency_array) - self.log_spline_points[0] - ) / self.delta_log_spline_points + log10f_per_deltalog10f = (np.log10(frequency_array) - self.log_spline_points[0]) / self.delta_log_spline_points previous_nodes = np.clip(np.floor(log10f_per_deltalog10f).astype(int), a_min=0, a_max=self.n_points - 2) b = log10f_per_deltalog10f - previous_nodes a = 1 - b @@ -377,7 +370,6 @@ def get_calibration_factor(self, frequency_array, **params): class Precomputed(Recalibrate): - name = "precomputed" def __init__(self, label, curves, frequency_array, parameters=None): @@ -399,7 +391,7 @@ def __init__(self, label, curves, frequency_array, parameters=None): self.curves = curves self.frequency_array = frequency_array self.parameters = parameters - super(Precomputed, self).__init__(prefix=f"recalib_index_{self.label}") + super().__init__(prefix=f"recalib_index_{self.label}") def get_calibration_factor(self, frequency_array, **params): idx = int(params.get(self.prefix, None)) @@ -410,9 +402,7 @@ def get_calibration_factor(self, frequency_array, **params): return self.curves[idx] @classmethod - def constant_uncertainty_spline( - cls, amplitude_sigma, phase_sigma, frequency_array, n_nodes, label, n_curves - ): + def constant_uncertainty_spline(cls, amplitude_sigma, phase_sigma, frequency_array, n_nodes, label, n_curves): priors = CalibrationPriorDict.constant_uncertainty_spline( amplitude_sigma=amplitude_sigma, phase_sigma=phase_sigma, @@ -423,11 +413,7 @@ def constant_uncertainty_spline( ) parameters = pd.DataFrame(priors.sample(n_curves)) curves = curves_from_spline_and_prior( - label=label, - frequency_array=frequency_array, - n_points=n_nodes, - parameters=parameters, - n_curves=n_curves + label=label, frequency_array=frequency_array, n_points=n_nodes, parameters=parameters, n_curves=n_curves ) return cls( label=label, @@ -437,9 +423,7 @@ def constant_uncertainty_spline( ) @classmethod - def from_envelope_file( - cls, envelope, frequency_array, n_nodes, label, n_curves, correction_type - ): + def from_envelope_file(cls, envelope, frequency_array, n_nodes, label, n_curves, correction_type): priors = CalibrationPriorDict.from_envelope_file( envelope_file=envelope, minimum_frequency=frequency_array[0], @@ -517,10 +501,7 @@ def build_calibration_lookup( parameters[name][model.prefix] = idxs else: if priors is None: - raise ValueError( - "Priors must be passed to generate calibration response curves " - "for cubic spline." - ) + raise ValueError("Priors must be passed to generate calibration response curves for cubic spline.") draws[name], parameters[name] = _generate_calibration_draws( interferometer=interferometer, priors=priors, @@ -544,13 +525,15 @@ def _generate_calibration_draws(interferometer, priors, n_curves): parameters = pd.DataFrame(calibration_priors.sample(n_curves)) - draws = np.array(curves_from_spline_and_prior( - parameters=parameters, - label=name, - n_points=interferometer.calibration_model.n_points, - frequency_array=frequencies, - n_curves=n_curves, - )) + draws = np.array( + curves_from_spline_and_prior( + parameters=parameters, + label=name, + n_points=interferometer.calibration_model.n_points, + frequency_array=frequencies, + n_curves=n_curves, + ) + ) return draws, parameters @@ -563,8 +546,5 @@ def curves_from_spline_and_prior(parameters, label, n_points, frequency_array, n ) curves = list() for ii in range(n_curves): - curves.append(spline.get_calibration_factor( - frequency_array=frequency_array, - **parameters.iloc[ii] - )) + curves.append(spline.get_calibration_factor(frequency_array=frequency_array, **parameters.iloc[ii])) return curves diff --git a/bilby/gw/detector/geometry.py b/bilby/gw/detector/geometry.py index d7e1433de..d57f3d5ab 100644 --- a/bilby/gw/detector/geometry.py +++ b/bilby/gw/detector/geometry.py @@ -4,9 +4,10 @@ from .. import utils as gwutils -class InterferometerGeometry(object): - def __init__(self, length, latitude, longitude, elevation, xarm_azimuth, yarm_azimuth, - xarm_tilt=0., yarm_tilt=0.): +class InterferometerGeometry: + def __init__( + self, length, latitude, longitude, elevation, xarm_azimuth, yarm_azimuth, xarm_tilt=0.0, yarm_tilt=0.0 + ): """ Instantiate an Interferometer object. @@ -51,22 +52,31 @@ def __init__(self, length, latitude, longitude, elevation, xarm_azimuth, yarm_az self._detector_tensor = None def __eq__(self, other): - for attribute in ['length', 'latitude', 'longitude', 'elevation', - 'xarm_azimuth', 'yarm_azimuth', 'xarm_tilt', 'yarm_tilt']: + for attribute in [ + "length", + "latitude", + "longitude", + "elevation", + "xarm_azimuth", + "yarm_azimuth", + "xarm_tilt", + "yarm_tilt", + ]: if not getattr(self, attribute) == getattr(other, attribute): return False return True def __repr__(self): - return self.__class__.__name__ + '(length={}, latitude={}, longitude={}, elevation={}, ' \ - 'xarm_azimuth={}, yarm_azimuth={}, xarm_tilt={}, yarm_tilt={})' \ - .format(float(self.length), float(self.latitude), float(self.longitude), - float(self.elevation), float(self.xarm_azimuth), float(self.yarm_azimuth), float(self.xarm_tilt), - float(self.yarm_tilt)) + return ( + self.__class__.__name__ + f"(length={float(self.length)}, latitude={float(self.latitude)}, " + f"longitude={float(self.longitude)}, elevation={float(self.elevation)}, " + + f"xarm_azimuth={float(self.xarm_azimuth)}, yarm_azimuth={float(self.yarm_azimuth)}, " + + f"xarm_tilt={float(self.xarm_tilt)}, yarm_tilt={float(self.yarm_tilt)})" + ) @property def latitude(self): - """ Saves latitude in rad internally. Updates related quantities if set to a different value. + """Saves latitude in rad internally. Updates related quantities if set to a different value. Returns ======= @@ -93,7 +103,7 @@ def latitude_radians(self): @property def longitude(self): - """ Saves longitude in rad internally. Updates related quantities if set to a different value. + """Saves longitude in rad internally. Updates related quantities if set to a different value. Returns ======= @@ -120,7 +130,7 @@ def longitude_radians(self): @property def elevation(self): - """ Updates related quantities if set to a different values. + """Updates related quantities if set to a different values. Returns ======= @@ -135,7 +145,7 @@ def elevation(self, elevation): @property def xarm_azimuth(self): - """ Saves the x-arm azimuth in rad internally. Updates related quantities if set to a different values. + """Saves the x-arm azimuth in rad internally. Updates related quantities if set to a different values. Returns ======= @@ -151,7 +161,7 @@ def xarm_azimuth(self, xarm_azimuth): @property def yarm_azimuth(self): - """ Saves the y-arm azimuth in rad internally. Updates related quantities if set to a different values. + """Saves the y-arm azimuth in rad internally. Updates related quantities if set to a different values. Returns ======= @@ -167,7 +177,7 @@ def yarm_azimuth(self, yarm_azimuth): @property def xarm_tilt(self): - """ Updates related quantities if set to a different values. + """Updates related quantities if set to a different values. Returns ======= @@ -183,7 +193,7 @@ def xarm_tilt(self, xarm_tilt): @property def yarm_tilt(self): - """ Updates related quantities if set to a different values. + """Updates related quantities if set to a different values. Returns ======= @@ -199,7 +209,7 @@ def yarm_tilt(self, yarm_tilt): @property def vertex(self): - """ Position of the IFO vertex in geocentric coordinates in meters. + """Position of the IFO vertex in geocentric coordinates in meters. Is automatically updated if related quantities are modified. @@ -208,14 +218,13 @@ def vertex(self): array_like: A 3D array representation of the vertex """ if not self._vertex_updated: - self._vertex = gwutils.get_vertex_position_geocentric(self._latitude, self._longitude, - self.elevation) + self._vertex = gwutils.get_vertex_position_geocentric(self._latitude, self._longitude, self.elevation) self._vertex_updated = True return self._vertex @property def x(self): - """ A unit vector along the x-arm + """A unit vector along the x-arm Is automatically updated if related quantities are modified. @@ -225,14 +234,14 @@ def x(self): """ if not self._x_updated: - self._x = self.unit_vector_along_arm('x') + self._x = self.unit_vector_along_arm("x") self._x_updated = True self._detector_tensor_updated = False return self._x @property def y(self): - """ A unit vector along the y-arm + """A unit vector along the y-arm Is automatically updated if related quantities are modified. @@ -242,7 +251,7 @@ def y(self): """ if not self._y_updated: - self._y = self.unit_vector_along_arm('y') + self._y = self.unit_vector_along_arm("y") self._y_updated = True self._detector_tensor_updated = False return self._y @@ -288,19 +297,19 @@ def unit_vector_along_arm(self, arm): ValueError: If arm is neither 'x' nor 'y' """ - if arm == 'x': + if arm == "x": return calculate_arm( arm_tilt=self._xarm_tilt, arm_azimuth=self._xarm_azimuth, longitude=self._longitude, - latitude=self._latitude + latitude=self._latitude, ) - elif arm == 'y': + elif arm == "y": return calculate_arm( arm_tilt=self._yarm_tilt, arm_azimuth=self._yarm_azimuth, longitude=self._longitude, - latitude=self._latitude + latitude=self._latitude, ) else: raise ValueError("Arm must either be 'x' or 'y'.") diff --git a/bilby/gw/detector/interferometer.py b/bilby/gw/detector/interferometer.py index 2fca163e0..44a42dc98 100644 --- a/bilby/gw/detector/interferometer.py +++ b/bilby/gw/detector/interferometer.py @@ -8,46 +8,60 @@ ) from ...core import utils -from ...core.utils import docstring, logger, PropertyAccessor, safe_file_dump +from ...core.utils import PropertyAccessor, docstring, logger, safe_file_dump from ...core.utils.env import string_to_boolean from .. import utils as gwutils +from ..conversion import generate_all_bbh_parameters from .calibration import Recalibrate from .geometry import InterferometerGeometry from .strain_data import InterferometerStrainData -from ..conversion import generate_all_bbh_parameters -class Interferometer(object): - """Class for the Interferometer """ - - length = PropertyAccessor('geometry', 'length') - latitude = PropertyAccessor('geometry', 'latitude') - latitude_radians = PropertyAccessor('geometry', 'latitude_radians') - longitude = PropertyAccessor('geometry', 'longitude') - longitude_radians = PropertyAccessor('geometry', 'longitude_radians') - elevation = PropertyAccessor('geometry', 'elevation') - x = PropertyAccessor('geometry', 'x') - y = PropertyAccessor('geometry', 'y') - xarm_azimuth = PropertyAccessor('geometry', 'xarm_azimuth') - yarm_azimuth = PropertyAccessor('geometry', 'yarm_azimuth') - xarm_tilt = PropertyAccessor('geometry', 'xarm_tilt') - yarm_tilt = PropertyAccessor('geometry', 'yarm_tilt') - vertex = PropertyAccessor('geometry', 'vertex') - detector_tensor = PropertyAccessor('geometry', 'detector_tensor') - - duration = PropertyAccessor('strain_data', 'duration') - sampling_frequency = PropertyAccessor('strain_data', 'sampling_frequency') - start_time = PropertyAccessor('strain_data', 'start_time') - frequency_array = PropertyAccessor('strain_data', 'frequency_array') - time_array = PropertyAccessor('strain_data', 'time_array') - minimum_frequency = PropertyAccessor('strain_data', 'minimum_frequency') - maximum_frequency = PropertyAccessor('strain_data', 'maximum_frequency') - frequency_mask = PropertyAccessor('strain_data', 'frequency_mask') - frequency_domain_strain = PropertyAccessor('strain_data', 'frequency_domain_strain') - time_domain_strain = PropertyAccessor('strain_data', 'time_domain_strain') - - def __init__(self, name, power_spectral_density, minimum_frequency, maximum_frequency, length, latitude, longitude, - elevation, xarm_azimuth, yarm_azimuth, xarm_tilt=0., yarm_tilt=0., calibration_model=Recalibrate()): +class Interferometer: + """Class for the Interferometer""" + + length = PropertyAccessor("geometry", "length") + latitude = PropertyAccessor("geometry", "latitude") + latitude_radians = PropertyAccessor("geometry", "latitude_radians") + longitude = PropertyAccessor("geometry", "longitude") + longitude_radians = PropertyAccessor("geometry", "longitude_radians") + elevation = PropertyAccessor("geometry", "elevation") + x = PropertyAccessor("geometry", "x") + y = PropertyAccessor("geometry", "y") + xarm_azimuth = PropertyAccessor("geometry", "xarm_azimuth") + yarm_azimuth = PropertyAccessor("geometry", "yarm_azimuth") + xarm_tilt = PropertyAccessor("geometry", "xarm_tilt") + yarm_tilt = PropertyAccessor("geometry", "yarm_tilt") + vertex = PropertyAccessor("geometry", "vertex") + detector_tensor = PropertyAccessor("geometry", "detector_tensor") + + duration = PropertyAccessor("strain_data", "duration") + sampling_frequency = PropertyAccessor("strain_data", "sampling_frequency") + start_time = PropertyAccessor("strain_data", "start_time") + frequency_array = PropertyAccessor("strain_data", "frequency_array") + time_array = PropertyAccessor("strain_data", "time_array") + minimum_frequency = PropertyAccessor("strain_data", "minimum_frequency") + maximum_frequency = PropertyAccessor("strain_data", "maximum_frequency") + frequency_mask = PropertyAccessor("strain_data", "frequency_mask") + frequency_domain_strain = PropertyAccessor("strain_data", "frequency_domain_strain") + time_domain_strain = PropertyAccessor("strain_data", "time_domain_strain") + + def __init__( + self, + name, + power_spectral_density, + minimum_frequency, + maximum_frequency, + length, + latitude, + longitude, + elevation, + xarm_azimuth, + yarm_azimuth, + xarm_tilt=0.0, + yarm_tilt=0.0, + calibration_model=Recalibrate(), + ): """ Instantiate an Interferometer object. @@ -82,39 +96,43 @@ def __init__(self, name, power_spectral_density, minimum_frequency, maximum_freq Calibration model, this applies the calibration correction to the template, the default model applies no correction. """ - self.geometry = InterferometerGeometry(length, latitude, longitude, elevation, - xarm_azimuth, yarm_azimuth, xarm_tilt, yarm_tilt) + self.geometry = InterferometerGeometry( + length, latitude, longitude, elevation, xarm_azimuth, yarm_azimuth, xarm_tilt, yarm_tilt + ) self.name = name self.power_spectral_density = power_spectral_density self.calibration_model = calibration_model self.strain_data = InterferometerStrainData( - minimum_frequency=minimum_frequency, - maximum_frequency=maximum_frequency) + minimum_frequency=minimum_frequency, maximum_frequency=maximum_frequency + ) self.meta_data = dict(name=name) def __eq__(self, other): - if self.name == other.name and \ - self.geometry == other.geometry and \ - self.power_spectral_density.__eq__(other.power_spectral_density) and \ - self.calibration_model == other.calibration_model and \ - self.strain_data == other.strain_data: + if ( + self.name == other.name + and self.geometry == other.geometry + and self.power_spectral_density.__eq__(other.power_spectral_density) + and self.calibration_model == other.calibration_model + and self.strain_data == other.strain_data + ): return True return False def __repr__(self): - return self.__class__.__name__ + '(name=\'{}\', power_spectral_density={}, minimum_frequency={}, ' \ - 'maximum_frequency={}, length={}, latitude={}, longitude={}, elevation={}, ' \ - 'xarm_azimuth={}, yarm_azimuth={}, xarm_tilt={}, yarm_tilt={})' \ - .format(self.name, self.power_spectral_density, float(self.strain_data.minimum_frequency), - float(self.strain_data.maximum_frequency), float(self.geometry.length), - float(self.geometry.latitude), float(self.geometry.longitude), - float(self.geometry.elevation), float(self.geometry.xarm_azimuth), - float(self.geometry.yarm_azimuth), float(self.geometry.xarm_tilt), - float(self.geometry.yarm_tilt)) + return ( + self.__class__.__name__ + + f"(name='{self.name}', power_spectral_density={self.power_spectral_density}, " + + f"minimum_frequency={float(self.strain_data.minimum_frequency)}, " + + f"maximum_frequency={float(self.strain_data.maximum_frequency)}, length={float(self.geometry.length)}, " + + f"latitude={float(self.geometry.latitude)}, longitude={float(self.geometry.longitude)}, " + + f"elevation={float(self.geometry.elevation)}, " + + f"xarm_azimuth={float(self.geometry.xarm_azimuth)}, yarm_azimuth={float(self.geometry.yarm_azimuth)}, " + + f"xarm_tilt={float(self.geometry.xarm_tilt)}, yarm_tilt={float(self.geometry.yarm_tilt)})" + ) def set_strain_data_from_gwpy_timeseries(self, time_series): - """ Set the `Interferometer.strain_data` from a gwpy TimeSeries + """Set the `Interferometer.strain_data` from a gwpy TimeSeries Parameters ========== @@ -125,9 +143,9 @@ def set_strain_data_from_gwpy_timeseries(self, time_series): self.strain_data.set_from_gwpy_timeseries(time_series=time_series) def set_strain_data_from_frequency_domain_strain( - self, frequency_domain_strain, sampling_frequency=None, - duration=None, start_time=0, frequency_array=None): - """ Set the `Interferometer.strain_data` from a numpy array + self, frequency_domain_strain, sampling_frequency=None, duration=None, start_time=0, frequency_array=None + ): + """Set the `Interferometer.strain_data` from a numpy array Parameters ========== @@ -146,12 +164,14 @@ def set_strain_data_from_frequency_domain_strain( """ self.strain_data.set_from_frequency_domain_strain( frequency_domain_strain=frequency_domain_strain, - sampling_frequency=sampling_frequency, duration=duration, - start_time=start_time, frequency_array=frequency_array) + sampling_frequency=sampling_frequency, + duration=duration, + start_time=start_time, + frequency_array=frequency_array, + ) - def set_strain_data_from_power_spectral_density( - self, sampling_frequency, duration, start_time=0): - """ Set the `Interferometer.strain_data` from a power spectal density + def set_strain_data_from_power_spectral_density(self, sampling_frequency, duration, start_time=0): + """Set the `Interferometer.strain_data` from a power spectal density This uses the `interferometer.power_spectral_density` object to set the `strain_data` to a noise realization. See @@ -168,13 +188,13 @@ def set_strain_data_from_power_spectral_density( """ self.strain_data.set_from_power_spectral_density( - self.power_spectral_density, sampling_frequency=sampling_frequency, - duration=duration, start_time=start_time) + self.power_spectral_density, sampling_frequency=sampling_frequency, duration=duration, start_time=start_time + ) def set_strain_data_from_frame_file( - self, frame_file, sampling_frequency, duration, start_time=0, - channel=None, buffer_time=1): - """ Set the `Interferometer.strain_data` from a frame file + self, frame_file, sampling_frequency, duration, start_time=0, channel=None, buffer_time=1 + ): + """Set the `Interferometer.strain_data` from a frame file Parameters ========== @@ -194,12 +214,15 @@ def set_strain_data_from_frame_file( """ self.strain_data.set_from_frame_file( - frame_file=frame_file, sampling_frequency=sampling_frequency, - duration=duration, start_time=start_time, - channel=channel, buffer_time=buffer_time) + frame_file=frame_file, + sampling_frequency=sampling_frequency, + duration=duration, + start_time=start_time, + channel=channel, + buffer_time=buffer_time, + ) - def set_strain_data_from_channel_name( - self, channel, sampling_frequency, duration, start_time=0): + def set_strain_data_from_channel_name(self, channel, sampling_frequency, duration, start_time=0): """ Set the `Interferometer.strain_data` by fetching from given channel using strain_data.set_from_channel_name() @@ -217,11 +240,11 @@ def set_strain_data_from_channel_name( """ self.strain_data.set_from_channel_name( - channel=channel, sampling_frequency=sampling_frequency, - duration=duration, start_time=start_time) + channel=channel, sampling_frequency=sampling_frequency, duration=duration, start_time=start_time + ) def set_strain_data_from_csv(self, filename): - """ Set the `Interferometer.strain_data` from a csv file + """Set the `Interferometer.strain_data` from a csv file Parameters ========== @@ -231,9 +254,8 @@ def set_strain_data_from_csv(self, filename): """ self.strain_data.set_from_csv(filename) - def set_strain_data_from_zero_noise( - self, sampling_frequency, duration, start_time=0): - """ Set the `Interferometer.strain_data` to zero noise + def set_strain_data_from_zero_noise(self, sampling_frequency, duration, start_time=0): + """Set the `Interferometer.strain_data` to zero noise Parameters ========== @@ -247,8 +269,8 @@ def set_strain_data_from_zero_noise( """ self.strain_data.set_from_zero_noise( - sampling_frequency=sampling_frequency, duration=duration, - start_time=start_time) + sampling_frequency=sampling_frequency, duration=duration, start_time=start_time + ) def antenna_response(self, ra, dec, time, psi, mode): """ @@ -287,7 +309,7 @@ def antenna_response(self, ra, dec, time, psi, mode): return 0 def get_detector_response(self, waveform_polarizations, parameters, frequencies=None): - """ Get the detector response for a particular waveform + """Get the detector response for a particular waveform Parameters ========== @@ -313,32 +335,29 @@ def get_detector_response(self, waveform_polarizations, parameters, frequencies= signal = {} for mode in waveform_polarizations.keys(): det_response = self.antenna_response( - parameters['ra'], - parameters['dec'], - parameters['geocent_time'], - parameters['psi'], mode) + parameters["ra"], parameters["dec"], parameters["geocent_time"], parameters["psi"], mode + ) signal[mode] = waveform_polarizations[mode] * det_response signal_ifo = sum(signal.values()) * mask - time_shift = self.time_delay_from_geocenter( - parameters['ra'], parameters['dec'], parameters['geocent_time']) + time_shift = self.time_delay_from_geocenter(parameters["ra"], parameters["dec"], parameters["geocent_time"]) # Be careful to first subtract the two GPS times which are ~1e9 sec. # And then add the time_shift which varies at ~1e-5 sec - dt_geocent = parameters['geocent_time'] - self.strain_data.start_time + dt_geocent = parameters["geocent_time"] - self.strain_data.start_time dt = dt_geocent + time_shift signal_ifo[mask] = signal_ifo[mask] * np.exp(-1j * 2 * np.pi * dt * frequencies) signal_ifo[mask] *= self.calibration_model.get_calibration_factor( - frequencies, prefix='recalib_{}_'.format(self.name), **parameters + frequencies, prefix=f"recalib_{self.name}_", **parameters ) return signal_ifo def check_signal_duration(self, parameters, raise_error=True): - """ Check that the signal with the given parameters fits in the data + """Check that the signal with the given parameters fits in the data Parameters ========== @@ -351,9 +370,7 @@ def check_signal_duration(self, parameters, raise_error=True): try: parameters = generate_all_bbh_parameters(parameters) except AttributeError: - logger.debug( - "generate_all_bbh_parameters parameters failed during check_signal_duration" - ) + logger.debug("generate_all_bbh_parameters parameters failed during check_signal_duration") return if ("mass_1" not in parameters) and ("mass_2" not in parameters): @@ -379,9 +396,8 @@ def check_signal_duration(self, parameters, raise_error=True): else: logger.warning(msg) - def inject_signal(self, parameters, injection_polarizations=None, - waveform_generator=None, raise_error=True): - """ General signal injection method. + def inject_signal(self, parameters, injection_polarizations=None, waveform_generator=None, raise_error=True): + """General signal injection method. Provide the injection parameters and either the injection polarizations or the waveform generator to inject a signal into the detector. Defaults to the injection polarizations is both are given. @@ -419,19 +435,19 @@ def inject_signal(self, parameters, injection_polarizations=None, self.check_signal_duration(parameters, raise_error) if injection_polarizations is None and waveform_generator is None: - raise ValueError( - "inject_signal needs one of waveform_generator or " - "injection_polarizations.") + raise ValueError("inject_signal needs one of waveform_generator or injection_polarizations.") elif injection_polarizations is not None: - self.inject_signal_from_waveform_polarizations(parameters=parameters, - injection_polarizations=injection_polarizations) + self.inject_signal_from_waveform_polarizations( + parameters=parameters, injection_polarizations=injection_polarizations + ) elif waveform_generator is not None: - injection_polarizations = self.inject_signal_from_waveform_generator(parameters=parameters, - waveform_generator=waveform_generator) + injection_polarizations = self.inject_signal_from_waveform_generator( + parameters=parameters, waveform_generator=waveform_generator + ) return injection_polarizations def inject_signal_from_waveform_generator(self, parameters, waveform_generator): - """ Inject a signal using a waveform generator and a set of parameters. + """Inject a signal using a waveform generator and a set of parameters. Alternative to `inject_signal` and `inject_signal_from_waveform_polarizations` Parameters @@ -453,14 +469,14 @@ def inject_signal_from_waveform_generator(self, parameters, waveform_generator): The internally generated injection parameters """ - injection_polarizations = \ - waveform_generator.frequency_domain_strain(parameters) - self.inject_signal_from_waveform_polarizations(parameters=parameters, - injection_polarizations=injection_polarizations) + injection_polarizations = waveform_generator.frequency_domain_strain(parameters) + self.inject_signal_from_waveform_polarizations( + parameters=parameters, injection_polarizations=injection_polarizations + ) return injection_polarizations def inject_signal_from_waveform_polarizations(self, parameters, injection_polarizations): - """ Inject a signal into the detector from a dict of waveform polarizations. + """Inject a signal into the detector from a dict of waveform polarizations. Alternative to `inject_signal` and `inject_signal_from_waveform_generator`. Parameters @@ -472,25 +488,25 @@ def inject_signal_from_waveform_polarizations(self, parameters, injection_polari `waveform_generator.frequency_domain_strain()`. """ - if not self.strain_data.time_within_data(parameters['geocent_time']): + if not self.strain_data.time_within_data(parameters["geocent_time"]): logger.warning( - 'Injecting signal outside segment, start_time={}, merger time={}.' - .format(self.strain_data.start_time, parameters['geocent_time'])) + "Injecting signal outside segment, start_time={}, merger time={}.".format( + self.strain_data.start_time, parameters["geocent_time"] + ) + ) signal_ifo = self.get_detector_response(injection_polarizations, parameters) self.strain_data.frequency_domain_strain += signal_ifo - self.meta_data['optimal_SNR'] = ( - np.sqrt(self.optimal_snr_squared(signal=signal_ifo)).real) - self.meta_data['matched_filter_SNR'] = ( - self.matched_filter_snr(signal=signal_ifo)) - self.meta_data['parameters'] = parameters + self.meta_data["optimal_SNR"] = np.sqrt(self.optimal_snr_squared(signal=signal_ifo)).real + self.meta_data["matched_filter_SNR"] = self.matched_filter_snr(signal=signal_ifo) + self.meta_data["parameters"] = parameters - logger.info("Injected signal in {}:".format(self.name)) - logger.info(" optimal SNR = {:.2f}".format(self.meta_data['optimal_SNR'])) - logger.info(" matched filter SNR = {:.2f}".format(self.meta_data['matched_filter_SNR'])) + logger.info(f"Injected signal in {self.name}:") + logger.info(" optimal SNR = {:.2f}".format(self.meta_data["optimal_SNR"])) + logger.info(" matched filter SNR = {:.2f}".format(self.meta_data["matched_filter_SNR"])) for key in parameters: - logger.info(' {} = {}'.format(key, parameters[key])) + logger.info(f" {key} = {parameters[key]}") @property def _window_power_correction(self): @@ -498,29 +514,30 @@ def _window_power_correction(self): This property enables the old (incorrect) PSD correction to be applied using the :code:`BILBY_INCORRECT_PSD_NORMALIZATION` environment variable. """ - if string_to_boolean( - os.environ.get("BILBY_INCORRECT_PSD_NORMALIZATION", "FALSE").upper() - ): + if string_to_boolean(os.environ.get("BILBY_INCORRECT_PSD_NORMALIZATION", "FALSE").upper()): return self.strain_data.window_factor else: return 1 @property def amplitude_spectral_density_array(self): - """ Returns the amplitude spectral density (ASD) given we know a power spectral density (PSD) + """Returns the amplitude spectral density (ASD) given we know a power spectral density (PSD) Returns ======= array_like: An array representation of the ASD """ - return self.power_spectral_density.get_amplitude_spectral_density_array( - frequency_array=self.strain_data.frequency_array - ) * self._window_power_correction**0.5 + return ( + self.power_spectral_density.get_amplitude_spectral_density_array( + frequency_array=self.strain_data.frequency_array + ) + * self._window_power_correction**0.5 + ) @property def power_spectral_density_array(self): - """ Returns the power spectral density (PSD) + """Returns the power spectral density (PSD) This accounts for whether the data in the interferometer has been windowed. @@ -529,13 +546,18 @@ def power_spectral_density_array(self): array_like: An array representation of the PSD """ - return self.power_spectral_density.get_power_spectral_density_array( - frequency_array=self.strain_data.frequency_array - ) * self._window_power_correction + return ( + self.power_spectral_density.get_power_spectral_density_array( + frequency_array=self.strain_data.frequency_array + ) + * self._window_power_correction + ) def unit_vector_along_arm(self, arm): - logger.warning("This method has been moved and will be removed in the future." - "Use Interferometer.geometry.unit_vector_along_arm instead.") + logger.warning( + "This method has been moved and will be removed in the future." + "Use Interferometer.geometry.unit_vector_along_arm instead." + ) return self.geometry.unit_vector_along_arm(arm) def time_delay_from_geocenter(self, ra, dec, time): @@ -570,9 +592,9 @@ def vertex_position_geocentric(self): ======= array_like: A 3D array representation of the vertex """ - return gwutils.get_vertex_position_geocentric(self.geometry.latitude_radians, - self.geometry.longitude_radians, - self.geometry.elevation) + return gwutils.get_vertex_position_geocentric( + self.geometry.latitude_radians, self.geometry.longitude_radians, self.geometry.elevation + ) def optimal_snr_squared(self, signal): """ @@ -589,7 +611,8 @@ def optimal_snr_squared(self, signal): return gwutils.optimal_snr_squared( signal=signal[self.strain_data.frequency_mask], power_spectral_density=self.power_spectral_density_array[self.strain_data.frequency_mask], - duration=self.strain_data.duration) + duration=self.strain_data.duration, + ) def inner_product(self, signal): """ @@ -607,7 +630,8 @@ def inner_product(self, signal): aa=signal[self.strain_data.frequency_mask], bb=self.strain_data.frequency_domain_strain[self.strain_data.frequency_mask], power_spectral_density=self.power_spectral_density_array[self.strain_data.frequency_mask], - duration=self.strain_data.duration) + duration=self.strain_data.duration, + ) def template_template_inner_product(self, signal_1, signal_2): """A noise weighted inner product between two templates, using this ifo's PSD. @@ -627,7 +651,8 @@ def template_template_inner_product(self, signal_1, signal_2): aa=signal_1[self.strain_data.frequency_mask], bb=signal_2[self.strain_data.frequency_mask], power_spectral_density=self.power_spectral_density_array[self.strain_data.frequency_mask], - duration=self.strain_data.duration) + duration=self.strain_data.duration, + ) def matched_filter_snr(self, signal): """ @@ -646,9 +671,10 @@ def matched_filter_snr(self, signal): signal=signal[self.strain_data.frequency_mask], frequency_domain_strain=self.strain_data.frequency_domain_strain[self.strain_data.frequency_mask], power_spectral_density=self.power_spectral_density_array[self.strain_data.frequency_mask], - duration=self.strain_data.duration) + duration=self.strain_data.duration, + ) - def whiten_frequency_series(self, frequency_series : np.array) -> np.array: + def whiten_frequency_series(self, frequency_series: np.array) -> np.array: """Whitens a frequency series with the noise properties of the detector .. math:: @@ -669,10 +695,7 @@ def whiten_frequency_series(self, frequency_series : np.array) -> np.array: """ return frequency_series / (self.amplitude_spectral_density_array * np.sqrt(self.duration / 4)) - def get_whitened_time_series_from_whitened_frequency_series( - self, - whitened_frequency_series : np.array - ) -> np.array: + def get_whitened_time_series_from_whitened_frequency_series(self, whitened_frequency_series: np.array) -> np.array: """Gets the whitened time series from a whitened frequency series. This ifft's and also applies a windowing factor, @@ -698,14 +721,10 @@ def get_whitened_time_series_from_whitened_frequency_series( w = \\sqrt{N W} = \\sqrt{\\sum_{k=0}^N \\Theta(f_{max} - f_k)\\Theta(f_k - f_{min})} """ - frequency_window_factor = ( - np.sum(self.frequency_mask) - / len(self.frequency_mask) - ) + frequency_window_factor = np.sum(self.frequency_mask) / len(self.frequency_mask) whitened_time_series = ( - np.fft.irfft(whitened_frequency_series) - * np.sqrt(np.sum(self.frequency_mask)) / frequency_window_factor + np.fft.irfft(whitened_frequency_series) * np.sqrt(np.sum(self.frequency_mask)) / frequency_window_factor ) return whitened_time_series @@ -739,7 +758,7 @@ def whitened_time_domain_strain(self) -> np.array: return self.get_whitened_time_series_from_whitened_frequency_series(self.whitened_frequency_domain_strain) def save_data(self, outdir, label=None): - """ Creates save files for interferometer data in plain text format. + """Creates save files for interferometer data in plain text format. Saves two files: the frequency domain strain data with three columns [f, real part of h(f), imaginary part of h(f)], and the amplitude spectral density with two columns [f, ASD(f)]. @@ -755,65 +774,75 @@ def save_data(self, outdir, label=None): """ if label is None: - filename_asd = '{}/{}_asd.dat'.format(outdir, self.name) - filename_data = '{}/{}_frequency_domain_data.dat'.format(outdir, self.name) + filename_asd = f"{outdir}/{self.name}_asd.dat" + filename_data = f"{outdir}/{self.name}_frequency_domain_data.dat" else: - filename_asd = '{}/{}_{}_asd.dat'.format(outdir, self.name, label) - filename_data = '{}/{}_{}_frequency_domain_data.dat'.format(outdir, self.name, label) - np.savetxt(filename_data, - np.array( - [self.strain_data.frequency_array, - self.strain_data.frequency_domain_strain.real, - self.strain_data.frequency_domain_strain.imag]).T, - header='f real_h(f) imag_h(f)') - np.savetxt(filename_asd, - np.array( - [self.strain_data.frequency_array, - self.amplitude_spectral_density_array]).T, - header='f h(f)') - - def plot_data(self, signal=None, outdir='.', label=None): + filename_asd = f"{outdir}/{self.name}_{label}_asd.dat" + filename_data = f"{outdir}/{self.name}_{label}_frequency_domain_data.dat" + np.savetxt( + filename_data, + np.array( + [ + self.strain_data.frequency_array, + self.strain_data.frequency_domain_strain.real, + self.strain_data.frequency_domain_strain.imag, + ] + ).T, + header="f real_h(f) imag_h(f)", + ) + np.savetxt( + filename_asd, + np.array([self.strain_data.frequency_array, self.amplitude_spectral_density_array]).T, + header="f h(f)", + ) + + def plot_data(self, signal=None, outdir=".", label=None): import matplotlib.pyplot as plt + if utils.command_line_args.bilby_test_mode: return fig, ax = plt.subplots() df = self.strain_data.frequency_array[1] - self.strain_data.frequency_array[0] - asd = gwutils.asd_from_freq_series( - freq_data=self.strain_data.frequency_domain_strain, df=df) - - ax.loglog(self.strain_data.frequency_array[self.strain_data.frequency_mask], - asd[self.strain_data.frequency_mask], - color='C0', label=self.name) - ax.loglog(self.strain_data.frequency_array[self.strain_data.frequency_mask], - self.amplitude_spectral_density_array[self.strain_data.frequency_mask], - color='C1', lw=1.0, label=self.name + ' ASD') + asd = gwutils.asd_from_freq_series(freq_data=self.strain_data.frequency_domain_strain, df=df) + + ax.loglog( + self.strain_data.frequency_array[self.strain_data.frequency_mask], + asd[self.strain_data.frequency_mask], + color="C0", + label=self.name, + ) + ax.loglog( + self.strain_data.frequency_array[self.strain_data.frequency_mask], + self.amplitude_spectral_density_array[self.strain_data.frequency_mask], + color="C1", + lw=1.0, + label=self.name + " ASD", + ) if signal is not None: - signal_asd = gwutils.asd_from_freq_series( - freq_data=signal, df=df) + signal_asd = gwutils.asd_from_freq_series(freq_data=signal, df=df) - ax.loglog(self.strain_data.frequency_array[self.strain_data.frequency_mask], - signal_asd[self.strain_data.frequency_mask], - color='C2', - label='Signal') + ax.loglog( + self.strain_data.frequency_array[self.strain_data.frequency_mask], + signal_asd[self.strain_data.frequency_mask], + color="C2", + label="Signal", + ) ax.grid(True) - ax.set_ylabel(r'Strain [strain/$\sqrt{\rm Hz}$]') - ax.set_xlabel(r'Frequency [Hz]') - ax.legend(loc='best') + ax.set_ylabel(r"Strain [strain/$\sqrt{\rm Hz}$]") + ax.set_xlabel(r"Frequency [Hz]") + ax.legend(loc="best") fig.tight_layout() if label is None: - fig.savefig( - '{}/{}_frequency_domain_data.png'.format(outdir, self.name)) + fig.savefig(f"{outdir}/{self.name}_frequency_domain_data.png") else: - fig.savefig( - '{}/{}_{}_frequency_domain_data.png'.format( - outdir, self.name, label)) + fig.savefig(f"{outdir}/{self.name}_{label}_frequency_domain_data.png") plt.close(fig) def plot_time_domain_data( - self, outdir='.', label=None, bandpass_frequencies=(50, 250), - notches=None, start_end=None, t0=None): - """ Plots the strain data in the time domain + self, outdir=".", label=None, bandpass_frequencies=(50, 250), notches=None, start_end=None, t0=None + ): + """Plots the strain data in the time domain Parameters ========== @@ -832,23 +861,19 @@ def plot_time_domain_data( """ import matplotlib.pyplot as plt - from gwpy.timeseries import TimeSeries from gwpy.signal.filter_design import bandpass, concatenate_zpks, notch + from gwpy.timeseries import TimeSeries # We use the gwpy timeseries to perform bandpass and notching if notches is None: notches = list() - timeseries = TimeSeries( - data=self.strain_data.time_domain_strain, times=self.strain_data.time_array) + timeseries = TimeSeries(data=self.strain_data.time_domain_strain, times=self.strain_data.time_array) zpks = [] if bandpass_frequencies is not None: - zpks.append(bandpass( - bandpass_frequencies[0], bandpass_frequencies[1], - self.strain_data.sampling_frequency)) + zpks.append(bandpass(bandpass_frequencies[0], bandpass_frequencies[1], self.strain_data.sampling_frequency)) if notches is not None: for line in notches: - zpks.append(notch( - line, self.strain_data.sampling_frequency)) + zpks.append(notch(line, self.strain_data.sampling_frequency)) if len(zpks) > 0: zpk = concatenate_zpks(*zpks) strain = timeseries.filter(zpk, filtfilt=False) @@ -859,14 +884,14 @@ def plot_time_domain_data( if t0: x = self.strain_data.time_array - t0 - xlabel = 'GPS time [s] - {}'.format(t0) + xlabel = f"GPS time [s] - {t0}" else: x = self.strain_data.time_array - xlabel = 'GPS time [s]' + xlabel = "GPS time [s]" ax.plot(x, strain) ax.set_xlabel(xlabel) - ax.set_ylabel('Strain') + ax.set_ylabel("Strain") if start_end is not None: ax.set_xlim(*start_end) @@ -874,16 +899,14 @@ def plot_time_domain_data( fig.tight_layout() if label is None: - fig.savefig( - '{}/{}_time_domain_data.png'.format(outdir, self.name)) + fig.savefig(f"{outdir}/{self.name}_time_domain_data.png") else: - fig.savefig( - '{}/{}_{}_time_domain_data.png'.format(outdir, self.name, label)) + fig.savefig(f"{outdir}/{self.name}_{label}_time_domain_data.png") plt.close(fig) @staticmethod def _filename_from_outdir_label_extension(outdir, label, extension="h5"): - return os.path.join(outdir, label + f'.{extension}') + return os.path.join(outdir, label + f".{extension}") _save_ifo_docstring = """ Save the object to a {format} file @@ -906,11 +929,9 @@ def _filename_from_outdir_label_extension(outdir, label, extension="h5"): """ - @docstring(_save_ifo_docstring.format( - format="pickle", extra=".. versionadded:: 1.1.0" - )) + @docstring(_save_ifo_docstring.format(format="pickle", extra=".. versionadded:: 1.1.0")) def to_pickle(self, outdir="outdir", label=None): - utils.check_directory_exists_and_if_not_mkdir('outdir') + utils.check_directory_exists_and_if_not_mkdir("outdir") filename = self._filename_from_outdir_label_extension(outdir, label, extension="pkl") safe_file_dump(self, filename, "dill") @@ -918,8 +939,9 @@ def to_pickle(self, outdir="outdir", label=None): @docstring(_load_docstring.format(format="pickle")) def from_pickle(cls, filename=None): import dill + with open(filename, "rb") as ff: res = dill.load(ff) if res.__class__ != cls: - raise TypeError('The loaded object is not an Interferometer') + raise TypeError("The loaded object is not an Interferometer") return res diff --git a/bilby/gw/detector/networks.py b/bilby/gw/detector/networks.py index 4efd3d8db..06cc986cf 100644 --- a/bilby/gw/detector/networks.py +++ b/bilby/gw/detector/networks.py @@ -1,7 +1,7 @@ +import math import os import numpy as np -import math from ...core import utils from ...core.utils import logger, safe_file_dump @@ -24,16 +24,14 @@ def __init__(self, interferometers): The list of interferometers """ - super(InterferometerList, self).__init__() + super().__init__() if isinstance(interferometers, str): raise TypeError("Input must not be a string") for ifo in interferometers: if isinstance(ifo, str): ifo = get_empty_interferometer(ifo) if not isinstance(ifo, (Interferometer, TriangularInterferometer)): - raise TypeError( - "Input list of interferometers are not all Interferometer objects" - ) + raise TypeError("Input list of interferometers are not all Interferometer objects") else: self.append(ifo) self._check_interferometers() @@ -48,24 +46,12 @@ def _check_interferometers(self): """ consistent_attributes = ["duration", "start_time", "sampling_frequency"] for attribute in consistent_attributes: - x = [ - getattr(interferometer.strain_data, attribute) - for interferometer in self - ] + x = [getattr(interferometer.strain_data, attribute) for interferometer in self] try: if not all(y == x[0] for y in x): - ifo_strs = [ - "{ifo}[{attribute}]={value}".format( - ifo=ifo.name, - attribute=attribute, - value=getattr(ifo.strain_data, attribute), - ) - for ifo in self - ] + ifo_strs = [f"{ifo.name}[{attribute}]={getattr(ifo.strain_data, attribute)}" for ifo in self] raise ValueError( - "The {} of all interferometers are not the same: {}".format( - attribute, ", ".join(ifo_strs) - ) + "The {} of all interferometers are not the same: {}".format(attribute, ", ".join(ifo_strs)) ) except ValueError as e: if not all(math.isclose(y, x[0], abs_tol=1e-5) for y in x): @@ -73,9 +59,7 @@ def _check_interferometers(self): else: logger.warning(e) - def set_strain_data_from_power_spectral_densities( - self, sampling_frequency, duration, start_time=0 - ): + def set_strain_data_from_power_spectral_densities(self, sampling_frequency, duration, start_time=0): """Set the `Interferometer.strain_data` from the power spectral densities of the detectors This uses the `interferometer.power_spectral_density` object to set @@ -99,9 +83,7 @@ def set_strain_data_from_power_spectral_densities( start_time=start_time, ) - def set_strain_data_from_zero_noise( - self, sampling_frequency, duration, start_time=0 - ): + def set_strain_data_from_zero_noise(self, sampling_frequency, duration, start_time=0): """Set the `Interferometer.strain_data` to zero in each detector See :py:meth:`bilby.gw.detector.InterferometerStrainData.set_from_zero_noise` @@ -131,7 +113,7 @@ def inject_signal( waveform_generator=None, raise_error=True, ): - """ Inject a signal into noise in each of the three detectors. + """Inject a signal into noise in each of the three detectors. Parameters ========== @@ -162,14 +144,9 @@ def inject_signal( """ if injection_polarizations is None: if waveform_generator is not None: - injection_polarizations = waveform_generator.frequency_domain_strain( - parameters - ) + injection_polarizations = waveform_generator.frequency_domain_strain(parameters) else: - raise ValueError( - "inject_signal needs one of waveform_generator or " - "injection_polarizations." - ) + raise ValueError("inject_signal needs one of waveform_generator or injection_polarizations.") all_injection_polarizations = list() for interferometer in self: @@ -204,8 +181,7 @@ def plot_data(self, signal=None, outdir=".", label=None): interferometer.plot_data(signal=signal, outdir=outdir, label=label) def plot_time_domain_data( - self, outdir=".", label=None, bandpass_frequencies=(50, 250), - notches=None, start_end=None, t0=None + self, outdir=".", label=None, bandpass_frequencies=(50, 250), notches=None, start_end=None, t0=None ): """Plots the strain data in the time domain for each of the interfeormeters @@ -237,7 +213,7 @@ def plot_time_domain_data( bandpass_frequencies=bandpass_frequencies, notches=notches, start_end=start_end, - t0=t0 + t0=t0, ) @property @@ -262,25 +238,23 @@ def frequency_array(self): def append(self, interferometer): if isinstance(interferometer, InterferometerList): - super(InterferometerList, self).extend(interferometer) + super().extend(interferometer) else: - super(InterferometerList, self).append(interferometer) + super().append(interferometer) self._check_interferometers() def extend(self, interferometers): - super(InterferometerList, self).extend(interferometers) + super().extend(interferometers) self._check_interferometers() def insert(self, index, interferometer): - super(InterferometerList, self).insert(index, interferometer) + super().insert(index, interferometer) self._check_interferometers() @property def meta_data(self): """Dictionary of the per-interferometer meta_data""" - return { - interferometer.name: interferometer.meta_data for interferometer in self - } + return {interferometer.name: interferometer.meta_data for interferometer in self} @staticmethod def _filename_from_outdir_label_extension(outdir, label, extension="h5"): @@ -311,9 +285,7 @@ def _filename_from_outdir_label_extension(outdir, label, extension="h5"): def to_pickle(self, outdir="outdir", label="ifo_list"): utils.check_directory_exists_and_if_not_mkdir(outdir) label = label + "_" + "".join(ifo.name for ifo in self) - filename = self._filename_from_outdir_label_extension( - outdir, label, extension="pkl" - ) + filename = self._filename_from_outdir_label_extension(outdir, label, extension="pkl") safe_file_dump(self, filename, "dill") @classmethod @@ -326,9 +298,7 @@ def from_pickle(cls, filename=None): raise TypeError("The loaded object is not an InterferometerList") return res - to_pickle.__doc__ = _save_docstring.format( - format="pickle", extra=".. versionadded:: 1.1.0" - ) + to_pickle.__doc__ = _save_docstring.format(format="pickle", extra=".. versionadded:: 1.1.0") from_pickle.__doc__ = _load_docstring.format(format="pickle") @@ -348,7 +318,7 @@ def __init__( xarm_tilt=0.0, yarm_tilt=0.0, ): - super(TriangularInterferometer, self).__init__([]) + super().__init__([]) self.name = name # for attr in ['power_spectral_density', 'minimum_frequency', 'maximum_frequency']: if isinstance(power_spectral_density, PowerSpectralDensity): @@ -361,7 +331,7 @@ def __init__( for ii in range(3): self.append( Interferometer( - "{}{}".format(name, ii + 1), + f"{name}{ii + 1}", power_spectral_density[ii], minimum_frequency[ii], maximum_frequency[ii], @@ -380,24 +350,10 @@ def __init__( yarm_azimuth += 240 latitude += ( - np.arctan( - length - * np.sin(xarm_azimuth * np.pi / 180) - * 1e3 - / utils.radius_of_earth - ) - * 180 - / np.pi + np.arctan(length * np.sin(xarm_azimuth * np.pi / 180) * 1e3 / utils.radius_of_earth) * 180 / np.pi ) longitude += ( - np.arctan( - length - * np.cos(xarm_azimuth * np.pi / 180) - * 1e3 - / utils.radius_of_earth - ) - * 180 - / np.pi + np.arctan(length * np.cos(xarm_azimuth * np.pi / 180) * 1e3 / utils.radius_of_earth) * 180 / np.pi ) @@ -431,19 +387,17 @@ def get_empty_interferometer(name): interferometer: Interferometer Interferometer instance """ - filename = os.path.join( - os.path.dirname(__file__), "detectors", "{}.interferometer".format(name) - ) + filename = os.path.join(os.path.dirname(__file__), "detectors", f"{name}.interferometer") try: return load_interferometer(filename) except OSError: - raise ValueError("Interferometer {} not implemented".format(name)) + raise ValueError(f"Interferometer {name} not implemented") def load_interferometer(filename): """Load an interferometer from a file.""" parameters = dict() - with open(filename, "r") as parameter_file: + with open(filename) as parameter_file: lines = parameter_file.readlines() for line in lines: if line[0] == "#" or line[0] == "\n": @@ -462,7 +416,5 @@ def load_interferometer(filename): parameters.pop("shape") ifo = TriangularInterferometer(**parameters) else: - raise IOError( - "{} could not be loaded. Invalid parameter 'shape'.".format(filename) - ) + raise OSError(f"{filename} could not be loaded. Invalid parameter 'shape'.") return ifo diff --git a/bilby/gw/detector/psd.py b/bilby/gw/detector/psd.py index a3948f966..9e56c3304 100644 --- a/bilby/gw/detector/psd.py +++ b/bilby/gw/detector/psd.py @@ -8,10 +8,8 @@ from .strain_data import InterferometerStrainData -class PowerSpectralDensity(object): - - def __init__(self, frequency_array=None, psd_array=None, asd_array=None, - psd_file=None, asd_file=None): +class PowerSpectralDensity: + def __init__(self, frequency_array=None, psd_array=None, asd_array=None, psd_file=None, asd_file=None): """ Instantiate a new PowerSpectralDensity object. @@ -46,8 +44,7 @@ def __init__(self, frequency_array=None, psd_array=None, asd_array=None, Interpolated function of the PSD """ - self._cache = dict( - frequency_array=np.array([]), psd_array=None, asd_array=None) + self._cache = dict(frequency_array=np.array([]), psd_array=None, asd_array=None) self.frequency_array = np.array(frequency_array) if psd_array is not None: self.psd_array = psd_array @@ -58,30 +55,33 @@ def __init__(self, frequency_array=None, psd_array=None, asd_array=None, def _update_cache(self, frequency_array): psd_array = self.power_spectral_density_interpolated(frequency_array) - self._cache['psd_array'] = psd_array - self._cache['asd_array'] = psd_array**0.5 - self._cache['frequency_array'] = frequency_array + self._cache["psd_array"] = psd_array + self._cache["asd_array"] = psd_array**0.5 + self._cache["frequency_array"] = frequency_array def __eq__(self, other): - if self.psd_file == other.psd_file \ - and self.asd_file == other.asd_file \ - and np.array_equal(self.frequency_array, other.frequency_array) \ - and np.array_equal(self.psd_array, other.psd_array) \ - and np.array_equal(self.asd_array, other.asd_array): + if ( + self.psd_file == other.psd_file + and self.asd_file == other.asd_file + and np.array_equal(self.frequency_array, other.frequency_array) + and np.array_equal(self.psd_array, other.psd_array) + and np.array_equal(self.asd_array, other.asd_array) + ): return True return False def __repr__(self): if self.asd_file is not None or self.psd_file is not None: - return self.__class__.__name__ + '(psd_file=\'{}\', asd_file=\'{}\')' \ - .format(self.psd_file, self.asd_file) + return self.__class__.__name__ + f"(psd_file='{self.psd_file}', asd_file='{self.asd_file}')" else: - return self.__class__.__name__ + '(frequency_array={}, psd_array={}, asd_array={})' \ - .format(self.frequency_array, self.psd_array, self.asd_array) + return ( + self.__class__.__name__ + + f"(frequency_array={self.frequency_array}, psd_array={self.psd_array}, asd_array={self.asd_array})" + ) @staticmethod def from_amplitude_spectral_density_file(asd_file): - """ Set the amplitude spectral density from a given file + """Set the amplitude spectral density from a given file Parameters ========== @@ -93,7 +93,7 @@ def from_amplitude_spectral_density_file(asd_file): @staticmethod def from_power_spectral_density_file(psd_file): - """ Set the power spectral density from a given file + """Set the power spectral density from a given file Parameters ========== @@ -104,11 +104,20 @@ def from_power_spectral_density_file(psd_file): return PowerSpectralDensity(psd_file=psd_file) @staticmethod - def from_frame_file(frame_file, psd_start_time, psd_duration, - fft_length=4, sampling_frequency=4096, roll_off=0.2, - overlap=0, channel=None, name=None, outdir=None, - analysis_segment_start_time=None): - """ Generate power spectral density from a frame file + def from_frame_file( + frame_file, + psd_start_time, + psd_duration, + fft_length=4, + sampling_frequency=4096, + roll_off=0.2, + overlap=0, + channel=None, + name=None, + outdir=None, + analysis_segment_start_time=None, + ): + """Generate power spectral density from a frame file Parameters ========== @@ -139,19 +148,35 @@ def from_frame_file(frame_file, psd_start_time, psd_duration, """ strain = InterferometerStrainData(roll_off=roll_off) strain.set_from_frame_file( - frame_file, start_time=psd_start_time, duration=psd_duration, - channel=channel, sampling_frequency=sampling_frequency) + frame_file, + start_time=psd_start_time, + duration=psd_duration, + channel=channel, + sampling_frequency=sampling_frequency, + ) frequency_array, psd_array = strain.create_power_spectral_density( - fft_length=fft_length, name=name, outdir=outdir, overlap=overlap, - analysis_segment_start_time=analysis_segment_start_time) + fft_length=fft_length, + name=name, + outdir=outdir, + overlap=overlap, + analysis_segment_start_time=analysis_segment_start_time, + ) return PowerSpectralDensity(frequency_array=frequency_array, psd_array=psd_array) @staticmethod - def from_channel_name(channel, psd_start_time, psd_duration, - fft_length=4, sampling_frequency=4096, roll_off=0.2, - overlap=0, name=None, outdir=None, - analysis_segment_start_time=None): - """ Generate power spectral density from a given channel name + def from_channel_name( + channel, + psd_start_time, + psd_duration, + fft_length=4, + sampling_frequency=4096, + roll_off=0.2, + overlap=0, + name=None, + outdir=None, + analysis_segment_start_time=None, + ): + """Generate power spectral density from a given channel name by loading data using `strain_data.set_from_channel_name` Parameters @@ -182,11 +207,15 @@ def from_channel_name(channel, psd_start_time, psd_duration, """ strain = InterferometerStrainData(roll_off=roll_off) strain.set_from_channel_name( - channel, duration=psd_duration, start_time=psd_start_time, - sampling_frequency=sampling_frequency) + channel, duration=psd_duration, start_time=psd_start_time, sampling_frequency=sampling_frequency + ) frequency_array, psd_array = strain.create_power_spectral_density( - fft_length=fft_length, name=name, outdir=outdir, overlap=overlap, - analysis_segment_start_time=analysis_segment_start_time) + fft_length=fft_length, + name=name, + outdir=outdir, + overlap=overlap, + analysis_segment_start_time=analysis_segment_start_time, + ) return PowerSpectralDensity(frequency_array=frequency_array, psd_array=psd_array) @staticmethod @@ -199,9 +228,8 @@ def from_power_spectral_density_array(frequency_array, psd_array): @staticmethod def from_aligo(): - logger.info("No power spectral density provided, using aLIGO," - "zero detuning, high power.") - return PowerSpectralDensity.from_power_spectral_density_file(psd_file='aLIGO_ZERO_DET_high_P_psd.txt') + logger.info("No power spectral density provided, using aLIGO,zero detuning, high power.") + return PowerSpectralDensity.from_power_spectral_density_file(psd_file="aLIGO_ZERO_DET_high_P_psd.txt") @property def psd_array(self): @@ -211,7 +239,7 @@ def psd_array(self): def psd_array(self, psd_array): self.__check_frequency_array_matches_density_array(psd_array) self.__psd_array = np.array(psd_array) - self.__asd_array = psd_array ** 0.5 + self.__asd_array = psd_array**0.5 self.__interpolate_power_spectral_density() @property @@ -222,34 +250,34 @@ def asd_array(self): def asd_array(self, asd_array): self.__check_frequency_array_matches_density_array(asd_array) self.__asd_array = np.array(asd_array) - self.__psd_array = asd_array ** 2 + self.__psd_array = asd_array**2 self.__interpolate_power_spectral_density() def __check_frequency_array_matches_density_array(self, density_array): if len(self.frequency_array) != len(density_array): - raise ValueError('Provided spectral density does not match frequency array. Not updating.\n' - 'Length spectral density {}\n Length frequency array {}\n' - .format(density_array, self.frequency_array)) + raise ValueError( + "Provided spectral density does not match frequency array. Not updating.\n" + f"Length spectral density {density_array}\n Length frequency array {self.frequency_array}\n" + ) def __interpolate_power_spectral_density(self): """Interpolate the loaded power spectral density so it can be resampled - for arbitrary frequency arrays. + for arbitrary frequency arrays. """ - self.__power_spectral_density_interpolated = interp1d(self.frequency_array, - self.psd_array, - bounds_error=False, - fill_value=np.inf) + self.__power_spectral_density_interpolated = interp1d( + self.frequency_array, self.psd_array, bounds_error=False, fill_value=np.inf + ) self._update_cache(self.frequency_array) def get_power_spectral_density_array(self, frequency_array): - if not np.array_equal(frequency_array, self._cache['frequency_array']): + if not np.array_equal(frequency_array, self._cache["frequency_array"]): self._update_cache(frequency_array=frequency_array) - return self._cache['psd_array'] + return self._cache["psd_array"] def get_amplitude_spectral_density_array(self, frequency_array): - if not np.array_equal(frequency_array, self._cache['frequency_array']): + if not np.array_equal(frequency_array, self._cache["frequency_array"]): self._update_cache(frequency_array=frequency_array) - return self._cache['asd_array'] + return self._cache["asd_array"] @property def power_spectral_density_interpolated(self): @@ -271,7 +299,7 @@ def __check_file_was_asd_file(self): if min(self.asd_array) < 1e-30: logger.warning("You specified an amplitude spectral density file.") logger.warning("{} WARNING {}".format("*" * 30, "*" * 30)) - logger.warning("The minimum of the provided curve is {:.2e}.".format(min(self.asd_array))) + logger.warning(f"The minimum of the provided curve is {min(self.asd_array):.2e}.") logger.warning("You may have intended to provide this as a power spectral density.") @property @@ -290,7 +318,7 @@ def __check_file_was_psd_file(self): if min(self.psd_array) > 1e-30: logger.warning("You specified a power spectral density file.") logger.warning("{} WARNING {}".format("*" * 30, "*" * 30)) - logger.warning("The minimum of the provided curve is {:.2e}.".format(min(self.psd_array))) + logger.warning(f"The minimum of the provided curve is {min(self.psd_array):.2e}.") logger.warning("You may have intended to provide this as an amplitude spectral density.") @staticmethod @@ -319,26 +347,23 @@ def __validate_file_name(file): logger.debug("PSD file set to None") return None elif os.path.isfile(file): - logger.debug("PSD file {} exists".format(file)) + logger.debug(f"PSD file {file} exists") return file else: - file_in_default_directory = ( - os.path.join(os.path.dirname(__file__), 'noise_curves', file)) + file_in_default_directory = os.path.join(os.path.dirname(__file__), "noise_curves", file) if os.path.isfile(file_in_default_directory): - logger.debug("PSD file {} exists in default dir.".format(file)) + logger.debug(f"PSD file {file} exists in default dir.") return file_in_default_directory else: - raise ValueError( - "Unable to locate PSD file {} locally or in the default dir" - .format(file)) + raise ValueError(f"Unable to locate PSD file {file} locally or in the default dir") return file def __import_amplitude_spectral_density(self): - """ Automagically load an amplitude spectral density curve """ + """Automagically load an amplitude spectral density curve""" self.frequency_array, self.asd_array = np.genfromtxt(self.asd_file).T def __import_power_spectral_density(self): - """ Automagically load a power spectral density curve """ + """Automagically load a power spectral density curve""" self.frequency_array, self.psd_array = np.genfromtxt(self.psd_file).T def get_noise_realisation(self, sampling_frequency, duration): diff --git a/bilby/gw/detector/strain_data.py b/bilby/gw/detector/strain_data.py index bca7acced..0390e500f 100644 --- a/bilby/gw/detector/strain_data.py +++ b/bilby/gw/detector/strain_data.py @@ -2,22 +2,21 @@ from ...core import utils from ...core.series import CoupledTimeAndFrequencySeries -from ...core.utils import logger, PropertyAccessor +from ...core.utils import PropertyAccessor, logger from .. import utils as gwutils -class InterferometerStrainData(object): - """ Strain data for an interferometer """ +class InterferometerStrainData: + """Strain data for an interferometer""" - duration = PropertyAccessor('_times_and_frequencies', 'duration') - sampling_frequency = PropertyAccessor('_times_and_frequencies', 'sampling_frequency') - start_time = PropertyAccessor('_times_and_frequencies', 'start_time') - frequency_array = PropertyAccessor('_times_and_frequencies', 'frequency_array') - time_array = PropertyAccessor('_times_and_frequencies', 'time_array') + duration = PropertyAccessor("_times_and_frequencies", "duration") + sampling_frequency = PropertyAccessor("_times_and_frequencies", "sampling_frequency") + start_time = PropertyAccessor("_times_and_frequencies", "start_time") + frequency_array = PropertyAccessor("_times_and_frequencies", "frequency_array") + time_array = PropertyAccessor("_times_and_frequencies", "time_array") - def __init__(self, minimum_frequency=0, maximum_frequency=np.inf, - roll_off=0.2, notch_list=None): - """ Initiate an InterferometerStrainData object + def __init__(self, minimum_frequency=0, maximum_frequency=np.inf, roll_off=0.2, notch_list=None): + """Initiate an InterferometerStrainData object The initialised object contains no data, this should be added using one of the `set_from..` methods. @@ -51,22 +50,24 @@ def __init__(self, minimum_frequency=0, maximum_frequency=np.inf, self._channel = None def __eq__(self, other): - if self.minimum_frequency == other.minimum_frequency \ - and self.maximum_frequency == other.maximum_frequency \ - and self.roll_off == other.roll_off \ - and self.window_factor == other.window_factor \ - and self.sampling_frequency == other.sampling_frequency \ - and self.duration == other.duration \ - and self.start_time == other.start_time \ - and np.array_equal(self.time_array, other.time_array) \ - and np.array_equal(self.frequency_array, other.frequency_array) \ - and np.array_equal(self.frequency_domain_strain, other.frequency_domain_strain) \ - and np.array_equal(self.time_domain_strain, other.time_domain_strain): + if ( + self.minimum_frequency == other.minimum_frequency + and self.maximum_frequency == other.maximum_frequency + and self.roll_off == other.roll_off + and self.window_factor == other.window_factor + and self.sampling_frequency == other.sampling_frequency + and self.duration == other.duration + and self.start_time == other.start_time + and np.array_equal(self.time_array, other.time_array) + and np.array_equal(self.frequency_array, other.frequency_array) + and np.array_equal(self.frequency_domain_strain, other.frequency_domain_strain) + and np.array_equal(self.time_domain_strain, other.time_domain_strain) + ): return True return False def time_within_data(self, time): - """ Check if time is within the data span + """Check if time is within the data span Parameters ========== @@ -99,10 +100,10 @@ def minimum_frequency(self, minimum_frequency): @property def maximum_frequency(self): - """ Force the maximum frequency be less than the Nyquist frequency """ + """Force the maximum frequency be less than the Nyquist frequency""" if self.sampling_frequency is not None: if 2 * self._maximum_frequency > self.sampling_frequency: - self._maximum_frequency = self.sampling_frequency / 2. + self._maximum_frequency = self.sampling_frequency / 2.0 return self._maximum_frequency @maximum_frequency.setter @@ -116,7 +117,7 @@ def notch_list(self): @notch_list.setter def notch_list(self, notch_list): - """ Set the notch_list + """Set the notch_list Parameters ========== @@ -132,12 +133,12 @@ def notch_list(self, notch_list): elif isinstance(notch_list, NotchList): self._notch_list = notch_list else: - raise ValueError("notch_list {} not understood".format(notch_list)) + raise ValueError(f"notch_list {notch_list} not understood") self._frequency_mask_updated = False @property def frequency_mask(self): - """ Masking array for limiting the frequency band. + """Masking array for limiting the frequency band. Returns ======= @@ -146,8 +147,7 @@ def frequency_mask(self): """ if not self._frequency_mask_updated: frequency_array = self._times_and_frequencies.frequency_array - mask = ((frequency_array >= self.minimum_frequency) & - (frequency_array <= self.maximum_frequency)) + mask = (frequency_array >= self.minimum_frequency) & (frequency_array <= self.maximum_frequency) for notch in self.notch_list: mask[notch.get_idxs(frequency_array)] = False self._frequency_mask = mask @@ -184,22 +184,22 @@ def time_domain_window(self, roll_off=None, alpha=None): Window function over time array """ from scipy.signal.windows import tukey + if roll_off is not None: self.roll_off = roll_off elif alpha is not None: self.roll_off = alpha * self.duration / 2 window = tukey(len(self._time_domain_strain), alpha=self.alpha) - self.window_factor = np.mean(window ** 2) + self.window_factor = np.mean(window**2) return window @property def time_domain_strain(self): - """ The time domain strain, in units of strain """ + """The time domain strain, in units of strain""" if self._time_domain_strain is not None: return self._time_domain_strain elif self._frequency_domain_strain is not None: - self._time_domain_strain = utils.infft( - self.frequency_domain_strain, self.sampling_frequency) + self._time_domain_strain = utils.infft(self.frequency_domain_strain, self.sampling_frequency) return self._time_domain_strain else: @@ -207,7 +207,7 @@ def time_domain_strain(self): @property def frequency_domain_strain(self): - """ Returns the frequency domain strain + """Returns the frequency domain strain This is the frequency domain strain normalised to units of strain / Hz, obtained by a one-sided Fourier transform of the @@ -216,14 +216,13 @@ def frequency_domain_strain(self): if self._frequency_domain_strain is not None: return self._frequency_domain_strain * self.frequency_mask elif self._time_domain_strain is not None: - logger.debug("Generating frequency domain strain from given time " - "domain strain.") - logger.debug("Applying a tukey window with alpha={}, roll off={}".format( - self.alpha, self.roll_off)) + logger.debug("Generating frequency domain strain from given time domain strain.") + logger.debug(f"Applying a tukey window with alpha={self.alpha}, roll off={self.roll_off}") # self.low_pass_filter() window = self.time_domain_window() self._frequency_domain_strain, self.frequency_array = utils.nfft( - self._time_domain_strain * window, self.sampling_frequency) + self._time_domain_strain * window, self.sampling_frequency + ) return self._frequency_domain_strain * self.frequency_mask else: raise ValueError("frequency domain strain data not yet set") @@ -245,8 +244,7 @@ def to_gwpy_timeseries(self): raise ModuleNotFoundError("Cannot output strain data as gwpy TimeSeries") return TimeSeries( - self.time_domain_strain, sample_rate=self.sampling_frequency, - t0=self.start_time, channel=self.channel + self.time_domain_strain, sample_rate=self.sampling_frequency, t0=self.start_time, channel=self.channel ) def to_pycbc_timeseries(self): @@ -255,14 +253,13 @@ def to_pycbc_timeseries(self): """ try: - from pycbc.types.timeseries import TimeSeries from lal import LIGOTimeGPS + from pycbc.types.timeseries import TimeSeries except ModuleNotFoundError: raise ModuleNotFoundError("Cannot output strain data as PyCBC TimeSeries") return TimeSeries( - self.time_domain_strain, delta_t=(1. / self.sampling_frequency), - epoch=LIGOTimeGPS(self.start_time) + self.time_domain_strain, delta_t=(1.0 / self.sampling_frequency), epoch=LIGOTimeGPS(self.start_time) ) def to_lal_timeseries(self): @@ -275,8 +272,7 @@ def to_lal_timeseries(self): raise ModuleNotFoundError("Cannot output strain data as PyCBC TimeSeries") lal_data = CreateREAL8TimeSeries( - "", LIGOTimeGPS(self.start_time), 0, 1 / self.sampling_frequency, - SecondUnit, len(self.time_domain_strain) + "", LIGOTimeGPS(self.start_time), 0, 1 / self.sampling_frequency, SecondUnit, len(self.time_domain_strain) ) lal_data.data.data[:] = self.time_domain_strain @@ -292,10 +288,7 @@ def to_gwpy_frequencyseries(self): raise ModuleNotFoundError("Cannot output strain data as gwpy FrequencySeries") return FrequencySeries( - self.frequency_domain_strain, - frequencies=self.frequency_array, - epoch=self.start_time, - channel=self.channel + self.frequency_domain_strain, frequencies=self.frequency_array, epoch=self.start_time, channel=self.channel ) def to_pycbc_frequencyseries(self): @@ -304,15 +297,13 @@ def to_pycbc_frequencyseries(self): """ try: - from pycbc.types.frequencyseries import FrequencySeries from lal import LIGOTimeGPS + from pycbc.types.frequencyseries import FrequencySeries except ImportError: raise ImportError("Cannot output strain data as PyCBC FrequencySeries") return FrequencySeries( - self.frequency_domain_strain, - delta_f=1 / self.duration, - epoch=LIGOTimeGPS(self.start_time) + self.frequency_domain_strain, delta_f=1 / self.duration, epoch=LIGOTimeGPS(self.start_time) ) def to_lal_frequencyseries(self): @@ -330,39 +321,38 @@ def to_lal_frequencyseries(self): self.frequency_array[0], 1 / self.duration, SecondUnit, - len(self.frequency_domain_strain) + len(self.frequency_domain_strain), ) lal_data.data.data[:] = self.frequency_domain_strain return lal_data def low_pass_filter(self, filter_freq=None): - """ Low pass filter the data """ + """Low pass filter the data""" from gwpy.signal.filter_design import lowpass from gwpy.timeseries import TimeSeries if filter_freq is None: - logger.debug( - "Setting low pass filter_freq using given maximum frequency") + logger.debug("Setting low pass filter_freq using given maximum frequency") filter_freq = self.maximum_frequency if 2 * filter_freq >= self.sampling_frequency: logger.info( - "Low pass filter frequency of {}Hz requested, this is equal" + f"Low pass filter frequency of {filter_freq}Hz requested, this is equal" " or greater than the Nyquist frequency so no filter applied" - .format(filter_freq)) + ) return - logger.debug("Applying low pass filter with filter frequency {}".format(filter_freq)) + logger.debug(f"Applying low pass filter with filter frequency {filter_freq}") bp = lowpass(filter_freq, self.sampling_frequency) strain = TimeSeries(self.time_domain_strain, sample_rate=self.sampling_frequency) strain = strain.filter(bp, filtfilt=True) self._time_domain_strain = strain.value def create_power_spectral_density( - self, fft_length, overlap=0, name='unknown', outdir=None, - analysis_segment_start_time=None): - """ Use the time domain strain to generate a power spectral density + self, fft_length, overlap=0, name="unknown", outdir=None, analysis_segment_start_time=None + ): + """Use the time domain strain to generate a power spectral density This create a Tukey-windowed power spectral density and writes it to a PSD file. @@ -395,77 +385,75 @@ def create_power_spectral_density( if analysis_segment_start_time is not None: analysis_segment_end_time = analysis_segment_start_time + fft_length - inside = (analysis_segment_start_time > self.time_array[0] + - analysis_segment_end_time < self.time_array[-1]) + inside = analysis_segment_start_time > self.time_array[0] + analysis_segment_end_time < self.time_array[-1] if inside: logger.info("Removing analysis segment data from the PSD data") - idxs = ( - (self.time_array < analysis_segment_start_time) + - (self.time_array > analysis_segment_end_time)) + idxs = (self.time_array < analysis_segment_start_time) + (self.time_array > analysis_segment_end_time) data = data[idxs] # WARNING this line can cause issues if the data is non-contiguous strain = TimeSeries(data=data, sample_rate=self.sampling_frequency) psd_alpha = 2 * self.roll_off / fft_length - logger.info( - "Tukey window PSD data with alpha={}, roll off={}".format( - psd_alpha, self.roll_off)) - psd = strain.psd( - fftlength=fft_length, overlap=overlap, window=('tukey', psd_alpha)) + logger.info(f"Tukey window PSD data with alpha={psd_alpha}, roll off={self.roll_off}") + psd = strain.psd(fftlength=fft_length, overlap=overlap, window=("tukey", psd_alpha)) if outdir: - psd_file = '{}/{}_PSD_{}_{}.txt'.format(outdir, name, self.start_time, self.duration) - with open('{}'.format(psd_file), 'w+') as opened_file: + psd_file = f"{outdir}/{name}_PSD_{self.start_time}_{self.duration}.txt" + with open(f"{psd_file}", "w+") as opened_file: for f, p in zip(psd.frequencies.value, psd.value): - opened_file.write('{} {}\n'.format(f, p)) + opened_file.write(f"{f} {p}\n") return psd.frequencies.value, psd.value - def _infer_time_domain_dependence( - self, start_time, sampling_frequency, duration, time_array): - """ Helper function to figure out if the time_array, or - sampling_frequency and duration where given + def _infer_time_domain_dependence(self, start_time, sampling_frequency, duration, time_array): + """Helper function to figure out if the time_array, or + sampling_frequency and duration where given """ - self._infer_dependence(domain='time', array=time_array, duration=duration, - sampling_frequency=sampling_frequency, start_time=start_time) + self._infer_dependence( + domain="time", + array=time_array, + duration=duration, + sampling_frequency=sampling_frequency, + start_time=start_time, + ) - def _infer_frequency_domain_dependence( - self, start_time, sampling_frequency, duration, frequency_array): - """ Helper function to figure out if the frequency_array, or - sampling_frequency and duration where given + def _infer_frequency_domain_dependence(self, start_time, sampling_frequency, duration, frequency_array): + """Helper function to figure out if the frequency_array, or + sampling_frequency and duration where given """ - self._infer_dependence(domain='frequency', array=frequency_array, - duration=duration, sampling_frequency=sampling_frequency, start_time=start_time) + self._infer_dependence( + domain="frequency", + array=frequency_array, + duration=duration, + sampling_frequency=sampling_frequency, + start_time=start_time, + ) def _infer_dependence(self, domain, array, duration, sampling_frequency, start_time): if (sampling_frequency is not None) and (duration is not None): if array is not None: - raise ValueError( - "You have given the sampling_frequency, duration, and " - "an array") + raise ValueError("You have given the sampling_frequency, duration, and an array") pass elif array is not None: - if domain == 'time': + if domain == "time": self.time_array = array - elif domain == 'frequency': + elif domain == "frequency": self.frequency_array = array self.start_time = start_time return elif sampling_frequency is None or duration is None: - raise ValueError( - "You must provide both sampling_frequency and duration") + raise ValueError("You must provide both sampling_frequency and duration") else: - raise ValueError( - "Insufficient information given to set arrays") - self._times_and_frequencies = CoupledTimeAndFrequencySeries(duration=duration, - sampling_frequency=sampling_frequency, - start_time=start_time) + raise ValueError("Insufficient information given to set arrays") + self._times_and_frequencies = CoupledTimeAndFrequencySeries( + duration=duration, sampling_frequency=sampling_frequency, start_time=start_time + ) def set_from_time_domain_strain( - self, time_domain_strain, sampling_frequency=None, duration=None, - start_time=0, time_array=None): - """ Set the strain data from a time domain strain array + self, time_domain_strain, sampling_frequency=None, duration=None, start_time=0, time_array=None + ): + """Set the strain data from a time domain strain array This sets the time_domain_strain attribute, the frequency_domain_strain is automatically calculated after a low-pass filter and Tukey window @@ -486,12 +474,11 @@ def set_from_time_domain_strain( given. """ - self._infer_time_domain_dependence(start_time=start_time, - sampling_frequency=sampling_frequency, - duration=duration, - time_array=time_array) + self._infer_time_domain_dependence( + start_time=start_time, sampling_frequency=sampling_frequency, duration=duration, time_array=time_array + ) - logger.debug('Setting data using provided time_domain_strain') + logger.debug("Setting data using provided time_domain_strain") if np.shape(time_domain_strain) == np.shape(self.time_array): self._time_domain_strain = time_domain_strain self._frequency_domain_strain = None @@ -499,7 +486,7 @@ def set_from_time_domain_strain( raise ValueError("Data times do not match time array") def set_from_gwpy_timeseries(self, time_series): - """ Set the strain data from a gwpy TimeSeries + """Set the strain data from a gwpy TimeSeries This sets the time_domain_strain attribute, the frequency_domain_strain is automatically calculated after a low-pass filter and Tukey window @@ -512,13 +499,15 @@ def set_from_gwpy_timeseries(self, time_series): """ from gwpy.timeseries import TimeSeries - logger.debug('Setting data using provided gwpy TimeSeries object') + + logger.debug("Setting data using provided gwpy TimeSeries object") if not isinstance(time_series, TimeSeries): raise ValueError("Input time_series is not a gwpy TimeSeries") - self._times_and_frequencies = \ - CoupledTimeAndFrequencySeries(duration=time_series.duration.value, - sampling_frequency=time_series.sample_rate.value, - start_time=time_series.epoch.value) + self._times_and_frequencies = CoupledTimeAndFrequencySeries( + duration=time_series.duration.value, + sampling_frequency=time_series.sample_rate.value, + start_time=time_series.epoch.value, + ) self._time_domain_strain = time_series.value self._frequency_domain_strain = None self._channel = time_series.channel @@ -527,10 +516,8 @@ def set_from_gwpy_timeseries(self, time_series): def channel(self): return self._channel - def set_from_open_data( - self, name, start_time, duration=4, outdir='outdir', cache=True, - **kwargs): - """ Set the strain data from open LOSC data + def set_from_open_data(self, name, start_time, duration=4, outdir="outdir", cache=True, **kwargs): + """Set the strain data from open LOSC data This sets the time_domain_strain attribute, the frequency_domain_strain is automatically calculated after a low-pass filter and Tukey window @@ -555,13 +542,13 @@ def set_from_open_data( """ timeseries = gwutils.get_open_strain_data( - name, start_time, start_time + duration, outdir=outdir, cache=cache, - **kwargs) + name, start_time, start_time + duration, outdir=outdir, cache=cache, **kwargs + ) self.set_from_gwpy_timeseries(timeseries) def set_from_csv(self, filename): - """ Set the strain data from a csv file + """Set the strain data from a csv file Parameters ========== @@ -570,13 +557,14 @@ def set_from_csv(self, filename): """ from gwpy.timeseries import TimeSeries - timeseries = TimeSeries.read(filename, format='csv') + + timeseries = TimeSeries.read(filename, format="csv") self.set_from_gwpy_timeseries(timeseries) def set_from_frequency_domain_strain( - self, frequency_domain_strain, sampling_frequency=None, - duration=None, start_time=0, frequency_array=None): - """ Set the `frequency_domain_strain` from a numpy array + self, frequency_domain_strain, sampling_frequency=None, duration=None, start_time=0, frequency_array=None + ): + """Set the `frequency_domain_strain` from a numpy array Parameters ========== @@ -594,22 +582,22 @@ def set_from_frequency_domain_strain( """ - self._infer_frequency_domain_dependence(start_time=start_time, - sampling_frequency=sampling_frequency, - duration=duration, - frequency_array=frequency_array) + self._infer_frequency_domain_dependence( + start_time=start_time, + sampling_frequency=sampling_frequency, + duration=duration, + frequency_array=frequency_array, + ) - logger.debug('Setting data using provided frequency_domain_strain') + logger.debug("Setting data using provided frequency_domain_strain") if np.shape(frequency_domain_strain) == np.shape(self.frequency_array): self._frequency_domain_strain = frequency_domain_strain self.window_factor = 1 else: raise ValueError("Data frequencies do not match frequency_array") - def set_from_power_spectral_density( - self, power_spectral_density, sampling_frequency, duration, - start_time=0): - """ Set the `frequency_domain_strain` by generating a noise realisation + def set_from_power_spectral_density(self, power_spectral_density, sampling_frequency, duration, start_time=0): + """Set the `frequency_domain_strain` by generating a noise realisation Parameters ========== @@ -624,15 +612,13 @@ def set_from_power_spectral_density( """ - self._times_and_frequencies = CoupledTimeAndFrequencySeries(duration=duration, - sampling_frequency=sampling_frequency, - start_time=start_time) - logger.debug( - 'Setting data using noise realization from provided' - 'power_spectal_density') - frequency_domain_strain, frequency_array = \ - power_spectral_density.get_noise_realisation( - self.sampling_frequency, self.duration) + self._times_and_frequencies = CoupledTimeAndFrequencySeries( + duration=duration, sampling_frequency=sampling_frequency, start_time=start_time + ) + logger.debug("Setting data using noise realization from providedpower_spectal_density") + frequency_domain_strain, frequency_array = power_spectral_density.get_noise_realisation( + self.sampling_frequency, self.duration + ) if np.array_equal(frequency_array, self.frequency_array): self._frequency_domain_strain = frequency_domain_strain @@ -640,7 +626,7 @@ def set_from_power_spectral_density( raise ValueError("Data frequencies do not match frequency_array") def set_from_zero_noise(self, sampling_frequency, duration, start_time=0): - """ Set the `frequency_domain_strain` to zero noise + """Set the `frequency_domain_strain` to zero noise This sets the `strain_data` to an array of (complex) zeros, while also making sure the frequency and time arrays match @@ -657,17 +643,14 @@ def set_from_zero_noise(self, sampling_frequency, duration, start_time=0): """ - self._times_and_frequencies = CoupledTimeAndFrequencySeries(duration=duration, - sampling_frequency=sampling_frequency, - start_time=start_time) - logger.debug('Setting zero noise data') - self._frequency_domain_strain = np.zeros_like(self.frequency_array, - dtype=complex) + self._times_and_frequencies = CoupledTimeAndFrequencySeries( + duration=duration, sampling_frequency=sampling_frequency, start_time=start_time + ) + logger.debug("Setting zero noise data") + self._frequency_domain_strain = np.zeros_like(self.frequency_array, dtype=complex) - def set_from_frame_file( - self, frame_file, sampling_frequency, duration, start_time=0, - channel=None, buffer_time=1): - """ Set the `frequency_domain_strain` from a frame fiile + def set_from_frame_file(self, frame_file, sampling_frequency, duration, start_time=0, channel=None, buffer_time=1): + """Set the `frequency_domain_strain` from a frame fiile Parameters ========== @@ -688,19 +671,23 @@ def set_from_frame_file( """ self._times_and_frequencies = CoupledTimeAndFrequencySeries( - duration=duration, sampling_frequency=sampling_frequency, - start_time=start_time) + duration=duration, sampling_frequency=sampling_frequency, start_time=start_time + ) - logger.info('Reading data from frame file {}'.format(frame_file)) + logger.info(f"Reading data from frame file {frame_file}") strain = gwutils.read_frame_file( - frame_file, start_time=start_time, end_time=start_time + duration, - buffer_time=buffer_time, channel=channel, - resample=sampling_frequency) + frame_file, + start_time=start_time, + end_time=start_time + duration, + buffer_time=buffer_time, + channel=channel, + resample=sampling_frequency, + ) self.set_from_gwpy_timeseries(strain) def set_from_channel_name(self, channel, duration, start_time, sampling_frequency): - """ Set the `frequency_domain_strain` by fetching from given channel + """Set the `frequency_domain_strain` by fetching from given channel using gwpy.TimesSeries.get(), which dynamically accesses either frames on disk, or a remote NDS2 server to find and return data. This function also verifies that the specified channel is given in the correct format. @@ -718,24 +705,25 @@ def set_from_channel_name(self, channel, duration, start_time, sampling_frequenc """ from gwpy.timeseries import TimeSeries - channel_comp = channel.split(':') + + channel_comp = channel.split(":") if len(channel_comp) != 2: - raise IndexError('Channel name must have format `IFO:Channel`') + raise IndexError("Channel name must have format `IFO:Channel`") self._times_and_frequencies = CoupledTimeAndFrequencySeries( - duration=duration, sampling_frequency=sampling_frequency, - start_time=start_time) + duration=duration, sampling_frequency=sampling_frequency, start_time=start_time + ) - logger.info('Fetching data using channel {}'.format(channel)) + logger.info(f"Fetching data using channel {channel}") strain = TimeSeries.get(channel, start_time, start_time + duration) strain = strain.resample(sampling_frequency) self.set_from_gwpy_timeseries(strain) -class Notch(object): +class Notch: def __init__(self, minimum_frequency, maximum_frequency): - """ A notch object storing the maximum and minimum frequency of the notch + """A notch object storing the maximum and minimum frequency of the notch Parameters ========== @@ -748,12 +736,14 @@ def __init__(self, minimum_frequency, maximum_frequency): self.minimum_frequency = minimum_frequency self.maximum_frequency = maximum_frequency else: - msg = ("Your notch minimum_frequency {} and maximum_frequency {} are invalid" - .format(minimum_frequency, maximum_frequency)) + msg = ( + f"Your notch minimum_frequency {minimum_frequency} and " + f"maximum_frequency {maximum_frequency} are invalid" + ) raise ValueError(msg) def get_idxs(self, frequency_array): - """ Get a boolean mask for the frequencies in frequency_array in the notch + """Get a boolean mask for the frequencies in frequency_array in the notch Parameters ========== @@ -766,12 +756,12 @@ def get_idxs(self, frequency_array): An array of booleans which are True for frequencies in the notch """ - lower = (frequency_array > self.minimum_frequency) - upper = (frequency_array < self.maximum_frequency) + lower = frequency_array > self.minimum_frequency + upper = frequency_array < self.maximum_frequency return lower & upper def check_frequency(self, freq): - """ Check if freq is inside the notch + """Check if freq is inside the notch Parameters ========== @@ -792,7 +782,7 @@ def check_frequency(self, freq): class NotchList(list): def __init__(self, notch_list): - """ A list of notches + """A list of notches Parameters ========== @@ -811,11 +801,11 @@ def __init__(self, notch_list): if isinstance(notch, tuple) and len(notch) == 2: self.append(Notch(*notch)) else: - msg = "notch_list {} is malformed".format(notch_list) + msg = f"notch_list {notch_list} is malformed" raise ValueError(msg) def check_frequency(self, freq): - """ Check if freq is inside the notch list + """Check if freq is inside the notch list Parameters ========== diff --git a/bilby/gw/eos/__init__.py b/bilby/gw/eos/__init__.py index 36eba397f..aff86a2ac 100644 --- a/bilby/gw/eos/__init__.py +++ b/bilby/gw/eos/__init__.py @@ -1,3 +1,4 @@ +from .eos import EOSFamily, SpectralDecompositionEOS, TabularEOS from .tov_solver import IntegrateTOV -from .eos import (SpectralDecompositionEOS, - EOSFamily, TabularEOS) + +__all__ = [EOSFamily, SpectralDecompositionEOS, TabularEOS, IntegrateTOV] diff --git a/bilby/gw/eos/eos.py b/bilby/gw/eos/eos.py index 2e962aa0c..eda310431 100644 --- a/bilby/gw/eos/eos.py +++ b/bilby/gw/eos/eos.py @@ -1,36 +1,38 @@ import os + import numpy as np -from scipy.interpolate import interp1d, CubicSpline +from scipy.interpolate import CubicSpline, interp1d -from .tov_solver import IntegrateTOV from ...core import utils +from .tov_solver import IntegrateTOV C_SI = utils.speed_of_light # m/s -C_CGS = C_SI * 100. +C_CGS = C_SI * 100.0 G_SI = utils.gravitational_constant # m^3 kg^-1 s^-2 MSUN_SI = utils.solar_mass # Kg # Stores conversions from geometerized to cgs or si unit systems -conversion_dict = {'pressure': {'cgs': C_SI ** 4. / G_SI * 10., 'si': C_SI ** 4. / G_SI, 'geom': 1.}, - 'energy_density': {'cgs': C_SI ** 4. / G_SI * 10., 'si': C_SI ** 4. / G_SI, 'geom': 1.}, - 'density': {'cgs': C_SI ** 2. / G_SI / 1000., 'si': C_SI ** 2. / G_SI, 'geom': 1.}, - 'pseudo_enthalpy': {'dimensionless': 1.}, - 'mass': {'g': C_SI ** 2. / G_SI * 1000, 'kg': C_SI ** 2. / G_SI, 'geom': 1., - 'm_sol': C_SI ** 2. / G_SI / MSUN_SI}, - 'radius': {'cm': 100., 'm': 1., 'km': .001}, - 'tidal_deformability': {'geom': 1.}} +conversion_dict = { + "pressure": {"cgs": C_SI**4.0 / G_SI * 10.0, "si": C_SI**4.0 / G_SI, "geom": 1.0}, + "energy_density": {"cgs": C_SI**4.0 / G_SI * 10.0, "si": C_SI**4.0 / G_SI, "geom": 1.0}, + "density": {"cgs": C_SI**2.0 / G_SI / 1000.0, "si": C_SI**2.0 / G_SI, "geom": 1.0}, + "pseudo_enthalpy": {"dimensionless": 1.0}, + "mass": {"g": C_SI**2.0 / G_SI * 1000, "kg": C_SI**2.0 / G_SI, "geom": 1.0, "m_sol": C_SI**2.0 / G_SI / MSUN_SI}, + "radius": {"cm": 100.0, "m": 1.0, "km": 0.001}, + "tidal_deformability": {"geom": 1.0}, +} # construct dictionary of pre-shipped EOS pressure denstity table -path_to_eos_tables = os.path.join(os.path.dirname(__file__), 'eos_tables') +path_to_eos_tables = os.path.join(os.path.dirname(__file__), "eos_tables") list_of_eos_tables = os.listdir(path_to_eos_tables) -valid_eos_files = [i for i in list_of_eos_tables if 'LAL' in i] +valid_eos_files = [i for i in list_of_eos_tables if "LAL" in i] valid_eos_file_paths = [os.path.join(path_to_eos_tables, filename) for filename in valid_eos_files] -valid_eos_names = [i.split('_', maxsplit=1)[-1].strip('.dat') for i in valid_eos_files] +valid_eos_names = [i.split("_", maxsplit=1)[-1].strip(".dat") for i in valid_eos_files] valid_eos_dict = dict(zip(valid_eos_names, valid_eos_file_paths)) -class TabularEOS(object): +class TabularEOS: """ Given a valid eos input format, such as 2-D array, an ascii file, or a string, parse, and interpolate @@ -68,8 +70,9 @@ def __init__(self, eos, sampling_flag=False, warning_flag=False): elif isinstance(eos, np.ndarray): table = eos else: - raise ValueError("eos provided is invalid type please supply a str name, str path to ASCII file, " - "or a numpy array") + raise ValueError( + "eos provided is invalid type please supply a str name, str path to ASCII file, or a numpy array" + ) table = self.__remove_leading_zero(table) @@ -85,18 +88,22 @@ def __init__(self, eos, sampling_flag=False, warning_flag=False): integrand = self.pressure / (self.energy_density + self.pressure) self.pseudo_enthalpy = cumulative_trapezoid(integrand, np.log(self.pressure), initial=0) + integrand[0] - self.interp_energy_density_from_pressure = CubicSpline(np.log10(self.pressure), - np.log10(self.energy_density), - ) + self.interp_energy_density_from_pressure = CubicSpline( + np.log10(self.pressure), + np.log10(self.energy_density), + ) - self.interp_energy_density_from_pseudo_enthalpy = CubicSpline(np.log10(self.pseudo_enthalpy), - np.log10(self.energy_density)) + self.interp_energy_density_from_pseudo_enthalpy = CubicSpline( + np.log10(self.pseudo_enthalpy), np.log10(self.energy_density) + ) - self.interp_pressure_from_pseudo_enthalpy = CubicSpline(np.log10(self.pseudo_enthalpy), - np.log10(self.pressure)) + self.interp_pressure_from_pseudo_enthalpy = CubicSpline( + np.log10(self.pseudo_enthalpy), np.log10(self.pressure) + ) - self.interp_pseudo_enthalpy_from_energy_density = CubicSpline(np.log10(self.energy_density), - np.log10(self.pseudo_enthalpy)) + self.interp_pseudo_enthalpy_from_energy_density = CubicSpline( + np.log10(self.energy_density), np.log10(self.pseudo_enthalpy) + ) self.__construct_all_tables() @@ -110,13 +117,13 @@ def __remove_leading_zero(self, table): loglog interpolation breaks if the first entries are 0s """ - if table[0, 0] == 0. or table[0, 1] == 0.: + if table[0, 0] == 0.0 or table[0, 1] == 0.0: return table[1:, :] else: return table - def energy_from_pressure(self, pressure, interp_type='CubicSpline'): + def energy_from_pressure(self, pressure, interp_type="CubicSpline"): """ Find value of energy_from_pressure as in lalsimulation, return e = K * p**(3./5.) below min pressure @@ -137,26 +144,28 @@ def energy_from_pressure(self, pressure, interp_type='CubicSpline'): indices_greater_than_min = np.nonzero(pressure >= self.minimum_pressure) # We do this special for less than min pressure - energy_returned[indices_less_than_min] = 10 ** (np.log10(self.energy_density[0]) + - (3. / 5.) * (np.log10(pressure[indices_less_than_min]) - - np.log10(self.pressure[0]))) - - if interp_type == 'CubicSpline': - energy_returned[indices_greater_than_min] = ( - 10. ** self.interp_energy_density_from_pressure(np.log10(pressure[indices_greater_than_min]))) - elif interp_type == 'linear': - energy_returned[indices_greater_than_min] = np.interp(pressure[indices_greater_than_min], - self.pressure, - self.energy_density) + energy_returned[indices_less_than_min] = 10 ** ( + np.log10(self.energy_density[0]) + + (3.0 / 5.0) * (np.log10(pressure[indices_less_than_min]) - np.log10(self.pressure[0])) + ) + + if interp_type == "CubicSpline": + energy_returned[indices_greater_than_min] = 10.0 ** self.interp_energy_density_from_pressure( + np.log10(pressure[indices_greater_than_min]) + ) + elif interp_type == "linear": + energy_returned[indices_greater_than_min] = np.interp( + pressure[indices_greater_than_min], self.pressure, self.energy_density + ) else: - raise ValueError('Interpolation scheme must be linear or CubicSpline') + raise ValueError("Interpolation scheme must be linear or CubicSpline") if energy_returned.size == 1: return energy_returned[0] else: return energy_returned - def pressure_from_pseudo_enthalpy(self, pseudo_enthalpy, interp_type='CubicSpline'): + def pressure_from_pseudo_enthalpy(self, pseudo_enthalpy, interp_type="CubicSpline"): """ Find p(h) as in lalsimulation, return p = K * h**(5./2.) below min enthalpy @@ -172,26 +181,28 @@ def pressure_from_pseudo_enthalpy(self, pseudo_enthalpy, interp_type='CubicSplin indices_less_than_min = np.nonzero(pseudo_enthalpy < self.minimum_pseudo_enthalpy) indices_greater_than_min = np.nonzero(pseudo_enthalpy >= self.minimum_pseudo_enthalpy) - pressure_returned[indices_less_than_min] = 10. ** (np.log10(self.pressure[0]) + - 2.5 * (np.log10(pseudo_enthalpy[indices_less_than_min]) - - np.log10(self.pseudo_enthalpy[0]))) - - if interp_type == 'CubicSpline': - pressure_returned[indices_greater_than_min] = ( - 10. ** self.interp_pressure_from_pseudo_enthalpy(np.log10(pseudo_enthalpy[indices_greater_than_min]))) - elif interp_type == 'linear': - pressure_returned[indices_greater_than_min] = np.interp(pseudo_enthalpy[indices_greater_than_min], - self.pseudo_enthalpy, - self.pressure) + pressure_returned[indices_less_than_min] = 10.0 ** ( + np.log10(self.pressure[0]) + + 2.5 * (np.log10(pseudo_enthalpy[indices_less_than_min]) - np.log10(self.pseudo_enthalpy[0])) + ) + + if interp_type == "CubicSpline": + pressure_returned[indices_greater_than_min] = 10.0 ** self.interp_pressure_from_pseudo_enthalpy( + np.log10(pseudo_enthalpy[indices_greater_than_min]) + ) + elif interp_type == "linear": + pressure_returned[indices_greater_than_min] = np.interp( + pseudo_enthalpy[indices_greater_than_min], self.pseudo_enthalpy, self.pressure + ) else: - raise ValueError('Interpolation scheme must be linear or CubicSpline') + raise ValueError("Interpolation scheme must be linear or CubicSpline") if pressure_returned.size == 1: return pressure_returned[0] else: return pressure_returned - def energy_density_from_pseudo_enthalpy(self, pseudo_enthalpy, interp_type='CubicSpline'): + def energy_density_from_pseudo_enthalpy(self, pseudo_enthalpy, interp_type="CubicSpline"): """ Find energy_density_from_pseudo_enthalpy(pseudo_enthalpy) as in lalsimulation, return e = K * h**(3./2.) below min enthalpy @@ -207,25 +218,26 @@ def energy_density_from_pseudo_enthalpy(self, pseudo_enthalpy, interp_type='Cubi indices_less_than_min = np.nonzero(pseudo_enthalpy < self.minimum_pseudo_enthalpy) indices_greater_than_min = np.nonzero(pseudo_enthalpy >= self.minimum_pseudo_enthalpy) - energy_returned[indices_less_than_min] = 10 ** (np.log10(self.energy_density[0]) + - 1.5 * (np.log10(pseudo_enthalpy[indices_less_than_min]) - - np.log10(self.pseudo_enthalpy[0]))) - if interp_type == 'CubicSpline': + energy_returned[indices_less_than_min] = 10 ** ( + np.log10(self.energy_density[0]) + + 1.5 * (np.log10(pseudo_enthalpy[indices_less_than_min]) - np.log10(self.pseudo_enthalpy[0])) + ) + if interp_type == "CubicSpline": x = np.log10(pseudo_enthalpy[indices_greater_than_min]) energy_returned[indices_greater_than_min] = 10 ** self.interp_energy_density_from_pseudo_enthalpy(x) - elif interp_type == 'linear': - energy_returned[indices_greater_than_min] = np.interp(pseudo_enthalpy[indices_greater_than_min], - self.pseudo_enthalpy, - self.energy_density) + elif interp_type == "linear": + energy_returned[indices_greater_than_min] = np.interp( + pseudo_enthalpy[indices_greater_than_min], self.pseudo_enthalpy, self.energy_density + ) else: - raise ValueError('Interpolation scheme must be linear or CubicSpline') + raise ValueError("Interpolation scheme must be linear or CubicSpline") if energy_returned.size == 1: return energy_returned[0] else: return energy_returned - def pseudo_enthalpy_from_energy_density(self, energy_density, interp_type='CubicSpline'): + def pseudo_enthalpy_from_energy_density(self, energy_density, interp_type="CubicSpline"): """ Find h(epsilon) as in lalsimulation, return h = K * e**(2./3.) below min enthalpy @@ -241,26 +253,29 @@ def pseudo_enthalpy_from_energy_density(self, energy_density, interp_type='Cubic indices_less_than_min = np.nonzero(energy_density < self.minimum_energy_density) indices_greater_than_min = np.nonzero(energy_density >= self.minimum_energy_density) - pseudo_enthalpy_returned[indices_less_than_min] = 10 ** (np.log10(self.pseudo_enthalpy[0]) + (2. / 3.) * - (np.log10(energy_density[indices_less_than_min]) - - np.log10(self.energy_density[0]))) + pseudo_enthalpy_returned[indices_less_than_min] = 10 ** ( + np.log10(self.pseudo_enthalpy[0]) + + (2.0 / 3.0) * (np.log10(energy_density[indices_less_than_min]) - np.log10(self.energy_density[0])) + ) - if interp_type == 'CubicSpline': + if interp_type == "CubicSpline": x = np.log10(energy_density[indices_greater_than_min]) - pseudo_enthalpy_returned[indices_greater_than_min] = 10**self.interp_pseudo_enthalpy_from_energy_density(x) - elif interp_type == 'linear': - pseudo_enthalpy_returned[indices_greater_than_min] = np.interp(energy_density[indices_greater_than_min], - self.energy_density, - self.pseudo_enthalpy) + pseudo_enthalpy_returned[indices_greater_than_min] = 10 ** self.interp_pseudo_enthalpy_from_energy_density( + x + ) + elif interp_type == "linear": + pseudo_enthalpy_returned[indices_greater_than_min] = np.interp( + energy_density[indices_greater_than_min], self.energy_density, self.pseudo_enthalpy + ) else: - raise ValueError('Interpolation scheme must be linear or CubicSpline') + raise ValueError("Interpolation scheme must be linear or CubicSpline") if pseudo_enthalpy_returned.size == 1: return pseudo_enthalpy_returned[0] else: return pseudo_enthalpy_returned - def dedh(self, pseudo_enthalpy, rel_dh=1e-5, interp_type='CubicSpline'): + def dedh(self, pseudo_enthalpy, rel_dh=1e-5, interp_type="CubicSpline"): """ Value of [depsilon/dh](p) @@ -279,9 +294,9 @@ def dedh(self, pseudo_enthalpy, rel_dh=1e-5, interp_type='CubicSpline'): eps_upper = self.energy_density_from_pseudo_enthalpy(pseudo_enthalpy + dh, interp_type=interp_type) eps_lower = self.energy_density_from_pseudo_enthalpy(pseudo_enthalpy - dh, interp_type=interp_type) - return (eps_upper - eps_lower) / (2. * dh) + return (eps_upper - eps_lower) / (2.0 * dh) - def dedp(self, pressure, rel_dp=1e-5, interp_type='CubicSpline'): + def dedp(self, pressure, rel_dp=1e-5, interp_type="CubicSpline"): """ Find value of [depsilon/dp](p) @@ -300,9 +315,9 @@ def dedp(self, pressure, rel_dp=1e-5, interp_type='CubicSpline'): eps_upper = self.energy_from_pressure(pressure + dp, interp_type=interp_type) eps_lower = self.energy_from_pressure(pressure - dp, interp_type=interp_type) - return (eps_upper - eps_lower) / (2. * dp) + return (eps_upper - eps_lower) / (2.0 * dp) - def velocity_from_pseudo_enthalpy(self, pseudo_enthalpy, interp_type='CubicSpline'): + def velocity_from_pseudo_enthalpy(self, pseudo_enthalpy, interp_type="CubicSpline"): """ Returns the speed of sound in geometerized units in the neutron star at the specified pressure. @@ -399,53 +414,62 @@ def plot(self, rep, xlim=None, ylim=None, units=None): import matplotlib.pyplot as plt # Set data based on specified representation - varnames = rep.split('-') + varnames = rep.split("-") - assert varnames[0] != varnames[ - 1], 'Cannot plot the same variable against itself. Please choose another representation' + assert varnames[0] != varnames[1], ( + "Cannot plot the same variable against itself. Please choose another representation" + ) # Correspondence of rep parameter, data, and latex symbol # rep_dict = {'energy_density': [self.epsilon, r'$\epsilon$'], # 'pressure': [self.p, r'$p$'], 'pseudo_enthalpy': [pseudo_enthalpy, r'$h$']} # FIXME: The second element in these arrays should be tex labels, but tex's not working rn - rep_dict = {'energy_density': [self.energy_density, 'energy_density'], - 'pressure': [self.pressure, 'pressure'], - 'pseudo_enthalpy': [self.pseudo_enthalpy, 'pseudo_enthalpy']} + rep_dict = { + "energy_density": [self.energy_density, "energy_density"], + "pressure": [self.pressure, "pressure"], + "pseudo_enthalpy": [self.pseudo_enthalpy, "pseudo_enthalpy"], + } xname = varnames[1] yname = varnames[0] # Set units - eos_default_units = {'pressure': 'cgs', 'energy_density': 'cgs', 'density': 'cgs', - 'pseudo_enthalpy': 'dimensionless'} + eos_default_units = { + "pressure": "cgs", + "energy_density": "cgs", + "density": "cgs", + "pseudo_enthalpy": "dimensionless", + } if units is None: - units = [eos_default_units[yname], eos_default_units[xname]] # Default unit system is cgs + units = [eos_default_units[yname], eos_default_units[xname]] # Default unit system is cgs elif isinstance(units, str): - units = [units, units] # If only one unit system given, use for both + units = [units, units] # If only one unit system given, use for both xunits = units[1] yunits = units[0] # Ensure valid units if xunits not in list(conversion_dict[xname].keys()) or yunits not in list(conversion_dict[yname].keys()): - s = ''' + s = """ Invalid unit system. Valid variable-unit pairs are: p: {p_units} e: {e_units} rho: {rho_units} h: {h_units}. - '''.format(p_units=list(conversion_dict['pressure'].keys()), - e_units=list(conversion_dict['energy_density'].keys()), - rho_units=list(conversion_dict['density'].keys()), - h_units=list(conversion_dict['pseudo_enthalpy'].keys())) + """.format( + p_units=list(conversion_dict["pressure"].keys()), + e_units=list(conversion_dict["energy_density"].keys()), + rho_units=list(conversion_dict["density"].keys()), + h_units=list(conversion_dict["pseudo_enthalpy"].keys()), + ) raise ValueError(s) xdat = rep_dict[xname][0] * conversion_dict[xname][xunits] ydat = rep_dict[yname][0] * conversion_dict[yname][yunits] - xlabel = rep_dict[varnames[1]][1].replace('_', ' ') - ylabel = rep_dict[varnames[0]][1].replace('_', ' ') + '(' + xlabel + ')' + xlabel = rep_dict[varnames[1]][1].replace("_", " ") + ylabel = rep_dict[varnames[0]][1].replace("_", " ") + "(" + xlabel + ")" # Determine plot ranges. Currently shows 10% wider than actual data range. if xlim is None: @@ -468,7 +492,7 @@ def plot(self, rep, xlim=None, ylim=None, units=None): def spectral_adiabatic_index(gammas, x): arg = 0 for i in range(len(gammas)): - arg += gammas[i] * x ** i + arg += gammas[i] * x**i return np.exp(arg) @@ -529,29 +553,28 @@ def __determine_xmax(self, a_max=6.0): highest_order_gamma = np.abs(self.gammas[-1])[0] expansion_order = float(len(self.gammas) - 1) - xmax = (np.log(a_max) / highest_order_gamma) ** (1. / expansion_order) + xmax = (np.log(a_max) / highest_order_gamma) ** (1.0 / expansion_order) return xmax def __mu_integrand(self, x): - - return 1. / spectral_adiabatic_index(self.gammas, x) + return 1.0 / spectral_adiabatic_index(self.gammas, x) def mu(self, x): from scipy.integrate import quad + return np.exp(-quad(self.__mu_integrand, 0, x)[0]) def __eps_integrand(self, x): - return np.exp(x) * self.mu(x) / spectral_adiabatic_index(self.gammas, x) def energy_density(self, x, eps0): from scipy.integrate import quad + quad_result, quad_err = quad(self.__eps_integrand, 0, x) - eps_of_x = (eps0 * C_CGS ** 2.) / self.mu(x) + self.p0 / self.mu(x) * quad_result + eps_of_x = (eps0 * C_CGS**2.0) / self.mu(x) + self.p0 / self.mu(x) * quad_result return eps_of_x def __construct_a_of_x_table(self): - xdat = np.linspace(0, self.xmax, num=self.npts) # Generate adiabatic index points until a point is out of prior range @@ -568,7 +591,7 @@ def __construct_a_of_x_table(self): xmax_new = xdat[i - 1] # If EOS is too short, set prior to 0, else resample the function and set new xmax - if xmax_new < 4. or i == 0: + if xmax_new < 4.0 or i == 0: self.warning_flag = True else: xdat = np.linspace(0, xmax_new, num=self.npts) @@ -580,7 +603,6 @@ def __construct_a_of_x_table(self): self.adat = adat def __construct_e_of_p_table(self): - """ Creates p, epsilon table for a given set of spectral parameters """ @@ -598,11 +620,11 @@ def __construct_e_of_p_table(self): # convert eos to geometrized units in *m^-2* # IMPORTANT - eos_vals = eos_vals * 0.1 * G_SI / C_SI ** 4 + eos_vals = eos_vals * 0.1 * G_SI / C_SI**4 # doing as those before me have done and using SLY4 as low density region # SLY4 in geometrized units - low_density_path = os.path.join(os.path.dirname(__file__), 'eos_tables', 'LALSimNeutronStarEOS_SLY4.dat') + low_density_path = os.path.join(os.path.dirname(__file__), "eos_tables", "LALSimNeutronStarEOS_SLY4.dat") low_density = np.loadtxt(low_density_path) cutoff = eos_vals[0, :] @@ -620,7 +642,7 @@ def __construct_e_of_p_table(self): return eos_vals -class EOSFamily(object): +class EOSFamily: """ Create a EOS family and get mass-radius information @@ -636,8 +658,10 @@ class EOSFamily(object): The mass-radius and mass-k2 data should be populated here via the TOV solver upon object construction. """ + def __init__(self, eos, npts=500): from scipy.optimize import minimize_scalar + self.eos = eos # FIXME: starting_energy_density is set somewhat arbitrarily @@ -645,9 +669,7 @@ def __init__(self, eos, npts=500): ending_energy_density = max(self.eos.energy_density) log_starting_energy_density = np.log(starting_energy_density) log_ending_energy_density = np.log(ending_energy_density) - log_energy_density_grid = np.linspace(log_starting_energy_density, - log_ending_energy_density, - num=npts) + log_energy_density_grid = np.linspace(log_starting_energy_density, log_ending_energy_density, num=npts) energy_density_grid = np.exp(log_energy_density_grid) # Generate m, r, and k2 lists @@ -674,7 +696,7 @@ def __init__(self, eos, npts=500): x = [energy_density_grid[i - 2], energy_density_grid[i - 1], energy_density_grid[i]] y = [mass[i - 2], mass[i - 1], mass[i]] - f = interp1d(x, y, kind='quadratic', bounds_error=False, fill_value='extrapolate') + f = interp1d(x, y, kind="quadratic", bounds_error=False, fill_value="extrapolate") res = minimize_scalar(lambda x: -f(x)) @@ -692,8 +714,7 @@ def __init__(self, eos, npts=500): # with these quantities, then convert to SI. # Calculating dimensionless lambda values from k2, radii, and mass - tidal_deformability = [2. / 3. * k2 * r ** 5. / m ** 5. for k2, r, m in - zip(k2love_number, radius, mass)] + tidal_deformability = [2.0 / 3.0 * k2 * r**5.0 / m**5.0 for k2, r, m in zip(k2love_number, radius, mass)] # As a last resort, if highest mass is still smaller than second # to last point, remove the last point from each array @@ -707,16 +728,16 @@ def __init__(self, eos, npts=500): self.radius = np.array(radius) self.k2love_number = np.array(k2love_number) self.tidal_deformability = np.array(tidal_deformability) - self.maximum_mass = mass[-1] * conversion_dict['mass']['m_sol'] + self.maximum_mass = mass[-1] * conversion_dict["mass"]["m_sol"] def radius_from_mass(self, m): """ :param m: mass of neutron star in solar masses :return: radius of neutron star in meters """ - f = CubicSpline(self.mass, self.radius, bc_type='natural', extrapolate=True) + f = CubicSpline(self.mass, self.radius, bc_type="natural", extrapolate=True) - mass_converted_to_geom = m * MSUN_SI * G_SI / C_SI ** 2. + mass_converted_to_geom = m * MSUN_SI * G_SI / C_SI**2.0 return f(mass_converted_to_geom) def k2_from_mass(self, m): @@ -724,9 +745,9 @@ def k2_from_mass(self, m): :param m: mass of neutron star in solar masses. :return: dimensionless second tidal love number. """ - f = CubicSpline(self.mass, self.k2love_number, bc_type='natural', extrapolate=True) + f = CubicSpline(self.mass, self.k2love_number, bc_type="natural", extrapolate=True) - m_geom = m * MSUN_SI * G_SI / C_SI ** 2. + m_geom = m * MSUN_SI * G_SI / C_SI**2.0 return f(m_geom) def lambda_from_mass(self, m): @@ -742,9 +763,9 @@ def lambda_from_mass(self, m): r = self.radius_from_mass(m) k = self.k2_from_mass(m) - m_geom = m * MSUN_SI * G_SI / C_SI ** 2. + m_geom = m * MSUN_SI * G_SI / C_SI**2.0 c = m_geom / r - lmbda = (2. / 3.) * k / c ** 5. + lmbda = (2.0 / 3.0) * k / c**5.0 return lmbda @@ -784,20 +805,25 @@ def plot(self, rep, xlim=None, ylim=None, units=None): import matplotlib.pyplot as plt # Set data based on specified representation - varnames = rep.split('-') + varnames = rep.split("-") - assert varnames[0] != varnames[ - 1], 'Cannot plot the same variable against itself. Please choose another representation' + assert varnames[0] != varnames[1], ( + "Cannot plot the same variable against itself. Please choose another representation" + ) # Correspondence of rep parameter, data, and latex symbol - rep_dict = {'mass': [self.mass, r'$M$'], 'radius': [self.radius, r'$R$'], 'k2': [self.k2love_number, r'$k_2$'], - 'tidal_deformability': [self.tidal_deformability, r'$l$']} + rep_dict = { + "mass": [self.mass, r"$M$"], + "radius": [self.radius, r"$R$"], + "k2": [self.k2love_number, r"$k_2$"], + "tidal_deformability": [self.tidal_deformability, r"$l$"], + } xname = varnames[1] yname = varnames[0] # Set units - fam_default_units = {'mass': 'm_sol', 'radius': 'km', 'tidal_deformability': 'geom'} + fam_default_units = {"mass": "m_sol", "radius": "km", "tidal_deformability": "geom"} if units is None: units = [fam_default_units[yname], fam_default_units[xname]] # Default unit system is cgs elif isinstance(units, str): @@ -808,21 +834,23 @@ def plot(self, rep, xlim=None, ylim=None, units=None): # Ensure valid units if xunits not in list(conversion_dict[xname].keys()) or yunits not in list(conversion_dict[yname].keys()): - s = ''' + s = """ Invalid unit system. Valid variable-unit pairs are: m: {m_units} r: {r_units} l: {l_units}. - '''.format(m_units=list(conversion_dict['mass'].keys()), - r_units=list(conversion_dict['radius'].keys()), - l_units=list(conversion_dict['tidal_deformability'].keys())) + """.format( + m_units=list(conversion_dict["mass"].keys()), + r_units=list(conversion_dict["radius"].keys()), + l_units=list(conversion_dict["tidal_deformability"].keys()), + ) raise ValueError(s) xdat = rep_dict[varnames[1]][0] * conversion_dict[xname][xunits] ydat = rep_dict[varnames[0]][0] * conversion_dict[yname][yunits] - xlabel = rep_dict[varnames[1]][1].replace('_', ' ') - ylabel = rep_dict[varnames[0]][1].replace('_', ' ') + '(' + xlabel + ')' + xlabel = rep_dict[varnames[1]][1].replace("_", " ") + ylabel = rep_dict[varnames[0]][1].replace("_", " ") + "(" + xlabel + ")" # Determine plot ranges. Currently shows 10% wider than actual data range. if xlim is None: diff --git a/bilby/gw/eos/tov_solver.py b/bilby/gw/eos/tov_solver.py index 0135086ef..e7eae960a 100644 --- a/bilby/gw/eos/tov_solver.py +++ b/bilby/gw/eos/tov_solver.py @@ -4,8 +4,7 @@ class IntegrateTOV: - """Class that given an initial pressure a mass radius value and a k2-love number - """ + """Class that given an initial pressure a mass radius value and a k2-love number""" def __init__(self, eos, eps_0): self.eos = eos @@ -18,8 +17,8 @@ def __init__(self, eos, eps_0): mass0, radius0 = self.__mass_radius_cent(pseudo_enthalpy0, self.pseudo_enthalpy) # k2 integration starting vals - H0 = radius0 ** 2 - B0 = 2. * radius0 + H0 = radius0**2 + B0 = 2.0 * radius0 self.y = np.array([mass0, radius0, H0, B0]) @@ -34,12 +33,20 @@ def __mass_radius_cent(self, pseudo_enthalpy0, pseudo_enthalpy): p_c = self.eos.pressure_from_pseudo_enthalpy(pseudo_enthalpy0) depsdh_c = self.eos.dedh(pseudo_enthalpy0) - radius = ((3. * (pseudo_enthalpy0 - pseudo_enthalpy)) / (2. * np.pi * (eps_c + 3. * p_c))) ** 0.5 \ - * (1. - 0.25 * (eps_c - 3. * p_c - (3. / 5.) * depsdh_c) * - ((pseudo_enthalpy0 - pseudo_enthalpy) / (eps_c + 3. * p_c))) - - mass = (4. * np.pi) / 3. * eps_c * radius ** 3 * (1. - (3. / 5.) * - depsdh_c * (pseudo_enthalpy0 - pseudo_enthalpy) / eps_c) + radius = ((3.0 * (pseudo_enthalpy0 - pseudo_enthalpy)) / (2.0 * np.pi * (eps_c + 3.0 * p_c))) ** 0.5 * ( + 1.0 + - 0.25 + * (eps_c - 3.0 * p_c - (3.0 / 5.0) * depsdh_c) + * ((pseudo_enthalpy0 - pseudo_enthalpy) / (eps_c + 3.0 * p_c)) + ) + + mass = ( + (4.0 * np.pi) + / 3.0 + * eps_c + * radius**3 + * (1.0 - (3.0 / 5.0) * depsdh_c * (pseudo_enthalpy0 - pseudo_enthalpy) / eps_c) + ) return mass, radius @@ -59,25 +66,24 @@ def __tov_eqns(self, h, y): r = y[1] H = y[2] B = y[3] - eps = self.eos.energy_density_from_pseudo_enthalpy(h, interp_type='CubicSpline') - p = self.eos.pressure_from_pseudo_enthalpy(h, interp_type='CubicSpline') - depsdp = self.eos.dedp(p, interp_type='CubicSpline') + eps = self.eos.energy_density_from_pseudo_enthalpy(h, interp_type="CubicSpline") + p = self.eos.pressure_from_pseudo_enthalpy(h, interp_type="CubicSpline") + depsdp = self.eos.dedp(p, interp_type="CubicSpline") - dmdh = (- (4. * np.pi * eps * r ** 3 * (r - 2. * m)) / - (m + 4. * np.pi * r ** 3 * p)) + dmdh = -(4.0 * np.pi * eps * r**3 * (r - 2.0 * m)) / (m + 4.0 * np.pi * r**3 * p) - drdh = -(r * (r - 2. * m)) / (m + 4. * np.pi * r ** 3 * p) + drdh = -(r * (r - 2.0 * m)) / (m + 4.0 * np.pi * r**3 * p) dHdh = B * drdh # taken from Damour & Nagar - e_lam = (1. - 2. * m / r) ** (-1) + e_lam = (1.0 - 2.0 * m / r) ** (-1) - C1 = 2. / r + e_lam * (2. * m / r ** 2. + 4. * np.pi * r * (p - eps)) - C0 = (e_lam * (- 6. / r ** 2. + 4. * np.pi * (eps + p) * - depsdp + 4. * np.pi * (5. * eps + 9. * p)) - - (2. * (m + 4. * np.pi * r ** 3. * p) / - (r ** 2. - 2. * m * r)) ** 2.) + C1 = 2.0 / r + e_lam * (2.0 * m / r**2.0 + 4.0 * np.pi * r * (p - eps)) + C0 = ( + e_lam * (-6.0 / r**2.0 + 4.0 * np.pi * (eps + p) * depsdp + 4.0 * np.pi * (5.0 * eps + 9.0 * p)) + - (2.0 * (m + 4.0 * np.pi * r**3.0 * p) / (r**2.0 - 2.0 * m * r)) ** 2.0 + ) dBdh = -(C1 * B + C0 * H) * drdh @@ -96,13 +102,15 @@ def __calc_k2(self, R, Beta, H, C): y = (R * Beta) / H - num = ((8. / 5.) * (1. - 2. * C) ** 2 * - C ** 5 * (2. * C * (y - 1.) - y + 2.)) - denom = (2. * C * (4. * (y + 1.) * C ** 4 + (6. * y - 4.) * C ** 3 + - (26. - 22. * y) * C ** 2 + 3. * (5. * y - 8.) * - C - 3. * y + 6.) - 3. * (1. - 2 * C) ** 2 * - (2. * C * (y - 1.) - y + 2.) * - np.log(1. / (1. - 2. * C))) + num = (8.0 / 5.0) * (1.0 - 2.0 * C) ** 2 * C**5 * (2.0 * C * (y - 1.0) - y + 2.0) + denom = 2.0 * C * ( + 4.0 * (y + 1.0) * C**4 + + (6.0 * y - 4.0) * C**3 + + (26.0 - 22.0 * y) * C**2 + + 3.0 * (5.0 * y - 8.0) * C + - 3.0 * y + + 6.0 + ) - 3.0 * (1.0 - 2 * C) ** 2 * (2.0 * C * (y - 1.0) - y + 2.0) * np.log(1.0 / (1.0 - 2.0 * C)) return num / denom @@ -116,8 +124,7 @@ def integrate_TOV(self): rel_err = 1e-4 abs_err = 0.0 - result = solve_ivp(self.__tov_eqns, (self.pseudo_enthalpy, 1e-16), self.y, rtol=rel_err, - atol=abs_err) + result = solve_ivp(self.__tov_eqns, (self.pseudo_enthalpy, 1e-16), self.y, rtol=rel_err, atol=abs_err) m_fin = result.y[0, -1] r_fin = result.y[1, -1] H_fin = result.y[2, -1] diff --git a/bilby/gw/likelihood/__init__.py b/bilby/gw/likelihood/__init__.py index d752f77a9..7e2f0e78e 100644 --- a/bilby/gw/likelihood/__init__.py +++ b/bilby/gw/likelihood/__init__.py @@ -1,15 +1,24 @@ +from ..source import lal_binary_black_hole +from ..waveform_generator import WaveformGenerator from .base import GravitationalWaveTransient from .basic import BasicGravitationalWaveTransient -from .roq import BilbyROQParamsRangeError, ROQGravitationalWaveTransient from .multiband import MBGravitationalWaveTransient from .relative import RelativeBinningGravitationalWaveTransient +from .roq import BilbyROQParamsRangeError, ROQGravitationalWaveTransient -from ..source import lal_binary_black_hole -from ..waveform_generator import WaveformGenerator +__all__ = [ + BasicGravitationalWaveTransient, + GravitationalWaveTransient, + MBGravitationalWaveTransient, + RelativeBinningGravitationalWaveTransient, + BilbyROQParamsRangeError, + ROQGravitationalWaveTransient, + "get_binary_black_hole_likelihood", +] def get_binary_black_hole_likelihood(interferometers): - """ A wrapper to quickly set up a likelihood for BBH parameter estimation + """A wrapper to quickly set up a likelihood for BBH parameter estimation Parameters ========== @@ -27,8 +36,6 @@ def get_binary_black_hole_likelihood(interferometers): duration=interferometers.duration, sampling_frequency=interferometers.sampling_frequency, frequency_domain_source_model=lal_binary_black_hole, - waveform_arguments={'waveform_approximant': 'IMRPhenomPv2', - 'reference_frequency': 50}) + waveform_arguments={"waveform_approximant": "IMRPhenomPv2", "reference_frequency": 50}, + ) return GravitationalWaveTransient(interferometers, waveform_generator) - - diff --git a/bilby/gw/likelihood/base.py b/bilby/gw/likelihood/base.py index c66e05a34..5a3700302 100644 --- a/bilby/gw/likelihood/base.py +++ b/bilby/gw/likelihood/base.py @@ -1,21 +1,20 @@ - -import os import copy +import os import attr import numpy as np from scipy.special import logsumexp from ...core.likelihood import Likelihood, _fallback_to_parameters -from ...core.utils import logger, BoundedRectBivariateSpline, create_time_series -from ...core.prior import Interped, Prior, Uniform, DeltaFunction -from ..detector import InterferometerList, get_empty_interferometer, calibration +from ...core.prior import DeltaFunction, Interped, Prior, Uniform +from ...core.utils import BoundedRectBivariateSpline, create_time_series, logger +from ..detector import InterferometerList, calibration, get_empty_interferometer from ..prior import BBHPriorDict, Cosmological -from ..utils import noise_weighted_inner_product, zenith_azimuth_to_ra_dec, ln_i0 +from ..utils import ln_i0, noise_weighted_inner_product, zenith_azimuth_to_ra_dec class GravitationalWaveTransient(Likelihood): - """ A gravitational-wave transient likelihood object + """A gravitational-wave transient likelihood object This is the usual likelihood object to use for transient gravitational wave parameter estimation. It computes the log-likelihood in the frequency @@ -138,20 +137,29 @@ def snrs_as_sample(self) -> dict: The dictionary of SNRs labelled accordingly """ return { - "matched_filter_snr" : self.complex_matched_filter_snr, - "optimal_snr" : self.optimal_snr_squared.real ** 0.5 + "matched_filter_snr": self.complex_matched_filter_snr, + "optimal_snr": self.optimal_snr_squared.real**0.5, } def __init__( - self, interferometers, waveform_generator, time_marginalization=False, - distance_marginalization=False, phase_marginalization=False, calibration_marginalization=False, priors=None, - distance_marginalization_lookup_table=None, calibration_lookup_table=None, - number_of_response_curves=1000, starting_index=0, jitter_time=True, reference_frame="sky", - time_reference="geocenter" + self, + interferometers, + waveform_generator, + time_marginalization=False, + distance_marginalization=False, + phase_marginalization=False, + calibration_marginalization=False, + priors=None, + distance_marginalization_lookup_table=None, + calibration_lookup_table=None, + number_of_response_curves=1000, + starting_index=0, + jitter_time=True, + reference_frame="sky", + time_reference="geocenter", ): - self.waveform_generator = waveform_generator - super(GravitationalWaveTransient, self).__init__() + super().__init__() self.interferometers = InterferometerList(interferometers) self.time_marginalization = time_marginalization self.distance_marginalization = distance_marginalization @@ -174,80 +182,79 @@ def __init__( self.reference_ifo = None if self.time_marginalization: - self._check_marginalized_prior_is_set(key='geocent_time') + self._check_marginalized_prior_is_set(key="geocent_time") self._setup_time_marginalization() - priors['geocent_time'] = float(self.interferometers.start_time) + priors["geocent_time"] = float(self.interferometers.start_time) if self.jitter_time: - priors['time_jitter'] = Uniform( - minimum=- self._delta_tc / 2, + priors["time_jitter"] = Uniform( + minimum=-self._delta_tc / 2, maximum=self._delta_tc / 2, - boundary='periodic', + boundary="periodic", name="time_jitter", - latex_label="$t_j$" + latex_label="$t_j$", ) - self._marginalized_parameters.append('geocent_time') + self._marginalized_parameters.append("geocent_time") elif self.jitter_time: - logger.debug( - "Time jittering requested with non-time-marginalised " - "likelihood, ignoring.") + logger.debug("Time jittering requested with non-time-marginalised likelihood, ignoring.") self.jitter_time = False if self.phase_marginalization: - self._check_marginalized_prior_is_set(key='phase') - priors['phase'] = float(0) - self._marginalized_parameters.append('phase') + self._check_marginalized_prior_is_set(key="phase") + priors["phase"] = float(0) + self._marginalized_parameters.append("phase") if self.distance_marginalization: self._lookup_table_filename = None - self._check_marginalized_prior_is_set(key='luminosity_distance') + self._check_marginalized_prior_is_set(key="luminosity_distance") self._distance_array = np.linspace( - self.priors['luminosity_distance'].minimum, - self.priors['luminosity_distance'].maximum, int(1e4)) + self.priors["luminosity_distance"].minimum, self.priors["luminosity_distance"].maximum, int(1e4) + ) self.distance_prior_array = np.array( - [self.priors['luminosity_distance'].prob(distance) - for distance in self._distance_array]) - self._ref_dist = self.priors['luminosity_distance'].rescale(0.5) - self._setup_distance_marginalization( - distance_marginalization_lookup_table) - for key in ['redshift', 'comoving_distance']: + [self.priors["luminosity_distance"].prob(distance) for distance in self._distance_array] + ) + self._ref_dist = self.priors["luminosity_distance"].rescale(0.5) + self._setup_distance_marginalization(distance_marginalization_lookup_table) + for key in ["redshift", "comoving_distance"]: if key in priors: del priors[key] - priors['luminosity_distance'] = float(self._ref_dist) - self._marginalized_parameters.append('luminosity_distance') + priors["luminosity_distance"] = float(self._ref_dist) + self._marginalized_parameters.append("luminosity_distance") if self.calibration_marginalization: self.number_of_response_curves = number_of_response_curves self.starting_index = starting_index self._setup_calibration_marginalization(calibration_lookup_table, priors) - self._marginalized_parameters.append('recalib_index') + self._marginalized_parameters.append("recalib_index") def __repr__(self): - return self.__class__.__name__ + '(interferometers={},\n\twaveform_generator={},\n\ttime_marginalization={}, ' \ - 'distance_marginalization={}, phase_marginalization={}, ' \ - 'calibration_marginalization={}, priors={})' \ - .format(self.interferometers, self.waveform_generator, self.time_marginalization, - self.distance_marginalization, self.phase_marginalization, self.calibration_marginalization, - self.priors) + return ( + self.__class__.__name__ + f"(interferometers={self.interferometers}," + f"\n\twaveform_generator={self.waveform_generator}," + f"\n\ttime_marginalization={self.time_marginalization}, " + f"distance_marginalization={self.distance_marginalization}, " + f"phase_marginalization={self.phase_marginalization}, " + f"calibration_marginalization={self.calibration_marginalization}, " + f"priors={self.priors})" + ) def _check_set_duration_and_sampling_frequency_of_waveform_generator(self): - """ Check the waveform_generator has the same duration and + """Check the waveform_generator has the same duration and sampling_frequency as the interferometers. If they are unset, then set them, if they differ, raise an error """ - attributes = ['duration', 'sampling_frequency', 'start_time'] + attributes = ["duration", "sampling_frequency", "start_time"] for attribute in attributes: wfg_attr = getattr(self.waveform_generator, attribute) ifo_attr = getattr(self.interferometers, attribute) if wfg_attr is None: - logger.debug( - "The waveform_generator {} is None. Setting from the " - "provided interferometers.".format(attribute)) + logger.debug(f"The waveform_generator {attribute} is None. Setting from the provided interferometers.") elif wfg_attr != ifo_attr: logger.debug( - "The waveform_generator {} is not equal to that of the " + f"The waveform_generator {attribute} is not equal to that of the " "provided interferometers. Overwriting the " - "waveform_generator.".format(attribute)) + "waveform_generator." + ) setattr(self.waveform_generator, attribute, ifo_attr) def calculate_snrs(self, waveform_polarizations, interferometer, return_array=True, parameters=None): @@ -280,8 +287,8 @@ def calculate_snrs(self, waveform_polarizations, interferometer, return_array=Tr ) _mask = interferometer.frequency_mask - if 'recalib_index' in parameters: - signal[_mask] *= self.calibration_draws[interferometer.name][int(parameters['recalib_index'])] + if "recalib_index" in parameters: + signal[_mask] *= self.calibration_draws[interferometer.name][int(parameters["recalib_index"])] d_inner_h = interferometer.inner_product(signal=signal) optimal_snr_squared = interferometer.optimal_snr_squared(signal=signal) @@ -296,23 +303,22 @@ def calculate_snrs(self, waveform_polarizations, interferometer, return_array=Tr d_inner_h_array = None optimal_snr_squared_array = None elif self.time_marginalization and self.calibration_marginalization: - d_inner_h_integrand = np.tile( - interferometer.frequency_domain_strain.conjugate() * signal / - interferometer.power_spectral_density_array, (self.number_of_response_curves, 1)).T + interferometer.frequency_domain_strain.conjugate() + * signal + / interferometer.power_spectral_density_array, + (self.number_of_response_curves, 1), + ).T d_inner_h_integrand[_mask] *= self.calibration_draws[interferometer.name].T - d_inner_h_array = 4 / self.waveform_generator.duration * np.fft.fft( - d_inner_h_integrand[0:-1], axis=0 - ).T + d_inner_h_array = 4 / self.waveform_generator.duration * np.fft.fft(d_inner_h_integrand[0:-1], axis=0).T optimal_snr_squared_integrand = ( - normalization * np.abs(signal)**2 / interferometer.power_spectral_density_array + normalization * np.abs(signal) ** 2 / interferometer.power_spectral_density_array ) optimal_snr_squared_array = np.dot( - optimal_snr_squared_integrand[_mask], - self.calibration_abs_draws[interferometer.name].T + optimal_snr_squared_integrand[_mask], self.calibration_abs_draws[interferometer.name].T ) elif self.time_marginalization and not self.calibration_marginalization: @@ -322,20 +328,20 @@ def calculate_snrs(self, waveform_polarizations, interferometer, return_array=Tr / interferometer.power_spectral_density_array[0:-1] ) - elif self.calibration_marginalization and ('recalib_index' not in parameters): + elif self.calibration_marginalization and ("recalib_index" not in parameters): d_inner_h_integrand = ( - normalization * - interferometer.frequency_domain_strain.conjugate() * signal + normalization + * interferometer.frequency_domain_strain.conjugate() + * signal / interferometer.power_spectral_density_array ) d_inner_h_array = np.dot(d_inner_h_integrand[_mask], self.calibration_draws[interferometer.name].T) optimal_snr_squared_integrand = ( - normalization * np.abs(signal)**2 / interferometer.power_spectral_density_array + normalization * np.abs(signal) ** 2 / interferometer.power_spectral_density_array ) optimal_snr_squared_array = np.dot( - optimal_snr_squared_integrand[_mask], - self.calibration_abs_draws[interferometer.name].T + optimal_snr_squared_integrand[_mask], self.calibration_abs_draws[interferometer.name].T ) return self._CalculatedSNRs( @@ -348,27 +354,23 @@ def calculate_snrs(self, waveform_polarizations, interferometer, return_array=Tr def _check_marginalized_prior_is_set(self, key): if key in self.priors and self.priors[key].is_fixed: - raise ValueError( - "Cannot use marginalized likelihood for {}: prior is fixed".format(key) - ) - if key not in self.priors or not isinstance( - self.priors[key], Prior): - logger.warning( - 'Prior not provided for {}, using the BBH default.'.format(key)) - if key == 'geocent_time': + raise ValueError(f"Cannot use marginalized likelihood for {key}: prior is fixed") + if key not in self.priors or not isinstance(self.priors[key], Prior): + logger.warning(f"Prior not provided for {key}, using the BBH default.") + if key == "geocent_time": self.priors[key] = Uniform( - self.interferometers.start_time, - self.interferometers.start_time + self.interferometers.duration) - elif key == 'luminosity_distance': - for key in ['redshift', 'comoving_distance']: + self.interferometers.start_time, self.interferometers.start_time + self.interferometers.duration + ) + elif key == "luminosity_distance": + for key in ["redshift", "comoving_distance"]: if key in self.priors: if not isinstance(self.priors[key], Cosmological): raise TypeError( - "To marginalize over {}, the prior must be specified as a " - "subclass of bilby.gw.prior.Cosmological.".format(key) + f"To marginalize over {key}, the prior must be specified as a " + "subclass of bilby.gw.prior.Cosmological." ) - self.priors['luminosity_distance'] = self.priors[key].get_corresponding_prior( - 'luminosity_distance' + self.priors["luminosity_distance"] = self.priors[key].get_corresponding_prior( + "luminosity_distance" ) del self.priors[key] else: @@ -382,8 +384,7 @@ def priors(self): def priors(self, priors): if priors is not None: self._prior = priors.copy() - elif any([self.time_marginalization, self.phase_marginalization, - self.distance_marginalization]): + elif any([self.time_marginalization, self.phase_marginalization, self.distance_marginalization]): raise ValueError("You can't use a marginalized likelihood without specifying a priors") else: self._prior = None @@ -392,11 +393,15 @@ def _calculate_noise_log_likelihood(self): log_l = 0 for interferometer in self.interferometers: mask = interferometer.frequency_mask - log_l -= noise_weighted_inner_product( - interferometer.frequency_domain_strain[mask], - interferometer.frequency_domain_strain[mask], - interferometer.power_spectral_density_array[mask], - self.waveform_generator.duration) / 2 + log_l -= ( + noise_weighted_inner_product( + interferometer.frequency_domain_strain[mask], + interferometer.frequency_domain_strain[mask], + interferometer.power_spectral_density_array[mask], + self.waveform_generator.duration, + ) + / 2 + ) return float(np.real(log_l)) def noise_log_likelihood(self): @@ -410,13 +415,12 @@ def log_likelihood_ratio(self, parameters=None): parameters = copy.deepcopy(parameters) else: parameters = _fallback_to_parameters(self, parameters) - waveform_polarizations = \ - self.waveform_generator.frequency_domain_strain(parameters) + waveform_polarizations = self.waveform_generator.frequency_domain_strain(parameters) if waveform_polarizations is None: return np.nan_to_num(-np.inf) if self.time_marginalization and self.jitter_time: - parameters['geocent_time'] += parameters['time_jitter'] + parameters["geocent_time"] += parameters["time_jitter"] parameters.update(self.get_sky_frame_parameters(parameters)) @@ -434,7 +438,7 @@ def log_likelihood_ratio(self, parameters=None): log_l = self.compute_log_likelihood_from_snrs(total_snrs, parameters=parameters) if self.time_marginalization and self.jitter_time: - parameters['geocent_time'] -= parameters['time_jitter'] + parameters["geocent_time"] -= parameters["time_jitter"] return float(log_l.real) @@ -443,8 +447,8 @@ def compute_log_likelihood_from_snrs(self, total_snrs, parameters=None): if self.calibration_marginalization: log_l = self.calibration_marginalized_likelihood( - d_inner_h_calibration_array=total_snrs.d_inner_h_array, - h_inner_h=total_snrs.optimal_snr_squared_array) + d_inner_h_calibration_array=total_snrs.d_inner_h_array, h_inner_h=total_snrs.optimal_snr_squared_array + ) elif self.time_marginalization: log_l = self.time_marginalized_likelihood( @@ -455,13 +459,15 @@ def compute_log_likelihood_from_snrs(self, total_snrs, parameters=None): elif self.distance_marginalization: log_l = self.distance_marginalized_likelihood( - d_inner_h=total_snrs.d_inner_h, h_inner_h=total_snrs.optimal_snr_squared, + d_inner_h=total_snrs.d_inner_h, + h_inner_h=total_snrs.optimal_snr_squared, parameters=parameters, ) elif self.phase_marginalization: log_l = self.phase_marginalized_likelihood( - d_inner_h=total_snrs.d_inner_h, h_inner_h=total_snrs.optimal_snr_squared) + d_inner_h=total_snrs.d_inner_h, h_inner_h=total_snrs.optimal_snr_squared + ) else: log_l = np.real(total_snrs.d_inner_h) - total_snrs.optimal_snr_squared / 2 @@ -470,11 +476,10 @@ def compute_log_likelihood_from_snrs(self, total_snrs, parameters=None): def compute_per_detector_log_likelihood(self, parameters=None): parameters = _fallback_to_parameters(self, parameters) - waveform_polarizations = \ - self.waveform_generator.frequency_domain_strain(parameters) + waveform_polarizations = self.waveform_generator.frequency_domain_strain(parameters) if self.time_marginalization and self.jitter_time: - parameters['geocent_time'] += parameters['time_jitter'] + parameters["geocent_time"] += parameters["time_jitter"] parameters.update(self.get_sky_frame_parameters(parameters)) @@ -485,11 +490,12 @@ def compute_per_detector_log_likelihood(self, parameters=None): parameters=parameters, ) - parameters['{}_log_likelihood'.format(interferometer.name)] = \ - self.compute_log_likelihood_from_snrs(per_detector_snr, parameters=parameters) + parameters[f"{interferometer.name}_log_likelihood"] = self.compute_log_likelihood_from_snrs( + per_detector_snr, parameters=parameters + ) if self.time_marginalization and self.jitter_time: - parameters['geocent_time'] -= parameters['time_jitter'] + parameters["geocent_time"] -= parameters["time_jitter"] return parameters.copy() @@ -512,32 +518,33 @@ def generate_posterior_sample_from_marginalized_likelihood(self, parameters=None """ parameters = _fallback_to_parameters(self, parameters) if len(self._marginalized_parameters) > 0: - signal_polarizations = copy.deepcopy( - self.waveform_generator.frequency_domain_strain( - parameters)) + signal_polarizations = copy.deepcopy(self.waveform_generator.frequency_domain_strain(parameters)) else: return parameters if self.calibration_marginalization: new_calibration = self.generate_calibration_sample_from_marginalized_likelihood( - signal_polarizations=signal_polarizations, parameters=parameters) - parameters['recalib_index'] = new_calibration + signal_polarizations=signal_polarizations, parameters=parameters + ) + parameters["recalib_index"] = new_calibration if self.time_marginalization: new_time = self.generate_time_sample_from_marginalized_likelihood( - signal_polarizations=signal_polarizations, parameters=parameters) - parameters['geocent_time'] = new_time + signal_polarizations=signal_polarizations, parameters=parameters + ) + parameters["geocent_time"] = new_time if self.distance_marginalization: new_distance = self.generate_distance_sample_from_marginalized_likelihood( - signal_polarizations=signal_polarizations, parameters=parameters) - parameters['luminosity_distance'] = new_distance + signal_polarizations=signal_polarizations, parameters=parameters + ) + parameters["luminosity_distance"] = new_distance if self.phase_marginalization: new_phase = self.generate_phase_sample_from_marginalized_likelihood( - signal_polarizations=signal_polarizations, parameters=parameters) - parameters['phase'] = new_phase + signal_polarizations=signal_polarizations, parameters=parameters + ) + parameters["phase"] = new_phase return parameters.copy() - def generate_calibration_sample_from_marginalized_likelihood( - self, signal_polarizations=None, parameters=None): + def generate_calibration_sample_from_marginalized_likelihood(self, signal_polarizations=None, parameters=None): """ Generate a single sample from the posterior distribution for the set of calibration response curves when explicitly marginalizing over the calibration uncertainty. @@ -555,12 +562,11 @@ def generate_calibration_sample_from_marginalized_likelihood( from ...core.utils import random parameters = _fallback_to_parameters(self, parameters) - if 'recalib_index' in parameters: - parameters.pop('recalib_index') + if "recalib_index" in parameters: + parameters.pop("recalib_index") parameters.update(self.get_sky_frame_parameters(parameters)) if signal_polarizations is None: - signal_polarizations = \ - self.waveform_generator.frequency_domain_strain(parameters) + signal_polarizations = self.waveform_generator.frequency_domain_strain(parameters) log_like = self.get_calibration_log_likelihoods( signal_polarizations=signal_polarizations, parameters=parameters @@ -573,8 +579,7 @@ def generate_calibration_sample_from_marginalized_likelihood( return new_calibration - def generate_time_sample_from_marginalized_likelihood( - self, signal_polarizations=None, parameters=None): + def generate_time_sample_from_marginalized_likelihood(self, signal_polarizations=None, parameters=None): """ Generate a single sample from the posterior distribution for coalescence time when using a likelihood which explicitly marginalises over time. @@ -596,15 +601,15 @@ def generate_time_sample_from_marginalized_likelihood( parameters = _fallback_to_parameters(self, parameters) parameters.update(self.get_sky_frame_parameters(parameters)) if self.jitter_time: - parameters['geocent_time'] += parameters['time_jitter'] + parameters["geocent_time"] += parameters["time_jitter"] if signal_polarizations is None: - signal_polarizations = \ - self.waveform_generator.frequency_domain_strain(parameters) + signal_polarizations = self.waveform_generator.frequency_domain_strain(parameters) times = create_time_series( sampling_frequency=16384, - starting_time=parameters['geocent_time'] - self.waveform_generator.start_time, - duration=self.waveform_generator.duration) + starting_time=parameters["geocent_time"] - self.waveform_generator.start_time, + duration=self.waveform_generator.duration, + ) times = times % self.waveform_generator.duration times += self.waveform_generator.start_time @@ -633,17 +638,16 @@ def generate_time_sample_from_marginalized_likelihood( h_inner_h += ifo.optimal_snr_squared(signal=signal).real if self.distance_marginalization: - time_log_like = self.distance_marginalized_likelihood( - d_inner_h, h_inner_h, parameters=parameters) + time_log_like = self.distance_marginalized_likelihood(d_inner_h, h_inner_h, parameters=parameters) elif self.phase_marginalization: time_log_like = ln_i0(abs(d_inner_h)) - h_inner_h.real / 2 else: - time_log_like = (d_inner_h.real - h_inner_h.real / 2) + time_log_like = d_inner_h.real - h_inner_h.real / 2 - time_prior_array = self.priors['geocent_time'].prob(times) + time_prior_array = self.priors["geocent_time"].prob(times) time_post = np.exp(time_log_like - max(time_log_like)) * time_prior_array - keep = (time_post > max(time_post) / 1000) + keep = time_post > max(time_post) / 1000 if sum(keep) < 3: keep[1:-1] = keep[1:-1] | keep[2:] | keep[:-2] time_post = time_post[keep] @@ -652,8 +656,7 @@ def generate_time_sample_from_marginalized_likelihood( new_time = Interped(times, time_post).sample() return new_time - def generate_distance_sample_from_marginalized_likelihood( - self, signal_polarizations=None, parameters=None): + def generate_distance_sample_from_marginalized_likelihood(self, signal_polarizations=None, parameters=None): """ Generate a single sample from the posterior distribution for luminosity distance when using a likelihood which explicitly marginalises over @@ -676,31 +679,22 @@ def generate_distance_sample_from_marginalized_likelihood( parameters = _fallback_to_parameters(self, parameters) parameters.update(self.get_sky_frame_parameters(parameters)) if signal_polarizations is None: - signal_polarizations = \ - self.waveform_generator.frequency_domain_strain(parameters) + signal_polarizations = self.waveform_generator.frequency_domain_strain(parameters) - d_inner_h, h_inner_h = self._calculate_inner_products( - signal_polarizations, parameters=parameters - ) + d_inner_h, h_inner_h = self._calculate_inner_products(signal_polarizations, parameters=parameters) - d_inner_h_dist = ( - d_inner_h * parameters['luminosity_distance'] / self._distance_array - ) + d_inner_h_dist = d_inner_h * parameters["luminosity_distance"] / self._distance_array - h_inner_h_dist = ( - h_inner_h * parameters['luminosity_distance']**2 / self._distance_array**2 - ) + h_inner_h_dist = h_inner_h * parameters["luminosity_distance"] ** 2 / self._distance_array**2 if self.phase_marginalization: distance_log_like = ln_i0(abs(d_inner_h_dist)) - h_inner_h_dist.real / 2 else: - distance_log_like = (d_inner_h_dist.real - h_inner_h_dist.real / 2) + distance_log_like = d_inner_h_dist.real - h_inner_h_dist.real / 2 - distance_post = (np.exp(distance_log_like - max(distance_log_like)) * - self.distance_prior_array) + distance_post = np.exp(distance_log_like - max(distance_log_like)) * self.distance_prior_array - new_distance = Interped( - self._distance_array, distance_post).sample() + new_distance = Interped(self._distance_array, distance_post).sample() self._rescale_signal(signal_polarizations, new_distance) return new_distance @@ -709,8 +703,7 @@ def _calculate_inner_products(self, signal_polarizations, parameters): d_inner_h = 0 h_inner_h = 0 for interferometer in self.interferometers: - per_detector_snr = self.calculate_snrs( - signal_polarizations, interferometer, parameters=parameters) + per_detector_snr = self.calculate_snrs(signal_polarizations, interferometer, parameters=parameters) d_inner_h += per_detector_snr.d_inner_h h_inner_h += per_detector_snr.optimal_snr_squared @@ -734,8 +727,7 @@ def _compute_full_waveform(self, signal_polarizations, interferometer, parameter parameters = _fallback_to_parameters(self, parameters) return interferometer.get_detector_response(signal_polarizations, parameters) - def generate_phase_sample_from_marginalized_likelihood( - self, signal_polarizations=None, parameters=None): + def generate_phase_sample_from_marginalized_likelihood(self, signal_polarizations=None, parameters=None): r""" Generate a single sample from the posterior distribution for phase when using a likelihood which explicitly marginalises over phase. @@ -759,11 +751,8 @@ def generate_phase_sample_from_marginalized_likelihood( parameters = _fallback_to_parameters(self, parameters) parameters.update(self.get_sky_frame_parameters(parameters)) if signal_polarizations is None: - signal_polarizations = \ - self.waveform_generator.frequency_domain_strain(parameters) - d_inner_h, h_inner_h = self._calculate_inner_products( - signal_polarizations, parameters=parameters - ) + signal_polarizations = self.waveform_generator.frequency_domain_strain(parameters) + d_inner_h, h_inner_h = self._calculate_inner_products(signal_polarizations, parameters=parameters) phases = np.linspace(0, 2 * np.pi, 101) phasor = np.exp(-2j * phases) @@ -774,15 +763,13 @@ def generate_phase_sample_from_marginalized_likelihood( def distance_marginalized_likelihood(self, d_inner_h, h_inner_h, parameters=None): parameters = _fallback_to_parameters(self, parameters) - d_inner_h_ref, h_inner_h_ref = self._setup_rho( - d_inner_h, h_inner_h, parameters=parameters) + d_inner_h_ref, h_inner_h_ref = self._setup_rho(d_inner_h, h_inner_h, parameters=parameters) if self.phase_marginalization: d_inner_h_ref = np.abs(d_inner_h_ref) else: d_inner_h_ref = np.real(d_inner_h_ref) - return self._interp_dist_margd_loglikelihood( - d_inner_h_ref, h_inner_h_ref, grid=False) + return self._interp_dist_margd_loglikelihood(d_inner_h_ref, h_inner_h_ref, grid=False) def phase_marginalized_likelihood(self, d_inner_h, h_inner_h): d_inner_h = ln_i0(abs(d_inner_h)) @@ -796,12 +783,12 @@ def time_marginalized_likelihood(self, d_inner_h_tc_array, h_inner_h, parameters parameters = _fallback_to_parameters(self, parameters) times = self._times if self.jitter_time: - times = self._times + parameters['time_jitter'] + times = self._times + parameters["time_jitter"] - _time_prior = self.priors['geocent_time'] + _time_prior = self.priors["geocent_time"] time_mask = (times >= _time_prior.minimum) & (times <= _time_prior.maximum) times = times[time_mask] - time_prior_array = self.priors['geocent_time'].prob(times) * self._delta_tc + time_prior_array = self.priors["geocent_time"].prob(times) * self._delta_tc if self.calibration_marginalization: d_inner_h_tc_array = d_inner_h_tc_array[:, time_mask] else: @@ -809,11 +796,10 @@ def time_marginalized_likelihood(self, d_inner_h_tc_array, h_inner_h, parameters if self.distance_marginalization: log_l_tc_array = self.distance_marginalized_likelihood( - d_inner_h=d_inner_h_tc_array, h_inner_h=h_inner_h, parameters=parameters) + d_inner_h=d_inner_h_tc_array, h_inner_h=h_inner_h, parameters=parameters + ) elif self.phase_marginalization: - log_l_tc_array = self.phase_marginalized_likelihood( - d_inner_h=d_inner_h_tc_array, - h_inner_h=h_inner_h) + log_l_tc_array = self.phase_marginalized_likelihood(d_inner_h=d_inner_h_tc_array, h_inner_h=h_inner_h) elif self.calibration_marginalization: log_l_tc_array = np.real(d_inner_h_tc_array) - h_inner_h[:, np.newaxis] / 2 else: @@ -824,8 +810,7 @@ def get_calibration_log_likelihoods(self, signal_polarizations=None, parameters= parameters = _fallback_to_parameters(self, parameters) parameters.update(self.get_sky_frame_parameters(parameters)) if signal_polarizations is None: - signal_polarizations = \ - self.waveform_generator.frequency_domain_strain(parameters) + signal_polarizations = self.waveform_generator.frequency_domain_strain(parameters) total_snrs = self._CalculatedSNRs() @@ -851,11 +836,10 @@ def get_calibration_log_likelihoods(self, signal_polarizations=None, parameters= ) elif self.phase_marginalization: log_l_cal_array = self.phase_marginalized_likelihood( - d_inner_h=total_snrs.d_inner_h_array, - h_inner_h=total_snrs.optimal_snr_squared_array) + d_inner_h=total_snrs.d_inner_h_array, h_inner_h=total_snrs.optimal_snr_squared_array + ) else: - log_l_cal_array = \ - np.real(total_snrs.d_inner_h_array - total_snrs.optimal_snr_squared_array / 2) + log_l_cal_array = np.real(total_snrs.d_inner_h_array - total_snrs.optimal_snr_squared_array / 2) return log_l_cal_array @@ -869,11 +853,12 @@ def calibration_marginalized_likelihood(self, d_inner_h_calibration_array, h_inn ) elif self.distance_marginalization: log_l_cal_array = self.distance_marginalized_likelihood( - d_inner_h=d_inner_h_calibration_array, h_inner_h=h_inner_h, parameters=parameters) + d_inner_h=d_inner_h_calibration_array, h_inner_h=h_inner_h, parameters=parameters + ) elif self.phase_marginalization: log_l_cal_array = self.phase_marginalized_likelihood( - d_inner_h=d_inner_h_calibration_array, - h_inner_h=h_inner_h) + d_inner_h=d_inner_h_calibration_array, h_inner_h=h_inner_h + ) else: log_l_cal_array = np.real(d_inner_h_calibration_array - h_inner_h / 2) @@ -881,11 +866,10 @@ def calibration_marginalized_likelihood(self, d_inner_h_calibration_array, h_inn def _setup_rho(self, d_inner_h, optimal_snr_squared, parameters=None): parameters = _fallback_to_parameters(self, parameters) - optimal_snr_squared_ref = (optimal_snr_squared.real * - parameters['luminosity_distance'] ** 2 / - self._ref_dist ** 2.) - d_inner_h_ref = (d_inner_h * parameters['luminosity_distance'] / - self._ref_dist) + optimal_snr_squared_ref = ( + optimal_snr_squared.real * parameters["luminosity_distance"] ** 2 / self._ref_dist**2.0 + ) + d_inner_h_ref = d_inner_h * parameters["luminosity_distance"] / self._ref_dist return d_inner_h_ref, optimal_snr_squared_ref def log_likelihood(self, parameters=None): @@ -897,55 +881,53 @@ def _delta_distance(self): @property def _dist_multiplier(self): - ''' Maximum value of ref_dist/dist_array ''' + """Maximum value of ref_dist/dist_array""" return self._ref_dist / self._distance_array[0] @property def _optimal_snr_squared_ref_array(self): - """ Optimal filter snr at fiducial distance of ref_dist Mpc """ + """Optimal filter snr at fiducial distance of ref_dist Mpc""" return np.logspace(-5, 10, self._dist_margd_loglikelihood_array.shape[0]) @property def _d_inner_h_ref_array(self): - """ Matched filter snr at fiducial distance of ref_dist Mpc """ + """Matched filter snr at fiducial distance of ref_dist Mpc""" if self.phase_marginalization: return np.logspace(-5, 10, self._dist_margd_loglikelihood_array.shape[1]) else: n_negative = self._dist_margd_loglikelihood_array.shape[1] // 2 n_positive = self._dist_margd_loglikelihood_array.shape[1] - n_negative - return np.hstack(( - -np.logspace(3, -3, n_negative), np.logspace(-3, 10, n_positive) - )) + return np.hstack((-np.logspace(3, -3, n_negative), np.logspace(-3, 10, n_positive))) def _setup_distance_marginalization(self, lookup_table=None): if isinstance(lookup_table, str) or lookup_table is None: self.cached_lookup_table_filename = lookup_table - lookup_table = self.load_lookup_table( - self.cached_lookup_table_filename) + lookup_table = self.load_lookup_table(self.cached_lookup_table_filename) if isinstance(lookup_table, dict): if self._test_cached_lookup_table(lookup_table): - self._dist_margd_loglikelihood_array = lookup_table[ - 'lookup_table'] + self._dist_margd_loglikelihood_array = lookup_table["lookup_table"] else: self._create_lookup_table() else: self._create_lookup_table() self._interp_dist_margd_loglikelihood = BoundedRectBivariateSpline( - self._d_inner_h_ref_array, self._optimal_snr_squared_ref_array, - self._dist_margd_loglikelihood_array.T, fill_value=-np.inf) + self._d_inner_h_ref_array, + self._optimal_snr_squared_ref_array, + self._dist_margd_loglikelihood_array.T, + fill_value=-np.inf, + ) @property def cached_lookup_table_filename(self): if self._lookup_table_filename is None: - self._lookup_table_filename = ( - '.distance_marginalization_lookup.npz') + self._lookup_table_filename = ".distance_marginalization_lookup.npz" return self._lookup_table_filename @cached_lookup_table_filename.setter def cached_lookup_table_filename(self, filename): if isinstance(filename, str): - if filename[-4:] != '.npz': - filename += '.npz' + if filename[-4:] != ".npz": + filename += ".npz" self._lookup_table_filename = filename def load_lookup_table(self, filename): @@ -958,63 +940,59 @@ def load_lookup_table(self, filename): return None match, failure = self._test_cached_lookup_table(loaded_file) if match: - logger.info('Loaded distance marginalisation lookup table from ' - '{}.'.format(filename)) + logger.info(f"Loaded distance marginalisation lookup table from {filename}.") return loaded_file else: - logger.info('Loaded distance marginalisation lookup table does ' - 'not match for {}.'.format(failure)) + logger.info(f"Loaded distance marginalisation lookup table does not match for {failure}.") elif isinstance(filename, str): - logger.info('Distance marginalisation file {} does not ' - 'exist'.format(filename)) + logger.info(f"Distance marginalisation file {filename} does not exist") return None def cache_lookup_table(self): - np.savez(self.cached_lookup_table_filename, - distance_array=self._distance_array, - prior_array=self.distance_prior_array, - lookup_table=self._dist_margd_loglikelihood_array, - reference_distance=self._ref_dist, - phase_marginalization=self.phase_marginalization) + np.savez( + self.cached_lookup_table_filename, + distance_array=self._distance_array, + prior_array=self.distance_prior_array, + lookup_table=self._dist_margd_loglikelihood_array, + reference_distance=self._ref_dist, + phase_marginalization=self.phase_marginalization, + ) def _test_cached_lookup_table(self, loaded_file): pairs = dict( distance_array=self._distance_array, prior_array=self.distance_prior_array, reference_distance=self._ref_dist, - phase_marginalization=self.phase_marginalization) + phase_marginalization=self.phase_marginalization, + ) for key in pairs: if key not in loaded_file: return False, key - elif not np.allclose(np.atleast_1d(loaded_file[key]), - np.atleast_1d(pairs[key]), - rtol=1e-15): + elif not np.allclose(np.atleast_1d(loaded_file[key]), np.atleast_1d(pairs[key]), rtol=1e-15): return False, key return True, None def _create_lookup_table(self): - """ Make the lookup table """ + """Make the lookup table""" from tqdm.auto import tqdm - logger.info('Building lookup table for distance marginalisation.') + + logger.info("Building lookup table for distance marginalisation.") self._dist_margd_loglikelihood_array = np.zeros((400, 800)) scaling = self._ref_dist / self._distance_array d_inner_h_array_full = np.outer(self._d_inner_h_ref_array, scaling) - h_inner_h_array_full = np.outer(self._optimal_snr_squared_ref_array, scaling ** 2) + h_inner_h_array_full = np.outer(self._optimal_snr_squared_ref_array, scaling**2) if self.phase_marginalization: d_inner_h_array_full = ln_i0(abs(d_inner_h_array_full)) prior_term = self.distance_prior_array * self._delta_distance for ii, optimal_snr_squared_array in tqdm( - enumerate(h_inner_h_array_full), total=len(self._optimal_snr_squared_ref_array) + enumerate(h_inner_h_array_full), total=len(self._optimal_snr_squared_ref_array) ): for jj, d_inner_h_array in enumerate(d_inner_h_array_full): self._dist_margd_loglikelihood_array[ii][jj] = logsumexp( - d_inner_h_array - optimal_snr_squared_array / 2, - b=prior_term + d_inner_h_array - optimal_snr_squared_array / 2, b=prior_term ) - log_norm = logsumexp( - 0 / self._distance_array, b=self.distance_prior_array * self._delta_distance - ) + log_norm = logsumexp(0 / self._distance_array, b=self.distance_prior_array * self._delta_distance) self._dist_margd_loglikelihood_array -= log_norm self.cache_lookup_table() @@ -1027,13 +1005,15 @@ def _setup_phase_marginalization(self, min_bound=-5, max_bound=10): def _setup_time_marginalization(self): self._delta_tc = 2 / self.waveform_generator.sampling_frequency - self._times = \ - self.interferometers.start_time + np.linspace( - 0, self.interferometers.duration, - int(self.interferometers.duration / 2 * - self.waveform_generator.sampling_frequency + 1))[1:] - self.time_prior_array = \ - self.priors['geocent_time'].prob(self._times) * self._delta_tc + self._times = ( + self.interferometers.start_time + + np.linspace( + 0, + self.interferometers.duration, + int(self.interferometers.duration / 2 * self.waveform_generator.sampling_frequency + 1), + )[1:] + ) + self.time_prior_array = self.priors["geocent_time"].prob(self._times) * self._delta_tc def _setup_calibration_marginalization(self, calibration_lookup_table, priors=None): self.calibration_draws, self.calibration_parameter_draws = calibration.build_calibration_lookup( @@ -1049,7 +1029,7 @@ def _setup_calibration_marginalization(self, calibration_lookup_table, priors=No priors[key] = DeltaFunction(0.0) self.calibration_abs_draws = dict() for name in self.calibration_draws: - self.calibration_abs_draws[name] = np.abs(self.calibration_draws[name])**2 + self.calibration_abs_draws[name] = np.abs(self.calibration_draws[name]) ** 2 @property def interferometers(self): @@ -1085,7 +1065,7 @@ def reference_frame(self, frame): elif isinstance(frame, str): self._reference_frame = InterferometerList([frame[:2], frame[2:4]]) else: - raise ValueError("Unable to parse reference frame {}".format(frame)) + raise ValueError(f"Unable to parse reference frame {frame}") def get_sky_frame_parameters(self, parameters=None): """ @@ -1105,34 +1085,26 @@ def get_sky_frame_parameters(self, parameters=None): dict: dictionary containing ra, dec, and geocent_time """ parameters = _fallback_to_parameters(self, parameters) - time = parameters.get(f'{self.time_reference}_time', None) + time = parameters.get(f"{self.time_reference}_time", None) if time is None and "geocent_time" in parameters: - logger.warning( - f"Cannot find {self.time_reference}_time in parameters. " - "Falling back to geocent time" - ) + logger.warning(f"Cannot find {self.time_reference}_time in parameters. Falling back to geocent time") if not self.reference_frame == "sky": try: ra, dec = zenith_azimuth_to_ra_dec( - parameters['zenith'], parameters['azimuth'], - time, self.reference_frame) + parameters["zenith"], parameters["azimuth"], time, self.reference_frame + ) except KeyError: if "ra" in parameters and "dec" in parameters: ra = parameters["ra"] dec = parameters["dec"] - logger.warning( - "Cannot convert from zenith/azimuth to ra/dec falling " - "back to provided ra/dec" - ) + logger.warning("Cannot convert from zenith/azimuth to ra/dec falling back to provided ra/dec") else: raise else: ra = parameters["ra"] dec = parameters["dec"] if "geocent" not in self.time_reference and f"{self.time_reference}_time" in parameters: - geocent_time = time - self.reference_ifo.time_delay_from_geocenter( - ra=ra, dec=dec, time=time - ) + geocent_time = time - self.reference_ifo.time_delay_from_geocenter(ra=ra, dec=dec, time=time) else: geocent_time = parameters["geocent_time"] return dict(ra=ra, dec=dec, geocent_time=geocent_time) @@ -1140,24 +1112,26 @@ def get_sky_frame_parameters(self, parameters=None): @property def lal_version(self): try: - from lal import git_version, __version__ + from lal import __version__, git_version + lal_version = str(__version__) - logger.info("Using lal version {}".format(lal_version)) + logger.info(f"Using lal version {lal_version}") lal_git_version = str(git_version.verbose_msg).replace("\n", ";") - logger.info("Using lal git version {}".format(lal_git_version)) - return "lal_version={}, lal_git_version={}".format(lal_version, lal_git_version) + logger.info(f"Using lal git version {lal_git_version}") + return f"lal_version={lal_version}, lal_git_version={lal_git_version}" except (ImportError, AttributeError): return "N/A" @property def lalsimulation_version(self): try: - from lalsimulation import git_version, __version__ + from lalsimulation import __version__, git_version + lalsim_version = str(__version__) - logger.info("Using lalsimulation version {}".format(lalsim_version)) + logger.info(f"Using lalsimulation version {lalsim_version}") lalsim_git_version = str(git_version.verbose_msg).replace("\n", ";") - logger.info("Using lalsimulation git version {}".format(lalsim_git_version)) - return "lalsimulation_version={}, lalsimulation_git_version={}".format(lalsim_version, lalsim_git_version) + logger.info(f"Using lalsimulation git version {lalsim_git_version}") + return f"lalsimulation_version={lalsim_version}, lalsimulation_git_version={lalsim_git_version}" except (ImportError, AttributeError): return "N/A" @@ -1180,4 +1154,5 @@ def meta_data(self): time_reference=self.time_reference, reference_frame=self._reference_frame_str, lal_version=self.lal_version, - lalsimulation_version=self.lalsimulation_version) + lalsimulation_version=self.lalsimulation_version, + ) diff --git a/bilby/gw/likelihood/basic.py b/bilby/gw/likelihood/basic.py index b2f04eb69..120fbbe0c 100644 --- a/bilby/gw/likelihood/basic.py +++ b/bilby/gw/likelihood/basic.py @@ -4,7 +4,6 @@ class BasicGravitationalWaveTransient(Likelihood): - def __init__(self, interferometers, waveform_generator): """ @@ -25,16 +24,18 @@ def __init__(self, interferometers, waveform_generator): given some set of parameters """ - super(BasicGravitationalWaveTransient, self).__init__(dict()) + super().__init__(dict()) self.interferometers = interferometers self.waveform_generator = waveform_generator def __repr__(self): - return self.__class__.__name__ + '(interferometers={},\n\twaveform_generator={})' \ - .format(self.interferometers, self.waveform_generator) + return ( + self.__class__.__name__ + + f"(interferometers={self.interferometers},\n\twaveform_generator={self.waveform_generator})" + ) def noise_log_likelihood(self): - """ Calculates the real part of noise log-likelihood + """Calculates the real part of noise log-likelihood Returns ======= @@ -43,13 +44,15 @@ def noise_log_likelihood(self): """ log_l = 0 for interferometer in self.interferometers: - log_l -= 2. / self.waveform_generator.duration * np.sum( - abs(interferometer.frequency_domain_strain) ** 2 / - interferometer.power_spectral_density_array) + log_l -= ( + 2.0 + / self.waveform_generator.duration + * np.sum(abs(interferometer.frequency_domain_strain) ** 2 / interferometer.power_spectral_density_array) + ) return log_l.real def log_likelihood(self, parameters=None): - """ Calculates the real part of log-likelihood value + """Calculates the real part of log-likelihood value Returns ======= @@ -58,17 +61,14 @@ def log_likelihood(self, parameters=None): """ parameters = _fallback_to_parameters(self, parameters) log_l = 0 - waveform_polarizations = \ - self.waveform_generator.frequency_domain_strain(parameters) + waveform_polarizations = self.waveform_generator.frequency_domain_strain(parameters) if waveform_polarizations is None: return np.nan_to_num(-np.inf) for interferometer in self.interferometers: - log_l += self.log_likelihood_interferometer( - waveform_polarizations, interferometer) + log_l += self.log_likelihood_interferometer(waveform_polarizations, interferometer) return log_l.real - def log_likelihood_interferometer(self, waveform_polarizations, - interferometer, parameters=None): + def log_likelihood_interferometer(self, waveform_polarizations, interferometer, parameters=None): """ Parameters @@ -84,11 +84,14 @@ def log_likelihood_interferometer(self, waveform_polarizations, """ parameters = _fallback_to_parameters(self, parameters) - signal_ifo = interferometer.get_detector_response( - waveform_polarizations, parameters) - - log_l = - 2. / self.waveform_generator.duration * np.vdot( - interferometer.frequency_domain_strain - signal_ifo, - (interferometer.frequency_domain_strain - signal_ifo) / - interferometer.power_spectral_density_array) + signal_ifo = interferometer.get_detector_response(waveform_polarizations, parameters) + + log_l = ( + -2.0 + / self.waveform_generator.duration + * np.vdot( + interferometer.frequency_domain_strain - signal_ifo, + (interferometer.frequency_domain_strain - signal_ifo) / interferometer.power_spectral_density_array, + ) + ) return log_l.real diff --git a/bilby/gw/likelihood/multiband.py b/bilby/gw/likelihood/multiband.py index cc7b8e386..499dd1e20 100644 --- a/bilby/gw/likelihood/multiband.py +++ b/bilby/gw/likelihood/multiband.py @@ -1,19 +1,22 @@ - import math import numbers import numpy as np -from .base import GravitationalWaveTransient +from ...core.likelihood import _fallback_to_parameters from ...core.utils import ( - logger, speed_of_light, solar_mass, radius_of_earth, - gravitational_constant, round_up_to_power_of_two, + gravitational_constant, + logger, + radius_of_earth, recursively_load_dict_contents_from_group, - recursively_save_dict_contents_to_group + recursively_save_dict_contents_to_group, + round_up_to_power_of_two, + solar_mass, + speed_of_light, ) -from ...core.likelihood import _fallback_to_parameters from ..prior import CBCPriorDict from ..utils import ln_i0 +from .base import GravitationalWaveTransient class MBGravitationalWaveTransient(GravitationalWaveTransient): @@ -90,20 +93,40 @@ class MBGravitationalWaveTransient(GravitationalWaveTransient): A likelihood object, able to compute the likelihood of the data given some model parameters """ + def __init__( - self, interferometers, waveform_generator, reference_chirp_mass=None, highest_mode=2, - linear_interpolation=True, accuracy_factor=5, time_offset=None, delta_f_end=None, - maximum_banding_frequency=None, minimum_banding_duration=0., weights=None, - distance_marginalization=False, phase_marginalization=False, priors=None, - time_marginalization=False, jitter_time=True, distance_marginalization_lookup_table=None, - reference_frame="sky", time_reference="geocenter" + self, + interferometers, + waveform_generator, + reference_chirp_mass=None, + highest_mode=2, + linear_interpolation=True, + accuracy_factor=5, + time_offset=None, + delta_f_end=None, + maximum_banding_frequency=None, + minimum_banding_duration=0.0, + weights=None, + distance_marginalization=False, + phase_marginalization=False, + priors=None, + time_marginalization=False, + jitter_time=True, + distance_marginalization_lookup_table=None, + reference_frame="sky", + time_reference="geocenter", ): - super(MBGravitationalWaveTransient, self).__init__( - interferometers=interferometers, waveform_generator=waveform_generator, priors=priors, - distance_marginalization=distance_marginalization, phase_marginalization=phase_marginalization, + super().__init__( + interferometers=interferometers, + waveform_generator=waveform_generator, + priors=priors, + distance_marginalization=distance_marginalization, + phase_marginalization=phase_marginalization, time_marginalization=time_marginalization, distance_marginalization_lookup_table=distance_marginalization_lookup_table, - jitter_time=jitter_time, reference_frame=reference_frame, time_reference=time_reference + jitter_time=jitter_time, + reference_frame=reference_frame, + time_reference=time_reference, ) if weights is None: self.reference_chirp_mass = reference_chirp_mass @@ -118,9 +141,10 @@ def __init__( else: if isinstance(weights, str): import h5py + logger.info(f"Loading multiband weights from {weights}.") - with h5py.File(weights, 'r') as f: - weights = recursively_load_dict_contents_from_group(f, '/') + with h5py.File(weights, "r") as f: + weights = recursively_load_dict_contents_from_group(f, "/") self.setup_multibanding_from_weights(weights) if self.time_marginalization: self._setup_time_marginalization_multiband() @@ -131,7 +155,7 @@ def reference_chirp_mass(self): @property def reference_chirp_mass_in_second(self): - return gravitational_constant * self._reference_chirp_mass * solar_mass / speed_of_light**3. + return gravitational_constant * self._reference_chirp_mass * solar_mass / speed_of_light**3.0 @reference_chirp_mass.setter def reference_chirp_mass(self, reference_chirp_mass): @@ -215,13 +239,16 @@ def time_offset(self, time_offset): raise TypeError("time_offset must be a number") elif self.priors is not None and time_parameter in self.priors: self._time_offset = ( - self.interferometers.start_time + self.interferometers.duration - - self.priors[time_parameter].minimum + safety + self.interferometers.start_time + + self.interferometers.duration + - self.priors[time_parameter].minimum + + safety ) else: self._time_offset = 2.12 - logger.warning("time offset can not be inferred. Use the standard time offset of {} seconds.".format( - self._time_offset)) + logger.warning( + f"time offset can not be inferred. Use the standard time offset of {self._time_offset} seconds." + ) @property def delta_f_end(self): @@ -249,13 +276,14 @@ def delta_f_end(self, delta_f_end): raise TypeError("delta_f_end must be a number") elif self.priors is not None and time_parameter in self.priors: self._delta_f_end = 100 / ( - self.interferometers.start_time + self.interferometers.duration - - self.priors[time_parameter].maximum - safety + self.interferometers.start_time + + self.interferometers.duration + - self.priors[time_parameter].maximum + - safety ) else: - self._delta_f_end = 53. - logger.warning("delta_f_end can not be inferred. Use the standard delta_f_end of {} Hz.".format( - self._delta_f_end)) + self._delta_f_end = 53.0 + logger.warning(f"delta_f_end can not be inferred. Use the standard delta_f_end of {self._delta_f_end} Hz.") @property def maximum_banding_frequency(self): @@ -271,16 +299,14 @@ def maximum_banding_frequency(self, maximum_banding_frequency): time-to-merger \tau(f). The user-specified frequency is used if it is lower than that frequency. """ fmax_tmp = ( - (15 / 968)**(3 / 5) * (self.highest_mode / (2 * np.pi))**(8 / 5) - / self.reference_chirp_mass_in_second + (15 / 968) ** (3 / 5) * (self.highest_mode / (2 * np.pi)) ** (8 / 5) / self.reference_chirp_mass_in_second ) if maximum_banding_frequency is not None: if isinstance(maximum_banding_frequency, numbers.Number): if maximum_banding_frequency < fmax_tmp: fmax_tmp = maximum_banding_frequency else: - logger.warning("The input maximum_banding_frequency is too large." - "It is set to be {} Hz.".format(fmax_tmp)) + logger.warning(f"The input maximum_banding_frequency is too large.It is set to be {fmax_tmp} Hz.") else: raise TypeError("maximum_banding_frequency must be a number") self._maximum_banding_frequency = fmax_tmp @@ -335,7 +361,9 @@ def _tau(self, f): """ f_22 = 2 * f / self.highest_mode return ( - 5 / 256 * self.reference_chirp_mass_in_second + 5 + / 256 + * self.reference_chirp_mass_in_second * (np.pi * self.reference_chirp_mass_in_second * f_22) ** (-8 / 3) ) @@ -355,8 +383,11 @@ def _dtaudf(self, f): """ f_22 = 2 * f / self.highest_mode return ( - -5 / 96 * self.reference_chirp_mass_in_second - * (np.pi * self.reference_chirp_mass_in_second * f_22) ** (-8. / 3.) / f + -5 + / 96 + * self.reference_chirp_mass_in_second + * (np.pi * self.reference_chirp_mass_in_second * f_22) ** (-8.0 / 3.0) + / f ) def _find_starting_frequency(self, duration, fnow): @@ -379,25 +410,24 @@ def _find_starting_frequency(self, duration, fnow): exist. """ + def _is_above_fnext(f): """This function returns True if f > fnext""" - cond1 = ( - duration - self.time_offset - self._tau(f) - - self.accuracy_factor * np.sqrt(-self._dtaudf(f)) - ) > 0 - cond2 = f - 1. / np.sqrt(-self._dtaudf(f)) - fnow > 0 + cond1 = (duration - self.time_offset - self._tau(f) - self.accuracy_factor * np.sqrt(-self._dtaudf(f))) > 0 + cond2 = f - 1.0 / np.sqrt(-self._dtaudf(f)) - fnow > 0 return cond1 and cond2 + # Bisection search for fnext fmin, fmax = fnow, self.maximum_banding_frequency if not _is_above_fnext(fmax): return None, None while fmax - fmin > 1e-2 / duration: - f = (fmin + fmax) / 2. + f = (fmin + fmax) / 2.0 if _is_above_fnext(f): fmax = f else: fmin = f - return f, 1. / np.sqrt(-self._dtaudf(f)) + return f, 1.0 / np.sqrt(-self._dtaudf(f)) def _setup_frequency_bands(self): r"""Set up frequency bands. The durations of bands geometrically decrease T, T/2. T/4, ..., where T is the @@ -409,7 +439,7 @@ def _setup_frequency_bands(self): """ self.durations = np.array([self.interferometers.duration]) - self.fb_dfb = [[self.minimum_frequency, 0.]] + self.fb_dfb = [[self.minimum_frequency, 0.0]] dnext = self.interferometers.duration / 2 while dnext > max(self.time_offset, self.minimum_banding_duration): fnow, _ = self.fb_dfb[-1] @@ -422,8 +452,11 @@ def _setup_frequency_bands(self): break self.fb_dfb.append([self.maximum_frequency + self.delta_f_end, self.delta_f_end]) self.fb_dfb = np.array(self.fb_dfb) - logger.info("The total frequency range is divided into {} bands with frequency intervals of {}.".format( - self.number_of_bands, ", ".join(["1/{} Hz".format(d) for d in self.durations]))) + logger.info( + "The total frequency range is divided into {} bands with frequency intervals of {}.".format( + self.number_of_bands, ", ".join([f"1/{d} Hz" for d in self.durations]) + ) + ) def _setup_integers(self): """Set up integers needed for likelihood evaluations. This sets the following instance variables. @@ -440,7 +473,7 @@ def _setup_integers(self): dnow = self.durations[b] fnow, dfnow = self.fb_dfb[b] fnext, _ = self.fb_dfb[b + 1] - Nb = max(round_up_to_power_of_two(2. * (fnext * self.interferometers.duration + 1.)), 2**b) + Nb = max(round_up_to_power_of_two(2.0 * (fnext * self.interferometers.duration + 1.0)), 2**b) self.Nbs = np.append(self.Nbs, Nb) self.Mbs = np.append(self.Mbs, Nb // 2**b) self.Ks_Ke.append([math.ceil((fnow - dfnow) * dnow), math.floor(fnext * dnow)]) @@ -469,13 +502,16 @@ def _setup_waveform_frequency_points(self): start_idx = end_idx + 1 self.start_end_idxs = np.array(self.start_end_idxs) unique_frequencies, idxs = np.unique(self.banded_frequency_points, return_inverse=True) - self.waveform_generator.waveform_arguments['frequencies'] = unique_frequencies + self.waveform_generator.waveform_arguments["frequencies"] = unique_frequencies self.unique_to_original_frequencies = idxs - logger.info("The number of frequency points where waveforms are evaluated is {}.".format( - len(unique_frequencies))) - logger.info("The speed-up gain of multi-banding is {}.".format( - (self.maximum_frequency - self.minimum_frequency) * self.interferometers.duration / - len(unique_frequencies))) + logger.info(f"The number of frequency points where waveforms are evaluated is {len(unique_frequencies)}.") + logger.info( + "The speed-up gain of multi-banding is {}.".format( + (self.maximum_frequency - self.minimum_frequency) + * self.interferometers.duration + / len(unique_frequencies) + ) + ) def _get_window_sequence(self, delta_f, start_idx, length, b): """Compute window function on frequencies with a fixed frequency interval @@ -500,29 +536,21 @@ def _get_window_sequence(self, delta_f, start_idx, length, b): fnext, dfnext = self.fb_dfb[b + 1] window_sequence = np.zeros(length) - increase_start = np.clip( - math.floor((fnow - dfnow) / delta_f) - start_idx + 1, 0, length - ) + increase_start = np.clip(math.floor((fnow - dfnow) / delta_f) - start_idx + 1, 0, length) unity_start = np.clip(math.ceil(fnow / delta_f) - start_idx, 0, length) - decrease_start = np.clip( - math.floor((fnext - dfnext) / delta_f) - start_idx + 1, 0, length - ) + decrease_start = np.clip(math.floor((fnext - dfnext) / delta_f) - start_idx + 1, 0, length) decrease_stop = np.clip(math.ceil(fnext / delta_f) - start_idx, 0, length) - window_sequence[unity_start:decrease_start] = 1. + window_sequence[unity_start:decrease_start] = 1.0 # this if statement avoids overflow caused by vanishing dfnow if increase_start < unity_start: frequencies = (np.arange(increase_start, unity_start) + start_idx) * delta_f - window_sequence[increase_start:unity_start] = ( - 1. + np.cos(np.pi * (frequencies - fnow) / dfnow) - ) / 2. + window_sequence[increase_start:unity_start] = (1.0 + np.cos(np.pi * (frequencies - fnow) / dfnow)) / 2.0 if decrease_start < decrease_stop: frequencies = (np.arange(decrease_start, decrease_stop) + start_idx) * delta_f - window_sequence[decrease_start:decrease_stop] = ( - 1. - np.cos(np.pi * (frequencies - fnext) / dfnext) - ) / 2. + window_sequence[decrease_start:decrease_stop] = (1.0 - np.cos(np.pi * (frequencies - fnext) / dfnext)) / 2.0 return window_sequence @@ -531,20 +559,22 @@ def _setup_linear_coefficients(self): self.linear_coeffs = dict((ifo.name, np.array([])) for ifo in self.interferometers) N = self.Nbs[-1] for ifo in self.interferometers: - logger.info("Pre-computing linear coefficients for {}".format(ifo.name)) + logger.info(f"Pre-computing linear coefficients for {ifo.name}") fddata = np.zeros(N // 2 + 1, dtype=complex) - fddata[:len(ifo.frequency_domain_strain)][ifo.frequency_mask[:len(fddata)]] += \ + fddata[: len(ifo.frequency_domain_strain)][ifo.frequency_mask[: len(fddata)]] += ( ifo.frequency_domain_strain[ifo.frequency_mask] / ifo.power_spectral_density_array[ifo.frequency_mask] + ) for b in range(self.number_of_bands): Ks, Ke = self.Ks_Ke[b] - windows = self._get_window_sequence(1. / self.durations[b], Ks, Ke - Ks + 1, b) - fddata_in_ith_band = np.copy(fddata[:int(self.Nbs[b] / 2 + 1)]) - fddata_in_ith_band[-1] = 0. # zeroing data at the Nyquist frequency - tddata = np.fft.irfft(fddata_in_ith_band)[-self.Mbs[b]:] + windows = self._get_window_sequence(1.0 / self.durations[b], Ks, Ke - Ks + 1, b) + fddata_in_ith_band = np.copy(fddata[: int(self.Nbs[b] / 2 + 1)]) + fddata_in_ith_band[-1] = 0.0 # zeroing data at the Nyquist frequency + tddata = np.fft.irfft(fddata_in_ith_band)[-self.Mbs[b] :] Ks, Ke = self.Ks_Ke[b] - fddata_in_ith_band = np.fft.rfft(tddata)[Ks:Ke + 1] + fddata_in_ith_band = np.fft.rfft(tddata)[Ks : Ke + 1] self.linear_coeffs[ifo.name] = np.append( - self.linear_coeffs[ifo.name], (4. / self.durations[b]) * windows * np.conj(fddata_in_ith_band)) + self.linear_coeffs[ifo.name], (4.0 / self.durations[b]) * windows * np.conj(fddata_in_ith_band) + ) def _setup_quadratic_coefficients_linear_interp(self): """Set up coefficients by which the squares of waveforms are multiplied to compute (h, h) for the @@ -556,7 +586,7 @@ def _setup_quadratic_coefficients_linear_interp(self): for b in range(self.number_of_bands): logger.info(f"Pre-computing quadratic coefficients for the {b}-th band") _start, _end = self.start_end_idxs[b] - banded_frequencies = self.banded_frequency_points[_start:_end + 1] + banded_frequencies = self.banded_frequency_points[_start : _end + 1] prefactor = 4 * self.durations[b] / original_duration # precompute window values @@ -567,44 +597,40 @@ def _setup_quadratic_coefficients_linear_interp(self): 1 / original_duration, start_idx_in_band, math.floor(_fnext * original_duration) - start_idx_in_band + 1, - b + b, ) for ifo in self.interferometers: end_idx_in_band = min( - start_idx_in_band + len(window_sequence) - 1, - len(ifo.power_spectral_density_array) - 1 + start_idx_in_band + len(window_sequence) - 1, len(ifo.power_spectral_density_array) - 1 ) - _frequency_mask = ifo.frequency_mask[start_idx_in_band:end_idx_in_band + 1] + _frequency_mask = ifo.frequency_mask[start_idx_in_band : end_idx_in_band + 1] window_over_psd = np.zeros(end_idx_in_band + 1 - start_idx_in_band) - window_over_psd[_frequency_mask] = \ - 1. / ifo.power_spectral_density_array[start_idx_in_band:end_idx_in_band + 1][_frequency_mask] - window_over_psd *= window_sequence[:len(window_over_psd)] + window_over_psd[_frequency_mask] = ( + 1.0 / ifo.power_spectral_density_array[start_idx_in_band : end_idx_in_band + 1][_frequency_mask] + ) + window_over_psd *= window_sequence[: len(window_over_psd)] coeffs = np.zeros(len(banded_frequencies)) for k in range(len(coeffs) - 1): if k == 0: start_idx_in_sum = start_idx_in_band else: - start_idx_in_sum = max( - start_idx_in_band, - math.ceil(original_duration * banded_frequencies[k]) - ) + start_idx_in_sum = max(start_idx_in_band, math.ceil(original_duration * banded_frequencies[k])) if k == len(coeffs) - 2: end_idx_in_sum = end_idx_in_band else: end_idx_in_sum = min( - end_idx_in_band, - math.ceil(original_duration * banded_frequencies[k + 1]) - 1 + end_idx_in_band, math.ceil(original_duration * banded_frequencies[k + 1]) - 1 ) frequencies_in_sum = np.arange(start_idx_in_sum, end_idx_in_sum + 1) / original_duration coeffs[k] += prefactor * np.sum( - (banded_frequencies[k + 1] - frequencies_in_sum) * - window_over_psd[start_idx_in_sum - start_idx_in_band:end_idx_in_sum - start_idx_in_band + 1] + (banded_frequencies[k + 1] - frequencies_in_sum) + * window_over_psd[start_idx_in_sum - start_idx_in_band : end_idx_in_sum - start_idx_in_band + 1] ) coeffs[k + 1] += prefactor * np.sum( - (frequencies_in_sum - banded_frequencies[k]) * - window_over_psd[start_idx_in_sum - start_idx_in_band:end_idx_in_sum - start_idx_in_band + 1] + (frequencies_in_sum - banded_frequencies[k]) + * window_over_psd[start_idx_in_sum - start_idx_in_band : end_idx_in_sum - start_idx_in_band + 1] ) self.quadratic_coeffs[ifo.name] = np.append(self.quadratic_coeffs[ifo.name], coeffs) @@ -620,15 +646,15 @@ def _setup_quadratic_coefficients_ifft_fft(self): self.hbcs = dict((ifo.name, []) for ifo in self.interferometers) self.wths = dict((ifo.name, []) for ifo in self.interferometers) for ifo in self.interferometers: - logger.info("Pre-computing quadratic coefficients for {}".format(ifo.name)) + logger.info(f"Pre-computing quadratic coefficients for {ifo.name}") full_inv_psds = np.zeros(N // 2 + 1) - full_inv_psds[:len(ifo.power_spectral_density_array)][ifo.frequency_mask[:len(full_inv_psds)]] = ( + full_inv_psds[: len(ifo.power_spectral_density_array)][ifo.frequency_mask[: len(full_inv_psds)]] = ( 1 / ifo.power_spectral_density_array[ifo.frequency_mask] ) for b in range(self.number_of_bands): - Imb = np.fft.irfft(full_inv_psds[:self.Nbs[b] // 2 + 1]) + Imb = np.fft.irfft(full_inv_psds[: self.Nbs[b] // 2 + 1]) half_length = Nhatbs[b] // 2 - Imbc = np.append(Imb[:half_length + 1], Imb[-(Nhatbs[b] - half_length - 1):]) + Imbc = np.append(Imb[: half_length + 1], Imb[-(Nhatbs[b] - half_length - 1) :]) self.Ibcs[ifo.name].append(np.fft.rfft(Imbc)) # Allocate arrays for IFFT-FFT operations self.hbcs[ifo.name].append(np.zeros(Nhatbs[b])) @@ -638,7 +664,7 @@ def _setup_quadratic_coefficients_ifft_fft(self): self.square_root_windows = np.array([]) for b in range(self.number_of_bands): Ks, Ke = self.Ks_Ke[b] - ws = self._get_window_sequence(1. / self.durations[b], Ks, Ke - Ks + 1, b) + ws = self._get_window_sequence(1.0 / self.durations[b], Ks, Ke - Ks + 1, b) self.windows = np.append(self.windows, ws) self.square_root_windows = np.append(self.square_root_windows, np.sqrt(ws)) @@ -646,16 +672,26 @@ def _setup_quadratic_coefficients_ifft_fft(self): def weights(self): _weights = {} for key in [ - "reference_chirp_mass", "highest_mode", "linear_interpolation", - "accuracy_factor", "time_offset", "delta_f_end", - "maximum_banding_frequency", "minimum_banding_duration", - "durations", "fb_dfb", "Nbs", "Mbs", "Ks_Ke", - "banded_frequency_points", "start_end_idxs", - "unique_to_original_frequencies", "linear_coeffs" + "reference_chirp_mass", + "highest_mode", + "linear_interpolation", + "accuracy_factor", + "time_offset", + "delta_f_end", + "maximum_banding_frequency", + "minimum_banding_duration", + "durations", + "fb_dfb", + "Nbs", + "Mbs", + "Ks_Ke", + "banded_frequency_points", + "start_end_idxs", + "unique_to_original_frequencies", + "linear_coeffs", ]: _weights[key] = getattr(self, key) - _weights["waveform_frequencies"] = \ - self.waveform_generator.waveform_arguments['frequencies'] + _weights["waveform_frequencies"] = self.waveform_generator.waveform_arguments["frequencies"] if self.linear_interpolation: _weights["quadratic_coeffs"] = self.quadratic_coeffs else: @@ -678,11 +714,12 @@ def save_weights(self, filename): """ import h5py + if not filename.endswith(".hdf5"): filename += ".hdf5" logger.info(f"Saving multiband weights to {filename}") - with h5py.File(filename, 'w') as f: - recursively_save_dict_contents_to_group(f, '/', self.weights) + with h5py.File(filename, "w") as f: + recursively_save_dict_contents_to_group(f, "/", self.weights) def setup_multibanding_from_weights(self, weights): """ @@ -705,7 +742,7 @@ def setup_multibanding_from_weights(self, weights): to_set[ifo_name] = [data[str(b)] for b in range(len(data.keys()))] setattr(self, key, to_set) elif key == "waveform_frequencies": - self.waveform_generator.waveform_arguments['frequencies'] = weights["waveform_frequencies"] + self.waveform_generator.waveform_arguments["frequencies"] = weights["waveform_frequencies"] else: setattr(self, key, value) @@ -713,16 +750,14 @@ def _setup_time_marginalization_multiband(self): """This overwrites attributes set by _setup_time_marginalization of the base likelihood class""" N = self.Nbs[-1] // 2 self._delta_tc = self.durations[0] / N - self._times = \ - self.interferometers.start_time + np.arange(N) * self._delta_tc - self.time_prior_array = \ - self.priors['geocent_time'].prob(self._times) * self._delta_tc + self._times = self.interferometers.start_time + np.arange(N) * self._delta_tc + self.time_prior_array = self.priors["geocent_time"].prob(self._times) * self._delta_tc # allocate array which is FFTed at each likelihood evaluation self._full_d_h = np.zeros(N, dtype=complex) # idxs to convert full frequency points to banded frequency points, used for filling _full_d_h. self._full_to_multiband = [int(f * self.durations[0]) for f in self.banded_frequency_points] self._beam_pattern_reference_time = ( - self.priors['geocent_time'].minimum + self.priors['geocent_time'].maximum + self.priors["geocent_time"].minimum + self.priors["geocent_time"].maximum ) / 2 def calculate_snrs(self, waveform_polarizations, interferometer, return_array=True, parameters=None): @@ -750,51 +785,52 @@ def calculate_snrs(self, waveform_polarizations, interferometer, return_array=Tr if self.time_marginalization: time_ref = self._beam_pattern_reference_time else: - time_ref = parameters['geocent_time'] + time_ref = parameters["geocent_time"] strain = np.zeros(len(self.banded_frequency_points), dtype=complex) for mode in waveform_polarizations: response = interferometer.antenna_response( - parameters['ra'], parameters['dec'], - time_ref, parameters['psi'], mode + parameters["ra"], parameters["dec"], time_ref, parameters["psi"], mode ) strain += waveform_polarizations[mode][self.unique_to_original_frequencies] * response - dt = interferometer.time_delay_from_geocenter( - parameters['ra'], parameters['dec'], time_ref) - dt_geocent = parameters['geocent_time'] - interferometer.strain_data.start_time + dt = interferometer.time_delay_from_geocenter(parameters["ra"], parameters["dec"], time_ref) + dt_geocent = parameters["geocent_time"] - interferometer.strain_data.start_time ifo_time = dt_geocent + dt - strain *= np.exp(-1j * 2. * np.pi * self.banded_frequency_points * ifo_time) + strain *= np.exp(-1j * 2.0 * np.pi * self.banded_frequency_points * ifo_time) strain *= interferometer.calibration_model.get_calibration_factor( - self.banded_frequency_points, prefix='recalib_{}_'.format(interferometer.name), **parameters) + self.banded_frequency_points, prefix=f"recalib_{interferometer.name}_", **parameters + ) d_inner_h = np.conj(np.dot(strain, self.linear_coeffs[interferometer.name])) if self.linear_interpolation: optimal_snr_squared = np.vdot( - np.real(strain * np.conjugate(strain)), - self.quadratic_coeffs[interferometer.name] + np.real(strain * np.conjugate(strain)), self.quadratic_coeffs[interferometer.name] ) else: - optimal_snr_squared = 0. + optimal_snr_squared = 0.0 for b in range(self.number_of_bands): Ks, Ke = self.Ks_Ke[b] start_idx, end_idx = self.start_end_idxs[b] Mb = self.Mbs[b] if b == 0: - optimal_snr_squared += (4. / self.interferometers.duration) * np.vdot( - np.real(strain[start_idx:end_idx + 1] * np.conjugate(strain[start_idx:end_idx + 1])), - interferometer.frequency_mask[Ks:Ke + 1] * self.windows[start_idx:end_idx + 1] - / interferometer.power_spectral_density_array[Ks:Ke + 1]) + optimal_snr_squared += (4.0 / self.interferometers.duration) * np.vdot( + np.real(strain[start_idx : end_idx + 1] * np.conjugate(strain[start_idx : end_idx + 1])), + interferometer.frequency_mask[Ks : Ke + 1] + * self.windows[start_idx : end_idx + 1] + / interferometer.power_spectral_density_array[Ks : Ke + 1], + ) else: - self.wths[interferometer.name][b][Ks:Ke + 1] = ( - self.square_root_windows[start_idx:end_idx + 1] * strain[start_idx:end_idx + 1] + self.wths[interferometer.name][b][Ks : Ke + 1] = ( + self.square_root_windows[start_idx : end_idx + 1] * strain[start_idx : end_idx + 1] ) self.hbcs[interferometer.name][b][-Mb:] = np.fft.irfft(self.wths[interferometer.name][b]) thbc = np.fft.rfft(self.hbcs[interferometer.name][b]) - optimal_snr_squared += (4. / self.Tbhats[b]) * np.vdot( - np.real(thbc * np.conjugate(thbc)), self.Ibcs[interferometer.name][b]) + optimal_snr_squared += (4.0 / self.Tbhats[b]) * np.vdot( + np.real(thbc * np.conjugate(thbc)), self.Ibcs[interferometer.name][b] + ) complex_matched_filter_snr = d_inner_h / (optimal_snr_squared**0.5) @@ -802,8 +838,9 @@ def calculate_snrs(self, waveform_polarizations, interferometer, return_array=Tr self._full_d_h[self._full_to_multiband] *= 0 for b in range(self.number_of_bands): start_idx, end_idx = self.start_end_idxs[b] - self._full_d_h[self._full_to_multiband[start_idx:end_idx + 1]] += \ - strain[start_idx:end_idx + 1] * self.linear_coeffs[interferometer.name][start_idx:end_idx + 1] + self._full_d_h[self._full_to_multiband[start_idx : end_idx + 1]] += ( + strain[start_idx : end_idx + 1] * self.linear_coeffs[interferometer.name][start_idx : end_idx + 1] + ) d_inner_h_array = np.fft.fft(self._full_d_h) else: d_inner_h_array = None @@ -823,31 +860,26 @@ def generate_time_sample_from_marginalized_likelihood(self, signal_polarizations parameters = _fallback_to_parameters(self, parameters) parameters.update(self.get_sky_frame_parameters(parameters=parameters)) if signal_polarizations is None: - signal_polarizations = \ - self.waveform_generator.frequency_domain_strain(parameters) + signal_polarizations = self.waveform_generator.frequency_domain_strain(parameters) snrs = self._CalculatedSNRs() for interferometer in self.interferometers: - snrs += self.calculate_snrs( - waveform_polarizations=signal_polarizations, - interferometer=interferometer - ) + snrs += self.calculate_snrs(waveform_polarizations=signal_polarizations, interferometer=interferometer) d_inner_h = snrs.d_inner_h_array h_inner_h = snrs.optimal_snr_squared if self.distance_marginalization: - time_log_like = self.distance_marginalized_likelihood( - d_inner_h, h_inner_h) + time_log_like = self.distance_marginalized_likelihood(d_inner_h, h_inner_h) elif self.phase_marginalization: time_log_like = ln_i0(abs(d_inner_h)) - h_inner_h.real / 2 else: - time_log_like = (d_inner_h.real - h_inner_h.real / 2) + time_log_like = d_inner_h.real - h_inner_h.real / 2 times = self._times if self.jitter_time: times = times + parameters["time_jitter"] - time_prior_array = self.priors['geocent_time'].prob(times) + time_prior_array = self.priors["geocent_time"].prob(times) time_post = np.exp(time_log_like - max(time_log_like)) * time_prior_array time_post /= np.sum(time_post) return np.random.choice(times, p=time_post) diff --git a/bilby/gw/likelihood/relative.py b/bilby/gw/likelihood/relative.py index f4c72e8ef..853d46ba0 100644 --- a/bilby/gw/likelihood/relative.py +++ b/bilby/gw/likelihood/relative.py @@ -3,12 +3,12 @@ import numpy as np from scipy.optimize import differential_evolution -from .base import GravitationalWaveTransient -from ...core.utils import logger -from ...core.prior.base import Constraint -from ...core.prior import DeltaFunction from ...core.likelihood import _fallback_to_parameters +from ...core.prior import DeltaFunction +from ...core.prior.base import Constraint +from ...core.utils import logger from ..utils import noise_weighted_inner_product +from .base import GravitationalWaveTransient class RelativeBinningGravitationalWaveTransient(GravitationalWaveTransient): @@ -96,24 +96,26 @@ class RelativeBinningGravitationalWaveTransient(GravitationalWaveTransient): The relative binning likelihood does not currently support calibration marginalization. """ - def __init__(self, interferometers, - waveform_generator, - fiducial_parameters=None, - parameter_bounds=None, - maximization_kwargs=None, - update_fiducial_parameters=False, - distance_marginalization=False, - time_marginalization=False, - phase_marginalization=False, - priors=None, - distance_marginalization_lookup_table=None, - jitter_time=True, - reference_frame="sky", - time_reference="geocenter", - chi=1, - epsilon=0.5): - - super(RelativeBinningGravitationalWaveTransient, self).__init__( + def __init__( + self, + interferometers, + waveform_generator, + fiducial_parameters=None, + parameter_bounds=None, + maximization_kwargs=None, + update_fiducial_parameters=False, + distance_marginalization=False, + time_marginalization=False, + phase_marginalization=False, + priors=None, + distance_marginalization_lookup_table=None, + jitter_time=True, + reference_frame="sky", + time_reference="geocenter", + chi=1, + epsilon=0.5, + ): + super().__init__( interferometers=interferometers, waveform_generator=waveform_generator, distance_marginalization=distance_marginalization, @@ -123,7 +125,8 @@ def __init__(self, interferometers, distance_marginalization_lookup_table=distance_marginalization_lookup_table, jitter_time=jitter_time, reference_frame=reference_frame, - time_reference=time_reference) + time_reference=time_reference, + ) if fiducial_parameters is None: logger.info("Drawing fiducial parameters from prior.") @@ -154,8 +157,9 @@ def __init__(self, interferometers, if update_fiducial_parameters: # write a check to make sure prior is not None logger.info("Using scipy optimization to find maximum likelihood parameters.") - self.parameters_to_be_updated = [key for key in priors if not isinstance( - priors[key], (DeltaFunction, Constraint, float, int))] + self.parameters_to_be_updated = [ + key for key in priors if not isinstance(priors[key], (DeltaFunction, Constraint, float, int)) + ] logger.info(f"Parameters over which likelihood is maximized: {self.parameters_to_be_updated}") if parameter_bounds is None: logger.info("No parameter bounds were given. Using priors instead.") @@ -163,12 +167,16 @@ def __init__(self, interferometers, else: self.parameter_bounds = self.get_parameter_list_from_dictionary(parameter_bounds) self.fiducial_parameters = self.find_maximum_likelihood_parameters( - self.parameter_bounds, maximization_kwargs=maximization_kwargs) + self.parameter_bounds, maximization_kwargs=maximization_kwargs + ) logger.info(f"Fiducial likelihood: {self.log_likelihood_ratio(self.fiducial_parameters):.2f}") def __repr__(self): - return self.__class__.__name__ + '(interferometers={},\n\twaveform_generator={},\n\fiducial_parameters={},' \ - .format(self.interferometers, self.waveform_generator, self.fiducial_parameters) + return ( + self.__class__.__name__ + + f"(interferometers={self.interferometers},\n\twaveform_generator={self.waveform_generator}," + + f"\n\fiducial_parameters={self.fiducial_parameters}," + ) def setup_bins(self): """ @@ -188,18 +196,19 @@ def setup_bins(self): minimum_frequency = min(minimum_frequency, interferometer.minimum_frequency) maximum_frequency = min(maximum_frequency, self.maximum_frequency) frequency_array_useful = frequency_array[ - (frequency_array >= minimum_frequency) - & (frequency_array <= maximum_frequency) + (frequency_array >= minimum_frequency) & (frequency_array <= maximum_frequency) ] - d_alpha = self.chi * 2 * np.pi / np.abs( - (minimum_frequency ** gamma) * np.heaviside(-gamma, 1) - - (maximum_frequency ** gamma) * np.heaviside(gamma, 1) - ) - d_phi = np.sum( - np.sign(gamma) * d_alpha * frequency_array_useful ** gamma, - axis=0 + d_alpha = ( + self.chi + * 2 + * np.pi + / np.abs( + (minimum_frequency**gamma) * np.heaviside(-gamma, 1) + - (maximum_frequency**gamma) * np.heaviside(gamma, 1) + ) ) + d_phi = np.sum(np.sign(gamma) * d_alpha * frequency_array_useful**gamma, axis=0) d_phi_from_start = d_phi - d_phi[0] number_of_bins = int(d_phi_from_start[-1] // self.epsilon) bin_inds = list() @@ -220,25 +229,19 @@ def setup_bins(self): self.bin_sizes[-1] += 1 self.bin_freqs = np.array(bin_freqs) self.number_of_bins = len(self.bin_inds) - 1 - logger.debug( - f"Set up {self.number_of_bins} bins " - f"between {minimum_frequency} Hz and {maximum_frequency} Hz" - ) + logger.debug(f"Set up {self.number_of_bins} bins between {minimum_frequency} Hz and {maximum_frequency} Hz") self.waveform_generator.waveform_arguments["frequency_bin_edges"] = self.bin_freqs self.bin_widths = self.bin_freqs[1:] - self.bin_freqs[:-1] self.bin_centers = (self.bin_freqs[1:] + self.bin_freqs[:-1]) / 2 for interferometer in self.interferometers: name = interferometer.name - self.per_detector_fiducial_waveform_points[name] = ( - self.per_detector_fiducial_waveforms[name][self.bin_inds] - ) + self.per_detector_fiducial_waveform_points[name] = self.per_detector_fiducial_waveforms[name][self.bin_inds] def set_fiducial_waveforms(self, parameters): parameters = parameters.copy() self._set_fiducial() parameters.update(self.get_sky_frame_parameters(parameters=parameters)) - self.fiducial_polarizations = self.waveform_generator.frequency_domain_strain( - parameters) + self.fiducial_polarizations = self.waveform_generator.frequency_domain_strain(parameters) self._unset_fiducial() maximum_nonzero_index = np.where(self.fiducial_polarizations["plus"] != 0j)[0][-1] @@ -256,8 +259,7 @@ def set_fiducial_waveforms(self, parameters): wf[interferometer.frequency_array > self.maximum_frequency] = 0 self.per_detector_fiducial_waveforms[interferometer.name] = wf - def find_maximum_likelihood_parameters(self, parameter_bounds, - iterations=5, maximization_kwargs=None): + def find_maximum_likelihood_parameters(self, parameter_bounds, iterations=5, maximization_kwargs=None): if maximization_kwargs is None: maximization_kwargs = dict() parameters = deepcopy(self.fiducial_parameters) @@ -274,7 +276,7 @@ def find_maximum_likelihood_parameters(self, parameter_bounds, x0=updated_parameters_list, **maximization_kwargs, ) - updated_parameters_list = output['x'] + updated_parameters_list = output["x"] updated_parameters = deepcopy(self.fiducial_parameters) updated_parameters.update(self.get_parameter_dictionary_from_list(updated_parameters_list)) self.set_fiducial_waveforms(updated_parameters) @@ -375,9 +377,7 @@ def compute_waveform_ratio_per_interferometer(self, waveform_polarizations, inte def _compute_full_waveform(self, signal_polarizations, interferometer, parameters=None): fiducial_waveform = self.per_detector_fiducial_waveforms[interferometer.name] r0, r1 = self.compute_waveform_ratio_per_interferometer( - waveform_polarizations=signal_polarizations, - interferometer=interferometer, - parameters=parameters + waveform_polarizations=signal_polarizations, interferometer=interferometer, parameters=parameters ) idxs = slice(self.bin_inds[0], self.bin_inds[-1] + 1) @@ -400,7 +400,7 @@ def calculate_snrs(self, waveform_polarizations, interferometer, return_array=Tr d_inner_h = np.sum(a0 * np.conjugate(r0) + a1 * np.conjugate(r1)) h_inner_h = np.sum(b0 * np.abs(r0) ** 2 + 2 * b1 * np.real(r0 * np.conjugate(r1))) optimal_snr_squared = h_inner_h - complex_matched_filter_snr = d_inner_h / (optimal_snr_squared ** 0.5) + complex_matched_filter_snr = d_inner_h / (optimal_snr_squared**0.5) if return_array and self.time_marginalization: full_waveform = self._compute_full_waveform( @@ -408,10 +408,15 @@ def calculate_snrs(self, waveform_polarizations, interferometer, return_array=Tr interferometer=interferometer, parameters=parameters, ) - d_inner_h_array = 4 / self.waveform_generator.duration * np.fft.fft( - full_waveform[0:-1] - * interferometer.frequency_domain_strain.conjugate()[0:-1] - / interferometer.power_spectral_density_array[0:-1]) + d_inner_h_array = ( + 4 + / self.waveform_generator.duration + * np.fft.fft( + full_waveform[0:-1] + * interferometer.frequency_domain_strain.conjugate()[0:-1] + / interferometer.power_spectral_density_array[0:-1] + ) + ) else: d_inner_h_array = None @@ -420,7 +425,7 @@ def calculate_snrs(self, waveform_polarizations, interferometer, return_array=Tr d_inner_h=d_inner_h, optimal_snr_squared=optimal_snr_squared.real, complex_matched_filter_snr=complex_matched_filter_snr, - d_inner_h_array=d_inner_h_array + d_inner_h_array=d_inner_h_array, ) def _set_fiducial(self): diff --git a/bilby/gw/likelihood/roq.py b/bilby/gw/likelihood/roq.py index 0f5a4c003..6d1a4f4c0 100644 --- a/bilby/gw/likelihood/roq.py +++ b/bilby/gw/likelihood/roq.py @@ -1,13 +1,10 @@ - import numpy as np -from .base import GravitationalWaveTransient -from ...core.utils import ( - logger, create_frequency_series, speed_of_light, radius_of_earth -) from ...core.likelihood import _fallback_to_parameters +from ...core.utils import create_frequency_series, logger, radius_of_earth, speed_of_light from ..prior import CBCPriorDict from ..utils import ln_i0 +from .base import GravitationalWaveTransient class ROQGravitationalWaveTransient(GravitationalWaveTransient): @@ -83,28 +80,40 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient): - e.g., "H1": sample in the time of arrival at H1 """ - def __init__( - self, interferometers, waveform_generator, priors, - weights=None, linear_matrix=None, quadratic_matrix=None, - roq_params=None, roq_params_check=True, roq_scale_factor=1, - distance_marginalization=False, phase_marginalization=False, - time_marginalization=False, jitter_time=True, delta_tc=None, - distance_marginalization_lookup_table=None, - reference_frame="sky", time_reference="geocenter", - parameter_conversion=None + def __init__( + self, + interferometers, + waveform_generator, + priors, + weights=None, + linear_matrix=None, + quadratic_matrix=None, + roq_params=None, + roq_params_check=True, + roq_scale_factor=1, + distance_marginalization=False, + phase_marginalization=False, + time_marginalization=False, + jitter_time=True, + delta_tc=None, + distance_marginalization_lookup_table=None, + reference_frame="sky", + time_reference="geocenter", + parameter_conversion=None, ): self._delta_tc = delta_tc - super(ROQGravitationalWaveTransient, self).__init__( + super().__init__( interferometers=interferometers, - waveform_generator=waveform_generator, priors=priors, + waveform_generator=waveform_generator, + priors=priors, distance_marginalization=distance_marginalization, phase_marginalization=phase_marginalization, time_marginalization=time_marginalization, distance_marginalization_lookup_table=distance_marginalization_lookup_table, jitter_time=jitter_time, reference_frame=reference_frame, - time_reference=time_reference + time_reference=time_reference, ) self.roq_params_check = roq_params_check @@ -121,37 +130,43 @@ def __init__( elif isinstance(weights, str): self.weights = self.load_weights(weights) else: - is_hdf5_linear = isinstance(linear_matrix, str) and linear_matrix.endswith('.hdf5') - linear_matrix = self._parse_basis(linear_matrix, 'linear') - is_hdf5_quadratic = isinstance(quadratic_matrix, str) and quadratic_matrix.endswith('.hdf5') - quadratic_matrix = self._parse_basis(quadratic_matrix, 'quadratic') + is_hdf5_linear = isinstance(linear_matrix, str) and linear_matrix.endswith(".hdf5") + linear_matrix = self._parse_basis(linear_matrix, "linear") + is_hdf5_quadratic = isinstance(quadratic_matrix, str) and quadratic_matrix.endswith(".hdf5") + quadratic_matrix = self._parse_basis(quadratic_matrix, "quadratic") # retrieve roq params from a basis file if it is .hdf5 if self.roq_params is None: if is_hdf5_linear: self.roq_params = np.array( - [(linear_matrix['minimum_frequency_hz'][()], - linear_matrix['maximum_frequency_hz'][()], - linear_matrix['duration_s'][()])], - dtype=[('flow', float), ('fhigh', float), ('seglen', float)] + [ + ( + linear_matrix["minimum_frequency_hz"][()], + linear_matrix["maximum_frequency_hz"][()], + linear_matrix["duration_s"][()], + ) + ], + dtype=[("flow", float), ("fhigh", float), ("seglen", float)], ) if is_hdf5_quadratic: if self.roq_params is None: self.roq_params = np.array( - [(quadratic_matrix['minimum_frequency_hz'][()], - quadratic_matrix['maximum_frequency_hz'][()], - quadratic_matrix['duration_s'][()])], - dtype=[('flow', float), ('fhigh', float), ('seglen', float)] + [ + ( + quadratic_matrix["minimum_frequency_hz"][()], + quadratic_matrix["maximum_frequency_hz"][()], + quadratic_matrix["duration_s"][()], + ) + ], + dtype=[("flow", float), ("fhigh", float), ("seglen", float)], ) else: - self.roq_params['flow'] = max( - self.roq_params['flow'], quadratic_matrix['minimum_frequency_hz'][()] - ) - self.roq_params['fhigh'] = min( - self.roq_params['fhigh'], quadratic_matrix['maximum_frequency_hz'][()] + self.roq_params["flow"] = max( + self.roq_params["flow"], quadratic_matrix["minimum_frequency_hz"][()] ) - self.roq_params['seglen'] = min( - self.roq_params['seglen'], quadratic_matrix['duration_s'][()] + self.roq_params["fhigh"] = min( + self.roq_params["fhigh"], quadratic_matrix["maximum_frequency_hz"][()] ) + self.roq_params["seglen"] = min(self.roq_params["seglen"], quadratic_matrix["duration_s"][()]) if self.roq_params is not None: for ifo in self.interferometers: self.perform_roq_params_check(ifo) @@ -163,13 +178,13 @@ def __init__( if is_hdf5_quadratic: quadratic_matrix.close() - self.number_of_bases_linear = len(self.weights[f'{self.interferometers[0].name}_linear']) - self.number_of_bases_quadratic = len(self.weights[f'{self.interferometers[0].name}_quadratic']) + self.number_of_bases_linear = len(self.weights[f"{self.interferometers[0].name}_linear"]) + self.number_of_bases_quadratic = len(self.weights[f"{self.interferometers[0].name}_quadratic"]) self._cache = dict(parameters=None, basis_number_linear=None, basis_number_quadratic=None) self.parameter_conversion = parameter_conversion - for basis_type in ['linear', 'quadratic']: - number_of_bases = getattr(self, f'number_of_bases_{basis_type}') + for basis_type in ["linear", "quadratic"]: + number_of_bases = getattr(self, f"number_of_bases_{basis_type}") if number_of_bases > 1: self._verify_numbers_of_prior_ranges_and_frequency_nodes(basis_type) else: @@ -179,11 +194,10 @@ def __init__( self._set_unique_frequency_nodes_and_inverse() # need to fill waveform_arguments here if single basis is used, as they will never be updated. if self.number_of_bases_linear == 1 and self.number_of_bases_quadratic == 1: - frequency_nodes, linear_indices, quadratic_indices = \ - self._unique_frequency_nodes_and_inverse[0][0] - self._waveform_generator.waveform_arguments['frequency_nodes'] = frequency_nodes - self._waveform_generator.waveform_arguments['linear_indices'] = linear_indices - self._waveform_generator.waveform_arguments['quadratic_indices'] = quadratic_indices + frequency_nodes, linear_indices, quadratic_indices = self._unique_frequency_nodes_and_inverse[0][0] + self._waveform_generator.waveform_arguments["frequency_nodes"] = frequency_nodes + self._waveform_generator.waveform_arguments["linear_indices"] = linear_indices + self._waveform_generator.waveform_arguments["quadratic_indices"] = quadratic_indices def _verify_numbers_of_prior_ranges_and_frequency_nodes(self, basis_type): """ @@ -195,29 +209,28 @@ def _verify_numbers_of_prior_ranges_and_frequency_nodes(self, basis_type): basis_type: str """ - number_of_bases = getattr(self, f'number_of_bases_{basis_type}') - key = f'prior_range_{basis_type}' + number_of_bases = getattr(self, f"number_of_bases_{basis_type}") + key = f"prior_range_{basis_type}" try: prior_ranges = self.weights[key] except KeyError: - raise AttributeError( - f'For the use of multiple {basis_type} ROQ bases, weights should contain "{key}".') + raise AttributeError(f'For the use of multiple {basis_type} ROQ bases, weights should contain "{key}".') else: for param_name in prior_ranges: if len(prior_ranges[param_name]) != number_of_bases: raise ValueError( - f'The number of prior ranges for "{param_name}" does not ' - f'match the number of {basis_type} bases') - key = f'frequency_nodes_{basis_type}' + f'The number of prior ranges for "{param_name}" does not match the number of {basis_type} bases' + ) + key = f"frequency_nodes_{basis_type}" try: frequency_nodes = self.weights[key] except KeyError: - raise AttributeError( - f'For the use of multiple {basis_type} ROQ bases, weights should contain "{key}".') + raise AttributeError(f'For the use of multiple {basis_type} ROQ bases, weights should contain "{key}".') else: if len(frequency_nodes) != number_of_bases: raise ValueError( - f'The number of arrays of frequency nodes does not match the number of {basis_type} bases') + f"The number of arrays of frequency nodes does not match the number of {basis_type} bases" + ) def _verify_prior_ranges(self, basis_type): """Check if the union of prior ranges is within the ROQ basis bounds. @@ -227,7 +240,7 @@ def _verify_prior_ranges(self, basis_type): basis_type: str """ - key = f'prior_range_{basis_type}' + key = f"prior_range_{basis_type}" if key not in self.weights: return prior_ranges = self.weights[key] @@ -236,16 +249,14 @@ def _verify_prior_ranges(self, basis_type): basis_minimum = np.min(prior_ranges_of_this_param[:, 0]) if prior_minimum < basis_minimum: raise BilbyROQParamsRangeError( - f"Prior minimum of {param_name} {prior_minimum} less " - f"than ROQ basis bound {basis_minimum}" + f"Prior minimum of {param_name} {prior_minimum} less than ROQ basis bound {basis_minimum}" ) prior_maximum = self.priors[param_name].maximum basis_maximum = np.max(prior_ranges_of_this_param[:, 1]) if prior_maximum > basis_maximum: raise BilbyROQParamsRangeError( - f"Prior maximum of {param_name} {prior_maximum} greater " - f"than ROQ basis bound {basis_maximum}" + f"Prior maximum of {param_name} {prior_maximum} greater than ROQ basis bound {basis_maximum}" ) def _check_frequency_nodes_exist_for_single_basis(self, basis_type): @@ -259,9 +270,9 @@ def _check_frequency_nodes_exist_for_single_basis(self, basis_type): basis_type: str """ - key = f'frequency_nodes_{basis_type}' + key = f"frequency_nodes_{basis_type}" if not (key in self.weights or key in self._waveform_generator.waveform_arguments): - raise AttributeError(f'{key} should be contained in weights or waveform arguments.') + raise AttributeError(f"{key} should be contained in weights or waveform arguments.") elif key not in self._waveform_generator.waveform_arguments: self._waveform_generator.waveform_arguments[key] = self.weights[key][0] elif key not in self.weights: @@ -274,33 +285,29 @@ def _set_unique_frequency_nodes_and_inverse(self): self._unique_frequency_nodes_and_inverse = [] for idx_linear in range(self.number_of_bases_linear): tmp = [] - frequency_nodes_linear = self.weights['frequency_nodes_linear'][idx_linear] + frequency_nodes_linear = self.weights["frequency_nodes_linear"][idx_linear] size_linear = len(frequency_nodes_linear) for idx_quadratic in range(self.number_of_bases_quadratic): - frequency_nodes_quadratic = self.weights['frequency_nodes_quadratic'][idx_quadratic] + frequency_nodes_quadratic = self.weights["frequency_nodes_quadratic"][idx_quadratic] frequency_nodes_unique, original_indices = np.unique( - np.hstack((frequency_nodes_linear, frequency_nodes_quadratic)), - return_inverse=True + np.hstack((frequency_nodes_linear, frequency_nodes_quadratic)), return_inverse=True ) linear_indices = original_indices[:size_linear] quadratic_indices = original_indices[size_linear:] - tmp.append( - (frequency_nodes_unique, linear_indices, quadratic_indices) - ) + tmp.append((frequency_nodes_unique, linear_indices, quadratic_indices)) self._unique_frequency_nodes_and_inverse.append(tmp) def _setup_time_marginalization(self): if self._delta_tc is None: self._delta_tc = self._get_time_resolution() - tcmin = self.priors['geocent_time'].minimum - tcmax = self.priors['geocent_time'].maximum + tcmin = self.priors["geocent_time"].minimum + tcmax = self.priors["geocent_time"].maximum number_of_time_samples = int(np.ceil((tcmax - tcmin) / self._delta_tc)) # adjust delta tc so that the last time sample has an equal weight self._delta_tc = (tcmax - tcmin) / number_of_time_samples - logger.info( - "delta tc for time marginalization = {} seconds.".format(self._delta_tc)) - self._times = tcmin + self._delta_tc / 2. + np.arange(number_of_time_samples) * self._delta_tc - self._beam_pattern_reference_time = (tcmin + tcmax) / 2. + logger.info(f"delta tc for time marginalization = {self._delta_tc} seconds.") + self._times = tcmin + self._delta_tc / 2.0 + np.arange(number_of_time_samples) * self._delta_tc + self._beam_pattern_reference_time = (tcmin + tcmax) / 2.0 @staticmethod def _parse_basis(basis, basis_type): @@ -319,22 +326,23 @@ def _parse_basis(basis, basis_type): basis : hdf5-like object """ - if basis_type not in ['linear', 'quadratic']: - raise ValueError(f'basis_type {basis_type} not recognized') + if basis_type not in ["linear", "quadratic"]: + raise ValueError(f"basis_type {basis_type} not recognized") if isinstance(basis, str): - logger.info(f'Loading {basis_type}_matrix from {basis}') - format = basis.split('.')[-1] - if format == 'npy': - basis = {f'basis_{basis_type}': {'0': {'basis': np.load(basis)}}} - elif format == 'hdf5': + logger.info(f"Loading {basis_type}_matrix from {basis}") + format = basis.split(".")[-1] + if format == "npy": + basis = {f"basis_{basis_type}": {"0": {"basis": np.load(basis)}}} + elif format == "hdf5": import h5py - basis = h5py.File(basis, 'r') + + basis = h5py.File(basis, "r") else: - raise IOError(f'Format {format} not recognized.') + raise OSError(f"Format {format} not recognized.") elif isinstance(basis, np.ndarray): - basis = {f'basis_{basis_type}': {'0': {'basis': basis.T}}} + basis = {f"basis_{basis_type}": {"0": {"basis": basis.T}}} else: - raise TypeError('basis needs to be str or np.ndarray') + raise TypeError("basis needs to be str or np.ndarray") return basis def _select_prior_ranges(self, prior_ranges): @@ -363,13 +371,13 @@ def _select_prior_ranges(self, prior_ranges): except KeyError: continue prior_ranges_of_this_param = prior_ranges[param_name] - in_prior_range *= \ - (prior_ranges_of_this_param[:, 1] >= prior.minimum) * \ - (prior_ranges_of_this_param[:, 0] <= prior.maximum) + in_prior_range *= (prior_ranges_of_this_param[:, 1] >= prior.minimum) * ( + prior_ranges_of_this_param[:, 0] <= prior.maximum + ) idxs_in_prior_range = np.arange(number_of_prior_ranges)[in_prior_range] - return idxs_in_prior_range, \ - dict((param_name, prior_ranges[param_name][idxs_in_prior_range]) - for param_name in param_names) + return idxs_in_prior_range, dict( + (param_name, prior_ranges[param_name][idxs_in_prior_range]) for param_name in param_names + ) def _update_basis(self, parameters=None): """ @@ -386,41 +394,42 @@ def _update_basis(self, parameters=None): if self._cache["parameters"] == parameters: return for basis_type, number_of_bases in zip( - ['linear', 'quadratic'], [self.number_of_bases_linear, self.number_of_bases_quadratic] + ["linear", "quadratic"], [self.number_of_bases_linear, self.number_of_bases_quadratic] ): - basis_number_key = f'basis_number_{basis_type}' + basis_number_key = f"basis_number_{basis_type}" if number_of_bases == 1: self._cache[basis_number_key] = 0 continue in_prior_range = np.ones(number_of_bases, dtype=bool) - prior_range_key = f'prior_range_{basis_type}' + prior_range_key = f"prior_range_{basis_type}" for param_name in self.weights[prior_range_key]: if param_name not in parameters: continue - in_prior_range *= \ - (self.weights[prior_range_key][param_name][:, 0] <= parameters[param_name]) * \ - (self.weights[prior_range_key][param_name][:, 1] >= parameters[param_name]) + in_prior_range *= (self.weights[prior_range_key][param_name][:, 0] <= parameters[param_name]) * ( + self.weights[prior_range_key][param_name][:, 1] >= parameters[param_name] + ) self._cache[basis_number_key] = np.arange(number_of_bases)[in_prior_range][0] - basis_number_linear = self._cache['basis_number_linear'] - basis_number_quadratic = self._cache['basis_number_quadratic'] - frequency_nodes, linear_indices, quadratic_indices = \ - self._unique_frequency_nodes_and_inverse[basis_number_linear][basis_number_quadratic] - self._waveform_generator.waveform_arguments['frequency_nodes'] = frequency_nodes - self._waveform_generator.waveform_arguments['linear_indices'] = linear_indices - self._waveform_generator.waveform_arguments['quadratic_indices'] = quadratic_indices - self._cache['parameters'] = parameters.copy() + basis_number_linear = self._cache["basis_number_linear"] + basis_number_quadratic = self._cache["basis_number_quadratic"] + frequency_nodes, linear_indices, quadratic_indices = self._unique_frequency_nodes_and_inverse[ + basis_number_linear + ][basis_number_quadratic] + self._waveform_generator.waveform_arguments["frequency_nodes"] = frequency_nodes + self._waveform_generator.waveform_arguments["linear_indices"] = linear_indices + self._waveform_generator.waveform_arguments["quadratic_indices"] = quadratic_indices + self._cache["parameters"] = parameters.copy() @property def basis_number_linear(self): if self.number_of_bases_linear > 1: - return self._cache['basis_number_linear'] + return self._cache["basis_number_linear"] else: return 0 @property def basis_number_quadratic(self): if self.number_of_bases_quadratic > 1: - return self._cache['basis_number_quadratic'] + return self._cache["basis_number_quadratic"] else: return 0 @@ -451,53 +460,49 @@ def calculate_snrs(self, waveform_polarizations, interferometer, return_array=Tr if self.time_marginalization: time_ref = self._beam_pattern_reference_time else: - time_ref = parameters['geocent_time'] + time_ref = parameters["geocent_time"] - frequency_nodes = self.waveform_generator.waveform_arguments['frequency_nodes'] - linear_indices = self.waveform_generator.waveform_arguments['linear_indices'] - quadratic_indices = self.waveform_generator.waveform_arguments['quadratic_indices'] + frequency_nodes = self.waveform_generator.waveform_arguments["frequency_nodes"] + linear_indices = self.waveform_generator.waveform_arguments["linear_indices"] + quadratic_indices = self.waveform_generator.waveform_arguments["quadratic_indices"] size_linear = len(linear_indices) size_quadratic = len(quadratic_indices) h_linear = np.zeros(size_linear, dtype=complex) h_quadratic = np.zeros(size_quadratic, dtype=complex) - for mode in waveform_polarizations['linear']: + for mode in waveform_polarizations["linear"]: response = interferometer.antenna_response( - parameters['ra'], parameters['dec'], - time_ref, - parameters['psi'], - mode + parameters["ra"], parameters["dec"], time_ref, parameters["psi"], mode ) - h_linear += waveform_polarizations['linear'][mode] * response - h_quadratic += waveform_polarizations['quadratic'][mode] * response + h_linear += waveform_polarizations["linear"][mode] * response + h_quadratic += waveform_polarizations["quadratic"][mode] * response calib_factor = interferometer.calibration_model.get_calibration_factor( - frequency_nodes, prefix='recalib_{}_'.format(interferometer.name), **parameters) + frequency_nodes, prefix=f"recalib_{interferometer.name}_", **parameters + ) h_linear *= calib_factor[linear_indices] h_quadratic *= calib_factor[quadratic_indices] optimal_snr_squared = np.vdot( - np.abs(h_quadratic)**2, - self.weights[interferometer.name + '_quadratic'][self.basis_number_quadratic] + np.abs(h_quadratic) ** 2, self.weights[interferometer.name + "_quadratic"][self.basis_number_quadratic] ) - dt = interferometer.time_delay_from_geocenter( - parameters['ra'], parameters['dec'], time_ref) - dt_geocent = parameters['geocent_time'] - interferometer.strain_data.start_time + dt = interferometer.time_delay_from_geocenter(parameters["ra"], parameters["dec"], time_ref) + dt_geocent = parameters["geocent_time"] - interferometer.strain_data.start_time ifo_time = dt_geocent + dt - indices, in_bounds = self._closest_time_indices( - ifo_time, self.weights['time_samples']) + indices, in_bounds = self._closest_time_indices(ifo_time, self.weights["time_samples"]) if not in_bounds: logger.debug("SNR calculation error: requested time at edge of ROQ time samples") d_inner_h = -np.inf complex_matched_filter_snr = -np.inf else: d_inner_h_tc_array = np.einsum( - 'i,ji->j', np.conjugate(h_linear), - self.weights[interferometer.name + '_linear'][self.basis_number_linear][indices]) + "i,ji->j", + np.conjugate(h_linear), + self.weights[interferometer.name + "_linear"][self.basis_number_linear][indices], + ) - d_inner_h = self._interp_five_samples( - self.weights['time_samples'][indices], d_inner_h_tc_array, ifo_time) + d_inner_h = self._interp_five_samples(self.weights["time_samples"][indices], d_inner_h_tc_array, ifo_time) with np.errstate(invalid="ignore"): complex_matched_filter_snr = d_inner_h / (optimal_snr_squared**0.5) @@ -506,7 +511,7 @@ def calculate_snrs(self, waveform_polarizations, interferometer, return_array=Tr ifo_times = self._times - interferometer.strain_data.start_time ifo_times += dt if self.jitter_time: - ifo_times += parameters['time_jitter'] + ifo_times += parameters["time_jitter"] d_inner_h_array = self._calculate_d_inner_h_array(ifo_times, h_linear, interferometer.name) else: d_inner_h_array = None @@ -562,12 +567,12 @@ def _interp_five_samples(time_samples, values, time): value: float The value of the function at the input time """ - r1 = (-values[0] + 8. * values[1] - 14. * values[2] + 8. * values[3] - values[4]) / 4. - r2 = values[2] - 2. * values[3] + values[4] + r1 = (-values[0] + 8.0 * values[1] - 14.0 * values[2] + 8.0 * values[3] - values[4]) / 4.0 + r2 = values[2] - 2.0 * values[3] + values[4] a = (time_samples[3] - time) / (time_samples[1] - time_samples[0]) - b = 1. - a - c = (a**3. - a) / 6. - d = (b**3. - b) / 6. + b = 1.0 - a + c = (a**3.0 - a) / 6.0 + d = (b**3.0 - b) / 6.0 return a * values[2] + b * values[3] + c * r1 + d * r2 def _calculate_d_inner_h_array(self, times, h_linear, ifo_name): @@ -588,12 +593,12 @@ def _calculate_d_inner_h_array(self, times, h_linear, ifo_name): ======= d_inner_h_array: array-like """ - roq_time_space = self.weights['time_samples'][1] - self.weights['time_samples'][0] - times_per_roq_time_space = (times - self.weights['time_samples'][0]) / roq_time_space + roq_time_space = self.weights["time_samples"][1] - self.weights["time_samples"][0] + times_per_roq_time_space = (times - self.weights["time_samples"][0]) / roq_time_space closest_idxs = np.floor(times_per_roq_time_space).astype(int) # Get the nearest 5 samples of d_inner_h. Calculate only the required d_inner_h values if the time # spacing is larger than 5 times the ROQ time spacing. - weights_linear = self.weights[ifo_name + '_linear'][self.basis_number_linear] + weights_linear = self.weights[ifo_name + "_linear"][self.basis_number_linear] h_linear_conj = np.conjugate(h_linear) if (times[1] - times[0]) / roq_time_space > 5: d_inner_h_m2 = np.dot(weights_linear[closest_idxs - 2], h_linear_conj) @@ -610,15 +615,15 @@ def _calculate_d_inner_h_array(self, times, h_linear, ifo_name): d_inner_h_p2 = d_inner_h_at_roq_time_samples[closest_idxs + 2] # quantities required for spline interpolation b = times_per_roq_time_space - closest_idxs - a = 1. - b - c = (a**3. - a) / 6. - d = (b**3. - b) / 6. - r1 = (-d_inner_h_m2 + 8. * d_inner_h_m1 - 14. * d_inner_h_0 + 8. * d_inner_h_p1 - d_inner_h_p2) / 4. - r2 = d_inner_h_0 - 2. * d_inner_h_p1 + d_inner_h_p2 + a = 1.0 - b + c = (a**3.0 - a) / 6.0 + d = (b**3.0 - b) / 6.0 + r1 = (-d_inner_h_m2 + 8.0 * d_inner_h_m1 - 14.0 * d_inner_h_0 + 8.0 * d_inner_h_p1 - d_inner_h_p2) / 4.0 + r2 = d_inner_h_0 - 2.0 * d_inner_h_p1 + d_inner_h_p2 return a * d_inner_h_0 + b * d_inner_h_p1 + c * r1 + d * r2 def perform_roq_params_check(self, ifo=None): - """ Perform checking that the prior and data are valid for the ROQ + """Perform checking that the prior and data are valid for the ROQ Parameters ========== @@ -630,43 +635,39 @@ def perform_roq_params_check(self, ifo=None): return else: if getattr(self, "roq_params_file", None) is not None: - msg = ("Check ROQ params {} with roq_scale_factor={}" - .format(self.roq_params_file, self.roq_scale_factor)) + msg = f"Check ROQ params {self.roq_params_file} with roq_scale_factor={self.roq_scale_factor}" else: - msg = ("Check ROQ params with roq_scale_factor={}" - .format(self.roq_scale_factor)) + msg = f"Check ROQ params with roq_scale_factor={self.roq_scale_factor}" logger.info(msg) roq_params = self.roq_params - roq_minimum_frequency = roq_params['flow'] * self.roq_scale_factor - roq_maximum_frequency = roq_params['fhigh'] * self.roq_scale_factor - roq_segment_length = roq_params['seglen'] / self.roq_scale_factor + roq_minimum_frequency = roq_params["flow"] * self.roq_scale_factor + roq_maximum_frequency = roq_params["fhigh"] * self.roq_scale_factor + roq_segment_length = roq_params["seglen"] / self.roq_scale_factor try: - roq_minimum_chirp_mass = roq_params['chirpmassmin'] / self.roq_scale_factor + roq_minimum_chirp_mass = roq_params["chirpmassmin"] / self.roq_scale_factor except ValueError: roq_minimum_chirp_mass = None try: - roq_maximum_chirp_mass = roq_params['chirpmassmax'] / self.roq_scale_factor + roq_maximum_chirp_mass = roq_params["chirpmassmax"] / self.roq_scale_factor except ValueError: roq_maximum_chirp_mass = None try: - roq_minimum_component_mass = roq_params['compmin'] / self.roq_scale_factor + roq_minimum_component_mass = roq_params["compmin"] / self.roq_scale_factor except ValueError: roq_minimum_component_mass = None if ifo.maximum_frequency > roq_maximum_frequency: raise BilbyROQParamsRangeError( - "Requested maximum frequency {} larger than ROQ basis fhigh {}" - .format(ifo.maximum_frequency, roq_maximum_frequency) + f"Requested maximum frequency {ifo.maximum_frequency} larger than ROQ basis " + f"fhigh {roq_maximum_frequency}" ) if ifo.minimum_frequency < roq_minimum_frequency: raise BilbyROQParamsRangeError( - "Requested minimum frequency {} lower than ROQ basis flow {}" - .format(ifo.minimum_frequency, roq_minimum_frequency) + f"Requested minimum frequency {ifo.minimum_frequency} lower than ROQ basis flow {roq_minimum_frequency}" ) if ifo.strain_data.duration != roq_segment_length: - raise BilbyROQParamsRangeError( - "Requested duration differs from ROQ basis seglen") + raise BilbyROQParamsRangeError("Requested duration differs from ROQ basis seglen") priors = self.priors if isinstance(priors, CBCPriorDict) is False: @@ -678,8 +679,8 @@ def perform_roq_params_check(self, ifo=None): logger.warning("Unable to check minimum chirp mass ROQ bounds") elif priors.minimum_chirp_mass < roq_minimum_chirp_mass: raise BilbyROQParamsRangeError( - "Prior minimum chirp mass {} less than ROQ basis bound {}" - .format(priors.minimum_chirp_mass, roq_minimum_chirp_mass) + f"Prior minimum chirp mass {priors.minimum_chirp_mass} less than " + f"ROQ basis bound {roq_minimum_chirp_mass}" ) if roq_maximum_chirp_mass is not None: @@ -687,8 +688,8 @@ def perform_roq_params_check(self, ifo=None): logger.warning("Unable to check maximum_chirp mass ROQ bounds") elif priors.maximum_chirp_mass > roq_maximum_chirp_mass: raise BilbyROQParamsRangeError( - "Prior maximum chirp mass {} greater than ROQ basis bound {}" - .format(priors.maximum_chirp_mass, roq_maximum_chirp_mass) + f"Prior maximum chirp mass {priors.maximum_chirp_mass} greater than " + f"ROQ basis bound {roq_maximum_chirp_mass}" ) if roq_minimum_component_mass is not None: @@ -696,8 +697,8 @@ def perform_roq_params_check(self, ifo=None): logger.warning("Unable to check minimum component mass ROQ bounds") elif priors.minimum_component_mass < roq_minimum_component_mass: raise BilbyROQParamsRangeError( - "Prior minimum component mass {} less than ROQ basis bound {}" - .format(priors.minimum_component_mass, roq_minimum_component_mass) + f"Prior minimum component mass {priors.minimum_component_mass} less " + f"than ROQ basis bound {roq_minimum_component_mass}" ) def _set_weights(self, linear_matrix, quadratic_matrix): @@ -715,34 +716,44 @@ def _set_weights(self, linear_matrix, quadratic_matrix): earth_light_crossing_time = 2 * radius_of_earth / speed_of_light + 5 * time_space start_idx = max( 0, - int(np.floor(( - self.priors['{}_time'.format(self.time_reference)].minimum - - earth_light_crossing_time - - self.interferometers.start_time - ) / time_space)) + int( + np.floor( + ( + self.priors[f"{self.time_reference}_time"].minimum + - earth_light_crossing_time + - self.interferometers.start_time + ) + / time_space + ) + ), ) end_idx = min( number_of_time_samples - 1, - int(np.ceil(( - self.priors['{}_time'.format(self.time_reference)].maximum - + earth_light_crossing_time - - self.interferometers.start_time - ) / time_space)) + int( + np.ceil( + ( + self.priors[f"{self.time_reference}_time"].maximum + + earth_light_crossing_time + - self.interferometers.start_time + ) + / time_space + ) + ), ) - self.weights['time_samples'] = np.arange(start_idx, end_idx + 1) * time_space - logger.info("Using {} ROQ time samples".format(len(self.weights['time_samples']))) + self.weights["time_samples"] = np.arange(start_idx, end_idx + 1) * time_space + logger.info("Using {} ROQ time samples".format(len(self.weights["time_samples"]))) # select bases to be used, set prior ranges and frequency nodes if exist idxs_in_prior_range = dict() - for basis_type, matrix in zip(['linear', 'quadratic'], [linear_matrix, quadratic_matrix]): - key = f'prior_range_{basis_type}' + for basis_type, matrix in zip(["linear", "quadratic"], [linear_matrix, quadratic_matrix]): + key = f"prior_range_{basis_type}" if key in matrix: prior_ranges = {} for param_name in matrix[key]: - if 'roq_scale_power' in matrix[key][param_name].attrs: - roq_scale_factor = self.roq_scale_factor**matrix[key][param_name].attrs['roq_scale_power'] + if "roq_scale_power" in matrix[key][param_name].attrs: + roq_scale_factor = self.roq_scale_factor ** matrix[key][param_name].attrs["roq_scale_power"] else: - roq_scale_factor = 1. + roq_scale_factor = 1.0 prior_ranges[param_name] = matrix[key][param_name][()] * roq_scale_factor selected_idxs, selected_prior_ranges = self._select_prior_ranges(prior_ranges) if len(selected_idxs) == 0: @@ -751,17 +762,18 @@ def _set_weights(self, linear_matrix, quadratic_matrix): idxs_in_prior_range[basis_type] = selected_idxs else: idxs_in_prior_range[basis_type] = [0] - if 'frequency_nodes' in matrix[f'basis_{basis_type}'][str(idxs_in_prior_range[basis_type][0])]: - self.weights[f'frequency_nodes_{basis_type}'] = [ - matrix[f'basis_{basis_type}'][str(i)]['frequency_nodes'][()] * self.roq_scale_factor - for i in idxs_in_prior_range[basis_type]] - - if 'multiband_linear' in linear_matrix: - multiband_linear = linear_matrix['multiband_linear'][()] + if "frequency_nodes" in matrix[f"basis_{basis_type}"][str(idxs_in_prior_range[basis_type][0])]: + self.weights[f"frequency_nodes_{basis_type}"] = [ + matrix[f"basis_{basis_type}"][str(i)]["frequency_nodes"][()] * self.roq_scale_factor + for i in idxs_in_prior_range[basis_type] + ] + + if "multiband_linear" in linear_matrix: + multiband_linear = linear_matrix["multiband_linear"][()] else: multiband_linear = False - if 'multiband_quadratic' in quadratic_matrix: - multiband_quadratic = quadratic_matrix['multiband_quadratic'][()] + if "multiband_quadratic" in quadratic_matrix: + multiband_quadratic = quadratic_matrix["multiband_quadratic"][()] else: multiband_quadratic = False @@ -772,45 +784,43 @@ def _set_weights(self, linear_matrix, quadratic_matrix): for ifo in self.interferometers: if self.roq_params is not None: # Get scaled ROQ quantities - roq_scaled_minimum_frequency = self.roq_params['flow'] * self.roq_scale_factor - roq_scaled_maximum_frequency = self.roq_params['fhigh'] * self.roq_scale_factor - roq_scaled_segment_length = self.roq_params['seglen'] / self.roq_scale_factor + roq_scaled_minimum_frequency = self.roq_params["flow"] * self.roq_scale_factor + roq_scaled_maximum_frequency = self.roq_params["fhigh"] * self.roq_scale_factor + roq_scaled_segment_length = self.roq_params["seglen"] / self.roq_scale_factor # Generate frequencies for the ROQ roq_frequencies = create_frequency_series( - sampling_frequency=roq_scaled_maximum_frequency * 2, - duration=roq_scaled_segment_length) + sampling_frequency=roq_scaled_maximum_frequency * 2, duration=roq_scaled_segment_length + ) roq_mask = roq_frequencies >= roq_scaled_minimum_frequency roq_frequencies = roq_frequencies[roq_mask] overlap_frequencies, ifo_idxs_this_ifo, roq_idxs_this_ifo = np.intersect1d( - ifo.frequency_array[ifo.frequency_mask], roq_frequencies, - return_indices=True) + ifo.frequency_array[ifo.frequency_mask], roq_frequencies, return_indices=True + ) else: overlap_frequencies = ifo.frequency_array[ifo.frequency_mask] roq_idxs_this_ifo = np.arange( - linear_matrix['basis_linear'][str(idxs_in_prior_range['linear'][0])]['basis'].shape[1], - dtype=int) + linear_matrix["basis_linear"][str(idxs_in_prior_range["linear"][0])]["basis"].shape[1], + dtype=int, + ) ifo_idxs_this_ifo = np.arange(sum(ifo.frequency_mask)) if len(ifo_idxs_this_ifo) != len(roq_idxs_this_ifo): - raise ValueError( - "Mismatch between ROQ basis and frequency array for " - "{}".format(ifo.name)) + raise ValueError(f"Mismatch between ROQ basis and frequency array for {ifo.name}") logger.info( - "Building ROQ weights for {} with {} frequencies between {} " - "and {}.".format( - ifo.name, len(overlap_frequencies), - min(overlap_frequencies), max(overlap_frequencies))) + f"Building ROQ weights for {ifo.name} with {len(overlap_frequencies)} " + f"frequencies between {min(overlap_frequencies)} and {max(overlap_frequencies)}." + ) roq_idxs[ifo.name] = roq_idxs_this_ifo ifo_idxs[ifo.name] = ifo_idxs_this_ifo if multiband_linear: - self._set_weights_linear_multiband(linear_matrix, idxs_in_prior_range['linear']) + self._set_weights_linear_multiband(linear_matrix, idxs_in_prior_range["linear"]) else: - self._set_weights_linear(linear_matrix, idxs_in_prior_range['linear'], roq_idxs, ifo_idxs) + self._set_weights_linear(linear_matrix, idxs_in_prior_range["linear"], roq_idxs, ifo_idxs) if multiband_quadratic: - self._set_weights_quadratic_multiband(quadratic_matrix, idxs_in_prior_range['quadratic']) + self._set_weights_quadratic_multiband(quadratic_matrix, idxs_in_prior_range["quadratic"]) else: - self._set_weights_quadratic(quadratic_matrix, idxs_in_prior_range['quadratic'], roq_idxs, ifo_idxs) + self._set_weights_quadratic(quadratic_matrix, idxs_in_prior_range["quadratic"], roq_idxs, ifo_idxs) def _set_weights_linear(self, linear_matrix, basis_idxs, roq_idxs, ifo_idxs): """ @@ -832,23 +842,27 @@ def _set_weights_linear(self, linear_matrix, basis_idxs, roq_idxs, ifo_idxs): """ for ifo in self.interferometers: - self.weights[ifo.name + '_linear'] = [] - time_space = self.weights['time_samples'][1] - self.weights['time_samples'][0] + self.weights[ifo.name + "_linear"] = [] + time_space = self.weights["time_samples"][1] - self.weights["time_samples"][0] number_of_time_samples = int(self.interferometers.duration / time_space) - start_idx = int(self.weights['time_samples'][0] / time_space) - end_idx = int(self.weights['time_samples'][-1] / time_space) + start_idx = int(self.weights["time_samples"][0] / time_space) + end_idx = int(self.weights["time_samples"][-1] / time_space) nonzero_idxs = {} data_over_psd = {} for ifo in self.interferometers: nonzero_idxs[ifo.name] = ifo_idxs[ifo.name] + int( - ifo.frequency_array[ifo.frequency_mask][0] * self.interferometers.duration) - data_over_psd[ifo.name] = ifo.frequency_domain_strain[ifo.frequency_mask][ifo_idxs[ifo.name]] / \ - ifo.power_spectral_density_array[ifo.frequency_mask][ifo_idxs[ifo.name]] + ifo.frequency_array[ifo.frequency_mask][0] * self.interferometers.duration + ) + data_over_psd[ifo.name] = ( + ifo.frequency_domain_strain[ifo.frequency_mask][ifo_idxs[ifo.name]] + / ifo.power_spectral_density_array[ifo.frequency_mask][ifo_idxs[ifo.name]] + ) try: import pyfftw + ifft_input = pyfftw.empty_aligned(number_of_time_samples, dtype=complex) ifft_output = pyfftw.empty_aligned(number_of_time_samples, dtype=complex) - ifft = pyfftw.FFTW(ifft_input, ifft_output, direction='FFTW_BACKWARD') + ifft = pyfftw.FFTW(ifft_input, ifft_output, direction="FFTW_BACKWARD") except ImportError: pyfftw = None logger.warning("You do not have pyfftw installed, falling back to numpy.fft.") @@ -856,18 +870,17 @@ def _set_weights_linear(self, linear_matrix, basis_idxs, roq_idxs, ifo_idxs): ifft = np.fft.ifft for basis_idx in basis_idxs: logger.info(f"Building linear ROQ weights for the {basis_idx}-th basis.") - linear_matrix_single = linear_matrix['basis_linear'][str(basis_idx)]['basis'] + linear_matrix_single = linear_matrix["basis_linear"][str(basis_idx)]["basis"] basis_size = linear_matrix_single.shape[0] for ifo in self.interferometers: - ifft_input[:] *= 0. - linear_weights = \ - np.zeros((len(self.weights['time_samples']), basis_size), dtype=complex) + ifft_input[:] *= 0.0 + linear_weights = np.zeros((len(self.weights["time_samples"]), basis_size), dtype=complex) for i in range(basis_size): basis_element = linear_matrix_single[i][roq_idxs[ifo.name]] ifft_input[nonzero_idxs[ifo.name]] = data_over_psd[ifo.name] * np.conj(basis_element) - linear_weights[:, i] = ifft(ifft_input)[start_idx:end_idx + 1] - linear_weights *= 4. * number_of_time_samples / self.interferometers.duration - self.weights[ifo.name + '_linear'].append(linear_weights) + linear_weights[:, i] = ifft(ifft_input)[start_idx : end_idx + 1] + linear_weights *= 4.0 * number_of_time_samples / self.interferometers.duration + self.weights[ifo.name + "_linear"].append(linear_weights) if pyfftw is not None: pyfftw.forget_wisdom() @@ -884,37 +897,45 @@ def _set_weights_linear_multiband(self, linear_matrix, basis_idxs): """ for ifo in self.interferometers: - self.weights[ifo.name + '_linear'] = [] - Tbs = linear_matrix['durations_s_linear'][()] / self.roq_scale_factor - start_end_frequency_bins = linear_matrix['start_end_frequency_bins_linear'][()] + self.weights[ifo.name + "_linear"] = [] + Tbs = linear_matrix["durations_s_linear"][()] / self.roq_scale_factor + start_end_frequency_bins = linear_matrix["start_end_frequency_bins_linear"][()] basis_dimension = np.sum(start_end_frequency_bins[:, 1] - start_end_frequency_bins[:, 0] + 1) fhigh_basis = np.max(start_end_frequency_bins[:, 1] / Tbs) # prepare time-shifted data, which is multiplied by basis tc_shifted_data = dict() for ifo in self.interferometers: over_whitened_frequency_data = np.zeros(int(fhigh_basis * ifo.duration) + 1, dtype=complex) - over_whitened_frequency_data[np.arange(len(ifo.frequency_domain_strain))[ifo.frequency_mask]] = \ + over_whitened_frequency_data[np.arange(len(ifo.frequency_domain_strain))[ifo.frequency_mask]] = ( ifo.frequency_domain_strain[ifo.frequency_mask] / ifo.power_spectral_density_array[ifo.frequency_mask] + ) over_whitened_time_data = np.fft.irfft(over_whitened_frequency_data) - tc_shifted_data[ifo.name] = np.zeros((basis_dimension, len(self.weights['time_samples'])), dtype=complex) + tc_shifted_data[ifo.name] = np.zeros((basis_dimension, len(self.weights["time_samples"])), dtype=complex) start_idx_of_band = 0 for b, Tb in enumerate(Tbs): start_frequency_bin, end_frequency_bin = start_end_frequency_bins[b] fs = np.arange(start_frequency_bin, end_frequency_bin + 1) / Tb - Db = np.fft.rfft( - over_whitened_time_data[-int(2. * fhigh_basis * Tb):] - )[start_frequency_bin:end_frequency_bin + 1] + Db = np.fft.rfft(over_whitened_time_data[-int(2.0 * fhigh_basis * Tb) :])[ + start_frequency_bin : end_frequency_bin + 1 + ] start_idx_of_next_band = start_idx_of_band + end_frequency_bin - start_frequency_bin + 1 - tc_shifted_data[ifo.name][start_idx_of_band:start_idx_of_next_band] = 4. / Tb * Db[:, None] * np.exp( - 2. * np.pi * 1j * fs[:, None] * (self.weights['time_samples'][None, :] - ifo.duration + Tb)) + tc_shifted_data[ifo.name][start_idx_of_band:start_idx_of_next_band] = ( + 4.0 + / Tb + * Db[:, None] + * np.exp( + 2.0 * np.pi * 1j * fs[:, None] * (self.weights["time_samples"][None, :] - ifo.duration + Tb) + ) + ) start_idx_of_band = start_idx_of_next_band # compute inner products for basis_idx in basis_idxs: logger.info(f"Building linear ROQ weights for the {basis_idx}-th basis.") - linear_matrix_single = linear_matrix['basis_linear'][str(basis_idx)]['basis'][()] + linear_matrix_single = linear_matrix["basis_linear"][str(basis_idx)]["basis"][()] for ifo in self.interferometers: - self.weights[ifo.name + '_linear'].append( - np.dot(np.conj(linear_matrix_single), tc_shifted_data[ifo.name]).T) + self.weights[ifo.name + "_linear"].append( + np.dot(np.conj(linear_matrix_single), tc_shifted_data[ifo.name]).T + ) def _set_weights_quadratic(self, quadratic_matrix, basis_idxs, roq_idxs, ifo_idxs): """ @@ -935,15 +956,19 @@ def _set_weights_quadratic(self, quadratic_matrix, basis_idxs, roq_idxs, ifo_idx """ for ifo in self.interferometers: - self.weights[ifo.name + '_quadratic'] = [] + self.weights[ifo.name + "_quadratic"] = [] for basis_idx in basis_idxs: logger.info(f"Building quadratic ROQ weights for the {basis_idx}-th basis.") - quadratic_matrix_single = quadratic_matrix['basis_quadratic'][str(basis_idx)]['basis'][()].real + quadratic_matrix_single = quadratic_matrix["basis_quadratic"][str(basis_idx)]["basis"][()].real for ifo in self.interferometers: - self.weights[ifo.name + '_quadratic'].append( - 4. / ifo.strain_data.duration * np.dot( + self.weights[ifo.name + "_quadratic"].append( + 4.0 + / ifo.strain_data.duration + * np.dot( quadratic_matrix_single[:, roq_idxs[ifo.name]], - 1 / ifo.power_spectral_density_array[ifo.frequency_mask][ifo_idxs[ifo.name]])) + 1 / ifo.power_spectral_density_array[ifo.frequency_mask][ifo_idxs[ifo.name]], + ) + ) del quadratic_matrix_single def _set_weights_quadratic_multiband(self, quadratic_matrix, basis_idxs): @@ -959,17 +984,18 @@ def _set_weights_quadratic_multiband(self, quadratic_matrix, basis_idxs): """ for ifo in self.interferometers: - self.weights[ifo.name + '_quadratic'] = [] - Tbs = quadratic_matrix['durations_s_quadratic'][()] / self.roq_scale_factor - start_end_frequency_bins = quadratic_matrix['start_end_frequency_bins_quadratic'][()] + self.weights[ifo.name + "_quadratic"] = [] + Tbs = quadratic_matrix["durations_s_quadratic"][()] / self.roq_scale_factor + start_end_frequency_bins = quadratic_matrix["start_end_frequency_bins_quadratic"][()] basis_dimension = np.sum(start_end_frequency_bins[:, 1] - start_end_frequency_bins[:, 0] + 1) fhigh_basis = np.max(start_end_frequency_bins[:, 1] / Tbs) # prepare coefficients multiplied by basis multibanded_inverse_psd = dict() for ifo in self.interferometers: inverse_psd_frequency = np.zeros(int(fhigh_basis * ifo.duration) + 1) - inverse_psd_frequency[np.arange(len(ifo.power_spectral_density_array))[ifo.frequency_mask]] = \ - 1. / ifo.power_spectral_density_array[ifo.frequency_mask] + inverse_psd_frequency[np.arange(len(ifo.power_spectral_density_array))[ifo.frequency_mask]] = ( + 1.0 / ifo.power_spectral_density_array[ifo.frequency_mask] + ) inverse_psd_time = np.fft.irfft(inverse_psd_frequency) multibanded_inverse_psd[ifo.name] = np.zeros(basis_dimension) start_idx_of_band = 0 @@ -977,19 +1003,24 @@ def _set_weights_quadratic_multiband(self, quadratic_matrix, basis_idxs): start_frequency_bin, end_frequency_bin = start_end_frequency_bins[b] number_of_samples_half = int(fhigh_basis * Tb) start_idx_of_next_band = start_idx_of_band + end_frequency_bin - start_frequency_bin + 1 - multibanded_inverse_psd[ifo.name][start_idx_of_band:start_idx_of_next_band] = 4. / Tb * np.fft.rfft( - np.append(inverse_psd_time[:number_of_samples_half], inverse_psd_time[-number_of_samples_half:]) - )[start_frequency_bin:end_frequency_bin + 1].real + multibanded_inverse_psd[ifo.name][start_idx_of_band:start_idx_of_next_band] = ( + 4.0 + / Tb + * np.fft.rfft( + np.append(inverse_psd_time[:number_of_samples_half], inverse_psd_time[-number_of_samples_half:]) + )[start_frequency_bin : end_frequency_bin + 1].real + ) start_idx_of_band = start_idx_of_next_band # compute inner products for basis_idx in basis_idxs: logger.info(f"Building quadratic ROQ weights for the {basis_idx}-th basis.") - quadratic_matrix_single = quadratic_matrix['basis_quadratic'][str(basis_idx)]['basis'][()].real + quadratic_matrix_single = quadratic_matrix["basis_quadratic"][str(basis_idx)]["basis"][()].real for ifo in self.interferometers: - self.weights[ifo.name + '_quadratic'].append( - np.dot(quadratic_matrix_single, multibanded_inverse_psd[ifo.name])) + self.weights[ifo.name + "_quadratic"].append( + np.dot(quadratic_matrix_single, multibanded_inverse_psd[ifo.name]) + ) - def save_weights(self, filename, format='hdf5'): + def save_weights(self, filename, format="hdf5"): """ Save ROQ weights into a single file. Support for json format was removed in :code:`v2.7`, only hdf5 and npz are supported. @@ -1001,45 +1032,42 @@ def save_weights(self, filename, format='hdf5'): format : str The format to save the weight in, should be in :code:`hdf5, npz`. """ - if format not in ['hdf5', 'npz']: - raise IOError(f"Format {format} not recognized.") + if format not in ["hdf5", "npz"]: + raise OSError(f"Format {format} not recognized.") if format not in filename: filename += "." + format logger.info(f"Saving ROQ weights to {filename}") - if format == 'npz': + if format == "npz": if self.number_of_bases_linear > 1 or self.number_of_bases_quadratic > 1: - raise ValueError(f'Format {format} not compatible with multiple bases') + raise ValueError(f"Format {format} not compatible with multiple bases") weights = dict() - weights['time_samples'] = self.weights['time_samples'] - for basis_type in ['linear', 'quadratic']: + weights["time_samples"] = self.weights["time_samples"] + for basis_type in ["linear", "quadratic"]: for ifo in self.interferometers: - key = f'{ifo.name}_{basis_type}' + key = f"{ifo.name}_{basis_type}" weights[key] = self.weights[key][0] np.savez(filename, **weights) else: import h5py - with h5py.File(filename, 'w') as f: - f.create_dataset('time_samples', - data=self.weights['time_samples']) - for basis_type in ['linear', 'quadratic']: - key = f'prior_range_{basis_type}' + + with h5py.File(filename, "w") as f: + f.create_dataset("time_samples", data=self.weights["time_samples"]) + for basis_type in ["linear", "quadratic"]: + key = f"prior_range_{basis_type}" if key in self.weights: grp = f.create_group(key) for param_name in self.weights[key]: - grp.create_dataset( - param_name, data=self.weights[key][param_name]) - key = f'frequency_nodes_{basis_type}' + grp.create_dataset(param_name, data=self.weights[key][param_name]) + key = f"frequency_nodes_{basis_type}" if key in self.weights: grp = f.create_group(key) for i in range(len(self.weights[key])): - grp.create_dataset( - str(i), data=self.weights[key][i]) + grp.create_dataset(str(i), data=self.weights[key][i]) for ifo in self.interferometers: key = f"{ifo.name}_{basis_type}" grp = f.create_group(key) for i in range(len(self.weights[key])): - grp.create_dataset( - str(i), data=self.weights[key][i]) + grp.create_dataset(str(i), data=self.weights[key][i]) def load_weights(self, filename, format=None): """ @@ -1062,41 +1090,36 @@ def load_weights(self, filename, format=None): if format == "json": import warnings - warnings.warn( - "json format for ROQ weights is deprecated, use hdf5 instead.", - DeprecationWarning - ) + warnings.warn("json format for ROQ weights is deprecated, use hdf5 instead.", DeprecationWarning) elif format not in ["npz", "hdf5"]: - raise IOError(f"Format {format} not recognized.") + raise OSError(f"Format {format} not recognized.") logger.info(f"Loading ROQ weights from {filename}") if format == "npz": weights = dict(np.load(filename)) - for basis_type in ['linear', 'quadratic']: + for basis_type in ["linear", "quadratic"]: for ifo in self.interferometers: - key = f'{ifo.name}_{basis_type}' + key = f"{ifo.name}_{basis_type}" weights[key] = [weights[key]] else: weights = dict() import h5py - with h5py.File(filename, 'r') as f: - weights['time_samples'] = f['time_samples'][()] - for basis_type in ['linear', 'quadratic']: - key = f'prior_range_{basis_type}' + + with h5py.File(filename, "r") as f: + weights["time_samples"] = f["time_samples"][()] + for basis_type in ["linear", "quadratic"]: + key = f"prior_range_{basis_type}" if key in f: - idxs_in_prior_range, selected_prior_ranges = \ - self._select_prior_ranges(f[key]) + idxs_in_prior_range, selected_prior_ranges = self._select_prior_ranges(f[key]) weights[key] = selected_prior_ranges else: idxs_in_prior_range = [0] - key = f'frequency_nodes_{basis_type}' + key = f"frequency_nodes_{basis_type}" if key in f: - weights[key] = [f[key][str(i)][()] - for i in idxs_in_prior_range] + weights[key] = [f[key][str(i)][()] for i in idxs_in_prior_range] for ifo in self.interferometers: key = f"{ifo.name}_{basis_type}" - weights[key] = [f[key][str(i)][()] - for i in idxs_in_prior_range] + weights[key] = [f[key][str(i)][()] for i in idxs_in_prior_range] return weights def _get_time_resolution(self): @@ -1115,7 +1138,7 @@ def _get_time_resolution(self): Time resolution """ - def calc_fhigh(freq, psd, scaling=20.): + def calc_fhigh(freq, psd, scaling=20.0): """ Parameters @@ -1133,21 +1156,22 @@ def calc_fhigh(freq, psd, scaling=20.): The maximum frequency which must be considered """ from scipy.integrate import simpson - integrand1 = np.power(freq, -7. / 3) / psd + + integrand1 = np.power(freq, -7.0 / 3) / psd integral1 = simpson(y=integrand1, x=freq) - integrand3 = np.power(freq, 2. / 3.) / (psd * integral1) + integrand3 = np.power(freq, 2.0 / 3.0) / (psd * integral1) f_3_bar = simpson(y=integrand3, x=freq) - f_high = scaling * f_3_bar**(1 / 3) + f_high = scaling * f_3_bar ** (1 / 3) return f_high def c_f_scaling(snr): - return (np.pi**2 * snr**2 / 6)**(1 / 3) + return (np.pi**2 * snr**2 / 6) ** (1 / 3) inj_snr_sq = 0 for ifo in self.interferometers: - inj_snr_sq += max(10, ifo.meta_data.get('optimal_SNR', 30))**2 + inj_snr_sq += max(10, ifo.meta_data.get("optimal_SNR", 30)) ** 2 psd = ifo.power_spectral_density_array[ifo.frequency_mask] freq = ifo.frequency_array[ifo.frequency_mask] @@ -1161,14 +1185,15 @@ def c_f_scaling(snr): # duration / delta_t needs to be a power of 2 for IFFT number_of_time_samples = max( self.interferometers.duration / delta_t, - self.interferometers.frequency_array[-1] * self.interferometers.duration + 1) - number_of_time_samples = int(2**np.ceil(np.log2(number_of_time_samples))) + self.interferometers.frequency_array[-1] * self.interferometers.duration + 1, + ) + number_of_time_samples = int(2 ** np.ceil(np.log2(number_of_time_samples))) delta_t = self.interferometers.duration / number_of_time_samples - logger.info("ROQ time-step = {}".format(delta_t)) + logger.info(f"ROQ time-step = {delta_t}") return delta_t def _rescale_signal(self, signal, new_distance): - for kind in ['linear', 'quadratic']: + for kind in ["linear", "quadratic"]: for mode in signal[kind]: signal[kind][mode] *= self._ref_dist / new_distance @@ -1178,31 +1203,26 @@ def generate_time_sample_from_marginalized_likelihood(self, signal_polarizations parameters = _fallback_to_parameters(self, parameters) parameters.update(self.get_sky_frame_parameters(parameters=parameters)) if signal_polarizations is None: - signal_polarizations = \ - self.waveform_generator.frequency_domain_strain(parameters) + signal_polarizations = self.waveform_generator.frequency_domain_strain(parameters) snrs = self._CalculatedSNRs() for interferometer in self.interferometers: - snrs += self.calculate_snrs( - waveform_polarizations=signal_polarizations, - interferometer=interferometer - ) + snrs += self.calculate_snrs(waveform_polarizations=signal_polarizations, interferometer=interferometer) d_inner_h = snrs.d_inner_h_array h_inner_h = snrs.optimal_snr_squared if self.distance_marginalization: - time_log_like = self.distance_marginalized_likelihood( - d_inner_h, h_inner_h) + time_log_like = self.distance_marginalized_likelihood(d_inner_h, h_inner_h) elif self.phase_marginalization: time_log_like = ln_i0(abs(d_inner_h)) - h_inner_h.real / 2 else: - time_log_like = (d_inner_h.real - h_inner_h.real / 2) + time_log_like = d_inner_h.real - h_inner_h.real / 2 times = self._times if self.jitter_time: times = times + parameters["time_jitter"] - time_prior_array = self.priors['geocent_time'].prob(times) + time_prior_array = self.priors["geocent_time"].prob(times) time_post = np.exp(time_log_like - max(time_log_like)) * time_prior_array time_post /= np.sum(time_post) return random.rng.choice(times, p=time_post) diff --git a/bilby/gw/prior.py b/bilby/gw/prior.py index e262eaaf3..c3d5d81a7 100644 --- a/bilby/gw/prior.py +++ b/bilby/gw/prior.py @@ -1,32 +1,46 @@ -import os import copy +import os import numpy as np -from scipy.integrate import cumulative_trapezoid, trapezoid, quad +from scipy.integrate import cumulative_trapezoid, quad, trapezoid from scipy.interpolate import InterpolatedUnivariateSpline from scipy.special import hyp2f1 from scipy.stats import norm from ..core.prior import ( - PriorDict, Uniform, Prior, DeltaFunction, Gaussian, Interped, Constraint, - conditional_prior_factory, PowerLaw, ConditionalLogUniform, - ConditionalPriorDict, ConditionalBasePrior, BaseJointPriorDist, JointPrior, + BaseJointPriorDist, + ConditionalBasePrior, + ConditionalLogUniform, + ConditionalPriorDict, + Constraint, + DeltaFunction, + Gaussian, + Interped, + JointPrior, JointPriorDistError, + PowerLaw, + Prior, + PriorDict, + Uniform, + conditional_prior_factory, ) -from ..core.utils import infer_args_from_method, logger, random, WrappedInterp1d as interp1d +from ..core.utils import WrappedInterp1d as interp1d +from ..core.utils import infer_args_from_method, logger, random from .conversion import ( + chirp_mass_and_mass_ratio_to_total_mass, convert_to_lal_binary_black_hole_parameters, - convert_to_lal_binary_neutron_star_parameters, generate_mass_parameters, - generate_tidal_parameters, fill_from_fixed_priors, + convert_to_lal_binary_neutron_star_parameters, + fill_from_fixed_priors, generate_all_bbh_parameters, - chirp_mass_and_mass_ratio_to_total_mass, - total_mass_and_mass_ratio_to_component_masses) + generate_mass_parameters, + generate_tidal_parameters, + total_mass_and_mass_ratio_to_component_masses, +) from .cosmology import get_cosmology, z_at_value from .source import PARAMETER_SETS from .utils import calculate_time_to_merger - -DEFAULT_PRIOR_DIR = os.path.join(os.path.dirname(__file__), 'prior_files') +DEFAULT_PRIOR_DIR = os.path.join(os.path.dirname(__file__), "prior_files") class BilbyPriorConversionError(Exception): @@ -34,7 +48,7 @@ class BilbyPriorConversionError(Exception): def convert_to_flat_in_component_mass_prior(result, fraction=0.25): - """ Converts samples with a defined prior in chirp-mass and mass-ratio to flat in component mass by resampling with + """Converts samples with a defined prior in chirp-mass and mass-ratio to flat in component mass by resampling with the posterior with weights defined as ratio in new:old prior values times the jacobian which for F(mc, q) -> G(m1, m2) is defined as J := m1^2 / mc @@ -49,14 +63,14 @@ def convert_to_flat_in_component_mass_prior(result, fraction=0.25): """ if getattr(result, "priors") is not None: - for key in ['chirp_mass', 'mass_ratio']: + for key in ["chirp_mass", "mass_ratio"]: if key not in result.priors.keys(): - BilbyPriorConversionError("{} Prior not found in result object".format(key)) + BilbyPriorConversionError(f"{key} Prior not found in result object") if isinstance(result.priors[key], Constraint): - BilbyPriorConversionError("{} Prior should not be a Constraint".format(key)) - for key in ['mass_1', 'mass_2']: + BilbyPriorConversionError(f"{key} Prior should not be a Constraint") + for key in ["mass_1", "mass_2"]: if not isinstance(result.priors[key], Constraint): - BilbyPriorConversionError("{} Prior should be a Constraint Prior".format(key)) + BilbyPriorConversionError(f"{key} Prior should be a Constraint Prior") else: BilbyPriorConversionError("No prior in the result: unable to convert") @@ -65,27 +79,30 @@ def convert_to_flat_in_component_mass_prior(result, fraction=0.25): old_priors = copy.copy(result.priors) posterior = result.posterior - for key in ['chirp_mass', 'mass_ratio']: + for key in ["chirp_mass", "mass_ratio"]: priors[key] = Constraint(priors[key].minimum, priors[key].maximum, key, latex_label=priors[key].latex_label) - for key in ['mass_1', 'mass_2']: - priors[key] = Uniform(priors[key].minimum, priors[key].maximum, key, latex_label=priors[key].latex_label, - unit=r"$M_{\odot}$") + for key in ["mass_1", "mass_2"]: + priors[key] = Uniform( + priors[key].minimum, priors[key].maximum, key, latex_label=priors[key].latex_label, unit=r"$M_{\odot}$" + ) - weights = np.array(result.get_weights_by_new_prior(old_priors, priors, - prior_names=['chirp_mass', 'mass_ratio', 'mass_1', 'mass_2'])) + weights = np.array( + result.get_weights_by_new_prior( + old_priors, priors, prior_names=["chirp_mass", "mass_ratio", "mass_1", "mass_2"] + ) + ) jacobian = posterior["mass_1"] ** 2 / posterior["chirp_mass"] weights = jacobian * weights result.posterior = posterior.sample(frac=fraction, weights=weights) logger.info("Resampling posterior to flat-in-component mass") - effective_sample_size = sum(weights)**2 / sum(weights**2) + effective_sample_size = sum(weights) ** 2 / sum(weights**2) n_posterior = len(posterior) if fraction > effective_sample_size / n_posterior: logger.warning( - "Sampling posterior of length {} with fraction {}, but " - "effective_sample_size / len(posterior) = {}. This may produce " + f"Sampling posterior of length {n_posterior} with fraction {fraction}, but " + f"effective_sample_size / len(posterior) = {effective_sample_size / n_posterior}. This may produce " "biased results" - .format(n_posterior, fraction, effective_sample_size / n_posterior) ) result.posterior = posterior.sample(frac=fraction, weights=weights, replace=True) result.meta_data["reweighted_to_flat_in_component_mass"] = True @@ -100,15 +117,14 @@ class Cosmological(Interped): @property def _default_args_dict(self): from astropy import units + return dict( - redshift=dict(name='redshift', latex_label='$z$', unit=None), - luminosity_distance=dict( - name='luminosity_distance', latex_label='$d_L$', unit=units.Mpc), - comoving_distance=dict( - name='comoving_distance', latex_label='$d_C$', unit=units.Mpc)) - - def __init__(self, minimum, maximum, cosmology=None, name=None, - latex_label=None, unit=None, boundary=None): + redshift=dict(name="redshift", latex_label="$z$", unit=None), + luminosity_distance=dict(name="luminosity_distance", latex_label="$d_L$", unit=units.Mpc), + comoving_distance=dict(name="comoving_distance", latex_label="$d_C$", unit=units.Mpc), + ) + + def __init__(self, minimum, maximum, cosmology=None, name=None, latex_label=None, unit=None, boundary=None): """ Parameters @@ -133,34 +149,34 @@ def __init__(self, minimum, maximum, cosmology=None, name=None, The boundary condition to apply to the prior when sampling. """ from astropy import units + self.cosmology = get_cosmology(cosmology) if name not in self._default_args_dict: raise ValueError( - "Name {} not recognised. Must be one of luminosity_distance, " - "comoving_distance, redshift".format(name)) + f"Name {name} not recognised. Must be one of luminosity_distance, comoving_distance, redshift" + ) self.name = name label_args = self._default_args_dict[self.name] if latex_label is not None: - label_args['latex_label'] = latex_label + label_args["latex_label"] = latex_label if unit is not None: if not isinstance(unit, units.Unit): unit = units.Unit(unit) - label_args['unit'] = unit - self.unit = label_args['unit'] + label_args["unit"] = unit + self.unit = label_args["unit"] self._minimum = dict() self._maximum = dict() self.minimum = minimum self.maximum = maximum - if name == 'redshift': + if name == "redshift": xx, yy = self._get_redshift_arrays() - elif name == 'comoving_distance': + elif name == "comoving_distance": xx, yy = self._get_comoving_distance_arrays() - elif name == 'luminosity_distance': + elif name == "luminosity_distance": xx, yy = self._get_luminosity_distance_arrays() else: - raise ValueError('Name {} not recognized.'.format(name)) - super(Cosmological, self).__init__(xx=xx, yy=yy, minimum=minimum, maximum=maximum, - boundary=boundary, **label_args) + raise ValueError(f"Name {name} not recognized.") + super().__init__(xx=xx, yy=yy, minimum=minimum, maximum=maximum, boundary=boundary, **label_args) @property def minimum(self): @@ -199,37 +215,27 @@ def _set_limit(self, value, limit_dict, recalculate_array=False): """ cosmology = get_cosmology(self.cosmology) limit_dict[self.name] = value - if self.name == 'redshift': - limit_dict['luminosity_distance'] = \ - cosmology.luminosity_distance(value).value - limit_dict['comoving_distance'] = \ - cosmology.comoving_distance(value).value - elif self.name == 'luminosity_distance': + if self.name == "redshift": + limit_dict["luminosity_distance"] = cosmology.luminosity_distance(value).value + limit_dict["comoving_distance"] = cosmology.comoving_distance(value).value + elif self.name == "luminosity_distance": if value == 0: - limit_dict['redshift'] = 0 + limit_dict["redshift"] = 0 else: - limit_dict['redshift'] = z_at_value( - cosmology.luminosity_distance, value * self.unit - ) - limit_dict['comoving_distance'] = ( - cosmology.comoving_distance(limit_dict['redshift']).value - ) - elif self.name == 'comoving_distance': + limit_dict["redshift"] = z_at_value(cosmology.luminosity_distance, value * self.unit) + limit_dict["comoving_distance"] = cosmology.comoving_distance(limit_dict["redshift"]).value + elif self.name == "comoving_distance": if value == 0: - limit_dict['redshift'] = 0 + limit_dict["redshift"] = 0 else: - limit_dict['redshift'] = z_at_value( - cosmology.comoving_distance, value * self.unit - ) - limit_dict['luminosity_distance'] = ( - cosmology.luminosity_distance(limit_dict['redshift']).value - ) + limit_dict["redshift"] = z_at_value(cosmology.comoving_distance, value * self.unit) + limit_dict["luminosity_distance"] = cosmology.luminosity_distance(limit_dict["redshift"]).value if recalculate_array: - if self.name == 'redshift': + if self.name == "redshift": self.xx, self.yy = self._get_redshift_arrays() - elif self.name == 'comoving_distance': + elif self.name == "comoving_distance": self.xx, self.yy = self._get_comoving_distance_arrays() - elif self.name == 'luminosity_distance': + elif self.name == "luminosity_distance": self.xx, self.yy = self._get_luminosity_distance_arrays() try: self._update_instance() @@ -257,13 +263,13 @@ def get_corresponding_prior(self, name=None, unit=None): args_dict = {key: getattr(self, key) for key in subclass_args} self._convert_to(new=name, args_dict=args_dict) if unit is not None: - args_dict['unit'] = unit + args_dict["unit"] = unit return self.__class__(**args_dict) def _convert_to(self, new, args_dict): args_dict.update(self._default_args_dict[new]) - args_dict['minimum'] = self._minimum[args_dict['name']] - args_dict['maximum'] = self._maximum[args_dict['name']] + args_dict["minimum"] = self._minimum[args_dict["name"]] + args_dict["maximum"] = self._maximum[args_dict["name"]] def _get_comoving_distance_arrays(self): zs, p_dz = self._get_redshift_arrays() @@ -286,8 +292,7 @@ def _get_redshift_arrays(self): def from_repr(cls, string): if "FlatLambdaCDM" in string: logger.warning( - "Cosmological priors cannot be loaded from a string. " - "If the prior has a name, use that instead." + "Cosmological priors cannot be loaded from a string. If the prior has a name, use that instead." ) return string else: @@ -295,13 +300,15 @@ def from_repr(cls, string): def get_instantiation_dict(self): from astropy import units + from .cosmology import get_available_cosmologies + available = get_available_cosmologies() instantiation_dict = super().get_instantiation_dict() if self.cosmology.name in available: - instantiation_dict['cosmology'] = self.cosmology.name + instantiation_dict["cosmology"] = self.cosmology.name if isinstance(self.unit, units.Unit): - instantiation_dict['unit'] = self.unit.to_string() + instantiation_dict["unit"] = self.unit.to_string() return instantiation_dict @@ -318,8 +325,7 @@ class UniformComovingVolume(Cosmological): """ def _get_redshift_arrays(self): - zs = np.linspace(self._minimum['redshift'] * 0.99, - self._maximum['redshift'] * 1.01, 1000) + zs = np.linspace(self._minimum["redshift"] * 0.99, self._maximum["redshift"] * 1.01, 1000) p_dz = self.cosmology.differential_comoving_volume(zs).value return zs, p_dz @@ -337,8 +343,7 @@ class UniformSourceFrame(Cosmological): """ def _get_redshift_arrays(self): - zs = np.linspace(self._minimum['redshift'] * 0.99, - self._maximum['redshift'] * 1.01, 1000) + zs = np.linspace(self._minimum["redshift"] * 0.99, self._maximum["redshift"] * 1.01, 1000) p_dz = self.cosmology.differential_comoving_volume(zs).value / (1 + zs) return zs, p_dz @@ -360,8 +365,7 @@ class UniformInComponentsChirpMass(PowerLaw): :code:`bilby.gw.prior.UniformInComponentsMassRatio`. """ - def __init__(self, minimum, maximum, name='chirp_mass', - latex_label=r'$\mathcal{M}$', unit=None, boundary=None): + def __init__(self, minimum, maximum, name="chirp_mass", latex_label=r"$\mathcal{M}$", unit=None, boundary=None): """ Parameters ========== @@ -374,9 +378,15 @@ def __init__(self, minimum, maximum, name='chirp_mass', unit: see superclass boundary: see superclass """ - super(UniformInComponentsChirpMass, self).__init__( - alpha=1., minimum=minimum, maximum=maximum, - name=name, latex_label=latex_label, unit=unit, boundary=boundary) + super().__init__( + alpha=1.0, + minimum=minimum, + maximum=maximum, + name=name, + latex_label=latex_label, + unit=unit, + boundary=boundary, + ) class UniformInComponentsMassRatio(Prior): @@ -396,8 +406,9 @@ class UniformInComponentsMassRatio(Prior): :code:`bilby.gw.prior.UniformInComponentsChirpMass`. """ - def __init__(self, minimum, maximum, name='mass_ratio', latex_label='$q$', - unit=None, boundary=None, equal_mass=False): + def __init__( + self, minimum, maximum, name="mass_ratio", latex_label="$q$", unit=None, boundary=None, equal_mass=False + ): """ Parameters ========== @@ -417,18 +428,16 @@ def __init__(self, minimum, maximum, name='mass_ratio', latex_label='$q$', """ self.equal_mass = equal_mass - super(UniformInComponentsMassRatio, self).__init__( - minimum=minimum, maximum=maximum, name=name, - latex_label=latex_label, unit=unit, boundary=boundary) + super().__init__( + minimum=minimum, maximum=maximum, name=name, latex_label=latex_label, unit=unit, boundary=boundary + ) self.norm = self._integral(maximum) - self._integral(minimum) qs = np.linspace(minimum, maximum, 1000) - self.icdf = interp1d( - self.cdf(qs), qs, kind='cubic', - bounds_error=False, fill_value=(minimum, maximum)) + self.icdf = interp1d(self.cdf(qs), qs, kind="cubic", bounds_error=False, fill_value=(minimum, maximum)) @staticmethod def _integral(q): - return -5. * q**(-1. / 5.) * hyp2f1(-2. / 5., -1. / 5., 4. / 5., -q) + return -5.0 * q ** (-1.0 / 5.0) * hyp2f1(-2.0 / 5.0, -1.0 / 5.0, 4.0 / 5.0, -q) def cdf(self, val): return (self._integral(val) - self._integral(self.minimum)) / self.norm @@ -441,7 +450,7 @@ def rescale(self, val): def prob(self, val): in_prior = (val >= self.minimum) & (val <= self.maximum) with np.errstate(invalid="ignore"): - prob = (1. + val)**(2. / 5.) / (val**(6. / 5.)) / self.norm * in_prior + prob = (1.0 + val) ** (2.0 / 5.0) / (val ** (6.0 / 5.0)) / self.norm * in_prior return prob def ln_prob(self, val): @@ -496,13 +505,12 @@ def __init__( """ self.a_prior = a_prior self.z_prior = z_prior - chi_min = min(a_prior.maximum * z_prior.minimum, - a_prior.minimum * z_prior.maximum) + chi_min = min(a_prior.maximum * z_prior.minimum, a_prior.minimum * z_prior.maximum) chi_max = a_prior.maximum * z_prior.maximum if self._is_simple_aligned_prior: self.num_interp = 100_000 if num_interp is None else num_interp xx = np.linspace(chi_min, chi_max, self.num_interp) - yy = - np.log(np.abs(xx) / a_prior.maximum) / (2 * a_prior.maximum) + yy = -np.log(np.abs(xx) / a_prior.maximum) / (2 * a_prior.maximum) else: def integrand(aa, chi): @@ -515,10 +523,7 @@ def integrand(aa, chi): self.num_interp = 10_000 if num_interp is None else num_interp xx = np.linspace(chi_min, chi_max, self.num_interp) - yy = [ - quad(integrand, a_prior.minimum, a_prior.maximum, chi)[0] - for chi in xx - ] + yy = [quad(integrand, a_prior.minimum, a_prior.maximum, chi)[0] for chi in xx] super().__init__( xx=xx, yy=yy, @@ -554,9 +559,15 @@ class ConditionalChiUniformSpinMagnitude(ConditionalLogUniform): """ def __init__(self, minimum, maximum, name, latex_label=None, unit=None, boundary=None): - super(ConditionalChiUniformSpinMagnitude, self).__init__( - minimum=minimum, maximum=maximum, name=name, latex_label=latex_label, unit=unit, boundary=boundary, - condition_func=self._condition_function) + super().__init__( + minimum=minimum, + maximum=maximum, + name=name, + latex_label=latex_label, + unit=unit, + boundary=boundary, + condition_func=self._condition_function, + ) self._required_variables = [name.replace("a", "chi")] self.__class__.__name__ = "ConditionalChiUniformSpinMagnitude" self.__class__.__qualname__ = "ConditionalChiUniformSpinMagnitude" @@ -589,11 +600,14 @@ class ConditionalChiInPlane(ConditionalBasePrior): """ def __init__(self, minimum, maximum, name, latex_label=None, unit=None, boundary=None): - super(ConditionalChiInPlane, self).__init__( - minimum=minimum, maximum=maximum, - name=name, latex_label=latex_label, - unit=unit, boundary=boundary, - condition_func=self._condition_function + super().__init__( + minimum=minimum, + maximum=maximum, + name=name, + latex_label=latex_label, + unit=unit, + boundary=boundary, + condition_func=self._condition_function, ) self._required_variables = [name[:5]] self._reference_maximum = maximum @@ -604,9 +618,10 @@ def prob(self, val, **required_variables): self.update_conditions(**required_variables) chi_aligned = abs(required_variables[self._required_variables[0]]) return ( - (val >= self.minimum) * (val <= self.maximum) + (val >= self.minimum) + * (val <= self.maximum) * val - / (chi_aligned ** 2 + val ** 2) + / (chi_aligned**2 + val**2) / np.log(self._reference_maximum / chi_aligned) ) @@ -634,12 +649,17 @@ def cdf(self, val, **required_variables): """ self.update_conditions(**required_variables) chi_aligned = abs(required_variables[self._required_variables[0]]) - return np.maximum(np.minimum( - (val >= self.minimum) * (val <= self.maximum) - * np.log(1 + (val / chi_aligned) ** 2) - / 2 / np.log(self._reference_maximum / chi_aligned) - , 1 - ), 0) + return np.maximum( + np.minimum( + (val >= self.minimum) + * (val <= self.maximum) + * np.log(1 + (val / chi_aligned) ** 2) + / 2 + / np.log(self._reference_maximum / chi_aligned), + 1, + ), + 0, + ) def rescale(self, val, **required_variables): r""" @@ -664,9 +684,7 @@ def rescale(self, val, **required_variables): def _condition_function(self, reference_params, **kwargs): with np.errstate(invalid="ignore"): - maximum = np.sqrt( - self._reference_maximum ** 2 - kwargs[self._required_variables[0]] ** 2 - ) + maximum = np.sqrt(self._reference_maximum**2 - kwargs[self._required_variables[0]] ** 2) return dict(minimum=0, maximum=maximum) def __repr__(self): @@ -697,7 +715,6 @@ def prob(self, val): return val def ln_prob(self, val): - if val: result = 0.0 elif not val: @@ -714,9 +731,9 @@ def minimum_chirp_mass(self): if "chirp_mass" in self: return self["chirp_mass"].minimum elif "mass_1" in self: - mass_1 = self['mass_1'].minimum + mass_1 = self["mass_1"].minimum if "mass_2" in self: - mass_2 = self['mass_2'].minimum + mass_2 = self["mass_2"].minimum elif "mass_ratio" in self: mass_2 = mass_1 * self["mass_ratio"].minimum if mass_1 is not None and mass_2 is not None: @@ -733,9 +750,9 @@ def maximum_chirp_mass(self): if "chirp_mass" in self: return self["chirp_mass"].maximum elif "mass_1" in self: - mass_1 = self['mass_1'].maximum + mass_1 = self["mass_1"].maximum if "mass_2" in self: - mass_2 = self['mass_2'].maximum + mass_2 = self["mass_2"].maximum elif "mass_ratio" in self: mass_2 = mass_1 * self["mass_ratio"].maximum if mass_1 is not None and mass_2 is not None: @@ -762,17 +779,15 @@ def minimum_component_mass(self): if "mass_2" in self: return self["mass_2"].minimum if "chirp_mass" in self and "mass_ratio" in self: - total_mass = chirp_mass_and_mass_ratio_to_total_mass( - self["chirp_mass"].minimum, self["mass_ratio"].minimum) - _, mass_2 = total_mass_and_mass_ratio_to_component_masses( - self["mass_ratio"].minimum, total_mass) + total_mass = chirp_mass_and_mass_ratio_to_total_mass(self["chirp_mass"].minimum, self["mass_ratio"].minimum) + _, mass_2 = total_mass_and_mass_ratio_to_component_masses(self["mass_ratio"].minimum, total_mass) return mass_2 else: logger.warning("Unable to determine minimum component mass") return None def is_nonempty_intersection(self, pset): - """ Check if keys in self exist in the parameter set + """Check if keys in self exist in the parameter set Parameters ---------- @@ -793,54 +808,52 @@ def is_nonempty_intersection(self, pset): @property def spin(self): - """ Return true if priors include any spin parameters """ + """Return true if priors include any spin parameters""" return self.is_nonempty_intersection("spin") @property def precession(self): - """ Return true if priors include any precession parameters """ + """Return true if priors include any precession parameters""" return self.is_nonempty_intersection("precession_only") @property def measured_spin(self): - """ Return true if priors include any measured_spin parameters """ + """Return true if priors include any measured_spin parameters""" return self.is_nonempty_intersection("measured_spin") @property def intrinsic(self): - """ Return true if priors include any intrinsic parameters """ + """Return true if priors include any intrinsic parameters""" return self.is_nonempty_intersection("intrinsic") @property def extrinsic(self): - """ Return true if priors include any extrinsic parameters """ + """Return true if priors include any extrinsic parameters""" return self.is_nonempty_intersection("extrinsic") @property def sky(self): - """ Return true if priors include any extrinsic parameters """ + """Return true if priors include any extrinsic parameters""" return self.is_nonempty_intersection("sky") @property def distance_inclination(self): - """ Return true if priors include any extrinsic parameters """ + """Return true if priors include any extrinsic parameters""" return self.is_nonempty_intersection("distance_inclination") @property def mass(self): - """ Return true if priors include any mass parameters """ + """Return true if priors include any mass parameters""" return self.is_nonempty_intersection("mass") @property def phase(self): - """ Return true if priors include phase parameters """ + """Return true if priors include phase parameters""" return self.is_nonempty_intersection("phase") @property def _cosmological_priors(self): - return [ - key for key, prior in self.items() if isinstance(prior, Cosmological) - ] + return [key for key, prior in self.items() if isinstance(prior, Cosmological)] @property def is_cosmological(self): @@ -888,10 +901,7 @@ def check_valid_cosmology(self, error=True, warning=False): if all(cosmology_equal(cosmologies[0], c, allow_equivalent=True) for c in cosmologies[1:]): return True - message = ( - "All cosmological priors must use the same cosmology. " - f"Found: {cosmologies}" - ) + message = f"All cosmological priors must use the same cosmology. Found: {cosmologies}" if warning: logger.warning(message) return False @@ -899,7 +909,7 @@ def check_valid_cosmology(self, error=True, warning=False): raise ValueError(message) def validate_prior(self, duration, minimum_frequency, N=1000, error=True, warning=False): - """ Validate the prior is suitable for use + """Validate the prior is suitable for use Parameters ========== @@ -920,14 +930,16 @@ def validate_prior(self, duration, minimum_frequency, N=1000, error=True, warnin """ samples = self.sample(N) samples = generate_all_bbh_parameters(samples) - durations = np.array([ - calculate_time_to_merger( - frequency=minimum_frequency, - mass_1=mass_1, - mass_2=mass_2, - ) - for (mass_1, mass_2) in zip(samples["mass_1"], samples["mass_2"]) - ]) + durations = np.array( + [ + calculate_time_to_merger( + frequency=minimum_frequency, + mass_1=mass_1, + mass_2=mass_2, + ) + for (mass_1, mass_2) in zip(samples["mass_1"], samples["mass_2"]) + ] + ) longest_duration = max(durations) if longest_duration < duration: return True @@ -944,9 +956,8 @@ def validate_prior(self, duration, minimum_frequency, N=1000, error=True, warnin class BBHPriorDict(CBCPriorDict): - def __init__(self, dictionary=None, filename=None, aligned_spin=False, - conversion_function=None): - """ Initialises a Prior set for Binary Black holes + def __init__(self, dictionary=None, filename=None, aligned_spin=False, conversion_function=None): + """Initialises a Prior set for Binary Black holes Parameters ========== @@ -961,17 +972,16 @@ def __init__(self, dictionary=None, filename=None, aligned_spin=False, """ if dictionary is None and filename is None: if aligned_spin: - fname = 'aligned_spins_bbh.prior' - logger.info('Using aligned spin prior') + fname = "aligned_spins_bbh.prior" + logger.info("Using aligned spin prior") else: - fname = 'precessing_spins_bbh.prior' + fname = "precessing_spins_bbh.prior" filename = os.path.join(DEFAULT_PRIOR_DIR, fname) - logger.info('No prior given, using default BBH priors in {}.'.format(filename)) + logger.info(f"No prior given, using default BBH priors in {filename}.") elif filename is not None: if not os.path.isfile(filename): filename = os.path.join(DEFAULT_PRIOR_DIR, filename) - super(BBHPriorDict, self).__init__(dictionary=dictionary, filename=filename, - conversion_function=conversion_function) + super().__init__(dictionary=dictionary, filename=filename, conversion_function=conversion_function) def default_conversion_function(self, sample): """ @@ -1019,41 +1029,43 @@ def test_redundancy(self, key, disable_logging=False): Whether the key is redundant or not """ if key in self: - logger.debug('{} already in prior'.format(key)) + logger.debug(f"{key} already in prior") return True - sampling_parameters = {key for key in self if not isinstance( - self[key], (DeltaFunction, Constraint))} - - mass_parameters = {'mass_1', 'mass_2', 'chirp_mass', 'total_mass', 'mass_ratio', 'symmetric_mass_ratio'} - spin_tilt_1_parameters = {'tilt_1', 'cos_tilt_1'} - spin_tilt_2_parameters = {'tilt_2', 'cos_tilt_2'} - spin_azimuth_parameters = {'phi_1', 'phi_2', 'phi_12', 'phi_jl'} - inclination_parameters = {'theta_jn', 'cos_theta_jn'} - distance_parameters = {'luminosity_distance', 'comoving_distance', 'redshift'} - - for independent_parameters, parameter_set in \ - zip([2, 2, 1, 1, 1, 1], - [mass_parameters, spin_azimuth_parameters, - spin_tilt_1_parameters, spin_tilt_2_parameters, - inclination_parameters, distance_parameters]): + sampling_parameters = {key for key in self if not isinstance(self[key], (DeltaFunction, Constraint))} + + mass_parameters = {"mass_1", "mass_2", "chirp_mass", "total_mass", "mass_ratio", "symmetric_mass_ratio"} + spin_tilt_1_parameters = {"tilt_1", "cos_tilt_1"} + spin_tilt_2_parameters = {"tilt_2", "cos_tilt_2"} + spin_azimuth_parameters = {"phi_1", "phi_2", "phi_12", "phi_jl"} + inclination_parameters = {"theta_jn", "cos_theta_jn"} + distance_parameters = {"luminosity_distance", "comoving_distance", "redshift"} + + for independent_parameters, parameter_set in zip( + [2, 2, 1, 1, 1, 1], + [ + mass_parameters, + spin_azimuth_parameters, + spin_tilt_1_parameters, + spin_tilt_2_parameters, + inclination_parameters, + distance_parameters, + ], + ): if key in parameter_set: - if len(parameter_set.intersection( - sampling_parameters)) >= independent_parameters: + if len(parameter_set.intersection(sampling_parameters)) >= independent_parameters: logger.disabled = disable_logging - logger.warning('{} already in prior. ' - 'This may lead to unexpected behaviour.' - .format(parameter_set.intersection(self))) + logger.warning( + f"{parameter_set.intersection(self)} already in prior. This may lead to unexpected behaviour." + ) logger.disabled = False return True return False class BNSPriorDict(CBCPriorDict): - - def __init__(self, dictionary=None, filename=None, aligned_spin=True, - conversion_function=None): - """ Initialises a Prior set for Binary Neutron Stars + def __init__(self, dictionary=None, filename=None, aligned_spin=True, conversion_function=None): + """Initialises a Prior set for Binary Neutron Stars Parameters ========== @@ -1067,17 +1079,16 @@ def __init__(self, dictionary=None, filename=None, aligned_spin=True, BNSPriorDict.default_conversion_function """ if aligned_spin: - default_file = 'aligned_spins_bns_tides_on.prior' + default_file = "aligned_spins_bns_tides_on.prior" else: - default_file = 'precessing_spins_bns_tides_on.prior' + default_file = "precessing_spins_bns_tides_on.prior" if dictionary is None and filename is None: filename = os.path.join(DEFAULT_PRIOR_DIR, default_file) - logger.info('No prior given, using default BNS priors in {}.'.format(filename)) + logger.info(f"No prior given, using default BNS priors in {filename}.") elif filename is not None: if not os.path.isfile(filename): filename = os.path.join(DEFAULT_PRIOR_DIR, filename) - super(BNSPriorDict, self).__init__(dictionary=dictionary, filename=filename, - conversion_function=conversion_function) + super().__init__(dictionary=dictionary, filename=filename, conversion_function=conversion_function) def default_conversion_function(self, sample): """ @@ -1120,19 +1131,17 @@ def test_redundancy(self, key, disable_logging=False): return True redundant = False - sampling_parameters = {key for key in self if not isinstance( - self[key], (DeltaFunction, Constraint))} + sampling_parameters = {key for key in self if not isinstance(self[key], (DeltaFunction, Constraint))} - tidal_parameters = \ - {'lambda_1', 'lambda_2', 'lambda_tilde', 'delta_lambda_tilde'} + tidal_parameters = {"lambda_1", "lambda_2", "lambda_tilde", "delta_lambda_tilde"} if key in tidal_parameters: if len(tidal_parameters.intersection(sampling_parameters)) > 2: redundant = True logger.disabled = disable_logging - logger.warning('{} already in prior. ' - 'This may lead to unexpected behaviour.' - .format(tidal_parameters.intersection(self))) + logger.warning( + f"{tidal_parameters.intersection(self)} already in prior. This may lead to unexpected behaviour." + ) logger.disabled = False elif len(tidal_parameters.intersection(sampling_parameters)) == 2: redundant = True @@ -1140,44 +1149,44 @@ def test_redundancy(self, key, disable_logging=False): @property def tidal(self): - """ Return true if priors include phase parameters """ + """Return true if priors include phase parameters""" return self.is_nonempty_intersection("tidal") Prior._default_latex_labels = { - 'mass_1': '$m_1$', - 'mass_2': '$m_2$', - 'total_mass': '$M$', - 'chirp_mass': r'$\mathcal{M}$', - 'mass_ratio': '$q$', - 'symmetric_mass_ratio': r'$\eta$', - 'a_1': '$a_1$', - 'a_2': '$a_2$', - 'tilt_1': r'$\theta_1$', - 'tilt_2': r'$\theta_2$', - 'cos_tilt_1': r'$\cos\theta_1$', - 'cos_tilt_2': r'$\cos\theta_2$', - 'phi_12': r'$\Delta\phi$', - 'phi_jl': r'$\phi_{JL}$', - 'luminosity_distance': '$d_L$', - 'dec': r'$\mathrm{DEC}$', - 'ra': r'$\mathrm{RA}$', - 'iota': r'$\iota$', - 'cos_iota': r'$\cos\iota$', - 'theta_jn': r'$\theta_{JN}$', - 'cos_theta_jn': r'$\cos\theta_{JN}$', - 'psi': r'$\psi$', - 'phase': r'$\phi$', - 'geocent_time': '$t_c$', - 'time_jitter': '$t_j$', - 'lambda_1': r'$\Lambda_1$', - 'lambda_2': r'$\Lambda_2$', - 'lambda_tilde': r'$\tilde{\Lambda}$', - 'delta_lambda_tilde': r'$\delta\tilde{\Lambda}$', - 'chi_1': r'$\chi_1$', - 'chi_2': r'$\chi_2$', - 'chi_1_in_plane': r'$\chi_{1, \perp}$', - 'chi_2_in_plane': r'$\chi_{2, \perp}$', + "mass_1": "$m_1$", + "mass_2": "$m_2$", + "total_mass": "$M$", + "chirp_mass": r"$\mathcal{M}$", + "mass_ratio": "$q$", + "symmetric_mass_ratio": r"$\eta$", + "a_1": "$a_1$", + "a_2": "$a_2$", + "tilt_1": r"$\theta_1$", + "tilt_2": r"$\theta_2$", + "cos_tilt_1": r"$\cos\theta_1$", + "cos_tilt_2": r"$\cos\theta_2$", + "phi_12": r"$\Delta\phi$", + "phi_jl": r"$\phi_{JL}$", + "luminosity_distance": "$d_L$", + "dec": r"$\mathrm{DEC}$", + "ra": r"$\mathrm{RA}$", + "iota": r"$\iota$", + "cos_iota": r"$\cos\iota$", + "theta_jn": r"$\theta_{JN}$", + "cos_theta_jn": r"$\cos\theta_{JN}$", + "psi": r"$\psi$", + "phase": r"$\phi$", + "geocent_time": "$t_c$", + "time_jitter": "$t_j$", + "lambda_1": r"$\Lambda_1$", + "lambda_2": r"$\Lambda_2$", + "lambda_tilde": r"$\tilde{\Lambda}$", + "delta_lambda_tilde": r"$\delta\tilde{\Lambda}$", + "chi_1": r"$\chi_1$", + "chi_2": r"$\chi_2$", + "chi_1_in_plane": r"$\chi_{1, \perp}$", + "chi_2_in_plane": r"$\chi_{2, \perp}$", } @@ -1201,7 +1210,7 @@ def __init__(self, dictionary=None, filename=None): """ if dictionary is None and filename is not None: filename = os.path.join(DEFAULT_PRIOR_DIR, filename) - super(CalibrationPriorDict, self).__init__(dictionary=dictionary, filename=filename) + super().__init__(dictionary=dictionary, filename=filename) self.source = None def to_file(self, outdir, label): @@ -1218,14 +1227,14 @@ def to_file(self, outdir, label): """ PriorDict.to_file(self, outdir=outdir, label=label) if self.source is not None: - prior_file = os.path.join(outdir, "{}.prior".format(label)) + prior_file = os.path.join(outdir, f"{label}.prior") with open(prior_file, "a") as outfile: - outfile.write("# prior source file is {}".format(self.source)) + outfile.write(f"# prior source file is {self.source}") @staticmethod - def from_envelope_file(envelope_file, minimum_frequency, - maximum_frequency, n_nodes, label, - boundary="reflective", correction_type=None): + def from_envelope_file( + envelope_file, minimum_frequency, maximum_frequency, n_nodes, label, boundary="reflective", correction_type=None + ): """ Load in the calibration envelope. @@ -1274,6 +1283,7 @@ def from_envelope_file(envelope_file, minimum_frequency, This includes the frequencies of the nodes which are _not_ sampled. """ from .detector.calibration import _check_calibration_correction_type + correction_type = _check_calibration_correction_type(correction_type=correction_type) calibration_data = np.genfromtxt(envelope_file).T @@ -1290,45 +1300,45 @@ def from_envelope_file(envelope_file, minimum_frequency, amplitude_sigma = abs(calibration_data[5] - calibration_data[3]) / 2 phase_sigma = abs(calibration_data[6] - calibration_data[4]) / 2 - log_nodes = np.linspace(np.log(minimum_frequency), - np.log(maximum_frequency), n_nodes) + log_nodes = np.linspace(np.log(minimum_frequency), np.log(maximum_frequency), n_nodes) - amplitude_mean_nodes = \ - InterpolatedUnivariateSpline(log_frequency_array, amplitude_median)(log_nodes) - amplitude_sigma_nodes = \ - InterpolatedUnivariateSpline(log_frequency_array, amplitude_sigma)(log_nodes) - phase_mean_nodes = \ - InterpolatedUnivariateSpline(log_frequency_array, phase_median)(log_nodes) - phase_sigma_nodes = \ - InterpolatedUnivariateSpline(log_frequency_array, phase_sigma)(log_nodes) + amplitude_mean_nodes = InterpolatedUnivariateSpline(log_frequency_array, amplitude_median)(log_nodes) + amplitude_sigma_nodes = InterpolatedUnivariateSpline(log_frequency_array, amplitude_sigma)(log_nodes) + phase_mean_nodes = InterpolatedUnivariateSpline(log_frequency_array, phase_median)(log_nodes) + phase_sigma_nodes = InterpolatedUnivariateSpline(log_frequency_array, phase_sigma)(log_nodes) prior = CalibrationPriorDict() for ii in range(n_nodes): - name = "recalib_{}_amplitude_{}".format(label, ii) - latex_label = "$A^{}_{}$".format(label, ii) - prior[name] = Gaussian(mu=amplitude_mean_nodes[ii], - sigma=amplitude_sigma_nodes[ii], - name=name, latex_label=latex_label, - boundary=boundary) + name = f"recalib_{label}_amplitude_{ii}" + latex_label = f"$A^{label}_{ii}$" + prior[name] = Gaussian( + mu=amplitude_mean_nodes[ii], + sigma=amplitude_sigma_nodes[ii], + name=name, + latex_label=latex_label, + boundary=boundary, + ) for ii in range(n_nodes): - name = "recalib_{}_phase_{}".format(label, ii) - latex_label = r"$\phi^{}_{}$".format(label, ii) - prior[name] = Gaussian(mu=phase_mean_nodes[ii], - sigma=phase_sigma_nodes[ii], - name=name, latex_label=latex_label, - boundary=boundary) + name = f"recalib_{label}_phase_{ii}" + latex_label = rf"$\phi^{label}_{ii}$" + prior[name] = Gaussian( + mu=phase_mean_nodes[ii], + sigma=phase_sigma_nodes[ii], + name=name, + latex_label=latex_label, + boundary=boundary, + ) for ii in range(n_nodes): - name = "recalib_{}_frequency_{}".format(label, ii) - latex_label = "$f^{}_{}$".format(label, ii) - prior[name] = DeltaFunction(peak=np.exp(log_nodes[ii]), name=name, - latex_label=latex_label) + name = f"recalib_{label}_frequency_{ii}" + latex_label = f"$f^{label}_{ii}$" + prior[name] = DeltaFunction(peak=np.exp(log_nodes[ii]), name=name, latex_label=latex_label) prior.source = os.path.abspath(envelope_file) return prior @staticmethod def constant_uncertainty_spline( - amplitude_sigma, phase_sigma, minimum_frequency, maximum_frequency, - n_nodes, label, boundary="reflective"): + amplitude_sigma, phase_sigma, minimum_frequency, maximum_frequency, n_nodes, label, boundary="reflective" + ): """ Make prior assuming constant in frequency calibration uncertainty. @@ -1357,8 +1367,7 @@ def constant_uncertainty_spline( Priors for the relevant parameters. This includes the frequencies of the nodes which are _not_ sampled. """ - nodes = np.logspace(np.log10(minimum_frequency), - np.log10(maximum_frequency), n_nodes) + nodes = np.logspace(np.log10(minimum_frequency), np.log10(maximum_frequency), n_nodes) amplitude_mean_nodes = [0] * n_nodes amplitude_sigma_nodes = [amplitude_sigma] * n_nodes @@ -1367,24 +1376,29 @@ def constant_uncertainty_spline( prior = CalibrationPriorDict() for ii in range(n_nodes): - name = "recalib_{}_amplitude_{}".format(label, ii) - latex_label = "$A^{}_{}$".format(label, ii) - prior[name] = Gaussian(mu=amplitude_mean_nodes[ii], - sigma=amplitude_sigma_nodes[ii], - name=name, latex_label=latex_label, - boundary=boundary) + name = f"recalib_{label}_amplitude_{ii}" + latex_label = f"$A^{label}_{ii}$" + prior[name] = Gaussian( + mu=amplitude_mean_nodes[ii], + sigma=amplitude_sigma_nodes[ii], + name=name, + latex_label=latex_label, + boundary=boundary, + ) for ii in range(n_nodes): - name = "recalib_{}_phase_{}".format(label, ii) - latex_label = r"$\phi^{}_{}$".format(label, ii) - prior[name] = Gaussian(mu=phase_mean_nodes[ii], - sigma=phase_sigma_nodes[ii], - name=name, latex_label=latex_label, - boundary=boundary) + name = f"recalib_{label}_phase_{ii}" + latex_label = rf"$\phi^{label}_{ii}$" + prior[name] = Gaussian( + mu=phase_mean_nodes[ii], + sigma=phase_sigma_nodes[ii], + name=name, + latex_label=latex_label, + boundary=boundary, + ) for ii in range(n_nodes): - name = "recalib_{}_frequency_{}".format(label, ii) - latex_label = "$f^{}_{}$".format(label, ii) - prior[name] = DeltaFunction(peak=nodes[ii], name=name, - latex_label=latex_label) + name = f"recalib_{label}_frequency_{ii}" + latex_label = f"$f^{label}_{ii}$" + prior[name] = DeltaFunction(peak=nodes[ii], name=name, latex_label=latex_label) return prior @@ -1398,6 +1412,7 @@ def secondary_mass_condition_function(reference_params, mass_1): .. code-block:: python import bilby + priors = bilby.gw.prior.CBCPriorDict() priors["mass_1"] = bilby.core.prior.Uniform(5, 50) priors["mass_2"] = bilby.core.prior.ConditionalUniform( @@ -1419,7 +1434,7 @@ def secondary_mass_condition_function(reference_params, mass_1): dict: Updated prior limits given the provided primary mass. """ - return dict(minimum=reference_params['minimum'], maximum=mass_1) + return dict(minimum=reference_params["minimum"], maximum=mass_1) ConditionalCosmological = conditional_prior_factory(Cosmological) @@ -1451,6 +1466,7 @@ class HealPixMapPriorDist(BaseJointPriorDist): PriorDist : `bilby.gw.prior.HealPixMapPriorDist` A JointPriorDist object to store the joint prior distribution according to passed healpix map """ + def __init__(self, hp_file, names=None, bounds=None, distance=False): self.hp = self._check_imports() self.hp_file = hp_file @@ -1469,15 +1485,13 @@ def __init__(self, hp_file, names=None, bounds=None, distance=False): if len(bounds) == 2: bounds.append([0, np.inf]) self.distance = True - self.prob, self.distmu, self.distsigma, self.distnorm = self.hp.read_map( - hp_file, field=range(4) - ) + self.prob, self.distmu, self.distsigma, self.distnorm = self.hp.read_map(hp_file, field=range(4)) else: self.distance = False self.prob = self.hp.read_map(hp_file) self.prob = self._check_norm(self.prob) - super(HealPixMapPriorDist, self).__init__(names=names, bounds=bounds) + super().__init__(names=names, bounds=bounds) self.distname = "hpmap" self.npix = len(self.prob) self.nside = self.hp.npix2nside(self.npix) @@ -1540,7 +1554,7 @@ def _rescale(self, samp, **kwargs): samp = samp[:, 0] pix_rescale = self.inverse_cdf(samp) sample = np.empty((len(pix_rescale), 2)) - dist_samples = np.empty((len(pix_rescale))) + dist_samples = np.empty(len(pix_rescale)) for i, val in enumerate(pix_rescale): theta, ra = self.hp.pix2ang(self.nside, int(round(val))) dec = 0.5 * np.pi - theta @@ -1571,7 +1585,7 @@ def update_distance(self, pix_idx): self.distance_pdf = lambda r: self.distnorm[pix_idx] * norm( loc=self.distmu[pix_idx], scale=self.distsigma[pix_idx] ).pdf(r) - pdfs = self.rs ** 2 * self.distance_pdf(self.rs) + pdfs = self.rs**2 * self.distance_pdf(self.rs) cdfs = np.cumsum(pdfs) / np.sum(pdfs) self.distance_icdf = interp1d(cdfs, self.rs) @@ -1734,7 +1748,7 @@ def _ln_prob(self, samp, lnprob, outbounds): lnprob[i] = np.log(self.prob[pixel] / self.pixel_area) if self.distance: self.update_distance(pixel) - lnprob[i] += np.log(self.distance_pdf(dist) * dist ** 2) + lnprob[i] += np.log(self.distance_pdf(dist) * dist**2) lnprob[outbounds] = -np.inf return lnprob @@ -1777,6 +1791,7 @@ class HealPixPrior(JointPrior): See :code:`bilby.gw.prior.HealPixMapPriorDist` for more details of how to instantiate the prior. """ + def __init__(self, dist, name=None, latex_label=None, unit=None): """ @@ -1795,4 +1810,4 @@ def __init__(self, dist, name=None, latex_label=None, unit=None): """ if not isinstance(dist, HealPixMapPriorDist): raise JointPriorDistError("dist object must be instance of HealPixMapPriorDist") - super(HealPixPrior, self).__init__(dist=dist, name=name, latex_label=latex_label, unit=unit) + super().__init__(dist=dist, name=name, latex_label=latex_label, unit=unit) diff --git a/bilby/gw/result.py b/bilby/gw/result.py index 01ebb4971..51751f379 100644 --- a/bilby/gw/result.py +++ b/bilby/gw/result.py @@ -6,11 +6,15 @@ from ..core.result import Result as CoreResult from ..core.utils import ( - infft, logger, check_directory_exists_and_if_not_mkdir, - latex_plot_format, safe_file_dump, safe_save_figure, + check_directory_exists_and_if_not_mkdir, + infft, + latex_plot_format, + logger, + safe_file_dump, + safe_save_figure, ) -from .utils import plot_spline_pos, spline_angle_xform, asd_from_freq_series -from .detector import get_empty_interferometer, Interferometer +from .detector import Interferometer, get_empty_interferometer +from .utils import asd_from_freq_series, plot_spline_pos, spline_angle_xform class CompactBinaryCoalescenceResult(CoreResult): @@ -18,8 +22,8 @@ class CompactBinaryCoalescenceResult(CoreResult): Result class with additional methods and attributes specific to analyses of compact binaries. """ - def __init__(self, **kwargs): + def __init__(self, **kwargs): if "meta_data" not in kwargs: kwargs["meta_data"] = dict() if "global_meta_data" not in kwargs: @@ -29,9 +33,10 @@ def __init__(self, **kwargs): # Ensure cosmology is always stored in the meta_data if "cosmology" not in kwargs["meta_data"]["global_meta_data"]: from .cosmology import get_cosmology + kwargs["meta_data"]["global_meta_data"]["cosmology"] = get_cosmology() - super(CompactBinaryCoalescenceResult, self).__init__(**kwargs) + super().__init__(**kwargs) def __get_from_nested_meta_data(self, *keys): dictionary = self.meta_data @@ -42,92 +47,77 @@ def __get_from_nested_meta_data(self, *keys): dictionary = item return item except KeyError: - raise AttributeError( - "No information stored for {}".format('/'.join(keys))) + raise AttributeError("No information stored for {}".format("/".join(keys))) @property def sampling_frequency(self): - """ Sampling frequency in Hertz""" - return self.__get_from_nested_meta_data( - 'likelihood', 'sampling_frequency') + """Sampling frequency in Hertz""" + return self.__get_from_nested_meta_data("likelihood", "sampling_frequency") @property def duration(self): - """ Duration in seconds """ - return self.__get_from_nested_meta_data( - 'likelihood', 'duration') + """Duration in seconds""" + return self.__get_from_nested_meta_data("likelihood", "duration") @property def start_time(self): - """ Start time in seconds """ - return self.__get_from_nested_meta_data( - 'likelihood', 'start_time') + """Start time in seconds""" + return self.__get_from_nested_meta_data("likelihood", "start_time") @property def time_marginalization(self): - """ Boolean for if the likelihood used time marginalization """ - return self.__get_from_nested_meta_data( - 'likelihood', 'time_marginalization') + """Boolean for if the likelihood used time marginalization""" + return self.__get_from_nested_meta_data("likelihood", "time_marginalization") @property def phase_marginalization(self): - """ Boolean for if the likelihood used phase marginalization """ - return self.__get_from_nested_meta_data( - 'likelihood', 'phase_marginalization') + """Boolean for if the likelihood used phase marginalization""" + return self.__get_from_nested_meta_data("likelihood", "phase_marginalization") @property def distance_marginalization(self): - """ Boolean for if the likelihood used distance marginalization """ - return self.__get_from_nested_meta_data( - 'likelihood', 'distance_marginalization') + """Boolean for if the likelihood used distance marginalization""" + return self.__get_from_nested_meta_data("likelihood", "distance_marginalization") @property def interferometers(self): - """ List of interferometer names """ - return [name for name in self.__get_from_nested_meta_data( - 'likelihood', 'interferometers')] + """List of interferometer names""" + return [name for name in self.__get_from_nested_meta_data("likelihood", "interferometers")] @property def waveform_approximant(self): - """ String of the waveform approximant """ - return self.__get_from_nested_meta_data( - 'likelihood', 'waveform_arguments', 'waveform_approximant') + """String of the waveform approximant""" + return self.__get_from_nested_meta_data("likelihood", "waveform_arguments", "waveform_approximant") @property def waveform_generator_class(self): - """ Dict of waveform arguments """ - return self.__get_from_nested_meta_data( - 'likelihood', 'waveform_generator_class') + """Dict of waveform arguments""" + return self.__get_from_nested_meta_data("likelihood", "waveform_generator_class") @property def waveform_arguments(self): - """ Dict of waveform arguments """ - return self.__get_from_nested_meta_data( - 'likelihood', 'waveform_arguments') + """Dict of waveform arguments""" + return self.__get_from_nested_meta_data("likelihood", "waveform_arguments") @property def reference_frequency(self): - """ Float of the reference frequency """ - return self.__get_from_nested_meta_data( - 'likelihood', 'waveform_arguments', 'reference_frequency') + """Float of the reference frequency""" + return self.__get_from_nested_meta_data("likelihood", "waveform_arguments", "reference_frequency") @property def frequency_domain_source_model(self): - """ The frequency domain source model (function)""" - return self.__get_from_nested_meta_data( - 'likelihood', 'frequency_domain_source_model') + """The frequency domain source model (function)""" + return self.__get_from_nested_meta_data("likelihood", "frequency_domain_source_model") @property def time_domain_source_model(self): - """ The time domain source model (function)""" - return self.__get_from_nested_meta_data( - 'likelihood', 'time_domain_source_model') + """The time domain source model (function)""" + return self.__get_from_nested_meta_data("likelihood", "time_domain_source_model") @property def parameter_conversion(self): - """ The frequency domain source model (function)""" - return self.__get_from_nested_meta_data( - 'likelihood', 'parameter_conversion') + """The frequency domain source model (function)""" + return self.__get_from_nested_meta_data("likelihood", "parameter_conversion") @property def cosmology(self): @@ -140,9 +130,7 @@ def cosmology(self): .. versionadded:: 2.5.0 """ try: - return self.__get_from_nested_meta_data( - 'global_meta_data', 'cosmology' - ) + return self.__get_from_nested_meta_data("global_meta_data", "cosmology") except AttributeError as e: logger.warning( "No cosmology found in result. " @@ -152,7 +140,7 @@ def cosmology(self): return None def detector_injection_properties(self, detector): - """ Returns a dictionary of the injection properties for each detector + """Returns a dictionary of the injection properties for each detector The injection properties include the parameters injected, and information about the signal to noise ratio (SNR) given the noise @@ -170,15 +158,14 @@ def detector_injection_properties(self, detector): """ try: - return self.__get_from_nested_meta_data( - 'likelihood', 'interferometers', detector) + return self.__get_from_nested_meta_data("likelihood", "interferometers", detector) except AttributeError: - logger.info("No injection for detector {}".format(detector)) + logger.info(f"No injection for detector {detector}") return None @latex_plot_format - def plot_calibration_posterior(self, level=.9, format="png"): - """ Plots the calibration amplitude and phase uncertainty. + def plot_calibration_posterior(self, level=0.9, format="png"): + """Plots the calibration amplitude and phase uncertainty. Adapted from the LALInference version in bayespputils Plot is saved to {self.outdir}/{self.label}_calibration.{format} @@ -191,6 +178,7 @@ def plot_calibration_posterior(self, level=.9, format="png"): Format to save the plot, default=png, options are png/pdf """ import matplotlib.pyplot as plt + if format not in ["png", "pdf"]: raise ValueError("Format should be one of png or pdf") @@ -201,68 +189,72 @@ def plot_calibration_posterior(self, level=.9, format="png"): outdir = self.outdir parameters = posterior.keys() - ifos = np.unique([param.split('_')[1] for param in parameters if 'recalib_' in param]) + ifos = np.unique([param.split("_")[1] for param in parameters if "recalib_" in param]) if ifos.size == 0: logger.info("No calibration parameters found. Aborting calibration plot.") return for ifo in ifos: - if ifo == 'H1': - color = 'r' - elif ifo == 'L1': - color = 'g' - elif ifo == 'V1': - color = 'm' + if ifo == "H1": + color = "r" + elif ifo == "L1": + color = "g" + elif ifo == "V1": + color = "m" else: - color = 'c' + color = "c" # Assume spline control frequencies are constant - freq_params = np.sort([param for param in parameters if - 'recalib_{0}_frequency_'.format(ifo) in param]) + freq_params = np.sort([param for param in parameters if f"recalib_{ifo}_frequency_" in param]) logfreqs = np.log([posterior[param].iloc[0] for param in freq_params]) # Amplitude calibration model plt.sca(ax1) - amp_params = np.sort([param for param in parameters if - 'recalib_{0}_amplitude_'.format(ifo) in param]) + amp_params = np.sort([param for param in parameters if f"recalib_{ifo}_amplitude_" in param]) if len(amp_params) > 0: amplitude = 100 * np.column_stack([posterior[param] for param in amp_params]) - plot_spline_pos(logfreqs, amplitude, color=color, level=level, - label=r"{0} (mean, {1}$\%$)".format(ifo.upper(), int(level * 100))) + plot_spline_pos( + logfreqs, + amplitude, + color=color, + level=level, + label=rf"{ifo.upper()} (mean, {int(level * 100)}$\%$)", + ) # Phase calibration model plt.sca(ax2) - phase_params = np.sort([param for param in parameters if - 'recalib_{0}_phase_'.format(ifo) in param]) + phase_params = np.sort([param for param in parameters if f"recalib_{ifo}_phase_" in param]) if len(phase_params) > 0: phase = np.column_stack([posterior[param] for param in phase_params]) - plot_spline_pos(logfreqs, phase, color=color, level=level, - label=r"{0} (mean, {1}$\%$)".format(ifo.upper(), int(level * 100)), - xform=spline_angle_xform) + plot_spline_pos( + logfreqs, + phase, + color=color, + level=level, + label=rf"{ifo.upper()} (mean, {int(level * 100)}$\%$)", + xform=spline_angle_xform, + ) - ax1.tick_params(labelsize=.75 * font_size) - ax2.tick_params(labelsize=.75 * font_size) - plt.legend(loc='upper right', prop={'size': .75 * font_size}, framealpha=0.1) - ax1.set_xscale('log') - ax2.set_xscale('log') + ax1.tick_params(labelsize=0.75 * font_size) + ax2.tick_params(labelsize=0.75 * font_size) + plt.legend(loc="upper right", prop={"size": 0.75 * font_size}, framealpha=0.1) + ax1.set_xscale("log") + ax2.set_xscale("log") - ax2.set_xlabel('Frequency [Hz]', fontsize=font_size) - ax1.set_ylabel(r'Amplitude [$\%$]', fontsize=font_size) - ax2.set_ylabel('Phase [deg]', fontsize=font_size) + ax2.set_xlabel("Frequency [Hz]", fontsize=font_size) + ax1.set_ylabel(r"Amplitude [$\%$]", fontsize=font_size) + ax2.set_ylabel("Phase [deg]", fontsize=font_size) - filename = os.path.join(outdir, self.label + '_calibration.' + format) + filename = os.path.join(outdir, self.label + "_calibration." + format) fig.tight_layout() - safe_save_figure( - fig=fig, filename=filename, - format=format, dpi=600, bbox_inches='tight' - ) - logger.debug("Calibration figure saved to {}".format(filename)) + safe_save_figure(fig=fig, filename=filename, format=format, dpi=600, bbox_inches="tight") + logger.debug(f"Calibration figure saved to {filename}") plt.close() def plot_waveform_posterior( - self, interferometers=None, level=0.9, n_samples=None, - format='png', start_time=None, end_time=None): + self, interferometers=None, level=0.9, n_samples=None, format="png", start_time=None, end_time=None + ): """ Plot the posterior for the waveform in the frequency domain and whitened time domain for all detectors. @@ -293,18 +285,22 @@ def plot_waveform_posterior( if interferometers is None: interferometers = self.interferometers elif not isinstance(interferometers, list): - raise TypeError( - 'interferometers must be a list or InterferometerList') + raise TypeError("interferometers must be a list or InterferometerList") for ifo in interferometers: self.plot_interferometer_waveform_posterior( - interferometer=ifo, level=level, n_samples=n_samples, - save=True, format=format, start_time=start_time, - end_time=end_time) + interferometer=ifo, + level=level, + n_samples=n_samples, + save=True, + format=format, + start_time=start_time, + end_time=end_time, + ) @latex_plot_format def plot_interferometer_waveform_posterior( - self, interferometer, level=0.9, n_samples=None, save=True, - format='png', start_time=None, end_time=None): + self, interferometer, level=0.9, n_samples=None, save=True, format="png", start_time=None, end_time=None + ): """ Plot the posterior for the waveform in the frequency domain and whitened time domain. @@ -360,57 +356,46 @@ def plot_interferometer_waveform_posterior( except ImportError: logger.warning( "HTML plotting requested, but plotly cannot be imported, " - "falling back to png format for waveform plot.") + "falling back to png format for waveform plot." + ) format = "png" if isinstance(interferometer, str): interferometer = get_empty_interferometer(interferometer) interferometer.set_strain_data_from_zero_noise( - sampling_frequency=self.sampling_frequency, - duration=self.duration, start_time=self.start_time) + sampling_frequency=self.sampling_frequency, duration=self.duration, start_time=self.start_time + ) PLOT_DATA = False elif not isinstance(interferometer, Interferometer): - raise TypeError( - 'interferometer must be either str or Interferometer') + raise TypeError("interferometer must be either str or Interferometer") else: PLOT_DATA = True - logger.info("Generating waveform figure for {}".format( - interferometer.name)) + logger.info(f"Generating waveform figure for {interferometer.name}") if n_samples is None: samples = self.posterior elif n_samples > len(self.posterior): logger.debug( - "Requested more waveform samples ({}) than we have " - "posterior samples ({})!".format( - n_samples, len(self.posterior) - ) + f"Requested more waveform samples ({n_samples}) than we have posterior samples ({len(self.posterior)})!" ) samples = self.posterior else: samples = self.posterior.sample(n_samples, replace=False) if start_time is None: - start_time = - 0.4 + start_time = -0.4 start_time = np.mean(samples.geocent_time) + start_time if end_time is None: end_time = 0.2 end_time = np.mean(samples.geocent_time) + end_time if format == "html": - start_time = - np.inf + start_time = -np.inf end_time = np.inf - time_idxs = ( - (interferometer.time_array >= start_time) & - (interferometer.time_array <= end_time) - ) + time_idxs = (interferometer.time_array >= start_time) & (interferometer.time_array <= end_time) frequency_idxs = np.where(interferometer.frequency_mask)[0] - logger.debug("Frequency mask contains {} values".format( - len(frequency_idxs)) - ) - frequency_idxs = frequency_idxs[::max(1, len(frequency_idxs) // 4000)] - logger.debug("Downsampling frequency mask to {} values".format( - len(frequency_idxs)) - ) + logger.debug(f"Frequency mask contains {len(frequency_idxs)} values") + frequency_idxs = frequency_idxs[:: max(1, len(frequency_idxs) // 4000)] + logger.debug(f"Downsampling frequency mask to {len(frequency_idxs)} values") plot_times = interferometer.time_array[time_idxs] plot_times -= interferometer.strain_data.start_time start_time -= interferometer.strain_data.start_time @@ -418,34 +403,34 @@ def plot_interferometer_waveform_posterior( plot_frequencies = interferometer.frequency_array[frequency_idxs] waveform_generator = self.waveform_generator_class( - duration=self.duration, sampling_frequency=self.sampling_frequency, + duration=self.duration, + sampling_frequency=self.sampling_frequency, start_time=self.start_time, frequency_domain_source_model=self.frequency_domain_source_model, time_domain_source_model=self.time_domain_source_model, parameter_conversion=self.parameter_conversion, - waveform_arguments=self.waveform_arguments) + waveform_arguments=self.waveform_arguments, + ) if format == "html": fig = make_subplots( - rows=2, cols=1, + rows=2, + cols=1, row_heights=[0.5, 0.5], ) fig.update_layout( - template='plotly_white', + template="plotly_white", font=dict( family="Computer Modern", - ) + ), ) else: import matplotlib.pyplot as plt from matplotlib import rcParams + old_font_size = rcParams["font.size"] rcParams["font.size"] = 20 - fig, axs = plt.subplots( - 2, 1, - gridspec_kw=dict(height_ratios=[1.5, 1]), - figsize=(16, 12.5) - ) + fig, axs = plt.subplots(2, 1, gridspec_kw=dict(height_ratios=[1.5, 1]), figsize=(16, 12.5)) if PLOT_DATA: if format == "html": @@ -454,13 +439,14 @@ def plot_interferometer_waveform_posterior( x=plot_frequencies, y=asd_from_freq_series( interferometer.frequency_domain_strain[frequency_idxs], - 1 / interferometer.strain_data.duration + 1 / interferometer.strain_data.duration, ), fill=None, - mode='lines', line_color=DATA_COLOR, + mode="lines", + line_color=DATA_COLOR, opacity=0.5, name="Data", - legendgroup='data', + legendgroup="data", ), row=1, col=1, @@ -470,10 +456,11 @@ def plot_interferometer_waveform_posterior( x=plot_frequencies, y=interferometer.amplitude_spectral_density_array[frequency_idxs], fill=None, - mode='lines', line_color=DATA_COLOR, + mode="lines", + line_color=DATA_COLOR, opacity=0.8, name="ASD", - legendgroup='asd', + legendgroup="asd", ), row=1, col=1, @@ -483,10 +470,11 @@ def plot_interferometer_waveform_posterior( x=plot_times, y=interferometer.whitened_time_domain_strain[time_idxs], fill=None, - mode='lines', line_color=DATA_COLOR, + mode="lines", + line_color=DATA_COLOR, opacity=0.5, name="Data", - legendgroup='data', + legendgroup="data", showlegend=False, ), row=2, @@ -496,17 +484,22 @@ def plot_interferometer_waveform_posterior( axs[0].loglog( plot_frequencies, asd_from_freq_series( - interferometer.frequency_domain_strain[frequency_idxs], - 1 / interferometer.strain_data.duration), - color=DATA_COLOR, label='Data', alpha=0.3) + interferometer.frequency_domain_strain[frequency_idxs], 1 / interferometer.strain_data.duration + ), + color=DATA_COLOR, + label="Data", + alpha=0.3, + ) axs[0].loglog( plot_frequencies, interferometer.amplitude_spectral_density_array[frequency_idxs], - color=DATA_COLOR, label='ASD') + color=DATA_COLOR, + label="ASD", + ) axs[1].plot( - plot_times, interferometer.whitened_time_domain_strain[time_idxs], - color=DATA_COLOR, alpha=0.3) - logger.debug('Plotted interferometer data.') + plot_times, interferometer.whitened_time_domain_strain[time_idxs], color=DATA_COLOR, alpha=0.3 + ) + logger.debug("Plotted interferometer data.") fd_waveforms = list() td_waveforms = list() @@ -515,59 +508,57 @@ def plot_interferometer_waveform_posterior( fd_waveform = interferometer.get_detector_response(wf_pols, params) fd_waveforms.append(fd_waveform[frequency_idxs]) whitened_fd_waveform = interferometer.whiten_frequency_series(fd_waveform) - td_waveform = interferometer.get_whitened_time_series_from_whitened_frequency_series( - whitened_fd_waveform - )[time_idxs] + td_waveform = interferometer.get_whitened_time_series_from_whitened_frequency_series(whitened_fd_waveform)[ + time_idxs + ] td_waveforms.append(td_waveform) - fd_waveforms = asd_from_freq_series( - fd_waveforms, - 1 / interferometer.strain_data.duration) + fd_waveforms = asd_from_freq_series(fd_waveforms, 1 / interferometer.strain_data.duration) td_waveforms = np.array(td_waveforms) delta = (1 + level) / 2 upper_percentile = delta * 100 lower_percentile = (1 - delta) * 100 - logger.debug( - 'Plotting posterior between the {} and {} percentiles'.format( - lower_percentile, upper_percentile - ) - ) + logger.debug(f"Plotting posterior between the {lower_percentile} and {upper_percentile} percentiles") if format == "html": fig.add_trace( go.Scatter( - x=plot_frequencies, y=np.median(fd_waveforms, axis=0), + x=plot_frequencies, + y=np.median(fd_waveforms, axis=0), fill=None, - mode='lines', line_color=WAVEFORM_COLOR, + mode="lines", + line_color=WAVEFORM_COLOR, opacity=1, name="Median reconstructed", - legendgroup='median', + legendgroup="median", ), row=1, col=1, ) fig.add_trace( go.Scatter( - x=plot_frequencies, y=np.percentile(fd_waveforms, lower_percentile, axis=0), + x=plot_frequencies, + y=np.percentile(fd_waveforms, lower_percentile, axis=0), fill=None, - mode='lines', + mode="lines", line_color=WAVEFORM_COLOR, opacity=0.1, - name="{:.2f}% credible interval".format(upper_percentile - lower_percentile), - legendgroup='uncertainty', + name=f"{upper_percentile - lower_percentile:.2f}% credible interval", + legendgroup="uncertainty", ), row=1, col=1, ) fig.add_trace( go.Scatter( - x=plot_frequencies, y=np.percentile(fd_waveforms, upper_percentile, axis=0), - fill='tonexty', - mode='lines', + x=plot_frequencies, + y=np.percentile(fd_waveforms, upper_percentile, axis=0), + fill="tonexty", + mode="lines", line_color=WAVEFORM_COLOR, opacity=0.1, - name="{:.2f}% credible interval".format(upper_percentile - lower_percentile), - legendgroup='uncertainty', + name=f"{upper_percentile - lower_percentile:.2f}% credible interval", + legendgroup="uncertainty", showlegend=False, ), row=1, @@ -575,12 +566,14 @@ def plot_interferometer_waveform_posterior( ) fig.add_trace( go.Scatter( - x=plot_times, y=np.median(td_waveforms, axis=0), + x=plot_times, + y=np.median(td_waveforms, axis=0), fill=None, - mode='lines', line_color=WAVEFORM_COLOR, + mode="lines", + line_color=WAVEFORM_COLOR, opacity=1, name="Median reconstructed", - legendgroup='median', + legendgroup="median", showlegend=False, ), row=2, @@ -588,13 +581,14 @@ def plot_interferometer_waveform_posterior( ) fig.add_trace( go.Scatter( - x=plot_times, y=np.percentile(td_waveforms, lower_percentile, axis=0), + x=plot_times, + y=np.percentile(td_waveforms, lower_percentile, axis=0), fill=None, - mode='lines', + mode="lines", line_color=WAVEFORM_COLOR, opacity=0.1, - name="{:.2f}% credible interval".format(upper_percentile - lower_percentile), - legendgroup='uncertainty', + name=f"{upper_percentile - lower_percentile:.2f}% credible interval", + legendgroup="uncertainty", showlegend=False, ), row=2, @@ -602,13 +596,14 @@ def plot_interferometer_waveform_posterior( ) fig.add_trace( go.Scatter( - x=plot_times, y=np.percentile(td_waveforms, upper_percentile, axis=0), - fill='tonexty', - mode='lines', + x=plot_times, + y=np.percentile(td_waveforms, upper_percentile, axis=0), + fill="tonexty", + mode="lines", line_color=WAVEFORM_COLOR, opacity=0.1, - name="{:.2f}% credible interval".format(upper_percentile - lower_percentile), - legendgroup='uncertainty', + name=f"{upper_percentile - lower_percentile:.2f}% credible interval", + legendgroup="uncertainty", showlegend=False, ), row=2, @@ -617,59 +612,58 @@ def plot_interferometer_waveform_posterior( else: lower_limit = np.mean(fd_waveforms, axis=0)[0] / 1e3 axs[0].loglog( - plot_frequencies, - np.mean(fd_waveforms, axis=0), color=WAVEFORM_COLOR, label='Mean reconstructed') + plot_frequencies, np.mean(fd_waveforms, axis=0), color=WAVEFORM_COLOR, label="Mean reconstructed" + ) axs[0].fill_between( plot_frequencies, np.percentile(fd_waveforms, lower_percentile, axis=0), np.percentile(fd_waveforms, upper_percentile, axis=0), color=WAVEFORM_COLOR, - label=r'{}% credible interval'.format(int(upper_percentile - lower_percentile)), - alpha=0.3) - axs[1].plot( - plot_times, np.mean(td_waveforms, axis=0), - color=WAVEFORM_COLOR) + label=rf"{int(upper_percentile - lower_percentile)}% credible interval", + alpha=0.3, + ) + axs[1].plot(plot_times, np.mean(td_waveforms, axis=0), color=WAVEFORM_COLOR) axs[1].fill_between( - plot_times, np.percentile( - td_waveforms, lower_percentile, axis=0), + plot_times, + np.percentile(td_waveforms, lower_percentile, axis=0), np.percentile(td_waveforms, upper_percentile, axis=0), color=WAVEFORM_COLOR, - alpha=0.3) + alpha=0.3, + ) if self.injection_parameters is not None: try: - hf_inj = waveform_generator.frequency_domain_strain( - self.injection_parameters) - hf_inj_det = interferometer.get_detector_response( - hf_inj, self.injection_parameters) + hf_inj = waveform_generator.frequency_domain_strain(self.injection_parameters) + hf_inj_det = interferometer.get_detector_response(hf_inj, self.injection_parameters) ht_inj_det = infft( - hf_inj_det * np.sqrt(2. / interferometer.sampling_frequency) / - interferometer.amplitude_spectral_density_array, - self.sampling_frequency)[time_idxs] + hf_inj_det + * np.sqrt(2.0 / interferometer.sampling_frequency) + / interferometer.amplitude_spectral_density_array, + self.sampling_frequency, + )[time_idxs] if format == "html": fig.add_trace( go.Scatter( x=plot_frequencies, - y=asd_from_freq_series( - hf_inj_det[frequency_idxs], - 1 / interferometer.strain_data.duration), + y=asd_from_freq_series(hf_inj_det[frequency_idxs], 1 / interferometer.strain_data.duration), fill=None, - mode='lines', - line=dict(color=INJECTION_COLOR, dash='dot'), + mode="lines", + line=dict(color=INJECTION_COLOR, dash="dot"), name="Injection", - legendgroup='injection', + legendgroup="injection", ), row=1, col=1, ) fig.add_trace( go.Scatter( - x=plot_times, y=ht_inj_det, + x=plot_times, + y=ht_inj_det, fill=None, - mode='lines', - line=dict(color=INJECTION_COLOR, dash='dot'), + mode="lines", + line=dict(color=INJECTION_COLOR, dash="dot"), name="Injection", - legendgroup='injection', + legendgroup="injection", showlegend=False, ), row=2, @@ -678,20 +672,19 @@ def plot_interferometer_waveform_posterior( else: axs[0].loglog( plot_frequencies, - asd_from_freq_series( - hf_inj_det[frequency_idxs], - 1 / interferometer.strain_data.duration), - color=INJECTION_COLOR, label='Injection', linestyle=':') - axs[1].plot( - plot_times, ht_inj_det, - color=INJECTION_COLOR, linestyle=':') - logger.debug('Plotted injection.') + asd_from_freq_series(hf_inj_det[frequency_idxs], 1 / interferometer.strain_data.duration), + color=INJECTION_COLOR, + label="Injection", + linestyle=":", + ) + axs[1].plot(plot_times, ht_inj_det, color=INJECTION_COLOR, linestyle=":") + logger.debug("Plotted injection.") except IndexError as e: - logger.info('Failed to plot injection with message {}.'.format(e)) + logger.info(f"Failed to plot injection with message {e}.") f_domain_x_label = "$f [\\mathrm{Hz}]$" f_domain_y_label = "$\\mathrm{ASD} \\left[\\mathrm{Hz}^{-1/2}\\right]$" - t_domain_x_label = "$t - {} [s]$".format(interferometer.strain_data.start_time) + t_domain_x_label = f"$t - {interferometer.strain_data.start_time} [s]$" t_domain_y_label = "Whitened Strain" if format == "html": fig.update_xaxes(title_text=f_domain_x_label, type="log", row=1) @@ -699,42 +692,47 @@ def plot_interferometer_waveform_posterior( fig.update_xaxes(title_text=t_domain_x_label, type="linear", row=2) fig.update_yaxes(title_text=t_domain_y_label, type="linear", row=2) else: - axs[0].set_xlim(interferometer.minimum_frequency, - interferometer.maximum_frequency) + axs[0].set_xlim(interferometer.minimum_frequency, interferometer.maximum_frequency) axs[1].set_xlim(start_time, end_time) axs[0].set_ylim(lower_limit) axs[0].set_xlabel(f_domain_x_label) axs[0].set_ylabel(f_domain_y_label) axs[1].set_xlabel(t_domain_x_label) axs[1].set_ylabel(t_domain_y_label) - axs[0].legend(loc='lower left', ncol=2) + axs[0].legend(loc="lower left", ncol=2) if save: - filename = os.path.join( - self.outdir, - self.label + '_{}_waveform.{}'.format( - interferometer.name, format)) - if format == 'html': - plot(fig, filename=filename, include_mathjax='cdn', auto_open=False) + filename = os.path.join(self.outdir, self.label + f"_{interferometer.name}_waveform.{format}") + if format == "html": + plot(fig, filename=filename, include_mathjax="cdn", auto_open=False) else: plt.tight_layout() - safe_save_figure( - fig=fig, filename=filename, - format=format, dpi=600 - ) + safe_save_figure(fig=fig, filename=filename, format=format, dpi=600) plt.close() - logger.debug("Waveform figure saved to {}".format(filename)) + logger.debug(f"Waveform figure saved to {filename}") rcParams["font.size"] = old_font_size else: rcParams["font.size"] = old_font_size return fig def plot_skymap( - self, maxpts=None, trials=5, jobs=1, enable_multiresolution=True, - objid=None, instruments=None, geo=False, dpi=600, - transparent=False, colorbar=False, contour=[50, 90], - annotate=True, cmap='cylon', load_pickle=False): - """ Generate a fits file and sky map from a result + self, + maxpts=None, + trials=5, + jobs=1, + enable_multiresolution=True, + objid=None, + instruments=None, + geo=False, + dpi=600, + transparent=False, + colorbar=False, + contour=[50, 90], + annotate=True, + cmap="cylon", + load_pickle=False, + ): + """Generate a fits file and sky map from a result Code adapted from ligo.skymap.tool.ligo_skymap_from_samples and ligo.skymap.tool.plot_skymap. Note, the use of this additionally @@ -776,38 +774,38 @@ def plot_skymap( from matplotlib import rcParams try: - from astropy.time import Time - from ligo.skymap import io, version, plot, postprocess, bayestar, kde import healpy as hp + from astropy.time import Time + from ligo.skymap import bayestar, io, kde, plot, postprocess, version except ImportError as e: - logger.info("Unable to generate skymap: error {}".format(e)) + logger.info(f"Unable to generate skymap: error {e}") return check_directory_exists_and_if_not_mkdir(self.outdir) - logger.info('Reading samples for skymap') + logger.info("Reading samples for skymap") data = self.posterior if maxpts is not None and maxpts < len(data): - logger.info('Taking random subsample of chain') + logger.info("Taking random subsample of chain") data = data.sample(maxpts) - default_obj_filename = os.path.join(self.outdir, '{}_skypost.obj'.format(self.label)) + default_obj_filename = os.path.join(self.outdir, f"{self.label}_skypost.obj") if load_pickle is False: try: - pts = data[['ra', 'dec', 'luminosity_distance']].values + pts = data[["ra", "dec", "luminosity_distance"]].values confidence_levels = kde.Clustered2Plus1DSkyKDE distance = True except KeyError: logger.warning("The results file does not contain luminosity_distance") - pts = data[['ra', 'dec']].values + pts = data[["ra", "dec"]].values confidence_levels = kde.Clustered2DSkyKDE distance = False - logger.info('Initialising skymap class') + logger.info("Initialising skymap class") skypost = confidence_levels(pts, trials=trials, jobs=jobs) - logger.info('Pickling skymap to {}'.format(default_obj_filename)) + logger.info(f"Pickling skymap to {default_obj_filename}") safe_file_dump(skypost, default_obj_filename, "pickle") else: @@ -815,38 +813,38 @@ def plot_skymap( obj_filename = load_pickle else: obj_filename = default_obj_filename - logger.info('Reading from pickle {}'.format(obj_filename)) - with open(obj_filename, 'rb') as file: + logger.info(f"Reading from pickle {obj_filename}") + with open(obj_filename, "rb") as file: skypost = pickle.load(file) skypost.jobs = jobs distance = isinstance(skypost, kde.Clustered2Plus1DSkyKDE) - logger.info('Making skymap') + logger.info("Making skymap") hpmap = skypost.as_healpix() if not enable_multiresolution: hpmap = bayestar.rasterize(hpmap) hpmap.meta.update(io.fits.metadata_for_version_module(version)) - hpmap.meta['creator'] = "bilby" - hpmap.meta['origin'] = 'LIGO/Virgo' - hpmap.meta['gps_creation_time'] = Time.now().gps - hpmap.meta['history'] = "" + hpmap.meta["creator"] = "bilby" + hpmap.meta["origin"] = "LIGO/Virgo" + hpmap.meta["gps_creation_time"] = Time.now().gps + hpmap.meta["history"] = "" if objid is not None: - hpmap.meta['objid'] = objid + hpmap.meta["objid"] = objid if instruments: - hpmap.meta['instruments'] = instruments + hpmap.meta["instruments"] = instruments if distance: - hpmap.meta['distmean'] = np.mean(data['luminosity_distance']) - hpmap.meta['diststd'] = np.std(data['luminosity_distance']) + hpmap.meta["distmean"] = np.mean(data["luminosity_distance"]) + hpmap.meta["diststd"] = np.std(data["luminosity_distance"]) try: - time = data['geocent_time'] - hpmap.meta['gps_time'] = time.mean() + time = data["geocent_time"] + hpmap.meta["gps_time"] = time.mean() except KeyError: - logger.warning('Cannot determine the event time from geocent_time') + logger.warning("Cannot determine the event time from geocent_time") - fits_filename = os.path.join(self.outdir, "{}_skymap.fits".format(self.label)) - logger.info('Saving skymap fits-file to {}'.format(fits_filename)) + fits_filename = os.path.join(self.outdir, f"{self.label}_skymap.fits") + logger.info(f"Saving skymap fits-file to {fits_filename}") io.write_sky_map(fits_filename, hpmap, nest=True) skymap, metadata = io.fits.read_sky_map(fits_filename, nest=None) @@ -857,41 +855,36 @@ def plot_skymap( probperdeg2 = skymap / deg2perpix if geo: - obstime = Time(metadata['gps_time'], format='gps').utc.isot - ax = plt.axes(projection='geo degrees mollweide', obstime=obstime) + obstime = Time(metadata["gps_time"], format="gps").utc.isot + ax = plt.axes(projection="geo degrees mollweide", obstime=obstime) else: - ax = plt.axes(projection='astro hours mollweide') + ax = plt.axes(projection="astro hours mollweide") ax.grid() # Plot sky map. vmax = probperdeg2.max() - img = ax.imshow_hpx( - (probperdeg2, 'ICRS'), nested=metadata['nest'], vmin=0., vmax=vmax, - cmap=cmap) + img = ax.imshow_hpx((probperdeg2, "ICRS"), nested=metadata["nest"], vmin=0.0, vmax=vmax, cmap=cmap) # Add colorbar. if colorbar: cb = plot.colorbar(img) - cb.set_label(r'prob. per deg$^2$') + cb.set_label(r"prob. per deg$^2$") if contour is not None: confidence_levels = 100 * postprocess.find_greedy_credible_levels(skymap) contours = ax.contour_hpx( - (confidence_levels, 'ICRS'), nested=metadata['nest'], - colors='k', linewidths=0.5, levels=contour) - fmt = r'%g\%%' if rcParams['text.usetex'] else '%g%%' + (confidence_levels, "ICRS"), nested=metadata["nest"], colors="k", linewidths=0.5, levels=contour + ) + fmt = r"%g\%%" if rcParams["text.usetex"] else "%g%%" plt.clabel(contours, fmt=fmt, fontsize=6, inline=True) # Add continents. if geo: - geojson_filename = os.path.join( - os.path.dirname(plot.__file__), 'ne_simplified_coastline.json') - with open(geojson_filename, 'r') as geojson_file: - geoms = json.load(geojson_file)['geometries'] - verts = [coord for geom in geoms - for coord in zip(*geom['coordinates'])] - plt.plot(*verts, color='0.5', linewidth=0.5, - transform=ax.get_transform('world')) + geojson_filename = os.path.join(os.path.dirname(plot.__file__), "ne_simplified_coastline.json") + with open(geojson_filename) as geojson_file: + geoms = json.load(geojson_file)["geometries"] + verts = [coord for geom in geoms for coord in zip(*geom["coordinates"])] + plt.plot(*verts, color="0.5", linewidth=0.5, transform=ax.get_transform("world")) # Add a white outline to all text to make it stand out from the background. plot.outline_text(ax) @@ -899,22 +892,20 @@ def plot_skymap( if annotate: text = [] try: - objid = metadata['objid'] + objid = metadata["objid"] except KeyError: pass else: - text.append('event ID: {}'.format(objid)) + text.append(f"event ID: {objid}") if contour: pp = np.round(contour).astype(int) - ii = np.round(np.searchsorted(np.sort(confidence_levels), contour) * - deg2perpix).astype(int) + ii = np.round(np.searchsorted(np.sort(confidence_levels), contour) * deg2perpix).astype(int) for i, p in zip(ii, pp): - text.append( - u'{:d}% area: {:d} deg$^2$'.format(p, i)) - ax.text(1, 1, '\n'.join(text), transform=ax.transAxes, ha='right') + text.append(f"{p:d}% area: {i:d} deg$^2$") + ax.text(1, 1, "\n".join(text), transform=ax.transAxes, ha="right") - filename = os.path.join(self.outdir, "{}_skymap.png".format(self.label)) - logger.info("Generating 2D projected skymap to {}".format(filename)) + filename = os.path.join(self.outdir, f"{self.label}_skymap.png") + logger.info(f"Generating 2D projected skymap to {filename}") safe_save_figure(fig=plt.gcf(), filename=filename, dpi=dpi) diff --git a/bilby/gw/sampler/__init__.py b/bilby/gw/sampler/__init__.py index 70a655892..be2406630 100644 --- a/bilby/gw/sampler/__init__.py +++ b/bilby/gw/sampler/__init__.py @@ -1 +1 @@ -from . import proposal +from . import proposal # noqa: F401 diff --git a/bilby/gw/sampler/proposal.py b/bilby/gw/sampler/proposal.py index 79e1ec92c..9d2484194 100644 --- a/bilby/gw/sampler/proposal.py +++ b/bilby/gw/sampler/proposal.py @@ -12,11 +12,11 @@ class SkyLocationWanderJump(JumpProposal): """ def __call__(self, sample, **kwargs): - temperature = 1 / kwargs.get('inverse_temperature', 1.0) + temperature = 1 / kwargs.get("inverse_temperature", 1.0) sigma = np.sqrt(temperature) / 2 / np.pi - sample['ra'] += random.gauss(0, sigma) - sample['dec'] += random.gauss(0, sigma) - return super(SkyLocationWanderJump, self).__call__(sample) + sample["ra"] += random.gauss(0, sigma) + sample["dec"] += random.gauss(0, sigma) + return super().__call__(sample) class CorrelatedPolarisationPhaseJump(JumpProposal): @@ -25,17 +25,17 @@ class CorrelatedPolarisationPhaseJump(JumpProposal): """ def __call__(self, sample, **kwargs): - alpha = sample['psi'] + sample['phase'] - beta = sample['psi'] - sample['phase'] + alpha = sample["psi"] + sample["phase"] + beta = sample["psi"] - sample["phase"] draw = random.random() if draw < 0.5: alpha = 3.0 * np.pi * random.random() else: beta = 3.0 * np.pi * random.random() - 2 * np.pi - sample['psi'] = (alpha + beta) * 0.5 - sample['phase'] = (alpha - beta) * 0.5 - return super(CorrelatedPolarisationPhaseJump, self).__call__(sample) + sample["psi"] = (alpha + beta) * 0.5 + sample["phase"] = (alpha - beta) * 0.5 + return super().__call__(sample) class PolarisationPhaseJump(JumpProposal): @@ -44,6 +44,6 @@ class PolarisationPhaseJump(JumpProposal): """ def __call__(self, sample, **kwargs): - sample['phase'] += np.pi - sample['psi'] += np.pi / 2 - return super(PolarisationPhaseJump, self).__call__(sample) + sample["phase"] += np.pi + sample["psi"] += np.pi / 2 + return super().__call__(sample) diff --git a/bilby/gw/source.py b/bilby/gw/source.py index 951346760..02a94ab77 100644 --- a/bilby/gw/source.py +++ b/bilby/gw/source.py @@ -3,11 +3,13 @@ from ..core import utils from ..core.utils import logger from .conversion import bilby_to_lalsimulation_spins -from .utils import (lalsim_GetApproximantFromString, - lalsim_SimInspiralFD, - lalsim_SimInspiralChooseFDWaveform, - lalsim_SimInspiralChooseFDWaveformSequence, - safe_cast_mode_to_int) +from .utils import ( + lalsim_GetApproximantFromString, + lalsim_SimInspiralChooseFDWaveform, + lalsim_SimInspiralChooseFDWaveformSequence, + lalsim_SimInspiralFD, + safe_cast_mode_to_int, +) UNUSED_KWARGS_MESSAGE = """There are unused waveform kwargs. This is deprecated behavior and will result in an error in future releases. Make sure all of the waveform kwargs are correctly @@ -17,8 +19,21 @@ """ -def gwsignal_binary_black_hole(frequency_array, mass_1, mass_2, luminosity_distance, a_1, tilt_1, - phi_12, a_2, tilt_2, phi_jl, theta_jn, phase, **kwargs): +def gwsignal_binary_black_hole( + frequency_array, + mass_1, + mass_2, + luminosity_distance, + a_1, + tilt_1, + phi_12, + a_2, + tilt_2, + phi_jl, + theta_jn, + phase, + **kwargs, +): """ A binary black hole waveform model using GWsignal @@ -93,9 +108,9 @@ def gwsignal_binary_black_hole(frequency_array, mass_1, mass_2, luminosity_dista " instead." ) + import astropy.units as u from lalsimulation.gwsignal import GenerateFDWaveform from lalsimulation.gwsignal.models import gwsignal_get_waveform_generator - import astropy.units as u waveform_kwargs = dict( waveform_approximant="SEOBNRv5PHM", @@ -108,19 +123,20 @@ def gwsignal_binary_black_hole(frequency_array, mass_1, mass_2, luminosity_dista ) waveform_kwargs.update(kwargs) - waveform_approximant = waveform_kwargs['waveform_approximant'] + waveform_approximant = waveform_kwargs["waveform_approximant"] if waveform_approximant not in ["SEOBNRv5HM", "SEOBNRv5PHM"]: if waveform_approximant == "IMRPhenomXPHM": - logger.warning("The new waveform interface is unreviewed for this model" + - "and it is only intended for testing.") + logger.warning( + "The new waveform interface is unreviewed for this model" + "and it is only intended for testing." + ) else: raise ValueError("The new waveform interface is unreviewed for this model.") - reference_frequency = waveform_kwargs['reference_frequency'] - minimum_frequency = waveform_kwargs['minimum_frequency'] - maximum_frequency = waveform_kwargs['maximum_frequency'] - catch_waveform_errors = waveform_kwargs['catch_waveform_errors'] - mode_array = waveform_kwargs['mode_array'] - pn_amplitude_order = waveform_kwargs['pn_amplitude_order'] + reference_frequency = waveform_kwargs["reference_frequency"] + minimum_frequency = waveform_kwargs["minimum_frequency"] + maximum_frequency = waveform_kwargs["maximum_frequency"] + catch_waveform_errors = waveform_kwargs["catch_waveform_errors"] + mode_array = waveform_kwargs["mode_array"] + pn_amplitude_order = waveform_kwargs["pn_amplitude_order"] if pn_amplitude_order != 0: # This is to mimic the behaviour in @@ -130,7 +146,7 @@ def gwsignal_binary_black_hole(frequency_array, mass_1, mass_2, luminosity_dista pn_amplitude_order = 3 # Equivalent to MAX_PRECESSING_AMP_PN_ORDER in LALSimulation else: pn_amplitude_order = 6 # Equivalent to MAX_NONPRECESSING_AMP_PN_ORDER in LALSimulation - start_frequency = minimum_frequency * 2. / (pn_amplitude_order + 2) + start_frequency = minimum_frequency * 2.0 / (pn_amplitude_order + 2) else: start_frequency = minimum_frequency @@ -139,13 +155,21 @@ def gwsignal_binary_black_hole(frequency_array, mass_1, mass_2, luminosity_dista delta_frequency = frequency_array[1] - frequency_array[0] - frequency_bounds = ((frequency_array >= minimum_frequency) * - (frequency_array <= maximum_frequency)) + frequency_bounds = (frequency_array >= minimum_frequency) * (frequency_array <= maximum_frequency) iota, spin_1x, spin_1y, spin_1z, spin_2x, spin_2y, spin_2z = bilby_to_lalsimulation_spins( - theta_jn=theta_jn, phi_jl=phi_jl, tilt_1=tilt_1, tilt_2=tilt_2, - phi_12=phi_12, a_1=a_1, a_2=a_2, mass_1=mass_1 * utils.solar_mass, mass_2=mass_2 * utils.solar_mass, - reference_frequency=reference_frequency, phase=phase) + theta_jn=theta_jn, + phi_jl=phi_jl, + tilt_1=tilt_1, + tilt_2=tilt_2, + phi_12=phi_12, + a_1=a_1, + a_2=a_2, + mass_1=mass_1 * utils.solar_mass, + mass_2=mass_2 * utils.solar_mass, + reference_frequency=reference_frequency, + phase=phase, + ) eccentricity = 0.0 longitude_ascending_nodes = 0.0 @@ -153,39 +177,39 @@ def gwsignal_binary_black_hole(frequency_array, mass_1, mass_2, luminosity_dista # Check if conditioning is needed condition = 0 - if wf_gen.metadata["implemented_domain"] == 'time': + if wf_gen.metadata["implemented_domain"] == "time": condition = 1 # Create dict for gwsignal generator - gwsignal_dict = {'mass1' : mass_1 * u.solMass, - 'mass2' : mass_2 * u.solMass, - 'spin1x' : spin_1x * u.dimensionless_unscaled, - 'spin1y' : spin_1y * u.dimensionless_unscaled, - 'spin1z' : spin_1z * u.dimensionless_unscaled, - 'spin2x' : spin_2x * u.dimensionless_unscaled, - 'spin2y' : spin_2y * u.dimensionless_unscaled, - 'spin2z' : spin_2z * u.dimensionless_unscaled, - 'deltaF' : delta_frequency * u.Hz, - 'f22_start' : start_frequency * u.Hz, - 'f_max': maximum_frequency * u.Hz, - 'f22_ref': reference_frequency * u.Hz, - 'phi_ref' : phase * u.rad, - 'distance' : luminosity_distance * u.Mpc, - 'inclination' : iota * u.rad, - 'eccentricity' : eccentricity * u.dimensionless_unscaled, - 'longAscNodes' : longitude_ascending_nodes * u.rad, - 'meanPerAno' : mean_per_ano * u.rad, - # 'ModeArray': mode_array, - 'condition': condition - } + gwsignal_dict = { + "mass1": mass_1 * u.solMass, + "mass2": mass_2 * u.solMass, + "spin1x": spin_1x * u.dimensionless_unscaled, + "spin1y": spin_1y * u.dimensionless_unscaled, + "spin1z": spin_1z * u.dimensionless_unscaled, + "spin2x": spin_2x * u.dimensionless_unscaled, + "spin2y": spin_2y * u.dimensionless_unscaled, + "spin2z": spin_2z * u.dimensionless_unscaled, + "deltaF": delta_frequency * u.Hz, + "f22_start": start_frequency * u.Hz, + "f_max": maximum_frequency * u.Hz, + "f22_ref": reference_frequency * u.Hz, + "phi_ref": phase * u.rad, + "distance": luminosity_distance * u.Mpc, + "inclination": iota * u.rad, + "eccentricity": eccentricity * u.dimensionless_unscaled, + "longAscNodes": longitude_ascending_nodes * u.rad, + "meanPerAno": mean_per_ano * u.rad, + # 'ModeArray': mode_array, + "condition": condition, + } if mode_array is not None: try: mode_array = [tuple(map(safe_cast_mode_to_int, mode)) for mode in mode_array] except (ValueError, TypeError) as e: raise ValueError( - f"Unable to convert mode_array elements to tuples of ints. " - f"mode_array: {mode_array}, Error: {e}" + f"Unable to convert mode_array elements to tuples of ints. mode_array: {mode_array}, Error: {e}" ) from e gwsignal_dict.update(ModeArray=mode_array) @@ -193,17 +217,17 @@ def gwsignal_binary_black_hole(frequency_array, mass_1, mass_2, luminosity_dista extra_args = waveform_kwargs.copy() for key in [ - "waveform_approximant", - "reference_frequency", - "minimum_frequency", - "maximum_frequency", - "catch_waveform_errors", - "mode_array", - "pn_spin_order", - "pn_amplitude_order", - "pn_tidal_order", - "pn_phase_order", - "numerical_relativity_file", + "waveform_approximant", + "reference_frequency", + "minimum_frequency", + "maximum_frequency", + "catch_waveform_errors", + "mode_array", + "pn_spin_order", + "pn_amplitude_order", + "pn_tidal_order", + "pn_phase_order", + "numerical_relativity_file", ]: if key in extra_args.keys(): del extra_args[key] @@ -216,22 +240,26 @@ def gwsignal_binary_black_hole(frequency_array, mass_1, mass_2, luminosity_dista if not catch_waveform_errors: raise else: - EDOM = ( - "Internal function call failed: Input domain error" in e.args[0] - ) or "Input domain error" in e.args[ + EDOM = ("Internal function call failed: Input domain error" in e.args[0]) or "Input domain error" in e.args[ 0 ] if EDOM: - failed_parameters = dict(mass_1=mass_1, mass_2=mass_2, - spin_1=(spin_1x, spin_1y, spin_1z), - spin_2=(spin_2x, spin_2y, spin_2z), - luminosity_distance=luminosity_distance, - iota=iota, phase=phase, - eccentricity=eccentricity, - start_frequency=minimum_frequency) - logger.warning("Evaluating the waveform failed with error: {}\n".format(e) + - "The parameters were {}\n".format(failed_parameters) + - "Likelihood will be set to -inf.") + failed_parameters = dict( + mass_1=mass_1, + mass_2=mass_2, + spin_1=(spin_1x, spin_1y, spin_1z), + spin_2=(spin_2x, spin_2y, spin_2z), + luminosity_distance=luminosity_distance, + iota=iota, + phase=phase, + eccentricity=eccentricity, + start_frequency=minimum_frequency, + ) + logger.warning( + f"Evaluating the waveform failed with error: {e}\n" + + f"The parameters were {failed_parameters}\n" + + "Likelihood will be set to -inf." + ) return None else: raise @@ -243,15 +271,17 @@ def gwsignal_binary_black_hole(frequency_array, mass_1, mass_2, luminosity_dista h_cross = np.zeros_like(frequency_array, dtype=complex) if len(hplus) > len(frequency_array): - logger.debug("GWsignal waveform longer than bilby's `frequency_array`" + - "({} vs {}), ".format(len(hplus), len(frequency_array)) + - "probably because padded with zeros up to the next power of two length." + - " Truncating GWsignal array.") - h_plus = hplus[:len(h_plus)] - h_cross = hcross[:len(h_cross)] + logger.debug( + "GWsignal waveform longer than bilby's `frequency_array`" + + f"({len(hplus)} vs {len(frequency_array)}), " + + "probably because padded with zeros up to the next power of two length." + + " Truncating GWsignal array." + ) + h_plus = hplus[: len(h_plus)] + h_cross = hcross[: len(h_cross)] else: - h_plus[:len(hplus)] = hplus - h_cross[:len(hcross)] = hcross + h_plus[: len(hplus)] = hplus + h_cross[: len(hcross)] = hcross h_plus *= frequency_bounds h_cross *= frequency_bounds @@ -266,9 +296,21 @@ def gwsignal_binary_black_hole(frequency_array, mass_1, mass_2, luminosity_dista def lal_binary_black_hole( - frequency_array, mass_1, mass_2, luminosity_distance, a_1, tilt_1, - phi_12, a_2, tilt_2, phi_jl, theta_jn, phase, **kwargs): - """ A Binary Black Hole waveform model using lalsimulation + frequency_array, + mass_1, + mass_2, + luminosity_distance, + a_1, + tilt_1, + phi_12, + a_2, + tilt_2, + phi_jl, + theta_jn, + phase, + **kwargs, +): + """A Binary Black Hole waveform model using lalsimulation Parameters ========== @@ -335,23 +377,52 @@ def lal_binary_black_hole( dict: A dictionary with the plus and cross polarisation strain modes """ waveform_kwargs = dict( - waveform_approximant='IMRPhenomPv2', reference_frequency=50.0, - minimum_frequency=20.0, maximum_frequency=frequency_array[-1], - catch_waveform_errors=False, pn_spin_order=-1, pn_tidal_order=-1, - pn_phase_order=-1, pn_amplitude_order=0) + waveform_approximant="IMRPhenomPv2", + reference_frequency=50.0, + minimum_frequency=20.0, + maximum_frequency=frequency_array[-1], + catch_waveform_errors=False, + pn_spin_order=-1, + pn_tidal_order=-1, + pn_phase_order=-1, + pn_amplitude_order=0, + ) waveform_kwargs.update(kwargs) return _base_lal_cbc_fd_waveform( - frequency_array=frequency_array, mass_1=mass_1, mass_2=mass_2, - luminosity_distance=luminosity_distance, theta_jn=theta_jn, phase=phase, - a_1=a_1, a_2=a_2, tilt_1=tilt_1, tilt_2=tilt_2, phi_12=phi_12, - phi_jl=phi_jl, **waveform_kwargs) + frequency_array=frequency_array, + mass_1=mass_1, + mass_2=mass_2, + luminosity_distance=luminosity_distance, + theta_jn=theta_jn, + phase=phase, + a_1=a_1, + a_2=a_2, + tilt_1=tilt_1, + tilt_2=tilt_2, + phi_12=phi_12, + phi_jl=phi_jl, + **waveform_kwargs, + ) def lal_binary_neutron_star( - frequency_array, mass_1, mass_2, luminosity_distance, a_1, tilt_1, - phi_12, a_2, tilt_2, phi_jl, theta_jn, phase, lambda_1, lambda_2, - **kwargs): - """ A Binary Neutron Star waveform model using lalsimulation + frequency_array, + mass_1, + mass_2, + luminosity_distance, + a_1, + tilt_1, + phi_12, + a_2, + tilt_2, + phi_jl, + theta_jn, + phase, + lambda_1, + lambda_2, + **kwargs, +): + """A Binary Neutron Star waveform model using lalsimulation Parameters ========== @@ -419,22 +490,40 @@ def lal_binary_neutron_star( dict: A dictionary with the plus and cross polarisation strain modes """ waveform_kwargs = dict( - waveform_approximant='IMRPhenomPv2_NRTidal', reference_frequency=50.0, - minimum_frequency=20.0, maximum_frequency=frequency_array[-1], - catch_waveform_errors=False, pn_spin_order=-1, pn_tidal_order=-1, - pn_phase_order=-1, pn_amplitude_order=0) + waveform_approximant="IMRPhenomPv2_NRTidal", + reference_frequency=50.0, + minimum_frequency=20.0, + maximum_frequency=frequency_array[-1], + catch_waveform_errors=False, + pn_spin_order=-1, + pn_tidal_order=-1, + pn_phase_order=-1, + pn_amplitude_order=0, + ) waveform_kwargs.update(kwargs) return _base_lal_cbc_fd_waveform( - frequency_array=frequency_array, mass_1=mass_1, mass_2=mass_2, - luminosity_distance=luminosity_distance, theta_jn=theta_jn, phase=phase, - a_1=a_1, a_2=a_2, tilt_1=tilt_1, tilt_2=tilt_2, phi_12=phi_12, - phi_jl=phi_jl, lambda_1=lambda_1, lambda_2=lambda_2, **waveform_kwargs) + frequency_array=frequency_array, + mass_1=mass_1, + mass_2=mass_2, + luminosity_distance=luminosity_distance, + theta_jn=theta_jn, + phase=phase, + a_1=a_1, + a_2=a_2, + tilt_1=tilt_1, + tilt_2=tilt_2, + phi_12=phi_12, + phi_jl=phi_jl, + lambda_1=lambda_1, + lambda_2=lambda_2, + **waveform_kwargs, + ) def lal_eccentric_binary_black_hole_no_spins( - frequency_array, mass_1, mass_2, eccentricity, luminosity_distance, - theta_jn, phase, **kwargs): - """ Eccentric binary black hole waveform model using lalsimulation (EccentricFD) + frequency_array, mass_1, mass_2, eccentricity, luminosity_distance, theta_jn, phase, **kwargs +): + """Eccentric binary black hole waveform model using lalsimulation (EccentricFD) Parameters ========== @@ -487,15 +576,27 @@ def lal_eccentric_binary_black_hole_no_spins( dict: A dictionary with the plus and cross polarisation strain modes """ waveform_kwargs = dict( - waveform_approximant='EccentricFD', reference_frequency=10.0, - minimum_frequency=10.0, maximum_frequency=frequency_array[-1], - catch_waveform_errors=False, pn_spin_order=-1, pn_tidal_order=-1, - pn_phase_order=-1, pn_amplitude_order=0) + waveform_approximant="EccentricFD", + reference_frequency=10.0, + minimum_frequency=10.0, + maximum_frequency=frequency_array[-1], + catch_waveform_errors=False, + pn_spin_order=-1, + pn_tidal_order=-1, + pn_phase_order=-1, + pn_amplitude_order=0, + ) waveform_kwargs.update(kwargs) return _base_lal_cbc_fd_waveform( - frequency_array=frequency_array, mass_1=mass_1, mass_2=mass_2, - luminosity_distance=luminosity_distance, theta_jn=theta_jn, phase=phase, - eccentricity=eccentricity, **waveform_kwargs) + frequency_array=frequency_array, + mass_1=mass_1, + mass_2=mass_2, + luminosity_distance=luminosity_distance, + theta_jn=theta_jn, + phase=phase, + eccentricity=eccentricity, + **waveform_kwargs, + ) def set_waveform_dictionary(waveform_kwargs, lambda_1=0, lambda_2=0): @@ -519,15 +620,14 @@ def set_waveform_dictionary(waveform_kwargs, lambda_1=0, lambda_2=0): """ import lalsimulation as lalsim from lal import CreateDict - waveform_dictionary = waveform_kwargs.pop('lal_waveform_dictionary', CreateDict()) + + waveform_dictionary = waveform_kwargs.pop("lal_waveform_dictionary", CreateDict()) waveform_kwargs["TidalLambda1"] = lambda_1 waveform_kwargs["TidalLambda2"] = lambda_2 waveform_kwargs["NumRelData"] = waveform_kwargs.pop("numerical_relativity_file", None) - for key in [ - "pn_spin_order", "pn_tidal_order", "pn_phase_order", "pn_amplitude_order" - ]: - waveform_kwargs[key[:2].upper() + key[3:].title().replace('_', '')] = waveform_kwargs.pop(key) + for key in ["pn_spin_order", "pn_tidal_order", "pn_phase_order", "pn_amplitude_order"]: + waveform_kwargs[key[:2].upper() + key[3:].title().replace("_", "")] = waveform_kwargs.pop(key) for key in list(waveform_kwargs.keys()).copy(): func = getattr(lalsim, f"SimInspiralWaveformParamsInsert{key}", None) @@ -548,10 +648,24 @@ def set_waveform_dictionary(waveform_kwargs, lambda_1=0, lambda_2=0): def _base_lal_cbc_fd_waveform( - frequency_array, mass_1, mass_2, luminosity_distance, theta_jn, phase, - a_1=0.0, a_2=0.0, tilt_1=0.0, tilt_2=0.0, phi_12=0.0, phi_jl=0.0, - lambda_1=0.0, lambda_2=0.0, eccentricity=0.0, **waveform_kwargs): - """ Generate a cbc waveform model using lalsimulation + frequency_array, + mass_1, + mass_2, + luminosity_distance, + theta_jn, + phase, + a_1=0.0, + a_2=0.0, + tilt_1=0.0, + tilt_2=0.0, + phi_12=0.0, + phi_jl=0.0, + lambda_1=0.0, + lambda_2=0.0, + eccentricity=0.0, + **waveform_kwargs, +): + """Generate a cbc waveform model using lalsimulation Parameters ========== @@ -594,36 +708,42 @@ def _base_lal_cbc_fd_waveform( """ import lalsimulation as lalsim - waveform_approximant = waveform_kwargs.pop('waveform_approximant') - reference_frequency = waveform_kwargs.pop('reference_frequency') - minimum_frequency = waveform_kwargs.pop('minimum_frequency') - maximum_frequency = waveform_kwargs.pop('maximum_frequency') - catch_waveform_errors = waveform_kwargs.pop('catch_waveform_errors') - pn_amplitude_order = waveform_kwargs['pn_amplitude_order'] + waveform_approximant = waveform_kwargs.pop("waveform_approximant") + reference_frequency = waveform_kwargs.pop("reference_frequency") + minimum_frequency = waveform_kwargs.pop("minimum_frequency") + maximum_frequency = waveform_kwargs.pop("maximum_frequency") + catch_waveform_errors = waveform_kwargs.pop("catch_waveform_errors") + pn_amplitude_order = waveform_kwargs["pn_amplitude_order"] waveform_dictionary = set_waveform_dictionary(waveform_kwargs, lambda_1, lambda_2) approximant = lalsim_GetApproximantFromString(waveform_approximant) if pn_amplitude_order != 0: - start_frequency = lalsim.SimInspiralfLow2fStart( - float(minimum_frequency), int(pn_amplitude_order), approximant - ) + start_frequency = lalsim.SimInspiralfLow2fStart(float(minimum_frequency), int(pn_amplitude_order), approximant) else: start_frequency = minimum_frequency delta_frequency = frequency_array[1] - frequency_array[0] - frequency_bounds = ((frequency_array >= minimum_frequency) * - (frequency_array <= maximum_frequency)) + frequency_bounds = (frequency_array >= minimum_frequency) * (frequency_array <= maximum_frequency) luminosity_distance = luminosity_distance * 1e6 * utils.parsec mass_1 = mass_1 * utils.solar_mass mass_2 = mass_2 * utils.solar_mass iota, spin_1x, spin_1y, spin_1z, spin_2x, spin_2y, spin_2z = bilby_to_lalsimulation_spins( - theta_jn=theta_jn, phi_jl=phi_jl, tilt_1=tilt_1, tilt_2=tilt_2, - phi_12=phi_12, a_1=a_1, a_2=a_2, mass_1=mass_1, mass_2=mass_2, - reference_frequency=reference_frequency, phase=phase) + theta_jn=theta_jn, + phi_jl=phi_jl, + tilt_1=tilt_1, + tilt_2=tilt_2, + phi_12=phi_12, + a_1=a_1, + a_2=a_2, + mass_1=mass_1, + mass_2=mass_2, + reference_frequency=reference_frequency, + phase=phase, + ) longitude_ascending_nodes = 0.0 mean_per_ano = 0.0 @@ -634,27 +754,49 @@ def _base_lal_cbc_fd_waveform( wf_func = lalsim_SimInspiralFD try: hplus, hcross = wf_func( - mass_1, mass_2, spin_1x, spin_1y, spin_1z, spin_2x, spin_2y, - spin_2z, luminosity_distance, iota, phase, - longitude_ascending_nodes, eccentricity, mean_per_ano, delta_frequency, - start_frequency, maximum_frequency, reference_frequency, - waveform_dictionary, approximant) + mass_1, + mass_2, + spin_1x, + spin_1y, + spin_1z, + spin_2x, + spin_2y, + spin_2z, + luminosity_distance, + iota, + phase, + longitude_ascending_nodes, + eccentricity, + mean_per_ano, + delta_frequency, + start_frequency, + maximum_frequency, + reference_frequency, + waveform_dictionary, + approximant, + ) except Exception as e: if not catch_waveform_errors: raise else: - EDOM = (e.args[0] == 'Internal function call failed: Input domain error') + EDOM = e.args[0] == "Internal function call failed: Input domain error" if EDOM: - failed_parameters = dict(mass_1=mass_1, mass_2=mass_2, - spin_1=(spin_1x, spin_1y, spin_1z), - spin_2=(spin_2x, spin_2y, spin_2z), - luminosity_distance=luminosity_distance, - iota=iota, phase=phase, - eccentricity=eccentricity, - start_frequency=start_frequency) - logger.warning("Evaluating the waveform failed with error: {}\n".format(e) + - "The parameters were {}\n".format(failed_parameters) + - "Likelihood will be set to -inf.") + failed_parameters = dict( + mass_1=mass_1, + mass_2=mass_2, + spin_1=(spin_1x, spin_1y, spin_1z), + spin_2=(spin_2x, spin_2y, spin_2z), + luminosity_distance=luminosity_distance, + iota=iota, + phase=phase, + eccentricity=eccentricity, + start_frequency=start_frequency, + ) + logger.warning( + f"Evaluating the waveform failed with error: {e}\n" + + f"The parameters were {failed_parameters}\n" + + "Likelihood will be set to -inf." + ) return None else: raise @@ -663,15 +805,17 @@ def _base_lal_cbc_fd_waveform( h_cross = np.zeros_like(frequency_array, dtype=complex) if len(hplus.data.data) > len(frequency_array): - logger.debug("LALsim waveform longer than bilby's `frequency_array`" + - "({} vs {}), ".format(len(hplus.data.data), len(frequency_array)) + - "probably because padded with zeros up to the next power of two length." + - " Truncating lalsim array.") - h_plus = hplus.data.data[:len(h_plus)] - h_cross = hcross.data.data[:len(h_cross)] + logger.debug( + "LALsim waveform longer than bilby's `frequency_array`" + + f"({len(hplus.data.data)} vs {len(frequency_array)}), " + + "probably because padded with zeros up to the next power of two length." + + " Truncating lalsim array." + ) + h_plus = hplus.data.data[: len(h_plus)] + h_cross = hcross.data.data[: len(h_cross)] else: - h_plus[:len(hplus.data.data)] = hplus.data.data - h_cross[:len(hcross.data.data)] = hcross.data.data + h_plus[: len(hplus.data.data)] = hplus.data.data + h_cross[: len(hcross.data.data)] = hcross.data.data h_plus *= frequency_bounds h_cross *= frequency_bounds @@ -689,40 +833,111 @@ def _base_lal_cbc_fd_waveform( def binary_black_hole_roq( - frequency_array, mass_1, mass_2, luminosity_distance, a_1, tilt_1, - phi_12, a_2, tilt_2, phi_jl, theta_jn, phase, **waveform_arguments): + frequency_array, + mass_1, + mass_2, + luminosity_distance, + a_1, + tilt_1, + phi_12, + a_2, + tilt_2, + phi_jl, + theta_jn, + phase, + **waveform_arguments, +): waveform_kwargs = dict( - waveform_approximant='IMRPhenomPv2', reference_frequency=20.0, - catch_waveform_errors=False, pn_spin_order=-1, pn_tidal_order=-1, - pn_phase_order=-1, pn_amplitude_order=0) + waveform_approximant="IMRPhenomPv2", + reference_frequency=20.0, + catch_waveform_errors=False, + pn_spin_order=-1, + pn_tidal_order=-1, + pn_phase_order=-1, + pn_amplitude_order=0, + ) waveform_kwargs.update(waveform_arguments) return _base_roq_waveform( - frequency_array=frequency_array, mass_1=mass_1, mass_2=mass_2, - luminosity_distance=luminosity_distance, theta_jn=theta_jn, phase=phase, - a_1=a_1, a_2=a_2, tilt_1=tilt_1, tilt_2=tilt_2, phi_jl=phi_jl, - phi_12=phi_12, lambda_1=0.0, lambda_2=0.0, **waveform_kwargs) + frequency_array=frequency_array, + mass_1=mass_1, + mass_2=mass_2, + luminosity_distance=luminosity_distance, + theta_jn=theta_jn, + phase=phase, + a_1=a_1, + a_2=a_2, + tilt_1=tilt_1, + tilt_2=tilt_2, + phi_jl=phi_jl, + phi_12=phi_12, + lambda_1=0.0, + lambda_2=0.0, + **waveform_kwargs, + ) def binary_neutron_star_roq( - frequency_array, mass_1, mass_2, luminosity_distance, a_1, tilt_1, - phi_12, a_2, tilt_2, phi_jl, lambda_1, lambda_2, theta_jn, phase, - **waveform_arguments): + frequency_array, + mass_1, + mass_2, + luminosity_distance, + a_1, + tilt_1, + phi_12, + a_2, + tilt_2, + phi_jl, + lambda_1, + lambda_2, + theta_jn, + phase, + **waveform_arguments, +): waveform_kwargs = dict( - waveform_approximant='IMRPhenomD_NRTidal', reference_frequency=20.0, - catch_waveform_errors=False, pn_spin_order=-1, pn_tidal_order=-1, - pn_phase_order=-1, pn_amplitude_order=0) + waveform_approximant="IMRPhenomD_NRTidal", + reference_frequency=20.0, + catch_waveform_errors=False, + pn_spin_order=-1, + pn_tidal_order=-1, + pn_phase_order=-1, + pn_amplitude_order=0, + ) waveform_kwargs.update(waveform_arguments) return _base_roq_waveform( - frequency_array=frequency_array, mass_1=mass_1, mass_2=mass_2, - luminosity_distance=luminosity_distance, theta_jn=theta_jn, phase=phase, - a_1=a_1, a_2=a_2, tilt_1=tilt_1, tilt_2=tilt_2, phi_jl=phi_jl, - phi_12=phi_12, lambda_1=lambda_1, lambda_2=lambda_2, **waveform_kwargs) + frequency_array=frequency_array, + mass_1=mass_1, + mass_2=mass_2, + luminosity_distance=luminosity_distance, + theta_jn=theta_jn, + phase=phase, + a_1=a_1, + a_2=a_2, + tilt_1=tilt_1, + tilt_2=tilt_2, + phi_jl=phi_jl, + phi_12=phi_12, + lambda_1=lambda_1, + lambda_2=lambda_2, + **waveform_kwargs, + ) def lal_binary_black_hole_relative_binning( - frequency_array, mass_1, mass_2, luminosity_distance, a_1, tilt_1, - phi_12, a_2, tilt_2, phi_jl, theta_jn, phase, **kwargs): - """ Source model to go with RelativeBinningGravitationalWaveTransient likelihood. + frequency_array, + mass_1, + mass_2, + luminosity_distance, + a_1, + tilt_1, + phi_12, + a_2, + tilt_2, + phi_jl, + theta_jn, + phase, + **kwargs, +): + """Source model to go with RelativeBinningGravitationalWaveTransient likelihood. All parameters are the same as in the usual source models, except `fiducial` @@ -734,35 +949,79 @@ def lal_binary_black_hole_relative_binning( fiducial = kwargs.pop("fiducial", 0) waveform_kwargs = dict( - waveform_approximant='IMRPhenomPv2', reference_frequency=50.0, - minimum_frequency=20.0, maximum_frequency=frequency_array[-1], - catch_waveform_errors=False, pn_spin_order=-1, pn_tidal_order=-1, - pn_phase_order=-1, pn_amplitude_order=0) + waveform_approximant="IMRPhenomPv2", + reference_frequency=50.0, + minimum_frequency=20.0, + maximum_frequency=frequency_array[-1], + catch_waveform_errors=False, + pn_spin_order=-1, + pn_tidal_order=-1, + pn_phase_order=-1, + pn_amplitude_order=0, + ) waveform_kwargs.update(kwargs) if fiducial == 1: _ = waveform_kwargs.pop("frequency_bin_edges", None) return _base_lal_cbc_fd_waveform( - frequency_array=frequency_array, mass_1=mass_1, mass_2=mass_2, - luminosity_distance=luminosity_distance, theta_jn=theta_jn, phase=phase, - a_1=a_1, a_2=a_2, tilt_1=tilt_1, tilt_2=tilt_2, phi_jl=phi_jl, - phi_12=phi_12, lambda_1=0.0, lambda_2=0.0, **waveform_kwargs) + frequency_array=frequency_array, + mass_1=mass_1, + mass_2=mass_2, + luminosity_distance=luminosity_distance, + theta_jn=theta_jn, + phase=phase, + a_1=a_1, + a_2=a_2, + tilt_1=tilt_1, + tilt_2=tilt_2, + phi_jl=phi_jl, + phi_12=phi_12, + lambda_1=0.0, + lambda_2=0.0, + **waveform_kwargs, + ) else: _ = waveform_kwargs.pop("minimum_frequency", None) _ = waveform_kwargs.pop("maximum_frequency", None) waveform_kwargs["frequencies"] = waveform_kwargs.pop("frequency_bin_edges") return _base_waveform_frequency_sequence( - frequency_array=frequency_array, mass_1=mass_1, mass_2=mass_2, - luminosity_distance=luminosity_distance, theta_jn=theta_jn, phase=phase, - a_1=a_1, a_2=a_2, tilt_1=tilt_1, tilt_2=tilt_2, phi_jl=phi_jl, - phi_12=phi_12, lambda_1=0.0, lambda_2=0.0, **waveform_kwargs) + frequency_array=frequency_array, + mass_1=mass_1, + mass_2=mass_2, + luminosity_distance=luminosity_distance, + theta_jn=theta_jn, + phase=phase, + a_1=a_1, + a_2=a_2, + tilt_1=tilt_1, + tilt_2=tilt_2, + phi_jl=phi_jl, + phi_12=phi_12, + lambda_1=0.0, + lambda_2=0.0, + **waveform_kwargs, + ) def lal_binary_neutron_star_relative_binning( - frequency_array, mass_1, mass_2, luminosity_distance, a_1, tilt_1, - phi_12, a_2, tilt_2, phi_jl, lambda_1, lambda_2, theta_jn, phase, **kwargs): - """ Source model to go with RelativeBinningGravitationalWaveTransient likelihood. + frequency_array, + mass_1, + mass_2, + luminosity_distance, + a_1, + tilt_1, + phi_12, + a_2, + tilt_2, + phi_jl, + lambda_1, + lambda_2, + theta_jn, + phase, + **kwargs, +): + """Source model to go with RelativeBinningGravitationalWaveTransient likelihood. All parameters are the same as in the usual source models, except `fiducial` @@ -774,34 +1033,77 @@ def lal_binary_neutron_star_relative_binning( fiducial = kwargs.pop("fiducial", 0) waveform_kwargs = dict( - waveform_approximant='IMRPhenomPv2_NRTidal', reference_frequency=50.0, - minimum_frequency=20.0, maximum_frequency=frequency_array[-1], - catch_waveform_errors=False, pn_spin_order=-1, pn_tidal_order=-1, - pn_phase_order=-1, pn_amplitude_order=0) + waveform_approximant="IMRPhenomPv2_NRTidal", + reference_frequency=50.0, + minimum_frequency=20.0, + maximum_frequency=frequency_array[-1], + catch_waveform_errors=False, + pn_spin_order=-1, + pn_tidal_order=-1, + pn_phase_order=-1, + pn_amplitude_order=0, + ) waveform_kwargs.update(kwargs) if fiducial == 1: return _base_lal_cbc_fd_waveform( - frequency_array=frequency_array, mass_1=mass_1, mass_2=mass_2, - luminosity_distance=luminosity_distance, theta_jn=theta_jn, phase=phase, - a_1=a_1, a_2=a_2, tilt_1=tilt_1, tilt_2=tilt_2, phi_12=phi_12, - phi_jl=phi_jl, lambda_1=lambda_1, lambda_2=lambda_2, **waveform_kwargs) + frequency_array=frequency_array, + mass_1=mass_1, + mass_2=mass_2, + luminosity_distance=luminosity_distance, + theta_jn=theta_jn, + phase=phase, + a_1=a_1, + a_2=a_2, + tilt_1=tilt_1, + tilt_2=tilt_2, + phi_12=phi_12, + phi_jl=phi_jl, + lambda_1=lambda_1, + lambda_2=lambda_2, + **waveform_kwargs, + ) else: _ = waveform_kwargs.pop("minimum_frequency", None) _ = waveform_kwargs.pop("maximum_frequency", None) waveform_kwargs["frequencies"] = waveform_kwargs.pop("frequency_bin_edges") return _base_waveform_frequency_sequence( - frequency_array=frequency_array, mass_1=mass_1, mass_2=mass_2, - luminosity_distance=luminosity_distance, theta_jn=theta_jn, phase=phase, - a_1=a_1, a_2=a_2, tilt_1=tilt_1, tilt_2=tilt_2, phi_jl=phi_jl, - phi_12=phi_12, lambda_1=lambda_1, lambda_2=lambda_2, **waveform_kwargs) + frequency_array=frequency_array, + mass_1=mass_1, + mass_2=mass_2, + luminosity_distance=luminosity_distance, + theta_jn=theta_jn, + phase=phase, + a_1=a_1, + a_2=a_2, + tilt_1=tilt_1, + tilt_2=tilt_2, + phi_jl=phi_jl, + phi_12=phi_12, + lambda_1=lambda_1, + lambda_2=lambda_2, + **waveform_kwargs, + ) def _base_roq_waveform( - frequency_array, mass_1, mass_2, luminosity_distance, a_1, tilt_1, - phi_12, a_2, tilt_2, lambda_1, lambda_2, phi_jl, theta_jn, phase, - **waveform_arguments): - """ Base source model for ROQGravitationalWaveTransient, which evaluates + frequency_array, + mass_1, + mass_2, + luminosity_distance, + a_1, + tilt_1, + phi_12, + a_2, + tilt_2, + lambda_1, + lambda_2, + phi_jl, + theta_jn, + phase, + **waveform_arguments, +): + """Base source model for ROQGravitationalWaveTransient, which evaluates waveform values at frequency nodes contained in waveform_arguments. This requires that waveform_arguments contain all of 'frequency_nodes', 'linear_indices', and 'quadratic_indices', or both 'frequency_nodes_linear' and @@ -860,46 +1162,67 @@ def _base_roq_waveform( Dict containing plus and cross modes evaluated at the linear and quadratic frequency nodes. """ - if 'frequency_nodes' not in waveform_arguments: - size_linear = len(waveform_arguments['frequency_nodes_linear']) + if "frequency_nodes" not in waveform_arguments: + size_linear = len(waveform_arguments["frequency_nodes_linear"]) frequency_nodes_combined = np.hstack( - (waveform_arguments.pop('frequency_nodes_linear'), - waveform_arguments.pop('frequency_nodes_quadratic')) - ) - frequency_nodes_unique, original_indices = np.unique( - frequency_nodes_combined, return_inverse=True + (waveform_arguments.pop("frequency_nodes_linear"), waveform_arguments.pop("frequency_nodes_quadratic")) ) + frequency_nodes_unique, original_indices = np.unique(frequency_nodes_combined, return_inverse=True) linear_indices = original_indices[:size_linear] quadratic_indices = original_indices[size_linear:] - waveform_arguments['frequencies'] = frequency_nodes_unique + waveform_arguments["frequencies"] = frequency_nodes_unique else: linear_indices = waveform_arguments.pop("linear_indices") quadratic_indices = waveform_arguments.pop("quadratic_indices") for key in ["frequency_nodes_linear", "frequency_nodes_quadratic"]: _ = waveform_arguments.pop(key, None) - waveform_arguments['frequencies'] = waveform_arguments.pop('frequency_nodes') + waveform_arguments["frequencies"] = waveform_arguments.pop("frequency_nodes") waveform_polarizations = _base_waveform_frequency_sequence( - frequency_array=frequency_array, mass_1=mass_1, mass_2=mass_2, - luminosity_distance=luminosity_distance, theta_jn=theta_jn, phase=phase, - a_1=a_1, a_2=a_2, tilt_1=tilt_1, tilt_2=tilt_2, phi_jl=phi_jl, - phi_12=phi_12, lambda_1=lambda_1, lambda_2=lambda_2, **waveform_arguments) + frequency_array=frequency_array, + mass_1=mass_1, + mass_2=mass_2, + luminosity_distance=luminosity_distance, + theta_jn=theta_jn, + phase=phase, + a_1=a_1, + a_2=a_2, + tilt_1=tilt_1, + tilt_2=tilt_2, + phi_jl=phi_jl, + phi_12=phi_12, + lambda_1=lambda_1, + lambda_2=lambda_2, + **waveform_arguments, + ) return { - 'linear': { - 'plus': waveform_polarizations['plus'][linear_indices], - 'cross': waveform_polarizations['cross'][linear_indices] + "linear": { + "plus": waveform_polarizations["plus"][linear_indices], + "cross": waveform_polarizations["cross"][linear_indices], + }, + "quadratic": { + "plus": waveform_polarizations["plus"][quadratic_indices], + "cross": waveform_polarizations["cross"][quadratic_indices], }, - 'quadratic': { - 'plus': waveform_polarizations['plus'][quadratic_indices], - 'cross': waveform_polarizations['cross'][quadratic_indices] - } } def binary_black_hole_frequency_sequence( - frequency_array, mass_1, mass_2, luminosity_distance, a_1, tilt_1, - phi_12, a_2, tilt_2, phi_jl, theta_jn, phase, **kwargs): - """ A Binary Black Hole waveform model using lalsimulation. This generates + frequency_array, + mass_1, + mass_2, + luminosity_distance, + a_1, + tilt_1, + phi_12, + a_2, + tilt_2, + phi_jl, + theta_jn, + phase, + **kwargs, +): + """A Binary Black Hole waveform model using lalsimulation. This generates a waveform only on specified frequency points. This is useful for likelihood requiring waveform values at a subset of all the frequency samples. For example, this is used for MBGravitationalWaveTransient. @@ -966,22 +1289,52 @@ def binary_black_hole_frequency_sequence( dict: A dictionary with the plus and cross polarisation strain modes """ waveform_kwargs = dict( - waveform_approximant='IMRPhenomPv2', reference_frequency=50.0, - catch_waveform_errors=False, pn_spin_order=-1, pn_tidal_order=-1, - pn_phase_order=-1, pn_amplitude_order=0) + waveform_approximant="IMRPhenomPv2", + reference_frequency=50.0, + catch_waveform_errors=False, + pn_spin_order=-1, + pn_tidal_order=-1, + pn_phase_order=-1, + pn_amplitude_order=0, + ) waveform_kwargs.update(kwargs) return _base_waveform_frequency_sequence( - frequency_array=frequency_array, mass_1=mass_1, mass_2=mass_2, - luminosity_distance=luminosity_distance, theta_jn=theta_jn, phase=phase, - a_1=a_1, a_2=a_2, tilt_1=tilt_1, tilt_2=tilt_2, phi_jl=phi_jl, - phi_12=phi_12, lambda_1=0.0, lambda_2=0.0, **waveform_kwargs) + frequency_array=frequency_array, + mass_1=mass_1, + mass_2=mass_2, + luminosity_distance=luminosity_distance, + theta_jn=theta_jn, + phase=phase, + a_1=a_1, + a_2=a_2, + tilt_1=tilt_1, + tilt_2=tilt_2, + phi_jl=phi_jl, + phi_12=phi_12, + lambda_1=0.0, + lambda_2=0.0, + **waveform_kwargs, + ) def binary_neutron_star_frequency_sequence( - frequency_array, mass_1, mass_2, luminosity_distance, a_1, tilt_1, - phi_12, a_2, tilt_2, phi_jl, lambda_1, lambda_2, theta_jn, phase, - **kwargs): - """ A Binary Neutron Star waveform model using lalsimulation. This generates + frequency_array, + mass_1, + mass_2, + luminosity_distance, + a_1, + tilt_1, + phi_12, + a_2, + tilt_2, + phi_jl, + lambda_1, + lambda_2, + theta_jn, + phase, + **kwargs, +): + """A Binary Neutron Star waveform model using lalsimulation. This generates a waveform only on specified frequency points. This is useful for likelihood requiring waveform values at a subset of all the frequency samples. For example, this is used for MBGravitationalWaveTransient. @@ -1052,22 +1405,52 @@ def binary_neutron_star_frequency_sequence( dict: A dictionary with the plus and cross polarisation strain modes """ waveform_kwargs = dict( - waveform_approximant='IMRPhenomPv2_NRTidal', reference_frequency=50.0, - catch_waveform_errors=False, pn_spin_order=-1, pn_tidal_order=-1, - pn_phase_order=-1, pn_amplitude_order=0) + waveform_approximant="IMRPhenomPv2_NRTidal", + reference_frequency=50.0, + catch_waveform_errors=False, + pn_spin_order=-1, + pn_tidal_order=-1, + pn_phase_order=-1, + pn_amplitude_order=0, + ) waveform_kwargs.update(kwargs) return _base_waveform_frequency_sequence( - frequency_array=frequency_array, mass_1=mass_1, mass_2=mass_2, - luminosity_distance=luminosity_distance, theta_jn=theta_jn, phase=phase, - a_1=a_1, a_2=a_2, tilt_1=tilt_1, tilt_2=tilt_2, phi_jl=phi_jl, - phi_12=phi_12, lambda_1=lambda_1, lambda_2=lambda_2, **waveform_kwargs) + frequency_array=frequency_array, + mass_1=mass_1, + mass_2=mass_2, + luminosity_distance=luminosity_distance, + theta_jn=theta_jn, + phase=phase, + a_1=a_1, + a_2=a_2, + tilt_1=tilt_1, + tilt_2=tilt_2, + phi_jl=phi_jl, + phi_12=phi_12, + lambda_1=lambda_1, + lambda_2=lambda_2, + **waveform_kwargs, + ) def _base_waveform_frequency_sequence( - frequency_array, mass_1, mass_2, luminosity_distance, a_1, tilt_1, - phi_12, a_2, tilt_2, lambda_1, lambda_2, phi_jl, theta_jn, phase, - **waveform_kwargs): - """ Generate a cbc waveform model on specified frequency samples + frequency_array, + mass_1, + mass_2, + luminosity_distance, + a_1, + tilt_1, + phi_12, + a_2, + tilt_2, + lambda_1, + lambda_2, + phi_jl, + theta_jn, + phase, + **waveform_kwargs, +): + """Generate a cbc waveform model on specified frequency samples Parameters ---------- @@ -1102,10 +1485,10 @@ def _base_waveform_frequency_sequence( Dict containing plus and cross modes evaluated at the linear and quadratic frequency nodes. """ - frequencies = waveform_kwargs.pop('frequencies') - reference_frequency = waveform_kwargs.pop('reference_frequency') - approximant = waveform_kwargs.pop('waveform_approximant') - catch_waveform_errors = waveform_kwargs.pop('catch_waveform_errors') + frequencies = waveform_kwargs.pop("frequencies") + reference_frequency = waveform_kwargs.pop("reference_frequency") + approximant = waveform_kwargs.pop("waveform_approximant") + catch_waveform_errors = waveform_kwargs.pop("catch_waveform_errors") waveform_dictionary = set_waveform_dictionary(waveform_kwargs, lambda_1, lambda_2) approximant = lalsim_GetApproximantFromString(approximant) @@ -1115,29 +1498,57 @@ def _base_waveform_frequency_sequence( mass_2 = mass_2 * utils.solar_mass iota, spin_1x, spin_1y, spin_1z, spin_2x, spin_2y, spin_2z = bilby_to_lalsimulation_spins( - theta_jn=theta_jn, phi_jl=phi_jl, tilt_1=tilt_1, tilt_2=tilt_2, - phi_12=phi_12, a_1=a_1, a_2=a_2, mass_1=mass_1, mass_2=mass_2, - reference_frequency=reference_frequency, phase=phase) + theta_jn=theta_jn, + phi_jl=phi_jl, + tilt_1=tilt_1, + tilt_2=tilt_2, + phi_12=phi_12, + a_1=a_1, + a_2=a_2, + mass_1=mass_1, + mass_2=mass_2, + reference_frequency=reference_frequency, + phase=phase, + ) try: h_plus, h_cross = lalsim_SimInspiralChooseFDWaveformSequence( - phase, mass_1, mass_2, spin_1x, spin_1y, spin_1z, spin_2x, spin_2y, - spin_2z, reference_frequency, luminosity_distance, iota, - waveform_dictionary, approximant, frequencies) + phase, + mass_1, + mass_2, + spin_1x, + spin_1y, + spin_1z, + spin_2x, + spin_2y, + spin_2z, + reference_frequency, + luminosity_distance, + iota, + waveform_dictionary, + approximant, + frequencies, + ) except Exception as e: if not catch_waveform_errors: raise else: - EDOM = (e.args[0] == 'Internal function call failed: Input domain error') + EDOM = e.args[0] == "Internal function call failed: Input domain error" if EDOM: - failed_parameters = dict(mass_1=mass_1, mass_2=mass_2, - spin_1=(spin_1x, spin_1y, spin_1z), - spin_2=(spin_2x, spin_2y, spin_2z), - luminosity_distance=luminosity_distance, - iota=iota, phase=phase) - logger.warning("Evaluating the waveform failed with error: {}\n".format(e) + - "The parameters were {}\n".format(failed_parameters) + - "Likelihood will be set to -inf.") + failed_parameters = dict( + mass_1=mass_1, + mass_2=mass_2, + spin_1=(spin_1x, spin_1y, spin_1z), + spin_2=(spin_2x, spin_2y, spin_2z), + luminosity_distance=luminosity_distance, + iota=iota, + phase=phase, + ) + logger.warning( + f"Evaluating the waveform failed with error: {e}\n" + + f"The parameters were {failed_parameters}\n" + + "Likelihood will be set to -inf." + ) return None else: raise @@ -1192,17 +1603,20 @@ def sinegaussian(frequency_array, hrss, Q, frequency, **kwargs): fm = frequency_array - frequency fp = frequency_array + frequency - h_plus = ((hrss / np.sqrt(temp * (1 + np.exp(-Q**2)))) * - ((np.sqrt(np.pi) * tau) / 2.0) * - (np.exp(-fm**2 * np.pi**2 * tau**2) + - np.exp(-fp**2 * np.pi**2 * tau**2))) + h_plus = ( + (hrss / np.sqrt(temp * (1 + np.exp(-(Q**2))))) + * ((np.sqrt(np.pi) * tau) / 2.0) + * (np.exp(-(fm**2) * np.pi**2 * tau**2) + np.exp(-(fp**2) * np.pi**2 * tau**2)) + ) - h_cross = (-1j * (hrss / np.sqrt(temp * (1 - np.exp(-Q**2)))) * - ((np.sqrt(np.pi) * tau) / 2.0) * - (np.exp(-fm**2 * np.pi**2 * tau**2) - - np.exp(-fp**2 * np.pi**2 * tau**2))) + h_cross = ( + -1j + * (hrss / np.sqrt(temp * (1 - np.exp(-(Q**2))))) + * ((np.sqrt(np.pi) * tau) / 2.0) + * (np.exp(-(fm**2) * np.pi**2 * tau**2) - np.exp(-(fp**2) * np.pi**2 * tau**2)) + ) - return {'plus': h_plus, 'cross': h_cross} + return {"plus": h_plus, "cross": h_cross} def supernova(frequency_array, luminosity_distance, **kwargs): @@ -1241,11 +1655,11 @@ def supernova(frequency_array, luminosity_distance, **kwargs): h_plus = scaling * (data[:, 0] + 1j * data[:, 1]) h_cross = scaling * (data[:, 2] + 1j * data[:, 3]) - return {'plus': h_plus, 'cross': h_cross} + return {"plus": h_plus, "cross": h_cross} def supernova_pca_model( - frequency_array, pc_coeff1, pc_coeff2, pc_coeff3, pc_coeff4, pc_coeff5, luminosity_distance, **kwargs + frequency_array, pc_coeff1, pc_coeff2, pc_coeff3, pc_coeff4, pc_coeff5, luminosity_distance, **kwargs ): r""" Signal model based on a five-component principal component decomposition @@ -1287,10 +1701,7 @@ def supernova_pca_model( principal_components = kwargs["realPCs"] + 1j * kwargs["imagPCs"] coefficients = [pc_coeff1, pc_coeff2, pc_coeff3, pc_coeff4, pc_coeff5] - strain = np.sum( - [coeff * principal_components[:, ii] for ii, coeff in enumerate(coefficients)], - axis=0 - ) + strain = np.sum([coeff * principal_components[:, ii] for ii, coeff in enumerate(coefficients)], axis=0) # file at 10kpc scaling = 1e-23 * (10.0 / luminosity_distance) @@ -1298,50 +1709,83 @@ def supernova_pca_model( h_plus = scaling * strain h_cross = scaling * strain - return {'plus': h_plus, 'cross': h_cross} + return {"plus": h_plus, "cross": h_cross} precession_only = { - "tilt_1", "tilt_2", "phi_12", "phi_jl", "chi_1_in_plane", "chi_2_in_plane", + "tilt_1", + "tilt_2", + "phi_12", + "phi_jl", + "chi_1_in_plane", + "chi_2_in_plane", } spin = { - "a_1", "a_2", "tilt_1", "tilt_2", "phi_12", "phi_jl", "chi_1", "chi_2", - "chi_1_in_plane", "chi_2_in_plane", + "a_1", + "a_2", + "tilt_1", + "tilt_2", + "phi_12", + "phi_jl", + "chi_1", + "chi_2", + "chi_1_in_plane", + "chi_2_in_plane", } mass = { - "chirp_mass", "mass_ratio", "total_mass", "mass_1", "mass_2", + "chirp_mass", + "mass_ratio", + "total_mass", + "mass_1", + "mass_2", "symmetric_mass_ratio", } -primary_spin_and_q = { - "a_1", "chi_1", "mass_ratio" -} -tidal = { - "lambda_1", "lambda_2", "lambda_tilde", "delta_lambda_tilde" -} +primary_spin_and_q = {"a_1", "chi_1", "mass_ratio"} +tidal = {"lambda_1", "lambda_2", "lambda_tilde", "delta_lambda_tilde"} phase = { - "phase", "delta_phase", + "phase", + "delta_phase", } extrinsic = { - "azimuth", "zenith", "luminosity_distance", "psi", "theta_jn", - "cos_theta_jn", "geocent_time", "time_jitter", "ra", "dec", - "H1_time", "L1_time", "V1_time", + "azimuth", + "zenith", + "luminosity_distance", + "psi", + "theta_jn", + "cos_theta_jn", + "geocent_time", + "time_jitter", + "ra", + "dec", + "H1_time", + "L1_time", + "V1_time", } sky = { - "azimuth", "zenith", "ra", "dec", + "azimuth", + "zenith", + "ra", + "dec", } distance_inclination = { - "luminosity_distance", "redshift", "theta_jn", "cos_theta_jn", -} -measured_spin = { - "chi_1", "chi_2", "a_1", "a_2", "chi_1_in_plane" + "luminosity_distance", + "redshift", + "theta_jn", + "cos_theta_jn", } +measured_spin = {"chi_1", "chi_2", "a_1", "a_2", "chi_1_in_plane"} PARAMETER_SETS = dict( - spin=spin, mass=mass, phase=phase, extrinsic=extrinsic, - tidal=tidal, primary_spin_and_q=primary_spin_and_q, + spin=spin, + mass=mass, + phase=phase, + extrinsic=extrinsic, + tidal=tidal, + primary_spin_and_q=primary_spin_and_q, intrinsic=spin.union(mass).union(phase).union(tidal), precession_only=precession_only, - sky=sky, distance_inclination=distance_inclination, + sky=sky, + distance_inclination=distance_inclination, measured_spin=measured_spin, ) diff --git a/bilby/gw/utils.py b/bilby/gw/utils.py index f1f4c0291..cc00bd95e 100644 --- a/bilby/gw/utils.py +++ b/bilby/gw/utils.py @@ -3,16 +3,20 @@ from functools import lru_cache import numpy as np -from scipy.interpolate import interp1d -from scipy.special import i0e from bilby_cython.geometry import ( zenith_azimuth_to_theta_phi as _zenith_azimuth_to_theta_phi, ) from bilby_cython.time import greenwich_mean_sidereal_time +from scipy.interpolate import interp1d +from scipy.special import i0e -from ..core.utils import (logger, run_commandline, - check_directory_exists_and_if_not_mkdir, - SamplesSummary, theta_phi_to_ra_dec) +from ..core.utils import ( + SamplesSummary, + check_directory_exists_and_if_not_mkdir, + logger, + run_commandline, + theta_phi_to_ra_dec, +) from ..core.utils.constants import solar_mass @@ -78,11 +82,12 @@ def get_vertex_position_geocentric(latitude, longitude, elevation): """ semi_major_axis = 6378137 # for ellipsoid model of Earth, in m semi_minor_axis = 6356752.314 # in m - radius = semi_major_axis**2 * (semi_major_axis**2 * np.cos(latitude)**2 + - semi_minor_axis**2 * np.sin(latitude)**2)**(-0.5) + radius = semi_major_axis**2 * ( + semi_major_axis**2 * np.cos(latitude) ** 2 + semi_minor_axis**2 * np.sin(latitude) ** 2 + ) ** (-0.5) x_comp = (radius + elevation) * np.cos(latitude) * np.cos(longitude) y_comp = (radius + elevation) * np.cos(latitude) * np.sin(longitude) - z_comp = ((semi_minor_axis / semi_major_axis)**2 * radius + elevation) * np.sin(latitude) + z_comp = ((semi_minor_axis / semi_major_axis) ** 2 * radius + elevation) * np.sin(latitude) return np.array([x_comp, y_comp, z_comp]) @@ -110,7 +115,7 @@ def inner_product(aa, bb, frequency, PSD): df = frequency[1] - frequency[0] integral = np.sum(integrand) * df - return 4. * np.real(integral) + return 4.0 * np.real(integral) def noise_weighted_inner_product(aa, bb, power_spectral_density, duration): @@ -159,11 +164,11 @@ def matched_filter_snr(signal, frequency_domain_strain, power_spectral_density, """ rho_mf = noise_weighted_inner_product( - aa=signal, bb=frequency_domain_strain, - power_spectral_density=power_spectral_density, duration=duration) - rho_mf /= optimal_snr_squared( - signal=signal, power_spectral_density=power_spectral_density, - duration=duration)**0.5 + aa=signal, bb=frequency_domain_strain, power_spectral_density=power_spectral_density, duration=duration + ) + rho_mf /= ( + optimal_snr_squared(signal=signal, power_spectral_density=power_spectral_density, duration=duration) ** 0.5 + ) return rho_mf @@ -190,8 +195,16 @@ def optimal_snr_squared(signal, power_spectral_density, duration): return noise_weighted_inner_product(signal, signal, power_spectral_density, duration) -def overlap(signal_a, signal_b, power_spectral_density=None, delta_frequency=None, - lower_cut_off=None, upper_cut_off=None, norm_a=None, norm_b=None): +def overlap( + signal_a, + signal_b, + power_spectral_density=None, + delta_frequency=None, + lower_cut_off=None, + upper_cut_off=None, + norm_a=None, + norm_b=None, +): r""" Compute the overlap between two signals @@ -306,9 +319,8 @@ def get_event_time(event): return datasets.event_gps(event) -def get_open_strain_data( - name, start_time, end_time, outdir, cache=False, buffer_time=0, **kwargs): - """ A function which accesses the open strain data +def get_open_strain_data(name, start_time, end_time, outdir, cache=False, buffer_time=0, **kwargs): + """A function which accesses the open strain data This uses `gwpy` to download the open data and then saves a cached copy for later use @@ -336,7 +348,8 @@ def get_open_strain_data( """ from gwpy.timeseries import TimeSeries - filename = '{}/{}_{}_{}.txt'.format(outdir, name, start_time, end_time) + + filename = f"{outdir}/{name}_{start_time}_{end_time}.txt" if buffer_time < 0: raise ValueError("buffer_time < 0") @@ -344,26 +357,24 @@ def get_open_strain_data( end_time = end_time + buffer_time if os.path.isfile(filename) and cache: - logger.info('Using cached data from {}'.format(filename)) + logger.info(f"Using cached data from {filename}") strain = TimeSeries.read(filename) else: - logger.info('Fetching open data from {} to {} with buffer time {}' - .format(start_time, end_time, buffer_time)) + logger.info(f"Fetching open data from {start_time} to {end_time} with buffer time {buffer_time}") try: strain = TimeSeries.fetch_open_data(name, start_time, end_time, **kwargs) - logger.info('Saving cache of data to {}'.format(filename)) + logger.info(f"Saving cache of data to {filename}") strain.write(filename) except Exception as e: logger.info("Unable to fetch open data, see debug for detailed info") - logger.info("Call to gwpy.timeseries.TimeSeries.fetch_open_data returned {}" - .format(e)) + logger.info(f"Call to gwpy.timeseries.TimeSeries.fetch_open_data returned {e}") strain = None return strain def read_frame_file(file_name, start_time, end_time, resample=None, channel=None, buffer_time=0, **kwargs): - """ A function which accesses the open strain data + """A function which accesses the open strain data This uses `gwpy` to download the open data and then saves a cached copy for later use @@ -390,6 +401,7 @@ def read_frame_file(file_name, start_time, end_time, resample=None, channel=None """ from gwpy.timeseries import TimeSeries + loaded = False strain = None @@ -397,27 +409,37 @@ def read_frame_file(file_name, start_time, end_time, resample=None, channel=None try: strain = TimeSeries.read(source=file_name, channel=channel, start=start_time, end=end_time, **kwargs) loaded = True - logger.info('Successfully loaded {}.'.format(channel)) + logger.info(f"Successfully loaded {channel}.") except (RuntimeError, ValueError): - logger.warning('Channel {} not found. Trying preset channel names'.format(channel)) + logger.warning(f"Channel {channel} not found. Trying preset channel names") if loaded is False: - ligo_channel_types = ['GDS-CALIB_STRAIN', 'DCS-CALIB_STRAIN_C01', 'DCS-CALIB_STRAIN_C02', - 'DCH-CLEAN_STRAIN_C02', 'GWOSC-16KHZ_R1_STRAIN', - 'GWOSC-4KHZ_R1_STRAIN'] - virgo_channel_types = ['Hrec_hoft_V1O2Repro2A_16384Hz', 'FAKE_h_16384Hz_4R', - 'GWOSC-16KHZ_R1_STRAIN', 'GWOSC-4KHZ_R1_STRAIN'] + ligo_channel_types = [ + "GDS-CALIB_STRAIN", + "DCS-CALIB_STRAIN_C01", + "DCS-CALIB_STRAIN_C02", + "DCH-CLEAN_STRAIN_C02", + "GWOSC-16KHZ_R1_STRAIN", + "GWOSC-4KHZ_R1_STRAIN", + ] + virgo_channel_types = [ + "Hrec_hoft_V1O2Repro2A_16384Hz", + "FAKE_h_16384Hz_4R", + "GWOSC-16KHZ_R1_STRAIN", + "GWOSC-4KHZ_R1_STRAIN", + ] channel_types = dict(H1=ligo_channel_types, L1=ligo_channel_types, V1=virgo_channel_types) for detector in channel_types.keys(): for channel_type in channel_types[detector]: if loaded: break - channel = '{}:{}'.format(detector, channel_type) + channel = f"{detector}:{channel_type}" try: - strain = TimeSeries.read(source=file_name, channel=channel, start=start_time, end=end_time, - **kwargs) + strain = TimeSeries.read( + source=file_name, channel=channel, start=start_time, end=end_time, **kwargs + ) loaded = True - logger.info('Successfully read strain data for channel {}.'.format(channel)) + logger.info(f"Successfully read strain data for channel {channel}.") except (RuntimeError, ValueError): pass @@ -426,7 +448,7 @@ def read_frame_file(file_name, start_time, end_time, resample=None, channel=None strain = strain.resample(resample) return strain else: - logger.warning('No data loaded.') + logger.warning("No data loaded.") return None @@ -460,21 +482,21 @@ def get_gracedb(gracedb, outdir, duration, calibration, detectors, query_types=N List of cache filenames, one per interferometer. """ candidate = gracedb_to_json(gracedb, outdir=outdir) - trigger_time = candidate['gpstime'] + trigger_time = candidate["gpstime"] gps_start_time = trigger_time - duration cache_files = [] if query_types is None: query_types = [None] * len(detectors) for i, det in enumerate(detectors): output_cache_file = gw_data_find( - det, gps_start_time, duration, calibration, - outdir=outdir, query_type=query_types[i], server=server) + det, gps_start_time, duration, calibration, outdir=outdir, query_type=query_types[i], server=server + ) cache_files.append(output_cache_file) return candidate, cache_files -def gracedb_to_json(gracedb, cred=None, service_url='https://gracedb.ligo.org/api/', outdir=None): - """ Script to download a GraceDB candidate +def gracedb_to_json(gracedb, cred=None, service_url="https://gracedb.ligo.org/api/", outdir=None): + """Script to download a GraceDB candidate Parameters ========== @@ -489,39 +511,35 @@ def gracedb_to_json(gracedb, cred=None, service_url='https://gracedb.ligo.org/ap outdir: str, optional If given, a string identfying the location in which to store the json """ - logger.info( - 'Starting routine to download GraceDb candidate {}'.format(gracedb)) + logger.info(f"Starting routine to download GraceDb candidate {gracedb}") from ligo.gracedb.rest import GraceDb - logger.info('Initialise client and attempt to download') - logger.info('Fetching from {}'.format(service_url)) + logger.info("Initialise client and attempt to download") + logger.info(f"Fetching from {service_url}") try: client = GraceDb(cred=cred, service_url=service_url) - except IOError: - raise ValueError( - 'Failed to authenticate with gracedb: check your X509 ' - 'certificate is accessible and valid') + except OSError: + raise ValueError("Failed to authenticate with gracedb: check your X509 certificate is accessible and valid") try: candidate = client.event(gracedb) - logger.info('Successfully downloaded candidate') + logger.info("Successfully downloaded candidate") except Exception as e: - raise ValueError("Unable to obtain GraceDB candidate, exception: {}".format(e)) + raise ValueError(f"Unable to obtain GraceDB candidate, exception: {e}") json_output = candidate.json() if outdir is not None: check_directory_exists_and_if_not_mkdir(outdir) - outfilepath = os.path.join(outdir, '{}.json'.format(gracedb)) - logger.info('Writing candidate to {}'.format(outfilepath)) - with open(outfilepath, 'w') as outfile: + outfilepath = os.path.join(outdir, f"{gracedb}.json") + logger.info(f"Writing candidate to {outfilepath}") + with open(outfilepath, "w") as outfile: json.dump(json_output, outfile, indent=2) return json_output -def gw_data_find(observatory, gps_start_time, duration, calibration, - outdir='.', query_type=None, server=None): - """ Builds a gw_data_find call and process output +def gw_data_find(observatory, gps_start_time, duration, calibration, outdir=".", query_type=None, server=None): + """Builds a gw_data_find call and process output Parameters ========== @@ -544,49 +562,50 @@ def gw_data_find(observatory, gps_start_time, duration, calibration, Path to the output cache file """ - logger.info('Building gw_data_find command line') + logger.info("Building gw_data_find command line") - observatory_lookup = dict(H1='H', L1='L', V1='V') + observatory_lookup = dict(H1="H", L1="L", V1="V") observatory_code = observatory_lookup[observatory] if query_type is None: - logger.warning('No query type provided. This may prevent data from being read.') - if observatory_code == 'V': - query_type = 'V1Online' + logger.warning("No query type provided. This may prevent data from being read.") + if observatory_code == "V": + query_type = "V1Online" else: - query_type = '{}_HOFT_C0{}'.format(observatory, calibration) + query_type = f"{observatory}_HOFT_C0{calibration}" - logger.info('Using LDRDataFind query type {}'.format(query_type)) + logger.info(f"Using LDRDataFind query type {query_type}") - cache_file = '{}-{}_CACHE-{}-{}.lcf'.format( - observatory, query_type, gps_start_time, duration) + cache_file = f"{observatory}-{query_type}_CACHE-{gps_start_time}-{duration}.lcf" output_cache_file = os.path.join(outdir, cache_file) gps_end_time = gps_start_time + duration if server is None: server = os.environ.get("LIGO_DATAFIND_SERVER") if server is None: - logger.warning('No data_find server found, defaulting to CIT server. ' - 'To run on other clusters, pass the output of ' - '`echo $LIGO_DATAFIND_SERVER`') - server = 'ldr.ldas.cit:80' - - cl_list = ['gw_data_find'] - cl_list.append('--observatory {}'.format(observatory_code)) - cl_list.append('--gps-start-time {}'.format(int(np.floor(gps_start_time)))) - cl_list.append('--gps-end-time {}'.format(int(np.ceil(gps_end_time)))) - cl_list.append('--type {}'.format(query_type)) - cl_list.append('--output {}'.format(output_cache_file)) - cl_list.append('--server {}'.format(server)) - cl_list.append('--url-type file') - cl_list.append('--lal-cache') - cl = ' '.join(cl_list) + logger.warning( + "No data_find server found, defaulting to CIT server. " + "To run on other clusters, pass the output of " + "`echo $LIGO_DATAFIND_SERVER`" + ) + server = "ldr.ldas.cit:80" + + cl_list = ["gw_data_find"] + cl_list.append(f"--observatory {observatory_code}") + cl_list.append(f"--gps-start-time {int(np.floor(gps_start_time))}") + cl_list.append(f"--gps-end-time {int(np.ceil(gps_end_time))}") + cl_list.append(f"--type {query_type}") + cl_list.append(f"--output {output_cache_file}") + cl_list.append(f"--server {server}") + cl_list.append("--url-type file") + cl_list.append("--lal-cache") + cl = " ".join(cl_list) run_commandline(cl) return output_cache_file def convert_args_list_to_float(*args_list): - """ Converts inputs to floats, returns a list in the same order as the input""" + """Converts inputs to floats, returns a list in the same order as the input""" try: args_list = [float(arg) for arg in args_list] except ValueError: @@ -595,13 +614,13 @@ def convert_args_list_to_float(*args_list): def lalsim_SimInspiralTransformPrecessingNewInitialConditions( - theta_jn, phi_jl, tilt_1, tilt_2, phi_12, a_1, a_2, mass_1, mass_2, - reference_frequency, phase): + theta_jn, phi_jl, tilt_1, tilt_2, phi_12, a_1, a_2, mass_1, mass_2, reference_frequency, phase +): from lalsimulation import SimInspiralTransformPrecessingNewInitialConditions args_list = convert_args_list_to_float( - theta_jn, phi_jl, tilt_1, tilt_2, phi_12, a_1, a_2, mass_1, mass_2, - reference_frequency, phase) + theta_jn, phi_jl, tilt_1, tilt_2, phi_12, a_1, a_2, mass_1, mass_2, reference_frequency, phase + ) return SimInspiralTransformPrecessingNewInitialConditions(*args_list) @@ -609,6 +628,7 @@ def lalsim_SimInspiralTransformPrecessingNewInitialConditions( @lru_cache(maxsize=10) def lalsim_GetApproximantFromString(waveform_approximant): from lalsimulation import GetApproximantFromString + if isinstance(waveform_approximant, str): return GetApproximantFromString(waveform_approximant) else: @@ -616,11 +636,27 @@ def lalsim_GetApproximantFromString(waveform_approximant): def lalsim_SimInspiralFD( - mass_1, mass_2, spin_1x, spin_1y, spin_1z, spin_2x, spin_2y, - spin_2z, luminosity_distance, iota, phase, - longitude_ascending_nodes, eccentricity, mean_per_ano, delta_frequency, - minimum_frequency, maximum_frequency, reference_frequency, - waveform_dictionary, approximant): + mass_1, + mass_2, + spin_1x, + spin_1y, + spin_1z, + spin_2x, + spin_2y, + spin_2z, + luminosity_distance, + iota, + phase, + longitude_ascending_nodes, + eccentricity, + mean_per_ano, + delta_frequency, + minimum_frequency, + maximum_frequency, + reference_frequency, + waveform_dictionary, + approximant, +): """ Safely call lalsimulation.SimInspiralFD @@ -650,10 +686,25 @@ def lalsim_SimInspiralFD( from lalsimulation import SimInspiralFD args = convert_args_list_to_float( - mass_1, mass_2, spin_1x, spin_1y, spin_1z, spin_2x, spin_2y, spin_2z, - luminosity_distance, iota, phase, longitude_ascending_nodes, - eccentricity, mean_per_ano, delta_frequency, minimum_frequency, - maximum_frequency, reference_frequency) + mass_1, + mass_2, + spin_1x, + spin_1y, + spin_1z, + spin_2x, + spin_2y, + spin_2z, + luminosity_distance, + iota, + phase, + longitude_ascending_nodes, + eccentricity, + mean_per_ano, + delta_frequency, + minimum_frequency, + maximum_frequency, + reference_frequency, + ) approximant = _get_lalsim_approximant(approximant) @@ -661,11 +712,27 @@ def lalsim_SimInspiralFD( def lalsim_SimInspiralChooseFDWaveform( - mass_1, mass_2, spin_1x, spin_1y, spin_1z, spin_2x, spin_2y, - spin_2z, luminosity_distance, iota, phase, - longitude_ascending_nodes, eccentricity, mean_per_ano, delta_frequency, - minimum_frequency, maximum_frequency, reference_frequency, - waveform_dictionary, approximant): + mass_1, + mass_2, + spin_1x, + spin_1y, + spin_1z, + spin_2x, + spin_2y, + spin_2z, + luminosity_distance, + iota, + phase, + longitude_ascending_nodes, + eccentricity, + mean_per_ano, + delta_frequency, + minimum_frequency, + maximum_frequency, + reference_frequency, + waveform_dictionary, + approximant, +): """ Safely call lalsimulation.SimInspiralChooseFDWaveform @@ -695,10 +762,25 @@ def lalsim_SimInspiralChooseFDWaveform( from lalsimulation import SimInspiralChooseFDWaveform args = convert_args_list_to_float( - mass_1, mass_2, spin_1x, spin_1y, spin_1z, spin_2x, spin_2y, spin_2z, - luminosity_distance, iota, phase, longitude_ascending_nodes, - eccentricity, mean_per_ano, delta_frequency, minimum_frequency, - maximum_frequency, reference_frequency) + mass_1, + mass_2, + spin_1x, + spin_1y, + spin_1z, + spin_2x, + spin_2y, + spin_2z, + luminosity_distance, + iota, + phase, + longitude_ascending_nodes, + eccentricity, + mean_per_ano, + delta_frequency, + minimum_frequency, + maximum_frequency, + reference_frequency, + ) approximant = _get_lalsim_approximant(approximant) @@ -717,9 +799,22 @@ def _get_lalsim_approximant(approximant): def lalsim_SimInspiralChooseFDWaveformSequence( - phase, mass_1, mass_2, spin_1x, spin_1y, spin_1z, spin_2x, spin_2y, - spin_2z, reference_frequency, luminosity_distance, iota, - waveform_dictionary, approximant, frequency_array): + phase, + mass_1, + mass_2, + spin_1x, + spin_1y, + spin_1z, + spin_2x, + spin_2y, + spin_2z, + reference_frequency, + luminosity_distance, + iota, + waveform_dictionary, + approximant, + frequency_array, +): """ Safely call lalsimulation.SimInspiralChooseFDWaveformSequence @@ -741,13 +836,36 @@ def lalsim_SimInspiralChooseFDWaveformSequence( approximant: int, str frequency_array: np.ndarray, lal.REAL8Vector """ - from lal import REAL8Vector, CreateREAL8Vector + from lal import CreateREAL8Vector, REAL8Vector from lalsimulation import SimInspiralChooseFDWaveformSequence - [mass_1, mass_2, spin_1x, spin_1y, spin_1z, spin_2x, spin_2y, spin_2z, - luminosity_distance, iota, phase, reference_frequency] = convert_args_list_to_float( - mass_1, mass_2, spin_1x, spin_1y, spin_1z, spin_2x, spin_2y, spin_2z, - luminosity_distance, iota, phase, reference_frequency) + [ + mass_1, + mass_2, + spin_1x, + spin_1y, + spin_1z, + spin_2x, + spin_2y, + spin_2z, + luminosity_distance, + iota, + phase, + reference_frequency, + ] = convert_args_list_to_float( + mass_1, + mass_2, + spin_1x, + spin_1y, + spin_1z, + spin_2x, + spin_2y, + spin_2z, + luminosity_distance, + iota, + phase, + reference_frequency, + ) if isinstance(approximant, int): pass @@ -762,37 +880,49 @@ def lalsim_SimInspiralChooseFDWaveformSequence( frequency_array.data = old_frequency_array return SimInspiralChooseFDWaveformSequence( - phase, mass_1, mass_2, spin_1x, spin_1y, spin_1z, spin_2x, spin_2y, - spin_2z, reference_frequency, luminosity_distance, iota, - waveform_dictionary, approximant, frequency_array) + phase, + mass_1, + mass_2, + spin_1x, + spin_1y, + spin_1z, + spin_2x, + spin_2y, + spin_2z, + reference_frequency, + luminosity_distance, + iota, + waveform_dictionary, + approximant, + frequency_array, + ) -def lalsim_SimInspiralWaveformParamsInsertTidalLambda1( - waveform_dictionary, lambda_1): +def lalsim_SimInspiralWaveformParamsInsertTidalLambda1(waveform_dictionary, lambda_1): from lalsimulation import SimInspiralWaveformParamsInsertTidalLambda1 + try: lambda_1 = float(lambda_1) except ValueError: raise ValueError("Unable to convert lambda_1 to float") - return SimInspiralWaveformParamsInsertTidalLambda1( - waveform_dictionary, lambda_1) + return SimInspiralWaveformParamsInsertTidalLambda1(waveform_dictionary, lambda_1) -def lalsim_SimInspiralWaveformParamsInsertTidalLambda2( - waveform_dictionary, lambda_2): +def lalsim_SimInspiralWaveformParamsInsertTidalLambda2(waveform_dictionary, lambda_2): from lalsimulation import SimInspiralWaveformParamsInsertTidalLambda2 + try: lambda_2 = float(lambda_2) except ValueError: raise ValueError("Unable to convert lambda_2 to float") - return SimInspiralWaveformParamsInsertTidalLambda2( - waveform_dictionary, lambda_2) + return SimInspiralWaveformParamsInsertTidalLambda2(waveform_dictionary, lambda_2) def lalsim_SimNeutronStarEOS4ParamSDGammaCheck(g0, g1, g2, g3): from lalsimulation import SimNeutronStarEOS4ParamSDGammaCheck + try: g0 = float(g0) g1 = float(g1) @@ -808,6 +938,7 @@ def lalsim_SimNeutronStarEOS4ParamSDGammaCheck(g0, g1, g2, g3): def lalsim_SimNeutronStarEOS4ParameterSpectralDecomposition(g0, g1, g2, g3): from lalsimulation import SimNeutronStarEOS4ParameterSpectralDecomposition + try: g0 = float(g0) g1 = float(g1) @@ -823,6 +954,7 @@ def lalsim_SimNeutronStarEOS4ParameterSpectralDecomposition(g0, g1, g2, g3): def lalsim_SimNeutronStarEOS4ParamSDViableFamilyCheck(g0, g1, g2, g3): from lalsimulation import SimNeutronStarEOS4ParamSDViableFamilyCheck + try: g0 = float(g0) g1 = float(g1) @@ -838,6 +970,7 @@ def lalsim_SimNeutronStarEOS4ParamSDViableFamilyCheck(g0, g1, g2, g3): def lalsim_SimNeutronStarEOS3PieceDynamicPolytrope(g0, log10p1_si, g1, log10p2_si, g2): from lalsimulation import SimNeutronStarEOS3PieceDynamicPolytrope + try: g0 = float(g0) g1 = float(g1) @@ -854,6 +987,7 @@ def lalsim_SimNeutronStarEOS3PieceDynamicPolytrope(g0, log10p1_si, g1, log10p2_s def lalsim_SimNeutronStarEOS3PieceCausalAnalytic(v1, log10p1_si, v2, log10p2_si, v3): from lalsimulation import SimNeutronStarEOS3PieceCausalAnalytic + try: v1 = float(v1) v2 = float(v2) @@ -870,6 +1004,7 @@ def lalsim_SimNeutronStarEOS3PieceCausalAnalytic(v1, log10p1_si, v2, log10p2_si, def lalsim_SimNeutronStarEOS3PDViableFamilyCheck(p0, log10p1_si, p1, log10p2_si, p2, causal): from lalsimulation import SimNeutronStarEOS3PDViableFamilyCheck + try: p0 = float(p0) p1 = float(p1) @@ -899,6 +1034,7 @@ def lalsim_SimNeutronStarEOSMaxPseudoEnthalpy(eos): def lalsim_SimNeutronStarEOSSpeedOfSoundGeometerized(max_pseudo_enthalpy, eos): from lalsimulation import SimNeutronStarEOSSpeedOfSoundGeometerized + try: max_pseudo_enthalpy = float(max_pseudo_enthalpy) except ValueError: @@ -923,6 +1059,7 @@ def lalsim_SimNeutronStarMaximumMass(fam): def lalsim_SimNeutronStarRadius(mass_in_SI, fam): from lalsimulation import SimNeutronStarRadius + try: mass_in_SI = float(mass_in_SI) except ValueError: @@ -935,6 +1072,7 @@ def lalsim_SimNeutronStarRadius(mass_in_SI, fam): def lalsim_SimNeutronStarLoveNumberK2(mass_in_SI, fam): from lalsimulation import SimNeutronStarLoveNumberK2 + try: mass_in_SI = float(mass_in_SI) except ValueError: @@ -965,7 +1103,7 @@ def spline_angle_xform(delta_psi): return 180.0 / np.pi * np.arctan2(np.imag(rotation), np.real(rotation)) -def plot_spline_pos(log_freqs, samples, nfreqs=100, level=0.9, color='k', label=None, xform=None): +def plot_spline_pos(log_freqs, samples, nfreqs=100, level=0.9, color="k", label=None, xform=None): """ Plot calibration posterior estimates for a spline model in log space. Adapted from the same function in lalinference.bayespputils @@ -989,6 +1127,7 @@ def plot_spline_pos(log_freqs, samples, nfreqs=100, level=0.9, color='k', label= """ import matplotlib.pyplot as plt + freq_points = np.exp(log_freqs) freqs = np.logspace(min(log_freqs), max(log_freqs), nfreqs, base=np.exp(1)) @@ -999,28 +1138,40 @@ def plot_spline_pos(log_freqs, samples, nfreqs=100, level=0.9, color='k', label= else: scaled_samples = xform(samples) - scaled_samples_summary = SamplesSummary(scaled_samples, average='mean') - data_summary = SamplesSummary(data, average='mean') - - plt.errorbar(freq_points, scaled_samples_summary.average, - yerr=[-scaled_samples_summary.lower_relative_credible_interval, - scaled_samples_summary.upper_relative_credible_interval], - fmt='.', color=color, lw=4, alpha=0.5, capsize=0) + scaled_samples_summary = SamplesSummary(scaled_samples, average="mean") + data_summary = SamplesSummary(data, average="mean") + + plt.errorbar( + freq_points, + scaled_samples_summary.average, + yerr=[ + -scaled_samples_summary.lower_relative_credible_interval, + scaled_samples_summary.upper_relative_credible_interval, + ], + fmt=".", + color=color, + lw=4, + alpha=0.5, + capsize=0, + ) for i, sample in enumerate(samples): - temp = interp1d( - log_freqs, sample, kind="cubic", fill_value=0, - bounds_error=False)(np.log(freqs)) + temp = interp1d(log_freqs, sample, kind="cubic", fill_value=0, bounds_error=False)(np.log(freqs)) if xform is None: data[i] = temp else: data[i] = xform(temp) plt.plot(freqs, np.mean(data, axis=0), color=color, label=label) - plt.fill_between(freqs, data_summary.lower_absolute_credible_interval, - data_summary.upper_absolute_credible_interval, - color=color, alpha=.1, linewidth=0.1) - plt.xlim(freq_points.min() - .5, freq_points.max() + 50) + plt.fill_between( + freqs, + data_summary.lower_absolute_credible_interval, + data_summary.upper_absolute_credible_interval, + color=color, + alpha=0.1, + linewidth=0.1, + ) + plt.xlim(freq_points.min() - 0.5, freq_points.max() + 50) def ln_i0(value): @@ -1042,7 +1193,7 @@ def ln_i0(value): def calculate_time_to_merger(frequency, mass_1, mass_2, chi=0, safety=1.1): - """ Leading-order calculation of the time to merger from frequency + """Leading-order calculation of the time to merger from frequency This uses the XLALSimInspiralTaylorF2ReducedSpinChirpTime routine to estimate the time to merger. Based on the implementation in @@ -1066,12 +1217,9 @@ def calculate_time_to_merger(frequency, mass_1, mass_2, chi=0, safety=1.1): """ import lalsimulation + return safety * lalsimulation.SimInspiralTaylorF2ReducedSpinChirpTime( - frequency, - mass_1 * solar_mass, - mass_2 * solar_mass, - chi, - -1 + frequency, mass_1 * solar_mass, mass_2 * solar_mass, chi, -1 ) diff --git a/bilby/gw/waveform_generator.py b/bilby/gw/waveform_generator.py index 1012fc931..d0ec15a95 100644 --- a/bilby/gw/waveform_generator.py +++ b/bilby/gw/waveform_generator.py @@ -2,29 +2,35 @@ from ..core import utils from ..core.series import CoupledTimeAndFrequencySeries -from ..core.utils import PropertyAccessor -from ..core.utils import logger +from ..core.utils import PropertyAccessor, logger from .conversion import convert_to_lal_binary_black_hole_parameters from .utils import lalsim_GetApproximantFromString -class WaveformGenerator(object): +class WaveformGenerator: """ The base waveform generator class. Waveform generators provide a unified method to call disparate source models. """ - duration = PropertyAccessor('_times_and_frequencies', 'duration') - sampling_frequency = PropertyAccessor('_times_and_frequencies', 'sampling_frequency') - start_time = PropertyAccessor('_times_and_frequencies', 'start_time') - frequency_array = PropertyAccessor('_times_and_frequencies', 'frequency_array') - time_array = PropertyAccessor('_times_and_frequencies', 'time_array') - - def __init__(self, duration=None, sampling_frequency=None, start_time=0, frequency_domain_source_model=None, - time_domain_source_model=None, parameters=None, - parameter_conversion=None, - waveform_arguments=None): + duration = PropertyAccessor("_times_and_frequencies", "duration") + sampling_frequency = PropertyAccessor("_times_and_frequencies", "sampling_frequency") + start_time = PropertyAccessor("_times_and_frequencies", "start_time") + frequency_array = PropertyAccessor("_times_and_frequencies", "frequency_array") + time_array = PropertyAccessor("_times_and_frequencies", "time_array") + + def __init__( + self, + duration=None, + sampling_frequency=None, + start_time=0, + frequency_domain_source_model=None, + time_domain_source_model=None, + parameters=None, + parameter_conversion=None, + waveform_arguments=None, + ): """ The base waveform generator class. @@ -59,9 +65,9 @@ def __init__(self, duration=None, sampling_frequency=None, start_time=0, frequen the WaveformGenerator object and initialised to `None`. """ - self._times_and_frequencies = CoupledTimeAndFrequencySeries(duration=duration, - sampling_frequency=sampling_frequency, - start_time=start_time) + self._times_and_frequencies = CoupledTimeAndFrequencySeries( + duration=duration, sampling_frequency=sampling_frequency, start_time=start_time + ) self.frequency_domain_source_model = frequency_domain_source_model self.time_domain_source_model = time_domain_source_model self.source_parameter_keys = self._parameters_from_source_model() @@ -92,15 +98,16 @@ def __repr__(self): else: param_conv_name = utils.get_function_path(self.parameter_conversion) - return self.__class__.__name__ + '(duration={}, sampling_frequency={}, start_time={}, ' \ - 'frequency_domain_source_model={}, time_domain_source_model={}, ' \ - 'parameter_conversion={}, ' \ - 'waveform_arguments={})'\ - .format(self.duration, self.sampling_frequency, self.start_time, fdsm_name, tdsm_name, - param_conv_name, self.waveform_arguments) + return ( + self.__class__.__name__ + + f"(duration={self.duration}, sampling_frequency={self.sampling_frequency}, start_time={self.start_time}, " + f"frequency_domain_source_model={fdsm_name}, time_domain_source_model={tdsm_name}, " + f"parameter_conversion={param_conv_name}, " + f"waveform_arguments={self.waveform_arguments})" + ) def frequency_domain_strain(self, parameters=None): - """ Wrapper to source_model. + """Wrapper to source_model. Converts self.parameters with self.parameter_conversion before handing it off to the source model. Automatically refers to the time_domain_source model via NFFT if no frequency_domain_source_model is given. @@ -121,15 +128,17 @@ def frequency_domain_strain(self, parameters=None): RuntimeError: If no source model is given """ - return self._calculate_strain(model=self.frequency_domain_source_model, - model_data_points=self.frequency_array, - parameters=parameters, - transformation_function=utils.nfft, - transformed_model=self.time_domain_source_model, - transformed_model_data_points=self.time_array) + return self._calculate_strain( + model=self.frequency_domain_source_model, + model_data_points=self.frequency_array, + parameters=parameters, + transformation_function=utils.nfft, + transformed_model=self.time_domain_source_model, + transformed_model_data_points=self.time_array, + ) def time_domain_strain(self, parameters=None): - """ Wrapper to source_model. + """Wrapper to source_model. Converts self.parameters with self.parameter_conversion before handing it off to the source model. Automatically refers to the frequency_domain_source model via INFFT if no frequency_domain_source_model is @@ -151,33 +160,46 @@ def time_domain_strain(self, parameters=None): RuntimeError: If no source model is given """ - return self._calculate_strain(model=self.time_domain_source_model, - model_data_points=self.time_array, - parameters=parameters, - transformation_function=utils.infft, - transformed_model=self.frequency_domain_source_model, - transformed_model_data_points=self.frequency_array) - - def _calculate_strain(self, model, model_data_points, transformation_function, transformed_model, - transformed_model_data_points, parameters): + return self._calculate_strain( + model=self.time_domain_source_model, + model_data_points=self.time_array, + parameters=parameters, + transformation_function=utils.infft, + transformed_model=self.frequency_domain_source_model, + transformed_model_data_points=self.frequency_array, + ) + + def _calculate_strain( + self, + model, + model_data_points, + transformation_function, + transformed_model, + transformed_model_data_points, + parameters, + ): if parameters is None: parameters = self.parameters - if parameters == self._cache['parameters'] and self._cache['model'] == model and \ - self._cache['transformed_model'] == transformed_model: - return self._cache['waveform'] + if ( + parameters == self._cache["parameters"] + and self._cache["model"] == model + and self._cache["transformed_model"] == transformed_model + ): + return self._cache["waveform"] else: - self._cache['parameters'] = parameters.copy() - self._cache['model'] = model - self._cache['transformed_model'] = transformed_model + self._cache["parameters"] = parameters.copy() + self._cache["model"] = model + self._cache["transformed_model"] = transformed_model parameters = self._format_parameters(parameters) if model is not None: model_strain = self._strain_from_model(model_data_points, model, parameters) elif transformed_model is not None: - model_strain = self._strain_from_transformed_model(transformed_model_data_points, transformed_model, - transformation_function, parameters) + model_strain = self._strain_from_transformed_model( + transformed_model_data_points, transformed_model, transformation_function, parameters + ) else: raise RuntimeError("No source model given") - self._cache['waveform'] = model_strain + self._cache["waveform"] = model_strain return model_strain def _strain_from_model(self, model_data_points, model, parameters): @@ -186,9 +208,7 @@ def _strain_from_model(self, model_data_points, model, parameters): def _strain_from_transformed_model( self, transformed_model_data_points, transformed_model, transformation_function, parameters ): - transformed_model_strain = self._strain_from_model( - transformed_model_data_points, transformed_model, parameters - ) + transformed_model_strain = self._strain_from_model(transformed_model_data_points, transformed_model, parameters) if isinstance(transformed_model_strain, np.ndarray): return transformation_function(transformed_model_strain, self.sampling_frequency) @@ -196,15 +216,14 @@ def _strain_from_transformed_model( model_strain = dict() for key in transformed_model_strain: if transformation_function == utils.nfft: - model_strain[key], _ = \ - transformation_function(transformed_model_strain[key], self.sampling_frequency) + model_strain[key], _ = transformation_function(transformed_model_strain[key], self.sampling_frequency) else: model_strain[key] = transformation_function(transformed_model_strain[key], self.sampling_frequency) return model_strain @property def parameters(self): - """ The dictionary of parameters for source model. + """The dictionary of parameters for source model. Returns ======= @@ -238,8 +257,7 @@ def _format_parameters(self, parameters): raise TypeError('"parameters" must be a dictionary.') new_parameters = parameters.copy() new_parameters, _ = self.parameter_conversion(new_parameters) - for key in self.source_parameter_keys.symmetric_difference( - new_parameters): + for key in self.source_parameter_keys.symmetric_difference(new_parameters): new_parameters.pop(key) new_parameters.update(self.waveform_arguments) return new_parameters @@ -257,13 +275,13 @@ def _parameters_from_source_model(self): elif self.time_domain_source_model is not None: model = self.time_domain_source_model else: - raise AttributeError('Either time or frequency domain source ' - 'model must be provided.') + raise AttributeError("Either time or frequency domain source model must be provided.") return set(utils.infer_parameters_from_function(model)) class LALCBCWaveformGenerator(WaveformGenerator): - """ A waveform generator with specific checks for LAL CBC waveforms """ + """A waveform generator with specific checks for LAL CBC waveforms""" + LAL_SIM_INSPIRAL_SPINS_FLOW = 1 def __init__(self, **kwargs): @@ -272,6 +290,7 @@ def __init__(self, **kwargs): def validate_reference_frequency(self): from lalsimulation import SimInspiralGetSpinFreqFromApproximant + waveform_approximant = self.waveform_arguments["waveform_approximant"] waveform_approximant_number = lalsim_GetApproximantFromString(waveform_approximant) if SimInspiralGetSpinFreqFromApproximant(waveform_approximant_number) == self.LAL_SIM_INSPIRAL_SPINS_FLOW: @@ -311,7 +330,8 @@ class GWSignalWaveformGenerator(WaveformGenerator): A dictionary of fixed keyword arguments to pass to the waveform generator. There is one required waveform argument :code:`waveform_approximant`. - .. gwsignal waveform generator: https://docs.ligo.org/lscsoft/lalsuite/lalsimulation/classlalsimulation_1_1gwsignal_1_1core_1_1waveform_1_1_gravitational_wave_generator.html # noqa + # noqa: E501 + .. gwsignal waveform generator: https://docs.ligo.org/lscsoft/lalsuite/lalsimulation/classlalsimulation_1_1gwsignal_1_1core_1_1waveform_1_1_gravitational_wave_generator.html """ generator_pickles = False @@ -389,11 +409,11 @@ def _from_bilby_parameters(self, **parameters): pn_amplitude_order=0, ) waveform_kwargs.update(self.waveform_arguments) - reference_frequency = waveform_kwargs['reference_frequency'] - minimum_frequency = waveform_kwargs['minimum_frequency'] - maximum_frequency = waveform_kwargs['maximum_frequency'] - mode_array = waveform_kwargs['mode_array'] - pn_amplitude_order = waveform_kwargs['pn_amplitude_order'] + reference_frequency = waveform_kwargs["reference_frequency"] + minimum_frequency = waveform_kwargs["minimum_frequency"] + maximum_frequency = waveform_kwargs["maximum_frequency"] + mode_array = waveform_kwargs["mode_array"] + pn_amplitude_order = waveform_kwargs["pn_amplitude_order"] if pn_amplitude_order != 0: # This is to mimic the behaviour in @@ -403,7 +423,7 @@ def _from_bilby_parameters(self, **parameters): pn_amplitude_order = 3 # Equivalent to MAX_PRECESSING_AMP_PN_ORDER in LALSimulation else: pn_amplitude_order = 6 # Equivalent to MAX_NONPRECESSING_AMP_PN_ORDER in LALSimulation - start_frequency = minimum_frequency * 2. / (pn_amplitude_order + 2) + start_frequency = minimum_frequency * 2.0 / (pn_amplitude_order + 2) else: start_frequency = minimum_frequency @@ -426,36 +446,33 @@ def _from_bilby_parameters(self, **parameters): ) gwsignal_dict = { - 'mass1': parameters["mass_1"], - 'mass2': parameters["mass_2"], - 'spin1x': spin_1x, - 'spin1y': spin_1y, - 'spin1z': spin_1z, - 'spin2x': spin_2x, - 'spin2y': spin_2y, - 'spin2z': spin_2z, - 'lambda1': parameters["lambda_1"], - 'lambda2': parameters["lambda_2"], - 'deltaF': 1 / self.duration, - 'deltaT': 1 / self.sampling_frequency, - 'f22_start': start_frequency, - 'f_max': maximum_frequency, - 'f22_ref': reference_frequency, - 'phi_ref': parameters["phase"], - 'distance': parameters["luminosity_distance"] * 1e6, - 'inclination': iota, - 'eccentricity': parameters["eccentricity"], - 'meanPerAno': parameters["mean_per_ano"], - 'condition': int(self.generator.metadata["implemented_domain"] == 'time'), + "mass1": parameters["mass_1"], + "mass2": parameters["mass_2"], + "spin1x": spin_1x, + "spin1y": spin_1y, + "spin1z": spin_1z, + "spin2x": spin_2x, + "spin2y": spin_2y, + "spin2z": spin_2z, + "lambda1": parameters["lambda_1"], + "lambda2": parameters["lambda_2"], + "deltaF": 1 / self.duration, + "deltaT": 1 / self.sampling_frequency, + "f22_start": start_frequency, + "f_max": maximum_frequency, + "f22_ref": reference_frequency, + "phi_ref": parameters["phase"], + "distance": parameters["luminosity_distance"] * 1e6, + "inclination": iota, + "eccentricity": parameters["eccentricity"], + "meanPerAno": parameters["mean_per_ano"], + "condition": int(self.generator.metadata["implemented_domain"] == "time"), } # add astropy units to the parameters using the defaults from gwsignal from lalsimulation.gwsignal.core.parameter_conventions import Cosmo_units_dictionary - gwsignal_dict = { - key: val << Cosmo_units_dictionary.get(key, 0) - for key, val in gwsignal_dict.items() - } + gwsignal_dict = {key: val << Cosmo_units_dictionary.get(key, 0) for key, val in gwsignal_dict.items()} if mode_array is not None: gwsignal_dict.update(ModeArray=mode_array) @@ -463,17 +480,17 @@ def _from_bilby_parameters(self, **parameters): extra_args = waveform_kwargs.copy() for key in [ - "waveform_approximant", - "reference_frequency", - "minimum_frequency", - "maximum_frequency", - "catch_waveform_errors", - "mode_array", - "pn_spin_order", - "pn_amplitude_order", - "pn_tidal_order", - "pn_phase_order", - "numerical_relativity_file", + "waveform_approximant", + "reference_frequency", + "minimum_frequency", + "maximum_frequency", + "catch_waveform_errors", + "mode_array", + "pn_spin_order", + "pn_amplitude_order", + "pn_tidal_order", + "pn_phase_order", + "numerical_relativity_file", ]: if key in extra_args.keys(): del extra_args[key] @@ -491,7 +508,7 @@ def frequency_domain_strain(self, parameters): GenerateFDWaveform, self._from_bilby_parameters(**parameters), self.generator, - self.waveform_arguments.get("catch_waveform_errors", False) + self.waveform_arguments.get("catch_waveform_errors", False), ) wf = self._extract_waveform(hpc, "frequency") @@ -500,14 +517,11 @@ def frequency_domain_strain(self, parameters): minimum_frequency = self.waveform_arguments.get("minimum_frequency", 20.0) maximum_frequency = self.waveform_arguments.get("maximum_frequency", self.frequency_array[-1]) - frequency_bounds = ( - (self.frequency_array >= minimum_frequency) - * (self.frequency_array <= maximum_frequency) - ) + frequency_bounds = (self.frequency_array >= minimum_frequency) * (self.frequency_array <= maximum_frequency) for key in wf: wf[key] *= frequency_bounds - if self.generator.metadata["implemented_domain"] == 'time': + if self.generator.metadata["implemented_domain"] == "time": dt = 1 / hpc.hp.df.value + hpc.hp.epoch.value time_shift = np.exp(-1j * 2 * np.pi * dt * self.frequency_array[frequency_bounds]) for key in wf: @@ -525,7 +539,7 @@ def time_domain_strain(self, parameters): GenerateTDWaveform, self._from_bilby_parameters(**parameters), self.generator, - self.waveform_arguments.get("catch_waveform_errors", False) + self.waveform_arguments.get("catch_waveform_errors", False), ) return self._extract_waveform(hpc, "time") @@ -547,13 +561,14 @@ def _extract_waveform(self, hpc, kind): if len(hpc.hp) > len(array): logger.debug( f"GWsignal waveform longer than bilby's `{kind}_array`({len(hpc.hp)} " - f"vs {len(array)}). Truncating GWsignal array.") + f"vs {len(array)}). Truncating GWsignal array." + ) # set slice to force the output into a numpy array - h_plus[:] = hpc.hp[:len(h_plus)] - h_cross[:] = hpc.hc[:len(h_cross)] + h_plus[:] = hpc.hp[: len(h_plus)] + h_cross[:] = hpc.hc[: len(h_cross)] else: - h_plus[:len(hpc.hp)] = hpc.hp - h_cross[:len(hpc.hc)] = hpc.hc + h_plus[: len(hpc.hp)] = hpc.hp + h_cross[: len(hpc.hc)] = hpc.hc return dict(plus=h_plus, cross=h_cross) diff --git a/bilby/hyper/__init__.py b/bilby/hyper/__init__.py index b25e08c4f..afcfe3ccc 100644 --- a/bilby/hyper/__init__.py +++ b/bilby/hyper/__init__.py @@ -1 +1,3 @@ from . import likelihood, model + +__all__ = [likelihood, model] diff --git a/bilby/hyper/likelihood.py b/bilby/hyper/likelihood.py index 4b691845e..7a9060017 100644 --- a/bilby/hyper/likelihood.py +++ b/bilby/hyper/likelihood.py @@ -1,15 +1,14 @@ - import logging import numpy as np from ..core.likelihood import Likelihood, _fallback_to_parameters -from .model import Model from ..core.prior import PriorDict +from .model import Model class HyperparameterLikelihood(Likelihood): - """ A likelihood for inferring hyperparameter posterior distributions + """A likelihood for inferring hyperparameter posterior distributions See Eq. (34) of https://arxiv.org/abs/1809.02293 for a definition. @@ -32,14 +31,15 @@ class HyperparameterLikelihood(Likelihood): """ - def __init__(self, posteriors, hyper_prior, sampling_prior=None, - log_evidences=None, max_samples=1e100): + def __init__(self, posteriors, hyper_prior, sampling_prior=None, log_evidences=None, max_samples=1e100): if not isinstance(hyper_prior, Model): hyper_prior = Model([hyper_prior]) if sampling_prior is None: - if ('log_prior' not in posteriors[0].keys()) and ('prior' not in posteriors[0].keys()): - raise ValueError('Missing both sampling prior function and prior or log_prior ' - 'column in posterior dictionary. Must pass one or the other.') + if ("log_prior" not in posteriors[0].keys()) and ("prior" not in posteriors[0].keys()): + raise ValueError( + "Missing both sampling prior function and prior or log_prior " + "column in posterior dictionary. Must pass one or the other." + ) else: if not (isinstance(sampling_prior, Model) or isinstance(sampling_prior, PriorDict)): sampling_prior = Model([sampling_prior]) @@ -51,18 +51,16 @@ def __init__(self, posteriors, hyper_prior, sampling_prior=None, self.hyper_prior = hyper_prior self.sampling_prior = sampling_prior self.max_samples = max_samples - super(HyperparameterLikelihood, self).__init__() + super().__init__() self.data = self.resample_posteriors() self.n_posteriors = len(self.posteriors) self.samples_per_posterior = self.max_samples - self.samples_factor =\ - - self.n_posteriors * np.log(self.samples_per_posterior) + self.samples_factor = -self.n_posteriors * np.log(self.samples_per_posterior) def log_likelihood_ratio(self, parameters=None): parameters = _fallback_to_parameters(self, parameters) - log_l = np.sum(np.log(np.sum(self.hyper_prior.prob(self.data, **parameters) / - self.data['prior'], axis=-1))) + log_l = np.sum(np.log(np.sum(self.hyper_prior.prob(self.data, **parameters) / self.data["prior"], axis=-1))) log_l += self.samples_factor return np.nan_to_num(log_l) @@ -92,18 +90,17 @@ def resample_posteriors(self, max_samples=None): for posterior in self.posteriors: self.max_samples = min(len(posterior), self.max_samples) data = {key: [] for key in self.posteriors[0]} - if 'log_prior' in data.keys(): - data.pop('log_prior') - if 'prior' not in data.keys(): - data['prior'] = [] - logging.debug('Downsampling to {} samples per posterior.'.format( - self.max_samples)) + if "log_prior" in data.keys(): + data.pop("log_prior") + if "prior" not in data.keys(): + data["prior"] = [] + logging.debug(f"Downsampling to {self.max_samples} samples per posterior.") for posterior in self.posteriors: temp = posterior.sample(self.max_samples) if self.sampling_prior is not None: - temp['prior'] = self.sampling_prior.prob(temp, axis=0) - elif 'log_prior' in temp.keys(): - temp['prior'] = np.exp(temp['log_prior']) + temp["prior"] = self.sampling_prior.prob(temp, axis=0) + elif "log_prior" in temp.keys(): + temp["prior"] = np.exp(temp["log_prior"]) for key in data: data[key].append(temp[key]) for key in data: diff --git a/bilby/hyper/model.py b/bilby/hyper/model.py index eaa2da98c..23f40dece 100644 --- a/bilby/hyper/model.py +++ b/bilby/hyper/model.py @@ -54,10 +54,7 @@ def prob(self, data, **kwargs): probability = 1.0 for ii, function in enumerate(self.models): function_parameters = self._get_function_parameters(function, **kwargs) - if ( - self.cache - and self._cached_parameters[function] == function_parameters - ): + if self.cache and self._cached_parameters[function] == function_parameters: new_probability = self._cached_probability[function] else: new_probability = function(data, **function_parameters) diff --git a/cli_bilby/bilby_result.py b/cli_bilby/bilby_result.py index 48d546747..03e40a6f4 100644 --- a/cli_bilby/bilby_result.py +++ b/cli_bilby/bilby_result.py @@ -1,4 +1,4 @@ -""" A command line interface to ease the process of batch jobs on result files +"""A command line interface to ease the process of batch jobs on result files Examples -------- @@ -21,6 +21,7 @@ individually. Note that passing extra commands in is not yet implemented. """ + import argparse import bilby @@ -54,9 +55,7 @@ def setup_command_line_args(): action="store_true", help="Gzip the merged output results file if using JSON format.", ) - parser.add_argument( - "-o", "--outdir", type=str, default=None, help="Output directory." - ) + parser.add_argument("-o", "--outdir", type=str, default=None, help="Output directory.") parser.add_argument( "-l", "--label", @@ -95,9 +94,7 @@ def setup_command_line_args(): help="Merge the set of runs, output saved using the outdir and label", ) - action_parser.add_argument( - "-b", "--bayes", action="store_true", help="Print all Bayes factors." - ) + action_parser.add_argument("-b", "--bayes", action="store_true", help="Print all Bayes factors.") action_parser.add_argument( "-p", "--print", @@ -115,10 +112,7 @@ def setup_command_line_args(): action_parser.add_argument( "--ipython", action="store_true", - help=( - "For each result given, drops the user into an " - "IPython shell with the result loaded in" - ), + help=("For each result given, drops the user into an IPython shell with the result loaded in"), ) args = parser.parse_args() @@ -143,10 +137,8 @@ def print_bayes_factors(results_list): def drop_to_ipython(results_list): for result in results_list: - message = "Opened IPython terminal for result {}".format(result.label) - message += "\nRunning with bilby={},\nResult generated with bilby={}".format( - bilby.__version__, result.version - ) + message = f"Opened IPython terminal for result {result.label}" + message += f"\nRunning with bilby={bilby.__version__},\nResult generated with bilby={result.version}" message += "\nBilby result loaded as `result`" import IPython @@ -155,7 +147,7 @@ def drop_to_ipython(results_list): def print_matches(results_list, args): for r in results_list: - print("\nResult file: {}/{}".format(r.outdir, r.label)) + print(f"\nResult file: {r.outdir}/{r.label}") for key in args.keys: for attr in r.__dict__: if key in attr: @@ -178,11 +170,7 @@ def apply_max_samples(result, args): def apply_lightweight(result, args): - strip_keys = [ - "_nested_samples", - "log_likelihood_evaluations", - "log_prior_evaluations" - ] + strip_keys = ["_nested_samples", "log_likelihood_evaluations", "log_prior_evaluations"] for key in strip_keys: setattr(result, key, None) return result diff --git a/cli_bilby/plot_multiple_posteriors.py b/cli_bilby/plot_multiple_posteriors.py index 51636e48b..54b23e1a0 100644 --- a/cli_bilby/plot_multiple_posteriors.py +++ b/cli_bilby/plot_multiple_posteriors.py @@ -2,18 +2,14 @@ def setup_command_line_args(): - parser = argparse.ArgumentParser( - description="Plot corner plots from results files") - parser.add_argument("-r", "--results", nargs='+', - help="List of results files to use.") - parser.add_argument("-f", "--filename", default=None, - help="Output file name.") - parser.add_argument("-l", "--labels", nargs='+', default=None, - help="List of labels to use for each result.") - parser.add_argument("-p", "--parameters", nargs='+', default=None, - help="List of parameters.") - parser.add_argument("-e", "--evidences", action='store_true', default=False, - help="Add the evidences to the legend.") + parser = argparse.ArgumentParser(description="Plot corner plots from results files") + parser.add_argument("-r", "--results", nargs="+", help="List of results files to use.") + parser.add_argument("-f", "--filename", default=None, help="Output file name.") + parser.add_argument("-l", "--labels", nargs="+", default=None, help="List of labels to use for each result.") + parser.add_argument("-p", "--parameters", nargs="+", default=None, help="List of parameters.") + parser.add_argument( + "-e", "--evidences", action="store_true", default=False, help="Add the evidences to the legend." + ) args, _ = parser.parse_known_args() return args @@ -22,9 +18,8 @@ def setup_command_line_args(): def main(): args = setup_command_line_args() import bilby - results = [bilby.core.result.read_in_result(filename=r) - for r in args.results] - bilby.core.result.plot_multiple(results, filename=args.filename, - labels=args.labels, - parameters=args.parameters, - evidences=args.evidences) + + results = [bilby.core.result.read_in_result(filename=r) for r in args.results] + bilby.core.result.plot_multiple( + results, filename=args.filename, labels=args.labels, parameters=args.parameters, evidences=args.evidences + ) diff --git a/containers/write_dockerfiles.py b/containers/write_dockerfiles.py index 064c5d0f1..a201306f8 100644 --- a/containers/write_dockerfiles.py +++ b/containers/write_dockerfiles.py @@ -1,6 +1,6 @@ from datetime import date -with open("dockerfile-template", "r") as ff: +with open("dockerfile-template") as ff: template = ff.read() python_versions = [(3, 10), (3, 11), (3, 12)] @@ -8,16 +8,12 @@ for python_major_version, python_minor_version in python_versions: key = f"python{python_major_version}{python_minor_version}" - with open( - f"v3-dockerfile-test-suite-{key}", - "w" - ) as ff: + with open(f"v3-dockerfile-test-suite-{key}", "w") as ff: + ff.write("# This dockerfile is written automatically and should not be modified by hand.\n\n") ff.write( - "# This dockerfile is written automatically and should not be " - "modified by hand.\n\n" + template.format( + date=today, + python_major_version=python_major_version, + python_minor_version=python_minor_version, + ) ) - ff.write(template.format( - date=today, - python_major_version=python_major_version, - python_minor_version=python_minor_version, - )) diff --git a/docs/conf.py b/docs/conf.py index d65666dbe..475f9ee72 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # bilby documentation build configuration file, created by # sphinx-quickstart on Fri May 25 12:08:01 2018. @@ -21,13 +20,15 @@ import os import subprocess import sys + import bilby -sys.path.insert(0, os.path.abspath('../')) + +sys.path.insert(0, os.path.abspath("../")) def git_revision_hash(): try: - return subprocess.check_output(['git', 'rev-parse', 'HEAD']).decode('ascii').strip() + return subprocess.check_output(["git", "rev-parse", "HEAD"]).decode("ascii").strip() except subprocess.CalledProcessError: return "master" @@ -58,43 +59,43 @@ def git_upstream_url(): # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.mathjax', - 'numpydoc', - 'nbsphinx', - 'sphinx.ext.autosummary', - 'sphinx.ext.autosectionlabel', - 'sphinx_tabs.tabs', + "sphinx.ext.autodoc", + "sphinx.ext.mathjax", + "numpydoc", + "nbsphinx", + "sphinx.ext.autosummary", + "sphinx.ext.autosectionlabel", + "sphinx_tabs.tabs", "sphinx.ext.linkcode", - 'myst_parser', + "myst_parser", "sphinx_sitemap", ] autosummary_generate = True # Add any paths that contain templates here, relative to this directory. -templates_path = ['templates'] +templates_path = ["templates"] # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: # -source_suffix = ['.rst', '.md', '.txt'] +source_suffix = [".rst", ".md", ".txt"] # The master toctree document. -master_doc = 'index' +master_doc = "index" # General information about the project. -project = u'bilby' -copyright = u'2019, Greg Ashton' -author = u'Greg Ashton' +project = "bilby" +copyright = "2019, Greg Ashton" +author = "Greg Ashton" # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the # built documents. # -fullversion = bilby.__version__.split(':')[0] +fullversion = bilby.__version__.split(":")[0] # The short X.Y version. -version = '.'.join(fullversion.split('.')[:2]) +version = ".".join(fullversion.split(".")[:2]) # The full version, including alpha/beta/rc tags. release = version @@ -109,16 +110,16 @@ def git_upstream_url(): # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This patterns also effect to html_static_path and html_extra_path -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store', 'requirements.txt'] +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", "requirements.txt"] # The html path allows for google search console verification through # Aditya Vijaykumar's gmail ID -html_extra_path = ['robots.txt', 'google063678b5c432c237.html'] +html_extra_path = ["robots.txt", "google063678b5c432c237.html"] # The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'sphinx' +pygments_style = "sphinx" # If true, `todo` and `todoList` produce output, else they produce nothing. todo_include_todos = False @@ -129,7 +130,7 @@ def git_upstream_url(): # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -html_theme = 'sphinx_rtd_theme' +html_theme = "sphinx_rtd_theme" # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the @@ -148,21 +149,21 @@ def git_upstream_url(): # This is required for the alabaster theme # refs: http://alabaster.readthedocs.io/en/latest/installation.html#sidebars html_sidebars = { - '**': [ - 'about.html', - 'navigation.html', - 'relations.html', # needs 'show_related': True theme option to display - 'searchbox.html', - 'donate.html', + "**": [ + "about.html", + "navigation.html", + "relations.html", # needs 'show_related': True theme option to display + "searchbox.html", + "donate.html", ] } -html_baseurl = 'https://bilby-dev.github.io/bilby/' +html_baseurl = "https://bilby-dev.github.io/bilby/" # -- Options for HTMLHelp output ------------------------------------------ # Output file base name for HTML help builder. -htmlhelp_basename = 'bilbydoc' +htmlhelp_basename = "bilbydoc" # -- Options for LaTeX output --------------------------------------------- @@ -171,15 +172,12 @@ def git_upstream_url(): # The paper size ('letterpaper' or 'a4paper'). # # 'papersize': 'letterpaper', - # The font size ('10pt', '11pt' or '12pt'). # # 'pointsize': '10pt', - # Additional stuff for the LaTeX preamble. # # 'preamble': '', - # Latex figure (float) alignment # # 'figure_align': 'htbp', @@ -189,8 +187,7 @@ def git_upstream_url(): # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ - (master_doc, 'bilby.tex', u'bilby Documentation', - u'Paul Lasky', 'manual'), + (master_doc, "bilby.tex", "bilby Documentation", "Paul Lasky", "manual"), ] @@ -198,10 +195,7 @@ def git_upstream_url(): # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). -man_pages = [ - (master_doc, 'bilby', u'bilby Documentation', - [author], 1) -] +man_pages = [(master_doc, "bilby", "bilby Documentation", [author], 1)] # -- Options for Texinfo output ------------------------------------------- @@ -210,9 +204,7 @@ def git_upstream_url(): # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - (master_doc, 'bilby', u'bilby Documentation', - author, 'bilby', 'One line description of project.', - 'Miscellaneous'), + (master_doc, "bilby", "bilby Documentation", author, "bilby", "One line description of project.", "Miscellaneous"), ] numpydoc_show_class_members = False @@ -225,9 +217,9 @@ def linkcode_resolve(domain, info): """ Adapted from https://github.com/aaugustin/websockets/blob/8e1628a14e0dd2ca98871c7500484b5d42d16b67/docs/conf.py """ - if domain != 'py': + if domain != "py": return None - if not info['module']: + if not info["module"]: return None try: diff --git a/examples/core_examples/15d_gaussian.py b/examples/core_examples/15d_gaussian.py index e52bc9e6b..b9ac1c7bb 100644 --- a/examples/core_examples/15d_gaussian.py +++ b/examples/core_examples/15d_gaussian.py @@ -273,12 +273,7 @@ likelihood = AnalyticalMultidimensionalCovariantGaussian(mean, cov) priors = bilby.core.prior.PriorDict() -priors.update( - { - "x{0}".format(i): bilby.core.prior.Uniform(-5, 5, "x{0}".format(i)) - for i in range(dim) - } -) +priors.update({f"x{i}": bilby.core.prior.Uniform(-5, 5, f"x{i}") for i in range(dim)}) result = bilby.run_sampler( likelihood=likelihood, @@ -290,27 +285,18 @@ resume=True, ) -result.plot_corner(parameters={"x{0}".format(i): mean[i] for i in range(dim)}) +result.plot_corner(parameters={f"x{i}": mean[i] for i in range(dim)}) # The prior is constant and flat, and the likelihood is normalised such that the area under it is one. # The analytical evidence is then given as 1/(prior volume) -log_prior_vol = np.sum( - np.log([prior.maximum - prior.minimum for key, prior in priors.items()]) -) +log_prior_vol = np.sum(np.log([prior.maximum - prior.minimum for key, prior in priors.items()])) log_evidence = -log_prior_vol -sampled_std = [ - np.std(result.posterior[param]) for param in result.search_parameter_keys -] +sampled_std = [np.std(result.posterior[param]) for param in result.search_parameter_keys] logger.info("Analytic log evidence: " + str(log_evidence)) -logger.info( - "Sampled log evidence: " - + str(result.log_evidence) - + " +/- " - + str(result.log_evidence_err) -) +logger.info("Sampled log evidence: " + str(result.log_evidence) + " +/- " + str(result.log_evidence_err)) for i, search_parameter_key in enumerate(result.search_parameter_keys): logger.info(search_parameter_key) @@ -327,12 +313,7 @@ likelihood = AnalyticalMultidimensionalBimodalCovariantGaussian(mean_1, mean_2, cov) priors = bilby.core.prior.PriorDict() -priors.update( - { - "x{0}".format(i): bilby.core.prior.Uniform(-5, 5, "x{0}".format(i)) - for i in range(dim) - } -) +priors.update({f"x{i}": bilby.core.prior.Uniform(-5, 5, f"x{i}") for i in range(dim)}) result = bilby.run_sampler( likelihood=likelihood, @@ -344,17 +325,15 @@ resume=True, ) result.plot_corner( - parameters={"x{0}".format(i): mean_1[i] for i in range(dim)}, + parameters={f"x{i}": mean_1[i] for i in range(dim)}, filename=outdir + "/multidim_gaussian_bimodal_mode_1", ) result.plot_corner( - parameters={"x{0}".format(i): mean_2[i] for i in range(dim)}, + parameters={f"x{i}": mean_2[i] for i in range(dim)}, filename=outdir + "/multidim_gaussian_bimodal_mode_2", ) -log_prior_vol = np.sum( - np.log([prior.maximum - prior.minimum for key, prior in priors.items()]) -) +log_prior_vol = np.sum(np.log([prior.maximum - prior.minimum for key, prior in priors.items()])) log_evidence = -log_prior_vol sampled_std_1 = [] sampled_std_2 = [] @@ -366,21 +345,10 @@ sampled_std_2.append(np.std(samples_2)) logger.info("Analytic log evidence: " + str(log_evidence)) -logger.info( - "Sampled log evidence: " - + str(result.log_evidence) - + " +/- " - + str(result.log_evidence_err) -) +logger.info("Sampled log evidence: " + str(result.log_evidence) + " +/- " + str(result.log_evidence_err)) for i, search_parameter_key in enumerate(result.search_parameter_keys): logger.info(search_parameter_key) - logger.info( - "Expected posterior standard deviation both modes: " + str(likelihood.sigma[i]) - ) - logger.info( - "Sampled posterior standard deviation first mode: " + str(sampled_std_1[i]) - ) - logger.info( - "Sampled posterior standard deviation second mode: " + str(sampled_std_2[i]) - ) + logger.info("Expected posterior standard deviation both modes: " + str(likelihood.sigma[i])) + logger.info("Sampled posterior standard deviation first mode: " + str(sampled_std_1[i])) + logger.info("Sampled posterior standard deviation second mode: " + str(sampled_std_2[i])) diff --git a/examples/core_examples/alternative_samplers/linear_regression_bilby_mcmc.py b/examples/core_examples/alternative_samplers/linear_regression_bilby_mcmc.py index 50fb5d769..338f89877 100644 --- a/examples/core_examples/alternative_samplers/linear_regression_bilby_mcmc.py +++ b/examples/core_examples/alternative_samplers/linear_regression_bilby_mcmc.py @@ -5,6 +5,7 @@ data with background Gaussian noise """ + import bilby import matplotlib.pyplot as plt import numpy as np @@ -45,7 +46,7 @@ def model(time, m, c): ax.set_xlabel("time") ax.set_ylabel("y") ax.legend() -fig.savefig("{}/{}_data.png".format(outdir, label)) +fig.savefig(f"{outdir}/{label}_data.png") # Now lets instantiate a version of our GaussianLikelihood, giving it # the time, data and signal model diff --git a/examples/core_examples/alternative_samplers/linear_regression_pymc.py b/examples/core_examples/alternative_samplers/linear_regression_pymc.py index cc28e1c56..64792f732 100644 --- a/examples/core_examples/alternative_samplers/linear_regression_pymc.py +++ b/examples/core_examples/alternative_samplers/linear_regression_pymc.py @@ -5,6 +5,7 @@ data with background Gaussian noise """ + import bilby import matplotlib.pyplot as plt import numpy as np @@ -46,7 +47,7 @@ def model(time, m, c): ax.set_xlabel("time") ax.set_ylabel("y") ax.legend() -fig.savefig("{}/{}_data.png".format(outdir, label)) +fig.savefig(f"{outdir}/{label}_data.png") # Now lets instantiate a version of our GaussianLikelihood, giving it # the time, data and signal model diff --git a/examples/core_examples/alternative_samplers/linear_regression_pymc_custom_likelihood.py b/examples/core_examples/alternative_samplers/linear_regression_pymc_custom_likelihood.py index cb381aec7..4bf018c0e 100644 --- a/examples/core_examples/alternative_samplers/linear_regression_pymc_custom_likelihood.py +++ b/examples/core_examples/alternative_samplers/linear_regression_pymc_custom_likelihood.py @@ -50,7 +50,7 @@ def model(time, m, c): ax.set_xlabel("time") ax.set_ylabel("y") ax.legend() -fig.savefig("{}/{}_data.png".format(outdir, label)) +fig.savefig(f"{outdir}/{label}_data.png") # Parameter estimation: we now define a Gaussian Likelihood class relevant for @@ -73,7 +73,7 @@ def __init__(self, x, y, sigma, func): will require a prior and will be sampled over (unless a fixed value is given). """ - super(GaussianLikelihoodPyMC, self).__init__(x=x, y=y, func=func, sigma=sigma) + super().__init__(x=x, y=y, func=func, sigma=sigma) def log_likelihood(self, sampler=None, parameters=None): """ @@ -91,9 +91,7 @@ def log_likelihood(self, sampler=None, parameters=None): """ if not isinstance(sampler, Pymc): - return super(GaussianLikelihoodPyMC, self).log_likelihood( - parameters=parameters - ) + return super().log_likelihood(parameters=parameters) if not hasattr(sampler, "pymc_model"): raise AttributeError("Sampler has not PyMC model attribute") @@ -119,9 +117,7 @@ def __init__(self, minimum, maximum, name=None, latex_label=None): """ Uniform prior with bounds (should be equivalent to bilby.prior.Uniform) """ - bilby.core.prior.Prior.__init__( - self, name, latex_label, minimum=minimum, maximum=maximum - ) + bilby.core.prior.Prior.__init__(self, name, latex_label, minimum=minimum, maximum=maximum) def ln_prob(self, sampler=None): """ @@ -132,7 +128,7 @@ def ln_prob(self, sampler=None): float or array to be passed to the superclass. """ if not isinstance(sampler, Pymc): - return super(PyMCUniformPrior, self).ln_prob(sampler) + return super().ln_prob(sampler) return pm.Uniform(self.name, lower=self.minimum, upper=self.maximum) diff --git a/examples/core_examples/dirichlet.py b/examples/core_examples/dirichlet.py index 6cd8ceff4..b5443d4b6 100644 --- a/examples/core_examples/dirichlet.py +++ b/examples/core_examples/dirichlet.py @@ -25,14 +25,10 @@ injection_parameters=injection_parameters, ) -result.posterior[label + str(n_dim - 1)] = 1 - np.sum( - [result.posterior[key] for key in priors], axis=0 -) +result.posterior[label + str(n_dim - 1)] = 1 - np.sum([result.posterior[key] for key in priors], axis=0) result.plot_corner(parameters=injection_parameters) samples = priors.sample(10000) samples[label + str(n_dim - 1)] = 1 - np.sum([samples[key] for key in samples], axis=0) result.posterior = pd.DataFrame(samples) -result.plot_corner( - parameters=[key for key in samples], filename="outdir/dirichlet_prior_corner.png" -) +result.plot_corner(parameters=[key for key in samples], filename="outdir/dirichlet_prior_corner.png") diff --git a/examples/core_examples/gaussian_example.py b/examples/core_examples/gaussian_example.py index f8b80ec37..52b732447 100644 --- a/examples/core_examples/gaussian_example.py +++ b/examples/core_examples/gaussian_example.py @@ -3,6 +3,7 @@ An example of how to use bilby to perform parameter estimation for non-gravitational wave data consisting of a Gaussian with a mean and variance """ + import bilby import numpy as np from bilby.core.utils import random @@ -40,9 +41,7 @@ def __init__(self, data): def log_likelihood(self, parameters): sigma = parameters["sigma"] res = self.data - parameters["mu"] - return -0.5 * ( - np.sum((res / sigma) ** 2) + self.N * np.log(2 * np.pi * sigma**2) - ) + return -0.5 * (np.sum((res / sigma) ** 2) + self.N * np.log(2 * np.pi * sigma**2)) likelihood = SimpleGaussianLikelihood(data) diff --git a/examples/core_examples/gaussian_process_celerite_example.py b/examples/core_examples/gaussian_process_celerite_example.py index a06dd22ed..8fd08fd6d 100644 --- a/examples/core_examples/gaussian_process_celerite_example.py +++ b/examples/core_examples/gaussian_process_celerite_example.py @@ -56,9 +56,7 @@ def linear_function(x, a, b): duration = times[-1] - times[0] ys = ( - amplitude - * np.sin(2 * np.pi * times / period) - * np.exp(-((times - 50) ** 2) / 2 / width**2) + amplitude * np.sin(2 * np.pi * times / period) * np.exp(-((times - 50) ** 2) / 2 / width**2) + random.rng.normal(scale=jitter, size=len(times)) + linear_function(x=times, a=slope, b=offset) ) @@ -94,9 +92,7 @@ def linear_function(x, a, b): # of white noise during the inference process. Smaller values of `yerr` # cause the program to break. If you know the `yerr` in your problem, # you can pass them in as an array. -likelihood = bilby.core.likelihood.CeleriteLikelihood( - kernel=kernel, mean_model=mean_model, t=times, y=ys, yerr=1e-6 -) +likelihood = bilby.core.likelihood.CeleriteLikelihood(kernel=kernel, mean_model=mean_model, t=times, y=ys, yerr=1e-6) # Print the parameter names. This is useful if we have trouble figuring out # how `celerite` applies its naming scheme. @@ -105,18 +101,10 @@ def linear_function(x, a, b): # Set up the priors. We know the name of the parameters from the print # statement in the line before. priors = bilby.core.prior.PriorDict() -priors["kernel:terms[0]:log_S0"] = Uniform( - minimum=-10, maximum=30, name="log_S0", latex_label=r"$\ln S_0$" -) -priors["kernel:terms[0]:log_Q"] = Uniform( - minimum=-10, maximum=30, name="log_Q", latex_label=r"$\ln Q$" -) -priors["kernel:terms[0]:log_omega0"] = Uniform( - minimum=-5, maximum=5, name="log_omega0", latex_label=r"$\ln \omega_0$" -) -priors["kernel:terms[1]:log_sigma"] = Uniform( - minimum=-5, maximum=5, name="log sigma", latex_label=r"$\ln \sigma$" -) +priors["kernel:terms[0]:log_S0"] = Uniform(minimum=-10, maximum=30, name="log_S0", latex_label=r"$\ln S_0$") +priors["kernel:terms[0]:log_Q"] = Uniform(minimum=-10, maximum=30, name="log_Q", latex_label=r"$\ln Q$") +priors["kernel:terms[0]:log_omega0"] = Uniform(minimum=-5, maximum=5, name="log_omega0", latex_label=r"$\ln \omega_0$") +priors["kernel:terms[1]:log_sigma"] = Uniform(minimum=-5, maximum=5, name="log sigma", latex_label=r"$\ln \sigma$") priors["mean:a"] = Uniform(minimum=-100, maximum=100, name="a", latex_label=r"$a$") priors["mean:b"] = Uniform(minimum=-100, maximum=100, name="b", latex_label=r"$b$") @@ -186,9 +174,7 @@ def linear_function(x, a, b): plt.plot(x, trend, color="green", label="Mean") # Plot the mean model for ten other posterior samples. -samples = [ - result.posterior.iloc[random.rng.integers(len(result.posterior))] for _ in range(10) -] +samples = [result.posterior.iloc[random.rng.integers(len(result.posterior))] for _ in range(10)] for sample in samples: likelihood.set_parameters(sample) if not isinstance(likelihood.mean_model, (float, int)): diff --git a/examples/core_examples/gaussian_process_george_example.py b/examples/core_examples/gaussian_process_george_example.py index 409201997..bc9bd761b 100644 --- a/examples/core_examples/gaussian_process_george_example.py +++ b/examples/core_examples/gaussian_process_george_example.py @@ -54,9 +54,7 @@ def linear_function(x, a, b): duration = times[-1] - times[0] ys = ( - amplitude - * np.sin(2 * np.pi * times / period) - * np.exp(-((times - 50) ** 2) / 2 / width**2) + amplitude * np.sin(2 * np.pi * times / period) * np.exp(-((times - 50) ** 2) / 2 / width**2) + random.rng.normal(scale=jitter, size=len(times)) + linear_function(x=times, a=slope, b=offset) ) @@ -87,9 +85,7 @@ def linear_function(x, a, b): # of white noise during the inference process. # Smaller values of `yerr` # cause the program to break. If you know the `yerr` in your problem, # you can pass them in as # an array. -likelihood = bilby.core.likelihood.GeorgeLikelihood( - kernel=kernel, mean_model=mean_model, t=times, y=ys, yerr=1e-6 -) +likelihood = bilby.core.likelihood.GeorgeLikelihood(kernel=kernel, mean_model=mean_model, t=times, y=ys, yerr=1e-6) # Print the parameter names. This is useful if we have trouble figuring out # how `celerite` applies its naming scheme. @@ -99,15 +95,9 @@ def linear_function(x, a, b): # Set up the priors. We know the name of the parameters from the print # statement in the line before. priors = bilby.core.prior.PriorDict() -priors["kernel:k1:log_constant"] = Uniform( - minimum=-10, maximum=30, name="log_A", latex_label=r"$\ln A$" -) -priors["kernel:k2:metric:log_M_0_0"] = Uniform( - minimum=-10, maximum=30, name="log_M_0_0", latex_label=r"$\ln M_{00}$" -) -priors["white_noise:value"] = Uniform( - minimum=0, maximum=10, name="white noise", latex_label=r"$\sigma$" -) +priors["kernel:k1:log_constant"] = Uniform(minimum=-10, maximum=30, name="log_A", latex_label=r"$\ln A$") +priors["kernel:k2:metric:log_M_0_0"] = Uniform(minimum=-10, maximum=30, name="log_M_0_0", latex_label=r"$\ln M_{00}$") +priors["white_noise:value"] = Uniform(minimum=0, maximum=10, name="white noise", latex_label=r"$\sigma$") priors["mean:a"] = Uniform(minimum=-100, maximum=100, name="a", latex_label=r"$a$") priors["mean:b"] = Uniform(minimum=-100, maximum=100, name="b", latex_label=r"$b$") @@ -165,9 +155,7 @@ def linear_function(x, a, b): plt.plot(x, trend, color="green", label="Mean") # Plot the mean model for ten other posterior samples. -samples = [ - result.posterior.iloc[random.rng.integer(len(result.posterior))] for _ in range(10) -] +samples = [result.posterior.iloc[random.rng.integer(len(result.posterior))] for _ in range(10)] for sample in samples: likelihood.set_parameters(sample) if not isinstance(likelihood.mean_model, (float, int)): diff --git a/examples/core_examples/grid_example.py b/examples/core_examples/grid_example.py index f2afa36c8..9aee5e89d 100644 --- a/examples/core_examples/grid_example.py +++ b/examples/core_examples/grid_example.py @@ -5,6 +5,7 @@ data with background Gaussian noise """ + import bilby import matplotlib.pyplot as plt import numpy as np @@ -45,7 +46,7 @@ def model(time, m, c): ax.set_xlabel("time") ax.set_ylabel("y") ax.legend() -fig.savefig("{}/{}_data.png".format(outdir, label)) +fig.savefig(f"{outdir}/{label}_data.png") plt.close() # Now lets instantiate a version of our GaussianLikelihood, giving it diff --git a/examples/core_examples/hyper_parameter_example.py b/examples/core_examples/hyper_parameter_example.py index f78430565..1bea10085 100644 --- a/examples/core_examples/hyper_parameter_example.py +++ b/examples/core_examples/hyper_parameter_example.py @@ -2,6 +2,7 @@ """ An example of how to use bilby to perform parameter estimation for hyper params """ + import matplotlib.pyplot as plt import numpy as np from bilby.core.likelihood import GaussianLikelihood @@ -55,7 +56,7 @@ def model(x, c0, c1): nlive=1000, outdir=outdir, verbose=False, - label="individual_{}".format(i), + label=f"individual_{i}", save=False, injection_parameters=injection_parameters, ) @@ -79,10 +80,7 @@ def model(x, c0, c1): def hyper_prior(dataset, mu, sigma): - return ( - np.exp(-((dataset["c0"] - mu) ** 2) / (2 * sigma**2)) - / (2 * np.pi * sigma**2) ** 0.5 - ) + return np.exp(-((dataset["c0"] - mu) ** 2) / (2 * sigma**2)) / (2 * np.pi * sigma**2) ** 0.5 samples = [result.posterior for result in results] diff --git a/examples/core_examples/linear_regression.py b/examples/core_examples/linear_regression.py index 744c3017c..630d27bd9 100644 --- a/examples/core_examples/linear_regression.py +++ b/examples/core_examples/linear_regression.py @@ -5,6 +5,7 @@ data with background Gaussian noise """ + import bilby import matplotlib.pyplot as plt import numpy as np @@ -45,7 +46,7 @@ def model(time, m, c): ax.set_xlabel("time") ax.set_ylabel("y") ax.legend() -fig.savefig("{}/{}_data.png".format(outdir, label)) +fig.savefig(f"{outdir}/{label}_data.png") # Now lets instantiate a version of our GaussianLikelihood, giving it # the time, data and signal model diff --git a/examples/core_examples/linear_regression_grid.py b/examples/core_examples/linear_regression_grid.py index 24030f4c3..c19942a1e 100644 --- a/examples/core_examples/linear_regression_grid.py +++ b/examples/core_examples/linear_regression_grid.py @@ -5,6 +5,7 @@ This will compare the output of using a stochastic sampling method to evaluating the posterior on a grid. """ + import bilby import matplotlib.pyplot as plt import numpy as np @@ -47,7 +48,7 @@ def model(time, m, c): ax.set_xlabel("time") ax.set_ylabel("y") ax.legend() -fig.savefig("{}/{}_data.png".format(outdir, label)) +fig.savefig(f"{outdir}/{label}_data.png") # Now lets instantiate a version of our GaussianLikelihood, giving it # the time, data and signal model @@ -93,8 +94,8 @@ def model(time, m, c): np.exp(grid.ln_posterior - np.max(grid.ln_posterior)), ) -fig.savefig("{}/{}_corner.png".format(outdir, label), dpi=300) +fig.savefig(f"{outdir}/{label}_corner.png", dpi=300) # compare evidences -print("Dynesty log(evidence): {}".format(result.log_evidence)) -print("Grid log(evidence): {}".format(grid_evidence)) +print(f"Dynesty log(evidence): {result.log_evidence}") +print(f"Grid log(evidence): {grid_evidence}") diff --git a/examples/core_examples/linear_regression_unknown_noise.py b/examples/core_examples/linear_regression_unknown_noise.py index fe869e8cb..e67d6c938 100644 --- a/examples/core_examples/linear_regression_unknown_noise.py +++ b/examples/core_examples/linear_regression_unknown_noise.py @@ -5,6 +5,7 @@ data with background Gaussian noise with unknown variance. """ + import bilby import matplotlib.pyplot as plt import numpy as np @@ -45,7 +46,7 @@ def model(time, m, c): ax.set_xlabel("time") ax.set_ylabel("y") ax.legend() -fig.savefig("{}/{}_data.png".format(outdir, label)) +fig.savefig(f"{outdir}/{label}_data.png") injection_parameters.update(dict(sigma=1)) diff --git a/examples/core_examples/linear_regression_with_Fisher.py b/examples/core_examples/linear_regression_with_Fisher.py index ebf4188af..23b9ec4ab 100644 --- a/examples/core_examples/linear_regression_with_Fisher.py +++ b/examples/core_examples/linear_regression_with_Fisher.py @@ -6,6 +6,8 @@ estimated using the Fisher Information Matrix approximation. """ +# ruff: noqa: E402 + import copy import bilby @@ -69,6 +71,4 @@ def model(time, m, c): result_fim.posterior = fim.sample_dataframe("maxL", 10000) result_fim.label = "Fisher" -bilby.core.result.plot_multiple( - [result, result_fim], parameters=injection_parameters, truth_color="k" -) +bilby.core.result.plot_multiple([result, result_fim], parameters=injection_parameters, truth_color="k") diff --git a/examples/core_examples/logo/sample_logo.py b/examples/core_examples/logo/sample_logo.py index 61ef772df..32d869985 100644 --- a/examples/core_examples/logo/sample_logo.py +++ b/examples/core_examples/logo/sample_logo.py @@ -1,4 +1,5 @@ -""" Script used to generate the samples for the bilby logo """ +"""Script used to generate the samples for the bilby logo""" + import bilby import numpy as np import scipy.interpolate as si @@ -15,7 +16,7 @@ def log_likelihood(self, parameters): for letter in ["B", "I", "L", "Y"]: - img = 1 - io.imread("{}.png".format(letter), as_gray=True)[::-1, :] + img = 1 - io.imread(f"{letter}.png", as_gray=True)[::-1, :] x = np.arange(img.shape[0]) y = np.arange(img.shape[1]) interp = si.RectBivariateSpline(x, y, img, kx=1, ky=1) diff --git a/examples/core_examples/multivariate_gaussian_prior.py b/examples/core_examples/multivariate_gaussian_prior.py index d0419d386..b6de752cf 100644 --- a/examples/core_examples/multivariate_gaussian_prior.py +++ b/examples/core_examples/multivariate_gaussian_prior.py @@ -99,4 +99,4 @@ def model(time, m, c): ) axs[2].add_artist(ell) -fig.savefig("{}/{}_corner.png".format(outdir, label), dpi=300) +fig.savefig(f"{outdir}/{label}_corner.png", dpi=300) diff --git a/examples/core_examples/occam_factor_example.py b/examples/core_examples/occam_factor_example.py index 3997d132f..ca6c9ae7e 100644 --- a/examples/core_examples/occam_factor_example.py +++ b/examples/core_examples/occam_factor_example.py @@ -30,6 +30,7 @@ improved by increasing this to say 500 or 1000. """ + import bilby import matplotlib.pyplot as plt import numpy as np @@ -56,7 +57,7 @@ ax.set_xlabel("time") ax.set_ylabel("y") ax.legend() -fig.savefig("{}/{}_data.png".format(outdir, label)) +fig.savefig(f"{outdir}/{label}_data.png") class Polynomial(bilby.Likelihood): @@ -73,7 +74,7 @@ def __init__(self, x, y, sigma, n): n: int The degree of the polynomial to fit. """ - self.keys = ["c{}".format(k) for k in range(n)] + self.keys = [f"c{k}" for k in range(n)] super().__init__() self.x = x self.y = y @@ -87,17 +88,14 @@ def polynomial(self, x, parameters): def log_likelihood(self, parameters): res = self.y - self.polynomial(self.x, parameters) - return -0.5 * ( - np.sum((res / self.sigma) ** 2) - + self.N * np.log(2 * np.pi * self.sigma**2) - ) + return -0.5 * (np.sum((res / self.sigma) ** 2) + self.N * np.log(2 * np.pi * self.sigma**2)) def fit(n): likelihood = Polynomial(time, data, sigma, n) priors = {} for i in range(n): - k = "c{}".format(i) + k = f"c{i}" priors[k] = bilby.core.prior.Uniform(0, 10, k) result = bilby.run_sampler( @@ -137,4 +135,4 @@ def fit(n): ax2.set_ylabel("Occam factor", color="C1") ax1.set_xlabel("Degree of polynomial") -fig.savefig("{}/{}_test".format(outdir, label)) +fig.savefig(f"{outdir}/{label}_test") diff --git a/examples/core_examples/radioactive_decay.py b/examples/core_examples/radioactive_decay.py index 0e669979b..01131ccba 100644 --- a/examples/core_examples/radioactive_decay.py +++ b/examples/core_examples/radioactive_decay.py @@ -4,6 +4,7 @@ non-gravitational wave data. In this case, fitting the half-life and initial radionuclide number for Polonium 214. """ + import bilby import matplotlib.pyplot as plt import numpy as np @@ -75,7 +76,7 @@ def decay_rate(delta_t, halflife, n_init): ax.set_xlabel("time") ax.set_ylabel("counts") ax.legend() -fig.savefig("{}/{}_data.png".format(outdir, label)) +fig.savefig(f"{outdir}/{label}_data.png") # Now lets instantiate a version of the Poisson Likelihood, giving it # the time intervals, counts and rate model @@ -84,9 +85,7 @@ def decay_rate(delta_t, halflife, n_init): # Make the prior priors = dict() priors["halflife"] = LogUniform(1e-5, 1e5, latex_label="$t_{1/2}$", unit="min") -priors["n_init"] = LogUniform( - 1e-25 / atto, 1e-10 / atto, latex_label="$N_0$", unit="attomole" -) +priors["n_init"] = LogUniform(1e-25 / atto, 1e-10 / atto, latex_label="$N_0$", unit="attomole") # And run sampler result = bilby.run_sampler( diff --git a/examples/core_examples/slabspike_example.py b/examples/core_examples/slabspike_example.py index 2c42c174b..52a7ae6ef 100644 --- a/examples/core_examples/slabspike_example.py +++ b/examples/core_examples/slabspike_example.py @@ -24,11 +24,7 @@ # Here we define our model. We want to inject two Gaussians and recover with up to three. def gaussian(xs, amplitude, mu, sigma): - return ( - amplitude - / np.sqrt(2 * np.pi * sigma**2) - * np.exp(-0.5 * (xs - mu) ** 2 / sigma**2) - ) + return amplitude / np.sqrt(2 * np.pi * sigma**2) * np.exp(-0.5 * (xs - mu) ** 2 / sigma**2) def triple_gaussian( @@ -84,52 +80,26 @@ def triple_gaussian( # Now we want to set up our priors. priors = bilby.core.prior.PriorDict() # For the slab-and-spike prior, we first need to define the 'slab' part, which is just a regular bilby prior. -amplitude_slab_0 = bilby.core.prior.Uniform( - minimum=-10, maximum=10, name="amplitude_0", latex_label="$A_0$" -) -amplitude_slab_1 = bilby.core.prior.Uniform( - minimum=-10, maximum=10, name="amplitude_1", latex_label="$A_1$" -) -amplitude_slab_2 = bilby.core.prior.Uniform( - minimum=-10, maximum=10, name="amplitude_2", latex_label="$A_2$" -) +amplitude_slab_0 = bilby.core.prior.Uniform(minimum=-10, maximum=10, name="amplitude_0", latex_label="$A_0$") +amplitude_slab_1 = bilby.core.prior.Uniform(minimum=-10, maximum=10, name="amplitude_1", latex_label="$A_1$") +amplitude_slab_2 = bilby.core.prior.Uniform(minimum=-10, maximum=10, name="amplitude_2", latex_label="$A_2$") # We do the following to create the slab-and-spike prior. The spike height is somewhat arbitrary and can # be corrected in post-processing. -priors["amplitude_0"] = bilby.core.prior.SlabSpikePrior( - slab=amplitude_slab_0, spike_location=0, spike_height=0.1 -) -priors["amplitude_1"] = bilby.core.prior.SlabSpikePrior( - slab=amplitude_slab_1, spike_location=0, spike_height=0.1 -) -priors["amplitude_2"] = bilby.core.prior.SlabSpikePrior( - slab=amplitude_slab_2, spike_location=0, spike_height=0.1 -) +priors["amplitude_0"] = bilby.core.prior.SlabSpikePrior(slab=amplitude_slab_0, spike_location=0, spike_height=0.1) +priors["amplitude_1"] = bilby.core.prior.SlabSpikePrior(slab=amplitude_slab_1, spike_location=0, spike_height=0.1) +priors["amplitude_2"] = bilby.core.prior.SlabSpikePrior(slab=amplitude_slab_2, spike_location=0, spike_height=0.1) # Our problem has a degeneracy in the ordering. In general, this problem is somewhat difficult to resolve properly. # See e.g. https://github.com/GregoryAshton/kookaburra/blob/master/src/priors.py#L72 for an implementation. # We resolve this by not letting the priors overlap in this case. -priors["mu_0"] = bilby.core.prior.Uniform( - minimum=-5, maximum=-2, name="mu_0", latex_label=r"$\mu_0$" -) -priors["mu_1"] = bilby.core.prior.Uniform( - minimum=-2, maximum=2, name="mu_1", latex_label=r"$\mu_1$" -) -priors["mu_2"] = bilby.core.prior.Uniform( - minimum=2, maximum=5, name="mu_2", latex_label=r"$\mu_2$" -) -priors["sigma_0"] = bilby.core.prior.LogUniform( - minimum=0.01, maximum=10, name="sigma_0", latex_label=r"$\sigma_0$" -) -priors["sigma_1"] = bilby.core.prior.LogUniform( - minimum=0.01, maximum=10, name="sigma_1", latex_label=r"$\sigma_1$" -) -priors["sigma_2"] = bilby.core.prior.LogUniform( - minimum=0.01, maximum=10, name="sigma_2", latex_label=r"$\sigma_2$" -) +priors["mu_0"] = bilby.core.prior.Uniform(minimum=-5, maximum=-2, name="mu_0", latex_label=r"$\mu_0$") +priors["mu_1"] = bilby.core.prior.Uniform(minimum=-2, maximum=2, name="mu_1", latex_label=r"$\mu_1$") +priors["mu_2"] = bilby.core.prior.Uniform(minimum=2, maximum=5, name="mu_2", latex_label=r"$\mu_2$") +priors["sigma_0"] = bilby.core.prior.LogUniform(minimum=0.01, maximum=10, name="sigma_0", latex_label=r"$\sigma_0$") +priors["sigma_1"] = bilby.core.prior.LogUniform(minimum=0.01, maximum=10, name="sigma_1", latex_label=r"$\sigma_1$") +priors["sigma_2"] = bilby.core.prior.LogUniform(minimum=0.01, maximum=10, name="sigma_2", latex_label=r"$\sigma_2$") # Setting up the likelihood and running the samplers works the same as elsewhere. -likelihood = bilby.core.likelihood.GaussianLikelihood( - x=xs, y=ys, func=triple_gaussian, sigma=sigma -) +likelihood = bilby.core.likelihood.GaussianLikelihood(x=xs, y=ys, func=triple_gaussian, sigma=sigma) result = bilby.run_sampler( likelihood=likelihood, priors=priors, @@ -152,15 +122,9 @@ def triple_gaussian( plt.clf() # Finally, we can check what fraction of amplitude samples are exactly on the spike. -spike_samples_0 = len(np.where(result.posterior["amplitude_0"] == 0.0)[0]) / len( - result.posterior -) -spike_samples_1 = len(np.where(result.posterior["amplitude_1"] == 0.0)[0]) / len( - result.posterior -) -spike_samples_2 = len(np.where(result.posterior["amplitude_2"] == 0.0)[0]) / len( - result.posterior -) +spike_samples_0 = len(np.where(result.posterior["amplitude_0"] == 0.0)[0]) / len(result.posterior) +spike_samples_1 = len(np.where(result.posterior["amplitude_1"] == 0.0)[0]) / len(result.posterior) +spike_samples_2 = len(np.where(result.posterior["amplitude_2"] == 0.0)[0]) / len(result.posterior) print(f"{spike_samples_0 * 100:.2f}% of amplitude_0 samples are exactly 0.0") print(f"{spike_samples_1 * 100:.2f}% of amplitude_1 samples are exactly 0.0") print(f"{spike_samples_2 * 100:.2f}% of amplitude_2 samples are exactly 0.0") diff --git a/examples/gw_examples/data_examples/GW150914.py b/examples/gw_examples/data_examples/GW150914.py index f1aff756e..dfb486314 100755 --- a/examples/gw_examples/data_examples/GW150914.py +++ b/examples/gw_examples/data_examples/GW150914.py @@ -9,6 +9,7 @@ [1] https://gwpy.github.io/docs/stable/timeseries/remote-access.html """ + import bilby from gwpy.timeseries import TimeSeries @@ -36,17 +37,15 @@ # We now use gwpy to obtain analysis and psd data and create the ifo_list ifo_list = bilby.gw.detector.InterferometerList([]) for det in detectors: - logger.info("Downloading analysis data for ifo {}".format(det)) + logger.info(f"Downloading analysis data for ifo {det}") ifo = bilby.gw.detector.get_empty_interferometer(det) data = TimeSeries.fetch_open_data(det, start_time, end_time) ifo.strain_data.set_from_gwpy_timeseries(data) - logger.info("Downloading psd data for ifo {}".format(det)) + logger.info(f"Downloading psd data for ifo {det}") psd_data = TimeSeries.fetch_open_data(det, psd_start_time, psd_end_time) psd_alpha = 2 * roll_off / duration - psd = psd_data.psd( - fftlength=duration, overlap=0, window=("tukey", psd_alpha), method="median" - ) + psd = psd_data.psd(fftlength=duration, overlap=0, window=("tukey", psd_alpha), method="median") ifo.power_spectral_density = bilby.gw.detector.PowerSpectralDensity( frequency_array=psd.frequencies.value, psd_array=psd.value ) @@ -54,7 +53,7 @@ ifo.minimum_frequency = minimum_frequency ifo_list.append(ifo) -logger.info("Saving data plots to {}".format(outdir)) +logger.info(f"Saving data plots to {outdir}") bilby.core.utils.check_directory_exists_and_if_not_mkdir(outdir) ifo_list.plot_data(outdir=outdir, label=label) @@ -66,9 +65,7 @@ priors = bilby.gw.prior.BBHPriorDict(filename="GW150914.prior") # Add the geocent time prior -priors["geocent_time"] = bilby.core.prior.Uniform( - trigger_time - 0.1, trigger_time + 0.1, name="geocent_time" -) +priors["geocent_time"] = bilby.core.prior.Uniform(trigger_time - 0.1, trigger_time + 0.1, name="geocent_time") # In this step we define a `waveform_generator`. This is the object which # creates the frequency-domain strain. In this instance, we are using the diff --git a/examples/gw_examples/data_examples/GW190425.py b/examples/gw_examples/data_examples/GW190425.py index efe9da71a..28d4f5318 100644 --- a/examples/gw_examples/data_examples/GW190425.py +++ b/examples/gw_examples/data_examples/GW190425.py @@ -9,6 +9,7 @@ [1] https://gwpy.github.io/docs/stable/timeseries/remote-access.html """ + import bilby from gwpy.timeseries import TimeSeries @@ -63,17 +64,15 @@ # We now use gwpy to obtain analysis and psd data and create the ifo_list ifo_list = bilby.gw.detector.InterferometerList([]) for det in detectors: - logger.info("Downloading analysis data for ifo {}".format(det)) + logger.info(f"Downloading analysis data for ifo {det}") ifo = bilby.gw.detector.get_empty_interferometer(det) data = TimeSeries.fetch_open_data(det, start_time, end_time) ifo.strain_data.set_from_gwpy_timeseries(data) - logger.info("Downloading psd data for ifo {}".format(det)) + logger.info(f"Downloading psd data for ifo {det}") psd_data = TimeSeries.fetch_open_data(det, psd_start_time, psd_end_time) psd_alpha = 2 * roll_off / duration - psd = psd_data.psd( - fftlength=duration, overlap=0, window=("tukey", psd_alpha), method="median" - ) + psd = psd_data.psd(fftlength=duration, overlap=0, window=("tukey", psd_alpha), method="median") ifo.power_spectral_density = bilby.gw.detector.PowerSpectralDensity( frequency_array=psd.frequencies.value, psd_array=psd.value ) @@ -81,7 +80,7 @@ ifo.minimum_frequency = minimum_frequency ifo_list.append(ifo) -logger.info("Saving data plots to {}".format(outdir)) +logger.info(f"Saving data plots to {outdir}") bilby.core.utils.check_directory_exists_and_if_not_mkdir(outdir) ifo_list.plot_data(outdir=outdir, label=label) @@ -94,9 +93,7 @@ priors["fiducial"] = 0 # Add the geocent time prior -priors["geocent_time"] = bilby.core.prior.Uniform( - trigger_time - 0.1, trigger_time + 0.1, name="geocent_time" -) +priors["geocent_time"] = bilby.core.prior.Uniform(trigger_time - 0.1, trigger_time + 0.1, name="geocent_time") # In this step we define a `waveform_generator`. This is the object which # creates the frequency-domain strain. In this instance, we are using the diff --git a/examples/gw_examples/data_examples/read_gracedb_data.py b/examples/gw_examples/data_examples/read_gracedb_data.py index e92ef8cd8..500f63881 100644 --- a/examples/gw_examples/data_examples/read_gracedb_data.py +++ b/examples/gw_examples/data_examples/read_gracedb_data.py @@ -32,9 +32,7 @@ minimum_frequency = 10 # Hz # Get frame caches -candidate, frame_caches = bilby.gw.utils.get_gracedb( - gracedb, outdir, duration, calibration, detectors, query_types -) +candidate, frame_caches = bilby.gw.utils.get_gracedb(gracedb, outdir, duration, calibration, detectors, query_types) # Set up interferometer objects from the cache files interferometers = bilby.gw.detector.InterferometerList([]) diff --git a/examples/gw_examples/injection_examples/australian_detector.py b/examples/gw_examples/injection_examples/australian_detector.py index d2a390cf0..a62bdd796 100644 --- a/examples/gw_examples/injection_examples/australian_detector.py +++ b/examples/gw_examples/injection_examples/australian_detector.py @@ -31,9 +31,7 @@ # Set up the detector as a four-kilometer detector in Gingin # The location of this detector is not defined in Bilby, so we need to add it AusIFO = bilby.gw.detector.Interferometer( - power_spectral_density=bilby.gw.detector.PowerSpectralDensity( - frequency_array=curve.freq, psd_array=curve.psd - ), + power_spectral_density=bilby.gw.detector.PowerSpectralDensity(frequency_array=curve.freq, psd_array=curve.psd), name="AusIFO", length=4, minimum_frequency=20, @@ -48,9 +46,7 @@ # Set up two other detectors at Hanford and Livingston interferometers = bilby.gw.detector.InterferometerList(["H1", "L1"]) for ifo in interferometers: - ifo.power_spectral_density = bilby.gw.detector.PowerSpectralDensity( - frequency_array=curve.freq, psd_array=curve.psd - ) + ifo.power_spectral_density = bilby.gw.detector.PowerSpectralDensity(frequency_array=curve.freq, psd_array=curve.psd) # append the Australian detector to the list of other detectors interferometers.append(AusIFO) @@ -94,17 +90,11 @@ # inject the signal into the interferometers for ifo in interferometers: - ifo.set_strain_data_from_power_spectral_density( - sampling_frequency=sampling_frequency, duration=duration - ) - ifo.inject_signal( - parameters=injection_parameters, waveform_generator=waveform_generator - ) + ifo.set_strain_data_from_power_spectral_density(sampling_frequency=sampling_frequency, duration=duration) + ifo.inject_signal(parameters=injection_parameters, waveform_generator=waveform_generator) # plot the data for sanity - signal = ifo.get_detector_response( - waveform_generator.frequency_domain_strain(), injection_parameters - ) + signal = ifo.get_detector_response(waveform_generator.frequency_domain_strain(), injection_parameters) ifo.plot_data(signal=signal, outdir=outdir, label=label) # set up priors diff --git a/examples/gw_examples/injection_examples/binary_neutron_star_example.py b/examples/gw_examples/injection_examples/binary_neutron_star_example.py index bc5a8ec40..04eb72ad9 100644 --- a/examples/gw_examples/injection_examples/binary_neutron_star_example.py +++ b/examples/gw_examples/injection_examples/binary_neutron_star_example.py @@ -8,7 +8,6 @@ tidal deformabilities """ - import bilby from bilby.core.utils.random import seed @@ -73,9 +72,7 @@ interferometers.set_strain_data_from_power_spectral_densities( sampling_frequency=sampling_frequency, duration=duration, start_time=start_time ) -interferometers.inject_signal( - parameters=injection_parameters, waveform_generator=waveform_generator -) +interferometers.inject_signal(parameters=injection_parameters, waveform_generator=waveform_generator) # Load the default prior for binary neutron stars. # We're going to sample in chirp_mass, symmetric_mass_ratio, lambda_tilde, and @@ -96,22 +93,12 @@ ]: priors[key] = injection_parameters[key] del priors["mass_ratio"], priors["lambda_1"], priors["lambda_2"] -priors["chirp_mass"] = bilby.core.prior.Gaussian( - 1.215, 0.1, name="chirp_mass", unit="$M_{\\odot}$" -) -priors["symmetric_mass_ratio"] = bilby.core.prior.Uniform( - 0.1, 0.25, name="symmetric_mass_ratio" -) +priors["chirp_mass"] = bilby.core.prior.Gaussian(1.215, 0.1, name="chirp_mass", unit="$M_{\\odot}$") +priors["symmetric_mass_ratio"] = bilby.core.prior.Uniform(0.1, 0.25, name="symmetric_mass_ratio") priors["lambda_tilde"] = bilby.core.prior.Uniform(0, 5000, name="lambda_tilde") -priors["delta_lambda_tilde"] = bilby.core.prior.Uniform( - -500, 1000, name="delta_lambda_tilde" -) -priors["lambda_1"] = bilby.core.prior.Constraint( - name="lambda_1", minimum=0, maximum=10000 -) -priors["lambda_2"] = bilby.core.prior.Constraint( - name="lambda_2", minimum=0, maximum=10000 -) +priors["delta_lambda_tilde"] = bilby.core.prior.Uniform(-500, 1000, name="delta_lambda_tilde") +priors["lambda_1"] = bilby.core.prior.Constraint(name="lambda_1", minimum=0, maximum=10000) +priors["lambda_2"] = bilby.core.prior.Constraint(name="lambda_2", minimum=0, maximum=10000) # Initialise the likelihood by passing in the interferometer data (IFOs) diff --git a/examples/gw_examples/injection_examples/bns_eos_example.py b/examples/gw_examples/injection_examples/bns_eos_example.py index fd24289db..d39b71eae 100644 --- a/examples/gw_examples/injection_examples/bns_eos_example.py +++ b/examples/gw_examples/injection_examples/bns_eos_example.py @@ -7,7 +7,6 @@ WARNING: The code is extremely slow. """ - import bilby from bilby.core.utils.random import seed from bilby.gw.eos import EOSFamily, TabularEOS @@ -90,9 +89,7 @@ interferometers.set_strain_data_from_power_spectral_densities( sampling_frequency=sampling_frequency, duration=duration, start_time=start_time ) -interferometers.inject_signal( - parameters=injection_parameters, waveform_generator=waveform_generator -) +interferometers.inject_signal(parameters=injection_parameters, waveform_generator=waveform_generator) # We're going to sample in chirp_mass, symmetric_mass_ratio, and # specific EoS model parameters. We're using a 4-parameter @@ -114,24 +111,12 @@ priors[key] = injection_parameters[key] for key in ["mass_1", "mass_2", "lambda_1", "lambda_2", "mass_ratio"]: del priors[key] -priors["chirp_mass"] = bilby.core.prior.Gaussian( - 1.215, 0.1, name="chirp_mass", unit="$M_{\\odot}$" -) -priors["symmetric_mass_ratio"] = bilby.core.prior.Uniform( - 0.1, 0.25, name="symmetric_mass_ratio" -) -priors["eos_spectral_gamma_0"] = bilby.core.prior.Uniform( - 0.2, 2.0, name="gamma0", latex_label="$\\gamma_0" -) -priors["eos_spectral_gamma_1"] = bilby.core.prior.Uniform( - -1.6, 1.7, name="gamma1", latex_label="$\\gamma_1" -) -priors["eos_spectral_gamma_2"] = bilby.core.prior.Uniform( - -0.6, 0.6, name="gamma2", latex_label="$\\gamma_2" -) -priors["eos_spectral_gamma_3"] = bilby.core.prior.Uniform( - -0.02, 0.02, name="gamma3", latex_label="$\\gamma_3" -) +priors["chirp_mass"] = bilby.core.prior.Gaussian(1.215, 0.1, name="chirp_mass", unit="$M_{\\odot}$") +priors["symmetric_mass_ratio"] = bilby.core.prior.Uniform(0.1, 0.25, name="symmetric_mass_ratio") +priors["eos_spectral_gamma_0"] = bilby.core.prior.Uniform(0.2, 2.0, name="gamma0", latex_label="$\\gamma_0") +priors["eos_spectral_gamma_1"] = bilby.core.prior.Uniform(-1.6, 1.7, name="gamma1", latex_label="$\\gamma_1") +priors["eos_spectral_gamma_2"] = bilby.core.prior.Uniform(-0.6, 0.6, name="gamma2", latex_label="$\\gamma_2") +priors["eos_spectral_gamma_3"] = bilby.core.prior.Uniform(-0.02, 0.02, name="gamma3", latex_label="$\\gamma_3") # The eos_check prior imposes several hard physical constraints on samples like # enforcing causality and monotinicity of the EoSs. In almost ever conceivable diff --git a/examples/gw_examples/injection_examples/bns_polytrope_eos_example.py b/examples/gw_examples/injection_examples/bns_polytrope_eos_example.py index 845f196ce..5285fdbf3 100644 --- a/examples/gw_examples/injection_examples/bns_polytrope_eos_example.py +++ b/examples/gw_examples/injection_examples/bns_polytrope_eos_example.py @@ -14,7 +14,6 @@ LALSimNeutronStarEOSDynamicPolytrope.c. """ - import bilby from bilby.core.utils.random import seed @@ -80,9 +79,7 @@ interferometers.set_strain_data_from_zero_noise( sampling_frequency=sampling_frequency, duration=duration, start_time=start_time ) -interferometers.inject_signal( - parameters=injection_parameters, waveform_generator=waveform_generator -) +interferometers.inject_signal(parameters=injection_parameters, waveform_generator=waveform_generator) # Load the default prior for binary neutron stars. # We're going to sample in chirp_mass, mass_ratio, and model parameters @@ -96,15 +93,9 @@ # The following are dynamic polytrope model priors # They are required for EOS inference -priors["eos_polytrope_gamma_0"] = bilby.core.prior.Uniform( - 1.0, 5.0, name="Gamma0", latex_label="$\\Gamma_0$" -) -priors["eos_polytrope_gamma_1"] = bilby.core.prior.Uniform( - 1.0, 5.0, name="Gamma1", latex_label="$\\Gamma_1$" -) -priors["eos_polytrope_gamma_2"] = bilby.core.prior.Uniform( - 1.0, 5.0, name="Gamma2", latex_label="$\\Gamma_2$" -) +priors["eos_polytrope_gamma_0"] = bilby.core.prior.Uniform(1.0, 5.0, name="Gamma0", latex_label="$\\Gamma_0$") +priors["eos_polytrope_gamma_1"] = bilby.core.prior.Uniform(1.0, 5.0, name="Gamma1", latex_label="$\\Gamma_1$") +priors["eos_polytrope_gamma_2"] = bilby.core.prior.Uniform(1.0, 5.0, name="Gamma2", latex_label="$\\Gamma_2$") """ One can run this model without the reparameterization using the following priors in place of the scaled pressure priors. The reparameterization approximates diff --git a/examples/gw_examples/injection_examples/bns_spectral_pca_eos_example.py b/examples/gw_examples/injection_examples/bns_spectral_pca_eos_example.py index 9c9e95464..6ca269d2f 100644 --- a/examples/gw_examples/injection_examples/bns_spectral_pca_eos_example.py +++ b/examples/gw_examples/injection_examples/bns_spectral_pca_eos_example.py @@ -10,7 +10,6 @@ in the appendix of https://arxiv.org/pdf/2001.01747.pdf. """ - import bilby from bilby.core.utils.random import seed @@ -78,9 +77,7 @@ interferometers.set_strain_data_from_zero_noise( sampling_frequency=sampling_frequency, duration=duration, start_time=start_time ) -interferometers.inject_signal( - parameters=injection_parameters, waveform_generator=waveform_generator -) +interferometers.inject_signal(parameters=injection_parameters, waveform_generator=waveform_generator) # Load the default prior for binary neutron stars. # We're going to sample in chirp_mass, mass_ratio, and model parameters diff --git a/examples/gw_examples/injection_examples/calibration_example.py b/examples/gw_examples/injection_examples/calibration_example.py index 688b114fa..1a62cd882 100644 --- a/examples/gw_examples/injection_examples/calibration_example.py +++ b/examples/gw_examples/injection_examples/calibration_example.py @@ -68,14 +68,10 @@ # (LIGO-Hanford (H1), LIGO-Livingston (L1), and Virgo (V1)). # These default to their design sensitivity ifos = bilby.gw.detector.InterferometerList(["H1", "L1", "V1"]) -ifos.set_strain_data_from_power_spectral_densities( - sampling_frequency=sampling_frequency, duration=duration -) +ifos.set_strain_data_from_power_spectral_densities(sampling_frequency=sampling_frequency, duration=duration) ifo = ifos[0] -injection_parameters.update( - {f"recalib_{ifo.name}_amplitude_{ii}": 0.1 for ii in range(5)} -) +injection_parameters.update({f"recalib_{ifo.name}_amplitude_{ii}": 0.1 for ii in range(5)}) injection_parameters.update({f"recalib_{ifo.name}_phase_{ii}": 0.01 for ii in range(5)}) ifo.calibration_model = bilby.gw.calibration.CubicSpline( prefix=f"recalib_{ifo.name}_", @@ -94,9 +90,7 @@ n_curves=100, ) -ifos.inject_signal( - parameters=injection_parameters, waveform_generator=waveform_generator -) +ifos.inject_signal(parameters=injection_parameters, waveform_generator=waveform_generator) # Set up prior, which is a dictionary # Here we fix the injected cbc parameters and most of the calibration parameters @@ -107,18 +101,12 @@ if "recalib" in key: priors[key] = injection_parameters[key] for name in ["recalib_H1_amplitude_0", "recalib_H1_amplitude_1"]: - priors[name] = bilby.core.prior.Gaussian( - mu=0, sigma=0.2, name=name, latex_label=f"H1 $A_{name[-1]}$" - ) -priors["recalib_index_L1"] = bilby.core.prior.Categorical( - ncategories=100, latex_label="recalib index L1" -) + priors[name] = bilby.core.prior.Gaussian(mu=0, sigma=0.2, name=name, latex_label=f"H1 $A_{name[-1]}$") +priors["recalib_index_L1"] = bilby.core.prior.Categorical(ncategories=100, latex_label="recalib index L1") # Initialise the likelihood by passing in the interferometer data (IFOs) and # the waveform generator -likelihood = bilby.gw.GravitationalWaveTransient( - interferometers=ifos, waveform_generator=waveform_generator -) +likelihood = bilby.gw.GravitationalWaveTransient(interferometers=ifos, waveform_generator=waveform_generator) # Run sampler. In this case we're going to use the `dynesty` sampler result = bilby.run_sampler( diff --git a/examples/gw_examples/injection_examples/calibration_marginalization_example.py b/examples/gw_examples/injection_examples/calibration_marginalization_example.py index e19ccee7a..cbec51743 100644 --- a/examples/gw_examples/injection_examples/calibration_marginalization_example.py +++ b/examples/gw_examples/injection_examples/calibration_marginalization_example.py @@ -66,12 +66,8 @@ # These default to their design sensitivity ifos = bilby.gw.detector.InterferometerList(["H1", "L1", "V1"]) for ifo in ifos: - injection_parameters.update( - {f"recalib_{ifo.name}_amplitude_{ii}": 0.0 for ii in range(10)} - ) - injection_parameters.update( - {f"recalib_{ifo.name}_phase_{ii}": 0.0 for ii in range(10)} - ) + injection_parameters.update({f"recalib_{ifo.name}_amplitude_{ii}": 0.0 for ii in range(10)}) + injection_parameters.update({f"recalib_{ifo.name}_phase_{ii}": 0.0 for ii in range(10)}) ifo.calibration_model = bilby.gw.calibration.CubicSpline( prefix=f"recalib_{ifo.name}_", minimum_frequency=ifo.minimum_frequency, @@ -81,9 +77,7 @@ ifos.set_strain_data_from_power_spectral_densities( sampling_frequency=sampling_frequency, duration=duration, start_time=start_time ) -ifos.inject_signal( - parameters=injection_parameters, waveform_generator=waveform_generator -) +ifos.inject_signal(parameters=injection_parameters, waveform_generator=waveform_generator) ifos_rew = deepcopy(ifos) # Set up prior, which is a dictionary @@ -126,15 +120,11 @@ # Setting the log likelihood to actually be the log likelihood and not the log likelihood ratio... # This is used the for reweighting -result.posterior["log_likelihood"] = ( - result.posterior["log_likelihood"] + result.log_noise_evidence -) +result.posterior["log_likelihood"] = result.posterior["log_likelihood"] + result.log_noise_evidence # Setting the priors we want on the calibration response curve parameters - as an example. for name in ["recalib_H1_amplitude_1", "recalib_H1_amplitude_4"]: - priors_rew[name] = bilby.prior.Gaussian( - mu=0, sigma=0.03, name=name, latex_label=f"H1 $A_{name[-1]}$" - ) + priors_rew[name] = bilby.prior.Gaussian(mu=0, sigma=0.03, name=name, latex_label=f"H1 $A_{name[-1]}$") # Setting up the calibration marginalized likelihood. # We save the calibration response curve files into the output directory under {ifo.name}_calibration_file.h5 @@ -144,10 +134,7 @@ calibration_marginalization=True, priors=priors_rew, number_of_response_curves=100, - calibration_lookup_table={ - ifos[i].name: f"{outdir}/{ifos[i].name}_calibration_file.h5" - for i in range(len(ifos)) - }, + calibration_lookup_table={ifos[i].name: f"{outdir}/{ifos[i].name}_calibration_file.h5" for i in range(len(ifos))}, ) # Plot the magnitude of the curves to be used in the marginalization @@ -170,9 +157,7 @@ ) # Plot distance posterior with and without the calibration -for res, label in zip( - [result, result_rew], ["No calibration uncertainty", "Calibration uncertainty"] -): +for res, label in zip([result, result_rew], ["No calibration uncertainty", "Calibration uncertainty"]): plt.hist( res.posterior["luminosity_distance"], label=label, diff --git a/examples/gw_examples/injection_examples/change_sampled_parameters.py b/examples/gw_examples/injection_examples/change_sampled_parameters.py index c5ad7f618..a9d230707 100644 --- a/examples/gw_examples/injection_examples/change_sampled_parameters.py +++ b/examples/gw_examples/injection_examples/change_sampled_parameters.py @@ -7,6 +7,7 @@ mass ratio and redshift. The cosmology is according to the Planck 2015 data release. """ + import bilby import numpy as np from bilby.core.utils.random import seed @@ -60,9 +61,7 @@ duration=duration, start_time=injection_parameters["geocent_time"] - 2, ) -ifos.inject_signal( - waveform_generator=waveform_generator, parameters=injection_parameters -) +ifos.inject_signal(waveform_generator=waveform_generator, parameters=injection_parameters) # Set up prior # Note it is possible to sample in different parameters to those that were @@ -75,9 +74,7 @@ ) del priors["luminosity_distance"] -priors["redshift"] = bilby.prior.Uniform( - name="redshift", latex_label="$z$", minimum=0, maximum=0.5 -) +priors["redshift"] = bilby.prior.Uniform(name="redshift", latex_label="$z$", minimum=0, maximum=0.5) # These parameters will not be sampled for key in [ "a_1", @@ -98,9 +95,7 @@ print(priors) # Initialise GravitationalWaveTransient -likelihood = bilby.gw.likelihood.GravitationalWaveTransient( - interferometers=ifos, waveform_generator=waveform_generator -) +likelihood = bilby.gw.likelihood.GravitationalWaveTransient(interferometers=ifos, waveform_generator=waveform_generator) # Run sampler # Note we've added a post-processing conversion function, this will generate diff --git a/examples/gw_examples/injection_examples/create_your_own_source_model.py b/examples/gw_examples/injection_examples/create_your_own_source_model.py index fab1a650c..97a06f1da 100644 --- a/examples/gw_examples/injection_examples/create_your_own_source_model.py +++ b/examples/gw_examples/injection_examples/create_your_own_source_model.py @@ -2,6 +2,7 @@ """ A script to demonstrate how to use your own source model """ + import bilby import numpy as np from bilby.core.utils.random import seed @@ -57,9 +58,7 @@ def gaussian(frequency_array, amplitude, f0, tau, phi0): # We now define some parameters that we will inject -injection_parameters = dict( - amplitude=1e-23, f0=100, tau=1, phi0=0, geocent_time=0, ra=0, dec=0, psi=0 -) +injection_parameters = dict(amplitude=1e-23, f0=100, tau=1, phi0=0, geocent_time=0, ra=0, dec=0, psi=0) # Now we pass our source function to the WaveformGenerator waveform_generator = bilby.gw.waveform_generator.WaveformGenerator( @@ -84,14 +83,10 @@ def gaussian(frequency_array, amplitude, f0, tau, phi0): # Here we define the priors for the search. We use the injection parameters # except for the amplitude, f0, and geocent_time prior = injection_parameters.copy() -prior["amplitude"] = bilby.core.prior.LogUniform( - minimum=1e-25, maximum=1e-21, latex_label="$\\mathcal{A}$" -) +prior["amplitude"] = bilby.core.prior.LogUniform(minimum=1e-25, maximum=1e-21, latex_label="$\\mathcal{A}$") prior["f0"] = bilby.core.prior.Uniform(90, 110, latex_label="$f_{0}$") -likelihood = bilby.gw.likelihood.GravitationalWaveTransient( - interferometers=ifos, waveform_generator=waveform_generator -) +likelihood = bilby.gw.likelihood.GravitationalWaveTransient(interferometers=ifos, waveform_generator=waveform_generator) result = bilby.core.sampler.run_sampler( likelihood, diff --git a/examples/gw_examples/injection_examples/create_your_own_time_domain_source_model.py b/examples/gw_examples/injection_examples/create_your_own_time_domain_source_model.py index 8e04c154e..c25d6523e 100644 --- a/examples/gw_examples/injection_examples/create_your_own_time_domain_source_model.py +++ b/examples/gw_examples/injection_examples/create_your_own_time_domain_source_model.py @@ -93,9 +93,7 @@ def time_domain_damped_sinusoid(time, amplitude, damping_time, frequency, phase, duration=duration, start_time=injection_parameters["geocent_time"] - 0.5, ) -ifos.inject_signal( - waveform_generator=waveform, parameters=injection_parameters, raise_error=False -) +ifos.inject_signal(waveform_generator=waveform, parameters=injection_parameters, raise_error=False) # create the priors prior = injection_parameters.copy() diff --git a/examples/gw_examples/injection_examples/custom_proposal_example.py b/examples/gw_examples/injection_examples/custom_proposal_example.py index 07fdaf527..d969523e6 100644 --- a/examples/gw_examples/injection_examples/custom_proposal_example.py +++ b/examples/gw_examples/injection_examples/custom_proposal_example.py @@ -7,6 +7,8 @@ Due to how cpnest creates parallel processes, the multiprocessing start method needs to be set on some operating systems. """ +# ruff: noqa: E402 + import multiprocessing multiprocessing.set_start_method("fork") # noqa @@ -62,15 +64,11 @@ duration=duration, start_time=injection_parameters["geocent_time"] - 2, ) -ifos.inject_signal( - waveform_generator=waveform_generator, parameters=injection_parameters -) +ifos.inject_signal(waveform_generator=waveform_generator, parameters=injection_parameters) priors = bilby.gw.prior.BBHPriorDict() for key in ["a_1", "a_2", "tilt_1", "tilt_2", "phi_12", "phi_jl", "geocent_time"]: priors[key] = injection_parameters[key] -likelihood = bilby.gw.GravitationalWaveTransient( - interferometers=ifos, waveform_generator=waveform_generator -) +likelihood = bilby.gw.GravitationalWaveTransient(interferometers=ifos, waveform_generator=waveform_generator) # Definition of the custom jump proposals. Define a JumpProposalCycle. The first argument is a list # of all allowed jump proposals. The second argument is a list of weights for the respective jump diff --git a/examples/gw_examples/injection_examples/eccentric_inspiral.py b/examples/gw_examples/injection_examples/eccentric_inspiral.py index fdb9dd043..98f624fab 100644 --- a/examples/gw_examples/injection_examples/eccentric_inspiral.py +++ b/examples/gw_examples/injection_examples/eccentric_inspiral.py @@ -36,9 +36,7 @@ dec=5.73, ) -waveform_arguments = dict( - waveform_approximant="EccentricFD", reference_frequency=10.0, minimum_frequency=10.0 -) +waveform_arguments = dict(waveform_approximant="EccentricFD", reference_frequency=10.0, minimum_frequency=10.0) # Create the waveform_generator using the GWSignal interface, this allows us # to specify what physics is included in the model @@ -67,9 +65,7 @@ duration=duration, start_time=injection_parameters["geocent_time"] + 2 - duration, ) -ifos.inject_signal( - waveform_generator=waveform_generator, parameters=injection_parameters -) +ifos.inject_signal(waveform_generator=waveform_generator, parameters=injection_parameters) # Now we set up the priors on each of the binary parameters. priors = bilby.core.prior.PriorDict() @@ -79,23 +75,13 @@ priors["mass_2"] = bilby.core.prior.Uniform( name="mass_2", minimum=5, maximum=60, unit="$M_{\\odot}$", latex_label="$m_2$" ) -priors["eccentricity"] = bilby.core.prior.LogUniform( - name="eccentricity", latex_label="$e$", minimum=1e-4, maximum=0.4 -) -priors["luminosity_distance"] = bilby.gw.prior.UniformSourceFrame( - name="luminosity_distance", minimum=1e2, maximum=2e3 -) +priors["eccentricity"] = bilby.core.prior.LogUniform(name="eccentricity", latex_label="$e$", minimum=1e-4, maximum=0.4) +priors["luminosity_distance"] = bilby.gw.prior.UniformSourceFrame(name="luminosity_distance", minimum=1e2, maximum=2e3) priors["dec"] = bilby.core.prior.Cosine(name="dec") -priors["ra"] = bilby.core.prior.Uniform( - name="ra", minimum=0, maximum=2 * np.pi, boundary="periodic" -) +priors["ra"] = bilby.core.prior.Uniform(name="ra", minimum=0, maximum=2 * np.pi, boundary="periodic") priors["theta_jn"] = bilby.core.prior.Sine(name="theta_jn") -priors["psi"] = bilby.core.prior.Uniform( - name="psi", minimum=0, maximum=np.pi, boundary="periodic" -) -priors["phase"] = bilby.core.prior.Uniform( - name="phase", minimum=0, maximum=2 * np.pi, boundary="periodic" -) +priors["psi"] = bilby.core.prior.Uniform(name="psi", minimum=0, maximum=np.pi, boundary="periodic") +priors["phase"] = bilby.core.prior.Uniform(name="phase", minimum=0, maximum=2 * np.pi, boundary="periodic") priors["geocent_time"] = bilby.core.prior.Uniform( injection_parameters["geocent_time"] - 0.1, injection_parameters["geocent_time"] + 0.1, diff --git a/examples/gw_examples/injection_examples/fake_sampler_example.py b/examples/gw_examples/injection_examples/fake_sampler_example.py index 6563a657d..9cd4d7527 100755 --- a/examples/gw_examples/injection_examples/fake_sampler_example.py +++ b/examples/gw_examples/injection_examples/fake_sampler_example.py @@ -119,9 +119,7 @@ def main(): ) # update the waveform generator to use our higher-order mode waveform - likelihood.waveform_generator.waveform_arguments[ - "waveform_approximant" - ] = "IMRPhenomXHM" + likelihood.waveform_generator.waveform_arguments["waveform_approximant"] = "IMRPhenomXHM" # call the FakeSampler to compute the new likelihoods new_result = bilby.run_sampler( diff --git a/examples/gw_examples/injection_examples/fast_tutorial.py b/examples/gw_examples/injection_examples/fast_tutorial.py index 2fce438c7..45cdfee37 100644 --- a/examples/gw_examples/injection_examples/fast_tutorial.py +++ b/examples/gw_examples/injection_examples/fast_tutorial.py @@ -71,9 +71,7 @@ duration=duration, start_time=injection_parameters["geocent_time"] - 2, ) -ifos.inject_signal( - waveform_generator=waveform_generator, parameters=injection_parameters -) +ifos.inject_signal(waveform_generator=waveform_generator, parameters=injection_parameters) # Set up a PriorDict, which inherits from dict. # By default we will sample all terms in the signal models. However, this will @@ -106,9 +104,7 @@ # Initialise the likelihood by passing in the interferometer data (ifos) and # the waveform generator -likelihood = bilby.gw.GravitationalWaveTransient( - interferometers=ifos, waveform_generator=waveform_generator -) +likelihood = bilby.gw.GravitationalWaveTransient(interferometers=ifos, waveform_generator=waveform_generator) # Run sampler. In this case we're going to use the `dynesty` sampler result = bilby.run_sampler( diff --git a/examples/gw_examples/injection_examples/how_to_specify_the_prior.py b/examples/gw_examples/injection_examples/how_to_specify_the_prior.py index 6015aaf00..054f99e57 100644 --- a/examples/gw_examples/injection_examples/how_to_specify_the_prior.py +++ b/examples/gw_examples/injection_examples/how_to_specify_the_prior.py @@ -54,9 +54,7 @@ duration=duration, start_time=injection_parameters["geocent_time"] - 2, ) -ifos.inject_signal( - waveform_generator=waveform_generator, parameters=injection_parameters -) +ifos.inject_signal(waveform_generator=waveform_generator, parameters=injection_parameters) # Set up prior # This loads in a predefined set of priors for BBHs. @@ -78,12 +76,8 @@ # We can make uniform distributions. del priors["chirp_mass"], priors["mass_ratio"] # We can make uniform distributions. -priors["mass_1"] = bilby.core.prior.Uniform( - name="mass_1", minimum=20, maximum=40, unit="$M_{\\odot}$" -) -priors["mass_2"] = bilby.core.prior.Uniform( - name="mass_2", minimum=20, maximum=40, unit="$M_{\\odot}$" -) +priors["mass_1"] = bilby.core.prior.Uniform(name="mass_1", minimum=20, maximum=40, unit="$M_{\\odot}$") +priors["mass_2"] = bilby.core.prior.Uniform(name="mass_2", minimum=20, maximum=40, unit="$M_{\\odot}$") # We can make a power-law distribution, p(x) ~ x^{alpha} # Note: alpha=0 is a uniform distribution, alpha=-1 is uniform-in-log priors["a_1"] = bilby.core.prior.PowerLaw(name="a_1", alpha=-1, minimum=1e-2, maximum=1) @@ -91,9 +85,7 @@ # Note: this doesn't have to be properly normalised. a_2 = np.linspace(0, 1, 1001) p_a_2 = a_2**4 -priors["a_2"] = bilby.core.prior.Interped( - name="a_2", xx=a_2, yy=p_a_2, minimum=0, maximum=0.5 -) +priors["a_2"] = bilby.core.prior.Interped(name="a_2", xx=a_2, yy=p_a_2, minimum=0, maximum=0.5) # Additionally, we have Gaussian, TruncatedGaussian, Sine and Cosine. # It's also possible to load an interpolate a prior from a file. # Finally, if you don't specify any necessary parameters it will be filled in @@ -101,9 +93,7 @@ # Enjoy. # Initialise GravitationalWaveTransient -likelihood = bilby.gw.GravitationalWaveTransient( - interferometers=ifos, waveform_generator=waveform_generator -) +likelihood = bilby.gw.GravitationalWaveTransient(interferometers=ifos, waveform_generator=waveform_generator) # Run sampler result = bilby.run_sampler( diff --git a/examples/gw_examples/injection_examples/marginalized_likelihood.py b/examples/gw_examples/injection_examples/marginalized_likelihood.py index 21494c97a..8b1d3bdf0 100644 --- a/examples/gw_examples/injection_examples/marginalized_likelihood.py +++ b/examples/gw_examples/injection_examples/marginalized_likelihood.py @@ -6,6 +6,7 @@ We also demonstrate how the posterior distribution for the marginalised parameter can be recovered in post-processing. """ + import bilby from bilby.core.utils.random import seed @@ -52,9 +53,7 @@ duration=duration, start_time=injection_parameters["geocent_time"] - 2, ) -ifos.inject_signal( - waveform_generator=waveform_generator, parameters=injection_parameters -) +ifos.inject_signal(waveform_generator=waveform_generator, parameters=injection_parameters) # Set up prior priors = bilby.gw.prior.BBHPriorDict() diff --git a/examples/gw_examples/injection_examples/multiband_example.py b/examples/gw_examples/injection_examples/multiband_example.py index c3efdb1fd..bbc6c2e6c 100644 --- a/examples/gw_examples/injection_examples/multiband_example.py +++ b/examples/gw_examples/injection_examples/multiband_example.py @@ -38,9 +38,7 @@ duration=duration, sampling_frequency=sampling_frequency, frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole, - waveform_arguments=dict( - waveform_approximant=approximant, reference_frequency=reference_frequency - ), + waveform_arguments=dict(waveform_approximant=approximant, reference_frequency=reference_frequency), parameter_conversion=bilby.gw.conversion.convert_to_lal_binary_black_hole_parameters, ) ifos = bilby.gw.detector.InterferometerList(["H1", "L1", "V1"]) @@ -49,9 +47,7 @@ duration=duration, start_time=injection_parameters["geocent_time"] - duration + 2, ) -ifos.inject_signal( - waveform_generator=waveform_generator, parameters=injection_parameters -) +ifos.inject_signal(waveform_generator=waveform_generator, parameters=injection_parameters) for ifo in ifos: ifo.minimum_frequency = minimum_frequency @@ -60,9 +56,7 @@ duration=duration, sampling_frequency=sampling_frequency, frequency_domain_source_model=bilby.gw.source.binary_black_hole_frequency_sequence, - waveform_arguments=dict( - waveform_approximant=approximant, reference_frequency=reference_frequency - ), + waveform_arguments=dict(waveform_approximant=approximant, reference_frequency=reference_frequency), parameter_conversion=bilby.gw.conversion.convert_to_lal_binary_black_hole_parameters, ) @@ -71,9 +65,7 @@ priors["chi_1"] = 0 priors["chi_2"] = 0 del priors["lambda_1"], priors["lambda_2"] -priors["chirp_mass"] = bilby.core.prior.Uniform( - name="chirp_mass", minimum=1.15, maximum=1.25 -) +priors["chirp_mass"] = bilby.core.prior.Uniform(name="chirp_mass", minimum=1.15, maximum=1.25) priors["geocent_time"] = bilby.core.prior.Uniform( injection_parameters["geocent_time"] - 0.1, injection_parameters["geocent_time"] + 0.1, diff --git a/examples/gw_examples/injection_examples/non_tensor.py b/examples/gw_examples/injection_examples/non_tensor.py index 9721919f3..09ffc2ae0 100644 --- a/examples/gw_examples/injection_examples/non_tensor.py +++ b/examples/gw_examples/injection_examples/non_tensor.py @@ -6,6 +6,7 @@ We adapt the sine-Gaussian burst model to include vector polarizations with an unknown contribution from the vector modes. """ + import bilby import numpy as np from bilby.core.utils.random import seed @@ -31,9 +32,7 @@ def vector_tensor_sine_gaussian(frequency_array, hrss, Q, frequency, epsilon): epsilon: float Relative size of the vector modes compared to the tensor modes. """ - waveform_polarizations = bilby.gw.source.sinegaussian( - frequency_array, hrss, Q, frequency - ) + waveform_polarizations = bilby.gw.source.sinegaussian(frequency_array, hrss, Q, frequency) waveform_polarizations["x"] = epsilon * waveform_polarizations["plus"] waveform_polarizations["y"] = epsilon * waveform_polarizations["cross"] diff --git a/examples/gw_examples/injection_examples/plot_skymap.py b/examples/gw_examples/injection_examples/plot_skymap.py index d98d0dc1e..784050569 100644 --- a/examples/gw_examples/injection_examples/plot_skymap.py +++ b/examples/gw_examples/injection_examples/plot_skymap.py @@ -3,6 +3,7 @@ Example script which produces posterior samples of ra and dec and generates a skymap """ + import bilby from bilby.core.utils.random import seed @@ -47,9 +48,7 @@ duration=duration, start_time=injection_parameters["geocent_time"] - 2, ) -ifos.inject_signal( - waveform_generator=waveform_generator, parameters=injection_parameters -) +ifos.inject_signal(waveform_generator=waveform_generator, parameters=injection_parameters) priors = bilby.gw.prior.BBHPriorDict() for key in [ @@ -69,9 +68,7 @@ priors[key] = injection_parameters[key] del priors["chirp_mass"], priors["mass_ratio"] -likelihood = bilby.gw.GravitationalWaveTransient( - interferometers=ifos, waveform_generator=waveform_generator -) +likelihood = bilby.gw.GravitationalWaveTransient(interferometers=ifos, waveform_generator=waveform_generator) result = bilby.run_sampler( likelihood=likelihood, diff --git a/examples/gw_examples/injection_examples/plot_time_domain_data.py b/examples/gw_examples/injection_examples/plot_time_domain_data.py index d1c29247a..75b542eae 100644 --- a/examples/gw_examples/injection_examples/plot_time_domain_data.py +++ b/examples/gw_examples/injection_examples/plot_time_domain_data.py @@ -3,6 +3,7 @@ This example demonstrates how to simulate some data, add an injected signal and plot the data. """ + from bilby.core.utils.random import seed from bilby.gw.detector import get_empty_interferometer from bilby.gw.source import lal_binary_black_hole @@ -34,9 +35,7 @@ dec=-1.2108, ) -waveform_arguments = dict( - waveform_approximant="IMRPhenomTPHM", reference_frequency=50.0 -) +waveform_arguments = dict(waveform_approximant="IMRPhenomTPHM", reference_frequency=50.0) waveform_generator = WaveformGenerator( duration=duration, @@ -48,9 +47,7 @@ hf_signal = waveform_generator.frequency_domain_strain(injection_parameters) ifo = get_empty_interferometer("H1") -ifo.set_strain_data_from_power_spectral_density( - duration=duration, sampling_frequency=sampling_frequency -) +ifo.set_strain_data_from_power_spectral_density(duration=duration, sampling_frequency=sampling_frequency) ifo.inject_signal(injection_polarizations=hf_signal, parameters=injection_parameters) t0 = injection_parameters["geocent_time"] diff --git a/examples/gw_examples/injection_examples/relative_binning.py b/examples/gw_examples/injection_examples/relative_binning.py index c64c62051..6215bd313 100644 --- a/examples/gw_examples/injection_examples/relative_binning.py +++ b/examples/gw_examples/injection_examples/relative_binning.py @@ -7,6 +7,7 @@ and distance using a uniform in comoving volume prior on luminosity distance between luminosity distances of 100Mpc and 5Gpc, the cosmology is Planck15. """ + from copy import deepcopy import bilby @@ -78,9 +79,7 @@ duration=duration, start_time=injection_parameters["geocent_time"] - 2, ) -ifos.inject_signal( - waveform_generator=waveform_generator, parameters=injection_parameters -) +ifos.inject_signal(waveform_generator=waveform_generator, parameters=injection_parameters) # Set up a PriorDict, which inherits from dict. # By default we will sample all terms in the signal models. However, this will @@ -117,9 +116,7 @@ fiducial_parameters = injection_parameters.copy() m1 = fiducial_parameters.pop("mass_1") m2 = fiducial_parameters.pop("mass_2") -fiducial_parameters["chirp_mass"] = bilby.gw.conversion.component_masses_to_chirp_mass( - m1, m2 -) +fiducial_parameters["chirp_mass"] = bilby.gw.conversion.component_masses_to_chirp_mass(m1, m2) fiducial_parameters["mass_ratio"] = m2 / m1 # Initialise the likelihood by passing in the interferometer data (ifos) and @@ -161,14 +158,9 @@ likelihood.distance_marginalization = False weights = list() for parameters in tqdm(result.posterior.to_dict(orient="records")): - weights.append( - alt_likelihood.log_likelihood_ratio(parameters) - - likelihood.log_likelihood_ratio(parameters) - ) + weights.append(alt_likelihood.log_likelihood_ratio(parameters) - likelihood.log_likelihood_ratio(parameters)) weights = np.exp(weights) -print( - f"Reweighting efficiency is {np.mean(weights)**2 / np.mean(weights**2) * 100:.2f}%" -) +print(f"Reweighting efficiency is {np.mean(weights) ** 2 / np.mean(weights**2) * 100:.2f}%") print(f"Binned vs unbinned log Bayes factor {np.log(np.mean(weights)):.2f}") # Generate result object with the posterior for the regular likelihood using diff --git a/examples/gw_examples/injection_examples/reproduce_mpa1_eos.py b/examples/gw_examples/injection_examples/reproduce_mpa1_eos.py index 6d2ee4fc2..166cc7b48 100644 --- a/examples/gw_examples/injection_examples/reproduce_mpa1_eos.py +++ b/examples/gw_examples/injection_examples/reproduce_mpa1_eos.py @@ -2,6 +2,7 @@ """ A script to demonstrate how to create and plot EoS's with Bilby """ + from bilby.gw import eos # In this script we're going to use Bilby to plot the MPA1 EoS from tabulated data @@ -14,20 +15,14 @@ MPA1_xmax = 6.63 # Dimensionless ending pressure # Create the spectral decomposition EoS class -MPA1_spectral = eos.SpectralDecompositionEOS( - MPA1_gammas, p0=MPA1_p0, e0=MPA1_e0_c2, xmax=MPA1_xmax, npts=100 -) +MPA1_spectral = eos.SpectralDecompositionEOS(MPA1_gammas, p0=MPA1_p0, e0=MPA1_e0_c2, xmax=MPA1_xmax, npts=100) # And create another from tabulated data MPA1_tabulated = eos.TabularEOS("MPA1") # Now let's plot them # To do so, we specify a representation and plot ranges. -MPA1_spectral_plot = MPA1_spectral.plot( - "pressure-energy_density", xlim=[1e22, 1e36], ylim=[1e9, 1e36] -) -MPA1_tabular_plot = MPA1_tabulated.plot( - "pressure-energy_density", xlim=[1e22, 1e36], ylim=[1e9, 1e36] -) +MPA1_spectral_plot = MPA1_spectral.plot("pressure-energy_density", xlim=[1e22, 1e36], ylim=[1e9, 1e36]) +MPA1_tabular_plot = MPA1_tabulated.plot("pressure-energy_density", xlim=[1e22, 1e36], ylim=[1e9, 1e36]) MPA1_spectral_plot.savefig("spectral_mpa1.pdf") MPA1_tabular_plot.savefig("tabular_mpa1.pdf") diff --git a/examples/gw_examples/injection_examples/roq_example.py b/examples/gw_examples/injection_examples/roq_example.py index 78a04be0c..72e310206 100644 --- a/examples/gw_examples/injection_examples/roq_example.py +++ b/examples/gw_examples/injection_examples/roq_example.py @@ -65,9 +65,7 @@ dec=-1.2108, ) -waveform_arguments = dict( - waveform_approximant="IMRPhenomPv2", reference_frequency=20.0 * scale_factor -) +waveform_arguments = dict(waveform_approximant="IMRPhenomPv2", reference_frequency=20.0 * scale_factor) waveform_generator = bilby.gw.WaveformGenerator( duration=duration, @@ -83,9 +81,7 @@ duration=duration, start_time=injection_parameters["geocent_time"] - 2 / scale_factor, ) -ifos.inject_signal( - waveform_generator=waveform_generator, parameters=injection_parameters -) +ifos.inject_signal(waveform_generator=waveform_generator, parameters=injection_parameters) for ifo in ifos: ifo.minimum_frequency = 20 * scale_factor @@ -130,9 +126,7 @@ ) # The roq parameters typically store the mass ratio bounds as m1/m2 not m2/m1 as in the # Bilby convention. -priors["mass_ratio"] = bilby.core.prior.Uniform( - 1 / params["qmax"], 1, name="mass_ratio" -) +priors["mass_ratio"] = bilby.core.prior.Uniform(1 / params["qmax"], 1, name="mass_ratio") priors["geocent_time"] = bilby.core.prior.Uniform( injection_parameters["geocent_time"] - 0.1, injection_parameters["geocent_time"] + 0.1, diff --git a/examples/gw_examples/injection_examples/sine_gaussian_example.py b/examples/gw_examples/injection_examples/sine_gaussian_example.py index a6aa425ce..545d0863d 100644 --- a/examples/gw_examples/injection_examples/sine_gaussian_example.py +++ b/examples/gw_examples/injection_examples/sine_gaussian_example.py @@ -3,6 +3,7 @@ Tutorial to demonstrate running parameter estimation on a sine gaussian injected signal. """ + import bilby from bilby.core.utils.random import seed @@ -66,9 +67,7 @@ # Initialise the likelihood by passing in the interferometer data (IFOs) and # the waveoform generator -likelihood = bilby.gw.likelihood.GravitationalWaveTransient( - interferometers=ifos, waveform_generator=waveform_generator -) +likelihood = bilby.gw.likelihood.GravitationalWaveTransient(interferometers=ifos, waveform_generator=waveform_generator) # Run sampler. In this case we're going to use the `dynesty` sampler result = bilby.core.sampler.run_sampler( diff --git a/examples/gw_examples/injection_examples/standard_15d_cbc_tutorial.py b/examples/gw_examples/injection_examples/standard_15d_cbc_tutorial.py index e34a06f99..b2e8fdf4b 100644 --- a/examples/gw_examples/injection_examples/standard_15d_cbc_tutorial.py +++ b/examples/gw_examples/injection_examples/standard_15d_cbc_tutorial.py @@ -6,6 +6,7 @@ This will take many hours to run. """ + import bilby import numpy as np from bilby.core.utils.random import seed @@ -73,9 +74,7 @@ start_time=injection_parameters["geocent_time"] - 2, ) -ifos.inject_signal( - waveform_generator=waveform_generator, parameters=injection_parameters -) +ifos.inject_signal(waveform_generator=waveform_generator, parameters=injection_parameters) # For this analysis, we implement the standard BBH priors defined, except for # the definition of the time prior, which is defined as uniform about the diff --git a/examples/gw_examples/supernova_example/supernova_example.py b/examples/gw_examples/supernova_example/supernova_example.py index 910979ce6..817d4352d 100644 --- a/examples/gw_examples/supernova_example/supernova_example.py +++ b/examples/gw_examples/supernova_example/supernova_example.py @@ -10,6 +10,7 @@ conda install -c conda-forge pymultinest """ + import bilby import numpy as np from bilby.core.utils.random import seed @@ -85,17 +86,13 @@ priors = bilby.core.prior.PriorDict() for key in ["psi", "geocent_time"]: priors[key] = injection_parameters[key] -priors["luminosity_distance"] = bilby.core.prior.Uniform( - 2, 20, "luminosity_distance", unit="$kpc$" -) +priors["luminosity_distance"] = bilby.core.prior.Uniform(2, 20, "luminosity_distance", unit="$kpc$") priors["pc_coeff1"] = bilby.core.prior.Uniform(-1, 1, "pc_coeff1") priors["pc_coeff2"] = bilby.core.prior.Uniform(-1, 1, "pc_coeff2") priors["pc_coeff3"] = bilby.core.prior.Uniform(-1, 1, "pc_coeff3") priors["pc_coeff4"] = bilby.core.prior.Uniform(-1, 1, "pc_coeff4") priors["pc_coeff5"] = bilby.core.prior.Uniform(-1, 1, "pc_coeff5") -priors["ra"] = bilby.core.prior.Uniform( - minimum=0, maximum=2 * np.pi, name="ra", boundary="periodic" -) +priors["ra"] = bilby.core.prior.Uniform(minimum=0, maximum=2 * np.pi, name="ra", boundary="periodic") priors["dec"] = bilby.core.prior.Sine(name="dec") priors["geocent_time"] = bilby.core.prior.Uniform( injection_parameters["geocent_time"] - 1, diff --git a/examples/tutorials/compare_samplers.ipynb b/examples/tutorials/compare_samplers.ipynb index 18e9946d4..e50275f88 100644 --- a/examples/tutorials/compare_samplers.ipynb +++ b/examples/tutorials/compare_samplers.ipynb @@ -19,9 +19,10 @@ "metadata": {}, "outputs": [], "source": [ - "import bilby\n", - "import numpy as np\n", "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "\n", + "import bilby\n", "from bilby.core.utils import random\n", "\n", "# Sets seed of bilby's generator \"rng\" to \"123\" to ensure reproducibility\n", @@ -178,7 +179,7 @@ " resume=False,\n", " clean=True,\n", " verbose=False,\n", - " **samplers[sampler]\n", + " **samplers[sampler],\n", " )\n", " results[sampler] = result" ] @@ -217,9 +218,7 @@ "metadata": {}, "outputs": [], "source": [ - "_ = bilby.core.result.plot_multiple(\n", - " list(results.values()), labels=list(results.keys()), save=False\n", - ")\n", + "_ = bilby.core.result.plot_multiple(list(results.values()), labels=list(results.keys()), save=False)\n", "plt.show()\n", "plt.close()" ] @@ -232,9 +231,7 @@ "source": [ "fig, ax = plt.subplots(figsize=(12, 8))\n", "ax.plot(time, data, \"x\", label=\"Data\", color=\"r\")\n", - "ax.plot(\n", - " time, model(time, **injection_parameters), linestyle=\"--\", color=\"k\", label=\"Truth\"\n", - ")\n", + "ax.plot(time, model(time, **injection_parameters), linestyle=\"--\", color=\"k\", label=\"Truth\")\n", "\n", "for jj, sampler in enumerate(samplers):\n", " result = results[sampler]\n", diff --git a/examples/tutorials/conditional_priors.ipynb b/examples/tutorials/conditional_priors.ipynb index 0838e5da9..630a2232b 100644 --- a/examples/tutorials/conditional_priors.ipynb +++ b/examples/tutorials/conditional_priors.ipynb @@ -21,13 +21,17 @@ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import pandas as pd\n", - "from bilby.core.prior import (\n", - " Prior, PriorDict, ConditionalPriorDict,\n", - " Uniform, ConditionalUniform, Constraint, \n", - ")\n", "from corner import corner\n", "from scipy.stats import semicircular\n", "\n", + "from bilby.core.prior import (\n", + " ConditionalPriorDict,\n", + " ConditionalUniform,\n", + " Constraint,\n", + " Prior,\n", + " PriorDict,\n", + " Uniform,\n", + ")\n", "\n", "%matplotlib inline" ] @@ -63,9 +67,8 @@ "outputs": [], "source": [ "class SemiCircular(Prior):\n", - "\n", " def __init__(self, radius=1, center=0, name=None, latex_label=None, unit=None, boundary=None):\n", - " super(SemiCircular, self).__init__(\n", + " super().__init__(\n", " minimum=center - radius,\n", " maximum=center + radius,\n", " name=name,\n", @@ -91,7 +94,7 @@ "\n", "\n", "def conditional_func_y(reference_parameters, x):\n", - " condition = np.sqrt(reference_parameters[\"maximum\"]-x**2)\n", + " condition = np.sqrt(reference_parameters[\"maximum\"] - x**2)\n", " return dict(minimum=-condition, maximum=condition)" ] }, @@ -124,31 +127,28 @@ "\n", "def convert_to_radial(parameters):\n", " p = parameters.copy()\n", - " p['r'] = p['x']**2 + p['y']**2\n", + " p[\"r\"] = p[\"x\"] ** 2 + p[\"y\"] ** 2\n", " return p\n", "\n", + "\n", "def sample_circle_with_constraint():\n", " d = PriorDict(\n", - " dictionary=dict(\n", - " x=Uniform(-1, 1),\n", - " y=Uniform(-1, 1),\n", - " r=Constraint(0, 1),\n", - " ),\n", - " conversion_function=convert_to_radial\n", - " )\n", + " dictionary=dict(\n", + " x=Uniform(-1, 1),\n", + " y=Uniform(-1, 1),\n", + " r=Constraint(0, 1),\n", + " ),\n", + " conversion_function=convert_to_radial,\n", + " )\n", " return pd.DataFrame(d.sample(N))\n", "\n", "\n", "def sample_circle_with_conditional():\n", " d = ConditionalPriorDict(\n", - " dictionary=dict(\n", - " x=SemiCircular(),\n", - " y=ConditionalUniform(\n", - " condition_func=conditional_func_y, \n", - " minimum=-1, maximum=1\n", - " )\n", - " )\n", + " dictionary=dict(\n", + " x=SemiCircular(), y=ConditionalUniform(condition_func=conditional_func_y, minimum=-1, maximum=1)\n", " )\n", + " )\n", " return pd.DataFrame(d.sample(N))\n", "\n", "\n", @@ -179,22 +179,25 @@ "source": [ "class BoundedUniform(ConditionalUniform):\n", " \"\"\"Conditional Uniform prior where prior sample < previous prior sample\n", - " \n", + "\n", " This is ensured by fixing the maximum bound to be the previous prior sample value.\n", " \"\"\"\n", - " def __init__(self, idx: int, minimum, maximum, name=None, latex_label=None,\n", - " unit=None, boundary=None):\n", - " super(BoundedUniform, self).__init__(\n", - " minimum=minimum, maximum=maximum, name=name, \n", - " latex_label=latex_label, unit=unit,\n", - " boundary=boundary, condition_func=self.bounds_condition\n", + "\n", + " def __init__(self, idx: int, minimum, maximum, name=None, latex_label=None, unit=None, boundary=None):\n", + " super().__init__(\n", + " minimum=minimum,\n", + " maximum=maximum,\n", + " name=name,\n", + " latex_label=latex_label,\n", + " unit=unit,\n", + " boundary=boundary,\n", + " condition_func=self.bounds_condition,\n", " )\n", " self.idx = idx\n", " self.previous_name = f\"{name[:-1]}{self.idx - 1}\"\n", - " self._required_variables = [self.previous_name] \n", + " self._required_variables = [self.previous_name]\n", " # this is used in prior.sample(... **required_variables)\n", "\n", - "\n", " def bounds_condition(self, reference_params, **required_variables):\n", " previous_sample = required_variables[self.previous_name]\n", " return dict(maximum=previous_sample)\n", @@ -203,11 +206,11 @@ "def make_uniform_conditonal_priordict(n_priors=3):\n", " priors = ConditionalPriorDict()\n", " for i in range(n_priors):\n", - " if i==0:\n", + " if i == 0:\n", " priors[f\"uni{i}\"] = Uniform(minimum=0, maximum=1, name=f\"uni{i}\")\n", " else:\n", " priors[f\"uni{i}\"] = BoundedUniform(idx=i, minimum=0, maximum=1, name=f\"uni{i}\")\n", - " return priors\n" + " return priors" ] }, { diff --git a/examples/tutorials/fitting_with_x_and_y_errors.ipynb b/examples/tutorials/fitting_with_x_and_y_errors.ipynb index 38818b666..1c044b86a 100644 --- a/examples/tutorials/fitting_with_x_and_y_errors.ipynb +++ b/examples/tutorials/fitting_with_x_and_y_errors.ipynb @@ -15,12 +15,13 @@ "metadata": {}, "outputs": [], "source": [ - "import bilby\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", + "\n", + "import bilby\n", "from bilby.core.utils import random\n", "\n", - "#sets seed of bilby's generator \"rng\" to \"123\"\n", + "# sets seed of bilby's generator \"rng\" to \"123\"\n", "random.seed(123)\n", "\n", "%matplotlib inline" @@ -94,11 +95,11 @@ "outputs": [], "source": [ "# setting up bilby priors\n", - "priors = dict(\n", - " m=bilby.core.prior.Uniform(0, 30, \"m\"), c=bilby.core.prior.Uniform(0, 30, \"c\")\n", - ")\n", + "priors = dict(m=bilby.core.prior.Uniform(0, 30, \"m\"), c=bilby.core.prior.Uniform(0, 30, \"c\"))\n", "\n", - "sampler_kwargs = dict(priors=priors, sampler=\"bilby_mcmc\", nsamples=1000, printdt=5, outdir=\"outdir\", verbose=False, clean=True)" + "sampler_kwargs = dict(\n", + " priors=priors, sampler=\"bilby_mcmc\", nsamples=1000, printdt=5, outdir=\"outdir\", verbose=False, clean=True\n", + ")" ] }, { @@ -116,9 +117,7 @@ "metadata": {}, "outputs": [], "source": [ - "known_x = bilby.core.likelihood.GaussianLikelihood(\n", - " x=data[\"xtrue\"], y=data[\"yobs\"], func=model, sigma=data[\"yerr\"]\n", - ")\n", + "known_x = bilby.core.likelihood.GaussianLikelihood(x=data[\"xtrue\"], y=data[\"yobs\"], func=model, sigma=data[\"yerr\"])\n", "result_known_x = bilby.run_sampler(\n", " likelihood=known_x,\n", " label=\"known_x\",\n", @@ -152,9 +151,7 @@ "metadata": {}, "outputs": [], "source": [ - "incorrect_x = bilby.core.likelihood.GaussianLikelihood(\n", - " x=data[\"xobs\"], y=data[\"yobs\"], func=model, sigma=data[\"yerr\"]\n", - ")\n", + "incorrect_x = bilby.core.likelihood.GaussianLikelihood(x=data[\"xobs\"], y=data[\"yobs\"], func=model, sigma=data[\"yerr\"])\n", "result_incorrect_x = bilby.run_sampler(\n", " likelihood=incorrect_x,\n", " label=\"incorrect_x\",\n", @@ -206,7 +203,7 @@ " function:\n", " The python function to fit to the data\n", " \"\"\"\n", - " super(GaussianLikelihoodUncertainX, self).__init__()\n", + " super().__init__()\n", " self.xobs = xobs\n", " self.yobs = yobs\n", " self.yerr = yerr\n", diff --git a/examples/tutorials/making_priors.ipynb b/examples/tutorials/making_priors.ipynb index a16f941c9..14ab892d4 100644 --- a/examples/tutorials/making_priors.ipynb +++ b/examples/tutorials/making_priors.ipynb @@ -29,10 +29,11 @@ "metadata": {}, "outputs": [], "source": [ - "import bilby\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", + "import bilby\n", + "\n", "%matplotlib inline" ] }, @@ -59,14 +60,10 @@ " bilby.core.prior.Uniform(minimum=5, maximum=50),\n", " bilby.core.prior.LogUniform(minimum=5, maximum=50),\n", " bilby.core.prior.PowerLaw(name=\"name\", alpha=2, minimum=100, maximum=1000),\n", - " bilby.gw.prior.UniformComovingVolume(\n", - " name=\"luminosity_distance\", minimum=100, maximum=1000, latex_label=\"label\"\n", - " ),\n", + " bilby.gw.prior.UniformComovingVolume(name=\"luminosity_distance\", minimum=100, maximum=1000, latex_label=\"label\"),\n", " bilby.gw.prior.AlignedSpin(),\n", " bilby.core.prior.Gaussian(name=\"name\", mu=0, sigma=1, latex_label=\"label\"),\n", - " bilby.core.prior.TruncatedGaussian(\n", - " name=\"name\", mu=1, sigma=0.4, minimum=-1, maximum=1, latex_label=\"label\"\n", - " ),\n", + " bilby.core.prior.TruncatedGaussian(name=\"name\", mu=1, sigma=0.4, minimum=-1, maximum=1, latex_label=\"label\"),\n", " bilby.core.prior.Cosine(name=\"name\", latex_label=\"label\"),\n", " bilby.core.prior.Sine(name=\"name\", latex_label=\"label\"),\n", " bilby.core.prior.Interped(\n", @@ -89,7 +86,7 @@ " )\n", " else:\n", " plt.plot(np.linspace(-5, 5, 1000), prior.prob(np.linspace(-5, 5, 1000)))\n", - " plt.xlabel(\"{}\".format(prior.latex_label))\n", + " plt.xlabel(f\"{prior.latex_label}\")\n", "\n", "plt.tight_layout()\n", "plt.show()\n", @@ -115,16 +112,13 @@ " \"\"\"Define a new prior class where p(x) ~ exp(alpha * x)\"\"\"\n", "\n", " def __init__(self, alpha, minimum, maximum, name=None, latex_label=None):\n", - " super(Exponential, self).__init__(\n", - " name=name, latex_label=latex_label, minimum=minimum, maximum=maximum\n", - " )\n", + " super().__init__(name=name, latex_label=latex_label, minimum=minimum, maximum=maximum)\n", " self.alpha = alpha\n", "\n", " def rescale(self, val):\n", " return (\n", " np.log(\n", - " (np.exp(self.alpha * self.maximum) - np.exp(self.alpha * self.minimum))\n", - " * val\n", + " (np.exp(self.alpha * self.maximum) - np.exp(self.alpha * self.minimum)) * val\n", " + np.exp(self.alpha * self.minimum)\n", " )\n", " / self.alpha\n", @@ -154,7 +148,7 @@ " np.linspace(prior.minimum, prior.maximum, 1000),\n", " prior.prob(np.linspace(prior.minimum, prior.maximum, 1000)),\n", ")\n", - "plt.xlabel(\"{}\".format(prior.latex_label))\n", + "plt.xlabel(f\"{prior.latex_label}\")\n", "plt.show()\n", "plt.close()" ] diff --git a/examples/tutorials/visualising_the_results.ipynb b/examples/tutorials/visualising_the_results.ipynb index 362432563..cf6747858 100644 --- a/examples/tutorials/visualising_the_results.ipynb +++ b/examples/tutorials/visualising_the_results.ipynb @@ -24,9 +24,10 @@ "metadata": {}, "outputs": [], "source": [ - "import bilby\n", "import matplotlib.pyplot as plt\n", "\n", + "import bilby\n", + "\n", "%matplotlib inline" ] }, @@ -94,9 +95,7 @@ " sampling_frequency=sampling_frequency,\n", " start_time=injection_parameters[\"geocent_time\"] - 2,\n", ")\n", - "_ = ifos.inject_signal(\n", - " waveform_generator=waveform_generator, parameters=injection_parameters\n", - ")" + "_ = ifos.inject_signal(waveform_generator=waveform_generator, parameters=injection_parameters)" ] }, { @@ -110,9 +109,7 @@ "# then, reset the priors on the masses and luminosity distance to conduct a search over these parameters\n", "priors[\"mass_1\"] = bilby.core.prior.Uniform(25, 40, \"mass_1\")\n", "priors[\"mass_2\"] = bilby.core.prior.Uniform(25, 40, \"mass_2\")\n", - "priors[\"luminosity_distance\"] = bilby.core.prior.Uniform(\n", - " 400, 2000, \"luminosity_distance\"\n", - ")" + "priors[\"luminosity_distance\"] = bilby.core.prior.Uniform(400, 2000, \"luminosity_distance\")" ] }, { @@ -122,9 +119,7 @@ "outputs": [], "source": [ "# compute the likelihoods\n", - "likelihood = bilby.gw.likelihood.GravitationalWaveTransient(\n", - " interferometers=ifos, waveform_generator=waveform_generator\n", - ")" + "likelihood = bilby.gw.likelihood.GravitationalWaveTransient(interferometers=ifos, waveform_generator=waveform_generator)" ] }, { @@ -201,9 +196,7 @@ "\n", "cbc_result = CBCResult.from_json(\"visualising_the_results/example_result.json\")\n", "for ifo in ifos:\n", - " cbc_result.plot_interferometer_waveform_posterior(\n", - " interferometer=ifo, n_samples=500, save=False\n", - " )\n", + " cbc_result.plot_interferometer_waveform_posterior(interferometer=ifo, n_samples=500, save=False)\n", " plt.show()\n", " plt.close()" ] @@ -265,9 +258,7 @@ "metadata": {}, "outputs": [], "source": [ - "result.plot_corner(\n", - " parameters=[\"mass_1\", \"mass_2\"], filename=\"{}/subset.png\".format(outdir)\n", - ")\n", + "result.plot_corner(parameters=[\"mass_1\", \"mass_2\"], filename=f\"{outdir}/subset.png\")\n", "plt.show()\n", "plt.close()" ] diff --git a/pyproject.toml b/pyproject.toml index 145d905d1..052ec8bd7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,18 +60,6 @@ Homepage = "https://github.com/bilby-dev/bilby" Source = "https://github.com/bilby-dev/bilby" Tracker = "https://github.com/bilby-dev/bilby/issues" -[tool.flake8] -exclude = [ - ".git", - "docs", - "build", - "dist", - "test", - "*__init__.py" -] -ignore = ["E129", "W503", "W504", "E203", "E402"] -max-line-length = 120 - [tool.isort] known_third_party = [ "astropy", @@ -110,6 +98,12 @@ addopts = [ "--ignore=test/integration/sampler_run_test.py" ] +[tool.ruff] +line-length = 120 + +[tool.ruff.format] +docstring-code-format = true + [tool.setuptools] packages = [ "bilby", diff --git a/test/bilby_mcmc/test_chain.py b/test/bilby_mcmc/test_chain.py index befc5b1d8..a9bbd9f76 100644 --- a/test/bilby_mcmc/test_chain.py +++ b/test/bilby_mcmc/test_chain.py @@ -2,12 +2,13 @@ import shutil import unittest +import numpy as np +import pandas as pd + import bilby from bilby.bilby_mcmc.chain import Chain, Sample, calculate_tau from bilby.bilby_mcmc.utils import LOGLKEY, LOGPKEY from bilby.core.sampler.base_sampler import SamplerError -import numpy as np -import pandas as pd class TestChain(unittest.TestCase): @@ -22,12 +23,9 @@ def tearDown(self): shutil.rmtree(self.outdir) def create_random_sample(self): - return Sample({ - "a": np.random.normal(0, 1), - "b": np.random.normal(0, 1), - LOGLKEY: np.random.normal(0, 1), - LOGPKEY: -1 - }) + return Sample( + {"a": np.random.normal(0, 1), "b": np.random.normal(0, 1), LOGLKEY: np.random.normal(0, 1), LOGPKEY: -1} + ) def create_chain(self, n=1000): chain = Chain(initial_sample=self.initial_sample) @@ -43,7 +41,7 @@ def test_append(self): chain = Chain(initial_sample=self.initial_sample) chain.append(self.create_random_sample()) self.assertEqual(chain.position, 1) - self.assertEqual(len(chain.get_1d_array('a')), 2) + self.assertEqual(len(chain.get_1d_array("a")), 2) def test_append_within_init_space(self): chain = Chain(initial_sample=self.initial_sample) @@ -54,7 +52,7 @@ def test_append_within_init_space(self): self.assertEqual(chain.position, N) # N samples + 1 initial position - self.assertEqual(len(chain.get_1d_array('a')), N + 1) + self.assertEqual(len(chain.get_1d_array("a")), N + 1) def test_append_with_extending(self): block_length = 100 @@ -153,7 +151,7 @@ def test_plot(self): chain.plot(outdir=self.outdir, label="test") self.assertTrue(os.path.exists(f"{self.outdir}/test_checkpoint_trace.png")) priors = dict( - a=bilby.core.prior.Uniform(-10, 10, latex_label='a'), + a=bilby.core.prior.Uniform(-10, 10, latex_label="a"), b=bilby.core.prior.Uniform(-10, 10), ) chain.thin_by_nact = 0.5 @@ -180,18 +178,18 @@ def test_dict_access(self): def test_list_access(self): s = Sample(self.sample_dict) slist = s.list - self.assertEqual(slist, [self.sample_dict['a'], self.sample_dict['b']]) + self.assertEqual(slist, [self.sample_dict["a"], self.sample_dict["b"]]) def test_setitem(self): s = Sample(self.sample_dict) # Set existing parameter - s['a'] = 100 - self.assertEqual(s['a'], 100) + s["a"] = 100 + self.assertEqual(s["a"], 100) # Add parameter - s['c'] = 100 - self.assertEqual(s['c'], 100) + s["c"] = 100 + self.assertEqual(s["c"], 100) def test_parameter_only_dict(self): s = Sample(self.sample_dict) @@ -201,9 +199,9 @@ def test_update(self): sample_dict = dict(a=1, b=2) curr = Sample(sample_dict) prop = curr.copy() - prop['a'] = 200 - self.assertEqual(prop['a'], 200) - self.assertEqual(curr['a'], 1) + prop["a"] = 200 + self.assertEqual(prop["a"], 200) + self.assertEqual(curr["a"], 1) class TestACT(unittest.TestCase): diff --git a/test/bilby_mcmc/test_proposals.py b/test/bilby_mcmc/test_proposals.py index cd36f94fa..f4fe92e60 100644 --- a/test/bilby_mcmc/test_proposals.py +++ b/test/bilby_mcmc/test_proposals.py @@ -1,23 +1,26 @@ -import os import copy -import shutil -import unittest -import inspect import importlib +import inspect +import os +import shutil import sys import time +import unittest + +import numpy as np +import pytest + import bilby -from bilby.bilby_mcmc.chain import Chain, Sample from bilby.bilby_mcmc import proposals +from bilby.bilby_mcmc.chain import Chain, Sample from bilby.bilby_mcmc.utils import LOGLKEY, LOGPKEY -import numpy as np -import pytest class GivenProposal(proposals.BaseProposal): - """ A simple proposal class used for testing """ + """A simple proposal class used for testing""" + def __init__(self, priors, weight=1, subset=None, sigma=0.01): - super(GivenProposal, self).__init__(priors, weight, subset) + super().__init__(priors, weight, subset) def propose(self, chain): log_factor = 0 @@ -26,10 +29,9 @@ def propose(self, chain): class TestBaseProposals(unittest.TestCase): def create_priors(self, ndim=2, boundary=None): - priors = bilby.core.prior.PriorDict({ - f'x{i}': bilby.core.prior.Uniform(-10, 10, name=f'x{i}', boundary=boundary) - for i in range(ndim) - }) + priors = bilby.core.prior.PriorDict( + {f"x{i}": bilby.core.prior.Uniform(-10, 10, name=f"x{i}", boundary=boundary) for i in range(ndim)} + ) priors["fixedA"] = bilby.core.prior.DeltaFunction(1) priors["infinite_support"] = bilby.core.prior.Normal(0, 1) priors["half_infinite_support"] = bilby.core.prior.HalfNormal(1) @@ -110,9 +112,7 @@ def tearDown(self): shutil.rmtree(self.outdir) def get_simple_proposals(self): - clsmembers = inspect.getmembers( - sys.modules[proposals.__name__], inspect.isclass - ) + clsmembers = inspect.getmembers(sys.modules[proposals.__name__], inspect.isclass) clsmembers_clean = [] for name, cls in clsmembers: a = "Proposal" in name @@ -130,7 +130,7 @@ def get_simple_proposals(self): def proposal_check(self, prop, ndim=2, N=100): chain = self.create_chain(ndim=ndim) - if getattr(prop, 'needs_likelihood_and_priors', False): + if getattr(prop, "needs_likelihood_and_priors", False): return print(f"Testing {prop.__class__.__name__}") diff --git a/test/bilby_mcmc/test_sampler.py b/test/bilby_mcmc/test_sampler.py index d72526683..041b95658 100644 --- a/test/bilby_mcmc/test_sampler.py +++ b/test/bilby_mcmc/test_sampler.py @@ -2,12 +2,13 @@ import shutil import unittest +import numpy as np +import pandas as pd + import bilby from bilby.bilby_mcmc.sampler import Bilby_MCMC, BilbyMCMCSampler from bilby.bilby_mcmc.utils import ConvergenceInputs from bilby.core.sampler.base_sampler import SamplerError -import numpy as np -import pandas as pd class TestBilbyMCMCSampler(unittest.TestCase): @@ -15,9 +16,7 @@ def setUp(self): default_kwargs = Bilby_MCMC.default_kwargs default_kwargs["target_nsamples"] = 100 default_kwargs["L1steps"] = 1 - self.convergence_inputs = ConvergenceInputs( - **{key: default_kwargs[key] for key in ConvergenceInputs._fields} - ) + self.convergence_inputs = ConvergenceInputs(**{key: default_kwargs[key] for key in ConvergenceInputs._fields}) self.outdir = "bilby_mcmc_sampler_test" if os.path.isdir(self.outdir) is False: @@ -25,6 +24,7 @@ def setUp(self): def model(time, m, c): return time * m + c + injection_parameters = dict(m=0.5, c=0.2) sampling_frequency = 10 time_duration = 10 @@ -37,11 +37,11 @@ def model(time, m, c): # From hereon, the syntax is exactly equivalent to other bilby examples # We make a prior priors = dict() - priors['m'] = bilby.core.prior.Uniform(0, 5, 'm') - priors['c'] = bilby.core.prior.Uniform(-2, 2, 'c') + priors["m"] = bilby.core.prior.Uniform(0, 5, "m") + priors["c"] = bilby.core.prior.Uniform(-2, 2, "c") priors = bilby.core.prior.PriorDict(priors) - search_parameter_keys = ['m', 'c'] + search_parameter_keys = ["m", "c"] use_ratio = False bilby.core.sampler.base_sampler._initialize_global_variables( @@ -64,7 +64,7 @@ def test_None_proposal_cycle(self): beta=1, Tindex=0, Eindex=0, - use_ratio=False + use_ratio=False, ) def test_default_proposal_cycle(self): @@ -74,7 +74,7 @@ def test_default_proposal_cycle(self): beta=1, Tindex=0, Eindex=0, - use_ratio=False + use_ratio=False, ) nsteps = 0 @@ -89,9 +89,7 @@ def test_default_proposal_cycle(self): def test_get_expected_outputs(): label = "par0" outdir = os.path.join("some", "bilby_pipe", "dir") - filenames, directories = Bilby_MCMC.get_expected_outputs( - outdir=outdir, label=label - ) + filenames, directories = Bilby_MCMC.get_expected_outputs(outdir=outdir, label=label) assert len(filenames) == 1 assert len(directories) == 0 assert os.path.join(outdir, f"{label}_resume.pickle") in filenames diff --git a/test/check_author_list.py b/test/check_author_list.py index 95d5be1ca..2bc8e63f0 100644 --- a/test/check_author_list.py +++ b/test/check_author_list.py @@ -1,11 +1,11 @@ -""" A script to verify that the .AUTHOR.md file is up to date """ +"""A script to verify that the .AUTHOR.md file is up to date""" import re import subprocess special_cases = ["plasky", "thomas", "mj-will", "richard", "douglas", "nixnyxnyx"] AUTHORS_list = [] -with open("AUTHORS.md", "r") as f: +with open("AUTHORS.md") as f: AUTHORS_list = " ".join([line for line in f]).lower() @@ -16,15 +16,14 @@ def remove_accents(raw_text): - - raw_text = re.sub(u"[àáâãäå]", 'a', raw_text) - raw_text = re.sub(u"[èéêë]", 'e', raw_text) - raw_text = re.sub(u"[ìíîï]", 'i', raw_text) - raw_text = re.sub(u"[òóôõö]", 'o', raw_text) - raw_text = re.sub(u"[ùúûü]", 'u', raw_text) - raw_text = re.sub(u"[ýÿ]", 'y', raw_text) - raw_text = re.sub(u"[ß]", 'ss', raw_text) - raw_text = re.sub(u"[ñ]", 'n', raw_text) + raw_text = re.sub("[àáâãäå]", "a", raw_text) + raw_text = re.sub("[èéêë]", "e", raw_text) + raw_text = re.sub("[ìíîï]", "i", raw_text) + raw_text = re.sub("[òóôõö]", "o", raw_text) + raw_text = re.sub("[ùúûü]", "u", raw_text) + raw_text = re.sub("[ýÿ]", "y", raw_text) + raw_text = re.sub("[ß]", "ss", raw_text) + raw_text = re.sub("[ñ]", "n", raw_text) return raw_text @@ -32,7 +31,7 @@ def remove_accents(raw_text): fail_test = False for line in lines: line = line.replace(".", " ") - line = re.sub('([A-Z][a-z]+)', r' \1', re.sub('([A-Z]+)', r' \1', line)) + line = re.sub("([A-Z][a-z]+)", r" \1", re.sub("([A-Z]+)", r" \1", line)) line = remove_accents(line) for element in line.split()[1:]: element = element.lower() diff --git a/test/conftest.py b/test/conftest.py index d08c38604..627379c83 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -2,9 +2,7 @@ def pytest_addoption(parser): - parser.addoption( - "--skip-roqs", action="store_true", default=False, help="Skip all tests that require ROQs" - ) + parser.addoption("--skip-roqs", action="store_true", default=False, help="Skip all tests that require ROQs") def pytest_configure(config): diff --git a/test/core/grid_test.py b/test/core/grid_test.py index 82e44c5cc..07a0c761f 100644 --- a/test/core/grid_test.py +++ b/test/core/grid_test.py @@ -1,7 +1,8 @@ +import os +import shutil import unittest + import numpy as np -import shutil -import os from scipy.stats import multivariate_normal import bilby @@ -10,7 +11,7 @@ # set 2D multivariate Gaussian likelihood class MultiGaussian(bilby.Likelihood): def __init__(self, mean, cov): - super(MultiGaussian, self).__init__(parameters=dict()) + super().__init__(parameters=dict()) self.cov = np.array(cov) self.mean = np.array(mean) self.sigma = np.sqrt(np.diag(self.cov)) @@ -21,7 +22,7 @@ def dim(self): return len(self.cov[0]) def log_likelihood(self): - x = np.array([self.parameters["x{0}".format(i)] for i in range(self.dim)]) + x = np.array([self.parameters[f"x{i}"] for i in range(self.dim)]) return self.pdf.logpdf(x) @@ -37,19 +38,10 @@ def setUp(self): # set priors out to +/- 5 sigma self.priors = bilby.core.prior.PriorDict() - self.priors.update( - { - "x{0}".format(i): bilby.core.prior.Uniform(-5, 5, "x{0}".format(i)) - for i in range(dim) - } - ) + self.priors.update({f"x{i}": bilby.core.prior.Uniform(-5, 5, f"x{i}") for i in range(dim)}) # expected evidence integral should be (1/V) where V is the prior volume - log_prior_vol = np.sum( - np.log( - [prior.maximum - prior.minimum for key, prior in self.priors.items()] - ) - ) + log_prior_vol = np.sum(np.log([prior.maximum - prior.minimum for key, prior in self.priors.items()])) self.expected_ln_evidence = -log_prior_vol self.grid_size = 100 @@ -80,11 +72,11 @@ def test_grid_file_name_default(self): label = "label" self.assertEqual( bilby.core.grid.grid_file_name(outdir, label), - "{}/{}_grid.json".format(outdir, label), + f"{outdir}/{label}_grid.json", ) self.assertEqual( bilby.core.grid.grid_file_name(outdir, label, True), - "{}/{}_grid.json.gz".format(outdir, label), + f"{outdir}/{label}_grid.json.gz", ) def test_fail_save_and_load(self): @@ -111,9 +103,11 @@ def test_parameter_names(self): def test_no_marginalization(self): # test arrays are the same if no parameters are given to marginalize # over - self.assertTrue(np.array_equal( - self.grid.ln_likelihood, - self.grid.marginalize_ln_likelihood(not_parameters=self.grid.parameter_names))) + self.assertTrue( + np.array_equal( + self.grid.ln_likelihood, self.grid.marginalize_ln_likelihood(not_parameters=self.grid.parameter_names) + ) + ) def test_marginalization_shapes(self): self.assertEqual(0, len(self.grid.marginalize_ln_likelihood().shape)) @@ -125,13 +119,19 @@ def test_marginalization_shapes(self): self.assertTupleEqual((self.grid_size, self.grid_size), self.grid.ln_posterior.shape) def test_marginalization_opposite(self): - self.assertTrue(np.array_equal( - self.grid.marginalize_ln_likelihood(parameters=self.grid.parameter_names[0]), - self.grid.marginalize_ln_likelihood(not_parameters=self.grid.parameter_names[1]))) + self.assertTrue( + np.array_equal( + self.grid.marginalize_ln_likelihood(parameters=self.grid.parameter_names[0]), + self.grid.marginalize_ln_likelihood(not_parameters=self.grid.parameter_names[1]), + ) + ) - self.assertTrue(np.array_equal( - self.grid.marginalize_ln_likelihood(parameters=self.grid.parameter_names[1]), - self.grid.marginalize_ln_likelihood(not_parameters=self.grid.parameter_names[0]))) + self.assertTrue( + np.array_equal( + self.grid.marginalize_ln_likelihood(parameters=self.grid.parameter_names[1]), + self.grid.marginalize_ln_likelihood(not_parameters=self.grid.parameter_names[0]), + ) + ) def test_max_marginalized_likelihood(self): # marginalised likelihoods should have max values of 1 (as they are not @@ -161,11 +161,7 @@ def test_mesh_grid(self): def test_grid_integer_points(self): n_points = [10, 20] grid = bilby.core.grid.Grid( - label="label", - outdir="outdir", - priors=self.priors, - grid_size=n_points, - likelihood=self.likelihood + label="label", outdir="outdir", priors=self.priors, grid_size=n_points, likelihood=self.likelihood ) self.assertTupleEqual(tuple(n_points), grid.mesh_grid[0].shape) @@ -175,11 +171,7 @@ def test_grid_integer_points(self): def test_grid_dict_points(self): n_points = {"x0": 15, "x1": 18} grid = bilby.core.grid.Grid( - label="label", - outdir="outdir", - priors=self.priors, - grid_size=n_points, - likelihood=self.likelihood + label="label", outdir="outdir", priors=self.priors, grid_size=n_points, likelihood=self.likelihood ) self.assertTupleEqual((n_points["x0"], n_points["x1"]), grid.mesh_grid[0].shape) self.assertEqual(grid.mesh_grid[0][0, 0], self.priors[self.grid.parameter_names[0]].minimum) @@ -227,9 +219,7 @@ def test_save_and_load_from_outdir_label(self): self.assertEqual(self.grid.n_dims, new_grid.n_dims) self.assertTrue(np.array_equal(new_grid.mesh_grid[0], self.grid.mesh_grid[0])) for par in new_grid.parameter_names: - self.assertTrue(np.array_equal( - new_grid.sample_points[par], self.grid.sample_points[par]) - ) + self.assertTrue(np.array_equal(new_grid.sample_points[par], self.grid.sample_points[par])) self.assertEqual(self.grid.ln_evidence, new_grid.ln_evidence) self.assertTrue(np.array_equal(self.grid.ln_likelihood, new_grid.ln_likelihood)) self.assertTrue(np.array_equal(self.grid.ln_posterior, new_grid.ln_posterior)) diff --git a/test/core/likelihood_test.py b/test/core/likelihood_test.py index 8061f9e55..8f46e7fd2 100644 --- a/test/core/likelihood_test.py +++ b/test/core/likelihood_test.py @@ -5,15 +5,15 @@ import bilby.core.likelihood from bilby.core.likelihood import ( - Likelihood, - GaussianLikelihood, - PoissonLikelihood, - StudentTLikelihood, Analytical1DLikelihood, - ExponentialLikelihood, - AnalyticalMultidimensionalCovariantGaussian, AnalyticalMultidimensionalBimodalCovariantGaussian, + AnalyticalMultidimensionalCovariantGaussian, + ExponentialLikelihood, + GaussianLikelihood, JointLikelihood, + Likelihood, + PoissonLikelihood, + StudentTLikelihood, ) @@ -62,9 +62,7 @@ def test_func(x, parameter1, parameter2): self.func = test_func self.parameter1_value = 4 self.parameter2_value = 7 - self.analytical_1d_likelihood = Analytical1DLikelihood( - x=self.x, y=self.y, func=self.func - ) + self.analytical_1d_likelihood = Analytical1DLikelihood(x=self.x, y=self.y, func=self.func) self.analytical_1d_likelihood.parameters["parameter1"] = self.parameter1_value self.analytical_1d_likelihood.parameters["parameter2"] = self.parameter2_value @@ -128,12 +126,8 @@ def new_func(x): self.analytical_1d_likelihood.func = new_func def test_parameters(self): - expected_parameters = dict( - parameter1=self.parameter1_value, parameter2=self.parameter2_value - ) - self.assertDictEqual( - expected_parameters, self.analytical_1d_likelihood.parameters - ) + expected_parameters = dict(parameter1=self.parameter1_value, parameter2=self.parameter2_value) + self.assertDictEqual(expected_parameters, self.analytical_1d_likelihood.parameters) def test_n(self): self.assertEqual(len(self.x), self.analytical_1d_likelihood.n) @@ -147,17 +141,11 @@ def test_model_parameters(self): sigma = 5 self.analytical_1d_likelihood.sigma = sigma self.analytical_1d_likelihood.parameters["sigma"] = sigma - expected_model_parameters = dict( - parameter1=self.parameter1_value, parameter2=self.parameter2_value - ) - self.assertDictEqual( - expected_model_parameters, self.analytical_1d_likelihood.model_parameters() - ) + expected_model_parameters = dict(parameter1=self.parameter1_value, parameter2=self.parameter2_value) + self.assertDictEqual(expected_model_parameters, self.analytical_1d_likelihood.model_parameters()) def test_repr(self): - expected = "Analytical1DLikelihood(x={}, y={}, func={})".format( - self.x, self.y, self.func.__name__ - ) + expected = f"Analytical1DLikelihood(x={self.x}, y={self.y}, func={self.func.__name__})" self.assertEqual(expected, repr(self.analytical_1d_likelihood)) @@ -219,9 +207,7 @@ def test_sigma_other(self): def test_repr(self): likelihood = GaussianLikelihood(self.x, self.y, self.function, sigma=self.sigma) - expected = "GaussianLikelihood(x={}, y={}, func={}, sigma={})".format( - self.x, self.y, self.function.__name__, self.sigma - ) + expected = f"GaussianLikelihood(x={self.x}, y={self.y}, func={self.function.__name__}, sigma={self.sigma})" self.assertEqual(expected, repr(likelihood)) @@ -246,9 +232,7 @@ def tearDown(self): del self.function def test_known_sigma(self): - likelihood = StudentTLikelihood( - self.x, self.y, self.function, self.nu, self.sigma - ) + likelihood = StudentTLikelihood(self.x, self.y, self.function, self.nu, self.sigma) likelihood.parameters["m"] = 2 likelihood.parameters["c"] = 0 likelihood.log_likelihood() @@ -297,12 +281,8 @@ def test_lam(self): def test_repr(self): nu = 0 sigma = 0.5 - likelihood = StudentTLikelihood( - self.x, self.y, self.function, nu=nu, sigma=sigma - ) - expected = "StudentTLikelihood(x={}, y={}, func={}, nu={}, sigma={})".format( - self.x, self.y, self.function.__name__, nu, sigma - ) + likelihood = StudentTLikelihood(self.x, self.y, self.function, nu=nu, sigma=sigma) + expected = f"StudentTLikelihood(x={self.x}, y={self.y}, func={self.function.__name__}, nu={nu}, sigma={sigma})" self.assertEqual(expected, repr(likelihood)) @@ -379,39 +359,29 @@ def test_set_y_to_float(self): self.poisson_likelihood.y = 5.3 def test_log_likelihood_wrong_func_return_type(self): - poisson_likelihood = PoissonLikelihood( - x=self.x, y=self.y, func=lambda x: "test" - ) + poisson_likelihood = PoissonLikelihood(x=self.x, y=self.y, func=lambda x: "test") with self.assertRaises(ValueError): poisson_likelihood.log_likelihood() def test_log_likelihood_negative_func_return_element(self): - poisson_likelihood = PoissonLikelihood( - x=self.x, y=self.y, func=lambda x: np.array([3, 6, -2]) - ) + poisson_likelihood = PoissonLikelihood(x=self.x, y=self.y, func=lambda x: np.array([3, 6, -2])) with self.assertRaises(ValueError): poisson_likelihood.log_likelihood() def test_log_likelihood_zero_func_return_element(self): - poisson_likelihood = PoissonLikelihood( - x=self.x, y=self.y, func=lambda x: np.array([3, 6, 0]) - ) + poisson_likelihood = PoissonLikelihood(x=self.x, y=self.y, func=lambda x: np.array([3, 6, 0])) self.assertEqual(-np.inf, poisson_likelihood.log_likelihood()) def test_log_likelihood_dummy(self): - """ Merely tests if it goes into the right if else bracket """ - poisson_likelihood = PoissonLikelihood( - x=self.x, y=self.y, func=lambda x: np.linspace(1, 100, self.N) - ) + """Merely tests if it goes into the right if else bracket""" + poisson_likelihood = PoissonLikelihood(x=self.x, y=self.y, func=lambda x: np.linspace(1, 100, self.N)) with mock.patch("numpy.sum") as m: m.return_value = 1 self.assertEqual(1, poisson_likelihood.log_likelihood()) def test_repr(self): likelihood = PoissonLikelihood(self.x, self.y, self.function) - expected = "PoissonLikelihood(x={}, y={}, func={})".format( - self.x, self.y, self.function.__name__ - ) + expected = f"PoissonLikelihood(x={self.x}, y={self.y}, func={self.function.__name__})" self.assertEqual(expected, repr(likelihood)) @@ -432,9 +402,7 @@ def test_function_array(x, c): self.function = test_function self.function_array = test_function_array - self.exponential_likelihood = ExponentialLikelihood( - x=self.x, y=self.y, func=self.function - ) + self.exponential_likelihood = ExponentialLikelihood(x=self.x, y=self.y, func=self.function) def tearDown(self): del self.N @@ -491,18 +459,14 @@ def test_set_y_to_nd_array_with_negative_element(self): self.exponential_likelihood.y = np.array([4.3, -1.2, 4]) def test_log_likelihood_default(self): - """ Merely tests that it ends up at the right place in the code """ - exponential_likelihood = ExponentialLikelihood( - x=self.x, y=self.y, func=lambda x: np.array([4.2]) - ) + """Merely tests that it ends up at the right place in the code""" + exponential_likelihood = ExponentialLikelihood(x=self.x, y=self.y, func=lambda x: np.array([4.2])) with mock.patch("numpy.sum") as m: m.return_value = 3 self.assertEqual(-3, exponential_likelihood.log_likelihood()) def test_repr(self): - expected = "ExponentialLikelihood(x={}, y={}, func={})".format( - self.x, self.y, self.function.__name__ - ) + expected = f"ExponentialLikelihood(x={self.x}, y={self.y}, func={self.function.__name__})" self.assertEqual(expected, repr(self.exponential_likelihood)) @@ -511,9 +475,7 @@ def setUp(self): self.cov = [[1, 0, 0], [0, 4, 0], [0, 0, 9]] self.sigma = [1, 2, 3] self.mean = [10, 11, 12] - self.likelihood = AnalyticalMultidimensionalCovariantGaussian( - mean=self.mean, cov=self.cov - ) + self.likelihood = AnalyticalMultidimensionalCovariantGaussian(mean=self.mean, cov=self.cov) self.likelihood.parameters.update({f"x{ii}": 0 for ii in range(len(self.sigma))}) def tearDown(self): @@ -579,9 +541,7 @@ def test_dim(self): self.assertEqual(3, self.likelihood.dim) def test_log_likelihood(self): - likelihood = AnalyticalMultidimensionalBimodalCovariantGaussian( - mean_1=[0], mean_2=[0], cov=[1] - ) + likelihood = AnalyticalMultidimensionalBimodalCovariantGaussian(mean_1=[0], mean_2=[0], cov=[1]) self.assertEqual(-np.log(2 * np.pi) / 2, likelihood.log_likelihood(dict(x0=0))) @@ -601,9 +561,7 @@ def setUp(self): self.third_likelihood = ExponentialLikelihood( x=self.x, y=self.y, func=lambda x, param4, param5: (param4 + param5) * x ) - self.joint_likelihood = JointLikelihood( - self.first_likelihood, self.second_likelihood, self.third_likelihood - ) + self.joint_likelihood = JointLikelihood(self.first_likelihood, self.second_likelihood, self.third_likelihood) # self.first_parameters = dict(param1=1, param2=2) # self.second_parmaeters = dict(param2=2, param3=3) @@ -630,7 +588,13 @@ def tearDown(self): del self.joint_likelihood def test_parameters_consistent_from_init(self): - expected = dict(param1=1, param2=2, param3=3, param4=4, param5=5,) + expected = dict( + param1=1, + param2=2, + param3=3, + param4=4, + param5=5, + ) self.assertDictEqual(expected, self.joint_likelihood.parameters) def test_log_likelihood_correctly_sums(self): @@ -667,9 +631,7 @@ def test_log_noise_likelihood(self): self.first_likelihood.noise_log_likelihood = mock.MagicMock(return_value=1) self.second_likelihood.noise_log_likelihood = mock.MagicMock(return_value=2) self.third_likelihood.noise_log_likelihood = mock.MagicMock(return_value=3) - self.joint_likelihood = JointLikelihood( - self.first_likelihood, self.second_likelihood, self.third_likelihood - ) + self.joint_likelihood = JointLikelihood(self.first_likelihood, self.second_likelihood, self.third_likelihood) expected = ( self.first_likelihood.noise_log_likelihood() + self.second_likelihood.noise_log_likelihood() @@ -679,9 +641,7 @@ def test_log_noise_likelihood(self): def test_init_with_list_of_likelihoods(self): with self.assertRaises(ValueError): - JointLikelihood( - [self.first_likelihood, self.second_likelihood, self.third_likelihood] - ) + JointLikelihood([self.first_likelihood, self.second_likelihood, self.third_likelihood]) def test_setting_single_likelihood(self): self.joint_likelihood.likelihoods = self.first_likelihood @@ -702,7 +662,6 @@ def test_setting_likelihood_other(self): class TestGPLikelihood(unittest.TestCase): - def setUp(self) -> None: self.t = [1, 2, 3] self.y = [4, 5, 6] @@ -715,7 +674,8 @@ def setUp(self) -> None: self.gp_mock.get_parameter_dict = mock.MagicMock(return_value=dict(self.parameter_dict)) self.gp_class = mock.MagicMock(return_value=self.gp_mock) self.celerite_likelihood = bilby.core.likelihood._GPLikelihood( - kernel=self.kernel, mean_model=self.mean_model, t=self.t, y=self.y, yerr=self.yerr, gp_class=self.gp_class) + kernel=self.kernel, mean_model=self.mean_model, t=self.t, y=self.y, yerr=self.yerr, gp_class=self.gp_class + ) def tearDown(self) -> None: del self.t @@ -745,11 +705,13 @@ def test_gp_class(self): def test_gp_instantiation(self): self.celerite_likelihood.GPClass.assert_called_once_with( - kernel=self.kernel, mean=self.mean_model, fit_mean=True, fit_white_noise=True) + kernel=self.kernel, mean=self.mean_model, fit_mean=True, fit_white_noise=True + ) def test_gp_mock(self): self.celerite_likelihood.gp.compute.assert_called_once_with( - self.celerite_likelihood.t, yerr=self.celerite_likelihood.yerr) + self.celerite_likelihood.t, yerr=self.celerite_likelihood.yerr + ) def test_parameters(self): self.assertDictEqual(self.parameter_dict, self.celerite_likelihood.parameters) @@ -764,24 +726,22 @@ def test_set_parameters_no_exceptions(self): class TestFunctionMeanModel(unittest.TestCase): - def test_function_to_celerite_mean_model(self): def func(x, a, b, c): - return a * x ** 2 + b * x + c + return a * x**2 + b * x + c mean_model = bilby.core.likelihood.function_to_celerite_mean_model(func=func) self.assertListEqual(["a", "b", "c"], list(mean_model.parameter_names)) def test_function_to_george_mean_model(self): def func(x, a, b, c): - return a * x ** 2 + b * x + c + return a * x**2 + b * x + c mean_model = bilby.core.likelihood.function_to_celerite_mean_model(func=func) self.assertListEqual(["a", "b", "c"], list(mean_model.parameter_names)) class TestCeleriteLikelihoodEvaluation(unittest.TestCase): - def setUp(self) -> None: import celerite @@ -797,7 +757,8 @@ def func(x, a): self.MeanModel = bilby.likelihood.function_to_celerite_mean_model(func=func) self.mean_model = self.MeanModel(a=1) self.celerite_likelihood = bilby.core.likelihood.CeleriteLikelihood( - kernel=self.kernel, mean_model=self.mean_model, t=self.t, y=self.y, yerr=self.yerr) + kernel=self.kernel, mean_model=self.mean_model, t=self.t, y=self.y, yerr=self.yerr + ) self.celerite_likelihood.parameters = self.parameters def tearDown(self) -> None: @@ -822,7 +783,6 @@ def test_set_parameters(self): class TestGeorgeLikelihoodEvaluation(unittest.TestCase): - def setUp(self) -> None: import george @@ -838,7 +798,8 @@ def func(x, a): self.MeanModel = bilby.likelihood.function_to_celerite_mean_model(func=func) self.mean_model = self.MeanModel(a=1) self.george_likelihood = bilby.core.likelihood.GeorgeLikelihood( - kernel=self.kernel, mean_model=self.mean_model, t=self.t, y=self.y, yerr=self.yerr) + kernel=self.kernel, mean_model=self.mean_model, t=self.t, y=self.y, yerr=self.yerr + ) def tearDown(self) -> None: pass diff --git a/test/core/prior/analytical_test.py b/test/core/prior/analytical_test.py index 12892aca1..67927e3a0 100644 --- a/test/core/prior/analytical_test.py +++ b/test/core/prior/analytical_test.py @@ -1,6 +1,7 @@ import unittest import numpy as np + import bilby @@ -57,8 +58,7 @@ def test_array_probability(self): discrete_value_prior = bilby.core.prior.DiscreteValues(values) self.assertTrue( np.all( - discrete_value_prior.prob([1.1, 2.2, 2.2, 300.0, 200.0]) - == np.array([1 / N, 1 / N, 1 / N, 1 / N, 0]) + discrete_value_prior.prob([1.1, 2.2, 2.2, 300.0, 200.0]) == np.array([1 / N, 1 / N, 1 / N, 1 / N, 0]) ) ) @@ -132,8 +132,12 @@ def test_single_lnprobability(self): def test_array_lnprobability(self): N = 3 categorical_prior = bilby.core.prior.Categorical(N) - self.assertTrue(np.all(categorical_prior.ln_prob([0, 1, 1, 2, 3]) == np.array( - [-np.log(N), -np.log(N), -np.log(N), -np.log(N), -np.inf]))) + self.assertTrue( + np.all( + categorical_prior.ln_prob([0, 1, 1, 2, 3]) + == np.array([-np.log(N), -np.log(N), -np.log(N), -np.log(N), -np.inf]) + ) + ) class TestWeightedCategoricalPrior(unittest.TestCase): @@ -214,9 +218,7 @@ def test_cdf(self): categorical_prior = bilby.core.prior.WeightedCategorical(N, weights=weights) sample = categorical_prior.sample(size=10) original = np.asarray(sample) - new = np.array(categorical_prior.rescale( - categorical_prior.cdf(sample) - )) + new = np.array(categorical_prior.rescale(categorical_prior.cdf(sample))) np.testing.assert_array_equal(original, new) diff --git a/test/core/prior/base_test.py b/test/core/prior/base_test.py index c9b788732..e8aaeaac4 100644 --- a/test/core/prior/base_test.py +++ b/test/core/prior/base_test.py @@ -149,6 +149,7 @@ def setUp(self): def conversion_func(parameters): return dict(a=parameters["a"], b=parameters["b"], c=parameters["a"] / parameters["b"]) + self.priors = bilby.core.prior.PriorDict(self.priors, conversion_function=conversion_func) def test_prob_integrate_to_one(self): @@ -158,7 +159,7 @@ def test_prob_integrate_to_one(self): prob = self.priors.prob(samples, axis=0) dm1 = self.priors["a"].maximum - self.priors["a"].minimum dm2 = self.priors["b"].maximum - self.priors["b"].minimum - prior_volume = (dm1 * dm2) + prior_volume = dm1 * dm2 n_accepted = np.sum(prob) integral = prior_volume * n_accepted / n_samples # binomial random distribution diff --git a/test/core/prior/conditional_test.py b/test/core/prior/conditional_test.py index 20c0cda93..c27284f61 100644 --- a/test/core/prior/conditional_test.py +++ b/test/core/prior/conditional_test.py @@ -1,11 +1,11 @@ import os +import pickle import shutil import unittest from unittest import mock import numpy as np import pandas as pd -import pickle import bilby @@ -71,17 +71,13 @@ def test_get_instantiation_dict(self): self.assertEqual(value, actual[key]) def test_update_conditions_correct_variables(self): - self.prior.update_conditions( - test_variable_1=self.test_variable_1, test_variable_2=self.test_variable_2 - ) + self.prior.update_conditions(test_variable_1=self.test_variable_1, test_variable_2=self.test_variable_2) self.assertEqual(1, self.condition_func_call_counter) self.assertEqual(self.minimum + 1, self.prior.minimum) self.assertEqual(self.maximum + 1, self.prior.maximum) def test_update_conditions_no_variables(self): - self.prior.update_conditions( - test_variable_1=self.test_variable_1, test_variable_2=self.test_variable_2 - ) + self.prior.update_conditions(test_variable_1=self.test_variable_1, test_variable_2=self.test_variable_2) self.prior.update_conditions() self.assertEqual(1, self.condition_func_call_counter) self.assertEqual(self.minimum + 1, self.prior.minimum) @@ -166,9 +162,7 @@ def test_reset_to_reference_parameters(self): self.assertEqual(self.prior.reference_params["maximum"], self.prior.maximum) def test_cond_prior_instantiation_no_boundary_prior(self): - prior = bilby.core.prior.ConditionalFermiDirac( - condition_func=None, sigma=1, mu=1 - ) + prior = bilby.core.prior.ConditionalFermiDirac(condition_func=None, sigma=1, mu=1) self.assertIsNone(prior.boundary) @@ -185,9 +179,7 @@ def condition_func_3(reference_parameters, var_1, var_2): self.minimum = 0 self.maximum = 1 - self.prior_0 = bilby.core.prior.Uniform( - minimum=self.minimum, maximum=self.maximum - ) + self.prior_0 = bilby.core.prior.Uniform(minimum=self.minimum, maximum=self.maximum) self.prior_1 = bilby.core.prior.ConditionalUniform( condition_func=condition_func_1, minimum=self.minimum, maximum=self.maximum ) @@ -205,9 +197,7 @@ def condition_func_3(reference_parameters, var_1, var_2): var_1=self.prior_1, ) ) - self.conditional_priors_manually_set_items = ( - bilby.core.prior.ConditionalPriorDict() - ) + self.conditional_priors_manually_set_items = bilby.core.prior.ConditionalPriorDict() self.test_sample = dict(var_0=0.7, var_1=0.6, var_2=0.5, var_3=0.4) self.test_value = 1 / np.prod([self.test_sample[f"var_{ii}"] for ii in range(3)]) for key, value in dict( @@ -230,9 +220,7 @@ def tearDown(self): del self.test_sample def test_conditions_resolved_upon_instantiation(self): - self.assertListEqual( - ["var_0", "var_1", "var_2", "var_3"], self.conditional_priors.sorted_keys - ) + self.assertListEqual(["var_0", "var_1", "var_2", "var_3"], self.conditional_priors.sorted_keys) def test_conditions_resolved_setting_items(self): self.assertListEqual( @@ -244,14 +232,10 @@ def test_unconditional_keys_upon_instantiation(self): self.assertListEqual(["var_0"], self.conditional_priors.unconditional_keys) def test_unconditional_keys_setting_items(self): - self.assertListEqual( - ["var_0"], self.conditional_priors_manually_set_items.unconditional_keys - ) + self.assertListEqual(["var_0"], self.conditional_priors_manually_set_items.unconditional_keys) def test_conditional_keys_upon_instantiation(self): - self.assertListEqual( - ["var_1", "var_2", "var_3"], self.conditional_priors.conditional_keys - ) + self.assertListEqual(["var_1", "var_2", "var_3"], self.conditional_priors.conditional_keys) def test_conditional_keys_setting_items(self): self.assertListEqual( @@ -284,9 +268,7 @@ def test_sample_subset_all_keys(self): var_2=0.33516501262044845, var_3=0.09579062316418356, ), - self.conditional_priors.sample_subset( - keys=["var_0", "var_1", "var_2", "var_3"] - ), + self.conditional_priors.sample_subset(keys=["var_0", "var_1", "var_2", "var_3"]), ) def test_sample_illegal_subset(self): @@ -318,9 +300,7 @@ def test_rescale(self): ) ) ref_variables = self.test_sample.values() - res = self.conditional_priors.rescale( - keys=self.test_sample.keys(), theta=ref_variables - ) + res = self.conditional_priors.rescale(keys=self.test_sample.keys(), theta=ref_variables) expected = [self.test_sample["var_0"]] for ii in range(1, 4): expected.append(expected[-1] * self.test_sample[f"var_{ii}"]) @@ -335,7 +315,7 @@ def test_rescale_with_joint_prior(self): # set multivariate Gaussian distribution names = ["mvgvar_0", "mvgvar_1"] mu = [[0.79, -0.83]] - cov = [[[0.03, 0.], [0., 0.04]]] + cov = [[[0.03, 0.0], [0.0, 0.04]]] mvg = bilby.core.prior.MultivariateGaussianDist(names, mus=mu, covs=cov) priordict = bilby.core.prior.ConditionalPriorDict( @@ -371,10 +351,8 @@ def test_cdf(self): """ sample = self.conditional_priors.sample() self.assertEqual( - self.conditional_priors.rescale( - sample.keys(), - self.conditional_priors.cdf(sample=sample).values() - ), list(sample.values()) + self.conditional_priors.rescale(sample.keys(), self.conditional_priors.cdf(sample=sample).values()), + list(sample.values()), ) def test_rescale_illegal_conditions(self): @@ -387,26 +365,18 @@ def test_rescale_illegal_conditions(self): def test_combined_conditions(self): def d_condition_func(reference_params, a, b, c): - return dict( - minimum=reference_params["minimum"], maximum=reference_params["maximum"] - ) + return dict(minimum=reference_params["minimum"], maximum=reference_params["maximum"]) def a_condition_func(reference_params, b, c): - return dict( - minimum=reference_params["minimum"], maximum=reference_params["maximum"] - ) + return dict(minimum=reference_params["minimum"], maximum=reference_params["maximum"]) priors = bilby.core.prior.ConditionalPriorDict() - priors["a"] = bilby.core.prior.ConditionalUniform( - condition_func=a_condition_func, minimum=0, maximum=1 - ) + priors["a"] = bilby.core.prior.ConditionalUniform(condition_func=a_condition_func, minimum=0, maximum=1) priors["b"] = bilby.core.prior.LogUniform(minimum=1, maximum=10) - priors["d"] = bilby.core.prior.ConditionalUniform( - condition_func=d_condition_func, minimum=0.0, maximum=1.0 - ) + priors["d"] = bilby.core.prior.ConditionalUniform(condition_func=d_condition_func, minimum=0.0, maximum=1.0) priors["c"] = bilby.core.prior.LogUniform(minimum=1, maximum=10) priors.sample() @@ -447,7 +417,6 @@ def _tp_conditional_uniform(ref_params, period): class TestDirichletPrior(unittest.TestCase): - def setUp(self): self.priors = bilby.core.prior.DirichletPriorDict(5) diff --git a/test/core/prior/dict_test.py b/test/core/prior/dict_test.py index 3970277e0..85a0597ad 100644 --- a/test/core/prior/dict_test.py +++ b/test/core/prior/dict_test.py @@ -14,7 +14,6 @@ def __init__(self, names, mus, covs, weights): class FakeJointPriorDist(bilby.core.prior.BaseJointPriorDist): - def __init__(self, names, bounds=None): super().__init__(names=names, bounds=bounds) @@ -23,19 +22,13 @@ def __init__(self, names, bounds=None): class TestPriorDict(unittest.TestCase): - def setUp(self): - - self.first_prior = bilby.core.prior.Uniform( - name="a", minimum=0, maximum=1, unit="kg", boundary=None - ) + self.first_prior = bilby.core.prior.Uniform(name="a", minimum=0, maximum=1, unit="kg", boundary=None) self.second_prior = bilby.core.prior.PowerLaw( name="b", alpha=3, minimum=1, maximum=2, unit="m/s", boundary=None ) self.third_prior = bilby.core.prior.DeltaFunction(name="c", peak=42, unit="m") - self.priors = dict( - mass=self.first_prior, speed=self.second_prior, length=self.third_prior - ) + self.priors = dict(mass=self.first_prior, speed=self.second_prior, length=self.third_prior) self.prior_set_from_dict = bilby.core.prior.PriorDict(dictionary=self.priors) self.default_prior_file = os.path.join( os.path.dirname(os.path.realpath(__file__)), @@ -45,13 +38,9 @@ def setUp(self): os.path.dirname(os.path.realpath(__file__)), "prior_files/joint_prior.prior", ) - self.prior_set_from_file = bilby.core.prior.PriorDict( - filename=self.default_prior_file - ) + self.prior_set_from_file = bilby.core.prior.PriorDict(filename=self.default_prior_file) - self.joint_prior_from_file = bilby.core.prior.PriorDict( - filename=self.joint_prior_file - ) + self.joint_prior_from_file = bilby.core.prior.PriorDict(filename=self.joint_prior_file) def tearDown(self): del self.first_prior @@ -100,20 +89,12 @@ def test_read_from_file(self): latex_label="$q$", unit=None, ), - a_1=bilby.core.prior.Uniform( - name="a_1", minimum=0, maximum=0.99 - ), - a_2=bilby.core.prior.Uniform( - name="a_2", minimum=0, maximum=0.99 - ), + a_1=bilby.core.prior.Uniform(name="a_1", minimum=0, maximum=0.99), + a_2=bilby.core.prior.Uniform(name="a_2", minimum=0, maximum=0.99), tilt_1=bilby.core.prior.Sine(name="tilt_1"), tilt_2=bilby.core.prior.Sine(name="tilt_2"), - phi_12=bilby.core.prior.Uniform( - name="phi_12", minimum=0, maximum=2 * np.pi, boundary="periodic" - ), - phi_jl=bilby.core.prior.Uniform( - name="phi_jl", minimum=0, maximum=2 * np.pi, boundary="periodic" - ), + phi_12=bilby.core.prior.Uniform(name="phi_12", minimum=0, maximum=2 * np.pi, boundary="periodic"), + phi_jl=bilby.core.prior.Uniform(name="phi_jl", minimum=0, maximum=2 * np.pi, boundary="periodic"), luminosity_distance=bilby.gw.prior.UniformSourceFrame( name="luminosity_distance", minimum=1e2, @@ -122,16 +103,10 @@ def test_read_from_file(self): boundary=None, ), dec=bilby.core.prior.Cosine(name="dec"), - ra=bilby.core.prior.Uniform( - name="ra", minimum=0, maximum=2 * np.pi, boundary="periodic" - ), + ra=bilby.core.prior.Uniform(name="ra", minimum=0, maximum=2 * np.pi, boundary="periodic"), theta_jn=bilby.core.prior.Sine(name="theta_jn"), - psi=bilby.core.prior.Uniform( - name="psi", minimum=0, maximum=np.pi, boundary="periodic" - ), - phase=bilby.core.prior.Uniform( - name="phase", minimum=0, maximum=2 * np.pi, boundary="periodic" - ), + psi=bilby.core.prior.Uniform(name="psi", minimum=0, maximum=np.pi, boundary="periodic"), + phase=bilby.core.prior.Uniform(name="phase", minimum=0, maximum=2 * np.pi, boundary="periodic"), ) self.assertDictEqual(expected, self.prior_set_from_file) @@ -143,10 +118,12 @@ def test_read_from_file(self): testBbase = bilby.core.prior.JointPrior(dist=base_dist, name="testBbase", unit="unit") expected_joint = dict(testAfake=testAfake, testBfake=testBfake, testAbase=testAbase, testBbase=testBbase) self.assertDictEqual(expected_joint, self.joint_prior_from_file) - self.assertTrue(id(self.joint_prior_from_file["testAfake"].dist) - == id(self.joint_prior_from_file["testBfake"].dist)) - self.assertTrue(id(self.joint_prior_from_file["testAbase"].dist) - == id(self.joint_prior_from_file["testBbase"].dist)) + self.assertTrue( + id(self.joint_prior_from_file["testAfake"].dist) == id(self.joint_prior_from_file["testBfake"].dist) + ) + self.assertTrue( + id(self.joint_prior_from_file["testAbase"].dist) == id(self.joint_prior_from_file["testBbase"].dist) + ) def test_to_file(self): """ @@ -156,23 +133,16 @@ def test_to_file(self): """ expected = [ "length = DeltaFunction(peak=42, name='c', latex_label='c', unit='m')\n", - "speed = PowerLaw(alpha=3, minimum=1, maximum=2, name='b', latex_label='b', " - "unit='m/s', boundary=None)\n", - "mass = Uniform(minimum=0, maximum=1, name='a', latex_label='a', " - "unit='kg', boundary=None)\n", + "speed = PowerLaw(alpha=3, minimum=1, maximum=2, name='b', latex_label='b', unit='m/s', boundary=None)\n", + "mass = Uniform(minimum=0, maximum=1, name='a', latex_label='a', unit='kg', boundary=None)\n", ] self.prior_set_from_dict.to_file(outdir="prior_files", label="to_file_test") with open("prior_files/to_file_test.prior") as f: for i, line in enumerate(f.readlines()): - self.assertTrue( - any([sorted(line) == sorted(expect) for expect in expected]) - ) + self.assertTrue(any([sorted(line) == sorted(expect) for expect in expected])) def test_from_dict_with_string(self): - string_prior = ( - "PowerLaw(name='b', alpha=3, minimum=1, maximum=2, unit='m/s', " - "boundary=None)" - ) + string_prior = "PowerLaw(name='b', alpha=3, minimum=1, maximum=2, unit='m/s', boundary=None)" self.priors["speed"] = string_prior from_dict = bilby.core.prior.PriorDict(dictionary=self.priors) self.assertDictEqual(self.prior_set_from_dict, from_dict) @@ -183,12 +153,8 @@ def test_convert_floats_to_delta_functions(self): self.prior_set_from_dict["f"] = "unconvertable" self.prior_set_from_dict.convert_floats_to_delta_functions() expected = dict( - mass=bilby.core.prior.Uniform( - name="a", minimum=0, maximum=1, unit="kg", boundary=None - ), - speed=bilby.core.prior.PowerLaw( - name="b", alpha=3, minimum=1, maximum=2, unit="m/s", boundary=None - ), + mass=bilby.core.prior.Uniform(name="a", minimum=0, maximum=1, unit="kg", boundary=None), + speed=bilby.core.prior.PowerLaw(name="b", alpha=3, minimum=1, maximum=2, unit="m/s", boundary=None), length=bilby.core.prior.DeltaFunction(name="c", peak=42, unit="m"), d=bilby.core.prior.DeltaFunction(peak=5), e=bilby.core.prior.DeltaFunction(peak=7.3), @@ -224,19 +190,19 @@ def test_prior_set_from_dict_but_using_a_string(self): unit=None, ), a_1=bilby.core.prior.Uniform( - name="a_1", minimum=0, maximum=0.99, + name="a_1", + minimum=0, + maximum=0.99, ), a_2=bilby.core.prior.Uniform( - name="a_2", minimum=0, maximum=0.99, + name="a_2", + minimum=0, + maximum=0.99, ), tilt_1=bilby.core.prior.Sine(name="tilt_1"), tilt_2=bilby.core.prior.Sine(name="tilt_2"), - phi_12=bilby.core.prior.Uniform( - name="phi_12", minimum=0, maximum=2 * np.pi, boundary="periodic" - ), - phi_jl=bilby.core.prior.Uniform( - name="phi_jl", minimum=0, maximum=2 * np.pi, boundary="periodic" - ), + phi_12=bilby.core.prior.Uniform(name="phi_12", minimum=0, maximum=2 * np.pi, boundary="periodic"), + phi_jl=bilby.core.prior.Uniform(name="phi_jl", minimum=0, maximum=2 * np.pi, boundary="periodic"), luminosity_distance=bilby.gw.prior.UniformSourceFrame( name="luminosity_distance", minimum=1e2, @@ -245,16 +211,10 @@ def test_prior_set_from_dict_but_using_a_string(self): boundary=None, ), dec=bilby.core.prior.Cosine(name="dec"), - ra=bilby.core.prior.Uniform( - name="ra", minimum=0, maximum=2 * np.pi, boundary="periodic" - ), + ra=bilby.core.prior.Uniform(name="ra", minimum=0, maximum=2 * np.pi, boundary="periodic"), theta_jn=bilby.core.prior.Sine(name="theta_jn"), - psi=bilby.core.prior.Uniform( - name="psi", minimum=0, maximum=np.pi, boundary="periodic" - ), - phase=bilby.core.prior.Uniform( - name="phase", minimum=0, maximum=2 * np.pi, boundary="periodic" - ), + psi=bilby.core.prior.Uniform(name="psi", minimum=0, maximum=np.pi, boundary="periodic"), + phase=bilby.core.prior.Uniform(name="phase", minimum=0, maximum=2 * np.pi, boundary="periodic"), ) ) all_keys = set(prior_set.keys()).union(set(expected.keys())) @@ -267,18 +227,14 @@ def test_dict_argument_is_not_string_or_dict(self): def test_sample_subset_correct_size(self): size = 7 - samples = self.prior_set_from_dict.sample_subset( - keys=self.prior_set_from_dict.keys(), size=size - ) + samples = self.prior_set_from_dict.sample_subset(keys=self.prior_set_from_dict.keys(), size=size) self.assertEqual(len(self.prior_set_from_dict), len(samples)) for key in samples: self.assertEqual(size, len(samples[key])) def test_sample_subset_correct_size_when_non_priors_in_dict(self): self.prior_set_from_dict["asdf"] = "not_a_prior" - samples = self.prior_set_from_dict.sample_subset( - keys=self.prior_set_from_dict.keys() - ) + samples = self.prior_set_from_dict.sample_subset(keys=self.prior_set_from_dict.keys()) self.assertEqual(len(self.prior_set_from_dict) - 1, len(samples)) def test_sample_subset_with_actual_subset(self): @@ -295,12 +251,9 @@ def test_sample_subset_constrained_as_array(self): self.assertTrue(out.shape == (len(keys), size)) def test_sample_subset_constrained(self): - def conversion_function(parameters): converted_parameters = parameters.copy() - converted_parameters["delta_mass"] = ( - parameters["mass_1"] - parameters["mass_2"] - ) + converted_parameters["delta_mass"] = parameters["mass_1"] - parameters["mass_2"] return converted_parameters N = 1_000 @@ -311,9 +264,7 @@ def conversion_function(parameters): priors1["delta_mass"] = bilby.core.prior.Constraint(minimum=-2, maximum=0) with patch("bilby.core.prior.logger.warning") as mock_warning: - samples1 = priors1.sample_subset_constrained( - keys=list(priors1.keys()), size=N - ) + samples1 = priors1.sample_subset_constrained(keys=list(priors1.keys()), size=N) self.assertEqual(len(priors1) - 1, len(samples1)) for key in samples1: self.assertEqual(N, len(samples1[key])) @@ -324,9 +275,7 @@ def conversion_function(parameters): priors2["mass_2"] = bilby.core.prior.Uniform(minimum=1, maximum=1.4) with patch("bilby.core.prior.logger.warning") as mock_warning: - samples2 = priors2.sample_subset_constrained( - keys=list(priors2.keys()), size=N - ) + samples2 = priors2.sample_subset_constrained(keys=list(priors2.keys()), size=N) self.assertEqual(len(priors2), len(samples2)) for key in samples2: self.assertEqual(N, len(samples2[key])) @@ -335,9 +284,7 @@ def conversion_function(parameters): def test_sample(self): size = 7 bilby.core.utils.random.seed(42) - samples1 = self.prior_set_from_dict.sample_subset( - keys=self.prior_set_from_dict.keys(), size=size - ) + samples1 = self.prior_set_from_dict.sample_subset(keys=self.prior_set_from_dict.keys(), size=size) bilby.core.utils.random.seed(42) samples2 = self.prior_set_from_dict.sample(size=size) self.assertEqual(set(samples1.keys()), set(samples2.keys())) @@ -346,16 +293,12 @@ def test_sample(self): def test_prob(self): samples = self.prior_set_from_dict.sample_subset(keys=["mass", "speed"]) - expected = self.first_prior.prob(samples["mass"]) * self.second_prior.prob( - samples["speed"] - ) + expected = self.first_prior.prob(samples["mass"]) * self.second_prior.prob(samples["speed"]) self.assertEqual(expected, self.prior_set_from_dict.prob(samples)) def test_ln_prob(self): samples = self.prior_set_from_dict.sample_subset(keys=["mass", "speed"]) - expected = self.first_prior.ln_prob( - samples["mass"] - ) + self.second_prior.ln_prob(samples["speed"]) + expected = self.first_prior.ln_prob(samples["mass"]) + self.second_prior.ln_prob(samples["speed"]) self.assertEqual(expected, self.prior_set_from_dict.ln_prob(samples)) def test_rescale(self): @@ -367,11 +310,7 @@ def test_rescale(self): ] self.assertListEqual( sorted(expected), - sorted( - self.prior_set_from_dict.rescale( - keys=self.prior_set_from_dict.keys(), theta=theta - ) - ), + sorted(self.prior_set_from_dict.rescale(keys=self.prior_set_from_dict.keys(), theta=theta)), ) def test_cdf(self): @@ -382,10 +321,9 @@ def test_cdf(self): """ sample = self.prior_set_from_dict.sample() original = np.array(list(sample.values())) - new = np.array(self.prior_set_from_dict.rescale( - sample.keys(), - self.prior_set_from_dict.cdf(sample=sample).values() - )) + new = np.array( + self.prior_set_from_dict.rescale(sample.keys(), self.prior_set_from_dict.cdf(sample=sample).values()) + ) self.assertLess(max(abs(original - new)), 1e-10) def test_redundancy(self): @@ -421,9 +359,7 @@ def setUp(self): os.path.dirname(os.path.realpath(__file__)), "prior_files/GW150914_testing_skymap.fits", ) - hp_dist = bilby.gw.prior.HealPixMapPriorDist( - hp_map_file, names=["testra", "testdec"] - ) + hp_dist = bilby.gw.prior.HealPixMapPriorDist(hp_map_file, names=["testra", "testdec"]) hp_3d_dist = bilby.gw.prior.HealPixMapPriorDist( hp_map_file, names=["testRA", "testDEC", "testdistance"], distance=True ) @@ -433,27 +369,13 @@ def setUp(self): aa=bilby.core.prior.DeltaFunction(name="test", unit="unit", peak=1), bb=bilby.core.prior.Gaussian(name="test", unit="unit", mu=0, sigma=1), cc=bilby.core.prior.Normal(name="test", unit="unit", mu=0, sigma=1), - dd=bilby.core.prior.PowerLaw( - name="test", unit="unit", alpha=0, minimum=0, maximum=1 - ), - ee=bilby.core.prior.PowerLaw( - name="test", unit="unit", alpha=-1, minimum=0.5, maximum=1 - ), - ff=bilby.core.prior.PowerLaw( - name="test", unit="unit", alpha=2, minimum=1, maximum=1e2 - ), - gg=bilby.core.prior.Uniform( - name="test", unit="unit", minimum=0, maximum=1 - ), - hh=bilby.core.prior.LogUniform( - name="test", unit="unit", minimum=5e0, maximum=1e2 - ), - ii=bilby.gw.prior.UniformComovingVolume( - name="redshift", minimum=0.1, maximum=1.0 - ), - jj=bilby.gw.prior.UniformSourceFrame( - name="luminosity_distance", minimum=1.0, maximum=1000.0 - ), + dd=bilby.core.prior.PowerLaw(name="test", unit="unit", alpha=0, minimum=0, maximum=1), + ee=bilby.core.prior.PowerLaw(name="test", unit="unit", alpha=-1, minimum=0.5, maximum=1), + ff=bilby.core.prior.PowerLaw(name="test", unit="unit", alpha=2, minimum=1, maximum=1e2), + gg=bilby.core.prior.Uniform(name="test", unit="unit", minimum=0, maximum=1), + hh=bilby.core.prior.LogUniform(name="test", unit="unit", minimum=5e0, maximum=1e2), + ii=bilby.gw.prior.UniformComovingVolume(name="redshift", minimum=0.1, maximum=1.0), + jj=bilby.gw.prior.UniformSourceFrame(name="luminosity_distance", minimum=1.0, maximum=1000.0), kk=bilby.core.prior.Sine(name="test", unit="unit"), ll=bilby.core.prior.Cosine(name="test", unit="unit"), m=bilby.core.prior.Interped( @@ -464,77 +386,41 @@ def setUp(self): minimum=3, maximum=5, ), - nn=bilby.core.prior.TruncatedGaussian( - name="test", unit="unit", mu=1, sigma=0.4, minimum=-1, maximum=1 - ), - oo=bilby.core.prior.TruncatedNormal( - name="test", unit="unit", mu=1, sigma=0.4, minimum=-1, maximum=1 - ), + nn=bilby.core.prior.TruncatedGaussian(name="test", unit="unit", mu=1, sigma=0.4, minimum=-1, maximum=1), + oo=bilby.core.prior.TruncatedNormal(name="test", unit="unit", mu=1, sigma=0.4, minimum=-1, maximum=1), pp=bilby.core.prior.HalfGaussian(name="test", unit="unit", sigma=1), qq=bilby.core.prior.HalfNormal(name="test", unit="unit", sigma=1), rr=bilby.core.prior.LogGaussian(name="test", unit="unit", mu=0, sigma=1), ss=bilby.core.prior.LogNormal(name="test", unit="unit", mu=0, sigma=1), tt=bilby.core.prior.Exponential(name="test", unit="unit", mu=1), - uu=bilby.core.prior.StudentT( - name="test", unit="unit", df=3, mu=0, scale=1 - ), + uu=bilby.core.prior.StudentT(name="test", unit="unit", df=3, mu=0, scale=1), vv=bilby.core.prior.Beta(name="test", unit="unit", alpha=2.0, beta=2.0), xx=bilby.core.prior.Logistic(name="test", unit="unit", mu=0, scale=1), yy=bilby.core.prior.Cauchy(name="test", unit="unit", alpha=0, beta=1), - zz=bilby.core.prior.Lorentzian( - name="test", unit="unit", alpha=0, beta=1 - ), + zz=bilby.core.prior.Lorentzian(name="test", unit="unit", alpha=0, beta=1), a_=bilby.core.prior.Gamma(name="test", unit="unit", k=1, theta=1), ab=bilby.core.prior.ChiSquared(name="test", unit="unit", nu=2), ac=bilby.gw.prior.AlignedSpin(name="test", unit="unit"), - testa=bilby.core.prior.MultivariateGaussian( - dist=mvg, name="testa", unit="unit" - ), - testb=bilby.core.prior.MultivariateGaussian( - dist=mvg, name="testb", unit="unit" - ), - testA=bilby.core.prior.MultivariateNormal( - dist=mvn, name="testA", unit="unit" - ), - testB=bilby.core.prior.MultivariateNormal( - dist=mvn, name="testB", unit="unit" - ), - testAsubclass=bilby.core.prior.JointPrior( - dist=mvn_subclass, name="testAsubclass", unit="unit" - ), - testBsubclass=bilby.core.prior.JointPrior( - dist=mvn_subclass, name="testBsubclass", unit="unit" - ), - testAfake=bilby.core.prior.JointPrior( - dist=fake_joint_prior, name="testAfake", unit="unit" - ), - testBfake=bilby.core.prior.JointPrior( - dist=fake_joint_prior, name="testBfake", unit="unit" - ), - testra=bilby.gw.prior.HealPixPrior( - dist=hp_dist, name="testra", unit="unit" - ), - testdec=bilby.gw.prior.HealPixPrior( - dist=hp_dist, name="testdec", unit="unit" - ), - testRA=bilby.gw.prior.HealPixPrior( - dist=hp_3d_dist, name="testRA", unit="unit" - ), - testDEC=bilby.gw.prior.HealPixPrior( - dist=hp_3d_dist, name="testDEC", unit="unit" - ), - testdistance=bilby.gw.prior.HealPixPrior( - dist=hp_3d_dist, name="testdistance", unit="unit" - ), + testa=bilby.core.prior.MultivariateGaussian(dist=mvg, name="testa", unit="unit"), + testb=bilby.core.prior.MultivariateGaussian(dist=mvg, name="testb", unit="unit"), + testA=bilby.core.prior.MultivariateNormal(dist=mvn, name="testA", unit="unit"), + testB=bilby.core.prior.MultivariateNormal(dist=mvn, name="testB", unit="unit"), + testAsubclass=bilby.core.prior.JointPrior(dist=mvn_subclass, name="testAsubclass", unit="unit"), + testBsubclass=bilby.core.prior.JointPrior(dist=mvn_subclass, name="testBsubclass", unit="unit"), + testAfake=bilby.core.prior.JointPrior(dist=fake_joint_prior, name="testAfake", unit="unit"), + testBfake=bilby.core.prior.JointPrior(dist=fake_joint_prior, name="testBfake", unit="unit"), + testra=bilby.gw.prior.HealPixPrior(dist=hp_dist, name="testra", unit="unit"), + testdec=bilby.gw.prior.HealPixPrior(dist=hp_dist, name="testdec", unit="unit"), + testRA=bilby.gw.prior.HealPixPrior(dist=hp_3d_dist, name="testRA", unit="unit"), + testDEC=bilby.gw.prior.HealPixPrior(dist=hp_3d_dist, name="testDEC", unit="unit"), + testdistance=bilby.gw.prior.HealPixPrior(dist=hp_3d_dist, name="testdistance", unit="unit"), ) ) def test_read_write_to_json(self): - """ Interped prior is removed as there is numerical error in the recovered prior.""" + """Interped prior is removed as there is numerical error in the recovered prior.""" self.priors.to_json(outdir="prior_files", label="json_test") - new_priors = bilby.core.prior.PriorDict.from_json( - filename="prior_files/json_test_prior.json" - ) + new_priors = bilby.core.prior.PriorDict.from_json(filename="prior_files/json_test_prior.json") old_interped = self.priors.pop("m") new_interped = new_priors.pop("m") self.assertDictEqual(self.priors, new_priors) @@ -577,9 +463,7 @@ def test_load_prior_with_function(self): class TestCreateDefaultPrior(unittest.TestCase): def test_none_behaviour(self): - self.assertIsNone( - bilby.core.prior.create_default_prior(name="name", default_priors_file=None) - ) + self.assertIsNone(bilby.core.prior.create_default_prior(name="name", default_priors_file=None)) def test_bbh_params(self): prior_file = os.path.join( @@ -590,9 +474,7 @@ def test_bbh_params(self): for prior in prior_set: self.assertEqual( prior_set[prior], - bilby.core.prior.create_default_prior( - name=prior, default_priors_file=prior_file - ), + bilby.core.prior.create_default_prior(name=prior, default_priors_file=prior_file), ) def test_unknown_prior(self): @@ -600,11 +482,7 @@ def test_unknown_prior(self): os.path.dirname(os.path.realpath(__file__)), "prior_files/precessing_spins_bbh.prior", ) - self.assertIsNone( - bilby.core.prior.create_default_prior( - name="name", default_priors_file=prior_file - ) - ) + self.assertIsNone(bilby.core.prior.create_default_prior(name="name", default_priors_file=prior_file)) class TestFillPrior(unittest.TestCase): @@ -645,11 +523,9 @@ def test_without_available_default_priors_no_prior_is_set(self): class TestLoadPriorWithCosmologicalParameters(unittest.TestCase): - def test_load(self): prior_file = os.path.join( - os.path.dirname(os.path.realpath(__file__)), - "prior_files/prior_with_cosmo_params.prior" + os.path.dirname(os.path.realpath(__file__)), "prior_files/prior_with_cosmo_params.prior" ) prior_dict = bilby.gw.prior.BBHPriorDict(filename=prior_file) cosmology = prior_dict["luminosity_distance"].cosmology diff --git a/test/core/prior/interpolated_test.py b/test/core/prior/interpolated_test.py index 9b4a926f8..fba0404ae 100644 --- a/test/core/prior/interpolated_test.py +++ b/test/core/prior/interpolated_test.py @@ -1,5 +1,4 @@ import unittest - if __name__ == "__main__": unittest.main() diff --git a/test/core/prior/joint_test.py b/test/core/prior/joint_test.py index c99373b00..fc0e4da5b 100644 --- a/test/core/prior/joint_test.py +++ b/test/core/prior/joint_test.py @@ -1,8 +1,9 @@ import unittest -import bilby import numpy as np +import bilby + class TestMultivariateGaussianDistFromRepr(unittest.TestCase): def test_mvg_from_repr(self): @@ -34,9 +35,7 @@ def test_mvg_from_repr(self): for d1, d2 in zip(fromstr.__getattribute__(key), item): self.assertTrue(type(d1) == type(d2)) # noqa: E721 elif isinstance(item, (list, tuple, np.ndarray)): - self.assertTrue( - np.all(np.array(item) == np.array(fromstr.__getattribute__(key))) - ) + self.assertTrue(np.all(np.array(item) == np.array(fromstr.__getattribute__(key)))) class TestMultivariateGaussianDistParameterScales(unittest.TestCase): @@ -49,10 +48,8 @@ def _test_mvg_ln_prob_diff_expected(self, mvg, mus, sigmas, corrcoefs): for i in np.ndindex(4, 4, 4): ln_prob_mus_sigmas_d_i = mvg.ln_prob(mus + sigmas * (d @ i)) diff_ln_prob = ln_prob_mus - ln_prob_mus_sigmas_d_i - diff_ln_prob_expected = 0.5 * np.sum(np.array(i)**2) - self.assertTrue( - np.allclose(diff_ln_prob, diff_ln_prob_expected) - ) + diff_ln_prob_expected = 0.5 * np.sum(np.array(i) ** 2) + self.assertTrue(np.allclose(diff_ln_prob, diff_ln_prob_expected)) def test_mvg_unit_scales(self): # test using order-unity standard deviations and correlations @@ -60,7 +57,7 @@ def test_mvg_unit_scales(self): corrcoefs = np.identity(3) mus = np.array([3, 1, 2]) mvg = bilby.core.prior.MultivariateGaussianDist( - names=['a', 'b', 'c'], + names=["a", "b", "c"], mus=mus, sigmas=sigmas, corrcoefs=corrcoefs, @@ -74,11 +71,13 @@ def test_mvg_cw_scales(self): # parameters of a continuous wave signal T = 365.25 * 86400 rho = 10 - sigmas = np.array([ - 5 * np.sqrt(3) / (2 * np.pi * T * rho), - 6 * np.sqrt(5) / (np.pi * T**2 * rho), - 60 * np.sqrt(7) / (np.pi * T**3 * rho) - ]) + sigmas = np.array( + [ + 5 * np.sqrt(3) / (2 * np.pi * T * rho), + 6 * np.sqrt(5) / (np.pi * T**2 * rho), + 60 * np.sqrt(7) / (np.pi * T**3 * rho), + ] + ) corrcoefs = np.identity(3) corrcoefs[0, 2] = corrcoefs[2, 0] = -np.sqrt(21) / 5 diff --git a/test/core/prior/prior_test.py b/test/core/prior/prior_test.py index 17d360d0c..93d4f27e1 100644 --- a/test/core/prior/prior_test.py +++ b/test/core/prior/prior_test.py @@ -1,14 +1,15 @@ -import bilby +import os import unittest + import numpy as np -import os import scipy.stats as ss from scipy.integrate import trapezoid +import bilby + class TestPriorClasses(unittest.TestCase): def setUp(self): - # set multivariate Gaussian mvg = bilby.core.prior.MultivariateGaussianDist( names=["testa", "testb"], @@ -26,9 +27,7 @@ def setUp(self): os.path.dirname(os.path.realpath(__file__)), "prior_files/GW150914_testing_skymap.fits", ) - hp_dist = bilby.gw.prior.HealPixMapPriorDist( - hp_map_file, names=["testra", "testdec"] - ) + hp_dist = bilby.gw.prior.HealPixMapPriorDist(hp_map_file, names=["testra", "testdec"]) hp_3d_dist = bilby.gw.prior.HealPixMapPriorDist( hp_map_file, names=["testra", "testdec", "testdistance"], distance=True ) @@ -40,25 +39,13 @@ def condition_func(reference_params, test_param): bilby.core.prior.DeltaFunction(name="test", unit="unit", peak=1), bilby.core.prior.Gaussian(name="test", unit="unit", mu=0, sigma=1), bilby.core.prior.Normal(name="test", unit="unit", mu=0, sigma=1), - bilby.core.prior.PowerLaw( - name="test", unit="unit", alpha=0, minimum=0, maximum=1 - ), - bilby.core.prior.PowerLaw( - name="test", unit="unit", alpha=-1, minimum=0.5, maximum=1 - ), - bilby.core.prior.PowerLaw( - name="test", unit="unit", alpha=2, minimum=1, maximum=1e2 - ), + bilby.core.prior.PowerLaw(name="test", unit="unit", alpha=0, minimum=0, maximum=1), + bilby.core.prior.PowerLaw(name="test", unit="unit", alpha=-1, minimum=0.5, maximum=1), + bilby.core.prior.PowerLaw(name="test", unit="unit", alpha=2, minimum=1, maximum=1e2), bilby.core.prior.Uniform(name="test", unit="unit", minimum=0, maximum=1), - bilby.core.prior.LogUniform( - name="test", unit="unit", minimum=5e0, maximum=1e2 - ), - bilby.gw.prior.UniformComovingVolume( - name="redshift", minimum=0.1, maximum=1.0 - ), - bilby.gw.prior.UniformSourceFrame( - name="redshift", minimum=0.1, maximum=1.0 - ), + bilby.core.prior.LogUniform(name="test", unit="unit", minimum=5e0, maximum=1e2), + bilby.gw.prior.UniformComovingVolume(name="redshift", minimum=0.1, maximum=1.0), + bilby.gw.prior.UniformSourceFrame(name="redshift", minimum=0.1, maximum=1.0), bilby.core.prior.Sine(name="test", unit="unit"), bilby.core.prior.Cosine(name="test", unit="unit"), bilby.core.prior.Interped( @@ -69,12 +56,8 @@ def condition_func(reference_params, test_param): minimum=3, maximum=5, ), - bilby.core.prior.TruncatedGaussian( - name="test", unit="unit", mu=1, sigma=0.4, minimum=-1, maximum=1 - ), - bilby.core.prior.TruncatedNormal( - name="test", unit="unit", mu=1, sigma=0.4, minimum=-1, maximum=1 - ), + bilby.core.prior.TruncatedGaussian(name="test", unit="unit", mu=1, sigma=0.4, minimum=-1, maximum=1), + bilby.core.prior.TruncatedNormal(name="test", unit="unit", mu=1, sigma=0.4, minimum=-1, maximum=1), bilby.core.prior.HalfGaussian(name="test", unit="unit", sigma=1), bilby.core.prior.HalfNormal(name="test", unit="unit", sigma=1), bilby.core.prior.LogGaussian(name="test", unit="unit", mu=0, sigma=1), @@ -91,16 +74,10 @@ def condition_func(reference_params, test_param): bilby.core.prior.WeightedDiscreteValues( name="test", unit="unit", values=[1, 2, 3, 4], weights=[1, 2, 3, 4] ), - bilby.core.prior.DiscreteValues( - name="test", unit="unit", values=[1, 2, 3, 4] - ), - bilby.core.prior.WeightedCategorical( - name="test", unit="unit", ncategories=4, weights=[1, 2, 3, 4] - ), + bilby.core.prior.DiscreteValues(name="test", unit="unit", values=[1, 2, 3, 4]), + bilby.core.prior.WeightedCategorical(name="test", unit="unit", ncategories=4, weights=[1, 2, 3, 4]), bilby.core.prior.Categorical(name="test", unit="unit", ncategories=5), - bilby.core.prior.SymmetricLogUniform( - name="test", unit="unit", minimum=1e-2, maximum=1e2 - ), + bilby.core.prior.SymmetricLogUniform(name="test", unit="unit", minimum=1e-2, maximum=1e2), bilby.gw.prior.AlignedSpin(name="test", unit="unit"), bilby.gw.prior.AlignedSpin( a_prior=bilby.core.prior.Beta(alpha=2.0, beta=2.0), @@ -113,9 +90,7 @@ def condition_func(reference_params, test_param): bilby.core.prior.MultivariateGaussian(dist=mvg, name="testb", unit="unit"), bilby.core.prior.MultivariateNormal(dist=mvn, name="testa", unit="unit"), bilby.core.prior.MultivariateNormal(dist=mvn, name="testb", unit="unit"), - bilby.core.prior.ConditionalDeltaFunction( - condition_func=condition_func, name="test", unit="unit", peak=1 - ), + bilby.core.prior.ConditionalDeltaFunction(condition_func=condition_func, name="test", unit="unit", peak=1), bilby.core.prior.ConditionalGaussian( condition_func=condition_func, name="test", unit="unit", mu=0, sigma=1 ), @@ -184,12 +159,8 @@ def condition_func(reference_params, test_param): bilby.gw.prior.ConditionalUniformSourceFrame( condition_func=condition_func, name="redshift", minimum=0.1, maximum=1.0 ), - bilby.core.prior.ConditionalSine( - condition_func=condition_func, name="test", unit="unit" - ), - bilby.core.prior.ConditionalCosine( - condition_func=condition_func, name="test", unit="unit" - ), + bilby.core.prior.ConditionalSine(condition_func=condition_func, name="test", unit="unit"), + bilby.core.prior.ConditionalCosine(condition_func=condition_func, name="test", unit="unit"), bilby.core.prior.ConditionalTruncatedGaussian( condition_func=condition_func, name="test", @@ -199,15 +170,11 @@ def condition_func(reference_params, test_param): minimum=-1, maximum=1, ), - bilby.core.prior.ConditionalHalfGaussian( - condition_func=condition_func, name="test", unit="unit", sigma=1 - ), + bilby.core.prior.ConditionalHalfGaussian(condition_func=condition_func, name="test", unit="unit", sigma=1), bilby.core.prior.ConditionalLogNormal( condition_func=condition_func, name="test", unit="unit", mu=0, sigma=1 ), - bilby.core.prior.ConditionalExponential( - condition_func=condition_func, name="test", unit="unit", mu=1 - ), + bilby.core.prior.ConditionalExponential(condition_func=condition_func, name="test", unit="unit", mu=1), bilby.core.prior.ConditionalStudentT( condition_func=condition_func, name="test", @@ -229,19 +196,13 @@ def condition_func(reference_params, test_param): bilby.core.prior.ConditionalCauchy( condition_func=condition_func, name="test", unit="unit", alpha=0, beta=1 ), - bilby.core.prior.ConditionalGamma( - condition_func=condition_func, name="test", unit="unit", k=1, theta=1 - ), - bilby.core.prior.ConditionalChiSquared( - condition_func=condition_func, name="test", unit="unit", nu=2 - ), + bilby.core.prior.ConditionalGamma(condition_func=condition_func, name="test", unit="unit", k=1, theta=1), + bilby.core.prior.ConditionalChiSquared(condition_func=condition_func, name="test", unit="unit", nu=2), bilby.gw.prior.HealPixPrior(dist=hp_dist, name="testra", unit="unit"), bilby.gw.prior.HealPixPrior(dist=hp_dist, name="testdec", unit="unit"), bilby.gw.prior.HealPixPrior(dist=hp_3d_dist, name="testra", unit="unit"), bilby.gw.prior.HealPixPrior(dist=hp_3d_dist, name="testdec", unit="unit"), - bilby.gw.prior.HealPixPrior( - dist=hp_3d_dist, name="testdistance", unit="unit" - ), + bilby.gw.prior.HealPixPrior(dist=hp_3d_dist, name="testdistance", unit="unit"), ] def tearDown(self): @@ -288,16 +249,12 @@ def test_many_sample_rescaling(self): if bilby.core.prior.JointPrior in prior.__class__.__mro__: if not prior.dist.filled_rescale(): continue - self.assertTrue( - all((many_samples >= prior.minimum) & (many_samples <= prior.maximum)) - ) + self.assertTrue(all((many_samples >= prior.minimum) & (many_samples <= prior.maximum))) def test_least_recently_sampled(self): for prior in self.priors: least_recently_sampled_expected = prior.sample() - self.assertEqual( - least_recently_sampled_expected, prior.least_recently_sampled - ) + self.assertEqual(least_recently_sampled_expected, prior.least_recently_sampled) def test_sampling_single(self): """Test that sampling from the prior always returns values within its domain.""" @@ -306,9 +263,7 @@ def test_sampling_single(self): # SymmetricLogUniform has support down to -maximum continue single_sample = prior.sample() - self.assertTrue( - (single_sample >= prior.minimum) & (single_sample <= prior.maximum) - ) + self.assertTrue((single_sample >= prior.minimum) & (single_sample <= prior.maximum)) def test_sampling_many(self): """Test that sampling from the prior always returns values within its domain.""" @@ -317,18 +272,13 @@ def test_sampling_many(self): # SymmetricLogUniform has support down to -maximum continue many_samples = prior.sample(5000) - self.assertTrue( - (all(many_samples >= prior.minimum)) - & (all(many_samples <= prior.maximum)) - ) + self.assertTrue((all(many_samples >= prior.minimum)) & (all(many_samples <= prior.maximum))) def test_probability_above_domain(self): """Test that the prior probability is non-negative in domain of validity and zero outside.""" for prior in self.priors: if prior.maximum != np.inf: - outside_domain = np.linspace( - prior.maximum + 1, prior.maximum + 1e4, 1000 - ) + outside_domain = np.linspace(prior.maximum + 1, prior.maximum + 1e4, 1000) if bilby.core.prior.JointPrior in prior.__class__.__mro__: if not prior.dist.filled_request(): prior.dist.requested_parameters[prior.name] = outside_domain @@ -342,9 +292,7 @@ def test_probability_below_domain(self): # SymmetricLogUniform has support down to -maximum continue if prior.minimum != -np.inf: - outside_domain = np.linspace( - prior.minimum - 1e4, prior.minimum - 1, 1000 - ) + outside_domain = np.linspace(prior.minimum - 1e4, prior.minimum - 1, 1000) if bilby.core.prior.JointPrior in prior.__class__.__mro__: if not prior.dist.filled_request(): prior.dist.requested_parameters[prior.name] = outside_domain @@ -362,9 +310,7 @@ def test_prob_and_ln_prob(self): if not bilby.core.prior.JointPrior in prior.__class__.__mro__: # noqa # due to the way that the Multivariate Gaussian prior must sequentially call # the prob and ln_prob functions, it must be ignored in this test. - self.assertAlmostEqual( - np.log(prior.prob(sample)), prior.ln_prob(sample), 12 - ) + self.assertAlmostEqual(np.log(prior.prob(sample)), prior.ln_prob(sample), 12) def test_many_prob_and_many_ln_prob(self): for prior in self.priors: @@ -400,9 +346,7 @@ def test_cdf_is_inverse_of_rescaling(self): def test_cdf_one_above_domain(self): for prior in self.priors: if prior.maximum != np.inf: - outside_domain = np.linspace( - prior.maximum + 1, prior.maximum + 1e4, 1000 - ) + outside_domain = np.linspace(prior.maximum + 1, prior.maximum + 1e4, 1000) self.assertTrue(all(prior.cdf(outside_domain) == 1)) def test_cdf_zero_below_domain(self): @@ -410,15 +354,10 @@ def test_cdf_zero_below_domain(self): if isinstance(prior, bilby.core.prior.analytical.SymmetricLogUniform): # SymmetricLogUniform has support down to -maximum continue - if ( - bilby.core.prior.JointPrior in prior.__class__.__mro__ - and prior.maximum == np.inf - ): + if bilby.core.prior.JointPrior in prior.__class__.__mro__ and prior.maximum == np.inf: continue if prior.minimum != -np.inf: - outside_domain = np.linspace( - prior.minimum - 1e4, prior.minimum - 1, 1000 - ) + outside_domain = np.linspace(prior.minimum - 1e4, prior.minimum - 1, 1000) self.assertTrue(all(np.nan_to_num(prior.cdf(outside_domain)) == 0)) def test_cdf_float_with_float_input(self): @@ -439,10 +378,10 @@ def test_studentt_fail(self): def test_beta_fail(self): with self.assertRaises(ValueError): - bilby.core.prior.Beta(name="test", unit="unit", alpha=-2.0, beta=2.0), + (bilby.core.prior.Beta(name="test", unit="unit", alpha=-2.0, beta=2.0),) with self.assertRaises(ValueError): - bilby.core.prior.Beta(name="test", unit="unit", alpha=2.0, beta=-2.0), + (bilby.core.prior.Beta(name="test", unit="unit", alpha=2.0, beta=-2.0),) def test_multivariate_gaussian_fail(self): with self.assertRaises(ValueError): @@ -450,19 +389,13 @@ def test_multivariate_gaussian_fail(self): bilby.core.prior.MultivariateGaussianDist(["a", "b"], bounds=[(-1.0, 1.0)]) with self.assertRaises(ValueError): # bounds has lower value greater than upper - bilby.core.prior.MultivariateGaussianDist( - ["a", "b"], bounds=[(-1.0, 1.0), (1.0, -1)] - ) + bilby.core.prior.MultivariateGaussianDist(["a", "b"], bounds=[(-1.0, 1.0), (1.0, -1)]) with self.assertRaises(TypeError): # bound is not a list/tuple - bilby.core.prior.MultivariateGaussianDist( - ["a", "b"], bounds=[(-1.0, 1.0), 2] - ) + bilby.core.prior.MultivariateGaussianDist(["a", "b"], bounds=[(-1.0, 1.0), 2]) with self.assertRaises(ValueError): # bound contains too many values - bilby.core.prior.MultivariateGaussianDist( - ["a", "b"], bounds=[(-1.0, 1.0, 4), 2] - ) + bilby.core.prior.MultivariateGaussianDist(["a", "b"], bounds=[(-1.0, 1.0, 4), 2]) with self.assertRaises(ValueError): # means is not a list bilby.core.prior.MultivariateGaussianDist(["a", "b"], mus=1.0) @@ -480,19 +413,13 @@ def test_multivariate_gaussian_fail(self): bilby.core.prior.MultivariateGaussianDist(["a", "b"], weights=[0.5, 0.5]) with self.assertRaises(ValueError): # not enough modes set - bilby.core.prior.MultivariateGaussianDist( - ["a", "b"], mus=[[1.0, 2.0]], nmodes=2 - ) + bilby.core.prior.MultivariateGaussianDist(["a", "b"], mus=[[1.0, 2.0]], nmodes=2) with self.assertRaises(ValueError): # covariance is the wrong shape - bilby.core.prior.MultivariateGaussianDist( - ["a", "b"], covs=np.array([[[1.0, 1.0], [1.0, 1.0]]]) - ) + bilby.core.prior.MultivariateGaussianDist(["a", "b"], covs=np.array([[[1.0, 1.0], [1.0, 1.0]]])) with self.assertRaises(ValueError): # covariance is the wrong shape - bilby.core.prior.MultivariateGaussianDist( - ["a", "b"], covs=np.array([[[1.0, 1.0]]]) - ) + bilby.core.prior.MultivariateGaussianDist(["a", "b"], covs=np.array([[[1.0, 1.0]]])) with self.assertRaises(ValueError): # correlation coefficient matrix is the wrong shape bilby.core.prior.MultivariateGaussianDist( @@ -502,9 +429,7 @@ def test_multivariate_gaussian_fail(self): ) with self.assertRaises(ValueError): # correlation coefficient matrix is the wrong shape - bilby.core.prior.MultivariateGaussianDist( - ["a", "b"], sigmas=[1.0, 1.0], corrcoefs=np.array([[[1.0, 1.0]]]) - ) + bilby.core.prior.MultivariateGaussianDist(["a", "b"], sigmas=[1.0, 1.0], corrcoefs=np.array([[[1.0, 1.0]]])) with self.assertRaises(ValueError): # correlation coefficient has non-unity diagonal value bilby.core.prior.MultivariateGaussianDist( @@ -545,9 +470,7 @@ def test_multivariate_gaussian_covariance(self): corrcoef = np.array([[1.0, 0.5], [0.5, 1.0]]) sigma = [2.0, 2.0] - mvg = bilby.core.prior.MultivariateGaussianDist( - ["a", "b"], corrcoefs=corrcoef, sigmas=sigma - ) + mvg = bilby.core.prior.MultivariateGaussianDist(["a", "b"], corrcoefs=corrcoef, sigmas=sigma) self.assertTrue(np.allclose(mvg.corrcoefs[0], corrcoef)) self.assertTrue(np.allclose(mvg.sigmas[0], sigma)) self.assertTrue(np.allclose(np.diag(mvg.covs[0]), np.square(sigma))) @@ -580,12 +503,8 @@ def test_probability_surrounding_domain(self): # SymmetricLogUniform has support down to -maximum continue surround_domain = np.linspace(prior.minimum - 1, prior.maximum + 1, 1000) - indomain = (surround_domain >= prior.minimum) | ( - surround_domain <= prior.maximum - ) - outdomain = (surround_domain < prior.minimum) | ( - surround_domain > prior.maximum - ) + indomain = (surround_domain >= prior.minimum) | (surround_domain <= prior.maximum) + outdomain = (surround_domain < prior.minimum) | (surround_domain > prior.maximum) if bilby.core.prior.JointPrior in prior.__class__.__mro__: if not prior.dist.filled_request(): continue @@ -667,9 +586,7 @@ def test_accuracy(self): scipy_lnprob = ss.t.logpdf(domain, 3, loc=0, scale=1) scipy_cdf = ss.t.cdf(domain, 3, loc=0, scale=1) scipy_rescale = ss.t.ppf(rescale_domain, 3, loc=0, scale=1) - elif isinstance(prior, bilby.core.prior.Gamma) and not isinstance( - prior, bilby.core.prior.ChiSquared - ): + elif isinstance(prior, bilby.core.prior.Gamma) and not isinstance(prior, bilby.core.prior.ChiSquared): domain = np.linspace(0.0, 1e2, 5000) scipy_prob = ss.gamma.pdf(domain, 1, loc=0, scale=1) scipy_lnprob = ss.gamma.logpdf(domain, 1, loc=0, scale=1) @@ -735,9 +652,7 @@ def test_accuracy(self): np.testing.assert_almost_equal(prior.prob(domain), scipy_prob) np.testing.assert_almost_equal(prior.ln_prob(domain), scipy_lnprob) np.testing.assert_almost_equal(prior.cdf(domain), scipy_cdf) - np.testing.assert_almost_equal( - prior.rescale(rescale_domain), scipy_rescale - ) + np.testing.assert_almost_equal(prior.rescale(rescale_domain), scipy_rescale) def test_unit_setting(self): for prior in self.priors: @@ -755,21 +670,13 @@ def test_eq_different_classes(self): self.assertNotEqual(self.priors[i], self.priors[j]) def test_eq_other_condition(self): - prior_1 = bilby.core.prior.PowerLaw( - name="test", unit="unit", alpha=0, minimum=0, maximum=1 - ) - prior_2 = bilby.core.prior.PowerLaw( - name="test", unit="unit", alpha=0, minimum=0, maximum=1.5 - ) + prior_1 = bilby.core.prior.PowerLaw(name="test", unit="unit", alpha=0, minimum=0, maximum=1) + prior_2 = bilby.core.prior.PowerLaw(name="test", unit="unit", alpha=0, minimum=0, maximum=1.5) self.assertNotEqual(prior_1, prior_2) def test_eq_different_keys(self): - prior_1 = bilby.core.prior.PowerLaw( - name="test", unit="unit", alpha=0, minimum=0, maximum=1 - ) - prior_2 = bilby.core.prior.PowerLaw( - name="test", unit="unit", alpha=0, minimum=0, maximum=1 - ) + prior_1 = bilby.core.prior.PowerLaw(name="test", unit="unit", alpha=0, minimum=0, maximum=1) + prior_2 = bilby.core.prior.PowerLaw(name="test", unit="unit", alpha=0, minimum=0, maximum=1) prior_2.other_key = 5 self.assertNotEqual(prior_1, prior_2) diff --git a/test/core/prior/slabspike_test.py b/test/core/prior/slabspike_test.py index d2cdcc55a..71ec3fb71 100644 --- a/test/core/prior/slabspike_test.py +++ b/test/core/prior/slabspike_test.py @@ -1,23 +1,37 @@ -import numpy as np import unittest +import numpy as np + import bilby +from bilby.core.prior.analytical import ( + Beta, + Cauchy, + ChiSquared, + Cosine, + Exponential, + Gamma, + Gaussian, + HalfGaussian, + Logistic, + LogNormal, + LogUniform, + PowerLaw, + Sine, + StudentT, + TruncatedGaussian, + Uniform, +) from bilby.core.prior.slabspike import SlabSpikePrior -from bilby.core.prior.analytical import Uniform, PowerLaw, LogUniform, TruncatedGaussian, \ - Beta, Gaussian, Cosine, Sine, HalfGaussian, LogNormal, Exponential, StudentT, Logistic, \ - Cauchy, Gamma, ChiSquared class TestSlabSpikePrior(unittest.TestCase): - def setUp(self): self.minimum = 0 self.maximum = 1 self.spike_loc = 0.5 self.spike_height = 0.3 self.slab = bilby.core.prior.Prior(minimum=self.minimum, maximum=self.maximum) - self.prior = SlabSpikePrior( - slab=self.slab, spike_location=self.spike_loc, spike_height=self.spike_height) + self.prior = SlabSpikePrior(slab=self.slab, spike_location=self.spike_loc, spike_height=self.spike_height) def tearDown(self): del self.minimum @@ -61,7 +75,6 @@ def test_set_spike_height_domain_edge(self): class TestSlabSpikeClasses(unittest.TestCase): - def setUp(self): self.minimum = 0.4 self.maximum = 2.4 @@ -83,15 +96,20 @@ def setUp(self): StudentT(df=2), Logistic(mu=2, scale=1), Cauchy(alpha=1, beta=2), - Gamma(k=1, theta=1.), - ChiSquared(nu=2)] - self.slab_spikes = [SlabSpikePrior(slab, spike_height=self.spike_height, spike_location=self.spike_loc) - for slab in self.slabs] + Gamma(k=1, theta=1.0), + ChiSquared(nu=2), + ] + self.slab_spikes = [ + SlabSpikePrior(slab, spike_height=self.spike_height, spike_location=self.spike_loc) for slab in self.slabs + ] self.test_nodes_finite_support = np.linspace(self.minimum, self.maximum, 1000) self.test_nodes_infinite_support = np.linspace(-10, 10, 1000) - self.test_nodes = [self.test_nodes_finite_support - if np.isinf(slab.minimum) or np.isinf(slab.maximum) - else self.test_nodes_finite_support for slab in self.slabs] + self.test_nodes = [ + self.test_nodes_finite_support + if np.isinf(slab.minimum) or np.isinf(slab.maximum) + else self.test_nodes_finite_support + for slab in self.slabs + ] def tearDown(self): del self.minimum @@ -187,8 +205,9 @@ def test_rescale_below_spike(self): def test_rescale_at_spike(self): for slab, slab_spike in zip(self.slabs, self.slab_spikes): - vals = np.linspace(slab_spike.inverse_cdf_below_spike, - slab_spike.inverse_cdf_below_spike + slab_spike.spike_height, 1000) + vals = np.linspace( + slab_spike.inverse_cdf_below_spike, slab_spike.inverse_cdf_below_spike + slab_spike.spike_height, 1000 + ) expected = np.ones(len(vals)) * slab.rescale(vals[0] / slab_spike.slab_fraction) actual = slab_spike.rescale(vals) self.assertTrue(np.allclose(expected, actual, rtol=1e-5)) @@ -196,7 +215,6 @@ def test_rescale_at_spike(self): def test_rescale_above_spike(self): for slab, slab_spike in zip(self.slabs, self.slab_spikes): vals = np.linspace(slab_spike.inverse_cdf_below_spike + self.spike_height, 1, 1000) - expected = np.ones(len(vals)) * slab.rescale( - (vals - self.spike_height) / slab_spike.slab_fraction) + expected = np.ones(len(vals)) * slab.rescale((vals - self.spike_height) / slab_spike.slab_fraction) actual = slab_spike.rescale(vals) self.assertTrue(np.allclose(expected, actual, rtol=1e-5)) diff --git a/test/core/result_test.py b/test/core/result_test.py index 23ba8e6b5..1d2e50a93 100644 --- a/test/core/result_test.py +++ b/test/core/result_test.py @@ -1,20 +1,20 @@ +import json +import os +import shutil import unittest +from unittest.mock import patch + import numpy as np import pandas as pd -import shutil -import os -import json import parameterized import pytest -from unittest.mock import patch import bilby -from bilby.core.result import ResultError, FileLoadError +from bilby.core.result import FileLoadError, ResultError from bilby.core.utils import logger class TestJson(unittest.TestCase): - def setUp(self): self.encoder = bilby.core.utils.BilbyJsonEncoder self.decoder = bilby.core.utils.decode_bilby_json @@ -54,7 +54,6 @@ def test_dataframe_encoding(self): class TestResult(unittest.TestCase): - @pytest.fixture(autouse=True) def init_outdir(self, tmp_path): # Use pytest's tmp_path fixture to create a temporary directory @@ -88,9 +87,7 @@ def setUp(self): ) n = 100 - posterior = pd.DataFrame( - dict(x=np.random.normal(0, 1, n), y=np.random.normal(0, 1, n)) - ) + posterior = pd.DataFrame(dict(x=np.random.normal(0, 1, n), y=np.random.normal(0, 1, n))) result.posterior = posterior result.log_evidence = 10 result.log_evidence_err = 11 @@ -113,7 +110,7 @@ def test_result_file_name_default(self): label = "label" self.assertEqual( bilby.core.result.result_file_name(outdir, label), - "{}/{}_result.json".format(outdir, label), + f"{outdir}/{label}_result.json", ) def test_result_file_name_hdf5(self): @@ -121,7 +118,7 @@ def test_result_file_name_hdf5(self): label = "label" self.assertEqual( bilby.core.result.result_file_name(outdir, label, extension="hdf5"), - "{}/{}_result.hdf5".format(outdir, label), + f"{outdir}/{label}_result.hdf5", ) def test_result_file_name_pkl(self): @@ -129,7 +126,7 @@ def test_result_file_name_pkl(self): label = "label" self.assertEqual( bilby.core.result.result_file_name(outdir, label, extension="pkl"), - "{}/{}_result.pkl".format(outdir, label), + f"{outdir}/{label}_result.pkl", ) def test_result_file_name_pickle(self): @@ -137,13 +134,11 @@ def test_result_file_name_pickle(self): label = "label" self.assertEqual( bilby.core.result.result_file_name(outdir, label, extension="pickle"), - "{}/{}_result.pkl".format(outdir, label), + f"{outdir}/{label}_result.pkl", ) def test_fail_save_and_load_missing_inputs(self): - with self.assertRaises( - ValueError, msg="No information given to load file" - ): + with self.assertRaises(ValueError, msg="No information given to load file"): bilby.core.result.read_in_result() def test_fail_save_and_load_no_extension(self): @@ -172,11 +167,9 @@ def test_fail_save_and_load_incomplete_json(self): "priors": { "chirp_mass": { """ - with open("{}/incomplete.json".format(self.result.outdir), "wb") as ff: + with open(f"{self.result.outdir}/incomplete.json", "wb") as ff: ff.write(incomplete_json) - bilby.core.result.read_in_result( - filename="{}/incomplete.json".format(self.result.outdir) - ) + bilby.core.result.read_in_result(filename=f"{self.result.outdir}/incomplete.json") def test_unset_priors(self): result = bilby.core.result.Result( @@ -193,9 +186,7 @@ def test_unset_priors(self): with self.assertRaises(ValueError): _ = result.priors self.assertEqual(result.parameter_labels, result.search_parameter_keys) - self.assertEqual( - result.parameter_labels_with_unit, result.search_parameter_keys - ) + self.assertEqual(result.parameter_labels_with_unit, result.search_parameter_keys) def test_unknown_priors_fail(self): with self.assertRaises(ValueError): @@ -237,16 +228,16 @@ def test_unset_posterior(self): _ = self.result.posterior def test_save_and_load_json(self): - self._save_and_load_test(extension='json') + self._save_and_load_test(extension="json") def test_save_and_load_json_gzip(self): - self._save_and_load_test(extension='json', gzip=True) + self._save_and_load_test(extension="json", gzip=True) def test_save_and_load_pkl(self): - self._save_and_load_test(extension='pkl') + self._save_and_load_test(extension="pkl") def test_save_and_load_hdf5(self): - self._save_and_load_test(extension='hdf5') + self._save_and_load_test(extension="hdf5") def _save_and_load_test(self, extension, gzip=False): self.result.save_to_file(extension=extension, gzip=gzip) @@ -259,20 +250,12 @@ def _save_and_load_test(self, extension, gzip=False): loaded_result.posterior.sort_values(by=["x"]), ) ) - self.assertTrue( - self.result.fixed_parameter_keys == loaded_result.fixed_parameter_keys - ) - self.assertTrue( - self.result.search_parameter_keys == loaded_result.search_parameter_keys - ) + self.assertTrue(self.result.fixed_parameter_keys == loaded_result.fixed_parameter_keys) + self.assertTrue(self.result.search_parameter_keys == loaded_result.search_parameter_keys) self.assertEqual(self.result.meta_data, loaded_result.meta_data) - self.assertEqual( - self.result.injection_parameters, loaded_result.injection_parameters - ) + self.assertEqual(self.result.injection_parameters, loaded_result.injection_parameters) self.assertEqual(self.result.log_evidence, loaded_result.log_evidence) - self.assertEqual( - self.result.log_noise_evidence, loaded_result.log_noise_evidence - ) + self.assertEqual(self.result.log_noise_evidence, loaded_result.log_noise_evidence) self.assertEqual(self.result.log_evidence_err, loaded_result.log_evidence_err) self.assertEqual(self.result.log_bayes_factor, loaded_result.log_bayes_factor) self.assertEqual(self.result.priors["x"], loaded_result.priors["x"]) @@ -282,13 +265,13 @@ def _save_and_load_test(self, extension, gzip=False): self.assertEqual(self.result.sampling_time, loaded_result.sampling_time) def test_save_and_dont_overwrite_json(self): - self._save_and_dont_overwrite_test(extension='json') + self._save_and_dont_overwrite_test(extension="json") def test_save_and_dont_overwrite_pkl(self): - self._save_and_dont_overwrite_test(extension='pkl') + self._save_and_dont_overwrite_test(extension="pkl") def test_save_and_dont_overwrite_hdf5(self): - self._save_and_dont_overwrite_test(extension='hdf5') + self._save_and_dont_overwrite_test(extension="hdf5") def _save_and_dont_overwrite_test(self, extension): self.result.save_to_file(overwrite=False, extension=extension) @@ -305,19 +288,13 @@ def test_save_with_outdir_and_filename_different_outdir(self): ) def test_save_with_outdir_and_filename_same_outdir(self): - self._save_with_outdir_and_filename( - f"{self.other_outdir}/result", None, f"{self.other_outdir}/result" - ) + self._save_with_outdir_and_filename(f"{self.other_outdir}/result", None, f"{self.other_outdir}/result") def test_save_with_outdir_and_filename_no_outdir_in_filename(self): - self._save_with_outdir_and_filename( - "result", self.other_outdir, f"{self.other_outdir}/result" - ) + self._save_with_outdir_and_filename("result", self.other_outdir, f"{self.other_outdir}/result") def test_save_with_filename_only(self): - self._save_with_outdir_and_filename( - "result", None, os.path.join(self.result.outdir, "result") - ) + self._save_with_outdir_and_filename("result", None, os.path.join(self.result.outdir, "result")) def test_save_with_outdir_no_filename(self): self._save_with_outdir_and_filename( @@ -330,13 +307,13 @@ def test_save_no_filename_or_outdir(self): ) def test_save_and_overwrite_json(self): - self._save_and_overwrite_test(extension='json') + self._save_and_overwrite_test(extension="json") def test_save_and_overwrite_pkl(self): - self._save_and_overwrite_test(extension='pkl') + self._save_and_overwrite_test(extension="pkl") def test_save_and_overwrite_hdf5(self): - self._save_and_overwrite_test(extension='hdf5') + self._save_and_overwrite_test(extension="hdf5") def _save_and_overwrite_test(self, extension): self.result.save_to_file(overwrite=True, extension=extension) @@ -345,17 +322,13 @@ def _save_and_overwrite_test(self, extension): def test_save_samples(self): self.result.save_posterior_samples() - filename = "{}/{}_posterior_samples.dat".format( - self.result.outdir, self.result.label - ) + filename = f"{self.result.outdir}/{self.result.label}_posterior_samples.dat" self.assertTrue(os.path.isfile(filename)) df = pd.read_csv(filename, sep=" ") self.assertTrue(np.allclose(self.result.posterior.values, df.values)) def test_save_samples_from_filename(self): - filename = "{}/{}_posterior_samples_OTHER.dat".format( - self.result.outdir, self.result.label - ) + filename = f"{self.result.outdir}/{self.result.label}_posterior_samples_OTHER.dat" self.result.save_posterior_samples(filename=filename) self.assertTrue(os.path.isfile(filename)) df = pd.read_csv(filename, sep=" ") @@ -363,9 +336,7 @@ def test_save_samples_from_filename(self): def test_save_samples_numpy_load(self): self.result.save_posterior_samples() - filename = "{}/{}_posterior_samples.dat".format( - self.result.outdir, self.result.label - ) + filename = f"{self.result.outdir}/{self.result.label}_posterior_samples.dat" self.assertTrue(os.path.isfile(filename)) data = np.genfromtxt(filename, names=True) df = pd.read_csv(filename, sep=" ") @@ -391,22 +362,16 @@ def test_samples_to_posterior(self): self.result.samples_to_posterior(priors=self.result.priors) self.assertTrue(all(self.result.posterior["x"] == x)) self.assertTrue(all(self.result.posterior["y"] == y)) - self.assertTrue( - np.array_equal(self.result.posterior.log_likelihood.values, log_likelihood) - ) - self.assertTrue( - all(self.result.posterior.c.values == self.result.priors["c"].peak) - ) - self.assertTrue( - all(self.result.posterior.d.values == self.result.priors["d"].peak) - ) + self.assertTrue(np.array_equal(self.result.posterior.log_likelihood.values, log_likelihood)) + self.assertTrue(all(self.result.posterior.c.values == self.result.priors["c"].peak)) + self.assertTrue(all(self.result.posterior.d.values == self.result.priors["d"].peak)) def test_calculate_prior_values(self): self.result.calculate_prior_values(priors=self.result.priors) self.assertEqual(len(self.result.posterior), len(self.result.prior_values)) def test_plot_multiple(self): - filename = "{}/multiple.png".format(self.result.outdir) + filename = f"{self.result.outdir}/multiple.png" bilby.core.result.plot_multiple([self.result, self.result], filename=filename) self.assertTrue(os.path.isfile(filename)) os.remove(filename) @@ -415,11 +380,7 @@ def test_plot_walkers(self): self.result.walkers = np.random.uniform(0, 1, (10, 11, 2)) self.result.nburn = 5 self.result.plot_walkers() - self.assertTrue( - os.path.isfile( - "{}/{}_walkers.png".format(self.result.outdir, self.result.label) - ) - ) + self.assertTrue(os.path.isfile(f"{self.result.outdir}/{self.result.label}_walkers.png")) def test_plot_with_data(self): x = np.linspace(0, 1, 10) @@ -430,14 +391,8 @@ def model(xx, theta): self.result.posterior = pd.DataFrame(dict(theta=[1, 2, 3])) self.result.plot_with_data(model, x, y, ndraws=10) - self.assertTrue( - os.path.isfile( - "{}/{}_plot_with_data.png".format(self.result.outdir, self.result.label) - ) - ) - self.result.posterior["log_likelihood"] = np.random.uniform( - 0, 1, len(self.result.posterior) - ) + self.assertTrue(os.path.isfile(f"{self.result.outdir}/{self.result.label}_plot_with_data.png")) + self.result.posterior["log_likelihood"] = np.random.uniform(0, 1, len(self.result.posterior)) self.result.plot_with_data(model, x, y, ndraws=10, xlabel="a", ylabel="y") def test_plot_corner(self): @@ -482,9 +437,7 @@ def test_get_credible_levels_raises_error_if_no_injection_parameters(self): self.result.injection_parameters = None with self.assertRaises(TypeError) as error_context: self.result.get_all_injection_credible_levels() - self.assertTrue( - "Result object has no 'injection_parameters" in str(error_context.exception) - ) + self.assertTrue("Result object has no 'injection_parameters" in str(error_context.exception)) def test_kde(self): kde = self.result.kde @@ -495,19 +448,13 @@ def test_kde(self): def test_posterior_probability(self): sample = dict(x=0, y=0.1) - self.assertTrue( - isinstance(self.result.posterior_probability(sample), np.ndarray) - ) + self.assertTrue(isinstance(self.result.posterior_probability(sample), np.ndarray)) self.assertTrue(len(self.result.posterior_probability(sample)), 1) - self.assertEqual( - self.result.posterior_probability(sample)[0], self.result.kde([0, 0.1]) - ) + self.assertEqual(self.result.posterior_probability(sample)[0], self.result.kde([0, 0.1])) def test_multiple_posterior_probability(self): sample = [dict(x=0, y=0.1), dict(x=0.8, y=0)] - self.assertTrue( - isinstance(self.result.posterior_probability(sample), np.ndarray) - ) + self.assertTrue(isinstance(self.result.posterior_probability(sample), np.ndarray)) self.assertTrue( np.array_equal( self.result.posterior_probability(sample), @@ -528,33 +475,22 @@ def test_to_arviz(self): self.assertTrue("x" in az.posterior and "y" in az.posterior) for var in ["x", "y"]: - self.assertTrue(np.array_equal(az.posterior[var].values.squeeze(), - self.result.posterior[var].values)) + self.assertTrue(np.array_equal(az.posterior[var].values.squeeze(), self.result.posterior[var].values)) self.assertTrue(len(az.prior[var][0]) == Nprior) - self.assertTrue(np.array_equal(az.log_likelihood["log_likelihood"].values.squeeze(), - log_likelihood)) + self.assertTrue(np.array_equal(az.log_likelihood["log_likelihood"].values.squeeze(), log_likelihood)) - self.assertTrue( - az.posterior.attrs["inference_library"] == "bilby: {}".format( - self.result.sampler - ) - ) - self.assertTrue( - az.posterior.attrs["inference_library_version"] - == bilby.utils.get_version_information() - ) + self.assertTrue(az.posterior.attrs["inference_library"] == f"bilby: {self.result.sampler}") + self.assertTrue(az.posterior.attrs["inference_library_version"] == bilby.utils.get_version_information()) # add log likelihood to samples and extract from there del az self.result.posterior["log_likelihood"] = log_likelihood az = self.result.to_arviz() - self.assertTrue(np.array_equal(az.log_likelihood["log_likelihood"].values.squeeze(), - log_likelihood)) + self.assertTrue(np.array_equal(az.log_likelihood["log_likelihood"].values.squeeze(), log_likelihood)) @patch("builtins.__import__") def test_to_arviz_not_installed(self, mock_import): - def import_side_effect(name, *args): if name == "arviz": raise ImportError @@ -565,19 +501,15 @@ def import_side_effect(name, *args): with self.assertRaises(ResultError) as excinfo: self.result.to_arviz() - self.assertEqual( - str(excinfo.exception), - "ArviZ is not installed, so cannot convert to InferenceData." - ) + self.assertEqual(str(excinfo.exception), "ArviZ is not installed, so cannot convert to InferenceData.") def test_result_caching(self): - class SimpleLikelihood(bilby.Likelihood): def __init__(self): super().__init__(parameters={"x": None}) def log_likelihood(self): - return -self.parameters["x"]**2 + return -(self.parameters["x"] ** 2) likelihood = SimpleLikelihood() priors = dict(x=bilby.core.prior.Uniform(-5, 5, "x")) @@ -590,14 +522,14 @@ class NotAResult(bilby.core.result.Result): result = bilby.run_sampler( likelihood, priors, - sampler='bilby_mcmc', + sampler="bilby_mcmc", outdir=self.outdir, nsamples=10, L1steps=1, proposal_cycle="default_noGMnoKD", printdt=1, check_point_plot=False, - result_class=NotAResult + result_class=NotAResult, ) assert isinstance(result, NotAResult) @@ -607,7 +539,7 @@ class NotAResult(bilby.core.result.Result): cached_result = bilby.run_sampler( likelihood, priors, - sampler='bilby_mcmc', + sampler="bilby_mcmc", outdir=self.outdir, nsamples=10, L1steps=1, @@ -815,17 +747,18 @@ def test_sanity_check_labels(self): class TestPPPlots(unittest.TestCase): - @pytest.fixture(autouse=True) def init_outdir(self, tmp_path): # Use pytest's tmp_path fixture to create a temporary directory self.outdir = str(tmp_path / "test_pp_plots") def setUp(self): - priors = bilby.core.prior.PriorDict(dict( - a=bilby.core.prior.Uniform(0, 1, latex_label="$a$"), - b=bilby.core.prior.Uniform(0, 1, latex_label="$b$"), - )) + priors = bilby.core.prior.PriorDict( + dict( + a=bilby.core.prior.Uniform(0, 1, latex_label="$a$"), + b=bilby.core.prior.Uniform(0, 1, latex_label="$b$"), + ) + ) self.results = [ bilby.core.result.Result( label=str(ii), @@ -847,9 +780,7 @@ def test_pp_plot_raises_error_with_wrong_number_of_lines(self): def test_pp_plot_raises_error_with_wrong_number_of_confidence_intervals(self): with self.assertRaises(ValueError): - _ = bilby.core.result.make_pp_plot( - self.results, save=False, confidence_interval_alpha=[0.1] - ) + _ = bilby.core.result.make_pp_plot(self.results, save=False, confidence_interval_alpha=[0.1]) class SimpleGaussianLikelihood(bilby.core.likelihood.Likelihood): @@ -858,6 +789,7 @@ def __init__(self, mean=0, sigma=1): A very simple Gaussian likelihood for testing """ from scipy.stats import norm + super().__init__(parameters=dict()) self.mean = mean self.sigma = sigma @@ -868,11 +800,12 @@ def log_likelihood(self): class TestReweight(unittest.TestCase): - def setUp(self): - self.priors = bilby.core.prior.PriorDict(dict( - mu=bilby.core.prior.TruncatedNormal(0, 1, minimum=-5, maximum=5), - )) + self.priors = bilby.core.prior.PriorDict( + dict( + mu=bilby.core.prior.TruncatedNormal(0, 1, minimum=-5, maximum=5), + ) + ) self.result = bilby.core.result.Result( search_parameter_keys=list(self.priors.keys()), priors=self.priors, @@ -890,9 +823,7 @@ def _run_reweighting(self, sigma): self.result.posterior["log_prior"] = self.priors.ln_prob(self.result.posterior) self.result.posterior["log_likelihood"] = original_ln_likelihoods self.original_ln_likelihoods = original_ln_likelihoods - return bilby.core.result.reweight( - self.result, likelihood_1, likelihood_2, verbose_output=True - ) + return bilby.core.result.reweight(self.result, likelihood_1, likelihood_2, verbose_output=True) def test_reweight_same_likelihood_weights_1(self): """ @@ -909,18 +840,15 @@ def test_reweight_different_likelihood_weights_correct(self): should be close to the original evidence within statistical error. """ from scipy.stats import norm + new, weights, _, _, _, _ = self._run_reweighting(sigma=0.5) - expected_weights = ( - norm(0, 0.5).pdf(self.result.posterior["mu"]) - / norm(0, 1).pdf(self.result.posterior["mu"]) - ) + expected_weights = norm(0, 0.5).pdf(self.result.posterior["mu"]) / norm(0, 1).pdf(self.result.posterior["mu"]) self.assertLess(min(abs(weights - expected_weights)), 1e-10) self.assertLess(abs(new.log_evidence - self.result.log_evidence), 0.05) self.assertNotEqual(new.log_evidence, self.result.log_evidence) class TestResultSaveAndRead(unittest.TestCase): - @pytest.fixture(autouse=True) def init_outdir(self, tmp_path): # Use pytest's tmp_path fixture to create a temporary directory @@ -951,9 +879,7 @@ def setUp(self): ) n = 100 - posterior = pd.DataFrame( - dict(x=np.random.normal(0, 1, n), y=np.random.normal(0, 1, n)) - ) + posterior = pd.DataFrame(dict(x=np.random.normal(0, 1, n), y=np.random.normal(0, 1, n))) result.posterior = posterior result.log_evidence = 10 result.log_evidence_err = 11 @@ -961,9 +887,15 @@ def setUp(self): result.log_noise_evidence = 13 self.result = result - @parameterized.parameterized.expand([ - ".h5", ".hdf5", ".json", ".pkl", ".pickle", - ]) + @parameterized.parameterized.expand( + [ + ".h5", + ".hdf5", + ".json", + ".pkl", + ".pickle", + ] + ) def test_save_and_read_filename_with_extension_and_extension_none(self, ext): # Should use the extension from filename filename = os.path.join(self.result.outdir, f"custom_name.{ext}") @@ -972,17 +904,19 @@ def test_save_and_read_filename_with_extension_and_extension_none(self, ext): bilby.core.result.read_in_result(filename=filename) os.remove(filename) - @parameterized.parameterized.expand([ - ("json",), - ("pkl",), - ("pickle",), - (True,), - ]) + @parameterized.parameterized.expand( + [ + ("json",), + ("pkl",), + ("pickle",), + (True,), + ] + ) def test_save_and_read_filename_with_extension_and_extension(self, extension): """Test all the extensions that are support when the filename is provided""" filename = os.path.join(self.result.outdir, "custom_name.hdf5") expected = filename - with self.assertLogs(logger, level='WARNING') as cm: + with self.assertLogs(logger, level="WARNING") as cm: self.result.save_to_file(filename=filename, extension=extension) self.assertIn("does not match the provided extension", cm.output[0]) self.assertTrue(os.path.isfile(expected)) @@ -1010,21 +944,21 @@ def test_save_to_file_defaults_to_pickle_with_incorrect_extension(self): bilby.core.result.read_in_result(filename=expected) os.remove(expected) - @parameterized.parameterized.expand([ - ("json", "hdf5"), - ("json", "pkl"), - ("hdf5", "json"), - ("pkl", "json"), - ("json", "pkl"), - ("hdf5", "pkl"), - ]) + @parameterized.parameterized.expand( + [ + ("json", "hdf5"), + ("json", "pkl"), + ("hdf5", "json"), + ("pkl", "json"), + ("json", "pkl"), + ("hdf5", "pkl"), + ] + ) def test_save_and_read_incorrect_extension(self, save_extension, read_extension): """Test that an incorrect extension raises a somewhat helpful error""" filename = os.path.join(self.result.outdir, "my_result") self.result.save_to_file(filename=filename, extension=save_extension) - with self.assertRaises( - (FileLoadError, IOError), msg=f"Failed to read in file {filename}" - ): + with self.assertRaises((FileLoadError, IOError), msg=f"Failed to read in file {filename}"): bilby.core.result.read_in_result(filename=filename, extension=read_extension) os.remove(filename) diff --git a/test/core/sampler/base_sampler_test.py b/test/core/sampler/base_sampler_test.py index d20ee978a..401205c75 100644 --- a/test/core/sampler/base_sampler_test.py +++ b/test/core/sampler/base_sampler_test.py @@ -3,9 +3,9 @@ import shutil import unittest from unittest.mock import MagicMock -from parameterized import parameterized import numpy as np +from parameterized import parameterized import bilby from bilby.core import prior @@ -38,7 +38,7 @@ def setUp(self, soft_init=False): outdir=test_directory, use_ratio=False, skip_import_verification=True, - soft_init=soft_init + soft_init=soft_init, ) def tearDown(self): @@ -50,15 +50,11 @@ def test_softinit(self): def test_search_parameter_keys(self): expected_search_parameter_keys = ["c"] - self.assertListEqual( - self.sampler.search_parameter_keys, expected_search_parameter_keys - ) + self.assertListEqual(self.sampler.search_parameter_keys, expected_search_parameter_keys) def test_fixed_parameter_keys(self): expected_fixed_parameter_keys = ["a"] - self.assertListEqual( - self.sampler.fixed_parameter_keys, expected_fixed_parameter_keys - ) + self.assertListEqual(self.sampler.fixed_parameter_keys, expected_fixed_parameter_keys) def test_ndim(self): self.assertEqual(self.sampler.ndim, 1) @@ -88,9 +84,7 @@ def test_prior_transform_transforms_search_parameter_keys(self): def test_prior_transform_does_not_transform_fixed_parameter_keys(self): self.sampler.prior_transform([0]) - self.assertEqual( - self.sampler.priors["a"].peak, prior.DeltaFunction(peak=0).peak - ) + self.assertEqual(self.sampler.priors["a"].peak, prior.DeltaFunction(peak=0).peak) def test_log_prior(self): self.assertEqual(self.sampler.log_prior({1}), 0.0) @@ -120,40 +114,28 @@ def test_bad_value_nan(self): self.sampler._check_bad_value(val=np.nan, warning=False, theta=None, label=None) def test_bad_value_np_abs_nan(self): - self.sampler._check_bad_value( - val=np.abs(np.nan), warning=False, theta=None, label=None - ) + self.sampler._check_bad_value(val=np.abs(np.nan), warning=False, theta=None, label=None) def test_bad_value_abs_nan(self): - self.sampler._check_bad_value( - val=abs(np.nan), warning=False, theta=None, label=None - ) + self.sampler._check_bad_value(val=abs(np.nan), warning=False, theta=None, label=None) def test_bad_value_pos_inf(self): self.sampler._check_bad_value(val=np.inf, warning=False, theta=None, label=None) def test_bad_value_neg_inf(self): - self.sampler._check_bad_value( - val=-np.inf, warning=False, theta=None, label=None - ) + self.sampler._check_bad_value(val=-np.inf, warning=False, theta=None, label=None) def test_bad_value_pos_inf_nan_to_num(self): - self.sampler._check_bad_value( - val=np.nan_to_num(np.inf), warning=False, theta=None, label=None - ) + self.sampler._check_bad_value(val=np.nan_to_num(np.inf), warning=False, theta=None, label=None) def test_bad_value_neg_inf_nan_to_num(self): - self.sampler._check_bad_value( - val=np.nan_to_num(-np.inf), warning=False, theta=None, label=None - ) + self.sampler._check_bad_value(val=np.nan_to_num(-np.inf), warning=False, theta=None, label=None) def test_get_expected_outputs(): outdir = os.path.join("some", "bilby_pipe", "dir") label = "par0" - filenames, directories = bilby.core.sampler.Sampler.get_expected_outputs( - outdir=outdir, label=label - ) + filenames, directories = bilby.core.sampler.Sampler.get_expected_outputs(outdir=outdir, label=label) assert len(filenames) == 0 assert len(directories) == 1 assert directories[0] == os.path.join(outdir, f"sampler_{label}", "") @@ -163,9 +145,7 @@ def test_get_expected_outputs_abbreviation(): outdir = os.path.join("some", "bilby_pipe", "dir") label = "par0" bilby.core.sampler.Sampler.abbreviation = "abbr" - filenames, directories = bilby.core.sampler.Sampler.get_expected_outputs( - outdir=outdir, label=label - ) + filenames, directories = bilby.core.sampler.Sampler.get_expected_outputs(outdir=outdir, label=label) assert len(filenames) == 0 assert len(directories) == 1 assert directories[0] == os.path.join(outdir, f"abbr_{label}", "") @@ -208,9 +188,7 @@ def test_pool_creates_properly_no_pool(self, sampler_name): @parameterized.expand(samplers) def test_pool_creates_properly_pool(self, sampler): - sampler = loaded_samplers[sampler]( - self.likelihood, self.priors, npool=2 - ) + sampler = loaded_samplers[sampler](self.likelihood, self.priors, npool=2) sampler._setup_pool() if hasattr(sampler, "setup_sampler"): sampler.setup_sampler() @@ -230,12 +208,8 @@ def tearDown(self): def test_ordering(self): func = bilby.core.sampler.base_sampler.NestedSampler.reorder_loglikelihoods - sorted_ln_likelihoods = func( - self.unsorted_ln_likelihoods, self.unsorted_samples, self.sorted_samples - ) - self.assertTrue( - np.array_equal(sorted_ln_likelihoods, self.sorted_ln_likelihoods) - ) + sorted_ln_likelihoods = func(self.unsorted_ln_likelihoods, self.unsorted_samples, self.sorted_samples) + self.assertTrue(np.array_equal(sorted_ln_likelihoods, self.sorted_ln_likelihoods)) if __name__ == "__main__": diff --git a/test/core/sampler/dynesty_test.py b/test/core/sampler/dynesty_test.py index 4177c4fea..842728ea5 100644 --- a/test/core/sampler/dynesty_test.py +++ b/test/core/sampler/dynesty_test.py @@ -3,14 +3,14 @@ import unittest from copy import deepcopy -import bilby -import bilby.core.sampler.dynesty import numpy as np import parameterized import pytest from attr import define -from scipy.stats import gamma, ks_1samp, uniform, powerlaw +from scipy.stats import gamma, ks_1samp, powerlaw, uniform +import bilby +import bilby.core.sampler.dynesty try: import dynesty.internal_samplers # noqa @@ -29,9 +29,7 @@ class Dummy: axes: np.ndarray scale: float = 1 rseed: float = 1234 - kwargs: dict = dict( - walks=500, live=np.zeros((2, 4)), periodic=None, reflective=None - ) + kwargs: dict = dict(walks=500, live=np.zeros((2, 4)), periodic=None, reflective=None) prior_transform: callable = lambda x: x loglikelihood: callable = lambda x: 0 loglstar: float = -1 @@ -122,9 +120,7 @@ def test_prior_boundary(self): self.priors["e"] = bilby.core.prior.Prior(boundary="periodic") self.init_sampler() if NEW_DYNESTY_API: - self.assertEqual( - [0, 4], self.dysampler.internal_sampler_next.sampler_kwargs["periodic"] - ) + self.assertEqual([0, 4], self.dysampler.internal_sampler_next.sampler_kwargs["periodic"]) self.assertEqual( [1, 3], self.dysampler.internal_sampler_next.sampler_kwargs["reflective"], @@ -164,9 +160,7 @@ def test_sampler_kwargs_rwalk(self): def test_sampler_kwargs_acceptance_walk(self): self.init_sampler(sample="acceptance-walk", naccept=5, maxmcmc=200) if NEW_DYNESTY_API: - self.assertIsInstance( - self.dysampler.internal_sampler_next, dynesty_utils.EnsembleWalkSampler - ) + self.assertIsInstance(self.dysampler.internal_sampler_next, dynesty_utils.EnsembleWalkSampler) self.assertEqual(self.dysampler.internal_sampler_next.naccept, 5) self.assertEqual(self.dysampler.internal_sampler_next.maxmcmc, 200) else: @@ -177,12 +171,14 @@ def test_sampler_kwargs_acceptance_walk(self): def test_run_test_runs(self): self.sampler._run_test() - @parameterized.parameterized.expand(( - ("unif", "single"), - ("unif", "multi"), - ("rslice", "single"), - ("rslice", "multi"), - )) + @parameterized.parameterized.expand( + ( + ("unif", "single"), + ("unif", "multi"), + ("rslice", "single"), + ("rslice", "multi"), + ) + ) def test_dynesty_native_methods_initialize(self, sample, bound): """ Make sure that we can initialize the native dynesty samplers. @@ -194,9 +190,7 @@ def test_dynesty_native_methods_initialize(self, sample, bound): def test_get_expected_outputs(): label = "par0" outdir = os.path.join("some", "bilby_pipe", "dir") - filenames, directories = bilby.core.sampler.dynesty.Dynesty.get_expected_outputs( - outdir=outdir, label=label - ) + filenames, directories = bilby.core.sampler.dynesty.Dynesty.get_expected_outputs(outdir=outdir, label=label) assert len(filenames) == 2 assert len(directories) == 0 assert os.path.join(outdir, f"{label}_resume.pickle") in filenames @@ -204,19 +198,13 @@ def test_get_expected_outputs(): class ProposalsTest(unittest.TestCase): - def test_boundaries(self): inputs = np.array([0.1, 1.1, -1.3]) expected = np.array([0.1, 0.1, 0.7]) periodic = [1] reflective = [2] self.assertLess( - max( - abs( - dynesty_utils.apply_boundaries_(inputs, periodic, reflective) - - expected - ) - ), + max(abs(dynesty_utils.apply_boundaries_(inputs, periodic, reflective) - expected)), 1e-10, ) @@ -233,13 +221,9 @@ def test_propose_volumetric(self): for _ in range(1000): new_samples.append(proposal_func(start, axes, 1, 4, 2, rng)) new_samples = np.array(new_samples) + self.assertGreater(ks_1samp(new_samples[:, 2:].flatten(), uniform(0, 1).cdf).pvalue, 0.01) self.assertGreater( - ks_1samp(new_samples[:, 2:].flatten(), uniform(0, 1).cdf).pvalue, 0.01 - ) - self.assertGreater( - ks_1samp( - np.linalg.norm(new_samples[:, :2], axis=-1), powerlaw(2, scale=2).cdf - ).pvalue, + ks_1samp(np.linalg.norm(new_samples[:, :2], axis=-1), powerlaw(2, scale=2).cdf).pvalue, 0.01, ) @@ -252,9 +236,7 @@ def test_propose_differential_evolution_mode_hopping(self): for _ in range(1000): new_samples.append(proposal_func(start, live, 4, 2, rng, mix=0)) new_samples = np.array(new_samples) - self.assertGreater( - ks_1samp(new_samples[:, 2:].flatten(), uniform(0, 1).cdf).pvalue, 0.01 - ) + self.assertGreater(ks_1samp(new_samples[:, 2:].flatten(), uniform(0, 1).cdf).pvalue, 0.01) self.assertLess(np.max(abs(new_samples[:, :2]) - np.array([1, 1])), 1e-10) @parameterized.parameterized.expand(((1,), (None,), (5,))) @@ -265,19 +247,13 @@ def test_propose_differential_evolution(self, scale): start = np.zeros(4) new_samples = list() for _ in range(1000): - new_samples.append( - proposal_func(start, live, 4, 2, rng, mix=1, scale=scale) - ) + new_samples.append(proposal_func(start, live, 4, 2, rng, mix=1, scale=scale)) new_samples = np.array(new_samples) if scale is None: scale = 1.17 + self.assertGreater(ks_1samp(new_samples[:, 2:].flatten(), uniform(0, 1).cdf).pvalue, 0.01) self.assertGreater( - ks_1samp(new_samples[:, 2:].flatten(), uniform(0, 1).cdf).pvalue, 0.01 - ) - self.assertGreater( - ks_1samp( - np.abs(new_samples[:, :2].flatten()), gamma(4, scale=scale / 4).cdf - ).pvalue, + ks_1samp(np.abs(new_samples[:, :2].flatten()), gamma(4, scale=scale / 4).cdf).pvalue, 0.01, ) @@ -290,9 +266,7 @@ def test_get_proposal_kwargs_diff(self): dynesty_utils._SamplingContainer.proposals = ["diff"] proposals, common, specific = dynesty_utils._get_proposal_kwargs(args) del common["rstate"] - self.assertTrue( - np.array_equal(proposals, np.array(["diff"] * args.kwargs["walks"])) - ) + self.assertTrue(np.array_equal(proposals, np.array(["diff"] * args.kwargs["walks"]))) self.assertDictEqual(common, dict(n=len(args.u), n_cluster=len(args.axes))) assert np.array_equal(args.kwargs["live"][:1, :2], specific["diff"]["live"]) del specific["diff"]["live"] @@ -307,13 +281,9 @@ def test_get_proposal_kwargs_volumetric(self): dynesty_utils._SamplingContainer.proposals = ["volumetric"] proposals, common, specific = dynesty_utils._get_proposal_kwargs(args) del common["rstate"] - self.assertTrue( - np.array_equal(proposals, np.array(["volumetric"] * args.kwargs["walks"])) - ) + self.assertTrue(np.array_equal(proposals, np.array(["volumetric"] * args.kwargs["walks"]))) self.assertDictEqual(common, dict(n=len(args.u), n_cluster=len(args.axes))) - self.assertDictEqual( - specific, dict(volumetric=dict(axes=args.axes, scale=args.scale)) - ) + self.assertDictEqual(specific, dict(volumetric=dict(axes=args.axes, scale=args.scale))) @pytest.mark.skipif(NEW_DYNESTY_API, reason="Invalid for new dynesty API") def test_proposal_functions_run_old(self): @@ -433,7 +403,6 @@ def test_converges_to_correct_value(self): class TestReproducibility(unittest.TestCase): - @staticmethod def model(x, m, c): return m * x + c @@ -445,12 +414,8 @@ def setUp(self): self.x = np.linspace(0, 1, 11) self.injection_parameters = dict(m=0.5, c=0.2) self.sigma = 0.1 - self.y = self.model(self.x, **self.injection_parameters) + rng.normal( - 0, self.sigma, len(self.x) - ) - self.likelihood = bilby.likelihood.GaussianLikelihood( - self.x, self.y, self.model, self.sigma - ) + self.y = self.model(self.x, **self.injection_parameters) + rng.normal(0, self.sigma, len(self.x)) + self.likelihood = bilby.likelihood.GaussianLikelihood(self.x, self.y, self.model, self.sigma) self.priors = bilby.core.prior.PriorDict() self.priors["m"] = bilby.core.prior.Uniform(0, 5, boundary="periodic") diff --git a/test/core/sampler/fake_sampler_test.py b/test/core/sampler/fake_sampler_test.py index 9b4a926f8..fba0404ae 100644 --- a/test/core/sampler/fake_sampler_test.py +++ b/test/core/sampler/fake_sampler_test.py @@ -1,5 +1,4 @@ import unittest - if __name__ == "__main__": unittest.main() diff --git a/test/core/sampler/general_sampler_tests.py b/test/core/sampler/general_sampler_tests.py index 38700c632..1b5bf1a78 100644 --- a/test/core/sampler/general_sampler_tests.py +++ b/test/core/sampler/general_sampler_tests.py @@ -1,8 +1,9 @@ +import pytest + from bilby.core.sampler import ( get_implemented_samplers, get_sampler_class, ) -import pytest def test_get_implemented_samplers(): @@ -25,8 +26,5 @@ def test_get_sampler_class(): def test_get_sampler_class_not_implemented(): """Assert an error is raised if the sampler is not recognized""" - with pytest.raises( - ValueError, - match=r"Sampler not_a_valid_sampler not yet implemented" - ): + with pytest.raises(ValueError, match=r"Sampler not_a_valid_sampler not yet implemented"): get_sampler_class("not_a_valid_sampler") diff --git a/test/core/sampler/implemented_samplers_test.py b/test/core/sampler/implemented_samplers_test.py index c34398185..20374f1c2 100644 --- a/test/core/sampler/implemented_samplers_test.py +++ b/test/core/sampler/implemented_samplers_test.py @@ -1,6 +1,7 @@ -from bilby.core.sampler import IMPLEMENTED_SAMPLERS, ImplementedSamplers import pytest +from bilby.core.sampler import IMPLEMENTED_SAMPLERS, ImplementedSamplers + def test_singleton(): assert ImplementedSamplers() is IMPLEMENTED_SAMPLERS @@ -20,10 +21,8 @@ def test_allowed_keys(): def test_values(): # Values and keys should have the same lengths - assert len(list(IMPLEMENTED_SAMPLERS.values())) \ - == len(list(IMPLEMENTED_SAMPLERS.keys())) - assert len(list(IMPLEMENTED_SAMPLERS.values())) \ - == len(list(IMPLEMENTED_SAMPLERS._samplers.values())) + assert len(list(IMPLEMENTED_SAMPLERS.values())) == len(list(IMPLEMENTED_SAMPLERS.keys())) + assert len(list(IMPLEMENTED_SAMPLERS.values())) == len(list(IMPLEMENTED_SAMPLERS._samplers.values())) def test_items(): diff --git a/test/core/sampler/nessai_test.py b/test/core/sampler/nessai_test.py index 3246c74e7..370d1a52d 100644 --- a/test/core/sampler/nessai_test.py +++ b/test/core/sampler/nessai_test.py @@ -1,9 +1,9 @@ +import os import unittest -from unittest.mock import MagicMock, patch, mock_open +from unittest.mock import MagicMock, mock_open, patch import bilby import bilby.core.sampler.nessai -import os class TestNessai(unittest.TestCase): @@ -26,8 +26,8 @@ def setUp(self): ) self.expected = self.sampler.default_kwargs self.expected["n_pool"] = 1 # Because npool=1 by default - self.expected['output'] = 'outdir/label_nessai/' - self.expected['seed'] = 150914 + self.expected["output"] = "outdir/label_nessai/" + self.expected["seed"] = 150914 def tearDown(self): del self.likelihood @@ -88,9 +88,7 @@ def test_update_from_config_file(self): def test_get_expected_outputs(): label = "par0" outdir = os.path.join("some", "bilby_pipe", "dir") - filenames, directories = bilby.core.sampler.nessai.Nessai.get_expected_outputs( - outdir=outdir, label=label - ) + filenames, directories = bilby.core.sampler.nessai.Nessai.get_expected_outputs(outdir=outdir, label=label) assert len(filenames) == 0 assert len(directories) == 3 base_dir = os.path.join(outdir, f"{label}_nessai", "") diff --git a/test/core/sampler/proposal_test.py b/test/core/sampler/proposal_test.py index 1b2331af0..a2663db79 100644 --- a/test/core/sampler/proposal_test.py +++ b/test/core/sampler/proposal_test.py @@ -35,9 +35,7 @@ class TestJumpProposal(unittest.TestCase): def setUp(self): self.priors = prior.PriorDict( dict( - reflective=prior.Uniform( - minimum=-0.5, maximum=1, boundary="reflective" - ), + reflective=prior.Uniform(minimum=-0.5, maximum=1, boundary="reflective"), periodic=prior.Uniform(minimum=-0.5, maximum=1, boundary="periodic"), default=prior.Uniform(minimum=-0.5, maximum=1), ) @@ -45,15 +43,9 @@ def setUp(self): self.sample_above = dict(reflective=1.1, periodic=1.1, default=1.1) self.sample_below = dict(reflective=-0.6, periodic=-0.6, default=-0.6) self.sample_way_above_case1 = dict(reflective=272, periodic=272, default=272) - self.sample_way_above_case2 = dict( - reflective=270.1, periodic=270.1, default=270.1 - ) - self.sample_way_below_case1 = dict( - reflective=-274, periodic=-274.1, default=-274 - ) - self.sample_way_below_case2 = dict( - reflective=-273.1, periodic=-273.1, default=-273.1 - ) + self.sample_way_above_case2 = dict(reflective=270.1, periodic=270.1, default=270.1) + self.sample_way_below_case1 = dict(reflective=-274, periodic=-274.1, default=-274) + self.sample_way_below_case2 = dict(reflective=-273.1, periodic=-273.1, default=-273.1) self.jump_proposal = proposal.JumpProposal(priors=self.priors) def tearDown(self): @@ -148,11 +140,13 @@ def test_set_step_size(self): def test_jump_proposal_call(self): sample = proposal.Sample(dict(reflective=0.0, periodic=0.0, default=0.0)) new_sample = self.jump_proposal(sample) - expected = proposal.Sample(dict( - reflective=0.5942057242396577, - periodic=-0.02692301311556511, - default=-0.7450848662857457, - )) + expected = proposal.Sample( + dict( + reflective=0.5942057242396577, + periodic=-0.02692301311556511, + default=-0.7450848662857457, + ) + ) self.assertDictEqual(expected, new_sample) @@ -160,9 +154,7 @@ class TestEnsembleWalk(unittest.TestCase): def setUp(self): self.priors = prior.PriorDict( dict( - reflective=prior.Uniform( - minimum=-0.5, maximum=1, boundary="reflective" - ), + reflective=prior.Uniform(minimum=-0.5, maximum=1, boundary="reflective"), periodic=prior.Uniform(minimum=-0.5, maximum=1, boundary="periodic"), default=prior.Uniform(minimum=-0.5, maximum=1), ) @@ -188,10 +180,7 @@ def test_random_number_generator_init(self): self.assertEqual(bilby.core.utils.random.rng.uniform, self.jump_proposal.random_number_generator) def test_get_center_of_mass(self): - samples = [ - proposal.Sample(dict(reflective=0.1 * i, periodic=0.1 * i, default=0.1 * i)) - for i in range(3) - ] + samples = [proposal.Sample(dict(reflective=0.1 * i, periodic=0.1 * i, default=0.1 * i)) for i in range(3)] expected = proposal.Sample(dict(reflective=0.1, periodic=0.1, default=0.1)) actual = self.jump_proposal.get_center_of_mass(samples) for key in samples[0].keys(): @@ -200,11 +189,13 @@ def test_get_center_of_mass(self): def test_jump_proposal_call(self): sample = proposal.Sample(dict(periodic=0.1, reflective=0.1, default=0.1)) new_sample = self.jump_proposal(sample, coordinates=self.coordinates) - expected = proposal.Sample(dict( - periodic=0.437075089594473, - reflective=-0.18027731528487945, - default=-0.17570046901727415, - )) + expected = proposal.Sample( + dict( + periodic=0.437075089594473, + reflective=-0.18027731528487945, + default=-0.17570046901727415, + ) + ) for key, value in new_sample.items(): self.assertAlmostEqual(expected[key], value) @@ -213,9 +204,7 @@ class TestEnsembleEnsembleStretch(unittest.TestCase): def setUp(self): self.priors = prior.PriorDict( dict( - reflective=prior.Uniform( - minimum=-0.5, maximum=1, boundary="reflective" - ), + reflective=prior.Uniform(minimum=-0.5, maximum=1, boundary="reflective"), periodic=prior.Uniform(minimum=-0.5, maximum=1, boundary="periodic"), default=prior.Uniform(minimum=-0.5, maximum=1), ) @@ -236,22 +225,20 @@ def test_set_get_scale(self): self.assertEqual(5.0, self.jump_proposal.scale) def test_jump_proposal_call(self): - sample = proposal.Sample( - dict(periodic=0.1, reflective=0.1, default=0.1) - ) + sample = proposal.Sample(dict(periodic=0.1, reflective=0.1, default=0.1)) new_sample = self.jump_proposal(sample, coordinates=self.coordinates) - expected = proposal.Sample(dict( - periodic=0.5790181653312239, - reflective=-0.028378746842481914, - default=-0.23534241783479043, - )) + expected = proposal.Sample( + dict( + periodic=0.5790181653312239, + reflective=-0.028378746842481914, + default=-0.23534241783479043, + ) + ) for key, value in new_sample.items(): self.assertAlmostEqual(expected[key], value) def test_log_j_after_call(self): - sample = proposal.Sample( - dict(periodic=0.2, reflective=0.2, default=0.2) - ) + sample = proposal.Sample(dict(periodic=0.2, reflective=0.2, default=0.2)) self.jump_proposal(sample=sample, coordinates=self.coordinates) self.assertAlmostEqual(-3.2879289432183088, self.jump_proposal.log_j, 10) @@ -260,17 +247,13 @@ class TestDifferentialEvolution(unittest.TestCase): def setUp(self): self.priors = prior.PriorDict( dict( - reflective=prior.Uniform( - minimum=-0.5, maximum=1, boundary="reflective" - ), + reflective=prior.Uniform(minimum=-0.5, maximum=1, boundary="reflective"), periodic=prior.Uniform(minimum=-0.5, maximum=1, boundary="periodic"), default=prior.Uniform(minimum=-0.5, maximum=1), ) ) bilby.core.utils.random.seed(5) - self.jump_proposal = proposal.DifferentialEvolution( - sigma=1e-3, mu=0.5, priors=self.priors - ) + self.jump_proposal = proposal.DifferentialEvolution(sigma=1e-3, mu=0.5, priors=self.priors) self.coordinates = [proposal.Sample(self.priors.sample()) for _ in range(10)] def tearDown(self): @@ -289,14 +272,14 @@ def test_set_get_sigma(self): self.assertEqual(2, self.jump_proposal.sigma) def test_jump_proposal_call(self): - sample = proposal.Sample( - dict(periodic=0.1, reflective=0.1, default=0.1) + sample = proposal.Sample(dict(periodic=0.1, reflective=0.1, default=0.1)) + expected = proposal.Sample( + dict( + periodic=0.09440864471444077, + reflective=0.567962015300636, + default=0.0657296821780595, + ) ) - expected = proposal.Sample(dict( - periodic=0.09440864471444077, - reflective=0.567962015300636, - default=0.0657296821780595, - )) new_sample = self.jump_proposal(sample, coordinates=self.coordinates) for key, value in new_sample.items(): self.assertAlmostEqual(expected[key], value) @@ -306,9 +289,7 @@ class TestEnsembleEigenVector(unittest.TestCase): def setUp(self): self.priors = prior.PriorDict( dict( - reflective=prior.Uniform( - minimum=-0.5, maximum=1, boundary="reflective" - ), + reflective=prior.Uniform(minimum=-0.5, maximum=1, boundary="reflective"), periodic=prior.Uniform(minimum=-0.5, maximum=1, boundary="periodic"), default=prior.Uniform(minimum=-0.5, maximum=1), ) @@ -343,9 +324,7 @@ def test_jump_proposal_update_eigenvectors_1_d(self): self.jump_proposal.update_eigenvectors(coordinates) self.assertTrue(np.equal(np.array([1]), self.jump_proposal.eigen_values)) self.assertTrue(np.equal(np.array([1]), self.jump_proposal.covariance)) - self.assertTrue( - np.equal(np.array([[1.0]]), self.jump_proposal.eigen_vectors) - ) + self.assertTrue(np.equal(np.array([[1.0]]), self.jump_proposal.eigen_vectors)) def test_jump_proposal_update_eigenvectors_n_d(self): coordinates = [ diff --git a/test/core/sampler/ptemcee_test.py b/test/core/sampler/ptemcee_test.py index 4708a12b0..f99a6f0e2 100644 --- a/test/core/sampler/ptemcee_test.py +++ b/test/core/sampler/ptemcee_test.py @@ -1,11 +1,12 @@ +import os import unittest +import numpy as np + from bilby.core.likelihood import GaussianLikelihood -from bilby.core.prior import Uniform, PriorDict -from bilby.core.sampler.ptemcee import Ptemcee +from bilby.core.prior import PriorDict, Uniform from bilby.core.sampler.base_sampler import MCMCSampler -import numpy as np -import os +from bilby.core.sampler.ptemcee import Ptemcee class TestPTEmcee(unittest.TestCase): @@ -71,10 +72,7 @@ def test_set_pos0_using_dict(self): """ old = np.array(self.sampler.get_pos0()) pos0 = np.moveaxis(old, -1, 0) - pos0 = { - key: points for key, points in - zip(self.sampler.search_parameter_keys, pos0) - } + pos0 = {key: points for key, points in zip(self.sampler.search_parameter_keys, pos0)} new_sampler = Ptemcee(*self._args, **self._kwargs, pos0=pos0) new = new_sampler.get_pos0() self.assertTrue(np.array_equal(new, old)) @@ -93,9 +91,7 @@ def test_set_pos0_from_minimize(self): def test_get_expected_outputs(): label = "par0" outdir = os.path.join("some", "bilby_pipe", "dir") - filenames, directories = Ptemcee.get_expected_outputs( - outdir=outdir, label=label - ) + filenames, directories = Ptemcee.get_expected_outputs(outdir=outdir, label=label) assert len(filenames) == 1 assert len(directories) == 0 assert os.path.join(outdir, f"{label}_checkpoint_resume.pickle") in filenames diff --git a/test/core/sampler/ptmcmc_test.py b/test/core/sampler/ptmcmc_test.py index 9b4a926f8..fba0404ae 100644 --- a/test/core/sampler/ptmcmc_test.py +++ b/test/core/sampler/ptmcmc_test.py @@ -1,5 +1,4 @@ import unittest - if __name__ == "__main__": unittest.main() diff --git a/test/core/sampler/pymultinest_test.py b/test/core/sampler/pymultinest_test.py index 7ec64b486..678428261 100644 --- a/test/core/sampler/pymultinest_test.py +++ b/test/core/sampler/pymultinest_test.py @@ -29,40 +29,66 @@ def tearDown(self): del self.sampler def test_default_kwargs(self): - expected = dict(importance_nested_sampling=False, resume=True, - verbose=True, sampling_efficiency='parameter', - n_live_points=500, n_params=2, - n_clustering_params=None, wrapped_params=None, - multimodal=True, const_efficiency_mode=False, - evidence_tolerance=0.5, - n_iter_before_update=100, null_log_evidence=-1e90, - max_modes=100, mode_tolerance=-1e90, seed=-1, - context=0, write_output=True, log_zero=-1e100, - max_iter=0, init_MPI=False, dump_callback='dumper') - self.sampler.kwargs['dump_callback'] = 'dumper' # Check like the dynesty print_func - self.assertListEqual([1, 0], self.sampler.kwargs['wrapped_params']) # Check this separately - self.sampler.kwargs['wrapped_params'] = None # The dict comparison can't handle lists + expected = dict( + importance_nested_sampling=False, + resume=True, + verbose=True, + sampling_efficiency="parameter", + n_live_points=500, + n_params=2, + n_clustering_params=None, + wrapped_params=None, + multimodal=True, + const_efficiency_mode=False, + evidence_tolerance=0.5, + n_iter_before_update=100, + null_log_evidence=-1e90, + max_modes=100, + mode_tolerance=-1e90, + seed=-1, + context=0, + write_output=True, + log_zero=-1e100, + max_iter=0, + init_MPI=False, + dump_callback="dumper", + ) + self.sampler.kwargs["dump_callback"] = "dumper" # Check like the dynesty print_func + self.assertListEqual([1, 0], self.sampler.kwargs["wrapped_params"]) # Check this separately + self.sampler.kwargs["wrapped_params"] = None # The dict comparison can't handle lists self.assertDictEqual(expected, self.sampler.kwargs) def test_translate_kwargs(self): - expected = dict(importance_nested_sampling=False, resume=True, - verbose=True, sampling_efficiency='parameter', - n_live_points=123, n_params=2, - n_clustering_params=None, wrapped_params=None, - multimodal=True, const_efficiency_mode=False, - evidence_tolerance=0.5, - n_iter_before_update=100, null_log_evidence=-1e90, - max_modes=100, mode_tolerance=-1e90, seed=-1, - context=0, write_output=True, log_zero=-1e100, - max_iter=0, init_MPI=False, dump_callback='dumper') + expected = dict( + importance_nested_sampling=False, + resume=True, + verbose=True, + sampling_efficiency="parameter", + n_live_points=123, + n_params=2, + n_clustering_params=None, + wrapped_params=None, + multimodal=True, + const_efficiency_mode=False, + evidence_tolerance=0.5, + n_iter_before_update=100, + null_log_evidence=-1e90, + max_modes=100, + mode_tolerance=-1e90, + seed=-1, + context=0, + write_output=True, + log_zero=-1e100, + max_iter=0, + init_MPI=False, + dump_callback="dumper", + ) for equiv in bilby.core.sampler.base_sampler.NestedSampler.npoints_equiv_kwargs: new_kwargs = self.sampler.kwargs.copy() del new_kwargs["n_live_points"] - new_kwargs[ - "wrapped_params" - ] = None # The dict comparison can't handle lists - new_kwargs['dump_callback'] = 'dumper' # Check this like Dynesty print_func + new_kwargs["wrapped_params"] = None # The dict comparison can't handle lists + new_kwargs["dump_callback"] = "dumper" # Check this like Dynesty print_func new_kwargs[equiv] = 123 self.sampler.kwargs = new_kwargs self.assertDictEqual(expected, self.sampler.kwargs) diff --git a/test/core/sampler/ultranest_test.py b/test/core/sampler/ultranest_test.py index c0219295b..b4e909084 100644 --- a/test/core/sampler/ultranest_test.py +++ b/test/core/sampler/ultranest_test.py @@ -7,13 +7,12 @@ class TestUltranest(unittest.TestCase): - def setUp(self): self.maxDiff = None self.likelihood = MagicMock() self.priors = bilby.core.prior.PriorDict( - dict(a=bilby.core.prior.Uniform(0, 1), - b=bilby.core.prior.Uniform(0, 1))) + dict(a=bilby.core.prior.Uniform(0, 1), b=bilby.core.prior.Uniform(0, 1)) + ) self.priors["a"] = bilby.core.prior.Prior(boundary="periodic") self.priors["b"] = bilby.core.prior.Prior(boundary="reflective") self.sampler = bilby.core.sampler.ultranest.Ultranest( @@ -98,7 +97,7 @@ def test_translate_kwargs(self): ) for equiv in bilby.core.sampler.base_sampler.NestedSampler.npoints_equiv_kwargs: new_kwargs = self.sampler.kwargs.copy() - del new_kwargs['num_live_points'] + del new_kwargs["num_live_points"] new_kwargs[equiv] = 123 self.sampler.kwargs = new_kwargs self.sampler.kwargs["wrapped_params"] = None diff --git a/test/core/series_test.py b/test/core/series_test.py index bf1b19c43..2f2f0bfbe 100644 --- a/test/core/series_test.py +++ b/test/core/series_test.py @@ -1,8 +1,9 @@ import unittest + import numpy as np -from bilby.core.utils import create_frequency_series, create_time_series from bilby.core.series import CoupledTimeAndFrequencySeries +from bilby.core.utils import create_frequency_series, create_time_series class TestCoupledTimeAndFrequencySeries(unittest.TestCase): @@ -24,12 +25,9 @@ def tearDown(self): def test_repr(self): expected = ( - "CoupledTimeAndFrequencySeries(duration={}, sampling_frequency={}," - " start_time={})".format( - self.series.duration, - self.series.sampling_frequency, - self.series.start_time, - ) + f"CoupledTimeAndFrequencySeries(duration={self.series.duration}, " + f"sampling_frequency={self.series.sampling_frequency}," + f" start_time={self.series.start_time})" ) self.assertEqual(expected, repr(self.series)) @@ -49,9 +47,7 @@ def test_time_array_type(self): self.assertIsInstance(self.series.time_array, np.ndarray) def test_frequency_array_from_init(self): - expected = create_frequency_series( - sampling_frequency=self.sampling_frequency, duration=self.duration - ) + expected = create_frequency_series(sampling_frequency=self.sampling_frequency, duration=self.duration) self.assertTrue(np.array_equal(expected, self.series.frequency_array)) def test_time_array_from_init(self): @@ -65,16 +61,10 @@ def test_time_array_from_init(self): def test_frequency_array_setter(self): new_sampling_frequency = 100 new_duration = 3 - new_frequency_array = create_frequency_series( - sampling_frequency=new_sampling_frequency, duration=new_duration - ) + new_frequency_array = create_frequency_series(sampling_frequency=new_sampling_frequency, duration=new_duration) self.series.frequency_array = new_frequency_array - self.assertTrue( - np.array_equal(new_frequency_array, self.series.frequency_array) - ) - self.assertLessEqual( - np.abs(new_sampling_frequency - self.series.sampling_frequency), 1 - ) + self.assertTrue(np.array_equal(new_frequency_array, self.series.frequency_array)) + self.assertLessEqual(np.abs(new_sampling_frequency - self.series.sampling_frequency), 1) self.assertAlmostEqual(new_duration, self.series.duration) self.assertAlmostEqual(self.start_time, self.series.start_time) @@ -89,9 +79,7 @@ def test_time_array_setter(self): ) self.series.time_array = new_time_array self.assertTrue(np.array_equal(new_time_array, self.series.time_array)) - self.assertAlmostEqual( - new_sampling_frequency, self.series.sampling_frequency, places=1 - ) + self.assertAlmostEqual(new_sampling_frequency, self.series.sampling_frequency, places=1) self.assertAlmostEqual(new_duration, self.series.duration, places=1) self.assertAlmostEqual(new_start_time, self.series.start_time, places=1) diff --git a/test/core/utils_test.py b/test/core/utils_test.py index df46d6bb3..ca827e238 100644 --- a/test/core/utils_test.py +++ b/test/core/utils_test.py @@ -1,19 +1,19 @@ -import unittest +import importlib +import json +import logging import os +import sys +import types +import unittest +import warnings import dill -import numpy as np -from astropy import constants -import importlib +import h5py import lal -import logging import matplotlib.pyplot as plt -import h5py -import json +import numpy as np import pytest -import sys -import types -import warnings +from astropy import constants import bilby from bilby.core import utils @@ -23,9 +23,7 @@ class TestConstants(unittest.TestCase): def test_speed_of_light(self): self.assertEqual(utils.speed_of_light, lal.C_SI) - self.assertLess( - abs(utils.speed_of_light - constants.c.value) / utils.speed_of_light, 1e-16 - ) + self.assertLess(abs(utils.speed_of_light - constants.c.value) / utils.speed_of_light, 1e-16) def test_parsec(self): self.assertEqual(utils.parsec, lal.PC_SI) @@ -33,15 +31,12 @@ def test_parsec(self): def test_solar_mass(self): self.assertEqual(utils.solar_mass, lal.MSUN_SI) - self.assertLess( - abs(utils.solar_mass - constants.M_sun.value) / utils.solar_mass, 1e-4 - ) + self.assertLess(abs(utils.solar_mass - constants.M_sun.value) / utils.solar_mass, 1e-4) def test_radius_of_earth(self): self.assertEqual(bilby.core.utils.radius_of_earth, lal.REARTH_SI) self.assertLess( - abs(utils.radius_of_earth - constants.R_earth.value) - / utils.radius_of_earth, + abs(utils.radius_of_earth - constants.R_earth.value) / utils.radius_of_earth, 1e-5, ) @@ -63,20 +58,14 @@ def test_nfft_sine_function(self): time_domain_strain = np.sin(2 * np.pi * times * injected_frequency + 0.4) - frequency_domain_strain, frequencies = bilby.core.utils.nfft( - time_domain_strain, self.sampling_frequency - ) + frequency_domain_strain, frequencies = bilby.core.utils.nfft(time_domain_strain, self.sampling_frequency) frequency_at_peak = frequencies[np.argmax(np.abs(frequency_domain_strain))] self.assertAlmostEqual(injected_frequency, frequency_at_peak, places=1) def test_nfft_infft(self): time_domain_strain = np.random.normal(0, 1, 10) - frequency_domain_strain, _ = bilby.core.utils.nfft( - time_domain_strain, self.sampling_frequency - ) - new_time_domain_strain = bilby.core.utils.infft( - frequency_domain_strain, self.sampling_frequency - ) + frequency_domain_strain, _ = bilby.core.utils.nfft(time_domain_strain, self.sampling_frequency) + new_time_domain_strain = bilby.core.utils.infft(frequency_domain_strain, self.sampling_frequency) self.assertTrue(np.allclose(time_domain_strain, new_time_domain_strain)) @@ -166,14 +155,10 @@ def test_get_sampling_frequency_from_time_array(self): def test_get_sampling_frequency_from_time_array_unequally_sampled(self): self.time_array[-1] += 0.0001 with self.assertRaises(ValueError): - _, _ = utils.get_sampling_frequency_and_duration_from_time_array( - self.time_array - ) + _, _ = utils.get_sampling_frequency_and_duration_from_time_array(self.time_array) def test_get_duration_from_time_array(self): - _, new_duration = utils.get_sampling_frequency_and_duration_from_time_array( - self.time_array - ) + _, new_duration = utils.get_sampling_frequency_and_duration_from_time_array(self.time_array) self.assertEqual(self.duration, new_duration) def test_get_start_time_from_time_array(self): @@ -184,25 +169,19 @@ def test_get_sampling_frequency_from_frequency_array(self): ( new_sampling_freq, _, - ) = utils.get_sampling_frequency_and_duration_from_frequency_array( - self.frequency_array - ) + ) = utils.get_sampling_frequency_and_duration_from_frequency_array(self.frequency_array) self.assertEqual(self.sampling_frequency, new_sampling_freq) def test_get_sampling_frequency_from_frequency_array_unequally_sampled(self): self.frequency_array[-1] += 0.0001 with self.assertRaises(ValueError): - _, _ = utils.get_sampling_frequency_and_duration_from_frequency_array( - self.frequency_array - ) + _, _ = utils.get_sampling_frequency_and_duration_from_frequency_array(self.frequency_array) def test_get_duration_from_frequency_array(self): ( _, new_duration, - ) = utils.get_sampling_frequency_and_duration_from_frequency_array( - self.frequency_array - ) + ) = utils.get_sampling_frequency_and_duration_from_frequency_array(self.frequency_array) self.assertEqual(self.duration, new_duration) def test_consistency_time_array_to_time_array(self): @@ -222,9 +201,7 @@ def test_consistency_frequency_array_to_frequency_array(self): ( new_sampling_frequency, new_duration, - ) = utils.get_sampling_frequency_and_duration_from_frequency_array( - self.frequency_array - ) + ) = utils.get_sampling_frequency_and_duration_from_frequency_array(self.frequency_array) new_frequency_array = utils.create_frequency_series( sampling_frequency=new_sampling_frequency, duration=new_duration ) @@ -232,9 +209,7 @@ def test_consistency_frequency_array_to_frequency_array(self): def test_illegal_sampling_frequency_and_duration(self): with self.assertRaises(utils.IllegalDurationAndSamplingFrequencyException): - _ = utils.create_time_series( - sampling_frequency=7.7, duration=1.3, starting_time=0 - ) + _ = utils.create_time_series(sampling_frequency=7.7, duration=1.3, starting_time=0) class TestReflect(unittest.TestCase): @@ -345,9 +320,7 @@ def test_returns_none_for_floats_outside_range(self): def test_returns_float_for_float_and_array(self): self.assertIsInstance(self.interpolant(0.5, np.random.random(10)), np.ndarray) self.assertIsInstance(self.interpolant(np.random.random(10), 0.5), np.ndarray) - self.assertIsInstance( - self.interpolant(np.random.random(10), np.random.random(10)), np.ndarray - ) + self.assertIsInstance(self.interpolant(np.random.random(10), np.random.random(10)), np.ndarray) def test_raises_for_mismatched_arrays(self): with self.assertRaises(ValueError): @@ -369,7 +342,7 @@ def setUp(self): self.lnfunc1 = np.log(self.x) self.func1int = (self.x[-1] ** 2 - self.x[0] ** 2) / 2 with np.errstate(divide="ignore"): - self.lnfunc2 = np.log(self.x ** 2) + self.lnfunc2 = np.log(self.x**2) self.func2int = (self.x[-1] ** 3 - self.x[0] ** 3) / 3 self.irregularx = np.array( @@ -391,7 +364,7 @@ def setUp(self): ) with np.errstate(divide="ignore"): self.lnfunc1irregular = np.log(self.irregularx) - self.lnfunc2irregular = np.log(self.irregularx ** 2) + self.lnfunc2irregular = np.log(self.irregularx**2) self.irregulardxs = np.diff(self.irregularx) def test_incorrect_step_type(self): @@ -423,7 +396,6 @@ def test_integral_func2_irregular_steps(self): class TestSavingNumpyRandomGenerator(unittest.TestCase): - @pytest.fixture(autouse=True) def init_outdir(self, tmp_path): # Use pytest's tmp_path fixture to create a temporary directory @@ -439,9 +411,7 @@ def setUp(self): def test_hdf5(self): with h5py.File(self.outdir / "test.h5", "w") as f: - bilby.core.utils.recursively_save_dict_contents_to_group( - f, "/", self.data - ) + bilby.core.utils.recursively_save_dict_contents_to_group(f, "/", self.data) a = self.data["rng"].random() with h5py.File(self.outdir / "test.h5", "r") as f: @@ -451,30 +421,29 @@ def test_hdf5(self): self.assertEqual(a, b) def test_json(self): - with open(self.outdir / "test.json", 'w') as file: + with open(self.outdir / "test.json", "w") as file: json.dump(self.data, file, indent=2, cls=bilby.core.utils.BilbyJsonEncoder) a = self.data["rng"].random() - with open(self.outdir / "test.json", 'r') as file: + with open(self.outdir / "test.json") as file: data = json.load(file, object_hook=bilby.core.utils.decode_bilby_json) b = data["rng"].random() self.assertEqual(a, b) def test_pickle(self): - with open(self.outdir / "test.pkl", 'wb') as file: + with open(self.outdir / "test.pkl", "wb") as file: dill.dump(self.data, file) a = self.data["rng"].random() - with open(self.outdir / "test.pkl", 'rb') as file: + with open(self.outdir / "test.pkl", "rb") as file: data = dill.load(file) b = data["rng"].random() self.assertEqual(a, b) class TestGlobalMetaData(unittest.TestCase): - @pytest.fixture(autouse=True) def set_caplog(self, caplog): self._caplog = caplog @@ -519,7 +488,6 @@ def test_init(self): class TestRandomUtils(unittest.TestCase): - def setUp(self): # Ensure a clean import of the random module if "bilby.core.utils.random" in sys.modules: @@ -541,8 +509,8 @@ def test_no_warning_when_accessed_via_module(self): self.assertNotIn("Detected that `rng` was likely imported directly", str(warning.message)) def test_warning_when_imported_directly(self): - from bilby.core.utils.random import rng from bilby.core.utils import random + from bilby.core.utils.random import rng # Simulate direct import of rng in a fake module fake_module = types.ModuleType("fake_module") diff --git a/test/gw/conversion_test.py b/test/gw/conversion_test.py index af0c81f2f..7b9779a03 100644 --- a/test/gw/conversion_test.py +++ b/test/gw/conversion_test.py @@ -14,8 +14,8 @@ def setUp(self): self.mass_2 = 1.3 self.mass_ratio = 13 / 14 self.total_mass = 2.7 - self.chirp_mass = (1.4 * 1.3) ** 0.6 / 2.7 ** 0.2 - self.symmetric_mass_ratio = (1.4 * 1.3) / 2.7 ** 2 + self.chirp_mass = (1.4 * 1.3) ** 0.6 / 2.7**0.2 + self.symmetric_mass_ratio = (1.4 * 1.3) / 2.7**2 self.cos_angle = -1 self.angle = np.pi self.lambda_1 = 300 @@ -24,18 +24,10 @@ def setUp(self): 8 / 13 * ( - ( - 1 - + 7 * self.symmetric_mass_ratio - - 31 * self.symmetric_mass_ratio ** 2 - ) + (1 + 7 * self.symmetric_mass_ratio - 31 * self.symmetric_mass_ratio**2) * (self.lambda_1 + self.lambda_2) + (1 - 4 * self.symmetric_mass_ratio) ** 0.5 - * ( - 1 - + 9 * self.symmetric_mass_ratio - - 11 * self.symmetric_mass_ratio ** 2 - ) + * (1 + 9 * self.symmetric_mass_ratio - 11 * self.symmetric_mass_ratio**2) * (self.lambda_1 - self.lambda_2) ) ) @@ -44,17 +36,13 @@ def setUp(self): / 2 * ( (1 - 4 * self.symmetric_mass_ratio) ** 0.5 - * ( - 1 - - 13272 / 1319 * self.symmetric_mass_ratio - + 8944 / 1319 * self.symmetric_mass_ratio ** 2 - ) + * (1 - 13272 / 1319 * self.symmetric_mass_ratio + 8944 / 1319 * self.symmetric_mass_ratio**2) * (self.lambda_1 + self.lambda_2) + ( 1 - 15910 / 1319 * self.symmetric_mass_ratio - + 32850 / 1319 * self.symmetric_mass_ratio ** 2 - + 3380 / 1319 * self.symmetric_mass_ratio ** 3 + + 32850 / 1319 * self.symmetric_mass_ratio**2 + + 3380 / 1319 * self.symmetric_mass_ratio**3 ) * (self.lambda_1 - self.lambda_2) ) @@ -69,23 +57,15 @@ def tearDown(self): del self.symmetric_mass_ratio def test_total_mass_and_mass_ratio_to_component_masses(self): - mass_1, mass_2 = conversion.total_mass_and_mass_ratio_to_component_masses( - self.mass_ratio, self.total_mass - ) - self.assertTrue( - all([abs(mass_1 - self.mass_1) < 1e-5, abs(mass_2 - self.mass_2) < 1e-5]) - ) + mass_1, mass_2 = conversion.total_mass_and_mass_ratio_to_component_masses(self.mass_ratio, self.total_mass) + self.assertTrue(all([abs(mass_1 - self.mass_1) < 1e-5, abs(mass_2 - self.mass_2) < 1e-5])) def test_chirp_mass_and_primary_mass_to_mass_ratio(self): - mass_ratio = conversion.chirp_mass_and_primary_mass_to_mass_ratio( - self.chirp_mass, self.mass_1 - ) + mass_ratio = conversion.chirp_mass_and_primary_mass_to_mass_ratio(self.chirp_mass, self.mass_1) self.assertAlmostEqual(self.mass_ratio, mass_ratio) def test_symmetric_mass_ratio_to_mass_ratio(self): - mass_ratio = conversion.symmetric_mass_ratio_to_mass_ratio( - self.symmetric_mass_ratio - ) + mass_ratio = conversion.symmetric_mass_ratio_to_mass_ratio(self.symmetric_mass_ratio) self.assertAlmostEqual(self.mass_ratio, mass_ratio) def test_chirp_mass_and_total_mass_to_symmetric_mass_ratio(self): @@ -95,15 +75,11 @@ def test_chirp_mass_and_total_mass_to_symmetric_mass_ratio(self): self.assertAlmostEqual(self.symmetric_mass_ratio, symmetric_mass_ratio) def test_chirp_mass_and_mass_ratio_to_total_mass(self): - total_mass = conversion.chirp_mass_and_mass_ratio_to_total_mass( - self.chirp_mass, self.mass_ratio - ) + total_mass = conversion.chirp_mass_and_mass_ratio_to_total_mass(self.chirp_mass, self.mass_ratio) self.assertAlmostEqual(self.total_mass, total_mass) def test_chirp_mass_and_mass_ratio_to_component_masses(self): - mass_1, mass_2 = \ - conversion.chirp_mass_and_mass_ratio_to_component_masses( - self.chirp_mass, self.mass_ratio) + mass_1, mass_2 = conversion.chirp_mass_and_mass_ratio_to_component_masses(self.chirp_mass, self.mass_ratio) self.assertAlmostEqual(self.mass_1, mass_1) self.assertAlmostEqual(self.mass_2, mass_2) @@ -116,9 +92,7 @@ def test_component_masses_to_total_mass(self): self.assertAlmostEqual(self.total_mass, total_mass) def test_component_masses_to_symmetric_mass_ratio(self): - symmetric_mass_ratio = conversion.component_masses_to_symmetric_mass_ratio( - self.mass_1, self.mass_2 - ) + symmetric_mass_ratio = conversion.component_masses_to_symmetric_mass_ratio(self.mass_1, self.mass_2) self.assertAlmostEqual(self.symmetric_mass_ratio, symmetric_mass_ratio) def test_component_masses_to_mass_ratio(self): @@ -126,15 +100,11 @@ def test_component_masses_to_mass_ratio(self): self.assertAlmostEqual(self.mass_ratio, mass_ratio) def test_mass_1_and_chirp_mass_to_mass_ratio(self): - mass_ratio = conversion.mass_1_and_chirp_mass_to_mass_ratio( - self.mass_1, self.chirp_mass - ) + mass_ratio = conversion.mass_1_and_chirp_mass_to_mass_ratio(self.mass_1, self.chirp_mass) self.assertAlmostEqual(self.mass_ratio, mass_ratio) def test_lambda_tilde_to_lambda_1_lambda_2(self): - lambda_1, lambda_2 = conversion.lambda_tilde_to_lambda_1_lambda_2( - self.lambda_tilde, self.mass_1, self.mass_2 - ) + lambda_1, lambda_2 = conversion.lambda_tilde_to_lambda_1_lambda_2(self.lambda_tilde, self.mass_1, self.mass_2) self.assertTrue( all( [ @@ -185,7 +155,7 @@ def test_identity_conversion(self): lambda_1=self.lambda_1, lambda_2=self.lambda_2, lambda_tilde=self.lambda_tilde, - delta_lambda_tilde=self.delta_lambda_tilde + delta_lambda_tilde=self.delta_lambda_tilde, ) identity_samples, blank_list = conversion.identity_map_conversion(original_samples) assert blank_list == [] @@ -199,30 +169,18 @@ def setUp(self): self.parameters = dict() self.component_mass_pars = dict(mass_1=1.4, mass_2=1.4) self.mass_parameters = self.component_mass_pars.copy() - self.mass_parameters["mass_ratio"] = conversion.component_masses_to_mass_ratio( - **self.component_mass_pars - ) - self.mass_parameters[ - "symmetric_mass_ratio" - ] = conversion.component_masses_to_symmetric_mass_ratio( - **self.component_mass_pars - ) - self.mass_parameters["chirp_mass"] = conversion.component_masses_to_chirp_mass( - **self.component_mass_pars - ) - self.mass_parameters["total_mass"] = conversion.component_masses_to_total_mass( + self.mass_parameters["mass_ratio"] = conversion.component_masses_to_mass_ratio(**self.component_mass_pars) + self.mass_parameters["symmetric_mass_ratio"] = conversion.component_masses_to_symmetric_mass_ratio( **self.component_mass_pars ) + self.mass_parameters["chirp_mass"] = conversion.component_masses_to_chirp_mass(**self.component_mass_pars) + self.mass_parameters["total_mass"] = conversion.component_masses_to_total_mass(**self.component_mass_pars) self.component_tidal_parameters = dict(lambda_1=300, lambda_2=300) self.all_component_pars = self.component_tidal_parameters.copy() self.all_component_pars.update(self.component_mass_pars) self.tidal_parameters = self.component_tidal_parameters.copy() - self.tidal_parameters[ - "lambda_tilde" - ] = conversion.lambda_1_lambda_2_to_lambda_tilde(**self.all_component_pars) - self.tidal_parameters[ - "delta_lambda_tilde" - ] = conversion.lambda_1_lambda_2_to_delta_lambda_tilde( + self.tidal_parameters["lambda_tilde"] = conversion.lambda_1_lambda_2_to_lambda_tilde(**self.all_component_pars) + self.tidal_parameters["delta_lambda_tilde"] = conversion.lambda_1_lambda_2_to_delta_lambda_tilde( **self.all_component_pars ) @@ -250,9 +208,7 @@ def test_redshift_to_luminosity_distance(self): def test_comoving_to_luminosity_distance(self): self.parameters["comoving_distance"] = 1 - dl = conversion.comoving_distance_to_luminosity_distance( - self.parameters["comoving_distance"] - ) + dl = conversion.comoving_distance_to_luminosity_distance(self.parameters["comoving_distance"]) self.bbh_convert() self.assertEqual(self.parameters["luminosity_distance"], dl) @@ -268,12 +224,7 @@ def _conversion_to_component_mass(self, keys): self.parameters[key] = self.mass_parameters[key] self.bbh_convert() self.assertAlmostEqual( - max( - [ - abs(self.parameters[key] - self.component_mass_pars[key]) - for key in ["mass_1", "mass_2"] - ] - ), + max([abs(self.parameters[key] - self.component_mass_pars[key]) for key in ["mass_1", "mass_2"]]), 0, ) @@ -318,10 +269,7 @@ def test_bbh_aligned_spin_to_spherical(self): phi_12 = 0.0 self.bbh_convert() self.assertDictEqual( - { - key: self.parameters[key] - for key in ["a_1", "tilt_1", "phi_12", "phi_jl"] - }, + {key: self.parameters[key] for key in ["a_1", "tilt_1", "phi_12", "phi_jl"]}, dict(a_1=a_1, tilt_1=tilt_1, phi_jl=phi_jl, phi_12=phi_12), ) @@ -345,10 +293,7 @@ def test_bbh_zero_aligned_spin_to_spherical_with_magnitude(self): phi_12 = 0 self.bbh_convert() self.assertDictEqual( - { - key: self.parameters[key] - for key in ["a_1", "a_2", "tilt_1", "tilt_2", "phi_12", "phi_jl"] - }, + {key: self.parameters[key] for key in ["a_1", "a_2", "tilt_1", "tilt_2", "phi_12", "phi_jl"]}, dict(a_1=a_1, a_2=a_2, tilt_1=tilt_1, tilt_2=tilt_2, phi_jl=phi_jl, phi_12=phi_12), ) @@ -367,10 +312,7 @@ def test_bbh_zero_aligned_spin_to_spherical_without_magnitude(self): phi_12 = 0 self.bbh_convert() self.assertDictEqual( - { - key: self.parameters[key] - for key in ["a_1", "tilt_1", "phi_12", "phi_jl"] - }, + {key: self.parameters[key] for key in ["a_1", "tilt_1", "phi_12", "phi_jl"]}, dict(a_1=a_1, tilt_1=tilt_1, phi_jl=phi_jl, phi_12=phi_12), ) @@ -471,9 +413,7 @@ def setUp(self): "lambda_tilde", "delta_lambda_tilde", ] - self.data_frame = pd.DataFrame({ - key: [value] * 100 for key, value in self.parameters.items() - }) + self.data_frame = pd.DataFrame({key: [value] * 100 for key, value in self.parameters.items()}) def test_generate_all_bbh_parameters(self): self._generate( @@ -536,10 +476,11 @@ def test_generate_bbh_parameters_with_likelihood(self): self.assertNotEqual(converted["mass_1"].values[0], likelihood.parameters["mass_1"]) def test_identity_generation_no_likelihood(self): - test_fixed_prior = bilby.core.prior.PriorDict({ - "test_param_a": bilby.core.prior.DeltaFunction(0, name="test_param_a"), - "test_param_b": bilby.core.prior.DeltaFunction(1, name="test_param_b") - } + test_fixed_prior = bilby.core.prior.PriorDict( + { + "test_param_a": bilby.core.prior.DeltaFunction(0, name="test_param_a"), + "test_param_b": bilby.core.prior.DeltaFunction(1, name="test_param_b"), + } ) output_sample = conversion.identity_map_generation(self.parameters, priors=test_fixed_prior) assert output_sample.pop("test_param_a") == 0 @@ -585,9 +526,7 @@ def setUp(self): self.distances = np.linspace(1, 1000, 100) def test_luminosity_redshift_with_cosmology(self): - z = conversion.luminosity_distance_to_redshift( - self.distances, cosmology="WMAP9" - ) + z = conversion.luminosity_distance_to_redshift(self.distances, cosmology="WMAP9") dl = conversion.redshift_to_luminosity_distance(z, cosmology="WMAP9") self.assertAlmostEqual(max(abs(dl - self.distances)), 0, 4) @@ -597,32 +536,30 @@ def test_comoving_redshift_with_cosmology(self): self.assertAlmostEqual(max(abs(dc - self.distances)), 0, 4) def test_comoving_luminosity_with_cosmology(self): - dc = conversion.comoving_distance_to_luminosity_distance( - self.distances, cosmology="WMAP9" - ) + dc = conversion.comoving_distance_to_luminosity_distance(self.distances, cosmology="WMAP9") dl = conversion.luminosity_distance_to_comoving_distance(dc, cosmology="WMAP9") self.assertAlmostEqual(max(abs(dl - self.distances)), 0, 4) class TestGenerateMassParameters(unittest.TestCase): def setUp(self): - self.expected_values = {'mass_1': 2.0, - 'mass_2': 1.0, - 'chirp_mass': 1.2167286837864113, - 'total_mass': 3.0, - 'mass_1_source': 4.0, - 'mass_2_source': 2.0, - 'chirp_mass_source': 2.433457367572823, - 'total_mass_source': 6, - 'symmetric_mass_ratio': 0.2222222222222222, - 'mass_ratio': 0.5} + self.expected_values = { + "mass_1": 2.0, + "mass_2": 1.0, + "chirp_mass": 1.2167286837864113, + "total_mass": 3.0, + "mass_1_source": 4.0, + "mass_2_source": 2.0, + "chirp_mass_source": 2.433457367572823, + "total_mass_source": 6, + "symmetric_mass_ratio": 0.2222222222222222, + "mass_ratio": 0.5, + } def helper_generation_from_keys(self, keys, expected_values, source=False): # Explicitly test the helper generate_component_masses - local_test_vars = \ - {key: expected_values[key] for key in keys} - local_test_vars_with_component_masses = \ - conversion.generate_component_masses(local_test_vars, source=source) + local_test_vars = {key: expected_values[key] for key in keys} + local_test_vars_with_component_masses = conversion.generate_component_masses(local_test_vars, source=source) if source: self.assertTrue("mass_1_source" in local_test_vars_with_component_masses.keys()) self.assertTrue("mass_2_source" in local_test_vars_with_component_masses.keys()) @@ -630,182 +567,170 @@ def helper_generation_from_keys(self, keys, expected_values, source=False): self.assertTrue("mass_1" in local_test_vars_with_component_masses.keys()) self.assertTrue("mass_2" in local_test_vars_with_component_masses.keys()) for key in local_test_vars_with_component_masses.keys(): - self.assertAlmostEqual( - local_test_vars_with_component_masses[key], - self.expected_values[key]) + self.assertAlmostEqual(local_test_vars_with_component_masses[key], self.expected_values[key]) # Test the function more generally - local_all_mass_parameters = \ - conversion.generate_mass_parameters(local_test_vars, source=source) + local_all_mass_parameters = conversion.generate_mass_parameters(local_test_vars, source=source) if source: self.assertEqual( set(local_all_mass_parameters.keys()), - set(["mass_1_source", - "mass_2_source", - "chirp_mass_source", - "total_mass_source", - "symmetric_mass_ratio", - "mass_ratio", - ] - ) + set( + [ + "mass_1_source", + "mass_2_source", + "chirp_mass_source", + "total_mass_source", + "symmetric_mass_ratio", + "mass_ratio", + ] + ), ) else: self.assertEqual( set(local_all_mass_parameters.keys()), - set(["mass_1", - "mass_2", - "chirp_mass", - "total_mass", - "symmetric_mass_ratio", - "mass_ratio", - ] - ) + set( + [ + "mass_1", + "mass_2", + "chirp_mass", + "total_mass", + "symmetric_mass_ratio", + "mass_ratio", + ] + ), ) for key in local_all_mass_parameters.keys(): self.assertAlmostEqual(expected_values[key], local_all_mass_parameters[key]) def test_from_mass_1_and_mass_2(self): - self.helper_generation_from_keys(["mass_1", "mass_2"], - self.expected_values) + self.helper_generation_from_keys(["mass_1", "mass_2"], self.expected_values) def test_from_mass_1_and_mass_ratio(self): - self.helper_generation_from_keys(["mass_1", "mass_ratio"], - self.expected_values) + self.helper_generation_from_keys(["mass_1", "mass_ratio"], self.expected_values) def test_from_mass_2_and_mass_ratio(self): - self.helper_generation_from_keys(["mass_2", "mass_ratio"], - self.expected_values) + self.helper_generation_from_keys(["mass_2", "mass_ratio"], self.expected_values) def test_from_mass_1_and_total_mass(self): - self.helper_generation_from_keys(["mass_2", "total_mass"], - self.expected_values) + self.helper_generation_from_keys(["mass_2", "total_mass"], self.expected_values) def test_from_chirp_mass_and_mass_ratio(self): - self.helper_generation_from_keys(["chirp_mass", "mass_ratio"], - self.expected_values) + self.helper_generation_from_keys(["chirp_mass", "mass_ratio"], self.expected_values) def test_from_chirp_mass_and_symmetric_mass_ratio(self): - self.helper_generation_from_keys(["chirp_mass", "symmetric_mass_ratio"], - self.expected_values) + self.helper_generation_from_keys(["chirp_mass", "symmetric_mass_ratio"], self.expected_values) def test_from_chirp_mass_and_symmetric_mass_1(self): - self.helper_generation_from_keys(["chirp_mass", "mass_1"], - self.expected_values) + self.helper_generation_from_keys(["chirp_mass", "mass_1"], self.expected_values) def test_from_chirp_mass_and_symmetric_mass_2(self): - self.helper_generation_from_keys(["chirp_mass", "mass_2"], - self.expected_values) + self.helper_generation_from_keys(["chirp_mass", "mass_2"], self.expected_values) def test_from_mass_1_source_and_mass_2_source(self): - self.helper_generation_from_keys(["mass_1_source", "mass_2_source"], - self.expected_values, source=True) + self.helper_generation_from_keys(["mass_1_source", "mass_2_source"], self.expected_values, source=True) def test_from_mass_1_source_and_mass_ratio(self): - self.helper_generation_from_keys(["mass_1_source", "mass_ratio"], - self.expected_values, source=True) + self.helper_generation_from_keys(["mass_1_source", "mass_ratio"], self.expected_values, source=True) def test_from_mass_2_source_and_mass_ratio(self): - self.helper_generation_from_keys(["mass_2_source", "mass_ratio"], - self.expected_values, source=True) + self.helper_generation_from_keys(["mass_2_source", "mass_ratio"], self.expected_values, source=True) def test_from_mass_1_source_and_total_mass(self): - self.helper_generation_from_keys(["mass_2_source", "total_mass_source"], - self.expected_values, source=True) + self.helper_generation_from_keys(["mass_2_source", "total_mass_source"], self.expected_values, source=True) def test_from_chirp_mass_source_and_mass_ratio(self): - self.helper_generation_from_keys(["chirp_mass_source", "mass_ratio"], - self.expected_values, source=True) + self.helper_generation_from_keys(["chirp_mass_source", "mass_ratio"], self.expected_values, source=True) def test_from_chirp_mass_source_and_symmetric_mass_ratio(self): - self.helper_generation_from_keys(["chirp_mass_source", "symmetric_mass_ratio"], - self.expected_values, source=True) + self.helper_generation_from_keys( + ["chirp_mass_source", "symmetric_mass_ratio"], self.expected_values, source=True + ) def test_from_chirp_mass_source_and_symmetric_mass_1(self): - self.helper_generation_from_keys(["chirp_mass_source", "mass_1_source"], - self.expected_values, source=True) + self.helper_generation_from_keys(["chirp_mass_source", "mass_1_source"], self.expected_values, source=True) def test_from_chirp_mass_source_and_symmetric_mass_2(self): - self.helper_generation_from_keys(["chirp_mass_source", "mass_2_source"], - self.expected_values, source=True) + self.helper_generation_from_keys(["chirp_mass_source", "mass_2_source"], self.expected_values, source=True) class TestEquationOfStateConversions(unittest.TestCase): - ''' + """ Class to test equation of state conversions. The test points were generated from a simulation independent of bilby using the original lalsimulation calls. Specific cases tested are described within each function. - ''' + """ + def setUp(self): self.mass_1_source_spectral = [ 4.922542724434885, 4.350626907771598, 4.206155335439082, 1.7822696459661311, - 1.3091740103047926 + 1.3091740103047926, ] self.mass_2_source_spectral = [ 3.459974694590303, 1.2276461777181447, 3.7287707089639976, 0.3724016563531846, - 1.055042934805801 + 1.055042934805801, ] self.spectral_pca_gamma_0 = [ 0.7074873121348357, 0.05855931126849878, 0.7795329261793462, 1.467907561566463, - 2.9066488405635624 + 2.9066488405635624, ] self.spectral_pca_gamma_1 = [ -0.29807111670823816, 2.027708558522935, -1.4415775226512115, -0.7104870098896858, - -0.4913817181089619 + -0.4913817181089619, ] self.spectral_pca_gamma_2 = [ 0.25625095371021156, -0.19574096643220049, -0.2710238103460012, 0.22815820981582358, - -0.1543413205016374 + -0.1543413205016374, ] self.spectral_pca_gamma_3 = [ -0.04030365100175101, 0.05698030777919032, -0.045595911403040264, -0.023480394227900117, - -0.07114492992285618 + -0.07114492992285618, ] self.spectral_gamma_0 = [ 1.1259406796075457, 0.3191335618787259, 1.3651245109783452, 1.3540140238735314, - 1.4551949842961993 + 1.4551949842961993, ] self.spectral_gamma_1 = [ 0.26791504475282835, 0.3930374252139248, 0.11438399886108475, 0.14181113477953, - -0.11989033256620368 + -0.11989033256620368, ] self.spectral_gamma_2 = [ -0.06810849354463173, -0.038250139296677754, -0.0801540229444505, -0.05230330841791625, - -0.005197303281460286 + -0.005197303281460286, ] self.spectral_gamma_3 = [ 0.002848121360389597, 0.000872447754855139, 0.005528747386660879, 0.0024325946344566484, - 0.00043890906202786106 + 0.00043890906202786106, ] self.mass_1_source_polytrope = [ 2.2466565877822573, @@ -813,7 +738,7 @@ def setUp(self): 4.123897187899834, 2.014160764697004, 1.414796714032148, - 2.0919349759766614 + 2.0919349759766614, ] self.mass_2_source_polytrope = [ 0.36696047254774256, @@ -821,7 +746,7 @@ def setUp(self): 1.650477659961306, 1.310399737462001, 0.5470843356210495, - 1.2311162283818198 + 1.2311162283818198, ] self.polytrope_log10_pressure_1 = [ 34.05849276958394, @@ -829,7 +754,7 @@ def setUp(self): 33.07579629429792, 33.93412833210738, 34.24096323517809, - 35.293288373856534 + 35.293288373856534, ] self.polytrope_log10_pressure_2 = [ 33.82891829901602, @@ -837,7 +762,7 @@ def setUp(self): 34.940095188881976, 34.72710820593933, 35.42780071717415, - 35.648689969687915 + 35.648689969687915, ] self.polytrope_gamma_0 = [ 2.359580734009537, @@ -845,7 +770,7 @@ def setUp(self): 4.784129809424835, 1.4900432021657437, 1.0037220431922798, - 4.183994058757201 + 4.183994058757201, ] self.polytrope_gamma_1 = [ 1.9497583698697314, @@ -853,7 +778,7 @@ def setUp(self): 2.8228335336587826, 4.032519623275465, 1.10894361284508, - 3.168076721819637 + 3.168076721819637, ] self.polytrope_gamma_2 = [ 4.6001755196585385, @@ -861,61 +786,62 @@ def setUp(self): 4.429607300132092, 1.8176338276795763, 2.9938859949129797, - 1.300271383168368 + 1.300271383168368, ] - self.lambda_1_spectral = [0., 0., 0., 0., 1275.7253186286332] - self.lambda_2_spectral = [0., 0., 0., 0., 4504.897675043909] - self.lambda_1_polytrope = [0., 0., 0., 0., 0., 234.66424898184766] - self.lambda_2_polytrope = [0., 0., 0., 0., 0., 3710.931378294547] + self.lambda_1_spectral = [0.0, 0.0, 0.0, 0.0, 1275.7253186286332] + self.lambda_2_spectral = [0.0, 0.0, 0.0, 0.0, 4504.897675043909] + self.lambda_1_polytrope = [0.0, 0.0, 0.0, 0.0, 0.0, 234.66424898184766] + self.lambda_2_polytrope = [0.0, 0.0, 0.0, 0.0, 0.0, 3710.931378294547] self.eos_check_spectral = [0, 0, 0, 0, 1] self.eos_check_polytrope = [0, 0, 0, 0, 0, 1] def test_spectral_pca_to_spectral(self): for i in range(len(self.mass_1_source_spectral)): - spectral_gamma_0, spectral_gamma_1, spectral_gamma_2, spectral_gamma_3 = \ + spectral_gamma_0, spectral_gamma_1, spectral_gamma_2, spectral_gamma_3 = ( conversion.spectral_pca_to_spectral( self.spectral_pca_gamma_0[i], self.spectral_pca_gamma_1[i], self.spectral_pca_gamma_2[i], - self.spectral_pca_gamma_3[i] + self.spectral_pca_gamma_3[i], ) + ) self.assertAlmostEqual(spectral_gamma_0, self.spectral_gamma_0[i], places=5) self.assertAlmostEqual(spectral_gamma_1, self.spectral_gamma_1[i], places=5) self.assertAlmostEqual(spectral_gamma_2, self.spectral_gamma_2[i], places=5) self.assertAlmostEqual(spectral_gamma_3, self.spectral_gamma_3[i], places=5) def test_spectral_params_to_lambda_1_lambda_2(self): - ''' + """ The points cover 5 test cases: - Fail SimNeutronStarEOS4ParamSDGammaCheck() - Fail max_speed_of_sound_ <=1.1 - Fail mass_1_source <= max_mass - Fail mass_2_source >= min_mass - Passes all and produces accurate lambda_1, lambda_2, eos_check values - ''' + """ for i in range(len(self.mass_1_source_spectral)): - spectral_gamma_0, spectral_gamma_1, spectral_gamma_2, spectral_gamma_3 = \ + spectral_gamma_0, spectral_gamma_1, spectral_gamma_2, spectral_gamma_3 = ( conversion.spectral_pca_to_spectral( self.spectral_pca_gamma_0[i], self.spectral_pca_gamma_1[i], self.spectral_pca_gamma_2[i], - self.spectral_pca_gamma_3[i] - ) - lambda_1, lambda_2, eos_check = \ - conversion.spectral_params_to_lambda_1_lambda_2( - spectral_gamma_0, - spectral_gamma_1, - spectral_gamma_2, - spectral_gamma_3, - self.mass_1_source_spectral[i], - self.mass_2_source_spectral[i] + self.spectral_pca_gamma_3[i], ) + ) + lambda_1, lambda_2, eos_check = conversion.spectral_params_to_lambda_1_lambda_2( + spectral_gamma_0, + spectral_gamma_1, + spectral_gamma_2, + spectral_gamma_3, + self.mass_1_source_spectral[i], + self.mass_2_source_spectral[i], + ) self.assertAlmostEqual(self.lambda_1_spectral[i], lambda_1, places=0) self.assertAlmostEqual(self.lambda_2_spectral[i], lambda_2, places=0) self.assertAlmostEqual(self.eos_check_spectral[i], eos_check) def test_polytrope_or_causal_params_to_lambda_1_lambda_2_causal(self): - ''' + """ The points cover 6 test cases: - Fail log10_pressure1 >= log10_pressure2 - Fail SimNeutronStarEOS3PDViableFamilyCheck() @@ -923,19 +849,18 @@ def test_polytrope_or_causal_params_to_lambda_1_lambda_2_causal(self): - Fail mass_1_source <= max_mass - Fail mass_2_source >= min_mass - Passes all and produces accurate lambda_1, lambda_2, eos_check values - ''' + """ for i in range(len(self.mass_1_source_polytrope)): - lambda_1, lambda_2, eos_check = \ - conversion.polytrope_or_causal_params_to_lambda_1_lambda_2( - self.polytrope_gamma_0[i], - self.polytrope_log10_pressure_1[i], - self.polytrope_gamma_1[i], - self.polytrope_log10_pressure_2[i], - self.polytrope_gamma_2[i], - self.mass_1_source_polytrope[i], - self.mass_2_source_polytrope[i], - 0 - ) + lambda_1, lambda_2, eos_check = conversion.polytrope_or_causal_params_to_lambda_1_lambda_2( + self.polytrope_gamma_0[i], + self.polytrope_log10_pressure_1[i], + self.polytrope_gamma_1[i], + self.polytrope_log10_pressure_2[i], + self.polytrope_gamma_2[i], + self.mass_1_source_polytrope[i], + self.mass_2_source_polytrope[i], + 0, + ) self.assertAlmostEqual(self.lambda_1_polytrope[i], lambda_1, places=2) self.assertAlmostEqual(self.lambda_2_polytrope[i], lambda_2, places=1) self.assertAlmostEqual(self.eos_check_polytrope[i], eos_check) diff --git a/test/gw/cosmology_test.py b/test/gw/cosmology_test.py index dacbff49c..3c6f5ac30 100644 --- a/test/gw/cosmology_test.py +++ b/test/gw/cosmology_test.py @@ -1,9 +1,10 @@ import unittest -from astropy.cosmology import WMAP9 -from bilby.gw import cosmology import lal import pytest +from astropy.cosmology import WMAP9 + +from bilby.gw import cosmology @pytest.fixture(autouse=True) @@ -77,7 +78,6 @@ def test_getting_cosmology_non_standard_default(self): class TestPlanck15LALCosmology(unittest.TestCase): - def setUp(self): pass @@ -87,14 +87,7 @@ def test_redshift_to_luminosity_distance(self): dist_bilby = cosmo.luminosity_distance(z) # Change to use CreateDefaultCosmologicalParameters once it is # available in a release - omega = lal.CreateCosmologicalParameters( - h=0.679, - om=0.3065, - ol=1 - 0.3065, - w0=-1.0, - w1=0.0, - w2=0.0 - ) + omega = lal.CreateCosmologicalParameters(h=0.679, om=0.3065, ol=1 - 0.3065, w0=-1.0, w1=0.0, w2=0.0) dist_lal = lal.LuminosityDistance(omega=omega, z=z) # Results are the same to within 9 decimal places self.assertAlmostEqual(dist_bilby.value, dist_lal) diff --git a/test/gw/detector/calibration_test.py b/test/gw/detector/calibration_test.py index 3453913c1..98f90aac3 100644 --- a/test/gw/detector/calibration_test.py +++ b/test/gw/detector/calibration_test.py @@ -1,8 +1,9 @@ import os import unittest -from parameterized import parameterized import numpy as np +from parameterized import parameterized + from bilby.core.prior import PriorDict from bilby.gw import calibration, detector, prior @@ -37,11 +38,7 @@ def setUp(self): maximum_frequency=self.maximum_frequency, n_points=self.n_points, ) - self.parameters = { - "recalib_{}_{}".format(param, ii): 0.0 - for ii in range(5) - for param in ["amplitude", "phase"] - } + self.parameters = {f"recalib_{param}_{ii}": 0.0 for ii in range(5) for param in ["amplitude", "phase"]} def tearDown(self): del self.prefix @@ -53,14 +50,13 @@ def tearDown(self): def test_calibration_factor(self): frequency_array = np.linspace(20, 1024, 1000) - cal_factor = self.model.get_calibration_factor( - frequency_array, **self.parameters - ) + cal_factor = self.model.get_calibration_factor(frequency_array, **self.parameters) assert np.all(cal_factor.real == np.ones_like(frequency_array)) def test_repr(self): - expected = "CubicSpline(prefix='{}', minimum_frequency={}, maximum_frequency={}, n_points={})".format( - self.prefix, self.minimum_frequency, self.maximum_frequency, self.n_points + expected = ( + f"CubicSpline(prefix='{self.prefix}', minimum_frequency={self.minimum_frequency}, " + f"maximum_frequency={self.maximum_frequency}, n_points={self.n_points})" ) actual = repr(self.model) self.assertEqual(expected, actual) @@ -81,9 +77,7 @@ def setUp(self): self.filename = "calibration_draws.h5" self.number_of_draws = 100 self.ifo = detector.get_empty_interferometer("H1") - self.ifo.set_strain_data_from_power_spectral_density( - duration=1, sampling_frequency=512 - ) + self.ifo.set_strain_data_from_power_spectral_density(duration=1, sampling_frequency=512) self.ifo.calibration_model = calibration.CubicSpline( prefix="recalib_H1_", n_points=10, @@ -104,17 +98,13 @@ def tearDown(self): os.remove(self.filename) def test_generate_draws(self): - draws, parameters = calibration._generate_calibration_draws( - self.ifo, self.priors, self.number_of_draws - ) + draws, parameters = calibration._generate_calibration_draws(self.ifo, self.priors, self.number_of_draws) self.assertEqual(draws.shape, (self.number_of_draws, sum(self.ifo.frequency_mask))) self.assertListEqual(list(self.priors.keys()), list(parameters.keys())) @parameterized.expand([("template",), ("data",), (None,)]) def test_read_write_matches(self, correction_type): - draws, parameters = calibration._generate_calibration_draws( - self.ifo, self.priors, self.number_of_draws - ) + draws, parameters = calibration._generate_calibration_draws(self.ifo, self.priors, self.number_of_draws) frequencies = self.ifo.frequency_array[self.ifo.frequency_mask] calibration.write_calibration_file( filename=self.filename, @@ -135,21 +125,21 @@ def test_read_write_matches(self, correction_type): def test_build_calibration_lookup(self): ifos = detector.InterferometerList(["H1", "L1", "H1"]) - ifos.set_strain_data_from_power_spectral_densities( - duration=4, sampling_frequency=1024 - ) + ifos.set_strain_data_from_power_spectral_densities(duration=4, sampling_frequency=1024) priors = PriorDict() for ifo in ifos: ifo.minimum_frequency = 20 ifo.maximum_frequency = 1024 - priors.update(prior.CalibrationPriorDict.constant_uncertainty_spline( - amplitude_sigma=0.1, - phase_sigma=0.1, - minimum_frequency=ifo.minimum_frequency, - maximum_frequency=ifo.maximum_frequency, - n_nodes=10, - label=ifo.name, - )) + priors.update( + prior.CalibrationPriorDict.constant_uncertainty_spline( + amplitude_sigma=0.1, + phase_sigma=0.1, + minimum_frequency=ifo.minimum_frequency, + maximum_frequency=ifo.maximum_frequency, + n_nodes=10, + label=ifo.name, + ) + ) ifo.calibration_model = calibration.CubicSpline( prefix=f"recalib_{ifo.name}_", n_points=10, diff --git a/test/gw/detector/geometry_test.py b/test/gw/detector/geometry_test.py index 358825b23..ea8468c26 100644 --- a/test/gw/detector/geometry_test.py +++ b/test/gw/detector/geometry_test.py @@ -191,17 +191,10 @@ def test_unit_vector_along_arm_y(self): def test_repr(self): expected = ( - "InterferometerGeometry(length={}, latitude={}, longitude={}, elevation={}, xarm_azimuth={}, " - "yarm_azimuth={}, xarm_tilt={}, yarm_tilt={})".format( - float(self.length), - float(self.latitude), - float(self.longitude), - float(self.elevation), - float(self.xarm_azimuth), - float(self.yarm_azimuth), - float(self.xarm_tilt), - float(self.yarm_tilt), - ) + f"InterferometerGeometry(length={float(self.length)}, latitude={float(self.latitude)}, " + f"longitude={float(self.longitude)}, elevation={float(self.elevation)}, " + f"xarm_azimuth={float(self.xarm_azimuth)}, yarm_azimuth={float(self.yarm_azimuth)}, " + f"xarm_tilt={float(self.xarm_tilt)}, yarm_tilt={float(self.yarm_tilt)})" ) self.assertEqual(expected, repr(self.geometry)) diff --git a/test/gw/detector/injection_test.py b/test/gw/detector/injection_test.py index 95d6a3da0..05c4d1ee9 100644 --- a/test/gw/detector/injection_test.py +++ b/test/gw/detector/injection_test.py @@ -1,7 +1,8 @@ -import bilby import pytest from gwpy.frequencyseries import FrequencySeries +import bilby + @pytest.mark.flaky(reruns=3) def test_injection_into_timeseries_matches_ifo_injections(): @@ -17,9 +18,7 @@ def test_injection_into_timeseries_matches_ifo_injections(): duration = 8 sampling_frequency = 16384 ifo = bilby.gw.detector.get_empty_interferometer("H1") - ifo.set_strain_data_from_zero_noise( - sampling_frequency=sampling_frequency, duration=duration - ) + ifo.set_strain_data_from_zero_noise(sampling_frequency=sampling_frequency, duration=duration) wfg = bilby.gw.waveform_generator.WaveformGenerator( frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole, duration=duration, @@ -51,8 +50,6 @@ def test_injection_into_timeseries_matches_ifo_injections(): whitened_data_2 = ifo.whitened_time_domain_strain mismatch = 1 - ( - sum(whitened_data_1 * whitened_data_2) - / sum(whitened_data_1**2)**0.5 - / sum(whitened_data_2**2)**0.5 + sum(whitened_data_1 * whitened_data_2) / sum(whitened_data_1**2) ** 0.5 / sum(whitened_data_2**2) ** 0.5 ) assert mismatch < 3e-3 diff --git a/test/gw/detector/interferometer_test.py b/test/gw/detector/interferometer_test.py index cb1666320..54abde245 100644 --- a/test/gw/detector/interferometer_test.py +++ b/test/gw/detector/interferometer_test.py @@ -1,13 +1,12 @@ import os import unittest +from shutil import rmtree from unittest import mock import lal import lalsimulation -import pytest -from shutil import rmtree - import numpy as np +import pytest import bilby @@ -15,9 +14,7 @@ class TestInterferometer(unittest.TestCase): def setUp(self): self.name = "name" - self.power_spectral_density = ( - bilby.gw.detector.PowerSpectralDensity.from_aligo() - ) + self.power_spectral_density = bilby.gw.detector.PowerSpectralDensity.from_aligo() self.minimum_frequency = 10 self.maximum_frequency = 20 self.length = 30 @@ -54,16 +51,9 @@ def setUp(self): self.injection_polarizations["cross"] = np.random.random(4097) self.waveform_generator = mock.MagicMock() - self.wg_polarizations = dict( - plus=np.random.random(4097), cross=np.random.random(4097) - ) - self.waveform_generator.frequency_domain_strain = ( - lambda _: self.wg_polarizations - ) - self.parameters = dict( - ra=0.0, dec=0.0, geocent_time=0.0, psi=0.0, - mass_1=100, mass_2=100 - ) + self.wg_polarizations = dict(plus=np.random.random(4097), cross=np.random.random(4097)) + self.waveform_generator.frequency_domain_strain = lambda _: self.wg_polarizations + self.parameters = dict(ra=0.0, dec=0.0, geocent_time=0.0, psi=0.0, mass_1=100, mass_2=100) bilby.core.utils.check_directory_exists_and_if_not_mkdir(self.outdir) @@ -111,9 +101,7 @@ def test_get_detector_response_default_behaviour(self): waveform_polarizations=dict(plus=plus), parameters=dict(ra=0, dec=0, geocent_time=0, psi=0), ) - self.assertTrue( - np.array_equal(response, plus * self.ifo.frequency_mask * np.exp(-0j)) - ) + self.assertTrue(np.array_equal(response, plus * self.ifo.frequency_mask * np.exp(-0j))) def test_get_detector_response_with_dt(self): self.ifo.antenna_response = mock.MagicMock(return_value=1) @@ -126,11 +114,7 @@ def test_get_detector_response_with_dt(self): waveform_polarizations=dict(plus=plus), parameters=dict(ra=0, dec=0, geocent_time=0, psi=0), ) - expected_response = ( - plus - * self.ifo.frequency_mask - * np.exp(-1j * 2 * np.pi * self.ifo.frequency_array) - ) + expected_response = plus * self.ifo.frequency_mask * np.exp(-1j * 2 * np.pi * self.ifo.frequency_array) self.assertTrue(np.allclose(abs(expected_response), abs(response))) def test_get_detector_response_multiple_modes(self): @@ -145,11 +129,7 @@ def test_get_detector_response_multiple_modes(self): waveform_polarizations=dict(plus=plus, cross=cross), parameters=dict(ra=0, dec=0, geocent_time=0, psi=0), ) - self.assertTrue( - np.array_equal( - response, (plus + cross) * self.ifo.frequency_mask * np.exp(-0j) - ) - ) + self.assertTrue(np.array_equal(response, (plus + cross) * self.ifo.frequency_mask * np.exp(-0j))) def test_inject_signal_from_waveform_polarizations_correct_injection(self): original_strain = self.ifo.strain_data.frequency_domain_strain @@ -158,14 +138,8 @@ def test_inject_signal_from_waveform_polarizations_correct_injection(self): parameters=self.parameters, injection_polarizations=self.injection_polarizations, ) - expected = ( - self.injection_polarizations["plus"] - + self.injection_polarizations["cross"] - + original_strain - ) - self.assertTrue( - np.array_equal(expected, self.ifo.strain_data._frequency_domain_strain) - ) + expected = self.injection_polarizations["plus"] + self.injection_polarizations["cross"] + original_strain + self.assertTrue(np.array_equal(expected, self.ifo.strain_data._frequency_domain_strain)) def test_inject_signal_from_waveform_polarizations_update_time_domain_strain(self): original_td_strain = self.ifo.strain_data.time_domain_strain @@ -174,9 +148,7 @@ def test_inject_signal_from_waveform_polarizations_update_time_domain_strain(sel parameters=self.parameters, injection_polarizations=self.injection_polarizations, ) - self.assertFalse( - np.array_equal(original_td_strain, self.ifo.strain_data.time_domain_strain) - ) + self.assertFalse(np.array_equal(original_td_strain, self.ifo.strain_data.time_domain_strain)) def test_inject_signal_from_waveform_polarizations_meta_data(self): self.ifo.get_detector_response = lambda x, params: x["plus"] + x["cross"] @@ -184,9 +156,7 @@ def test_inject_signal_from_waveform_polarizations_meta_data(self): parameters=self.parameters, injection_polarizations=self.injection_polarizations, ) - signal_ifo_expected = ( - self.injection_polarizations["plus"] + self.injection_polarizations["cross"] - ) + signal_ifo_expected = self.injection_polarizations["plus"] + self.injection_polarizations["cross"] self.assertAlmostEqual( self.ifo.optimal_snr_squared(signal=signal_ifo_expected).real, self.ifo.meta_data["optimal_SNR"] ** 2, @@ -224,28 +194,14 @@ def test_inject_signal_from_waveform_generator_correct_return_value(self): returned_polarizations = self.ifo.inject_signal_from_waveform_generator( parameters=self.parameters, waveform_generator=self.waveform_generator ) - self.assertTrue( - np.array_equal( - self.wg_polarizations["plus"], returned_polarizations["plus"] - ) - ) - self.assertTrue( - np.array_equal( - self.wg_polarizations["cross"], returned_polarizations["cross"] - ) - ) + self.assertTrue(np.array_equal(self.wg_polarizations["plus"], returned_polarizations["plus"])) + self.assertTrue(np.array_equal(self.wg_polarizations["cross"], returned_polarizations["cross"])) - @mock.patch.object( - bilby.gw.detector.Interferometer, "inject_signal_from_waveform_generator" - ) + @mock.patch.object(bilby.gw.detector.Interferometer, "inject_signal_from_waveform_generator") def test_inject_signal_with_waveform_generator_correct_call(self, m): self.ifo.get_detector_response = lambda x, params: x["plus"] + x["cross"] - _ = self.ifo.inject_signal( - parameters=self.parameters, waveform_generator=self.waveform_generator - ) - m.assert_called_with( - parameters=self.parameters, waveform_generator=self.waveform_generator - ) + _ = self.ifo.inject_signal(parameters=self.parameters, waveform_generator=self.waveform_generator) + m.assert_called_with(parameters=self.parameters, waveform_generator=self.waveform_generator) def test_inject_signal_from_waveform_generator_correct_injection(self): original_strain = self.ifo.strain_data.frequency_domain_strain @@ -253,14 +209,8 @@ def test_inject_signal_from_waveform_generator_correct_injection(self): injection_polarizations = self.ifo.inject_signal_from_waveform_generator( parameters=self.parameters, waveform_generator=self.waveform_generator ) - expected = ( - injection_polarizations["plus"] - + injection_polarizations["cross"] - + original_strain - ) - self.assertTrue( - np.array_equal(expected, self.ifo.strain_data._frequency_domain_strain) - ) + expected = injection_polarizations["plus"] + injection_polarizations["cross"] + original_strain + self.assertTrue(np.array_equal(expected, self.ifo.strain_data._frequency_domain_strain)) def test_inject_signal_with_injection_polarizations(self): original_strain = self.ifo.strain_data.frequency_domain_strain @@ -269,18 +219,10 @@ def test_inject_signal_with_injection_polarizations(self): parameters=self.parameters, injection_polarizations=self.injection_polarizations, ) - expected = ( - self.injection_polarizations["plus"] - + self.injection_polarizations["cross"] - + original_strain - ) - self.assertTrue( - np.array_equal(expected, self.ifo.strain_data._frequency_domain_strain) - ) + expected = self.injection_polarizations["plus"] + self.injection_polarizations["cross"] + original_strain + self.assertTrue(np.array_equal(expected, self.ifo.strain_data._frequency_domain_strain)) - @mock.patch.object( - bilby.gw.detector.Interferometer, "inject_signal_from_waveform_polarizations" - ) + @mock.patch.object(bilby.gw.detector.Interferometer, "inject_signal_from_waveform_polarizations") def test_inject_signal_with_injection_polarizations_and_waveform_generator(self, m): self.ifo.get_detector_response = lambda x, params: x["plus"] + x["cross"] _ = self.ifo.inject_signal( @@ -328,8 +270,7 @@ def test_template_template_inner_product(self): signal_2 = np.ones_like(self.ifo.power_spectral_density_array) * 2 signal_1_optimal = self.ifo.optimal_snr_squared(signal=signal_1) signal_1_optimal_by_template_template = self.ifo.template_template_inner_product( - signal_1=signal_1, - signal_2=signal_1 + signal_1=signal_1, signal_2=signal_1 ) self.assertTrue(np.array_equal(signal_1_optimal, signal_1_optimal_by_template_template)) signal_1_signal_2_inner_product = self.ifo.template_template_inner_product(signal_1=signal_1, signal_2=signal_2) @@ -337,22 +278,12 @@ def test_template_template_inner_product(self): def test_repr(self): expected = ( - "Interferometer(name='{}', power_spectral_density={}, minimum_frequency={}, " - "maximum_frequency={}, length={}, latitude={}, longitude={}, elevation={}, xarm_azimuth={}, " - "yarm_azimuth={}, xarm_tilt={}, yarm_tilt={})".format( - self.name, - self.power_spectral_density, - float(self.minimum_frequency), - float(self.maximum_frequency), - float(self.length), - float(self.latitude), - float(self.longitude), - float(self.elevation), - float(self.xarm_azimuth), - float(self.yarm_azimuth), - float(self.xarm_tilt), - float(self.yarm_tilt), - ) + f"Interferometer(name='{self.name}', power_spectral_density={self.power_spectral_density}, " + f"minimum_frequency={float(self.minimum_frequency)}, " + f"maximum_frequency={float(self.maximum_frequency)}, length={float(self.length)}, " + f"latitude={float(self.latitude)}, longitude={float(self.longitude)}, elevation={float(self.elevation)}, " + f"xarm_azimuth={float(self.xarm_azimuth)}, yarm_azimuth={float(self.yarm_azimuth)}, " + f"xarm_tilt={float(self.xarm_tilt)}, yarm_tilt={float(self.yarm_tilt)})" ) self.assertEqual(expected, repr(self.ifo)) @@ -364,11 +295,10 @@ def test_to_and_from_pkl_loading(self): def test_to_and_from_pkl_wrong_class(self): import dill + with open("./outdir/psd.pkl", "wb") as ff: dill.dump(self.ifo.power_spectral_density, ff) - filename = self.ifo._filename_from_outdir_label_extension( - outdir="outdir", label="psd", extension="pkl" - ) + filename = self.ifo._filename_from_outdir_label_extension(outdir="outdir", label="psd", extension="pkl") with self.assertRaises(TypeError): bilby.gw.detector.Interferometer.from_pickle(filename) @@ -378,10 +308,8 @@ def test_psd_not_impacted_by_window_factor(monkeypatch): ifo.set_strain_data_from_zero_noise(duration=4, sampling_frequency=256) monkeypatch.setattr(ifo.strain_data, "window_factor", np.nan) np.testing.assert_array_equal( - ifo.power_spectral_density.get_power_spectral_density_array( - frequency_array=ifo.strain_data.frequency_array - ), - ifo.power_spectral_density_array + ifo.power_spectral_density.get_power_spectral_density_array(frequency_array=ifo.strain_data.frequency_array), + ifo.power_spectral_density_array, ) @@ -399,12 +327,8 @@ def test_psd_impacted_by_window_factor_with_environment_variable(monkeypatch): class TestInterferometerEquals(unittest.TestCase): def setUp(self): self.name = "name" - self.power_spectral_density_1 = ( - bilby.gw.detector.PowerSpectralDensity.from_aligo() - ) - self.power_spectral_density_2 = ( - bilby.gw.detector.PowerSpectralDensity.from_aligo() - ) + self.power_spectral_density_1 = bilby.gw.detector.PowerSpectralDensity.from_aligo() + self.power_spectral_density_2 = bilby.gw.detector.PowerSpectralDensity.from_aligo() self.minimum_frequency = 10 self.maximum_frequency = 20 self.length = 30 @@ -538,9 +462,9 @@ def test_eq_false_different_ifo_strain_data(self): class TestInterferometerAntennaPatternAgainstLAL(unittest.TestCase): def setUp(self): self.name = "name" - self.ifo_names = ['H1', 'L1', 'V1', 'K1', 'GEO600', 'ET'] - self.lal_prefixes = {'H1': 'H1', 'L1': 'L1', 'V1': 'V1', 'K1': 'K1', 'GEO600': 'G1', 'ET': 'E1'} - self.polarizations = ['plus', 'cross', 'breathing', 'longitudinal', 'x', 'y'] + self.ifo_names = ["H1", "L1", "V1", "K1", "GEO600", "ET"] + self.lal_prefixes = {"H1": "H1", "L1": "L1", "V1": "V1", "K1": "K1", "GEO600": "G1", "ET": "E1"} + self.polarizations = ["plus", "cross", "breathing", "longitudinal", "x", "y"] self.ifos = bilby.gw.detector.InterferometerList(self.ifo_names) self.gpstime = 1305303144 self.trial = 100 @@ -563,8 +487,8 @@ def test_antenna_pattern_vs_lal(self): response = lalsimulation.DetectorPrefixToLALDetector(self.lal_prefixes[ifo_name]).response ifo = self.ifos[n] for i in range(self.trial): - ra = 2. * np.pi * np.random.uniform() - dec = np.pi * np.random.uniform() - np.pi / 2. + ra = 2.0 * np.pi * np.random.uniform() + dec = np.pi * np.random.uniform() - np.pi / 2.0 psi = np.pi * np.random.uniform() f_lal[i] = lal.ComputeDetAMResponseExtraModes(response, ra, dec, psi, gmst) for m, pol in enumerate(self.polarizations): @@ -572,7 +496,7 @@ def test_antenna_pattern_vs_lal(self): std = np.std(f_bilby - f_lal, axis=0) for m, pol in enumerate(self.polarizations): - with self.subTest(':'.join((ifo_name, pol))): + with self.subTest(":".join((ifo_name, pol))): self.assertAlmostEqual(std[m], 0.0, places=7) def test_time_delay_vs_lal(self): @@ -583,12 +507,11 @@ def test_time_delay_vs_lal(self): det = lal.cached_detector_by_prefix[self.lal_prefixes[ifo_name]] for i in range(self.trial): gpstime = np.random.uniform(1205303144, 1405303144) - ra = 2. * np.pi * np.random.uniform() - dec = np.pi * np.random.uniform() - np.pi / 2. - delays[i] = ( - lal.TimeDelayFromEarthCenter(det.location, ra, dec, gpstime) - - ifo.time_delay_from_geocenter(ra, dec, gpstime) - ) + ra = 2.0 * np.pi * np.random.uniform() + dec = np.pi * np.random.uniform() - np.pi / 2.0 + delays[i] = lal.TimeDelayFromEarthCenter( + det.location, ra, dec, gpstime + ) - ifo.time_delay_from_geocenter(ra, dec, gpstime) std = max(abs(delays)) with self.subTest(ifo_name): @@ -600,33 +523,33 @@ class TestInterferometerWhitenedStrain(unittest.TestCase): def setUp(self): self.duration = 64 self.sampling_frequency = 4096 - self.ifo = bilby.gw.detector.get_empty_interferometer('H1') + self.ifo = bilby.gw.detector.get_empty_interferometer("H1") self.ifo.set_strain_data_from_power_spectral_density( - sampling_frequency=self.sampling_frequency, duration=self.duration) + sampling_frequency=self.sampling_frequency, duration=self.duration + ) self.waveform_generator = bilby.gw.waveform_generator.WaveformGenerator( duration=self.duration, sampling_frequency=self.sampling_frequency, frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole, - waveform_arguments={ - "waveform_approximant": "IMRPhenomXP" - }) + waveform_arguments={"waveform_approximant": "IMRPhenomXP"}, + ) self.parameters = { - 'mass_1': 10, - 'mass_2': 10, - 'a_1': 0, - 'a_2': 0, - 'tilt_1': 0, - 'tilt_2': 0, - 'phi_12': 0, - 'phi_jl': 0, - 'theta_jn': 0, - 'luminosity_distance': 40, - 'phase': 0, - 'ra': 0, - 'dec': 0, - 'geocent_time': 62, - 'psi': 0 + "mass_1": 10, + "mass_2": 10, + "a_1": 0, + "a_2": 0, + "tilt_1": 0, + "tilt_2": 0, + "phi_12": 0, + "phi_jl": 0, + "theta_jn": 0, + "luminosity_distance": 40, + "phase": 0, + "ra": 0, + "dec": 0, + "geocent_time": 62, + "psi": 0, } def tearDown(self): @@ -661,8 +584,7 @@ def test_frequency_domain_noise_and_signal_whitening(self): # Make the template separately waveform_polarizations = self.waveform_generator.frequency_domain_strain(parameters=self.parameters) signal_ifo = self.ifo.get_detector_response( - waveform_polarizations=waveform_polarizations, - parameters=self.parameters + waveform_polarizations=waveform_polarizations, parameters=self.parameters ) # Whiten the template whitened_signal_ifo = self.ifo.whiten_frequency_series(signal_ifo) @@ -676,8 +598,7 @@ def test_time_domain_noise_and_signal_whitening(self): # Make the template separately waveform_polarizations = self.waveform_generator.frequency_domain_strain(parameters=self.parameters) signal_ifo = self.ifo.get_detector_response( - waveform_polarizations=waveform_polarizations, - parameters=self.parameters + waveform_polarizations=waveform_polarizations, parameters=self.parameters ) # Whiten the template in FD whitened_signal_ifo_fd = self.ifo.whiten_frequency_series(signal_ifo) diff --git a/test/gw/detector/networks_test.py b/test/gw/detector/networks_test.py index 6f7b211d3..446cf6cb7 100644 --- a/test/gw/detector/networks_test.py +++ b/test/gw/detector/networks_test.py @@ -1,7 +1,7 @@ import unittest -from unittest import mock -from shutil import rmtree from itertools import combinations +from shutil import rmtree +from unittest import mock import numpy as np @@ -13,12 +13,8 @@ def setUp(self): self.frequency_arrays = np.linspace(0, 4096, 4097) self.name1 = "name1" self.name2 = "name2" - self.power_spectral_density1 = ( - bilby.gw.detector.PowerSpectralDensity.from_aligo() - ) - self.power_spectral_density2 = ( - bilby.gw.detector.PowerSpectralDensity.from_aligo() - ) + self.power_spectral_density1 = bilby.gw.detector.PowerSpectralDensity.from_aligo() + self.power_spectral_density2 = bilby.gw.detector.PowerSpectralDensity.from_aligo() self.minimum_frequency1 = 10 self.minimum_frequency2 = 10 self.maximum_frequency1 = 20 @@ -192,25 +188,17 @@ def test_check_interferometers_relative_tolerance(self, mock_warning): self.assertTrue(mock_warning.called) warning_log_str = mock_warning.call_args.args[0].args[0] self.assertIsInstance(warning_log_str, str) - self.assertTrue( - "The start_time of all interferometers are not the same:" in warning_log_str - ) + self.assertTrue("The start_time of all interferometers are not the same:" in warning_log_str) - @mock.patch.object( - bilby.gw.detector.Interferometer, "set_strain_data_from_power_spectral_density" - ) + @mock.patch.object(bilby.gw.detector.Interferometer, "set_strain_data_from_power_spectral_density") def test_set_strain_data_from_power_spectral_density(self, m): - self.ifo_list.set_strain_data_from_power_spectral_densities( - sampling_frequency=123, duration=6.2, start_time=3 - ) + self.ifo_list.set_strain_data_from_power_spectral_densities(sampling_frequency=123, duration=6.2, start_time=3) m.assert_called_with(sampling_frequency=123, duration=6.2, start_time=3) self.assertEqual(len(self.ifo_list), m.call_count) def test_inject_signal_pol_and_wg_none(self): with self.assertRaises(ValueError): - self.ifo_list.inject_signal( - injection_polarizations=None, waveform_generator=None - ) + self.ifo_list.inject_signal(injection_polarizations=None, waveform_generator=None) def test_meta_data(self): ifos_list = [self.ifo1, self.ifo2] @@ -219,37 +207,27 @@ def test_meta_data(self): meta_data = {ifo.name: ifo.meta_data for ifo in ifos_list} self.assertEqual(ifos.meta_data, meta_data) - @mock.patch.object( - bilby.gw.waveform_generator.WaveformGenerator, "frequency_domain_strain" - ) + @mock.patch.object(bilby.gw.waveform_generator.WaveformGenerator, "frequency_domain_strain") def test_inject_signal_pol_none_calls_frequency_domain_strain(self, m): waveform_generator = bilby.gw.waveform_generator.WaveformGenerator( frequency_domain_source_model=lambda x, y, z: x ) self.ifo1.inject_signal = mock.MagicMock(return_value=None) self.ifo2.inject_signal = mock.MagicMock(return_value=None) - self.ifo_list.inject_signal( - parameters=None, waveform_generator=waveform_generator - ) + self.ifo_list.inject_signal(parameters=None, waveform_generator=waveform_generator) self.assertTrue(m.called) @mock.patch.object(bilby.gw.detector.Interferometer, "inject_signal") def test_inject_signal_with_inj_pol(self, m): - self.ifo_list.inject_signal( - injection_polarizations=dict(plus=1), raise_error=False - ) - m.assert_called_with( - parameters=None, injection_polarizations=dict(plus=1), raise_error=False - ) + self.ifo_list.inject_signal(injection_polarizations=dict(plus=1), raise_error=False) + m.assert_called_with(parameters=None, injection_polarizations=dict(plus=1), raise_error=False) self.assertEqual(len(self.ifo_list), m.call_count) @mock.patch.object(bilby.gw.detector.Interferometer, "inject_signal") def test_inject_signal_returns_expected_polarisations(self, m): m.return_value = dict(plus=1, cross=2) injection_polarizations = dict(plus=1, cross=2) - ifos_pol = self.ifo_list.inject_signal( - injection_polarizations=injection_polarizations - ) + ifos_pol = self.ifo_list.inject_signal(injection_polarizations=injection_polarizations) self.assertDictEqual( self.ifo1.inject_signal(injection_polarizations=injection_polarizations), ifos_pol[0], @@ -273,28 +251,16 @@ def test_duration(self): self.assertEqual(self.ifo2.strain_data.duration, self.ifo_list.duration) def test_sampling_frequency(self): - self.assertEqual( - self.ifo1.strain_data.sampling_frequency, self.ifo_list.sampling_frequency - ) - self.assertEqual( - self.ifo2.strain_data.sampling_frequency, self.ifo_list.sampling_frequency - ) + self.assertEqual(self.ifo1.strain_data.sampling_frequency, self.ifo_list.sampling_frequency) + self.assertEqual(self.ifo2.strain_data.sampling_frequency, self.ifo_list.sampling_frequency) def test_start_time(self): self.assertEqual(self.ifo1.strain_data.start_time, self.ifo_list.start_time) self.assertEqual(self.ifo2.strain_data.start_time, self.ifo_list.start_time) def test_frequency_array(self): - self.assertTrue( - np.array_equal( - self.ifo1.strain_data.frequency_array, self.ifo_list.frequency_array - ) - ) - self.assertTrue( - np.array_equal( - self.ifo2.strain_data.frequency_array, self.ifo_list.frequency_array - ) - ) + self.assertTrue(np.array_equal(self.ifo1.strain_data.frequency_array, self.ifo_list.frequency_array)) + self.assertTrue(np.array_equal(self.ifo2.strain_data.frequency_array, self.ifo_list.frequency_array)) def test_append_with_ifo(self): self.ifo_list.append(self.ifo2) @@ -304,16 +270,12 @@ def test_append_with_ifo(self): def test_append_with_ifo_list(self): self.ifo_list.append(self.ifo_list) names = [ifo.name for ifo in self.ifo_list] - self.assertListEqual( - [self.ifo1.name, self.ifo2.name, self.ifo1.name, self.ifo2.name], names - ) + self.assertListEqual([self.ifo1.name, self.ifo2.name, self.ifo1.name, self.ifo2.name], names) def test_extend(self): self.ifo_list.extend(self.ifo_list) names = [ifo.name for ifo in self.ifo_list] - self.assertListEqual( - [self.ifo1.name, self.ifo2.name, self.ifo1.name, self.ifo2.name], names - ) + self.assertListEqual([self.ifo1.name, self.ifo2.name, self.ifo1.name, self.ifo2.name], names) def test_insert(self): new_ifo = self.ifo1 @@ -333,9 +295,7 @@ def test_to_and_from_pkl_wrong_class(self): with open("./outdir/psd.pkl", "wb") as ff: dill.dump(self.ifo_list[0].power_spectral_density, ff) - filename = self.ifo_list._filename_from_outdir_label_extension( - outdir="outdir", label="psd", extension="pkl" - ) + filename = self.ifo_list._filename_from_outdir_label_extension(outdir="outdir", label="psd", extension="pkl") with self.assertRaises(TypeError): bilby.gw.detector.InterferometerList.from_pickle(filename) @@ -376,10 +336,7 @@ def test_individual_positions(self): """ def a(delta_lat, delta_long, lat_1, lat_2): - return ( - np.sin(delta_lat / 2) ** 2 - + np.cos(lat_1) * np.cos(lat_2) * np.sin(delta_long / 2) ** 2 - ) + return np.sin(delta_lat / 2) ** 2 + np.cos(lat_1) * np.cos(lat_2) * np.sin(delta_long / 2) ** 2 def c(a): return 2 * np.arctan2(np.sqrt(a), np.sqrt(1 - a)) diff --git a/test/gw/detector/psd_test.py b/test/gw/detector/psd_test.py index 1f718c6ae..0fed8216c 100644 --- a/test/gw/detector/psd_test.py +++ b/test/gw/detector/psd_test.py @@ -19,34 +19,26 @@ def tearDown(self): del self.asd_array def test_init_with_asd_array(self): - psd = bilby.gw.detector.PowerSpectralDensity( - frequency_array=self.frequency_array, asd_array=self.asd_array - ) + psd = bilby.gw.detector.PowerSpectralDensity(frequency_array=self.frequency_array, asd_array=self.asd_array) self.assertTrue(np.array_equal(self.frequency_array, psd.frequency_array)) self.assertTrue(np.array_equal(self.asd_array, psd.asd_array)) self.assertTrue(np.array_equal(self.psd_array, psd.psd_array)) def test_init_with_psd_array(self): - psd = bilby.gw.detector.PowerSpectralDensity( - frequency_array=self.frequency_array, psd_array=self.psd_array - ) + psd = bilby.gw.detector.PowerSpectralDensity(frequency_array=self.frequency_array, psd_array=self.psd_array) self.assertTrue(np.array_equal(self.frequency_array, psd.frequency_array)) self.assertTrue(np.array_equal(self.asd_array, psd.asd_array)) self.assertTrue(np.array_equal(self.psd_array, psd.psd_array)) def test_setting_asd_array_after_init(self): - psd = bilby.gw.detector.PowerSpectralDensity( - frequency_array=self.frequency_array - ) + psd = bilby.gw.detector.PowerSpectralDensity(frequency_array=self.frequency_array) psd.asd_array = self.asd_array self.assertTrue(np.array_equal(self.frequency_array, psd.frequency_array)) self.assertTrue(np.array_equal(self.asd_array, psd.asd_array)) self.assertTrue(np.array_equal(self.psd_array, psd.psd_array)) def test_setting_psd_array_after_init(self): - psd = bilby.gw.detector.PowerSpectralDensity( - frequency_array=self.frequency_array - ) + psd = bilby.gw.detector.PowerSpectralDensity(frequency_array=self.frequency_array) psd.psd_array = self.psd_array self.assertTrue(np.array_equal(self.frequency_array, psd.frequency_array)) self.assertTrue(np.array_equal(self.asd_array, psd.asd_array)) @@ -54,16 +46,12 @@ def test_setting_psd_array_after_init(self): def test_power_spectral_density_interpolated_from_asd_array(self): expected = np.array([25.0]) - psd = bilby.gw.detector.PowerSpectralDensity( - frequency_array=self.frequency_array, asd_array=self.asd_array - ) + psd = bilby.gw.detector.PowerSpectralDensity(frequency_array=self.frequency_array, asd_array=self.asd_array) self.assertEqual(expected, psd.power_spectral_density_interpolated(2)) def test_power_spectral_density_interpolated_from_psd_array(self): expected = np.array([25.0]) - psd = bilby.gw.detector.PowerSpectralDensity( - frequency_array=self.frequency_array, psd_array=self.psd_array - ) + psd = bilby.gw.detector.PowerSpectralDensity(frequency_array=self.frequency_array, psd_array=self.psd_array) self.assertEqual(expected, psd.power_spectral_density_interpolated(2)) def test_from_amplitude_spectral_density_array(self): @@ -81,11 +69,10 @@ def test_from_power_spectral_density_array(self): self.assertTrue(np.array_equal(self.asd_array, actual.asd_array)) def test_repr(self): - psd = bilby.gw.detector.PowerSpectralDensity( - frequency_array=self.frequency_array, psd_array=self.psd_array - ) - expected = "PowerSpectralDensity(frequency_array={}, psd_array={}, asd_array={})".format( - self.frequency_array, self.psd_array, self.asd_array + psd = bilby.gw.detector.PowerSpectralDensity(frequency_array=self.frequency_array, psd_array=self.psd_array) + expected = ( + f"PowerSpectralDensity(frequency_array={self.frequency_array}, " + f"psd_array={self.psd_array}, asd_array={self.asd_array})" ) self.assertEqual(expected, repr(psd)) @@ -94,12 +81,8 @@ class TestPowerSpectralDensityWithFiles(unittest.TestCase): def setUp(self): self.dir = os.path.join(os.path.dirname(__file__), "noise_curves") os.mkdir(self.dir) - self.asd_file = os.path.join( - os.path.dirname(__file__), "noise_curves", "asd_test_file.txt" - ) - self.psd_file = os.path.join( - os.path.dirname(__file__), "noise_curves", "psd_test_file.txt" - ) + self.asd_file = os.path.join(os.path.dirname(__file__), "noise_curves", "asd_test_file.txt") + self.psd_file = os.path.join(os.path.dirname(__file__), "noise_curves", "psd_test_file.txt") with open(self.asd_file, "w") as f: f.write("1.\t1.0e-21\n2.\t2.0e-21\n3.\t3.0e-21") with open(self.psd_file, "w") as f: @@ -119,34 +102,26 @@ def tearDown(self): del self.psd_file def test_init_with_psd_file(self): - psd = bilby.gw.detector.PowerSpectralDensity( - frequency_array=self.frequency_array, psd_file=self.psd_file - ) + psd = bilby.gw.detector.PowerSpectralDensity(frequency_array=self.frequency_array, psd_file=self.psd_file) self.assertEqual(self.psd_file, psd.psd_file) self.assertTrue(np.array_equal(self.psd_array, psd.psd_array)) self.assertTrue(np.allclose(self.asd_array, psd.asd_array, atol=1e-30)) def test_init_with_asd_file(self): - psd = bilby.gw.detector.PowerSpectralDensity( - frequency_array=self.frequency_array, asd_file=self.asd_file - ) + psd = bilby.gw.detector.PowerSpectralDensity(frequency_array=self.frequency_array, asd_file=self.asd_file) self.assertEqual(self.asd_file, psd.asd_file) self.assertTrue(np.allclose(self.psd_array, psd.psd_array, atol=1e-60)) self.assertTrue(np.array_equal(self.asd_array, psd.asd_array)) def test_setting_psd_array_after_init(self): - psd = bilby.gw.detector.PowerSpectralDensity( - frequency_array=self.frequency_array - ) + psd = bilby.gw.detector.PowerSpectralDensity(frequency_array=self.frequency_array) psd.psd_file = self.psd_file self.assertEqual(self.psd_file, psd.psd_file) self.assertTrue(np.array_equal(self.psd_array, psd.psd_array)) self.assertTrue(np.allclose(self.asd_array, psd.asd_array, atol=1e-30)) def test_init_with_asd_array_after_init(self): - psd = bilby.gw.detector.PowerSpectralDensity( - frequency_array=self.frequency_array - ) + psd = bilby.gw.detector.PowerSpectralDensity(frequency_array=self.frequency_array) psd.asd_file = self.asd_file self.assertEqual(self.asd_file, psd.asd_file) self.assertTrue(np.allclose(self.psd_array, psd.psd_array, atol=1e-60)) @@ -154,34 +129,22 @@ def test_init_with_asd_array_after_init(self): def test_power_spectral_density_interpolated_from_asd_file(self): expected = np.array([4.0e-42]) - psd = bilby.gw.detector.PowerSpectralDensity( - frequency_array=self.frequency_array, asd_file=self.asd_file - ) - self.assertTrue( - np.allclose( - expected, psd.power_spectral_density_interpolated(2), atol=1e-60 - ) - ) + psd = bilby.gw.detector.PowerSpectralDensity(frequency_array=self.frequency_array, asd_file=self.asd_file) + self.assertTrue(np.allclose(expected, psd.power_spectral_density_interpolated(2), atol=1e-60)) def test_power_spectral_density_interpolated_from_psd_file(self): expected = np.array([4.0e-42]) - psd = bilby.gw.detector.PowerSpectralDensity( - frequency_array=self.frequency_array, psd_file=self.psd_file - ) + psd = bilby.gw.detector.PowerSpectralDensity(frequency_array=self.frequency_array, psd_file=self.psd_file) self.assertAlmostEqual(expected, psd.power_spectral_density_interpolated(2)) def test_from_amplitude_spectral_density_file(self): - psd = bilby.gw.detector.PowerSpectralDensity.from_amplitude_spectral_density_file( - asd_file=self.asd_file - ) + psd = bilby.gw.detector.PowerSpectralDensity.from_amplitude_spectral_density_file(asd_file=self.asd_file) self.assertEqual(self.asd_file, psd.asd_file) self.assertTrue(np.allclose(self.psd_array, psd.psd_array, atol=1e-60)) self.assertTrue(np.array_equal(self.asd_array, psd.asd_array)) def test_from_power_spectral_density_file(self): - psd = bilby.gw.detector.PowerSpectralDensity.from_power_spectral_density_file( - psd_file=self.psd_file - ) + psd = bilby.gw.detector.PowerSpectralDensity.from_power_spectral_density_file(psd_file=self.psd_file) self.assertEqual(self.psd_file, psd.psd_file) self.assertTrue(np.array_equal(self.psd_array, psd.psd_array)) self.assertTrue(np.allclose(self.asd_array, psd.asd_array, atol=1e-30)) @@ -218,26 +181,18 @@ def test_check_file_not_called_asd_file_set_to_asd_file(self, mock_warning): def test_from_frame_file(self): expected_frequency_array = np.array([1.0, 2.0, 3.0]) expected_psd_array = np.array([16.0, 25.0, 36.0]) - with mock.patch( - "bilby.gw.detector.InterferometerStrainData.set_from_frame_file" - ) as _: - with mock.patch( - "bilby.gw.detector.InterferometerStrainData.create_power_spectral_density" - ) as n: + with mock.patch("bilby.gw.detector.InterferometerStrainData.set_from_frame_file") as _: + with mock.patch("bilby.gw.detector.InterferometerStrainData.create_power_spectral_density") as n: n.return_value = expected_frequency_array, expected_psd_array psd = bilby.gw.detector.PowerSpectralDensity.from_frame_file( frame_file=self.asd_file, psd_start_time=0, psd_duration=4 ) - self.assertTrue( - np.array_equal(expected_frequency_array, psd.frequency_array) - ) + self.assertTrue(np.array_equal(expected_frequency_array, psd.frequency_array)) self.assertTrue(np.array_equal(expected_psd_array, psd.psd_array)) def test_repr(self): psd = bilby.gw.detector.PowerSpectralDensity(psd_file=self.psd_file) - expected = "PowerSpectralDensity(psd_file='{}', asd_file='{}')".format( - self.psd_file, None - ) + expected = f"PowerSpectralDensity(psd_file='{self.psd_file}', asd_file='{None}')" self.assertEqual(expected, repr(psd)) diff --git a/test/gw/detector/strain_data_test.py b/test/gw/detector/strain_data_test.py index 0f82a40a2..6a8b96080 100644 --- a/test/gw/detector/strain_data_test.py +++ b/test/gw/detector/strain_data_test.py @@ -30,16 +30,12 @@ def test_frequency_mask(self): frequency_domain_strain=np.array([0, 1, 2]), frequency_array=np.array([5, 15, 25]), ) - self.assertTrue( - np.array_equal(self.ifosd.frequency_mask, [False, True, False]) - ) + self.assertTrue(np.array_equal(self.ifosd.frequency_mask, [False, True, False])) def test_frequency_mask_2(self): - strain_data = bilby.gw.detector.InterferometerStrainData( - minimum_frequency=20, maximum_frequency=512) + strain_data = bilby.gw.detector.InterferometerStrainData(minimum_frequency=20, maximum_frequency=512) strain_data.set_from_time_domain_strain( - time_domain_strain=np.random.normal(0, 1, 4096), - time_array=np.arange(0, 4, 4 / 4096) + time_domain_strain=np.random.normal(0, 1, 4096), time_array=np.arange(0, 4, 4 / 4096) ) # Test from init @@ -56,10 +52,10 @@ def test_frequency_mask_2(self): def test_notches_frequency_mask(self): strain_data = bilby.gw.detector.InterferometerStrainData( - minimum_frequency=20, maximum_frequency=512, notch_list=[(100, 101)]) + minimum_frequency=20, maximum_frequency=512, notch_list=[(100, 101)] + ) strain_data.set_from_time_domain_strain( - time_domain_strain=np.random.normal(0, 1, 4096), - time_array=np.arange(0, 4, 4 / 4096) + time_domain_strain=np.random.normal(0, 1, 4096), time_array=np.arange(0, 4, 4 / 4096) ) # Test from init @@ -81,9 +77,7 @@ def test_set_data_fails(self): with mock.patch("bilby.core.utils.create_frequency_series") as m: m.return_value = [1, 2, 3] with self.assertRaises(ValueError): - self.ifosd.set_from_frequency_domain_strain( - frequency_domain_strain=np.array([0, 1, 2]) - ) + self.ifosd.set_from_frequency_domain_strain(frequency_domain_strain=np.array([0, 1, 2])) def test_set_data_fails_too_much(self): with mock.patch("bilby.core.utils.create_frequency_series") as m: @@ -124,12 +118,8 @@ def test_start_time_set(self): def test_time_array_frequency_array_consistency(self): duration = 1 sampling_frequency = 10 - time_array = bilby.core.utils.create_time_series( - sampling_frequency=sampling_frequency, duration=duration - ) - time_domain_strain = np.random.normal( - 0, duration - 1 / sampling_frequency, len(time_array) - ) + time_array = bilby.core.utils.create_time_series(sampling_frequency=sampling_frequency, duration=duration) + time_domain_strain = np.random.normal(0, duration - 1 / sampling_frequency, len(time_array)) self.ifosd.roll_off = 0 self.ifosd.set_from_time_domain_strain( time_domain_strain=time_domain_strain, @@ -137,15 +127,10 @@ def test_time_array_frequency_array_consistency(self): sampling_frequency=sampling_frequency, ) - frequency_domain_strain, freqs = bilby.core.utils.nfft( - time_domain_strain, sampling_frequency - ) + frequency_domain_strain, freqs = bilby.core.utils.nfft(time_domain_strain, sampling_frequency) self.assertTrue( - np.all( - self.ifosd.frequency_domain_strain - == frequency_domain_strain * self.ifosd.frequency_mask - ) + np.all(self.ifosd.frequency_domain_strain == frequency_domain_strain * self.ifosd.frequency_mask) ) def test_time_within_data_before(self): @@ -169,11 +154,9 @@ def test_time_domain_window_no_roll_off_no_alpha(self): self.ifosd._time_domain_strain = np.array([3]) self.ifosd.duration = 5 self.ifosd.roll_off = 2 - expected_window = scipy.signal.windows.tukey( - len(self.ifosd._time_domain_strain), alpha=self.ifosd.alpha - ) + expected_window = scipy.signal.windows.tukey(len(self.ifosd._time_domain_strain), alpha=self.ifosd.alpha) self.assertEqual(expected_window, self.ifosd.time_domain_window()) - self.assertEqual(np.mean(expected_window ** 2), self.ifosd.window_factor) + self.assertEqual(np.mean(expected_window**2), self.ifosd.window_factor) def test_time_domain_window_sets_roll_off_directly(self): self.ifosd._time_domain_strain = np.array([3]) @@ -217,9 +200,7 @@ def test_frequency_domain_strain_when_set(self): self.ifosd.duration = 4 expected_strain = self.ifosd.frequency_array * self.ifosd.frequency_mask self.ifosd._frequency_domain_strain = expected_strain - self.assertTrue( - np.array_equal(expected_strain, self.ifosd.frequency_domain_strain) - ) + self.assertTrue(np.array_equal(expected_strain, self.ifosd.frequency_domain_strain)) @mock.patch("bilby.core.utils.nfft") def test_frequency_domain_strain_from_frequency_domain_strain(self, m): @@ -394,7 +375,6 @@ def test_idxs(self): class TestNotchList(unittest.TestCase): - def test_init_single(self): notch_list_of_tuples = [(32, 34)] notch_list = bilby.gw.detector.strain_data.NotchList(notch_list_of_tuples) diff --git a/test/gw/eos/eos_test.py b/test/gw/eos/eos_test.py index 8f23b6a95..713ec53fc 100644 --- a/test/gw/eos/eos_test.py +++ b/test/gw/eos/eos_test.py @@ -1,18 +1,18 @@ import unittest -import numpy + import lalsimulation as lalsim -from bilby.gw.eos import SpectralDecompositionEOS, EOSFamily, TabularEOS -from bilby.core import utils +import numpy +from bilby.core import utils +from bilby.gw.eos import EOSFamily, SpectralDecompositionEOS, TabularEOS KNOWN_TOV_RESULT = 1 KNOWN_EOS_PRIOR_RESULTS = 1 ENERGY_FROM_PRESSURE_RESULTS = 0.04047517810698063 -EOS_FROM_TABLE = EOSFamily(TabularEOS('AP4', True)) -EOS_FROM_SPRECTRAL_DECOMPOSITION = EOSFamily(SpectralDecompositionEOS(gammas=[0.8651, 0.1548, -0.0151, -0.0002], - p0=1.5e33, - e0=2.02e14, - xmax=7.04)) +EOS_FROM_TABLE = EOSFamily(TabularEOS("AP4", True)) +EOS_FROM_SPRECTRAL_DECOMPOSITION = EOSFamily( + SpectralDecompositionEOS(gammas=[0.8651, 0.1548, -0.0151, -0.0002], p0=1.5e33, e0=2.02e14, xmax=7.04) +) MASSES_TO_TEST = numpy.linspace(0.9, 2.0, 12) PRESSURE = 1e-11 ENTHALPY = 1.2 @@ -20,7 +20,7 @@ # Setup for MPA1 Comparison PRESSURE_0 = 5.3716e32 -ENERGY_DENSITY_0 = 1.1555e35 / (utils.speed_of_light * 100) ** 2. +ENERGY_DENSITY_0 = 1.1555e35 / (utils.speed_of_light * 100) ** 2.0 XMAX = 12.3081 GAMMAS = [1.0215, 0.1653, -0.0235, -0.0004] @@ -28,39 +28,44 @@ BILBY_MPA1 = SpectralDecompositionEOS(GAMMAS, PRESSURE_0, ENERGY_DENSITY_0, XMAX) MPA1_PRESSURES = BILBY_MPA1.e_pdat.T[0] -LALSIM_MPA1_ENERGY_DENSITY = [lalsim.SimNeutronStarEOSEnergyDensityOfPressureGeometerized(PRESSURE, LALSIM_MPA1) - for PRESSURE in MPA1_PRESSURES[:-1]] +LALSIM_MPA1_ENERGY_DENSITY = [ + lalsim.SimNeutronStarEOSEnergyDensityOfPressureGeometerized(PRESSURE, LALSIM_MPA1) + for PRESSURE in MPA1_PRESSURES[:-1] +] BILBY_MPA1_ENERGY_DENSITY = BILBY_MPA1.e_pdat.T[1][:-1] class TestEOSFamily(unittest.TestCase): def test_spectral_decomposition_energy_from_pressure(self): - self.assertAlmostEqual(EOS_FROM_TABLE.eos.energy_from_pressure(PRESSURE), - 3.2736497985232014e-10) + self.assertAlmostEqual(EOS_FROM_TABLE.eos.energy_from_pressure(PRESSURE), 3.2736497985232014e-10) - self.assertAlmostEqual(EOS_FROM_SPRECTRAL_DECOMPOSITION.eos.energy_from_pressure(PRESSURE), - 3.270622527256167e-10) + self.assertAlmostEqual( + EOS_FROM_SPRECTRAL_DECOMPOSITION.eos.energy_from_pressure(PRESSURE), 3.270622527256167e-10 + ) def test_spectral_decomposition_pressure_from_pseudo_enthalpy(self): - self.assertAlmostEqual(EOS_FROM_TABLE.eos.pressure_from_pseudo_enthalpy(ENTHALPY), - 2.7338376042831513e-09) + self.assertAlmostEqual(EOS_FROM_TABLE.eos.pressure_from_pseudo_enthalpy(ENTHALPY), 2.7338376042831513e-09) - self.assertAlmostEqual(EOS_FROM_SPRECTRAL_DECOMPOSITION.eos.pressure_from_pseudo_enthalpy(ENTHALPY), - 2.754018499535077e-09) + self.assertAlmostEqual( + EOS_FROM_SPRECTRAL_DECOMPOSITION.eos.pressure_from_pseudo_enthalpy(ENTHALPY), 2.754018499535077e-09 + ) def test_spectral_decomposition_energy_density_from_pseudo_enthalpy(self): - self.assertAlmostEqual(EOS_FROM_TABLE.eos.energy_density_from_pseudo_enthalpy(ENTHALPY), - 2.9486942467903607e-09) + self.assertAlmostEqual(EOS_FROM_TABLE.eos.energy_density_from_pseudo_enthalpy(ENTHALPY), 2.9486942467903607e-09) - self.assertAlmostEqual(EOS_FROM_SPRECTRAL_DECOMPOSITION.eos.energy_density_from_pseudo_enthalpy(ENTHALPY), - 3.0402598495601078e-09) + self.assertAlmostEqual( + EOS_FROM_SPRECTRAL_DECOMPOSITION.eos.energy_density_from_pseudo_enthalpy(ENTHALPY), 3.0402598495601078e-09 + ) def test_spectral_decomposition_pseudo_enthalpy_from_energy_density(self): - self.assertAlmostEqual(EOS_FROM_TABLE.eos.pseudo_enthalpy_from_energy_density(ENERGY_DENSITY), - 0.024415755781136812) + self.assertAlmostEqual( + EOS_FROM_TABLE.eos.pseudo_enthalpy_from_energy_density(ENERGY_DENSITY), 0.024415755781136812 + ) print(EOS_FROM_SPRECTRAL_DECOMPOSITION.eos.pseudo_enthalpy_from_energy_density(ENERGY_DENSITY)) - self.assertAlmostEqual(EOS_FROM_SPRECTRAL_DECOMPOSITION.eos.pseudo_enthalpy_from_energy_density(ENERGY_DENSITY), - 0.02420629785967365) + self.assertAlmostEqual( + EOS_FROM_SPRECTRAL_DECOMPOSITION.eos.pseudo_enthalpy_from_energy_density(ENERGY_DENSITY), + 0.02420629785967365, + ) class TestBilbyLALSimComparison(unittest.TestCase): diff --git a/test/gw/likelihood/marginalization_test.py b/test/gw/likelihood/marginalization_test.py index 3538a4c58..8ec32a89c 100644 --- a/test/gw/likelihood/marginalization_test.py +++ b/test/gw/likelihood/marginalization_test.py @@ -1,17 +1,18 @@ import itertools import os -import pytest import unittest from copy import deepcopy from itertools import product -from parameterized import parameterized import numpy as np -import bilby -from bilby.gw.detector import calibration +import pytest +from parameterized import parameterized from scipy.integrate import trapezoid from scipy.special import logsumexp +import bilby +from bilby.gw.detector import calibration + class TestMarginalizedLikelihood(unittest.TestCase): def setUp(self): @@ -166,6 +167,7 @@ class TestMarginalizations(unittest.TestCase): For time, this is strongly dependent on the specific time grid used. The `time_jitter` parameter makes this a weaker dependence during sampling. """ + _parameters = product( ["regular", "roq", "relbin", "multiband"], ["luminosity_distance", "geocent_time", "phase"], @@ -218,16 +220,13 @@ def setUp(self): reference_frequency=20.0, minimum_frequency=20.0, waveform_approximant="IMRPhenomPv2", - ) - ) - self.interferometers.inject_signal( - parameters=self.parameters, waveform_generator=self.waveform_generator + ), ) + self.interferometers.inject_signal(parameters=self.parameters, waveform_generator=self.waveform_generator) self.priors = bilby.gw.prior.BBHPriorDict() self.priors["geocent_time"] = bilby.prior.Uniform( - minimum=self.parameters["geocent_time"] - 0.1, - maximum=self.parameters["geocent_time"] + 0.1 + minimum=self.parameters["geocent_time"] - 0.1, maximum=self.parameters["geocent_time"] + 0.1 ) trial_roq_paths = [ @@ -253,7 +252,7 @@ def setUp(self): waveform_approximant="IMRPhenomPv2", frequency_nodes_linear=np.load(f"{roq_dir}/fnodes_linear.npy"), frequency_nodes_quadratic=np.load(f"{roq_dir}/fnodes_quadratic.npy"), - ) + ), ) self.roq_linear_matrix_file = f"{roq_dir}/B_linear.npy" self.roq_quadratic_matrix_file = f"{roq_dir}/B_quadratic.npy" @@ -267,7 +266,7 @@ def setUp(self): reference_frequency=20.0, minimum_frequency=20.0, waveform_approximant="IMRPhenomPv2", - ) + ), ) self.multiband_waveform_generator = bilby.gw.WaveformGenerator( @@ -278,7 +277,7 @@ def setUp(self): waveform_arguments=dict( reference_frequency=20.0, waveform_approximant="IMRPhenomPv2", - ) + ), ) def tearDown(self): @@ -316,11 +315,13 @@ def likelihood_kwargs(self, kind, time_marginalization, phase_marginalization, d priors=priors, ) if kind == "roq": - kwargs.update(dict( - linear_matrix=self.roq_linear_matrix_file, - quadratic_matrix=self.roq_quadratic_matrix_file, - waveform_generator=self.roq_waveform_generator, - )) + kwargs.update( + dict( + linear_matrix=self.roq_linear_matrix_file, + quadratic_matrix=self.roq_quadratic_matrix_file, + waveform_generator=self.roq_waveform_generator, + ) + ) if os.path.exists(self.__class__.path_to_roq_weights): kwargs["weights"] = self.__class__.path_to_roq_weights elif kind == "relbin": @@ -328,19 +329,13 @@ def likelihood_kwargs(self, kind, time_marginalization, phase_marginalization, d kwargs["waveform_generator"] = self.relbin_waveform_generator elif kind == "multiband": kwargs["waveform_generator"] = self.multiband_waveform_generator - kwargs["reference_chirp_mass"] = ( - (self.parameters["mass_1"] * self.parameters["mass_2"])**0.6 / - (self.parameters["mass_1"] + self.parameters["mass_2"])**0.2 - ) + kwargs["reference_chirp_mass"] = (self.parameters["mass_1"] * self.parameters["mass_2"]) ** 0.6 / ( + self.parameters["mass_1"] + self.parameters["mass_2"] + ) ** 0.2 return kwargs def get_likelihood( - self, - kind, - time_marginalization=False, - phase_marginalization=False, - distance_marginalization=False, - priors=None + self, kind, time_marginalization=False, phase_marginalization=False, distance_marginalization=False, priors=None ): kwargs = self.likelihood_kwargs( kind, time_marginalization, phase_marginalization, distance_marginalization, priors @@ -382,18 +377,14 @@ def _template(self, marginalized, non_marginalized, key, prior=None, values=None like = np.exp(ln_likes - max(ln_likes)) marg_like = np.log(trapezoid(like * prior_values, values)) + max(ln_likes) - self.assertAlmostEqual( - marg_like, marginalized.log_likelihood_ratio(), delta=0.5 - ) + self.assertAlmostEqual(marg_like, marginalized.log_likelihood_ratio(), delta=0.5) @parameterized.expand( _parameters, name_func=lambda func, num, param: ( - f"{func.__name__}_{num}__{param.args[0]}_{param.args[1]}_" + "_".join([ - ["D", "T", "P"][ii] for ii, val - in enumerate(param.args[-3:]) if val - ]) - ) + f"{func.__name__}_{num}__{param.args[0]}_{param.args[1]}_" + + "_".join([["D", "T", "P"][ii] for ii, val in enumerate(param.args[-3:]) if val]) + ), ) def test_marginalisation(self, kind, key, distance, time, phase): if all([distance, time, phase]): @@ -437,11 +428,9 @@ def test_time_marginalisation_full_segment(self, kind): @parameterized.expand( itertools.product(["regular", "roq", "relbin", "multiband"], *itertools.repeat([True, False], 3)), name_func=lambda func, num, param: ( - f"{func.__name__}_{num}__{param.args[0]}_" + "_".join([ - ["D", "P", "T"][ii] for ii, val - in enumerate(param.args[1:]) if val - ]) - ) + f"{func.__name__}_{num}__{param.args[0]}_" + + "_".join([["D", "P", "T"][ii] for ii, val in enumerate(param.args[1:]) if val]) + ), ) def test_marginalization_reconstruction(self, kind, distance, phase, time): marginalizations = dict( @@ -470,12 +459,9 @@ def test_marginalization_reconstruction(self, kind, distance, phase, time): class CalibrationMarginalization(unittest.TestCase): - def setUp(self): self.ifos = bilby.gw.detector.InterferometerList(["H1", "L1"]) - self.ifos.set_strain_data_from_power_spectral_densities( - duration=4, sampling_frequency=1024 - ) + self.ifos.set_strain_data_from_power_spectral_densities(duration=4, sampling_frequency=1024) self.ifos[0].calibration_model = calibration.CubicSpline( prefix="recalib_H1_", minimum_frequency=20, @@ -492,14 +478,16 @@ def setUp(self): ) self.priors = bilby.gw.prior.BBHPriorDict() self.priors["geocent_time"] = bilby.core.prior.Uniform(0, 4) - self.priors.update(bilby.gw.prior.CalibrationPriorDict.constant_uncertainty_spline( - amplitude_sigma=0.1, - phase_sigma=0.1, - minimum_frequency=20, - maximum_frequency=512, - n_nodes=5, - label="H1", - )) + self.priors.update( + bilby.gw.prior.CalibrationPriorDict.constant_uncertainty_spline( + amplitude_sigma=0.1, + phase_sigma=0.1, + minimum_frequency=20, + maximum_frequency=512, + n_nodes=5, + label="H1", + ) + ) self.wfg = bilby.gw.waveform_generator.WaveformGenerator( frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole, parameter_conversion=bilby.gw.conversion.convert_to_lal_binary_black_hole_parameters, diff --git a/test/gw/likelihood/relative_binning_test.py b/test/gw/likelihood/relative_binning_test.py index 7687453ab..e4a8a7e5d 100644 --- a/test/gw/likelihood/relative_binning_test.py +++ b/test/gw/likelihood/relative_binning_test.py @@ -1,10 +1,11 @@ import unittest from copy import deepcopy -import bilby import numpy as np from parameterized import parameterized +import bilby + class TestRelativeBinningLikelihood(unittest.TestCase): def setUp(self): @@ -41,8 +42,9 @@ def setUp(self): ifos = bilby.gw.detector.InterferometerList(["H1", "L1", "V1"]) ifos.set_strain_data_from_power_spectral_densities( - sampling_frequency=sampling_frequency, duration=duration, - start_time=self.test_parameters['geocent_time'] - duration + 2. + sampling_frequency=sampling_frequency, + duration=duration, + start_time=self.test_parameters["geocent_time"] - duration + 2.0, ) for ifo in ifos: ifo.minimum_frequency = fmin @@ -54,16 +56,16 @@ def setUp(self): prefix=f"recalib_{ifo.name}_", minimum_frequency=ifo.minimum_frequency, maximum_frequency=ifo.maximum_frequency, - n_points=spline_calibration_nodes + n_points=spline_calibration_nodes, ) for i in range(spline_calibration_nodes): self.test_parameters[f"recalib_{ifo.name}_amplitude_{i}"] = 0 self.test_parameters[f"recalib_{ifo.name}_phase_{i}"] = 0 # Calibration errors of 5% in amplitude and 5 degrees in phase - self.calibration_parameters[f"recalib_{ifo.name}_amplitude_{i}"] = \ - np.random.normal(loc=0, scale=0.05) - self.calibration_parameters[f"recalib_{ifo.name}_phase_{i}"] = \ - np.random.normal(loc=0, scale=5 * np.pi / 180) + self.calibration_parameters[f"recalib_{ifo.name}_amplitude_{i}"] = np.random.normal(loc=0, scale=0.05) + self.calibration_parameters[f"recalib_{ifo.name}_phase_{i}"] = np.random.normal( + loc=0, scale=5 * np.pi / 180 + ) priors = bilby.gw.prior.BBHPriorDict() priors.pop("mass_1") @@ -76,16 +78,16 @@ def setUp(self): approximant = "IMRPhenomXP" non_bin_wfg = bilby.gw.WaveformGenerator( - duration=duration, sampling_frequency=sampling_frequency, + duration=duration, + sampling_frequency=sampling_frequency, frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole, - waveform_arguments=dict( - reference_frequency=fmin, minimum_frequency=fmin, waveform_approximant=approximant) + waveform_arguments=dict(reference_frequency=fmin, minimum_frequency=fmin, waveform_approximant=approximant), ) bin_wfg = bilby.gw.waveform_generator.WaveformGenerator( - duration=duration, sampling_frequency=sampling_frequency, + duration=duration, + sampling_frequency=sampling_frequency, frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole_relative_binning, - waveform_arguments=dict( - reference_frequency=fmin, waveform_approximant=approximant, minimum_frequency=fmin) + waveform_arguments=dict(reference_frequency=fmin, waveform_approximant=approximant, minimum_frequency=fmin), ) ifos.inject_signal( parameters=self.test_parameters, @@ -95,11 +97,11 @@ def setUp(self): self.ifos = ifos self.non_bin = bilby.gw.likelihood.GravitationalWaveTransient( - interferometers=ifos, waveform_generator=deepcopy(non_bin_wfg), - priors=priors.copy() + interferometers=ifos, waveform_generator=deepcopy(non_bin_wfg), priors=priors.copy() ) self.binned = bilby.gw.likelihood.RelativeBinningGravitationalWaveTransient( - interferometers=ifos, waveform_generator=deepcopy(bin_wfg), + interferometers=ifos, + waveform_generator=deepcopy(bin_wfg), fiducial_parameters=self.fiducial_parameters, priors=priors.copy(), epsilon=0.05, @@ -119,11 +121,7 @@ def test_matches_non_binned_many(self): parameters = self.priors.sample() regular_ln_l = self.non_bin.log_likelihood_ratio(parameters) binned_ln_l = self.binned.log_likelihood_ratio(parameters) - self.assertLess( - abs(regular_ln_l - binned_ln_l) - / abs(self.reference_ln_l - regular_ln_l), - 0.1 - ) + self.assertLess(abs(regular_ln_l - binned_ln_l) / abs(self.reference_ln_l - regular_ln_l), 0.1) def test_matches_non_binned_many_state(self): for _ in range(100): @@ -132,13 +130,9 @@ def test_matches_non_binned_many_state(self): self.binned.parameters.update(parameters) regular_ln_l = self.non_bin.log_likelihood_ratio() binned_ln_l = self.binned.log_likelihood_ratio() - self.assertLess( - abs(regular_ln_l - binned_ln_l) - / abs(self.reference_ln_l - regular_ln_l), - 0.1 - ) + self.assertLess(abs(regular_ln_l - binned_ln_l) / abs(self.reference_ln_l - regular_ln_l), 0.1) - @parameterized.expand([(False, ), (True, )]) + @parameterized.expand([(False,), (True,)]) def test_matches_non_binned(self, add_cal_errors): parameters = deepcopy(self.test_parameters) if add_cal_errors: @@ -157,12 +151,24 @@ def test_optimization_gives_good_match(self): fiducial_parameters["chirp_mass"] *= 0.99 priors = self.priors.copy() for key in [ - "ra", "dec", "geocent_time", "phase", "psi", "theta_jn", "luminosity_distance", - "a_1", "a_2", "tilt_1", "tilt_2", "phi_12", "phi_jl", + "ra", + "dec", + "geocent_time", + "phase", + "psi", + "theta_jn", + "luminosity_distance", + "a_1", + "a_2", + "tilt_1", + "tilt_2", + "phi_12", + "phi_jl", ]: priors[key] = self.test_parameters[key] binned = bilby.gw.likelihood.RelativeBinningGravitationalWaveTransient( - interferometers=self.ifos, waveform_generator=deepcopy(self.bin_wfg), + interferometers=self.ifos, + waveform_generator=deepcopy(self.bin_wfg), priors=priors, fiducial_parameters=fiducial_parameters, epsilon=0.05, @@ -183,7 +189,8 @@ def test_very_small_epsilon_returns_good_value(self): test that we avoid this. """ binned = bilby.gw.likelihood.RelativeBinningGravitationalWaveTransient( - interferometers=self.ifos, waveform_generator=deepcopy(self.bin_wfg), + interferometers=self.ifos, + waveform_generator=deepcopy(self.bin_wfg), fiducial_parameters=self.fiducial_parameters, priors=self.priors.copy(), epsilon=0.001, @@ -227,7 +234,7 @@ def test_likelihood_when_waveform_extends_beyond_maximum_frequency(self): ifos.set_strain_data_from_zero_noise( sampling_frequency=sampling_frequency, duration=duration, - start_time=test_parameters['geocent_time'] - duration + 2. + start_time=test_parameters["geocent_time"] - duration + 2.0, ) for ifo in ifos: ifo.minimum_frequency = fmin @@ -243,21 +250,13 @@ def test_likelihood_when_waveform_extends_beyond_maximum_frequency(self): duration=duration, sampling_frequency=sampling_frequency, frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole, - waveform_arguments=dict( - reference_frequency=fmin, - minimum_frequency=fmin, - waveform_approximant=approximant - ) + waveform_arguments=dict(reference_frequency=fmin, minimum_frequency=fmin, waveform_approximant=approximant), ) bin_wfg = bilby.gw.waveform_generator.WaveformGenerator( duration=duration, sampling_frequency=sampling_frequency, frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole_relative_binning, - waveform_arguments=dict( - reference_frequency=fmin, - waveform_approximant=approximant, - minimum_frequency=fmin - ) + waveform_arguments=dict(reference_frequency=fmin, waveform_approximant=approximant, minimum_frequency=fmin), ) ifos.inject_signal( parameters=test_parameters, @@ -266,9 +265,7 @@ def test_likelihood_when_waveform_extends_beyond_maximum_frequency(self): ) non_bin = bilby.gw.likelihood.GravitationalWaveTransient( - interferometers=ifos, - waveform_generator=deepcopy(non_bin_wfg), - priors=priors.copy() + interferometers=ifos, waveform_generator=deepcopy(non_bin_wfg), priors=priors.copy() ) binned = bilby.gw.likelihood.RelativeBinningGravitationalWaveTransient( interferometers=ifos, diff --git a/test/gw/likelihood_test.py b/test/gw/likelihood_test.py index 4ee18e6a4..4106bfe78 100644 --- a/test/gw/likelihood_test.py +++ b/test/gw/likelihood_test.py @@ -1,13 +1,14 @@ import os -import unittest import tempfile -from itertools import product -from parameterized import parameterized -import pytest +import unittest from copy import deepcopy +from itertools import product import h5py import numpy as np +import pytest +from parameterized import parameterized + import bilby from bilby.gw.likelihood import BilbyROQParamsRangeError @@ -33,9 +34,7 @@ def setUp(self): dec=-1.2108, ) self.interferometers = bilby.gw.detector.InterferometerList(["H1"]) - self.interferometers.set_strain_data_from_power_spectral_densities( - sampling_frequency=2048, duration=4 - ) + self.interferometers.set_strain_data_from_power_spectral_densities(sampling_frequency=2048, duration=4) self.waveform_generator = bilby.gw.waveform_generator.GWSignalWaveformGenerator( duration=4, sampling_frequency=2048, @@ -57,9 +56,7 @@ def tearDown(self): def test_noise_log_likelihood(self): """Test noise log likelihood matches precomputed value""" self.likelihood.noise_log_likelihood() - self.assertAlmostEqual( - -4014.1787704539474, self.likelihood.noise_log_likelihood(), 3 - ) + self.assertAlmostEqual(-4014.1787704539474, self.likelihood.noise_log_likelihood(), 3) def test_log_likelihood(self): """Test log likelihood matches precomputed value""" @@ -81,8 +78,9 @@ def test_likelihood_zero_when_waveform_is_none(self): self.assertEqual(self.likelihood.log_likelihood_ratio(), np.nan_to_num(-np.inf)) def test_repr(self): - expected = "BasicGravitationalWaveTransient(interferometers={},\n\twaveform_generator={})".format( - self.interferometers, self.waveform_generator + expected = ( + f"BasicGravitationalWaveTransient(interferometers={self.interferometers}," + f"\n\twaveform_generator={self.waveform_generator})" ) self.assertEqual(expected, repr(self.likelihood)) @@ -142,15 +140,12 @@ def tearDown(self): def test_noise_log_likelihood(self): """Test noise log likelihood matches precomputed value""" self.likelihood.noise_log_likelihood() - self.assertAlmostEqual( - -4014.1787704539474, self.likelihood.noise_log_likelihood(), 3 - ) + self.assertAlmostEqual(-4014.1787704539474, self.likelihood.noise_log_likelihood(), 3) def test_log_likelihood(self): """Test log likelihood matches precomputed value""" self.likelihood.log_likelihood() - self.assertAlmostEqual(self.likelihood.log_likelihood(), - -4032.4397343470005, 3) + self.assertAlmostEqual(self.likelihood.log_likelihood(), -4032.4397343470005, 3) def test_log_likelihood_ratio(self): """Test log likelihood ratio returns the correct value""" @@ -168,42 +163,24 @@ def test_likelihood_zero_when_waveform_is_none(self): def test_repr(self): expected = ( - "GravitationalWaveTransient(interferometers={},\n\twaveform_generator={},\n\t" - "time_marginalization={}, distance_marginalization={}, phase_marginalization={}, " - "calibration_marginalization={}, priors={})".format( - self.interferometers, - self.waveform_generator, - False, - False, - False, - False, - self.prior, - ) + f"GravitationalWaveTransient(interferometers={self.interferometers},\n\twaveform_generator={self.waveform_generator},\n\t" + f"time_marginalization={False}, distance_marginalization={False}, phase_marginalization={False}, " + f"calibration_marginalization={False}, priors={self.prior})" ) self.assertEqual(expected, repr(self.likelihood)) def test_interferometers_setting_list(self): - ifos = [ - bilby.gw.detector.get_empty_interferometer(name=name) - for name in ["H1", "L1"] - ] + ifos = [bilby.gw.detector.get_empty_interferometer(name=name) for name in ["H1", "L1"]] self.likelihood.interferometers = ifos - self.assertListEqual( - bilby.gw.detector.InterferometerList(ifos), self.likelihood.interferometers - ) + self.assertListEqual(bilby.gw.detector.InterferometerList(ifos), self.likelihood.interferometers) self.assertIsInstance(self.likelihood.interferometers, bilby.gw.detector.InterferometerList) def test_interferometers_setting_interferometer_list(self): ifos = bilby.gw.detector.InterferometerList( - [ - bilby.gw.detector.get_empty_interferometer(name=name) - for name in ["H1", "L1"] - ] + [bilby.gw.detector.get_empty_interferometer(name=name) for name in ["H1", "L1"]] ) self.likelihood.interferometers = ifos - self.assertListEqual( - bilby.gw.detector.InterferometerList(ifos), self.likelihood.interferometers - ) + self.assertListEqual(bilby.gw.detector.InterferometerList(ifos), self.likelihood.interferometers) self.assertIsInstance(self.likelihood.interferometers, bilby.gw.detector.InterferometerList) def test_meta_data(self): @@ -233,7 +210,7 @@ def test_reference_frame_agrees_with_default(self): interferometers=self.interferometers, waveform_generator=self.waveform_generator, priors=self.prior.copy(), - reference_frame="H1L1" + reference_frame="H1L1", ) parameters = self.parameters.copy() del parameters["ra"], parameters["dec"] @@ -243,44 +220,34 @@ def test_reference_frame_agrees_with_default(self): zenith=parameters["zenith"], azimuth=parameters["azimuth"], geocent_time=parameters["geocent_time"], - ifos=bilby.gw.detector.InterferometerList(["H1", "L1"]) + ifos=bilby.gw.detector.InterferometerList(["H1", "L1"]), ) self.assertEqual( - new_likelihood.log_likelihood_ratio(parameters), - self.likelihood.log_likelihood_ratio(parameters) + new_likelihood.log_likelihood_ratio(parameters), self.likelihood.log_likelihood_ratio(parameters) ) new_likelihood.parameters.update(parameters) self.likelihood.parameters.update(parameters) - self.assertEqual( - new_likelihood.log_likelihood_ratio(), - self.likelihood.log_likelihood_ratio() - ) + self.assertEqual(new_likelihood.log_likelihood_ratio(), self.likelihood.log_likelihood_ratio()) def test_time_reference_agrees_with_default(self): new_likelihood = bilby.gw.likelihood.GravitationalWaveTransient( interferometers=self.interferometers, waveform_generator=self.waveform_generator, priors=self.prior.copy(), - time_reference="H1" + time_reference="H1", ) ifo = bilby.gw.detector.get_empty_interferometer("H1") time_delay = ifo.time_delay_from_geocenter( - ra=self.parameters["ra"], - dec=self.parameters["dec"], - time=self.parameters["geocent_time"] + ra=self.parameters["ra"], dec=self.parameters["dec"], time=self.parameters["geocent_time"] ) parameters = self.parameters.copy() parameters["H1_time"] = parameters["geocent_time"] + time_delay self.assertEqual( - new_likelihood.log_likelihood_ratio(parameters), - self.likelihood.log_likelihood_ratio(parameters) + new_likelihood.log_likelihood_ratio(parameters), self.likelihood.log_likelihood_ratio(parameters) ) new_likelihood.parameters.update(parameters) self.likelihood.parameters.update(parameters) - self.assertEqual( - new_likelihood.log_likelihood_ratio(), - self.likelihood.log_likelihood_ratio() - ) + self.assertEqual(new_likelihood.log_likelihood_ratio(), self.likelihood.log_likelihood_ratio()) @pytest.mark.requires_roqs @@ -303,16 +270,16 @@ def setUp(self): if roq_dir is None: raise Exception("Unable to load ROQ basis: cannot proceed with tests") - linear_matrix_file = "{}/B_linear.npy".format(roq_dir) - quadratic_matrix_file = "{}/B_quadratic.npy".format(roq_dir) + linear_matrix_file = f"{roq_dir}/B_linear.npy" + quadratic_matrix_file = f"{roq_dir}/B_quadratic.npy" - fnodes_linear_file = "{}/fnodes_linear.npy".format(roq_dir) + fnodes_linear_file = f"{roq_dir}/fnodes_linear.npy" fnodes_linear = np.load(fnodes_linear_file).T - fnodes_quadratic_file = "{}/fnodes_quadratic.npy".format(roq_dir) + fnodes_quadratic_file = f"{roq_dir}/fnodes_quadratic.npy" fnodes_quadratic = np.load(fnodes_quadratic_file).T - self.linear_matrix_file = "{}/B_linear.npy".format(roq_dir) - self.quadratic_matrix_file = "{}/B_quadratic.npy".format(roq_dir) - self.params_file = "{}/params.dat".format(roq_dir) + self.linear_matrix_file = f"{roq_dir}/B_linear.npy" + self.quadratic_matrix_file = f"{roq_dir}/B_quadratic.npy" + self.params_file = f"{roq_dir}/params.dat" self.test_parameters = dict( mass_1=36.0, @@ -356,9 +323,7 @@ def setUp(self): ), ) - ifos.inject_signal( - parameters=self.test_parameters, waveform_generator=non_roq_wfg - ) + ifos.inject_signal(parameters=self.test_parameters, waveform_generator=non_roq_wfg) self.ifos = ifos @@ -419,7 +384,8 @@ def test_matches_non_roq(self): abs( self.non_roq.log_likelihood_ratio(self.test_parameters) - self.roq.log_likelihood_ratio(self.test_parameters) - ) / self.non_roq.log_likelihood_ratio(self.test_parameters), + ) + / self.non_roq.log_likelihood_ratio(self.test_parameters), 1e-3, ) self.non_roq.parameters.update(self.test_parameters) @@ -447,15 +413,13 @@ def test_create_roq_weights_with_params(self): priors=self.priors, ) self.assertEqual( - roq.log_likelihood_ratio(self.test_parameters), - self.roq.log_likelihood_ratio(self.test_parameters) + roq.log_likelihood_ratio(self.test_parameters), self.roq.log_likelihood_ratio(self.test_parameters) ) roq.parameters.update(self.test_parameters) self.roq.parameters.update(self.test_parameters) self.assertEqual(roq.log_likelihood_ratio(), self.roq.log_likelihood_ratio()) def test_create_roq_weights_frequency_mismatch_works_with_params(self): - self.ifos[0].maximum_frequency = self.ifos[0].maximum_frequency / 2 bilby.gw.likelihood.ROQGravitationalWaveTransient( interferometers=self.ifos, @@ -519,10 +483,8 @@ def test_create_roq_weights_fails_with_min_component_mass_outside_bounds(self): def test_create_roq_weights_fails_with_max_frequency(self): ifos = bilby.gw.detector.InterferometerList(["H1"]) - ifos.set_strain_data_from_power_spectral_densities( - sampling_frequency=2 ** 14, duration=4 - ) - ifos[0].maximum_frequency = 2 ** 13 + ifos.set_strain_data_from_power_spectral_densities(sampling_frequency=2**14, duration=4) + ifos[0].maximum_frequency = 2**13 with self.assertRaises(BilbyROQParamsRangeError): bilby.gw.likelihood.ROQGravitationalWaveTransient( interferometers=ifos, @@ -547,9 +509,7 @@ def test_create_roq_weights_fails_due_to_min_frequency(self): def test_create_roq_weights_fails_due_to_duration(self): ifos = bilby.gw.detector.InterferometerList(["H1"]) - ifos.set_strain_data_from_power_spectral_densities( - sampling_frequency=self.sampling_frequency, duration=16 - ) + ifos.set_strain_data_from_power_spectral_densities(sampling_frequency=self.sampling_frequency, duration=16) with self.assertRaises(BilbyROQParamsRangeError): bilby.gw.likelihood.ROQGravitationalWaveTransient( interferometers=ifos, @@ -564,7 +524,6 @@ def test_create_roq_weights_fails_due_to_duration(self): @pytest.mark.requires_roqs class TestRescaledROQLikelihood(unittest.TestCase): def test_rescaling(self): - # Possible locations for the ROQ: in the docker image, local, or on CIT trial_roq_paths = [ "/roq_basis", @@ -579,16 +538,16 @@ def test_rescaling(self): if roq_dir is None: raise Exception("Unable to load ROQ basis: cannot proceed with tests") - linear_matrix_file = "{}/B_linear.npy".format(roq_dir) - quadratic_matrix_file = "{}/B_quadratic.npy".format(roq_dir) + linear_matrix_file = f"{roq_dir}/B_linear.npy" + quadratic_matrix_file = f"{roq_dir}/B_quadratic.npy" - fnodes_linear_file = "{}/fnodes_linear.npy".format(roq_dir) + fnodes_linear_file = f"{roq_dir}/fnodes_linear.npy" fnodes_linear = np.load(fnodes_linear_file).T - fnodes_quadratic_file = "{}/fnodes_quadratic.npy".format(roq_dir) + fnodes_quadratic_file = f"{roq_dir}/fnodes_quadratic.npy" fnodes_quadratic = np.load(fnodes_quadratic_file).T - self.linear_matrix_file = "{}/B_linear.npy".format(roq_dir) - self.quadratic_matrix_file = "{}/B_quadratic.npy".format(roq_dir) - self.params_file = "{}/params.dat".format(roq_dir) + self.linear_matrix_file = f"{roq_dir}/B_linear.npy" + self.quadratic_matrix_file = f"{roq_dir}/B_quadratic.npy" + self.params_file = f"{roq_dir}/params.dat" scale_factor = 0.5 params = np.genfromtxt(self.params_file, names=True) @@ -606,9 +565,7 @@ def test_rescaling(self): self.priors.pop("mass_1") self.priors.pop("mass_2") # Testing is done with the 4s IMRPhenomPV2 ROQ basis - self.priors["chirp_mass"] = bilby.core.prior.Uniform( - 12.299703 / scale_factor, 45 / scale_factor - ) + self.priors["chirp_mass"] = bilby.core.prior.Uniform(12.299703 / scale_factor, 45 / scale_factor) self.priors["mass_ratio"] = bilby.core.prior.Uniform(0.125, 1) self.priors["geocent_time"] = bilby.core.prior.Uniform(1.19, 1.21) @@ -665,26 +622,25 @@ def setUp(self): phase=1.3, geocent_time=1.2, ra=1.3, - dec=-1.2 + dec=-1.2, ) self.priors = bilby.gw.prior.BBHPriorDict() self.priors.pop("mass_1") self.priors.pop("mass_2") self.priors["mass_ratio"] = bilby.core.prior.Uniform(0.125, 1) self.priors["geocent_time"] = bilby.core.prior.Uniform( - self.injection_parameters["geocent_time"] - 0.1, - self.injection_parameters["geocent_time"] + 0.1 + self.injection_parameters["geocent_time"] - 0.1, self.injection_parameters["geocent_time"] + 0.1 ) @parameterized.expand( - [(_path_to_basis, 20., 2048., 16), - (_path_to_basis, 10., 1024., 16), - (_path_to_basis, 20., 1024., 32), - (_path_to_basis_mb, 20., 2048., 16)] + [ + (_path_to_basis, 20.0, 2048.0, 16), + (_path_to_basis, 10.0, 1024.0, 16), + (_path_to_basis, 20.0, 1024.0, 32), + (_path_to_basis_mb, 20.0, 2048.0, 16), + ] ) - def test_fails_with_frequency_duration_mismatch( - self, basis, minimum_frequency, maximum_frequency, duration - ): + def test_fails_with_frequency_duration_mismatch(self, basis, minimum_frequency, maximum_frequency, duration): """Test if likelihood fails as expected, when data frequency range is not within the basis range or data duration does not match the basis duration. The basis frequency range and duration are 20--1024Hz and @@ -695,7 +651,7 @@ def test_fails_with_frequency_duration_mismatch( interferometers.set_strain_data_from_power_spectral_densities( sampling_frequency=2 * maximum_frequency, duration=duration, - start_time=self.injection_parameters["geocent_time"] - duration + 1 + start_time=self.injection_parameters["geocent_time"] - duration + 1, ) for ifo in interferometers: ifo.minimum_frequency = minimum_frequency @@ -705,9 +661,8 @@ def test_fails_with_frequency_duration_mismatch( sampling_frequency=2 * maximum_frequency, frequency_domain_source_model=bilby.gw.source.binary_black_hole_roq, waveform_arguments=dict( - reference_frequency=self.reference_frequency, - waveform_approximant=self.waveform_approximant - ) + reference_frequency=self.reference_frequency, waveform_approximant=self.waveform_approximant + ), ) with self.assertRaises(BilbyROQParamsRangeError): bilby.gw.likelihood.ROQGravitationalWaveTransient( @@ -728,7 +683,7 @@ def test_fails_with_prior_mismatch(self, basis, chirp_mass_min, chirp_mass_max): interferometers.set_strain_data_from_power_spectral_densities( sampling_frequency=self.sampling_frequency, duration=self.duration, - start_time=self.injection_parameters["geocent_time"] - self.duration + 1 + start_time=self.injection_parameters["geocent_time"] - self.duration + 1, ) for ifo in interferometers: ifo.minimum_frequency = self.minimum_frequency @@ -737,9 +692,8 @@ def test_fails_with_prior_mismatch(self, basis, chirp_mass_min, chirp_mass_max): sampling_frequency=self.sampling_frequency, frequency_domain_source_model=bilby.gw.source.binary_black_hole_roq, waveform_arguments=dict( - reference_frequency=self.reference_frequency, - waveform_approximant=self.waveform_approximant - ) + reference_frequency=self.reference_frequency, waveform_approximant=self.waveform_approximant + ), ) with self.assertRaises(BilbyROQParamsRangeError): bilby.gw.likelihood.ROQGravitationalWaveTransient( @@ -755,7 +709,7 @@ def test_fails_with_prior_mismatch(self, basis, chirp_mass_min, chirp_mass_max): [_path_to_basis, _path_to_basis_mb], [_path_to_basis, _path_to_basis_mb], [(8, 9), (8, 10.5), (8, 11.5), (8, 12.5), (8, 14)], - [1, 2] + [1, 2], ) ) def test_number_of_loaded_bases(self, basis_linear, basis_quadratic, mc_range, roq_scale_factor): @@ -774,7 +728,7 @@ def test_number_of_loaded_bases(self, basis_linear, basis_quadratic, mc_range, r interferometers.set_strain_data_from_power_spectral_densities( sampling_frequency=self.sampling_frequency, duration=self.duration, - start_time=self.injection_parameters["geocent_time"] - self.duration + 1 + start_time=self.injection_parameters["geocent_time"] - self.duration + 1, ) for ifo in interferometers: ifo.minimum_frequency = self.minimum_frequency @@ -784,9 +738,8 @@ def test_number_of_loaded_bases(self, basis_linear, basis_quadratic, mc_range, r sampling_frequency=self.sampling_frequency, frequency_domain_source_model=bilby.gw.source.binary_black_hole_roq, waveform_arguments=dict( - reference_frequency=self.reference_frequency, - waveform_approximant=self.waveform_approximant - ) + reference_frequency=self.reference_frequency, waveform_approximant=self.waveform_approximant + ), ) likelihood = bilby.gw.likelihood.ROQGravitationalWaveTransient( @@ -795,7 +748,7 @@ def test_number_of_loaded_bases(self, basis_linear, basis_quadratic, mc_range, r waveform_generator=search_waveform_generator, linear_matrix=basis_linear, quadratic_matrix=basis_quadratic, - roq_scale_factor=roq_scale_factor + roq_scale_factor=roq_scale_factor, ) with h5py.File(basis_linear, "r") as f: @@ -803,21 +756,21 @@ def test_number_of_loaded_bases(self, basis_linear, basis_quadratic, mc_range, r with h5py.File(basis_quadratic, "r") as f: mc_ranges_quadratic = f["prior_range_quadratic"]["chirp_mass"][()] / roq_scale_factor number_of_bases_linear = np.sum( - (mc_ranges_linear[:, 1] >= self.priors["chirp_mass"].minimum) * - (mc_ranges_linear[:, 0] <= self.priors["chirp_mass"].maximum) + (mc_ranges_linear[:, 1] >= self.priors["chirp_mass"].minimum) + * (mc_ranges_linear[:, 0] <= self.priors["chirp_mass"].maximum) ) number_of_bases_quadratic = np.sum( - (mc_ranges_quadratic[:, 1] >= self.priors["chirp_mass"].minimum) * - (mc_ranges_quadratic[:, 0] <= self.priors["chirp_mass"].maximum) + (mc_ranges_quadratic[:, 1] >= self.priors["chirp_mass"].minimum) + * (mc_ranges_quadratic[:, 0] <= self.priors["chirp_mass"].maximum) ) self.assertEqual(likelihood.number_of_bases_linear, number_of_bases_linear) self.assertEqual(likelihood.number_of_bases_quadratic, number_of_bases_quadratic) - self.assertEqual(len(likelihood.weights['frequency_nodes_linear']), number_of_bases_linear) - self.assertEqual(len(likelihood.weights['frequency_nodes_quadratic']), number_of_bases_quadratic) + self.assertEqual(len(likelihood.weights["frequency_nodes_linear"]), number_of_bases_linear) + self.assertEqual(len(likelihood.weights["frequency_nodes_quadratic"]), number_of_bases_quadratic) for ifo in interferometers: - self.assertEqual(len(likelihood.weights['{}_linear'.format(ifo.name)]), number_of_bases_linear) - self.assertEqual(len(likelihood.weights['{}_quadratic'.format(ifo.name)]), number_of_bases_quadratic) + self.assertEqual(len(likelihood.weights[f"{ifo.name}_linear"]), number_of_bases_linear) + self.assertEqual(len(likelihood.weights[f"{ifo.name}_quadratic"]), number_of_bases_quadratic) @parameterized.expand( product( @@ -825,7 +778,7 @@ def test_number_of_loaded_bases(self, basis_linear, basis_quadratic, mc_range, r [_path_to_basis, _path_to_basis_mb], [(8, 9), (8, 14)], [1, 2], - [False, True] + [False, True], ) ) def test_likelihood_accuracy(self, basis_linear, basis_quadratic, mc_range, roq_scale_factor, add_cal_errors): @@ -848,13 +801,26 @@ def test_likelihood_accuracy_narrower_frequency_range(self, basis, minimum_frequ """Compare with log likelihood ratios computed by the non-ROQ likelihood in the case where analyzed frequency range is narrower than the basis frequency range""" self.assertLess_likelihood_errors( - basis, basis, (8, 9), 1, False, 1.5e-1, - minimum_frequency=minimum_frequency, maximum_frequency=maximum_frequency + basis, + basis, + (8, 9), + 1, + False, + 1.5e-1, + minimum_frequency=minimum_frequency, + maximum_frequency=maximum_frequency, ) def assertLess_likelihood_errors( - self, basis_linear, basis_quadratic, mc_range, roq_scale_factor, add_cal_errors, max_llr_error, - minimum_frequency=None, maximum_frequency=None + self, + basis_linear, + basis_quadratic, + mc_range, + roq_scale_factor, + add_cal_errors, + max_llr_error, + minimum_frequency=None, + maximum_frequency=None, ): self.minimum_frequency *= roq_scale_factor self.sampling_frequency *= roq_scale_factor @@ -878,7 +844,7 @@ def assertLess_likelihood_errors( interferometers.set_strain_data_from_zero_noise( sampling_frequency=self.sampling_frequency, duration=self.duration, - start_time=self.injection_parameters["geocent_time"] - self.duration + 1 + start_time=self.injection_parameters["geocent_time"] - self.duration + 1, ) if add_cal_errors: @@ -891,30 +857,25 @@ def assertLess_likelihood_errors( prefix=prefix, minimum_frequency=ifo.minimum_frequency, maximum_frequency=ifo.maximum_frequency, - n_points=spline_calibration_nodes + n_points=spline_calibration_nodes, ) for i in range(spline_calibration_nodes): # 5% in amplitude, 5deg in phase - self.injection_parameters[f"{prefix}amplitude_{i}"] = \ - rng.normal(loc=0, scale=0.05) - self.injection_parameters[f"{prefix}phase_{i}"] = \ - rng.normal(loc=0, scale=5 * np.pi / 180) + self.injection_parameters[f"{prefix}amplitude_{i}"] = rng.normal(loc=0, scale=0.05) + self.injection_parameters[f"{prefix}phase_{i}"] = rng.normal(loc=0, scale=5 * np.pi / 180) waveform_generator = bilby.gw.WaveformGenerator( duration=self.duration, sampling_frequency=self.sampling_frequency, frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole, waveform_arguments=dict( - reference_frequency=self.reference_frequency, - waveform_approximant=self.waveform_approximant - ) + reference_frequency=self.reference_frequency, waveform_approximant=self.waveform_approximant + ), ) interferometers.inject_signal(waveform_generator=waveform_generator, parameters=self.injection_parameters) likelihood = bilby.gw.GravitationalWaveTransient( - interferometers=interferometers, - waveform_generator=waveform_generator, - priors=self.priors + interferometers=interferometers, waveform_generator=waveform_generator, priors=self.priors ) search_waveform_generator = bilby.gw.waveform_generator.WaveformGenerator( @@ -922,9 +883,8 @@ def assertLess_likelihood_errors( sampling_frequency=self.sampling_frequency, frequency_domain_source_model=bilby.gw.source.binary_black_hole_roq, waveform_arguments=dict( - reference_frequency=self.reference_frequency, - waveform_approximant=self.waveform_approximant - ) + reference_frequency=self.reference_frequency, waveform_approximant=self.waveform_approximant + ), ) likelihood_roq = bilby.gw.likelihood.ROQGravitationalWaveTransient( interferometers=interferometers, @@ -932,7 +892,7 @@ def assertLess_likelihood_errors( waveform_generator=search_waveform_generator, linear_matrix=basis_linear, quadratic_matrix=basis_quadratic, - roq_scale_factor=roq_scale_factor + roq_scale_factor=roq_scale_factor, ) for mc in np.linspace(self.priors["chirp_mass"].minimum, self.priors["chirp_mass"].maximum, 11): parameters = self.injection_parameters.copy() @@ -986,10 +946,7 @@ def test_from_hdf5(self, basis_linear, basis_quadratic): duration=duration, sampling_frequency=sampling_frequency, frequency_domain_source_model=bilby.gw.source.binary_black_hole_roq, - waveform_arguments=dict( - reference_frequency=reference_frequency, - waveform_approximant=waveform_approximant - ) + waveform_arguments=dict(reference_frequency=reference_frequency, waveform_approximant=waveform_approximant), ) bilby.gw.likelihood.ROQGravitationalWaveTransient( @@ -997,10 +954,10 @@ def test_from_hdf5(self, basis_linear, basis_quadratic): priors=priors, waveform_generator=search_waveform_generator, linear_matrix=basis_linear, - quadratic_matrix=basis_quadratic + quadratic_matrix=basis_quadratic, ) - @parameterized.expand([(False, ), (True, )]) + @parameterized.expand([(False,), (True,)]) def test_from_npy(self, from_array): # Possible locations for the ROQ: in the docker image, local, or on CIT trial_roq_paths = [ @@ -1016,15 +973,15 @@ def test_from_npy(self, from_array): if roq_dir is None: raise Exception("Unable to load ROQ basis: cannot proceed with tests") - basis_linear = "{}/B_linear.npy".format(roq_dir) + basis_linear = f"{roq_dir}/B_linear.npy" if from_array: basis_linear = np.load(basis_linear).T - basis_quadratic = "{}/B_quadratic.npy".format(roq_dir) + basis_quadratic = f"{roq_dir}/B_quadratic.npy" if from_array: basis_quadratic = np.load(basis_quadratic).T - fnodes_linear = np.load("{}/fnodes_linear.npy".format(roq_dir)) - fnodes_quadratic = np.load("{}/fnodes_quadratic.npy".format(roq_dir)) - params_file = "{}/params.dat".format(roq_dir) + fnodes_linear = np.load(f"{roq_dir}/fnodes_linear.npy") + fnodes_quadratic = np.load(f"{roq_dir}/fnodes_quadratic.npy") + params_file = f"{roq_dir}/params.dat" minimum_frequency = 20 sampling_frequency = 2048 @@ -1054,8 +1011,8 @@ def test_from_npy(self, from_array): frequency_nodes_linear=fnodes_linear, frequency_nodes_quadratic=fnodes_quadratic, reference_frequency=reference_frequency, - waveform_approximant=waveform_approximant - ) + waveform_approximant=waveform_approximant, + ), ) bilby.gw.likelihood.ROQGravitationalWaveTransient( @@ -1064,69 +1021,68 @@ def test_from_npy(self, from_array): waveform_generator=search_waveform_generator, linear_matrix=basis_linear, quadratic_matrix=basis_quadratic, - roq_params=params_file + roq_params=params_file, ) @pytest.mark.requires_roqs class TestInOutROQWeights(unittest.TestCase): - - @parameterized.expand(['npz', 'hdf5']) + @parameterized.expand(["npz", "hdf5"]) def test_out_single_basis(self, format): likelihood = self.create_likelihood_single_basis() - filename = f'weights.{format}' + filename = f"weights.{format}" likelihood.save_weights(filename, format=format) self.assertTrue(os.path.exists(filename)) def test_saving_wrong_format_fails(self): likelihood = self.create_likelihood_single_basis() - filename = 'weights.json' + filename = "weights.json" with self.assertRaises(IOError): - likelihood.save_weights(filename, format='json') + likelihood.save_weights(filename, format="json") - @parameterized.expand(['npz', 'hdf5']) + @parameterized.expand(["npz", "hdf5"]) def test_in_single_basis(self, format): likelihood = self.create_likelihood_single_basis() - filename = f'weights.{format}' + filename = f"weights.{format}" likelihood.save_weights(filename, format=format) likelihood_from_weights = bilby.gw.likelihood.ROQGravitationalWaveTransient( interferometers=likelihood.interferometers, priors=likelihood.priors, waveform_generator=likelihood.waveform_generator, - weights=filename + weights=filename, ) self.check_weights_are_same(likelihood, likelihood_from_weights) - @parameterized.expand([(False, ), (True, )]) + @parameterized.expand([(False,), (True,)]) def test_out_multiple_bases(self, multiband): - filename = 'weights.hdf5' + filename = "weights.hdf5" likelihood = self.create_likelihood_multiple_bases(multiband) - likelihood.save_weights(filename, format='hdf5') + likelihood.save_weights(filename, format="hdf5") self.assertTrue(os.path.exists(filename)) - @parameterized.expand([(False, ), (True, )]) + @parameterized.expand([(False,), (True,)]) def test_in_multiple_bases(self, multiband): - filename = 'weights.hdf5' + filename = "weights.hdf5" likelihood = self.create_likelihood_multiple_bases(multiband) - likelihood.save_weights(filename, format='hdf5') + likelihood.save_weights(filename, format="hdf5") likelihood_from_weights = bilby.gw.likelihood.ROQGravitationalWaveTransient( interferometers=likelihood.interferometers, priors=likelihood.priors, waveform_generator=likelihood.waveform_generator, - weights=filename + weights=filename, ) self.check_weights_are_same(likelihood, likelihood_from_weights) - @parameterized.expand([(False, ), (True, )]) + @parameterized.expand([(False,), (True,)]) def test_out_multiple_bases_inconsistent_format(self, multiband): "npz format is not compatible with multiple bases" likelihood = self.create_likelihood_multiple_bases(multiband) with self.assertRaises(ValueError): - likelihood.save_weights('weights', format="npz") + likelihood.save_weights("weights", format="npz") def tearDown(self): - for format in ['npz', 'hdf5']: - filename = f'weights.{format}' + for format in ["npz", "hdf5"]: + filename = f"weights.{format}" if os.path.exists(filename): os.remove(filename) @@ -1139,20 +1095,20 @@ def check_weights_are_same(l1, l2): l1, l2: bilby.gw.likelihood.ROQGravitationalWaveTransient """ - np.testing.assert_array_almost_equal(l1.weights['time_samples'], l2.weights['time_samples']) - for basis_type in ['linear', 'quadratic']: + np.testing.assert_array_almost_equal(l1.weights["time_samples"], l2.weights["time_samples"]) + for basis_type in ["linear", "quadratic"]: # check weights for ifo in l1.interferometers: - key = f'{ifo.name}_{basis_type}' + key = f"{ifo.name}_{basis_type}" for i in range(len(l1.weights[key])): np.testing.assert_array_almost_equal(l1.weights[key][i], l2.weights[key][i]) # check prior ranges - key = f'prior_range_{basis_type}' + key = f"prior_range_{basis_type}" if key in l1.weights: for param_name in l1.weights[key]: np.testing.assert_array_almost_equal(l1.weights[key][param_name], l2.weights[key][param_name]) # check frequency nodes - key = f'frequency_nodes_{basis_type}' + key = f"frequency_nodes_{basis_type}" if key in l1.weights: for i in range(len(l1.weights[key])): np.testing.assert_array_almost_equal(l1.weights[key][i], l2.weights[key][i]) @@ -1172,10 +1128,10 @@ def create_likelihood_single_basis(self): if roq_dir is None: raise Exception("Unable to load ROQ basis: cannot proceed with tests") - linear_matrix_file = "{}/B_linear.npy".format(roq_dir) - quadratic_matrix_file = "{}/B_quadratic.npy".format(roq_dir) - fnodes_linear = np.load("{}/fnodes_linear.npy".format(roq_dir)) - fnodes_quadratic = np.load("{}/fnodes_quadratic.npy".format(roq_dir)) + linear_matrix_file = f"{roq_dir}/B_linear.npy" + quadratic_matrix_file = f"{roq_dir}/B_quadratic.npy" + fnodes_linear = np.load(f"{roq_dir}/fnodes_linear.npy") + fnodes_quadratic = np.load(f"{roq_dir}/fnodes_quadratic.npy") minimum_frequency = 20 sampling_frequency = 2048 @@ -1205,8 +1161,8 @@ def create_likelihood_single_basis(self): frequency_nodes_linear=fnodes_linear, frequency_nodes_quadratic=fnodes_quadratic, reference_frequency=reference_frequency, - waveform_approximant=waveform_approximant - ) + waveform_approximant=waveform_approximant, + ), ) return bilby.gw.likelihood.ROQGravitationalWaveTransient( @@ -1214,7 +1170,7 @@ def create_likelihood_single_basis(self): priors=priors, waveform_generator=search_waveform_generator, linear_matrix=linear_matrix_file, - quadratic_matrix=quadratic_matrix_file + quadratic_matrix=quadratic_matrix_file, ) def create_likelihood_multiple_bases(self, multiband): @@ -1242,10 +1198,7 @@ def create_likelihood_multiple_bases(self, multiband): duration=duration, sampling_frequency=sampling_frequency, frequency_domain_source_model=bilby.gw.source.binary_black_hole_roq, - waveform_arguments=dict( - reference_frequency=reference_frequency, - waveform_approximant=waveform_approximant - ) + waveform_arguments=dict(reference_frequency=reference_frequency, waveform_approximant=waveform_approximant), ) if multiband: @@ -1257,7 +1210,7 @@ def create_likelihood_multiple_bases(self, multiband): priors=priors, waveform_generator=search_waveform_generator, linear_matrix=path_to_basis, - quadratic_matrix=path_to_basis + quadratic_matrix=path_to_basis, ) @@ -1275,8 +1228,8 @@ def test_instantiation(self): class TestMBLikelihood(unittest.TestCase): def setUp(self): self.duration = 16 - self.fmin = 20. - self.sampling_frequency = 2048. + self.fmin = 20.0 + self.sampling_frequency = 2048.0 self.test_parameters = dict( chirp_mass=6.0, mass_ratio=0.5, @@ -1292,15 +1245,16 @@ def setUp(self): phase=1.3, geocent_time=1187008882, ra=1.3, - dec=-1.2 + dec=-1.2, ) # Network SNR is ~50 self.ifos = bilby.gw.detector.InterferometerList(["H1", "L1", "V1"]) bilby.core.utils.random.seed(70817) rng = bilby.core.utils.random.rng self.ifos.set_strain_data_from_power_spectral_densities( - sampling_frequency=self.sampling_frequency, duration=self.duration, - start_time=self.test_parameters['geocent_time'] - self.duration + 2. + sampling_frequency=self.sampling_frequency, + duration=self.duration, + start_time=self.test_parameters["geocent_time"] - self.duration + 2.0, ) for ifo in self.ifos: ifo.minimum_frequency = self.fmin @@ -1312,16 +1266,14 @@ def setUp(self): prefix=f"recalib_{ifo.name}_", minimum_frequency=ifo.minimum_frequency, maximum_frequency=ifo.maximum_frequency, - n_points=spline_calibration_nodes + n_points=spline_calibration_nodes, ) for i in range(spline_calibration_nodes): self.test_parameters[f"recalib_{ifo.name}_amplitude_{i}"] = 0 self.test_parameters[f"recalib_{ifo.name}_phase_{i}"] = 0 # Calibration errors of 5% in amplitude and 5 degrees in phase - self.calibration_parameters[f"recalib_{ifo.name}_amplitude_{i}"] = \ - rng.normal(loc=0, scale=0.05) - self.calibration_parameters[f"recalib_{ifo.name}_phase_{i}"] = \ - rng.normal(loc=0, scale=5 * np.pi / 180) + self.calibration_parameters[f"recalib_{ifo.name}_amplitude_{i}"] = rng.normal(loc=0, scale=0.05) + self.calibration_parameters[f"recalib_{ifo.name}_phase_{i}"] = rng.normal(loc=0, scale=5 * np.pi / 180) self.priors = bilby.gw.prior.BBHPriorDict() self.priors.pop("mass_1") @@ -1329,23 +1281,22 @@ def setUp(self): self.priors["chirp_mass"] = bilby.core.prior.Uniform(5.5, 6.5) self.priors["mass_ratio"] = bilby.core.prior.Uniform(0.125, 1) self.priors["geocent_time"] = bilby.core.prior.Uniform( - self.test_parameters['geocent_time'] - 0.1, - self.test_parameters['geocent_time'] + 0.1) + self.test_parameters["geocent_time"] - 0.1, self.test_parameters["geocent_time"] + 0.1 + ) def tearDown(self): - del ( - self.ifos, - self.priors - ) + del (self.ifos, self.priors) - @parameterized.expand([ - ("IMRPhenomD", True, 2, False, 1.5e-2), - ("IMRPhenomD", True, 2, True, 1.5e-2), - ("IMRPhenomD", False, 2, False, 5e-3), - ("IMRPhenomD", False, 2, True, 6e-3), - ("IMRPhenomHM", False, 4, False, 8e-4), - ("IMRPhenomHM", False, 4, True, 1e-3) - ]) + @parameterized.expand( + [ + ("IMRPhenomD", True, 2, False, 1.5e-2), + ("IMRPhenomD", True, 2, True, 1.5e-2), + ("IMRPhenomD", False, 2, False, 5e-3), + ("IMRPhenomD", False, 2, True, 6e-3), + ("IMRPhenomHM", False, 4, False, 8e-4), + ("IMRPhenomHM", False, 4, True, 1e-3), + ] + ) def test_matches_original_likelihood( self, waveform_approximant, linear_interpolation, highest_mode, add_cal_errors, tolerance ): @@ -1353,43 +1304,37 @@ def test_matches_original_likelihood( Check if multi-band likelihood values match original likelihood values """ wfg = bilby.gw.WaveformGenerator( - duration=self.duration, sampling_frequency=self.sampling_frequency, + duration=self.duration, + sampling_frequency=self.sampling_frequency, frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole, - waveform_arguments=dict( - reference_frequency=self.fmin, waveform_approximant=waveform_approximant - ) + waveform_arguments=dict(reference_frequency=self.fmin, waveform_approximant=waveform_approximant), ) self.ifos.inject_signal(parameters=self.test_parameters, waveform_generator=wfg) wfg_mb = bilby.gw.WaveformGenerator( - duration=self.duration, sampling_frequency=self.sampling_frequency, + duration=self.duration, + sampling_frequency=self.sampling_frequency, frequency_domain_source_model=bilby.gw.source.binary_black_hole_frequency_sequence, - waveform_arguments=dict( - reference_frequency=self.fmin, waveform_approximant=waveform_approximant - ) - ) - likelihood = bilby.gw.likelihood.GravitationalWaveTransient( - interferometers=self.ifos, waveform_generator=wfg + waveform_arguments=dict(reference_frequency=self.fmin, waveform_approximant=waveform_approximant), ) + likelihood = bilby.gw.likelihood.GravitationalWaveTransient(interferometers=self.ifos, waveform_generator=wfg) likelihood_mb = bilby.gw.likelihood.MBGravitationalWaveTransient( - interferometers=self.ifos, waveform_generator=wfg_mb, - reference_chirp_mass=self.test_parameters['chirp_mass'], - priors=self.priors.copy(), linear_interpolation=linear_interpolation, - highest_mode=highest_mode + interferometers=self.ifos, + waveform_generator=wfg_mb, + reference_chirp_mass=self.test_parameters["chirp_mass"], + priors=self.priors.copy(), + linear_interpolation=linear_interpolation, + highest_mode=highest_mode, ) parameters = deepcopy(self.test_parameters) if add_cal_errors: parameters.update(self.calibration_parameters) self.assertLess( - abs(likelihood.log_likelihood_ratio(parameters) - likelihood_mb.log_likelihood_ratio(parameters)), - tolerance + abs(likelihood.log_likelihood_ratio(parameters) - likelihood_mb.log_likelihood_ratio(parameters)), tolerance ) likelihood.parameters.update(parameters) likelihood_mb.parameters.update(parameters) - self.assertLess( - abs(likelihood.log_likelihood_ratio() - likelihood_mb.log_likelihood_ratio()), - tolerance - ) + self.assertLess(abs(likelihood.log_likelihood_ratio() - likelihood_mb.log_likelihood_ratio()), tolerance) def test_large_accuracy_factor(self): """ @@ -1397,33 +1342,33 @@ def test_large_accuracy_factor(self): """ waveform_approximant = "IMRPhenomD" wfg = bilby.gw.WaveformGenerator( - duration=self.duration, sampling_frequency=self.sampling_frequency, + duration=self.duration, + sampling_frequency=self.sampling_frequency, frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole, - waveform_arguments=dict( - reference_frequency=self.fmin, waveform_approximant=waveform_approximant - ) + waveform_arguments=dict(reference_frequency=self.fmin, waveform_approximant=waveform_approximant), ) self.ifos.inject_signal(parameters=self.test_parameters, waveform_generator=wfg) wfg_mb = bilby.gw.WaveformGenerator( - duration=self.duration, sampling_frequency=self.sampling_frequency, + duration=self.duration, + sampling_frequency=self.sampling_frequency, frequency_domain_source_model=bilby.gw.source.binary_black_hole_frequency_sequence, - waveform_arguments=dict( - reference_frequency=self.fmin, waveform_approximant=waveform_approximant - ) - ) - likelihood = bilby.gw.likelihood.GravitationalWaveTransient( - interferometers=self.ifos, waveform_generator=wfg + waveform_arguments=dict(reference_frequency=self.fmin, waveform_approximant=waveform_approximant), ) + likelihood = bilby.gw.likelihood.GravitationalWaveTransient(interferometers=self.ifos, waveform_generator=wfg) likelihood_mb = bilby.gw.likelihood.MBGravitationalWaveTransient( - interferometers=self.ifos, waveform_generator=wfg_mb, - reference_chirp_mass=self.test_parameters['chirp_mass'], - priors=self.priors.copy(), accuracy_factor=5 + interferometers=self.ifos, + waveform_generator=wfg_mb, + reference_chirp_mass=self.test_parameters["chirp_mass"], + priors=self.priors.copy(), + accuracy_factor=5, ) likelihood_mb_more_accurate = bilby.gw.likelihood.MBGravitationalWaveTransient( - interferometers=self.ifos, waveform_generator=wfg_mb, - reference_chirp_mass=self.test_parameters['chirp_mass'], - priors=self.priors.copy(), accuracy_factor=50 + interferometers=self.ifos, + waveform_generator=wfg_mb, + reference_chirp_mass=self.test_parameters["chirp_mass"], + priors=self.priors.copy(), + accuracy_factor=50, ) self.assertLess( abs( @@ -1433,14 +1378,15 @@ def test_large_accuracy_factor(self): abs( likelihood.log_likelihood_ratio(self.test_parameters) - likelihood_mb.log_likelihood_ratio(self.test_parameters) - ) / 2 + ) + / 2, ) likelihood.parameters.update(self.test_parameters) likelihood_mb.parameters.update(self.test_parameters) likelihood_mb_more_accurate.parameters.update(self.test_parameters) self.assertLess( abs(likelihood.log_likelihood_ratio() - likelihood_mb_more_accurate.log_likelihood_ratio()), - abs(likelihood.log_likelihood_ratio() - likelihood_mb.log_likelihood_ratio()) / 2 + abs(likelihood.log_likelihood_ratio() - likelihood_mb.log_likelihood_ratio()) / 2, ) def test_reference_chirp_mass_from_prior(self): @@ -1448,20 +1394,19 @@ def test_reference_chirp_mass_from_prior(self): Check if reference chirp mass is automatically determined from prior if no number has been passed """ wfg_mb = bilby.gw.WaveformGenerator( - duration=self.duration, sampling_frequency=self.sampling_frequency, + duration=self.duration, + sampling_frequency=self.sampling_frequency, frequency_domain_source_model=bilby.gw.source.binary_black_hole_frequency_sequence, - waveform_arguments=dict( - reference_frequency=self.fmin, waveform_approximant="IMRPhenomD" - ) + waveform_arguments=dict(reference_frequency=self.fmin, waveform_approximant="IMRPhenomD"), ) likelihood1 = bilby.gw.likelihood.MBGravitationalWaveTransient( - interferometers=self.ifos, waveform_generator=wfg_mb, + interferometers=self.ifos, + waveform_generator=wfg_mb, reference_chirp_mass=self.priors["chirp_mass"].minimum, - priors=self.priors.copy() + priors=self.priors.copy(), ) likelihood2 = bilby.gw.likelihood.MBGravitationalWaveTransient( - interferometers=self.ifos, waveform_generator=wfg_mb, - priors=self.priors.copy() + interferometers=self.ifos, waveform_generator=wfg_mb, priors=self.priors.copy() ) self.assertAlmostEqual(likelihood1.reference_chirp_mass, likelihood2.reference_chirp_mass) @@ -1470,27 +1415,23 @@ def test_no_reference_chirp_mass(self): Check if an error is raised if either reference_chirp_mass or priors is not specified. """ wfg_mb = bilby.gw.WaveformGenerator( - duration=self.duration, sampling_frequency=self.sampling_frequency, + duration=self.duration, + sampling_frequency=self.sampling_frequency, frequency_domain_source_model=bilby.gw.source.binary_black_hole_frequency_sequence, - waveform_arguments=dict( - reference_frequency=self.fmin, waveform_approximant="IMRPhenomD" - ) + waveform_arguments=dict(reference_frequency=self.fmin, waveform_approximant="IMRPhenomD"), ) with self.assertRaises(TypeError): - bilby.gw.likelihood.MBGravitationalWaveTransient( - interferometers=self.ifos, waveform_generator=wfg_mb - ) + bilby.gw.likelihood.MBGravitationalWaveTransient(interferometers=self.ifos, waveform_generator=wfg_mb) def test_cannot_determine_reference_chirp_mass(self): """ Check if an error is raised if priors does not contain necessary information to determine reference chirp mass """ wfg_mb = bilby.gw.WaveformGenerator( - duration=self.duration, sampling_frequency=self.sampling_frequency, + duration=self.duration, + sampling_frequency=self.sampling_frequency, frequency_domain_source_model=bilby.gw.source.binary_black_hole_frequency_sequence, - waveform_arguments=dict( - reference_frequency=self.fmin, waveform_approximant="IMRPhenomD" - ) + waveform_arguments=dict(reference_frequency=self.fmin, waveform_approximant="IMRPhenomD"), ) for key in ["chirp_mass", "mass_1", "mass_2"]: if key in self.priors: @@ -1500,7 +1441,7 @@ def test_cannot_determine_reference_chirp_mass(self): interferometers=self.ifos, waveform_generator=wfg_mb, priors=self.priors ) - @parameterized.expand([(True, ), (False, )]) + @parameterized.expand([(True,), (False,)]) def test_inout_weights(self, linear_interpolation): """ Check if multiband weights can be saved as a file, and a likelihood object constructed from the weights file @@ -1508,26 +1449,23 @@ def test_inout_weights(self, linear_interpolation): """ waveform_approximant = "IMRPhenomD" wfg = bilby.gw.WaveformGenerator( - duration=self.duration, sampling_frequency=self.sampling_frequency, + duration=self.duration, + sampling_frequency=self.sampling_frequency, frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole, - waveform_arguments=dict( - reference_frequency=self.fmin, waveform_approximant=waveform_approximant - ) - ) - self.ifos.inject_signal( - parameters=self.test_parameters, waveform_generator=wfg + waveform_arguments=dict(reference_frequency=self.fmin, waveform_approximant=waveform_approximant), ) + self.ifos.inject_signal(parameters=self.test_parameters, waveform_generator=wfg) wfg_mb = bilby.gw.WaveformGenerator( - duration=self.duration, sampling_frequency=self.sampling_frequency, + duration=self.duration, + sampling_frequency=self.sampling_frequency, frequency_domain_source_model=bilby.gw.source.binary_black_hole_frequency_sequence, - waveform_arguments=dict( - reference_frequency=self.fmin, waveform_approximant=waveform_approximant - ) + waveform_arguments=dict(reference_frequency=self.fmin, waveform_approximant=waveform_approximant), ) likelihood_mb = bilby.gw.likelihood.MBGravitationalWaveTransient( - interferometers=self.ifos, waveform_generator=wfg_mb, - reference_chirp_mass=self.test_parameters['chirp_mass'], + interferometers=self.ifos, + waveform_generator=wfg_mb, + reference_chirp_mass=self.test_parameters["chirp_mass"], linear_interpolation=linear_interpolation, ) likelihood_mb.parameters.update(self.test_parameters) @@ -1542,11 +1480,10 @@ def test_inout_weights(self, linear_interpolation): # reset waveform generator to check if likelihood recovered from the weights file properly adds banded # frequency points to waveform arguments wfg_mb = bilby.gw.WaveformGenerator( - duration=self.duration, sampling_frequency=self.sampling_frequency, + duration=self.duration, + sampling_frequency=self.sampling_frequency, frequency_domain_source_model=bilby.gw.source.binary_black_hole_frequency_sequence, - waveform_arguments=dict( - reference_frequency=self.fmin, waveform_approximant=waveform_approximant - ) + waveform_arguments=dict(reference_frequency=self.fmin, waveform_approximant=waveform_approximant), ) likelihood_mb_from_weights = bilby.gw.likelihood.MBGravitationalWaveTransient( interferometers=self.ifos, waveform_generator=wfg_mb, weights=filepath @@ -1557,33 +1494,30 @@ def test_inout_weights(self, linear_interpolation): self.assertAlmostEqual(llr, llr_from_weights) - @parameterized.expand([(True, ), (False, )]) + @parameterized.expand([(True,), (False,)]) def test_from_dict_weights(self, linear_interpolation): """ Check if a likelihood object constructed from dictionary-like weights produce the same likelihood value """ waveform_approximant = "IMRPhenomD" wfg = bilby.gw.WaveformGenerator( - duration=self.duration, sampling_frequency=self.sampling_frequency, + duration=self.duration, + sampling_frequency=self.sampling_frequency, frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole, - waveform_arguments=dict( - reference_frequency=self.fmin, waveform_approximant=waveform_approximant - ) - ) - self.ifos.inject_signal( - parameters=self.test_parameters, waveform_generator=wfg + waveform_arguments=dict(reference_frequency=self.fmin, waveform_approximant=waveform_approximant), ) + self.ifos.inject_signal(parameters=self.test_parameters, waveform_generator=wfg) wfg_mb = bilby.gw.WaveformGenerator( - duration=self.duration, sampling_frequency=self.sampling_frequency, + duration=self.duration, + sampling_frequency=self.sampling_frequency, frequency_domain_source_model=bilby.gw.source.binary_black_hole_frequency_sequence, - waveform_arguments=dict( - reference_frequency=self.fmin, waveform_approximant=waveform_approximant - ) + waveform_arguments=dict(reference_frequency=self.fmin, waveform_approximant=waveform_approximant), ) likelihood_mb = bilby.gw.likelihood.MBGravitationalWaveTransient( - interferometers=self.ifos, waveform_generator=wfg_mb, - reference_chirp_mass=self.test_parameters['chirp_mass'], + interferometers=self.ifos, + waveform_generator=wfg_mb, + reference_chirp_mass=self.test_parameters["chirp_mass"], linear_interpolation=linear_interpolation, ) likelihood_mb.parameters.update(self.test_parameters) @@ -1592,11 +1526,10 @@ def test_from_dict_weights(self, linear_interpolation): # reset waveform generator to check if likelihood recovered from the weights properly adds banded # frequency points to waveform arguments wfg_mb = bilby.gw.WaveformGenerator( - duration=self.duration, sampling_frequency=self.sampling_frequency, + duration=self.duration, + sampling_frequency=self.sampling_frequency, frequency_domain_source_model=bilby.gw.source.binary_black_hole_frequency_sequence, - waveform_arguments=dict( - reference_frequency=self.fmin, waveform_approximant=waveform_approximant - ) + waveform_arguments=dict(reference_frequency=self.fmin, waveform_approximant=waveform_approximant), ) weights = likelihood_mb.weights likelihood_mb_from_weights = bilby.gw.likelihood.MBGravitationalWaveTransient( @@ -1607,11 +1540,13 @@ def test_from_dict_weights(self, linear_interpolation): self.assertAlmostEqual(llr, llr_from_weights) - @parameterized.expand([ - ("IMRPhenomD", True, 2, False, 1e-2), - ("IMRPhenomD", True, 2, True, 1e-2), - ("IMRPhenomHM", False, 4, False, 5e-3), - ]) + @parameterized.expand( + [ + ("IMRPhenomD", True, 2, False, 1e-2), + ("IMRPhenomD", True, 2, True, 1e-2), + ("IMRPhenomHM", False, 4, False, 5e-3), + ] + ) def test_matches_original_likelihood_low_maximum_frequency( self, waveform_approximant, linear_interpolation, highest_mode, add_cal_errors, tolerance ): @@ -1622,43 +1557,37 @@ def test_matches_original_likelihood_low_maximum_frequency( ifo.maximum_frequency = self.sampling_frequency / 8 wfg = bilby.gw.WaveformGenerator( - duration=self.duration, sampling_frequency=self.sampling_frequency, + duration=self.duration, + sampling_frequency=self.sampling_frequency, frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole, - waveform_arguments=dict( - reference_frequency=self.fmin, waveform_approximant=waveform_approximant - ) + waveform_arguments=dict(reference_frequency=self.fmin, waveform_approximant=waveform_approximant), ) self.ifos.inject_signal(parameters=self.test_parameters, waveform_generator=wfg) wfg_mb = bilby.gw.WaveformGenerator( - duration=self.duration, sampling_frequency=self.sampling_frequency, + duration=self.duration, + sampling_frequency=self.sampling_frequency, frequency_domain_source_model=bilby.gw.source.binary_black_hole_frequency_sequence, - waveform_arguments=dict( - reference_frequency=self.fmin, waveform_approximant=waveform_approximant - ) - ) - likelihood = bilby.gw.likelihood.GravitationalWaveTransient( - interferometers=self.ifos, waveform_generator=wfg + waveform_arguments=dict(reference_frequency=self.fmin, waveform_approximant=waveform_approximant), ) + likelihood = bilby.gw.likelihood.GravitationalWaveTransient(interferometers=self.ifos, waveform_generator=wfg) likelihood_mb = bilby.gw.likelihood.MBGravitationalWaveTransient( - interferometers=self.ifos, waveform_generator=wfg_mb, - reference_chirp_mass=self.test_parameters['chirp_mass'], - priors=self.priors.copy(), linear_interpolation=linear_interpolation, - highest_mode=highest_mode + interferometers=self.ifos, + waveform_generator=wfg_mb, + reference_chirp_mass=self.test_parameters["chirp_mass"], + priors=self.priors.copy(), + linear_interpolation=linear_interpolation, + highest_mode=highest_mode, ) parameters = deepcopy(self.test_parameters) if add_cal_errors: parameters.update(self.calibration_parameters) self.assertLess( - abs(likelihood.log_likelihood_ratio(parameters) - likelihood_mb.log_likelihood_ratio(parameters)), - tolerance + abs(likelihood.log_likelihood_ratio(parameters) - likelihood_mb.log_likelihood_ratio(parameters)), tolerance ) likelihood.parameters.update(parameters) likelihood_mb.parameters.update(parameters) - self.assertLess( - abs(likelihood.log_likelihood_ratio() - likelihood_mb.log_likelihood_ratio()), - tolerance - ) + self.assertLess(abs(likelihood.log_likelihood_ratio() - likelihood_mb.log_likelihood_ratio()), tolerance) if __name__ == "__main__": diff --git a/test/gw/plot_test.py b/test/gw/plot_test.py index bd9414212..a6646838d 100644 --- a/test/gw/plot_test.py +++ b/test/gw/plot_test.py @@ -20,9 +20,7 @@ def setUp(self): time_marginalization=True, frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole, time_domain_source_model=None, - waveform_arguments=dict( - reference_frequency=20.0, waveform_approximant="IMRPhenomPv2" - ), + waveform_arguments=dict(reference_frequency=20.0, waveform_approximant="IMRPhenomPv2"), interferometers=dict( H1=dict(optimal_SNR=1, parameters=injection_parameters), L1=dict(optimal_SNR=1, parameters=injection_parameters), @@ -66,9 +64,7 @@ def test_calibration_plot(self): label="recalib_H1_", n_nodes=5, ) - calibration_filename = ( - f"{self.result.outdir}/{self.result.label}_calibration.png" - ) + calibration_filename = f"{self.result.outdir}/{self.result.label}_calibration.png" for key in calibration_prior: self.result.posterior[key] = calibration_prior[key].sample(100) self.result.plot_calibration_posterior() @@ -76,9 +72,7 @@ def test_calibration_plot(self): def test_calibration_plot_returns_none_with_no_calibration_parameters(self): self.assertIsNone(self.result.plot_calibration_posterior()) - calibration_filename = ( - f"{self.result.outdir}/{self.result.label}_calibration.png" - ) + calibration_filename = f"{self.result.outdir}/{self.result.label}_calibration.png" self.assertFalse(os.path.exists(calibration_filename)) def test_calibration_pdf_plot(self): @@ -90,9 +84,7 @@ def test_calibration_pdf_plot(self): label="recalib_H1_", n_nodes=5, ) - calibration_filename = ( - f"{self.result.outdir}/{self.result.label}_calibration.pdf" - ) + calibration_filename = f"{self.result.outdir}/{self.result.label}_calibration.pdf" for key in calibration_prior: self.result.posterior[key] = calibration_prior[key].sample(100) self.result.plot_calibration_posterior(format="pdf") @@ -105,11 +97,7 @@ def test_calibration_invalid_format_raises_error(self): def test_waveform_plotting_png(self): self.result.plot_waveform_posterior(n_samples=200) for ifo in self.result.interferometers: - self.assertTrue( - os.path.exists( - f"{self.result.outdir}/{self.result.label}_{ifo}_waveform.png" - ) - ) + self.assertTrue(os.path.exists(f"{self.result.outdir}/{self.result.label}_{ifo}_waveform.png")) def test_plot_skymap_meta_data(self): from ligo.skymap import io diff --git a/test/gw/prior_test.py b/test/gw/prior_test.py index 1dcc11096..fe7b16b5d 100644 --- a/test/gw/prior_test.py +++ b/test/gw/prior_test.py @@ -1,28 +1,26 @@ -from collections import OrderedDict -import unittest import glob import os -import sys import pickle +import sys +import unittest +from collections import OrderedDict +import matplotlib.pyplot as plt import numpy as np +import pandas as pd from astropy import cosmology from scipy.stats import ks_2samp -import matplotlib.pyplot as plt -import pandas as pd import bilby -from bilby.core.prior import Uniform, Constraint -from bilby.gw.prior import BBHPriorDict +from bilby.core.prior import Constraint, Uniform from bilby.gw import conversion +from bilby.gw.prior import BBHPriorDict class TestBBHPriorDict(unittest.TestCase): def setUp(self): self.prior_dict = dict() - self.base_directory = "/".join( - os.path.dirname(os.path.abspath(sys.argv[0])).split("/")[:-1] - ) + self.base_directory = "/".join(os.path.dirname(os.path.abspath(sys.argv[0])).split("/")[:-1]) self.filename = os.path.join( os.path.dirname(os.path.realpath(__file__)), "prior_files/precessing_spins_bbh.prior", @@ -47,30 +45,10 @@ def test_read_write_default_prior(self): def test_create_default_prior(self): default = bilby.gw.prior.BBHPriorDict() - minima = all( - [ - self.bbh_prior_dict[key].minimum == default[key].minimum - for key in default.keys() - ] - ) - maxima = all( - [ - self.bbh_prior_dict[key].maximum == default[key].maximum - for key in default.keys() - ] - ) - names = all( - [ - self.bbh_prior_dict[key].name == default[key].name - for key in default.keys() - ] - ) - boundaries = all( - [ - self.bbh_prior_dict[key].boundary == default[key].boundary - for key in default.keys() - ] - ) + minima = all([self.bbh_prior_dict[key].minimum == default[key].minimum for key in default.keys()]) + maxima = all([self.bbh_prior_dict[key].maximum == default[key].maximum for key in default.keys()]) + names = all([self.bbh_prior_dict[key].name == default[key].name for key in default.keys()]) + boundaries = all([self.bbh_prior_dict[key].boundary == default[key].boundary for key in default.keys()]) self.assertTrue(all([minima, maxima, names, boundaries])) @@ -165,9 +143,7 @@ def test_test_has_redundant_priors(self): del self.bbh_prior_dict[prior] def test_add_constraint_prior_not_redundant(self): - self.bbh_prior_dict["chirp_mass"] = bilby.prior.Constraint( - minimum=20, maximum=40, name="chirp_mass" - ) + self.bbh_prior_dict["chirp_mass"] = bilby.prior.Constraint(minimum=20, maximum=40, name="chirp_mass") self.assertFalse(self.bbh_prior_dict.test_has_redundant_keys()) def test_is_cosmological_true(self): @@ -188,10 +164,16 @@ def test_check_valid_cosmology(self): def test_check_valid_cosmology_raises_error(self): self.bbh_prior_dict["luminosity_distance"] = bilby.gw.prior.UniformComovingVolume( - minimum=10, maximum=10000, name="luminosity_distance", cosmology="Planck15", + minimum=10, + maximum=10000, + name="luminosity_distance", + cosmology="Planck15", ) self.bbh_prior_dict["redshift"] = bilby.gw.prior.UniformComovingVolume( - minimum=0.1, maximum=1, name="redshift", cosmology="Planck15_LAL", + minimum=0.1, + maximum=1, + name="redshift", + cosmology="Planck15_LAL", ) self.assertEqual( self.bbh_prior_dict._cosmological_priors, @@ -230,12 +212,8 @@ def test_bilby_to_lalinference(self): bilby_prior = BBHPriorDict( dictionary=dict( - chirp_mass=Uniform( - name="chirp_mass", minimum=chirp_mass[0], maximum=chirp_mass[1] - ), - mass_ratio=Uniform( - name="mass_ratio", minimum=mass_ratio[0], maximum=mass_ratio[1] - ), + chirp_mass=Uniform(name="chirp_mass", minimum=chirp_mass[0], maximum=chirp_mass[1]), + mass_ratio=Uniform(name="mass_ratio", minimum=mass_ratio[0], maximum=mass_ratio[1]), mass_2=Constraint(name="mass_2", minimum=mass_1[0], maximum=mass_1[1]), mass_1=Constraint(name="mass_1", minimum=mass_2[0], maximum=mass_2[1]), ) @@ -243,12 +221,8 @@ def test_bilby_to_lalinference(self): lalinf_prior = BBHPriorDict( dictionary=dict( - mass_ratio=Constraint( - name="mass_ratio", minimum=mass_ratio[0], maximum=mass_ratio[1] - ), - chirp_mass=Constraint( - name="chirp_mass", minimum=chirp_mass[0], maximum=chirp_mass[1] - ), + mass_ratio=Constraint(name="mass_ratio", minimum=mass_ratio[0], maximum=mass_ratio[1]), + chirp_mass=Constraint(name="chirp_mass", minimum=chirp_mass[0], maximum=chirp_mass[1]), mass_2=Uniform(name="mass_2", minimum=mass_1[0], maximum=mass_1[1]), mass_1=Uniform(name="mass_1", minimum=mass_2[0], maximum=mass_2[1]), ) @@ -256,9 +230,7 @@ def test_bilby_to_lalinference(self): nsamples = 5000 bilby_samples = bilby_prior.sample(nsamples) - bilby_samples, _ = conversion.convert_to_lal_binary_black_hole_parameters( - bilby_samples - ) + bilby_samples, _ = conversion.convert_to_lal_binary_black_hole_parameters(bilby_samples) # Quicker way to generate LA prior samples (rather than specifying Constraint) lalinf_samples = [] @@ -269,9 +241,7 @@ def test_bilby_to_lalinference(self): if s["mass_2"] / s["mass_1"] > 0.125: lalinf_samples.append(s) lalinf_samples = pd.DataFrame(lalinf_samples) - lalinf_samples["mass_ratio"] = ( - lalinf_samples["mass_2"] / lalinf_samples["mass_1"] - ) + lalinf_samples["mass_ratio"] = lalinf_samples["mass_2"] / lalinf_samples["mass_1"] # Construct fake result object result = bilby.core.result.Result() @@ -279,9 +249,7 @@ def test_bilby_to_lalinference(self): result.meta_data = dict() result.priors = bilby_prior result.posterior = pd.DataFrame(bilby_samples) - result_converted = bilby.gw.prior.convert_to_flat_in_component_mass_prior( - result - ) + result_converted = bilby.gw.prior.convert_to_flat_in_component_mass_prior(result) if "plot" in sys.argv: # Useful for debugging @@ -301,15 +269,13 @@ def test_bilby_to_lalinference(self): self.assertFalse(ks.pvalue > 0.05) # Check that the non-reweighted posteriors pass a KS test - ks = ks_2samp( - result_converted.posterior["mass_ratio"], lalinf_samples["mass_ratio"] - ) + ks = ks_2samp(result_converted.posterior["mass_ratio"], lalinf_samples["mass_ratio"]) print("Reweighted KS test = ", ks) self.assertTrue(ks.pvalue > 0.001) class TestPackagedPriors(unittest.TestCase): - """ Test that the prepackaged priors load """ + """Test that the prepackaged priors load""" def test_aligned(self): filename = "aligned_spins_bbh.prior" @@ -325,7 +291,7 @@ def test_binary_black_holes(self): def test_all(self): prior_files = glob.glob(bilby.gw.prior.DEFAULT_PRIOR_DIR + "/*prior") for ff in prior_files: - print("Checking prior file {}".format(ff)) + print(f"Checking prior file {ff}") prior_dict = bilby.gw.prior.BBHPriorDict(filename=ff) self.assertTrue("chirp_mass" in prior_dict) self.assertTrue("mass_ratio" in prior_dict) @@ -338,9 +304,7 @@ def test_all(self): class TestBNSPriorDict(unittest.TestCase): def setUp(self): self.prior_dict = OrderedDict() - self.base_directory = "/".join( - os.path.dirname(os.path.abspath(sys.argv[0])).split("/")[:-1] - ) + self.base_directory = "/".join(os.path.dirname(os.path.abspath(sys.argv[0])).split("/")[:-1]) self.filename = os.path.join( os.path.dirname(os.path.realpath(__file__)), "prior_files/aligned_spins_bns_tides_on.prior", @@ -357,30 +321,10 @@ def tearDown(self): def test_create_default_prior(self): default = bilby.gw.prior.BNSPriorDict() - minima = all( - [ - self.bns_prior_dict[key].minimum == default[key].minimum - for key in default.keys() - ] - ) - maxima = all( - [ - self.bns_prior_dict[key].maximum == default[key].maximum - for key in default.keys() - ] - ) - names = all( - [ - self.bns_prior_dict[key].name == default[key].name - for key in default.keys() - ] - ) - boundaries = all( - [ - self.bns_prior_dict[key].boundary == default[key].boundary - for key in default.keys() - ] - ) + minima = all([self.bns_prior_dict[key].minimum == default[key].minimum for key in default.keys()]) + maxima = all([self.bns_prior_dict[key].maximum == default[key].maximum for key in default.keys()]) + names = all([self.bns_prior_dict[key].name == default[key].name for key in default.keys()]) + boundaries = all([self.bns_prior_dict[key].boundary == default[key].boundary for key in default.keys()]) self.assertTrue(all([minima, maxima, names, boundaries])) @@ -458,9 +402,7 @@ def test_test_has_redundant_priors(self): del self.bns_prior_dict[prior] def test_add_constraint_prior_not_redundant(self): - self.bns_prior_dict["chirp_mass"] = bilby.prior.Constraint( - minimum=1, maximum=2, name="chirp_mass" - ) + self.bns_prior_dict["chirp_mass"] = bilby.prior.Constraint(minimum=1, maximum=2, name="chirp_mass") self.assertFalse(self.bns_prior_dict.test_has_redundant_keys()) @@ -492,29 +434,21 @@ def setUp(self): pass def test_minimum(self): - prior = bilby.gw.prior.UniformComovingVolume( - minimum=10, maximum=10000, name="luminosity_distance" - ) + prior = bilby.gw.prior.UniformComovingVolume(minimum=10, maximum=10000, name="luminosity_distance") self.assertEqual(prior.minimum, 10) def test_maximum(self): - prior = bilby.gw.prior.UniformComovingVolume( - minimum=10, maximum=10000, name="luminosity_distance" - ) + prior = bilby.gw.prior.UniformComovingVolume(minimum=10, maximum=10000, name="luminosity_distance") self.assertEqual(prior.maximum, 10000) def test_increase_maximum(self): - prior = bilby.gw.prior.UniformComovingVolume( - minimum=10, maximum=10000, name="luminosity_distance" - ) + prior = bilby.gw.prior.UniformComovingVolume(minimum=10, maximum=10000, name="luminosity_distance") prior.maximum = 20000 prior_sample = prior.sample(5000) self.assertGreater(np.mean(prior_sample), 10000) def test_zero_minimum_works(self): - prior = bilby.gw.prior.UniformComovingVolume( - minimum=0, maximum=10000, name="luminosity_distance" - ) + prior = bilby.gw.prior.UniformComovingVolume(minimum=0, maximum=10000, name="luminosity_distance") self.assertEqual(prior.minimum, 0) def test_specify_cosmology(self): @@ -524,35 +458,25 @@ def test_specify_cosmology(self): self.assertEqual(repr(prior.cosmology), repr(cosmology.Planck13)) def test_comoving_prior_creation(self): - prior = bilby.gw.prior.UniformComovingVolume( - minimum=10, maximum=1000, name="comoving_distance" - ) + prior = bilby.gw.prior.UniformComovingVolume(minimum=10, maximum=1000, name="comoving_distance") self.assertEqual(prior.latex_label, "$d_C$") def test_redshift_prior_creation(self): - prior = bilby.gw.prior.UniformComovingVolume( - minimum=0.1, maximum=1, name="redshift" - ) + prior = bilby.gw.prior.UniformComovingVolume(minimum=0.1, maximum=1, name="redshift") self.assertEqual(prior.latex_label, "$z$") def test_redshift_to_luminosity_distance(self): - prior = bilby.gw.prior.UniformComovingVolume( - minimum=0.1, maximum=1, name="redshift" - ) + prior = bilby.gw.prior.UniformComovingVolume(minimum=0.1, maximum=1, name="redshift") new_prior = prior.get_corresponding_prior("luminosity_distance") self.assertEqual(new_prior.name, "luminosity_distance") def test_luminosity_distance_to_redshift(self): - prior = bilby.gw.prior.UniformComovingVolume( - minimum=10, maximum=10000, name="luminosity_distance" - ) + prior = bilby.gw.prior.UniformComovingVolume(minimum=10, maximum=10000, name="luminosity_distance") new_prior = prior.get_corresponding_prior("redshift") self.assertEqual(new_prior.name, "redshift") def test_luminosity_distance_to_comoving_distance(self): - prior = bilby.gw.prior.UniformComovingVolume( - minimum=10, maximum=10000, name="luminosity_distance" - ) + prior = bilby.gw.prior.UniformComovingVolume(minimum=10, maximum=10000, name="luminosity_distance") new_prior = prior.get_corresponding_prior("comoving_distance") self.assertEqual(new_prior.name, "comoving_distance") @@ -579,7 +503,6 @@ def test_non_analytic_form_has_correct_statistics(self): class TestConditionalChiUniformSpinMagnitude(unittest.TestCase): - def setUp(self): pass diff --git a/test/gw/result_test.py b/test/gw/result_test.py index d3b0977f2..698139d38 100644 --- a/test/gw/result_test.py +++ b/test/gw/result_test.py @@ -1,16 +1,15 @@ -import os import logging +import os import unittest import pandas as pd -from parameterized import parameterized, parameterized_class import pytest +from parameterized import parameterized, parameterized_class import bilby class BaseCBCResultTest(unittest.TestCase): - @pytest.fixture(autouse=True) def init_outdir(self, tmp_path): # Use pytest's tmp_path fixture to create a temporary directory @@ -27,9 +26,7 @@ def setUp(self): distance_marginalization=False, time_marginalization=True, frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole, - waveform_arguments=dict( - reference_frequency=20.0, waveform_approximant="IMRPhenomPv2" - ), + waveform_arguments=dict(reference_frequency=20.0, waveform_approximant="IMRPhenomPv2"), interferometers=dict( H1=dict(optimal_SNR=1, parameters=injection_parameters), L1=dict(optimal_SNR=1, parameters=injection_parameters), @@ -62,7 +59,6 @@ def tearDown(self): class TestCBCResult(BaseCBCResultTest): - @pytest.fixture(autouse=True) def set_caplog(self, caplog): self._caplog = caplog @@ -107,9 +103,7 @@ def test_reference_frequency(self): ) def test_reference_frequency_unset(self): - self.result.meta_data["likelihood"]["waveform_arguments"].pop( - "reference_frequency" - ) + self.result.meta_data["likelihood"]["waveform_arguments"].pop("reference_frequency") with self.assertRaises(AttributeError): self.result.reference_frequency @@ -133,9 +127,7 @@ def test_duration_unset(self): self.result.duration def test_start_time(self): - self.assertEqual( - self.result.start_time, self.meta_data["likelihood"]["start_time"] - ) + self.assertEqual(self.result.start_time, self.meta_data["likelihood"]["start_time"]) def test_start_time_unset(self): self.result.meta_data["likelihood"].pop("start_time") @@ -149,9 +141,7 @@ def test_waveform_approximant(self): ) def test_waveform_approximant_unset(self): - self.result.meta_data["likelihood"]["waveform_arguments"].pop( - "waveform_approximant" - ) + self.result.meta_data["likelihood"]["waveform_arguments"].pop("waveform_approximant") with self.assertRaises(AttributeError): self.result.waveform_approximant @@ -207,16 +197,11 @@ def test_detector_injection_properties(self): ) def test_detector_injection_properties_no_injection(self): - self.assertEqual( - self.result.detector_injection_properties("not_a_detector"), None - ) + self.assertEqual(self.result.detector_injection_properties("not_a_detector"), None) -@parameterized_class( - ["include_global_meta_data"], [["True"], ["False"]] -) +@parameterized_class(["include_global_meta_data"], [["True"], ["False"]]) class CBCResultsGlobalMetaDataTest(BaseCBCResultTest): - @pytest.fixture(autouse=True) def set_caplog(self, caplog): self._caplog = caplog @@ -251,12 +236,8 @@ def test_cosmology(self): assert "not containing global meta data" in str(self._caplog.text) -@parameterized_class( - ["cosmology_name"], - [["Planck15"], ["Planck15_LAL"]] -) +@parameterized_class(["cosmology_name"], [["Planck15"], ["Planck15_LAL"]]) class TestCBCResultSaveAndLoad(BaseCBCResultTest): - def setUp(self): self.orig_cosmology = bilby.gw.cosmology.get_cosmology() bilby.gw.cosmology.set_cosmology(self.cosmology_name) diff --git a/test/gw/sampler/proposal_test.py b/test/gw/sampler/proposal_test.py index 3ad800084..77a634e9a 100644 --- a/test/gw/sampler/proposal_test.py +++ b/test/gw/sampler/proposal_test.py @@ -16,9 +16,7 @@ def setUp(self): dec=prior.Uniform(minimum=0.0, maximum=np.pi, boundary="reflective"), ) ) - self.jump_proposal = bilby.gw.sampler.proposal.SkyLocationWanderJump( - priors=self.priors - ) + self.jump_proposal = bilby.gw.sampler.proposal.SkyLocationWanderJump(priors=self.priors) def tearDown(self): del self.priors @@ -53,9 +51,7 @@ def setUp(self): psi=prior.Uniform(minimum=0.0, maximum=np.pi), ) ) - self.jump_proposal = bilby.gw.sampler.proposal.CorrelatedPolarisationPhaseJump( - priors=self.priors - ) + self.jump_proposal = bilby.gw.sampler.proposal.CorrelatedPolarisationPhaseJump(priors=self.priors) def tearDown(self): del self.priors @@ -67,9 +63,7 @@ def test_jump_proposal_call_case_1(self): sample = proposal.Sample(dict(phase=0.2, psi=0.5)) alpha = 3.0 * np.pi * 0.3 beta = 0.3 - expected = proposal.Sample( - dict(phase=0.5 * (alpha - beta), psi=0.5 * (alpha + beta)) - ) + expected = proposal.Sample(dict(phase=0.5 * (alpha - beta), psi=0.5 * (alpha + beta))) self.assertEqual(expected, self.jump_proposal(sample, coordinates=None)) def test_jump_proposal_call_case_2(self): @@ -78,9 +72,7 @@ def test_jump_proposal_call_case_2(self): sample = proposal.Sample(dict(phase=0.2, psi=0.5)) alpha = 0.7 beta = 3.0 * np.pi * 0.7 - 2 * np.pi - expected = proposal.Sample( - dict(phase=0.5 * (alpha - beta), psi=0.5 * (alpha + beta)) - ) + expected = proposal.Sample(dict(phase=0.5 * (alpha - beta), psi=0.5 * (alpha + beta))) self.assertEqual(expected, self.jump_proposal(sample)) @@ -92,9 +84,7 @@ def setUp(self): psi=prior.Uniform(minimum=0.0, maximum=np.pi), ) ) - self.jump_proposal = bilby.gw.sampler.proposal.PolarisationPhaseJump( - priors=self.priors - ) + self.jump_proposal = bilby.gw.sampler.proposal.PolarisationPhaseJump(priors=self.priors) def tearDown(self): del self.priors diff --git a/test/gw/source_test.py b/test/gw/source_test.py index c55db9464..d57fac983 100644 --- a/test/gw/source_test.py +++ b/test/gw/source_test.py @@ -1,13 +1,13 @@ -import unittest import logging -import pytest +import unittest +from copy import copy -import bilby import lal import lalsimulation - import numpy as np -from copy import copy +import pytest + +import bilby class TestLalBBH(unittest.TestCase): @@ -48,28 +48,20 @@ def tearDown(self): def test_lal_bbh_works_runs_valid_parameters(self): self.parameters.update(self.waveform_kwargs) self.assertIsInstance( - bilby.gw.source.lal_binary_black_hole( - self.frequency_array, **self.parameters - ), + bilby.gw.source.lal_binary_black_hole(self.frequency_array, **self.parameters), dict, ) def test_waveform_error_catching(self): self.bad_parameters.update(self.waveform_kwargs) - self.assertIsNone( - bilby.gw.source.lal_binary_black_hole( - self.frequency_array, **self.bad_parameters - ) - ) + self.assertIsNone(bilby.gw.source.lal_binary_black_hole(self.frequency_array, **self.bad_parameters)) def test_waveform_error_raising(self): raise_error_parameters = copy(self.bad_parameters) raise_error_parameters.update(self.waveform_kwargs) raise_error_parameters["catch_waveform_errors"] = False with self.assertRaises(Exception): - bilby.gw.source.lal_binary_black_hole( - self.frequency_array, **raise_error_parameters - ) + bilby.gw.source.lal_binary_black_hole(self.frequency_array, **raise_error_parameters) def test_unused_waveform_kwargs_message(self): self.parameters.update(self.waveform_kwargs) @@ -77,18 +69,14 @@ def test_unused_waveform_kwargs_message(self): bilby.gw.source.logger.propagate = True with self._caplog.at_level(logging.WARNING, logger="bilby"): - bilby.gw.source.lal_binary_black_hole( - self.frequency_array, **self.parameters - ) + bilby.gw.source.lal_binary_black_hole(self.frequency_array, **self.parameters) assert "There are unused waveform kwargs" in self._caplog.text del self.parameters["unused_waveform_parameter"] def test_lal_bbh_works_without_waveform_parameters(self): self.assertIsInstance( - bilby.gw.source.lal_binary_black_hole( - self.frequency_array, **self.parameters - ), + bilby.gw.source.lal_binary_black_hole(self.frequency_array, **self.parameters), dict, ) @@ -148,28 +136,21 @@ def tearDown(self): def test_gwsignal_bbh_works_runs_valid_parameters(self): self.parameters.update(self.waveform_kwargs) self.assertIsInstance( - bilby.gw.source.gwsignal_binary_black_hole( - self.frequency_array, **self.parameters - ), + bilby.gw.source.gwsignal_binary_black_hole(self.frequency_array, **self.parameters), dict, ) def test_waveform_error_catching(self): self.bad_parameters.update(self.waveform_kwargs) - self.assertIsNone( - bilby.gw.source.gwsignal_binary_black_hole( - self.frequency_array, **self.bad_parameters - ) - ) + self.assertIsNone(bilby.gw.source.gwsignal_binary_black_hole(self.frequency_array, **self.bad_parameters)) def test_waveform_error_raising(self): raise_error_parameters = copy(self.bad_parameters) raise_error_parameters.update(self.waveform_kwargs) raise_error_parameters["catch_waveform_errors"] = False with self.assertRaises(Exception): - bilby.gw.source.gwsignal_binary_black_hole( - self.frequency_array, **raise_error_parameters - ) + bilby.gw.source.gwsignal_binary_black_hole(self.frequency_array, **raise_error_parameters) + # def test_gwsignal_bbh_works_without_waveform_parameters(self): # self.assertIsInstance( # bilby.gw.source.gwsignal_binary_black_hole( @@ -180,18 +161,10 @@ def test_waveform_error_raising(self): def test_gwsignal_lal_bbh_consistency(self): self.parameters.update(self.waveform_kwargs) - hpc_gwsignal = bilby.gw.source.gwsignal_binary_black_hole( - self.frequency_array, **self.parameters - ) - hpc_lal = bilby.gw.source.lal_binary_black_hole( - self.frequency_array, **self.parameters - ) - self.assertTrue( - np.allclose(hpc_gwsignal["plus"], hpc_lal["plus"], atol=0, rtol=1e-7) - ) - self.assertTrue( - np.allclose(hpc_gwsignal["cross"], hpc_lal["cross"], atol=0, rtol=1e-7) - ) + hpc_gwsignal = bilby.gw.source.gwsignal_binary_black_hole(self.frequency_array, **self.parameters) + hpc_lal = bilby.gw.source.lal_binary_black_hole(self.frequency_array, **self.parameters) + self.assertTrue(np.allclose(hpc_gwsignal["plus"], hpc_lal["plus"], atol=0, rtol=1e-7)) + self.assertTrue(np.allclose(hpc_gwsignal["cross"], hpc_lal["cross"], atol=0, rtol=1e-7)) class TestLalBNS(unittest.TestCase): @@ -226,17 +199,13 @@ def tearDown(self): def test_lal_bns_runs_with_valid_parameters(self): self.parameters.update(self.waveform_kwargs) self.assertIsInstance( - bilby.gw.source.lal_binary_neutron_star( - self.frequency_array, **self.parameters - ), + bilby.gw.source.lal_binary_neutron_star(self.frequency_array, **self.parameters), dict, ) def test_lal_bns_works_without_waveform_parameters(self): self.assertIsInstance( - bilby.gw.source.lal_binary_neutron_star( - self.frequency_array, **self.parameters - ), + bilby.gw.source.lal_binary_neutron_star(self.frequency_array, **self.parameters), dict, ) @@ -245,9 +214,7 @@ def test_fails_without_tidal_parameters(self): self.parameters.pop("lambda_2") self.parameters.update(self.waveform_kwargs) with self.assertRaises(TypeError): - bilby.gw.source.lal_binary_neutron_star( - self.frequency_array, **self.parameters - ) + bilby.gw.source.lal_binary_neutron_star(self.frequency_array, **self.parameters) class TestEccentricLalBBH(unittest.TestCase): @@ -275,17 +242,13 @@ def tearDown(self): def test_lal_ebbh_works_runs_valid_parameters(self): self.parameters.update(self.waveform_kwargs) self.assertIsInstance( - bilby.gw.source.lal_eccentric_binary_black_hole_no_spins( - self.frequency_array, **self.parameters - ), + bilby.gw.source.lal_eccentric_binary_black_hole_no_spins(self.frequency_array, **self.parameters), dict, ) def test_lal_ebbh_works_without_waveform_parameters(self): self.assertIsInstance( - bilby.gw.source.lal_eccentric_binary_black_hole_no_spins( - self.frequency_array, **self.parameters - ), + bilby.gw.source.lal_eccentric_binary_black_hole_no_spins(self.frequency_array, **self.parameters), dict, ) @@ -293,9 +256,7 @@ def test_fails_without_eccentricity(self): self.parameters.pop("eccentricity") self.parameters.update(self.waveform_kwargs) with self.assertRaises(TypeError): - bilby.gw.source.lal_eccentric_binary_black_hole_no_spins( - self.frequency_array, **self.parameters - ) + bilby.gw.source.lal_eccentric_binary_black_hole_no_spins(self.frequency_array, **self.parameters) @pytest.mark.requires_roqs @@ -303,9 +264,9 @@ class TestROQBBH(unittest.TestCase): def setUp(self): roq_dir = "/roq_basis" - fnodes_linear_file = "{}/fnodes_linear.npy".format(roq_dir) + fnodes_linear_file = f"{roq_dir}/fnodes_linear.npy" fnodes_linear = np.load(fnodes_linear_file).T - fnodes_quadratic_file = "{}/fnodes_quadratic.npy".format(roq_dir) + fnodes_quadratic_file = f"{roq_dir}/fnodes_quadratic.npy" fnodes_quadratic = np.load(fnodes_quadratic_file).T self.parameters = dict( @@ -336,9 +297,7 @@ def tearDown(self): def test_roq_runs_valid_parameters(self): self.parameters.update(self.waveform_kwargs) - self.assertIsInstance( - bilby.gw.source.binary_black_hole_roq(self.frequency_array, **self.parameters), dict - ) + self.assertIsInstance(bilby.gw.source.binary_black_hole_roq(self.frequency_array, **self.parameters), dict) def test_roq_fails_without_frequency_nodes(self): self.parameters.update(self.waveform_kwargs) @@ -355,13 +314,13 @@ def setUp(self): mass_2=30.0, luminosity_distance=400.0, a_1=0.4, - tilt_1=0., - phi_12=0., + tilt_1=0.0, + phi_12=0.0, a_2=0.8, - tilt_2=0., - phi_jl=0., + tilt_2=0.0, + phi_jl=0.0, theta_jn=0.3, - phase=0.0 + phase=0.0, ) self.minimum_frequency = 20.0 self.frequency_array = bilby.core.utils.create_frequency_series(2048, 8) @@ -388,7 +347,7 @@ def test_valid_parameters(self): bilby.gw.source.binary_black_hole_frequency_sequence( self.frequency_array, frequencies=self.frequencies, **self.parameters ), - dict + dict, ) def test_waveform_error_catching(self): @@ -418,14 +377,14 @@ def test_match_LalBBH(self): ) self.assertEqual(freqseq.keys(), lalbbh.keys()) for mode in freqseq: - diff = np.sum(np.abs(freqseq[mode] - lalbbh[mode][self.full_frequencies_to_sequence])**2.) - norm = np.sum(np.abs(freqseq[mode])**2.) + diff = np.sum(np.abs(freqseq[mode] - lalbbh[mode][self.full_frequencies_to_sequence]) ** 2.0) + norm = np.sum(np.abs(freqseq[mode]) ** 2.0) self.assertLess(diff / norm, 1e-15) def test_match_LalBBH_specify_modes(self): parameters = copy(self.parameters) parameters.update(self.waveform_kwargs) - parameters['mode_array'] = [[2, 2]] + parameters["mode_array"] = [[2, 2]] freqseq = bilby.gw.source.binary_black_hole_frequency_sequence( self.frequency_array, frequencies=self.frequencies, **parameters ) @@ -434,16 +393,16 @@ def test_match_LalBBH_specify_modes(self): ) self.assertEqual(freqseq.keys(), lalbbh.keys()) for mode in freqseq: - diff = np.sum(np.abs(freqseq[mode] - lalbbh[mode][self.full_frequencies_to_sequence])**2.) - norm = np.sum(np.abs(freqseq[mode])**2.) + diff = np.sum(np.abs(freqseq[mode] - lalbbh[mode][self.full_frequencies_to_sequence]) ** 2.0) + norm = np.sum(np.abs(freqseq[mode]) ** 2.0) self.assertLess(diff / norm, 1e-15) def test_match_LalBBH_nonGR(self): parameters = copy(self.parameters) parameters.update(self.waveform_kwargs) wf_dict = lal.CreateDict() - lalsimulation.SimInspiralWaveformParamsInsertNonGRDChi0(wf_dict, 1.) - parameters['lal_waveform_dictionary'] = wf_dict + lalsimulation.SimInspiralWaveformParamsInsertNonGRDChi0(wf_dict, 1.0) + parameters["lal_waveform_dictionary"] = wf_dict freqseq = bilby.gw.source.binary_black_hole_frequency_sequence( self.frequency_array, frequencies=self.frequencies, **parameters ) @@ -452,8 +411,8 @@ def test_match_LalBBH_nonGR(self): ) self.assertEqual(freqseq.keys(), lalbbh.keys()) for mode in freqseq: - diff = np.sum(np.abs(freqseq[mode] - lalbbh[mode][self.full_frequencies_to_sequence])**2.) - norm = np.sum(np.abs(freqseq[mode])**2.) + diff = np.sum(np.abs(freqseq[mode] - lalbbh[mode][self.full_frequencies_to_sequence]) ** 2.0) + norm = np.sum(np.abs(freqseq[mode]) ** 2.0) self.assertLess(diff / norm, 1e-15) @@ -472,7 +431,7 @@ def setUp(self): theta_jn=1.7, phase=0.0, lambda_1=1000.0, - lambda_2=1000.0 + lambda_2=1000.0, ) self.minimum_frequency = 50.0 self.frequency_array = bilby.core.utils.create_frequency_series(2048, 16) @@ -495,7 +454,7 @@ def test_with_valid_parameters(self): bilby.gw.source.binary_neutron_star_frequency_sequence( self.frequency_array, frequencies=self.frequencies, **self.parameters ), - dict + dict, ) def test_fails_without_tidal_parameters(self): @@ -517,8 +476,8 @@ def test_match_LalBNS(self): ) self.assertEqual(freqseq.keys(), lalbns.keys()) for mode in freqseq: - diff = np.sum(np.abs(freqseq[mode] - lalbns[mode][self.full_frequencies_to_sequence])**2.) - norm = np.sum(np.abs(freqseq[mode])**2.) + diff = np.sum(np.abs(freqseq[mode] - lalbns[mode][self.full_frequencies_to_sequence]) ** 2.0) + norm = np.sum(np.abs(freqseq[mode]) ** 2.0) self.assertLess(diff / norm, 1e-5) @@ -550,7 +509,7 @@ def setUp(self): minimum_frequency=20.0, catch_waveform_errors=True, fiducial=False, - frequency_bin_edges=np.arange(20, 1500, 50) + frequency_bin_edges=np.arange(20, 1500, 50), ) self.frequency_array = bilby.core.utils.create_frequency_series(2048, 4) self.bad_parameters = copy(self.parameters) @@ -566,35 +525,27 @@ def tearDown(self): def test_relbin_fiducial_bbh_works_runs_valid_parameters(self): self.parameters.update(self.waveform_kwargs_fiducial) self.assertIsInstance( - bilby.gw.source.lal_binary_black_hole_relative_binning( - self.frequency_array, **self.parameters - ), + bilby.gw.source.lal_binary_black_hole_relative_binning(self.frequency_array, **self.parameters), dict, ) def test_relbin_binned_bbh_works_runs_valid_parameters(self): self.parameters.update(self.waveform_kwargs_binned) self.assertIsInstance( - bilby.gw.source.lal_binary_black_hole_relative_binning( - self.frequency_array, **self.parameters - ), + bilby.gw.source.lal_binary_black_hole_relative_binning(self.frequency_array, **self.parameters), dict, ) def test_waveform_error_catching_fiducial(self): self.bad_parameters.update(self.waveform_kwargs_fiducial) self.assertIsNone( - bilby.gw.source.lal_binary_black_hole_relative_binning( - self.frequency_array, **self.bad_parameters - ) + bilby.gw.source.lal_binary_black_hole_relative_binning(self.frequency_array, **self.bad_parameters) ) def test_waveform_error_catching_binned(self): self.bad_parameters.update(self.waveform_kwargs_binned) self.assertIsNone( - bilby.gw.source.lal_binary_black_hole_relative_binning( - self.frequency_array, **self.bad_parameters - ) + bilby.gw.source.lal_binary_black_hole_relative_binning(self.frequency_array, **self.bad_parameters) ) def test_waveform_error_raising_fiducial(self): @@ -602,18 +553,14 @@ def test_waveform_error_raising_fiducial(self): raise_error_parameters.update(self.waveform_kwargs_fiducial) raise_error_parameters["catch_waveform_errors"] = False with self.assertRaises(Exception): - bilby.gw.source.lal_binary_black_hole_relative_binning( - self.frequency_array, **raise_error_parameters - ) + bilby.gw.source.lal_binary_black_hole_relative_binning(self.frequency_array, **raise_error_parameters) def test_waveform_error_raising_binned(self): raise_error_parameters = copy(self.bad_parameters) raise_error_parameters.update(self.waveform_kwargs_binned) raise_error_parameters["catch_waveform_errors"] = False with self.assertRaises(Exception): - bilby.gw.source.lal_binary_black_hole_relative_binning( - self.frequency_array, **raise_error_parameters - ) + bilby.gw.source.lal_binary_black_hole_relative_binning(self.frequency_array, **raise_error_parameters) def test_relbin_bbh_runs_without_fiducial_option(self): self.assertIsInstance( @@ -680,18 +627,14 @@ def tearDown(self): def test_relbin_fiducial_bns_runs_with_valid_parameters(self): self.parameters.update(self.waveform_kwargs_fiducial) self.assertIsInstance( - bilby.gw.source.lal_binary_neutron_star_relative_binning( - self.frequency_array, **self.parameters - ), + bilby.gw.source.lal_binary_neutron_star_relative_binning(self.frequency_array, **self.parameters), dict, ) def test_relbin_binned_bns_runs_with_valid_parameters(self): self.parameters.update(self.waveform_kwargs_binned) self.assertIsInstance( - bilby.gw.source.lal_binary_neutron_star_relative_binning( - self.frequency_array, **self.parameters - ), + bilby.gw.source.lal_binary_neutron_star_relative_binning(self.frequency_array, **self.parameters), dict, ) @@ -710,18 +653,14 @@ def test_fiducial_fails_without_tidal_parameters(self): self.parameters.pop("lambda_2") self.parameters.update(self.waveform_kwargs_fiducial) with self.assertRaises(TypeError): - bilby.gw.source.lal_binary_neutron_star_relative_binning( - self.frequency_array, **self.parameters - ) + bilby.gw.source.lal_binary_neutron_star_relative_binning(self.frequency_array, **self.parameters) def test_binned_fails_without_tidal_parameters(self): self.parameters.pop("lambda_1") self.parameters.pop("lambda_2") self.parameters.update(self.waveform_kwargs_binned) with self.assertRaises(TypeError): - bilby.gw.source.lal_binary_neutron_star_relative_binning( - self.frequency_array, **self.parameters - ) + bilby.gw.source.lal_binary_neutron_star_relative_binning(self.frequency_array, **self.parameters) if __name__ == "__main__": diff --git a/test/gw/utils_test.py b/test/gw/utils_test.py index cf78849c7..6610c07ad 100644 --- a/test/gw/utils_test.py +++ b/test/gw/utils_test.py @@ -1,15 +1,15 @@ -import unittest import os -from shutil import rmtree +import unittest from importlib.metadata import version +from shutil import rmtree -import numpy as np import lal import lalsimulation as lalsim -from gwpy.timeseries import TimeSeries +import numpy as np +import pytest from gwpy.detector import Channel +from gwpy.timeseries import TimeSeries from scipy.stats import ks_2samp -import pytest import bilby from bilby.gw import utils as gwutils @@ -30,13 +30,13 @@ def test_asd_from_freq_series(self): freq_data = np.array([1, 2, 3]) df = 0.1 asd = gwutils.asd_from_freq_series(freq_data, df) - self.assertTrue(np.all(asd == freq_data * 2 * df ** 0.5)) + self.assertTrue(np.all(asd == freq_data * 2 * df**0.5)) def test_psd_from_freq_series(self): freq_data = np.array([1, 2, 3]) df = 0.1 psd = gwutils.psd_from_freq_series(freq_data, df) - self.assertTrue(np.all(psd == (freq_data * 2 * df ** 0.5) ** 2)) + self.assertTrue(np.all(psd == (freq_data * 2 * df**0.5) ** 2)) def test_inner_product(self): aa = np.array([1, 2, 3]) @@ -69,14 +69,13 @@ def test_matched_filter_snr(self): psd = PSD.power_spectral_density_interpolated(frequency) duration = 4 - mfsnr = gwutils.matched_filter_snr( - signal, frequency_domain_strain, psd, duration - ) + mfsnr = gwutils.matched_filter_snr(signal, frequency_domain_strain, psd, duration) self.assertEqual(mfsnr, 25.510869054168282) @pytest.mark.skip(reason="GWOSC unstable: avoiding this test") def test_get_event_time(self): from urllib3.exceptions import NewConnectionError + events = [ "GW150914", "GW170104", @@ -120,18 +119,14 @@ def test_read_frame_file(self): ts.write(filename, format="gwf") # Check reading without time limits - strain = gwutils.read_frame_file( - filename, start_time=None, end_time=None, channel=channel - ) + strain = gwutils.read_frame_file(filename, start_time=None, end_time=None, channel=channel) self.assertEqual(strain.name, channel) self.assertTrue(np.all(strain.value == data)) # Check reading with time limits start_cut = 2 end_cut = 8 - strain = gwutils.read_frame_file( - filename, start_time=start_cut, end_time=end_cut, channel=channel - ) + strain = gwutils.read_frame_file(filename, start_time=start_cut, end_time=end_cut, channel=channel) idxs = (times >= start_cut) & (times < end_cut) self.assertTrue(np.all(strain.value == data[idxs])) @@ -140,9 +135,7 @@ def test_read_frame_file(self): self.assertTrue(np.all(strain.value == data)) # Check reading with incorrect channel - strain = gwutils.read_frame_file( - filename, start_time=None, end_time=None, channel="WRONG" - ) + strain = gwutils.read_frame_file(filename, start_time=None, end_time=None, channel="WRONG") self.assertTrue(np.all(strain.value == data)) ts = TimeSeries(data=data, times=times, t0=0) @@ -152,9 +145,7 @@ def test_read_frame_file(self): self.assertEqual(strain, None) def test_convert_args_list_to_float(self): - self.assertEqual( - gwutils.convert_args_list_to_float(1, "2", 3.0), [1.0, 2.0, 3.0] - ) + self.assertEqual(gwutils.convert_args_list_to_float(1, "2", 3.0), [1.0, 2.0, 3.0]) with self.assertRaises(ValueError): gwutils.convert_args_list_to_float(1, "2", "ten") @@ -265,7 +256,6 @@ def test_safe_cast_mode_to_int(self): class TestSkyFrameConversion(unittest.TestCase): - def setUp(self) -> None: self.priors = bilby.core.prior.PriorDict() self.priors["ra"] = bilby.core.prior.Uniform(0, 2 * np.pi) @@ -285,10 +275,7 @@ def test_conversion_gives_correct_prior(self) -> None: zeniths = self.samples["zenith"] azimuths = self.samples["azimuth"] times = self.samples["time"] - args = zip(*[ - (zenith, azimuth, time, self.ifos) - for zenith, azimuth, time in zip(zeniths, azimuths, times) - ]) + args = zip(*[(zenith, azimuth, time, self.ifos) for zenith, azimuth, time in zip(zeniths, azimuths, times)]) ras, decs = zip(*map(bilby.gw.utils.zenith_azimuth_to_ra_dec, *args)) self.assertGreaterEqual(ks_2samp(self.samples["ra"], ras).pvalue, 0.01) self.assertGreaterEqual(ks_2samp(self.samples["dec"], decs).pvalue, 0.01) @@ -296,6 +283,7 @@ def test_conversion_gives_correct_prior(self) -> None: def test_ln_i0_mathces_scipy(): from scipy.special import i0 + values = np.linspace(-10, 10, 101) assert max(abs(gwutils.ln_i0(values) - np.log(i0(values)))) < 1e-10 diff --git a/test/gw/waveform_generator_test.py b/test/gw/waveform_generator_test.py index 1840e521a..89f4c703b 100644 --- a/test/gw/waveform_generator_test.py +++ b/test/gw/waveform_generator_test.py @@ -1,44 +1,25 @@ import unittest from unittest import mock -import bilby import lalsimulation import numpy as np +import bilby + -def dummy_func_array_return_value( - frequency_array, amplitude, mu, sigma, ra, dec, geocent_time, psi, **kwargs -): +def dummy_func_array_return_value(frequency_array, amplitude, mu, sigma, ra, dec, geocent_time, psi, **kwargs): return amplitude + mu + frequency_array + sigma + ra + dec + geocent_time + psi -def dummy_func_dict_return_value( - frequency_array, amplitude, mu, sigma, ra, dec, geocent_time, psi, **kwargs -): +def dummy_func_dict_return_value(frequency_array, amplitude, mu, sigma, ra, dec, geocent_time, psi, **kwargs): ht = { - "plus": amplitude - + mu - + frequency_array - + sigma - + ra - + dec - + geocent_time - + psi, - "cross": amplitude - + mu - + frequency_array - + sigma - + ra - + dec - + geocent_time - + psi, + "plus": amplitude + mu + frequency_array + sigma + ra + dec + geocent_time + psi, + "cross": amplitude + mu + frequency_array + sigma + ra + dec + geocent_time + psi, } return ht -def dummy_func_array_return_value_2( - array, amplitude, mu, sigma, ra, dec, geocent_time, psi -): +def dummy_func_array_return_value_2(array, amplitude, mu, sigma, ra, dec, geocent_time, psi): return dict(plus=np.array(array), cross=np.array(array)) @@ -62,18 +43,19 @@ def tearDown(self): del self.simulation_parameters def test_repr(self): + frequency_domain_model = bilby.core.utils.get_function_path( + self.waveform_generator.frequency_domain_source_model + ) + time_domain_model = bilby.core.utils.get_function_path(self.waveform_generator.time_domain_source_model) + conversion = bilby.core.utils.get_function_path(bilby.gw.conversion.convert_to_lal_binary_black_hole_parameters) expected = ( - "WaveformGenerator(duration={}, sampling_frequency={}, start_time={}, " - "frequency_domain_source_model={}, time_domain_source_model={}, " - "parameter_conversion={}, waveform_arguments={})".format( - self.waveform_generator.duration, - self.waveform_generator.sampling_frequency, - self.waveform_generator.start_time, - bilby.core.utils.get_function_path(self.waveform_generator.frequency_domain_source_model), - bilby.core.utils.get_function_path(self.waveform_generator.time_domain_source_model), - bilby.core.utils.get_function_path(bilby.gw.conversion.convert_to_lal_binary_black_hole_parameters), - self.waveform_generator.waveform_arguments, - ) + f"WaveformGenerator(duration={self.waveform_generator.duration}, " + f"sampling_frequency={self.waveform_generator.sampling_frequency}, " + f"start_time={self.waveform_generator.start_time}, " + f"frequency_domain_source_model={frequency_domain_model}, " + f"time_domain_source_model={time_domain_model}, " + f"parameter_conversion={conversion}, " + f"waveform_arguments={self.waveform_generator.waveform_arguments})" ) self.assertEqual(expected, repr(self.waveform_generator)) @@ -81,18 +63,19 @@ def test_repr_with_time_domain_source_model(self): self.waveform_generator = bilby.gw.waveform_generator.WaveformGenerator( 1, 4096, time_domain_source_model=dummy_func_dict_return_value ) + frequency_domain_model = bilby.core.utils.get_function_path( + self.waveform_generator.frequency_domain_source_model + ) + time_domain_model = bilby.core.utils.get_function_path(self.waveform_generator.time_domain_source_model) + conversion = bilby.core.utils.get_function_path(bilby.gw.conversion.convert_to_lal_binary_black_hole_parameters) expected = ( - "WaveformGenerator(duration={}, sampling_frequency={}, start_time={}, " - "frequency_domain_source_model={}, time_domain_source_model={}, " - "parameter_conversion={}, waveform_arguments={})".format( - self.waveform_generator.duration, - self.waveform_generator.sampling_frequency, - self.waveform_generator.start_time, - bilby.core.utils.get_function_path(self.waveform_generator.frequency_domain_source_model), - bilby.core.utils.get_function_path(self.waveform_generator.time_domain_source_model), - bilby.core.utils.get_function_path(bilby.gw.conversion.convert_to_lal_binary_black_hole_parameters), - self.waveform_generator.waveform_arguments, - ) + f"WaveformGenerator(duration={self.waveform_generator.duration}, " + f"sampling_frequency={self.waveform_generator.sampling_frequency}, " + f"start_time={self.waveform_generator.start_time}, " + f"frequency_domain_source_model={frequency_domain_model}, " + f"time_domain_source_model={time_domain_model}, " + f"parameter_conversion={conversion}, " + f"waveform_arguments={self.waveform_generator.waveform_arguments})" ) self.assertEqual(expected, repr(self.waveform_generator)) @@ -101,18 +84,19 @@ def conversion_func(): pass self.waveform_generator.parameter_conversion = conversion_func + frequency_domain_model = bilby.core.utils.get_function_path( + self.waveform_generator.frequency_domain_source_model + ) + time_domain_model = bilby.core.utils.get_function_path(self.waveform_generator.time_domain_source_model) + conversion = bilby.core.utils.get_function_path(conversion_func) expected = ( - "WaveformGenerator(duration={}, sampling_frequency={}, start_time={}, " - "frequency_domain_source_model={}, time_domain_source_model={}, " - "parameter_conversion={}, waveform_arguments={})".format( - self.waveform_generator.duration, - self.waveform_generator.sampling_frequency, - self.waveform_generator.start_time, - bilby.core.utils.get_function_path(self.waveform_generator.frequency_domain_source_model), - bilby.core.utils.get_function_path(self.waveform_generator.time_domain_source_model), - bilby.core.utils.get_function_path(conversion_func), - self.waveform_generator.waveform_arguments, - ) + f"WaveformGenerator(duration={self.waveform_generator.duration}, " + f"sampling_frequency={self.waveform_generator.sampling_frequency}, " + f"start_time={self.waveform_generator.start_time}, " + f"frequency_domain_source_model={frequency_domain_model}, " + f"time_domain_source_model={time_domain_model}, " + f"parameter_conversion={conversion}, " + f"waveform_arguments={self.waveform_generator.waveform_arguments})" ) self.assertEqual(expected, repr(self.waveform_generator)) @@ -246,9 +230,7 @@ def tearDown(self): def test_parameter_setter_sets_expected_values_with_expected_keys(self): self.waveform_generator.parameters = self.simulation_parameters.copy() for key in self.simulation_parameters: - self.assertEqual( - self.waveform_generator.parameters[key], self.simulation_parameters[key] - ) + self.assertEqual(self.waveform_generator.parameters[key], self.simulation_parameters[key]) def test_parameter_setter_none_handling(self): with self.assertRaises(TypeError): @@ -259,21 +241,15 @@ def test_parameter_setter_none_handling(self): def test_frequency_array_setter(self): new_frequency_array = np.arange(1, 100) self.waveform_generator.frequency_array = new_frequency_array - self.assertTrue( - np.array_equal(new_frequency_array, self.waveform_generator.frequency_array) - ) + self.assertTrue(np.array_equal(new_frequency_array, self.waveform_generator.frequency_array)) def test_time_array_setter(self): new_time_array = np.arange(1, 100) self.waveform_generator.time_array = new_time_array - self.assertTrue( - np.array_equal(new_time_array, self.waveform_generator.time_array) - ) + self.assertTrue(np.array_equal(new_time_array, self.waveform_generator.time_array)) def test_parameters_set_from_frequency_domain_source_model(self): - self.waveform_generator.frequency_domain_source_model = ( - dummy_func_dict_return_value - ) + self.waveform_generator.frequency_domain_source_model = dummy_func_dict_return_value self.waveform_generator.parameters = self.simulation_parameters.copy() self.assertListEqual( sorted(list(self.waveform_generator.parameters.keys())), @@ -323,13 +299,9 @@ def tearDown(self): del self.simulation_parameters def test_parameter_conversion_is_called(self): - self.waveform_generator.parameter_conversion = mock.MagicMock( - side_effect=KeyError("test") - ) + self.waveform_generator.parameter_conversion = mock.MagicMock(side_effect=KeyError("test")) with self.assertRaises(KeyError): - self.waveform_generator.frequency_domain_strain( - parameters=self.simulation_parameters - ) + self.waveform_generator.frequency_domain_strain(parameters=self.simulation_parameters) def test_frequency_domain_source_model_call(self): expected = self.waveform_generator.frequency_domain_source_model( @@ -342,9 +314,7 @@ def test_frequency_domain_source_model_call(self): self.simulation_parameters["geocent_time"], self.simulation_parameters["psi"], ) - actual = self.waveform_generator.frequency_domain_strain( - parameters=self.simulation_parameters - ) + actual = self.waveform_generator.frequency_domain_strain(parameters=self.simulation_parameters) self.assertTrue(np.array_equal(expected["plus"], actual["plus"])) self.assertTrue(np.array_equal(expected["cross"], actual["cross"])) @@ -357,12 +327,8 @@ def side_effect(value, value2): with mock.patch("bilby.core.utils.nfft") as m: m.side_effect = side_effect - expected = self.waveform_generator.time_domain_strain( - parameters=self.simulation_parameters - ) - actual = self.waveform_generator.frequency_domain_strain( - parameters=self.simulation_parameters - ) + expected = self.waveform_generator.time_domain_strain(parameters=self.simulation_parameters) + actual = self.waveform_generator.frequency_domain_strain(parameters=self.simulation_parameters) self.assertTrue(np.array_equal(expected, actual)) def test_time_domain_source_model_call_with_dict(self): @@ -374,12 +340,8 @@ def side_effect(value, value2): with mock.patch("bilby.core.utils.nfft") as m: m.side_effect = side_effect - expected = self.waveform_generator.time_domain_strain( - parameters=self.simulation_parameters - ) - actual = self.waveform_generator.frequency_domain_strain( - parameters=self.simulation_parameters - ) + expected = self.waveform_generator.time_domain_strain(parameters=self.simulation_parameters) + actual = self.waveform_generator.frequency_domain_strain(parameters=self.simulation_parameters) self.assertTrue(np.array_equal(expected["plus"], actual["plus"])) self.assertTrue(np.array_equal(expected["cross"], actual["cross"])) @@ -387,9 +349,7 @@ def test_no_source_model_given(self): self.waveform_generator.time_domain_source_model = None self.waveform_generator.frequency_domain_source_model = None with self.assertRaises(RuntimeError): - self.waveform_generator.frequency_domain_strain( - parameters=self.simulation_parameters - ) + self.waveform_generator.frequency_domain_strain(parameters=self.simulation_parameters) def test_key_popping(self): self.waveform_generator.parameter_conversion = mock.MagicMock( @@ -409,9 +369,7 @@ def test_key_popping(self): ) ) try: - self.waveform_generator.frequency_domain_strain( - parameters=self.simulation_parameters - ) + self.waveform_generator.frequency_domain_strain(parameters=self.simulation_parameters) except RuntimeError: pass self.assertListEqual( @@ -420,93 +378,79 @@ def test_key_popping(self): ) def test_caching_with_parameters(self): - original_waveform = self.waveform_generator.frequency_domain_strain( - parameters=self.simulation_parameters - ) - new_waveform = self.waveform_generator.frequency_domain_strain( - parameters=self.simulation_parameters - ) + original_waveform = self.waveform_generator.frequency_domain_strain(parameters=self.simulation_parameters) + new_waveform = self.waveform_generator.frequency_domain_strain(parameters=self.simulation_parameters) self.assertDictEqual(original_waveform, new_waveform) def test_caching_without_parameters(self): - original_waveform = self.waveform_generator.frequency_domain_strain( - parameters=self.simulation_parameters - ) + original_waveform = self.waveform_generator.frequency_domain_strain(parameters=self.simulation_parameters) new_waveform = self.waveform_generator.frequency_domain_strain() self.assertDictEqual(original_waveform, new_waveform) def test_frequency_domain_caching_and_using_time_domain_strain_without_parameters( self, ): - self.assertFalse(_test_caching_different_domain( - self.waveform_generator.frequency_domain_strain, - self.waveform_generator.time_domain_strain, - self.simulation_parameters, - None, - )) + self.assertFalse( + _test_caching_different_domain( + self.waveform_generator.frequency_domain_strain, + self.waveform_generator.time_domain_strain, + self.simulation_parameters, + None, + ) + ) def test_frequency_domain_caching_and_using_time_domain_strain_with_parameters( self, ): - self.assertFalse(_test_caching_different_domain( - self.waveform_generator.frequency_domain_strain, - self.waveform_generator.time_domain_strain, - self.simulation_parameters, - self.simulation_parameters, - )) + self.assertFalse( + _test_caching_different_domain( + self.waveform_generator.frequency_domain_strain, + self.waveform_generator.time_domain_strain, + self.simulation_parameters, + self.simulation_parameters, + ) + ) def test_time_domain_caching_and_using_frequency_domain_strain_without_parameters( self, ): - self.assertFalse(_test_caching_different_domain( - self.waveform_generator.time_domain_strain, - self.waveform_generator.frequency_domain_strain, - self.simulation_parameters, - None, - )) + self.assertFalse( + _test_caching_different_domain( + self.waveform_generator.time_domain_strain, + self.waveform_generator.frequency_domain_strain, + self.simulation_parameters, + None, + ) + ) def test_time_domain_caching_and_using_frequency_domain_strain_with_parameters( self, ): - self.assertFalse(_test_caching_different_domain( - self.waveform_generator.time_domain_strain, - self.waveform_generator.frequency_domain_strain, - self.simulation_parameters, - self.simulation_parameters, - )) - - def test_frequency_domain_caching_changing_model(self): - original_waveform = self.waveform_generator.frequency_domain_strain( - parameters=self.simulation_parameters - ) - self.waveform_generator.frequency_domain_source_model = ( - dummy_func_array_return_value_2 - ) - new_waveform = self.waveform_generator.frequency_domain_strain( - parameters=self.simulation_parameters - ) self.assertFalse( - np.array_equal(original_waveform["plus"], new_waveform["plus"]) + _test_caching_different_domain( + self.waveform_generator.time_domain_strain, + self.waveform_generator.frequency_domain_strain, + self.simulation_parameters, + self.simulation_parameters, + ) ) + def test_frequency_domain_caching_changing_model(self): + original_waveform = self.waveform_generator.frequency_domain_strain(parameters=self.simulation_parameters) + self.waveform_generator.frequency_domain_source_model = dummy_func_array_return_value_2 + new_waveform = self.waveform_generator.frequency_domain_strain(parameters=self.simulation_parameters) + self.assertFalse(np.array_equal(original_waveform["plus"], new_waveform["plus"])) + def test_time_domain_caching_changing_model(self): self.waveform_generator = bilby.gw.waveform_generator.WaveformGenerator( duration=1, sampling_frequency=4096, time_domain_source_model=dummy_func_dict_return_value, ) - original_waveform = self.waveform_generator.frequency_domain_strain( - parameters=self.simulation_parameters - ) - self.waveform_generator.time_domain_source_model = ( - dummy_func_array_return_value_2 - ) - new_waveform = self.waveform_generator.frequency_domain_strain( - parameters=self.simulation_parameters - ) - self.assertFalse( - np.array_equal(original_waveform["plus"], new_waveform["plus"]) - ) + original_waveform = self.waveform_generator.frequency_domain_strain(parameters=self.simulation_parameters) + self.waveform_generator.time_domain_source_model = dummy_func_array_return_value_2 + new_waveform = self.waveform_generator.frequency_domain_strain(parameters=self.simulation_parameters) + self.assertFalse(np.array_equal(original_waveform["plus"], new_waveform["plus"])) class TestTimeDomainStrainMethod(unittest.TestCase): @@ -529,13 +473,9 @@ def tearDown(self): del self.simulation_parameters def test_parameter_conversion_is_called(self): - self.waveform_generator.parameter_conversion = mock.MagicMock( - side_effect=KeyError("test") - ) + self.waveform_generator.parameter_conversion = mock.MagicMock(side_effect=KeyError("test")) with self.assertRaises(KeyError): - self.waveform_generator.time_domain_strain( - parameters=self.simulation_parameters - ) + self.waveform_generator.time_domain_strain(parameters=self.simulation_parameters) def test_time_domain_source_model_call(self): expected = self.waveform_generator.time_domain_source_model( @@ -548,48 +488,34 @@ def test_time_domain_source_model_call(self): self.simulation_parameters["geocent_time"], self.simulation_parameters["psi"], ) - actual = self.waveform_generator.time_domain_strain( - parameters=self.simulation_parameters - ) + actual = self.waveform_generator.time_domain_strain(parameters=self.simulation_parameters) self.assertTrue(np.array_equal(expected["plus"], actual["plus"])) self.assertTrue(np.array_equal(expected["cross"], actual["cross"])) def test_frequency_domain_source_model_call_with_ndarray(self): self.waveform_generator.time_domain_source_model = None - self.waveform_generator.frequency_domain_source_model = ( - dummy_func_array_return_value - ) + self.waveform_generator.frequency_domain_source_model = dummy_func_array_return_value def side_effect(value, value2): return value with mock.patch("bilby.core.utils.infft") as m: m.side_effect = side_effect - expected = self.waveform_generator.frequency_domain_strain( - parameters=self.simulation_parameters - ) - actual = self.waveform_generator.time_domain_strain( - parameters=self.simulation_parameters - ) + expected = self.waveform_generator.frequency_domain_strain(parameters=self.simulation_parameters) + actual = self.waveform_generator.time_domain_strain(parameters=self.simulation_parameters) self.assertTrue(np.array_equal(expected, actual)) def test_frequency_domain_source_model_call_with_dict(self): self.waveform_generator.time_domain_source_model = None - self.waveform_generator.frequency_domain_source_model = ( - dummy_func_dict_return_value - ) + self.waveform_generator.frequency_domain_source_model = dummy_func_dict_return_value def side_effect(value, value2): return value with mock.patch("bilby.core.utils.infft") as m: m.side_effect = side_effect - expected = self.waveform_generator.frequency_domain_strain( - parameters=self.simulation_parameters - ) - actual = self.waveform_generator.time_domain_strain( - parameters=self.simulation_parameters - ) + expected = self.waveform_generator.frequency_domain_strain(parameters=self.simulation_parameters) + actual = self.waveform_generator.time_domain_strain(parameters=self.simulation_parameters) self.assertTrue(np.array_equal(expected["plus"], actual["plus"])) self.assertTrue(np.array_equal(expected["cross"], actual["cross"])) @@ -597,9 +523,7 @@ def test_no_source_model_given(self): self.waveform_generator.time_domain_source_model = None self.waveform_generator.frequency_domain_source_model = None with self.assertRaises(RuntimeError): - self.waveform_generator.time_domain_strain( - parameters=self.simulation_parameters - ) + self.waveform_generator.time_domain_strain(parameters=self.simulation_parameters) def test_key_popping(self): self.waveform_generator.parameter_conversion = mock.MagicMock( @@ -619,9 +543,7 @@ def test_key_popping(self): ) ) try: - self.waveform_generator.time_domain_strain( - parameters=self.simulation_parameters - ) + self.waveform_generator.time_domain_strain(parameters=self.simulation_parameters) except RuntimeError: pass self.assertListEqual( @@ -630,60 +552,62 @@ def test_key_popping(self): ) def test_caching_with_parameters(self): - original_waveform = self.waveform_generator.time_domain_strain( - parameters=self.simulation_parameters - ) - new_waveform = self.waveform_generator.time_domain_strain( - parameters=self.simulation_parameters - ) + original_waveform = self.waveform_generator.time_domain_strain(parameters=self.simulation_parameters) + new_waveform = self.waveform_generator.time_domain_strain(parameters=self.simulation_parameters) self.assertDictEqual(original_waveform, new_waveform) def test_caching_without_parameters(self): - original_waveform = self.waveform_generator.time_domain_strain( - parameters=self.simulation_parameters - ) + original_waveform = self.waveform_generator.time_domain_strain(parameters=self.simulation_parameters) new_waveform = self.waveform_generator.time_domain_strain() self.assertDictEqual(original_waveform, new_waveform) def test_frequency_domain_caching_and_using_time_domain_strain_without_parameters( self, ): - self.assertFalse(_test_caching_different_domain( - self.waveform_generator.frequency_domain_strain, - self.waveform_generator.time_domain_strain, - self.simulation_parameters, - None, - )) + self.assertFalse( + _test_caching_different_domain( + self.waveform_generator.frequency_domain_strain, + self.waveform_generator.time_domain_strain, + self.simulation_parameters, + None, + ) + ) def test_frequency_domain_caching_and_using_time_domain_strain_with_parameters( self, ): - self.assertFalse(_test_caching_different_domain( - self.waveform_generator.frequency_domain_strain, - self.waveform_generator.time_domain_strain, - self.simulation_parameters, - self.simulation_parameters, - )) + self.assertFalse( + _test_caching_different_domain( + self.waveform_generator.frequency_domain_strain, + self.waveform_generator.time_domain_strain, + self.simulation_parameters, + self.simulation_parameters, + ) + ) def test_time_domain_caching_and_using_frequency_domain_strain_without_parameters( self, ): - self.assertFalse(_test_caching_different_domain( - self.waveform_generator.time_domain_strain, - self.waveform_generator.frequency_domain_strain, - self.simulation_parameters, - None, - )) + self.assertFalse( + _test_caching_different_domain( + self.waveform_generator.time_domain_strain, + self.waveform_generator.frequency_domain_strain, + self.simulation_parameters, + None, + ) + ) def test_time_domain_caching_and_using_frequency_domain_strain_with_parameters( self, ): - self.assertFalse(_test_caching_different_domain( - self.waveform_generator.time_domain_strain, - self.waveform_generator.frequency_domain_strain, - self.simulation_parameters, - self.simulation_parameters, - )) + self.assertFalse( + _test_caching_different_domain( + self.waveform_generator.time_domain_strain, + self.waveform_generator.frequency_domain_strain, + self.simulation_parameters, + self.simulation_parameters, + ) + ) def _test_caching_different_domain(func1, func2, params1, params2): @@ -696,7 +620,6 @@ def _test_caching_different_domain(func1, func2, params1, params2): class TestGWSignalGenerator(unittest.TestCase): - def get_wfgen(self, **kwargs): default_kwargs = dict( duration=4, @@ -769,7 +692,9 @@ def test_default_parameters_are_ignored(self): def test_eccentric_parameters_work(self): wfg = self.get_wfgen( - eccentric=True, spinning=False, waveform_approximant="EccentricFD", + eccentric=True, + spinning=False, + waveform_approximant="EccentricFD", ) parameters_1 = dict( mass_1=10, @@ -795,7 +720,9 @@ def test_eccentric_parameters_work(self): def test_tidal_parameters_work(self): wfg = self.get_wfgen( - tidal=True, spinning=False, waveform_approximant="IMRPhenomD_NRTidalv2", + tidal=True, + spinning=False, + waveform_approximant="IMRPhenomD_NRTidalv2", sampling_frequency=16384, ) parameters_1 = dict( diff --git a/test/hyper/hyper_pe_test.py b/test/hyper/hyper_pe_test.py index 4ca58927d..9f2e02ade 100644 --- a/test/hyper/hyper_pe_test.py +++ b/test/hyper/hyper_pe_test.py @@ -1,4 +1,5 @@ import unittest + import numpy as np import pandas as pd from parameterized import parameterized @@ -16,7 +17,6 @@ def __call__(self, a, b, c): class _ToyClassVariableNames: - variable_names = ["a", "b", "c"] def __call__(self, **kwargs): @@ -29,9 +29,7 @@ def setUp(self): self.lengths = [300, 400, 500] self.posteriors = list() for ii, length in enumerate(self.lengths): - self.posteriors.append( - pd.DataFrame({key: np.random.normal(0, 1, length) for key in self.keys}) - ) + self.posteriors.append(pd.DataFrame({key: np.random.normal(0, 1, length) for key in self.keys})) self.log_evidences = [2, 2, 2] self.model = hyp.model.Model(list()) self.sampling_model = hyp.model.Model(list()) @@ -52,16 +50,16 @@ def test_evidence_factor_with_evidences(self): self.assertEqual(like.evidence_factor, 6) def test_evidence_factor_without_evidences(self): - like = hyp.likelihood.HyperparameterLikelihood( - self.posteriors, self.model, self.sampling_model - ) + like = hyp.likelihood.HyperparameterLikelihood(self.posteriors, self.model, self.sampling_model) self.assertTrue(np.isnan(like.evidence_factor)) - @parameterized.expand([ - ("func", _toy_function), - ("class_no_names", _ToyClassNoVariableNames()), - ("class_with_names", _ToyClassVariableNames()), - ]) + @parameterized.expand( + [ + ("func", _toy_function), + ("class_no_names", _ToyClassNoVariableNames()), + ("class_with_names", _ToyClassVariableNames()), + ] + ) def test_get_function_parameters(self, _, model): expected = dict(a=1, b=2, c=3) model = hyp.model.Model([model]) diff --git a/test/import_test.py b/test/import_test.py index 111dbc322..6d3328022 100644 --- a/test/import_test.py +++ b/test/import_test.py @@ -10,18 +10,23 @@ unique_packages = set(sys.modules) unwanted = { - "lal", "lalsimulation", "matplotlib", - "h5py", "dill", "tqdm", "tables", "deepdish", "corner", + "lal", + "lalsimulation", + "matplotlib", + "h5py", + "dill", + "tqdm", + "tables", + "deepdish", + "corner", } for filename in ["sampler_requirements.txt", "optional_requirements.txt"]: - with open(filename, "r") as ff: + with open(filename) as ff: packages = ff.readlines() for package in packages: package = package.split(">")[0].split("<")[0].split("=")[0].strip() unwanted.add(package) if not unique_packages.isdisjoint(unwanted): - raise ImportError( - f"{' '.join(unique_packages.intersection(unwanted))} imported with Bilby" - ) + raise ImportError(f"{' '.join(unique_packages.intersection(unwanted))} imported with Bilby") diff --git a/test/integration/example_test.py b/test/integration/example_test.py index 266001fc1..3be1c7b58 100644 --- a/test/integration/example_test.py +++ b/test/integration/example_test.py @@ -4,12 +4,13 @@ import glob import importlib.util -import unittest +import logging import os +import shutil +import unittest + import parameterized import pytest -import shutil -import logging import bilby.core.utils @@ -48,7 +49,7 @@ def tearDown(self): try: shutil.rmtree(self.outdir) except OSError: - logging.warning("{} not removed after tests".format(self.outdir)) + logging.warning(f"{self.outdir} not removed after tests") os.chdir(self.init_dir) @classmethod @@ -57,7 +58,7 @@ def setUpClass(cls): try: shutil.rmtree(cls.outdir) except OSError: - logging.warning("{} not removed prior to tests".format(cls.outdir)) + logging.warning(f"{cls.outdir} not removed prior to tests") @classmethod def tearDownClass(cls): @@ -65,11 +66,11 @@ def tearDownClass(cls): try: shutil.rmtree(cls.outdir) except OSError: - logging.warning("{} not removed prior to tests".format(cls.outdir)) + logging.warning(f"{cls.outdir} not removed prior to tests") @parameterized.parameterized.expand(core_args) def test_core_examples(self, name, fname): - """ Loop over examples to check they run """ + """Loop over examples to check they run""" bilby.core.utils.command_line_args.bilby_test_mode = False ignore = ["15d_gaussian"] if any([item in fname for item in ignore]): @@ -78,7 +79,7 @@ def test_core_examples(self, name, fname): @parameterized.parameterized.expand(gw_args) def test_gw_examples(self, name, fname): - """ Loop over examples to check they run """ + """Loop over examples to check they run""" bilby.core.utils.command_line_args.bilby_test_mode = True _execute_file(name, fname) diff --git a/test/integration/make_standard_data.py b/test/integration/make_standard_data.py index 4f5d4073d..f60f630f5 100644 --- a/test/integration/make_standard_data.py +++ b/test/integration/make_standard_data.py @@ -49,17 +49,13 @@ ) hf_signal_and_noise = IFO.strain_data.frequency_domain_strain -frequencies = bilby.core.utils.create_frequency_series( - sampling_frequency=sampling_frequency, duration=time_duration -) +frequencies = bilby.core.utils.create_frequency_series(sampling_frequency=sampling_frequency, duration=time_duration) if __name__ == "__main__": dir_path = os.path.dirname(os.path.realpath(__file__)) with open(dir_path + "/standard_data.txt", "w+") as f: np.savetxt( f, - np.column_stack( - [frequencies, hf_signal_and_noise.view(float).reshape(-1, 2)] - ), + np.column_stack([frequencies, hf_signal_and_noise.view(float).reshape(-1, 2)]), header="frequency hf_real hf_imag", ) diff --git a/test/integration/noise_realisation_test.py b/test/integration/noise_realisation_test.py index 6bada5f58..05649d93e 100644 --- a/test/integration/noise_realisation_test.py +++ b/test/integration/noise_realisation_test.py @@ -1,5 +1,7 @@ -import numpy as np import unittest + +import numpy as np + import bilby @@ -20,28 +22,20 @@ def test_averaged_noise(self): psd_avg = psd_avg / n_avg asd_avg = np.sqrt(abs(psd_avg)) * interferometer.frequency_mask - a = np.nan_to_num( - interferometer.amplitude_spectral_density_array - / factor - * interferometer.frequency_mask - ) + a = np.nan_to_num(interferometer.amplitude_spectral_density_array / factor * interferometer.frequency_mask) b = asd_avg self.assertTrue(np.allclose(a, b, rtol=1e-1)) def test_noise_normalisation(self): duration = 1.0 sampling_frequency = 4096.0 - time_array = bilby.core.utils.create_time_series( - sampling_frequency=sampling_frequency, duration=duration - ) + time_array = bilby.core.utils.create_time_series(sampling_frequency=sampling_frequency, duration=duration) interferometer = bilby.gw.detector.get_empty_interferometer("H1") # generate some toy-model signal for matched filtering SNR testing n_avg = 1000 snr = np.zeros(n_avg) - mu = np.exp(-((time_array - duration / 2.0) ** 2) / (2.0 * 0.1 ** 2)) * np.sin( - 2 * np.pi * 100 * time_array - ) + mu = np.exp(-((time_array - duration / 2.0) ** 2) / (2.0 * 0.1**2)) * np.sin(2 * np.pi * 100 * time_array) muf, frequency_array = bilby.core.utils.nfft(mu, sampling_frequency) for x in range(0, n_avg): interferometer.set_strain_data_from_power_spectral_density( diff --git a/test/integration/other_test.py b/test/integration/other_test.py index 7f49825f9..b74444230 100644 --- a/test/integration/other_test.py +++ b/test/integration/other_test.py @@ -36,12 +36,10 @@ def tearDownClass(self): """ def test_make_standard_data(self): - " Load in the saved standard data and compare with new data " + "Load in the saved standard data and compare with new data" # Load in the saved standard data - frequencies_saved, hf_real_saved, hf_imag_saved = np.loadtxt( - self.dir_path + "/integration/standard_data.txt" - ).T + frequencies_saved, hf_real_saved, hf_imag_saved = np.loadtxt(self.dir_path + "/integration/standard_data.txt").T hf_signal_and_noise_saved = hf_real_saved + 1j * hf_imag_saved self.assertTrue(np.array_equal(self.msd["frequencies"], frequencies_saved)) @@ -52,9 +50,7 @@ def test_make_standard_data(self): ) def test_recover_luminosity_distance(self): - likelihood = bilby.gw.likelihood.GravitationalWaveTransient( - [self.msd["IFO"]], self.msd["waveform_generator"] - ) + likelihood = bilby.gw.likelihood.GravitationalWaveTransient([self.msd["IFO"]], self.msd["waveform_generator"]) priors = {} for key in self.msd["simulation_parameters"]: @@ -65,9 +61,7 @@ def test_recover_luminosity_distance(self): name="luminosity_distance", minimum=dL - 10, maximum=dL + 10 ) - result = bilby.core.sampler.run_sampler( - likelihood, priors, sampler="dynesty", verbose=False, npoints=100 - ) + result = bilby.core.sampler.run_sampler(likelihood, priors, sampler="dynesty", verbose=False, npoints=100) self.assertAlmostEqual( np.mean(result.posterior.luminosity_distance), dL, diff --git a/test/integration/sample_from_the_prior_test.py b/test/integration/sample_from_the_prior_test.py index 7fce8b162..9367d7847 100644 --- a/test/integration/sample_from_the_prior_test.py +++ b/test/integration/sample_from_the_prior_test.py @@ -1,13 +1,14 @@ -import shutil -import os import logging -from packaging import version - +import os +import shutil import unittest -import bilby + import scipy +from packaging import version from scipy.stats import ks_2samp, kstest +import bilby + def ks_2samp_wrapper(data1, data2): if version.parse(scipy.__version__) >= version.parse("1.3.0"): @@ -25,7 +26,7 @@ def setUpClass(self): try: shutil.rmtree(self.outdir) except OSError: - logging.warning("{} not removed prior to tests".format(self.outdir)) + logging.warning(f"{self.outdir} not removed prior to tests") @classmethod def tearDownClass(self): @@ -33,7 +34,7 @@ def tearDownClass(self): try: shutil.rmtree(self.outdir) except OSError: - logging.warning("{} not removed prior to tests".format(self.outdir)) + logging.warning(f"{self.outdir} not removed prior to tests") def test_fifteen_dimensional_cbc(self): duration = 4.0 @@ -69,9 +70,7 @@ def test_fifteen_dimensional_cbc(self): maximum=100.0, unit="$M_{\\odot}$", ) - priors["mass_ratio"] = bilby.prior.Uniform( - name="mass_ratio", latex_label="$q$", minimum=0.5, maximum=1.0 - ) + priors["mass_ratio"] = bilby.prior.Uniform(name="mass_ratio", latex_label="$q$", minimum=0.5, maximum=1.0) priors["geocent_time"] = bilby.core.prior.Uniform(minimum=-0.1, maximum=0.1) likelihood = bilby.gw.GravitationalWaveTransient( @@ -93,12 +92,10 @@ def test_fifteen_dimensional_cbc(self): nact=10, outdir=self.outdir, label=label, - save=False + save=False, ) pvalues = [ - ks_2samp_wrapper( - result.priors[key].sample(10000), result.posterior[key].values - ).pvalue + ks_2samp_wrapper(result.priors[key].sample(10000), result.posterior[key].values).pvalue for key in priors.keys() ] print("P values per parameter") diff --git a/test/integration/sampler_run_test.py b/test/integration/sampler_run_test.py index f35cf07c6..eba315c5c 100644 --- a/test/integration/sampler_run_test.py +++ b/test/integration/sampler_run_test.py @@ -1,3 +1,5 @@ +# ruff: noqa: E402 + import multiprocessing import os import sys @@ -7,14 +9,14 @@ multiprocessing.set_start_method("fork") # noqa +import shutil import unittest + +import numpy as np import pytest from parameterized import parameterized -import shutil import bilby -import numpy as np - _sampler_kwargs = dict( bilby_mcmc=dict(nsamples=200, printdt=1), @@ -47,13 +49,10 @@ pymc=dict(draws=50, tune=50, n_init=250), pymultinest=dict(nlive=100), ultranest=dict(nlive=100, temporary_directory=False), - zeus=dict(nwalkers=10, iterations=100) + zeus=dict(nwalkers=10, iterations=100), ) -sampler_imports = dict( - bilby_mcmc="bilby", - dynamic_dynesty="dynesty" -) +sampler_imports = dict(bilby_mcmc="bilby", dynamic_dynesty="dynesty") no_pool_test = ["pymultinest", "nestle", "ptmcmcsampler", "ultranest", "pymc"] @@ -77,12 +76,8 @@ def setUp(self): self.x = np.linspace(0, 1, 11) self.injection_parameters = dict(m=0.5, c=0.2) self.sigma = 0.1 - self.y = model(self.x, **self.injection_parameters) + rng.normal( - 0, self.sigma, len(self.x) - ) - self.likelihood = bilby.likelihood.GaussianLikelihood( - self.x, self.y, model, self.sigma - ) + self.y = model(self.x, **self.injection_parameters) + rng.normal(0, self.sigma, len(self.x)) + self.likelihood = bilby.likelihood.GaussianLikelihood(self.x, self.y, model, self.sigma) self.priors = bilby.core.prior.PriorDict() self.priors["m"] = bilby.core.prior.Uniform(0, 5, boundary="periodic") diff --git a/test/integration/test_waveforms.py b/test/integration/test_waveforms.py index c564c1052..331af628a 100644 --- a/test/integration/test_waveforms.py +++ b/test/integration/test_waveforms.py @@ -1,9 +1,11 @@ import unittest -import bilby -import numpy as np -from bilby.gw.utils import overlap + import lal import lalsimulation as lalsim +import numpy as np + +import bilby +from bilby.gw.utils import overlap class TestWaveformDirectAgainstLALSIM(unittest.TestCase): @@ -67,7 +69,6 @@ def test_TaylorF2(self): self.run_for_approximant(waveform_approximant, source="bns") def run_for_approximant(self, waveform_approximant, source): - if source == "bbh": injection_parameters = self.BBH_precessing_injection_parameters frequency_domain_source_model = bilby.gw.source.lal_binary_black_hole @@ -75,7 +76,7 @@ def run_for_approximant(self, waveform_approximant, source): injection_parameters = self.BNS_precessing_injection_parameters frequency_domain_source_model = bilby.gw.source.lal_binary_neutron_star else: - raise ValueError("Source can only be 'bbh' or 'bns', but was '{}'".format(source)) + raise ValueError(f"Source can only be 'bbh' or 'bns', but was '{source}'") # create a waveform generator for bilby duration = 4.0 @@ -119,9 +120,7 @@ def run_for_approximant(self, waveform_approximant, source): waveform_arguments=waveform_arguments, ) - bilby_strain = waveform_generator.frequency_domain_strain( - parameters=injection_parameters - ) + bilby_strain = waveform_generator.frequency_domain_strain(parameters=injection_parameters) # LALSIM Waveform @@ -144,16 +143,14 @@ def run_for_approximant(self, waveform_approximant, source): (waveform_generator.frequency_array)[-1], lambda_1, lambda_2, - **waveform_arguments + **waveform_arguments, ) h_plus = get_lalsim_waveform["plus"] h_cross = get_lalsim_waveform["cross"] if waveform_approximant == "TaylorF2": - upper_freq = ISCO( - injection_parameters["mass_1"], injection_parameters["mass_2"] - ) + upper_freq = ISCO(injection_parameters["mass_1"], injection_parameters["mass_2"]) else: upper_freq = waveform_generator.frequency_array[-1] @@ -164,9 +161,7 @@ def run_for_approximant(self, waveform_approximant, source): f_len = int((2 * sampling_frequency) / delta_f) # PSD aLIGO - psd_aLIGO = generate_PSD( - psd_name="aLIGOZeroDetHighPower", length=f_len, delta_f=delta_f - ) + psd_aLIGO = generate_PSD(psd_name="aLIGOZeroDetHighPower", length=f_len, delta_f=delta_f) norm_hp_bilby = normalize_strain( bilby_strain["plus"], @@ -232,22 +227,7 @@ def ISCO(m1, m2): def lalsim_FD_waveform( - m1, - m2, - s1x, - s1y, - s1z, - s2x, - s2y, - s2z, - theta_jn, - phase, - duration, - dL, - fmax, - lambda_1=None, - lambda_2=None, - **kwarg + m1, m2, s1x, s1y, s1z, s2x, s2y, s2z, theta_jn, phase, duration, dL, fmax, lambda_1=None, lambda_2=None, **kwarg ): mass1 = m1 * lal.MSUN_SI mass2 = m2 * lal.MSUN_SI @@ -275,13 +255,9 @@ def lalsim_FD_waveform( waveform_dictionary = lal.CreateDict() if lambda_1 is not None: - lalsim.SimInspiralWaveformParamsInsertTidalLambda1( - waveform_dictionary, float(lambda_1) - ) + lalsim.SimInspiralWaveformParamsInsertTidalLambda1(waveform_dictionary, float(lambda_1)) if lambda_2 is not None: - lalsim.SimInspiralWaveformParamsInsertTidalLambda2( - waveform_dictionary, float(lambda_2) - ) + lalsim.SimInspiralWaveformParamsInsertTidalLambda2(waveform_dictionary, float(lambda_2)) hplus, hcross = lalsim.SimInspiralChooseFDWaveform( mass1, @@ -331,18 +307,10 @@ def get_lalsim_psd_list(): psd_list = [] # Avoid the string 'SimNoisePSD' for name in lalsim.__dict__: - if ( - name != PSD_prefix - and name.startswith(PSD_prefix) - and not name.endswith(PSD_suffix) - ): + if name != PSD_prefix and name.startswith(PSD_prefix) and not name.endswith(PSD_suffix): # if name in blacklist: name = name[len(PSD_prefix) :] - if ( - name not in blacklist - and not name.startswith("iLIGO") - and not name.startswith("eLIGO") - ): + if name not in blacklist and not name.startswith("iLIGO") and not name.startswith("eLIGO"): psd_list.append(name) return sorted(psd_list) @@ -356,18 +324,14 @@ def generate_PSD(psd_name="aLIGOZeroDetHighPower", length=None, delta_f=None): # Function for PSD func = lalsim.__dict__["SimNoisePSD" + psd_name + "Ptr"] # Generate a lal frequency series - PSDseries = lal.CreateREAL8FrequencySeries( - "", lal.LIGOTimeGPS(0), 0, delta_f, lal.DimensionlessUnit, length - ) + PSDseries = lal.CreateREAL8FrequencySeries("", lal.LIGOTimeGPS(0), 0, delta_f, lal.DimensionlessUnit, length) # func(PSDseries) lalsim.SimNoisePSD(PSDseries, 0, func) return PSDseries # Normalizing a waveform -def normalize_strain( - signal, psd=None, delta_f=None, lower_cut_off=None, upper_cut_off=None -): +def normalize_strain(signal, psd=None, delta_f=None, lower_cut_off=None, upper_cut_off=None): low_index = int(lower_cut_off / delta_f) up_index = int(upper_cut_off / delta_f) integrand = np.conj(signal) * signal diff --git a/test/test_samplers_import.py b/test/test_samplers_import.py index d3d12f851..7be477d80 100644 --- a/test/test_samplers_import.py +++ b/test/test_samplers_import.py @@ -1,10 +1,9 @@ -import bilby import pytest +import bilby + -@pytest.mark.parametrize( - "sampler_name", bilby.core.sampler.IMPLEMENTED_SAMPLERS.keys() -) +@pytest.mark.parametrize("sampler_name", bilby.core.sampler.IMPLEMENTED_SAMPLERS.keys()) def test_sampler_import(sampler_name): """ Tests that all of the implemented samplers can be initialized.