Skip to content

Commit 1c517e6

Browse files
committed
Be more strict about the definitions of matrix/vectors
1 parent 366a341 commit 1c517e6

File tree

9 files changed

+51
-46
lines changed

9 files changed

+51
-46
lines changed

pygmt/clib/session.py

+12-11
Original file line numberDiff line numberDiff line change
@@ -1701,7 +1701,7 @@ def virtualfile_in( # noqa: PLR0912
17011701
if check_kind == "raster":
17021702
valid_kinds += ("grid", "image")
17031703
elif check_kind == "vector":
1704-
valid_kinds += ("matrix", "vectors", "geojson")
1704+
valid_kinds += ("none", "matrix", "vectors", "geojson")
17051705
if kind not in valid_kinds:
17061706
raise GMTInvalidInput(
17071707
f"Unrecognized data type for {check_kind}: {type(data)}"
@@ -1714,11 +1714,9 @@ def virtualfile_in( # noqa: PLR0912
17141714
"geojson": tempfile_from_geojson,
17151715
"grid": self.virtualfile_from_grid,
17161716
"image": tempfile_from_image,
1717-
# Note: virtualfile_from_matrix is not used because a matrix can be
1718-
# converted to vectors instead, and using vectors allows for better
1719-
# handling of string type inputs (e.g. for datetime data types)
1720-
"matrix": self.virtualfile_from_vectors,
1717+
"matrix": self.virtualfile_from_matrix,
17211718
"vectors": self.virtualfile_from_vectors,
1719+
"none": self.virtualfile_from_vectors,
17221720
}[kind]
17231721

17241722
# Ensure the data is an iterable (Python list or tuple)
@@ -1733,25 +1731,28 @@ def virtualfile_in( # noqa: PLR0912
17331731
)
17341732
warnings.warn(message=msg, category=RuntimeWarning, stacklevel=2)
17351733
_data = (data,) if not isinstance(data, pathlib.PurePath) else (str(data),)
1736-
elif kind == "vectors":
1734+
elif kind == "none":
17371735
_data = [np.atleast_1d(x), np.atleast_1d(y)]
17381736
if z is not None:
17391737
_data.append(np.atleast_1d(z))
17401738
if extra_arrays:
17411739
_data.extend(extra_arrays)
1742-
elif kind == "matrix": # turn 2-D arrays into list of vectors
1740+
elif kind == "vectors":
17431741
if hasattr(data, "items") and not hasattr(data, "to_frame"):
17441742
# pandas.DataFrame or xarray.Dataset types.
17451743
# pandas.Series will be handled below like a 1-D numpy.ndarray.
17461744
_data = [array for _, array in data.items()]
1747-
elif hasattr(data, "ndim") and data.ndim == 2 and data.dtype.kind in "iuf":
1745+
else:
1746+
# Python list, tuple, numpy.ndarray, and pandas.Series types
1747+
_data = np.atleast_2d(np.asanyarray(data).T)
1748+
elif kind == "matrix":
1749+
if data.dtype.kind in "iuf":
17481750
# Just use virtualfile_from_matrix for 2-D numpy.ndarray
17491751
# which are signed integer (i), unsigned integer (u) or
17501752
# floating point (f) types
1751-
_virtualfile_from = self.virtualfile_from_matrix
17521753
_data = (data,)
1753-
else:
1754-
# Python list, tuple, numpy.ndarray, and pandas.Series types
1754+
else: # turn 2-D arrays into list of vectors
1755+
_virtualfile_from = self.virtualfile_from_vectors
17551756
_data = np.atleast_2d(np.asanyarray(data).T)
17561757

17571758
# Finally create the virtualfile from the data, to be passed into GMT

pygmt/helpers/utils.py

+24-21
Original file line numberDiff line numberDiff line change
@@ -187,31 +187,30 @@ def _check_encoding(
187187
return "ISOLatin1+"
188188

189189

190-
def data_kind(
190+
def data_kind( # noqa: PLR0911
191191
data: Any, required: bool = True
192-
) -> Literal["arg", "file", "geojson", "grid", "image", "matrix", "vectors"]:
192+
) -> Literal["none", "arg", "file", "geojson", "grid", "image", "matrix", "vectors"]:
193193
"""
194194
Check the kind of data that is provided to a module.
195195
196196
Recognized data kinds are:
197197
198-
- ``"arg"``: bool, int or float, representing an optional argument, mainly used for
199-
dealing with optional virtual files
198+
- ``"none"``: None and data is required. In this case, the data is usually given via
199+
a series of vectors (e.g., x/y/z)
200+
- ``"arg"``: bool, int, float, or None (only when ``required`` is False),
201+
representing an optional argument, mainly used for dealing with optional virtual
202+
files
200203
- ``"file"``: a string or a :class:`pathlib.PurePath` object or a sequence of them,
201204
representing a file name or a list of file names
202205
- ``"geojson"``: a geo-like Python object that implements ``__geo_interface__``
203206
(e.g., geopandas.GeoDataFrame or shapely.geometry)
204207
- ``"grid"``: a :class:`xarray.DataArray` object with dimensions not equal to 3
205208
- ``"image"``: a :class:`xarray.DataArray` object with 3 dimensions
206-
- ``"matrix"``: a :class:`pandas.DataFrame` object, a 2-D :class:`numpy.ndarray`,
207-
a dictionary with array-like values, or a sequence of sequences
209+
- ``"matrix"``: a 2-D :class:`numpy.ndarray` object
210+
- ``"vectors"``: a :class:`pandas.DataFrame` object, a dictionary with array-like
211+
values, or a sequence of sequences
208212
209-
In addition, the data can be given via a series of vectors (e.g., x/y/z). In this
210-
case, the ``data`` argument is ``None`` and the data kind is determined by the
211-
``required`` argument. The data kind is ``"vectors"`` if ``required`` is ``True``,
212-
otherwise the data kind is ``"arg"``.
213-
214-
The function will fallback to ``"matrix"`` for any unrecognized data.
213+
The function will fallback to ``"vectors"`` for any unrecognized data.
215214
216215
Parameters
217216
----------
@@ -232,12 +231,12 @@ def data_kind(
232231
>>> import xarray as xr
233232
>>> import pandas as pd
234233
>>> import pathlib
235-
>>> [data_kind(data=data) for data in (2, 2.0, True, False)]
236-
['arg', 'arg', 'arg', 'arg']
237234
>>> data_kind(data=None)
238-
'vectors'
235+
'none'
239236
>>> data_kind(data=None, required=False)
240237
'arg'
238+
>>> [data_kind(data=data) for data in (2, 2.0, True, False)]
239+
['arg', 'arg', 'arg', 'arg']
241240
>>> data_kind(data="my-data-file.txt")
242241
'file'
243242
>>> data_kind(data=pathlib.Path("my-data-file.txt"))
@@ -251,16 +250,16 @@ def data_kind(
251250
>>> data_kind(data=np.arange(10).reshape((5, 2)))
252251
'matrix'
253252
>>> data_kind(data=pd.DataFrame(data={"col1": [1, 2], "col2": [3, 4]}))
254-
'matrix'
253+
'vectors'
255254
>>> data_kind(data={"x": [1, 2], "y": [3, 4]})
256-
'matrix'
255+
'vectors'
257256
>>> data_kind(data=[1, 2, 3])
258-
'matrix'
257+
'vectors'
259258
"""
260259
# data is None, so data must be given via a series of vectors (i.e., x/y/z).
261260
# The only exception is when dealing with optional virtual files.
262261
if data is None:
263-
return "vectors" if required else "arg"
262+
return "none" if required else "arg"
264263

265264
# A file or a list of files
266265
if isinstance(data, str | pathlib.PurePath) or (
@@ -282,8 +281,12 @@ def data_kind(
282281
if hasattr(data, "__geo_interface__"):
283282
return "geojson"
284283

285-
# Fallback to "matrix" for anything else
286-
return "matrix"
284+
# A 2-D numpy.ndarray
285+
if hasattr(data, "__array_interface__") and data.ndim == 2:
286+
return "matrix"
287+
288+
# Fallback to "vectors" for anything else
289+
return "vectors"
287290

288291

289292
def non_ascii_to_octal(

pygmt/src/legend.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,13 @@ def legend(self, spec=None, position="JTR+jTR+o0.2c", box="+gwhite+p1p", **kwarg
7575
if kwargs.get("F") is None:
7676
kwargs["F"] = box
7777

78-
with Session() as lib:
79-
if spec is None:
78+
match data_kind(spec):
79+
case "none":
8080
specfile = ""
81-
elif data_kind(spec) == "file" and not is_nonstr_iter(spec):
82-
# Is a file but not a list of files
81+
case kind if kind == "file" and not is_nonstr_iter(spec):
8382
specfile = spec
84-
else:
83+
case _:
8584
raise GMTInvalidInput(f"Unrecognized data type: {type(spec)}")
85+
86+
with Session() as lib:
8687
lib.call_module(module="legend", args=build_arg_list(kwargs, infile=specfile))

pygmt/src/plot.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def plot( # noqa: PLR0912
210210

211211
kind = data_kind(data)
212212
extra_arrays = []
213-
if kind == "vectors": # Add more columns for vectors input
213+
if kind == "none": # Add more columns for vectors input
214214
# Parameters for vector styles
215215
if (
216216
kwargs.get("S") is not None

pygmt/src/plot3d.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def plot3d( # noqa: PLR0912
186186
kind = data_kind(data)
187187
extra_arrays = []
188188

189-
if kind == "vectors": # Add more columns for vectors input
189+
if kind == "none": # Add more columns for vectors input
190190
# Parameters for vector styles
191191
if (
192192
kwargs.get("S") is not None

pygmt/src/text.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def text_( # noqa: PLR0912
185185
"Provide either position only, or x/y pairs, or textfiles."
186186
)
187187
kind = data_kind(textfiles)
188-
if kind == "vectors" and text is None:
188+
if kind == "none" and text is None:
189189
raise GMTInvalidInput("Must provide text with x/y pairs")
190190
else:
191191
if any(v is not None for v in (x, y, textfiles)):
@@ -227,7 +227,7 @@ def text_( # noqa: PLR0912
227227

228228
# Append text at last column. Text must be passed in as str type.
229229
confdict = {}
230-
if kind == "vectors":
230+
if kind == "none":
231231
text = np.atleast_1d(text).astype(str)
232232
encoding = _check_encoding("".join(text))
233233
if encoding != "ascii":

pygmt/src/x2sys_cross.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ def x2sys_cross(
195195
match data_kind(track):
196196
case "file":
197197
file_contexts.append(contextlib.nullcontext(track))
198-
case "matrix":
198+
case "vectors":
199199
# find suffix (-E) of trackfiles used (e.g. xyz, csv, etc) from
200200
# $X2SYS_HOME/TAGNAME/TAGNAME.tag file
201201
tagfile = Path(

pygmt/tests/test_grdtrack.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def test_grdtrack_profile(dataarray):
126126

127127
def test_grdtrack_wrong_kind_of_points_input(dataarray, dataframe):
128128
"""
129-
Run grdtrack using points input that is not a pandas.DataFrame (matrix) or file.
129+
Run grdtrack using points input that is not a pandas.DataFrame or file.
130130
"""
131131
invalid_points = dataframe.longitude.to_xarray()
132132

@@ -141,7 +141,7 @@ def test_grdtrack_wrong_kind_of_grid_input(dataarray, dataframe):
141141
"""
142142
invalid_grid = dataarray.to_dataset()
143143

144-
assert data_kind(invalid_grid) == "matrix"
144+
assert data_kind(invalid_grid) == "vectors"
145145
with pytest.raises(GMTInvalidInput):
146146
grdtrack(points=dataframe, grid=invalid_grid, newcolname="bathymetry")
147147

pygmt/tests/test_grdview.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def test_grdview_wrong_kind_of_grid(xrgrid):
5858
Run grdview using grid input that is not an xarray.DataArray or file.
5959
"""
6060
dataset = xrgrid.to_dataset() # convert xarray.DataArray to xarray.Dataset
61-
assert data_kind(dataset) == "matrix"
61+
assert data_kind(dataset) == "vectors"
6262

6363
fig = Figure()
6464
with pytest.raises(GMTInvalidInput):
@@ -238,7 +238,7 @@ def test_grdview_wrong_kind_of_drapegrid(xrgrid):
238238
Run grdview using drapegrid input that is not an xarray.DataArray or file.
239239
"""
240240
dataset = xrgrid.to_dataset() # convert xarray.DataArray to xarray.Dataset
241-
assert data_kind(dataset) == "matrix"
241+
assert data_kind(dataset) == "vectors"
242242

243243
fig = Figure()
244244
with pytest.raises(GMTInvalidInput):

0 commit comments

Comments
 (0)