Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions airflow-core/src/airflow/dag_processing/bundles/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,10 +304,12 @@ def __init__(
name: str,
refresh_interval: int = conf.getint("dag_processor", "refresh_interval"),
version: str | None = None,
version_data: dict[str, Any] | None = None,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Worth updating docstring to add this param on L295

view_url_template: str | None = None,
) -> None:
self.name = name
self.version = version
self.version_data = version_data
self.refresh_interval = refresh_interval
self.is_initialized: bool = False

Expand Down
9 changes: 7 additions & 2 deletions airflow-core/src/airflow/dag_processing/bundles/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,19 +324,24 @@ def _extract_template_params(bundle_instance: BaseDagBundle) -> dict:

return params

def get_bundle(self, name: str, version: str | None = None) -> BaseDagBundle:
def get_bundle(
self, name: str, version: str | None = None, version_data: dict | None = None
) -> BaseDagBundle:
"""
Get a DAG bundle by name.

:param name: The name of the DAG bundle.
:param version: The version of the DAG bundle you need (optional). If not provided, ``tracking_ref`` will be used instead.
:param version_data: Optional structured data associated with this version (e.g., S3 manifest).

:return: The DAG bundle.
"""
cfg_bundle = self._bundle_config.get(name)
if not cfg_bundle:
raise ValueError(f"Requested bundle '{name}' is not configured.")
return cfg_bundle.bundle_class(name=name, version=version, **cfg_bundle.kwargs)
return cfg_bundle.bundle_class(
name=name, version=version, version_data=version_data, **cfg_bundle.kwargs
)

def get_all_dag_bundles(self) -> Iterable[BaseDagBundle]:
"""
Expand Down
3 changes: 2 additions & 1 deletion airflow-core/src/airflow/executors/workloads/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import os
from abc import ABC, abstractmethod
from collections.abc import Hashable
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

from pydantic import BaseModel, ConfigDict, Field

Expand Down Expand Up @@ -66,6 +66,7 @@ class BundleInfo(BaseModel):

name: str
version: str | None = None
version_data: dict[str, Any] | None = None


class BaseWorkloadSchema(BaseModel):
Expand Down
4 changes: 4 additions & 0 deletions airflow-core/src/airflow/executors/workloads/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,13 @@ def make(

ser_ti = TaskInstanceDTO.model_validate(ti, from_attributes=True)
if not bundle_info:
version_data = None
if ti.dag_version is not None:
version_data = ti.dag_version.version_data
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why of this read off ti, but other things just below are ti.dag_model.bundle*

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Worth flagging that even off the same ti, version and version_data can disagree. Two cases:

  1. Unpinned runs (disable_bundle_versioning=True): dag_run.bundle_version is None, but ti.dag_version.version_data may still carry a manifest -- so version=None, version_data={...}.
  2. After _verify_integrity_if_dag_changed (scheduler_job_runner.py:2521-2530): TI's dag_version_id is bumped to the latest version while dag_run.bundle_version is left untouched, so version_data describes a newer version than version reports.

The scheduler picks a deliberate rule for bundle_version at scheduler_job_runner.py:1438-1442; worth deciding the equivalent rule for version_data here (e.g., is it valid to expose version_data when version is None?).

bundle_info = BundleInfo(
name=ti.dag_model.bundle_name,
version=ti.dag_run.bundle_version,
version_data=version_data,
)
fname = log_filename_template_renderer()(ti=ti)

Expand Down
4 changes: 1 addition & 3 deletions airflow-core/src/airflow/executors/workloads/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,4 @@
def state_class_for_key(key: WorkloadKey) -> type[TaskInstanceState] | type[CallbackState]:
if isinstance(key, TaskInstanceKey):
return TaskInstanceState
if isinstance(key, CallbackKey):
return CallbackState
raise TypeError(f"Unknown workload key type: {type(key)!r}")
return CallbackState
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated change? Bad rebase?

1 change: 1 addition & 0 deletions airflow-core/src/airflow/jobs/scheduler_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,7 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session) -
ranked_query.c.map_index_for_ordering,
)
.options(selectinload(TI.dag_model))
.options(selectinload(TI.dag_version))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, joins aren't free, and this isn't used for most places.

I'm wondering if this needs to be based on what the bundle backend needs somehow?

)

query = query.limit(max_tis)
Expand Down
13 changes: 13 additions & 0 deletions airflow-core/tests/unit/dag_processing/bundles/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,3 +323,16 @@ def test_bundle_version_inequality(self):
bv1 = BundleVersion(version="abc", data={"key": "val"})
bv2 = BundleVersion(version="abc", data={"key": "other"})
assert bv1 != bv2


def test_version_data_stored_on_bundle():
"""Test that version_data passed to a bundle constructor is stored on the instance."""
manifest = {"schema_version": 1, "files": {"dags/my_dag.py": "S3VersionId123"}}
bundle = BasicBundle(name="test", version="abc", version_data=manifest)
assert bundle.version_data == manifest


def test_version_data_defaults_to_none():
"""Test that version_data defaults to None when not provided."""
bundle = BasicBundle(name="test")
assert bundle.version_data is None
9 changes: 5 additions & 4 deletions airflow-core/tests/unit/executors/test_workloads.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,12 @@ def test_callback_key_is_not_a_string():
assert not isinstance(key, str)


def test_state_class_for_key_raises_on_unknown_type():
"""state_class_for_key should raise TypeError for unrecognized key types."""
def test_state_class_for_key_falls_back_to_callback_state():
"""state_class_for_key should fall back to CallbackState for non-TaskInstanceKey types."""
from airflow.utils.state import CallbackState

with pytest.raises(TypeError, match="Unknown workload key type"):
state_class_for_key("bare-string-is-not-a-key") # type: ignore[arg-type]
result = state_class_for_key("bare-string-is-not-a-key") # type: ignore[arg-type]
assert result is CallbackState


def test_callback_dto_key_returns_callback_key_instance():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1308,6 +1308,7 @@ def test_try_adopt_task_instances(self, mock_executor):
task.dag_model = mock.Mock()
task.dag_model.bundle_name = "test_bundle"
task.dag_model.relative_fileloc = "test_dag.py"
task.dag_version = mock.Mock(version_data=None)
task.dag_run = mock.Mock()
task.dag_run.bundle_version = "1.0.0"
task.dag_run.context_carrier = {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -885,7 +885,7 @@ def test_celery_tasks_registered_on_import():
)


@pytest.mark.skipif(not AIRFLOW_V_3_2_PLUS, reason="ExecuteCallback requires Airflow 3.2+")
@pytest.mark.skipif(not AIRFLOW_V_3_3_PLUS, reason="CallbackKey dataclass requires Airflow 3.3+")
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated change to this pr

@pytest.mark.parametrize(
("callback_data", "expected_queue"),
[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -959,6 +959,12 @@ components:
- type: string
- type: 'null'
title: Version
version_data:
anyOf:
- additionalProperties: true
type: object
- type: 'null'
Comment on lines +962 to +966
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the field is present it should be an object.

(I.e. its worker not sent, or its an {...} object )

title: Version Data
type: object
required:
- name
Expand Down
1 change: 1 addition & 0 deletions task-sdk/src/airflow/sdk/api/datamodels/_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,7 @@ class BundleInfo(BaseModel):

name: Annotated[str, Field(title="Name")]
version: Annotated[str | None, Field(title="Version")] = None
version_data: Annotated[dict[str, Any] | None, Field(title="Version Data")] = None


class TerminalTIState(str, Enum):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ def _target():
bundle = DagBundlesManager().get_bundle(
name=bundle_info.name,
version=bundle_info.version,
version_data=bundle_info.version_data,
)
bundle.initialize()
if (bundle_path := str(bundle.path)) not in sys.path:
Expand Down
1 change: 1 addition & 0 deletions task-sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,6 +806,7 @@ def parse(what: StartupDetails, log: Logger) -> RuntimeTaskInstance:
bundle_instance = DagBundlesManager().get_bundle(
name=bundle_info.name,
version=bundle_info.version,
version_data=bundle_info.version_data,
)
bundle_instance.initialize()
_verify_bundle_access(bundle_instance, log)
Expand Down
Loading