From 35e1002a3a6ad9d11d8b226caaa300fd0abfb707 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Tue, 31 Mar 2026 18:12:28 +0800 Subject: [PATCH 01/16] feat(AIP-76): window --- .../src/airflow/jobs/scheduler_job_runner.py | 47 ++++++-- .../src/airflow/partition_mappers/base.py | 19 ++++ .../src/airflow/partition_mappers/temporal.py | 100 ++++++++++++++++-- .../src/airflow/serialization/encoders.py | 4 + airflow-core/src/airflow/timetables/base.py | 11 ++ task-sdk/src/airflow/sdk/__init__.py | 6 ++ task-sdk/src/airflow/sdk/__init__.pyi | 4 + .../sdk/definitions/partition_mappers/base.py | 15 +++ .../definitions/partition_mappers/temporal.py | 32 +++++- 9 files changed, 215 insertions(+), 23 deletions(-) diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py b/airflow-core/src/airflow/jobs/scheduler_job_runner.py index 3eed95a8bb030..e6f9f08f29ac4 100644 --- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py +++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py @@ -116,6 +116,7 @@ from airflow.executors.base_executor import BaseExecutor from airflow.executors.executor_utils import ExecutorName from airflow.executors.workloads.types import SchedulerWorkload + from airflow.partition_mappers.base import RollupMapper from airflow.serialization.definitions.dag import SerializedDAG from airflow.utils.sqlalchemy import CommitProhibitorGuard @@ -1831,6 +1832,18 @@ def _do_scheduling(self, session: Session) -> int: return num_queued_tis + def _check_rollup_asset_status( + self, + asset_model: AssetModel, + apdr: AssetPartitionDagRun, + mapper: RollupMapper, + actual_by_asset: dict[int, set[str]], + ) -> bool: + if TYPE_CHECKING: + assert apdr.partition_key is not None + expected = mapper.to_upstream(apdr.partition_key) + return expected.issubset(actual_by_asset.get(asset_model.id, set())) + def _create_dagruns_for_partitioned_asset_dags(self, session: Session) -> set[str]: partition_dag_ids: set[str] = set() @@ -1855,16 +1868,30 @@ def _create_dagruns_for_partitioned_asset_dags(self, session: Session) -> set[st ) ) ) - ) - statuses: dict[SerializedAssetUniqueKey, bool] = { - SerializedAssetUniqueKey.from_asset(a): True for a in asset_models - } - # todo: AIP-76 so, this basically works when we only require one partition from each asset to be there - # but, we ultimately need rollup ability - # that is, we need to ensure that whenever it is many -> one partitions, then we need to ensure - # that all the required keys are there - # one way to do this would be just to figure out what the count should be - if not evaluator.run(dag.timetable.asset_condition, statuses=statuses): + ).all() + actual_by_asset: dict[int, set[str]] = defaultdict(set) + for asset_id, source_key in session.execute( + select(PartitionedAssetKeyLog.asset_id, PartitionedAssetKeyLog.source_partition_key).where( + PartitionedAssetKeyLog.asset_partition_dag_run_id == apdr.id + ) + ): + actual_by_asset[asset_id].add(source_key) + + timetable = dag.timetable + statuses: dict[SerializedAssetUniqueKey, bool] = {} + for asset_model in asset_models: + if timetable.partitioned: + mapper = cast( + "RollupMapper", + timetable.get_partition_mapper(name=asset_model.name, uri=asset_model.uri), + ) + if mapper.is_rollup: + statuses[SerializedAssetUniqueKey.from_asset(asset_model)] = ( + self._check_rollup_asset_status(asset_model, apdr, mapper, actual_by_asset) + ) + continue + statuses[SerializedAssetUniqueKey.from_asset(asset_model)] = True + if not evaluator.run(timetable.asset_condition, statuses=statuses): continue partition_dag_ids.add(apdr.target_dag_id) diff --git a/airflow-core/src/airflow/partition_mappers/base.py b/airflow-core/src/airflow/partition_mappers/base.py index 7c64d05625855..e97d128027ffa 100644 --- a/airflow-core/src/airflow/partition_mappers/base.py +++ b/airflow-core/src/airflow/partition_mappers/base.py @@ -31,6 +31,8 @@ class PartitionMapper(ABC): Maps keys from asset events to target dag run partitions. """ + is_rollup: bool = False + @abstractmethod def to_downstream(self, key: str) -> str | Iterable[str]: """Return the target key that the given source partition key maps to.""" @@ -41,3 +43,20 @@ def serialize(self) -> dict[str, Any]: @classmethod def deserialize(cls, data: dict[str, Any]) -> PartitionMapper: return cls() + + +class RollupMapper(PartitionMapper, ABC): + """ + Partition mapper that supports rollup (many upstream keys → one downstream key). + + Subclass this when the downstream Dag should wait for a complete set of upstream + partition keys before triggering. The scheduler calls ``to_upstream`` to discover + which source keys are required and only creates a Dag run once all of them have + arrived in ``PartitionedAssetKeyLog``. + """ + + is_rollup: bool = True + + @abstractmethod + def to_upstream(self, downstream_key: str) -> frozenset[str]: + """Return the complete set of upstream partition keys required for *downstream_key*.""" diff --git a/airflow-core/src/airflow/partition_mappers/temporal.py b/airflow-core/src/airflow/partition_mappers/temporal.py index 49f4162a12e2c..7e3824d9030cd 100644 --- a/airflow-core/src/airflow/partition_mappers/temporal.py +++ b/airflow-core/src/airflow/partition_mappers/temporal.py @@ -21,7 +21,7 @@ from typing import TYPE_CHECKING, Any from airflow._shared.timezones.timezone import make_aware, parse_timezone -from airflow.partition_mappers.base import PartitionMapper +from airflow.partition_mappers.base import PartitionMapper, RollupMapper if TYPE_CHECKING: from pendulum import FixedTimezone, Timezone @@ -99,30 +99,110 @@ def normalize(self, dt: datetime) -> datetime: class StartOfWeekMapper(_BaseTemporalMapper): - """Map a time-based partition key to week.""" + """Map a time-based partition key to the start of its week.""" default_output_format = "%Y-%m-%d (W%V)" + def __init__( + self, + *, + week_start: int = 0, + timezone: str | Timezone | FixedTimezone = "UTC", + input_format: str = "%Y-%m-%dT%H:%M:%S", + output_format: str | None = None, + ) -> None: + super().__init__(timezone=timezone, input_format=input_format, output_format=output_format) + self.week_start = week_start # 0 = Monday (ISO default), 6 = Sunday + def normalize(self, dt: datetime) -> datetime: - start = dt - timedelta(days=dt.weekday()) + days_since_start = (dt.weekday() - self.week_start) % 7 + start = dt - timedelta(days=days_since_start) return start.replace(hour=0, minute=0, second=0, microsecond=0) + def serialize(self) -> dict[str, Any]: + return {**super().serialize(), "week_start": self.week_start} + + @classmethod + def deserialize(cls, data: dict[str, Any]) -> PartitionMapper: + return cls( + week_start=data.get("week_start", 0), + timezone=parse_timezone(data.get("timezone", "UTC")), + input_format=data["input_format"], + output_format=data["output_format"], + ) + + +class WeeklyRollupMapper(StartOfWeekMapper, RollupMapper): + """ + Map a time-based partition key to the start of its week, requiring all 7 daily keys. + + Use this when a partitioned Dag should only run once every daily asset partition + for a full week has been produced. Configure ``week_start`` to set which day begins + the week (0 = Monday, 6 = Sunday). + """ + + def to_upstream(self, downstream_key: str) -> frozenset[str]: + # The output format always embeds the week-start date as the first 10 chars. + week_start_dt = datetime.strptime(downstream_key[:10], "%Y-%m-%d") + return frozenset((week_start_dt + timedelta(days=i)).strftime(self.input_format) for i in range(7)) + class StartOfMonthMapper(_BaseTemporalMapper): - """Map a time-based partition key to month.""" + """Map a time-based partition key to the start of its month.""" default_output_format = "%Y-%m" + def __init__( + self, + *, + month_start_day: int = 1, + timezone: str | Timezone | FixedTimezone = "UTC", + input_format: str = "%Y-%m-%dT%H:%M:%S", + output_format: str | None = None, + ) -> None: + super().__init__(timezone=timezone, input_format=input_format, output_format=output_format) + self.month_start_day = month_start_day # 1–28; use >1 for fiscal-month offsets + def normalize(self, dt: datetime) -> datetime: - return dt.replace( - day=1, - hour=0, - minute=0, - second=0, - microsecond=0, + if dt.day < self.month_start_day: + month = dt.month - 1 or 12 + year = dt.year - (1 if dt.month == 1 else 0) + start = dt.replace(year=year, month=month, day=self.month_start_day) + else: + start = dt.replace(day=self.month_start_day) + return start.replace(hour=0, minute=0, second=0, microsecond=0) + + def serialize(self) -> dict[str, Any]: + return {**super().serialize(), "month_start_day": self.month_start_day} + + @classmethod + def deserialize(cls, data: dict[str, Any]) -> PartitionMapper: + return cls( + month_start_day=data.get("month_start_day", 1), + timezone=parse_timezone(data.get("timezone", "UTC")), + input_format=data["input_format"], + output_format=data["output_format"], ) +class MonthlyRollupMapper(StartOfMonthMapper, RollupMapper): + """ + Map a time-based partition key to the start of its month, requiring all daily keys in that month. + + Use this when a partitioned Dag should only run once every daily asset partition + for a full calendar month has been produced. Configure ``month_start_day`` for + fiscal-month offsets (e.g. ``month_start_day=15`` for a mid-month period). + """ + + def to_upstream(self, downstream_key: str) -> frozenset[str]: + period_start = datetime.strptime(downstream_key, self.output_format).replace(day=self.month_start_day) + next_month = period_start.month % 12 + 1 + next_year = period_start.year + (1 if period_start.month == 12 else 0) + next_start = period_start.replace(year=next_year, month=next_month) + days = (next_start - period_start).days + return frozenset((period_start + timedelta(days=i)).strftime(self.input_format) for i in range(days)) + + class StartOfQuarterMapper(_BaseTemporalMapper): """Map a time-based partition key to quarter.""" diff --git a/airflow-core/src/airflow/serialization/encoders.py b/airflow-core/src/airflow/serialization/encoders.py index 9e341cbe78340..b84dfb73f05aa 100644 --- a/airflow-core/src/airflow/serialization/encoders.py +++ b/airflow-core/src/airflow/serialization/encoders.py @@ -41,6 +41,7 @@ DeltaTriggerTimetable, EventsTimetable, IdentityMapper, + MonthlyRollupMapper, MultipleCronTriggerTimetable, PartitionMapper, ProductMapper, @@ -49,6 +50,7 @@ StartOfQuarterMapper, StartOfWeekMapper, StartOfYearMapper, + WeeklyRollupMapper, ) from airflow.sdk.bases.timetable import BaseTimetable from airflow.sdk.definitions.asset import AssetRef @@ -414,6 +416,8 @@ def _(self, timetable: PartitionedAssetTimetable) -> dict[str, Any]: StartOfQuarterMapper: "airflow.partition_mappers.temporal.StartOfQuarterMapper", StartOfWeekMapper: "airflow.partition_mappers.temporal.StartOfWeekMapper", StartOfYearMapper: "airflow.partition_mappers.temporal.StartOfYearMapper", + WeeklyRollupMapper: "airflow.partition_mappers.temporal.WeeklyRollupMapper", + MonthlyRollupMapper: "airflow.partition_mappers.temporal.MonthlyRollupMapper", } @functools.singledispatchmethod diff --git a/airflow-core/src/airflow/timetables/base.py b/airflow-core/src/airflow/timetables/base.py index a92bcd7a1f85d..5420b5a2c5bc0 100644 --- a/airflow-core/src/airflow/timetables/base.py +++ b/airflow-core/src/airflow/timetables/base.py @@ -29,6 +29,7 @@ from airflow.models.dag import DagModel from airflow.models.dagrun import DagRun + from airflow.partition_mappers.base import PartitionMapper from airflow.serialization.dag_dependency import DagDependency from airflow.serialization.definitions.assets import ( SerializedAsset, @@ -218,6 +219,16 @@ class Timetable(Protocol): instead of the traditional logic based on logical dates and data intervals. """ + def get_partition_mapper(self, *, name: str = "", uri: str = "") -> PartitionMapper: + """ + Return the partition mapper for the asset identified by *name* or *uri*. + + Only called by the scheduler when ``partitioned`` is *True*. The default + implementation raises :exc:`NotImplementedError`; timetables that set + ``partitioned = True`` must override this. + """ + raise NotImplementedError + @classmethod def deserialize(cls, data: dict[str, Any]) -> Timetable: """ diff --git a/task-sdk/src/airflow/sdk/__init__.py b/task-sdk/src/airflow/sdk/__init__.py index f304b068237b3..301d118e257c1 100644 --- a/task-sdk/src/airflow/sdk/__init__.py +++ b/task-sdk/src/airflow/sdk/__init__.py @@ -75,6 +75,8 @@ "StartOfQuarterMapper", "StartOfWeekMapper", "StartOfYearMapper", + "WeeklyRollupMapper", + "MonthlyRollupMapper", "TaskGroup", "TaskInstance", "TaskInstanceState", @@ -136,12 +138,14 @@ from airflow.sdk.definitions.partition_mappers.identity import IdentityMapper from airflow.sdk.definitions.partition_mappers.product import ProductMapper from airflow.sdk.definitions.partition_mappers.temporal import ( + MonthlyRollupMapper, StartOfDayMapper, StartOfHourMapper, StartOfMonthMapper, StartOfQuarterMapper, StartOfWeekMapper, StartOfYearMapper, + WeeklyRollupMapper, ) from airflow.sdk.definitions.retry_policy import ( ExceptionRetryPolicy, @@ -232,6 +236,8 @@ "StartOfQuarterMapper": ".definitions.partition_mappers.temporal", "StartOfWeekMapper": ".definitions.partition_mappers.temporal", "StartOfYearMapper": ".definitions.partition_mappers.temporal", + "WeeklyRollupMapper": ".definitions.partition_mappers.temporal", + "MonthlyRollupMapper": ".definitions.partition_mappers.temporal", "TaskGroup": ".definitions.taskgroup", "TaskInstance": ".types", "TaskInstanceState": ".api.datamodels._generated", diff --git a/task-sdk/src/airflow/sdk/__init__.pyi b/task-sdk/src/airflow/sdk/__init__.pyi index 7e6d211674eba..854cff94335d0 100644 --- a/task-sdk/src/airflow/sdk/__init__.pyi +++ b/task-sdk/src/airflow/sdk/__init__.pyi @@ -68,12 +68,14 @@ from airflow.sdk.definitions.partition_mappers.chain import ChainMapper from airflow.sdk.definitions.partition_mappers.identity import IdentityMapper from airflow.sdk.definitions.partition_mappers.product import ProductMapper from airflow.sdk.definitions.partition_mappers.temporal import ( + MonthlyRollupMapper, StartOfDayMapper, StartOfHourMapper, StartOfMonthMapper, StartOfQuarterMapper, StartOfWeekMapper, StartOfYearMapper, + WeeklyRollupMapper, ) from airflow.sdk.definitions.retry_policy import ( ExceptionRetryPolicy as ExceptionRetryPolicy, @@ -160,6 +162,8 @@ __all__ = [ "StartOfQuarterMapper", "StartOfWeekMapper", "StartOfYearMapper", + "WeeklyRollupMapper", + "MonthlyRollupMapper", "TaskGroup", "TaskInstanceState", "TriggerRule", diff --git a/task-sdk/src/airflow/sdk/definitions/partition_mappers/base.py b/task-sdk/src/airflow/sdk/definitions/partition_mappers/base.py index 728332d506fc7..00c4f57412f2f 100644 --- a/task-sdk/src/airflow/sdk/definitions/partition_mappers/base.py +++ b/task-sdk/src/airflow/sdk/definitions/partition_mappers/base.py @@ -23,3 +23,18 @@ class PartitionMapper: Maps keys from asset events to target dag run partitions. """ + + is_rollup: bool = False + + +class RollupMapper(PartitionMapper): + """ + Partition mapper that supports rollup (many upstream keys → one downstream key). + + Subclass this when the downstream Dag should wait for a complete set of upstream + partition keys before triggering. The scheduler calls ``to_upstream`` to discover + which source keys are required and only creates a Dag run once all of them have + arrived in ``PartitionedAssetKeyLog``. + """ + + is_rollup: bool = True diff --git a/task-sdk/src/airflow/sdk/definitions/partition_mappers/temporal.py b/task-sdk/src/airflow/sdk/definitions/partition_mappers/temporal.py index 60ca18f5044f3..1dc7cd56f1b42 100644 --- a/task-sdk/src/airflow/sdk/definitions/partition_mappers/temporal.py +++ b/task-sdk/src/airflow/sdk/definitions/partition_mappers/temporal.py @@ -16,7 +16,7 @@ # under the License. from __future__ import annotations -from airflow.sdk.definitions.partition_mappers.base import PartitionMapper +from airflow.sdk.definitions.partition_mappers.base import PartitionMapper, RollupMapper class _BaseTemporalMapper(PartitionMapper): @@ -44,16 +44,42 @@ class StartOfDayMapper(_BaseTemporalMapper): class StartOfWeekMapper(_BaseTemporalMapper): - """Map a time-based partition key to week.""" + """Map a time-based partition key to the start of its week.""" default_output_format = "%Y-%m-%d (W%V)" + def __init__(self, *, week_start: int = 0, **kwargs) -> None: + super().__init__(**kwargs) + self.week_start = week_start + + +class WeeklyRollupMapper(StartOfWeekMapper, RollupMapper): + """ + Map a time-based partition key to the start of its week, requiring all 7 daily keys. + + Use this when a partitioned Dag should only run once every daily asset partition + for a full week has been produced. + """ + class StartOfMonthMapper(_BaseTemporalMapper): - """Map a time-based partition key to month.""" + """Map a time-based partition key to the start of its month.""" default_output_format = "%Y-%m" + def __init__(self, *, month_start_day: int = 1, **kwargs) -> None: + super().__init__(**kwargs) + self.month_start_day = month_start_day + + +class MonthlyRollupMapper(StartOfMonthMapper, RollupMapper): + """ + Map a time-based partition key to the start of its month, requiring all daily keys in that month. + + Use this when a partitioned Dag should only run once every daily asset partition + for a full calendar month has been produced. + """ + class StartOfQuarterMapper(_BaseTemporalMapper): """Map a time-based partition key to quarter.""" From 87c2ff016e46d93863dd2e6f422197bd3c4ba772 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Tue, 7 Apr 2026 18:13:08 +0800 Subject: [PATCH 02/16] perf: simplify SQL queries --- .../src/airflow/jobs/scheduler_job_runner.py | 128 +++++++++++++----- .../src/airflow/partition_mappers/temporal.py | 39 ++++-- airflow-core/src/airflow/timetables/base.py | 11 -- .../unit/partition_mappers/test_temporal.py | 29 ++-- task-sdk/docs/api.rst | 4 + 5 files changed, 144 insertions(+), 67 deletions(-) diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py b/airflow-core/src/airflow/jobs/scheduler_job_runner.py index e6f9f08f29ac4..c4602933e6b4a 100644 --- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py +++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py @@ -88,7 +88,7 @@ from airflow.serialization.definitions.assets import SerializedAssetUniqueKey from airflow.serialization.definitions.notset import NOTSET from airflow.ti_deps.dependencies_states import ACTIVE_STATES, EXECUTION_STATES -from airflow.timetables.simple import AssetTriggeredTimetable +from airflow.timetables.simple import AssetTriggeredTimetable, PartitionedAssetTimetable from airflow.utils.event_scheduler import EventScheduler from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.retries import MAX_DB_RETRIES, retry_db_transaction, run_with_db_retries @@ -1834,7 +1834,8 @@ def _do_scheduling(self, session: Session) -> int: def _check_rollup_asset_status( self, - asset_model: AssetModel, + *, + asset_id: int, apdr: AssetPartitionDagRun, mapper: RollupMapper, actual_by_asset: dict[int, set[str]], @@ -1842,55 +1843,112 @@ def _check_rollup_asset_status( if TYPE_CHECKING: assert apdr.partition_key is not None expected = mapper.to_upstream(apdr.partition_key) - return expected.issubset(actual_by_asset.get(asset_model.id, set())) + return expected.issubset(actual_by_asset.get(asset_id, set())) + + def _resolve_asset_partition_status( + self, + *, + asset_id: int, + name: str, + uri: str, + apdr: AssetPartitionDagRun, + timetable: PartitionedAssetTimetable, + actual_by_asset: dict[int, set[str]], + ) -> bool | None: + """ + Return the rollup status for one asset within a pending partitioned Dag run. + + Returns *True*/*False* for rollup assets, or *None* when the asset has no + rollup mapper and should default to satisfied. + """ + try: + mapper = timetable.get_partition_mapper(name=name, uri=uri) + if not mapper.is_rollup: + return None + return self._check_rollup_asset_status( + asset_id=asset_id, + apdr=apdr, + mapper=cast("RollupMapper", mapper), + actual_by_asset=actual_by_asset, + ) + except Exception: + self.log.exception( + "Failed to evaluate rollup status for asset; treating as not-yet-satisfied. " + "This likely indicates a misconfigured partition mapper.", + dag_id=apdr.target_dag_id, + partition_key=apdr.partition_key, + asset_name=name, + asset_uri=uri, + ) + return False def _create_dagruns_for_partitioned_asset_dags(self, session: Session) -> set[str]: partition_dag_ids: set[str] = set() - evaluator = AssetEvaluator(session) - for apdr in session.scalars( + pending_apdrs = session.scalars( select(AssetPartitionDagRun).where(AssetPartitionDagRun.created_dag_run_id.is_(None)) + ).all() + if not pending_apdrs: + return partition_dag_ids + + pending_apdr_ids = [apdr.id for apdr in pending_apdrs] + + # Pre-fetch all required serialized Dags in one query. + dag_ids = list({apdr.target_dag_id for apdr in pending_apdrs if apdr.target_dag_id}) + # {"dag_id": Serialized Dag} + serialized_dags: dict[str, SerializedDAG] = {} + for serdag in SerializedDagModel.get_latest_serialized_dags(dag_ids=dag_ids, session=session): + try: + serdag.load_op_links = False + serialized_dags[serdag.dag_id] = serdag.dag + except Exception: + self.log.exception("Failed to deserialize Dag '%s'", serdag.dag_id) + + # {apdr_id: {asset_id: set(source_key, ...)} + source_key_by_asset_per_apdr: dict[int, dict[int, set[str]]] = defaultdict(lambda: defaultdict(set)) + # {apdr_id: {asset_id: (asset_name, asset_uri)} + asset_info_per_apdr: dict[int, dict[int, tuple[str, str]]] = defaultdict(dict) + for apdr_id, asset_id, source_key, name, uri in session.execute( + select( + PartitionedAssetKeyLog.asset_partition_dag_run_id, + PartitionedAssetKeyLog.asset_id, + PartitionedAssetKeyLog.source_partition_key, + AssetModel.name, + AssetModel.uri, + ) + .join(AssetModel, AssetModel.id == PartitionedAssetKeyLog.asset_id) + .where(PartitionedAssetKeyLog.asset_partition_dag_run_id.in_(pending_apdr_ids)) ): + source_key_by_asset_per_apdr[apdr_id][asset_id].add(source_key) + asset_info_per_apdr[apdr_id][asset_id] = (name, uri) + + evaluator = AssetEvaluator(session) + for apdr in pending_apdrs: if TYPE_CHECKING: assert apdr.target_dag_id - if not (dag := self._get_current_dag(dag_id=apdr.target_dag_id, session=session)): + if not (dag := serialized_dags.get(apdr.target_dag_id)): self.log.error("Dag '%s' not found in serialized_dag table", apdr.target_dag_id) continue - asset_models = session.scalars( - select(AssetModel).where( - exists( - select(1).where( - PartitionedAssetKeyLog.asset_id == AssetModel.id, - PartitionedAssetKeyLog.asset_partition_dag_run_id == apdr.id, - PartitionedAssetKeyLog.target_partition_key == apdr.partition_key, - ) - ) - ) - ).all() - actual_by_asset: dict[int, set[str]] = defaultdict(set) - for asset_id, source_key in session.execute( - select(PartitionedAssetKeyLog.asset_id, PartitionedAssetKeyLog.source_partition_key).where( - PartitionedAssetKeyLog.asset_partition_dag_run_id == apdr.id - ) - ): - actual_by_asset[asset_id].add(source_key) - + source_key_by_asset = source_key_by_asset_per_apdr[apdr.id] timetable = dag.timetable statuses: dict[SerializedAssetUniqueKey, bool] = {} - for asset_model in asset_models: - if timetable.partitioned: - mapper = cast( - "RollupMapper", - timetable.get_partition_mapper(name=asset_model.name, uri=asset_model.uri), + for asset_id, (name, uri) in asset_info_per_apdr[apdr.id].items(): + key = SerializedAssetUniqueKey(name=name, uri=uri) + if isinstance(timetable, PartitionedAssetTimetable): + status = self._resolve_asset_partition_status( + asset_id=asset_id, + name=name, + uri=uri, + apdr=apdr, + timetable=timetable, + actual_by_asset=source_key_by_asset, ) - if mapper.is_rollup: - statuses[SerializedAssetUniqueKey.from_asset(asset_model)] = ( - self._check_rollup_asset_status(asset_model, apdr, mapper, actual_by_asset) - ) + if status is not None: + statuses[key] = status continue - statuses[SerializedAssetUniqueKey.from_asset(asset_model)] = True + statuses[key] = True if not evaluator.run(timetable.asset_condition, statuses=statuses): continue diff --git a/airflow-core/src/airflow/partition_mappers/temporal.py b/airflow-core/src/airflow/partition_mappers/temporal.py index 7e3824d9030cd..96e98d3e7fcdb 100644 --- a/airflow-core/src/airflow/partition_mappers/temporal.py +++ b/airflow-core/src/airflow/partition_mappers/temporal.py @@ -141,10 +141,24 @@ class WeeklyRollupMapper(StartOfWeekMapper, RollupMapper): the week (0 = Monday, 6 = Sunday). """ + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + if "%Y-%m-%d" not in self.output_format: + raise ValueError( + f"WeeklyRollupMapper requires output_format to contain '%Y-%m-%d' so that " + f"to_upstream() can recover the week-start date, got: {self.output_format!r}" + ) + def to_upstream(self, downstream_key: str) -> frozenset[str]: - # The output format always embeds the week-start date as the first 10 chars. - week_start_dt = datetime.strptime(downstream_key[:10], "%Y-%m-%d") - return frozenset((week_start_dt + timedelta(days=i)).strftime(self.input_format) for i in range(7)) + # Parse via output_format (not a hardcoded slice) so custom formats work correctly. + # Arithmetic stays on naive datetimes to keep day-counting unambiguous across + # DST transitions; each result is made timezone-aware before formatting so that + # %z in input_format produces the correct offset. + week_start_naive = datetime.strptime(downstream_key, self.output_format) + return frozenset( + make_aware(week_start_naive + timedelta(days=i), self._timezone).strftime(self.input_format) + for i in range(7) + ) class StartOfMonthMapper(_BaseTemporalMapper): @@ -195,12 +209,19 @@ class MonthlyRollupMapper(StartOfMonthMapper, RollupMapper): """ def to_upstream(self, downstream_key: str) -> frozenset[str]: - period_start = datetime.strptime(downstream_key, self.output_format).replace(day=self.month_start_day) - next_month = period_start.month % 12 + 1 - next_year = period_start.year + (1 if period_start.month == 12 else 0) - next_start = period_start.replace(year=next_year, month=next_month) - days = (next_start - period_start).days - return frozenset((period_start + timedelta(days=i)).strftime(self.input_format) for i in range(days)) + # Use naive datetimes for day-counting to avoid DST ambiguity, then make + # each result timezone-aware before formatting so %z produces the correct offset. + period_start_naive = datetime.strptime(downstream_key, self.output_format).replace( + day=self.month_start_day + ) + next_month = period_start_naive.month % 12 + 1 + next_year = period_start_naive.year + (1 if period_start_naive.month == 12 else 0) + next_start_naive = period_start_naive.replace(year=next_year, month=next_month) + days = (next_start_naive - period_start_naive).days + return frozenset( + make_aware(period_start_naive + timedelta(days=i), self._timezone).strftime(self.input_format) + for i in range(days) + ) class StartOfQuarterMapper(_BaseTemporalMapper): diff --git a/airflow-core/src/airflow/timetables/base.py b/airflow-core/src/airflow/timetables/base.py index 5420b5a2c5bc0..a92bcd7a1f85d 100644 --- a/airflow-core/src/airflow/timetables/base.py +++ b/airflow-core/src/airflow/timetables/base.py @@ -29,7 +29,6 @@ from airflow.models.dag import DagModel from airflow.models.dagrun import DagRun - from airflow.partition_mappers.base import PartitionMapper from airflow.serialization.dag_dependency import DagDependency from airflow.serialization.definitions.assets import ( SerializedAsset, @@ -219,16 +218,6 @@ class Timetable(Protocol): instead of the traditional logic based on logical dates and data intervals. """ - def get_partition_mapper(self, *, name: str = "", uri: str = "") -> PartitionMapper: - """ - Return the partition mapper for the asset identified by *name* or *uri*. - - Only called by the scheduler when ``partitioned`` is *True*. The default - implementation raises :exc:`NotImplementedError`; timetables that set - ``partitioned = True`` must override this. - """ - raise NotImplementedError - @classmethod def deserialize(cls, data: dict[str, Any]) -> Timetable: """ diff --git a/airflow-core/tests/unit/partition_mappers/test_temporal.py b/airflow-core/tests/unit/partition_mappers/test_temporal.py index c54ca8a51f9a7..1a32900284912 100644 --- a/airflow-core/tests/unit/partition_mappers/test_temporal.py +++ b/airflow-core/tests/unit/partition_mappers/test_temporal.py @@ -57,29 +57,34 @@ def test_to_downstream( ], ) @pytest.mark.parametrize( - ("mapper_cls", "expected_outut_format"), + ("mapper_cls", "expected_outut_format", "extra_kwargs"), [ - (StartOfHourMapper, "%Y-%m-%dT%H"), - (StartOfDayMapper, "%Y-%m-%d"), - (StartOfWeekMapper, "%Y-%m-%d (W%V)"), - (StartOfMonthMapper, "%Y-%m"), - (StartOfQuarterMapper, "%Y-Q{quarter}"), - (StartOfYearMapper, "%Y"), + (StartOfHourMapper, "%Y-%m-%dT%H", {}), + (StartOfDayMapper, "%Y-%m-%d", {}), + (StartOfWeekMapper, "%Y-%m-%d (W%V)", {"week_start": 0}), + (StartOfMonthMapper, "%Y-%m", {"month_start_day": 1}), + (StartOfQuarterMapper, "%Y-Q{quarter}", {}), + (StartOfYearMapper, "%Y", {}), ], ) def test_serialize( self, mapper_cls: type[_BaseTemporalMapper], expected_outut_format: str, + extra_kwargs: dict[str, int], timezone: str | None, expected_timezone: str, ): pm = mapper_cls() if timezone is None else mapper_cls(timezone=timezone) - assert pm.serialize() == { - "timezone": expected_timezone, - "input_format": "%Y-%m-%dT%H:%M:%S", - "output_format": expected_outut_format, - } + assert ( + pm.serialize() + == { + "timezone": expected_timezone, + "input_format": "%Y-%m-%dT%H:%M:%S", + "output_format": expected_outut_format, + } + | extra_kwargs + ) @pytest.mark.parametrize( "mapper_cls", diff --git a/task-sdk/docs/api.rst b/task-sdk/docs/api.rst index cb9789f5bb69f..1598fc71afe91 100644 --- a/task-sdk/docs/api.rst +++ b/task-sdk/docs/api.rst @@ -233,6 +233,10 @@ Partition Mapper .. autoapiclass:: airflow.sdk.StartOfYearMapper +.. autoapiclass:: airflow.sdk.WeeklyRollupMapper + +.. autoapiclass:: airflow.sdk.MonthlyRollupMapper + .. autoapiclass:: airflow.sdk.ProductMapper .. autoapiclass:: airflow.sdk.AllowedKeyMapper From ddcd24a89cf2072fe4efa141e24b9b7c7c2bb055 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Wed, 8 Apr 2026 15:51:21 +0800 Subject: [PATCH 03/16] feat(ui): add one to many partition key support to UI this is still not ideal, but at least it's not super wrong now --- .../datamodels/ui/partitioned_dag_runs.py | 2 + .../core_api/openapi/_private_ui.yaml | 8 + .../api_fastapi/core_api/routes/ui/assets.py | 64 +++++++ .../routes/ui/partitioned_dag_runs.py | 168 ++++++++++++------ .../src/airflow/partition_mappers/temporal.py | 11 +- .../ui/openapi-gen/requests/schemas.gen.ts | 10 +- .../ui/openapi-gen/requests/types.gen.ts | 2 + .../AssetExpression/AssetExpression.tsx | 8 + .../components/AssetExpression/AssetNode.tsx | 70 +++++--- .../src/components/AssetExpression/types.ts | 9 +- .../ui/src/components/AssetProgressCell.tsx | 6 +- .../ui/src/pages/DagsList/AssetSchedule.tsx | 21 ++- .../routes/ui/test_partitioned_dag_runs.py | 2 +- 13 files changed, 291 insertions(+), 90 deletions(-) diff --git a/airflow-core/src/airflow/api_fastapi/core_api/datamodels/ui/partitioned_dag_runs.py b/airflow-core/src/airflow/api_fastapi/core_api/datamodels/ui/partitioned_dag_runs.py index 628f8560e96af..19fc6b62129bc 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/datamodels/ui/partitioned_dag_runs.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/datamodels/ui/partitioned_dag_runs.py @@ -47,6 +47,8 @@ class PartitionedDagRunAssetResponse(BaseModel): asset_name: str asset_uri: str received: bool + received_count: int + required_count: int class PartitionedDagRunDetailResponse(BaseModel): diff --git a/airflow-core/src/airflow/api_fastapi/core_api/openapi/_private_ui.yaml b/airflow-core/src/airflow/api_fastapi/core_api/openapi/_private_ui.yaml index 2983263bbc59b..995fc004e46b3 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/openapi/_private_ui.yaml +++ b/airflow-core/src/airflow/api_fastapi/core_api/openapi/_private_ui.yaml @@ -3152,12 +3152,20 @@ components: received: type: boolean title: Received + received_count: + type: integer + title: Received Count + required_count: + type: integer + title: Required Count type: object required: - asset_id - asset_name - asset_uri - received + - received_count + - required_count title: PartitionedDagRunAssetResponse description: Asset info within a partitioned Dag run detail. PartitionedDagRunCollectionResponse: diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/assets.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/assets.py index 155b49726f6e4..84b1e5fdf28fc 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/assets.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/assets.py @@ -16,6 +16,9 @@ # under the License. from __future__ import annotations +from contextlib import suppress +from typing import TYPE_CHECKING, cast + from fastapi import Depends, HTTPException, status from sqlalchemy import ColumnElement, and_, case, exists, func, select, true @@ -31,6 +34,9 @@ DagScheduleAssetReference, PartitionedAssetKeyLog, ) +from airflow.models.serialized_dag import SerializedDagModel +from airflow.partition_mappers.base import RollupMapper +from airflow.timetables.simple import PartitionedAssetTimetable assets_router = AirflowRouter(tags=["Asset"]) @@ -115,6 +121,64 @@ def next_run_assets( if not event.pop("queued", None): event["lastUpdate"] = None + # For partitioned Dags: enrich events with per-asset received/required counts, + # using to_upstream for rollup mappers, and fix lastUpdate for partial receipt. + if is_partitioned: + pending_apdr = session.execute( + select(AssetPartitionDagRun.id, AssetPartitionDagRun.partition_key) + .where( + AssetPartitionDagRun.target_dag_id == dag_id, + AssetPartitionDagRun.created_dag_run_id.is_(None), + ) + .order_by(AssetPartitionDagRun.created_at.desc()) + .limit(1) + ).one_or_none() + + if pending_apdr is not None: + # Count received log entries per asset for this partition + received_by_asset: dict[int, int] = { + row.asset_id: row.cnt + for row in session.execute( + select( + PartitionedAssetKeyLog.asset_id, + func.count(PartitionedAssetKeyLog.id).label("cnt"), + ) + .where(PartitionedAssetKeyLog.asset_partition_dag_run_id == pending_apdr.id) + .group_by(PartitionedAssetKeyLog.asset_id) + ).all() + } + + timetable = None + serdag = SerializedDagModel.get(dag_id=dag_id, session=session) + if serdag is not None: + with suppress(Exception): + dag_obj = serdag.dag + if isinstance(dag_obj.timetable, PartitionedAssetTimetable): + timetable = dag_obj.timetable + + for event in events: + asset_id = event["id"] + received_count = received_by_asset.get(asset_id, 0) + required_count = 1 + if timetable is not None: + with suppress(Exception): + mapper = timetable.get_partition_mapper( + name=event.get("name") or "", + uri=event.get("uri") or "", + ) + mapper.is_rollup + if isinstance(mapper, RollupMapper): + required_count = len(mapper.to_upstream(pending_apdr.partition_key)) + event["receivedCount"] = received_count + event["requiredCount"] = required_count + # Only show lastUpdate when all required upstream keys are received + if received_count < required_count: + event["lastUpdate"] = None + else: + for event in events: + event["receivedCount"] = 0 + event["requiredCount"] = 1 + data: dict = {"asset_expression": dag_model.asset_expression, "events": events} if pending_partition_count is not None: data["pending_partition_count"] = pending_partition_count diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/partitioned_dag_runs.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/partitioned_dag_runs.py index fe4696efbb5a0..0860957374521 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/partitioned_dag_runs.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/partitioned_dag_runs.py @@ -16,8 +16,11 @@ # under the License. from __future__ import annotations +from contextlib import suppress +from typing import TYPE_CHECKING, cast + from fastapi import Depends, HTTPException, status -from sqlalchemy import exists, func, select +from sqlalchemy import func, select from airflow.api_fastapi.common.db.common import SessionDep, apply_filters_to_select from airflow.api_fastapi.common.parameters import ( @@ -44,6 +47,50 @@ PartitionedAssetKeyLog, ) from airflow.models.dagrun import DagRun +from airflow.models.serialized_dag import SerializedDagModel +from airflow.timetables.simple import PartitionedAssetTimetable + +if TYPE_CHECKING: + from airflow.partition_mappers.base import RollupMapper + + +def _load_timetable_and_assets( + dag_id: str, session +) -> tuple[PartitionedAssetTimetable | None, list[tuple[str, str]]]: + """Load the DAG timetable and its active required assets as (name, uri) pairs.""" + timetable = None + serdag = SerializedDagModel.get(dag_id=dag_id, session=session) + if serdag is not None: + with suppress(Exception): + dag = serdag.dag + if isinstance(dag.timetable, PartitionedAssetTimetable): + timetable = dag.timetable + + asset_rows = session.execute( + select(AssetModel.name, AssetModel.uri) + .join(DagScheduleAssetReference, DagScheduleAssetReference.asset_id == AssetModel.id) + .where(DagScheduleAssetReference.dag_id == dag_id, AssetModel.active.has()) + ).all() + return timetable, [(r.name or "", r.uri or "") for r in asset_rows] + + +def _compute_total_required( + timetable: PartitionedAssetTimetable | None, + asset_info: list[tuple[str, str]], + partition_key: str, +) -> int: + """Sum required upstream events across all assets, using to_upstream for rollup mappers.""" + if timetable is None: + return len(asset_info) + total = 0 + for name, uri in asset_info: + try: + mapper = timetable.get_partition_mapper(name=name, uri=uri) + total += len(mapper.to_upstream(partition_key)) if isinstance(mapper, RollupMapper) else 1 + except Exception: + total += 1 + return total or len(asset_info) + partitioned_dag_runs_router = AirflowRouter(tags=["PartitionedDagRun"]) @@ -73,21 +120,8 @@ def get_partitioned_dag_runs( ) -> PartitionedDagRunCollectionResponse: """Return PartitionedDagRuns. Filter by dag_id and/or has_created_dag_run_id.""" if dag_id.value is not None: - # Single query: validate Dag + get required count dag_info = session.execute( - select( - DagModel.timetable_summary, - func.count(DagScheduleAssetReference.asset_id).label("required_count"), - ) - .outerjoin( - DagScheduleAssetReference, - (DagScheduleAssetReference.dag_id == DagModel.dag_id) - & DagScheduleAssetReference.asset_id.in_( - select(AssetModel.id).where(AssetModel.active.has()) - ), - ) - .where(DagModel.dag_id == dag_id.value) - .group_by(DagModel.dag_id) + select(DagModel.timetable_summary).where(DagModel.dag_id == dag_id.value) ).one_or_none() if dag_info is None: @@ -95,9 +129,7 @@ def get_partitioned_dag_runs( if dag_info.timetable_summary != "Partitioned Asset": return PartitionedDagRunCollectionResponse(partitioned_dag_runs=[], total=0) - required_count = dag_info.required_count - - # Subquery for received count per partition (only count required assets) + # Subquery: count received events per partition (PartitionedAssetKeyLog rows for required assets) required_assets_subq = ( select(DagScheduleAssetReference.asset_id) .join(AssetModel, AssetModel.id == DagScheduleAssetReference.asset_id) @@ -108,7 +140,7 @@ def get_partitioned_dag_runs( .correlate(AssetPartitionDagRun) ) received_subq = ( - select(func.count(func.distinct(PartitionedAssetKeyLog.asset_id))) + select(func.count(PartitionedAssetKeyLog.id)) .where( PartitionedAssetKeyLog.asset_partition_dag_run_id == AssetPartitionDagRun.id, PartitionedAssetKeyLog.asset_id.in_(required_assets_subq), @@ -137,30 +169,34 @@ def get_partitioned_dag_runs( return PartitionedDagRunCollectionResponse(partitioned_dag_runs=[], total=0) if dag_id.value is not None: - results = [_build_response(row, required_count) for row in rows] + timetable, asset_info = _load_timetable_and_assets(dag_id.value, session) + results = [ + _build_response(row, _compute_total_required(timetable, asset_info, row.partition_key)) + for row in rows + ] return PartitionedDagRunCollectionResponse(partitioned_dag_runs=results, total=len(results)) - # No dag_id: need to get required counts and expressions per dag - dag_ids = list({row.target_dag_id for row in rows}) + # No dag_id filter: load timetables and assets for each unique DAG + unique_dag_ids = list({row.target_dag_id for row in rows}) + dag_timetables_assets: dict[str, tuple[PartitionedAssetTimetable | None, list[tuple[str, str]]]] = { + did: _load_timetable_and_assets(did, session) for did in unique_dag_ids + } dag_rows = session.execute( - select( - DagModel.dag_id, - DagModel.asset_expression, - func.count(DagScheduleAssetReference.asset_id).label("required_count"), - ) - .outerjoin( - DagScheduleAssetReference, - (DagScheduleAssetReference.dag_id == DagModel.dag_id) - & DagScheduleAssetReference.asset_id.in_(select(AssetModel.id).where(AssetModel.active.has())), - ) - .where(DagModel.dag_id.in_(dag_ids)) - .group_by(DagModel.dag_id) + select(DagModel.dag_id, DagModel.asset_expression).where(DagModel.dag_id.in_(unique_dag_ids)) ).all() - - required_counts = {r.dag_id: r.required_count for r in dag_rows} asset_expressions = {r.dag_id: r.asset_expression for r in dag_rows} - results = [_build_response(row, required_counts.get(row.target_dag_id, 0)) for row in rows] + results = [ + _build_response( + row, + _compute_total_required( + dag_timetables_assets[row.target_dag_id][0], + dag_timetables_assets[row.target_dag_id][1], + row.partition_key, + ), + ) + for row in rows + ] return PartitionedDagRunCollectionResponse( partitioned_dag_runs=results, total=len(results), @@ -201,13 +237,17 @@ def get_pending_partitioned_dag_run( f"No PartitionedDagRun for dag={dag_id} partition={partition_key}", ) - received_subq = ( - select(PartitionedAssetKeyLog.asset_id).where( - PartitionedAssetKeyLog.asset_partition_dag_run_id == partitioned_dag_run.id + # Count received PartitionedAssetKeyLog entries per asset for this partition + received_count_col = ( + select(func.count(PartitionedAssetKeyLog.id)) + .where( + PartitionedAssetKeyLog.asset_partition_dag_run_id == partitioned_dag_run.id, + PartitionedAssetKeyLog.asset_id == AssetModel.id, ) - ).correlate(AssetModel) - - received_expr = exists(received_subq.where(PartitionedAssetKeyLog.asset_id == AssetModel.id)) + .correlate(AssetModel) + .scalar_subquery() + .label("received_count") + ) asset_expression_subq = ( select(DagModel.asset_expression).where(DagModel.dag_id == dag_id).scalar_subquery() @@ -217,21 +257,45 @@ def get_pending_partitioned_dag_run( AssetModel.id, AssetModel.uri, AssetModel.name, - received_expr.label("received"), + received_count_col, asset_expression_subq.label("asset_expression"), ) .join(DagScheduleAssetReference, DagScheduleAssetReference.asset_id == AssetModel.id) .where(DagScheduleAssetReference.dag_id == dag_id, AssetModel.active.has()) - .order_by(received_expr.asc(), AssetModel.uri) + .order_by(received_count_col, AssetModel.uri) ).all() - assets = [ - PartitionedDagRunAssetResponse( - asset_id=row.id, asset_name=row.name, asset_uri=row.uri, received=row.received + # Load serialized DAG to compute required counts for rollup assets + timetable = None + serdag = SerializedDagModel.get(dag_id=dag_id, session=session) + if serdag is not None: + with suppress(Exception): + dag = serdag.dag + if isinstance(dag.timetable, PartitionedAssetTimetable): + timetable = dag.timetable + + assets = [] + for row in asset_rows: + received_count = row.received_count or 0 + required_count = 1 + if timetable is not None: + with suppress(Exception): + mapper = timetable.get_partition_mapper(name=row.name or "", uri=row.uri or "") + if isinstance(mapper, RollupMapper): + required_count = len(mapper.to_upstream(partition_key)) + assets.append( + PartitionedDagRunAssetResponse( + asset_id=asset_row.id, + asset_name=asset_row.name, + asset_uri=asset_row.uri, + received=received_count >= required_count and required_count > 0, + received_count=received_count, + required_count=required_count, + ) ) - for row in asset_rows - ] - total_received = sum(1 for a in assets if a.received) + + total_received = sum(a.received_count for a in assets) + total_required = sum(a.required_count for a in assets) asset_expression = asset_rows[0].asset_expression if asset_rows else None return PartitionedDagRunDetailResponse( @@ -242,7 +306,7 @@ def get_pending_partitioned_dag_run( updated_at=partitioned_dag_run.updated_at.isoformat() if partitioned_dag_run.updated_at else None, created_dag_run_id=partitioned_dag_run.created_dag_run_id, assets=assets, - total_required=len(assets), + total_required=total_required, total_received=total_received, asset_expression=asset_expression, ) diff --git a/airflow-core/src/airflow/partition_mappers/temporal.py b/airflow-core/src/airflow/partition_mappers/temporal.py index 96e98d3e7fcdb..565c11d351c7d 100644 --- a/airflow-core/src/airflow/partition_mappers/temporal.py +++ b/airflow-core/src/airflow/partition_mappers/temporal.py @@ -150,11 +150,18 @@ def __init__(self, **kwargs) -> None: ) def to_upstream(self, downstream_key: str) -> frozenset[str]: - # Parse via output_format (not a hardcoded slice) so custom formats work correctly. + # Python strptime raises ValueError when %V (ISO week number) appears without + # %G and a weekday directive, so we cannot parse via the full output_format. + # Instead, locate %Y-%m-%d in the format string — __init__ guarantees it is + # present — and parse only the matching 10-char slice of the key. + # The prefix before %Y-%m-%d is literal text (no format directives), so its + # length in the format string equals its length in the formatted output. + ymd_fmt = "%Y-%m-%d" + key_start = len(self.output_format[: self.output_format.index(ymd_fmt)]) + week_start_naive = datetime.strptime(downstream_key[key_start : key_start + 10], ymd_fmt) # Arithmetic stays on naive datetimes to keep day-counting unambiguous across # DST transitions; each result is made timezone-aware before formatting so that # %z in input_format produces the correct offset. - week_start_naive = datetime.strptime(downstream_key, self.output_format) return frozenset( make_aware(week_start_naive + timedelta(days=i), self._timezone).strftime(self.input_format) for i in range(7) diff --git a/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts b/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts index bc4c7179c2dbb..18062b0454be6 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts @@ -8877,10 +8877,18 @@ export const $PartitionedDagRunAssetResponse = { received: { type: 'boolean', title: 'Received' + }, + received_count: { + type: 'integer', + title: 'Received Count' + }, + required_count: { + type: 'integer', + title: 'Required Count' } }, type: 'object', - required: ['asset_id', 'asset_name', 'asset_uri', 'received'], + required: ['asset_id', 'asset_name', 'asset_uri', 'received', 'received_count', 'required_count'], title: 'PartitionedDagRunAssetResponse', description: 'Asset info within a partitioned Dag run detail.' } as const; diff --git a/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts b/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts index 5b216da03a58b..98136d4688cae 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts @@ -2211,6 +2211,8 @@ export type PartitionedDagRunAssetResponse = { asset_name: string; asset_uri: string; received: boolean; + received_count: number; + required_count: number; }; /** diff --git a/airflow-core/src/airflow/ui/src/components/AssetExpression/AssetExpression.tsx b/airflow-core/src/airflow/ui/src/components/AssetExpression/AssetExpression.tsx index 9b2a8e3b4f2b8..4b9b7ccc03579 100644 --- a/airflow-core/src/airflow/ui/src/components/AssetExpression/AssetExpression.tsx +++ b/airflow-core/src/airflow/ui/src/components/AssetExpression/AssetExpression.tsx @@ -39,6 +39,14 @@ export const AssetExpression = ({ return undefined; } + // A bare asset or alias at the top level (no all/any wrapper) — render it directly. + if ("asset" in expression) { + return ev.id === expression.asset.id)} />; + } + if ("alias" in expression) { + return ; + } + return ( <> {"any" in expression ? ( diff --git a/airflow-core/src/airflow/ui/src/components/AssetExpression/AssetNode.tsx b/airflow-core/src/airflow/ui/src/components/AssetExpression/AssetNode.tsx index 6289309ec59f4..e4d2368abe3ce 100644 --- a/airflow-core/src/airflow/ui/src/components/AssetExpression/AssetNode.tsx +++ b/airflow-core/src/airflow/ui/src/components/AssetExpression/AssetNode.tsx @@ -30,32 +30,44 @@ export const AssetNode = ({ }: { readonly asset: AssetSummary; readonly event?: NextRunEvent; -}) => ( - - - {"asset" in asset ? : } - {"alias" in asset ? ( - {asset.alias.name} - ) : ( - - {asset.asset.name} - - )} - - {event?.lastUpdate === undefined ? undefined : ( - - - )} - -); +}) => { + const isFullyReceived = Boolean(event?.lastUpdate); + const isPartial = + !isFullyReceived && + (event?.receivedCount ?? 0) > 0 && + (event?.receivedCount ?? 0) < (event?.requiredCount ?? 1); + + return ( + + + {"asset" in asset ? : } + {"alias" in asset ? ( + {asset.alias.name} + ) : ( + + {asset.asset.name} + + )} + + {isFullyReceived ? ( + + + ) : isPartial ? ( + + {event?.receivedCount} / {event?.requiredCount} + + ) : undefined} + + ); +}; diff --git a/airflow-core/src/airflow/ui/src/components/AssetExpression/types.ts b/airflow-core/src/airflow/ui/src/components/AssetExpression/types.ts index 7328dcfcc0eec..8766634b68dfb 100644 --- a/airflow-core/src/airflow/ui/src/components/AssetExpression/types.ts +++ b/airflow-core/src/airflow/ui/src/components/AssetExpression/types.ts @@ -33,7 +33,14 @@ type Alias = { }; }; -export type NextRunEvent = { id: number; lastUpdate: string | null; name: string | null; uri: string }; +export type NextRunEvent = { + id: number; + lastUpdate: string | null; + name: string | null; + receivedCount?: number; + requiredCount?: number; + uri: string; +}; export type AssetSummary = Alias | Asset; diff --git a/airflow-core/src/airflow/ui/src/components/AssetProgressCell.tsx b/airflow-core/src/airflow/ui/src/components/AssetProgressCell.tsx index 7d305a03fb765..2a70cf7517d32 100644 --- a/airflow-core/src/airflow/ui/src/components/AssetProgressCell.tsx +++ b/airflow-core/src/airflow/ui/src/components/AssetProgressCell.tsx @@ -39,11 +39,13 @@ export const AssetProgressCell = ({ dagId, partitionKey, totalReceived, totalReq const assets: Array = data?.assets ?? []; const events: Array = assets - .filter((ak: PartitionedDagRunAssetResponse) => ak.received) + .filter((ak: PartitionedDagRunAssetResponse) => ak.received_count > 0) .map((ak: PartitionedDagRunAssetResponse) => ({ id: ak.asset_id, - lastUpdate: "received", + lastUpdate: ak.received ? "received" : null, name: ak.asset_name, + receivedCount: ak.received_count, + requiredCount: ak.required_count, uri: ak.asset_uri, })); diff --git a/airflow-core/src/airflow/ui/src/pages/DagsList/AssetSchedule.tsx b/airflow-core/src/airflow/ui/src/pages/DagsList/AssetSchedule.tsx index 78536cc1c1a46..790a806fef590 100644 --- a/airflow-core/src/airflow/ui/src/pages/DagsList/AssetSchedule.tsx +++ b/airflow-core/src/airflow/ui/src/pages/DagsList/AssetSchedule.tsx @@ -83,6 +83,7 @@ export const AssetSchedule = ({ assetExpression, dagId, timetablePartitioned, ti } } + // Fully satisfied assets (used for the button count label). const pendingEvents = nextRunEvents.flatMap((event) => { if (timetablePartitioned) { return event.lastUpdate === null ? [] : [event]; @@ -92,6 +93,22 @@ export const AssetSchedule = ({ assetExpression, dagId, timetablePartitioned, ti return queuedAt === undefined ? [] : [{ ...event, lastUpdate: event.lastUpdate ?? queuedAt }]; }); + + // For partitioned Dags, also include partially-received assets in the popover visualization. + const popoverEvents = timetablePartitioned + ? nextRunEvents.filter((event) => (event.receivedCount ?? (event.lastUpdate === null ? 0 : 1)) > 0) + : pendingEvents; + + // For partitioned Dags (which may use rollup mappers), compute event-level totals so the + // button label reflects received/required partition-key events, not just asset counts. + // For non-partitioned Dags, fall back to asset counts (existing behaviour). + const scheduledCount = timetablePartitioned + ? nextRunEvents.reduce((sum, event) => sum + (event.receivedCount ?? 0), 0) + : pendingEvents.length; + const scheduledTotal = timetablePartitioned + ? nextRunEvents.reduce((sum, event) => sum + (event.requiredCount ?? 1), 0) + : nextRunEvents.length; + const isLoading = isNextRunLoading || (!timetablePartitioned && isQueuedEventsLoading); if (!nextRunEvents.length) { @@ -143,14 +160,14 @@ export const AssetSchedule = ({ assetExpression, dagId, timetablePartitioned, ti diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_partitioned_dag_runs.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_partitioned_dag_runs.py index a2a86da327b86..3bd09cc5644d8 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_partitioned_dag_runs.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_partitioned_dag_runs.py @@ -146,7 +146,7 @@ def test_should_response_200( ) session.commit() - with assert_queries_count(3): + with assert_queries_count(4): resp = test_client.get( f"/partitioned_dag_runs?dag_id=list_dag" f"&has_created_dag_run_id={str(has_created_dag_run_id).lower()}" From 8c2a7e87cbbac8337ea426eb875cb03652286218 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Wed, 15 Apr 2026 18:25:57 +0800 Subject: [PATCH 04/16] feat(ui): update UI for to_upstream --- .../datamodels/ui/partitioned_dag_runs.py | 2 + .../core_api/openapi/_private_ui.yaml | 12 ++++ .../api_fastapi/core_api/routes/ui/assets.py | 33 +++++----- .../routes/ui/partitioned_dag_runs.py | 35 +++++------ .../ui/openapi-gen/requests/schemas.gen.ts | 16 ++++- .../ui/openapi-gen/requests/types.gen.ts | 2 + .../src/components/AssetExpression/types.ts | 2 + .../ui/src/components/AssetProgressCell.tsx | 45 +++++++++++++- .../ui/src/pages/DagsList/AssetSchedule.tsx | 60 ++++++++++++++++++- 9 files changed, 169 insertions(+), 38 deletions(-) diff --git a/airflow-core/src/airflow/api_fastapi/core_api/datamodels/ui/partitioned_dag_runs.py b/airflow-core/src/airflow/api_fastapi/core_api/datamodels/ui/partitioned_dag_runs.py index 19fc6b62129bc..377af683811c4 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/datamodels/ui/partitioned_dag_runs.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/datamodels/ui/partitioned_dag_runs.py @@ -49,6 +49,8 @@ class PartitionedDagRunAssetResponse(BaseModel): received: bool received_count: int required_count: int + received_keys: list[str] + required_keys: list[str] class PartitionedDagRunDetailResponse(BaseModel): diff --git a/airflow-core/src/airflow/api_fastapi/core_api/openapi/_private_ui.yaml b/airflow-core/src/airflow/api_fastapi/core_api/openapi/_private_ui.yaml index 995fc004e46b3..bd89a8794e213 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/openapi/_private_ui.yaml +++ b/airflow-core/src/airflow/api_fastapi/core_api/openapi/_private_ui.yaml @@ -3158,6 +3158,16 @@ components: required_count: type: integer title: Required Count + received_keys: + items: + type: string + type: array + title: Received Keys + required_keys: + items: + type: string + type: array + title: Required Keys type: object required: - asset_id @@ -3166,6 +3176,8 @@ components: - received - received_count - required_count + - received_keys + - required_keys title: PartitionedDagRunAssetResponse description: Asset info within a partitioned Dag run detail. PartitionedDagRunCollectionResponse: diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/assets.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/assets.py index 84b1e5fdf28fc..387e0215a127d 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/assets.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/assets.py @@ -135,18 +135,15 @@ def next_run_assets( ).one_or_none() if pending_apdr is not None: - # Count received log entries per asset for this partition - received_by_asset: dict[int, int] = { - row.asset_id: row.cnt - for row in session.execute( - select( - PartitionedAssetKeyLog.asset_id, - func.count(PartitionedAssetKeyLog.id).label("cnt"), - ) - .where(PartitionedAssetKeyLog.asset_partition_dag_run_id == pending_apdr.id) - .group_by(PartitionedAssetKeyLog.asset_id) - ).all() - } + # Collect received upstream partition keys per asset for this partition run. + received_keys_by_asset: dict[int, list[str]] = {} + for row in session.execute( + select( + PartitionedAssetKeyLog.asset_id, + PartitionedAssetKeyLog.source_partition_key, + ).where(PartitionedAssetKeyLog.asset_partition_dag_run_id == pending_apdr.id) + ): + received_keys_by_asset.setdefault(row.asset_id, []).append(row.source_partition_key or "") timetable = None serdag = SerializedDagModel.get(dag_id=dag_id, session=session) @@ -158,8 +155,8 @@ def next_run_assets( for event in events: asset_id = event["id"] - received_count = received_by_asset.get(asset_id, 0) - required_count = 1 + received_keys = received_keys_by_asset.get(asset_id, []) + required_keys: list[str] = [pending_apdr.partition_key] if timetable is not None: with suppress(Exception): mapper = timetable.get_partition_mapper( @@ -168,9 +165,13 @@ def next_run_assets( ) mapper.is_rollup if isinstance(mapper, RollupMapper): - required_count = len(mapper.to_upstream(pending_apdr.partition_key)) + required_keys = sorted(mapper.to_upstream(pending_apdr.partition_key)) + received_count = len(received_keys) + required_count = len(required_keys) event["receivedCount"] = received_count event["requiredCount"] = required_count + event["receivedKeys"] = sorted(received_keys) + event["requiredKeys"] = required_keys # Only show lastUpdate when all required upstream keys are received if received_count < required_count: event["lastUpdate"] = None @@ -178,6 +179,8 @@ def next_run_assets( for event in events: event["receivedCount"] = 0 event["requiredCount"] = 1 + event["receivedKeys"] = [] + event["requiredKeys"] = [] data: dict = {"asset_expression": dag_model.asset_expression, "events": events} if pending_partition_count is not None: diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/partitioned_dag_runs.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/partitioned_dag_runs.py index 0860957374521..e5c0810eec344 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/partitioned_dag_runs.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/partitioned_dag_runs.py @@ -237,17 +237,15 @@ def get_pending_partitioned_dag_run( f"No PartitionedDagRun for dag={dag_id} partition={partition_key}", ) - # Count received PartitionedAssetKeyLog entries per asset for this partition - received_count_col = ( - select(func.count(PartitionedAssetKeyLog.id)) - .where( - PartitionedAssetKeyLog.asset_partition_dag_run_id == partitioned_dag_run.id, - PartitionedAssetKeyLog.asset_id == AssetModel.id, - ) - .correlate(AssetModel) - .scalar_subquery() - .label("received_count") - ) + # Collect received upstream partition keys per asset for this partition run. + received_keys_by_asset: dict[int, list[str]] = {} + for row in session.execute( + select( + PartitionedAssetKeyLog.asset_id, + PartitionedAssetKeyLog.source_partition_key, + ).where(PartitionedAssetKeyLog.asset_partition_dag_run_id == partitioned_dag_run.id) + ): + received_keys_by_asset.setdefault(row.asset_id, []).append(row.source_partition_key or "") asset_expression_subq = ( select(DagModel.asset_expression).where(DagModel.dag_id == dag_id).scalar_subquery() @@ -257,15 +255,14 @@ def get_pending_partitioned_dag_run( AssetModel.id, AssetModel.uri, AssetModel.name, - received_count_col, asset_expression_subq.label("asset_expression"), ) .join(DagScheduleAssetReference, DagScheduleAssetReference.asset_id == AssetModel.id) .where(DagScheduleAssetReference.dag_id == dag_id, AssetModel.active.has()) - .order_by(received_count_col, AssetModel.uri) + .order_by(AssetModel.uri) ).all() - # Load serialized DAG to compute required counts for rollup assets + # Load serialized DAG to compute required keys for rollup assets timetable = None serdag = SerializedDagModel.get(dag_id=dag_id, session=session) if serdag is not None: @@ -276,13 +273,15 @@ def get_pending_partitioned_dag_run( assets = [] for row in asset_rows: - received_count = row.received_count or 0 - required_count = 1 + received_keys = received_keys_by_asset.get(row.id, []) + required_keys: list[str] = [partition_key] if timetable is not None: with suppress(Exception): mapper = timetable.get_partition_mapper(name=row.name or "", uri=row.uri or "") if isinstance(mapper, RollupMapper): - required_count = len(mapper.to_upstream(partition_key)) + required_keys = sorted(mapper.to_upstream(partition_key)) + received_count = len(received_keys) + required_count = len(required_keys) assets.append( PartitionedDagRunAssetResponse( asset_id=asset_row.id, @@ -291,6 +290,8 @@ def get_pending_partitioned_dag_run( received=received_count >= required_count and required_count > 0, received_count=received_count, required_count=required_count, + received_keys=sorted(received_keys), + required_keys=required_keys, ) ) diff --git a/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts b/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts index 18062b0454be6..4a1e9ba8fb0ce 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts @@ -8885,10 +8885,24 @@ export const $PartitionedDagRunAssetResponse = { required_count: { type: 'integer', title: 'Required Count' + }, + received_keys: { + items: { + type: 'string' + }, + type: 'array', + title: 'Received Keys' + }, + required_keys: { + items: { + type: 'string' + }, + type: 'array', + title: 'Required Keys' } }, type: 'object', - required: ['asset_id', 'asset_name', 'asset_uri', 'received', 'received_count', 'required_count'], + required: ['asset_id', 'asset_name', 'asset_uri', 'received', 'received_count', 'required_count', 'received_keys', 'required_keys'], title: 'PartitionedDagRunAssetResponse', description: 'Asset info within a partitioned Dag run detail.' } as const; diff --git a/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts b/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts index 98136d4688cae..23a597c66b2d7 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts @@ -2213,6 +2213,8 @@ export type PartitionedDagRunAssetResponse = { received: boolean; received_count: number; required_count: number; + received_keys: Array<(string)>; + required_keys: Array<(string)>; }; /** diff --git a/airflow-core/src/airflow/ui/src/components/AssetExpression/types.ts b/airflow-core/src/airflow/ui/src/components/AssetExpression/types.ts index 8766634b68dfb..0ee56f8d7f889 100644 --- a/airflow-core/src/airflow/ui/src/components/AssetExpression/types.ts +++ b/airflow-core/src/airflow/ui/src/components/AssetExpression/types.ts @@ -38,7 +38,9 @@ export type NextRunEvent = { lastUpdate: string | null; name: string | null; receivedCount?: number; + receivedKeys?: Array; requiredCount?: number; + requiredKeys?: Array; uri: string; }; diff --git a/airflow-core/src/airflow/ui/src/components/AssetProgressCell.tsx b/airflow-core/src/airflow/ui/src/components/AssetProgressCell.tsx index 2a70cf7517d32..d0f90285718d6 100644 --- a/airflow-core/src/airflow/ui/src/components/AssetProgressCell.tsx +++ b/airflow-core/src/airflow/ui/src/components/AssetProgressCell.tsx @@ -16,8 +16,9 @@ * specific language governing permissions and limitations * under the License. */ -import { Button } from "@chakra-ui/react"; -import { FiDatabase } from "react-icons/fi"; +import { Button, HStack, Link, Text, VStack } from "@chakra-ui/react"; +import { FiCheck, FiDatabase, FiMinus } from "react-icons/fi"; +import { Link as RouterLink } from "react-router-dom"; import { usePartitionedDagRunServiceGetPendingPartitionedDagRun } from "openapi/queries"; import type { PartitionedDagRunAssetResponse } from "openapi/requests/types.gen"; @@ -38,6 +39,8 @@ export const AssetProgressCell = ({ dagId, partitionKey, totalReceived, totalReq const assetExpression = data?.asset_expression as ExpressionType | undefined; const assets: Array = data?.assets ?? []; + const hasRollup = assets.some((ak) => ak.required_count > 1); + const events: Array = assets .filter((ak: PartitionedDagRunAssetResponse) => ak.received_count > 0) .map((ak: PartitionedDagRunAssetResponse) => ({ @@ -61,7 +64,43 @@ export const AssetProgressCell = ({ dagId, partitionKey, totalReceived, totalReq - + {hasRollup ? ( + + {assets + .filter((ak) => ak.required_count > 1) + .map((ak) => { + const receivedKeySet = new Set(ak.received_keys); + + return ( + + {assets.length > 1 ? ( + + {ak.asset_name} + + ) : undefined} + {ak.required_keys.map((key) => { + const isReceived = receivedKeySet.has(key); + + return ( + + {isReceived ? ( + + ) : ( + + )} + + {key} + + + ); + })} + + ); + })} + + ) : ( + + )} diff --git a/airflow-core/src/airflow/ui/src/pages/DagsList/AssetSchedule.tsx b/airflow-core/src/airflow/ui/src/pages/DagsList/AssetSchedule.tsx index 790a806fef590..69d74cb38c45a 100644 --- a/airflow-core/src/airflow/ui/src/pages/DagsList/AssetSchedule.tsx +++ b/airflow-core/src/airflow/ui/src/pages/DagsList/AssetSchedule.tsx @@ -16,11 +16,11 @@ * specific language governing permissions and limitations * under the License. */ -import { Button, HStack, Link, Text } from "@chakra-ui/react"; +import { Button, HStack, Link, Text, VStack } from "@chakra-ui/react"; import dayjs from "dayjs"; import { useState } from "react"; import { useTranslation } from "react-i18next"; -import { FiDatabase } from "react-icons/fi"; +import { FiCheck, FiDatabase, FiMinus } from "react-icons/fi"; import { Link as RouterLink } from "react-router-dom"; import { useAssetServiceGetDagAssetQueuedEvents, useAssetServiceNextRunAssets } from "openapi/queries"; @@ -142,6 +142,62 @@ export const AssetSchedule = ({ assetExpression, dagId, timetablePartitioned, ti const [asset] = nextRunEvents; if (nextRunEvents.length === 1 && asset !== undefined) { + const requiredCount = asset.requiredCount ?? 1; + const receivedCount = asset.receivedCount ?? 0; + + if (requiredCount > 1) { + const requiredKeys = asset.requiredKeys ?? []; + const receivedKeySet = new Set(asset.receivedKeys ?? []); + + return ( + + + + + + + + {/* eslint-disable-next-line jsx-a11y/no-autofocus */} + + + + + + + + + {requiredKeys.map((key) => { + const isReceived = receivedKeySet.has(key); + + return ( + + {isReceived ? ( + + ) : ( + + )} + + {key} + + + ); + })} + + + + + + ); + } + return ( From 3e24fac0dc4f8ea7368e076fd5ef37e04eeba9f9 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Wed, 15 Apr 2026 18:27:15 +0800 Subject: [PATCH 05/16] docs: add example Dag for window partition case --- .../example_dags/example_asset_partition.py | 67 +++++++++++++++++++ 1 file changed, 67 insertions(+) diff --git a/airflow-core/src/airflow/example_dags/example_asset_partition.py b/airflow-core/src/airflow/example_dags/example_asset_partition.py index 75d582f6cad6e..1569c850bf7a3 100644 --- a/airflow-core/src/airflow/example_dags/example_asset_partition.py +++ b/airflow-core/src/airflow/example_dags/example_asset_partition.py @@ -23,6 +23,7 @@ DAG, AllowedKeyMapper, Asset, + AssetAll, CronPartitionTimetable, IdentityMapper, PartitionedAssetTimetable, @@ -30,6 +31,7 @@ StartOfDayMapper, StartOfHourMapper, StartOfYearMapper, + WeeklyRollupMapper, asset, task, ) @@ -225,3 +227,68 @@ def regional_stats_breakdown(): keys belong to a fixed set of allowed values (``us``, ``eu``, ``apac``) rather than time-based partitions. """ pass + + +daily_sales = Asset(uri="s3://sales/daily", name="daily_sales") +daily_costs = Asset(uri="s3://costs/daily", name="daily_costs") + + +with DAG( + dag_id="produce_daily_sales", + schedule=CronPartitionTimetable("0 0 * * *", timezone="UTC"), + tags=["sales", "ingestion"], +): + """Produce daily sales data partitioned by date.""" + + @task(outlets=[daily_sales]) + def upload_daily_sales(dag_run=None): + """Upload sales data for the current daily partition.""" + if TYPE_CHECKING: + assert dag_run + print(f"Producing partition: {dag_run.partition_key}") + + upload_daily_sales() + + +with DAG( + dag_id="produce_daily_costs", + schedule=CronPartitionTimetable("0 0 * * *", timezone="UTC"), + tags=["costs", "ingestion"], +): + """Produce daily cost data partitioned by date.""" + + @task(outlets=[daily_costs]) + def upload_daily_costs(dag_run=None): + """Upload cost data for the current daily partition.""" + if TYPE_CHECKING: + assert dag_run + print(f"Producing partition: {dag_run.partition_key}") + + upload_daily_costs() + + +with DAG( + dag_id="weekly_sales_report", + schedule=PartitionedAssetTimetable( + assets=AssetAll(daily_sales, daily_costs), + default_partition_mapper=WeeklyRollupMapper(), + ), + catchup=False, + tags=["sales", "reporting"], +): + """ + Generate a weekly sales report once all daily partitions for both assets have arrived. + + This Dag demonstrates WeeklyRollupMapper with multiple assets: it waits for all 7 + daily partitions of ``daily_sales`` and ``daily_costs`` before triggering for a given week. + The partition key is the week identifier, e.g. ``2024-01-15 (W03)``. + """ + + @task + def generate_weekly_report(dag_run=None): + """Combine the full week of sales and cost data into a report.""" + if TYPE_CHECKING: + assert dag_run + print(f"All 7 daily partitions for both assets received. Week: {dag_run.partition_key}") + + generate_weekly_report() From 39d1b2c067505a74036539e2d5562110fe69b3f9 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Thu, 16 Apr 2026 18:23:57 +0800 Subject: [PATCH 06/16] feat: improve UX for both roll up and non-rollup cases --- .../datamodels/ui/partitioned_dag_runs.py | 1 + .../core_api/openapi/_private_ui.yaml | 4 + .../api_fastapi/core_api/routes/ui/assets.py | 75 ++++--- .../routes/ui/partitioned_dag_runs.py | 142 ++++++++---- .../src/airflow/jobs/scheduler_job_runner.py | 9 +- .../src/airflow/partition_mappers/temporal.py | 21 +- .../ui/openapi-gen/requests/schemas.gen.ts | 5 + .../ui/openapi-gen/requests/types.gen.ts | 1 + .../components/AssetExpression/AssetNode.tsx | 85 ++++++- .../src/components/AssetExpression/types.ts | 11 +- .../ui/src/components/AssetProgressCell.tsx | 23 +- .../ui/src/pages/DagsList/AssetSchedule.tsx | 23 +- .../core_api/routes/ui/test_assets.py | 10 +- .../routes/ui/test_partitioned_dag_runs.py | 208 +++++++++++++++++- 14 files changed, 493 insertions(+), 125 deletions(-) diff --git a/airflow-core/src/airflow/api_fastapi/core_api/datamodels/ui/partitioned_dag_runs.py b/airflow-core/src/airflow/api_fastapi/core_api/datamodels/ui/partitioned_dag_runs.py index 377af683811c4..aca1d9a331dee 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/datamodels/ui/partitioned_dag_runs.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/datamodels/ui/partitioned_dag_runs.py @@ -51,6 +51,7 @@ class PartitionedDagRunAssetResponse(BaseModel): required_count: int received_keys: list[str] required_keys: list[str] + is_rollup: bool = False class PartitionedDagRunDetailResponse(BaseModel): diff --git a/airflow-core/src/airflow/api_fastapi/core_api/openapi/_private_ui.yaml b/airflow-core/src/airflow/api_fastapi/core_api/openapi/_private_ui.yaml index bd89a8794e213..5f4701b21e926 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/openapi/_private_ui.yaml +++ b/airflow-core/src/airflow/api_fastapi/core_api/openapi/_private_ui.yaml @@ -3168,6 +3168,10 @@ components: type: string type: array title: Required Keys + is_rollup: + type: boolean + title: Is Rollup + default: false type: object required: - asset_id diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/assets.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/assets.py index 387e0215a127d..b35e33fa861c4 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/assets.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/assets.py @@ -24,6 +24,7 @@ from airflow.api_fastapi.common.db.common import SessionDep from airflow.api_fastapi.common.router import AirflowRouter +from airflow.api_fastapi.core_api.routes.ui.partitioned_dag_runs import _load_timetable from airflow.api_fastapi.core_api.security import requires_access_asset, requires_access_dag from airflow.models import DagModel from airflow.models.asset import ( @@ -34,9 +35,9 @@ DagScheduleAssetReference, PartitionedAssetKeyLog, ) -from airflow.models.serialized_dag import SerializedDagModel -from airflow.partition_mappers.base import RollupMapper -from airflow.timetables.simple import PartitionedAssetTimetable + +if TYPE_CHECKING: + from airflow.partition_mappers.base import RollupMapper assets_router = AirflowRouter(tags=["Asset"]) @@ -96,7 +97,7 @@ def next_run_assets( AssetModel.id, AssetModel.uri, AssetModel.name, - func.max(AssetEvent.timestamp).label("lastUpdate"), + func.max(AssetEvent.timestamp).label("last_update"), queued_expr.label("queued"), ) .join(DagScheduleAssetReference, DagScheduleAssetReference.asset_id == AssetModel.id) @@ -119,11 +120,13 @@ def next_run_assets( events = [dict(info._mapping) for info in session.execute(query)] for event in events: if not event.pop("queued", None): - event["lastUpdate"] = None + event["last_update"] = None # For partitioned Dags: enrich events with per-asset received/required counts, - # using to_upstream for rollup mappers, and fix lastUpdate for partial receipt. + # using to_upstream for rollup mappers, and fix last_update for partial receipt. if is_partitioned: + timetable = _load_timetable(dag_id, session) + pending_apdr = session.execute( select(AssetPartitionDagRun.id, AssetPartitionDagRun.partition_key) .where( @@ -136,51 +139,59 @@ def next_run_assets( if pending_apdr is not None: # Collect received upstream partition keys per asset for this partition run. - received_keys_by_asset: dict[int, list[str]] = {} + # Use a set to deduplicate: multiple events for the same key count as one. + received_keys_by_asset: dict[int, set[str]] = {} for row in session.execute( select( PartitionedAssetKeyLog.asset_id, PartitionedAssetKeyLog.source_partition_key, ).where(PartitionedAssetKeyLog.asset_partition_dag_run_id == pending_apdr.id) ): - received_keys_by_asset.setdefault(row.asset_id, []).append(row.source_partition_key or "") - - timetable = None - serdag = SerializedDagModel.get(dag_id=dag_id, session=session) - if serdag is not None: - with suppress(Exception): - dag_obj = serdag.dag - if isinstance(dag_obj.timetable, PartitionedAssetTimetable): - timetable = dag_obj.timetable + received_keys_by_asset.setdefault(row.asset_id, set()).add(row.source_partition_key or "") for event in events: asset_id = event["id"] - received_keys = received_keys_by_asset.get(asset_id, []) + received_keys = sorted(received_keys_by_asset.get(asset_id, set())) required_keys: list[str] = [pending_apdr.partition_key] + is_rollup = False if timetable is not None: with suppress(Exception): mapper = timetable.get_partition_mapper( - name=event.get("name") or "", - uri=event.get("uri") or "", + name=event["name"], + uri=event["uri"], ) - mapper.is_rollup - if isinstance(mapper, RollupMapper): - required_keys = sorted(mapper.to_upstream(pending_apdr.partition_key)) + if mapper.is_rollup: + required_keys = sorted( + cast("RollupMapper", mapper).to_upstream(pending_apdr.partition_key) + ) + is_rollup = True received_count = len(received_keys) required_count = len(required_keys) - event["receivedCount"] = received_count - event["requiredCount"] = required_count - event["receivedKeys"] = sorted(received_keys) - event["requiredKeys"] = required_keys - # Only show lastUpdate when all required upstream keys are received + event["received_count"] = received_count + event["required_count"] = required_count + event["received_keys"] = received_keys + event["required_keys"] = required_keys + event["is_rollup"] = is_rollup + # Only show last_update when all required upstream keys are received if received_count < required_count: - event["lastUpdate"] = None + event["last_update"] = None else: + # No pending APDR yet — mark rollup assets so the UI can handle them + # correctly (e.g. skip "Asset Triggered" in favour of the asset name view). for event in events: - event["receivedCount"] = 0 - event["requiredCount"] = 1 - event["receivedKeys"] = [] - event["requiredKeys"] = [] + is_rollup = False + if timetable is not None: + with suppress(Exception): + mapper = timetable.get_partition_mapper( + name=event["name"], + uri=event["uri"], + ) + is_rollup = mapper.is_rollup + event["received_count"] = 0 + event["required_count"] = 1 + event["received_keys"] = [] + event["required_keys"] = [] + event["is_rollup"] = is_rollup data: dict = {"asset_expression": dag_model.asset_expression, "events": events} if pending_partition_count is not None: diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/partitioned_dag_runs.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/partitioned_dag_runs.py index e5c0810eec344..2216b4908efc5 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/partitioned_dag_runs.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/partitioned_dag_runs.py @@ -54,24 +54,34 @@ from airflow.partition_mappers.base import RollupMapper +def _load_timetable(dag_id: str, session) -> PartitionedAssetTimetable | None: + """Return the PartitionedAssetTimetable for *dag_id*, or None if absent or not partitioned.""" + serdag = SerializedDagModel.get(dag_id=dag_id, session=session) + if serdag is None: + return None + with suppress(Exception): + if isinstance(serdag.dag.timetable, PartitionedAssetTimetable): + return serdag.dag.timetable + return None + + def _load_timetable_and_assets( dag_id: str, session -) -> tuple[PartitionedAssetTimetable | None, list[tuple[str, str]]]: - """Load the DAG timetable and its active required assets as (name, uri) pairs.""" - timetable = None - serdag = SerializedDagModel.get(dag_id=dag_id, session=session) - if serdag is not None: - with suppress(Exception): - dag = serdag.dag - if isinstance(dag.timetable, PartitionedAssetTimetable): - timetable = dag.timetable +) -> tuple[PartitionedAssetTimetable | None, list[tuple[str, str]], dict[int, tuple[str, str]]]: + """ + Load timetable and active required assets. + Returns (timetable, [(name, uri), ...], {asset_id: (name, uri)}). + """ + timetable = _load_timetable(dag_id, session) asset_rows = session.execute( - select(AssetModel.name, AssetModel.uri) + select(AssetModel.id, AssetModel.name, AssetModel.uri) .join(DagScheduleAssetReference, DagScheduleAssetReference.asset_id == AssetModel.id) .where(DagScheduleAssetReference.dag_id == dag_id, AssetModel.active.has()) ).all() - return timetable, [(r.name or "", r.uri or "") for r in asset_rows] + asset_info = [(r.name, r.uri) for r in asset_rows] + asset_id_to_info = {r.id: (r.name, r.uri) for r in asset_rows} + return timetable, asset_info, asset_id_to_info def _compute_total_required( @@ -84,24 +94,48 @@ def _compute_total_required( return len(asset_info) total = 0 for name, uri in asset_info: - try: + mapper = timetable.get_partition_mapper(name=name, uri=uri) + total += len(cast("RollupMapper", mapper).to_upstream(partition_key)) if mapper.is_rollup else 1 + return total + + +def _compute_received_count( + received_by_asset: dict[int, set[str]], + timetable: PartitionedAssetTimetable | None, + asset_id_to_info: dict[int, tuple[str, str]], + partition_key: str, +) -> int: + """ + Count received events using rollup-aware deduplication. + + For rollup assets: count distinct upstream keys that intersect the required set. + For non-rollup assets: count 1 per asset if any event has been logged — the + source_partition_key value is irrelevant; having any event satisfies the requirement. + """ + total = 0 + for asset_id, received_keys in received_by_asset.items(): + if timetable is not None: + name, uri = asset_id_to_info[asset_id] mapper = timetable.get_partition_mapper(name=name, uri=uri) - total += len(mapper.to_upstream(partition_key)) if isinstance(mapper, RollupMapper) else 1 - except Exception: - total += 1 - return total or len(asset_info) + if mapper.is_rollup: + required_keys = frozenset(cast("RollupMapper", mapper).to_upstream(partition_key)) + total += len(received_keys & required_keys) + continue + # Non-rollup: any logged event satisfies this asset's requirement. + total += 1 if received_keys else 0 + return total partitioned_dag_runs_router = AirflowRouter(tags=["PartitionedDagRun"]) -def _build_response(row, required_count: int) -> PartitionedDagRunResponse: +def _build_response(row, required_count: int, received_count: int | None = None) -> PartitionedDagRunResponse: return PartitionedDagRunResponse( id=row.id, dag_id=row.target_dag_id, partition_key=row.partition_key, created_at=row.created_at.isoformat() if row.created_at else None, - total_received=row.total_received or 0, + total_received=received_count if received_count is not None else (row.total_received or 0), total_required=required_count, state=row.dag_run_state if row.created_dag_run_id else "pending", created_dag_run_id=row.dag_run_id, @@ -129,7 +163,9 @@ def get_partitioned_dag_runs( if dag_info.timetable_summary != "Partitioned Asset": return PartitionedDagRunCollectionResponse(partitioned_dag_runs=[], total=0) - # Subquery: count received events per partition (PartitionedAssetKeyLog rows for required assets) + # Subquery for received count per partition (count of required assets that have any log). + # This matches the non-rollup contract "any event for an asset = that asset is satisfied". + # Rollup-aware counts are computed in Python in _compute_received_count when dag_id is set. required_assets_subq = ( select(DagScheduleAssetReference.asset_id) .join(AssetModel, AssetModel.id == DagScheduleAssetReference.asset_id) @@ -140,7 +176,7 @@ def get_partitioned_dag_runs( .correlate(AssetPartitionDagRun) ) received_subq = ( - select(func.count(PartitionedAssetKeyLog.id)) + select(func.count(func.distinct(PartitionedAssetKeyLog.asset_id))) .where( PartitionedAssetKeyLog.asset_partition_dag_run_id == AssetPartitionDagRun.id, PartitionedAssetKeyLog.asset_id.in_(required_assets_subq), @@ -169,18 +205,43 @@ def get_partitioned_dag_runs( return PartitionedDagRunCollectionResponse(partitioned_dag_runs=[], total=0) if dag_id.value is not None: - timetable, asset_info = _load_timetable_and_assets(dag_id.value, session) + timetable, asset_info, asset_id_to_info = _load_timetable_and_assets(dag_id.value, session) + + # Batch-fetch all log entries for these APDRs in one query. + apdr_ids = [row.id for row in rows] + log_by_apdr: dict[int, dict[int, set[str]]] = {} + for log_row in session.execute( + select( + PartitionedAssetKeyLog.asset_partition_dag_run_id, + PartitionedAssetKeyLog.asset_id, + PartitionedAssetKeyLog.source_partition_key, + ).where( + PartitionedAssetKeyLog.asset_partition_dag_run_id.in_(apdr_ids), + PartitionedAssetKeyLog.asset_id.in_(list(asset_id_to_info)), + ) + ).all(): + log_by_apdr.setdefault(log_row.asset_partition_dag_run_id, {}).setdefault( + log_row.asset_id, set() + ).add(log_row.source_partition_key or "") + results = [ - _build_response(row, _compute_total_required(timetable, asset_info, row.partition_key)) + _build_response( + row, + _compute_total_required(timetable, asset_info, row.partition_key), + _compute_received_count( + log_by_apdr.get(row.id, {}), timetable, asset_id_to_info, row.partition_key + ), + ) for row in rows ] return PartitionedDagRunCollectionResponse(partitioned_dag_runs=results, total=len(results)) - # No dag_id filter: load timetables and assets for each unique DAG + # No dag_id filter: load timetables and assets for each unique Dag. + # total_received is approximated via SQL for this global view. unique_dag_ids = list({row.target_dag_id for row in rows}) - dag_timetables_assets: dict[str, tuple[PartitionedAssetTimetable | None, list[tuple[str, str]]]] = { - did: _load_timetable_and_assets(did, session) for did in unique_dag_ids - } + dag_timetables_assets: dict[ + str, tuple[PartitionedAssetTimetable | None, list[tuple[str, str]], dict[int, tuple[str, str]]] + ] = {did: _load_timetable_and_assets(did, session) for did in unique_dag_ids} dag_rows = session.execute( select(DagModel.dag_id, DagModel.asset_expression).where(DagModel.dag_id.in_(unique_dag_ids)) ).all() @@ -238,14 +299,15 @@ def get_pending_partitioned_dag_run( ) # Collect received upstream partition keys per asset for this partition run. - received_keys_by_asset: dict[int, list[str]] = {} + # Use a set to deduplicate: multiple events for the same key count as one. + received_keys_by_asset: dict[int, set[str]] = {} for row in session.execute( select( PartitionedAssetKeyLog.asset_id, PartitionedAssetKeyLog.source_partition_key, ).where(PartitionedAssetKeyLog.asset_partition_dag_run_id == partitioned_dag_run.id) ): - received_keys_by_asset.setdefault(row.asset_id, []).append(row.source_partition_key or "") + received_keys_by_asset.setdefault(row.asset_id, set()).add(row.source_partition_key or "") asset_expression_subq = ( select(DagModel.asset_expression).where(DagModel.dag_id == dag_id).scalar_subquery() @@ -262,24 +324,19 @@ def get_pending_partitioned_dag_run( .order_by(AssetModel.uri) ).all() - # Load serialized DAG to compute required keys for rollup assets - timetable = None - serdag = SerializedDagModel.get(dag_id=dag_id, session=session) - if serdag is not None: - with suppress(Exception): - dag = serdag.dag - if isinstance(dag.timetable, PartitionedAssetTimetable): - timetable = dag.timetable + timetable = _load_timetable(dag_id, session) assets = [] - for row in asset_rows: - received_keys = received_keys_by_asset.get(row.id, []) + for asset_row in asset_rows: + received_keys = sorted(received_keys_by_asset.get(asset_row.id, set())) required_keys: list[str] = [partition_key] + is_rollup = False if timetable is not None: with suppress(Exception): - mapper = timetable.get_partition_mapper(name=row.name or "", uri=row.uri or "") - if isinstance(mapper, RollupMapper): - required_keys = sorted(mapper.to_upstream(partition_key)) + mapper = timetable.get_partition_mapper(name=row.name, uri=row.uri) + if mapper.is_rollup: + required_keys = sorted(cast("RollupMapper", mapper).to_upstream(partition_key)) + is_rollup = True received_count = len(received_keys) required_count = len(required_keys) assets.append( @@ -290,8 +347,9 @@ def get_pending_partitioned_dag_run( received=received_count >= required_count and required_count > 0, received_count=received_count, required_count=required_count, - received_keys=sorted(received_keys), + received_keys=received_keys, required_keys=required_keys, + is_rollup=is_rollup, ) ) diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py b/airflow-core/src/airflow/jobs/scheduler_job_runner.py index c4602933e6b4a..1b2096a8cb9a8 100644 --- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py +++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py @@ -1885,8 +1885,15 @@ def _resolve_asset_partition_status( def _create_dagruns_for_partitioned_asset_dags(self, session: Session) -> set[str]: partition_dag_ids: set[str] = set() + # Cap per-tick work so the scheduler transaction stays bounded and other + # scheduling work isn't starved. Remaining APDRs drain across subsequent ticks. + # Note: with strict FIFO ordering, >BATCH persistently-unsatisfied APDRs would + # block newer ones; switch to updated_at-based ordering if that becomes an issue. pending_apdrs = session.scalars( - select(AssetPartitionDagRun).where(AssetPartitionDagRun.created_dag_run_id.is_(None)) + select(AssetPartitionDagRun) + .where(AssetPartitionDagRun.created_dag_run_id.is_(None)) + .order_by(AssetPartitionDagRun.created_at) + .limit(500) ).all() if not pending_apdrs: return partition_dag_ids diff --git a/airflow-core/src/airflow/partition_mappers/temporal.py b/airflow-core/src/airflow/partition_mappers/temporal.py index 565c11d351c7d..ba1c3ee99d1e0 100644 --- a/airflow-core/src/airflow/partition_mappers/temporal.py +++ b/airflow-core/src/airflow/partition_mappers/temporal.py @@ -16,6 +16,7 @@ # under the License. from __future__ import annotations +import re from abc import ABC, abstractmethod from datetime import datetime, timedelta from typing import TYPE_CHECKING, Any @@ -23,6 +24,8 @@ from airflow._shared.timezones.timezone import make_aware, parse_timezone from airflow.partition_mappers.base import PartitionMapper, RollupMapper +_YMD_RE = re.compile(r"\d{4}-\d{2}-\d{2}") + if TYPE_CHECKING: from pendulum import FixedTimezone, Timezone @@ -150,15 +153,15 @@ def __init__(self, **kwargs) -> None: ) def to_upstream(self, downstream_key: str) -> frozenset[str]: - # Python strptime raises ValueError when %V (ISO week number) appears without - # %G and a weekday directive, so we cannot parse via the full output_format. - # Instead, locate %Y-%m-%d in the format string — __init__ guarantees it is - # present — and parse only the matching 10-char slice of the key. - # The prefix before %Y-%m-%d is literal text (no format directives), so its - # length in the format string equals its length in the formatted output. - ymd_fmt = "%Y-%m-%d" - key_start = len(self.output_format[: self.output_format.index(ymd_fmt)]) - week_start_naive = datetime.strptime(downstream_key[key_start : key_start + 10], ymd_fmt) + # strptime cannot consume %V (ISO week) without %G+weekday, so parse by + # locating the YYYY-MM-DD slice directly. Regex is robust against + # variable-width directives (e.g. %B, %A, %Z) appearing elsewhere in the key. + match = _YMD_RE.search(downstream_key) + if match is None: + raise ValueError( + f"WeeklyRollupMapper.to_upstream could not locate YYYY-MM-DD in {downstream_key!r}" + ) + week_start_naive = datetime.strptime(match.group(), "%Y-%m-%d") # Arithmetic stays on naive datetimes to keep day-counting unambiguous across # DST transitions; each result is made timezone-aware before formatting so that # %z in input_format produces the correct offset. diff --git a/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts b/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts index 4a1e9ba8fb0ce..2d4fd67fc420c 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts @@ -8899,6 +8899,11 @@ export const $PartitionedDagRunAssetResponse = { }, type: 'array', title: 'Required Keys' + }, + is_rollup: { + type: 'boolean', + title: 'Is Rollup', + default: false } }, type: 'object', diff --git a/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts b/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts index 23a597c66b2d7..304095701ae4a 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts @@ -2215,6 +2215,7 @@ export type PartitionedDagRunAssetResponse = { required_count: number; received_keys: Array<(string)>; required_keys: Array<(string)>; + is_rollup?: boolean; }; /** diff --git a/airflow-core/src/airflow/ui/src/components/AssetExpression/AssetNode.tsx b/airflow-core/src/airflow/ui/src/components/AssetExpression/AssetNode.tsx index e4d2368abe3ce..72f06bbcca421 100644 --- a/airflow-core/src/airflow/ui/src/components/AssetExpression/AssetNode.tsx +++ b/airflow-core/src/airflow/ui/src/components/AssetExpression/AssetNode.tsx @@ -16,14 +16,69 @@ * specific language governing permissions and limitations * under the License. */ -import { Box, Text, HStack, Link } from "@chakra-ui/react"; -import { FiDatabase } from "react-icons/fi"; +import { Box, Button, HStack, Link, Text, VStack } from "@chakra-ui/react"; +import { FiCheck, FiDatabase, FiMinus } from "react-icons/fi"; import { PiRectangleDashed } from "react-icons/pi"; import { Link as RouterLink } from "react-router-dom"; +import { Popover } from "src/components/ui"; + import Time from "../Time"; import type { AssetSummary, NextRunEvent } from "./types"; +const RollupKeyChecklistPopover = ({ + receivedCount, + receivedKeys, + requiredCount, + requiredKeys, +}: { + readonly receivedCount: number; + readonly receivedKeys: Array; + readonly requiredCount: number; + readonly requiredKeys: Array; +}) => { + const receivedKeySet = new Set(receivedKeys); + + return ( + // eslint-disable-next-line jsx-a11y/no-autofocus + + + + + + + + + {requiredKeys.map((key) => { + const isReceived = receivedKeySet.has(key); + + return ( + + {isReceived ? ( + + ) : ( + + )} + + {key} + + + ); + })} + + + + + ); +}; + export const AssetNode = ({ asset, event, @@ -31,11 +86,18 @@ export const AssetNode = ({ readonly asset: AssetSummary; readonly event?: NextRunEvent; }) => { - const isFullyReceived = Boolean(event?.lastUpdate); + const isFullyReceived = Boolean(event?.last_update); const isPartial = !isFullyReceived && - (event?.receivedCount ?? 0) > 0 && - (event?.receivedCount ?? 0) < (event?.requiredCount ?? 1); + (event?.received_count ?? 0) > 0 && + (event?.received_count ?? 0) < (event?.required_count ?? 1); + // In partitioned dags the `last_update` timestamp is the last asset event, not the + // pending partition key's arrival — so it isn't meaningful here. Suppress it for + // non-rollup partitioned nodes. + const isPartitionedNonRollup = event?.is_rollup === false; + const showTime = isFullyReceived && !isPartitionedNonRollup; + const showRollupChecklist = + (event?.is_rollup ?? false) && (event?.required_keys?.length ?? 0) > 0 && (isPartial || isFullyReceived); return ( )} - {isFullyReceived ? ( + {showRollupChecklist ? ( + + ) : showTime ? ( - ) : isPartial ? ( - {event?.receivedCount} / {event?.requiredCount} + {event?.received_count} / {event?.required_count} ) : undefined} diff --git a/airflow-core/src/airflow/ui/src/components/AssetExpression/types.ts b/airflow-core/src/airflow/ui/src/components/AssetExpression/types.ts index 0ee56f8d7f889..c24af1900d349 100644 --- a/airflow-core/src/airflow/ui/src/components/AssetExpression/types.ts +++ b/airflow-core/src/airflow/ui/src/components/AssetExpression/types.ts @@ -35,12 +35,13 @@ type Alias = { export type NextRunEvent = { id: number; - lastUpdate: string | null; + is_rollup?: boolean; + last_update: string | null; name: string | null; - receivedCount?: number; - receivedKeys?: Array; - requiredCount?: number; - requiredKeys?: Array; + received_count?: number; + received_keys?: Array; + required_count?: number; + required_keys?: Array; uri: string; }; diff --git a/airflow-core/src/airflow/ui/src/components/AssetProgressCell.tsx b/airflow-core/src/airflow/ui/src/components/AssetProgressCell.tsx index d0f90285718d6..83d1f2321f93a 100644 --- a/airflow-core/src/airflow/ui/src/components/AssetProgressCell.tsx +++ b/airflow-core/src/airflow/ui/src/components/AssetProgressCell.tsx @@ -39,16 +39,17 @@ export const AssetProgressCell = ({ dagId, partitionKey, totalReceived, totalReq const assetExpression = data?.asset_expression as ExpressionType | undefined; const assets: Array = data?.assets ?? []; - const hasRollup = assets.some((ak) => ak.required_count > 1); + const hasRollup = assets.some((ak) => ak.is_rollup); const events: Array = assets .filter((ak: PartitionedDagRunAssetResponse) => ak.received_count > 0) .map((ak: PartitionedDagRunAssetResponse) => ({ id: ak.asset_id, - lastUpdate: ak.received ? "received" : null, + is_rollup: ak.is_rollup, + last_update: ak.received ? "received" : null, name: ak.asset_name, - receivedCount: ak.received_count, - requiredCount: ak.required_count, + received_count: ak.received_count, + required_count: ak.required_count, uri: ak.asset_uri, })); @@ -65,19 +66,17 @@ export const AssetProgressCell = ({ dagId, partitionKey, totalReceived, totalReq {hasRollup ? ( - + {assets - .filter((ak) => ak.required_count > 1) + .filter((ak) => ak.is_rollup) .map((ak) => { const receivedKeySet = new Set(ak.received_keys); return ( - {assets.length > 1 ? ( - - {ak.asset_name} - - ) : undefined} + + {ak.asset_name} + {ak.required_keys.map((key) => { const isReceived = receivedKeySet.has(key); @@ -97,7 +96,7 @@ export const AssetProgressCell = ({ dagId, partitionKey, totalReceived, totalReq ); })} - + ) : ( )} diff --git a/airflow-core/src/airflow/ui/src/pages/DagsList/AssetSchedule.tsx b/airflow-core/src/airflow/ui/src/pages/DagsList/AssetSchedule.tsx index 69d74cb38c45a..e00b46a72dbfe 100644 --- a/airflow-core/src/airflow/ui/src/pages/DagsList/AssetSchedule.tsx +++ b/airflow-core/src/airflow/ui/src/pages/DagsList/AssetSchedule.tsx @@ -86,27 +86,30 @@ export const AssetSchedule = ({ assetExpression, dagId, timetablePartitioned, ti // Fully satisfied assets (used for the button count label). const pendingEvents = nextRunEvents.flatMap((event) => { if (timetablePartitioned) { - return event.lastUpdate === null ? [] : [event]; + return event.last_update === null ? [] : [event]; } const queuedAt = queuedAssetEvents.get(event.id); - return queuedAt === undefined ? [] : [{ ...event, lastUpdate: event.lastUpdate ?? queuedAt }]; + return queuedAt === undefined ? [] : [{ ...event, last_update: event.last_update ?? queuedAt }]; }); // For partitioned Dags, also include partially-received assets in the popover visualization. const popoverEvents = timetablePartitioned - ? nextRunEvents.filter((event) => (event.receivedCount ?? (event.lastUpdate === null ? 0 : 1)) > 0) + ? nextRunEvents.filter((event) => (event.received_count ?? (event.last_update === null ? 0 : 1)) > 0) : pendingEvents; // For partitioned Dags (which may use rollup mappers), compute event-level totals so the // button label reflects received/required partition-key events, not just asset counts. // For non-partitioned Dags, fall back to asset counts (existing behaviour). const scheduledCount = timetablePartitioned - ? nextRunEvents.reduce((sum, event) => sum + (event.receivedCount ?? 0), 0) + ? nextRunEvents.reduce( + (sum, event) => sum + Math.min(event.received_count ?? 0, event.required_count ?? 1), + 0, + ) : pendingEvents.length; const scheduledTotal = timetablePartitioned - ? nextRunEvents.reduce((sum, event) => sum + (event.requiredCount ?? 1), 0) + ? nextRunEvents.reduce((sum, event) => sum + (event.required_count ?? 1), 0) : nextRunEvents.length; const isLoading = isNextRunLoading || (!timetablePartitioned && isQueuedEventsLoading); @@ -142,12 +145,12 @@ export const AssetSchedule = ({ assetExpression, dagId, timetablePartitioned, ti const [asset] = nextRunEvents; if (nextRunEvents.length === 1 && asset !== undefined) { - const requiredCount = asset.requiredCount ?? 1; - const receivedCount = asset.receivedCount ?? 0; + const requiredCount = asset.required_count ?? 1; + const receivedCount = asset.received_count ?? 0; + const requiredKeys = asset.required_keys ?? []; - if (requiredCount > 1) { - const requiredKeys = asset.requiredKeys ?? []; - const receivedKeySet = new Set(asset.receivedKeys ?? []); + if (asset.is_rollup && requiredKeys.length > 0) { + const receivedKeySet = new Set(asset.received_keys ?? []); return ( diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_assets.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_assets.py index eaf3e5031fc48..0b6f1b218578d 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_assets.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_assets.py @@ -77,7 +77,7 @@ def test_should_response_200(self, test_client, dag_maker): ] }, "events": [ - {"id": mock.ANY, "uri": "s3://bucket/next-run-asset/1", "name": "asset1", "lastUpdate": None} + {"id": mock.ANY, "uri": "s3://bucket/next-run-asset/1", "name": "asset1", "last_update": None} ], } @@ -141,8 +141,8 @@ def test_should_set_last_update_only_for_queued_and_hide_flag(self, test_client, }, # events are ordered by uri "events": [ - {"id": mock.ANY, "uri": "s3://bucket/A", "name": "A", "lastUpdate": mock.ANY}, - {"id": mock.ANY, "uri": "s3://bucket/B", "name": "B", "lastUpdate": None}, + {"id": mock.ANY, "uri": "s3://bucket/A", "name": "A", "last_update": mock.ANY}, + {"id": mock.ANY, "uri": "s3://bucket/B", "name": "B", "last_update": None}, ], } @@ -169,7 +169,7 @@ def test_last_update_respects_latest_run_filter(self, test_client, dag_maker, se resp = test_client.get("/next_run_assets/filter_run") assert resp.status_code == 200 ev = resp.json()["events"][0] - assert ev["lastUpdate"] is not None + assert ev["last_update"] is not None assert "queued" not in ev @pytest.mark.parametrize( @@ -221,4 +221,4 @@ def test_partitioned_dag_last_update( resp = test_client.get("/next_run_assets/part_dag") assert resp.status_code == 200 ev = resp.json()["events"][0] - assert (ev["lastUpdate"] is not None) == expect_last_update + assert (ev["last_update"] is not None) == expect_last_update diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_partitioned_dag_runs.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_partitioned_dag_runs.py index 3bd09cc5644d8..e05eee99187bc 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_partitioned_dag_runs.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_partitioned_dag_runs.py @@ -23,6 +23,7 @@ from sqlalchemy import select from airflow.models.asset import AssetEvent, AssetModel, AssetPartitionDagRun, PartitionedAssetKeyLog +from airflow.partition_mappers.temporal import WeeklyRollupMapper from airflow.providers.standard.operators.empty import EmptyOperator from airflow.sdk.definitions.asset import Asset from airflow.sdk.definitions.timetables.assets import PartitionedAssetTimetable @@ -146,7 +147,7 @@ def test_should_response_200( ) session.commit() - with assert_queries_count(4): + with assert_queries_count(5): resp = test_client.get( f"/partitioned_dag_runs?dag_id=list_dag" f"&has_created_dag_run_id={str(has_created_dag_run_id).lower()}" @@ -238,6 +239,141 @@ def test_partitioned_dag_runs_filters_unreadable_dags(self, _, test_client, dag_ dag_ids = {r["dag_id"] for r in body["partitioned_dag_runs"]} assert "restricted_dag" not in dag_ids + def test_duplicate_events_count_as_one(self, test_client, dag_maker, session): + """Multiple log entries for the same asset count as 1 received, not N.""" + asset_def = Asset(uri="s3://bucket/dup0", name="dup0") + with dag_maker( + dag_id="dup_dag", + schedule=PartitionedAssetTimetable(assets=asset_def), + serialized=True, + ): + EmptyOperator(task_id="t") + dag_maker.create_dagrun() + dag_maker.sync_dagbag_to_db() + + asset = session.scalar(select(AssetModel).where(AssetModel.uri == "s3://bucket/dup0")) + pdr = AssetPartitionDagRun(target_dag_id="dup_dag", partition_key="2024-06-01") + session.add(pdr) + session.flush() + + # Log 3 events for the same asset — all should collapse to total_received = 1. + for _ in range(3): + event = AssetEvent(asset_id=asset.id, timestamp=pendulum.now()) + session.add(event) + session.flush() + session.add( + PartitionedAssetKeyLog( + asset_id=asset.id, + asset_event_id=event.id, + asset_partition_dag_run_id=pdr.id, + source_partition_key="2024-06-01", + target_dag_id="dup_dag", + target_partition_key="2024-06-01", + ) + ) + session.commit() + + resp = test_client.get("/partitioned_dag_runs?dag_id=dup_dag&has_created_dag_run_id=false") + assert resp.status_code == 200 + pdr_resp = resp.json()["partitioned_dag_runs"][0] + assert pdr_resp["total_required"] == 1 + assert pdr_resp["total_received"] == 1 + + def test_non_rollup_any_event_counts_as_one(self, test_client, dag_maker, session): + """For non-rollup, an event with a different source_partition_key still counts as 1.""" + asset_def = Asset(uri="s3://bucket/nr0", name="nr0") + with dag_maker( + dag_id="nr_dag", + schedule=PartitionedAssetTimetable(assets=asset_def), + serialized=True, + ): + EmptyOperator(task_id="t") + dag_maker.create_dagrun() + dag_maker.sync_dagbag_to_db() + + asset = session.scalar(select(AssetModel).where(AssetModel.uri == "s3://bucket/nr0")) + pdr = AssetPartitionDagRun(target_dag_id="nr_dag", partition_key="2024-06-01") + session.add(pdr) + session.flush() + + # Log an event whose source_partition_key differs from the APDR partition_key. + event = AssetEvent(asset_id=asset.id, timestamp=pendulum.now()) + session.add(event) + session.flush() + session.add( + PartitionedAssetKeyLog( + asset_id=asset.id, + asset_event_id=event.id, + asset_partition_dag_run_id=pdr.id, + source_partition_key="different-key", + target_dag_id="nr_dag", + target_partition_key="2024-06-01", + ) + ) + session.commit() + + resp = test_client.get("/partitioned_dag_runs?dag_id=nr_dag&has_created_dag_run_id=false") + assert resp.status_code == 200 + pdr_resp = resp.json()["partitioned_dag_runs"][0] + assert pdr_resp["total_required"] == 1 + assert pdr_resp["total_received"] == 1 + + def test_rollup_mapper_counts_received_upstream_keys(self, test_client, dag_maker, session): + """For a rollup mapper, only upstream keys in to_upstream() are counted.""" + asset_def = Asset(uri="s3://bucket/daily", name="daily") + mapper = WeeklyRollupMapper(input_format="%Y-%m-%d", output_format="%Y-%m-%d", week_start=0) + with dag_maker( + dag_id="rollup_dag", + schedule=PartitionedAssetTimetable(assets=asset_def, partition_mapper_config={asset_def: mapper}), + serialized=True, + ): + EmptyOperator(task_id="t") + dag_maker.create_dagrun() + dag_maker.sync_dagbag_to_db() + + asset = session.scalar(select(AssetModel).where(AssetModel.uri == "s3://bucket/daily")) + # Week starting 2024-06-03 (Monday) needs 7 daily keys. + pdr = AssetPartitionDagRun(target_dag_id="rollup_dag", partition_key="2024-06-03") + session.add(pdr) + session.flush() + + # Receive 2 of the 7 required upstream daily keys. + for day in ("2024-06-03", "2024-06-04"): + event = AssetEvent(asset_id=asset.id, timestamp=pendulum.now()) + session.add(event) + session.flush() + session.add( + PartitionedAssetKeyLog( + asset_id=asset.id, + asset_event_id=event.id, + asset_partition_dag_run_id=pdr.id, + source_partition_key=day, + target_dag_id="rollup_dag", + target_partition_key="2024-06-03", + ) + ) + # Also log a key outside the required week — it must not inflate the count. + stray = AssetEvent(asset_id=asset.id, timestamp=pendulum.now()) + session.add(stray) + session.flush() + session.add( + PartitionedAssetKeyLog( + asset_id=asset.id, + asset_event_id=stray.id, + asset_partition_dag_run_id=pdr.id, + source_partition_key="2024-05-27", # previous week + target_dag_id="rollup_dag", + target_partition_key="2024-06-03", + ) + ) + session.commit() + + resp = test_client.get("/partitioned_dag_runs?dag_id=rollup_dag&has_created_dag_run_id=false") + assert resp.status_code == 200 + pdr_resp = resp.json()["partitioned_dag_runs"][0] + assert pdr_resp["total_required"] == 7 + assert pdr_resp["total_received"] == 2 # only the 2 in-week keys count + class TestGetPendingPartitionedDagRun: def test_should_response_401(self, unauthenticated_test_client): @@ -354,3 +490,73 @@ def test_should_response_200(self, test_client, dag_maker, session, num_assets, received_uris = {a["asset_uri"] for a in body["assets"] if a["received"]} assert received_uris == set(uris[:received_count]) + + def test_is_rollup_false_for_non_rollup_asset(self, test_client, dag_maker, session): + """is_rollup is False for assets that use the identity (non-rollup) mapper.""" + asset_def = Asset(uri="s3://bucket/nr1", name="nr1") + with dag_maker( + dag_id="nr_detail_dag", + schedule=PartitionedAssetTimetable(assets=asset_def), + serialized=True, + ): + EmptyOperator(task_id="t") + dag_maker.create_dagrun() + dag_maker.sync_dagbag_to_db() + + session.add(AssetPartitionDagRun(target_dag_id="nr_detail_dag", partition_key="2024-07-01")) + session.commit() + + resp = test_client.get("/pending_partitioned_dag_run/nr_detail_dag/2024-07-01") + assert resp.status_code == 200 + assets = resp.json()["assets"] + assert len(assets) == 1 + assert assets[0]["is_rollup"] is False + + def test_is_rollup_true_for_rollup_asset(self, test_client, dag_maker, session): + """is_rollup is True for assets that use a RollupMapper, and keys are populated.""" + asset_def = Asset(uri="s3://bucket/weekly", name="weekly") + mapper = WeeklyRollupMapper(input_format="%Y-%m-%d", output_format="%Y-%m-%d", week_start=0) + with dag_maker( + dag_id="rollup_detail_dag", + schedule=PartitionedAssetTimetable(assets=asset_def, partition_mapper_config={asset_def: mapper}), + serialized=True, + ): + EmptyOperator(task_id="t") + dag_maker.create_dagrun() + dag_maker.sync_dagbag_to_db() + + asset = session.scalar(select(AssetModel).where(AssetModel.uri == "s3://bucket/weekly")) + pdr = AssetPartitionDagRun(target_dag_id="rollup_detail_dag", partition_key="2024-06-03") + session.add(pdr) + session.flush() + + # Receive one upstream daily key. + event = AssetEvent(asset_id=asset.id, timestamp=pendulum.now()) + session.add(event) + session.flush() + session.add( + PartitionedAssetKeyLog( + asset_id=asset.id, + asset_event_id=event.id, + asset_partition_dag_run_id=pdr.id, + source_partition_key="2024-06-03", + target_dag_id="rollup_detail_dag", + target_partition_key="2024-06-03", + ) + ) + session.commit() + + resp = test_client.get("/pending_partitioned_dag_run/rollup_detail_dag/2024-06-03") + assert resp.status_code == 200 + body = resp.json() + assert body["total_required"] == 7 + assert body["total_received"] == 1 + assets = body["assets"] + assert len(assets) == 1 + a = assets[0] + assert a["is_rollup"] is True + assert a["required_count"] == 7 + assert a["received_count"] == 1 + assert len(a["required_keys"]) == 7 + assert "2024-06-03" in a["required_keys"] + assert a["received_keys"] == ["2024-06-03"] From 209ff6863194da1dc381fdd36d62f111c2620755 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Tue, 21 Apr 2026 18:23:40 +0800 Subject: [PATCH 07/16] refactor: simplify event data handling --- .../core_api/datamodels/ui/assets.py | 45 +++++ .../core_api/openapi/_private_ui.yaml | 77 +++++++- .../api_fastapi/core_api/routes/ui/assets.py | 180 ++++++++++-------- .../routes/ui/partitioned_dag_runs.py | 32 ++-- .../ui/openapi-gen/queries/ensureQueryData.ts | 2 +- .../ui/openapi-gen/queries/prefetch.ts | 2 +- .../airflow/ui/openapi-gen/queries/queries.ts | 2 +- .../ui/openapi-gen/queries/suspense.ts | 2 +- .../ui/openapi-gen/requests/schemas.gen.ts | 111 ++++++++++- .../ui/openapi-gen/requests/services.gen.ts | 6 +- .../ui/openapi-gen/requests/types.gen.ts | 37 +++- .../AssetExpression/AssetExpression.tsx | 6 +- .../components/AssetExpression/AssetNode.tsx | 5 +- .../src/components/AssetExpression/types.ts | 12 -- .../ui/src/components/AssetProgressCell.tsx | 5 +- .../ui/src/pages/DagsList/AssetSchedule.tsx | 8 +- .../core_api/routes/ui/test_assets.py | 38 +++- 17 files changed, 429 insertions(+), 141 deletions(-) create mode 100644 airflow-core/src/airflow/api_fastapi/core_api/datamodels/ui/assets.py diff --git a/airflow-core/src/airflow/api_fastapi/core_api/datamodels/ui/assets.py b/airflow-core/src/airflow/api_fastapi/core_api/datamodels/ui/assets.py new file mode 100644 index 0000000000000..cd5aba6a68f1e --- /dev/null +++ b/airflow-core/src/airflow/api_fastapi/core_api/datamodels/ui/assets.py @@ -0,0 +1,45 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from datetime import datetime + +from pydantic import Field + +from airflow.api_fastapi.core_api.base import BaseModel + + +class NextRunAssetEventResponse(BaseModel): + """One asset event in the ``next_run_assets`` payload.""" + + id: int + name: str | None + uri: str + last_update: datetime | None = None + received_count: int = 0 + required_count: int = 1 + received_keys: list[str] = Field(default_factory=list) + required_keys: list[str] = Field(default_factory=list) + is_rollup: bool = False + + +class NextRunAssetsResponse(BaseModel): + """Response for the ``next_run_assets`` endpoint.""" + + asset_expression: dict | None = None + events: list[NextRunAssetEventResponse] + pending_partition_count: int | None = None diff --git a/airflow-core/src/airflow/api_fastapi/core_api/openapi/_private_ui.yaml b/airflow-core/src/airflow/api_fastapi/core_api/openapi/_private_ui.yaml index 5f4701b21e926..6e31538413d29 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/openapi/_private_ui.yaml +++ b/airflow-core/src/airflow/api_fastapi/core_api/openapi/_private_ui.yaml @@ -91,9 +91,7 @@ paths: content: application/json: schema: - type: object - additionalProperties: true - title: Response Next Run Assets + $ref: '#/components/schemas/NextRunAssetsResponse' '422': description: Validation Error content: @@ -2564,8 +2562,6 @@ components: - text - href title: ExtraMenuItem - description: Define a menu item that can be added to the menu by auth managers - or plugins. GanttResponse: properties: dag_id: @@ -3069,6 +3065,77 @@ components: - extra_menu_items title: MenuItemCollectionResponse description: Menu Item Collection serializer for responses. + NextRunAssetEventResponse: + properties: + id: + type: integer + title: Id + name: + anyOf: + - type: string + - type: 'null' + title: Name + uri: + type: string + title: Uri + last_update: + anyOf: + - type: string + format: date-time + - type: 'null' + title: Last Update + received_count: + type: integer + title: Received Count + default: 0 + required_count: + type: integer + title: Required Count + default: 1 + received_keys: + items: + type: string + type: array + title: Received Keys + required_keys: + items: + type: string + type: array + title: Required Keys + is_rollup: + type: boolean + title: Is Rollup + default: false + type: object + required: + - id + - name + - uri + title: NextRunAssetEventResponse + description: One asset event in the ``next_run_assets`` payload. + NextRunAssetsResponse: + properties: + asset_expression: + anyOf: + - additionalProperties: true + type: object + - type: 'null' + title: Asset Expression + events: + items: + $ref: '#/components/schemas/NextRunAssetEventResponse' + type: array + title: Events + pending_partition_count: + anyOf: + - type: integer + - type: 'null' + title: Pending Partition Count + type: object + required: + - events + title: NextRunAssetsResponse + description: Response for the ``next_run_assets`` endpoint. NodeResponse: properties: id: diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/assets.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/assets.py index b35e33fa861c4..43684bbee354b 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/assets.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/assets.py @@ -24,6 +24,10 @@ from airflow.api_fastapi.common.db.common import SessionDep from airflow.api_fastapi.common.router import AirflowRouter +from airflow.api_fastapi.core_api.datamodels.ui.assets import ( + NextRunAssetEventResponse, + NextRunAssetsResponse, +) from airflow.api_fastapi.core_api.routes.ui.partitioned_dag_runs import _load_timetable from airflow.api_fastapi.core_api.security import requires_access_asset, requires_access_dag from airflow.models import DagModel @@ -49,7 +53,7 @@ def next_run_assets( dag_id: str, session: SessionDep, -) -> dict: +) -> NextRunAssetsResponse: dag_model = DagModel.get_dagmodel(dag_id, session=session) if dag_model is None: raise HTTPException(status.HTTP_404_NOT_FOUND, f"Dag with id {dag_id} was not found") @@ -62,7 +66,7 @@ def next_run_assets( pending_partition_count: int | None = None queued_expr: ColumnElement[int] - if is_partitioned := dag_model.timetable_summary == "Partitioned Asset": + if is_partitioned := dag_model.timetable_partitioned: pending_partition_count = session.scalar( select(func.count()) .select_from(AssetPartitionDagRun) @@ -117,83 +121,101 @@ def next_run_assets( isouter=True, ) - events = [dict(info._mapping) for info in session.execute(query)] - for event in events: - if not event.pop("queued", None): - event["last_update"] = None + raw_rows = list(session.execute(query)) - # For partitioned Dags: enrich events with per-asset received/required counts, - # using to_upstream for rollup mappers, and fix last_update for partial receipt. - if is_partitioned: - timetable = _load_timetable(dag_id, session) + if not is_partitioned: + events = [ + NextRunAssetEventResponse( + id=row.id, + name=row.name, + uri=row.uri, + last_update=row.last_update if row.queued else None, + ) + for row in raw_rows + ] + return NextRunAssetsResponse(asset_expression=dag_model.asset_expression, events=events) + + # Partitioned Dags: enrich with per-asset received/required counts and rollup flag. + timetable = _load_timetable(dag_id, session) + pending_apdr = session.execute( + select(AssetPartitionDagRun.id, AssetPartitionDagRun.partition_key) + .where( + AssetPartitionDagRun.target_dag_id == dag_id, + AssetPartitionDagRun.created_dag_run_id.is_(None), + ) + .order_by(AssetPartitionDagRun.created_at.desc()) + .limit(1) + ).one_or_none() + + if pending_apdr is None: + # No pending APDR yet — mark rollup assets so the UI can handle them + # correctly (e.g. skip "Asset Triggered" in favour of the asset name view). + events = [] + for row in raw_rows: + is_rollup = False + if timetable is not None: + with suppress(Exception): + mapper = timetable.get_partition_mapper(name=row.name, uri=row.uri) + is_rollup = mapper.is_rollup + events.append( + NextRunAssetEventResponse( + id=row.id, + name=row.name, + uri=row.uri, + last_update=row.last_update if row.queued else None, + is_rollup=is_rollup, + ) + ) + return NextRunAssetsResponse( + asset_expression=dag_model.asset_expression, + events=events, + pending_partition_count=pending_partition_count, + ) - pending_apdr = session.execute( - select(AssetPartitionDagRun.id, AssetPartitionDagRun.partition_key) - .where( - AssetPartitionDagRun.target_dag_id == dag_id, - AssetPartitionDagRun.created_dag_run_id.is_(None), + # Collect received upstream partition keys per asset for this partition run. + # Use a set to deduplicate: multiple events for the same key count as one. + received_keys_by_asset: dict[int, set[str]] = {} + for log_row in session.execute( + select( + PartitionedAssetKeyLog.asset_id, + PartitionedAssetKeyLog.source_partition_key, + ).where(PartitionedAssetKeyLog.asset_partition_dag_run_id == pending_apdr.id) + ): + received_keys_by_asset.setdefault(log_row.asset_id, set()).add(log_row.source_partition_key or "") + + events = [] + for row in raw_rows: + received_keys = sorted(received_keys_by_asset.get(row.id, set())) + required_keys: list[str] = [pending_apdr.partition_key] + is_rollup = False + if timetable is not None: + with suppress(Exception): + mapper = timetable.get_partition_mapper(name=row.name, uri=row.uri) + if mapper.is_rollup: + required_keys = sorted( + cast("RollupMapper", mapper).to_upstream(pending_apdr.partition_key) + ) + is_rollup = True + received_count = len(received_keys) + required_count = len(required_keys) + # Only surface last_update once all required upstream keys have arrived. + last_update = row.last_update if row.queued and received_count >= required_count else None + events.append( + NextRunAssetEventResponse( + id=row.id, + name=row.name, + uri=row.uri, + last_update=last_update, + received_count=received_count, + required_count=required_count, + received_keys=received_keys, + required_keys=required_keys, + is_rollup=is_rollup, ) - .order_by(AssetPartitionDagRun.created_at.desc()) - .limit(1) - ).one_or_none() - - if pending_apdr is not None: - # Collect received upstream partition keys per asset for this partition run. - # Use a set to deduplicate: multiple events for the same key count as one. - received_keys_by_asset: dict[int, set[str]] = {} - for row in session.execute( - select( - PartitionedAssetKeyLog.asset_id, - PartitionedAssetKeyLog.source_partition_key, - ).where(PartitionedAssetKeyLog.asset_partition_dag_run_id == pending_apdr.id) - ): - received_keys_by_asset.setdefault(row.asset_id, set()).add(row.source_partition_key or "") - - for event in events: - asset_id = event["id"] - received_keys = sorted(received_keys_by_asset.get(asset_id, set())) - required_keys: list[str] = [pending_apdr.partition_key] - is_rollup = False - if timetable is not None: - with suppress(Exception): - mapper = timetable.get_partition_mapper( - name=event["name"], - uri=event["uri"], - ) - if mapper.is_rollup: - required_keys = sorted( - cast("RollupMapper", mapper).to_upstream(pending_apdr.partition_key) - ) - is_rollup = True - received_count = len(received_keys) - required_count = len(required_keys) - event["received_count"] = received_count - event["required_count"] = required_count - event["received_keys"] = received_keys - event["required_keys"] = required_keys - event["is_rollup"] = is_rollup - # Only show last_update when all required upstream keys are received - if received_count < required_count: - event["last_update"] = None - else: - # No pending APDR yet — mark rollup assets so the UI can handle them - # correctly (e.g. skip "Asset Triggered" in favour of the asset name view). - for event in events: - is_rollup = False - if timetable is not None: - with suppress(Exception): - mapper = timetable.get_partition_mapper( - name=event["name"], - uri=event["uri"], - ) - is_rollup = mapper.is_rollup - event["received_count"] = 0 - event["required_count"] = 1 - event["received_keys"] = [] - event["required_keys"] = [] - event["is_rollup"] = is_rollup - - data: dict = {"asset_expression": dag_model.asset_expression, "events": events} - if pending_partition_count is not None: - data["pending_partition_count"] = pending_partition_count - return data + ) + + return NextRunAssetsResponse( + asset_expression=dag_model.asset_expression, + events=events, + pending_partition_count=pending_partition_count, + ) diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/partitioned_dag_runs.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/partitioned_dag_runs.py index 2216b4908efc5..7bf4f2392b2cb 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/partitioned_dag_runs.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/partitioned_dag_runs.py @@ -17,7 +17,7 @@ from __future__ import annotations from contextlib import suppress -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING, TypeAlias, cast from fastapi import Depends, HTTPException, status from sqlalchemy import func, select @@ -51,10 +51,16 @@ from airflow.timetables.simple import PartitionedAssetTimetable if TYPE_CHECKING: + from sqlalchemy.orm import Session + from airflow.partition_mappers.base import RollupMapper -def _load_timetable(dag_id: str, session) -> PartitionedAssetTimetable | None: +AssetNameUri: TypeAlias = tuple[str, str] +"""A ``(name, uri)`` pair identifying an asset.""" + + +def _load_timetable(dag_id: str, session: Session) -> PartitionedAssetTimetable | None: """Return the PartitionedAssetTimetable for *dag_id*, or None if absent or not partitioned.""" serdag = SerializedDagModel.get(dag_id=dag_id, session=session) if serdag is None: @@ -66,8 +72,8 @@ def _load_timetable(dag_id: str, session) -> PartitionedAssetTimetable | None: def _load_timetable_and_assets( - dag_id: str, session -) -> tuple[PartitionedAssetTimetable | None, list[tuple[str, str]], dict[int, tuple[str, str]]]: + dag_id: str, session: Session +) -> tuple[PartitionedAssetTimetable | None, list[AssetNameUri], dict[int, AssetNameUri]]: """ Load timetable and active required assets. @@ -86,7 +92,7 @@ def _load_timetable_and_assets( def _compute_total_required( timetable: PartitionedAssetTimetable | None, - asset_info: list[tuple[str, str]], + asset_info: list[AssetNameUri], partition_key: str, ) -> int: """Sum required upstream events across all assets, using to_upstream for rollup mappers.""" @@ -102,7 +108,7 @@ def _compute_total_required( def _compute_received_count( received_by_asset: dict[int, set[str]], timetable: PartitionedAssetTimetable | None, - asset_id_to_info: dict[int, tuple[str, str]], + asset_id_to_info: dict[int, AssetNameUri], partition_key: str, ) -> int: """ @@ -155,12 +161,12 @@ def get_partitioned_dag_runs( """Return PartitionedDagRuns. Filter by dag_id and/or has_created_dag_run_id.""" if dag_id.value is not None: dag_info = session.execute( - select(DagModel.timetable_summary).where(DagModel.dag_id == dag_id.value) + select(DagModel.timetable_partitioned).where(DagModel.dag_id == dag_id.value) ).one_or_none() if dag_info is None: raise HTTPException(status.HTTP_404_NOT_FOUND, f"Dag with id {dag_id.value} was not found") - if dag_info.timetable_summary != "Partitioned Asset": + if not dag_info.timetable_partitioned: return PartitionedDagRunCollectionResponse(partitioned_dag_runs=[], total=0) # Subquery for received count per partition (count of required assets that have any log). @@ -210,7 +216,7 @@ def get_partitioned_dag_runs( # Batch-fetch all log entries for these APDRs in one query. apdr_ids = [row.id for row in rows] log_by_apdr: dict[int, dict[int, set[str]]] = {} - for log_row in session.execute( + for pakl_row in session.execute( select( PartitionedAssetKeyLog.asset_partition_dag_run_id, PartitionedAssetKeyLog.asset_id, @@ -220,9 +226,9 @@ def get_partitioned_dag_runs( PartitionedAssetKeyLog.asset_id.in_(list(asset_id_to_info)), ) ).all(): - log_by_apdr.setdefault(log_row.asset_partition_dag_run_id, {}).setdefault( - log_row.asset_id, set() - ).add(log_row.source_partition_key or "") + log_by_apdr.setdefault(pakl_row.asset_partition_dag_run_id, {}).setdefault( + pakl_row.asset_id, set() + ).add(pakl_row.source_partition_key or "") results = [ _build_response( @@ -240,7 +246,7 @@ def get_partitioned_dag_runs( # total_received is approximated via SQL for this global view. unique_dag_ids = list({row.target_dag_id for row in rows}) dag_timetables_assets: dict[ - str, tuple[PartitionedAssetTimetable | None, list[tuple[str, str]], dict[int, tuple[str, str]]] + str, tuple[PartitionedAssetTimetable | None, list[AssetNameUri], dict[int, AssetNameUri]] ] = {did: _load_timetable_and_assets(did, session) for did in unique_dag_ids} dag_rows = session.execute( select(DagModel.dag_id, DagModel.asset_expression).where(DagModel.dag_id.in_(unique_dag_ids)) diff --git a/airflow-core/src/airflow/ui/openapi-gen/queries/ensureQueryData.ts b/airflow-core/src/airflow/ui/openapi-gen/queries/ensureQueryData.ts index 5e0595ec8f6b4..91d5df8edeaff 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/queries/ensureQueryData.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/queries/ensureQueryData.ts @@ -162,7 +162,7 @@ export const ensureUseAssetServiceGetDagAssetQueuedEventData = (queryClient: Que * Next Run Assets * @param data The data for the request. * @param data.dagId -* @returns unknown Successful Response +* @returns NextRunAssetsResponse Successful Response * @throws ApiError */ export const ensureUseAssetServiceNextRunAssetsData = (queryClient: QueryClient, { dagId }: { diff --git a/airflow-core/src/airflow/ui/openapi-gen/queries/prefetch.ts b/airflow-core/src/airflow/ui/openapi-gen/queries/prefetch.ts index 2a17cdfefc00a..4dd3b0067c57c 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/queries/prefetch.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/queries/prefetch.ts @@ -162,7 +162,7 @@ export const prefetchUseAssetServiceGetDagAssetQueuedEvent = (queryClient: Query * Next Run Assets * @param data The data for the request. * @param data.dagId -* @returns unknown Successful Response +* @returns NextRunAssetsResponse Successful Response * @throws ApiError */ export const prefetchUseAssetServiceNextRunAssets = (queryClient: QueryClient, { dagId }: { diff --git a/airflow-core/src/airflow/ui/openapi-gen/queries/queries.ts b/airflow-core/src/airflow/ui/openapi-gen/queries/queries.ts index 52e4776d510a5..1e8c86df6f095 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/queries/queries.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/queries/queries.ts @@ -162,7 +162,7 @@ export const useAssetServiceGetDagAssetQueuedEvent = = unknown[]>({ dagId }: { diff --git a/airflow-core/src/airflow/ui/openapi-gen/queries/suspense.ts b/airflow-core/src/airflow/ui/openapi-gen/queries/suspense.ts index 9d42191b88606..d67f99e2669a4 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/queries/suspense.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/queries/suspense.ts @@ -162,7 +162,7 @@ export const useAssetServiceGetDagAssetQueuedEventSuspense = = unknown[]>({ dagId }: { diff --git a/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts b/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts index 2d4fd67fc420c..fd5ed96975257 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts @@ -8323,8 +8323,7 @@ export const $ExtraMenuItem = { }, type: 'object', required: ['text', 'href'], - title: 'ExtraMenuItem', - description: 'Define a menu item that can be added to the menu by auth managers or plugins.' + title: 'ExtraMenuItem' } as const; export const $GanttResponse = { @@ -8763,6 +8762,114 @@ export const $MenuItemCollectionResponse = { description: 'Menu Item Collection serializer for responses.' } as const; +export const $NextRunAssetEventResponse = { + properties: { + id: { + type: 'integer', + title: 'Id' + }, + name: { + anyOf: [ + { + type: 'string' + }, + { + type: 'null' + } + ], + title: 'Name' + }, + uri: { + type: 'string', + title: 'Uri' + }, + last_update: { + anyOf: [ + { + type: 'string', + format: 'date-time' + }, + { + type: 'null' + } + ], + title: 'Last Update' + }, + received_count: { + type: 'integer', + title: 'Received Count', + default: 0 + }, + required_count: { + type: 'integer', + title: 'Required Count', + default: 1 + }, + received_keys: { + items: { + type: 'string' + }, + type: 'array', + title: 'Received Keys' + }, + required_keys: { + items: { + type: 'string' + }, + type: 'array', + title: 'Required Keys' + }, + is_rollup: { + type: 'boolean', + title: 'Is Rollup', + default: false + } + }, + type: 'object', + required: ['id', 'name', 'uri'], + title: 'NextRunAssetEventResponse', + description: 'One asset event in the ``next_run_assets`` payload.' +} as const; + +export const $NextRunAssetsResponse = { + properties: { + asset_expression: { + anyOf: [ + { + additionalProperties: true, + type: 'object' + }, + { + type: 'null' + } + ], + title: 'Asset Expression' + }, + events: { + items: { + '$ref': '#/components/schemas/NextRunAssetEventResponse' + }, + type: 'array', + title: 'Events' + }, + pending_partition_count: { + anyOf: [ + { + type: 'integer' + }, + { + type: 'null' + } + ], + title: 'Pending Partition Count' + } + }, + type: 'object', + required: ['events'], + title: 'NextRunAssetsResponse', + description: 'Response for the ``next_run_assets`` endpoint.' +} as const; + export const $NodeResponse = { properties: { id: { diff --git a/airflow-core/src/airflow/ui/openapi-gen/requests/services.gen.ts b/airflow-core/src/airflow/ui/openapi-gen/requests/services.gen.ts index 83c5d5c17b78d..b9f7e7ce77a08 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/requests/services.gen.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/requests/services.gen.ts @@ -3,7 +3,7 @@ import type { CancelablePromise } from './core/CancelablePromise'; import { OpenAPI } from './core/OpenAPI'; import { request as __request } from './core/request'; -import type { GetAssetsData, GetAssetsResponse, GetAssetAliasesData, GetAssetAliasesResponse, GetAssetAliasData, GetAssetAliasResponse, GetAssetEventsData, GetAssetEventsResponse, CreateAssetEventData, CreateAssetEventResponse, MaterializeAssetData, MaterializeAssetResponse, GetAssetQueuedEventsData, GetAssetQueuedEventsResponse, DeleteAssetQueuedEventsData, DeleteAssetQueuedEventsResponse, GetAssetData, GetAssetResponse, GetDagAssetQueuedEventsData, GetDagAssetQueuedEventsResponse, DeleteDagAssetQueuedEventsData, DeleteDagAssetQueuedEventsResponse, GetDagAssetQueuedEventData, GetDagAssetQueuedEventResponse, DeleteDagAssetQueuedEventData, DeleteDagAssetQueuedEventResponse, NextRunAssetsData, NextRunAssetsResponse, ListBackfillsData, ListBackfillsResponse, CreateBackfillData, CreateBackfillResponse, GetBackfillData, GetBackfillResponse, PauseBackfillData, PauseBackfillResponse, UnpauseBackfillData, UnpauseBackfillResponse, CancelBackfillData, CancelBackfillResponse, CreateBackfillDryRunData, CreateBackfillDryRunResponse, ListBackfillsUiData, ListBackfillsUiResponse, DeleteConnectionData, DeleteConnectionResponse, GetConnectionData, GetConnectionResponse, PatchConnectionData, PatchConnectionResponse, GetConnectionsData, GetConnectionsResponse, PostConnectionData, PostConnectionResponse, BulkConnectionsData, BulkConnectionsResponse, TestConnectionData, TestConnectionResponse, CreateDefaultConnectionsResponse, HookMetaDataResponse, GetDagRunData, GetDagRunResponse, DeleteDagRunData, DeleteDagRunResponse, PatchDagRunData, PatchDagRunResponse, GetUpstreamAssetEventsData, GetUpstreamAssetEventsResponse, ClearDagRunData, ClearDagRunResponse, GetDagRunsData, GetDagRunsResponse, TriggerDagRunData, TriggerDagRunResponse, WaitDagRunUntilFinishedData, WaitDagRunUntilFinishedResponse, GetListDagRunsBatchData, GetListDagRunsBatchResponse, GetDagSourceData, GetDagSourceResponse, GetDagStatsData, GetDagStatsResponse, GetConfigData, GetConfigResponse, GetConfigValueData, GetConfigValueResponse, GetConfigsResponse, ListDagWarningsData, ListDagWarningsResponse, GetDagsData, GetDagsResponse, PatchDagsData, PatchDagsResponse, GetDagData, GetDagResponse, PatchDagData, PatchDagResponse, DeleteDagData, DeleteDagResponse, GetDagDetailsData, GetDagDetailsResponse, FavoriteDagData, FavoriteDagResponse, UnfavoriteDagData, UnfavoriteDagResponse, GetDagTagsData, GetDagTagsResponse, GetDagsUiData, GetDagsUiResponse, GetLatestRunInfoData, GetLatestRunInfoResponse, GetEventLogData, GetEventLogResponse, GetEventLogsData, GetEventLogsResponse, GetExtraLinksData, GetExtraLinksResponse, GetTaskInstanceData, GetTaskInstanceResponse, PatchTaskInstanceData, PatchTaskInstanceResponse, DeleteTaskInstanceData, DeleteTaskInstanceResponse, GetMappedTaskInstancesData, GetMappedTaskInstancesResponse, GetTaskInstanceDependenciesByMapIndexData, GetTaskInstanceDependenciesByMapIndexResponse, GetTaskInstanceDependenciesData, GetTaskInstanceDependenciesResponse, GetTaskInstanceTriesData, GetTaskInstanceTriesResponse, GetMappedTaskInstanceTriesData, GetMappedTaskInstanceTriesResponse, GetMappedTaskInstanceData, GetMappedTaskInstanceResponse, PatchTaskInstanceByMapIndexData, PatchTaskInstanceByMapIndexResponse, GetTaskInstancesData, GetTaskInstancesResponse, BulkTaskInstancesData, BulkTaskInstancesResponse, GetTaskInstancesBatchData, GetTaskInstancesBatchResponse, GetTaskInstanceTryDetailsData, GetTaskInstanceTryDetailsResponse, GetMappedTaskInstanceTryDetailsData, GetMappedTaskInstanceTryDetailsResponse, PostClearTaskInstancesData, PostClearTaskInstancesResponse, PatchTaskGroupInstancesData, PatchTaskGroupInstancesResponse, PatchTaskGroupInstancesDryRunData, PatchTaskGroupInstancesDryRunResponse, PatchTaskInstanceDryRunByMapIndexData, PatchTaskInstanceDryRunByMapIndexResponse, PatchTaskInstanceDryRunData, PatchTaskInstanceDryRunResponse, GetLogData, GetLogResponse, GetExternalLogUrlData, GetExternalLogUrlResponse, UpdateHitlDetailData, UpdateHitlDetailResponse, GetHitlDetailData, GetHitlDetailResponse, GetHitlDetailTryDetailData, GetHitlDetailTryDetailResponse, GetHitlDetailsData, GetHitlDetailsResponse, GetImportErrorData, GetImportErrorResponse, GetImportErrorsData, GetImportErrorsResponse, GetJobsData, GetJobsResponse, GetPluginsData, GetPluginsResponse, ImportErrorsResponse, DeletePoolData, DeletePoolResponse, GetPoolData, GetPoolResponse, PatchPoolData, PatchPoolResponse, GetPoolsData, GetPoolsResponse, PostPoolData, PostPoolResponse, BulkPoolsData, BulkPoolsResponse, GetProvidersData, GetProvidersResponse, GetXcomEntryData, GetXcomEntryResponse, UpdateXcomEntryData, UpdateXcomEntryResponse, DeleteXcomEntryData, DeleteXcomEntryResponse, GetXcomEntriesData, GetXcomEntriesResponse, CreateXcomEntryData, CreateXcomEntryResponse, GetTasksData, GetTasksResponse, GetTaskData, GetTaskResponse, DeleteVariableData, DeleteVariableResponse, GetVariableData, GetVariableResponse, PatchVariableData, PatchVariableResponse, GetVariablesData, GetVariablesResponse, PostVariableData, PostVariableResponse, BulkVariablesData, BulkVariablesResponse, ReparseDagFileData, ReparseDagFileResponse, GetDagVersionData, GetDagVersionResponse, GetDagVersionsData, GetDagVersionsResponse, GetHealthResponse, GetVersionResponse, LoginData, LoginResponse, LogoutResponse, GetAuthMenusResponse, GetCurrentUserInfoResponse, GenerateTokenData, GenerateTokenResponse2, GetPartitionedDagRunsData, GetPartitionedDagRunsResponse, GetPendingPartitionedDagRunData, GetPendingPartitionedDagRunResponse, GetDependenciesData, GetDependenciesResponse, HistoricalMetricsData, HistoricalMetricsResponse, DagStatsResponse2, GetDeadlinesData, GetDeadlinesResponse, GetDagDeadlineAlertsData, GetDagDeadlineAlertsResponse, StructureDataData, StructureDataResponse2, GetDagStructureData, GetDagStructureResponse, GetGridRunsData, GetGridRunsResponse, GetGridTiSummariesStreamData, GetGridTiSummariesStreamResponse, GetGanttDataData, GetGanttDataResponse, GetCalendarData, GetCalendarResponse, ListTeamsData, ListTeamsResponse } from './types.gen'; +import type { GetAssetsData, GetAssetsResponse, GetAssetAliasesData, GetAssetAliasesResponse, GetAssetAliasData, GetAssetAliasResponse, GetAssetEventsData, GetAssetEventsResponse, CreateAssetEventData, CreateAssetEventResponse, MaterializeAssetData, MaterializeAssetResponse, GetAssetQueuedEventsData, GetAssetQueuedEventsResponse, DeleteAssetQueuedEventsData, DeleteAssetQueuedEventsResponse, GetAssetData, GetAssetResponse, GetDagAssetQueuedEventsData, GetDagAssetQueuedEventsResponse, DeleteDagAssetQueuedEventsData, DeleteDagAssetQueuedEventsResponse, GetDagAssetQueuedEventData, GetDagAssetQueuedEventResponse, DeleteDagAssetQueuedEventData, DeleteDagAssetQueuedEventResponse, NextRunAssetsData, NextRunAssetsResponse2, ListBackfillsData, ListBackfillsResponse, CreateBackfillData, CreateBackfillResponse, GetBackfillData, GetBackfillResponse, PauseBackfillData, PauseBackfillResponse, UnpauseBackfillData, UnpauseBackfillResponse, CancelBackfillData, CancelBackfillResponse, CreateBackfillDryRunData, CreateBackfillDryRunResponse, ListBackfillsUiData, ListBackfillsUiResponse, DeleteConnectionData, DeleteConnectionResponse, GetConnectionData, GetConnectionResponse, PatchConnectionData, PatchConnectionResponse, GetConnectionsData, GetConnectionsResponse, PostConnectionData, PostConnectionResponse, BulkConnectionsData, BulkConnectionsResponse, TestConnectionData, TestConnectionResponse, CreateDefaultConnectionsResponse, HookMetaDataResponse, GetDagRunData, GetDagRunResponse, DeleteDagRunData, DeleteDagRunResponse, PatchDagRunData, PatchDagRunResponse, GetUpstreamAssetEventsData, GetUpstreamAssetEventsResponse, ClearDagRunData, ClearDagRunResponse, GetDagRunsData, GetDagRunsResponse, TriggerDagRunData, TriggerDagRunResponse, WaitDagRunUntilFinishedData, WaitDagRunUntilFinishedResponse, GetListDagRunsBatchData, GetListDagRunsBatchResponse, GetDagSourceData, GetDagSourceResponse, GetDagStatsData, GetDagStatsResponse, GetConfigData, GetConfigResponse, GetConfigValueData, GetConfigValueResponse, GetConfigsResponse, ListDagWarningsData, ListDagWarningsResponse, GetDagsData, GetDagsResponse, PatchDagsData, PatchDagsResponse, GetDagData, GetDagResponse, PatchDagData, PatchDagResponse, DeleteDagData, DeleteDagResponse, GetDagDetailsData, GetDagDetailsResponse, FavoriteDagData, FavoriteDagResponse, UnfavoriteDagData, UnfavoriteDagResponse, GetDagTagsData, GetDagTagsResponse, GetDagsUiData, GetDagsUiResponse, GetLatestRunInfoData, GetLatestRunInfoResponse, GetEventLogData, GetEventLogResponse, GetEventLogsData, GetEventLogsResponse, GetExtraLinksData, GetExtraLinksResponse, GetTaskInstanceData, GetTaskInstanceResponse, PatchTaskInstanceData, PatchTaskInstanceResponse, DeleteTaskInstanceData, DeleteTaskInstanceResponse, GetMappedTaskInstancesData, GetMappedTaskInstancesResponse, GetTaskInstanceDependenciesByMapIndexData, GetTaskInstanceDependenciesByMapIndexResponse, GetTaskInstanceDependenciesData, GetTaskInstanceDependenciesResponse, GetTaskInstanceTriesData, GetTaskInstanceTriesResponse, GetMappedTaskInstanceTriesData, GetMappedTaskInstanceTriesResponse, GetMappedTaskInstanceData, GetMappedTaskInstanceResponse, PatchTaskInstanceByMapIndexData, PatchTaskInstanceByMapIndexResponse, GetTaskInstancesData, GetTaskInstancesResponse, BulkTaskInstancesData, BulkTaskInstancesResponse, GetTaskInstancesBatchData, GetTaskInstancesBatchResponse, GetTaskInstanceTryDetailsData, GetTaskInstanceTryDetailsResponse, GetMappedTaskInstanceTryDetailsData, GetMappedTaskInstanceTryDetailsResponse, PostClearTaskInstancesData, PostClearTaskInstancesResponse, PatchTaskGroupInstancesData, PatchTaskGroupInstancesResponse, PatchTaskGroupInstancesDryRunData, PatchTaskGroupInstancesDryRunResponse, PatchTaskInstanceDryRunByMapIndexData, PatchTaskInstanceDryRunByMapIndexResponse, PatchTaskInstanceDryRunData, PatchTaskInstanceDryRunResponse, GetLogData, GetLogResponse, GetExternalLogUrlData, GetExternalLogUrlResponse, UpdateHitlDetailData, UpdateHitlDetailResponse, GetHitlDetailData, GetHitlDetailResponse, GetHitlDetailTryDetailData, GetHitlDetailTryDetailResponse, GetHitlDetailsData, GetHitlDetailsResponse, GetImportErrorData, GetImportErrorResponse, GetImportErrorsData, GetImportErrorsResponse, GetJobsData, GetJobsResponse, GetPluginsData, GetPluginsResponse, ImportErrorsResponse, DeletePoolData, DeletePoolResponse, GetPoolData, GetPoolResponse, PatchPoolData, PatchPoolResponse, GetPoolsData, GetPoolsResponse, PostPoolData, PostPoolResponse, BulkPoolsData, BulkPoolsResponse, GetProvidersData, GetProvidersResponse, GetXcomEntryData, GetXcomEntryResponse, UpdateXcomEntryData, UpdateXcomEntryResponse, DeleteXcomEntryData, DeleteXcomEntryResponse, GetXcomEntriesData, GetXcomEntriesResponse, CreateXcomEntryData, CreateXcomEntryResponse, GetTasksData, GetTasksResponse, GetTaskData, GetTaskResponse, DeleteVariableData, DeleteVariableResponse, GetVariableData, GetVariableResponse, PatchVariableData, PatchVariableResponse, GetVariablesData, GetVariablesResponse, PostVariableData, PostVariableResponse, BulkVariablesData, BulkVariablesResponse, ReparseDagFileData, ReparseDagFileResponse, GetDagVersionData, GetDagVersionResponse, GetDagVersionsData, GetDagVersionsResponse, GetHealthResponse, GetVersionResponse, LoginData, LoginResponse, LogoutResponse, GetAuthMenusResponse, GetCurrentUserInfoResponse, GenerateTokenData, GenerateTokenResponse2, GetPartitionedDagRunsData, GetPartitionedDagRunsResponse, GetPendingPartitionedDagRunData, GetPendingPartitionedDagRunResponse, GetDependenciesData, GetDependenciesResponse, HistoricalMetricsData, HistoricalMetricsResponse, DagStatsResponse2, GetDeadlinesData, GetDeadlinesResponse, GetDagDeadlineAlertsData, GetDagDeadlineAlertsResponse, StructureDataData, StructureDataResponse2, GetDagStructureData, GetDagStructureResponse, GetGridRunsData, GetGridRunsResponse, GetGridTiSummariesStreamData, GetGridTiSummariesStreamResponse, GetGanttDataData, GetGanttDataResponse, GetCalendarData, GetCalendarResponse, ListTeamsData, ListTeamsResponse } from './types.gen'; export class AssetService { /** @@ -411,10 +411,10 @@ export class AssetService { * Next Run Assets * @param data The data for the request. * @param data.dagId - * @returns unknown Successful Response + * @returns NextRunAssetsResponse Successful Response * @throws ApiError */ - public static nextRunAssets(data: NextRunAssetsData): CancelablePromise { + public static nextRunAssets(data: NextRunAssetsData): CancelablePromise { return __request(OpenAPI, { method: 'GET', url: '/ui/next_run_assets/{dag_id}', diff --git a/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts b/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts index 304095701ae4a..57948231797be 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts @@ -2063,9 +2063,6 @@ export type EdgeResponse = { is_source_asset?: boolean | null; }; -/** - * Define a menu item that can be added to the menu by auth managers or plugins. - */ export type ExtraMenuItem = { text: string; href: string; @@ -2186,6 +2183,32 @@ export type MenuItemCollectionResponse = { extra_menu_items: Array; }; +/** + * One asset event in the ``next_run_assets`` payload. + */ +export type NextRunAssetEventResponse = { + id: number; + name: string | null; + uri: string; + last_update?: string | null; + received_count?: number; + required_count?: number; + received_keys?: Array<(string)>; + required_keys?: Array<(string)>; + is_rollup?: boolean; +}; + +/** + * Response for the ``next_run_assets`` endpoint. + */ +export type NextRunAssetsResponse = { + asset_expression?: { + [key: string]: unknown; +} | null; + events: Array; + pending_partition_count?: number | null; +}; + /** * Node serializer for responses. */ @@ -2527,9 +2550,7 @@ export type NextRunAssetsData = { dagId: string; }; -export type NextRunAssetsResponse = { - [key: string]: unknown; -}; +export type NextRunAssetsResponse2 = NextRunAssetsResponse; export type ListBackfillsData = { dagId: string; @@ -4492,9 +4513,7 @@ export type $OpenApiTs = { /** * Successful Response */ - 200: { - [key: string]: unknown; - }; + 200: NextRunAssetsResponse; /** * Validation Error */ diff --git a/airflow-core/src/airflow/ui/src/components/AssetExpression/AssetExpression.tsx b/airflow-core/src/airflow/ui/src/components/AssetExpression/AssetExpression.tsx index 4b9b7ccc03579..c8c0382fd7326 100644 --- a/airflow-core/src/airflow/ui/src/components/AssetExpression/AssetExpression.tsx +++ b/airflow-core/src/airflow/ui/src/components/AssetExpression/AssetExpression.tsx @@ -21,16 +21,18 @@ import { Fragment } from "react"; import { useTranslation } from "react-i18next"; import { TbLogicOr } from "react-icons/tb"; +import type { NextRunAssetEventResponse } from "openapi/requests/types.gen"; + import { AndGateNode } from "./AndGateNode"; import { AssetNode } from "./AssetNode"; import { OrGateNode } from "./OrGateNode"; -import type { ExpressionType, NextRunEvent } from "./types"; +import type { ExpressionType } from "./types"; export const AssetExpression = ({ events, expression, }: { - readonly events?: Array; + readonly events?: Array; readonly expression: ExpressionType | undefined; }) => { const { t: translate } = useTranslation("common"); diff --git a/airflow-core/src/airflow/ui/src/components/AssetExpression/AssetNode.tsx b/airflow-core/src/airflow/ui/src/components/AssetExpression/AssetNode.tsx index 72f06bbcca421..ba61aeb944dfa 100644 --- a/airflow-core/src/airflow/ui/src/components/AssetExpression/AssetNode.tsx +++ b/airflow-core/src/airflow/ui/src/components/AssetExpression/AssetNode.tsx @@ -21,10 +21,11 @@ import { FiCheck, FiDatabase, FiMinus } from "react-icons/fi"; import { PiRectangleDashed } from "react-icons/pi"; import { Link as RouterLink } from "react-router-dom"; +import type { NextRunAssetEventResponse } from "openapi/requests/types.gen"; import { Popover } from "src/components/ui"; import Time from "../Time"; -import type { AssetSummary, NextRunEvent } from "./types"; +import type { AssetSummary } from "./types"; const RollupKeyChecklistPopover = ({ receivedCount, @@ -84,7 +85,7 @@ export const AssetNode = ({ event, }: { readonly asset: AssetSummary; - readonly event?: NextRunEvent; + readonly event?: NextRunAssetEventResponse; }) => { const isFullyReceived = Boolean(event?.last_update); const isPartial = diff --git a/airflow-core/src/airflow/ui/src/components/AssetExpression/types.ts b/airflow-core/src/airflow/ui/src/components/AssetExpression/types.ts index c24af1900d349..cae9bfa524f04 100644 --- a/airflow-core/src/airflow/ui/src/components/AssetExpression/types.ts +++ b/airflow-core/src/airflow/ui/src/components/AssetExpression/types.ts @@ -33,18 +33,6 @@ type Alias = { }; }; -export type NextRunEvent = { - id: number; - is_rollup?: boolean; - last_update: string | null; - name: string | null; - received_count?: number; - received_keys?: Array; - required_count?: number; - required_keys?: Array; - uri: string; -}; - export type AssetSummary = Alias | Asset; export type ExpressionType = diff --git a/airflow-core/src/airflow/ui/src/components/AssetProgressCell.tsx b/airflow-core/src/airflow/ui/src/components/AssetProgressCell.tsx index 83d1f2321f93a..50676016e2869 100644 --- a/airflow-core/src/airflow/ui/src/components/AssetProgressCell.tsx +++ b/airflow-core/src/airflow/ui/src/components/AssetProgressCell.tsx @@ -21,9 +21,8 @@ import { FiCheck, FiDatabase, FiMinus } from "react-icons/fi"; import { Link as RouterLink } from "react-router-dom"; import { usePartitionedDagRunServiceGetPendingPartitionedDagRun } from "openapi/queries"; -import type { PartitionedDagRunAssetResponse } from "openapi/requests/types.gen"; +import type { NextRunAssetEventResponse, PartitionedDagRunAssetResponse } from "openapi/requests/types.gen"; import { AssetExpression, type ExpressionType } from "src/components/AssetExpression"; -import type { NextRunEvent } from "src/components/AssetExpression/types"; import { Popover } from "src/components/ui"; type Props = { @@ -41,7 +40,7 @@ export const AssetProgressCell = ({ dagId, partitionKey, totalReceived, totalReq const hasRollup = assets.some((ak) => ak.is_rollup); - const events: Array = assets + const events: Array = assets .filter((ak: PartitionedDagRunAssetResponse) => ak.received_count > 0) .map((ak: PartitionedDagRunAssetResponse) => ({ id: ak.asset_id, diff --git a/airflow-core/src/airflow/ui/src/pages/DagsList/AssetSchedule.tsx b/airflow-core/src/airflow/ui/src/pages/DagsList/AssetSchedule.tsx index e00b46a72dbfe..f5baf81c30bd0 100644 --- a/airflow-core/src/airflow/ui/src/pages/DagsList/AssetSchedule.tsx +++ b/airflow-core/src/airflow/ui/src/pages/DagsList/AssetSchedule.tsx @@ -24,8 +24,8 @@ import { FiCheck, FiDatabase, FiMinus } from "react-icons/fi"; import { Link as RouterLink } from "react-router-dom"; import { useAssetServiceGetDagAssetQueuedEvents, useAssetServiceNextRunAssets } from "openapi/queries"; +import type { NextRunAssetEventResponse } from "openapi/requests/types.gen"; import { AssetExpression, type ExpressionType } from "src/components/AssetExpression"; -import type { NextRunEvent } from "src/components/AssetExpression/types"; import { TruncatedText } from "src/components/TruncatedText"; import { Popover } from "src/components/ui"; @@ -69,7 +69,7 @@ export const AssetSchedule = ({ assetExpression, dagId, timetablePartitioned, ti { enabled: !timetablePartitioned }, ); - const nextRunEvents = (nextRun?.events ?? []) as Array; + const nextRunEvents: Array = nextRun?.events ?? []; const queuedAssetEvents = new Map(); if (!timetablePartitioned) { @@ -124,7 +124,7 @@ export const AssetSchedule = ({ assetExpression, dagId, timetablePartitioned, ti } if (timetablePartitioned) { - const pendingCount = (nextRun?.pending_partition_count as number | undefined) ?? 0; + const pendingCount = nextRun?.pending_partition_count ?? 0; if (pendingCount === 0) { return ( @@ -227,7 +227,7 @@ export const AssetSchedule = ({ assetExpression, dagId, timetablePartitioned, ti diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_assets.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_assets.py index 0b6f1b218578d..431d2271f4ae6 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_assets.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_assets.py @@ -77,8 +77,19 @@ def test_should_response_200(self, test_client, dag_maker): ] }, "events": [ - {"id": mock.ANY, "uri": "s3://bucket/next-run-asset/1", "name": "asset1", "last_update": None} + { + "id": mock.ANY, + "uri": "s3://bucket/next-run-asset/1", + "name": "asset1", + "last_update": None, + "received_count": 0, + "required_count": 1, + "received_keys": [], + "required_keys": [], + "is_rollup": False, + } ], + "pending_partition_count": None, } def test_should_respond_401(self, unauthenticated_test_client): @@ -141,9 +152,30 @@ def test_should_set_last_update_only_for_queued_and_hide_flag(self, test_client, }, # events are ordered by uri "events": [ - {"id": mock.ANY, "uri": "s3://bucket/A", "name": "A", "last_update": mock.ANY}, - {"id": mock.ANY, "uri": "s3://bucket/B", "name": "B", "last_update": None}, + { + "id": mock.ANY, + "uri": "s3://bucket/A", + "name": "A", + "last_update": mock.ANY, + "received_count": 0, + "required_count": 1, + "received_keys": [], + "required_keys": [], + "is_rollup": False, + }, + { + "id": mock.ANY, + "uri": "s3://bucket/B", + "name": "B", + "last_update": None, + "received_count": 0, + "required_count": 1, + "received_keys": [], + "required_keys": [], + "is_rollup": False, + }, ], + "pending_partition_count": None, } def test_last_update_respects_latest_run_filter(self, test_client, dag_maker, session): From 1d1b75e9392614f06222cf0d631dd0a6ba4c7b70 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Fri, 24 Apr 2026 17:49:20 +0800 Subject: [PATCH 08/16] feat: Add RollUpMapper and remove other outdated mappers --- .../core_api/openapi/_private_ui.yaml | 2 + .../routes/ui/partitioned_dag_runs.py | 20 +- .../example_dags/example_asset_partition.py | 94 +++--- .../src/airflow/jobs/scheduler_job_runner.py | 18 +- .../src/airflow/partition_mappers/base.py | 66 ++++- .../src/airflow/partition_mappers/temporal.py | 124 ++++---- .../src/airflow/partition_mappers/window.py | 106 +++++++ .../src/airflow/serialization/decoders.py | 11 + .../src/airflow/serialization/encoders.py | 70 ++++- .../ui/openapi-gen/requests/schemas.gen.ts | 3 +- .../ui/openapi-gen/requests/types.gen.ts | 3 + .../components/AssetExpression/AssetNode.tsx | 2 +- .../ui/src/pages/DagsList/AssetSchedule.tsx | 2 +- .../routes/ui/test_partitioned_dag_runs.py | 14 +- .../unit/partition_mappers/test_temporal.py | 22 ++ .../unit/partition_mappers/test_window.py | 276 ++++++++++++++++++ docs/spelling_wordlist.txt | 1 + task-sdk/docs/api.rst | 21 +- task-sdk/src/airflow/sdk/__init__.py | 33 ++- task-sdk/src/airflow/sdk/__init__.pyi | 23 +- .../sdk/definitions/partition_mappers/base.py | 19 +- .../definitions/partition_mappers/temporal.py | 24 +- .../definitions/partition_mappers/window.py | 53 ++++ 23 files changed, 832 insertions(+), 175 deletions(-) create mode 100644 airflow-core/src/airflow/partition_mappers/window.py create mode 100644 airflow-core/tests/unit/partition_mappers/test_window.py create mode 100644 task-sdk/src/airflow/sdk/definitions/partition_mappers/window.py diff --git a/airflow-core/src/airflow/api_fastapi/core_api/openapi/_private_ui.yaml b/airflow-core/src/airflow/api_fastapi/core_api/openapi/_private_ui.yaml index 6e31538413d29..eea4a738a83e2 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/openapi/_private_ui.yaml +++ b/airflow-core/src/airflow/api_fastapi/core_api/openapi/_private_ui.yaml @@ -2562,6 +2562,8 @@ components: - text - href title: ExtraMenuItem + description: Define a menu item that can be added to the menu by auth managers + or plugins. GanttResponse: properties: dag_id: diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/partitioned_dag_runs.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/partitioned_dag_runs.py index 7bf4f2392b2cb..f7504b6df3d5e 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/partitioned_dag_runs.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/partitioned_dag_runs.py @@ -159,15 +159,11 @@ def get_partitioned_dag_runs( has_created_dag_run_id: QueryPartitionedDagRunHasCreatedDagRunIdFilter, ) -> PartitionedDagRunCollectionResponse: """Return PartitionedDagRuns. Filter by dag_id and/or has_created_dag_run_id.""" - if dag_id.value is not None: - dag_info = session.execute( - select(DagModel.timetable_partitioned).where(DagModel.dag_id == dag_id.value) - ).one_or_none() - - if dag_info is None: - raise HTTPException(status.HTTP_404_NOT_FOUND, f"Dag with id {dag_id.value} was not found") - if not dag_info.timetable_partitioned: - return PartitionedDagRunCollectionResponse(partitioned_dag_runs=[], total=0) + # The dag-existence / partitioned-timetable check is intentionally deferred to the + # empty-results branch below. In the happy path (rows exist), filtering by dag_id + # already restricts to that Dag, so an extra DagModel lookup just to validate + # existence wastes a query. We only consult DagModel when we have no rows to + # report — that's the only case where the distinction (404 vs empty) matters. # Subquery for received count per partition (count of required assets that have any log). # This matches the non-rollup contract "any event for an asset = that asset is satisfied". @@ -208,6 +204,12 @@ def get_partitioned_dag_runs( query = query.order_by(AssetPartitionDagRun.created_at.desc()) if not (rows := session.execute(query).all()): + if dag_id.value is not None: + dag_info = session.execute( + select(DagModel.timetable_partitioned).where(DagModel.dag_id == dag_id.value) + ).one_or_none() + if dag_info is None: + raise HTTPException(status.HTTP_404_NOT_FOUND, f"Dag with id {dag_id.value} was not found") return PartitionedDagRunCollectionResponse(partitioned_dag_runs=[], total=0) if dag_id.value is not None: diff --git a/airflow-core/src/airflow/example_dags/example_asset_partition.py b/airflow-core/src/airflow/example_dags/example_asset_partition.py index 1569c850bf7a3..12b23f8e8bc75 100644 --- a/airflow-core/src/airflow/example_dags/example_asset_partition.py +++ b/airflow-core/src/airflow/example_dags/example_asset_partition.py @@ -23,15 +23,17 @@ DAG, AllowedKeyMapper, Asset, - AssetAll, CronPartitionTimetable, + DayWindow, IdentityMapper, + MonthWindow, PartitionedAssetTimetable, ProductMapper, + RollupMapper, StartOfDayMapper, StartOfHourMapper, + StartOfMonthMapper, StartOfYearMapper, - WeeklyRollupMapper, asset, task, ) @@ -229,66 +231,74 @@ def regional_stats_breakdown(): pass -daily_sales = Asset(uri="s3://sales/daily", name="daily_sales") -daily_costs = Asset(uri="s3://costs/daily", name="daily_costs") +# --- Chained rollup: hourly → daily → monthly -------------------------------- +# The hourly source asset already exists above (``team_a_player_stats``). +# Each rollup Dag publishes its own asset so the next level can consume it. - -with DAG( - dag_id="produce_daily_sales", - schedule=CronPartitionTimetable("0 0 * * *", timezone="UTC"), - tags=["sales", "ingestion"], -): - """Produce daily sales data partitioned by date.""" - - @task(outlets=[daily_sales]) - def upload_daily_sales(dag_run=None): - """Upload sales data for the current daily partition.""" - if TYPE_CHECKING: - assert dag_run - print(f"Producing partition: {dag_run.partition_key}") - - upload_daily_sales() +daily_team_a = Asset(uri="s3://team-a/daily", name="daily_team_a") +monthly_team_a = Asset(uri="s3://team-a/monthly", name="monthly_team_a") with DAG( - dag_id="produce_daily_costs", - schedule=CronPartitionTimetable("0 0 * * *", timezone="UTC"), - tags=["costs", "ingestion"], + dag_id="daily_team_a_rollup", + schedule=PartitionedAssetTimetable( + assets=team_a_player_stats, + default_partition_mapper=RollupMapper( + source_mapper=StartOfDayMapper(), + window=DayWindow(), + ), + ), + catchup=False, + tags=["player-stats", "rollup"], ): - """Produce daily cost data partitioned by date.""" + """ + First rollup level: 24 hourly partitions of ``team_a_player_stats`` → one daily summary. - @task(outlets=[daily_costs]) - def upload_daily_costs(dag_run=None): - """Upload cost data for the current daily partition.""" + ``StartOfDayMapper`` normalizes each upstream hourly timestamp (``%Y-%m-%dT%H:%M:%S``) + to its day-start (``%Y-%m-%d``); ``DayWindow`` declares the downstream run needs + all 24 hourly partitions before firing. Publishes ``daily_team_a`` so the + monthly rollup below can consume it. + """ + + @task(outlets=[daily_team_a]) + def summarise_team_a_day(dag_run=None): + """Produce the full-day rollup once every hour has arrived.""" if TYPE_CHECKING: assert dag_run - print(f"Producing partition: {dag_run.partition_key}") + print(f"All 24 hourly partitions received. Day: {dag_run.partition_key}") - upload_daily_costs() + summarise_team_a_day() with DAG( - dag_id="weekly_sales_report", + dag_id="monthly_team_a_rollup", schedule=PartitionedAssetTimetable( - assets=AssetAll(daily_sales, daily_costs), - default_partition_mapper=WeeklyRollupMapper(), + assets=daily_team_a, + # The upstream (``daily_team_a``) emits day-formatted partition keys + # (``%Y-%m-%d``), so the source mapper here must accept that format. + default_partition_mapper=RollupMapper( + source_mapper=StartOfMonthMapper(input_format="%Y-%m-%d"), + window=MonthWindow(), + ), ), catchup=False, - tags=["sales", "reporting"], + tags=["player-stats", "rollup"], ): """ - Generate a weekly sales report once all daily partitions for both assets have arrived. + Chained rollup: every day of ``daily_team_a`` (itself a rollup) → one monthly summary. - This Dag demonstrates WeeklyRollupMapper with multiple assets: it waits for all 7 - daily partitions of ``daily_sales`` and ``daily_costs`` before triggering for a given week. - The partition key is the week identifier, e.g. ``2024-01-15 (W03)``. + Demonstrates how a rollup output can feed another rollup. ``StartOfMonthMapper`` + is configured with ``input_format="%Y-%m-%d"`` so it can parse the day keys + emitted by ``daily_team_a_rollup``; ``MonthWindow`` waits for every day of the + calendar month (28–31 depending on the month). The partition key is the month + identifier, e.g. ``2024-01``. """ - @task - def generate_weekly_report(dag_run=None): - """Combine the full week of sales and cost data into a report.""" + @task(outlets=[monthly_team_a]) + def summarise_team_a_month(dag_run=None): + """Produce the full-month rollup once every day has arrived.""" if TYPE_CHECKING: assert dag_run - print(f"All 7 daily partitions for both assets received. Week: {dag_run.partition_key}") + print(f"All daily partitions received. Month: {dag_run.partition_key}") - generate_weekly_report() + summarise_team_a_month() diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py b/airflow-core/src/airflow/jobs/scheduler_job_runner.py index 1b2096a8cb9a8..7ade6e22036c4 100644 --- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py +++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py @@ -127,6 +127,11 @@ TASK_STUCK_IN_QUEUED_RESCHEDULE_EVENT = "stuck in queued reschedule" """:meta private:""" +# Per-tick cap on pending ``AssetPartitionDagRun`` rows the scheduler will evaluate. +# Bounds the transaction size so other scheduling work isn't starved; remaining +# rows drain across subsequent ticks. +MAX_PARTITION_DAG_RUNS_PER_TICK = 500 + def _eager_load_dag_run_for_validation() -> tuple[LoaderOption, LoaderOption]: """ @@ -1887,13 +1892,18 @@ def _create_dagruns_for_partitioned_asset_dags(self, session: Session) -> set[st # Cap per-tick work so the scheduler transaction stays bounded and other # scheduling work isn't starved. Remaining APDRs drain across subsequent ticks. - # Note: with strict FIFO ordering, >BATCH persistently-unsatisfied APDRs would - # block newer ones; switch to updated_at-based ordering if that becomes an issue. + # Note: with strict FIFO ordering, persistently-unsatisfied APDRs at the head + # of the queue would block newer ones; switch to updated_at-based ordering if + # that becomes an issue. pending_apdrs = session.scalars( select(AssetPartitionDagRun) - .where(AssetPartitionDagRun.created_dag_run_id.is_(None)) + .join(DagModel, DagModel.dag_id == AssetPartitionDagRun.target_dag_id) + .where( + AssetPartitionDagRun.created_dag_run_id.is_(None), + DagModel.is_stale.is_(False), + ) .order_by(AssetPartitionDagRun.created_at) - .limit(500) + .limit(MAX_PARTITION_DAG_RUNS_PER_TICK) ).all() if not pending_apdrs: return partition_dag_ids diff --git a/airflow-core/src/airflow/partition_mappers/base.py b/airflow-core/src/airflow/partition_mappers/base.py index e97d128027ffa..eafd60bbeae62 100644 --- a/airflow-core/src/airflow/partition_mappers/base.py +++ b/airflow-core/src/airflow/partition_mappers/base.py @@ -23,6 +23,8 @@ if TYPE_CHECKING: from collections.abc import Iterable + from airflow.partition_mappers.window import Window + class PartitionMapper(ABC): """ @@ -37,6 +39,28 @@ class PartitionMapper(ABC): def to_downstream(self, key: str) -> str | Iterable[str]: """Return the target key that the given source partition key maps to.""" + def decode_downstream(self, downstream_key: str) -> Any: + """ + Recover the canonical decoded form of *downstream_key*. + + Used by :class:`RollupMapper` to hand the window an opaque "anchor" + for the downstream period; the window iterates in this decoded space + and the mapper re-encodes each expected upstream via + :meth:`encode_upstream`. Default is identity (string in, string out) + — temporal mappers override to return ``datetime``, future segment + mappers will return whatever shape suits them. + """ + return downstream_key + + def encode_upstream(self, decoded: Any) -> str: + """ + Encode an expected upstream object back into a key string. + + Pair of :meth:`decode_downstream`. Default is identity. Temporal + mappers override to apply timezone + ``input_format``. + """ + return decoded + def serialize(self) -> dict[str, Any]: return {} @@ -45,18 +69,46 @@ def deserialize(cls, data: dict[str, Any]) -> PartitionMapper: return cls() -class RollupMapper(PartitionMapper, ABC): +class RollupMapper(PartitionMapper): """ - Partition mapper that supports rollup (many upstream keys → one downstream key). + Partition mapper that rolls up many upstream keys into one downstream key. - Subclass this when the downstream Dag should wait for a complete set of upstream - partition keys before triggering. The scheduler calls ``to_upstream`` to discover - which source keys are required and only creates a Dag run once all of them have - arrived in ``PartitionedAssetKeyLog``. + Compose a ``source_mapper`` (which normalizes each upstream key to the + downstream granularity) with a ``window`` that declares the full set of + upstream keys required for a given downstream key. The scheduler holds + the Dag run until every upstream key in the window has arrived. """ is_rollup: bool = True - @abstractmethod + def __init__(self, *, source_mapper: PartitionMapper, window: Window) -> None: + self.source_mapper = source_mapper + self.window = window + + def to_downstream(self, key: str) -> str | Iterable[str]: + return self.source_mapper.to_downstream(key) + def to_upstream(self, downstream_key: str) -> frozenset[str]: """Return the complete set of upstream partition keys required for *downstream_key*.""" + decoded = self.source_mapper.decode_downstream(downstream_key) + return frozenset( + self.source_mapper.encode_upstream(expected_upstream) + for expected_upstream in self.window.to_upstream(decoded) + ) + + def serialize(self) -> dict[str, Any]: + from airflow.serialization.encoders import encode_partition_mapper, encode_window + + return { + "source_mapper": encode_partition_mapper(self.source_mapper), + "window": encode_window(self.window), + } + + @classmethod + def deserialize(cls, data: dict[str, Any]) -> PartitionMapper: + from airflow.serialization.decoders import decode_partition_mapper, decode_window + + return cls( + source_mapper=decode_partition_mapper(data["source_mapper"]), + window=decode_window(data["window"]), + ) diff --git a/airflow-core/src/airflow/partition_mappers/temporal.py b/airflow-core/src/airflow/partition_mappers/temporal.py index ba1c3ee99d1e0..e82e943333d6b 100644 --- a/airflow-core/src/airflow/partition_mappers/temporal.py +++ b/airflow-core/src/airflow/partition_mappers/temporal.py @@ -22,9 +22,7 @@ from typing import TYPE_CHECKING, Any from airflow._shared.timezones.timezone import make_aware, parse_timezone -from airflow.partition_mappers.base import PartitionMapper, RollupMapper - -_YMD_RE = re.compile(r"\d{4}-\d{2}-\d{2}") +from airflow.partition_mappers.base import PartitionMapper if TYPE_CHECKING: from pendulum import FixedTimezone, Timezone @@ -65,6 +63,29 @@ def format(self, dt: datetime) -> str: """Format the normalized datetime.""" return dt.strftime(self.output_format) + def decode_downstream(self, downstream_key: str) -> datetime: + """ + Recover the period-start datetime from a previously formatted downstream key. + + Inverse of ``format``. The default implementation uses ``strptime`` with + ``output_format``, which works for any format made of standard strptime + directives. Subclasses with custom format markers (e.g. ``{quarter}``) or + ambiguous directives (e.g. bare ``%V``) override this. + """ + return datetime.strptime(downstream_key, self.output_format) + + def encode_upstream(self, dt: datetime) -> str: + """ + Format *dt* as an upstream partition key string. + + Pair of :meth:`decode_downstream`: takes a (decoded) period-start + datetime and produces a key string in the upstream's ``input_format`` + with ``timezone`` applied. Used by :class:`RollupMapper` to render each + upstream member yielded by the window back into the form upstream + producers actually emit. + """ + return make_aware(dt, self._timezone).strftime(self.input_format) + def serialize(self) -> dict[str, Any]: from airflow.serialization.encoders import encode_timezone @@ -105,6 +126,7 @@ class StartOfWeekMapper(_BaseTemporalMapper): """Map a time-based partition key to the start of its week.""" default_output_format = "%Y-%m-%d (W%V)" + _YMD_RE = re.compile(r"\d{4}-\d{2}-\d{2}") def __init__( self, @@ -114,6 +136,8 @@ def __init__( input_format: str = "%Y-%m-%dT%H:%M:%S", output_format: str | None = None, ) -> None: + if not 0 <= week_start <= 6: + raise ValueError(f"week_start must be between 0 (Monday) and 6 (Sunday), got {week_start!r}") super().__init__(timezone=timezone, input_format=input_format, output_format=output_format) self.week_start = week_start # 0 = Monday (ISO default), 6 = Sunday @@ -122,6 +146,18 @@ def normalize(self, dt: datetime) -> datetime: start = dt - timedelta(days=days_since_start) return start.replace(hour=0, minute=0, second=0, microsecond=0) + def decode_downstream(self, downstream_key: str) -> datetime: + # %V (ISO week) cannot be parsed by strptime without %G+%u, so locate + # the YYYY-MM-DD slice with a regex. Robust across formats that mix + # the date with extras like "(W%V)". + match = self._YMD_RE.search(downstream_key) + if match is None: + raise ValueError( + f"StartOfWeekMapper.decode_downstream could not locate YYYY-MM-DD in {downstream_key!r}; " + "output_format must include '%Y-%m-%d'." + ) + return datetime.strptime(match.group(), "%Y-%m-%d") + def serialize(self) -> dict[str, Any]: return {**super().serialize(), "week_start": self.week_start} @@ -135,42 +171,6 @@ def deserialize(cls, data: dict[str, Any]) -> PartitionMapper: ) -class WeeklyRollupMapper(StartOfWeekMapper, RollupMapper): - """ - Map a time-based partition key to the start of its week, requiring all 7 daily keys. - - Use this when a partitioned Dag should only run once every daily asset partition - for a full week has been produced. Configure ``week_start`` to set which day begins - the week (0 = Monday, 6 = Sunday). - """ - - def __init__(self, **kwargs) -> None: - super().__init__(**kwargs) - if "%Y-%m-%d" not in self.output_format: - raise ValueError( - f"WeeklyRollupMapper requires output_format to contain '%Y-%m-%d' so that " - f"to_upstream() can recover the week-start date, got: {self.output_format!r}" - ) - - def to_upstream(self, downstream_key: str) -> frozenset[str]: - # strptime cannot consume %V (ISO week) without %G+weekday, so parse by - # locating the YYYY-MM-DD slice directly. Regex is robust against - # variable-width directives (e.g. %B, %A, %Z) appearing elsewhere in the key. - match = _YMD_RE.search(downstream_key) - if match is None: - raise ValueError( - f"WeeklyRollupMapper.to_upstream could not locate YYYY-MM-DD in {downstream_key!r}" - ) - week_start_naive = datetime.strptime(match.group(), "%Y-%m-%d") - # Arithmetic stays on naive datetimes to keep day-counting unambiguous across - # DST transitions; each result is made timezone-aware before formatting so that - # %z in input_format produces the correct offset. - return frozenset( - make_aware(week_start_naive + timedelta(days=i), self._timezone).strftime(self.input_format) - for i in range(7) - ) - - class StartOfMonthMapper(_BaseTemporalMapper): """Map a time-based partition key to the start of its month.""" @@ -184,6 +184,8 @@ def __init__( input_format: str = "%Y-%m-%dT%H:%M:%S", output_format: str | None = None, ) -> None: + if not 1 <= month_start_day <= 28: + raise ValueError(f"month_start_day must be between 1 and 28, got {month_start_day!r}") super().__init__(timezone=timezone, input_format=input_format, output_format=output_format) self.month_start_day = month_start_day # 1–28; use >1 for fiscal-month offsets @@ -196,6 +198,11 @@ def normalize(self, dt: datetime) -> datetime: start = dt.replace(day=self.month_start_day) return start.replace(hour=0, minute=0, second=0, microsecond=0) + def decode_downstream(self, downstream_key: str) -> datetime: + # The default strptime returns day=1; pin to month_start_day so fiscal + # months recover the correct period start. + return super().decode_downstream(downstream_key).replace(day=self.month_start_day) + def serialize(self) -> dict[str, Any]: return {**super().serialize(), "month_start_day": self.month_start_day} @@ -209,35 +216,11 @@ def deserialize(cls, data: dict[str, Any]) -> PartitionMapper: ) -class MonthlyRollupMapper(StartOfMonthMapper, RollupMapper): - """ - Map a time-based partition key to the start of its month, requiring all daily keys in that month. - - Use this when a partitioned Dag should only run once every daily asset partition - for a full calendar month has been produced. Configure ``month_start_day`` for - fiscal-month offsets (e.g. ``month_start_day=15`` for a mid-month period). - """ - - def to_upstream(self, downstream_key: str) -> frozenset[str]: - # Use naive datetimes for day-counting to avoid DST ambiguity, then make - # each result timezone-aware before formatting so %z produces the correct offset. - period_start_naive = datetime.strptime(downstream_key, self.output_format).replace( - day=self.month_start_day - ) - next_month = period_start_naive.month % 12 + 1 - next_year = period_start_naive.year + (1 if period_start_naive.month == 12 else 0) - next_start_naive = period_start_naive.replace(year=next_year, month=next_month) - days = (next_start_naive - period_start_naive).days - return frozenset( - make_aware(period_start_naive + timedelta(days=i), self._timezone).strftime(self.input_format) - for i in range(days) - ) - - class StartOfQuarterMapper(_BaseTemporalMapper): """Map a time-based partition key to quarter.""" default_output_format = "%Y-Q{quarter}" + _YEAR_QUARTER_RE = re.compile(r"(\d{4}).*?Q([1-4])") def normalize(self, dt: datetime) -> datetime: quarter = (dt.month - 1) // 3 @@ -255,6 +238,19 @@ def format(self, dt: datetime) -> str: quarter = (dt.month - 1) // 3 + 1 return dt.strftime(self.output_format).format(quarter=quarter) + def decode_downstream(self, downstream_key: str) -> datetime: + # output_format carries a ``{quarter}`` placeholder, so strptime doesn't + # apply directly. Locate ``YYYY...Q`` and rebuild the period start. + match = self._YEAR_QUARTER_RE.search(downstream_key) + if match is None: + raise ValueError( + f"StartOfQuarterMapper.decode_downstream could not locate YYYY...Q in " + f"{downstream_key!r}; output_format must include the year and 'Q{{quarter}}'." + ) + year = int(match.group(1)) + quarter = int(match.group(2)) + return datetime(year, (quarter - 1) * 3 + 1, 1) + class StartOfYearMapper(_BaseTemporalMapper): """Map a time-based partition key to year.""" diff --git a/airflow-core/src/airflow/partition_mappers/window.py b/airflow-core/src/airflow/partition_mappers/window.py new file mode 100644 index 0000000000000..1913c2295b06b --- /dev/null +++ b/airflow-core/src/airflow/partition_mappers/window.py @@ -0,0 +1,106 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from abc import ABC, abstractmethod +from datetime import datetime, timedelta +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from collections.abc import Iterable + + +class Window(ABC): + """ + Describes a rollup window: which decoded upstream items make up one decoded downstream period. + + Paired with a source mapper inside a :class:`RollupMapper`. The window + operates purely in the source mapper's *decoded* form (``datetime`` for + temporal mappers, domain-specific types for future segment / runtime + mappers). It does not touch key strings, timezones, or formats — those + belong to the source mapper. ``RollupMapper`` orchestrates the three: + decode the downstream key, expand via the window, encode each upstream. + """ + + @abstractmethod + def to_upstream(self, decoded_downstream: Any) -> Iterable[Any]: + """Yield each decoded upstream item composing *decoded_downstream*.""" + + def serialize(self) -> dict[str, Any]: + return {} + + @classmethod + def deserialize(cls, data: dict[str, Any]) -> Window: + return cls() + + +class HourWindow(Window): + """Sixty consecutive minute period-starts making up one hour.""" + + def to_upstream(self, period_start: datetime) -> Iterable[datetime]: + return (period_start + timedelta(minutes=i) for i in range(60)) + + +class DayWindow(Window): + """ + Twenty-four consecutive hourly period-starts making up one day. + + Arithmetic is done on naive datetime steps so the 24-hour stride is + unambiguous across DST transitions; the source mapper handles timezone + awareness when it encodes each upstream member back to a key string. + """ + + def to_upstream(self, period_start: datetime) -> Iterable[datetime]: + return (period_start + timedelta(hours=i) for i in range(24)) + + +class WeekWindow(Window): + """Seven consecutive daily period-starts making up one week.""" + + def to_upstream(self, period_start: datetime) -> Iterable[datetime]: + return (period_start + timedelta(days=i) for i in range(7)) + + +class MonthWindow(Window): + """ + All daily period-starts making up one calendar month. + + The source mapper's ``decode_downstream`` already accounts for fiscal + ``month_start_day``, so this just iterates from the period start to the + matching day of the next month. + """ + + def to_upstream(self, period_start: datetime) -> Iterable[datetime]: + next_month = period_start.month % 12 + 1 + next_year = period_start.year + (1 if period_start.month == 12 else 0) + next_start = period_start.replace(year=next_year, month=next_month) + days = (next_start - period_start).days + return (period_start + timedelta(days=i) for i in range(days)) + + +class QuarterWindow(Window): + """Three monthly period-starts making up one calendar quarter (e.g. Jan/Feb/Mar for Q1).""" + + def to_upstream(self, period_start: datetime) -> Iterable[datetime]: + return (period_start.replace(month=period_start.month + i) for i in range(3)) + + +class YearWindow(Window): + """Twelve consecutive monthly period-starts making up one calendar year.""" + + def to_upstream(self, period_start: datetime) -> Iterable[datetime]: + return (period_start.replace(month=i + 1) for i in range(12)) diff --git a/airflow-core/src/airflow/serialization/decoders.py b/airflow-core/src/airflow/serialization/decoders.py index 683efdc87e6ec..577cdc703e3cc 100644 --- a/airflow-core/src/airflow/serialization/decoders.py +++ b/airflow-core/src/airflow/serialization/decoders.py @@ -49,6 +49,7 @@ if TYPE_CHECKING: from airflow.partition_mappers.base import PartitionMapper + from airflow.partition_mappers.window import Window from airflow.timetables.base import Timetable as CoreTimetable R = TypeVar("R") @@ -202,3 +203,13 @@ def decode_partition_mapper(var: dict[str, Any]) -> PartitionMapper: else: partition_mapper_cls = find_registered_custom_partition_mapper(importable_string) return partition_mapper_cls.deserialize(var[Encoding.VAR]) + + +def decode_window(var: dict[str, Any]) -> Window: + """ + Decode a previously serialized :class:`Window`. + + :meta private: + """ + window_cls: type[Window] = import_string(var[Encoding.TYPE]) + return window_cls.deserialize(var[Encoding.VAR]) diff --git a/airflow-core/src/airflow/serialization/encoders.py b/airflow-core/src/airflow/serialization/encoders.py index b84dfb73f05aa..17c80727762ba 100644 --- a/airflow-core/src/airflow/serialization/encoders.py +++ b/airflow-core/src/airflow/serialization/encoders.py @@ -27,6 +27,7 @@ from airflow._shared.module_loading import qualname from airflow.partition_mappers.base import PartitionMapper as CorePartitionMapper +from airflow.partition_mappers.window import Window as CoreWindow from airflow.sdk import ( AllowedKeyMapper, Asset, @@ -37,20 +38,26 @@ ChainMapper, CronDataIntervalTimetable, CronTriggerTimetable, + DayWindow, DeltaDataIntervalTimetable, DeltaTriggerTimetable, EventsTimetable, + HourWindow, IdentityMapper, - MonthlyRollupMapper, + MonthWindow, MultipleCronTriggerTimetable, PartitionMapper, ProductMapper, + QuarterWindow, + RollupMapper, StartOfDayMapper, StartOfMonthMapper, StartOfQuarterMapper, StartOfWeekMapper, StartOfYearMapper, - WeeklyRollupMapper, + WeekWindow, + Window, + YearWindow, ) from airflow.sdk.bases.timetable import BaseTimetable from airflow.sdk.definitions.asset import AssetRef @@ -410,14 +417,13 @@ def _(self, timetable: PartitionedAssetTimetable) -> dict[str, Any]: ChainMapper: "airflow.partition_mappers.chain.ChainMapper", IdentityMapper: "airflow.partition_mappers.identity.IdentityMapper", ProductMapper: "airflow.partition_mappers.product.ProductMapper", + RollupMapper: "airflow.partition_mappers.base.RollupMapper", StartOfDayMapper: "airflow.partition_mappers.temporal.StartOfDayMapper", StartOfHourMapper: "airflow.partition_mappers.temporal.StartOfHourMapper", StartOfMonthMapper: "airflow.partition_mappers.temporal.StartOfMonthMapper", StartOfQuarterMapper: "airflow.partition_mappers.temporal.StartOfQuarterMapper", StartOfWeekMapper: "airflow.partition_mappers.temporal.StartOfWeekMapper", StartOfYearMapper: "airflow.partition_mappers.temporal.StartOfYearMapper", - WeeklyRollupMapper: "airflow.partition_mappers.temporal.WeeklyRollupMapper", - MonthlyRollupMapper: "airflow.partition_mappers.temporal.MonthlyRollupMapper", } @functools.singledispatchmethod @@ -467,6 +473,40 @@ def _(self, partition_mapper: ProductMapper) -> dict[str, Any]: def _(self, partition_mapper: AllowedKeyMapper) -> dict[str, Any]: return {"allowed_keys": partition_mapper.allowed_keys} + @serialize_partition_mapper.register + def _(self, partition_mapper: RollupMapper) -> dict[str, Any]: + return { + "source_mapper": encode_partition_mapper(partition_mapper.source_mapper), + "window": encode_window(partition_mapper.window), + } + + BUILTIN_WINDOWS: dict[type, str] = { + HourWindow: "airflow.partition_mappers.window.HourWindow", + DayWindow: "airflow.partition_mappers.window.DayWindow", + WeekWindow: "airflow.partition_mappers.window.WeekWindow", + MonthWindow: "airflow.partition_mappers.window.MonthWindow", + QuarterWindow: "airflow.partition_mappers.window.QuarterWindow", + YearWindow: "airflow.partition_mappers.window.YearWindow", + } + + @functools.singledispatchmethod + def serialize_window(self, window: Window | CoreWindow) -> dict[str, Any]: + if not isinstance(window, CoreWindow): + raise NotImplementedError(f"can not serialize window {type(window).__name__}") + return window.serialize() + + @serialize_window.register(HourWindow) + @serialize_window.register(DayWindow) + @serialize_window.register(WeekWindow) + @serialize_window.register(MonthWindow) + @serialize_window.register(QuarterWindow) + @serialize_window.register(YearWindow) + def _( + self, + window: HourWindow | DayWindow | WeekWindow | MonthWindow | QuarterWindow | YearWindow, + ) -> dict[str, Any]: + return {} + _serializer = _Serializer() @@ -556,3 +596,25 @@ def encode_partition_mapper(var: PartitionMapper | CorePartitionMapper) -> dict[ Encoding.TYPE: qn, Encoding.VAR: _serializer.serialize_partition_mapper(var), } + + +def encode_window(var: Window | CoreWindow) -> dict[str, Any]: + """ + Encode a :class:`Window` instance. + + :meta private: + """ + var_type = type(var) + importable_string = _serializer.BUILTIN_WINDOWS.get(var_type) + if importable_string is not None: + return { + Encoding.TYPE: importable_string, + Encoding.VAR: _serializer.serialize_window(var), + } + + # Custom Window subclasses must live at an importable path so the + # scheduler can reconstruct them via import_string during deserialization. + return { + Encoding.TYPE: qualname(var), + Encoding.VAR: _serializer.serialize_window(var), + } diff --git a/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts b/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts index fd5ed96975257..61589be2ed320 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts @@ -8323,7 +8323,8 @@ export const $ExtraMenuItem = { }, type: 'object', required: ['text', 'href'], - title: 'ExtraMenuItem' + title: 'ExtraMenuItem', + description: 'Define a menu item that can be added to the menu by auth managers or plugins.' } as const; export const $GanttResponse = { diff --git a/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts b/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts index 57948231797be..8acc694124bea 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts @@ -2063,6 +2063,9 @@ export type EdgeResponse = { is_source_asset?: boolean | null; }; +/** + * Define a menu item that can be added to the menu by auth managers or plugins. + */ export type ExtraMenuItem = { text: string; href: string; diff --git a/airflow-core/src/airflow/ui/src/components/AssetExpression/AssetNode.tsx b/airflow-core/src/airflow/ui/src/components/AssetExpression/AssetNode.tsx index ba61aeb944dfa..c29b89e83324d 100644 --- a/airflow-core/src/airflow/ui/src/components/AssetExpression/AssetNode.tsx +++ b/airflow-core/src/airflow/ui/src/components/AssetExpression/AssetNode.tsx @@ -56,7 +56,7 @@ const RollupKeyChecklistPopover = ({ - + {requiredKeys.map((key) => { const isReceived = receivedKeySet.has(key); diff --git a/airflow-core/src/airflow/ui/src/pages/DagsList/AssetSchedule.tsx b/airflow-core/src/airflow/ui/src/pages/DagsList/AssetSchedule.tsx index f5baf81c30bd0..d2477e11caae7 100644 --- a/airflow-core/src/airflow/ui/src/pages/DagsList/AssetSchedule.tsx +++ b/airflow-core/src/airflow/ui/src/pages/DagsList/AssetSchedule.tsx @@ -176,7 +176,7 @@ export const AssetSchedule = ({ assetExpression, dagId, timetablePartitioned, ti - + {requiredKeys.map((key) => { const isReceived = receivedKeySet.has(key); diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_partitioned_dag_runs.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_partitioned_dag_runs.py index e05eee99187bc..72c05a68d20ea 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_partitioned_dag_runs.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_partitioned_dag_runs.py @@ -23,7 +23,9 @@ from sqlalchemy import select from airflow.models.asset import AssetEvent, AssetModel, AssetPartitionDagRun, PartitionedAssetKeyLog -from airflow.partition_mappers.temporal import WeeklyRollupMapper +from airflow.partition_mappers.base import RollupMapper +from airflow.partition_mappers.temporal import StartOfWeekMapper +from airflow.partition_mappers.window import WeekWindow from airflow.providers.standard.operators.empty import EmptyOperator from airflow.sdk.definitions.asset import Asset from airflow.sdk.definitions.timetables.assets import PartitionedAssetTimetable @@ -321,7 +323,10 @@ def test_non_rollup_any_event_counts_as_one(self, test_client, dag_maker, sessio def test_rollup_mapper_counts_received_upstream_keys(self, test_client, dag_maker, session): """For a rollup mapper, only upstream keys in to_upstream() are counted.""" asset_def = Asset(uri="s3://bucket/daily", name="daily") - mapper = WeeklyRollupMapper(input_format="%Y-%m-%d", output_format="%Y-%m-%d", week_start=0) + mapper = RollupMapper( + source_mapper=StartOfWeekMapper(input_format="%Y-%m-%d", output_format="%Y-%m-%d", week_start=0), + window=WeekWindow(), + ) with dag_maker( dag_id="rollup_dag", schedule=PartitionedAssetTimetable(assets=asset_def, partition_mapper_config={asset_def: mapper}), @@ -515,7 +520,10 @@ def test_is_rollup_false_for_non_rollup_asset(self, test_client, dag_maker, sess def test_is_rollup_true_for_rollup_asset(self, test_client, dag_maker, session): """is_rollup is True for assets that use a RollupMapper, and keys are populated.""" asset_def = Asset(uri="s3://bucket/weekly", name="weekly") - mapper = WeeklyRollupMapper(input_format="%Y-%m-%d", output_format="%Y-%m-%d", week_start=0) + mapper = RollupMapper( + source_mapper=StartOfWeekMapper(input_format="%Y-%m-%d", output_format="%Y-%m-%d", week_start=0), + window=WeekWindow(), + ) with dag_maker( dag_id="rollup_detail_dag", schedule=PartitionedAssetTimetable(assets=asset_def, partition_mapper_config={asset_def: mapper}), diff --git a/airflow-core/tests/unit/partition_mappers/test_temporal.py b/airflow-core/tests/unit/partition_mappers/test_temporal.py index 1a32900284912..170856f883182 100644 --- a/airflow-core/tests/unit/partition_mappers/test_temporal.py +++ b/airflow-core/tests/unit/partition_mappers/test_temporal.py @@ -130,3 +130,25 @@ def test_to_downstream_input_timezone_differs_from_mapper_timezone(self): # 2026-02-11T06:00:00+00:00 UTC == 2026-02-11T01:00:00-05:00 New York # → start-of-day in New York is 2026-02-11 assert pm.to_downstream("2026-02-11T06:00:00+0000") == "2026-02-11" + + +class TestStartOfWeekMapperValidation: + @pytest.mark.parametrize("week_start", [-1, 7, 100]) + def test_rejects_out_of_range(self, week_start): + with pytest.raises(ValueError, match="week_start must be between 0"): + StartOfWeekMapper(week_start=week_start) + + @pytest.mark.parametrize("week_start", [0, 3, 6]) + def test_accepts_valid_range(self, week_start): + assert StartOfWeekMapper(week_start=week_start).week_start == week_start + + +class TestStartOfMonthMapperValidation: + @pytest.mark.parametrize("month_start_day", [0, 29, 31, -1]) + def test_rejects_out_of_range(self, month_start_day): + with pytest.raises(ValueError, match="month_start_day must be between 1 and 28"): + StartOfMonthMapper(month_start_day=month_start_day) + + @pytest.mark.parametrize("month_start_day", [1, 15, 28]) + def test_accepts_valid_range(self, month_start_day): + assert StartOfMonthMapper(month_start_day=month_start_day).month_start_day == month_start_day diff --git a/airflow-core/tests/unit/partition_mappers/test_window.py b/airflow-core/tests/unit/partition_mappers/test_window.py new file mode 100644 index 0000000000000..027ef32bf4ebe --- /dev/null +++ b/airflow-core/tests/unit/partition_mappers/test_window.py @@ -0,0 +1,276 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from datetime import datetime + +import pytest + +from airflow.partition_mappers.base import RollupMapper +from airflow.partition_mappers.temporal import ( + StartOfDayMapper, + StartOfHourMapper, + StartOfMonthMapper, + StartOfQuarterMapper, + StartOfWeekMapper, + StartOfYearMapper, +) +from airflow.partition_mappers.window import ( + DayWindow, + HourWindow, + MonthWindow, + QuarterWindow, + WeekWindow, + YearWindow, +) + + +class TestHourWindow: + def test_yields_60_minute_period_starts(self): + period_start = datetime(2024, 6, 10, 14) + members = list(HourWindow().to_upstream(period_start)) + assert len(members) == 60 + assert members[0] == datetime(2024, 6, 10, 14, 0) + assert members[-1] == datetime(2024, 6, 10, 14, 59) + + +class TestDayWindow: + def test_yields_24_hourly_period_starts(self): + period_start = datetime(2024, 6, 10) + members = list(DayWindow().to_upstream(period_start)) + assert len(members) == 24 + assert members[0] == datetime(2024, 6, 10, 0) + assert members[-1] == datetime(2024, 6, 10, 23) + + +class TestWeekWindow: + def test_yields_seven_daily_period_starts(self): + # 2024-06-10 is a Monday. + period_start = datetime(2024, 6, 10) + members = list(WeekWindow().to_upstream(period_start)) + assert members == [ + datetime(2024, 6, 10), + datetime(2024, 6, 11), + datetime(2024, 6, 12), + datetime(2024, 6, 13), + datetime(2024, 6, 14), + datetime(2024, 6, 15), + datetime(2024, 6, 16), + ] + + +class TestMonthWindow: + def test_yields_all_days_in_february_leap_year(self): + members = list(MonthWindow().to_upstream(datetime(2024, 2, 1))) + assert len(members) == 29 + assert members[0] == datetime(2024, 2, 1) + assert members[-1] == datetime(2024, 2, 29) + + def test_honours_fiscal_month_offset(self): + # 2024-06-15 → 2024-07-14 covers 30 days when month_start_day=15. + members = list(MonthWindow().to_upstream(datetime(2024, 6, 15))) + assert len(members) == 30 + assert members[0] == datetime(2024, 6, 15) + assert members[-1] == datetime(2024, 7, 14) + + def test_crosses_year_boundary(self): + members = list(MonthWindow().to_upstream(datetime(2024, 12, 1))) + assert len(members) == 31 + assert members[-1] == datetime(2024, 12, 31) + + +class TestQuarterWindow: + def test_yields_three_monthly_period_starts(self): + # Q2 starts at 2024-04-01. + members = list(QuarterWindow().to_upstream(datetime(2024, 4, 1))) + assert members == [datetime(2024, 4, 1), datetime(2024, 5, 1), datetime(2024, 6, 1)] + + +class TestYearWindow: + def test_yields_twelve_monthly_period_starts(self): + members = list(YearWindow().to_upstream(datetime(2024, 1, 1))) + assert members == [datetime(2024, m, 1) for m in range(1, 13)] + + +class TestRollupMapperComposition: + def test_to_downstream_delegates_to_source_mapper(self): + mapper = RollupMapper( + source_mapper=StartOfWeekMapper(week_start=0), + window=WeekWindow(), + ) + # Wednesday 2024-06-12 belongs to the week starting Monday 2024-06-10. + assert mapper.to_downstream("2024-06-12T14:30:00") == "2024-06-10 (W24)" + + def test_is_rollup_flag(self): + mapper = RollupMapper(source_mapper=StartOfWeekMapper(), window=WeekWindow()) + assert mapper.is_rollup is True + + def test_hour_rollup_to_upstream_keys(self): + mapper = RollupMapper( + source_mapper=StartOfHourMapper(input_format="%Y-%m-%dT%H:%M", output_format="%Y-%m-%dT%H"), + window=HourWindow(), + ) + upstream = mapper.to_upstream("2024-06-10T14") + assert upstream == frozenset(f"2024-06-10T14:{m:02d}" for m in range(60)) + + def test_day_rollup_to_upstream_keys(self): + mapper = RollupMapper( + source_mapper=StartOfDayMapper(input_format="%Y-%m-%dT%H", output_format="%Y-%m-%d"), + window=DayWindow(), + ) + upstream = mapper.to_upstream("2024-06-10") + assert upstream == frozenset(f"2024-06-10T{h:02d}" for h in range(24)) + + def test_day_rollup_honours_timezone_in_encode(self): + mapper = RollupMapper( + source_mapper=StartOfDayMapper( + input_format="%Y-%m-%dT%H%z", + output_format="%Y-%m-%d", + timezone="Asia/Taipei", + ), + window=DayWindow(), + ) + upstream = mapper.to_upstream("2024-06-10") + assert "2024-06-10T00+0800" in upstream + assert "2024-06-10T23+0800" in upstream + + def test_week_rollup_with_default_format(self): + mapper = RollupMapper( + source_mapper=StartOfWeekMapper(input_format="%Y-%m-%d", output_format="%Y-%m-%d", week_start=0), + window=WeekWindow(), + ) + upstream = mapper.to_upstream("2024-06-10") + assert upstream == frozenset( + { + "2024-06-10", + "2024-06-11", + "2024-06-12", + "2024-06-13", + "2024-06-14", + "2024-06-15", + "2024-06-16", + } + ) + + def test_week_rollup_accepts_custom_output_format(self): + # decode_downstream lives on the mapper, so any format embedding the date works. + mapper = RollupMapper( + source_mapper=StartOfWeekMapper( + input_format="%Y-%m-%d", + output_format="Week-of-%Y-%m-%d", + week_start=0, + ), + window=WeekWindow(), + ) + upstream = mapper.to_upstream("Week-of-2024-06-10") + assert len(upstream) == 7 + assert "2024-06-10" in upstream + + def test_week_rollup_raises_when_downstream_key_has_no_date(self): + mapper = RollupMapper( + source_mapper=StartOfWeekMapper(input_format="%Y-%m-%d"), + window=WeekWindow(), + ) + with pytest.raises(ValueError, match="could not locate YYYY-MM-DD"): + mapper.to_upstream("week-24") + + def test_month_rollup_fiscal_offset(self): + mapper = RollupMapper( + source_mapper=StartOfMonthMapper( + input_format="%Y-%m-%d", output_format="%Y-%m", month_start_day=15 + ), + window=MonthWindow(), + ) + upstream = mapper.to_upstream("2024-06") + assert "2024-06-15" in upstream + assert "2024-07-14" in upstream + + def test_quarter_rollup_to_upstream_keys(self): + mapper = RollupMapper( + source_mapper=StartOfQuarterMapper(input_format="%Y-%m"), + window=QuarterWindow(), + ) + assert mapper.to_upstream("2024-Q2") == frozenset({"2024-04", "2024-05", "2024-06"}) + + def test_quarter_rollup_raises_when_marker_missing(self): + mapper = RollupMapper( + source_mapper=StartOfQuarterMapper(input_format="%Y-%m"), + window=QuarterWindow(), + ) + with pytest.raises(ValueError, match="could not locate YYYY...Q"): + mapper.to_upstream("2024-06") + + def test_year_rollup_to_upstream_keys(self): + mapper = RollupMapper( + source_mapper=StartOfYearMapper(input_format="%Y-%m"), + window=YearWindow(), + ) + assert mapper.to_upstream("2024") == frozenset(f"2024-{m:02d}" for m in range(1, 13)) + + def test_serialize_round_trip(self): + from airflow.serialization.decoders import decode_partition_mapper + from airflow.serialization.encoders import encode_partition_mapper + + mapper = RollupMapper( + source_mapper=StartOfWeekMapper(week_start=0, input_format="%Y-%m-%d", output_format="%Y-%m-%d"), + window=WeekWindow(), + ) + restored = decode_partition_mapper(encode_partition_mapper(mapper)) + assert isinstance(restored, RollupMapper) + assert isinstance(restored.source_mapper, StartOfWeekMapper) + assert restored.source_mapper.week_start == 0 + assert isinstance(restored.window, WeekWindow) + assert restored.to_upstream("2024-06-10") == mapper.to_upstream("2024-06-10") + + @pytest.mark.parametrize( + ("source_factory", "window", "downstream_key"), + [ + pytest.param( + lambda: StartOfHourMapper(input_format="%Y-%m-%dT%H:%M", output_format="%Y-%m-%dT%H"), + HourWindow(), + "2024-06-10T14", + id="hour", + ), + pytest.param( + lambda: StartOfDayMapper(input_format="%Y-%m-%dT%H", output_format="%Y-%m-%d"), + DayWindow(), + "2024-06-10", + id="day", + ), + pytest.param( + lambda: StartOfQuarterMapper(input_format="%Y-%m"), + QuarterWindow(), + "2024-Q2", + id="quarter", + ), + pytest.param( + lambda: StartOfYearMapper(input_format="%Y-%m"), + YearWindow(), + "2024", + id="year", + ), + ], + ) + def test_window_serialize_round_trip(self, source_factory, window, downstream_key): + from airflow.serialization.decoders import decode_partition_mapper + from airflow.serialization.encoders import encode_partition_mapper + + mapper = RollupMapper(source_mapper=source_factory(), window=window) + restored = decode_partition_mapper(encode_partition_mapper(mapper)) + assert isinstance(restored, RollupMapper) + assert isinstance(restored.window, type(window)) + assert restored.to_upstream(downstream_key) == mapper.to_upstream(downstream_key) diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index b4460b21e7563..12469172aca48 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -1374,6 +1374,7 @@ Roadmap roadmap roc RoleBinding +rollup rom romeoandjuliet rootcss diff --git a/task-sdk/docs/api.rst b/task-sdk/docs/api.rst index 1598fc71afe91..d76aa4b872b18 100644 --- a/task-sdk/docs/api.rst +++ b/task-sdk/docs/api.rst @@ -233,14 +233,29 @@ Partition Mapper .. autoapiclass:: airflow.sdk.StartOfYearMapper -.. autoapiclass:: airflow.sdk.WeeklyRollupMapper - -.. autoapiclass:: airflow.sdk.MonthlyRollupMapper +.. autoapiclass:: airflow.sdk.RollupMapper .. autoapiclass:: airflow.sdk.ProductMapper .. autoapiclass:: airflow.sdk.AllowedKeyMapper +Rollup Windows +~~~~~~~~~~~~~~ + +.. autoapiclass:: airflow.sdk.Window + +.. autoapiclass:: airflow.sdk.HourWindow + +.. autoapiclass:: airflow.sdk.DayWindow + +.. autoapiclass:: airflow.sdk.WeekWindow + +.. autoapiclass:: airflow.sdk.MonthWindow + +.. autoapiclass:: airflow.sdk.QuarterWindow + +.. autoapiclass:: airflow.sdk.YearWindow + I/O Helpers ----------- .. autoapiclass:: airflow.sdk.ObjectStoragePath diff --git a/task-sdk/src/airflow/sdk/__init__.py b/task-sdk/src/airflow/sdk/__init__.py index 301d118e257c1..4077bfa5bf41e 100644 --- a/task-sdk/src/airflow/sdk/__init__.py +++ b/task-sdk/src/airflow/sdk/__init__.py @@ -45,6 +45,7 @@ "CronPartitionTimetable", "DAG", "DagRunState", + "DayWindow", "DeadlineAlert", "DeadlineReference", "DeltaDataIntervalTimetable", @@ -52,9 +53,11 @@ "EdgeModifier", "EventsTimetable", "ExceptionRetryPolicy", + "HourWindow", "IdentityMapper", "Label", "Metadata", + "MonthWindow", "MultipleCronTriggerTimetable", "ObjectStoragePath", "Param", @@ -63,10 +66,12 @@ "PartitionMapper", "PokeReturnValue", "ProductMapper", + "QuarterWindow", "RetryAction", "RetryDecision", "RetryPolicy", "RetryRule", + "RollupMapper", "SkipMixin", "SyncCallback", "StartOfDayMapper", @@ -75,15 +80,16 @@ "StartOfQuarterMapper", "StartOfWeekMapper", "StartOfYearMapper", - "WeeklyRollupMapper", - "MonthlyRollupMapper", "TaskGroup", "TaskInstance", "TaskInstanceState", "TriggerRule", "Variable", + "WeekWindow", "WeightRule", + "Window", "XComArg", + "YearWindow", "asset", "chain", "chain_linear", @@ -133,19 +139,26 @@ from airflow.sdk.definitions.edges import EdgeModifier, Label from airflow.sdk.definitions.param import Param, ParamsDict from airflow.sdk.definitions.partition_mappers.allowed_key import AllowedKeyMapper - from airflow.sdk.definitions.partition_mappers.base import PartitionMapper + from airflow.sdk.definitions.partition_mappers.base import PartitionMapper, RollupMapper from airflow.sdk.definitions.partition_mappers.chain import ChainMapper from airflow.sdk.definitions.partition_mappers.identity import IdentityMapper from airflow.sdk.definitions.partition_mappers.product import ProductMapper from airflow.sdk.definitions.partition_mappers.temporal import ( - MonthlyRollupMapper, StartOfDayMapper, StartOfHourMapper, StartOfMonthMapper, StartOfQuarterMapper, StartOfWeekMapper, StartOfYearMapper, - WeeklyRollupMapper, + ) + from airflow.sdk.definitions.partition_mappers.window import ( + DayWindow, + HourWindow, + MonthWindow, + QuarterWindow, + WeekWindow, + Window, + YearWindow, ) from airflow.sdk.definitions.retry_policy import ( ExceptionRetryPolicy, @@ -205,6 +218,7 @@ "CronPartitionTimetable": ".definitions.timetables.trigger", "DAG": ".definitions.dag", "DagRunState": ".api.datamodels._generated", + "DayWindow": ".definitions.partition_mappers.window", "DeadlineAlert": ".definitions.deadline", "DeadlineReference": ".definitions.deadline", "DeltaDataIntervalTimetable": ".definitions.timetables.interval", @@ -212,9 +226,11 @@ "EdgeModifier": ".definitions.edges", "EventsTimetable": ".definitions.timetables.events", "ExceptionRetryPolicy": ".definitions.retry_policy", + "HourWindow": ".definitions.partition_mappers.window", "IdentityMapper": ".definitions.partition_mappers.identity", "Label": ".definitions.edges", "Metadata": ".definitions.asset.metadata", + "MonthWindow": ".definitions.partition_mappers.window", "MultipleCronTriggerTimetable": ".definitions.timetables.trigger", "ObjectStoragePath": ".io.path", "Param": ".definitions.param", @@ -223,10 +239,12 @@ "PartitionMapper": ".definitions.partition_mappers.base", "PokeReturnValue": ".bases.sensor", "ProductMapper": ".definitions.partition_mappers.product", + "QuarterWindow": ".definitions.partition_mappers.window", "RetryAction": ".definitions.retry_policy", "RetryDecision": ".definitions.retry_policy", "RetryPolicy": ".definitions.retry_policy", "RetryRule": ".definitions.retry_policy", + "RollupMapper": ".definitions.partition_mappers.base", "SecretCache": ".execution_time.cache", "SkipMixin": ".bases.skipmixin", "SyncCallback": ".definitions.callback", @@ -236,15 +254,16 @@ "StartOfQuarterMapper": ".definitions.partition_mappers.temporal", "StartOfWeekMapper": ".definitions.partition_mappers.temporal", "StartOfYearMapper": ".definitions.partition_mappers.temporal", - "WeeklyRollupMapper": ".definitions.partition_mappers.temporal", - "MonthlyRollupMapper": ".definitions.partition_mappers.temporal", "TaskGroup": ".definitions.taskgroup", "TaskInstance": ".types", "TaskInstanceState": ".api.datamodels._generated", "TriggerRule": ".api.datamodels._generated", "Variable": ".definitions.variable", + "WeekWindow": ".definitions.partition_mappers.window", "WeightRule": ".api.datamodels._generated", + "Window": ".definitions.partition_mappers.window", "XComArg": ".definitions.xcom_arg", + "YearWindow": ".definitions.partition_mappers.window", "asset": ".definitions.asset.decorators", "chain": ".bases.operator", "chain_linear": ".bases.operator", diff --git a/task-sdk/src/airflow/sdk/__init__.pyi b/task-sdk/src/airflow/sdk/__init__.pyi index 854cff94335d0..2db7e57aa671a 100644 --- a/task-sdk/src/airflow/sdk/__init__.pyi +++ b/task-sdk/src/airflow/sdk/__init__.pyi @@ -63,19 +63,26 @@ from airflow.sdk.definitions.decorators.task_group import task_group as task_gro from airflow.sdk.definitions.edges import EdgeModifier as EdgeModifier, Label as Label from airflow.sdk.definitions.param import Param as Param from airflow.sdk.definitions.partition_mappers.allowed_key import AllowedKeyMapper -from airflow.sdk.definitions.partition_mappers.base import PartitionMapper +from airflow.sdk.definitions.partition_mappers.base import PartitionMapper, RollupMapper from airflow.sdk.definitions.partition_mappers.chain import ChainMapper from airflow.sdk.definitions.partition_mappers.identity import IdentityMapper from airflow.sdk.definitions.partition_mappers.product import ProductMapper from airflow.sdk.definitions.partition_mappers.temporal import ( - MonthlyRollupMapper, StartOfDayMapper, StartOfHourMapper, StartOfMonthMapper, StartOfQuarterMapper, StartOfWeekMapper, StartOfYearMapper, - WeeklyRollupMapper, +) +from airflow.sdk.definitions.partition_mappers.window import ( + DayWindow, + HourWindow, + MonthWindow, + QuarterWindow, + WeekWindow, + Window, + YearWindow, ) from airflow.sdk.definitions.retry_policy import ( ExceptionRetryPolicy as ExceptionRetryPolicy, @@ -135,14 +142,17 @@ __all__ = [ "CronPartitionTimetable", "DAG", "DagRunState", + "DayWindow", "DeltaDataIntervalTimetable", "DeltaTriggerTimetable", "EdgeModifier", "EventsTimetable", "ExceptionRetryPolicy", + "HourWindow", "IdentityMapper", "Label", "Metadata", + "MonthWindow", "MultipleCronTriggerTimetable", "ObjectStoragePath", "Param", @@ -150,10 +160,12 @@ __all__ = [ "PartitionedAssetTimetable", "PartitionMapper", "ProductMapper", + "QuarterWindow", "RetryAction", "RetryDecision", "RetryPolicy", "RetryRule", + "RollupMapper", "SecretCache", "SkipMixin", "StartOfDayMapper", @@ -162,14 +174,15 @@ __all__ = [ "StartOfQuarterMapper", "StartOfWeekMapper", "StartOfYearMapper", - "WeeklyRollupMapper", - "MonthlyRollupMapper", "TaskGroup", "TaskInstanceState", "TriggerRule", "Variable", + "WeekWindow", "WeightRule", + "Window", "XComArg", + "YearWindow", "asset", "chain", "chain_linear", diff --git a/task-sdk/src/airflow/sdk/definitions/partition_mappers/base.py b/task-sdk/src/airflow/sdk/definitions/partition_mappers/base.py index 00c4f57412f2f..78383b0b34dda 100644 --- a/task-sdk/src/airflow/sdk/definitions/partition_mappers/base.py +++ b/task-sdk/src/airflow/sdk/definitions/partition_mappers/base.py @@ -16,6 +16,11 @@ # under the License. from __future__ import annotations +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from airflow.sdk.definitions.partition_mappers.window import Window + class PartitionMapper: """ @@ -29,12 +34,16 @@ class PartitionMapper: class RollupMapper(PartitionMapper): """ - Partition mapper that supports rollup (many upstream keys → one downstream key). + Partition mapper that rolls up many upstream keys into one downstream key. - Subclass this when the downstream Dag should wait for a complete set of upstream - partition keys before triggering. The scheduler calls ``to_upstream`` to discover - which source keys are required and only creates a Dag run once all of them have - arrived in ``PartitionedAssetKeyLog``. + Compose a ``source_mapper`` (which normalizes each upstream key to the + downstream granularity) with a ``window`` that declares the full set of + upstream keys required for a given downstream key. The scheduler holds + the Dag run until every upstream key in the window has arrived. """ is_rollup: bool = True + + def __init__(self, *, source_mapper: PartitionMapper, window: Window) -> None: + self.source_mapper = source_mapper + self.window = window diff --git a/task-sdk/src/airflow/sdk/definitions/partition_mappers/temporal.py b/task-sdk/src/airflow/sdk/definitions/partition_mappers/temporal.py index 1dc7cd56f1b42..030d0fd32cb60 100644 --- a/task-sdk/src/airflow/sdk/definitions/partition_mappers/temporal.py +++ b/task-sdk/src/airflow/sdk/definitions/partition_mappers/temporal.py @@ -16,7 +16,7 @@ # under the License. from __future__ import annotations -from airflow.sdk.definitions.partition_mappers.base import PartitionMapper, RollupMapper +from airflow.sdk.definitions.partition_mappers.base import PartitionMapper class _BaseTemporalMapper(PartitionMapper): @@ -49,38 +49,24 @@ class StartOfWeekMapper(_BaseTemporalMapper): default_output_format = "%Y-%m-%d (W%V)" def __init__(self, *, week_start: int = 0, **kwargs) -> None: + if not 0 <= week_start <= 6: + raise ValueError(f"week_start must be between 0 (Monday) and 6 (Sunday), got {week_start!r}") super().__init__(**kwargs) self.week_start = week_start -class WeeklyRollupMapper(StartOfWeekMapper, RollupMapper): - """ - Map a time-based partition key to the start of its week, requiring all 7 daily keys. - - Use this when a partitioned Dag should only run once every daily asset partition - for a full week has been produced. - """ - - class StartOfMonthMapper(_BaseTemporalMapper): """Map a time-based partition key to the start of its month.""" default_output_format = "%Y-%m" def __init__(self, *, month_start_day: int = 1, **kwargs) -> None: + if not 1 <= month_start_day <= 28: + raise ValueError(f"month_start_day must be between 1 and 28, got {month_start_day!r}") super().__init__(**kwargs) self.month_start_day = month_start_day -class MonthlyRollupMapper(StartOfMonthMapper, RollupMapper): - """ - Map a time-based partition key to the start of its month, requiring all daily keys in that month. - - Use this when a partitioned Dag should only run once every daily asset partition - for a full calendar month has been produced. - """ - - class StartOfQuarterMapper(_BaseTemporalMapper): """Map a time-based partition key to quarter.""" diff --git a/task-sdk/src/airflow/sdk/definitions/partition_mappers/window.py b/task-sdk/src/airflow/sdk/definitions/partition_mappers/window.py new file mode 100644 index 0000000000000..92b389a326251 --- /dev/null +++ b/task-sdk/src/airflow/sdk/definitions/partition_mappers/window.py @@ -0,0 +1,53 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + + +class Window: + """ + Describes a rollup window: which upstream keys make up one downstream key. + + Paired with a ``source_mapper`` :class:`PartitionMapper` inside a + :class:`RollupMapper`. The source_mapper normalizes upstream keys to the + downstream granularity; the window enumerates the complete set of + upstream keys that roll up into one downstream key. Runtime logic + lives in ``airflow.partition_mappers.window`` on the scheduler side. + """ + + +class HourWindow(Window): + """Sixty consecutive minute keys making up one hour.""" + + +class DayWindow(Window): + """Twenty-four consecutive hourly keys making up one day.""" + + +class WeekWindow(Window): + """Seven consecutive daily keys making up one week.""" + + +class MonthWindow(Window): + """All daily keys making up one calendar month.""" + + +class QuarterWindow(Window): + """Three consecutive monthly keys making up one calendar quarter.""" + + +class YearWindow(Window): + """Twelve consecutive monthly keys making up one calendar year.""" From 2d6cfe06ba1569b8bfcf8ef2a0b33cf774c39353 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Mon, 4 May 2026 21:35:08 +0800 Subject: [PATCH 09/16] refactor(partition-mappers): decode rollup keys from any output_format ordering MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit StartOfWeekMapper and StartOfQuarterMapper now derive their decode_downstream regex from output_format itself, so users can re-order strftime directives and {name} placeholders (e.g. "Q{quarter}/%Y") without having to override decode_downstream. Malformed output_format — empty {}, non-identifier placeholder names, duplicate %X directives, duplicate {name} placeholders — raises ValueError at mapper construction instead of an opaque re.error from deep inside a scheduler tick or UI route. --- .../src/airflow/partition_mappers/temporal.py | 134 +++++++++++++++--- .../unit/partition_mappers/test_temporal.py | 28 ++++ .../unit/partition_mappers/test_window.py | 13 +- 3 files changed, 155 insertions(+), 20 deletions(-) diff --git a/airflow-core/src/airflow/partition_mappers/temporal.py b/airflow-core/src/airflow/partition_mappers/temporal.py index e82e943333d6b..5101f982cf2fc 100644 --- a/airflow-core/src/airflow/partition_mappers/temporal.py +++ b/airflow-core/src/airflow/partition_mappers/temporal.py @@ -19,6 +19,7 @@ import re from abc import ABC, abstractmethod from datetime import datetime, timedelta +from functools import cached_property from typing import TYPE_CHECKING, Any from airflow._shared.timezones.timezone import make_aware, parse_timezone @@ -28,6 +29,81 @@ from pendulum import FixedTimezone, Timezone +_STRPTIME_PATTERNS: dict[str, str] = { + "%Y": r"\d{4}", + "%m": r"\d{2}", + "%d": r"\d{2}", + "%H": r"\d{2}", + "%M": r"\d{2}", + "%S": r"\d{2}", + "%V": r"\d{2}", + "%U": r"\d{2}", + "%W": r"\d{2}", +} + + +def _compile_output_format_regex( + fmt: str, placeholder_patterns: dict[str, str] | None = None +) -> re.Pattern[str]: + r""" + Compile *fmt* into a regex with named groups so a formatted key can be parsed back. + + Strftime directives in ``_STRPTIME_PATTERNS`` (``%Y``, ``%m``, ``%V`` …) become + named groups keyed on the directive letter — e.g. ``%Y`` → ``(?P\d{4})``. + FastAPI-style ``{name}`` placeholders become named groups keyed on *name*; the + regex defaults to ``\w+``, with *placeholder_patterns* letting callers narrow + it (e.g. ``{"quarter": r"[1-4]"}``). Anything else is escaped as a literal. + + The regex is anchored with ``.search`` rather than ``.fullmatch`` so callers + can extract fields without caring about literal prefix/suffix length. + + Raises :exc:`ValueError` if *fmt* cannot produce a valid regex — empty + ``{}`` placeholders, names that aren't valid Python identifiers, or any + duplicate group name (a directive or placeholder appearing twice). Catching + this at construction time keeps a misconfigured mapper from surfacing as an + opaque ``re.error`` deep inside the scheduler tick or a UI route. + """ + placeholder_patterns = placeholder_patterns or {} + parts: list[str] = [] + seen_groups: set[str] = set() + i = 0 + while i < len(fmt): + if fmt[i] == "%" and i + 1 < len(fmt) and fmt[i : i + 2] in _STRPTIME_PATTERNS: + directive = fmt[i : i + 2] + name = directive[1] + if name in seen_groups: + raise ValueError( + f"output_format {fmt!r} reuses directive {directive!r}; " + "each strftime directive may appear at most once." + ) + seen_groups.add(name) + parts.append(f"(?P<{name}>{_STRPTIME_PATTERNS[directive]})") + i += 2 + continue + if fmt[i] == "{": + end = fmt.find("}", i) + if end != -1: + name = fmt[i + 1 : end] + if not name.isidentifier(): + raise ValueError( + f"output_format {fmt!r} has invalid placeholder {{{name}}}; " + "placeholder names must be valid Python identifiers." + ) + if name in seen_groups: + raise ValueError( + f"output_format {fmt!r} reuses placeholder {{{name}}}; " + "each placeholder may appear at most once." + ) + seen_groups.add(name) + pattern = placeholder_patterns.get(name, r"\w+") + parts.append(f"(?P<{name}>{pattern})") + i = end + 1 + continue + parts.append(re.escape(fmt[i])) + i += 1 + return re.compile("".join(parts)) + + class _BaseTemporalMapper(PartitionMapper, ABC): """Base class for Temporal Partition Mappers.""" @@ -126,7 +202,6 @@ class StartOfWeekMapper(_BaseTemporalMapper): """Map a time-based partition key to the start of its week.""" default_output_format = "%Y-%m-%d (W%V)" - _YMD_RE = re.compile(r"\d{4}-\d{2}-\d{2}") def __init__( self, @@ -140,23 +215,30 @@ def __init__( raise ValueError(f"week_start must be between 0 (Monday) and 6 (Sunday), got {week_start!r}") super().__init__(timezone=timezone, input_format=input_format, output_format=output_format) self.week_start = week_start # 0 = Monday (ISO default), 6 = Sunday + # Force-compile so a malformed output_format raises ValueError at + # construction time instead of an opaque re.error inside the scheduler. + _ = self._key_pattern def normalize(self, dt: datetime) -> datetime: days_since_start = (dt.weekday() - self.week_start) % 7 start = dt - timedelta(days=days_since_start) return start.replace(hour=0, minute=0, second=0, microsecond=0) + @cached_property + def _key_pattern(self) -> re.Pattern[str]: + # %V (ISO week) cannot be round-tripped through strptime without %G+%u, + # so derive a named-group regex from output_format and pull out %Y/%m/%d. + return _compile_output_format_regex(self.output_format) + def decode_downstream(self, downstream_key: str) -> datetime: - # %V (ISO week) cannot be parsed by strptime without %G+%u, so locate - # the YYYY-MM-DD slice with a regex. Robust across formats that mix - # the date with extras like "(W%V)". - match = self._YMD_RE.search(downstream_key) - if match is None: + match = self._key_pattern.search(downstream_key) + if match is None or not {"Y", "m", "d"}.issubset(match.groupdict()): raise ValueError( - f"StartOfWeekMapper.decode_downstream could not locate YYYY-MM-DD in {downstream_key!r}; " - "output_format must include '%Y-%m-%d'." + f"StartOfWeekMapper.decode_downstream could not parse {downstream_key!r} " + f"with output_format {self.output_format!r}; " + "output_format must include the %Y, %m and %d directives." ) - return datetime.strptime(match.group(), "%Y-%m-%d") + return datetime(int(match["Y"]), int(match["m"]), int(match["d"])) def serialize(self) -> dict[str, Any]: return {**super().serialize(), "week_start": self.week_start} @@ -220,7 +302,18 @@ class StartOfQuarterMapper(_BaseTemporalMapper): """Map a time-based partition key to quarter.""" default_output_format = "%Y-Q{quarter}" - _YEAR_QUARTER_RE = re.compile(r"(\d{4}).*?Q([1-4])") + + def __init__( + self, + *, + timezone: str | Timezone | FixedTimezone = "UTC", + input_format: str = "%Y-%m-%dT%H:%M:%S", + output_format: str | None = None, + ) -> None: + super().__init__(timezone=timezone, input_format=input_format, output_format=output_format) + # Force-compile so a malformed output_format raises ValueError at + # construction time instead of an opaque re.error inside the scheduler. + _ = self._key_pattern def normalize(self, dt: datetime) -> datetime: quarter = (dt.month - 1) // 3 @@ -238,17 +331,22 @@ def format(self, dt: datetime) -> str: quarter = (dt.month - 1) // 3 + 1 return dt.strftime(self.output_format).format(quarter=quarter) + @cached_property + def _key_pattern(self) -> re.Pattern[str]: + # ``{quarter}`` is a Python-format placeholder, not a strftime directive, + # so derive a named-group regex from output_format that handles both. + return _compile_output_format_regex(self.output_format, {"quarter": r"[1-4]"}) + def decode_downstream(self, downstream_key: str) -> datetime: - # output_format carries a ``{quarter}`` placeholder, so strptime doesn't - # apply directly. Locate ``YYYY...Q`` and rebuild the period start. - match = self._YEAR_QUARTER_RE.search(downstream_key) - if match is None: + match = self._key_pattern.search(downstream_key) + if match is None or not {"Y", "quarter"}.issubset(match.groupdict()): raise ValueError( - f"StartOfQuarterMapper.decode_downstream could not locate YYYY...Q in " - f"{downstream_key!r}; output_format must include the year and 'Q{{quarter}}'." + f"StartOfQuarterMapper.decode_downstream could not parse {downstream_key!r} " + f"with output_format {self.output_format!r}; " + "output_format must include the %Y directive and the {quarter} placeholder." ) - year = int(match.group(1)) - quarter = int(match.group(2)) + year = int(match["Y"]) + quarter = int(match["quarter"]) return datetime(year, (quarter - 1) * 3 + 1, 1) diff --git a/airflow-core/tests/unit/partition_mappers/test_temporal.py b/airflow-core/tests/unit/partition_mappers/test_temporal.py index 170856f883182..60d6ffc2a0352 100644 --- a/airflow-core/tests/unit/partition_mappers/test_temporal.py +++ b/airflow-core/tests/unit/partition_mappers/test_temporal.py @@ -152,3 +152,31 @@ def test_rejects_out_of_range(self, month_start_day): @pytest.mark.parametrize("month_start_day", [1, 15, 28]) def test_accepts_valid_range(self, month_start_day): assert StartOfMonthMapper(month_start_day=month_start_day).month_start_day == month_start_day + + +class TestOutputFormatValidation: + """Malformed output_format must fail fast at mapper construction, not as an opaque re.error inside the scheduler.""" + + @pytest.mark.parametrize( + ("output_format", "match"), + [ + ("a{}b", "invalid placeholder"), + ("%Y-{1bad}", "invalid placeholder"), + ("%Y-%Y-%m-%d", "reuses directive"), + ("%Y-Q{quarter}-{quarter}", "reuses placeholder"), + ], + ) + def test_quarter_mapper_rejects_malformed_format(self, output_format, match): + with pytest.raises(ValueError, match=match): + StartOfQuarterMapper(output_format=output_format) + + @pytest.mark.parametrize( + ("output_format", "match"), + [ + ("week-{}", "invalid placeholder"), + ("%Y-%m-%d-%Y", "reuses directive"), + ], + ) + def test_week_mapper_rejects_malformed_format(self, output_format, match): + with pytest.raises(ValueError, match=match): + StartOfWeekMapper(output_format=output_format) diff --git a/airflow-core/tests/unit/partition_mappers/test_window.py b/airflow-core/tests/unit/partition_mappers/test_window.py index 027ef32bf4ebe..269f9b2ad422f 100644 --- a/airflow-core/tests/unit/partition_mappers/test_window.py +++ b/airflow-core/tests/unit/partition_mappers/test_window.py @@ -185,7 +185,7 @@ def test_week_rollup_raises_when_downstream_key_has_no_date(self): source_mapper=StartOfWeekMapper(input_format="%Y-%m-%d"), window=WeekWindow(), ) - with pytest.raises(ValueError, match="could not locate YYYY-MM-DD"): + with pytest.raises(ValueError, match="StartOfWeekMapper.decode_downstream could not parse"): mapper.to_upstream("week-24") def test_month_rollup_fiscal_offset(self): @@ -211,9 +211,18 @@ def test_quarter_rollup_raises_when_marker_missing(self): source_mapper=StartOfQuarterMapper(input_format="%Y-%m"), window=QuarterWindow(), ) - with pytest.raises(ValueError, match="could not locate YYYY...Q"): + with pytest.raises(ValueError, match="StartOfQuarterMapper.decode_downstream could not parse"): mapper.to_upstream("2024-06") + def test_quarter_rollup_accepts_reordered_output_format(self): + # ``{quarter}`` and ``%Y`` can appear in any order; the format-derived + # regex pulls out both via named groups. + mapper = RollupMapper( + source_mapper=StartOfQuarterMapper(input_format="%Y-%m", output_format="Q{quarter}/%Y"), + window=QuarterWindow(), + ) + assert mapper.to_upstream("Q2/2024") == frozenset({"2024-04", "2024-05", "2024-06"}) + def test_year_rollup_to_upstream_keys(self): mapper = RollupMapper( source_mapper=StartOfYearMapper(input_format="%Y-%m"), From 7929264d402d655d766b69961dd7b94ece1101a7 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Tue, 5 May 2026 17:02:14 +0800 Subject: [PATCH 10/16] fix: harden rollup mapper edge cases (boundary, error suppression, UI) --- .../routes/ui/partitioned_dag_runs.py | 33 ++++-- .../src/airflow/partition_mappers/window.py | 18 ++- .../components/AssetExpression/AssetNode.tsx | 12 +- .../tests/unit/jobs/test_scheduler_job.py | 111 ++++++++++++++++++ .../unit/partition_mappers/test_window.py | 15 +++ 5 files changed, 173 insertions(+), 16 deletions(-) diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/partitioned_dag_runs.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/partitioned_dag_runs.py index f7504b6df3d5e..31915b4dd563c 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/partitioned_dag_runs.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/partitioned_dag_runs.py @@ -95,13 +95,23 @@ def _compute_total_required( asset_info: list[AssetNameUri], partition_key: str, ) -> int: - """Sum required upstream events across all assets, using to_upstream for rollup mappers.""" + """ + Sum required upstream events across all assets, using to_upstream for rollup mappers. + + A misconfigured custom mapper that raises in ``to_upstream`` falls back to the + non-rollup count (1 per asset) for that asset rather than failing the whole + request — matches the suppression already used in the detail route. + """ if timetable is None: return len(asset_info) total = 0 for name, uri in asset_info: - mapper = timetable.get_partition_mapper(name=name, uri=uri) - total += len(cast("RollupMapper", mapper).to_upstream(partition_key)) if mapper.is_rollup else 1 + count = 1 + with suppress(Exception): + mapper = timetable.get_partition_mapper(name=name, uri=uri) + if mapper.is_rollup: + count = len(cast("RollupMapper", mapper).to_upstream(partition_key)) + total += count return total @@ -117,17 +127,20 @@ def _compute_received_count( For rollup assets: count distinct upstream keys that intersect the required set. For non-rollup assets: count 1 per asset if any event has been logged — the source_partition_key value is irrelevant; having any event satisfies the requirement. + A misconfigured rollup mapper falls back to the non-rollup behavior for that + asset so the route doesn't 500. """ total = 0 for asset_id, received_keys in received_by_asset.items(): if timetable is not None: - name, uri = asset_id_to_info[asset_id] - mapper = timetable.get_partition_mapper(name=name, uri=uri) - if mapper.is_rollup: - required_keys = frozenset(cast("RollupMapper", mapper).to_upstream(partition_key)) - total += len(received_keys & required_keys) - continue - # Non-rollup: any logged event satisfies this asset's requirement. + with suppress(Exception): + name, uri = asset_id_to_info[asset_id] + mapper = timetable.get_partition_mapper(name=name, uri=uri) + if mapper.is_rollup: + required_keys = frozenset(cast("RollupMapper", mapper).to_upstream(partition_key)) + total += len(received_keys & required_keys) + continue + # Non-rollup (or rollup mapper raised): any logged event satisfies this asset. total += 1 if received_keys else 0 return total diff --git a/airflow-core/src/airflow/partition_mappers/window.py b/airflow-core/src/airflow/partition_mappers/window.py index 1913c2295b06b..ea2b2677150db 100644 --- a/airflow-core/src/airflow/partition_mappers/window.py +++ b/airflow-core/src/airflow/partition_mappers/window.py @@ -24,6 +24,20 @@ from collections.abc import Iterable +def _shift_months(dt: datetime, months: int) -> datetime: + """ + Return *dt* shifted forward by *months*, wrapping the year as needed. + + Built-in temporal mappers always emit period-starts on day 1 (or + ``month_start_day`` <= 28 for fiscal months), so :meth:`datetime.replace` + on the new month is always valid. A user-provided source mapper that + emits a higher day for a month with fewer days remains the caller's + responsibility. + """ + total = dt.month - 1 + months + return dt.replace(year=dt.year + total // 12, month=total % 12 + 1) + + class Window(ABC): """ Describes a rollup window: which decoded upstream items make up one decoded downstream period. @@ -96,11 +110,11 @@ class QuarterWindow(Window): """Three monthly period-starts making up one calendar quarter (e.g. Jan/Feb/Mar for Q1).""" def to_upstream(self, period_start: datetime) -> Iterable[datetime]: - return (period_start.replace(month=period_start.month + i) for i in range(3)) + return (_shift_months(period_start, i) for i in range(3)) class YearWindow(Window): """Twelve consecutive monthly period-starts making up one calendar year.""" def to_upstream(self, period_start: datetime) -> Iterable[datetime]: - return (period_start.replace(month=i + 1) for i in range(12)) + return (_shift_months(period_start, i) for i in range(12)) diff --git a/airflow-core/src/airflow/ui/src/components/AssetExpression/AssetNode.tsx b/airflow-core/src/airflow/ui/src/components/AssetExpression/AssetNode.tsx index c29b89e83324d..678d66760b274 100644 --- a/airflow-core/src/airflow/ui/src/components/AssetExpression/AssetNode.tsx +++ b/airflow-core/src/airflow/ui/src/components/AssetExpression/AssetNode.tsx @@ -92,10 +92,14 @@ export const AssetNode = ({ !isFullyReceived && (event?.received_count ?? 0) > 0 && (event?.received_count ?? 0) < (event?.required_count ?? 1); - // In partitioned dags the `last_update` timestamp is the last asset event, not the - // pending partition key's arrival — so it isn't meaningful here. Suppress it for - // non-rollup partitioned nodes. - const isPartitionedNonRollup = event?.is_rollup === false; + // In a partitioned Dag with a pending partition, `last_update` is the last + // asset event, not the pending partition key's arrival — so suppress it for + // non-rollup partitioned nodes. We detect that case via `required_keys`, + // which only the partitioned + pending-APDR branch of `next_run_assets` + // populates. Plain non-partitioned events leave `required_keys` empty and + // must keep showing their timestamp. + const isPartitionedNonRollup = + event?.is_rollup === false && (event?.required_keys?.length ?? 0) > 0; const showTime = isFullyReceived && !isPartitionedNonRollup; const showRollupChecklist = (event?.is_rollup ?? false) && (event?.required_keys?.length ?? 0) > 0 && (isPartial || isFullyReceived); diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py b/airflow-core/tests/unit/jobs/test_scheduler_job.py index 56e69459104ea..ff15f91485152 100644 --- a/airflow-core/tests/unit/jobs/test_scheduler_job.py +++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py @@ -92,7 +92,9 @@ AssetAlias, AssetWatcher, CronPartitionTimetable, + HourWindow, IdentityMapper, + RollupMapper, StartOfHourMapper, task, ) @@ -9927,6 +9929,115 @@ def test_consumer_dag_listen_to_two_partitioned_asset( assert asset_event.source_run_id == "test" +@pytest.mark.need_serialized_dag +@pytest.mark.usefixtures("clear_asset_partition_rows") +def test_partitioned_dag_run_rollup_holds_until_window_complete( + dag_maker: DagMaker, + session: Session, +): + """A rollup APDR stays pending until every upstream key in the window has arrived.""" + asset_1 = Asset(name="asset-1") + with dag_maker( + dag_id="rollup-consumer", + schedule=PartitionedAssetTimetable( + assets=asset_1, + default_partition_mapper=RollupMapper( + source_mapper=StartOfHourMapper(), + window=HourWindow(), + ), + ), + session=session, + ): + EmptyOperator(task_id="hi") + session.commit() + + runner = SchedulerJobRunner( + job=Job(job_type=SchedulerJobRunner.job_type), executors=[MockExecutor(do_update=False)] + ) + + # First minute key arrives — only 1 / 60 upstream keys, so the APDR must + # not fire yet. + apdr = _produce_and_register_asset_event( + dag_id="rollup-producer-0", + asset=asset_1, + partition_key="2024-01-01T00:00:00", + session=session, + dag_maker=dag_maker, + expected_partition_key="2024-01-01T00", + ) + partition_dags = runner._create_dagruns_for_partitioned_asset_dags(session=session) + session.refresh(apdr) + assert apdr.created_dag_run_id is None + assert partition_dags == set() + + # Send the remaining 59 minute keys — once all 60 are present the rollup is + # satisfied and the APDR creates its Dag run on the next tick. + for minute in range(1, 60): + _produce_and_register_asset_event( + dag_id=f"rollup-producer-{minute}", + asset=asset_1, + partition_key=f"2024-01-01T00:{minute:02d}:00", + session=session, + dag_maker=dag_maker, + expected_partition_key="2024-01-01T00", + ) + partition_dags = runner._create_dagruns_for_partitioned_asset_dags(session=session) + session.refresh(apdr) + assert apdr.created_dag_run_id is not None + assert partition_dags == {"rollup-consumer"} + + +@pytest.mark.need_serialized_dag +@pytest.mark.usefixtures("clear_asset_partition_rows") +def test_partitioned_dag_run_rollup_treats_mapper_exception_as_not_satisfied( + dag_maker: DagMaker, + session: Session, +): + """ + A misconfigured rollup mapper that raises during status evaluation must not crash + the scheduler tick — the asset is treated as not-yet-satisfied and the APDR + remains pending. + """ + asset_1 = Asset(name="asset-1") + with dag_maker( + dag_id="rollup-consumer", + schedule=PartitionedAssetTimetable( + assets=asset_1, + default_partition_mapper=RollupMapper( + source_mapper=StartOfHourMapper(), + window=HourWindow(), + ), + ), + session=session, + ): + EmptyOperator(task_id="hi") + session.commit() + + runner = SchedulerJobRunner( + job=Job(job_type=SchedulerJobRunner.job_type), executors=[MockExecutor(do_update=False)] + ) + + apdr = _produce_and_register_asset_event( + dag_id="rollup-producer-0", + asset=asset_1, + partition_key="2024-01-01T00:00:00", + session=session, + dag_maker=dag_maker, + expected_partition_key="2024-01-01T00", + ) + + with mock.patch.object( + SchedulerJobRunner, + "_check_rollup_asset_status", + side_effect=RuntimeError("misconfigured rollup mapper"), + ): + partition_dags = runner._create_dagruns_for_partitioned_asset_dags(session=session) + + session.refresh(apdr) + assert apdr.created_dag_run_id is None + assert partition_dags == set() + + @pytest.mark.need_serialized_dag @pytest.mark.usefixtures("clear_asset_partition_rows") def test_consumer_dag_listen_to_two_partitioned_asset_with_key_1_mapper( diff --git a/airflow-core/tests/unit/partition_mappers/test_window.py b/airflow-core/tests/unit/partition_mappers/test_window.py index 269f9b2ad422f..0faaa2bd92f4d 100644 --- a/airflow-core/tests/unit/partition_mappers/test_window.py +++ b/airflow-core/tests/unit/partition_mappers/test_window.py @@ -99,12 +99,27 @@ def test_yields_three_monthly_period_starts(self): members = list(QuarterWindow().to_upstream(datetime(2024, 4, 1))) assert members == [datetime(2024, 4, 1), datetime(2024, 5, 1), datetime(2024, 6, 1)] + def test_wraps_year_for_non_calendar_quarter_start(self): + # A custom source mapper might decode a quarter starting in November + # (e.g. fiscal Q4); the window must wrap into the next year. + members = list(QuarterWindow().to_upstream(datetime(2024, 11, 1))) + assert members == [datetime(2024, 11, 1), datetime(2024, 12, 1), datetime(2025, 1, 1)] + class TestYearWindow: def test_yields_twelve_monthly_period_starts(self): members = list(YearWindow().to_upstream(datetime(2024, 1, 1))) assert members == [datetime(2024, m, 1) for m in range(1, 13)] + def test_wraps_year_for_fiscal_year_start(self): + # A fiscal year starting in April rolls forward 12 months from April, + # crossing the calendar boundary into the following year. + members = list(YearWindow().to_upstream(datetime(2024, 4, 1))) + assert members[0] == datetime(2024, 4, 1) + assert members[8] == datetime(2024, 12, 1) + assert members[9] == datetime(2025, 1, 1) + assert members[-1] == datetime(2025, 3, 1) + class TestRollupMapperComposition: def test_to_downstream_delegates_to_source_mapper(self): From 3cb92344455c2b27200c7984d6de5d488b46e062 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Tue, 5 May 2026 18:12:41 +0800 Subject: [PATCH 11/16] fix: report rollup-aware total_received consistently in partitioned_dag_runs list MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Drop the SQL "count distinct assets with any log" subquery and always compute total_received via the Python rollup-aware helper. The list endpoint previously returned different numbers for the same APDR depending on whether the caller filtered by dag_id (rollup-aware, counts upstream window keys) or queried globally (SQL approximation, counts assets with any log) — same field, different semantics, very confusing for any UI consumer. The N+1 cost of per-Dag timetable loads was already paid in the global branch for total_required, so adding a single batched log fetch keeps the existing query budget while making the contract identical across both views. _compute_received_count now skips asset_ids that are no longer required (active=False) so the relaxed log query doesn't over-count. --- .../routes/ui/partitioned_dag_runs.py | 107 +++++++----------- 1 file changed, 42 insertions(+), 65 deletions(-) diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/partitioned_dag_runs.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/partitioned_dag_runs.py index 31915b4dd563c..254d799412034 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/partitioned_dag_runs.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/partitioned_dag_runs.py @@ -20,7 +20,7 @@ from typing import TYPE_CHECKING, TypeAlias, cast from fastapi import Depends, HTTPException, status -from sqlalchemy import func, select +from sqlalchemy import select from airflow.api_fastapi.common.db.common import SessionDep, apply_filters_to_select from airflow.api_fastapi.common.parameters import ( @@ -132,6 +132,11 @@ def _compute_received_count( """ total = 0 for asset_id, received_keys in received_by_asset.items(): + # Logs may reference assets that are no longer required (active=False) + # — skip them rather than over-counting. The list route relies on this + # because it batch-fetches logs without a per-Dag asset_id filter. + if asset_id not in asset_id_to_info: + continue if timetable is not None: with suppress(Exception): name, uri = asset_id_to_info[asset_id] @@ -148,13 +153,13 @@ def _compute_received_count( partitioned_dag_runs_router = AirflowRouter(tags=["PartitionedDagRun"]) -def _build_response(row, required_count: int, received_count: int | None = None) -> PartitionedDagRunResponse: +def _build_response(row, required_count: int, received_count: int) -> PartitionedDagRunResponse: return PartitionedDagRunResponse( id=row.id, dag_id=row.target_dag_id, partition_key=row.partition_key, created_at=row.created_at.isoformat() if row.created_at else None, - total_received=received_count if received_count is not None else (row.total_received or 0), + total_received=received_count, total_required=required_count, state=row.dag_run_state if row.created_dag_run_id else "pending", created_dag_run_id=row.dag_run_id, @@ -172,34 +177,12 @@ def get_partitioned_dag_runs( has_created_dag_run_id: QueryPartitionedDagRunHasCreatedDagRunIdFilter, ) -> PartitionedDagRunCollectionResponse: """Return PartitionedDagRuns. Filter by dag_id and/or has_created_dag_run_id.""" - # The dag-existence / partitioned-timetable check is intentionally deferred to the + # The Dag-existence / partitioned-timetable check is intentionally deferred to the # empty-results branch below. In the happy path (rows exist), filtering by dag_id # already restricts to that Dag, so an extra DagModel lookup just to validate # existence wastes a query. We only consult DagModel when we have no rows to # report — that's the only case where the distinction (404 vs empty) matters. - # Subquery for received count per partition (count of required assets that have any log). - # This matches the non-rollup contract "any event for an asset = that asset is satisfied". - # Rollup-aware counts are computed in Python in _compute_received_count when dag_id is set. - required_assets_subq = ( - select(DagScheduleAssetReference.asset_id) - .join(AssetModel, AssetModel.id == DagScheduleAssetReference.asset_id) - .where( - DagScheduleAssetReference.dag_id == AssetPartitionDagRun.target_dag_id, - AssetModel.active.has(), - ) - .correlate(AssetPartitionDagRun) - ) - received_subq = ( - select(func.count(func.distinct(PartitionedAssetKeyLog.asset_id))) - .where( - PartitionedAssetKeyLog.asset_partition_dag_run_id == AssetPartitionDagRun.id, - PartitionedAssetKeyLog.asset_id.in_(required_assets_subq), - ) - .correlate(AssetPartitionDagRun) - .scalar_subquery() - ) - query = select( AssetPartitionDagRun.id, AssetPartitionDagRun.target_dag_id, @@ -208,7 +191,6 @@ def get_partitioned_dag_runs( AssetPartitionDagRun.created_dag_run_id, DagRun.run_id.label("dag_run_id"), DagRun.state.label("dag_run_state"), - received_subq.label("total_received"), ).outerjoin(DagRun, AssetPartitionDagRun.created_dag_run_id == DagRun.id) query = apply_filters_to_select(statement=query, filters=[dag_id, has_created_dag_run_id]) readable_dag_ids = readable_dags_filter.value @@ -225,48 +207,29 @@ def get_partitioned_dag_runs( raise HTTPException(status.HTTP_404_NOT_FOUND, f"Dag with id {dag_id.value} was not found") return PartitionedDagRunCollectionResponse(partitioned_dag_runs=[], total=0) - if dag_id.value is not None: - timetable, asset_info, asset_id_to_info = _load_timetable_and_assets(dag_id.value, session) - - # Batch-fetch all log entries for these APDRs in one query. - apdr_ids = [row.id for row in rows] - log_by_apdr: dict[int, dict[int, set[str]]] = {} - for pakl_row in session.execute( - select( - PartitionedAssetKeyLog.asset_partition_dag_run_id, - PartitionedAssetKeyLog.asset_id, - PartitionedAssetKeyLog.source_partition_key, - ).where( - PartitionedAssetKeyLog.asset_partition_dag_run_id.in_(apdr_ids), - PartitionedAssetKeyLog.asset_id.in_(list(asset_id_to_info)), - ) - ).all(): - log_by_apdr.setdefault(pakl_row.asset_partition_dag_run_id, {}).setdefault( - pakl_row.asset_id, set() - ).add(pakl_row.source_partition_key or "") - - results = [ - _build_response( - row, - _compute_total_required(timetable, asset_info, row.partition_key), - _compute_received_count( - log_by_apdr.get(row.id, {}), timetable, asset_id_to_info, row.partition_key - ), - ) - for row in rows - ] - return PartitionedDagRunCollectionResponse(partitioned_dag_runs=results, total=len(results)) - - # No dag_id filter: load timetables and assets for each unique Dag. - # total_received is approximated via SQL for this global view. + # Load per-Dag timetable + required assets and batch-fetch the log entries + # for every APDR in a single query, so total_received uses the rollup-aware + # Python computation uniformly across single-Dag and global views. The + # alternative (a SQL count subquery) cannot honour rollup windows without + # also running the mapper, and silently divergent semantics between branches + # are worse than paying the per-Dag timetable load cost. unique_dag_ids = list({row.target_dag_id for row in rows}) dag_timetables_assets: dict[ str, tuple[PartitionedAssetTimetable | None, list[AssetNameUri], dict[int, AssetNameUri]] ] = {did: _load_timetable_and_assets(did, session) for did in unique_dag_ids} - dag_rows = session.execute( - select(DagModel.dag_id, DagModel.asset_expression).where(DagModel.dag_id.in_(unique_dag_ids)) - ).all() - asset_expressions = {r.dag_id: r.asset_expression for r in dag_rows} + + apdr_ids = [row.id for row in rows] + log_by_apdr: dict[int, dict[int, set[str]]] = {} + for pakl_row in session.execute( + select( + PartitionedAssetKeyLog.asset_partition_dag_run_id, + PartitionedAssetKeyLog.asset_id, + PartitionedAssetKeyLog.source_partition_key, + ).where(PartitionedAssetKeyLog.asset_partition_dag_run_id.in_(apdr_ids)) + ).all(): + log_by_apdr.setdefault(pakl_row.asset_partition_dag_run_id, {}).setdefault( + pakl_row.asset_id, set() + ).add(pakl_row.source_partition_key or "") results = [ _build_response( @@ -276,9 +239,23 @@ def get_partitioned_dag_runs( dag_timetables_assets[row.target_dag_id][1], row.partition_key, ), + _compute_received_count( + log_by_apdr.get(row.id, {}), + dag_timetables_assets[row.target_dag_id][0], + dag_timetables_assets[row.target_dag_id][2], + row.partition_key, + ), ) for row in rows ] + + asset_expressions: dict[str, dict | None] | None = None + if dag_id.value is None: + dag_rows = session.execute( + select(DagModel.dag_id, DagModel.asset_expression).where(DagModel.dag_id.in_(unique_dag_ids)) + ).all() + asset_expressions = {r.dag_id: r.asset_expression for r in dag_rows} + return PartitionedDagRunCollectionResponse( partitioned_dag_runs=results, total=len(results), From 5506d3681f834ce593da46031c6526169652cf5b Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Tue, 5 May 2026 18:28:34 +0800 Subject: [PATCH 12/16] feat: write audit log when rollup mapper evaluation fails --- .../src/airflow/jobs/scheduler_job_runner.py | 75 ++++++++++++++----- .../tests/unit/jobs/test_scheduler_job.py | 15 +++- 2 files changed, 71 insertions(+), 19 deletions(-) diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py b/airflow-core/src/airflow/jobs/scheduler_job_runner.py index 7ade6e22036c4..cd9dd7c6430e2 100644 --- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py +++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py @@ -127,10 +127,22 @@ TASK_STUCK_IN_QUEUED_RESCHEDULE_EVENT = "stuck in queued reschedule" """:meta private:""" -# Per-tick cap on pending ``AssetPartitionDagRun`` rows the scheduler will evaluate. -# Bounds the transaction size so other scheduling work isn't starved; remaining -# rows drain across subsequent ticks. MAX_PARTITION_DAG_RUNS_PER_TICK = 500 +""" +Per-tick cap on pending :class:`~airflow.models.asset.AssetPartitionDagRun` +rows the scheduler will evaluate when creating partition-driven Dag runs. + +A single tick reads this many APDRs in FIFO order, bulk-loads the +serialized Dags + per-asset event logs, and evaluates each. Bounding the +per-tick batch keeps the transaction short so executor heartbeats and +regular scheduling work aren't starved on busy deployments. Remaining +APDRs drain across subsequent ticks; the cap therefore sets the +*steady-state throughput*, not the maximum backlog. + +500 was chosen as a default that keeps the per-tick critical section +comfortably under one second on typical hardware while still draining a +modest backlog within a few ticks. +""" def _eager_load_dag_run_for_validation() -> tuple[LoaderOption, LoaderOption]: @@ -1853,30 +1865,38 @@ def _check_rollup_asset_status( def _resolve_asset_partition_status( self, *, + session: Session, asset_id: int, name: str, uri: str, apdr: AssetPartitionDagRun, timetable: PartitionedAssetTimetable, actual_by_asset: dict[int, set[str]], - ) -> bool | None: + ) -> bool: """ - Return the rollup status for one asset within a pending partitioned Dag run. + Return whether *asset_id* has been satisfied for *apdr*. + + Non-rollup assets resolve to ``True`` because the caller only invokes + this for assets that already have at least one logged event for *APDR* + (see :class:`~airflow.models.asset.PartitionedAssetKeyLog`), which is + the non-rollup contract for "received". Rollup assets defer to + :meth:`_check_rollup_asset_status` for the upstream-window check. - Returns *True*/*False* for rollup assets, or *None* when the asset has no - rollup mapper and should default to satisfied. + A misconfigured mapper that raises returns ``False`` (treated as + not-yet-satisfied) and an audit log entry is written so the operator + can see why the Dag run is being held in the UI. """ try: mapper = timetable.get_partition_mapper(name=name, uri=uri) if not mapper.is_rollup: - return None + return True return self._check_rollup_asset_status( asset_id=asset_id, apdr=apdr, mapper=cast("RollupMapper", mapper), actual_by_asset=actual_by_asset, ) - except Exception: + except Exception as err: self.log.exception( "Failed to evaluate rollup status for asset; treating as not-yet-satisfied. " "This likely indicates a misconfigured partition mapper.", @@ -1885,11 +1905,32 @@ def _resolve_asset_partition_status( asset_name=name, asset_uri=uri, ) + session.add( + Log( + event="failed to evaluate rollup status", + dag_id=apdr.target_dag_id, + extra=( + "Could not evaluate rollup status for partition_key " + f"'{apdr.partition_key}' on asset (name='{name}', uri='{uri}') " + f"in target Dag '{apdr.target_dag_id}'. This likely indicates " + "that the rollup mapper is misconfigured or does not support " + f"this partition key.\n{type(err).__name__}: {err}" + ), + ) + ) return False def _create_dagruns_for_partitioned_asset_dags(self, session: Session) -> set[str]: - partition_dag_ids: set[str] = set() - + """ + Create Dag runs for pending :class:`AssetPartitionDagRun` rows whose partition is satisfied. + + Returns the set of ``dag_id`` strings that received a new partition-driven Dag run in this + tick. The caller (:meth:`_create_dagruns_for_dags`) uses this set to exclude the same Dags + from the standard schedule-driven and asset-triggered creation paths so a single Dag never + gets two Dag runs for the same tick when it appears in more than one creation path. We + return ``dag_id`` strings rather than full Dag/DagRun objects because the only downstream + use is membership lookup, and a heavier return type would just be discarded. + """ # Cap per-tick work so the scheduler transaction stays bounded and other # scheduling work isn't starved. Remaining APDRs drain across subsequent ticks. # Note: with strict FIFO ordering, persistently-unsatisfied APDRs at the head @@ -1906,8 +1947,9 @@ def _create_dagruns_for_partitioned_asset_dags(self, session: Session) -> set[st .limit(MAX_PARTITION_DAG_RUNS_PER_TICK) ).all() if not pending_apdrs: - return partition_dag_ids + return set() + partition_dag_ids: set[str] = set() pending_apdr_ids = [apdr.id for apdr in pending_apdrs] # Pre-fetch all required serialized Dags in one query. @@ -1954,7 +1996,8 @@ def _create_dagruns_for_partitioned_asset_dags(self, session: Session) -> set[st for asset_id, (name, uri) in asset_info_per_apdr[apdr.id].items(): key = SerializedAssetUniqueKey(name=name, uri=uri) if isinstance(timetable, PartitionedAssetTimetable): - status = self._resolve_asset_partition_status( + statuses[key] = self._resolve_asset_partition_status( + session=session, asset_id=asset_id, name=name, uri=uri, @@ -1962,10 +2005,8 @@ def _create_dagruns_for_partitioned_asset_dags(self, session: Session) -> set[st timetable=timetable, actual_by_asset=source_key_by_asset, ) - if status is not None: - statuses[key] = status - continue - statuses[key] = True + else: + statuses[key] = True if not evaluator.run(timetable.asset_condition, statuses=statuses): continue diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py b/airflow-core/tests/unit/jobs/test_scheduler_job.py index ff15f91485152..4b4c72bb03484 100644 --- a/airflow-core/tests/unit/jobs/test_scheduler_job.py +++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py @@ -9995,9 +9995,10 @@ def test_partitioned_dag_run_rollup_treats_mapper_exception_as_not_satisfied( ): """ A misconfigured rollup mapper that raises during status evaluation must not crash - the scheduler tick — the asset is treated as not-yet-satisfied and the APDR - remains pending. + the scheduler tick — the asset is treated as not-yet-satisfied, the APDR remains + pending, and an audit log entry surfaces the reason to the operator. """ + session.execute(delete(Log)) asset_1 = Asset(name="asset-1") with dag_maker( dag_id="rollup-consumer", @@ -10037,6 +10038,16 @@ def test_partitioned_dag_run_rollup_treats_mapper_exception_as_not_satisfied( assert apdr.created_dag_run_id is None assert partition_dags == set() + # Audit log is added via session.add inside the scheduler tick and only + # visible to a subsequent read after a flush. + session.flush() + audit_log = session.scalar(select(Log).where(Log.event == "failed to evaluate rollup status")) + assert audit_log is not None + assert audit_log.dag_id == "rollup-consumer" + assert "misconfigured rollup mapper" in audit_log.extra + assert "asset-1" in audit_log.extra + assert "2024-01-01T00" in audit_log.extra + @pytest.mark.need_serialized_dag @pytest.mark.usefixtures("clear_asset_partition_rows") From 846c0d8379af48fe2878fe299f30340c925f2169 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Tue, 5 May 2026 19:12:14 +0800 Subject: [PATCH 13/16] refactor: compile temporal mapper key pattern eagerly in __init__ --- .../src/airflow/partition_mappers/temporal.py | 29 +++++++------------ .../ui/openapi-gen/requests/schemas.gen.ts | 3 +- .../ui/openapi-gen/requests/types.gen.ts | 3 -- .../components/AssetExpression/AssetNode.tsx | 3 +- .../tests/unit/jobs/test_scheduler_job.py | 1 + 5 files changed, 13 insertions(+), 26 deletions(-) diff --git a/airflow-core/src/airflow/partition_mappers/temporal.py b/airflow-core/src/airflow/partition_mappers/temporal.py index 5101f982cf2fc..089eb358c398c 100644 --- a/airflow-core/src/airflow/partition_mappers/temporal.py +++ b/airflow-core/src/airflow/partition_mappers/temporal.py @@ -19,7 +19,6 @@ import re from abc import ABC, abstractmethod from datetime import datetime, timedelta -from functools import cached_property from typing import TYPE_CHECKING, Any from airflow._shared.timezones.timezone import make_aware, parse_timezone @@ -215,21 +214,17 @@ def __init__( raise ValueError(f"week_start must be between 0 (Monday) and 6 (Sunday), got {week_start!r}") super().__init__(timezone=timezone, input_format=input_format, output_format=output_format) self.week_start = week_start # 0 = Monday (ISO default), 6 = Sunday - # Force-compile so a malformed output_format raises ValueError at - # construction time instead of an opaque re.error inside the scheduler. - _ = self._key_pattern + # %V (ISO week) cannot be round-tripped through strptime without %G+%u, + # so derive a named-group regex from output_format and pull out %Y/%m/%d. + # Compile eagerly so a malformed output_format raises ValueError here + # instead of an opaque re.error inside the scheduler. + self._key_pattern = _compile_output_format_regex(self.output_format) def normalize(self, dt: datetime) -> datetime: days_since_start = (dt.weekday() - self.week_start) % 7 start = dt - timedelta(days=days_since_start) return start.replace(hour=0, minute=0, second=0, microsecond=0) - @cached_property - def _key_pattern(self) -> re.Pattern[str]: - # %V (ISO week) cannot be round-tripped through strptime without %G+%u, - # so derive a named-group regex from output_format and pull out %Y/%m/%d. - return _compile_output_format_regex(self.output_format) - def decode_downstream(self, downstream_key: str) -> datetime: match = self._key_pattern.search(downstream_key) if match is None or not {"Y", "m", "d"}.issubset(match.groupdict()): @@ -311,9 +306,11 @@ def __init__( output_format: str | None = None, ) -> None: super().__init__(timezone=timezone, input_format=input_format, output_format=output_format) - # Force-compile so a malformed output_format raises ValueError at - # construction time instead of an opaque re.error inside the scheduler. - _ = self._key_pattern + # ``{quarter}`` is a Python-format placeholder, not a strftime directive, + # so derive a named-group regex from output_format that handles both. + # Compile eagerly so a malformed output_format raises ValueError here + # instead of an opaque re.error inside the scheduler. + self._key_pattern = _compile_output_format_regex(self.output_format, {"quarter": r"[1-4]"}) def normalize(self, dt: datetime) -> datetime: quarter = (dt.month - 1) // 3 @@ -331,12 +328,6 @@ def format(self, dt: datetime) -> str: quarter = (dt.month - 1) // 3 + 1 return dt.strftime(self.output_format).format(quarter=quarter) - @cached_property - def _key_pattern(self) -> re.Pattern[str]: - # ``{quarter}`` is a Python-format placeholder, not a strftime directive, - # so derive a named-group regex from output_format that handles both. - return _compile_output_format_regex(self.output_format, {"quarter": r"[1-4]"}) - def decode_downstream(self, downstream_key: str) -> datetime: match = self._key_pattern.search(downstream_key) if match is None or not {"Y", "quarter"}.issubset(match.groupdict()): diff --git a/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts b/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts index 61589be2ed320..fd5ed96975257 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts @@ -8323,8 +8323,7 @@ export const $ExtraMenuItem = { }, type: 'object', required: ['text', 'href'], - title: 'ExtraMenuItem', - description: 'Define a menu item that can be added to the menu by auth managers or plugins.' + title: 'ExtraMenuItem' } as const; export const $GanttResponse = { diff --git a/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts b/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts index 8acc694124bea..57948231797be 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts @@ -2063,9 +2063,6 @@ export type EdgeResponse = { is_source_asset?: boolean | null; }; -/** - * Define a menu item that can be added to the menu by auth managers or plugins. - */ export type ExtraMenuItem = { text: string; href: string; diff --git a/airflow-core/src/airflow/ui/src/components/AssetExpression/AssetNode.tsx b/airflow-core/src/airflow/ui/src/components/AssetExpression/AssetNode.tsx index 678d66760b274..84dfb9a459389 100644 --- a/airflow-core/src/airflow/ui/src/components/AssetExpression/AssetNode.tsx +++ b/airflow-core/src/airflow/ui/src/components/AssetExpression/AssetNode.tsx @@ -98,8 +98,7 @@ export const AssetNode = ({ // which only the partitioned + pending-APDR branch of `next_run_assets` // populates. Plain non-partitioned events leave `required_keys` empty and // must keep showing their timestamp. - const isPartitionedNonRollup = - event?.is_rollup === false && (event?.required_keys?.length ?? 0) > 0; + const isPartitionedNonRollup = event?.is_rollup === false && (event.required_keys?.length ?? 0) > 0; const showTime = isFullyReceived && !isPartitionedNonRollup; const showRollupChecklist = (event?.is_rollup ?? false) && (event?.required_keys?.length ?? 0) > 0 && (isPartial || isFullyReceived); diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py b/airflow-core/tests/unit/jobs/test_scheduler_job.py index 4b4c72bb03484..b8a684be29c13 100644 --- a/airflow-core/tests/unit/jobs/test_scheduler_job.py +++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py @@ -10044,6 +10044,7 @@ def test_partitioned_dag_run_rollup_treats_mapper_exception_as_not_satisfied( audit_log = session.scalar(select(Log).where(Log.event == "failed to evaluate rollup status")) assert audit_log is not None assert audit_log.dag_id == "rollup-consumer" + assert audit_log.extra is not None assert "misconfigured rollup mapper" in audit_log.extra assert "asset-1" in audit_log.extra assert "2024-01-01T00" in audit_log.extra From c776f3841360cd9d87565ce7412501f43d7b5776 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Wed, 6 May 2026 17:01:56 +0800 Subject: [PATCH 14/16] refactor: use timetable.partitioned property for partition checks --- .../api_fastapi/core_api/routes/ui/partitioned_dag_runs.py | 2 +- airflow-core/src/airflow/jobs/scheduler_job_runner.py | 4 +++- .../src/airflow/ui/openapi-gen/requests/schemas.gen.ts | 3 ++- airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts | 3 +++ 4 files changed, 9 insertions(+), 3 deletions(-) diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/partitioned_dag_runs.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/partitioned_dag_runs.py index 254d799412034..b30570618f977 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/partitioned_dag_runs.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/partitioned_dag_runs.py @@ -66,7 +66,7 @@ def _load_timetable(dag_id: str, session: Session) -> PartitionedAssetTimetable if serdag is None: return None with suppress(Exception): - if isinstance(serdag.dag.timetable, PartitionedAssetTimetable): + if serdag.dag.timetable.partitioned: return serdag.dag.timetable return None diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py b/airflow-core/src/airflow/jobs/scheduler_job_runner.py index cd9dd7c6430e2..23019c8a3d438 100644 --- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py +++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py @@ -1995,7 +1995,9 @@ def _create_dagruns_for_partitioned_asset_dags(self, session: Session) -> set[st statuses: dict[SerializedAssetUniqueKey, bool] = {} for asset_id, (name, uri) in asset_info_per_apdr[apdr.id].items(): key = SerializedAssetUniqueKey(name=name, uri=uri) - if isinstance(timetable, PartitionedAssetTimetable): + if timetable.partitioned is True: + if TYPE_CHECKING: + assert isinstance(timetable, PartitionedAssetTimetable) statuses[key] = self._resolve_asset_partition_status( session=session, asset_id=asset_id, diff --git a/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts b/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts index fd5ed96975257..61589be2ed320 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts @@ -8323,7 +8323,8 @@ export const $ExtraMenuItem = { }, type: 'object', required: ['text', 'href'], - title: 'ExtraMenuItem' + title: 'ExtraMenuItem', + description: 'Define a menu item that can be added to the menu by auth managers or plugins.' } as const; export const $GanttResponse = { diff --git a/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts b/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts index 57948231797be..8acc694124bea 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts @@ -2063,6 +2063,9 @@ export type EdgeResponse = { is_source_asset?: boolean | null; }; +/** + * Define a menu item that can be added to the menu by auth managers or plugins. + */ export type ExtraMenuItem = { text: string; href: string; From e008f8dfd779fdf24954646a1fcc02a5af5a8bbb Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Thu, 7 May 2026 13:38:02 +0800 Subject: [PATCH 15/16] refactor: drop week_start and month_start_day from temporal mappers StartOfWeekMapper now always uses ISO weeks (Monday) and StartOfMonthMapper always emits the 1st of the month. Custom fiscal boundaries can still be expressed by pairing a user-defined source mapper with the existing windows. --- .../src/airflow/partition_mappers/temporal.py | 61 +++---------------- .../src/airflow/partition_mappers/window.py | 15 +++-- .../routes/ui/test_partitioned_dag_runs.py | 4 +- .../unit/partition_mappers/test_temporal.py | 51 ++++------------ .../unit/partition_mappers/test_window.py | 25 ++------ .../definitions/partition_mappers/temporal.py | 12 ---- 6 files changed, 34 insertions(+), 134 deletions(-) diff --git a/airflow-core/src/airflow/partition_mappers/temporal.py b/airflow-core/src/airflow/partition_mappers/temporal.py index 089eb358c398c..1f42b4755302d 100644 --- a/airflow-core/src/airflow/partition_mappers/temporal.py +++ b/airflow-core/src/airflow/partition_mappers/temporal.py @@ -205,15 +205,11 @@ class StartOfWeekMapper(_BaseTemporalMapper): def __init__( self, *, - week_start: int = 0, timezone: str | Timezone | FixedTimezone = "UTC", input_format: str = "%Y-%m-%dT%H:%M:%S", output_format: str | None = None, ) -> None: - if not 0 <= week_start <= 6: - raise ValueError(f"week_start must be between 0 (Monday) and 6 (Sunday), got {week_start!r}") super().__init__(timezone=timezone, input_format=input_format, output_format=output_format) - self.week_start = week_start # 0 = Monday (ISO default), 6 = Sunday # %V (ISO week) cannot be round-tripped through strptime without %G+%u, # so derive a named-group regex from output_format and pull out %Y/%m/%d. # Compile eagerly so a malformed output_format raises ValueError here @@ -221,8 +217,7 @@ def __init__( self._key_pattern = _compile_output_format_regex(self.output_format) def normalize(self, dt: datetime) -> datetime: - days_since_start = (dt.weekday() - self.week_start) % 7 - start = dt - timedelta(days=days_since_start) + start = dt - timedelta(days=dt.weekday()) return start.replace(hour=0, minute=0, second=0, microsecond=0) def decode_downstream(self, downstream_key: str) -> datetime: @@ -235,61 +230,19 @@ def decode_downstream(self, downstream_key: str) -> datetime: ) return datetime(int(match["Y"]), int(match["m"]), int(match["d"])) - def serialize(self) -> dict[str, Any]: - return {**super().serialize(), "week_start": self.week_start} - - @classmethod - def deserialize(cls, data: dict[str, Any]) -> PartitionMapper: - return cls( - week_start=data.get("week_start", 0), - timezone=parse_timezone(data.get("timezone", "UTC")), - input_format=data["input_format"], - output_format=data["output_format"], - ) - class StartOfMonthMapper(_BaseTemporalMapper): """Map a time-based partition key to the start of its month.""" default_output_format = "%Y-%m" - def __init__( - self, - *, - month_start_day: int = 1, - timezone: str | Timezone | FixedTimezone = "UTC", - input_format: str = "%Y-%m-%dT%H:%M:%S", - output_format: str | None = None, - ) -> None: - if not 1 <= month_start_day <= 28: - raise ValueError(f"month_start_day must be between 1 and 28, got {month_start_day!r}") - super().__init__(timezone=timezone, input_format=input_format, output_format=output_format) - self.month_start_day = month_start_day # 1–28; use >1 for fiscal-month offsets - def normalize(self, dt: datetime) -> datetime: - if dt.day < self.month_start_day: - month = dt.month - 1 or 12 - year = dt.year - (1 if dt.month == 1 else 0) - start = dt.replace(year=year, month=month, day=self.month_start_day) - else: - start = dt.replace(day=self.month_start_day) - return start.replace(hour=0, minute=0, second=0, microsecond=0) - - def decode_downstream(self, downstream_key: str) -> datetime: - # The default strptime returns day=1; pin to month_start_day so fiscal - # months recover the correct period start. - return super().decode_downstream(downstream_key).replace(day=self.month_start_day) - - def serialize(self) -> dict[str, Any]: - return {**super().serialize(), "month_start_day": self.month_start_day} - - @classmethod - def deserialize(cls, data: dict[str, Any]) -> PartitionMapper: - return cls( - month_start_day=data.get("month_start_day", 1), - timezone=parse_timezone(data.get("timezone", "UTC")), - input_format=data["input_format"], - output_format=data["output_format"], + return dt.replace( + day=1, + hour=0, + minute=0, + second=0, + microsecond=0, ) diff --git a/airflow-core/src/airflow/partition_mappers/window.py b/airflow-core/src/airflow/partition_mappers/window.py index ea2b2677150db..f47bcb386b48a 100644 --- a/airflow-core/src/airflow/partition_mappers/window.py +++ b/airflow-core/src/airflow/partition_mappers/window.py @@ -28,11 +28,10 @@ def _shift_months(dt: datetime, months: int) -> datetime: """ Return *dt* shifted forward by *months*, wrapping the year as needed. - Built-in temporal mappers always emit period-starts on day 1 (or - ``month_start_day`` <= 28 for fiscal months), so :meth:`datetime.replace` - on the new month is always valid. A user-provided source mapper that - emits a higher day for a month with fewer days remains the caller's - responsibility. + Built-in temporal mappers always emit period-starts on day 1, so + :meth:`datetime.replace` on the new month is always valid. A + user-provided source mapper that emits a higher day for a month with + fewer days remains the caller's responsibility. """ total = dt.month - 1 + months return dt.replace(year=dt.year + total // 12, month=total % 12 + 1) @@ -93,9 +92,9 @@ class MonthWindow(Window): """ All daily period-starts making up one calendar month. - The source mapper's ``decode_downstream`` already accounts for fiscal - ``month_start_day``, so this just iterates from the period start to the - matching day of the next month. + Iterates from the period start to the matching day of the next month, + so a user-provided source mapper that emits non-1st period-starts + (e.g. fiscal months) is handled transparently. """ def to_upstream(self, period_start: datetime) -> Iterable[datetime]: diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_partitioned_dag_runs.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_partitioned_dag_runs.py index 72c05a68d20ea..10ca5304961eb 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_partitioned_dag_runs.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_partitioned_dag_runs.py @@ -324,7 +324,7 @@ def test_rollup_mapper_counts_received_upstream_keys(self, test_client, dag_make """For a rollup mapper, only upstream keys in to_upstream() are counted.""" asset_def = Asset(uri="s3://bucket/daily", name="daily") mapper = RollupMapper( - source_mapper=StartOfWeekMapper(input_format="%Y-%m-%d", output_format="%Y-%m-%d", week_start=0), + source_mapper=StartOfWeekMapper(input_format="%Y-%m-%d", output_format="%Y-%m-%d"), window=WeekWindow(), ) with dag_maker( @@ -521,7 +521,7 @@ def test_is_rollup_true_for_rollup_asset(self, test_client, dag_maker, session): """is_rollup is True for assets that use a RollupMapper, and keys are populated.""" asset_def = Asset(uri="s3://bucket/weekly", name="weekly") mapper = RollupMapper( - source_mapper=StartOfWeekMapper(input_format="%Y-%m-%d", output_format="%Y-%m-%d", week_start=0), + source_mapper=StartOfWeekMapper(input_format="%Y-%m-%d", output_format="%Y-%m-%d"), window=WeekWindow(), ) with dag_maker( diff --git a/airflow-core/tests/unit/partition_mappers/test_temporal.py b/airflow-core/tests/unit/partition_mappers/test_temporal.py index 60d6ffc2a0352..6d61c8b38e72c 100644 --- a/airflow-core/tests/unit/partition_mappers/test_temporal.py +++ b/airflow-core/tests/unit/partition_mappers/test_temporal.py @@ -57,34 +57,29 @@ def test_to_downstream( ], ) @pytest.mark.parametrize( - ("mapper_cls", "expected_outut_format", "extra_kwargs"), + ("mapper_cls", "expected_outut_format"), [ - (StartOfHourMapper, "%Y-%m-%dT%H", {}), - (StartOfDayMapper, "%Y-%m-%d", {}), - (StartOfWeekMapper, "%Y-%m-%d (W%V)", {"week_start": 0}), - (StartOfMonthMapper, "%Y-%m", {"month_start_day": 1}), - (StartOfQuarterMapper, "%Y-Q{quarter}", {}), - (StartOfYearMapper, "%Y", {}), + (StartOfHourMapper, "%Y-%m-%dT%H"), + (StartOfDayMapper, "%Y-%m-%d"), + (StartOfWeekMapper, "%Y-%m-%d (W%V)"), + (StartOfMonthMapper, "%Y-%m"), + (StartOfQuarterMapper, "%Y-Q{quarter}"), + (StartOfYearMapper, "%Y"), ], ) def test_serialize( self, mapper_cls: type[_BaseTemporalMapper], expected_outut_format: str, - extra_kwargs: dict[str, int], timezone: str | None, expected_timezone: str, ): pm = mapper_cls() if timezone is None else mapper_cls(timezone=timezone) - assert ( - pm.serialize() - == { - "timezone": expected_timezone, - "input_format": "%Y-%m-%dT%H:%M:%S", - "output_format": expected_outut_format, - } - | extra_kwargs - ) + assert pm.serialize() == { + "timezone": expected_timezone, + "input_format": "%Y-%m-%dT%H:%M:%S", + "output_format": expected_outut_format, + } @pytest.mark.parametrize( "mapper_cls", @@ -132,28 +127,6 @@ def test_to_downstream_input_timezone_differs_from_mapper_timezone(self): assert pm.to_downstream("2026-02-11T06:00:00+0000") == "2026-02-11" -class TestStartOfWeekMapperValidation: - @pytest.mark.parametrize("week_start", [-1, 7, 100]) - def test_rejects_out_of_range(self, week_start): - with pytest.raises(ValueError, match="week_start must be between 0"): - StartOfWeekMapper(week_start=week_start) - - @pytest.mark.parametrize("week_start", [0, 3, 6]) - def test_accepts_valid_range(self, week_start): - assert StartOfWeekMapper(week_start=week_start).week_start == week_start - - -class TestStartOfMonthMapperValidation: - @pytest.mark.parametrize("month_start_day", [0, 29, 31, -1]) - def test_rejects_out_of_range(self, month_start_day): - with pytest.raises(ValueError, match="month_start_day must be between 1 and 28"): - StartOfMonthMapper(month_start_day=month_start_day) - - @pytest.mark.parametrize("month_start_day", [1, 15, 28]) - def test_accepts_valid_range(self, month_start_day): - assert StartOfMonthMapper(month_start_day=month_start_day).month_start_day == month_start_day - - class TestOutputFormatValidation: """Malformed output_format must fail fast at mapper construction, not as an opaque re.error inside the scheduler.""" diff --git a/airflow-core/tests/unit/partition_mappers/test_window.py b/airflow-core/tests/unit/partition_mappers/test_window.py index 0faaa2bd92f4d..f3b40caf685a3 100644 --- a/airflow-core/tests/unit/partition_mappers/test_window.py +++ b/airflow-core/tests/unit/partition_mappers/test_window.py @@ -24,7 +24,6 @@ from airflow.partition_mappers.temporal import ( StartOfDayMapper, StartOfHourMapper, - StartOfMonthMapper, StartOfQuarterMapper, StartOfWeekMapper, StartOfYearMapper, @@ -80,8 +79,9 @@ def test_yields_all_days_in_february_leap_year(self): assert members[0] == datetime(2024, 2, 1) assert members[-1] == datetime(2024, 2, 29) - def test_honours_fiscal_month_offset(self): - # 2024-06-15 → 2024-07-14 covers 30 days when month_start_day=15. + def test_honours_non_calendar_month_offset(self): + # A custom source mapper might decode a period start mid-month; + # the window must walk forward to the same day in the next month. members = list(MonthWindow().to_upstream(datetime(2024, 6, 15))) assert len(members) == 30 assert members[0] == datetime(2024, 6, 15) @@ -124,7 +124,7 @@ def test_wraps_year_for_fiscal_year_start(self): class TestRollupMapperComposition: def test_to_downstream_delegates_to_source_mapper(self): mapper = RollupMapper( - source_mapper=StartOfWeekMapper(week_start=0), + source_mapper=StartOfWeekMapper(), window=WeekWindow(), ) # Wednesday 2024-06-12 belongs to the week starting Monday 2024-06-10. @@ -165,7 +165,7 @@ def test_day_rollup_honours_timezone_in_encode(self): def test_week_rollup_with_default_format(self): mapper = RollupMapper( - source_mapper=StartOfWeekMapper(input_format="%Y-%m-%d", output_format="%Y-%m-%d", week_start=0), + source_mapper=StartOfWeekMapper(input_format="%Y-%m-%d", output_format="%Y-%m-%d"), window=WeekWindow(), ) upstream = mapper.to_upstream("2024-06-10") @@ -187,7 +187,6 @@ def test_week_rollup_accepts_custom_output_format(self): source_mapper=StartOfWeekMapper( input_format="%Y-%m-%d", output_format="Week-of-%Y-%m-%d", - week_start=0, ), window=WeekWindow(), ) @@ -203,17 +202,6 @@ def test_week_rollup_raises_when_downstream_key_has_no_date(self): with pytest.raises(ValueError, match="StartOfWeekMapper.decode_downstream could not parse"): mapper.to_upstream("week-24") - def test_month_rollup_fiscal_offset(self): - mapper = RollupMapper( - source_mapper=StartOfMonthMapper( - input_format="%Y-%m-%d", output_format="%Y-%m", month_start_day=15 - ), - window=MonthWindow(), - ) - upstream = mapper.to_upstream("2024-06") - assert "2024-06-15" in upstream - assert "2024-07-14" in upstream - def test_quarter_rollup_to_upstream_keys(self): mapper = RollupMapper( source_mapper=StartOfQuarterMapper(input_format="%Y-%m"), @@ -250,13 +238,12 @@ def test_serialize_round_trip(self): from airflow.serialization.encoders import encode_partition_mapper mapper = RollupMapper( - source_mapper=StartOfWeekMapper(week_start=0, input_format="%Y-%m-%d", output_format="%Y-%m-%d"), + source_mapper=StartOfWeekMapper(input_format="%Y-%m-%d", output_format="%Y-%m-%d"), window=WeekWindow(), ) restored = decode_partition_mapper(encode_partition_mapper(mapper)) assert isinstance(restored, RollupMapper) assert isinstance(restored.source_mapper, StartOfWeekMapper) - assert restored.source_mapper.week_start == 0 assert isinstance(restored.window, WeekWindow) assert restored.to_upstream("2024-06-10") == mapper.to_upstream("2024-06-10") diff --git a/task-sdk/src/airflow/sdk/definitions/partition_mappers/temporal.py b/task-sdk/src/airflow/sdk/definitions/partition_mappers/temporal.py index 030d0fd32cb60..eb1fdd049f598 100644 --- a/task-sdk/src/airflow/sdk/definitions/partition_mappers/temporal.py +++ b/task-sdk/src/airflow/sdk/definitions/partition_mappers/temporal.py @@ -48,24 +48,12 @@ class StartOfWeekMapper(_BaseTemporalMapper): default_output_format = "%Y-%m-%d (W%V)" - def __init__(self, *, week_start: int = 0, **kwargs) -> None: - if not 0 <= week_start <= 6: - raise ValueError(f"week_start must be between 0 (Monday) and 6 (Sunday), got {week_start!r}") - super().__init__(**kwargs) - self.week_start = week_start - class StartOfMonthMapper(_BaseTemporalMapper): """Map a time-based partition key to the start of its month.""" default_output_format = "%Y-%m" - def __init__(self, *, month_start_day: int = 1, **kwargs) -> None: - if not 1 <= month_start_day <= 28: - raise ValueError(f"month_start_day must be between 1 and 28, got {month_start_day!r}") - super().__init__(**kwargs) - self.month_start_day = month_start_day - class StartOfQuarterMapper(_BaseTemporalMapper): """Map a time-based partition key to quarter.""" From 8243d0ec13a504924314e02ceaed739b5adbe40c Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Wed, 6 May 2026 21:22:34 +0800 Subject: [PATCH 16/16] perf(ui): avoid timetable deserialization in partitioned Dag UI routes The next_run_assets and partitioned_dag_runs endpoints used to load and deserialize the full timetable on every request just to read mapper attributes (is_rollup) and required-key counts. Cache mapper metadata per asset on DagModel during Dag sync via a new ``partition_mapper_info`` JSON column, so the UI resolves mapper attributes from the cache and only loads the timetable when ``to_upstream`` evaluation for rollup mappers is actually needed. --- airflow-core/docs/migrations-ref.rst | 4 +- .../api_fastapi/core_api/routes/ui/assets.py | 44 ++--- .../routes/ui/partitioned_dag_runs.py | 181 +++++++++++------- .../src/airflow/dag_processing/collection.py | 1 + ..._3_3_0_add_partition_mapper_info_to_dag.py | 73 +++++++ airflow-core/src/airflow/models/dag.py | 34 +++- airflow-core/src/airflow/timetables/base.py | 31 ++- airflow-core/src/airflow/timetables/simple.py | 26 ++- airflow-core/src/airflow/utils/db.py | 2 +- .../unit/dag_processing/test_collection.py | 64 +++++++ airflow-core/tests/unit/models/test_dag.py | 66 +++++++ .../timetables/test_partitioned_timetable.py | 41 ++++ .../in_container/run_migration_round_trip.py | 1 + 13 files changed, 470 insertions(+), 98 deletions(-) create mode 100644 airflow-core/src/airflow/migrations/versions/0115_3_3_0_add_partition_mapper_info_to_dag.py diff --git a/airflow-core/docs/migrations-ref.rst b/airflow-core/docs/migrations-ref.rst index a9502d28fe944..3cff5e5449ce4 100644 --- a/airflow-core/docs/migrations-ref.rst +++ b/airflow-core/docs/migrations-ref.rst @@ -39,7 +39,9 @@ Here's the list of all the Database Migrations that are executed via when you ru +-------------------------+------------------+-------------------+--------------------------------------------------------------+ | Revision ID | Revises ID | Airflow Version | Description | +=========================+==================+===================+==============================================================+ -| ``a7f3b2c1d4e5`` (head) | ``b8f3e4a1d2c9`` | ``3.3.0`` | Add allow_producer_teams column to | +| ``c20871fbf23a`` (head) | ``a7f3b2c1d4e5`` | ``3.3.0`` | Add partition_mapper_info to DagModel. | ++-------------------------+------------------+-------------------+--------------------------------------------------------------+ +| ``a7f3b2c1d4e5`` | ``b8f3e4a1d2c9`` | ``3.3.0`` | Add allow_producer_teams column to | | | | | dag_schedule_asset_reference table. | +-------------------------+------------------+-------------------+--------------------------------------------------------------+ | ``b8f3e4a1d2c9`` | ``fde9ed84d07b`` | ``3.3.0`` | Add retry_delay_override and retry_reason to task_instance. | diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/assets.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/assets.py index 43684bbee354b..80d7cdad9e1c0 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/assets.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/assets.py @@ -136,7 +136,6 @@ def next_run_assets( return NextRunAssetsResponse(asset_expression=dag_model.asset_expression, events=events) # Partitioned Dags: enrich with per-asset received/required counts and rollup flag. - timetable = _load_timetable(dag_id, session) pending_apdr = session.execute( select(AssetPartitionDagRun.id, AssetPartitionDagRun.partition_key) .where( @@ -147,25 +146,22 @@ def next_run_assets( .limit(1) ).one_or_none() + has_rollup_mappers = dag_model.has_rollup_mappers + if pending_apdr is None: # No pending APDR yet — mark rollup assets so the UI can handle them # correctly (e.g. skip "Asset Triggered" in favour of the asset name view). - events = [] - for row in raw_rows: - is_rollup = False - if timetable is not None: - with suppress(Exception): - mapper = timetable.get_partition_mapper(name=row.name, uri=row.uri) - is_rollup = mapper.is_rollup - events.append( - NextRunAssetEventResponse( - id=row.id, - name=row.name, - uri=row.uri, - last_update=row.last_update if row.queued else None, - is_rollup=is_rollup, - ) + # Reads from the cached partition_mapper_info so no timetable load is needed. + events = [ + NextRunAssetEventResponse( + id=row.id, + name=row.name, + uri=row.uri, + last_update=row.last_update if row.queued else None, + is_rollup=has_rollup_mappers and dag_model.is_rollup_asset(name=row.name, uri=row.uri), ) + for row in raw_rows + ] return NextRunAssetsResponse( asset_expression=dag_model.asset_expression, events=events, @@ -183,19 +179,19 @@ def next_run_assets( ): received_keys_by_asset.setdefault(log_row.asset_id, set()).add(log_row.source_partition_key or "") + # The timetable is only needed to call ``to_upstream`` for rollup mappers. + # When the cached info shows no rollup mappers, skip loading it entirely. + rollup_timetable = _load_timetable(dag_id, session) if has_rollup_mappers else None + events = [] for row in raw_rows: received_keys = sorted(received_keys_by_asset.get(row.id, set())) required_keys: list[str] = [pending_apdr.partition_key] - is_rollup = False - if timetable is not None: + is_rollup = has_rollup_mappers and dag_model.is_rollup_asset(name=row.name, uri=row.uri) + if is_rollup and rollup_timetable is not None: with suppress(Exception): - mapper = timetable.get_partition_mapper(name=row.name, uri=row.uri) - if mapper.is_rollup: - required_keys = sorted( - cast("RollupMapper", mapper).to_upstream(pending_apdr.partition_key) - ) - is_rollup = True + mapper = rollup_timetable.get_partition_mapper(name=row.name, uri=row.uri) + required_keys = sorted(cast("RollupMapper", mapper).to_upstream(pending_apdr.partition_key)) received_count = len(received_keys) required_count = len(required_keys) # Only surface last_update once all required upstream keys have arrived. diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/partitioned_dag_runs.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/partitioned_dag_runs.py index b30570618f977..e90c5e81eda7f 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/partitioned_dag_runs.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/partitioned_dag_runs.py @@ -61,63 +61,99 @@ def _load_timetable(dag_id: str, session: Session) -> PartitionedAssetTimetable | None: - """Return the PartitionedAssetTimetable for *dag_id*, or None if absent or not partitioned.""" + """ + Return the PartitionedAssetTimetable for *dag_id*, or None if absent or not partitioned. + + Callers gate this behind ``DagModel.has_rollup_mappers``, which is only + populated for ``PartitionedAssetTimetable``. The ``TYPE_CHECKING`` assert + narrows the type for mypy without a runtime ``isinstance`` cost. + """ serdag = SerializedDagModel.get(dag_id=dag_id, session=session) if serdag is None: return None with suppress(Exception): if serdag.dag.timetable.partitioned: + if TYPE_CHECKING: + assert isinstance(serdag.dag.timetable, PartitionedAssetTimetable) return serdag.dag.timetable return None -def _load_timetable_and_assets( - dag_id: str, session: Session -) -> tuple[PartitionedAssetTimetable | None, list[AssetNameUri], dict[int, AssetNameUri]]: +def _fetch_active_assets_per_dag( + dag_ids: list[str], session: Session +) -> dict[str, tuple[list[AssetNameUri], dict[int, AssetNameUri]]]: """ - Load timetable and active required assets. + Batch-fetch active required assets for multiple Dags in a single query. - Returns (timetable, [(name, uri), ...], {asset_id: (name, uri)}). + Returns ``{dag_id: ([(name, uri), ...], {asset_id: (name, uri)})}``. + Dags with no active references are still included with empty containers + so callers can index by ``dag_id`` without ``KeyError``. """ - timetable = _load_timetable(dag_id, session) - asset_rows = session.execute( - select(AssetModel.id, AssetModel.name, AssetModel.uri) + rows = session.execute( + select( + DagScheduleAssetReference.dag_id, + AssetModel.id, + AssetModel.name, + AssetModel.uri, + ) .join(DagScheduleAssetReference, DagScheduleAssetReference.asset_id == AssetModel.id) - .where(DagScheduleAssetReference.dag_id == dag_id, AssetModel.active.has()) + .where(DagScheduleAssetReference.dag_id.in_(dag_ids), AssetModel.active.has()) ).all() - asset_info = [(r.name, r.uri) for r in asset_rows] - asset_id_to_info = {r.id: (r.name, r.uri) for r in asset_rows} - return timetable, asset_info, asset_id_to_info + result: dict[str, tuple[list[AssetNameUri], dict[int, AssetNameUri]]] = { + dag_id: ([], {}) for dag_id in dag_ids + } + for row in rows: + info, id_to_info = result[row.dag_id] + info.append((row.name, row.uri)) + id_to_info[row.id] = (row.name, row.uri) + return result + + +def _resolve_rollup_upstream_keys( + dag_model: DagModel | None, + rollup_timetable: PartitionedAssetTimetable | None, + name: str, + uri: str, + partition_key: str, +) -> frozenset[str] | None: + """ + Return the upstream rollup keys for an asset, or ``None`` for non-rollup. + + Returns ``None`` when the asset is not rollup, when ``dag_model`` is missing + (Dag was deleted while runs still reference it), when ``rollup_timetable`` + couldn't be loaded, or when the mapper raises — callers fall back to + non-rollup behaviour for that asset rather than 500-ing the request. + """ + if dag_model is None or rollup_timetable is None or not dag_model.is_rollup_asset(name=name, uri=uri): + return None + with suppress(Exception): + mapper = rollup_timetable.get_partition_mapper(name=name, uri=uri) + return frozenset(cast("RollupMapper", mapper).to_upstream(partition_key)) + return None def _compute_total_required( - timetable: PartitionedAssetTimetable | None, + dag_model: DagModel | None, + rollup_timetable: PartitionedAssetTimetable | None, asset_info: list[AssetNameUri], partition_key: str, ) -> int: """ Sum required upstream events across all assets, using to_upstream for rollup mappers. - A misconfigured custom mapper that raises in ``to_upstream`` falls back to the - non-rollup count (1 per asset) for that asset rather than failing the whole - request — matches the suppression already used in the detail route. + Non-rollup assets (and rollup assets whose mapper raises) count as 1. """ - if timetable is None: - return len(asset_info) total = 0 for name, uri in asset_info: - count = 1 - with suppress(Exception): - mapper = timetable.get_partition_mapper(name=name, uri=uri) - if mapper.is_rollup: - count = len(cast("RollupMapper", mapper).to_upstream(partition_key)) - total += count + keys = _resolve_rollup_upstream_keys(dag_model, rollup_timetable, name, uri, partition_key) + total += len(keys) if keys is not None else 1 return total def _compute_received_count( + dag_model: DagModel | None, received_by_asset: dict[int, set[str]], - timetable: PartitionedAssetTimetable | None, + rollup_timetable: PartitionedAssetTimetable | None, asset_id_to_info: dict[int, AssetNameUri], partition_key: str, ) -> int: @@ -125,10 +161,8 @@ def _compute_received_count( Count received events using rollup-aware deduplication. For rollup assets: count distinct upstream keys that intersect the required set. - For non-rollup assets: count 1 per asset if any event has been logged — the - source_partition_key value is irrelevant; having any event satisfies the requirement. - A misconfigured rollup mapper falls back to the non-rollup behavior for that - asset so the route doesn't 500. + For non-rollup assets (or rollup mapper failures): any logged event satisfies + the asset — the source_partition_key value is irrelevant. """ total = 0 for asset_id, received_keys in received_by_asset.items(): @@ -137,16 +171,12 @@ def _compute_received_count( # because it batch-fetches logs without a per-Dag asset_id filter. if asset_id not in asset_id_to_info: continue - if timetable is not None: - with suppress(Exception): - name, uri = asset_id_to_info[asset_id] - mapper = timetable.get_partition_mapper(name=name, uri=uri) - if mapper.is_rollup: - required_keys = frozenset(cast("RollupMapper", mapper).to_upstream(partition_key)) - total += len(received_keys & required_keys) - continue - # Non-rollup (or rollup mapper raised): any logged event satisfies this asset. - total += 1 if received_keys else 0 + name, uri = asset_id_to_info[asset_id] + keys = _resolve_rollup_upstream_keys(dag_model, rollup_timetable, name, uri, partition_key) + if keys is not None: + total += len(received_keys & keys) + else: + total += 1 if received_keys else 0 return total @@ -207,16 +237,27 @@ def get_partitioned_dag_runs( raise HTTPException(status.HTTP_404_NOT_FOUND, f"Dag with id {dag_id.value} was not found") return PartitionedDagRunCollectionResponse(partitioned_dag_runs=[], total=0) - # Load per-Dag timetable + required assets and batch-fetch the log entries - # for every APDR in a single query, so total_received uses the rollup-aware - # Python computation uniformly across single-Dag and global views. The - # alternative (a SQL count subquery) cannot honour rollup windows without - # also running the mapper, and silently divergent semantics between branches - # are worse than paying the per-Dag timetable load cost. + # Batch-fetch DagModels (for cached partition_mapper_info), required assets, + # and APDR log entries in three single queries instead of N per-Dag queries. + # Timetables are only loaded for Dags that actually have rollup mappers, + # since that's the only case where ``to_upstream`` evaluation is needed. + # A SQL count subquery for total_received cannot honour rollup windows + # without running the mapper, so the rollup-aware Python computation runs + # uniformly across single-Dag and global views. unique_dag_ids = list({row.target_dag_id for row in rows}) - dag_timetables_assets: dict[ - str, tuple[PartitionedAssetTimetable | None, list[AssetNameUri], dict[int, AssetNameUri]] - ] = {did: _load_timetable_and_assets(did, session) for did in unique_dag_ids} + dag_models: dict[str, DagModel] = { + dm.dag_id: dm + for dm in session.scalars(select(DagModel).where(DagModel.dag_id.in_(unique_dag_ids))).all() + } + assets_by_dag = _fetch_active_assets_per_dag(unique_dag_ids, session) + rollup_timetables_by_dag: dict[str, PartitionedAssetTimetable | None] = { + did: ( + _load_timetable(did, session) + if (dm := dag_models.get(did)) is not None and dm.has_rollup_mappers + else None + ) + for did in unique_dag_ids + } apdr_ids = [row.id for row in rows] log_by_apdr: dict[int, dict[int, set[str]]] = {} @@ -235,14 +276,16 @@ def get_partitioned_dag_runs( _build_response( row, _compute_total_required( - dag_timetables_assets[row.target_dag_id][0], - dag_timetables_assets[row.target_dag_id][1], + dag_models.get(row.target_dag_id), + rollup_timetables_by_dag[row.target_dag_id], + assets_by_dag[row.target_dag_id][0], row.partition_key, ), _compute_received_count( + dag_models.get(row.target_dag_id), log_by_apdr.get(row.id, {}), - dag_timetables_assets[row.target_dag_id][0], - dag_timetables_assets[row.target_dag_id][2], + rollup_timetables_by_dag[row.target_dag_id], + assets_by_dag[row.target_dag_id][1], row.partition_key, ), ) @@ -251,10 +294,7 @@ def get_partitioned_dag_runs( asset_expressions: dict[str, dict | None] | None = None if dag_id.value is None: - dag_rows = session.execute( - select(DagModel.dag_id, DagModel.asset_expression).where(DagModel.dag_id.in_(unique_dag_ids)) - ).all() - asset_expressions = {r.dag_id: r.asset_expression for r in dag_rows} + asset_expressions = {dm.dag_id: dm.asset_expression for dm in dag_models.values()} return PartitionedDagRunCollectionResponse( partitioned_dag_runs=results, @@ -307,34 +347,37 @@ def get_pending_partitioned_dag_run( ): received_keys_by_asset.setdefault(row.asset_id, set()).add(row.source_partition_key or "") - asset_expression_subq = ( - select(DagModel.asset_expression).where(DagModel.dag_id == dag_id).scalar_subquery() - ) + dag_model = session.get(DagModel, dag_id) asset_rows = session.execute( select( AssetModel.id, AssetModel.uri, AssetModel.name, - asset_expression_subq.label("asset_expression"), ) .join(DagScheduleAssetReference, DagScheduleAssetReference.asset_id == AssetModel.id) .where(DagScheduleAssetReference.dag_id == dag_id, AssetModel.active.has()) .order_by(AssetModel.uri) ).all() - timetable = _load_timetable(dag_id, session) + # Skip the timetable load when no rollup mapper is configured — the cached + # ``partition_mapper_info`` already tells us whether we will need + # ``to_upstream`` evaluation, which is the only thing the timetable adds here. + has_rollup_mappers = dag_model is not None and dag_model.has_rollup_mappers + rollup_timetable = _load_timetable(dag_id, session) if has_rollup_mappers else None assets = [] for asset_row in asset_rows: received_keys = sorted(received_keys_by_asset.get(asset_row.id, set())) required_keys: list[str] = [partition_key] - is_rollup = False - if timetable is not None: + is_rollup = ( + has_rollup_mappers + and dag_model is not None + and dag_model.is_rollup_asset(name=asset_row.name, uri=asset_row.uri) + ) + if is_rollup and rollup_timetable is not None: with suppress(Exception): - mapper = timetable.get_partition_mapper(name=row.name, uri=row.uri) - if mapper.is_rollup: - required_keys = sorted(cast("RollupMapper", mapper).to_upstream(partition_key)) - is_rollup = True + mapper = rollup_timetable.get_partition_mapper(name=asset_row.name, uri=asset_row.uri) + required_keys = sorted(cast("RollupMapper", mapper).to_upstream(partition_key)) received_count = len(received_keys) required_count = len(required_keys) assets.append( @@ -353,7 +396,7 @@ def get_pending_partitioned_dag_run( total_received = sum(a.received_count for a in assets) total_required = sum(a.required_count for a in assets) - asset_expression = asset_rows[0].asset_expression if asset_rows else None + asset_expression = dag_model.asset_expression if dag_model is not None else None return PartitionedDagRunDetailResponse( id=partitioned_dag_run.id, diff --git a/airflow-core/src/airflow/dag_processing/collection.py b/airflow-core/src/airflow/dag_processing/collection.py index 6f8bc752bbc4f..2e780825cd577 100644 --- a/airflow-core/src/airflow/dag_processing/collection.py +++ b/airflow-core/src/airflow/dag_processing/collection.py @@ -622,6 +622,7 @@ def update_dags( dm.timetable_description = dag.timetable.description dm.timetable_partitioned = dag.timetable.partitioned dm.timetable_periodic = dag.timetable.periodic + dm.partition_mapper_info = dag.timetable.partition_mapper_info dm.fail_fast = dag.fail_fast if dag.fail_fast is not None else False allowed_types = dag.allowed_run_types diff --git a/airflow-core/src/airflow/migrations/versions/0115_3_3_0_add_partition_mapper_info_to_dag.py b/airflow-core/src/airflow/migrations/versions/0115_3_3_0_add_partition_mapper_info_to_dag.py new file mode 100644 index 0000000000000..7a451a632a232 --- /dev/null +++ b/airflow-core/src/airflow/migrations/versions/0115_3_3_0_add_partition_mapper_info_to_dag.py @@ -0,0 +1,73 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Add partition_mapper_info to DagModel. + +The new JSON column caches per-asset partition mapper metadata produced +during Dag serialization (one entry per asset, see ``PartitionMapperInfo``) +so UI endpoints can resolve mapper attributes such as ``is_rollup`` without +deserializing the timetable on every request. Defaults to ``[]`` for +timetables without per-asset partition mappers. + +Revision ID: c20871fbf23a +Revises: a7f3b2c1d4e5 +Create Date: 2026-05-06 00:00:00.000000 + +""" + +from __future__ import annotations + +import sqlalchemy as sa +from alembic import op + +from airflow.migrations.utils import disable_sqlite_fkeys + +revision = "c20871fbf23a" +down_revision = "a7f3b2c1d4e5" +branch_labels = None +depends_on = None +airflow_version = "3.3.0" + + +def upgrade(): + """ + Add partition_mapper_info JSON column to dag table. + + The column is added nullable, existing rows are backfilled with ``[]``, + then the column is altered to NOT NULL. MySQL refuses literal defaults on + JSON columns, and the model intentionally carries no ``server_default``, + so we use the same backfill strategy on every backend to avoid leaving a + stray DB-side default that diverges from the ORM definition. The final + ``alter_column`` triggers a table rebuild on SQLite; foreign keys are + disabled around the whole upgrade so dependent tables' references stay + intact regardless of which step ends up rebuilding the table. + """ + with disable_sqlite_fkeys(op): + with op.batch_alter_table("dag", schema=None) as batch_op: + batch_op.add_column(sa.Column("partition_mapper_info", sa.JSON(), nullable=True)) + op.execute(sa.text("UPDATE dag SET partition_mapper_info = '[]' WHERE partition_mapper_info IS NULL")) + with op.batch_alter_table("dag", schema=None) as batch_op: + batch_op.alter_column("partition_mapper_info", existing_type=sa.JSON(), nullable=False) + + +def downgrade(): + """Remove partition_mapper_info column from dag table.""" + with disable_sqlite_fkeys(op): + with op.batch_alter_table("dag", schema=None) as batch_op: + batch_op.drop_column("partition_mapper_info") diff --git a/airflow-core/src/airflow/models/dag.py b/airflow-core/src/airflow/models/dag.py index ca84b7047b435..a50dd10298c1a 100644 --- a/airflow-core/src/airflow/models/dag.py +++ b/airflow-core/src/airflow/models/dag.py @@ -66,7 +66,7 @@ from airflow.serialization.definitions.assets import SerializedAssetUniqueKey from airflow.serialization.encoders import DAT, encode_deadline_alert from airflow.serialization.enums import Encoding -from airflow.timetables.base import DataInterval, Timetable +from airflow.timetables.base import DataInterval, PartitionMapperInfo, Timetable from airflow.timetables.interval import CronDataIntervalTimetable, DeltaDataIntervalTimetable from airflow.timetables.simple import AssetTriggeredTimetable, NullTimetable, OnceTimetable from airflow.utils.session import NEW_SESSION, provide_session @@ -388,6 +388,16 @@ class DagModel(Base): timetable_partitioned: Mapped[bool] = mapped_column(Boolean, nullable=False, server_default="0") # Whether the timetable is periodic (supports backfilling). timetable_periodic: Mapped[bool] = mapped_column(Boolean, nullable=False, server_default="0") + # Cached partition mapper metadata for partitioned timetables, populated + # during Dag serialization so the UI can resolve mapper attributes without + # deserializing the timetable. See ``PartitionMapperInfo`` for the per-asset + # entry shape; empty list for timetables without per-asset partition mappers. + # No ``server_default`` — MySQL refuses literal defaults on JSON columns; + # the migration backfills existing rows and ``default=list`` covers new + # ORM inserts that don't pass an explicit value. + partition_mapper_info: Mapped[list[PartitionMapperInfo]] = mapped_column( + sa.JSON(), nullable=False, default=list + ) # Asset expression based on asset triggers asset_expression: Mapped[dict[str, Any] | None] = mapped_column(sa.JSON(), nullable=True) # DAG deadline information @@ -481,6 +491,28 @@ def __init__(self, **kwargs): def __repr__(self): return f"" + def is_rollup_asset(self, *, name: str, uri: str) -> bool: + """ + Return whether the asset identified by *name*/*uri* uses a rollup mapper. + + Reads the cached ``partition_mapper_info`` populated during Dag + serialization, mirroring ``PartitionedAssetTimetable.get_partition_mapper`` + (name lookup wins over uri lookup). Returns ``False`` when no entry + matches the asset. + """ + for entry in self.partition_mapper_info: + if entry.get("name") == name: + return entry["is_rollup"] + for entry in self.partition_mapper_info: + if entry.get("uri") == uri: + return entry["is_rollup"] + return False + + @property + def has_rollup_mappers(self) -> bool: + """Whether any cached partition mapper is a rollup mapper.""" + return any(entry["is_rollup"] for entry in self.partition_mapper_info) + @property def next_dagrun_data_interval(self) -> DataInterval | None: return _get_model_data_interval( diff --git a/airflow-core/src/airflow/timetables/base.py b/airflow-core/src/airflow/timetables/base.py index a92bcd7a1f85d..d9ed4f574af4b 100644 --- a/airflow-core/src/airflow/timetables/base.py +++ b/airflow-core/src/airflow/timetables/base.py @@ -16,7 +16,9 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Any, NamedTuple, Protocol, runtime_checkable +from typing import TYPE_CHECKING, Any, NamedTuple, Protocol, TypedDict, runtime_checkable + +from typing_extensions import NotRequired from airflow._shared.module_loading import qualname from airflow._shared.timezones import timezone @@ -39,6 +41,21 @@ from airflow.utils.types import DagRunType +class PartitionMapperInfo(TypedDict): + """ + JSON-serializable snapshot of one asset's partition mapper attributes. + + Stored as ``DagModel.partition_mapper_info`` (a list of these) so the UI can + resolve mapper attributes without deserializing the timetable on each request. + Either ``name``, ``uri``, or both identify the asset; ``Asset.ref(name=...)`` + omits ``uri`` and ``Asset.ref(uri=...)`` omits ``name``. + """ + + is_rollup: bool + name: NotRequired[str] + uri: NotRequired[str] + + class DataInterval(NamedTuple): """ A data interval for a DagRun to operate over. @@ -218,6 +235,18 @@ class Timetable(Protocol): instead of the traditional logic based on logical dates and data intervals. """ + @property + def partition_mapper_info(self) -> list[PartitionMapperInfo]: + """ + JSON-serializable per-asset partition mapper attributes. + + Empty list for timetables without asset-level partition mappers (the + default, including non-partitioned timetables and cron-driven partitioned + timetables). Asset-driven partitioned timetables override this with one + entry per asset (or asset ref) — see :class:`PartitionMapperInfo`. + """ + return [] + @classmethod def deserialize(cls, data: dict[str, Any]) -> Timetable: """ diff --git a/airflow-core/src/airflow/timetables/simple.py b/airflow-core/src/airflow/timetables/simple.py index 01fb12f81dd0c..be47f56957fcb 100644 --- a/airflow-core/src/airflow/timetables/simple.py +++ b/airflow-core/src/airflow/timetables/simple.py @@ -32,7 +32,7 @@ SerializedAssetUriRef, ) from airflow.serialization.encoders import encode_asset_like, encode_partition_mapper -from airflow.timetables.base import DagRunInfo, DataInterval, Timetable +from airflow.timetables.base import DagRunInfo, DataInterval, PartitionMapperInfo, Timetable try: from airflow.sdk.definitions.asset import BaseAsset @@ -299,6 +299,30 @@ def get_partition_mapper(self, *, name: str = "", uri: str = "") -> PartitionMap return self.default_partition_mapper + @property + def partition_mapper_info(self) -> list[PartitionMapperInfo]: + """ + JSON-serializable snapshot of partition mapper attributes per asset. + + One :class:`~airflow.timetables.base.PartitionMapperInfo` entry per asset + (or asset ref) in ``partition_mapper_config``. The UI reads this from + the cached ``DagModel.partition_mapper_info`` instead of deserializing + the timetable on each request. + """ + entries: list[PartitionMapperInfo] = [] + for base_asset, partition_mapper in self.partition_mapper_config.items(): + is_rollup = partition_mapper.is_rollup + for unique_key, _ in base_asset.iter_assets(): + entries.append( + PartitionMapperInfo(name=unique_key.name, uri=unique_key.uri, is_rollup=is_rollup) + ) + for s_asset_ref in base_asset.iter_asset_refs(): + if isinstance(s_asset_ref, SerializedAssetNameRef): + entries.append(PartitionMapperInfo(name=s_asset_ref.name, is_rollup=is_rollup)) + elif isinstance(s_asset_ref, SerializedAssetUriRef): + entries.append(PartitionMapperInfo(uri=s_asset_ref.uri, is_rollup=is_rollup)) + return entries + def serialize(self) -> dict[str, Any]: from airflow.serialization.serialized_objects import encode_asset_like diff --git a/airflow-core/src/airflow/utils/db.py b/airflow-core/src/airflow/utils/db.py index 4caa9901bfb95..d863e8433d05f 100644 --- a/airflow-core/src/airflow/utils/db.py +++ b/airflow-core/src/airflow/utils/db.py @@ -116,7 +116,7 @@ class MappedClassProtocol(Protocol): "3.1.0": "cc92b33c6709", "3.1.8": "509b94a1042d", "3.2.0": "1d6611b6ab7c", - "3.3.0": "a7f3b2c1d4e5", + "3.3.0": "c20871fbf23a", } # Prefix used to identify tables holding data moved during migration. diff --git a/airflow-core/tests/unit/dag_processing/test_collection.py b/airflow-core/tests/unit/dag_processing/test_collection.py index 42be7792381fb..943012b79c578 100644 --- a/airflow-core/tests/unit/dag_processing/test_collection.py +++ b/airflow-core/tests/unit/dag_processing/test_collection.py @@ -54,9 +54,13 @@ from airflow.models.errors import ParseImportError from airflow.models.serialized_dag import SerializedDagModel from airflow.models.trigger import Trigger +from airflow.partition_mappers.base import RollupMapper +from airflow.partition_mappers.temporal import StartOfDayMapper +from airflow.partition_mappers.window import DayWindow from airflow.providers.standard.operators.empty import EmptyOperator from airflow.providers.standard.triggers.file import FileDeleteTrigger from airflow.sdk import DAG, Asset, AssetAlias, AssetWatcher +from airflow.sdk.definitions.timetables.assets import PartitionedAssetTimetable from airflow.serialization.definitions.assets import SerializedAsset from airflow.serialization.encoders import encode_trigger, ensure_serialized_asset from airflow.serialization.serialized_objects import LazyDeserializedDAG @@ -1193,3 +1197,63 @@ def test_update_dag_tags(self, testing_dag_bundle, session, initial_tags, new_ta session.commit() assert {t.name for t in dag_model.tags} == expected_tags + + +@pytest.mark.db_test +class TestPartitionMapperInfoSync: + """Verify partition_mapper_info is populated on DagModel during Dag sync.""" + + @pytest.fixture(autouse=True) + def clean_db_around_test(self) -> Generator: + def reset() -> None: + clear_db_dags() + clear_db_assets() + clear_db_serialized_dags() + + reset() + yield + reset() + + def test_partitioned_dag_with_rollup_mapper(self, dag_maker, session): + """Cover regular Asset, name ref, and uri ref entries in partition_mapper_info.""" + rollup_asset = Asset(uri="s3://bucket/rollup", name="rollup") + name_ref = Asset.ref(name="ref_by_name") + uri_ref = Asset.ref(uri="s3://ref") + rollup_mapper = RollupMapper(source_mapper=StartOfDayMapper(), window=DayWindow()) + + with dag_maker( + dag_id="partitioned_with_rollup", + schedule=PartitionedAssetTimetable( + assets=rollup_asset, + partition_mapper_config={ + rollup_asset: rollup_mapper, + name_ref: rollup_mapper, + uri_ref: rollup_mapper, + }, + ), + serialized=True, + ): + EmptyOperator(task_id="t") + + dag_model = session.get(DagModel, "partitioned_with_rollup") + assert dag_model.partition_mapper_info == [ + {"name": "rollup", "uri": "s3://bucket/rollup", "is_rollup": True}, + {"name": "ref_by_name", "is_rollup": True}, + {"uri": "s3://ref", "is_rollup": True}, + ] + assert dag_model.has_rollup_mappers is True + assert dag_model.is_rollup_asset(name="rollup", uri="s3://bucket/rollup") is True + assert dag_model.is_rollup_asset(name="ref_by_name", uri="") is True + assert dag_model.is_rollup_asset(name="", uri="s3://ref") is True + + def test_non_partitioned_dag_leaves_info_empty(self, dag_maker, session): + with dag_maker( + dag_id="non_partitioned_dag", + schedule=[Asset(uri="s3://bucket/A", name="A")], + serialized=True, + ): + EmptyOperator(task_id="t") + + dag_model = session.get(DagModel, "non_partitioned_dag") + assert dag_model.partition_mapper_info == [] + assert dag_model.has_rollup_mappers is False diff --git a/airflow-core/tests/unit/models/test_dag.py b/airflow-core/tests/unit/models/test_dag.py index b34ab12dd4aef..c959a8a624f74 100644 --- a/airflow-core/tests/unit/models/test_dag.py +++ b/airflow-core/tests/unit/models/test_dag.py @@ -2961,6 +2961,72 @@ def test_get_dag_id_to_team_name_mapping(self, testing_team): } +class TestDagModelPartitionMapperInfo: + """Pure-function tests for DagModel.is_rollup_asset / has_rollup_mappers.""" + + @pytest.mark.parametrize( + ("info", "name", "uri", "expected"), + [ + pytest.param([], "a", "s3://a", False, id="empty-info"), + pytest.param( + [ + {"name": "asset", "is_rollup": True}, + {"uri": "s3://asset", "is_rollup": False}, + ], + "asset", + "s3://asset", + True, + id="name-match-wins-over-uri", + ), + pytest.param( + [{"uri": "s3://asset", "is_rollup": True}], + "asset", + "s3://asset", + True, + id="uri-match-fallback", + ), + pytest.param( + [{"name": "other", "is_rollup": True}], + "asset", + "s3://asset", + False, + id="unknown-asset", + ), + ], + ) + def test_is_rollup_asset(self, info, name, uri, expected): + dm = DagModel(dag_id="d") + dm.partition_mapper_info = info + assert dm.is_rollup_asset(name=name, uri=uri) is expected + + @pytest.mark.parametrize( + ("info", "expected"), + [ + pytest.param([], False, id="empty"), + pytest.param( + [ + {"name": "a", "is_rollup": False}, + {"uri": "s3://b", "is_rollup": False}, + ], + False, + id="no-rollup-entries", + ), + pytest.param( + [ + {"name": "a", "is_rollup": False}, + {"uri": "s3://b", "is_rollup": True}, + ], + True, + id="at-least-one-rollup", + ), + ], + ) + def test_has_rollup_mappers(self, info, expected): + dm = DagModel(dag_id="d") + dm.partition_mapper_info = info + assert dm.has_rollup_mappers is expected + + class TestQueries: def setup_method(self) -> None: clear_db_runs() diff --git a/airflow-core/tests/unit/timetables/test_partitioned_timetable.py b/airflow-core/tests/unit/timetables/test_partitioned_timetable.py index b1befe9f5d2cd..0a66c124808b1 100644 --- a/airflow-core/tests/unit/timetables/test_partitioned_timetable.py +++ b/airflow-core/tests/unit/timetables/test_partitioned_timetable.py @@ -25,7 +25,10 @@ import pytest from airflow._shared.module_loading import qualname +from airflow.partition_mappers.base import RollupMapper from airflow.partition_mappers.identity import IdentityMapper as IdentityMapper +from airflow.partition_mappers.temporal import StartOfDayMapper +from airflow.partition_mappers.window import DayWindow from airflow.sdk import Asset from airflow.serialization.definitions.assets import SerializedAsset from airflow.serialization.encoders import ensure_serialized_asset @@ -110,6 +113,44 @@ def test_get_partition_mapper_with_mapping(self, asset_obj): assert isinstance(timetable.get_partition_mapper(name="test_1"), Key1Mapper) assert isinstance(timetable.get_partition_mapper(uri="test_1"), Key1Mapper) + def test_partition_mapper_info_empty(self): + timetable = PartitionedAssetTimetable(assets=Asset("a")) + assert timetable.partition_mapper_info == [] + + def test_partition_mapper_info_mixes_rollup_and_non_rollup(self): + non_rollup = ensure_serialized_asset(Asset(name="non_rollup_name", uri="s3://bucket/non_rollup")) + rollup = ensure_serialized_asset(Asset(name="rollup_name", uri="s3://bucket/rollup")) + timetable = PartitionedAssetTimetable( + assets=non_rollup, + partition_mapper_config={ + non_rollup: IdentityMapper(), + rollup: RollupMapper(source_mapper=StartOfDayMapper(), window=DayWindow()), + }, + ) + + info = timetable.partition_mapper_info + assert info == [ + {"name": "non_rollup_name", "uri": "s3://bucket/non_rollup", "is_rollup": False}, + {"name": "rollup_name", "uri": "s3://bucket/rollup", "is_rollup": True}, + ] + + def test_partition_mapper_info_handles_asset_refs(self): + timetable = PartitionedAssetTimetable( + assets=Asset(name="x", uri="x"), + partition_mapper_config={ + ensure_serialized_asset(Asset.ref(name="ref_by_name")): RollupMapper( + source_mapper=StartOfDayMapper(), window=DayWindow() + ), + ensure_serialized_asset(Asset.ref(uri="s3://ref")): IdentityMapper(), + }, + ) + + info = timetable.partition_mapper_info + assert info == [ + {"name": "ref_by_name", "is_rollup": True}, + {"uri": "s3://ref", "is_rollup": False}, + ] + def test_serialize(self): ser_asset = ensure_serialized_asset(Asset("test")) timetable = PartitionedAssetTimetable( diff --git a/scripts/in_container/run_migration_round_trip.py b/scripts/in_container/run_migration_round_trip.py index 20e15c672c3db..f8b12d3dc9454 100755 --- a/scripts/in_container/run_migration_round_trip.py +++ b/scripts/in_container/run_migration_round_trip.py @@ -105,6 +105,7 @@ "timetable_type": "'cron'", "timetable_partitioned": "0", "timetable_periodic": "0", + "partition_mapper_info": "'[]'", }, "dag_version": { "id": f"'{SEED_DAG_VERSION_ID}'",