Skip to content
45 changes: 45 additions & 0 deletions sqlglot/typing/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,44 @@ def _annotate_median(self: TypeAnnotator, expression: exp.Median) -> exp.Median:
return expression


def _annotate_variance(self: TypeAnnotator, expression: exp.Expression) -> exp.Expression:
"""Annotate variance functions (VAR_POP, VAR_SAMP, VARIANCE, VARIANCE_POP) with correct return type.

Based on Snowflake behavior:
- DECFLOAT -> DECFLOAT(38)
- FLOAT/DOUBLE -> FLOAT
- INT, NUMBER(p, 0) -> NUMBER(38, 6)
- NUMBER(p, s) -> NUMBER(38, max(12, s))
"""
# First annotate the argument to get its type
expression = self._annotate_by_args(expression, "this")

# Get the input type
input_type = expression.this.type

# Special case: DECFLOAT -> DECFLOAT(38)
if input_type.is_type(exp.DataType.Type.DECFLOAT):
self._set_type(expression, exp.DataType.build("DECFLOAT", dialect="snowflake"))
# Special case: FLOAT/DOUBLE -> DOUBLE
elif input_type.is_type(exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE):
self._set_type(expression, exp.DataType.Type.DOUBLE)
# For NUMBER types: determine the scale
else:
exprs = input_type.expressions
scale_expr = seq_get(exprs, 1)
scale = scale_expr.this.to_py() if scale_expr else 0

# If scale is 0 (INT, BIGINT, NUMBER(p,0)): return NUMBER(38, 6)
# Otherwise, Snowflake appears to assign scale through the formula MAX(12, s)
new_scale = 6 if scale == 0 else max(12, scale)

# Build the new NUMBER type
new_type = exp.DataType.build(f"NUMBER({MAX_PRECISION}, {new_scale})", dialect="snowflake")
self._set_type(expression, new_type)

return expression


def _annotate_math_with_float_decfloat(
self: TypeAnnotator, expression: exp.Expression
) -> exp.Expression:
Expand Down Expand Up @@ -408,6 +446,13 @@ def _annotate_math_with_float_decfloat(
exp.MinhashCombine,
}
},
**{
expr_type: {"annotator": _annotate_variance}
for expr_type in (
exp.Variance,
exp.VariancePop,
)
},
exp.ArgMax: {"annotator": _annotate_arg_max_min},
exp.ArgMin: {"annotator": _annotate_arg_max_min},
exp.ConcatWs: {"annotator": lambda self, e: self._annotate_by_args(e, "expressions")},
Expand Down
1 change: 1 addition & 0 deletions tests/dialects/test_duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,7 @@ def test_duckdb(self):
},
write={
"": "VARIANCE_POP(x)",
"duckdb": "VAR_POP(x)",
},
)
self.validate_all(
Expand Down
16 changes: 16 additions & 0 deletions tests/dialects/test_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,22 @@ def test_snowflake(self):
self.validate_identity("SELECT REGR_SXY(y, x)")
self.validate_identity("SELECT REGR_SYY(y, x)")
self.validate_identity("SELECT REGR_SLOPE(y, x)")
self.validate_all(
"SELECT VAR_SAMP(x)",
write={
"snowflake": "SELECT VARIANCE(x)",
"duckdb": "SELECT VARIANCE(x)",
"postgres": "SELECT VAR_SAMP(x)",
},
)
self.validate_all(
"SELECT VAR_POP(x)",
write={
"snowflake": "SELECT VARIANCE_POP(x)",
"duckdb": "SELECT VAR_POP(x)",
"postgres": "SELECT VAR_POP(x)",
},
)
self.validate_all(
"SELECT SKEW(a)",
write={
Expand Down
72 changes: 72 additions & 0 deletions tests/fixtures/optimizer/annotate_functions.sql
Original file line number Diff line number Diff line change
Expand Up @@ -4320,6 +4320,78 @@ INT;
MODE(tbl.str_col) OVER (PARTITION BY tbl.int_col);
VARCHAR;

# dialect: snowflake
VAR_SAMP(tbl.decfloat_col);
DECFLOAT;

# dialect: snowflake
VAR_SAMP(tbl.double_col);
DOUBLE;

# dialect: snowflake
VAR_SAMP(tbl.int_col);
NUMBER(38, 6);

# dialect: snowflake
VARIANCE_SAMP(tbl.decfloat_col);
DECFLOAT;

# dialect: snowflake
VARIANCE_SAMP(tbl.double_col);
DOUBLE;

# dialect: snowflake
VARIANCE_SAMP(tbl.int_col);
NUMBER(38, 6);

# dialect: snowflake
VARIANCE(tbl.decfloat_col);
DECFLOAT;

# dialect: snowflake
VARIANCE(tbl.double_col);
DOUBLE;

# dialect: snowflake
VARIANCE(tbl.int_col);
NUMBER(38, 6);

# dialect: snowflake
VAR_POP(tbl.decfloat_col);
DECFLOAT;

# dialect: snowflake
VAR_POP(tbl.double_col);
DOUBLE;

# dialect: snowflake
VAR_POP(tbl.int_col);
NUMBER(38, 6);

# dialect: snowflake
VARIANCE_POP(tbl.decfloat_col);
DECFLOAT;

# dialect: snowflake
VARIANCE_POP(tbl.double_col);
DOUBLE;

# dialect: snowflake
VARIANCE_POP(tbl.int_col);
NUMBER(38, 6);

# dialect: snowflake
VARIANCE_POP(1::NUMBER(38, 6));
NUMBER(38, 12);

# dialect: snowflake
VARIANCE_POP(1::NUMBER(38, 15));
NUMBER(38, 15);

# dialect: snowflake
VARIANCE_POP(1::NUMBER(30, 5));
NUMBER(38, 12);

--------------------------------------
-- T-SQL
--------------------------------------
Expand Down