Skip to content
Open
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,23 @@ def _get_secret_value(self, key: str) -> str:
return self._provider.get_secret(key)


def build_credentials(vault_type: str, source: str, credentials: dict) -> dict:
"""Build credentials dictionary with secret vault type included.

Args:
vault_type: The type of secret vault (e.g., 'local', 'databricks').
source: The source system name.
credentials: The original credentials dictionary.

Returns:
A new credentials dictionary including the secret vault type.
"""
return {
source: credentials,
'secret_vault_type': vault_type.lower(),
}


def _get_home() -> Path:
return Path(__file__).home()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@

from pyspark.sql import DataFrame

from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils
from databricks.labs.lakebridge.reconcile.connectors.models import NormalizedIdentifier
from databricks.labs.lakebridge.config import ReconcileCredentialConfig
from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils, NormalizedIdentifier
from databricks.labs.lakebridge.reconcile.exception import DataSourceRuntimeException
from databricks.labs.lakebridge.reconcile.recon_config import JdbcReaderOptions, Schema

logger = logging.getLogger(__name__)


class DataSource(ABC):
_DOCS_URL = "https://databrickslabs.github.io/lakebridge/docs/reconcile/"

@abstractmethod
def read_data(
Expand All @@ -34,6 +35,10 @@ def get_schema(
) -> list[Schema]:
return NotImplemented

@abstractmethod
def load_credentials(self, creds: ReconcileCredentialConfig) -> "DataSource":
return NotImplemented

@abstractmethod
def normalize_identifier(self, identifier: str) -> NormalizedIdentifier:
pass
Expand Down Expand Up @@ -94,5 +99,8 @@ def get_schema(self, catalog: str | None, schema: str, table: str, normalize: bo
return self.log_and_throw_exception(self._exception, "schema", f"({catalog}, {schema}, {table})")
return mock_schema

def load_credentials(self, creds: ReconcileCredentialConfig) -> "MockDataSource":
return self

def normalize_identifier(self, identifier: str) -> NormalizedIdentifier:
return DialectUtils.normalize_identifier(identifier, self._delimiter, self._delimiter)
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@
from pyspark.sql.functions import col
from sqlglot import Dialect

from databricks.labs.lakebridge.config import ReconcileCredentialConfig
from databricks.labs.lakebridge.reconcile.connectors.data_source import DataSource
from databricks.labs.lakebridge.reconcile.connectors.models import NormalizedIdentifier
from databricks.labs.lakebridge.reconcile.connectors.secrets import SecretsMixin
from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils
from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils, NormalizedIdentifier
from databricks.labs.lakebridge.reconcile.recon_config import JdbcReaderOptions, Schema
from databricks.sdk import WorkspaceClient

Expand All @@ -36,20 +35,18 @@ def _get_schema_query(catalog: str, schema: str, table: str):
return re.sub(r'\s+', ' ', query)


class DatabricksDataSource(DataSource, SecretsMixin):
class DatabricksDataSource(DataSource):
_IDENTIFIER_DELIMITER = "`"

def __init__(
self,
engine: Dialect,
spark: SparkSession,
ws: WorkspaceClient,
secret_scope: str,
):
self._engine = engine
self._spark = spark
self._ws = ws
self._secret_scope = secret_scope

def read_data(
self,
Expand Down Expand Up @@ -96,6 +93,9 @@ def get_schema(
except (RuntimeError, PySparkException) as e:
return self.log_and_throw_exception(e, "schema", schema_query)

def load_credentials(self, creds: ReconcileCredentialConfig) -> "DatabricksDataSource":
return self

def normalize_identifier(self, identifier: str) -> NormalizedIdentifier:
return DialectUtils.normalize_identifier(
identifier,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
from databricks.labs.lakebridge.reconcile.connectors.models import NormalizedIdentifier
import dataclasses


@dataclasses.dataclass()
class NormalizedIdentifier:
ansi_normalized: str
source_normalized: str


class DialectUtils:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
class JDBCReaderMixin:
_spark: SparkSession

# TODO update the url
def _get_jdbc_reader(self, query, jdbc_url, driver, additional_options: dict | None = None):
driver_class = {
"oracle": "oracle.jdbc.OracleDriver",
Expand Down

This file was deleted.

60 changes: 44 additions & 16 deletions src/databricks/labs/lakebridge/reconcile/connectors/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,18 @@
from pyspark.sql.functions import col
from sqlglot import Dialect

from databricks.labs.lakebridge.config import ReconcileCredentialConfig
from databricks.labs.lakebridge.connections.credential_manager import create_credential_manager, build_credentials
from databricks.labs.lakebridge.reconcile.connectors.data_source import DataSource
from databricks.labs.lakebridge.reconcile.connectors.jdbc_reader import JDBCReaderMixin
from databricks.labs.lakebridge.reconcile.connectors.models import NormalizedIdentifier
from databricks.labs.lakebridge.reconcile.connectors.secrets import SecretsMixin
from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils
from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils, NormalizedIdentifier
from databricks.labs.lakebridge.reconcile.recon_config import JdbcReaderOptions, Schema
from databricks.sdk import WorkspaceClient

logger = logging.getLogger(__name__)


class OracleDataSource(DataSource, SecretsMixin, JDBCReaderMixin):
class OracleDataSource(DataSource, JDBCReaderMixin):
_DRIVER = "oracle"
_IDENTIFIER_DELIMITER = "\""
_SCHEMA_QUERY = """select column_name, case when (data_precision is not null
Expand All @@ -34,23 +34,23 @@ class OracleDataSource(DataSource, SecretsMixin, JDBCReaderMixin):
FROM ALL_TAB_COLUMNS
WHERE lower(TABLE_NAME) = '{table}' and lower(owner) = '{owner}'"""

def __init__(
self,
engine: Dialect,
spark: SparkSession,
ws: WorkspaceClient,
secret_scope: str,
):
def __init__(self, engine: Dialect, spark: SparkSession, ws: WorkspaceClient):
self._engine = engine
self._spark = spark
self._ws = ws
self._secret_scope = secret_scope
self._creds_or_empty: dict[str, str] = {}

@property
def _creds(self):
if self._creds_or_empty:
return self._creds_or_empty
raise RuntimeError("Oracle credentials have not been loaded. Please call load_credentials() first.")

@property
def get_jdbc_url(self) -> str:
return (
f"jdbc:{OracleDataSource._DRIVER}:thin:@//{self._get_secret('host')}"
f":{self._get_secret('port')}/{self._get_secret('database')}"
f"jdbc:{OracleDataSource._DRIVER}:thin:@//{self._creds.get('host')}"
f":{self._creds.get('port')}/{self._creds.get('database')}"
)

def read_data(
Expand Down Expand Up @@ -108,13 +108,41 @@ def _get_timestamp_options() -> dict[str, str]:
}

def reader(self, query: str) -> DataFrameReader:
user = self._get_secret('user')
password = self._get_secret('password')
user = self._creds.get('user')
password = self._creds.get('password')
logger.debug(f"Using user: {user} to connect to Oracle")
return self._get_jdbc_reader(
query, self.get_jdbc_url, OracleDataSource._DRIVER, {"user": user, "password": password}
)

def load_credentials(self, creds: ReconcileCredentialConfig) -> "OracleDataSource":
connector_creds = [
"host",
"port",
"database",
"user",
"password",
]

use_scope = creds.source_creds.get("__secret_scope")
if use_scope:
source_creds = {key: f"{use_scope}/{key}" for key in connector_creds}
logger.warning(
f"Secret scope configuration is deprecated. Please refer to the docs {self._DOCS_URL} to update."
)

assert creds.vault_type == "databricks", "Secret scope provided, vault_type must be 'databricks'"
parsed_creds = build_credentials(creds.vault_type, "oracle", source_creds)
else:
parsed_creds = build_credentials(creds.vault_type, "oracle", creds.source_creds)

self._creds_or_empty = create_credential_manager(parsed_creds, self._ws).get_credentials("oracle")
assert all(
self._creds.get(k) for k in connector_creds
), f"Missing mandatory Oracle credentials. Please configure all of {connector_creds}."

return self

def normalize_identifier(self, identifier: str) -> NormalizedIdentifier:
normalized = DialectUtils.normalize_identifier(
identifier,
Expand Down
49 changes: 0 additions & 49 deletions src/databricks/labs/lakebridge/reconcile/connectors/secrets.py

This file was deleted.

Loading
Loading