diff --git a/sqlglot/typing/snowflake.py b/sqlglot/typing/snowflake.py index dbfa551fd4..2b5ea7c65e 100644 --- a/sqlglot/typing/snowflake.py +++ b/sqlglot/typing/snowflake.py @@ -142,6 +142,44 @@ def _annotate_median(self: TypeAnnotator, expression: exp.Median) -> exp.Median: return expression +def _annotate_variance(self: TypeAnnotator, expression: exp.Expression) -> exp.Expression: + """Annotate variance functions (VAR_POP, VAR_SAMP, VARIANCE, VARIANCE_POP) with correct return type. + + Based on Snowflake behavior: + - DECFLOAT -> DECFLOAT(38) + - FLOAT/DOUBLE -> FLOAT + - INT, NUMBER(p, 0) -> NUMBER(38, 6) + - NUMBER(p, s) -> NUMBER(38, max(12, s)) + """ + # First annotate the argument to get its type + expression = self._annotate_by_args(expression, "this") + + # Get the input type + input_type = expression.this.type + + # Special case: DECFLOAT -> DECFLOAT(38) + if input_type.is_type(exp.DataType.Type.DECFLOAT): + self._set_type(expression, exp.DataType.build("DECFLOAT", dialect="snowflake")) + # Special case: FLOAT/DOUBLE -> DOUBLE + elif input_type.is_type(exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE): + self._set_type(expression, exp.DataType.Type.DOUBLE) + # For NUMBER types: determine the scale + else: + exprs = input_type.expressions + scale_expr = seq_get(exprs, 1) + scale = scale_expr.this.to_py() if scale_expr else 0 + + # If scale is 0 (INT, BIGINT, NUMBER(p,0)): return NUMBER(38, 6) + # Otherwise, Snowflake appears to assign scale through the formula MAX(12, s) + new_scale = 6 if scale == 0 else max(12, scale) + + # Build the new NUMBER type + new_type = exp.DataType.build(f"NUMBER({MAX_PRECISION}, {new_scale})", dialect="snowflake") + self._set_type(expression, new_type) + + return expression + + def _annotate_math_with_float_decfloat( self: TypeAnnotator, expression: exp.Expression ) -> exp.Expression: @@ -408,6 +446,13 @@ def _annotate_math_with_float_decfloat( exp.MinhashCombine, } }, + **{ + expr_type: {"annotator": _annotate_variance} + for expr_type in ( + exp.Variance, + exp.VariancePop, + ) + }, exp.ArgMax: {"annotator": _annotate_arg_max_min}, exp.ArgMin: {"annotator": _annotate_arg_max_min}, exp.ConcatWs: {"annotator": lambda self, e: self._annotate_by_args(e, "expressions")}, diff --git a/tests/dialects/test_duckdb.py b/tests/dialects/test_duckdb.py index 6e0836d87c..5e04d3ce73 100644 --- a/tests/dialects/test_duckdb.py +++ b/tests/dialects/test_duckdb.py @@ -639,6 +639,7 @@ def test_duckdb(self): }, write={ "": "VARIANCE_POP(x)", + "duckdb": "VAR_POP(x)", }, ) self.validate_all( diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index 978f14e71f..155ee07be2 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -147,6 +147,22 @@ def test_snowflake(self): self.validate_identity("SELECT REGR_SXY(y, x)") self.validate_identity("SELECT REGR_SYY(y, x)") self.validate_identity("SELECT REGR_SLOPE(y, x)") + self.validate_all( + "SELECT VAR_SAMP(x)", + write={ + "snowflake": "SELECT VARIANCE(x)", + "duckdb": "SELECT VARIANCE(x)", + "postgres": "SELECT VAR_SAMP(x)", + }, + ) + self.validate_all( + "SELECT VAR_POP(x)", + write={ + "snowflake": "SELECT VARIANCE_POP(x)", + "duckdb": "SELECT VAR_POP(x)", + "postgres": "SELECT VAR_POP(x)", + }, + ) self.validate_all( "SELECT SKEW(a)", write={ diff --git a/tests/fixtures/optimizer/annotate_functions.sql b/tests/fixtures/optimizer/annotate_functions.sql index f20b53a18d..8485133d26 100644 --- a/tests/fixtures/optimizer/annotate_functions.sql +++ b/tests/fixtures/optimizer/annotate_functions.sql @@ -4320,6 +4320,78 @@ INT; MODE(tbl.str_col) OVER (PARTITION BY tbl.int_col); VARCHAR; +# dialect: snowflake +VAR_SAMP(tbl.decfloat_col); +DECFLOAT; + +# dialect: snowflake +VAR_SAMP(tbl.double_col); +DOUBLE; + +# dialect: snowflake +VAR_SAMP(tbl.int_col); +NUMBER(38, 6); + +# dialect: snowflake +VARIANCE_SAMP(tbl.decfloat_col); +DECFLOAT; + +# dialect: snowflake +VARIANCE_SAMP(tbl.double_col); +DOUBLE; + +# dialect: snowflake +VARIANCE_SAMP(tbl.int_col); +NUMBER(38, 6); + +# dialect: snowflake +VARIANCE(tbl.decfloat_col); +DECFLOAT; + +# dialect: snowflake +VARIANCE(tbl.double_col); +DOUBLE; + +# dialect: snowflake +VARIANCE(tbl.int_col); +NUMBER(38, 6); + +# dialect: snowflake +VAR_POP(tbl.decfloat_col); +DECFLOAT; + +# dialect: snowflake +VAR_POP(tbl.double_col); +DOUBLE; + +# dialect: snowflake +VAR_POP(tbl.int_col); +NUMBER(38, 6); + +# dialect: snowflake +VARIANCE_POP(tbl.decfloat_col); +DECFLOAT; + +# dialect: snowflake +VARIANCE_POP(tbl.double_col); +DOUBLE; + +# dialect: snowflake +VARIANCE_POP(tbl.int_col); +NUMBER(38, 6); + +# dialect: snowflake +VARIANCE_POP(1::NUMBER(38, 6)); +NUMBER(38, 12); + +# dialect: snowflake +VARIANCE_POP(1::NUMBER(38, 15)); +NUMBER(38, 15); + +# dialect: snowflake +VARIANCE_POP(1::NUMBER(30, 5)); +NUMBER(38, 12); + -------------------------------------- -- T-SQL --------------------------------------