Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,11 @@ public OUTPUT visit(Expression.StructLiteral expr) throws EXCEPTION {
return visitFallback(expr);
}

@Override
public OUTPUT visit(Expression.StructNested expr) throws EXCEPTION {
return visitFallback(expr);
}

@Override
public OUTPUT visit(Expression.Switch expr) throws EXCEPTION {
return visitFallback(expr);
Expand Down
21 changes: 21 additions & 0 deletions core/src/main/java/io/substrait/expression/Expression.java
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,27 @@ public <R, E extends Throwable> R accept(ExpressionVisitor<R, E> visitor) throws
}
}

@Value.Immutable
abstract static class StructNested implements Expression {
public abstract List<Expression> fields();

public Type getType() {
return Type.withNullability(false)
.struct(
fields().stream()
.map(Expression::getType)
.collect(java.util.stream.Collectors.toList()));
}

public static ImmutableExpression.StructNested.Builder builder() {
return ImmutableExpression.StructNested.builder();
}

public <R, E extends Throwable> R accept(ExpressionVisitor<R, E> visitor) throws E {
return visitor.visit(this);
}
}

@Value.Immutable
abstract static class UserDefinedLiteral implements Literal {
public abstract ByteString value();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ public interface ExpressionVisitor<R, E extends Throwable> {

R visit(Expression.StructLiteral expr) throws E;

R visit(Expression.StructNested expr) throws E;

R visit(Expression.UserDefinedLiteral expr) throws E;

R visit(Expression.Switch expr) throws E;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,12 @@ private Expression lit(Consumer<Expression.Literal.Builder> consumer) {
return Expression.newBuilder().setLiteral(builder).build();
}

private Expression nested(Consumer<Expression.Nested.Builder> consumer) {
var builder = Expression.Nested.newBuilder();
consumer.accept(builder);
return Expression.newBuilder().setNested(builder).build();
}

@Override
public Expression visit(io.substrait.expression.Expression.BoolLiteral expr) {
return lit(bldr -> bldr.setNullable(expr.nullable()).setBoolean(expr.value()));
Expand Down Expand Up @@ -323,6 +329,18 @@ public Expression visit(io.substrait.expression.Expression.StructLiteral expr) {
});
}

@Override
public Expression visit(io.substrait.expression.Expression.StructNested expr) {
return nested(
bldr -> {
var values =
expr.fields().stream()
.map(this::toProto)
.collect(java.util.stream.Collectors.toList());
bldr.setStruct(Expression.Nested.Struct.newBuilder().addAllFields(values));
});
}

@Override
public Expression visit(io.substrait.expression.Expression.UserDefinedLiteral expr) {
var typeReference =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,14 @@ public Optional<Expression> visit(Expression.StructLiteral expr) throws EXCEPTIO
return visitLiteral(expr);
}

@Override
public Optional<Expression> visit(Expression.StructNested expr) throws EXCEPTION {
var expressions = visitExprList(expr.fields());
return expressions.map(
expressionList ->
Expression.StructNested.builder().from(expr).fields(expressionList).build());
}

@Override
public Optional<Expression> visit(Expression.UserDefinedLiteral expr) throws EXCEPTION {
return visitLiteral(expr);
Expand Down
30 changes: 23 additions & 7 deletions core/src/main/java/io/substrait/relation/ProtoRelConverter.java
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ public Rel from(io.substrait.proto.Rel rel) {
protected Rel newRead(ReadRel rel) {
if (rel.hasVirtualTable()) {
var virtualTable = rel.getVirtualTable();
if (virtualTable.getValuesCount() == 0) {
if (virtualTable.getValuesCount() == 0 && virtualTable.getExpressionsCount() == 0) {
return newEmptyScan(rel);
} else {
return newVirtualTable(rel);
Expand Down Expand Up @@ -417,17 +417,33 @@ protected FileOrFiles newFileOrFiles(ReadRel.LocalFiles.FileOrFiles file) {

protected VirtualTableScan newVirtualTable(ReadRel rel) {
var virtualTable = rel.getVirtualTable();
// If both values and expressions are set, raise an error
if (virtualTable.getValuesCount() > 0 && virtualTable.getExpressionsCount() > 0) {
throw new IllegalArgumentException(
"Virtual table cannot have both values and expressions set");
}

var virtualTableSchema = newNamedStruct(rel);

var converter =
new ProtoExpressionConverter(lookup, extensions, virtualTableSchema.struct(), this);
List<Expression.StructLiteral> structLiterals = new ArrayList<>(virtualTable.getValuesCount());

List<Expression> expressions =
new ArrayList<>(virtualTable.getValuesCount() + virtualTable.getExpressionsCount());

for (var struct : virtualTable.getValuesList()) {
structLiterals.add(
expressions.add(
ImmutableExpression.StructLiteral.builder()
.fields(
struct.getFieldsList().stream()
.map(converter::from)
.collect(java.util.stream.Collectors.toList()))
struct.getFieldsList().stream().map(converter::from).collect(Collectors.toList()))
.build());
}

for (var expr : virtualTable.getExpressionsList()) {
expressions.add(
ImmutableExpression.StructNested.builder()
.fields(
expr.getFieldsList().stream().map(converter::from).collect(Collectors.toList()))
.build());
}

Expand All @@ -438,7 +454,7 @@ protected VirtualTableScan newVirtualTable(ReadRel rel) {
rel.hasBestEffortFilter() ? converter.from(rel.getBestEffortFilter()) : null))
.filter(Optional.ofNullable(rel.hasFilter() ? converter.from(rel.getFilter()) : null))
.initialSchema(NamedStruct.fromProto(rel.getBaseSchema(), protoTypeConverter))
.rows(structLiterals);
.rows(expressions);

builder
.commonExtension(optionalAdvancedExtension(rel.getCommon()))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
import io.substrait.type.Type;
import io.substrait.type.TypeVisitor;
import java.util.List;
import java.util.Objects;
import org.immutables.value.Value;

@Value.Immutable
public abstract class VirtualTableScan extends AbstractReadRel {

public abstract List<Expression.StructLiteral> getRows();
public abstract List<Expression> getRows();

/**
*
Expand All @@ -29,9 +30,9 @@ protected void check() {
== NamedFieldCountingTypeVisitor.countNames(this.getInitialSchema().struct());
var rows = getRows();

assert rows.size() > 0
&& names.stream().noneMatch(s -> s == null)
&& rows.stream().noneMatch(r -> r == null)
assert !rows.isEmpty()
&& names.stream().noneMatch(Objects::isNull)
&& rows.stream().noneMatch(Objects::isNull)
&& rows.stream()
.allMatch(r -> NamedFieldCountingTypeVisitor.countNames(r.getType()) == names.size());
}
Expand Down
28 changes: 28 additions & 0 deletions core/src/test/java/io/substrait/relation/VirtualTableScanTest.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
package io.substrait.relation;

import static io.substrait.expression.ExpressionCreator.bool;
import static io.substrait.expression.ExpressionCreator.fp32;
import static io.substrait.expression.ExpressionCreator.fp64;
import static io.substrait.expression.ExpressionCreator.i8;
import static io.substrait.expression.ExpressionCreator.i16;
import static io.substrait.expression.ExpressionCreator.i32;
import static io.substrait.expression.ExpressionCreator.i64;
import static io.substrait.expression.ExpressionCreator.list;
import static io.substrait.expression.ExpressionCreator.map;
import static io.substrait.expression.ExpressionCreator.string;
Expand All @@ -25,6 +32,13 @@ void check() {
NamedStruct.of(
Arrays.stream(
new String[] {
"bool_field",
"i8_field",
"i16_field",
"i32_field",
"i64_field",
"fp32_field",
"fp64_field",
"string",
"struct",
"struct_field1",
Expand All @@ -37,13 +51,27 @@ void check() {
})
.collect(Collectors.toList()),
R.struct(
R.BOOLEAN,
R.I8,
R.I16,
R.I32,
R.I64,
R.FP32,
R.FP64,
R.STRING,
R.struct(R.STRING, R.STRING),
R.list(R.struct(R.STRING)),
R.map(R.struct(R.STRING), R.struct(R.STRING)))))
.addRows(
struct(
false,
bool(false, true),
i8(false, 42),
i16(false, 1234),
i32(false, 123456),
i64(false, 9876543210L),
fp32(false, 3.14f),
fp64(false, 2.718281828),
string(false, "string_val"),
struct(
false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -345,13 +345,21 @@ class ToLogicalPlan(spark: SparkSession = SparkSession.builder().getOrCreate())
}

override def visit(virtualTableScan: relation.VirtualTableScan): LogicalPlan = {
val rows = virtualTableScan.getRows.asScala.map(
row =>
val rows = virtualTableScan.getRows.asScala.map {
case structLit: SExpression.StructLiteral =>
InternalRow.fromSeq(
row
.fields()
.asScala
.map(field => field.accept(expressionConverter).asInstanceOf[Literal].value)))
structLit.fields.asScala
.map(field => field.accept(expressionConverter).asInstanceOf[Literal].value)
)
case structNested: SExpression.StructNested =>
InternalRow.fromSeq(
structNested.fields.asScala
.map(expr => expr.accept(expressionConverter))
)
case other =>
throw new UnsupportedOperationException(
s"Unsupported row type in VirtualTableScan: ${other.getClass}")
}
virtualTableScan.getInitialSchema match {
case ns: NamedStruct if ns.names().isEmpty && rows.length == 1 =>
OneRowRelation()
Expand Down