diff --git a/src/databricks/sql/ae.py b/src/databricks/sql/ae.py index 0751e1bb..efa8e4e0 100644 --- a/src/databricks/sql/ae.py +++ b/src/databricks/sql/ae.py @@ -1,7 +1,7 @@ from enum import Enum from typing import Optional, Union, TYPE_CHECKING from databricks.sql.exc import RequestError -from databricks.sql.results import ResultSet +from databricks.sql.results import ResultSet, execute_response_contains_direct_results from datetime import datetime @@ -81,6 +81,7 @@ class AsyncExecution: ] _last_sync_timestamp: Optional[datetime] = None _result_set: Optional["ResultSet"] = None + _returned_as_direct_result: bool = False def __init__( self, @@ -101,6 +102,8 @@ def __init__( if execute_statement_response: self._execute_statement_response = execute_statement_response + if execute_response_contains_direct_results(execute_statement_response): + self._returned_as_direct_result = True else: self._execute_statement_response = FakeExecuteStatementResponse( directResults=False, operationHandle=self.t_operation_handle @@ -225,6 +228,17 @@ def last_sync_timestamp(self) -> Optional[datetime]: """The timestamp of the last time self.status was synced with the server""" return self._last_sync_timestamp + @property + def returned_as_direct_result(self) -> bool: + """When direct results were returned, this query_id cannot be picked up + with `Connection.get_async_execution()` + + Only returns True if the query returned its results directly when `execute_async` + was called. + """ + + return self._returned_as_direct_result + @classmethod def from_thrift_response( cls, diff --git a/src/databricks/sql/results.py b/src/databricks/sql/results.py index aed07982..ec2433c2 100644 --- a/src/databricks/sql/results.py +++ b/src/databricks/sql/results.py @@ -9,6 +9,7 @@ from databricks.sql.exc import ( CursorAlreadyClosedError, ) +from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.types import Row from databricks.sql.utils import ExecuteResponse @@ -17,6 +18,10 @@ from databricks.sql.client import Connection from databricks.sql.thrift_backend import ThriftBackend +import logging + +logger = logging.getLogger(__name__) + # TODO: this is duplicated from client.py to avoid ImportError. Fix this. DEFAULT_RESULT_BUFFER_SIZE_BYTES = 104857600 @@ -223,3 +228,28 @@ def map_col_type(type_): (column.name, map_col_type(column.datatype), None, None, None, None, None) for column in table_schema_message.columns ] + + +def execute_response_contains_direct_results( + execute_response: ttypes.TExecuteStatementResp, +) -> bool: + """ + Returns True if the thrift TExecuteStatementResp returned a direct result. + + When directResults is used the server just batches these rpcs together, + if the entire result can be returned in a single round-trip: + + struct TSparkDirectResults { + 1: optional TGetOperationStatusResp operationStatus + 2: optional TGetResultSetMetadataResp resultSetMetadata + 3: optional TFetchResultsResp resultSet + 4: optional TCloseOperationResp closeOperation + } + """ + + has_op_status = execute_response.directResults.operationStatus + has_result_set = execute_response.directResults.resultSet + has_metadata = execute_response.directResults.resultSetMetadata + has_close_op = execute_response.directResults.closeOperation + + return has_op_status and has_result_set and has_metadata and has_close_op diff --git a/tests/e2e/test_execute_async.py b/tests/e2e/test_execute_async.py index 6fde60ea..a14e59cc 100644 --- a/tests/e2e/test_execute_async.py +++ b/tests/e2e/test_execute_async.py @@ -68,7 +68,7 @@ def test_direct_results_query_canary(self): with self.connection() as conn: ae = conn.execute_async(DIRECT_RESULTS_QUERY, {"param": 1}) - assert not ae.is_running + assert ae.returned_as_direct_result def test_cancel_running_query(self, long_running_ae: AsyncExecution): long_running_ae.cancel() @@ -112,21 +112,14 @@ def cancel_query_in_separate_thread(query_id, query_secret): assert long_running_ae.status == AsyncExecutionStatus.CANCELED def test_long_ish_query_canary(self, long_ish_ae: AsyncExecution): - """This test verifies that on the current endpoint, the LONG_ISH_QUERY requires - at least one sync_status call before it is finished. If this test fails, it means - the SQL warehouse got faster at executing this query and we should increment the value - of GT_FIVE_SECONDS_VALUE + """This test verifies that on the current endpoint, the LONG_ISH_QUERY does not return direct results. It would be easier to do this if Databricks SQL had a SLEEP() function :/ - """ - poll_count = 0 - while long_ish_ae.is_running: - time.sleep(1) - long_ish_ae.sync_status() - poll_count += 1 + We could acheive something similar by overriding the directResults setting in our ExecuteStatementReq + """ - assert poll_count > 0 + assert not long_ish_ae.returned_as_direct_result def test_get_async_execution_and_get_results_without_direct_results( self, long_ish_ae: AsyncExecution @@ -162,10 +155,13 @@ def test_serialize(self, long_running_ae: AsyncExecution): assert ae.is_running def test_get_async_execution_no_results_when_direct_results_were_sent(self): - """It remains to be seen whether results can be fetched repeatedly from a "picked up" execution.""" + """Queries that return direct results cannot be picked up with `get_async_execution()`.""" with self.connection() as conn: ae = conn.execute_async(DIRECT_RESULTS_QUERY, {"param": 1}) + assert ( + ae.returned_as_direct_result + ), "Queries that return direct results should not be available" query_id, query_secret = ae.serialize().split(":") ae.get_result() @@ -193,9 +189,12 @@ def test_get_async_execution_twice(self): """ with self.connection() as conn_1, self.connection() as conn_2: ae_1 = conn_1.execute_async(LONG_ISH_QUERY) + assert not ae_1.returned_as_direct_result + query_id, query_secret = ae_1.serialize().split(":") ae_2 = conn_2.get_async_execution(query_id, query_secret) + assert not ae_2.returned_as_direct_result while ae_1.is_running: time.sleep(1)