From f65f969b8d7c55ca12942e2ca951fbc621a112d1 Mon Sep 17 00:00:00 2001 From: Andrew Coleman Date: Wed, 17 Sep 2025 16:57:50 +0100 Subject: [PATCH] feat(spark): support Hive DDL / Insert operations Add support for DdlRel and WriteRel for Hive in Spark Signed-off-by: Andrew Coleman --- spark/.gitignore | 4 + spark/build.gradle.kts | 6 +- .../spark/logical/ToLogicalPlan.scala | 109 ++++++++++++----- .../spark/logical/ToSubstraitRel.scala | 69 ++++++++++- .../io/substrait/spark/HiveTableSuite.scala | 113 ++++++++++++++++++ 5 files changed, 264 insertions(+), 37 deletions(-) create mode 100644 spark/.gitignore create mode 100644 spark/src/test/scala/io/substrait/spark/HiveTableSuite.scala diff --git a/spark/.gitignore b/spark/.gitignore new file mode 100644 index 000000000..4e591f7d3 --- /dev/null +++ b/spark/.gitignore @@ -0,0 +1,4 @@ +metastore_db +spark-warehouse +/src/test/resources/write-a.csv +derby.log diff --git a/spark/build.gradle.kts b/spark/build.gradle.kts index 41837b386..48d5c75f2 100644 --- a/spark/build.gradle.kts +++ b/spark/build.gradle.kts @@ -116,6 +116,7 @@ dependencies { implementation(libs.scala.library) api(libs.spark.core) api(libs.spark.sql) + implementation(libs.spark.hive) implementation(libs.spark.catalyst) implementation(libs.slf4j.api) @@ -148,6 +149,9 @@ tasks { test { dependsOn(":core:shadowJar") useJUnitPlatform { includeEngines("scalatest") } - jvmArgs("--add-exports=java.base/sun.nio.ch=ALL-UNNAMED") + jvmArgs( + "--add-exports=java.base/sun.nio.ch=ALL-UNNAMED", + "--add-opens=java.base/java.net=ALL-UNNAMED", + ) } } diff --git a/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala b/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala index aaf08ba07..e21cecb9a 100644 --- a/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala +++ b/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala @@ -29,12 +29,13 @@ import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftAnti, LeftOute import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.execution.QueryExecution -import org.apache.spark.sql.execution.command.{CreateDataSourceTableAsSelectCommand, DataWritingCommand, LeafRunnableCommand} +import org.apache.spark.sql.execution.command.{CreateDataSourceTableAsSelectCommand, CreateTableCommand, DataWritingCommand, DropTableCommand, LeafRunnableCommand} import org.apache.spark.sql.execution.datasources.{FileFormat => SparkFileFormat, HadoopFsRelation, InMemoryFileIndex, InsertIntoHadoopFsRelationCommand, LogicalRelation, V1Writes} import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat -import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.hive.execution.{CreateHiveTableAsSelectCommand, InsertIntoHiveTable} +import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} import org.apache.spark.sql.types.{DataType, IntegerType, StructField, StructType} import io.substrait.`type`.{NamedStruct, StringTypeVisitor, Type} @@ -42,7 +43,8 @@ import io.substrait.{expression => exp} import io.substrait.expression.{Expression => SExpression} import io.substrait.plan.Plan import io.substrait.relation -import io.substrait.relation.{ExtensionWrite, LocalFiles, NamedWrite} +import io.substrait.relation.{ExtensionWrite, LocalFiles, NamedDdl, NamedWrite} +import io.substrait.relation.AbstractDdlRel.{DdlObject, DdlOp} import io.substrait.relation.AbstractWriteRel.{CreateMode, WriteOp} import io.substrait.relation.Expand.{ConsistentField, SwitchingField} import io.substrait.relation.Set.SetOp @@ -50,6 +52,8 @@ import io.substrait.relation.files.FileFormat import io.substrait.util.EmptyVisitationContext import org.apache.hadoop.fs.Path +import java.net.URI + import scala.collection.JavaConverters.asScalaBufferConverter import scala.collection.mutable.ArrayBuffer @@ -437,35 +441,44 @@ class ToLogicalPlan(spark: SparkSession = SparkSession.builder().getOrCreate()) override def visit(write: NamedWrite, context: EmptyVisitationContext): LogicalPlan = { val child = write.getInput.accept(this, context) - - val (table, database, catalog) = write.getNames.asScala match { - case Seq(table) => (table, None, None) - case Seq(database, table) => (table, Some(database), None) - case Seq(catalog, database, table) => (table, Some(database), Some(catalog)) - case names => - throw new UnsupportedOperationException( - s"NamedWrite requires up to three names ([[catalog,] database,] table): $names") + val table = catalogTable(write.getNames.asScala) + val isHive = spark.conf.get(StaticSQLConf.CATALOG_IMPLEMENTATION.key) match { + case "hive" => true + case _ => false } - val id = TableIdentifier(table, database, catalog) - val catalogTable = CatalogTable( - id, - CatalogTableType.MANAGED, - CatalogStorageFormat.empty, - new StructType(), - Some("parquet") - ) write.getOperation match { case WriteOp.CTAS => withChild(child) { - CreateDataSourceTableAsSelectCommand( - catalogTable, - saveMode(write.getCreateMode), + if (isHive) { + CreateHiveTableAsSelectCommand( + table, + child, + write.getTableSchema.names().asScala, + saveMode(write.getCreateMode) + ) + } else { + CreateDataSourceTableAsSelectCommand( + table, + saveMode(write.getCreateMode), + child, + write.getTableSchema.names().asScala + ) + } + } + case WriteOp.INSERT => + withChild(child) { + InsertIntoHiveTable( + table, + Map.empty, child, + write.getCreateMode == CreateMode.REPLACE_IF_EXISTS, + false, write.getTableSchema.names().asScala ) } case op => throw new UnsupportedOperationException(s"Write mode $op not supported") } + } override def visit(write: ExtensionWrite, context: EmptyVisitationContext): LogicalPlan = { @@ -491,14 +504,7 @@ class ToLogicalPlan(spark: SparkSession = SparkSession.builder().getOrCreate()) val (format, options) = convertFileFormat(file.getFileFormat.get) val name = file.getPath.get.split('/').reverse.head - val id = TableIdentifier(name) - val table = CatalogTable( - id, - CatalogTableType.MANAGED, - CatalogStorageFormat.empty, - new StructType(), - None - ) + val table = catalogTable(Seq(name)) withChild(child) { V1Writes.apply( @@ -519,6 +525,49 @@ class ToLogicalPlan(spark: SparkSession = SparkSession.builder().getOrCreate()) } } + override def visit(ddl: NamedDdl, context: EmptyVisitationContext): LogicalPlan = { + val table = catalogTable(ddl.getNames.asScala, ToSparkType.toStructType(ddl.getTableSchema)) + + (ddl.getOperation, ddl.getObject) match { + case (DdlOp.CREATE, DdlObject.TABLE) => CreateTableCommand(table, false) + case (DdlOp.DROP, DdlObject.TABLE) => DropTableCommand(table.identifier, false, false, false) + case (DdlOp.DROP_IF_EXIST, DdlObject.TABLE) => + DropTableCommand(table.identifier, true, false, false) + case op => throw new UnsupportedOperationException(s"Ddl operation $op not supported") + } + } + + private def catalogTable( + names: Seq[String], + schema: StructType = new StructType()): CatalogTable = { + val (table, database, catalog) = names match { + case Seq(table) => (table, None, None) + case Seq(database, table) => (table, Some(database), None) + case Seq(catalog, database, table) => (table, Some(database), Some(catalog)) + case names => + throw new UnsupportedOperationException( + s"NamedWrite requires up to three names ([[catalog,] database,] table): $names") + } + + val loc = spark.conf.get(StaticSQLConf.WAREHOUSE_PATH.key) + val storage = CatalogStorageFormat( + locationUri = Some(URI.create(f"$loc/$table")), + inputFormat = None, + outputFormat = None, + serde = None, + compressed = false, + properties = Map.empty + ) + val id = TableIdentifier(table, database, catalog) + CatalogTable( + id, + CatalogTableType.MANAGED, + storage, + schema, + Some("parquet") + ) + } + private def saveMode(mode: CreateMode): SaveMode = mode match { case CreateMode.APPEND_IF_EXISTS => SaveMode.Append case CreateMode.REPLACE_IF_EXISTS => SaveMode.Overwrite diff --git a/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala b/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala index 406fce6cb..5bee8873d 100644 --- a/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala +++ b/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala @@ -22,27 +22,31 @@ import io.substrait.spark.expression._ import org.apache.spark.internal.Logging import org.apache.spark.sql.SaveMode import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.ResolvedIdentifier import org.apache.spark.sql.catalyst.catalog.{CatalogTable, HiveTableRelation} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, Sum} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.LogicalRDD -import org.apache.spark.sql.execution.command.CreateDataSourceTableAsSelectCommand +import org.apache.spark.sql.execution.command.{CreateDataSourceTableAsSelectCommand, CreateTableCommand, DropTableCommand} import org.apache.spark.sql.execution.datasources.{FileFormat => DSFileFormat, HadoopFsRelation, InsertIntoHadoopFsRelationCommand, LogicalRelation, V1WriteCommand, WriteFiles} import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat -import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2ScanRelation} +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2ScanRelation, V2SessionCatalog} +import org.apache.spark.sql.hive.execution.{CreateHiveTableAsSelectCommand, InsertIntoHiveTable} import org.apache.spark.sql.types.{NullType, StructType} import io.substrait.`type`.{NamedStruct, Type} import io.substrait.{proto, relation} import io.substrait.debug.TreePrinter import io.substrait.expression.{Expression => SExpression, ExpressionCreator} +import io.substrait.expression.Expression.StructLiteral import io.substrait.extension.ExtensionCollector import io.substrait.hint.Hint import io.substrait.plan.Plan +import io.substrait.relation.AbstractDdlRel.{DdlObject, DdlOp} import io.substrait.relation.AbstractWriteRel.{CreateMode, OutputMode, WriteOp} import io.substrait.relation.RelProtoConverter import io.substrait.relation.Set.SetOp @@ -54,7 +58,7 @@ import io.substrait.utils.Util import java.util import java.util.{Collections, Optional} -import scala.collection.JavaConverters.asJavaIterableConverter +import scala.collection.JavaConverters.{asJavaIterableConverter, seqAsJavaList} import scala.collection.mutable import scala.collection.mutable.ArrayBuffer @@ -75,9 +79,7 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging { override def default(p: LogicalPlan): relation.Rel = p match { case c: CommandResult => visit(c.commandLogicalPlan) case w: WriteFiles => visit(w.child) - case c: V1WriteCommand => convertDataWritingCommand(c) - case CreateDataSourceTableAsSelectCommand(table, mode, query, names) => - convertCTAS(table, mode, query, names) + case c: Command => convertCommand(c) case p: LeafNode => convertReadOperator(p) case s: SubqueryAlias => visit(s.child) case v: View => visit(v.child) @@ -566,6 +568,28 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging { } } + private def convertCommand(command: Command): relation.Rel = command match { + case c: V1WriteCommand => convertDataWritingCommand(c) + case CreateDataSourceTableAsSelectCommand(table, mode, query, names) => + convertCTAS(table, mode, query, names) + case CreateHiveTableAsSelectCommand(table, query, names, mode) => + convertCTAS(table, mode, query, names) + case CreateTableCommand(table, _) => + convertCreateTable(table.identifier.unquotedString.split("\\."), table.schema) + case DropTableCommand(tableName, ifExists, _, _) => + convertDropTable(tableName.unquotedString.split("\\."), ifExists) + case CreateTable(ResolvedIdentifier(c: V2SessionCatalog, id), tableSchema, _, _, _) + if id.namespace().length > 0 => + val names = Seq(c.name(), id.namespace()(0), id.name()) + convertCreateTable(names, tableSchema) + case DropTable(ResolvedIdentifier(c: V2SessionCatalog, id), ifExists, _) + if id.namespace().length > 0 => + val names = Seq(c.name(), id.namespace()(0), id.name()) + convertDropTable(names, ifExists) + case _ => + throw new UnsupportedOperationException(s"Unable to convert command: $command") + } + private def convertDataWritingCommand(command: V1WriteCommand): relation.AbstractWriteRel = command match { case InsertIntoHadoopFsRelationCommand( @@ -600,6 +624,16 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging { .tableSchema(outputSchema(child.output, outputColumnNames)) .detail(FileHolder(file)) .build() + case InsertIntoHiveTable(table, _, child, overwrite, _, outputColumnNames, _, _, _, _, _) => + relation.NamedWrite + .builder() + .input(visit(child)) + .operation(WriteOp.INSERT) + .outputMode(OutputMode.UNSPECIFIED) + .createMode(if (overwrite) CreateMode.REPLACE_IF_EXISTS else CreateMode.ERROR_IF_EXISTS) + .names(seqAsJavaList(table.identifier.unquotedString.split("\\.").toList)) + .tableSchema(outputSchema(child.output, outputColumnNames)) + .build() case _ => throw new UnsupportedOperationException(s"Unable to convert command: ${command.getClass}") } @@ -619,6 +653,29 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging { .tableSchema(outputSchema(query.output, outputColumnNames)) .build() + private def convertCreateTable(names: Seq[String], schema: StructType): relation.NamedDdl = { + relation.NamedDdl + .builder() + .operation(DdlOp.CREATE) + .`object`(DdlObject.TABLE) + .names(seqAsJavaList(names)) + .tableSchema(ToSubstraitType.toNamedStruct(schema)) + .tableDefaults(StructLiteral.builder.nullable(true).build()) + .build() + } + + private def convertDropTable(names: Seq[String], ifExists: Boolean): relation.NamedDdl = { + relation.NamedDdl + .builder() + .operation(if (ifExists) DdlOp.DROP_IF_EXIST else DdlOp.DROP) + .`object`(DdlObject.TABLE) + .names(seqAsJavaList(names)) + .tableSchema( + NamedStruct.builder().struct(Type.Struct.builder().nullable(true).build()).build()) + .tableDefaults(StructLiteral.builder.nullable(true).build()) + .build() + } + private def createMode(mode: SaveMode): CreateMode = mode match { case SaveMode.Append => CreateMode.APPEND_IF_EXISTS case SaveMode.Overwrite => CreateMode.REPLACE_IF_EXISTS diff --git a/spark/src/test/scala/io/substrait/spark/HiveTableSuite.scala b/spark/src/test/scala/io/substrait/spark/HiveTableSuite.scala new file mode 100644 index 000000000..edb2acd96 --- /dev/null +++ b/spark/src/test/scala/io/substrait/spark/HiveTableSuite.scala @@ -0,0 +1,113 @@ +package io.substrait.spark + +import io.substrait.spark.logical.{ToLogicalPlan, ToSubstraitRel} + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan + +import io.substrait.extension.ExtensionLookup +import io.substrait.plan.{PlanProtoConverter, ProtoPlanConverter} +import io.substrait.relation.ProtoRelConverter + +class HiveTableSuite extends SparkFunSuite { + var spark: SparkSession = + SparkSession.builder().config("spark.master", "local").enableHiveSupport().getOrCreate() + + override def beforeAll(): Unit = { + super.beforeAll() + spark.sparkContext.setLogLevel("WARN") + + // introduced in spark 3.4 + spark.conf.set("spark.sql.readSideCharPadding", "false") + spark.conf.set("spark.sql.legacy.createHiveTableByDefault", "false") + } + + def assertRoundTripExtension(plan: LogicalPlan): LogicalPlan = { + val toSubstrait = new ToSubstraitRel + val substraitPlan = toSubstrait.convert(plan) + + // Serialize to proto buffer + val bytes = new PlanProtoConverter() + .toProto(substraitPlan) + .toByteArray + + // Read it back + val protoPlan = io.substrait.proto.Plan + .parseFrom(bytes) + val converter = new ProtoPlanConverter { + override protected def getProtoRelConverter( + functionLookup: ExtensionLookup): ProtoRelConverter = { + new FileHolderHandlingProtoRelConverter(functionLookup) + } + } + val substraitPlan2 = converter.from(protoPlan) + val sparkPlan2 = new ToLogicalPlan(spark).convert(substraitPlan2) + val roundTrippedPlan = toSubstrait.convert(sparkPlan2) + assertResult(substraitPlan)(roundTrippedPlan) + + sparkPlan2 + } + + test("Create / Drop table") { + spark.sql("drop table if exists cdtest") + + val create = spark.sql("create table cdtest(ID int, VALUE string) using hive") + assertResult(true)(spark.catalog.tableExists("cdtest")) + + val drop = spark.sql("drop table cdtest") + assertResult(false)(spark.catalog.tableExists("cdtest")) + + // convert the plans to Substrait and back + val cPlan = assertRoundTripExtension(create.queryExecution.optimizedPlan) + val dPlan = assertRoundTripExtension(drop.queryExecution.optimizedPlan) + + // execute the round-tripped 'create' plan and assert the table exists + spark.sessionState.executePlan(cPlan).executedPlan.execute() + assertResult(true)(spark.catalog.tableExists("cdtest")) + + // execute the round-tripped 'drop' plan and assert the table no longer exists + spark.sessionState.executePlan(dPlan).executedPlan.execute() + assertResult(false)(spark.catalog.tableExists("cdtest")) + } + + test("Insert into Hive table") { + // create a Hive table with 2 rows + spark.sql("drop table if exists test") + spark.sql("create table test(ID int, VALUE string) using hive") + spark.sql("insert into test values(1001, 'hello')") + spark.sql("insert into test values(1002, 'world')") + assertResult(2)(spark.sql("select * from test").count()) + + // insert a new row - and capture the query plan + val insert = spark.sql("insert into test values(1003, 'again')") + // there are now 3 rows + assertResult(3)(spark.sql("select * from test").count()) + + // convert the plan to Substrait and back + val plan = assertRoundTripExtension(insert.queryExecution.optimizedPlan) + // this should not have affected the table (still 3 rows) + assertResult(3)(spark.sql("select * from test").count()) + + // now execute the round-tripped plan and assert an extra row is appended + spark.sessionState.executePlan(plan).executedPlan.execute() + assertResult(4)(spark.sql("select * from test").count()) + // and again... + spark.sessionState.executePlan(plan).executedPlan.execute() + assertResult(5)(spark.sql("select * from test").count()) + } + + test("CTAS") { + spark.sql("drop table if exists ctas") + val df = spark.sql( + "create table ctas using hive as select * from (values (1, 'a'), (2, 'b') as table(col1, col2))") + assertResult(2)(spark.sql("select * from ctas").count()) + + spark.sql("drop table ctas") + + val plan = assertRoundTripExtension(df.queryExecution.optimizedPlan) + spark.sessionState.executePlan(plan).executedPlan.execute() + assertResult(2)(spark.sql("select * from ctas").count()) + } + +}