Skip to content

Commit

Permalink
Merge pull request #378 from LSSTDESC/u/yymao/use-native-parquet-writer
Browse files Browse the repository at this point in the history
Ensure columns in merged_cat are in same order to take advantage of native merge
  • Loading branch information
yymao authored Dec 4, 2019
2 parents 16f6ebb + e4a03e2 commit d938097
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 29 deletions.
43 changes: 22 additions & 21 deletions scripts/make_object_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import re

import numpy as np
from astropy.table import hstack

from lsst.daf.persistence import Butler
from lsst.daf.persistence.butlerExceptions import NoResults
Expand All @@ -21,12 +20,13 @@ def _ensure_butler_instance(butler_or_repo):

_default_fill_value = {'i': -1, 'b': False, 'U': ''}


def _get_fill_value(name, dtype):
kind = np.dtype(dtype).kind
fill_value = _default_fill_value.get(kind, np.nan)
if kind == 'b' and (name.endswith('_flag_bad') or name.endswith('_flag_noGoodPixels')):
fill_value = True
return fill_value
return np.array(fill_value, dtype=np.dtype(dtype))


def generate_object_catalog(output_dir, butler, tract, patches=None,
Expand Down Expand Up @@ -139,10 +139,7 @@ def merge_coadd_forced_src(butler, tract, patch, filters='ugrizy',
ref_table['tract'] = int(tract)
ref_table['patch'] = str(patch)

cat_dtype = None
missing_filters = list()
tables_to_merge = [ref_table]

tables_to_merge = dict()
for filter_this in filters:
filter_name = filters.get(filter_this) if hasattr(filters, 'get') else filter_this
this_data_id = dict(tract_patch_data_id, filter=filter_name)
Expand All @@ -152,7 +149,6 @@ def merge_coadd_forced_src(butler, tract, patch, filters='ugrizy',
except NoResults as e:
if verbose:
print(" ", e)
missing_filters.append(filter_this)
continue

cat = cat.asAstropy()
Expand All @@ -166,24 +162,29 @@ def merge_coadd_forced_src(butler, tract, patch, filters='ugrizy',
calib = butler.get('deepCoadd_calexp_photoCalib', this_data_id)
cat['FLUXMAG0'] = calib.getInstFluxAtZeroMagnitude()

if cat_dtype is None:
cat_dtype = cat.dtype
tables_to_merge[filter_this] = cat

cat.meta = None
for name in cat_dtype.names:
cat.rename_column(name, '{}_{}'.format(filter_this, name))

tables_to_merge.append(cat)
try:
cat_dtype = next(iter(tables_to_merge.values())).dtype
except StopIteration:
if verbose:
print(" No filter can be found in deepCoadd_forced_src")
return

if debug:
assert cat_dtype is not None
assert all(cat_dtype == cat.dtype for cat in tables_to_merge.values())

merged_cat = hstack(tables_to_merge, join_type='exact')
del tables_to_merge

for filter_this in missing_filters:
for name, (dt, _) in cat_dtype.fields.items():
merged_cat['{}_{}'.format(filter_this, name)] = _get_fill_value(name, dt)
merged_cat = ref_table # merged_cat will start with the reference table
merged_cat.meta = None
for filter_this in filters:
if filter_this in tables_to_merge:
cat = tables_to_merge[filter_this]
for name in cat_dtype.names:
merged_cat['{}_{}'.format(filter_this, name)] = cat[name]
del cat, tables_to_merge[filter_this]
else:
for name, (dt, _) in cat_dtype.fields.items():
merged_cat['{}_{}'.format(filter_this, name)] = _get_fill_value(name, dt)

return merged_cat.to_pandas() if return_pandas else merged_cat

Expand Down
43 changes: 35 additions & 8 deletions scripts/merge_parquet_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,48 @@

import pandas as pd

try:
import pyarrow.parquet as pq
except ImportError:
_HAS_PYARROW_ = False
else:
_HAS_PYARROW_ = True


def load_parquet_files_into_dataframe(parquet_files):
return pd.concat(
[pd.read_parquet(f) for f in parquet_files],
axis=0,
ignore_index=True,
)

def run(input_files, output_file, sort_input_files=False, parquet_engine='pyarrow'):

def run(input_files, output_file, sort_input_files=False,
parquet_engine='pyarrow', assume_consistent_schema=False):
if sort_input_files:
input_files = sorted(input_files)
df = load_parquet_files_into_dataframe(input_files)
df.to_parquet(
output_file,
engine=parquet_engine,
compression=None,
index=False,
)

if assume_consistent_schema:
if parquet_engine != "pyarrow" or not _HAS_PYARROW_:
raise ValueError("Must use/have pyarrow when assume_consistent_schema is set to True")
if not input_files:
raise ValueError("No input files to merge")

t = pq.read_table(input_files[0])
with pq.ParquetWriter(output_file, t.schema, flavor='spark') as pqwriter:
pqwriter.write_table(t)
for input_file in input_files[1:]:
t = pq.read_table(input_file)
pqwriter.write_table(t)

else:
df = load_parquet_files_into_dataframe(input_files)
df.to_parquet(
output_file,
engine=parquet_engine,
compression=None,
index=False,
)


if __name__ == "__main__":
Expand Down Expand Up @@ -48,6 +73,8 @@ def run(input_files, output_file, sort_input_files=False, parquet_engine='pyarro
parser.add_argument('--parquet_engine', default='pyarrow',
choices=['fastparquet', 'pyarrow'],
help="""(default: %(default)s)""")
parser.add_argument('--assume-consistent-schema', action='store_true',
help='Assume schema is consistent across input files')
args = parser.parse_args()

if not args.input_files:
Expand Down

0 comments on commit d938097

Please sign in to comment.