Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ public static Expression.StructLiteral struct(boolean nullable, Expression.Liter
}

/**
* Creator a nested list expression with one or more elements.
* Creates a nested list expression with one or more elements.
*
* <p>Note: This class cannot be used to construct an empty list. To create an empty list, use
* {@link ExpressionCreator#emptyList(boolean, Type)} which returns an {@link
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,8 @@ public Expression.Nested from(io.substrait.proto.Expression.Nested nested) {
nested.getList().getValuesList().stream().map(this::from).collect(Collectors.toList());
return ExpressionCreator.nestedList(nested.getNullable(), list);
default:
throw new IllegalStateException("Unimplemented nested type: " + nested.getNestedTypeCase());
throw new UnsupportedOperationException(
"Unimplemented nested type: " + nested.getNestedTypeCase());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@ class NestedListExpressionTest extends TestBase {
Expression.ScalarFunctionInvocation nonLiteralExpression = b.add(b.i32(7), b.i32(42));

@Test
void DifferentTypedLiteralsNestedListTest() {
void rejectNestedListWithElementsOfDifferentTypes() {
ImmutableExpression.NestedList.Builder builder =
Expression.NestedList.builder().addValues(literalExpression).addValues(b.i32(12));
assertThrows(AssertionError.class, builder::build);
}

@Test
void SameTypedLiteralsNestedListTest() {
void acceptNestedListWithElementsOfSameType() {
ImmutableExpression.NestedList.Builder builder =
Expression.NestedList.builder().addValues(nonLiteralExpression).addValues(b.i32(12));
assertDoesNotThrow(builder::build);
Expand All @@ -35,7 +35,7 @@ void SameTypedLiteralsNestedListTest() {
}

@Test
void EmptyNestedListTest() {
void rejectEmptyNestedListTest() {
ImmutableExpression.NestedList.Builder builder = Expression.NestedList.builder();
assertThrows(AssertionError.class, builder::build);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ public class ProjectRelRoundtripTest extends TestBase {
Arrays.asList("col_a", "col_b", "col_c", "col_d"),
Arrays.asList(R.I64, R.FP64, R.STRING, R.I32));

final Rel emptyTable = b.emptyScan();

@Test
void simpleProjection() {
// Project single field
Expand Down Expand Up @@ -146,4 +148,11 @@ void emptyProjection() {

verifyRoundTrip(projection);
}

@Test
void avoidProjectRemapOnEmptyInput() {
Rel projection =
Project.builder().input(emptyTable).addExpressions(b.add(b.i32(1), b.i32(2))).build();
verifyRoundTrip(projection);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@

/**
* Substrait-specific constructor to map back to the Expression NestedList type in Substrait. This
* constructor creates a special type of SqlKind.ARRAY_VALUE_CONSTRUCTOR for lists that store
* non-literal expressions.
* constructor creates a special type of SqlKind.ARRAY_VALUE_CONSTRUCTOR for lists that can contain
* both literal and non-literal expressions.
*/
public class NestedListConstructor extends SqlMultisetValueConstructor {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,10 +179,11 @@ public Rel visit(org.apache.calcite.rel.core.Project project) {
.map(this::toExpression)
.collect(java.util.stream.Collectors.toList());

// if there is no input fields, don’t put a remapping on it
// if there are no input fields, no remap is necessary
if (project.getInput().getRowType().getFieldCount() == 0) {
return Project.builder().expressions(expressions).input(apply(project.getInput())).build();
}

// todo: eliminate excessive projects. This should be done by converting rexinputrefs to remaps.
return Project.builder()
.remap(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ public static List<CallConverter> defaults(TypeConverter typeConverter) {
CallConverters.CAST.apply(typeConverter),
CallConverters.REINTERPRET.apply(typeConverter),
new LiteralConstructorConverter(typeConverter),
new NestedExpressionConverter());
new NestedExpressionCallConverter());
}

public interface SimpleCallConverter extends CallConverter {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.util.TimeString;
import org.apache.calcite.util.TimestampString;

Expand Down Expand Up @@ -325,7 +326,19 @@ public RexNode visit(Expression.ListLiteral expr, Context context) throws Runtim
public RexNode visit(Expression.NestedList expr, Context context) {
List<RexNode> args =
expr.values().stream().map(e -> e.accept(this, context)).collect(Collectors.toList());
return rexBuilder.makeCall(SubstraitOperatorTable.NESTED_LIST_CONSTRUCTOR, args);

// to preserve NestedList nullability
RelDataType elementType;
if (args.isEmpty()) {
elementType = typeFactory.createSqlType(SqlTypeName.ANY);
} else {
elementType = args.get(0).getType();
}
RelDataType nestedListType = typeFactory.createArrayType(elementType, -1);
nestedListType = typeFactory.createTypeWithNullability(nestedListType, expr.nullable());

return rexBuilder.makeCall(
nestedListType, SubstraitOperatorTable.NESTED_LIST_CONSTRUCTOR, args);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,23 @@

import io.substrait.expression.Expression;
import io.substrait.isthmus.CallConverter;
import io.substrait.isthmus.NestedListConstructor;
import java.util.List;
import java.util.Optional;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexNode;

public class NestedExpressionConverter implements CallConverter {
public class NestedExpressionCallConverter implements CallConverter {

public NestedExpressionConverter() {}
public NestedExpressionCallConverter() {}

@Override
public Optional<Expression> convert(
RexCall call, Function<RexNode, Expression> topLevelConverter) {

if (!call.getOperator().getName().equals("NESTEDLIST")) {
if (!(call.getOperator() instanceof NestedListConstructor)) {
return Optional.empty();
}

Expand Down
115 changes: 100 additions & 15 deletions isthmus/src/test/java/io/substrait/isthmus/NestedExpressionsTest.java
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package io.substrait.isthmus;

import static org.junit.jupiter.api.Assertions.assertEquals;

import com.google.protobuf.ByteString;
import io.substrait.dsl.SubstraitBuilder;
import io.substrait.expression.Expression;
import io.substrait.extension.DefaultExtensionCatalog;
Expand All @@ -11,27 +10,29 @@
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import org.apache.calcite.rel.RelNode;
import org.junit.jupiter.api.Test;

class NestedExpressionsTest extends PlanTestBase {

protected static final SimpleExtension.ExtensionCollection defaultExtensionCollection =
DefaultExtensionCatalog.DEFAULT_COLLECTION;
protected SubstraitBuilder b = new SubstraitBuilder(defaultExtensionCollection);
SubstraitToCalcite substraitToCalcite = new SubstraitToCalcite(extensions, typeFactory);

io.substrait.expression.Expression literalExpression =
Expression.BoolLiteral.builder().value(true).build();
Expression.ScalarFunctionInvocation nonLiteralExpression = b.add(b.i32(7), b.i32(42));
Expression.ScalarFunctionInvocation nonLiteralExpression2 = b.add(b.i32(3), b.i32(4));

final List<Type> tableType = List.of(R.I32, R.FP32, N.STRING, N.BOOLEAN);
final Rel commonTable = b.namedScan(List.of("example"), List.of("a", "b", "c", "d"), tableType);
final List<Type> tableType = List.of(R.I32, R.FP32, N.STRING, N.BOOLEAN, N.STRING);
final Rel commonTable =
b.namedScan(List.of("example"), List.of("a", "b", "c", "d", "e"), tableType);
final Rel emptyTable = b.emptyScan();

Expression fieldRef1 = b.fieldReference(commonTable, 2);
Expression fieldRef2 = b.fieldReference(commonTable, 4);

@Test
void NestedListWithJustLiteralsTest() {
void nestedListWithLiteralsTest() {
List<Expression> expressionList = new ArrayList<>();
Expression.NestedList literalNestedList =
Expression.NestedList.builder()
Expand All @@ -46,13 +47,11 @@ void NestedListWithJustLiteralsTest() {
.input(emptyTable)
.build();

RelNode relNode = substraitToCalcite.convert(project); // substrait rel to calcite
Rel project2 = SubstraitRelVisitor.convert(relNode, extensions); // calcite to substrait
assertEquals(project, project2); // pojo -> calcite -> pojo
assertFullRoundTrip(project);
}

@Test
void NestedListWithNonLiteralsTest() {
void nestedListWithNonLiteralsTest() {
List<Expression> expressionList = new ArrayList<>();

Expression.NestedList nonLiteralNestedList =
Expand All @@ -66,11 +65,97 @@ void NestedListWithNonLiteralsTest() {
io.substrait.relation.Project.builder()
.expressions(expressionList)
.input(commonTable)
.remap(Rel.Remap.of(Collections.singleton(4)))
// project only the nestedList expression and exclude the 5 input columns
.remap(Rel.Remap.of(Collections.singleton(5)))
.build();

assertFullRoundTrip(project);
}

@Test
void nestedListWithFieldReferenceTest() {
Expression.NestedList nestedListWithField =
Expression.NestedList.builder().addValues(fieldRef1).addValues(fieldRef2).build();

List<Expression> expressionList = new ArrayList<>();
expressionList.add(nestedListWithField);

io.substrait.relation.Project project =
io.substrait.relation.Project.builder()
.expressions(expressionList)
.input(commonTable)
.remap(Rel.Remap.of(Collections.singleton(5)))
.build();

assertFullRoundTrip(project);
}

@Test
void nestedListWithStringLiteralsTest() {
Expression.NestedList nestedList =
Expression.NestedList.builder().addValues(b.str("xzy")).addValues(b.str("abc")).build();

Rel project =
io.substrait.relation.Project.builder()
.expressions(List.of(nestedList))
.input(emptyTable)
.build();

assertFullRoundTrip(project);
}

@Test
void nestedListWithBinaryLiteralTest() {
Expression binaryLiteral =
Expression.BinaryLiteral.builder()
.value(ByteString.copyFrom(new byte[] {0x01, 0x02}))
.build();

Expression.NestedList nestedList =
Expression.NestedList.builder().addValues(binaryLiteral).addValues(binaryLiteral).build();

Rel project =
io.substrait.relation.Project.builder()
.expressions(List.of(nestedList))
.input(emptyTable)
.build();

assertFullRoundTrip(project);
}

@Test
void nestedListWithSingleLiteralTest() {
List<Expression> expressionList = new ArrayList<>();
Expression.NestedList literalNestedList =
Expression.NestedList.builder().addValues(literalExpression).build();
expressionList.add(literalNestedList);

io.substrait.relation.Project project =
io.substrait.relation.Project.builder()
.expressions(expressionList)
.input(emptyTable)
.build();

assertFullRoundTrip(project);
}

@Test
void nullableNestedListTest() {
List<Expression> expressionList = new ArrayList<>();
Expression.NestedList literalNestedList =
Expression.NestedList.builder()
.addValues(literalExpression)
.addValues(literalExpression)
.nullable(true)
.build();
expressionList.add(literalNestedList);

io.substrait.relation.Project project =
io.substrait.relation.Project.builder()
.expressions(expressionList)
.input(emptyTable)
.build();

RelNode relNode = substraitToCalcite.convert(project); // substrait rel to calcite
Rel project2 = SubstraitRelVisitor.convert(relNode, extensions); // calcite to substrait
assertEquals(project, project2); // pojo -> calcite -> pojo
assertFullRoundTrip(project);
}
}