Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: enable support from/to pojo/protobuf for extended expressions #206

Merged
merged 14 commits into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package io.substrait.extendedexpression;

import io.substrait.expression.Expression;
import io.substrait.proto.AdvancedExtension;
import io.substrait.proto.AggregateFunction;
import io.substrait.type.NamedStruct;
import java.util.List;
import java.util.Optional;
import org.immutables.value.Value;

@Value.Immutable
public abstract class ExtendedExpression {
public abstract List<ExpressionReference> getReferredExpressions();

public abstract NamedStruct getBaseSchema();

public abstract List<String> getExpectedTypeUrls();

// creating simple extensions, such as extensionURIs and extensions, is performed on the fly

public abstract Optional<AdvancedExtension> getAdvancedExtension();

@Value.Immutable
public abstract static class ExpressionReference {
public abstract ExpressionTypeReference getExpressionType();

public abstract List<String> getOutputNames();
}

public abstract static class ExpressionTypeReference {}

@Value.Immutable
public abstract static class ExpressionType extends ExpressionTypeReference {
public abstract Expression getExpression();
}

@Value.Immutable
public abstract static class AggregateFunctionType extends ExpressionTypeReference {
public abstract AggregateFunction getMeasure();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package io.substrait.extendedexpression;

import io.substrait.expression.Expression;
import io.substrait.expression.proto.ExpressionProtoConverter;
import io.substrait.extension.ExtensionCollector;
import io.substrait.proto.ExpressionReference;
import io.substrait.proto.ExtendedExpression;
import io.substrait.type.proto.TypeProtoConverter;

/** Converts from {@link ExtendedExpression} to {@link ExtendedExpression} */
public class ExtendedExpressionProtoConverter {
public ExtendedExpression toProto(
io.substrait.extendedexpression.ExtendedExpression extendedExpression) {

ExtendedExpression.Builder builder = ExtendedExpression.newBuilder();
ExtensionCollector functionCollector = new ExtensionCollector();

final ExpressionProtoConverter expressionProtoConverter =
new ExpressionProtoConverter(functionCollector, null);

for (io.substrait.extendedexpression.ExtendedExpression.ExpressionReference
expressionReference : extendedExpression.getReferredExpressions()) {
io.substrait.extendedexpression.ExtendedExpression.ExpressionTypeReference expressionType =
expressionReference.getExpressionType();
if (expressionType
instanceof io.substrait.extendedexpression.ExtendedExpression.ExpressionType) {
io.substrait.proto.Expression expressionProto =
expressionProtoConverter.visit(
(Expression.ScalarFunctionInvocation)
((io.substrait.extendedexpression.ExtendedExpression.ExpressionType)
expressionType)
.getExpression());
ExpressionReference.Builder expressionReferenceBuilder =
ExpressionReference.newBuilder()
.setExpression(expressionProto)
.addAllOutputNames(expressionReference.getOutputNames());
builder.addReferredExpr(expressionReferenceBuilder);
} else if (expressionType
instanceof io.substrait.extendedexpression.ExtendedExpression.AggregateFunctionType) {
throw new UnsupportedOperationException(
"Aggregate function types are not supported in conversion to proto Extended Expressions for now");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we go ahead and implement in this PR? Would it be a lot of work?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just added.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this actually may be a fair bit of work to do properly. In the interest of keeping this PR small and moving this work along, I think we could include AggregateFunction support as a future change.

What do you think @danepitkin

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 I agree. I didn't realize the level of effort required here. Sorry, @davisusanibar !

} else {
throw new UnsupportedOperationException(
"Only Expression or Aggregate Function type are supported in conversion to proto Extended Expressions for now");
}
}
builder.setBaseSchema(
extendedExpression.getBaseSchema().toProto(new TypeProtoConverter(functionCollector)));

// the process of adding simple extensions, such as extensionURIs and extensions, is handled on
// the fly
functionCollector.addExtensionsToExtendedExpression(builder);
if (extendedExpression.getAdvancedExtension().isPresent()) {
builder.setAdvancedExtensions(extendedExpression.getAdvancedExtension().get());
}
return builder.build();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package io.substrait.extendedexpression;

import io.substrait.expression.Expression;
import io.substrait.expression.proto.ProtoExpressionConverter;
import io.substrait.extension.*;
import io.substrait.proto.ExpressionReference;
import io.substrait.proto.NamedStruct;
import io.substrait.type.proto.ProtoTypeConverter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;

/** Converts from {@link io.substrait.proto.ExtendedExpression} to {@link ExtendedExpression} */
public class ProtoExtendedExpressionConverter {
private final SimpleExtension.ExtensionCollection extensionCollection;

public ProtoExtendedExpressionConverter() throws IOException {
this(SimpleExtension.loadDefaults());
}

public ProtoExtendedExpressionConverter(SimpleExtension.ExtensionCollection extensionCollection) {
this.extensionCollection = extensionCollection;
}

private final ProtoTypeConverter protoTypeConverter =
new ProtoTypeConverter(
new ExtensionCollector(), ImmutableSimpleExtension.ExtensionCollection.builder().build());

public ExtendedExpression from(io.substrait.proto.ExtendedExpression extendedExpression) {
// fill in simple extension information through a discovery in the current proto-extended
// expression
ExtensionLookup functionLookup =
ImmutableExtensionLookup.builder().from(extendedExpression).build();

NamedStruct baseSchemaProto = extendedExpression.getBaseSchema();

io.substrait.type.NamedStruct namedStruct =
io.substrait.type.NamedStruct.convertNamedStructProtoToPojo(
baseSchemaProto, protoTypeConverter);

ProtoExpressionConverter protoExpressionConverter =
new ProtoExpressionConverter(
functionLookup, this.extensionCollection, namedStruct.struct(), null);

List<ExtendedExpression.ExpressionReference> expressionReferences = new ArrayList<>();
for (ExpressionReference expressionReference : extendedExpression.getReferredExprList()) {
if (expressionReference.getExprTypeCase().getNumber() == 1) { // Expression
Expression expressionPojo =
protoExpressionConverter.from(expressionReference.getExpression());
expressionReferences.add(
ImmutableExpressionReference.builder()
.expressionType(
ImmutableExpressionType.builder().expression(expressionPojo).build())
.addAllOutputNames(expressionReference.getOutputNamesList())
.build());
} else if (expressionReference.getExprTypeCase().getNumber() == 2) { // AggregateFunction
throw new UnsupportedOperationException(
"Aggregate function types are not supported in conversion from proto Extended Expressions for now");
} else {
throw new UnsupportedOperationException(
"Only Expression or Aggregate Function type are supported in conversion from proto Extended Expressions for now");
}
}

ImmutableExtendedExpression.Builder builder =
ImmutableExtendedExpression.builder()
.referredExpressions(expressionReferences)
.advancedExtension(
Optional.ofNullable(
extendedExpression.hasAdvancedExtensions()
? extendedExpression.getAdvancedExtensions()
: null))
.baseSchema(namedStruct);
return builder.build();
}
}
25 changes: 22 additions & 3 deletions core/src/main/java/io/substrait/extension/ExtensionCollector.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package io.substrait.extension;

import com.github.bsideup.jabel.Desugar;
import io.substrait.proto.ExtendedExpression;
import io.substrait.proto.Plan;
import io.substrait.proto.SimpleExtensionDeclaration;
import io.substrait.proto.SimpleExtensionURI;
Expand Down Expand Up @@ -51,6 +53,20 @@ public int getTypeReference(SimpleExtension.TypeAnchor typeAnchor) {
}

public void addExtensionsToPlan(Plan.Builder builder) {
SimpleExtensions simpleExtensions = getExtensions();

builder.addAllExtensionUris(simpleExtensions.uris().values());
builder.addAllExtensions(simpleExtensions.extensionList());
}

public void addExtensionsToExtendedExpression(ExtendedExpression.Builder builder) {
SimpleExtensions simpleExtensions = getExtensions();

builder.addAllExtensionUris(simpleExtensions.uris().values());
builder.addAllExtensions(simpleExtensions.extensionList());
}

private SimpleExtensions getExtensions() {
var uriPos = new AtomicInteger(1);
var uris = new HashMap<String, SimpleExtensionURI>();

Expand Down Expand Up @@ -93,11 +109,14 @@ public void addExtensionsToPlan(Plan.Builder builder) {
.build();
extensionList.add(decl);
}

builder.addAllExtensionUris(uris.values());
builder.addAllExtensions(extensionList);
return new SimpleExtensions(uris, extensionList);
}

@Desugar
private record SimpleExtensions(
HashMap<String, SimpleExtensionURI> uris,
ArrayList<SimpleExtensionDeclaration> extensionList) {}

/** We don't depend on guava... */
private static class BidiMap<T1, T2> {
private final Map<T1, T2> forwardMap;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
package io.substrait.extension;

import io.substrait.proto.ExtendedExpression;
import io.substrait.proto.Plan;
import io.substrait.proto.SimpleExtensionDeclaration;
import io.substrait.proto.SimpleExtensionURI;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
Expand All @@ -30,14 +33,25 @@ public static class Builder {
private final Map<Integer, SimpleExtension.FunctionAnchor> functionMap = new HashMap<>();
private final Map<Integer, SimpleExtension.TypeAnchor> typeMap = new HashMap<>();

public Builder from(Plan p) {
public Builder from(Plan plan) {
return from(plan.getExtensionUrisList(), plan.getExtensionsList());
}

public Builder from(ExtendedExpression extendedExpression) {
return from(
extendedExpression.getExtensionUrisList(), extendedExpression.getExtensionsList());
}

private Builder from(
List<SimpleExtensionURI> simpleExtensionURIs,
List<SimpleExtensionDeclaration> simpleExtensionDeclarations) {
Map<Integer, String> namespaceMap = new HashMap<>();
for (var extension : p.getExtensionUrisList()) {
for (var extension : simpleExtensionURIs) {
namespaceMap.put(extension.getExtensionUriAnchor(), extension.getUri());
}

// Add all functions used in plan to the functionMap
for (var extension : p.getExtensionsList()) {
for (var extension : simpleExtensionDeclarations) {
if (!extension.hasExtensionFunction()) {
continue;
}
Expand All @@ -54,7 +68,7 @@ public Builder from(Plan p) {
}

// Add all types used in plan to the typeMap
for (var extension : p.getExtensionsList()) {
for (var extension : simpleExtensionDeclarations) {
if (!extension.hasExtensionType()) {
continue;
}
Expand Down
17 changes: 17 additions & 0 deletions core/src/main/java/io/substrait/type/NamedStruct.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.substrait.type;

import io.substrait.type.proto.ProtoTypeConverter;
import io.substrait.type.proto.TypeProtoConverter;
import java.util.List;
import org.immutables.value.Value;
Expand All @@ -21,4 +22,20 @@ default io.substrait.proto.NamedStruct toProto(TypeProtoConverter typeProtoConve
.addAllNames(names())
.build();
}

static io.substrait.type.NamedStruct convertNamedStructProtoToPojo(
io.substrait.proto.NamedStruct namedStruct, ProtoTypeConverter protoTypeConverter) {
var struct = namedStruct.getStruct();
return ImmutableNamedStruct.builder()
.names(namedStruct.getNamesList())
.struct(
Type.Struct.builder()
.fields(
struct.getTypesList().stream()
.map(protoTypeConverter::from)
.collect(java.util.stream.Collectors.toList()))
.nullable(ProtoTypeConverter.isNullable(struct.getNullability()))
.build())
.build();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package io.substrait.extendedexpression;

import static org.junit.jupiter.api.Assertions.assertEquals;

import io.substrait.TestBase;
import io.substrait.expression.*;
import io.substrait.type.ImmutableNamedStruct;
import io.substrait.type.Type;
import io.substrait.type.TypeCreator;
import java.util.ArrayList;
import java.util.List;
import org.junit.jupiter.api.Test;

public class ExtendedExpressionProtoConverterTest extends TestBase {
static final String NAMESPACE = "/functions_arithmetic_decimal.yaml";

@Test
public void toProtoTest() {
// create predefined POJO extended expression
Expression.ScalarFunctionInvocation scalarFunctionInvocation =
b.scalarFn(
NAMESPACE,
"add:dec_dec",
TypeCreator.REQUIRED.BOOLEAN,
ImmutableFieldReference.builder()
.addSegments(FieldReference.StructField.of(0))
.type(TypeCreator.REQUIRED.decimal(10, 2))
.build(),
ExpressionCreator.i32(false, 183));

ImmutableExpressionReference expressionReference =
ImmutableExpressionReference.builder()
.expressionType(
ImmutableExpressionType.builder().expression(scalarFunctionInvocation).build())
.addOutputNames("new-column")
.build();

List<ExtendedExpression.ExpressionReference> expressionReferences = new ArrayList<>();
expressionReferences.add(expressionReference);

ImmutableNamedStruct namedStruct =
ImmutableNamedStruct.builder()
.addNames("N_NATIONKEY", "N_NAME", "N_REGIONKEY", "N_COMMENT")
.struct(
Type.Struct.builder()
.nullable(false)
.addFields(
TypeCreator.NULLABLE.decimal(10, 2),
TypeCreator.REQUIRED.STRING,
TypeCreator.REQUIRED.decimal(10, 2),
TypeCreator.REQUIRED.STRING)
.build())
.build();

ImmutableExtendedExpression.Builder extendedExpression =
ImmutableExtendedExpression.builder()
.referredExpressions(expressionReferences)
.baseSchema(namedStruct);

// convert POJO extended expression into PROTOBUF extended expression
io.substrait.proto.ExtendedExpression proto =
new ExtendedExpressionProtoConverter().toProto(extendedExpression.build());

assertEquals(NAMESPACE, proto.getExtensionUrisList().get(0).getUri());
assertEquals("add:dec_dec", proto.getExtensionsList().get(0).getExtensionFunction().getName());
}
}
Loading