From b3862d8b4f79ce85357287737796e54c288c5b9b Mon Sep 17 00:00:00 2001 From: Tori Wei <41123940+toriwei@users.noreply.github.com> Date: Fri, 5 Dec 2025 10:52:25 -0800 Subject: [PATCH 1/3] support named arguments and non-integer scale values --- sqlglot/dialects/duckdb.py | 16 +++++++++++++ sqlglot/dialects/snowflake.py | 8 +++++++ sqlglot/expressions.py | 7 +++++- tests/dialects/test_duckdb.py | 2 ++ tests/dialects/test_snowflake.py | 40 ++++++++++++++++++++++++++++++++ 5 files changed, 72 insertions(+), 1 deletion(-) diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index 4698df93e7..fb2c46828e 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -1600,6 +1600,22 @@ def round_sql(self, expression: exp.Round) -> str: decimals = expression.args.get("decimals") truncate = expression.args.get("truncate") + if isinstance(this, exp.Kwarg): + this = this.expression + if isinstance(decimals, exp.Kwarg): + decimals = decimals.expression + if isinstance(truncate, exp.Kwarg): + truncate = truncate.expression + + # DuckDB requires the scale (decimals) argument to be an INT + # Some dialects (e.g., Snowflake) allow non-integer scales and cast to an integer internally + if decimals is not None and expression.args.get("casts_non_integer_decimals"): + if isinstance(decimals, exp.Literal): + if not decimals.is_int: + decimals = exp.cast(decimals, exp.DataType.Type.INT) + elif not decimals.is_type(*exp.DataType.INTEGER_TYPES): + decimals = exp.cast(decimals, exp.DataType.Type.INT) + func = "ROUND" if truncate: # BigQuery uses ROUND_HALF_EVEN; Snowflake uses HALF_TO_EVEN diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 4e9663ab0f..e14af3edfc 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -540,6 +540,12 @@ def _build_timestamp_from_parts(args: t.List) -> exp.Func: return exp.TimestampFromParts.from_arg_list(args) +def _build_round(args: t.List) -> exp.Round: + expression = exp.Round.from_arg_list(args) + expression.set("casts_non_integer_decimals", True) + return expression + + class Snowflake(Dialect): # https://docs.snowflake.com/en/sql-reference/identifiers-syntax NORMALIZATION_STRATEGY = NormalizationStrategy.UPPERCASE @@ -554,6 +560,7 @@ class Snowflake(Dialect): ALTER_TABLE_ADD_REQUIRED_FOR_EACH_COLUMN = False TRY_CAST_REQUIRES_STRING = True SUPPORTS_ALIAS_REFS_IN_JOIN_CONDITIONS = True + ROUND_casts_non_integer_decimals = True EXPRESSION_METADATA = EXPRESSION_METADATA.copy() @@ -707,6 +714,7 @@ class Parser(parser.Parser): "REGEXP_SUBSTR_ALL": _build_regexp_extract(exp.RegexpExtractAll), "REPLACE": build_replace_with_optional_replacement, "RLIKE": exp.RegexpLike.from_arg_list, + "ROUND": _build_round, "SHA1_BINARY": exp.SHA1Digest.from_arg_list, "SHA1_HEX": exp.SHA.from_arg_list, "SHA2_BINARY": exp.SHA2Digest.from_arg_list, diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 8f63d06af0..c6eb5a35ae 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -7655,7 +7655,12 @@ class Radians(Func): # https://learn.microsoft.com/en-us/sql/t-sql/functions/round-transact-sql?view=sql-server-ver16 # tsql third argument function == trunctaion if not 0 class Round(Func): - arg_types = {"this": True, "decimals": False, "truncate": False} + arg_types = { + "this": True, + "decimals": False, + "truncate": False, + "casts_non_integer_decimals": False, + } class RowNumber(Func): diff --git a/tests/dialects/test_duckdb.py b/tests/dialects/test_duckdb.py index 0aa247f625..a7b2933ebb 100644 --- a/tests/dialects/test_duckdb.py +++ b/tests/dialects/test_duckdb.py @@ -1247,6 +1247,8 @@ def test_duckdb(self): ) self.validate_identity("SELECT GREATEST(1.0, 2.5, NULL, 3.7)") self.validate_identity("FROM t1, t2 SELECT *", "SELECT * FROM t1, t2") + self.validate_identity("ROUND(2.256, 1.8)") + self.validate_identity("ROUND(2.256, 1)") # TODO: This is incorrect AST, DATE_PART creates a STRUCT of values but it's stored in 'year' arg self.validate_identity("SELECT MAKE_DATE(DATE_PART(['year', 'month', 'day'], TODAY()))") diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index adfee79282..9783802575 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -3670,6 +3670,14 @@ def test_round(self): }, ) + self.validate_all( + "SELECT ROUND(EXPR => 2.25, SCALE => 1) AS value", + write={ + "snowflake": "SELECT ROUND(EXPR => 2.25, SCALE => 1) AS value", + "duckdb": "SELECT ROUND(2.25, 1) AS value", + }, + ) + self.validate_all( "SELECT ROUND(2.25, 1, 'HALF_AWAY_FROM_ZERO') AS value", write={ @@ -3678,6 +3686,14 @@ def test_round(self): }, ) + self.validate_all( + "SELECT ROUND(EXPR => 2.25, SCALE => 1, ROUNDING_MODE => 'HALF_AWAY_FROM_ZERO') AS value", + write={ + "snowflake": "SELECT ROUND(EXPR => 2.25, SCALE => 1, ROUNDING_MODE => 'HALF_AWAY_FROM_ZERO') AS value", + "duckdb": "SELECT ROUND(2.25, 1) AS value", + }, + ) + self.validate_all( "SELECT ROUND(2.25, 1, 'HALF_TO_EVEN') AS value", write={ @@ -3685,3 +3701,27 @@ def test_round(self): "duckdb": "SELECT ROUND_EVEN(2.25, 1) AS value", }, ) + + self.validate_all( + "SELECT ROUND(EXPR => 2.25, SCALE => 1, ROUNDING_MODE => 'HALF_TO_EVEN') AS value", + write={ + "snowflake": "SELECT ROUND(EXPR => 2.25, SCALE => 1, ROUNDING_MODE => 'HALF_TO_EVEN') AS value", + "duckdb": "SELECT ROUND_EVEN(2.25, 1) AS value", + }, + ) + + self.validate_all( + "SELECT ROUND(2.256, 1.8) AS value", + write={ + "snowflake": "SELECT ROUND(2.256, 1.8) AS value", + "duckdb": "SELECT ROUND(2.256, CAST(1.8 AS INT)) AS value", + }, + ) + + self.validate_all( + "SELECT ROUND(2.256, CAST(1.8 AS DECIMAL(38, 0))) AS value", + write={ + "snowflake": "SELECT ROUND(2.256, CAST(1.8 AS DECIMAL(38, 0))) AS value", + "duckdb": "SELECT ROUND(2.256, CAST(CAST(1.8 AS DECIMAL(38, 0)) AS INT)) AS value", + }, + ) From f80ee134bdc87312d0973c3b4a28aa73414b233c Mon Sep 17 00:00:00 2001 From: Tori Wei <41123940+toriwei@users.noreply.github.com> Date: Fri, 5 Dec 2025 11:22:01 -0800 Subject: [PATCH 2/3] remove --- sqlglot/dialects/snowflake.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index e14af3edfc..b5f287d044 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -560,7 +560,6 @@ class Snowflake(Dialect): ALTER_TABLE_ADD_REQUIRED_FOR_EACH_COLUMN = False TRY_CAST_REQUIRES_STRING = True SUPPORTS_ALIAS_REFS_IN_JOIN_CONDITIONS = True - ROUND_casts_non_integer_decimals = True EXPRESSION_METADATA = EXPRESSION_METADATA.copy() From cd408aaf185239fbfdc94d3afa729c2107346675 Mon Sep 17 00:00:00 2001 From: Tori Wei <41123940+toriwei@users.noreply.github.com> Date: Mon, 8 Dec 2025 09:33:27 -0800 Subject: [PATCH 3/3] unpack Kwargs at Snowflake parse time + address other Kwarg feedback --- sqlglot/dialects/duckdb.py | 7 ------- sqlglot/dialects/snowflake.py | 26 +++++++++++++++++++++++++- tests/dialects/test_duckdb.py | 1 - tests/dialects/test_snowflake.py | 32 ++++++++++++++++++++++++++++---- 4 files changed, 53 insertions(+), 13 deletions(-) diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index fb2c46828e..62fcccf614 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -1600,13 +1600,6 @@ def round_sql(self, expression: exp.Round) -> str: decimals = expression.args.get("decimals") truncate = expression.args.get("truncate") - if isinstance(this, exp.Kwarg): - this = this.expression - if isinstance(decimals, exp.Kwarg): - decimals = decimals.expression - if isinstance(truncate, exp.Kwarg): - truncate = truncate.expression - # DuckDB requires the scale (decimals) argument to be an INT # Some dialects (e.g., Snowflake) allow non-integer scales and cast to an integer internally if decimals is not None and expression.args.get("casts_non_integer_decimals"): diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index b5f287d044..033a46a2e4 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -541,7 +541,31 @@ def _build_timestamp_from_parts(args: t.List) -> exp.Func: def _build_round(args: t.List) -> exp.Round: - expression = exp.Round.from_arg_list(args) + """ + Build Round expression, unwrapping Snowflake's named parameters. + + Maps EXPR => this, SCALE => decimals, ROUNDING_MODE => truncate. + + Note: Snowflake does not support mixing named and positional arguments. + Arguments are either all named or all positional. + """ + kwarg_map = {"EXPR": "this", "SCALE": "decimals", "ROUNDING_MODE": "truncate"} + round_args = {} + positional_keys = ["this", "decimals", "truncate"] + positional_idx = 0 + + for arg in args: + if isinstance(arg, exp.Kwarg): + key = arg.this.name.upper() + round_key = kwarg_map.get(key) + if round_key: + round_args[round_key] = arg.expression + else: + if positional_idx < len(positional_keys): + round_args[positional_keys[positional_idx]] = arg + positional_idx += 1 + + expression = exp.Round(**round_args) expression.set("casts_non_integer_decimals", True) return expression diff --git a/tests/dialects/test_duckdb.py b/tests/dialects/test_duckdb.py index a7b2933ebb..0f3abe44ef 100644 --- a/tests/dialects/test_duckdb.py +++ b/tests/dialects/test_duckdb.py @@ -1247,7 +1247,6 @@ def test_duckdb(self): ) self.validate_identity("SELECT GREATEST(1.0, 2.5, NULL, 3.7)") self.validate_identity("FROM t1, t2 SELECT *", "SELECT * FROM t1, t2") - self.validate_identity("ROUND(2.256, 1.8)") self.validate_identity("ROUND(2.256, 1)") # TODO: This is incorrect AST, DATE_PART creates a STRUCT of values but it's stored in 'year' arg diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index 9783802575..abed49b489 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -3673,7 +3673,15 @@ def test_round(self): self.validate_all( "SELECT ROUND(EXPR => 2.25, SCALE => 1) AS value", write={ - "snowflake": "SELECT ROUND(EXPR => 2.25, SCALE => 1) AS value", + "snowflake": "SELECT ROUND(2.25, 1) AS value", + "duckdb": "SELECT ROUND(2.25, 1) AS value", + }, + ) + + self.validate_all( + "SELECT ROUND(SCALE => 1, EXPR => 2.25) AS value", + write={ + "snowflake": "SELECT ROUND(2.25, 1) AS value", "duckdb": "SELECT ROUND(2.25, 1) AS value", }, ) @@ -3689,7 +3697,7 @@ def test_round(self): self.validate_all( "SELECT ROUND(EXPR => 2.25, SCALE => 1, ROUNDING_MODE => 'HALF_AWAY_FROM_ZERO') AS value", write={ - "snowflake": "SELECT ROUND(EXPR => 2.25, SCALE => 1, ROUNDING_MODE => 'HALF_AWAY_FROM_ZERO') AS value", + "snowflake": "SELECT ROUND(2.25, 1, 'HALF_AWAY_FROM_ZERO') AS value", "duckdb": "SELECT ROUND(2.25, 1) AS value", }, ) @@ -3697,7 +3705,23 @@ def test_round(self): self.validate_all( "SELECT ROUND(2.25, 1, 'HALF_TO_EVEN') AS value", write={ - "snowflake": """SELECT ROUND(2.25, 1, 'HALF_TO_EVEN') AS value""", + "snowflake": "SELECT ROUND(2.25, 1, 'HALF_TO_EVEN') AS value", + "duckdb": "SELECT ROUND_EVEN(2.25, 1) AS value", + }, + ) + + self.validate_all( + "SELECT ROUND(ROUNDING_MODE => 'HALF_TO_EVEN', EXPR => 2.25, SCALE => 1) AS value", + write={ + "snowflake": "SELECT ROUND(2.25, 1, 'HALF_TO_EVEN') AS value", + "duckdb": "SELECT ROUND_EVEN(2.25, 1) AS value", + }, + ) + + self.validate_all( + "SELECT ROUND(SCALE => 1, EXPR => 2.25, , ROUNDING_MODE => 'HALF_TO_EVEN') AS value", + write={ + "snowflake": "SELECT ROUND(2.25, 1, 'HALF_TO_EVEN') AS value", "duckdb": "SELECT ROUND_EVEN(2.25, 1) AS value", }, ) @@ -3705,7 +3729,7 @@ def test_round(self): self.validate_all( "SELECT ROUND(EXPR => 2.25, SCALE => 1, ROUNDING_MODE => 'HALF_TO_EVEN') AS value", write={ - "snowflake": "SELECT ROUND(EXPR => 2.25, SCALE => 1, ROUNDING_MODE => 'HALF_TO_EVEN') AS value", + "snowflake": "SELECT ROUND(2.25, 1, 'HALF_TO_EVEN') AS value", "duckdb": "SELECT ROUND_EVEN(2.25, 1) AS value", }, )