diff --git a/airflow-core/src/airflow/migrations/utils.py b/airflow-core/src/airflow/migrations/utils.py index 4eeaf373c6a87..87a1e970f861e 100644 --- a/airflow-core/src/airflow/migrations/utils.py +++ b/airflow-core/src/airflow/migrations/utils.py @@ -19,13 +19,23 @@ import contextlib from contextlib import contextmanager +import sqlalchemy as sa +from alembic import op as alembic_op +from sqlalchemy import text + + +def get_dialect_name(op) -> str: + conn = op.get_bind() + return conn.dialect.name if conn is not None else op.get_context().dialect.name + @contextmanager def disable_sqlite_fkeys(op): - if op.get_bind().dialect.name == "sqlite": - op.execute("PRAGMA foreign_keys=off") - yield op - op.execute("PRAGMA foreign_keys=on") + if get_dialect_name(op) == "sqlite": + with contextlib.ExitStack() as exit_stack: + op.execute("PRAGMA foreign_keys=off") + exit_stack.callback(op.execute, "PRAGMA foreign_keys=on") + yield op else: yield op @@ -56,8 +66,196 @@ def mysql_drop_foreignkey_if_exists(constraint_name, table_name, op): def ignore_sqlite_value_error(): - from alembic import op - - if op.get_bind().dialect.name == "sqlite": + if get_dialect_name(alembic_op) == "sqlite": return contextlib.suppress(ValueError) return contextlib.nullcontext() + + +def create_index_if_not_exists(op, index_name, table_name, columns, unique=False) -> None: + """ + Create an index if it does not already exist. + + MySQL does not support CREATE INDEX IF NOT EXISTS, so a stored procedure is used. + PostgreSQL and SQLite support it natively. + """ + dialect_name = get_dialect_name(op) + + if dialect_name == "mysql": + unique_kw = "UNIQUE " if unique else "" + col_list = ", ".join(f"`{c}`" for c in columns) + op.execute( + text(f""" + DROP PROCEDURE IF EXISTS CreateIndexIfNotExists; + CREATE PROCEDURE CreateIndexIfNotExists() + BEGIN + IF NOT EXISTS ( + SELECT 1 + FROM information_schema.STATISTICS + WHERE + TABLE_SCHEMA = DATABASE() AND + TABLE_NAME = '{table_name}' AND + INDEX_NAME = '{index_name}' + ) THEN + CREATE {unique_kw}INDEX `{index_name}` ON `{table_name}` ({col_list}); + END IF; + END; + CALL CreateIndexIfNotExists(); + DROP PROCEDURE IF EXISTS CreateIndexIfNotExists; + """) + ) + else: + op.create_index(index_name, table_name, columns, unique=unique, if_not_exists=True) + + +def drop_index_if_exists(op, index_name, table_name) -> None: + """ + Drop an index if it exists. + + Works in both online and offline mode by using raw SQL for PostgreSQL and MySQL. + SQLite and PostgreSQL support DROP INDEX IF EXISTS natively. + MySQL requires a stored procedure since it does not support IF EXISTS for DROP INDEX. + """ + dialect_name = get_dialect_name(op) + + if dialect_name == "mysql": + op.execute( + text(f""" + DROP PROCEDURE IF EXISTS DropIndexIfExists; + CREATE PROCEDURE DropIndexIfExists() + BEGIN + IF EXISTS ( + SELECT 1 + FROM information_schema.STATISTICS + WHERE + TABLE_SCHEMA = DATABASE() AND + TABLE_NAME = '{table_name}' AND + INDEX_NAME = '{index_name}' + ) THEN + DROP INDEX `{index_name}` ON `{table_name}`; + END IF; + END; + CALL DropIndexIfExists(); + DROP PROCEDURE DropIndexIfExists; + """) + ) + else: + # PostgreSQL and SQLite both support DROP INDEX IF EXISTS + op.drop_index(index_name, table_name=table_name, if_exists=True) + + +def drop_unique_constraints_on_columns(op, table_name, columns) -> None: + """ + Drop all unique constraints covering any of the given columns, regardless of constraint name. + + Works in both online and offline mode by using raw SQL for PostgreSQL and MySQL. + SQLite falls back to batch mode and requires a live connection. + """ + dialect_name = get_dialect_name(op) + + if dialect_name == "postgresql": + cols_array = ", ".join(f"'{c}'" for c in columns) + op.execute( + text(f""" + DO $$ + DECLARE r record; + BEGIN + FOR r IN + SELECT DISTINCT tc.constraint_name + FROM information_schema.table_constraints tc + JOIN information_schema.key_column_usage kcu + ON tc.constraint_name = kcu.constraint_name + AND tc.table_schema = kcu.table_schema + WHERE tc.table_name = '{table_name}' + AND tc.constraint_type = 'UNIQUE' + AND kcu.column_name = ANY(ARRAY[{cols_array}]::text[]) + LOOP + EXECUTE 'ALTER TABLE ' || quote_ident('{table_name}') || ' DROP CONSTRAINT IF EXISTS ' + || quote_ident(r.constraint_name); + END LOOP; + END $$ + """) + ) + elif dialect_name == "mysql": + cols_in = ", ".join(f"'{c}'" for c in columns) + op.execute( + text(f""" + DROP PROCEDURE IF EXISTS DropUniqueOnColumns; + CREATE PROCEDURE DropUniqueOnColumns() + BEGIN + DECLARE done INT DEFAULT FALSE; + DECLARE v_name VARCHAR(255); + DECLARE cur CURSOR FOR + SELECT DISTINCT kcu.CONSTRAINT_NAME + FROM information_schema.KEY_COLUMN_USAGE kcu + JOIN information_schema.TABLE_CONSTRAINTS tc + ON kcu.CONSTRAINT_NAME = tc.CONSTRAINT_NAME + AND kcu.TABLE_SCHEMA = tc.TABLE_SCHEMA + AND kcu.TABLE_NAME = tc.TABLE_NAME + WHERE kcu.TABLE_NAME = '{table_name}' + AND kcu.TABLE_SCHEMA = DATABASE() + AND tc.CONSTRAINT_TYPE = 'UNIQUE' + AND kcu.COLUMN_NAME IN ({cols_in}); + DECLARE CONTINUE HANDLER FOR NOT FOUND SET done = TRUE; + OPEN cur; + drop_loop: LOOP + FETCH cur INTO v_name; + IF done THEN LEAVE drop_loop; END IF; + SET @stmt = CONCAT('ALTER TABLE `{table_name}` DROP INDEX `', v_name, '`'); + PREPARE s FROM @stmt; + EXECUTE s; + DEALLOCATE PREPARE s; + END LOOP; + CLOSE cur; + END; + CALL DropUniqueOnColumns(); + DROP PROCEDURE DropUniqueOnColumns; + """) + ) + else: + # SQLite — batch mode rewrites the table; requires a live connection + with op.batch_alter_table(table_name, schema=None) as batch_op: + for uq in sa.inspect(op.get_bind()).get_unique_constraints(table_name): + if any(col in uq["column_names"] for col in columns): + batch_op.drop_constraint(uq["name"], type_="unique") + + +def drop_unique_constraint_if_exists(op, table_name, constraint_name) -> None: + """ + Drop a unique constraint by name if it exists. + + Works in both online and offline mode by using raw SQL for PostgreSQL and MySQL. + SQLite falls back to batch mode and requires a live connection. + """ + dialect_name = get_dialect_name(op) + + if dialect_name == "postgresql": + op.execute(text(f'ALTER TABLE "{table_name}" DROP CONSTRAINT IF EXISTS "{constraint_name}"')) + elif dialect_name == "mysql": + op.execute( + text(f""" + DROP PROCEDURE IF EXISTS DropUniqueIfExists; + CREATE PROCEDURE DropUniqueIfExists() + BEGIN + IF EXISTS ( + SELECT 1 + FROM information_schema.TABLE_CONSTRAINTS + WHERE + CONSTRAINT_SCHEMA = DATABASE() AND + TABLE_NAME = '{table_name}' AND + CONSTRAINT_NAME = '{constraint_name}' AND + CONSTRAINT_TYPE = 'UNIQUE' + ) THEN + ALTER TABLE `{table_name}` DROP INDEX `{constraint_name}`; + ELSE + SELECT 1; + END IF; + END; + CALL DropUniqueIfExists(); + DROP PROCEDURE DropUniqueIfExists; + """) + ) + else: + # SQLite — batch mode rewrites the table; requires a live connection + with op.batch_alter_table(table_name, schema=None) as batch_op: + with contextlib.suppress(ValueError): + batch_op.drop_constraint(constraint_name, type_="unique") diff --git a/airflow-core/src/airflow/migrations/versions/0041_3_0_0_rename_dataset_as_asset.py b/airflow-core/src/airflow/migrations/versions/0041_3_0_0_rename_dataset_as_asset.py index 7234b8233c4c1..737961a62a9fd 100644 --- a/airflow-core/src/airflow/migrations/versions/0041_3_0_0_rename_dataset_as_asset.py +++ b/airflow-core/src/airflow/migrations/versions/0041_3_0_0_rename_dataset_as_asset.py @@ -30,9 +30,8 @@ import sqlalchemy as sa from alembic import op -from sqlalchemy import text -from airflow.migrations.utils import mysql_drop_foreignkey_if_exists +from airflow.migrations.utils import disable_sqlite_fkeys, mysql_drop_foreignkey_if_exists # revision identifiers, used by Alembic. revision = "05234396c6fc" @@ -103,19 +102,18 @@ def _drop_fkey_if_exists(table, constraint_name): conn = op.get_bind() dialect_name = conn.dialect.name - if dialect_name == "sqlite": - # SQLite requires foreign key constraints to be disabled during batch operations - conn.execute(text("PRAGMA foreign_keys=OFF")) - try: - with op.batch_alter_table(table, schema=None) as batch_op: - batch_op.drop_constraint(op.f(constraint_name), type_="foreignkey") - except ValueError: - pass - conn.execute(text("PRAGMA foreign_keys=ON")) - elif dialect_name == "mysql": + if dialect_name == "mysql": mysql_drop_foreignkey_if_exists(constraint_name, table, op) - else: + elif dialect_name == "postgresql": op.execute(f"ALTER TABLE {table} DROP CONSTRAINT IF EXISTS {constraint_name}") + else: + # SQLite requires foreign key constraints to be disabled during batch operations. + with disable_sqlite_fkeys(op): + try: + with op.batch_alter_table(table, schema=None) as batch_op: + batch_op.drop_constraint(op.f(constraint_name), type_="foreignkey") + except ValueError: + pass # original table name to new table name diff --git a/airflow-core/src/airflow/migrations/versions/0082_3_1_0_make_bundle_name_not_nullable.py b/airflow-core/src/airflow/migrations/versions/0082_3_1_0_make_bundle_name_not_nullable.py index e4507454bb43f..b899fd1cda00e 100644 --- a/airflow-core/src/airflow/migrations/versions/0082_3_1_0_make_bundle_name_not_nullable.py +++ b/airflow-core/src/airflow/migrations/versions/0082_3_1_0_make_bundle_name_not_nullable.py @@ -31,7 +31,7 @@ from sqlalchemy.sql import text from airflow.migrations.db_types import StringID -from airflow.migrations.utils import ignore_sqlite_value_error +from airflow.migrations.utils import disable_sqlite_fkeys, get_dialect_name, ignore_sqlite_value_error # revision identifiers, used by Alembic. revision = "7582ea3f3dd5" @@ -43,79 +43,66 @@ def upgrade(): """Make bundle_name not nullable.""" - dialect_name = op.get_bind().dialect.name - if dialect_name == "postgresql": - op.execute( - text(""" - INSERT INTO dag_bundle (name) VALUES - ('example_dags'), - ('dags-folder') - ON CONFLICT (name) DO NOTHING; - """) - ) - if dialect_name == "mysql": - op.execute( - text(""" - INSERT IGNORE INTO dag_bundle (name) VALUES - ('example_dags'), - ('dags-folder'); + with disable_sqlite_fkeys(op): + dialect_name = get_dialect_name(op) + if dialect_name == "postgresql": + op.execute( + text(""" + INSERT INTO dag_bundle (name) VALUES + ('example_dags'), + ('dags-folder') + ON CONFLICT (name) DO NOTHING; """) - ) - if dialect_name == "sqlite": - op.execute(text("PRAGMA foreign_keys=OFF")) - op.execute( - text(""" - INSERT OR IGNORE INTO dag_bundle (name) VALUES - ('example_dags'), - ('dags-folder'); - """) - ) - - conn = op.get_bind() - with ignore_sqlite_value_error(), op.batch_alter_table("dag", schema=None) as batch_op: - conn.execute( - text( - """ - UPDATE dag - SET bundle_name = - CASE - WHEN fileloc LIKE '%/airflow/example_dags/%' THEN 'example_dags' - ELSE 'dags-folder' - END - WHERE bundle_name IS NULL - """ ) - ) - # drop the foreign key temporarily and recreate it once both columns are changed - batch_op.drop_constraint(batch_op.f("dag_bundle_name_fkey"), type_="foreignkey") - batch_op.alter_column("bundle_name", nullable=False, existing_type=StringID()) + if dialect_name == "mysql": + op.execute( + text(""" + INSERT IGNORE INTO dag_bundle (name) VALUES + ('example_dags'), + ('dags-folder'); + """) + ) - with op.batch_alter_table("dag_bundle", schema=None) as batch_op: - batch_op.alter_column("name", nullable=False, existing_type=StringID()) + if dialect_name == "sqlite": + op.execute( + text(""" + INSERT OR IGNORE INTO dag_bundle (name) VALUES + ('example_dags'), + ('dags-folder'); + """) + ) - with op.batch_alter_table("dag", schema=None) as batch_op: - batch_op.create_foreign_key( - batch_op.f("dag_bundle_name_fkey"), "dag_bundle", ["bundle_name"], ["name"] - ) + conn = op.get_bind() + with ignore_sqlite_value_error(), op.batch_alter_table("dag", schema=None) as batch_op: + conn.execute( + text( + """ + UPDATE dag + SET bundle_name = + CASE + WHEN fileloc LIKE '%/airflow/example_dags/%' THEN 'example_dags' + ELSE 'dags-folder' + END + WHERE bundle_name IS NULL + """ + ) + ) + # drop the foreign key temporarily and recreate it once both columns are changed + batch_op.drop_constraint(batch_op.f("dag_bundle_name_fkey"), type_="foreignkey") + batch_op.alter_column("bundle_name", nullable=False, existing_type=StringID()) - if dialect_name == "sqlite": - op.execute(text("PRAGMA foreign_keys=ON")) + with op.batch_alter_table("dag_bundle", schema=None) as batch_op: + batch_op.alter_column("name", nullable=False, existing_type=StringID()) + + with op.batch_alter_table("dag", schema=None) as batch_op: + batch_op.create_foreign_key( + batch_op.f("dag_bundle_name_fkey"), "dag_bundle", ["bundle_name"], ["name"] + ) def downgrade(): """Make bundle_name nullable.""" - import contextlib - - dialect_name = op.get_bind().dialect.name - exitstack = contextlib.ExitStack() - - if dialect_name == "sqlite": - # SQLite requires foreign key constraints to be disabled during batch operations - conn = op.get_bind() - conn.execute(text("PRAGMA foreign_keys=OFF")) - exitstack.callback(conn.execute, text("PRAGMA foreign_keys=ON")) - - with exitstack: + with disable_sqlite_fkeys(op): with op.batch_alter_table("dag", schema=None) as batch_op: batch_op.drop_constraint(batch_op.f("dag_bundle_name_fkey"), type_="foreignkey") batch_op.alter_column("bundle_name", nullable=True, existing_type=StringID()) diff --git a/airflow-core/src/airflow/migrations/versions/0084_3_1_0_add_last_parse_duration_to_dag_model.py b/airflow-core/src/airflow/migrations/versions/0084_3_1_0_add_last_parse_duration_to_dag_model.py index 1a922d97deed5..61731e75264b3 100644 --- a/airflow-core/src/airflow/migrations/versions/0084_3_1_0_add_last_parse_duration_to_dag_model.py +++ b/airflow-core/src/airflow/migrations/versions/0084_3_1_0_add_last_parse_duration_to_dag_model.py @@ -29,7 +29,8 @@ import sqlalchemy as sa from alembic import op -from sqlalchemy import text + +from airflow.migrations.utils import disable_sqlite_fkeys # revision identifiers, used by Alembic. revision = "eaf332f43c7c" @@ -47,15 +48,6 @@ def upgrade(): def downgrade(): """Unapply add last_parse_duration to dag model.""" - conn = op.get_bind() - dialect_name = conn.dialect.name - - if dialect_name == "sqlite": - # SQLite requires foreign key constraints to be disabled during batch operations - conn.execute(text("PRAGMA foreign_keys=OFF")) - with op.batch_alter_table("dag", schema=None) as batch_op: - batch_op.drop_column("last_parse_duration") - conn.execute(text("PRAGMA foreign_keys=ON")) - else: + with disable_sqlite_fkeys(op): with op.batch_alter_table("dag", schema=None) as batch_op: batch_op.drop_column("last_parse_duration") diff --git a/airflow-core/src/airflow/migrations/versions/0101_3_2_0_ui_improvements_for_deadlines.py b/airflow-core/src/airflow/migrations/versions/0101_3_2_0_ui_improvements_for_deadlines.py index 3e1a5c5d54f8b..704cb8277ab8e 100644 --- a/airflow-core/src/airflow/migrations/versions/0101_3_2_0_ui_improvements_for_deadlines.py +++ b/airflow-core/src/airflow/migrations/versions/0101_3_2_0_ui_improvements_for_deadlines.py @@ -45,6 +45,7 @@ from airflow._shared.timezones import timezone from airflow.configuration import conf +from airflow.migrations.utils import disable_sqlite_fkeys from airflow.serialization.enums import Encoding from airflow.utils.hashlib_wrapper import md5 from airflow.utils.sqlalchemy import UtcDateTime @@ -171,64 +172,58 @@ def upgrade() -> None: # user-provided DeadlineDefinition, and the actual instance of a Definition is (still) the Deadline. # This feels more intuitive than DeadlineAlert defining the Deadline. - op.create_table( - "deadline_alert", - sa.Column("id", sa.Uuid(), default=uuid6.uuid7), - sa.Column("created_at", UtcDateTime, nullable=False), - sa.Column("serialized_dag_id", sa.Uuid(), nullable=False), - sa.Column("name", sa.String(250), nullable=True), - sa.Column("description", sa.Text(), nullable=True), - sa.Column("reference", sa.JSON(), nullable=False), - sa.Column("interval", sa.Float(), nullable=False), - sa.Column("callback_def", sa.JSON(), nullable=False), - sa.PrimaryKeyConstraint("id", name=op.f("deadline_alert_pkey")), - ) - - conn = op.get_bind() - dialect_name = conn.dialect.name - - if dialect_name == "sqlite": - conn.execute(sa.text("PRAGMA foreign_keys=OFF")) - - with op.batch_alter_table("deadline", schema=None) as batch_op: - batch_op.add_column(sa.Column("deadline_alert_id", sa.Uuid(), nullable=True)) - batch_op.add_column(sa.Column("created_at", UtcDateTime, nullable=True)) - batch_op.add_column(sa.Column("last_updated_at", UtcDateTime, nullable=True)) - batch_op.create_foreign_key( - batch_op.f("deadline_deadline_alert_id_fkey"), + with disable_sqlite_fkeys(op): + op.create_table( "deadline_alert", - ["deadline_alert_id"], - ["id"], - ondelete="SET NULL", + sa.Column("id", sa.Uuid(), default=uuid6.uuid7), + sa.Column("created_at", UtcDateTime, nullable=False), + sa.Column("serialized_dag_id", sa.Uuid(), nullable=False), + sa.Column("name", sa.String(250), nullable=True), + sa.Column("description", sa.Text(), nullable=True), + sa.Column("reference", sa.JSON(), nullable=False), + sa.Column("interval", sa.Float(), nullable=False), + sa.Column("callback_def", sa.JSON(), nullable=False), + sa.PrimaryKeyConstraint("id", name=op.f("deadline_alert_pkey")), ) - # For migration/backcompat purposes if no timestamp is there from the migration, use now() - # then lock the columns down so all new entries require the timestamps to be provided. - now = timezone.utcnow() - conn.execute( - sa.text(""" - UPDATE deadline - SET created_at = :now, last_updated_at = :now - WHERE created_at IS NULL OR last_updated_at IS NULL - """), - {"now": now}, - ) + conn = op.get_bind() + + with op.batch_alter_table("deadline", schema=None) as batch_op: + batch_op.add_column(sa.Column("deadline_alert_id", sa.Uuid(), nullable=True)) + batch_op.add_column(sa.Column("created_at", UtcDateTime, nullable=True)) + batch_op.add_column(sa.Column("last_updated_at", UtcDateTime, nullable=True)) + batch_op.create_foreign_key( + batch_op.f("deadline_deadline_alert_id_fkey"), + "deadline_alert", + ["deadline_alert_id"], + ["id"], + ondelete="SET NULL", + ) - with op.batch_alter_table("deadline", schema=None) as batch_op: - batch_op.alter_column("created_at", existing_type=UtcDateTime, nullable=False) - batch_op.alter_column("last_updated_at", existing_type=UtcDateTime, nullable=False) - - with op.batch_alter_table("deadline_alert", schema=None) as batch_op: - batch_op.create_foreign_key( - batch_op.f("deadline_alert_serialized_dag_id_fkey"), - "serialized_dag", - ["serialized_dag_id"], - ["id"], - ondelete="CASCADE", + # For migration/backcompat purposes if no timestamp is there from the migration, use now() + # then lock the columns down so all new entries require the timestamps to be provided. + now = timezone.utcnow() + conn.execute( + sa.text(""" + UPDATE deadline + SET created_at = :now, last_updated_at = :now + WHERE created_at IS NULL OR last_updated_at IS NULL + """), + {"now": now}, ) - if dialect_name == "sqlite": - conn.execute(sa.text("PRAGMA foreign_keys=ON")) + with op.batch_alter_table("deadline", schema=None) as batch_op: + batch_op.alter_column("created_at", existing_type=UtcDateTime, nullable=False) + batch_op.alter_column("last_updated_at", existing_type=UtcDateTime, nullable=False) + + with op.batch_alter_table("deadline_alert", schema=None) as batch_op: + batch_op.create_foreign_key( + batch_op.f("deadline_alert_serialized_dag_id_fkey"), + "serialized_dag", + ["serialized_dag_id"], + ["id"], + ondelete="CASCADE", + ) migrate_existing_deadline_alert_data_from_serialized_dag() @@ -237,23 +232,15 @@ def downgrade() -> None: """Remove changes that were added to enable adding DeadlineAlerts to the UI.""" migrate_deadline_alert_data_back_to_serialized_dag() - conn = op.get_bind() - dialect_name = conn.dialect.name - - if dialect_name == "sqlite": - conn.execute(sa.text("PRAGMA foreign_keys=OFF")) - - with op.batch_alter_table("deadline", schema=None) as batch_op: - batch_op.drop_constraint(batch_op.f("deadline_deadline_alert_id_fkey"), type_="foreignkey") - batch_op.drop_column("deadline_alert_id") - batch_op.drop_column("last_updated_at") - batch_op.drop_column("created_at") - - with op.batch_alter_table("deadline_alert", schema=None) as batch_op: - batch_op.drop_constraint(batch_op.f("deadline_alert_serialized_dag_id_fkey"), type_="foreignkey") + with disable_sqlite_fkeys(op): + with op.batch_alter_table("deadline", schema=None) as batch_op: + batch_op.drop_constraint(batch_op.f("deadline_deadline_alert_id_fkey"), type_="foreignkey") + batch_op.drop_column("deadline_alert_id") + batch_op.drop_column("last_updated_at") + batch_op.drop_column("created_at") - if dialect_name == "sqlite": - conn.execute(sa.text("PRAGMA foreign_keys=ON")) + with op.batch_alter_table("deadline_alert", schema=None) as batch_op: + batch_op.drop_constraint(batch_op.f("deadline_alert_serialized_dag_id_fkey"), type_="foreignkey") op.drop_table("deadline_alert") diff --git a/airflow-core/tests/unit/migrations/test_sqlite_migration_utils.py b/airflow-core/tests/unit/migrations/test_sqlite_migration_utils.py new file mode 100644 index 0000000000000..d4f9b375c3c5f --- /dev/null +++ b/airflow-core/tests/unit/migrations/test_sqlite_migration_utils.py @@ -0,0 +1,113 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pytest + +from airflow.migrations.utils import disable_sqlite_fkeys + + +class _Dialect: + def __init__(self, name: str) -> None: + self.name = name + + +class _Bind: + def __init__(self, dialect_name: str) -> None: + self.dialect = _Dialect(name=dialect_name) + + +class _Context: + def __init__(self, dialect_name: str) -> None: + self.dialect = _Dialect(name=dialect_name) + + +class _FakeOp: + def __init__(self, dialect_name: str) -> None: + self._bind = _Bind(dialect_name=dialect_name) + self.executed: list[str] = [] + + def get_bind(self) -> _Bind: + return self._bind + + def execute(self, statement: str) -> None: + self.executed.append(statement) + + +class _OfflineFakeOp: + """Simulates Alembic offline mode where get_bind() returns None.""" + + def __init__(self, dialect_name: str) -> None: + self._context = _Context(dialect_name=dialect_name) + self.executed: list[str] = [] + + def get_bind(self) -> None: + return None + + def get_context(self) -> _Context: + return self._context + + def execute(self, statement: str) -> None: + self.executed.append(statement) + + +def test_disable_sqlite_fkeys_restores_pragma_on_success() -> None: + op = _FakeOp(dialect_name="sqlite") + + with disable_sqlite_fkeys(op) as yielded_op: + assert yielded_op is op + + assert op.executed == ["PRAGMA foreign_keys=off", "PRAGMA foreign_keys=on"] + + +def test_disable_sqlite_fkeys_restores_pragma_on_exception() -> None: + op = _FakeOp(dialect_name="sqlite") + + with pytest.raises(RuntimeError, match="boom"): + with disable_sqlite_fkeys(op): + raise RuntimeError("boom") + + assert op.executed == ["PRAGMA foreign_keys=off", "PRAGMA foreign_keys=on"] + + +def test_disable_sqlite_fkeys_noop_for_non_sqlite() -> None: + op = _FakeOp(dialect_name="postgresql") + + with disable_sqlite_fkeys(op) as yielded_op: + assert yielded_op is op + + assert op.executed == [] + + +def test_disable_sqlite_fkeys_offline_mode_sqlite() -> None: + """Alembic offline mode: get_bind() returns None; dialect comes from the migration context.""" + op = _OfflineFakeOp(dialect_name="sqlite") + + with disable_sqlite_fkeys(op) as yielded_op: + assert yielded_op is op + + assert op.executed == ["PRAGMA foreign_keys=off", "PRAGMA foreign_keys=on"] + + +def test_disable_sqlite_fkeys_offline_mode_non_sqlite() -> None: + """Alembic offline mode: get_bind() returns None; non-SQLite dialect is a noop.""" + op = _OfflineFakeOp(dialect_name="postgresql") + + with disable_sqlite_fkeys(op) as yielded_op: + assert yielded_op is op + + assert op.executed == []