From d98227544604954840f6251b0ec5b7e774a3fef3 Mon Sep 17 00:00:00 2001 From: Wei Ji <23487320+weiji14@users.noreply.github.com> Date: Fri, 15 Nov 2024 16:13:13 +1300 Subject: [PATCH] pyarrow: Check compatibility of pyarrow.array with string type (#2933) Co-authored-by: Dongdong Tian --- doc/conf.py | 1 + doc/ecosystem.md | 6 ++-- pygmt/_typing.py | 10 ++++++ pygmt/clib/conversion.py | 5 +-- pygmt/src/text.py | 6 ++-- .../test_clib_virtualfile_from_vectors.py | 33 +++++++++++++++---- pygmt/tests/test_text.py | 22 +++++++++++-- 7 files changed, 66 insertions(+), 17 deletions(-) diff --git a/doc/conf.py b/doc/conf.py index 1586601804d..613c860aa75 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -85,6 +85,7 @@ "contextily": ("https://contextily.readthedocs.io/en/stable/", None), "geopandas": ("https://geopandas.org/en/stable/", None), "numpy": ("https://numpy.org/doc/stable/", None), + "pyarrow": ("https://arrow.apache.org/docs/", None), "python": ("https://docs.python.org/3/", None), "pandas": ("https://pandas.pydata.org/pandas-docs/stable/", None), "rasterio": ("https://rasterio.readthedocs.io/en/stable/", None), diff --git a/doc/ecosystem.md b/doc/ecosystem.md index 3e265c2c5eb..0f43835d7d3 100644 --- a/doc/ecosystem.md +++ b/doc/ecosystem.md @@ -94,9 +94,9 @@ Python objects. They are based on the C++ implementation of Arrow. ```{note} If you have [PyArrow][] installed, PyGMT does have some initial support for `pandas.Series` and `pandas.DataFrame` objects with Apache Arrow-backed arrays. -Specifically, only uint/int/float and date32/date64 are supported for now. -Support for string Array dtypes, Duration types and GeoArrow geometry types is still a work in progress. -For more details, see +Specifically, only uint/int/float, date32/date64 and string types are supported for now. +Support for Duration types and GeoArrow geometry types is still a work in progress. For +more details, see [issue #2800](https://github.com/GenericMappingTools/pygmt/issues/2800). ``` diff --git a/pygmt/_typing.py b/pygmt/_typing.py index bbc7d596c65..4a57c3c7678 100644 --- a/pygmt/_typing.py +++ b/pygmt/_typing.py @@ -2,7 +2,17 @@ Type aliases for type hints. """ +import contextlib +import importlib +from collections.abc import Sequence from typing import Literal +import numpy as np + # Anchor codes AnchorCode = Literal["TL", "TC", "TR", "ML", "MC", "MR", "BL", "BC", "BR"] + +# String array types +StringArrayTypes = Sequence[str] | np.ndarray +with contextlib.suppress(ImportError): + StringArrayTypes |= importlib.import_module(name="pyarrow").StringArray diff --git a/pygmt/clib/conversion.py b/pygmt/clib/conversion.py index 4716fe09b5e..68cddd63549 100644 --- a/pygmt/clib/conversion.py +++ b/pygmt/clib/conversion.py @@ -280,12 +280,13 @@ def sequence_to_ctypes_array( def strings_to_ctypes_array(strings: Sequence[str] | np.ndarray) -> ctp.Array: """ - Convert a sequence (e.g., a list) of strings into a ctypes array. + Convert a sequence (e.g., a list) of strings or numpy.ndarray of strings into a + ctypes array. Parameters ---------- strings - A sequence of strings. + A sequence of strings, or a numpy.ndarray of str dtype. Returns ------- diff --git a/pygmt/src/text.py b/pygmt/src/text.py index 2ed475c9ac2..ad98711824b 100644 --- a/pygmt/src/text.py +++ b/pygmt/src/text.py @@ -5,7 +5,7 @@ from collections.abc import Sequence import numpy as np -from pygmt._typing import AnchorCode +from pygmt._typing import AnchorCode, StringArrayTypes from pygmt.clib import Session from pygmt.exceptions import GMTInvalidInput from pygmt.helpers import ( @@ -48,7 +48,7 @@ def text_( # noqa: PLR0912 x=None, y=None, position: AnchorCode | None = None, - text=None, + text: str | StringArrayTypes | None = None, angle=None, font=None, justify: bool | None | AnchorCode | Sequence[AnchorCode] = None, @@ -104,7 +104,7 @@ def text_( # noqa: PLR0912 For example, ``position="TL"`` plots the text at the Top Left corner of the map. - text : str or 1-D array + text The text string, or an array of strings to plot on the figure. angle: float, str, bool or list Set the angle measured in degrees counter-clockwise from diff --git a/pygmt/tests/test_clib_virtualfile_from_vectors.py b/pygmt/tests/test_clib_virtualfile_from_vectors.py index 041bc7a803c..b76a9bfe168 100644 --- a/pygmt/tests/test_clib_virtualfile_from_vectors.py +++ b/pygmt/tests/test_clib_virtualfile_from_vectors.py @@ -11,6 +11,14 @@ from pygmt.clib.session import DTYPES_NUMERIC from pygmt.exceptions import GMTInvalidInput from pygmt.helpers import GMTTempFile +from pygmt.helpers.testing import skip_if_no + +try: + import pyarrow as pa + + pa_array = pa.array +except ImportError: + pa_array = None @pytest.fixture(scope="module", name="dtypes") @@ -53,17 +61,30 @@ def test_virtualfile_from_vectors(dtypes): @pytest.mark.benchmark -@pytest.mark.parametrize("dtype", [str, object]) -def test_virtualfile_from_vectors_one_string_or_object_column(dtype): - """ - Test passing in one column with string or object dtype into virtual file dataset. +@pytest.mark.parametrize( + ("array_func", "dtype"), + [ + pytest.param(np.array, {"dtype": np.str_}, id="str"), + pytest.param(np.array, {"dtype": np.object_}, id="object"), + pytest.param( + pa_array, + {}, # {"type": pa.string()} + marks=skip_if_no(package="pyarrow"), + id="pyarrow", + ), + ], +) +def test_virtualfile_from_vectors_one_string_or_object_column(array_func, dtype): + """ + Test passing in one column with string (numpy/pyarrow) or object (numpy) + dtype into virtual file dataset. """ size = 5 x = np.arange(size, dtype=np.int32) y = np.arange(size, size * 2, 1, dtype=np.int32) - strings = np.array(["a", "bc", "defg", "hijklmn", "opqrst"], dtype=dtype) + strings = array_func(["a", "bc", "defg", "hijklmn", "opqrst"], **dtype) with clib.Session() as lib: - with lib.virtualfile_from_vectors((x, y, strings)) as vfile: + with lib.virtualfile_from_vectors(vectors=(x, y, strings)) as vfile: with GMTTempFile() as outfile: lib.call_module("convert", [vfile, f"->{outfile.name}"]) output = outfile.read(keep_tabs=True) diff --git a/pygmt/tests/test_text.py b/pygmt/tests/test_text.py index 64781c514bc..593c07a7b4d 100644 --- a/pygmt/tests/test_text.py +++ b/pygmt/tests/test_text.py @@ -9,6 +9,14 @@ from pygmt import Figure from pygmt.exceptions import GMTCLibError, GMTInvalidInput from pygmt.helpers import GMTTempFile +from pygmt.helpers.testing import skip_if_no + +try: + import pyarrow as pa + + pa_array = pa.array +except ImportError: + pa_array = None TEST_DATA_DIR = Path(__file__).parent / "data" POINTS_DATA = TEST_DATA_DIR / "points.txt" @@ -48,8 +56,16 @@ def test_text_single_line_of_text(region, projection): @pytest.mark.benchmark -@pytest.mark.mpl_image_compare -def test_text_multiple_lines_of_text(region, projection): +@pytest.mark.mpl_image_compare(filename="test_text_multiple_lines_of_text.png") +@pytest.mark.parametrize( + "array_func", + [ + list, + pytest.param(np.array, id="numpy"), + pytest.param(pa_array, marks=skip_if_no(package="pyarrow"), id="pyarrow"), + ], +) +def test_text_multiple_lines_of_text(region, projection, array_func): """ Place multiple lines of text at their respective x, y locations. """ @@ -59,7 +75,7 @@ def test_text_multiple_lines_of_text(region, projection): projection=projection, x=[1.2, 1.6], y=[0.6, 0.3], - text=["This is a line of text", "This is another line of text"], + text=array_func(["This is a line of text", "This is another line of text"]), ) return fig