Skip to content

Commit

Permalink
Merge pull request #81 from astronomy-commons/sean/healpix-pixel-para…
Browse files Browse the repository at this point in the history
…meter-decorator

Add `healpix_pixel_decorator` to accept `HealpixPixel` objects or tuples and change functions to use it
  • Loading branch information
smcguire-cmu authored Apr 18, 2023
2 parents b61e885 + 4fa4bcc commit deecbdf
Show file tree
Hide file tree
Showing 18 changed files with 395 additions and 171 deletions.
1 change: 0 additions & 1 deletion src/hipscat/catalog/catalog_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ class CatalogParameters:
def __post_init__(
self,
):

if self.catalog_type not in self.CATALOG_TYPES:
raise ValueError(f"Unknown catalog type: {self.catalog_type}")

Expand Down
10 changes: 4 additions & 6 deletions src/hipscat/io/paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,16 +185,14 @@ def pixel_association_file(
)


def create_hive_directory_name(
base_dir, partition_token_names, partition_token_values
):
def create_hive_directory_name(base_dir, partition_token_names, partition_token_values):
"""Create path *pointer* for a directory with hive partitioning naming.
This will not create the directory.
The directory name will have the form of:
<catalog_base_dir>/<name_1>=<value_1>/.../<name_n>=<value_n>
Args:
catalog_base_dir (FilePointer): base directory of the catalog (includes catalog name)
partition_token_names (list[string]): list of partition name parts.
Expand All @@ -212,11 +210,11 @@ def create_hive_parquet_file_name(
base_dir, partition_token_names, partition_token_values
):
"""Create path *pointer* for a single parquet with hive partitioning naming.
The file name will have the form of:
<catalog_base_dir>/<name_1>=<value_1>/.../<name_n>=<value_n>.parquet
Args:
catalog_base_dir (FilePointer): base directory of the catalog (includes catalog name)
partition_token_names (list[string]): list of partition name parts.
Expand Down
4 changes: 4 additions & 0 deletions src/hipscat/pixel_math/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,7 @@
get_truncated_margin_pixels,
pixel_is_polar,
)
from .healpix_pixel_convertor import (
get_healpix_pixel,
HealpixInputTypes,
)
28 changes: 28 additions & 0 deletions src/hipscat/pixel_math/healpix_pixel_convertor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from __future__ import annotations

from typing import Tuple, Union

from hipscat.pixel_math.healpix_pixel import HealpixPixel

HealpixInputTypes = Union[HealpixPixel, Tuple[int, int]]


def get_healpix_pixel(pixel: HealpixInputTypes) -> HealpixPixel:
"""Function to convert argument of either HealpixPixel or a tuple of (order, pixel) to a
HealpixPixel
Args:
pixel: an object to be converted to a HealpixPixel object
"""

if isinstance(pixel, tuple):
if len(pixel) != 2:
raise ValueError(
"Tuple must contain two values: HEALPix order and HEALPix pixel number"
)
return HealpixPixel(order=pixel[0], pixel=pixel[1])
if isinstance(pixel, HealpixPixel):
return pixel
raise TypeError(
"pixel must either be of type `HealpixPixel` or tuple (order, pixel)"
)
45 changes: 23 additions & 22 deletions src/hipscat/pixel_math/margin_bounding.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
"""Utilities to build bounding boxes around healpixels that include a neighor margin."""

import astropy.units as u
import astropy.wcs as world_coordinate_system
import healpy as hp
import numpy as np

from astropy.coordinates import SkyCoord
from regions import PixCoord, PolygonSkyRegion
import astropy.wcs as world_coordinate_system
import astropy.units as u

from . import pixel_margins as pm

Expand Down Expand Up @@ -51,10 +50,10 @@ def get_margin_bounds_and_wcs(pixel_order, pix, scale, step=10):
hp.boundaries(2**pixel_order, pix, step=1, nest=True), lonlat=True
)

min_ra = corners[0][1] # western corner
max_ra = corners[0][3] # eastern corner
min_dec = corners[1][2] # southern corner
max_dec = corners[1][0] # northern corner
min_ra = corners[0][1] # western corner
max_ra = corners[0][3] # eastern corner
min_dec = corners[1][2] # southern corner
max_dec = corners[1][0] # northern corner

pixel_boundaries = hp.vec2dir(
hp.boundaries(2**pixel_order, pix, step=step, nest=True), lonlat=True
Expand All @@ -63,9 +62,9 @@ def get_margin_bounds_and_wcs(pixel_order, pix, scale, step=10):
# if the eastern corner is less than the western corner, then we've hit the
# ra rollover and need to normalize to 0 -> 360.
if max_ra < min_ra:
max_ra = max_ra + 360.
max_ra = max_ra + 360.0
ra_vals = pixel_boundaries[0]
normal = np.where(ra_vals < 0., ra_vals + 360., ra_vals)
normal = np.where(ra_vals < 0.0, ra_vals + 360.0, ra_vals)
pixel_boundaries[0] = normal

# find the translation values to keep the bounding box
Expand All @@ -90,7 +89,7 @@ def get_margin_bounds_and_wcs(pixel_order, pix, scale, step=10):

# if the transform places the declination of any points outside of
# the range 90 > dec > -90, change it to a proper dec value.
transformed_bounding_box[1] = np.clip(transformed_bounding_box[1], -90., 90.)
transformed_bounding_box[1] = np.clip(transformed_bounding_box[1], -90.0, 90.0)

# one arcsecond
pix_size = 0.0002777777778
Expand Down Expand Up @@ -162,7 +161,7 @@ def check_margin_bounds(r_asc, dec, poly_and_wcs):
"""Get the astropy `regions.PolygonPixelRegion` and `astropy.wcs` objects
for a given margin bounding box scale.
For pixels that fall along the poles, this code must be used in conjunction
with `check_polar_margin_bounds` and
with `check_polar_margin_bounds` and
`pixel_margins.get_truncated_margin_pixels`.
Args:
Expand All @@ -189,8 +188,11 @@ def check_margin_bounds(r_asc, dec, poly_and_wcs):
bound_vals.append(vals)
return np.array(bound_vals).any(axis=0)


# pylint: disable=too-many-locals
def check_polar_margin_bounds(r_asc, dec, order, pix, margin_order, margin_threshold, step=1000):
def check_polar_margin_bounds(
r_asc, dec, order, pix, margin_order, margin_threshold, step=1000
):
"""Given a set of ra and dec values that are around one of the poles,
determine if they are within `margin_threshold` of a provided
partition pixel. This method helps us solve the edge cases that
Expand Down Expand Up @@ -226,19 +228,18 @@ def check_polar_margin_bounds(r_asc, dec, order, pix, margin_order, margin_thres
# on the boundary of the main pixel
boundary_range = int((marg_pix_res / part_pix_res) * step)
pixel_boundaries = hp.vec2dir(
hp.boundaries(2**order, pix, step=step, nest=True),
lonlat=True
hp.boundaries(2**order, pix, step=step, nest=True), lonlat=True
)

# to optimize our code, we only want to take boundary samples from the part
# of the pixel that directly abuts the polar margin pixels.
if pole == "North":
end = len(pixel_boundaries[0])
east_ra = pixel_boundaries[0][0:boundary_range+1]
east_dec = pixel_boundaries[1][0:boundary_range+1]
east_ra = pixel_boundaries[0][0 : boundary_range + 1]
east_dec = pixel_boundaries[1][0 : boundary_range + 1]

west_ra = pixel_boundaries[0][end-boundary_range:end]
west_dec = pixel_boundaries[1][end-boundary_range:end]
west_ra = pixel_boundaries[0][end - boundary_range : end]
west_dec = pixel_boundaries[1][end - boundary_range : end]

bound_ra = np.concatenate((east_ra, west_ra), axis=None)
bound_dec = np.concatenate((east_dec, west_dec), axis=None)
Expand All @@ -252,17 +253,17 @@ def check_polar_margin_bounds(r_asc, dec, order, pix, margin_order, margin_thres

# healpy.boundaries sometimes returns dec values greater than 90, especially
# when taking many samples...
polar_boundaries[1] = np.clip(polar_boundaries[1], -90., 90.)
polar_boundaries[1] = np.clip(polar_boundaries[1], -90.0, 90.0)

sky_coords = SkyCoord(r_asc, dec, unit='deg')
sky_coords = SkyCoord(r_asc, dec, unit="deg")

checks = []
for i in range(len(polar_boundaries[0])):
lon = polar_boundaries[0][i]
lat = polar_boundaries[1][i]
bound_coord = SkyCoord(lon, lat, unit='deg')
bound_coord = SkyCoord(lon, lat, unit="deg")

ang_dist = bound_coord.separation(sky_coords)
checks.append(ang_dist <= margin_threshold*u.deg)
checks.append(ang_dist <= margin_threshold * u.deg)

return np.array(checks).any(axis=0)
16 changes: 9 additions & 7 deletions src/hipscat/pixel_math/pixel_margins.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def get_margin(order, pix, d_order):
margins = np.concatenate(margins)
return margins


def pixel_is_polar(order, pix):
"""Checks if a healpixel is a polar pixel.
Expand All @@ -130,10 +131,11 @@ def pixel_is_polar(order, pix):

# in the ring numbering scheme, the first and last 4 pixels are the poles.
if ring_pix <= 3:
return (True, 'North')
return (True, "North")
if ring_pix >= npix - 4:
return (True, 'South')
return (False, '')
return (True, "South")
return (False, "")


def get_truncated_margin_pixels(order, pix, margin_order):
"""Given a polar healpixel, find the margin pixels at order highest_k that will be
Expand All @@ -142,13 +144,13 @@ def get_truncated_margin_pixels(order, pix, margin_order):
Args:
order (int): the healpix order of the pixel that we wish to
find the truncated margin pixels for.
find the truncated margin pixels for.
pix (int): the healpixel we wish to find the truncated margin pixels for.
margin_order (int): the healpixel order that our margin pixels are at. Must
be larger than `order`.
Returns:
a list of margin pixels at margin_order order that will be truncated at
the poles, i.e. the margin pixels that are also polar pixels
a list of margin pixels at margin_order order that will be truncated at
the poles, i.e. the margin pixels that are also polar pixels
themselves. In the case that pix is a polar pixel, it will return
3 pixels, otherwise it will return an empty list.
"""
Expand All @@ -172,7 +174,7 @@ def get_truncated_margin_pixels(order, pix, margin_order):
truncs.append(hp.ring2nest(margin_nside, i))
else:
d_order = margin_order - order
excluded_pixel_nest = (4 ** d_order) * pix
excluded_pixel_nest = (4**d_order) * pix
excluded_pixel_ring = hp.nest2ring(margin_nside, excluded_pixel_nest)
npix = hp.nside2npix(margin_nside)

Expand Down
27 changes: 19 additions & 8 deletions src/hipscat/pixel_tree/pixel_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import List

from hipscat.pixel_tree.pixel_node_type import PixelNodeType
from hipscat.pixel_math import HealpixInputTypes, get_healpix_pixel


class PixelNode:
Expand All @@ -26,37 +27,37 @@ class PixelNode:

def __init__(
self,
hp_order: int,
hp_pixel: int,
pixel: HealpixInputTypes,
node_type: PixelNodeType,
parent: PixelNode | None,
children: List[PixelNode] | None = None,
):
) -> None:
"""Inits PixelNode with its attributes
Raises:
ValueError: Invalid arguments for the specified pixel type
"""

pixel = get_healpix_pixel(pixel)

if node_type == PixelNodeType.ROOT:
if parent is not None:
raise ValueError("Root node cannot have a parent")
if hp_order != -1:
if pixel.order != -1:
raise ValueError("Root node must be at order -1")

if node_type in (PixelNodeType.INNER, PixelNodeType.LEAF):
if parent is None:
raise ValueError("Inner and leaf nodes must have a parent")
if hp_pixel < 0 or hp_order < 0:
if pixel.order < 0 or pixel.pixel < 0:
raise ValueError(
"Inner and leaf nodes must have an order and pixel number >= 0"
)

if parent is not None and parent.hp_order != hp_order - 1:
if parent is not None and parent.hp_order != pixel.order - 1:
raise ValueError("Parent node must be at order one less than current node")

self.hp_order = hp_order
self.hp_pixel = hp_pixel
self.pixel = pixel
self.node_type = node_type
self.parent = parent
self.children = []
Expand All @@ -68,6 +69,16 @@ def __init__(
if self.parent is not None:
self.parent.add_child_node(self)

@property
def hp_order(self):
"""The order of the HealpixPixel the node is at"""
return self.pixel.order

@property
def hp_pixel(self):
"""The pixel number in NESTED ordering of the HealpixPixel the node is at"""
return self.pixel.pixel

def add_child_node(self, child: PixelNode):
"""Adds a child node to the node
Expand Down
35 changes: 23 additions & 12 deletions src/hipscat/pixel_tree/pixel_tree.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from hipscat.pixel_tree.pixel_node import PixelNode
from hipscat.pixel_math import HealpixInputTypes, get_healpix_pixel


class PixelTree:
Expand All @@ -14,12 +15,14 @@ class PixelTree:
Attributes:
pixels: Nested dictionary of pixel nodes stored in the tree. Indexed by HEALPix
order then pixel number
order then pixel number
root_pixel: Root node of the tree. Its children are a subset of the
12 base HEALPix pixels
"""

def __init__(self, root_pixel: PixelNode, pixels: dict[int, dict[int, PixelNode]]) -> None:
def __init__(
self, root_pixel: PixelNode, pixels: dict[int, dict[int, PixelNode]]
) -> None:
"""Initialises a tree object from the nodes in the tree
Args:
Expand All @@ -40,28 +43,36 @@ def __len__(self):
pixel_count += len(order_pixels)
return pixel_count

def contains(self, hp_order: int, hp_pixel: int) -> bool:
def contains(self, pixel: HealpixInputTypes) -> bool:
"""Check if tree contains a node at a given order and pixel
Args:
hp_order: HEALPix order to check
hp_pixel: HEALPix pixel number to check
pixel: HEALPix pixel to check. Either of type `HealpixPixel`
or a tuple of (order, pixel)
Returns:
True if the tree contains the pixel, False if not
"""
return hp_order in self.pixels and hp_pixel in self.pixels[hp_order]
pixel = get_healpix_pixel(pixel)
return pixel.order in self.pixels and pixel.pixel in self.pixels[pixel.order]

def get_node(self, hp_order: int, hp_pixel: int) -> PixelNode | None:
"""Get the node at a given order and pixel
def __contains__(self, item):
return self.contains(item)

def get_node(self, pixel: HealpixInputTypes) -> PixelNode | None:
"""Get the node at a given pixel
Args:
hp_order: HEALPix order to get
hp_pixel: HEALPix pixel number to get
pixel: HEALPix pixel to get. Either of type `HealpixPixel`
or a tuple of (order, pixel)
Returns:
The PixelNode at the index, or None if a node does not exist
"""
if self.contains(hp_order, hp_pixel):
return self.pixels[hp_order][hp_pixel]
pixel = get_healpix_pixel(pixel)
if self.contains(pixel):
return self.pixels[pixel.order][pixel.pixel]
return None

def __getitem__(self, item):
return self.get_node(item)
Loading

0 comments on commit deecbdf

Please sign in to comment.