|
11 | 11 | from packaging.version import Version
|
12 | 12 | from pygmt.clib.conversion import _to_numpy
|
13 | 13 |
|
| 14 | +try: |
| 15 | + import pyarrow as pa |
| 16 | + |
| 17 | + _HAS_PYARROW = True |
| 18 | +except ImportError: |
| 19 | + _HAS_PYARROW = False |
| 20 | + |
14 | 21 |
|
15 | 22 | def _check_result(result, expected_dtype):
|
16 | 23 | """
|
@@ -121,7 +128,7 @@ def test_to_numpy_ndarray_numpy_dtypes_numeric(dtype, expected_dtype):
|
121 | 128 | #
|
122 | 129 | # 1. NumPy dtypes (see above)
|
123 | 130 | # 2. pandas dtypes
|
124 |
| -# 3. PyArrow dtypes |
| 131 | +# 3. PyArrow types (see below) |
125 | 132 | #
|
126 | 133 | # pandas provides following dtypes:
|
127 | 134 | #
|
@@ -152,3 +159,82 @@ def test_to_numpy_pandas_series_numpy_dtypes_numeric(dtype, expected_dtype):
|
152 | 159 | result = _to_numpy(series)
|
153 | 160 | _check_result(result, expected_dtype)
|
154 | 161 | npt.assert_array_equal(result, series)
|
| 162 | + |
| 163 | + |
| 164 | +######################################################################################## |
| 165 | +# Test the _to_numpy function with PyArrow arrays. |
| 166 | +# |
| 167 | +# PyArrow provides the following types: |
| 168 | +# |
| 169 | +# - Numeric types: |
| 170 | +# - int8, int16, int32, int64 |
| 171 | +# - uint8, uint16, uint32, uint64 |
| 172 | +# - float16, float32, float64 |
| 173 | +# |
| 174 | +# In PyArrow, array types can be specified in two ways: |
| 175 | +# |
| 176 | +# - Using string aliases (e.g., "int8") |
| 177 | +# - Using pyarrow.DataType (e.g., ``pa.int8()``) |
| 178 | +# |
| 179 | +# Reference: https://arrow.apache.org/docs/python/api/datatypes.html |
| 180 | +######################################################################################## |
| 181 | +@pytest.mark.skipif(not _HAS_PYARROW, reason="pyarrow is not installed") |
| 182 | +@pytest.mark.parametrize( |
| 183 | + ("dtype", "expected_dtype"), |
| 184 | + [ |
| 185 | + pytest.param("int8", np.int8, id="int8"), |
| 186 | + pytest.param("int16", np.int16, id="int16"), |
| 187 | + pytest.param("int32", np.int32, id="int32"), |
| 188 | + pytest.param("int64", np.int64, id="int64"), |
| 189 | + pytest.param("uint8", np.uint8, id="uint8"), |
| 190 | + pytest.param("uint16", np.uint16, id="uint16"), |
| 191 | + pytest.param("uint32", np.uint32, id="uint32"), |
| 192 | + pytest.param("uint64", np.uint64, id="uint64"), |
| 193 | + pytest.param("float16", np.float16, id="float16"), |
| 194 | + pytest.param("float32", np.float32, id="float32"), |
| 195 | + pytest.param("float64", np.float64, id="float64"), |
| 196 | + ], |
| 197 | +) |
| 198 | +def test_to_numpy_pyarrow_array_pyarrow_dtypes_numeric(dtype, expected_dtype): |
| 199 | + """ |
| 200 | + Test the _to_numpy function with PyArrow arrays of PyArrow numeric types. |
| 201 | + """ |
| 202 | + data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] |
| 203 | + if dtype == "float16": # float16 needs special handling |
| 204 | + # Example from https://arrow.apache.org/docs/python/generated/pyarrow.float16.html |
| 205 | + data = np.array(data, dtype=np.float16) |
| 206 | + array = pa.array(data, type=dtype)[::2] |
| 207 | + result = _to_numpy(array) |
| 208 | + _check_result(result, expected_dtype) |
| 209 | + npt.assert_array_equal(result, array) |
| 210 | + |
| 211 | + |
| 212 | +@pytest.mark.skipif(not _HAS_PYARROW, reason="pyarrow is not installed") |
| 213 | +@pytest.mark.parametrize( |
| 214 | + ("dtype", "expected_dtype"), |
| 215 | + [ |
| 216 | + pytest.param("int8", np.float64, id="int8"), |
| 217 | + pytest.param("int16", np.float64, id="int16"), |
| 218 | + pytest.param("int32", np.float64, id="int32"), |
| 219 | + pytest.param("int64", np.float64, id="int64"), |
| 220 | + pytest.param("uint8", np.float64, id="uint8"), |
| 221 | + pytest.param("uint16", np.float64, id="uint16"), |
| 222 | + pytest.param("uint32", np.float64, id="uint32"), |
| 223 | + pytest.param("uint64", np.float64, id="uint64"), |
| 224 | + pytest.param("float16", np.float16, id="float16"), |
| 225 | + pytest.param("float32", np.float32, id="float32"), |
| 226 | + pytest.param("float64", np.float64, id="float64"), |
| 227 | + ], |
| 228 | +) |
| 229 | +def test_to_numpy_pyarrow_array_pyarrow_dtypes_numeric_with_na(dtype, expected_dtype): |
| 230 | + """ |
| 231 | + Test the _to_numpy function with PyArrow arrays of PyArrow numeric types and NA. |
| 232 | + """ |
| 233 | + data = [1.0, 2.0, None, 4.0, 5.0, 6.0] |
| 234 | + if dtype == "float16": # float16 needs special handling |
| 235 | + # Example from https://arrow.apache.org/docs/python/generated/pyarrow.float16.html |
| 236 | + data = np.array(data, dtype=np.float16) |
| 237 | + array = pa.array(data, type=dtype)[::2] |
| 238 | + result = _to_numpy(array) |
| 239 | + _check_result(result, expected_dtype) |
| 240 | + npt.assert_array_equal(result, array) |
0 commit comments