Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions sqlglot/dialects/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -1600,6 +1600,15 @@ def round_sql(self, expression: exp.Round) -> str:
decimals = expression.args.get("decimals")
truncate = expression.args.get("truncate")

# DuckDB requires the scale (decimals) argument to be an INT
# Some dialects (e.g., Snowflake) allow non-integer scales and cast to an integer internally
if decimals is not None and expression.args.get("casts_non_integer_decimals"):
if isinstance(decimals, exp.Literal):
if not decimals.is_int:
decimals = exp.cast(decimals, exp.DataType.Type.INT)
elif not decimals.is_type(*exp.DataType.INTEGER_TYPES):
decimals = exp.cast(decimals, exp.DataType.Type.INT)

func = "ROUND"
if truncate:
# BigQuery uses ROUND_HALF_EVEN; Snowflake uses HALF_TO_EVEN
Expand Down
31 changes: 31 additions & 0 deletions sqlglot/dialects/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,36 @@ def _build_timestamp_from_parts(args: t.List) -> exp.Func:
return exp.TimestampFromParts.from_arg_list(args)


def _build_round(args: t.List) -> exp.Round:
"""
Build Round expression, unwrapping Snowflake's named parameters.

Maps EXPR => this, SCALE => decimals, ROUNDING_MODE => truncate.

Note: Snowflake does not support mixing named and positional arguments.
Arguments are either all named or all positional.
"""
kwarg_map = {"EXPR": "this", "SCALE": "decimals", "ROUNDING_MODE": "truncate"}
round_args = {}
positional_keys = ["this", "decimals", "truncate"]
positional_idx = 0

for arg in args:
if isinstance(arg, exp.Kwarg):
key = arg.this.name.upper()
round_key = kwarg_map.get(key)
if round_key:
round_args[round_key] = arg.expression
else:
if positional_idx < len(positional_keys):
round_args[positional_keys[positional_idx]] = arg
positional_idx += 1

expression = exp.Round(**round_args)
expression.set("casts_non_integer_decimals", True)
return expression


class Snowflake(Dialect):
# https://docs.snowflake.com/en/sql-reference/identifiers-syntax
NORMALIZATION_STRATEGY = NormalizationStrategy.UPPERCASE
Expand Down Expand Up @@ -707,6 +737,7 @@ class Parser(parser.Parser):
"REGEXP_SUBSTR_ALL": _build_regexp_extract(exp.RegexpExtractAll),
"REPLACE": build_replace_with_optional_replacement,
"RLIKE": exp.RegexpLike.from_arg_list,
"ROUND": _build_round,
"SHA1_BINARY": exp.SHA1Digest.from_arg_list,
"SHA1_HEX": exp.SHA.from_arg_list,
"SHA2_BINARY": exp.SHA2Digest.from_arg_list,
Expand Down
7 changes: 6 additions & 1 deletion sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7655,7 +7655,12 @@ class Radians(Func):
# https://learn.microsoft.com/en-us/sql/t-sql/functions/round-transact-sql?view=sql-server-ver16
# tsql third argument function == trunctaion if not 0
class Round(Func):
arg_types = {"this": True, "decimals": False, "truncate": False}
arg_types = {
"this": True,
"decimals": False,
"truncate": False,
"casts_non_integer_decimals": False,
}


class RowNumber(Func):
Expand Down
1 change: 1 addition & 0 deletions tests/dialects/test_duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -1247,6 +1247,7 @@ def test_duckdb(self):
)
self.validate_identity("SELECT GREATEST(1.0, 2.5, NULL, 3.7)")
self.validate_identity("FROM t1, t2 SELECT *", "SELECT * FROM t1, t2")
self.validate_identity("ROUND(2.256, 1)")

# TODO: This is incorrect AST, DATE_PART creates a STRUCT of values but it's stored in 'year' arg
self.validate_identity("SELECT MAKE_DATE(DATE_PART(['year', 'month', 'day'], TODAY()))")
Expand Down
66 changes: 65 additions & 1 deletion tests/dialects/test_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -3670,6 +3670,22 @@ def test_round(self):
},
)

self.validate_all(
"SELECT ROUND(EXPR => 2.25, SCALE => 1) AS value",
write={
"snowflake": "SELECT ROUND(2.25, 1) AS value",
"duckdb": "SELECT ROUND(2.25, 1) AS value",
},
)

self.validate_all(
"SELECT ROUND(SCALE => 1, EXPR => 2.25) AS value",
write={
"snowflake": "SELECT ROUND(2.25, 1) AS value",
"duckdb": "SELECT ROUND(2.25, 1) AS value",
},
)

self.validate_all(
"SELECT ROUND(2.25, 1, 'HALF_AWAY_FROM_ZERO') AS value",
write={
Expand All @@ -3678,10 +3694,58 @@ def test_round(self):
},
)

self.validate_all(
"SELECT ROUND(EXPR => 2.25, SCALE => 1, ROUNDING_MODE => 'HALF_AWAY_FROM_ZERO') AS value",
write={
"snowflake": "SELECT ROUND(2.25, 1, 'HALF_AWAY_FROM_ZERO') AS value",
"duckdb": "SELECT ROUND(2.25, 1) AS value",
},
)

self.validate_all(
"SELECT ROUND(2.25, 1, 'HALF_TO_EVEN') AS value",
write={
"snowflake": """SELECT ROUND(2.25, 1, 'HALF_TO_EVEN') AS value""",
"snowflake": "SELECT ROUND(2.25, 1, 'HALF_TO_EVEN') AS value",
"duckdb": "SELECT ROUND_EVEN(2.25, 1) AS value",
},
)

self.validate_all(
"SELECT ROUND(ROUNDING_MODE => 'HALF_TO_EVEN', EXPR => 2.25, SCALE => 1) AS value",
write={
"snowflake": "SELECT ROUND(2.25, 1, 'HALF_TO_EVEN') AS value",
"duckdb": "SELECT ROUND_EVEN(2.25, 1) AS value",
},
)

self.validate_all(
"SELECT ROUND(SCALE => 1, EXPR => 2.25, , ROUNDING_MODE => 'HALF_TO_EVEN') AS value",
write={
"snowflake": "SELECT ROUND(2.25, 1, 'HALF_TO_EVEN') AS value",
"duckdb": "SELECT ROUND_EVEN(2.25, 1) AS value",
},
)

self.validate_all(
"SELECT ROUND(EXPR => 2.25, SCALE => 1, ROUNDING_MODE => 'HALF_TO_EVEN') AS value",
write={
"snowflake": "SELECT ROUND(2.25, 1, 'HALF_TO_EVEN') AS value",
"duckdb": "SELECT ROUND_EVEN(2.25, 1) AS value",
},
)

self.validate_all(
"SELECT ROUND(2.256, 1.8) AS value",
write={
"snowflake": "SELECT ROUND(2.256, 1.8) AS value",
"duckdb": "SELECT ROUND(2.256, CAST(1.8 AS INT)) AS value",
},
)

self.validate_all(
"SELECT ROUND(2.256, CAST(1.8 AS DECIMAL(38, 0))) AS value",
write={
"snowflake": "SELECT ROUND(2.256, CAST(1.8 AS DECIMAL(38, 0))) AS value",
"duckdb": "SELECT ROUND(2.256, CAST(CAST(1.8 AS DECIMAL(38, 0)) AS INT)) AS value",
},
)