Skip to content

Commit

Permalink
feat: initial NestedLoopJoin support (#188)
Browse files Browse the repository at this point in the history
* feat: more builder support for field references
  • Loading branch information
danepitkin authored Nov 3, 2023
1 parent 6548670 commit b66d5b1
Show file tree
Hide file tree
Showing 9 changed files with 220 additions and 0 deletions.
35 changes: 35 additions & 0 deletions core/src/main/java/io/substrait/dsl/SubstraitBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import io.substrait.relation.Set;
import io.substrait.relation.Sort;
import io.substrait.relation.physical.HashJoin;
import io.substrait.relation.physical.NestedLoopJoin;
import io.substrait.type.ImmutableType;
import io.substrait.type.NamedStruct;
import io.substrait.type.Type;
Expand Down Expand Up @@ -218,6 +219,30 @@ private NamedScan namedScan(
return NamedScan.builder().names(tableName).initialSchema(namedStruct).remap(remap).build();
}

public NestedLoopJoin nestedLoopJoin(
Function<JoinInput, Expression> conditionFn,
NestedLoopJoin.JoinType joinType,
Rel left,
Rel right) {
return nestedLoopJoin(conditionFn, joinType, Optional.empty(), left, right);
}

private NestedLoopJoin nestedLoopJoin(
Function<JoinInput, Expression> conditionFn,
NestedLoopJoin.JoinType joinType,
Optional<Rel.Remap> remap,
Rel left,
Rel right) {
var condition = conditionFn.apply(new JoinInput(left, right));
return NestedLoopJoin.builder()
.left(left)
.right(right)
.condition(condition)
.joinType(joinType)
.remap(remap)
.build();
}

public Project project(Function<Rel, Iterable<? extends Expression>> expressionsFn, Rel input) {
return project(expressionsFn, Optional.empty(), input);
}
Expand Down Expand Up @@ -286,6 +311,16 @@ public List<FieldReference> fieldReferences(Rel input, int... indexes) {
.collect(java.util.stream.Collectors.toList());
}

public FieldReference fieldReference(List<Rel> inputs, int index) {
return ImmutableFieldReference.newInputRelReference(index, inputs);
}

public List<FieldReference> fieldReferences(List<Rel> inputs, int... indexes) {
return Arrays.stream(indexes)
.mapToObj(index -> fieldReference(inputs, index))
.collect(java.util.stream.Collectors.toList());
}

public Expression cast(Expression input, Type type) {
return Cast.builder()
.input(input)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package io.substrait.relation;

import io.substrait.relation.physical.HashJoin;
import io.substrait.relation.physical.NestedLoopJoin;

public abstract class AbstractRelVisitor<OUTPUT, EXCEPTION extends Exception>
implements RelVisitor<OUTPUT, EXCEPTION> {
Expand Down Expand Up @@ -31,6 +32,11 @@ public OUTPUT visit(Join join) throws EXCEPTION {
return visitFallback(join);
}

@Override
public OUTPUT visit(NestedLoopJoin nestedLoopJoin) throws EXCEPTION {
return visitFallback(nestedLoopJoin);
}

@Override
public OUTPUT visit(Set set) throws EXCEPTION {
return visitFallback(set);
Expand Down
32 changes: 32 additions & 0 deletions core/src/main/java/io/substrait/relation/ProtoRelConverter.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import io.substrait.proto.FilterRel;
import io.substrait.proto.HashJoinRel;
import io.substrait.proto.JoinRel;
import io.substrait.proto.NestedLoopJoinRel;
import io.substrait.proto.ProjectRel;
import io.substrait.proto.ReadRel;
import io.substrait.proto.SetRel;
Expand All @@ -27,6 +28,7 @@
import io.substrait.relation.files.ImmutableFileFormat;
import io.substrait.relation.files.ImmutableFileOrFiles;
import io.substrait.relation.physical.HashJoin;
import io.substrait.relation.physical.NestedLoopJoin;
import io.substrait.type.ImmutableNamedStruct;
import io.substrait.type.NamedStruct;
import io.substrait.type.Type;
Expand Down Expand Up @@ -77,6 +79,9 @@ public Rel from(io.substrait.proto.Rel rel) {
case JOIN -> {
return newJoin(rel.getJoin());
}
case NESTED_LOOP_JOIN -> {
return newNestedLoopJoin(rel.getNestedLoopJoin());
}
case SET -> {
return newSet(rel.getSet());
}
Expand Down Expand Up @@ -532,6 +537,33 @@ private Rel newHashJoin(HashJoinRel rel) {
return builder.build();
}

private NestedLoopJoin newNestedLoopJoin(NestedLoopJoinRel rel) {
Rel left = from(rel.getLeft());
Rel right = from(rel.getRight());
Type.Struct leftStruct = left.getRecordType();
Type.Struct rightStruct = right.getRecordType();
Type.Struct unionedStruct = Type.Struct.builder().from(leftStruct).from(rightStruct).build();
var converter = new ProtoExpressionConverter(lookup, extensions, unionedStruct, this);
var builder =
NestedLoopJoin.builder()
.left(left)
.right(right)
.condition(
// defaults to true (aka cartesian join) if the join expression is missing
rel.hasExpression()
? converter.from(rel.getExpression())
: Expression.BoolLiteral.builder().value(true).build())
.joinType(NestedLoopJoin.JoinType.fromProto(rel.getType()));

builder
.commonExtension(optionalAdvancedExtension(rel.getCommon()))
.remap(optionalRelmap(rel.getCommon()));
if (rel.hasAdvancedExtension()) {
builder.extension(advancedExtension(rel.getAdvancedExtension()));
}
return builder.build();
}

private static Optional<Rel.Remap> optionalRelmap(io.substrait.proto.RelCommon relCommon) {
return Optional.ofNullable(
relCommon.hasEmit() ? Rel.Remap.of(relCommon.getEmit().getOutputMappingList()) : null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import io.substrait.expression.ImmutableFieldReference;
import io.substrait.relation.physical.HashJoin;
import io.substrait.relation.physical.ImmutableHashJoin;
import io.substrait.relation.physical.ImmutableNestedLoopJoin;
import io.substrait.relation.physical.NestedLoopJoin;
import io.substrait.type.Type;
import java.util.ArrayList;
import java.util.List;
Expand Down Expand Up @@ -120,6 +122,23 @@ public Optional<Rel> visit(Join join) throws RuntimeException {
.build());
}

@Override
public Optional<Rel> visit(NestedLoopJoin nestedLoopJoin) throws RuntimeException {
var left = nestedLoopJoin.getLeft().accept(this);
var right = nestedLoopJoin.getRight().accept(this);
var condition = visitExpression(nestedLoopJoin.getCondition());
if (allEmpty(left, right, condition)) {
return Optional.empty();
}
return Optional.of(
ImmutableNestedLoopJoin.builder()
.from(nestedLoopJoin)
.left(left.orElse(nestedLoopJoin.getLeft()))
.right(right.orElse(nestedLoopJoin.getRight()))
.condition(condition.orElse(nestedLoopJoin.getCondition()))
.build());
}

@Override
public Optional<Rel> visit(Set set) throws RuntimeException {
return transformList(set.getInputs(), t -> t.accept(this))
Expand Down
16 changes: 16 additions & 0 deletions core/src/main/java/io/substrait/relation/RelProtoConverter.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import io.substrait.proto.FilterRel;
import io.substrait.proto.HashJoinRel;
import io.substrait.proto.JoinRel;
import io.substrait.proto.NestedLoopJoinRel;
import io.substrait.proto.ProjectRel;
import io.substrait.proto.ReadRel;
import io.substrait.proto.Rel;
Expand All @@ -24,6 +25,7 @@
import io.substrait.proto.SortRel;
import io.substrait.relation.files.FileOrFiles;
import io.substrait.relation.physical.HashJoin;
import io.substrait.relation.physical.NestedLoopJoin;
import io.substrait.type.proto.TypeProtoConverter;
import java.util.Collection;
import java.util.List;
Expand Down Expand Up @@ -179,6 +181,20 @@ public Rel visit(Join join) throws RuntimeException {
return Rel.newBuilder().setJoin(builder).build();
}

@Override
public Rel visit(NestedLoopJoin nestedLoopJoin) throws RuntimeException {
var builder =
NestedLoopJoinRel.newBuilder()
.setCommon(common(nestedLoopJoin))
.setLeft(toProto(nestedLoopJoin.getLeft()))
.setRight(toProto(nestedLoopJoin.getRight()))
.setExpression(toProto(nestedLoopJoin.getCondition()))
.setType(nestedLoopJoin.getJoinType().toProto());

nestedLoopJoin.getExtension().ifPresent(ae -> builder.setAdvancedExtension(ae.toProto()));
return Rel.newBuilder().setNestedLoopJoin(builder).build();
}

@Override
public Rel visit(Set set) throws RuntimeException {
var builder = SetRel.newBuilder().setCommon(common(set)).setOp(set.getSetOp().toProto());
Expand Down
3 changes: 3 additions & 0 deletions core/src/main/java/io/substrait/relation/RelVisitor.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package io.substrait.relation;

import io.substrait.relation.physical.HashJoin;
import io.substrait.relation.physical.NestedLoopJoin;

public interface RelVisitor<OUTPUT, EXCEPTION extends Exception> {
OUTPUT visit(Aggregate aggregate) throws EXCEPTION;
Expand All @@ -13,6 +14,8 @@ public interface RelVisitor<OUTPUT, EXCEPTION extends Exception> {

OUTPUT visit(Join join) throws EXCEPTION;

OUTPUT visit(NestedLoopJoin nestedLoopJoin) throws EXCEPTION;

OUTPUT visit(Set set) throws EXCEPTION;

OUTPUT visit(NamedScan namedScan) throws EXCEPTION;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package io.substrait.relation.physical;

import io.substrait.expression.Expression;
import io.substrait.proto.NestedLoopJoinRel;
import io.substrait.relation.BiRel;
import io.substrait.relation.HasExtension;
import io.substrait.relation.RelVisitor;
import io.substrait.type.Type;
import io.substrait.type.TypeCreator;
import java.util.stream.Stream;
import org.immutables.value.Value;

@Value.Immutable
public abstract class NestedLoopJoin extends BiRel implements HasExtension {

public abstract Expression getCondition();

public abstract JoinType getJoinType();

public static enum JoinType {
UNKNOWN(NestedLoopJoinRel.JoinType.JOIN_TYPE_UNSPECIFIED),
INNER(NestedLoopJoinRel.JoinType.JOIN_TYPE_INNER),
OUTER(NestedLoopJoinRel.JoinType.JOIN_TYPE_OUTER),
LEFT(NestedLoopJoinRel.JoinType.JOIN_TYPE_LEFT),
RIGHT(NestedLoopJoinRel.JoinType.JOIN_TYPE_RIGHT),
LEFT_SEMI(NestedLoopJoinRel.JoinType.JOIN_TYPE_LEFT_SEMI),
RIGHT_SEMI(NestedLoopJoinRel.JoinType.JOIN_TYPE_RIGHT_SEMI),
LEFT_ANTI(NestedLoopJoinRel.JoinType.JOIN_TYPE_LEFT_ANTI),
RIGHT_ANTI(NestedLoopJoinRel.JoinType.JOIN_TYPE_RIGHT_ANTI);

private NestedLoopJoinRel.JoinType proto;

JoinType(NestedLoopJoinRel.JoinType proto) {
this.proto = proto;
}

public NestedLoopJoinRel.JoinType toProto() {
return proto;
}

public static JoinType fromProto(NestedLoopJoinRel.JoinType proto) {
for (var v : values()) {
if (v.proto == proto) {
return v;
}
}

throw new IllegalArgumentException("Unknown type: " + proto);
}
}

@Override
protected Type.Struct deriveRecordType() {
Stream<Type> leftTypes =
switch (getJoinType()) {
case RIGHT, OUTER -> getLeft().getRecordType().fields().stream()
.map(TypeCreator::asNullable);
case RIGHT_ANTI, RIGHT_SEMI -> Stream.empty();
default -> getLeft().getRecordType().fields().stream();
};
Stream<Type> rightTypes =
switch (getJoinType()) {
case LEFT, OUTER -> getRight().getRecordType().fields().stream()
.map(TypeCreator::asNullable);
case LEFT_ANTI, LEFT_SEMI -> Stream.empty();
default -> getRight().getRecordType().fields().stream();
};
return TypeCreator.REQUIRED.struct(Stream.concat(leftTypes, rightTypes));
}

@Override
public <O, E extends Exception> O accept(RelVisitor<O, E> visitor) throws E {
return visitor.visit(this);
}

public static ImmutableNestedLoopJoin.Builder builder() {
return ImmutableNestedLoopJoin.builder();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import io.substrait.relation.Sort;
import io.substrait.relation.VirtualTableScan;
import io.substrait.relation.physical.HashJoin;
import io.substrait.relation.physical.NestedLoopJoin;
import io.substrait.relation.utils.StringHolder;
import io.substrait.relation.utils.StringHolderHandlingProtoRelConverter;
import io.substrait.type.NamedStruct;
Expand Down Expand Up @@ -186,6 +187,19 @@ void hashJoin() {
verifyRoundTrip(relWithoutKeys);
}

@Test
void nestedLoopJoin() {
Rel rel =
NestedLoopJoin.builder()
.from(
b.nestedLoopJoin(
__ -> b.bool(true), NestedLoopJoin.JoinType.INNER, commonTable, commonTable))
.commonExtension(commonExtension)
.extension(relExtension)
.build();
verifyRoundTrip(rel);
}

@Test
void project() {
Rel rel =
Expand Down
16 changes: 16 additions & 0 deletions core/src/test/java/io/substrait/type/proto/JoinRoundtripTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import io.substrait.TestBase;
import io.substrait.relation.Rel;
import io.substrait.relation.physical.HashJoin;
import io.substrait.relation.physical.NestedLoopJoin;
import java.util.Arrays;
import java.util.List;
import org.junit.jupiter.api.Test;
Expand Down Expand Up @@ -31,4 +32,19 @@ void hashJoin() {
.build();
verifyRoundTrip(relWithoutKeys);
}

@Test
void nestedLoopJoin() {
List<Rel> inputRels = Arrays.asList(leftTable, rightTable);
Rel rel =
NestedLoopJoin.builder()
.from(
b.nestedLoopJoin(
__ -> b.equal(b.fieldReference(inputRels, 0), b.fieldReference(inputRels, 5)),
NestedLoopJoin.JoinType.INNER,
leftTable,
rightTable))
.build();
verifyRoundTrip(rel);
}
}

0 comments on commit b66d5b1

Please sign in to comment.