diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index a2fdc0a9b0..79f8d8f8b3 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -1324,6 +1324,9 @@ class Tokenizer(tokens.Tokenizer): "TIMESTAMP_TZ": TokenType.TIMESTAMPTZ, "TOP": TokenType.TOP, "WAREHOUSE": TokenType.WAREHOUSE, + # https://docs.snowflake.com/en/sql-reference/data-types-numeric#float + # FLOAT is a synonym for DOUBLE in Snowflake + "FLOAT": TokenType.DOUBLE, } KEYWORDS.pop("/*+") @@ -1573,6 +1576,15 @@ def values_sql(self, expression: exp.Values, values_as_table: bool = True) -> st return super().values_sql(expression, values_as_table=values_as_table) def datatype_sql(self, expression: exp.DataType) -> str: + # Check if this is a FLOAT type nested inside a VECTOR type + # VECTOR only accepts FLOAT (not DOUBLE), INT, and STRING as element types + # https://docs.snowflake.com/en/sql-reference/data-types-vector + if expression.is_type(exp.DataType.Type.DOUBLE): + parent = expression.parent + if isinstance(parent, exp.DataType) and parent.is_type(exp.DataType.Type.VECTOR): + # Preserve FLOAT for VECTOR types instead of mapping to synonym DOUBLE + return "FLOAT" + expressions = expression.expressions if expressions and expression.is_type(*exp.DataType.STRUCT_TYPES): for field_type in expressions: diff --git a/sqlglot/typing/snowflake.py b/sqlglot/typing/snowflake.py index d2103ff637..b77e92dc5c 100644 --- a/sqlglot/typing/snowflake.py +++ b/sqlglot/typing/snowflake.py @@ -110,7 +110,7 @@ def _annotate_median(self: TypeAnnotator, expression: exp.Median) -> exp.Median: """Annotate MEDIAN function with correct return type. Based on Snowflake documentation: - - If the expr is FLOAT -> annotate as FLOAT + - If the expr is FLOAT/DOUBLE -> annotate as DOUBLE (FLOAT is a synonym for DOUBLE) - If the expr is NUMBER(p, s) -> annotate as NUMBER(min(p+3, 38), min(s+3, 37)) """ # First annotate the argument to get its type @@ -119,9 +119,9 @@ def _annotate_median(self: TypeAnnotator, expression: exp.Median) -> exp.Median: # Get the input type input_type = expression.this.type - if input_type.is_type(exp.DataType.Type.FLOAT): - # If input is FLOAT, return FLOAT - self._set_type(expression, exp.DataType.Type.FLOAT) + if input_type.is_type(exp.DataType.Type.DOUBLE): + # If input is FLOAT/DOUBLE, return DOUBLE (FLOAT is normalized to DOUBLE in Snowflake) + self._set_type(expression, exp.DataType.Type.DOUBLE) else: # If input is NUMBER(p, s), return NUMBER(min(p+3, 38), min(s+3, 37)) exprs = input_type.expressions diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index 7e1af1d1f5..ccd3e4427d 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -324,7 +324,7 @@ def test_snowflake(self): self.validate_identity("$x") # parameter self.validate_identity("a$b") # valid snowflake identifier self.validate_identity("SELECT REGEXP_LIKE(a, b, c)") - self.validate_identity("CREATE TABLE foo (bar FLOAT AUTOINCREMENT START 0 INCREMENT 1)") + self.validate_identity("CREATE TABLE foo (bar DOUBLE AUTOINCREMENT START 0 INCREMENT 1)") self.validate_identity("COMMENT IF EXISTS ON TABLE foo IS 'bar'") self.validate_identity("SELECT CONVERT_TIMEZONE('UTC', 'America/Los_Angeles', col)") self.validate_identity("SELECT CURRENT_ORGANIZATION_NAME()") @@ -1797,6 +1797,13 @@ def test_snowflake(self): "duckdb": "SET VARIABLE a = 1", }, ) + self.validate_all( + "CAST(6.43 AS FLOAT)", + write={ + "snowflake": "CAST(6.43 AS DOUBLE)", + "duckdb": "CAST(6.43 AS DOUBLE)", + }, + ) def test_null_treatment(self): self.validate_all(