Skip to content

Commit 82031d6

Browse files
committed
Merge branch 'tickets/DM-44159'
2 parents 3dd589e + 7fc0c7d commit 82031d6

File tree

4 files changed

+145
-97
lines changed

4 files changed

+145
-97
lines changed

python/lsst/pipe/tasks/diff_matched_tract_catalog.py

Lines changed: 89 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737

3838
from abc import ABCMeta, abstractmethod
3939
from astropy.stats import mad_std
40+
import astropy.table
4041
import astropy.units as u
4142
from dataclasses import dataclass
4243
from decimal import Decimal
@@ -48,6 +49,7 @@
4849
from smatch.matcher import sphdist
4950
from types import SimpleNamespace
5051
from typing import Sequence
52+
import warnings
5153

5254

5355
def is_sequence_set(x: Sequence):
@@ -75,14 +77,14 @@ class DiffMatchedTractCatalogConnections(
7577
cat_ref = cT.Input(
7678
doc="Reference object catalog to match from",
7779
name="{name_input_cat_ref}",
78-
storageClass="DataFrame",
80+
storageClass="ArrowAstropy",
7981
dimensions=("tract", "skymap"),
8082
deferLoad=True,
8183
)
8284
cat_target = cT.Input(
8385
doc="Target object catalog to match",
8486
name="{name_input_cat_target}",
85-
storageClass="DataFrame",
87+
storageClass="ArrowAstropy",
8688
dimensions=("tract", "skymap"),
8789
deferLoad=True,
8890
)
@@ -95,33 +97,33 @@ class DiffMatchedTractCatalogConnections(
9597
cat_match_ref = cT.Input(
9698
doc="Reference match catalog with indices of target matches",
9799
name="match_ref_{name_input_cat_ref}_{name_input_cat_target}",
98-
storageClass="DataFrame",
100+
storageClass="ArrowAstropy",
99101
dimensions=("tract", "skymap"),
100102
deferLoad=True,
101103
)
102104
cat_match_target = cT.Input(
103105
doc="Target match catalog with indices of references matches",
104106
name="match_target_{name_input_cat_ref}_{name_input_cat_target}",
105-
storageClass="DataFrame",
107+
storageClass="ArrowAstropy",
106108
dimensions=("tract", "skymap"),
107109
deferLoad=True,
108110
)
109111
columns_match_target = cT.Input(
110112
doc="Target match catalog columns",
111113
name="match_target_{name_input_cat_ref}_{name_input_cat_target}.columns",
112-
storageClass="DataFrameIndex",
114+
storageClass="ArrowColumnList",
113115
dimensions=("tract", "skymap"),
114116
)
115117
cat_matched = cT.Output(
116118
doc="Catalog with reference and target columns for joined sources",
117119
name="matched_{name_input_cat_ref}_{name_input_cat_target}",
118-
storageClass="DataFrame",
120+
storageClass="ArrowAstropy",
119121
dimensions=("tract", "skymap"),
120122
)
121123
diff_matched = cT.Output(
122124
doc="Table with aggregated counts, difference and chi statistics",
123125
name="diff_matched_{name_input_cat_ref}_{name_input_cat_target}",
124-
storageClass="DataFrame",
126+
storageClass="ArrowAstropy",
125127
dimensions=("tract", "skymap"),
126128
)
127129

@@ -137,6 +139,8 @@ def __init__(self, *, config=None):
137139
dimensions=(),
138140
deferLoad=old.deferLoad,
139141
)
142+
if not (config.compute_stats and len(config.columns_flux) > 0):
143+
del self.diff_matched
140144

141145

142146
class MatchedCatalogFluxesConfig(pexConfig.Config):
@@ -685,25 +689,25 @@ def runQuantum(self, butlerQC, inputRefs, outputRefs):
685689

686690
def run(
687691
self,
688-
catalog_ref: pd.DataFrame,
689-
catalog_target: pd.DataFrame,
690-
catalog_match_ref: pd.DataFrame,
691-
catalog_match_target: pd.DataFrame,
692+
catalog_ref: pd.DataFrame | astropy.table.Table,
693+
catalog_target: pd.DataFrame | astropy.table.Table,
694+
catalog_match_ref: pd.DataFrame | astropy.table.Table,
695+
catalog_match_target: pd.DataFrame | astropy.table.Table,
692696
wcs: afwGeom.SkyWcs = None,
693697
) -> pipeBase.Struct:
694698
"""Load matched reference and target (measured) catalogs, measure summary statistics, and output
695699
a combined matched catalog with columns from both inputs.
696700
697701
Parameters
698702
----------
699-
catalog_ref : `pandas.DataFrame`
703+
catalog_ref : `pandas.DataFrame` | `astropy.table.Table`
700704
A reference catalog to diff objects/sources from.
701-
catalog_target : `pandas.DataFrame`
705+
catalog_target : `pandas.DataFrame` | `astropy.table.Table`
702706
A target catalog to diff reference objects/sources to.
703-
catalog_match_ref : `pandas.DataFrame`
707+
catalog_match_ref : `pandas.DataFrame` | `astropy.table.Table`
704708
A catalog with match indices of target sources and selection flags
705709
for each reference source.
706-
catalog_match_target : `pandas.DataFrame`
710+
catalog_match_target : `pandas.DataFrame` | `astropy.table.Table`
707711
A catalog with selection flags for each target source.
708712
wcs : `lsst.afw.image.SkyWcs`
709713
A coordinate system to convert catalog positions to sky coordinates,
@@ -718,16 +722,33 @@ def run(
718722
# Would be nice if this could refer directly to ConfigClass
719723
config: DiffMatchedTractCatalogConfig = self.config
720724

721-
select_ref = catalog_match_ref['match_candidate'].values
725+
is_ref_pd = isinstance(catalog_ref, pd.DataFrame)
726+
is_target_pd = isinstance(catalog_target, pd.DataFrame)
727+
is_match_ref_pd = isinstance(catalog_match_ref, pd.DataFrame)
728+
is_match_target_pd = isinstance(catalog_match_target, pd.DataFrame)
729+
if is_ref_pd:
730+
catalog_ref = astropy.table.Table.from_pandas(catalog_ref)
731+
if is_target_pd:
732+
catalog_target = astropy.table.Table.from_pandas(catalog_target)
733+
if is_match_ref_pd:
734+
catalog_match_ref = astropy.table.Table.from_pandas(catalog_match_ref)
735+
if is_match_target_pd:
736+
catalog_match_target = astropy.table.Table.from_pandas(catalog_match_target)
737+
# TODO: Remove pandas support in DM-46523
738+
if is_ref_pd or is_target_pd or is_match_ref_pd or is_match_target_pd:
739+
warnings.warn("pandas usage in MatchProbabilisticTask is deprecated; it will be removed "
740+
" in favour of astropy.table after release 28.0.0", category=FutureWarning)
741+
742+
select_ref = catalog_match_ref['match_candidate']
722743
# Add additional selection criteria for target sources beyond those for matching
723744
# (not recommended, but can be done anyway)
724-
select_target = (catalog_match_target['match_candidate'].values
745+
select_target = (catalog_match_target['match_candidate']
725746
if 'match_candidate' in catalog_match_target.columns
726747
else np.ones(len(catalog_match_target), dtype=bool))
727748
for column in config.columns_target_select_true:
728-
select_target &= catalog_target[column].values
749+
select_target &= catalog_target[column]
729750
for column in config.columns_target_select_false:
730-
select_target &= ~catalog_target[column].values
751+
select_target &= ~catalog_target[column]
731752

732753
ref, target = config.coord_format.format_catalogs(
733754
catalog_ref=catalog_ref, catalog_target=catalog_target,
@@ -739,9 +760,9 @@ def run(
739760

740761
if config.include_unmatched:
741762
for cat_add, cat_match in ((cat_ref, catalog_match_ref), (cat_target, catalog_match_target)):
742-
cat_add['match_candidate'] = cat_match['match_candidate'].values
763+
cat_add['match_candidate'] = cat_match['match_candidate']
743764

744-
match_row = catalog_match_ref['match_row'].values
765+
match_row = catalog_match_ref['match_row']
745766
matched_ref = match_row >= 0
746767
matched_row = match_row[matched_ref]
747768
matched_target = np.zeros(n_target, dtype=bool)
@@ -761,48 +782,44 @@ def run(
761782
) if config.coord_format.coords_spherical else np.hypot(
762783
target_match_c1 - target_ref_c1, target_match_c2 - target_ref_c2,
763784
)
785+
cat_target_matched = cat_target[matched_row]
786+
# This will convert a masked array to an array filled with nans
787+
# wherever there are bad values (otherwise sphdist can raise)
788+
c1_err, c2_err = (
789+
np.ma.getdata(cat_target_matched[c_err]) for c_err in (coord1_target_err, coord2_target_err)
790+
)
764791
# Should probably explicitly add cosine terms if ref has errors too
765792
dist_err[matched_row] = sphdist(
766-
target_match_c1, target_match_c2,
767-
target_match_c1 + cat_target.iloc[matched_row][coord1_target_err].values,
768-
target_match_c2 + cat_target.iloc[matched_row][coord2_target_err].values,
769-
) if config.coord_format.coords_spherical else np.hypot(
770-
cat_target.iloc[matched_row][coord1_target_err].values,
771-
cat_target.iloc[matched_row][coord2_target_err].values
772-
)
793+
target_match_c1, target_match_c2, target_match_c1 + c1_err, target_match_c2 + c2_err
794+
) if config.coord_format.coords_spherical else np.hypot(c1_err, c2_err)
773795
cat_target[column_dist], cat_target[column_dist_err] = dist, dist_err
774796

775797
# Create a matched table, preserving the target catalog's named index (if it has one)
776-
cat_left = cat_target.iloc[matched_row]
777-
has_index_left = cat_left.index.name is not None
778-
cat_right = cat_ref[matched_ref].reset_index()
779-
cat_right.columns = [f'{config.column_matched_prefix_ref}{col}' for col in cat_right.columns]
780-
cat_matched = pd.concat(objs=(cat_left.reset_index(drop=not has_index_left), cat_right), axis=1)
798+
cat_left = cat_target[matched_row]
799+
cat_right = cat_ref[matched_ref]
800+
cat_right.rename_columns(
801+
list(cat_right.columns),
802+
new_names=[f'{config.column_matched_prefix_ref}{col}' for col in cat_right.columns],
803+
)
804+
cat_matched = astropy.table.hstack((cat_left, cat_right))
781805

782806
if config.include_unmatched:
783807
# Create an unmatched table with the same schema as the matched one
784808
# ... but only for objects with no matches (for completeness/purity)
785809
# and that were selected for matching (or inclusion via config)
786-
cat_right = cat_ref[~matched_ref & select_ref].reset_index(drop=False)
787-
cat_right.columns = (f'{config.column_matched_prefix_ref}{col}' for col in cat_right.columns)
788-
match_row_target = catalog_match_target['match_row'].values
789-
cat_left = cat_target[~(match_row_target >= 0) & select_target].reset_index(
790-
drop=not has_index_left)
810+
cat_right = astropy.table.Table(
811+
cat_ref[~matched_ref & select_ref]
812+
)
813+
cat_right.rename_columns(
814+
cat_right.colnames,
815+
[f"{config.column_matched_prefix_ref}{col}" for col in cat_right.colnames],
816+
)
817+
match_row_target = catalog_match_target['match_row']
818+
cat_left = cat_target[~(match_row_target >= 0) & select_target]
819+
# This may be slower than pandas but will, for example, create
820+
# masked columns for booleans, which pandas does not support.
791821
# See https://github.com/pandas-dev/pandas/issues/46662
792-
# astropy masked columns would handle this much more gracefully
793-
# Unfortunately, that would require storageClass migration
794-
# So we use pandas "extended" nullable types for now
795-
for cat_i in (cat_left, cat_right):
796-
for colname in cat_i.columns:
797-
column = cat_i[colname]
798-
dtype = str(column.dtype)
799-
if dtype == "bool":
800-
cat_i[colname] = column.astype("boolean")
801-
elif dtype.startswith("int"):
802-
cat_i[colname] = column.astype(f"Int{dtype[3:]}")
803-
elif dtype.startswith("uint"):
804-
cat_i[colname] = column.astype(f"UInt{dtype[3:]}")
805-
cat_unmatched = pd.concat(objs=(cat_left, cat_right))
822+
cat_unmatched = astropy.table.vstack([cat_left, cat_right])
806823

807824
for columns_convert_base, prefix in (
808825
(config.columns_ref_mag_to_nJy, config.column_matched_prefix_ref),
@@ -812,8 +829,14 @@ def run(
812829
columns_convert = {
813830
f"{prefix}{k}": f"{prefix}{v}" for k, v in columns_convert_base.items()
814831
} if prefix else columns_convert_base
815-
for cat_convert in (cat_matched, cat_unmatched):
816-
cat_convert.rename(columns=columns_convert, inplace=True)
832+
to_convert = [cat_matched]
833+
if config.include_unmatched:
834+
to_convert.append(cat_unmatched)
835+
for cat_convert in to_convert:
836+
cat_convert.rename_columns(
837+
tuple(columns_convert.keys()),
838+
tuple(columns_convert.values()),
839+
)
817840
for column_flux in columns_convert.values():
818841
cat_convert[column_flux] = u.ABmag.to(u.nJy, cat_convert[column_flux])
819842

@@ -822,7 +845,8 @@ def run(
822845
n_bands = len(band_fluxes)
823846

824847
# TODO: Deprecated by RFC-1017 and to be removed in DM-44988
825-
if self.config.compute_stats and (n_bands > 0):
848+
do_stats = self.config.compute_stats and (n_bands > 0)
849+
if do_stats:
826850
# Slightly smelly hack for when a column (like distance) is already relative to truth
827851
column_dummy = 'dummy'
828852
cat_ref[column_dummy] = np.zeros_like(ref.coord1)
@@ -831,7 +855,7 @@ def run(
831855
# TODO: remove the assumption of a boolean column
832856
extended_ref = cat_ref[config.column_ref_extended] == (not config.column_ref_extended_inverted)
833857

834-
extended_target = cat_target[config.column_target_extended].values >= config.extendedness_cut
858+
extended_target = cat_target[config.column_target_extended] >= config.extendedness_cut
835859

836860
# Define difference/chi columns and statistics thereof
837861
suffixes = {MeasurementType.DIFF: 'diff', MeasurementType.CHI: 'chi'}
@@ -999,7 +1023,7 @@ def run(
9991023

10001024
if n_match > 0:
10011025
rows_matched = match_row_bin[match_good]
1002-
subset_target = cat_target.iloc[rows_matched]
1026+
subset_target = cat_target[rows_matched]
10031027
if (is_extended is not None) and (idx_model == 0):
10041028
right_type = extended_target[rows_matched] == is_extended
10051029
n_total = len(right_type)
@@ -1016,15 +1040,15 @@ def run(
10161040
# compute stats for this bin, for all columns
10171041
for column, (column_ref, column_target, column_err_target, skip_diff) \
10181042
in columns_target.items():
1019-
values_ref = cat_ref[column_ref][match_good].values
1043+
values_ref = cat_ref[column_ref][match_good]
10201044
errors_target = (
1021-
subset_target[column_err_target].values
1045+
subset_target[column_err_target]
10221046
if column_err_target is not None
10231047
else None
10241048
)
10251049
compute_stats(
10261050
values_ref,
1027-
subset_target[column_target].values,
1051+
subset_target[column_target],
10281052
errors_target,
10291053
row,
10301054
stats,
@@ -1066,7 +1090,10 @@ def run(
10661090
mag_ref_first = mag_ref
10671091

10681092
if config.include_unmatched:
1069-
cat_matched = pd.concat((cat_matched, cat_unmatched))
1093+
# This is probably less efficient than just doing an outer join originally; worth checking
1094+
cat_matched = astropy.table.vstack([cat_matched, cat_unmatched])
10701095

1071-
retStruct = pipeBase.Struct(cat_matched=cat_matched, diff_matched=pd.DataFrame(data))
1096+
retStruct = pipeBase.Struct(cat_matched=cat_matched)
1097+
if do_stats:
1098+
retStruct.diff_matched = astropy.table.Table(data)
10721099
return retStruct

0 commit comments

Comments
 (0)