-
Notifications
You must be signed in to change notification settings - Fork 17k
AIP-103: Adding periodic task state garbage collection and retention support #66463
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,61 @@ | ||
| # Licensed to the Apache Software Foundation (ASF) under one | ||
| # or more contributor license agreements. See the NOTICE file | ||
| # distributed with this work for additional information | ||
| # regarding copyright ownership. The ASF licenses this file | ||
| # to you under the Apache License, Version 2.0 (the | ||
| # "License"); you may not use this file except in compliance | ||
| # with the License. You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, | ||
| # software distributed under the License is distributed on an | ||
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| # KIND, either express or implied. See the License for the | ||
| # specific language governing permissions and limitations | ||
| # under the License. | ||
| from __future__ import annotations | ||
|
|
||
| import logging | ||
|
|
||
| log = logging.getLogger(__name__) | ||
|
|
||
| # Other state operations (list, get, delete per key) will be added here once the | ||
| # Core API endpoints (PR 6) land. For now, inspection is available via the REST | ||
| # API and the Task Instance detail panel in the UI. | ||
|
|
||
|
|
||
| def cleanup(args) -> None: | ||
| """Remove expired task state rows via the configured state backend.""" | ||
| from airflow.state import get_state_backend | ||
| from airflow.state.metastore import MetastoreStateBackend | ||
|
|
||
| backend = get_state_backend() | ||
|
|
||
| if args.dry_run: | ||
| if isinstance(backend, MetastoreStateBackend): | ||
| summary = backend._dry_run_summary() | ||
| stale, expired = summary["stale"], summary["expired"] | ||
| total = len(stale) + len(expired) | ||
| if not total: | ||
| print("Nothing to delete.") | ||
| return | ||
| print(f"Would delete {total} task state row(s):\n") | ||
| if stale: | ||
| print(f" Older than retention period ({len(stale)}):") | ||
| for dag_id, run_id, task_id, map_index, key in stale: | ||
| print( | ||
| f" DAG {dag_id!r}, run {run_id!r}, task {task_id!r}, map_index {map_index!r}, key {key!r}" | ||
| ) | ||
| if expired: | ||
| print(f"\n Per-key expiry reached ({len(expired)}):") | ||
| for dag_id, run_id, task_id, map_index, key in expired: | ||
| print( | ||
| f" DAG {dag_id!r}, run {run_id!r}, task {task_id!r}, map_index {map_index!r}, key {key!r}" | ||
| ) | ||
| else: | ||
| print("Custom backend configured — cannot preview rows.") | ||
| return | ||
|
|
||
| log.info("Running state store cleanup") | ||
| backend.cleanup() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -65,6 +65,11 @@ def upgrade(): | |
| sa.Column("run_id", StringID(), nullable=False), | ||
| sa.Column("value", sa.Text().with_variant(mysql.MEDIUMTEXT(), "mysql"), nullable=False), | ||
| sa.Column("updated_at", UtcDateTime(), nullable=False), | ||
| # Optional early-expiry override. When set, GC deletes this row when expires_at < now() | ||
| # even if updated_at is recent. NULL means no early expiry — the row is still cleaned | ||
| # up by the global updated_at + default_retention_days check. Populated via | ||
|
Comment on lines
+69
to
+70
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is confusing IMO -- an We can pre-compute the expires_at value at update time by reading the default_retention config then (i.e. cleanup becomes a simpler "SELECT where expires_at < Now()`. This possibly also removes the need for an index on udpated_at. |
||
| # task_state.set(retention_days=N) for keys that should expire sooner than the default. | ||
| sa.Column("expires_at", UtcDateTime(), nullable=True), | ||
| sa.ForeignKeyConstraint( | ||
| ["dag_run_id"], ["dag_run.id"], name="task_state_dag_run_fkey", ondelete="CASCADE" | ||
| ), | ||
|
|
@@ -74,11 +79,15 @@ def upgrade(): | |
| batch_op.create_index( | ||
| "idx_task_state_lookup", ["dag_id", "run_id", "task_id", "map_index"], unique=False | ||
| ) | ||
| batch_op.create_index("idx_task_state_updated_at", ["updated_at"], unique=False) | ||
| batch_op.create_index("idx_task_state_expires_at", ["expires_at"], unique=False) | ||
|
|
||
|
|
||
| def downgrade(): | ||
| """Unapply add task_state and asset_state tables.""" | ||
| with op.batch_alter_table("task_state", schema=None) as batch_op: | ||
| batch_op.drop_index("idx_task_state_expires_at") | ||
| batch_op.drop_index("idx_task_state_updated_at") | ||
| batch_op.drop_index("idx_task_state_lookup") | ||
|
|
||
| op.drop_table("task_state") | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,17 +17,21 @@ | |
| # under the License. | ||
| from __future__ import annotations | ||
|
|
||
| from datetime import timedelta | ||
| from typing import TYPE_CHECKING | ||
|
|
||
| import structlog | ||
| from sqlalchemy import delete, select | ||
| from sqlalchemy.sql.expression import tuple_ | ||
|
|
||
| from airflow._shared.state import AssetScope, BaseStateBackend, StateScope, TaskScope | ||
| from airflow._shared.timezones import timezone | ||
| from airflow.configuration import conf | ||
| from airflow.models.asset_state import AssetStateModel | ||
| from airflow.models.dagrun import DagRun | ||
| from airflow.models.task_state import TaskStateModel | ||
| from airflow.typing_compat import assert_never | ||
| from airflow.utils.session import NEW_SESSION, create_session_async, provide_session | ||
| from airflow.utils.session import NEW_SESSION, create_session, create_session_async, provide_session | ||
| from airflow.utils.sqlalchemy import get_dialect_name | ||
|
|
||
| if TYPE_CHECKING: | ||
|
|
@@ -38,6 +42,9 @@ | |
| from sqlalchemy.orm import Session | ||
|
|
||
|
|
||
| log = structlog.get_logger(__name__) | ||
|
|
||
|
|
||
| def _build_upsert_stmt( | ||
| dialect: str | None, | ||
| model: type, | ||
|
|
@@ -252,6 +259,84 @@ def _clear_asset_state(self, scope: AssetScope, *, session: Session) -> None: | |
| ) | ||
| ) | ||
|
|
||
| def cleanup(self) -> None: | ||
| """ | ||
| Remove expired task state rows. | ||
|
|
||
| Reads ``[state_store] default_retention_days`` and ``[state_store] state_cleanup_batch_size`` | ||
| from config. Each pass runs in its own transaction so partial progress is committed even if a | ||
| later pass fails. Each pass is batched to avoid long-running locks on the table. | ||
|
|
||
| Two passes: | ||
| a. Rows where updated_at < now() - default_retention_days (global retention) | ||
| b. Rows where expires_at < now() (per-key early expiry set by the operator) | ||
| """ | ||
| retention_days = conf.getint("state_store", "default_retention_days") | ||
| batch_size = conf.getint("state_store", "state_cleanup_batch_size") | ||
| now = timezone.utcnow() | ||
| older_than = now - timedelta(days=retention_days) if retention_days > 0 else None | ||
|
|
||
| pk_cols = ( | ||
| TaskStateModel.dag_run_id, | ||
| TaskStateModel.task_id, | ||
| TaskStateModel.map_index, | ||
| TaskStateModel.key, | ||
| ) | ||
|
Comment on lines
+279
to
+284
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Given how this is used, it might be time to add a single column |
||
|
|
||
| def _delete_batched(where_clause) -> int: | ||
| total = 0 | ||
| while True: | ||
| with create_session() as session: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think this should be a new session object each time around the loop, but instead one session object that is explicitly |
||
| pk_query = select(*pk_cols).where(where_clause) | ||
| if batch_size > 0: | ||
| pk_query = pk_query.limit(batch_size) | ||
| ids = session.execute(pk_query).all() | ||
| if not ids: | ||
| break | ||
| session.execute(delete(TaskStateModel).where(tuple_(*pk_cols).in_(ids))) | ||
| total += len(ids) | ||
| if batch_size <= 0 or len(ids) < batch_size: | ||
| break | ||
| return total | ||
|
|
||
| if older_than: | ||
| deleted = _delete_batched(TaskStateModel.updated_at < older_than) | ||
| log.info("Deleted stale task_state rows", rows_deleted=deleted, older_than=older_than) | ||
|
|
||
| deleted = _delete_batched((TaskStateModel.expires_at.isnot(None)) & (TaskStateModel.expires_at < now)) | ||
| log.info("Deleted expired task_state rows", rows_deleted=deleted) | ||
|
|
||
| def _dry_run_summary(self) -> dict[str, list]: | ||
| """ | ||
| Return rows that would be deleted by cleanup(), without deleting anything. | ||
|
|
||
| Returns a dict with keys 'stale' and 'expired', each containing a list of | ||
| (dag_id, run_id, task_id, map_index, key) tuples. | ||
| """ | ||
| retention_days = conf.getint("state_store", "default_retention_days") | ||
| now = timezone.utcnow() | ||
| older_than = now - timedelta(days=retention_days) if retention_days > 0 else None | ||
|
|
||
| cols = ( | ||
| TaskStateModel.dag_id, | ||
| TaskStateModel.run_id, | ||
| TaskStateModel.task_id, | ||
| TaskStateModel.map_index, | ||
| TaskStateModel.key, | ||
| ) | ||
|
|
||
| with create_session() as session: | ||
| stale = ( | ||
| session.execute(select(*cols).where(TaskStateModel.updated_at < older_than)).all() | ||
| if older_than | ||
| else [] | ||
| ) | ||
| expired = session.execute( | ||
| select(*cols).where(TaskStateModel.expires_at.isnot(None), TaskStateModel.expires_at < now) | ||
| ).all() | ||
|
|
||
| return {"stale": list(stale), "expired": list(expired)} | ||
|
|
||
|
Comment on lines
+277
to
+339
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If it's a valid assumption that users might produce a large amount of state records here between the I just double checked the concern with Claude: Yes, this is a real problem. Several compounding issues:
The two predicates the cleanup filters on are not indexed:
So both DELETE WHERE updated_at < cutoff and DELETE WHERE expires_at < now()
Compare to airflow db cleanup (utils/db_cleanup.py:217), which deletes in
with create_session() as session: opens one session; each session.execute()
_cleanup_expired_task_state is registered via call_regular_interval, which is
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks, good catches.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Handled it in: cdc4237 |
||
| async def _aget_task_state(self, scope: TaskScope, key: str, *, session: AsyncSession) -> str | None: | ||
| row = await session.scalar( | ||
| select(TaskStateModel).where( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: