Skip to content

Commit

Permalink
Merge pull request #86 from astronomy-commons/sean/add-association-ca…
Browse files Browse the repository at this point in the history
…talog

Add Association Catalog
  • Loading branch information
smcguire-cmu authored Apr 26, 2023
2 parents bdde9aa + 3485b69 commit 816cade
Show file tree
Hide file tree
Showing 13 changed files with 369 additions and 2 deletions.
66 changes: 66 additions & 0 deletions src/hipscat/catalog/association_catalog/association_catalog.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from typing import Tuple, Union

import pandas as pd

from hipscat.catalog import CatalogType
from hipscat.catalog.association_catalog.association_catalog_info import \
AssociationCatalogInfo
from hipscat.catalog.association_catalog.partition_join_info import \
PartitionJoinInfo
from hipscat.catalog.dataset.dataset import Dataset
from hipscat.io import FilePointer, paths


class AssociationCatalog(Dataset):
"""A HiPSCat Catalog for enabling fast joins between two HiPSCat catalogs
Catalogs of this type are partitioned based on the partitioning of both joining catalogs in the
form 'Norder=/Dir=/Npix=/join_Norder=/join_Dir=/join_Npix=.parquet'. Where each partition
contains the matching pair of hipscat indexes from each catalog's respective partitions to join.
The `partition_join_info` metadata file specifies all pairs of pixels in the Association
Catalog, corresponding to each pair of partitions in each catalog that contain rows to join.
"""

CatalogInfoClass = AssociationCatalogInfo
JoinPixelInputTypes = Union[list, pd.DataFrame, PartitionJoinInfo]

def __init__(
self,
catalog_info: CatalogInfoClass,
join_pixels: JoinPixelInputTypes,
catalog_path=None,
) -> None:
if not catalog_info.catalog_type == CatalogType.ASSOCIATION:
raise ValueError(
"Catalog info `catalog_type` must be 'association'"
)
super().__init__(catalog_info, catalog_path)
self.join_info = self._get_partition_join_info_from_pixels(join_pixels)

def get_join_pixels(self) -> pd.DataFrame:
"""Get join pixels listing all pairs of pixels from left and right catalogs that contain
matching association rows
Returns:
pd.DataFrame with each row being a pair of pixels from the primary and join catalogs
"""
return self.join_info.data_frame

@staticmethod
def _get_partition_join_info_from_pixels(
join_pixels: JoinPixelInputTypes
) -> PartitionJoinInfo:
if isinstance(join_pixels, PartitionJoinInfo):
return join_pixels
if isinstance(join_pixels, pd.DataFrame):
return PartitionJoinInfo(join_pixels)
raise TypeError("join_pixels must be of type PartitionJoinInfo or DataFrame")

@classmethod
def _read_args(
cls, catalog_base_dir: FilePointer
) -> Tuple[CatalogInfoClass, JoinPixelInputTypes]:
args = super()._read_args(catalog_base_dir)
partition_join_info_file = paths.get_partition_join_info_pointer(catalog_base_dir)
partition_join_info = PartitionJoinInfo.read_from_file(partition_join_info_file)
return args + (partition_join_info,)
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from dataclasses import dataclass

from hipscat.catalog.dataset.base_catalog_info import BaseCatalogInfo


@dataclass
class AssociationCatalogInfo(BaseCatalogInfo):
"""Catalog Info for a HiPSCat Association Catalog"""

primary_catalog: str = None
primary_column: str = None
join_catalog: str = None
join_column: str = None

required_fields = BaseCatalogInfo.required_fields + [
"primary_catalog",
"join_catalog",
]
47 changes: 47 additions & 0 deletions src/hipscat/catalog/association_catalog/partition_join_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import pandas as pd
from typing_extensions import Self

from hipscat.io import FilePointer, file_io


class PartitionJoinInfo:
"""Association catalog metadata with which partitions matches occur in the join"""

PRIMARY_ORDER_COLUMN_NAME = "primary_Norder"
PRIMARY_PIXEL_COLUMN_NAME = "primary_Npix"
JOIN_ORDER_COLUMN_NAME = "join_Norder"
JOIN_PIXEL_COLUMN_NAME = "join_Npix"

COLUMN_NAMES = [
PRIMARY_PIXEL_COLUMN_NAME,
PRIMARY_PIXEL_COLUMN_NAME,
JOIN_ORDER_COLUMN_NAME,
JOIN_PIXEL_COLUMN_NAME,
]

def __init__(self, join_info_df: pd.DataFrame) -> None:
self.data_frame = join_info_df
self._check_column_names()

def _check_column_names(self):
for column in self.COLUMN_NAMES:
if column not in self.data_frame.columns:
raise ValueError(f"join_info_df does not contain column {column}")

@classmethod
def read_from_file(cls, partition_join_info_file: FilePointer) -> Self:
"""Read partition join info from a `partition_join_info.csv` file to create an object
Args:
partition_join_info_file: FilePointer to the `partition_join_info.csv` file
Returns:
A `PartitionJoinInfo` object with the data from the file
"""
if not file_io.does_file_or_directory_exist(partition_join_info_file):
raise FileNotFoundError(
f"No partition info found where expected: {str(partition_join_info_file)}"
)

data_frame = file_io.load_csv_to_pandas(partition_join_info_file)
return cls(data_frame)
2 changes: 1 addition & 1 deletion src/hipscat/catalog/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(
self,
catalog_info: CatalogInfoClass,
pixels: PixelInputTypes,
catalog_path=None,
catalog_path: str = None,
) -> None:
"""Initializes a Catalog
Expand Down
12 changes: 12 additions & 0 deletions src/hipscat/io/paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

CATALOG_INFO_FILENAME = "catalog_info.json"
PARTITION_INFO_FILENAME = "partition_info.csv"
PARTITION_JOIN_INFO_FILENAME = "partition_join_info.csv"
PROVENANCE_INFO_FILENAME = "provenance_info.json"
PARQUET_METADATA_FILENAME = "_metadata"
PARQUET_COMMON_METADATA_FILENAME = "_common_metadata"
Expand Down Expand Up @@ -292,3 +293,14 @@ def get_point_map_file_pointer(catalog_base_dir: FilePointer) -> FilePointer:
File Pointer to the catalog's `point_map.fits` FITS image file.
"""
return append_paths_to_pointer(catalog_base_dir, POINT_MAP_FILENAME)


def get_partition_join_info_pointer(catalog_base_dir: FilePointer) -> FilePointer:
"""Get file pointer to `partition_join_info.csv` association metadata file
Args:
catalog_base_dir: pointer to base catalog directory
Returns:
File Pointer to the catalog's `partition_join_info.csv` association metadata file
"""
return append_paths_to_pointer(catalog_base_dir, PARTITION_JOIN_INFO_FILENAME)
Empty file added src/hipscat/py.typed
Empty file.
6 changes: 6 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
DATA_DIR_NAME = "data"
SMALL_SKY_DIR_NAME = "small_sky"
SMALL_SKY_ORDER1_DIR_NAME = "small_sky_order1"
SMALL_SKY_TO_SMALL_SKY_ORDER1_DIR_NAME = "small_sky_to_small_sky_order1"
TEST_DIR = os.path.dirname(__file__)

# pylint: disable=missing-function-docstring, redefined-outer-name
Expand All @@ -23,3 +24,8 @@ def small_sky_dir(test_data_dir):
@pytest.fixture
def small_sky_order1_dir(test_data_dir):
return os.path.join(test_data_dir, SMALL_SKY_ORDER1_DIR_NAME)


@pytest.fixture
def small_sky_to_small_sky_order1_dir(test_data_dir):
return os.path.join(test_data_dir, SMALL_SKY_TO_SMALL_SKY_ORDER1_DIR_NAME)
9 changes: 9 additions & 0 deletions tests/data/small_sky_to_small_sky_order1/catalog_info.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"catalog_name": "small_sky_to_small_sky_order1",
"catalog_type": "association",
"primary_catalog": "small_sky",
"primary_column": "id",
"join_catalog": "small_sky_order1",
"join_column": "id",
"total_rows": 131
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
primary_Norder,primary_Npix,join_Norder,join_Npix
0,11,1,44
0,11,1,45
0,11,1,46
0,11,1,47
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import json
import os

import pandas as pd
import pytest

from hipscat.catalog import CatalogType
from hipscat.catalog.association_catalog.association_catalog import \
AssociationCatalog
from hipscat.catalog.association_catalog.partition_join_info import \
PartitionJoinInfo


def test_init_catalog(association_catalog_info, association_catalog_join_pixels):
catalog = AssociationCatalog(association_catalog_info, association_catalog_join_pixels)
assert catalog.catalog_name == association_catalog_info.catalog_name
pd.testing.assert_frame_equal(catalog.get_join_pixels(), association_catalog_join_pixels)
assert catalog.catalog_info == association_catalog_info


def test_wrong_catalog_type(association_catalog_info, association_catalog_join_pixels):
association_catalog_info.catalog_type = CatalogType.OBJECT
with pytest.raises(ValueError, match="catalog_type"):
AssociationCatalog(association_catalog_info, association_catalog_join_pixels)


def test_wrong_catalog_info_type(catalog_info, association_catalog_join_pixels):
catalog_info.catalog_type = CatalogType.ASSOCIATION
with pytest.raises(TypeError, match="catalog_info"):
AssociationCatalog(catalog_info, association_catalog_join_pixels)


def test_wrong_join_pixels_type(association_catalog_info):
with pytest.raises(TypeError, match="join_pixels"):
AssociationCatalog(association_catalog_info, "test")


def test_different_join_pixels_type(association_catalog_info, association_catalog_join_pixels):
partition_join_info = PartitionJoinInfo(association_catalog_join_pixels)
catalog = AssociationCatalog(association_catalog_info, partition_join_info)
pd.testing.assert_frame_equal(catalog.get_join_pixels(), association_catalog_join_pixels)


def test_read_from_file(association_catalog_path, association_catalog_join_pixels):
catalog = AssociationCatalog.read_from_hipscat(association_catalog_path)
assert catalog.on_disk
assert catalog.catalog_path == association_catalog_path
assert len(catalog.get_join_pixels()) == 4
pd.testing.assert_frame_equal(catalog.get_join_pixels(), association_catalog_join_pixels)


def test_empty_directory(tmp_path, association_catalog_info_data, association_catalog_join_pixels):
"""Test loading empty or incomplete data"""
## Path doesn't exist
with pytest.raises(FileNotFoundError):
AssociationCatalog.read_from_hipscat(os.path.join("path", "empty"))

catalog_path = os.path.join(tmp_path, "empty")
os.makedirs(catalog_path, exist_ok=True)

## Path exists but there's nothing there
with pytest.raises(FileNotFoundError):
AssociationCatalog.read_from_hipscat(catalog_path)

## catalog_info file exists - getting closer
file_name = os.path.join(catalog_path, "catalog_info.json")
with open(
file_name,
"w",
encoding="utf-8",
) as metadata_file:
metadata_file.write(json.dumps(association_catalog_info_data))

with pytest.raises(FileNotFoundError):
AssociationCatalog.read_from_hipscat(catalog_path)

## partition_info file exists - enough to create a catalog
file_name = os.path.join(catalog_path, "partition_join_info.csv")
association_catalog_join_pixels.to_csv(file_name)

catalog = AssociationCatalog.read_from_hipscat(catalog_path)
assert catalog.catalog_name == association_catalog_info_data["catalog_name"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import dataclasses
import json

import pytest

from hipscat.catalog.association_catalog.association_catalog_info import \
AssociationCatalogInfo
from hipscat.io import file_io


def test_association_catalog_info(association_catalog_info_data, assert_catalog_info_matches_dict):
info = AssociationCatalogInfo(**association_catalog_info_data)
assert_catalog_info_matches_dict(info, association_catalog_info_data)


def test_str(association_catalog_info_data):
correct_string = ""
for name, value in association_catalog_info_data.items():
correct_string += f" {name} {value}\n"
cat_info = AssociationCatalogInfo(**association_catalog_info_data)
assert str(cat_info) == correct_string


def test_read_from_file(association_catalog_info_file, assert_catalog_info_matches_dict):
cat_info_fp = file_io.get_file_pointer_from_path(association_catalog_info_file)
catalog_info = AssociationCatalogInfo.read_from_metadata_file(cat_info_fp)
for column in [
"catalog_name",
"catalog_type",
"total_rows",
"primary_column",
"primary_catalog",
"join_column",
"join_catalog"
]:
assert column in dataclasses.asdict(catalog_info)

with open(association_catalog_info_file, "r", encoding="utf-8") as cat_info_file:
catalog_info_json = json.load(cat_info_file)
assert_catalog_info_matches_dict(catalog_info, catalog_info_json)


def test_required_fields_missing(association_catalog_info_data):
for required_field in ["primary_catalog", "join_catalog"]:
assert required_field in AssociationCatalogInfo.required_fields
for field in AssociationCatalogInfo.required_fields:
init_data = association_catalog_info_data.copy()
init_data[field] = None
with pytest.raises(ValueError, match=field):
AssociationCatalogInfo(**init_data)
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import pandas as pd
import pytest

from hipscat.catalog.association_catalog.partition_join_info import \
PartitionJoinInfo
from hipscat.io import file_io


def test_init(association_catalog_join_pixels):
partition_join_info = PartitionJoinInfo(association_catalog_join_pixels)
pd.testing.assert_frame_equal(partition_join_info.data_frame, association_catalog_join_pixels)


def test_wrong_columns(association_catalog_join_pixels):
for column in PartitionJoinInfo.COLUMN_NAMES:
join_pixels = association_catalog_join_pixels.copy()
join_pixels = join_pixels.rename(columns={column: "wrong_name"})
with pytest.raises(ValueError, match=column):
PartitionJoinInfo(join_pixels)


def test_read_from_file(association_catalog_partition_join_file, association_catalog_join_pixels):
file_pointer = file_io.get_file_pointer_from_path(association_catalog_partition_join_file)
info = PartitionJoinInfo.read_from_file(file_pointer)
pd.testing.assert_frame_equal(info.data_frame, association_catalog_join_pixels)
Loading

0 comments on commit 816cade

Please sign in to comment.