From 173ec596ce2da8d5aed14f4a89dad5c3434c7224 Mon Sep 17 00:00:00 2001 From: Ben Bellick Date: Wed, 12 Nov 2025 17:03:44 -0500 Subject: [PATCH 01/13] feat: add structured UDT literal support with dual encoding Support both opaque (google.protobuf.Any) and structured (Literal.Struct) encodings for user-defined type literals per Substrait spec. - Split UserDefinedLiteral into UserDefinedAny and UserDefinedStruct - Move type parameters to interface level for parameterized types - Comprehensive test coverage including roundtrip tests - Throw exception on unhandled struct-based representation in isthmus --- .../expression/AbstractExpressionVisitor.java | 10 + .../io/substrait/expression/Expression.java | 87 +++++- .../expression/ExpressionCreator.java | 46 +++- .../expression/ExpressionVisitor.java | 4 +- .../proto/ExpressionProtoConverter.java | 46 +++- .../proto/ProtoExpressionConverter.java | 30 ++- .../extension/DefaultExtensionCatalog.java | 3 + .../ExpressionCopyOnWriteVisitor.java | 8 +- .../src/main/java/io/substrait/type/Type.java | 17 ++ .../type/proto/BaseProtoConverter.java | 2 +- .../substrait/type/proto/BaseProtoTypes.java | 3 + .../proto/ParameterizedProtoConverter.java | 7 + .../type/proto/ProtoTypeConverter.java | 8 +- .../proto/TypeExpressionProtoVisitor.java | 7 + .../type/proto/TypeProtoConverter.java | 11 + core/src/test/java/io/substrait/TestBase.java | 17 ++ .../type/proto/LiteralRoundtripTest.java | 255 +++++++++++++++++- .../examples/util/ExpressionStringify.java | 12 +- .../examples/util/SubstraitStringify.java | 2 +- .../isthmus/expression/CallConverters.java | 27 +- .../expression/ExpressionRexConverter.java | 11 +- .../substrait/isthmus/CustomFunctionTest.java | 4 +- .../substrait/debug/ExpressionToString.scala | 6 +- .../spark/DefaultExpressionVisitor.scala | 10 +- 24 files changed, 576 insertions(+), 57 deletions(-) diff --git a/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java b/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java index 072507295..d190542f8 100644 --- a/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java +++ b/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java @@ -151,6 +151,16 @@ public O visit(Expression.StructLiteral expr, C context) throws E { return visitFallback(expr, context); } + @Override + public O visit(Expression.UserDefinedAny expr, C context) throws E { + return visitFallback(expr, context); + } + + @Override + public O visit(Expression.UserDefinedStruct expr, C context) throws E { + return visitFallback(expr, context); + } + @Override public O visit(Expression.Switch expr, C context) throws E { return visitFallback(expr, context); diff --git a/core/src/main/java/io/substrait/expression/Expression.java b/core/src/main/java/io/substrait/expression/Expression.java index 42c3c5118..1b0a8362d 100644 --- a/core/src/main/java/io/substrait/expression/Expression.java +++ b/core/src/main/java/io/substrait/expression/Expression.java @@ -662,21 +662,96 @@ public R accept( } } + /** + * Base interface for user-defined literals. + * + *

User-defined literals can be encoded in one of two ways as per the Substrait spec: + * + *

+ * + * @see UserDefinedAny + * @see UserDefinedStruct + */ + interface UserDefinedLiteral extends Literal { + String urn(); + + String name(); + + List typeParameters(); + } + + /** + * User-defined literal with value encoded as {@code google.protobuf.Any}. + * + *

This encoding allows for arbitrary binary data to be stored in the literal value. + */ @Value.Immutable - abstract class UserDefinedLiteral implements Literal { - public abstract ByteString value(); + abstract class UserDefinedAny implements UserDefinedLiteral { + @Override + public abstract String urn(); + + @Override + public abstract String name(); + + @Override + public abstract List typeParameters(); + + public abstract com.google.protobuf.Any value(); + + @Override + public Type.UserDefined getType() { + return Type.UserDefined.builder() + .nullable(nullable()) + .urn(urn()) + .name(name()) + .typeParameters(typeParameters()) + .build(); + } + + public static ImmutableExpression.UserDefinedAny.Builder builder() { + return ImmutableExpression.UserDefinedAny.builder(); + } + @Override + public R accept( + ExpressionVisitor visitor, C context) throws E { + return visitor.visit(this, context); + } + } + + /** + * User-defined literal with value encoded as {@code Literal.Struct}. + * + *

This encoding uses a structured list of fields to represent the literal value. + */ + @Value.Immutable + abstract class UserDefinedStruct implements UserDefinedLiteral { + @Override public abstract String urn(); + @Override public abstract String name(); @Override - public Type getType() { - return Type.withNullability(nullable()).userDefined(urn(), name()); + public abstract List typeParameters(); + + public abstract List fields(); + + @Override + public Type.UserDefined getType() { + return Type.UserDefined.builder() + .nullable(nullable()) + .urn(urn()) + .name(name()) + .typeParameters(typeParameters()) + .build(); } - public static ImmutableExpression.UserDefinedLiteral.Builder builder() { - return ImmutableExpression.UserDefinedLiteral.builder(); + public static ImmutableExpression.UserDefinedStruct.Builder builder() { + return ImmutableExpression.UserDefinedStruct.builder(); } @Override diff --git a/core/src/main/java/io/substrait/expression/ExpressionCreator.java b/core/src/main/java/io/substrait/expression/ExpressionCreator.java index adf157d7b..2f924bef8 100644 --- a/core/src/main/java/io/substrait/expression/ExpressionCreator.java +++ b/core/src/main/java/io/substrait/expression/ExpressionCreator.java @@ -286,13 +286,51 @@ public static Expression.StructLiteral struct( return Expression.StructLiteral.builder().nullable(nullable).addAllFields(values).build(); } - public static Expression.UserDefinedLiteral userDefinedLiteral( - boolean nullable, String urn, String name, Any value) { - return Expression.UserDefinedLiteral.builder() + /** + * Create a UserDefinedAny with google.protobuf.Any representation. + * + * @param nullable whether the literal is nullable + * @param urn the URN of the user-defined type + * @param name the name of the user-defined type + * @param typeParameters the type parameters for the user-defined type (can be empty list) + * @param value the value, encoded as google.protobuf.Any + */ + public static Expression.UserDefinedAny userDefinedLiteralAny( + boolean nullable, + String urn, + String name, + java.util.List typeParameters, + Any value) { + return Expression.UserDefinedAny.builder() + .nullable(nullable) + .urn(urn) + .name(name) + .addAllTypeParameters(typeParameters) + .value(value) + .build(); + } + + /** + * Create a UserDefinedStruct with Struct representation. + * + * @param nullable whether the literal is nullable + * @param urn the URN of the user-defined type + * @param name the name of the user-defined type + * @param typeParameters the type parameters for the user-defined type (can be empty list) + * @param fields the fields, as a list of Literal values + */ + public static Expression.UserDefinedStruct userDefinedLiteralStruct( + boolean nullable, + String urn, + String name, + java.util.List typeParameters, + java.util.List fields) { + return Expression.UserDefinedStruct.builder() .nullable(nullable) .urn(urn) .name(name) - .value(value.toByteString()) + .addAllTypeParameters(typeParameters) + .addAllFields(fields) .build(); } diff --git a/core/src/main/java/io/substrait/expression/ExpressionVisitor.java b/core/src/main/java/io/substrait/expression/ExpressionVisitor.java index d64cab48c..7cec9b953 100644 --- a/core/src/main/java/io/substrait/expression/ExpressionVisitor.java +++ b/core/src/main/java/io/substrait/expression/ExpressionVisitor.java @@ -62,7 +62,9 @@ public interface ExpressionVisitor { - try { - bldr.setNullable(expr.nullable()) - .setUserDefined( - Expression.Literal.UserDefined.newBuilder() - .setTypeReference(typeReference) - .setValue(Any.parseFrom(expr.value()))) - .build(); - } catch (InvalidProtocolBufferException e) { - throw new IllegalStateException(e); - } + Expression.Literal.UserDefined.Builder userDefinedBuilder = + Expression.Literal.UserDefined.newBuilder() + .setTypeReference(typeReference) + .addAllTypeParameters(expr.typeParameters()) + .setValue(expr.value()); + + bldr.setNullable(expr.nullable()).setUserDefined(userDefinedBuilder).build(); + }); + } + + @Override + public Expression visit( + io.substrait.expression.Expression.UserDefinedStruct expr, EmptyVisitationContext context) { + int typeReference = + extensionCollector.getTypeReference(SimpleExtension.TypeAnchor.of(expr.urn(), expr.name())); + return lit( + bldr -> { + Expression.Literal.Struct structLiteral = + Expression.Literal.Struct.newBuilder() + .addAllFields( + expr.fields().stream() + .map(this::toLiteral) + .collect(java.util.stream.Collectors.toList())) + .build(); + + Expression.Literal.UserDefined.Builder userDefinedBuilder = + Expression.Literal.UserDefined.newBuilder() + .setTypeReference(typeReference) + .addAllTypeParameters(expr.typeParameters()) + .setStruct(structLiteral); + + bldr.setNullable(expr.nullable()).setUserDefined(userDefinedBuilder).build(); }); } diff --git a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java index 8f95cdf07..847fcae55 100644 --- a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java +++ b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java @@ -492,10 +492,36 @@ public Expression.Literal from(io.substrait.proto.Expression.Literal literal) { { io.substrait.proto.Expression.Literal.UserDefined userDefinedLiteral = literal.getUserDefined(); + SimpleExtension.Type type = lookup.getType(userDefinedLiteral.getTypeReference(), extensions); - return ExpressionCreator.userDefinedLiteral( - literal.getNullable(), type.urn(), type.name(), userDefinedLiteral.getValue()); + String urn = type.urn(); + String name = type.name(); + + switch (userDefinedLiteral.getValCase()) { + case VALUE: + return ExpressionCreator.userDefinedLiteralAny( + literal.getNullable(), + urn, + name, + userDefinedLiteral.getTypeParametersList(), + userDefinedLiteral.getValue()); + case STRUCT: + return ExpressionCreator.userDefinedLiteralStruct( + literal.getNullable(), + urn, + name, + userDefinedLiteral.getTypeParametersList(), + userDefinedLiteral.getStruct().getFieldsList().stream() + .map(this::from) + .collect(Collectors.toList())); + case VAL_NOT_SET: + throw new IllegalStateException( + "UserDefined literal has no value (neither 'value' nor 'struct' is set)"); + default: + throw new IllegalStateException( + "Unknown UserDefined literal value case: " + userDefinedLiteral.getValCase()); + } } default: throw new IllegalStateException("Unexpected value: " + literal.getLiteralTypeCase()); diff --git a/core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java b/core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java index 89aad954e..31214878c 100644 --- a/core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java +++ b/core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java @@ -22,6 +22,7 @@ public class DefaultExtensionCatalog { "extension:io.substrait:functions_rounding_decimal"; public static final String FUNCTIONS_SET = "extension:io.substrait:functions_set"; public static final String FUNCTIONS_STRING = "extension:io.substrait:functions_string"; + public static final String EXTENSION_TYPES = "extension:io.substrait:extension_types"; public static final SimpleExtension.ExtensionCollection DEFAULT_COLLECTION = loadDefaultCollection(); @@ -44,6 +45,8 @@ private static SimpleExtension.ExtensionCollection loadDefaultCollection() { .map(c -> String.format("/functions_%s.yaml", c)) .collect(Collectors.toList()); + defaultFiles.add("/extension_types.yaml"); + return SimpleExtension.load(defaultFiles); } } diff --git a/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java b/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java index 57132a940..68395ac0d 100644 --- a/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java +++ b/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java @@ -203,9 +203,15 @@ public Optional visit(Expression.StructLiteral expr, EmptyVisitation return visitLiteral(expr); } + @Override + public Optional visit(Expression.UserDefinedAny expr, EmptyVisitationContext context) + throws E { + return visitLiteral(expr); + } + @Override public Optional visit( - Expression.UserDefinedLiteral expr, EmptyVisitationContext context) throws E { + Expression.UserDefinedStruct expr, EmptyVisitationContext context) throws E { return visitLiteral(expr); } diff --git a/core/src/main/java/io/substrait/type/Type.java b/core/src/main/java/io/substrait/type/Type.java index aaf97aa12..7ef2d75a7 100644 --- a/core/src/main/java/io/substrait/type/Type.java +++ b/core/src/main/java/io/substrait/type/Type.java @@ -393,6 +393,23 @@ abstract class UserDefined implements Type { public abstract String name(); + /** + * Returns the type parameters for this user-defined type. + * + *

Type parameters are used to represent parameterized/generic types, such as {@code + * List} or {@code Map}. Each parameter in the list represents a type argument + * that specializes the generic user-defined type. + * + *

For example, a user-defined type {@code MyList} parameterized by {@code i32} would have + * one type parameter containing the {@code i32} type definition. + * + * @return a list of type parameters, or an empty list if this type is not parameterized + */ + @Value.Default + public java.util.List typeParameters() { + return java.util.Collections.emptyList(); + } + public static ImmutableType.UserDefined.Builder builder() { return ImmutableType.UserDefined.builder(); } diff --git a/core/src/main/java/io/substrait/type/proto/BaseProtoConverter.java b/core/src/main/java/io/substrait/type/proto/BaseProtoConverter.java index 691d4bce5..67d7bc9b5 100644 --- a/core/src/main/java/io/substrait/type/proto/BaseProtoConverter.java +++ b/core/src/main/java/io/substrait/type/proto/BaseProtoConverter.java @@ -165,6 +165,6 @@ public final T visit(final Type.Map expr) { public final T visit(final Type.UserDefined expr) { int ref = extensionCollector.getTypeReference(SimpleExtension.TypeAnchor.of(expr.urn(), expr.name())); - return typeContainer(expr).userDefined(ref); + return typeContainer(expr).userDefined(ref, expr.typeParameters()); } } diff --git a/core/src/main/java/io/substrait/type/proto/BaseProtoTypes.java b/core/src/main/java/io/substrait/type/proto/BaseProtoTypes.java index 6a1bc3186..1009fe52a 100644 --- a/core/src/main/java/io/substrait/type/proto/BaseProtoTypes.java +++ b/core/src/main/java/io/substrait/type/proto/BaseProtoTypes.java @@ -131,6 +131,9 @@ public final T struct(T... types) { public abstract T userDefined(int ref); + public abstract T userDefined( + int ref, java.util.List typeParameters); + protected abstract T wrap(Object o); protected abstract I i(int integerValue); diff --git a/core/src/main/java/io/substrait/type/proto/ParameterizedProtoConverter.java b/core/src/main/java/io/substrait/type/proto/ParameterizedProtoConverter.java index 4e0caa7c2..137c1fba3 100644 --- a/core/src/main/java/io/substrait/type/proto/ParameterizedProtoConverter.java +++ b/core/src/main/java/io/substrait/type/proto/ParameterizedProtoConverter.java @@ -262,6 +262,13 @@ public ParameterizedType userDefined(int ref) { "User defined types are not supported in Parameterized Types for now"); } + @Override + public ParameterizedType userDefined( + int ref, java.util.List typeParameters) { + throw new UnsupportedOperationException( + "User defined types are not supported in Parameterized Types for now"); + } + @Override protected ParameterizedType wrap(final Object o) { ParameterizedType.Builder bldr = ParameterizedType.newBuilder(); diff --git a/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java b/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java index 95d42328a..ee77e1445 100644 --- a/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java +++ b/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java @@ -90,7 +90,13 @@ public Type from(io.substrait.proto.Type type) { { io.substrait.proto.Type.UserDefined userDefined = type.getUserDefined(); SimpleExtension.Type t = lookup.getType(userDefined.getTypeReference(), extensions); - return n(userDefined.getNullability()).userDefined(t.urn(), t.name()); + boolean nullable = isNullable(userDefined.getNullability()); + return io.substrait.type.Type.UserDefined.builder() + .nullable(nullable) + .urn(t.urn()) + .name(t.name()) + .typeParameters(userDefined.getTypeParametersList()) + .build(); } case USER_DEFINED_TYPE_REFERENCE: throw new UnsupportedOperationException("Unsupported user defined reference: " + type); diff --git a/core/src/main/java/io/substrait/type/proto/TypeExpressionProtoVisitor.java b/core/src/main/java/io/substrait/type/proto/TypeExpressionProtoVisitor.java index 96cddd395..a3412a9e3 100644 --- a/core/src/main/java/io/substrait/type/proto/TypeExpressionProtoVisitor.java +++ b/core/src/main/java/io/substrait/type/proto/TypeExpressionProtoVisitor.java @@ -355,6 +355,13 @@ public DerivationExpression userDefined(int ref) { "User defined types are not supported in Derivation Expressions for now"); } + @Override + public DerivationExpression userDefined( + int ref, java.util.List typeParameters) { + throw new UnsupportedOperationException( + "User defined types are not supported in Derivation Expressions for now"); + } + @Override protected DerivationExpression wrap(final Object o) { DerivationExpression.Builder bldr = DerivationExpression.newBuilder(); diff --git a/core/src/main/java/io/substrait/type/proto/TypeProtoConverter.java b/core/src/main/java/io/substrait/type/proto/TypeProtoConverter.java index 2d0ed0ffc..7cb98263f 100644 --- a/core/src/main/java/io/substrait/type/proto/TypeProtoConverter.java +++ b/core/src/main/java/io/substrait/type/proto/TypeProtoConverter.java @@ -133,6 +133,17 @@ public Type userDefined(int ref) { Type.UserDefined.newBuilder().setTypeReference(ref).setNullability(nullability).build()); } + @Override + public Type userDefined( + int ref, java.util.List typeParameters) { + return wrap( + Type.UserDefined.newBuilder() + .setTypeReference(ref) + .setNullability(nullability) + .addAllTypeParameters(typeParameters) + .build()); + } + @Override protected Type wrap(final Object o) { Type.Builder bldr = Type.newBuilder(); diff --git a/core/src/test/java/io/substrait/TestBase.java b/core/src/test/java/io/substrait/TestBase.java index 3defbf78f..b5f1dd4f1 100644 --- a/core/src/test/java/io/substrait/TestBase.java +++ b/core/src/test/java/io/substrait/TestBase.java @@ -1,8 +1,12 @@ package io.substrait; +import static io.substrait.expression.proto.ProtoExpressionConverter.EMPTY_TYPE; import static org.junit.jupiter.api.Assertions.assertEquals; import io.substrait.dsl.SubstraitBuilder; +import io.substrait.expression.Expression; +import io.substrait.expression.proto.ExpressionProtoConverter; +import io.substrait.expression.proto.ProtoExpressionConverter; import io.substrait.extension.DefaultExtensionCatalog; import io.substrait.extension.ExtensionCollector; import io.substrait.extension.SimpleExtension; @@ -25,9 +29,22 @@ public abstract class TestBase { protected ProtoRelConverter protoRelConverter = new ProtoRelConverter(functionCollector, defaultExtensionCollection); + protected ExpressionProtoConverter expressionProtoConverter = + new ExpressionProtoConverter(functionCollector, relProtoConverter); + + protected ProtoExpressionConverter protoExpressionConverter = + new ProtoExpressionConverter( + functionCollector, defaultExtensionCollection, EMPTY_TYPE, protoRelConverter); + protected void verifyRoundTrip(Rel rel) { io.substrait.proto.Rel protoRel = relProtoConverter.toProto(rel); Rel relReturned = protoRelConverter.from(protoRel); assertEquals(rel, relReturned); } + + protected void verifyRoundTrip(Expression expression) { + io.substrait.proto.Expression protoExpression = expressionProtoConverter.toProto(expression); + Expression expressionReturned = protoExpressionConverter.from(protoExpression); + assertEquals(expression, expressionReturned); + } } diff --git a/core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java index ccac93bcb..ca3d80447 100644 --- a/core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java @@ -3,23 +3,268 @@ import static io.substrait.expression.proto.ProtoExpressionConverter.EMPTY_TYPE; import static org.junit.jupiter.api.Assertions.assertEquals; +import com.google.protobuf.Any; import io.substrait.TestBase; +import io.substrait.expression.Expression; import io.substrait.expression.ExpressionCreator; import io.substrait.expression.proto.ExpressionProtoConverter; import io.substrait.expression.proto.ProtoExpressionConverter; -import io.substrait.util.EmptyVisitationContext; +import io.substrait.extension.DefaultExtensionCatalog; +import io.substrait.extension.ExtensionCollector; +import io.substrait.extension.SimpleExtension; +import io.substrait.relation.ProtoRelConverter; +import io.substrait.relation.RelProtoConverter; import java.math.BigDecimal; +import java.util.Collections; import org.junit.jupiter.api.Test; public class LiteralRoundtripTest extends TestBase { + private static final String NESTED_TYPES_URN = "extension:io.substrait:test_nested_types"; + + private static final String NESTED_TYPES_YAML = + "---\n" + + "urn: " + + NESTED_TYPES_URN + + "\n" + + "types:\n" + + " - name: point\n" + + " structure:\n" + + " latitude: i32\n" + + " longitude: i32\n" + + " - name: triangle\n" + + " structure:\n" + + " p1: point\n" + + " p2: point\n" + + " p3: point\n"; + + private static final SimpleExtension.ExtensionCollection NESTED_TYPES_EXTENSIONS = + SimpleExtension.load("nested_types.yaml", NESTED_TYPES_YAML); + + private static final ExtensionCollector NESTED_TYPES_FUNCTION_COLLECTOR = + new ExtensionCollector(); + private static final RelProtoConverter NESTED_TYPES_REL_PROTO_CONVERTER = + new RelProtoConverter(NESTED_TYPES_FUNCTION_COLLECTOR); + private static final ProtoRelConverter NESTED_TYPES_PROTO_REL_CONVERTER = + new ProtoRelConverter(NESTED_TYPES_FUNCTION_COLLECTOR, NESTED_TYPES_EXTENSIONS); + private static final ExpressionProtoConverter NESTED_TYPES_EXPRESSION_TO_PROTO = + new ExpressionProtoConverter( + NESTED_TYPES_FUNCTION_COLLECTOR, NESTED_TYPES_REL_PROTO_CONVERTER); + private static final ProtoExpressionConverter NESTED_TYPES_PROTO_TO_EXPRESSION = + new ProtoExpressionConverter( + NESTED_TYPES_FUNCTION_COLLECTOR, + NESTED_TYPES_EXTENSIONS, + EMPTY_TYPE, + NESTED_TYPES_PROTO_REL_CONVERTER); + @Test void decimal() { io.substrait.expression.Expression.DecimalLiteral val = ExpressionCreator.decimal(false, BigDecimal.TEN, 10, 2); - ExpressionProtoConverter to = new ExpressionProtoConverter(null, null); - ProtoExpressionConverter from = - new ProtoExpressionConverter(null, null, EMPTY_TYPE, protoRelConverter); - assertEquals(val, from.from(val.accept(to, EmptyVisitationContext.INSTANCE))); + verifyRoundTrip(val); + } + + /** Verifies round-trip conversion of a simple user-defined type using Any representation. */ + @Test + void userDefinedLiteralWithAnyRepresentation() { + // Create a struct literal inline representing a point with latitude=42, longitude=100 + io.substrait.proto.Expression.Literal.Struct pointStruct = + io.substrait.proto.Expression.Literal.Struct.newBuilder() + .addFields(io.substrait.proto.Expression.Literal.newBuilder().setI32(42)) + .addFields(io.substrait.proto.Expression.Literal.newBuilder().setI32(100)) + .build(); + io.substrait.proto.Expression.Literal innerLiteral = + io.substrait.proto.Expression.Literal.newBuilder().setStruct(pointStruct).build(); + Any anyValue = Any.pack(innerLiteral); + + Expression.UserDefinedLiteral val = + ExpressionCreator.userDefinedLiteralAny( + false, + DefaultExtensionCatalog.EXTENSION_TYPES, + "point", + java.util.Collections.emptyList(), + anyValue); + + verifyRoundTrip(val); + } + + /** Verifies round-trip conversion of a simple user-defined type using Struct representation. */ + @Test + void userDefinedLiteralWithStructRepresentation() { + java.util.List fields = + java.util.Arrays.asList( + ExpressionCreator.i32(false, 42), ExpressionCreator.i32(false, 100)); + Expression.UserDefinedLiteral val = + ExpressionCreator.userDefinedLiteralStruct( + false, + DefaultExtensionCatalog.EXTENSION_TYPES, + "point", + java.util.Collections.emptyList(), + fields); + + verifyRoundTrip(val); + } + + /** + * Verifies round-trip conversion of nested user-defined types where a triangle UDT contains three + * point UDTs. Both outer and nested types use Struct representation. + */ + @Test + void nestedUserDefinedLiteralWithStructRepresentation() { + Expression.UserDefinedStruct p1 = + ExpressionCreator.userDefinedLiteralStruct( + false, + NESTED_TYPES_URN, + "point", + Collections.emptyList(), + java.util.Arrays.asList( + ExpressionCreator.i32(false, 0), ExpressionCreator.i32(false, 0))); + + Expression.UserDefinedStruct p2 = + ExpressionCreator.userDefinedLiteralStruct( + false, + NESTED_TYPES_URN, + "point", + Collections.emptyList(), + java.util.Arrays.asList( + ExpressionCreator.i32(false, 10), ExpressionCreator.i32(false, 0))); + + Expression.UserDefinedStruct p3 = + ExpressionCreator.userDefinedLiteralStruct( + false, + NESTED_TYPES_URN, + "point", + Collections.emptyList(), + java.util.Arrays.asList( + ExpressionCreator.i32(false, 5), ExpressionCreator.i32(false, 10))); + + Expression.UserDefinedStruct triangle = + ExpressionCreator.userDefinedLiteralStruct( + false, + NESTED_TYPES_URN, + "triangle", + Collections.emptyList(), + java.util.Arrays.asList(p1, p2, p3)); + + io.substrait.proto.Expression protoExpression = + NESTED_TYPES_EXPRESSION_TO_PROTO.toProto(triangle); + Expression result = NESTED_TYPES_PROTO_TO_EXPRESSION.from(protoExpression); + assertEquals(triangle, result); + } + + /** + * Verifies round-trip conversion of nested user-defined types where a triangle UDT contains three + * point UDTs. Both outer and nested types use Any representation. + */ + @Test + void nestedUserDefinedLiteralWithAnyRepresentation() { + + // Create three point UDTs using Any representation + io.substrait.proto.Expression.Literal.Struct p1Struct = + io.substrait.proto.Expression.Literal.Struct.newBuilder() + .addFields(io.substrait.proto.Expression.Literal.newBuilder().setI32(0)) + .addFields(io.substrait.proto.Expression.Literal.newBuilder().setI32(0)) + .build(); + Any p1Any = + Any.pack(io.substrait.proto.Expression.Literal.newBuilder().setStruct(p1Struct).build()); + Expression.UserDefinedAny p1 = + ExpressionCreator.userDefinedLiteralAny( + false, NESTED_TYPES_URN, "point", Collections.emptyList(), p1Any); + + io.substrait.proto.Expression.Literal.Struct p2Struct = + io.substrait.proto.Expression.Literal.Struct.newBuilder() + .addFields(io.substrait.proto.Expression.Literal.newBuilder().setI32(10)) + .addFields(io.substrait.proto.Expression.Literal.newBuilder().setI32(0)) + .build(); + Any p2Any = + Any.pack(io.substrait.proto.Expression.Literal.newBuilder().setStruct(p2Struct).build()); + Expression.UserDefinedAny p2 = + ExpressionCreator.userDefinedLiteralAny( + false, NESTED_TYPES_URN, "point", Collections.emptyList(), p2Any); + + io.substrait.proto.Expression.Literal.Struct p3Struct = + io.substrait.proto.Expression.Literal.Struct.newBuilder() + .addFields(io.substrait.proto.Expression.Literal.newBuilder().setI32(5)) + .addFields(io.substrait.proto.Expression.Literal.newBuilder().setI32(10)) + .build(); + Any p3Any = + Any.pack(io.substrait.proto.Expression.Literal.newBuilder().setStruct(p3Struct).build()); + Expression.UserDefinedAny p3 = + ExpressionCreator.userDefinedLiteralAny( + false, NESTED_TYPES_URN, "point", Collections.emptyList(), p3Any); + + // Create a "triangle" struct containing three point UDTs + io.substrait.proto.Expression.Literal.Struct triangleStruct = + io.substrait.proto.Expression.Literal.Struct.newBuilder() + .addFields(NESTED_TYPES_EXPRESSION_TO_PROTO.toProto(p1).getLiteral()) + .addFields(NESTED_TYPES_EXPRESSION_TO_PROTO.toProto(p2).getLiteral()) + .addFields(NESTED_TYPES_EXPRESSION_TO_PROTO.toProto(p3).getLiteral()) + .build(); + Any triangleAny = + Any.pack( + io.substrait.proto.Expression.Literal.newBuilder().setStruct(triangleStruct).build()); + + Expression.UserDefinedAny triangle = + ExpressionCreator.userDefinedLiteralAny( + false, NESTED_TYPES_URN, "triangle", Collections.emptyList(), triangleAny); + + io.substrait.proto.Expression protoExpression = + NESTED_TYPES_EXPRESSION_TO_PROTO.toProto(triangle); + Expression result = NESTED_TYPES_PROTO_TO_EXPRESSION.from(protoExpression); + assertEquals(triangle, result); + } + + /** + * Verifies round-trip conversion of nested user-defined types with mixed representations. The + * triangle UDT uses Struct representation while the nested point UDTs use Any representation. + */ + @Test + void mixedRepresentationNestedUserDefinedLiteral() { + io.substrait.proto.Expression.Literal.Struct p1Struct = + io.substrait.proto.Expression.Literal.Struct.newBuilder() + .addFields(io.substrait.proto.Expression.Literal.newBuilder().setI32(0)) + .addFields(io.substrait.proto.Expression.Literal.newBuilder().setI32(0)) + .build(); + Any p1Any = + Any.pack(io.substrait.proto.Expression.Literal.newBuilder().setStruct(p1Struct).build()); + Expression.UserDefinedAny p1 = + ExpressionCreator.userDefinedLiteralAny( + false, NESTED_TYPES_URN, "point", Collections.emptyList(), p1Any); + + io.substrait.proto.Expression.Literal.Struct p2Struct = + io.substrait.proto.Expression.Literal.Struct.newBuilder() + .addFields(io.substrait.proto.Expression.Literal.newBuilder().setI32(10)) + .addFields(io.substrait.proto.Expression.Literal.newBuilder().setI32(0)) + .build(); + Any p2Any = + Any.pack(io.substrait.proto.Expression.Literal.newBuilder().setStruct(p2Struct).build()); + Expression.UserDefinedAny p2 = + ExpressionCreator.userDefinedLiteralAny( + false, NESTED_TYPES_URN, "point", Collections.emptyList(), p2Any); + + io.substrait.proto.Expression.Literal.Struct p3Struct = + io.substrait.proto.Expression.Literal.Struct.newBuilder() + .addFields(io.substrait.proto.Expression.Literal.newBuilder().setI32(5)) + .addFields(io.substrait.proto.Expression.Literal.newBuilder().setI32(10)) + .build(); + Any p3Any = + Any.pack(io.substrait.proto.Expression.Literal.newBuilder().setStruct(p3Struct).build()); + Expression.UserDefinedAny p3 = + ExpressionCreator.userDefinedLiteralAny( + false, NESTED_TYPES_URN, "point", Collections.emptyList(), p3Any); + + // Create a "triangle" UDT using Struct representation, but with Any-encoded point fields + Expression.UserDefinedStruct triangle = + ExpressionCreator.userDefinedLiteralStruct( + false, + NESTED_TYPES_URN, + "triangle", + Collections.emptyList(), + java.util.Arrays.asList(p1, p2, p3)); + + io.substrait.proto.Expression protoExpression = + NESTED_TYPES_EXPRESSION_TO_PROTO.toProto(triangle); + Expression result = NESTED_TYPES_PROTO_TO_EXPRESSION.from(protoExpression); + assertEquals(triangle, result); } } diff --git a/examples/substrait-spark/src/main/java/io/substrait/examples/util/ExpressionStringify.java b/examples/substrait-spark/src/main/java/io/substrait/examples/util/ExpressionStringify.java index 71de9a7d5..cee84d13d 100644 --- a/examples/substrait-spark/src/main/java/io/substrait/examples/util/ExpressionStringify.java +++ b/examples/substrait-spark/src/main/java/io/substrait/examples/util/ExpressionStringify.java @@ -37,7 +37,8 @@ import io.substrait.expression.Expression.TimestampLiteral; import io.substrait.expression.Expression.TimestampTZLiteral; import io.substrait.expression.Expression.UUIDLiteral; -import io.substrait.expression.Expression.UserDefinedLiteral; +import io.substrait.expression.Expression.UserDefinedAny; +import io.substrait.expression.Expression.UserDefinedStruct; import io.substrait.expression.Expression.VarCharLiteral; import io.substrait.expression.Expression.WindowFunctionInvocation; import io.substrait.expression.ExpressionVisitor; @@ -188,9 +189,14 @@ public String visit(StructLiteral expr, EmptyVisitationContext context) throws R } @Override - public String visit(UserDefinedLiteral expr, EmptyVisitationContext context) + public String visit(UserDefinedAny expr, EmptyVisitationContext context) throws RuntimeException { + return ""; + } + + @Override + public String visit(UserDefinedStruct expr, EmptyVisitationContext context) throws RuntimeException { - return ""; + return ""; } @Override diff --git a/examples/substrait-spark/src/main/java/io/substrait/examples/util/SubstraitStringify.java b/examples/substrait-spark/src/main/java/io/substrait/examples/util/SubstraitStringify.java index f3b34f6c2..de67a3ee8 100644 --- a/examples/substrait-spark/src/main/java/io/substrait/examples/util/SubstraitStringify.java +++ b/examples/substrait-spark/src/main/java/io/substrait/examples/util/SubstraitStringify.java @@ -91,7 +91,7 @@ public static List explain(io.substrait.plan.Plan plan) { /** * Explains the Sustrait relation * - * @param plan Subsrait relation + * @param rel Subsrait relation * @return List of strings; typically these would then be logged or sent to stdout */ public static List explain(io.substrait.relation.Rel rel) { diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java b/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java index 3406de7de..5191ecca8 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java @@ -44,15 +44,15 @@ public class CallConverters { * {@link SqlKind#REINTERPRET} is utilized by Isthmus to represent and store {@link * Expression.UserDefinedLiteral}s within Calcite. * - *

When converting from Substrait to Calcite, the {@link Expression.UserDefinedLiteral#value()} - * is stored within a {@link org.apache.calcite.sql.type.SqlTypeName#BINARY} {@link + *

When converting from Substrait to Calcite, the {@link Expression.UserDefinedAny#value()} is + * stored within a {@link org.apache.calcite.sql.type.SqlTypeName#BINARY} {@link * org.apache.calcite.rex.RexLiteral} and then re-interpreted to have the correct type. * - *

See {@link ExpressionRexConverter#visit(Expression.UserDefinedLiteral, + *

See {@link ExpressionRexConverter#visit(Expression.UserDefinedAny, * SubstraitRelNodeConverter.Context)} for this conversion. * *

When converting from Calcite to Substrait, this call converter extracts the {@link - * Expression.UserDefinedLiteral} that was stored. + * Expression.UserDefinedAny} that was stored. */ public static Function REINTERPRET = typeConverter -> @@ -70,11 +70,20 @@ public class CallConverters { Expression.FixedBinaryLiteral literal = (Expression.FixedBinaryLiteral) operand; Type.UserDefined t = (Type.UserDefined) type; - return Expression.UserDefinedLiteral.builder() - .urn(t.urn()) - .name(t.name()) - .value(literal.value()) - .build(); + // The binary literal contains the serialized protobuf Any - just parse it directly + try { + com.google.protobuf.Any anyValue = + com.google.protobuf.Any.parseFrom(literal.value().toByteArray()); + + return Expression.UserDefinedAny.builder() + .urn(t.urn()) + .name(t.name()) + .addAllTypeParameters(t.typeParameters()) + .value(anyValue) + .build(); + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + throw new IllegalStateException("Failed to parse UserDefinedAny value", e); + } } return null; }; diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java index 2b8052889..e4f82731a 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java @@ -109,14 +109,19 @@ public RexNode visit(Expression.NullLiteral expr, Context context) throws Runtim } @Override - public RexNode visit(Expression.UserDefinedLiteral expr, Context context) - throws RuntimeException { + public RexNode visit(Expression.UserDefinedAny expr, Context context) throws RuntimeException { RexLiteral binaryLiteral = - rexBuilder.makeBinaryLiteral(new ByteString(expr.value().toByteArray())); + rexBuilder.makeBinaryLiteral(new ByteString(expr.value().toByteString().toByteArray())); RelDataType type = typeConverter.toCalcite(typeFactory, expr.getType()); return rexBuilder.makeReinterpretCast(type, binaryLiteral, rexBuilder.makeLiteral(false)); } + @Override + public RexNode visit(Expression.UserDefinedStruct expr, Context context) throws RuntimeException { + throw new UnsupportedOperationException( + "UserDefinedStruct representation is not yet supported in Isthmus"); + } + @Override public RexNode visit(Expression.BoolLiteral expr, Context context) throws RuntimeException { return rexBuilder.makeLiteral(expr.value()); diff --git a/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java b/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java index c80f6a4ba..9ef12f418 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java @@ -587,7 +587,9 @@ void customTypesInFunctionsRoundtrip() { void customTypesLiteralInFunctionsRoundtrip() { Builder bldr = Expression.Literal.newBuilder(); Any anyValue = Any.pack(bldr.setI32(10).build()); - UserDefinedLiteral val = ExpressionCreator.userDefinedLiteral(false, URN, "a_type", anyValue); + UserDefinedLiteral val = + ExpressionCreator.userDefinedLiteralAny( + false, URN, "a_type", java.util.Collections.emptyList(), anyValue); Rel rel1 = b.project( diff --git a/spark/src/main/scala/io/substrait/debug/ExpressionToString.scala b/spark/src/main/scala/io/substrait/debug/ExpressionToString.scala index 5377f4257..133052fa1 100644 --- a/spark/src/main/scala/io/substrait/debug/ExpressionToString.scala +++ b/spark/src/main/scala/io/substrait/debug/ExpressionToString.scala @@ -76,8 +76,12 @@ class ExpressionToString extends DefaultExpressionVisitor[String] { s"${expr.declaration().key()}[${expr.outputType().accept(ToTypeString.INSTANCE)}]($args)" } + override def visit(expr: Expression.UserDefinedAny, context: EmptyVisitationContext): String = { + expr.toString + } + override def visit( - expr: Expression.UserDefinedLiteral, + expr: Expression.UserDefinedStruct, context: EmptyVisitationContext): String = { expr.toString } diff --git a/spark/src/main/scala/io/substrait/spark/DefaultExpressionVisitor.scala b/spark/src/main/scala/io/substrait/spark/DefaultExpressionVisitor.scala index 5f7137b14..07594d3bf 100644 --- a/spark/src/main/scala/io/substrait/spark/DefaultExpressionVisitor.scala +++ b/spark/src/main/scala/io/substrait/spark/DefaultExpressionVisitor.scala @@ -65,9 +65,9 @@ class DefaultExpressionVisitor[T] context: EmptyVisitationContext): T = e.accept(this, context) - override def visit( - userDefinedLiteral: Expression.UserDefinedLiteral, - context: EmptyVisitationContext): T = { - visitFallback(userDefinedLiteral, context) - } + override def visit(expr: Expression.UserDefinedAny, context: EmptyVisitationContext): T = + visitFallback(expr, context) + + override def visit(expr: Expression.UserDefinedStruct, context: EmptyVisitationContext): T = + visitFallback(expr, context) } From f0e13efee75853618e79d3e0a940040fd081914e Mon Sep 17 00:00:00 2001 From: Ben Bellick Date: Fri, 21 Nov 2025 14:55:20 -0500 Subject: [PATCH 02/13] test: add test to ensure struct-UDTs work with type parameters --- .../type/proto/LiteralRoundtripTest.java | 49 ++++++++++++++++++- 1 file changed, 48 insertions(+), 1 deletion(-) diff --git a/core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java index ca3d80447..a1355e9e9 100644 --- a/core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java @@ -36,7 +36,15 @@ public class LiteralRoundtripTest extends TestBase { + " structure:\n" + " p1: point\n" + " p2: point\n" - + " p3: point\n"; + + " p3: point\n" + + " - name: vector\n" + + " parameters:\n" + + " - name: T\n" + + " type: dataType\n" + + " structure:\n" + + " x: T\n" + + " y: T\n" + + " z: T\n"; private static final SimpleExtension.ExtensionCollection NESTED_TYPES_EXTENSIONS = SimpleExtension.load("nested_types.yaml", NESTED_TYPES_YAML); @@ -267,4 +275,43 @@ void mixedRepresentationNestedUserDefinedLiteral() { Expression result = NESTED_TYPES_PROTO_TO_EXPRESSION.from(protoExpression); assertEquals(triangle, result); } + + /** + * Verifies round-trip conversion of a parameterized user-defined type. Tests that type parameters + * are correctly preserved during serialization and deserialization. + */ + @Test + void userDefinedLiteralWithTypeParameters() { + // Create a type parameter for i32 + io.substrait.proto.Type i32Type = + io.substrait.proto.Type.newBuilder() + .setI32( + io.substrait.proto.Type.I32 + .newBuilder() + .setNullability(io.substrait.proto.Type.Nullability.NULLABILITY_REQUIRED)) + .build(); + io.substrait.proto.Type.Parameter typeParam = + io.substrait.proto.Type.Parameter.newBuilder().setDataType(i32Type).build(); + + // Create a vector instance with fields (x: 1, y: 2, z: 3) + Expression.UserDefinedStruct vectorI32 = + ExpressionCreator.userDefinedLiteralStruct( + false, + NESTED_TYPES_URN, + "vector", + java.util.Arrays.asList(typeParam), + java.util.Arrays.asList( + ExpressionCreator.i32(false, 1), + ExpressionCreator.i32(false, 2), + ExpressionCreator.i32(false, 3))); + + io.substrait.proto.Expression protoExpression = + NESTED_TYPES_EXPRESSION_TO_PROTO.toProto(vectorI32); + Expression result = NESTED_TYPES_PROTO_TO_EXPRESSION.from(protoExpression); + assertEquals(vectorI32, result); + + Expression.UserDefinedStruct resultStruct = (Expression.UserDefinedStruct) result; + assertEquals(1, resultStruct.typeParameters().size()); + assertEquals(typeParam, resultStruct.typeParameters().get(0)); + } } From f5b6341ad9fefcc485d80bd623224e982162d878 Mon Sep 17 00:00:00 2001 From: Ben Bellick Date: Fri, 21 Nov 2025 15:28:42 -0500 Subject: [PATCH 03/13] test: improve UDT literal test by making opaque-ness explicit --- .../type/proto/LiteralRoundtripTest.java | 87 +++---------------- 1 file changed, 10 insertions(+), 77 deletions(-) diff --git a/core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java index a1355e9e9..3fcb9f7aa 100644 --- a/core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java @@ -75,15 +75,8 @@ void decimal() { /** Verifies round-trip conversion of a simple user-defined type using Any representation. */ @Test void userDefinedLiteralWithAnyRepresentation() { - // Create a struct literal inline representing a point with latitude=42, longitude=100 - io.substrait.proto.Expression.Literal.Struct pointStruct = - io.substrait.proto.Expression.Literal.Struct.newBuilder() - .addFields(io.substrait.proto.Expression.Literal.newBuilder().setI32(42)) - .addFields(io.substrait.proto.Expression.Literal.newBuilder().setI32(100)) - .build(); - io.substrait.proto.Expression.Literal innerLiteral = - io.substrait.proto.Expression.Literal.newBuilder().setStruct(pointStruct).build(); - Any anyValue = Any.pack(innerLiteral); + Any anyValue = + Any.pack(com.google.protobuf.StringValue.of("")); Expression.UserDefinedLiteral val = ExpressionCreator.userDefinedLiteralAny( @@ -166,51 +159,8 @@ void nestedUserDefinedLiteralWithStructRepresentation() { */ @Test void nestedUserDefinedLiteralWithAnyRepresentation() { - - // Create three point UDTs using Any representation - io.substrait.proto.Expression.Literal.Struct p1Struct = - io.substrait.proto.Expression.Literal.Struct.newBuilder() - .addFields(io.substrait.proto.Expression.Literal.newBuilder().setI32(0)) - .addFields(io.substrait.proto.Expression.Literal.newBuilder().setI32(0)) - .build(); - Any p1Any = - Any.pack(io.substrait.proto.Expression.Literal.newBuilder().setStruct(p1Struct).build()); - Expression.UserDefinedAny p1 = - ExpressionCreator.userDefinedLiteralAny( - false, NESTED_TYPES_URN, "point", Collections.emptyList(), p1Any); - - io.substrait.proto.Expression.Literal.Struct p2Struct = - io.substrait.proto.Expression.Literal.Struct.newBuilder() - .addFields(io.substrait.proto.Expression.Literal.newBuilder().setI32(10)) - .addFields(io.substrait.proto.Expression.Literal.newBuilder().setI32(0)) - .build(); - Any p2Any = - Any.pack(io.substrait.proto.Expression.Literal.newBuilder().setStruct(p2Struct).build()); - Expression.UserDefinedAny p2 = - ExpressionCreator.userDefinedLiteralAny( - false, NESTED_TYPES_URN, "point", Collections.emptyList(), p2Any); - - io.substrait.proto.Expression.Literal.Struct p3Struct = - io.substrait.proto.Expression.Literal.Struct.newBuilder() - .addFields(io.substrait.proto.Expression.Literal.newBuilder().setI32(5)) - .addFields(io.substrait.proto.Expression.Literal.newBuilder().setI32(10)) - .build(); - Any p3Any = - Any.pack(io.substrait.proto.Expression.Literal.newBuilder().setStruct(p3Struct).build()); - Expression.UserDefinedAny p3 = - ExpressionCreator.userDefinedLiteralAny( - false, NESTED_TYPES_URN, "point", Collections.emptyList(), p3Any); - - // Create a "triangle" struct containing three point UDTs - io.substrait.proto.Expression.Literal.Struct triangleStruct = - io.substrait.proto.Expression.Literal.Struct.newBuilder() - .addFields(NESTED_TYPES_EXPRESSION_TO_PROTO.toProto(p1).getLiteral()) - .addFields(NESTED_TYPES_EXPRESSION_TO_PROTO.toProto(p2).getLiteral()) - .addFields(NESTED_TYPES_EXPRESSION_TO_PROTO.toProto(p3).getLiteral()) - .build(); Any triangleAny = - Any.pack( - io.substrait.proto.Expression.Literal.newBuilder().setStruct(triangleStruct).build()); + Any.pack(com.google.protobuf.StringValue.of("")); Expression.UserDefinedAny triangle = ExpressionCreator.userDefinedLiteralAny( @@ -228,38 +178,21 @@ void nestedUserDefinedLiteralWithAnyRepresentation() { */ @Test void mixedRepresentationNestedUserDefinedLiteral() { - io.substrait.proto.Expression.Literal.Struct p1Struct = - io.substrait.proto.Expression.Literal.Struct.newBuilder() - .addFields(io.substrait.proto.Expression.Literal.newBuilder().setI32(0)) - .addFields(io.substrait.proto.Expression.Literal.newBuilder().setI32(0)) - .build(); - Any p1Any = - Any.pack(io.substrait.proto.Expression.Literal.newBuilder().setStruct(p1Struct).build()); + Any anyValue = + Any.pack(com.google.protobuf.StringValue.of("")); + + // Create point UDTs using Any representation Expression.UserDefinedAny p1 = ExpressionCreator.userDefinedLiteralAny( - false, NESTED_TYPES_URN, "point", Collections.emptyList(), p1Any); + false, NESTED_TYPES_URN, "point", Collections.emptyList(), anyValue); - io.substrait.proto.Expression.Literal.Struct p2Struct = - io.substrait.proto.Expression.Literal.Struct.newBuilder() - .addFields(io.substrait.proto.Expression.Literal.newBuilder().setI32(10)) - .addFields(io.substrait.proto.Expression.Literal.newBuilder().setI32(0)) - .build(); - Any p2Any = - Any.pack(io.substrait.proto.Expression.Literal.newBuilder().setStruct(p2Struct).build()); Expression.UserDefinedAny p2 = ExpressionCreator.userDefinedLiteralAny( - false, NESTED_TYPES_URN, "point", Collections.emptyList(), p2Any); + false, NESTED_TYPES_URN, "point", Collections.emptyList(), anyValue); - io.substrait.proto.Expression.Literal.Struct p3Struct = - io.substrait.proto.Expression.Literal.Struct.newBuilder() - .addFields(io.substrait.proto.Expression.Literal.newBuilder().setI32(5)) - .addFields(io.substrait.proto.Expression.Literal.newBuilder().setI32(10)) - .build(); - Any p3Any = - Any.pack(io.substrait.proto.Expression.Literal.newBuilder().setStruct(p3Struct).build()); Expression.UserDefinedAny p3 = ExpressionCreator.userDefinedLiteralAny( - false, NESTED_TYPES_URN, "point", Collections.emptyList(), p3Any); + false, NESTED_TYPES_URN, "point", Collections.emptyList(), anyValue); // Create a "triangle" UDT using Struct representation, but with Any-encoded point fields Expression.UserDefinedStruct triangle = From 8c5470673ed9e91991b615d0ef12058065ca16e2 Mon Sep 17 00:00:00 2001 From: Ben Bellick Date: Mon, 24 Nov 2025 11:26:48 -0500 Subject: [PATCH 04/13] refactor: rename UserDefined{kind} to UserDefined{kind}Literal --- .../expression/AbstractExpressionVisitor.java | 4 ++-- .../io/substrait/expression/Expression.java | 20 ++++++++--------- .../expression/ExpressionCreator.java | 12 +++++----- .../expression/ExpressionVisitor.java | 4 ++-- .../proto/ExpressionProtoConverter.java | 6 +++-- .../ExpressionCopyOnWriteVisitor.java | 6 ++--- .../type/proto/LiteralRoundtripTest.java | 22 +++++++++---------- .../examples/util/ExpressionStringify.java | 13 ++++++----- .../isthmus/expression/CallConverters.java | 15 +++++++------ .../expression/ExpressionRexConverter.java | 8 ++++--- .../substrait/debug/ExpressionToString.scala | 6 +++-- .../spark/DefaultExpressionVisitor.scala | 6 +++-- 12 files changed, 66 insertions(+), 56 deletions(-) diff --git a/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java b/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java index d190542f8..6b49cae19 100644 --- a/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java +++ b/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java @@ -152,12 +152,12 @@ public O visit(Expression.StructLiteral expr, C context) throws E { } @Override - public O visit(Expression.UserDefinedAny expr, C context) throws E { + public O visit(Expression.UserDefinedAnyLiteral expr, C context) throws E { return visitFallback(expr, context); } @Override - public O visit(Expression.UserDefinedStruct expr, C context) throws E { + public O visit(Expression.UserDefinedStructLiteral expr, C context) throws E { return visitFallback(expr, context); } diff --git a/core/src/main/java/io/substrait/expression/Expression.java b/core/src/main/java/io/substrait/expression/Expression.java index 1b0a8362d..2adb6f77b 100644 --- a/core/src/main/java/io/substrait/expression/Expression.java +++ b/core/src/main/java/io/substrait/expression/Expression.java @@ -668,12 +668,12 @@ public R accept( *

User-defined literals can be encoded in one of two ways as per the Substrait spec: * *

    - *
  • As {@code google.protobuf.Any} - see {@link UserDefinedAny} - *
  • As {@code Literal.Struct} - see {@link UserDefinedStruct} + *
  • As {@code google.protobuf.Any} - see {@link UserDefinedAnyLiteral} + *
  • As {@code Literal.Struct} - see {@link UserDefinedStructLiteral} *
* - * @see UserDefinedAny - * @see UserDefinedStruct + * @see UserDefinedAnyLiteral + * @see UserDefinedStructLiteral */ interface UserDefinedLiteral extends Literal { String urn(); @@ -689,7 +689,7 @@ interface UserDefinedLiteral extends Literal { *

This encoding allows for arbitrary binary data to be stored in the literal value. */ @Value.Immutable - abstract class UserDefinedAny implements UserDefinedLiteral { + abstract class UserDefinedAnyLiteral implements UserDefinedLiteral { @Override public abstract String urn(); @@ -711,8 +711,8 @@ public Type.UserDefined getType() { .build(); } - public static ImmutableExpression.UserDefinedAny.Builder builder() { - return ImmutableExpression.UserDefinedAny.builder(); + public static ImmutableExpression.UserDefinedAnyLiteral.Builder builder() { + return ImmutableExpression.UserDefinedAnyLiteral.builder(); } @Override @@ -728,7 +728,7 @@ public R accept( *

This encoding uses a structured list of fields to represent the literal value. */ @Value.Immutable - abstract class UserDefinedStruct implements UserDefinedLiteral { + abstract class UserDefinedStructLiteral implements UserDefinedLiteral { @Override public abstract String urn(); @@ -750,8 +750,8 @@ public Type.UserDefined getType() { .build(); } - public static ImmutableExpression.UserDefinedStruct.Builder builder() { - return ImmutableExpression.UserDefinedStruct.builder(); + public static ImmutableExpression.UserDefinedStructLiteral.Builder builder() { + return ImmutableExpression.UserDefinedStructLiteral.builder(); } @Override diff --git a/core/src/main/java/io/substrait/expression/ExpressionCreator.java b/core/src/main/java/io/substrait/expression/ExpressionCreator.java index 2f924bef8..57945b536 100644 --- a/core/src/main/java/io/substrait/expression/ExpressionCreator.java +++ b/core/src/main/java/io/substrait/expression/ExpressionCreator.java @@ -287,7 +287,7 @@ public static Expression.StructLiteral struct( } /** - * Create a UserDefinedAny with google.protobuf.Any representation. + * Create a UserDefinedAnyLiteral with google.protobuf.Any representation. * * @param nullable whether the literal is nullable * @param urn the URN of the user-defined type @@ -295,13 +295,13 @@ public static Expression.StructLiteral struct( * @param typeParameters the type parameters for the user-defined type (can be empty list) * @param value the value, encoded as google.protobuf.Any */ - public static Expression.UserDefinedAny userDefinedLiteralAny( + public static Expression.UserDefinedAnyLiteral userDefinedLiteralAny( boolean nullable, String urn, String name, java.util.List typeParameters, Any value) { - return Expression.UserDefinedAny.builder() + return Expression.UserDefinedAnyLiteral.builder() .nullable(nullable) .urn(urn) .name(name) @@ -311,7 +311,7 @@ public static Expression.UserDefinedAny userDefinedLiteralAny( } /** - * Create a UserDefinedStruct with Struct representation. + * Create a UserDefinedStructLiteral with Struct representation. * * @param nullable whether the literal is nullable * @param urn the URN of the user-defined type @@ -319,13 +319,13 @@ public static Expression.UserDefinedAny userDefinedLiteralAny( * @param typeParameters the type parameters for the user-defined type (can be empty list) * @param fields the fields, as a list of Literal values */ - public static Expression.UserDefinedStruct userDefinedLiteralStruct( + public static Expression.UserDefinedStructLiteral userDefinedLiteralStruct( boolean nullable, String urn, String name, java.util.List typeParameters, java.util.List fields) { - return Expression.UserDefinedStruct.builder() + return Expression.UserDefinedStructLiteral.builder() .nullable(nullable) .urn(urn) .name(name) diff --git a/core/src/main/java/io/substrait/expression/ExpressionVisitor.java b/core/src/main/java/io/substrait/expression/ExpressionVisitor.java index 7cec9b953..43f54cadf 100644 --- a/core/src/main/java/io/substrait/expression/ExpressionVisitor.java +++ b/core/src/main/java/io/substrait/expression/ExpressionVisitor.java @@ -62,9 +62,9 @@ public interface ExpressionVisitor visit(Expression.StructLiteral expr, EmptyVisitation } @Override - public Optional visit(Expression.UserDefinedAny expr, EmptyVisitationContext context) - throws E { + public Optional visit( + Expression.UserDefinedAnyLiteral expr, EmptyVisitationContext context) throws E { return visitLiteral(expr); } @Override public Optional visit( - Expression.UserDefinedStruct expr, EmptyVisitationContext context) throws E { + Expression.UserDefinedStructLiteral expr, EmptyVisitationContext context) throws E { return visitLiteral(expr); } diff --git a/core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java index 3fcb9f7aa..d51b66389 100644 --- a/core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java @@ -112,7 +112,7 @@ void userDefinedLiteralWithStructRepresentation() { */ @Test void nestedUserDefinedLiteralWithStructRepresentation() { - Expression.UserDefinedStruct p1 = + Expression.UserDefinedStructLiteral p1 = ExpressionCreator.userDefinedLiteralStruct( false, NESTED_TYPES_URN, @@ -121,7 +121,7 @@ void nestedUserDefinedLiteralWithStructRepresentation() { java.util.Arrays.asList( ExpressionCreator.i32(false, 0), ExpressionCreator.i32(false, 0))); - Expression.UserDefinedStruct p2 = + Expression.UserDefinedStructLiteral p2 = ExpressionCreator.userDefinedLiteralStruct( false, NESTED_TYPES_URN, @@ -130,7 +130,7 @@ void nestedUserDefinedLiteralWithStructRepresentation() { java.util.Arrays.asList( ExpressionCreator.i32(false, 10), ExpressionCreator.i32(false, 0))); - Expression.UserDefinedStruct p3 = + Expression.UserDefinedStructLiteral p3 = ExpressionCreator.userDefinedLiteralStruct( false, NESTED_TYPES_URN, @@ -139,7 +139,7 @@ void nestedUserDefinedLiteralWithStructRepresentation() { java.util.Arrays.asList( ExpressionCreator.i32(false, 5), ExpressionCreator.i32(false, 10))); - Expression.UserDefinedStruct triangle = + Expression.UserDefinedStructLiteral triangle = ExpressionCreator.userDefinedLiteralStruct( false, NESTED_TYPES_URN, @@ -162,7 +162,7 @@ void nestedUserDefinedLiteralWithAnyRepresentation() { Any triangleAny = Any.pack(com.google.protobuf.StringValue.of("")); - Expression.UserDefinedAny triangle = + Expression.UserDefinedAnyLiteral triangle = ExpressionCreator.userDefinedLiteralAny( false, NESTED_TYPES_URN, "triangle", Collections.emptyList(), triangleAny); @@ -182,20 +182,20 @@ void mixedRepresentationNestedUserDefinedLiteral() { Any.pack(com.google.protobuf.StringValue.of("")); // Create point UDTs using Any representation - Expression.UserDefinedAny p1 = + Expression.UserDefinedAnyLiteral p1 = ExpressionCreator.userDefinedLiteralAny( false, NESTED_TYPES_URN, "point", Collections.emptyList(), anyValue); - Expression.UserDefinedAny p2 = + Expression.UserDefinedAnyLiteral p2 = ExpressionCreator.userDefinedLiteralAny( false, NESTED_TYPES_URN, "point", Collections.emptyList(), anyValue); - Expression.UserDefinedAny p3 = + Expression.UserDefinedAnyLiteral p3 = ExpressionCreator.userDefinedLiteralAny( false, NESTED_TYPES_URN, "point", Collections.emptyList(), anyValue); // Create a "triangle" UDT using Struct representation, but with Any-encoded point fields - Expression.UserDefinedStruct triangle = + Expression.UserDefinedStructLiteral triangle = ExpressionCreator.userDefinedLiteralStruct( false, NESTED_TYPES_URN, @@ -227,7 +227,7 @@ void userDefinedLiteralWithTypeParameters() { io.substrait.proto.Type.Parameter.newBuilder().setDataType(i32Type).build(); // Create a vector instance with fields (x: 1, y: 2, z: 3) - Expression.UserDefinedStruct vectorI32 = + Expression.UserDefinedStructLiteral vectorI32 = ExpressionCreator.userDefinedLiteralStruct( false, NESTED_TYPES_URN, @@ -243,7 +243,7 @@ void userDefinedLiteralWithTypeParameters() { Expression result = NESTED_TYPES_PROTO_TO_EXPRESSION.from(protoExpression); assertEquals(vectorI32, result); - Expression.UserDefinedStruct resultStruct = (Expression.UserDefinedStruct) result; + Expression.UserDefinedStructLiteral resultStruct = (Expression.UserDefinedStructLiteral) result; assertEquals(1, resultStruct.typeParameters().size()); assertEquals(typeParam, resultStruct.typeParameters().get(0)); } diff --git a/examples/substrait-spark/src/main/java/io/substrait/examples/util/ExpressionStringify.java b/examples/substrait-spark/src/main/java/io/substrait/examples/util/ExpressionStringify.java index cee84d13d..fe1e7a965 100644 --- a/examples/substrait-spark/src/main/java/io/substrait/examples/util/ExpressionStringify.java +++ b/examples/substrait-spark/src/main/java/io/substrait/examples/util/ExpressionStringify.java @@ -37,8 +37,8 @@ import io.substrait.expression.Expression.TimestampLiteral; import io.substrait.expression.Expression.TimestampTZLiteral; import io.substrait.expression.Expression.UUIDLiteral; -import io.substrait.expression.Expression.UserDefinedAny; -import io.substrait.expression.Expression.UserDefinedStruct; +import io.substrait.expression.Expression.UserDefinedAnyLiteral; +import io.substrait.expression.Expression.UserDefinedStructLiteral; import io.substrait.expression.Expression.VarCharLiteral; import io.substrait.expression.Expression.WindowFunctionInvocation; import io.substrait.expression.ExpressionVisitor; @@ -189,14 +189,15 @@ public String visit(StructLiteral expr, EmptyVisitationContext context) throws R } @Override - public String visit(UserDefinedAny expr, EmptyVisitationContext context) throws RuntimeException { - return ""; + public String visit(UserDefinedAnyLiteral expr, EmptyVisitationContext context) + throws RuntimeException { + return ""; } @Override - public String visit(UserDefinedStruct expr, EmptyVisitationContext context) + public String visit(UserDefinedStructLiteral expr, EmptyVisitationContext context) throws RuntimeException { - return ""; + return ""; } @Override diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java b/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java index 5191ecca8..6ca93b950 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java @@ -44,15 +44,16 @@ public class CallConverters { * {@link SqlKind#REINTERPRET} is utilized by Isthmus to represent and store {@link * Expression.UserDefinedLiteral}s within Calcite. * - *

When converting from Substrait to Calcite, the {@link Expression.UserDefinedAny#value()} is - * stored within a {@link org.apache.calcite.sql.type.SqlTypeName#BINARY} {@link - * org.apache.calcite.rex.RexLiteral} and then re-interpreted to have the correct type. + *

When converting from Substrait to Calcite, the {@link + * Expression.UserDefinedAnyLiteral#value()} is stored within a {@link + * org.apache.calcite.sql.type.SqlTypeName#BINARY} {@link org.apache.calcite.rex.RexLiteral} and + * then re-interpreted to have the correct type. * - *

See {@link ExpressionRexConverter#visit(Expression.UserDefinedAny, + *

See {@link ExpressionRexConverter#visit(Expression.UserDefinedAnyLiteral, * SubstraitRelNodeConverter.Context)} for this conversion. * *

When converting from Calcite to Substrait, this call converter extracts the {@link - * Expression.UserDefinedAny} that was stored. + * Expression.UserDefinedAnyLiteral} that was stored. */ public static Function REINTERPRET = typeConverter -> @@ -75,14 +76,14 @@ public class CallConverters { com.google.protobuf.Any anyValue = com.google.protobuf.Any.parseFrom(literal.value().toByteArray()); - return Expression.UserDefinedAny.builder() + return Expression.UserDefinedAnyLiteral.builder() .urn(t.urn()) .name(t.name()) .addAllTypeParameters(t.typeParameters()) .value(anyValue) .build(); } catch (com.google.protobuf.InvalidProtocolBufferException e) { - throw new IllegalStateException("Failed to parse UserDefinedAny value", e); + throw new IllegalStateException("Failed to parse UserDefinedAnyLiteral value", e); } } return null; diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java index e4f82731a..3127ce642 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java @@ -109,7 +109,8 @@ public RexNode visit(Expression.NullLiteral expr, Context context) throws Runtim } @Override - public RexNode visit(Expression.UserDefinedAny expr, Context context) throws RuntimeException { + public RexNode visit(Expression.UserDefinedAnyLiteral expr, Context context) + throws RuntimeException { RexLiteral binaryLiteral = rexBuilder.makeBinaryLiteral(new ByteString(expr.value().toByteString().toByteArray())); RelDataType type = typeConverter.toCalcite(typeFactory, expr.getType()); @@ -117,9 +118,10 @@ public RexNode visit(Expression.UserDefinedAny expr, Context context) throws Run } @Override - public RexNode visit(Expression.UserDefinedStruct expr, Context context) throws RuntimeException { + public RexNode visit(Expression.UserDefinedStructLiteral expr, Context context) + throws RuntimeException { throw new UnsupportedOperationException( - "UserDefinedStruct representation is not yet supported in Isthmus"); + "UserDefinedStructLiteral representation is not yet supported in Isthmus"); } @Override diff --git a/spark/src/main/scala/io/substrait/debug/ExpressionToString.scala b/spark/src/main/scala/io/substrait/debug/ExpressionToString.scala index 133052fa1..10c134658 100644 --- a/spark/src/main/scala/io/substrait/debug/ExpressionToString.scala +++ b/spark/src/main/scala/io/substrait/debug/ExpressionToString.scala @@ -76,12 +76,14 @@ class ExpressionToString extends DefaultExpressionVisitor[String] { s"${expr.declaration().key()}[${expr.outputType().accept(ToTypeString.INSTANCE)}]($args)" } - override def visit(expr: Expression.UserDefinedAny, context: EmptyVisitationContext): String = { + override def visit( + expr: Expression.UserDefinedAnyLiteral, + context: EmptyVisitationContext): String = { expr.toString } override def visit( - expr: Expression.UserDefinedStruct, + expr: Expression.UserDefinedStructLiteral, context: EmptyVisitationContext): String = { expr.toString } diff --git a/spark/src/main/scala/io/substrait/spark/DefaultExpressionVisitor.scala b/spark/src/main/scala/io/substrait/spark/DefaultExpressionVisitor.scala index 07594d3bf..d1b7a32a6 100644 --- a/spark/src/main/scala/io/substrait/spark/DefaultExpressionVisitor.scala +++ b/spark/src/main/scala/io/substrait/spark/DefaultExpressionVisitor.scala @@ -65,9 +65,11 @@ class DefaultExpressionVisitor[T] context: EmptyVisitationContext): T = e.accept(this, context) - override def visit(expr: Expression.UserDefinedAny, context: EmptyVisitationContext): T = + override def visit(expr: Expression.UserDefinedAnyLiteral, context: EmptyVisitationContext): T = visitFallback(expr, context) - override def visit(expr: Expression.UserDefinedStruct, context: EmptyVisitationContext): T = + override def visit( + expr: Expression.UserDefinedStructLiteral, + context: EmptyVisitationContext): T = visitFallback(expr, context) } From 36c81f00e445c959c705534182c1c115b8d4fa1b Mon Sep 17 00:00:00 2001 From: Ben Bellick Date: Mon, 24 Nov 2025 14:09:54 -0500 Subject: [PATCH 05/13] feat: first class support for TypeParameters --- .../io/substrait/expression/Expression.java | 6 +-- .../expression/ExpressionCreator.java | 4 +- .../proto/ExpressionProtoConverter.java | 10 +++- .../proto/ProtoExpressionConverter.java | 8 +++- .../src/main/java/io/substrait/type/Type.java | 48 ++++++++++++++++++- .../substrait/type/proto/BaseProtoTypes.java | 2 +- .../proto/ParameterizedProtoConverter.java | 2 +- .../type/proto/ProtoTypeConverter.java | 30 +++++++++++- .../proto/TypeExpressionProtoVisitor.java | 2 +- .../type/proto/TypeProtoConverter.java | 41 +++++++++++++++- .../type/proto/LiteralRoundtripTest.java | 11 ++--- 11 files changed, 140 insertions(+), 24 deletions(-) diff --git a/core/src/main/java/io/substrait/expression/Expression.java b/core/src/main/java/io/substrait/expression/Expression.java index 2adb6f77b..7f0fc8748 100644 --- a/core/src/main/java/io/substrait/expression/Expression.java +++ b/core/src/main/java/io/substrait/expression/Expression.java @@ -680,7 +680,7 @@ interface UserDefinedLiteral extends Literal { String name(); - List typeParameters(); + List typeParameters(); } /** @@ -697,7 +697,7 @@ abstract class UserDefinedAnyLiteral implements UserDefinedLiteral { public abstract String name(); @Override - public abstract List typeParameters(); + public abstract List typeParameters(); public abstract com.google.protobuf.Any value(); @@ -736,7 +736,7 @@ abstract class UserDefinedStructLiteral implements UserDefinedLiteral { public abstract String name(); @Override - public abstract List typeParameters(); + public abstract List typeParameters(); public abstract List fields(); diff --git a/core/src/main/java/io/substrait/expression/ExpressionCreator.java b/core/src/main/java/io/substrait/expression/ExpressionCreator.java index 57945b536..0b3fb6cd8 100644 --- a/core/src/main/java/io/substrait/expression/ExpressionCreator.java +++ b/core/src/main/java/io/substrait/expression/ExpressionCreator.java @@ -299,7 +299,7 @@ public static Expression.UserDefinedAnyLiteral userDefinedLiteralAny( boolean nullable, String urn, String name, - java.util.List typeParameters, + java.util.List typeParameters, Any value) { return Expression.UserDefinedAnyLiteral.builder() .nullable(nullable) @@ -323,7 +323,7 @@ public static Expression.UserDefinedStructLiteral userDefinedLiteralStruct( boolean nullable, String urn, String name, - java.util.List typeParameters, + java.util.List typeParameters, java.util.List fields) { return Expression.UserDefinedStructLiteral.builder() .nullable(nullable) diff --git a/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java b/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java index 97dc63196..2a966c9d5 100644 --- a/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java +++ b/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java @@ -366,7 +366,10 @@ public Expression visit( Expression.Literal.UserDefined.Builder userDefinedBuilder = Expression.Literal.UserDefined.newBuilder() .setTypeReference(typeReference) - .addAllTypeParameters(expr.typeParameters()) + .addAllTypeParameters( + expr.typeParameters().stream() + .map(io.substrait.type.proto.TypeProtoConverter::toProto) + .collect(java.util.stream.Collectors.toList())) .setValue(expr.value()); bldr.setNullable(expr.nullable()).setUserDefined(userDefinedBuilder).build(); @@ -392,7 +395,10 @@ public Expression visit( Expression.Literal.UserDefined.Builder userDefinedBuilder = Expression.Literal.UserDefined.newBuilder() .setTypeReference(typeReference) - .addAllTypeParameters(expr.typeParameters()) + .addAllTypeParameters( + expr.typeParameters().stream() + .map(io.substrait.type.proto.TypeProtoConverter::toProto) + .collect(java.util.stream.Collectors.toList())) .setStruct(structLiteral); bldr.setNullable(expr.nullable()).setUserDefined(userDefinedBuilder).build(); diff --git a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java index 847fcae55..39c5e2f9a 100644 --- a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java +++ b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java @@ -504,14 +504,18 @@ public Expression.Literal from(io.substrait.proto.Expression.Literal literal) { literal.getNullable(), urn, name, - userDefinedLiteral.getTypeParametersList(), + userDefinedLiteral.getTypeParametersList().stream() + .map(protoTypeConverter::from) + .collect(Collectors.toList()), userDefinedLiteral.getValue()); case STRUCT: return ExpressionCreator.userDefinedLiteralStruct( literal.getNullable(), urn, name, - userDefinedLiteral.getTypeParametersList(), + userDefinedLiteral.getTypeParametersList().stream() + .map(protoTypeConverter::from) + .collect(Collectors.toList()), userDefinedLiteral.getStruct().getFieldsList().stream() .map(this::from) .collect(Collectors.toList())); diff --git a/core/src/main/java/io/substrait/type/Type.java b/core/src/main/java/io/substrait/type/Type.java index 7ef2d75a7..354535125 100644 --- a/core/src/main/java/io/substrait/type/Type.java +++ b/core/src/main/java/io/substrait/type/Type.java @@ -406,7 +406,7 @@ abstract class UserDefined implements Type { * @return a list of type parameters, or an empty list if this type is not parameterized */ @Value.Default - public java.util.List typeParameters() { + public java.util.List typeParameters() { return java.util.Collections.emptyList(); } @@ -419,4 +419,50 @@ public R accept(TypeVisitor typeVisitor) throws E return typeVisitor.visit(this); } } + + /** + * Represents a type parameter for user-defined types. + * + *

Type parameters can be data types (like {@code i32} in {@code List}), or value + * parameters (like the {@code 10} in {@code VARCHAR<10>}). This interface provides a type-safe + * representation of all possible parameter kinds. + */ + interface Parameter {} + + /** A data type parameter, such as the {@code i32} in {@code List}. */ + @Value.Immutable + abstract class ParameterDataType implements Parameter { + public abstract Type type(); + } + + /** A boolean value parameter. */ + @Value.Immutable + abstract class ParameterBooleanValue implements Parameter { + public abstract boolean value(); + } + + /** An integer value parameter, such as the {@code 10} in {@code VARCHAR<10>}. */ + @Value.Immutable + abstract class ParameterIntegerValue implements Parameter { + public abstract long value(); + } + + /** An enum value parameter (represented as a string). */ + @Value.Immutable + abstract class ParameterEnumValue implements Parameter { + public abstract String value(); + } + + /** A string value parameter. */ + @Value.Immutable + abstract class ParameterStringValue implements Parameter { + public abstract String value(); + } + + /** An explicitly null/unspecified parameter, used to select the default value (if any). */ + class ParameterNull implements Parameter { + public static final ParameterNull INSTANCE = new ParameterNull(); + + private ParameterNull() {} + } } diff --git a/core/src/main/java/io/substrait/type/proto/BaseProtoTypes.java b/core/src/main/java/io/substrait/type/proto/BaseProtoTypes.java index 1009fe52a..57b1f26b5 100644 --- a/core/src/main/java/io/substrait/type/proto/BaseProtoTypes.java +++ b/core/src/main/java/io/substrait/type/proto/BaseProtoTypes.java @@ -132,7 +132,7 @@ public final T struct(T... types) { public abstract T userDefined(int ref); public abstract T userDefined( - int ref, java.util.List typeParameters); + int ref, java.util.List typeParameters); protected abstract T wrap(Object o); diff --git a/core/src/main/java/io/substrait/type/proto/ParameterizedProtoConverter.java b/core/src/main/java/io/substrait/type/proto/ParameterizedProtoConverter.java index 137c1fba3..817f7b0b3 100644 --- a/core/src/main/java/io/substrait/type/proto/ParameterizedProtoConverter.java +++ b/core/src/main/java/io/substrait/type/proto/ParameterizedProtoConverter.java @@ -264,7 +264,7 @@ public ParameterizedType userDefined(int ref) { @Override public ParameterizedType userDefined( - int ref, java.util.List typeParameters) { + int ref, java.util.List typeParameters) { throw new UnsupportedOperationException( "User defined types are not supported in Parameterized Types for now"); } diff --git a/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java b/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java index ee77e1445..bdb600c1c 100644 --- a/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java +++ b/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java @@ -2,6 +2,7 @@ import io.substrait.extension.ExtensionLookup; import io.substrait.extension.SimpleExtension; +import io.substrait.type.ImmutableType; import io.substrait.type.Type; import io.substrait.type.TypeCreator; @@ -95,7 +96,10 @@ public Type from(io.substrait.proto.Type type) { .nullable(nullable) .urn(t.urn()) .name(t.name()) - .typeParameters(userDefined.getTypeParametersList()) + .typeParameters( + userDefined.getTypeParametersList().stream() + .map(this::from) + .collect(java.util.stream.Collectors.toList())) .build(); } case USER_DEFINED_TYPE_REFERENCE: @@ -124,4 +128,28 @@ private static TypeCreator n(io.substrait.proto.Type.Nullability n) { ? TypeCreator.NULLABLE : TypeCreator.REQUIRED; } + + public io.substrait.type.Type.Parameter from(io.substrait.proto.Type.Parameter parameter) { + switch (parameter.getParameterCase()) { + case NULL: + return io.substrait.type.Type.ParameterNull.INSTANCE; + case DATA_TYPE: + return ImmutableType.ParameterDataType.builder() + .type(from(parameter.getDataType())) + .build(); + case BOOLEAN: + return ImmutableType.ParameterBooleanValue.builder().value(parameter.getBoolean()).build(); + case INTEGER: + return ImmutableType.ParameterIntegerValue.builder().value(parameter.getInteger()).build(); + case ENUM: + return ImmutableType.ParameterEnumValue.builder().value(parameter.getEnum()).build(); + case STRING: + return ImmutableType.ParameterStringValue.builder().value(parameter.getString()).build(); + case PARAMETER_NOT_SET: + throw new IllegalArgumentException("Parameter type is not set: " + parameter); + default: + throw new UnsupportedOperationException( + "Unsupported parameter type: " + parameter.getParameterCase()); + } + } } diff --git a/core/src/main/java/io/substrait/type/proto/TypeExpressionProtoVisitor.java b/core/src/main/java/io/substrait/type/proto/TypeExpressionProtoVisitor.java index a3412a9e3..f9d2129d2 100644 --- a/core/src/main/java/io/substrait/type/proto/TypeExpressionProtoVisitor.java +++ b/core/src/main/java/io/substrait/type/proto/TypeExpressionProtoVisitor.java @@ -357,7 +357,7 @@ public DerivationExpression userDefined(int ref) { @Override public DerivationExpression userDefined( - int ref, java.util.List typeParameters) { + int ref, java.util.List typeParameters) { throw new UnsupportedOperationException( "User defined types are not supported in Derivation Expressions for now"); } diff --git a/core/src/main/java/io/substrait/type/proto/TypeProtoConverter.java b/core/src/main/java/io/substrait/type/proto/TypeProtoConverter.java index 7cb98263f..b9070cebb 100644 --- a/core/src/main/java/io/substrait/type/proto/TypeProtoConverter.java +++ b/core/src/main/java/io/substrait/type/proto/TypeProtoConverter.java @@ -135,12 +135,15 @@ public Type userDefined(int ref) { @Override public Type userDefined( - int ref, java.util.List typeParameters) { + int ref, java.util.List typeParameters) { return wrap( Type.UserDefined.newBuilder() .setTypeReference(ref) .setNullability(nullability) - .addAllTypeParameters(typeParameters) + .addAllTypeParameters( + typeParameters.stream() + .map(TypeProtoConverter::toProto) + .collect(java.util.stream.Collectors.toList())) .build()); } @@ -210,4 +213,38 @@ protected Integer i(final int integerValue) { return integerValue; } } + + public static io.substrait.proto.Type.Parameter toProto( + io.substrait.type.Type.Parameter parameter) { + if (parameter instanceof io.substrait.type.Type.ParameterNull) { + return Type.Parameter.newBuilder() + .setNull(com.google.protobuf.Empty.getDefaultInstance()) + .build(); + } else if (parameter instanceof io.substrait.type.Type.ParameterDataType) { + io.substrait.type.Type.ParameterDataType dataType = + (io.substrait.type.Type.ParameterDataType) parameter; + TypeProtoConverter converter = + new TypeProtoConverter(new io.substrait.extension.ExtensionCollector()); + return Type.Parameter.newBuilder().setDataType(converter.toProto(dataType.type())).build(); + } else if (parameter instanceof io.substrait.type.Type.ParameterBooleanValue) { + io.substrait.type.Type.ParameterBooleanValue boolValue = + (io.substrait.type.Type.ParameterBooleanValue) parameter; + return Type.Parameter.newBuilder().setBoolean(boolValue.value()).build(); + } else if (parameter instanceof io.substrait.type.Type.ParameterIntegerValue) { + io.substrait.type.Type.ParameterIntegerValue intValue = + (io.substrait.type.Type.ParameterIntegerValue) parameter; + return Type.Parameter.newBuilder().setInteger(intValue.value()).build(); + } else if (parameter instanceof io.substrait.type.Type.ParameterEnumValue) { + io.substrait.type.Type.ParameterEnumValue enumValue = + (io.substrait.type.Type.ParameterEnumValue) parameter; + return Type.Parameter.newBuilder().setEnum(enumValue.value()).build(); + } else if (parameter instanceof io.substrait.type.Type.ParameterStringValue) { + io.substrait.type.Type.ParameterStringValue stringValue = + (io.substrait.type.Type.ParameterStringValue) parameter; + return Type.Parameter.newBuilder().setString(stringValue.value()).build(); + } else { + throw new UnsupportedOperationException( + "Unsupported parameter type: " + parameter.getClass()); + } + } } diff --git a/core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java index d51b66389..f14267fc9 100644 --- a/core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java @@ -216,15 +216,10 @@ void mixedRepresentationNestedUserDefinedLiteral() { @Test void userDefinedLiteralWithTypeParameters() { // Create a type parameter for i32 - io.substrait.proto.Type i32Type = - io.substrait.proto.Type.newBuilder() - .setI32( - io.substrait.proto.Type.I32 - .newBuilder() - .setNullability(io.substrait.proto.Type.Nullability.NULLABILITY_REQUIRED)) + io.substrait.type.Type.Parameter typeParam = + io.substrait.type.ImmutableType.ParameterDataType.builder() + .type(io.substrait.type.Type.I32.builder().nullable(false).build()) .build(); - io.substrait.proto.Type.Parameter typeParam = - io.substrait.proto.Type.Parameter.newBuilder().setDataType(i32Type).build(); // Create a vector instance with fields (x: 1, y: 2, z: 3) Expression.UserDefinedStructLiteral vectorI32 = From 680ed292d264eb15d6a9a085e654bb33944ae94a Mon Sep 17 00:00:00 2001 From: Ben Bellick Date: Mon, 24 Nov 2025 14:19:02 -0500 Subject: [PATCH 06/13] docs: update docstring to explain use of Parameter --- core/src/main/java/io/substrait/type/Type.java | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/core/src/main/java/io/substrait/type/Type.java b/core/src/main/java/io/substrait/type/Type.java index 354535125..c597e7e67 100644 --- a/core/src/main/java/io/substrait/type/Type.java +++ b/core/src/main/java/io/substrait/type/Type.java @@ -397,11 +397,18 @@ abstract class UserDefined implements Type { * Returns the type parameters for this user-defined type. * *

Type parameters are used to represent parameterized/generic types, such as {@code - * List} or {@code Map}. Each parameter in the list represents a type argument - * that specializes the generic user-defined type. + * vector} or custom types like {@code FixedArray<100>}. Each parameter in the list can be + * either a type (like {@code i32}) or a value (like the integer {@code 100}). * - *

For example, a user-defined type {@code MyList} parameterized by {@code i32} would have - * one type parameter containing the {@code i32} type definition. + *

Unlike built-in parameterized types ({@link Map}, {@link ListType}, {@link Decimal}), + * which have fixed, known schemas with concrete typed fields, user-defined types have variable, + * unknown schemas. This is why UserDefined uses a generic {@link Parameter} list that can hold + * any mix of types or values, while other parameterized types use concrete fields like {@code + * Type key()} or {@code int precision()}. + * + *

For example, a user-defined {@code vector} type parameterized by {@code i32} would have + * one type parameter containing the {@code i32} type definition, while a {@code FixedArray} + * type might take an integer parameter specifying its size. * * @return a list of type parameters, or an empty list if this type is not parameterized */ From 6ac7aab03de1199ff78fe4e7a67191c144b186d1 Mon Sep 17 00:00:00 2001 From: Ben Bellick Date: Mon, 24 Nov 2025 14:42:00 -0500 Subject: [PATCH 07/13] test: add test showing all Type.Parameters preserved --- .../type/proto/LiteralRoundtripTest.java | 85 ++++++++++++++----- 1 file changed, 65 insertions(+), 20 deletions(-) diff --git a/core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java index f14267fc9..e70fe685e 100644 --- a/core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java @@ -44,7 +44,23 @@ public class LiteralRoundtripTest extends TestBase { + " structure:\n" + " x: T\n" + " y: T\n" - + " z: T\n"; + + " z: T\n" + + " - name: multi_param\n" + + " parameters:\n" + + " - name: T\n" + + " type: dataType\n" + + " - name: size\n" + + " type: integer\n" + + " - name: nullable\n" + + " type: boolean\n" + + " - name: encoding\n" + + " type: string\n" + + " - name: precision\n" + + " type: dataType\n" + + " - name: mode\n" + + " type: enum\n" + + " structure:\n" + + " value: T\n"; private static final SimpleExtension.ExtensionCollection NESTED_TYPES_EXTENSIONS = SimpleExtension.load("nested_types.yaml", NESTED_TYPES_YAML); @@ -65,6 +81,13 @@ public class LiteralRoundtripTest extends TestBase { EMPTY_TYPE, NESTED_TYPES_PROTO_REL_CONVERTER); + private void verifyNestedTypesRoundTrip(Expression expression) { + io.substrait.proto.Expression protoExpression = + NESTED_TYPES_EXPRESSION_TO_PROTO.toProto(expression); + Expression result = NESTED_TYPES_PROTO_TO_EXPRESSION.from(protoExpression); + assertEquals(expression, result); + } + @Test void decimal() { io.substrait.expression.Expression.DecimalLiteral val = @@ -147,10 +170,7 @@ void nestedUserDefinedLiteralWithStructRepresentation() { Collections.emptyList(), java.util.Arrays.asList(p1, p2, p3)); - io.substrait.proto.Expression protoExpression = - NESTED_TYPES_EXPRESSION_TO_PROTO.toProto(triangle); - Expression result = NESTED_TYPES_PROTO_TO_EXPRESSION.from(protoExpression); - assertEquals(triangle, result); + verifyNestedTypesRoundTrip(triangle); } /** @@ -166,10 +186,7 @@ void nestedUserDefinedLiteralWithAnyRepresentation() { ExpressionCreator.userDefinedLiteralAny( false, NESTED_TYPES_URN, "triangle", Collections.emptyList(), triangleAny); - io.substrait.proto.Expression protoExpression = - NESTED_TYPES_EXPRESSION_TO_PROTO.toProto(triangle); - Expression result = NESTED_TYPES_PROTO_TO_EXPRESSION.from(protoExpression); - assertEquals(triangle, result); + verifyNestedTypesRoundTrip(triangle); } /** @@ -203,10 +220,7 @@ void mixedRepresentationNestedUserDefinedLiteral() { Collections.emptyList(), java.util.Arrays.asList(p1, p2, p3)); - io.substrait.proto.Expression protoExpression = - NESTED_TYPES_EXPRESSION_TO_PROTO.toProto(triangle); - Expression result = NESTED_TYPES_PROTO_TO_EXPRESSION.from(protoExpression); - assertEquals(triangle, result); + verifyNestedTypesRoundTrip(triangle); } /** @@ -233,13 +247,44 @@ void userDefinedLiteralWithTypeParameters() { ExpressionCreator.i32(false, 2), ExpressionCreator.i32(false, 3))); - io.substrait.proto.Expression protoExpression = - NESTED_TYPES_EXPRESSION_TO_PROTO.toProto(vectorI32); - Expression result = NESTED_TYPES_PROTO_TO_EXPRESSION.from(protoExpression); - assertEquals(vectorI32, result); + verifyNestedTypesRoundTrip(vectorI32); + } + + /** + * Verifies round-trip conversion of a user-defined type with all parameter types. Tests that all + * parameter kinds (type, integer, boolean, string, null, enum) are correctly preserved during + * serialization and deserialization. + */ + @Test + void userDefinedLiteralWithAllParameterTypes() { + io.substrait.type.Type.Parameter typeParam = + io.substrait.type.ImmutableType.ParameterDataType.builder() + .type(io.substrait.type.Type.I32.builder().nullable(false).build()) + .build(); + + io.substrait.type.Type.Parameter intParam = + io.substrait.type.ImmutableType.ParameterIntegerValue.builder().value(100L).build(); + + io.substrait.type.Type.Parameter boolParam = + io.substrait.type.ImmutableType.ParameterBooleanValue.builder().value(true).build(); + + io.substrait.type.Type.Parameter stringParam = + io.substrait.type.ImmutableType.ParameterStringValue.builder().value("utf8").build(); + + io.substrait.type.Type.Parameter nullParam = io.substrait.type.Type.ParameterNull.INSTANCE; + + io.substrait.type.Type.Parameter enumParam = + io.substrait.type.ImmutableType.ParameterEnumValue.builder().value("FAST").build(); + + Expression.UserDefinedStructLiteral multiParam = + ExpressionCreator.userDefinedLiteralStruct( + false, + NESTED_TYPES_URN, + "multi_param", + java.util.Arrays.asList( + typeParam, intParam, boolParam, stringParam, nullParam, enumParam), + java.util.Arrays.asList(ExpressionCreator.i32(false, 42))); - Expression.UserDefinedStructLiteral resultStruct = (Expression.UserDefinedStructLiteral) result; - assertEquals(1, resultStruct.typeParameters().size()); - assertEquals(typeParam, resultStruct.typeParameters().get(0)); + verifyNestedTypesRoundTrip(multiParam); } } From b2063d047e80269db5c721a9640c1d68fe6c011c Mon Sep 17 00:00:00 2001 From: Ben Bellick Date: Mon, 24 Nov 2025 16:45:02 -0500 Subject: [PATCH 08/13] fix: correctly handle type conversion with paramterized types --- .../proto/ExpressionProtoConverter.java | 4 +- .../type/proto/TypeProtoConverter.java | 89 ++++++------ .../io/substrait/plan/PlanConverterTest.java | 131 ++++++++++++++++++ 3 files changed, 182 insertions(+), 42 deletions(-) diff --git a/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java b/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java index 2a966c9d5..d33365a5b 100644 --- a/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java +++ b/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java @@ -368,7 +368,7 @@ public Expression visit( .setTypeReference(typeReference) .addAllTypeParameters( expr.typeParameters().stream() - .map(io.substrait.type.proto.TypeProtoConverter::toProto) + .map(typeProtoConverter::toProto) .collect(java.util.stream.Collectors.toList())) .setValue(expr.value()); @@ -397,7 +397,7 @@ public Expression visit( .setTypeReference(typeReference) .addAllTypeParameters( expr.typeParameters().stream() - .map(io.substrait.type.proto.TypeProtoConverter::toProto) + .map(typeProtoConverter::toProto) .collect(java.util.stream.Collectors.toList())) .setStruct(structLiteral); diff --git a/core/src/main/java/io/substrait/type/proto/TypeProtoConverter.java b/core/src/main/java/io/substrait/type/proto/TypeProtoConverter.java index b9070cebb..6422904c4 100644 --- a/core/src/main/java/io/substrait/type/proto/TypeProtoConverter.java +++ b/core/src/main/java/io/substrait/type/proto/TypeProtoConverter.java @@ -5,25 +5,68 @@ /** Convert from {@link io.substrait.type.Type} to {@link io.substrait.proto.Type} */ public class TypeProtoConverter extends BaseProtoConverter { - private static final BaseProtoTypes NULLABLE = - new Types(Type.Nullability.NULLABILITY_NULLABLE); - private static final BaseProtoTypes REQUIRED = - new Types(Type.Nullability.NULLABILITY_REQUIRED); + // Instance fields (not static) because Types is a non-static inner class that calls + // TypeProtoConverter.this.toProto() to recursively convert nested type parameters. + // Each converter instance needs its own Types instances to ensure type registrations + // use the correct ExtensionCollector. + private final BaseProtoTypes NULLABLE; + private final BaseProtoTypes REQUIRED; public TypeProtoConverter(ExtensionCollector extensionCollector) { super(extensionCollector, "Type literals cannot contain parameters or expressions."); + NULLABLE = new Types(Type.Nullability.NULLABILITY_NULLABLE); + REQUIRED = new Types(Type.Nullability.NULLABILITY_REQUIRED); } public io.substrait.proto.Type toProto(io.substrait.type.Type type) { return type.accept(this); } + public io.substrait.proto.Type.Parameter toProto(io.substrait.type.Type.Parameter parameter) { + if (parameter instanceof io.substrait.type.Type.ParameterNull) { + return Type.Parameter.newBuilder() + .setNull(com.google.protobuf.Empty.getDefaultInstance()) + .build(); + } else if (parameter instanceof io.substrait.type.Type.ParameterDataType) { + io.substrait.type.Type.ParameterDataType dataType = + (io.substrait.type.Type.ParameterDataType) parameter; + return Type.Parameter.newBuilder().setDataType(toProto(dataType.type())).build(); + } else if (parameter instanceof io.substrait.type.Type.ParameterBooleanValue) { + io.substrait.type.Type.ParameterBooleanValue boolValue = + (io.substrait.type.Type.ParameterBooleanValue) parameter; + return Type.Parameter.newBuilder().setBoolean(boolValue.value()).build(); + } else if (parameter instanceof io.substrait.type.Type.ParameterIntegerValue) { + io.substrait.type.Type.ParameterIntegerValue intValue = + (io.substrait.type.Type.ParameterIntegerValue) parameter; + return Type.Parameter.newBuilder().setInteger(intValue.value()).build(); + } else if (parameter instanceof io.substrait.type.Type.ParameterEnumValue) { + io.substrait.type.Type.ParameterEnumValue enumValue = + (io.substrait.type.Type.ParameterEnumValue) parameter; + return Type.Parameter.newBuilder().setEnum(enumValue.value()).build(); + } else if (parameter instanceof io.substrait.type.Type.ParameterStringValue) { + io.substrait.type.Type.ParameterStringValue stringValue = + (io.substrait.type.Type.ParameterStringValue) parameter; + return Type.Parameter.newBuilder().setString(stringValue.value()).build(); + } else { + throw new UnsupportedOperationException( + "Unsupported parameter type: " + parameter.getClass()); + } + } + @Override public BaseProtoTypes typeContainer(final boolean nullable) { return nullable ? NULLABLE : REQUIRED; } - private static class Types extends BaseProtoTypes { + /** + * Non-static inner class that can access the outer TypeProtoConverter instance. + * + *

This class must be non-static to access TypeProtoConverter.this.toProto() for converting + * nested type parameters (e.g., ParameterDataType containing another Type). Being non-static + * means instances are bound to a specific outer TypeProtoConverter instance, ensuring parameter + * conversions use the correct ExtensionCollector. + */ + private class Types extends BaseProtoTypes { public Types(final Type.Nullability nullability) { super(nullability); @@ -142,7 +185,7 @@ public Type userDefined( .setNullability(nullability) .addAllTypeParameters( typeParameters.stream() - .map(TypeProtoConverter::toProto) + .map(TypeProtoConverter.this::toProto) .collect(java.util.stream.Collectors.toList())) .build()); } @@ -213,38 +256,4 @@ protected Integer i(final int integerValue) { return integerValue; } } - - public static io.substrait.proto.Type.Parameter toProto( - io.substrait.type.Type.Parameter parameter) { - if (parameter instanceof io.substrait.type.Type.ParameterNull) { - return Type.Parameter.newBuilder() - .setNull(com.google.protobuf.Empty.getDefaultInstance()) - .build(); - } else if (parameter instanceof io.substrait.type.Type.ParameterDataType) { - io.substrait.type.Type.ParameterDataType dataType = - (io.substrait.type.Type.ParameterDataType) parameter; - TypeProtoConverter converter = - new TypeProtoConverter(new io.substrait.extension.ExtensionCollector()); - return Type.Parameter.newBuilder().setDataType(converter.toProto(dataType.type())).build(); - } else if (parameter instanceof io.substrait.type.Type.ParameterBooleanValue) { - io.substrait.type.Type.ParameterBooleanValue boolValue = - (io.substrait.type.Type.ParameterBooleanValue) parameter; - return Type.Parameter.newBuilder().setBoolean(boolValue.value()).build(); - } else if (parameter instanceof io.substrait.type.Type.ParameterIntegerValue) { - io.substrait.type.Type.ParameterIntegerValue intValue = - (io.substrait.type.Type.ParameterIntegerValue) parameter; - return Type.Parameter.newBuilder().setInteger(intValue.value()).build(); - } else if (parameter instanceof io.substrait.type.Type.ParameterEnumValue) { - io.substrait.type.Type.ParameterEnumValue enumValue = - (io.substrait.type.Type.ParameterEnumValue) parameter; - return Type.Parameter.newBuilder().setEnum(enumValue.value()).build(); - } else if (parameter instanceof io.substrait.type.Type.ParameterStringValue) { - io.substrait.type.Type.ParameterStringValue stringValue = - (io.substrait.type.Type.ParameterStringValue) parameter; - return Type.Parameter.newBuilder().setString(stringValue.value()).build(); - } else { - throw new UnsupportedOperationException( - "Unsupported parameter type: " + parameter.getClass()); - } - } } diff --git a/core/src/test/java/io/substrait/plan/PlanConverterTest.java b/core/src/test/java/io/substrait/plan/PlanConverterTest.java index dd49cf207..97268e627 100644 --- a/core/src/test/java/io/substrait/plan/PlanConverterTest.java +++ b/core/src/test/java/io/substrait/plan/PlanConverterTest.java @@ -3,14 +3,22 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; +import io.substrait.expression.Expression; +import io.substrait.expression.ExpressionCreator; import io.substrait.extension.AdvancedExtension; +import io.substrait.extension.SimpleExtension; import io.substrait.plan.Plan.Root; import io.substrait.relation.EmptyScan; +import io.substrait.relation.ImmutableVirtualTableScan; +import io.substrait.relation.VirtualTableScan; import io.substrait.type.NamedStruct; +import io.substrait.type.Type; import io.substrait.type.TypeCreator; import io.substrait.utils.StringHolder; import io.substrait.utils.StringHolderHandlingExtensionProtoConverter; import io.substrait.utils.StringHolderHandlingProtoExtensionConverter; +import java.util.Arrays; +import java.util.Collections; import org.junit.jupiter.api.Test; class PlanConverterTest { @@ -189,4 +197,127 @@ void planIncludingRelationWithAdvancedExtension() { assertEquals(plan, plan2); } + + /** + * Verifies that nested UserDefined types with type parameters share the same ExtensionCollector + * and don't create duplicate type references. Tests that a plan containing both a standalone + * UserDefined literal (point) and a parameterized UserDefined literal (vector) correctly + * registers both types in the extension collection without duplication. + */ + @Test + void nestedUserDefinedTypesShareExtensionCollector() { + // Define custom types: point and vector + String urn = "extension:test:nested_types"; + String yaml = + "---\n" + + "urn: " + + urn + + "\n" + + "types:\n" + + " - name: point\n" + + " structure:\n" + + " x: i32\n" + + " y: i32\n" + + " - name: vector\n" + + " parameters:\n" + + " - name: T\n" + + " type: dataType\n" + + " structure:\n" + + " x: T\n" + + " y: T\n" + + " z: T\n"; + + SimpleExtension.ExtensionCollection extensions = SimpleExtension.load("test.yaml", yaml); + + // Create type objects + Type pointType = Type.UserDefined.builder().nullable(false).urn(urn).name("point").build(); + + Type.Parameter pointTypeParam = + io.substrait.type.ImmutableType.ParameterDataType.builder().type(pointType).build(); + + Type vectorOfPointType = + Type.UserDefined.builder() + .nullable(false) + .urn(urn) + .name("vector") + .addTypeParameters(pointTypeParam) + .build(); + + // Create literals + Expression.UserDefinedStructLiteral pointLiteral = + ExpressionCreator.userDefinedLiteralStruct( + false, + urn, + "point", + Collections.emptyList(), + Arrays.asList(ExpressionCreator.i32(false, 10), ExpressionCreator.i32(false, 20))); + + // Create vector literal: vector{(1,2), (3,4), (5,6)} + Expression.UserDefinedStructLiteral vectorOfPointLiteral = + ExpressionCreator.userDefinedLiteralStruct( + false, + urn, + "vector", + Arrays.asList(pointTypeParam), + Arrays.asList( + ExpressionCreator.userDefinedLiteralStruct( + false, + urn, + "point", + Collections.emptyList(), + Arrays.asList( + ExpressionCreator.i32(false, 1), ExpressionCreator.i32(false, 2))), + ExpressionCreator.userDefinedLiteralStruct( + false, + urn, + "point", + Collections.emptyList(), + Arrays.asList( + ExpressionCreator.i32(false, 3), ExpressionCreator.i32(false, 4))), + ExpressionCreator.userDefinedLiteralStruct( + false, + urn, + "point", + Collections.emptyList(), + Arrays.asList( + ExpressionCreator.i32(false, 5), ExpressionCreator.i32(false, 6))))); + + Type nullablePointType = + Type.UserDefined.builder().nullable(true).urn(urn).name("point").build(); + + Expression.UserDefinedStructLiteral nullablePointLiteral = + ExpressionCreator.userDefinedLiteralStruct( + true, + urn, + "point", + Collections.emptyList(), + Arrays.asList(ExpressionCreator.i32(false, 30), ExpressionCreator.i32(false, 40))); + + // Create virtual table with all three columns (nullable point, required point, required vector) + VirtualTableScan virtualTable = + ImmutableVirtualTableScan.builder() + .initialSchema( + NamedStruct.of( + Arrays.asList("nullable_point_col", "point_col", "vector_col"), + TypeCreator.REQUIRED.struct(nullablePointType, pointType, vectorOfPointType))) + .addRows( + ExpressionCreator.struct( + false, nullablePointLiteral, pointLiteral, vectorOfPointLiteral)) + .build(); + + Plan plan = Plan.builder().addRoots(Root.builder().input(virtualTable).build()).build(); + + PlanProtoConverter toProtoConverter = new PlanProtoConverter(); + io.substrait.proto.Plan protoPlan = toProtoConverter.toProto(plan); + + assertEquals(1, protoPlan.getExtensionUrnsCount(), "Should have exactly 1 extension URN"); + assertEquals( + 2, + protoPlan.getExtensionsCount(), + "Should have exactly 2 type extensions (point and vector), no duplicates"); + + ProtoPlanConverter fromProtoConverter = new ProtoPlanConverter(extensions); + Plan roundTrippedPlan = fromProtoConverter.from(protoPlan); + assertEquals(plan, roundTrippedPlan, "Plan should roundtrip correctly"); + } } From ce7985fcade463600fec9ea666da246d71dc2fc5 Mon Sep 17 00:00:00 2001 From: Ben Bellick Date: Mon, 24 Nov 2025 16:51:01 -0500 Subject: [PATCH 09/13] refactor: revert to simpler binary construction for any rep UDT in calcite --- .../io/substrait/isthmus/expression/ExpressionRexConverter.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java index 3127ce642..68936a6d4 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java @@ -112,7 +112,7 @@ public RexNode visit(Expression.NullLiteral expr, Context context) throws Runtim public RexNode visit(Expression.UserDefinedAnyLiteral expr, Context context) throws RuntimeException { RexLiteral binaryLiteral = - rexBuilder.makeBinaryLiteral(new ByteString(expr.value().toByteString().toByteArray())); + rexBuilder.makeBinaryLiteral(new ByteString(expr.value().toByteArray())); RelDataType type = typeConverter.toCalcite(typeFactory, expr.getType()); return rexBuilder.makeReinterpretCast(type, binaryLiteral, rexBuilder.makeLiteral(false)); } From e8cb8622df1d96acd9024ef3d3012daae659ae07 Mon Sep 17 00:00:00 2001 From: Ben Bellick Date: Tue, 25 Nov 2025 13:56:13 -0500 Subject: [PATCH 10/13] fix: ensure nullability preserved in calcite roundtrip --- .../isthmus/expression/CallConverters.java | 1 + .../substrait/isthmus/CustomFunctionTest.java | 25 +++++++++++++------ 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java b/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java index 6ca93b950..72198d49c 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java @@ -77,6 +77,7 @@ public class CallConverters { com.google.protobuf.Any.parseFrom(literal.value().toByteArray()); return Expression.UserDefinedAnyLiteral.builder() + .nullable(t.nullable()) .urn(t.urn()) .name(t.name()) .addAllTypeParameters(t.typeParameters()) diff --git a/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java b/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java index 9ef12f418..66eff0769 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java @@ -7,7 +7,6 @@ import io.substrait.dsl.SubstraitBuilder; import io.substrait.expression.Expression.UserDefinedLiteral; import io.substrait.expression.ExpressionCreator; -import io.substrait.extension.ExtensionCollector; import io.substrait.extension.SimpleExtension; import io.substrait.isthmus.expression.AggregateFunctionConverter; import io.substrait.isthmus.expression.FunctionMappings; @@ -16,9 +15,7 @@ import io.substrait.isthmus.utils.UserTypeFactory; import io.substrait.proto.Expression; import io.substrait.proto.Expression.Literal.Builder; -import io.substrait.relation.ProtoRelConverter; import io.substrait.relation.Rel; -import io.substrait.relation.RelProtoConverter; import io.substrait.type.Type; import io.substrait.type.TypeCreator; import java.io.IOException; @@ -601,10 +598,24 @@ void customTypesLiteralInFunctionsRoundtrip() { RelNode calciteRel = substraitToCalcite.convert(rel1); Rel rel2 = calciteToSubstrait.apply(calciteRel); assertEquals(rel1, rel2); + } + + @Test + void customNullableUserDefinedLiteralRoundtrip() { + Builder bldr = Expression.Literal.newBuilder(); + Any anyValue = Any.pack(bldr.setI32(10).build()); + UserDefinedLiteral nullableLiteral = + ExpressionCreator.userDefinedLiteralAny( + true, URN, "a_type", java.util.Collections.emptyList(), anyValue); + + Rel rel = + b.project( + input -> List.of(nullableLiteral), + b.remap(1), + b.namedScan(List.of("example"), List.of("a"), List.of(N.userDefined(URN, "a_type")))); - ExtensionCollector extensionCollector = new ExtensionCollector(); - io.substrait.proto.Rel protoRel = new RelProtoConverter(extensionCollector).toProto(rel1); - Rel rel3 = new ProtoRelConverter(extensionCollector, extensionCollection).from(protoRel); - assertEquals(rel1, rel3); + RelNode calciteRel = substraitToCalcite.convert(rel); + Rel relReturned = calciteToSubstrait.apply(calciteRel); + assertEquals(rel, relReturned); } } From 77f0bd74894184202f99292e8715b2755319884f Mon Sep 17 00:00:00 2001 From: Ben Bellick Date: Wed, 26 Nov 2025 12:10:00 -0600 Subject: [PATCH 11/13] chore: addressing some PR comments --- .../src/main/java/io/substrait/expression/Expression.java | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/core/src/main/java/io/substrait/expression/Expression.java b/core/src/main/java/io/substrait/expression/Expression.java index 7f0fc8748..0a1244df3 100644 --- a/core/src/main/java/io/substrait/expression/Expression.java +++ b/core/src/main/java/io/substrait/expression/Expression.java @@ -671,9 +671,6 @@ public R accept( *

  • As {@code google.protobuf.Any} - see {@link UserDefinedAnyLiteral} *
  • As {@code Literal.Struct} - see {@link UserDefinedStructLiteral} * - * - * @see UserDefinedAnyLiteral - * @see UserDefinedStructLiteral */ interface UserDefinedLiteral extends Literal { String urn(); @@ -684,7 +681,7 @@ interface UserDefinedLiteral extends Literal { } /** - * User-defined literal with value encoded as {@code google.protobuf.Any}. + * User-defined literal with value encoded as {@link com.google.protobuf.Any}. * *

    This encoding allows for arbitrary binary data to be stored in the literal value. */ @@ -723,7 +720,8 @@ public R accept( } /** - * User-defined literal with value encoded as {@code Literal.Struct}. + * User-defined literal with value encoded as {@link + * io.substrait.proto.Expression.Literal.Struct}. * *

    This encoding uses a structured list of fields to represent the literal value. */ From 4b4f5eef6d5a6c61fd563c99669010c6e3202ee1 Mon Sep 17 00:00:00 2001 From: Ben Bellick Date: Wed, 26 Nov 2025 12:20:24 -0600 Subject: [PATCH 12/13] chore: addressing more PR comments --- .../proto/ProtoExpressionConverter.java | 16 ++++++---------- core/src/main/java/io/substrait/type/Type.java | 10 +++++----- core/src/test/java/io/substrait/TestBase.java | 2 +- 3 files changed, 12 insertions(+), 16 deletions(-) diff --git a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java index 39c5e2f9a..4d7515222 100644 --- a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java +++ b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java @@ -497,25 +497,21 @@ public Expression.Literal from(io.substrait.proto.Expression.Literal literal) { lookup.getType(userDefinedLiteral.getTypeReference(), extensions); String urn = type.urn(); String name = type.name(); + List typeParameters = + userDefinedLiteral.getTypeParametersList().stream() + .map(protoTypeConverter::from) + .collect(Collectors.toList()); switch (userDefinedLiteral.getValCase()) { case VALUE: return ExpressionCreator.userDefinedLiteralAny( - literal.getNullable(), - urn, - name, - userDefinedLiteral.getTypeParametersList().stream() - .map(protoTypeConverter::from) - .collect(Collectors.toList()), - userDefinedLiteral.getValue()); + literal.getNullable(), urn, name, typeParameters, userDefinedLiteral.getValue()); case STRUCT: return ExpressionCreator.userDefinedLiteralStruct( literal.getNullable(), urn, name, - userDefinedLiteral.getTypeParametersList().stream() - .map(protoTypeConverter::from) - .collect(Collectors.toList()), + typeParameters, userDefinedLiteral.getStruct().getFieldsList().stream() .map(this::from) .collect(Collectors.toList())); diff --git a/core/src/main/java/io/substrait/type/Type.java b/core/src/main/java/io/substrait/type/Type.java index c597e7e67..90ec21a14 100644 --- a/core/src/main/java/io/substrait/type/Type.java +++ b/core/src/main/java/io/substrait/type/Type.java @@ -400,11 +400,11 @@ abstract class UserDefined implements Type { * vector} or custom types like {@code FixedArray<100>}. Each parameter in the list can be * either a type (like {@code i32}) or a value (like the integer {@code 100}). * - *

    Unlike built-in parameterized types ({@link Map}, {@link ListType}, {@link Decimal}), - * which have fixed, known schemas with concrete typed fields, user-defined types have variable, - * unknown schemas. This is why UserDefined uses a generic {@link Parameter} list that can hold - * any mix of types or values, while other parameterized types use concrete fields like {@code - * Type key()} or {@code int precision()}. + *

    Unlike built-in parameterized types ({@link Type.Map}, {@link Type.ListType}, {@link + * Type.Decimal}), which have fixed, known schemas with concrete typed fields, user-defined + * types have variable, unknown schemas. This is why UserDefined uses a generic {@link + * Parameter} list that can hold any mix of types or values, while other parameterized types use + * concrete fields like {@code Type key()} or {@code int precision()}. * *

    For example, a user-defined {@code vector} type parameterized by {@code i32} would have * one type parameter containing the {@code i32} type definition, while a {@code FixedArray} diff --git a/core/src/test/java/io/substrait/TestBase.java b/core/src/test/java/io/substrait/TestBase.java index b5f1dd4f1..e2fa7a91c 100644 --- a/core/src/test/java/io/substrait/TestBase.java +++ b/core/src/test/java/io/substrait/TestBase.java @@ -30,7 +30,7 @@ public abstract class TestBase { new ProtoRelConverter(functionCollector, defaultExtensionCollection); protected ExpressionProtoConverter expressionProtoConverter = - new ExpressionProtoConverter(functionCollector, relProtoConverter); + relProtoConverter.getExpressionProtoConverter(); protected ProtoExpressionConverter protoExpressionConverter = new ProtoExpressionConverter( From 3732cf98e6ccd7a4429010274fdbd250154c13ab Mon Sep 17 00:00:00 2001 From: Ben Bellick Date: Wed, 26 Nov 2025 12:32:32 -0600 Subject: [PATCH 13/13] chore: simplify documentation --- core/src/main/java/io/substrait/type/Type.java | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/core/src/main/java/io/substrait/type/Type.java b/core/src/main/java/io/substrait/type/Type.java index 90ec21a14..c9cb3f780 100644 --- a/core/src/main/java/io/substrait/type/Type.java +++ b/core/src/main/java/io/substrait/type/Type.java @@ -397,18 +397,7 @@ abstract class UserDefined implements Type { * Returns the type parameters for this user-defined type. * *

    Type parameters are used to represent parameterized/generic types, such as {@code - * vector} or custom types like {@code FixedArray<100>}. Each parameter in the list can be - * either a type (like {@code i32}) or a value (like the integer {@code 100}). - * - *

    Unlike built-in parameterized types ({@link Type.Map}, {@link Type.ListType}, {@link - * Type.Decimal}), which have fixed, known schemas with concrete typed fields, user-defined - * types have variable, unknown schemas. This is why UserDefined uses a generic {@link - * Parameter} list that can hold any mix of types or values, while other parameterized types use - * concrete fields like {@code Type key()} or {@code int precision()}. - * - *

    For example, a user-defined {@code vector} type parameterized by {@code i32} would have - * one type parameter containing the {@code i32} type definition, while a {@code FixedArray} - * type might take an integer parameter specifying its size. + * vector}. * * @return a list of type parameters, or an empty list if this type is not parameterized */