diff --git a/spark/src/main/resources/spark.yml b/spark/src/main/resources/spark.yml index e398aa3a6..4f610ded7 100644 --- a/spark/src/main/resources/spark.yml +++ b/spark/src/main/resources/spark.yml @@ -32,3 +32,12 @@ scalar_functions: - args: - value: DECIMAL
return: i64 + - + name: make_decimal + description: >- + Return the Decimal value of an unscaled Long. + Note: this expression is internal and created only by the optimizer, + impls: + - args: + - value: i64 + return: DECIMAL
diff --git a/spark/src/main/scala/io/substrait/spark/SparkExtension.scala b/spark/src/main/scala/io/substrait/spark/SparkExtension.scala index 0d6d84b71..d61a06d3e 100644 --- a/spark/src/main/scala/io/substrait/spark/SparkExtension.scala +++ b/spark/src/main/scala/io/substrait/spark/SparkExtension.scala @@ -26,8 +26,10 @@ import scala.collection.JavaConverters import scala.collection.JavaConverters.asScalaBufferConverter object SparkExtension { + final val uri = "/spark.yml" + private val SparkImpls: SimpleExtension.ExtensionCollection = - SimpleExtension.load(Collections.singletonList("/spark.yml")) + SimpleExtension.load(Collections.singletonList(uri)) private val EXTENSION_COLLECTION: SimpleExtension.ExtensionCollection = SimpleExtension.loadDefaults() diff --git a/spark/src/main/scala/io/substrait/spark/expression/FunctionMappings.scala b/spark/src/main/scala/io/substrait/spark/expression/FunctionMappings.scala index 4bcae8dd9..27eebfc67 100644 --- a/spark/src/main/scala/io/substrait/spark/expression/FunctionMappings.scala +++ b/spark/src/main/scala/io/substrait/spark/expression/FunctionMappings.scala @@ -58,6 +58,7 @@ class FunctionMappings { s[Year]("year"), // internal + s[MakeDecimal]("make_decimal"), s[UnscaledValue]("unscaled") ) diff --git a/spark/src/main/scala/io/substrait/spark/expression/ToSparkExpression.scala b/spark/src/main/scala/io/substrait/spark/expression/ToSparkExpression.scala index e928689fa..2279d5496 100644 --- a/spark/src/main/scala/io/substrait/spark/expression/ToSparkExpression.scala +++ b/spark/src/main/scala/io/substrait/spark/expression/ToSparkExpression.scala @@ -16,10 +16,10 @@ */ package io.substrait.spark.expression -import io.substrait.spark.{DefaultExpressionVisitor, HasOutputStack, ToSubstraitType} +import io.substrait.spark.{DefaultExpressionVisitor, HasOutputStack, SparkExtension, ToSubstraitType} import io.substrait.spark.logical.ToLogicalPlan -import org.apache.spark.sql.catalyst.expressions.{CaseWhen, Cast, Expression, In, Literal, NamedExpression, ScalarSubquery} -import org.apache.spark.sql.types.{Decimal, NullType} +import org.apache.spark.sql.catalyst.expressions.{CaseWhen, Cast, Expression, In, Literal, MakeDecimal, NamedExpression, ScalarSubquery} +import org.apache.spark.sql.types.Decimal import org.apache.spark.unsafe.types.UTF8String import io.substrait.`type`.{StringTypeVisitor, Type} import io.substrait.{expression => exp} @@ -131,23 +131,32 @@ class ToSparkExpression( arg.accept(expr.declaration(), i, this) } - scalarFunctionConverter - .getSparkExpressionFromSubstraitFunc(expr.declaration().key(), expr.outputType()) - .flatMap(sig => Option(sig.makeCall(args))) - .getOrElse({ - val msg = String.format( - "Unable to convert scalar function %s(%s).", - expr.declaration.name, - expr.arguments.asScala - .map { - case ea: exp.EnumArg => ea.value.toString - case e: SExpression => e.getType.accept(new StringTypeVisitor) - case t: Type => t.accept(new StringTypeVisitor) - case a => throw new IllegalStateException("Unexpected value: " + a) - } - .mkString(", ") - ) - throw new IllegalArgumentException(msg) - }) + expr.declaration.name match { + case "make_decimal" if expr.declaration.uri == SparkExtension.uri => expr.outputType match { + // Need special case handing of this internal function. + // Because the precision and scale arguments are extracted from the output type, + // we can't use the generic scalar function conversion mechanism here. + case d: Type.Decimal => MakeDecimal(args.head, d.precision, d.scale) + case _ => throw new IllegalArgumentException("Output type of MakeDecimal must be a decimal type") + } + case _ => scalarFunctionConverter + .getSparkExpressionFromSubstraitFunc(expr.declaration().key(), expr.outputType()) + .flatMap(sig => Option(sig.makeCall(args))) + .getOrElse({ + val msg = String.format( + "Unable to convert scalar function %s(%s).", + expr.declaration.name, + expr.arguments.asScala + .map { + case ea: exp.EnumArg => ea.value.toString + case e: SExpression => e.getType.accept(new StringTypeVisitor) + case t: Type => t.accept(new StringTypeVisitor) + case a => throw new IllegalStateException("Unexpected value: " + a) + } + .mkString(", ") + ) + throw new IllegalArgumentException(msg) + }) + } } } diff --git a/spark/src/test/scala/io/substrait/spark/TPCDSPlan.scala b/spark/src/test/scala/io/substrait/spark/TPCDSPlan.scala index fd35b551a..f880b25a4 100644 --- a/spark/src/test/scala/io/substrait/spark/TPCDSPlan.scala +++ b/spark/src/test/scala/io/substrait/spark/TPCDSPlan.scala @@ -32,7 +32,16 @@ class TPCDSPlan extends TPCDSBase with SubstraitPlanTestBase { } // "q9" failed in spark 3.3 - val successfulSQL: Set[String] = Set("q4", "q7", "q18", "q22", "q26", "q28", "q29", "q37", "q41", "q48", "q50", "q62", "q69", "q82", "q85", "q88", "q90", "q93", "q96", "q97", "q99") + val successfulSQL: Set[String] = Set("q1", "q3", "q4", "q7", + "q11", "q13", "q15", "q16", "q18", "q19", + "q22", "q25", "q26", "q28", "q29", + "q30", "q31", "q32", "q37", + "q41", "q42", "q43", "q46", "q48", + "q50", "q52", "q55", "q58", "q59", + "q61", "q62", "q65", "q68", "q69", + "q79", + "q81", "q82", "q85", "q88", + "q90", "q91", "q92", "q93", "q94", "q95", "q96", "q97", "q99") tpcdsQueries.foreach { q =>