diff --git a/src/main/scala/za/co/absa/abris/avro/read/confluent/SchemaManagerFactory.scala b/src/main/scala/za/co/absa/abris/avro/read/confluent/SchemaManagerFactory.scala index 6fc8c289..6dd70b12 100644 --- a/src/main/scala/za/co/absa/abris/avro/read/confluent/SchemaManagerFactory.scala +++ b/src/main/scala/za/co/absa/abris/avro/read/confluent/SchemaManagerFactory.scala @@ -17,6 +17,7 @@ package za.co.absa.abris.avro.read.confluent +import io.confluent.kafka.schemaregistry.client.rest.RestService import org.apache.spark.internal.Logging import za.co.absa.abris.avro.registry.{AbrisRegistryClient, ConfluentMockRegistryClient, ConfluentRegistryClient} import za.co.absa.abris.config.AbrisConfig @@ -34,7 +35,8 @@ import scala.util.control.NonFatal */ object SchemaManagerFactory extends Logging { - private val clientInstances: concurrent.Map[Map[String,String], AbrisRegistryClient] = concurrent.TrieMap() + private val clientInstances: concurrent.Map[Map[String, String], AbrisRegistryClient] = concurrent.TrieMap() + private val restClientInstances: concurrent.Map[RestService, AbrisRegistryClient] = concurrent.TrieMap() @DeveloperApi def addSRClientInstance(configs: Map[String, String], client: AbrisRegistryClient): Unit = { @@ -43,12 +45,17 @@ object SchemaManagerFactory extends Logging { @DeveloperApi def resetSRClientInstance(): Unit = { - clientInstances.clear() + clientInstances.clear() } - def create(configs: Map[String,String]): SchemaManager = new SchemaManager(getOrCreateRegistryClient(configs)) + def create(configs: Map[String, String]): SchemaManager = new SchemaManager(getOrCreateRegistryClient(configs)) - private def getOrCreateRegistryClient(configs: Map[String,String]): AbrisRegistryClient = { + def create(restService: RestService, maxSchemaObject: Int): SchemaManager = { + new SchemaManager(getOrCreateRegistryClient(restService, maxSchemaObject)) + } + + + private def getOrCreateRegistryClient(configs: Map[String, String]): AbrisRegistryClient = { clientInstances.getOrElseUpdate(configs, { if (configs.contains(AbrisConfig.REGISTRY_CLIENT_CLASS)) { try { @@ -74,4 +81,11 @@ object SchemaManagerFactory extends Logging { } }) } + + private def getOrCreateRegistryClient(restService: RestService, maxSchemaObject: Int): AbrisRegistryClient = { + restClientInstances.getOrElseUpdate(restService, { + logInfo(msg = s"Configuring new Schema Registry client of type ConfluentRegistryClient") + new ConfluentRegistryClient(restService, maxSchemaObject) + }) + } } diff --git a/src/main/scala/za/co/absa/abris/avro/registry/ConfluentRegistryClient.scala b/src/main/scala/za/co/absa/abris/avro/registry/ConfluentRegistryClient.scala index bbf33145..ba484105 100644 --- a/src/main/scala/za/co/absa/abris/avro/registry/ConfluentRegistryClient.scala +++ b/src/main/scala/za/co/absa/abris/avro/registry/ConfluentRegistryClient.scala @@ -15,6 +15,7 @@ */ package za.co.absa.abris.avro.registry +import io.confluent.kafka.schemaregistry.client.rest.RestService import io.confluent.kafka.schemaregistry.client.{CachedSchemaRegistryClient, SchemaRegistryClient} import io.confluent.kafka.serializers.KafkaAvroDeserializerConfig @@ -23,6 +24,10 @@ import scala.collection.JavaConverters._ class ConfluentRegistryClient(client: SchemaRegistryClient) extends AbstractConfluentRegistryClient(client) { def this(configs: Map[String,String]) = this(ConfluentRegistryClient.createClient(configs)) + + def this(restService: RestService, maxSchemaObject: Int) = { + this(ConfluentRegistryClient.createClient(restService, maxSchemaObject)) + } } object ConfluentRegistryClient { @@ -35,4 +40,8 @@ object ConfluentRegistryClient { new CachedSchemaRegistryClient(urls, maxSchemaObject, configs.asJava) } + private def createClient(restService: RestService, maxSchemaObject: Int) = { + new CachedSchemaRegistryClient(restService, maxSchemaObject) + } + } diff --git a/src/main/scala/za/co/absa/abris/config/Config.scala b/src/main/scala/za/co/absa/abris/config/Config.scala index e7ca9b45..74da082c 100644 --- a/src/main/scala/za/co/absa/abris/config/Config.scala +++ b/src/main/scala/za/co/absa/abris/config/Config.scala @@ -16,9 +16,10 @@ package za.co.absa.abris.config +import io.confluent.kafka.schemaregistry.client.rest.RestService import za.co.absa.abris.avro.errors.DeserializationExceptionHandler import za.co.absa.abris.avro.parsing.utils.AvroSchemaUtils -import za.co.absa.abris.avro.read.confluent.SchemaManagerFactory +import za.co.absa.abris.avro.read.confluent.{SchemaManager, SchemaManagerFactory} import za.co.absa.abris.avro.registry._ object AbrisConfig { @@ -78,15 +79,23 @@ class ToStrategyConfigFragment(version: SchemaVersion, confluent: Boolean) { class ToSchemaDownloadingConfigFragment(schemaCoordinates: SchemaCoordinate, confluent: Boolean) { def usingSchemaRegistry(url: String): ToAvroConfig = usingSchemaRegistry(Map(AbrisConfig.SCHEMA_REGISTRY_URL -> url)) - def usingSchemaRegistry(config: Map[String, String]): ToAvroConfig = { - val schemaManager = SchemaManagerFactory.create(config) - val (schemaId, schemaString) = schemaCoordinates match { + private def getSchemaIDAndSchemaString(schemaManager: SchemaManager) : (Int, String) = { + schemaCoordinates match { case ic: IdCoordinate => (ic.schemaId, schemaManager.getSchemaById(ic.schemaId).toString) case sc: SubjectCoordinate => { val metadata = schemaManager.getSchemaMetadataBySubjectAndVersion(sc.subject, sc.version) (metadata.getId, metadata.getSchema) } } + } + def usingSchemaRegistry(config: Map[String, String]): ToAvroConfig = { + val schemaManager = SchemaManagerFactory.create(config) + val (schemaId, schemaString) = getSchemaIDAndSchemaString(schemaManager) + new ToAvroConfig(schemaString, if (confluent) Some(schemaId) else None) + } + def usingSchemaRegistry(restService: RestService, maxSchemaObject: Int): ToAvroConfig = { + val schemaManager = SchemaManagerFactory.create(restService, maxSchemaObject) + val (schemaId, schemaString) = getSchemaIDAndSchemaString(schemaManager) new ToAvroConfig(schemaString, if (confluent) Some(schemaId) else None) } } @@ -137,6 +146,12 @@ class ToSchemaRegisteringConfigFragment( val schemaId = schemaManager.getIfExistsOrElseRegisterSchema(schema, subject) new ToAvroConfig(schemaString, if (confluent) Some(schemaId) else None) } + def usingSchemaRegistry(restService: RestService, maxSchemaObject: Int): ToAvroConfig = { + val schemaManager = SchemaManagerFactory.create(restService, maxSchemaObject) + val schema = AvroSchemaUtils.parse(schemaString) + val schemaId = schemaManager.getIfExistsOrElseRegisterSchema(schema, subject) + new ToAvroConfig(schemaString, if (confluent) Some(schemaId) else None) + } } /** @@ -258,6 +273,21 @@ class FromSchemaDownloadingConfigFragment( throw new UnsupportedOperationException("Unsupported config permutation") } } + + def usingSchemaRegistry(restService: RestService, maxSchemaObject: Int): FromAvroConfig = { + schemaCoordinatesOrSchemaString match { + case Left(coordinate) => + val schemaManager = SchemaManagerFactory.create(restService, maxSchemaObject) + val schema = schemaManager.getSchema(coordinate) + new FromAvroConfig(schema.toString, if (confluent) Some(Map()) else None) + case Right(schemaString) => + if (confluent) { + new FromAvroConfig(schemaString, None) + } else { + throw new UnsupportedOperationException("Unsupported config permutation") + } + } + } } class FromConfluentAvroConfigFragment { diff --git a/src/test/scala/za/co/absa/abris/avro/read/confluent/SchemaManagerFactorySpec.scala b/src/test/scala/za/co/absa/abris/avro/read/confluent/SchemaManagerFactorySpec.scala index ba5334c3..988806fb 100644 --- a/src/test/scala/za/co/absa/abris/avro/read/confluent/SchemaManagerFactorySpec.scala +++ b/src/test/scala/za/co/absa/abris/avro/read/confluent/SchemaManagerFactorySpec.scala @@ -16,11 +16,14 @@ package za.co.absa.abris.avro.read.confluent +import io.confluent.kafka.schemaregistry.client.rest.RestService +import io.confluent.kafka.schemaregistry.client.security.basicauth.UserInfoCredentialProvider import org.scalatest.BeforeAndAfterEach import org.scalatest.flatspec.AnyFlatSpec import za.co.absa.abris.avro.registry.{AbrisRegistryClient, ConfluentMockRegistryClient, ConfluentRegistryClient, TestRegistryClient} import za.co.absa.abris.config.AbrisConfig +import scala.collection.JavaConverters.mapAsJavaMapConverter import scala.reflect.runtime.{universe => ru} class SchemaManagerFactorySpec extends AnyFlatSpec with BeforeAndAfterEach { @@ -34,6 +37,11 @@ class SchemaManagerFactorySpec extends AnyFlatSpec with BeforeAndAfterEach { AbrisConfig.REGISTRY_CLIENT_CLASS -> "za.co.absa.abris.avro.registry.TestRegistryClient" ) + private val restService = new RestService("http://dummy_sr_2") + val provider = new UserInfoCredentialProvider() + provider.configure(schemaRegistryConfig2.asJava) + restService.setBasicAuthCredentialProvider(provider) + override def beforeEach(): Unit = { super.beforeEach() SchemaManagerFactory.resetSRClientInstance() // Reset factory state