diff --git a/airflow-core/docs/migrations-ref.rst b/airflow-core/docs/migrations-ref.rst index a9502d28fe944..3cff5e5449ce4 100644 --- a/airflow-core/docs/migrations-ref.rst +++ b/airflow-core/docs/migrations-ref.rst @@ -39,7 +39,9 @@ Here's the list of all the Database Migrations that are executed via when you ru +-------------------------+------------------+-------------------+--------------------------------------------------------------+ | Revision ID | Revises ID | Airflow Version | Description | +=========================+==================+===================+==============================================================+ -| ``a7f3b2c1d4e5`` (head) | ``b8f3e4a1d2c9`` | ``3.3.0`` | Add allow_producer_teams column to | +| ``c20871fbf23a`` (head) | ``a7f3b2c1d4e5`` | ``3.3.0`` | Add partition_mapper_info to DagModel. | ++-------------------------+------------------+-------------------+--------------------------------------------------------------+ +| ``a7f3b2c1d4e5`` | ``b8f3e4a1d2c9`` | ``3.3.0`` | Add allow_producer_teams column to | | | | | dag_schedule_asset_reference table. | +-------------------------+------------------+-------------------+--------------------------------------------------------------+ | ``b8f3e4a1d2c9`` | ``fde9ed84d07b`` | ``3.3.0`` | Add retry_delay_override and retry_reason to task_instance. | diff --git a/airflow-core/src/airflow/api_fastapi/core_api/datamodels/ui/assets.py b/airflow-core/src/airflow/api_fastapi/core_api/datamodels/ui/assets.py new file mode 100644 index 0000000000000..cd5aba6a68f1e --- /dev/null +++ b/airflow-core/src/airflow/api_fastapi/core_api/datamodels/ui/assets.py @@ -0,0 +1,45 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from datetime import datetime + +from pydantic import Field + +from airflow.api_fastapi.core_api.base import BaseModel + + +class NextRunAssetEventResponse(BaseModel): + """One asset event in the ``next_run_assets`` payload.""" + + id: int + name: str | None + uri: str + last_update: datetime | None = None + received_count: int = 0 + required_count: int = 1 + received_keys: list[str] = Field(default_factory=list) + required_keys: list[str] = Field(default_factory=list) + is_rollup: bool = False + + +class NextRunAssetsResponse(BaseModel): + """Response for the ``next_run_assets`` endpoint.""" + + asset_expression: dict | None = None + events: list[NextRunAssetEventResponse] + pending_partition_count: int | None = None diff --git a/airflow-core/src/airflow/api_fastapi/core_api/datamodels/ui/partitioned_dag_runs.py b/airflow-core/src/airflow/api_fastapi/core_api/datamodels/ui/partitioned_dag_runs.py index 628f8560e96af..aca1d9a331dee 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/datamodels/ui/partitioned_dag_runs.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/datamodels/ui/partitioned_dag_runs.py @@ -47,6 +47,11 @@ class PartitionedDagRunAssetResponse(BaseModel): asset_name: str asset_uri: str received: bool + received_count: int + required_count: int + received_keys: list[str] + required_keys: list[str] + is_rollup: bool = False class PartitionedDagRunDetailResponse(BaseModel): diff --git a/airflow-core/src/airflow/api_fastapi/core_api/openapi/_private_ui.yaml b/airflow-core/src/airflow/api_fastapi/core_api/openapi/_private_ui.yaml index 2983263bbc59b..eea4a738a83e2 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/openapi/_private_ui.yaml +++ b/airflow-core/src/airflow/api_fastapi/core_api/openapi/_private_ui.yaml @@ -91,9 +91,7 @@ paths: content: application/json: schema: - type: object - additionalProperties: true - title: Response Next Run Assets + $ref: '#/components/schemas/NextRunAssetsResponse' '422': description: Validation Error content: @@ -3069,6 +3067,77 @@ components: - extra_menu_items title: MenuItemCollectionResponse description: Menu Item Collection serializer for responses. + NextRunAssetEventResponse: + properties: + id: + type: integer + title: Id + name: + anyOf: + - type: string + - type: 'null' + title: Name + uri: + type: string + title: Uri + last_update: + anyOf: + - type: string + format: date-time + - type: 'null' + title: Last Update + received_count: + type: integer + title: Received Count + default: 0 + required_count: + type: integer + title: Required Count + default: 1 + received_keys: + items: + type: string + type: array + title: Received Keys + required_keys: + items: + type: string + type: array + title: Required Keys + is_rollup: + type: boolean + title: Is Rollup + default: false + type: object + required: + - id + - name + - uri + title: NextRunAssetEventResponse + description: One asset event in the ``next_run_assets`` payload. + NextRunAssetsResponse: + properties: + asset_expression: + anyOf: + - additionalProperties: true + type: object + - type: 'null' + title: Asset Expression + events: + items: + $ref: '#/components/schemas/NextRunAssetEventResponse' + type: array + title: Events + pending_partition_count: + anyOf: + - type: integer + - type: 'null' + title: Pending Partition Count + type: object + required: + - events + title: NextRunAssetsResponse + description: Response for the ``next_run_assets`` endpoint. NodeResponse: properties: id: @@ -3152,12 +3221,36 @@ components: received: type: boolean title: Received + received_count: + type: integer + title: Received Count + required_count: + type: integer + title: Required Count + received_keys: + items: + type: string + type: array + title: Received Keys + required_keys: + items: + type: string + type: array + title: Required Keys + is_rollup: + type: boolean + title: Is Rollup + default: false type: object required: - asset_id - asset_name - asset_uri - received + - received_count + - required_count + - received_keys + - required_keys title: PartitionedDagRunAssetResponse description: Asset info within a partitioned Dag run detail. PartitionedDagRunCollectionResponse: diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/assets.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/assets.py index 155b49726f6e4..80d7cdad9e1c0 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/assets.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/assets.py @@ -16,11 +16,19 @@ # under the License. from __future__ import annotations +from contextlib import suppress +from typing import TYPE_CHECKING, cast + from fastapi import Depends, HTTPException, status from sqlalchemy import ColumnElement, and_, case, exists, func, select, true from airflow.api_fastapi.common.db.common import SessionDep from airflow.api_fastapi.common.router import AirflowRouter +from airflow.api_fastapi.core_api.datamodels.ui.assets import ( + NextRunAssetEventResponse, + NextRunAssetsResponse, +) +from airflow.api_fastapi.core_api.routes.ui.partitioned_dag_runs import _load_timetable from airflow.api_fastapi.core_api.security import requires_access_asset, requires_access_dag from airflow.models import DagModel from airflow.models.asset import ( @@ -32,6 +40,9 @@ PartitionedAssetKeyLog, ) +if TYPE_CHECKING: + from airflow.partition_mappers.base import RollupMapper + assets_router = AirflowRouter(tags=["Asset"]) @@ -42,7 +53,7 @@ def next_run_assets( dag_id: str, session: SessionDep, -) -> dict: +) -> NextRunAssetsResponse: dag_model = DagModel.get_dagmodel(dag_id, session=session) if dag_model is None: raise HTTPException(status.HTTP_404_NOT_FOUND, f"Dag with id {dag_id} was not found") @@ -55,7 +66,7 @@ def next_run_assets( pending_partition_count: int | None = None queued_expr: ColumnElement[int] - if is_partitioned := dag_model.timetable_summary == "Partitioned Asset": + if is_partitioned := dag_model.timetable_partitioned: pending_partition_count = session.scalar( select(func.count()) .select_from(AssetPartitionDagRun) @@ -90,7 +101,7 @@ def next_run_assets( AssetModel.id, AssetModel.uri, AssetModel.name, - func.max(AssetEvent.timestamp).label("lastUpdate"), + func.max(AssetEvent.timestamp).label("last_update"), queued_expr.label("queued"), ) .join(DagScheduleAssetReference, DagScheduleAssetReference.asset_id == AssetModel.id) @@ -110,12 +121,97 @@ def next_run_assets( isouter=True, ) - events = [dict(info._mapping) for info in session.execute(query)] - for event in events: - if not event.pop("queued", None): - event["lastUpdate"] = None + raw_rows = list(session.execute(query)) + + if not is_partitioned: + events = [ + NextRunAssetEventResponse( + id=row.id, + name=row.name, + uri=row.uri, + last_update=row.last_update if row.queued else None, + ) + for row in raw_rows + ] + return NextRunAssetsResponse(asset_expression=dag_model.asset_expression, events=events) + + # Partitioned Dags: enrich with per-asset received/required counts and rollup flag. + pending_apdr = session.execute( + select(AssetPartitionDagRun.id, AssetPartitionDagRun.partition_key) + .where( + AssetPartitionDagRun.target_dag_id == dag_id, + AssetPartitionDagRun.created_dag_run_id.is_(None), + ) + .order_by(AssetPartitionDagRun.created_at.desc()) + .limit(1) + ).one_or_none() + + has_rollup_mappers = dag_model.has_rollup_mappers - data: dict = {"asset_expression": dag_model.asset_expression, "events": events} - if pending_partition_count is not None: - data["pending_partition_count"] = pending_partition_count - return data + if pending_apdr is None: + # No pending APDR yet — mark rollup assets so the UI can handle them + # correctly (e.g. skip "Asset Triggered" in favour of the asset name view). + # Reads from the cached partition_mapper_info so no timetable load is needed. + events = [ + NextRunAssetEventResponse( + id=row.id, + name=row.name, + uri=row.uri, + last_update=row.last_update if row.queued else None, + is_rollup=has_rollup_mappers and dag_model.is_rollup_asset(name=row.name, uri=row.uri), + ) + for row in raw_rows + ] + return NextRunAssetsResponse( + asset_expression=dag_model.asset_expression, + events=events, + pending_partition_count=pending_partition_count, + ) + + # Collect received upstream partition keys per asset for this partition run. + # Use a set to deduplicate: multiple events for the same key count as one. + received_keys_by_asset: dict[int, set[str]] = {} + for log_row in session.execute( + select( + PartitionedAssetKeyLog.asset_id, + PartitionedAssetKeyLog.source_partition_key, + ).where(PartitionedAssetKeyLog.asset_partition_dag_run_id == pending_apdr.id) + ): + received_keys_by_asset.setdefault(log_row.asset_id, set()).add(log_row.source_partition_key or "") + + # The timetable is only needed to call ``to_upstream`` for rollup mappers. + # When the cached info shows no rollup mappers, skip loading it entirely. + rollup_timetable = _load_timetable(dag_id, session) if has_rollup_mappers else None + + events = [] + for row in raw_rows: + received_keys = sorted(received_keys_by_asset.get(row.id, set())) + required_keys: list[str] = [pending_apdr.partition_key] + is_rollup = has_rollup_mappers and dag_model.is_rollup_asset(name=row.name, uri=row.uri) + if is_rollup and rollup_timetable is not None: + with suppress(Exception): + mapper = rollup_timetable.get_partition_mapper(name=row.name, uri=row.uri) + required_keys = sorted(cast("RollupMapper", mapper).to_upstream(pending_apdr.partition_key)) + received_count = len(received_keys) + required_count = len(required_keys) + # Only surface last_update once all required upstream keys have arrived. + last_update = row.last_update if row.queued and received_count >= required_count else None + events.append( + NextRunAssetEventResponse( + id=row.id, + name=row.name, + uri=row.uri, + last_update=last_update, + received_count=received_count, + required_count=required_count, + received_keys=received_keys, + required_keys=required_keys, + is_rollup=is_rollup, + ) + ) + + return NextRunAssetsResponse( + asset_expression=dag_model.asset_expression, + events=events, + pending_partition_count=pending_partition_count, + ) diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/partitioned_dag_runs.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/partitioned_dag_runs.py index fe4696efbb5a0..e90c5e81eda7f 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/partitioned_dag_runs.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/partitioned_dag_runs.py @@ -16,8 +16,11 @@ # under the License. from __future__ import annotations +from contextlib import suppress +from typing import TYPE_CHECKING, TypeAlias, cast + from fastapi import Depends, HTTPException, status -from sqlalchemy import exists, func, select +from sqlalchemy import select from airflow.api_fastapi.common.db.common import SessionDep, apply_filters_to_select from airflow.api_fastapi.common.parameters import ( @@ -44,17 +47,149 @@ PartitionedAssetKeyLog, ) from airflow.models.dagrun import DagRun +from airflow.models.serialized_dag import SerializedDagModel +from airflow.timetables.simple import PartitionedAssetTimetable + +if TYPE_CHECKING: + from sqlalchemy.orm import Session + + from airflow.partition_mappers.base import RollupMapper + + +AssetNameUri: TypeAlias = tuple[str, str] +"""A ``(name, uri)`` pair identifying an asset.""" + + +def _load_timetable(dag_id: str, session: Session) -> PartitionedAssetTimetable | None: + """ + Return the PartitionedAssetTimetable for *dag_id*, or None if absent or not partitioned. + + Callers gate this behind ``DagModel.has_rollup_mappers``, which is only + populated for ``PartitionedAssetTimetable``. The ``TYPE_CHECKING`` assert + narrows the type for mypy without a runtime ``isinstance`` cost. + """ + serdag = SerializedDagModel.get(dag_id=dag_id, session=session) + if serdag is None: + return None + with suppress(Exception): + if serdag.dag.timetable.partitioned: + if TYPE_CHECKING: + assert isinstance(serdag.dag.timetable, PartitionedAssetTimetable) + return serdag.dag.timetable + return None + + +def _fetch_active_assets_per_dag( + dag_ids: list[str], session: Session +) -> dict[str, tuple[list[AssetNameUri], dict[int, AssetNameUri]]]: + """ + Batch-fetch active required assets for multiple Dags in a single query. + + Returns ``{dag_id: ([(name, uri), ...], {asset_id: (name, uri)})}``. + Dags with no active references are still included with empty containers + so callers can index by ``dag_id`` without ``KeyError``. + """ + rows = session.execute( + select( + DagScheduleAssetReference.dag_id, + AssetModel.id, + AssetModel.name, + AssetModel.uri, + ) + .join(DagScheduleAssetReference, DagScheduleAssetReference.asset_id == AssetModel.id) + .where(DagScheduleAssetReference.dag_id.in_(dag_ids), AssetModel.active.has()) + ).all() + result: dict[str, tuple[list[AssetNameUri], dict[int, AssetNameUri]]] = { + dag_id: ([], {}) for dag_id in dag_ids + } + for row in rows: + info, id_to_info = result[row.dag_id] + info.append((row.name, row.uri)) + id_to_info[row.id] = (row.name, row.uri) + return result + + +def _resolve_rollup_upstream_keys( + dag_model: DagModel | None, + rollup_timetable: PartitionedAssetTimetable | None, + name: str, + uri: str, + partition_key: str, +) -> frozenset[str] | None: + """ + Return the upstream rollup keys for an asset, or ``None`` for non-rollup. + + Returns ``None`` when the asset is not rollup, when ``dag_model`` is missing + (Dag was deleted while runs still reference it), when ``rollup_timetable`` + couldn't be loaded, or when the mapper raises — callers fall back to + non-rollup behaviour for that asset rather than 500-ing the request. + """ + if dag_model is None or rollup_timetable is None or not dag_model.is_rollup_asset(name=name, uri=uri): + return None + with suppress(Exception): + mapper = rollup_timetable.get_partition_mapper(name=name, uri=uri) + return frozenset(cast("RollupMapper", mapper).to_upstream(partition_key)) + return None + + +def _compute_total_required( + dag_model: DagModel | None, + rollup_timetable: PartitionedAssetTimetable | None, + asset_info: list[AssetNameUri], + partition_key: str, +) -> int: + """ + Sum required upstream events across all assets, using to_upstream for rollup mappers. + + Non-rollup assets (and rollup assets whose mapper raises) count as 1. + """ + total = 0 + for name, uri in asset_info: + keys = _resolve_rollup_upstream_keys(dag_model, rollup_timetable, name, uri, partition_key) + total += len(keys) if keys is not None else 1 + return total + + +def _compute_received_count( + dag_model: DagModel | None, + received_by_asset: dict[int, set[str]], + rollup_timetable: PartitionedAssetTimetable | None, + asset_id_to_info: dict[int, AssetNameUri], + partition_key: str, +) -> int: + """ + Count received events using rollup-aware deduplication. + + For rollup assets: count distinct upstream keys that intersect the required set. + For non-rollup assets (or rollup mapper failures): any logged event satisfies + the asset — the source_partition_key value is irrelevant. + """ + total = 0 + for asset_id, received_keys in received_by_asset.items(): + # Logs may reference assets that are no longer required (active=False) + # — skip them rather than over-counting. The list route relies on this + # because it batch-fetches logs without a per-Dag asset_id filter. + if asset_id not in asset_id_to_info: + continue + name, uri = asset_id_to_info[asset_id] + keys = _resolve_rollup_upstream_keys(dag_model, rollup_timetable, name, uri, partition_key) + if keys is not None: + total += len(received_keys & keys) + else: + total += 1 if received_keys else 0 + return total + partitioned_dag_runs_router = AirflowRouter(tags=["PartitionedDagRun"]) -def _build_response(row, required_count: int) -> PartitionedDagRunResponse: +def _build_response(row, required_count: int, received_count: int) -> PartitionedDagRunResponse: return PartitionedDagRunResponse( id=row.id, dag_id=row.target_dag_id, partition_key=row.partition_key, created_at=row.created_at.isoformat() if row.created_at else None, - total_received=row.total_received or 0, + total_received=received_count, total_required=required_count, state=row.dag_run_state if row.created_dag_run_id else "pending", created_dag_run_id=row.dag_run_id, @@ -72,50 +207,11 @@ def get_partitioned_dag_runs( has_created_dag_run_id: QueryPartitionedDagRunHasCreatedDagRunIdFilter, ) -> PartitionedDagRunCollectionResponse: """Return PartitionedDagRuns. Filter by dag_id and/or has_created_dag_run_id.""" - if dag_id.value is not None: - # Single query: validate Dag + get required count - dag_info = session.execute( - select( - DagModel.timetable_summary, - func.count(DagScheduleAssetReference.asset_id).label("required_count"), - ) - .outerjoin( - DagScheduleAssetReference, - (DagScheduleAssetReference.dag_id == DagModel.dag_id) - & DagScheduleAssetReference.asset_id.in_( - select(AssetModel.id).where(AssetModel.active.has()) - ), - ) - .where(DagModel.dag_id == dag_id.value) - .group_by(DagModel.dag_id) - ).one_or_none() - - if dag_info is None: - raise HTTPException(status.HTTP_404_NOT_FOUND, f"Dag with id {dag_id.value} was not found") - if dag_info.timetable_summary != "Partitioned Asset": - return PartitionedDagRunCollectionResponse(partitioned_dag_runs=[], total=0) - - required_count = dag_info.required_count - - # Subquery for received count per partition (only count required assets) - required_assets_subq = ( - select(DagScheduleAssetReference.asset_id) - .join(AssetModel, AssetModel.id == DagScheduleAssetReference.asset_id) - .where( - DagScheduleAssetReference.dag_id == AssetPartitionDagRun.target_dag_id, - AssetModel.active.has(), - ) - .correlate(AssetPartitionDagRun) - ) - received_subq = ( - select(func.count(func.distinct(PartitionedAssetKeyLog.asset_id))) - .where( - PartitionedAssetKeyLog.asset_partition_dag_run_id == AssetPartitionDagRun.id, - PartitionedAssetKeyLog.asset_id.in_(required_assets_subq), - ) - .correlate(AssetPartitionDagRun) - .scalar_subquery() - ) + # The Dag-existence / partitioned-timetable check is intentionally deferred to the + # empty-results branch below. In the happy path (rows exist), filtering by dag_id + # already restricts to that Dag, so an extra DagModel lookup just to validate + # existence wastes a query. We only consult DagModel when we have no rows to + # report — that's the only case where the distinction (404 vs empty) matters. query = select( AssetPartitionDagRun.id, @@ -125,7 +221,6 @@ def get_partitioned_dag_runs( AssetPartitionDagRun.created_dag_run_id, DagRun.run_id.label("dag_run_id"), DagRun.state.label("dag_run_state"), - received_subq.label("total_received"), ).outerjoin(DagRun, AssetPartitionDagRun.created_dag_run_id == DagRun.id) query = apply_filters_to_select(statement=query, filters=[dag_id, has_created_dag_run_id]) readable_dag_ids = readable_dags_filter.value @@ -134,32 +229,72 @@ def get_partitioned_dag_runs( query = query.order_by(AssetPartitionDagRun.created_at.desc()) if not (rows := session.execute(query).all()): + if dag_id.value is not None: + dag_info = session.execute( + select(DagModel.timetable_partitioned).where(DagModel.dag_id == dag_id.value) + ).one_or_none() + if dag_info is None: + raise HTTPException(status.HTTP_404_NOT_FOUND, f"Dag with id {dag_id.value} was not found") return PartitionedDagRunCollectionResponse(partitioned_dag_runs=[], total=0) - if dag_id.value is not None: - results = [_build_response(row, required_count) for row in rows] - return PartitionedDagRunCollectionResponse(partitioned_dag_runs=results, total=len(results)) + # Batch-fetch DagModels (for cached partition_mapper_info), required assets, + # and APDR log entries in three single queries instead of N per-Dag queries. + # Timetables are only loaded for Dags that actually have rollup mappers, + # since that's the only case where ``to_upstream`` evaluation is needed. + # A SQL count subquery for total_received cannot honour rollup windows + # without running the mapper, so the rollup-aware Python computation runs + # uniformly across single-Dag and global views. + unique_dag_ids = list({row.target_dag_id for row in rows}) + dag_models: dict[str, DagModel] = { + dm.dag_id: dm + for dm in session.scalars(select(DagModel).where(DagModel.dag_id.in_(unique_dag_ids))).all() + } + assets_by_dag = _fetch_active_assets_per_dag(unique_dag_ids, session) + rollup_timetables_by_dag: dict[str, PartitionedAssetTimetable | None] = { + did: ( + _load_timetable(did, session) + if (dm := dag_models.get(did)) is not None and dm.has_rollup_mappers + else None + ) + for did in unique_dag_ids + } - # No dag_id: need to get required counts and expressions per dag - dag_ids = list({row.target_dag_id for row in rows}) - dag_rows = session.execute( + apdr_ids = [row.id for row in rows] + log_by_apdr: dict[int, dict[int, set[str]]] = {} + for pakl_row in session.execute( select( - DagModel.dag_id, - DagModel.asset_expression, - func.count(DagScheduleAssetReference.asset_id).label("required_count"), - ) - .outerjoin( - DagScheduleAssetReference, - (DagScheduleAssetReference.dag_id == DagModel.dag_id) - & DagScheduleAssetReference.asset_id.in_(select(AssetModel.id).where(AssetModel.active.has())), + PartitionedAssetKeyLog.asset_partition_dag_run_id, + PartitionedAssetKeyLog.asset_id, + PartitionedAssetKeyLog.source_partition_key, + ).where(PartitionedAssetKeyLog.asset_partition_dag_run_id.in_(apdr_ids)) + ).all(): + log_by_apdr.setdefault(pakl_row.asset_partition_dag_run_id, {}).setdefault( + pakl_row.asset_id, set() + ).add(pakl_row.source_partition_key or "") + + results = [ + _build_response( + row, + _compute_total_required( + dag_models.get(row.target_dag_id), + rollup_timetables_by_dag[row.target_dag_id], + assets_by_dag[row.target_dag_id][0], + row.partition_key, + ), + _compute_received_count( + dag_models.get(row.target_dag_id), + log_by_apdr.get(row.id, {}), + rollup_timetables_by_dag[row.target_dag_id], + assets_by_dag[row.target_dag_id][1], + row.partition_key, + ), ) - .where(DagModel.dag_id.in_(dag_ids)) - .group_by(DagModel.dag_id) - ).all() + for row in rows + ] - required_counts = {r.dag_id: r.required_count for r in dag_rows} - asset_expressions = {r.dag_id: r.asset_expression for r in dag_rows} - results = [_build_response(row, required_counts.get(row.target_dag_id, 0)) for row in rows] + asset_expressions: dict[str, dict | None] | None = None + if dag_id.value is None: + asset_expressions = {dm.dag_id: dm.asset_expression for dm in dag_models.values()} return PartitionedDagRunCollectionResponse( partitioned_dag_runs=results, @@ -201,38 +336,67 @@ def get_pending_partitioned_dag_run( f"No PartitionedDagRun for dag={dag_id} partition={partition_key}", ) - received_subq = ( - select(PartitionedAssetKeyLog.asset_id).where( - PartitionedAssetKeyLog.asset_partition_dag_run_id == partitioned_dag_run.id - ) - ).correlate(AssetModel) - - received_expr = exists(received_subq.where(PartitionedAssetKeyLog.asset_id == AssetModel.id)) + # Collect received upstream partition keys per asset for this partition run. + # Use a set to deduplicate: multiple events for the same key count as one. + received_keys_by_asset: dict[int, set[str]] = {} + for row in session.execute( + select( + PartitionedAssetKeyLog.asset_id, + PartitionedAssetKeyLog.source_partition_key, + ).where(PartitionedAssetKeyLog.asset_partition_dag_run_id == partitioned_dag_run.id) + ): + received_keys_by_asset.setdefault(row.asset_id, set()).add(row.source_partition_key or "") - asset_expression_subq = ( - select(DagModel.asset_expression).where(DagModel.dag_id == dag_id).scalar_subquery() - ) + dag_model = session.get(DagModel, dag_id) asset_rows = session.execute( select( AssetModel.id, AssetModel.uri, AssetModel.name, - received_expr.label("received"), - asset_expression_subq.label("asset_expression"), ) .join(DagScheduleAssetReference, DagScheduleAssetReference.asset_id == AssetModel.id) .where(DagScheduleAssetReference.dag_id == dag_id, AssetModel.active.has()) - .order_by(received_expr.asc(), AssetModel.uri) + .order_by(AssetModel.uri) ).all() - assets = [ - PartitionedDagRunAssetResponse( - asset_id=row.id, asset_name=row.name, asset_uri=row.uri, received=row.received + # Skip the timetable load when no rollup mapper is configured — the cached + # ``partition_mapper_info`` already tells us whether we will need + # ``to_upstream`` evaluation, which is the only thing the timetable adds here. + has_rollup_mappers = dag_model is not None and dag_model.has_rollup_mappers + rollup_timetable = _load_timetable(dag_id, session) if has_rollup_mappers else None + + assets = [] + for asset_row in asset_rows: + received_keys = sorted(received_keys_by_asset.get(asset_row.id, set())) + required_keys: list[str] = [partition_key] + is_rollup = ( + has_rollup_mappers + and dag_model is not None + and dag_model.is_rollup_asset(name=asset_row.name, uri=asset_row.uri) ) - for row in asset_rows - ] - total_received = sum(1 for a in assets if a.received) - asset_expression = asset_rows[0].asset_expression if asset_rows else None + if is_rollup and rollup_timetable is not None: + with suppress(Exception): + mapper = rollup_timetable.get_partition_mapper(name=asset_row.name, uri=asset_row.uri) + required_keys = sorted(cast("RollupMapper", mapper).to_upstream(partition_key)) + received_count = len(received_keys) + required_count = len(required_keys) + assets.append( + PartitionedDagRunAssetResponse( + asset_id=asset_row.id, + asset_name=asset_row.name, + asset_uri=asset_row.uri, + received=received_count >= required_count and required_count > 0, + received_count=received_count, + required_count=required_count, + received_keys=received_keys, + required_keys=required_keys, + is_rollup=is_rollup, + ) + ) + + total_received = sum(a.received_count for a in assets) + total_required = sum(a.required_count for a in assets) + asset_expression = dag_model.asset_expression if dag_model is not None else None return PartitionedDagRunDetailResponse( id=partitioned_dag_run.id, @@ -242,7 +406,7 @@ def get_pending_partitioned_dag_run( updated_at=partitioned_dag_run.updated_at.isoformat() if partitioned_dag_run.updated_at else None, created_dag_run_id=partitioned_dag_run.created_dag_run_id, assets=assets, - total_required=len(assets), + total_required=total_required, total_received=total_received, asset_expression=asset_expression, ) diff --git a/airflow-core/src/airflow/dag_processing/collection.py b/airflow-core/src/airflow/dag_processing/collection.py index 6f8bc752bbc4f..2e780825cd577 100644 --- a/airflow-core/src/airflow/dag_processing/collection.py +++ b/airflow-core/src/airflow/dag_processing/collection.py @@ -622,6 +622,7 @@ def update_dags( dm.timetable_description = dag.timetable.description dm.timetable_partitioned = dag.timetable.partitioned dm.timetable_periodic = dag.timetable.periodic + dm.partition_mapper_info = dag.timetable.partition_mapper_info dm.fail_fast = dag.fail_fast if dag.fail_fast is not None else False allowed_types = dag.allowed_run_types diff --git a/airflow-core/src/airflow/example_dags/example_asset_partition.py b/airflow-core/src/airflow/example_dags/example_asset_partition.py index 75d582f6cad6e..12b23f8e8bc75 100644 --- a/airflow-core/src/airflow/example_dags/example_asset_partition.py +++ b/airflow-core/src/airflow/example_dags/example_asset_partition.py @@ -24,11 +24,15 @@ AllowedKeyMapper, Asset, CronPartitionTimetable, + DayWindow, IdentityMapper, + MonthWindow, PartitionedAssetTimetable, ProductMapper, + RollupMapper, StartOfDayMapper, StartOfHourMapper, + StartOfMonthMapper, StartOfYearMapper, asset, task, @@ -225,3 +229,76 @@ def regional_stats_breakdown(): keys belong to a fixed set of allowed values (``us``, ``eu``, ``apac``) rather than time-based partitions. """ pass + + +# --- Chained rollup: hourly → daily → monthly -------------------------------- +# The hourly source asset already exists above (``team_a_player_stats``). +# Each rollup Dag publishes its own asset so the next level can consume it. + +daily_team_a = Asset(uri="s3://team-a/daily", name="daily_team_a") +monthly_team_a = Asset(uri="s3://team-a/monthly", name="monthly_team_a") + + +with DAG( + dag_id="daily_team_a_rollup", + schedule=PartitionedAssetTimetable( + assets=team_a_player_stats, + default_partition_mapper=RollupMapper( + source_mapper=StartOfDayMapper(), + window=DayWindow(), + ), + ), + catchup=False, + tags=["player-stats", "rollup"], +): + """ + First rollup level: 24 hourly partitions of ``team_a_player_stats`` → one daily summary. + + ``StartOfDayMapper`` normalizes each upstream hourly timestamp (``%Y-%m-%dT%H:%M:%S``) + to its day-start (``%Y-%m-%d``); ``DayWindow`` declares the downstream run needs + all 24 hourly partitions before firing. Publishes ``daily_team_a`` so the + monthly rollup below can consume it. + """ + + @task(outlets=[daily_team_a]) + def summarise_team_a_day(dag_run=None): + """Produce the full-day rollup once every hour has arrived.""" + if TYPE_CHECKING: + assert dag_run + print(f"All 24 hourly partitions received. Day: {dag_run.partition_key}") + + summarise_team_a_day() + + +with DAG( + dag_id="monthly_team_a_rollup", + schedule=PartitionedAssetTimetable( + assets=daily_team_a, + # The upstream (``daily_team_a``) emits day-formatted partition keys + # (``%Y-%m-%d``), so the source mapper here must accept that format. + default_partition_mapper=RollupMapper( + source_mapper=StartOfMonthMapper(input_format="%Y-%m-%d"), + window=MonthWindow(), + ), + ), + catchup=False, + tags=["player-stats", "rollup"], +): + """ + Chained rollup: every day of ``daily_team_a`` (itself a rollup) → one monthly summary. + + Demonstrates how a rollup output can feed another rollup. ``StartOfMonthMapper`` + is configured with ``input_format="%Y-%m-%d"`` so it can parse the day keys + emitted by ``daily_team_a_rollup``; ``MonthWindow`` waits for every day of the + calendar month (28–31 depending on the month). The partition key is the month + identifier, e.g. ``2024-01``. + """ + + @task(outlets=[monthly_team_a]) + def summarise_team_a_month(dag_run=None): + """Produce the full-month rollup once every day has arrived.""" + if TYPE_CHECKING: + assert dag_run + print(f"All daily partitions received. Month: {dag_run.partition_key}") + + summarise_team_a_month() diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py b/airflow-core/src/airflow/jobs/scheduler_job_runner.py index 3eed95a8bb030..23019c8a3d438 100644 --- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py +++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py @@ -88,7 +88,7 @@ from airflow.serialization.definitions.assets import SerializedAssetUniqueKey from airflow.serialization.definitions.notset import NOTSET from airflow.ti_deps.dependencies_states import ACTIVE_STATES, EXECUTION_STATES -from airflow.timetables.simple import AssetTriggeredTimetable +from airflow.timetables.simple import AssetTriggeredTimetable, PartitionedAssetTimetable from airflow.utils.event_scheduler import EventScheduler from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.retries import MAX_DB_RETRIES, retry_db_transaction, run_with_db_retries @@ -116,6 +116,7 @@ from airflow.executors.base_executor import BaseExecutor from airflow.executors.executor_utils import ExecutorName from airflow.executors.workloads.types import SchedulerWorkload + from airflow.partition_mappers.base import RollupMapper from airflow.serialization.definitions.dag import SerializedDAG from airflow.utils.sqlalchemy import CommitProhibitorGuard @@ -126,6 +127,23 @@ TASK_STUCK_IN_QUEUED_RESCHEDULE_EVENT = "stuck in queued reschedule" """:meta private:""" +MAX_PARTITION_DAG_RUNS_PER_TICK = 500 +""" +Per-tick cap on pending :class:`~airflow.models.asset.AssetPartitionDagRun` +rows the scheduler will evaluate when creating partition-driven Dag runs. + +A single tick reads this many APDRs in FIFO order, bulk-loads the +serialized Dags + per-asset event logs, and evaluates each. Bounding the +per-tick batch keeps the transaction short so executor heartbeats and +regular scheduling work aren't starved on busy deployments. Remaining +APDRs drain across subsequent ticks; the cap therefore sets the +*steady-state throughput*, not the maximum backlog. + +500 was chosen as a default that keeps the per-tick critical section +comfortably under one second on typical hardware while still draining a +modest backlog within a few ticks. +""" + def _eager_load_dag_run_for_validation() -> tuple[LoaderOption, LoaderOption]: """ @@ -1831,40 +1849,167 @@ def _do_scheduling(self, session: Session) -> int: return num_queued_tis + def _check_rollup_asset_status( + self, + *, + asset_id: int, + apdr: AssetPartitionDagRun, + mapper: RollupMapper, + actual_by_asset: dict[int, set[str]], + ) -> bool: + if TYPE_CHECKING: + assert apdr.partition_key is not None + expected = mapper.to_upstream(apdr.partition_key) + return expected.issubset(actual_by_asset.get(asset_id, set())) + + def _resolve_asset_partition_status( + self, + *, + session: Session, + asset_id: int, + name: str, + uri: str, + apdr: AssetPartitionDagRun, + timetable: PartitionedAssetTimetable, + actual_by_asset: dict[int, set[str]], + ) -> bool: + """ + Return whether *asset_id* has been satisfied for *apdr*. + + Non-rollup assets resolve to ``True`` because the caller only invokes + this for assets that already have at least one logged event for *APDR* + (see :class:`~airflow.models.asset.PartitionedAssetKeyLog`), which is + the non-rollup contract for "received". Rollup assets defer to + :meth:`_check_rollup_asset_status` for the upstream-window check. + + A misconfigured mapper that raises returns ``False`` (treated as + not-yet-satisfied) and an audit log entry is written so the operator + can see why the Dag run is being held in the UI. + """ + try: + mapper = timetable.get_partition_mapper(name=name, uri=uri) + if not mapper.is_rollup: + return True + return self._check_rollup_asset_status( + asset_id=asset_id, + apdr=apdr, + mapper=cast("RollupMapper", mapper), + actual_by_asset=actual_by_asset, + ) + except Exception as err: + self.log.exception( + "Failed to evaluate rollup status for asset; treating as not-yet-satisfied. " + "This likely indicates a misconfigured partition mapper.", + dag_id=apdr.target_dag_id, + partition_key=apdr.partition_key, + asset_name=name, + asset_uri=uri, + ) + session.add( + Log( + event="failed to evaluate rollup status", + dag_id=apdr.target_dag_id, + extra=( + "Could not evaluate rollup status for partition_key " + f"'{apdr.partition_key}' on asset (name='{name}', uri='{uri}') " + f"in target Dag '{apdr.target_dag_id}'. This likely indicates " + "that the rollup mapper is misconfigured or does not support " + f"this partition key.\n{type(err).__name__}: {err}" + ), + ) + ) + return False + def _create_dagruns_for_partitioned_asset_dags(self, session: Session) -> set[str]: + """ + Create Dag runs for pending :class:`AssetPartitionDagRun` rows whose partition is satisfied. + + Returns the set of ``dag_id`` strings that received a new partition-driven Dag run in this + tick. The caller (:meth:`_create_dagruns_for_dags`) uses this set to exclude the same Dags + from the standard schedule-driven and asset-triggered creation paths so a single Dag never + gets two Dag runs for the same tick when it appears in more than one creation path. We + return ``dag_id`` strings rather than full Dag/DagRun objects because the only downstream + use is membership lookup, and a heavier return type would just be discarded. + """ + # Cap per-tick work so the scheduler transaction stays bounded and other + # scheduling work isn't starved. Remaining APDRs drain across subsequent ticks. + # Note: with strict FIFO ordering, persistently-unsatisfied APDRs at the head + # of the queue would block newer ones; switch to updated_at-based ordering if + # that becomes an issue. + pending_apdrs = session.scalars( + select(AssetPartitionDagRun) + .join(DagModel, DagModel.dag_id == AssetPartitionDagRun.target_dag_id) + .where( + AssetPartitionDagRun.created_dag_run_id.is_(None), + DagModel.is_stale.is_(False), + ) + .order_by(AssetPartitionDagRun.created_at) + .limit(MAX_PARTITION_DAG_RUNS_PER_TICK) + ).all() + if not pending_apdrs: + return set() + partition_dag_ids: set[str] = set() + pending_apdr_ids = [apdr.id for apdr in pending_apdrs] - evaluator = AssetEvaluator(session) - for apdr in session.scalars( - select(AssetPartitionDagRun).where(AssetPartitionDagRun.created_dag_run_id.is_(None)) + # Pre-fetch all required serialized Dags in one query. + dag_ids = list({apdr.target_dag_id for apdr in pending_apdrs if apdr.target_dag_id}) + # {"dag_id": Serialized Dag} + serialized_dags: dict[str, SerializedDAG] = {} + for serdag in SerializedDagModel.get_latest_serialized_dags(dag_ids=dag_ids, session=session): + try: + serdag.load_op_links = False + serialized_dags[serdag.dag_id] = serdag.dag + except Exception: + self.log.exception("Failed to deserialize Dag '%s'", serdag.dag_id) + + # {apdr_id: {asset_id: set(source_key, ...)} + source_key_by_asset_per_apdr: dict[int, dict[int, set[str]]] = defaultdict(lambda: defaultdict(set)) + # {apdr_id: {asset_id: (asset_name, asset_uri)} + asset_info_per_apdr: dict[int, dict[int, tuple[str, str]]] = defaultdict(dict) + for apdr_id, asset_id, source_key, name, uri in session.execute( + select( + PartitionedAssetKeyLog.asset_partition_dag_run_id, + PartitionedAssetKeyLog.asset_id, + PartitionedAssetKeyLog.source_partition_key, + AssetModel.name, + AssetModel.uri, + ) + .join(AssetModel, AssetModel.id == PartitionedAssetKeyLog.asset_id) + .where(PartitionedAssetKeyLog.asset_partition_dag_run_id.in_(pending_apdr_ids)) ): + source_key_by_asset_per_apdr[apdr_id][asset_id].add(source_key) + asset_info_per_apdr[apdr_id][asset_id] = (name, uri) + + evaluator = AssetEvaluator(session) + for apdr in pending_apdrs: if TYPE_CHECKING: assert apdr.target_dag_id - if not (dag := self._get_current_dag(dag_id=apdr.target_dag_id, session=session)): + if not (dag := serialized_dags.get(apdr.target_dag_id)): self.log.error("Dag '%s' not found in serialized_dag table", apdr.target_dag_id) continue - asset_models = session.scalars( - select(AssetModel).where( - exists( - select(1).where( - PartitionedAssetKeyLog.asset_id == AssetModel.id, - PartitionedAssetKeyLog.asset_partition_dag_run_id == apdr.id, - PartitionedAssetKeyLog.target_partition_key == apdr.partition_key, - ) + source_key_by_asset = source_key_by_asset_per_apdr[apdr.id] + timetable = dag.timetable + statuses: dict[SerializedAssetUniqueKey, bool] = {} + for asset_id, (name, uri) in asset_info_per_apdr[apdr.id].items(): + key = SerializedAssetUniqueKey(name=name, uri=uri) + if timetable.partitioned is True: + if TYPE_CHECKING: + assert isinstance(timetable, PartitionedAssetTimetable) + statuses[key] = self._resolve_asset_partition_status( + session=session, + asset_id=asset_id, + name=name, + uri=uri, + apdr=apdr, + timetable=timetable, + actual_by_asset=source_key_by_asset, ) - ) - ) - statuses: dict[SerializedAssetUniqueKey, bool] = { - SerializedAssetUniqueKey.from_asset(a): True for a in asset_models - } - # todo: AIP-76 so, this basically works when we only require one partition from each asset to be there - # but, we ultimately need rollup ability - # that is, we need to ensure that whenever it is many -> one partitions, then we need to ensure - # that all the required keys are there - # one way to do this would be just to figure out what the count should be - if not evaluator.run(dag.timetable.asset_condition, statuses=statuses): + else: + statuses[key] = True + if not evaluator.run(timetable.asset_condition, statuses=statuses): continue partition_dag_ids.add(apdr.target_dag_id) diff --git a/airflow-core/src/airflow/migrations/versions/0115_3_3_0_add_partition_mapper_info_to_dag.py b/airflow-core/src/airflow/migrations/versions/0115_3_3_0_add_partition_mapper_info_to_dag.py new file mode 100644 index 0000000000000..7a451a632a232 --- /dev/null +++ b/airflow-core/src/airflow/migrations/versions/0115_3_3_0_add_partition_mapper_info_to_dag.py @@ -0,0 +1,73 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Add partition_mapper_info to DagModel. + +The new JSON column caches per-asset partition mapper metadata produced +during Dag serialization (one entry per asset, see ``PartitionMapperInfo``) +so UI endpoints can resolve mapper attributes such as ``is_rollup`` without +deserializing the timetable on every request. Defaults to ``[]`` for +timetables without per-asset partition mappers. + +Revision ID: c20871fbf23a +Revises: a7f3b2c1d4e5 +Create Date: 2026-05-06 00:00:00.000000 + +""" + +from __future__ import annotations + +import sqlalchemy as sa +from alembic import op + +from airflow.migrations.utils import disable_sqlite_fkeys + +revision = "c20871fbf23a" +down_revision = "a7f3b2c1d4e5" +branch_labels = None +depends_on = None +airflow_version = "3.3.0" + + +def upgrade(): + """ + Add partition_mapper_info JSON column to dag table. + + The column is added nullable, existing rows are backfilled with ``[]``, + then the column is altered to NOT NULL. MySQL refuses literal defaults on + JSON columns, and the model intentionally carries no ``server_default``, + so we use the same backfill strategy on every backend to avoid leaving a + stray DB-side default that diverges from the ORM definition. The final + ``alter_column`` triggers a table rebuild on SQLite; foreign keys are + disabled around the whole upgrade so dependent tables' references stay + intact regardless of which step ends up rebuilding the table. + """ + with disable_sqlite_fkeys(op): + with op.batch_alter_table("dag", schema=None) as batch_op: + batch_op.add_column(sa.Column("partition_mapper_info", sa.JSON(), nullable=True)) + op.execute(sa.text("UPDATE dag SET partition_mapper_info = '[]' WHERE partition_mapper_info IS NULL")) + with op.batch_alter_table("dag", schema=None) as batch_op: + batch_op.alter_column("partition_mapper_info", existing_type=sa.JSON(), nullable=False) + + +def downgrade(): + """Remove partition_mapper_info column from dag table.""" + with disable_sqlite_fkeys(op): + with op.batch_alter_table("dag", schema=None) as batch_op: + batch_op.drop_column("partition_mapper_info") diff --git a/airflow-core/src/airflow/models/dag.py b/airflow-core/src/airflow/models/dag.py index ca84b7047b435..a50dd10298c1a 100644 --- a/airflow-core/src/airflow/models/dag.py +++ b/airflow-core/src/airflow/models/dag.py @@ -66,7 +66,7 @@ from airflow.serialization.definitions.assets import SerializedAssetUniqueKey from airflow.serialization.encoders import DAT, encode_deadline_alert from airflow.serialization.enums import Encoding -from airflow.timetables.base import DataInterval, Timetable +from airflow.timetables.base import DataInterval, PartitionMapperInfo, Timetable from airflow.timetables.interval import CronDataIntervalTimetable, DeltaDataIntervalTimetable from airflow.timetables.simple import AssetTriggeredTimetable, NullTimetable, OnceTimetable from airflow.utils.session import NEW_SESSION, provide_session @@ -388,6 +388,16 @@ class DagModel(Base): timetable_partitioned: Mapped[bool] = mapped_column(Boolean, nullable=False, server_default="0") # Whether the timetable is periodic (supports backfilling). timetable_periodic: Mapped[bool] = mapped_column(Boolean, nullable=False, server_default="0") + # Cached partition mapper metadata for partitioned timetables, populated + # during Dag serialization so the UI can resolve mapper attributes without + # deserializing the timetable. See ``PartitionMapperInfo`` for the per-asset + # entry shape; empty list for timetables without per-asset partition mappers. + # No ``server_default`` — MySQL refuses literal defaults on JSON columns; + # the migration backfills existing rows and ``default=list`` covers new + # ORM inserts that don't pass an explicit value. + partition_mapper_info: Mapped[list[PartitionMapperInfo]] = mapped_column( + sa.JSON(), nullable=False, default=list + ) # Asset expression based on asset triggers asset_expression: Mapped[dict[str, Any] | None] = mapped_column(sa.JSON(), nullable=True) # DAG deadline information @@ -481,6 +491,28 @@ def __init__(self, **kwargs): def __repr__(self): return f"" + def is_rollup_asset(self, *, name: str, uri: str) -> bool: + """ + Return whether the asset identified by *name*/*uri* uses a rollup mapper. + + Reads the cached ``partition_mapper_info`` populated during Dag + serialization, mirroring ``PartitionedAssetTimetable.get_partition_mapper`` + (name lookup wins over uri lookup). Returns ``False`` when no entry + matches the asset. + """ + for entry in self.partition_mapper_info: + if entry.get("name") == name: + return entry["is_rollup"] + for entry in self.partition_mapper_info: + if entry.get("uri") == uri: + return entry["is_rollup"] + return False + + @property + def has_rollup_mappers(self) -> bool: + """Whether any cached partition mapper is a rollup mapper.""" + return any(entry["is_rollup"] for entry in self.partition_mapper_info) + @property def next_dagrun_data_interval(self) -> DataInterval | None: return _get_model_data_interval( diff --git a/airflow-core/src/airflow/partition_mappers/base.py b/airflow-core/src/airflow/partition_mappers/base.py index 7c64d05625855..eafd60bbeae62 100644 --- a/airflow-core/src/airflow/partition_mappers/base.py +++ b/airflow-core/src/airflow/partition_mappers/base.py @@ -23,6 +23,8 @@ if TYPE_CHECKING: from collections.abc import Iterable + from airflow.partition_mappers.window import Window + class PartitionMapper(ABC): """ @@ -31,13 +33,82 @@ class PartitionMapper(ABC): Maps keys from asset events to target dag run partitions. """ + is_rollup: bool = False + @abstractmethod def to_downstream(self, key: str) -> str | Iterable[str]: """Return the target key that the given source partition key maps to.""" + def decode_downstream(self, downstream_key: str) -> Any: + """ + Recover the canonical decoded form of *downstream_key*. + + Used by :class:`RollupMapper` to hand the window an opaque "anchor" + for the downstream period; the window iterates in this decoded space + and the mapper re-encodes each expected upstream via + :meth:`encode_upstream`. Default is identity (string in, string out) + — temporal mappers override to return ``datetime``, future segment + mappers will return whatever shape suits them. + """ + return downstream_key + + def encode_upstream(self, decoded: Any) -> str: + """ + Encode an expected upstream object back into a key string. + + Pair of :meth:`decode_downstream`. Default is identity. Temporal + mappers override to apply timezone + ``input_format``. + """ + return decoded + def serialize(self) -> dict[str, Any]: return {} @classmethod def deserialize(cls, data: dict[str, Any]) -> PartitionMapper: return cls() + + +class RollupMapper(PartitionMapper): + """ + Partition mapper that rolls up many upstream keys into one downstream key. + + Compose a ``source_mapper`` (which normalizes each upstream key to the + downstream granularity) with a ``window`` that declares the full set of + upstream keys required for a given downstream key. The scheduler holds + the Dag run until every upstream key in the window has arrived. + """ + + is_rollup: bool = True + + def __init__(self, *, source_mapper: PartitionMapper, window: Window) -> None: + self.source_mapper = source_mapper + self.window = window + + def to_downstream(self, key: str) -> str | Iterable[str]: + return self.source_mapper.to_downstream(key) + + def to_upstream(self, downstream_key: str) -> frozenset[str]: + """Return the complete set of upstream partition keys required for *downstream_key*.""" + decoded = self.source_mapper.decode_downstream(downstream_key) + return frozenset( + self.source_mapper.encode_upstream(expected_upstream) + for expected_upstream in self.window.to_upstream(decoded) + ) + + def serialize(self) -> dict[str, Any]: + from airflow.serialization.encoders import encode_partition_mapper, encode_window + + return { + "source_mapper": encode_partition_mapper(self.source_mapper), + "window": encode_window(self.window), + } + + @classmethod + def deserialize(cls, data: dict[str, Any]) -> PartitionMapper: + from airflow.serialization.decoders import decode_partition_mapper, decode_window + + return cls( + source_mapper=decode_partition_mapper(data["source_mapper"]), + window=decode_window(data["window"]), + ) diff --git a/airflow-core/src/airflow/partition_mappers/temporal.py b/airflow-core/src/airflow/partition_mappers/temporal.py index 49f4162a12e2c..1f42b4755302d 100644 --- a/airflow-core/src/airflow/partition_mappers/temporal.py +++ b/airflow-core/src/airflow/partition_mappers/temporal.py @@ -16,6 +16,7 @@ # under the License. from __future__ import annotations +import re from abc import ABC, abstractmethod from datetime import datetime, timedelta from typing import TYPE_CHECKING, Any @@ -27,6 +28,81 @@ from pendulum import FixedTimezone, Timezone +_STRPTIME_PATTERNS: dict[str, str] = { + "%Y": r"\d{4}", + "%m": r"\d{2}", + "%d": r"\d{2}", + "%H": r"\d{2}", + "%M": r"\d{2}", + "%S": r"\d{2}", + "%V": r"\d{2}", + "%U": r"\d{2}", + "%W": r"\d{2}", +} + + +def _compile_output_format_regex( + fmt: str, placeholder_patterns: dict[str, str] | None = None +) -> re.Pattern[str]: + r""" + Compile *fmt* into a regex with named groups so a formatted key can be parsed back. + + Strftime directives in ``_STRPTIME_PATTERNS`` (``%Y``, ``%m``, ``%V`` …) become + named groups keyed on the directive letter — e.g. ``%Y`` → ``(?P\d{4})``. + FastAPI-style ``{name}`` placeholders become named groups keyed on *name*; the + regex defaults to ``\w+``, with *placeholder_patterns* letting callers narrow + it (e.g. ``{"quarter": r"[1-4]"}``). Anything else is escaped as a literal. + + The regex is anchored with ``.search`` rather than ``.fullmatch`` so callers + can extract fields without caring about literal prefix/suffix length. + + Raises :exc:`ValueError` if *fmt* cannot produce a valid regex — empty + ``{}`` placeholders, names that aren't valid Python identifiers, or any + duplicate group name (a directive or placeholder appearing twice). Catching + this at construction time keeps a misconfigured mapper from surfacing as an + opaque ``re.error`` deep inside the scheduler tick or a UI route. + """ + placeholder_patterns = placeholder_patterns or {} + parts: list[str] = [] + seen_groups: set[str] = set() + i = 0 + while i < len(fmt): + if fmt[i] == "%" and i + 1 < len(fmt) and fmt[i : i + 2] in _STRPTIME_PATTERNS: + directive = fmt[i : i + 2] + name = directive[1] + if name in seen_groups: + raise ValueError( + f"output_format {fmt!r} reuses directive {directive!r}; " + "each strftime directive may appear at most once." + ) + seen_groups.add(name) + parts.append(f"(?P<{name}>{_STRPTIME_PATTERNS[directive]})") + i += 2 + continue + if fmt[i] == "{": + end = fmt.find("}", i) + if end != -1: + name = fmt[i + 1 : end] + if not name.isidentifier(): + raise ValueError( + f"output_format {fmt!r} has invalid placeholder {{{name}}}; " + "placeholder names must be valid Python identifiers." + ) + if name in seen_groups: + raise ValueError( + f"output_format {fmt!r} reuses placeholder {{{name}}}; " + "each placeholder may appear at most once." + ) + seen_groups.add(name) + pattern = placeholder_patterns.get(name, r"\w+") + parts.append(f"(?P<{name}>{pattern})") + i = end + 1 + continue + parts.append(re.escape(fmt[i])) + i += 1 + return re.compile("".join(parts)) + + class _BaseTemporalMapper(PartitionMapper, ABC): """Base class for Temporal Partition Mappers.""" @@ -62,6 +138,29 @@ def format(self, dt: datetime) -> str: """Format the normalized datetime.""" return dt.strftime(self.output_format) + def decode_downstream(self, downstream_key: str) -> datetime: + """ + Recover the period-start datetime from a previously formatted downstream key. + + Inverse of ``format``. The default implementation uses ``strptime`` with + ``output_format``, which works for any format made of standard strptime + directives. Subclasses with custom format markers (e.g. ``{quarter}``) or + ambiguous directives (e.g. bare ``%V``) override this. + """ + return datetime.strptime(downstream_key, self.output_format) + + def encode_upstream(self, dt: datetime) -> str: + """ + Format *dt* as an upstream partition key string. + + Pair of :meth:`decode_downstream`: takes a (decoded) period-start + datetime and produces a key string in the upstream's ``input_format`` + with ``timezone`` applied. Used by :class:`RollupMapper` to render each + upstream member yielded by the window back into the form upstream + producers actually emit. + """ + return make_aware(dt, self._timezone).strftime(self.input_format) + def serialize(self) -> dict[str, Any]: from airflow.serialization.encoders import encode_timezone @@ -99,17 +198,41 @@ def normalize(self, dt: datetime) -> datetime: class StartOfWeekMapper(_BaseTemporalMapper): - """Map a time-based partition key to week.""" + """Map a time-based partition key to the start of its week.""" default_output_format = "%Y-%m-%d (W%V)" + def __init__( + self, + *, + timezone: str | Timezone | FixedTimezone = "UTC", + input_format: str = "%Y-%m-%dT%H:%M:%S", + output_format: str | None = None, + ) -> None: + super().__init__(timezone=timezone, input_format=input_format, output_format=output_format) + # %V (ISO week) cannot be round-tripped through strptime without %G+%u, + # so derive a named-group regex from output_format and pull out %Y/%m/%d. + # Compile eagerly so a malformed output_format raises ValueError here + # instead of an opaque re.error inside the scheduler. + self._key_pattern = _compile_output_format_regex(self.output_format) + def normalize(self, dt: datetime) -> datetime: start = dt - timedelta(days=dt.weekday()) return start.replace(hour=0, minute=0, second=0, microsecond=0) + def decode_downstream(self, downstream_key: str) -> datetime: + match = self._key_pattern.search(downstream_key) + if match is None or not {"Y", "m", "d"}.issubset(match.groupdict()): + raise ValueError( + f"StartOfWeekMapper.decode_downstream could not parse {downstream_key!r} " + f"with output_format {self.output_format!r}; " + "output_format must include the %Y, %m and %d directives." + ) + return datetime(int(match["Y"]), int(match["m"]), int(match["d"])) + class StartOfMonthMapper(_BaseTemporalMapper): - """Map a time-based partition key to month.""" + """Map a time-based partition key to the start of its month.""" default_output_format = "%Y-%m" @@ -128,6 +251,20 @@ class StartOfQuarterMapper(_BaseTemporalMapper): default_output_format = "%Y-Q{quarter}" + def __init__( + self, + *, + timezone: str | Timezone | FixedTimezone = "UTC", + input_format: str = "%Y-%m-%dT%H:%M:%S", + output_format: str | None = None, + ) -> None: + super().__init__(timezone=timezone, input_format=input_format, output_format=output_format) + # ``{quarter}`` is a Python-format placeholder, not a strftime directive, + # so derive a named-group regex from output_format that handles both. + # Compile eagerly so a malformed output_format raises ValueError here + # instead of an opaque re.error inside the scheduler. + self._key_pattern = _compile_output_format_regex(self.output_format, {"quarter": r"[1-4]"}) + def normalize(self, dt: datetime) -> datetime: quarter = (dt.month - 1) // 3 month = quarter * 3 + 1 @@ -144,6 +281,18 @@ def format(self, dt: datetime) -> str: quarter = (dt.month - 1) // 3 + 1 return dt.strftime(self.output_format).format(quarter=quarter) + def decode_downstream(self, downstream_key: str) -> datetime: + match = self._key_pattern.search(downstream_key) + if match is None or not {"Y", "quarter"}.issubset(match.groupdict()): + raise ValueError( + f"StartOfQuarterMapper.decode_downstream could not parse {downstream_key!r} " + f"with output_format {self.output_format!r}; " + "output_format must include the %Y directive and the {quarter} placeholder." + ) + year = int(match["Y"]) + quarter = int(match["quarter"]) + return datetime(year, (quarter - 1) * 3 + 1, 1) + class StartOfYearMapper(_BaseTemporalMapper): """Map a time-based partition key to year.""" diff --git a/airflow-core/src/airflow/partition_mappers/window.py b/airflow-core/src/airflow/partition_mappers/window.py new file mode 100644 index 0000000000000..f47bcb386b48a --- /dev/null +++ b/airflow-core/src/airflow/partition_mappers/window.py @@ -0,0 +1,119 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from abc import ABC, abstractmethod +from datetime import datetime, timedelta +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from collections.abc import Iterable + + +def _shift_months(dt: datetime, months: int) -> datetime: + """ + Return *dt* shifted forward by *months*, wrapping the year as needed. + + Built-in temporal mappers always emit period-starts on day 1, so + :meth:`datetime.replace` on the new month is always valid. A + user-provided source mapper that emits a higher day for a month with + fewer days remains the caller's responsibility. + """ + total = dt.month - 1 + months + return dt.replace(year=dt.year + total // 12, month=total % 12 + 1) + + +class Window(ABC): + """ + Describes a rollup window: which decoded upstream items make up one decoded downstream period. + + Paired with a source mapper inside a :class:`RollupMapper`. The window + operates purely in the source mapper's *decoded* form (``datetime`` for + temporal mappers, domain-specific types for future segment / runtime + mappers). It does not touch key strings, timezones, or formats — those + belong to the source mapper. ``RollupMapper`` orchestrates the three: + decode the downstream key, expand via the window, encode each upstream. + """ + + @abstractmethod + def to_upstream(self, decoded_downstream: Any) -> Iterable[Any]: + """Yield each decoded upstream item composing *decoded_downstream*.""" + + def serialize(self) -> dict[str, Any]: + return {} + + @classmethod + def deserialize(cls, data: dict[str, Any]) -> Window: + return cls() + + +class HourWindow(Window): + """Sixty consecutive minute period-starts making up one hour.""" + + def to_upstream(self, period_start: datetime) -> Iterable[datetime]: + return (period_start + timedelta(minutes=i) for i in range(60)) + + +class DayWindow(Window): + """ + Twenty-four consecutive hourly period-starts making up one day. + + Arithmetic is done on naive datetime steps so the 24-hour stride is + unambiguous across DST transitions; the source mapper handles timezone + awareness when it encodes each upstream member back to a key string. + """ + + def to_upstream(self, period_start: datetime) -> Iterable[datetime]: + return (period_start + timedelta(hours=i) for i in range(24)) + + +class WeekWindow(Window): + """Seven consecutive daily period-starts making up one week.""" + + def to_upstream(self, period_start: datetime) -> Iterable[datetime]: + return (period_start + timedelta(days=i) for i in range(7)) + + +class MonthWindow(Window): + """ + All daily period-starts making up one calendar month. + + Iterates from the period start to the matching day of the next month, + so a user-provided source mapper that emits non-1st period-starts + (e.g. fiscal months) is handled transparently. + """ + + def to_upstream(self, period_start: datetime) -> Iterable[datetime]: + next_month = period_start.month % 12 + 1 + next_year = period_start.year + (1 if period_start.month == 12 else 0) + next_start = period_start.replace(year=next_year, month=next_month) + days = (next_start - period_start).days + return (period_start + timedelta(days=i) for i in range(days)) + + +class QuarterWindow(Window): + """Three monthly period-starts making up one calendar quarter (e.g. Jan/Feb/Mar for Q1).""" + + def to_upstream(self, period_start: datetime) -> Iterable[datetime]: + return (_shift_months(period_start, i) for i in range(3)) + + +class YearWindow(Window): + """Twelve consecutive monthly period-starts making up one calendar year.""" + + def to_upstream(self, period_start: datetime) -> Iterable[datetime]: + return (_shift_months(period_start, i) for i in range(12)) diff --git a/airflow-core/src/airflow/serialization/decoders.py b/airflow-core/src/airflow/serialization/decoders.py index 683efdc87e6ec..577cdc703e3cc 100644 --- a/airflow-core/src/airflow/serialization/decoders.py +++ b/airflow-core/src/airflow/serialization/decoders.py @@ -49,6 +49,7 @@ if TYPE_CHECKING: from airflow.partition_mappers.base import PartitionMapper + from airflow.partition_mappers.window import Window from airflow.timetables.base import Timetable as CoreTimetable R = TypeVar("R") @@ -202,3 +203,13 @@ def decode_partition_mapper(var: dict[str, Any]) -> PartitionMapper: else: partition_mapper_cls = find_registered_custom_partition_mapper(importable_string) return partition_mapper_cls.deserialize(var[Encoding.VAR]) + + +def decode_window(var: dict[str, Any]) -> Window: + """ + Decode a previously serialized :class:`Window`. + + :meta private: + """ + window_cls: type[Window] = import_string(var[Encoding.TYPE]) + return window_cls.deserialize(var[Encoding.VAR]) diff --git a/airflow-core/src/airflow/serialization/encoders.py b/airflow-core/src/airflow/serialization/encoders.py index 9e341cbe78340..17c80727762ba 100644 --- a/airflow-core/src/airflow/serialization/encoders.py +++ b/airflow-core/src/airflow/serialization/encoders.py @@ -27,6 +27,7 @@ from airflow._shared.module_loading import qualname from airflow.partition_mappers.base import PartitionMapper as CorePartitionMapper +from airflow.partition_mappers.window import Window as CoreWindow from airflow.sdk import ( AllowedKeyMapper, Asset, @@ -37,18 +38,26 @@ ChainMapper, CronDataIntervalTimetable, CronTriggerTimetable, + DayWindow, DeltaDataIntervalTimetable, DeltaTriggerTimetable, EventsTimetable, + HourWindow, IdentityMapper, + MonthWindow, MultipleCronTriggerTimetable, PartitionMapper, ProductMapper, + QuarterWindow, + RollupMapper, StartOfDayMapper, StartOfMonthMapper, StartOfQuarterMapper, StartOfWeekMapper, StartOfYearMapper, + WeekWindow, + Window, + YearWindow, ) from airflow.sdk.bases.timetable import BaseTimetable from airflow.sdk.definitions.asset import AssetRef @@ -408,6 +417,7 @@ def _(self, timetable: PartitionedAssetTimetable) -> dict[str, Any]: ChainMapper: "airflow.partition_mappers.chain.ChainMapper", IdentityMapper: "airflow.partition_mappers.identity.IdentityMapper", ProductMapper: "airflow.partition_mappers.product.ProductMapper", + RollupMapper: "airflow.partition_mappers.base.RollupMapper", StartOfDayMapper: "airflow.partition_mappers.temporal.StartOfDayMapper", StartOfHourMapper: "airflow.partition_mappers.temporal.StartOfHourMapper", StartOfMonthMapper: "airflow.partition_mappers.temporal.StartOfMonthMapper", @@ -463,6 +473,40 @@ def _(self, partition_mapper: ProductMapper) -> dict[str, Any]: def _(self, partition_mapper: AllowedKeyMapper) -> dict[str, Any]: return {"allowed_keys": partition_mapper.allowed_keys} + @serialize_partition_mapper.register + def _(self, partition_mapper: RollupMapper) -> dict[str, Any]: + return { + "source_mapper": encode_partition_mapper(partition_mapper.source_mapper), + "window": encode_window(partition_mapper.window), + } + + BUILTIN_WINDOWS: dict[type, str] = { + HourWindow: "airflow.partition_mappers.window.HourWindow", + DayWindow: "airflow.partition_mappers.window.DayWindow", + WeekWindow: "airflow.partition_mappers.window.WeekWindow", + MonthWindow: "airflow.partition_mappers.window.MonthWindow", + QuarterWindow: "airflow.partition_mappers.window.QuarterWindow", + YearWindow: "airflow.partition_mappers.window.YearWindow", + } + + @functools.singledispatchmethod + def serialize_window(self, window: Window | CoreWindow) -> dict[str, Any]: + if not isinstance(window, CoreWindow): + raise NotImplementedError(f"can not serialize window {type(window).__name__}") + return window.serialize() + + @serialize_window.register(HourWindow) + @serialize_window.register(DayWindow) + @serialize_window.register(WeekWindow) + @serialize_window.register(MonthWindow) + @serialize_window.register(QuarterWindow) + @serialize_window.register(YearWindow) + def _( + self, + window: HourWindow | DayWindow | WeekWindow | MonthWindow | QuarterWindow | YearWindow, + ) -> dict[str, Any]: + return {} + _serializer = _Serializer() @@ -552,3 +596,25 @@ def encode_partition_mapper(var: PartitionMapper | CorePartitionMapper) -> dict[ Encoding.TYPE: qn, Encoding.VAR: _serializer.serialize_partition_mapper(var), } + + +def encode_window(var: Window | CoreWindow) -> dict[str, Any]: + """ + Encode a :class:`Window` instance. + + :meta private: + """ + var_type = type(var) + importable_string = _serializer.BUILTIN_WINDOWS.get(var_type) + if importable_string is not None: + return { + Encoding.TYPE: importable_string, + Encoding.VAR: _serializer.serialize_window(var), + } + + # Custom Window subclasses must live at an importable path so the + # scheduler can reconstruct them via import_string during deserialization. + return { + Encoding.TYPE: qualname(var), + Encoding.VAR: _serializer.serialize_window(var), + } diff --git a/airflow-core/src/airflow/timetables/base.py b/airflow-core/src/airflow/timetables/base.py index a92bcd7a1f85d..d9ed4f574af4b 100644 --- a/airflow-core/src/airflow/timetables/base.py +++ b/airflow-core/src/airflow/timetables/base.py @@ -16,7 +16,9 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Any, NamedTuple, Protocol, runtime_checkable +from typing import TYPE_CHECKING, Any, NamedTuple, Protocol, TypedDict, runtime_checkable + +from typing_extensions import NotRequired from airflow._shared.module_loading import qualname from airflow._shared.timezones import timezone @@ -39,6 +41,21 @@ from airflow.utils.types import DagRunType +class PartitionMapperInfo(TypedDict): + """ + JSON-serializable snapshot of one asset's partition mapper attributes. + + Stored as ``DagModel.partition_mapper_info`` (a list of these) so the UI can + resolve mapper attributes without deserializing the timetable on each request. + Either ``name``, ``uri``, or both identify the asset; ``Asset.ref(name=...)`` + omits ``uri`` and ``Asset.ref(uri=...)`` omits ``name``. + """ + + is_rollup: bool + name: NotRequired[str] + uri: NotRequired[str] + + class DataInterval(NamedTuple): """ A data interval for a DagRun to operate over. @@ -218,6 +235,18 @@ class Timetable(Protocol): instead of the traditional logic based on logical dates and data intervals. """ + @property + def partition_mapper_info(self) -> list[PartitionMapperInfo]: + """ + JSON-serializable per-asset partition mapper attributes. + + Empty list for timetables without asset-level partition mappers (the + default, including non-partitioned timetables and cron-driven partitioned + timetables). Asset-driven partitioned timetables override this with one + entry per asset (or asset ref) — see :class:`PartitionMapperInfo`. + """ + return [] + @classmethod def deserialize(cls, data: dict[str, Any]) -> Timetable: """ diff --git a/airflow-core/src/airflow/timetables/simple.py b/airflow-core/src/airflow/timetables/simple.py index 01fb12f81dd0c..be47f56957fcb 100644 --- a/airflow-core/src/airflow/timetables/simple.py +++ b/airflow-core/src/airflow/timetables/simple.py @@ -32,7 +32,7 @@ SerializedAssetUriRef, ) from airflow.serialization.encoders import encode_asset_like, encode_partition_mapper -from airflow.timetables.base import DagRunInfo, DataInterval, Timetable +from airflow.timetables.base import DagRunInfo, DataInterval, PartitionMapperInfo, Timetable try: from airflow.sdk.definitions.asset import BaseAsset @@ -299,6 +299,30 @@ def get_partition_mapper(self, *, name: str = "", uri: str = "") -> PartitionMap return self.default_partition_mapper + @property + def partition_mapper_info(self) -> list[PartitionMapperInfo]: + """ + JSON-serializable snapshot of partition mapper attributes per asset. + + One :class:`~airflow.timetables.base.PartitionMapperInfo` entry per asset + (or asset ref) in ``partition_mapper_config``. The UI reads this from + the cached ``DagModel.partition_mapper_info`` instead of deserializing + the timetable on each request. + """ + entries: list[PartitionMapperInfo] = [] + for base_asset, partition_mapper in self.partition_mapper_config.items(): + is_rollup = partition_mapper.is_rollup + for unique_key, _ in base_asset.iter_assets(): + entries.append( + PartitionMapperInfo(name=unique_key.name, uri=unique_key.uri, is_rollup=is_rollup) + ) + for s_asset_ref in base_asset.iter_asset_refs(): + if isinstance(s_asset_ref, SerializedAssetNameRef): + entries.append(PartitionMapperInfo(name=s_asset_ref.name, is_rollup=is_rollup)) + elif isinstance(s_asset_ref, SerializedAssetUriRef): + entries.append(PartitionMapperInfo(uri=s_asset_ref.uri, is_rollup=is_rollup)) + return entries + def serialize(self) -> dict[str, Any]: from airflow.serialization.serialized_objects import encode_asset_like diff --git a/airflow-core/src/airflow/ui/openapi-gen/queries/ensureQueryData.ts b/airflow-core/src/airflow/ui/openapi-gen/queries/ensureQueryData.ts index 5e0595ec8f6b4..91d5df8edeaff 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/queries/ensureQueryData.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/queries/ensureQueryData.ts @@ -162,7 +162,7 @@ export const ensureUseAssetServiceGetDagAssetQueuedEventData = (queryClient: Que * Next Run Assets * @param data The data for the request. * @param data.dagId -* @returns unknown Successful Response +* @returns NextRunAssetsResponse Successful Response * @throws ApiError */ export const ensureUseAssetServiceNextRunAssetsData = (queryClient: QueryClient, { dagId }: { diff --git a/airflow-core/src/airflow/ui/openapi-gen/queries/prefetch.ts b/airflow-core/src/airflow/ui/openapi-gen/queries/prefetch.ts index 2a17cdfefc00a..4dd3b0067c57c 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/queries/prefetch.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/queries/prefetch.ts @@ -162,7 +162,7 @@ export const prefetchUseAssetServiceGetDagAssetQueuedEvent = (queryClient: Query * Next Run Assets * @param data The data for the request. * @param data.dagId -* @returns unknown Successful Response +* @returns NextRunAssetsResponse Successful Response * @throws ApiError */ export const prefetchUseAssetServiceNextRunAssets = (queryClient: QueryClient, { dagId }: { diff --git a/airflow-core/src/airflow/ui/openapi-gen/queries/queries.ts b/airflow-core/src/airflow/ui/openapi-gen/queries/queries.ts index 52e4776d510a5..1e8c86df6f095 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/queries/queries.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/queries/queries.ts @@ -162,7 +162,7 @@ export const useAssetServiceGetDagAssetQueuedEvent = = unknown[]>({ dagId }: { diff --git a/airflow-core/src/airflow/ui/openapi-gen/queries/suspense.ts b/airflow-core/src/airflow/ui/openapi-gen/queries/suspense.ts index 9d42191b88606..d67f99e2669a4 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/queries/suspense.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/queries/suspense.ts @@ -162,7 +162,7 @@ export const useAssetServiceGetDagAssetQueuedEventSuspense = = unknown[]>({ dagId }: { diff --git a/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts b/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts index bc4c7179c2dbb..61589be2ed320 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts @@ -8763,6 +8763,114 @@ export const $MenuItemCollectionResponse = { description: 'Menu Item Collection serializer for responses.' } as const; +export const $NextRunAssetEventResponse = { + properties: { + id: { + type: 'integer', + title: 'Id' + }, + name: { + anyOf: [ + { + type: 'string' + }, + { + type: 'null' + } + ], + title: 'Name' + }, + uri: { + type: 'string', + title: 'Uri' + }, + last_update: { + anyOf: [ + { + type: 'string', + format: 'date-time' + }, + { + type: 'null' + } + ], + title: 'Last Update' + }, + received_count: { + type: 'integer', + title: 'Received Count', + default: 0 + }, + required_count: { + type: 'integer', + title: 'Required Count', + default: 1 + }, + received_keys: { + items: { + type: 'string' + }, + type: 'array', + title: 'Received Keys' + }, + required_keys: { + items: { + type: 'string' + }, + type: 'array', + title: 'Required Keys' + }, + is_rollup: { + type: 'boolean', + title: 'Is Rollup', + default: false + } + }, + type: 'object', + required: ['id', 'name', 'uri'], + title: 'NextRunAssetEventResponse', + description: 'One asset event in the ``next_run_assets`` payload.' +} as const; + +export const $NextRunAssetsResponse = { + properties: { + asset_expression: { + anyOf: [ + { + additionalProperties: true, + type: 'object' + }, + { + type: 'null' + } + ], + title: 'Asset Expression' + }, + events: { + items: { + '$ref': '#/components/schemas/NextRunAssetEventResponse' + }, + type: 'array', + title: 'Events' + }, + pending_partition_count: { + anyOf: [ + { + type: 'integer' + }, + { + type: 'null' + } + ], + title: 'Pending Partition Count' + } + }, + type: 'object', + required: ['events'], + title: 'NextRunAssetsResponse', + description: 'Response for the ``next_run_assets`` endpoint.' +} as const; + export const $NodeResponse = { properties: { id: { @@ -8877,10 +8985,37 @@ export const $PartitionedDagRunAssetResponse = { received: { type: 'boolean', title: 'Received' + }, + received_count: { + type: 'integer', + title: 'Received Count' + }, + required_count: { + type: 'integer', + title: 'Required Count' + }, + received_keys: { + items: { + type: 'string' + }, + type: 'array', + title: 'Received Keys' + }, + required_keys: { + items: { + type: 'string' + }, + type: 'array', + title: 'Required Keys' + }, + is_rollup: { + type: 'boolean', + title: 'Is Rollup', + default: false } }, type: 'object', - required: ['asset_id', 'asset_name', 'asset_uri', 'received'], + required: ['asset_id', 'asset_name', 'asset_uri', 'received', 'received_count', 'required_count', 'received_keys', 'required_keys'], title: 'PartitionedDagRunAssetResponse', description: 'Asset info within a partitioned Dag run detail.' } as const; diff --git a/airflow-core/src/airflow/ui/openapi-gen/requests/services.gen.ts b/airflow-core/src/airflow/ui/openapi-gen/requests/services.gen.ts index 83c5d5c17b78d..b9f7e7ce77a08 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/requests/services.gen.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/requests/services.gen.ts @@ -3,7 +3,7 @@ import type { CancelablePromise } from './core/CancelablePromise'; import { OpenAPI } from './core/OpenAPI'; import { request as __request } from './core/request'; -import type { GetAssetsData, GetAssetsResponse, GetAssetAliasesData, GetAssetAliasesResponse, GetAssetAliasData, GetAssetAliasResponse, GetAssetEventsData, GetAssetEventsResponse, CreateAssetEventData, CreateAssetEventResponse, MaterializeAssetData, MaterializeAssetResponse, GetAssetQueuedEventsData, GetAssetQueuedEventsResponse, DeleteAssetQueuedEventsData, DeleteAssetQueuedEventsResponse, GetAssetData, GetAssetResponse, GetDagAssetQueuedEventsData, GetDagAssetQueuedEventsResponse, DeleteDagAssetQueuedEventsData, DeleteDagAssetQueuedEventsResponse, GetDagAssetQueuedEventData, GetDagAssetQueuedEventResponse, DeleteDagAssetQueuedEventData, DeleteDagAssetQueuedEventResponse, NextRunAssetsData, NextRunAssetsResponse, ListBackfillsData, ListBackfillsResponse, CreateBackfillData, CreateBackfillResponse, GetBackfillData, GetBackfillResponse, PauseBackfillData, PauseBackfillResponse, UnpauseBackfillData, UnpauseBackfillResponse, CancelBackfillData, CancelBackfillResponse, CreateBackfillDryRunData, CreateBackfillDryRunResponse, ListBackfillsUiData, ListBackfillsUiResponse, DeleteConnectionData, DeleteConnectionResponse, GetConnectionData, GetConnectionResponse, PatchConnectionData, PatchConnectionResponse, GetConnectionsData, GetConnectionsResponse, PostConnectionData, PostConnectionResponse, BulkConnectionsData, BulkConnectionsResponse, TestConnectionData, TestConnectionResponse, CreateDefaultConnectionsResponse, HookMetaDataResponse, GetDagRunData, GetDagRunResponse, DeleteDagRunData, DeleteDagRunResponse, PatchDagRunData, PatchDagRunResponse, GetUpstreamAssetEventsData, GetUpstreamAssetEventsResponse, ClearDagRunData, ClearDagRunResponse, GetDagRunsData, GetDagRunsResponse, TriggerDagRunData, TriggerDagRunResponse, WaitDagRunUntilFinishedData, WaitDagRunUntilFinishedResponse, GetListDagRunsBatchData, GetListDagRunsBatchResponse, GetDagSourceData, GetDagSourceResponse, GetDagStatsData, GetDagStatsResponse, GetConfigData, GetConfigResponse, GetConfigValueData, GetConfigValueResponse, GetConfigsResponse, ListDagWarningsData, ListDagWarningsResponse, GetDagsData, GetDagsResponse, PatchDagsData, PatchDagsResponse, GetDagData, GetDagResponse, PatchDagData, PatchDagResponse, DeleteDagData, DeleteDagResponse, GetDagDetailsData, GetDagDetailsResponse, FavoriteDagData, FavoriteDagResponse, UnfavoriteDagData, UnfavoriteDagResponse, GetDagTagsData, GetDagTagsResponse, GetDagsUiData, GetDagsUiResponse, GetLatestRunInfoData, GetLatestRunInfoResponse, GetEventLogData, GetEventLogResponse, GetEventLogsData, GetEventLogsResponse, GetExtraLinksData, GetExtraLinksResponse, GetTaskInstanceData, GetTaskInstanceResponse, PatchTaskInstanceData, PatchTaskInstanceResponse, DeleteTaskInstanceData, DeleteTaskInstanceResponse, GetMappedTaskInstancesData, GetMappedTaskInstancesResponse, GetTaskInstanceDependenciesByMapIndexData, GetTaskInstanceDependenciesByMapIndexResponse, GetTaskInstanceDependenciesData, GetTaskInstanceDependenciesResponse, GetTaskInstanceTriesData, GetTaskInstanceTriesResponse, GetMappedTaskInstanceTriesData, GetMappedTaskInstanceTriesResponse, GetMappedTaskInstanceData, GetMappedTaskInstanceResponse, PatchTaskInstanceByMapIndexData, PatchTaskInstanceByMapIndexResponse, GetTaskInstancesData, GetTaskInstancesResponse, BulkTaskInstancesData, BulkTaskInstancesResponse, GetTaskInstancesBatchData, GetTaskInstancesBatchResponse, GetTaskInstanceTryDetailsData, GetTaskInstanceTryDetailsResponse, GetMappedTaskInstanceTryDetailsData, GetMappedTaskInstanceTryDetailsResponse, PostClearTaskInstancesData, PostClearTaskInstancesResponse, PatchTaskGroupInstancesData, PatchTaskGroupInstancesResponse, PatchTaskGroupInstancesDryRunData, PatchTaskGroupInstancesDryRunResponse, PatchTaskInstanceDryRunByMapIndexData, PatchTaskInstanceDryRunByMapIndexResponse, PatchTaskInstanceDryRunData, PatchTaskInstanceDryRunResponse, GetLogData, GetLogResponse, GetExternalLogUrlData, GetExternalLogUrlResponse, UpdateHitlDetailData, UpdateHitlDetailResponse, GetHitlDetailData, GetHitlDetailResponse, GetHitlDetailTryDetailData, GetHitlDetailTryDetailResponse, GetHitlDetailsData, GetHitlDetailsResponse, GetImportErrorData, GetImportErrorResponse, GetImportErrorsData, GetImportErrorsResponse, GetJobsData, GetJobsResponse, GetPluginsData, GetPluginsResponse, ImportErrorsResponse, DeletePoolData, DeletePoolResponse, GetPoolData, GetPoolResponse, PatchPoolData, PatchPoolResponse, GetPoolsData, GetPoolsResponse, PostPoolData, PostPoolResponse, BulkPoolsData, BulkPoolsResponse, GetProvidersData, GetProvidersResponse, GetXcomEntryData, GetXcomEntryResponse, UpdateXcomEntryData, UpdateXcomEntryResponse, DeleteXcomEntryData, DeleteXcomEntryResponse, GetXcomEntriesData, GetXcomEntriesResponse, CreateXcomEntryData, CreateXcomEntryResponse, GetTasksData, GetTasksResponse, GetTaskData, GetTaskResponse, DeleteVariableData, DeleteVariableResponse, GetVariableData, GetVariableResponse, PatchVariableData, PatchVariableResponse, GetVariablesData, GetVariablesResponse, PostVariableData, PostVariableResponse, BulkVariablesData, BulkVariablesResponse, ReparseDagFileData, ReparseDagFileResponse, GetDagVersionData, GetDagVersionResponse, GetDagVersionsData, GetDagVersionsResponse, GetHealthResponse, GetVersionResponse, LoginData, LoginResponse, LogoutResponse, GetAuthMenusResponse, GetCurrentUserInfoResponse, GenerateTokenData, GenerateTokenResponse2, GetPartitionedDagRunsData, GetPartitionedDagRunsResponse, GetPendingPartitionedDagRunData, GetPendingPartitionedDagRunResponse, GetDependenciesData, GetDependenciesResponse, HistoricalMetricsData, HistoricalMetricsResponse, DagStatsResponse2, GetDeadlinesData, GetDeadlinesResponse, GetDagDeadlineAlertsData, GetDagDeadlineAlertsResponse, StructureDataData, StructureDataResponse2, GetDagStructureData, GetDagStructureResponse, GetGridRunsData, GetGridRunsResponse, GetGridTiSummariesStreamData, GetGridTiSummariesStreamResponse, GetGanttDataData, GetGanttDataResponse, GetCalendarData, GetCalendarResponse, ListTeamsData, ListTeamsResponse } from './types.gen'; +import type { GetAssetsData, GetAssetsResponse, GetAssetAliasesData, GetAssetAliasesResponse, GetAssetAliasData, GetAssetAliasResponse, GetAssetEventsData, GetAssetEventsResponse, CreateAssetEventData, CreateAssetEventResponse, MaterializeAssetData, MaterializeAssetResponse, GetAssetQueuedEventsData, GetAssetQueuedEventsResponse, DeleteAssetQueuedEventsData, DeleteAssetQueuedEventsResponse, GetAssetData, GetAssetResponse, GetDagAssetQueuedEventsData, GetDagAssetQueuedEventsResponse, DeleteDagAssetQueuedEventsData, DeleteDagAssetQueuedEventsResponse, GetDagAssetQueuedEventData, GetDagAssetQueuedEventResponse, DeleteDagAssetQueuedEventData, DeleteDagAssetQueuedEventResponse, NextRunAssetsData, NextRunAssetsResponse2, ListBackfillsData, ListBackfillsResponse, CreateBackfillData, CreateBackfillResponse, GetBackfillData, GetBackfillResponse, PauseBackfillData, PauseBackfillResponse, UnpauseBackfillData, UnpauseBackfillResponse, CancelBackfillData, CancelBackfillResponse, CreateBackfillDryRunData, CreateBackfillDryRunResponse, ListBackfillsUiData, ListBackfillsUiResponse, DeleteConnectionData, DeleteConnectionResponse, GetConnectionData, GetConnectionResponse, PatchConnectionData, PatchConnectionResponse, GetConnectionsData, GetConnectionsResponse, PostConnectionData, PostConnectionResponse, BulkConnectionsData, BulkConnectionsResponse, TestConnectionData, TestConnectionResponse, CreateDefaultConnectionsResponse, HookMetaDataResponse, GetDagRunData, GetDagRunResponse, DeleteDagRunData, DeleteDagRunResponse, PatchDagRunData, PatchDagRunResponse, GetUpstreamAssetEventsData, GetUpstreamAssetEventsResponse, ClearDagRunData, ClearDagRunResponse, GetDagRunsData, GetDagRunsResponse, TriggerDagRunData, TriggerDagRunResponse, WaitDagRunUntilFinishedData, WaitDagRunUntilFinishedResponse, GetListDagRunsBatchData, GetListDagRunsBatchResponse, GetDagSourceData, GetDagSourceResponse, GetDagStatsData, GetDagStatsResponse, GetConfigData, GetConfigResponse, GetConfigValueData, GetConfigValueResponse, GetConfigsResponse, ListDagWarningsData, ListDagWarningsResponse, GetDagsData, GetDagsResponse, PatchDagsData, PatchDagsResponse, GetDagData, GetDagResponse, PatchDagData, PatchDagResponse, DeleteDagData, DeleteDagResponse, GetDagDetailsData, GetDagDetailsResponse, FavoriteDagData, FavoriteDagResponse, UnfavoriteDagData, UnfavoriteDagResponse, GetDagTagsData, GetDagTagsResponse, GetDagsUiData, GetDagsUiResponse, GetLatestRunInfoData, GetLatestRunInfoResponse, GetEventLogData, GetEventLogResponse, GetEventLogsData, GetEventLogsResponse, GetExtraLinksData, GetExtraLinksResponse, GetTaskInstanceData, GetTaskInstanceResponse, PatchTaskInstanceData, PatchTaskInstanceResponse, DeleteTaskInstanceData, DeleteTaskInstanceResponse, GetMappedTaskInstancesData, GetMappedTaskInstancesResponse, GetTaskInstanceDependenciesByMapIndexData, GetTaskInstanceDependenciesByMapIndexResponse, GetTaskInstanceDependenciesData, GetTaskInstanceDependenciesResponse, GetTaskInstanceTriesData, GetTaskInstanceTriesResponse, GetMappedTaskInstanceTriesData, GetMappedTaskInstanceTriesResponse, GetMappedTaskInstanceData, GetMappedTaskInstanceResponse, PatchTaskInstanceByMapIndexData, PatchTaskInstanceByMapIndexResponse, GetTaskInstancesData, GetTaskInstancesResponse, BulkTaskInstancesData, BulkTaskInstancesResponse, GetTaskInstancesBatchData, GetTaskInstancesBatchResponse, GetTaskInstanceTryDetailsData, GetTaskInstanceTryDetailsResponse, GetMappedTaskInstanceTryDetailsData, GetMappedTaskInstanceTryDetailsResponse, PostClearTaskInstancesData, PostClearTaskInstancesResponse, PatchTaskGroupInstancesData, PatchTaskGroupInstancesResponse, PatchTaskGroupInstancesDryRunData, PatchTaskGroupInstancesDryRunResponse, PatchTaskInstanceDryRunByMapIndexData, PatchTaskInstanceDryRunByMapIndexResponse, PatchTaskInstanceDryRunData, PatchTaskInstanceDryRunResponse, GetLogData, GetLogResponse, GetExternalLogUrlData, GetExternalLogUrlResponse, UpdateHitlDetailData, UpdateHitlDetailResponse, GetHitlDetailData, GetHitlDetailResponse, GetHitlDetailTryDetailData, GetHitlDetailTryDetailResponse, GetHitlDetailsData, GetHitlDetailsResponse, GetImportErrorData, GetImportErrorResponse, GetImportErrorsData, GetImportErrorsResponse, GetJobsData, GetJobsResponse, GetPluginsData, GetPluginsResponse, ImportErrorsResponse, DeletePoolData, DeletePoolResponse, GetPoolData, GetPoolResponse, PatchPoolData, PatchPoolResponse, GetPoolsData, GetPoolsResponse, PostPoolData, PostPoolResponse, BulkPoolsData, BulkPoolsResponse, GetProvidersData, GetProvidersResponse, GetXcomEntryData, GetXcomEntryResponse, UpdateXcomEntryData, UpdateXcomEntryResponse, DeleteXcomEntryData, DeleteXcomEntryResponse, GetXcomEntriesData, GetXcomEntriesResponse, CreateXcomEntryData, CreateXcomEntryResponse, GetTasksData, GetTasksResponse, GetTaskData, GetTaskResponse, DeleteVariableData, DeleteVariableResponse, GetVariableData, GetVariableResponse, PatchVariableData, PatchVariableResponse, GetVariablesData, GetVariablesResponse, PostVariableData, PostVariableResponse, BulkVariablesData, BulkVariablesResponse, ReparseDagFileData, ReparseDagFileResponse, GetDagVersionData, GetDagVersionResponse, GetDagVersionsData, GetDagVersionsResponse, GetHealthResponse, GetVersionResponse, LoginData, LoginResponse, LogoutResponse, GetAuthMenusResponse, GetCurrentUserInfoResponse, GenerateTokenData, GenerateTokenResponse2, GetPartitionedDagRunsData, GetPartitionedDagRunsResponse, GetPendingPartitionedDagRunData, GetPendingPartitionedDagRunResponse, GetDependenciesData, GetDependenciesResponse, HistoricalMetricsData, HistoricalMetricsResponse, DagStatsResponse2, GetDeadlinesData, GetDeadlinesResponse, GetDagDeadlineAlertsData, GetDagDeadlineAlertsResponse, StructureDataData, StructureDataResponse2, GetDagStructureData, GetDagStructureResponse, GetGridRunsData, GetGridRunsResponse, GetGridTiSummariesStreamData, GetGridTiSummariesStreamResponse, GetGanttDataData, GetGanttDataResponse, GetCalendarData, GetCalendarResponse, ListTeamsData, ListTeamsResponse } from './types.gen'; export class AssetService { /** @@ -411,10 +411,10 @@ export class AssetService { * Next Run Assets * @param data The data for the request. * @param data.dagId - * @returns unknown Successful Response + * @returns NextRunAssetsResponse Successful Response * @throws ApiError */ - public static nextRunAssets(data: NextRunAssetsData): CancelablePromise { + public static nextRunAssets(data: NextRunAssetsData): CancelablePromise { return __request(OpenAPI, { method: 'GET', url: '/ui/next_run_assets/{dag_id}', diff --git a/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts b/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts index 5b216da03a58b..8acc694124bea 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts @@ -2186,6 +2186,32 @@ export type MenuItemCollectionResponse = { extra_menu_items: Array; }; +/** + * One asset event in the ``next_run_assets`` payload. + */ +export type NextRunAssetEventResponse = { + id: number; + name: string | null; + uri: string; + last_update?: string | null; + received_count?: number; + required_count?: number; + received_keys?: Array<(string)>; + required_keys?: Array<(string)>; + is_rollup?: boolean; +}; + +/** + * Response for the ``next_run_assets`` endpoint. + */ +export type NextRunAssetsResponse = { + asset_expression?: { + [key: string]: unknown; +} | null; + events: Array; + pending_partition_count?: number | null; +}; + /** * Node serializer for responses. */ @@ -2211,6 +2237,11 @@ export type PartitionedDagRunAssetResponse = { asset_name: string; asset_uri: string; received: boolean; + received_count: number; + required_count: number; + received_keys: Array<(string)>; + required_keys: Array<(string)>; + is_rollup?: boolean; }; /** @@ -2522,9 +2553,7 @@ export type NextRunAssetsData = { dagId: string; }; -export type NextRunAssetsResponse = { - [key: string]: unknown; -}; +export type NextRunAssetsResponse2 = NextRunAssetsResponse; export type ListBackfillsData = { dagId: string; @@ -4487,9 +4516,7 @@ export type $OpenApiTs = { /** * Successful Response */ - 200: { - [key: string]: unknown; - }; + 200: NextRunAssetsResponse; /** * Validation Error */ diff --git a/airflow-core/src/airflow/ui/src/components/AssetExpression/AssetExpression.tsx b/airflow-core/src/airflow/ui/src/components/AssetExpression/AssetExpression.tsx index 9b2a8e3b4f2b8..c8c0382fd7326 100644 --- a/airflow-core/src/airflow/ui/src/components/AssetExpression/AssetExpression.tsx +++ b/airflow-core/src/airflow/ui/src/components/AssetExpression/AssetExpression.tsx @@ -21,16 +21,18 @@ import { Fragment } from "react"; import { useTranslation } from "react-i18next"; import { TbLogicOr } from "react-icons/tb"; +import type { NextRunAssetEventResponse } from "openapi/requests/types.gen"; + import { AndGateNode } from "./AndGateNode"; import { AssetNode } from "./AssetNode"; import { OrGateNode } from "./OrGateNode"; -import type { ExpressionType, NextRunEvent } from "./types"; +import type { ExpressionType } from "./types"; export const AssetExpression = ({ events, expression, }: { - readonly events?: Array; + readonly events?: Array; readonly expression: ExpressionType | undefined; }) => { const { t: translate } = useTranslation("common"); @@ -39,6 +41,14 @@ export const AssetExpression = ({ return undefined; } + // A bare asset or alias at the top level (no all/any wrapper) — render it directly. + if ("asset" in expression) { + return ev.id === expression.asset.id)} />; + } + if ("alias" in expression) { + return ; + } + return ( <> {"any" in expression ? ( diff --git a/airflow-core/src/airflow/ui/src/components/AssetExpression/AssetNode.tsx b/airflow-core/src/airflow/ui/src/components/AssetExpression/AssetNode.tsx index 6289309ec59f4..84dfb9a459389 100644 --- a/airflow-core/src/airflow/ui/src/components/AssetExpression/AssetNode.tsx +++ b/airflow-core/src/airflow/ui/src/components/AssetExpression/AssetNode.tsx @@ -16,46 +16,131 @@ * specific language governing permissions and limitations * under the License. */ -import { Box, Text, HStack, Link } from "@chakra-ui/react"; -import { FiDatabase } from "react-icons/fi"; +import { Box, Button, HStack, Link, Text, VStack } from "@chakra-ui/react"; +import { FiCheck, FiDatabase, FiMinus } from "react-icons/fi"; import { PiRectangleDashed } from "react-icons/pi"; import { Link as RouterLink } from "react-router-dom"; +import type { NextRunAssetEventResponse } from "openapi/requests/types.gen"; +import { Popover } from "src/components/ui"; + import Time from "../Time"; -import type { AssetSummary, NextRunEvent } from "./types"; +import type { AssetSummary } from "./types"; + +const RollupKeyChecklistPopover = ({ + receivedCount, + receivedKeys, + requiredCount, + requiredKeys, +}: { + readonly receivedCount: number; + readonly receivedKeys: Array; + readonly requiredCount: number; + readonly requiredKeys: Array; +}) => { + const receivedKeySet = new Set(receivedKeys); + + return ( + // eslint-disable-next-line jsx-a11y/no-autofocus + + + + + + + + + {requiredKeys.map((key) => { + const isReceived = receivedKeySet.has(key); + + return ( + + {isReceived ? ( + + ) : ( + + )} + + {key} + + + ); + })} + + + + + ); +}; export const AssetNode = ({ asset, event, }: { readonly asset: AssetSummary; - readonly event?: NextRunEvent; -}) => ( - - - {"asset" in asset ? : } - {"alias" in asset ? ( - {asset.alias.name} - ) : ( - - {asset.asset.name} - - )} - - {event?.lastUpdate === undefined ? undefined : ( - - - )} - -); + readonly event?: NextRunAssetEventResponse; +}) => { + const isFullyReceived = Boolean(event?.last_update); + const isPartial = + !isFullyReceived && + (event?.received_count ?? 0) > 0 && + (event?.received_count ?? 0) < (event?.required_count ?? 1); + // In a partitioned Dag with a pending partition, `last_update` is the last + // asset event, not the pending partition key's arrival — so suppress it for + // non-rollup partitioned nodes. We detect that case via `required_keys`, + // which only the partitioned + pending-APDR branch of `next_run_assets` + // populates. Plain non-partitioned events leave `required_keys` empty and + // must keep showing their timestamp. + const isPartitionedNonRollup = event?.is_rollup === false && (event.required_keys?.length ?? 0) > 0; + const showTime = isFullyReceived && !isPartitionedNonRollup; + const showRollupChecklist = + (event?.is_rollup ?? false) && (event?.required_keys?.length ?? 0) > 0 && (isPartial || isFullyReceived); + + return ( + + + {"asset" in asset ? : } + {"alias" in asset ? ( + {asset.alias.name} + ) : ( + + {asset.asset.name} + + )} + + {showRollupChecklist ? ( + + ) : showTime ? ( + + + ) : isPartial ? ( + + {event?.received_count} / {event?.required_count} + + ) : undefined} + + ); +}; diff --git a/airflow-core/src/airflow/ui/src/components/AssetExpression/types.ts b/airflow-core/src/airflow/ui/src/components/AssetExpression/types.ts index 7328dcfcc0eec..cae9bfa524f04 100644 --- a/airflow-core/src/airflow/ui/src/components/AssetExpression/types.ts +++ b/airflow-core/src/airflow/ui/src/components/AssetExpression/types.ts @@ -33,8 +33,6 @@ type Alias = { }; }; -export type NextRunEvent = { id: number; lastUpdate: string | null; name: string | null; uri: string }; - export type AssetSummary = Alias | Asset; export type ExpressionType = diff --git a/airflow-core/src/airflow/ui/src/components/AssetProgressCell.tsx b/airflow-core/src/airflow/ui/src/components/AssetProgressCell.tsx index 7d305a03fb765..50676016e2869 100644 --- a/airflow-core/src/airflow/ui/src/components/AssetProgressCell.tsx +++ b/airflow-core/src/airflow/ui/src/components/AssetProgressCell.tsx @@ -16,13 +16,13 @@ * specific language governing permissions and limitations * under the License. */ -import { Button } from "@chakra-ui/react"; -import { FiDatabase } from "react-icons/fi"; +import { Button, HStack, Link, Text, VStack } from "@chakra-ui/react"; +import { FiCheck, FiDatabase, FiMinus } from "react-icons/fi"; +import { Link as RouterLink } from "react-router-dom"; import { usePartitionedDagRunServiceGetPendingPartitionedDagRun } from "openapi/queries"; -import type { PartitionedDagRunAssetResponse } from "openapi/requests/types.gen"; +import type { NextRunAssetEventResponse, PartitionedDagRunAssetResponse } from "openapi/requests/types.gen"; import { AssetExpression, type ExpressionType } from "src/components/AssetExpression"; -import type { NextRunEvent } from "src/components/AssetExpression/types"; import { Popover } from "src/components/ui"; type Props = { @@ -38,12 +38,17 @@ export const AssetProgressCell = ({ dagId, partitionKey, totalReceived, totalReq const assetExpression = data?.asset_expression as ExpressionType | undefined; const assets: Array = data?.assets ?? []; - const events: Array = assets - .filter((ak: PartitionedDagRunAssetResponse) => ak.received) + const hasRollup = assets.some((ak) => ak.is_rollup); + + const events: Array = assets + .filter((ak: PartitionedDagRunAssetResponse) => ak.received_count > 0) .map((ak: PartitionedDagRunAssetResponse) => ({ id: ak.asset_id, - lastUpdate: "received", + is_rollup: ak.is_rollup, + last_update: ak.received ? "received" : null, name: ak.asset_name, + received_count: ak.received_count, + required_count: ak.required_count, uri: ak.asset_uri, })); @@ -59,7 +64,41 @@ export const AssetProgressCell = ({ dagId, partitionKey, totalReceived, totalReq - + {hasRollup ? ( + + {assets + .filter((ak) => ak.is_rollup) + .map((ak) => { + const receivedKeySet = new Set(ak.received_keys); + + return ( + + + {ak.asset_name} + + {ak.required_keys.map((key) => { + const isReceived = receivedKeySet.has(key); + + return ( + + {isReceived ? ( + + ) : ( + + )} + + {key} + + + ); + })} + + ); + })} + + ) : ( + + )} diff --git a/airflow-core/src/airflow/ui/src/pages/DagsList/AssetSchedule.tsx b/airflow-core/src/airflow/ui/src/pages/DagsList/AssetSchedule.tsx index 78536cc1c1a46..d2477e11caae7 100644 --- a/airflow-core/src/airflow/ui/src/pages/DagsList/AssetSchedule.tsx +++ b/airflow-core/src/airflow/ui/src/pages/DagsList/AssetSchedule.tsx @@ -16,16 +16,16 @@ * specific language governing permissions and limitations * under the License. */ -import { Button, HStack, Link, Text } from "@chakra-ui/react"; +import { Button, HStack, Link, Text, VStack } from "@chakra-ui/react"; import dayjs from "dayjs"; import { useState } from "react"; import { useTranslation } from "react-i18next"; -import { FiDatabase } from "react-icons/fi"; +import { FiCheck, FiDatabase, FiMinus } from "react-icons/fi"; import { Link as RouterLink } from "react-router-dom"; import { useAssetServiceGetDagAssetQueuedEvents, useAssetServiceNextRunAssets } from "openapi/queries"; +import type { NextRunAssetEventResponse } from "openapi/requests/types.gen"; import { AssetExpression, type ExpressionType } from "src/components/AssetExpression"; -import type { NextRunEvent } from "src/components/AssetExpression/types"; import { TruncatedText } from "src/components/TruncatedText"; import { Popover } from "src/components/ui"; @@ -69,7 +69,7 @@ export const AssetSchedule = ({ assetExpression, dagId, timetablePartitioned, ti { enabled: !timetablePartitioned }, ); - const nextRunEvents = (nextRun?.events ?? []) as Array; + const nextRunEvents: Array = nextRun?.events ?? []; const queuedAssetEvents = new Map(); if (!timetablePartitioned) { @@ -83,15 +83,35 @@ export const AssetSchedule = ({ assetExpression, dagId, timetablePartitioned, ti } } + // Fully satisfied assets (used for the button count label). const pendingEvents = nextRunEvents.flatMap((event) => { if (timetablePartitioned) { - return event.lastUpdate === null ? [] : [event]; + return event.last_update === null ? [] : [event]; } const queuedAt = queuedAssetEvents.get(event.id); - return queuedAt === undefined ? [] : [{ ...event, lastUpdate: event.lastUpdate ?? queuedAt }]; + return queuedAt === undefined ? [] : [{ ...event, last_update: event.last_update ?? queuedAt }]; }); + + // For partitioned Dags, also include partially-received assets in the popover visualization. + const popoverEvents = timetablePartitioned + ? nextRunEvents.filter((event) => (event.received_count ?? (event.last_update === null ? 0 : 1)) > 0) + : pendingEvents; + + // For partitioned Dags (which may use rollup mappers), compute event-level totals so the + // button label reflects received/required partition-key events, not just asset counts. + // For non-partitioned Dags, fall back to asset counts (existing behaviour). + const scheduledCount = timetablePartitioned + ? nextRunEvents.reduce( + (sum, event) => sum + Math.min(event.received_count ?? 0, event.required_count ?? 1), + 0, + ) + : pendingEvents.length; + const scheduledTotal = timetablePartitioned + ? nextRunEvents.reduce((sum, event) => sum + (event.required_count ?? 1), 0) + : nextRunEvents.length; + const isLoading = isNextRunLoading || (!timetablePartitioned && isQueuedEventsLoading); if (!nextRunEvents.length) { @@ -104,7 +124,7 @@ export const AssetSchedule = ({ assetExpression, dagId, timetablePartitioned, ti } if (timetablePartitioned) { - const pendingCount = (nextRun?.pending_partition_count as number | undefined) ?? 0; + const pendingCount = nextRun?.pending_partition_count ?? 0; if (pendingCount === 0) { return ( @@ -125,6 +145,62 @@ export const AssetSchedule = ({ assetExpression, dagId, timetablePartitioned, ti const [asset] = nextRunEvents; if (nextRunEvents.length === 1 && asset !== undefined) { + const requiredCount = asset.required_count ?? 1; + const receivedCount = asset.received_count ?? 0; + const requiredKeys = asset.required_keys ?? []; + + if (asset.is_rollup && requiredKeys.length > 0) { + const receivedKeySet = new Set(asset.received_keys ?? []); + + return ( + + + + + + + + {/* eslint-disable-next-line jsx-a11y/no-autofocus */} + + + + + + + + + {requiredKeys.map((key) => { + const isReceived = receivedKeySet.has(key); + + return ( + + {isReceived ? ( + + ) : ( + + )} + + {key} + + + ); + })} + + + + + + ); + } + return ( @@ -143,15 +219,15 @@ export const AssetSchedule = ({ assetExpression, dagId, timetablePartitioned, ti diff --git a/airflow-core/src/airflow/utils/db.py b/airflow-core/src/airflow/utils/db.py index 4caa9901bfb95..d863e8433d05f 100644 --- a/airflow-core/src/airflow/utils/db.py +++ b/airflow-core/src/airflow/utils/db.py @@ -116,7 +116,7 @@ class MappedClassProtocol(Protocol): "3.1.0": "cc92b33c6709", "3.1.8": "509b94a1042d", "3.2.0": "1d6611b6ab7c", - "3.3.0": "a7f3b2c1d4e5", + "3.3.0": "c20871fbf23a", } # Prefix used to identify tables holding data moved during migration. diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_assets.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_assets.py index eaf3e5031fc48..431d2271f4ae6 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_assets.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_assets.py @@ -77,8 +77,19 @@ def test_should_response_200(self, test_client, dag_maker): ] }, "events": [ - {"id": mock.ANY, "uri": "s3://bucket/next-run-asset/1", "name": "asset1", "lastUpdate": None} + { + "id": mock.ANY, + "uri": "s3://bucket/next-run-asset/1", + "name": "asset1", + "last_update": None, + "received_count": 0, + "required_count": 1, + "received_keys": [], + "required_keys": [], + "is_rollup": False, + } ], + "pending_partition_count": None, } def test_should_respond_401(self, unauthenticated_test_client): @@ -141,9 +152,30 @@ def test_should_set_last_update_only_for_queued_and_hide_flag(self, test_client, }, # events are ordered by uri "events": [ - {"id": mock.ANY, "uri": "s3://bucket/A", "name": "A", "lastUpdate": mock.ANY}, - {"id": mock.ANY, "uri": "s3://bucket/B", "name": "B", "lastUpdate": None}, + { + "id": mock.ANY, + "uri": "s3://bucket/A", + "name": "A", + "last_update": mock.ANY, + "received_count": 0, + "required_count": 1, + "received_keys": [], + "required_keys": [], + "is_rollup": False, + }, + { + "id": mock.ANY, + "uri": "s3://bucket/B", + "name": "B", + "last_update": None, + "received_count": 0, + "required_count": 1, + "received_keys": [], + "required_keys": [], + "is_rollup": False, + }, ], + "pending_partition_count": None, } def test_last_update_respects_latest_run_filter(self, test_client, dag_maker, session): @@ -169,7 +201,7 @@ def test_last_update_respects_latest_run_filter(self, test_client, dag_maker, se resp = test_client.get("/next_run_assets/filter_run") assert resp.status_code == 200 ev = resp.json()["events"][0] - assert ev["lastUpdate"] is not None + assert ev["last_update"] is not None assert "queued" not in ev @pytest.mark.parametrize( @@ -221,4 +253,4 @@ def test_partitioned_dag_last_update( resp = test_client.get("/next_run_assets/part_dag") assert resp.status_code == 200 ev = resp.json()["events"][0] - assert (ev["lastUpdate"] is not None) == expect_last_update + assert (ev["last_update"] is not None) == expect_last_update diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_partitioned_dag_runs.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_partitioned_dag_runs.py index a2a86da327b86..10ca5304961eb 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_partitioned_dag_runs.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_partitioned_dag_runs.py @@ -23,6 +23,9 @@ from sqlalchemy import select from airflow.models.asset import AssetEvent, AssetModel, AssetPartitionDagRun, PartitionedAssetKeyLog +from airflow.partition_mappers.base import RollupMapper +from airflow.partition_mappers.temporal import StartOfWeekMapper +from airflow.partition_mappers.window import WeekWindow from airflow.providers.standard.operators.empty import EmptyOperator from airflow.sdk.definitions.asset import Asset from airflow.sdk.definitions.timetables.assets import PartitionedAssetTimetable @@ -146,7 +149,7 @@ def test_should_response_200( ) session.commit() - with assert_queries_count(3): + with assert_queries_count(5): resp = test_client.get( f"/partitioned_dag_runs?dag_id=list_dag" f"&has_created_dag_run_id={str(has_created_dag_run_id).lower()}" @@ -238,6 +241,144 @@ def test_partitioned_dag_runs_filters_unreadable_dags(self, _, test_client, dag_ dag_ids = {r["dag_id"] for r in body["partitioned_dag_runs"]} assert "restricted_dag" not in dag_ids + def test_duplicate_events_count_as_one(self, test_client, dag_maker, session): + """Multiple log entries for the same asset count as 1 received, not N.""" + asset_def = Asset(uri="s3://bucket/dup0", name="dup0") + with dag_maker( + dag_id="dup_dag", + schedule=PartitionedAssetTimetable(assets=asset_def), + serialized=True, + ): + EmptyOperator(task_id="t") + dag_maker.create_dagrun() + dag_maker.sync_dagbag_to_db() + + asset = session.scalar(select(AssetModel).where(AssetModel.uri == "s3://bucket/dup0")) + pdr = AssetPartitionDagRun(target_dag_id="dup_dag", partition_key="2024-06-01") + session.add(pdr) + session.flush() + + # Log 3 events for the same asset — all should collapse to total_received = 1. + for _ in range(3): + event = AssetEvent(asset_id=asset.id, timestamp=pendulum.now()) + session.add(event) + session.flush() + session.add( + PartitionedAssetKeyLog( + asset_id=asset.id, + asset_event_id=event.id, + asset_partition_dag_run_id=pdr.id, + source_partition_key="2024-06-01", + target_dag_id="dup_dag", + target_partition_key="2024-06-01", + ) + ) + session.commit() + + resp = test_client.get("/partitioned_dag_runs?dag_id=dup_dag&has_created_dag_run_id=false") + assert resp.status_code == 200 + pdr_resp = resp.json()["partitioned_dag_runs"][0] + assert pdr_resp["total_required"] == 1 + assert pdr_resp["total_received"] == 1 + + def test_non_rollup_any_event_counts_as_one(self, test_client, dag_maker, session): + """For non-rollup, an event with a different source_partition_key still counts as 1.""" + asset_def = Asset(uri="s3://bucket/nr0", name="nr0") + with dag_maker( + dag_id="nr_dag", + schedule=PartitionedAssetTimetable(assets=asset_def), + serialized=True, + ): + EmptyOperator(task_id="t") + dag_maker.create_dagrun() + dag_maker.sync_dagbag_to_db() + + asset = session.scalar(select(AssetModel).where(AssetModel.uri == "s3://bucket/nr0")) + pdr = AssetPartitionDagRun(target_dag_id="nr_dag", partition_key="2024-06-01") + session.add(pdr) + session.flush() + + # Log an event whose source_partition_key differs from the APDR partition_key. + event = AssetEvent(asset_id=asset.id, timestamp=pendulum.now()) + session.add(event) + session.flush() + session.add( + PartitionedAssetKeyLog( + asset_id=asset.id, + asset_event_id=event.id, + asset_partition_dag_run_id=pdr.id, + source_partition_key="different-key", + target_dag_id="nr_dag", + target_partition_key="2024-06-01", + ) + ) + session.commit() + + resp = test_client.get("/partitioned_dag_runs?dag_id=nr_dag&has_created_dag_run_id=false") + assert resp.status_code == 200 + pdr_resp = resp.json()["partitioned_dag_runs"][0] + assert pdr_resp["total_required"] == 1 + assert pdr_resp["total_received"] == 1 + + def test_rollup_mapper_counts_received_upstream_keys(self, test_client, dag_maker, session): + """For a rollup mapper, only upstream keys in to_upstream() are counted.""" + asset_def = Asset(uri="s3://bucket/daily", name="daily") + mapper = RollupMapper( + source_mapper=StartOfWeekMapper(input_format="%Y-%m-%d", output_format="%Y-%m-%d"), + window=WeekWindow(), + ) + with dag_maker( + dag_id="rollup_dag", + schedule=PartitionedAssetTimetable(assets=asset_def, partition_mapper_config={asset_def: mapper}), + serialized=True, + ): + EmptyOperator(task_id="t") + dag_maker.create_dagrun() + dag_maker.sync_dagbag_to_db() + + asset = session.scalar(select(AssetModel).where(AssetModel.uri == "s3://bucket/daily")) + # Week starting 2024-06-03 (Monday) needs 7 daily keys. + pdr = AssetPartitionDagRun(target_dag_id="rollup_dag", partition_key="2024-06-03") + session.add(pdr) + session.flush() + + # Receive 2 of the 7 required upstream daily keys. + for day in ("2024-06-03", "2024-06-04"): + event = AssetEvent(asset_id=asset.id, timestamp=pendulum.now()) + session.add(event) + session.flush() + session.add( + PartitionedAssetKeyLog( + asset_id=asset.id, + asset_event_id=event.id, + asset_partition_dag_run_id=pdr.id, + source_partition_key=day, + target_dag_id="rollup_dag", + target_partition_key="2024-06-03", + ) + ) + # Also log a key outside the required week — it must not inflate the count. + stray = AssetEvent(asset_id=asset.id, timestamp=pendulum.now()) + session.add(stray) + session.flush() + session.add( + PartitionedAssetKeyLog( + asset_id=asset.id, + asset_event_id=stray.id, + asset_partition_dag_run_id=pdr.id, + source_partition_key="2024-05-27", # previous week + target_dag_id="rollup_dag", + target_partition_key="2024-06-03", + ) + ) + session.commit() + + resp = test_client.get("/partitioned_dag_runs?dag_id=rollup_dag&has_created_dag_run_id=false") + assert resp.status_code == 200 + pdr_resp = resp.json()["partitioned_dag_runs"][0] + assert pdr_resp["total_required"] == 7 + assert pdr_resp["total_received"] == 2 # only the 2 in-week keys count + class TestGetPendingPartitionedDagRun: def test_should_response_401(self, unauthenticated_test_client): @@ -354,3 +495,76 @@ def test_should_response_200(self, test_client, dag_maker, session, num_assets, received_uris = {a["asset_uri"] for a in body["assets"] if a["received"]} assert received_uris == set(uris[:received_count]) + + def test_is_rollup_false_for_non_rollup_asset(self, test_client, dag_maker, session): + """is_rollup is False for assets that use the identity (non-rollup) mapper.""" + asset_def = Asset(uri="s3://bucket/nr1", name="nr1") + with dag_maker( + dag_id="nr_detail_dag", + schedule=PartitionedAssetTimetable(assets=asset_def), + serialized=True, + ): + EmptyOperator(task_id="t") + dag_maker.create_dagrun() + dag_maker.sync_dagbag_to_db() + + session.add(AssetPartitionDagRun(target_dag_id="nr_detail_dag", partition_key="2024-07-01")) + session.commit() + + resp = test_client.get("/pending_partitioned_dag_run/nr_detail_dag/2024-07-01") + assert resp.status_code == 200 + assets = resp.json()["assets"] + assert len(assets) == 1 + assert assets[0]["is_rollup"] is False + + def test_is_rollup_true_for_rollup_asset(self, test_client, dag_maker, session): + """is_rollup is True for assets that use a RollupMapper, and keys are populated.""" + asset_def = Asset(uri="s3://bucket/weekly", name="weekly") + mapper = RollupMapper( + source_mapper=StartOfWeekMapper(input_format="%Y-%m-%d", output_format="%Y-%m-%d"), + window=WeekWindow(), + ) + with dag_maker( + dag_id="rollup_detail_dag", + schedule=PartitionedAssetTimetable(assets=asset_def, partition_mapper_config={asset_def: mapper}), + serialized=True, + ): + EmptyOperator(task_id="t") + dag_maker.create_dagrun() + dag_maker.sync_dagbag_to_db() + + asset = session.scalar(select(AssetModel).where(AssetModel.uri == "s3://bucket/weekly")) + pdr = AssetPartitionDagRun(target_dag_id="rollup_detail_dag", partition_key="2024-06-03") + session.add(pdr) + session.flush() + + # Receive one upstream daily key. + event = AssetEvent(asset_id=asset.id, timestamp=pendulum.now()) + session.add(event) + session.flush() + session.add( + PartitionedAssetKeyLog( + asset_id=asset.id, + asset_event_id=event.id, + asset_partition_dag_run_id=pdr.id, + source_partition_key="2024-06-03", + target_dag_id="rollup_detail_dag", + target_partition_key="2024-06-03", + ) + ) + session.commit() + + resp = test_client.get("/pending_partitioned_dag_run/rollup_detail_dag/2024-06-03") + assert resp.status_code == 200 + body = resp.json() + assert body["total_required"] == 7 + assert body["total_received"] == 1 + assets = body["assets"] + assert len(assets) == 1 + a = assets[0] + assert a["is_rollup"] is True + assert a["required_count"] == 7 + assert a["received_count"] == 1 + assert len(a["required_keys"]) == 7 + assert "2024-06-03" in a["required_keys"] + assert a["received_keys"] == ["2024-06-03"] diff --git a/airflow-core/tests/unit/dag_processing/test_collection.py b/airflow-core/tests/unit/dag_processing/test_collection.py index 42be7792381fb..943012b79c578 100644 --- a/airflow-core/tests/unit/dag_processing/test_collection.py +++ b/airflow-core/tests/unit/dag_processing/test_collection.py @@ -54,9 +54,13 @@ from airflow.models.errors import ParseImportError from airflow.models.serialized_dag import SerializedDagModel from airflow.models.trigger import Trigger +from airflow.partition_mappers.base import RollupMapper +from airflow.partition_mappers.temporal import StartOfDayMapper +from airflow.partition_mappers.window import DayWindow from airflow.providers.standard.operators.empty import EmptyOperator from airflow.providers.standard.triggers.file import FileDeleteTrigger from airflow.sdk import DAG, Asset, AssetAlias, AssetWatcher +from airflow.sdk.definitions.timetables.assets import PartitionedAssetTimetable from airflow.serialization.definitions.assets import SerializedAsset from airflow.serialization.encoders import encode_trigger, ensure_serialized_asset from airflow.serialization.serialized_objects import LazyDeserializedDAG @@ -1193,3 +1197,63 @@ def test_update_dag_tags(self, testing_dag_bundle, session, initial_tags, new_ta session.commit() assert {t.name for t in dag_model.tags} == expected_tags + + +@pytest.mark.db_test +class TestPartitionMapperInfoSync: + """Verify partition_mapper_info is populated on DagModel during Dag sync.""" + + @pytest.fixture(autouse=True) + def clean_db_around_test(self) -> Generator: + def reset() -> None: + clear_db_dags() + clear_db_assets() + clear_db_serialized_dags() + + reset() + yield + reset() + + def test_partitioned_dag_with_rollup_mapper(self, dag_maker, session): + """Cover regular Asset, name ref, and uri ref entries in partition_mapper_info.""" + rollup_asset = Asset(uri="s3://bucket/rollup", name="rollup") + name_ref = Asset.ref(name="ref_by_name") + uri_ref = Asset.ref(uri="s3://ref") + rollup_mapper = RollupMapper(source_mapper=StartOfDayMapper(), window=DayWindow()) + + with dag_maker( + dag_id="partitioned_with_rollup", + schedule=PartitionedAssetTimetable( + assets=rollup_asset, + partition_mapper_config={ + rollup_asset: rollup_mapper, + name_ref: rollup_mapper, + uri_ref: rollup_mapper, + }, + ), + serialized=True, + ): + EmptyOperator(task_id="t") + + dag_model = session.get(DagModel, "partitioned_with_rollup") + assert dag_model.partition_mapper_info == [ + {"name": "rollup", "uri": "s3://bucket/rollup", "is_rollup": True}, + {"name": "ref_by_name", "is_rollup": True}, + {"uri": "s3://ref", "is_rollup": True}, + ] + assert dag_model.has_rollup_mappers is True + assert dag_model.is_rollup_asset(name="rollup", uri="s3://bucket/rollup") is True + assert dag_model.is_rollup_asset(name="ref_by_name", uri="") is True + assert dag_model.is_rollup_asset(name="", uri="s3://ref") is True + + def test_non_partitioned_dag_leaves_info_empty(self, dag_maker, session): + with dag_maker( + dag_id="non_partitioned_dag", + schedule=[Asset(uri="s3://bucket/A", name="A")], + serialized=True, + ): + EmptyOperator(task_id="t") + + dag_model = session.get(DagModel, "non_partitioned_dag") + assert dag_model.partition_mapper_info == [] + assert dag_model.has_rollup_mappers is False diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py b/airflow-core/tests/unit/jobs/test_scheduler_job.py index 56e69459104ea..b8a684be29c13 100644 --- a/airflow-core/tests/unit/jobs/test_scheduler_job.py +++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py @@ -92,7 +92,9 @@ AssetAlias, AssetWatcher, CronPartitionTimetable, + HourWindow, IdentityMapper, + RollupMapper, StartOfHourMapper, task, ) @@ -9927,6 +9929,127 @@ def test_consumer_dag_listen_to_two_partitioned_asset( assert asset_event.source_run_id == "test" +@pytest.mark.need_serialized_dag +@pytest.mark.usefixtures("clear_asset_partition_rows") +def test_partitioned_dag_run_rollup_holds_until_window_complete( + dag_maker: DagMaker, + session: Session, +): + """A rollup APDR stays pending until every upstream key in the window has arrived.""" + asset_1 = Asset(name="asset-1") + with dag_maker( + dag_id="rollup-consumer", + schedule=PartitionedAssetTimetable( + assets=asset_1, + default_partition_mapper=RollupMapper( + source_mapper=StartOfHourMapper(), + window=HourWindow(), + ), + ), + session=session, + ): + EmptyOperator(task_id="hi") + session.commit() + + runner = SchedulerJobRunner( + job=Job(job_type=SchedulerJobRunner.job_type), executors=[MockExecutor(do_update=False)] + ) + + # First minute key arrives — only 1 / 60 upstream keys, so the APDR must + # not fire yet. + apdr = _produce_and_register_asset_event( + dag_id="rollup-producer-0", + asset=asset_1, + partition_key="2024-01-01T00:00:00", + session=session, + dag_maker=dag_maker, + expected_partition_key="2024-01-01T00", + ) + partition_dags = runner._create_dagruns_for_partitioned_asset_dags(session=session) + session.refresh(apdr) + assert apdr.created_dag_run_id is None + assert partition_dags == set() + + # Send the remaining 59 minute keys — once all 60 are present the rollup is + # satisfied and the APDR creates its Dag run on the next tick. + for minute in range(1, 60): + _produce_and_register_asset_event( + dag_id=f"rollup-producer-{minute}", + asset=asset_1, + partition_key=f"2024-01-01T00:{minute:02d}:00", + session=session, + dag_maker=dag_maker, + expected_partition_key="2024-01-01T00", + ) + partition_dags = runner._create_dagruns_for_partitioned_asset_dags(session=session) + session.refresh(apdr) + assert apdr.created_dag_run_id is not None + assert partition_dags == {"rollup-consumer"} + + +@pytest.mark.need_serialized_dag +@pytest.mark.usefixtures("clear_asset_partition_rows") +def test_partitioned_dag_run_rollup_treats_mapper_exception_as_not_satisfied( + dag_maker: DagMaker, + session: Session, +): + """ + A misconfigured rollup mapper that raises during status evaluation must not crash + the scheduler tick — the asset is treated as not-yet-satisfied, the APDR remains + pending, and an audit log entry surfaces the reason to the operator. + """ + session.execute(delete(Log)) + asset_1 = Asset(name="asset-1") + with dag_maker( + dag_id="rollup-consumer", + schedule=PartitionedAssetTimetable( + assets=asset_1, + default_partition_mapper=RollupMapper( + source_mapper=StartOfHourMapper(), + window=HourWindow(), + ), + ), + session=session, + ): + EmptyOperator(task_id="hi") + session.commit() + + runner = SchedulerJobRunner( + job=Job(job_type=SchedulerJobRunner.job_type), executors=[MockExecutor(do_update=False)] + ) + + apdr = _produce_and_register_asset_event( + dag_id="rollup-producer-0", + asset=asset_1, + partition_key="2024-01-01T00:00:00", + session=session, + dag_maker=dag_maker, + expected_partition_key="2024-01-01T00", + ) + + with mock.patch.object( + SchedulerJobRunner, + "_check_rollup_asset_status", + side_effect=RuntimeError("misconfigured rollup mapper"), + ): + partition_dags = runner._create_dagruns_for_partitioned_asset_dags(session=session) + + session.refresh(apdr) + assert apdr.created_dag_run_id is None + assert partition_dags == set() + + # Audit log is added via session.add inside the scheduler tick and only + # visible to a subsequent read after a flush. + session.flush() + audit_log = session.scalar(select(Log).where(Log.event == "failed to evaluate rollup status")) + assert audit_log is not None + assert audit_log.dag_id == "rollup-consumer" + assert audit_log.extra is not None + assert "misconfigured rollup mapper" in audit_log.extra + assert "asset-1" in audit_log.extra + assert "2024-01-01T00" in audit_log.extra + + @pytest.mark.need_serialized_dag @pytest.mark.usefixtures("clear_asset_partition_rows") def test_consumer_dag_listen_to_two_partitioned_asset_with_key_1_mapper( diff --git a/airflow-core/tests/unit/models/test_dag.py b/airflow-core/tests/unit/models/test_dag.py index b34ab12dd4aef..c959a8a624f74 100644 --- a/airflow-core/tests/unit/models/test_dag.py +++ b/airflow-core/tests/unit/models/test_dag.py @@ -2961,6 +2961,72 @@ def test_get_dag_id_to_team_name_mapping(self, testing_team): } +class TestDagModelPartitionMapperInfo: + """Pure-function tests for DagModel.is_rollup_asset / has_rollup_mappers.""" + + @pytest.mark.parametrize( + ("info", "name", "uri", "expected"), + [ + pytest.param([], "a", "s3://a", False, id="empty-info"), + pytest.param( + [ + {"name": "asset", "is_rollup": True}, + {"uri": "s3://asset", "is_rollup": False}, + ], + "asset", + "s3://asset", + True, + id="name-match-wins-over-uri", + ), + pytest.param( + [{"uri": "s3://asset", "is_rollup": True}], + "asset", + "s3://asset", + True, + id="uri-match-fallback", + ), + pytest.param( + [{"name": "other", "is_rollup": True}], + "asset", + "s3://asset", + False, + id="unknown-asset", + ), + ], + ) + def test_is_rollup_asset(self, info, name, uri, expected): + dm = DagModel(dag_id="d") + dm.partition_mapper_info = info + assert dm.is_rollup_asset(name=name, uri=uri) is expected + + @pytest.mark.parametrize( + ("info", "expected"), + [ + pytest.param([], False, id="empty"), + pytest.param( + [ + {"name": "a", "is_rollup": False}, + {"uri": "s3://b", "is_rollup": False}, + ], + False, + id="no-rollup-entries", + ), + pytest.param( + [ + {"name": "a", "is_rollup": False}, + {"uri": "s3://b", "is_rollup": True}, + ], + True, + id="at-least-one-rollup", + ), + ], + ) + def test_has_rollup_mappers(self, info, expected): + dm = DagModel(dag_id="d") + dm.partition_mapper_info = info + assert dm.has_rollup_mappers is expected + + class TestQueries: def setup_method(self) -> None: clear_db_runs() diff --git a/airflow-core/tests/unit/partition_mappers/test_temporal.py b/airflow-core/tests/unit/partition_mappers/test_temporal.py index c54ca8a51f9a7..6d61c8b38e72c 100644 --- a/airflow-core/tests/unit/partition_mappers/test_temporal.py +++ b/airflow-core/tests/unit/partition_mappers/test_temporal.py @@ -125,3 +125,31 @@ def test_to_downstream_input_timezone_differs_from_mapper_timezone(self): # 2026-02-11T06:00:00+00:00 UTC == 2026-02-11T01:00:00-05:00 New York # → start-of-day in New York is 2026-02-11 assert pm.to_downstream("2026-02-11T06:00:00+0000") == "2026-02-11" + + +class TestOutputFormatValidation: + """Malformed output_format must fail fast at mapper construction, not as an opaque re.error inside the scheduler.""" + + @pytest.mark.parametrize( + ("output_format", "match"), + [ + ("a{}b", "invalid placeholder"), + ("%Y-{1bad}", "invalid placeholder"), + ("%Y-%Y-%m-%d", "reuses directive"), + ("%Y-Q{quarter}-{quarter}", "reuses placeholder"), + ], + ) + def test_quarter_mapper_rejects_malformed_format(self, output_format, match): + with pytest.raises(ValueError, match=match): + StartOfQuarterMapper(output_format=output_format) + + @pytest.mark.parametrize( + ("output_format", "match"), + [ + ("week-{}", "invalid placeholder"), + ("%Y-%m-%d-%Y", "reuses directive"), + ], + ) + def test_week_mapper_rejects_malformed_format(self, output_format, match): + with pytest.raises(ValueError, match=match): + StartOfWeekMapper(output_format=output_format) diff --git a/airflow-core/tests/unit/partition_mappers/test_window.py b/airflow-core/tests/unit/partition_mappers/test_window.py new file mode 100644 index 0000000000000..f3b40caf685a3 --- /dev/null +++ b/airflow-core/tests/unit/partition_mappers/test_window.py @@ -0,0 +1,287 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from datetime import datetime + +import pytest + +from airflow.partition_mappers.base import RollupMapper +from airflow.partition_mappers.temporal import ( + StartOfDayMapper, + StartOfHourMapper, + StartOfQuarterMapper, + StartOfWeekMapper, + StartOfYearMapper, +) +from airflow.partition_mappers.window import ( + DayWindow, + HourWindow, + MonthWindow, + QuarterWindow, + WeekWindow, + YearWindow, +) + + +class TestHourWindow: + def test_yields_60_minute_period_starts(self): + period_start = datetime(2024, 6, 10, 14) + members = list(HourWindow().to_upstream(period_start)) + assert len(members) == 60 + assert members[0] == datetime(2024, 6, 10, 14, 0) + assert members[-1] == datetime(2024, 6, 10, 14, 59) + + +class TestDayWindow: + def test_yields_24_hourly_period_starts(self): + period_start = datetime(2024, 6, 10) + members = list(DayWindow().to_upstream(period_start)) + assert len(members) == 24 + assert members[0] == datetime(2024, 6, 10, 0) + assert members[-1] == datetime(2024, 6, 10, 23) + + +class TestWeekWindow: + def test_yields_seven_daily_period_starts(self): + # 2024-06-10 is a Monday. + period_start = datetime(2024, 6, 10) + members = list(WeekWindow().to_upstream(period_start)) + assert members == [ + datetime(2024, 6, 10), + datetime(2024, 6, 11), + datetime(2024, 6, 12), + datetime(2024, 6, 13), + datetime(2024, 6, 14), + datetime(2024, 6, 15), + datetime(2024, 6, 16), + ] + + +class TestMonthWindow: + def test_yields_all_days_in_february_leap_year(self): + members = list(MonthWindow().to_upstream(datetime(2024, 2, 1))) + assert len(members) == 29 + assert members[0] == datetime(2024, 2, 1) + assert members[-1] == datetime(2024, 2, 29) + + def test_honours_non_calendar_month_offset(self): + # A custom source mapper might decode a period start mid-month; + # the window must walk forward to the same day in the next month. + members = list(MonthWindow().to_upstream(datetime(2024, 6, 15))) + assert len(members) == 30 + assert members[0] == datetime(2024, 6, 15) + assert members[-1] == datetime(2024, 7, 14) + + def test_crosses_year_boundary(self): + members = list(MonthWindow().to_upstream(datetime(2024, 12, 1))) + assert len(members) == 31 + assert members[-1] == datetime(2024, 12, 31) + + +class TestQuarterWindow: + def test_yields_three_monthly_period_starts(self): + # Q2 starts at 2024-04-01. + members = list(QuarterWindow().to_upstream(datetime(2024, 4, 1))) + assert members == [datetime(2024, 4, 1), datetime(2024, 5, 1), datetime(2024, 6, 1)] + + def test_wraps_year_for_non_calendar_quarter_start(self): + # A custom source mapper might decode a quarter starting in November + # (e.g. fiscal Q4); the window must wrap into the next year. + members = list(QuarterWindow().to_upstream(datetime(2024, 11, 1))) + assert members == [datetime(2024, 11, 1), datetime(2024, 12, 1), datetime(2025, 1, 1)] + + +class TestYearWindow: + def test_yields_twelve_monthly_period_starts(self): + members = list(YearWindow().to_upstream(datetime(2024, 1, 1))) + assert members == [datetime(2024, m, 1) for m in range(1, 13)] + + def test_wraps_year_for_fiscal_year_start(self): + # A fiscal year starting in April rolls forward 12 months from April, + # crossing the calendar boundary into the following year. + members = list(YearWindow().to_upstream(datetime(2024, 4, 1))) + assert members[0] == datetime(2024, 4, 1) + assert members[8] == datetime(2024, 12, 1) + assert members[9] == datetime(2025, 1, 1) + assert members[-1] == datetime(2025, 3, 1) + + +class TestRollupMapperComposition: + def test_to_downstream_delegates_to_source_mapper(self): + mapper = RollupMapper( + source_mapper=StartOfWeekMapper(), + window=WeekWindow(), + ) + # Wednesday 2024-06-12 belongs to the week starting Monday 2024-06-10. + assert mapper.to_downstream("2024-06-12T14:30:00") == "2024-06-10 (W24)" + + def test_is_rollup_flag(self): + mapper = RollupMapper(source_mapper=StartOfWeekMapper(), window=WeekWindow()) + assert mapper.is_rollup is True + + def test_hour_rollup_to_upstream_keys(self): + mapper = RollupMapper( + source_mapper=StartOfHourMapper(input_format="%Y-%m-%dT%H:%M", output_format="%Y-%m-%dT%H"), + window=HourWindow(), + ) + upstream = mapper.to_upstream("2024-06-10T14") + assert upstream == frozenset(f"2024-06-10T14:{m:02d}" for m in range(60)) + + def test_day_rollup_to_upstream_keys(self): + mapper = RollupMapper( + source_mapper=StartOfDayMapper(input_format="%Y-%m-%dT%H", output_format="%Y-%m-%d"), + window=DayWindow(), + ) + upstream = mapper.to_upstream("2024-06-10") + assert upstream == frozenset(f"2024-06-10T{h:02d}" for h in range(24)) + + def test_day_rollup_honours_timezone_in_encode(self): + mapper = RollupMapper( + source_mapper=StartOfDayMapper( + input_format="%Y-%m-%dT%H%z", + output_format="%Y-%m-%d", + timezone="Asia/Taipei", + ), + window=DayWindow(), + ) + upstream = mapper.to_upstream("2024-06-10") + assert "2024-06-10T00+0800" in upstream + assert "2024-06-10T23+0800" in upstream + + def test_week_rollup_with_default_format(self): + mapper = RollupMapper( + source_mapper=StartOfWeekMapper(input_format="%Y-%m-%d", output_format="%Y-%m-%d"), + window=WeekWindow(), + ) + upstream = mapper.to_upstream("2024-06-10") + assert upstream == frozenset( + { + "2024-06-10", + "2024-06-11", + "2024-06-12", + "2024-06-13", + "2024-06-14", + "2024-06-15", + "2024-06-16", + } + ) + + def test_week_rollup_accepts_custom_output_format(self): + # decode_downstream lives on the mapper, so any format embedding the date works. + mapper = RollupMapper( + source_mapper=StartOfWeekMapper( + input_format="%Y-%m-%d", + output_format="Week-of-%Y-%m-%d", + ), + window=WeekWindow(), + ) + upstream = mapper.to_upstream("Week-of-2024-06-10") + assert len(upstream) == 7 + assert "2024-06-10" in upstream + + def test_week_rollup_raises_when_downstream_key_has_no_date(self): + mapper = RollupMapper( + source_mapper=StartOfWeekMapper(input_format="%Y-%m-%d"), + window=WeekWindow(), + ) + with pytest.raises(ValueError, match="StartOfWeekMapper.decode_downstream could not parse"): + mapper.to_upstream("week-24") + + def test_quarter_rollup_to_upstream_keys(self): + mapper = RollupMapper( + source_mapper=StartOfQuarterMapper(input_format="%Y-%m"), + window=QuarterWindow(), + ) + assert mapper.to_upstream("2024-Q2") == frozenset({"2024-04", "2024-05", "2024-06"}) + + def test_quarter_rollup_raises_when_marker_missing(self): + mapper = RollupMapper( + source_mapper=StartOfQuarterMapper(input_format="%Y-%m"), + window=QuarterWindow(), + ) + with pytest.raises(ValueError, match="StartOfQuarterMapper.decode_downstream could not parse"): + mapper.to_upstream("2024-06") + + def test_quarter_rollup_accepts_reordered_output_format(self): + # ``{quarter}`` and ``%Y`` can appear in any order; the format-derived + # regex pulls out both via named groups. + mapper = RollupMapper( + source_mapper=StartOfQuarterMapper(input_format="%Y-%m", output_format="Q{quarter}/%Y"), + window=QuarterWindow(), + ) + assert mapper.to_upstream("Q2/2024") == frozenset({"2024-04", "2024-05", "2024-06"}) + + def test_year_rollup_to_upstream_keys(self): + mapper = RollupMapper( + source_mapper=StartOfYearMapper(input_format="%Y-%m"), + window=YearWindow(), + ) + assert mapper.to_upstream("2024") == frozenset(f"2024-{m:02d}" for m in range(1, 13)) + + def test_serialize_round_trip(self): + from airflow.serialization.decoders import decode_partition_mapper + from airflow.serialization.encoders import encode_partition_mapper + + mapper = RollupMapper( + source_mapper=StartOfWeekMapper(input_format="%Y-%m-%d", output_format="%Y-%m-%d"), + window=WeekWindow(), + ) + restored = decode_partition_mapper(encode_partition_mapper(mapper)) + assert isinstance(restored, RollupMapper) + assert isinstance(restored.source_mapper, StartOfWeekMapper) + assert isinstance(restored.window, WeekWindow) + assert restored.to_upstream("2024-06-10") == mapper.to_upstream("2024-06-10") + + @pytest.mark.parametrize( + ("source_factory", "window", "downstream_key"), + [ + pytest.param( + lambda: StartOfHourMapper(input_format="%Y-%m-%dT%H:%M", output_format="%Y-%m-%dT%H"), + HourWindow(), + "2024-06-10T14", + id="hour", + ), + pytest.param( + lambda: StartOfDayMapper(input_format="%Y-%m-%dT%H", output_format="%Y-%m-%d"), + DayWindow(), + "2024-06-10", + id="day", + ), + pytest.param( + lambda: StartOfQuarterMapper(input_format="%Y-%m"), + QuarterWindow(), + "2024-Q2", + id="quarter", + ), + pytest.param( + lambda: StartOfYearMapper(input_format="%Y-%m"), + YearWindow(), + "2024", + id="year", + ), + ], + ) + def test_window_serialize_round_trip(self, source_factory, window, downstream_key): + from airflow.serialization.decoders import decode_partition_mapper + from airflow.serialization.encoders import encode_partition_mapper + + mapper = RollupMapper(source_mapper=source_factory(), window=window) + restored = decode_partition_mapper(encode_partition_mapper(mapper)) + assert isinstance(restored, RollupMapper) + assert isinstance(restored.window, type(window)) + assert restored.to_upstream(downstream_key) == mapper.to_upstream(downstream_key) diff --git a/airflow-core/tests/unit/timetables/test_partitioned_timetable.py b/airflow-core/tests/unit/timetables/test_partitioned_timetable.py index b1befe9f5d2cd..0a66c124808b1 100644 --- a/airflow-core/tests/unit/timetables/test_partitioned_timetable.py +++ b/airflow-core/tests/unit/timetables/test_partitioned_timetable.py @@ -25,7 +25,10 @@ import pytest from airflow._shared.module_loading import qualname +from airflow.partition_mappers.base import RollupMapper from airflow.partition_mappers.identity import IdentityMapper as IdentityMapper +from airflow.partition_mappers.temporal import StartOfDayMapper +from airflow.partition_mappers.window import DayWindow from airflow.sdk import Asset from airflow.serialization.definitions.assets import SerializedAsset from airflow.serialization.encoders import ensure_serialized_asset @@ -110,6 +113,44 @@ def test_get_partition_mapper_with_mapping(self, asset_obj): assert isinstance(timetable.get_partition_mapper(name="test_1"), Key1Mapper) assert isinstance(timetable.get_partition_mapper(uri="test_1"), Key1Mapper) + def test_partition_mapper_info_empty(self): + timetable = PartitionedAssetTimetable(assets=Asset("a")) + assert timetable.partition_mapper_info == [] + + def test_partition_mapper_info_mixes_rollup_and_non_rollup(self): + non_rollup = ensure_serialized_asset(Asset(name="non_rollup_name", uri="s3://bucket/non_rollup")) + rollup = ensure_serialized_asset(Asset(name="rollup_name", uri="s3://bucket/rollup")) + timetable = PartitionedAssetTimetable( + assets=non_rollup, + partition_mapper_config={ + non_rollup: IdentityMapper(), + rollup: RollupMapper(source_mapper=StartOfDayMapper(), window=DayWindow()), + }, + ) + + info = timetable.partition_mapper_info + assert info == [ + {"name": "non_rollup_name", "uri": "s3://bucket/non_rollup", "is_rollup": False}, + {"name": "rollup_name", "uri": "s3://bucket/rollup", "is_rollup": True}, + ] + + def test_partition_mapper_info_handles_asset_refs(self): + timetable = PartitionedAssetTimetable( + assets=Asset(name="x", uri="x"), + partition_mapper_config={ + ensure_serialized_asset(Asset.ref(name="ref_by_name")): RollupMapper( + source_mapper=StartOfDayMapper(), window=DayWindow() + ), + ensure_serialized_asset(Asset.ref(uri="s3://ref")): IdentityMapper(), + }, + ) + + info = timetable.partition_mapper_info + assert info == [ + {"name": "ref_by_name", "is_rollup": True}, + {"uri": "s3://ref", "is_rollup": False}, + ] + def test_serialize(self): ser_asset = ensure_serialized_asset(Asset("test")) timetable = PartitionedAssetTimetable( diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index b4460b21e7563..12469172aca48 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -1374,6 +1374,7 @@ Roadmap roadmap roc RoleBinding +rollup rom romeoandjuliet rootcss diff --git a/scripts/in_container/run_migration_round_trip.py b/scripts/in_container/run_migration_round_trip.py index 20e15c672c3db..f8b12d3dc9454 100755 --- a/scripts/in_container/run_migration_round_trip.py +++ b/scripts/in_container/run_migration_round_trip.py @@ -105,6 +105,7 @@ "timetable_type": "'cron'", "timetable_partitioned": "0", "timetable_periodic": "0", + "partition_mapper_info": "'[]'", }, "dag_version": { "id": f"'{SEED_DAG_VERSION_ID}'", diff --git a/task-sdk/docs/api.rst b/task-sdk/docs/api.rst index cb9789f5bb69f..d76aa4b872b18 100644 --- a/task-sdk/docs/api.rst +++ b/task-sdk/docs/api.rst @@ -233,10 +233,29 @@ Partition Mapper .. autoapiclass:: airflow.sdk.StartOfYearMapper +.. autoapiclass:: airflow.sdk.RollupMapper + .. autoapiclass:: airflow.sdk.ProductMapper .. autoapiclass:: airflow.sdk.AllowedKeyMapper +Rollup Windows +~~~~~~~~~~~~~~ + +.. autoapiclass:: airflow.sdk.Window + +.. autoapiclass:: airflow.sdk.HourWindow + +.. autoapiclass:: airflow.sdk.DayWindow + +.. autoapiclass:: airflow.sdk.WeekWindow + +.. autoapiclass:: airflow.sdk.MonthWindow + +.. autoapiclass:: airflow.sdk.QuarterWindow + +.. autoapiclass:: airflow.sdk.YearWindow + I/O Helpers ----------- .. autoapiclass:: airflow.sdk.ObjectStoragePath diff --git a/task-sdk/src/airflow/sdk/__init__.py b/task-sdk/src/airflow/sdk/__init__.py index f304b068237b3..4077bfa5bf41e 100644 --- a/task-sdk/src/airflow/sdk/__init__.py +++ b/task-sdk/src/airflow/sdk/__init__.py @@ -45,6 +45,7 @@ "CronPartitionTimetable", "DAG", "DagRunState", + "DayWindow", "DeadlineAlert", "DeadlineReference", "DeltaDataIntervalTimetable", @@ -52,9 +53,11 @@ "EdgeModifier", "EventsTimetable", "ExceptionRetryPolicy", + "HourWindow", "IdentityMapper", "Label", "Metadata", + "MonthWindow", "MultipleCronTriggerTimetable", "ObjectStoragePath", "Param", @@ -63,10 +66,12 @@ "PartitionMapper", "PokeReturnValue", "ProductMapper", + "QuarterWindow", "RetryAction", "RetryDecision", "RetryPolicy", "RetryRule", + "RollupMapper", "SkipMixin", "SyncCallback", "StartOfDayMapper", @@ -80,8 +85,11 @@ "TaskInstanceState", "TriggerRule", "Variable", + "WeekWindow", "WeightRule", + "Window", "XComArg", + "YearWindow", "asset", "chain", "chain_linear", @@ -131,7 +139,7 @@ from airflow.sdk.definitions.edges import EdgeModifier, Label from airflow.sdk.definitions.param import Param, ParamsDict from airflow.sdk.definitions.partition_mappers.allowed_key import AllowedKeyMapper - from airflow.sdk.definitions.partition_mappers.base import PartitionMapper + from airflow.sdk.definitions.partition_mappers.base import PartitionMapper, RollupMapper from airflow.sdk.definitions.partition_mappers.chain import ChainMapper from airflow.sdk.definitions.partition_mappers.identity import IdentityMapper from airflow.sdk.definitions.partition_mappers.product import ProductMapper @@ -143,6 +151,15 @@ StartOfWeekMapper, StartOfYearMapper, ) + from airflow.sdk.definitions.partition_mappers.window import ( + DayWindow, + HourWindow, + MonthWindow, + QuarterWindow, + WeekWindow, + Window, + YearWindow, + ) from airflow.sdk.definitions.retry_policy import ( ExceptionRetryPolicy, RetryAction, @@ -201,6 +218,7 @@ "CronPartitionTimetable": ".definitions.timetables.trigger", "DAG": ".definitions.dag", "DagRunState": ".api.datamodels._generated", + "DayWindow": ".definitions.partition_mappers.window", "DeadlineAlert": ".definitions.deadline", "DeadlineReference": ".definitions.deadline", "DeltaDataIntervalTimetable": ".definitions.timetables.interval", @@ -208,9 +226,11 @@ "EdgeModifier": ".definitions.edges", "EventsTimetable": ".definitions.timetables.events", "ExceptionRetryPolicy": ".definitions.retry_policy", + "HourWindow": ".definitions.partition_mappers.window", "IdentityMapper": ".definitions.partition_mappers.identity", "Label": ".definitions.edges", "Metadata": ".definitions.asset.metadata", + "MonthWindow": ".definitions.partition_mappers.window", "MultipleCronTriggerTimetable": ".definitions.timetables.trigger", "ObjectStoragePath": ".io.path", "Param": ".definitions.param", @@ -219,10 +239,12 @@ "PartitionMapper": ".definitions.partition_mappers.base", "PokeReturnValue": ".bases.sensor", "ProductMapper": ".definitions.partition_mappers.product", + "QuarterWindow": ".definitions.partition_mappers.window", "RetryAction": ".definitions.retry_policy", "RetryDecision": ".definitions.retry_policy", "RetryPolicy": ".definitions.retry_policy", "RetryRule": ".definitions.retry_policy", + "RollupMapper": ".definitions.partition_mappers.base", "SecretCache": ".execution_time.cache", "SkipMixin": ".bases.skipmixin", "SyncCallback": ".definitions.callback", @@ -237,8 +259,11 @@ "TaskInstanceState": ".api.datamodels._generated", "TriggerRule": ".api.datamodels._generated", "Variable": ".definitions.variable", + "WeekWindow": ".definitions.partition_mappers.window", "WeightRule": ".api.datamodels._generated", + "Window": ".definitions.partition_mappers.window", "XComArg": ".definitions.xcom_arg", + "YearWindow": ".definitions.partition_mappers.window", "asset": ".definitions.asset.decorators", "chain": ".bases.operator", "chain_linear": ".bases.operator", diff --git a/task-sdk/src/airflow/sdk/__init__.pyi b/task-sdk/src/airflow/sdk/__init__.pyi index 7e6d211674eba..2db7e57aa671a 100644 --- a/task-sdk/src/airflow/sdk/__init__.pyi +++ b/task-sdk/src/airflow/sdk/__init__.pyi @@ -63,7 +63,7 @@ from airflow.sdk.definitions.decorators.task_group import task_group as task_gro from airflow.sdk.definitions.edges import EdgeModifier as EdgeModifier, Label as Label from airflow.sdk.definitions.param import Param as Param from airflow.sdk.definitions.partition_mappers.allowed_key import AllowedKeyMapper -from airflow.sdk.definitions.partition_mappers.base import PartitionMapper +from airflow.sdk.definitions.partition_mappers.base import PartitionMapper, RollupMapper from airflow.sdk.definitions.partition_mappers.chain import ChainMapper from airflow.sdk.definitions.partition_mappers.identity import IdentityMapper from airflow.sdk.definitions.partition_mappers.product import ProductMapper @@ -75,6 +75,15 @@ from airflow.sdk.definitions.partition_mappers.temporal import ( StartOfWeekMapper, StartOfYearMapper, ) +from airflow.sdk.definitions.partition_mappers.window import ( + DayWindow, + HourWindow, + MonthWindow, + QuarterWindow, + WeekWindow, + Window, + YearWindow, +) from airflow.sdk.definitions.retry_policy import ( ExceptionRetryPolicy as ExceptionRetryPolicy, RetryAction as RetryAction, @@ -133,14 +142,17 @@ __all__ = [ "CronPartitionTimetable", "DAG", "DagRunState", + "DayWindow", "DeltaDataIntervalTimetable", "DeltaTriggerTimetable", "EdgeModifier", "EventsTimetable", "ExceptionRetryPolicy", + "HourWindow", "IdentityMapper", "Label", "Metadata", + "MonthWindow", "MultipleCronTriggerTimetable", "ObjectStoragePath", "Param", @@ -148,10 +160,12 @@ __all__ = [ "PartitionedAssetTimetable", "PartitionMapper", "ProductMapper", + "QuarterWindow", "RetryAction", "RetryDecision", "RetryPolicy", "RetryRule", + "RollupMapper", "SecretCache", "SkipMixin", "StartOfDayMapper", @@ -164,8 +178,11 @@ __all__ = [ "TaskInstanceState", "TriggerRule", "Variable", + "WeekWindow", "WeightRule", + "Window", "XComArg", + "YearWindow", "asset", "chain", "chain_linear", diff --git a/task-sdk/src/airflow/sdk/definitions/partition_mappers/base.py b/task-sdk/src/airflow/sdk/definitions/partition_mappers/base.py index 728332d506fc7..78383b0b34dda 100644 --- a/task-sdk/src/airflow/sdk/definitions/partition_mappers/base.py +++ b/task-sdk/src/airflow/sdk/definitions/partition_mappers/base.py @@ -16,6 +16,11 @@ # under the License. from __future__ import annotations +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from airflow.sdk.definitions.partition_mappers.window import Window + class PartitionMapper: """ @@ -23,3 +28,22 @@ class PartitionMapper: Maps keys from asset events to target dag run partitions. """ + + is_rollup: bool = False + + +class RollupMapper(PartitionMapper): + """ + Partition mapper that rolls up many upstream keys into one downstream key. + + Compose a ``source_mapper`` (which normalizes each upstream key to the + downstream granularity) with a ``window`` that declares the full set of + upstream keys required for a given downstream key. The scheduler holds + the Dag run until every upstream key in the window has arrived. + """ + + is_rollup: bool = True + + def __init__(self, *, source_mapper: PartitionMapper, window: Window) -> None: + self.source_mapper = source_mapper + self.window = window diff --git a/task-sdk/src/airflow/sdk/definitions/partition_mappers/temporal.py b/task-sdk/src/airflow/sdk/definitions/partition_mappers/temporal.py index 60ca18f5044f3..eb1fdd049f598 100644 --- a/task-sdk/src/airflow/sdk/definitions/partition_mappers/temporal.py +++ b/task-sdk/src/airflow/sdk/definitions/partition_mappers/temporal.py @@ -44,13 +44,13 @@ class StartOfDayMapper(_BaseTemporalMapper): class StartOfWeekMapper(_BaseTemporalMapper): - """Map a time-based partition key to week.""" + """Map a time-based partition key to the start of its week.""" default_output_format = "%Y-%m-%d (W%V)" class StartOfMonthMapper(_BaseTemporalMapper): - """Map a time-based partition key to month.""" + """Map a time-based partition key to the start of its month.""" default_output_format = "%Y-%m" diff --git a/task-sdk/src/airflow/sdk/definitions/partition_mappers/window.py b/task-sdk/src/airflow/sdk/definitions/partition_mappers/window.py new file mode 100644 index 0000000000000..92b389a326251 --- /dev/null +++ b/task-sdk/src/airflow/sdk/definitions/partition_mappers/window.py @@ -0,0 +1,53 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + + +class Window: + """ + Describes a rollup window: which upstream keys make up one downstream key. + + Paired with a ``source_mapper`` :class:`PartitionMapper` inside a + :class:`RollupMapper`. The source_mapper normalizes upstream keys to the + downstream granularity; the window enumerates the complete set of + upstream keys that roll up into one downstream key. Runtime logic + lives in ``airflow.partition_mappers.window`` on the scheduler side. + """ + + +class HourWindow(Window): + """Sixty consecutive minute keys making up one hour.""" + + +class DayWindow(Window): + """Twenty-four consecutive hourly keys making up one day.""" + + +class WeekWindow(Window): + """Seven consecutive daily keys making up one week.""" + + +class MonthWindow(Window): + """All daily keys making up one calendar month.""" + + +class QuarterWindow(Window): + """Three consecutive monthly keys making up one calendar quarter.""" + + +class YearWindow(Window): + """Twelve consecutive monthly keys making up one calendar year."""