Skip to content

Commit f6b6981

Browse files
committed
Add streaming infiles
1 parent eb15029 commit f6b6981

File tree

4 files changed

+119
-23
lines changed

4 files changed

+119
-23
lines changed

singlestoredb/connection.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
"""SingleStoreDB connections and cursors."""
33
import abc
44
import inspect
5+
import io
6+
import queue
57
import re
68
import warnings
79
import weakref
@@ -496,6 +498,14 @@ def close(self) -> None:
496498
def execute(
497499
self, query: str,
498500
args: Optional[Union[Sequence[Any], Dict[str, Any], Any]] = None,
501+
infile_stream: Optional[
502+
Union[
503+
io.RawIOBase,
504+
io.TextIOBase,
505+
Iterator[Union[bytes, str]],
506+
queue.Queue[Union[bytes, str]],
507+
]
508+
] = None,
499509
) -> int:
500510
"""
501511
Execute a SQL statement.
@@ -510,6 +520,8 @@ def execute(
510520
The SQL statement to execute
511521
args : Sequence or dict, optional
512522
Parameters to substitute into the SQL code
523+
infile_stream : io.RawIOBase or io.TextIOBase or Iterator[bytes|str], optional
524+
Data stream for ``LOCAL INFILE`` statement
513525
514526
Examples
515527
--------

singlestoredb/http/connection.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
import datetime
44
import decimal
55
import functools
6+
import io
67
import json
78
import math
89
import os
10+
import queue
911
import re
1012
import time
1113
from base64 import b64decode
@@ -420,6 +422,14 @@ def close(self) -> None:
420422
def execute(
421423
self, query: str,
422424
args: Optional[Union[Sequence[Any], Dict[str, Any]]] = None,
425+
infile_stream: Optional[
426+
Union[
427+
io.RawIOBase,
428+
io.TextIOBase,
429+
Iterable[Union[bytes, str]],
430+
queue.Queue[Union[bytes, str]],
431+
]
432+
] = None,
423433
) -> int:
424434
"""
425435
Execute a SQL statement.
@@ -432,7 +442,7 @@ def execute(
432442
Parameters to substitute into the SQL code
433443
434444
"""
435-
return self._execute(query, args)
445+
return self._execute(query, args, infile_stream=infile_stream)
436446

437447
def _validate_param_subs(
438448
self, query: str,
@@ -496,6 +506,14 @@ def _execute(
496506
self, oper: str,
497507
params: Optional[Union[Sequence[Any], Dict[str, Any]]] = None,
498508
is_callproc: bool = False,
509+
infile_stream: Optional[
510+
Union[
511+
io.RawIOBase,
512+
io.TextIOBase,
513+
Iterable[Union[bytes, str]],
514+
queue.Queue[Union[bytes, str]],
515+
]
516+
] = None,
499517
) -> int:
500518
self._descriptions = []
501519
self._schemas = []

singlestoredb/mysql/connection.py

Lines changed: 80 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,15 @@
55
# https://dev.mysql.com/doc/refman/5.5/en/error-handling.html
66
import errno
77
import functools
8+
import io
89
import os
10+
import queue
911
import socket
1012
import struct
1113
import sys
1214
import traceback
1315
import warnings
16+
from typing import Iterable
1417

1518
try:
1619
import _singlestoredb_accel
@@ -353,6 +356,7 @@ def __init__( # noqa: C901
353356
)
354357

355358
self._local_infile = bool(local_infile)
359+
self._local_infile_stream = None
356360
if self._local_infile:
357361
client_flag |= CLIENT.LOCAL_FILES
358362
if multi_statements:
@@ -843,7 +847,7 @@ def cursor(self):
843847
return self.cursorclass(self)
844848

845849
# The following methods are INTERNAL USE ONLY (called from Cursor)
846-
def query(self, sql, unbuffered=False):
850+
def query(self, sql, unbuffered=False, infile_stream=None):
847851
"""
848852
Run a query on the server.
849853
@@ -859,8 +863,10 @@ def query(self, sql, unbuffered=False):
859863
else:
860864
if isinstance(sql, str):
861865
sql = sql.encode(self.encoding, 'surrogateescape')
866+
self._local_infile_stream = infile_stream
862867
self._execute_command(COMMAND.COM_QUERY, sql)
863868
self._affected_rows = self._read_query_result(unbuffered=unbuffered)
869+
self._local_infile_stream = None
864870
return self._affected_rows
865871

866872
def next_result(self, unbuffered=False):
@@ -1871,24 +1877,82 @@ def __init__(self, filename, connection):
18711877
def send_data(self):
18721878
"""Send data packets from the local file to the server"""
18731879
if not self.connection._sock:
1874-
raise err.InterfaceError(0, '')
1880+
raise err.InterfaceError(0, 'Connection is closed')
1881+
18751882
conn = self.connection
1883+
infile = conn._local_infile_stream
1884+
1885+
# 16KB is efficient enough
1886+
packet_size = min(conn.max_allowed_packet, 16 * 1024)
18761887

18771888
try:
1878-
with open(self.filename, 'rb') as open_file:
1879-
packet_size = min(
1880-
conn.max_allowed_packet, 16 * 1024,
1881-
) # 16KB is efficient enough
1882-
while True:
1883-
chunk = open_file.read(packet_size)
1884-
if not chunk:
1885-
break
1886-
conn.write_packet(chunk)
1887-
except OSError:
1888-
raise err.OperationalError(
1889-
ER.FILE_NOT_FOUND,
1890-
f"Can't find file '{self.filename}'",
1891-
)
1889+
1890+
if self.filename in [':stream:', b':stream:']:
1891+
1892+
if infile is None:
1893+
raise err.OperationalError(
1894+
ER.FILE_NOT_FOUND,
1895+
':stream: specified for LOCAL INFILE, but no stream was supplied',
1896+
)
1897+
1898+
# Binary IO
1899+
elif isinstance(infile, io.RawIOBase):
1900+
while True:
1901+
chunk = infile.read(packet_size)
1902+
if not chunk:
1903+
break
1904+
conn.write_packet(chunk)
1905+
1906+
# Text IO
1907+
elif isinstance(infile, io.TextIOBase):
1908+
while True:
1909+
chunk = infile.read(packet_size)
1910+
if not chunk:
1911+
break
1912+
conn.write_packet(chunk.encode('utf8'))
1913+
1914+
# Iterable of bytes or str
1915+
elif isinstance(infile, Iterable):
1916+
for chunk in infile:
1917+
if not chunk:
1918+
continue
1919+
if isinstance(chunk, str):
1920+
conn.write_packet(chunk.encode('utf8'))
1921+
else:
1922+
conn.write_packet(chunk)
1923+
1924+
# Queue (empty value ends the iteration)
1925+
elif isinstance(infile, queue.Queue):
1926+
while True:
1927+
chunk = infile.get()
1928+
if not chunk:
1929+
break
1930+
if isinstance(chunk, str):
1931+
conn.write_packet(chunk.encode('utf8'))
1932+
else:
1933+
conn.write_packet(chunk)
1934+
1935+
else:
1936+
raise err.OperationalError(
1937+
ER.FILE_NOT_FOUND,
1938+
':stream: specified for LOCAL INFILE, ' +
1939+
f'but stream type is unrecognized: {infile}',
1940+
)
1941+
1942+
else:
1943+
try:
1944+
with open(self.filename, 'rb') as open_file:
1945+
while True:
1946+
chunk = open_file.read(packet_size)
1947+
if not chunk:
1948+
break
1949+
conn.write_packet(chunk)
1950+
except OSError:
1951+
raise err.OperationalError(
1952+
ER.FILE_NOT_FOUND,
1953+
f"Can't find file '{self.filename!s}'",
1954+
)
1955+
18921956
finally:
18931957
if not conn._closed:
18941958
# send the empty packet to signify we are done sending data

singlestoredb/mysql/cursors.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def mogrify(self, query, args=None):
178178

179179
return query
180180

181-
def execute(self, query, args=None):
181+
def execute(self, query, args=None, infile_stream=None):
182182
"""
183183
Execute a query.
184184
@@ -192,6 +192,8 @@ def execute(self, query, args=None):
192192
Query to execute.
193193
args : Sequence[Any] or Dict[str, Any] or Any, optional
194194
Parameters used with query. (optional)
195+
infile_stream : io.BytesIO or Iterator[bytes], optional
196+
Data stream for ``LOCAL INFILE`` statements
195197
196198
Returns
197199
-------
@@ -205,7 +207,7 @@ def execute(self, query, args=None):
205207

206208
query = self.mogrify(query, args)
207209

208-
result = self._query(query)
210+
result = self._query(query, infile_stream=infile_stream)
209211
self._executed = query
210212
return result
211213

@@ -387,10 +389,10 @@ def scroll(self, value, mode='relative'):
387389
raise IndexError('out of range')
388390
self._rownumber = r
389391

390-
def _query(self, q):
392+
def _query(self, q, infile_stream=None):
391393
conn = self._get_db()
392394
self._clear_result()
393-
conn.query(q)
395+
conn.query(q, infile_stream=infile_stream)
394396
self._do_get_result()
395397
return self.rowcount
396398

@@ -680,10 +682,10 @@ def close(self):
680682

681683
__del__ = close
682684

683-
def _query(self, q):
685+
def _query(self, q, infile_stream=None):
684686
conn = self._get_db()
685687
self._clear_result()
686-
conn.query(q, unbuffered=True)
688+
conn.query(q, unbuffered=True, infile_stream=infile_stream)
687689
self._do_get_result()
688690
return self.rowcount
689691

0 commit comments

Comments
 (0)