Skip to content

Commit bf928ff

Browse files
committed
add load_credentials interface for data sources and impl
1 parent c284f50 commit bf928ff

File tree

22 files changed

+353
-354
lines changed

22 files changed

+353
-354
lines changed

src/databricks/labs/lakebridge/reconcile/connectors/data_source.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,31 @@
33

44
from pyspark.sql import DataFrame
55

6-
from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils
7-
from databricks.labs.lakebridge.reconcile.connectors.models import NormalizedIdentifier
6+
from databricks.labs.lakebridge.config import ReconcileCredentialConfig
7+
from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils, NormalizedIdentifier
88
from databricks.labs.lakebridge.reconcile.exception import DataSourceRuntimeException
99
from databricks.labs.lakebridge.reconcile.recon_config import JdbcReaderOptions, Schema
1010

1111
logger = logging.getLogger(__name__)
1212

1313

14+
def build_credentials(vault_type: str, source: str, credentials: dict) -> dict:
15+
"""Build credentials dictionary with secret vault type included.
16+
17+
Args:
18+
vault_type: The type of secret vault (e.g., 'local', 'databricks').
19+
source: The source system name.
20+
credentials: The original credentials dictionary.
21+
22+
Returns:
23+
A new credentials dictionary including the secret vault type.
24+
"""
25+
return {
26+
source: credentials,
27+
'secret_vault_type': vault_type.lower(),
28+
}
29+
30+
1431
class DataSource(ABC):
1532

1633
@abstractmethod
@@ -34,6 +51,10 @@ def get_schema(
3451
) -> list[Schema]:
3552
return NotImplemented
3653

54+
@abstractmethod
55+
def load_credentials(self, creds: ReconcileCredentialConfig) -> "DataSource":
56+
return NotImplemented
57+
3758
@abstractmethod
3859
def normalize_identifier(self, identifier: str) -> NormalizedIdentifier:
3960
pass
@@ -94,5 +115,8 @@ def get_schema(self, catalog: str | None, schema: str, table: str, normalize: bo
94115
return self.log_and_throw_exception(self._exception, "schema", f"({catalog}, {schema}, {table})")
95116
return mock_schema
96117

118+
def load_credentials(self, creds: ReconcileCredentialConfig) -> "MockDataSource":
119+
return self
120+
97121
def normalize_identifier(self, identifier: str) -> NormalizedIdentifier:
98122
return DialectUtils.normalize_identifier(identifier, self._delimiter, self._delimiter)

src/databricks/labs/lakebridge/reconcile/connectors/databricks.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,9 @@
77
from pyspark.sql.functions import col
88
from sqlglot import Dialect
99

10+
from databricks.labs.lakebridge.config import ReconcileCredentialConfig
1011
from databricks.labs.lakebridge.reconcile.connectors.data_source import DataSource
11-
from databricks.labs.lakebridge.reconcile.connectors.models import NormalizedIdentifier
12-
from databricks.labs.lakebridge.reconcile.connectors.secrets import SecretsMixin
13-
from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils
12+
from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils, NormalizedIdentifier
1413
from databricks.labs.lakebridge.reconcile.recon_config import JdbcReaderOptions, Schema
1514
from databricks.sdk import WorkspaceClient
1615

@@ -36,20 +35,18 @@ def _get_schema_query(catalog: str, schema: str, table: str):
3635
return re.sub(r'\s+', ' ', query)
3736

3837

39-
class DatabricksDataSource(DataSource, SecretsMixin):
38+
class DatabricksDataSource(DataSource):
4039
_IDENTIFIER_DELIMITER = "`"
4140

4241
def __init__(
4342
self,
4443
engine: Dialect,
4544
spark: SparkSession,
4645
ws: WorkspaceClient,
47-
secret_scope: str,
4846
):
4947
self._engine = engine
5048
self._spark = spark
5149
self._ws = ws
52-
self._secret_scope = secret_scope
5350

5451
def read_data(
5552
self,
@@ -96,6 +93,9 @@ def get_schema(
9693
except (RuntimeError, PySparkException) as e:
9794
return self.log_and_throw_exception(e, "schema", schema_query)
9895

96+
def load_credentials(self, creds: ReconcileCredentialConfig) -> "DatabricksDataSource":
97+
return self
98+
9999
def normalize_identifier(self, identifier: str) -> NormalizedIdentifier:
100100
return DialectUtils.normalize_identifier(
101101
identifier,

src/databricks/labs/lakebridge/reconcile/connectors/dialect_utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
1-
from databricks.labs.lakebridge.reconcile.connectors.models import NormalizedIdentifier
1+
import dataclasses
2+
3+
4+
@dataclasses.dataclass()
5+
class NormalizedIdentifier:
6+
ansi_normalized: str
7+
source_normalized: str
28

39

410
class DialectUtils:

src/databricks/labs/lakebridge/reconcile/connectors/jdbc_reader.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
class JDBCReaderMixin:
99
_spark: SparkSession
1010

11-
# TODO update the url
1211
def _get_jdbc_reader(self, query, jdbc_url, driver, additional_options: dict | None = None):
1312
driver_class = {
1413
"oracle": "oracle.jdbc.OracleDriver",

src/databricks/labs/lakebridge/reconcile/connectors/models.py

Lines changed: 0 additions & 7 deletions
This file was deleted.

src/databricks/labs/lakebridge/reconcile/connectors/oracle.py

Lines changed: 36 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,18 @@
77
from pyspark.sql.functions import col
88
from sqlglot import Dialect
99

10-
from databricks.labs.lakebridge.reconcile.connectors.data_source import DataSource
10+
from databricks.labs.lakebridge.config import ReconcileCredentialConfig
11+
from databricks.labs.lakebridge.connections.credential_manager import create_credential_manager
12+
from databricks.labs.lakebridge.reconcile.connectors.data_source import DataSource, build_credentials
1113
from databricks.labs.lakebridge.reconcile.connectors.jdbc_reader import JDBCReaderMixin
12-
from databricks.labs.lakebridge.reconcile.connectors.models import NormalizedIdentifier
13-
from databricks.labs.lakebridge.reconcile.connectors.secrets import SecretsMixin
14-
from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils
14+
from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils, NormalizedIdentifier
1515
from databricks.labs.lakebridge.reconcile.recon_config import JdbcReaderOptions, Schema
1616
from databricks.sdk import WorkspaceClient
1717

1818
logger = logging.getLogger(__name__)
1919

2020

21-
class OracleDataSource(DataSource, SecretsMixin, JDBCReaderMixin):
21+
class OracleDataSource(DataSource, JDBCReaderMixin):
2222
_DRIVER = "oracle"
2323
_IDENTIFIER_DELIMITER = "\""
2424
_SCHEMA_QUERY = """select column_name, case when (data_precision is not null
@@ -34,23 +34,17 @@ class OracleDataSource(DataSource, SecretsMixin, JDBCReaderMixin):
3434
FROM ALL_TAB_COLUMNS
3535
WHERE lower(TABLE_NAME) = '{table}' and lower(owner) = '{owner}'"""
3636

37-
def __init__(
38-
self,
39-
engine: Dialect,
40-
spark: SparkSession,
41-
ws: WorkspaceClient,
42-
secret_scope: str,
43-
):
37+
def __init__(self, engine: Dialect, spark: SparkSession, ws: WorkspaceClient):
4438
self._engine = engine
4539
self._spark = spark
4640
self._ws = ws
47-
self._secret_scope = secret_scope
41+
self._creds: dict[str, str] = {}
4842

4943
@property
5044
def get_jdbc_url(self) -> str:
5145
return (
52-
f"jdbc:{OracleDataSource._DRIVER}:thin:@//{self._get_secret('host')}"
53-
f":{self._get_secret('port')}/{self._get_secret('database')}"
46+
f"jdbc:{OracleDataSource._DRIVER}:thin:@//{self._creds.get('host')}"
47+
f":{self._creds.get('port')}/{self._creds.get('database')}"
5448
)
5549

5650
def read_data(
@@ -108,13 +102,38 @@ def _get_timestamp_options() -> dict[str, str]:
108102
}
109103

110104
def reader(self, query: str) -> DataFrameReader:
111-
user = self._get_secret('user')
112-
password = self._get_secret('password')
105+
user = self._creds.get('user')
106+
password = self._creds.get('password')
113107
logger.debug(f"Using user: {user} to connect to Oracle")
114108
return self._get_jdbc_reader(
115109
query, self.get_jdbc_url, OracleDataSource._DRIVER, {"user": user, "password": password}
116110
)
117111

112+
def load_credentials(self, creds: ReconcileCredentialConfig) -> "OracleDataSource":
113+
connector_creds = [
114+
"host",
115+
"port",
116+
"database",
117+
"user",
118+
"password",
119+
]
120+
121+
use_scope = creds.source_creds.get("__secret_scope")
122+
if use_scope:
123+
source_creds = {key: f"{use_scope}/{key}" for key in connector_creds}
124+
125+
assert creds.vault_type == "databricks", "Secret scope provided, vault_type must be 'databricks'"
126+
parsed_creds = build_credentials(creds.vault_type, "oracle", source_creds)
127+
else:
128+
parsed_creds = build_credentials(creds.vault_type, "oracle", creds.source_creds)
129+
130+
self._creds = create_credential_manager(parsed_creds, self._ws).get_credentials("oracle")
131+
assert all(
132+
self._creds.get(k) for k in connector_creds
133+
), f"Missing mandatory Oracle credentials. Please configure all of {connector_creds}."
134+
135+
return self
136+
118137
def normalize_identifier(self, identifier: str) -> NormalizedIdentifier:
119138
normalized = DialectUtils.normalize_identifier(
120139
identifier,

src/databricks/labs/lakebridge/reconcile/connectors/secrets.py

Lines changed: 0 additions & 49 deletions
This file was deleted.

0 commit comments

Comments
 (0)