Skip to content

Commit e871072

Browse files
committed
Refactoring
1 parent 6142886 commit e871072

File tree

2 files changed

+189
-104
lines changed

2 files changed

+189
-104
lines changed

bin.src/curate_templates.py

Lines changed: 114 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,10 @@
2929

3030
import argparse
3131
import logging
32-
import numpy as np
3332
import os
3433
import sys
3534

36-
from astropy.table import Table
35+
from astropy.table import Table, vstack
3736

3837
from lsst.daf.butler import Butler, CollectionType
3938

@@ -56,6 +55,7 @@ def _make_parser():
5655
"--collections",
5756
action="extend",
5857
nargs="+",
58+
required=True,
5959
help="The input collections to search for template_coadd and coadd_depth_table datasets.",
6060
)
6161
parser.add_argument(
@@ -86,20 +86,20 @@ def _make_parser():
8686
"--cutoff",
8787
required=False,
8888
default=95,
89+
type=int,
8990
help="The curation process will filter out anything below this cutoff."
9091
" Default is 95.",
9192
)
9293
return parser
9394

9495

95-
def get_tracts(butler, where):
96-
tracts = []
97-
coadd_depth_tables = butler.registry.queryDatasets(datasetType='coadd_depth_table', where=where)
98-
for item in coadd_depth_tables:
99-
tract = item.dataId['tract']
100-
tracts.append(tract)
101-
tracts = set(tracts)
102-
return tracts
96+
def select_ref(drefs, tract, patch, band, dtype="template_coadd"):
97+
if not drefs:
98+
logging.warning(f"No {dtype} found for tract {tract}, patch {patch}, band {band}. Skipping.")
99+
return None
100+
if len(drefs) > 1:
101+
return sorted(drefs, key=lambda ref: ref.run)[-1]
102+
return drefs[0]
103103

104104

105105
def make_threshold_cuts(butler, template_coadds, n_images, tracts, filter_by, cutoff):
@@ -126,13 +126,21 @@ def make_threshold_cuts(butler, template_coadds, n_images, tracts, filter_by, cu
126126
and d.dataId['band'] == band
127127
]
128128

129+
if not dref:
130+
logging.warning(f"No template_coadd found for tract {tract}, patch {patch}, band {band}. "
131+
f"Skipping.")
132+
continue
129133
if len(dref) > 1:
130134
sorted_dupe_entry = sorted(dref, key=lambda ref: ref.run)
131135
ref = sorted_dupe_entry[-1]
132136
else:
133137
ref = dref[0]
134138
accepted_drefs.append(ref)
135139

140+
if not n_image_dref:
141+
logging.warning(f"No template_coadd_n_image found for tract {tract}, patch {patch}, "
142+
f"band {band}. Skipping.")
143+
continue
136144
if len(n_image_dref) > 1:
137145
sorted_dupe_entry = sorted(n_image_dref, key=lambda ref: ref.run)
138146
n_image_ref = sorted_dupe_entry[-1]
@@ -150,6 +158,10 @@ def make_threshold_cuts(butler, template_coadds, n_images, tracts, filter_by, cu
150158
and d.dataId['band'] == band
151159
]
152160

161+
if not dref:
162+
logging.warning(f"No template_coadd found for tract {tract}, patch {patch}, band {band}. "
163+
f"Skipping.")
164+
continue
153165
if len(dref) > 1:
154166
sorted_dupe_entry = sorted(dref, key=lambda ref: ref.run)
155167
ref = sorted_dupe_entry[-1]
@@ -160,68 +172,75 @@ def make_threshold_cuts(butler, template_coadds, n_images, tracts, filter_by, cu
160172

161173

162174
def run_stats(accepted_drefs, rejected_drefs, tracts, stats_records_file):
163-
# Create table of accepted drefs
164-
accepted = Table()
165-
accepted_tracts = []
166-
accepted_patches = []
167-
accepted_bands = []
168-
bands = ['u', 'g', 'r', 'i', 'z', 'y']
169-
170-
for ref in accepted_drefs:
171-
accepted_tracts.append(ref.dataId['tract'])
172-
accepted_patches.append(ref.dataId['patch'])
173-
accepted_bands.append(ref.dataId['band'])
174-
175-
accepted_table_data = [accepted_tracts, accepted_patches, accepted_bands]
176-
accepted = Table(data=accepted_table_data, names=['tract', 'patch', 'band'])
175+
"""
176+
Compute per-tract and per-band accepted/rejected statistics and save to CSV.
177+
178+
Parameters
179+
----------
180+
accepted_drefs : list of DatasetRef
181+
Template coadd references that passed curation.
182+
rejected_drefs : list of DatasetRef
183+
Template coadd references that failed curation.
184+
tracts : iterable of int
185+
List of tract IDs to include in the stats.
186+
stats_records_file : str
187+
Path to save the resulting CSV file.
188+
"""
177189

178-
# Create table of rejected drefs
179-
rejected = Table()
180-
rejected_tracts = []
181-
rejected_patches = []
182-
rejected_bands = []
183-
184-
for ref in rejected_drefs:
185-
rejected_tracts.append(ref.dataId['tract'])
186-
rejected_patches.append(ref.dataId['patch'])
187-
rejected_bands.append(ref.dataId['band'])
190+
bands = ['u', 'g', 'r', 'i', 'z', 'y']
188191

189-
rejected_table_data = [rejected_tracts, rejected_patches, rejected_bands]
190-
rejected = Table(data=rejected_table_data, names=['tract', 'patch', 'band'])
192+
# Build accepted table
193+
if accepted_drefs:
194+
accepted = Table(
195+
{
196+
'tract': [int(r.dataId['tract']) for r in accepted_drefs],
197+
'patch': [int(r.dataId['patch']) for r in accepted_drefs],
198+
'band': [str(r.dataId['band']) for r in accepted_drefs],
199+
'status': ['accepted'] * len(accepted_drefs)
200+
}
201+
)
202+
else:
203+
accepted = Table(names=('tract', 'patch', 'band', 'status'))
204+
205+
# Build rejected table
206+
if rejected_drefs:
207+
rejected = Table(
208+
{
209+
'tract': [int(r.dataId['tract']) for r in rejected_drefs],
210+
'patch': [int(r.dataId['patch']) for r in rejected_drefs],
211+
'band': [str(r.dataId['band']) for r in rejected_drefs],
212+
'status': ['rejected'] * len(rejected_drefs)
213+
}
214+
)
215+
else:
216+
rejected = Table(names=('tract', 'patch', 'band', 'status'))
217+
218+
# Combine tables
219+
all_refs = vstack([accepted, rejected])
220+
221+
# Group by tract and band
222+
grouped = all_refs.group_by(['tract', 'band'])
223+
224+
# Prepare output table
225+
stat_table_data = {'tract': [], }
226+
for band in bands:
227+
stat_table_data[f'{band}_num_accepted'] = []
228+
stat_table_data[f'{band}_percent_accepted'] = []
191229

192-
# Run stats
193-
by_band_stats = []
194230
for tract in tracts:
195-
tract_band_stats = []
196-
for band in bands:
197-
accepted_bands = ((accepted['tract'] == tract) & (accepted['band'] == band)).sum()
198-
rejected_bands = ((rejected['tract'] == tract) & (rejected['band'] == band)).sum()
199-
total_bands = accepted_bands + rejected_bands
200-
if total_bands == 0:
201-
tract_band_stats.append(["0 / 0", np.nan])
202-
else:
203-
tract_band_stats.append([f"{accepted_bands} / {total_bands}",
204-
accepted_bands / total_bands * 100])
205-
by_band_stats.append(tract_band_stats)
206-
207-
# Compile stats into a table and save
208-
accepted_col_names = [f"{band}_{suffix}" for band in bands for suffix
209-
in ("num_accepted", "percent_accepted")]
210-
by_tract_names = ['tract'] + accepted_col_names
211-
212-
stat_table_data = {col: [] for col in by_tract_names}
213-
214-
for tract_index, tract in enumerate(tracts):
215-
band_stats = by_band_stats[tract_index]
216-
217231
stat_table_data['tract'].append(tract)
232+
for band in bands:
233+
mask = (grouped['tract'] == tract) & (grouped['band'] == band)
234+
subset = grouped[mask]
235+
n_total = len(subset)
236+
n_accepted = (subset['status'] == 'accepted').sum() if n_total > 0 else 0
237+
percent = (n_accepted / n_total * 100) if n_total > 0 else float('nan')
238+
stat_table_data[f'{band}_num_accepted'].append(f"{n_accepted} / {n_total}")
239+
stat_table_data[f'{band}_percent_accepted'].append(percent)
218240

219-
for band_idx, band in enumerate(bands):
220-
accepted_str, percent = band_stats[band_idx]
221-
stat_table_data[f"{band}_num_accepted"].append(accepted_str)
222-
stat_table_data[f"{band}_percent_accepted"].append(percent)
223-
by_tract_stats = Table(stat_table_data)
224-
by_tract_stats.write(stats_records_file, format='ascii.csv', overwrite=True)
241+
# Create final table
242+
stat_table = Table(stat_table_data)
243+
stat_table.write(stats_records_file, format='ascii.csv', overwrite=True)
225244

226245

227246
def main():
@@ -254,21 +273,39 @@ def main():
254273
logging.error(f"Collection {tagged_collection} already exists. Aborting.")
255274
sys.exit(1)
256275

257-
logging.info("Collecting template_coadd and template_coadd_n_image refs.")
258-
refs = butler.query_datasets("template_coadd", where=args.where, limit=None)
259-
n_image_refs = butler.query_datasets("template_coadd_n_image", where=args.where, limit=None)
260-
logging.info(f"Found {len(refs)} template_coadd datasets in {args.collections}.")
276+
logging.info("Collecting coadd_depth_table, template_coadd, and template_coadd_n_image refs.")
277+
coadd_depth_table_refs = butler.query_datasets("coadd_depth_table", where=args.where, limit=None)
278+
if not coadd_depth_table_refs:
279+
logging.error("No coadd_depth_table datasets found in the given collections.")
280+
sys.exit(1)
281+
282+
# Get a list of relavent tracts.
283+
tracts = {item.dataId['tract'] for item in coadd_depth_table_refs}
284+
285+
# Ammend the where argument to restrict refs to relavent tracts.
286+
tracts_str = ",".join(str(t) for t in tracts)
287+
tract_restriction = f"tract IN ({tracts_str})"
288+
args.where = f"({args.where}) AND ({tract_restriction})" if args.where else tract_restriction
261289

262-
# Get a list of the tracts inside the template collection.
263-
tracts = get_tracts(butler, args.where)
290+
# Get relavent template_coadd and template_coadd_n_image refs.
291+
coadd_refs = butler.query_datasets("template_coadd", where=args.where, limit=None)
292+
if not coadd_refs:
293+
logging.error("No template_coadd datasets found in the given collections.")
294+
sys.exit(1)
295+
n_image_refs = butler.query_datasets("template_coadd_n_image", where=args.where, limit=None)
296+
if not n_image_refs:
297+
logging.error("No template_coadd_n_image datasets found in the given collections.")
298+
sys.exit(1)
299+
logging.info(f"Found {len(coadd_refs)} template_coadd datasets with coadd_depth_tables "
300+
f"in {args.collections}.")
264301

265302
# Filter out template_coads that don't meet the cutoff and save them to record.
266303
logging.info("Starting curation.")
267-
accepted_drefs, rejected_drefs, accepted_n_image_refs = make_threshold_cuts(butler, refs,
304+
accepted_drefs, rejected_drefs, accepted_n_image_refs = make_threshold_cuts(butler, coadd_refs,
268305
n_image_refs, tracts,
269306
args.filter_by, args.cutoff
270307
)
271-
logging.info(f"Curation complete. Accepted {len(accepted_drefs)} out of {len(refs)}"
308+
logging.info(f"Curation complete. Accepted {len(accepted_drefs)} out of {len(coadd_refs)}"
272309
f" template_coadd datasets in {args.collections}.")
273310

274311
# Run accepted/rejected statistics and save them to record.
@@ -278,13 +315,10 @@ def main():
278315
logging.info("Stat generation complete. Accepted/rejected stat records written to"
279316
f" {stats_records_file}.")
280317

281-
# Associate accepted template_coadds to tagged collection.
282-
logging.info(f"Associating {len(accepted_drefs)} template_coadds to {tagged_collection}.")
318+
# Associate accepted template_coadds and template_coadd_n_images to tagged collection.
319+
logging.info(f"Associating {len(accepted_drefs)} template_coadds and "
320+
f"{len(accepted_n_image_refs)} template_coadd_n_images to {tagged_collection}.")
283321
butler_write.registry.associate(tagged_collection, accepted_drefs)
284-
logging.info("Association complete.")
285-
286-
# Associate accepted template_coadd_n_images to tagged collection.
287-
logging.info(f"Associating {len(accepted_n_image_refs)} template_coadd_n_images to {tagged_collection}.")
288322
butler_write.registry.associate(tagged_collection, accepted_n_image_refs)
289323
logging.info("Association complete.")
290324

0 commit comments

Comments
 (0)