Skip to content

Commit

Permalink
Add test against catalog info total num rows.
Browse files Browse the repository at this point in the history
  • Loading branch information
delucchi-cmu committed Jan 22, 2025
1 parent d4b98c9 commit 2eff4ed
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 4 deletions.
18 changes: 15 additions & 3 deletions src/hats_import/verification/run_verification.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import pandas as pd
import pyarrow as pa
import pyarrow.dataset as pds
from hats import read_hats
from hats.pixel_math.spatial_index import SPATIAL_INDEX_COLUMN

from hats_import.verification.arguments import VerificationArguments
Expand Down Expand Up @@ -219,9 +220,20 @@ def test_num_rows(self) -> bool:
description = "Test that number of rows are equal."
print(f"\nStarting: {description}")

catalog_prop_len = read_hats(self.args.input_catalog_path).catalog_info.total_rows

# get the number of rows in each file, indexed by file path. we treat this as truth.
files_df = self._load_nrows(self.files_ds)
files_df_total = f"file footers ({files_df.num_rows.sum():,})"
files_df_sum = files_df.num_rows.sum()
files_df_total = f"file footers ({files_df_sum:,})"

target = "file footers vs catalog properties"
print(f"\t{target}")
passed_cat = catalog_prop_len == files_df_sum
_description = f" {files_df_total} vs catalog properties ({catalog_prop_len:,})."
self.results.append(
Result(passed=passed_cat, test=test, target=target, description=description + _description)
)

# check _metadata
target = "file footers vs _metadata"
Expand All @@ -245,15 +257,15 @@ def test_num_rows(self) -> bool:
if self.args.truth_total_rows is not None:
target = "file footers vs truth"
print(f"\t{target}")
passed_th = self.args.truth_total_rows == files_df.num_rows.sum()
passed_th = self.args.truth_total_rows == files_df_sum
_description = f" {files_df_total} vs user-provided truth ({self.args.truth_total_rows:,})."
self.results.append(
Result(passed=passed_th, test=test, target=target, description=description + _description)
)
else:
passed_th = True # this test did not fail. this is only needed for the return value.

all_passed = all([passed_md, passed_th])
all_passed = all([passed_md, passed_th, passed_cat])
print(f"Result: {'PASSED' if all_passed else 'FAILED'}")
return all_passed

Expand Down
2 changes: 2 additions & 0 deletions tests/data/wrong_files_and_rows/partition_info.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Norder,Npix
0,11
8 changes: 8 additions & 0 deletions tests/data/wrong_files_and_rows/properties
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#HATS catalog
obs_collection=wrong_files_and_rows
dataproduct_type=object
hats_nrows=600
hats_col_ra=source_ra
hats_col_dec=source_dec
hats_order=2

2 changes: 1 addition & 1 deletion tests/hats_import/verification/test_run_verification.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def test_test_num_rows(small_sky_object_catalog, wrong_files_and_rows_dir, tmp_p
all_failed = not results.passed.any()
assert all_failed, "bad catalog passed"

targets = {"file footers vs _metadata", "file footers vs truth"}
targets = {"file footers vs catalog properties", "file footers vs _metadata", "file footers vs truth"}
assert targets == set(results.target), "wrong targets"

expected_bad_file_names = {
Expand Down

0 comments on commit 2eff4ed

Please sign in to comment.