Skip to content

Commit 890626d

Browse files
authored
Add the Session.virtualfile_from_stringio method to allow StringIO input for certain functions/methods (#3326)
1 parent a592ade commit 890626d

File tree

4 files changed

+222
-5
lines changed

4 files changed

+222
-5
lines changed

doc/api/index.rst

+1
Original file line numberDiff line numberDiff line change
@@ -317,5 +317,6 @@ Low level access (these are mostly used by the :mod:`pygmt.clib` package):
317317
clib.Session.get_libgmt_func
318318
clib.Session.virtualfile_from_data
319319
clib.Session.virtualfile_from_grid
320+
clib.Session.virtualfile_from_stringio
320321
clib.Session.virtualfile_from_matrix
321322
clib.Session.virtualfile_from_vectors

pygmt/clib/session.py

+104-2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import contextlib
99
import ctypes as ctp
10+
import io
1011
import pathlib
1112
import sys
1213
import warnings
@@ -60,6 +61,7 @@
6061
"GMT_IS_PLP", # items could be any one of POINT, LINE, or POLY
6162
"GMT_IS_SURFACE", # items are 2-D grid
6263
"GMT_IS_VOLUME", # items are 3-D grid
64+
"GMT_IS_TEXT", # Text strings which triggers ASCII text reading
6365
]
6466

6567
METHODS = [
@@ -70,6 +72,11 @@
7072
DIRECTIONS = ["GMT_IN", "GMT_OUT"]
7173

7274
MODES = ["GMT_CONTAINER_ONLY", "GMT_IS_OUTPUT"]
75+
MODE_MODIFIERS = [
76+
"GMT_GRID_IS_CARTESIAN",
77+
"GMT_GRID_IS_GEO",
78+
"GMT_WITH_STRINGS",
79+
]
7380

7481
REGISTRATIONS = ["GMT_GRID_PIXEL_REG", "GMT_GRID_NODE_REG"]
7582

@@ -728,7 +735,7 @@ def create_data(
728735
mode_int = self._parse_constant(
729736
mode,
730737
valid=MODES,
731-
valid_modifiers=["GMT_GRID_IS_CARTESIAN", "GMT_GRID_IS_GEO"],
738+
valid_modifiers=MODE_MODIFIERS,
732739
)
733740
geometry_int = self._parse_constant(geometry, valid=GEOMETRIES)
734741
registration_int = self._parse_constant(registration, valid=REGISTRATIONS)
@@ -1603,6 +1610,100 @@ def virtualfile_from_grid(self, grid):
16031610
with self.open_virtualfile(*args) as vfile:
16041611
yield vfile
16051612

1613+
@contextlib.contextmanager
1614+
def virtualfile_from_stringio(self, stringio: io.StringIO):
1615+
r"""
1616+
Store a :class:`io.StringIO` object in a virtual file.
1617+
1618+
Store the contents of a :class:`io.StringIO` object in a GMT_DATASET container
1619+
and create a virtual file to pass to a GMT module.
1620+
1621+
For simplicity, currently we make following assumptions in the StringIO object
1622+
1623+
- ``"#"`` indicates a comment line.
1624+
- ``">"`` indicates a segment header.
1625+
1626+
Parameters
1627+
----------
1628+
stringio
1629+
The :class:`io.StringIO` object containing the data to be stored in the
1630+
virtual file.
1631+
1632+
Yields
1633+
------
1634+
fname
1635+
The name of the virtual file.
1636+
1637+
Examples
1638+
--------
1639+
>>> import io
1640+
>>> from pygmt.clib import Session
1641+
>>> # A StringIO object containing legend specifications
1642+
>>> stringio = io.StringIO(
1643+
... "# Comment\n"
1644+
... "H 24p Legend\n"
1645+
... "N 2\n"
1646+
... "S 0.1i c 0.15i p300/12 0.25p 0.3i My circle\n"
1647+
... )
1648+
>>> with Session() as lib:
1649+
... with lib.virtualfile_from_stringio(stringio) as fin:
1650+
... lib.virtualfile_to_dataset(vfname=fin, output_type="pandas")
1651+
0
1652+
0 H 24p Legend
1653+
1 N 2
1654+
2 S 0.1i c 0.15i p300/12 0.25p 0.3i My circle
1655+
"""
1656+
# Parse the io.StringIO object.
1657+
segments = []
1658+
current_segment = {"header": "", "data": []}
1659+
for line in stringio.getvalue().splitlines():
1660+
if line.startswith("#"): # Skip comments
1661+
continue
1662+
if line.startswith(">"): # Segment header
1663+
if current_segment["data"]: # If we have data, start a new segment
1664+
segments.append(current_segment)
1665+
current_segment = {"header": "", "data": []}
1666+
current_segment["header"] = line.strip(">").lstrip()
1667+
else:
1668+
current_segment["data"].append(line) # type: ignore[attr-defined]
1669+
if current_segment["data"]: # Add the last segment if it has data
1670+
segments.append(current_segment)
1671+
1672+
# One table with one or more segments.
1673+
# n_rows is the maximum number of rows/records for all segments.
1674+
# n_columns is the number of numeric data columns, so it's 0 here.
1675+
n_tables = 1
1676+
n_segments = len(segments)
1677+
n_rows = max(len(segment["data"]) for segment in segments)
1678+
n_columns = 0
1679+
1680+
# Create the GMT_DATASET container
1681+
family, geometry = "GMT_IS_DATASET", "GMT_IS_TEXT"
1682+
dataset = self.create_data(
1683+
family,
1684+
geometry,
1685+
mode="GMT_CONTAINER_ONLY|GMT_WITH_STRINGS",
1686+
dim=[n_tables, n_segments, n_rows, n_columns],
1687+
)
1688+
dataset = ctp.cast(dataset, ctp.POINTER(_GMT_DATASET))
1689+
table = dataset.contents.table[0].contents
1690+
for i, segment in enumerate(segments):
1691+
seg = table.segment[i].contents
1692+
if segment["header"]:
1693+
seg.header = segment["header"].encode() # type: ignore[attr-defined]
1694+
seg.text = strings_to_ctypes_array(segment["data"])
1695+
1696+
with self.open_virtualfile(family, geometry, "GMT_IN", dataset) as vfile:
1697+
try:
1698+
yield vfile
1699+
finally:
1700+
# Must set the pointers to None to avoid double freeing the memory.
1701+
# Maybe upstream bug.
1702+
for i in range(n_segments):
1703+
seg = table.segment[i].contents
1704+
seg.header = None
1705+
seg.text = None
1706+
16061707
def virtualfile_in( # noqa: PLR0912
16071708
self,
16081709
check_kind=None,
@@ -1696,6 +1797,7 @@ def virtualfile_in( # noqa: PLR0912
16961797
"geojson": tempfile_from_geojson,
16971798
"grid": self.virtualfile_from_grid,
16981799
"image": tempfile_from_image,
1800+
"stringio": self.virtualfile_from_stringio,
16991801
# Note: virtualfile_from_matrix is not used because a matrix can be
17001802
# converted to vectors instead, and using vectors allows for better
17011803
# handling of string type inputs (e.g. for datetime data types)
@@ -1704,7 +1806,7 @@ def virtualfile_in( # noqa: PLR0912
17041806
}[kind]
17051807

17061808
# Ensure the data is an iterable (Python list or tuple)
1707-
if kind in {"geojson", "grid", "image", "file", "arg"}:
1809+
if kind in {"geojson", "grid", "image", "file", "arg", "stringio"}:
17081810
if kind == "image" and data.dtype != "uint8":
17091811
msg = (
17101812
f"Input image has dtype: {data.dtype} which is unsupported, "

pygmt/helpers/utils.py

+13-3
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Utilities and common tasks for wrapping the GMT modules.
33
"""
44

5+
import io
56
import os
67
import pathlib
78
import shutil
@@ -188,8 +189,10 @@ def _check_encoding(
188189

189190
def data_kind(
190191
data: Any = None, required: bool = True
191-
) -> Literal["arg", "file", "geojson", "grid", "image", "matrix", "vectors"]:
192-
"""
192+
) -> Literal[
193+
"arg", "file", "geojson", "grid", "image", "matrix", "stringio", "vectors"
194+
]:
195+
r"""
193196
Check the kind of data that is provided to a module.
194197
195198
The ``data`` argument can be in any type, but only following types are supported:
@@ -222,6 +225,7 @@ def data_kind(
222225
>>> import numpy as np
223226
>>> import xarray as xr
224227
>>> import pathlib
228+
>>> import io
225229
>>> data_kind(data=None)
226230
'vectors'
227231
>>> data_kind(data=np.arange(10).reshape((5, 2)))
@@ -240,8 +244,12 @@ def data_kind(
240244
'grid'
241245
>>> data_kind(data=xr.DataArray(np.random.rand(3, 4, 5)))
242246
'image'
247+
>>> data_kind(data=io.StringIO("TEXT1\nTEXT23\n"))
248+
'stringio'
243249
"""
244-
kind: Literal["arg", "file", "geojson", "grid", "image", "matrix", "vectors"]
250+
kind: Literal[
251+
"arg", "file", "geojson", "grid", "image", "matrix", "stringio", "vectors"
252+
]
245253
if isinstance(data, str | pathlib.PurePath) or (
246254
isinstance(data, list | tuple)
247255
and all(isinstance(_file, str | pathlib.PurePath) for _file in data)
@@ -250,6 +258,8 @@ def data_kind(
250258
kind = "file"
251259
elif isinstance(data, bool | int | float) or (data is None and not required):
252260
kind = "arg"
261+
elif isinstance(data, io.StringIO):
262+
kind = "stringio"
253263
elif isinstance(data, xr.DataArray):
254264
kind = "image" if len(data.dims) == 3 else "grid"
255265
elif hasattr(data, "__geo_interface__"):

pygmt/tests/test_clib_virtualfiles.py

+104
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Test the C API functions related to virtual files.
33
"""
44

5+
import io
56
from importlib.util import find_spec
67
from itertools import product
78
from pathlib import Path
@@ -407,3 +408,106 @@ def test_inquire_virtualfile():
407408
]:
408409
with lib.open_virtualfile(family, geometry, "GMT_OUT", None) as vfile:
409410
assert lib.inquire_virtualfile(vfile) == lib[family]
411+
412+
413+
class TestVirtualfileFromStringIO:
414+
"""
415+
Test the virtualfile_from_stringio method.
416+
"""
417+
418+
def _stringio_to_dataset(self, data: io.StringIO):
419+
"""
420+
A helper function for check the virtualfile_from_stringio method.
421+
422+
The function does the following:
423+
424+
1. Creates a virtual file from the input StringIO object.
425+
2. Pass the virtual file to the ``read`` module, which reads the virtual file
426+
and writes it to another virtual file.
427+
3. Reads the output virtual file as a GMT_DATASET object.
428+
4. Extracts the header and the trailing text from the dataset and returns it as
429+
a string.
430+
"""
431+
with clib.Session() as lib:
432+
with (
433+
lib.virtualfile_from_stringio(data) as vintbl,
434+
lib.virtualfile_out(kind="dataset") as vouttbl,
435+
):
436+
lib.call_module("read", args=[vintbl, vouttbl, "-Td"])
437+
ds = lib.read_virtualfile(vouttbl, kind="dataset").contents
438+
439+
output = []
440+
table = ds.table[0].contents
441+
for segment in table.segment[: table.n_segments]:
442+
seg = segment.contents
443+
output.append(f"> {seg.header.decode()}" if seg.header else ">")
444+
output.extend(np.char.decode(seg.text[: seg.n_rows]))
445+
return "\n".join(output) + "\n"
446+
447+
def test_virtualfile_from_stringio(self):
448+
"""
449+
Test the virtualfile_from_stringio method.
450+
"""
451+
data = io.StringIO(
452+
"# Comment\n"
453+
"H 24p Legend\n"
454+
"N 2\n"
455+
"S 0.1i c 0.15i p300/12 0.25p 0.3i My circle\n"
456+
)
457+
expected = (
458+
">\n"
459+
"H 24p Legend\n"
460+
"N 2\n"
461+
"S 0.1i c 0.15i p300/12 0.25p 0.3i My circle\n"
462+
)
463+
assert self._stringio_to_dataset(data) == expected
464+
465+
def test_one_segment(self):
466+
"""
467+
Test the virtualfile_from_stringio method with one segment.
468+
"""
469+
data = io.StringIO(
470+
"# Comment\n"
471+
"> Segment 1\n"
472+
"1 2 3 ABC\n"
473+
"4 5 DE\n"
474+
"6 7 8 9 FGHIJK LMN OPQ\n"
475+
"RSTUVWXYZ\n"
476+
)
477+
expected = (
478+
"> Segment 1\n"
479+
"1 2 3 ABC\n"
480+
"4 5 DE\n"
481+
"6 7 8 9 FGHIJK LMN OPQ\n"
482+
"RSTUVWXYZ\n"
483+
)
484+
assert self._stringio_to_dataset(data) == expected
485+
486+
def test_multiple_segments(self):
487+
"""
488+
Test the virtualfile_from_stringio method with multiple segments.
489+
"""
490+
data = io.StringIO(
491+
"# Comment line 1\n"
492+
"# Comment line 2\n"
493+
"> Segment 1\n"
494+
"1 2 3 ABC\n"
495+
"4 5 DE\n"
496+
"6 7 8 9 FG\n"
497+
"# Comment line 3\n"
498+
"> Segment 2\n"
499+
"1 2 3 ABC\n"
500+
"4 5 DE\n"
501+
"6 7 8 9 FG\n"
502+
)
503+
expected = (
504+
"> Segment 1\n"
505+
"1 2 3 ABC\n"
506+
"4 5 DE\n"
507+
"6 7 8 9 FG\n"
508+
"> Segment 2\n"
509+
"1 2 3 ABC\n"
510+
"4 5 DE\n"
511+
"6 7 8 9 FG\n"
512+
)
513+
assert self._stringio_to_dataset(data) == expected

0 commit comments

Comments
 (0)