Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove typing imports for List, Tuple, Union #441

Merged
merged 3 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions src/hats/catalog/association_catalog/association_catalog.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from __future__ import annotations

from typing import Union

import pandas as pd
import pyarrow as pa
from mocpy import MOC

from hats.catalog.association_catalog.partition_join_info import PartitionJoinInfo
from hats.catalog.dataset.table_properties import TableProperties
from hats.catalog.healpix_dataset.healpix_dataset import HealpixDataset, PixelInputTypes
from hats.catalog.healpix_dataset.healpix_dataset import HealpixDataset
from hats.catalog.partition_info import PartitionInfo
from hats.pixel_math import HealpixPixel
from hats.pixel_tree.pixel_tree import PixelTree


class AssociationCatalog(HealpixDataset):
Expand All @@ -19,13 +20,11 @@ class AssociationCatalog(HealpixDataset):
Catalog, corresponding to each pair of partitions in each catalog that contain rows to join.
"""

JoinPixelInputTypes = Union[list, pd.DataFrame, PartitionJoinInfo]

def __init__(
self,
catalog_info: TableProperties,
pixels: PixelInputTypes,
join_pixels: JoinPixelInputTypes,
pixels: PartitionInfo | PixelTree | list[HealpixPixel],
join_pixels: list | pd.DataFrame | PartitionJoinInfo,
catalog_path=None,
moc: MOC | None = None,
schema: pa.Schema | None = None,
Expand All @@ -44,7 +43,7 @@ def get_join_pixels(self) -> pd.DataFrame:

@staticmethod
def _get_partition_join_info_from_pixels(
join_pixels: JoinPixelInputTypes,
join_pixels: list | pd.DataFrame | PartitionJoinInfo,
) -> PartitionJoinInfo:
if isinstance(join_pixels, PartitionJoinInfo):
return join_pixels
Expand Down
3 changes: 1 addition & 2 deletions src/hats/catalog/association_catalog/partition_join_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import warnings
from pathlib import Path
from typing import Dict, List

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -46,7 +45,7 @@ def _check_column_names(self):
if column not in self.data_frame.columns:
raise ValueError(f"join_info_df does not contain column {column}")

def primary_to_join_map(self) -> Dict[HealpixPixel, List[HealpixPixel]]:
def primary_to_join_map(self) -> dict[HealpixPixel, list[HealpixPixel]]:
"""Generate a map from a single primary pixel to one or more pixels in the join catalog.

Lots of cute comprehension is happening here, so watch out!
Expand Down
4 changes: 1 addition & 3 deletions src/hats/catalog/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

from __future__ import annotations

from typing import List

from hats.catalog.healpix_dataset.healpix_dataset import HealpixDataset
from hats.pixel_math import HealpixPixel
from hats.pixel_tree.negative_tree import compute_negative_tree_pixels
Expand All @@ -17,7 +15,7 @@ class Catalog(HealpixDataset):
`Norder=/Dir=/Npix=.parquet`
"""

def generate_negative_tree_pixels(self) -> List[HealpixPixel]:
def generate_negative_tree_pixels(self) -> list[HealpixPixel]:
"""Get the leaf nodes at each healpix order that have zero catalog data.

For example, if an example catalog only had data points in pixel 0 at
Expand Down
5 changes: 2 additions & 3 deletions src/hats/catalog/dataset/dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

from pathlib import Path
from typing import List

import pyarrow as pa
from upath import UPath
Expand Down Expand Up @@ -41,8 +40,8 @@ def __init__(
def aggregate_column_statistics(
self,
exclude_hats_columns: bool = True,
exclude_columns: List[str] = None,
include_columns: List[str] = None,
exclude_columns: list[str] = None,
include_columns: list[str] = None,
):
"""Read footer statistics in parquet metadata, and report on global min/max values.

Expand Down
12 changes: 6 additions & 6 deletions src/hats/catalog/dataset/table_properties.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import re
from pathlib import Path
from typing import Iterable, List, Optional, Union
from typing import Iterable, Optional

from jproperties import Properties
from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator, model_validator
Expand Down Expand Up @@ -90,7 +90,7 @@ class TableProperties(BaseModel):

ra_column: Optional[str] = Field(default=None, alias="hats_col_ra")
dec_column: Optional[str] = Field(default=None, alias="hats_col_dec")
default_columns: Optional[List[str]] = Field(default=None, alias="hats_cols_default")
default_columns: Optional[list[str]] = Field(default=None, alias="hats_cols_default")
"""Which columns should be read from parquet files, when user doesn't otherwise specify."""

primary_catalog: Optional[str] = Field(default=None, alias="hats_primary_table_url")
Expand Down Expand Up @@ -120,15 +120,15 @@ class TableProperties(BaseModel):
indexing_column: Optional[str] = Field(default=None, alias="hats_index_column")
"""Column that we provide an index over."""

extra_columns: Optional[List[str]] = Field(default=None, alias="hats_index_extra_column")
extra_columns: Optional[list[str]] = Field(default=None, alias="hats_index_extra_column")
"""Any additional payload columns included in index."""

## Allow any extra keyword args to be stored on the properties object.
model_config = ConfigDict(extra="allow", populate_by_name=True, use_enum_values=True)

@field_validator("default_columns", "extra_columns", mode="before")
@classmethod
def space_delimited_list(cls, str_value: str) -> List[str]:
def space_delimited_list(cls, str_value: str) -> list[str]:
"""Convert a space-delimited list string into a python list of strings."""
if isinstance(str_value, str):
# Split on a few kinds of delimiters (just to be safe), and remove duplicates
Expand Down Expand Up @@ -193,7 +193,7 @@ def __str__(self):
return formatted_string

@classmethod
def read_from_dir(cls, catalog_dir: Union[str, Path, UPath]) -> Self:
def read_from_dir(cls, catalog_dir: str | Path | UPath) -> Self:
"""Read field values from a java-style properties file."""
file_path = file_io.get_upath(catalog_dir) / "properties"
if not file_io.does_file_or_directory_exist(file_path):
Expand All @@ -203,7 +203,7 @@ def read_from_dir(cls, catalog_dir: Union[str, Path, UPath]) -> Self:
p.load(f, "utf-8")
return cls(**p.properties)

def to_properties_file(self, catalog_dir: Union[str, Path, UPath]) -> Self:
def to_properties_file(self, catalog_dir: str | Path | UPath) -> Self:
"""Write fields to a java-style properties file."""
# pylint: disable=protected-access
parameters = self.model_dump(by_alias=True, exclude_none=True)
Expand Down
17 changes: 8 additions & 9 deletions src/hats/catalog/healpix_dataset/healpix_dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

from pathlib import Path
from typing import List, Tuple, Union

import astropy.units as u
import numpy as np
Expand Down Expand Up @@ -30,8 +29,6 @@
from hats.pixel_tree.pixel_alignment import align_with_mocs
from hats.pixel_tree.pixel_tree import PixelTree

PixelInputTypes = Union[PartitionInfo, PixelTree, List[HealpixPixel]]


class HealpixDataset(Dataset):
"""A HATS dataset partitioned with a HEALPix partitioning structure.
Expand All @@ -45,7 +42,7 @@ class HealpixDataset(Dataset):
def __init__(
self,
catalog_info: TableProperties,
pixels: PixelInputTypes,
pixels: PartitionInfo | PixelTree | list[HealpixPixel],
catalog_path: str | Path | UPath | None = None,
moc: MOC | None = None,
schema: pa.Schema | None = None,
Expand All @@ -66,7 +63,7 @@ def __init__(
self.pixel_tree = self._get_pixel_tree_from_pixels(pixels)
self.moc = moc

def get_healpix_pixels(self) -> List[HealpixPixel]:
def get_healpix_pixels(self) -> list[HealpixPixel]:
"""Get healpix pixel objects for all pixels contained in the catalog.

Returns:
Expand All @@ -75,7 +72,9 @@ def get_healpix_pixels(self) -> List[HealpixPixel]:
return self.partition_info.get_healpix_pixels()

@staticmethod
def _get_partition_info_from_pixels(pixels: PixelInputTypes) -> PartitionInfo:
def _get_partition_info_from_pixels(
pixels: PartitionInfo | PixelTree | list[HealpixPixel],
) -> PartitionInfo:
if isinstance(pixels, PartitionInfo):
return pixels
if isinstance(pixels, PixelTree):
Expand All @@ -85,7 +84,7 @@ def _get_partition_info_from_pixels(pixels: PixelInputTypes) -> PartitionInfo:
raise TypeError("Pixels must be of type PartitionInfo, PixelTree, or List[HealpixPixel]")

@staticmethod
def _get_pixel_tree_from_pixels(pixels: PixelInputTypes) -> PixelTree:
def _get_pixel_tree_from_pixels(pixels: PartitionInfo | PixelTree | list[HealpixPixel]) -> PixelTree:
if isinstance(pixels, PartitionInfo):
return PixelTree.from_healpix(pixels.get_healpix_pixels())
if isinstance(pixels, PixelTree):
Expand Down Expand Up @@ -118,7 +117,7 @@ def get_max_coverage_order(self) -> int:
)
return max_order

def filter_from_pixel_list(self, pixels: List[HealpixPixel]) -> Self:
def filter_from_pixel_list(self, pixels: list[HealpixPixel]) -> Self:
"""Filter the pixels in the catalog to only include any that overlap with the requested pixels.

Args:
Expand Down Expand Up @@ -155,7 +154,7 @@ def filter_by_cone(self, ra: float, dec: float, radius_arcsec: float) -> Self:
)
return self.filter_by_moc(cone_moc)

def filter_by_box(self, ra: Tuple[float, float], dec: Tuple[float, float]) -> Self:
def filter_by_box(self, ra: tuple[float, float], dec: tuple[float, float]) -> Self:
"""Filter the pixels in the catalog to only include the pixels that overlap with a
zone, defined by right ascension and declination ranges. The right ascension edges follow
great arc circles and the declination edges follow small arc circles.
Expand Down
4 changes: 1 addition & 3 deletions src/hats/catalog/index/index_catalog.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import List

import numpy as np
import pyarrow.compute as pc
import pyarrow.dataset as pds
Expand All @@ -16,7 +14,7 @@ class IndexCatalog(Dataset):
Note that this is not a true "HATS Catalog", as it is not partitioned spatially.
"""

def loc_partitions(self, ids) -> List[HealpixPixel]:
def loc_partitions(self, ids) -> list[HealpixPixel]:
"""Find the set of partitions in the primary catalog for the ids provided.

Args:
Expand Down
9 changes: 4 additions & 5 deletions src/hats/catalog/partition_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import warnings
from pathlib import Path
from typing import List

import numpy as np
import pandas as pd
Expand All @@ -27,11 +26,11 @@ class PartitionInfo:
METADATA_ORDER_COLUMN_NAME = "Norder"
METADATA_PIXEL_COLUMN_NAME = "Npix"

def __init__(self, pixel_list: List[HealpixPixel], catalog_base_dir: str = None) -> None:
def __init__(self, pixel_list: list[HealpixPixel], catalog_base_dir: str = None) -> None:
self.pixel_list = pixel_list
self.catalog_base_dir = catalog_base_dir

def get_healpix_pixels(self) -> List[HealpixPixel]:
def get_healpix_pixels(self) -> list[HealpixPixel]:
"""Get healpix pixel objects for all pixels represented as partitions.

Returns:
Expand Down Expand Up @@ -158,7 +157,7 @@ def read_from_file(cls, metadata_file: str | Path | UPath, strict: bool = False)
@classmethod
def _read_from_metadata_file(
cls, metadata_file: str | Path | UPath, strict: bool = False
) -> List[HealpixPixel]:
) -> list[HealpixPixel]:
"""Read partition info list from a `_metadata` file.

Args:
Expand Down Expand Up @@ -260,7 +259,7 @@ def as_dataframe(self):
return pd.DataFrame.from_dict(partition_info_dict)

@classmethod
def from_healpix(cls, healpix_pixels: List[HealpixPixel]) -> PartitionInfo:
def from_healpix(cls, healpix_pixels: list[HealpixPixel]) -> PartitionInfo:
"""Create a partition info object from a list of constituent healpix pixels.

Args:
Expand Down
3 changes: 1 addition & 2 deletions src/hats/inspection/almanac.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import os
import warnings
from typing import List

import pandas as pd

Expand Down Expand Up @@ -218,7 +217,7 @@ def _get_linked_catalog(self, linked_text, namespace) -> AlmanacInfo | None:
return None
return self.entries[resolved_name]

def catalogs(self, include_deprecated=False, types: List[str] | None = None):
def catalogs(self, include_deprecated=False, types: list[str] | None = None):
"""Get names of catalogs in the almanac, matching the provided conditions.

Catalogs must meet all criteria provided in order to be returned (e.g.
Expand Down
17 changes: 8 additions & 9 deletions src/hats/inspection/almanac_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import os
from dataclasses import dataclass, field
from typing import List

import yaml
from typing_extensions import Self
Expand All @@ -24,14 +23,14 @@ class AlmanacInfo:
join: str | None = None
primary_link: Self | None = None
join_link: Self | None = None
sources: List[Self] = field(default_factory=list)
objects: List[Self] = field(default_factory=list)
margins: List[Self] = field(default_factory=list)
associations: List[Self] = field(default_factory=list)
associations_right: List[Self] = field(default_factory=list)
indexes: List[Self] = field(default_factory=list)

creators: List[str] = field(default_factory=list)
sources: list[Self] = field(default_factory=list)
objects: list[Self] = field(default_factory=list)
margins: list[Self] = field(default_factory=list)
associations: list[Self] = field(default_factory=list)
associations_right: list[Self] = field(default_factory=list)
indexes: list[Self] = field(default_factory=list)

creators: list[str] = field(default_factory=list)
description: str = ""
version: str = ""
deprecated: str = ""
Expand Down
Loading
Loading