Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions sqlglot/dialects/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,45 @@ def _date_sql(self: DuckDB.Generator, expression: exp.Date) -> str:
return result


def _to_binary_sql(self: DuckDB.Generator, expression: exp.ToBinary) -> str:
"""
TO_BINARY(value, format) transpilation if the return_type is BINARY:
- 'HEX': TO_BINARY('48454C50', 'HEX') → UNHEX('48454C50')
- 'UTF-8': TO_BINARY('TEST', 'UTF-8') → ENCODE('TEST')
- 'BASE64': TO_BINARY('SEVMUA==', 'BASE64') → FROM_BASE64('SEVMUA==')

format can be 'HEX', 'UTF-8' or 'BASE64'
return_type can be either VARCHAR or BINARY
"""
value = expression.this
format_arg = expression.args.get("format")
return_type_arg = expression.args.get("return_type")

return_type = "VARCHAR"
if return_type_arg:
return_type = return_type_arg.to_py().upper()

format = "UTF-8"
if format_arg:
format = format_arg.to_py().upper()

if format == "HEX":
if return_type == "BINARY":
return self.func("UNHEX", value)
elif format == "BASE64":
if return_type == "BINARY":
return self.func("FROM_BASE64", value)
else: # UTF-8
if return_type == "BINARY":
return self.func("ENCODE", value)
else:
# DuckDB's TO_BINARY takes a UTF-8 string and returns a binary string representation (like '0101010...') of type VARCHAR
return self.func("TO_BINARY", value)

# Fallback, which needs to be updated if want to support transpilation from other dialects than Snowflake
return self.func("TO_BINARY", value)


# BigQuery -> DuckDB conversion for the TIME_DIFF function
def _timediff_sql(self: DuckDB.Generator, expression: exp.TimeDiff) -> str:
this = exp.cast(expression.this, exp.DataType.Type.TIME)
Expand Down Expand Up @@ -586,6 +625,11 @@ class Parser(parser.Parser):
"STR_SPLIT": exp.Split.from_arg_list,
"STR_SPLIT_REGEX": exp.RegexpSplit.from_arg_list,
"TIME_BUCKET": exp.DateBin.from_arg_list,
"TO_BINARY": lambda args: exp.ToBinary(
this=seq_get(args, 0),
format=exp.Literal.string("UTF-8"),
return_type=exp.Literal.string("VARCHAR"),
),
"TO_TIMESTAMP": exp.UnixToTime.from_arg_list,
"UNNEST": exp.Explode.from_arg_list,
"XOR": binary_from_function(exp.BitwiseXor),
Expand Down Expand Up @@ -931,6 +975,7 @@ class Generator(generator.Generator):
exp.Time: no_time_sql,
exp.TimeDiff: _timediff_sql,
exp.Timestamp: no_timestamp_sql,
exp.ToBinary: _to_binary_sql,
exp.TimestampAdd: date_delta_to_binary_interval_op(),
exp.TimestampDiff: lambda self, e: self.func(
"DATE_DIFF", exp.Literal.string(e.unit), e.expression, e.this
Expand Down
6 changes: 6 additions & 0 deletions sqlglot/dialects/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,6 +769,11 @@ class Parser(parser.Parser):
"TRY_TO_TIMESTAMP": _build_datetime(
"TRY_TO_TIMESTAMP", exp.DataType.Type.TIMESTAMP, safe=True
),
"TO_BINARY": lambda args: exp.ToBinary(
this=seq_get(args, 0),
format=seq_get(args, 1) or exp.Literal.string("HEX"),
return_type=exp.Literal.string("BINARY"),
),
"TO_CHAR": build_timetostr_or_tochar,
"TO_DATE": _build_datetime("TO_DATE", exp.DataType.Type.DATE),
"TO_NUMBER": lambda args: exp.ToNumber(
Expand Down Expand Up @@ -1528,6 +1533,7 @@ class Generator(generator.Generator):
exp.SHA2Digest: lambda self, e: self.func(
"SHA2_BINARY", e.this, e.args.get("length") or exp.Literal.number(256)
),
exp.ToBinary: lambda self, e: self.func("TO_BINARY", e.this, e.args.get("format")),
}

SUPPORTED_JSON_PATH_PARTS = {
Expand Down
2 changes: 1 addition & 1 deletion sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6713,7 +6713,7 @@ class ToBase64(Func):


class ToBinary(Func):
arg_types = {"this": True, "format": False}
arg_types = {"this": True, "format": False, "return_type": False}


# https://docs.snowflake.com/en/sql-reference/functions/base64_decode_binary
Expand Down
4 changes: 0 additions & 4 deletions tests/dialects/test_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,16 +623,13 @@ def test_to_binary(self):
"TO_BINARY('test')",
read={
"": "TO_BINARY('test')",
"snowflake": "TO_BINARY('test')",
"starrocks": "TO_BINARY('test')",
"duckdb": "TO_BINARY('test')",
"spark": "TO_BINARY('test')",
"databricks": "TO_BINARY('test')",
},
write={
"snowflake": "TO_BINARY('test')",
"starrocks": "TO_BINARY('test')",
"duckdb": "TO_BINARY('test')",
"spark": "TO_BINARY('test')",
"databricks": "TO_BINARY('test')",
},
Expand All @@ -641,7 +638,6 @@ def test_to_binary(self):
"TO_BINARY('test', 'HEX')",
read={
"": "TO_BINARY('test', 'HEX')",
"snowflake": "TO_BINARY('test', 'HEX')",
"starrocks": "TO_BINARY('test', 'HEX')",
"spark": "TO_BINARY('test', 'HEX')",
"databricks": "TO_BINARY('test', 'HEX')",
Expand Down
1 change: 1 addition & 0 deletions tests/dialects/test_duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,7 @@ def test_duckdb(self):
self.validate_identity("SELECT LIST_TRANSFORM([5, NULL, 6], LAMBDA x : COALESCE(x, 0) + 1)")
self.validate_identity("SELECT LIST_TRANSFORM(nbr, LAMBDA x : x + 1) FROM article AS a")
self.validate_identity("SELECT * FROM my_ducklake.demo AT (VERSION => 2)")
self.validate_identity("SELECT TO_BINARY('test')")
self.validate_identity("SELECT UUIDV7()")
self.validate_identity("SELECT TRY(LOG(0))")
self.validate_identity("x::timestamp", "CAST(x AS TIMESTAMP)")
Expand Down
23 changes: 23 additions & 0 deletions tests/dialects/test_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,8 @@ def test_snowflake(self):
self.validate_identity("SELECT a, exclude, b FROM xxx")
self.validate_identity("SELECT ARRAY_SORT(x, TRUE, FALSE)")
self.validate_identity("SELECT BOOLXOR_AGG(col) FROM tbl")
self.validate_identity("SELECT TO_BINARY('CA')", "SELECT TO_BINARY('CA', 'HEX')")
self.validate_identity("SELECT TO_BINARY('CA', 'HEX')")
self.validate_identity(
"SELECT PERCENTILE_DISC(0.9) WITHIN GROUP (ORDER BY col) OVER (PARTITION BY category)"
)
Expand Down Expand Up @@ -719,6 +721,27 @@ def test_snowflake(self):
"redshift": "SELECT GETBIT(11, 3)",
},
)
self.validate_all(
"SELECT TO_BINARY('48454C50', 'HEX')",
write={
"snowflake": "SELECT TO_BINARY('48454C50', 'HEX')",
"duckdb": "SELECT UNHEX('48454C50')",
},
)
self.validate_all(
"SELECT TO_BINARY('TEST', 'UTF-8')",
write={
"snowflake": "SELECT TO_BINARY('TEST', 'UTF-8')",
"duckdb": "SELECT ENCODE('TEST')",
},
)
self.validate_all(
"SELECT TO_BINARY('SEVMUA==', 'BASE64')",
write={
"snowflake": "SELECT TO_BINARY('SEVMUA==', 'BASE64')",
"duckdb": "SELECT FROM_BASE64('SEVMUA==')",
},
)
self.validate_identity(
"SELECT TIMESTAMPNTZFROMPARTS(2013, 4, 5, 12, 00, 00)",
"SELECT TIMESTAMP_FROM_PARTS(2013, 4, 5, 12, 00, 00)",
Expand Down