Skip to content

Commit 827239b

Browse files
committed
Merge branch 'tickets/DM-50091'
2 parents 5657b65 + 2af6140 commit 827239b

File tree

3 files changed

+110
-29
lines changed

3 files changed

+110
-29
lines changed

python/lsst/pipe/tasks/diff_matched_tract_catalog.py

Lines changed: 87 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -106,14 +106,27 @@ def __init__(self, *, config=None):
106106
if config.refcat_sharding_type != "tract":
107107
if config.refcat_sharding_type == "none":
108108
old = self.cat_ref
109-
del self.cat_ref
110109
self.cat_ref = cT.Input(
111110
doc=old.doc,
112111
name=old.name,
113112
storageClass=old.storageClass,
114113
dimensions=(),
115114
deferLoad=old.deferLoad,
116115
)
116+
else:
117+
raise NotImplementedError(f"{config.refcat_sharding_type=} not implemented")
118+
if config.target_sharding_type != "tract":
119+
if config.target_sharding_type == "none":
120+
old = self.cat_target
121+
self.cat_target = cT.Input(
122+
doc=old.doc,
123+
name=old.name,
124+
storageClass=old.storageClass,
125+
dimensions=(),
126+
deferLoad=old.deferLoad,
127+
)
128+
else:
129+
raise NotImplementedError(f"{config.target_sharding_type=} not implemented")
117130

118131

119132
class MatchedCatalogFluxesConfig(pexConfig.Config):
@@ -149,14 +162,36 @@ class DiffMatchedTractCatalogConfig(
149162
pipeBase.PipelineTaskConfig,
150163
pipelineConnections=DiffMatchedTractCatalogConnections,
151164
):
165+
column_match_candidate_ref = pexConfig.Field[str](
166+
default='match_candidate',
167+
doc='The column name for the boolean field identifying reference objects'
168+
' that were used for matching',
169+
optional=True,
170+
)
171+
column_match_candidate_target = pexConfig.Field[str](
172+
default='match_candidate',
173+
doc='The column name for the boolean field identifying target objects'
174+
' that were used for matching',
175+
optional=True,
176+
)
152177
column_matched_prefix_ref = pexConfig.Field[str](
153178
default='refcat_',
154179
doc='The prefix for matched columns copied from the reference catalog',
155180
)
181+
column_matched_prefix_target = pexConfig.Field[str](
182+
default='',
183+
doc='The prefix for matched columns copied from the target catalog',
184+
)
156185
include_unmatched = pexConfig.Field[bool](
157186
default=False,
158187
doc='Whether to include unmatched rows in the matched table',
159188
)
189+
filter_on_match_candidate = pexConfig.Field[bool](
190+
default=False,
191+
doc='Whether to use provided column_match_candidate_[ref/target] to'
192+
' exclude rows from the output table. If False, any provided'
193+
' columns will be copied instead.'
194+
)
160195
prefix_best_coord = pexConfig.Field[str](
161196
default=None,
162197
doc="A string prefix for ra/dec coordinate columns generated from the reference coordinate if "
@@ -243,6 +278,11 @@ def columns_in_target(self) -> list[str]:
243278
allowed={"tract": "Tract-based shards", "none": "No sharding at all"},
244279
default="tract",
245280
)
281+
target_sharding_type = pexConfig.ChoiceField[str](
282+
doc="The type of sharding (spatial splitting) for the target catalog",
283+
allowed={"tract": "Tract-based shards", "none": "No sharding at all"},
284+
default="tract",
285+
)
246286

247287
def validate(self):
248288
super().validate()
@@ -279,19 +319,21 @@ def runQuantum(self, butlerQC, inputRefs, outputRefs):
279319
inputs = butlerQC.get(inputRefs)
280320
skymap = inputs.pop("skymap")
281321

322+
columns_match_ref = ['match_row']
323+
if (column := self.config.column_match_candidate_ref) is not None:
324+
columns_match_ref.append(column)
325+
282326
columns_match_target = ['match_row']
283-
if 'match_candidate' in inputs['columns_match_target']:
284-
columns_match_target.append('match_candidate')
327+
if (column := self.config.column_match_candidate_target) is not None and (
328+
column in inputs['columns_match_target']
329+
):
330+
columns_match_target.append(column)
285331

286332
outputs = self.run(
287333
catalog_ref=inputs['cat_ref'].get(parameters={'columns': self.config.columns_in_ref}),
288334
catalog_target=inputs['cat_target'].get(parameters={'columns': self.config.columns_in_target}),
289-
catalog_match_ref=inputs['cat_match_ref'].get(
290-
parameters={'columns': ['match_candidate', 'match_row']},
291-
),
292-
catalog_match_target=inputs['cat_match_target'].get(
293-
parameters={'columns': columns_match_target},
294-
),
335+
catalog_match_ref=inputs['cat_match_ref'].get(parameters={'columns': columns_match_ref}),
336+
catalog_match_target=inputs['cat_match_target'].get(parameters={'columns': columns_match_target}),
295337
wcs=skymap[butlerQC.quantum.dataId["tract"]].wcs,
296338
)
297339
butlerQC.put(outputs, outputRefs)
@@ -339,12 +381,17 @@ def run(
339381
DatasetProvenance.strip_provenance_from_flat_dict(catalog_match_ref.meta)
340382
DatasetProvenance.strip_provenance_from_flat_dict(catalog_match_target.meta)
341383

342-
select_ref = catalog_match_ref['match_candidate']
384+
# It would be nice to make this a Selector but those are
385+
# only available in analysis_tools for now
386+
select_ref, select_target = (
387+
(catalog[column] if column else np.ones(len(catalog), dtype=bool))
388+
for catalog, column in (
389+
(catalog_match_ref, self.config.column_match_candidate_ref),
390+
(catalog_match_target, self.config.column_match_candidate_target),
391+
)
392+
)
343393
# Add additional selection criteria for target sources beyond those for matching
344394
# (not recommended, but can be done anyway)
345-
select_target = (catalog_match_target['match_candidate']
346-
if 'match_candidate' in catalog_match_target.columns
347-
else np.ones(len(catalog_match_target), dtype=bool))
348395
for column in config.columns_target_select_true:
349396
select_target &= catalog_target[column]
350397
for column in config.columns_target_select_false:
@@ -358,9 +405,13 @@ def run(
358405
cat_target = target.catalog
359406
n_target = len(cat_target)
360407

361-
if config.include_unmatched:
362-
for cat_add, cat_match in ((cat_ref, catalog_match_ref), (cat_target, catalog_match_target)):
363-
cat_add['match_candidate'] = cat_match['match_candidate']
408+
if not config.filter_on_match_candidate:
409+
for cat_add, cat_match, column in (
410+
(cat_ref, catalog_match_ref, config.column_match_candidate_ref),
411+
(cat_target, catalog_match_target, config.column_match_candidate_target),
412+
):
413+
if column is not None:
414+
cat_add[column] = cat_match[column]
364415

365416
match_row = catalog_match_ref['match_row']
366417
matched_ref = match_row >= 0
@@ -397,10 +448,16 @@ def run(
397448
# Create a matched table, preserving the target catalog's named index (if it has one)
398449
cat_left = cat_target[matched_row]
399450
cat_right = cat_ref[matched_ref]
400-
cat_right.rename_columns(
401-
list(cat_right.columns),
402-
new_names=[f'{config.column_matched_prefix_ref}{col}' for col in cat_right.columns],
403-
)
451+
if config.column_matched_prefix_target:
452+
cat_left.rename_columns(
453+
list(cat_left.columns),
454+
new_names=[f'{config.column_matched_prefix_target}{col}' for col in cat_left.columns],
455+
)
456+
if config.column_matched_prefix_ref:
457+
cat_right.rename_columns(
458+
list(cat_right.columns),
459+
new_names=[f'{config.column_matched_prefix_ref}{col}' for col in cat_right.columns],
460+
)
404461
cat_matched = astropy.table.hstack((cat_left, cat_right))
405462

406463
if config.include_unmatched:
@@ -416,6 +473,10 @@ def run(
416473
)
417474
match_row_target = catalog_match_target['match_row']
418475
cat_left = cat_target[~(match_row_target >= 0) & select_target]
476+
cat_left.rename_columns(
477+
cat_left.colnames,
478+
[f"{config.column_matched_prefix_target}{col}" for col in cat_left.colnames],
479+
)
419480
# This may be slower than pandas but will, for example, create
420481
# masked columns for booleans, which pandas does not support.
421482
# See https://github.com/pandas-dev/pandas/issues/46662
@@ -450,22 +511,23 @@ def run(
450511
)
451512
)
452513
for column_coord_best, column_coord_ref, column_coord_target in zip(
453-
columns_coord_best,
454-
(config.coord_format.column_ref_coord1, config.coord_format.column_ref_coord2),
455-
(config.coord_format.column_target_coord1, config.coord_format.column_target_coord2),
514+
columns_coord_best,
515+
(config.coord_format.column_ref_coord1, config.coord_format.column_ref_coord2),
516+
(config.coord_format.column_target_coord1, config.coord_format.column_target_coord2),
456517
):
457518
column_full_ref = f'{config.column_matched_prefix_ref}{column_coord_ref}'
519+
column_full_target = f'{config.column_matched_prefix_target}{column_coord_target}'
458520
values = cat_matched[column_full_ref]
459521
unit = values.unit
460522
values_bad = np.ma.masked_invalid(values).mask
461523
# Cast to an unmasked array - there will be no bad values
462524
values = np.array(values)
463-
values[values_bad] = cat_matched[column_coord_target][values_bad]
525+
values[values_bad] = cat_matched[column_full_target][values_bad]
464526
cat_matched[column_coord_best] = values
465527
cat_matched[column_coord_best].unit = unit
466528
cat_matched[column_coord_best].description = (
467529
f"Best {columns_coord_best} value from {column_full_ref} if available"
468-
f" else {column_coord_target}"
530+
f" else {column_full_target}"
469531
)
470532

471533
retStruct = pipeBase.Struct(cat_matched=cat_matched)

python/lsst/pipe/tasks/match_tract_catalog.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,14 +85,27 @@ def __init__(self, *, config=None):
8585
if config.refcat_sharding_type != "tract":
8686
if config.refcat_sharding_type == "none":
8787
old = self.cat_ref
88-
del self.cat_ref
8988
self.cat_ref = cT.Input(
9089
doc=old.doc,
9190
name=old.name,
9291
storageClass=old.storageClass,
9392
dimensions=(),
9493
deferLoad=old.deferLoad,
9594
)
95+
else:
96+
raise NotImplementedError(f"{config.refcat_sharding_type=} not implemented")
97+
if config.target_sharding_type != "tract":
98+
if config.target_sharding_type == "none":
99+
old = self.cat_target
100+
self.cat_target = cT.Input(
101+
doc=old.doc,
102+
name=old.name,
103+
storageClass=old.storageClass,
104+
dimensions=(),
105+
deferLoad=old.deferLoad,
106+
)
107+
else:
108+
raise NotImplementedError(f"{config.target_sharding_type=} not implemented")
96109

97110

98111
class MatchTractCatalogSubConfig(pexConfig.Config):
@@ -167,6 +180,11 @@ class MatchTractCatalogConfig(
167180
allowed={"tract": "Tract-based shards", "none": "No sharding at all"},
168181
default="tract",
169182
)
183+
target_sharding_type = pexConfig.ChoiceField[str](
184+
doc="The type of sharding (spatial splitting) for the target catalog",
185+
allowed={"tract": "Tract-based shards", "none": "No sharding at all"},
186+
default="tract",
187+
)
170188

171189
def get_columns_in(self) -> Tuple[Set, Set]:
172190
"""Get the set of input columns required for matching.

tests/test_diff_matched_tract_catalog.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,9 +150,10 @@ def test_DiffMatchedTractCatalogTask(self):
150150
wcs=self.wcs,
151151
)
152152
columns_result = list(result.cat_matched.columns)
153-
columns_expect = list(columns_target) + ["match_distance", "match_distanceErr"]
154-
prefix = DiffMatchedTractCatalogConfig.column_matched_prefix_ref.default
155-
columns_expect.extend((f'{prefix}{col}' for col in columns_ref))
153+
columns_expect = list(columns_target) + ["match_candidate", "match_distance", "match_distanceErr"]
154+
prefix = task.config.column_matched_prefix_ref
155+
columns_expect.extend((f"{prefix}{col}" for col in columns_ref))
156+
columns_expect.append(f"{prefix}match_candidate")
156157
self.assertListEqual(columns_expect, columns_result)
157158

158159
def test_spherical(self):

0 commit comments

Comments
 (0)