Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 68 additions & 11 deletions gplately/grids.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,41 @@ def _guess_data_variable_name(cdf: netCDF4.Dataset, x_name: str, y_name: str) ->
return None


def _is_a_common_name_for_longitude(name: str) -> bool:
"""Return True if the `name` parameter is a possible common name for longitude."""
return name in ["lon", "lons", "longitude", "x", "east", "easting", "eastings"]


def _is_a_common_name_for_latitude(name: str) -> bool:
"""Return True if the `name` parameter is a possible common name for latitude."""
return name in ["lat", "lats", "latitude", "y", "north", "northing", "northings"]


def _find_extent_from_data(data) -> Union[Tuple[float, float, float, float], None]:
"""Try to find the extent from data. Return None if data doesn't contain coordinates.
As of 2025-12-10, only support xarray.DataArray."""
extent = None
lons = None
lats = None
try:
for name in data.coords:
if not lats and _is_a_common_name_for_latitude(name):
lats = data.coords[name]
elif not lons and _is_a_common_name_for_longitude(name):
lons = data.coords[name]
if lons is not None and lats is not None:
extent = (
float(lons.min()),
float(lons.max()),
float(lats.min()),
float(lats.max()),
)
except Exception as ex:
logger.debug(ex)
return None
return extent


def read_netcdf_grid(
filename,
return_grids: bool = False,
Expand Down Expand Up @@ -294,7 +329,9 @@ def find_label(keys, labels):
if (unique_grid == [0, 1]).all():
cdf_grid = cdf_grid.astype(bool)

if realign:
# we realign the grid to -180/180 when the longitudes are from 0 to 360
# this is a temporary fix. we need a more sophisticated solution.
if np.max(cdf_lon) > 180:
# realign longitudes to -180/180 dateline
cdf_grid_z, cdf_lon, cdf_lat = _realign_grid(cdf_grid, cdf_lon, cdf_lat)
else:
Expand Down Expand Up @@ -1564,6 +1601,7 @@ def __init__(
The raster data, either as a file path (:class:`str`) or array data or a ``Raster`` object.
If a ``Raster`` object is specified then all other arguments are ignored except ``plate_reconstruction``
which, if it is not ``None``, will override the plate reconstruction of the ``Raster`` object.
The data parameter accepts `numpy.ndarray`, `xarray.DataArray` or or any object that can be converted to a `numpy.ndarray`.

plate_reconstruction : PlateReconstruction
A :class:`PlateReconstruction` object to provide the following essential components for reconstructing points.
Expand All @@ -1588,7 +1626,7 @@ def __init__(
2-tuple e.g. resample=(resX, resY).

time : float, default: 0.0
The geological time the time-dependant raster data.
The geological time of the time-dependant raster data.

origin : {'lower', 'upper'}, optional
When ``data`` is an array, use this parameter to specify the origin
Expand Down Expand Up @@ -1645,6 +1683,7 @@ def __init__(
+ "'{}'".format(key)
)

# if the "data" parameter is a "Raster" object
if isinstance(data, self.__class__):
self._data = data._data.copy()
# Use specified plate reconstruction (if specified),
Expand All @@ -1661,15 +1700,20 @@ def __init__(

self.plate_reconstruction = plate_reconstruction

if time < 0.0:
raise ValueError("Invalid time: {}".format(time))
time = float(time)
self._time = time
# get the geological time parameter for the time-dependant raster data
try:
time = float(time)
if time < 0.0:
raise ValueError()
self._time = time
except ValueError:
raise ValueError(f"Invalid time parameter: {time}")

if data is None:
raise TypeError("`data` argument (or `filename` or `array`) is required")

# if the user has passed a NetCDF file path
if isinstance(data, str):
# Filename
self._filename = data
self._data, lons, lats = read_netcdf_grid(
data,
Expand All @@ -1683,17 +1727,30 @@ def __init__(
)
self._lons = lons
self._lats = lats

else:
# numpy array
# if the "data" parameter is a numpy array or xarray.DataArray object
self._filename = None
extent = _parse_extent_origin(extent, origin)

# try to extract the extent from input data
# if the extent from data is different from the extent parameter, use the extent from data
extent_from_data = _find_extent_from_data(data)
if extent_from_data is not None and extent != extent_from_data:
extent = extent_from_data
logger.info(
f"Raster.__init__(): Use the extent extracted from data: {extent}."
)

data = _check_grid(data)
self._data = np.array(data)
self._lons = np.linspace(extent[0], extent[1], self.data.shape[1])
self._lats = np.linspace(extent[2], extent[3], self.data.shape[0])
if realign:
# realign to -180,180 and flip grid

# we realign the grid to -180/180 when the longitudes are from 0 to 360
# this is a temporary fix. we need a more sophisticated solution.
# for example, some people may use (-360-0) or some other ranges for longitudes. It is unlikely, but possible.
if np.max(self._lons) > 180:
# realign to -180,180 and flip grid if needed
self._data, self._lons, self._lats = _realign_grid(
self._data, self._lons, self._lats
)
Expand Down
Loading
Loading