Skip to content

Commit 6ad28ae

Browse files
feat: add ExpandRel support to core and spark
Signed-off-by: Andrew Coleman <[email protected]>
1 parent e24ce6f commit 6ad28ae

18 files changed

+385
-37
lines changed

Diff for: core/src/main/java/io/substrait/dsl/SubstraitBuilder.java

+33
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,10 @@
1919
import io.substrait.plan.ImmutablePlan;
2020
import io.substrait.plan.ImmutableRoot;
2121
import io.substrait.plan.Plan;
22+
import io.substrait.proto.RelCommon;
2223
import io.substrait.relation.Aggregate;
2324
import io.substrait.relation.Cross;
25+
import io.substrait.relation.Expand;
2426
import io.substrait.relation.Fetch;
2527
import io.substrait.relation.Filter;
2628
import io.substrait.relation.Join;
@@ -313,6 +315,37 @@ private Project project(
313315
return Project.builder().input(input).expressions(expressions).remap(remap).build();
314316
}
315317

318+
public Expand expand(Function<Rel, Iterable<? extends Expand.ExpandField>> fieldsFn, Rel input) {
319+
return expand(fieldsFn, Optional.empty(), Optional.empty(), input);
320+
}
321+
322+
public Expand expand(
323+
Function<Rel, Iterable<? extends Expand.ExpandField>> fieldsFn,
324+
List<String> outputNames,
325+
Rel input) {
326+
return expand(fieldsFn, Optional.empty(), Optional.of(outputNames), input);
327+
}
328+
329+
public Expand expand(
330+
Function<Rel, Iterable<? extends Expand.ExpandField>> fieldsFn,
331+
Rel.Remap remap,
332+
List<String> outputNames,
333+
Rel input) {
334+
return expand(fieldsFn, Optional.of(remap), Optional.of(outputNames), input);
335+
}
336+
337+
private Expand expand(
338+
Function<Rel, Iterable<? extends Expand.ExpandField>> fieldsFn,
339+
Optional<Rel.Remap> remap,
340+
Optional<List<String>> outputNames,
341+
Rel input) {
342+
var fields = fieldsFn.apply(input);
343+
var expand = Expand.builder().input(input).fields(fields).remap(remap);
344+
outputNames.ifPresent(
345+
names -> expand.hint(RelCommon.Hint.newBuilder().addAllOutputNames(names).build()));
346+
return expand.build();
347+
}
348+
316349
public Set set(Set.SetOp op, Rel... inputs) {
317350
return set(op, Optional.empty(), inputs);
318351
}

Diff for: core/src/main/java/io/substrait/relation/AbstractRelVisitor.java

+5
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,11 @@ public OUTPUT visit(Project project) throws EXCEPTION {
5353
return visitFallback(project);
5454
}
5555

56+
@Override
57+
public OUTPUT visit(Expand expand) throws EXCEPTION {
58+
return visitFallback(expand);
59+
}
60+
5661
@Override
5762
public OUTPUT visit(Sort sort) throws EXCEPTION {
5863
return visitFallback(sort);

Diff for: core/src/main/java/io/substrait/relation/Expand.java

+64
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
package io.substrait.relation;
2+
3+
import io.substrait.expression.Expression;
4+
import io.substrait.type.Type;
5+
import io.substrait.type.TypeCreator;
6+
import java.util.List;
7+
import java.util.Optional;
8+
import java.util.stream.Stream;
9+
import org.immutables.value.Value;
10+
11+
@Value.Enclosing
12+
@Value.Immutable
13+
public abstract class Expand extends SingleInputRel {
14+
static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(Expand.class);
15+
16+
public abstract List<ExpandField> getFields();
17+
18+
@Override
19+
public Type.Struct deriveRecordType() {
20+
Type.Struct initial = getInput().getRecordType();
21+
return TypeCreator.of(initial.nullable())
22+
.struct(
23+
Stream.concat(
24+
initial.fields().stream(),
25+
getFields().stream()
26+
.map(
27+
f -> {
28+
if (f.getSwitchingField().isPresent()) {
29+
return f.getSwitchingField().get().getDuplicates().get(0).getType();
30+
} else {
31+
return f.getConsistentField().get().getType();
32+
}
33+
})));
34+
}
35+
36+
@Override
37+
public <O, E extends Exception> O accept(RelVisitor<O, E> visitor) throws E {
38+
return visitor.visit(this);
39+
}
40+
41+
public static ImmutableExpand.Builder builder() {
42+
return ImmutableExpand.builder();
43+
}
44+
45+
@Value.Immutable
46+
public abstract static class ExpandField {
47+
public abstract Optional<SwitchingField> getSwitchingField();
48+
49+
public abstract Optional<Expression> getConsistentField();
50+
51+
public static ImmutableExpand.ExpandField.Builder builder() {
52+
return ImmutableExpand.ExpandField.builder();
53+
}
54+
}
55+
56+
@Value.Immutable
57+
public abstract static class SwitchingField {
58+
public abstract List<Expression> getDuplicates();
59+
60+
public static ImmutableExpand.SwitchingField.Builder builder() {
61+
return ImmutableExpand.SwitchingField.builder();
62+
}
63+
}
64+
}

Diff for: core/src/main/java/io/substrait/relation/ProtoRelConverter.java

+56-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import io.substrait.proto.AggregateRel;
1111
import io.substrait.proto.ConsistentPartitionWindowRel;
1212
import io.substrait.proto.CrossRel;
13+
import io.substrait.proto.ExpandRel;
1314
import io.substrait.proto.ExtensionLeafRel;
1415
import io.substrait.proto.ExtensionMultiRel;
1516
import io.substrait.proto.ExtensionSingleRel;
@@ -21,6 +22,7 @@
2122
import io.substrait.proto.NestedLoopJoinRel;
2223
import io.substrait.proto.ProjectRel;
2324
import io.substrait.proto.ReadRel;
25+
import io.substrait.proto.RelCommon;
2426
import io.substrait.proto.SetRel;
2527
import io.substrait.proto.SortRel;
2628
import io.substrait.relation.extensions.EmptyDetail;
@@ -87,6 +89,9 @@ public Rel from(io.substrait.proto.Rel rel) {
8789
case PROJECT -> {
8890
return newProject(rel.getProject());
8991
}
92+
case EXPAND -> {
93+
return newExpand(rel.getExpand());
94+
}
9095
case CROSS -> {
9196
return newCross(rel.getCross());
9297
}
@@ -155,7 +160,10 @@ protected Filter newFilter(FilterRel rel) {
155160
}
156161

157162
protected NamedStruct newNamedStruct(ReadRel rel) {
158-
var namedStruct = rel.getBaseSchema();
163+
return newNamedStruct(rel.getBaseSchema());
164+
}
165+
166+
protected NamedStruct newNamedStruct(io.substrait.proto.NamedStruct namedStruct) {
159167
var struct = namedStruct.getStruct();
160168
return ImmutableNamedStruct.builder()
161169
.names(namedStruct.getNamesList())
@@ -389,6 +397,43 @@ protected Project newProject(ProjectRel rel) {
389397
return builder.build();
390398
}
391399

400+
protected Expand newExpand(ExpandRel rel) {
401+
var input = from(rel.getInput());
402+
var converter = new ProtoExpressionConverter(lookup, extensions, input.getRecordType(), this);
403+
var builder =
404+
Expand.builder()
405+
.input(input)
406+
.fields(
407+
rel.getFieldsList().stream()
408+
.map(
409+
expandField ->
410+
switch (expandField.getFieldTypeCase()) {
411+
case CONSISTENT_FIELD -> Expand.ExpandField.builder()
412+
.consistentField(converter.from(expandField.getConsistentField()))
413+
.build();
414+
case SWITCHING_FIELD -> Expand.ExpandField.builder()
415+
.switchingField(
416+
Expand.SwitchingField.builder()
417+
.duplicates(
418+
expandField
419+
.getSwitchingField()
420+
.getDuplicatesList()
421+
.stream()
422+
.map(converter::from)
423+
.collect(java.util.stream.Collectors.toList()))
424+
.build())
425+
.build();
426+
case FIELDTYPE_NOT_SET -> Expand.ExpandField.builder().build();
427+
})
428+
.collect(java.util.stream.Collectors.toList()));
429+
430+
builder
431+
.commonExtension(optionalAdvancedExtension(rel.getCommon()))
432+
.remap(optionalRelmap(rel.getCommon()))
433+
.hint(optionalHint(rel.getCommon()));
434+
return builder.build();
435+
}
436+
392437
protected Aggregate newAggregate(AggregateRel rel) {
393438
var input = from(rel.getInput());
394439
var protoExprConverter =
@@ -647,6 +692,16 @@ protected static Optional<Rel.Remap> optionalRelmap(io.substrait.proto.RelCommon
647692
relCommon.hasEmit() ? Rel.Remap.of(relCommon.getEmit().getOutputMappingList()) : null);
648693
}
649694

695+
protected static Optional<RelCommon.Hint> optionalHint(io.substrait.proto.RelCommon relCommon) {
696+
return Optional.ofNullable(
697+
relCommon.hasHint()
698+
? RelCommon.Hint.newBuilder()
699+
.setAlias(relCommon.getHint().getAlias())
700+
.addAllOutputNames(relCommon.getHint().getOutputNamesList())
701+
.build()
702+
: null);
703+
}
704+
650705
protected Optional<AdvancedExtension> optionalAdvancedExtension(
651706
io.substrait.proto.RelCommon relCommon) {
652707
return Optional.ofNullable(

Diff for: core/src/main/java/io/substrait/relation/Rel.java

+3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package io.substrait.relation;
22

33
import io.substrait.extension.AdvancedExtension;
4+
import io.substrait.proto.RelCommon;
45
import io.substrait.type.Type;
56
import io.substrait.type.TypeCreator;
67
import java.util.List;
@@ -21,6 +22,8 @@ public interface Rel {
2122

2223
List<Rel> getInputs();
2324

25+
Optional<RelCommon.Hint> getHint();
26+
2427
@Value.Immutable
2528
public abstract static class Remap {
2629
public abstract List<Integer> indices();

Diff for: core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java

+5
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,11 @@ public Optional<Rel> visit(Project project) throws EXCEPTION {
201201
.build());
202202
}
203203

204+
@Override
205+
public Optional<Rel> visit(Expand expand) throws EXCEPTION {
206+
throw new UnsupportedOperationException();
207+
}
208+
204209
@Override
205210
public Optional<Rel> visit(Sort sort) throws EXCEPTION {
206211
var input = sort.getInput().accept(this);

0 commit comments

Comments
 (0)