@@ -21,12 +21,12 @@ package org.apache.comet.serde
2121
2222import scala .math .min
2323
24- import org .apache .spark .sql .catalyst .expressions .{Add , Attribute , Cast , Divide , EqualTo , EvalMode , Expression , If , IntegralDivide , Literal , Multiply , Remainder , Subtract }
24+ import org .apache .spark .sql .catalyst .expressions .{Add , Attribute , Cast , Divide , EmptyRow , EqualTo , EvalMode , Expression , If , IntegralDivide , Literal , Multiply , Remainder , Round , Subtract }
2525import org .apache .spark .sql .types .{ByteType , DataType , DecimalType , DoubleType , FloatType , IntegerType , LongType , ShortType }
2626
2727import org .apache .comet .CometSparkSessionExtensions .withInfo
2828import org .apache .comet .expressions .CometEvalMode
29- import org .apache .comet .serde .QueryPlanSerde .{castToProto , evalModeToProto , exprToProtoInternal , serializeDataType }
29+ import org .apache .comet .serde .QueryPlanSerde .{castToProto , evalModeToProto , exprToProtoInternal , optExprWithInfo , scalarFunctionExprToProtoWithReturnType , serializeDataType }
3030import org .apache .comet .shims .CometEvalModeUtil
3131
3232trait MathBase {
@@ -261,3 +261,50 @@ object CometRemainder extends CometExpressionSerde[Remainder] with MathBase {
261261 (builder, mathExpr) => builder.setRemainder(mathExpr))
262262 }
263263}
264+
265+ object CometRound extends CometExpressionSerde [Round ] {
266+
267+ override def convert (
268+ r : Round ,
269+ inputs : Seq [Attribute ],
270+ binding : Boolean ): Option [ExprOuterClass .Expr ] = {
271+ // _scale s a constant, copied from Spark's RoundBase because it is a protected val
272+ val scaleV : Any = r.scale.eval(EmptyRow )
273+ val _scale : Int = scaleV.asInstanceOf [Int ]
274+
275+ lazy val childExpr = exprToProtoInternal(r.child, inputs, binding)
276+ r.child.dataType match {
277+ case t : DecimalType if t.scale < 0 => // Spark disallows negative scale SPARK-30252
278+ withInfo(r, " Decimal type has negative scale" )
279+ None
280+ case _ if scaleV == null =>
281+ exprToProtoInternal(Literal (null ), inputs, binding)
282+ case _ : ByteType | ShortType | IntegerType | LongType if _scale >= 0 =>
283+ childExpr // _scale(I.e. decimal place) >= 0 is a no-op for integer types in Spark
284+ case _ : FloatType | DoubleType =>
285+ // We cannot properly match with the Spark behavior for floating-point numbers.
286+ // Spark uses BigDecimal for rounding float/double, and BigDecimal fist converts a
287+ // double to string internally in order to create its own internal representation.
288+ // The problem is BigDecimal uses java.lang.Double.toString() and it has complicated
289+ // rounding algorithm. E.g. -5.81855622136895E8 is actually
290+ // -581855622.13689494132995605468750. Note the 5th fractional digit is 4 instead of
291+ // 5. Java(Scala)'s toString() rounds it up to -581855622.136895. This makes a
292+ // difference when rounding at 5th digit, I.e. round(-5.81855622136895E8, 5) should be
293+ // -5.818556221369E8, instead of -5.8185562213689E8. There is also an example that
294+ // toString() does NOT round up. 6.1317116247283497E18 is 6131711624728349696. It can
295+ // be rounded up to 6.13171162472835E18 that still represents the same double number.
296+ // I.e. 6.13171162472835E18 == 6.1317116247283497E18. However, toString() does not.
297+ // That results in round(6.1317116247283497E18, -5) == 6.1317116247282995E18 instead
298+ // of 6.1317116247283999E18.
299+ withInfo(r, " Comet does not support Spark's BigDecimal rounding" )
300+ None
301+ case _ =>
302+ // `scale` must be Int64 type in DataFusion
303+ val scaleExpr = exprToProtoInternal(Literal (_scale.toLong, LongType ), inputs, binding)
304+ val optExpr =
305+ scalarFunctionExprToProtoWithReturnType(" round" , r.dataType, childExpr, scaleExpr)
306+ optExprWithInfo(optExpr, r, r.child)
307+ }
308+
309+ }
310+ }
0 commit comments