Skip to content
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,11 @@ public O visit(Expression.MultiOrList expr, C context) throws E {
return visitFallback(expr, context);
}

@Override
public O visit(Expression.NestedList expr, C context) throws E {
return visitFallback(expr, context);
}

@Override
public O visit(FieldReference expr, C context) throws E {
return visitFallback(expr, context);
Expand Down
41 changes: 41 additions & 0 deletions core/src/main/java/io/substrait/expression/Expression.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@ default boolean nullable() {
}
}

interface Nested extends Expression {
@Value.Default
default boolean nullable() {
return false;
}
}

<R, C extends VisitationContext, E extends Throwable> R accept(
ExpressionVisitor<R, C, E> visitor, C context) throws E;

Expand Down Expand Up @@ -922,6 +929,40 @@ public <R, C extends VisitationContext, E extends Throwable> R accept(
}
}

/**
* A nested list expression with one or more elements.
*
* <p>Note: This class cannot be used to construct an empty list. To create an empty list, use
* {@link ExpressionCreator#emptyList(boolean, Type)} which returns an {@link EmptyListLiteral}.
*/
@Value.Immutable
abstract class NestedList implements Nested {
public abstract List<Expression> values();

@Value.Check
protected void check() {
assert !values().isEmpty() : "To specify an empty list, use ExpressionCreator.emptyList()";

assert values().stream().map(Expression::getType).distinct().count() <= 1
: "All values in NestedList must have the same type";
}

@Override
public Type getType() {
return Type.withNullability(nullable()).list(values().get(0).getType());
}

@Override
public <R, C extends VisitationContext, E extends Throwable> R accept(
ExpressionVisitor<R, C, E> visitor, C context) throws E {
return visitor.visit(this, context);
}

public static ImmutableExpression.NestedList.Builder builder() {
return ImmutableExpression.NestedList.builder();
}
}

@Value.Immutable
abstract class MultiOrListRecord {
public abstract List<Expression> values();
Expand Down
11 changes: 11 additions & 0 deletions core/src/main/java/io/substrait/expression/ExpressionCreator.java
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,17 @@ public static Expression.StructLiteral struct(boolean nullable, Expression.Liter
return Expression.StructLiteral.builder().nullable(nullable).addFields(values).build();
}

/**
* Creates a nested list expression with one or more elements.
*
* <p>Note: This class cannot be used to construct an empty list. To create an empty list, use
* {@link ExpressionCreator#emptyList(boolean, Type)} which returns an {@link
* Expression.EmptyListLiteral}.
*/
public static Expression.NestedList nestedList(boolean nullable, List<Expression> values) {
return Expression.NestedList.builder().nullable(nullable).addAllValues(values).build();
}

public static Expression.StructLiteral struct(
boolean nullable, Iterable<? extends Expression.Literal> values) {
return Expression.StructLiteral.builder().nullable(nullable).addAllFields(values).build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ public interface ExpressionVisitor<R, C extends VisitationContext, E extends Thr

R visit(Expression.MultiOrList expr, C context) throws E;

R visit(Expression.NestedList expr, C context) throws E;

R visit(FieldReference expr, C context) throws E;

R visit(Expression.SetPredicate expr, C context) throws E;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,22 @@ public Expression visit(
.build();
}

@Override
public Expression visit(
io.substrait.expression.Expression.NestedList expr, EmptyVisitationContext context)
throws RuntimeException {

List<Expression> values =
expr.values().stream().map(this::toProto).collect(Collectors.toList());

return Expression.newBuilder()
.setNested(
Expression.Nested.newBuilder()
.setList(Expression.Nested.List.newBuilder().addAllValues(values))
.setNullable(expr.nullable()))
.build();
}

@Override
public Expression visit(FieldReference expr, EmptyVisitationContext context) {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,8 @@ public Expression from(io.substrait.proto.Expression expr) {
multiOrList.getValueList().stream().map(this::from).collect(Collectors.toList()))
.build();
}
case NESTED:
return from(expr.getNested());
case CAST:
return ExpressionCreator.cast(
protoTypeConverter.from(expr.getCast().getType()),
Expand Down Expand Up @@ -361,6 +363,18 @@ private WindowBound toWindowBound(io.substrait.proto.Expression.WindowFunction.B
}
}

public Expression.Nested from(io.substrait.proto.Expression.Nested nested) {
switch (nested.getNestedTypeCase()) {
case LIST:
List<Expression> list =
nested.getList().getValuesList().stream().map(this::from).collect(Collectors.toList());
return ExpressionCreator.nestedList(nested.getNullable(), list);
default:
throw new UnsupportedOperationException(
"Unimplemented nested type: " + nested.getNestedTypeCase());
}
}

public Expression.Literal from(io.substrait.proto.Expression.Literal literal) {
switch (literal.getLiteralTypeCase()) {
case BOOLEAN:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,16 @@ public Optional<Expression> visit(
.build());
}

@Override
public Optional<Expression> visit(Expression.NestedList expr, EmptyVisitationContext context)
throws E {
Optional<List<Expression>> expressions = visitExprList(expr.values(), context);

return expressions.map(
expressionList ->
Expression.NestedList.builder().from(expr).values(expressionList).build());
}

protected Optional<Expression.MultiOrListRecord> visitMultiOrListRecord(
Expression.MultiOrListRecord multiOrListRecord, EmptyVisitationContext context) throws E {
return visitExprList(multiOrListRecord.values(), context)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package io.substrait.type.proto;

import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertThrows;

import io.substrait.TestBase;
import io.substrait.expression.Expression;
import io.substrait.expression.ImmutableExpression;
import org.junit.jupiter.api.Test;

class NestedListExpressionTest extends TestBase {
io.substrait.expression.Expression literalExpression =
Expression.BoolLiteral.builder().value(true).build();
Expression.ScalarFunctionInvocation nonLiteralExpression = b.add(b.i32(7), b.i32(42));

@Test
void rejectNestedListWithElementsOfDifferentTypes() {
ImmutableExpression.NestedList.Builder builder =
Expression.NestedList.builder().addValues(literalExpression).addValues(b.i32(12));
assertThrows(AssertionError.class, builder::build);
}

@Test
void acceptNestedListWithElementsOfSameType() {
ImmutableExpression.NestedList.Builder builder =
Expression.NestedList.builder().addValues(nonLiteralExpression).addValues(b.i32(12));
assertDoesNotThrow(builder::build);

io.substrait.relation.Project project =
io.substrait.relation.Project.builder()
.addExpressions(builder.build())
.input(b.emptyScan())
.build();
verifyRoundTrip(project);
}

@Test
void rejectEmptyNestedListTest() {
ImmutableExpression.NestedList.Builder builder = Expression.NestedList.builder();
assertThrows(AssertionError.class, builder::build);
}

@Test
void literalNestedListTest() {
Expression.NestedList literalNestedList =
Expression.NestedList.builder()
.addValues(literalExpression)
.addValues(literalExpression)
.build();

io.substrait.relation.Project project =
io.substrait.relation.Project.builder()
.addExpressions(literalNestedList)
.input(b.emptyScan())
.build();

verifyRoundTrip(project);
}

@Test
void literalNullableNestedListTest() {
Expression.NestedList literalNestedList =
Expression.NestedList.builder()
.addValues(literalExpression)
.addValues(literalExpression)
.nullable(true)
.build();

io.substrait.relation.Project project =
io.substrait.relation.Project.builder()
.addExpressions(literalNestedList)
.input(b.emptyScan())
.build();

verifyRoundTrip(project);
}

@Test
void nonLiteralNestedListTest() {
Expression.NestedList nonLiteralNestedList =
Expression.NestedList.builder()
.addValues(nonLiteralExpression)
.addValues(nonLiteralExpression)
.build();

io.substrait.relation.Project project =
io.substrait.relation.Project.builder()
.addExpressions(nonLiteralNestedList)
.input(b.emptyScan())
.build();

verifyRoundTrip(project);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ public class ProjectRelRoundtripTest extends TestBase {
Arrays.asList("col_a", "col_b", "col_c", "col_d"),
Arrays.asList(R.I64, R.FP64, R.STRING, R.I32));

final Rel emptyTable = b.emptyScan();

@Test
void simpleProjection() {
// Project single field
Expand Down Expand Up @@ -146,4 +148,11 @@ void emptyProjection() {

verifyRoundTrip(projection);
}

@Test
void avoidProjectRemapOnEmptyInput() {
Rel projection =
Project.builder().input(emptyTable).addExpressions(b.add(b.i32(1), b.i32(2))).build();
verifyRoundTrip(projection);
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.substrait.examples.util;

import io.substrait.expression.Expression;
import io.substrait.expression.Expression.BinaryLiteral;
import io.substrait.expression.Expression.BoolLiteral;
import io.substrait.expression.Expression.Cast;
Expand Down Expand Up @@ -256,6 +257,12 @@ public String visit(MultiOrList expr, EmptyVisitationContext context) throws Run
return sb.toString();
}

@Override
public String visit(Expression.NestedList expr, EmptyVisitationContext context)
throws RuntimeException {
return "<NestedList>";
}

@Override
public String visit(FieldReference expr, EmptyVisitationContext context) throws RuntimeException {
StringBuilder sb = new StringBuilder("FieldRef#");
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package io.substrait.isthmus;

import static java.util.Objects.requireNonNull;

import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlOperatorBinding;
import org.apache.calcite.sql.fun.SqlMultisetValueConstructor;
import org.apache.calcite.sql.type.SqlTypeUtil;
import org.apache.calcite.sql.validate.SqlValidatorUtil;

/**
* Substrait-specific constructor to map back to the Expression NestedList type in Substrait. This
* constructor creates a special type of SqlKind.ARRAY_VALUE_CONSTRUCTOR for lists that can contain
* both literal and non-literal expressions.
*/
public class NestedListConstructor extends SqlMultisetValueConstructor {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess you picked SqlMultisetValueConstructor instead of the more natural SqlArrayValueConstructor because CallConverters.java#L144 would then match on your NestedList and treat it as the regular SqlArrayValueConstructor (then invoke LiteralConstructorConverter.java#L32 which is not what we want here).

SqlMultisetValueConstructor is conceptually wrong as a multiset is radically different from an array/list in standard SQL, on top of that I imagine that tomorrow you might want to support something similar and hit the same issue you are trying to avoid here by using this class in the first place.

The impedance mismatch is that in SQL (and Calcite), arrays and lists are technically the same entity, while IIRC in Substrait they are treated as different entities (@benbellick can you confirm this?).

By looking at LiteralConstructorConverter, there is an implicit assumption that arrays store only literals, we go down that route without checking if elements in the array are really literals (LiteralConstructorConverter.java#L62).

It's probably enough to change LiteralConstructorConverter::toNonEmptyListLiteral to something like this (haven't tested it):

private Optional<Expression> toNonEmptyListLiteral(
      RexCall call, Function<RexNode, Expression> topLevelConverter) {
    List<Expression> expressions = call.operands.stream()
        .map(topLevelConverter)
        .collect(Collectors.toList());

    // Check if all operands are actually literals
    if (expressions.stream().allMatch(e -> e instanceof Expression.Literal)) {
      return Optional.of(ExpressionCreator.list(
          call.getType().isNullable(),
          expressions.stream()
              .map(e -> (Expression.Literal) e)
              .collect(Collectors.toList())));
    }

    return Optional.empty();
  }

I suggest to extend SqlArrayValueConstructor (which, I know, extends SqlMultisetValueConstructor, but still), then fix LiteralConstructorConverter as suggested, so that we can continue with NestedExpressionConverter which comes just after, and we should be good

Copy link
Contributor Author

@gord02 gord02 Dec 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to make sure I am understanding it correctly, let me know your thoughts on the following scenarios: We want to ensure that the roundtrip of both a nested list with literals and non-literals are both returned to a nested list. If the literalConstructorConverter is run first on a list of just literals, then it would pass and then wouldn't be mapped back to a nested list. In the other case, where the nestedExpressionConverter is run first, the literal lists that were originally not a NestedList would become a nested list. Does the above account for this or is the difference not important?

Also, is there a way to meaningfully extend the SqlArrayValueConstructor class? Its definition is bare with just a constructor to its parent type.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no distinction between lists and arrays in substrait. Only list is a legitimate type. Array is informally used in some places in the docs, but that doesn't actually exist as a distinct type.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To your point about

We want to ensure that the roundtrip of both a nested list with literals and non-literals are both returned to a nested list.

The design you have works to solve the problem of "NestedList in is always NestedList out". Thinking about what Alessandro is saying though, I actually don't think it's worth preserving that. There is a little bit of redundancy on the Substrait side in that list containing only literal values can be expressed equivalently as a Literal.List or a Nested.List.

In Calcite, these can both be mapped a SqlArrayValueConstructor. You introduced your NestedListConstructor to be able to distinguish between the two incoming cases so as to roundtrip them, but what if instead of doing that we just said:

  1. A SqlArrayValueConstructor call with all literals gets turned into a Substrait Literal.List
  2. A SqlArrayValueConstructor call with any non-literals gets turned into Substrat Nested.List

This isn't as nice from a roundtrip perspective, but from a pure Calcite perspective we're mapping to the most specific Substrait construct that we can use, which is better in my opinion.

One thing I would push for that's different from what Alessandro is suggesting is to have a single call converter SqlArrayValueConstructorCallConverter that combines the array converting code in LiteralConstructorConverter and your code in NestedExpressionCallConverter. Effectively, let's not split the code that handles arrays with only literals, and arrays with non-literals, and just put it all in one class so that it's easy to see how the behaviour works and you don't have to look at 2 different converters for 1 Calcite construct.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @benbellick and @vbarua for the clarification, the fact that only List exists in Substrait actually makes things easier, as we can map it to Array in standard SQL (and in Calcite, where it's informally called List too sometimes).

In light of this clarification, my previous suggestion is an unneeded over complication, and I fully agree with Victor's proposal.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I created a new file SqlArrayValueConstructorCallConverter that has the code from the LiteralConstructorConverter and the nestedListConstructor and therefore removed both files. I also updated the tests to check for the listLiteral in the event of converting from a nestedList of literals, and compared the resulting list of expressions to the beginning list.


public NestedListConstructor() {
super("NESTEDLIST", SqlKind.ARRAY_VALUE_CONSTRUCTOR);
}

@Override
public RelDataType inferReturnType(SqlOperatorBinding opBinding) {
RelDataType type =
getComponentType(opBinding.getTypeFactory(), opBinding.collectOperandTypes());
requireNonNull(type, "inferred array element type");

// explicit cast elements to component type if they are not same
SqlValidatorUtil.adjustTypeForArrayConstructor(type, opBinding);

return SqlTypeUtil.createArrayType(opBinding.getTypeFactory(), type, false);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,11 @@ public Rel visit(org.apache.calcite.rel.core.Project project) {
.map(this::toExpression)
.collect(java.util.stream.Collectors.toList());

// if there are no input fields, no remap is necessary
if (project.getInput().getRowType().getFieldCount() == 0) {
return Project.builder().expressions(expressions).input(apply(project.getInput())).build();
}

// todo: eliminate excessive projects. This should be done by converting rexinputrefs to remaps.
return Project.builder()
.remap(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package io.substrait.isthmus.calcite;

import io.substrait.isthmus.AggregateFunctions;
import io.substrait.isthmus.NestedListConstructor;
import java.util.EnumSet;
import java.util.List;
import java.util.Set;
Expand Down Expand Up @@ -36,6 +37,8 @@ public class SubstraitOperatorTable implements SqlOperatorTable {
AggregateFunctions.SUM,
AggregateFunctions.SUM0));

public static NestedListConstructor NESTED_LIST_CONSTRUCTOR = new NestedListConstructor();

// SQL Kinds for which Substrait specific operators are provided
private static final Set<SqlKind> OVERRIDE_KINDS =
EnumSet.copyOf(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,8 @@ public static List<CallConverter> defaults(TypeConverter typeConverter) {
CallConverters.CASE,
CallConverters.CAST.apply(typeConverter),
CallConverters.REINTERPRET.apply(typeConverter),
new LiteralConstructorConverter(typeConverter));
new LiteralConstructorConverter(typeConverter),
new NestedExpressionCallConverter());
}

public interface SimpleCallConverter extends CallConverter {
Expand Down
Loading