Skip to content

Commit

Permalink
Format codes to comply with the Black code style standard.
Browse files Browse the repository at this point in the history
  • Loading branch information
r-xue committed Feb 5, 2025
1 parent d313ae0 commit bca1dee
Show file tree
Hide file tree
Showing 9 changed files with 66 additions and 44 deletions.
64 changes: 38 additions & 26 deletions src/xradio/_utils/_casacore/casacore_from_casatools.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,9 @@ def __init__(
concatsubtables: List = [],
**kwargs,
):
super().__init__(tablename=tablename, lockoptions=lockoptions, nomodify=True, **kwargs)
super().__init__(
tablename=tablename, lockoptions=lockoptions, nomodify=True, **kwargs
)

def row(self, columnnames: List[str] = [], exclude: bool = False) -> "tablerow":
"""Access rows in the table.
Expand Down Expand Up @@ -210,7 +212,9 @@ def taql(self, taqlcommand="TaQL expression"):
else:
tablename = self.name()
tb_query_from = self
tb_query = taqlcommand.replace("$mtable", tablename).replace("$gtable", tablename)
tb_query = taqlcommand.replace("$mtable", tablename).replace(
"$gtable", tablename
)
logger.debug(f"tb_query_from: {tb_query_from.name()}")
logger.debug(f"tb_query_cmd: {tb_query}")
tb_query_to = _wrap_table(swig_object=tb_query_from._swigobj.taql(tb_query))
Expand Down Expand Up @@ -251,7 +255,7 @@ def getcolshapestring(self, *args, **kwargs):
def getcellslice(self, columnname, rownr, blc, trc, incr=1):
"""Retrieve a sliced portion of a cell from a specified column.
This method extracts a subarray from a cell within a table column,
This method extracts a subarray from a cell within a table column,
given the bottom-left corner (BLC) and top-right corner (TRC) indices.
It also supports an optional increment (`incr`) to control step size.
Expand All @@ -260,15 +264,15 @@ def getcellslice(self, columnname, rownr, blc, trc, incr=1):
columnname : str
The name of the column from which to extract data.
rownr : int or Sequence[int]
The row number(s) from which to extract data. If a sequence is provided,
The row number(s) from which to extract data. If a sequence is provided,
it is reversed before processing.
blc : Sequence[int]
The bottom-left corner indices of the slice.
trc : Sequence[int]
The top-right corner indices of the slice.
incr : int or Sequence[int], optional
Step size for slicing. If a sequence is provided, it is reversed.
If a single integer is given, it is expanded to match `blc` dimensions.
Step size for slicing. If a sequence is provided, it is reversed.
If a single integer is given, it is expanded to match `blc` dimensions.
Defaults to 1.
Returns
Expand All @@ -286,7 +290,7 @@ def getcellslice(self, columnname, rownr, blc, trc, incr=1):
if isinstance(rownr, Sequence):
rownr = rownr[::-1]
else:
rownr = [rownr]*len(blc)
rownr = [rownr] * len(blc)
rownr = 0
if isinstance(blc, Sequence):
blc = list(map(int, blc[::-1]))
Expand All @@ -295,7 +299,7 @@ def getcellslice(self, columnname, rownr, blc, trc, incr=1):
if isinstance(incr, Sequence):
incr = incr[::-1]
else:
incr = [incr]*len(blc)
incr = [incr] * len(blc)
ret = super().getcellslice(columnname=columnname, rownr=rownr, blc=blc, trc=trc)
return ret

Expand Down Expand Up @@ -400,9 +404,11 @@ def info(self):
imageinfo = self._flatten_multibeam(imageinfo)
# table.getdesc()

return {'imageinfo': imageinfo,
'coordinates': self.coordsys().torecord(),
'miscinfo': self.miscinfo()}
return {
"imageinfo": imageinfo,
"coordinates": self.coordsys().torecord(),
"miscinfo": self.miscinfo(),
}

def datatype(self):
return self.pixeltype()
Expand All @@ -423,18 +429,20 @@ def _flatten_multibeam(self, imageinfo):
dict
Updated `imageinfo` dictionary with flattened per-plane beam data.
"""
if 'perplanebeams' in imageinfo:
perplanebeams = imageinfo['perplanebeams']['beams']
if "perplanebeams" in imageinfo:
perplanebeams = imageinfo["perplanebeams"]["beams"]
perplanebeams_flat = {}
nchan = imageinfo['perplanebeams']['nChannels']
npol = imageinfo['perplanebeams']['nStokes']
nchan = imageinfo["perplanebeams"]["nChannels"]
npol = imageinfo["perplanebeams"]["nStokes"]

for c in range(nchan):
for p in range(npol):
k = nchan * p + c
perplanebeams_flat["*" + str(k)] = perplanebeams['*'+str(c)]['*'+str(p)]
imageinfo['perplanebeams'].pop('beams', None)
imageinfo['perplanebeams'].update(perplanebeams_flat)
perplanebeams_flat["*" + str(k)] = perplanebeams["*" + str(c)][
"*" + str(p)
]
imageinfo["perplanebeams"].pop("beams", None)
imageinfo["perplanebeams"].update(perplanebeams_flat)

return imageinfo

Expand All @@ -461,9 +469,9 @@ def get_axes(self):
axes = []
axis_names = self._cs.names()
for axis_type in self.get_names():
axis_inds = self._cs.findcoordinate(axis_type).get('pixel')
axis_inds = self._cs.findcoordinate(axis_type).get("pixel")
axes_list = [axis_names[idx] for idx in axis_inds[::-1]]
if axis_type == 'spectral':
if axis_type == "spectral":
axes_list = axes_list[0]
axes.append(axes_list)
return axes
Expand All @@ -476,7 +484,7 @@ def get_referencepixel(self):
list of float
The numeric reference pixel values, with axes reversed.
"""
return self._cs.referencepixel()['numeric'][::-1]
return self._cs.referencepixel()["numeric"][::-1]

def get_referencevalue(self):
"""Get the reference value at the reference pixel.
Expand All @@ -486,7 +494,7 @@ def get_referencevalue(self):
list of float
The numeric reference values, with axes reversed.
"""
return self._cs.referencevalue()['numeric'][::-1]
return self._cs.referencevalue()["numeric"][::-1]

def get_increment(self):
"""Get the coordinate increments per pixel.
Expand All @@ -496,7 +504,7 @@ def get_increment(self):
list of float
The coordinate increment values, with axes reversed.
"""
return self._cs.increment()['numeric'][::-1]
return self._cs.increment()["numeric"][::-1]

def get_unit(self):
"""Get the units of the coordinate axes.
Expand Down Expand Up @@ -536,7 +544,7 @@ def __init__(self, rec):
self._rec = rec

def get_projection(self):
return self._rec['projection']
return self._rec["projection"]


@wrap_class_methods
Expand All @@ -553,7 +561,9 @@ class tablerow(casatools.tablerow):
Whether to exclude the specified columns.
"""

def __init__(self, table: table, columnnames: List[str] = [], exclude: bool = False):
def __init__(
self, table: table, columnnames: List[str] = [], exclude: bool = False
):
super().__init__(table, columnnames=columnnames, exclude=exclude)

@method_wrapper
Expand All @@ -572,7 +582,9 @@ def get(self, rownr: int) -> Dict[str, Any]:
"""
return super().get(rownr)

def __getitem__(self, key: Union[int, slice]) -> Union[Dict[str, Any], List[Dict[str, Any]]]:
def __getitem__(
self, key: Union[int, slice]
) -> Union[Dict[str, Any], List[Dict[str, Any]]]:
"""Retrieve rows using indexing or slicing.
Parameters
Expand Down
4 changes: 1 addition & 3 deletions src/xradio/image/_util/_casacore/xds_from_casacore.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import numpy as np
import xarray as xr
from astropy import units as u

try:
from casacore import tables
from casacore.images import coordinates
Expand Down Expand Up @@ -313,9 +314,6 @@ def _casa_image_to_xds_coords(
"note": attr_note[c],
}
if do_sky_coords:
#from pprint import pprint
#print('--->',coord_dict)
#pprint(coord_dict)
for k in coord_dict.keys():
if k.startswith("direction"):
dc = coordinates.directioncoordinate(coord_dict[k])
Expand Down
1 change: 1 addition & 0 deletions src/xradio/image/_util/_casacore/xds_to_casacore.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
import xarray as xr
from astropy.coordinates import Angle

try:
from casacore import tables
except ImportError:
Expand Down
19 changes: 13 additions & 6 deletions src/xradio/image/_util/_fits/xds_from_fits.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import re
from typing import Union

import astropy as ap
import dask
import dask.array as da
import numpy as np
Expand All @@ -15,11 +14,19 @@
from xradio._utils.dict_helpers import make_quantity

from ....measurement_set._utils._utils.stokes_types import stokes_types
from ..common import (_compute_linear_world_values, _compute_velocity_values,
_compute_world_sph_dims, _convert_beam_to_rad,
_default_freq_info, _doppler_types, _freq_from_vel,
_get_unit, _get_xds_dim_order, _image_type,
_l_m_attr_notes)
from ..common import (
_compute_linear_world_values,
_compute_velocity_values,
_compute_world_sph_dims,
_convert_beam_to_rad,
_default_freq_info,
_doppler_types,
_freq_from_vel,
_get_unit,
_get_xds_dim_order,
_image_type,
_l_m_attr_notes,
)


def _fits_image_to_xds(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from casacore import tables
except ImportError:
from ....._utils._casacore import casacore_from_casatools as tables


from .load import load_col_chunk
from .read_main_table import get_partition_ids, redim_id_data_vars, rename_vars
Expand Down
1 change: 1 addition & 0 deletions src/xradio/measurement_set/_utils/_msv2/_tables/read.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import xarray as xr

import astropy.units

try:
from casacore import tables
except ImportError:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
except ImportError:
from ....._utils._casacore import casacore_from_casatools as tables


@contextmanager
def open_table_ro(infile: str) -> Generator[tables.table, None, None]:
table = tables.table(
Expand All @@ -20,10 +21,10 @@ def open_table_ro(infile: str) -> Generator[tables.table, None, None]:
@contextmanager
def open_query(table: tables.table, query: str) -> Generator[tables.table, None, None]:

if hasattr(tables, 'taql'):
ttq=tables.taql(query)
if hasattr(tables, "taql"):
ttq = tables.taql(query)
else:
ttq=table.taql(query)
ttq = table.taql(query)
try:
yield ttq
finally:
Expand Down
3 changes: 2 additions & 1 deletion src/xradio/measurement_set/_utils/_msv2/_tables/write.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from casacore import tables
except ImportError:
from ....._utils._casacore import casacore_from_casatools as tables



def revert_time(datetimes: np.ndarray) -> np.ndarray:
"""
Convert time back from pandas datetime ref to casacore ref
Expand Down
9 changes: 5 additions & 4 deletions src/xradio/measurement_set/_utils/_msv2/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import traceback

import toolviper.utils.logger as logger

try:
from casacore import tables
except ImportError:
Expand Down Expand Up @@ -610,15 +611,15 @@ def create_data_variables(
"Time to read column " + str(col) + " : " + str(time.time() - start)
)
except Exception as exc:
logger.debug(f'Could not load column {col}, exception: {exc}')
logger.debug(f"Could not load column {col}, exception: {exc}")
logger.debug(traceback.format_exc())

if ('WEIGHT_SPECTRUM' == col) and (
'WEIGHT' in col_names
if ("WEIGHT_SPECTRUM" == col) and (
"WEIGHT" in col_names
): # Bogus WEIGHT_SPECTRUM column, need to use WEIGHT.
xds = get_weight(
xds,
'WEIGHT',
"WEIGHT",
tb_tool,
time_baseline_shape,
tidxs,
Expand Down

0 comments on commit bca1dee

Please sign in to comment.