Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions example/workbench.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
UniformSpherePrior,
)
from jimgw.core.single_event.detector import get_H1, get_L1, get_V1
from jimgw.core.single_event.likelihood import TransientLikelihoodFD
from jimgw.core.single_event.likelihood import BaseTransientLikelihoodFD
from jimgw.core.single_event.data import Data
from jimgw.core.single_event.waveform import RippleIMRPhenomPv2
from jimgw.core.transforms import BoundToUnbound
Expand Down Expand Up @@ -131,7 +131,9 @@
gps_time=gps, ifos=ifos, dL_min=dL_prior.xmin, dL_max=dL_prior.xmax
),
GeocentricArrivalPhaseToDetectorArrivalPhaseTransform(gps_time=gps, ifo=ifos[0]),
GeocentricArrivalTimeToDetectorArrivalTimeTransform(tc_min=t_c_prior.xmin, tc_max=t_c_prior.xmax, gps_time=gps, ifo=ifos[0]),
GeocentricArrivalTimeToDetectorArrivalTimeTransform(
tc_min=t_c_prior.xmin, tc_max=t_c_prior.xmax, gps_time=gps, ifo=ifos[0]
),
SkyFrameToDetectorFrameSkyPositionTransform(gps_time=gps, ifos=ifos),
BoundToUnbound(
name_mapping=(["M_c"], ["M_c_unbounded"]),
Expand Down Expand Up @@ -207,13 +209,12 @@
]


likelihood = TransientLikelihoodFD(
likelihood = BaseTransientLikelihoodFD(
ifos,
waveform=waveform,
trigger_time=gps,
f_min=fmin,
f_max=fmax,
# marginalization="time",
)

jim = Jim(
Expand Down
176 changes: 148 additions & 28 deletions jim_dagster/InjectionRecovery/assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

# Sample a fiducial population


@dg.asset(
group_name="prerun",
key_prefix="InjectionRecovery",
Expand All @@ -32,10 +33,12 @@ def sample_population():
path_prefix="./data/",
)


# TODO: Add diagnostics regarding the sampled population.

# Create asset group for run and configuration


@dg.asset(
group_name="prerun",
description="Configuration file for the run.",
Expand Down Expand Up @@ -76,10 +79,17 @@ def config_file():
run.local_data_prefix = f"./data/runs/{idx}/strains/"
run.serialize(f"./data/runs/{idx}/config.yaml")


@dg.multi_asset(
specs=[
dg.AssetSpec(key=["InjectionRecovery", "strain"], deps=[["InjectionRecovery", "config_file"]]),
dg.AssetSpec(key=["InjectionRecovery", "psd"], deps=[["InjectionRecovery", "config_file"]]),
dg.AssetSpec(
key=["InjectionRecovery", "strain"],
deps=[["InjectionRecovery", "config_file"]],
),
dg.AssetSpec(
key=["InjectionRecovery", "psd"],
deps=[["InjectionRecovery", "config_file"]],
),
],
group_name="prerun",
)
Expand Down Expand Up @@ -122,19 +132,53 @@ def raw_data():
detector.data.to_file(f"./data/runs/{idx}/strains/{ifo}_data")
detector.psd.to_file(f"./data/runs/{idx}/strains/{ifo}_psd")


@dg.multi_asset(
specs=[
dg.AssetSpec(key=["InjectionRecovery", "training_chains"], deps=[["InjectionRecovery", "raw_data"]]),
dg.AssetSpec(key=["InjectionRecovery", "training_log_prob"], deps=[["InjectionRecovery", "raw_data"]]),
dg.AssetSpec(key=["InjectionRecovery", "training_local_acceptance"], deps=[["InjectionRecovery", "raw_data"]]),
dg.AssetSpec(key=["InjectionRecovery", "training_global_acceptance"], deps=[["InjectionRecovery", "raw_data"]]),
dg.AssetSpec(key=["InjectionRecovery", "training_loss"], deps=[["InjectionRecovery", "raw_data"]]),
dg.AssetSpec(key=["InjectionRecovery", "production_chains"], deps=[["InjectionRecovery", "raw_data"]]),
dg.AssetSpec(key=["InjectionRecovery", "production_log_prob"], deps=[["InjectionRecovery", "raw_data"]]),
dg.AssetSpec(key=["InjectionRecovery", "production_local_acceptance"], deps=[["InjectionRecovery", "raw_data"]]),
dg.AssetSpec(key=["InjectionRecovery", "production_global_acceptance"], deps=[["InjectionRecovery", "raw_data"]]),
dg.AssetSpec(key=["InjectionRecovery", "auxiliary_nf_samples"], deps=[["InjectionRecovery", "raw_data"]]),
dg.AssetSpec(key=["InjectionRecovery", "auxiliary_prior_samples"], deps=[["InjectionRecovery", "raw_data"]]),
dg.AssetSpec(
key=["InjectionRecovery", "training_chains"],
deps=[["InjectionRecovery", "raw_data"]],
),
dg.AssetSpec(
key=["InjectionRecovery", "training_log_prob"],
deps=[["InjectionRecovery", "raw_data"]],
),
dg.AssetSpec(
key=["InjectionRecovery", "training_local_acceptance"],
deps=[["InjectionRecovery", "raw_data"]],
),
dg.AssetSpec(
key=["InjectionRecovery", "training_global_acceptance"],
deps=[["InjectionRecovery", "raw_data"]],
),
dg.AssetSpec(
key=["InjectionRecovery", "training_loss"],
deps=[["InjectionRecovery", "raw_data"]],
),
dg.AssetSpec(
key=["InjectionRecovery", "production_chains"],
deps=[["InjectionRecovery", "raw_data"]],
),
dg.AssetSpec(
key=["InjectionRecovery", "production_log_prob"],
deps=[["InjectionRecovery", "raw_data"]],
),
dg.AssetSpec(
key=["InjectionRecovery", "production_local_acceptance"],
deps=[["InjectionRecovery", "raw_data"]],
),
dg.AssetSpec(
key=["InjectionRecovery", "production_global_acceptance"],
deps=[["InjectionRecovery", "raw_data"]],
),
dg.AssetSpec(
key=["InjectionRecovery", "auxiliary_nf_samples"],
deps=[["InjectionRecovery", "raw_data"]],
),
dg.AssetSpec(
key=["InjectionRecovery", "auxiliary_prior_samples"],
deps=[["InjectionRecovery", "raw_data"]],
),
],
group_name="run",
)
Expand All @@ -145,64 +189,140 @@ def run():
"""
pass


# Create asset group for diagnostics

@dg.asset(group_name="diagnostics", deps=[["InjectionRecovery", "training_loss"]], key_prefix="InjectionRecovery")

@dg.asset(
group_name="diagnostics",
deps=[["InjectionRecovery", "training_loss"]],
key_prefix="InjectionRecovery",
)
def loss_plot():
pass

@dg.asset(group_name="diagnostics", deps=[["InjectionRecovery", "training_chains"]], key_prefix="InjectionRecovery")

@dg.asset(
group_name="diagnostics",
deps=[["InjectionRecovery", "training_chains"]],
key_prefix="InjectionRecovery",
)
def training_chains_corner_plot():
pass

@dg.asset(group_name="diagnostics", deps=[["InjectionRecovery", "training_chains"]], key_prefix="InjectionRecovery")

@dg.asset(
group_name="diagnostics",
deps=[["InjectionRecovery", "training_chains"]],
key_prefix="InjectionRecovery",
)
def training_chains_trace_plot():
pass

@dg.asset(group_name="diagnostics", deps=[["InjectionRecovery", "training_chains"]], key_prefix="InjectionRecovery")

@dg.asset(
group_name="diagnostics",
deps=[["InjectionRecovery", "training_chains"]],
key_prefix="InjectionRecovery",
)
def training_chains_rhat_plot():
pass

@dg.asset(group_name="diagnostics", deps=[["InjectionRecovery", "training_log_prob"]], key_prefix="InjectionRecovery")

@dg.asset(
group_name="diagnostics",
deps=[["InjectionRecovery", "training_log_prob"]],
key_prefix="InjectionRecovery",
)
def training_log_prob_distribution():
pass

@dg.asset(group_name="diagnostics", deps=[["InjectionRecovery", "training_log_prob"]], key_prefix="InjectionRecovery")

@dg.asset(
group_name="diagnostics",
deps=[["InjectionRecovery", "training_log_prob"]],
key_prefix="InjectionRecovery",
)
def training_log_prob_evolution():
pass

@dg.asset(group_name="diagnostics", deps=[["InjectionRecovery", "training_local_acceptance"]], key_prefix="InjectionRecovery")

@dg.asset(
group_name="diagnostics",
deps=[["InjectionRecovery", "training_local_acceptance"]],
key_prefix="InjectionRecovery",
)
def training_local_acceptance_plot():
pass

@dg.asset(group_name="diagnostics", deps=[["InjectionRecovery", "training_global_acceptance"]], key_prefix="InjectionRecovery")

@dg.asset(
group_name="diagnostics",
deps=[["InjectionRecovery", "training_global_acceptance"]],
key_prefix="InjectionRecovery",
)
def training_global_acceptance_plot():
pass

@dg.asset(group_name="diagnostics", deps=[["InjectionRecovery", "production_chains"]], key_prefix="InjectionRecovery")

@dg.asset(
group_name="diagnostics",
deps=[["InjectionRecovery", "production_chains"]],
key_prefix="InjectionRecovery",
)
def production_chains_corner_plot():
pass

@dg.asset(group_name="diagnostics", deps=[["InjectionRecovery", "production_chains"]], key_prefix="InjectionRecovery")

@dg.asset(
group_name="diagnostics",
deps=[["InjectionRecovery", "production_chains"]],
key_prefix="InjectionRecovery",
)
def production_chains_trace_plot():
pass

@dg.asset(group_name="diagnostics", deps=[["InjectionRecovery", "production_chains"]], key_prefix="InjectionRecovery")

@dg.asset(
group_name="diagnostics",
deps=[["InjectionRecovery", "production_chains"]],
key_prefix="InjectionRecovery",
)
def production_chains_rhat_plot():
pass

@dg.asset(group_name="diagnostics", deps=[["InjectionRecovery", "production_log_prob"]], key_prefix="InjectionRecovery")

@dg.asset(
group_name="diagnostics",
deps=[["InjectionRecovery", "production_log_prob"]],
key_prefix="InjectionRecovery",
)
def production_log_prob_distribution():
pass

@dg.asset(group_name="diagnostics", deps=[["InjectionRecovery", "production_log_prob"]], key_prefix="InjectionRecovery")

@dg.asset(
group_name="diagnostics",
deps=[["InjectionRecovery", "production_log_prob"]],
key_prefix="InjectionRecovery",
)
def production_log_prob_evolution():
pass

@dg.asset(group_name="diagnostics", deps=[["InjectionRecovery", "production_local_acceptance"]], key_prefix="InjectionRecovery")

@dg.asset(
group_name="diagnostics",
deps=[["InjectionRecovery", "production_local_acceptance"]],
key_prefix="InjectionRecovery",
)
def production_local_acceptance_plot():
pass

@dg.asset(group_name="diagnostics", deps=[["InjectionRecovery", "production_global_acceptance"]], key_prefix="InjectionRecovery")

@dg.asset(
group_name="diagnostics",
deps=[["InjectionRecovery", "production_global_acceptance"]],
key_prefix="InjectionRecovery",
)
def production_global_acceptance_plot():
pass
Loading
Loading