diff --git a/airflow-core/src/airflow/models/taskinstance.py b/airflow-core/src/airflow/models/taskinstance.py index f4bba7c65ea3d..da5f994451865 100644 --- a/airflow-core/src/airflow/models/taskinstance.py +++ b/airflow-core/src/airflow/models/taskinstance.py @@ -26,7 +26,7 @@ from collections import defaultdict from collections.abc import Collection, Iterable from datetime import datetime, timedelta -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, NamedTuple from urllib.parse import quote from uuid import UUID @@ -129,6 +129,13 @@ PAST_DEPENDS_MET = "past_depends_met" +class OutletEventPayload(NamedTuple): + """A single outlet emission carrying its ``extra`` payload and optional per-emission ``partition_key``.""" + + extra: dict + partition_key: str | None + + @provide_session def _add_log( event, @@ -1476,10 +1483,37 @@ def register_asset_changes_in_db( SerializedAssetUriRef, ) - # TODO: AIP-76 should we provide an interface to override this, so that the task can - # tell the truth if for some reason it touches a different partition? - # https://github.com/apache/airflow/issues/58474 - partition_key = ti.dag_run.partition_key + payloads_by_asset: dict[SerializedAssetUniqueKey, list[OutletEventPayload]] = defaultdict(list) + for outlet_event in outlet_events: + # Alias-emitted events are handled separately further down via + # register_asset_change_for_alias, which uses the DagRun-level + # partition_key. Per-emission partition keys do not fan out through + # the alias path — emission via an alias produces one event per + # resolved asset, all carrying the same dag_run_partition_key. + if "source_alias_name" in outlet_event: + continue + asset_key = SerializedAssetUniqueKey(**outlet_event["dest_asset_key"]) + payloads_by_asset[asset_key].append( + OutletEventPayload( + extra=outlet_event["extra"], partition_key=outlet_event.get("partition_key") + ) + ) + + # Back-fill DagRun.partition_key from the task emission when the task + # emitted exactly one distinct partition_key across all outlet events + # and the DagRun did not already have one set. This lets a task that + # discovers the partition at runtime (rather than via params) act as + # the source of truth for the DagRun-level key. + runtime_pks: set[str] = { + payload.partition_key + for payloads in payloads_by_asset.values() + for payload in payloads + if payload.partition_key is not None + } + if len(runtime_pks) == 1 and ti.dag_run.partition_key is None: + ti.dag_run.partition_key = next(iter(runtime_pks)) + dag_run_partition_key = ti.dag_run.partition_key + asset_keys = { SerializedAssetUniqueKey(o.name, o.uri) for o in task_outlets @@ -1506,11 +1540,27 @@ def register_asset_changes_in_db( ) } - asset_event_extras: dict[SerializedAssetUniqueKey, dict] = { - SerializedAssetUniqueKey(**event["dest_asset_key"]): event["extra"] - for event in outlet_events - if "source_alias_name" not in event - } + def _register(am: AssetModel, key: SerializedAssetUniqueKey) -> None: + payloads_for_asset = payloads_by_asset.get(key, []) + if not payloads_for_asset: + asset_manager.register_asset_change( + task_instance=ti, + asset=am, + extra=None, + partition_key=dag_run_partition_key, + session=session, + ) + return + for payload in payloads_for_asset: + asset_manager.register_asset_change( + task_instance=ti, + asset=am, + extra=payload.extra, + partition_key=payload.partition_key + if payload.partition_key is not None + else dag_run_partition_key, + session=session, + ) for key in asset_keys: try: @@ -1523,52 +1573,36 @@ def register_asset_changes_in_db( ) continue ti.log.debug("register event for asset %s", am) - asset_manager.register_asset_change( - task_instance=ti, - asset=am, - extra=asset_event_extras.get(key), - partition_key=partition_key, - session=session, - ) + _register(am, key) if asset_name_refs: - asset_models_by_name = {key.name: am for key, am in asset_models.items()} - asset_event_extras_by_name = {key.name: extra for key, extra in asset_event_extras.items()} + asset_models_by_name: dict[str, tuple[SerializedAssetUniqueKey, AssetModel]] = { + key.name: (key, am) for key, am in asset_models.items() + } for nref in asset_name_refs: try: - am = asset_models_by_name[nref.name] + key, am = asset_models_by_name[nref.name] except KeyError: ti.log.warning( 'Task has inactive assets "Asset.ref(name=%s)" in inlets or outlets', nref.name ) continue ti.log.debug("register event for asset name ref %s", am) - asset_manager.register_asset_change( - task_instance=ti, - asset=am, - extra=asset_event_extras_by_name.get(nref.name), - partition_key=partition_key, - session=session, - ) + _register(am, key) if asset_uri_refs: - asset_models_by_uri = {key.uri: am for key, am in asset_models.items()} - asset_event_extras_by_uri = {key.uri: extra for key, extra in asset_event_extras.items()} + asset_models_by_uri: dict[str, tuple[SerializedAssetUniqueKey, AssetModel]] = { + key.uri: (key, am) for key, am in asset_models.items() + } for uref in asset_uri_refs: try: - am = asset_models_by_uri[uref.uri] + key, am = asset_models_by_uri[uref.uri] except KeyError: ti.log.warning( 'Task has inactive assets "Asset.ref(uri=%s)" in inlets or outlets', uref.uri ) continue ti.log.debug("register event for asset uri ref %s", am) - asset_manager.register_asset_change( - task_instance=ti, - asset=am, - extra=asset_event_extras_by_uri.get(uref.uri), - partition_key=partition_key, - session=session, - ) + _register(am, key) def _asset_event_extras_from_aliases() -> dict[tuple[SerializedAssetUniqueKey, str, str], set[str]]: d = defaultdict(set) @@ -1607,7 +1641,7 @@ def _asset_event_extras_from_aliases() -> dict[tuple[SerializedAssetUniqueKey, s asset=asset, source_alias_names=event_aliase_names, extra=asset_event_extra, - partition_key=partition_key, + partition_key=dag_run_partition_key, session=session, ) if event is None: @@ -1619,7 +1653,7 @@ def _asset_event_extras_from_aliases() -> dict[tuple[SerializedAssetUniqueKey, s asset=asset, source_alias_names=event_aliase_names, extra=asset_event_extra, - partition_key=partition_key, + partition_key=dag_run_partition_key, session=session, ) diff --git a/airflow-core/tests/unit/models/test_taskinstance.py b/airflow-core/tests/unit/models/test_taskinstance.py index b9e5452855ff1..b6f5e5425b7e3 100644 --- a/airflow-core/tests/unit/models/test_taskinstance.py +++ b/airflow-core/tests/unit/models/test_taskinstance.py @@ -3513,6 +3513,163 @@ def test_when_dag_run_has_partition_and_downstreams_listening_then_tables_popula assert pakl.target_dag_id == "asset_event_listener" +def test_runtime_partition_key_backfills_dag_run_when_none(dag_maker, session): + """Single runtime key on a PartitionAtRuntime-style run (dag_run.partition_key=None) back-fills the run.""" + asset = Asset(name="hello") + with dag_maker(dag_id="rt_pk_backfill", schedule=None) as dag: + EmptyOperator(task_id="hi", outlets=[asset]) + dr = dag_maker.create_dagrun(session=session) + assert dr.partition_key is None + [ti] = dr.get_task_instances(session=session) + + TaskInstance.register_asset_changes_in_db( + ti=ti, + task_outlets=[ensure_serialized_asset(asset).asprofile()], + outlet_events=[ + {"dest_asset_key": {"name": "hello", "uri": "hello"}, "extra": {}, "partition_key": "us"}, + ], + session=session, + ) + event = session.scalar(select(AssetEvent).where(AssetEvent.source_dag_id == dag.dag_id)) + assert event.partition_key == "us" + session.refresh(dr) + assert dr.partition_key == "us" + + +def test_runtime_partition_key_does_not_overwrite_scheduler_partition(dag_maker, session): + """Task-emitted key lands on the AssetEvent but does NOT overwrite a scheduler-set DagRun.partition_key.""" + asset = Asset(name="hello") + with dag_maker(dag_id="rt_pk_no_overwrite", schedule=None) as dag: + EmptyOperator(task_id="hi", outlets=[asset]) + dr = dag_maker.create_dagrun(partition_key="scheduler-key", session=session) + [ti] = dr.get_task_instances(session=session) + + TaskInstance.register_asset_changes_in_db( + ti=ti, + task_outlets=[ensure_serialized_asset(asset).asprofile()], + outlet_events=[ + {"dest_asset_key": {"name": "hello", "uri": "hello"}, "extra": {}, "partition_key": "task-key"}, + ], + session=session, + ) + event = session.scalar(select(AssetEvent).where(AssetEvent.source_dag_id == dag.dag_id)) + assert event.partition_key == "task-key" + session.refresh(dr) + assert dr.partition_key == "scheduler-key" + + +def test_runtime_partition_keys_fan_out_to_one_event_per_key(dag_maker, session): + """Multiple distinct runtime keys produce one AssetEvent each; DagRun.partition_key stays None.""" + asset = Asset(name="hello") + with dag_maker(dag_id="rt_pk_fanout", schedule=None) as dag: + EmptyOperator(task_id="hi", outlets=[asset]) + dr = dag_maker.create_dagrun(session=session) + assert dr.partition_key is None + [ti] = dr.get_task_instances(session=session) + + TaskInstance.register_asset_changes_in_db( + ti=ti, + task_outlets=[ensure_serialized_asset(asset).asprofile()], + outlet_events=[ + {"dest_asset_key": {"name": "hello", "uri": "hello"}, "extra": {}, "partition_key": "us"}, + {"dest_asset_key": {"name": "hello", "uri": "hello"}, "extra": {}, "partition_key": "eu"}, + ], + session=session, + ) + events = session.scalars(select(AssetEvent).where(AssetEvent.source_dag_id == dag.dag_id)).all() + assert {e.partition_key for e in events} == {"us", "eu"} + session.refresh(dr) + assert dr.partition_key is None + + +def test_runtime_partition_key_falls_back_to_dag_run_when_event_has_no_key(dag_maker, session): + """An outlet event without partition_key falls back to dag_run.partition_key (backward compat).""" + asset = Asset(name="hello") + with dag_maker(dag_id="rt_pk_fallback", schedule=None) as dag: + EmptyOperator(task_id="hi", outlets=[asset]) + dr = dag_maker.create_dagrun(partition_key="from-run", session=session) + [ti] = dr.get_task_instances(session=session) + + TaskInstance.register_asset_changes_in_db( + ti=ti, + task_outlets=[ensure_serialized_asset(asset).asprofile()], + outlet_events=[ + {"dest_asset_key": {"name": "hello", "uri": "hello"}, "extra": {"x": 1}}, + ], + session=session, + ) + event = session.scalar(select(AssetEvent).where(AssetEvent.source_dag_id == dag.dag_id)) + assert event.partition_key == "from-run" + + +def test_runtime_partition_key_mixed_events_for_same_asset(dag_maker, session): + """One event with partition_key + one without produce two AssetEvents (with/without override).""" + asset = Asset(name="hello") + with dag_maker(dag_id="rt_pk_mixed", schedule=None) as dag: + EmptyOperator(task_id="hi", outlets=[asset]) + dr = dag_maker.create_dagrun(partition_key="from-run", session=session) + [ti] = dr.get_task_instances(session=session) + + TaskInstance.register_asset_changes_in_db( + ti=ti, + task_outlets=[ensure_serialized_asset(asset).asprofile()], + outlet_events=[ + {"dest_asset_key": {"name": "hello", "uri": "hello"}, "extra": {}, "partition_key": "us"}, + {"dest_asset_key": {"name": "hello", "uri": "hello"}, "extra": {}}, + ], + session=session, + ) + events = session.scalars(select(AssetEvent).where(AssetEvent.source_dag_id == dag.dag_id)).all() + assert {e.partition_key for e in events} == {"us", "from-run"} + session.refresh(dr) + assert dr.partition_key == "from-run" + + +def test_when_runtime_partition_keys_and_downstreams_listening_then_tables_populated( + dag_maker, + session, +): + """Runtime-emitted fan-out populates PartitionedAssetKeyLog + AssetPartitionDagRun per key.""" + asset = Asset(name="hello") + with dag_maker(dag_id="rt_producer", schedule=None, session=session) as dag: + EmptyOperator(task_id="hi", outlets=[asset]) + producer_dag_id = dag.dag_id + dr = dag_maker.create_dagrun(session=session) + assert dr.partition_key is None + [ti] = dr.get_task_instances(session=session) + session.commit() + + with dag_maker( + dag_id="rt_consumer", + schedule=PartitionedAssetTimetable( + assets=Asset(name="hello"), default_partition_mapper=IdentityMapper() + ), + session=session, + ): + EmptyOperator(task_id="hi") + session.commit() + + TaskInstance.register_asset_changes_in_db( + ti=ti, + task_outlets=[ensure_serialized_asset(asset).asprofile()], + outlet_events=[ + {"dest_asset_key": {"name": "hello", "uri": "hello"}, "extra": {}, "partition_key": "us"}, + {"dest_asset_key": {"name": "hello", "uri": "hello"}, "extra": {}, "partition_key": "eu"}, + ], + session=session, + ) + session.commit() + events = session.scalars(select(AssetEvent).where(AssetEvent.source_dag_id == producer_dag_id)).all() + assert {e.partition_key for e in events} == {"us", "eu"} + pakls = session.scalars(select(PartitionedAssetKeyLog)).all() + apdrs = session.scalars(select(AssetPartitionDagRun)).all() + assert {p.source_partition_key for p in pakls} == {"us", "eu"} + assert {p.target_partition_key for p in pakls} == {"us", "eu"} + assert {p.target_dag_id for p in pakls} == {"rt_consumer"} + assert {a.partition_key for a in apdrs} == {"us", "eu"} + assert {a.target_dag_id for a in apdrs} == {"rt_consumer"} + + async def empty_callback_for_deadline(): """Used in deadline tests to confirm that Deadlines and DeadlineAlerts function correctly.""" pass