Skip to content

Commit

Permalink
Address slow list operations in progress tracking. (#468)
Browse files Browse the repository at this point in the history
* Address slow list operations in progress tracking.

* Dict comprehension

* Apply suggestions from code review

Co-authored-by: Derek T. Jones <[email protected]>

* Compile regular expressions.

---------

Co-authored-by: Derek T. Jones <[email protected]>
  • Loading branch information
delucchi-cmu and gitosaurus authored Jan 8, 2025
1 parent bacabec commit c161725
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 121 deletions.
53 changes: 26 additions & 27 deletions src/hats_import/catalog/resume_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
from __future__ import annotations

import pickle
import re
from dataclasses import dataclass, field
from typing import List, Optional, Tuple
from typing import Dict, List, Optional, Tuple

import hats.pixel_math.healpix_shim as hp
import numpy as np
Expand All @@ -28,7 +29,7 @@ class ResumePlan(PipelineResumePlan):
"""list of files (and job keys) that have yet to be mapped"""
split_keys: List[Tuple[str, str]] = field(default_factory=list)
"""set of files (and job keys) that have yet to be split"""
destination_pixel_map: Optional[List[Tuple[int, int, int]]] = None
destination_pixel_map: Optional[Dict[HealpixPixel, int]] = None
"""Destination pixels and their expected final count"""
should_run_mapping: bool = True
should_run_splitting: bool = True
Expand Down Expand Up @@ -144,13 +145,11 @@ def get_remaining_map_keys(self):
Returns:
list of mapping keys *not* found in files like /resume/path/mapping_key.npz
"""
prefix = file_io.append_paths_to_pointer(self.tmp_path, self.HISTOGRAMS_DIR)
mapped_keys = self.get_keys_from_file_names(prefix, ".npz")
return [
(f"map_{i}", file_path)
for i, file_path in enumerate(self.input_paths)
if f"map_{i}" not in mapped_keys
]
prefix = file_io.get_upath(self.tmp_path) / self.HISTOGRAMS_DIR
map_file_pattern = re.compile(r"map_(\d+).npz")
done_indexes = [int(map_file_pattern.match(path.name).group(1)) for path in prefix.glob("*.npz")]
remaining_indexes = list(set(range(0, len(self.input_paths))) - (set(done_indexes)))
return [(f"map_{key}", self.input_paths[key]) for key in remaining_indexes]

def read_histogram(self, healpix_order):
"""Return histogram with healpix_order'd shape
Expand Down Expand Up @@ -213,12 +212,11 @@ def get_remaining_split_keys(self):
Returns:
list of splitting keys *not* found in files like /resume/path/split_key.done
"""
split_keys = set(self.read_done_keys(self.SPLITTING_STAGE))
return [
(f"split_{i}", file_path)
for i, file_path in enumerate(self.input_paths)
if f"split_{i}" not in split_keys
]
prefix = file_io.get_upath(self.tmp_path) / self.SPLITTING_STAGE
split_file_pattern = re.compile(r"split_(\d+)_done")
done_indexes = [int(split_file_pattern.match(path.name).group(1)) for path in prefix.glob("*_done")]
remaining_indexes = list(set(range(0, len(self.input_paths))) - set(done_indexes))
return [(f"split_{key}", self.input_paths[key]) for key in remaining_indexes]

@classmethod
def splitting_key_done(cls, tmp_path, splitting_key: str):
Expand Down Expand Up @@ -301,11 +299,11 @@ def get_alignment_file(
with open(file_name, "rb") as pickle_file:
alignment = pickle.load(pickle_file)
non_none_elements = alignment[alignment != np.array(None)]
self.destination_pixel_map = np.unique(non_none_elements)
self.destination_pixel_map = [
(order, pix, count) for (order, pix, count) in self.destination_pixel_map if int(count) > 0
]
total_rows = sum(count for (_, _, count) in self.destination_pixel_map)
pixel_list = np.unique(non_none_elements)
self.destination_pixel_map = {
HealpixPixel(order, pix): count for (order, pix, count) in pixel_list if int(count) > 0
}
total_rows = sum(self.destination_pixel_map.values())
if total_rows != expected_total_rows:
raise ValueError(
f"Number of rows ({total_rows}) does not match expectation ({expected_total_rows})"
Expand All @@ -330,21 +328,22 @@ def get_reduce_items(self):
- number of rows expected for this pixel
- reduce key (string of destination order+pixel)
"""
reduced_keys = set(self.read_done_keys(self.REDUCING_STAGE))
if self.destination_pixel_map is None:
raise RuntimeError("destination pixel map not provided for progress tracking.")
reduce_items = [
(HealpixPixel(hp_order, hp_pixel), row_count, f"{hp_order}_{hp_pixel}")
for hp_order, hp_pixel, row_count in self.destination_pixel_map
if f"{hp_order}_{hp_pixel}" not in reduced_keys

reduced_pixels = self.read_done_pixels(self.REDUCING_STAGE)

remaining_pixels = list(set(self.destination_pixel_map.keys()) - set(reduced_pixels))
return [
(hp_pixel, self.destination_pixel_map[hp_pixel], f"{hp_pixel.order}_{hp_pixel.pixel}")
for hp_pixel in remaining_pixels
]
return reduce_items

def get_destination_pixels(self):
"""Create HealpixPixel list of all destination pixels."""
if self.destination_pixel_map is None:
raise RuntimeError("destination pixel map not known.")
return [HealpixPixel(hp_order, hp_pixel) for hp_order, hp_pixel, _ in self.destination_pixel_map]
return list(self.destination_pixel_map.keys())

def wait_for_reducing(self, futures):
"""Wait for reducing futures to complete."""
Expand Down
19 changes: 6 additions & 13 deletions src/hats_import/margin_cache/margin_cache_resume_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,9 @@ def get_mapping_total(self) -> int:

def get_remaining_map_keys(self):
"""Fetch a tuple for each pixel/partition left to map."""
map_keys = set(self.read_done_keys(self.MAPPING_STAGE))
return [
(f"{hp_pixel.order}_{hp_pixel.pixel}", hp_pixel)
for hp_pixel in self.partition_pixels
if f"{hp_pixel.order}_{hp_pixel.pixel}" not in map_keys
]
mapped_pixels = set(self.read_done_pixels(self.MAPPING_STAGE))
remaining_pixels = list(set(self.partition_pixels) - set(mapped_pixels))
return [(f"{hp_pixel.order}_{hp_pixel.pixel}", hp_pixel) for hp_pixel in remaining_pixels]

@classmethod
def reducing_key_done(cls, tmp_path, reducing_key: str):
Expand All @@ -125,13 +122,9 @@ def reducing_key_done(cls, tmp_path, reducing_key: str):

def get_remaining_reduce_keys(self):
"""Fetch a tuple for each object catalog pixel to reduce."""
reduced_keys = set(self.read_done_keys(self.REDUCING_STAGE))
reduce_items = [
(f"{hp_pixel.order}_{hp_pixel.pixel}", hp_pixel)
for hp_pixel in self.combined_pixels
if f"{hp_pixel.order}_{hp_pixel.pixel}" not in reduced_keys
]
return reduce_items
reduced_pixels = self.read_done_pixels(self.REDUCING_STAGE)
remaining_pixels = list(set(self.combined_pixels) - set(reduced_pixels))
return [(f"{hp_pixel.order}_{hp_pixel.pixel}", hp_pixel) for hp_pixel in remaining_pixels]

def is_reducing_done(self) -> bool:
"""Are there partitions left to reduce?"""
Expand Down
32 changes: 8 additions & 24 deletions src/hats_import/pipeline_resume_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,43 +119,27 @@ def read_markers(self, stage_name: str) -> dict[str, list[str]]:
prefix = file_io.append_paths_to_pointer(self.tmp_path, stage_name)
result = {}
result_files = file_io.find_files_matching_path(prefix, "*_done")
done_file_pattern = re.compile(r"(.*)_done")
for file_path in result_files:
match = re.match(r"(.*)_done", str(file_path.name))
match = done_file_pattern.match(str(file_path.name))
if not match:
raise ValueError(f"Unexpected file found: {file_path.name}")
key = match.group(1)
result[key] = file_io.load_text_file(file_path)
return result

def read_done_keys(self, stage_name):
"""Inspect the stage's directory of done files, fetching the keys from done file names.
def read_done_pixels(self, stage_name):
"""Inspect the stage's directory of done files, fetching the pixel keys from done file names.
Args:
stage_name(str): name of the stage (e.g. mapping, reducing)
Return:
List[str] - all keys found in done directory
List[HealpixPixel] - all pixel keys found in done directory
"""
prefix = file_io.append_paths_to_pointer(self.tmp_path, stage_name)
return self.get_keys_from_file_names(prefix, "_done")

@staticmethod
def get_keys_from_file_names(directory, extension):
"""Gather keys for successful tasks from result file names.
Args:
directory: where to look for result files. this is NOT a recursive lookup
extension (str): file suffix to look for and to remove from all file names.
if you expect a file like "map_01.csv", extension should be ".csv"
Returns:
list of keys taken from files like /resume/path/{key}{extension}
"""
result_files = file_io.find_files_matching_path(directory, f"*{extension}")
keys = []
for file_path in result_files:
match = re.match(r"(.*)" + extension, str(file_path.name))
keys.append(match.group(1))
return keys
done_file_pattern = re.compile(r"(\d+)_(\d+)_done")
pixel_tuples = [done_file_pattern.match(path.name).group(1, 2) for path in prefix.glob("*_done")]
return [HealpixPixel(int(match[0]), int(match[1])) for match in pixel_tuples]

def clean_resume_files(self):
"""Remove the intermediate directory created in execution if the user decided
Expand Down
23 changes: 12 additions & 11 deletions src/hats_import/soap/resume_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import re
from dataclasses import dataclass, field
from typing import List, Optional, Tuple

Expand Down Expand Up @@ -97,12 +98,16 @@ def get_sources_to_count(self, source_pixel_map=None):
self.source_pixel_map = source_pixel_map
if self.source_pixel_map is None:
raise ValueError("source_pixel_map not provided for progress tracking.")
count_file_pattern = re.compile(r"(\d+)_(\d+).csv")
counted_pixel_tuples = [
count_file_pattern.match(path.name).group(1, 2) for path in self.tmp_path.glob("*.csv")
]
counted_pixels = [HealpixPixel(int(match[0]), int(match[1])) for match in counted_pixel_tuples]

counted_keys = set(self.get_keys_from_file_names(self.tmp_path, ".csv"))
remaining_pixels = list(set(source_pixel_map.keys()) - set(counted_pixels))
return [
(hp_pixel, object_pixels, f"{hp_pixel.order}_{hp_pixel.pixel}")
for hp_pixel, object_pixels in source_pixel_map.items()
if f"{hp_pixel.order}_{hp_pixel.pixel}" not in counted_keys
(hp_pixel, source_pixel_map[hp_pixel], f"{hp_pixel.order}_{hp_pixel.pixel}")
for hp_pixel in remaining_pixels
]

@classmethod
Expand All @@ -117,13 +122,9 @@ def reducing_key_done(cls, tmp_path, reducing_key: str):

def get_objects_to_reduce(self):
"""Fetch a tuple for each object catalog pixel to reduce."""
reduced_keys = set(self.read_done_keys(self.REDUCING_STAGE))
reduce_items = [
(hp_pixel, f"{hp_pixel.order}_{hp_pixel.pixel}")
for hp_pixel in self.object_catalog.get_healpix_pixels()
if f"{hp_pixel.order}_{hp_pixel.pixel}" not in reduced_keys
]
return reduce_items
reduced_pixels = set(self.read_done_pixels(self.REDUCING_STAGE))
remaining_pixels = list(set(self.object_catalog.get_healpix_pixels()) - set(reduced_pixels))
return [(hp_pixel, f"{hp_pixel.order}_{hp_pixel.pixel}") for hp_pixel in remaining_pixels]

def is_reducing_done(self) -> bool:
"""Are there partitions left to reduce?"""
Expand Down
24 changes: 20 additions & 4 deletions tests/hats_import/catalog/test_resume_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import numpy.testing as npt
import pytest
from hats.pixel_math.healpix_pixel import HealpixPixel
from hats.pixel_math.sparse_histogram import SparseHistogram

from hats_import.catalog.resume_plan import ResumePlan
Expand Down Expand Up @@ -68,6 +69,23 @@ def test_same_input_paths(tmp_path, small_sky_single_file, formats_headers_csv):
)


def test_remaining_map_keys(tmp_path):
"""Test that we can read what we write into a histogram file."""
num_inputs = 1_000
input_paths = [f"foo_{i}" for i in range(0, num_inputs)]
plan = ResumePlan(tmp_path=tmp_path, progress_bar=False, input_paths=input_paths)

remaining_keys = plan.get_remaining_map_keys()
assert len(remaining_keys) == num_inputs

histogram = SparseHistogram([11], [131], 0)
for i in range(0, num_inputs):
histogram.to_file(ResumePlan.partial_histogram_file(tmp_path=tmp_path, mapping_key=f"map_{i}"))

remaining_keys = plan.get_remaining_map_keys()
assert len(remaining_keys) == 0


def test_read_write_histogram(tmp_path):
"""Test that we can read what we write into a histogram file."""
plan = ResumePlan(tmp_path=tmp_path, progress_bar=False, input_paths=["foo1"])
Expand Down Expand Up @@ -171,7 +189,6 @@ def test_some_split_task_failures(tmp_path, dask_client):

def test_get_reduce_items(tmp_path):
"""Test generation of remaining reduce items"""
destination_pixel_map = [(0, 11, 131)]
plan = ResumePlan(tmp_path=tmp_path, progress_bar=False)

with pytest.raises(RuntimeError, match="destination pixel map"):
Expand All @@ -180,7 +197,7 @@ def test_get_reduce_items(tmp_path):
with pytest.raises(RuntimeError, match="destination pixel map"):
remaining_reduce_items = plan.get_destination_pixels()

plan.destination_pixel_map = destination_pixel_map
plan.destination_pixel_map = {HealpixPixel(0, 11): 131}
remaining_reduce_items = plan.get_reduce_items()
assert len(remaining_reduce_items) == 1

Expand All @@ -197,8 +214,7 @@ def test_some_reduce_task_failures(tmp_path, dask_client):
"""Test that we only consider reduce stage successful if all done files are written"""
plan = ResumePlan(tmp_path=tmp_path, progress_bar=False)

destination_pixel_map = [(0, 11, 131)]
plan.destination_pixel_map = destination_pixel_map
plan.destination_pixel_map = {HealpixPixel(0, 11): 131}
remaining_reduce_items = plan.get_reduce_items()
assert len(remaining_reduce_items) == 1

Expand Down
42 changes: 0 additions & 42 deletions tests/hats_import/test_pipeline_resume_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,29 +8,6 @@
from hats_import.pipeline_resume_plan import PipelineResumePlan, get_formatted_stage_name


def test_done_key(tmp_path):
"""Verify expected behavior of marking stage progress via done files."""
plan = PipelineResumePlan(tmp_path=tmp_path, progress_bar=False)
stage = "testing"
(tmp_path / stage).mkdir(parents=True)

keys = plan.read_done_keys(stage)
assert len(keys) == 0

PipelineResumePlan.touch_key_done_file(tmp_path, stage, "key_01")
keys = plan.read_done_keys(stage)
assert keys == ["key_01"]

PipelineResumePlan.touch_key_done_file(tmp_path, stage, "key_02")
keys = plan.read_done_keys(stage)
assert keys == ["key_01", "key_02"]

plan.clean_resume_files()

keys = plan.read_done_keys(stage)
assert len(keys) == 0


def test_done_file(tmp_path):
"""Verify expected behavior of done file"""
plan = PipelineResumePlan(tmp_path=tmp_path, progress_bar=False)
Expand All @@ -53,25 +30,6 @@ def test_done_file(tmp_path):
assert not plan.done_file_exists(done_file)


def test_get_keys_from_results(tmp_path):
"""Test that we can create a list of completed keys via the output results files."""
plan = PipelineResumePlan(tmp_path=tmp_path, progress_bar=False, resume=False)
keys = PipelineResumePlan.get_keys_from_file_names(tmp_path, ".foo")
assert len(keys) == 0

Path(tmp_path, "file_0.foo").touch()
keys = PipelineResumePlan.get_keys_from_file_names(tmp_path, ".foo")
assert keys == ["file_0"]

Path(tmp_path, "file_1.foo").touch()
keys = PipelineResumePlan.get_keys_from_file_names(tmp_path, ".foo")
assert keys == ["file_0", "file_1"]

plan.clean_resume_files()
keys = PipelineResumePlan.get_keys_from_file_names(tmp_path, ".foo")
assert len(keys) == 0


def test_safe_to_resume(tmp_path):
"""Check that we throw errors when it's not safe to resume."""
plan = PipelineResumePlan(tmp_path=tmp_path, progress_bar=False, resume=False)
Expand Down

0 comments on commit c161725

Please sign in to comment.