Skip to content

Commit 2a1db22

Browse files
committed
Add new parameter 'required_cols' and remove the parameter 'required_z'
1 parent 9e78da0 commit 2a1db22

12 files changed

+54
-40
lines changed

pygmt/clib/session.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -1605,8 +1605,8 @@ def virtualfile_in( # noqa: PLR0912
16051605
x=None,
16061606
y=None,
16071607
z=None,
1608-
required_z=False,
16091608
required_data=True,
1609+
required_cols: int = 2,
16101610
):
16111611
"""
16121612
Store any data inside a virtual file.
@@ -1626,11 +1626,11 @@ def virtualfile_in( # noqa: PLR0912
16261626
data input.
16271627
x/y/z : 1-D arrays or None
16281628
x, y, and z columns as numpy arrays.
1629-
required_z : bool
1630-
State whether the 'z' column is required.
16311629
required_data : bool
16321630
Set to True when 'data' is required, or False when dealing with
16331631
optional virtual files. [Default is True].
1632+
required_cols
1633+
Number of required columns.
16341634
16351635
Returns
16361636
-------
@@ -1664,8 +1664,8 @@ def virtualfile_in( # noqa: PLR0912
16641664
x=x,
16651665
y=y,
16661666
z=z,
1667-
required_z=required_z,
16681667
required_data=required_data,
1668+
required_cols=required_cols,
16691669
kind=kind,
16701670
)
16711671

@@ -1775,8 +1775,8 @@ def virtualfile_from_data(
17751775
x=x,
17761776
y=y,
17771777
z=z,
1778-
required_z=required_z,
17791778
required_data=required_data,
1779+
required_cols=3 if required_z else 2,
17801780
)
17811781

17821782
@contextlib.contextmanager

pygmt/helpers/utils.py

+36-24
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222

2323
def _validate_data_input(
24-
data=None, x=None, y=None, z=None, required_z=False, required_data=True, kind=None
24+
data=None, x=None, y=None, z=None, required_data=True, required_cols=2, kind=None
2525
):
2626
"""
2727
Check if the combination of data/x/y/z is valid.
@@ -44,29 +44,29 @@ def _validate_data_input(
4444
Traceback (most recent call last):
4545
...
4646
pygmt.exceptions.GMTInvalidInput: Must provide both x and y.
47-
>>> _validate_data_input(x=[1, 2, 3], y=[4, 5, 6], required_z=True)
47+
>>> _validate_data_input(x=[1, 2, 3], y=[4, 5, 6], required_cols=3)
4848
Traceback (most recent call last):
4949
...
5050
pygmt.exceptions.GMTInvalidInput: Must provide x, y, and z.
5151
>>> import numpy as np
5252
>>> import pandas as pd
5353
>>> import xarray as xr
5454
>>> data = np.arange(8).reshape((4, 2))
55-
>>> _validate_data_input(data=data, required_z=True, kind="matrix")
55+
>>> _validate_data_input(data=data, required_cols=3, kind="matrix")
5656
Traceback (most recent call last):
5757
...
5858
pygmt.exceptions.GMTInvalidInput: data must provide x, y, and z columns.
5959
>>> _validate_data_input(
6060
... data=pd.DataFrame(data, columns=["x", "y"]),
61-
... required_z=True,
61+
... required_cols=3,
6262
... kind="matrix",
6363
... )
6464
Traceback (most recent call last):
6565
...
6666
pygmt.exceptions.GMTInvalidInput: data must provide x, y, and z columns.
6767
>>> _validate_data_input(
6868
... data=xr.Dataset(pd.DataFrame(data, columns=["x", "y"])),
69-
... required_z=True,
69+
... required_cols=3,
7070
... kind="matrix",
7171
... )
7272
Traceback (most recent call last):
@@ -94,26 +94,38 @@ def _validate_data_input(
9494
GMTInvalidInput
9595
If the data input is not valid.
9696
"""
97-
if data is None: # data is None
98-
if x is None and y is None: # both x and y are None
99-
if required_data: # data is not optional
97+
if kind is None:
98+
kind = data_kind(data, required=required_data)
99+
100+
if data is not None and any(v is not None for v in (x, y, z)):
101+
raise GMTInvalidInput("Too much data. Use either data or x/y/z.")
102+
103+
match kind:
104+
case "none":
105+
if x is None and y is None: # both x and y are None
100106
raise GMTInvalidInput("No input data provided.")
101-
elif x is None or y is None: # either x or y is None
102-
raise GMTInvalidInput("Must provide both x and y.")
103-
if required_z and z is None: # both x and y are not None, now check z
104-
raise GMTInvalidInput("Must provide x, y, and z.")
105-
else: # data is not None
106-
if x is not None or y is not None or z is not None:
107-
raise GMTInvalidInput("Too much data. Use either data or x/y/z.")
108-
# For 'matrix' kind, check if data has the required z column
109-
if kind == "matrix" and required_z:
110-
if hasattr(data, "shape"): # np.ndarray or pd.DataFrame
111-
if len(data.shape) == 1 and data.shape[0] < 3:
112-
raise GMTInvalidInput("data must provide x, y, and z columns.")
113-
if len(data.shape) > 1 and data.shape[1] < 3:
114-
raise GMTInvalidInput("data must provide x, y, and z columns.")
115-
if hasattr(data, "data_vars") and len(data.data_vars) < 3: # xr.Dataset
116-
raise GMTInvalidInput("data must provide x, y, and z columns.")
107+
if x is None or y is None: # either x or y is None
108+
raise GMTInvalidInput("Must provide both x and y.")
109+
if required_cols >= 3 and z is None:
110+
# both x and y are not None, now check z
111+
raise GMTInvalidInput("Must provide x, y, and z.")
112+
case "matrix": # 2-D numpy.ndarray
113+
if (actual_cols := data.shape[1]) < required_cols:
114+
msg = f"data needs {required_cols} columns but {actual_cols} column(s) are given."
115+
raise GMTInvalidInput(msg)
116+
case "vectors":
117+
if hasattr(data, "items") and not hasattr(data, "to_frame"):
118+
# Dict, pd.DataFrame, xr.Dataset
119+
arrays = [array for _, array in data.items()]
120+
if (actual_cols := len(arrays)) < required_cols:
121+
msg = f"data needs {required_cols} columns but {actual_cols} column(s) are given."
122+
raise GMTInvalidInput(msg)
123+
124+
# Loop over columns to make sure they're not None
125+
for idx, array in enumerate(arrays[:required_cols]):
126+
if array is None:
127+
msg = f"data needs {required_cols} columns but the {idx} column is None."
128+
raise GMTInvalidInput(msg)
117129

118130

119131
def _check_encoding(

pygmt/src/blockm.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def _blockm(
5555
with Session() as lib:
5656
with (
5757
lib.virtualfile_in(
58-
check_kind="vector", data=data, x=x, y=y, z=z, required_z=True
58+
check_kind="vector", data=data, x=x, y=y, z=z, required_cols=3
5959
) as vintbl,
6060
lib.virtualfile_out(kind="dataset", fname=outfile) as vouttbl,
6161
):

pygmt/src/contour.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def contour(self, data=None, x=None, y=None, z=None, **kwargs):
145145

146146
with Session() as lib:
147147
with lib.virtualfile_in(
148-
check_kind="vector", data=data, x=x, y=y, z=z, required_z=True
148+
check_kind="vector", data=data, x=x, y=y, z=z, required_cols=3
149149
) as vintbl:
150150
lib.call_module(
151151
module="contour", args=build_arg_list(kwargs, infile=vintbl)

pygmt/src/nearneighbor.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def nearneighbor(
140140
with Session() as lib:
141141
with (
142142
lib.virtualfile_in(
143-
check_kind="vector", data=data, x=x, y=y, z=z, required_z=True
143+
check_kind="vector", data=data, x=x, y=y, z=z, required_cols=3
144144
) as vintbl,
145145
lib.virtualfile_out(kind="grid", fname=outgrid) as voutgrd,
146146
):

pygmt/src/plot.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,7 @@ def plot( # noqa: PLR0912
211211
kind = data_kind(data)
212212
if kind == "none": # Vectors input
213213
data = {"x": x, "y": y}
214+
x, y = None, None
214215
# Parameters for vector styles
215216
if (
216217
kwargs.get("S") is not None
@@ -255,5 +256,5 @@ def plot( # noqa: PLR0912
255256
pass
256257

257258
with Session() as lib:
258-
with lib.virtualfile_in(check_kind="vector", data=data) as vintbl:
259+
with lib.virtualfile_in(check_kind="vector", data=data, x=x, y=y) as vintbl:
259260
lib.call_module(module="plot", args=build_arg_list(kwargs, infile=vintbl))

pygmt/src/plot3d.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@ def plot3d( # noqa: PLR0912
186186
kind = data_kind(data)
187187
if kind == "none": # Vectors input
188188
data = {"x": x, "y": y, "z": z}
189+
x, y, z = None, None, None
189190
# Parameters for vector styles
190191
if (
191192
kwargs.get("S") is not None
@@ -231,6 +232,6 @@ def plot3d( # noqa: PLR0912
231232

232233
with Session() as lib:
233234
with lib.virtualfile_in(
234-
check_kind="vector", data=data, required_z=True
235+
check_kind="vector", data=data, x=x, y=y, z=z, required_cols=3
235236
) as vintbl:
236237
lib.call_module(module="plot3d", args=build_arg_list(kwargs, infile=vintbl))

pygmt/src/project.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ def project(
246246
x=x,
247247
y=y,
248248
z=z,
249-
required_z=False,
249+
required_cols=2,
250250
required_data=False,
251251
) as vintbl,
252252
lib.virtualfile_out(kind="dataset", fname=outfile) as vouttbl,

pygmt/src/surface.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def surface(data=None, x=None, y=None, z=None, outgrid: str | None = None, **kwa
153153
with Session() as lib:
154154
with (
155155
lib.virtualfile_in(
156-
check_kind="vector", data=data, x=x, y=y, z=z, required_z=True
156+
check_kind="vector", data=data, x=x, y=y, z=z, required_cols=3
157157
) as vintbl,
158158
lib.virtualfile_out(kind="grid", fname=outgrid) as voutgrd,
159159
):

pygmt/src/triangulate.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def regular_grid(
138138
with Session() as lib:
139139
with (
140140
lib.virtualfile_in(
141-
check_kind="vector", data=data, x=x, y=y, z=z, required_z=False
141+
check_kind="vector", data=data, x=x, y=y, z=z, required_cols=2
142142
) as vintbl,
143143
lib.virtualfile_out(kind="grid", fname=outgrid) as voutgrd,
144144
):
@@ -238,7 +238,7 @@ def delaunay_triples(
238238
with Session() as lib:
239239
with (
240240
lib.virtualfile_in(
241-
check_kind="vector", data=data, x=x, y=y, z=z, required_z=False
241+
check_kind="vector", data=data, x=x, y=y, z=z, required_cols=2
242242
) as vintbl,
243243
lib.virtualfile_out(kind="dataset", fname=outfile) as vouttbl,
244244
):

pygmt/src/wiggle.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,6 @@ def wiggle(
108108

109109
with Session() as lib:
110110
with lib.virtualfile_in(
111-
check_kind="vector", data=data, x=x, y=y, z=z, required_z=True
111+
check_kind="vector", data=data, x=x, y=y, z=z, required_cols=3
112112
) as vintbl:
113113
lib.call_module(module="wiggle", args=build_arg_list(kwargs, infile=vintbl))

pygmt/src/xyz2grd.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def xyz2grd(data=None, x=None, y=None, z=None, outgrid: str | None = None, **kwa
145145
with Session() as lib:
146146
with (
147147
lib.virtualfile_in(
148-
check_kind="vector", data=data, x=x, y=y, z=z, required_z=True
148+
check_kind="vector", data=data, x=x, y=y, z=z, required_cols=3
149149
) as vintbl,
150150
lib.virtualfile_out(kind="grid", fname=outgrid) as voutgrd,
151151
):

0 commit comments

Comments
 (0)