Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 73 additions & 39 deletions airflow-core/src/airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from collections import defaultdict
from collections.abc import Collection, Iterable
from datetime import datetime, timedelta
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, NamedTuple
from urllib.parse import quote
from uuid import UUID

Expand Down Expand Up @@ -129,6 +129,13 @@
PAST_DEPENDS_MET = "past_depends_met"


class OutletEventPayload(NamedTuple):
"""A single outlet emission carrying its ``extra`` payload and optional per-emission ``partition_key``."""

extra: dict
partition_key: str | None


@provide_session
def _add_log(
event,
Expand Down Expand Up @@ -1476,10 +1483,37 @@ def register_asset_changes_in_db(
SerializedAssetUriRef,
)

# TODO: AIP-76 should we provide an interface to override this, so that the task can
# tell the truth if for some reason it touches a different partition?
# https://github.com/apache/airflow/issues/58474
partition_key = ti.dag_run.partition_key
payloads_by_asset: dict[SerializedAssetUniqueKey, list[OutletEventPayload]] = defaultdict(list)
for outlet_event in outlet_events:
# Alias-emitted events are handled separately further down via
# register_asset_change_for_alias, which uses the DagRun-level
# partition_key. Per-emission partition keys do not fan out through
# the alias path — emission via an alias produces one event per
# resolved asset, all carrying the same dag_run_partition_key.
if "source_alias_name" in outlet_event:
Comment thread
anishgirianish marked this conversation as resolved.
continue
asset_key = SerializedAssetUniqueKey(**outlet_event["dest_asset_key"])
payloads_by_asset[asset_key].append(
OutletEventPayload(
extra=outlet_event["extra"], partition_key=outlet_event.get("partition_key")
)
)

# Back-fill DagRun.partition_key from the task emission when the task
# emitted exactly one distinct partition_key across all outlet events
# and the DagRun did not already have one set. This lets a task that
# discovers the partition at runtime (rather than via params) act as
# the source of truth for the DagRun-level key.
runtime_pks: set[str] = {
payload.partition_key
for payloads in payloads_by_asset.values()
for payload in payloads
if payload.partition_key is not None
}
if len(runtime_pks) == 1 and ti.dag_run.partition_key is None:
Comment thread
anishgirianish marked this conversation as resolved.
ti.dag_run.partition_key = next(iter(runtime_pks))
dag_run_partition_key = ti.dag_run.partition_key

asset_keys = {
SerializedAssetUniqueKey(o.name, o.uri)
for o in task_outlets
Expand All @@ -1506,11 +1540,27 @@ def register_asset_changes_in_db(
)
}

asset_event_extras: dict[SerializedAssetUniqueKey, dict] = {
SerializedAssetUniqueKey(**event["dest_asset_key"]): event["extra"]
for event in outlet_events
if "source_alias_name" not in event
}
def _register(am: AssetModel, key: SerializedAssetUniqueKey) -> None:
payloads_for_asset = payloads_by_asset.get(key, [])
if not payloads_for_asset:
asset_manager.register_asset_change(
task_instance=ti,
asset=am,
extra=None,
partition_key=dag_run_partition_key,
session=session,
)
return
for payload in payloads_for_asset:
asset_manager.register_asset_change(
task_instance=ti,
asset=am,
extra=payload.extra,
partition_key=payload.partition_key
if payload.partition_key is not None
else dag_run_partition_key,
session=session,
)

for key in asset_keys:
try:
Expand All @@ -1523,52 +1573,36 @@ def register_asset_changes_in_db(
)
continue
ti.log.debug("register event for asset %s", am)
asset_manager.register_asset_change(
task_instance=ti,
asset=am,
extra=asset_event_extras.get(key),
partition_key=partition_key,
session=session,
)
_register(am, key)

if asset_name_refs:
asset_models_by_name = {key.name: am for key, am in asset_models.items()}
asset_event_extras_by_name = {key.name: extra for key, extra in asset_event_extras.items()}
asset_models_by_name: dict[str, tuple[SerializedAssetUniqueKey, AssetModel]] = {
key.name: (key, am) for key, am in asset_models.items()
}
for nref in asset_name_refs:
try:
am = asset_models_by_name[nref.name]
key, am = asset_models_by_name[nref.name]
except KeyError:
ti.log.warning(
'Task has inactive assets "Asset.ref(name=%s)" in inlets or outlets', nref.name
)
continue
ti.log.debug("register event for asset name ref %s", am)
asset_manager.register_asset_change(
task_instance=ti,
asset=am,
extra=asset_event_extras_by_name.get(nref.name),
partition_key=partition_key,
session=session,
)
_register(am, key)
if asset_uri_refs:
asset_models_by_uri = {key.uri: am for key, am in asset_models.items()}
asset_event_extras_by_uri = {key.uri: extra for key, extra in asset_event_extras.items()}
asset_models_by_uri: dict[str, tuple[SerializedAssetUniqueKey, AssetModel]] = {
key.uri: (key, am) for key, am in asset_models.items()
}
for uref in asset_uri_refs:
try:
am = asset_models_by_uri[uref.uri]
key, am = asset_models_by_uri[uref.uri]
except KeyError:
ti.log.warning(
'Task has inactive assets "Asset.ref(uri=%s)" in inlets or outlets', uref.uri
)
continue
ti.log.debug("register event for asset uri ref %s", am)
asset_manager.register_asset_change(
task_instance=ti,
asset=am,
extra=asset_event_extras_by_uri.get(uref.uri),
partition_key=partition_key,
session=session,
)
_register(am, key)

def _asset_event_extras_from_aliases() -> dict[tuple[SerializedAssetUniqueKey, str, str], set[str]]:
d = defaultdict(set)
Expand Down Expand Up @@ -1607,7 +1641,7 @@ def _asset_event_extras_from_aliases() -> dict[tuple[SerializedAssetUniqueKey, s
asset=asset,
source_alias_names=event_aliase_names,
extra=asset_event_extra,
partition_key=partition_key,
partition_key=dag_run_partition_key,
session=session,
)
if event is None:
Expand All @@ -1619,7 +1653,7 @@ def _asset_event_extras_from_aliases() -> dict[tuple[SerializedAssetUniqueKey, s
asset=asset,
source_alias_names=event_aliase_names,
extra=asset_event_extra,
partition_key=partition_key,
partition_key=dag_run_partition_key,
session=session,
)

Expand Down
157 changes: 157 additions & 0 deletions airflow-core/tests/unit/models/test_taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -3513,6 +3513,163 @@ def test_when_dag_run_has_partition_and_downstreams_listening_then_tables_popula
assert pakl.target_dag_id == "asset_event_listener"


def test_runtime_partition_key_backfills_dag_run_when_none(dag_maker, session):
"""Single runtime key on a PartitionAtRuntime-style run (dag_run.partition_key=None) back-fills the run."""
asset = Asset(name="hello")
with dag_maker(dag_id="rt_pk_backfill", schedule=None) as dag:
EmptyOperator(task_id="hi", outlets=[asset])
dr = dag_maker.create_dagrun(session=session)
assert dr.partition_key is None
[ti] = dr.get_task_instances(session=session)

TaskInstance.register_asset_changes_in_db(
ti=ti,
task_outlets=[ensure_serialized_asset(asset).asprofile()],
outlet_events=[
{"dest_asset_key": {"name": "hello", "uri": "hello"}, "extra": {}, "partition_key": "us"},
],
session=session,
)
event = session.scalar(select(AssetEvent).where(AssetEvent.source_dag_id == dag.dag_id))
assert event.partition_key == "us"
session.refresh(dr)
assert dr.partition_key == "us"


def test_runtime_partition_key_does_not_overwrite_scheduler_partition(dag_maker, session):
"""Task-emitted key lands on the AssetEvent but does NOT overwrite a scheduler-set DagRun.partition_key."""
asset = Asset(name="hello")
with dag_maker(dag_id="rt_pk_no_overwrite", schedule=None) as dag:
EmptyOperator(task_id="hi", outlets=[asset])
dr = dag_maker.create_dagrun(partition_key="scheduler-key", session=session)
[ti] = dr.get_task_instances(session=session)

TaskInstance.register_asset_changes_in_db(
ti=ti,
task_outlets=[ensure_serialized_asset(asset).asprofile()],
outlet_events=[
{"dest_asset_key": {"name": "hello", "uri": "hello"}, "extra": {}, "partition_key": "task-key"},
],
session=session,
)
event = session.scalar(select(AssetEvent).where(AssetEvent.source_dag_id == dag.dag_id))
assert event.partition_key == "task-key"
session.refresh(dr)
assert dr.partition_key == "scheduler-key"


def test_runtime_partition_keys_fan_out_to_one_event_per_key(dag_maker, session):
"""Multiple distinct runtime keys produce one AssetEvent each; DagRun.partition_key stays None."""
asset = Asset(name="hello")
with dag_maker(dag_id="rt_pk_fanout", schedule=None) as dag:
EmptyOperator(task_id="hi", outlets=[asset])
dr = dag_maker.create_dagrun(session=session)
assert dr.partition_key is None
[ti] = dr.get_task_instances(session=session)

TaskInstance.register_asset_changes_in_db(
ti=ti,
task_outlets=[ensure_serialized_asset(asset).asprofile()],
outlet_events=[
{"dest_asset_key": {"name": "hello", "uri": "hello"}, "extra": {}, "partition_key": "us"},
{"dest_asset_key": {"name": "hello", "uri": "hello"}, "extra": {}, "partition_key": "eu"},
],
session=session,
)
events = session.scalars(select(AssetEvent).where(AssetEvent.source_dag_id == dag.dag_id)).all()
assert {e.partition_key for e in events} == {"us", "eu"}
session.refresh(dr)
assert dr.partition_key is None


def test_runtime_partition_key_falls_back_to_dag_run_when_event_has_no_key(dag_maker, session):
"""An outlet event without partition_key falls back to dag_run.partition_key (backward compat)."""
asset = Asset(name="hello")
with dag_maker(dag_id="rt_pk_fallback", schedule=None) as dag:
EmptyOperator(task_id="hi", outlets=[asset])
dr = dag_maker.create_dagrun(partition_key="from-run", session=session)
[ti] = dr.get_task_instances(session=session)

TaskInstance.register_asset_changes_in_db(
ti=ti,
task_outlets=[ensure_serialized_asset(asset).asprofile()],
outlet_events=[
{"dest_asset_key": {"name": "hello", "uri": "hello"}, "extra": {"x": 1}},
],
session=session,
)
event = session.scalar(select(AssetEvent).where(AssetEvent.source_dag_id == dag.dag_id))
assert event.partition_key == "from-run"


def test_runtime_partition_key_mixed_events_for_same_asset(dag_maker, session):
"""One event with partition_key + one without produce two AssetEvents (with/without override)."""
asset = Asset(name="hello")
with dag_maker(dag_id="rt_pk_mixed", schedule=None) as dag:
EmptyOperator(task_id="hi", outlets=[asset])
dr = dag_maker.create_dagrun(partition_key="from-run", session=session)
[ti] = dr.get_task_instances(session=session)

TaskInstance.register_asset_changes_in_db(
ti=ti,
task_outlets=[ensure_serialized_asset(asset).asprofile()],
outlet_events=[
{"dest_asset_key": {"name": "hello", "uri": "hello"}, "extra": {}, "partition_key": "us"},
{"dest_asset_key": {"name": "hello", "uri": "hello"}, "extra": {}},
],
session=session,
)
events = session.scalars(select(AssetEvent).where(AssetEvent.source_dag_id == dag.dag_id)).all()
assert {e.partition_key for e in events} == {"us", "from-run"}
session.refresh(dr)
assert dr.partition_key == "from-run"


def test_when_runtime_partition_keys_and_downstreams_listening_then_tables_populated(
dag_maker,
session,
):
"""Runtime-emitted fan-out populates PartitionedAssetKeyLog + AssetPartitionDagRun per key."""
asset = Asset(name="hello")
with dag_maker(dag_id="rt_producer", schedule=None, session=session) as dag:
EmptyOperator(task_id="hi", outlets=[asset])
producer_dag_id = dag.dag_id
dr = dag_maker.create_dagrun(session=session)
assert dr.partition_key is None
[ti] = dr.get_task_instances(session=session)
session.commit()

with dag_maker(
dag_id="rt_consumer",
schedule=PartitionedAssetTimetable(
assets=Asset(name="hello"), default_partition_mapper=IdentityMapper()
),
session=session,
):
EmptyOperator(task_id="hi")
session.commit()

TaskInstance.register_asset_changes_in_db(
ti=ti,
task_outlets=[ensure_serialized_asset(asset).asprofile()],
outlet_events=[
{"dest_asset_key": {"name": "hello", "uri": "hello"}, "extra": {}, "partition_key": "us"},
{"dest_asset_key": {"name": "hello", "uri": "hello"}, "extra": {}, "partition_key": "eu"},
],
session=session,
)
session.commit()
events = session.scalars(select(AssetEvent).where(AssetEvent.source_dag_id == producer_dag_id)).all()
assert {e.partition_key for e in events} == {"us", "eu"}
pakls = session.scalars(select(PartitionedAssetKeyLog)).all()
apdrs = session.scalars(select(AssetPartitionDagRun)).all()
assert {p.source_partition_key for p in pakls} == {"us", "eu"}
assert {p.target_partition_key for p in pakls} == {"us", "eu"}
assert {p.target_dag_id for p in pakls} == {"rt_consumer"}
assert {a.partition_key for a in apdrs} == {"us", "eu"}
assert {a.target_dag_id for a in apdrs} == {"rt_consumer"}


async def empty_callback_for_deadline():
"""Used in deadline tests to confirm that Deadlines and DeadlineAlerts function correctly."""
pass
Expand Down
Loading