diff --git a/src/hats_import/catalog/resume_plan.py b/src/hats_import/catalog/resume_plan.py index 3748b64f..5cec0e01 100644 --- a/src/hats_import/catalog/resume_plan.py +++ b/src/hats_import/catalog/resume_plan.py @@ -3,8 +3,9 @@ from __future__ import annotations import pickle +import re from dataclasses import dataclass, field -from typing import List, Optional, Tuple +from typing import Dict, List, Optional, Tuple import hats.pixel_math.healpix_shim as hp import numpy as np @@ -28,7 +29,7 @@ class ResumePlan(PipelineResumePlan): """list of files (and job keys) that have yet to be mapped""" split_keys: List[Tuple[str, str]] = field(default_factory=list) """set of files (and job keys) that have yet to be split""" - destination_pixel_map: Optional[List[Tuple[int, int, int]]] = None + destination_pixel_map: Optional[Dict[HealpixPixel, int]] = None """Destination pixels and their expected final count""" should_run_mapping: bool = True should_run_splitting: bool = True @@ -144,13 +145,11 @@ def get_remaining_map_keys(self): Returns: list of mapping keys *not* found in files like /resume/path/mapping_key.npz """ - prefix = file_io.append_paths_to_pointer(self.tmp_path, self.HISTOGRAMS_DIR) - mapped_keys = self.get_keys_from_file_names(prefix, ".npz") - return [ - (f"map_{i}", file_path) - for i, file_path in enumerate(self.input_paths) - if f"map_{i}" not in mapped_keys - ] + prefix = file_io.get_upath(self.tmp_path) / self.HISTOGRAMS_DIR + map_file_pattern = re.compile(r"map_(\d+).npz") + done_indexes = [int(map_file_pattern.match(path.name).group(1)) for path in prefix.glob("*.npz")] + remaining_indexes = list(set(range(0, len(self.input_paths))) - (set(done_indexes))) + return [(f"map_{key}", self.input_paths[key]) for key in remaining_indexes] def read_histogram(self, healpix_order): """Return histogram with healpix_order'd shape @@ -213,12 +212,11 @@ def get_remaining_split_keys(self): Returns: list of splitting keys *not* found in files like /resume/path/split_key.done """ - split_keys = set(self.read_done_keys(self.SPLITTING_STAGE)) - return [ - (f"split_{i}", file_path) - for i, file_path in enumerate(self.input_paths) - if f"split_{i}" not in split_keys - ] + prefix = file_io.get_upath(self.tmp_path) / self.SPLITTING_STAGE + split_file_pattern = re.compile(r"split_(\d+)_done") + done_indexes = [int(split_file_pattern.match(path.name).group(1)) for path in prefix.glob("*_done")] + remaining_indexes = list(set(range(0, len(self.input_paths))) - set(done_indexes)) + return [(f"split_{key}", self.input_paths[key]) for key in remaining_indexes] @classmethod def splitting_key_done(cls, tmp_path, splitting_key: str): @@ -301,11 +299,11 @@ def get_alignment_file( with open(file_name, "rb") as pickle_file: alignment = pickle.load(pickle_file) non_none_elements = alignment[alignment != np.array(None)] - self.destination_pixel_map = np.unique(non_none_elements) - self.destination_pixel_map = [ - (order, pix, count) for (order, pix, count) in self.destination_pixel_map if int(count) > 0 - ] - total_rows = sum(count for (_, _, count) in self.destination_pixel_map) + pixel_list = np.unique(non_none_elements) + self.destination_pixel_map = { + HealpixPixel(order, pix): count for (order, pix, count) in pixel_list if int(count) > 0 + } + total_rows = sum(self.destination_pixel_map.values()) if total_rows != expected_total_rows: raise ValueError( f"Number of rows ({total_rows}) does not match expectation ({expected_total_rows})" @@ -330,21 +328,22 @@ def get_reduce_items(self): - number of rows expected for this pixel - reduce key (string of destination order+pixel) """ - reduced_keys = set(self.read_done_keys(self.REDUCING_STAGE)) if self.destination_pixel_map is None: raise RuntimeError("destination pixel map not provided for progress tracking.") - reduce_items = [ - (HealpixPixel(hp_order, hp_pixel), row_count, f"{hp_order}_{hp_pixel}") - for hp_order, hp_pixel, row_count in self.destination_pixel_map - if f"{hp_order}_{hp_pixel}" not in reduced_keys + + reduced_pixels = self.read_done_pixels(self.REDUCING_STAGE) + + remaining_pixels = list(set(self.destination_pixel_map.keys()) - set(reduced_pixels)) + return [ + (hp_pixel, self.destination_pixel_map[hp_pixel], f"{hp_pixel.order}_{hp_pixel.pixel}") + for hp_pixel in remaining_pixels ] - return reduce_items def get_destination_pixels(self): """Create HealpixPixel list of all destination pixels.""" if self.destination_pixel_map is None: raise RuntimeError("destination pixel map not known.") - return [HealpixPixel(hp_order, hp_pixel) for hp_order, hp_pixel, _ in self.destination_pixel_map] + return list(self.destination_pixel_map.keys()) def wait_for_reducing(self, futures): """Wait for reducing futures to complete.""" diff --git a/src/hats_import/margin_cache/margin_cache_resume_plan.py b/src/hats_import/margin_cache/margin_cache_resume_plan.py index a26608b4..030bd920 100644 --- a/src/hats_import/margin_cache/margin_cache_resume_plan.py +++ b/src/hats_import/margin_cache/margin_cache_resume_plan.py @@ -106,12 +106,9 @@ def get_mapping_total(self) -> int: def get_remaining_map_keys(self): """Fetch a tuple for each pixel/partition left to map.""" - map_keys = set(self.read_done_keys(self.MAPPING_STAGE)) - return [ - (f"{hp_pixel.order}_{hp_pixel.pixel}", hp_pixel) - for hp_pixel in self.partition_pixels - if f"{hp_pixel.order}_{hp_pixel.pixel}" not in map_keys - ] + mapped_pixels = set(self.read_done_pixels(self.MAPPING_STAGE)) + remaining_pixels = list(set(self.partition_pixels) - set(mapped_pixels)) + return [(f"{hp_pixel.order}_{hp_pixel.pixel}", hp_pixel) for hp_pixel in remaining_pixels] @classmethod def reducing_key_done(cls, tmp_path, reducing_key: str): @@ -125,13 +122,9 @@ def reducing_key_done(cls, tmp_path, reducing_key: str): def get_remaining_reduce_keys(self): """Fetch a tuple for each object catalog pixel to reduce.""" - reduced_keys = set(self.read_done_keys(self.REDUCING_STAGE)) - reduce_items = [ - (f"{hp_pixel.order}_{hp_pixel.pixel}", hp_pixel) - for hp_pixel in self.combined_pixels - if f"{hp_pixel.order}_{hp_pixel.pixel}" not in reduced_keys - ] - return reduce_items + reduced_pixels = self.read_done_pixels(self.REDUCING_STAGE) + remaining_pixels = list(set(self.combined_pixels) - set(reduced_pixels)) + return [(f"{hp_pixel.order}_{hp_pixel.pixel}", hp_pixel) for hp_pixel in remaining_pixels] def is_reducing_done(self) -> bool: """Are there partitions left to reduce?""" diff --git a/src/hats_import/pipeline_resume_plan.py b/src/hats_import/pipeline_resume_plan.py index d7ded581..449ab53c 100644 --- a/src/hats_import/pipeline_resume_plan.py +++ b/src/hats_import/pipeline_resume_plan.py @@ -119,43 +119,27 @@ def read_markers(self, stage_name: str) -> dict[str, list[str]]: prefix = file_io.append_paths_to_pointer(self.tmp_path, stage_name) result = {} result_files = file_io.find_files_matching_path(prefix, "*_done") + done_file_pattern = re.compile(r"(.*)_done") for file_path in result_files: - match = re.match(r"(.*)_done", str(file_path.name)) + match = done_file_pattern.match(str(file_path.name)) if not match: raise ValueError(f"Unexpected file found: {file_path.name}") key = match.group(1) result[key] = file_io.load_text_file(file_path) return result - def read_done_keys(self, stage_name): - """Inspect the stage's directory of done files, fetching the keys from done file names. + def read_done_pixels(self, stage_name): + """Inspect the stage's directory of done files, fetching the pixel keys from done file names. Args: stage_name(str): name of the stage (e.g. mapping, reducing) Return: - List[str] - all keys found in done directory + List[HealpixPixel] - all pixel keys found in done directory """ prefix = file_io.append_paths_to_pointer(self.tmp_path, stage_name) - return self.get_keys_from_file_names(prefix, "_done") - - @staticmethod - def get_keys_from_file_names(directory, extension): - """Gather keys for successful tasks from result file names. - - Args: - directory: where to look for result files. this is NOT a recursive lookup - extension (str): file suffix to look for and to remove from all file names. - if you expect a file like "map_01.csv", extension should be ".csv" - - Returns: - list of keys taken from files like /resume/path/{key}{extension} - """ - result_files = file_io.find_files_matching_path(directory, f"*{extension}") - keys = [] - for file_path in result_files: - match = re.match(r"(.*)" + extension, str(file_path.name)) - keys.append(match.group(1)) - return keys + done_file_pattern = re.compile(r"(\d+)_(\d+)_done") + pixel_tuples = [done_file_pattern.match(path.name).group(1, 2) for path in prefix.glob("*_done")] + return [HealpixPixel(int(match[0]), int(match[1])) for match in pixel_tuples] def clean_resume_files(self): """Remove the intermediate directory created in execution if the user decided diff --git a/src/hats_import/soap/resume_plan.py b/src/hats_import/soap/resume_plan.py index bef61335..93a3d710 100644 --- a/src/hats_import/soap/resume_plan.py +++ b/src/hats_import/soap/resume_plan.py @@ -2,6 +2,7 @@ from __future__ import annotations +import re from dataclasses import dataclass, field from typing import List, Optional, Tuple @@ -97,12 +98,16 @@ def get_sources_to_count(self, source_pixel_map=None): self.source_pixel_map = source_pixel_map if self.source_pixel_map is None: raise ValueError("source_pixel_map not provided for progress tracking.") + count_file_pattern = re.compile(r"(\d+)_(\d+).csv") + counted_pixel_tuples = [ + count_file_pattern.match(path.name).group(1, 2) for path in self.tmp_path.glob("*.csv") + ] + counted_pixels = [HealpixPixel(int(match[0]), int(match[1])) for match in counted_pixel_tuples] - counted_keys = set(self.get_keys_from_file_names(self.tmp_path, ".csv")) + remaining_pixels = list(set(source_pixel_map.keys()) - set(counted_pixels)) return [ - (hp_pixel, object_pixels, f"{hp_pixel.order}_{hp_pixel.pixel}") - for hp_pixel, object_pixels in source_pixel_map.items() - if f"{hp_pixel.order}_{hp_pixel.pixel}" not in counted_keys + (hp_pixel, source_pixel_map[hp_pixel], f"{hp_pixel.order}_{hp_pixel.pixel}") + for hp_pixel in remaining_pixels ] @classmethod @@ -117,13 +122,9 @@ def reducing_key_done(cls, tmp_path, reducing_key: str): def get_objects_to_reduce(self): """Fetch a tuple for each object catalog pixel to reduce.""" - reduced_keys = set(self.read_done_keys(self.REDUCING_STAGE)) - reduce_items = [ - (hp_pixel, f"{hp_pixel.order}_{hp_pixel.pixel}") - for hp_pixel in self.object_catalog.get_healpix_pixels() - if f"{hp_pixel.order}_{hp_pixel.pixel}" not in reduced_keys - ] - return reduce_items + reduced_pixels = set(self.read_done_pixels(self.REDUCING_STAGE)) + remaining_pixels = list(set(self.object_catalog.get_healpix_pixels()) - set(reduced_pixels)) + return [(hp_pixel, f"{hp_pixel.order}_{hp_pixel.pixel}") for hp_pixel in remaining_pixels] def is_reducing_done(self) -> bool: """Are there partitions left to reduce?""" diff --git a/tests/hats_import/catalog/test_resume_plan.py b/tests/hats_import/catalog/test_resume_plan.py index 61afd5cd..ca8f0535 100644 --- a/tests/hats_import/catalog/test_resume_plan.py +++ b/tests/hats_import/catalog/test_resume_plan.py @@ -3,6 +3,7 @@ import numpy as np import numpy.testing as npt import pytest +from hats.pixel_math.healpix_pixel import HealpixPixel from hats.pixel_math.sparse_histogram import SparseHistogram from hats_import.catalog.resume_plan import ResumePlan @@ -68,6 +69,23 @@ def test_same_input_paths(tmp_path, small_sky_single_file, formats_headers_csv): ) +def test_remaining_map_keys(tmp_path): + """Test that we can read what we write into a histogram file.""" + num_inputs = 1_000 + input_paths = [f"foo_{i}" for i in range(0, num_inputs)] + plan = ResumePlan(tmp_path=tmp_path, progress_bar=False, input_paths=input_paths) + + remaining_keys = plan.get_remaining_map_keys() + assert len(remaining_keys) == num_inputs + + histogram = SparseHistogram([11], [131], 0) + for i in range(0, num_inputs): + histogram.to_file(ResumePlan.partial_histogram_file(tmp_path=tmp_path, mapping_key=f"map_{i}")) + + remaining_keys = plan.get_remaining_map_keys() + assert len(remaining_keys) == 0 + + def test_read_write_histogram(tmp_path): """Test that we can read what we write into a histogram file.""" plan = ResumePlan(tmp_path=tmp_path, progress_bar=False, input_paths=["foo1"]) @@ -171,7 +189,6 @@ def test_some_split_task_failures(tmp_path, dask_client): def test_get_reduce_items(tmp_path): """Test generation of remaining reduce items""" - destination_pixel_map = [(0, 11, 131)] plan = ResumePlan(tmp_path=tmp_path, progress_bar=False) with pytest.raises(RuntimeError, match="destination pixel map"): @@ -180,7 +197,7 @@ def test_get_reduce_items(tmp_path): with pytest.raises(RuntimeError, match="destination pixel map"): remaining_reduce_items = plan.get_destination_pixels() - plan.destination_pixel_map = destination_pixel_map + plan.destination_pixel_map = {HealpixPixel(0, 11): 131} remaining_reduce_items = plan.get_reduce_items() assert len(remaining_reduce_items) == 1 @@ -197,8 +214,7 @@ def test_some_reduce_task_failures(tmp_path, dask_client): """Test that we only consider reduce stage successful if all done files are written""" plan = ResumePlan(tmp_path=tmp_path, progress_bar=False) - destination_pixel_map = [(0, 11, 131)] - plan.destination_pixel_map = destination_pixel_map + plan.destination_pixel_map = {HealpixPixel(0, 11): 131} remaining_reduce_items = plan.get_reduce_items() assert len(remaining_reduce_items) == 1 diff --git a/tests/hats_import/test_pipeline_resume_plan.py b/tests/hats_import/test_pipeline_resume_plan.py index ab74db56..7db94655 100644 --- a/tests/hats_import/test_pipeline_resume_plan.py +++ b/tests/hats_import/test_pipeline_resume_plan.py @@ -8,29 +8,6 @@ from hats_import.pipeline_resume_plan import PipelineResumePlan, get_formatted_stage_name -def test_done_key(tmp_path): - """Verify expected behavior of marking stage progress via done files.""" - plan = PipelineResumePlan(tmp_path=tmp_path, progress_bar=False) - stage = "testing" - (tmp_path / stage).mkdir(parents=True) - - keys = plan.read_done_keys(stage) - assert len(keys) == 0 - - PipelineResumePlan.touch_key_done_file(tmp_path, stage, "key_01") - keys = plan.read_done_keys(stage) - assert keys == ["key_01"] - - PipelineResumePlan.touch_key_done_file(tmp_path, stage, "key_02") - keys = plan.read_done_keys(stage) - assert keys == ["key_01", "key_02"] - - plan.clean_resume_files() - - keys = plan.read_done_keys(stage) - assert len(keys) == 0 - - def test_done_file(tmp_path): """Verify expected behavior of done file""" plan = PipelineResumePlan(tmp_path=tmp_path, progress_bar=False) @@ -53,25 +30,6 @@ def test_done_file(tmp_path): assert not plan.done_file_exists(done_file) -def test_get_keys_from_results(tmp_path): - """Test that we can create a list of completed keys via the output results files.""" - plan = PipelineResumePlan(tmp_path=tmp_path, progress_bar=False, resume=False) - keys = PipelineResumePlan.get_keys_from_file_names(tmp_path, ".foo") - assert len(keys) == 0 - - Path(tmp_path, "file_0.foo").touch() - keys = PipelineResumePlan.get_keys_from_file_names(tmp_path, ".foo") - assert keys == ["file_0"] - - Path(tmp_path, "file_1.foo").touch() - keys = PipelineResumePlan.get_keys_from_file_names(tmp_path, ".foo") - assert keys == ["file_0", "file_1"] - - plan.clean_resume_files() - keys = PipelineResumePlan.get_keys_from_file_names(tmp_path, ".foo") - assert len(keys) == 0 - - def test_safe_to_resume(tmp_path): """Check that we throw errors when it's not safe to resume.""" plan = PipelineResumePlan(tmp_path=tmp_path, progress_bar=False, resume=False)