From bf928fffa4ee70c44d135406ad5fc42b9858cdda Mon Sep 17 00:00:00 2001 From: M Abulazm Date: Fri, 21 Nov 2025 10:11:46 +0100 Subject: [PATCH 1/6] add load_credentials interface for data sources and impl --- .../reconcile/connectors/data_source.py | 28 ++++- .../reconcile/connectors/databricks.py | 12 +- .../reconcile/connectors/dialect_utils.py | 8 +- .../reconcile/connectors/jdbc_reader.py | 1 - .../lakebridge/reconcile/connectors/models.py | 7 -- .../lakebridge/reconcile/connectors/oracle.py | 53 +++++--- .../reconcile/connectors/secrets.py | 49 -------- .../reconcile/connectors/snowflake.py | 115 +++++++++-------- .../reconcile/connectors/source_adapter.py | 9 +- .../lakebridge/reconcile/connectors/tsql.py | 52 ++++++-- .../reconcile/trigger_recon_service.py | 2 +- .../labs/lakebridge/reconcile/utils.py | 10 +- tests/conftest.py | 7 +- .../reconcile/connectors/test_read_schema.py | 14 +-- .../reconcile/query_builder/test_execute.py | 30 ++++- .../reconcile/test_oracle_reconcile.py | 2 +- .../reconcile/connectors/test_databricks.py | 26 ++-- .../unit/reconcile/connectors/test_oracle.py | 36 ++++-- .../unit/reconcile/connectors/test_secrets.py | 65 ---------- .../reconcile/connectors/test_snowflake.py | 117 ++++++++++-------- .../reconcile/connectors/test_sql_server.py | 46 +++---- tests/unit/reconcile/test_source_adapter.py | 18 ++- 22 files changed, 353 insertions(+), 354 deletions(-) delete mode 100644 src/databricks/labs/lakebridge/reconcile/connectors/models.py delete mode 100644 src/databricks/labs/lakebridge/reconcile/connectors/secrets.py delete mode 100644 tests/unit/reconcile/connectors/test_secrets.py diff --git a/src/databricks/labs/lakebridge/reconcile/connectors/data_source.py b/src/databricks/labs/lakebridge/reconcile/connectors/data_source.py index 9294768b77..0b3c6e6388 100644 --- a/src/databricks/labs/lakebridge/reconcile/connectors/data_source.py +++ b/src/databricks/labs/lakebridge/reconcile/connectors/data_source.py @@ -3,14 +3,31 @@ from pyspark.sql import DataFrame -from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils -from databricks.labs.lakebridge.reconcile.connectors.models import NormalizedIdentifier +from databricks.labs.lakebridge.config import ReconcileCredentialConfig +from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils, NormalizedIdentifier from databricks.labs.lakebridge.reconcile.exception import DataSourceRuntimeException from databricks.labs.lakebridge.reconcile.recon_config import JdbcReaderOptions, Schema logger = logging.getLogger(__name__) +def build_credentials(vault_type: str, source: str, credentials: dict) -> dict: + """Build credentials dictionary with secret vault type included. + + Args: + vault_type: The type of secret vault (e.g., 'local', 'databricks'). + source: The source system name. + credentials: The original credentials dictionary. + + Returns: + A new credentials dictionary including the secret vault type. + """ + return { + source: credentials, + 'secret_vault_type': vault_type.lower(), + } + + class DataSource(ABC): @abstractmethod @@ -34,6 +51,10 @@ def get_schema( ) -> list[Schema]: return NotImplemented + @abstractmethod + def load_credentials(self, creds: ReconcileCredentialConfig) -> "DataSource": + return NotImplemented + @abstractmethod def normalize_identifier(self, identifier: str) -> NormalizedIdentifier: pass @@ -94,5 +115,8 @@ def get_schema(self, catalog: str | None, schema: str, table: str, normalize: bo return self.log_and_throw_exception(self._exception, "schema", f"({catalog}, {schema}, {table})") return mock_schema + def load_credentials(self, creds: ReconcileCredentialConfig) -> "MockDataSource": + return self + def normalize_identifier(self, identifier: str) -> NormalizedIdentifier: return DialectUtils.normalize_identifier(identifier, self._delimiter, self._delimiter) diff --git a/src/databricks/labs/lakebridge/reconcile/connectors/databricks.py b/src/databricks/labs/lakebridge/reconcile/connectors/databricks.py index 89d05b3e4c..ef7fe9c7ce 100644 --- a/src/databricks/labs/lakebridge/reconcile/connectors/databricks.py +++ b/src/databricks/labs/lakebridge/reconcile/connectors/databricks.py @@ -7,10 +7,9 @@ from pyspark.sql.functions import col from sqlglot import Dialect +from databricks.labs.lakebridge.config import ReconcileCredentialConfig from databricks.labs.lakebridge.reconcile.connectors.data_source import DataSource -from databricks.labs.lakebridge.reconcile.connectors.models import NormalizedIdentifier -from databricks.labs.lakebridge.reconcile.connectors.secrets import SecretsMixin -from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils +from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils, NormalizedIdentifier from databricks.labs.lakebridge.reconcile.recon_config import JdbcReaderOptions, Schema from databricks.sdk import WorkspaceClient @@ -36,7 +35,7 @@ def _get_schema_query(catalog: str, schema: str, table: str): return re.sub(r'\s+', ' ', query) -class DatabricksDataSource(DataSource, SecretsMixin): +class DatabricksDataSource(DataSource): _IDENTIFIER_DELIMITER = "`" def __init__( @@ -44,12 +43,10 @@ def __init__( engine: Dialect, spark: SparkSession, ws: WorkspaceClient, - secret_scope: str, ): self._engine = engine self._spark = spark self._ws = ws - self._secret_scope = secret_scope def read_data( self, @@ -96,6 +93,9 @@ def get_schema( except (RuntimeError, PySparkException) as e: return self.log_and_throw_exception(e, "schema", schema_query) + def load_credentials(self, creds: ReconcileCredentialConfig) -> "DatabricksDataSource": + return self + def normalize_identifier(self, identifier: str) -> NormalizedIdentifier: return DialectUtils.normalize_identifier( identifier, diff --git a/src/databricks/labs/lakebridge/reconcile/connectors/dialect_utils.py b/src/databricks/labs/lakebridge/reconcile/connectors/dialect_utils.py index 665755e85c..2785fd8002 100644 --- a/src/databricks/labs/lakebridge/reconcile/connectors/dialect_utils.py +++ b/src/databricks/labs/lakebridge/reconcile/connectors/dialect_utils.py @@ -1,4 +1,10 @@ -from databricks.labs.lakebridge.reconcile.connectors.models import NormalizedIdentifier +import dataclasses + + +@dataclasses.dataclass() +class NormalizedIdentifier: + ansi_normalized: str + source_normalized: str class DialectUtils: diff --git a/src/databricks/labs/lakebridge/reconcile/connectors/jdbc_reader.py b/src/databricks/labs/lakebridge/reconcile/connectors/jdbc_reader.py index f2313e7a90..7159d04368 100644 --- a/src/databricks/labs/lakebridge/reconcile/connectors/jdbc_reader.py +++ b/src/databricks/labs/lakebridge/reconcile/connectors/jdbc_reader.py @@ -8,7 +8,6 @@ class JDBCReaderMixin: _spark: SparkSession - # TODO update the url def _get_jdbc_reader(self, query, jdbc_url, driver, additional_options: dict | None = None): driver_class = { "oracle": "oracle.jdbc.OracleDriver", diff --git a/src/databricks/labs/lakebridge/reconcile/connectors/models.py b/src/databricks/labs/lakebridge/reconcile/connectors/models.py deleted file mode 100644 index c98cbef7dd..0000000000 --- a/src/databricks/labs/lakebridge/reconcile/connectors/models.py +++ /dev/null @@ -1,7 +0,0 @@ -import dataclasses - - -@dataclasses.dataclass -class NormalizedIdentifier: - ansi_normalized: str - source_normalized: str diff --git a/src/databricks/labs/lakebridge/reconcile/connectors/oracle.py b/src/databricks/labs/lakebridge/reconcile/connectors/oracle.py index ebcd5f0991..26c042c9f3 100644 --- a/src/databricks/labs/lakebridge/reconcile/connectors/oracle.py +++ b/src/databricks/labs/lakebridge/reconcile/connectors/oracle.py @@ -7,18 +7,18 @@ from pyspark.sql.functions import col from sqlglot import Dialect -from databricks.labs.lakebridge.reconcile.connectors.data_source import DataSource +from databricks.labs.lakebridge.config import ReconcileCredentialConfig +from databricks.labs.lakebridge.connections.credential_manager import create_credential_manager +from databricks.labs.lakebridge.reconcile.connectors.data_source import DataSource, build_credentials from databricks.labs.lakebridge.reconcile.connectors.jdbc_reader import JDBCReaderMixin -from databricks.labs.lakebridge.reconcile.connectors.models import NormalizedIdentifier -from databricks.labs.lakebridge.reconcile.connectors.secrets import SecretsMixin -from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils +from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils, NormalizedIdentifier from databricks.labs.lakebridge.reconcile.recon_config import JdbcReaderOptions, Schema from databricks.sdk import WorkspaceClient logger = logging.getLogger(__name__) -class OracleDataSource(DataSource, SecretsMixin, JDBCReaderMixin): +class OracleDataSource(DataSource, JDBCReaderMixin): _DRIVER = "oracle" _IDENTIFIER_DELIMITER = "\"" _SCHEMA_QUERY = """select column_name, case when (data_precision is not null @@ -34,23 +34,17 @@ class OracleDataSource(DataSource, SecretsMixin, JDBCReaderMixin): FROM ALL_TAB_COLUMNS WHERE lower(TABLE_NAME) = '{table}' and lower(owner) = '{owner}'""" - def __init__( - self, - engine: Dialect, - spark: SparkSession, - ws: WorkspaceClient, - secret_scope: str, - ): + def __init__(self, engine: Dialect, spark: SparkSession, ws: WorkspaceClient): self._engine = engine self._spark = spark self._ws = ws - self._secret_scope = secret_scope + self._creds: dict[str, str] = {} @property def get_jdbc_url(self) -> str: return ( - f"jdbc:{OracleDataSource._DRIVER}:thin:@//{self._get_secret('host')}" - f":{self._get_secret('port')}/{self._get_secret('database')}" + f"jdbc:{OracleDataSource._DRIVER}:thin:@//{self._creds.get('host')}" + f":{self._creds.get('port')}/{self._creds.get('database')}" ) def read_data( @@ -108,13 +102,38 @@ def _get_timestamp_options() -> dict[str, str]: } def reader(self, query: str) -> DataFrameReader: - user = self._get_secret('user') - password = self._get_secret('password') + user = self._creds.get('user') + password = self._creds.get('password') logger.debug(f"Using user: {user} to connect to Oracle") return self._get_jdbc_reader( query, self.get_jdbc_url, OracleDataSource._DRIVER, {"user": user, "password": password} ) + def load_credentials(self, creds: ReconcileCredentialConfig) -> "OracleDataSource": + connector_creds = [ + "host", + "port", + "database", + "user", + "password", + ] + + use_scope = creds.source_creds.get("__secret_scope") + if use_scope: + source_creds = {key: f"{use_scope}/{key}" for key in connector_creds} + + assert creds.vault_type == "databricks", "Secret scope provided, vault_type must be 'databricks'" + parsed_creds = build_credentials(creds.vault_type, "oracle", source_creds) + else: + parsed_creds = build_credentials(creds.vault_type, "oracle", creds.source_creds) + + self._creds = create_credential_manager(parsed_creds, self._ws).get_credentials("oracle") + assert all( + self._creds.get(k) for k in connector_creds + ), f"Missing mandatory Oracle credentials. Please configure all of {connector_creds}." + + return self + def normalize_identifier(self, identifier: str) -> NormalizedIdentifier: normalized = DialectUtils.normalize_identifier( identifier, diff --git a/src/databricks/labs/lakebridge/reconcile/connectors/secrets.py b/src/databricks/labs/lakebridge/reconcile/connectors/secrets.py deleted file mode 100644 index daa213afc8..0000000000 --- a/src/databricks/labs/lakebridge/reconcile/connectors/secrets.py +++ /dev/null @@ -1,49 +0,0 @@ -import base64 -import logging - -from databricks.sdk import WorkspaceClient -from databricks.sdk.errors import NotFound - -logger = logging.getLogger(__name__) - - -# TODO use CredentialManager to allow for changing secret provider for tests -class SecretsMixin: - _ws: WorkspaceClient - _secret_scope: str - - def _get_secret_or_none(self, secret_key: str) -> str | None: - """ - Get the secret value given a secret scope & secret key. Log a warning if secret does not exist - Used To ensure backwards compatibility when supporting new secrets - """ - try: - # Return the decoded secret value in string format - return self._get_secret(secret_key) - except NotFound as e: - logger.warning(f"Secret not found: key={secret_key}") - logger.debug("Secret lookup failed", exc_info=e) - return None - - def _get_secret(self, secret_key: str) -> str: - """Get the secret value given a secret scope & secret key. - - Raises: - NotFound: The secret could not be found. - UnicodeDecodeError: The secret value was not Base64-encoded UTF-8. - """ - try: - # Return the decoded secret value in string format - secret = self._ws.secrets.get_secret(self._secret_scope, secret_key) - assert secret.value is not None - return base64.b64decode(secret.value).decode("utf-8") - except NotFound as e: - raise NotFound(f'Secret does not exist with scope: {self._secret_scope} and key: {secret_key} : {e}') from e - except UnicodeDecodeError as e: - raise UnicodeDecodeError( - "utf-8", - secret_key.encode(), - 0, - 1, - f"Secret {self._secret_scope}/{secret_key} has Base64 bytes that cannot be decoded to utf-8 string: {e}.", - ) from e diff --git a/src/databricks/labs/lakebridge/reconcile/connectors/snowflake.py b/src/databricks/labs/lakebridge/reconcile/connectors/snowflake.py index e66751d29b..173e834043 100644 --- a/src/databricks/labs/lakebridge/reconcile/connectors/snowflake.py +++ b/src/databricks/labs/lakebridge/reconcile/connectors/snowflake.py @@ -9,20 +9,21 @@ from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization -from databricks.labs.lakebridge.reconcile.connectors.data_source import DataSource +from databricks.labs.lakebridge.config import ReconcileCredentialConfig +from databricks.labs.lakebridge.connections.credential_manager import ( + create_credential_manager, +) +from databricks.labs.lakebridge.reconcile.connectors.data_source import DataSource, build_credentials from databricks.labs.lakebridge.reconcile.connectors.jdbc_reader import JDBCReaderMixin -from databricks.labs.lakebridge.reconcile.connectors.models import NormalizedIdentifier -from databricks.labs.lakebridge.reconcile.connectors.secrets import SecretsMixin -from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils +from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils, NormalizedIdentifier from databricks.labs.lakebridge.reconcile.exception import InvalidSnowflakePemPrivateKey from databricks.labs.lakebridge.reconcile.recon_config import JdbcReaderOptions, Schema from databricks.sdk import WorkspaceClient -from databricks.sdk.errors import NotFound logger = logging.getLogger(__name__) -class SnowflakeDataSource(DataSource, SecretsMixin, JDBCReaderMixin): +class SnowflakeDataSource(DataSource, JDBCReaderMixin): _DRIVER = "snowflake" _IDENTIFIER_DELIMITER = "\"" @@ -51,33 +52,60 @@ class SnowflakeDataSource(DataSource, SecretsMixin, JDBCReaderMixin): where lower(table_name)='{table}' and table_schema = '{schema}' order by ordinal_position""" - def __init__( - self, - engine: Dialect, - spark: SparkSession, - ws: WorkspaceClient, - secret_scope: str, - ): + def __init__(self, engine: Dialect, spark: SparkSession, ws: WorkspaceClient): self._engine = engine self._spark = spark self._ws = ws - self._secret_scope = secret_scope + self._creds: dict[str, str] = {} + + def load_credentials(self, creds: ReconcileCredentialConfig) -> "SnowflakeDataSource": + connector_creds = [ + "sfUser", + "sfUrl", + "sfDatabase", + "sfSchema", + "sfWarehouse", + "sfRole", + ] + + use_scope = creds.source_creds.get("__secret_scope") + if use_scope: + # to use pem key and/or pem password, migrate to source_creds approach + connector_creds += ["sfPassword"] + source_creds = {key: f"{use_scope}/{key}" for key in connector_creds} + + assert creds.vault_type == "databricks", "Secret scope provided, vault_type must be 'databricks'" + parsed_creds = build_credentials(creds.vault_type, "snowflake", source_creds) + else: + parsed_creds = build_credentials(creds.vault_type, "snowflake", creds.source_creds) + + self._creds = create_credential_manager(parsed_creds, self._ws).get_credentials("snowflake") + assert all( + self._creds.get(k) for k in connector_creds + ), f"Missing mandatory Snowflake credentials. Please configure all of {connector_creds}." + assert any( + self._creds.get(k) for k in ("sfPassword", "pem_private_key") + ), "Missing Snowflake credentials. Please configure any of [sfPassword, pem_private_key]." + + if self._creds.get("pem_private_key"): + self._creds["pem_private_key"] = SnowflakeDataSource._get_private_key( + self._creds["pem_private_key"], + self._creds.get("pem_private_key_password"), + ) + + return self @property def get_jdbc_url(self) -> str: - try: - sf_password = self._get_secret('sfPassword') - except (NotFound, KeyError) as e: - message = "sfPassword is mandatory for jdbc connectivity with Snowflake." - logger.error(message) - raise NotFound(message) from e + if not self._creds: + raise RuntimeError("Credentials not loaded. Please call `load_credentials(ReconcileCredentialConfig)`.") return ( - f"jdbc:{SnowflakeDataSource._DRIVER}://{self._get_secret('sfAccount')}.snowflakecomputing.com" - f"/?user={self._get_secret('sfUser')}&password={sf_password}" - f"&db={self._get_secret('sfDatabase')}&schema={self._get_secret('sfSchema')}" - f"&warehouse={self._get_secret('sfWarehouse')}&role={self._get_secret('sfRole')}" - ) + f"jdbc:{SnowflakeDataSource._DRIVER}://{self._creds['sfUrl']}" + f"/?user={self._creds['sfUser']}&password={self._creds['sfPassword']}" + f"&db={self._creds['sfDatabase']}&schema={self._creds['sfSchema']}" + f"&warehouse={self._creds['sfWarehouse']}&role={self._creds['sfRole']}" + ) # TODO Support PEM key auth def read_data( self, @@ -132,39 +160,10 @@ def get_schema( return self.log_and_throw_exception(e, "schema", schema_query) def reader(self, query: str) -> DataFrameReader: - options = self._get_snowflake_options() - return self._spark.read.format("snowflake").option("dbtable", f"({query}) as tmp").options(**options) - - # TODO cache this method using @functools.cache - # Pay attention to https://pylint.pycqa.org/en/latest/user_guide/messages/warning/method-cache-max-size-none.html - def _get_snowflake_options(self): - options = { - "sfUrl": self._get_secret('sfUrl'), - "sfUser": self._get_secret('sfUser'), - "sfDatabase": self._get_secret('sfDatabase'), - "sfSchema": self._get_secret('sfSchema'), - "sfWarehouse": self._get_secret('sfWarehouse'), - "sfRole": self._get_secret('sfRole'), - } - options = options | self._get_snowflake_auth_options() - - return options - - def _get_snowflake_auth_options(self): - try: - key = SnowflakeDataSource._get_private_key( - self._get_secret('pem_private_key'), self._get_secret_or_none('pem_private_key_password') - ) - return {"pem_private_key": key} - except (NotFound, KeyError): - logger.warning("pem_private_key not found. Checking for sfPassword") - try: - password = self._get_secret('sfPassword') - return {"sfPassword": password} - except (NotFound, KeyError) as e: - message = "sfPassword and pem_private_key not found. Either one is required for snowflake auth." - logger.error(message) - raise NotFound(message) from e + if not self._creds: + raise RuntimeError("Credentials not loaded. Please call `load_credentials(ReconcileCredentialConfig)`.") + + return self._spark.read.format("snowflake").option("dbtable", f"({query}) as tmp").options(**self._creds) @staticmethod def _get_private_key(pem_private_key: str, pem_private_key_password: str | None) -> str: diff --git a/src/databricks/labs/lakebridge/reconcile/connectors/source_adapter.py b/src/databricks/labs/lakebridge/reconcile/connectors/source_adapter.py index 71039f4494..286bb36a8f 100644 --- a/src/databricks/labs/lakebridge/reconcile/connectors/source_adapter.py +++ b/src/databricks/labs/lakebridge/reconcile/connectors/source_adapter.py @@ -17,14 +17,13 @@ def create_adapter( engine: Dialect, spark: SparkSession, ws: WorkspaceClient, - secret_scope: str, ) -> DataSource: if isinstance(engine, Snowflake): - return SnowflakeDataSource(engine, spark, ws, secret_scope) + return SnowflakeDataSource(engine, spark, ws) if isinstance(engine, Oracle): - return OracleDataSource(engine, spark, ws, secret_scope) + return OracleDataSource(engine, spark, ws) if isinstance(engine, Databricks): - return DatabricksDataSource(engine, spark, ws, secret_scope) + return DatabricksDataSource(engine, spark, ws) if isinstance(engine, TSQL): - return TSQLServerDataSource(engine, spark, ws, secret_scope) + return TSQLServerDataSource(engine, spark, ws) raise ValueError(f"Unsupported source type --> {engine}") diff --git a/src/databricks/labs/lakebridge/reconcile/connectors/tsql.py b/src/databricks/labs/lakebridge/reconcile/connectors/tsql.py index a5241b7ac8..c0ddf524db 100644 --- a/src/databricks/labs/lakebridge/reconcile/connectors/tsql.py +++ b/src/databricks/labs/lakebridge/reconcile/connectors/tsql.py @@ -7,11 +7,11 @@ from pyspark.sql.functions import col from sqlglot import Dialect -from databricks.labs.lakebridge.reconcile.connectors.data_source import DataSource +from databricks.labs.lakebridge.config import ReconcileCredentialConfig +from databricks.labs.lakebridge.connections.credential_manager import create_credential_manager +from databricks.labs.lakebridge.reconcile.connectors.data_source import DataSource, build_credentials from databricks.labs.lakebridge.reconcile.connectors.jdbc_reader import JDBCReaderMixin -from databricks.labs.lakebridge.reconcile.connectors.models import NormalizedIdentifier -from databricks.labs.lakebridge.reconcile.connectors.secrets import SecretsMixin -from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils +from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils, NormalizedIdentifier from databricks.labs.lakebridge.reconcile.recon_config import JdbcReaderOptions, Schema from databricks.sdk import WorkspaceClient @@ -49,7 +49,7 @@ """ -class TSQLServerDataSource(DataSource, SecretsMixin, JDBCReaderMixin): +class TSQLServerDataSource(DataSource, JDBCReaderMixin): _DRIVER = "sqlserver" _IDENTIFIER_DELIMITER = {"prefix": "[", "suffix": "]"} @@ -58,23 +58,22 @@ def __init__( engine: Dialect, spark: SparkSession, ws: WorkspaceClient, - secret_scope: str, ): self._engine = engine self._spark = spark self._ws = ws - self._secret_scope = secret_scope + self._creds: dict[str, str] = {} @property def get_jdbc_url(self) -> str: # Construct the JDBC URL return ( - f"jdbc:{self._DRIVER}://{self._get_secret('host')}:{self._get_secret('port')};" - f"databaseName={self._get_secret('database')};" - f"user={self._get_secret('user')};" - f"password={self._get_secret('password')};" - f"encrypt={self._get_secret('encrypt')};" - f"trustServerCertificate={self._get_secret('trustServerCertificate')};" + f"jdbc:{self._DRIVER}://{self._creds.get('host')}:{self._creds.get('port')};" + f"databaseName={self._creds.get('database')};" + f"user={self._creds.get('user')};" + f"password={self._creds.get('password')};" + f"encrypt={self._creds.get('encrypt')};" + f"trustServerCertificate={self._creds.get('trustServerCertificate')};" ) def read_data( @@ -104,6 +103,33 @@ def read_data( except (RuntimeError, PySparkException) as e: return self.log_and_throw_exception(e, "data", table_query) + def load_credentials(self, creds: ReconcileCredentialConfig) -> "TSQLServerDataSource": + connector_creds = [ + "host", + "port", + "database", + "user", + "password", + "encrypt", + "trustServerCertificate", + ] + + use_scope = creds.source_creds.get("__secret_scope") + if use_scope: + source_creds = {key: f"{use_scope}/{key}" for key in connector_creds} + + assert creds.vault_type == "databricks", "Secret scope provided, vault_type must be 'databricks'" + parsed_creds = build_credentials(creds.vault_type, "mssql", source_creds) + else: + parsed_creds = build_credentials(creds.vault_type, "mssql", creds.source_creds) + + self._creds = create_credential_manager(parsed_creds, self._ws).get_credentials("mssql") + assert all( + self._creds.get(k) for k in connector_creds + ), f"Missing mandatory MS SQL credentials. Please configure all of {connector_creds}." + + return self + def get_schema( self, catalog: str | None, diff --git a/src/databricks/labs/lakebridge/reconcile/trigger_recon_service.py b/src/databricks/labs/lakebridge/reconcile/trigger_recon_service.py index 3fd837d668..9873a177ca 100644 --- a/src/databricks/labs/lakebridge/reconcile/trigger_recon_service.py +++ b/src/databricks/labs/lakebridge/reconcile/trigger_recon_service.py @@ -74,7 +74,7 @@ def create_recon_dependencies( engine=reconcile_config.data_source, spark=spark, ws=ws_client, - secret_scope=reconcile_config.secret_scope, + creds=reconcile_config.creds, ) recon_id = str(uuid4()) diff --git a/src/databricks/labs/lakebridge/reconcile/utils.py b/src/databricks/labs/lakebridge/reconcile/utils.py index 42a309d8da..1fa80b6a37 100644 --- a/src/databricks/labs/lakebridge/reconcile/utils.py +++ b/src/databricks/labs/lakebridge/reconcile/utils.py @@ -4,7 +4,7 @@ from databricks.sdk import WorkspaceClient -from databricks.labs.lakebridge.config import ReconcileMetadataConfig +from databricks.labs.lakebridge.config import ReconcileMetadataConfig, ReconcileCredentialConfig from databricks.labs.lakebridge.reconcile.connectors.source_adapter import create_adapter from databricks.labs.lakebridge.reconcile.exception import InvalidInputException from databricks.labs.lakebridge.reconcile.recon_config import Table @@ -17,10 +17,12 @@ def initialise_data_source( ws: WorkspaceClient, spark: SparkSession, engine: str, - secret_scope: str, + creds: ReconcileCredentialConfig, ): - source = create_adapter(engine=get_dialect(engine), spark=spark, ws=ws, secret_scope=secret_scope) - target = create_adapter(engine=get_dialect("databricks"), spark=spark, ws=ws, secret_scope=secret_scope) + source = create_adapter(engine=get_dialect(engine), spark=spark, ws=ws) + target = create_adapter(engine=get_dialect("databricks"), spark=spark, ws=ws) + source.load_credentials(creds) + target.load_credentials(creds) return source, target diff --git a/tests/conftest.py b/tests/conftest.py index 2b2419231e..86cdf726e0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,8 +17,8 @@ from databricks.sdk import WorkspaceClient from databricks.sdk.service import iam -from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils -from databricks.labs.lakebridge.reconcile.connectors.models import NormalizedIdentifier +from databricks.labs.lakebridge.config import ReconcileCredentialConfig +from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils, NormalizedIdentifier from databricks.labs.lakebridge.reconcile.connectors.data_source import DataSource, MockDataSource from databricks.labs.lakebridge.reconcile.recon_config import ( Table, @@ -344,6 +344,9 @@ def read_data( ) -> DataFrame: raise RuntimeError("Not implemented") + def load_credentials(self, creds: ReconcileCredentialConfig) -> "FakeDataSource": + raise RuntimeError("Not implemented") + @pytest.fixture def fake_oracle_datasource() -> FakeDataSource: diff --git a/tests/integration/reconcile/connectors/test_read_schema.py b/tests/integration/reconcile/connectors/test_read_schema.py index b1d2752c7f..5f5a7f3f77 100644 --- a/tests/integration/reconcile/connectors/test_read_schema.py +++ b/tests/integration/reconcile/connectors/test_read_schema.py @@ -17,8 +17,8 @@ class TSQLServerDataSourceUnderTest(TSQLServerDataSource): def __init__(self, spark, ws): - super().__init__(get_dialect("tsql"), spark, ws, "secret_scope") - self._test_env = TestEnvGetter(True) + super().__init__(get_dialect("tsql"), spark, ws) + self._test_env = TestEnvGetter(True) # TODO use load_credentials @property def get_jdbc_url(self) -> str: @@ -31,8 +31,8 @@ def get_jdbc_url(self) -> str: class OracleDataSourceUnderTest(OracleDataSource): def __init__(self, spark, ws): - super().__init__(get_dialect("oracle"), spark, ws, "secret_scope") - self._test_env = TestEnvGetter(False) + super().__init__(get_dialect("oracle"), spark, ws) + self._test_env = TestEnvGetter(False) # TODO use load_credentials @property def get_jdbc_url(self) -> str: @@ -48,8 +48,8 @@ def reader(self, query: str) -> DataFrameReader: class SnowflakeDataSourceUnderTest(SnowflakeDataSource): def __init__(self, spark, ws): - super().__init__(get_dialect("snowflake"), spark, ws, "secret_scope") - self._test_env = TestEnvGetter(True) + super().__init__(get_dialect("snowflake"), spark, ws) + self._test_env = TestEnvGetter(True) # TODO use load_credentials @property def get_jdbc_url(self) -> str: @@ -86,7 +86,7 @@ def test_sql_server_read_schema_happy(mock_spark): def test_databricks_read_schema_happy(mock_spark): mock_ws = create_autospec(WorkspaceClient) - connector = DatabricksDataSource(get_dialect("databricks"), mock_spark, mock_ws, "my_secret") + connector = DatabricksDataSource(get_dialect("databricks"), mock_spark, mock_ws) mock_spark.sql("CREATE DATABASE IF NOT EXISTS my_test_db") mock_spark.sql("CREATE TABLE IF NOT EXISTS my_test_db.my_test_table (id INT, name STRING) USING parquet") diff --git a/tests/integration/reconcile/query_builder/test_execute.py b/tests/integration/reconcile/query_builder/test_execute.py index 0015ff9dbb..c510a1d04b 100644 --- a/tests/integration/reconcile/query_builder/test_execute.py +++ b/tests/integration/reconcile/query_builder/test_execute.py @@ -1,18 +1,23 @@ +import base64 from pathlib import Path from dataclasses import dataclass from datetime import datetime -from unittest.mock import patch, MagicMock +from unittest.mock import patch, MagicMock, create_autospec import pytest from pyspark import Row from pyspark.errors import PySparkException from pyspark.testing import assertDataFrameEqual +from databricks.sdk import WorkspaceClient +from databricks.sdk.service.workspace import GetSecretResponse + from databricks.labs.lakebridge.config import ( DatabaseConfig, TableRecon, ReconcileMetadataConfig, ReconcileConfig, + ReconcileCredentialConfig, ) from databricks.labs.lakebridge.reconcile.reconciliation import Reconciliation from databricks.labs.lakebridge.reconcile.trigger_recon_service import TriggerReconService @@ -1901,12 +1906,22 @@ def test_data_recon_with_source_exception( def test_initialise_data_source(mock_workspace_client, mock_spark): src_engine = get_dialect("snowflake") - secret_scope = "test" - source, target = initialise_data_source(mock_workspace_client, mock_spark, src_engine, secret_scope) + sf_creds = { + "sfUser": "user", + "sfPassword": "password", + "sfUrl": "account.snowflakecomputing.com", + "sfDatabase": "database", + "sfSchema": "schema", + "sfWarehouse": "warehouse", + "sfRole": "role", + } + source, target = initialise_data_source( + mock_workspace_client, mock_spark, "snowflake", ReconcileCredentialConfig("local", sf_creds) + ) - snowflake_data_source = SnowflakeDataSource(src_engine, mock_spark, mock_workspace_client, secret_scope).__class__ - databricks_data_source = DatabricksDataSource(src_engine, mock_spark, mock_workspace_client, secret_scope).__class__ + snowflake_data_source = SnowflakeDataSource(src_engine, mock_spark, mock_workspace_client).__class__ + databricks_data_source = DatabricksDataSource(src_engine, mock_spark, mock_workspace_client).__class__ assert isinstance(source, snowflake_data_source) assert isinstance(target, databricks_data_source) @@ -2020,7 +2035,10 @@ def test_reconcile_data_with_threshold_and_row_report_type( @patch('databricks.labs.lakebridge.reconcile.recon_capture.generate_final_reconcile_output') def test_recon_output_without_exception(mock_gen_final_recon_output): - mock_workspace_client = MagicMock() + mock_workspace_client = create_autospec(WorkspaceClient) + mock_workspace_client.secrets.get_secret.return_value = GetSecretResponse( + key="key", value=base64.b64encode(bytes('value', 'utf-8')).decode('utf-8') + ) mock_spark = MagicMock() mock_table_recon = MagicMock() mock_gen_final_recon_output.return_value = ReconcileOutput( diff --git a/tests/integration/reconcile/test_oracle_reconcile.py b/tests/integration/reconcile/test_oracle_reconcile.py index 841b130168..29ceb47549 100644 --- a/tests/integration/reconcile/test_oracle_reconcile.py +++ b/tests/integration/reconcile/test_oracle_reconcile.py @@ -17,7 +17,7 @@ class DatabricksDataSourceUnderTest(DatabricksDataSource): def __init__(self, databricks, ws, local_spark): - super().__init__(get_dialect("databricks"), databricks, ws, "not used") + super().__init__(get_dialect("databricks"), databricks, ws) self._local_spark = local_spark def read_data( diff --git a/tests/unit/reconcile/connectors/test_databricks.py b/tests/unit/reconcile/connectors/test_databricks.py index 7f89612e85..2f69dbd317 100644 --- a/tests/unit/reconcile/connectors/test_databricks.py +++ b/tests/unit/reconcile/connectors/test_databricks.py @@ -3,7 +3,7 @@ import pytest -from databricks.labs.lakebridge.reconcile.connectors.models import NormalizedIdentifier +from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import NormalizedIdentifier from databricks.labs.lakebridge.transpiler.sqlglot.dialect_utils import get_dialect from databricks.labs.lakebridge.reconcile.connectors.databricks import DatabricksDataSource from databricks.labs.lakebridge.reconcile.exception import DataSourceRuntimeException @@ -23,10 +23,10 @@ def initial_setup(): def test_get_schema(): # initial setup - engine, spark, ws, scope = initial_setup() + engine, spark, ws, _ = initial_setup() # catalog as catalog - ddds = DatabricksDataSource(engine, spark, ws, scope) + ddds = DatabricksDataSource(engine, spark, ws) ddds.get_schema("catalog", "schema", "supplier") spark.sql.assert_called_with( re.sub( @@ -56,10 +56,10 @@ def test_get_schema(): def test_read_data_from_uc(): # initial setup - engine, spark, ws, scope = initial_setup() + engine, spark, ws, _ = initial_setup() # create object for DatabricksDataSource - ddds = DatabricksDataSource(engine, spark, ws, scope) + ddds = DatabricksDataSource(engine, spark, ws) # Test with query ddds.read_data("org", "data", "employee", "select id as id, name as name from :tbl", None) @@ -72,10 +72,10 @@ def test_read_data_from_uc(): def test_read_data_from_hive(): # initial setup - engine, spark, ws, scope = initial_setup() + engine, spark, ws, _ = initial_setup() # create object for DatabricksDataSource - ddds = DatabricksDataSource(engine, spark, ws, scope) + ddds = DatabricksDataSource(engine, spark, ws) # Test with query ddds.read_data("hive_metastore", "data", "employee", "select id as id, name as name from :tbl", None) @@ -88,10 +88,10 @@ def test_read_data_from_hive(): def test_read_data_exception_handling(): # initial setup - engine, spark, ws, scope = initial_setup() + engine, spark, ws, _ = initial_setup() # create object for DatabricksDataSource - ddds = DatabricksDataSource(engine, spark, ws, scope) + ddds = DatabricksDataSource(engine, spark, ws) spark.sql.side_effect = RuntimeError("Test Exception") with pytest.raises( @@ -104,10 +104,10 @@ def test_read_data_exception_handling(): def test_get_schema_exception_handling(): # initial setup - engine, spark, ws, scope = initial_setup() + engine, spark, ws, _ = initial_setup() # create object for DatabricksDataSource - ddds = DatabricksDataSource(engine, spark, ws, scope) + ddds = DatabricksDataSource(engine, spark, ws) spark.sql.side_effect = RuntimeError("Test Exception") with pytest.raises(DataSourceRuntimeException) as exception: ddds.get_schema("org", "data", "employee") @@ -121,8 +121,8 @@ def test_get_schema_exception_handling(): def test_normalize_identifier(): - engine, spark, ws, scope = initial_setup() - data_source = DatabricksDataSource(engine, spark, ws, scope) + engine, spark, ws, _ = initial_setup() + data_source = DatabricksDataSource(engine, spark, ws) assert data_source.normalize_identifier("a") == NormalizedIdentifier("`a`", '`a`') assert data_source.normalize_identifier('`b`') == NormalizedIdentifier("`b`", '`b`') diff --git a/tests/unit/reconcile/connectors/test_oracle.py b/tests/unit/reconcile/connectors/test_oracle.py index 086b48f19c..11b917c3e1 100644 --- a/tests/unit/reconcile/connectors/test_oracle.py +++ b/tests/unit/reconcile/connectors/test_oracle.py @@ -4,7 +4,8 @@ import pytest -from databricks.labs.lakebridge.reconcile.connectors.models import NormalizedIdentifier +from databricks.labs.lakebridge.config import ReconcileCredentialConfig +from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import NormalizedIdentifier from databricks.labs.lakebridge.transpiler.sqlglot.dialect_utils import get_dialect from databricks.labs.lakebridge.reconcile.connectors.oracle import OracleDataSource from databricks.labs.lakebridge.reconcile.exception import DataSourceRuntimeException @@ -31,6 +32,16 @@ def mock_secret(scope, key): return secret_mock[scope][key] +def oracle_creds(scope): + return { + "host": f"{scope}/host", + "port": f"{scope}/port", + "database": f"{scope}/database", + "user": f"{scope}/user", + "password": f"{scope}/password", + } + + def initial_setup(): pyspark_sql_session = MagicMock() spark = pyspark_sql_session.SparkSession.builder.getOrCreate() @@ -47,8 +58,9 @@ def test_read_data_with_options(): # initial setup engine, spark, ws, scope = initial_setup() - # create object for SnowflakeDataSource - ords = OracleDataSource(engine, spark, ws, scope) + # create object for OracleDataSource + ords = OracleDataSource(engine, spark, ws) + ords.load_credentials(ReconcileCredentialConfig("databricks", oracle_creds(scope))) # Create a Tables configuration object with JDBC reader options table_conf = Table( source_name="supplier", @@ -96,10 +108,10 @@ def test_read_data_with_options(): def test_get_schema(): # initial setup - engine, spark, ws, scope = initial_setup() + engine, spark, ws, _ = initial_setup() - # create object for SnowflakeDataSource - ords = OracleDataSource(engine, spark, ws, scope) + # create object for OracleDataSource + ords = OracleDataSource(engine, spark, ws) # call test method ords.get_schema(None, "data", "employee") # spark assertions @@ -127,8 +139,8 @@ def test_get_schema(): def test_read_data_exception_handling(): # initial setup - engine, spark, ws, scope = initial_setup() - ords = OracleDataSource(engine, spark, ws, scope) + engine, spark, ws, _ = initial_setup() + ords = OracleDataSource(engine, spark, ws) # Create a Tables configuration object table_conf = Table( source_name="supplier", @@ -157,8 +169,8 @@ def test_read_data_exception_handling(): def test_get_schema_exception_handling(): # initial setup - engine, spark, ws, scope = initial_setup() - ords = OracleDataSource(engine, spark, ws, scope) + engine, spark, ws, _ = initial_setup() + ords = OracleDataSource(engine, spark, ws) spark.read.format().option().option().option().option().option().load.side_effect = RuntimeError("Test Exception") @@ -184,8 +196,8 @@ def test_get_schema_exception_handling(): @pytest.mark.skip("Turned off till we can handle case sensitivity.") def test_normalize_identifier(): - engine, spark, ws, scope = initial_setup() - data_source = OracleDataSource(engine, spark, ws, scope) + engine, spark, ws, _ = initial_setup() + data_source = OracleDataSource(engine, spark, ws) assert data_source.normalize_identifier("a") == NormalizedIdentifier("`a`", '"a"') assert data_source.normalize_identifier('"b"') == NormalizedIdentifier("`b`", '"b"') diff --git a/tests/unit/reconcile/connectors/test_secrets.py b/tests/unit/reconcile/connectors/test_secrets.py deleted file mode 100644 index dea7515b09..0000000000 --- a/tests/unit/reconcile/connectors/test_secrets.py +++ /dev/null @@ -1,65 +0,0 @@ -import base64 -from unittest.mock import create_autospec - -import pytest - -from databricks.labs.lakebridge.reconcile.connectors.secrets import SecretsMixin -from databricks.sdk import WorkspaceClient -from databricks.sdk.errors import NotFound -from databricks.sdk.service.workspace import GetSecretResponse - - -class SecretsMixinUnderTest(SecretsMixin): - def __init__(self, ws: WorkspaceClient, secret_scope: str): - self._ws = ws - self._secret_scope = secret_scope - - def get_secret(self, secret_key: str) -> str: - return self._get_secret(secret_key) - - def get_secret_or_none(self, secret_key: str) -> str | None: - return self._get_secret_or_none(secret_key) - - -def mock_secret(scope, key): - secret_mock = { - "scope": { - 'user_name': GetSecretResponse( - key='user_name', value=base64.b64encode(bytes('my_user', 'utf-8')).decode('utf-8') - ), - 'password': GetSecretResponse( - key='password', value=base64.b64encode(bytes('my_password', 'utf-8')).decode('utf-8') - ), - } - } - - return secret_mock.get(scope).get(key) - - -def test_get_secrets_happy(): - ws = create_autospec(WorkspaceClient) - ws.secrets.get_secret.side_effect = mock_secret - - sut = SecretsMixinUnderTest(ws, "scope") - - assert sut.get_secret("user_name") == "my_user" - assert sut.get_secret_or_none("user_name") == "my_user" - assert sut.get_secret("password") == "my_password" - assert sut.get_secret_or_none("password") == "my_password" - - -def test_get_secrets_not_found_exception(): - ws = create_autospec(WorkspaceClient) - ws.secrets.get_secret.side_effect = NotFound("Test Exception") - sut = SecretsMixinUnderTest(ws, "scope") - - with pytest.raises(NotFound, match="Secret does not exist with scope: scope and key: unknown : Test Exception"): - sut.get_secret("unknown") - - -def test_get_secrets_not_found_swallow(): - ws = create_autospec(WorkspaceClient) - ws.secrets.get_secret.side_effect = NotFound("Test Exception") - sut = SecretsMixinUnderTest(ws, "scope") - - assert sut.get_secret_or_none("unknown") is None diff --git a/tests/unit/reconcile/connectors/test_snowflake.py b/tests/unit/reconcile/connectors/test_snowflake.py index 114aa42f2a..566f58dd45 100644 --- a/tests/unit/reconcile/connectors/test_snowflake.py +++ b/tests/unit/reconcile/connectors/test_snowflake.py @@ -6,7 +6,8 @@ from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.hazmat.primitives import serialization -from databricks.labs.lakebridge.reconcile.connectors.models import NormalizedIdentifier +from databricks.labs.lakebridge.config import ReconcileCredentialConfig +from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import NormalizedIdentifier from databricks.labs.lakebridge.transpiler.sqlglot.dialect_utils import get_dialect from databricks.labs.lakebridge.reconcile.connectors.snowflake import SnowflakeDataSource from databricks.labs.lakebridge.reconcile.exception import DataSourceRuntimeException, InvalidSnowflakePemPrivateKey @@ -19,9 +20,6 @@ def mock_secret(scope, key): secret_mock = { "scope": { - 'sfAccount': GetSecretResponse( - key='sfAccount', value=base64.b64encode(bytes('my_account', 'utf-8')).decode('utf-8') - ), 'sfUser': GetSecretResponse( key='sfUser', value=base64.b64encode(bytes('my_user', 'utf-8')).decode('utf-8') ), @@ -40,13 +38,39 @@ def mock_secret(scope, key): 'sfRole': GetSecretResponse( key='sfRole', value=base64.b64encode(bytes('my_role', 'utf-8')).decode('utf-8') ), - 'sfUrl': GetSecretResponse(key='sfUrl', value=base64.b64encode(bytes('my_url', 'utf-8')).decode('utf-8')), + 'sfUrl': GetSecretResponse( + key='sfUrl', value=base64.b64encode(bytes('my_account.snowflakecomputing.com', 'utf-8')).decode('utf-8') + ), } } return secret_mock[scope][key] +@pytest.fixture() +def snowflake_creds(): + def _snowflake_creds(scope, use_private_key=False, use_pem_password=False): + creds = { + 'sfUser': f'{scope}/sfUser', + 'sfDatabase': f'{scope}/sfDatabase', + 'sfSchema': f'{scope}/sfSchema', + 'sfWarehouse': f'{scope}/sfWarehouse', + 'sfRole': f'{scope}/sfRole', + 'sfUrl': f'{scope}/sfUrl', + } + + if use_private_key: + creds['pem_private_key'] = f'{scope}/pem_private_key' + if use_pem_password: + creds['pem_private_key_password'] = f'{scope}/pem_private_key_password' + else: + creds['sfPassword'] = f'{scope}/sfPassword' + + return creds + + return _snowflake_creds + + def generate_pkcs8_pem_key(malformed=False): private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) pem_key = private_key.private_bytes( @@ -91,11 +115,12 @@ def initial_setup(): return engine, spark, ws, scope -def test_get_jdbc_url_happy(): +def test_get_jdbc_url_happy(snowflake_creds): # initial setup engine, spark, ws, scope = initial_setup() # create object for SnowflakeDataSource - dfds = SnowflakeDataSource(engine, spark, ws, scope) + dfds = SnowflakeDataSource(engine, spark, ws) + dfds.load_credentials(ReconcileCredentialConfig("databricks", snowflake_creds(scope))) url = dfds.get_jdbc_url # Assert that the URL is generated correctly assert url == ( @@ -106,28 +131,13 @@ def test_get_jdbc_url_happy(): ) -def test_get_jdbc_url_fail(): - # initial setup - engine, spark, ws, scope = initial_setup() - ws.secrets.get_secret.side_effect = mock_secret - # create object for SnowflakeDataSource - dfds = SnowflakeDataSource(engine, spark, ws, scope) - url = dfds.get_jdbc_url - # Assert that the URL is generated correctly - assert url == ( - "jdbc:snowflake://my_account.snowflakecomputing.com" - "/?user=my_user&password=my_password" - "&db=my_database&schema=my_schema" - "&warehouse=my_warehouse&role=my_role" - ) - - -def test_read_data_with_out_options(): +def test_read_data_with_out_options(snowflake_creds): # initial setup engine, spark, ws, scope = initial_setup() # create object for SnowflakeDataSource - dfds = SnowflakeDataSource(engine, spark, ws, scope) + dfds = SnowflakeDataSource(engine, spark, ws) + dfds.load_credentials(ReconcileCredentialConfig("databricks", snowflake_creds(scope))) # Create a Tables configuration object with no JDBC reader options table_conf = Table( source_name="supplier", @@ -141,7 +151,7 @@ def test_read_data_with_out_options(): spark.read.format.assert_called_with("snowflake") spark.read.format().option.assert_called_with("dbtable", "(select 1 from org.data.employee) as tmp") spark.read.format().option().options.assert_called_with( - sfUrl="my_url", + sfUrl="my_account.snowflakecomputing.com", sfUser="my_user", sfPassword="my_password", sfDatabase="my_database", @@ -152,12 +162,13 @@ def test_read_data_with_out_options(): spark.read.format().option().options().load.assert_called_once() -def test_read_data_with_options(): +def test_read_data_with_options(snowflake_creds): # initial setup engine, spark, ws, scope = initial_setup() # create object for SnowflakeDataSource - dfds = SnowflakeDataSource(engine, spark, ws, scope) + dfds = SnowflakeDataSource(engine, spark, ws) + dfds.load_credentials(ReconcileCredentialConfig("databricks", snowflake_creds(scope))) # Create a Tables configuration object with JDBC reader options table_conf = Table( source_name="supplier", @@ -192,12 +203,13 @@ def test_read_data_with_options(): spark.read.format().option().option().option().options().load.assert_called_once() -def test_get_schema(): +def test_get_schema(snowflake_creds): # initial setup engine, spark, ws, scope = initial_setup() # Mocking get secret method to return the required values # create object for SnowflakeDataSource - dfds = SnowflakeDataSource(engine, spark, ws, scope) + dfds = SnowflakeDataSource(engine, spark, ws) + dfds.load_credentials(ReconcileCredentialConfig("databricks", snowflake_creds(scope))) # call test method dfds.get_schema("catalog", "schema", "supplier") # spark assertions @@ -215,7 +227,7 @@ def test_get_schema(): ), ) spark.read.format().option().options.assert_called_with( - sfUrl="my_url", + sfUrl="my_account.snowflakecomputing.com", sfUser="my_user", sfPassword="my_password", sfDatabase="my_database", @@ -226,10 +238,11 @@ def test_get_schema(): spark.read.format().option().options().load.assert_called_once() -def test_read_data_exception_handling(): +def test_read_data_exception_handling(snowflake_creds): # initial setup engine, spark, ws, scope = initial_setup() - dfds = SnowflakeDataSource(engine, spark, ws, scope) + dfds = SnowflakeDataSource(engine, spark, ws) + dfds.load_credentials(ReconcileCredentialConfig("databricks", snowflake_creds(scope))) # Create a Tables configuration object table_conf = Table( source_name="supplier", @@ -254,11 +267,12 @@ def test_read_data_exception_handling(): dfds.read_data("org", "data", "employee", "select 1 from :tbl", table_conf.jdbc_reader_options) -def test_get_schema_exception_handling(): +def test_get_schema_exception_handling(snowflake_creds): # initial setup engine, spark, ws, scope = initial_setup() - dfds = SnowflakeDataSource(engine, spark, ws, scope) + dfds = SnowflakeDataSource(engine, spark, ws) + dfds.load_credentials(ReconcileCredentialConfig("databricks", snowflake_creds(scope))) spark.read.format().option().options().load.side_effect = RuntimeError("Test Exception") @@ -276,16 +290,17 @@ def test_get_schema_exception_handling(): dfds.get_schema("catalog", "schema", "supplier") -def test_read_data_without_options_private_key(): +def test_read_data_without_options_private_key(snowflake_creds): engine, spark, ws, scope = initial_setup() ws.secrets.get_secret.side_effect = mock_private_key_secret - dfds = SnowflakeDataSource(engine, spark, ws, scope) + dfds = SnowflakeDataSource(engine, spark, ws) + dfds.load_credentials(ReconcileCredentialConfig("databricks", snowflake_creds(scope, use_private_key=True))) table_conf = Table(source_name="supplier", target_name="supplier") dfds.read_data("org", "data", "employee", "select 1 from :tbl", table_conf.jdbc_reader_options) spark.read.format.assert_called_with("snowflake") spark.read.format().option.assert_called_with("dbtable", "(select 1 from org.data.employee) as tmp") expected_options = { - "sfUrl": "my_url", + "sfUrl": "my_account.snowflakecomputing.com", "sfUser": "my_user", "sfDatabase": "my_database", "sfSchema": "my_schema", @@ -298,30 +313,30 @@ def test_read_data_without_options_private_key(): spark.read.format().option().options().load.assert_called_once() -def test_read_data_without_options_malformed_private_key(): +def test_read_data_without_options_malformed_private_key(snowflake_creds): engine, spark, ws, scope = initial_setup() ws.secrets.get_secret.side_effect = mock_malformed_private_key_secret - dfds = SnowflakeDataSource(engine, spark, ws, scope) - table_conf = Table(source_name="supplier", target_name="supplier") + dfds = SnowflakeDataSource(engine, spark, ws) + with pytest.raises(InvalidSnowflakePemPrivateKey, match="Failed to load or process the provided PEM private key."): - dfds.read_data("org", "data", "employee", "select 1 from :tbl", table_conf.jdbc_reader_options) + dfds.load_credentials(ReconcileCredentialConfig("databricks", snowflake_creds(scope, use_private_key=True))) -def test_read_data_without_any_auth(): +def test_read_data_without_any_auth(snowflake_creds): engine, spark, ws, scope = initial_setup() ws.secrets.get_secret.side_effect = mock_no_auth_key_secret - dfds = SnowflakeDataSource(engine, spark, ws, scope) - table_conf = Table(source_name="supplier", target_name="supplier") - with pytest.raises( - NotFound, match='sfPassword and pem_private_key not found. Either one is required for snowflake auth.' - ): - dfds.read_data("org", "data", "employee", "select 1 from :tbl", table_conf.jdbc_reader_options) + dfds = SnowflakeDataSource(engine, spark, ws) + creds = snowflake_creds(scope) + creds.pop('sfPassword') + + with pytest.raises(AssertionError, match='Missing Snowflake credentials. Please configure any of .*'): + dfds.load_credentials(ReconcileCredentialConfig("databricks", creds)) @pytest.mark.skip("Turned off till we can handle case sensitivity.") def test_normalize_identifier(): - engine, spark, ws, scope = initial_setup() - data_source = SnowflakeDataSource(engine, spark, ws, scope) + engine, spark, ws, _ = initial_setup() + data_source = SnowflakeDataSource(engine, spark, ws) assert data_source.normalize_identifier("a") == NormalizedIdentifier("`a`", '"a"') assert data_source.normalize_identifier('"b"') == NormalizedIdentifier("`b`", '"b"') diff --git a/tests/unit/reconcile/connectors/test_sql_server.py b/tests/unit/reconcile/connectors/test_sql_server.py index fa6ad90415..32c81d3bd7 100644 --- a/tests/unit/reconcile/connectors/test_sql_server.py +++ b/tests/unit/reconcile/connectors/test_sql_server.py @@ -4,7 +4,8 @@ import pytest -from databricks.labs.lakebridge.reconcile.connectors.models import NormalizedIdentifier +from databricks.labs.lakebridge.config import ReconcileCredentialConfig +from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import NormalizedIdentifier from databricks.labs.lakebridge.transpiler.sqlglot.dialect_utils import get_dialect from databricks.labs.lakebridge.reconcile.connectors.tsql import TSQLServerDataSource from databricks.labs.lakebridge.reconcile.exception import DataSourceRuntimeException @@ -35,6 +36,18 @@ def mock_secret(scope, key): return scope_secret_mock[scope][key] +def mssql_creds(scope): + return { + "host": f"{scope}/host", + "port": f"{scope}/port", + "database": f"{scope}/database", + "user": f"{scope}/user", + "password": f"{scope}/password", + "encrypt": f"{scope}/encrypt", + "trustServerCertificate": f"{scope}/trustServerCertificate", + } + + def initial_setup(): pyspark_sql_session = MagicMock() spark = pyspark_sql_session.SparkSession.builder.getOrCreate() @@ -51,20 +64,8 @@ def test_get_jdbc_url_happy(): # initial setup engine, spark, ws, scope = initial_setup() # create object for TSQLServerDataSource - data_source = TSQLServerDataSource(engine, spark, ws, scope) - url = data_source.get_jdbc_url - # Assert that the URL is generated correctly - assert url == ( - """jdbc:sqlserver://my_host:777;databaseName=my_database;user=my_user;password=my_password;encrypt=true;trustServerCertificate=true;""" - ) - - -def test_get_jdbc_url_fail(): - # initial setup - engine, spark, ws, scope = initial_setup() - ws.secrets.get_secret.side_effect = mock_secret - # create object for TSQLServerDataSource - data_source = TSQLServerDataSource(engine, spark, ws, scope) + data_source = TSQLServerDataSource(engine, spark, ws) + data_source.load_credentials(ReconcileCredentialConfig("databricks", mssql_creds(scope))) url = data_source.get_jdbc_url # Assert that the URL is generated correctly assert url == ( @@ -77,7 +78,8 @@ def test_read_data_with_options(): engine, spark, ws, scope = initial_setup() # create object for MSSQLServerDataSource - data_source = TSQLServerDataSource(engine, spark, ws, scope) + data_source = TSQLServerDataSource(engine, spark, ws) + data_source.load_credentials(ReconcileCredentialConfig("databricks", mssql_creds(scope))) # Create a Tables configuration object with JDBC reader options table_conf = Table( source_name="src_supplier", @@ -116,9 +118,9 @@ def test_read_data_with_options(): def test_get_schema(): # initial setup - engine, spark, ws, scope = initial_setup() + engine, spark, ws, _ = initial_setup() # Mocking get secret method to return the required values - data_source = TSQLServerDataSource(engine, spark, ws, scope) + data_source = TSQLServerDataSource(engine, spark, ws) # call test method data_source.get_schema("org", "schema", "supplier") # spark assertions @@ -163,8 +165,8 @@ def test_get_schema(): def test_get_schema_exception_handling(): # initial setup - engine, spark, ws, scope = initial_setup() - data_source = TSQLServerDataSource(engine, spark, ws, scope) + engine, spark, ws, _ = initial_setup() + data_source = TSQLServerDataSource(engine, spark, ws) spark.read.format().option().option().option().option().load.side_effect = RuntimeError("Test Exception") @@ -180,8 +182,8 @@ def test_get_schema_exception_handling(): def test_normalize_identifier(): - engine, spark, ws, scope = initial_setup() - data_source = TSQLServerDataSource(engine, spark, ws, scope) + engine, spark, ws, _ = initial_setup() + data_source = TSQLServerDataSource(engine, spark, ws) assert data_source.normalize_identifier("a") == NormalizedIdentifier("`a`", "[a]") assert data_source.normalize_identifier('"b"') == NormalizedIdentifier("`b`", "[b]") diff --git a/tests/unit/reconcile/test_source_adapter.py b/tests/unit/reconcile/test_source_adapter.py index 5a9cc4032d..68b093e2da 100644 --- a/tests/unit/reconcile/test_source_adapter.py +++ b/tests/unit/reconcile/test_source_adapter.py @@ -15,10 +15,9 @@ def test_create_adapter_for_snowflake_dialect(): spark = create_autospec(DatabricksSession) engine = get_dialect("snowflake") ws = create_autospec(WorkspaceClient) - scope = "scope" - data_source = create_adapter(engine, spark, ws, scope) - snowflake_data_source = SnowflakeDataSource(engine, spark, ws, scope).__class__ + data_source = create_adapter(engine, spark, ws) + snowflake_data_source = SnowflakeDataSource(engine, spark, ws).__class__ assert isinstance(data_source, snowflake_data_source) @@ -27,10 +26,9 @@ def test_create_adapter_for_oracle_dialect(): spark = create_autospec(DatabricksSession) engine = get_dialect("oracle") ws = create_autospec(WorkspaceClient) - scope = "scope" - data_source = create_adapter(engine, spark, ws, scope) - oracle_data_source = OracleDataSource(engine, spark, ws, scope).__class__ + data_source = create_adapter(engine, spark, ws) + oracle_data_source = OracleDataSource(engine, spark, ws).__class__ assert isinstance(data_source, oracle_data_source) @@ -39,10 +37,9 @@ def test_create_adapter_for_databricks_dialect(): spark = create_autospec(DatabricksSession) engine = get_dialect("databricks") ws = create_autospec(WorkspaceClient) - scope = "scope" - data_source = create_adapter(engine, spark, ws, scope) - databricks_data_source = DatabricksDataSource(engine, spark, ws, scope).__class__ + data_source = create_adapter(engine, spark, ws) + databricks_data_source = DatabricksDataSource(engine, spark, ws).__class__ assert isinstance(data_source, databricks_data_source) @@ -51,7 +48,6 @@ def test_raise_exception_for_unknown_dialect(): spark = create_autospec(DatabricksSession) engine = get_dialect("trino") ws = create_autospec(WorkspaceClient) - scope = "scope" with pytest.raises(ValueError, match=f"Unsupported source type --> {engine}"): - create_adapter(engine, spark, ws, scope) + create_adapter(engine, spark, ws) From 74a1f8846f73526db918b64243fb2d9338707ce9 Mon Sep 17 00:00:00 2001 From: M Abulazm Date: Mon, 24 Nov 2025 10:09:25 +0100 Subject: [PATCH 2/6] add check before using creds --- .../connections/credential_manager.py | 17 +++++++++++++++++ .../reconcile/connectors/data_source.py | 17 ----------------- .../lakebridge/reconcile/connectors/oracle.py | 14 ++++++++++---- .../reconcile/connectors/snowflake.py | 16 ++++++++++------ .../lakebridge/reconcile/connectors/tsql.py | 14 ++++++++++---- 5 files changed, 47 insertions(+), 31 deletions(-) diff --git a/src/databricks/labs/lakebridge/connections/credential_manager.py b/src/databricks/labs/lakebridge/connections/credential_manager.py index 204c85bef1..7cb5e6140f 100644 --- a/src/databricks/labs/lakebridge/connections/credential_manager.py +++ b/src/databricks/labs/lakebridge/connections/credential_manager.py @@ -107,6 +107,23 @@ def _get_secret_value(self, key: str) -> str: return self._provider.get_secret(key) +def build_credentials(vault_type: str, source: str, credentials: dict) -> dict: + """Build credentials dictionary with secret vault type included. + + Args: + vault_type: The type of secret vault (e.g., 'local', 'databricks'). + source: The source system name. + credentials: The original credentials dictionary. + + Returns: + A new credentials dictionary including the secret vault type. + """ + return { + source: credentials, + 'secret_vault_type': vault_type.lower(), + } + + def _get_home() -> Path: return Path(__file__).home() diff --git a/src/databricks/labs/lakebridge/reconcile/connectors/data_source.py b/src/databricks/labs/lakebridge/reconcile/connectors/data_source.py index 0b3c6e6388..3d548d2722 100644 --- a/src/databricks/labs/lakebridge/reconcile/connectors/data_source.py +++ b/src/databricks/labs/lakebridge/reconcile/connectors/data_source.py @@ -11,23 +11,6 @@ logger = logging.getLogger(__name__) -def build_credentials(vault_type: str, source: str, credentials: dict) -> dict: - """Build credentials dictionary with secret vault type included. - - Args: - vault_type: The type of secret vault (e.g., 'local', 'databricks'). - source: The source system name. - credentials: The original credentials dictionary. - - Returns: - A new credentials dictionary including the secret vault type. - """ - return { - source: credentials, - 'secret_vault_type': vault_type.lower(), - } - - class DataSource(ABC): @abstractmethod diff --git a/src/databricks/labs/lakebridge/reconcile/connectors/oracle.py b/src/databricks/labs/lakebridge/reconcile/connectors/oracle.py index 26c042c9f3..b8ba1b1c79 100644 --- a/src/databricks/labs/lakebridge/reconcile/connectors/oracle.py +++ b/src/databricks/labs/lakebridge/reconcile/connectors/oracle.py @@ -8,8 +8,8 @@ from sqlglot import Dialect from databricks.labs.lakebridge.config import ReconcileCredentialConfig -from databricks.labs.lakebridge.connections.credential_manager import create_credential_manager -from databricks.labs.lakebridge.reconcile.connectors.data_source import DataSource, build_credentials +from databricks.labs.lakebridge.connections.credential_manager import create_credential_manager, build_credentials +from databricks.labs.lakebridge.reconcile.connectors.data_source import DataSource from databricks.labs.lakebridge.reconcile.connectors.jdbc_reader import JDBCReaderMixin from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils, NormalizedIdentifier from databricks.labs.lakebridge.reconcile.recon_config import JdbcReaderOptions, Schema @@ -38,7 +38,13 @@ def __init__(self, engine: Dialect, spark: SparkSession, ws: WorkspaceClient): self._engine = engine self._spark = spark self._ws = ws - self._creds: dict[str, str] = {} + self._creds_or_empty: dict[str, str] = {} + + @property + def _creds(self): + if self._creds_or_empty: + return self._creds_or_empty + raise ValueError("Oracle credentials have not been loaded. Please call load_credentials() first.") @property def get_jdbc_url(self) -> str: @@ -127,7 +133,7 @@ def load_credentials(self, creds: ReconcileCredentialConfig) -> "OracleDataSourc else: parsed_creds = build_credentials(creds.vault_type, "oracle", creds.source_creds) - self._creds = create_credential_manager(parsed_creds, self._ws).get_credentials("oracle") + self._creds_or_empty = create_credential_manager(parsed_creds, self._ws).get_credentials("oracle") assert all( self._creds.get(k) for k in connector_creds ), f"Missing mandatory Oracle credentials. Please configure all of {connector_creds}." diff --git a/src/databricks/labs/lakebridge/reconcile/connectors/snowflake.py b/src/databricks/labs/lakebridge/reconcile/connectors/snowflake.py index 173e834043..1427a872ee 100644 --- a/src/databricks/labs/lakebridge/reconcile/connectors/snowflake.py +++ b/src/databricks/labs/lakebridge/reconcile/connectors/snowflake.py @@ -10,10 +10,8 @@ from cryptography.hazmat.primitives import serialization from databricks.labs.lakebridge.config import ReconcileCredentialConfig -from databricks.labs.lakebridge.connections.credential_manager import ( - create_credential_manager, -) -from databricks.labs.lakebridge.reconcile.connectors.data_source import DataSource, build_credentials +from databricks.labs.lakebridge.connections.credential_manager import create_credential_manager, build_credentials +from databricks.labs.lakebridge.reconcile.connectors.data_source import DataSource from databricks.labs.lakebridge.reconcile.connectors.jdbc_reader import JDBCReaderMixin from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils, NormalizedIdentifier from databricks.labs.lakebridge.reconcile.exception import InvalidSnowflakePemPrivateKey @@ -56,7 +54,13 @@ def __init__(self, engine: Dialect, spark: SparkSession, ws: WorkspaceClient): self._engine = engine self._spark = spark self._ws = ws - self._creds: dict[str, str] = {} + self._creds_or_empty: dict[str, str] = {} + + @property + def _creds(self): + if self._creds_or_empty: + return self._creds_or_empty + raise ValueError("Snowflake credentials have not been loaded. Please call load_credentials() first.") def load_credentials(self, creds: ReconcileCredentialConfig) -> "SnowflakeDataSource": connector_creds = [ @@ -79,7 +83,7 @@ def load_credentials(self, creds: ReconcileCredentialConfig) -> "SnowflakeDataSo else: parsed_creds = build_credentials(creds.vault_type, "snowflake", creds.source_creds) - self._creds = create_credential_manager(parsed_creds, self._ws).get_credentials("snowflake") + self._creds_or_empty = create_credential_manager(parsed_creds, self._ws).get_credentials("snowflake") assert all( self._creds.get(k) for k in connector_creds ), f"Missing mandatory Snowflake credentials. Please configure all of {connector_creds}." diff --git a/src/databricks/labs/lakebridge/reconcile/connectors/tsql.py b/src/databricks/labs/lakebridge/reconcile/connectors/tsql.py index c0ddf524db..8589196557 100644 --- a/src/databricks/labs/lakebridge/reconcile/connectors/tsql.py +++ b/src/databricks/labs/lakebridge/reconcile/connectors/tsql.py @@ -8,8 +8,8 @@ from sqlglot import Dialect from databricks.labs.lakebridge.config import ReconcileCredentialConfig -from databricks.labs.lakebridge.connections.credential_manager import create_credential_manager -from databricks.labs.lakebridge.reconcile.connectors.data_source import DataSource, build_credentials +from databricks.labs.lakebridge.connections.credential_manager import create_credential_manager, build_credentials +from databricks.labs.lakebridge.reconcile.connectors.data_source import DataSource from databricks.labs.lakebridge.reconcile.connectors.jdbc_reader import JDBCReaderMixin from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils, NormalizedIdentifier from databricks.labs.lakebridge.reconcile.recon_config import JdbcReaderOptions, Schema @@ -62,7 +62,13 @@ def __init__( self._engine = engine self._spark = spark self._ws = ws - self._creds: dict[str, str] = {} + self._creds_or_empty: dict[str, str] = {} + + @property + def _creds(self): + if self._creds_or_empty: + return self._creds_or_empty + raise ValueError("MS SQL/Synapse credentials have not been loaded. Please call load_credentials() first.") @property def get_jdbc_url(self) -> str: @@ -123,7 +129,7 @@ def load_credentials(self, creds: ReconcileCredentialConfig) -> "TSQLServerDataS else: parsed_creds = build_credentials(creds.vault_type, "mssql", creds.source_creds) - self._creds = create_credential_manager(parsed_creds, self._ws).get_credentials("mssql") + self._creds_or_empty = create_credential_manager(parsed_creds, self._ws).get_credentials("mssql") assert all( self._creds.get(k) for k in connector_creds ), f"Missing mandatory MS SQL credentials. Please configure all of {connector_creds}." From 36dbc29e1a7fd80d7fe3f0404126cb2cbcd3a19a Mon Sep 17 00:00:00 2001 From: M Abulazm Date: Mon, 24 Nov 2025 10:50:31 +0100 Subject: [PATCH 3/6] add check before using creds and unit tests --- .../lakebridge/reconcile/connectors/oracle.py | 5 ++++- .../reconcile/connectors/snowflake.py | 6 +++-- .../lakebridge/reconcile/connectors/tsql.py | 5 ++++- .../unit/reconcile/connectors/test_oracle.py | 17 +++++++++++++- .../reconcile/connectors/test_snowflake.py | 13 +++++++++++ .../reconcile/connectors/test_sql_server.py | 22 ++++++++++++++----- 6 files changed, 58 insertions(+), 10 deletions(-) diff --git a/src/databricks/labs/lakebridge/reconcile/connectors/oracle.py b/src/databricks/labs/lakebridge/reconcile/connectors/oracle.py index b8ba1b1c79..851a1c9ea9 100644 --- a/src/databricks/labs/lakebridge/reconcile/connectors/oracle.py +++ b/src/databricks/labs/lakebridge/reconcile/connectors/oracle.py @@ -12,6 +12,7 @@ from databricks.labs.lakebridge.reconcile.connectors.data_source import DataSource from databricks.labs.lakebridge.reconcile.connectors.jdbc_reader import JDBCReaderMixin from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils, NormalizedIdentifier +from databricks.labs.lakebridge.reconcile.exception import DataSourceRuntimeException from databricks.labs.lakebridge.reconcile.recon_config import JdbcReaderOptions, Schema from databricks.sdk import WorkspaceClient @@ -44,7 +45,9 @@ def __init__(self, engine: Dialect, spark: SparkSession, ws: WorkspaceClient): def _creds(self): if self._creds_or_empty: return self._creds_or_empty - raise ValueError("Oracle credentials have not been loaded. Please call load_credentials() first.") + raise DataSourceRuntimeException( + "Oracle credentials have not been loaded. Please call load_credentials() first." + ) @property def get_jdbc_url(self) -> str: diff --git a/src/databricks/labs/lakebridge/reconcile/connectors/snowflake.py b/src/databricks/labs/lakebridge/reconcile/connectors/snowflake.py index 1427a872ee..76825f40ac 100644 --- a/src/databricks/labs/lakebridge/reconcile/connectors/snowflake.py +++ b/src/databricks/labs/lakebridge/reconcile/connectors/snowflake.py @@ -14,7 +14,7 @@ from databricks.labs.lakebridge.reconcile.connectors.data_source import DataSource from databricks.labs.lakebridge.reconcile.connectors.jdbc_reader import JDBCReaderMixin from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils, NormalizedIdentifier -from databricks.labs.lakebridge.reconcile.exception import InvalidSnowflakePemPrivateKey +from databricks.labs.lakebridge.reconcile.exception import InvalidSnowflakePemPrivateKey, DataSourceRuntimeException from databricks.labs.lakebridge.reconcile.recon_config import JdbcReaderOptions, Schema from databricks.sdk import WorkspaceClient @@ -60,7 +60,9 @@ def __init__(self, engine: Dialect, spark: SparkSession, ws: WorkspaceClient): def _creds(self): if self._creds_or_empty: return self._creds_or_empty - raise ValueError("Snowflake credentials have not been loaded. Please call load_credentials() first.") + raise DataSourceRuntimeException( + "Snowflake credentials have not been loaded. Please call load_credentials() first." + ) def load_credentials(self, creds: ReconcileCredentialConfig) -> "SnowflakeDataSource": connector_creds = [ diff --git a/src/databricks/labs/lakebridge/reconcile/connectors/tsql.py b/src/databricks/labs/lakebridge/reconcile/connectors/tsql.py index 8589196557..0b1af4d984 100644 --- a/src/databricks/labs/lakebridge/reconcile/connectors/tsql.py +++ b/src/databricks/labs/lakebridge/reconcile/connectors/tsql.py @@ -12,6 +12,7 @@ from databricks.labs.lakebridge.reconcile.connectors.data_source import DataSource from databricks.labs.lakebridge.reconcile.connectors.jdbc_reader import JDBCReaderMixin from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils, NormalizedIdentifier +from databricks.labs.lakebridge.reconcile.exception import DataSourceRuntimeException from databricks.labs.lakebridge.reconcile.recon_config import JdbcReaderOptions, Schema from databricks.sdk import WorkspaceClient @@ -68,7 +69,9 @@ def __init__( def _creds(self): if self._creds_or_empty: return self._creds_or_empty - raise ValueError("MS SQL/Synapse credentials have not been loaded. Please call load_credentials() first.") + raise DataSourceRuntimeException( + "MS SQL/Synapse credentials have not been loaded. Please call load_credentials() first." + ) @property def get_jdbc_url(self) -> str: diff --git a/tests/unit/reconcile/connectors/test_oracle.py b/tests/unit/reconcile/connectors/test_oracle.py index 11b917c3e1..74a60f677d 100644 --- a/tests/unit/reconcile/connectors/test_oracle.py +++ b/tests/unit/reconcile/connectors/test_oracle.py @@ -112,6 +112,7 @@ def test_get_schema(): # create object for OracleDataSource ords = OracleDataSource(engine, spark, ws) + ords.load_credentials(ReconcileCredentialConfig("databricks", oracle_creds("scope"))) # call test method ords.get_schema(None, "data", "employee") # spark assertions @@ -141,6 +142,7 @@ def test_read_data_exception_handling(): # initial setup engine, spark, ws, _ = initial_setup() ords = OracleDataSource(engine, spark, ws) + ords.load_credentials(ReconcileCredentialConfig("databricks", oracle_creds("scope"))) # Create a Tables configuration object table_conf = Table( source_name="supplier", @@ -171,7 +173,7 @@ def test_get_schema_exception_handling(): # initial setup engine, spark, ws, _ = initial_setup() ords = OracleDataSource(engine, spark, ws) - + ords.load_credentials(ReconcileCredentialConfig("databricks", oracle_creds("scope"))) spark.read.format().option().option().option().option().option().load.side_effect = RuntimeError("Test Exception") # Call the get_schema method with predefined table, schema, and catalog names and assert that a PySparkException @@ -194,6 +196,19 @@ def test_get_schema_exception_handling(): ords.get_schema(None, "data", "employee") +def test_credentials_not_loaded_fails(): + engine, spark, ws, _ = initial_setup() + data_source = OracleDataSource(engine, spark, ws) + + # Call the get_schema method with predefined table, schema, and catalog names and assert that a PySparkException + # is raised + with pytest.raises( + DataSourceRuntimeException, + match=re.escape("Oracle credentials have not been loaded. Please call load_credentials() first."), + ): + data_source.get_schema("org", "schema", "supplier") + + @pytest.mark.skip("Turned off till we can handle case sensitivity.") def test_normalize_identifier(): engine, spark, ws, _ = initial_setup() diff --git a/tests/unit/reconcile/connectors/test_snowflake.py b/tests/unit/reconcile/connectors/test_snowflake.py index 566f58dd45..981e0f63c0 100644 --- a/tests/unit/reconcile/connectors/test_snowflake.py +++ b/tests/unit/reconcile/connectors/test_snowflake.py @@ -333,6 +333,19 @@ def test_read_data_without_any_auth(snowflake_creds): dfds.load_credentials(ReconcileCredentialConfig("databricks", creds)) +def test_credentials_not_loaded_fails(): + engine, spark, ws, _ = initial_setup() + data_source = SnowflakeDataSource(engine, spark, ws) + + # Call the get_schema method with predefined table, schema, and catalog names and assert that a PySparkException + # is raised + with pytest.raises( + DataSourceRuntimeException, + match=re.escape("Snowflake credentials have not been loaded. Please call load_credentials() first."), + ): + data_source.get_schema("org", "schema", "supplier") + + @pytest.mark.skip("Turned off till we can handle case sensitivity.") def test_normalize_identifier(): engine, spark, ws, _ = initial_setup() diff --git a/tests/unit/reconcile/connectors/test_sql_server.py b/tests/unit/reconcile/connectors/test_sql_server.py index 32c81d3bd7..347533f82d 100644 --- a/tests/unit/reconcile/connectors/test_sql_server.py +++ b/tests/unit/reconcile/connectors/test_sql_server.py @@ -117,13 +117,12 @@ def test_read_data_with_options(): def test_get_schema(): - # initial setup engine, spark, ws, _ = initial_setup() - # Mocking get secret method to return the required values data_source = TSQLServerDataSource(engine, spark, ws) - # call test method + data_source.load_credentials(ReconcileCredentialConfig("databricks", mssql_creds("scope"))) + data_source.get_schema("org", "schema", "supplier") - # spark assertions + spark.read.format.assert_called_with("jdbc") spark.read.format().option().option().option.assert_called_with( "dbtable", @@ -164,9 +163,9 @@ def test_get_schema(): def test_get_schema_exception_handling(): - # initial setup engine, spark, ws, _ = initial_setup() data_source = TSQLServerDataSource(engine, spark, ws) + data_source.load_credentials(ReconcileCredentialConfig("databricks", mssql_creds("scope"))) spark.read.format().option().option().option().option().load.side_effect = RuntimeError("Test Exception") @@ -181,6 +180,19 @@ def test_get_schema_exception_handling(): data_source.get_schema("org", "schema", "supplier") +def test_credentials_not_loaded_fails(): + engine, spark, ws, _ = initial_setup() + data_source = TSQLServerDataSource(engine, spark, ws) + + # Call the get_schema method with predefined table, schema, and catalog names and assert that a PySparkException + # is raised + with pytest.raises( + DataSourceRuntimeException, + match=re.escape("MS SQL/Synapse credentials have not been loaded. Please call load_credentials() first."), + ): + data_source.get_schema("org", "schema", "supplier") + + def test_normalize_identifier(): engine, spark, ws, _ = initial_setup() data_source = TSQLServerDataSource(engine, spark, ws) From 320348042f8f7deba7acbcc740994c8b4d3f4656 Mon Sep 17 00:00:00 2001 From: M Abulazm Date: Mon, 24 Nov 2025 11:16:15 +0100 Subject: [PATCH 4/6] fmt --- .../labs/lakebridge/reconcile/connectors/oracle.py | 5 +---- .../labs/lakebridge/reconcile/connectors/snowflake.py | 6 ++---- src/databricks/labs/lakebridge/reconcile/connectors/tsql.py | 5 +---- tests/unit/reconcile/connectors/test_oracle.py | 2 +- tests/unit/reconcile/connectors/test_snowflake.py | 2 +- tests/unit/reconcile/connectors/test_sql_server.py | 2 +- 6 files changed, 7 insertions(+), 15 deletions(-) diff --git a/src/databricks/labs/lakebridge/reconcile/connectors/oracle.py b/src/databricks/labs/lakebridge/reconcile/connectors/oracle.py index 851a1c9ea9..24ccbc00c3 100644 --- a/src/databricks/labs/lakebridge/reconcile/connectors/oracle.py +++ b/src/databricks/labs/lakebridge/reconcile/connectors/oracle.py @@ -12,7 +12,6 @@ from databricks.labs.lakebridge.reconcile.connectors.data_source import DataSource from databricks.labs.lakebridge.reconcile.connectors.jdbc_reader import JDBCReaderMixin from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils, NormalizedIdentifier -from databricks.labs.lakebridge.reconcile.exception import DataSourceRuntimeException from databricks.labs.lakebridge.reconcile.recon_config import JdbcReaderOptions, Schema from databricks.sdk import WorkspaceClient @@ -45,9 +44,7 @@ def __init__(self, engine: Dialect, spark: SparkSession, ws: WorkspaceClient): def _creds(self): if self._creds_or_empty: return self._creds_or_empty - raise DataSourceRuntimeException( - "Oracle credentials have not been loaded. Please call load_credentials() first." - ) + raise RuntimeError("Oracle credentials have not been loaded. Please call load_credentials() first.") @property def get_jdbc_url(self) -> str: diff --git a/src/databricks/labs/lakebridge/reconcile/connectors/snowflake.py b/src/databricks/labs/lakebridge/reconcile/connectors/snowflake.py index 76825f40ac..8812ff76d5 100644 --- a/src/databricks/labs/lakebridge/reconcile/connectors/snowflake.py +++ b/src/databricks/labs/lakebridge/reconcile/connectors/snowflake.py @@ -14,7 +14,7 @@ from databricks.labs.lakebridge.reconcile.connectors.data_source import DataSource from databricks.labs.lakebridge.reconcile.connectors.jdbc_reader import JDBCReaderMixin from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils, NormalizedIdentifier -from databricks.labs.lakebridge.reconcile.exception import InvalidSnowflakePemPrivateKey, DataSourceRuntimeException +from databricks.labs.lakebridge.reconcile.exception import InvalidSnowflakePemPrivateKey from databricks.labs.lakebridge.reconcile.recon_config import JdbcReaderOptions, Schema from databricks.sdk import WorkspaceClient @@ -60,9 +60,7 @@ def __init__(self, engine: Dialect, spark: SparkSession, ws: WorkspaceClient): def _creds(self): if self._creds_or_empty: return self._creds_or_empty - raise DataSourceRuntimeException( - "Snowflake credentials have not been loaded. Please call load_credentials() first." - ) + raise RuntimeError("Snowflake credentials have not been loaded. Please call load_credentials() first.") def load_credentials(self, creds: ReconcileCredentialConfig) -> "SnowflakeDataSource": connector_creds = [ diff --git a/src/databricks/labs/lakebridge/reconcile/connectors/tsql.py b/src/databricks/labs/lakebridge/reconcile/connectors/tsql.py index 0b1af4d984..5b8381df50 100644 --- a/src/databricks/labs/lakebridge/reconcile/connectors/tsql.py +++ b/src/databricks/labs/lakebridge/reconcile/connectors/tsql.py @@ -12,7 +12,6 @@ from databricks.labs.lakebridge.reconcile.connectors.data_source import DataSource from databricks.labs.lakebridge.reconcile.connectors.jdbc_reader import JDBCReaderMixin from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils, NormalizedIdentifier -from databricks.labs.lakebridge.reconcile.exception import DataSourceRuntimeException from databricks.labs.lakebridge.reconcile.recon_config import JdbcReaderOptions, Schema from databricks.sdk import WorkspaceClient @@ -69,9 +68,7 @@ def __init__( def _creds(self): if self._creds_or_empty: return self._creds_or_empty - raise DataSourceRuntimeException( - "MS SQL/Synapse credentials have not been loaded. Please call load_credentials() first." - ) + raise RuntimeError("MS SQL/Synapse credentials have not been loaded. Please call load_credentials() first.") @property def get_jdbc_url(self) -> str: diff --git a/tests/unit/reconcile/connectors/test_oracle.py b/tests/unit/reconcile/connectors/test_oracle.py index 74a60f677d..f6821d32c3 100644 --- a/tests/unit/reconcile/connectors/test_oracle.py +++ b/tests/unit/reconcile/connectors/test_oracle.py @@ -203,7 +203,7 @@ def test_credentials_not_loaded_fails(): # Call the get_schema method with predefined table, schema, and catalog names and assert that a PySparkException # is raised with pytest.raises( - DataSourceRuntimeException, + RuntimeError, match=re.escape("Oracle credentials have not been loaded. Please call load_credentials() first."), ): data_source.get_schema("org", "schema", "supplier") diff --git a/tests/unit/reconcile/connectors/test_snowflake.py b/tests/unit/reconcile/connectors/test_snowflake.py index 981e0f63c0..1aa88a5dc3 100644 --- a/tests/unit/reconcile/connectors/test_snowflake.py +++ b/tests/unit/reconcile/connectors/test_snowflake.py @@ -340,7 +340,7 @@ def test_credentials_not_loaded_fails(): # Call the get_schema method with predefined table, schema, and catalog names and assert that a PySparkException # is raised with pytest.raises( - DataSourceRuntimeException, + RuntimeError, match=re.escape("Snowflake credentials have not been loaded. Please call load_credentials() first."), ): data_source.get_schema("org", "schema", "supplier") diff --git a/tests/unit/reconcile/connectors/test_sql_server.py b/tests/unit/reconcile/connectors/test_sql_server.py index 347533f82d..0f37afc548 100644 --- a/tests/unit/reconcile/connectors/test_sql_server.py +++ b/tests/unit/reconcile/connectors/test_sql_server.py @@ -187,7 +187,7 @@ def test_credentials_not_loaded_fails(): # Call the get_schema method with predefined table, schema, and catalog names and assert that a PySparkException # is raised with pytest.raises( - DataSourceRuntimeException, + RuntimeError, match=re.escape("MS SQL/Synapse credentials have not been loaded. Please call load_credentials() first."), ): data_source.get_schema("org", "schema", "supplier") From ceb42c34e32d100c535b99a5d94dd037962db721 Mon Sep 17 00:00:00 2001 From: M Abulazm Date: Mon, 24 Nov 2025 11:22:38 +0100 Subject: [PATCH 5/6] assert DataSourceRuntimeException instead of RuntimeError --- tests/unit/reconcile/connectors/test_oracle.py | 2 +- tests/unit/reconcile/connectors/test_snowflake.py | 2 +- tests/unit/reconcile/connectors/test_sql_server.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/unit/reconcile/connectors/test_oracle.py b/tests/unit/reconcile/connectors/test_oracle.py index f6821d32c3..74a60f677d 100644 --- a/tests/unit/reconcile/connectors/test_oracle.py +++ b/tests/unit/reconcile/connectors/test_oracle.py @@ -203,7 +203,7 @@ def test_credentials_not_loaded_fails(): # Call the get_schema method with predefined table, schema, and catalog names and assert that a PySparkException # is raised with pytest.raises( - RuntimeError, + DataSourceRuntimeException, match=re.escape("Oracle credentials have not been loaded. Please call load_credentials() first."), ): data_source.get_schema("org", "schema", "supplier") diff --git a/tests/unit/reconcile/connectors/test_snowflake.py b/tests/unit/reconcile/connectors/test_snowflake.py index 1aa88a5dc3..981e0f63c0 100644 --- a/tests/unit/reconcile/connectors/test_snowflake.py +++ b/tests/unit/reconcile/connectors/test_snowflake.py @@ -340,7 +340,7 @@ def test_credentials_not_loaded_fails(): # Call the get_schema method with predefined table, schema, and catalog names and assert that a PySparkException # is raised with pytest.raises( - RuntimeError, + DataSourceRuntimeException, match=re.escape("Snowflake credentials have not been loaded. Please call load_credentials() first."), ): data_source.get_schema("org", "schema", "supplier") diff --git a/tests/unit/reconcile/connectors/test_sql_server.py b/tests/unit/reconcile/connectors/test_sql_server.py index 0f37afc548..347533f82d 100644 --- a/tests/unit/reconcile/connectors/test_sql_server.py +++ b/tests/unit/reconcile/connectors/test_sql_server.py @@ -187,7 +187,7 @@ def test_credentials_not_loaded_fails(): # Call the get_schema method with predefined table, schema, and catalog names and assert that a PySparkException # is raised with pytest.raises( - RuntimeError, + DataSourceRuntimeException, match=re.escape("MS SQL/Synapse credentials have not been loaded. Please call load_credentials() first."), ): data_source.get_schema("org", "schema", "supplier") From e9fef640b7c7f5108f90a54e106df458e744194e Mon Sep 17 00:00:00 2001 From: M Abulazm Date: Mon, 24 Nov 2025 12:57:49 +0100 Subject: [PATCH 6/6] Add warning logs if using secret scope --- .../labs/lakebridge/reconcile/connectors/data_source.py | 1 + src/databricks/labs/lakebridge/reconcile/connectors/oracle.py | 3 +++ .../labs/lakebridge/reconcile/connectors/snowflake.py | 3 +++ src/databricks/labs/lakebridge/reconcile/connectors/tsql.py | 3 +++ 4 files changed, 10 insertions(+) diff --git a/src/databricks/labs/lakebridge/reconcile/connectors/data_source.py b/src/databricks/labs/lakebridge/reconcile/connectors/data_source.py index 3d548d2722..fe01200e9b 100644 --- a/src/databricks/labs/lakebridge/reconcile/connectors/data_source.py +++ b/src/databricks/labs/lakebridge/reconcile/connectors/data_source.py @@ -12,6 +12,7 @@ class DataSource(ABC): + _DOCS_URL = "https://databrickslabs.github.io/lakebridge/docs/reconcile/" @abstractmethod def read_data( diff --git a/src/databricks/labs/lakebridge/reconcile/connectors/oracle.py b/src/databricks/labs/lakebridge/reconcile/connectors/oracle.py index 24ccbc00c3..de5d8b19bb 100644 --- a/src/databricks/labs/lakebridge/reconcile/connectors/oracle.py +++ b/src/databricks/labs/lakebridge/reconcile/connectors/oracle.py @@ -127,6 +127,9 @@ def load_credentials(self, creds: ReconcileCredentialConfig) -> "OracleDataSourc use_scope = creds.source_creds.get("__secret_scope") if use_scope: source_creds = {key: f"{use_scope}/{key}" for key in connector_creds} + logger.warning( + f"Secret scope configuration is deprecated. Please refer to the docs {self._DOCS_URL} to update." + ) assert creds.vault_type == "databricks", "Secret scope provided, vault_type must be 'databricks'" parsed_creds = build_credentials(creds.vault_type, "oracle", source_creds) diff --git a/src/databricks/labs/lakebridge/reconcile/connectors/snowflake.py b/src/databricks/labs/lakebridge/reconcile/connectors/snowflake.py index 8812ff76d5..6af85bae3a 100644 --- a/src/databricks/labs/lakebridge/reconcile/connectors/snowflake.py +++ b/src/databricks/labs/lakebridge/reconcile/connectors/snowflake.py @@ -75,6 +75,9 @@ def load_credentials(self, creds: ReconcileCredentialConfig) -> "SnowflakeDataSo use_scope = creds.source_creds.get("__secret_scope") if use_scope: # to use pem key and/or pem password, migrate to source_creds approach + logger.warning( + f"Secret scope configuration is deprecated. Using secret scopes supports password authentication only. Please refer to the docs {self._DOCS_URL} to update and to access full features." + ) connector_creds += ["sfPassword"] source_creds = {key: f"{use_scope}/{key}" for key in connector_creds} diff --git a/src/databricks/labs/lakebridge/reconcile/connectors/tsql.py b/src/databricks/labs/lakebridge/reconcile/connectors/tsql.py index 5b8381df50..bb1608945a 100644 --- a/src/databricks/labs/lakebridge/reconcile/connectors/tsql.py +++ b/src/databricks/labs/lakebridge/reconcile/connectors/tsql.py @@ -122,6 +122,9 @@ def load_credentials(self, creds: ReconcileCredentialConfig) -> "TSQLServerDataS use_scope = creds.source_creds.get("__secret_scope") if use_scope: + logger.warning( + f"Secret scope configuration is deprecated. Please refer to the docs {self._DOCS_URL} to update." + ) source_creds = {key: f"{use_scope}/{key}" for key in connector_creds} assert creds.vault_type == "databricks", "Secret scope provided, vault_type must be 'databricks'"