diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index 5b5e652c94..8259ec6ec7 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -85,6 +85,35 @@ def _date_sql(self: DuckDB.Generator, expression: exp.Date) -> str: return result +def _to_binary_sql(self: DuckDB.Generator, expression: exp.ToBinary) -> str: + """ + TO_BINARY(value, format) transpilation if the return type is BINARY: + - 'HEX': TO_BINARY('48454C50', 'HEX') → UNHEX('48454C50') + - 'UTF-8': TO_BINARY('TEST', 'UTF-8') → ENCODE('TEST') + - 'BASE64': TO_BINARY('SEVMUA==', 'BASE64') → FROM_BASE64('SEVMUA==') + + format can be 'HEX', 'UTF-8' or 'BASE64' + return type can be either VARCHAR or BINARY + """ + value = expression.this + format_arg = expression.args.get("format") + + format = "HEX" + if format_arg: + format = format_arg.name.upper() + + if expression.is_type(exp.DataType.Type.BINARY): + if format == "UTF-8": + return self.func("ENCODE", value) + elif format == "BASE64": + return self.func("FROM_BASE64", value) + else: # HEX + return self.func("UNHEX", value) + + # Fallback, which needs to be updated if want to support transpilation from other dialects than Snowflake + return self.func("TO_BINARY", value) + + # BigQuery -> DuckDB conversion for the TIME_DIFF function def _timediff_sql(self: DuckDB.Generator, expression: exp.TimeDiff) -> str: this = exp.cast(expression.this, exp.DataType.Type.TIME) @@ -931,6 +960,7 @@ class Generator(generator.Generator): exp.Time: no_time_sql, exp.TimeDiff: _timediff_sql, exp.Timestamp: no_timestamp_sql, + exp.ToBinary: _to_binary_sql, exp.TimestampAdd: date_delta_to_binary_interval_op(), exp.TimestampDiff: lambda self, e: self.func( "DATE_DIFF", exp.Literal.string(e.unit), e.expression, e.this diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index d4e1c61a40..31dcd8386b 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -620,37 +620,37 @@ def test_decode(self): def test_to_binary(self): self.validate_all( - "TO_BINARY('test')", + "TO_BINARY('1C')", read={ - "": "TO_BINARY('test')", - "snowflake": "TO_BINARY('test')", - "starrocks": "TO_BINARY('test')", - "duckdb": "TO_BINARY('test')", - "spark": "TO_BINARY('test')", - "databricks": "TO_BINARY('test')", + "": "TO_BINARY('1C')", + "snowflake": "TO_BINARY('1C')", + "starrocks": "TO_BINARY('1C')", + "duckdb": "TO_BINARY('1C')", + "spark": "TO_BINARY('1C')", + "databricks": "TO_BINARY('1C')", }, write={ - "snowflake": "TO_BINARY('test')", - "starrocks": "TO_BINARY('test')", - "duckdb": "TO_BINARY('test')", - "spark": "TO_BINARY('test')", - "databricks": "TO_BINARY('test')", + "snowflake": "TO_BINARY('1C')", + "starrocks": "TO_BINARY('1C')", + "duckdb": "TO_BINARY('1C')", + "spark": "TO_BINARY('1C')", + "databricks": "TO_BINARY('1C')", }, ) self.validate_all( - "TO_BINARY('test', 'HEX')", + "TO_BINARY('1C', 'HEX')", read={ - "": "TO_BINARY('test', 'HEX')", - "snowflake": "TO_BINARY('test', 'HEX')", - "starrocks": "TO_BINARY('test', 'HEX')", - "spark": "TO_BINARY('test', 'HEX')", - "databricks": "TO_BINARY('test', 'HEX')", + "": "TO_BINARY('1C', 'HEX')", + "snowflake": "TO_BINARY('1C', 'HEX')", + "starrocks": "TO_BINARY('1C', 'HEX')", + "spark": "TO_BINARY('1C', 'HEX')", + "databricks": "TO_BINARY('1C', 'HEX')", }, write={ - "snowflake": "TO_BINARY('test', 'HEX')", - "starrocks": "TO_BINARY('test', 'HEX')", - "spark": "TO_BINARY('test', 'HEX')", - "databricks": "TO_BINARY('test', 'HEX')", + "snowflake": "TO_BINARY('1C', 'HEX')", + "starrocks": "TO_BINARY('1C', 'HEX')", + "spark": "TO_BINARY('1C', 'HEX')", + "databricks": "TO_BINARY('1C', 'HEX')", }, ) diff --git a/tests/dialects/test_duckdb.py b/tests/dialects/test_duckdb.py index 6e0836d87c..b3d14e1dcf 100644 --- a/tests/dialects/test_duckdb.py +++ b/tests/dialects/test_duckdb.py @@ -380,6 +380,7 @@ def test_duckdb(self): self.validate_identity("SELECT LIST_TRANSFORM([5, NULL, 6], LAMBDA x : COALESCE(x, 0) + 1)") self.validate_identity("SELECT LIST_TRANSFORM(nbr, LAMBDA x : x + 1) FROM article AS a") self.validate_identity("SELECT * FROM my_ducklake.demo AT (VERSION => 2)") + self.validate_identity("SELECT TO_BINARY('test')") self.validate_identity("SELECT UUIDV7()") self.validate_identity("SELECT TRY(LOG(0))") self.validate_identity("x::timestamp", "CAST(x AS TIMESTAMP)") @@ -707,6 +708,23 @@ def test_duckdb(self): "duckdb": "CREATE TABLE IF NOT EXISTS t (cola INT, colb TEXT)", }, ) + + expr = self.parse_one("TO_BINARY('48454C50', 'HEX')", dialect="snowflake") + annotated = annotate_types(expr, dialect="snowflake") + self.assertEqual(annotated.sql("duckdb"), "UNHEX('48454C50')") + + expr = self.parse_one("TO_BINARY('48454C50')", dialect="snowflake") + annotated = annotate_types(expr, dialect="snowflake") + self.assertEqual(annotated.sql("duckdb"), "UNHEX('48454C50')") + + expr = self.parse_one("TO_BINARY('TEST', 'UTF-8')", dialect="snowflake") + annotated = annotate_types(expr, dialect="snowflake") + self.assertEqual(annotated.sql("duckdb"), "ENCODE('TEST')") + + expr = self.parse_one("TO_BINARY('SEVMUA==', 'BASE64')", dialect="snowflake") + annotated = annotate_types(expr, dialect="snowflake") + self.assertEqual(annotated.sql("duckdb"), "FROM_BASE64('SEVMUA==')") + self.validate_all( "[0, 1, 2]", read={ diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index 978f14e71f..06447a6902 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -288,6 +288,9 @@ def test_snowflake(self): self.validate_identity("SELECT a, exclude, b FROM xxx") self.validate_identity("SELECT ARRAY_SORT(x, TRUE, FALSE)") self.validate_identity("SELECT BOOLXOR_AGG(col) FROM tbl") + self.validate_identity("SELECT TO_BINARY('C2')") + self.validate_identity("SELECT TO_BINARY('C2', 'HEX')") + self.validate_identity("SELECT TO_BINARY('café', 'UTF-8')") self.validate_identity( "SELECT PERCENTILE_DISC(0.9) WITHIN GROUP (ORDER BY col) OVER (PARTITION BY category)" )