Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions airflow-core/src/airflow/cli/cli_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1519,6 +1519,19 @@ class GroupCommand(NamedTuple):
args=(ARG_OUTPUT, ARG_VERBOSE),
),
)
STATE_STORE_COMMANDS = (
ActionCommand(
name="cleanup",
help="Remove expired task state rows via the configured state backend",
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit:

Suggested change
help="Remove expired task state rows via the configured state backend",
help="Remove expired stored state via the configured state backend",

description=(
"Reads [state_store] default_retention_days from config and deletes task_state rows "
"older than the configured threshold. Use --dry-run to preview without deleting."
),
func=lazy_load_command("airflow.cli.commands.state_store_command.cleanup"),
args=(ARG_DB_DRY_RUN, ARG_VERBOSE),
),
)

DB_COMMANDS = (
ActionCommand(
name="check-migrations",
Expand Down Expand Up @@ -2102,6 +2115,11 @@ class GroupCommand(NamedTuple):
help="Display providers",
subcommands=PROVIDERS_COMMANDS,
),
GroupCommand(
name="state-store",
help="Manage task and asset state storage",
subcommands=STATE_STORE_COMMANDS,
),
ActionCommand(
name="rotate-fernet-key",
func=lazy_load_command("airflow.cli.commands.rotate_fernet_key_command.rotate_fernet_key"),
Expand Down
61 changes: 61 additions & 0 deletions airflow-core/src/airflow/cli/commands/state_store_command.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import logging

log = logging.getLogger(__name__)

# Other state operations (list, get, delete per key) will be added here once the
# Core API endpoints (PR 6) land. For now, inspection is available via the REST
# API and the Task Instance detail panel in the UI.


def cleanup(args) -> None:
"""Remove expired task state rows via the configured state backend."""
from airflow.state import get_state_backend
from airflow.state.metastore import MetastoreStateBackend

backend = get_state_backend()

if args.dry_run:
if isinstance(backend, MetastoreStateBackend):
summary = backend._dry_run_summary()
stale, expired = summary["stale"], summary["expired"]
total = len(stale) + len(expired)
if not total:
print("Nothing to delete.")
return
print(f"Would delete {total} task state row(s):\n")
if stale:
print(f" Older than retention period ({len(stale)}):")
for dag_id, run_id, task_id, map_index, key in stale:
print(
f" DAG {dag_id!r}, run {run_id!r}, task {task_id!r}, map_index {map_index!r}, key {key!r}"
)
if expired:
print(f"\n Per-key expiry reached ({len(expired)}):")
for dag_id, run_id, task_id, map_index, key in expired:
print(
f" DAG {dag_id!r}, run {run_id!r}, task {task_id!r}, map_index {map_index!r}, key {key!r}"
)
else:
print("Custom backend configured — cannot preview rows.")
return

log.info("Running state store cleanup")
backend.cleanup()
18 changes: 18 additions & 0 deletions airflow-core/src/airflow/config_templates/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3014,6 +3014,24 @@ state_store:
type: string
example: "mypackage.state.CustomStateBackend"
default: "airflow.state.metastore.MetastoreStateBackend"
default_retention_days:
description: |
Number of days to retain task_state rows after their last update.
Rows older than this are removed by the scheduler's periodic cleanup.
This config does not affect asset_state rows.
Set to 0 to disable time-based cleanup entirely.
version_added: 3.3.0
type: integer
example: "7"
default: "30"
state_cleanup_batch_size:
description: |
Number of rows deleted per batch during cleanup. Defaults to 0 (no batching).
Tune this on deployments with large task_state tables to improve performance per transaction.
version_added: 3.3.0
type: integer
example: "10000"
default: "0"

profiling:
description: |
Expand Down
36 changes: 35 additions & 1 deletion airflow-core/src/airflow/jobs/scheduler_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,22 @@
from itertools import groupby
from typing import TYPE_CHECKING, Any, cast

from sqlalchemy import CTE, and_, case, delete, exists, func, inspect, or_, select, text, tuple_, update
from sqlalchemy import (
CTE,
and_,
case,
delete,
delete as _delete,
exists,
func,
inspect,
or_,
select,
select as _select,
text,
tuple_,
update,
)
from sqlalchemy.exc import DBAPIError, OperationalError
from sqlalchemy.orm import joinedload, lazyload, load_only, make_transient, selectinload
from sqlalchemy.sql import expression
Expand All @@ -58,10 +73,12 @@
from airflow.models import Deadline, Log
from airflow.models.asset import (
AssetActive,
AssetActive as _AssetActive,
AssetAliasModel,
AssetDagRunQueue,
AssetEvent,
AssetModel,
AssetModel as _AssetModel,
AssetPartitionDagRun,
AssetWatcherModel,
DagScheduleAssetAliasReference,
Expand All @@ -70,6 +87,7 @@
TaskInletAssetReference,
TaskOutletAssetReference,
)
from airflow.models.asset_state import AssetStateModel
from airflow.models.backfill import Backfill, BackfillDagRun
from airflow.models.callback import Callback, CallbackType, ExecutorCallback
from airflow.models.dag import DagModel
Expand Down Expand Up @@ -3080,6 +3098,7 @@ def _update_asset_orphanage(self, session: Session = NEW_SESSION) -> None:

self._orphan_unreferenced_assets(orphan_query, session=session)
self._activate_referenced_assets(activate_query, session=session)
self._cleanup_orphaned_asset_state(session=session)

@staticmethod
def _orphan_unreferenced_assets(assets_query: CTE, *, session: Session) -> None:
Expand Down Expand Up @@ -3188,6 +3207,21 @@ def _activate_assets_generate_warnings() -> Iterator[tuple[str, str]]:
session.add(warning)
existing_warned_dag_ids.add(warning.dag_id)

@staticmethod
def _cleanup_orphaned_asset_state(*, session: Session) -> None:
"""
Delete asset_state rows for assets no longer active in any DAG.

When _orphan_unreferenced_assets removes an asset from asset_active, its
asset_state rows become unreachable — no task can write to them anymore.
This runs in the same pass as asset orphanage to keep the table clean.
"""
active_asset_ids = _select(_AssetModel.id).join(
_AssetActive,
(_AssetActive.name == _AssetModel.name) & (_AssetActive.uri == _AssetModel.uri),
)
session.execute(_delete(AssetStateModel).where(AssetStateModel.asset_id.not_in(active_asset_ids)))

def _executor_to_workloads(
self,
workloads: Iterable[SchedulerWorkload],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ def upgrade():
sa.Column("run_id", StringID(), nullable=False),
sa.Column("value", sa.Text().with_variant(mysql.MEDIUMTEXT(), "mysql"), nullable=False),
sa.Column("updated_at", UtcDateTime(), nullable=False),
# Optional early-expiry override. When set, GC deletes this row when expires_at < now()
# even if updated_at is recent. NULL means no early expiry — the row is still cleaned
# up by the global updated_at + default_retention_days check. Populated via
Comment on lines +69 to +70
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is confusing IMO -- an expires_at of None should mean it never expires.

We can pre-compute the expires_at value at update time by reading the default_retention config then (i.e. cleanup becomes a simpler "SELECT where expires_at < Now()`.

This possibly also removes the need for an index on udpated_at.

# task_state.set(retention_days=N) for keys that should expire sooner than the default.
sa.Column("expires_at", UtcDateTime(), nullable=True),
sa.ForeignKeyConstraint(
["dag_run_id"], ["dag_run.id"], name="task_state_dag_run_fkey", ondelete="CASCADE"
),
Expand All @@ -74,11 +79,15 @@ def upgrade():
batch_op.create_index(
"idx_task_state_lookup", ["dag_id", "run_id", "task_id", "map_index"], unique=False
)
batch_op.create_index("idx_task_state_updated_at", ["updated_at"], unique=False)
batch_op.create_index("idx_task_state_expires_at", ["expires_at"], unique=False)


def downgrade():
"""Unapply add task_state and asset_state tables."""
with op.batch_alter_table("task_state", schema=None) as batch_op:
batch_op.drop_index("idx_task_state_expires_at")
batch_op.drop_index("idx_task_state_updated_at")
batch_op.drop_index("idx_task_state_lookup")

op.drop_table("task_state")
Expand Down
8 changes: 8 additions & 0 deletions airflow-core/src/airflow/models/task_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ class TaskStateModel(Base):

value: Mapped[str] = mapped_column(Text().with_variant(MEDIUMTEXT, "mysql"), nullable=False)
updated_at: Mapped[datetime] = mapped_column(UtcDateTime, default=timezone.utcnow, nullable=False)
# Optional override for early expiry. When set, garbage collection deletes this row when
# expires_at < now(), even if updated_at is recent. NULL means no early expiry —
# the row is still cleaned up by the global `updated_at + default_retention_days` check.
# Populated via task_state.set(retention_days=N) for keys that should expire differently
# than the deployment wide default.
expires_at: Mapped[datetime | None] = mapped_column(UtcDateTime, nullable=True)

__table_args__ = (
PrimaryKeyConstraint("dag_run_id", "task_id", "map_index", "key", name="task_state_pkey"),
Expand All @@ -59,4 +65,6 @@ class TaskStateModel(Base):
ondelete="CASCADE",
),
Index("idx_task_state_lookup", "dag_id", "run_id", "task_id", "map_index"),
Index("idx_task_state_updated_at", "updated_at"),
Index("idx_task_state_expires_at", "expires_at"),
)
87 changes: 86 additions & 1 deletion airflow-core/src/airflow/state/metastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,21 @@
# under the License.
from __future__ import annotations

from datetime import timedelta
from typing import TYPE_CHECKING

import structlog
from sqlalchemy import delete, select
from sqlalchemy.sql.expression import tuple_

from airflow._shared.state import AssetScope, BaseStateBackend, StateScope, TaskScope
from airflow._shared.timezones import timezone
from airflow.configuration import conf
from airflow.models.asset_state import AssetStateModel
from airflow.models.dagrun import DagRun
from airflow.models.task_state import TaskStateModel
from airflow.typing_compat import assert_never
from airflow.utils.session import NEW_SESSION, create_session_async, provide_session
from airflow.utils.session import NEW_SESSION, create_session, create_session_async, provide_session
from airflow.utils.sqlalchemy import get_dialect_name

if TYPE_CHECKING:
Expand All @@ -38,6 +42,9 @@
from sqlalchemy.orm import Session


log = structlog.get_logger(__name__)


def _build_upsert_stmt(
dialect: str | None,
model: type,
Expand Down Expand Up @@ -252,6 +259,84 @@ def _clear_asset_state(self, scope: AssetScope, *, session: Session) -> None:
)
)

def cleanup(self) -> None:
"""
Remove expired task state rows.

Reads ``[state_store] default_retention_days`` and ``[state_store] state_cleanup_batch_size``
from config. Each pass runs in its own transaction so partial progress is committed even if a
later pass fails. Each pass is batched to avoid long-running locks on the table.

Two passes:
a. Rows where updated_at < now() - default_retention_days (global retention)
b. Rows where expires_at < now() (per-key early expiry set by the operator)
"""
retention_days = conf.getint("state_store", "default_retention_days")
batch_size = conf.getint("state_store", "state_cleanup_batch_size")
now = timezone.utcnow()
older_than = now - timedelta(days=retention_days) if retention_days > 0 else None

pk_cols = (
TaskStateModel.dag_run_id,
TaskStateModel.task_id,
TaskStateModel.map_index,
TaskStateModel.key,
)
Comment on lines +279 to +284
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given how this is used, it might be time to add a single column id pk (either integer, or uuid)


def _delete_batched(where_clause) -> int:
total = 0
while True:
with create_session() as session:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this should be a new session object each time around the loop, but instead one session object that is explicitly session.commit()ed after each batch.

pk_query = select(*pk_cols).where(where_clause)
if batch_size > 0:
pk_query = pk_query.limit(batch_size)
ids = session.execute(pk_query).all()
if not ids:
break
session.execute(delete(TaskStateModel).where(tuple_(*pk_cols).in_(ids)))
total += len(ids)
if batch_size <= 0 or len(ids) < batch_size:
break
return total

if older_than:
deleted = _delete_batched(TaskStateModel.updated_at < older_than)
log.info("Deleted stale task_state rows", rows_deleted=deleted, older_than=older_than)

deleted = _delete_batched((TaskStateModel.expires_at.isnot(None)) & (TaskStateModel.expires_at < now))
log.info("Deleted expired task_state rows", rows_deleted=deleted)

def _dry_run_summary(self) -> dict[str, list]:
"""
Return rows that would be deleted by cleanup(), without deleting anything.

Returns a dict with keys 'stale' and 'expired', each containing a list of
(dag_id, run_id, task_id, map_index, key) tuples.
"""
retention_days = conf.getint("state_store", "default_retention_days")
now = timezone.utcnow()
older_than = now - timedelta(days=retention_days) if retention_days > 0 else None

cols = (
TaskStateModel.dag_id,
TaskStateModel.run_id,
TaskStateModel.task_id,
TaskStateModel.map_index,
TaskStateModel.key,
)

with create_session() as session:
stale = (
session.execute(select(*cols).where(TaskStateModel.updated_at < older_than)).all()
if older_than
else []
)
expired = session.execute(
select(*cols).where(TaskStateModel.expires_at.isnot(None), TaskStateModel.expires_at < now)
).all()

return {"stale": list(stale), "expired": list(expired)}

Comment on lines +277 to +339
Copy link
Copy Markdown
Member

@jason810496 jason810496 May 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's a valid assumption that users might produce a large amount of state records here between the state_cleanup_interval time window. I have some concern regarding the single pass delete Transaction here.

I just double checked the concern with Claude:


Yes, this is a real problem. Several compounding issues:

  1. Missing indexes — every cleanup will be a full table scan

The two predicates the cleanup filters on are not indexed:

  • task_state.updated_at — no index (only task_state_pkey on (dag_run_id,
    task_id, map_index, key) and idx_task_state_lookup on (dag_id, run_id,
    task_id, map_index))
  • task_state.expires_at — no index (just added in this PR)

So both DELETE WHERE updated_at < cutoff and DELETE WHERE expires_at < now()
do full sequential scans. On a deployment with millions of rows that's minutes
of scanning every 24h, plus the locks held for the whole duration.

  1. No batching / no LIMIT

Compare to airflow db cleanup (utils/db_cleanup.py:217), which deletes in
configurable batches and commits between them. The new path runs three plain
bulk DELETEs in a single session. Long-running bulk DELETE means:

  • Row locks held for the duration (writers calling task_state.set() upserts on
    matching rows block — they queue behind the cleanup transaction).
  • On Postgres: massive WAL churn, autovacuum can't keep up, table bloat.
  • On MySQL/InnoDB at REPEATABLE READ (Airflow's default): next-key/gap locks
    make conflicts even more likely.
  1. All three DELETEs share one transaction

with create_session() as session: opens one session; each session.execute()
runs inside it; commit happens at exit. If pass 1 takes 90s, the locks from
pass 1 are held while pass 2 and pass 3 run. A failure in pass 3 rolls back
passes 1 and 2 (cleanup makes no forward progress at all).

  1. Scheduler main loop is blocked

_cleanup_expired_task_state is registered via call_regular_interval, which is
synchronous in the scheduler loop. Same pattern as
_remove_unreferenced_triggers and _update_asset_orphanage — but those have
small cardinality. task_state is user-driven and unbounded (the AIP encourages
users to write a lot of it). With a multi-minute cleanup the scheduler is not
scheduling for those minutes.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, good catches.
I will address all of them except last now cos its invalid from scheduler perspective, its a cli command now

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Handled it in: cdc4237

async def _aget_task_state(self, scope: TaskScope, key: str, *, session: AsyncSession) -> str | None:
row = await session.scalar(
select(TaskStateModel).where(
Expand Down
Loading
Loading