Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package za.co.absa.abris.avro.read.confluent

import io.confluent.kafka.schemaregistry.client.rest.RestService
import org.apache.spark.internal.Logging
import za.co.absa.abris.avro.registry.{AbrisRegistryClient, ConfluentMockRegistryClient, ConfluentRegistryClient}
import za.co.absa.abris.config.AbrisConfig
Expand All @@ -34,7 +35,8 @@ import scala.util.control.NonFatal
*/
object SchemaManagerFactory extends Logging {

private val clientInstances: concurrent.Map[Map[String,String], AbrisRegistryClient] = concurrent.TrieMap()
private val clientInstances: concurrent.Map[Map[String, String], AbrisRegistryClient] = concurrent.TrieMap()
private val restClientInstances: concurrent.Map[RestService, AbrisRegistryClient] = concurrent.TrieMap()

@DeveloperApi
def addSRClientInstance(configs: Map[String, String], client: AbrisRegistryClient): Unit = {
Expand All @@ -43,12 +45,17 @@ object SchemaManagerFactory extends Logging {

@DeveloperApi
def resetSRClientInstance(): Unit = {
clientInstances.clear()
clientInstances.clear()
}

def create(configs: Map[String,String]): SchemaManager = new SchemaManager(getOrCreateRegistryClient(configs))
def create(configs: Map[String, String]): SchemaManager = new SchemaManager(getOrCreateRegistryClient(configs))

private def getOrCreateRegistryClient(configs: Map[String,String]): AbrisRegistryClient = {
def create(restService: RestService, maxSchemaObject: Int): SchemaManager = {
new SchemaManager(getOrCreateRegistryClient(restService, maxSchemaObject))
}


private def getOrCreateRegistryClient(configs: Map[String, String]): AbrisRegistryClient = {
clientInstances.getOrElseUpdate(configs, {
if (configs.contains(AbrisConfig.REGISTRY_CLIENT_CLASS)) {
try {
Expand All @@ -74,4 +81,11 @@ object SchemaManagerFactory extends Logging {
}
})
}

private def getOrCreateRegistryClient(restService: RestService, maxSchemaObject: Int): AbrisRegistryClient = {
restClientInstances.getOrElseUpdate(restService, {
logInfo(msg = s"Configuring new Schema Registry client of type ConfluentRegistryClient")
new ConfluentRegistryClient(restService, maxSchemaObject)
})
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/

package za.co.absa.abris.avro.registry
import io.confluent.kafka.schemaregistry.client.rest.RestService
import io.confluent.kafka.schemaregistry.client.{CachedSchemaRegistryClient, SchemaRegistryClient}
import io.confluent.kafka.serializers.KafkaAvroDeserializerConfig

Expand All @@ -23,6 +24,10 @@ import scala.collection.JavaConverters._
class ConfluentRegistryClient(client: SchemaRegistryClient) extends AbstractConfluentRegistryClient(client) {

def this(configs: Map[String,String]) = this(ConfluentRegistryClient.createClient(configs))

def this(restService: RestService, maxSchemaObject: Int) = {
this(ConfluentRegistryClient.createClient(restService, maxSchemaObject))
}
}
Comment on lines +27 to 31

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Lets update the files where this class is being used.
  2. Add Test cases
  3. Add scaladoc for this in terms of how to use this constructor


object ConfluentRegistryClient {
Expand All @@ -35,4 +40,8 @@ object ConfluentRegistryClient {
new CachedSchemaRegistryClient(urls, maxSchemaObject, configs.asJava)
}

private def createClient(restService: RestService, maxSchemaObject: Int) = {
new CachedSchemaRegistryClient(restService, maxSchemaObject)
}

}
38 changes: 34 additions & 4 deletions src/main/scala/za/co/absa/abris/config/Config.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@

package za.co.absa.abris.config

import io.confluent.kafka.schemaregistry.client.rest.RestService
import za.co.absa.abris.avro.errors.DeserializationExceptionHandler
import za.co.absa.abris.avro.parsing.utils.AvroSchemaUtils
import za.co.absa.abris.avro.read.confluent.SchemaManagerFactory
import za.co.absa.abris.avro.read.confluent.{SchemaManager, SchemaManagerFactory}
import za.co.absa.abris.avro.registry._

object AbrisConfig {
Expand Down Expand Up @@ -78,15 +79,23 @@ class ToStrategyConfigFragment(version: SchemaVersion, confluent: Boolean) {

class ToSchemaDownloadingConfigFragment(schemaCoordinates: SchemaCoordinate, confluent: Boolean) {
def usingSchemaRegistry(url: String): ToAvroConfig = usingSchemaRegistry(Map(AbrisConfig.SCHEMA_REGISTRY_URL -> url))
def usingSchemaRegistry(config: Map[String, String]): ToAvroConfig = {
val schemaManager = SchemaManagerFactory.create(config)
val (schemaId, schemaString) = schemaCoordinates match {
private def getSchemaIDAndSchemaString(schemaManager: SchemaManager) : (Int, String) = {
schemaCoordinates match {
case ic: IdCoordinate => (ic.schemaId, schemaManager.getSchemaById(ic.schemaId).toString)
case sc: SubjectCoordinate => {
val metadata = schemaManager.getSchemaMetadataBySubjectAndVersion(sc.subject, sc.version)
(metadata.getId, metadata.getSchema)
}
}
}
def usingSchemaRegistry(config: Map[String, String]): ToAvroConfig = {
val schemaManager = SchemaManagerFactory.create(config)
val (schemaId, schemaString) = getSchemaIDAndSchemaString(schemaManager)
new ToAvroConfig(schemaString, if (confluent) Some(schemaId) else None)
}
def usingSchemaRegistry(restService: RestService, maxSchemaObject: Int): ToAvroConfig = {
val schemaManager = SchemaManagerFactory.create(restService, maxSchemaObject)
val (schemaId, schemaString) = getSchemaIDAndSchemaString(schemaManager)
new ToAvroConfig(schemaString, if (confluent) Some(schemaId) else None)
}
}
Expand Down Expand Up @@ -137,6 +146,12 @@ class ToSchemaRegisteringConfigFragment(
val schemaId = schemaManager.getIfExistsOrElseRegisterSchema(schema, subject)
new ToAvroConfig(schemaString, if (confluent) Some(schemaId) else None)
}
def usingSchemaRegistry(restService: RestService, maxSchemaObject: Int): ToAvroConfig = {
val schemaManager = SchemaManagerFactory.create(restService, maxSchemaObject)
val schema = AvroSchemaUtils.parse(schemaString)
val schemaId = schemaManager.getIfExistsOrElseRegisterSchema(schema, subject)
new ToAvroConfig(schemaString, if (confluent) Some(schemaId) else None)
}
}

/**
Expand Down Expand Up @@ -258,6 +273,21 @@ class FromSchemaDownloadingConfigFragment(
throw new UnsupportedOperationException("Unsupported config permutation")
}
}

def usingSchemaRegistry(restService: RestService, maxSchemaObject: Int): FromAvroConfig = {
schemaCoordinatesOrSchemaString match {
case Left(coordinate) =>
val schemaManager = SchemaManagerFactory.create(restService, maxSchemaObject)
val schema = schemaManager.getSchema(coordinate)
new FromAvroConfig(schema.toString, if (confluent) Some(Map()) else None)
case Right(schemaString) =>
if (confluent) {
new FromAvroConfig(schemaString, None)
} else {
throw new UnsupportedOperationException("Unsupported config permutation")
}
}
}
}

class FromConfluentAvroConfigFragment {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@

package za.co.absa.abris.avro.read.confluent

import io.confluent.kafka.schemaregistry.client.rest.RestService
import io.confluent.kafka.schemaregistry.client.security.basicauth.UserInfoCredentialProvider
import org.scalatest.BeforeAndAfterEach
import org.scalatest.flatspec.AnyFlatSpec
import za.co.absa.abris.avro.registry.{AbrisRegistryClient, ConfluentMockRegistryClient, ConfluentRegistryClient, TestRegistryClient}
import za.co.absa.abris.config.AbrisConfig

import scala.collection.JavaConverters.mapAsJavaMapConverter
import scala.reflect.runtime.{universe => ru}

class SchemaManagerFactorySpec extends AnyFlatSpec with BeforeAndAfterEach {
Expand All @@ -34,6 +37,11 @@ class SchemaManagerFactorySpec extends AnyFlatSpec with BeforeAndAfterEach {
AbrisConfig.REGISTRY_CLIENT_CLASS -> "za.co.absa.abris.avro.registry.TestRegistryClient"
)

private val restService = new RestService("http://dummy_sr_2")
val provider = new UserInfoCredentialProvider()
provider.configure(schemaRegistryConfig2.asJava)
restService.setBasicAuthCredentialProvider(provider)

override def beforeEach(): Unit = {
super.beforeEach()
SchemaManagerFactory.resetSRClientInstance() // Reset factory state
Expand Down
Loading