diff --git a/sqlglot/dialects/exasol.py b/sqlglot/dialects/exasol.py index da9ed97493..de8b7741f7 100644 --- a/sqlglot/dialects/exasol.py +++ b/sqlglot/dialects/exasol.py @@ -20,6 +20,7 @@ from sqlglot.generator import unsupported_args from sqlglot.helper import seq_get from sqlglot.tokens import TokenType +from sqlglot.optimizer.scope import build_scope if t.TYPE_CHECKING: from sqlglot.dialects.dialect import DialectType @@ -169,6 +170,66 @@ def _substring_index_sql(self: Exasol.Generator, expression: exp.SubstringIndex) return self.func("SUBSTR", haystack_sql, direction, length) +# https://docs.exasol.com/db/latest/sql/select.htm#:~:text=The%20select_list%20defines%20the%20columns%20of%20the%20result%20table.%20If%20*%20is%20used%2C%20all%20columns%20are%20listed.%20You%20can%20use%20an%20expression%20like%20t.*%20to%20list%20all%20columns%20of%20the%20table%20t%2C%20the%20view%20t%2C%20or%20the%20object%20with%20the%20table%20alias%20t. +def _qualify_unscoped_star(expression: exp.Expression) -> exp.Expression: + """ + Exasol doesn't support a bare * alongside other select items, so we rewrite it + Rewrite: SELECT *, FROM + Into: SELECT T.*, FROM
AS T + """ + + if not isinstance(expression, exp.Select): + return expression + + select_expressions = expression.expressions or [] + + def is_bare_star(expr: exp.Expression) -> bool: + return isinstance(expr, exp.Star) and expr.this is None + + has_other_expression = False + bare_star_expr: exp.Expression | None = None + for expr in select_expressions: + has_bare_star = is_bare_star(expr) + if has_bare_star and bare_star_expr is None: + bare_star_expr = expr + elif not has_bare_star: + has_other_expression = True + if bare_star_expr and has_other_expression: + break + + if not (bare_star_expr and has_other_expression): + return expression + + scope = build_scope(expression) + + if not scope or not scope.selected_sources: + return expression + + table_identifiers: list[exp.Identifier] = [] + + for source_name, (source_expr, _) in scope.selected_sources.items(): + ident = ( + source_expr.this.copy() + if isinstance(source_expr, exp.Table) and isinstance(source_expr.this, exp.Identifier) + else exp.to_identifier(source_name) + ) + table_identifiers.append(ident) + + qualified_star_columns = [ + exp.Column(this=bare_star_expr.copy(), table=ident) for ident in table_identifiers + ] + + new_select_expressions: list[exp.Expression] = [] + + for select_expr in select_expressions: + new_select_expressions.extend(qualified_star_columns) if is_bare_star( + select_expr + ) else new_select_expressions.append(select_expr) + + expression.set("expressions", new_select_expressions) + return expression + + def _add_date_sql(self: Exasol.Generator, expression: DATE_ADD_OR_SUB) -> str: interval = expression.expression if isinstance(expression.expression, exp.Interval) else None @@ -453,6 +514,7 @@ def datatype_sql(self, expression: exp.DataType) -> str: exp.CommentColumnConstraint: lambda self, e: f"COMMENT IS {self.sql(e, 'this')}", exp.Select: transforms.preprocess( [ + _qualify_unscoped_star, _add_local_prefix_for_aliases, ] ), diff --git a/tests/dialects/test_exasol.py b/tests/dialects/test_exasol.py index 8ff110d25d..5469d41bf5 100644 --- a/tests/dialects/test_exasol.py +++ b/tests/dialects/test_exasol.py @@ -11,6 +11,44 @@ def test_exasol(self): 'SELECT 1 AS "x"', ) + def test_qualify_unscoped_star(self): + self.validate_all( + "SELECT TEST.*, 1 FROM TEST", + read={ + "": "SELECT *, 1 FROM TEST", + }, + ) + self.validate_identity( + "SELECT t.*, 1 FROM t", + ) + self.validate_identity( + "SELECT t.* FROM t", + ) + self.validate_identity( + "SELECT * FROM t", + ) + self.validate_identity( + "WITH t AS (SELECT 1 AS x) SELECT t.*, 3 FROM t", + ) + self.validate_all( + "WITH t1 AS (SELECT 1 AS c1), t2 AS (SELECT 2 AS c2) SELECT t1.*, t2.*, 3 FROM t1, t2", + read={ + "": "WITH t1 AS (SELECT 1 AS c1), t2 AS (SELECT 2 AS c2) SELECT *, 3 FROM t1, t2", + }, + ) + self.validate_all( + 'SELECT "A".*, "B".*, 3 FROM "A" JOIN "B" ON 1 = 1', + read={ + "": 'SELECT *, 3 FROM "A" JOIN "B" ON 1=1', + }, + ) + self.validate_all( + "SELECT s.*, q.*, 7 FROM (SELECT 1 AS x) AS s CROSS JOIN (SELECT 2 AS y) AS q", + read={ + "": "SELECT *, 7 FROM (SELECT 1 AS x) s CROSS JOIN (SELECT 2 AS y) q", + }, + ) + def test_type_mappings(self): self.validate_identity("CAST(x AS BLOB)", "CAST(x AS VARCHAR)") self.validate_identity("CAST(x AS LONGBLOB)", "CAST(x AS VARCHAR)")