Skip to content

Commit b33476c

Browse files
committed
Add new parameter 'required_cols' and remove the parameter 'required_z'
1 parent 9e78da0 commit b33476c

13 files changed

+60
-47
lines changed

pygmt/clib/session.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -1605,8 +1605,8 @@ def virtualfile_in( # noqa: PLR0912
16051605
x=None,
16061606
y=None,
16071607
z=None,
1608-
required_z=False,
16091608
required_data=True,
1609+
required_cols: int = 2,
16101610
):
16111611
"""
16121612
Store any data inside a virtual file.
@@ -1626,11 +1626,11 @@ def virtualfile_in( # noqa: PLR0912
16261626
data input.
16271627
x/y/z : 1-D arrays or None
16281628
x, y, and z columns as numpy arrays.
1629-
required_z : bool
1630-
State whether the 'z' column is required.
16311629
required_data : bool
16321630
Set to True when 'data' is required, or False when dealing with
16331631
optional virtual files. [Default is True].
1632+
required_cols
1633+
Number of required columns.
16341634
16351635
Returns
16361636
-------
@@ -1664,8 +1664,8 @@ def virtualfile_in( # noqa: PLR0912
16641664
x=x,
16651665
y=y,
16661666
z=z,
1667-
required_z=required_z,
16681667
required_data=required_data,
1668+
required_cols=required_cols,
16691669
kind=kind,
16701670
)
16711671

@@ -1775,8 +1775,8 @@ def virtualfile_from_data(
17751775
x=x,
17761776
y=y,
17771777
z=z,
1778-
required_z=required_z,
17791778
required_data=required_data,
1779+
required_cols=3 if required_z else 2,
17801780
)
17811781

17821782
@contextlib.contextmanager

pygmt/helpers/utils.py

+39-28
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222

2323
def _validate_data_input(
24-
data=None, x=None, y=None, z=None, required_z=False, required_data=True, kind=None
24+
data=None, x=None, y=None, z=None, required_data=True, required_cols=2, kind=None
2525
):
2626
"""
2727
Check if the combination of data/x/y/z is valid.
@@ -44,34 +44,33 @@ def _validate_data_input(
4444
Traceback (most recent call last):
4545
...
4646
pygmt.exceptions.GMTInvalidInput: Must provide both x and y.
47-
>>> _validate_data_input(x=[1, 2, 3], y=[4, 5, 6], required_z=True)
47+
>>> _validate_data_input(x=[1, 2, 3], y=[4, 5, 6], required_cols=3)
4848
Traceback (most recent call last):
4949
...
5050
pygmt.exceptions.GMTInvalidInput: Must provide x, y, and z.
5151
>>> import numpy as np
5252
>>> import pandas as pd
5353
>>> import xarray as xr
5454
>>> data = np.arange(8).reshape((4, 2))
55-
>>> _validate_data_input(data=data, required_z=True, kind="matrix")
55+
>>> _validate_data_input(data=data, required_cols=3, kind="matrix")
5656
Traceback (most recent call last):
5757
...
58-
pygmt.exceptions.GMTInvalidInput: data must provide x, y, and z columns.
58+
pygmt.exceptions.GMTInvalidInput: data needs 3 columns but 2 column(s) are given.
5959
>>> _validate_data_input(
6060
... data=pd.DataFrame(data, columns=["x", "y"]),
61-
... required_z=True,
61+
... required_cols=3,
6262
... kind="matrix",
6363
... )
6464
Traceback (most recent call last):
6565
...
66-
pygmt.exceptions.GMTInvalidInput: data must provide x, y, and z columns.
66+
pygmt.exceptions.GMTInvalidInput: data needs 3 columns but 2 column(s) are given.
6767
>>> _validate_data_input(
6868
... data=xr.Dataset(pd.DataFrame(data, columns=["x", "y"])),
69-
... required_z=True,
70-
... kind="matrix",
69+
... required_cols=3,
7170
... )
7271
Traceback (most recent call last):
7372
...
74-
pygmt.exceptions.GMTInvalidInput: data must provide x, y, and z columns.
73+
pygmt.exceptions.GMTInvalidInput: data needs 3 columns but 2 column(s) are given.
7574
>>> _validate_data_input(data="infile", x=[1, 2, 3])
7675
Traceback (most recent call last):
7776
...
@@ -94,26 +93,38 @@ def _validate_data_input(
9493
GMTInvalidInput
9594
If the data input is not valid.
9695
"""
97-
if data is None: # data is None
98-
if x is None and y is None: # both x and y are None
99-
if required_data: # data is not optional
96+
if kind is None:
97+
kind = data_kind(data, required=required_data)
98+
99+
if data is not None and any(v is not None for v in (x, y, z)):
100+
raise GMTInvalidInput("Too much data. Use either data or x/y/z.")
101+
102+
match kind:
103+
case "none":
104+
if x is None and y is None: # both x and y are None
100105
raise GMTInvalidInput("No input data provided.")
101-
elif x is None or y is None: # either x or y is None
102-
raise GMTInvalidInput("Must provide both x and y.")
103-
if required_z and z is None: # both x and y are not None, now check z
104-
raise GMTInvalidInput("Must provide x, y, and z.")
105-
else: # data is not None
106-
if x is not None or y is not None or z is not None:
107-
raise GMTInvalidInput("Too much data. Use either data or x/y/z.")
108-
# For 'matrix' kind, check if data has the required z column
109-
if kind == "matrix" and required_z:
110-
if hasattr(data, "shape"): # np.ndarray or pd.DataFrame
111-
if len(data.shape) == 1 and data.shape[0] < 3:
112-
raise GMTInvalidInput("data must provide x, y, and z columns.")
113-
if len(data.shape) > 1 and data.shape[1] < 3:
114-
raise GMTInvalidInput("data must provide x, y, and z columns.")
115-
if hasattr(data, "data_vars") and len(data.data_vars) < 3: # xr.Dataset
116-
raise GMTInvalidInput("data must provide x, y, and z columns.")
106+
if x is None or y is None: # either x or y is None
107+
raise GMTInvalidInput("Must provide both x and y.")
108+
if required_cols >= 3 and z is None:
109+
# both x and y are not None, now check z
110+
raise GMTInvalidInput("Must provide x, y, and z.")
111+
case "matrix": # 2-D numpy.ndarray
112+
if (actual_cols := data.shape[1]) < required_cols:
113+
msg = f"data needs {required_cols} columns but {actual_cols} column(s) are given."
114+
raise GMTInvalidInput(msg)
115+
case "vectors":
116+
if hasattr(data, "items") and not hasattr(data, "to_frame"):
117+
# Dict, pd.DataFrame, xr.Dataset
118+
arrays = [array for _, array in data.items()]
119+
if (actual_cols := len(arrays)) < required_cols:
120+
msg = f"data needs {required_cols} columns but {actual_cols} column(s) are given."
121+
raise GMTInvalidInput(msg)
122+
123+
# Loop over columns to make sure they're not None
124+
for idx, array in enumerate(arrays[:required_cols]):
125+
if array is None:
126+
msg = f"data needs {required_cols} columns but the {idx} column is None."
127+
raise GMTInvalidInput(msg)
117128

118129

119130
def _check_encoding(

pygmt/src/blockm.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def _blockm(
5555
with Session() as lib:
5656
with (
5757
lib.virtualfile_in(
58-
check_kind="vector", data=data, x=x, y=y, z=z, required_z=True
58+
check_kind="vector", data=data, x=x, y=y, z=z, required_cols=3
5959
) as vintbl,
6060
lib.virtualfile_out(kind="dataset", fname=outfile) as vouttbl,
6161
):

pygmt/src/contour.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def contour(self, data=None, x=None, y=None, z=None, **kwargs):
145145

146146
with Session() as lib:
147147
with lib.virtualfile_in(
148-
check_kind="vector", data=data, x=x, y=y, z=z, required_z=True
148+
check_kind="vector", data=data, x=x, y=y, z=z, required_cols=3
149149
) as vintbl:
150150
lib.call_module(
151151
module="contour", args=build_arg_list(kwargs, infile=vintbl)

pygmt/src/nearneighbor.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def nearneighbor(
140140
with Session() as lib:
141141
with (
142142
lib.virtualfile_in(
143-
check_kind="vector", data=data, x=x, y=y, z=z, required_z=True
143+
check_kind="vector", data=data, x=x, y=y, z=z, required_cols=3
144144
) as vintbl,
145145
lib.virtualfile_out(kind="grid", fname=outgrid) as voutgrd,
146146
):

pygmt/src/plot.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,7 @@ def plot( # noqa: PLR0912
211211
kind = data_kind(data)
212212
if kind == "none": # Vectors input
213213
data = {"x": x, "y": y}
214+
x, y = None, None
214215
# Parameters for vector styles
215216
if (
216217
kwargs.get("S") is not None
@@ -255,5 +256,5 @@ def plot( # noqa: PLR0912
255256
pass
256257

257258
with Session() as lib:
258-
with lib.virtualfile_in(check_kind="vector", data=data) as vintbl:
259+
with lib.virtualfile_in(check_kind="vector", data=data, x=x, y=y) as vintbl:
259260
lib.call_module(module="plot", args=build_arg_list(kwargs, infile=vintbl))

pygmt/src/plot3d.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@ def plot3d( # noqa: PLR0912
186186
kind = data_kind(data)
187187
if kind == "none": # Vectors input
188188
data = {"x": x, "y": y, "z": z}
189+
x, y, z = None, None, None
189190
# Parameters for vector styles
190191
if (
191192
kwargs.get("S") is not None
@@ -231,6 +232,6 @@ def plot3d( # noqa: PLR0912
231232

232233
with Session() as lib:
233234
with lib.virtualfile_in(
234-
check_kind="vector", data=data, required_z=True
235+
check_kind="vector", data=data, x=x, y=y, z=z, required_cols=3
235236
) as vintbl:
236237
lib.call_module(module="plot3d", args=build_arg_list(kwargs, infile=vintbl))

pygmt/src/project.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ def project(
246246
x=x,
247247
y=y,
248248
z=z,
249-
required_z=False,
249+
required_cols=2,
250250
required_data=False,
251251
) as vintbl,
252252
lib.virtualfile_out(kind="dataset", fname=outfile) as vouttbl,

pygmt/src/surface.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def surface(data=None, x=None, y=None, z=None, outgrid: str | None = None, **kwa
153153
with Session() as lib:
154154
with (
155155
lib.virtualfile_in(
156-
check_kind="vector", data=data, x=x, y=y, z=z, required_z=True
156+
check_kind="vector", data=data, x=x, y=y, z=z, required_cols=3
157157
) as vintbl,
158158
lib.virtualfile_out(kind="grid", fname=outgrid) as voutgrd,
159159
):

pygmt/src/triangulate.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def regular_grid(
138138
with Session() as lib:
139139
with (
140140
lib.virtualfile_in(
141-
check_kind="vector", data=data, x=x, y=y, z=z, required_z=False
141+
check_kind="vector", data=data, x=x, y=y, z=z, required_cols=2
142142
) as vintbl,
143143
lib.virtualfile_out(kind="grid", fname=outgrid) as voutgrd,
144144
):
@@ -238,7 +238,7 @@ def delaunay_triples(
238238
with Session() as lib:
239239
with (
240240
lib.virtualfile_in(
241-
check_kind="vector", data=data, x=x, y=y, z=z, required_z=False
241+
check_kind="vector", data=data, x=x, y=y, z=z, required_cols=2
242242
) as vintbl,
243243
lib.virtualfile_out(kind="dataset", fname=outfile) as vouttbl,
244244
):

pygmt/src/wiggle.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,6 @@ def wiggle(
108108

109109
with Session() as lib:
110110
with lib.virtualfile_in(
111-
check_kind="vector", data=data, x=x, y=y, z=z, required_z=True
111+
check_kind="vector", data=data, x=x, y=y, z=z, required_cols=3
112112
) as vintbl:
113113
lib.call_module(module="wiggle", args=build_arg_list(kwargs, infile=vintbl))

pygmt/src/xyz2grd.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def xyz2grd(data=None, x=None, y=None, z=None, outgrid: str | None = None, **kwa
145145
with Session() as lib:
146146
with (
147147
lib.virtualfile_in(
148-
check_kind="vector", data=data, x=x, y=y, z=z, required_z=True
148+
check_kind="vector", data=data, x=x, y=y, z=z, required_cols=3
149149
) as vintbl,
150150
lib.virtualfile_out(kind="grid", fname=outgrid) as voutgrd,
151151
):

pygmt/tests/test_clib_virtualfiles.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def test_virtualfile_in_required_z_matrix(array_func, kind):
141141
data = array_func(dataframe)
142142
with clib.Session() as lib:
143143
with lib.virtualfile_in(
144-
data=data, required_z=True, check_kind="vector"
144+
data=data, required_cols=3, check_kind="vector"
145145
) as vfile:
146146
with GMTTempFile() as outfile:
147147
lib.call_module("info", [vfile, f"->{outfile.name}"])
@@ -163,7 +163,7 @@ def test_virtualfile_in_required_z_matrix_missing():
163163
data = np.ones((5, 2))
164164
with clib.Session() as lib:
165165
with pytest.raises(GMTInvalidInput):
166-
with lib.virtualfile_in(data=data, required_z=True, check_kind="vector"):
166+
with lib.virtualfile_in(data=data, required_cols=3, check_kind="vector"):
167167
pass
168168

169169

@@ -190,7 +190,7 @@ def test_virtualfile_in_fail_non_valid_data(data):
190190
with clib.Session() as lib:
191191
with pytest.raises(GMTInvalidInput):
192192
lib.virtualfile_in(
193-
x=variable[0], y=variable[1], z=variable[2], required_z=True
193+
x=variable[0], y=variable[1], z=variable[2], required_cols=3
194194
)
195195

196196
# Should also fail if given too much data

0 commit comments

Comments
 (0)