Skip to content

Commit

Permalink
fix(spark): incorrect deriveRecordType() for Expand
Browse files Browse the repository at this point in the history
In the Expand relation, the record type was being calculated
incorrectly, leading to errors when round-tripping to protobuf and back.

Signed-off-by: Andrew Coleman <[email protected]>
  • Loading branch information
andrew-coleman committed Nov 5, 2024
1 parent e3139c6 commit 67ff12c
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 8 deletions.
17 changes: 10 additions & 7 deletions core/src/main/java/io/substrait/relation/Expand.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@ public abstract class Expand extends SingleInputRel {
@Override
public Type.Struct deriveRecordType() {
Type.Struct initial = getInput().getRecordType();
return TypeCreator.of(initial.nullable())
.struct(Stream.concat(initial.fields().stream(), Stream.of(TypeCreator.REQUIRED.I64)));
var fields =
getFields().isEmpty()
? initial.fields().stream()
: Stream.concat(initial.fields().stream(), getFields().get(0).getTypes());
return TypeCreator.of(initial.nullable()).struct(fields);
}

@Override
Expand All @@ -31,15 +34,15 @@ public static ImmutableExpand.Builder builder() {
}

public interface ExpandField {
Type getType();
Stream<Type> getTypes();
}

@Value.Immutable
public abstract static class ConsistentField implements ExpandField {
public abstract Expression getExpression();

public Type getType() {
return getExpression().getType();
public Stream<Type> getTypes() {
return Stream.of(getExpression().getType());
}

public static ImmutableExpand.ConsistentField.Builder builder() {
Expand All @@ -51,8 +54,8 @@ public static ImmutableExpand.ConsistentField.Builder builder() {
public abstract static class SwitchingField implements ExpandField {
public abstract List<Expression> getDuplicates();

public Type getType() {
return getDuplicates().get(0).getType();
public Stream<Type> getTypes() {
return getDuplicates().stream().map(Expression::getType);
}

public static ImmutableExpand.SwitchingField.Builder builder() {
Expand Down
2 changes: 2 additions & 0 deletions spark/src/main/scala/io/substrait/spark/SparkExtension.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ object SparkExtension {
private val EXTENSION_COLLECTION: SimpleExtension.ExtensionCollection =
SimpleExtension.loadDefaults()

val COLLECTION: SimpleExtension.ExtensionCollection = EXTENSION_COLLECTION.merge(SparkImpls)

lazy val SparkScalarFunctions: Seq[SimpleExtension.ScalarFunctionVariant] = {
val ret = new collection.mutable.ArrayBuffer[SimpleExtension.ScalarFunctionVariant]()
ret.appendAll(EXTENSION_COLLECTION.scalarFunctions().asScala)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import io.substrait.debug.TreePrinter
import io.substrait.extension.ExtensionCollector
import io.substrait.plan.{Plan, PlanProtoConverter, ProtoPlanConverter}
import io.substrait.proto
import io.substrait.relation.RelProtoConverter
import io.substrait.relation.{ProtoRelConverter, RelProtoConverter}
import org.scalactic.Equality
import org.scalactic.source.Position
import org.scalatest.Succeeded
Expand Down Expand Up @@ -93,6 +93,10 @@ trait SubstraitPlanTestBase { self: SharedSparkSession =>
require(logicalPlan2.resolved);
val pojoRel2 = new ToSubstraitRel().visit(logicalPlan2)

val extensionCollector = new ExtensionCollector;
val proto = new RelProtoConverter(extensionCollector).toProto(pojoRel)
new ProtoRelConverter(extensionCollector, SparkExtension.COLLECTION).from(proto)

pojoRel2.shouldEqualPlainly(pojoRel)
logicalPlan2
}
Expand Down

0 comments on commit 67ff12c

Please sign in to comment.