Skip to content

Commit ee623f7

Browse files
author
Dane Pitkin
committed
feat: add MergeJoinRel
1 parent b938573 commit ee623f7

File tree

9 files changed

+305
-60
lines changed

9 files changed

+305
-60
lines changed

core/src/main/java/io/substrait/dsl/SubstraitBuilder.java

+48-19
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import io.substrait.relation.Set;
2929
import io.substrait.relation.Sort;
3030
import io.substrait.relation.physical.HashJoin;
31+
import io.substrait.relation.physical.MergeJoin;
3132
import io.substrait.relation.physical.NestedLoopJoin;
3233
import io.substrait.type.ImmutableType;
3334
import io.substrait.type.NamedStruct;
@@ -201,27 +202,32 @@ public HashJoin hashJoin(
201202
.build();
202203
}
203204

204-
public NamedScan namedScan(
205-
Iterable<String> tableName, Iterable<String> columnNames, Iterable<Type> types) {
206-
return namedScan(tableName, columnNames, types, Optional.empty());
207-
}
208-
209-
public NamedScan namedScan(
210-
Iterable<String> tableName,
211-
Iterable<String> columnNames,
212-
Iterable<Type> types,
213-
Rel.Remap remap) {
214-
return namedScan(tableName, columnNames, types, Optional.of(remap));
205+
public MergeJoin mergeJoin(
206+
List<Integer> leftKeys,
207+
List<Integer> rightKeys,
208+
MergeJoin.JoinType joinType,
209+
Rel left,
210+
Rel right) {
211+
return mergeJoin(leftKeys, rightKeys, joinType, Optional.empty(), left, right);
215212
}
216213

217-
private NamedScan namedScan(
218-
Iterable<String> tableName,
219-
Iterable<String> columnNames,
220-
Iterable<Type> types,
221-
Optional<Rel.Remap> remap) {
222-
var struct = Type.Struct.builder().addAllFields(types).nullable(false).build();
223-
var namedStruct = NamedStruct.of(columnNames, struct);
224-
return NamedScan.builder().names(tableName).initialSchema(namedStruct).remap(remap).build();
214+
public MergeJoin mergeJoin(
215+
List<Integer> leftKeys,
216+
List<Integer> rightKeys,
217+
MergeJoin.JoinType joinType,
218+
Optional<Rel.Remap> remap,
219+
Rel left,
220+
Rel right) {
221+
return MergeJoin.builder()
222+
.left(left)
223+
.right(right)
224+
.leftKeys(
225+
this.fieldReferences(left, leftKeys.stream().mapToInt(Integer::intValue).toArray()))
226+
.rightKeys(
227+
this.fieldReferences(right, rightKeys.stream().mapToInt(Integer::intValue).toArray()))
228+
.joinType(joinType)
229+
.remap(remap)
230+
.build();
225231
}
226232

227233
public NestedLoopJoin nestedLoopJoin(
@@ -248,6 +254,29 @@ private NestedLoopJoin nestedLoopJoin(
248254
.build();
249255
}
250256

257+
public NamedScan namedScan(
258+
Iterable<String> tableName, Iterable<String> columnNames, Iterable<Type> types) {
259+
return namedScan(tableName, columnNames, types, Optional.empty());
260+
}
261+
262+
public NamedScan namedScan(
263+
Iterable<String> tableName,
264+
Iterable<String> columnNames,
265+
Iterable<Type> types,
266+
Rel.Remap remap) {
267+
return namedScan(tableName, columnNames, types, Optional.of(remap));
268+
}
269+
270+
private NamedScan namedScan(
271+
Iterable<String> tableName,
272+
Iterable<String> columnNames,
273+
Iterable<Type> types,
274+
Optional<Rel.Remap> remap) {
275+
var struct = Type.Struct.builder().addAllFields(types).nullable(false).build();
276+
var namedStruct = NamedStruct.of(columnNames, struct);
277+
return NamedScan.builder().names(tableName).initialSchema(namedStruct).remap(remap).build();
278+
}
279+
251280
public Project project(Function<Rel, Iterable<? extends Expression>> expressionsFn, Rel input) {
252281
return project(expressionsFn, Optional.empty(), input);
253282
}

core/src/main/java/io/substrait/relation/AbstractRelVisitor.java

+11-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package io.substrait.relation;
22

33
import io.substrait.relation.physical.HashJoin;
4+
import io.substrait.relation.physical.MergeJoin;
45
import io.substrait.relation.physical.NestedLoopJoin;
56

67
public abstract class AbstractRelVisitor<OUTPUT, EXCEPTION extends Exception>
@@ -32,11 +33,6 @@ public OUTPUT visit(Join join) throws EXCEPTION {
3233
return visitFallback(join);
3334
}
3435

35-
@Override
36-
public OUTPUT visit(NestedLoopJoin nestedLoopJoin) throws EXCEPTION {
37-
return visitFallback(nestedLoopJoin);
38-
}
39-
4036
@Override
4137
public OUTPUT visit(Set set) throws EXCEPTION {
4238
return visitFallback(set);
@@ -96,4 +92,14 @@ public OUTPUT visit(ExtensionTable extensionTable) throws EXCEPTION {
9692
public OUTPUT visit(HashJoin hashJoin) throws EXCEPTION {
9793
return visitFallback(hashJoin);
9894
}
95+
96+
@Override
97+
public OUTPUT visit(NestedLoopJoin nestedLoopJoin) throws EXCEPTION {
98+
return visitFallback(nestedLoopJoin);
99+
}
100+
101+
@Override
102+
public OUTPUT visit(MergeJoin mergeJoin) throws EXCEPTION {
103+
return visitFallback(mergeJoin);
104+
}
99105
}

core/src/main/java/io/substrait/relation/ProtoRelConverter.java

+40-3
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import io.substrait.proto.FilterRel;
1818
import io.substrait.proto.HashJoinRel;
1919
import io.substrait.proto.JoinRel;
20+
import io.substrait.proto.MergeJoinRel;
2021
import io.substrait.proto.NestedLoopJoinRel;
2122
import io.substrait.proto.ProjectRel;
2223
import io.substrait.proto.ReadRel;
@@ -28,6 +29,7 @@
2829
import io.substrait.relation.files.ImmutableFileFormat;
2930
import io.substrait.relation.files.ImmutableFileOrFiles;
3031
import io.substrait.relation.physical.HashJoin;
32+
import io.substrait.relation.physical.MergeJoin;
3133
import io.substrait.relation.physical.NestedLoopJoin;
3234
import io.substrait.type.ImmutableNamedStruct;
3335
import io.substrait.type.NamedStruct;
@@ -79,9 +81,6 @@ public Rel from(io.substrait.proto.Rel rel) {
7981
case JOIN -> {
8082
return newJoin(rel.getJoin());
8183
}
82-
case NESTED_LOOP_JOIN -> {
83-
return newNestedLoopJoin(rel.getNestedLoopJoin());
84-
}
8584
case SET -> {
8685
return newSet(rel.getSet());
8786
}
@@ -103,6 +102,12 @@ public Rel from(io.substrait.proto.Rel rel) {
103102
case HASH_JOIN -> {
104103
return newHashJoin(rel.getHashJoin());
105104
}
105+
case NESTED_LOOP_JOIN -> {
106+
return newNestedLoopJoin(rel.getNestedLoopJoin());
107+
}
108+
case MERGE_JOIN -> {
109+
return newMergeJoin(rel.getMergeJoin());
110+
}
106111
default -> {
107112
throw new UnsupportedOperationException("Unsupported RelTypeCase of " + relType);
108113
}
@@ -564,6 +569,38 @@ private NestedLoopJoin newNestedLoopJoin(NestedLoopJoinRel rel) {
564569
return builder.build();
565570
}
566571

572+
private Rel newMergeJoin(MergeJoinRel rel) {
573+
Rel left = from(rel.getLeft());
574+
Rel right = from(rel.getRight());
575+
var leftKeys = rel.getLeftKeysList();
576+
var rightKeys = rel.getRightKeysList();
577+
578+
Type.Struct leftStruct = left.getRecordType();
579+
Type.Struct rightStruct = right.getRecordType();
580+
Type.Struct unionedStruct = Type.Struct.builder().from(leftStruct).from(rightStruct).build();
581+
var leftConverter = new ProtoExpressionConverter(lookup, extensions, leftStruct, this);
582+
var rightConverter = new ProtoExpressionConverter(lookup, extensions, rightStruct, this);
583+
var unionConverter = new ProtoExpressionConverter(lookup, extensions, unionedStruct, this);
584+
var builder =
585+
MergeJoin.builder()
586+
.left(left)
587+
.right(right)
588+
.leftKeys(leftKeys.stream().map(leftConverter::from).collect(Collectors.toList()))
589+
.rightKeys(rightKeys.stream().map(rightConverter::from).collect(Collectors.toList()))
590+
.joinType(MergeJoin.JoinType.fromProto(rel.getType()))
591+
.postJoinFilter(
592+
Optional.ofNullable(
593+
rel.hasPostJoinFilter() ? unionConverter.from(rel.getPostJoinFilter()) : null));
594+
595+
builder
596+
.commonExtension(optionalAdvancedExtension(rel.getCommon()))
597+
.remap(optionalRelmap(rel.getCommon()));
598+
if (rel.hasAdvancedExtension()) {
599+
builder.extension(advancedExtension(rel.getAdvancedExtension()));
600+
}
601+
return builder.build();
602+
}
603+
567604
private static Optional<Rel.Remap> optionalRelmap(io.substrait.proto.RelCommon relCommon) {
568605
return Optional.ofNullable(
569606
relCommon.hasEmit() ? Rel.Remap.of(relCommon.getEmit().getOutputMappingList()) : null);

core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java

+42-17
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88
import io.substrait.expression.ImmutableFieldReference;
99
import io.substrait.relation.physical.HashJoin;
1010
import io.substrait.relation.physical.ImmutableHashJoin;
11+
import io.substrait.relation.physical.ImmutableMergeJoin;
1112
import io.substrait.relation.physical.ImmutableNestedLoopJoin;
13+
import io.substrait.relation.physical.MergeJoin;
1214
import io.substrait.relation.physical.NestedLoopJoin;
1315
import io.substrait.type.Type;
1416
import java.util.ArrayList;
@@ -122,23 +124,6 @@ public Optional<Rel> visit(Join join) throws RuntimeException {
122124
.build());
123125
}
124126

125-
@Override
126-
public Optional<Rel> visit(NestedLoopJoin nestedLoopJoin) throws RuntimeException {
127-
var left = nestedLoopJoin.getLeft().accept(this);
128-
var right = nestedLoopJoin.getRight().accept(this);
129-
var condition = visitExpression(nestedLoopJoin.getCondition());
130-
if (allEmpty(left, right, condition)) {
131-
return Optional.empty();
132-
}
133-
return Optional.of(
134-
ImmutableNestedLoopJoin.builder()
135-
.from(nestedLoopJoin)
136-
.left(left.orElse(nestedLoopJoin.getLeft()))
137-
.right(right.orElse(nestedLoopJoin.getRight()))
138-
.condition(condition.orElse(nestedLoopJoin.getCondition()))
139-
.build());
140-
}
141-
142127
@Override
143128
public Optional<Rel> visit(Set set) throws RuntimeException {
144129
return transformList(set.getInputs(), t -> t.accept(this))
@@ -210,6 +195,46 @@ public Optional<Rel> visit(HashJoin hashJoin) throws RuntimeException {
210195
.build());
211196
}
212197

198+
@Override
199+
public Optional<Rel> visit(NestedLoopJoin nestedLoopJoin) throws RuntimeException {
200+
var left = nestedLoopJoin.getLeft().accept(this);
201+
var right = nestedLoopJoin.getRight().accept(this);
202+
var condition = visitExpression(nestedLoopJoin.getCondition());
203+
if (allEmpty(left, right, condition)) {
204+
return Optional.empty();
205+
}
206+
return Optional.of(
207+
ImmutableNestedLoopJoin.builder()
208+
.from(nestedLoopJoin)
209+
.left(left.orElse(nestedLoopJoin.getLeft()))
210+
.right(right.orElse(nestedLoopJoin.getRight()))
211+
.condition(condition.orElse(nestedLoopJoin.getCondition()))
212+
.build());
213+
}
214+
215+
@Override
216+
public Optional<Rel> visit(MergeJoin mergeJoin) throws RuntimeException {
217+
var left = mergeJoin.getLeft().accept(this);
218+
var right = mergeJoin.getRight().accept(this);
219+
var leftKeys = mergeJoin.getLeftKeys();
220+
var rightKeys = mergeJoin.getRightKeys();
221+
var postFilter = mergeJoin.getPostJoinFilter().flatMap(t -> visitExpression(t));
222+
if (allEmpty(left, right, postFilter)) {
223+
return Optional.empty();
224+
}
225+
return Optional.of(
226+
ImmutableMergeJoin.builder()
227+
.from(mergeJoin)
228+
.left(left.orElse(mergeJoin.getLeft()))
229+
.right(right.orElse(mergeJoin.getRight()))
230+
.leftKeys(leftKeys)
231+
.rightKeys(rightKeys)
232+
.postJoinFilter(
233+
Optional.ofNullable(
234+
postFilter.orElseGet(() -> mergeJoin.getPostJoinFilter().orElse(null))))
235+
.build());
236+
}
237+
213238
private Optional<Expression> visitExpression(Expression expression) {
214239
ExpressionVisitor<Optional<Expression>, RuntimeException> visitor =
215240
new AbstractExpressionVisitor<>() {

core/src/main/java/io/substrait/relation/RelProtoConverter.java

+41-14
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import io.substrait.proto.FilterRel;
1616
import io.substrait.proto.HashJoinRel;
1717
import io.substrait.proto.JoinRel;
18+
import io.substrait.proto.MergeJoinRel;
1819
import io.substrait.proto.NestedLoopJoinRel;
1920
import io.substrait.proto.ProjectRel;
2021
import io.substrait.proto.ReadRel;
@@ -25,6 +26,7 @@
2526
import io.substrait.proto.SortRel;
2627
import io.substrait.relation.files.FileOrFiles;
2728
import io.substrait.relation.physical.HashJoin;
29+
import io.substrait.relation.physical.MergeJoin;
2830
import io.substrait.relation.physical.NestedLoopJoin;
2931
import io.substrait.type.proto.TypeProtoConverter;
3032
import java.util.Collection;
@@ -181,20 +183,6 @@ public Rel visit(Join join) throws RuntimeException {
181183
return Rel.newBuilder().setJoin(builder).build();
182184
}
183185

184-
@Override
185-
public Rel visit(NestedLoopJoin nestedLoopJoin) throws RuntimeException {
186-
var builder =
187-
NestedLoopJoinRel.newBuilder()
188-
.setCommon(common(nestedLoopJoin))
189-
.setLeft(toProto(nestedLoopJoin.getLeft()))
190-
.setRight(toProto(nestedLoopJoin.getRight()))
191-
.setExpression(toProto(nestedLoopJoin.getCondition()))
192-
.setType(nestedLoopJoin.getJoinType().toProto());
193-
194-
nestedLoopJoin.getExtension().ifPresent(ae -> builder.setAdvancedExtension(ae.toProto()));
195-
return Rel.newBuilder().setNestedLoopJoin(builder).build();
196-
}
197-
198186
@Override
199187
public Rel visit(Set set) throws RuntimeException {
200188
var builder = SetRel.newBuilder().setCommon(common(set)).setOp(set.getSetOp().toProto());
@@ -280,6 +268,45 @@ public Rel visit(HashJoin hashJoin) throws RuntimeException {
280268
return Rel.newBuilder().setHashJoin(builder).build();
281269
}
282270

271+
@Override
272+
public Rel visit(NestedLoopJoin nestedLoopJoin) throws RuntimeException {
273+
var builder =
274+
NestedLoopJoinRel.newBuilder()
275+
.setCommon(common(nestedLoopJoin))
276+
.setLeft(toProto(nestedLoopJoin.getLeft()))
277+
.setRight(toProto(nestedLoopJoin.getRight()))
278+
.setExpression(toProto(nestedLoopJoin.getCondition()))
279+
.setType(nestedLoopJoin.getJoinType().toProto());
280+
281+
nestedLoopJoin.getExtension().ifPresent(ae -> builder.setAdvancedExtension(ae.toProto()));
282+
return Rel.newBuilder().setNestedLoopJoin(builder).build();
283+
}
284+
285+
@Override
286+
public Rel visit(MergeJoin mergeJoin) throws RuntimeException {
287+
var builder =
288+
MergeJoinRel.newBuilder()
289+
.setCommon(common(mergeJoin))
290+
.setLeft(toProto(mergeJoin.getLeft()))
291+
.setRight(toProto(mergeJoin.getRight()))
292+
.setType(mergeJoin.getJoinType().toProto());
293+
294+
List<FieldReference> leftKeys = mergeJoin.getLeftKeys();
295+
List<FieldReference> rightKeys = mergeJoin.getRightKeys();
296+
297+
if (leftKeys.size() != rightKeys.size()) {
298+
throw new RuntimeException("Number of left and right keys must be equal.");
299+
}
300+
301+
builder.addAllLeftKeys(leftKeys.stream().map(this::toProto).collect(Collectors.toList()));
302+
builder.addAllRightKeys(rightKeys.stream().map(this::toProto).collect(Collectors.toList()));
303+
304+
mergeJoin.getPostJoinFilter().ifPresent(t -> builder.setPostJoinFilter(toProto(t)));
305+
306+
mergeJoin.getExtension().ifPresent(ae -> builder.setAdvancedExtension(ae.toProto()));
307+
return Rel.newBuilder().setMergeJoin(builder).build();
308+
}
309+
283310
@Override
284311
public Rel visit(Project project) throws RuntimeException {
285312
var builder =

core/src/main/java/io/substrait/relation/RelVisitor.java

+5-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package io.substrait.relation;
22

33
import io.substrait.relation.physical.HashJoin;
4+
import io.substrait.relation.physical.MergeJoin;
45
import io.substrait.relation.physical.NestedLoopJoin;
56

67
public interface RelVisitor<OUTPUT, EXCEPTION extends Exception> {
@@ -14,8 +15,6 @@ public interface RelVisitor<OUTPUT, EXCEPTION extends Exception> {
1415

1516
OUTPUT visit(Join join) throws EXCEPTION;
1617

17-
OUTPUT visit(NestedLoopJoin nestedLoopJoin) throws EXCEPTION;
18-
1918
OUTPUT visit(Set set) throws EXCEPTION;
2019

2120
OUTPUT visit(NamedScan namedScan) throws EXCEPTION;
@@ -39,4 +38,8 @@ public interface RelVisitor<OUTPUT, EXCEPTION extends Exception> {
3938
OUTPUT visit(ExtensionTable extensionTable) throws EXCEPTION;
4039

4140
OUTPUT visit(HashJoin hashJoin) throws EXCEPTION;
41+
42+
OUTPUT visit(NestedLoopJoin nestedLoopJoin) throws EXCEPTION;
43+
44+
OUTPUT visit(MergeJoin mergeJoin) throws EXCEPTION;
4245
}

0 commit comments

Comments
 (0)