Skip to content

Commit 189f376

Browse files
authored
clib.conversion._to_numpy: Add tests for pyarrow.array with pyarrow numeric types (#3599)
1 parent b6f3e2b commit 189f376

File tree

1 file changed

+87
-1
lines changed

1 file changed

+87
-1
lines changed

Diff for: pygmt/tests/test_clib_to_numpy.py

+87-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,13 @@
1111
from packaging.version import Version
1212
from pygmt.clib.conversion import _to_numpy
1313

14+
try:
15+
import pyarrow as pa
16+
17+
_HAS_PYARROW = True
18+
except ImportError:
19+
_HAS_PYARROW = False
20+
1421

1522
def _check_result(result, expected_dtype):
1623
"""
@@ -121,7 +128,7 @@ def test_to_numpy_ndarray_numpy_dtypes_numeric(dtype, expected_dtype):
121128
#
122129
# 1. NumPy dtypes (see above)
123130
# 2. pandas dtypes
124-
# 3. PyArrow dtypes
131+
# 3. PyArrow types (see below)
125132
#
126133
# pandas provides following dtypes:
127134
#
@@ -152,3 +159,82 @@ def test_to_numpy_pandas_series_numpy_dtypes_numeric(dtype, expected_dtype):
152159
result = _to_numpy(series)
153160
_check_result(result, expected_dtype)
154161
npt.assert_array_equal(result, series)
162+
163+
164+
########################################################################################
165+
# Test the _to_numpy function with PyArrow arrays.
166+
#
167+
# PyArrow provides the following types:
168+
#
169+
# - Numeric types:
170+
# - int8, int16, int32, int64
171+
# - uint8, uint16, uint32, uint64
172+
# - float16, float32, float64
173+
#
174+
# In PyArrow, array types can be specified in two ways:
175+
#
176+
# - Using string aliases (e.g., "int8")
177+
# - Using pyarrow.DataType (e.g., ``pa.int8()``)
178+
#
179+
# Reference: https://arrow.apache.org/docs/python/api/datatypes.html
180+
########################################################################################
181+
@pytest.mark.skipif(not _HAS_PYARROW, reason="pyarrow is not installed")
182+
@pytest.mark.parametrize(
183+
("dtype", "expected_dtype"),
184+
[
185+
pytest.param("int8", np.int8, id="int8"),
186+
pytest.param("int16", np.int16, id="int16"),
187+
pytest.param("int32", np.int32, id="int32"),
188+
pytest.param("int64", np.int64, id="int64"),
189+
pytest.param("uint8", np.uint8, id="uint8"),
190+
pytest.param("uint16", np.uint16, id="uint16"),
191+
pytest.param("uint32", np.uint32, id="uint32"),
192+
pytest.param("uint64", np.uint64, id="uint64"),
193+
pytest.param("float16", np.float16, id="float16"),
194+
pytest.param("float32", np.float32, id="float32"),
195+
pytest.param("float64", np.float64, id="float64"),
196+
],
197+
)
198+
def test_to_numpy_pyarrow_array_pyarrow_dtypes_numeric(dtype, expected_dtype):
199+
"""
200+
Test the _to_numpy function with PyArrow arrays of PyArrow numeric types.
201+
"""
202+
data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
203+
if dtype == "float16": # float16 needs special handling
204+
# Example from https://arrow.apache.org/docs/python/generated/pyarrow.float16.html
205+
data = np.array(data, dtype=np.float16)
206+
array = pa.array(data, type=dtype)[::2]
207+
result = _to_numpy(array)
208+
_check_result(result, expected_dtype)
209+
npt.assert_array_equal(result, array)
210+
211+
212+
@pytest.mark.skipif(not _HAS_PYARROW, reason="pyarrow is not installed")
213+
@pytest.mark.parametrize(
214+
("dtype", "expected_dtype"),
215+
[
216+
pytest.param("int8", np.float64, id="int8"),
217+
pytest.param("int16", np.float64, id="int16"),
218+
pytest.param("int32", np.float64, id="int32"),
219+
pytest.param("int64", np.float64, id="int64"),
220+
pytest.param("uint8", np.float64, id="uint8"),
221+
pytest.param("uint16", np.float64, id="uint16"),
222+
pytest.param("uint32", np.float64, id="uint32"),
223+
pytest.param("uint64", np.float64, id="uint64"),
224+
pytest.param("float16", np.float16, id="float16"),
225+
pytest.param("float32", np.float32, id="float32"),
226+
pytest.param("float64", np.float64, id="float64"),
227+
],
228+
)
229+
def test_to_numpy_pyarrow_array_pyarrow_dtypes_numeric_with_na(dtype, expected_dtype):
230+
"""
231+
Test the _to_numpy function with PyArrow arrays of PyArrow numeric types and NA.
232+
"""
233+
data = [1.0, 2.0, None, 4.0, 5.0, 6.0]
234+
if dtype == "float16": # float16 needs special handling
235+
# Example from https://arrow.apache.org/docs/python/generated/pyarrow.float16.html
236+
data = np.array(data, dtype=np.float16)
237+
array = pa.array(data, type=dtype)[::2]
238+
result = _to_numpy(array)
239+
_check_result(result, expected_dtype)
240+
npt.assert_array_equal(result, array)

0 commit comments

Comments
 (0)