Skip to content

Commit 9377338

Browse files
committed
Add max_result_size feature to Pandas.read_sql_athena.
1 parent 2b676f9 commit 9377338

File tree

3 files changed

+41
-3
lines changed

3 files changed

+41
-3
lines changed

awswrangler/pandas.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,20 @@ def _read_csv_once(
268268
buff.close()
269269
return dataframe
270270

271-
def read_sql_athena(self, sql, database, s3_output=None):
271+
def read_sql_athena(self,
272+
sql,
273+
database,
274+
s3_output=None,
275+
max_result_size=None):
276+
"""
277+
Executes any SQL query on AWS Athena and return a Dataframe of the result.
278+
P.S. If max_result_size is passed, then a iterator of Dataframes is returned.
279+
:param sql: SQL Query
280+
:param database: Glue/Athena Databease
281+
:param s3_output: AWS S3 path
282+
:param max_result_size: Max number of bytes on each request to S3
283+
:return: Pandas Dataframe or Iterator of Pandas Dataframes if max_result_size != None
284+
"""
272285
if not s3_output:
273286
account_id = (self._session.boto3_session.client(
274287
service_name="sts", config=self._session.botocore_config).
@@ -290,8 +303,8 @@ def read_sql_athena(self, sql, database, s3_output=None):
290303
raise AthenaQueryError(message_error)
291304
else:
292305
path = f"{s3_output}{query_execution_id}.csv"
293-
dataframe = self.read_csv(path=path)
294-
return dataframe
306+
ret = self.read_csv(path=path, max_result_size=max_result_size)
307+
return ret
295308

296309
def to_csv(
297310
self,

awswrangler/session.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ class SessionPrimitives:
227227
It is required to "share" the session attributes to other processes.
228228
That must be "pickable"!
229229
"""
230+
230231
def __init__(
231232
self,
232233
profile_name=None,

testing/test_awswrangler/test_pandas.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,3 +181,27 @@ def test_to_s3(
181181
break
182182
sleep(1)
183183
assert factor * len(dataframe.index) == len(dataframe2.index)
184+
185+
186+
@pytest.mark.parametrize("sample, row_num", [("data_samples/micro.csv", 30),
187+
("data_samples/small.csv", 100)])
188+
def test_read_sql_athena_iterator(session, bucket, database, sample, row_num):
189+
dataframe_sample = pandas.read_csv(sample)
190+
path = f"s3://{bucket}/test/"
191+
session.pandas.to_parquet(dataframe=dataframe_sample,
192+
database=database,
193+
path=path,
194+
preserve_index=False,
195+
mode="overwrite")
196+
total_count = 0
197+
for counter in range(10):
198+
dataframe_iter = session.pandas.read_sql_athena(
199+
sql="select * from test", database=database, max_result_size=200)
200+
total_count = 0
201+
for dataframe in dataframe_iter:
202+
total_count += len(dataframe.index)
203+
if total_count == row_num:
204+
break
205+
sleep(1)
206+
session.s3.delete_objects(path=path)
207+
assert total_count == row_num

0 commit comments

Comments
 (0)