Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@ def _get_jdbc_reader(self, query, jdbc_url, driver, additional_options: dict | N
.option("dbtable", f"({query}) tmp")
)
if isinstance(additional_options, dict):
for key, value in additional_options.items():
reader = reader.option(key, value)
reader = reader.options(**additional_options)
return reader

@staticmethod
Expand Down
10 changes: 6 additions & 4 deletions src/databricks/labs/lakebridge/reconcile/connectors/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,9 @@ def read_data(
table_query = query.replace(":tbl", f"{schema}.{table}")
try:
if options is None:
return self.reader(table_query).options(**self._get_timestamp_options()).load()
return self.reader(table_query, self._get_timestamp_options()).load()
reader_options = self._get_jdbc_reader_options(options) | self._get_timestamp_options()
df = self.reader(table_query).options(**reader_options).load()
df = self.reader(table_query, reader_options).load()
logger.warning(f"Fetching data using query: \n`{table_query}`")

# Convert all column names to lower case
Expand Down Expand Up @@ -107,12 +107,14 @@ def _get_timestamp_options() -> dict[str, str]:
"HH24:MI:SS''');END;",
}

def reader(self, query: str) -> DataFrameReader:
def reader(self, query: str, options: dict | None = None) -> DataFrameReader:
if options is None:
options = {}
user = self._get_secret('user')
password = self._get_secret('password')
logger.debug(f"Using user: {user} to connect to Oracle")
return self._get_jdbc_reader(
query, self.get_jdbc_url, OracleDataSource._DRIVER, {"user": user, "password": password}
query, self.get_jdbc_url, OracleDataSource._DRIVER, {**options, "user": user, "password": password}
)

def normalize_identifier(self, identifier: str) -> NormalizedIdentifier:
Expand Down
21 changes: 13 additions & 8 deletions src/databricks/labs/lakebridge/reconcile/connectors/tsql.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,6 @@ def get_jdbc_url(self) -> str:
return (
f"jdbc:{self._DRIVER}://{self._get_secret('host')}:{self._get_secret('port')};"
f"databaseName={self._get_secret('database')};"
f"user={self._get_secret('user')};"
f"password={self._get_secret('password')};"
f"encrypt={self._get_secret('encrypt')};"
f"trustServerCertificate={self._get_secret('trustServerCertificate')};"
)
Expand All @@ -96,10 +94,10 @@ def read_data(
prepare_query_string = ""
try:
if options is None:
df = self.reader(query, prepare_query_string).load()
df = self.reader(query, {"prepareQuery": prepare_query_string}).load()
else:
options = self._get_jdbc_reader_options(options)
df = self._get_jdbc_reader(table_query, self.get_jdbc_url, self._DRIVER).options(**options).load()
spark_options = self._get_jdbc_reader_options(options)
df = self.reader(table_query, spark_options).load()
return df.select([col(column).alias(column.lower()) for column in df.columns])
except (RuntimeError, PySparkException) as e:
return self.log_and_throw_exception(e, "data", table_query)
Expand All @@ -126,15 +124,22 @@ def get_schema(
try:
logger.debug(f"Fetching schema using query: \n`{schema_query}`")
logger.info(f"Fetching Schema: Started at: {datetime.now()}")
df = self.reader(schema_query).load()
df = self.reader(schema_query, {}).load()
schema_metadata = df.select([col(c).alias(c.lower()) for c in df.columns]).collect()
logger.info(f"Schema fetched successfully. Completed at: {datetime.now()}")
return [self._map_meta_column(field, normalize) for field in schema_metadata]
except (RuntimeError, PySparkException) as e:
return self.log_and_throw_exception(e, "schema", schema_query)

def reader(self, query: str, prepare_query_str="") -> DataFrameReader:
return self._get_jdbc_reader(query, self.get_jdbc_url, self._DRIVER, {"prepareQuery": prepare_query_str})
def reader(self, query: str, options: dict) -> DataFrameReader:
creds = self._get_user_password()
return self._get_jdbc_reader(query, self.get_jdbc_url, self._DRIVER, {**options, **creds})

def _get_user_password(self) -> dict:
return {
"user": self._get_secret("user"),
"password": self._get_secret("password"),
}

def normalize_identifier(self, identifier: str) -> NormalizedIdentifier:
return DialectUtils.normalize_identifier(
Expand Down
21 changes: 12 additions & 9 deletions tests/integration/reconcile/connectors/test_read_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@ def __init__(self, spark, ws):

@property
def get_jdbc_url(self) -> str:
return (
self._test_env.get("TEST_TSQL_JDBC")
+ f"user={self._test_env.get('TEST_TSQL_USER')};"
+ f"password={self._test_env.get('TEST_TSQL_PASS')};"
)
return self._test_env.get("TEST_TSQL_JDBC")

def _get_user_password(self) -> dict:
user = self._test_env.get("TEST_TSQL_USER")
password = self._test_env.get("TEST_TSQL_PASS")
return {"user": user, "password": password}


class OracleDataSourceUnderTest(OracleDataSource):
Expand All @@ -38,11 +39,13 @@ def __init__(self, spark, ws):
def get_jdbc_url(self) -> str:
return self._test_env.get("TEST_ORACLE_JDBC")

def reader(self, query: str) -> DataFrameReader:
def reader(self, query: str, options: dict | None = None) -> DataFrameReader:
if options is None:
options = {}
user = self._test_env.get("TEST_ORACLE_USER")
password = self._test_env.get("TEST_ORACLE_PASSWORD")
return self._get_jdbc_reader(
query, self.get_jdbc_url, OracleDataSource._DRIVER, {"user": user, "password": password}
query, self.get_jdbc_url, OracleDataSource._DRIVER, {**options, "user": user, "password": password}
)


Expand Down Expand Up @@ -75,12 +78,12 @@ def _get_snowflake_options(self):
return opts


@pytest.mark.skip(reason="Add the creds to Github secrets and populate the actions' env to enable this test")
@pytest.mark.skip(reason="Run in acceptance environment only")
def test_sql_server_read_schema_happy(mock_spark):
mock_ws = create_autospec(WorkspaceClient)
connector = TSQLServerDataSourceUnderTest(mock_spark, mock_ws)

columns = connector.get_schema("labs_azure_sandbox_remorph", "dbo", "Employees")
columns = connector.get_schema("labs_azure_sandbox_remorph", "dbo", "reconcile_in")
assert columns


Expand Down
14 changes: 6 additions & 8 deletions tests/unit/reconcile/connectors/test_oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,7 @@ def test_read_data_with_options():
)
spark.read.format().option().option.assert_called_with("driver", "oracle.jdbc.OracleDriver")
spark.read.format().option().option().option.assert_called_with("dbtable", "(select 1 from data.employee) tmp")
spark.read.format().option().option().option().option.assert_called_with("user", "my_user")
spark.read.format().option().option().option().option().option.assert_called_with("password", "my_password")
jdbc_actual_args = spark.read.format().option().option().option().option().option().options.call_args.kwargs
jdbc_actual_args = spark.read.format().option().option().option().options.call_args.kwargs
jdbc_expected_args = {
"numPartitions": 50,
"partitionColumn": "s_nationkey",
Expand All @@ -89,9 +87,11 @@ def test_read_data_with_options():
"sessionInitStatement": r"BEGIN dbms_session.set_nls('nls_date_format', "
r"'''YYYY-MM-DD''');dbms_session.set_nls('nls_timestamp_format', '''YYYY-MM-DD "
r"HH24:MI:SS''');END;",
"user": "my_user",
"password": "my_password",
}
assert jdbc_actual_args == jdbc_expected_args
spark.read.format().option().option().option().option().option().options().load.assert_called_once()
spark.read.format().option().option().option().options().load.assert_called_once()


def test_get_schema():
Expand Down Expand Up @@ -143,9 +143,7 @@ def test_read_data_exception_handling():
filters=None,
)

spark.read.format().option().option().option().option().option().options().load.side_effect = RuntimeError(
"Test Exception"
)
spark.read.format().option().option().option().options().load.side_effect = RuntimeError("Test Exception")

# Call the read_data method with the Tables configuration and assert that a PySparkException is raised
with pytest.raises(
Expand All @@ -160,7 +158,7 @@ def test_get_schema_exception_handling():
engine, spark, ws, scope = initial_setup()
ords = OracleDataSource(engine, spark, ws, scope)

spark.read.format().option().option().option().option().option().load.side_effect = RuntimeError("Test Exception")
spark.read.format().option().option().option().options().load.side_effect = RuntimeError("Test Exception")

# Call the get_schema method with predefined table, schema, and catalog names and assert that a PySparkException
# is raised
Expand Down
21 changes: 5 additions & 16 deletions tests/unit/reconcile/connectors/test_sql_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,20 +55,7 @@ def test_get_jdbc_url_happy():
url = data_source.get_jdbc_url
# Assert that the URL is generated correctly
assert url == (
"""jdbc:sqlserver://my_host:777;databaseName=my_database;user=my_user;password=my_password;encrypt=true;trustServerCertificate=true;"""
)


def test_get_jdbc_url_fail():
# initial setup
engine, spark, ws, scope = initial_setup()
ws.secrets.get_secret.side_effect = mock_secret
# create object for TSQLServerDataSource
data_source = TSQLServerDataSource(engine, spark, ws, scope)
url = data_source.get_jdbc_url
# Assert that the URL is generated correctly
assert url == (
"""jdbc:sqlserver://my_host:777;databaseName=my_database;user=my_user;password=my_password;encrypt=true;trustServerCertificate=true;"""
"""jdbc:sqlserver://my_host:777;databaseName=my_database;encrypt=true;trustServerCertificate=true;"""
)


Expand Down Expand Up @@ -96,7 +83,7 @@ def test_read_data_with_options():
spark.read.format.assert_called_with("jdbc")
spark.read.format().option.assert_called_with(
"url",
"jdbc:sqlserver://my_host:777;databaseName=my_database;user=my_user;password=my_password;encrypt=true;trustServerCertificate=true;",
"jdbc:sqlserver://my_host:777;databaseName=my_database;encrypt=true;trustServerCertificate=true;",
)
spark.read.format().option().option.assert_called_with("driver", "com.microsoft.sqlserver.jdbc.SQLServerDriver")
spark.read.format().option().option().option.assert_called_with(
Expand All @@ -109,6 +96,8 @@ def test_read_data_with_options():
"lowerBound": '0',
"upperBound": "100",
"fetchsize": 100,
"user": "my_user",
"password": "my_password",
}
assert actual_args == expected_args
spark.read.format().option().option().option().options().load.assert_called_once()
Expand Down Expand Up @@ -166,7 +155,7 @@ def test_get_schema_exception_handling():
engine, spark, ws, scope = initial_setup()
data_source = TSQLServerDataSource(engine, spark, ws, scope)

spark.read.format().option().option().option().option().load.side_effect = RuntimeError("Test Exception")
spark.read.format().option().option().option().options().load.side_effect = RuntimeError("Test Exception")

# Call the get_schema method with predefined table, schema, and catalog names and assert that a PySparkException
# is raised
Expand Down
Loading