Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions spharpy/plot/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""Private utility functions for plot module."""

import numpy as np
import matplotlib.pyplot as plt


def _prepare_plot(ax=None, projection=None):
"""
Returns a figure to plot on.

Parameters
----------
ax : matplotlib.axes.Axes or list, tuple or ndarray of maplotlib.axes.Axes
Axes to plot on. The default is None in which case the axes are
obtained from the current figure. A new figure is created if it does
not exist.
projection : str, optional
Type of projection for the axes. This is only applied if new axes are
created. Default is ``None`` (2D projection). See
`matplotlib.projections <https://matplotlib.org/stable/api/projections_api.html>`_
for more information on projections.

Returns
-------
fig : matplotlib.figure.Figure
Returns the active figure.
ax : maptlotlib.axes.Axes
Returns the current axes.
""" # noqa: E501
if ax is None:
# get current figure or create a new one
fig = plt.gcf()
if fig.axes:
ax = plt.gca()
else:
ax = plt.axes(projection=projection)

else:
# get figure from axis
# ax can be array or tuple of two ax objects
# first axis for the plot, second axis for colorbar placement
if isinstance(ax, np.ndarray):
fig = ax.flatten()[0].figure
elif isinstance(ax, (list, tuple)):
fig = ax[0].figure
else:
fig = ax.figure

return fig, ax
32 changes: 32 additions & 0 deletions tests/test_plot__utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import pytest
import matplotlib.pyplot as plt
from spharpy.plot._utils import _prepare_plot


@pytest.mark.parametrize(
("ax_case", "output_type", "projection"),
[("none", plt.Axes, "3d"), ("single", plt.Axes, None),
("two", list, None)],
)
def test_prepare_plot(ax_case, output_type, projection):
"""
Test output of :py:func:`~spharpy.plot._utils._prepare_plot`.
"""
if ax_case == "none":
input_ax = None
elif ax_case == "single":
_, input_ax = plt.subplots()
else:
_, axs = plt.subplots(1, 2)
input_ax = [axs[0], axs[1]]

fig, ax = _prepare_plot(input_ax, projection)

assert isinstance(fig, plt.Figure)
assert isinstance(ax, output_type)

if isinstance(ax, list):
assert all(isinstance(ax_, plt.Axes) for ax_ in ax)

if ax is None and projection is not None:
assert ax.name == projection