Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: enable support from/to pojo/protobuf for extended expressions #206

Merged
merged 14 commits into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package io.substrait.extendedexpression;

import io.substrait.expression.Expression;
import io.substrait.proto.AdvancedExtension;
import io.substrait.proto.AggregateFunction;
import io.substrait.type.NamedStruct;
import java.util.List;
import java.util.Optional;
import org.immutables.value.Value;

@Value.Immutable
public abstract class ExtendedExpression {
public abstract List<ExpressionReference> getReferredExpressions();

public abstract NamedStruct getBaseSchema();

public abstract List<String> getExpectedTypeUrls();

// creating simple extensions, such as extensionURIs and extensions, is performed on the fly

public abstract Optional<AdvancedExtension> getAdvancedExtension();
vbarua marked this conversation as resolved.
Show resolved Hide resolved

@Value.Immutable
public abstract static class ExpressionReference {
public abstract ExpressionTypeReference getExpressionType();
vbarua marked this conversation as resolved.
Show resolved Hide resolved

public abstract List<String> getOutputNames();
}

public abstract static class ExpressionTypeReference {}

@Value.Immutable
public abstract static class ExpressionType extends ExpressionTypeReference {
public abstract Expression getExpression();
}

@Value.Immutable
public abstract static class AggregateFunctionType extends ExpressionTypeReference {
public abstract AggregateFunction getMeasure();
vbarua marked this conversation as resolved.
Show resolved Hide resolved
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package io.substrait.extendedexpression;

import io.substrait.expression.Expression;
import io.substrait.expression.proto.ExpressionProtoConverter;
import io.substrait.extension.ExtensionCollector;
import io.substrait.proto.AggregateFunction;
import io.substrait.proto.ExpressionReference;
import io.substrait.proto.ExtendedExpression;
import io.substrait.type.proto.TypeProtoConverter;

/** Converts from {@link ExtendedExpression} to {@link ExtendedExpression} */
vbarua marked this conversation as resolved.
Show resolved Hide resolved
public class ExtendedExpressionProtoConverter {
public ExtendedExpression toProto(
io.substrait.extendedexpression.ExtendedExpression extendedExpression) {

ExtendedExpression.Builder builder = ExtendedExpression.newBuilder();
ExtensionCollector functionCollector = new ExtensionCollector();

final ExpressionProtoConverter expressionProtoConverter =
new ExpressionProtoConverter(functionCollector, null);

for (io.substrait.extendedexpression.ExtendedExpression.ExpressionReference
expressionReference : extendedExpression.getReferredExpressions()) {
io.substrait.extendedexpression.ExtendedExpression.ExpressionTypeReference expressionType =
expressionReference.getExpressionType();
if (expressionType
instanceof io.substrait.extendedexpression.ExtendedExpression.ExpressionType) {
io.substrait.proto.Expression expressionProto =
expressionProtoConverter.visit(
(Expression.ScalarFunctionInvocation)
((io.substrait.extendedexpression.ExtendedExpression.ExpressionType)
expressionType)
.getExpression());
ExpressionReference.Builder expressionReferenceBuilder =
ExpressionReference.newBuilder()
.setExpression(expressionProto)
.addAllOutputNames(expressionReference.getOutputNames());
builder.addReferredExpr(expressionReferenceBuilder);
} else if (expressionType
instanceof io.substrait.extendedexpression.ExtendedExpression.AggregateFunctionType) {
AggregateFunction measure =
((io.substrait.extendedexpression.ExtendedExpression.AggregateFunctionType)
expressionType)
.getMeasure();
ExpressionReference.Builder expressionReferenceBuilder =
ExpressionReference.newBuilder()
.setMeasure(measure.toBuilder())
.addAllOutputNames(expressionReference.getOutputNames());
builder.addReferredExpr(expressionReferenceBuilder);
} else {
throw new UnsupportedOperationException(
"Only Expression or Aggregate Function type are supported in conversion to proto Extended Expressions");
}
}
builder.setBaseSchema(
extendedExpression.getBaseSchema().toProto(new TypeProtoConverter(functionCollector)));

// the process of adding simple extensions, such as extensionURIs and extensions, is handled on
// the fly
functionCollector.addExtensionsToExtendedExpression(builder);
if (extendedExpression.getAdvancedExtension().isPresent()) {
builder.setAdvancedExtensions(extendedExpression.getAdvancedExtension().get());
}
return builder.build();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package io.substrait.extendedexpression;

import io.substrait.expression.Expression;
import io.substrait.expression.proto.ProtoExpressionConverter;
import io.substrait.extension.ExtensionCollector;
import io.substrait.extension.ExtensionLookup;
import io.substrait.extension.ImmutableExtensionLookup;
import io.substrait.extension.ImmutableSimpleExtension;
import io.substrait.extension.SimpleExtension;
import io.substrait.proto.AggregateFunction;
import io.substrait.proto.ExpressionReference;
import io.substrait.proto.NamedStruct;
import io.substrait.type.proto.ProtoTypeConverter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;

/** Converts from {@link io.substrait.proto.ExtendedExpression} to {@link ExtendedExpression} */
public class ProtoExtendedExpressionConverter {
private final SimpleExtension.ExtensionCollection extensionCollection;

public ProtoExtendedExpressionConverter() throws IOException {
this(SimpleExtension.loadDefaults());
}

public ProtoExtendedExpressionConverter(SimpleExtension.ExtensionCollection extensionCollection) {
this.extensionCollection = extensionCollection;
}

private final ProtoTypeConverter protoTypeConverter =
new ProtoTypeConverter(
new ExtensionCollector(), ImmutableSimpleExtension.ExtensionCollection.builder().build());

public ExtendedExpression from(io.substrait.proto.ExtendedExpression extendedExpression) {
// fill in simple extension information through a discovery in the current proto-extended
// expression
ExtensionLookup functionLookup =
ImmutableExtensionLookup.builder().from(extendedExpression).build();

NamedStruct baseSchemaProto = extendedExpression.getBaseSchema();

io.substrait.type.NamedStruct namedStruct =
io.substrait.type.NamedStruct.fromProto(baseSchemaProto, protoTypeConverter);

ProtoExpressionConverter protoExpressionConverter =
new ProtoExpressionConverter(
functionLookup, this.extensionCollection, namedStruct.struct(), null);

List<ExtendedExpression.ExpressionReference> expressionReferences = new ArrayList<>();
for (ExpressionReference expressionReference : extendedExpression.getReferredExprList()) {
if (expressionReference.getExprTypeCase().getNumber() == 1) { // Expression
vbarua marked this conversation as resolved.
Show resolved Hide resolved
Expression expressionPojo =
protoExpressionConverter.from(expressionReference.getExpression());
expressionReferences.add(
ImmutableExpressionReference.builder()
.expressionType(
ImmutableExpressionType.builder().expression(expressionPojo).build())
.addAllOutputNames(expressionReference.getOutputNamesList())
.build());
} else if (expressionReference.getExprTypeCase().getNumber() == 2) { // AggregateFunction
AggregateFunction measure = expressionReference.getMeasure();
vbarua marked this conversation as resolved.
Show resolved Hide resolved
ImmutableExpressionReference.Builder builder =
ImmutableExpressionReference.builder()
.expressionType(ImmutableAggregateFunctionType.builder().measure(measure).build())
.addAllOutputNames(expressionReference.getOutputNamesList());
expressionReferences.add(builder.build());
} else {
throw new UnsupportedOperationException(
"Only Expression or Aggregate Function type are supported in conversion from proto Extended Expressions");
}
}

ImmutableExtendedExpression.Builder builder =
ImmutableExtendedExpression.builder()
.referredExpressions(expressionReferences)
.advancedExtension(
Optional.ofNullable(
extendedExpression.hasAdvancedExtensions()
? extendedExpression.getAdvancedExtensions()
: null))
.baseSchema(namedStruct);
return builder.build();
}
}
25 changes: 22 additions & 3 deletions core/src/main/java/io/substrait/extension/ExtensionCollector.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package io.substrait.extension;

import com.github.bsideup.jabel.Desugar;
import io.substrait.proto.ExtendedExpression;
import io.substrait.proto.Plan;
import io.substrait.proto.SimpleExtensionDeclaration;
import io.substrait.proto.SimpleExtensionURI;
Expand Down Expand Up @@ -51,6 +53,20 @@ public int getTypeReference(SimpleExtension.TypeAnchor typeAnchor) {
}

public void addExtensionsToPlan(Plan.Builder builder) {
SimpleExtensions simpleExtensions = getExtensions();

builder.addAllExtensionUris(simpleExtensions.uris().values());
builder.addAllExtensions(simpleExtensions.extensionList());
}

public void addExtensionsToExtendedExpression(ExtendedExpression.Builder builder) {
SimpleExtensions simpleExtensions = getExtensions();

builder.addAllExtensionUris(simpleExtensions.uris().values());
builder.addAllExtensions(simpleExtensions.extensionList());
}

private SimpleExtensions getExtensions() {
var uriPos = new AtomicInteger(1);
var uris = new HashMap<String, SimpleExtensionURI>();

Expand Down Expand Up @@ -93,11 +109,14 @@ public void addExtensionsToPlan(Plan.Builder builder) {
.build();
extensionList.add(decl);
}

builder.addAllExtensionUris(uris.values());
builder.addAllExtensions(extensionList);
return new SimpleExtensions(uris, extensionList);
}

@Desugar
vbarua marked this conversation as resolved.
Show resolved Hide resolved
private record SimpleExtensions(
HashMap<String, SimpleExtensionURI> uris,
ArrayList<SimpleExtensionDeclaration> extensionList) {}

/** We don't depend on guava... */
private static class BidiMap<T1, T2> {
private final Map<T1, T2> forwardMap;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
package io.substrait.extension;

import io.substrait.proto.ExtendedExpression;
import io.substrait.proto.Plan;
import io.substrait.proto.SimpleExtensionDeclaration;
import io.substrait.proto.SimpleExtensionURI;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
Expand All @@ -30,14 +33,25 @@ public static class Builder {
private final Map<Integer, SimpleExtension.FunctionAnchor> functionMap = new HashMap<>();
private final Map<Integer, SimpleExtension.TypeAnchor> typeMap = new HashMap<>();

public Builder from(Plan p) {
public Builder from(Plan plan) {
return from(plan.getExtensionUrisList(), plan.getExtensionsList());
}

public Builder from(ExtendedExpression extendedExpression) {
return from(
extendedExpression.getExtensionUrisList(), extendedExpression.getExtensionsList());
}

private Builder from(
List<SimpleExtensionURI> simpleExtensionURIs,
List<SimpleExtensionDeclaration> simpleExtensionDeclarations) {
Map<Integer, String> namespaceMap = new HashMap<>();
for (var extension : p.getExtensionUrisList()) {
for (var extension : simpleExtensionURIs) {
namespaceMap.put(extension.getExtensionUriAnchor(), extension.getUri());
}

// Add all functions used in plan to the functionMap
for (var extension : p.getExtensionsList()) {
for (var extension : simpleExtensionDeclarations) {
if (!extension.hasExtensionFunction()) {
continue;
}
Expand All @@ -54,7 +68,7 @@ public Builder from(Plan p) {
}

// Add all types used in plan to the typeMap
for (var extension : p.getExtensionsList()) {
for (var extension : simpleExtensionDeclarations) {
if (!extension.hasExtensionType()) {
continue;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package io.substrait.relation;

import io.substrait.expression.FunctionArg;
import io.substrait.expression.proto.ExpressionProtoConverter;
import io.substrait.extension.ExtensionCollector;
import io.substrait.proto.AggregateFunction;
import io.substrait.type.proto.TypeProtoConverter;
import java.util.stream.IntStream;

/**
* Converts from {@link io.substrait.relation.Aggregate.Measure} to {@link
* io.substrait.proto.AggregateFunction}
*/
public class AggregateFunctionProtoController {

private final ExpressionProtoConverter exprProtoConverter;
private final TypeProtoConverter typeProtoConverter;
private final ExtensionCollector functionCollector;

public AggregateFunctionProtoController(ExtensionCollector functionCollector) {
this.functionCollector = functionCollector;
this.exprProtoConverter = new ExpressionProtoConverter(functionCollector, null);
this.typeProtoConverter = new TypeProtoConverter(functionCollector);
}

public AggregateFunction toProto(Aggregate.Measure measure) {
var argVisitor = FunctionArg.toProto(typeProtoConverter, exprProtoConverter);
var args = measure.getFunction().arguments();
var aggFuncDef = measure.getFunction().declaration();

return AggregateFunction.newBuilder()
.setPhase(measure.getFunction().aggregationPhase().toProto())
.setInvocation(measure.getFunction().invocation().toProto())
.setOutputType(measure.getFunction().getType().accept(typeProtoConverter))
.addAllArguments(
IntStream.range(0, args.size())
.mapToObj(i -> args.get(i).accept(aggFuncDef, i, argVisitor))
.collect(java.util.stream.Collectors.toList()))
.setFunctionReference(
functionCollector.getFunctionReference(measure.getFunction().declaration()))
.build();
}
}
17 changes: 17 additions & 0 deletions core/src/main/java/io/substrait/type/NamedStruct.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.substrait.type;

import io.substrait.type.proto.ProtoTypeConverter;
import io.substrait.type.proto.TypeProtoConverter;
import java.util.List;
import org.immutables.value.Value;
Expand All @@ -21,4 +22,20 @@ default io.substrait.proto.NamedStruct toProto(TypeProtoConverter typeProtoConve
.addAllNames(names())
.build();
}

static io.substrait.type.NamedStruct fromProto(
io.substrait.proto.NamedStruct namedStruct, ProtoTypeConverter protoTypeConverter) {
var struct = namedStruct.getStruct();
return ImmutableNamedStruct.builder()
.names(namedStruct.getNamesList())
.struct(
Type.Struct.builder()
.fields(
struct.getTypesList().stream()
.map(protoTypeConverter::from)
.collect(java.util.stream.Collectors.toList()))
.nullable(ProtoTypeConverter.isNullable(struct.getNullability()))
.build())
.build();
}
}
Loading