From bca1dee46561273cb75d8c5eaed770e02e25e851 Mon Sep 17 00:00:00 2001 From: Rui Xue Date: Wed, 5 Feb 2025 14:31:00 -0600 Subject: [PATCH] Format codes to comply with the Black code style standard. --- .../_casacore/casacore_from_casatools.py | 64 +++++++++++-------- .../_util/_casacore/xds_from_casacore.py | 4 +- .../image/_util/_casacore/xds_to_casacore.py | 1 + src/xradio/image/_util/_fits/xds_from_fits.py | 19 ++++-- .../_utils/_msv2/_tables/load_main_table.py | 2 +- .../_utils/_msv2/_tables/read.py | 1 + .../_utils/_msv2/_tables/table_query.py | 7 +- .../_utils/_msv2/_tables/write.py | 3 +- .../_utils/_msv2/conversion.py | 9 +-- 9 files changed, 66 insertions(+), 44 deletions(-) diff --git a/src/xradio/_utils/_casacore/casacore_from_casatools.py b/src/xradio/_utils/_casacore/casacore_from_casatools.py index d6cd0e46..ee36fefa 100644 --- a/src/xradio/_utils/_casacore/casacore_from_casatools.py +++ b/src/xradio/_utils/_casacore/casacore_from_casatools.py @@ -141,7 +141,9 @@ def __init__( concatsubtables: List = [], **kwargs, ): - super().__init__(tablename=tablename, lockoptions=lockoptions, nomodify=True, **kwargs) + super().__init__( + tablename=tablename, lockoptions=lockoptions, nomodify=True, **kwargs + ) def row(self, columnnames: List[str] = [], exclude: bool = False) -> "tablerow": """Access rows in the table. @@ -210,7 +212,9 @@ def taql(self, taqlcommand="TaQL expression"): else: tablename = self.name() tb_query_from = self - tb_query = taqlcommand.replace("$mtable", tablename).replace("$gtable", tablename) + tb_query = taqlcommand.replace("$mtable", tablename).replace( + "$gtable", tablename + ) logger.debug(f"tb_query_from: {tb_query_from.name()}") logger.debug(f"tb_query_cmd: {tb_query}") tb_query_to = _wrap_table(swig_object=tb_query_from._swigobj.taql(tb_query)) @@ -251,7 +255,7 @@ def getcolshapestring(self, *args, **kwargs): def getcellslice(self, columnname, rownr, blc, trc, incr=1): """Retrieve a sliced portion of a cell from a specified column. - This method extracts a subarray from a cell within a table column, + This method extracts a subarray from a cell within a table column, given the bottom-left corner (BLC) and top-right corner (TRC) indices. It also supports an optional increment (`incr`) to control step size. @@ -260,15 +264,15 @@ def getcellslice(self, columnname, rownr, blc, trc, incr=1): columnname : str The name of the column from which to extract data. rownr : int or Sequence[int] - The row number(s) from which to extract data. If a sequence is provided, + The row number(s) from which to extract data. If a sequence is provided, it is reversed before processing. blc : Sequence[int] The bottom-left corner indices of the slice. trc : Sequence[int] The top-right corner indices of the slice. incr : int or Sequence[int], optional - Step size for slicing. If a sequence is provided, it is reversed. - If a single integer is given, it is expanded to match `blc` dimensions. + Step size for slicing. If a sequence is provided, it is reversed. + If a single integer is given, it is expanded to match `blc` dimensions. Defaults to 1. Returns @@ -286,7 +290,7 @@ def getcellslice(self, columnname, rownr, blc, trc, incr=1): if isinstance(rownr, Sequence): rownr = rownr[::-1] else: - rownr = [rownr]*len(blc) + rownr = [rownr] * len(blc) rownr = 0 if isinstance(blc, Sequence): blc = list(map(int, blc[::-1])) @@ -295,7 +299,7 @@ def getcellslice(self, columnname, rownr, blc, trc, incr=1): if isinstance(incr, Sequence): incr = incr[::-1] else: - incr = [incr]*len(blc) + incr = [incr] * len(blc) ret = super().getcellslice(columnname=columnname, rownr=rownr, blc=blc, trc=trc) return ret @@ -400,9 +404,11 @@ def info(self): imageinfo = self._flatten_multibeam(imageinfo) # table.getdesc() - return {'imageinfo': imageinfo, - 'coordinates': self.coordsys().torecord(), - 'miscinfo': self.miscinfo()} + return { + "imageinfo": imageinfo, + "coordinates": self.coordsys().torecord(), + "miscinfo": self.miscinfo(), + } def datatype(self): return self.pixeltype() @@ -423,18 +429,20 @@ def _flatten_multibeam(self, imageinfo): dict Updated `imageinfo` dictionary with flattened per-plane beam data. """ - if 'perplanebeams' in imageinfo: - perplanebeams = imageinfo['perplanebeams']['beams'] + if "perplanebeams" in imageinfo: + perplanebeams = imageinfo["perplanebeams"]["beams"] perplanebeams_flat = {} - nchan = imageinfo['perplanebeams']['nChannels'] - npol = imageinfo['perplanebeams']['nStokes'] + nchan = imageinfo["perplanebeams"]["nChannels"] + npol = imageinfo["perplanebeams"]["nStokes"] for c in range(nchan): for p in range(npol): k = nchan * p + c - perplanebeams_flat["*" + str(k)] = perplanebeams['*'+str(c)]['*'+str(p)] - imageinfo['perplanebeams'].pop('beams', None) - imageinfo['perplanebeams'].update(perplanebeams_flat) + perplanebeams_flat["*" + str(k)] = perplanebeams["*" + str(c)][ + "*" + str(p) + ] + imageinfo["perplanebeams"].pop("beams", None) + imageinfo["perplanebeams"].update(perplanebeams_flat) return imageinfo @@ -461,9 +469,9 @@ def get_axes(self): axes = [] axis_names = self._cs.names() for axis_type in self.get_names(): - axis_inds = self._cs.findcoordinate(axis_type).get('pixel') + axis_inds = self._cs.findcoordinate(axis_type).get("pixel") axes_list = [axis_names[idx] for idx in axis_inds[::-1]] - if axis_type == 'spectral': + if axis_type == "spectral": axes_list = axes_list[0] axes.append(axes_list) return axes @@ -476,7 +484,7 @@ def get_referencepixel(self): list of float The numeric reference pixel values, with axes reversed. """ - return self._cs.referencepixel()['numeric'][::-1] + return self._cs.referencepixel()["numeric"][::-1] def get_referencevalue(self): """Get the reference value at the reference pixel. @@ -486,7 +494,7 @@ def get_referencevalue(self): list of float The numeric reference values, with axes reversed. """ - return self._cs.referencevalue()['numeric'][::-1] + return self._cs.referencevalue()["numeric"][::-1] def get_increment(self): """Get the coordinate increments per pixel. @@ -496,7 +504,7 @@ def get_increment(self): list of float The coordinate increment values, with axes reversed. """ - return self._cs.increment()['numeric'][::-1] + return self._cs.increment()["numeric"][::-1] def get_unit(self): """Get the units of the coordinate axes. @@ -536,7 +544,7 @@ def __init__(self, rec): self._rec = rec def get_projection(self): - return self._rec['projection'] + return self._rec["projection"] @wrap_class_methods @@ -553,7 +561,9 @@ class tablerow(casatools.tablerow): Whether to exclude the specified columns. """ - def __init__(self, table: table, columnnames: List[str] = [], exclude: bool = False): + def __init__( + self, table: table, columnnames: List[str] = [], exclude: bool = False + ): super().__init__(table, columnnames=columnnames, exclude=exclude) @method_wrapper @@ -572,7 +582,9 @@ def get(self, rownr: int) -> Dict[str, Any]: """ return super().get(rownr) - def __getitem__(self, key: Union[int, slice]) -> Union[Dict[str, Any], List[Dict[str, Any]]]: + def __getitem__( + self, key: Union[int, slice] + ) -> Union[Dict[str, Any], List[Dict[str, Any]]]: """Retrieve rows using indexing or slicing. Parameters diff --git a/src/xradio/image/_util/_casacore/xds_from_casacore.py b/src/xradio/image/_util/_casacore/xds_from_casacore.py index ee5ed437..deb856af 100644 --- a/src/xradio/image/_util/_casacore/xds_from_casacore.py +++ b/src/xradio/image/_util/_casacore/xds_from_casacore.py @@ -9,6 +9,7 @@ import numpy as np import xarray as xr from astropy import units as u + try: from casacore import tables from casacore.images import coordinates @@ -313,9 +314,6 @@ def _casa_image_to_xds_coords( "note": attr_note[c], } if do_sky_coords: - #from pprint import pprint - #print('--->',coord_dict) - #pprint(coord_dict) for k in coord_dict.keys(): if k.startswith("direction"): dc = coordinates.directioncoordinate(coord_dict[k]) diff --git a/src/xradio/image/_util/_casacore/xds_to_casacore.py b/src/xradio/image/_util/_casacore/xds_to_casacore.py index 618c626c..325867b9 100644 --- a/src/xradio/image/_util/_casacore/xds_to_casacore.py +++ b/src/xradio/image/_util/_casacore/xds_to_casacore.py @@ -5,6 +5,7 @@ import numpy as np import xarray as xr from astropy.coordinates import Angle + try: from casacore import tables except ImportError: diff --git a/src/xradio/image/_util/_fits/xds_from_fits.py b/src/xradio/image/_util/_fits/xds_from_fits.py index 27786ca1..219c939b 100644 --- a/src/xradio/image/_util/_fits/xds_from_fits.py +++ b/src/xradio/image/_util/_fits/xds_from_fits.py @@ -2,7 +2,6 @@ import re from typing import Union -import astropy as ap import dask import dask.array as da import numpy as np @@ -15,11 +14,19 @@ from xradio._utils.dict_helpers import make_quantity from ....measurement_set._utils._utils.stokes_types import stokes_types -from ..common import (_compute_linear_world_values, _compute_velocity_values, - _compute_world_sph_dims, _convert_beam_to_rad, - _default_freq_info, _doppler_types, _freq_from_vel, - _get_unit, _get_xds_dim_order, _image_type, - _l_m_attr_notes) +from ..common import ( + _compute_linear_world_values, + _compute_velocity_values, + _compute_world_sph_dims, + _convert_beam_to_rad, + _default_freq_info, + _doppler_types, + _freq_from_vel, + _get_unit, + _get_xds_dim_order, + _image_type, + _l_m_attr_notes, +) def _fits_image_to_xds( diff --git a/src/xradio/measurement_set/_utils/_msv2/_tables/load_main_table.py b/src/xradio/measurement_set/_utils/_msv2/_tables/load_main_table.py index 5feba217..60a53f1e 100644 --- a/src/xradio/measurement_set/_utils/_msv2/_tables/load_main_table.py +++ b/src/xradio/measurement_set/_utils/_msv2/_tables/load_main_table.py @@ -9,7 +9,7 @@ from casacore import tables except ImportError: from ....._utils._casacore import casacore_from_casatools as tables - + from .load import load_col_chunk from .read_main_table import get_partition_ids, redim_id_data_vars, rename_vars diff --git a/src/xradio/measurement_set/_utils/_msv2/_tables/read.py b/src/xradio/measurement_set/_utils/_msv2/_tables/read.py index 892c53b7..88d36d4b 100644 --- a/src/xradio/measurement_set/_utils/_msv2/_tables/read.py +++ b/src/xradio/measurement_set/_utils/_msv2/_tables/read.py @@ -9,6 +9,7 @@ import xarray as xr import astropy.units + try: from casacore import tables except ImportError: diff --git a/src/xradio/measurement_set/_utils/_msv2/_tables/table_query.py b/src/xradio/measurement_set/_utils/_msv2/_tables/table_query.py index ae2390ee..ffadc6ed 100644 --- a/src/xradio/measurement_set/_utils/_msv2/_tables/table_query.py +++ b/src/xradio/measurement_set/_utils/_msv2/_tables/table_query.py @@ -6,6 +6,7 @@ except ImportError: from ....._utils._casacore import casacore_from_casatools as tables + @contextmanager def open_table_ro(infile: str) -> Generator[tables.table, None, None]: table = tables.table( @@ -20,10 +21,10 @@ def open_table_ro(infile: str) -> Generator[tables.table, None, None]: @contextmanager def open_query(table: tables.table, query: str) -> Generator[tables.table, None, None]: - if hasattr(tables, 'taql'): - ttq=tables.taql(query) + if hasattr(tables, "taql"): + ttq = tables.taql(query) else: - ttq=table.taql(query) + ttq = table.taql(query) try: yield ttq finally: diff --git a/src/xradio/measurement_set/_utils/_msv2/_tables/write.py b/src/xradio/measurement_set/_utils/_msv2/_tables/write.py index cc0464d8..2f6e8735 100644 --- a/src/xradio/measurement_set/_utils/_msv2/_tables/write.py +++ b/src/xradio/measurement_set/_utils/_msv2/_tables/write.py @@ -8,7 +8,8 @@ from casacore import tables except ImportError: from ....._utils._casacore import casacore_from_casatools as tables - + + def revert_time(datetimes: np.ndarray) -> np.ndarray: """ Convert time back from pandas datetime ref to casacore ref diff --git a/src/xradio/measurement_set/_utils/_msv2/conversion.py b/src/xradio/measurement_set/_utils/_msv2/conversion.py index 313715bc..b986a570 100644 --- a/src/xradio/measurement_set/_utils/_msv2/conversion.py +++ b/src/xradio/measurement_set/_utils/_msv2/conversion.py @@ -11,6 +11,7 @@ import traceback import toolviper.utils.logger as logger + try: from casacore import tables except ImportError: @@ -610,15 +611,15 @@ def create_data_variables( "Time to read column " + str(col) + " : " + str(time.time() - start) ) except Exception as exc: - logger.debug(f'Could not load column {col}, exception: {exc}') + logger.debug(f"Could not load column {col}, exception: {exc}") logger.debug(traceback.format_exc()) - if ('WEIGHT_SPECTRUM' == col) and ( - 'WEIGHT' in col_names + if ("WEIGHT_SPECTRUM" == col) and ( + "WEIGHT" in col_names ): # Bogus WEIGHT_SPECTRUM column, need to use WEIGHT. xds = get_weight( xds, - 'WEIGHT', + "WEIGHT", tb_tool, time_baseline_shape, tidxs,