2020# along with this program. If not, see <https://www.gnu.org/licenses/>.
2121
2222__all__ = [
23- "CoaddMultibandFitConfig" , "CoaddMultibandFitSubConfig " , "CoaddMultibandFitSubTask " ,
24- "CoaddMultibandFitTask" ,
23+ "CoaddMultibandFitConfig" , "CoaddMultibandFitConnections " , "CoaddMultibandFitSubConfig " ,
24+ "CoaddMultibandFitSubTask" , " CoaddMultibandFitTask" ,
2525]
2626
2727from .fit_multiband import CatalogExposure , CatalogExposureConfig
4242CoaddMultibandFitBaseTemplates = {
4343 "name_coadd" : "deep" ,
4444 "name_method" : "multiprofit" ,
45+ "name_table" : "objects" ,
4546}
4647
4748
@@ -165,11 +166,15 @@ def adjustQuantum(self, inputs, outputs, label, data_id):
165166 super ().adjustQuantum (inputs , outputs , label , data_id )
166167 return adjusted_inputs , {}
167168
169+ def __init__ (self , * , config = None ):
170+ if config .drop_psf_connection :
171+ del self .models_psf
172+
168173
169174class CoaddMultibandFitConnections (CoaddMultibandFitInputConnections ):
170175 cat_output = cT .Output (
171176 doc = "Output source model fit parameter catalog" ,
172- name = "{name_coadd}Coadd_objects_ {name_method}" ,
177+ name = "{name_coadd}Coadd_{name_table}_ {name_method}" ,
173178 storageClass = "ArrowTable" ,
174179 dimensions = ("tract" , "patch" , "skymap" ),
175180 )
@@ -240,6 +245,10 @@ class CoaddMultibandFitBaseConfig(
240245):
241246 """Base class for multiband fitting."""
242247
248+ drop_psf_connection = pexConfig .Field [bool ](
249+ doc = "Whether to drop the PSF model connection, e.g. because PSF parameters are in the input catalog" ,
250+ default = False ,
251+ )
243252 fit_coadd_multiband = pexConfig .ConfigurableField (
244253 target = CoaddMultibandFitSubTask ,
245254 doc = "Task to fit sources using multiple bands" ,
@@ -281,12 +290,18 @@ class CoaddMultibandFitBase:
281290 def build_catexps (self , butlerQC , inputRefs , inputs ) -> list [CatalogExposureInputs ]:
282291 id_tp = self .config .idGenerator .apply (butlerQC .quantum .dataId ).catalog_id
283292 # This is a roundabout way of ensuring all inputs get sorted and matched
284- input_refs_objs = [(getattr (inputRefs , key ), inputs [key ])
285- for key in ("cats_meas" , "coadds" , "models_psf" )]
286- cats , exps , models_psf = [
293+ keys = ["cats_meas" , "coadds" ]
294+ has_psf_models = "models_psf" in inputs
295+ if has_psf_models :
296+ keys .append ("models_psf" )
297+ input_refs_objs = ((getattr (inputRefs , key ), inputs [key ]) for key in keys )
298+ inputs_sorted = tuple (
287299 {dRef .dataId : obj for dRef , obj in zip (refs , objs )}
288300 for refs , objs in input_refs_objs
289- ]
301+ )
302+ cats = inputs_sorted [0 ]
303+ exps = inputs_sorted [1 ]
304+ models_psf = inputs_sorted [2 ] if has_psf_models else None
290305 dataIds = set (cats ).union (set (exps ))
291306 models_scarlet = inputs ["models_scarlet" ]
292307 catexps = {}
@@ -302,8 +317,11 @@ def build_catexps(self, butlerQC, inputRefs, inputs) -> list[CatalogExposureInpu
302317 updateFluxColumns = False ,
303318 )
304319 catexps [dataId ['band' ]] = CatalogExposureInputs (
305- catalog = catalog , exposure = exposure , table_psf_fits = models_psf [dataId ],
306- dataId = dataId , id_tract_patch = id_tp ,
320+ catalog = catalog ,
321+ exposure = exposure ,
322+ table_psf_fits = models_psf [dataId ] if has_psf_models else astropy .table .Table (),
323+ dataId = dataId ,
324+ id_tract_patch = id_tp ,
307325 )
308326 catexps = [catexps [band ] for band in self .config .get_band_sets ()[0 ]]
309327 return catexps
@@ -318,7 +336,7 @@ class CoaddMultibandFitTask(CoaddMultibandFitBase, pipeBase.PipelineTask):
318336 """
319337
320338 ConfigClass = CoaddMultibandFitConfig
321- _DefaultName = "CoaddMultibandFit "
339+ _DefaultName = "coaddMultibandFit "
322340
323341 def __init__ (self , initInputs , ** kwargs ):
324342 super ().__init__ (initInputs = initInputs , ** kwargs )
0 commit comments