Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions core/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ dependencies {
testImplementation(platform("org.junit:junit-bom:${JUNIT_VERSION}"))
testImplementation("org.junit.jupiter:junit-jupiter")
testRuntimeOnly("org.junit.platform:junit-platform-launcher")
testRuntimeOnly("org.jetbrains.kotlin:kotlin-stdlib:${properties.get("kotlin.version")}")
api("com.google.protobuf:protobuf-java:${PROTOBUF_VERSION}")
implementation("com.fasterxml.jackson.core:jackson-databind:${JACKSON_VERSION}")
implementation("com.fasterxml.jackson.core:jackson-annotations:${JACKSON_VERSION}")
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/java/io/substrait/dsl/SubstraitBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,7 @@ public Expression.WindowFunctionInvocation windowFn(
// Types

public Type.UserDefined userDefinedType(String namespace, String typeName) {
return Type.UserDefined.builder().uri(namespace).name(typeName).nullable(false).build();
return Type.UserDefined.builder().urn(namespace).name(typeName).nullable(false).build();
}

// Misc
Expand Down
4 changes: 2 additions & 2 deletions core/src/main/java/io/substrait/expression/Expression.java
Original file line number Diff line number Diff line change
Expand Up @@ -666,13 +666,13 @@ public <R, C extends VisitationContext, E extends Throwable> R accept(
abstract class UserDefinedLiteral implements Literal {
public abstract ByteString value();

public abstract String uri();
public abstract String urn();

public abstract String name();

@Override
public Type getType() {
return Type.withNullability(nullable()).userDefined(uri(), name());
return Type.withNullability(nullable()).userDefined(urn(), name());
}

public static ImmutableExpression.UserDefinedLiteral.Builder builder() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -287,10 +287,10 @@ public static Expression.StructLiteral struct(
}

public static Expression.UserDefinedLiteral userDefinedLiteral(
boolean nullable, String uri, String name, Any value) {
boolean nullable, String urn, String name, Any value) {
return Expression.UserDefinedLiteral.builder()
.nullable(nullable)
.uri(uri)
.urn(urn)
.name(name)
.value(value.toByteString())
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ public Expression visit(
public Expression visit(
io.substrait.expression.Expression.UserDefinedLiteral expr, EmptyVisitationContext context) {
int typeReference =
extensionCollector.getTypeReference(SimpleExtension.TypeAnchor.of(expr.uri(), expr.name()));
extensionCollector.getTypeReference(SimpleExtension.TypeAnchor.of(expr.urn(), expr.name()));
return lit(
bldr -> {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,7 @@ public Expression.Literal from(io.substrait.proto.Expression.Literal literal) {
SimpleExtension.Type type =
lookup.getType(userDefinedLiteral.getTypeReference(), extensions);
return ExpressionCreator.userDefinedLiteral(
literal.getNullable(), type.uri(), type.name(), userDefinedLiteral.getValue());
literal.getNullable(), type.urn(), type.name(), userDefinedLiteral.getValue());
}
default:
throw new IllegalStateException("Unexpected value: " + literal.getLiteralTypeCase());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public ExtendedExpression from(io.substrait.proto.ExtendedExpression extendedExp
// fill in simple extension information through a discovery in the current proto-extended
// expression
ExtensionLookup functionLookup =
ImmutableExtensionLookup.builder().from(extendedExpression).build();
ImmutableExtensionLookup.builder().from(extendedExpression, this.extensionCollection.uriUrnMap()).build();

NamedStruct baseSchemaProto = extendedExpression.getBaseSchema();

Expand Down
60 changes: 60 additions & 0 deletions core/src/main/java/io/substrait/extension/BidiMap.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package io.substrait.extension;

import java.util.HashMap;
import java.util.Map;
import java.util.Set;

/** We don't depend on guava... */
class BidiMap<T1, T2> {
private final Map<T1, T2> forwardMap;
private final Map<T2, T1> reverseMap;

public BidiMap(Map<T1, T2> forwardMap) {
this.forwardMap = forwardMap;
this.reverseMap = new HashMap<>();
}

public BidiMap() {
this.forwardMap = new HashMap<>();
this.reverseMap = new HashMap<>();
}

public T2 get(T1 t1) {
return forwardMap.get(t1);
}

public T1 reverseGet(T2 t2) {
return reverseMap.get(t2);
}

public void put(T1 t1, T2 t2) {
// Check for conflicting mappings (different values for same key)
T2 existingForward = forwardMap.get(t1);
T1 existingReverse = reverseMap.get(t2);

if (existingForward != null && !existingForward.equals(t2)) {
throw new IllegalArgumentException("Key already exists in map with different value");
}
if (existingReverse != null && !existingReverse.equals(t1)) {
throw new IllegalArgumentException("Key already exists in map with different value");
}

// Allow identical mappings, only add if not already present
forwardMap.put(t1, t2);
reverseMap.put(t2, t1);
}

public void merge(BidiMap<T1, T2> other) {
for (Map.Entry<T1, T2> entry : other.forwardEntrySet()) {
put(entry.getKey(), entry.getValue());
}
}

public Set<Map.Entry<T1, T2>> forwardEntrySet() {
return forwardMap.entrySet();
}

public Set<Map.Entry<T2, T1>> reverseEntrySet() {
return reverseMap.entrySet();
}
}
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
package io.substrait.extension;

public class DefaultExtensionCatalog {
public static final String FUNCTIONS_AGGREGATE_APPROX = "/functions_aggregate_approx.yaml";
public static final String FUNCTIONS_AGGREGATE_GENERIC = "/functions_aggregate_generic.yaml";
public static final String FUNCTIONS_ARITHMETIC = "/functions_arithmetic.yaml";
public static final String FUNCTIONS_ARITHMETIC_DECIMAL = "/functions_arithmetic_decimal.yaml";
public static final String FUNCTIONS_BOOLEAN = "/functions_boolean.yaml";
public static final String FUNCTIONS_COMPARISON = "/functions_comparison.yaml";
public static final String FUNCTIONS_DATETIME = "/functions_datetime.yaml";
public static final String FUNCTIONS_GEOMETRY = "/functions_geometry.yaml";
public static final String FUNCTIONS_LOGARITHMIC = "/functions_logarithmic.yaml";
public static final String FUNCTIONS_ROUNDING = "/functions_rounding.yaml";
public static final String FUNCTIONS_ROUNDING_DECIMAL = "/functions_rounding_decimal.yaml";
public static final String FUNCTIONS_SET = "/functions_set.yaml";
public static final String FUNCTIONS_STRING = "/functions_string.yaml";
public static final String FUNCTIONS_AGGREGATE_APPROX = "extension:io.substrait:functions_aggregate_approx";
public static final String FUNCTIONS_AGGREGATE_GENERIC = "extension:io.substrait:functions_aggregate_generic";
public static final String FUNCTIONS_ARITHMETIC = "extension:io.substrait:functions_arithmetic";
public static final String FUNCTIONS_ARITHMETIC_DECIMAL = "extension:io.substrait:functions_arithmetic_decimal";
public static final String FUNCTIONS_BOOLEAN = "extension:io.substrait:functions_boolean";
public static final String FUNCTIONS_COMPARISON = "extension:io.substrait:functions_comparison";
public static final String FUNCTIONS_DATETIME = "extension:io.substrait:functions_datetime";
public static final String FUNCTIONS_GEOMETRY = "extension:io.substrait:functions_geometry";
public static final String FUNCTIONS_LOGARITHMIC = "extension:io.substrait:functions_logarithmic";
public static final String FUNCTIONS_ROUNDING = "extension:io.substrait:functions_rounding";
public static final String FUNCTIONS_ROUNDING_DECIMAL = "extension:io.substrait:functions_rounding_decimal";
public static final String FUNCTIONS_SET = "extension:io.substrait:functions_set";
public static final String FUNCTIONS_STRING = "extension:io.substrait:functions_string";
}
74 changes: 25 additions & 49 deletions core/src/main/java/io/substrait/extension/ExtensionCollector.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import io.substrait.proto.ExtendedExpression;
import io.substrait.proto.Plan;
import io.substrait.proto.SimpleExtensionDeclaration;
import io.substrait.proto.SimpleExtensionURI;
import io.substrait.proto.SimpleExtensionURN;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
Expand Down Expand Up @@ -52,96 +52,72 @@ public int getTypeReference(SimpleExtension.TypeAnchor typeAnchor) {
public void addExtensionsToPlan(Plan.Builder builder) {
SimpleExtensions simpleExtensions = getExtensions();

builder.addAllExtensionUris(simpleExtensions.uris.values());
builder.addAllExtensionUrns(simpleExtensions.urns.values());
builder.addAllExtensions(simpleExtensions.extensionList);
}

public void addExtensionsToExtendedExpression(ExtendedExpression.Builder builder) {
SimpleExtensions simpleExtensions = getExtensions();

builder.addAllExtensionUris(simpleExtensions.uris.values());
builder.addAllExtensionUrns(simpleExtensions.urns.values());
builder.addAllExtensions(simpleExtensions.extensionList);
}

private SimpleExtensions getExtensions() {
AtomicInteger uriPos = new AtomicInteger(1);
HashMap<String, SimpleExtensionURI> uris = new HashMap<>();
AtomicInteger urnPos = new AtomicInteger(1);
HashMap<String, SimpleExtensionURN> urns = new HashMap<>();

ArrayList<SimpleExtensionDeclaration> extensionList = new ArrayList<>();
for (Map.Entry<Integer, SimpleExtension.FunctionAnchor> e : funcMap.forwardMap.entrySet()) {
SimpleExtensionURI uri =
uris.computeIfAbsent(
e.getValue().namespace(),
for (Map.Entry<Integer, SimpleExtension.FunctionAnchor> e : funcMap.forwardEntrySet()) {
SimpleExtensionURN urn =
urns.computeIfAbsent(
e.getValue().urn(),
k ->
SimpleExtensionURI.newBuilder()
.setExtensionUriAnchor(uriPos.getAndIncrement())
.setUri(k)
SimpleExtensionURN.newBuilder()
.setExtensionUrnAnchor(urnPos.getAndIncrement())
.setUrn(k)
.build());
SimpleExtensionDeclaration decl =
SimpleExtensionDeclaration.newBuilder()
.setExtensionFunction(
SimpleExtensionDeclaration.ExtensionFunction.newBuilder()
.setFunctionAnchor(e.getKey())
.setName(e.getValue().key())
.setExtensionUriReference(uri.getExtensionUriAnchor()))
.setExtensionUrnReference(urn.getExtensionUrnAnchor()))
.build();
extensionList.add(decl);
}
for (Map.Entry<Integer, SimpleExtension.TypeAnchor> e : typeMap.forwardMap.entrySet()) {
SimpleExtensionURI uri =
uris.computeIfAbsent(
e.getValue().namespace(),
for (Map.Entry<Integer, SimpleExtension.TypeAnchor> e : typeMap.forwardEntrySet()) {
SimpleExtensionURN urn =
urns.computeIfAbsent(
e.getValue().urn(),
k ->
SimpleExtensionURI.newBuilder()
.setExtensionUriAnchor(uriPos.getAndIncrement())
.setUri(k)
SimpleExtensionURN.newBuilder()
.setExtensionUrnAnchor(urnPos.getAndIncrement())
.setUrn(k)
.build());
SimpleExtensionDeclaration decl =
SimpleExtensionDeclaration.newBuilder()
.setExtensionType(
SimpleExtensionDeclaration.ExtensionType.newBuilder()
.setTypeAnchor(e.getKey())
.setName(e.getValue().key())
.setExtensionUriReference(uri.getExtensionUriAnchor()))
.setExtensionUrnReference(urn.getExtensionUrnAnchor()))
.build();
extensionList.add(decl);
}
return new SimpleExtensions(uris, extensionList);
return new SimpleExtensions(urns, extensionList);
}

private static final class SimpleExtensions {
final HashMap<String, SimpleExtensionURI> uris;
final HashMap<String, SimpleExtensionURN> urns;
final ArrayList<SimpleExtensionDeclaration> extensionList;

SimpleExtensions(
HashMap<String, SimpleExtensionURI> uris,
HashMap<String, SimpleExtensionURN> urns,
ArrayList<SimpleExtensionDeclaration> extensionList) {
this.uris = uris;
this.urns = urns;
this.extensionList = extensionList;
}
}

/** We don't depend on guava... */
private static class BidiMap<T1, T2> {
private final Map<T1, T2> forwardMap;
private final Map<T2, T1> reverseMap;

public BidiMap(Map<T1, T2> forwardMap) {
this.forwardMap = forwardMap;
this.reverseMap = new HashMap<>();
}

public T2 get(T1 t1) {
return forwardMap.get(t1);
}

public T1 reverseGet(T2 t2) {
return reverseMap.get(t2);
}

public void put(T1 t1, T2 t2) {
forwardMap.put(t1, t2);
reverseMap.put(t2, t1);
}
}
}
Loading