Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 46 additions & 38 deletions airflow-core/src/airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -1476,10 +1476,20 @@ def register_asset_changes_in_db(
SerializedAssetUriRef,
)

# TODO: AIP-76 should we provide an interface to override this, so that the task can
# tell the truth if for some reason it touches a different partition?
# https://github.com/apache/airflow/issues/58474
partition_key = ti.dag_run.partition_key
events_by_asset: dict[SerializedAssetUniqueKey, list[tuple[dict, str | None]]] = defaultdict(list)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it would be easier to read this way

Suggested change
events_by_asset: dict[SerializedAssetUniqueKey, list[tuple[dict, str | None]]] = defaultdict(list)
payloads_by_asset: dict[SerializedAssetUniqueKey, list[OutletEventPayload] = defaultdict(list)
 class OutletEventPayload(NamedTuple):
      extra: dict
      partition_key: str | None

for outlet_event in outlet_events:
if "source_alias_name" in outlet_event:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if "source_alias_name" in outlet_event:
# Alias-emitted events are handled separately further down via
# register_asset_change_for_alias, which uses the DagRun-level
# partition_key. Per-emission partition keys do not fan out through
# the alias path — emission via an alias produces one event per
# resolved asset, all carrying the same dag_run_partition_key.
if "source_alias_name" in outlet_event:

continue
asset_key = SerializedAssetUniqueKey(**outlet_event["dest_asset_key"])
events_by_asset[asset_key].append((outlet_event["extra"], outlet_event.get("partition_key")))

runtime_pks: set[str] = {
pk for events in events_by_asset.values() for _, pk in events if pk is not None
}
if len(runtime_pks) == 1 and ti.dag_run.partition_key is None:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's also add a comment here to explain what we're doing

ti.dag_run.partition_key = next(iter(runtime_pks))
dag_run_partition_key = ti.dag_run.partition_key

asset_keys = {
SerializedAssetUniqueKey(o.name, o.uri)
for o in task_outlets
Expand All @@ -1506,11 +1516,25 @@ def register_asset_changes_in_db(
)
}

asset_event_extras: dict[SerializedAssetUniqueKey, dict] = {
SerializedAssetUniqueKey(**event["dest_asset_key"]): event["extra"]
for event in outlet_events
if "source_alias_name" not in event
}
def _register(am: AssetModel, key: SerializedAssetUniqueKey) -> None:
events_for_asset = events_by_asset.get(key, [])
if not events_for_asset:
asset_manager.register_asset_change(
task_instance=ti,
asset=am,
extra=None,
partition_key=dag_run_partition_key,
session=session,
)
return
for extra, event_pk in events_for_asset:
asset_manager.register_asset_change(
task_instance=ti,
asset=am,
extra=extra,
partition_key=event_pk if event_pk is not None else dag_run_partition_key,
session=session,
)

for key in asset_keys:
try:
Expand All @@ -1523,52 +1547,36 @@ def register_asset_changes_in_db(
)
continue
ti.log.debug("register event for asset %s", am)
asset_manager.register_asset_change(
task_instance=ti,
asset=am,
extra=asset_event_extras.get(key),
partition_key=partition_key,
session=session,
)
_register(am, key)

if asset_name_refs:
asset_models_by_name = {key.name: am for key, am in asset_models.items()}
asset_event_extras_by_name = {key.name: extra for key, extra in asset_event_extras.items()}
asset_models_by_name: dict[str, tuple[SerializedAssetUniqueKey, AssetModel]] = {
key.name: (key, am) for key, am in asset_models.items()
}
for nref in asset_name_refs:
try:
am = asset_models_by_name[nref.name]
key, am = asset_models_by_name[nref.name]
except KeyError:
ti.log.warning(
'Task has inactive assets "Asset.ref(name=%s)" in inlets or outlets', nref.name
)
continue
ti.log.debug("register event for asset name ref %s", am)
asset_manager.register_asset_change(
task_instance=ti,
asset=am,
extra=asset_event_extras_by_name.get(nref.name),
partition_key=partition_key,
session=session,
)
_register(am, key)
if asset_uri_refs:
asset_models_by_uri = {key.uri: am for key, am in asset_models.items()}
asset_event_extras_by_uri = {key.uri: extra for key, extra in asset_event_extras.items()}
asset_models_by_uri: dict[str, tuple[SerializedAssetUniqueKey, AssetModel]] = {
key.uri: (key, am) for key, am in asset_models.items()
}
for uref in asset_uri_refs:
try:
am = asset_models_by_uri[uref.uri]
key, am = asset_models_by_uri[uref.uri]
except KeyError:
ti.log.warning(
'Task has inactive assets "Asset.ref(uri=%s)" in inlets or outlets', uref.uri
)
continue
ti.log.debug("register event for asset uri ref %s", am)
asset_manager.register_asset_change(
task_instance=ti,
asset=am,
extra=asset_event_extras_by_uri.get(uref.uri),
partition_key=partition_key,
session=session,
)
_register(am, key)

def _asset_event_extras_from_aliases() -> dict[tuple[SerializedAssetUniqueKey, str, str], set[str]]:
d = defaultdict(set)
Expand Down Expand Up @@ -1607,7 +1615,7 @@ def _asset_event_extras_from_aliases() -> dict[tuple[SerializedAssetUniqueKey, s
asset=asset,
source_alias_names=event_aliase_names,
extra=asset_event_extra,
partition_key=partition_key,
partition_key=dag_run_partition_key,
session=session,
)
if event is None:
Expand All @@ -1619,7 +1627,7 @@ def _asset_event_extras_from_aliases() -> dict[tuple[SerializedAssetUniqueKey, s
asset=asset,
source_alias_names=event_aliase_names,
extra=asset_event_extra,
partition_key=partition_key,
partition_key=dag_run_partition_key,
session=session,
)

Expand Down
157 changes: 157 additions & 0 deletions airflow-core/tests/unit/models/test_taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -3513,6 +3513,163 @@ def test_when_dag_run_has_partition_and_downstreams_listening_then_tables_popula
assert pakl.target_dag_id == "asset_event_listener"


def test_runtime_partition_key_backfills_dag_run_when_none(dag_maker, session):
"""Single runtime key on a PartitionAtRuntime-style run (dag_run.partition_key=None) back-fills the run."""
asset = Asset(name="hello")
with dag_maker(dag_id="rt_pk_backfill", schedule=None) as dag:
EmptyOperator(task_id="hi", outlets=[asset])
dr = dag_maker.create_dagrun(session=session)
assert dr.partition_key is None
[ti] = dr.get_task_instances(session=session)

TaskInstance.register_asset_changes_in_db(
ti=ti,
task_outlets=[ensure_serialized_asset(asset).asprofile()],
outlet_events=[
{"dest_asset_key": {"name": "hello", "uri": "hello"}, "extra": {}, "partition_key": "us"},
],
session=session,
)
event = session.scalar(select(AssetEvent).where(AssetEvent.source_dag_id == dag.dag_id))
assert event.partition_key == "us"
session.refresh(dr)
assert dr.partition_key == "us"


def test_runtime_partition_key_does_not_overwrite_scheduler_partition(dag_maker, session):
"""Task-emitted key lands on the AssetEvent but does NOT overwrite a scheduler-set DagRun.partition_key."""
asset = Asset(name="hello")
with dag_maker(dag_id="rt_pk_no_overwrite", schedule=None) as dag:
EmptyOperator(task_id="hi", outlets=[asset])
dr = dag_maker.create_dagrun(partition_key="scheduler-key", session=session)
[ti] = dr.get_task_instances(session=session)

TaskInstance.register_asset_changes_in_db(
ti=ti,
task_outlets=[ensure_serialized_asset(asset).asprofile()],
outlet_events=[
{"dest_asset_key": {"name": "hello", "uri": "hello"}, "extra": {}, "partition_key": "task-key"},
],
session=session,
)
event = session.scalar(select(AssetEvent).where(AssetEvent.source_dag_id == dag.dag_id))
assert event.partition_key == "task-key"
session.refresh(dr)
assert dr.partition_key == "scheduler-key"


def test_runtime_partition_keys_fan_out_to_one_event_per_key(dag_maker, session):
"""Multiple distinct runtime keys produce one AssetEvent each; DagRun.partition_key stays None."""
asset = Asset(name="hello")
with dag_maker(dag_id="rt_pk_fanout", schedule=None) as dag:
EmptyOperator(task_id="hi", outlets=[asset])
dr = dag_maker.create_dagrun(session=session)
assert dr.partition_key is None
[ti] = dr.get_task_instances(session=session)

TaskInstance.register_asset_changes_in_db(
ti=ti,
task_outlets=[ensure_serialized_asset(asset).asprofile()],
outlet_events=[
{"dest_asset_key": {"name": "hello", "uri": "hello"}, "extra": {}, "partition_key": "us"},
{"dest_asset_key": {"name": "hello", "uri": "hello"}, "extra": {}, "partition_key": "eu"},
],
session=session,
)
events = session.scalars(select(AssetEvent).where(AssetEvent.source_dag_id == dag.dag_id)).all()
assert {e.partition_key for e in events} == {"us", "eu"}
session.refresh(dr)
assert dr.partition_key is None


def test_runtime_partition_key_falls_back_to_dag_run_when_event_has_no_key(dag_maker, session):
"""An outlet event without partition_key falls back to dag_run.partition_key (backward compat)."""
asset = Asset(name="hello")
with dag_maker(dag_id="rt_pk_fallback", schedule=None) as dag:
EmptyOperator(task_id="hi", outlets=[asset])
dr = dag_maker.create_dagrun(partition_key="from-run", session=session)
[ti] = dr.get_task_instances(session=session)

TaskInstance.register_asset_changes_in_db(
ti=ti,
task_outlets=[ensure_serialized_asset(asset).asprofile()],
outlet_events=[
{"dest_asset_key": {"name": "hello", "uri": "hello"}, "extra": {"x": 1}},
],
session=session,
)
event = session.scalar(select(AssetEvent).where(AssetEvent.source_dag_id == dag.dag_id))
assert event.partition_key == "from-run"


def test_runtime_partition_key_mixed_events_for_same_asset(dag_maker, session):
"""One event with partition_key + one without produce two AssetEvents (with/without override)."""
asset = Asset(name="hello")
with dag_maker(dag_id="rt_pk_mixed", schedule=None) as dag:
EmptyOperator(task_id="hi", outlets=[asset])
dr = dag_maker.create_dagrun(partition_key="from-run", session=session)
[ti] = dr.get_task_instances(session=session)

TaskInstance.register_asset_changes_in_db(
ti=ti,
task_outlets=[ensure_serialized_asset(asset).asprofile()],
outlet_events=[
{"dest_asset_key": {"name": "hello", "uri": "hello"}, "extra": {}, "partition_key": "us"},
{"dest_asset_key": {"name": "hello", "uri": "hello"}, "extra": {}},
],
session=session,
)
events = session.scalars(select(AssetEvent).where(AssetEvent.source_dag_id == dag.dag_id)).all()
assert {e.partition_key for e in events} == {"us", "from-run"}
session.refresh(dr)
assert dr.partition_key == "from-run"


def test_when_runtime_partition_keys_and_downstreams_listening_then_tables_populated(
dag_maker,
session,
):
"""Runtime-emitted fan-out populates PartitionedAssetKeyLog + AssetPartitionDagRun per key."""
asset = Asset(name="hello")
with dag_maker(dag_id="rt_producer", schedule=None, session=session) as dag:
EmptyOperator(task_id="hi", outlets=[asset])
producer_dag_id = dag.dag_id
dr = dag_maker.create_dagrun(session=session)
assert dr.partition_key is None
[ti] = dr.get_task_instances(session=session)
session.commit()

with dag_maker(
dag_id="rt_consumer",
schedule=PartitionedAssetTimetable(
assets=Asset(name="hello"), default_partition_mapper=IdentityMapper()
),
session=session,
):
EmptyOperator(task_id="hi")
session.commit()

TaskInstance.register_asset_changes_in_db(
ti=ti,
task_outlets=[ensure_serialized_asset(asset).asprofile()],
outlet_events=[
{"dest_asset_key": {"name": "hello", "uri": "hello"}, "extra": {}, "partition_key": "us"},
{"dest_asset_key": {"name": "hello", "uri": "hello"}, "extra": {}, "partition_key": "eu"},
],
session=session,
)
session.commit()
events = session.scalars(select(AssetEvent).where(AssetEvent.source_dag_id == producer_dag_id)).all()
assert {e.partition_key for e in events} == {"us", "eu"}
pakls = session.scalars(select(PartitionedAssetKeyLog)).all()
apdrs = session.scalars(select(AssetPartitionDagRun)).all()
assert {p.source_partition_key for p in pakls} == {"us", "eu"}
assert {p.target_partition_key for p in pakls} == {"us", "eu"}
assert {p.target_dag_id for p in pakls} == {"rt_consumer"}
assert {a.partition_key for a in apdrs} == {"us", "eu"}
assert {a.target_dag_id for a in apdrs} == {"rt_consumer"}


async def empty_callback_for_deadline():
"""Used in deadline tests to confirm that Deadlines and DeadlineAlerts function correctly."""
pass
Expand Down
Loading