Skip to content

Commit

Permalink
Add casatools as the fallback backend for casacore image reading when…
Browse files Browse the repository at this point in the history
… python-casacore is unavailable.
  • Loading branch information
r-xue committed Feb 5, 2025
1 parent 997a556 commit 9141d66
Show file tree
Hide file tree
Showing 18 changed files with 316 additions and 66 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import inspect
import shutil
from functools import wraps
from typing import Any, Dict, List, Union
from typing import Any, Dict, List, Union, Sequence
import os

import casatools
Expand Down Expand Up @@ -248,8 +248,62 @@ def getcolshapestring(self, *args, **kwargs):
ret = super().getcolshapestring(*args, **kwargs)
return [str(list(reversed(ast.literal_eval(shape)))) for shape in ret]

def getcellslice(self, columnname, rownr, blc, trc, incr=1):
"""Retrieve a sliced portion of a cell from a specified column.
This method extracts a subarray from a cell within a table column,
given the bottom-left corner (BLC) and top-right corner (TRC) indices.
It also supports an optional increment (`incr`) to control step size.
Parameters
----------
columnname : str
The name of the column from which to extract data.
rownr : int or Sequence[int]
The row number(s) from which to extract data. If a sequence is provided,
it is reversed before processing.
blc : Sequence[int]
The bottom-left corner indices of the slice.
trc : Sequence[int]
The top-right corner indices of the slice.
incr : int or Sequence[int], optional
Step size for slicing. If a sequence is provided, it is reversed.
If a single integer is given, it is expanded to match `blc` dimensions.
Defaults to 1.
Returns
-------
Any
The extracted slice from the specified column and row(s).
Notes
-----
- If `rownr` is a sequence, it is reversed before processing.
- The `blc`, `trc`, and `incr` parameters are converted to lists of integers.
- Calls the superclass method `getcellslice` for actual data retrieval.
"""

if isinstance(rownr, Sequence):
rownr = rownr[::-1]
else:
rownr = [rownr]*len(blc)
rownr = 0
if isinstance(blc, Sequence):
blc = list(map(int, blc[::-1]))
if isinstance(trc, Sequence):
trc = list(map(int, trc[::-1]))
if isinstance(incr, Sequence):
incr = incr[::-1]
else:
incr = [incr]*len(blc)
ret = super().getcellslice(columnname=columnname, rownr=rownr, blc=blc, trc=trc)
return ret


@wrap_class_methods
class image(casatools.image):
"""A Wrapper class around `casatools.image` that provides python-casacore-like methods."""

def __init__(
self,
imagename,
Expand All @@ -265,7 +319,224 @@ def __init__(
tileshape=(),
):
super().__init__()
self.open(*arg, **kwargs)
self._imagename = imagename
if shape is None:
# self.open(*arg, **kwargs)
self.open(imagename)
else:
if value is None:
self.newimagefromshape(outfile, shape=shape)
else:
self.newimagefromarray(outfile, pixels=self.makearray(value, shape))

def getdata(self, blc=None, trc=None, inc=None):
"""Retrieve image data as a chunk.
Parameters
----------
blc : list of int, optional
Bottom-left corner of the region to extract. Defaults to `[-1]` (entire image).
trc : list of int, optional
Top-right corner of the region to extract. Defaults to `[-1]` (entire image).
inc : list of int, optional
Step size for slicing. Defaults to `[1]`.
Returns
-------
numpy.ndarray
The extracted data chunk.
"""

if blc is None:
blc = [-1]
if trc is None:
trc = [-1]
if inc is None:
inc = [1]
return super().getchunk(blc, trc, inc)

def shape(self):
"""Get the shape of the image.
Returns
-------
list of int
The shape of the image, with axes reversed for consistency.
"""
return list(map(int, super().shape()[::-1]))

def coordinates(self):
"""Get the coordinate system of the image.
Returns
-------
casatools.coordinatesystem
The coordinate system associated with the image.
"""
return coordinatesystem(self)

def unit(self):
"""Get the brightness unit of the image.
Returns
-------
str
The brightness unit of the image.
"""
return self.brightnessunit()

def info(self):
"""Retrieve image metadata including coordinates, misc info, and beam information.
Returns
-------
dict
Dictionary containing:
- 'imageinfo': Flattened image summary.
- 'coordinates': Coordinate system as a dictionary.
- 'miscinfo': Miscellaneous metadata.
"""
imageinfo = self.summary(list=False)
imageinfo = self._flatten_multibeam(imageinfo)
# table.getdesc()

return {'imageinfo': imageinfo,
'coordinates': self.coordsys().torecord(),
'miscinfo': self.miscinfo()}

def datatype(self):
return self.pixeltype()

def _flatten_multibeam(self, imageinfo):
"""Flatten the per-plane beam information in the image metadata.
This method restructures the `perplanebeams` field in `imageinfo`
to make it more accessible by flattening the nested structure.
Parameters
----------
imageinfo : dict
The image metadata containing per-plane beam information.
Returns
-------
dict
Updated `imageinfo` dictionary with flattened per-plane beam data.
"""
if 'perplanebeams' in imageinfo:
perplanebeams = imageinfo['perplanebeams']['beams']
perplanebeams_flat = {}
nchan = imageinfo['perplanebeams']['nChannels']
npol = imageinfo['perplanebeams']['nStokes']

for c in range(nchan):
for p in range(npol):
k = nchan * p + c
perplanebeams_flat["*" + str(k)] = perplanebeams['*'+str(c)]['*'+str(p)]
imageinfo['perplanebeams'].pop('beams', None)
imageinfo['perplanebeams'].update(perplanebeams_flat)

return imageinfo


class coordinatesystem(casatools.coordsys):
"""A wrapper around `casatools.coordsys` that provides python-casacore like methods"""

def __init__(self, image=None):
self._image = image
if image is None:
self._cs = casatools.coordsys()
else:
self._cs = image.coordsys()

def get_axes(self):
"""Retrieve the names of the coordinate axes.
Returns
-------
list of str or list of lists
A list containing the names of each axis, grouped by coordinate type.
Spectral axes are returned as a single string instead of a list.
"""
axes = []
axis_names = self._cs.names()
for axis_type in self.get_names():
axis_inds = self._cs.findcoordinate(axis_type).get('pixel')
axes_list = [axis_names[idx] for idx in axis_inds[::-1]]
if axis_type == 'spectral':
axes_list = axes_list[0]
axes.append(axes_list)
return axes

def get_referencepixel(self):
"""Get the reference pixel coordinates.
Returns
-------
list of float
The numeric reference pixel values, with axes reversed.
"""
return self._cs.referencepixel()['numeric'][::-1]

def get_referencevalue(self):
"""Get the reference value at the reference pixel.
Returns
-------
list of float
The numeric reference values, with axes reversed.
"""
return self._cs.referencevalue()['numeric'][::-1]

def get_increment(self):
"""Get the coordinate increments per pixel.
Returns
-------
list of float
The coordinate increment values, with axes reversed.
"""
return self._cs.increment()['numeric'][::-1]

def get_unit(self):
"""Get the units of the coordinate axes.
Returns
-------
list of str
The units of each axis, with axes reversed.
"""
return self._cs.units()[::-1]

def get_names(self):
"""Get the coordinate type names in lowercase.
Returns
-------
list of str
The coordinate type names, with axes reversed.
"""
return list(map(str.lower, self._cs.coordinatetype()[::-1]))

def dict(self):
"""Convert the coordinate system to a dictionary representation.
Returns
-------
dict
The coordinate system in CASA's dictionary format.
"""
return self._cs.torecord()


class directioncoordinate(coordinatesystem):

def __init__(self, rec):
super().__init__()
self._rec = rec

def get_projection(self):
return self._rec['projection']


@wrap_class_methods
Expand Down
2 changes: 1 addition & 1 deletion src/xradio/_utils/_casacore/tables.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
try:
from casacore import tables
except ImportError:
from . import casatools_to_casacore as tables
from . import casacore_from_casatools as tables

from contextlib import contextmanager
from typing import Dict, Generator
Expand Down
6 changes: 5 additions & 1 deletion src/xradio/image/_util/_casacore/common.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from casacore import images
try:
from casacore import images
except ImportError:
from ...._utils._casacore import casacore_from_casatools as images

from contextlib import contextmanager
import numpy as np
from typing import Dict, Generator, List, Union
Expand Down
8 changes: 6 additions & 2 deletions src/xradio/image/_util/_casacore/xds_from_casacore.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@
from astropy import units as u
try:
from casacore import tables
from casacore.images import coordinates
except ImportError:
from ...._utils._casacore import casatools_to_casacore as tables
from casacore.images import coordinates
from ...._utils._casacore import casacore_from_casatools as tables
from ...._utils._casacore import casacore_from_casatools as coordinates

from .common import (
_active_mask,
Expand Down Expand Up @@ -312,6 +313,9 @@ def _casa_image_to_xds_coords(
"note": attr_note[c],
}
if do_sky_coords:
#from pprint import pprint
#print('--->',coord_dict)
#pprint(coord_dict)
for k in coord_dict.keys():
if k.startswith("direction"):
dc = coordinates.directioncoordinate(coord_dict[k])
Expand Down
2 changes: 1 addition & 1 deletion src/xradio/image/_util/_casacore/xds_to_casacore.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
try:
from casacore import tables
except ImportError:
from ...._utils._casacore import casatools_to_casacore as tables
from ...._utils._casacore import casacore_from_casatools as tables

from .common import _active_mask, _create_new_image, _object_name, _pointing_center
from ..common import _aperture_or_sky, _compute_sky_reference_pixel, _doppler_types
Expand Down
Loading

0 comments on commit 9141d66

Please sign in to comment.