21
21
22
22
23
23
def _validate_data_input (
24
- data = None , x = None , y = None , z = None , required_z = False , required_data = True , kind = None
24
+ data = None , x = None , y = None , z = None , required_data = True , required_cols = 2 , kind = None
25
25
):
26
26
"""
27
27
Check if the combination of data/x/y/z is valid.
@@ -44,34 +44,33 @@ def _validate_data_input(
44
44
Traceback (most recent call last):
45
45
...
46
46
pygmt.exceptions.GMTInvalidInput: Must provide both x and y.
47
- >>> _validate_data_input(x=[1, 2, 3], y=[4, 5, 6], required_z=True )
47
+ >>> _validate_data_input(x=[1, 2, 3], y=[4, 5, 6], required_cols=3 )
48
48
Traceback (most recent call last):
49
49
...
50
50
pygmt.exceptions.GMTInvalidInput: Must provide x, y, and z.
51
51
>>> import numpy as np
52
52
>>> import pandas as pd
53
53
>>> import xarray as xr
54
54
>>> data = np.arange(8).reshape((4, 2))
55
- >>> _validate_data_input(data=data, required_z=True , kind="matrix")
55
+ >>> _validate_data_input(data=data, required_cols=3 , kind="matrix")
56
56
Traceback (most recent call last):
57
57
...
58
- pygmt.exceptions.GMTInvalidInput: data must provide x, y, and z columns .
58
+ pygmt.exceptions.GMTInvalidInput: data needs 3 columns but 2 column(s) are given .
59
59
>>> _validate_data_input(
60
60
... data=pd.DataFrame(data, columns=["x", "y"]),
61
- ... required_z=True ,
61
+ ... required_cols=3 ,
62
62
... kind="matrix",
63
63
... )
64
64
Traceback (most recent call last):
65
65
...
66
- pygmt.exceptions.GMTInvalidInput: data must provide x, y, and z columns .
66
+ pygmt.exceptions.GMTInvalidInput: data needs 3 columns but 2 column(s) are given .
67
67
>>> _validate_data_input(
68
68
... data=xr.Dataset(pd.DataFrame(data, columns=["x", "y"])),
69
- ... required_z=True,
70
- ... kind="matrix",
69
+ ... required_cols=3,
71
70
... )
72
71
Traceback (most recent call last):
73
72
...
74
- pygmt.exceptions.GMTInvalidInput: data must provide x, y, and z columns .
73
+ pygmt.exceptions.GMTInvalidInput: data needs 3 columns but 2 column(s) are given .
75
74
>>> _validate_data_input(data="infile", x=[1, 2, 3])
76
75
Traceback (most recent call last):
77
76
...
@@ -94,26 +93,38 @@ def _validate_data_input(
94
93
GMTInvalidInput
95
94
If the data input is not valid.
96
95
"""
97
- if data is None : # data is None
98
- if x is None and y is None : # both x and y are None
99
- if required_data : # data is not optional
96
+ if kind is None :
97
+ kind = data_kind (data , required = required_data )
98
+
99
+ if data is not None and any (v is not None for v in (x , y , z )):
100
+ raise GMTInvalidInput ("Too much data. Use either data or x/y/z." )
101
+
102
+ match kind :
103
+ case "none" :
104
+ if x is None and y is None : # both x and y are None
100
105
raise GMTInvalidInput ("No input data provided." )
101
- elif x is None or y is None : # either x or y is None
102
- raise GMTInvalidInput ("Must provide both x and y." )
103
- if required_z and z is None : # both x and y are not None, now check z
104
- raise GMTInvalidInput ("Must provide x, y, and z." )
105
- else : # data is not None
106
- if x is not None or y is not None or z is not None :
107
- raise GMTInvalidInput ("Too much data. Use either data or x/y/z." )
108
- # For 'matrix' kind, check if data has the required z column
109
- if kind == "matrix" and required_z :
110
- if hasattr (data , "shape" ): # np.ndarray or pd.DataFrame
111
- if len (data .shape ) == 1 and data .shape [0 ] < 3 :
112
- raise GMTInvalidInput ("data must provide x, y, and z columns." )
113
- if len (data .shape ) > 1 and data .shape [1 ] < 3 :
114
- raise GMTInvalidInput ("data must provide x, y, and z columns." )
115
- if hasattr (data , "data_vars" ) and len (data .data_vars ) < 3 : # xr.Dataset
116
- raise GMTInvalidInput ("data must provide x, y, and z columns." )
106
+ if x is None or y is None : # either x or y is None
107
+ raise GMTInvalidInput ("Must provide both x and y." )
108
+ if required_cols >= 3 and z is None :
109
+ # both x and y are not None, now check z
110
+ raise GMTInvalidInput ("Must provide x, y, and z." )
111
+ case "matrix" : # 2-D numpy.ndarray
112
+ if (actual_cols := data .shape [1 ]) < required_cols :
113
+ msg = f"data needs { required_cols } columns but { actual_cols } column(s) are given."
114
+ raise GMTInvalidInput (msg )
115
+ case "vectors" :
116
+ if hasattr (data , "items" ) and not hasattr (data , "to_frame" ):
117
+ # Dict, pd.DataFrame, xr.Dataset
118
+ arrays = [array for _ , array in data .items ()]
119
+ if (actual_cols := len (arrays )) < required_cols :
120
+ msg = f"data needs { required_cols } columns but { actual_cols } column(s) are given."
121
+ raise GMTInvalidInput (msg )
122
+
123
+ # Loop over columns to make sure they're not None
124
+ for idx , array in enumerate (arrays [:required_cols ]):
125
+ if array is None :
126
+ msg = f"data needs { required_cols } columns but the { idx } column is None."
127
+ raise GMTInvalidInput (msg )
117
128
118
129
119
130
def _check_encoding (
0 commit comments