diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index 4698df93e7..62fcccf614 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -1600,6 +1600,15 @@ def round_sql(self, expression: exp.Round) -> str: decimals = expression.args.get("decimals") truncate = expression.args.get("truncate") + # DuckDB requires the scale (decimals) argument to be an INT + # Some dialects (e.g., Snowflake) allow non-integer scales and cast to an integer internally + if decimals is not None and expression.args.get("casts_non_integer_decimals"): + if isinstance(decimals, exp.Literal): + if not decimals.is_int: + decimals = exp.cast(decimals, exp.DataType.Type.INT) + elif not decimals.is_type(*exp.DataType.INTEGER_TYPES): + decimals = exp.cast(decimals, exp.DataType.Type.INT) + func = "ROUND" if truncate: # BigQuery uses ROUND_HALF_EVEN; Snowflake uses HALF_TO_EVEN diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 4e9663ab0f..033a46a2e4 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -540,6 +540,36 @@ def _build_timestamp_from_parts(args: t.List) -> exp.Func: return exp.TimestampFromParts.from_arg_list(args) +def _build_round(args: t.List) -> exp.Round: + """ + Build Round expression, unwrapping Snowflake's named parameters. + + Maps EXPR => this, SCALE => decimals, ROUNDING_MODE => truncate. + + Note: Snowflake does not support mixing named and positional arguments. + Arguments are either all named or all positional. + """ + kwarg_map = {"EXPR": "this", "SCALE": "decimals", "ROUNDING_MODE": "truncate"} + round_args = {} + positional_keys = ["this", "decimals", "truncate"] + positional_idx = 0 + + for arg in args: + if isinstance(arg, exp.Kwarg): + key = arg.this.name.upper() + round_key = kwarg_map.get(key) + if round_key: + round_args[round_key] = arg.expression + else: + if positional_idx < len(positional_keys): + round_args[positional_keys[positional_idx]] = arg + positional_idx += 1 + + expression = exp.Round(**round_args) + expression.set("casts_non_integer_decimals", True) + return expression + + class Snowflake(Dialect): # https://docs.snowflake.com/en/sql-reference/identifiers-syntax NORMALIZATION_STRATEGY = NormalizationStrategy.UPPERCASE @@ -707,6 +737,7 @@ class Parser(parser.Parser): "REGEXP_SUBSTR_ALL": _build_regexp_extract(exp.RegexpExtractAll), "REPLACE": build_replace_with_optional_replacement, "RLIKE": exp.RegexpLike.from_arg_list, + "ROUND": _build_round, "SHA1_BINARY": exp.SHA1Digest.from_arg_list, "SHA1_HEX": exp.SHA.from_arg_list, "SHA2_BINARY": exp.SHA2Digest.from_arg_list, diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 8f63d06af0..c6eb5a35ae 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -7655,7 +7655,12 @@ class Radians(Func): # https://learn.microsoft.com/en-us/sql/t-sql/functions/round-transact-sql?view=sql-server-ver16 # tsql third argument function == trunctaion if not 0 class Round(Func): - arg_types = {"this": True, "decimals": False, "truncate": False} + arg_types = { + "this": True, + "decimals": False, + "truncate": False, + "casts_non_integer_decimals": False, + } class RowNumber(Func): diff --git a/tests/dialects/test_duckdb.py b/tests/dialects/test_duckdb.py index 0aa247f625..0f3abe44ef 100644 --- a/tests/dialects/test_duckdb.py +++ b/tests/dialects/test_duckdb.py @@ -1247,6 +1247,7 @@ def test_duckdb(self): ) self.validate_identity("SELECT GREATEST(1.0, 2.5, NULL, 3.7)") self.validate_identity("FROM t1, t2 SELECT *", "SELECT * FROM t1, t2") + self.validate_identity("ROUND(2.256, 1)") # TODO: This is incorrect AST, DATE_PART creates a STRUCT of values but it's stored in 'year' arg self.validate_identity("SELECT MAKE_DATE(DATE_PART(['year', 'month', 'day'], TODAY()))") diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index adfee79282..abed49b489 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -3670,6 +3670,22 @@ def test_round(self): }, ) + self.validate_all( + "SELECT ROUND(EXPR => 2.25, SCALE => 1) AS value", + write={ + "snowflake": "SELECT ROUND(2.25, 1) AS value", + "duckdb": "SELECT ROUND(2.25, 1) AS value", + }, + ) + + self.validate_all( + "SELECT ROUND(SCALE => 1, EXPR => 2.25) AS value", + write={ + "snowflake": "SELECT ROUND(2.25, 1) AS value", + "duckdb": "SELECT ROUND(2.25, 1) AS value", + }, + ) + self.validate_all( "SELECT ROUND(2.25, 1, 'HALF_AWAY_FROM_ZERO') AS value", write={ @@ -3678,10 +3694,58 @@ def test_round(self): }, ) + self.validate_all( + "SELECT ROUND(EXPR => 2.25, SCALE => 1, ROUNDING_MODE => 'HALF_AWAY_FROM_ZERO') AS value", + write={ + "snowflake": "SELECT ROUND(2.25, 1, 'HALF_AWAY_FROM_ZERO') AS value", + "duckdb": "SELECT ROUND(2.25, 1) AS value", + }, + ) + self.validate_all( "SELECT ROUND(2.25, 1, 'HALF_TO_EVEN') AS value", write={ - "snowflake": """SELECT ROUND(2.25, 1, 'HALF_TO_EVEN') AS value""", + "snowflake": "SELECT ROUND(2.25, 1, 'HALF_TO_EVEN') AS value", + "duckdb": "SELECT ROUND_EVEN(2.25, 1) AS value", + }, + ) + + self.validate_all( + "SELECT ROUND(ROUNDING_MODE => 'HALF_TO_EVEN', EXPR => 2.25, SCALE => 1) AS value", + write={ + "snowflake": "SELECT ROUND(2.25, 1, 'HALF_TO_EVEN') AS value", + "duckdb": "SELECT ROUND_EVEN(2.25, 1) AS value", + }, + ) + + self.validate_all( + "SELECT ROUND(SCALE => 1, EXPR => 2.25, , ROUNDING_MODE => 'HALF_TO_EVEN') AS value", + write={ + "snowflake": "SELECT ROUND(2.25, 1, 'HALF_TO_EVEN') AS value", + "duckdb": "SELECT ROUND_EVEN(2.25, 1) AS value", + }, + ) + + self.validate_all( + "SELECT ROUND(EXPR => 2.25, SCALE => 1, ROUNDING_MODE => 'HALF_TO_EVEN') AS value", + write={ + "snowflake": "SELECT ROUND(2.25, 1, 'HALF_TO_EVEN') AS value", "duckdb": "SELECT ROUND_EVEN(2.25, 1) AS value", }, ) + + self.validate_all( + "SELECT ROUND(2.256, 1.8) AS value", + write={ + "snowflake": "SELECT ROUND(2.256, 1.8) AS value", + "duckdb": "SELECT ROUND(2.256, CAST(1.8 AS INT)) AS value", + }, + ) + + self.validate_all( + "SELECT ROUND(2.256, CAST(1.8 AS DECIMAL(38, 0))) AS value", + write={ + "snowflake": "SELECT ROUND(2.256, CAST(1.8 AS DECIMAL(38, 0))) AS value", + "duckdb": "SELECT ROUND(2.256, CAST(CAST(1.8 AS DECIMAL(38, 0)) AS INT)) AS value", + }, + )