Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/source/contributor-guide/spark_expressions_support.md
Original file line number Diff line number Diff line change
Expand Up @@ -587,9 +587,9 @@

- [ ] parse_url
- [ ] try_parse_url
- [ ] try_url_decode
- [ ] url_decode
- [ ] url_encode
- [x] try_url_decode
- [x] url_decode
- [x] url_encode

### window_funcs

Expand Down
6 changes: 6 additions & 0 deletions native/core/src/execution/jni_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ use datafusion_spark::function::string::char::CharFunc;
use datafusion_spark::function::string::concat::SparkConcat;
use datafusion_spark::function::string::luhn_check::SparkLuhnCheck;
use datafusion_spark::function::string::space::SparkSpace;
use datafusion_spark::function::url::try_url_decode::TryUrlDecode as SparkTryUrlDecode;
use datafusion_spark::function::url::url_decode::UrlDecode as SparkUrlDecode;
use datafusion_spark::function::url::url_encode::UrlEncode as SparkUrlEncode;
use futures::poll;
use futures::stream::StreamExt;
use futures::FutureExt;
Expand Down Expand Up @@ -567,6 +570,9 @@ fn register_datafusion_spark_function(session_ctx: &SessionContext) {
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkArrayContains::default()));
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkBin::default()));
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkStrToMap::default()));
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkUrlDecode::default()));
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkUrlEncode::default()));
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkTryUrlDecode::default()));
}

/// Prepares arrow arrays for output.
Expand Down
34 changes: 32 additions & 2 deletions spark/src/main/scala/org/apache/comet/serde/statics.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@

package org.apache.comet.serde

import org.apache.spark.sql.catalyst.expressions.{Attribute, ExpressionImplUtils}
import org.apache.spark.sql.catalyst.expressions.{Attribute, ExpressionImplUtils, Literal, UrlCodec}
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils

import org.apache.comet.CometSparkSessionExtensions.withInfo
import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto}

object CometStaticInvoke extends CometExpressionSerde[StaticInvoke] {

Expand All @@ -35,7 +36,9 @@ object CometStaticInvoke extends CometExpressionSerde[StaticInvoke] {
Map(
("readSidePadding", classOf[CharVarcharCodegenUtils]) -> CometScalarFunction(
"read_side_padding"),
("isLuhnNumber", classOf[ExpressionImplUtils]) -> CometScalarFunction("luhn_check"))
("isLuhnNumber", classOf[ExpressionImplUtils]) -> CometScalarFunction("luhn_check"),
("encode", UrlCodec.getClass) -> CometUrlEncodeStaticInvoke,
("decode", UrlCodec.getClass) -> CometUrlDecodeStaticInvoke)
Comment on lines +40 to +41
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
("encode", UrlCodec.getClass) -> CometUrlEncodeStaticInvoke,
("decode", UrlCodec.getClass) -> CometUrlDecodeStaticInvoke)
("url_encode", UrlCodec.getClass) -> CometUrlEncodeStaticInvoke,
("url_decode", UrlCodec.getClass) -> CometUrlDecodeStaticInvoke)

?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at the Spark source for UrlEncode / UrlDecode, the rewrite uses:

StaticInvoke(UrlCodec.getClass, dataType, "encode", Seq(child), ...)
StaticInvoke(UrlCodec.getClass, dataType, "decode", Seq(child, Literal(failOnError)), ...)

The third argument is the JVM method name on UrlCodec, which is literally "encode" and "decode". The user-facing SQL names url_encode / url_decode come from prettyName, not functionName.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@comphead These correspond to the names in Spark code - UrlCodec.encode(...) and UrlCodec.decode(...).


override def convert(
expr: StaticInvoke,
Expand All @@ -53,3 +56,30 @@ object CometStaticInvoke extends CometExpressionSerde[StaticInvoke] {
}
}
}

object CometUrlEncodeStaticInvoke extends CometExpressionSerde[StaticInvoke] {
override def convert(
expr: StaticInvoke,
inputs: Seq[Attribute],
binding: Boolean): Option[ExprOuterClass.Expr] = {
val childExpr = exprToProtoInternal(expr.children.head, inputs, binding)
val optExpr = scalarFunctionExprToProto("url_encode", childExpr)
optExprWithInfo(optExpr, expr, expr.children: _*)
}
}

object CometUrlDecodeStaticInvoke extends CometExpressionSerde[StaticInvoke] {
override def convert(
expr: StaticInvoke,
inputs: Seq[Attribute],
binding: Boolean): Option[ExprOuterClass.Expr] = {
val failOnError = expr.children match {
case Seq(_, Literal(false, _)) => false
case _ => true
}
val funcName = if (failOnError) "url_decode" else "try_url_decode"
val childExpr = exprToProtoInternal(expr.children.head, inputs, binding)
val optExpr = scalarFunctionExprToProto(funcName, childExpr)
optExprWithInfo(optExpr, expr, expr.children: _*)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
-- Licensed to the Apache Software Foundation (ASF) under one
-- or more contributor license agreements. See the NOTICE file
-- distributed with this work for additional information
-- regarding copyright ownership. The ASF licenses this file
-- to you under the Apache License, Version 2.0 (the
-- "License"); you may not use this file except in compliance
-- with the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing,
-- software distributed under the License is distributed on an
-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-- KIND, either express or implied. See the License for the
-- specific language governing permissions and limitations
-- under the License.

-- try_url_decode is Spark 4.0+. It rewrites to UrlDecode(_, failOnError=false),
-- which becomes StaticInvoke(UrlCodec, "decode", [child, Literal(false)], ...).
-- CometUrlDecodeStaticInvoke detects failOnError=false and emits try_url_decode.

-- MinSparkVersion: 4.0

statement
CREATE TABLE test_try_decode(s string) USING parquet

statement
INSERT INTO test_try_decode VALUES
('https%3A%2F%2Fspark.apache.org'),
('hello+world'),
('a%2Bb%3Dc%26d%3De'),
('caf%C3%A9'),
(''),
(NULL),
('no+encoding+needed'),
('%21%40%23%24%25%5E%26%2A%28%29%5F%2B'),
('%2a%2b%2c'),
('http%3A%2F%2spark.apache.org')

query
SELECT try_url_decode(s) FROM test_try_decode

-- literal arguments
query
SELECT try_url_decode('https%3A%2F%2Fspark.apache.org')

query
SELECT try_url_decode('hello+world')

query
SELECT try_url_decode('')

query
SELECT try_url_decode(NULL)

-- roundtrip: encode then decode
query
SELECT try_url_decode(url_encode('hello world & goodbye'))

-- multibyte UTF-8
query
SELECT try_url_decode('%E6%97%A5%E6%9C%AC%E8%AA%9E%E3%83%86%E3%82%B9%E3%83%88')

-- lowercase hex (RFC 3986 says hex digits are case-insensitive)
query
SELECT try_url_decode('%2a%2b%2c')

-- malformed percent-encoding: try_url_decode returns NULL instead of erroring
query
SELECT try_url_decode('http%3A%2F%2spark.apache.org')
65 changes: 65 additions & 0 deletions spark/src/test/resources/sql-tests/expressions/url/url_decode.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
-- Licensed to the Apache Software Foundation (ASF) under one
-- or more contributor license agreements. See the NOTICE file
-- distributed with this work for additional information
-- regarding copyright ownership. The ASF licenses this file
-- to you under the Apache License, Version 2.0 (the
-- "License"); you may not use this file except in compliance
-- with the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing,
-- software distributed under the License is distributed on an
-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-- KIND, either express or implied. See the License for the
-- specific language governing permissions and limitations
-- under the License.

-- url_decode function
statement
CREATE TABLE test_decode(s string) USING parquet

statement
INSERT INTO test_decode VALUES
('https%3A%2F%2Fspark.apache.org'),
('hello+world'),
('a%2Bb%3Dc%26d%3De'),
('caf%C3%A9'),
(''),
(NULL),
('no+encoding+needed'),
('%21%40%23%24%25%5E%26%2A%28%29%5F%2B'),
('%2a%2b%2c')

query
SELECT url_decode(s) FROM test_decode

-- literal arguments
query
SELECT url_decode('https%3A%2F%2Fspark.apache.org')

query
SELECT url_decode('hello+world')

query
SELECT url_decode('')

query
SELECT url_decode(NULL)

-- roundtrip: encode then decode
query
SELECT url_decode(url_encode('hello world & goodbye'))

-- multibyte UTF-8
query
SELECT url_decode('%E6%97%A5%E6%9C%AC%E8%AA%9E%E3%83%86%E3%82%B9%E3%83%88')

-- lowercase hex (RFC 3986 says hex digits are case-insensitive)
query
SELECT url_decode('%2a%2b%2c')

-- malformed percent-encoding: both Spark and Comet must error and the bad
-- sequence must appear in the error message
query expect_error(%2s)
SELECT url_decode('http%3A%2F%2spark.apache.org')
53 changes: 53 additions & 0 deletions spark/src/test/resources/sql-tests/expressions/url/url_encode.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
-- Licensed to the Apache Software Foundation (ASF) under one
-- or more contributor license agreements. See the NOTICE file
-- distributed with this work for additional information
-- regarding copyright ownership. The ASF licenses this file
-- to you under the Apache License, Version 2.0 (the
-- "License"); you may not use this file except in compliance
-- with the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing,
-- software distributed under the License is distributed on an
-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-- KIND, either express or implied. See the License for the
-- specific language governing permissions and limitations
-- under the License.

-- url_encode function
statement
CREATE TABLE test_encode(s string) USING parquet

statement
INSERT INTO test_encode VALUES
('https://spark.apache.org'),
('hello world'),
('a+b=c&d=e'),
(''),
(NULL),
('foo bar/baz?x=1&y=2')

query
SELECT url_encode(s) FROM test_encode

-- literal arguments
query
SELECT url_encode('https://spark.apache.org')

query
SELECT url_encode('hello world')

query
SELECT url_encode('')

query
SELECT url_encode(NULL)

-- special characters
query
SELECT url_encode('a b+c&d=e/f')

-- multibyte UTF-8
query
SELECT url_encode('日本語テスト')
Loading