Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
/*
* Copyright (2021) The Delta Lake Project Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.delta.coordinatedcommits

import java.net.{URI, URISyntaxException}

import io.delta.storage.commit.uccommitcoordinator.{OAuthUCTokenProvider, UCClient}

import org.apache.spark.sql.delta.logging.DeltaLogKeys
import org.apache.spark.sql.delta.metering.DeltaLogging

import org.apache.spark.internal.MDC

/**
* Abstract sealed trait for Unity Catalog client parameters.
* Implementations provide different authentication mechanisms (token-based or OAuth).
*/
sealed trait UCClientParams {

/** The Unity Catalog server URI */
def uri: String

/**
* Builds a UCClient using the appropriate authentication mechanism.
*
* @param ucClientFactory Factory for creating UCClient instances
* @return A configured UCClient
*/
def buildUCClient(ucClientFactory: UCClientFactory): UCClient
}

/**
* Unity Catalog client parameters for token-based authentication.
*
* @param uri The Unity Catalog server URI
* @param token The authentication token
*/
case class UCTokenClientParams(uri: String, token: String) extends UCClientParams {
override def buildUCClient(ucClientFactory: UCClientFactory): UCClient = {
ucClientFactory.createUCClient(uri, token)
}
}

/**
* Unity Catalog client parameters for OAuth-based authentication.
*
* @param uri The Unity Catalog server URI
* @param oauthUri The OAuth token endpoint URI
* @param oauthClientId The OAuth client ID
* @param oauthClientSecret The OAuth client secret
*/
case class UCOAuthClientParams(
uri: String,
oauthUri: String,
oauthClientId: String,
oauthClientSecret: String) extends UCClientParams {
override def buildUCClient(ucClientFactory: UCClientFactory): UCClient = {
val provider = new OAuthUCTokenProvider(oauthUri, oauthClientId, oauthClientSecret)
ucClientFactory.createUCClient(uri, provider)
}
}

object UCClientParams extends DeltaLogging {

/**
* Factory method to create UCClientParams from optional configuration values.
* Returns None if the configuration is invalid or incomplete.
*
* @param catalogName The catalog name.
* @param uri The Unity Catalog server URI
* @param token The authentication token (for token-based auth)
* @param oauthUri The OAuth token endpoint URI (for OAuth auth)
* @param oauthClientId The OAuth client ID (for OAuth auth)
* @param oauthClientSecret The OAuth client secret (for OAuth auth)
* @return Some(UCClientParams) if valid, None otherwise
*/
def create(
catalogName: String,
uri: Option[String],
token: Option[String] = None,
oauthUri: Option[String] = None,
oauthClientId: Option[String] = None,
oauthClientSecret: Option[String] = None): Option[UCClientParams] = {
// Validate the uri.
uri match {
case Some(u) =>
if (!isValidURI(u)) {
logWarning(log"Skipping catalog ${MDC(DeltaLogKeys.CATALOG, catalogName)} as it " +
log"does not have a valid URI ${MDC(DeltaLogKeys.URI, u)}.")
return None
}
case None => return None
}

(uri, token, oauthUri, oauthClientId, oauthClientSecret) match {
case (Some(u), Some(t), _, _, _) =>
// Use fixed token to build the UCTokenClientParams.
Some(UCTokenClientParams(uri = u, token = t))
case (Some(u), _, Some(oUri), Some(oClientId), Some(oClientSecret)) =>
// Validate the OAuth URI.
if (!isValidURI(oUri)) {
logWarning(log"Skipping catalog ${MDC(DeltaLogKeys.CATALOG, catalogName)} " +
log"as it does not have a valid OAuth URI")
return None
}
// Use OAuth credentials to build the UCOAuthClientParams.
Some(UCOAuthClientParams(
uri = u,
oauthUri = oUri,
oauthClientId = oClientId,
oauthClientSecret = oClientSecret))
case _ =>
logWarning(log"Skipping catalog ${MDC(DeltaLogKeys.CATALOG, catalogName)} as it does " +
"not have configured fixed token or oauth credential in Spark Session.")
None
}
}

private def isValidURI(uri: String): Boolean = {
try {
new URI(uri)
true
} catch {
case _: URISyntaxException => false
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,17 @@

package org.apache.spark.sql.delta.coordinatedcommits

import java.net.{URI, URISyntaxException}
import java.util.concurrent.ConcurrentHashMap

import scala.collection.JavaConverters._
import scala.util.control.NonFatal

import io.delta.storage.commit.CommitCoordinatorClient
import io.delta.storage.commit.uccommitcoordinator.{FixedUCTokenProvider, UCClient, UCCommitCoordinatorClient, UCTokenBasedRestClient, UCTokenProvider}

import org.apache.spark.sql.delta.logging.DeltaLogKeys
import org.apache.spark.sql.delta.metering.DeltaLogging
import io.delta.storage.commit.CommitCoordinatorClient
import io.delta.storage.commit.uccommitcoordinator.{UCClient, UCCommitCoordinatorClient, UCTokenBasedRestClient}

import org.apache.spark.internal.MDC
import org.apache.spark.internal.MDC
import org.apache.spark.sql.SparkSession

Expand All @@ -40,7 +39,7 @@ import org.apache.spark.sql.SparkSession
* It caches the UCCommitCoordinatorClient instance for a given metastore ID upon its first access.
*/
object UCCommitCoordinatorBuilder
extends CatalogOwnedCommitCoordinatorBuilder with DeltaLogging {
extends CatalogOwnedCommitCoordinatorBuilder with DeltaLogging {

/** Prefix for Spark SQL catalog configurations. */
final private val SPARK_SQL_CATALOG_PREFIX = "spark.sql.catalog."
Expand All @@ -55,13 +54,22 @@ object UCCommitCoordinatorBuilder
/** Suffix for the token configuration of a catalog. */
final private val TOKEN_SUFFIX = "token"

/** Suffix for the OAuth URI configuration of a catalog */
final private val OAUTH_URI_SUFFIX = "auth.oauth.uri"

/** Suffix for the OAuth client id configuration of a catalog */
final private val OAUTH_CLIENT_ID_SUFFIX = "auth.oauth.clientId"

/** Suffix for the OAuth client secret configuration of a catalog */
final private val OAUTH_CLIENT_SECRET_SUFFIX = "auth.oauth.clientSecret"

/** Cache for UCCommitCoordinatorClient instances. */
private val commitCoordinatorClientCache =
new ConcurrentHashMap[String, UCCommitCoordinatorClient]()

// Helper cache for (uri, token) to metastoreId to avoid redundant calls to getMetastoreId
// catalog.
private val uriTokenToMetastoreIdCache = new ConcurrentHashMap[(String, String), String]()
private val ucClientParamsToMetastoreIdCache = new ConcurrentHashMap[UCClientParams, String]()

// Use a var instead of val for ease of testing by injecting different UCClientFactory.
private[delta] var ucClientFactory: UCClientFactory = UCTokenBasedRestClientFactory
Expand All @@ -76,14 +84,14 @@ object UCCommitCoordinatorBuilder

commitCoordinatorClientCache.computeIfAbsent(
metastoreId,
_ => new UCCommitCoordinatorClient(conf.asJava, getMatchingUCClient(spark, metastoreId))
)
_ => new UCCommitCoordinatorClient(conf.asJava, getMatchingUCClient(spark, metastoreId)))
}

override def buildForCatalog(
spark: SparkSession, catalogName: String): CommitCoordinatorClient = {
spark: SparkSession,
catalogName: String): CommitCoordinatorClient = {
val client = getCatalogConfigs(spark).find(_._1 == catalogName) match {
case Some((_, uri, token)) => ucClientFactory.createUCClient(uri, token)
case Some((_, ucClientParams)) => ucClientParams.buildUCClient(ucClientFactory)
case None =>
throw new IllegalArgumentException(
s"Catalog $catalogName not found in the provided SparkSession configurations.")
Expand All @@ -101,31 +109,32 @@ object UCCommitCoordinatorBuilder
* appropriate exception.
*/
private def getMatchingUCClient(spark: SparkSession, metastoreId: String): UCClient = {
val matchingClients: List[(String, String)] = getCatalogConfigs(spark)
.map { case (name, uri, token) => (uri, token) }
val matchingClients: List[UCClientParams] = getCatalogConfigs(spark)
.map { case (_, ucClientParams) => ucClientParams }
.distinct // Remove duplicates since multiple catalogs can have the same uri and token
.filter { case (uri, token) => getMetastoreId(uri, token).contains(metastoreId) }
.filter(ucClientParams => getMetastoreId(ucClientParams).contains(metastoreId))

matchingClients match {
case Nil => throw noMatchingCatalogException(metastoreId)
case (uri, token) :: Nil => ucClientFactory.createUCClient(uri, token)
case multiple => throw multipleMatchingCatalogs(metastoreId, multiple.map(_._1))
case ucClientParams :: Nil => ucClientParams.buildUCClient(ucClientFactory)
case multiple => throw multipleMatchingCatalogs(metastoreId, multiple.map(_.uri))
}
}

/**
* Retrieves the metastore ID for a given URI and token.
* Retrieves the metastore ID for a given UCClientParams.
*
* This method creates a UCClient using the provided URI and token, then retrieves its metastore
* This method creates a UCClient using the provided UCClientParams, then retrieves its metastore
* ID. The result is cached to avoid unnecessary getMetastoreId requests in future calls. If
* there's an error, it returns None and logs a warning.
*/
private def getMetastoreId(uri: String, token: String): Option[String] = {
private def getMetastoreId(ucClientParams: UCClientParams): Option[String] = {
val uri = ucClientParams.uri
try {
val metastoreId = uriTokenToMetastoreIdCache.computeIfAbsent(
(uri, token),
val metastoreId = ucClientParamsToMetastoreIdCache.computeIfAbsent(
ucClientParams,
_ => {
val ucClient = ucClientFactory.createUCClient(uri, token)
val ucClient = ucClientParams.buildUCClient(ucClientFactory)
try {
ucClient.getMetastoreId
} finally {
Expand Down Expand Up @@ -164,39 +173,46 @@ object UCCommitCoordinatorBuilder
* Retrieves the catalog configurations from the SparkSession.
*
* Example; Given Spark configurations:
* spark.sql.catalog.catalog1 = "io.unitycatalog.connectors.spark.UCSingleCatalog"
* spark.sql.catalog.catalog1.uri = "https://dbc-123abc.databricks.com"
* spark.sql.catalog.catalog1.token = "dapi1234567890"
* spark.sql.catalog.catalog1 = "io.unitycatalog.connectors.spark.UCSingleCatalog"
* spark.sql.catalog.catalog1.uri = "https://dbc-123abc.databricks.com"
* spark.sql.catalog.catalog1.token = "dapi1234567890"
*
* spark.sql.catalog.catalog2 = "io.unitycatalog.connectors.spark.UCSingleCatalog"
* spark.sql.catalog.catalog2.uri = "https://dbc-456def.databricks.com"
* spark.sql.catalog.catalog2.token = "dapi0987654321"
*
* spark.sql.catalog.catalog2 = "io.unitycatalog.connectors.spark.UCSingleCatalog"
* spark.sql.catalog.catalog2.uri = "https://dbc-456def.databricks.com"
* spark.sql.catalog.catalog2.token = "dapi0987654321"
* spark.sql.catalog.catalog3 = "io.unitycatalog.connectors.spark.UCSingleCatalog"
* spark.sql.catalog.catalog3.uri = "https://dbc-789ghi.databricks.com"
*
* spark.sql.catalog.catalog3 = "io.unitycatalog.connectors.spark.UCSingleCatalog"
* spark.sql.catalog.catalog3.uri = "https://dbc-789ghi.databricks.com"
* spark.sql.catalog.catalog4 = "com.databricks.sql.lakehouse.catalog3"
* spark.sql.catalog.catalog4.uri = "https://dbc-456def.databricks.com"
* spark.sql.catalog.catalog4.token = "dapi0987654321"
*
* spark.sql.catalog.catalog4 = "com.databricks.sql.lakehouse.catalog3"
* spark.sql.catalog.catalog4.uri = "https://dbc-456def.databricks.com"
* spark.sql.catalog.catalog4.token = "dapi0987654321"
* spark.sql.catalog.catalog5 = "io.unitycatalog.connectors.spark.UCSingleCatalog"
* spark.sql.catalog.catalog5.uri = "random-string"
* spark.sql.catalog.catalog5.token = "dapi0987654321"
*
* spark.sql.catalog.catalog5 = "io.unitycatalog.connectors.spark.UCSingleCatalog"
* spark.sql.catalog.catalog5.uri = "random-string"
* spark.sql.catalog.catalog5.token = "dapi0987654321"
* spark.sql.catalog.catalog6 = "io.unitycatalog.connectors.spark.UCSingleCatalog"
* spark.sql.catalog.catalog6.uri = "https://local:8080/"
* spark.sql.catalog.catalog6.auth.oauth.uri = "https://local:8081/"
* spark.sql.catalog.catalog6.auth.oauth.clientId = "client-id"
* spark.sql.catalog.catalog6.auth.oauth.clientSecret = "client-secret"
*
* This method would return:
* List(
* ("catalog1", "https://dbc-123abc.databricks.com", "dapi1234567890"),
* ("catalog2", "https://dbc-456def.databricks.com", "dapi0987654321")
* ("catalog1", UCClientParams(..)),
* ("catalog2", UCClientParams(..)),
* ("catalog6", UCClientParams(..))
* )
*
* Note: catalog3 is not included in the result because it's missing the token configuration.
* Note: catalog4 is not included in the result because it's not a UCSingleCatalog connector.
* Note: catalog5 is not included in the result because its URI is not a valid URI.
*
* @return
* A list of tuples containing (catalogName, uri, token) for each properly configured catalog
* A list of tuples containing (catalogName, uri, token) for each properly configured catalog
*/
private[delta] def getCatalogConfigs(spark: SparkSession): List[(String, String, String)] = {
private[delta] def getCatalogConfigs(spark: SparkSession): List[(String, UCClientParams)] = {
val catalogConfigs = spark.conf.getAll.filterKeys(_.startsWith(SPARK_SQL_CATALOG_PREFIX))

catalogConfigs
Expand All @@ -206,26 +222,31 @@ object UCCommitCoordinatorBuilder
.map(_(3))
.filter { catalogName: String =>
val connector = catalogConfigs.get(s"$SPARK_SQL_CATALOG_PREFIX$catalogName")
connector.contains(UNITY_CATALOG_CONNECTOR_CLASS)}
connector.contains(UNITY_CATALOG_CONNECTOR_CLASS)
}
.flatMap { catalogName: String =>
val uri = catalogConfigs.get(s"$SPARK_SQL_CATALOG_PREFIX$catalogName.$URI_SUFFIX")
val token = catalogConfigs.get(s"$SPARK_SQL_CATALOG_PREFIX$catalogName.$TOKEN_SUFFIX")
(uri, token) match {
case (Some(u), Some(t)) =>
try {
new URI(u) // Validate the URI
Some((catalogName, u, t))
} catch {
case _: URISyntaxException =>
logWarning(log"Skipping catalog ${MDC(DeltaLogKeys.CATALOG, catalogName)} as it " +
log"does not have a valid URI ${MDC(DeltaLogKeys.URI, u)}.")
None
}
val oauthUri = catalogConfigs.get(
s"$SPARK_SQL_CATALOG_PREFIX$catalogName.$OAUTH_URI_SUFFIX")
val oauthClientId = catalogConfigs.get(
s"$SPARK_SQL_CATALOG_PREFIX$catalogName.$OAUTH_CLIENT_ID_SUFFIX")
val oauthClientSecret = catalogConfigs.get(
s"$SPARK_SQL_CATALOG_PREFIX$catalogName.$OAUTH_CLIENT_SECRET_SUFFIX")

UCClientParams.create(
catalogName,
uri,
token,
oauthUri,
oauthClientId,
oauthClientSecret) match {
case Some(ucClientParams) =>
Some(catalogName, ucClientParams)
case _ =>
logWarning(log"Skipping catalog ${MDC(DeltaLogKeys.CATALOG, catalogName)} as it does " +
"not have both uri and token configured in Spark Session.")
None
}}
}
}
.toList
}

Expand All @@ -240,15 +261,18 @@ object UCCommitCoordinatorBuilder

def clearCache(): Unit = {
commitCoordinatorClientCache.clear()
uriTokenToMetastoreIdCache.clear()
ucClientParamsToMetastoreIdCache.clear()
}
}

trait UCClientFactory {
def createUCClient(uri: String, token: String): UCClient
def createUCClient(uri: String, token: String): UCClient =
createUCClient(uri, new FixedUCTokenProvider(token))

def createUCClient(uri: String, provider: UCTokenProvider): UCClient
}

object UCTokenBasedRestClientFactory extends UCClientFactory {
override def createUCClient(uri: String, token: String): UCClient =
new UCTokenBasedRestClient(uri, token)
override def createUCClient(uri: String, provider: UCTokenProvider): UCClient =
new UCTokenBasedRestClient(uri, provider)
}
Loading
Loading