diff --git a/pyproject.toml b/pyproject.toml index 0ecc7e2..df3de5e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,11 +24,23 @@ dependencies = [ # On a mac, install optional dependencies with `pip install '.[dev]'` (include the single quotes) [project.optional-dependencies] full = [ - "pz-rail[algos]", + "pz-rail-astro-tools", + "pz-rail-bpz", + "pz-rail-cmnn", + "pz-rail-dnf", + "pz-rail-dsps", + "pz-rail-flexzboost", + "pz-rail-fsps", + "pz-rail-gpz-v1", + "pz-rail-pzflow", + "pz-rail-sklearn", + "pz-rail-yaw", + "pz-rail-lephare", + "qp-prob[full]>=1.0.0", ] dev = [ - "pz-rail[algos]", + "pz-rail-pipelines[full]", "coverage", "pytest", "pytest-cov", # Used to report total code coverage diff --git a/src/rail/pipelines/estimation/pz_all.py b/src/rail/pipelines/estimation/pz_all.py index e0b2052..64ac89e 100644 --- a/src/rail/pipelines/estimation/pz_all.py +++ b/src/rail/pipelines/estimation/pz_all.py @@ -5,7 +5,7 @@ # Various rail modules from rail.core.stage import RailStage, RailPipeline -from rail.utils.catalog_utils import CatalogConfigBase +from rail.utils import catalog_utils from rail.evaluation.single_evaluator import SingleEvaluator from rail.utils.algo_library import PZ_ALGORITHMS @@ -27,14 +27,14 @@ def __init__(self, algorithms: dict|None=None): DS = RailStage.data_store DS.__class__.allow_overwrite = True - active_catalog = CatalogConfigBase.active_class() + active_catalog_config = catalog_utils.get_active_tag() eval_shared_stage_opts = dict( metrics=['all'], exclude_metrics=['rmse', 'ks', 'kld', 'cvm', 'ad', 'rbpe', 'outlier'], hdf5_groupname="", limits=[0, 3.5], - truth_point_estimates=[active_catalog.redshift_col], + truth_point_estimates=[active_catalog_config.config['redshift_col']], point_estimates=['zmode'], ) diff --git a/src/rail/pipelines/utils/prepare_observed.py b/src/rail/pipelines/utils/prepare_observed.py index cccc0a6..23e8159 100644 --- a/src/rail/pipelines/utils/prepare_observed.py +++ b/src/rail/pipelines/utils/prepare_observed.py @@ -32,14 +32,14 @@ def __init__( DS = RailStage.data_store DS.__class__.allow_overwrite = True - active_catalog_config = catalog_utils.CatalogConfigBase.active_class() + active_catalog_config = catalog_utils.get_active_tag() self.flux_to_mag = LSSTFluxToMagConverter.build( flux_name="{band}Flux", flux_err_name="{band}FluxErr", mag_name="{band}Mag", mag_err_name="{band}MagErr", - bands=active_catalog_config.bandlist, + bands=active_catalog_config.config['band_list'], copy_cols=dict( objectId='objectId', coord_ra='coord_ra',