diff --git a/gplately/mapping/__init__.py b/gplately/mapping/__init__.py index 2bf22597..cf0a93bb 100644 --- a/gplately/mapping/__init__.py +++ b/gplately/mapping/__init__.py @@ -1 +1,6 @@ # This submodule contains code to plot maps. +# The folder is named "mapping" to avoid name conflicts with the "plot" submodule. +# The folder name is inspired by GMT(The Generic Mapping Tools). +# The PlotEngine abstract base class is defined in plot_engine.py. +# There are different PlotEngine subclasses, CartopyPlotEngine and PygmtPlotEngine, for different plotting libraries, +# such as Cartopy and PyGMT. diff --git a/gplately/mapping/cartopy_plot.py b/gplately/mapping/cartopy_plot.py index 42d3bf53..dcccfdea 100644 --- a/gplately/mapping/cartopy_plot.py +++ b/gplately/mapping/cartopy_plot.py @@ -20,6 +20,8 @@ import cartopy.crs as ccrs from geopandas.geodataframe import GeoDataFrame +from ..grids import Raster + from ..tools import EARTH_RADIUS from ..utils.plot_utils import plot_subduction_teeth from .plot_engine import PlotEngine @@ -133,3 +135,42 @@ def plot_subduction_zones( color=color, **kwargs, ) + + def plot_grid( + self, ax_or_fig, grid, projection=None, extent=(-180, 180, -90, 90), **kwargs + ): + """Plot a grid onto a map using Cartopy + + Parameters + ---------- + ax_or_fig : cartopy.mpl.geoaxes.GeoAxes + Cartopy GeoAxes instance + grid : 2D array-like + The grid data to be plotted + projection : cartopy.crs.Projection + The projection to use for the grid + extent : tuple + The extent of the grid in the form (min_lon, max_lon, min_lat, max_lat) + **kwargs : + Keyword arguments for plotting the grid. See Matplotlib's ``imshow()`` keyword arguments + `here `__. + + """ + # Override matplotlib default origin ('upper') + origin = kwargs.pop("origin", "lower") + + if isinstance(grid, Raster): + # extract extent and origin + extent = grid.extent + origin = grid.origin + data = grid.data + else: + data = grid + + return ax_or_fig.imshow( + data, + extent=extent, + transform=projection, + origin=origin, + **kwargs, + ) diff --git a/gplately/mapping/plot_engine.py b/gplately/mapping/plot_engine.py index 1606e422..e21964b5 100644 --- a/gplately/mapping/plot_engine.py +++ b/gplately/mapping/plot_engine.py @@ -45,3 +45,10 @@ def plot_subduction_zones( ): """Plot subduction zones with "teeth"(abstract method)""" pass # This is an abstract method, no implementation here. + + @abstractmethod + def plot_grid( + self, ax_or_fig, grid, projection=None, extent=(-180, 180, -90, 90), **kwargs + ): + """Plot a grid (abstract method). See :meth:`CartopyPlotEngine.plot_grid()` and :meth:`PygmtPlotEngine.plot_grid()` for details.""" + pass # This is an abstract method, no implementation here. diff --git a/gplately/mapping/pygmt_plot.py b/gplately/mapping/pygmt_plot.py index 7dbb6ea1..5cdbc613 100644 --- a/gplately/mapping/pygmt_plot.py +++ b/gplately/mapping/pygmt_plot.py @@ -15,6 +15,7 @@ # 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. # import logging +from pathlib import Path logger = logging.getLogger("gplately") try: @@ -135,3 +136,61 @@ def plot_subduction_zones( fill=color, style="f0.2/0.08+r+t", ) + + def plot_grid( + self, + ax_or_fig, + grid, + projection=None, + extent=(-180, 180, -90, 90), + cmap="gmt/geo", + nan_transparent=False, + **kwargs, + ): + """Use PyGMT to plot a grid onto a map. + + Parameters + ---------- + ax_or_fig : pygmt.Figure() + A PyGMT Figure object. + grid : Raster + A gplately Raster object or 2D array-like grid data. + projection : str + Not used currently. + extent : str or tuple + (xmin, xmax, ymin, ymax). See details at + https://www.pygmt.org/dev/tutorials/basics/regions.html + cmap : str + A built-in GMT colormaps name or a CPT file path. + nan_transparent : bool + If True, NaN values in the grid will be plotted as transparent. + **kwargs : + Additional keyword arguments. + """ + from ..grids import Raster + import xarray as xr + + # we need to convert the grid data to xarray.DataArray for pygmt.grdimage(). + if isinstance(grid, Raster): + data = xr.DataArray( + data=grid.data, + dims=["lat", "lon"], + coords=dict( + lon=(["lon"], grid.lons), + lat=(["lat"], grid.lats), + ), + ) + else: + data = xr.DataArray(grid) + + # check exisence if cmap is a CPT file + if cmap.endswith(".cpt"): + if not Path(cmap).exists(): + raise FileNotFoundError(f"The CPT file '{cmap}' does not exist.") + + ax_or_fig.grdimage( + grid=data, + cmap=cmap, + region=extent, + nan_transparent=nan_transparent, + ) diff --git a/gplately/plot.py b/gplately/plot.py index c0fe79a7..619c59d8 100644 --- a/gplately/plot.py +++ b/gplately/plot.py @@ -1,5 +1,5 @@ # -# Copyright (C) 2024-2025 The University of Sydney, Australia +# Copyright (C) 2024-2026 The University of Sydney, Australia # # This program is free software; you can redistribute it and/or modify it under # the terms of the GNU General Public License, version 2, as published by @@ -1175,7 +1175,7 @@ def plot_subduction_teeth( ) def plot_plate_polygon_by_id(self, ax, plate_id, color="black", **kwargs): - """Plot a plate polygon with the given``plate_id`` on a map. + """Plot a plate polygon with the given ``plate_id`` on a map. Parameters ---------- @@ -1221,53 +1221,54 @@ def plot_plate_id(self, *args, **kwargs): return self.plot_plate_polygon_by_id(*args, **kwargs) def plot_grid(self, ax, grid, extent=(-180, 180, -90, 90), **kwargs): - """Plot a `MaskedArray`_ raster or grid onto a map. - - .. note:: - - Plotting grid with pygmt has not been implemented yet! + """Plot a grid onto a map. The grid can be a NumPy `MaskedArray`_ object, a GPlately `Raster` object + or a time-dependent raster name. Parameters ---------- ax : - Cartopy ax. + Cartopy ax or pygmt figure object. - grid : MaskedArray or Raster + grid : NumPy `MaskedArray`_, GPlately `Raster` or a time-dependent raster name. A `MaskedArray`_ with elements that define a grid. The number of rows in the raster corresponds to the number of latitudinal coordinates, while the number of raster columns corresponds to the number of longitudinal coordinates. + Alternatively, a GPlately `Raster` object can be provided. + If a raster name is provided, the raster will be looked up from the time-dependent rasters registered in the Plate Model Manager. + The :class:`gplately.PlateReconstruction` object must be created with a valid :class:`PlateModel` object. extent : tuple, default=(-180, 180, -90, 90) A tuple of 4 (min_lon, max_lon, min_lat, max_lat) representing the extent of gird. - **kwargs : - Keyword arguments for plotting the grid. - See Matplotlib's ``imshow()`` keyword arguments - `here `__. + + .. note:: + + The parameters of this function are different for different plot engines. See :meth:`CartopyPlotEngine.plot_grid` + and :meth:`PyGMTPlotEngine.plot_grid` for details. """ - if not isinstance(self._plot_engine, CartopyPlotEngine): - raise NotImplementedError( - f"Plotting grid has not been implemented for {self._plot_engine.__class__} yet." - ) - # Override matplotlib default origin ('upper') - origin = kwargs.pop("origin", "lower") - if isinstance(grid, Raster): - # extract extent and origin - extent = grid.extent - origin = grid.origin - data = grid.data - else: - data = grid + # TODO: the parameters of this function need to be unified for different plot engines. - return ax.imshow( - data, - extent=extent, - transform=self.base_projection, - origin=origin, - **kwargs, - ) + if isinstance(grid, str): # grid is a raster name + if not self.plate_reconstruction.plate_model: + raise Exception( + "The 'plate_reconstruction' does not have a valid 'plate_model' object. " + "Cannot look up the raster by name. Make sure to create the 'plate_reconstruction' with a valid 'plate_model' object." + ) + + grid_data = Raster( + data=self.plate_reconstruction.plate_model.get_raster(grid, self.time), + plate_reconstruction=self.plate_reconstruction, + extent=(-180, 180, -90, 90), + ) + return self._plot_engine.plot_grid( + ax, grid_data, extent=extent, projection=self.base_projection, **kwargs + ) + else: # grid is a MaskedArray or Raster object + return self._plot_engine.plot_grid( + ax, grid, extent=extent, projection=self.base_projection, **kwargs + ) def plot_grid_from_netCDF(self, ax, filename, **kwargs): """Read raster data from a netCDF file, convert the data into a `MaskedArray`_ object and plot it on a map. diff --git a/tests-dir/unittest/test_pygmt_plot.py b/tests-dir/unittest/test_pygmt_plot.py index d97a7774..ace85df2 100755 --- a/tests-dir/unittest/test_pygmt_plot.py +++ b/tests-dir/unittest/test_pygmt_plot.py @@ -1,36 +1,52 @@ #!/usr/bin/env python3 + +# This test script generates a sample plot using the PygmtPlotEngine. + import os +import pygmt + os.environ["DISABLE_GPLATELY_DEV_WARNING"] = "true" from gplately.auxiliary import get_gplot, get_pygmt_basemap_figure from gplately.mapping.pygmt_plot import PygmtPlotEngine if __name__ == "__main__": + model_name = "muller2019" + reconstruction_name = 55 + gplot = get_gplot( - "merdith2021", "plate-model-repo", time=55, plot_engine=PygmtPlotEngine() + model_name, + "plate-model-repo", + time=reconstruction_name, + plot_engine=PygmtPlotEngine(), ) fig = get_pygmt_basemap_figure(projection="N180/10c", region="d") + + gplot.plot_grid( + fig, "AgeGrids", cmap="create-age-grids-video/agegrid.cpt", nan_transparent=True + ) + # fig.coast(shorelines=True) + gplot.plot_coastlines( + fig, + edgecolor="none", + facecolor="gray", + linewidth=0.1, + central_meridian=180, + gmtlabel="Coastlines", + ) gplot.plot_topological_plate_boundaries( fig, edgecolor="black", linewidth=0.25, central_meridian=180, - gmtlabel="plate boundaries", - ) - gplot.plot_coastlines( - fig, edgecolor="none", facecolor="gray", linewidth=0.1, central_meridian=180 + gmtlabel="Plate Boundaries", ) - gplot.plot_ridges(fig, pen="0.5p,red", gmtlabel="ridges") - gplot.plot_transforms(fig, pen="0.5p,red", gmtlabel="transforms") - gplot.plot_subduction_teeth(fig, color="blue", gmtlabel="subduction zones") - - try: - gplot.plot_grid(fig, None) - except NotImplementedError as e: - print(e) + gplot.plot_ridges(fig, pen="0.5p,black", gmtlabel="Ridges") + gplot.plot_transforms(fig, pen="0.5p,green", gmtlabel="Transforms") + gplot.plot_subduction_teeth(fig, color="blue", gmtlabel="Subduction Zones") try: gplot.plot_plate_motion_vectors(fig) @@ -38,13 +54,16 @@ print(e) fig.text( - text="55Ma (Merdith2021)", + text=f"{reconstruction_name}Ma", position="TC", no_clip=True, font="12p,Helvetica,black", offset="j0/-0.5c", ) - fig.legend(position="jBL+o-2.7/0", box="+gwhite+p0.5p") + with pygmt.config(FONT_ANNOT_PRIMARY=4): + fig.legend(position="jBL+o-1.0/0", box="+gwhite+p0.25p") # fig.show(width=1200) - fig.savefig("./output/test-pygmt-plot.pdf") + output_file = "./output/test-pygmt-plot.pdf" + fig.savefig(output_file) + print(f"The figure has been saved to: {output_file}.")