diff --git a/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java b/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java index e87c071a0..b330c0d09 100644 --- a/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java +++ b/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java @@ -23,6 +23,7 @@ import io.substrait.relation.Rel; import io.substrait.relation.Set; import io.substrait.relation.Sort; +import io.substrait.relation.physical.HashJoin; import io.substrait.type.ImmutableType; import io.substrait.type.NamedStruct; import io.substrait.type.Type; @@ -165,6 +166,34 @@ private Join join( .build(); } + public HashJoin hashJoin( + List leftKeys, + List rightKeys, + HashJoin.JoinType joinType, + Rel left, + Rel right) { + return hashJoin(leftKeys, rightKeys, joinType, Optional.empty(), left, right); + } + + public HashJoin hashJoin( + List leftKeys, + List rightKeys, + HashJoin.JoinType joinType, + Optional remap, + Rel left, + Rel right) { + return HashJoin.builder() + .left(left) + .right(right) + .leftKeys( + this.fieldReferences(left, leftKeys.stream().mapToInt(Integer::intValue).toArray())) + .rightKeys( + this.fieldReferences(right, rightKeys.stream().mapToInt(Integer::intValue).toArray())) + .joinType(joinType) + .remap(remap) + .build(); + } + public NamedScan namedScan( Iterable tableName, Iterable columnNames, Iterable types) { return namedScan(tableName, columnNames, types, Optional.empty()); diff --git a/core/src/main/java/io/substrait/relation/AbstractRelVisitor.java b/core/src/main/java/io/substrait/relation/AbstractRelVisitor.java index f46e8899a..645f692e2 100644 --- a/core/src/main/java/io/substrait/relation/AbstractRelVisitor.java +++ b/core/src/main/java/io/substrait/relation/AbstractRelVisitor.java @@ -1,5 +1,7 @@ package io.substrait.relation; +import io.substrait.relation.physical.HashJoin; + public abstract class AbstractRelVisitor implements RelVisitor { public abstract OUTPUT visitFallback(Rel rel); @@ -83,4 +85,9 @@ public OUTPUT visit(ExtensionMulti extensionMulti) throws EXCEPTION { public OUTPUT visit(ExtensionTable extensionTable) throws EXCEPTION { return visitFallback(extensionTable); } + + @Override + public OUTPUT visit(HashJoin hashJoin) throws EXCEPTION { + return visitFallback(hashJoin); + } } diff --git a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java index 0cc846f25..9ae2cb959 100644 --- a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java +++ b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java @@ -17,6 +17,7 @@ import io.substrait.proto.ExtensionSingleRel; import io.substrait.proto.FetchRel; import io.substrait.proto.FilterRel; +import io.substrait.proto.HashJoinRel; import io.substrait.proto.JoinRel; import io.substrait.proto.ProjectRel; import io.substrait.proto.ReadRel; @@ -27,6 +28,7 @@ import io.substrait.relation.files.FileOrFiles; import io.substrait.relation.files.ImmutableFileFormat; import io.substrait.relation.files.ImmutableFileOrFiles; +import io.substrait.relation.physical.HashJoin; import io.substrait.type.ImmutableNamedStruct; import io.substrait.type.NamedStruct; import io.substrait.type.Type; @@ -95,6 +97,9 @@ public Rel from(io.substrait.proto.Rel rel) { case EXTENSION_MULTI -> { return newExtensionMulti(rel.getExtensionMulti()); } + case HASH_JOIN -> { + return newHashJoin(rel.getHashJoin()); + } default -> { throw new UnsupportedOperationException("Unsupported RelTypeCase of " + relType); } @@ -490,6 +495,38 @@ private Set newSet(SetRel rel) { return builder.build(); } + private Rel newHashJoin(HashJoinRel rel) { + Rel left = from(rel.getLeft()); + Rel right = from(rel.getRight()); + var leftKeys = rel.getLeftKeysList(); + var rightKeys = rel.getRightKeysList(); + + Type.Struct leftStruct = left.getRecordType(); + Type.Struct rightStruct = right.getRecordType(); + Type.Struct unionedStruct = Type.Struct.builder().from(leftStruct).from(rightStruct).build(); + var leftConverter = new ProtoExpressionConverter(lookup, extensions, leftStruct, this); + var rightConverter = new ProtoExpressionConverter(lookup, extensions, rightStruct, this); + var unionConverter = new ProtoExpressionConverter(lookup, extensions, unionedStruct, this); + var builder = + HashJoin.builder() + .left(left) + .right(right) + .leftKeys(leftKeys.stream().map(leftConverter::from).collect(Collectors.toList())) + .rightKeys(rightKeys.stream().map(rightConverter::from).collect(Collectors.toList())) + .joinType(HashJoin.JoinType.fromProto(rel.getType())) + .postJoinFilter( + Optional.ofNullable( + rel.hasPostJoinFilter() ? unionConverter.from(rel.getPostJoinFilter()) : null)); + + builder + .commonExtension(optionalAdvancedExtension(rel.getCommon())) + .remap(optionalRelmap(rel.getCommon())); + if (rel.hasAdvancedExtension()) { + builder.extension(advancedExtension(rel.getAdvancedExtension())); + } + return builder.build(); + } + private static Optional optionalRelmap(io.substrait.proto.RelCommon relCommon) { return Optional.ofNullable( relCommon.hasEmit() ? Rel.Remap.of(relCommon.getEmit().getOutputMappingList()) : null); diff --git a/core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java b/core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java index 14ef98a03..0dddfbd98 100644 --- a/core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java +++ b/core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java @@ -6,6 +6,8 @@ import io.substrait.expression.FieldReference; import io.substrait.expression.FunctionArg; import io.substrait.expression.ImmutableFieldReference; +import io.substrait.relation.physical.HashJoin; +import io.substrait.relation.physical.ImmutableHashJoin; import io.substrait.type.Type; import java.util.ArrayList; import java.util.List; @@ -166,6 +168,29 @@ public Optional visit(Cross cross) throws RuntimeException { .build()); } + @Override + public Optional visit(HashJoin hashJoin) throws RuntimeException { + var left = hashJoin.getLeft().accept(this); + var right = hashJoin.getRight().accept(this); + var leftKeys = hashJoin.getLeftKeys(); + var rightKeys = hashJoin.getRightKeys(); + var postFilter = hashJoin.getPostJoinFilter().flatMap(t -> visitExpression(t)); + if (allEmpty(left, right, postFilter)) { + return Optional.empty(); + } + return Optional.of( + ImmutableHashJoin.builder() + .from(hashJoin) + .left(left.orElse(hashJoin.getLeft())) + .right(right.orElse(hashJoin.getRight())) + .leftKeys(leftKeys) + .rightKeys(rightKeys) + .postJoinFilter( + Optional.ofNullable( + postFilter.orElseGet(() -> hashJoin.getPostJoinFilter().orElse(null)))) + .build()); + } + private Optional visitExpression(Expression expression) { ExpressionVisitor, RuntimeException> visitor = new AbstractExpressionVisitor<>() { diff --git a/core/src/main/java/io/substrait/relation/RelProtoConverter.java b/core/src/main/java/io/substrait/relation/RelProtoConverter.java index 25c181037..d9b19f951 100644 --- a/core/src/main/java/io/substrait/relation/RelProtoConverter.java +++ b/core/src/main/java/io/substrait/relation/RelProtoConverter.java @@ -1,6 +1,7 @@ package io.substrait.relation; import io.substrait.expression.Expression; +import io.substrait.expression.FieldReference; import io.substrait.expression.FunctionArg; import io.substrait.expression.proto.ExpressionProtoConverter; import io.substrait.extension.ExtensionCollector; @@ -12,6 +13,7 @@ import io.substrait.proto.ExtensionSingleRel; import io.substrait.proto.FetchRel; import io.substrait.proto.FilterRel; +import io.substrait.proto.HashJoinRel; import io.substrait.proto.JoinRel; import io.substrait.proto.ProjectRel; import io.substrait.proto.ReadRel; @@ -21,6 +23,7 @@ import io.substrait.proto.SortField; import io.substrait.proto.SortRel; import io.substrait.relation.files.FileOrFiles; +import io.substrait.relation.physical.HashJoin; import io.substrait.type.proto.TypeProtoConverter; import java.util.Collection; import java.util.List; @@ -68,6 +71,10 @@ private List toProtoS(Collection sorts) { .collect(java.util.stream.Collectors.toList()); } + private io.substrait.proto.Expression.FieldReference toProto(FieldReference fieldReference) { + return fieldReference.accept(exprProtoConverter).getSelection(); + } + @Override public Rel visit(Aggregate aggregate) throws RuntimeException { var builder = @@ -166,6 +173,8 @@ public Rel visit(Join join) throws RuntimeException { join.getCondition().ifPresent(t -> builder.setExpression(toProto(t))); + join.getPostJoinFilter().ifPresent(t -> builder.setPostJoinFilter(toProto(t))); + join.getExtension().ifPresent(ae -> builder.setAdvancedExtension(ae.toProto())); return Rel.newBuilder().setJoin(builder).build(); } @@ -228,6 +237,31 @@ public Rel visit(ExtensionTable extensionTable) throws RuntimeException { return Rel.newBuilder().setRead(builder).build(); } + @Override + public Rel visit(HashJoin hashJoin) throws RuntimeException { + var builder = + HashJoinRel.newBuilder() + .setCommon(common(hashJoin)) + .setLeft(toProto(hashJoin.getLeft())) + .setRight(toProto(hashJoin.getRight())) + .setType(hashJoin.getJoinType().toProto()); + + List leftKeys = hashJoin.getLeftKeys(); + List rightKeys = hashJoin.getRightKeys(); + + if (leftKeys.size() != rightKeys.size()) { + throw new RuntimeException("Number of left and right keys must be equal."); + } + + builder.addAllLeftKeys(leftKeys.stream().map(this::toProto).collect(Collectors.toList())); + builder.addAllRightKeys(rightKeys.stream().map(this::toProto).collect(Collectors.toList())); + + hashJoin.getPostJoinFilter().ifPresent(t -> builder.setPostJoinFilter(toProto(t))); + + hashJoin.getExtension().ifPresent(ae -> builder.setAdvancedExtension(ae.toProto())); + return Rel.newBuilder().setHashJoin(builder).build(); + } + @Override public Rel visit(Project project) throws RuntimeException { var builder = diff --git a/core/src/main/java/io/substrait/relation/RelVisitor.java b/core/src/main/java/io/substrait/relation/RelVisitor.java index c685cd98d..e8e78aaf7 100644 --- a/core/src/main/java/io/substrait/relation/RelVisitor.java +++ b/core/src/main/java/io/substrait/relation/RelVisitor.java @@ -1,5 +1,7 @@ package io.substrait.relation; +import io.substrait.relation.physical.HashJoin; + public interface RelVisitor { OUTPUT visit(Aggregate aggregate) throws EXCEPTION; @@ -32,4 +34,6 @@ public interface RelVisitor { OUTPUT visit(ExtensionMulti extensionMulti) throws EXCEPTION; OUTPUT visit(ExtensionTable extensionTable) throws EXCEPTION; + + OUTPUT visit(HashJoin hashJoin) throws EXCEPTION; } diff --git a/core/src/main/java/io/substrait/relation/physical/HashJoin.java b/core/src/main/java/io/substrait/relation/physical/HashJoin.java new file mode 100644 index 000000000..6d0e68f8a --- /dev/null +++ b/core/src/main/java/io/substrait/relation/physical/HashJoin.java @@ -0,0 +1,85 @@ +package io.substrait.relation.physical; + +import io.substrait.expression.Expression; +import io.substrait.expression.FieldReference; +import io.substrait.proto.HashJoinRel; +import io.substrait.relation.BiRel; +import io.substrait.relation.HasExtension; +import io.substrait.relation.RelVisitor; +import io.substrait.type.Type; +import io.substrait.type.TypeCreator; +import java.util.List; +import java.util.Optional; +import java.util.stream.Stream; +import org.immutables.value.Value; + +@Value.Immutable +public abstract class HashJoin extends BiRel implements HasExtension { + + public abstract List getLeftKeys(); + + public abstract List getRightKeys(); + + public abstract JoinType getJoinType(); + + public abstract Optional getPostJoinFilter(); + + public static enum JoinType { + UNKNOWN(HashJoinRel.JoinType.JOIN_TYPE_UNSPECIFIED), + INNER(HashJoinRel.JoinType.JOIN_TYPE_INNER), + OUTER(HashJoinRel.JoinType.JOIN_TYPE_OUTER), + LEFT(HashJoinRel.JoinType.JOIN_TYPE_LEFT), + RIGHT(HashJoinRel.JoinType.JOIN_TYPE_RIGHT), + LEFT_SEMI(HashJoinRel.JoinType.JOIN_TYPE_LEFT_SEMI), + RIGHT_SEMI(HashJoinRel.JoinType.JOIN_TYPE_RIGHT_SEMI), + LEFT_ANTI(HashJoinRel.JoinType.JOIN_TYPE_LEFT_ANTI), + RIGHT_ANTI(HashJoinRel.JoinType.JOIN_TYPE_RIGHT_ANTI); + + private HashJoinRel.JoinType proto; + + JoinType(HashJoinRel.JoinType proto) { + this.proto = proto; + } + + public static JoinType fromProto(HashJoinRel.JoinType proto) { + for (var v : values()) { + if (v.proto == proto) { + return v; + } + } + throw new IllegalArgumentException("Unknown type: " + proto); + } + + public HashJoinRel.JoinType toProto() { + return proto; + } + } + + @Override + protected Type.Struct deriveRecordType() { + Stream leftTypes = + switch (getJoinType()) { + case RIGHT, OUTER -> getLeft().getRecordType().fields().stream() + .map(TypeCreator::asNullable); + case RIGHT_ANTI, RIGHT_SEMI -> Stream.empty(); + default -> getLeft().getRecordType().fields().stream(); + }; + Stream rightTypes = + switch (getJoinType()) { + case LEFT, OUTER -> getRight().getRecordType().fields().stream() + .map(TypeCreator::asNullable); + case LEFT_ANTI, LEFT_SEMI -> Stream.empty(); + default -> getRight().getRecordType().fields().stream(); + }; + return TypeCreator.REQUIRED.struct(Stream.concat(leftTypes, rightTypes)); + } + + @Override + public O accept(RelVisitor visitor) throws E { + return visitor.visit(this); + } + + public static ImmutableHashJoin.Builder builder() { + return ImmutableHashJoin.builder(); + } +} diff --git a/core/src/test/java/io/substrait/type/proto/ExtensionRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/ExtensionRoundtripTest.java index d340cd012..ad492470f 100644 --- a/core/src/test/java/io/substrait/type/proto/ExtensionRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/ExtensionRoundtripTest.java @@ -26,12 +26,14 @@ import io.substrait.relation.Set; import io.substrait.relation.Sort; import io.substrait.relation.VirtualTableScan; +import io.substrait.relation.physical.HashJoin; import io.substrait.relation.utils.StringHolder; import io.substrait.relation.utils.StringHolderHandlingProtoRelConverter; import io.substrait.type.NamedStruct; import io.substrait.type.Type; import io.substrait.type.TypeCreator; import java.util.Collections; +import java.util.List; import java.util.stream.Collectors; import java.util.stream.Stream; import org.junit.jupiter.api.Nested; @@ -174,6 +176,26 @@ void join() { verifyRoundTrip(rel); } + @Test + void hashJoin() { + // with empty keys + List leftEmptyKeys = Collections.emptyList(); + List rightEmptyKeys = Collections.emptyList(); + Rel relWithoutKeys = + HashJoin.builder() + .from( + b.hashJoin( + leftEmptyKeys, + rightEmptyKeys, + HashJoin.JoinType.INNER, + commonTable, + commonTable)) + .commonExtension(commonExtension) + .extension(relExtension) + .build(); + verifyRoundTrip(relWithoutKeys); + } + @Test void project() { Rel rel = diff --git a/core/src/test/java/io/substrait/type/proto/JoinRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/JoinRoundtripTest.java new file mode 100644 index 000000000..2204178f6 --- /dev/null +++ b/core/src/test/java/io/substrait/type/proto/JoinRoundtripTest.java @@ -0,0 +1,60 @@ +package io.substrait.type.proto; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import io.substrait.TestBase; +import io.substrait.dsl.SubstraitBuilder; +import io.substrait.extension.ExtensionCollector; +import io.substrait.extension.SimpleExtension; +import io.substrait.relation.ProtoRelConverter; +import io.substrait.relation.Rel; +import io.substrait.relation.RelProtoConverter; +import io.substrait.relation.physical.HashJoin; +import io.substrait.relation.utils.StringHolderHandlingProtoRelConverter; +import io.substrait.type.TypeCreator; +import java.util.Arrays; +import java.util.List; +import org.junit.jupiter.api.Test; + +public class JoinRoundtripTest extends TestBase { + + final SimpleExtension.ExtensionCollection extensions = defaultExtensionCollection; + + TypeCreator R = TypeCreator.REQUIRED; + + final SubstraitBuilder b = new SubstraitBuilder(extensions); + + final ExtensionCollector functionCollector = new ExtensionCollector(); + final RelProtoConverter relProtoConverter = new RelProtoConverter(functionCollector); + final ProtoRelConverter protoRelConverter = + new StringHolderHandlingProtoRelConverter(functionCollector, extensions); + + final Rel leftTable = + b.namedScan( + Arrays.asList("T1"), + Arrays.asList("a", "b", "c"), + Arrays.asList(R.I64, R.FP64, R.STRING)); + + final Rel rightTable = + b.namedScan( + Arrays.asList("T2"), + Arrays.asList("d", "e", "f"), + Arrays.asList(R.FP64, R.STRING, R.I64)); + + void verifyRoundTrip(Rel rel) { + io.substrait.proto.Rel protoRel = relProtoConverter.toProto(rel); + Rel relReturned = protoRelConverter.from(protoRel); + assertEquals(rel, relReturned); + } + + @Test + void hashJoin() { + List leftKeys = Arrays.asList(0, 1); + List rightKeys = Arrays.asList(2, 0); + Rel relWithoutKeys = + HashJoin.builder() + .from(b.hashJoin(leftKeys, rightKeys, HashJoin.JoinType.INNER, leftTable, rightTable)) + .build(); + verifyRoundTrip(relWithoutKeys); + } +}