diff --git a/.dockerignore b/.dockerignore
index df08c066ce3b4..368d437dcfd1a 100644
--- a/.dockerignore
+++ b/.dockerignore
@@ -39,6 +39,7 @@
!task-sdk/
!airflow-ctl/
!go-sdk/
+!sdk/
# Add all "test" distributions
!tests
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index e84a272c5ff35..243796d4bdac8 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -492,6 +492,12 @@ repos:
language: python
pass_filenames: false
files: ^dev/registry/registry_tools/types\.py$|^registry/src/_data/types\.json$
+ - id: check-task-instance-dto-sync
+ name: Check BaseTaskInstanceDTO duplicate is in sync between core and task-sdk
+ entry: ./scripts/ci/prek/check_task_instance_dto_sync.py
+ language: python
+ pass_filenames: false
+ files: ^airflow-core/src/airflow/executors/workloads/task\.py$|^task-sdk/src/airflow/sdk/execution_time/workloads/task\.py$
- id: ruff
name: Run 'ruff' for extremely fast Python linting
description: "Run 'ruff' for extremely fast Python linting"
diff --git a/airflow-core/docs/extra-packages-ref.rst b/airflow-core/docs/extra-packages-ref.rst
index 2646b0a7c3079..9fb579c9b08ec 100644
--- a/airflow-core/docs/extra-packages-ref.rst
+++ b/airflow-core/docs/extra-packages-ref.rst
@@ -178,6 +178,17 @@ all the ``airflow`` packages together - similarly to what happened in Airflow 2.
``airflow-task-sdk`` separately, if you want to install providers, you need to install them separately as
``apache-airflow-providers-*`` distribution packages.
+Multi-Language extras
+=====================
+
+These are extras that add dependencies needed for integration with other languages runtimes. Currently we have only Java SDK related extra, but in the future we might add more extras related to other languages runtimes.
+
++----------+------------------------------------------+------------------------------------------------------------------+
+| extra | install command | enables |
++==========+==========================================+==================================================================+
+| sdk.java | ``pip install apache-airflow[sdk.java]`` | JavaCoordinator for both dag processing and workload execution. |
++----------+------------------------------------------+------------------------------------------------------------------+
+
Apache Software extras
======================
diff --git a/airflow-core/src/airflow/config_templates/config.yml b/airflow-core/src/airflow/config_templates/config.yml
index 8a060007978ee..a5a4ff9aa2157 100644
--- a/airflow-core/src/airflow/config_templates/config.yml
+++ b/airflow-core/src/airflow/config_templates/config.yml
@@ -1967,6 +1967,43 @@ workers:
type: integer
example: ~
default: "60"
+sdk:
+ description: Settings for non-Python SDK runtime coordination
+ options:
+ coordinators:
+ description: |
+ JSON list of runtime coordinator entries.
+
+ Each entry is an object with ``name``, ``classpath`` and optional
+ ``kwargs``. ``classpath`` is resolved via ``import_string`` and
+ constructed with ``kwargs`` once per process. Entries are
+ independent instances, so the same ``classpath`` can be configured
+ multiple times with different ``kwargs`` (for example, two
+ ``JavaCoordinator`` instances pinned to different JDK versions).
+ version_added: 3.1.7
+ type: string
+ example: |
+ [
+ {
+ "name": "jdk-17",
+ "classpath": "airflow.sdk.coordinators.java.JavaCoordinator",
+ "kwargs": {"java_executable": "/usr/lib/jvm/java-17-openjdk/bin/java", "jvm_args": ["-Xmx1024m"]}
+ }
+ ]
+ default: ~
+ queue_to_coordinator:
+ description: |
+ JSON mapping of queue names to coordinator ``name`` from
+ ``[sdk] coordinators``.
+
+ When a task's ``language`` field is not set, this mapping is checked
+ to route the task to a configured coordinator instance based on its
+ queue. This is useful when queues are used as environment or
+ isolation identifiers (e.g. ``legacy-java``, ``modern-java``).
+ version_added: 3.1.7
+ type: string
+ example: '{"legacy-java": "jdk-11", "modern-java": "jdk-17"}'
+ default: ~
api_auth:
description: Settings relating to authentication on the Airflow APIs
options:
diff --git a/airflow-core/src/airflow/dag_processing/manager.py b/airflow-core/src/airflow/dag_processing/manager.py
index 8d497ca7508e3..1221487818387 100644
--- a/airflow-core/src/airflow/dag_processing/manager.py
+++ b/airflow-core/src/airflow/dag_processing/manager.py
@@ -267,6 +267,9 @@ class DagFileProcessorManager(LoggingMixin):
factory=_config_get_factory("dag_processor", "file_parsing_sort_mode")
)
+ _runtime_file_extensions: tuple[str, ...] | None = attrs.field(default=None, init=False)
+ """File extensions registered by runtime coordinators (e.g. ".jar"). Lazily populated."""
+
_api_server: InProcessExecutionAPI = attrs.field(init=False, factory=InProcessExecutionAPI)
"""API server to interact with Metadata DB"""
@@ -815,6 +818,16 @@ def _find_files_in_bundle(self, bundle: BaseDagBundle) -> list[Path]:
return rel_paths
+ def _get_runtime_file_extensions(self) -> tuple[str, ...]:
+ """Collect file extensions from configured runtime coordinators (cached after first call)."""
+ if self._runtime_file_extensions is not None:
+ return self._runtime_file_extensions
+
+ from airflow.sdk.execution_time.coordinator import get_coordinator_manager
+
+ self._runtime_file_extensions = get_coordinator_manager().file_extensions()
+ return self._runtime_file_extensions
+
def _get_observed_filelocs(self, present: set[DagFileInfo]) -> set[str]:
"""
Return observed DAG source paths for bundle entries.
@@ -822,7 +835,11 @@ def _get_observed_filelocs(self, present: set[DagFileInfo]) -> set[str]:
For regular files this includes the relative file path.
For ZIP archives this includes DAG-like inner paths such as
``archive.zip/dag.py``.
+
+ Runtime coordinator file extensions (e.g. ``.jar``) are treated as
+ opaque files rather than ZIP archives.
"""
+ runtime_extensions = self._get_runtime_file_extensions()
def find_zipped_dags(abs_path: os.PathLike) -> Iterator[str]:
"""Yield absolute paths for DAG-like files inside a ZIP archive."""
@@ -837,7 +854,7 @@ def find_zipped_dags(abs_path: os.PathLike) -> Iterator[str]:
observed_filelocs: set[str] = set()
for info in present:
abs_path = str(info.absolute_path)
- if abs_path.endswith(".py") or not zipfile.is_zipfile(abs_path):
+ if abs_path.endswith((".py", *runtime_extensions)) or not zipfile.is_zipfile(abs_path):
observed_filelocs.add(str(info.rel_path))
else:
if TYPE_CHECKING:
diff --git a/airflow-core/src/airflow/dag_processing/processor.py b/airflow-core/src/airflow/dag_processing/processor.py
index aa9f07411f87d..00d17bd390c26 100644
--- a/airflow-core/src/airflow/dag_processing/processor.py
+++ b/airflow-core/src/airflow/dag_processing/processor.py
@@ -17,6 +17,7 @@
from __future__ import annotations
import contextlib
+import functools
import importlib
import logging
import os
@@ -76,8 +77,6 @@
from airflow.utils.state import TaskInstanceState
if TYPE_CHECKING:
- from socket import socket
-
from structlog.typing import FilteringBoundLogger
from airflow.api_fastapi.execution_api.app import InProcessExecutionAPI
@@ -86,6 +85,7 @@
from airflow.sdk.definitions.context import Context
from airflow.sdk.definitions.dag import DAG
from airflow.sdk.definitions.mappedoperator import MappedOperator
+ from airflow.sdk.execution_time.supervisor import SelectorCallback
from airflow.typing_compat import Self
@@ -553,7 +553,14 @@ def start( # type: ignore[override]
) -> Self:
logger = kwargs["logger"]
- _pre_import_airflow_modules(os.fspath(path), logger)
+ # Check if a configured runtime coordinator should handle this file
+ logger.debug("Checking for runtime coordinator entrypoint for file", path=path)
+ resolved_target = cls._resolve_processor_target(path, bundle_name, bundle_path, logger)
+ if resolved_target is not None:
+ target = resolved_target
+ logger.debug("Resolved runtime coordinator entrypoint for file", path=path)
+ else:
+ _pre_import_airflow_modules(os.fspath(path), logger)
proc: Self = super().start(
target=target,
@@ -566,6 +573,35 @@ def start( # type: ignore[override]
proc._on_child_started(callbacks, path, bundle_path, bundle_name)
return proc
+ @staticmethod
+ def _resolve_processor_target(
+ path: str | os.PathLike[str],
+ bundle_name: str,
+ bundle_path: Path,
+ log: FilteringBoundLogger,
+ ) -> Callable[[], None] | None:
+ """
+ Return the entrypoint of the first runtime coordinator that can handle *path*.
+
+ The returned callable is a ``functools.partial`` that binds *path*, *bundle_name*
+ and *bundle_path* so the supervisor can pass it as a no-arg ``target`` to
+ ``WatchedSubprocess.start``.
+ """
+ from airflow.sdk.execution_time.coordinator import get_coordinator_manager
+
+ coordinator = get_coordinator_manager().for_dag_file(bundle_name, path)
+ if coordinator is None:
+ log.debug("No runtime coordinator found for file %s, using default processor", path)
+ return None
+
+ log.debug("Using runtime coordinator %s for file %s", type(coordinator).__qualname__, path)
+ return functools.partial(
+ coordinator.run_dag_parsing,
+ path=os.fspath(path),
+ bundle_name=bundle_name,
+ bundle_path=os.fspath(bundle_path),
+ )
+
def _on_child_started(
self,
callbacks: list[CallbackRequest],
@@ -591,7 +627,7 @@ def _get_target_loggers(self) -> tuple[FilteringBoundLogger, ...]:
def _create_log_forwarder(
self, loggers: tuple[FilteringBoundLogger, ...], name: str, log_level: int = logging.INFO
- ) -> Callable[[socket], bool]:
+ ) -> SelectorCallback:
return super()._create_log_forwarder(loggers, name.replace("task.", "dag_processor.", 1), log_level)
def _handle_request(self, msg: ToManager, log: FilteringBoundLogger, req_id: int) -> None:
diff --git a/airflow-core/src/airflow/executors/base_executor.py b/airflow-core/src/airflow/executors/base_executor.py
index 9c9487c5377bc..bced160a5f4c7 100644
--- a/airflow-core/src/airflow/executors/base_executor.py
+++ b/airflow-core/src/airflow/executors/base_executor.py
@@ -650,11 +650,10 @@ def run_workload(
if isinstance(workload, ExecuteTask):
from airflow.sdk.execution_time.supervisor import supervise_task
+ from airflow.sdk.execution_time.workloads.task import TaskInstanceDTO as SDKTaskInstanceDTO
- # workload.ti is a TaskInstanceDTO which duck-types as TaskInstance.
- # TODO: Create a protocol for this.
return supervise_task(
- ti=workload.ti, # type: ignore[arg-type]
+ ti=SDKTaskInstanceDTO.model_validate(workload.ti, from_attributes=True),
bundle_info=workload.bundle_info,
dag_rel_path=workload.dag_rel_path,
token=workload.token,
diff --git a/airflow-core/src/airflow/executors/workloads/task.py b/airflow-core/src/airflow/executors/workloads/task.py
index d05affe433096..9af3f33c10efd 100644
--- a/airflow-core/src/airflow/executors/workloads/task.py
+++ b/airflow-core/src/airflow/executors/workloads/task.py
@@ -33,8 +33,14 @@
from airflow.models.taskinstancekey import TaskInstanceKey
-class TaskInstanceDTO(BaseModel):
- """Schema for TaskInstance with minimal required fields needed for Executors and Task SDK."""
+class BaseTaskInstanceDTO(BaseModel):
+ """
+ Base schema for TaskInstance with the minimal fields shared by Executors and the Task SDK.
+
+ This definition is duplicated in :mod:`airflow.sdk.execution_time.workloads.task`
+ and the two are kept in sync by the ``check-task-instance-dto-sync`` prek
+ hook. Update both files together.
+ """
id: uuid.UUID
dag_version_id: uuid.UUID
@@ -48,11 +54,16 @@ class TaskInstanceDTO(BaseModel):
queue: str
priority_weight: int
executor_config: dict | None = Field(default=None, exclude=True)
- external_executor_id: str | None = Field(default=None, exclude=True)
parent_context_carrier: dict | None = None
context_carrier: dict | None = None
+
+class TaskInstanceDTO(BaseTaskInstanceDTO):
+ """TaskInstanceDTO with executor-specific ``external_executor_id`` field and ``key`` property."""
+
+ external_executor_id: str | None = Field(default=None, exclude=True)
+
# TODO: Task-SDK: Can we replace TaskInstanceKey with just the uuid across the codebase?
@property
def key(self) -> TaskInstanceKey:
diff --git a/airflow-core/src/airflow/models/dagcode.py b/airflow-core/src/airflow/models/dagcode.py
index 60ee91c8b59b5..cdec5baa95717 100644
--- a/airflow-core/src/airflow/models/dagcode.py
+++ b/airflow-core/src/airflow/models/dagcode.py
@@ -119,6 +119,14 @@ def code(cls, dag_id, session: Session = NEW_SESSION) -> str:
@staticmethod
def get_code_from_file(fileloc):
+ # Try from runtime coordinator first.
+ from airflow.sdk.execution_time.coordinator import get_coordinator_manager
+
+ coordinator = get_coordinator_manager().for_dag_file("", fileloc)
+ if coordinator is not None:
+ return coordinator.get_code_from_file(fileloc)
+
+ # Then fallback to python native
try:
with open_maybe_zipped(fileloc, "r") as f:
code = f.read()
diff --git a/airflow-core/src/airflow/serialization/definitions/baseoperator.py b/airflow-core/src/airflow/serialization/definitions/baseoperator.py
index 6bafc5891235a..9eaf9cc3ed906 100644
--- a/airflow-core/src/airflow/serialization/definitions/baseoperator.py
+++ b/airflow-core/src/airflow/serialization/definitions/baseoperator.py
@@ -195,6 +195,7 @@ def get_serialized_fields(cls):
"ignore_first_depends_on_past",
"inlets",
"is_setup",
+ "sdk",
"is_teardown",
"map_index_template",
"max_active_tis_per_dag",
diff --git a/dev/breeze/src/airflow_breeze/utils/docker_command_utils.py b/dev/breeze/src/airflow_breeze/utils/docker_command_utils.py
index 2f376d98e0158..19d787f2967b6 100644
--- a/dev/breeze/src/airflow_breeze/utils/docker_command_utils.py
+++ b/dev/breeze/src/airflow_breeze/utils/docker_command_utils.py
@@ -111,6 +111,7 @@
("scripts", "/opt/airflow/scripts"),
("uv.lock", "/opt/airflow/uv.lock"),
("scripts/docker/entrypoint_ci.sh", "/entrypoint"),
+ ("sdk", "/opt/airflow/sdk"),
("shared", "/opt/airflow/shared"),
("task-sdk", "/opt/airflow/task-sdk"),
]
diff --git a/devel-common/src/docs/provider_conf.py b/devel-common/src/docs/provider_conf.py
index 6bc9da15f5f61..b730e8f20a417 100644
--- a/devel-common/src/docs/provider_conf.py
+++ b/devel-common/src/docs/provider_conf.py
@@ -151,7 +151,7 @@
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
-empty_subpackages = ["apache", "atlassian", "common", "cncf", "dbt", "microsoft"]
+empty_subpackages = ["apache", "atlassian", "common", "cncf", "dbt", "microsoft", "sdk"]
exclude_patterns = [
"operators/_partials",
"_api/airflow/index.rst",
diff --git a/devel-common/src/sphinx_exts/includes/sections-and-options.rst b/devel-common/src/sphinx_exts/includes/sections-and-options.rst
index e04383c8c5582..b0d84a1bd8a5a 100644
--- a/devel-common/src/sphinx_exts/includes/sections-and-options.rst
+++ b/devel-common/src/sphinx_exts/includes/sections-and-options.rst
@@ -65,7 +65,7 @@
{% if default and "\n" in default %}
.. code-block::
- {{ default }}
+ {{ default | indent(width=8) }}
{% else %}
``{{ "''" if default == "" else default }}``
{% endif %}
@@ -85,7 +85,7 @@
{% if "\n" in example %}
.. code-block::
- {{ example }}
+ {{ example | indent(width=8) }}
{% else %}
``{{ example }}``
{% endif %}
diff --git a/devel-common/src/tests_common/pytest_plugin.py b/devel-common/src/tests_common/pytest_plugin.py
index 98a871ebb12d2..ed3547ccbfa84 100644
--- a/devel-common/src/tests_common/pytest_plugin.py
+++ b/devel-common/src/tests_common/pytest_plugin.py
@@ -2510,7 +2510,6 @@ def execute(self, context):
from uuid6 import uuid7
from airflow.sdk import DAG
- from airflow.sdk.api.datamodels._generated import TaskInstance
from airflow.sdk.execution_time.comms import BundleInfo, StartupDetails
from airflow.timetables.base import TimeRestriction
@@ -2538,6 +2537,15 @@ def _create_task_instance(
should_retry: bool | None = None,
max_tries: int | None = None,
) -> RuntimeTaskInstance:
+ from tests_common.test_utils.version_compat import AIRFLOW_V_3_3_PLUS
+
+ if AIRFLOW_V_3_3_PLUS:
+ from airflow.sdk.execution_time.workloads.task import TaskInstanceDTO
+ else:
+ from airflow.sdk.api.datamodels._generated import ( # type: ignore[no-redef,assignment]
+ TaskInstance as TaskInstanceDTO,
+ )
+
from airflow.sdk.api.datamodels._generated import DagRun, DagRunState, TIRunContext
from airflow.utils.types import DagRunType
@@ -2615,14 +2623,17 @@ def _create_task_instance(
}
startup_details = StartupDetails(
- ti=TaskInstance(
+ ti=TaskInstanceDTO(
id=ti_id,
task_id=task.task_id,
dag_id=dag_id,
run_id=run_id,
try_number=try_number,
- map_index=map_index,
+ map_index=map_index, # type: ignore[arg-type]
dag_version_id=uuid7(),
+ pool_slots=1,
+ queue="default",
+ priority_weight=1,
),
dag_rel_path="",
bundle_info=BundleInfo(name="anything", version="any"),
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index 8d6413a288594..713c8bb3284b0 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -857,6 +857,7 @@ iTerm
iterm
itertools
Jarek
+JavaCoordinator
javascript
jaydebeapi
Jdbc
@@ -894,6 +895,7 @@ jsonl
juli
Jupyter
jupyter
+jvm
jwks
JWT
jwt
@@ -1125,6 +1127,7 @@ openai
openapi
openfaas
OpenID
+openjdk
openlineage
OpenSearch
opensearch
@@ -1852,6 +1855,7 @@ XComs
Xiaodong
xlarge
xml
+Xmx
xpath
XSS
xyz
diff --git a/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml b/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml
index 01c8149d1dad8..c777d1f1bec97 100644
--- a/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml
+++ b/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml
@@ -1257,8 +1257,8 @@ components:
- queue
- priority_weight
title: TaskInstanceDTO
- description: Schema for TaskInstance with minimal required fields needed for
- Executors and Task SDK.
+ description: TaskInstanceDTO with executor-specific ``external_executor_id``
+ field and ``key`` property.
TaskInstanceState:
type: string
enum:
diff --git a/providers/standard/src/airflow/providers/standard/decorators/stub.py b/providers/standard/src/airflow/providers/standard/decorators/stub.py
index f29d123c740c1..a5e63d925f795 100644
--- a/providers/standard/src/airflow/providers/standard/decorators/stub.py
+++ b/providers/standard/src/airflow/providers/standard/decorators/stub.py
@@ -85,7 +85,6 @@ def stub(
Stub tasks exist in the Dag graph only, but the execution must happen in an external
environment via the Task Execution Interface.
-
"""
return task_decorator_factory(
decorated_operator_class=_StubOperator,
diff --git a/pyproject.toml b/pyproject.toml
index 4a988db2d3ea7..5cb138ad5e88f 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1813,6 +1813,7 @@ members = [
"helm-tests",
"kubernetes-tests",
"task-sdk",
+ "sdk/coordinators/java",
"providers-summary-docs",
"docker-stack-docs",
"shared/configuration",
diff --git a/scripts/ci/docker-compose/local.yml b/scripts/ci/docker-compose/local.yml
index 185f2f385c344..a59fd4fd8e9cd 100644
--- a/scripts/ci/docker-compose/local.yml
+++ b/scripts/ci/docker-compose/local.yml
@@ -126,6 +126,9 @@ services:
- type: bind
source: ../../../scripts/docker/entrypoint_ci.sh
target: /entrypoint
+ - type: bind
+ source: ../../../sdk
+ target: /opt/airflow/sdk
- type: bind
source: ../../../shared
target: /opt/airflow/shared
diff --git a/scripts/ci/docker-compose/remove-sources.yml b/scripts/ci/docker-compose/remove-sources.yml
index a2f7d3a035766..cf78e9258f39f 100644
--- a/scripts/ci/docker-compose/remove-sources.yml
+++ b/scripts/ci/docker-compose/remove-sources.yml
@@ -107,6 +107,7 @@ services:
- ../../../empty:/opt/airflow/providers/redis/src
- ../../../empty:/opt/airflow/providers/salesforce/src
- ../../../empty:/opt/airflow/providers/samba/src
+ - ../../../empty:/opt/airflow/sdk/coordinators/java/src
- ../../../empty:/opt/airflow/providers/segment/src
- ../../../empty:/opt/airflow/providers/sendgrid/src
- ../../../empty:/opt/airflow/providers/sftp/src
diff --git a/scripts/ci/docker-compose/tests-sources.yml b/scripts/ci/docker-compose/tests-sources.yml
index 9c02d1c271412..67b2590f69b61 100644
--- a/scripts/ci/docker-compose/tests-sources.yml
+++ b/scripts/ci/docker-compose/tests-sources.yml
@@ -120,6 +120,7 @@ services:
- ../../../providers/redis/tests:/opt/airflow/providers/redis/tests
- ../../../providers/salesforce/tests:/opt/airflow/providers/salesforce/tests
- ../../../providers/samba/tests:/opt/airflow/providers/samba/tests
+ - ../../../sdk/coordinators/java/tests:/opt/airflow/sdk/coordinators/java/tests
- ../../../providers/segment/tests:/opt/airflow/providers/segment/tests
- ../../../providers/sendgrid/tests:/opt/airflow/providers/sendgrid/tests
- ../../../providers/sftp/tests:/opt/airflow/providers/sftp/tests
diff --git a/scripts/ci/prek/check_task_instance_dto_sync.py b/scripts/ci/prek/check_task_instance_dto_sync.py
new file mode 100755
index 0000000000000..689d35a4d15e3
--- /dev/null
+++ b/scripts/ci/prek/check_task_instance_dto_sync.py
@@ -0,0 +1,125 @@
+#!/usr/bin/env python
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Verify that the duplicate ``BaseTaskInstanceDTO`` definitions in airflow-core
+and task-sdk stay structurally identical.
+
+``BaseTaskInstanceDTO`` is duplicated (not shared) in:
+
+- ``airflow-core/src/airflow/executors/workloads/task.py``
+- ``task-sdk/src/airflow/sdk/execution_time/workloads/task.py``
+
+This hook compares the *fields* (annotated assignments) and bases of both
+``BaseTaskInstanceDTO`` classes. The concrete ``TaskInstanceDTO`` subclasses
+in each file are allowed to differ (airflow-core adds an executor-specific
+``key`` property that depends on ``airflow.models``, which the Task SDK
+does not have access to).
+"""
+
+from __future__ import annotations
+
+import ast
+import sys
+from pathlib import Path
+
+AIRFLOW_ROOT = Path(__file__).parents[3].resolve()
+CORE_FILE = AIRFLOW_ROOT / "airflow-core" / "src" / "airflow" / "executors" / "workloads" / "task.py"
+SDK_FILE = AIRFLOW_ROOT / "task-sdk" / "src" / "airflow" / "sdk" / "execution_time" / "workloads" / "task.py"
+CLASS_NAME = "BaseTaskInstanceDTO"
+
+
+def _find_class(tree: ast.AST, class_name: str) -> ast.ClassDef | None:
+ for node in ast.walk(tree):
+ if isinstance(node, ast.ClassDef) and node.name == class_name:
+ return node
+ return None
+
+
+def _field_signature(class_node: ast.ClassDef) -> list[tuple[str, str, str | None]]:
+ """Return a normalized list of ``(name, annotation, default)`` for each field."""
+ fields: list[tuple[str, str, str | None]] = []
+ for stmt in class_node.body:
+ if isinstance(stmt, ast.AnnAssign) and isinstance(stmt.target, ast.Name):
+ name = stmt.target.id
+ annotation = ast.unparse(stmt.annotation)
+ default = ast.unparse(stmt.value) if stmt.value is not None else None
+ fields.append((name, annotation, default))
+ return fields
+
+
+def _bases(class_node: ast.ClassDef) -> list[str]:
+ return [ast.unparse(base) for base in class_node.bases]
+
+
+def _extract(file_path: Path) -> tuple[list[str], list[tuple[str, str, str | None]]]:
+ source = file_path.read_text()
+ tree = ast.parse(source, filename=str(file_path))
+ class_node = _find_class(tree, CLASS_NAME)
+ if class_node is None:
+ print(f"ERROR: Could not find class {CLASS_NAME} in {file_path}", file=sys.stderr)
+ sys.exit(1)
+ return _bases(class_node), _field_signature(class_node)
+
+
+def main() -> None:
+ core_bases, core_fields = _extract(CORE_FILE)
+ sdk_bases, sdk_fields = _extract(SDK_FILE)
+
+ if core_bases == sdk_bases and core_fields == sdk_fields:
+ sys.exit(0)
+
+ print(
+ f"\nERROR: {CLASS_NAME} definitions in airflow-core and task-sdk are out of sync!",
+ file=sys.stderr,
+ )
+ print(f"\n airflow-core: {CORE_FILE.relative_to(AIRFLOW_ROOT)}", file=sys.stderr)
+ print(f" task-sdk: {SDK_FILE.relative_to(AIRFLOW_ROOT)}", file=sys.stderr)
+
+ if core_bases != sdk_bases:
+ print("\nClass bases differ:", file=sys.stderr)
+ print(f" airflow-core: {core_bases}", file=sys.stderr)
+ print(f" task-sdk: {sdk_bases}", file=sys.stderr)
+
+ if core_fields != sdk_fields:
+ core_set = {f[0]: f for f in core_fields}
+ sdk_set = {f[0]: f for f in sdk_fields}
+ only_in_core = sorted(set(core_set) - set(sdk_set))
+ only_in_sdk = sorted(set(sdk_set) - set(core_set))
+ differing = sorted(name for name in set(core_set) & set(sdk_set) if core_set[name] != sdk_set[name])
+ if only_in_core:
+ print(f"\n Fields only in airflow-core: {only_in_core}", file=sys.stderr)
+ if only_in_sdk:
+ print(f"\n Fields only in task-sdk: {only_in_sdk}", file=sys.stderr)
+ for name in differing:
+ print(
+ f"\n Field {name!r} differs:"
+ f"\n airflow-core: {core_set[name]}"
+ f"\n task-sdk: {sdk_set[name]}",
+ file=sys.stderr,
+ )
+
+ print(
+ f"\nUpdate both files together so the two {CLASS_NAME} definitions stay in sync.",
+ file=sys.stderr,
+ )
+ sys.exit(1)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts/in_container/java_sdk_setup.sh b/scripts/in_container/java_sdk_setup.sh
new file mode 100644
index 0000000000000..b3437b7fc4200
--- /dev/null
+++ b/scripts/in_container/java_sdk_setup.sh
@@ -0,0 +1,73 @@
+#!/usr/bin/env bash
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+
+# 1. Check Java
+check_java() {
+ local java_bin="/files/openjdk/bin/java"
+ local version_output
+
+ # First check if the locally installed OpenJDK exists and works.
+ if [ -x "$java_bin" ] && version_output=$("$java_bin" -version 2>&1); then
+ echo "Found existing OpenJDK at $java_bin. OK."
+ return
+ fi
+
+ # On macOS, /usr/bin/java exists as a shim even without a JDK installed,
+ # so we must test with `java -version` directly.
+ if ! version_output=$(java -version 2>&1); then
+ echo "Java is not installed."
+ install_java
+ return
+ fi
+
+ local java_version
+ java_version=$(echo "$version_output" | head -n1 | sed -E 's/.*"([0-9]+)(\.[0-9]+)*.*/\1/')
+
+ if ! [[ "$java_version" =~ ^[0-9]+$ ]]; then
+ echo "Could not determine Java version."
+ install_java
+ return
+ fi
+
+ if [ "$java_version" -ge 11 ]; then
+ echo "Java $java_version detected. OK."
+ else
+ echo "Java version $java_version found, but >= 11 is required."
+ install_java
+ fi
+}
+
+
+install_java() {
+ echo "Installing OpenJDK 11 in Breeze..."
+
+ curl -L -o /files/openjdk-11-aarch64.tar.gz \
+ https://github.com/adoptium/temurin11-binaries/releases/download/jdk-11.0.30+7/OpenJDK11U-jdk_aarch64_linux_hotspot_11.0.30_7.tar.gz
+
+ rm -rf /files/openjdk && mkdir -p /files/openjdk && \
+ tar -xzf /files/openjdk-11-aarch64.tar.gz --strip-components=1 -C /files/openjdk
+
+ /files/openjdk/bin/java -version
+ echo ""
+}
+
+check_java
+# Install Java Provider
+pip install -e /opt/airflow/providers/languages/java/
diff --git a/sdk/coordinators/java/.gitignore b/sdk/coordinators/java/.gitignore
new file mode 100644
index 0000000000000..bff2d7629604d
--- /dev/null
+++ b/sdk/coordinators/java/.gitignore
@@ -0,0 +1 @@
+*.iml
diff --git a/sdk/coordinators/java/LICENSE b/sdk/coordinators/java/LICENSE
new file mode 100644
index 0000000000000..11069edd79019
--- /dev/null
+++ b/sdk/coordinators/java/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+END OF TERMS AND CONDITIONS
+
+APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+Copyright [yyyy] [name of copyright owner]
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
diff --git a/sdk/coordinators/java/NOTICE b/sdk/coordinators/java/NOTICE
new file mode 100644
index 0000000000000..a51bd9390d030
--- /dev/null
+++ b/sdk/coordinators/java/NOTICE
@@ -0,0 +1,5 @@
+Apache Airflow
+Copyright 2016-2026 The Apache Software Foundation
+
+This product includes software developed at
+The Apache Software Foundation (http://www.apache.org/).
diff --git a/sdk/coordinators/java/README.rst b/sdk/coordinators/java/README.rst
new file mode 100644
index 0000000000000..63f19caa412bb
--- /dev/null
+++ b/sdk/coordinators/java/README.rst
@@ -0,0 +1,51 @@
+
+.. Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+.. http://www.apache.org/licenses/LICENSE-2.0
+
+.. Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+
+Package ``apache-airflow-coordinators-java``
+===========================================
+
+Java runtime coordinator for the Apache Airflow Task SDK.
+
+This distribution contributes the ``airflow.sdk.coordinators.java.JavaCoordinator``
+class, which spawns a JVM subprocess to parse Java DAG bundles (``.jar``)
+and execute Java tasks. It is loaded via the ``[sdk] coordinators`` configuration
+and is *not* a standard Airflow provider — it does not register hooks, operators,
+or any other provider-managed resources.
+
+Configure it in ``airflow.cfg``::
+
+ [sdk]
+ coordinators = [
+ {
+ "name": "jdk-17",
+ "classpath": "airflow.sdk.coordinators.java.JavaCoordinator",
+ "kwargs": {
+ "java_executable": "/usr/lib/jvm/java-17-openjdk/bin/java",
+ "jvm_args": ["-Xmx1024m"],
+ "bundles_folder": "~/airflow/java-bundles"
+ }
+ }
+ ]
+ queue_to_coordinator = {"java-queue": "jdk-17"}
+
+Installation
+------------
+
+::
+
+ pip install apache-airflow-coordinators-java
diff --git a/sdk/coordinators/java/pyproject.toml b/sdk/coordinators/java/pyproject.toml
new file mode 100644
index 0000000000000..f4e0dd31f5284
--- /dev/null
+++ b/sdk/coordinators/java/pyproject.toml
@@ -0,0 +1,108 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+[build-system]
+requires = ["hatchling==1.29.0"]
+build-backend = "hatchling.build"
+
+[project]
+name = "apache-airflow-coordinators-java"
+version = "0.1.0"
+description = "Java runtime coordinator for the Apache Airflow Task SDK"
+readme = "README.rst"
+license = "Apache-2.0"
+license-files = ["LICENSE", "NOTICE"]
+authors = [
+ {name="Apache Software Foundation", email="dev@airflow.apache.org"},
+]
+maintainers = [
+ {name="Apache Software Foundation", email="dev@airflow.apache.org"},
+]
+keywords = ["airflow", "coordinator", "java", "sdk"]
+classifiers = [
+ "Development Status :: 4 - Beta",
+ "Environment :: Console",
+ "Framework :: Apache Airflow",
+ "Intended Audience :: Developers",
+ "Intended Audience :: System Administrators",
+ "Programming Language :: Python :: 3.10",
+ "Programming Language :: Python :: 3.11",
+ "Programming Language :: Python :: 3.12",
+ "Programming Language :: Python :: 3.13",
+ "Programming Language :: Python :: 3.14",
+ "Topic :: System :: Monitoring",
+]
+requires-python = ">=3.10,!=3.15"
+
+dependencies = [
+ "apache-airflow-task-sdk>=1.3.0",
+ "PyYAML>=6.0.2",
+]
+
+[dependency-groups]
+dev = [
+ "apache-airflow",
+ "apache-airflow-task-sdk",
+ "apache-airflow-devel-common",
+]
+
+docs = [
+ "apache-airflow-devel-common[docs]"
+]
+
+[tool.uv.sources]
+apache-airflow = {workspace = true}
+apache-airflow-devel-common = {workspace = true}
+apache-airflow-task-sdk = {workspace = true}
+
+[project.urls]
+"Documentation" = "https://airflow.apache.org/docs/apache-airflow-coordinators-java/0.1.0"
+"Bug Tracker" = "https://github.com/apache/airflow/issues"
+"Source Code" = "https://github.com/apache/airflow"
+"Slack Chat" = "https://s.apache.org/airflow-slack"
+
+[tool.hatch.build.targets.wheel]
+packages = ["src/airflow"]
+# Do not ship the airflow / airflow.sdk / airflow.sdk.coordinators package roots
+# -- those are owned by airflow-core / task-sdk. This distribution only contributes
+# the airflow.sdk.coordinators.java sub-package.
+exclude = [
+ "src/airflow/__init__.py",
+ "src/airflow/sdk/__init__.py",
+ "src/airflow/sdk/coordinators/__init__.py",
+]
+
+[tool.hatch.build.targets.sdist]
+include = [
+ "src/airflow",
+ "docs",
+ "tests",
+ "LICENSE",
+ "NOTICE",
+ "README.rst",
+]
+
+[tool.ruff]
+extend = "../../../pyproject.toml"
+src = ["src"]
+namespace-packages = ["src/airflow"]
+
+[tool.ruff.lint.per-file-ignores]
+# Ignore Doc rules et al for anything outside of tests
+"!src/*" = ["D", "TID253", "S101", "TRY002"]
+# Ignore the pytest rules outside the tests folder
+"!tests/*" = ["PT"]
diff --git a/sdk/coordinators/java/src/airflow/sdk/coordinators/java/__init__.py b/sdk/coordinators/java/src/airflow/sdk/coordinators/java/__init__.py
new file mode 100644
index 0000000000000..daf8fce338d23
--- /dev/null
+++ b/sdk/coordinators/java/src/airflow/sdk/coordinators/java/__init__.py
@@ -0,0 +1,25 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Java runtime coordinator for the Apache Airflow Task SDK."""
+
+from __future__ import annotations
+
+from airflow.sdk.coordinators.java.coordinator import JavaCoordinator
+
+__all__ = ["JavaCoordinator", "__version__"]
+
+__version__ = "0.1.0"
diff --git a/sdk/coordinators/java/src/airflow/sdk/coordinators/java/bundle_scanner.py b/sdk/coordinators/java/src/airflow/sdk/coordinators/java/bundle_scanner.py
new file mode 100644
index 0000000000000..9e2c1c1ab46fc
--- /dev/null
+++ b/sdk/coordinators/java/src/airflow/sdk/coordinators/java/bundle_scanner.py
@@ -0,0 +1,218 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Scan directories for Airflow Java SDK bundle JARs.
+
+Mirrors the Java SDK's ``BundleScanner`` -- checks each JAR's manifest for
+``Airflow-Java-SDK-Metadata``, reads the embedded metadata YAML, and
+resolves the main class and classpath needed to launch the bundle process.
+"""
+
+from __future__ import annotations
+
+import email
+import os
+import zipfile
+from pathlib import Path
+from typing import NamedTuple
+
+import yaml
+
+MANIFEST_PATH = "META-INF/MANIFEST.MF"
+METADATA_MANIFEST_KEY = "Airflow-Java-SDK-Metadata"
+SDK_VERSION_MANIFEST_KEY = "Airflow-Java-SDK-Version"
+DAG_CODE_MANIFEST_KEY = "Airflow-Java-SDK-Dag-Code"
+MAIN_CLASS_MANIFEST_KEY = "Main-Class"
+
+
+class ResolvedJarBundle(NamedTuple):
+ """A resolved Java DAG bundle: everything needed to start the bundle process."""
+
+ main_class: str
+ classpath: str
+
+
+class BundleScanner:
+ """
+ Locate Airflow Java SDK bundles inside a directory tree.
+
+ Supports two directory layouts:
+
+ - **Nested** -- each immediate subdirectory of *bundles_dir* is a bundle home.
+ - **Flat** -- *bundles_dir* itself contains the bundle JARs.
+
+ Within a bundle home the JVM convention of a ``lib/`` subdirectory for
+ dependency JARs is respected automatically.
+ """
+
+ def __init__(self, bundles_dir: Path) -> None:
+ self._bundles_dir = bundles_dir
+
+ def resolve(self, dag_id: str) -> ResolvedJarBundle:
+ """
+ Find the bundle whose metadata YAML lists *dag_id*.
+
+ :raises FileNotFoundError: if no matching bundle is found.
+ """
+ for bundle_home in self._candidate_homes():
+ jars = _jar_files(bundle_home)
+ if not jars:
+ continue
+
+ for jar_path in jars:
+ result = _read_bundle_jar(jar_path)
+ if result is None:
+ continue
+ main_class, dag_ids = result
+ if dag_id in dag_ids:
+ classpath = os.pathsep.join(str(j.resolve()) for j in jars)
+ return ResolvedJarBundle(main_class=main_class, classpath=classpath)
+
+ raise FileNotFoundError(f"No JAR bundle containing dag_id={dag_id!r} found in {self._bundles_dir}")
+
+ @staticmethod
+ def resolve_jar(jar_path: Path) -> str:
+ """
+ Read ``Main-Class`` from a single bundle JAR, validating SDK attributes.
+
+ :raises FileNotFoundError: if the JAR is not a valid Airflow Java SDK bundle.
+ """
+ result = _read_bundle_jar(jar_path)
+ if result is None:
+ raise FileNotFoundError(
+ f"Not a valid Airflow Java SDK bundle: {jar_path} "
+ f"(requires {METADATA_MANIFEST_KEY} and {MAIN_CLASS_MANIFEST_KEY})"
+ )
+ return result[0]
+
+ def _candidate_homes(self) -> list[Path]:
+ """Return normalised bundle-home directories to inspect."""
+ candidates: list[Path] = []
+
+ if self._bundles_dir.is_dir():
+ for child in sorted(self._bundles_dir.iterdir()):
+ if child.is_dir():
+ candidates.append(_normalize_bundle_home(child))
+
+ candidates.append(_normalize_bundle_home(self._bundles_dir))
+ return candidates
+
+
+def _jar_files(directory: Path) -> list[Path]:
+ """List all ``.jar`` files in *directory*, sorted by name."""
+ if not directory.is_dir():
+ return []
+ return sorted(p for p in directory.iterdir() if p.is_file() and p.suffix == ".jar")
+
+
+def _normalize_bundle_home(path: Path) -> Path:
+ """
+ Normalize a bundle path to the directory containing JARs.
+
+ Handles the common JVM distribution layout where dependency JARs
+ live in a ``lib/`` subdirectory (Gradle ``application`` plugin,
+ Maven Assembly, sbt-native-packager, etc.).
+
+ - If *path* points to a JAR file, use its parent directory.
+ - If the directory has a ``lib/`` subdirectory containing JARs, use that.
+ - Otherwise, return the directory as-is.
+ """
+ normalized = path.resolve()
+ if normalized.is_file() and normalized.suffix == ".jar":
+ return normalized.parent
+ lib = normalized / "lib"
+ if lib.is_dir() and any(p.suffix == ".jar" for p in lib.iterdir()):
+ return lib
+ return normalized
+
+
+def _read_bundle_jar(jar_path: Path) -> tuple[str, set[str]] | None:
+ """
+ Read ``Main-Class`` and dag IDs from a JAR's manifest and embedded metadata.
+
+ Returns ``(main_class, dag_ids)`` when the JAR carries valid
+ ``Airflow-Java-SDK-Metadata`` and ``Main-Class`` manifest attributes
+ and the referenced metadata YAML contains at least one dag ID.
+ Returns ``None`` otherwise.
+ """
+ try:
+ with zipfile.ZipFile(jar_path) as zf:
+ try:
+ with zf.open(MANIFEST_PATH) as f:
+ manifest = email.message_from_binary_file(f)
+ except KeyError:
+ return None
+
+ metadata_file = manifest.get(METADATA_MANIFEST_KEY)
+ if not metadata_file:
+ return None
+
+ main_class = manifest.get(MAIN_CLASS_MANIFEST_KEY)
+ if not main_class:
+ return None
+
+ try:
+ with zf.open(metadata_file) as f:
+ content = f.read().decode()
+ except KeyError:
+ return None
+ except zipfile.BadZipFile:
+ return None
+
+ dag_ids = _parse_dag_ids_from_metadata(content)
+ if not dag_ids:
+ return None
+
+ return main_class, dag_ids
+
+
+def read_dag_code(jar_path: Path) -> str | None:
+ """
+ Read the DAG source code embedded in a JAR bundle.
+
+ Returns the source code string when the JAR carries a valid
+ ``Airflow-Java-SDK-Dag-Code`` manifest attribute pointing to an
+ embedded source file. Returns ``None`` otherwise.
+ """
+ try:
+ with zipfile.ZipFile(jar_path) as zf:
+ try:
+ with zf.open(MANIFEST_PATH) as f:
+ manifest = email.message_from_binary_file(f)
+ except KeyError:
+ return None
+
+ dag_code_path = manifest.get(DAG_CODE_MANIFEST_KEY)
+ if not dag_code_path:
+ return None
+
+ try:
+ with zf.open(dag_code_path) as f:
+ return f.read().decode()
+ except KeyError:
+ return None
+ except zipfile.BadZipFile:
+ return None
+
+
+def _parse_dag_ids_from_metadata(yaml_content: str) -> set[str]:
+ """Parse dag IDs from an ``airflow-metadata.yaml`` content string."""
+ data = yaml.safe_load(yaml_content)
+ if not isinstance(data, dict) or "dags" not in data:
+ return set()
+ return set(data["dags"].keys())
diff --git a/sdk/coordinators/java/src/airflow/sdk/coordinators/java/coordinator.py b/sdk/coordinators/java/src/airflow/sdk/coordinators/java/coordinator.py
new file mode 100644
index 0000000000000..4f93fcc93fd88
--- /dev/null
+++ b/sdk/coordinators/java/src/airflow/sdk/coordinators/java/coordinator.py
@@ -0,0 +1,161 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Java runtime coordinator that launches a JVM subprocess for Dag file processing and task execution."""
+
+from __future__ import annotations
+
+import contextlib
+import os
+import zipfile
+from pathlib import Path
+from typing import TYPE_CHECKING
+
+from airflow.sdk.coordinators.java.bundle_scanner import BundleScanner, read_dag_code
+from airflow.sdk.execution_time.coordinator import BaseCoordinator
+
+if TYPE_CHECKING:
+ from airflow.sdk.api.datamodels._generated import BundleInfo
+ from airflow.sdk.execution_time.workloads.task import TaskInstanceDTO
+
+
+class JavaCoordinator(BaseCoordinator):
+ """
+ Coordinator that launches a JVM subprocess for DAG parsing and task execution.
+
+ Configuration is taken from the ``[sdk] coordinators`` entry that constructs
+ this instance::
+
+ {
+ "name": "jdk-17",
+ "classpath": "airflow.sdk.coordinators.java.JavaCoordinator",
+ "kwargs": {
+ "java_executable": "/usr/lib/jvm/java-17-openjdk/bin/java",
+ "jvm_args": ["-Xmx1024m"],
+ "bundles_folder": "~/airflow/java-bundles",
+ },
+ }
+
+ :param java_executable: Path to the ``java`` binary (defaults to ``"java"``,
+ which relies on ``$PATH``).
+ :param jvm_args: Extra arguments passed to the JVM (e.g. ``["-Xmx512m"]``).
+ :param bundles_folder: Directory scanned for JAR bundles when a Python
+ stub DAG delegates task execution to Java. Required for the stub-DAG
+ flow; unused for pure-Java DAGs.
+ """
+
+ sdk = "java"
+ file_extension = ".jar"
+
+ def __init__(
+ self,
+ *,
+ java_executable: str = "java",
+ jvm_args: list[str] | None = None,
+ bundles_folder: str | None = None,
+ ) -> None:
+ self.java_executable = java_executable
+ self.jvm_args = list(jvm_args) if jvm_args else []
+ self.bundles_folder = bundles_folder
+
+ def can_handle_dag_file(self, bundle_name: str, path: str | os.PathLike[str]) -> bool:
+ """Return ``True`` when *path* is a JAR with valid Airflow Java SDK manifest attributes."""
+ if not os.fspath(path).endswith(self.file_extension):
+ return False
+ with contextlib.suppress(FileNotFoundError, NotADirectoryError, zipfile.BadZipFile, KeyError):
+ return BundleScanner.resolve_jar(Path(path)) is not None
+ return False
+
+ def get_code_from_file(self, fileloc: str) -> str:
+ """Read embedded DAG source code from a JAR bundle."""
+ code = read_dag_code(Path(fileloc))
+ if code is None:
+ raise FileNotFoundError(f"No DAG source code found in JAR: {fileloc}")
+ return code
+
+ def dag_parsing_cmd(
+ self,
+ *,
+ dag_file_path: str,
+ bundle_name: str,
+ bundle_path: str,
+ comm_addr: str,
+ logs_addr: str,
+ ) -> list[str]:
+ """Build the ``java`` command for parsing a JAR bundle."""
+ jar_path = Path(dag_file_path)
+ # Java bundles are typically thin JARs: the main JAR only contains
+ # the bundle's own classes while its dependencies (the Airflow Java
+ # SDK, logging libraries, etc.) are separate JARs that live alongside
+ # it. Using ``
/*`` lets the JVM load every JAR in the directory.
+ classpath = f"{bundle_path}/*"
+ return [
+ self.java_executable,
+ *self.jvm_args,
+ "-classpath",
+ classpath,
+ BundleScanner.resolve_jar(jar_path),
+ f"--comm={comm_addr}",
+ f"--logs={logs_addr}",
+ ]
+
+ def task_execution_cmd(
+ self,
+ *,
+ what: TaskInstanceDTO,
+ dag_file_path: str,
+ bundle_path: str,
+ bundle_info: BundleInfo,
+ comm_addr: str,
+ logs_addr: str,
+ ) -> list[str]:
+ """Build the ``java`` command for executing a task in a JAR bundle."""
+ if dag_file_path.endswith(".jar"):
+ # Case 1: Pure Java Dag -- the dag_file_path points directly to a
+ # bundle JAR inside the Airflow Core Dag Bundle.
+ jar_path = Path(dag_file_path)
+ classpath = f"{bundle_path}/*"
+ return [
+ self.java_executable,
+ *self.jvm_args,
+ "-classpath",
+ classpath,
+ BundleScanner.resolve_jar(jar_path),
+ f"--comm={comm_addr}",
+ f"--logs={logs_addr}",
+ ]
+
+ # Case 2: Python Stub Dag -- the dag_file_path is a Python file but
+ # the task delegates to a Java runtime. The actual JAR bundle lives
+ # in ``bundles_folder`` (passed to __init__ from the [sdk] coordinators
+ # config entry).
+ if not self.bundles_folder:
+ raise ValueError(
+ "JavaCoordinator: bundles_folder kwarg must be set for Python stub DAGs "
+ "that delegate to Java task execution."
+ )
+
+ resolved = BundleScanner(Path(self.bundles_folder)).resolve(dag_id=what.dag_id)
+ return [
+ self.java_executable,
+ *self.jvm_args,
+ "-classpath",
+ resolved.classpath,
+ resolved.main_class,
+ f"--comm={comm_addr}",
+ f"--logs={logs_addr}",
+ ]
diff --git a/sdk/coordinators/java/tests/__init__.py b/sdk/coordinators/java/tests/__init__.py
new file mode 100644
index 0000000000000..13a83393a9124
--- /dev/null
+++ b/sdk/coordinators/java/tests/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/sdk/coordinators/java/tests/unit/__init__.py b/sdk/coordinators/java/tests/unit/__init__.py
new file mode 100644
index 0000000000000..13a83393a9124
--- /dev/null
+++ b/sdk/coordinators/java/tests/unit/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/sdk/coordinators/java/tests/unit/coordinators/__init__.py b/sdk/coordinators/java/tests/unit/coordinators/__init__.py
new file mode 100644
index 0000000000000..13a83393a9124
--- /dev/null
+++ b/sdk/coordinators/java/tests/unit/coordinators/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/sdk/coordinators/java/tests/unit/coordinators/java/__init__.py b/sdk/coordinators/java/tests/unit/coordinators/java/__init__.py
new file mode 100644
index 0000000000000..13a83393a9124
--- /dev/null
+++ b/sdk/coordinators/java/tests/unit/coordinators/java/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/sdk/coordinators/java/tests/unit/coordinators/java/test_bundle_scanner.py b/sdk/coordinators/java/tests/unit/coordinators/java/test_bundle_scanner.py
new file mode 100644
index 0000000000000..93457aa1a9755
--- /dev/null
+++ b/sdk/coordinators/java/tests/unit/coordinators/java/test_bundle_scanner.py
@@ -0,0 +1,332 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import os
+import zipfile
+from pathlib import Path
+
+import pytest
+import yaml
+
+from airflow.sdk.coordinators.java.bundle_scanner import (
+ DAG_CODE_MANIFEST_KEY,
+ MAIN_CLASS_MANIFEST_KEY,
+ MANIFEST_PATH,
+ METADATA_MANIFEST_KEY,
+ SDK_VERSION_MANIFEST_KEY,
+ BundleScanner,
+ ResolvedJarBundle,
+ _jar_files,
+ _normalize_bundle_home,
+ _parse_dag_ids_from_metadata,
+ _read_bundle_jar,
+ read_dag_code,
+)
+
+METADATA_YAML_PATH = "META-INF/airflow-metadata.yaml"
+DAG_CODE_PATH = "JavaExample.java"
+TEST_MAIN_CLASS = "com.example.MyDag"
+TEST_SDK_VERSION = "1.0.0"
+
+
+def _make_manifest(
+ *,
+ main_class: str | None = TEST_MAIN_CLASS,
+ metadata_path: str | None = METADATA_YAML_PATH,
+ sdk_version: str | None = TEST_SDK_VERSION,
+ dag_code_path: str | None = None,
+) -> str:
+ lines = ["Manifest-Version: 1.0"]
+ if main_class:
+ lines.append(f"{MAIN_CLASS_MANIFEST_KEY}: {main_class}")
+ if metadata_path:
+ lines.append(f"{METADATA_MANIFEST_KEY}: {metadata_path}")
+ if sdk_version:
+ lines.append(f"{SDK_VERSION_MANIFEST_KEY}: {sdk_version}")
+ if dag_code_path:
+ lines.append(f"{DAG_CODE_MANIFEST_KEY}: {dag_code_path}")
+ return "\n".join(lines) + "\n"
+
+
+def _make_metadata_yaml(dag_ids: list[str]) -> str:
+ return yaml.dump({"dags": {dag_id: {} for dag_id in dag_ids}})
+
+
+def _create_bundle_jar(
+ jar_path: Path,
+ *,
+ dag_ids: list[str] | None = None,
+ main_class: str | None = TEST_MAIN_CLASS,
+ include_metadata: bool = True,
+ include_manifest: bool = True,
+ dag_code: str | None = None,
+) -> Path:
+ """Create a minimal JAR (zip) file with Airflow Java SDK manifest attributes."""
+ with zipfile.ZipFile(jar_path, "w") as zf:
+ if include_manifest:
+ dag_code_path = DAG_CODE_PATH if dag_code else None
+ manifest = _make_manifest(
+ main_class=main_class,
+ metadata_path=METADATA_YAML_PATH if include_metadata else None,
+ dag_code_path=dag_code_path,
+ )
+ zf.writestr(MANIFEST_PATH, manifest)
+
+ if include_metadata and dag_ids is not None:
+ zf.writestr(METADATA_YAML_PATH, _make_metadata_yaml(dag_ids))
+
+ if dag_code:
+ zf.writestr(DAG_CODE_PATH, dag_code)
+ return jar_path
+
+
+class TestJarFiles:
+ def test_lists_jar_files_sorted(self, tmp_path: Path):
+ (tmp_path / "b.jar").touch()
+ (tmp_path / "a.jar").touch()
+ (tmp_path / "c.txt").touch()
+ result = _jar_files(tmp_path)
+ assert result == [tmp_path / "a.jar", tmp_path / "b.jar"]
+
+ def test_returns_empty_for_nonexistent_directory(self, tmp_path: Path):
+ assert _jar_files(tmp_path / "nonexistent") == []
+
+ def test_returns_empty_for_directory_with_no_jars(self, tmp_path: Path):
+ (tmp_path / "readme.txt").touch()
+ assert _jar_files(tmp_path) == []
+
+ def test_ignores_jar_directories(self, tmp_path: Path):
+ (tmp_path / "fake.jar").mkdir()
+ assert _jar_files(tmp_path) == []
+
+
+class TestNormalizeBundleHome:
+ def test_jar_file_returns_parent(self, tmp_path: Path):
+ jar = tmp_path / "bundle.jar"
+ jar.touch()
+ assert _normalize_bundle_home(jar) == tmp_path.resolve()
+
+ def test_dir_with_lib_containing_jars(self, tmp_path: Path):
+ lib = tmp_path / "lib"
+ lib.mkdir()
+ (lib / "dep.jar").touch()
+ assert _normalize_bundle_home(tmp_path) == lib.resolve()
+
+ def test_dir_with_empty_lib(self, tmp_path: Path):
+ lib = tmp_path / "lib"
+ lib.mkdir()
+ assert _normalize_bundle_home(tmp_path) == tmp_path.resolve()
+
+ def test_plain_directory(self, tmp_path: Path):
+ assert _normalize_bundle_home(tmp_path) == tmp_path.resolve()
+
+
+class TestParseDagIdsFromMetadata:
+ def test_parses_dag_ids(self):
+ content = yaml.dump({"dags": {"dag_a": {}, "dag_b": {"key": "val"}}})
+ assert _parse_dag_ids_from_metadata(content) == {"dag_a", "dag_b"}
+
+ @pytest.mark.parametrize(
+ "yaml_content",
+ [
+ pytest.param(yaml.dump({"other": 1}), id="missing_dags_key"),
+ pytest.param("just a string", id="non_dict"),
+ pytest.param(yaml.dump({"dags": {}}), id="empty_dags"),
+ ],
+ )
+ def test_returns_empty_set(self, yaml_content):
+ assert _parse_dag_ids_from_metadata(yaml_content) == set()
+
+
+class TestReadBundleJar:
+ def test_valid_jar(self, tmp_path: Path):
+ jar = _create_bundle_jar(tmp_path / "valid.jar", dag_ids=["my_dag"])
+ result = _read_bundle_jar(jar)
+ assert result is not None
+ main_class, dag_ids = result
+ assert main_class == TEST_MAIN_CLASS
+ assert dag_ids == {"my_dag"}
+
+ def test_returns_none_for_missing_manifest(self, tmp_path: Path):
+ jar = _create_bundle_jar(tmp_path / "no_manifest.jar", include_manifest=False)
+ assert _read_bundle_jar(jar) is None
+
+ def test_returns_none_for_missing_metadata_key(self, tmp_path: Path):
+ jar = _create_bundle_jar(tmp_path / "no_meta.jar", include_metadata=False)
+ assert _read_bundle_jar(jar) is None
+
+ def test_returns_none_for_missing_main_class(self, tmp_path: Path):
+ jar = _create_bundle_jar(tmp_path / "no_main.jar", dag_ids=["d"], main_class=None)
+ assert _read_bundle_jar(jar) is None
+
+ def test_returns_none_for_missing_metadata_file(self, tmp_path: Path):
+ """Manifest references a metadata file that does not exist inside the JAR."""
+ jar = tmp_path / "missing_meta_file.jar"
+ with zipfile.ZipFile(jar, "w") as zf:
+ manifest = _make_manifest(metadata_path="nonexistent.yaml")
+ zf.writestr(MANIFEST_PATH, manifest)
+ assert _read_bundle_jar(jar) is None
+
+ def test_returns_none_for_bad_zip(self, tmp_path: Path):
+ bad = tmp_path / "bad.jar"
+ bad.write_text("not a zip file")
+ assert _read_bundle_jar(bad) is None
+
+ def test_returns_none_for_empty_dag_ids(self, tmp_path: Path):
+ jar = _create_bundle_jar(tmp_path / "empty_dags.jar", dag_ids=[])
+ assert _read_bundle_jar(jar) is None
+
+ def test_multiple_dag_ids(self, tmp_path: Path):
+ jar = _create_bundle_jar(tmp_path / "multi.jar", dag_ids=["dag_1", "dag_2", "dag_3"])
+ result = _read_bundle_jar(jar)
+ assert result is not None
+ _, dag_ids = result
+ assert dag_ids == {"dag_1", "dag_2", "dag_3"}
+
+
+class TestReadDagCode:
+ def test_reads_embedded_dag_code(self, tmp_path: Path):
+ code = "public class MyDag {}"
+ jar = _create_bundle_jar(tmp_path / "with_code.jar", dag_ids=["d"], dag_code=code)
+ assert read_dag_code(jar) == code
+
+ def test_returns_none_for_missing_dag_code_key(self, tmp_path: Path):
+ jar = _create_bundle_jar(tmp_path / "no_code.jar", dag_ids=["d"])
+ assert read_dag_code(jar) is None
+
+ def test_returns_none_for_missing_manifest(self, tmp_path: Path):
+ jar = _create_bundle_jar(tmp_path / "no_manifest.jar", include_manifest=False)
+ assert read_dag_code(jar) is None
+
+ def test_returns_none_for_bad_zip(self, tmp_path: Path):
+ bad = tmp_path / "bad.jar"
+ bad.write_text("not a zip")
+ assert read_dag_code(bad) is None
+
+ def test_returns_none_when_code_file_missing(self, tmp_path: Path):
+ """Manifest references a dag code file that does not exist inside the JAR."""
+ jar = tmp_path / "broken_code.jar"
+ with zipfile.ZipFile(jar, "w") as zf:
+ manifest = _make_manifest(dag_code_path="missing_source.py")
+ zf.writestr(MANIFEST_PATH, manifest)
+ assert read_dag_code(jar) is None
+
+
+class TestBundleScannerResolveJar:
+ def test_returns_main_class(self, tmp_path: Path):
+ jar = _create_bundle_jar(tmp_path / "bundle.jar", dag_ids=["d"])
+ assert BundleScanner.resolve_jar(jar) == TEST_MAIN_CLASS
+
+ def test_raises_for_invalid_jar(self, tmp_path: Path):
+ jar = tmp_path / "not_bundle.jar"
+ jar.write_text("not a zip")
+ with pytest.raises(FileNotFoundError, match="Not a valid Airflow Java SDK bundle"):
+ BundleScanner.resolve_jar(jar)
+
+
+class TestBundleScannerCandidateHomes:
+ def test_nested_layout(self, tmp_path: Path):
+ sub_a = tmp_path / "bundle_a"
+ sub_a.mkdir()
+ (sub_a / "app.jar").touch()
+
+ sub_b = tmp_path / "bundle_b"
+ sub_b.mkdir()
+ (sub_b / "app.jar").touch()
+
+ scanner = BundleScanner(tmp_path)
+ homes = scanner._candidate_homes()
+ assert len(homes) == 3
+ assert sub_a.resolve() in homes
+ assert sub_b.resolve() in homes
+ assert tmp_path.resolve() in homes
+
+ def test_flat_layout(self, tmp_path: Path):
+ (tmp_path / "app.jar").touch()
+ scanner = BundleScanner(tmp_path)
+ homes = scanner._candidate_homes()
+ assert homes == [tmp_path.resolve()]
+
+ def test_nested_with_lib_subdir(self, tmp_path: Path):
+ sub = tmp_path / "my_bundle"
+ sub.mkdir()
+ lib = sub / "lib"
+ lib.mkdir()
+ (lib / "dep.jar").touch()
+
+ scanner = BundleScanner(tmp_path)
+ homes = scanner._candidate_homes()
+ assert lib.resolve() in homes
+
+
+class TestBundleScannerResolve:
+ def test_finds_matching_dag(self, tmp_path: Path):
+ bundle_dir = tmp_path / "my_bundle"
+ bundle_dir.mkdir()
+ _create_bundle_jar(bundle_dir / "app.jar", dag_ids=["target_dag"])
+
+ scanner = BundleScanner(tmp_path)
+ result = scanner.resolve("target_dag")
+ assert isinstance(result, ResolvedJarBundle)
+ assert result.main_class == TEST_MAIN_CLASS
+ assert str((bundle_dir / "app.jar").resolve()) in result.classpath
+
+ def test_raises_when_no_match(self, tmp_path: Path):
+ bundle_dir = tmp_path / "my_bundle"
+ bundle_dir.mkdir()
+ _create_bundle_jar(bundle_dir / "app.jar", dag_ids=["other_dag"])
+
+ scanner = BundleScanner(tmp_path)
+ with pytest.raises(FileNotFoundError, match="No JAR bundle containing dag_id='missing'"):
+ scanner.resolve("missing")
+
+ def test_classpath_includes_all_jars(self, tmp_path: Path):
+ bundle_dir = tmp_path / "my_bundle"
+ bundle_dir.mkdir()
+ _create_bundle_jar(bundle_dir / "app.jar", dag_ids=["my_dag"])
+ with zipfile.ZipFile(bundle_dir / "dep.jar", "w") as zf:
+ zf.writestr("placeholder.class", b"")
+
+ scanner = BundleScanner(tmp_path)
+ result = scanner.resolve("my_dag")
+ parts = result.classpath.split(os.pathsep)
+ assert len(parts) == 2
+
+ def test_flat_layout_resolve(self, tmp_path: Path):
+ _create_bundle_jar(tmp_path / "app.jar", dag_ids=["flat_dag"])
+
+ scanner = BundleScanner(tmp_path)
+ result = scanner.resolve("flat_dag")
+ assert result.main_class == TEST_MAIN_CLASS
+
+ def test_skips_non_bundle_jars(self, tmp_path: Path):
+ bundle_dir = tmp_path / "my_bundle"
+ bundle_dir.mkdir()
+ with zipfile.ZipFile(bundle_dir / "plain.jar", "w") as zf:
+ zf.writestr("placeholder.class", b"")
+ _create_bundle_jar(bundle_dir / "real.jar", dag_ids=["real_dag"])
+
+ scanner = BundleScanner(tmp_path)
+ result = scanner.resolve("real_dag")
+ assert result.main_class == TEST_MAIN_CLASS
+
+ def test_empty_bundles_dir(self, tmp_path: Path):
+ scanner = BundleScanner(tmp_path)
+ with pytest.raises(FileNotFoundError):
+ scanner.resolve("any_dag")
diff --git a/sdk/coordinators/java/tests/unit/coordinators/java/test_coordinator.py b/sdk/coordinators/java/tests/unit/coordinators/java/test_coordinator.py
new file mode 100644
index 0000000000000..3a101f30cba88
--- /dev/null
+++ b/sdk/coordinators/java/tests/unit/coordinators/java/test_coordinator.py
@@ -0,0 +1,268 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import uuid
+import zipfile
+from pathlib import Path
+
+import pytest
+import yaml
+
+from airflow.sdk.api.datamodels._generated import BundleInfo
+from airflow.sdk.coordinators.java.bundle_scanner import (
+ MAIN_CLASS_MANIFEST_KEY,
+ MANIFEST_PATH,
+ METADATA_MANIFEST_KEY,
+ SDK_VERSION_MANIFEST_KEY,
+)
+from airflow.sdk.coordinators.java.coordinator import JavaCoordinator
+from airflow.sdk.execution_time.workloads.task import TaskInstanceDTO
+
+from tests_common.test_utils.version_compat import AIRFLOW_V_3_3_PLUS
+
+if not AIRFLOW_V_3_3_PLUS:
+ pytest.skip("Coordinator is only compatible with Airflow >= 3.3.0", allow_module_level=True)
+
+METADATA_YAML_PATH = "META-INF/airflow-metadata.yaml"
+DAG_CODE_PATH = "dag_source.py"
+TEST_MAIN_CLASS = "com.example.MyBundle"
+
+
+def _make_manifest(
+ *,
+ main_class: str | None = TEST_MAIN_CLASS,
+ metadata_path: str | None = METADATA_YAML_PATH,
+ dag_code_path: str | None = None,
+) -> str:
+ lines = ["Manifest-Version: 1.0"]
+ if main_class:
+ lines.append(f"{MAIN_CLASS_MANIFEST_KEY}: {main_class}")
+ if metadata_path:
+ lines.append(f"{METADATA_MANIFEST_KEY}: {metadata_path}")
+ lines.append(f"{SDK_VERSION_MANIFEST_KEY}: 1.0.0")
+ if dag_code_path:
+ lines.append(f"Airflow-Java-SDK-Dag-Code: {dag_code_path}")
+ return "\n".join(lines) + "\n"
+
+
+def _create_bundle_jar(
+ jar_path: Path,
+ *,
+ dag_ids: list[str] | None = None,
+ dag_code: str | None = None,
+) -> Path:
+ with zipfile.ZipFile(jar_path, "w") as zf:
+ dag_code_path = DAG_CODE_PATH if dag_code else None
+ manifest = _make_manifest(dag_code_path=dag_code_path)
+ zf.writestr(MANIFEST_PATH, manifest)
+ if dag_ids is not None:
+ metadata = yaml.dump({"dags": {d: {} for d in dag_ids}})
+ zf.writestr(METADATA_YAML_PATH, metadata)
+ if dag_code:
+ zf.writestr(DAG_CODE_PATH, dag_code)
+ return jar_path
+
+
+def _make_ti(dag_id: str = "test_dag") -> TaskInstanceDTO:
+ return TaskInstanceDTO(
+ id=uuid.uuid4(),
+ dag_version_id=uuid.uuid4(),
+ task_id="task_1",
+ dag_id=dag_id,
+ run_id="run_1",
+ try_number=1,
+ map_index=-1,
+ pool_slots=1,
+ queue="default",
+ priority_weight=1,
+ )
+
+
+class TestJavaCoordinatorAttributes:
+ def test_sdk(self):
+ assert JavaCoordinator.sdk == "java"
+
+ def test_file_extension(self):
+ assert JavaCoordinator.file_extension == ".jar"
+
+ def test_default_kwargs(self):
+ coordinator = JavaCoordinator()
+ assert coordinator.java_executable == "java"
+ assert coordinator.jvm_args == []
+ assert coordinator.bundles_folder is None
+
+ def test_custom_kwargs(self):
+ coordinator = JavaCoordinator(
+ java_executable="/opt/java/bin/java",
+ jvm_args=["-Xmx512m", "-Xms256m"],
+ bundles_folder="/airflow/java-bundles",
+ )
+ assert coordinator.java_executable == "/opt/java/bin/java"
+ assert coordinator.jvm_args == ["-Xmx512m", "-Xms256m"]
+ assert coordinator.bundles_folder == "/airflow/java-bundles"
+
+
+class TestCanHandleDagFile:
+ def test_valid_jar_returns_true(self, tmp_path: Path):
+ jar = _create_bundle_jar(tmp_path / "valid.jar", dag_ids=["d"])
+ assert JavaCoordinator().can_handle_dag_file("bundle", str(jar)) is True
+
+ def test_non_jar_file_returns_false(self, tmp_path: Path):
+ py_file = tmp_path / "dag.py"
+ py_file.write_text("from airflow import DAG")
+ assert JavaCoordinator().can_handle_dag_file("bundle", str(py_file)) is False
+
+ def test_missing_file_returns_false(self, tmp_path: Path):
+ assert JavaCoordinator().can_handle_dag_file("bundle", str(tmp_path / "missing.jar")) is False
+
+ def test_bad_zip_returns_false(self, tmp_path: Path):
+ bad = tmp_path / "bad.jar"
+ bad.write_text("not a zip")
+ assert JavaCoordinator().can_handle_dag_file("bundle", str(bad)) is False
+
+ def test_jar_without_sdk_manifest_returns_false(self, tmp_path: Path):
+ jar = tmp_path / "plain.jar"
+ with zipfile.ZipFile(jar, "w") as zf:
+ zf.writestr("placeholder.class", b"")
+ assert JavaCoordinator().can_handle_dag_file("bundle", str(jar)) is False
+
+
+class TestGetCodeFromFile:
+ def test_returns_embedded_code(self, tmp_path: Path):
+ code = "from airflow import DAG\ndag = DAG('my_dag')"
+ jar = _create_bundle_jar(tmp_path / "with_code.jar", dag_ids=["d"], dag_code=code)
+ assert JavaCoordinator().get_code_from_file(str(jar)) == code
+
+ def test_raises_when_no_code(self, tmp_path: Path):
+ jar = _create_bundle_jar(tmp_path / "no_code.jar", dag_ids=["d"])
+ with pytest.raises(FileNotFoundError, match="No DAG source code found in JAR"):
+ JavaCoordinator().get_code_from_file(str(jar))
+
+
+class TestDagParsingCmd:
+ def test_builds_default_java_command(self, tmp_path: Path):
+ jar = _create_bundle_jar(tmp_path / "app.jar", dag_ids=["d"])
+ bundle_path = str(tmp_path)
+ cmd = JavaCoordinator().dag_parsing_cmd(
+ dag_file_path=str(jar),
+ bundle_name="my_bundle",
+ bundle_path=bundle_path,
+ comm_addr="localhost:1234",
+ logs_addr="localhost:5678",
+ )
+ assert cmd == [
+ "java",
+ "-classpath",
+ f"{bundle_path}/*",
+ TEST_MAIN_CLASS,
+ "--comm=localhost:1234",
+ "--logs=localhost:5678",
+ ]
+
+ def test_uses_custom_executable_and_jvm_args(self, tmp_path: Path):
+ jar = _create_bundle_jar(tmp_path / "app.jar", dag_ids=["d"])
+ bundle_path = str(tmp_path)
+ coordinator = JavaCoordinator(
+ java_executable="/opt/jdk-17/bin/java",
+ jvm_args=["-Xmx1024m", "-Xms256m"],
+ )
+ cmd = coordinator.dag_parsing_cmd(
+ dag_file_path=str(jar),
+ bundle_name="my_bundle",
+ bundle_path=bundle_path,
+ comm_addr="localhost:1234",
+ logs_addr="localhost:5678",
+ )
+ assert cmd == [
+ "/opt/jdk-17/bin/java",
+ "-Xmx1024m",
+ "-Xms256m",
+ "-classpath",
+ f"{bundle_path}/*",
+ TEST_MAIN_CLASS,
+ "--comm=localhost:1234",
+ "--logs=localhost:5678",
+ ]
+
+
+class TestTaskExecutionCmd:
+ def test_pure_java_dag(self, tmp_path: Path):
+ jar = _create_bundle_jar(tmp_path / "app.jar", dag_ids=["test_dag"])
+ bundle_path = str(tmp_path)
+ ti = _make_ti()
+ bundle_info = BundleInfo(name="my_bundle")
+
+ cmd = JavaCoordinator().task_execution_cmd(
+ what=ti, # type: ignore[arg-type]
+ dag_file_path=str(jar),
+ bundle_path=bundle_path,
+ bundle_info=bundle_info,
+ comm_addr="localhost:1234",
+ logs_addr="localhost:5678",
+ )
+ assert cmd == [
+ "java",
+ "-classpath",
+ f"{bundle_path}/*",
+ TEST_MAIN_CLASS,
+ "--comm=localhost:1234",
+ "--logs=localhost:5678",
+ ]
+
+ def test_python_stub_dag_uses_bundles_folder_kwarg(self, tmp_path: Path):
+ bundles_folder = tmp_path / "java_bundles"
+ bundle_sub = bundles_folder / "my_bundle"
+ bundle_sub.mkdir(parents=True)
+ _create_bundle_jar(bundle_sub / "app.jar", dag_ids=["stub_dag"])
+
+ ti = _make_ti(dag_id="stub_dag")
+ bundle_info = BundleInfo(name="my_bundle")
+
+ coordinator = JavaCoordinator(bundles_folder=str(bundles_folder))
+ cmd = coordinator.task_execution_cmd(
+ what=ti, # type: ignore[arg-type]
+ dag_file_path="/dags/stub_dag.py",
+ bundle_path="/some/bundle/path",
+ bundle_info=bundle_info,
+ comm_addr="localhost:1234",
+ logs_addr="localhost:5678",
+ )
+
+ assert cmd == [
+ "java",
+ "-classpath",
+ f"{bundles_folder}/my_bundle/app.jar",
+ TEST_MAIN_CLASS,
+ "--comm=localhost:1234",
+ "--logs=localhost:5678",
+ ]
+
+ def test_python_stub_dag_without_bundles_folder_raises(self):
+ ti = _make_ti()
+ bundle_info = BundleInfo(name="my_bundle")
+
+ with pytest.raises(ValueError, match="bundles_folder kwarg must be set"):
+ JavaCoordinator().task_execution_cmd(
+ what=ti, # type: ignore[arg-type]
+ dag_file_path="/dags/stub_dag.py",
+ bundle_path="/some/bundle/path",
+ bundle_info=bundle_info,
+ comm_addr="localhost:1234",
+ logs_addr="localhost:5678",
+ )
diff --git a/task-sdk/.pre-commit-config.yaml b/task-sdk/.pre-commit-config.yaml
index 100a6e6490849..c1d4498a623fc 100644
--- a/task-sdk/.pre-commit-config.yaml
+++ b/task-sdk/.pre-commit-config.yaml
@@ -43,6 +43,7 @@ repos:
^src/airflow/sdk/definitions/deadline\.py$|
^src/airflow/sdk/definitions/dag\.py$|
^src/airflow/sdk/definitions/_internal/types\.py$|
+ ^src/airflow/sdk/execution_time/coordinator\.py$|
^src/airflow/sdk/execution_time/execute_workload\.py$|
^src/airflow/sdk/execution_time/secrets_masker\.py$|
^src/airflow/sdk/execution_time/callback_supervisor\.py$|
diff --git a/task-sdk/src/airflow/sdk/__init__.py b/task-sdk/src/airflow/sdk/__init__.py
index f304b068237b3..ab6c10f417fc8 100644
--- a/task-sdk/src/airflow/sdk/__init__.py
+++ b/task-sdk/src/airflow/sdk/__init__.py
@@ -16,6 +16,11 @@
# under the License.
from __future__ import annotations
+# Make ``airflow.sdk`` a namespace-extending package so sibling distributions
+# (e.g. ``apache-airflow-coordinators-java`` shipping
+# ``airflow/sdk/coordinators/java/``) can contribute sub-packages.
+__path__ = __import__("pkgutil").extend_path(__path__, __name__)
+
from typing import TYPE_CHECKING
__all__ = [
diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py b/task-sdk/src/airflow/sdk/execution_time/comms.py
index 1e11e9636e56f..b407e71a9891f 100644
--- a/task-sdk/src/airflow/sdk/execution_time/comms.py
+++ b/task-sdk/src/airflow/sdk/execution_time/comms.py
@@ -79,7 +79,6 @@
PreviousTIResponse,
PrevSuccessfulDagRunResponse,
TaskBreadcrumbsResponse,
- TaskInstance,
TaskInstanceState,
TaskStatesResponse,
TIDeferredStatePayload,
@@ -96,6 +95,9 @@
XComSequenceSliceResponse,
)
from airflow.sdk.exceptions import ErrorType
+from airflow.sdk.execution_time.workloads.task import (
+ TaskInstanceDTO, # noqa: TC001 -- Pydantic needs this at runtime
+)
try:
from socket import recv_fds
@@ -332,7 +334,7 @@ def _get_response(self) -> ReceiveMsgType | None:
class StartupDetails(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
- ti: TaskInstance
+ ti: TaskInstanceDTO
dag_rel_path: str
bundle_info: BundleInfo
start_date: datetime
diff --git a/task-sdk/src/airflow/sdk/execution_time/coordinator.py b/task-sdk/src/airflow/sdk/execution_time/coordinator.py
new file mode 100644
index 0000000000000..7e8e0f685a349
--- /dev/null
+++ b/task-sdk/src/airflow/sdk/execution_time/coordinator.py
@@ -0,0 +1,552 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Runtime coordinator for non-Python DAG file processing and task execution.
+
+Provides :class:`BaseCoordinator`, the base class for
+SDK-specific coordinators that bridge subprocess I/O between the
+Airflow supervisor and an external-SDK runtime (Java, Go, Rust, etc.),
+and :class:`CoordinatorManager`, the registry that loads coordinator
+instances from the ``[sdk] coordinators`` configuration.
+
+The coordinator's :meth:`~BaseCoordinator.run_dag_parsing` and
+:meth:`~BaseCoordinator.run_task_execution` methods handle the full lifecycle:
+
+1. Creates TCP servers for comm and logs channels, and a socketpair for stderr.
+2. Calls :meth:`~BaseCoordinator.dag_parsing_cmd` or
+ :meth:`~BaseCoordinator.task_execution_cmd` (provided by the subclass) to
+ obtain the subprocess command.
+3. Spawns the subprocess and accepts TCP connections from it.
+4. Runs a selector-based bridge that transparently forwards bytes
+ between fd 0 (supervisor) and the subprocess comm socket, and
+ re-emits the subprocess's log and stderr output through structlog.
+
+I/O multiplexing uses the same selector-based loop as
+:class:`~airflow.sdk.execution_time.supervisor.WatchedSubprocess`,
+driven by :func:`~airflow.sdk.execution_time.selector_loop.service_selector`.
+"""
+
+from __future__ import annotations
+
+import contextlib
+import functools
+import os
+import selectors
+import socket
+import subprocess
+import time
+from typing import TYPE_CHECKING, ClassVar, NamedTuple
+
+from airflow.sdk._shared.module_loading import import_string
+
+if TYPE_CHECKING:
+ from structlog.typing import FilteringBoundLogger
+ from typing_extensions import Self
+
+ from airflow.sdk.api.datamodels._generated import BundleInfo
+ from airflow.sdk.execution_time.comms import StartupDetails
+ from airflow.sdk.execution_time.workloads.task import TaskInstanceDTO
+
+
+def _start_server() -> socket.socket:
+ """Create a TCP server socket bound to a random port on localhost."""
+ server = socket.socket()
+ server.bind(("127.0.0.1", 0))
+ server.setblocking(True)
+ server.listen(1)
+ return server
+
+
+def _send_startup_details(runtime_comm: socket.socket, startup_details: StartupDetails) -> None:
+ """
+ Re-encode and send the ``StartupDetails`` frame to the runtime subprocess.
+
+ In the task execution flow, ``task_runner.main()`` consumes the
+ ``StartupDetails`` message from fd 0 (to determine routing) before
+ delegating to the runtime coordinator. This function re-serializes
+ the message and writes it to the runtime subprocess's comm socket so
+ the subprocess receives it as if it came directly from the supervisor.
+ """
+ from airflow.sdk.execution_time.comms import _ResponseFrame
+
+ # Use mode="json" so that datetime, UUID, and other complex Python
+ # types are serialized as plain strings/numbers in msgpack -- avoiding
+ # msgpack extension types (e.g. Timestamp) that non-Python decoders
+ # may not support.
+ frame = _ResponseFrame(id=0, body=startup_details.model_dump(mode="json"))
+ runtime_comm.sendall(frame.as_bytes())
+
+
+def _bridge(
+ supervisor_comm: socket.socket,
+ runtime_comm: socket.socket,
+ runtime_logs: socket.socket,
+ runtime_stderr: socket.socket,
+ proc: subprocess.Popen,
+ log: FilteringBoundLogger,
+) -> None:
+ """
+ Multiplex I/O between the supervisor and a runtime subprocess.
+
+ Four channels are registered with the selector:
+
+ - ``supervisor_comm`` -> ``runtime_comm`` (raw byte forwarding)
+ - ``runtime_comm`` -> ``supervisor_comm`` (raw byte forwarding)
+ - ``runtime_logs`` -> structlog (line-buffered JSON logs)
+ - ``runtime_stderr`` -> structlog (line-buffered stderr output)
+
+ Uses the same ``(handler, on_close)`` callback contract as
+ :class:`~airflow.sdk.execution_time.supervisor.WatchedSubprocess`,
+ driven by :func:`~airflow.sdk.execution_time.selector_loop.service_selector`.
+ """
+ from airflow.sdk.execution_time.selector_loop import (
+ make_buffered_socket_reader,
+ make_raw_forwarder,
+ service_selector,
+ )
+ from airflow.sdk.execution_time.supervisor import (
+ forward_to_log,
+ process_log_messages_from_subprocess,
+ )
+
+ sel = selectors.DefaultSelector()
+
+ def on_close(sock: socket.socket) -> None:
+ with contextlib.suppress(KeyError):
+ sel.unregister(sock)
+
+ target_loggers = (log,)
+
+ # Comm: bidirectional raw byte forwarding.
+ sel.register(supervisor_comm, selectors.EVENT_READ, make_raw_forwarder(runtime_comm, on_close))
+ sel.register(runtime_comm, selectors.EVENT_READ, make_raw_forwarder(supervisor_comm, on_close))
+
+ # TCP logs channel: line-buffered JSON from the runtime SDK's LogSender,
+ # processed with the same handler as WatchedSubprocess (level mapping,
+ # timestamp parsing, exception extraction).
+ sel.register(
+ runtime_logs,
+ selectors.EVENT_READ,
+ make_buffered_socket_reader(process_log_messages_from_subprocess(target_loggers), on_close),
+ )
+ # stderr: plain-text output from the runtime process's logging framework
+ # (e.g. SLF4J simple logger). Use forward_to_log which handles raw
+ # text lines, not process_log_messages_from_subprocess which expects JSON.
+ import logging
+
+ sel.register(
+ runtime_stderr,
+ selectors.EVENT_READ,
+ make_buffered_socket_reader(
+ forward_to_log(target_loggers, logger="task.stderr", level=logging.ERROR), on_close
+ ),
+ )
+
+ # Event loop -- runs until the subprocess exits and all sockets are drained.
+ while sel.get_map():
+ service_selector(sel, timeout=1.0)
+ if proc.poll() is not None:
+ # Subprocess has exited -- drain remaining data with a short deadline.
+ deadline = time.monotonic() + 5.0
+ while sel.get_map() and time.monotonic() < deadline:
+ service_selector(sel, timeout=0.5)
+ break
+
+ sel.close()
+ for sock in (supervisor_comm, runtime_comm, runtime_logs, runtime_stderr):
+ with contextlib.suppress(OSError):
+ sock.close()
+
+
+class BaseCoordinator:
+ """
+ Base coordinator for runtime-specific DAG file processing and task execution.
+
+ Coordinators are instantiated from the ``[sdk] coordinators`` configuration
+ (see :class:`CoordinatorManager`) — each entry's ``classpath`` is resolved
+ via :func:`~airflow.sdk._shared.module_loading.import_string` and
+ constructed with the entry's ``kwargs``.
+
+ Subclasses represent a specific SDK runtime (Java, Go, etc.) and only
+ need to implement :meth:`can_handle_dag_file`, :meth:`dag_parsing_cmd`
+ and :meth:`task_execution_cmd`. The base class owns the entire bridge
+ lifecycle: TCP servers, subprocess management, selector-based I/O loop,
+ and cleanup.
+ """
+
+ sdk: ClassVar[str]
+ file_extension: ClassVar[str]
+
+ class DagParsingInfo(NamedTuple):
+ """Information needed for runtime Dag parsing."""
+
+ dag_file_path: str
+ bundle_name: str
+ bundle_path: str
+ mode: str = "dag-parsing"
+
+ class TaskExecutionInfo(NamedTuple):
+ """Information needed for runtime task execution."""
+
+ what: TaskInstanceDTO
+ dag_rel_path: str | os.PathLike[str]
+ bundle_info: BundleInfo
+ startup_details: StartupDetails
+ mode: str = "task-execution"
+
+ def can_handle_dag_file(self, bundle_name: str, path: str | os.PathLike[str]) -> bool:
+ """
+ Return ``True`` if this coordinator should handle DAG-file parsing for *path*.
+
+ Called by :meth:`DagFileProcessorProcess._resolve_processor_target` to
+ decide whether to delegate parsing to this coordinator's
+ :meth:`run_dag_parsing` instead of the default Python entrypoint.
+
+ The default implementation returns ``False``; subclasses must override.
+ """
+ return False
+
+ def get_code_from_file(self, fileloc: str) -> str:
+ """
+ Return the human-readable source code for a DAG file managed by this coordinator.
+
+ Called by :class:`~airflow.models.dagcode.DagCode` when persisting DAG
+ source to the metadata database. The default Python path reads ``.py``
+ files directly; runtime coordinators must override this to extract source
+ from their native packaging format (e.g. reading an embedded ``.java``
+ file from a JAR bundle).
+
+ :param fileloc: Absolute path to the DAG file (e.g. a ``/path/to/example.jar``).
+ :return: The source code as a string.
+ :raises FileNotFoundError: If source code cannot be retrieved from *fileloc*.
+ """
+ raise NotImplementedError
+
+ def dag_parsing_cmd(
+ self,
+ *,
+ dag_file_path: str,
+ bundle_name: str,
+ bundle_path: str,
+ comm_addr: str,
+ logs_addr: str,
+ ) -> list[str]:
+ """
+ Return the subprocess command for DAG file parsing.
+
+ :param dag_file_path: Absolute path to the DAG file to parse.
+ :param bundle_name: Name of the DAG bundle.
+ :param bundle_path: Root path of the DAG bundle.
+ :param comm_addr: ``host:port`` the subprocess must connect to
+ for the bidirectional msgpack comm channel.
+ :param logs_addr: ``host:port`` the subprocess must connect to
+ for the structured JSON log channel.
+ :returns: Full command list (e.g. ``["java", "-cp", "...", ...]`` based on each runtime).
+ """
+ raise NotImplementedError
+
+ def task_execution_cmd(
+ self,
+ *,
+ what: TaskInstanceDTO,
+ dag_file_path: str,
+ bundle_path: str,
+ bundle_info: BundleInfo,
+ comm_addr: str,
+ logs_addr: str,
+ ) -> list[str]:
+ """
+ Return the subprocess command for task execution.
+
+ :param what: The task instance to execute.
+ :param dag_file_path: Absolute path to the DAG file.
+ :param bundle_path: Root path of the DAG bundle.
+ :param bundle_info: Bundle metadata.
+ :param comm_addr: ``host:port`` the subprocess must connect to
+ for the bidirectional msgpack comm channel.
+ :param logs_addr: ``host:port`` the subprocess must connect to
+ for the structured JSON log channel.
+ :returns: Full command list.
+ """
+ raise NotImplementedError
+
+ def run_dag_parsing(self, *, path: str, bundle_name: str, bundle_path: str) -> None:
+ """Entry point for running runtime-specific Dag File Processing."""
+ self._runtime_subprocess_entrypoint(
+ self.DagParsingInfo(
+ dag_file_path=path,
+ bundle_name=bundle_name,
+ bundle_path=bundle_path,
+ )
+ )
+
+ def run_task_execution(
+ self,
+ *,
+ what: TaskInstanceDTO,
+ dag_rel_path: str | os.PathLike[str],
+ bundle_info: BundleInfo,
+ startup_details: StartupDetails,
+ ) -> None:
+ self._runtime_subprocess_entrypoint(
+ self.TaskExecutionInfo(
+ what=what,
+ dag_rel_path=dag_rel_path,
+ bundle_info=bundle_info,
+ startup_details=startup_details,
+ )
+ )
+
+ def _runtime_subprocess_entrypoint(self, entrypoint_info: DagParsingInfo | TaskExecutionInfo) -> None:
+ """
+ Spawn the runtime subprocess and bridge I/O with the supervisor.
+
+ This is called inside the forked child process where fd 0 is the
+ bidirectional comms socket to the supervisor. The method:
+
+ 1. Creates TCP servers for comm and logs.
+ 2. Calls :meth:`dag_parsing_cmd` or :meth:`task_execution_cmd` to get the command.
+ 3. Spawns the subprocess with ``stdin=/dev/null`` and stderr
+ captured via a socketpair.
+ 4. Runs the selector-based bridge until the subprocess exits.
+
+ Two distinct IPC mechanisms are used because each channel has a
+ different initiator:
+
+ - The runtime subprocess actively *connects* to the comm and logs
+ TCP servers using ``host:port`` strings passed via the command line
+ -- portable across every language's stdlib socket API.
+ - stderr is *inherited*: the subprocess writes to fd 2 transparently
+ (its native logging framework targets stderr by default), so we
+ replace fd 2 with one end of a socketpair instead of teaching the
+ runtime about an address. ``subprocess.PIPE`` would not work
+ because :func:`make_buffered_socket_reader` requires a real socket.
+
+ fd layout of *this* coordinator process (set up by
+ ``_reopen_std_io_handles`` before this runs):
+
+ - fd 0 -- bidirectional comms socket to the supervisor
+ (``DagFileParseRequest`` <-> ``DagFileParsingResult``,
+ length-prefixed msgpack frames)
+ - fd 1 -- stdout socket to the supervisor
+ - fd 2 -- stderr socket to the supervisor
+ - fd N -- structured JSON log channel (``log_fd``, configured by
+ ``_configure_logs_over_json_channel`` -> structlog)
+
+ The runtime subprocess gets ``stdin=DEVNULL``, inherits fd 1 (so its
+ stdout flows straight to the supervisor), and has its fd 2 replaced
+ by the coordinator-owned end of the stderr socketpair.
+ """
+ os.environ["_AIRFLOW_PROCESS_CONTEXT"] = "client"
+
+ import structlog
+
+ log = structlog.get_logger(logger_name="task")
+ log.info(
+ "Starting runtime subprocess",
+ sdk=self.sdk,
+ mode=entrypoint_info.mode,
+ )
+
+ # TCP servers for the runtime subprocess to connect to.
+ comm_server = _start_server()
+ logs_server = _start_server()
+ comm_host, comm_port = comm_server.getsockname()
+ logs_host, logs_port = logs_server.getsockname()
+
+ comm_addr = f"{comm_host}:{comm_port}"
+ logs_addr = f"{logs_host}:{logs_port}"
+
+ # stderr uses a socketpair (instead of ``subprocess.PIPE``) so it
+ # is a real socket compatible with ``make_buffered_socket_reader``.
+ child_stderr, read_stderr = socket.socketpair()
+
+ # For task execution, hold a BundleVersionLock for the entire
+ # subprocess lifetime to prevent the bundle version from being
+ # garbage-collected while the runtime process is still running.
+ bundle_version_lock: contextlib.AbstractContextManager = contextlib.nullcontext()
+
+ if isinstance(entrypoint_info, self.DagParsingInfo):
+ cmd = self.dag_parsing_cmd(
+ dag_file_path=entrypoint_info.dag_file_path,
+ bundle_name=entrypoint_info.bundle_name,
+ bundle_path=entrypoint_info.bundle_path,
+ comm_addr=comm_addr,
+ logs_addr=logs_addr,
+ )
+ elif isinstance(entrypoint_info, self.TaskExecutionInfo):
+ from airflow.dag_processing.bundles.base import BundleVersionLock
+ from airflow.sdk.execution_time.task_runner import resolve_bundle
+
+ bundle_instance = resolve_bundle(entrypoint_info.bundle_info, log)
+ resolved_dag_file_path = bundle_instance.path / entrypoint_info.dag_rel_path
+
+ cmd = self.task_execution_cmd(
+ what=entrypoint_info.what,
+ dag_file_path=os.fspath(resolved_dag_file_path),
+ bundle_path=os.fspath(bundle_instance.path),
+ bundle_info=entrypoint_info.bundle_info,
+ comm_addr=comm_addr,
+ logs_addr=logs_addr,
+ )
+ bundle_version_lock = BundleVersionLock(
+ bundle_name=entrypoint_info.bundle_info.name,
+ bundle_version=entrypoint_info.bundle_info.version,
+ )
+ else:
+ raise ValueError(f"Unknown entrypoint_info type: {type(entrypoint_info)}")
+
+ with bundle_version_lock:
+ # stdin redirected to /dev/null so the subprocess does not inherit
+ # fd 0 (the comms socket).
+ proc = subprocess.Popen(
+ cmd,
+ stdin=subprocess.DEVNULL,
+ stderr=child_stderr.fileno(),
+ )
+ child_stderr.close()
+
+ # Wait for the subprocess to connect to both servers.
+ runtime_comm, _ = comm_server.accept()
+ runtime_logs, _ = logs_server.accept()
+ comm_server.close()
+ logs_server.close()
+
+ # For task execution the supervisor already sent ``StartupDetails``
+ # on fd 0 and ``task_runner.main()`` consumed it before delegating
+ # here. Re-encode and forward it to the runtime subprocess so it
+ # knows which task to execute.
+ if isinstance(entrypoint_info, self.TaskExecutionInfo):
+ _send_startup_details(runtime_comm, entrypoint_info.startup_details)
+
+ # fd 0 is the bidirectional comms socket to the supervisor.
+ supervisor_comm = socket.socket(fileno=os.dup(0))
+
+ _bridge(supervisor_comm, runtime_comm, runtime_logs, read_stderr, proc, log)
+
+
+class CoordinatorManager:
+ """
+ Registry of coordinator instances loaded from the ``[sdk] coordinators`` config.
+
+ Each entry in the JSON list takes the form::
+
+ {
+ "name": "jdk-11",
+ "classpath": "airflow.sdk.coordinators.java.JavaCoordinator",
+ "kwargs": {"java_executable": "/usr/lib/jvm/jdk-11/bin/java", ...}
+ }
+
+ The ``classpath`` is resolved via
+ :func:`~airflow.sdk._shared.module_loading.import_string` (no
+ :class:`ProvidersManager` involvement) and constructed with ``kwargs``.
+
+ The ``[sdk] queue_to_coordinator`` config maps queue names to a coordinator
+ ``name`` from that list, which lets users reuse existing queue assignments
+ to route tasks to a specific coordinator instance (for example, a
+ ``"legacy-java"`` queue routed to a JDK 11 coordinator and a
+ ``"modern-java"`` queue routed to a JDK 17 coordinator).
+ """
+
+ def __init__(
+ self,
+ instances_by_name: dict[str, BaseCoordinator],
+ queue_to_coordinator: dict[str, str],
+ ) -> None:
+ self._instances_by_name = instances_by_name
+ self._queue_to_coordinator = queue_to_coordinator
+
+ @classmethod
+ def from_config(cls) -> Self:
+ """Load coordinator instances from the ``[sdk]`` configuration."""
+ from airflow.sdk.configuration import conf
+
+ entries = conf.getjson("sdk", "coordinators", fallback=[])
+ if not isinstance(entries, list):
+ entries = []
+
+ instances: dict[str, BaseCoordinator] = {}
+ for entry in entries:
+ if not isinstance(entry, dict):
+ continue
+ name = entry.get("name")
+ classpath = entry.get("classpath")
+ if not name or not classpath:
+ continue
+ kwargs = entry.get("kwargs") or {}
+ coordinator_cls = import_string(classpath)
+ instances[name] = coordinator_cls(**kwargs)
+
+ queue_mapping = conf.getjson("sdk", "queue_to_coordinator", fallback={})
+ if not isinstance(queue_mapping, dict):
+ queue_mapping = {}
+
+ return cls(instances, queue_mapping)
+
+ def all(self) -> list[BaseCoordinator]:
+ """Return all loaded coordinator instances, sorted by configured name."""
+ return [self._instances_by_name[name] for name in sorted(self._instances_by_name)]
+
+ def get(self, name: str) -> BaseCoordinator | None:
+ """Return the coordinator instance registered under *name*, or ``None``."""
+ return self._instances_by_name.get(name)
+
+ def for_queue(self, queue: str) -> BaseCoordinator | None:
+ """Return the coordinator instance routed to *queue*, or ``None``."""
+ name = self._queue_to_coordinator.get(queue)
+ if name is None:
+ return None
+ return self._instances_by_name.get(name)
+
+ def for_dag_file(self, bundle_name: str, path: str | os.PathLike[str]) -> BaseCoordinator | None:
+ """Return the first coordinator whose ``can_handle_dag_file`` matches *path*."""
+ for instance in self.all():
+ try:
+ if instance.can_handle_dag_file(bundle_name, path):
+ return instance
+ except Exception:
+ continue
+ return None
+
+ def file_extensions(self) -> tuple[str, ...]:
+ """Return the file extensions registered by all loaded coordinators."""
+ extensions: list[str] = []
+ for instance in self.all():
+ ext = getattr(type(instance), "file_extension", None)
+ if ext:
+ extensions.append(ext)
+ return tuple(extensions)
+
+
+@functools.cache
+def get_coordinator_manager() -> CoordinatorManager:
+ """Return the process-wide :class:`CoordinatorManager`, loaded from config on first use."""
+ return CoordinatorManager.from_config()
+
+
+def reset_coordinator_manager() -> None:
+ """Clear the cached :class:`CoordinatorManager` (test helper)."""
+ get_coordinator_manager.cache_clear()
+
+
+__all__ = [
+ "BaseCoordinator",
+ "CoordinatorManager",
+ "get_coordinator_manager",
+ "reset_coordinator_manager",
+]
diff --git a/task-sdk/src/airflow/sdk/execution_time/selector_loop.py b/task-sdk/src/airflow/sdk/execution_time/selector_loop.py
new file mode 100644
index 0000000000000..d67014ad1b418
--- /dev/null
+++ b/task-sdk/src/airflow/sdk/execution_time/selector_loop.py
@@ -0,0 +1,159 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Selector-based I/O loop utilities shared across subprocess monitors.
+
+Both :class:`~airflow.sdk.execution_time.supervisor.WatchedSubprocess`
+(supervisor-side) and provider-registered bridges such as the Locale DagFileProcessor (child-side) use these building blocks to multiplex
+socket I/O without threads.
+
+The common contract for every callback registered with the selector:
+
+* The selector stores a ``(handler, on_close)`` tuple as ``key.data``.
+* ``handler(fileobj) -> bool`` — read available data and return
+ ``True`` to keep listening, ``False`` on EOF / error.
+* ``on_close(fileobj)`` — called when the handler returns ``False``;
+ must unregister the fileobj from the selector.
+* :func:`service_selector` drives one iteration of this protocol.
+"""
+
+from __future__ import annotations
+
+import selectors
+from contextlib import suppress
+from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+ from collections.abc import Callable, Generator
+ from socket import socket
+
+ # (handler, on_close) — stored as ``selector.register(..., data=cb)``
+ SelectorCallback = tuple[Callable[[socket], bool], Callable[[socket], None]]
+
+
+# Sockets, even the `.makefile()` function don't correctly do line buffering on reading. If a chunk is read
+# and it doesn't contain a new line character, `.readline()` will just return the chunk as is.
+#
+# This returns a callback suitable for attaching to a `selector` that reads in to a buffer, and yields lines
+# to a (sync) generator
+def make_buffered_socket_reader(
+ gen: Generator[None, bytes | bytearray, None],
+ on_close: Callable[[socket], None],
+ buffer_size: int = 4096,
+) -> SelectorCallback:
+ """
+ Create a selector callback that line-buffers socket data into a generator.
+
+ Bytes are accumulated until a newline is found; each
+ complete line is sent to *gen* via ``gen.send(line)``. On EOF the
+ remainder of the buffer (if any) is flushed.
+
+ Returns a ``(handler, on_close)`` tuple suitable for
+ ``selector.register(..., data=...)``.
+ """
+ buffer = bytearray() # This will hold our accumulated binary data
+ read_buffer = bytearray(buffer_size) # Temporary buffer for each read
+
+ # We need to start up the generator to get it to the point it's at waiting on the yield
+ next(gen)
+
+ def cb(sock: socket):
+ nonlocal buffer, read_buffer
+ # Read up to `buffer_size` bytes of data from the socket
+ n_received = sock.recv_into(read_buffer)
+
+ if not n_received:
+ # If no data is returned, the connection is closed. Return whatever is left in the buffer
+ if len(buffer):
+ with suppress(StopIteration):
+ gen.send(buffer)
+ return False
+
+ buffer.extend(read_buffer[:n_received])
+
+ # We could have read multiple lines in one go, yield them all
+ while (newline_pos := buffer.find(b"\n")) != -1:
+ line = buffer[: newline_pos + 1]
+ try:
+ gen.send(line)
+ except StopIteration:
+ return False
+ buffer = buffer[newline_pos + 1 :] # Update the buffer with remaining data
+
+ return True
+
+ return cb, on_close
+
+
+def make_raw_forwarder(
+ dest: socket,
+ on_close: Callable[[socket], None],
+) -> SelectorCallback:
+ """
+ Create a selector callback that forwards raw bytes to *dest*.
+
+ Used for transparent protocol bridges where bytes must be shuttled
+ between two sockets without interpretation (e.g. length-prefixed
+ msgpack frames between a supervisor and a Java subprocess).
+ """
+
+ def cb(sock: socket) -> bool:
+ data = sock.recv(65536)
+ if not data:
+ return False
+ try:
+ dest.sendall(data)
+ except (BrokenPipeError, ConnectionResetError, OSError):
+ return False
+ return True
+
+ return cb, on_close
+
+
+def service_selector(selector: selectors.BaseSelector, timeout: float = 1.0) -> None:
+ """
+ Process one round of selector events.
+
+ For each ready socket whose handler returns ``False`` (EOF / error),
+ the socket's *on_close* callback is invoked and the socket is closed.
+ """
+ # Ensure minimum timeout to prevent CPU spike with tight loop when timeout is 0 or negative
+ timeout = max(0.01, timeout)
+ events = selector.select(timeout=timeout)
+ for key, _ in events:
+ # Retrieve the handler responsible for processing this file object (e.g., stdout, stderr)
+ socket_handler, on_close = key.data
+
+ # Example of handler behavior:
+ # If the subprocess writes "Hello, World!" to stdout:
+ # - `socket_handler` reads and processes the message.
+ # - If EOF is reached, the handler returns False to signal no more reads are expected.
+ # - BrokenPipeError should be caught and treated as if the handler returned false, similar
+ # to EOF case
+ try:
+ need_more = socket_handler(key.fileobj)
+ except (BrokenPipeError, ConnectionResetError):
+ need_more = False
+
+ # If the handler signals that the file object is no longer needed (EOF, closed, etc.)
+ # unregister it from the selector to stop monitoring; `wait()` blocks until all selectors
+ # are removed.
+ if not need_more:
+ sock: socket = key.fileobj # type: ignore[assignment]
+ on_close(sock)
+ sock.close()
diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py
index cd25d9279571b..75ac2ce784673 100644
--- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py
+++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py
@@ -126,6 +126,7 @@
handle_get_variable,
handle_mask_secret,
)
+from airflow.sdk.execution_time.selector_loop import make_buffered_socket_reader, service_selector
try:
from socket import send_fds
@@ -144,6 +145,8 @@
from airflow.executors.workloads import BundleInfo
from airflow.sdk.bases.secrets_backend import BaseSecretsBackend
from airflow.sdk.definitions.connection import Connection
+ from airflow.sdk.execution_time.selector_loop import SelectorCallback
+ from airflow.sdk.execution_time.workloads.task import TaskInstanceDTO
from airflow.sdk.types import RuntimeTaskInstanceProtocol as RuntimeTI
__all__ = ["ActivitySubprocess", "WatchedSubprocess", "supervise", "supervise_task"]
@@ -695,7 +698,7 @@ def _get_target_loggers(self) -> tuple[FilteringBoundLogger, ...]:
target_loggers += (log,)
return target_loggers
- def _create_log_forwarder(self, loggers, name, log_level=logging.INFO) -> Callable[[socket], bool]:
+ def _create_log_forwarder(self, loggers, name, log_level=logging.INFO) -> SelectorCallback:
"""Create a socket handler that forwards logs to a logger."""
loggers = tuple(
reconfigure_logger(
@@ -888,41 +891,15 @@ def _service_subprocess(
"""
Service subprocess events by processing socket activity and checking for process exit.
- This method:
- - Waits for activity on the registered file objects (via `self.selector.select`).
- - Processes any events triggered on these file objects.
- - Checks if the subprocess has exited during the wait.
+ Delegates the selector event loop to :func:`service_selector` (shared
+ with provider-registered bridges), then checks the subprocess status.
:param max_wait_time: Maximum time to block while waiting for events, in seconds.
:param raise_on_timeout: If True, raise an exception if the subprocess does not exit within the timeout.
:param expect_signal: Signal not to log if the task exits with this code.
:returns: The process exit code, or None if it's still alive
"""
- # Ensure minimum timeout to prevent CPU spike with tight loop when timeout is 0 or negative
- timeout = max(0.01, max_wait_time)
- events = self.selector.select(timeout=timeout)
- for key, _ in events:
- # Retrieve the handler responsible for processing this file object (e.g., stdout, stderr)
- socket_handler, on_close = key.data
-
- # Example of handler behavior:
- # If the subprocess writes "Hello, World!" to stdout:
- # - `socket_handler` reads and processes the message.
- # - If EOF is reached, the handler returns False to signal no more reads are expected.
- # - BrokenPipeError should be caught and treated as if the handler returned false, similar
- # to EOF case
- try:
- need_more = socket_handler(key.fileobj)
- except (BrokenPipeError, ConnectionResetError):
- need_more = False
-
- # If the handler signals that the file object is no longer needed (EOF, closed, etc.)
- # unregister it from the selector to stop monitoring; `wait()` blocks until all selectors
- # are removed.
- if not need_more:
- sock: socket = key.fileobj # type: ignore[assignment]
- on_close(sock)
- sock.close()
+ service_selector(self.selector, timeout=max_wait_time)
# Check if the subprocess has exited
return self._check_subprocess_exit(raise_on_timeout=raise_on_timeout, expect_signal=expect_signal)
@@ -1132,7 +1109,7 @@ class ActivitySubprocess(WatchedSubprocess):
def start( # type: ignore[override]
cls,
*,
- what: TaskInstance,
+ what: TaskInstanceDTO,
dag_rel_path: str | os.PathLike[str],
bundle_info,
client: Client,
@@ -1161,7 +1138,7 @@ def start( # type: ignore[override]
def _on_child_started(
self,
*,
- ti: TaskInstance,
+ ti: TaskInstanceDTO,
dag_rel_path: str | os.PathLike[str],
bundle_info,
sentry_integration: str,
@@ -1932,50 +1909,6 @@ def run_task_in_process(ti: TaskInstance, task) -> TaskRunResult:
return InProcessTestSupervisor.start(what=ti, task=task)
-# Sockets, even the `.makefile()` function don't correctly do line buffering on reading. If a chunk is read
-# and it doesn't contain a new line character, `.readline()` will just return the chunk as is.
-#
-# This returns a callback suitable for attaching to a `selector` that reads in to a buffer, and yields lines
-# to a (sync) generator
-def make_buffered_socket_reader(
- gen: Generator[None, bytes | bytearray, None],
- on_close: Callable[[socket], None],
- buffer_size: int = 4096,
-):
- buffer = bytearray() # This will hold our accumulated binary data
- read_buffer = bytearray(buffer_size) # Temporary buffer for each read
-
- # We need to start up the generator to get it to the point it's at waiting on the yield
- next(gen)
-
- def cb(sock: socket):
- nonlocal buffer, read_buffer
- # Read up to `buffer_size` bytes of data from the socket
- n_received = sock.recv_into(read_buffer)
-
- if not n_received:
- # If no data is returned, the connection is closed. Return whatever is left in the buffer
- if len(buffer):
- with suppress(StopIteration):
- gen.send(buffer)
- return False
-
- buffer.extend(read_buffer[:n_received])
-
- # We could have read multiple lines in one go, yield them all
- while (newline_pos := buffer.find(b"\n")) != -1:
- line = buffer[: newline_pos + 1]
- try:
- gen.send(line)
- except StopIteration:
- return False
- buffer = buffer[newline_pos + 1 :] # Update the buffer with remaining data
-
- return True
-
- return cb, on_close
-
-
def length_prefixed_frame_reader(
gen: Generator[None, _RequestFrame, None], on_close: Callable[[socket], None]
):
@@ -2150,7 +2083,7 @@ def _configure_logging(log_path: str, client: Client) -> tuple[FilteringBoundLog
def supervise_task(
*,
- ti: TaskInstance,
+ ti: TaskInstanceDTO,
bundle_info: BundleInfo,
dag_rel_path: str | os.PathLike[str],
token: str,
diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
index 56ba8343c648b..2917689261a1f 100644
--- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -47,6 +47,7 @@
from airflow.sdk.api.client import get_hostname, getuser
from airflow.sdk.api.datamodels._generated import (
AssetProfile,
+ BundleInfo,
DagRun,
PreviousTIResponse,
TaskInstance,
@@ -778,12 +779,7 @@ def parse(what: StartupDetails, log: Logger) -> RuntimeTaskInstance:
bundle_info = what.bundle_info
bundle_prepare_start = time.monotonic()
- bundle_instance = DagBundlesManager().get_bundle(
- name=bundle_info.name,
- version=bundle_info.version,
- )
- bundle_instance.initialize()
- _verify_bundle_access(bundle_instance, log)
+ bundle_instance = resolve_bundle(bundle_info, log)
bundle_prepare_ms = int((time.monotonic() - bundle_prepare_start) * 1000)
dag_absolute_path = os.fspath(Path(bundle_instance.path, what.dag_rel_path))
@@ -909,6 +905,22 @@ def _verify_bundle_access(bundle_instance: BaseDagBundle, log: Logger) -> None:
)
+def resolve_bundle(bundle_info: BundleInfo, log: Logger) -> BaseDagBundle:
+ """
+ Resolve, initialize, and verify access to a DAG bundle.
+
+ Used by both the standard Python task execution path and locale
+ coordinators (Java, Go, etc.) to obtain a ready-to-use bundle instance.
+ """
+ bundle_instance = DagBundlesManager().get_bundle(
+ name=bundle_info.name,
+ version=bundle_info.version,
+ )
+ bundle_instance.initialize()
+ _verify_bundle_access(bundle_instance, log)
+ return bundle_instance
+
+
def get_startup_details() -> StartupDetails:
# The parent sends us a StartupDetails message un-prompted. After this, every single message is only sent
# in response to us sending a request.
@@ -1954,6 +1966,72 @@ def flush_spans():
provider.force_flush(timeout_millis=timeout_millis)
+def _resolve_runtime_entrypoint(startup_details: StartupDetails, log: Logger) -> Callable[[], None] | None:
+ """
+ Check configured runtime coordinators for a runtime-specific entrypoint.
+
+ Resolution order:
+
+ 1. **Queue mapping** -- the ``[sdk] queue_to_coordinator`` config maps
+ the task's ``queue`` to a coordinator name from ``[sdk] coordinators``.
+ Used by the python-stub pattern where users set the queue explicitly.
+ 2. **DAG file extension** -- if no queue mapping matches, the DAG file's
+ extension (e.g. ``.jar``) is compared against each coordinator's
+ ``file_extension``. Used by the pure-runtime DAG pattern where the
+ entire DAG is authored in a non-Python language.
+
+ Returns a no-arg callable that bridges fd 0 to the runtime subprocess,
+ or ``None`` to fall through to the standard Python execution path.
+ """
+ import functools
+
+ from airflow.sdk.execution_time.coordinator import get_coordinator_manager
+
+ manager = get_coordinator_manager()
+
+ def _build(coordinator) -> Callable[[], None]:
+ return functools.partial(
+ coordinator.run_task_execution,
+ what=startup_details.ti,
+ dag_rel_path=startup_details.dag_rel_path,
+ bundle_info=startup_details.bundle_info,
+ startup_details=startup_details,
+ )
+
+ # Step 1: queue-to-coordinator mapping.
+ queue = startup_details.ti.queue
+ if (coordinator := manager.for_queue(queue)) is not None:
+ log.debug(
+ "Resolved coordinator for task via queue mapping",
+ coordinator=type(coordinator).__qualname__,
+ queue=queue,
+ task_id=startup_details.ti.task_id,
+ )
+ return _build(coordinator)
+
+ # Step 2: DAG file extension fallback (pure- DAGs).
+ dag_rel_path = startup_details.dag_rel_path
+ for coordinator in manager.all():
+ ext = getattr(type(coordinator), "file_extension", None)
+ if not ext or not dag_rel_path.endswith(ext):
+ continue
+ log.debug(
+ "Resolved coordinator for task via DAG file extension",
+ coordinator=type(coordinator).__qualname__,
+ dag_rel_path=dag_rel_path,
+ task_id=startup_details.ti.task_id,
+ )
+ return _build(coordinator)
+
+ log.debug(
+ "No runtime coordinator matched, using standard Python execution path",
+ queue=queue,
+ dag_rel_path=dag_rel_path,
+ task_id=startup_details.ti.task_id,
+ )
+ return None
+
+
@flush_spans()
def main():
log = structlog.get_logger(logger_name="task")
@@ -1980,6 +2058,14 @@ def main():
# startup message as a ResendLoggingFD response.
if os.environ.pop("_AIRFLOW_FORK_EXEC", None) == "1":
reinit_supervisor_comms()
+ # Check if a configured runtime coordinator should handle this
+ # task (e.g. Java, Go) instead of the standard Python
+ # execution path.
+ log.debug("Checking for runtime-specific entrypoint")
+ runtime_entrypoint = _resolve_runtime_entrypoint(startup_details, log)
+ if runtime_entrypoint is not None:
+ runtime_entrypoint()
+ return
span = _make_task_span(msg=startup_details)
stack.enter_context(span)
ti, context, log = startup(msg=startup_details)
diff --git a/task-sdk/src/airflow/sdk/execution_time/workloads/__init__.py b/task-sdk/src/airflow/sdk/execution_time/workloads/__init__.py
new file mode 100644
index 0000000000000..cdf955e742dc0
--- /dev/null
+++ b/task-sdk/src/airflow/sdk/execution_time/workloads/__init__.py
@@ -0,0 +1,23 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Workload schemas for Task SDK execution-time communication."""
+
+from __future__ import annotations
+
+from airflow.sdk.execution_time.workloads.task import TaskInstanceDTO
+
+__all__ = ["TaskInstanceDTO"]
diff --git a/task-sdk/src/airflow/sdk/execution_time/workloads/task.py b/task-sdk/src/airflow/sdk/execution_time/workloads/task.py
new file mode 100644
index 0000000000000..ceff200856f06
--- /dev/null
+++ b/task-sdk/src/airflow/sdk/execution_time/workloads/task.py
@@ -0,0 +1,53 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Task workload schemas for Task SDK execution-time communication."""
+
+from __future__ import annotations
+
+import uuid
+
+from pydantic import BaseModel, Field
+
+
+class BaseTaskInstanceDTO(BaseModel):
+ """
+ Base schema for TaskInstance with the minimal fields shared by Executors and the Task SDK.
+
+ This class is duplicated in :mod:`airflow.executors.workloads.task` and the
+ two definitions are kept in sync by the ``check-task-instance-dto-sync``
+ prek hook. Update both files together.
+ """
+
+ id: uuid.UUID
+ dag_version_id: uuid.UUID
+ task_id: str
+ dag_id: str
+ run_id: str
+ try_number: int
+ map_index: int = -1
+
+ pool_slots: int
+ queue: str
+ priority_weight: int
+ executor_config: dict | None = Field(default=None, exclude=True)
+
+ parent_context_carrier: dict | None = None
+ context_carrier: dict | None = None
+
+
+class TaskInstanceDTO(BaseTaskInstanceDTO):
+ """Task SDK TaskInstanceDTO."""
diff --git a/task-sdk/tests/task_sdk/definitions/test_mappedoperator.py b/task-sdk/tests/task_sdk/definitions/test_mappedoperator.py
index 2b34fac6ea0f9..93c5cc19aed47 100644
--- a/task-sdk/tests/task_sdk/definitions/test_mappedoperator.py
+++ b/task-sdk/tests/task_sdk/definitions/test_mappedoperator.py
@@ -680,14 +680,14 @@ def mock_comms_response(msg):
("tg.t2", 0): ["a", "b"],
("tg.t2", 1): [4],
("tg.t2", 2): ["z"],
- ("t3", None): [["a", "b"], [4], ["z"]],
+ ("t3", -1): [["a", "b"], [4], ["z"]],
}
# We hard-code the number of expansions here as the server is in charge of that.
expansion_per_task_id = {
"tg.t1": range(3),
"tg.t2": range(3),
- "t3": [None],
+ "t3": [-1],
}
for task in dag.tasks:
for map_index in expansion_per_task_id[task.task_id]:
diff --git a/task-sdk/tests/task_sdk/definitions/test_xcom_arg.py b/task-sdk/tests/task_sdk/definitions/test_xcom_arg.py
index af487851b07cb..6014d26f2208b 100644
--- a/task-sdk/tests/task_sdk/definitions/test_xcom_arg.py
+++ b/task-sdk/tests/task_sdk/definitions/test_xcom_arg.py
@@ -344,7 +344,7 @@ def xcom_get(msg):
mock_supervisor_comms.send.side_effect = xcom_get
# Run "pull_one" and "pull_all".
- assert run_ti(dag, "pull_all", None) == TaskInstanceState.SUCCESS
+ assert run_ti(dag, "pull_all", -1) == TaskInstanceState.SUCCESS
assert all_results == ["a", "b", "c", 1, 2]
states = [run_ti(dag, "pull_one", map_index) for map_index in range(5)]
diff --git a/task-sdk/tests/task_sdk/execution_time/test_comms.py b/task-sdk/tests/task_sdk/execution_time/test_comms.py
index 5c6d88439250c..37a91dd0ecc28 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_comms.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_comms.py
@@ -86,6 +86,9 @@ def test_recv_StartupDetails(self):
"run_id": "b",
"dag_id": "c",
"dag_version_id": uuid.UUID("4d828a62-a417-4936-a7a6-2b3fabacecab"),
+ "pool_slots": 1,
+ "queue": "default",
+ "priority_weight": 1,
},
"ti_context": {
"dag_run": {
diff --git a/task-sdk/tests/task_sdk/execution_time/test_coordinator.py b/task-sdk/tests/task_sdk/execution_time/test_coordinator.py
new file mode 100644
index 0000000000000..7c11022755d9f
--- /dev/null
+++ b/task-sdk/tests/task_sdk/execution_time/test_coordinator.py
@@ -0,0 +1,650 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import contextlib
+import json
+import os
+import socket
+import subprocess
+from pathlib import Path
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from airflow.sdk.execution_time.coordinator import (
+ BaseCoordinator,
+ CoordinatorManager,
+ _bridge,
+ _send_startup_details,
+ _start_server,
+ get_coordinator_manager,
+ reset_coordinator_manager,
+)
+
+
+class TestStartServer:
+ def test_binds_to_localhost(self):
+ server = _start_server()
+ try:
+ host, port = server.getsockname()
+ assert host == "127.0.0.1"
+ assert port > 0
+ finally:
+ server.close()
+
+ def test_assigns_random_port(self):
+ s1 = _start_server()
+ s2 = _start_server()
+ try:
+ _, port1 = s1.getsockname()
+ _, port2 = s2.getsockname()
+ assert port1 != port2
+ finally:
+ s1.close()
+ s2.close()
+
+ def test_accepts_connection(self):
+ server = _start_server()
+ try:
+ addr = server.getsockname()
+ client = socket.socket()
+ client.connect(addr)
+ conn, _ = server.accept()
+ conn.sendall(b"ping")
+ assert client.recv(4) == b"ping"
+ conn.close()
+ client.close()
+ finally:
+ server.close()
+
+
+class TestSendStartupDetails:
+ def test_sends_frame_bytes_to_socket(self):
+ mock_startup = MagicMock()
+ mock_startup.model_dump.return_value = {"type": "StartupDetails", "ti": {}}
+
+ mock_socket = MagicMock(spec=socket.socket)
+
+ _send_startup_details(mock_socket, mock_startup)
+
+ mock_startup.model_dump.assert_called_once_with(mode="json")
+ mock_socket.sendall.assert_called_once()
+
+ sent_bytes = mock_socket.sendall.call_args[0][0]
+ assert len(sent_bytes) > 4
+ length = int.from_bytes(sent_bytes[:4], "big")
+ assert length == len(sent_bytes) - 4
+
+ def test_frame_contains_response_id_zero(self):
+ import msgpack
+
+ mock_startup = MagicMock()
+ mock_startup.model_dump.return_value = {"type": "StartupDetails"}
+
+ mock_socket = MagicMock(spec=socket.socket)
+
+ _send_startup_details(mock_socket, mock_startup)
+
+ sent_bytes = mock_socket.sendall.call_args[0][0]
+ frame = msgpack.unpackb(sent_bytes[4:])
+ assert frame[0] == 0
+
+ def test_frame_body_matches_model_dump(self):
+ import msgpack
+
+ body = {"type": "StartupDetails", "ti": {"task_id": "t1"}, "dag_rel_path": "test.jar"}
+ mock_startup = MagicMock()
+ mock_startup.model_dump.return_value = body
+
+ mock_socket = MagicMock(spec=socket.socket)
+
+ _send_startup_details(mock_socket, mock_startup)
+
+ sent_bytes = mock_socket.sendall.call_args[0][0]
+ frame = msgpack.unpackb(sent_bytes[4:])
+ assert frame[1] == body
+
+ def test_real_socket_roundtrip(self):
+ import msgpack
+
+ server = socket.socket()
+ server.bind(("127.0.0.1", 0))
+ server.listen(1)
+ addr = server.getsockname()
+
+ client = socket.socket()
+ client.connect(addr)
+ conn, _ = server.accept()
+
+ try:
+ body = {"type": "StartupDetails", "value": 42}
+ mock_startup = MagicMock()
+ mock_startup.model_dump.return_value = body
+
+ _send_startup_details(conn, mock_startup)
+
+ length_bytes = client.recv(4)
+ length = int.from_bytes(length_bytes, "big")
+
+ data = client.recv(length)
+ frame = msgpack.unpackb(data)
+ assert frame[0] == 0
+ assert frame[1] == body
+ finally:
+ conn.close()
+ client.close()
+ server.close()
+
+
+class TestBaseCoordinatorDefaults:
+ def test_can_handle_dag_file_returns_false(self):
+ assert BaseCoordinator().can_handle_dag_file("bundle", "/path/to/dag.py") is False
+
+ def test_get_code_from_file_raises_not_implemented(self):
+ with pytest.raises(NotImplementedError):
+ BaseCoordinator().get_code_from_file("/path/to/dag.jar")
+
+ def test_dag_parsing_cmd_raises_not_implemented(self):
+ with pytest.raises(NotImplementedError):
+ BaseCoordinator().dag_parsing_cmd(
+ dag_file_path="/dag.jar",
+ bundle_name="b",
+ bundle_path="/path",
+ comm_addr="127.0.0.1:1234",
+ logs_addr="127.0.0.1:1235",
+ )
+
+ def test_task_execution_cmd_raises_not_implemented(self):
+ with pytest.raises(NotImplementedError):
+ BaseCoordinator().task_execution_cmd(
+ what=MagicMock(),
+ dag_file_path="/dag.jar",
+ bundle_path="/path",
+ bundle_info=MagicMock(),
+ comm_addr="127.0.0.1:1234",
+ logs_addr="127.0.0.1:1235",
+ )
+
+
+class TestCoordinatorNamedTuples:
+ def test_dag_parsing_info_defaults(self):
+ info = BaseCoordinator.DagParsingInfo(
+ dag_file_path="/dag.jar",
+ bundle_name="my-bundle",
+ bundle_path="/bundles/my-bundle",
+ )
+ assert info.mode == "dag-parsing"
+ assert info.dag_file_path == "/dag.jar"
+ assert info.bundle_name == "my-bundle"
+ assert info.bundle_path == "/bundles/my-bundle"
+
+ def test_task_execution_info_defaults(self):
+ mock_ti = MagicMock()
+ mock_bundle = MagicMock()
+ mock_startup = MagicMock()
+ info = BaseCoordinator.TaskExecutionInfo(
+ what=mock_ti,
+ dag_rel_path="dags/example.jar",
+ bundle_info=mock_bundle,
+ startup_details=mock_startup,
+ )
+ assert info.mode == "task-execution"
+ assert info.what is mock_ti
+ assert info.dag_rel_path == "dags/example.jar"
+
+
+class TestBridge:
+ def test_bridge_forwards_comm_bidirectionally(self):
+ sup_send, sup_recv = socket.socketpair()
+ rt_send, rt_recv = socket.socketpair()
+ log_send, log_recv = socket.socketpair()
+ stderr_send, stderr_recv = socket.socketpair()
+
+ mock_proc = MagicMock(spec=subprocess.Popen)
+ mock_proc.poll.return_value = 0
+ mock_log = MagicMock()
+
+ try:
+ sup_send.sendall(b"from_supervisor")
+ rt_send.sendall(b"from_runtime")
+ log_send.sendall(b'{"event":"hello","level":"info"}\n')
+ stderr_send.sendall(b"stderr line\n")
+
+ sup_send.close()
+ rt_send.close()
+ log_send.close()
+ stderr_send.close()
+
+ _bridge(sup_recv, rt_recv, log_recv, stderr_recv, mock_proc, mock_log)
+ finally:
+ for s in (sup_send, rt_send, log_send, stderr_send, sup_recv, rt_recv, log_recv, stderr_recv):
+ with contextlib.suppress(OSError):
+ s.close()
+
+ def test_bridge_drains_after_process_exit(self):
+ sup_local, sup_remote = socket.socketpair()
+ rt_local, rt_remote = socket.socketpair()
+ log_local, log_remote = socket.socketpair()
+ stderr_local, stderr_remote = socket.socketpair()
+
+ mock_proc = MagicMock(spec=subprocess.Popen)
+ mock_proc.poll.side_effect = [None, 0, 0, 0, 0, 0, 0, 0, 0, 0]
+ mock_log = MagicMock()
+
+ try:
+ stderr_local.sendall(b"error output\n")
+ stderr_local.close()
+ sup_local.close()
+ rt_local.close()
+ log_local.close()
+
+ _bridge(sup_remote, rt_remote, log_remote, stderr_remote, mock_proc, mock_log)
+ finally:
+ for s in (
+ sup_local,
+ sup_remote,
+ rt_local,
+ rt_remote,
+ log_local,
+ log_remote,
+ stderr_local,
+ stderr_remote,
+ ):
+ with contextlib.suppress(OSError):
+ s.close()
+
+ def test_bridge_closes_all_sockets(self):
+ sup = MagicMock(spec=socket.socket)
+ rt = MagicMock(spec=socket.socket)
+ logs = MagicMock(spec=socket.socket)
+ stderr = MagicMock(spec=socket.socket)
+
+ mock_proc = MagicMock(spec=subprocess.Popen)
+ mock_proc.poll.return_value = 0
+ mock_log = MagicMock()
+
+ with (
+ patch("airflow.sdk.execution_time.coordinator.selectors.DefaultSelector") as mock_sel_cls,
+ patch("airflow.sdk.execution_time.selector_loop.service_selector"),
+ ):
+ mock_sel = MagicMock()
+ mock_sel_cls.return_value = mock_sel
+ mock_sel.get_map.return_value = {}
+
+ _bridge(sup, rt, logs, stderr, mock_proc, mock_log)
+
+ sup.close.assert_called()
+ rt.close.assert_called()
+ logs.close.assert_called()
+ stderr.close.assert_called()
+ mock_sel.close.assert_called_once()
+
+
+class _StubCoordinator(BaseCoordinator):
+ sdk = "test"
+ file_extension = ".test"
+
+ def __init__(self, *, parse_cmd: list[str] | None = None, exec_cmd: list[str] | None = None):
+ self._parse_cmd = parse_cmd or ["test-runtime", "--parse"]
+ self._exec_cmd = exec_cmd or ["test-runtime", "--execute"]
+
+ def dag_parsing_cmd(self, *, dag_file_path, **_):
+ return [*self._parse_cmd, dag_file_path]
+
+ def task_execution_cmd(self, *, dag_file_path, **_):
+ return [*self._exec_cmd, dag_file_path]
+
+
+class TestRunDagParsing:
+ @patch.object(BaseCoordinator, "_runtime_subprocess_entrypoint")
+ def test_run_dag_parsing_creates_dag_parsing_info(self, mock_entrypoint):
+ coordinator = _StubCoordinator()
+ coordinator.run_dag_parsing(
+ path="/bundles/my-bundle/dags/example.jar",
+ bundle_name="my-bundle",
+ bundle_path="/bundles/my-bundle",
+ )
+
+ mock_entrypoint.assert_called_once()
+ info = mock_entrypoint.call_args[0][0]
+ assert isinstance(info, BaseCoordinator.DagParsingInfo)
+ assert info.dag_file_path == "/bundles/my-bundle/dags/example.jar"
+ assert info.bundle_name == "my-bundle"
+ assert info.bundle_path == "/bundles/my-bundle"
+ assert info.mode == "dag-parsing"
+
+
+class TestRunTaskExecution:
+ @patch.object(BaseCoordinator, "_runtime_subprocess_entrypoint")
+ def test_run_task_execution_creates_task_execution_info(self, mock_entrypoint):
+ mock_ti = MagicMock()
+ mock_bundle_info = MagicMock()
+ mock_startup = MagicMock()
+
+ coordinator = _StubCoordinator()
+ coordinator.run_task_execution(
+ what=mock_ti,
+ dag_rel_path="dags/example.jar",
+ bundle_info=mock_bundle_info,
+ startup_details=mock_startup,
+ )
+
+ mock_entrypoint.assert_called_once()
+ info = mock_entrypoint.call_args[0][0]
+ assert isinstance(info, BaseCoordinator.TaskExecutionInfo)
+ assert info.what is mock_ti
+ assert info.dag_rel_path == "dags/example.jar"
+ assert info.bundle_info is mock_bundle_info
+ assert info.startup_details is mock_startup
+ assert info.mode == "task-execution"
+
+
+class TestRuntimeSubprocessEntrypoint:
+ @pytest.fixture(autouse=True)
+ def _restore_process_context_env(self):
+ old = os.environ.get("_AIRFLOW_PROCESS_CONTEXT")
+ try:
+ yield
+ finally:
+ if old is None:
+ os.environ.pop("_AIRFLOW_PROCESS_CONTEXT", None)
+ else:
+ os.environ["_AIRFLOW_PROCESS_CONTEXT"] = old
+
+ def test_unknown_entrypoint_info_type_raises(self):
+ coordinator = _StubCoordinator()
+ fake_info = MagicMock()
+ fake_info.mode = "unknown"
+
+ with pytest.raises(ValueError, match="Unknown entrypoint_info type"):
+ coordinator._runtime_subprocess_entrypoint(fake_info) # type: ignore[arg-type]
+
+ @patch("airflow.sdk.execution_time.coordinator._bridge")
+ @patch("airflow.sdk.execution_time.coordinator._send_startup_details")
+ @patch("subprocess.Popen", autospec=True)
+ @patch("airflow.sdk.execution_time.coordinator._start_server")
+ @patch("os.dup", return_value=99)
+ def test_dag_parsing_flow(self, mock_dup, mock_start_server, mock_popen, mock_send_startup, mock_bridge):
+ comm_server = MagicMock(spec=socket.socket)
+ comm_server.getsockname.return_value = ("127.0.0.1", 5000)
+ logs_server = MagicMock(spec=socket.socket)
+ logs_server.getsockname.return_value = ("127.0.0.1", 5001)
+ mock_start_server.side_effect = [comm_server, logs_server]
+
+ runtime_comm = MagicMock(spec=socket.socket)
+ runtime_logs = MagicMock(spec=socket.socket)
+ comm_server.accept.return_value = (runtime_comm, ("127.0.0.1", 9000))
+ logs_server.accept.return_value = (runtime_logs, ("127.0.0.1", 9001))
+
+ child_stderr = MagicMock(spec=socket.socket)
+ read_stderr = MagicMock(spec=socket.socket)
+ child_stderr.fileno.return_value = 10
+
+ supervisor_comm = MagicMock(spec=socket.socket)
+
+ coordinator = _StubCoordinator(parse_cmd=["test-runtime", "--parse"])
+ info = BaseCoordinator.DagParsingInfo(
+ dag_file_path="/dag.test",
+ bundle_name="test-bundle",
+ bundle_path="/bundles/test-bundle",
+ )
+
+ with (
+ patch("socket.socketpair", return_value=(child_stderr, read_stderr)),
+ patch("airflow.sdk.execution_time.coordinator.socket.socket", return_value=supervisor_comm),
+ ):
+ coordinator._runtime_subprocess_entrypoint(info)
+
+ mock_popen.assert_called_once()
+ cmd = mock_popen.call_args[0][0]
+ assert cmd == ["test-runtime", "--parse", "/dag.test"]
+
+ comm_server.accept.assert_called_once()
+ logs_server.accept.assert_called_once()
+ comm_server.close.assert_called_once()
+ logs_server.close.assert_called_once()
+
+ child_stderr.close.assert_called_once()
+ mock_send_startup.assert_not_called()
+
+ mock_bridge.assert_called_once()
+ assert mock_bridge.call_args[0][0] is supervisor_comm
+
+ @patch("airflow.sdk.execution_time.coordinator._bridge")
+ @patch("airflow.sdk.execution_time.coordinator._send_startup_details")
+ @patch("subprocess.Popen", autospec=True)
+ @patch("airflow.sdk.execution_time.coordinator._start_server")
+ @patch("os.dup", return_value=99)
+ @patch("airflow.sdk.execution_time.task_runner.resolve_bundle")
+ @patch("airflow.dag_processing.bundles.base.BundleVersionLock", autospec=True)
+ def test_task_execution_flow(
+ self,
+ mock_bundle_lock,
+ mock_resolve_bundle,
+ mock_dup,
+ mock_start_server,
+ mock_popen,
+ mock_send_startup,
+ mock_bridge,
+ ):
+ comm_server = MagicMock(spec=socket.socket)
+ comm_server.getsockname.return_value = ("127.0.0.1", 6000)
+ logs_server = MagicMock(spec=socket.socket)
+ logs_server.getsockname.return_value = ("127.0.0.1", 6001)
+ mock_start_server.side_effect = [comm_server, logs_server]
+
+ runtime_comm = MagicMock(spec=socket.socket)
+ runtime_logs = MagicMock(spec=socket.socket)
+ comm_server.accept.return_value = (runtime_comm, ("127.0.0.1", 9000))
+ logs_server.accept.return_value = (runtime_logs, ("127.0.0.1", 9001))
+
+ child_stderr = MagicMock(spec=socket.socket)
+ read_stderr = MagicMock(spec=socket.socket)
+ child_stderr.fileno.return_value = 10
+
+ mock_bundle_instance = MagicMock()
+ mock_bundle_instance.path = Path("/resolved/bundles/test-bundle")
+ mock_resolve_bundle.return_value = mock_bundle_instance
+
+ mock_lock_instance = MagicMock()
+ mock_bundle_lock.return_value = mock_lock_instance
+ mock_lock_instance.__enter__ = MagicMock(return_value=mock_lock_instance)
+ mock_lock_instance.__exit__ = MagicMock(return_value=False)
+
+ mock_ti = MagicMock()
+ mock_bundle_info = MagicMock()
+ mock_bundle_info.name = "test-bundle"
+ mock_bundle_info.version = "v1"
+ mock_startup = MagicMock()
+
+ coordinator = _StubCoordinator(exec_cmd=["test-runtime", "--execute"])
+ info = BaseCoordinator.TaskExecutionInfo(
+ what=mock_ti,
+ dag_rel_path="dags/example.test",
+ bundle_info=mock_bundle_info,
+ startup_details=mock_startup,
+ )
+
+ supervisor_comm = MagicMock(spec=socket.socket)
+
+ with (
+ patch("socket.socketpair", return_value=(child_stderr, read_stderr)),
+ patch("airflow.sdk.execution_time.coordinator.socket.socket", return_value=supervisor_comm),
+ ):
+ coordinator._runtime_subprocess_entrypoint(info)
+
+ mock_resolve_bundle.assert_called_once()
+ mock_bundle_lock.assert_called_once_with(bundle_name="test-bundle", bundle_version="v1")
+
+ mock_popen.assert_called_once()
+ cmd = mock_popen.call_args[0][0]
+ assert cmd == ["test-runtime", "--execute", "/resolved/bundles/test-bundle/dags/example.test"]
+
+ mock_send_startup.assert_called_once_with(runtime_comm, mock_startup)
+ mock_bridge.assert_called_once()
+
+ @patch("airflow.sdk.execution_time.coordinator._bridge")
+ @patch("subprocess.Popen", autospec=True)
+ @patch("airflow.sdk.execution_time.coordinator._start_server")
+ @patch("os.dup", return_value=99)
+ def test_sets_process_context_env_var(self, mock_dup, mock_start_server, mock_popen, mock_bridge):
+ comm_server = MagicMock(spec=socket.socket)
+ comm_server.getsockname.return_value = ("127.0.0.1", 7000)
+ logs_server = MagicMock(spec=socket.socket)
+ logs_server.getsockname.return_value = ("127.0.0.1", 7001)
+ mock_start_server.side_effect = [comm_server, logs_server]
+
+ runtime_comm = MagicMock(spec=socket.socket)
+ runtime_logs = MagicMock(spec=socket.socket)
+ comm_server.accept.return_value = (runtime_comm, ("127.0.0.1", 9000))
+ logs_server.accept.return_value = (runtime_logs, ("127.0.0.1", 9001))
+
+ child_stderr = MagicMock(spec=socket.socket)
+ read_stderr = MagicMock(spec=socket.socket)
+ child_stderr.fileno.return_value = 10
+
+ coordinator = _StubCoordinator(parse_cmd=["echo", "test"])
+ info = BaseCoordinator.DagParsingInfo(
+ dag_file_path="/dag.test",
+ bundle_name="b",
+ bundle_path="/path",
+ )
+
+ supervisor_comm = MagicMock(spec=socket.socket)
+
+ old_val = os.environ.get("_AIRFLOW_PROCESS_CONTEXT")
+ try:
+ with (
+ patch("socket.socketpair", return_value=(child_stderr, read_stderr)),
+ patch("airflow.sdk.execution_time.coordinator.socket.socket", return_value=supervisor_comm),
+ ):
+ coordinator._runtime_subprocess_entrypoint(info)
+ assert os.environ["_AIRFLOW_PROCESS_CONTEXT"] == "client"
+ finally:
+ if old_val is None:
+ os.environ.pop("_AIRFLOW_PROCESS_CONTEXT", None)
+ else:
+ os.environ["_AIRFLOW_PROCESS_CONTEXT"] = old_val
+
+
+class _CoordinatorA(BaseCoordinator):
+ sdk = "a"
+ file_extension = ".a"
+
+ def __init__(self, *, label: str = "a"):
+ self.label = label
+
+ def can_handle_dag_file(self, bundle_name, path):
+ return os.fspath(path).endswith(".a")
+
+
+class _CoordinatorB(BaseCoordinator):
+ sdk = "b"
+ file_extension = ".b"
+
+ def can_handle_dag_file(self, bundle_name, path):
+ return os.fspath(path).endswith(".b")
+
+
+class TestCoordinatorManager:
+ @pytest.fixture(autouse=True)
+ def _reset_cache(self):
+ reset_coordinator_manager()
+ yield
+ reset_coordinator_manager()
+
+ def test_from_config_loads_instances(self, monkeypatch):
+ coordinators_json = json.dumps(
+ [
+ {
+ "name": "alpha",
+ "classpath": f"{_CoordinatorA.__module__}._CoordinatorA",
+ "kwargs": {"label": "alpha-label"},
+ },
+ {
+ "name": "beta",
+ "classpath": f"{_CoordinatorB.__module__}._CoordinatorB",
+ },
+ ]
+ )
+ queue_json = json.dumps({"queue-a": "alpha"})
+
+ monkeypatch.setenv("AIRFLOW__SDK__COORDINATORS", coordinators_json)
+ monkeypatch.setenv("AIRFLOW__SDK__QUEUE_TO_COORDINATOR", queue_json)
+
+ from airflow.sdk.configuration import conf
+
+ conf.invalidate_cache()
+
+ manager = CoordinatorManager.from_config()
+
+ alpha = manager.get("alpha")
+ beta = manager.get("beta")
+ assert isinstance(alpha, _CoordinatorA)
+ assert isinstance(beta, _CoordinatorB)
+ assert alpha.label == "alpha-label"
+ assert {type(c) for c in manager.all()} == {_CoordinatorA, _CoordinatorB}
+
+ def test_from_config_empty(self, monkeypatch):
+ monkeypatch.delenv("AIRFLOW__SDK__COORDINATORS", raising=False)
+ monkeypatch.delenv("AIRFLOW__SDK__QUEUE_TO_COORDINATOR", raising=False)
+
+ from airflow.sdk.configuration import conf
+
+ conf.invalidate_cache()
+
+ manager = CoordinatorManager.from_config()
+ assert manager.all() == []
+ assert manager.get("missing") is None
+
+ def test_for_queue_resolves_via_mapping(self):
+ coordinator_a = _CoordinatorA()
+ coordinator_b = _CoordinatorB()
+ manager = CoordinatorManager(
+ {"alpha": coordinator_a, "beta": coordinator_b},
+ {"queue-a": "alpha", "queue-b": "beta"},
+ )
+
+ assert manager.for_queue("queue-a") is coordinator_a
+ assert manager.for_queue("queue-b") is coordinator_b
+ assert manager.for_queue("queue-missing") is None
+
+ def test_for_dag_file_picks_first_match(self):
+ coordinator_a = _CoordinatorA()
+ coordinator_b = _CoordinatorB()
+ manager = CoordinatorManager({"alpha": coordinator_a, "beta": coordinator_b}, {})
+
+ assert manager.for_dag_file("bundle", "dag.a") is coordinator_a
+ assert manager.for_dag_file("bundle", "dag.b") is coordinator_b
+ assert manager.for_dag_file("bundle", "dag.py") is None
+
+ def test_file_extensions(self):
+ manager = CoordinatorManager({"a": _CoordinatorA(), "b": _CoordinatorB()}, {})
+ assert set(manager.file_extensions()) == {".a", ".b"}
+
+ def test_get_coordinator_manager_is_cached(self, monkeypatch):
+ monkeypatch.delenv("AIRFLOW__SDK__COORDINATORS", raising=False)
+
+ from airflow.sdk.configuration import conf
+
+ conf.invalidate_cache()
+
+ m1 = get_coordinator_manager()
+ m2 = get_coordinator_manager()
+ assert m1 is m2
diff --git a/task-sdk/tests/task_sdk/execution_time/test_selector_loop.py b/task-sdk/tests/task_sdk/execution_time/test_selector_loop.py
new file mode 100644
index 0000000000000..efbfa83adecf8
--- /dev/null
+++ b/task-sdk/tests/task_sdk/execution_time/test_selector_loop.py
@@ -0,0 +1,479 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import selectors
+import socket
+from unittest.mock import MagicMock
+
+import pytest
+
+from airflow.sdk.execution_time.selector_loop import (
+ make_buffered_socket_reader,
+ make_raw_forwarder,
+ service_selector,
+)
+
+
+def _make_generator():
+ """Return a generator that collects sent lines into a list."""
+ received: list[bytes | bytearray] = []
+
+ def gen():
+ while True:
+ line = yield
+ received.append(bytes(line))
+
+ g = gen()
+ return g, received
+
+
+def _make_socket_pair():
+ """Create a connected TCP socket pair on localhost."""
+ server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ server.bind(("127.0.0.1", 0))
+ server.listen(1)
+ addr = server.getsockname()
+
+ client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ client.connect(addr)
+ conn, _ = server.accept()
+ server.close()
+ return client, conn
+
+
+class TestMakeBufferedSocketReader:
+ def test_single_complete_line(self):
+ gen, received = _make_generator()
+ on_close = MagicMock()
+ handler, returned_on_close = make_buffered_socket_reader(gen, on_close)
+
+ sock = MagicMock(spec=socket.socket)
+ # recv_into writes data and returns count
+ data = b"hello world\n"
+ sock.recv_into.side_effect = lambda buf: _fill_buffer(buf, data)
+
+ result = handler(sock)
+
+ assert result is True
+ assert received == [b"hello world\n"]
+ assert returned_on_close is on_close
+
+ def test_multiple_lines_in_single_recv(self):
+ gen, received = _make_generator()
+ on_close = MagicMock()
+ handler, _ = make_buffered_socket_reader(gen, on_close)
+
+ sock = MagicMock(spec=socket.socket)
+ data = b"line1\nline2\nline3\n"
+ sock.recv_into.side_effect = lambda buf: _fill_buffer(buf, data)
+
+ result = handler(sock)
+
+ assert result is True
+ assert received == [b"line1\n", b"line2\n", b"line3\n"]
+
+ def test_partial_line_accumulated_across_calls(self):
+ gen, received = _make_generator()
+ on_close = MagicMock()
+ handler, _ = make_buffered_socket_reader(gen, on_close)
+
+ sock = MagicMock(spec=socket.socket)
+
+ # First call: partial line (no newline)
+ sock.recv_into.side_effect = lambda buf: _fill_buffer(buf, b"hell")
+ result = handler(sock)
+ assert result is True
+ assert received == []
+
+ # Second call: rest of the line
+ sock.recv_into.side_effect = lambda buf: _fill_buffer(buf, b"o\n")
+ result = handler(sock)
+ assert result is True
+ assert received == [b"hello\n"]
+
+ def test_eof_flushes_remaining_buffer(self):
+ gen, received = _make_generator()
+ on_close = MagicMock()
+ handler, _ = make_buffered_socket_reader(gen, on_close)
+
+ sock = MagicMock(spec=socket.socket)
+
+ # Send partial data (no newline)
+ sock.recv_into.side_effect = lambda buf: _fill_buffer(buf, b"leftover")
+ handler(sock)
+ assert received == []
+
+ # EOF (recv_into returns 0) — clear side_effect so return_value takes effect
+ sock.recv_into.side_effect = None
+ sock.recv_into.return_value = 0
+ result = handler(sock)
+
+ assert result is False
+ assert received == [b"leftover"]
+
+ def test_eof_with_empty_buffer(self):
+ gen, received = _make_generator()
+ on_close = MagicMock()
+ handler, _ = make_buffered_socket_reader(gen, on_close)
+
+ sock = MagicMock(spec=socket.socket)
+ sock.recv_into.return_value = 0
+
+ result = handler(sock)
+
+ assert result is False
+ assert received == []
+
+ def test_generator_stop_iteration_returns_false(self):
+ """If the generator is exhausted, handler returns False."""
+
+ def limited_gen():
+ yield # startup
+ yield # receive one line, then stop
+
+ gen = limited_gen()
+ on_close = MagicMock()
+ handler, _ = make_buffered_socket_reader(gen, on_close)
+
+ sock = MagicMock(spec=socket.socket)
+ # First line succeeds
+ sock.recv_into.side_effect = lambda buf: _fill_buffer(buf, b"line1\n")
+ result = handler(sock)
+ assert result is True
+
+ # Second line triggers StopIteration in the generator
+ sock.recv_into.side_effect = lambda buf: _fill_buffer(buf, b"line2\n")
+ result = handler(sock)
+ assert result is False
+
+ def test_mixed_complete_and_partial_lines(self):
+ gen, received = _make_generator()
+ on_close = MagicMock()
+ handler, _ = make_buffered_socket_reader(gen, on_close)
+
+ sock = MagicMock(spec=socket.socket)
+ # Data contains one complete line and a partial line
+ sock.recv_into.side_effect = lambda buf: _fill_buffer(buf, b"complete\npart")
+ handler(sock)
+ assert received == [b"complete\n"]
+
+ # Finish the partial line
+ sock.recv_into.side_effect = lambda buf: _fill_buffer(buf, b"ial\n")
+ handler(sock)
+ assert received == [b"complete\n", b"partial\n"]
+
+ def test_custom_buffer_size(self):
+ gen, received = _make_generator()
+ on_close = MagicMock()
+ handler, _ = make_buffered_socket_reader(gen, on_close, buffer_size=8)
+
+ sock = MagicMock(spec=socket.socket)
+ # Data larger than buffer_size — recv_into only reads buffer_size bytes
+ full_data = b"abcdefghijklmnop\n"
+ # Simulate chunked reads
+ sock.recv_into.side_effect = lambda buf: _fill_buffer(buf, full_data[: len(buf)])
+ handler(sock)
+ # Only first 8 bytes read, no newline yet
+ assert received == []
+
+ sock.recv_into.side_effect = lambda buf: _fill_buffer(buf, full_data[8:16])
+ handler(sock)
+ assert received == []
+
+ sock.recv_into.side_effect = lambda buf: _fill_buffer(buf, full_data[16:])
+ handler(sock)
+ assert received == [b"abcdefghijklmnop\n"]
+
+
+def _fill_buffer(buf: bytearray, data: bytes) -> int:
+ """Helper to simulate socket.recv_into by filling the buffer."""
+ n = min(len(data), len(buf))
+ buf[:n] = data[:n]
+ return n
+
+
+class TestMakeRawForwarder:
+ def test_forwards_data_to_dest(self):
+ on_close = MagicMock()
+ dest = MagicMock(spec=socket.socket)
+ handler, returned_on_close = make_raw_forwarder(dest, on_close)
+
+ src = MagicMock(spec=socket.socket)
+ src.recv.return_value = b"hello"
+
+ result = handler(src)
+
+ assert result is True
+ dest.sendall.assert_called_once_with(b"hello")
+ assert returned_on_close is on_close
+
+ def test_eof_returns_false(self):
+ on_close = MagicMock()
+ dest = MagicMock(spec=socket.socket)
+ handler, _ = make_raw_forwarder(dest, on_close)
+
+ src = MagicMock(spec=socket.socket)
+ src.recv.return_value = b""
+
+ result = handler(src)
+
+ assert result is False
+ dest.sendall.assert_not_called()
+
+ @pytest.mark.parametrize(
+ "exception",
+ [BrokenPipeError, ConnectionResetError, OSError],
+ ids=["broken_pipe", "connection_reset", "os_error"],
+ )
+ def test_sendall_exception_returns_false(self, exception):
+ on_close = MagicMock()
+ dest = MagicMock(spec=socket.socket)
+ dest.sendall.side_effect = exception
+ handler, _ = make_raw_forwarder(dest, on_close)
+
+ src = MagicMock(spec=socket.socket)
+ src.recv.return_value = b"data"
+
+ result = handler(src)
+
+ assert result is False
+
+ def test_multiple_forwards(self):
+ on_close = MagicMock()
+ dest = MagicMock(spec=socket.socket)
+ handler, _ = make_raw_forwarder(dest, on_close)
+
+ src = MagicMock(spec=socket.socket)
+
+ for chunk in [b"chunk1", b"chunk2", b"chunk3"]:
+ src.recv.return_value = chunk
+ assert handler(src) is True
+
+ assert dest.sendall.call_count == 3
+
+
+class TestServiceSelector:
+ def test_calls_handler_for_ready_sockets(self):
+ sel = MagicMock(spec=selectors.DefaultSelector)
+ handler = MagicMock(return_value=True)
+ on_close = MagicMock()
+ sock = MagicMock(spec=socket.socket)
+
+ key = MagicMock()
+ key.data = (handler, on_close)
+ key.fileobj = sock
+
+ sel.select.return_value = [(key, selectors.EVENT_READ)]
+
+ service_selector(sel, timeout=1.0)
+
+ handler.assert_called_once_with(sock)
+ on_close.assert_not_called()
+ sock.close.assert_not_called()
+
+ def test_on_close_and_sock_close_when_handler_returns_false(self):
+ sel = MagicMock(spec=selectors.DefaultSelector)
+ handler = MagicMock(return_value=False)
+ on_close = MagicMock()
+ sock = MagicMock(spec=socket.socket)
+
+ key = MagicMock()
+ key.data = (handler, on_close)
+ key.fileobj = sock
+
+ sel.select.return_value = [(key, selectors.EVENT_READ)]
+
+ service_selector(sel, timeout=1.0)
+
+ handler.assert_called_once_with(sock)
+ on_close.assert_called_once_with(sock)
+ sock.close.assert_called_once()
+
+ @pytest.mark.parametrize(
+ "exception",
+ [BrokenPipeError, ConnectionResetError],
+ ids=["broken_pipe", "connection_reset"],
+ )
+ def test_pipe_errors_treated_as_eof(self, exception):
+ sel = MagicMock(spec=selectors.DefaultSelector)
+ handler = MagicMock(side_effect=exception)
+ on_close = MagicMock()
+ sock = MagicMock(spec=socket.socket)
+
+ key = MagicMock()
+ key.data = (handler, on_close)
+ key.fileobj = sock
+
+ sel.select.return_value = [(key, selectors.EVENT_READ)]
+
+ service_selector(sel, timeout=1.0)
+
+ on_close.assert_called_once_with(sock)
+ sock.close.assert_called_once()
+
+ def test_empty_selector_no_events(self):
+ sel = MagicMock(spec=selectors.DefaultSelector)
+ sel.select.return_value = []
+
+ # Should not raise
+ service_selector(sel, timeout=1.0)
+
+ @pytest.mark.parametrize(
+ ("input_timeout", "expected_min"),
+ [
+ (0.0, 0.01),
+ (-1.0, 0.01),
+ (-100.0, 0.01),
+ (0.5, 0.5),
+ (2.0, 2.0),
+ ],
+ ids=["zero", "negative", "very_negative", "positive_half", "positive_two"],
+ )
+ def test_timeout_clamped_to_minimum(self, input_timeout, expected_min):
+ sel = MagicMock(spec=selectors.DefaultSelector)
+ sel.select.return_value = []
+
+ service_selector(sel, timeout=input_timeout)
+
+ sel.select.assert_called_once()
+ actual_timeout = sel.select.call_args[1].get("timeout") or sel.select.call_args[0][0]
+ assert actual_timeout == pytest.approx(expected_min)
+
+ def test_multiple_ready_sockets(self):
+ sel = MagicMock(spec=selectors.DefaultSelector)
+
+ handler1 = MagicMock(return_value=True)
+ on_close1 = MagicMock()
+ sock1 = MagicMock(spec=socket.socket)
+ key1 = MagicMock()
+ key1.data = (handler1, on_close1)
+ key1.fileobj = sock1
+
+ handler2 = MagicMock(return_value=False)
+ on_close2 = MagicMock()
+ sock2 = MagicMock(spec=socket.socket)
+ key2 = MagicMock()
+ key2.data = (handler2, on_close2)
+ key2.fileobj = sock2
+
+ sel.select.return_value = [(key1, selectors.EVENT_READ), (key2, selectors.EVENT_READ)]
+
+ service_selector(sel, timeout=1.0)
+
+ # First socket: handler returns True, stays open
+ handler1.assert_called_once_with(sock1)
+ on_close1.assert_not_called()
+ sock1.close.assert_not_called()
+
+ # Second socket: handler returns False, closed
+ handler2.assert_called_once_with(sock2)
+ on_close2.assert_called_once_with(sock2)
+ sock2.close.assert_called_once()
+
+
+class TestSelectorLoopIntegration:
+ def test_buffered_reader_with_real_sockets(self):
+ """End-to-end: send lines through real sockets and verify buffered reading."""
+ gen, received = _make_generator()
+ sender, reader = _make_socket_pair()
+ try:
+ sel = selectors.DefaultSelector()
+
+ def on_close(sock):
+ sel.unregister(sock)
+
+ sel.register(reader, selectors.EVENT_READ, make_buffered_socket_reader(gen, on_close))
+
+ sender.sendall(b"first line\nsecond line\n")
+
+ service_selector(sel, timeout=1.0)
+
+ assert b"first line\n" in received
+ assert b"second line\n" in received
+
+ # Close sender, then drain
+ sender.close()
+ sender = None
+
+ service_selector(sel, timeout=0.5)
+
+ sel.close()
+ finally:
+ if sender:
+ sender.close()
+ reader.close()
+
+ def test_raw_forwarder_with_real_sockets(self):
+ """End-to-end: forward raw bytes between real socket pairs."""
+ src_send, src_recv = _make_socket_pair()
+ # Use socketpair for the destination so reads/writes are symmetric
+ dst_write, dst_read = socket.socketpair()
+ try:
+ sel = selectors.DefaultSelector()
+
+ def on_close(sock):
+ sel.unregister(sock)
+
+ sel.register(src_recv, selectors.EVENT_READ, make_raw_forwarder(dst_write, on_close))
+
+ src_send.sendall(b"raw data payload")
+
+ service_selector(sel, timeout=1.0)
+
+ dst_read.setblocking(False)
+ forwarded = dst_read.recv(4096)
+
+ assert forwarded == b"raw data payload"
+
+ sel.close()
+ finally:
+ for s in (src_send, src_recv, dst_write, dst_read):
+ s.close()
+
+ def test_eof_triggers_on_close_with_real_sockets(self):
+ """When the sender closes, the selector callback chain fires on_close."""
+ gen, received = _make_generator()
+ sender, reader = _make_socket_pair()
+ closed_sockets: list[socket.socket] = []
+ try:
+ sel = selectors.DefaultSelector()
+
+ def on_close(sock):
+ sel.unregister(sock)
+ closed_sockets.append(sock)
+
+ sel.register(reader, selectors.EVENT_READ, make_buffered_socket_reader(gen, on_close))
+
+ # Send data then close
+ sender.sendall(b"final\n")
+ service_selector(sel, timeout=1.0)
+ assert received == [b"final\n"]
+
+ sender.close()
+ sender = None
+ service_selector(sel, timeout=0.5)
+
+ # on_close should have been called, and socket closed by service_selector
+ assert len(closed_sockets) == 1
+
+ sel.close()
+ finally:
+ if sender:
+ sender.close()
+ reader.close()
diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
index f0d8e1a0b65d6..332dd1b537d8f 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
@@ -64,7 +64,6 @@
DagRunState,
DagRunType,
PreviousTIResponse,
- TaskInstance,
TaskInstanceState,
)
from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType, TaskAlreadyRunningError
@@ -151,6 +150,7 @@
supervise_task,
)
from airflow.sdk.execution_time.task_runner import run
+from airflow.sdk.execution_time.workloads.task import TaskInstanceDTO
from tests_common.test_utils.config import conf_vars
@@ -211,13 +211,16 @@ def test_supervise(
"""
Test that the supervisor validates server URL and dry_run parameter combinations correctly.
"""
- ti = TaskInstance(
+ ti = TaskInstanceDTO(
id=uuid7(),
task_id="async",
dag_id="super_basic_deferred_run",
run_id="d",
try_number=1,
dag_version_id=uuid7(),
+ pool_slots=1,
+ queue="default",
+ priority_weight=1,
)
bundle_info = BundleInfo(name="my-bundle", version=None)
@@ -304,13 +307,16 @@ def subprocess_main():
proc = ActivitySubprocess.start(
dag_rel_path=os.devnull,
bundle_info=FAKE_BUNDLE,
- what=TaskInstance(
+ what=TaskInstanceDTO(
id="4d828a62-a417-4936-a7a6-2b3fabacecab",
task_id="b",
dag_id="c",
run_id="d",
try_number=1,
dag_version_id=uuid7(),
+ pool_slots=1,
+ queue="default",
+ priority_weight=1,
),
client=client_with_ti_start,
target=subprocess_main,
@@ -379,13 +385,16 @@ def subprocess_main():
proc = ActivitySubprocess.start(
dag_rel_path=os.devnull,
bundle_info=FAKE_BUNDLE,
- what=TaskInstance(
+ what=TaskInstanceDTO(
id="4d828a62-a417-4936-a7a6-2b3fabacecab",
task_id="b",
dag_id="c",
run_id="d",
try_number=1,
dag_version_id=uuid7(),
+ pool_slots=1,
+ queue="default",
+ priority_weight=1,
),
client=client_with_ti_start,
target=subprocess_main,
@@ -476,13 +485,16 @@ def on_kill(self) -> None:
proc = ActivitySubprocess.start(
dag_rel_path=os.devnull,
bundle_info=FAKE_BUNDLE,
- what=TaskInstance(
+ what=TaskInstanceDTO(
id=ti_id,
task_id="b",
dag_id="c",
run_id="d",
try_number=1,
dag_version_id=uuid7(),
+ pool_slots=1,
+ queue="default",
+ priority_weight=1,
),
client=make_client(transport=httpx.MockTransport(handle_request)),
target=subprocess_main,
@@ -505,13 +517,16 @@ def subprocess_main():
proc = ActivitySubprocess.start(
dag_rel_path=os.devnull,
bundle_info=FAKE_BUNDLE,
- what=TaskInstance(
+ what=TaskInstanceDTO(
id="4d828a62-a417-4936-a7a6-2b3fabacecab",
task_id="b",
dag_id="c",
run_id="d",
try_number=1,
dag_version_id=uuid7(),
+ pool_slots=1,
+ queue="default",
+ priority_weight=1,
),
client=client_with_ti_start,
target=subprocess_main,
@@ -540,8 +555,16 @@ def subprocess_main():
proc = ActivitySubprocess.start(
dag_rel_path=os.devnull,
bundle_info=FAKE_BUNDLE,
- what=TaskInstance(
- id=uuid7(), task_id="b", dag_id="c", run_id="d", try_number=1, dag_version_id=uuid7()
+ what=TaskInstanceDTO(
+ id=uuid7(),
+ task_id="b",
+ dag_id="c",
+ run_id="d",
+ try_number=1,
+ dag_version_id=uuid7(),
+ pool_slots=1,
+ queue="default",
+ priority_weight=1,
),
client=mock_client,
target=subprocess_main,
@@ -579,13 +602,16 @@ def test_resume_start_date_from_context(self, mocker, make_ti_context, start_dat
proc = ActivitySubprocess.start(
dag_rel_path=os.devnull,
bundle_info=FAKE_BUNDLE,
- what=TaskInstance(
+ what=TaskInstanceDTO(
id=uuid7(),
task_id="b",
dag_id="c",
run_id="d",
try_number=1,
dag_version_id=uuid7(),
+ pool_slots=1,
+ queue="default",
+ priority_weight=1,
),
client=mock_client,
target=lambda: None,
@@ -622,8 +648,16 @@ def subprocess_main():
proc = ActivitySubprocess.start(
dag_rel_path=os.devnull,
bundle_info=FAKE_BUNDLE,
- what=TaskInstance(
- id=ti_id, task_id="b", dag_id="c", run_id="d", try_number=1, dag_version_id=uuid7()
+ what=TaskInstanceDTO(
+ id=ti_id,
+ task_id="b",
+ dag_id="c",
+ run_id="d",
+ try_number=1,
+ dag_version_id=uuid7(),
+ pool_slots=1,
+ queue="default",
+ priority_weight=1,
),
client=sdk_client.Client(base_url="", dry_run=True, token=""),
target=subprocess_main,
@@ -659,8 +693,16 @@ def _on_child_started(self, *args, **kwargs):
proc = ActivitySubprocess.start(
dag_rel_path=os.devnull,
bundle_info=FAKE_BUNDLE,
- what=TaskInstance(
- id=ti_id, task_id="b", dag_id="c", run_id="d", try_number=1, dag_version_id=uuid7()
+ what=TaskInstanceDTO(
+ id=ti_id,
+ task_id="b",
+ dag_id="c",
+ run_id="d",
+ try_number=1,
+ dag_version_id=uuid7(),
+ pool_slots=1,
+ queue="default",
+ priority_weight=1,
),
client=sdk_client.Client(base_url="", dry_run=True, token=""),
target=subprocess_main,
@@ -675,13 +717,16 @@ def test_run_simple_dag(self, test_dags_dir, captured_logs, time_machine, mocker
time_machine.move_to(instant, tick=False)
dagfile_path = test_dags_dir
- ti = TaskInstance(
+ ti = TaskInstanceDTO(
id=uuid7(),
task_id="hello",
dag_id="super_basic_run",
run_id="c",
try_number=1,
dag_version_id=uuid7(),
+ pool_slots=1,
+ queue="default",
+ priority_weight=1,
)
bundle_info = BundleInfo(name="my-bundle", version=None)
@@ -716,13 +761,16 @@ def test_supervise_handles_deferred_task(
"""
instant = timezone.datetime(2024, 11, 7, 12, 34, 56, 0)
- ti = TaskInstance(
+ ti = TaskInstanceDTO(
id=uuid7(),
task_id="async",
dag_id="super_basic_deferred_run",
run_id="d",
try_number=1,
dag_version_id=uuid7(),
+ pool_slots=1,
+ queue="default",
+ priority_weight=1,
)
# Create a mock client to assert calls to the client
@@ -843,8 +891,16 @@ def handle_request(request: httpx.Request) -> httpx.Response:
proc = ActivitySubprocess.start(
dag_rel_path=os.devnull,
- what=TaskInstance(
- id=ti_id, task_id="b", dag_id="c", run_id="d", try_number=1, dag_version_id=uuid7()
+ what=TaskInstanceDTO(
+ id=ti_id,
+ task_id="b",
+ dag_id="c",
+ run_id="d",
+ try_number=1,
+ dag_version_id=uuid7(),
+ pool_slots=1,
+ queue="default",
+ priority_weight=1,
),
client=make_client(transport=httpx.MockTransport(handle_request)),
target=subprocess_main,
@@ -921,8 +977,16 @@ def subprocess_main():
ActivitySubprocess.start(
dag_rel_path=os.devnull,
bundle_info=FAKE_BUNDLE,
- what=TaskInstance(
- id=ti_id, task_id="b", dag_id="c", run_id="d", try_number=1, dag_version_id=uuid7()
+ what=TaskInstanceDTO(
+ id=ti_id,
+ task_id="b",
+ dag_id="c",
+ run_id="d",
+ try_number=1,
+ dag_version_id=uuid7(),
+ pool_slots=1,
+ queue="default",
+ priority_weight=1,
),
client=make_client(transport=httpx.MockTransport(handle_request)),
target=subprocess_main,
@@ -1126,13 +1190,16 @@ def subprocess_main():
proc = ActivitySubprocess.start(
dag_rel_path=os.devnull,
bundle_info=FAKE_BUNDLE,
- what=TaskInstance(
+ what=TaskInstanceDTO(
id="4d828a62-a417-4936-a7a6-2b3fabacecab",
task_id="b",
dag_id="c",
run_id="d",
try_number=1,
dag_version_id=uuid7(),
+ pool_slots=1,
+ queue="default",
+ priority_weight=1,
),
client=client_with_ti_start,
target=subprocess_main,
@@ -1289,8 +1356,16 @@ def _handler(sig, frame):
proc = ActivitySubprocess.start(
dag_rel_path=os.devnull,
bundle_info=FAKE_BUNDLE,
- what=TaskInstance(
- id=ti_id, task_id="b", dag_id="c", run_id="d", try_number=1, dag_version_id=uuid7()
+ what=TaskInstanceDTO(
+ id=ti_id,
+ task_id="b",
+ dag_id="c",
+ run_id="d",
+ try_number=1,
+ dag_version_id=uuid7(),
+ pool_slots=1,
+ queue="default",
+ priority_weight=1,
),
client=client_with_ti_start,
target=subprocess_main,
@@ -3304,13 +3379,16 @@ def subprocess_main():
proc = ActivitySubprocess.start(
dag_rel_path=os.devnull,
bundle_info=FAKE_BUNDLE,
- what=TaskInstance(
+ what=TaskInstanceDTO(
id="4d828a62-a417-4936-a7a6-2b3fabacecab",
task_id="b",
dag_id="c",
run_id="d",
try_number=1,
dag_version_id=uuid7(),
+ pool_slots=1,
+ queue="default",
+ priority_weight=1,
),
client=client_with_ti_start,
target=subprocess_main,
diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
index 630aff9094ed1..476db10184713 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
@@ -146,6 +146,7 @@
run,
startup,
)
+from airflow.sdk.execution_time.workloads.task import TaskInstanceDTO
from airflow.sdk.execution_time.xcom import XCom
from airflow.sdk.serde import deserialize
from airflow.triggers.base import BaseEventTrigger, BaseTrigger, TriggerEvent
@@ -177,13 +178,16 @@ def execute(self, context):
def test_parse(test_dags_dir: Path, make_ti_context):
"""Test that checks parsing of a basic dag with an un-mocked parse."""
what = StartupDetails(
- ti=TaskInstance(
+ ti=TaskInstanceDTO(
id=uuid7(),
task_id="a",
dag_id="super_basic",
run_id="c",
try_number=1,
dag_version_id=uuid7(),
+ pool_slots=1,
+ queue="default",
+ priority_weight=1,
),
dag_rel_path="super_basic.py",
bundle_info=BundleInfo(name="my-bundle", version=None),
@@ -224,13 +228,16 @@ def test_parse_dag_bag(mock_dagbag, test_dags_dir: Path, make_ti_context):
mock_dag.task_dict = {"a": mock_task}
what = StartupDetails(
- ti=TaskInstance(
+ ti=TaskInstanceDTO(
id=uuid7(),
task_id="a",
dag_id="super_basic",
run_id="c",
try_number=1,
dag_version_id=uuid7(),
+ pool_slots=1,
+ queue="default",
+ priority_weight=1,
),
dag_rel_path="super_basic.py",
bundle_info=BundleInfo(name="my-bundle", version=None),
@@ -284,13 +291,16 @@ def test_parse_dag_bag(mock_dagbag, test_dags_dir: Path, make_ti_context):
def test_parse_not_found(test_dags_dir: Path, make_ti_context, dag_id, task_id, expected_error):
"""Check for nice error messages on dag not found."""
what = StartupDetails(
- ti=TaskInstance(
+ ti=TaskInstanceDTO(
id=uuid7(),
task_id=task_id,
dag_id=dag_id,
run_id="c",
try_number=1,
dag_version_id=uuid7(),
+ pool_slots=1,
+ queue="default",
+ priority_weight=1,
),
dag_rel_path="super_basic.py",
bundle_info=BundleInfo(name="my-bundle", version=None),
@@ -330,13 +340,16 @@ def test_parse_not_found_does_not_reschedule_when_max_attempts_reached(test_dags
and should surface as a hard failure (SystemExit in the task runner process).
"""
what = StartupDetails(
- ti=TaskInstance(
+ ti=TaskInstanceDTO(
id=uuid7(),
task_id="a",
dag_id="madeup_dag_id",
run_id="c",
try_number=1,
dag_version_id=uuid7(),
+ pool_slots=1,
+ queue="default",
+ priority_weight=1,
),
dag_rel_path="super_basic.py",
bundle_info=BundleInfo(name="my-bundle", version=None),
@@ -391,13 +404,16 @@ def test_main_sends_reschedule_task_when_startup_reschedules(
mock_comms_instance.socket = None
mock_comms_decoder_cls.__getitem__.return_value.return_value = mock_comms_instance
what = StartupDetails(
- ti=TaskInstance(
+ ti=TaskInstanceDTO(
id=uuid7(),
task_id="my_task",
dag_id="test_dag",
run_id="test_run",
try_number=1,
dag_version_id=uuid7(),
+ pool_slots=1,
+ queue="default",
+ priority_weight=1,
context_carrier={},
),
dag_rel_path="",
@@ -484,13 +500,16 @@ def test_task_span_is_child_of_dag_run_span(make_ti_context):
# Step 3: build StartupDetails with ti.context_carrier = ti_carrier.
what = StartupDetails(
- ti=TaskInstance(
+ ti=TaskInstanceDTO(
id=uuid7(),
task_id="my_task",
dag_id="test_dag",
run_id="test_run",
try_number=1,
dag_version_id=uuid7(),
+ pool_slots=1,
+ queue="default",
+ priority_weight=1,
context_carrier=ti_carrier,
),
dag_rel_path="",
@@ -552,13 +571,16 @@ def test_task_span_no_parent_when_no_context_carrier(make_ti_context):
provider.add_span_processor(SimpleSpanProcessor(in_mem_exporter))
what = StartupDetails(
- ti=TaskInstance(
+ ti=TaskInstanceDTO(
id=uuid7(),
task_id="standalone_task",
dag_id="test_dag",
run_id="test_run",
try_number=1,
dag_version_id=uuid7(),
+ pool_slots=1,
+ queue="default",
+ priority_weight=1,
context_carrier=None,
),
dag_rel_path="",
@@ -593,13 +615,16 @@ def test_parse_module_in_bundle_root(tmp_path: Path, make_ti_context):
dag1_path.write_text(textwrap.dedent(dag1_code))
what = StartupDetails(
- ti=TaskInstance(
+ ti=TaskInstanceDTO(
id=uuid7(),
task_id="a",
dag_id="dag_name",
run_id="c",
try_number=1,
dag_version_id=uuid7(),
+ pool_slots=1,
+ queue="default",
+ priority_weight=1,
),
dag_rel_path="path_test.py",
bundle_info=BundleInfo(name="my-bundle", version=None),
@@ -1040,13 +1065,16 @@ def test_basic_templated_dag(mocked_parse, make_ti_context, mock_supervisor_comm
)
what = StartupDetails(
- ti=TaskInstance(
+ ti=TaskInstanceDTO(
id=uuid7(),
task_id="templated_task",
dag_id="basic_templated_dag",
run_id="c",
try_number=1,
dag_version_id=uuid7(),
+ pool_slots=1,
+ queue="default",
+ priority_weight=1,
),
bundle_info=FAKE_BUNDLE,
dag_rel_path="",
@@ -1156,13 +1184,16 @@ def execute(self, context):
instant = timezone.datetime(2024, 12, 3, 10, 0)
what = StartupDetails(
- ti=TaskInstance(
+ ti=TaskInstanceDTO(
id=uuid7(),
task_id="templated_task",
dag_id="basic_dag",
run_id="c",
try_number=1,
dag_version_id=uuid7(),
+ pool_slots=1,
+ queue="default",
+ priority_weight=1,
),
dag_rel_path="",
bundle_info=FAKE_BUNDLE,
@@ -1204,13 +1235,16 @@ def execute(self, context):
instant = timezone.datetime(2024, 12, 3, 10, 0)
what = StartupDetails(
- ti=TaskInstance(
+ ti=TaskInstanceDTO(
id=uuid7(),
task_id="impersonation_task",
dag_id="basic_dag",
run_id="c",
try_number=1,
dag_version_id=uuid7(),
+ pool_slots=1,
+ queue="default",
+ priority_weight=1,
),
dag_rel_path="",
bundle_info=FAKE_BUNDLE,
@@ -1252,13 +1286,16 @@ def execute(self, context):
instant = timezone.datetime(2024, 12, 3, 10, 0)
what = StartupDetails(
- ti=TaskInstance(
+ ti=TaskInstanceDTO(
id=uuid7(),
task_id="impersonation_task",
dag_id="basic_dag",
run_id="c",
try_number=1,
dag_version_id=uuid7(),
+ pool_slots=1,
+ queue="default",
+ priority_weight=1,
),
dag_rel_path="",
bundle_info=FAKE_BUNDLE,
@@ -1292,13 +1329,16 @@ def execute(self, context):
instant = timezone.datetime(2024, 12, 3, 10, 0)
what = StartupDetails(
- ti=TaskInstance(
+ ti=TaskInstanceDTO(
id=uuid7(),
task_id="impersonation_task",
dag_id="basic_dag",
run_id="c",
try_number=1,
dag_version_id=uuid7(),
+ pool_slots=1,
+ queue="default",
+ priority_weight=1,
),
dag_rel_path="",
bundle_info=FAKE_BUNDLE,
@@ -1465,8 +1505,16 @@ def test_dag_parsing_context(make_ti_context, mock_supervisor_comms, monkeypatch
task_id = "conditional_task"
what = StartupDetails(
- ti=TaskInstance(
- id=uuid7(), task_id=task_id, dag_id=dag_id, run_id="c", try_number=1, dag_version_id=uuid7()
+ ti=TaskInstanceDTO(
+ id=uuid7(),
+ task_id=task_id,
+ dag_id=dag_id,
+ run_id="c",
+ try_number=1,
+ dag_version_id=uuid7(),
+ pool_slots=1,
+ queue="default",
+ priority_weight=1,
),
dag_rel_path="dag_parsing_context.py",
bundle_info=BundleInfo(name="my-bundle", version=None),
@@ -2061,8 +2109,10 @@ def execute(self, context):
test_task_id = "pull_task"
task = CustomOperator(task_id=test_task_id)
- # In case of the specific map_index or None we should check it is passed to TI
- extra_for_ti = {"map_index": map_indexes} if map_indexes in (1, None) else {}
+ # In case of the specific map_index we should check it is passed to TI.
+ # ``None`` is not a valid TaskInstanceDTO.map_index value, but xcom_pull's
+ # behaviour with ``map_indexes=None`` is independent of the TI's own map_index.
+ extra_for_ti = {"map_index": map_indexes} if isinstance(map_indexes, int) else {}
runtime_ti = create_runtime_ti(task=task, **extra_for_ti)
ser_value = BaseXCom.serialize_value(xcom_values)
@@ -3883,13 +3933,16 @@ def execute(self, context):
task_id="test_task_runner_calls_listeners", do_xcom_push=True, multiple_outputs=True
)
what = StartupDetails(
- ti=TaskInstance(
+ ti=TaskInstanceDTO(
id=uuid7(),
task_id="templated_task",
dag_id="basic_dag",
run_id="c",
try_number=1,
dag_version_id=uuid7(),
+ pool_slots=1,
+ queue="default",
+ priority_weight=1,
),
dag_rel_path="",
bundle_info=FAKE_BUNDLE,
@@ -4491,7 +4544,8 @@ class CustomOperator(BaseOperator):
class TestTriggerDagRunOperator:
"""Tests to verify various aspects of TriggerDagRunOperator"""
- @time_machine.travel("2025-01-01 00:00:00", tick=False)
+ # make timetravel timezone-aware
+ @time_machine.travel(datetime(2025, 1, 1, 0, 0, 0, tzinfo=timezone.utc), tick=False)
def test_handle_trigger_dag_run(self, create_runtime_ti, mock_supervisor_comms):
"""Test that TriggerDagRunOperator (with default args) sends the correct message to the Supervisor"""
from airflow.providers.standard.operators.trigger_dagrun import TriggerDagRunOperator
@@ -4539,7 +4593,7 @@ def test_handle_trigger_dag_run(self, create_runtime_ti, mock_supervisor_comms):
(False, TaskInstanceState.FAILED),
],
)
- @time_machine.travel("2025-01-01 00:00:00", tick=False)
+ @time_machine.travel(datetime(2025, 1, 1, 0, 0, 0, tzinfo=timezone.utc), tick=False)
def test_handle_trigger_dag_run_conflict(
self, skip_when_already_exists, expected_state, create_runtime_ti, mock_supervisor_comms
):
@@ -4583,7 +4637,7 @@ def test_handle_trigger_dag_run_conflict(
([DagRunState.SUCCESS], None, DagRunState.FAILED, DagRunState.FAILED),
],
)
- @time_machine.travel("2025-01-01 00:00:00", tick=False)
+ @time_machine.travel(datetime(2025, 1, 1, 0, 0, 0, tzinfo=timezone.utc), tick=False)
def test_handle_trigger_dag_run_wait_for_completion(
self,
allowed_states,
@@ -4704,7 +4758,7 @@ def test_handle_trigger_dag_run_deferred(
assert state == intermediate_state
- @time_machine.travel("2025-01-01 00:00:00", tick=False)
+ @time_machine.travel(datetime(2025, 1, 1, 0, 0, 0, tzinfo=timezone.utc), tick=False)
def test_handle_trigger_dag_run_deferred_with_reset_uses_run_id_only(
self, create_runtime_ti, mock_supervisor_comms
):
diff --git a/uv.lock b/uv.lock
index ddebdf7a493e8..9758c3b7811f5 100644
--- a/uv.lock
+++ b/uv.lock
@@ -158,6 +158,7 @@ apache-aurflow-docker-stack = false
members = [
"apache-airflow",
"apache-airflow-breeze",
+ "apache-airflow-coordinators-java",
"apache-airflow-core",
"apache-airflow-ctl",
"apache-airflow-ctl-tests",
@@ -1791,6 +1792,39 @@ requires-dist = [
{ name = "twine", specifier = ">=4.0.2" },
]
+[[package]]
+name = "apache-airflow-coordinators-java"
+version = "0.1.0"
+source = { editable = "sdk/coordinators/java" }
+dependencies = [
+ { name = "apache-airflow-task-sdk" },
+ { name = "pyyaml" },
+]
+
+[package.dev-dependencies]
+dev = [
+ { name = "apache-airflow" },
+ { name = "apache-airflow-devel-common" },
+ { name = "apache-airflow-task-sdk" },
+]
+docs = [
+ { name = "apache-airflow-devel-common", extra = ["docs"] },
+]
+
+[package.metadata]
+requires-dist = [
+ { name = "apache-airflow-task-sdk", editable = "task-sdk" },
+ { name = "pyyaml", specifier = ">=6.0.2" },
+]
+
+[package.metadata.requires-dev]
+dev = [
+ { name = "apache-airflow", editable = "." },
+ { name = "apache-airflow-devel-common", editable = "devel-common" },
+ { name = "apache-airflow-task-sdk", editable = "task-sdk" },
+]
+docs = [{ name = "apache-airflow-devel-common", extras = ["docs"], editable = "devel-common" }]
+
[[package]]
name = "apache-airflow-core"
version = "3.3.0"