diff --git a/core/build.gradle.kts b/core/build.gradle.kts index 423042570..19017de43 100644 --- a/core/build.gradle.kts +++ b/core/build.gradle.kts @@ -117,6 +117,7 @@ dependencies { testImplementation(platform("org.junit:junit-bom:${JUNIT_VERSION}")) testImplementation("org.junit.jupiter:junit-jupiter") testRuntimeOnly("org.junit.platform:junit-platform-launcher") + testRuntimeOnly("org.jetbrains.kotlin:kotlin-stdlib:${properties.get("kotlin.version")}") api("com.google.protobuf:protobuf-java:${PROTOBUF_VERSION}") implementation("com.fasterxml.jackson.core:jackson-databind:${JACKSON_VERSION}") implementation("com.fasterxml.jackson.core:jackson-annotations:${JACKSON_VERSION}") diff --git a/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java b/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java index 4e8e428f7..ca6cd4ce3 100644 --- a/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java +++ b/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java @@ -699,7 +699,7 @@ public Expression.WindowFunctionInvocation windowFn( // Types public Type.UserDefined userDefinedType(String namespace, String typeName) { - return Type.UserDefined.builder().uri(namespace).name(typeName).nullable(false).build(); + return Type.UserDefined.builder().urn(namespace).name(typeName).nullable(false).build(); } // Misc diff --git a/core/src/main/java/io/substrait/expression/Expression.java b/core/src/main/java/io/substrait/expression/Expression.java index e5c45e19e..42c3c5118 100644 --- a/core/src/main/java/io/substrait/expression/Expression.java +++ b/core/src/main/java/io/substrait/expression/Expression.java @@ -666,13 +666,13 @@ public R accept( abstract class UserDefinedLiteral implements Literal { public abstract ByteString value(); - public abstract String uri(); + public abstract String urn(); public abstract String name(); @Override public Type getType() { - return Type.withNullability(nullable()).userDefined(uri(), name()); + return Type.withNullability(nullable()).userDefined(urn(), name()); } public static ImmutableExpression.UserDefinedLiteral.Builder builder() { diff --git a/core/src/main/java/io/substrait/expression/ExpressionCreator.java b/core/src/main/java/io/substrait/expression/ExpressionCreator.java index 4a18026ab..adf157d7b 100644 --- a/core/src/main/java/io/substrait/expression/ExpressionCreator.java +++ b/core/src/main/java/io/substrait/expression/ExpressionCreator.java @@ -287,10 +287,10 @@ public static Expression.StructLiteral struct( } public static Expression.UserDefinedLiteral userDefinedLiteral( - boolean nullable, String uri, String name, Any value) { + boolean nullable, String urn, String name, Any value) { return Expression.UserDefinedLiteral.builder() .nullable(nullable) - .uri(uri) + .urn(urn) .name(name) .value(value.toByteString()) .build(); diff --git a/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java b/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java index 093e3fff3..caf145dfc 100644 --- a/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java +++ b/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java @@ -361,7 +361,7 @@ public Expression visit( public Expression visit( io.substrait.expression.Expression.UserDefinedLiteral expr, EmptyVisitationContext context) { int typeReference = - extensionCollector.getTypeReference(SimpleExtension.TypeAnchor.of(expr.uri(), expr.name())); + extensionCollector.getTypeReference(SimpleExtension.TypeAnchor.of(expr.urn(), expr.name())); return lit( bldr -> { try { diff --git a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java index 859ddfca5..8f95cdf07 100644 --- a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java +++ b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java @@ -495,7 +495,7 @@ public Expression.Literal from(io.substrait.proto.Expression.Literal literal) { SimpleExtension.Type type = lookup.getType(userDefinedLiteral.getTypeReference(), extensions); return ExpressionCreator.userDefinedLiteral( - literal.getNullable(), type.uri(), type.name(), userDefinedLiteral.getValue()); + literal.getNullable(), type.urn(), type.name(), userDefinedLiteral.getValue()); } default: throw new IllegalStateException("Unexpected value: " + literal.getLiteralTypeCase()); diff --git a/core/src/main/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverter.java b/core/src/main/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverter.java index e05aa1f5b..7b74e1e48 100644 --- a/core/src/main/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverter.java +++ b/core/src/main/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverter.java @@ -34,7 +34,7 @@ public ExtendedExpression from(io.substrait.proto.ExtendedExpression extendedExp // fill in simple extension information through a discovery in the current proto-extended // expression ExtensionLookup functionLookup = - ImmutableExtensionLookup.builder().from(extendedExpression).build(); + ImmutableExtensionLookup.builder().from(extendedExpression, this.extensionCollection.uriUrnMap()).build(); NamedStruct baseSchemaProto = extendedExpression.getBaseSchema(); diff --git a/core/src/main/java/io/substrait/extension/BidiMap.java b/core/src/main/java/io/substrait/extension/BidiMap.java new file mode 100644 index 000000000..245925a28 --- /dev/null +++ b/core/src/main/java/io/substrait/extension/BidiMap.java @@ -0,0 +1,60 @@ +package io.substrait.extension; + +import java.util.HashMap; +import java.util.Map; +import java.util.Set; + +/** We don't depend on guava... */ +class BidiMap { + private final Map forwardMap; + private final Map reverseMap; + + public BidiMap(Map forwardMap) { + this.forwardMap = forwardMap; + this.reverseMap = new HashMap<>(); + } + + public BidiMap() { + this.forwardMap = new HashMap<>(); + this.reverseMap = new HashMap<>(); + } + + public T2 get(T1 t1) { + return forwardMap.get(t1); + } + + public T1 reverseGet(T2 t2) { + return reverseMap.get(t2); + } + + public void put(T1 t1, T2 t2) { + // Check for conflicting mappings (different values for same key) + T2 existingForward = forwardMap.get(t1); + T1 existingReverse = reverseMap.get(t2); + + if (existingForward != null && !existingForward.equals(t2)) { + throw new IllegalArgumentException("Key already exists in map with different value"); + } + if (existingReverse != null && !existingReverse.equals(t1)) { + throw new IllegalArgumentException("Key already exists in map with different value"); + } + + // Allow identical mappings, only add if not already present + forwardMap.put(t1, t2); + reverseMap.put(t2, t1); + } + + public void merge(BidiMap other) { + for (Map.Entry entry : other.forwardEntrySet()) { + put(entry.getKey(), entry.getValue()); + } + } + + public Set> forwardEntrySet() { + return forwardMap.entrySet(); + } + + public Set> reverseEntrySet() { + return reverseMap.entrySet(); + } +} diff --git a/core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java b/core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java index 46b7a920c..6c46a48ce 100644 --- a/core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java +++ b/core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java @@ -1,17 +1,17 @@ package io.substrait.extension; public class DefaultExtensionCatalog { - public static final String FUNCTIONS_AGGREGATE_APPROX = "/functions_aggregate_approx.yaml"; - public static final String FUNCTIONS_AGGREGATE_GENERIC = "/functions_aggregate_generic.yaml"; - public static final String FUNCTIONS_ARITHMETIC = "/functions_arithmetic.yaml"; - public static final String FUNCTIONS_ARITHMETIC_DECIMAL = "/functions_arithmetic_decimal.yaml"; - public static final String FUNCTIONS_BOOLEAN = "/functions_boolean.yaml"; - public static final String FUNCTIONS_COMPARISON = "/functions_comparison.yaml"; - public static final String FUNCTIONS_DATETIME = "/functions_datetime.yaml"; - public static final String FUNCTIONS_GEOMETRY = "/functions_geometry.yaml"; - public static final String FUNCTIONS_LOGARITHMIC = "/functions_logarithmic.yaml"; - public static final String FUNCTIONS_ROUNDING = "/functions_rounding.yaml"; - public static final String FUNCTIONS_ROUNDING_DECIMAL = "/functions_rounding_decimal.yaml"; - public static final String FUNCTIONS_SET = "/functions_set.yaml"; - public static final String FUNCTIONS_STRING = "/functions_string.yaml"; + public static final String FUNCTIONS_AGGREGATE_APPROX = "extension:io.substrait:functions_aggregate_approx"; + public static final String FUNCTIONS_AGGREGATE_GENERIC = "extension:io.substrait:functions_aggregate_generic"; + public static final String FUNCTIONS_ARITHMETIC = "extension:io.substrait:functions_arithmetic"; + public static final String FUNCTIONS_ARITHMETIC_DECIMAL = "extension:io.substrait:functions_arithmetic_decimal"; + public static final String FUNCTIONS_BOOLEAN = "extension:io.substrait:functions_boolean"; + public static final String FUNCTIONS_COMPARISON = "extension:io.substrait:functions_comparison"; + public static final String FUNCTIONS_DATETIME = "extension:io.substrait:functions_datetime"; + public static final String FUNCTIONS_GEOMETRY = "extension:io.substrait:functions_geometry"; + public static final String FUNCTIONS_LOGARITHMIC = "extension:io.substrait:functions_logarithmic"; + public static final String FUNCTIONS_ROUNDING = "extension:io.substrait:functions_rounding"; + public static final String FUNCTIONS_ROUNDING_DECIMAL = "extension:io.substrait:functions_rounding_decimal"; + public static final String FUNCTIONS_SET = "extension:io.substrait:functions_set"; + public static final String FUNCTIONS_STRING = "extension:io.substrait:functions_string"; } diff --git a/core/src/main/java/io/substrait/extension/ExtensionCollector.java b/core/src/main/java/io/substrait/extension/ExtensionCollector.java index d408c600b..b871a7bc0 100644 --- a/core/src/main/java/io/substrait/extension/ExtensionCollector.java +++ b/core/src/main/java/io/substrait/extension/ExtensionCollector.java @@ -3,7 +3,7 @@ import io.substrait.proto.ExtendedExpression; import io.substrait.proto.Plan; import io.substrait.proto.SimpleExtensionDeclaration; -import io.substrait.proto.SimpleExtensionURI; +import io.substrait.proto.SimpleExtensionURN; import java.util.ArrayList; import java.util.HashMap; import java.util.Map; @@ -52,30 +52,30 @@ public int getTypeReference(SimpleExtension.TypeAnchor typeAnchor) { public void addExtensionsToPlan(Plan.Builder builder) { SimpleExtensions simpleExtensions = getExtensions(); - builder.addAllExtensionUris(simpleExtensions.uris.values()); + builder.addAllExtensionUrns(simpleExtensions.urns.values()); builder.addAllExtensions(simpleExtensions.extensionList); } public void addExtensionsToExtendedExpression(ExtendedExpression.Builder builder) { SimpleExtensions simpleExtensions = getExtensions(); - builder.addAllExtensionUris(simpleExtensions.uris.values()); + builder.addAllExtensionUrns(simpleExtensions.urns.values()); builder.addAllExtensions(simpleExtensions.extensionList); } private SimpleExtensions getExtensions() { - AtomicInteger uriPos = new AtomicInteger(1); - HashMap uris = new HashMap<>(); + AtomicInteger urnPos = new AtomicInteger(1); + HashMap urns = new HashMap<>(); ArrayList extensionList = new ArrayList<>(); - for (Map.Entry e : funcMap.forwardMap.entrySet()) { - SimpleExtensionURI uri = - uris.computeIfAbsent( - e.getValue().namespace(), + for (Map.Entry e : funcMap.forwardEntrySet()) { + SimpleExtensionURN urn = + urns.computeIfAbsent( + e.getValue().urn(), k -> - SimpleExtensionURI.newBuilder() - .setExtensionUriAnchor(uriPos.getAndIncrement()) - .setUri(k) + SimpleExtensionURN.newBuilder() + .setExtensionUrnAnchor(urnPos.getAndIncrement()) + .setUrn(k) .build()); SimpleExtensionDeclaration decl = SimpleExtensionDeclaration.newBuilder() @@ -83,18 +83,18 @@ private SimpleExtensions getExtensions() { SimpleExtensionDeclaration.ExtensionFunction.newBuilder() .setFunctionAnchor(e.getKey()) .setName(e.getValue().key()) - .setExtensionUriReference(uri.getExtensionUriAnchor())) + .setExtensionUrnReference(urn.getExtensionUrnAnchor())) .build(); extensionList.add(decl); } - for (Map.Entry e : typeMap.forwardMap.entrySet()) { - SimpleExtensionURI uri = - uris.computeIfAbsent( - e.getValue().namespace(), + for (Map.Entry e : typeMap.forwardEntrySet()) { + SimpleExtensionURN urn = + urns.computeIfAbsent( + e.getValue().urn(), k -> - SimpleExtensionURI.newBuilder() - .setExtensionUriAnchor(uriPos.getAndIncrement()) - .setUri(k) + SimpleExtensionURN.newBuilder() + .setExtensionUrnAnchor(urnPos.getAndIncrement()) + .setUrn(k) .build()); SimpleExtensionDeclaration decl = SimpleExtensionDeclaration.newBuilder() @@ -102,46 +102,22 @@ private SimpleExtensions getExtensions() { SimpleExtensionDeclaration.ExtensionType.newBuilder() .setTypeAnchor(e.getKey()) .setName(e.getValue().key()) - .setExtensionUriReference(uri.getExtensionUriAnchor())) + .setExtensionUrnReference(urn.getExtensionUrnAnchor())) .build(); extensionList.add(decl); } - return new SimpleExtensions(uris, extensionList); + return new SimpleExtensions(urns, extensionList); } private static final class SimpleExtensions { - final HashMap uris; + final HashMap urns; final ArrayList extensionList; SimpleExtensions( - HashMap uris, + HashMap urns, ArrayList extensionList) { - this.uris = uris; + this.urns = urns; this.extensionList = extensionList; } } - - /** We don't depend on guava... */ - private static class BidiMap { - private final Map forwardMap; - private final Map reverseMap; - - public BidiMap(Map forwardMap) { - this.forwardMap = forwardMap; - this.reverseMap = new HashMap<>(); - } - - public T2 get(T1 t1) { - return forwardMap.get(t1); - } - - public T1 reverseGet(T2 t2) { - return reverseMap.get(t2); - } - - public void put(T1 t1, T2 t2) { - forwardMap.put(t1, t2); - reverseMap.put(t2, t1); - } - } } diff --git a/core/src/main/java/io/substrait/extension/ImmutableExtensionLookup.java b/core/src/main/java/io/substrait/extension/ImmutableExtensionLookup.java index 6ac4fe922..940d4f9c7 100644 --- a/core/src/main/java/io/substrait/extension/ImmutableExtensionLookup.java +++ b/core/src/main/java/io/substrait/extension/ImmutableExtensionLookup.java @@ -4,6 +4,7 @@ import io.substrait.proto.Plan; import io.substrait.proto.SimpleExtensionDeclaration; import io.substrait.proto.SimpleExtensionURI; +import io.substrait.proto.SimpleExtensionURN; import java.util.Collections; import java.util.HashMap; import java.util.List; @@ -29,21 +30,33 @@ public static class Builder { private final Map functionMap = new HashMap<>(); private final Map typeMap = new HashMap<>(); - public Builder from(Plan plan) { - return from(plan.getExtensionUrisList(), plan.getExtensionsList()); + public Builder from(Plan plan, BidiMap uriUrnMap) { + return from( + plan.getExtensionUrnsList(), plan.getExtensionUrisList(), plan.getExtensionsList(), uriUrnMap); } - public Builder from(ExtendedExpression extendedExpression) { + public Builder from(ExtendedExpression extendedExpression, BidiMap uriUrnMap) { return from( - extendedExpression.getExtensionUrisList(), extendedExpression.getExtensionsList()); + extendedExpression.getExtensionUrnsList(), + extendedExpression.getExtensionUrisList(), + extendedExpression.getExtensionsList(), + uriUrnMap); } private Builder from( + List simpleExtensionURNs, List simpleExtensionURIs, - List simpleExtensionDeclarations) { - Map namespaceMap = new HashMap<>(); + List simpleExtensionDeclarations, + BidiMap uriUrnMap) { + Map urnMap = new HashMap<>(); + Map uriMap = new HashMap<>(); + // Handle URN format + for (SimpleExtensionURN extension : simpleExtensionURNs) { + urnMap.put(extension.getExtensionUrnAnchor(), extension.getUrn()); + } + for (SimpleExtensionURI extension : simpleExtensionURIs) { - namespaceMap.put(extension.getExtensionUriAnchor(), extension.getUri()); + uriMap.put(extension.getExtensionUriAnchor(), extension.getUri()); } // Add all functions used in plan to the functionMap @@ -53,13 +66,27 @@ private Builder from( } SimpleExtensionDeclaration.ExtensionFunction func = extension.getExtensionFunction(); int reference = func.getFunctionAnchor(); - String namespace = namespaceMap.get(func.getExtensionUriReference()); - if (namespace == null) { - throw new IllegalStateException( - "Could not find extension URI of " + func.getExtensionUriReference()); + String urn = urnMap.get(func.getExtensionUrnReference()); + if (urn == null) { + int uriReference = func.getExtensionUriReference(); + String uri = uriMap.get(uriReference); + if (uri == null) { + throw new IllegalStateException( + "Could not find extension URN for function reference " + + func.getExtensionUrnReference() + + " or extension URI for function reference " + + func.getExtensionUriReference()); + } + // Translate URI to URN using the BidiMap + urn = uriUrnMap.get(uri); + if (urn == null) { + throw new IllegalStateException( + "Could not translate URI '" + uri + "' to URN. " + + "URI-URN mapping not found in the provided mapping."); + } } String name = func.getName(); - SimpleExtension.FunctionAnchor anchor = SimpleExtension.FunctionAnchor.of(namespace, name); + SimpleExtension.FunctionAnchor anchor = SimpleExtension.FunctionAnchor.of(urn, name); functionMap.put(reference, anchor); } @@ -70,13 +97,27 @@ private Builder from( } SimpleExtensionDeclaration.ExtensionType type = extension.getExtensionType(); int reference = type.getTypeAnchor(); - String namespace = namespaceMap.get(type.getExtensionUriReference()); - if (namespace == null) { - throw new IllegalStateException( - "Could not find extension URI of " + type.getExtensionUriReference()); + String urn = urnMap.get(type.getExtensionUrnReference()); + if (urn == null) { + int uriReference = type.getExtensionUriReference(); + String uri = uriMap.get(uriReference); + if (uri == null) { + throw new IllegalStateException( + "Could not find extension URN for type reference " + + type.getExtensionUrnReference() + + " or extension URI for type reference " + + type.getExtensionUriReference()); + } + // Translate URI to URN using the BidiMap + urn = uriUrnMap.get(uri); + if (urn == null) { + throw new IllegalStateException( + "Could not translate URI '" + uri + "' to URN. " + + "URI-URN mapping not found in the provided mapping."); + } } String name = type.getName(); - SimpleExtension.TypeAnchor anchor = SimpleExtension.TypeAnchor.of(namespace, name); + SimpleExtension.TypeAnchor anchor = SimpleExtension.TypeAnchor.of(urn, name); typeMap.put(reference, anchor); } diff --git a/core/src/main/java/io/substrait/extension/SimpleExtension.java b/core/src/main/java/io/substrait/extension/SimpleExtension.java index ad0548920..7cc1c7ecf 100644 --- a/core/src/main/java/io/substrait/extension/SimpleExtension.java +++ b/core/src/main/java/io/substrait/extension/SimpleExtension.java @@ -1,10 +1,10 @@ package io.substrait.extension; import com.fasterxml.jackson.annotation.JacksonInject; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonSubTypes; import com.fasterxml.jackson.annotation.JsonTypeInfo; -import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.DeserializationFeature; import com.fasterxml.jackson.databind.InjectableValues; import com.fasterxml.jackson.databind.ObjectMapper; @@ -23,6 +23,7 @@ import java.io.InputStream; import java.io.UncheckedIOException; import java.util.Arrays; +import java.util.Scanner; import java.util.List; import java.util.Map; import java.util.Optional; @@ -32,6 +33,7 @@ import java.util.stream.Collectors; import java.util.stream.IntStream; import java.util.stream.Stream; + import org.immutables.value.Value; import org.jspecify.annotations.Nullable; import org.slf4j.Logger; @@ -42,12 +44,22 @@ public class SimpleExtension { private static final Logger LOGGER = LoggerFactory.getLogger(SimpleExtension.class); - // Key for looking up URI in InjectableValues - public static final String URI_LOCATOR_KEY = "uri"; + // Key for looking up URN in InjectableValues + public static final String URN_LOCATOR_KEY = "urn"; - private static ObjectMapper objectMapper(String namespace) { + private static void validateUrn(String urn) { + if (urn == null || urn.trim().isEmpty()) { + throw new IllegalArgumentException("URN cannot be null or empty"); + } + if (!urn.matches("^extension:[^:]+:[^:]+$")) { + throw new IllegalArgumentException( + "URN must follow format 'extension::', got: " + urn); + } + } + + private static ObjectMapper objectMapper(String urn) { InjectableValues.Std iv = new InjectableValues.Std(); - iv.addValue(URI_LOCATOR_KEY, namespace); + iv.addValue(URN_LOCATOR_KEY, urn); return new ObjectMapper(new YAMLFactory()) .enable(DeserializationFeature.ACCEPT_SINGLE_VALUE_AS_ARRAY) @@ -185,25 +197,22 @@ public static ImmutableSimpleExtension.EnumArgument.Builder builder() { } public interface Anchor { - String namespace(); + String urn(); String key(); } @Value.Immutable public interface FunctionAnchor extends Anchor { - static FunctionAnchor of(String namespace, String key) { - return ImmutableSimpleExtension.FunctionAnchor.builder() - .namespace(namespace) - .key(key) - .build(); + static FunctionAnchor of(String urn, String key) { + return ImmutableSimpleExtension.FunctionAnchor.builder().urn(urn).key(key).build(); } } @Value.Immutable public interface TypeAnchor extends Anchor { - static TypeAnchor of(String namespace, String name) { - return ImmutableSimpleExtension.TypeAnchor.builder().namespace(namespace).key(name).build(); + static TypeAnchor of(String urn, String name) { + return ImmutableSimpleExtension.TypeAnchor.builder().urn(urn).key(name).build(); } } @@ -227,7 +236,7 @@ default ParameterConsistency parameterConsistency() { public abstract static class Function { private final Supplier anchorSupplier = - Util.memoize(() -> FunctionAnchor.of(uri(), key())); + Util.memoize(() -> FunctionAnchor.of(urn(), key())); private final Supplier keySupplier = Util.memoize(() -> constructKey(name(), args())); private final Supplier> requiredArgsSupplier = Util.memoize( @@ -242,8 +251,8 @@ public String name() { } @Value.Default - public String uri() { - // we can't use null detection here since we initially construct this without a uri, then + public String urn() { + // we can't use null detection here since we initially construct this without a urn, then // resolve later. return ""; } @@ -367,8 +376,8 @@ public abstract static class ScalarFunction { public abstract List impls(); - public Stream resolve(String uri) { - return impls().stream().map(f -> f.resolve(uri, name(), description())); + public Stream resolve(String urn) { + return impls().stream().map(f -> f.resolve(urn, name(), description())); } } @@ -376,9 +385,9 @@ public Stream resolve(String uri) { @JsonSerialize(as = ImmutableSimpleExtension.ScalarFunctionVariant.class) @Value.Immutable public abstract static class ScalarFunctionVariant extends Function { - public ScalarFunctionVariant resolve(String uri, String name, String description) { + public ScalarFunctionVariant resolve(String urn, String name, String description) { return ImmutableSimpleExtension.ScalarFunctionVariant.builder() - .uri(uri) + .urn(urn) .name(name) .description(description) .nullability(nullability()) @@ -403,8 +412,8 @@ public abstract static class AggregateFunction { public abstract List impls(); - public Stream resolve(String uri) { - return impls().stream().map(f -> f.resolve(uri, name(), description())); + public Stream resolve(String urn) { + return impls().stream().map(f -> f.resolve(urn, name(), description())); } } @@ -420,8 +429,8 @@ public abstract static class WindowFunction { public abstract List impls(); - public Stream resolve(String uri) { - return impls().stream().map(f -> f.resolve(uri, name(), description())); + public Stream resolve(String urn) { + return impls().stream().map(f -> f.resolve(urn, name(), description())); } public static ImmutableSimpleExtension.WindowFunction.Builder builder() { @@ -447,9 +456,9 @@ public String toString() { @Nullable public abstract TypeExpression intermediate(); - AggregateFunctionVariant resolve(String uri, String name, String description) { + AggregateFunctionVariant resolve(String urn, String name, String description) { return ImmutableSimpleExtension.AggregateFunctionVariant.builder() - .uri(uri) + .urn(urn) .name(name) .description(description) .nullability(nullability()) @@ -489,9 +498,9 @@ public String toString() { return super.toString(); } - WindowFunctionVariant resolve(String uri, String name, String description) { + WindowFunctionVariant resolve(String urn, String name, String description) { return ImmutableSimpleExtension.WindowFunctionVariant.builder() - .uri(uri) + .urn(urn) .name(name) .description(description) .nullability(nullability()) @@ -516,12 +525,12 @@ public static ImmutableSimpleExtension.WindowFunctionVariant.Builder builder() { @Value.Immutable public abstract static class Type { private final Supplier anchorSupplier = - Util.memoize(() -> TypeAnchor.of(uri(), name())); + Util.memoize(() -> TypeAnchor.of(urn(), name())); public abstract String name(); - @JacksonInject(SimpleExtension.URI_LOCATOR_KEY) - public abstract String uri(); + @JacksonInject(SimpleExtension.URN_LOCATOR_KEY) + public abstract String urn(); // TODO: Handle conversion of structure object to Named Struct representation protected abstract Optional structure(); @@ -533,11 +542,22 @@ public TypeAnchor getAnchor() { @JsonDeserialize(as = ImmutableSimpleExtension.ExtensionSignatures.class) @JsonSerialize(as = ImmutableSimpleExtension.ExtensionSignatures.class) + @JsonIgnoreProperties(ignoreUnknown = true) @Value.Immutable public abstract static class ExtensionSignatures { @JsonProperty("types") public abstract List types(); + @JsonProperty("urn") + public abstract String urn(); + + // URI is not from YAML, but from the loading context + // this only needs to be present temporarily to handle the URI -> URN migration + @Value.Default + public String uri() { + return ""; + } + @JsonProperty("scalar_functions") public abstract List scalars(); @@ -554,27 +574,32 @@ public int size() { + (windows() == null ? 0 : windows().size()); } - public Stream resolve(String uri) { + public Stream resolve(String urn) { return Stream.concat( Stream.concat( - scalars() == null ? Stream.of() : scalars().stream().flatMap(f -> f.resolve(uri)), + scalars() == null ? Stream.of() : scalars().stream().flatMap(f -> f.resolve(urn)), aggregates() == null ? Stream.of() - : aggregates().stream().flatMap(f -> f.resolve(uri))), - windows() == null ? Stream.of() : windows().stream().flatMap(f -> f.resolve(uri))); + : aggregates().stream().flatMap(f -> f.resolve(urn))), + windows() == null ? Stream.of() : windows().stream().flatMap(f -> f.resolve(urn))); } } @Value.Immutable public abstract static class ExtensionCollection { - private final Supplier> namespaceSupplier = + @Value.Default + public BidiMap uriUrnMap() { + return new BidiMap<>(); + } + + private final Supplier> urnSupplier = Util.memoize( () -> { return Stream.concat( Stream.concat( - scalarFunctions().stream().map(Function::uri), - aggregateFunctions().stream().map(Function::uri)), - windowFunctions().stream().map(Function::uri)) + scalarFunctions().stream().map(Function::urn), + aggregateFunctions().stream().map(Function::urn)), + windowFunctions().stream().map(Function::urn)) .collect(Collectors.toSet()); }); @@ -628,11 +653,11 @@ public Type getType(TypeAnchor anchor) { if (type != null) { return type; } - checkNamespace(anchor.namespace()); + checkUrn(anchor.urn()); throw new IllegalArgumentException( String.format( - "Unexpected type with name %s. The namespace %s is loaded but no type with this name found.", - anchor.key(), anchor.namespace())); + "Unexpected type with name %s. The URN %s is loaded but no type with this name found.", + anchor.key(), anchor.urn())); } public ScalarFunctionVariant getScalarFunction(FunctionAnchor anchor) { @@ -640,16 +665,16 @@ public ScalarFunctionVariant getScalarFunction(FunctionAnchor anchor) { if (variant != null) { return variant; } - checkNamespace(anchor.namespace()); + checkUrn(anchor.urn()); throw new IllegalArgumentException( String.format( - "Unexpected scalar function with key %s. The namespace %s is loaded " + "Unexpected scalar function with key %s. The URN %s is loaded " + "but no scalar function with this key found.", - anchor.key(), anchor.namespace())); + anchor.key(), anchor.urn())); } - private void checkNamespace(String name) { - if (namespaceSupplier.get().contains(name)) { + private void checkUrn(String name) { + if (urnSupplier.get().contains(name)) { return; } @@ -666,12 +691,12 @@ public AggregateFunctionVariant getAggregateFunction(FunctionAnchor anchor) { return variant; } - checkNamespace(anchor.namespace()); + checkUrn(anchor.urn()); throw new IllegalArgumentException( String.format( - "Unexpected aggregate function with key %s. The namespace %s is loaded " + "Unexpected aggregate function with key %s. The URN %s is loaded " + "but no aggregate function with this key was found.", - anchor.key(), anchor.namespace())); + anchor.key(), anchor.urn())); } public WindowFunctionVariant getWindowFunction(FunctionAnchor anchor) { @@ -679,15 +704,63 @@ public WindowFunctionVariant getWindowFunction(FunctionAnchor anchor) { if (variant != null) { return variant; } - checkNamespace(anchor.namespace()); + checkUrn(anchor.urn()); throw new IllegalArgumentException( String.format( - "Unexpected window aggregate function with key %s. The namespace %s is loaded " + "Unexpected window aggregate function with key %s. The URN %s is loaded " + "but no window aggregate function with this key was found.", - anchor.key(), anchor.namespace())); + anchor.key(), anchor.urn())); + } + + /** + * Gets the URN for a given URI. This is only useful during the URI -> URN migration, and will + * be dropped when the migration is complete. + * + * @param uri The URI to look up + * @return The corresponding URN, or null if not found + */ + public String getUrn(String uri) { + return uriUrnMap().get(uri); + } + + /** + * Gets the URI for a given URN. This is only useful during the URI -> URN migration, and will + * be dropped when the migration is complete. + * + * @param urn The URN to look up + * @return The corresponding URI, or null if not found + */ + public String getUri(String urn) { + return uriUrnMap().reverseGet(urn); + } + + /** + * Checks if a URI has a corresponding URN mapping. This is only useful during the URI -> URN + * migration, and will be dropped when the migration is complete. + * + * @param uri The URI to check + * @return true if the URI has a URN mapping, false otherwise + */ + public boolean hasUrn(String uri) { + return uriUrnMap().get(uri) != null; + } + + /** + * Checks if a URN has a corresponding URI mapping. This is only useful during the URI -> URN + * migration, and will be dropped when the migration is complete. + * + * @param urn The URN to check + * @return true if the URN has a URI mapping, false otherwise + */ + public boolean hasUri(String urn) { + return uriUrnMap().reverseGet(urn) != null; } public ExtensionCollection merge(ExtensionCollection extensionCollection) { + BidiMap mergedUriUrnMap = new BidiMap<>(); + mergedUriUrnMap.merge(uriUrnMap()); + mergedUriUrnMap.merge(extensionCollection.uriUrnMap()); + return ImmutableSimpleExtension.ExtensionCollection.builder() .addAllAggregateFunctions(aggregateFunctions()) .addAllAggregateFunctions(extensionCollection.aggregateFunctions()) @@ -697,6 +770,7 @@ public ExtensionCollection merge(ExtensionCollection extensionCollection) { .addAllWindowFunctions(extensionCollection.windowFunctions()) .addAllTypes(types()) .addAllTypes(extensionCollection.types()) + .uriUrnMap(mergedUriUrnMap) .build(); } } @@ -722,7 +796,7 @@ public static ExtensionCollection loadDefaults() { return load(defaultFiles); } - public static ExtensionCollection load(List resourcePaths) { + private static ExtensionCollection load(List resourcePaths) { if (resourcePaths.isEmpty()) { throw new IllegalArgumentException("Require at least one resource path."); } @@ -745,41 +819,60 @@ public static ExtensionCollection load(List resourcePaths) { return complete; } - public static ExtensionCollection load(String namespace, String str) { + public static ExtensionCollection load(String uri, String content) { try { - ExtensionSignatures doc = objectMapper(namespace).readValue(str, ExtensionSignatures.class); - return buildExtensionCollection(namespace, doc); - } catch (JsonProcessingException e) { + // Parse with basic YAML mapper first to extract URN + ObjectMapper basicYamlMapper = new ObjectMapper(new YAMLFactory()); + com.fasterxml.jackson.databind.JsonNode rootNode = basicYamlMapper.readTree(content); + com.fasterxml.jackson.databind.JsonNode urnNode = rootNode.get("urn"); + if (urnNode == null) { + throw new IllegalArgumentException("Extension YAML file must contain a 'urn' field"); + } + String urn = urnNode.asText(); + validateUrn(urn); + + ExtensionSignatures docWithoutUri = + objectMapper(urn).readValue(content, ExtensionSignatures.class); + + ExtensionSignatures doc = + ImmutableSimpleExtension.ExtensionSignatures.builder() + .from(docWithoutUri) + .uri(uri) + .build(); + + return buildExtensionCollection(uri, doc); + } catch (IOException e) { throw new IllegalStateException(e); } } - public static ExtensionCollection load(String namespace, InputStream stream) { - try { - ExtensionSignatures doc = - objectMapper(namespace).readValue(stream, ExtensionSignatures.class); - return buildExtensionCollection(namespace, doc); - } catch (RuntimeException ex) { - throw ex; - } catch (Exception ex) { - throw new IllegalStateException("Failure while parsing " + namespace, ex); + public static ExtensionCollection load(String uri, InputStream stream) { + try (Scanner scanner = new Scanner(stream)) { + scanner.useDelimiter("\\A"); + String content = scanner.next(); + return load(uri, content); } } public static ExtensionCollection buildExtensionCollection( - String namespace, ExtensionSignatures extensionSignatures) { + String uri, ExtensionSignatures extensionSignatures) { + String urn = extensionSignatures.urn(); + validateUrn(urn); + if (uri == null || uri == "") { + throw new IllegalArgumentException("URI cannot be null or empty"); + } List scalarFunctionVariants = extensionSignatures.scalars().stream() - .flatMap(t -> t.resolve(namespace)) + .flatMap(t -> t.resolve(urn)) .collect(Collectors.toList()); List aggregateFunctionVariants = extensionSignatures.aggregates().stream() - .flatMap(t -> t.resolve(namespace)) + .flatMap(t -> t.resolve(urn)) .collect(Collectors.toList()); Stream windowFunctionVariants = - extensionSignatures.windows().stream().flatMap(t -> t.resolve(namespace)); + extensionSignatures.windows().stream().flatMap(t -> t.resolve(urn)); // Aggregate functions can be used as Window Functions Stream windowAggFunctionVariants = @@ -800,18 +893,23 @@ public static ExtensionCollection buildExtensionCollection( Stream.concat(windowFunctionVariants, windowAggFunctionVariants) .collect(Collectors.toList()); + BidiMap uriUrnMap = new BidiMap<>(); + uriUrnMap.put(uri, urn); + ImmutableSimpleExtension.ExtensionCollection collection = ImmutableSimpleExtension.ExtensionCollection.builder() .scalarFunctions(scalarFunctionVariants) .aggregateFunctions(aggregateFunctionVariants) .windowFunctions(allWindowFunctionVariants) .addAllTypes(extensionSignatures.types()) + .uriUrnMap(uriUrnMap) .build(); + LOGGER.atDebug().log( "Loaded {} aggregate functions and {} scalar functions from {}.", collection.aggregateFunctions().size(), collection.scalarFunctions().size(), - namespace); + extensionSignatures.urn()); return collection; } } diff --git a/core/src/main/java/io/substrait/plan/Plan.java b/core/src/main/java/io/substrait/plan/Plan.java index ff4bfcdda..7b8b81fd7 100644 --- a/core/src/main/java/io/substrait/plan/Plan.java +++ b/core/src/main/java/io/substrait/plan/Plan.java @@ -21,6 +21,10 @@ public Version getVersion() { public abstract Optional getAdvancedExtension(); + public abstract List getExtensionUrns(); + + public abstract List getExtensionUris(); + public static ImmutablePlan.Builder builder() { return ImmutablePlan.builder(); } diff --git a/core/src/main/java/io/substrait/plan/ProtoPlanConverter.java b/core/src/main/java/io/substrait/plan/ProtoPlanConverter.java index b6d5c3bc2..f94457709 100644 --- a/core/src/main/java/io/substrait/plan/ProtoPlanConverter.java +++ b/core/src/main/java/io/substrait/plan/ProtoPlanConverter.java @@ -29,7 +29,8 @@ protected ProtoRelConverter getProtoRelConverter(ExtensionLookup functionLookup) } public Plan from(io.substrait.proto.Plan plan) { - ExtensionLookup functionLookup = ImmutableExtensionLookup.builder().from(plan).build(); + ExtensionLookup functionLookup = + ImmutableExtensionLookup.builder().from(plan, extensionCollection.uriUrnMap()).build(); ProtoRelConverter relConverter = getProtoRelConverter(functionLookup); List roots = new ArrayList<>(); for (PlanRel planRel : plan.getRelationsList()) { @@ -54,12 +55,23 @@ public Plan from(io.substrait.proto.Plan plan) { versionBuilder.producer(Optional.of(plan.getVersion().getProducer())); } + List extensionUrns = + plan.getExtensionUrnsList().stream() + .map(urn -> urn.getUrn()) + .collect(java.util.stream.Collectors.toList()); + List extensionUris = + extensionUrns.stream() + .map(urn -> extensionCollection.getUri(urn)) + .collect(java.util.stream.Collectors.toList()); + return Plan.builder() .roots(roots) .expectedTypeUrls(plan.getExpectedTypeUrlsList()) .advancedExtension( Optional.ofNullable(plan.hasAdvancedExtensions() ? plan.getAdvancedExtensions() : null)) .version(versionBuilder.build()) + .extensionUrns(extensionUrns) + .extensionUris(extensionUris) .build(); } } diff --git a/core/src/main/java/io/substrait/type/Deserializers.java b/core/src/main/java/io/substrait/type/Deserializers.java index 160004a3b..efdbc3306 100644 --- a/core/src/main/java/io/substrait/type/Deserializers.java +++ b/core/src/main/java/io/substrait/type/Deserializers.java @@ -44,7 +44,7 @@ public T deserialize(final JsonParser p, final DeserializationContext ctxt) String typeString = p.getValueAsString(); try { String namespace = - (String) ctxt.findInjectableValue(SimpleExtension.URI_LOCATOR_KEY, null, null); + (String) ctxt.findInjectableValue(SimpleExtension.URN_LOCATOR_KEY, null, null); return TypeStringParser.parse(typeString, namespace, converter); } catch (Exception ex) { throw JsonMappingException.from( diff --git a/core/src/main/java/io/substrait/type/Type.java b/core/src/main/java/io/substrait/type/Type.java index 362adc3a8..aaf97aa12 100644 --- a/core/src/main/java/io/substrait/type/Type.java +++ b/core/src/main/java/io/substrait/type/Type.java @@ -389,7 +389,7 @@ public R accept(final TypeVisitor typeVisitor) th @Value.Immutable abstract class UserDefined implements Type { - public abstract String uri(); + public abstract String urn(); public abstract String name(); diff --git a/core/src/main/java/io/substrait/type/TypeCreator.java b/core/src/main/java/io/substrait/type/TypeCreator.java index 880b72ed9..43358e505 100644 --- a/core/src/main/java/io/substrait/type/TypeCreator.java +++ b/core/src/main/java/io/substrait/type/TypeCreator.java @@ -108,8 +108,8 @@ public Type.Map map(Type key, Type value) { return Type.Map.builder().nullable(nullable).key(key).value(value).build(); } - public Type userDefined(String uri, String name) { - return Type.UserDefined.builder().nullable(nullable).uri(uri).name(name).build(); + public Type userDefined(String urn, String name) { + return Type.UserDefined.builder().nullable(nullable).urn(urn).name(name).build(); } public static TypeCreator of(boolean nullability) { diff --git a/core/src/main/java/io/substrait/type/parser/TypeStringParser.java b/core/src/main/java/io/substrait/type/parser/TypeStringParser.java index 07093e925..8085ea6b8 100644 --- a/core/src/main/java/io/substrait/type/parser/TypeStringParser.java +++ b/core/src/main/java/io/substrait/type/parser/TypeStringParser.java @@ -16,16 +16,16 @@ public class TypeStringParser { private TypeStringParser() {} - public static Type parseSimple(String str, String namespace) { - return parse(str, namespace, ParseToPojo::type); + public static Type parseSimple(String str, String urn) { + return parse(str, urn, ParseToPojo::type); } - public static ParameterizedType parseParameterized(String str, String namespace) { - return parse(str, namespace, ParseToPojo::parameterizedType); + public static ParameterizedType parseParameterized(String str, String urn) { + return parse(str, urn, ParseToPojo::parameterizedType); } - public static TypeExpression parseExpression(String str, String namespace) { - return parse(str, namespace, ParseToPojo::typeExpression); + public static TypeExpression parseExpression(String str, String urn) { + return parse(str, urn, ParseToPojo::typeExpression); } private static SubstraitTypeParser.StartContext parse(String str) { @@ -40,8 +40,8 @@ private static SubstraitTypeParser.StartContext parse(String str) { } public static T parse( - String str, String namespace, BiFunction func) { - return func.apply(namespace, parse(str)); + String str, String urn, BiFunction func) { + return func.apply(urn, parse(str)); } public static TypeExpression parse(String str, ParseToPojo.Visitor visitor) { diff --git a/core/src/main/java/io/substrait/type/proto/BaseProtoConverter.java b/core/src/main/java/io/substrait/type/proto/BaseProtoConverter.java index a8c64db95..691d4bce5 100644 --- a/core/src/main/java/io/substrait/type/proto/BaseProtoConverter.java +++ b/core/src/main/java/io/substrait/type/proto/BaseProtoConverter.java @@ -164,7 +164,7 @@ public final T visit(final Type.Map expr) { @Override public final T visit(final Type.UserDefined expr) { int ref = - extensionCollector.getTypeReference(SimpleExtension.TypeAnchor.of(expr.uri(), expr.name())); + extensionCollector.getTypeReference(SimpleExtension.TypeAnchor.of(expr.urn(), expr.name())); return typeContainer(expr).userDefined(ref); } } diff --git a/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java b/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java index 661f57fea..95d42328a 100644 --- a/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java +++ b/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java @@ -90,7 +90,7 @@ public Type from(io.substrait.proto.Type type) { { io.substrait.proto.Type.UserDefined userDefined = type.getUserDefined(); SimpleExtension.Type t = lookup.getType(userDefined.getTypeReference(), extensions); - return n(userDefined.getNullability()).userDefined(t.uri(), t.name()); + return n(userDefined.getNullability()).userDefined(t.urn(), t.name()); } case USER_DEFINED_TYPE_REFERENCE: throw new UnsupportedOperationException("Unsupported user defined reference: " + type); diff --git a/core/src/test/java/io/substrait/extension/TypeExtensionTest.java b/core/src/test/java/io/substrait/extension/TypeExtensionTest.java index cd2522090..2e8fd8403 100644 --- a/core/src/test/java/io/substrait/extension/TypeExtensionTest.java +++ b/core/src/test/java/io/substrait/extension/TypeExtensionTest.java @@ -28,13 +28,13 @@ public class TypeExtensionTest { static final TypeCreator R = TypeCreator.of(false); - static final String NAMESPACE = "/custom_extensions"; + static final String NAMESPACE = "extension:test:custom_extensions"; final SimpleExtension.ExtensionCollection extensionCollection; { - InputStream inputStream = - this.getClass().getResourceAsStream("/extensions/custom_extensions.yaml"); - extensionCollection = SimpleExtension.load(NAMESPACE, inputStream); + String path = "/extensions/custom_extensions.yaml"; + InputStream inputStream = this.getClass().getResourceAsStream(path); + extensionCollection = SimpleExtension.load(path, inputStream); } final SubstraitBuilder b = new SubstraitBuilder(extensionCollection); diff --git a/core/src/test/java/io/substrait/extension/UriUrnMigrationTest.java b/core/src/test/java/io/substrait/extension/UriUrnMigrationTest.java new file mode 100644 index 000000000..3f9321907 --- /dev/null +++ b/core/src/test/java/io/substrait/extension/UriUrnMigrationTest.java @@ -0,0 +1,152 @@ +package io.substrait.extension; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import io.substrait.plan.Plan; +import io.substrait.plan.PlanProtoConverter; +import io.substrait.plan.ProtoPlanConverter; +import io.substrait.proto.SimpleExtensionURI; +import io.substrait.proto.SimpleExtensionURN; +import io.substrait.proto.SimpleExtensionDeclaration; +import io.substrait.proto.PlanRel; + +import org.junit.jupiter.api.Test; + +/** + * Tests describing the desired URI ↔ URN migration behaviour. These are disabled until the runtime + * support is implemented. + */ +public class UriUrnMigrationTest { + + private static final String SAMPLE_URI = "https://example.com/extensions/sample.yaml"; + private static final String SAMPLE_YAML = + "%YAML 1.2\n" + + "---\n" + + "urn: extension:test:sample\n" + + "scalar_functions:\n" + + " - name: add\n" + + " impls:\n" + + " - args:\n" + + " - value: i32\n" + + " - value: i32\n" + + " return: i32\n"; + + @Test + void uriOnlyPlanShouldHaveUrn() throws Exception { + SimpleExtension.ExtensionCollection extensions = SimpleExtension.load(SAMPLE_URI, SAMPLE_YAML); + io.substrait.proto.Plan protoPlan = + io.substrait.proto.Plan.newBuilder() + .addExtensionUrns( + SimpleExtensionURN.newBuilder() + .setExtensionUrnAnchor(1) + .setUrn("extension:test:sample") + .build()) + .addExtensions( + SimpleExtensionDeclaration.newBuilder() + .setExtensionFunction( + SimpleExtensionDeclaration.ExtensionFunction.newBuilder() + .setFunctionAnchor(1) + .setName("add:i32_i32") + .setExtensionUrnReference(1) + .build()) + .build()) + .addRelations( + PlanRel.newBuilder() + .setRoot( + io.substrait.proto.RelRoot.newBuilder() + .setInput( + io.substrait.proto.Rel.newBuilder() + .setProject( + io.substrait.proto.ProjectRel.newBuilder() + .setInput( + io.substrait.proto.Rel.newBuilder() + .setRead( + io.substrait.proto.ReadRel.newBuilder() + .setNamedTable( + io.substrait.proto.ReadRel + .NamedTable.newBuilder() + .addNames("dummy") + .build()) + .setBaseSchema( + io.substrait.proto.NamedStruct + .newBuilder() + .addNames("col") + .setStruct( + io.substrait.proto.Type + .Struct.newBuilder() + .addTypes( + io.substrait.proto + .Type + .newBuilder() + .setI32( + io.substrait + .proto + .Type + .I32 + .newBuilder()) + .build()) + .build()) + .build()) + .build()) + .build()) + .addExpressions( + io.substrait.proto.Expression.newBuilder() + .setScalarFunction( + io.substrait.proto.Expression.ScalarFunction + .newBuilder() + .setFunctionReference( + 1) // Uses our add function + .addArguments( + io.substrait.proto.FunctionArgument + .newBuilder() + .setValue( + io.substrait.proto + .Expression.newBuilder() + .setLiteral( + io.substrait.proto + .Expression + .Literal + .newBuilder() + .setI32(1) + .build()) + .build()) + .build()) + .addArguments( + io.substrait.proto.FunctionArgument + .newBuilder() + .setValue( + io.substrait.proto + .Expression.newBuilder() + .setLiteral( + io.substrait.proto + .Expression + .Literal + .newBuilder() + .setI32(2) + .build()) + .build()) + .build()) + .setOutputType( + io.substrait.proto.Type.newBuilder() + .setI32( + io.substrait.proto.Type.I32 + .newBuilder()) + .build()) + .build()) + .build()) + .build()) + .build()) + .addNames("result") + .build()) + .build()) + .build(); + + Plan planFromProto = new ProtoPlanConverter(extensions).from(protoPlan); + + assertTrue(planFromProto.getExtensionUris().size() > 0, "Plan should have URI"); + assertTrue(planFromProto.getExtensionUrns().size() > 0, "Plan should have URN"); + } +} diff --git a/core/src/test/java/io/substrait/extension/UrnValidationTest.java b/core/src/test/java/io/substrait/extension/UrnValidationTest.java new file mode 100644 index 000000000..c9b0bb465 --- /dev/null +++ b/core/src/test/java/io/substrait/extension/UrnValidationTest.java @@ -0,0 +1,38 @@ +package io.substrait.extension; + +import org.junit.jupiter.api.Test; +import static org.junit.jupiter.api.Assertions.*; + +public class UrnValidationTest { + + @Test + public void testMissingUrnThrowsException() { + String yamlWithoutUrn = "%YAML 1.2\n" + "---\n" + "scalar_functions:\n" + " - name: test\n"; + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> SimpleExtension.load("some/uri", yamlWithoutUrn)); + assertTrue(exception.getMessage().contains("Extension YAML file must contain a 'urn' field")); + } + + @Test + public void testInvalidUrnFormatThrowsException() { + String yamlWithInvalidUrn = "%YAML 1.2\n" + "---\n" + "urn: invalid:format\n" + "scalar_functions:\n" + " - name: test\n"; + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> SimpleExtension.load("some/uri", yamlWithInvalidUrn)); + assertTrue(exception.getMessage().contains("URN must follow format 'extension::'")); + } + + @Test + public void testValidUrnWorks() { + String yamlWithValidUrn = "%YAML 1.2\n" + "---\n" + "urn: extension:test:valid\n" + "scalar_functions:\n" + " - name: test\n"; + assertDoesNotThrow(() -> SimpleExtension.load("some/uri", yamlWithValidUrn)); + } + + @Test + public void testUriUrnMapIsPopulated() { + String yamlWithValidUrn = "%YAML 1.2\n" + + "---\n" + + "urn: extension:test:valid\n" + + "scalar_functions:\n" + + " - name: test\n"; + SimpleExtension.ExtensionCollection collection = SimpleExtension.load("test://uri", yamlWithValidUrn); + assertEquals("extension:test:valid", collection.getUrn("test://uri")); + } +} diff --git a/core/src/test/resources/extensions/custom_extensions.yaml b/core/src/test/resources/extensions/custom_extensions.yaml index 204a5f9ac..4776312ac 100644 --- a/core/src/test/resources/extensions/custom_extensions.yaml +++ b/core/src/test/resources/extensions/custom_extensions.yaml @@ -1,5 +1,6 @@ %YAML 1.2 --- +urn: extension:test:custom_extensions types: - name: "customType1" - name: "customType2" diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java b/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java index c0e1796da..3406de7de 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java @@ -71,7 +71,7 @@ public class CallConverters { Type.UserDefined t = (Type.UserDefined) type; return Expression.UserDefinedLiteral.builder() - .uri(t.uri()) + .urn(t.urn()) .name(t.name()) .value(literal.value()) .build(); diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/EnumConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/EnumConverter.java index cec68fcb7..b69ef9b02 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/EnumConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/EnumConverter.java @@ -185,7 +185,7 @@ private static ArgAnchor argAnchor(String fnNS, String fnSig, int argIdx) { private static ArgAnchor argAnchor(SimpleExtension.Function fnDef, int argIdx) { return new ArgAnchor( - SimpleExtension.FunctionAnchor.of(fnDef.getAnchor().namespace(), fnDef.getAnchor().key()), + SimpleExtension.FunctionAnchor.of(fnDef.getAnchor().urn(), fnDef.getAnchor().key()), argIdx); } } diff --git a/isthmus/src/test/java/io/substrait/isthmus/CalciteTypeTest.java b/isthmus/src/test/java/io/substrait/isthmus/CalciteTypeTest.java index a12a1eac6..176552f7d 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/CalciteTypeTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/CalciteTypeTest.java @@ -37,7 +37,7 @@ public Type toSubstrait(RelDataType relDataType) { @Nullable @Override public RelDataType toCalcite(Type.UserDefined type) { - if (type.uri().equals(uTypeURI) && type.name().equals(uTypeName)) { + if (type.urn().equals(uTypeURI) && type.name().equals(uTypeName)) { return uTypeFactory.createCalcite(type.nullable()); } return null; diff --git a/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java b/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java index fabddf56e..d351b1bca 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java @@ -43,7 +43,7 @@ public class CustomFunctionTest extends PlanTestBase { // Define custom functions in a "functions_custom.yaml" extension - static final String NAMESPACE = "/functions_custom"; + static final String NAMESPACE = "extension:substrait:functions_custom"; static final String FUNCTIONS_CUSTOM; static { @@ -56,7 +56,7 @@ public class CustomFunctionTest extends PlanTestBase { // Load custom extension into an ExtensionCollection static final SimpleExtension.ExtensionCollection extensionCollection = - SimpleExtension.load("/functions_custom", FUNCTIONS_CUSTOM); + SimpleExtension.load("custom.yaml", FUNCTIONS_CUSTOM); final SubstraitBuilder b = new SubstraitBuilder(extensionCollection); @@ -84,7 +84,7 @@ public Type toSubstrait(RelDataType relDataType) { @Nullable @Override public RelDataType toCalcite(Type.UserDefined type) { - if (type.uri().equals(NAMESPACE)) { + if (type.urn().equals(NAMESPACE)) { if (type.name().equals(aTypeName)) { return aTypeFactory.createCalcite(type.nullable()); } diff --git a/isthmus/src/test/java/io/substrait/isthmus/RelCopyOnWriteVisitorTest.java b/isthmus/src/test/java/io/substrait/isthmus/RelCopyOnWriteVisitorTest.java index 6e95e3dcd..47e8d2b71 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/RelCopyOnWriteVisitorTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/RelCopyOnWriteVisitorTest.java @@ -29,14 +29,15 @@ public class RelCopyOnWriteVisitorTest extends PlanTestBase { public static SimpleExtension.FunctionAnchor APPROX_COUNT_DISTINCT = SimpleExtension.FunctionAnchor.of( - "/functions_aggregate_approx.yaml", "approx_count_distinct:any"); + "extension:io.substrait:functions_aggregate_approx", "approx_count_distinct:any"); public static SimpleExtension.FunctionAnchor COUNT = - SimpleExtension.FunctionAnchor.of("/functions_aggregate_generic.yaml", "count:any"); + SimpleExtension.FunctionAnchor.of( + "extension:io.substrait:functions_aggregate_generic", "count:any"); private static final String COUNT_DISTINCT_SUBBQUERY = "select\n" + " count(distinct l.l_orderkey),\n" - + " count(distinct l.l_orderkey) + 1,\n" + + " count(distinct l.l_orderkey) + 1,\n" + " sum(l.l_extendedprice * (1 - l.l_discount)) as revenue,\n" + " o.o_orderdate,\n" + " count(distinct o.o_shippriority)\n" diff --git a/isthmus/src/test/java/io/substrait/isthmus/utils/UserTypeFactory.java b/isthmus/src/test/java/io/substrait/isthmus/utils/UserTypeFactory.java index 99fbc7d09..2c90f133d 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/utils/UserTypeFactory.java +++ b/isthmus/src/test/java/io/substrait/isthmus/utils/UserTypeFactory.java @@ -13,11 +13,11 @@ public class UserTypeFactory { private final InnerType N; private final InnerType R; - private final String uri; + private final String urn; private final String name; - public UserTypeFactory(String uri, String name) { - this.uri = uri; + public UserTypeFactory(String urn, String name) { + this.urn = urn; this.name = name; this.N = new InnerType(true, name); this.R = new InnerType(false, name); @@ -32,7 +32,7 @@ public RelDataType createCalcite(boolean nullable) { } public Type createSubstrait(boolean nullable) { - return TypeCreator.of(nullable).userDefined(uri, name); + return TypeCreator.of(nullable).userDefined(urn, name); } public boolean isTypeFromFactory(RelDataType type) { diff --git a/isthmus/src/test/resources/extensions/functions_custom.yaml b/isthmus/src/test/resources/extensions/functions_custom.yaml index 9fb8b010a..03160f723 100644 --- a/isthmus/src/test/resources/extensions/functions_custom.yaml +++ b/isthmus/src/test/resources/extensions/functions_custom.yaml @@ -1,5 +1,6 @@ %YAML 1.2 --- +urn: extension:substrait:functions_custom types: - name: "a_type" - name: "b_type" diff --git a/spark/src/main/resources/spark.yml b/spark/src/main/resources/spark.yml index fb33385a1..48281ea03 100644 --- a/spark/src/main/resources/spark.yml +++ b/spark/src/main/resources/spark.yml @@ -14,6 +14,7 @@ # limitations under the License. %YAML 1.2 --- +urn: extension:substrait:spark scalar_functions: - name: add description: >- diff --git a/spark/src/main/scala/io/substrait/spark/SparkExtension.scala b/spark/src/main/scala/io/substrait/spark/SparkExtension.scala index c470c7a42..367365430 100644 --- a/spark/src/main/scala/io/substrait/spark/SparkExtension.scala +++ b/spark/src/main/scala/io/substrait/spark/SparkExtension.scala @@ -26,10 +26,10 @@ import scala.collection.JavaConverters import scala.collection.JavaConverters.asScalaBufferConverter object SparkExtension { - final val uri = "/spark.yml" + final val file = "/spark.yml" private val SparkImpls: SimpleExtension.ExtensionCollection = - SimpleExtension.load(Collections.singletonList(uri)) + SimpleExtension.load(file, getClass.getResourceAsStream(file)) private val EXTENSION_COLLECTION: SimpleExtension.ExtensionCollection = SimpleExtension.loadDefaults() diff --git a/substrait b/substrait index 793c64ba2..4c3531872 160000 --- a/substrait +++ b/substrait @@ -1 +1 @@ -Subproject commit 793c64ba26e337c22f5e91b658be58b1eea7efd3 +Subproject commit 4c35318727c36d6e49779c06daf9f4ced722fe43