Skip to content

Commit

Permalink
feat(spark): add MakeDecimal support (#298)
Browse files Browse the repository at this point in the history
The Spark query optimiser injects an internal function
(MakeDecimal) when numeric literals appear in a query.
This commit adds support for this, which drastically improves
the pass rate for the TPC-DS test suite.
  • Loading branch information
andrew-coleman authored Oct 2, 2024
1 parent f22b3d0 commit eec9727
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 23 deletions.
9 changes: 9 additions & 0 deletions spark/src/main/resources/spark.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,12 @@ scalar_functions:
- args:
- value: DECIMAL<P,S>
return: i64
-
name: make_decimal
description: >-
Return the Decimal value of an unscaled Long.
Note: this expression is internal and created only by the optimizer,
impls:
- args:
- value: i64
return: DECIMAL<P,S>
4 changes: 3 additions & 1 deletion spark/src/main/scala/io/substrait/spark/SparkExtension.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@ import scala.collection.JavaConverters
import scala.collection.JavaConverters.asScalaBufferConverter

object SparkExtension {
final val uri = "/spark.yml"

private val SparkImpls: SimpleExtension.ExtensionCollection =
SimpleExtension.load(Collections.singletonList("/spark.yml"))
SimpleExtension.load(Collections.singletonList(uri))

private val EXTENSION_COLLECTION: SimpleExtension.ExtensionCollection =
SimpleExtension.loadDefaults()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class FunctionMappings {
s[Year]("year"),

// internal
s[MakeDecimal]("make_decimal"),
s[UnscaledValue]("unscaled")
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
*/
package io.substrait.spark.expression

import io.substrait.spark.{DefaultExpressionVisitor, HasOutputStack, ToSubstraitType}
import io.substrait.spark.{DefaultExpressionVisitor, HasOutputStack, SparkExtension, ToSubstraitType}
import io.substrait.spark.logical.ToLogicalPlan
import org.apache.spark.sql.catalyst.expressions.{CaseWhen, Cast, Expression, In, Literal, NamedExpression, ScalarSubquery}
import org.apache.spark.sql.types.{Decimal, NullType}
import org.apache.spark.sql.catalyst.expressions.{CaseWhen, Cast, Expression, In, Literal, MakeDecimal, NamedExpression, ScalarSubquery}
import org.apache.spark.sql.types.Decimal
import org.apache.spark.unsafe.types.UTF8String
import io.substrait.`type`.{StringTypeVisitor, Type}
import io.substrait.{expression => exp}
Expand Down Expand Up @@ -131,23 +131,32 @@ class ToSparkExpression(
arg.accept(expr.declaration(), i, this)
}

scalarFunctionConverter
.getSparkExpressionFromSubstraitFunc(expr.declaration().key(), expr.outputType())
.flatMap(sig => Option(sig.makeCall(args)))
.getOrElse({
val msg = String.format(
"Unable to convert scalar function %s(%s).",
expr.declaration.name,
expr.arguments.asScala
.map {
case ea: exp.EnumArg => ea.value.toString
case e: SExpression => e.getType.accept(new StringTypeVisitor)
case t: Type => t.accept(new StringTypeVisitor)
case a => throw new IllegalStateException("Unexpected value: " + a)
}
.mkString(", ")
)
throw new IllegalArgumentException(msg)
})
expr.declaration.name match {
case "make_decimal" if expr.declaration.uri == SparkExtension.uri => expr.outputType match {
// Need special case handing of this internal function.
// Because the precision and scale arguments are extracted from the output type,
// we can't use the generic scalar function conversion mechanism here.
case d: Type.Decimal => MakeDecimal(args.head, d.precision, d.scale)
case _ => throw new IllegalArgumentException("Output type of MakeDecimal must be a decimal type")
}
case _ => scalarFunctionConverter
.getSparkExpressionFromSubstraitFunc(expr.declaration().key(), expr.outputType())
.flatMap(sig => Option(sig.makeCall(args)))
.getOrElse({
val msg = String.format(
"Unable to convert scalar function %s(%s).",
expr.declaration.name,
expr.arguments.asScala
.map {
case ea: exp.EnumArg => ea.value.toString
case e: SExpression => e.getType.accept(new StringTypeVisitor)
case t: Type => t.accept(new StringTypeVisitor)
case a => throw new IllegalStateException("Unexpected value: " + a)
}
.mkString(", ")
)
throw new IllegalArgumentException(msg)
})
}
}
}
11 changes: 10 additions & 1 deletion spark/src/test/scala/io/substrait/spark/TPCDSPlan.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,16 @@ class TPCDSPlan extends TPCDSBase with SubstraitPlanTestBase {
}

// "q9" failed in spark 3.3
val successfulSQL: Set[String] = Set("q4", "q7", "q18", "q22", "q26", "q28", "q29", "q37", "q41", "q48", "q50", "q62", "q69", "q82", "q85", "q88", "q90", "q93", "q96", "q97", "q99")
val successfulSQL: Set[String] = Set("q1", "q3", "q4", "q7",
"q11", "q13", "q15", "q16", "q18", "q19",
"q22", "q25", "q26", "q28", "q29",
"q30", "q31", "q32", "q37",
"q41", "q42", "q43", "q46", "q48",
"q50", "q52", "q55", "q58", "q59",
"q61", "q62", "q65", "q68", "q69",
"q79",
"q81", "q82", "q85", "q88",
"q90", "q91", "q92", "q93", "q94", "q95", "q96", "q97", "q99")

tpcdsQueries.foreach {
q =>
Expand Down

0 comments on commit eec9727

Please sign in to comment.