Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions airflow-core/docs/authoring-and-scheduling/assets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,8 @@ Declaring an ``@asset`` automatically creates:
* A ``DAG`` with *dag_id* set to the function name.
* A task inside the ``DAG`` with *task_id* set to the function name, and *outlet* to the created ``Asset``.

The parameter names ``self``, ``context``, and ``outlet_events`` are **reserved** in an ``@asset`` function: they are populated by Airflow at runtime (with the asset itself, the execution context, and the outlet event accessor respectively) and are never treated as inlet asset references.

Attaching extra information to an emitting asset event
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
7 changes: 6 additions & 1 deletion airflow-core/src/airflow/serialization/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
from airflow.sdk.definitions.partition_mappers.temporal import StartOfHourMapper
from airflow.sdk.definitions.timetables.assets import (
AssetTriggeredTimetable,
PartitionAtRuntime,
PartitionedAssetTimetable,
)
from airflow.sdk.definitions.timetables.simple import ContinuousTimetable, NullTimetable, OnceTimetable
Expand Down Expand Up @@ -295,6 +296,7 @@ class _Serializer:
MultipleCronTriggerTimetable: "airflow.timetables.trigger.MultipleCronTriggerTimetable",
NullTimetable: "airflow.timetables.simple.NullTimetable",
OnceTimetable: "airflow.timetables.simple.OnceTimetable",
PartitionAtRuntime: "airflow.timetables.simple.PartitionAtRuntime",
PartitionedAssetTimetable: "airflow.timetables.simple.PartitionedAssetTimetable",
}

Expand All @@ -320,7 +322,10 @@ def serialize_timetable(self, timetable: BaseTimetable | CoreTimetable) -> dict[
@serialize_timetable.register(ContinuousTimetable)
@serialize_timetable.register(NullTimetable)
@serialize_timetable.register(OnceTimetable)
def _(self, timetable: ContinuousTimetable | NullTimetable | OnceTimetable) -> dict[str, Any]:
@serialize_timetable.register(PartitionAtRuntime)
def _(
self, timetable: ContinuousTimetable | NullTimetable | OnceTimetable | PartitionAtRuntime
) -> dict[str, Any]:
return {}

@serialize_timetable.register
Expand Down
7 changes: 7 additions & 0 deletions airflow-core/src/airflow/timetables/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,13 @@ class Timetable(Protocol):
instead of the traditional logic based on logical dates and data intervals.
"""

partitioned_at_runtime: bool = False
"""Whether this timetable defers partition selection to task runtime.

*True* for :class:`~airflow.timetables.simple.PartitionAtRuntime`;
downstream code can branch on this flag instead of using ``isinstance``.
"""

@classmethod
def deserialize(cls, data: dict[str, Any]) -> Timetable:
"""
Expand Down
16 changes: 16 additions & 0 deletions airflow-core/src/airflow/timetables/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ class NullTimetable(_TrivialTimetable):
"""

can_be_scheduled = False # TODO (GH-52141): Find a way to keep this and one in Core in sync.
partitioned_at_runtime = False
description: str = "Never, external triggers only"

@property
Expand Down Expand Up @@ -183,6 +184,21 @@ def next_dagrun_info(
return DagRunInfo.interval(start, end)


class PartitionAtRuntime(NullTimetable):
Comment thread
anishgirianish marked this conversation as resolved.
"""
Timetable that never schedules anything; partition keys are set at runtime.

This corresponds to ``schedule=PartitionAtRuntime()``.
"""

description: str = "Never, partition key(s) set at runtime"
partitioned_at_runtime = True

@property
def summary(self) -> str:
return "PartitionAtRuntime"


class AssetTriggeredTimetable(_TrivialTimetable):
"""
Timetable that never schedules anything.
Expand Down
2 changes: 2 additions & 0 deletions task-sdk/docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,8 @@ Timetables

.. autoapiclass:: airflow.sdk.MultipleCronTriggerTimetable

.. autoapiclass:: airflow.sdk.PartitionAtRuntime

.. autoapiclass:: airflow.sdk.PartitionedAssetTimetable


Expand Down
11 changes: 10 additions & 1 deletion task-sdk/src/airflow/sdk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
"ObjectStoragePath",
"Param",
"ParamsDict",
"PartitionAtRuntime",
"PartitionedAssetTimetable",
"PartitionMapper",
"PokeReturnValue",
Expand Down Expand Up @@ -118,7 +119,13 @@
from airflow.sdk.bases.skipmixin import SkipMixin
from airflow.sdk.bases.xcom import BaseXCom
from airflow.sdk.configuration import AirflowSDKConfigParser
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAll, AssetAny, AssetWatcher
from airflow.sdk.definitions.asset import (
Asset,
AssetAlias,
AssetAll,
AssetAny,
AssetWatcher,
)
from airflow.sdk.definitions.asset.decorators import asset
from airflow.sdk.definitions.asset.metadata import Metadata
from airflow.sdk.definitions.callback import AsyncCallback, SyncCallback
Expand Down Expand Up @@ -154,6 +161,7 @@
from airflow.sdk.definitions.template import literal
from airflow.sdk.definitions.timetables.assets import (
AssetOrTimeSchedule,
PartitionAtRuntime,
PartitionedAssetTimetable,
)
from airflow.sdk.definitions.timetables.events import EventsTimetable
Expand Down Expand Up @@ -215,6 +223,7 @@
"ObjectStoragePath": ".io.path",
"Param": ".definitions.param",
"ParamsDict": ".definitions.param",
"PartitionAtRuntime": ".definitions.timetables.assets",
"PartitionedAssetTimetable": ".definitions.timetables.assets",
"PartitionMapper": ".definitions.partition_mappers.base",
"PokeReturnValue": ".bases.sensor",
Expand Down
2 changes: 2 additions & 0 deletions task-sdk/src/airflow/sdk/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ from airflow.sdk.definitions.taskgroup import TaskGroup as TaskGroup
from airflow.sdk.definitions.template import literal as literal
from airflow.sdk.definitions.timetables.assets import (
AssetOrTimeSchedule,
PartitionAtRuntime,
PartitionedAssetTimetable,
)
from airflow.sdk.definitions.timetables.events import EventsTimetable
Expand Down Expand Up @@ -145,6 +146,7 @@ __all__ = [
"ObjectStoragePath",
"Param",
"PokeReturnValue",
"PartitionAtRuntime",
"PartitionedAssetTimetable",
"PartitionMapper",
"ProductMapper",
Expand Down
8 changes: 8 additions & 0 deletions task-sdk/src/airflow/sdk/bases/timetable.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,14 @@ class BaseTimetable:

asset_condition: BaseAsset | None = None

partitioned_at_runtime: bool = False
"""
Whether this timetable defers partition selection to task runtime.

*True* for :class:`~airflow.sdk.PartitionAtRuntime`; downstream code can
branch on this flag instead of using ``isinstance``.
"""

def validate(self) -> None:
"""
Validate the timetable is correctly specified.
Expand Down
8 changes: 7 additions & 1 deletion task-sdk/src/airflow/sdk/definitions/asset/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@
from airflow.triggers.base import BaseTrigger


_INVALID_INLET_ASSET_NAMES = ("self", "context", "outlet_events")


def _validate_asset_function_arguments(f: Callable) -> None:
for name, param in inspect.signature(f).parameters.items():
if param.kind == inspect.Parameter.VAR_POSITIONAL:
Expand All @@ -62,7 +65,8 @@ def from_definition(cls, definition: AssetDefinition | MultiAssetDefinition) ->
inlets=[
Asset.ref(name=inlet_asset_name)
for inlet_asset_name, param in inspect.signature(definition._function).parameters.items()
if inlet_asset_name not in ("self", "context") and param.default is inspect.Parameter.empty
if inlet_asset_name not in _INVALID_INLET_ASSET_NAMES
and param.default is inspect.Parameter.empty
],
outlets=list(definition.iter_outlets()),
python_callable=definition._function,
Expand All @@ -89,6 +93,8 @@ def _fetch_asset(name: str) -> Asset:
value = _fetch_asset(self._definition_name)
elif key == "context":
value = context
elif key == "outlet_events":
value = context["outlet_events"]
else:
value = _fetch_asset(key)
yield key, value
Expand Down
7 changes: 7 additions & 0 deletions task-sdk/src/airflow/sdk/definitions/timetables/assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,13 @@ class PartitionedAssetTimetable(AssetTriggeredTimetable):
default_partition_mapper: PartitionMapper = IdentityMapper()


class PartitionAtRuntime(BaseTimetable):
"""Marker timetable indicating that partition key(s) are determined at runtime."""

can_be_scheduled = False
partitioned_at_runtime = True


def _coerce_assets(o: Collection[Asset] | BaseAsset) -> BaseAsset:
if isinstance(o, BaseAsset):
return o
Expand Down
7 changes: 7 additions & 0 deletions task-sdk/src/airflow/sdk/execution_time/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,6 +710,13 @@ class OutletEventAccessor(_AssetRefResolutionMixin):
key: BaseAssetUniqueKey
extra: dict[str, JsonValue] = attrs.Factory(dict)
asset_alias_events: list[AssetAliasEvent] = attrs.field(factory=list)
partition_keys: set[str] = attrs.field(factory=set)

def add_partitions(self, keys: str | list[str]) -> None:
"""Add one or more partition keys to :attr:`partition_keys`."""
if isinstance(keys, str):
keys = [keys]
self.partition_keys.update(keys)

def add(self, asset: Asset | AssetRef, extra: dict[str, JsonValue] | None = None) -> None:
"""Add an AssetEvent to an existing Asset."""
Expand Down
18 changes: 16 additions & 2 deletions task-sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,13 @@
from airflow.sdk.configuration import conf
from airflow.sdk.definitions._internal.dag_parsing_context import _airflow_parsing_context_manager
from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet, is_arg_set
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetNameRef, AssetUniqueKey, AssetUriRef
from airflow.sdk.definitions.asset import (
Asset,
AssetAlias,
AssetNameRef,
AssetUniqueKey,
AssetUriRef,
)
from airflow.sdk.definitions.mappedoperator import MappedOperator
from airflow.sdk.definitions.param import process_params
from airflow.sdk.exceptions import (
Expand Down Expand Up @@ -1145,7 +1151,15 @@ def _serialize_outlet_events(events: OutletEventAccessorsProtocol) -> Iterator[d
# Further filtering will be done in the API server.
for key, accessor in events._dict.items():
if isinstance(key, AssetUniqueKey):
yield {"dest_asset_key": attrs.asdict(key), "extra": accessor.extra}
if accessor.partition_keys:
for partition_key in accessor.partition_keys:
yield {
"dest_asset_key": attrs.asdict(key),
"extra": accessor.extra,
"partition_key": partition_key,
}
else:
yield {"dest_asset_key": attrs.asdict(key), "extra": accessor.extra}
for alias_event in accessor.asset_alias_events:
yield attrs.asdict(alias_event)

Expand Down
11 changes: 10 additions & 1 deletion task-sdk/src/airflow/sdk/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,13 @@
TaskInstanceState,
)
from airflow.sdk.bases.operator import BaseOperator
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAliasEvent, AssetRef, BaseAssetUniqueKey
from airflow.sdk.definitions.asset import (
Asset,
AssetAlias,
AssetAliasEvent,
AssetRef,
BaseAssetUniqueKey,
)
from airflow.sdk.definitions.context import Context
from airflow.sdk.definitions.mappedoperator import MappedOperator
from airflow.sdk.execution_time.comms import DagResult
Expand Down Expand Up @@ -215,15 +221,18 @@ class OutletEventAccessorProtocol(Protocol):
key: BaseAssetUniqueKey
extra: dict[str, JsonValue]
asset_alias_events: list[AssetAliasEvent]
partition_keys: set[str]

def __init__(
self,
*,
key: BaseAssetUniqueKey,
extra: dict[str, JsonValue],
asset_alias_events: list[AssetAliasEvent],
partition_keys: set[str] = ...,
) -> None: ...
def add(self, asset: Asset, extra: dict[str, JsonValue] | None = None) -> None: ...
def add_partitions(self, keys: str | list[str]) -> None: ...


class OutletEventAccessorsProtocol(Protocol):
Expand Down
68 changes: 54 additions & 14 deletions task-sdk/tests/task_sdk/definitions/test_asset_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from airflow.sdk.definitions.asset.decorators import _AssetMainOperator, asset
from airflow.sdk.definitions.decorators import task
from airflow.sdk.execution_time.comms import AssetResult, GetAssetByName
from airflow.sdk.execution_time.context import OutletEventAccessors


@pytest.fixture
Expand Down Expand Up @@ -77,6 +78,15 @@ def _example_asset_func(
return _example_asset_func


@pytest.fixture
def example_asset_func_with_outlet_events(func_fixer):
@func_fixer
def _example_asset_func(self, outlet_events):
return "This is example_asset"

return _example_asset_func


class TestAssetDecorator:
def test_without_uri(self, example_asset_func):
asset_definition = asset(schedule=None)(example_asset_func)
Expand Down Expand Up @@ -392,17 +402,18 @@ def test_determine_kwargs(
python_callable=example_asset_func_with_valid_arg_as_inlet_asset,
definition_name="example_asset_func",
)
assert op.determine_kwargs(context={"k": "v"}) == {
"self": Asset(
name="example_asset_func",
uri="s3://bucket/object",
group="MLModel",
extra={"k": "v"},
),
"context": {"k": "v"},
"inlet_asset_1": Asset(name="inlet_asset_1", uri="s3://bucket/object1"),
"inlet_asset_2": Asset(name="inlet_asset_2"),
}
outlet_events = OutletEventAccessors()
context = {"k": "v", "outlet_events": outlet_events}
kwargs = op.determine_kwargs(context=context)
assert kwargs["self"] == Asset(
name="example_asset_func",
uri="s3://bucket/object",
group="MLModel",
extra={"k": "v"},
)
assert kwargs["context"] is context
assert kwargs["inlet_asset_1"] == Asset(name="inlet_asset_1", uri="s3://bucket/object1")
assert kwargs["inlet_asset_2"] == Asset(name="inlet_asset_2")

assert mock_supervisor_comms.mock_calls == [
mock.call.send(GetAssetByName(name="example_asset_func")),
Expand Down Expand Up @@ -453,10 +464,39 @@ def example_asset_func(self):
AssetResult(name="custom_name", uri="s3://bucket/object1", group="Asset")
]

assert op.determine_kwargs(context={}) == {
"self": Asset(name="custom_name", uri="s3://bucket/object1", group="Asset")
}
kwargs = op.determine_kwargs(context={"outlet_events": OutletEventAccessors()})
assert list(kwargs) == ["self"]
assert kwargs["self"] == Asset(name="custom_name", uri="s3://bucket/object1", group="Asset")

assert mock_supervisor_comms.mock_calls == [
mock.call.send(GetAssetByName(name="custom_name", uri="s3://bucket/object1", group="Asset"))
]


class TestOutletEventsKwarg:
def test_determine_kwargs_injects_outlet_events(
self, mock_supervisor_comms, example_asset_func_with_outlet_events
):
definition = asset(schedule=None)(example_asset_func_with_outlet_events)
outlet_events = OutletEventAccessors()
context = {"outlet_events": outlet_events}

mock_supervisor_comms.send.side_effect = [
AssetResult(name="example_asset_func", uri="example_asset_func", group="asset"),
]

op = _AssetMainOperator(
task_id="example_asset_func",
inlets=[],
outlets=[definition],
python_callable=example_asset_func_with_outlet_events,
definition_name="example_asset_func",
)

kwargs = op.determine_kwargs(context=context)
assert kwargs["outlet_events"] is outlet_events

def test_from_definition_excludes_outlet_events_from_inlets(self, example_asset_func_with_outlet_events):
definition = asset(schedule=None)(example_asset_func_with_outlet_events)
op = _AssetMainOperator.from_definition(definition)
assert op.inlets == []
Loading
Loading