Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions gplately/mapping/cartopy_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import cartopy.crs as ccrs
from geopandas.geodataframe import GeoDataFrame

from ..grids import Raster

from ..tools import EARTH_RADIUS
from ..utils.plot_utils import plot_subduction_teeth
from .plot_engine import PlotEngine
Expand Down Expand Up @@ -133,3 +135,42 @@ def plot_subduction_zones(
color=color,
**kwargs,
)

def plot_grid(
self, ax_or_fig, grid, projection=None, extent=(-180, 180, -90, 90), **kwargs
):
"""Plot a grid onto a map using Cartopy

Parameters
----------
ax_or_fig : cartopy.mpl.geoaxes.GeoAxes
Cartopy GeoAxes instance
grid : 2D array-like
The grid data to be plotted
projection : cartopy.crs.Projection
The projection to use for the grid
extent : tuple
The extent of the grid in the form (min_lon, max_lon, min_lat, max_lat)
**kwargs :
Keyword arguments for plotting the grid. See Matplotlib's ``imshow()`` keyword arguments
`here <https://matplotlib.org/3.5.1/api/_as_gen/matplotlib.axes.Axes.imshow.html>`__.

"""
# Override matplotlib default origin ('upper')
origin = kwargs.pop("origin", "lower")

if isinstance(grid, Raster):
# extract extent and origin
extent = grid.extent
origin = grid.origin
data = grid.data
else:
data = grid

return ax_or_fig.imshow(
data,
extent=extent,
transform=projection,
origin=origin,
**kwargs,
)
7 changes: 7 additions & 0 deletions gplately/mapping/plot_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,10 @@ def plot_subduction_zones(
):
"""Plot subduction zones with "teeth"(abstract method)"""
pass # This is an abstract method, no implementation here.

@abstractmethod
def plot_grid(
self, ax_or_fig, grid, projection=None, extent=(-180, 180, -90, 90), **kwargs
):
"""Plot a grid (abstract method)"""
pass # This is an abstract method, no implementation here.
60 changes: 60 additions & 0 deletions gplately/mapping/pygmt_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,63 @@ def plot_subduction_zones(
fill=color,
style="f0.2/0.08+r+t",
)

def plot_grid(
self,
ax_or_fig,
grid,
projection=None,
extent=(-180, 180, -90, 90),
nan_transparent=False,
**kwargs,
):
"""Use PyGMT to plot a grid onto a map.

Parameters
----------
ax_or_fig : pygmt.Figure()
pygmt Figure object
grid : Raster
gplately Raster object or 2D array-like grid data
projection : str
GMT projection string, e.g., "M6i" for Mercator projection with 6-inch width.
extent : tuple
(min_lon, max_lon, min_lat, max_lat)
cmap : str
Colormap name
shading : str
Shading method, e.g., "a" for artificial illumination.
"""
from ..grids import Raster
import xarray as xr

if isinstance(grid, Raster):
# extract extent and origin
extent = grid.extent
origin = grid.origin
data = xr.DataArray(
data=grid.data,
dims=["lat", "lon"],
coords=dict(
lon=(["lon"], grid.lons),
lat=(["lat"], grid.lats),
),
)
else:
data = xr.DataArray(grid)

region = [extent[0], extent[1], extent[2], extent[3]]

ax_or_fig.grdimage(
grid=data,
cmap="gmt/geo",
nan_transparent=nan_transparent,
)
"""
region=region,
projection=projection,
# cmap=cmap, cmap="YlGnBu",
# shading=shading,
frame=False,
**kwargs,
)"""
67 changes: 33 additions & 34 deletions gplately/plot.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (C) 2024-2025 The University of Sydney, Australia
# Copyright (C) 2024-2026 The University of Sydney, Australia
#
# This program is free software; you can redistribute it and/or modify it under
# the terms of the GNU General Public License, version 2, as published by
Expand Down Expand Up @@ -1175,7 +1175,7 @@ def plot_subduction_teeth(
)

def plot_plate_polygon_by_id(self, ax, plate_id, color="black", **kwargs):
"""Plot a plate polygon with the given``plate_id`` on a map.
"""Plot a plate polygon with the given ``plate_id`` on a map.

Parameters
----------
Expand Down Expand Up @@ -1221,53 +1221,52 @@ def plot_plate_id(self, *args, **kwargs):
return self.plot_plate_polygon_by_id(*args, **kwargs)

def plot_grid(self, ax, grid, extent=(-180, 180, -90, 90), **kwargs):
"""Plot a `MaskedArray`_ raster or grid onto a map.

.. note::

Plotting grid with pygmt has not been implemented yet!
"""Plot a grid onto a map. The grid can be a NumPy `MaskedArray`_ object, a GPlately `Raster` object
or a time-dependent raster name.

Parameters
----------
ax :
Cartopy ax.
Cartopy ax or pygmt figure object.

grid : MaskedArray or Raster
grid : NumPy `MaskedArray`_, GPlately `Raster` or a time-dependent raster name.
A `MaskedArray`_ with elements that define a grid. The number of rows in the raster
corresponds to the number of latitudinal coordinates, while the number of raster
columns corresponds to the number of longitudinal coordinates.
Alternatively, a GPlately `Raster` object can be provided.
If a raster name is provided, the raster will be looked up from the time-dependent rasters registered in the Plate Model Manager.
The :class:`gplately.PlateReconstruction` object must be created with a valid :class:`PlateModel` object.

extent : tuple, default=(-180, 180, -90, 90)
A tuple of 4 (min_lon, max_lon, min_lat, max_lat) representing the extent of gird.

**kwargs :
Keyword arguments for plotting the grid.
See Matplotlib's ``imshow()`` keyword arguments
`here <https://matplotlib.org/3.5.1/api/_as_gen/matplotlib.axes.Axes.imshow.html>`__.

"""
if not isinstance(self._plot_engine, CartopyPlotEngine):
raise NotImplementedError(
f"Plotting grid has not been implemented for {self._plot_engine.__class__} yet."
)
# Override matplotlib default origin ('upper')
origin = kwargs.pop("origin", "lower")
.. note::

if isinstance(grid, Raster):
# extract extent and origin
extent = grid.extent
origin = grid.origin
data = grid.data
else:
data = grid
The parameters of this function are different for different plot engines. See `CartopyPlotEngine.plot_grid`
and `PyGMTPlotEngine.plot_grid` for details.

return ax.imshow(
data,
extent=extent,
transform=self.base_projection,
origin=origin,
**kwargs,
)
"""

if isinstance(grid, str): # grid is a raster name
if not self.plate_reconstruction.plate_model:
raise Exception(
"The 'plate_reconstruction' does not have a valid 'plate_model'. "
"Cannot look up the raster by name. Make sure to create the 'plate_reconstruction' with a valid 'plate_model'."
)

grid_data = Raster(
data=self.plate_reconstruction.plate_model.get_raster(grid, self.time),
plate_reconstruction=self.plate_reconstruction,
extent=(-180, 180, -90, 90),
)
return self._plot_engine.plot_grid(
ax, grid_data, extent=extent, projection=self.base_projection, **kwargs
)
else: # grid is a MaskedArray or Raster object
return self._plot_engine.plot_grid(
ax, grid, extent=extent, projection=self.base_projection, **kwargs
)

def plot_grid_from_netCDF(self, ax, filename, **kwargs):
"""Read raster data from a netCDF file, convert the data into a `MaskedArray`_ object and plot it on a map.
Expand Down
22 changes: 14 additions & 8 deletions tests-dir/unittest/test_pygmt_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,19 @@
from gplately.mapping.pygmt_plot import PygmtPlotEngine

if __name__ == "__main__":
model_name = "muller2019"
reconstruction_name = 55

gplot = get_gplot(
"merdith2021", "plate-model-repo", time=55, plot_engine=PygmtPlotEngine()
model_name,
"plate-model-repo",
time=reconstruction_name,
plot_engine=PygmtPlotEngine(),
)
fig = get_pygmt_basemap_figure(projection="N180/10c", region="d")

gplot.plot_grid(fig, "AgeGrids", nan_transparent=True)

# fig.coast(shorelines=True)

gplot.plot_topological_plate_boundaries(
Expand All @@ -27,18 +36,13 @@
gplot.plot_transforms(fig, pen="0.5p,red", gmtlabel="transforms")
gplot.plot_subduction_teeth(fig, color="blue", gmtlabel="subduction zones")

try:
gplot.plot_grid(fig, None)
except NotImplementedError as e:
print(e)

try:
gplot.plot_plate_motion_vectors(fig)
except NotImplementedError as e:
print(e)

fig.text(
text="55Ma (Merdith2021)",
text=f"{reconstruction_name}Ma ({model_name})",
position="TC",
no_clip=True,
font="12p,Helvetica,black",
Expand All @@ -47,4 +51,6 @@
fig.legend(position="jBL+o-2.7/0", box="+gwhite+p0.5p")

# fig.show(width=1200)
fig.savefig("./output/test-pygmt-plot.pdf")
output_file = "./output/test-pygmt-plot.pdf"
fig.savefig(output_file)
print(f"The figure has been saved to: {output_file}.")