Skip to content

Commit

Permalink
Add max_result_size feature to Pandas.read_sql_athena.
Browse files Browse the repository at this point in the history
  • Loading branch information
igorborgest committed Jul 25, 2019
1 parent 2b676f9 commit 9377338
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 3 deletions.
19 changes: 16 additions & 3 deletions awswrangler/pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,20 @@ def _read_csv_once(
buff.close()
return dataframe

def read_sql_athena(self, sql, database, s3_output=None):
def read_sql_athena(self,
sql,
database,
s3_output=None,
max_result_size=None):
"""
Executes any SQL query on AWS Athena and return a Dataframe of the result.
P.S. If max_result_size is passed, then a iterator of Dataframes is returned.
:param sql: SQL Query
:param database: Glue/Athena Databease
:param s3_output: AWS S3 path
:param max_result_size: Max number of bytes on each request to S3
:return: Pandas Dataframe or Iterator of Pandas Dataframes if max_result_size != None
"""
if not s3_output:
account_id = (self._session.boto3_session.client(
service_name="sts", config=self._session.botocore_config).
Expand All @@ -290,8 +303,8 @@ def read_sql_athena(self, sql, database, s3_output=None):
raise AthenaQueryError(message_error)
else:
path = f"{s3_output}{query_execution_id}.csv"
dataframe = self.read_csv(path=path)
return dataframe
ret = self.read_csv(path=path, max_result_size=max_result_size)
return ret

def to_csv(
self,
Expand Down
1 change: 1 addition & 0 deletions awswrangler/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ class SessionPrimitives:
It is required to "share" the session attributes to other processes.
That must be "pickable"!
"""

def __init__(
self,
profile_name=None,
Expand Down
24 changes: 24 additions & 0 deletions testing/test_awswrangler/test_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,3 +181,27 @@ def test_to_s3(
break
sleep(1)
assert factor * len(dataframe.index) == len(dataframe2.index)


@pytest.mark.parametrize("sample, row_num", [("data_samples/micro.csv", 30),
("data_samples/small.csv", 100)])
def test_read_sql_athena_iterator(session, bucket, database, sample, row_num):
dataframe_sample = pandas.read_csv(sample)
path = f"s3://{bucket}/test/"
session.pandas.to_parquet(dataframe=dataframe_sample,
database=database,
path=path,
preserve_index=False,
mode="overwrite")
total_count = 0
for counter in range(10):
dataframe_iter = session.pandas.read_sql_athena(
sql="select * from test", database=database, max_result_size=200)
total_count = 0
for dataframe in dataframe_iter:
total_count += len(dataframe.index)
if total_count == row_num:
break
sleep(1)
session.s3.delete_objects(path=path)
assert total_count == row_num

0 comments on commit 9377338

Please sign in to comment.