Skip to content

Commit

Permalink
Merge branch 'main' into raen/add/verification
Browse files Browse the repository at this point in the history
  • Loading branch information
delucchi-cmu authored Jan 21, 2025
2 parents a7053a2 + c161725 commit d4b98c9
Show file tree
Hide file tree
Showing 16 changed files with 138 additions and 217 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[project]
name = "hats-import"
license = {file = "LICENSE"}
description = "Utility for ingesting large survey data into HATS structure"
readme = "README.md"
authors = [
{ name = "LINCC Frameworks", email = "[email protected]" }
Expand Down
13 changes: 6 additions & 7 deletions src/hats_import/catalog/map_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from hats import pixel_math
from hats.io import file_io, paths
from hats.pixel_math.healpix_pixel import HealpixPixel
from hats.pixel_math.sparse_histogram import SparseHistogram
from hats.pixel_math.sparse_histogram import HistogramAggregator, SparseHistogram
from hats.pixel_math.spatial_index import SPATIAL_INDEX_COLUMN, spatial_index_to_healpix
from upath import UPath

Expand Down Expand Up @@ -105,7 +105,7 @@ def map_to_pixels(
FileNotFoundError: if the file does not exist, or is a directory
"""
try:
histo = SparseHistogram.make_empty(highest_order)
histo = HistogramAggregator(highest_order)

if use_healpix_29:
read_columns = [SPATIAL_INDEX_COLUMN]
Expand All @@ -123,12 +123,11 @@ def map_to_pixels(
):
mapped_pixel, count_at_pixel = np.unique(mapped_pixels, return_counts=True)

partial = SparseHistogram.make_from_counts(
mapped_pixel, count_at_pixel, healpix_order=highest_order
)
histo.add(partial)
histo.add(SparseHistogram(mapped_pixel, count_at_pixel, highest_order))

histo.to_file(ResumePlan.partial_histogram_file(tmp_path=resume_path, mapping_key=mapping_key))
histo.to_sparse().to_file(
ResumePlan.partial_histogram_file(tmp_path=resume_path, mapping_key=mapping_key)
)
except Exception as exception: # pylint: disable=broad-exception-caught
print_task_failure(f"Failed MAPPING stage with file {input_file}", exception)
raise exception
Expand Down
68 changes: 30 additions & 38 deletions src/hats_import/catalog/resume_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@
from __future__ import annotations

import pickle
import re
from dataclasses import dataclass, field
from typing import List, Optional, Tuple
from typing import Dict, List, Optional, Tuple

import hats.pixel_math.healpix_shim as hp
import numpy as np
from hats import pixel_math
from hats.io import file_io
from hats.pixel_math import empty_histogram
from hats.pixel_math.healpix_pixel import HealpixPixel
from hats.pixel_math.sparse_histogram import SparseHistogram
from hats.pixel_math.sparse_histogram import HistogramAggregator, SparseHistogram
from numpy import frombuffer
from upath import UPath

Expand All @@ -29,7 +29,7 @@ class ResumePlan(PipelineResumePlan):
"""list of files (and job keys) that have yet to be mapped"""
split_keys: List[Tuple[str, str]] = field(default_factory=list)
"""set of files (and job keys) that have yet to be split"""
destination_pixel_map: Optional[List[Tuple[int, int, int]]] = None
destination_pixel_map: Optional[Dict[HealpixPixel, int]] = None
"""Destination pixels and their expected final count"""
should_run_mapping: bool = True
should_run_splitting: bool = True
Expand Down Expand Up @@ -145,13 +145,11 @@ def get_remaining_map_keys(self):
Returns:
list of mapping keys *not* found in files like /resume/path/mapping_key.npz
"""
prefix = file_io.append_paths_to_pointer(self.tmp_path, self.HISTOGRAMS_DIR)
mapped_keys = self.get_keys_from_file_names(prefix, ".npz")
return [
(f"map_{i}", file_path)
for i, file_path in enumerate(self.input_paths)
if f"map_{i}" not in mapped_keys
]
prefix = file_io.get_upath(self.tmp_path) / self.HISTOGRAMS_DIR
map_file_pattern = re.compile(r"map_(\d+).npz")
done_indexes = [int(map_file_pattern.match(path.name).group(1)) for path in prefix.glob("*.npz")]
remaining_indexes = list(set(range(0, len(self.input_paths))) - (set(done_indexes)))
return [(f"map_{key}", self.input_paths[key]) for key in remaining_indexes]

def read_histogram(self, healpix_order):
"""Return histogram with healpix_order'd shape
Expand All @@ -167,20 +165,14 @@ def read_histogram(self, healpix_order):
if len(remaining_map_files) > 0:
raise RuntimeError(f"{len(remaining_map_files)} map stages did not complete successfully.")
histogram_files = file_io.find_files_matching_path(self.tmp_path, self.HISTOGRAMS_DIR, "*.npz")
aggregate_histogram = empty_histogram(healpix_order)
aggregate_histogram = HistogramAggregator(healpix_order)
for partial_file_name in histogram_files:
partial = SparseHistogram.from_file(partial_file_name)
partial_as_array = partial.to_array()
if aggregate_histogram.shape != partial_as_array.shape:
raise ValueError(
"The histogram partials have incompatible sizes due to different healpix orders. "
+ "To start the pipeline from scratch with the current order set `resume` to False."
)
aggregate_histogram = np.add(aggregate_histogram, partial_as_array)
aggregate_histogram.add(partial)

file_name = file_io.append_paths_to_pointer(self.tmp_path, self.HISTOGRAM_BINARY_FILE)
with open(file_name, "wb+") as file_handle:
file_handle.write(aggregate_histogram.data)
file_handle.write(aggregate_histogram.full_histogram)
if self.delete_resume_log_files:
file_io.remove_directory(
file_io.append_paths_to_pointer(self.tmp_path, self.HISTOGRAMS_DIR),
Expand Down Expand Up @@ -220,12 +212,11 @@ def get_remaining_split_keys(self):
Returns:
list of splitting keys *not* found in files like /resume/path/split_key.done
"""
split_keys = set(self.read_done_keys(self.SPLITTING_STAGE))
return [
(f"split_{i}", file_path)
for i, file_path in enumerate(self.input_paths)
if f"split_{i}" not in split_keys
]
prefix = file_io.get_upath(self.tmp_path) / self.SPLITTING_STAGE
split_file_pattern = re.compile(r"split_(\d+)_done")
done_indexes = [int(split_file_pattern.match(path.name).group(1)) for path in prefix.glob("*_done")]
remaining_indexes = list(set(range(0, len(self.input_paths))) - set(done_indexes))
return [(f"split_{key}", self.input_paths[key]) for key in remaining_indexes]

@classmethod
def splitting_key_done(cls, tmp_path, splitting_key: str):
Expand Down Expand Up @@ -308,11 +299,11 @@ def get_alignment_file(
with open(file_name, "rb") as pickle_file:
alignment = pickle.load(pickle_file)
non_none_elements = alignment[alignment != np.array(None)]
self.destination_pixel_map = np.unique(non_none_elements)
self.destination_pixel_map = [
(order, pix, count) for (order, pix, count) in self.destination_pixel_map if int(count) > 0
]
total_rows = sum(count for (_, _, count) in self.destination_pixel_map)
pixel_list = np.unique(non_none_elements)
self.destination_pixel_map = {
HealpixPixel(order, pix): count for (order, pix, count) in pixel_list if int(count) > 0
}
total_rows = sum(self.destination_pixel_map.values())
if total_rows != expected_total_rows:
raise ValueError(
f"Number of rows ({total_rows}) does not match expectation ({expected_total_rows})"
Expand All @@ -337,21 +328,22 @@ def get_reduce_items(self):
- number of rows expected for this pixel
- reduce key (string of destination order+pixel)
"""
reduced_keys = set(self.read_done_keys(self.REDUCING_STAGE))
if self.destination_pixel_map is None:
raise RuntimeError("destination pixel map not provided for progress tracking.")
reduce_items = [
(HealpixPixel(hp_order, hp_pixel), row_count, f"{hp_order}_{hp_pixel}")
for hp_order, hp_pixel, row_count in self.destination_pixel_map
if f"{hp_order}_{hp_pixel}" not in reduced_keys

reduced_pixels = self.read_done_pixels(self.REDUCING_STAGE)

remaining_pixels = list(set(self.destination_pixel_map.keys()) - set(reduced_pixels))
return [
(hp_pixel, self.destination_pixel_map[hp_pixel], f"{hp_pixel.order}_{hp_pixel.pixel}")
for hp_pixel in remaining_pixels
]
return reduce_items

def get_destination_pixels(self):
"""Create HealpixPixel list of all destination pixels."""
if self.destination_pixel_map is None:
raise RuntimeError("destination pixel map not known.")
return [HealpixPixel(hp_order, hp_pixel) for hp_order, hp_pixel, _ in self.destination_pixel_map]
return list(self.destination_pixel_map.keys())

def wait_for_reducing(self, futures):
"""Wait for reducing futures to complete."""
Expand Down
27 changes: 7 additions & 20 deletions src/hats_import/catalog/run_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import os
import pickle
from pathlib import Path

import hats.io.file_io as io
from hats.catalog import PartitionInfo
Expand All @@ -19,27 +18,13 @@
from hats_import.catalog.resume_plan import ResumePlan


def _validate_arguments(args):
"""
Verify that the args for run are valid: they exist, are of the appropriate type,
and do not specify an output which is a valid catalog.
Raises ValueError if they are invalid.
"""
def run(args, client):
"""Run catalog creation pipeline."""
if not args:
raise ValueError("args is required and should be type ImportArguments")
if not isinstance(args, ImportArguments):
raise ValueError("args must be type ImportArguments")

potential_path = Path(args.output_path) / args.output_artifact_name
if is_valid_catalog(potential_path):
raise ValueError(f"Output path {potential_path} already contains a valid catalog")


def run(args, client):
"""Run catalog creation pipeline."""
_validate_arguments(args)

resume_plan = ResumePlan(import_args=args)

pickled_reader_file = os.path.join(resume_plan.tmp_path, "reader.pickle")
Expand Down Expand Up @@ -137,7 +122,7 @@ def run(args, client):

# All done - write out the metadata
if resume_plan.should_run_finishing:
with resume_plan.print_progress(total=4, stage_name="Finishing") as step_progress:
with resume_plan.print_progress(total=5, stage_name="Finishing") as step_progress:
partition_info = PartitionInfo.from_healpix(resume_plan.get_destination_pixels())
partition_info_file = paths.get_partition_info_pointer(args.catalog_path)
partition_info.write_to_file(partition_info_file)
Expand All @@ -151,12 +136,14 @@ def run(args, client):
else:
partition_info.write_to_metadata_files(args.catalog_path)
step_progress.update(1)
io.write_fits_image(raw_histogram, paths.get_point_map_file_pointer(args.catalog_path))
step_progress.update(1)
catalog_info = args.to_table_properties(
total_rows, partition_info.get_highest_order(), partition_info.calculate_fractional_coverage()
)
catalog_info.to_properties_file(args.catalog_path)
step_progress.update(1)
io.write_fits_image(raw_histogram, paths.get_point_map_file_pointer(args.catalog_path))
step_progress.update(1)
resume_plan.clean_resume_files()
step_progress.update(1)
assert is_valid_catalog(args.catalog_path)
step_progress.update(1)
8 changes: 0 additions & 8 deletions src/hats_import/margin_cache/margin_cache.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
from pathlib import Path

from hats.catalog import PartitionInfo
from hats.io import file_io, parquet_metadata, paths
from hats.io.validation import is_valid_catalog

import hats_import.margin_cache.margin_cache_map_reduce as mcmr
from hats_import.margin_cache.margin_cache_resume_plan import MarginCachePlan
Expand All @@ -18,11 +15,6 @@ def generate_margin_cache(args, client):
args (MarginCacheArguments): A valid `MarginCacheArguments` object.
client (dask.distributed.Client): A dask distributed client object.
"""
potential_path = Path(args.output_path) / args.output_artifact_name
# Verify that the planned output path is not occupied by a valid catalog
if is_valid_catalog(potential_path):
raise ValueError(f"Output path {potential_path} already contains a valid catalog")

resume_plan = MarginCachePlan(args)
original_catalog_metadata = paths.get_common_metadata_pointer(args.input_catalog_path)

Expand Down
19 changes: 6 additions & 13 deletions src/hats_import/margin_cache/margin_cache_resume_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,9 @@ def get_mapping_total(self) -> int:

def get_remaining_map_keys(self):
"""Fetch a tuple for each pixel/partition left to map."""
map_keys = set(self.read_done_keys(self.MAPPING_STAGE))
return [
(f"{hp_pixel.order}_{hp_pixel.pixel}", hp_pixel)
for hp_pixel in self.partition_pixels
if f"{hp_pixel.order}_{hp_pixel.pixel}" not in map_keys
]
mapped_pixels = set(self.read_done_pixels(self.MAPPING_STAGE))
remaining_pixels = list(set(self.partition_pixels) - set(mapped_pixels))
return [(f"{hp_pixel.order}_{hp_pixel.pixel}", hp_pixel) for hp_pixel in remaining_pixels]

@classmethod
def reducing_key_done(cls, tmp_path, reducing_key: str):
Expand All @@ -125,13 +122,9 @@ def reducing_key_done(cls, tmp_path, reducing_key: str):

def get_remaining_reduce_keys(self):
"""Fetch a tuple for each object catalog pixel to reduce."""
reduced_keys = set(self.read_done_keys(self.REDUCING_STAGE))
reduce_items = [
(f"{hp_pixel.order}_{hp_pixel.pixel}", hp_pixel)
for hp_pixel in self.combined_pixels
if f"{hp_pixel.order}_{hp_pixel.pixel}" not in reduced_keys
]
return reduce_items
reduced_pixels = self.read_done_pixels(self.REDUCING_STAGE)
remaining_pixels = list(set(self.combined_pixels) - set(reduced_pixels))
return [(f"{hp_pixel.order}_{hp_pixel.pixel}", hp_pixel) for hp_pixel in remaining_pixels]

def is_reducing_done(self) -> bool:
"""Are there partitions left to reduce?"""
Expand Down
32 changes: 8 additions & 24 deletions src/hats_import/pipeline_resume_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,43 +119,27 @@ def read_markers(self, stage_name: str) -> dict[str, list[str]]:
prefix = file_io.append_paths_to_pointer(self.tmp_path, stage_name)
result = {}
result_files = file_io.find_files_matching_path(prefix, "*_done")
done_file_pattern = re.compile(r"(.*)_done")
for file_path in result_files:
match = re.match(r"(.*)_done", str(file_path.name))
match = done_file_pattern.match(str(file_path.name))
if not match:
raise ValueError(f"Unexpected file found: {file_path.name}")
key = match.group(1)
result[key] = file_io.load_text_file(file_path)
return result

def read_done_keys(self, stage_name):
"""Inspect the stage's directory of done files, fetching the keys from done file names.
def read_done_pixels(self, stage_name):
"""Inspect the stage's directory of done files, fetching the pixel keys from done file names.
Args:
stage_name(str): name of the stage (e.g. mapping, reducing)
Return:
List[str] - all keys found in done directory
List[HealpixPixel] - all pixel keys found in done directory
"""
prefix = file_io.append_paths_to_pointer(self.tmp_path, stage_name)
return self.get_keys_from_file_names(prefix, "_done")

@staticmethod
def get_keys_from_file_names(directory, extension):
"""Gather keys for successful tasks from result file names.
Args:
directory: where to look for result files. this is NOT a recursive lookup
extension (str): file suffix to look for and to remove from all file names.
if you expect a file like "map_01.csv", extension should be ".csv"
Returns:
list of keys taken from files like /resume/path/{key}{extension}
"""
result_files = file_io.find_files_matching_path(directory, f"*{extension}")
keys = []
for file_path in result_files:
match = re.match(r"(.*)" + extension, str(file_path.name))
keys.append(match.group(1))
return keys
done_file_pattern = re.compile(r"(\d+)_(\d+)_done")
pixel_tuples = [done_file_pattern.match(path.name).group(1, 2) for path in prefix.glob("*_done")]
return [HealpixPixel(int(match[0]), int(match[1])) for match in pixel_tuples]

def clean_resume_files(self):
"""Remove the intermediate directory created in execution if the user decided
Expand Down
3 changes: 3 additions & 0 deletions src/hats_import/runtime_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pathlib import Path

from hats.io import file_io
from hats.io.validation import is_valid_catalog
from upath import UPath

# pylint: disable=too-many-instance-attributes
Expand Down Expand Up @@ -89,6 +90,8 @@ def _check_arguments(self):
raise ValueError("dask_threads_per_worker should be greater than 0")

self.catalog_path = file_io.get_upath(self.output_path) / self.output_artifact_name
if is_valid_catalog(self.catalog_path):
raise ValueError(f"Output path {self.catalog_path} already contains a valid catalog")
if not self.resume:
file_io.remove_directory(self.catalog_path, ignore_errors=True)
file_io.make_directory(self.catalog_path, exist_ok=True)
Expand Down
Loading

0 comments on commit d4b98c9

Please sign in to comment.