Skip to content

Commit c3e45a5

Browse files
committed
Add sql scalar function to transform value in a row with a single expression
1 parent 974b759 commit c3e45a5

File tree

6 files changed

+293
-1
lines changed

6 files changed

+293
-1
lines changed

core/trino-main/src/main/java/io/trino/metadata/SystemFunctionBundle.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,7 @@
311311
import static io.trino.operator.scalar.Re2JCastToRegexpFunction.castVarcharToRe2JRegexp;
312312
import static io.trino.operator.scalar.RowToJsonCast.ROW_TO_JSON;
313313
import static io.trino.operator.scalar.RowToRowCast.ROW_TO_ROW_CAST;
314+
import static io.trino.operator.scalar.RowTransformFunction.ROW_TRANSFORM_FUNCTION;
314315
import static io.trino.operator.scalar.TryCastFunction.TRY_CAST;
315316
import static io.trino.operator.scalar.ZipFunction.ZIP_FUNCTIONS;
316317
import static io.trino.operator.scalar.ZipWithFunction.ZIP_WITH_FUNCTION;
@@ -577,7 +578,7 @@ public static FunctionBundle create(FeaturesConfig featuresConfig, TypeOperators
577578
.aggregates(DecimalAverageAggregation.class)
578579
.aggregates(DecimalSumAggregation.class)
579580
.function(DECIMAL_MOD_FUNCTION)
580-
.functions(ARRAY_TRANSFORM_FUNCTION, ARRAY_REDUCE_FUNCTION)
581+
.functions(ROW_TRANSFORM_FUNCTION, ARRAY_TRANSFORM_FUNCTION, ARRAY_REDUCE_FUNCTION)
581582
.functions(MAP_FILTER_FUNCTION, new MapTransformKeysFunction(blockTypeOperators), MAP_TRANSFORM_VALUES_FUNCTION)
582583
.function(FORMAT_FUNCTION)
583584
.function(TRY_CAST)
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
/*
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* You may obtain a copy of the License at
5+
*
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
*/
14+
package io.trino.operator.scalar;
15+
16+
import com.google.common.collect.ImmutableList;
17+
import io.airlift.slice.Slice;
18+
import io.trino.annotation.UsedByGeneratedCode;
19+
import io.trino.metadata.SqlScalarFunction;
20+
import io.trino.spi.StandardErrorCode;
21+
import io.trino.spi.TrinoException;
22+
import io.trino.spi.block.Block;
23+
import io.trino.spi.block.SqlRow;
24+
import io.trino.spi.function.BoundSignature;
25+
import io.trino.spi.function.FunctionMetadata;
26+
import io.trino.spi.function.Signature;
27+
import io.trino.spi.type.RowType;
28+
import io.trino.spi.type.RowType.Field;
29+
import io.trino.spi.type.Type;
30+
import io.trino.spi.type.TypeSignature;
31+
import io.trino.sql.gen.lambda.UnaryFunctionInterface;
32+
33+
import java.lang.invoke.MethodHandle;
34+
import java.util.List;
35+
import java.util.Optional;
36+
37+
import static io.trino.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT;
38+
import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.FUNCTION;
39+
import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL;
40+
import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL;
41+
import static io.trino.spi.type.TypeSignature.functionType;
42+
import static io.trino.spi.type.TypeUtils.readNativeValue;
43+
import static io.trino.spi.type.TypeUtils.writeNativeValue;
44+
import static io.trino.spi.type.VarcharType.VARCHAR;
45+
import static io.trino.util.Reflection.methodHandle;
46+
47+
public final class RowTransformFunction
48+
extends SqlScalarFunction
49+
{
50+
public static final RowTransformFunction ROW_TRANSFORM_FUNCTION = new RowTransformFunction();
51+
private static final String ROW_TRANSFORM_NAME = "transform";
52+
private static final MethodHandle METHOD_HANDLE = methodHandle(RowTransformFunction.class, "transform", RowType.class, Type.class, SqlRow.class, Slice.class, Object.class, UnaryFunctionInterface.class);
53+
54+
private RowTransformFunction()
55+
{
56+
super(FunctionMetadata.scalarBuilder(ROW_TRANSFORM_NAME)
57+
.signature(Signature.builder()
58+
.variadicTypeParameter("T", "row")
59+
.typeVariable("V")
60+
.returnType(new TypeSignature("T"))
61+
.argumentType(new TypeSignature("T"))
62+
.argumentType(VARCHAR.getTypeSignature())
63+
.argumentType(new TypeSignature("V"))
64+
.argumentType(functionType(new TypeSignature("V"), new TypeSignature("V")))
65+
.build())
66+
.description("Apply lambda to the value of a field, returning the transformed row")
67+
.build());
68+
}
69+
70+
@Override
71+
protected SpecializedSqlScalarFunction specialize(BoundSignature boundSignature)
72+
{
73+
RowType rowType = (RowType) boundSignature.getArgumentType(0);
74+
Type valueType = boundSignature.getArgumentType(2);
75+
76+
return new ChoicesSpecializedSqlScalarFunction(
77+
boundSignature,
78+
FAIL_ON_NULL,
79+
ImmutableList.of(NEVER_NULL, NEVER_NULL, NEVER_NULL, FUNCTION),
80+
ImmutableList.of(UnaryFunctionInterface.class),
81+
METHOD_HANDLE.asType(
82+
METHOD_HANDLE.type()
83+
.changeParameterType(4, valueType.getJavaType())
84+
).bindTo(rowType).bindTo(valueType),
85+
Optional.empty());
86+
}
87+
88+
@UsedByGeneratedCode
89+
public static SqlRow transform(RowType rowType, Type valueType, SqlRow sqlRow, Slice fieldNameSlice, Object dummyValue, UnaryFunctionInterface function)
90+
{
91+
int fieldIndex = -1;
92+
Field match = null;
93+
String fieldName = fieldNameSlice.toStringUtf8();
94+
List<Field> fields = rowType.getFields();
95+
for (int i = 0; i < fields.size(); i++) {
96+
Field field = fields.get(i);
97+
if (field.getName().orElse("").equals(fieldName)) {
98+
match = field;
99+
fieldIndex = i;
100+
break;
101+
}
102+
}
103+
104+
if (match == null) {
105+
throw new TrinoException(INVALID_FUNCTION_ARGUMENT, String.format("Field with name %s not found in row", fieldName));
106+
}
107+
if (match.getType().getClass() != valueType.getClass()) {
108+
throw new TrinoException(INVALID_FUNCTION_ARGUMENT, String.format("Incompatible function types: field is of type %s but lambda returns %s", match.getType(), valueType));
109+
}
110+
111+
Block[] blocks = new Block[fields.size()];
112+
for (int i = 0; i < fields.size(); i++) {
113+
if (i != fieldIndex) {
114+
blocks[i] = sqlRow.getRawFieldBlock(i).getSingleValueBlock(sqlRow.getRawIndex());
115+
}
116+
else {
117+
Object value = readNativeValue(valueType, sqlRow.getRawFieldBlock(i), sqlRow.getRawIndex());
118+
blocks[i] = writeNativeValue(valueType, function.apply(value));
119+
}
120+
}
121+
return new SqlRow(0, blocks);
122+
}
123+
}
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
/*
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* You may obtain a copy of the License at
5+
*
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
*/
14+
package io.trino.operator.scalar;
15+
16+
import com.google.common.collect.ImmutableList;
17+
import io.trino.spi.type.ArrayType;
18+
import io.trino.sql.query.QueryAssertions;
19+
import org.junit.jupiter.api.AfterAll;
20+
import org.junit.jupiter.api.BeforeAll;
21+
import org.junit.jupiter.api.Test;
22+
import org.junit.jupiter.api.TestInstance;
23+
import org.junit.jupiter.api.parallel.Execution;
24+
25+
import static io.trino.spi.type.ArrayType.arrayType;
26+
import static io.trino.spi.type.IntegerType.INTEGER;
27+
import static io.trino.spi.type.RowType.field;
28+
import static io.trino.spi.type.RowType.rowType;
29+
import static io.trino.spi.type.VarcharType.VARCHAR;
30+
import static org.apache.commons.io.IOUtils.closeQuietly;
31+
import static org.assertj.core.api.Assertions.assertThat;
32+
import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS;
33+
import static org.junit.jupiter.api.parallel.ExecutionMode.CONCURRENT;
34+
35+
@TestInstance(PER_CLASS)
36+
@Execution(CONCURRENT)
37+
public class TestRowTransformFunction
38+
{
39+
private QueryAssertions assertions;
40+
41+
@BeforeAll
42+
public void init()
43+
{
44+
assertions = new QueryAssertions();
45+
}
46+
47+
@AfterAll
48+
public void teardown()
49+
{
50+
closeQuietly(assertions);
51+
assertions = null;
52+
}
53+
54+
@Test
55+
public void testInteger()
56+
{
57+
assertThat(assertions.expression("transform(a, 'greeting', 1337, greeting -> greeting * 2)")
58+
.binding("a", "CAST(ROW(2, 3) as ROW(greeting integer, planet integer))"))
59+
.hasType(rowType(field("greeting", INTEGER), field("planet", INTEGER)))
60+
.isEqualTo(ImmutableList.of(4, 3));
61+
}
62+
63+
@Test
64+
public void testVarchar()
65+
{
66+
assertThat(assertions.expression("transform(a, 'greeting', '', greeting -> concat(greeting, ' or goodbye'))")
67+
.binding("a", "CAST(ROW('hello', 'world') as ROW(greeting varchar, planet varchar))"))
68+
.hasType(rowType(field("greeting", VARCHAR), field("planet", VARCHAR)))
69+
.isEqualTo(ImmutableList.of("hello or goodbye", "world"));
70+
}
71+
72+
@Test
73+
public void testIntegerArray()
74+
{
75+
assertThat(assertions.expression("transform(a, 'greeting', ARRAY[0], greeting -> greeting || 2)")
76+
.binding("a", "CAST(ROW(ARRAY[1], 'world') as ROW(greeting array(integer), planet varchar))"))
77+
.hasType(rowType(field("greeting", new ArrayType(INTEGER)), field("planet", VARCHAR)))
78+
.isEqualTo(ImmutableList.of(ImmutableList.of(1, 2), "world"));
79+
}
80+
81+
@Test
82+
public void testVarcharArray()
83+
{
84+
assertThat(assertions.expression("transform(a, 'greeting', ARRAY[''], greeting -> greeting || 'or' || 'goodbye')")
85+
.binding("a", "CAST(ROW(ARRAY['hello'], 'world') AS ROW(greeting array(varchar), planet varchar))"))
86+
.hasType(rowType(field("greeting", new ArrayType(VARCHAR)), field("planet", VARCHAR)))
87+
.isEqualTo(ImmutableList.of(ImmutableList.of("hello", "or", "goodbye"), "world"));
88+
}
89+
90+
@Test
91+
public void testVarcharRowType()
92+
{
93+
assertThat(assertions.expression("transform(a, 'greeting', CAST(ROW('') as ROW(text varchar)), greeting -> transform(greeting, 'text', '', old_text -> concat(old_text, ' or goodbye')))")
94+
.binding("a", "CAST(ROW(ROW('hello'), 'world') as ROW(greeting ROW(text varchar), planet varchar))"))
95+
.hasType(rowType(field("greeting", rowType(field("text", VARCHAR))), field("planet", VARCHAR)))
96+
.isEqualTo(ImmutableList.of(ImmutableList.of("hello or goodbye"), "world"));
97+
}
98+
99+
@Test
100+
public void testVarcharRowTypeArrayType()
101+
{
102+
assertThat(assertions.expression("""
103+
transform(a, data ->
104+
transform(data, 'greeting', '', greeting -> concat(greeting, ' or goodbye')))
105+
""")
106+
.binding("a", """
107+
ARRAY[CAST(ROW('hello', 'world') as ROW(greeting varchar, planet varchar)),
108+
CAST(ROW('hi', 'mars') as ROW(greeting varchar, planet varchar)),
109+
CAST(ROW('hey', 'jupiter') as ROW(greeting varchar, planet varchar))]
110+
"""))
111+
.hasType(arrayType(rowType(field("greeting", VARCHAR), field("planet", VARCHAR))))
112+
.isEqualTo(ImmutableList.of(
113+
ImmutableList.of("hello or goodbye", "world"),
114+
ImmutableList.of("hi or goodbye", "mars"),
115+
ImmutableList.of("hey or goodbye", "jupiter")));
116+
}
117+
}

core/trino-spi/src/main/java/io/trino/spi/type/ArrayType.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,11 @@ public TypeOperatorDeclaration getTypeOperatorDeclaration(TypeOperators typeOper
117117
return operatorDeclaration;
118118
}
119119

120+
public static ArrayType arrayType(Type elementType)
121+
{
122+
return new ArrayType(elementType);
123+
}
124+
120125
private synchronized void generateTypeOperators(TypeOperators typeOperators)
121126
{
122127
if (operatorDeclaration != null) {

docs/src/main/sphinx/functions.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ Map <functions/map>
5353
Math <functions/math>
5454
Quantile digest <functions/qdigest>
5555
Regular expression <functions/regexp>
56+
Row <functions/row>
5657
Session <functions/session>
5758
Set Digest <functions/setdigest>
5859
String <functions/string>
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Row functions
2+
3+
Row functions use the [ROW type](row-type).
4+
Create a row by explicitly casting the field names and types:
5+
6+
```sql
7+
SELECT CAST(ROW('hello', 'world') AS ROW(greeting varchar, planet varchar))
8+
-- ROW('hello', 'world')
9+
```
10+
11+
Fields can be accessed via the dot fieldname:
12+
```sql
13+
SELECT CAST(ROW('hello', 'world') AS ROW(greeting varchar, planet varchar)).greeting
14+
-- 'hello'
15+
```
16+
17+
## Row functions
18+
19+
:::{function} transform(T, varchar, V, function(V, V)) -> T
20+
With this function, a field in the row can be updated with the lambda function.
21+
The returned value is the original value with the updated field. The second
22+
argument is the name of the field to update. The third argument, `V` is a dummy
23+
so the type of the function can be resolved. It can be any value, as long as the
24+
type of the value is equal to the type of the argument and return type of the
25+
lambda function.
26+
27+
```
28+
SELECT transform(
29+
CAST(ROW('hello', 'world') AS ROW(greeting varchar, planet varchar))",
30+
'greeting',
31+
'',
32+
greeting -> concat(greeting, ' or goodbye'));
33+
-- ROW('hello or goodbye', 'world')
34+
```
35+
36+
The transform can be used to reach fields in nested rows or fields in rows
37+
in arrays:
38+
```
39+
SELECT transform(ARRAY[
40+
CAST(ROW('hello', 'world') AS ROW(greeting varchar, planet varchar)),
41+
CAST(ROW('hi', 'mars') AS ROW(greeting varchar, planet varchar))],
42+
data -> transform(data, 'greeting', '', greeting -> concat(greeting, ' or goodbye')));
43+
-- ARRAY[ROW('hello or goodbye', 'world'), ROW('hi or goodbye', 'mars')]
44+
```
45+
:::

0 commit comments

Comments
 (0)