Skip to content

Commit 237179f

Browse files
authored
feat: add MergeJoinRel (#201)
1 parent 496d1a8 commit 237179f

File tree

9 files changed

+304
-61
lines changed

9 files changed

+304
-61
lines changed

Diff for: core/src/main/java/io/substrait/dsl/SubstraitBuilder.java

+48-19
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import io.substrait.relation.Set;
2929
import io.substrait.relation.Sort;
3030
import io.substrait.relation.physical.HashJoin;
31+
import io.substrait.relation.physical.MergeJoin;
3132
import io.substrait.relation.physical.NestedLoopJoin;
3233
import io.substrait.type.ImmutableType;
3334
import io.substrait.type.NamedStruct;
@@ -201,27 +202,32 @@ public HashJoin hashJoin(
201202
.build();
202203
}
203204

204-
public NamedScan namedScan(
205-
Iterable<String> tableName, Iterable<String> columnNames, Iterable<Type> types) {
206-
return namedScan(tableName, columnNames, types, Optional.empty());
207-
}
208-
209-
public NamedScan namedScan(
210-
Iterable<String> tableName,
211-
Iterable<String> columnNames,
212-
Iterable<Type> types,
213-
Rel.Remap remap) {
214-
return namedScan(tableName, columnNames, types, Optional.of(remap));
205+
public MergeJoin mergeJoin(
206+
List<Integer> leftKeys,
207+
List<Integer> rightKeys,
208+
MergeJoin.JoinType joinType,
209+
Rel left,
210+
Rel right) {
211+
return mergeJoin(leftKeys, rightKeys, joinType, Optional.empty(), left, right);
215212
}
216213

217-
private NamedScan namedScan(
218-
Iterable<String> tableName,
219-
Iterable<String> columnNames,
220-
Iterable<Type> types,
221-
Optional<Rel.Remap> remap) {
222-
var struct = Type.Struct.builder().addAllFields(types).nullable(false).build();
223-
var namedStruct = NamedStruct.of(columnNames, struct);
224-
return NamedScan.builder().names(tableName).initialSchema(namedStruct).remap(remap).build();
214+
public MergeJoin mergeJoin(
215+
List<Integer> leftKeys,
216+
List<Integer> rightKeys,
217+
MergeJoin.JoinType joinType,
218+
Optional<Rel.Remap> remap,
219+
Rel left,
220+
Rel right) {
221+
return MergeJoin.builder()
222+
.left(left)
223+
.right(right)
224+
.leftKeys(
225+
this.fieldReferences(left, leftKeys.stream().mapToInt(Integer::intValue).toArray()))
226+
.rightKeys(
227+
this.fieldReferences(right, rightKeys.stream().mapToInt(Integer::intValue).toArray()))
228+
.joinType(joinType)
229+
.remap(remap)
230+
.build();
225231
}
226232

227233
public NestedLoopJoin nestedLoopJoin(
@@ -248,6 +254,29 @@ private NestedLoopJoin nestedLoopJoin(
248254
.build();
249255
}
250256

257+
public NamedScan namedScan(
258+
Iterable<String> tableName, Iterable<String> columnNames, Iterable<Type> types) {
259+
return namedScan(tableName, columnNames, types, Optional.empty());
260+
}
261+
262+
public NamedScan namedScan(
263+
Iterable<String> tableName,
264+
Iterable<String> columnNames,
265+
Iterable<Type> types,
266+
Rel.Remap remap) {
267+
return namedScan(tableName, columnNames, types, Optional.of(remap));
268+
}
269+
270+
private NamedScan namedScan(
271+
Iterable<String> tableName,
272+
Iterable<String> columnNames,
273+
Iterable<Type> types,
274+
Optional<Rel.Remap> remap) {
275+
var struct = Type.Struct.builder().addAllFields(types).nullable(false).build();
276+
var namedStruct = NamedStruct.of(columnNames, struct);
277+
return NamedScan.builder().names(tableName).initialSchema(namedStruct).remap(remap).build();
278+
}
279+
251280
public Project project(Function<Rel, Iterable<? extends Expression>> expressionsFn, Rel input) {
252281
return project(expressionsFn, Optional.empty(), input);
253282
}

Diff for: core/src/main/java/io/substrait/relation/AbstractRelVisitor.java

+11-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package io.substrait.relation;
22

33
import io.substrait.relation.physical.HashJoin;
4+
import io.substrait.relation.physical.MergeJoin;
45
import io.substrait.relation.physical.NestedLoopJoin;
56

67
public abstract class AbstractRelVisitor<OUTPUT, EXCEPTION extends Exception>
@@ -32,11 +33,6 @@ public OUTPUT visit(Join join) throws EXCEPTION {
3233
return visitFallback(join);
3334
}
3435

35-
@Override
36-
public OUTPUT visit(NestedLoopJoin nestedLoopJoin) throws EXCEPTION {
37-
return visitFallback(nestedLoopJoin);
38-
}
39-
4036
@Override
4137
public OUTPUT visit(Set set) throws EXCEPTION {
4238
return visitFallback(set);
@@ -96,4 +92,14 @@ public OUTPUT visit(ExtensionTable extensionTable) throws EXCEPTION {
9692
public OUTPUT visit(HashJoin hashJoin) throws EXCEPTION {
9793
return visitFallback(hashJoin);
9894
}
95+
96+
@Override
97+
public OUTPUT visit(MergeJoin mergeJoin) throws EXCEPTION {
98+
return visitFallback(mergeJoin);
99+
}
100+
101+
@Override
102+
public OUTPUT visit(NestedLoopJoin nestedLoopJoin) throws EXCEPTION {
103+
return visitFallback(nestedLoopJoin);
104+
}
99105
}

Diff for: core/src/main/java/io/substrait/relation/ProtoRelConverter.java

+40-3
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import io.substrait.proto.FilterRel;
1818
import io.substrait.proto.HashJoinRel;
1919
import io.substrait.proto.JoinRel;
20+
import io.substrait.proto.MergeJoinRel;
2021
import io.substrait.proto.NestedLoopJoinRel;
2122
import io.substrait.proto.ProjectRel;
2223
import io.substrait.proto.ReadRel;
@@ -28,6 +29,7 @@
2829
import io.substrait.relation.files.ImmutableFileFormat;
2930
import io.substrait.relation.files.ImmutableFileOrFiles;
3031
import io.substrait.relation.physical.HashJoin;
32+
import io.substrait.relation.physical.MergeJoin;
3133
import io.substrait.relation.physical.NestedLoopJoin;
3234
import io.substrait.type.ImmutableNamedStruct;
3335
import io.substrait.type.NamedStruct;
@@ -79,9 +81,6 @@ public Rel from(io.substrait.proto.Rel rel) {
7981
case JOIN -> {
8082
return newJoin(rel.getJoin());
8183
}
82-
case NESTED_LOOP_JOIN -> {
83-
return newNestedLoopJoin(rel.getNestedLoopJoin());
84-
}
8584
case SET -> {
8685
return newSet(rel.getSet());
8786
}
@@ -103,6 +102,12 @@ public Rel from(io.substrait.proto.Rel rel) {
103102
case HASH_JOIN -> {
104103
return newHashJoin(rel.getHashJoin());
105104
}
105+
case MERGE_JOIN -> {
106+
return newMergeJoin(rel.getMergeJoin());
107+
}
108+
case NESTED_LOOP_JOIN -> {
109+
return newNestedLoopJoin(rel.getNestedLoopJoin());
110+
}
106111
default -> {
107112
throw new UnsupportedOperationException("Unsupported RelTypeCase of " + relType);
108113
}
@@ -537,6 +542,38 @@ private Rel newHashJoin(HashJoinRel rel) {
537542
return builder.build();
538543
}
539544

545+
private Rel newMergeJoin(MergeJoinRel rel) {
546+
Rel left = from(rel.getLeft());
547+
Rel right = from(rel.getRight());
548+
var leftKeys = rel.getLeftKeysList();
549+
var rightKeys = rel.getRightKeysList();
550+
551+
Type.Struct leftStruct = left.getRecordType();
552+
Type.Struct rightStruct = right.getRecordType();
553+
Type.Struct unionedStruct = Type.Struct.builder().from(leftStruct).from(rightStruct).build();
554+
var leftConverter = new ProtoExpressionConverter(lookup, extensions, leftStruct, this);
555+
var rightConverter = new ProtoExpressionConverter(lookup, extensions, rightStruct, this);
556+
var unionConverter = new ProtoExpressionConverter(lookup, extensions, unionedStruct, this);
557+
var builder =
558+
MergeJoin.builder()
559+
.left(left)
560+
.right(right)
561+
.leftKeys(leftKeys.stream().map(leftConverter::from).collect(Collectors.toList()))
562+
.rightKeys(rightKeys.stream().map(rightConverter::from).collect(Collectors.toList()))
563+
.joinType(MergeJoin.JoinType.fromProto(rel.getType()))
564+
.postJoinFilter(
565+
Optional.ofNullable(
566+
rel.hasPostJoinFilter() ? unionConverter.from(rel.getPostJoinFilter()) : null));
567+
568+
builder
569+
.commonExtension(optionalAdvancedExtension(rel.getCommon()))
570+
.remap(optionalRelmap(rel.getCommon()));
571+
if (rel.hasAdvancedExtension()) {
572+
builder.extension(advancedExtension(rel.getAdvancedExtension()));
573+
}
574+
return builder.build();
575+
}
576+
540577
private NestedLoopJoin newNestedLoopJoin(NestedLoopJoinRel rel) {
541578
Rel left = from(rel.getLeft());
542579
Rel right = from(rel.getRight());

Diff for: core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java

+41-18
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import io.substrait.expression.FieldReference;
1010
import io.substrait.expression.FunctionArg;
1111
import io.substrait.relation.physical.HashJoin;
12+
import io.substrait.relation.physical.MergeJoin;
1213
import io.substrait.relation.physical.NestedLoopJoin;
1314
import java.util.List;
1415
import java.util.Optional;
@@ -156,24 +157,6 @@ public Optional<Rel> visit(Join join) throws EXCEPTION {
156157
.build());
157158
}
158159

159-
@Override
160-
public Optional<Rel> visit(NestedLoopJoin nestedLoopJoin) throws EXCEPTION {
161-
var left = nestedLoopJoin.getLeft().accept(this);
162-
var right = nestedLoopJoin.getRight().accept(this);
163-
var condition = nestedLoopJoin.getCondition().accept(getExpressionCopyOnWriteVisitor());
164-
165-
if (allEmpty(left, right, condition)) {
166-
return Optional.empty();
167-
}
168-
return Optional.of(
169-
NestedLoopJoin.builder()
170-
.from(nestedLoopJoin)
171-
.left(left.orElse(nestedLoopJoin.getLeft()))
172-
.right(right.orElse(nestedLoopJoin.getRight()))
173-
.condition(condition.orElse(nestedLoopJoin.getCondition()))
174-
.build());
175-
}
176-
177160
@Override
178161
public Optional<Rel> visit(Set set) throws EXCEPTION {
179162
return transformList(set.getInputs(), t -> t.accept(this))
@@ -319,6 +302,46 @@ public Optional<Rel> visit(HashJoin hashJoin) throws EXCEPTION {
319302
.build());
320303
}
321304

305+
@Override
306+
public Optional<Rel> visit(MergeJoin mergeJoin) throws EXCEPTION {
307+
var left = mergeJoin.getLeft().accept(this);
308+
var right = mergeJoin.getRight().accept(this);
309+
var leftKeys = transformList(mergeJoin.getLeftKeys(), this::visitFieldReference);
310+
var rightKeys = transformList(mergeJoin.getRightKeys(), this::visitFieldReference);
311+
var postFilter = visitOptionalExpression(mergeJoin.getPostJoinFilter());
312+
313+
if (allEmpty(left, right, leftKeys, rightKeys, postFilter)) {
314+
return Optional.empty();
315+
}
316+
return Optional.of(
317+
MergeJoin.builder()
318+
.from(mergeJoin)
319+
.left(left.orElse(mergeJoin.getLeft()))
320+
.right(right.orElse(mergeJoin.getRight()))
321+
.leftKeys(leftKeys.orElse(mergeJoin.getLeftKeys()))
322+
.rightKeys(rightKeys.orElse(mergeJoin.getRightKeys()))
323+
.postJoinFilter(or(postFilter, mergeJoin::getPostJoinFilter))
324+
.build());
325+
}
326+
327+
@Override
328+
public Optional<Rel> visit(NestedLoopJoin nestedLoopJoin) throws EXCEPTION {
329+
var left = nestedLoopJoin.getLeft().accept(this);
330+
var right = nestedLoopJoin.getRight().accept(this);
331+
var condition = nestedLoopJoin.getCondition().accept(getExpressionCopyOnWriteVisitor());
332+
333+
if (allEmpty(left, right, condition)) {
334+
return Optional.empty();
335+
}
336+
return Optional.of(
337+
NestedLoopJoin.builder()
338+
.from(nestedLoopJoin)
339+
.left(left.orElse(nestedLoopJoin.getLeft()))
340+
.right(right.orElse(nestedLoopJoin.getRight()))
341+
.condition(condition.orElse(nestedLoopJoin.getCondition()))
342+
.build());
343+
}
344+
322345
// utilities
323346

324347
protected Optional<List<Expression>> visitExprList(List<Expression> exprs) throws EXCEPTION {

Diff for: core/src/main/java/io/substrait/relation/RelProtoConverter.java

+41-14
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import io.substrait.proto.FilterRel;
1616
import io.substrait.proto.HashJoinRel;
1717
import io.substrait.proto.JoinRel;
18+
import io.substrait.proto.MergeJoinRel;
1819
import io.substrait.proto.NestedLoopJoinRel;
1920
import io.substrait.proto.ProjectRel;
2021
import io.substrait.proto.ReadRel;
@@ -25,6 +26,7 @@
2526
import io.substrait.proto.SortRel;
2627
import io.substrait.relation.files.FileOrFiles;
2728
import io.substrait.relation.physical.HashJoin;
29+
import io.substrait.relation.physical.MergeJoin;
2830
import io.substrait.relation.physical.NestedLoopJoin;
2931
import io.substrait.type.proto.TypeProtoConverter;
3032
import java.util.Collection;
@@ -181,20 +183,6 @@ public Rel visit(Join join) throws RuntimeException {
181183
return Rel.newBuilder().setJoin(builder).build();
182184
}
183185

184-
@Override
185-
public Rel visit(NestedLoopJoin nestedLoopJoin) throws RuntimeException {
186-
var builder =
187-
NestedLoopJoinRel.newBuilder()
188-
.setCommon(common(nestedLoopJoin))
189-
.setLeft(toProto(nestedLoopJoin.getLeft()))
190-
.setRight(toProto(nestedLoopJoin.getRight()))
191-
.setExpression(toProto(nestedLoopJoin.getCondition()))
192-
.setType(nestedLoopJoin.getJoinType().toProto());
193-
194-
nestedLoopJoin.getExtension().ifPresent(ae -> builder.setAdvancedExtension(ae.toProto()));
195-
return Rel.newBuilder().setNestedLoopJoin(builder).build();
196-
}
197-
198186
@Override
199187
public Rel visit(Set set) throws RuntimeException {
200188
var builder = SetRel.newBuilder().setCommon(common(set)).setOp(set.getSetOp().toProto());
@@ -280,6 +268,45 @@ public Rel visit(HashJoin hashJoin) throws RuntimeException {
280268
return Rel.newBuilder().setHashJoin(builder).build();
281269
}
282270

271+
@Override
272+
public Rel visit(MergeJoin mergeJoin) throws RuntimeException {
273+
var builder =
274+
MergeJoinRel.newBuilder()
275+
.setCommon(common(mergeJoin))
276+
.setLeft(toProto(mergeJoin.getLeft()))
277+
.setRight(toProto(mergeJoin.getRight()))
278+
.setType(mergeJoin.getJoinType().toProto());
279+
280+
List<FieldReference> leftKeys = mergeJoin.getLeftKeys();
281+
List<FieldReference> rightKeys = mergeJoin.getRightKeys();
282+
283+
if (leftKeys.size() != rightKeys.size()) {
284+
throw new RuntimeException("Number of left and right keys must be equal.");
285+
}
286+
287+
builder.addAllLeftKeys(leftKeys.stream().map(this::toProto).collect(Collectors.toList()));
288+
builder.addAllRightKeys(rightKeys.stream().map(this::toProto).collect(Collectors.toList()));
289+
290+
mergeJoin.getPostJoinFilter().ifPresent(t -> builder.setPostJoinFilter(toProto(t)));
291+
292+
mergeJoin.getExtension().ifPresent(ae -> builder.setAdvancedExtension(ae.toProto()));
293+
return Rel.newBuilder().setMergeJoin(builder).build();
294+
}
295+
296+
@Override
297+
public Rel visit(NestedLoopJoin nestedLoopJoin) throws RuntimeException {
298+
var builder =
299+
NestedLoopJoinRel.newBuilder()
300+
.setCommon(common(nestedLoopJoin))
301+
.setLeft(toProto(nestedLoopJoin.getLeft()))
302+
.setRight(toProto(nestedLoopJoin.getRight()))
303+
.setExpression(toProto(nestedLoopJoin.getCondition()))
304+
.setType(nestedLoopJoin.getJoinType().toProto());
305+
306+
nestedLoopJoin.getExtension().ifPresent(ae -> builder.setAdvancedExtension(ae.toProto()));
307+
return Rel.newBuilder().setNestedLoopJoin(builder).build();
308+
}
309+
283310
@Override
284311
public Rel visit(Project project) throws RuntimeException {
285312
var builder =

Diff for: core/src/main/java/io/substrait/relation/RelVisitor.java

+5-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package io.substrait.relation;
22

33
import io.substrait.relation.physical.HashJoin;
4+
import io.substrait.relation.physical.MergeJoin;
45
import io.substrait.relation.physical.NestedLoopJoin;
56

67
public interface RelVisitor<OUTPUT, EXCEPTION extends Exception> {
@@ -14,8 +15,6 @@ public interface RelVisitor<OUTPUT, EXCEPTION extends Exception> {
1415

1516
OUTPUT visit(Join join) throws EXCEPTION;
1617

17-
OUTPUT visit(NestedLoopJoin nestedLoopJoin) throws EXCEPTION;
18-
1918
OUTPUT visit(Set set) throws EXCEPTION;
2019

2120
OUTPUT visit(NamedScan namedScan) throws EXCEPTION;
@@ -39,4 +38,8 @@ public interface RelVisitor<OUTPUT, EXCEPTION extends Exception> {
3938
OUTPUT visit(ExtensionTable extensionTable) throws EXCEPTION;
4039

4140
OUTPUT visit(HashJoin hashJoin) throws EXCEPTION;
41+
42+
OUTPUT visit(MergeJoin mergeJoin) throws EXCEPTION;
43+
44+
OUTPUT visit(NestedLoopJoin nestedLoopJoin) throws EXCEPTION;
4245
}

0 commit comments

Comments
 (0)