Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions airflow-core/src/airflow/dag_processing/bundles/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,10 +304,12 @@ def __init__(
name: str,
refresh_interval: int = conf.getint("dag_processor", "refresh_interval"),
version: str | None = None,
version_data: dict[str, Any] | None = None,
view_url_template: str | None = None,
) -> None:
self.name = name
self.version = version
self.version_data = version_data
self.refresh_interval = refresh_interval
self.is_initialized: bool = False

Expand Down
9 changes: 7 additions & 2 deletions airflow-core/src/airflow/dag_processing/bundles/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,19 +324,24 @@ def _extract_template_params(bundle_instance: BaseDagBundle) -> dict:

return params

def get_bundle(self, name: str, version: str | None = None) -> BaseDagBundle:
def get_bundle(
self, name: str, version: str | None = None, version_data: dict | None = None
) -> BaseDagBundle:
"""
Get a DAG bundle by name.

:param name: The name of the DAG bundle.
:param version: The version of the DAG bundle you need (optional). If not provided, ``tracking_ref`` will be used instead.
:param version_data: Optional structured data associated with this version (e.g., S3 manifest).

:return: The DAG bundle.
"""
cfg_bundle = self._bundle_config.get(name)
if not cfg_bundle:
raise ValueError(f"Requested bundle '{name}' is not configured.")
return cfg_bundle.bundle_class(name=name, version=version, **cfg_bundle.kwargs)
return cfg_bundle.bundle_class(
name=name, version=version, version_data=version_data, **cfg_bundle.kwargs
)

def get_all_dag_bundles(self) -> Iterable[BaseDagBundle]:
"""
Expand Down
3 changes: 2 additions & 1 deletion airflow-core/src/airflow/executors/workloads/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import os
from abc import ABC, abstractmethod
from collections.abc import Hashable
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

from pydantic import BaseModel, ConfigDict, Field

Expand Down Expand Up @@ -66,6 +66,7 @@ class BundleInfo(BaseModel):

name: str
version: str | None = None
version_data: dict[str, Any] | None = None


class BaseWorkloadSchema(BaseModel):
Expand Down
4 changes: 4 additions & 0 deletions airflow-core/src/airflow/executors/workloads/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,13 @@ def make(

ser_ti = TaskInstanceDTO.model_validate(ti, from_attributes=True)
if not bundle_info:
version_data = None
if ti.dag_version is not None:
version_data = ti.dag_version.version_data
bundle_info = BundleInfo(
name=ti.dag_model.bundle_name,
version=ti.dag_run.bundle_version,
version_data=version_data,
)
fname = log_filename_template_renderer()(ti=ti)

Expand Down
4 changes: 1 addition & 3 deletions airflow-core/src/airflow/executors/workloads/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,4 @@
def state_class_for_key(key: WorkloadKey) -> type[TaskInstanceState] | type[CallbackState]:
if isinstance(key, TaskInstanceKey):
return TaskInstanceState
if isinstance(key, CallbackKey):
return CallbackState
raise TypeError(f"Unknown workload key type: {type(key)!r}")
return CallbackState
1 change: 1 addition & 0 deletions airflow-core/src/airflow/jobs/scheduler_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,7 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session) -
ranked_query.c.map_index_for_ordering,
)
.options(selectinload(TI.dag_model))
.options(selectinload(TI.dag_version))
)

query = query.limit(max_tis)
Expand Down
13 changes: 13 additions & 0 deletions airflow-core/tests/unit/dag_processing/bundles/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,3 +323,16 @@ def test_bundle_version_inequality(self):
bv1 = BundleVersion(version="abc", data={"key": "val"})
bv2 = BundleVersion(version="abc", data={"key": "other"})
assert bv1 != bv2


def test_version_data_stored_on_bundle():
"""Test that version_data passed to a bundle constructor is stored on the instance."""
manifest = {"schema_version": 1, "files": {"dags/my_dag.py": "S3VersionId123"}}
bundle = BasicBundle(name="test", version="abc", version_data=manifest)
assert bundle.version_data == manifest


def test_version_data_defaults_to_none():
"""Test that version_data defaults to None when not provided."""
bundle = BasicBundle(name="test")
assert bundle.version_data is None
9 changes: 5 additions & 4 deletions airflow-core/tests/unit/executors/test_workloads.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,12 @@ def test_callback_key_is_not_a_string():
assert not isinstance(key, str)


def test_state_class_for_key_raises_on_unknown_type():
"""state_class_for_key should raise TypeError for unrecognized key types."""
def test_state_class_for_key_falls_back_to_callback_state():
"""state_class_for_key should fall back to CallbackState for non-TaskInstanceKey types."""
from airflow.utils.state import CallbackState

with pytest.raises(TypeError, match="Unknown workload key type"):
state_class_for_key("bare-string-is-not-a-key") # type: ignore[arg-type]
result = state_class_for_key("bare-string-is-not-a-key") # type: ignore[arg-type]
assert result is CallbackState


def test_callback_dto_key_returns_callback_key_instance():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1308,6 +1308,7 @@ def test_try_adopt_task_instances(self, mock_executor):
task.dag_model = mock.Mock()
task.dag_model.bundle_name = "test_bundle"
task.dag_model.relative_fileloc = "test_dag.py"
task.dag_version = mock.Mock(version_data=None)
task.dag_run = mock.Mock()
task.dag_run.bundle_version = "1.0.0"
task.dag_run.context_carrier = {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -885,7 +885,7 @@ def test_celery_tasks_registered_on_import():
)


@pytest.mark.skipif(not AIRFLOW_V_3_2_PLUS, reason="ExecuteCallback requires Airflow 3.2+")
@pytest.mark.skipif(not AIRFLOW_V_3_3_PLUS, reason="CallbackKey dataclass requires Airflow 3.3+")
@pytest.mark.parametrize(
("callback_data", "expected_queue"),
[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -959,6 +959,12 @@ components:
- type: string
- type: 'null'
title: Version
version_data:
anyOf:
- additionalProperties: true
type: object
- type: 'null'
title: Version Data
type: object
required:
- name
Expand Down
1 change: 1 addition & 0 deletions task-sdk/src/airflow/sdk/api/datamodels/_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,7 @@ class BundleInfo(BaseModel):

name: Annotated[str, Field(title="Name")]
version: Annotated[str | None, Field(title="Version")] = None
version_data: Annotated[dict[str, Any] | None, Field(title="Version Data")] = None


class TerminalTIState(str, Enum):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ def _target():
bundle = DagBundlesManager().get_bundle(
name=bundle_info.name,
version=bundle_info.version,
version_data=bundle_info.version_data,
)
bundle.initialize()
if (bundle_path := str(bundle.path)) not in sys.path:
Expand Down
1 change: 1 addition & 0 deletions task-sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,6 +806,7 @@ def parse(what: StartupDetails, log: Logger) -> RuntimeTaskInstance:
bundle_instance = DagBundlesManager().get_bundle(
name=bundle_info.name,
version=bundle_info.version,
version_data=bundle_info.version_data,
)
bundle_instance.initialize()
_verify_bundle_access(bundle_instance, log)
Expand Down
Loading