Skip to content

Commit

Permalink
Merge branch 'apache:master' into aggregation-memory-usage-optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
akupchinskiy authored Dec 27, 2024
2 parents 44a263d + b309db0 commit 6f526e8
Show file tree
Hide file tree
Showing 108 changed files with 3,821 additions and 1,015 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ import org.apache.spark.sql.errors.DataTypeErrors.toSQLId
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.expressions.SparkUserDefinedFunction
import org.apache.spark.sql.functions.{struct, to_json}
import org.apache.spark.sql.internal.{ColumnNodeToProtoConverter, DataFrameWriterImpl, DataFrameWriterV2Impl, MergeIntoWriterImpl, ToScalaUDF, UDFAdaptors, UnresolvedAttribute, UnresolvedRegex}
import org.apache.spark.sql.internal.{ColumnNodeToProtoConverter, DataFrameWriterImpl, DataFrameWriterV2Impl, MergeIntoWriterImpl, SubqueryExpressionNode, SubqueryType, ToScalaUDF, UDFAdaptors, UnresolvedAttribute, UnresolvedRegex}
import org.apache.spark.sql.streaming.DataStreamWriter
import org.apache.spark.sql.types.{Metadata, StructType}
import org.apache.spark.storage.StorageLevel
Expand Down Expand Up @@ -288,9 +288,10 @@ class Dataset[T] private[sql] (
/** @inheritdoc */
def stat: DataFrameStatFunctions = new DataFrameStatFunctions(toDF())

private def buildJoin(right: Dataset[_])(f: proto.Join.Builder => Unit): DataFrame = {
private def buildJoin(right: Dataset[_], cols: Seq[Column] = Seq.empty)(
f: proto.Join.Builder => Unit): DataFrame = {
checkSameSparkSession(right)
sparkSession.newDataFrame { builder =>
sparkSession.newDataFrame(cols) { builder =>
val joinBuilder = builder.getJoinBuilder
joinBuilder.setLeft(plan.getRoot).setRight(right.plan.getRoot)
f(joinBuilder)
Expand Down Expand Up @@ -334,7 +335,7 @@ class Dataset[T] private[sql] (

/** @inheritdoc */
def join(right: Dataset[_], joinExprs: Column, joinType: String): DataFrame = {
buildJoin(right) { builder =>
buildJoin(right, Seq(joinExprs)) { builder =>
builder
.setJoinType(toJoinType(joinType))
.setJoinCondition(joinExprs.expr)
Expand Down Expand Up @@ -394,7 +395,7 @@ class Dataset[T] private[sql] (
case _ =>
throw new IllegalArgumentException(s"Unsupported lateral join type $joinType")
}
sparkSession.newDataFrame { builder =>
sparkSession.newDataFrame(joinExprs.toSeq) { builder =>
val lateralJoinBuilder = builder.getLateralJoinBuilder
lateralJoinBuilder.setLeft(plan.getRoot).setRight(right.plan.getRoot)
joinExprs.foreach(c => lateralJoinBuilder.setJoinCondition(c.expr))
Expand Down Expand Up @@ -426,7 +427,7 @@ class Dataset[T] private[sql] (
val sortExprs = sortCols.map { c =>
ColumnNodeToProtoConverter(c.sortOrder).getSortOrder
}
sparkSession.newDataset(agnosticEncoder) { builder =>
sparkSession.newDataset(agnosticEncoder, sortCols) { builder =>
builder.getSortBuilder
.setInput(plan.getRoot)
.setIsGlobal(global)
Expand Down Expand Up @@ -502,37 +503,40 @@ class Dataset[T] private[sql] (
* methods and typed select methods is the encoder used to build the return dataset.
*/
private def selectUntyped(encoder: AgnosticEncoder[_], cols: Seq[Column]): Dataset[_] = {
sparkSession.newDataset(encoder) { builder =>
sparkSession.newDataset(encoder, cols) { builder =>
builder.getProjectBuilder
.setInput(plan.getRoot)
.addAllExpressions(cols.map(_.typedExpr(this.encoder)).asJava)
}
}

/** @inheritdoc */
def filter(condition: Column): Dataset[T] = sparkSession.newDataset(agnosticEncoder) {
builder =>
def filter(condition: Column): Dataset[T] = {
sparkSession.newDataset(agnosticEncoder, Seq(condition)) { builder =>
builder.getFilterBuilder.setInput(plan.getRoot).setCondition(condition.expr)
}
}

private def buildUnpivot(
ids: Array[Column],
valuesOption: Option[Array[Column]],
variableColumnName: String,
valueColumnName: String): DataFrame = sparkSession.newDataFrame { builder =>
val unpivot = builder.getUnpivotBuilder
.setInput(plan.getRoot)
.addAllIds(ids.toImmutableArraySeq.map(_.expr).asJava)
.setVariableColumnName(variableColumnName)
.setValueColumnName(valueColumnName)
valuesOption.foreach { values =>
unpivot.getValuesBuilder
.addAllValues(values.toImmutableArraySeq.map(_.expr).asJava)
valueColumnName: String): DataFrame = {
sparkSession.newDataFrame(ids.toSeq ++ valuesOption.toSeq.flatten) { builder =>
val unpivot = builder.getUnpivotBuilder
.setInput(plan.getRoot)
.addAllIds(ids.toImmutableArraySeq.map(_.expr).asJava)
.setVariableColumnName(variableColumnName)
.setValueColumnName(valueColumnName)
valuesOption.foreach { values =>
unpivot.getValuesBuilder
.addAllValues(values.toImmutableArraySeq.map(_.expr).asJava)
}
}
}

private def buildTranspose(indices: Seq[Column]): DataFrame =
sparkSession.newDataFrame { builder =>
sparkSession.newDataFrame(indices) { builder =>
val transpose = builder.getTransposeBuilder.setInput(plan.getRoot)
indices.foreach { indexColumn =>
transpose.addIndexColumns(indexColumn.expr)
Expand Down Expand Up @@ -624,18 +628,15 @@ class Dataset[T] private[sql] (
def transpose(): DataFrame =
buildTranspose(Seq.empty)

// TODO(SPARK-50134): Support scalar Subquery API in Spark Connect
// scalastyle:off not.implemented.error.usage
/** @inheritdoc */
def scalar(): Column = {
???
Column(SubqueryExpressionNode(plan.getRoot, SubqueryType.SCALAR))
}

/** @inheritdoc */
def exists(): Column = {
???
Column(SubqueryExpressionNode(plan.getRoot, SubqueryType.EXISTS))
}
// scalastyle:on not.implemented.error.usage

/** @inheritdoc */
def limit(n: Int): Dataset[T] = sparkSession.newDataset(agnosticEncoder) { builder =>
Expand Down Expand Up @@ -782,7 +783,7 @@ class Dataset[T] private[sql] (
val aliases = values.zip(names).map { case (value, name) =>
value.name(name).expr.getAlias
}
sparkSession.newDataFrame { builder =>
sparkSession.newDataFrame(values) { builder =>
builder.getWithColumnsBuilder
.setInput(plan.getRoot)
.addAllAliases(aliases.asJava)
Expand Down Expand Up @@ -842,10 +843,12 @@ class Dataset[T] private[sql] (
@scala.annotation.varargs
def drop(col: Column, cols: Column*): DataFrame = buildDrop(col +: cols)

private def buildDrop(cols: Seq[Column]): DataFrame = sparkSession.newDataFrame { builder =>
builder.getDropBuilder
.setInput(plan.getRoot)
.addAllColumns(cols.map(_.expr).asJava)
private def buildDrop(cols: Seq[Column]): DataFrame = {
sparkSession.newDataFrame(cols) { builder =>
builder.getDropBuilder
.setInput(plan.getRoot)
.addAllColumns(cols.map(_.expr).asJava)
}
}

private def buildDropByNames(cols: Seq[String]): DataFrame = sparkSession.newDataFrame {
Expand Down Expand Up @@ -1015,12 +1018,13 @@ class Dataset[T] private[sql] (

private def buildRepartitionByExpression(
numPartitions: Option[Int],
partitionExprs: Seq[Column]): Dataset[T] = sparkSession.newDataset(agnosticEncoder) {
builder =>
partitionExprs: Seq[Column]): Dataset[T] = {
sparkSession.newDataset(agnosticEncoder, partitionExprs) { builder =>
val repartitionBuilder = builder.getRepartitionByExpressionBuilder
.setInput(plan.getRoot)
.addAllPartitionExprs(partitionExprs.map(_.expr).asJava)
numPartitions.foreach(repartitionBuilder.setNumPartitions)
}
}

/** @inheritdoc */
Expand Down Expand Up @@ -1152,7 +1156,7 @@ class Dataset[T] private[sql] (
/** @inheritdoc */
@scala.annotation.varargs
def observe(name: String, expr: Column, exprs: Column*): Dataset[T] = {
sparkSession.newDataset(agnosticEncoder) { builder =>
sparkSession.newDataset(agnosticEncoder, expr +: exprs) { builder =>
builder.getCollectMetricsBuilder
.setInput(plan.getRoot)
.setName(name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class RelationalGroupedDataset private[sql] (
import df.sparkSession.RichColumn

protected def toDF(aggExprs: Seq[Column]): DataFrame = {
df.sparkSession.newDataFrame { builder =>
df.sparkSession.newDataFrame(groupingExprs ++ aggExprs) { builder =>
val aggBuilder = builder.getAggregateBuilder
.setInput(df.plan.getRoot)
groupingExprs.foreach(c => aggBuilder.addGroupingExpressions(c.expr))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ import org.apache.spark.sql.connect.client.{ClassFinder, CloseableIterator, Spar
import org.apache.spark.sql.connect.client.SparkConnectClient.Configuration
import org.apache.spark.sql.connect.client.arrow.ArrowSerializer
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.internal.{CatalogImpl, ConnectRuntimeConfig, SessionCleaner, SessionState, SharedState, SqlApiConf}
import org.apache.spark.sql.internal.{CatalogImpl, ConnectRuntimeConfig, SessionCleaner, SessionState, SharedState, SqlApiConf, SubqueryExpressionNode}
import org.apache.spark.sql.internal.ColumnNodeToProtoConverter.{toExpr, toTypedExpr}
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.streaming.DataStreamReader
Expand Down Expand Up @@ -324,20 +324,111 @@ class SparkSession private[sql] (
}
}

/**
* Create a DataFrame including the proto plan built by the given function.
*
* @param f
* The function to build the proto plan.
* @return
* The DataFrame created from the proto plan.
*/
@Since("4.0.0")
@DeveloperApi
def newDataFrame(f: proto.Relation.Builder => Unit): DataFrame = {
newDataset(UnboundRowEncoder)(f)
}

/**
* Create a DataFrame including the proto plan built by the given function.
*
* Use this method when columns are used to create a new DataFrame. When there are columns
* referring to other Dataset or DataFrame, the plan will be wrapped with a `WithRelation`.
*
* {{{
* with_relations [id 10]
* root: plan [id 9] using columns referring to other Dataset or DataFrame, holding plan ids
* reference:
* refs#1: [id 8] plan for the reference 1
* refs#2: [id 5] plan for the reference 2
* }}}
*
* @param cols
* The columns to be used in the DataFrame.
* @param f
* The function to build the proto plan.
* @return
* The DataFrame created from the proto plan.
*/
@Since("4.0.0")
@DeveloperApi
def newDataFrame(cols: Seq[Column])(f: proto.Relation.Builder => Unit): DataFrame = {
newDataset(UnboundRowEncoder, cols)(f)
}

/**
* Create a Dataset including the proto plan built by the given function.
*
* @param encoder
* The encoder for the Dataset.
* @param f
* The function to build the proto plan.
* @return
* The Dataset created from the proto plan.
*/
@Since("4.0.0")
@DeveloperApi
def newDataset[T](encoder: AgnosticEncoder[T])(
f: proto.Relation.Builder => Unit): Dataset[T] = {
newDataset[T](encoder, Seq.empty)(f)
}

/**
* Create a Dataset including the proto plan built by the given function.
*
* Use this method when columns are used to create a new Dataset. When there are columns
* referring to other Dataset or DataFrame, the plan will be wrapped with a `WithRelation`.
*
* {{{
* with_relations [id 10]
* root: plan [id 9] using columns referring to other Dataset or DataFrame, holding plan ids
* reference:
* refs#1: [id 8] plan for the reference 1
* refs#2: [id 5] plan for the reference 2
* }}}
*
* @param encoder
* The encoder for the Dataset.
* @param cols
* The columns to be used in the DataFrame.
* @param f
* The function to build the proto plan.
* @return
* The Dataset created from the proto plan.
*/
@Since("4.0.0")
@DeveloperApi
def newDataset[T](encoder: AgnosticEncoder[T], cols: Seq[Column])(
f: proto.Relation.Builder => Unit): Dataset[T] = {
val references = cols.flatMap(_.node.collect { case n: SubqueryExpressionNode =>
n.relation
})

val builder = proto.Relation.newBuilder()
f(builder)
builder.getCommonBuilder.setPlanId(planIdGenerator.getAndIncrement())
val plan = proto.Plan.newBuilder().setRoot(builder).build()

val rootBuilder = if (references.length == 0) {
builder
} else {
val rootBuilder = proto.Relation.newBuilder()
rootBuilder.getWithRelationsBuilder
.setRoot(builder)
.addAllReferences(references.asJava)
rootBuilder.getCommonBuilder.setPlanId(planIdGenerator.getAndIncrement())
rootBuilder
}

val plan = proto.Plan.newBuilder().setRoot(rootBuilder).build()
new Dataset[T](this, plan, encoder)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class TableValuedFunction(sparkSession: SparkSession) extends api.TableValuedFun
}

private def fn(name: String, args: Seq[Column]): Dataset[Row] = {
sparkSession.newDataFrame { builder =>
sparkSession.newDataFrame(args) { builder =>
builder.getUnresolvedTableValuedFunctionBuilder
.setFunctionName(name)
.addAllArguments(args.map(toExpr).asJava)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,15 @@ object ColumnNodeToProtoConverter extends (ColumnNode => proto.Expression) {
case LazyExpression(child, _) =>
builder.getLazyExpressionBuilder.setChild(apply(child, e))

case SubqueryExpressionNode(relation, subqueryType, _) =>
val b = builder.getSubqueryExpressionBuilder
b.setSubqueryType(subqueryType match {
case SubqueryType.SCALAR => proto.SubqueryExpression.SubqueryType.SUBQUERY_TYPE_SCALAR
case SubqueryType.EXISTS => proto.SubqueryExpression.SubqueryType.SUBQUERY_TYPE_EXISTS
})
assert(relation.hasCommon && relation.getCommon.hasPlanId)
b.setPlanId(relation.getCommon.getPlanId)

case ProtoColumnNode(e, _) =>
return e

Expand Down Expand Up @@ -217,4 +226,24 @@ case class ProtoColumnNode(
override val origin: Origin = CurrentOrigin.get)
extends ColumnNode {
override def sql: String = expr.toString
override private[internal] def children: Seq[ColumnNodeLike] = Seq.empty
}

sealed trait SubqueryType

object SubqueryType {
case object SCALAR extends SubqueryType
case object EXISTS extends SubqueryType
}

case class SubqueryExpressionNode(
relation: proto.Relation,
subqueryType: SubqueryType,
override val origin: Origin = CurrentOrigin.get)
extends ColumnNode {
override def sql: String = subqueryType match {
case SubqueryType.SCALAR => s"($relation)"
case _ => s"$subqueryType ($relation)"
}
override private[internal] def children: Seq[ColumnNodeLike] = Seq.empty
}
Loading

0 comments on commit 6f526e8

Please sign in to comment.