3737
3838from abc import ABCMeta , abstractmethod
3939from astropy .stats import mad_std
40+ import astropy .table
4041import astropy .units as u
4142from dataclasses import dataclass
4243from decimal import Decimal
4849from smatch .matcher import sphdist
4950from types import SimpleNamespace
5051from typing import Sequence
52+ import warnings
5153
5254
5355def is_sequence_set (x : Sequence ):
@@ -75,14 +77,14 @@ class DiffMatchedTractCatalogConnections(
7577 cat_ref = cT .Input (
7678 doc = "Reference object catalog to match from" ,
7779 name = "{name_input_cat_ref}" ,
78- storageClass = "DataFrame " ,
80+ storageClass = "ArrowAstropy " ,
7981 dimensions = ("tract" , "skymap" ),
8082 deferLoad = True ,
8183 )
8284 cat_target = cT .Input (
8385 doc = "Target object catalog to match" ,
8486 name = "{name_input_cat_target}" ,
85- storageClass = "DataFrame " ,
87+ storageClass = "ArrowAstropy " ,
8688 dimensions = ("tract" , "skymap" ),
8789 deferLoad = True ,
8890 )
@@ -95,33 +97,33 @@ class DiffMatchedTractCatalogConnections(
9597 cat_match_ref = cT .Input (
9698 doc = "Reference match catalog with indices of target matches" ,
9799 name = "match_ref_{name_input_cat_ref}_{name_input_cat_target}" ,
98- storageClass = "DataFrame " ,
100+ storageClass = "ArrowAstropy " ,
99101 dimensions = ("tract" , "skymap" ),
100102 deferLoad = True ,
101103 )
102104 cat_match_target = cT .Input (
103105 doc = "Target match catalog with indices of references matches" ,
104106 name = "match_target_{name_input_cat_ref}_{name_input_cat_target}" ,
105- storageClass = "DataFrame " ,
107+ storageClass = "ArrowAstropy " ,
106108 dimensions = ("tract" , "skymap" ),
107109 deferLoad = True ,
108110 )
109111 columns_match_target = cT .Input (
110112 doc = "Target match catalog columns" ,
111113 name = "match_target_{name_input_cat_ref}_{name_input_cat_target}.columns" ,
112- storageClass = "DataFrameIndex " ,
114+ storageClass = "ArrowColumnList " ,
113115 dimensions = ("tract" , "skymap" ),
114116 )
115117 cat_matched = cT .Output (
116118 doc = "Catalog with reference and target columns for joined sources" ,
117119 name = "matched_{name_input_cat_ref}_{name_input_cat_target}" ,
118- storageClass = "DataFrame " ,
120+ storageClass = "ArrowAstropy " ,
119121 dimensions = ("tract" , "skymap" ),
120122 )
121123 diff_matched = cT .Output (
122124 doc = "Table with aggregated counts, difference and chi statistics" ,
123125 name = "diff_matched_{name_input_cat_ref}_{name_input_cat_target}" ,
124- storageClass = "DataFrame " ,
126+ storageClass = "ArrowAstropy " ,
125127 dimensions = ("tract" , "skymap" ),
126128 )
127129
@@ -137,6 +139,8 @@ def __init__(self, *, config=None):
137139 dimensions = (),
138140 deferLoad = old .deferLoad ,
139141 )
142+ if not (config .compute_stats and len (config .columns_flux ) > 0 ):
143+ del self .diff_matched
140144
141145
142146class MatchedCatalogFluxesConfig (pexConfig .Config ):
@@ -685,25 +689,25 @@ def runQuantum(self, butlerQC, inputRefs, outputRefs):
685689
686690 def run (
687691 self ,
688- catalog_ref : pd .DataFrame ,
689- catalog_target : pd .DataFrame ,
690- catalog_match_ref : pd .DataFrame ,
691- catalog_match_target : pd .DataFrame ,
692+ catalog_ref : pd .DataFrame | astropy . table . Table ,
693+ catalog_target : pd .DataFrame | astropy . table . Table ,
694+ catalog_match_ref : pd .DataFrame | astropy . table . Table ,
695+ catalog_match_target : pd .DataFrame | astropy . table . Table ,
692696 wcs : afwGeom .SkyWcs = None ,
693697 ) -> pipeBase .Struct :
694698 """Load matched reference and target (measured) catalogs, measure summary statistics, and output
695699 a combined matched catalog with columns from both inputs.
696700
697701 Parameters
698702 ----------
699- catalog_ref : `pandas.DataFrame`
703+ catalog_ref : `pandas.DataFrame` | `astropy.table.Table`
700704 A reference catalog to diff objects/sources from.
701- catalog_target : `pandas.DataFrame`
705+ catalog_target : `pandas.DataFrame` | `astropy.table.Table`
702706 A target catalog to diff reference objects/sources to.
703- catalog_match_ref : `pandas.DataFrame`
707+ catalog_match_ref : `pandas.DataFrame` | `astropy.table.Table`
704708 A catalog with match indices of target sources and selection flags
705709 for each reference source.
706- catalog_match_target : `pandas.DataFrame`
710+ catalog_match_target : `pandas.DataFrame` | `astropy.table.Table`
707711 A catalog with selection flags for each target source.
708712 wcs : `lsst.afw.image.SkyWcs`
709713 A coordinate system to convert catalog positions to sky coordinates,
@@ -718,16 +722,33 @@ def run(
718722 # Would be nice if this could refer directly to ConfigClass
719723 config : DiffMatchedTractCatalogConfig = self .config
720724
721- select_ref = catalog_match_ref ['match_candidate' ].values
725+ is_ref_pd = isinstance (catalog_ref , pd .DataFrame )
726+ is_target_pd = isinstance (catalog_target , pd .DataFrame )
727+ is_match_ref_pd = isinstance (catalog_match_ref , pd .DataFrame )
728+ is_match_target_pd = isinstance (catalog_match_target , pd .DataFrame )
729+ if is_ref_pd :
730+ catalog_ref = astropy .table .Table .from_pandas (catalog_ref )
731+ if is_target_pd :
732+ catalog_target = astropy .table .Table .from_pandas (catalog_target )
733+ if is_match_ref_pd :
734+ catalog_match_ref = astropy .table .Table .from_pandas (catalog_match_ref )
735+ if is_match_target_pd :
736+ catalog_match_target = astropy .table .Table .from_pandas (catalog_match_target )
737+ # TODO: Remove pandas support in DM-46523
738+ if is_ref_pd or is_target_pd or is_match_ref_pd or is_match_target_pd :
739+ warnings .warn ("pandas usage in MatchProbabilisticTask is deprecated; it will be removed "
740+ " in favour of astropy.table after release 28.0.0" , category = FutureWarning )
741+
742+ select_ref = catalog_match_ref ['match_candidate' ]
722743 # Add additional selection criteria for target sources beyond those for matching
723744 # (not recommended, but can be done anyway)
724- select_target = (catalog_match_target ['match_candidate' ]. values
745+ select_target = (catalog_match_target ['match_candidate' ]
725746 if 'match_candidate' in catalog_match_target .columns
726747 else np .ones (len (catalog_match_target ), dtype = bool ))
727748 for column in config .columns_target_select_true :
728- select_target &= catalog_target [column ]. values
749+ select_target &= catalog_target [column ]
729750 for column in config .columns_target_select_false :
730- select_target &= ~ catalog_target [column ]. values
751+ select_target &= ~ catalog_target [column ]
731752
732753 ref , target = config .coord_format .format_catalogs (
733754 catalog_ref = catalog_ref , catalog_target = catalog_target ,
@@ -739,9 +760,9 @@ def run(
739760
740761 if config .include_unmatched :
741762 for cat_add , cat_match in ((cat_ref , catalog_match_ref ), (cat_target , catalog_match_target )):
742- cat_add ['match_candidate' ] = cat_match ['match_candidate' ]. values
763+ cat_add ['match_candidate' ] = cat_match ['match_candidate' ]
743764
744- match_row = catalog_match_ref ['match_row' ]. values
765+ match_row = catalog_match_ref ['match_row' ]
745766 matched_ref = match_row >= 0
746767 matched_row = match_row [matched_ref ]
747768 matched_target = np .zeros (n_target , dtype = bool )
@@ -761,48 +782,44 @@ def run(
761782 ) if config .coord_format .coords_spherical else np .hypot (
762783 target_match_c1 - target_ref_c1 , target_match_c2 - target_ref_c2 ,
763784 )
785+ cat_target_matched = cat_target [matched_row ]
786+ # This will convert a masked array to an array filled with nans
787+ # wherever there are bad values (otherwise sphdist can raise)
788+ c1_err , c2_err = (
789+ np .ma .getdata (cat_target_matched [c_err ]) for c_err in (coord1_target_err , coord2_target_err )
790+ )
764791 # Should probably explicitly add cosine terms if ref has errors too
765792 dist_err [matched_row ] = sphdist (
766- target_match_c1 , target_match_c2 ,
767- target_match_c1 + cat_target .iloc [matched_row ][coord1_target_err ].values ,
768- target_match_c2 + cat_target .iloc [matched_row ][coord2_target_err ].values ,
769- ) if config .coord_format .coords_spherical else np .hypot (
770- cat_target .iloc [matched_row ][coord1_target_err ].values ,
771- cat_target .iloc [matched_row ][coord2_target_err ].values
772- )
793+ target_match_c1 , target_match_c2 , target_match_c1 + c1_err , target_match_c2 + c2_err
794+ ) if config .coord_format .coords_spherical else np .hypot (c1_err , c2_err )
773795 cat_target [column_dist ], cat_target [column_dist_err ] = dist , dist_err
774796
775797 # Create a matched table, preserving the target catalog's named index (if it has one)
776- cat_left = cat_target .iloc [matched_row ]
777- has_index_left = cat_left .index .name is not None
778- cat_right = cat_ref [matched_ref ].reset_index ()
779- cat_right .columns = [f'{ config .column_matched_prefix_ref } { col } ' for col in cat_right .columns ]
780- cat_matched = pd .concat (objs = (cat_left .reset_index (drop = not has_index_left ), cat_right ), axis = 1 )
798+ cat_left = cat_target [matched_row ]
799+ cat_right = cat_ref [matched_ref ]
800+ cat_right .rename_columns (
801+ list (cat_right .columns ),
802+ new_names = [f'{ config .column_matched_prefix_ref } { col } ' for col in cat_right .columns ],
803+ )
804+ cat_matched = astropy .table .hstack ((cat_left , cat_right ))
781805
782806 if config .include_unmatched :
783807 # Create an unmatched table with the same schema as the matched one
784808 # ... but only for objects with no matches (for completeness/purity)
785809 # and that were selected for matching (or inclusion via config)
786- cat_right = cat_ref [~ matched_ref & select_ref ].reset_index (drop = False )
787- cat_right .columns = (f'{ config .column_matched_prefix_ref } { col } ' for col in cat_right .columns )
788- match_row_target = catalog_match_target ['match_row' ].values
789- cat_left = cat_target [~ (match_row_target >= 0 ) & select_target ].reset_index (
790- drop = not has_index_left )
810+ cat_right = astropy .table .Table (
811+ cat_ref [~ matched_ref & select_ref ]
812+ )
813+ cat_right .rename_columns (
814+ cat_right .colnames ,
815+ [f"{ config .column_matched_prefix_ref } { col } " for col in cat_right .colnames ],
816+ )
817+ match_row_target = catalog_match_target ['match_row' ]
818+ cat_left = cat_target [~ (match_row_target >= 0 ) & select_target ]
819+ # This may be slower than pandas but will, for example, create
820+ # masked columns for booleans, which pandas does not support.
791821 # See https://github.com/pandas-dev/pandas/issues/46662
792- # astropy masked columns would handle this much more gracefully
793- # Unfortunately, that would require storageClass migration
794- # So we use pandas "extended" nullable types for now
795- for cat_i in (cat_left , cat_right ):
796- for colname in cat_i .columns :
797- column = cat_i [colname ]
798- dtype = str (column .dtype )
799- if dtype == "bool" :
800- cat_i [colname ] = column .astype ("boolean" )
801- elif dtype .startswith ("int" ):
802- cat_i [colname ] = column .astype (f"Int{ dtype [3 :]} " )
803- elif dtype .startswith ("uint" ):
804- cat_i [colname ] = column .astype (f"UInt{ dtype [3 :]} " )
805- cat_unmatched = pd .concat (objs = (cat_left , cat_right ))
822+ cat_unmatched = astropy .table .vstack ([cat_left , cat_right ])
806823
807824 for columns_convert_base , prefix in (
808825 (config .columns_ref_mag_to_nJy , config .column_matched_prefix_ref ),
@@ -812,8 +829,14 @@ def run(
812829 columns_convert = {
813830 f"{ prefix } { k } " : f"{ prefix } { v } " for k , v in columns_convert_base .items ()
814831 } if prefix else columns_convert_base
815- for cat_convert in (cat_matched , cat_unmatched ):
816- cat_convert .rename (columns = columns_convert , inplace = True )
832+ to_convert = [cat_matched ]
833+ if config .include_unmatched :
834+ to_convert .append (cat_unmatched )
835+ for cat_convert in to_convert :
836+ cat_convert .rename_columns (
837+ tuple (columns_convert .keys ()),
838+ tuple (columns_convert .values ()),
839+ )
817840 for column_flux in columns_convert .values ():
818841 cat_convert [column_flux ] = u .ABmag .to (u .nJy , cat_convert [column_flux ])
819842
@@ -822,7 +845,8 @@ def run(
822845 n_bands = len (band_fluxes )
823846
824847 # TODO: Deprecated by RFC-1017 and to be removed in DM-44988
825- if self .config .compute_stats and (n_bands > 0 ):
848+ do_stats = self .config .compute_stats and (n_bands > 0 )
849+ if do_stats :
826850 # Slightly smelly hack for when a column (like distance) is already relative to truth
827851 column_dummy = 'dummy'
828852 cat_ref [column_dummy ] = np .zeros_like (ref .coord1 )
@@ -831,7 +855,7 @@ def run(
831855 # TODO: remove the assumption of a boolean column
832856 extended_ref = cat_ref [config .column_ref_extended ] == (not config .column_ref_extended_inverted )
833857
834- extended_target = cat_target [config .column_target_extended ]. values >= config .extendedness_cut
858+ extended_target = cat_target [config .column_target_extended ] >= config .extendedness_cut
835859
836860 # Define difference/chi columns and statistics thereof
837861 suffixes = {MeasurementType .DIFF : 'diff' , MeasurementType .CHI : 'chi' }
@@ -999,7 +1023,7 @@ def run(
9991023
10001024 if n_match > 0 :
10011025 rows_matched = match_row_bin [match_good ]
1002- subset_target = cat_target . iloc [rows_matched ]
1026+ subset_target = cat_target [rows_matched ]
10031027 if (is_extended is not None ) and (idx_model == 0 ):
10041028 right_type = extended_target [rows_matched ] == is_extended
10051029 n_total = len (right_type )
@@ -1016,15 +1040,15 @@ def run(
10161040 # compute stats for this bin, for all columns
10171041 for column , (column_ref , column_target , column_err_target , skip_diff ) \
10181042 in columns_target .items ():
1019- values_ref = cat_ref [column_ref ][match_good ]. values
1043+ values_ref = cat_ref [column_ref ][match_good ]
10201044 errors_target = (
1021- subset_target [column_err_target ]. values
1045+ subset_target [column_err_target ]
10221046 if column_err_target is not None
10231047 else None
10241048 )
10251049 compute_stats (
10261050 values_ref ,
1027- subset_target [column_target ]. values ,
1051+ subset_target [column_target ],
10281052 errors_target ,
10291053 row ,
10301054 stats ,
@@ -1066,7 +1090,10 @@ def run(
10661090 mag_ref_first = mag_ref
10671091
10681092 if config .include_unmatched :
1069- cat_matched = pd .concat ((cat_matched , cat_unmatched ))
1093+ # This is probably less efficient than just doing an outer join originally; worth checking
1094+ cat_matched = astropy .table .vstack ([cat_matched , cat_unmatched ])
10701095
1071- retStruct = pipeBase .Struct (cat_matched = cat_matched , diff_matched = pd .DataFrame (data ))
1096+ retStruct = pipeBase .Struct (cat_matched = cat_matched )
1097+ if do_stats :
1098+ retStruct .diff_matched = astropy .table .Table (data )
10721099 return retStruct
0 commit comments