Skip to content
This repository was archived by the owner on Feb 14, 2023. It is now read-only.

Define a maximum allowable depth for protocol buffer messages #2

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions project/build.properties
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
sbt.version=0.13.16
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
package com.trueaccord.scalapb.spark

import com.google.protobuf.ByteString
import com.google.protobuf.Descriptors.{ EnumValueDescriptor, FieldDescriptor }
import com.google.protobuf.Descriptors.{EnumValueDescriptor, FieldDescriptor}
import com.google.protobuf.Descriptors.FieldDescriptor.JavaType
import com.trueaccord.scalapb.{ GeneratedMessage, GeneratedMessageCompanion, Message }
import org.apache.spark.sql.types.{ ArrayType, StructField }
import org.apache.spark.sql.{ DataFrame, Row, SQLContext, SparkSession }
import com.trueaccord.scalapb.{GeneratedMessage, GeneratedMessageCompanion, Message}
import org.apache.spark.sql.types.{ArrayType, StringType, StructField}
import org.apache.spark.sql.{DataFrame, Row, SQLContext, SparkSession}

object ProtoSQL {
import scala.language.existentials

val maxSchemaDepth = 30

def protoToDataFrame[T <: GeneratedMessage with Message[T] : GeneratedMessageCompanion](
sparkSession: SparkSession, protoRdd: org.apache.spark.rdd.RDD[T]): DataFrame = {
sparkSession.createDataFrame(protoRdd.map(messageToRow[T]), schemaFor[T])
sparkSession.createDataFrame(protoRdd.map( x =>
messageToRow[T](x, 0)
), schemaFor[T])
}

def protoToDataFrame[T <: GeneratedMessage with Message[T] : GeneratedMessageCompanion](
Expand All @@ -23,52 +27,67 @@ object ProtoSQL {
def schemaFor[T <: GeneratedMessage with Message[T]](implicit cmp: GeneratedMessageCompanion[T]) = {
import org.apache.spark.sql.types._
import collection.JavaConverters._
StructType(cmp.javaDescriptor.getFields.asScala.map(structFieldFor))
StructType(cmp.javaDescriptor.getFields.asScala.map( x =>
structFieldFor(x, 0))
)
}

private def toRowData(fd: FieldDescriptor, obj: Any) = fd.getJavaType match {
case JavaType.BYTE_STRING => obj.asInstanceOf[ByteString].toByteArray
case JavaType.ENUM => obj.asInstanceOf[EnumValueDescriptor].getName
case JavaType.MESSAGE => messageToRow(obj.asInstanceOf[T forSome { type T <: GeneratedMessage with Message[T] }])
case _ => obj
private def toRowData(fd: FieldDescriptor, obj: Any, msgDepth: Integer) = {
if (msgDepth > maxSchemaDepth) {
throw new UnsupportedOperationException("Protobufs with schema depth of more than $maxSchemaDepth are not supported.")
} else {
fd.getJavaType match {
case JavaType.BYTE_STRING => obj.asInstanceOf[ByteString].toByteArray
case JavaType.ENUM => obj.asInstanceOf[EnumValueDescriptor].getName
case JavaType.MESSAGE => messageToRow(obj.asInstanceOf[T forSome {type T <: GeneratedMessage with Message[T]}], msgDepth + 1)
case _ => obj
}
}
}

def messageToRow[T <: GeneratedMessage with Message[T]](msg: T): Row = {
def messageToRow[T <: GeneratedMessage with Message[T]](msg: T, msgDepth: Integer): Row = {
import collection.JavaConversions._
Row(
msg.companion.javaDescriptor.getFields.map {
fd =>
val obj = msg.getField(fd)
if (obj != null) {
if (fd.isRepeated) {
obj.asInstanceOf[Traversable[Any]].map(toRowData(fd, _))
obj.asInstanceOf[Traversable[Any]].map(toRowData(fd, _, msgDepth))
} else {
toRowData(fd, obj)
toRowData(fd, obj, msgDepth)
}
} else null
}: _*)
}

def dataTypeFor(fd: FieldDescriptor) = {
import com.google.protobuf.Descriptors.FieldDescriptor.JavaType._
import org.apache.spark.sql.types._
fd.getJavaType match {
case INT => IntegerType
case LONG => LongType
case FLOAT => FloatType
case DOUBLE => DoubleType
case BOOLEAN => BooleanType
case STRING => StringType
case BYTE_STRING => BinaryType
case ENUM => StringType
case MESSAGE =>
import collection.JavaConverters._
StructType(fd.getMessageType.getFields.asScala.map(structFieldFor))
def dataTypeFor(fd: FieldDescriptor, msgDepth: Integer) = {
if(msgDepth > maxSchemaDepth) {
StringType
} else {
import com.google.protobuf.Descriptors.FieldDescriptor.JavaType._
import org.apache.spark.sql.types._
fd.getJavaType match {
case INT => IntegerType
case LONG => LongType
case FLOAT => FloatType
case DOUBLE => DoubleType
case BOOLEAN => BooleanType
case STRING => StringType
case BYTE_STRING => BinaryType
case ENUM => StringType
case MESSAGE =>
import collection.JavaConverters._
StructType(fd.getMessageType.getFields.asScala.map( x =>
structFieldFor(x, msgDepth + 1)
)
)
}
}
}

def structFieldFor(fd: FieldDescriptor): StructField = {
val dataType = dataTypeFor(fd)
def structFieldFor(fd: FieldDescriptor, msgDepth: Integer): StructField = {
val dataType = dataTypeFor(fd, msgDepth)
StructField(
fd.getName,
if (fd.isRepeated) ArrayType(dataType, containsNull = false) else dataType,
Expand Down
1 change: 1 addition & 0 deletions sparksql-scalapb/src/test/protobuf/demo.proto
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,5 @@ message Person {
repeated string tags = 4;
repeated Address addresses = 5;
optional Base base = 6;
optional Person other_person = 7;
}