Skip to content

Commit a3df697

Browse files
authored
Merge pull request #784 from Cadair/gwcs_dev_fixes
2 parents 60a0c79 + 93d19dc commit a3df697

File tree

2 files changed

+28
-7
lines changed

2 files changed

+28
-7
lines changed

Diff for: ndcube/extra_coords/table_coord.py

+27-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import abc
22
import copy
33
from numbers import Integral
4+
from functools import partial
45
from collections import defaultdict
56

67
import gwcs
@@ -13,6 +14,7 @@
1314
from astropy.modeling.models import tabular_model
1415
from astropy.modeling.tabular import _Tabular
1516
from astropy.time import Time
17+
from astropy.utils import isiterable
1618
from astropy.wcs.wcsapi.wrappers.sliced_wcs import combine_slices, sanitize_slices
1719

1820
try:
@@ -901,6 +903,17 @@ def frame(self):
901903

902904
return cf.CompositeFrame(frames)
903905

906+
@staticmethod
907+
def _from_high_level_coordinates(dropped_frame, *highlevel_coords):
908+
"""
909+
This is a backwards compatibility wrapper for the new
910+
from_high_level_coordinates method in gwcs.
911+
"""
912+
quantities = dropped_frame.coordinate_to_quantity(*highlevel_coords)
913+
if isiterable(quantities):
914+
quantities = tuple(q.value for q in quantities)
915+
return quantities
916+
904917
@property
905918
def dropped_world_dimensions(self):
906919
dropped_world_dimensions = defaultdict(list)
@@ -919,22 +932,31 @@ def dropped_world_dimensions(self):
919932
dropped_world_dimensions["world_axis_names"] += [name or None for name in dropped_multi_table.frame.axes_names]
920933
dropped_world_dimensions["world_axis_physical_types"] += list(dropped_multi_table.frame.axis_physical_types)
921934
dropped_world_dimensions["world_axis_units"] += [u.to_string() for u in dropped_multi_table.frame.unit]
922-
dropped_world_dimensions["world_axis_object_components"] += dropped_multi_table.frame._world_axis_object_components
923-
dropped_world_dimensions["world_axis_object_classes"].update(dropped_multi_table.frame._world_axis_object_classes)
935+
# In gwcs https://github.com/spacetelescope/gwcs/pull/457 the underscore was dropped
936+
waocomp = getattr(dropped_multi_table.frame, "world_axis_object_components", getattr(dropped_multi_table.frame, "_world_axis_object_components", []))
937+
dropped_world_dimensions["world_axis_object_components"] += waocomp
938+
waocls = getattr(dropped_multi_table.frame, "world_axis_object_classes", getattr(dropped_multi_table.frame, "_world_axis_object_classes", {}))
939+
dropped_world_dimensions["world_axis_object_classes"].update(waocls)
924940

925941
for dropped in self._dropped_coords:
926942
# If the table is a tuple (QuantityTableCoordinate) then we need to
927943
# squish the input
944+
# In gwcs https://github.com/spacetelescope/gwcs/pull/457 coordinate_to_quantity was removed
945+
coord_meth = getattr(
946+
dropped.frame,
947+
"from_high_level_coordinates",
948+
partial(self._from_high_level_coordinates, dropped.frame)
949+
)
928950
if isinstance(dropped.table, tuple):
929-
coord = dropped.frame.coordinate_to_quantity(*dropped.table)
951+
coord = coord_meth(*dropped.table)
930952
else:
931-
coord = dropped.frame.coordinate_to_quantity(dropped.table)
953+
coord = coord_meth(dropped.table)
932954

933955
# We want the value in the output dict to be a flat list of values
934956
# in the order of world_axis_object_components, so if we get a
935957
# tuple of coordinates out of gWCS then append them to the list, if
936958
# we only get one quantity out then append to the list.
937-
if isinstance(coord, tuple):
959+
if isinstance(coord, (tuple, list)):
938960
dropped_world_dimensions["value"] += list(coord)
939961
else:
940962
dropped_world_dimensions["value"].append(coord)

Diff for: ndcube/extra_coords/tests/test_lookup_table_coord.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -520,15 +520,14 @@ def test_mtc_dropped_table_skycoord_join(lut_1d_time, lut_2d_skycoord_mesh):
520520
assert isinstance(dwd, dict)
521521
wao_classes = dwd.pop("world_axis_object_classes")
522522
assert all(isinstance(value, list) for value in dwd.values())
523-
assert all(len(value) == 2 for value in dwd.values())
524523

525524
assert dwd["world_axis_names"] == ["lon", "lat"]
526525
assert all(isinstance(u, str) for u in dwd["world_axis_units"])
527526
assert dwd["world_axis_units"] == ["deg", "deg"]
528527
assert dwd["world_axis_physical_types"] == ["pos.eq.ra", "pos.eq.dec"]
529528
assert [c[:2] for c in dwd["world_axis_object_components"]] == [("celestial", 0), ("celestial", 1)]
530529
assert wao_classes["celestial"][0] is SkyCoord
531-
assert dwd["value"] == [0*u.deg, 0*u.deg]
530+
assert dwd["value"] == [0, 0]
532531

533532

534533
@pytest.mark.xfail(reason=">1D Tables not supported")

0 commit comments

Comments
 (0)