diff --git a/airflow-core/docs/authoring-and-scheduling/assets.rst b/airflow-core/docs/authoring-and-scheduling/assets.rst index a764bc33449e6..6f1e9401bb099 100644 --- a/airflow-core/docs/authoring-and-scheduling/assets.rst +++ b/airflow-core/docs/authoring-and-scheduling/assets.rst @@ -188,6 +188,8 @@ Declaring an ``@asset`` automatically creates: * A ``DAG`` with *dag_id* set to the function name. * A task inside the ``DAG`` with *task_id* set to the function name, and *outlet* to the created ``Asset``. +The parameter names ``self``, ``context``, and ``outlet_events`` are **reserved** in an ``@asset`` function: they are populated by Airflow at runtime (with the asset itself, the execution context, and the outlet event accessor respectively) and are never treated as inlet asset references. + Attaching extra information to an emitting asset event ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/asset_event.py b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/asset_event.py index f6c3ce826698b..9daa7ce6d90d0 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/asset_event.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/asset_event.py @@ -19,6 +19,7 @@ from datetime import datetime +from pydantic import Field from pydantic.types import JsonValue from airflow.api_fastapi.core_api.base import BaseModel, StrictBaseModel @@ -54,6 +55,7 @@ class AssetEventResponse(BaseModel): source_run_id: str | None = None source_map_index: int | None = None partition_key: str | None = None + partition_keys: list[str] = Field(default_factory=list) class AssetEventsResponse(BaseModel): diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py b/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py index dfa27f53ebd91..ccbc8c8fe5269 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py @@ -32,6 +32,7 @@ AddDagEndpoint, AddDagRunDetailEndpoint, AddNoteField, + AddOutletPartitionKeysField, AddPartitionKeyField, AddRunAfterField, AddTaskInstanceStartDateField, @@ -54,6 +55,7 @@ Version( "2026-04-06", AddPartitionKeyField, + AddOutletPartitionKeysField, MovePreviousRunEndpoint, AddDagRunDetailEndpoint, MakeDagRunStartDateNullable, diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_04_06.py b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_04_06.py index 5a3b0f0f5fc2d..4029bed6eb39d 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_04_06.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_04_06.py @@ -61,6 +61,23 @@ def remove_partition_key_from_asset_events(response: ResponseInfo) -> None: # t elem.pop("partition_key", None) +class AddOutletPartitionKeysField(VersionChange): + """Add the `partition_keys` field to AssetEventResponse for runtime-partitioned outlet events.""" + + description = __doc__ + + instructions_to_migrate_to_previous_version = ( + schema(AssetEventResponse).field("partition_keys").didnt_exist, + ) + + @convert_response_to_previous_version_for(AssetEventsResponse) # type: ignore[arg-type] + def remove_partition_keys_from_asset_events(response: ResponseInfo) -> None: # type: ignore[misc] + """Remove the `partition_keys` field from each asset event when converting to the previous version.""" + events = response.body["asset_events"] + for elem in events: + elem.pop("partition_keys", None) + + class MovePreviousRunEndpoint(VersionChange): """Add new previous-run endpoint and migrate old endpoint.""" diff --git a/airflow-core/src/airflow/serialization/encoders.py b/airflow-core/src/airflow/serialization/encoders.py index 87d7ef1f137e9..e543a28dfe525 100644 --- a/airflow-core/src/airflow/serialization/encoders.py +++ b/airflow-core/src/airflow/serialization/encoders.py @@ -55,6 +55,7 @@ from airflow.sdk.definitions.partition_mappers.temporal import StartOfHourMapper from airflow.sdk.definitions.timetables.assets import ( AssetTriggeredTimetable, + PartitionAtRuntime, PartitionedAssetTimetable, ) from airflow.sdk.definitions.timetables.simple import ContinuousTimetable, NullTimetable, OnceTimetable @@ -293,6 +294,7 @@ class _Serializer: MultipleCronTriggerTimetable: "airflow.timetables.trigger.MultipleCronTriggerTimetable", NullTimetable: "airflow.timetables.simple.NullTimetable", OnceTimetable: "airflow.timetables.simple.OnceTimetable", + PartitionAtRuntime: "airflow.timetables.simple.PartitionAtRuntime", PartitionedAssetTimetable: "airflow.timetables.simple.PartitionedAssetTimetable", } @@ -318,7 +320,10 @@ def serialize_timetable(self, timetable: BaseTimetable | CoreTimetable) -> dict[ @serialize_timetable.register(ContinuousTimetable) @serialize_timetable.register(NullTimetable) @serialize_timetable.register(OnceTimetable) - def _(self, timetable: ContinuousTimetable | NullTimetable | OnceTimetable) -> dict[str, Any]: + @serialize_timetable.register(PartitionAtRuntime) + def _( + self, timetable: ContinuousTimetable | NullTimetable | OnceTimetable | PartitionAtRuntime + ) -> dict[str, Any]: return {} @serialize_timetable.register diff --git a/airflow-core/src/airflow/timetables/base.py b/airflow-core/src/airflow/timetables/base.py index a92bcd7a1f85d..b90377cb295a8 100644 --- a/airflow-core/src/airflow/timetables/base.py +++ b/airflow-core/src/airflow/timetables/base.py @@ -218,6 +218,13 @@ class Timetable(Protocol): instead of the traditional logic based on logical dates and data intervals. """ + partitioned_at_runtime: bool = False + """Whether this timetable defers partition selection to task runtime. + + *True* for :class:`~airflow.timetables.simple.PartitionAtRuntime`; + downstream code can branch on this flag instead of using ``isinstance``. + """ + @classmethod def deserialize(cls, data: dict[str, Any]) -> Timetable: """ diff --git a/airflow-core/src/airflow/timetables/simple.py b/airflow-core/src/airflow/timetables/simple.py index 01fb12f81dd0c..086e1153d618a 100644 --- a/airflow-core/src/airflow/timetables/simple.py +++ b/airflow-core/src/airflow/timetables/simple.py @@ -93,6 +93,7 @@ class NullTimetable(_TrivialTimetable): """ can_be_scheduled = False # TODO (GH-52141): Find a way to keep this and one in Core in sync. + partitioned_at_runtime = False description: str = "Never, external triggers only" @property @@ -183,6 +184,21 @@ def next_dagrun_info( return DagRunInfo.interval(start, end) +class PartitionAtRuntime(NullTimetable): + """ + Timetable that never schedules anything; partition keys are set at runtime. + + This corresponds to ``schedule=PartitionAtRuntime()``. + """ + + description: str = "Never, partition key(s) set at runtime" + partitioned_at_runtime = True + + @property + def summary(self) -> str: + return "PartitionAtRuntime" + + class AssetTriggeredTimetable(_TrivialTimetable): """ Timetable that never schedules anything. diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_asset_events.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_asset_events.py index e3839f19eafe8..915db1dfa929d 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_asset_events.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_asset_events.py @@ -124,6 +124,7 @@ def test_get_by_asset(self, uri, name, client): }, "timestamp": "2021-01-01T00:00:00Z", "partition_key": None, + "partition_keys": [], }, { "id": 2, @@ -141,6 +142,7 @@ def test_get_by_asset(self, uri, name, client): "created_dagruns": [], "timestamp": "2021-01-02T00:00:00Z", "partition_key": None, + "partition_keys": [], }, { "id": 3, @@ -158,6 +160,7 @@ def test_get_by_asset(self, uri, name, client): "created_dagruns": [], "timestamp": "2021-01-03T00:00:00Z", "partition_key": None, + "partition_keys": [], }, ] } @@ -195,6 +198,7 @@ def test_get_by_asset_with_after_filter(self, uri, name, client): "created_dagruns": [], "timestamp": "2021-01-02T00:00:00Z", "partition_key": None, + "partition_keys": [], }, { "id": 3, @@ -212,6 +216,7 @@ def test_get_by_asset_with_after_filter(self, uri, name, client): "created_dagruns": [], "timestamp": "2021-01-03T00:00:00Z", "partition_key": None, + "partition_keys": [], }, ] } @@ -249,6 +254,7 @@ def test_get_by_asset_with_before_filter(self, uri, name, client): "created_dagruns": [], "timestamp": "2021-01-01T00:00:00Z", "partition_key": None, + "partition_keys": [], }, { "id": 2, @@ -266,6 +272,7 @@ def test_get_by_asset_with_before_filter(self, uri, name, client): "created_dagruns": [], "timestamp": "2021-01-02T00:00:00Z", "partition_key": None, + "partition_keys": [], }, ] } @@ -308,6 +315,7 @@ def test_get_by_asset_with_before_and_after_filters(self, uri, name, client): "created_dagruns": [], "timestamp": "2021-01-02T00:00:00Z", "partition_key": None, + "partition_keys": [], }, ] } @@ -345,6 +353,7 @@ def test_get_by_asset_with_descending_order(self, uri, name, client): "created_dagruns": [], "timestamp": "2021-01-03T00:00:00Z", "partition_key": None, + "partition_keys": [], }, { "id": 2, @@ -362,6 +371,7 @@ def test_get_by_asset_with_descending_order(self, uri, name, client): "created_dagruns": [], "timestamp": "2021-01-02T00:00:00Z", "partition_key": None, + "partition_keys": [], }, { "id": 1, @@ -379,6 +389,7 @@ def test_get_by_asset_with_descending_order(self, uri, name, client): "created_dagruns": [], "timestamp": "2021-01-01T00:00:00Z", "partition_key": None, + "partition_keys": [], }, ] } @@ -416,6 +427,7 @@ def test_get_by_asset_get_first(self, uri, name, client): "created_dagruns": [], "timestamp": "2021-01-01T00:00:00Z", "partition_key": None, + "partition_keys": [], }, ] } @@ -453,6 +465,7 @@ def test_get_by_asset_get_last(self, uri, name, client): "created_dagruns": [], "timestamp": "2021-01-03T00:00:00Z", "partition_key": None, + "partition_keys": [], }, ] } @@ -484,6 +497,7 @@ def test_get_by_asset(self, client): "created_dagruns": [], "timestamp": "2021-01-01T00:00:00Z", "partition_key": None, + "partition_keys": [], }, { "id": 2, @@ -501,6 +515,7 @@ def test_get_by_asset(self, client): "created_dagruns": [], "timestamp": "2021-01-02T00:00:00Z", "partition_key": None, + "partition_keys": [], }, { "id": 3, @@ -518,6 +533,7 @@ def test_get_by_asset(self, client): "created_dagruns": [], "timestamp": "2021-01-03T00:00:00Z", "partition_key": None, + "partition_keys": [], }, ] } diff --git a/task-sdk/docs/api.rst b/task-sdk/docs/api.rst index cb9789f5bb69f..9aa0fd66770c1 100644 --- a/task-sdk/docs/api.rst +++ b/task-sdk/docs/api.rst @@ -209,6 +209,8 @@ Timetables .. autoapiclass:: airflow.sdk.MultipleCronTriggerTimetable +.. autoapiclass:: airflow.sdk.PartitionAtRuntime + .. autoapiclass:: airflow.sdk.PartitionedAssetTimetable diff --git a/task-sdk/src/airflow/sdk/__init__.py b/task-sdk/src/airflow/sdk/__init__.py index f304b068237b3..be4d16f3de8bb 100644 --- a/task-sdk/src/airflow/sdk/__init__.py +++ b/task-sdk/src/airflow/sdk/__init__.py @@ -59,6 +59,7 @@ "ObjectStoragePath", "Param", "ParamsDict", + "PartitionAtRuntime", "PartitionedAssetTimetable", "PartitionMapper", "PokeReturnValue", @@ -118,7 +119,13 @@ from airflow.sdk.bases.skipmixin import SkipMixin from airflow.sdk.bases.xcom import BaseXCom from airflow.sdk.configuration import AirflowSDKConfigParser - from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAll, AssetAny, AssetWatcher + from airflow.sdk.definitions.asset import ( + Asset, + AssetAlias, + AssetAll, + AssetAny, + AssetWatcher, + ) from airflow.sdk.definitions.asset.decorators import asset from airflow.sdk.definitions.asset.metadata import Metadata from airflow.sdk.definitions.callback import AsyncCallback, SyncCallback @@ -154,6 +161,7 @@ from airflow.sdk.definitions.template import literal from airflow.sdk.definitions.timetables.assets import ( AssetOrTimeSchedule, + PartitionAtRuntime, PartitionedAssetTimetable, ) from airflow.sdk.definitions.timetables.events import EventsTimetable @@ -215,6 +223,7 @@ "ObjectStoragePath": ".io.path", "Param": ".definitions.param", "ParamsDict": ".definitions.param", + "PartitionAtRuntime": ".definitions.timetables.assets", "PartitionedAssetTimetable": ".definitions.timetables.assets", "PartitionMapper": ".definitions.partition_mappers.base", "PokeReturnValue": ".bases.sensor", diff --git a/task-sdk/src/airflow/sdk/__init__.pyi b/task-sdk/src/airflow/sdk/__init__.pyi index 7e6d211674eba..d0b4af5d9e5bd 100644 --- a/task-sdk/src/airflow/sdk/__init__.pyi +++ b/task-sdk/src/airflow/sdk/__init__.pyi @@ -86,6 +86,7 @@ from airflow.sdk.definitions.taskgroup import TaskGroup as TaskGroup from airflow.sdk.definitions.template import literal as literal from airflow.sdk.definitions.timetables.assets import ( AssetOrTimeSchedule, + PartitionAtRuntime, PartitionedAssetTimetable, ) from airflow.sdk.definitions.timetables.events import EventsTimetable @@ -145,6 +146,7 @@ __all__ = [ "ObjectStoragePath", "Param", "PokeReturnValue", + "PartitionAtRuntime", "PartitionedAssetTimetable", "PartitionMapper", "ProductMapper", diff --git a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py index b5b100154c389..a61436fbfea4c 100644 --- a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py +++ b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py @@ -660,6 +660,7 @@ class AssetEventResponse(BaseModel): source_run_id: Annotated[str | None, Field(title="Source Run Id")] = None source_map_index: Annotated[int | None, Field(title="Source Map Index")] = None partition_key: Annotated[str | None, Field(title="Partition Key")] = None + partition_keys: Annotated[list[str] | None, Field(title="Partition Keys")] = None class AssetEventsResponse(BaseModel): diff --git a/task-sdk/src/airflow/sdk/bases/timetable.py b/task-sdk/src/airflow/sdk/bases/timetable.py index e732566f1533f..bd37a6a0a7955 100644 --- a/task-sdk/src/airflow/sdk/bases/timetable.py +++ b/task-sdk/src/airflow/sdk/bases/timetable.py @@ -47,6 +47,14 @@ class BaseTimetable: asset_condition: BaseAsset | None = None + partitioned_at_runtime: bool = False + """ + Whether this timetable defers partition selection to task runtime. + + *True* for :class:`~airflow.sdk.PartitionAtRuntime`; downstream code can + branch on this flag instead of using ``isinstance``. + """ + def validate(self) -> None: """ Validate the timetable is correctly specified. diff --git a/task-sdk/src/airflow/sdk/definitions/asset/decorators.py b/task-sdk/src/airflow/sdk/definitions/asset/decorators.py index 26205f0c58335..a91a70d565e4f 100644 --- a/task-sdk/src/airflow/sdk/definitions/asset/decorators.py +++ b/task-sdk/src/airflow/sdk/definitions/asset/decorators.py @@ -36,9 +36,33 @@ from airflow.sdk.bases.decorator import _TaskDecorator from airflow.sdk.definitions.dag import DagStateChangeCallback, ScheduleArg from airflow.sdk.definitions.param import ParamsDict + from airflow.sdk.types import OutletEventAccessorsProtocol from airflow.triggers.base import BaseTrigger +_INVALID_INLET_ASSET_NAMES = ("self", "context", "outlet_events") + + +class _AssetSelfProxy: + """Proxy for ``self`` in ``@asset`` functions; intercepts ``partition_keys`` writes and forwards them to the outlet event accessor.""" + + def __init__(self, asset: Asset, outlet_events: OutletEventAccessorsProtocol) -> None: + object.__setattr__(self, "_asset", asset) + object.__setattr__(self, "_outlet_events", outlet_events) + + def __getattr__(self, name: str) -> Any: + if name == "partition_keys": + return self._outlet_events[self._asset].partition_keys + return getattr(self._asset, name) + + def __setattr__(self, name: str, value: Any) -> None: + if name != "partition_keys": + raise AttributeError( + f"Cannot set {name!r} on @asset self; only 'partition_keys' is settable at runtime" + ) + self._outlet_events[self._asset].partition_keys = value + + def _validate_asset_function_arguments(f: Callable) -> None: for name, param in inspect.signature(f).parameters.items(): if param.kind == inspect.Parameter.VAR_POSITIONAL: @@ -62,7 +86,8 @@ def from_definition(cls, definition: AssetDefinition | MultiAssetDefinition) -> inlets=[ Asset.ref(name=inlet_asset_name) for inlet_asset_name, param in inspect.signature(definition._function).parameters.items() - if inlet_asset_name not in ("self", "context") and param.default is inspect.Parameter.empty + if inlet_asset_name not in _INVALID_INLET_ASSET_NAMES + and param.default is inspect.Parameter.empty ], outlets=list(definition.iter_outlets()), python_callable=definition._function, @@ -86,9 +111,11 @@ def _fetch_asset(name: str) -> Asset: if param.default is not inspect.Parameter.empty: value = param.default elif key == "self": - value = _fetch_asset(self._definition_name) + value = _AssetSelfProxy(_fetch_asset(self._definition_name), context["outlet_events"]) elif key == "context": value = context + elif key == "outlet_events": + value = context["outlet_events"] else: value = _fetch_asset(key) yield key, value diff --git a/task-sdk/src/airflow/sdk/definitions/timetables/assets.py b/task-sdk/src/airflow/sdk/definitions/timetables/assets.py index e6bb683ebcadb..a0c6493692572 100644 --- a/task-sdk/src/airflow/sdk/definitions/timetables/assets.py +++ b/task-sdk/src/airflow/sdk/definitions/timetables/assets.py @@ -54,6 +54,13 @@ class PartitionedAssetTimetable(AssetTriggeredTimetable): default_partition_mapper: PartitionMapper = IdentityMapper() +class PartitionAtRuntime(BaseTimetable): + """Marker timetable indicating that partition key(s) are determined at runtime.""" + + can_be_scheduled = False + partitioned_at_runtime = True + + def _coerce_assets(o: Collection[Asset] | BaseAsset) -> BaseAsset: if isinstance(o, BaseAsset): return o diff --git a/task-sdk/src/airflow/sdk/execution_time/context.py b/task-sdk/src/airflow/sdk/execution_time/context.py index 66c1f3aa8b7eb..3af763ae95bb2 100644 --- a/task-sdk/src/airflow/sdk/execution_time/context.py +++ b/task-sdk/src/airflow/sdk/execution_time/context.py @@ -488,6 +488,13 @@ class OutletEventAccessor(_AssetRefResolutionMixin): key: BaseAssetUniqueKey extra: dict[str, JsonValue] = attrs.Factory(dict) asset_alias_events: list[AssetAliasEvent] = attrs.field(factory=list) + partition_keys: list[str] = attrs.field(factory=list) + + def add_partitions(self, keys: str | list[str]) -> None: + """Append a partition key to :attr:`partition_keys`.""" + if isinstance(keys, str): + keys = [keys] + self.partition_keys.extend(keys) def add(self, asset: Asset | AssetRef, extra: dict[str, JsonValue] | None = None) -> None: """Add an AssetEvent to an existing Asset.""" diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 56ba8343c648b..082e0c3720b28 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -58,7 +58,13 @@ from airflow.sdk.configuration import conf from airflow.sdk.definitions._internal.dag_parsing_context import _airflow_parsing_context_manager from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet, is_arg_set -from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetNameRef, AssetUniqueKey, AssetUriRef +from airflow.sdk.definitions.asset import ( + Asset, + AssetAlias, + AssetNameRef, + AssetUniqueKey, + AssetUriRef, +) from airflow.sdk.definitions.mappedoperator import MappedOperator from airflow.sdk.definitions.param import process_params from airflow.sdk.exceptions import ( @@ -1138,7 +1144,11 @@ def _serialize_outlet_events(events: OutletEventAccessorsProtocol) -> Iterator[d # Further filtering will be done in the API server. for key, accessor in events._dict.items(): if isinstance(key, AssetUniqueKey): - yield {"dest_asset_key": attrs.asdict(key), "extra": accessor.extra} + yield { + "dest_asset_key": attrs.asdict(key), + "extra": accessor.extra, + "partition_keys": list(accessor.partition_keys), + } for alias_event in accessor.asset_alias_events: yield attrs.asdict(alias_event) diff --git a/task-sdk/src/airflow/sdk/types.py b/task-sdk/src/airflow/sdk/types.py index 262e1206a30aa..e1181a531fad2 100644 --- a/task-sdk/src/airflow/sdk/types.py +++ b/task-sdk/src/airflow/sdk/types.py @@ -42,7 +42,13 @@ TaskInstanceState, ) from airflow.sdk.bases.operator import BaseOperator - from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAliasEvent, AssetRef, BaseAssetUniqueKey + from airflow.sdk.definitions.asset import ( + Asset, + AssetAlias, + AssetAliasEvent, + AssetRef, + BaseAssetUniqueKey, + ) from airflow.sdk.definitions.context import Context from airflow.sdk.definitions.mappedoperator import MappedOperator from airflow.sdk.execution_time.comms import DagResult @@ -215,6 +221,7 @@ class OutletEventAccessorProtocol(Protocol): key: BaseAssetUniqueKey extra: dict[str, JsonValue] asset_alias_events: list[AssetAliasEvent] + partition_keys: list[str] def __init__( self, @@ -222,8 +229,10 @@ def __init__( key: BaseAssetUniqueKey, extra: dict[str, JsonValue], asset_alias_events: list[AssetAliasEvent], + partition_keys: list[str] = ..., ) -> None: ... def add(self, asset: Asset, extra: dict[str, JsonValue] | None = None) -> None: ... + def add_partitions(self, keys: str | list[str]) -> None: ... class OutletEventAccessorsProtocol(Protocol): diff --git a/task-sdk/tests/task_sdk/definitions/test_asset_decorators.py b/task-sdk/tests/task_sdk/definitions/test_asset_decorators.py index c0b3b617fc068..3cda0760ab901 100644 --- a/task-sdk/tests/task_sdk/definitions/test_asset_decorators.py +++ b/task-sdk/tests/task_sdk/definitions/test_asset_decorators.py @@ -21,9 +21,10 @@ import pytest from airflow.sdk.definitions.asset import Asset -from airflow.sdk.definitions.asset.decorators import _AssetMainOperator, asset +from airflow.sdk.definitions.asset.decorators import _AssetMainOperator, _AssetSelfProxy, asset from airflow.sdk.definitions.decorators import task from airflow.sdk.execution_time.comms import AssetResult, GetAssetByName +from airflow.sdk.execution_time.context import OutletEventAccessors @pytest.fixture @@ -77,6 +78,15 @@ def _example_asset_func( return _example_asset_func +@pytest.fixture +def example_asset_func_with_outlet_events(func_fixer): + @func_fixer + def _example_asset_func(self, outlet_events): + return "This is example_asset" + + return _example_asset_func + + class TestAssetDecorator: def test_without_uri(self, example_asset_func): asset_definition = asset(schedule=None)(example_asset_func) @@ -392,17 +402,19 @@ def test_determine_kwargs( python_callable=example_asset_func_with_valid_arg_as_inlet_asset, definition_name="example_asset_func", ) - assert op.determine_kwargs(context={"k": "v"}) == { - "self": Asset( - name="example_asset_func", - uri="s3://bucket/object", - group="MLModel", - extra={"k": "v"}, - ), - "context": {"k": "v"}, - "inlet_asset_1": Asset(name="inlet_asset_1", uri="s3://bucket/object1"), - "inlet_asset_2": Asset(name="inlet_asset_2"), - } + outlet_events = OutletEventAccessors() + context = {"k": "v", "outlet_events": outlet_events} + kwargs = op.determine_kwargs(context=context) + assert isinstance(kwargs["self"], _AssetSelfProxy) + assert kwargs["self"]._asset == Asset( + name="example_asset_func", + uri="s3://bucket/object", + group="MLModel", + extra={"k": "v"}, + ) + assert kwargs["context"] is context + assert kwargs["inlet_asset_1"] == Asset(name="inlet_asset_1", uri="s3://bucket/object1") + assert kwargs["inlet_asset_2"] == Asset(name="inlet_asset_2") assert mock_supervisor_comms.mock_calls == [ mock.call.send(GetAssetByName(name="example_asset_func")), @@ -453,10 +465,71 @@ def example_asset_func(self): AssetResult(name="custom_name", uri="s3://bucket/object1", group="Asset") ] - assert op.determine_kwargs(context={}) == { - "self": Asset(name="custom_name", uri="s3://bucket/object1", group="Asset") - } + kwargs = op.determine_kwargs(context={"outlet_events": OutletEventAccessors()}) + assert list(kwargs) == ["self"] + assert isinstance(kwargs["self"], _AssetSelfProxy) + assert kwargs["self"]._asset == Asset(name="custom_name", uri="s3://bucket/object1", group="Asset") assert mock_supervisor_comms.mock_calls == [ mock.call.send(GetAssetByName(name="custom_name", uri="s3://bucket/object1", group="Asset")) ] + + +class TestAssetSelfProxy: + @pytest.fixture + def asset(self): + return Asset(name="a", uri="s3://bucket/a") + + @pytest.fixture + def outlet_events(self): + return OutletEventAccessors() + + def test_read_forwards_to_asset(self, asset, outlet_events): + proxy = _AssetSelfProxy(asset, outlet_events) + assert proxy.name == "a" + assert proxy.uri == "s3://bucket/a" + + def test_partition_keys_read_forwards_to_accessor(self, asset, outlet_events): + outlet_events[asset].partition_keys = ["us"] + proxy = _AssetSelfProxy(asset, outlet_events) + assert proxy.partition_keys == ["us"] + + def test_partition_keys_write_forwards_to_accessor(self, asset, outlet_events): + proxy = _AssetSelfProxy(asset, outlet_events) + proxy.partition_keys = ["us", "eu"] + assert outlet_events[asset].partition_keys == ["us", "eu"] + + @pytest.mark.parametrize("name", ["name", "uri", "extra", "group"]) + def test_setting_other_attributes_raises(self, asset, outlet_events, name): + proxy = _AssetSelfProxy(asset, outlet_events) + with pytest.raises(AttributeError, match="only 'partition_keys' is settable at runtime"): + setattr(proxy, name, "anything") + + +class TestOutletEventsKwarg: + def test_determine_kwargs_injects_outlet_events( + self, mock_supervisor_comms, example_asset_func_with_outlet_events + ): + definition = asset(schedule=None)(example_asset_func_with_outlet_events) + outlet_events = OutletEventAccessors() + context = {"outlet_events": outlet_events} + + mock_supervisor_comms.send.side_effect = [ + AssetResult(name="example_asset_func", uri="example_asset_func", group="asset"), + ] + + op = _AssetMainOperator( + task_id="example_asset_func", + inlets=[], + outlets=[definition], + python_callable=example_asset_func_with_outlet_events, + definition_name="example_asset_func", + ) + + kwargs = op.determine_kwargs(context=context) + assert kwargs["outlet_events"] is outlet_events + + def test_from_definition_excludes_outlet_events_from_inlets(self, example_asset_func_with_outlet_events): + definition = asset(schedule=None)(example_asset_func_with_outlet_events) + op = _AssetMainOperator.from_definition(definition) + assert op.inlets == [] diff --git a/task-sdk/tests/task_sdk/definitions/test_dag.py b/task-sdk/tests/task_sdk/definitions/test_dag.py index a7d1bfbd926a2..c42eb8dfc80a1 100644 --- a/task-sdk/tests/task_sdk/definitions/test_dag.py +++ b/task-sdk/tests/task_sdk/definitions/test_dag.py @@ -24,10 +24,12 @@ import pytest -from airflow.sdk import Context, Label, TaskGroup +from airflow.sdk import Context, Label, PartitionAtRuntime, TaskGroup from airflow.sdk.bases.operator import BaseOperator +from airflow.sdk.bases.timetable import BaseTimetable from airflow.sdk.definitions.dag import DAG, dag as dag_decorator from airflow.sdk.definitions.param import DagParam, Param, ParamsDict +from airflow.sdk.definitions.timetables import assets, events, interval, simple, trigger # noqa: F401 from airflow.sdk.exceptions import AirflowDagCycleException, DuplicateTaskIdFound, RemovedInAirflow4Warning from airflow.utils.types import DagRunType @@ -437,6 +439,17 @@ def test_continuous_schedule_linmits_max_active_runs(self): with pytest.raises(ValueError, match="ContinuousTimetable requires max_active_runs <= 1"): dag = DAG("continuous", start_date=DEFAULT_DATE, schedule="@continuous", max_active_runs=25) + def test_only_partition_at_runtime_has_partitioned_at_runtime_flag(self): + """Regression guard: across every BaseTimetable subclass, only PartitionAtRuntime sets partitioned_at_runtime=True.""" + + def all_subclasses(cls): + for sub in cls.__subclasses__(): + yield sub + yield from all_subclasses(sub) + + flagged = {c for c in all_subclasses(BaseTimetable) if c.partitioned_at_runtime} + assert flagged == {PartitionAtRuntime} + def test_dag_add_task_checks_trigger_rule(self): # A non fail stop dag should allow any trigger rule from airflow.sdk import TriggerRule diff --git a/task-sdk/tests/task_sdk/execution_time/test_context.py b/task-sdk/tests/task_sdk/execution_time/test_context.py index 1b0be13ab3ad9..a199d40493bf6 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_context.py +++ b/task-sdk/tests/task_sdk/execution_time/test_context.py @@ -434,6 +434,29 @@ def test_add_with_db(self, add_args, key, asset_alias_events, mock_supervisor_co assert outlet_event_accessor.asset_alias_events == asset_alias_events +class TestOutletEventAccessorPartitionKeys: + @pytest.fixture + def accessor(self) -> OutletEventAccessor: + return OutletEventAccessor(key=AssetUniqueKey.from_asset(Asset("a"))) + + def test_default_is_empty(self, accessor): + assert accessor.partition_keys == [] + + def test_direct_assignment(self, accessor): + accessor.partition_keys = ["us", "eu"] + assert accessor.partition_keys == ["us", "eu"] + + def test_add_partitions(self, accessor): + accessor.add_partitions("us") + assert accessor.partition_keys == ["us"] + + def test_add_partitions_appends(self, accessor): + accessor.add_partitions("us") + accessor.add_partitions("eu") + accessor.add_partitions("apac") + assert accessor.partition_keys == ["us", "eu", "apac"] + + class TestTriggeringAssetEventsAccessor: @pytest.fixture(autouse=True) def clear_cache(self): diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index 630aff9094ed1..1d1eed93d03e4 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -139,6 +139,7 @@ _execute_task, _make_task_span, _push_xcom_if_needed, + _serialize_outlet_events, _xcom_push, finalize, get_startup_details, @@ -1716,6 +1717,34 @@ def test_function(ti): assert ti.rendered_map_index == "Label: test_task" +class TestSerializeOutletEvents: + """Tests for the wire format produced by ``_serialize_outlet_events``.""" + + def test_emits_empty_partition_keys_when_none_set(self): + accessors = OutletEventAccessors() + accessors[Asset(name="a")].extra = {"x": 1} + + events = list(_serialize_outlet_events(accessors)) + + assert events == [ + {"dest_asset_key": {"name": "a", "uri": "a"}, "extra": {"x": 1}, "partition_keys": []} + ] + + def test_emits_partition_keys_from_strings(self): + accessors = OutletEventAccessors() + accessors[Asset(name="a")].partition_keys = ["us", "eu"] + + events = list(_serialize_outlet_events(accessors)) + + assert events == [ + { + "dest_asset_key": {"name": "a", "uri": "a"}, + "extra": {}, + "partition_keys": ["us", "eu"], + } + ] + + class TestRuntimeTaskInstance: def test_get_context_without_ti_context_from_server(self, mocked_parse, make_ti_context): """Test get_template_context without ti_context_from_server."""