diff --git a/src/hipscat/catalog/association_catalog/association_catalog.py b/src/hipscat/catalog/association_catalog/association_catalog.py new file mode 100644 index 00000000..d408d21d --- /dev/null +++ b/src/hipscat/catalog/association_catalog/association_catalog.py @@ -0,0 +1,66 @@ +from typing import Tuple, Union + +import pandas as pd + +from hipscat.catalog import CatalogType +from hipscat.catalog.association_catalog.association_catalog_info import \ + AssociationCatalogInfo +from hipscat.catalog.association_catalog.partition_join_info import \ + PartitionJoinInfo +from hipscat.catalog.dataset.dataset import Dataset +from hipscat.io import FilePointer, paths + + +class AssociationCatalog(Dataset): + """A HiPSCat Catalog for enabling fast joins between two HiPSCat catalogs + + Catalogs of this type are partitioned based on the partitioning of both joining catalogs in the + form 'Norder=/Dir=/Npix=/join_Norder=/join_Dir=/join_Npix=.parquet'. Where each partition + contains the matching pair of hipscat indexes from each catalog's respective partitions to join. + The `partition_join_info` metadata file specifies all pairs of pixels in the Association + Catalog, corresponding to each pair of partitions in each catalog that contain rows to join. + """ + + CatalogInfoClass = AssociationCatalogInfo + JoinPixelInputTypes = Union[list, pd.DataFrame, PartitionJoinInfo] + + def __init__( + self, + catalog_info: CatalogInfoClass, + join_pixels: JoinPixelInputTypes, + catalog_path=None, + ) -> None: + if not catalog_info.catalog_type == CatalogType.ASSOCIATION: + raise ValueError( + "Catalog info `catalog_type` must be 'association'" + ) + super().__init__(catalog_info, catalog_path) + self.join_info = self._get_partition_join_info_from_pixels(join_pixels) + + def get_join_pixels(self) -> pd.DataFrame: + """Get join pixels listing all pairs of pixels from left and right catalogs that contain + matching association rows + + Returns: + pd.DataFrame with each row being a pair of pixels from the primary and join catalogs + """ + return self.join_info.data_frame + + @staticmethod + def _get_partition_join_info_from_pixels( + join_pixels: JoinPixelInputTypes + ) -> PartitionJoinInfo: + if isinstance(join_pixels, PartitionJoinInfo): + return join_pixels + if isinstance(join_pixels, pd.DataFrame): + return PartitionJoinInfo(join_pixels) + raise TypeError("join_pixels must be of type PartitionJoinInfo or DataFrame") + + @classmethod + def _read_args( + cls, catalog_base_dir: FilePointer + ) -> Tuple[CatalogInfoClass, JoinPixelInputTypes]: + args = super()._read_args(catalog_base_dir) + partition_join_info_file = paths.get_partition_join_info_pointer(catalog_base_dir) + partition_join_info = PartitionJoinInfo.read_from_file(partition_join_info_file) + return args + (partition_join_info,) diff --git a/src/hipscat/catalog/association_catalog/association_catalog_info.py b/src/hipscat/catalog/association_catalog/association_catalog_info.py new file mode 100644 index 00000000..238a6a9f --- /dev/null +++ b/src/hipscat/catalog/association_catalog/association_catalog_info.py @@ -0,0 +1,18 @@ +from dataclasses import dataclass + +from hipscat.catalog.dataset.base_catalog_info import BaseCatalogInfo + + +@dataclass +class AssociationCatalogInfo(BaseCatalogInfo): + """Catalog Info for a HiPSCat Association Catalog""" + + primary_catalog: str = None + primary_column: str = None + join_catalog: str = None + join_column: str = None + + required_fields = BaseCatalogInfo.required_fields + [ + "primary_catalog", + "join_catalog", + ] diff --git a/src/hipscat/catalog/association_catalog/partition_join_info.py b/src/hipscat/catalog/association_catalog/partition_join_info.py new file mode 100644 index 00000000..a4fabb0e --- /dev/null +++ b/src/hipscat/catalog/association_catalog/partition_join_info.py @@ -0,0 +1,47 @@ +import pandas as pd +from typing_extensions import Self + +from hipscat.io import FilePointer, file_io + + +class PartitionJoinInfo: + """Association catalog metadata with which partitions matches occur in the join""" + + PRIMARY_ORDER_COLUMN_NAME = "primary_Norder" + PRIMARY_PIXEL_COLUMN_NAME = "primary_Npix" + JOIN_ORDER_COLUMN_NAME = "join_Norder" + JOIN_PIXEL_COLUMN_NAME = "join_Npix" + + COLUMN_NAMES = [ + PRIMARY_PIXEL_COLUMN_NAME, + PRIMARY_PIXEL_COLUMN_NAME, + JOIN_ORDER_COLUMN_NAME, + JOIN_PIXEL_COLUMN_NAME, + ] + + def __init__(self, join_info_df: pd.DataFrame) -> None: + self.data_frame = join_info_df + self._check_column_names() + + def _check_column_names(self): + for column in self.COLUMN_NAMES: + if column not in self.data_frame.columns: + raise ValueError(f"join_info_df does not contain column {column}") + + @classmethod + def read_from_file(cls, partition_join_info_file: FilePointer) -> Self: + """Read partition join info from a `partition_join_info.csv` file to create an object + + Args: + partition_join_info_file: FilePointer to the `partition_join_info.csv` file + + Returns: + A `PartitionJoinInfo` object with the data from the file + """ + if not file_io.does_file_or_directory_exist(partition_join_info_file): + raise FileNotFoundError( + f"No partition info found where expected: {str(partition_join_info_file)}" + ) + + data_frame = file_io.load_csv_to_pandas(partition_join_info_file) + return cls(data_frame) diff --git a/src/hipscat/catalog/catalog.py b/src/hipscat/catalog/catalog.py index 2a4e7361..35d832f3 100644 --- a/src/hipscat/catalog/catalog.py +++ b/src/hipscat/catalog/catalog.py @@ -30,7 +30,7 @@ def __init__( self, catalog_info: CatalogInfoClass, pixels: PixelInputTypes, - catalog_path=None, + catalog_path: str = None, ) -> None: """Initializes a Catalog diff --git a/src/hipscat/io/paths.py b/src/hipscat/io/paths.py index 5c1d75ba..019f9b69 100644 --- a/src/hipscat/io/paths.py +++ b/src/hipscat/io/paths.py @@ -13,6 +13,7 @@ CATALOG_INFO_FILENAME = "catalog_info.json" PARTITION_INFO_FILENAME = "partition_info.csv" +PARTITION_JOIN_INFO_FILENAME = "partition_join_info.csv" PROVENANCE_INFO_FILENAME = "provenance_info.json" PARQUET_METADATA_FILENAME = "_metadata" PARQUET_COMMON_METADATA_FILENAME = "_common_metadata" @@ -292,3 +293,14 @@ def get_point_map_file_pointer(catalog_base_dir: FilePointer) -> FilePointer: File Pointer to the catalog's `point_map.fits` FITS image file. """ return append_paths_to_pointer(catalog_base_dir, POINT_MAP_FILENAME) + + +def get_partition_join_info_pointer(catalog_base_dir: FilePointer) -> FilePointer: + """Get file pointer to `partition_join_info.csv` association metadata file + + Args: + catalog_base_dir: pointer to base catalog directory + Returns: + File Pointer to the catalog's `partition_join_info.csv` association metadata file + """ + return append_paths_to_pointer(catalog_base_dir, PARTITION_JOIN_INFO_FILENAME) diff --git a/src/hipscat/py.typed b/src/hipscat/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/tests/conftest.py b/tests/conftest.py index 672bacb0..1e986a5d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,6 +5,7 @@ DATA_DIR_NAME = "data" SMALL_SKY_DIR_NAME = "small_sky" SMALL_SKY_ORDER1_DIR_NAME = "small_sky_order1" +SMALL_SKY_TO_SMALL_SKY_ORDER1_DIR_NAME = "small_sky_to_small_sky_order1" TEST_DIR = os.path.dirname(__file__) # pylint: disable=missing-function-docstring, redefined-outer-name @@ -23,3 +24,8 @@ def small_sky_dir(test_data_dir): @pytest.fixture def small_sky_order1_dir(test_data_dir): return os.path.join(test_data_dir, SMALL_SKY_ORDER1_DIR_NAME) + + +@pytest.fixture +def small_sky_to_small_sky_order1_dir(test_data_dir): + return os.path.join(test_data_dir, SMALL_SKY_TO_SMALL_SKY_ORDER1_DIR_NAME) diff --git a/tests/data/small_sky_to_small_sky_order1/catalog_info.json b/tests/data/small_sky_to_small_sky_order1/catalog_info.json new file mode 100644 index 00000000..5f0f09ec --- /dev/null +++ b/tests/data/small_sky_to_small_sky_order1/catalog_info.json @@ -0,0 +1,9 @@ +{ + "catalog_name": "small_sky_to_small_sky_order1", + "catalog_type": "association", + "primary_catalog": "small_sky", + "primary_column": "id", + "join_catalog": "small_sky_order1", + "join_column": "id", + "total_rows": 131 +} \ No newline at end of file diff --git a/tests/data/small_sky_to_small_sky_order1/partition_join_info.csv b/tests/data/small_sky_to_small_sky_order1/partition_join_info.csv new file mode 100644 index 00000000..b7945bb5 --- /dev/null +++ b/tests/data/small_sky_to_small_sky_order1/partition_join_info.csv @@ -0,0 +1,5 @@ +primary_Norder,primary_Npix,join_Norder,join_Npix +0,11,1,44 +0,11,1,45 +0,11,1,46 +0,11,1,47 \ No newline at end of file diff --git a/tests/hipscat/catalog/association_catalog/test_association_catalog.py b/tests/hipscat/catalog/association_catalog/test_association_catalog.py new file mode 100644 index 00000000..cf105df2 --- /dev/null +++ b/tests/hipscat/catalog/association_catalog/test_association_catalog.py @@ -0,0 +1,82 @@ +import json +import os + +import pandas as pd +import pytest + +from hipscat.catalog import CatalogType +from hipscat.catalog.association_catalog.association_catalog import \ + AssociationCatalog +from hipscat.catalog.association_catalog.partition_join_info import \ + PartitionJoinInfo + + +def test_init_catalog(association_catalog_info, association_catalog_join_pixels): + catalog = AssociationCatalog(association_catalog_info, association_catalog_join_pixels) + assert catalog.catalog_name == association_catalog_info.catalog_name + pd.testing.assert_frame_equal(catalog.get_join_pixels(), association_catalog_join_pixels) + assert catalog.catalog_info == association_catalog_info + + +def test_wrong_catalog_type(association_catalog_info, association_catalog_join_pixels): + association_catalog_info.catalog_type = CatalogType.OBJECT + with pytest.raises(ValueError, match="catalog_type"): + AssociationCatalog(association_catalog_info, association_catalog_join_pixels) + + +def test_wrong_catalog_info_type(catalog_info, association_catalog_join_pixels): + catalog_info.catalog_type = CatalogType.ASSOCIATION + with pytest.raises(TypeError, match="catalog_info"): + AssociationCatalog(catalog_info, association_catalog_join_pixels) + + +def test_wrong_join_pixels_type(association_catalog_info): + with pytest.raises(TypeError, match="join_pixels"): + AssociationCatalog(association_catalog_info, "test") + + +def test_different_join_pixels_type(association_catalog_info, association_catalog_join_pixels): + partition_join_info = PartitionJoinInfo(association_catalog_join_pixels) + catalog = AssociationCatalog(association_catalog_info, partition_join_info) + pd.testing.assert_frame_equal(catalog.get_join_pixels(), association_catalog_join_pixels) + + +def test_read_from_file(association_catalog_path, association_catalog_join_pixels): + catalog = AssociationCatalog.read_from_hipscat(association_catalog_path) + assert catalog.on_disk + assert catalog.catalog_path == association_catalog_path + assert len(catalog.get_join_pixels()) == 4 + pd.testing.assert_frame_equal(catalog.get_join_pixels(), association_catalog_join_pixels) + + +def test_empty_directory(tmp_path, association_catalog_info_data, association_catalog_join_pixels): + """Test loading empty or incomplete data""" + ## Path doesn't exist + with pytest.raises(FileNotFoundError): + AssociationCatalog.read_from_hipscat(os.path.join("path", "empty")) + + catalog_path = os.path.join(tmp_path, "empty") + os.makedirs(catalog_path, exist_ok=True) + + ## Path exists but there's nothing there + with pytest.raises(FileNotFoundError): + AssociationCatalog.read_from_hipscat(catalog_path) + + ## catalog_info file exists - getting closer + file_name = os.path.join(catalog_path, "catalog_info.json") + with open( + file_name, + "w", + encoding="utf-8", + ) as metadata_file: + metadata_file.write(json.dumps(association_catalog_info_data)) + + with pytest.raises(FileNotFoundError): + AssociationCatalog.read_from_hipscat(catalog_path) + + ## partition_info file exists - enough to create a catalog + file_name = os.path.join(catalog_path, "partition_join_info.csv") + association_catalog_join_pixels.to_csv(file_name) + + catalog = AssociationCatalog.read_from_hipscat(catalog_path) + assert catalog.catalog_name == association_catalog_info_data["catalog_name"] diff --git a/tests/hipscat/catalog/association_catalog/test_association_catalog_info.py b/tests/hipscat/catalog/association_catalog/test_association_catalog_info.py new file mode 100644 index 00000000..55082e9c --- /dev/null +++ b/tests/hipscat/catalog/association_catalog/test_association_catalog_info.py @@ -0,0 +1,50 @@ +import dataclasses +import json + +import pytest + +from hipscat.catalog.association_catalog.association_catalog_info import \ + AssociationCatalogInfo +from hipscat.io import file_io + + +def test_association_catalog_info(association_catalog_info_data, assert_catalog_info_matches_dict): + info = AssociationCatalogInfo(**association_catalog_info_data) + assert_catalog_info_matches_dict(info, association_catalog_info_data) + + +def test_str(association_catalog_info_data): + correct_string = "" + for name, value in association_catalog_info_data.items(): + correct_string += f" {name} {value}\n" + cat_info = AssociationCatalogInfo(**association_catalog_info_data) + assert str(cat_info) == correct_string + + +def test_read_from_file(association_catalog_info_file, assert_catalog_info_matches_dict): + cat_info_fp = file_io.get_file_pointer_from_path(association_catalog_info_file) + catalog_info = AssociationCatalogInfo.read_from_metadata_file(cat_info_fp) + for column in [ + "catalog_name", + "catalog_type", + "total_rows", + "primary_column", + "primary_catalog", + "join_column", + "join_catalog" + ]: + assert column in dataclasses.asdict(catalog_info) + + with open(association_catalog_info_file, "r", encoding="utf-8") as cat_info_file: + catalog_info_json = json.load(cat_info_file) + assert_catalog_info_matches_dict(catalog_info, catalog_info_json) + + +def test_required_fields_missing(association_catalog_info_data): + for required_field in ["primary_catalog", "join_catalog"]: + assert required_field in AssociationCatalogInfo.required_fields + for field in AssociationCatalogInfo.required_fields: + init_data = association_catalog_info_data.copy() + init_data[field] = None + with pytest.raises(ValueError, match=field): + AssociationCatalogInfo(**init_data) diff --git a/tests/hipscat/catalog/association_catalog/test_partition_join_info.py b/tests/hipscat/catalog/association_catalog/test_partition_join_info.py new file mode 100644 index 00000000..4a917280 --- /dev/null +++ b/tests/hipscat/catalog/association_catalog/test_partition_join_info.py @@ -0,0 +1,25 @@ +import pandas as pd +import pytest + +from hipscat.catalog.association_catalog.partition_join_info import \ + PartitionJoinInfo +from hipscat.io import file_io + + +def test_init(association_catalog_join_pixels): + partition_join_info = PartitionJoinInfo(association_catalog_join_pixels) + pd.testing.assert_frame_equal(partition_join_info.data_frame, association_catalog_join_pixels) + + +def test_wrong_columns(association_catalog_join_pixels): + for column in PartitionJoinInfo.COLUMN_NAMES: + join_pixels = association_catalog_join_pixels.copy() + join_pixels = join_pixels.rename(columns={column: "wrong_name"}) + with pytest.raises(ValueError, match=column): + PartitionJoinInfo(join_pixels) + + +def test_read_from_file(association_catalog_partition_join_file, association_catalog_join_pixels): + file_pointer = file_io.get_file_pointer_from_path(association_catalog_partition_join_file) + info = PartitionJoinInfo.read_from_file(file_pointer) + pd.testing.assert_frame_equal(info.data_frame, association_catalog_join_pixels) diff --git a/tests/hipscat/catalog/conftest.py b/tests/hipscat/catalog/conftest.py index 47e11728..da76a93f 100644 --- a/tests/hipscat/catalog/conftest.py +++ b/tests/hipscat/catalog/conftest.py @@ -4,7 +4,11 @@ import pandas as pd import pytest -from hipscat.catalog import partition_info, PartitionInfo +from hipscat.catalog import PartitionInfo +from hipscat.catalog.association_catalog.association_catalog_info import \ + AssociationCatalogInfo +from hipscat.catalog.association_catalog.partition_join_info import \ + PartitionJoinInfo from hipscat.catalog.catalog_info import CatalogInfo from hipscat.catalog.dataset.base_catalog_info import BaseCatalogInfo @@ -38,6 +42,19 @@ def catalog_info_data() -> dict: } +@pytest.fixture +def association_catalog_info_data() -> dict: + return { + "catalog_name": "test_name", + "catalog_type": "association", + "total_rows": 10, + "primary_catalog": "small_sky", + "primary_column": "id", + "join_catalog": "small_sky_order1", + "join_column": "id", + } + + @pytest.fixture def dataset_path(test_data_dir) -> str: return os.path.join(test_data_dir, "dataset") @@ -75,3 +92,33 @@ def catalog_pixels() -> pd.DataFrame: PartitionInfo.METADATA_DIR_COLUMN_NAME: [0, 0, 0], PartitionInfo.METADATA_PIXEL_COLUMN_NAME: [0, 1, 8] }) + + +@pytest.fixture +def association_catalog_path(test_data_dir) -> str: + return os.path.join(test_data_dir, "small_sky_to_small_sky_order1") + + +@pytest.fixture +def association_catalog_info_file(association_catalog_path) -> str: + return os.path.join(association_catalog_path, "catalog_info.json") + + +@pytest.fixture +def association_catalog_info(association_catalog_info_data) -> AssociationCatalogInfo: + return AssociationCatalogInfo(**association_catalog_info_data) + + +@pytest.fixture +def association_catalog_partition_join_file(association_catalog_path) -> str: + return os.path.join(association_catalog_path, "partition_join_info.csv") + + +@pytest.fixture +def association_catalog_join_pixels() -> pd.DataFrame: + return pd.DataFrame.from_dict({ + PartitionJoinInfo.PRIMARY_ORDER_COLUMN_NAME: [0, 0, 0, 0], + PartitionJoinInfo.PRIMARY_PIXEL_COLUMN_NAME: [11, 11, 11, 11], + PartitionJoinInfo.JOIN_ORDER_COLUMN_NAME: [1, 1, 1, 1], + PartitionJoinInfo.JOIN_PIXEL_COLUMN_NAME: [44, 45, 46, 47], + })