Skip to content

Commit e4bcd9d

Browse files
committed
refactor cast
1 parent 5472697 commit e4bcd9d

File tree

5 files changed

+29
-10
lines changed

5 files changed

+29
-10
lines changed

spark/src/main/scala/org/apache/comet/GenerateDocs.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ import scala.collection.mutable.ListBuffer
2525

2626
import org.apache.spark.sql.catalyst.expressions.Cast
2727

28-
import org.apache.comet.expressions.{CometCast, CometEvalMode, Compatible, Incompatible}
28+
import org.apache.comet.expressions.CometEvalMode
29+
import org.apache.comet.serde.{CometCast, Compatible, Incompatible}
2930

3031
/**
3132
* Utility for generating markdown documentation from the configs.

spark/src/main/scala/org/apache/comet/expressions/CometCast.scala renamed to spark/src/main/scala/org/apache/comet/serde/CometCast.scala

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,14 @@
1717
* under the License.
1818
*/
1919

20-
package org.apache.comet.expressions
20+
package org.apache.comet.serde
2121

22-
import org.apache.spark.sql.types.{ArrayType, DataType, DataTypes, DecimalType, NullType, StructType}
22+
import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast}
23+
import org.apache.spark.sql.types._
24+
25+
import org.apache.comet.expressions.CometEvalMode
26+
import org.apache.comet.serde.QueryPlanSerde.handleCast
27+
import org.apache.comet.shims.CometExprShim
2328

2429
sealed trait SupportLevel
2530

@@ -32,7 +37,20 @@ case class Incompatible(notes: Option[String] = None) extends SupportLevel
3237
/** We do not support this feature */
3338
object Unsupported extends SupportLevel
3439

35-
object CometCast {
40+
object CometCast
41+
extends CometExpressionSerde[Cast]
42+
with CometExprShim
43+
with IncompatExpr
44+
with IncompatAnsiExpr[Cast] {
45+
46+
override def isAnsiMode(expr: Cast): Boolean = expr.ansiEnabled
47+
48+
override def convert(
49+
expr: Cast,
50+
inputs: Seq[Attribute],
51+
binding: Boolean): Option[ExprOuterClass.Expr] = {
52+
handleCast(expr, expr.child, inputs, binding, expr.dataType, expr.timeZoneId, evalMode(expr))
53+
}
3654

3755
def supportedTypes: Seq[DataType] =
3856
Seq(

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,8 @@ object QueryPlanSerde extends Logging with CometExprShim {
169169
classOf[DateAdd] -> CometDateAdd,
170170
classOf[DateSub] -> CometDateSub,
171171
classOf[TruncDate] -> CometTruncDate,
172-
classOf[TruncTimestamp] -> CometTruncTimestamp)
172+
classOf[TruncTimestamp] -> CometTruncTimestamp,
173+
classOf[Cast] -> CometCast)
173174

174175
/**
175176
* Mapping of Spark aggregate expression class to Comet expression handler.
@@ -689,9 +690,6 @@ object QueryPlanSerde extends Logging with CometExprShim {
689690
Some(timeZoneId),
690691
CometEvalMode.TRY)
691692

692-
case c @ Cast(child, dt, timeZoneId, _) =>
693-
handleCast(expr, child, inputs, binding, dt, timeZoneId, evalMode(c))
694-
695693
case EqualTo(left, right) =>
696694
createBinaryExpr(
697695
expr,

spark/src/test/scala/org/apache/comet/CometCastSuite.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ import org.apache.spark.sql.internal.SQLConf
3333
import org.apache.spark.sql.types.{DataType, DataTypes, DecimalType, StructField, StructType}
3434

3535
import org.apache.comet.CometSparkSessionExtensions.isSpark40Plus
36-
import org.apache.comet.expressions.{CometCast, CometEvalMode, Compatible}
36+
import org.apache.comet.expressions.CometEvalMode
37+
import org.apache.comet.serde.{CometCast, Compatible}
3738

3839
class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
3940

spark/src/test/spark-3.5/org/apache/spark/sql/CometToPrettyStringSuite.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
package org.apache.spark.sql
2121

2222
import org.apache.comet.CometConf
23-
import org.apache.comet.expressions.{CometCast, CometEvalMode, Compatible}
23+
import org.apache.comet.expressions.CometEvalMode
24+
import org.apache.comet.serde.{CometCast, Compatible}
2425
import org.apache.comet.testing.{DataGenOptions, ParquetGenerator}
2526
import org.apache.commons.io.FileUtils
2627
import org.apache.spark.sql.catalyst.TableIdentifier

0 commit comments

Comments
 (0)