From 7428a3981d5f869af282efd292c1059c4d565417 Mon Sep 17 00:00:00 2001 From: Brad Paskewitz Date: Thu, 4 Dec 2025 16:19:14 -0800 Subject: [PATCH] fix(optimizer)!: query schema directly when type annotation fails for processing UNNEST source --- sqlglot/optimizer/resolver.py | 75 +++++++++++++++++++++++++++++++++-- tests/test_optimizer.py | 42 ++++++++++++++++++++ 2 files changed, 114 insertions(+), 3 deletions(-) diff --git a/sqlglot/optimizer/resolver.py b/sqlglot/optimizer/resolver.py index 8256e7d1ca..b6afeabb50 100644 --- a/sqlglot/optimizer/resolver.py +++ b/sqlglot/optimizer/resolver.py @@ -144,9 +144,25 @@ def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequenc # in bigquery, unnest structs are automatically scoped as tables, so you can # directly select a struct field in a query. # this handles the case where the unnest is statically defined. - if self.dialect.UNNEST_COLUMN_ONLY: - if source.expression.is_type(exp.DataType.Type.STRUCT): - for k in source.expression.type.expressions: # type: ignore + if self.dialect.UNNEST_COLUMN_ONLY and isinstance(source.expression, exp.Unnest): + unnest = source.expression + + # if type is not annotated yet, try to get it from the schema + if not unnest.type or unnest.type.is_type(exp.DataType.Type.UNKNOWN): + unnest_expr = seq_get(unnest.expressions, 0) + if isinstance(unnest_expr, exp.Column) and self.scope.parent: + col_type = self._get_unnest_column_type(unnest_expr) + # extract element type if it's an ARRAY + if col_type and col_type.is_type(exp.DataType.Type.ARRAY): + element_types = col_type.expressions + if element_types: + unnest.type = element_types[0].copy() + else: + if col_type: + unnest.type = col_type.copy() + # check if the result type is a STRUCT - extract struct field names + if unnest.is_type(exp.DataType.Type.STRUCT): + for k in unnest.type.expressions: # type: ignore columns.append(k.name) elif isinstance(source, Scope) and isinstance(source.expression, exp.SetOperation): columns = self.get_source_columns_from_set_op(source.expression) @@ -299,3 +315,56 @@ def _get_unambiguous_columns( unambiguous_columns[column] = table return unambiguous_columns + + def _get_unnest_column_type(self, column: exp.Column) -> t.Optional[exp.DataType]: + """ + Get the type of a column being unnested, tracing through CTEs/subqueries to find the base table. + + Args: + column: The column expression being unnested. + + Returns: + The DataType of the column, or None if not found. + """ + scope = self.scope.parent + + # if column is qualified, use that table, otherwise disambiguate using the resolver + if column.table: + table_name = column.table + else: + # use the parent scope's resolver to disambiguate the column + parent_resolver = Resolver(scope, self.schema, self._infer_schema) + table_identifier = parent_resolver.get_table(column) + if not table_identifier: + return None + table_name = table_identifier.name + + source = scope.sources.get(table_name) + return self._get_column_type_from_scope(source, column) if source else None + + def _get_column_type_from_scope( + self, source: t.Union[Scope, exp.Table], column: exp.Column + ) -> t.Optional[exp.DataType]: + """ + Get a column's type by tracing through scopes/tables to find the base table. + + Args: + source: The source to search - can be a Scope (to iterate its sources) or a Table. + column: The column to find the type for. + + Returns: + The DataType of the column, or None if not found. + """ + if isinstance(source, exp.Table): + # base table - get the column type from schema + col_type: t.Optional[exp.DataType] = self.schema.get_column_type(source, column) + if col_type and not col_type.is_type(exp.DataType.Type.UNKNOWN): + return col_type + elif isinstance(source, Scope): + # iterate over all sources in the scope + for source_name, nested_source in source.sources.items(): + col_type = self._get_column_type_from_scope(nested_source, column) + if col_type and not col_type.is_type(exp.DataType.Type.UNKNOWN): + return col_type + + return None diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 05d0f79f6d..040c558eaa 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -516,6 +516,48 @@ def test_qualify_columns(self, logger): "SELECT a.b_id AS b_id FROM a AS a JOIN b AS b ON a.b_id = b.b_id JOIN c AS c ON b.b_id = c.b_id JOIN d AS d ON b.d_id = d.d_id", ) + self.assertEqual( + optimizer.qualify.qualify( + parse_one( + """ + SELECT + (SELECT SUM(c.amount) + FROM UNNEST(credits) AS c + WHERE type != 'promotion') as total + FROM billing + """, + read="bigquery", + ), + schema={"billing": {"credits": "ARRAY>"}}, + dialect="bigquery", + ).sql(dialect="bigquery"), + "SELECT (SELECT SUM(`c`.`amount`) AS `_col_0` FROM UNNEST(`billing`.`credits`) AS `c` WHERE `type` <> 'promotion') AS `total` FROM `billing` AS `billing`", + ) + + self.assertEqual( + optimizer.qualify.qualify( + parse_one( + """ + WITH cte AS (SELECT * FROM base_table) + SELECT + (SELECT SUM(item.price) + FROM UNNEST(items) AS item + WHERE category = 'electronics') as electronics_total + FROM cte + """, + read="bigquery", + ), + schema={ + "base_table": { + "id": "INT64", + "items": "ARRAY>", + } + }, + dialect="bigquery", + ).sql(dialect="bigquery"), + "WITH `cte` AS (SELECT `base_table`.`id` AS `id`, `base_table`.`items` AS `items` FROM `base_table` AS `base_table`) SELECT (SELECT SUM(`item`.`price`) AS `_col_0` FROM UNNEST(`cte`.`items`) AS `item` WHERE `category` = 'electronics') AS `electronics_total` FROM `cte` AS `cte`", + ) + self.check_file( "qualify_columns", qualify_columns,