diff --git a/airflow-core/src/airflow/cli/cli_config.py b/airflow-core/src/airflow/cli/cli_config.py index a09851e5e24ff..d4d4611243bb0 100644 --- a/airflow-core/src/airflow/cli/cli_config.py +++ b/airflow-core/src/airflow/cli/cli_config.py @@ -1519,6 +1519,19 @@ class GroupCommand(NamedTuple): args=(ARG_OUTPUT, ARG_VERBOSE), ), ) +STATE_STORE_COMMANDS = ( + ActionCommand( + name="cleanup", + help="Remove expired stored state via the configured state backend", + description=( + "Reads [state_store] default_retention_days from config and deletes task_state rows " + "older than the configured threshold. Use --dry-run to preview without deleting." + ), + func=lazy_load_command("airflow.cli.commands.state_store_command.cleanup"), + args=(ARG_DB_DRY_RUN, ARG_VERBOSE), + ), +) + DB_COMMANDS = ( ActionCommand( name="check-migrations", @@ -2102,6 +2115,11 @@ class GroupCommand(NamedTuple): help="Display providers", subcommands=PROVIDERS_COMMANDS, ), + GroupCommand( + name="state-store", + help="Manage task and asset state storage", + subcommands=STATE_STORE_COMMANDS, + ), ActionCommand( name="rotate-fernet-key", func=lazy_load_command("airflow.cli.commands.rotate_fernet_key_command.rotate_fernet_key"), diff --git a/airflow-core/src/airflow/cli/commands/state_store_command.py b/airflow-core/src/airflow/cli/commands/state_store_command.py new file mode 100644 index 0000000000000..6aa5a83c200cd --- /dev/null +++ b/airflow-core/src/airflow/cli/commands/state_store_command.py @@ -0,0 +1,50 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import logging + +log = logging.getLogger(__name__) + +# Other state operations (list, get, delete per key) will be added here in the future. + + +def cleanup(args) -> None: + """Remove expired task state rows via the configured state backend.""" + from airflow.state import get_state_backend + from airflow.state.metastore import MetastoreStateBackend + + backend = get_state_backend() + + if args.dry_run: + if isinstance(backend, MetastoreStateBackend): + summary = backend._dry_run_summary() + expired = summary["expired"] + if not expired: + print("Nothing to delete.") + return + print(f"Would delete {len(expired)} task state row(s):\n") + for dag_id, run_id, task_id, map_index, key in expired: + print( + f" DAG {dag_id!r}, run {run_id!r}, task {task_id!r}, map_index {map_index!r}, key {key!r}" + ) + else: + print("Custom backend configured — cannot preview rows.") + return + + log.info("Running state store cleanup") + backend.cleanup() diff --git a/airflow-core/src/airflow/config_templates/config.yml b/airflow-core/src/airflow/config_templates/config.yml index 8a060007978ee..3f1b80fbf1d17 100644 --- a/airflow-core/src/airflow/config_templates/config.yml +++ b/airflow-core/src/airflow/config_templates/config.yml @@ -3025,6 +3025,24 @@ state_store: type: string example: "mypackage.state.CustomStateBackend" default: "airflow.state.metastore.MetastoreStateBackend" + default_retention_days: + description: | + Number of days to retain task_state rows after their last update. + Rows older than this are removed by the scheduler's periodic cleanup. + This config does not affect asset_state rows. + Set to 0 to disable time-based cleanup entirely. + version_added: 3.3.0 + type: integer + example: "7" + default: "30" + state_cleanup_batch_size: + description: | + Number of rows deleted per batch during cleanup. Defaults to 0 (no batching). + Tune this on deployments with large task_state tables to improve performance per transaction. + version_added: 3.3.0 + type: integer + example: "10000" + default: "0" profiling: description: | diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py b/airflow-core/src/airflow/jobs/scheduler_job_runner.py index 3eed95a8bb030..eeabfc1427678 100644 --- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py +++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py @@ -33,7 +33,22 @@ from itertools import groupby from typing import TYPE_CHECKING, Any, cast -from sqlalchemy import CTE, and_, case, delete, exists, func, inspect, or_, select, text, tuple_, update +from sqlalchemy import ( + CTE, + and_, + case, + delete, + delete as _delete, + exists, + func, + inspect, + or_, + select, + select as _select, + text, + tuple_, + update, +) from sqlalchemy.exc import DBAPIError, OperationalError from sqlalchemy.orm import joinedload, lazyload, load_only, make_transient, selectinload from sqlalchemy.sql import expression @@ -58,10 +73,12 @@ from airflow.models import Deadline, Log from airflow.models.asset import ( AssetActive, + AssetActive as _AssetActive, AssetAliasModel, AssetDagRunQueue, AssetEvent, AssetModel, + AssetModel as _AssetModel, AssetPartitionDagRun, AssetWatcherModel, DagScheduleAssetAliasReference, @@ -70,6 +87,7 @@ TaskInletAssetReference, TaskOutletAssetReference, ) +from airflow.models.asset_state import AssetStateModel from airflow.models.backfill import Backfill, BackfillDagRun from airflow.models.callback import Callback, CallbackType, ExecutorCallback from airflow.models.dag import DagModel @@ -3092,6 +3110,7 @@ def _update_asset_orphanage(self, session: Session = NEW_SESSION) -> None: self._orphan_unreferenced_assets(orphan_query, session=session) self._activate_referenced_assets(activate_query, session=session) + self._cleanup_orphaned_asset_state(session=session) @staticmethod def _orphan_unreferenced_assets(assets_query: CTE, *, session: Session) -> None: @@ -3200,6 +3219,21 @@ def _activate_assets_generate_warnings() -> Iterator[tuple[str, str]]: session.add(warning) existing_warned_dag_ids.add(warning.dag_id) + @staticmethod + def _cleanup_orphaned_asset_state(*, session: Session) -> None: + """ + Delete asset_state rows for assets no longer active in any DAG. + + When _orphan_unreferenced_assets removes an asset from asset_active, its + asset_state rows become unreachable — no task can write to them anymore. + This runs in the same pass as asset orphanage to keep the table clean. + """ + active_asset_ids = _select(_AssetModel.id).join( + _AssetActive, + (_AssetActive.name == _AssetModel.name) & (_AssetActive.uri == _AssetModel.uri), + ) + session.execute(_delete(AssetStateModel).where(AssetStateModel.asset_id.not_in(active_asset_ids))) + def _executor_to_workloads( self, workloads: Iterable[SchedulerWorkload], diff --git a/airflow-core/src/airflow/migrations/versions/0112_3_3_0_add_task_state_and_asset_state_tables.py b/airflow-core/src/airflow/migrations/versions/0112_3_3_0_add_task_state_and_asset_state_tables.py index 7f852d05c6ca6..e64f80a05b119 100644 --- a/airflow-core/src/airflow/migrations/versions/0112_3_3_0_add_task_state_and_asset_state_tables.py +++ b/airflow-core/src/airflow/migrations/versions/0112_3_3_0_add_task_state_and_asset_state_tables.py @@ -57,6 +57,7 @@ def upgrade(): ) op.create_table( "task_state", + sa.Column("id", sa.Integer(), nullable=False, autoincrement=True), sa.Column("dag_run_id", sa.Integer(), nullable=False), sa.Column("task_id", StringID(), nullable=False), sa.Column("map_index", sa.Integer(), server_default="-1", nullable=False), @@ -65,20 +66,24 @@ def upgrade(): sa.Column("run_id", StringID(), nullable=False), sa.Column("value", sa.Text().with_variant(mysql.MEDIUMTEXT(), "mysql"), nullable=False), sa.Column("updated_at", UtcDateTime(), nullable=False), + sa.Column("expires_at", UtcDateTime(), nullable=True), sa.ForeignKeyConstraint( ["dag_run_id"], ["dag_run.id"], name="task_state_dag_run_fkey", ondelete="CASCADE" ), - sa.PrimaryKeyConstraint("dag_run_id", "task_id", "map_index", "key", name="task_state_pkey"), + sa.PrimaryKeyConstraint("id", name="task_state_pkey"), + sa.UniqueConstraint("dag_run_id", "task_id", "map_index", "key", name="task_state_uq"), ) with op.batch_alter_table("task_state", schema=None) as batch_op: batch_op.create_index( "idx_task_state_lookup", ["dag_id", "run_id", "task_id", "map_index"], unique=False ) + batch_op.create_index("idx_task_state_expires_at", ["expires_at"], unique=False) def downgrade(): """Unapply add task_state and asset_state tables.""" with op.batch_alter_table("task_state", schema=None) as batch_op: + batch_op.drop_index("idx_task_state_expires_at") batch_op.drop_index("idx_task_state_lookup") op.drop_table("task_state") diff --git a/airflow-core/src/airflow/models/task_state.py b/airflow-core/src/airflow/models/task_state.py index dbc17e3b06950..72a7624eddd6e 100644 --- a/airflow-core/src/airflow/models/task_state.py +++ b/airflow-core/src/airflow/models/task_state.py @@ -19,7 +19,7 @@ from datetime import datetime -from sqlalchemy import ForeignKeyConstraint, Index, Integer, PrimaryKeyConstraint, String, Text +from sqlalchemy import ForeignKeyConstraint, Index, Integer, String, Text, UniqueConstraint from sqlalchemy.dialects.mysql import MEDIUMTEXT from sqlalchemy.orm import Mapped, mapped_column @@ -39,19 +39,27 @@ class TaskStateModel(Base): __tablename__ = "task_state" - dag_run_id: Mapped[int] = mapped_column(Integer, nullable=False, primary_key=True) - task_id: Mapped[str] = mapped_column(StringID(), nullable=False, primary_key=True) - map_index: Mapped[int] = mapped_column(Integer, primary_key=True, nullable=False, server_default="-1") - key: Mapped[str] = mapped_column(String(512, **COLLATION_ARGS), nullable=False, primary_key=True) + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + + dag_run_id: Mapped[int] = mapped_column(Integer, nullable=False) + task_id: Mapped[str] = mapped_column(StringID(), nullable=False) + map_index: Mapped[int] = mapped_column(Integer, nullable=False, server_default="-1") + key: Mapped[str] = mapped_column(String(512, **COLLATION_ARGS), nullable=False) dag_id: Mapped[str] = mapped_column(StringID(), nullable=False) run_id: Mapped[str] = mapped_column(StringID(), nullable=False) value: Mapped[str] = mapped_column(Text().with_variant(MEDIUMTEXT, "mysql"), nullable=False) updated_at: Mapped[datetime] = mapped_column(UtcDateTime, default=timezone.utcnow, nullable=False) + # Optional override for early expiry. When set, garbage collection deletes this row when + # expires_at < now(), even if updated_at is recent. NULL means no early expiry — + # the row is still cleaned up by the global `updated_at + default_retention_days` check. + # Populated via task_state.set(retention_days=N) for keys that should expire differently + # than the deployment wide default. + expires_at: Mapped[datetime | None] = mapped_column(UtcDateTime, nullable=True) __table_args__ = ( - PrimaryKeyConstraint("dag_run_id", "task_id", "map_index", "key", name="task_state_pkey"), + UniqueConstraint("dag_run_id", "task_id", "map_index", "key", name="task_state_uq"), ForeignKeyConstraint( ["dag_run_id"], ["dag_run.id"], @@ -59,4 +67,5 @@ class TaskStateModel(Base): ondelete="CASCADE", ), Index("idx_task_state_lookup", "dag_id", "run_id", "task_id", "map_index"), + Index("idx_task_state_expires_at", "expires_at"), ) diff --git a/airflow-core/src/airflow/state/metastore.py b/airflow-core/src/airflow/state/metastore.py index 3382dad81fc65..cdf595d2be634 100644 --- a/airflow-core/src/airflow/state/metastore.py +++ b/airflow-core/src/airflow/state/metastore.py @@ -17,17 +17,20 @@ # under the License. from __future__ import annotations +from datetime import datetime, timedelta from typing import TYPE_CHECKING +import structlog from sqlalchemy import delete, select from airflow._shared.state import AssetScope, BaseStateBackend, StateScope, TaskScope from airflow._shared.timezones import timezone +from airflow.configuration import conf from airflow.models.asset_state import AssetStateModel from airflow.models.dagrun import DagRun from airflow.models.task_state import TaskStateModel from airflow.typing_compat import assert_never -from airflow.utils.session import NEW_SESSION, create_session_async, provide_session +from airflow.utils.session import NEW_SESSION, create_session, create_session_async, provide_session from airflow.utils.sqlalchemy import get_dialect_name if TYPE_CHECKING: @@ -38,6 +41,21 @@ from sqlalchemy.orm import Session +log = structlog.get_logger(__name__) + + +def _compute_expires_at(now: datetime) -> datetime | None: + """ + Return the expiry timestamp for a new task state row based on config. + + Returns None if default_retention_days is 0 (never expires). + """ + retention_days = conf.getint("state_store", "default_retention_days") + if retention_days <= 0: + return None + return now + timedelta(days=retention_days) + + def _build_upsert_stmt( dialect: str | None, model: type, @@ -176,6 +194,7 @@ def _set_task_state(self, scope: TaskScope, key: str, value: str, *, session: Se if dag_run_id is None: raise ValueError(f"No DagRun found for dag_id={scope.dag_id!r} run_id={scope.run_id!r}") now = timezone.utcnow() + expires_at = _compute_expires_at(now) values = dict( dag_run_id=dag_run_id, dag_id=scope.dag_id, @@ -185,13 +204,14 @@ def _set_task_state(self, scope: TaskScope, key: str, value: str, *, session: Se key=key, value=value, updated_at=now, + expires_at=expires_at, ) stmt = _build_upsert_stmt( get_dialect_name(session), TaskStateModel, ["dag_run_id", "task_id", "map_index", "key"], values, - dict(value=value, updated_at=now), + dict(value=value, updated_at=now, expires_at=expires_at), ) session.execute(stmt) @@ -252,6 +272,51 @@ def _clear_asset_state(self, scope: AssetScope, *, session: Session) -> None: ) ) + def cleanup(self) -> None: + """ + Remove expired task state rows. + + ``expires_at`` is set at write time on every ``set()`` call, so cleanup is a single + ``WHERE expires_at < now()`` pass. Rows with ``expires_at=NULL`` (default_retention_days=0) + are never deleted. Batching is configurable via ``[state_store] state_cleanup_batch_size``. + """ + batch_size = conf.getint("state_store", "state_cleanup_batch_size") + now = timezone.utcnow() + + def _delete_batched(where_clause) -> int: + total = 0 + with create_session() as session: + while True: + id_query = select(TaskStateModel.id).where(where_clause) + if batch_size > 0: + id_query = id_query.limit(batch_size) + ids = session.scalars(id_query).all() + if not ids: + break + session.execute(delete(TaskStateModel).where(TaskStateModel.id.in_(ids))) + session.commit() + total += len(ids) + if batch_size <= 0 or len(ids) < batch_size: + break + return total + + deleted = _delete_batched(TaskStateModel.expires_at < now) + log.info("Deleted expired task_state rows", rows_deleted=deleted) + + def _dry_run_summary(self) -> dict[str, list]: + """Return rows that would be deleted by cleanup() without deleting anything.""" + now = timezone.utcnow() + cols = ( + TaskStateModel.dag_id, + TaskStateModel.run_id, + TaskStateModel.task_id, + TaskStateModel.map_index, + TaskStateModel.key, + ) + with create_session() as session: + expired = session.execute(select(*cols).where(TaskStateModel.expires_at < now)).all() + return {"expired": list(expired)} + async def _aget_task_state(self, scope: TaskScope, key: str, *, session: AsyncSession) -> str | None: row = await session.scalar( select(TaskStateModel).where( @@ -276,6 +341,7 @@ async def _aset_task_state( if dag_run_id is None: raise ValueError(f"No DagRun found for dag_id={scope.dag_id!r} run_id={scope.run_id!r}") now = timezone.utcnow() + expires_at = _compute_expires_at(now) values = dict( dag_run_id=dag_run_id, dag_id=scope.dag_id, @@ -285,6 +351,7 @@ async def _aset_task_state( key=key, value=value, updated_at=now, + expires_at=expires_at, ) # get_dialect_name expects a sync Session; sync_session is the underlying Session the async wrapper delegates to stmt = _build_upsert_stmt( @@ -292,7 +359,7 @@ async def _aset_task_state( TaskStateModel, ["dag_run_id", "task_id", "map_index", "key"], values, - dict(value=value, updated_at=now), + dict(value=value, updated_at=now, expires_at=expires_at), ) await session.execute(stmt) diff --git a/airflow-core/tests/unit/cli/commands/test_state_store_command.py b/airflow-core/tests/unit/cli/commands/test_state_store_command.py new file mode 100644 index 0000000000000..a6ad669181156 --- /dev/null +++ b/airflow-core/tests/unit/cli/commands/test_state_store_command.py @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from argparse import Namespace +from unittest import mock +from unittest.mock import MagicMock, patch + +from airflow.cli.commands.state_store_command import cleanup +from airflow.state.metastore import MetastoreStateBackend + + +class TestStateStoreCleanupCommand: + def test_cleanup_calls_backend(self): + args = Namespace(dry_run=False, verbose=False) + with mock.patch("airflow.state.get_state_backend") as mock_get_backend: + mock_backend = MagicMock() + mock_get_backend.return_value = mock_backend + + cleanup(args) + + mock_backend.cleanup.assert_called_once_with() + + def test_dry_run_does_not_call_backend(self, capsys): + args = Namespace(dry_run=True, verbose=False) + backend = MetastoreStateBackend() + with ( + mock.patch("airflow.state.get_state_backend", return_value=backend), + patch.object(backend, "_dry_run_summary", return_value={"expired": []}), + ): + cleanup(args) + + captured = capsys.readouterr() + assert "Nothing to delete" in captured.out diff --git a/airflow-core/tests/unit/state/test_metastore.py b/airflow-core/tests/unit/state/test_metastore.py index 98993d7133c41..2407f21d51bc4 100644 --- a/airflow-core/tests/unit/state/test_metastore.py +++ b/airflow-core/tests/unit/state/test_metastore.py @@ -17,13 +17,18 @@ # under the License. from __future__ import annotations +from contextlib import contextmanager +from datetime import timedelta from typing import TYPE_CHECKING +from unittest.mock import patch import pytest -from sqlalchemy import select +from sqlalchemy import Delete, select from airflow._shared.timezones import timezone +from airflow.configuration import conf from airflow.models.asset import AssetModel +from airflow.models.asset_state import AssetStateModel from airflow.models.dagrun import DagRun, DagRunType from airflow.models.task_state import TaskStateModel from airflow.state import AssetScope, TaskScope, resolve_state_backend @@ -234,6 +239,112 @@ def test_clear_with_all_map_indices_flag_wipes_wide( assert backend.get(scope0, "job_id", session=session) is None assert backend.get(scope1, "job_id", session=session) is None + def test_set_populates_expires_at( + self, session: Session, backend: MetastoreStateBackend, dag_run: DagRun + ): + """set() always populates expires_at so cleanup has a single pass.""" + scope = TaskScope(dag_id=DAG_ID, run_id=RUN_ID, task_id=TASK_ID) + backend.set(scope, "job_id", "app_1234", session=session) + session.flush() + + row = session.scalar(select(TaskStateModel).where(TaskStateModel.key == "job_id")) + assert row is not None + assert row.expires_at is not None + + def test_cleanup_removes_expired_rows( + self, session: Session, backend: MetastoreStateBackend, dag_run: DagRun + ): + scope = TaskScope(dag_id=DAG_ID, run_id=RUN_ID, task_id=TASK_ID) + backend.set(scope, "old_key", "old_value", session=session) + backend.set(scope, "new_key", "new_value", session=session) + session.flush() + + # Backdate expires_at on old_key to simulate it having expired + old_row = session.scalar( + select(TaskStateModel).where(TaskStateModel.dag_id == DAG_ID, TaskStateModel.key == "old_key") + ) + assert old_row is not None + old_row.expires_at = timezone.utcnow() - timedelta(hours=1) + session.flush() + session.commit() + + backend.cleanup() + + session.expire_all() + assert session.scalar(select(TaskStateModel).where(TaskStateModel.key == "old_key")) is None + assert session.scalar(select(TaskStateModel).where(TaskStateModel.key == "new_key")) is not None + + def test_cleanup_removes_expires_at_rows( + self, session: Session, backend: MetastoreStateBackend, dag_run: DagRun + ): + scope = TaskScope(dag_id=DAG_ID, run_id=RUN_ID, task_id=TASK_ID) + backend.set(scope, "short_lived", "value", session=session) + session.flush() + + row = session.scalar( + select(TaskStateModel).where(TaskStateModel.dag_id == DAG_ID, TaskStateModel.key == "short_lived") + ) + assert row is not None + row.expires_at = timezone.utcnow() - timedelta(hours=1) + session.flush() + session.commit() + + backend.cleanup() + + session.expire_all() + + # cleaned up via expires_at, even though updated_at is recent + assert session.scalar(select(TaskStateModel).where(TaskStateModel.key == "short_lived")) is None + + @conf_vars({("state_store", "state_cleanup_batch_size"): "2"}) + def test_cleanup_batches_deletes(self, session: Session, backend: MetastoreStateBackend, dag_run: DagRun): + """cleanup() issues one DELETE per batch, not one DELETE for all rows at once. + + Verifying this is not straightforward because cleanup() creates its own internal session, + so we cannot simply inspect it from outside, so what we do is: + + 1. Patch `create_session` in the metastore module with a thin wrapper (`tracking_cs`) that + yields the real session but replaces `session.execute` with a spy. + 2. The spy checks whether the statement being executed is a sqla Delete object and + records it if so. + 3. After cleanup() returns, we assert that exactly ceil(/). + """ + import airflow.state.metastore as metastore_mod + + scope = TaskScope(dag_id=DAG_ID, run_id=RUN_ID, task_id=TASK_ID) + for key in ("k1", "k2", "k3", "k4", "k5"): + backend.set(scope, key, "v", session=session) + session.flush() + + session.execute( + TaskStateModel.__table__.update().values(expires_at=timezone.utcnow() - timedelta(hours=1)) + ) + session.commit() + + deletes = [] + original_cs = metastore_mod.create_session + + @contextmanager + def tracking_cs(*args, **kwargs): + with original_cs(*args, **kwargs) as s: + orig_execute = s.execute + + def tracked(stmt, *a, **kw): + if isinstance(stmt, Delete): + deletes.append(stmt) + return orig_execute(stmt, *a, **kw) + + s.execute = tracked + yield s + + with patch.object(metastore_mod, "create_session", side_effect=tracking_cs): + backend.cleanup() + + session.expire_all() + + # batch_size=2, 5 rows -> delete runs 3 times (2+2+1) + assert len(deletes) == 3 + class TestMetastoreStateBackendAssetScope: def test_get_returns_none_for_missing_key( @@ -306,6 +417,19 @@ def test_different_assets_are_isolated( assert backend.get(scope2, "watermark", session=session) is None + def test_cleanup_does_not_touch_asset_state( + self, session: Session, backend: MetastoreStateBackend, asset: AssetModel + ): + scope = AssetScope(asset_id=asset.id) + backend.set(scope, "watermark", "2026-01-01", session=session) + session.flush() + session.commit() + + backend.cleanup() + + session.expire_all() + assert session.scalar(select(AssetStateModel).where(AssetStateModel.asset_id == asset.id)) is not None + @pytest.mark.asyncio(loop_scope="class") class TestMetastoreStateBackendAsync: @@ -380,6 +504,19 @@ async def test_aset_task_raises_for_missing_dag_run(self, backend: MetastoreStat await backend.aset(scope, "job_id", "app_async") +class TestStateStoreConfig: + def test_defaults(self): + assert conf.getint("state_store", "default_retention_days") == 30 + assert conf.getint("state_store", "state_cleanup_batch_size") == 0 + + @conf_vars( + {("state_store", "default_retention_days"): "7", ("state_store", "state_cleanup_batch_size"): "50"} + ) + def test_overrides(self): + assert conf.getint("state_store", "default_retention_days") == 7 + assert conf.getint("state_store", "state_cleanup_batch_size") == 50 + + class TestResolveStateBackend: @conf_vars({("state_store", "backend"): "airflow.state.metastore.MetastoreStateBackend"}) def test_resolve_returns_configured_backend(self): diff --git a/shared/state/src/airflow_shared/state/__init__.py b/shared/state/src/airflow_shared/state/__init__.py index 463d9f378f315..1e03a381957d5 100644 --- a/shared/state/src/airflow_shared/state/__init__.py +++ b/shared/state/src/airflow_shared/state/__init__.py @@ -122,3 +122,14 @@ async def aclear(self, scope: StateScope, *, all_map_indices: bool = False) -> N scope are cleared. Pass ``all_map_indices=True`` to wipe state across every mapped instance of the task. For ``AssetScope`` the flag has no effect. """ + + def cleanup(self) -> None: + """ + Remove expired and orphaned state records. + + This is a no-op by default. Custom backends override this to implement their own + retention policy. The backend is responsible for reading any relevant config (e.g. + ``[state_store] default_retention_days``) and deciding what to delete. + Airflow does not call this from any standard job — the scheduler triggers it via + ``call_regular_interval`` for the default backend. + """