diff --git a/.github/ISSUE_TEMPLATE/1-airflow_bug_report.yml b/.github/ISSUE_TEMPLATE/1-airflow_bug_report.yml index 633ef433f4b63..e8ef6d7af228e 100644 --- a/.github/ISSUE_TEMPLATE/1-airflow_bug_report.yml +++ b/.github/ISSUE_TEMPLATE/1-airflow_bug_report.yml @@ -192,6 +192,7 @@ body: - redis - salesforce - samba + - sdk-java - segment - sendgrid - sftp diff --git a/.github/boring-cyborg.yml b/.github/boring-cyborg.yml index 21b44a9fb840f..3f5477f4512c5 100644 --- a/.github/boring-cyborg.yml +++ b/.github/boring-cyborg.yml @@ -189,6 +189,9 @@ labelPRBasedOnFilePath: provider:keycloak: - providers/keycloak/** + provider:sdk-java: + - providers/sdk/java/** + provider:microsoft-azure: - providers/microsoft/azure/** diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 295688d115f4c..6605d28649497 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -463,6 +463,12 @@ repos: language: python pass_filenames: false files: ^dev/registry/registry_tools/types\.py$|^registry/src/_data/types\.json$ + - id: check-task-instance-dto-sync + name: Check BaseTaskInstanceDTO duplicate is in sync between core and task-sdk + entry: ./scripts/ci/prek/check_task_instance_dto_sync.py + language: python + pass_filenames: false + files: ^airflow-core/src/airflow/executors/workloads/task\.py$|^task-sdk/src/airflow/sdk/execution_time/workloads/task\.py$ - id: ruff name: Run 'ruff' for extremely fast Python linting description: "Run 'ruff' for extremely fast Python linting" diff --git a/airflow-core/docs/extra-packages-ref.rst b/airflow-core/docs/extra-packages-ref.rst index 2646b0a7c3079..9fb579c9b08ec 100644 --- a/airflow-core/docs/extra-packages-ref.rst +++ b/airflow-core/docs/extra-packages-ref.rst @@ -178,6 +178,17 @@ all the ``airflow`` packages together - similarly to what happened in Airflow 2. ``airflow-task-sdk`` separately, if you want to install providers, you need to install them separately as ``apache-airflow-providers-*`` distribution packages. +Multi-Language extras +===================== + +These are extras that add dependencies needed for integration with other languages runtimes. Currently we have only Java SDK related extra, but in the future we might add more extras related to other languages runtimes. + ++----------+------------------------------------------+------------------------------------------------------------------+ +| extra | install command | enables | ++==========+==========================================+==================================================================+ +| sdk.java | ``pip install apache-airflow[sdk.java]`` | JavaCoordinator for both dag processing and workload execution. | ++----------+------------------------------------------+------------------------------------------------------------------+ + Apache Software extras ====================== diff --git a/airflow-core/src/airflow/config_templates/config.yml b/airflow-core/src/airflow/config_templates/config.yml index c37989b9b7486..e173080b69062 100644 --- a/airflow-core/src/airflow/config_templates/config.yml +++ b/airflow-core/src/airflow/config_templates/config.yml @@ -1967,6 +1967,21 @@ workers: type: integer example: ~ default: "60" +sdk: + description: Settings for non-Python SDK runtime coordination + options: + queue_to_sdk: + description: | + JSON mapping of queue names to SDK runtime coordinator names. + + When a task's ``language`` field is not set, this mapping is checked + to route the task to a non-Python runtime coordinator based on its + queue. This is useful when queues are used as environment or + isolation identifiers (e.g. ``foo``, ``bar``). + version_added: 3.1.7 + type: string + example: '{"foo": "java", "bar": "java", "go-queue": "go"}' + default: ~ api_auth: description: Settings relating to authentication on the Airflow APIs options: diff --git a/airflow-core/src/airflow/dag_processing/manager.py b/airflow-core/src/airflow/dag_processing/manager.py index 8d497ca7508e3..c2dab6e3fe6d5 100644 --- a/airflow-core/src/airflow/dag_processing/manager.py +++ b/airflow-core/src/airflow/dag_processing/manager.py @@ -66,7 +66,7 @@ from airflow.sdk import SecretCache from airflow.sdk.log import init_log_file, logging_processors from airflow.typing_compat import assert_never -from airflow.utils.file import list_py_file_paths, might_contain_dag +from airflow.utils.file import might_contain_dag from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.net import get_hostname from airflow.utils.process_utils import ( @@ -88,6 +88,9 @@ from airflow.sdk.api.client import Client +log = logging.getLogger(__name__) + + class DagParsingStat(NamedTuple): """Information on processing progress.""" @@ -158,6 +161,62 @@ def utc_epoch() -> datetime: return result +def discover_dag_file_paths( + directory: str | os.PathLike[str] | None, + bundle_name: str = "", + safe_mode: bool = conf.getboolean("core", "DAG_DISCOVERY_SAFE_MODE", fallback=True), +) -> list[str]: + """ + Discover paths of DAG files within a directory. + + Walks ``directory`` (honouring ``.airflowignore``) and returns each file that is + either a Python DAG candidate (``.py`` source or ZIP archive that passes + :func:`~airflow.utils.file.might_contain_dag`) or accepted by a registered coordinator's + :meth:`~airflow.sdk.execution_time.coordinator.BaseCoordinator.can_handle_dag_file` + (e.g. a ``.jar`` for the Java SDK, a self-contained executable for the Go SDK). + + Coordinator handling takes precedence over the generic ZIP heuristic so that, + for example, a ``.jar`` is delegated to its coordinator rather than being + scanned for embedded ``.py`` modules. + + :param directory: Directory to scan, or a single file path. ``None`` returns + an empty list. A single file is returned as-is without filtering. + :param bundle_name: Bundle name forwarded to ``can_handle_dag_file``. + :param safe_mode: Whether to apply the Python DAG heuristic; see + :func:`~airflow.utils.file.might_contain_dag`. + :return: Absolute paths discovered as DAG sources. + """ + if directory is None: + return [] + if os.path.isfile(directory): + return [str(directory)] + if not os.path.isdir(directory): + return [] + + from airflow._shared.module_loading.file_discovery import find_path_from_directory + from airflow.providers_manager import ProvidersManager + + coordinators = ProvidersManager().coordinators + ignore_file_syntax = conf.get_mandatory_value("core", "DAG_IGNORE_FILE_SYNTAX", fallback="glob") + + file_paths: list[str] = [] + for file_path in find_path_from_directory(directory, ".airflowignore", ignore_file_syntax): + path = Path(file_path) + try: + if not path.is_file(): + continue + if path.suffix == ".py": + if might_contain_dag(file_path, safe_mode): + file_paths.append(file_path) + elif any(c.can_handle_dag_file(bundle_name, file_path) for c in coordinators): + file_paths.append(file_path) + elif zipfile.is_zipfile(path) and might_contain_dag(file_path, safe_mode): + file_paths.append(file_path) + except Exception: + log.exception("Error while examining %s", file_path) + return file_paths + + class _StubSelector(selectors.BaseSelector): """ Stub to stand in until the real selector is created. @@ -808,9 +867,11 @@ def _refresh_dag_bundles(self, known_files: dict[str, set[DagFileInfo]]): def _find_files_in_bundle(self, bundle: BaseDagBundle) -> list[Path]: """Get relative paths for dag files from bundle dir.""" - # Build up a list of Python files that could contain DAGs self.log.info("Searching for files in %s at %s", bundle.name, bundle.path) - rel_paths = [Path(x).relative_to(bundle.path) for x in list_py_file_paths(bundle.path)] + rel_paths = [ + Path(x).relative_to(bundle.path) + for x in discover_dag_file_paths(bundle.path, bundle_name=bundle.name) + ] self.log.info("Found %s files for bundle %s", len(rel_paths), bundle.name) return rel_paths @@ -822,7 +883,13 @@ def _get_observed_filelocs(self, present: set[DagFileInfo]) -> set[str]: For regular files this includes the relative file path. For ZIP archives this includes DAG-like inner paths such as ``archive.zip/dag.py``. + + Files claimed by a registered runtime coordinator (e.g. ``.jar``) + are treated as opaque files rather than ZIP archives. """ + from airflow.providers_manager import ProvidersManager + + coordinators = ProvidersManager().coordinators def find_zipped_dags(abs_path: os.PathLike) -> Iterator[str]: """Yield absolute paths for DAG-like files inside a ZIP archive.""" @@ -837,7 +904,10 @@ def find_zipped_dags(abs_path: os.PathLike) -> Iterator[str]: observed_filelocs: set[str] = set() for info in present: abs_path = str(info.absolute_path) - if abs_path.endswith(".py") or not zipfile.is_zipfile(abs_path): + handled_by_coordinator = any( + c.can_handle_dag_file(info.bundle_name, abs_path) for c in coordinators + ) + if abs_path.endswith(".py") or handled_by_coordinator or not zipfile.is_zipfile(abs_path): observed_filelocs.add(str(info.rel_path)) else: if TYPE_CHECKING: diff --git a/airflow-core/src/airflow/dag_processing/processor.py b/airflow-core/src/airflow/dag_processing/processor.py index 30ad827ede798..b1fb4d48d9cb4 100644 --- a/airflow-core/src/airflow/dag_processing/processor.py +++ b/airflow-core/src/airflow/dag_processing/processor.py @@ -17,6 +17,7 @@ from __future__ import annotations import contextlib +import functools import importlib import logging import os @@ -75,8 +76,6 @@ from airflow.utils.state import TaskInstanceState if TYPE_CHECKING: - from socket import socket - from structlog.typing import FilteringBoundLogger from airflow.api_fastapi.execution_api.app import InProcessExecutionAPI @@ -85,6 +84,7 @@ from airflow.sdk.definitions.context import Context from airflow.sdk.definitions.dag import DAG from airflow.sdk.definitions.mappedoperator import MappedOperator + from airflow.sdk.execution_time.supervisor import SelectorCallback from airflow.typing_compat import Self @@ -552,7 +552,14 @@ def start( # type: ignore[override] ) -> Self: logger = kwargs["logger"] - _pre_import_airflow_modules(os.fspath(path), logger) + # Check if a provider-registered runtime coordinator should handle this file + logger.debug("Checking for provider-registered runtime coordinator entrypoint for file", path=path) + resolved_target = cls._resolve_processor_target(path, bundle_name, bundle_path, logger) + if resolved_target is not None: + target = resolved_target + logger.debug("Resolved provider-registered runtime coordinator entrypoint for file", path=path) + else: + _pre_import_airflow_modules(os.fspath(path), logger) proc: Self = super().start( target=target, @@ -565,6 +572,53 @@ def start( # type: ignore[override] proc._on_child_started(callbacks, path, bundle_path, bundle_name) return proc + @staticmethod + def _resolve_processor_target( + path: str | os.PathLike[str], + bundle_name: str, + bundle_path: Path, + log: FilteringBoundLogger, + ) -> Callable[[], None] | None: + """ + Return the entrypoint of the first provider runtime coordinator that can handle *path*. + + The returned callable is a ``functools.partial`` that binds *path*, *bundle_name* + and *bundle_path* so the supervisor can pass it as a no-arg ``target`` to + ``WatchedSubprocess.start``. + """ + from airflow.providers_manager import ProvidersManager + + for coordinator_cls in ProvidersManager().coordinators: + try: + log.debug( + "Checking runtime coordinator %s for file %s", + coordinator_cls, + path, + ) + if coordinator_cls.can_handle_dag_file(bundle_name, path): + log.debug( + "Using runtime coordinator %s for file %s", + coordinator_cls, + path, + ) + return functools.partial( + coordinator_cls.run_dag_parsing, + path=os.fspath(path), + bundle_name=bundle_name, + bundle_path=os.fspath(bundle_path), + ) + log.debug( + "Runtime coordinator %s cannot handle file %s with bundle name %s", + coordinator_cls, + path, + bundle_name, + ) + except Exception: + log.warning("Failed to check runtime coordinator %s", coordinator_cls, exc_info=True) + + log.debug("No runtime coordinator found for file %s, using default processor", path) + return None + def _on_child_started( self, callbacks: list[CallbackRequest], @@ -590,7 +644,7 @@ def _get_target_loggers(self) -> tuple[FilteringBoundLogger, ...]: def _create_log_forwarder( self, loggers: tuple[FilteringBoundLogger, ...], name: str, log_level: int = logging.INFO - ) -> Callable[[socket], bool]: + ) -> SelectorCallback: return super()._create_log_forwarder(loggers, name.replace("task.", "dag_processor.", 1), log_level) def _handle_request(self, msg: ToManager, log: FilteringBoundLogger, req_id: int) -> None: diff --git a/airflow-core/src/airflow/executors/base_executor.py b/airflow-core/src/airflow/executors/base_executor.py index 9c9487c5377bc..bced160a5f4c7 100644 --- a/airflow-core/src/airflow/executors/base_executor.py +++ b/airflow-core/src/airflow/executors/base_executor.py @@ -650,11 +650,10 @@ def run_workload( if isinstance(workload, ExecuteTask): from airflow.sdk.execution_time.supervisor import supervise_task + from airflow.sdk.execution_time.workloads.task import TaskInstanceDTO as SDKTaskInstanceDTO - # workload.ti is a TaskInstanceDTO which duck-types as TaskInstance. - # TODO: Create a protocol for this. return supervise_task( - ti=workload.ti, # type: ignore[arg-type] + ti=SDKTaskInstanceDTO.model_validate(workload.ti, from_attributes=True), bundle_info=workload.bundle_info, dag_rel_path=workload.dag_rel_path, token=workload.token, diff --git a/airflow-core/src/airflow/executors/workloads/task.py b/airflow-core/src/airflow/executors/workloads/task.py index d05affe433096..9af3f33c10efd 100644 --- a/airflow-core/src/airflow/executors/workloads/task.py +++ b/airflow-core/src/airflow/executors/workloads/task.py @@ -33,8 +33,14 @@ from airflow.models.taskinstancekey import TaskInstanceKey -class TaskInstanceDTO(BaseModel): - """Schema for TaskInstance with minimal required fields needed for Executors and Task SDK.""" +class BaseTaskInstanceDTO(BaseModel): + """ + Base schema for TaskInstance with the minimal fields shared by Executors and the Task SDK. + + This definition is duplicated in :mod:`airflow.sdk.execution_time.workloads.task` + and the two are kept in sync by the ``check-task-instance-dto-sync`` prek + hook. Update both files together. + """ id: uuid.UUID dag_version_id: uuid.UUID @@ -48,11 +54,16 @@ class TaskInstanceDTO(BaseModel): queue: str priority_weight: int executor_config: dict | None = Field(default=None, exclude=True) - external_executor_id: str | None = Field(default=None, exclude=True) parent_context_carrier: dict | None = None context_carrier: dict | None = None + +class TaskInstanceDTO(BaseTaskInstanceDTO): + """TaskInstanceDTO with executor-specific ``external_executor_id`` field and ``key`` property.""" + + external_executor_id: str | None = Field(default=None, exclude=True) + # TODO: Task-SDK: Can we replace TaskInstanceKey with just the uuid across the codebase? @property def key(self) -> TaskInstanceKey: diff --git a/airflow-core/src/airflow/models/dagcode.py b/airflow-core/src/airflow/models/dagcode.py index 60ee91c8b59b5..528859f4cd311 100644 --- a/airflow-core/src/airflow/models/dagcode.py +++ b/airflow-core/src/airflow/models/dagcode.py @@ -119,6 +119,16 @@ def code(cls, dag_id, session: Session = NEW_SESSION) -> str: @staticmethod def get_code_from_file(fileloc): + # Try from runtime coordinator first (classes are pre-loaded by ProvidersManager) + from airflow.providers_manager import ProvidersManager + + for coordinator_cls in ProvidersManager().coordinators: + # TODO: Perhaps the `can_handle_dag_file` interface should just accept `path` only? + # Or maybe we can have different granularity for this. that 1 with bundle + path, another with just path + if coordinator_cls.can_handle_dag_file("", fileloc): + return coordinator_cls.get_code_from_file(fileloc) + + # Then fallback to python native try: with open_maybe_zipped(fileloc, "r") as f: code = f.read() diff --git a/airflow-core/src/airflow/provider.yaml.schema.json b/airflow-core/src/airflow/provider.yaml.schema.json index 5714b8db658c5..1c41b906289cf 100644 --- a/airflow-core/src/airflow/provider.yaml.schema.json +++ b/airflow-core/src/airflow/provider.yaml.schema.json @@ -624,6 +624,13 @@ } } }, + "coordinators": { + "type": "array", + "description": "Runtime Coordinator class names (BaseCoordinator subclasses)", + "items": { + "type": "string" + } + }, "source-date-epoch": { "type": "integer", "description": "Source date epoch - seconds since epoch (gmtime) when the release documentation was prepared. Used to generate reproducible package builds with flint.", diff --git a/airflow-core/src/airflow/provider_info.schema.json b/airflow-core/src/airflow/provider_info.schema.json index 86fc726a05168..92601fc58af74 100644 --- a/airflow-core/src/airflow/provider_info.schema.json +++ b/airflow-core/src/airflow/provider_info.schema.json @@ -446,6 +446,13 @@ "type": "string" } } + }, + "coordinators": { + "type": "array", + "description": "Runtime Coordinator class names (BaseCoordinator subclasses)", + "items": { + "type": "string" + } } }, "definitions": { diff --git a/airflow-core/src/airflow/providers_manager.py b/airflow-core/src/airflow/providers_manager.py index 6fefcbc39b06d..8945589b4b046 100644 --- a/airflow-core/src/airflow/providers_manager.py +++ b/airflow-core/src/airflow/providers_manager.py @@ -41,6 +41,7 @@ if TYPE_CHECKING: from airflow.cli.cli_config import CLICommand + from airflow.sdk.execution_time.coordinator import BaseCoordinator log = logging.getLogger(__name__) @@ -448,6 +449,7 @@ def __init__(self): ) # Set of plugins contained in providers self._plugins_set: set[PluginInfo] = set() + self._coordinators: list[type[BaseCoordinator]] = [] self._init_airflow_core_hooks() self._runtime_manager = None @@ -625,6 +627,12 @@ def initialize_providers_configuration(self): self.initialize_providers_list() self._discover_config() + @provider_info_cache("coordinators") + def initialize_providers_coordinators(self): + """Lazy initialization of providers runtime coordinators.""" + self.initialize_providers_list() + self._discover_coordinators() + @provider_info_cache("plugins") def initialize_providers_plugins(self): self.initialize_providers_list() @@ -1280,6 +1288,19 @@ def _discover_config(self) -> None: if provider.data.get("config"): self._provider_configs[provider_package] = provider.data.get("config") # type: ignore[assignment] + def _discover_coordinators(self) -> None: + """Retrieve and pre-load all coordinators defined in the providers.""" + seen: set[str] = set() + for provider_package, provider in self._provider_dict.items(): + for coordinator_class_path in provider.data.get("coordinators", []): + if coordinator_class_path in seen: + continue + coordinator_cls = _correctness_check(provider_package, coordinator_class_path, provider) + if coordinator_cls: + seen.add(coordinator_class_path) + self._coordinators.append(coordinator_cls) + self._coordinators = sorted(self._coordinators, key=lambda c: c.__qualname__) + def _discover_plugins(self) -> None: """Retrieve all plugins defined in the providers.""" for provider_package, provider in self._provider_dict.items(): @@ -1477,6 +1498,12 @@ def db_managers(self) -> list[str]: self.initialize_providers_db_managers() return sorted(self._db_manager_class_name_set) + @property + def coordinators(self) -> list[type[BaseCoordinator]]: + """Returns pre-loaded coordinator classes available in providers.""" + self.initialize_providers_coordinators() + return self._coordinators + @property def filesystem_module_names(self) -> list[str]: self.initialize_providers_filesystems() @@ -1548,6 +1575,7 @@ def _cleanup(self): self._trigger_info_set.clear() self._notification_info_set.clear() self._plugins_set.clear() + self._coordinators.clear() self._cli_command_functions_set.clear() self._cli_command_provider_name_set.clear() diff --git a/airflow-core/src/airflow/serialization/definitions/baseoperator.py b/airflow-core/src/airflow/serialization/definitions/baseoperator.py index 6bafc5891235a..9eaf9cc3ed906 100644 --- a/airflow-core/src/airflow/serialization/definitions/baseoperator.py +++ b/airflow-core/src/airflow/serialization/definitions/baseoperator.py @@ -195,6 +195,7 @@ def get_serialized_fields(cls): "ignore_first_depends_on_past", "inlets", "is_setup", + "sdk", "is_teardown", "map_index_template", "max_active_tis_per_dag", diff --git a/airflow-core/src/airflow/utils/file.py b/airflow-core/src/airflow/utils/file.py index c614cfff0ad96..25e191cdccfd8 100644 --- a/airflow-core/src/airflow/utils/file.py +++ b/airflow-core/src/airflow/utils/file.py @@ -19,7 +19,6 @@ import ast import hashlib -import logging import os import re import zipfile @@ -30,8 +29,6 @@ from airflow.configuration import conf -log = logging.getLogger(__name__) - MODIFIED_DAG_MODULE_NAME = "unusual_prefix_{path_hash}_{module_name}" @@ -74,49 +71,6 @@ def open_maybe_zipped(fileloc, mode="r"): return open(fileloc, mode=mode) -def list_py_file_paths( - directory: str | os.PathLike[str] | None, - safe_mode: bool = conf.getboolean("core", "DAG_DISCOVERY_SAFE_MODE", fallback=True), -) -> list[str]: - """ - Traverse a directory and look for Python files. - - :param directory: the directory to traverse - :param safe_mode: whether to use a heuristic to determine whether a file - contains Airflow DAG definitions. If not provided, use the - core.DAG_DISCOVERY_SAFE_MODE configuration setting. If not set, default - to safe. - :return: a list of paths to Python files in the specified directory - """ - file_paths: list[str] = [] - if directory is None: - file_paths = [] - elif os.path.isfile(directory): - file_paths = [str(directory)] - elif os.path.isdir(directory): - file_paths.extend(find_dag_file_paths(directory, safe_mode)) - return file_paths - - -def find_dag_file_paths(directory: str | os.PathLike[str], safe_mode: bool) -> list[str]: - """Find file paths of all DAG files.""" - from airflow._shared.module_loading.file_discovery import find_path_from_directory - - file_paths = [] - ignore_file_syntax = conf.get_mandatory_value("core", "DAG_IGNORE_FILE_SYNTAX", fallback="glob") - - for file_path in find_path_from_directory(directory, ".airflowignore", ignore_file_syntax): - path = Path(file_path) - try: - if path.is_file() and (path.suffix == ".py" or zipfile.is_zipfile(path)): - if might_contain_dag(file_path, safe_mode): - file_paths.append(file_path) - except Exception: - log.exception("Error while examining %s", file_path) - - return file_paths - - COMMENT_PATTERN = re.compile(r"\s*#.*") diff --git a/airflow-core/tests/unit/always/test_providers_manager.py b/airflow-core/tests/unit/always/test_providers_manager.py index afa473e80a4f0..b13930c98d1c2 100644 --- a/airflow-core/tests/unit/always/test_providers_manager.py +++ b/airflow-core/tests/unit/always/test_providers_manager.py @@ -258,6 +258,34 @@ def test_dialects(self): assert len(dialect_class_names) == 3 assert dialect_class_names == ["default", "mssql", "postgresql"] + @patch("airflow.providers_manager.import_string") + def test_coordinators(self, mock_import_string): + class ACoordinator: + pass + + class ZCoordinator: + pass + + mock_import_string.side_effect = lambda path: { + "airflow.providers.sdk.java.coordinator.ACoordinator": ACoordinator, + "airflow.providers.sdk.java.coordinator.ZCoordinator": ZCoordinator, + }[path] + providers_manager = ProvidersManager() + providers_manager._provider_dict = LazyDictWithCache() + providers_manager._provider_dict["apache-airflow-providers-sdk-java"] = ProviderInfo( + version="0.0.1", + data={ + "coordinators": [ + "airflow.providers.sdk.java.coordinator.ZCoordinator", + "airflow.providers.sdk.java.coordinator.ACoordinator", + "airflow.providers.sdk.java.coordinator.ZCoordinator", + ] + }, + ) + + with patch.object(providers_manager, "initialize_providers_list"): + assert providers_manager.coordinators == [ACoordinator, ZCoordinator] + class TestWithoutCheckProviderManager: @pytest.fixture(autouse=True) diff --git a/airflow-core/tests/unit/dag_processing/test_manager.py b/airflow-core/tests/unit/dag_processing/test_manager.py index c10e73473bfef..281a378b26c1e 100644 --- a/airflow-core/tests/unit/dag_processing/test_manager.py +++ b/airflow-core/tests/unit/dag_processing/test_manager.py @@ -31,6 +31,7 @@ from collections import defaultdict, deque from datetime import datetime, timedelta from pathlib import Path +from pprint import pformat from socket import socket, socketpair from unittest import mock from unittest.mock import MagicMock @@ -41,6 +42,7 @@ from sqlalchemy import func, select from uuid6 import uuid7 +from airflow._shared.module_loading import find_path_from_directory from airflow._shared.timezones import timezone from airflow.callbacks.callback_requests import DagCallbackRequest from airflow.dag_processing.bundles.base import BaseDagBundle @@ -51,6 +53,7 @@ DagFileInfo, DagFileProcessorManager, DagFileStat, + discover_dag_file_paths, ) from airflow.dag_processing.processor import DagFileParsingResult, DagFileProcessorProcess from airflow.models import DagModel, DbCallbackRequest @@ -60,6 +63,7 @@ from airflow.models.dagcode import DagCode from airflow.models.serialized_dag import SerializedDagModel from airflow.models.team import Team +from airflow.utils import file as file_utils from airflow.utils.net import get_hostname from airflow.utils.session import create_session @@ -85,6 +89,11 @@ DEFAULT_DATE = timezone.datetime(2016, 1, 1) +def might_contain_dag(file_path: str, zip_file: zipfile.ZipFile | None = None): + """Custom callable injected via conf_vars in TestDagFileDiscovery.test_might_contain_dag.""" + return False + + def _get_file_infos(files: list[str | Path]) -> list[DagFileInfo]: return [DagFileInfo(bundle_name="testing", bundle_path=TEST_DAGS_FOLDER, rel_path=Path(f)) for f in files] @@ -114,6 +123,34 @@ def encode_mtime_in_filename(val): return out +class _FakeCoordinator: + """Test double recording every can_handle_dag_file call and matching by extension.""" + + file_extension: str = ".fakeext" + invocations: list[tuple[str, str]] = [] + + @classmethod + def reset(cls) -> None: + cls.invocations = [] + + @classmethod + def can_handle_dag_file(cls, bundle_name: str, path) -> bool: + cls.invocations.append((bundle_name, str(path))) + return str(path).endswith(cls.file_extension) + + +@pytest.fixture +def fake_coordinator(): + """Inject a fake coordinator into ProvidersManager.coordinators for the duration of a test.""" + _FakeCoordinator.reset() + with mock.patch( + "airflow.providers_manager.ProvidersManager.coordinators", + new_callable=mock.PropertyMock, + return_value=[_FakeCoordinator], + ): + yield _FakeCoordinator + + def _create_zip_bundle_with_valid_and_broken_dags(zip_path: Path) -> None: with zipfile.ZipFile(zip_path, "w") as zf: zf.writestr( @@ -285,6 +322,43 @@ def test_get_observed_filelocs_expands_zip_inner_paths(self, tmp_path): "test_zip.zip/broken_dag.py", } + def test_get_observed_filelocs_treats_coordinator_handled_zip_as_opaque(self, tmp_path, fake_coordinator): + """A coordinator-claimed file that happens to be a ZIP must NOT be expanded into inner paths.""" + # Coordinator handles ".fakeext"; the file is a real ZIP archive so + # without the coordinator check it would be enumerated like a dag-zip. + bundle_file = tmp_path / "bundle.fakeext" + _create_zip_bundle_with_valid_and_broken_dags(bundle_file) + + manager = DagFileProcessorManager(max_runs=1) + observed_filelocs = manager._get_observed_filelocs( + { + DagFileInfo( + bundle_name="testing", + rel_path=Path("bundle.fakeext"), + bundle_path=tmp_path, + ) + } + ) + + assert observed_filelocs == {"bundle.fakeext"} + + def test_get_observed_filelocs_forwards_bundle_name_to_coordinator(self, tmp_path, fake_coordinator): + bundle_file = tmp_path / "bundle.fakeext" + bundle_file.write_bytes(b"opaque payload") + + manager = DagFileProcessorManager(max_runs=1) + manager._get_observed_filelocs( + { + DagFileInfo( + bundle_name="my_bundle", + rel_path=Path("bundle.fakeext"), + bundle_path=tmp_path, + ) + } + ) + + assert fake_coordinator.invocations == [("my_bundle", str(bundle_file))] + @pytest.mark.usefixtures("clear_parse_import_errors") def test_refresh_dag_bundles_keeps_zip_inner_file_errors(self, session, tmp_path, configure_dag_bundles): bundle_path = tmp_path / "bundleone" @@ -2462,3 +2536,162 @@ def test_refresh_dag_bundles_update_bundle_state_failure_still_scans_files(self) # _bundle_versions must NOT advance — DB still holds the old version, so the next # iteration will see a version mismatch and re-refresh rather than skip incorrectly assert "mock_bundle" not in manager._bundle_versions + + +class TestDagFileDiscovery: + def test_find_path_from_directory_regex_ignore(self): + should_ignore = [ + "test_invalid_cron.py", + "test_invalid_param.py", + "test_ignore_this.py", + ] + files = find_path_from_directory(TEST_DAGS_FOLDER, ".airflowignore") + + assert files + assert all(os.path.basename(file) not in should_ignore for file in files) + + def test_find_path_from_directory_glob_ignore(self): + should_ignore = { + "should_ignore_this.py", + "test_explicit_ignore.py", + "test_invalid_cron.py", + "test_invalid_param.py", + "test_ignore_this.py", + "test_prev_dagrun_dep.py", + "test_nested_dag.py", + ".airflowignore", + } + should_not_ignore = { + "test_on_kill.py", + "test_negate_ignore.py", + "test_dont_ignore_this.py", + "test_nested_negate_ignore.py", + "test_explicit_dont_ignore.py", + } + actual_files = list(find_path_from_directory(TEST_DAGS_FOLDER, ".airflowignore_glob", "glob")) + + assert actual_files + assert all(os.path.basename(file) not in should_ignore for file in actual_files) + actual_included_filenames = { + os.path.basename(f) for f in actual_files if os.path.basename(f) in should_not_ignore + } + assert actual_included_filenames == should_not_ignore, ( + f"actual_included_filenames: {pformat(actual_included_filenames)}\nexpected_included_filenames: {pformat(should_not_ignore)}" + ) + + def test_might_contain_dag_with_default_callable(self): + file_path_with_dag = os.path.join(TEST_DAGS_FOLDER, "test_scheduler_dags.py") + + assert file_utils.might_contain_dag(file_path=file_path_with_dag, safe_mode=True) + + @conf_vars({("core", "might_contain_dag_callable"): "unit.dag_processing.test_manager.might_contain_dag"}) + def test_might_contain_dag(self): + """Test might_contain_dag_callable""" + file_path_with_dag = os.path.join(TEST_DAGS_FOLDER, "test_scheduler_dags.py") + + # There is a DAG defined in the file_path_with_dag, however, the might_contain_dag_callable + # returns False no matter what, which is used to test might_contain_dag_callable actually + # overrides the default function + assert not file_utils.might_contain_dag(file_path=file_path_with_dag, safe_mode=True) + + # With safe_mode is False, the user defined callable won't be invoked + assert file_utils.might_contain_dag(file_path=file_path_with_dag, safe_mode=False) + + def test_get_modules(self): + file_path = os.path.join(TEST_DAGS_FOLDER, "test_imports.py") + + modules = list(file_utils.iter_airflow_imports(file_path)) + + assert len(modules) == 4 + assert "airflow.utils" in modules + assert "airflow.decorators" in modules + assert "airflow.models" in modules + assert "airflow.sensors" in modules + # this one is a local import, we don't want it. + assert "airflow.local_import" not in modules + # this one is in a comment, we don't want it + assert "airflow.in_comment" not in modules + # we don't want imports under conditions + assert "airflow.if_branch" not in modules + assert "airflow.else_branch" not in modules + + def test_get_modules_from_invalid_file(self): + file_path = os.path.join(TEST_DAGS_FOLDER, "README.md") # just getting a non-python file + + # should not error + modules = list(file_utils.iter_airflow_imports(file_path)) + + assert len(modules) == 0 + + def test_discover_dag_file_paths(self, test_zip_path): + expected_files = set() + # No_dags is empty, _invalid_ is ignored by .airflowignore + ignored_files = { + "no_dags.py", + "should_ignore_this.py", + "test_explicit_ignore.py", + "test_invalid_cron.py", + "test_invalid_dup_task.py", + "test_ignore_this.py", + "test_invalid_param.py", + "test_invalid_param2.py", + "test_invalid_param3.py", + "test_invalid_param4.py", + "test_nested_dag.py", + "test_imports.py", + "test_nested_negate_ignore.py", + "file_no_airflow_dag.py", # no_dag test case in test_zip folder + "test.py", # no_dag test case in test_zip_module folder + "__init__.py", + } + for root, _, files in os.walk(TEST_DAGS_FOLDER): + for file_name in files: + if file_name.endswith((".py", ".zip")): + if file_name not in ignored_files: + expected_files.add(f"{root}/{file_name}") + detected_files = set(discover_dag_file_paths(str(TEST_DAGS_FOLDER))) + assert detected_files == expected_files, ( + f"Detected files mismatched expected files:\ndetected_files: {pformat(detected_files)}\nexpected_files: {pformat(expected_files)}" + ) + + def test_discover_returns_empty_for_none(self): + assert discover_dag_file_paths(None) == [] + + def test_discover_returns_empty_for_missing_path(self, tmp_path): + assert discover_dag_file_paths(tmp_path / "does_not_exist") == [] + + def test_discover_returns_single_file_as_is(self, tmp_path): + single = tmp_path / "anything.bin" + single.write_bytes(b"opaque") + assert discover_dag_file_paths(single) == [str(single)] + + def test_discover_includes_coordinator_handled_files(self, tmp_path, fake_coordinator): + coord_file = tmp_path / "bundle.fakeext" + coord_file.write_bytes(b"opaque payload") + py_file = tmp_path / "dag.py" + py_file.write_text("from airflow.sdk import DAG\nDAG('d')") + + assert set(discover_dag_file_paths(tmp_path)) == {str(coord_file), str(py_file)} + + def test_discover_coordinator_takes_precedence_over_zip_heuristic(self, tmp_path, fake_coordinator): + """A coordinator-claimed file that is also a ZIP must NOT also be included via the generic ZIP path.""" + coord_zip = tmp_path / "bundle.fakeext" + _create_zip_bundle_with_valid_and_broken_dags(coord_zip) + + # File appears exactly once: claimed by coordinator, generic zip branch skipped. + assert discover_dag_file_paths(tmp_path) == [str(coord_zip)] + + def test_discover_forwards_bundle_name_to_coordinator(self, tmp_path, fake_coordinator): + coord_file = tmp_path / "bundle.fakeext" + coord_file.write_bytes(b"opaque payload") + + discover_dag_file_paths(tmp_path, bundle_name="my_bundle") + + # Only one non-.py file, so exactly one coordinator invocation, with the bundle name. + assert fake_coordinator.invocations == [("my_bundle", str(coord_file))] + + def test_discover_skips_non_matching_unknown_file(self, tmp_path, fake_coordinator): + """A file no coordinator claims and that isn't .py / a ZIP must not appear in results.""" + (tmp_path / "random.bin").write_bytes(b"unknown payload") + + assert discover_dag_file_paths(tmp_path) == [] diff --git a/airflow-core/tests/unit/models/test_dag.py b/airflow-core/tests/unit/models/test_dag.py index b34ab12dd4aef..5666ad9dd3212 100644 --- a/airflow-core/tests/unit/models/test_dag.py +++ b/airflow-core/tests/unit/models/test_dag.py @@ -41,6 +41,7 @@ from airflow._shared.timezones.timezone import datetime as datetime_tz from airflow.configuration import conf from airflow.dag_processing.dagbag import BundleDagBag, DagBag +from airflow.dag_processing.manager import discover_dag_file_paths from airflow.exceptions import AirflowException from airflow.models.asset import ( AssetAliasModel, @@ -91,7 +92,6 @@ NullTimetable, OnceTimetable, ) -from airflow.utils.file import list_py_file_paths from airflow.utils.session import create_session from airflow.utils.state import DagRunState, State, TaskInstanceState from airflow.utils.types import DagRunTriggeredByType, DagRunType @@ -1135,7 +1135,7 @@ def test_dag_is_deactivated_upon_dagfile_deletion(self, dag_maker): DagModel.deactivate_deleted_dags( bundle_name=orm_dag.bundle_name, - rel_filelocs=list_py_file_paths(settings.DAGS_FOLDER), + rel_filelocs=discover_dag_file_paths(settings.DAGS_FOLDER), ) orm_dag = session.scalar(select(DagModel).where(DagModel.dag_id == dag_id)) diff --git a/airflow-core/tests/unit/utils/test_file.py b/airflow-core/tests/unit/utils/test_file.py index cc55c1ac0632e..a1e14d22a45f0 100644 --- a/airflow-core/tests/unit/utils/test_file.py +++ b/airflow-core/tests/unit/utils/test_file.py @@ -18,29 +18,18 @@ from __future__ import annotations import os -import zipfile -from pprint import pformat from unittest import mock import pytest -from airflow._shared.module_loading import find_path_from_directory from airflow.utils import file as file_utils from airflow.utils.file import ( correct_maybe_zipped, - list_py_file_paths, open_maybe_zipped, ) -from tests_common.test_utils.config import conf_vars from unit.models import TEST_DAGS_FOLDER -TEST_DAG_FOLDER = os.environ["AIRFLOW__CORE__DAGS_FOLDER"] - - -def might_contain_dag(file_path: str, zip_file: zipfile.ZipFile | None = None): - return False - class TestCorrectMaybeZipped: @mock.patch("zipfile.is_zipfile") @@ -95,124 +84,6 @@ def test_open_maybe_zipped_archive(self, test_zip_path): assert isinstance(content, str) -class TestListPyFilesPath: - def test_find_path_from_directory_regex_ignore(self): - should_ignore = [ - "test_invalid_cron.py", - "test_invalid_param.py", - "test_ignore_this.py", - ] - files = find_path_from_directory(TEST_DAGS_FOLDER, ".airflowignore") - - assert files - assert all(os.path.basename(file) not in should_ignore for file in files) - - def test_find_path_from_directory_glob_ignore(self): - should_ignore = { - "should_ignore_this.py", - "test_explicit_ignore.py", - "test_invalid_cron.py", - "test_invalid_param.py", - "test_ignore_this.py", - "test_prev_dagrun_dep.py", - "test_nested_dag.py", - ".airflowignore", - } - should_not_ignore = { - "test_on_kill.py", - "test_negate_ignore.py", - "test_dont_ignore_this.py", - "test_nested_negate_ignore.py", - "test_explicit_dont_ignore.py", - } - actual_files = list(find_path_from_directory(TEST_DAGS_FOLDER, ".airflowignore_glob", "glob")) - - assert actual_files - assert all(os.path.basename(file) not in should_ignore for file in actual_files) - actual_included_filenames = set( - [os.path.basename(f) for f in actual_files if os.path.basename(f) in should_not_ignore] - ) - assert actual_included_filenames == should_not_ignore, ( - f"actual_included_filenames: {pformat(actual_included_filenames)}\nexpected_included_filenames: {pformat(should_not_ignore)}" - ) - - def test_might_contain_dag_with_default_callable(self): - file_path_with_dag = os.path.join(TEST_DAGS_FOLDER, "test_scheduler_dags.py") - - assert file_utils.might_contain_dag(file_path=file_path_with_dag, safe_mode=True) - - @conf_vars({("core", "might_contain_dag_callable"): "unit.utils.test_file.might_contain_dag"}) - def test_might_contain_dag(self): - """Test might_contain_dag_callable""" - file_path_with_dag = os.path.join(TEST_DAGS_FOLDER, "test_scheduler_dags.py") - - # There is a DAG defined in the file_path_with_dag, however, the might_contain_dag_callable - # returns False no matter what, which is used to test might_contain_dag_callable actually - # overrides the default function - assert not file_utils.might_contain_dag(file_path=file_path_with_dag, safe_mode=True) - - # With safe_mode is False, the user defined callable won't be invoked - assert file_utils.might_contain_dag(file_path=file_path_with_dag, safe_mode=False) - - def test_get_modules(self): - file_path = os.path.join(TEST_DAGS_FOLDER, "test_imports.py") - - modules = list(file_utils.iter_airflow_imports(file_path)) - - assert len(modules) == 4 - assert "airflow.utils" in modules - assert "airflow.decorators" in modules - assert "airflow.models" in modules - assert "airflow.sensors" in modules - # this one is a local import, we don't want it. - assert "airflow.local_import" not in modules - # this one is in a comment, we don't want it - assert "airflow.in_comment" not in modules - # we don't want imports under conditions - assert "airflow.if_branch" not in modules - assert "airflow.else_branch" not in modules - - def test_get_modules_from_invalid_file(self): - file_path = os.path.join(TEST_DAGS_FOLDER, "README.md") # just getting a non-python file - - # should not error - modules = list(file_utils.iter_airflow_imports(file_path)) - - assert len(modules) == 0 - - def test_list_py_file_paths(self, test_zip_path): - detected_files = set() - expected_files = set() - # No_dags is empty, _invalid_ is ignored by .airflowignore - ignored_files = { - "no_dags.py", - "should_ignore_this.py", - "test_explicit_ignore.py", - "test_invalid_cron.py", - "test_invalid_dup_task.py", - "test_ignore_this.py", - "test_invalid_param.py", - "test_invalid_param2.py", - "test_invalid_param3.py", - "test_invalid_param4.py", - "test_nested_dag.py", - "test_imports.py", - "test_nested_negate_ignore.py", - "file_no_airflow_dag.py", # no_dag test case in test_zip folder - "test.py", # no_dag test case in test_zip_module folder - "__init__.py", - } - for root, _, files in os.walk(TEST_DAG_FOLDER): - for file_name in files: - if file_name.endswith((".py", ".zip")): - if file_name not in ignored_files: - expected_files.add(f"{root}/{file_name}") - detected_files = set(list_py_file_paths(TEST_DAG_FOLDER)) - assert detected_files == expected_files, ( - f"Detected files mismatched expected files:\ndetected_files: {pformat(detected_files)}\nexpected_files: {pformat(expected_files)}" - ) - - @pytest.mark.parametrize( ("edge_filename", "expected_modification"), [ diff --git a/dev/breeze/doc/images/output_build-docs.svg b/dev/breeze/doc/images/output_build-docs.svg index 1858bbb097e91..2a0812c05163e 100644 --- a/dev/breeze/doc/images/output_build-docs.svg +++ b/dev/breeze/doc/images/output_build-docs.svg @@ -240,8 +240,8 @@ hashicorp | helm-chart | http | imap | influxdb | informatica | jdbc | jenkins | keycloak | microsoft.azure |        microsoft.mssql | microsoft.psrp | microsoft.winrm | mongo | mysql | neo4j | odbc | openai | openfaas | openlineage |  opensearch | opsgenie | oracle | pagerduty | papermill | pgvector | pinecone | postgres | presto | qdrant | redis |    -salesforce | samba | segment | sendgrid | sftp | singularity | slack | smtp | snowflake | sqlite | ssh | standard |    -tableau | task-sdk | telegram | teradata | trino | vertica | vespa | weaviate | yandex | ydb | zendesk]...             +salesforce | samba | sdk.java | segment | sendgrid | sftp | singularity | slack | smtp | snowflake | sqlite | ssh |    +standard | tableau | task-sdk | telegram | teradata | trino | vertica | vespa | weaviate | yandex | ydb | zendesk]...  Build documents. diff --git a/dev/breeze/doc/images/output_build-docs.txt b/dev/breeze/doc/images/output_build-docs.txt index 54d8d4e3f39bb..247bee9e56ff6 100644 --- a/dev/breeze/doc/images/output_build-docs.txt +++ b/dev/breeze/doc/images/output_build-docs.txt @@ -1 +1 @@ -c5f2067ec852773089ed0ca7b8d1d533 +b4c249b4d1f7605a443774262109694a diff --git a/dev/breeze/doc/images/output_release-management_add-back-references.svg b/dev/breeze/doc/images/output_release-management_add-back-references.svg index f17f7f47ed43b..37e9086660253 100644 --- a/dev/breeze/doc/images/output_release-management_add-back-references.svg +++ b/dev/breeze/doc/images/output_release-management_add-back-references.svg @@ -155,8 +155,8 @@ hashicorp | helm-chart | http | imap | influxdb | informatica | jdbc | jenkins | keycloak | microsoft.azure |        microsoft.mssql | microsoft.psrp | microsoft.winrm | mongo | mysql | neo4j | odbc | openai | openfaas | openlineage |  opensearch | opsgenie | oracle | pagerduty | papermill | pgvector | pinecone | postgres | presto | qdrant | redis |    -salesforce | samba | segment | sendgrid | sftp | singularity | slack | smtp | snowflake | sqlite | ssh | standard |    -tableau | task-sdk | telegram | teradata | trino | vertica | vespa | weaviate | yandex | ydb | zendesk]...             +salesforce | samba | sdk.java | segment | sendgrid | sftp | singularity | slack | smtp | snowflake | sqlite | ssh |    +standard | tableau | task-sdk | telegram | teradata | trino | vertica | vespa | weaviate | yandex | ydb | zendesk]...  Command to add back references for documentation to make it backward compatible. diff --git a/dev/breeze/doc/images/output_release-management_add-back-references.txt b/dev/breeze/doc/images/output_release-management_add-back-references.txt index ffc7eeea6018b..a43ec033fc2a6 100644 --- a/dev/breeze/doc/images/output_release-management_add-back-references.txt +++ b/dev/breeze/doc/images/output_release-management_add-back-references.txt @@ -1 +1 @@ -3df401aef0085547b08fe896a9a65381 +a44de0a6fcf0ad832e0b2a73a883f0a0 diff --git a/dev/breeze/doc/images/output_release-management_generate-issue-content-providers.svg b/dev/breeze/doc/images/output_release-management_generate-issue-content-providers.svg index 8fe24cdf434e6..6566b6c97716f 100644 --- a/dev/breeze/doc/images/output_release-management_generate-issue-content-providers.svg +++ b/dev/breeze/doc/images/output_release-management_generate-issue-content-providers.svg @@ -149,9 +149,9 @@ github | google | grpc | hashicorp | http | imap | influxdb | informatica | jdbc | jenkins | keycloak |                microsoft.azure | microsoft.mssql | microsoft.psrp | microsoft.winrm | mongo | mysql | neo4j | odbc | openai |         openfaas | openlineage | opensearch | opsgenie | oracle | pagerduty | papermill | pgvector | pinecone | postgres |     -presto | qdrant | redis | salesforce | samba | segment | sendgrid | sftp | singularity | slack | smtp | snowflake |    -sqlite | ssh | standard | tableau | telegram | teradata | trino | vertica | vespa | weaviate | yandex | ydb |          -zendesk]...                                                                                                            +presto | qdrant | redis | salesforce | samba | sdk.java | segment | sendgrid | sftp | singularity | slack | smtp |     +snowflake | sqlite | ssh | standard | tableau | telegram | teradata | trino | vertica | vespa | weaviate | yandex |    +ydb | zendesk]...                                                                                                      Generates content for issue to test the release. diff --git a/dev/breeze/doc/images/output_release-management_generate-issue-content-providers.txt b/dev/breeze/doc/images/output_release-management_generate-issue-content-providers.txt index c6189be26338f..0c327de82828f 100644 --- a/dev/breeze/doc/images/output_release-management_generate-issue-content-providers.txt +++ b/dev/breeze/doc/images/output_release-management_generate-issue-content-providers.txt @@ -1 +1 @@ -a85c889b710aa347eb6c47fc36b11720 +ee99c790838efb1d5e5a3b06e6c49846 diff --git a/dev/breeze/doc/images/output_release-management_generate-providers-metadata.svg b/dev/breeze/doc/images/output_release-management_generate-providers-metadata.svg index 867b9fedc0357..742e316f5a754 100644 --- a/dev/breeze/doc/images/output_release-management_generate-providers-metadata.svg +++ b/dev/breeze/doc/images/output_release-management_generate-providers-metadata.svg @@ -1,4 +1,4 @@ - + keycloak | microsoft.azure | microsoft.mssql | microsoft.psrp | microsoft.winrm | mongo | mysql  | neo4j | odbc | openai | openfaas | openlineage | opensearch | opsgenie | oracle | pagerduty |  papermill | pgvector | pinecone | postgres | presto | qdrant | redis | salesforce | samba |  -segment | sendgrid | sftp | singularity | slack | smtp | snowflake | sqlite | ssh | standard |  -tableau | telegram | teradata | trino | vertica | vespa | weaviate | yandex | ydb | zendesk) ---provider-versionProvider version to generate metadata for. Only used when --provider-id is specified. Limits     -running metadata generation to only this version of the provider. (TEXT) -╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ -╭─ Common options ─────────────────────────────────────────────────────────────────────────────────────────────────────╮ ---dry-run-DIf dry-run is set, commands are only printed, not executed. ---verbose-vPrint verbose information about performed steps. ---help   -hShow this message and exit. -╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ +sdk.java | segment | sendgrid | sftp | singularity | slack | smtp | snowflake | sqlite | ssh |  +standard | tableau | telegram | teradata | trino | vertica | vespa | weaviate | yandex | ydb |  +zendesk) +--provider-versionProvider version to generate metadata for. Only used when --provider-id is specified. Limits     +running metadata generation to only this version of the provider. (TEXT) +╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ +╭─ Common options ─────────────────────────────────────────────────────────────────────────────────────────────────────╮ +--dry-run-DIf dry-run is set, commands are only printed, not executed. +--verbose-vPrint verbose information about performed steps. +--help   -hShow this message and exit. +╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ diff --git a/dev/breeze/doc/images/output_release-management_generate-providers-metadata.txt b/dev/breeze/doc/images/output_release-management_generate-providers-metadata.txt index 3615848d57819..6e5c7fd64de14 100644 --- a/dev/breeze/doc/images/output_release-management_generate-providers-metadata.txt +++ b/dev/breeze/doc/images/output_release-management_generate-providers-metadata.txt @@ -1 +1 @@ -fdfdca32a5248d3b91cb29e14cc538b4 +de007da2573c2e6066fd2b0d26d14874 diff --git a/dev/breeze/doc/images/output_release-management_prepare-provider-distributions.svg b/dev/breeze/doc/images/output_release-management_prepare-provider-distributions.svg index 3661b47f2a46d..dbabcc063abf1 100644 --- a/dev/breeze/doc/images/output_release-management_prepare-provider-distributions.svg +++ b/dev/breeze/doc/images/output_release-management_prepare-provider-distributions.svg @@ -197,9 +197,9 @@ github | google | grpc | hashicorp | http | imap | influxdb | informatica | jdbc | jenkins | keycloak |                microsoft.azure | microsoft.mssql | microsoft.psrp | microsoft.winrm | mongo | mysql | neo4j | odbc | openai |         openfaas | openlineage | opensearch | opsgenie | oracle | pagerduty | papermill | pgvector | pinecone | postgres |     -presto | qdrant | redis | salesforce | samba | segment | sendgrid | sftp | singularity | slack | smtp | snowflake |    -sqlite | ssh | standard | tableau | telegram | teradata | trino | vertica | vespa | weaviate | yandex | ydb |          -zendesk]...                                                                                                            +presto | qdrant | redis | salesforce | samba | sdk.java | segment | sendgrid | sftp | singularity | slack | smtp |     +snowflake | sqlite | ssh | standard | tableau | telegram | teradata | trino | vertica | vespa | weaviate | yandex |    +ydb | zendesk]...                                                                                                      Prepare sdist/whl distributions of Airflow Providers. Each provider directory is wiped with `git clean -fdx (preserving .venv, .idea, .vscode) before build to keep in-tree generated files out of the artifact. See dev/breeze  diff --git a/dev/breeze/doc/images/output_release-management_prepare-provider-distributions.txt b/dev/breeze/doc/images/output_release-management_prepare-provider-distributions.txt index f10fd70bd89fa..aa4a21a6dcebd 100644 --- a/dev/breeze/doc/images/output_release-management_prepare-provider-distributions.txt +++ b/dev/breeze/doc/images/output_release-management_prepare-provider-distributions.txt @@ -1 +1 @@ -18d45fa2bec60ab0557f04fb4427b35e +71c54d02659478978d0aa40b2baf4fef diff --git a/dev/breeze/doc/images/output_release-management_prepare-provider-documentation.svg b/dev/breeze/doc/images/output_release-management_prepare-provider-documentation.svg index c4454038e4a78..ae36d4fde676e 100644 --- a/dev/breeze/doc/images/output_release-management_prepare-provider-documentation.svg +++ b/dev/breeze/doc/images/output_release-management_prepare-provider-documentation.svg @@ -218,9 +218,9 @@ github | google | grpc | hashicorp | http | imap | influxdb | informatica | jdbc | jenkins | keycloak |                microsoft.azure | microsoft.mssql | microsoft.psrp | microsoft.winrm | mongo | mysql | neo4j | odbc | openai |         openfaas | openlineage | opensearch | opsgenie | oracle | pagerduty | papermill | pgvector | pinecone | postgres |     -presto | qdrant | redis | salesforce | samba | segment | sendgrid | sftp | singularity | slack | smtp | snowflake |    -sqlite | ssh | standard | tableau | telegram | teradata | trino | vertica | vespa | weaviate | yandex | ydb |          -zendesk]...                                                                                                            +presto | qdrant | redis | salesforce | samba | sdk.java | segment | sendgrid | sftp | singularity | slack | smtp |     +snowflake | sqlite | ssh | standard | tableau | telegram | teradata | trino | vertica | vespa | weaviate | yandex |    +ydb | zendesk]...                                                                                                      Prepare CHANGELOG, README and COMMITS information for providers. diff --git a/dev/breeze/doc/images/output_release-management_prepare-provider-documentation.txt b/dev/breeze/doc/images/output_release-management_prepare-provider-documentation.txt index 5586a29b8136c..68e5927948e4f 100644 --- a/dev/breeze/doc/images/output_release-management_prepare-provider-documentation.txt +++ b/dev/breeze/doc/images/output_release-management_prepare-provider-documentation.txt @@ -1 +1 @@ -622441d283775edefeda685820e7169a +542fd516d5584cf5bc1b6aa945338a8c diff --git a/dev/breeze/doc/images/output_release-management_publish-docs.svg b/dev/breeze/doc/images/output_release-management_publish-docs.svg index d119da2013d75..291f8b0d144e3 100644 --- a/dev/breeze/doc/images/output_release-management_publish-docs.svg +++ b/dev/breeze/doc/images/output_release-management_publish-docs.svg @@ -194,8 +194,8 @@ hashicorp | helm-chart | http | imap | influxdb | informatica | jdbc | jenkins | keycloak | microsoft.azure |        microsoft.mssql | microsoft.psrp | microsoft.winrm | mongo | mysql | neo4j | odbc | openai | openfaas | openlineage |  opensearch | opsgenie | oracle | pagerduty | papermill | pgvector | pinecone | postgres | presto | qdrant | redis |    -salesforce | samba | segment | sendgrid | sftp | singularity | slack | smtp | snowflake | sqlite | ssh | standard |    -tableau | task-sdk | telegram | teradata | trino | vertica | vespa | weaviate | yandex | ydb | zendesk]...             +salesforce | samba | sdk.java | segment | sendgrid | sftp | singularity | slack | smtp | snowflake | sqlite | ssh |    +standard | tableau | task-sdk | telegram | teradata | trino | vertica | vespa | weaviate | yandex | ydb | zendesk]...  Command to publish generated documentation to airflow-site diff --git a/dev/breeze/doc/images/output_release-management_publish-docs.txt b/dev/breeze/doc/images/output_release-management_publish-docs.txt index c73c7846664c8..487f7d9fef5f8 100644 --- a/dev/breeze/doc/images/output_release-management_publish-docs.txt +++ b/dev/breeze/doc/images/output_release-management_publish-docs.txt @@ -1 +1 @@ -4521ec02334b8909f66e82c460a69446 +6a7fed8b89fffc1e9d8856bf1a2d5f2d diff --git a/dev/breeze/doc/images/output_sbom_generate-providers-requirements.svg b/dev/breeze/doc/images/output_sbom_generate-providers-requirements.svg index fd62a65b513d4..ec6cd73739017 100644 --- a/dev/breeze/doc/images/output_sbom_generate-providers-requirements.svg +++ b/dev/breeze/doc/images/output_sbom_generate-providers-requirements.svg @@ -189,9 +189,9 @@ | grpc | hashicorp | http | imap | influxdb | informatica | jdbc | jenkins | keycloak |  microsoft.azure | microsoft.mssql | microsoft.psrp | microsoft.winrm | mongo | mysql | neo4j |  odbc | openai | openfaas | openlineage | opensearch | opsgenie | oracle | pagerduty | papermill  -| pgvector | pinecone | postgres | presto | qdrant | redis | salesforce | samba | segment |  -sendgrid | sftp | singularity | slack | smtp | snowflake | sqlite | ssh | standard | tableau |  -telegram | teradata | trino | vertica | vespa | weaviate | yandex | ydb | zendesk) +| pgvector | pinecone | postgres | presto | qdrant | redis | salesforce | samba | sdk.java |  +segment | sendgrid | sftp | singularity | slack | smtp | snowflake | sqlite | ssh | standard |  +tableau | telegram | teradata | trino | vertica | vespa | weaviate | yandex | ydb | zendesk) --provider-versionProvider version to generate the requirements for i.e `2.1.0`. `latest` is also a supported      value to account for the most recent version of the provider (TEXT) --force           Force update providers requirements even if they already exist. diff --git a/dev/breeze/doc/images/output_sbom_generate-providers-requirements.txt b/dev/breeze/doc/images/output_sbom_generate-providers-requirements.txt index a7761ea29d68a..8a0c324836340 100644 --- a/dev/breeze/doc/images/output_sbom_generate-providers-requirements.txt +++ b/dev/breeze/doc/images/output_sbom_generate-providers-requirements.txt @@ -1 +1 @@ -fa98bbcd73f9160c29eff1b6779a23bc +c4babe6a19ea7748ed3488c930187a8e diff --git a/dev/breeze/doc/images/output_workflow-run_publish-docs.svg b/dev/breeze/doc/images/output_workflow-run_publish-docs.svg index 511790e79d721..0b42e92e917d3 100644 --- a/dev/breeze/doc/images/output_workflow-run_publish-docs.svg +++ b/dev/breeze/doc/images/output_workflow-run_publish-docs.svg @@ -200,8 +200,8 @@ hashicorp | helm-chart | http | imap | influxdb | informatica | jdbc | jenkins | keycloak | microsoft.azure |        microsoft.mssql | microsoft.psrp | microsoft.winrm | mongo | mysql | neo4j | odbc | openai | openfaas | openlineage |  opensearch | opsgenie | oracle | pagerduty | papermill | pgvector | pinecone | postgres | presto | qdrant | redis |    -salesforce | samba | segment | sendgrid | sftp | singularity | slack | smtp | snowflake | sqlite | ssh | standard |    -tableau | task-sdk | telegram | teradata | trino | vertica | vespa | weaviate | yandex | ydb | zendesk]...             +salesforce | samba | sdk.java | segment | sendgrid | sftp | singularity | slack | smtp | snowflake | sqlite | ssh |    +standard | tableau | task-sdk | telegram | teradata | trino | vertica | vespa | weaviate | yandex | ydb | zendesk]...  Trigger publish docs to S3 workflow diff --git a/dev/breeze/doc/images/output_workflow-run_publish-docs.txt b/dev/breeze/doc/images/output_workflow-run_publish-docs.txt index 6a433f7935a96..cbb67ce0a1df9 100644 --- a/dev/breeze/doc/images/output_workflow-run_publish-docs.txt +++ b/dev/breeze/doc/images/output_workflow-run_publish-docs.txt @@ -1 +1 @@ -6ff7091e58988c6273e51f372bb8a1a6 +a4876e7e49973aad884a0270de53885a diff --git a/dev/breeze/src/airflow_breeze/global_constants.py b/dev/breeze/src/airflow_breeze/global_constants.py index 4deaa3bf598b9..ce35be9d26a3a 100644 --- a/dev/breeze/src/airflow_breeze/global_constants.py +++ b/dev/breeze/src/airflow_breeze/global_constants.py @@ -789,25 +789,25 @@ def get_airflow_extras(): { "python-version": "3.10", "airflow-version": "2.11.1", - "remove-providers": "common.messaging edge3 fab git keycloak informatica common.ai opensearch", + "remove-providers": "common.messaging edge3 fab git keycloak informatica common.ai opensearch sdk.java", "run-unit-tests": "true", }, { "python-version": "3.10", "airflow-version": "3.0.6", - "remove-providers": "", + "remove-providers": "sdk.java", "run-unit-tests": "true", }, { "python-version": "3.10", "airflow-version": "3.1.8", - "remove-providers": "", + "remove-providers": "sdk.java", "run-unit-tests": "true", }, { "python-version": "3.10", "airflow-version": "3.2.1", - "remove-providers": "", + "remove-providers": "sdk.java", "run-unit-tests": "true", }, ] diff --git a/dev/registry/extract_metadata.py b/dev/registry/extract_metadata.py index 463a9c408082b..5d9f635f74ea3 100644 --- a/dev/registry/extract_metadata.py +++ b/dev/registry/extract_metadata.py @@ -46,7 +46,7 @@ try: import tomllib # Python 3.11+ stdlib except ModuleNotFoundError: # pragma: no cover -- Python 3.10 fallback - import tomli as tomllib + import tomli as tomllib # type: ignore[no-redef] import yaml from registry_contract_models import validate_providers_catalog diff --git a/dev/registry/extract_versions.py b/dev/registry/extract_versions.py index d9dc4e166dcf1..2908b22b32e6a 100644 --- a/dev/registry/extract_versions.py +++ b/dev/registry/extract_versions.py @@ -49,7 +49,7 @@ try: import tomllib # Python 3.11+ stdlib except ModuleNotFoundError: # pragma: no cover -- Python 3.10 fallback - import tomli as tomllib + import tomli as tomllib # type: ignore[no-redef] from registry_contract_models import validate_provider_version_metadata try: diff --git a/devel-common/src/docs/provider_conf.py b/devel-common/src/docs/provider_conf.py index 6bc9da15f5f61..b730e8f20a417 100644 --- a/devel-common/src/docs/provider_conf.py +++ b/devel-common/src/docs/provider_conf.py @@ -151,7 +151,7 @@ # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. -empty_subpackages = ["apache", "atlassian", "common", "cncf", "dbt", "microsoft"] +empty_subpackages = ["apache", "atlassian", "common", "cncf", "dbt", "microsoft", "sdk"] exclude_patterns = [ "operators/_partials", "_api/airflow/index.rst", diff --git a/devel-common/src/tests_common/pytest_plugin.py b/devel-common/src/tests_common/pytest_plugin.py index 98a871ebb12d2..ed3547ccbfa84 100644 --- a/devel-common/src/tests_common/pytest_plugin.py +++ b/devel-common/src/tests_common/pytest_plugin.py @@ -2510,7 +2510,6 @@ def execute(self, context): from uuid6 import uuid7 from airflow.sdk import DAG - from airflow.sdk.api.datamodels._generated import TaskInstance from airflow.sdk.execution_time.comms import BundleInfo, StartupDetails from airflow.timetables.base import TimeRestriction @@ -2538,6 +2537,15 @@ def _create_task_instance( should_retry: bool | None = None, max_tries: int | None = None, ) -> RuntimeTaskInstance: + from tests_common.test_utils.version_compat import AIRFLOW_V_3_3_PLUS + + if AIRFLOW_V_3_3_PLUS: + from airflow.sdk.execution_time.workloads.task import TaskInstanceDTO + else: + from airflow.sdk.api.datamodels._generated import ( # type: ignore[no-redef,assignment] + TaskInstance as TaskInstanceDTO, + ) + from airflow.sdk.api.datamodels._generated import DagRun, DagRunState, TIRunContext from airflow.utils.types import DagRunType @@ -2615,14 +2623,17 @@ def _create_task_instance( } startup_details = StartupDetails( - ti=TaskInstance( + ti=TaskInstanceDTO( id=ti_id, task_id=task.task_id, dag_id=dag_id, run_id=run_id, try_number=try_number, - map_index=map_index, + map_index=map_index, # type: ignore[arg-type] dag_version_id=uuid7(), + pool_slots=1, + queue="default", + priority_weight=1, ), dag_rel_path="", bundle_info=BundleInfo(name="anything", version="any"), diff --git a/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml b/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml index b9388b6b6c9d6..2dd3391f007aa 100644 --- a/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml +++ b/providers/edge3/src/airflow/providers/edge3/worker_api/v2-edge-generated.yaml @@ -1257,8 +1257,8 @@ components: - queue - priority_weight title: TaskInstanceDTO - description: Schema for TaskInstance with minimal required fields needed for - Executors and Task SDK. + description: TaskInstanceDTO with executor-specific ``external_executor_id`` + field and ``key`` property. TaskInstanceState: type: string enum: diff --git a/providers/sdk/java/.gitignore b/providers/sdk/java/.gitignore new file mode 100644 index 0000000000000..bff2d7629604d --- /dev/null +++ b/providers/sdk/java/.gitignore @@ -0,0 +1 @@ +*.iml diff --git a/providers/sdk/java/LICENSE b/providers/sdk/java/LICENSE new file mode 100644 index 0000000000000..11069edd79019 --- /dev/null +++ b/providers/sdk/java/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + +2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + +3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + +4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + +5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + +6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + +8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + +APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + +Copyright [yyyy] [name of copyright owner] + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/providers/sdk/java/NOTICE b/providers/sdk/java/NOTICE new file mode 100644 index 0000000000000..a51bd9390d030 --- /dev/null +++ b/providers/sdk/java/NOTICE @@ -0,0 +1,5 @@ +Apache Airflow +Copyright 2016-2026 The Apache Software Foundation + +This product includes software developed at +The Apache Software Foundation (http://www.apache.org/). diff --git a/providers/sdk/java/README.rst b/providers/sdk/java/README.rst new file mode 100644 index 0000000000000..ba3081bb6cb53 --- /dev/null +++ b/providers/sdk/java/README.rst @@ -0,0 +1,60 @@ + +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +.. NOTE! THIS FILE IS AUTOMATICALLY GENERATED AND WILL BE OVERWRITTEN! + +.. IF YOU WANT TO MODIFY TEMPLATE FOR THIS FILE, YOU SHOULD MODIFY THE TEMPLATE + ``PROVIDER_README_TEMPLATE.rst.jinja2`` IN the ``dev/breeze/src/airflow_breeze/templates`` DIRECTORY + +Package ``apache-airflow-providers-sdk-java`` + +Release: ``0.1.0`` + + +Java Coordinator + + +Provider package +---------------- + +This is a provider package for ``sdk.java`` provider. All classes for this provider package +are in ``airflow.providers.sdk.java`` python package. + +You can find package information and changelog for the provider +in the `documentation `_. + +Installation +------------ + +You can install this package on top of an existing Airflow installation (see ``Requirements`` below +for the minimum Airflow version supported) via +``pip install apache-airflow-providers-sdk-java`` + +The package supports the following python versions: 3.10,3.11,3.12,3.13,3.14 + +Requirements +------------ + +================== ================== +PIP package Version required +================== ================== +``apache-airflow`` ``>=3.0.0`` +================== ================== + +The changelog for the provider package can be found in the +`changelog `_. diff --git a/providers/sdk/java/docs/.latest-doc-only-change.txt b/providers/sdk/java/docs/.latest-doc-only-change.txt new file mode 100644 index 0000000000000..2c1ab461a9c8e --- /dev/null +++ b/providers/sdk/java/docs/.latest-doc-only-change.txt @@ -0,0 +1 @@ +da9caffdbbeab1917e1cec5726e50af5f14a5206 diff --git a/providers/sdk/java/docs/changelog.rst b/providers/sdk/java/docs/changelog.rst new file mode 100644 index 0000000000000..c5aa1ad337ef8 --- /dev/null +++ b/providers/sdk/java/docs/changelog.rst @@ -0,0 +1,40 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +.. NOTE TO CONTRIBUTORS: + Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes + and you want to add an explanation to the users on how they are supposed to deal with them. + The changelog is updated and maintained semi-automatically by release manager. + + +``apache-airflow-providers-sdk-java`` + + +Changelog +--------- + +0.1.0 +..... + +Features +~~~~~~~~ + +* ``Add the initial Java coordinator interface`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): diff --git a/providers/sdk/java/docs/commits.rst b/providers/sdk/java/docs/commits.rst new file mode 100644 index 0000000000000..6b84d751e94e3 --- /dev/null +++ b/providers/sdk/java/docs/commits.rst @@ -0,0 +1,35 @@ + + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + .. NOTE! THIS FILE IS AUTOMATICALLY GENERATED AND WILL BE OVERWRITTEN! + + .. IF YOU WANT TO MODIFY THIS FILE, YOU SHOULD MODIFY THE TEMPLATE + `PROVIDER_COMMITS_TEMPLATE.rst.jinja2` IN the `dev/breeze/src/airflow_breeze/templates` DIRECTORY + + .. THE REMAINDER OF THE FILE IS AUTOMATICALLY GENERATED. IT WILL BE OVERWRITTEN! + +Package apache-airflow-providers-sdk-java +------------------------------------------------------ + +Java Coordinator + + +This is detailed commit list of changes for versions provider package: ``sdk.java``. +For high-level changelog, see :doc:`package information including changelog `. + +.. airflow-providers-commits:: diff --git a/providers/sdk/java/docs/conf.py b/providers/sdk/java/docs/conf.py new file mode 100644 index 0000000000000..596c5b5c7b5f3 --- /dev/null +++ b/providers/sdk/java/docs/conf.py @@ -0,0 +1,27 @@ +# Disable Flake8 because of all the sphinx imports +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Configuration of Providers docs building.""" + +from __future__ import annotations + +import os + +os.environ["AIRFLOW_PACKAGE_NAME"] = "apache-airflow-providers-sdk-java" + +from docs.provider_conf import * # noqa: F403 diff --git a/providers/sdk/java/docs/configurations-ref.rst b/providers/sdk/java/docs/configurations-ref.rst new file mode 100644 index 0000000000000..ea8e668d75793 --- /dev/null +++ b/providers/sdk/java/docs/configurations-ref.rst @@ -0,0 +1,19 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +.. include:: /../../../../devel-common/src/sphinx_exts/includes/providers-configurations-ref.rst +.. include:: /../../../../devel-common/src/sphinx_exts/includes/sections-and-options.rst diff --git a/providers/sdk/java/docs/index.rst b/providers/sdk/java/docs/index.rst new file mode 100644 index 0000000000000..77e8b1e22d80e --- /dev/null +++ b/providers/sdk/java/docs/index.rst @@ -0,0 +1,123 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +``apache-airflow-providers-sdk-java`` +=========================================== + +The SDK: Java provider registers Java-specific task coordinator and DAG file processor classes for Apache Airflow. + +.. toctree:: + :hidden: + :maxdepth: 1 + :caption: Basics + + Home + Changelog + Security + +.. toctree:: + :hidden: + :maxdepth: 1 + :caption: Guides + + Configuration + +.. toctree:: + :hidden: + :maxdepth: 1 + :caption: References + + Python API <_api/airflow/providers/sdk/java/index> + +.. toctree:: + :hidden: + :maxdepth: 1 + :caption: Resources + + PyPI Repository + Installing from sources + +.. THE REMAINDER OF THE FILE IS AUTOMATICALLY GENERATED. IT WILL BE OVERWRITTEN AT RELEASE TIME! + + +.. toctree:: + :hidden: + :maxdepth: 1 + :caption: Commits + + Detailed list of commits + + +apache-airflow-providers-sdk-java package +------------------------------------------------------ + +Java SDK support for Apache Airflow runtime coordinators. + + +Release: 0.1.0 + +Provider package +---------------- + +This package is for the ``sdk.java`` provider. +All classes for this package are included in the ``airflow.providers.sdk.java`` python package. + +Installation +------------ + +You can install this package on top of an existing Airflow installation via +``pip install apache-airflow-providers-sdk-java``. +For the minimum Airflow version supported, see ``Requirements`` below. + +Requirements +------------ + +The minimum Apache Airflow version supported by this provider distribution is ``3.3.0``. + +================== ================== +PIP package Version required +================== ================== +``apache-airflow`` ``>=3.3.0`` +================== ================== + +Cross provider package dependencies +----------------------------------- + +Those are dependencies that might be needed in order to use all the features of the package. +You need to install the specified provider distributions in order to use them. + +You can install such cross-provider dependencies when installing from PyPI. For example: + +.. code-block:: bash + + pip install apache-airflow-providers-sdk-java[common.compat] + + +================================================================================================================== ================= +Dependent package Extra +================================================================================================================== ================= +`apache-airflow-providers-common-compat `_ ``common.compat`` +================================================================================================================== ================= + +Downloading official packages +----------------------------- + +You can download officially released packages and verify their checksums and signatures from the +`Official Apache Download site `_ + +* `The apache-airflow-providers-sdk-java 0.1.0 sdist package `_ (`asc `__, `sha512 `__) +* `The apache-airflow-providers-sdk-java 0.1.0 wheel package `_ (`asc `__, `sha512 `__) diff --git a/providers/sdk/java/docs/installing-providers-from-sources.rst b/providers/sdk/java/docs/installing-providers-from-sources.rst new file mode 100644 index 0000000000000..fdbb17d017579 --- /dev/null +++ b/providers/sdk/java/docs/installing-providers-from-sources.rst @@ -0,0 +1,18 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +.. include:: /../../../../devel-common/src/sphinx_exts/includes/installing-providers-from-sources.rst diff --git a/providers/sdk/java/docs/security.rst b/providers/sdk/java/docs/security.rst new file mode 100644 index 0000000000000..351ff007ebf2f --- /dev/null +++ b/providers/sdk/java/docs/security.rst @@ -0,0 +1,18 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +.. include:: /../../../../devel-common/src/sphinx_exts/includes/security.rst diff --git a/providers/sdk/java/provider.yaml b/providers/sdk/java/provider.yaml new file mode 100644 index 0000000000000..d10f841962034 --- /dev/null +++ b/providers/sdk/java/provider.yaml @@ -0,0 +1,58 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +package-name: apache-airflow-providers-sdk-java +name: "SDK: Java" +description: | + Java SDK support for Apache Airflow runtime coordinators. + +state: ready +lifecycle: incubation +source-date-epoch: 1775631151 +# Note that those versions are maintained by release manager - do not update them manually +# with the exception of case where other provider in sources has >= new provider version. +# In such case adding >= NEW_VERSION and bumping to NEW_VERSION in a provider have +# to be done in the same PR +versions: + - 0.1.0 + +integrations: + - integration-name: Java + external-doc-url: https://openjdk.org/ + tags: + - software + +config: + java: + description: "Options for the Java SDK provider." + options: + bundles_folder: + description: | + Path to the directory containing Java DAG bundle JARs. + When using Python stub DAGs that delegate task execution to Java, + the coordinator scans this directory to find the JAR bundle matching + the target dag_id. Each immediate subdirectory is treated as a + separate bundle home, and the directory itself is also checked + (flat layout). + type: string + version_added: ~ + example: ~/airflow/java-bundles + default: "" + +coordinators: + - airflow.providers.sdk.java.coordinator.JavaCoordinator diff --git a/providers/sdk/java/pyproject.toml b/providers/sdk/java/pyproject.toml new file mode 100644 index 0000000000000..6baca6f81fdd4 --- /dev/null +++ b/providers/sdk/java/pyproject.toml @@ -0,0 +1,131 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# NOTE! THIS FILE IS AUTOMATICALLY GENERATED AND WILL BE OVERWRITTEN! + +# IF YOU WANT TO MODIFY THIS FILE EXCEPT DEPENDENCIES, YOU SHOULD MODIFY THE TEMPLATE +# `pyproject_TEMPLATE.toml.jinja2` IN the `dev/breeze/src/airflow_breeze/templates` DIRECTORY +[build-system] +requires = ["flit_core==3.12.0"] +build-backend = "flit_core.buildapi" + +[project] +name = "apache-airflow-providers-sdk-java" +version = "0.1.0" +description = "Provider package apache-airflow-providers-sdk-java for Apache Airflow" +readme = "README.rst" +license = "Apache-2.0" +license-files = ['LICENSE', 'NOTICE'] +authors = [ + {name="Apache Software Foundation", email="dev@airflow.apache.org"}, +] +maintainers = [ + {name="Apache Software Foundation", email="dev@airflow.apache.org"}, +] +keywords = [ "airflow-provider", "sdk.java", "airflow", "integration" ] +classifiers = [ + "Development Status :: 5 - Production/Stable", + "Environment :: Console", + "Environment :: Web Environment", + "Intended Audience :: Developers", + "Intended Audience :: System Administrators", + "Framework :: Apache Airflow", + "Framework :: Apache Airflow :: Provider", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", + "Topic :: System :: Monitoring", +] +requires-python = ">=3.10" + +# The dependencies should be modified in place in the generated file. +# Any change in the dependencies is preserved when the file is regenerated +# Make sure to run ``prek update-providers-dependencies --all-files`` +# After you modify the dependencies, and rebuild your Breeze CI image with ``breeze ci-image build`` +dependencies = [ + "apache-airflow>=3.3.0", +] + +# The optional dependencies should be modified in place in the generated file +# Any change in the dependencies is preserved when the file is regenerated +[project.optional-dependencies] +"common.compat" = [ + "apache-airflow-providers-common-compat" +] + +[dependency-groups] +dev = [ + "apache-airflow", + "apache-airflow-task-sdk", + "apache-airflow-devel-common", + "apache-airflow-providers-common-compat", + # Additional devel dependencies (do not remove this line and add extra development dependencies) +] + +# To build docs: +# +# uv run --group docs build-docs +# +# To enable auto-refreshing build with server: +# +# uv run --group docs build-docs --autobuild +# +# To see more options: +# +# uv run --group docs build-docs --help +# +docs = [ + "apache-airflow-devel-common[docs]" +] + +[tool.uv.sources] +# These names must match the names as defined in the pyproject.toml of the workspace items, +# *not* the workspace folder paths +apache-airflow = {workspace = true} +apache-airflow-devel-common = {workspace = true} +apache-airflow-task-sdk = {workspace = true} +apache-airflow-providers-common-sql = {workspace = true} +apache-airflow-providers-standard = {workspace = true} + +[project.urls] +"Documentation" = "https://airflow.apache.org/docs/apache-airflow-providers-sdk-java/0.1.0" +"Changelog" = "https://airflow.apache.org/docs/apache-airflow-providers-sdk-java/0.1.0/changelog.html" +"Bug Tracker" = "https://github.com/apache/airflow/issues" +"Source Code" = "https://github.com/apache/airflow" +"Slack Chat" = "https://s.apache.org/airflow-slack" +"Mastodon" = "https://fosstodon.org/@airflow" +"YouTube" = "https://www.youtube.com/channel/UCSXwxpWZQ7XZ1WL3wqevChA/" + +[project.entry-points."apache_airflow_provider"] +provider_info = "airflow.providers.sdk.java.get_provider_info:get_provider_info" + +[tool.flit.module] +name = "airflow.providers.sdk.java" + +# Explicit sdist contents so the build does not rely on VCS information +# (flit 4.0 makes --no-use-vcs the default — see https://github.com/pypa/flit/pull/782). +[tool.flit.sdist] +include = [ + "docs/", + "provider.yaml", + "src/airflow/__init__.py", + "src/airflow/providers/__init__.py", + "src/airflow/providers/sdk/__init__.py", + "tests/", +] diff --git a/providers/sdk/java/src/airflow/__init__.py b/providers/sdk/java/src/airflow/__init__.py new file mode 100644 index 0000000000000..5966d6b1d5261 --- /dev/null +++ b/providers/sdk/java/src/airflow/__init__.py @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +__path__ = __import__("pkgutil").extend_path(__path__, __name__) diff --git a/providers/sdk/java/src/airflow/providers/__init__.py b/providers/sdk/java/src/airflow/providers/__init__.py new file mode 100644 index 0000000000000..5966d6b1d5261 --- /dev/null +++ b/providers/sdk/java/src/airflow/providers/__init__.py @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +__path__ = __import__("pkgutil").extend_path(__path__, __name__) diff --git a/providers/sdk/java/src/airflow/providers/sdk/__init__.py b/providers/sdk/java/src/airflow/providers/sdk/__init__.py new file mode 100644 index 0000000000000..5966d6b1d5261 --- /dev/null +++ b/providers/sdk/java/src/airflow/providers/sdk/__init__.py @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +__path__ = __import__("pkgutil").extend_path(__path__, __name__) diff --git a/providers/sdk/java/src/airflow/providers/sdk/java/__init__.py b/providers/sdk/java/src/airflow/providers/sdk/java/__init__.py new file mode 100644 index 0000000000000..1c942bc68df44 --- /dev/null +++ b/providers/sdk/java/src/airflow/providers/sdk/java/__init__.py @@ -0,0 +1,39 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# NOTE! THIS FILE IS AUTOMATICALLY GENERATED AND WILL BE +# OVERWRITTEN WHEN PREPARING DOCUMENTATION FOR THE PACKAGES. +# +# IF YOU WANT TO MODIFY THIS FILE, YOU SHOULD MODIFY THE TEMPLATE +# `PROVIDER__INIT__PY_TEMPLATE.py.jinja2` IN the `dev/breeze/src/airflow_breeze/templates` DIRECTORY +# +from __future__ import annotations + +import packaging.version + +from airflow import __version__ as airflow_version + +__all__ = ["__version__"] + +__version__ = "0.1.0" + +if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse( + "3.3.0" +): + raise RuntimeError( + f"The package `apache-airflow-providers-sdk-java:{__version__}` needs Apache Airflow 3.3.0+" + ) diff --git a/providers/sdk/java/src/airflow/providers/sdk/java/bundle_scanner.py b/providers/sdk/java/src/airflow/providers/sdk/java/bundle_scanner.py new file mode 100644 index 0000000000000..87bbf518b8e5b --- /dev/null +++ b/providers/sdk/java/src/airflow/providers/sdk/java/bundle_scanner.py @@ -0,0 +1,220 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Scan directories for Airflow Java SDK bundle JARs. + +Mirrors the Java SDK's ``BundleScanner`` — checks each JAR's manifest for +``Airflow-Java-SDK-Metadata``, reads the embedded metadata YAML, and +resolves the main class and classpath needed to launch the bundle process. +""" + +from __future__ import annotations + +import email +import os +import zipfile +from pathlib import Path +from typing import NamedTuple + +import yaml + +MANIFEST_PATH = "META-INF/MANIFEST.MF" +METADATA_MANIFEST_KEY = "Airflow-Java-SDK-Metadata" +SDK_VERSION_MANIFEST_KEY = "Airflow-Java-SDK-Version" +DAG_CODE_MANIFEST_KEY = "Airflow-Java-SDK-Dag-Code" +MAIN_CLASS_MANIFEST_KEY = "Main-Class" + + +class ResolvedJarBundle(NamedTuple): + """A resolved Java DAG bundle: everything needed to start the bundle process.""" + + main_class: str + classpath: str + + +class BundleScanner: + """ + Locate Airflow Java SDK bundles inside a directory tree. + + Supports two directory layouts: + + - **Nested** - each immediate subdirectory of *bundles_dir* is a bundle home. + - **Flat** — *bundles_dir* itself contains the bundle JARs. + + Within a bundle home the JVM convention of a ``lib/`` subdirectory for + dependency JARs is respected automatically. + """ + + def __init__(self, bundles_dir: Path) -> None: + self._bundles_dir = bundles_dir + + def resolve(self, dag_id: str) -> ResolvedJarBundle: + """ + Find the bundle whose metadata YAML lists *dag_id*. + + :raises FileNotFoundError: if no matching bundle is found. + """ + for bundle_home in self._candidate_homes(): + jars = _jar_files(bundle_home) + if not jars: + continue + + for jar_path in jars: + result = _read_bundle_jar(jar_path) + if result is None: + continue + main_class, dag_ids = result + if dag_id in dag_ids: + classpath = os.pathsep.join(str(j.resolve()) for j in jars) + return ResolvedJarBundle(main_class=main_class, classpath=classpath) + + raise FileNotFoundError(f"No JAR bundle containing dag_id={dag_id!r} found in {self._bundles_dir}") + + @staticmethod + def resolve_jar(jar_path: Path) -> str: + """ + Read ``Main-Class`` from a single bundle JAR, validating SDK attributes. + + :raises FileNotFoundError: if the JAR is not a valid Airflow Java SDK bundle. + """ + result = _read_bundle_jar(jar_path) + if result is None: + raise FileNotFoundError( + f"Not a valid Airflow Java SDK bundle: {jar_path} " + f"(requires {METADATA_MANIFEST_KEY} and {MAIN_CLASS_MANIFEST_KEY})" + ) + return result[0] + + def _candidate_homes(self) -> list[Path]: + """Return normalised bundle-home directories to inspect.""" + candidates: list[Path] = [] + + # Each subdirectory is a potential bundle home (nested layout). + if self._bundles_dir.is_dir(): + for child in sorted(self._bundles_dir.iterdir()): + if child.is_dir(): + candidates.append(_normalize_bundle_home(child)) + + # The directory itself (flat layout). + candidates.append(_normalize_bundle_home(self._bundles_dir)) + return candidates + + +def _jar_files(directory: Path) -> list[Path]: + """List all ``.jar`` files in *directory*, sorted by name.""" + if not directory.is_dir(): + return [] + return sorted(p for p in directory.iterdir() if p.is_file() and p.suffix == ".jar") + + +def _normalize_bundle_home(path: Path) -> Path: + """ + Normalize a bundle path to the directory containing JARs. + + Handles the common JVM distribution layout where dependency JARs + live in a ``lib/`` subdirectory (Gradle ``application`` plugin, + Maven Assembly, sbt-native-packager, etc.). + + - If *path* points to a JAR file, use its parent directory. + - If the directory has a ``lib/`` subdirectory containing JARs, use that. + - Otherwise, return the directory as-is. + """ + normalized = path.resolve() + if normalized.is_file() and normalized.suffix == ".jar": + return normalized.parent + lib = normalized / "lib" + if lib.is_dir() and any(p.suffix == ".jar" for p in lib.iterdir()): + return lib + return normalized + + +def _read_bundle_jar(jar_path: Path) -> tuple[str, set[str]] | None: + """ + Read ``Main-Class`` and dag IDs from a JAR's manifest and embedded metadata. + + Returns ``(main_class, dag_ids)`` when the JAR carries valid + ``Airflow-Java-SDK-Metadata`` and ``Main-Class`` manifest attributes + and the referenced metadata YAML contains at least one dag ID. + Returns ``None`` otherwise. + """ + try: + with zipfile.ZipFile(jar_path) as zf: + try: + with zf.open(MANIFEST_PATH) as f: + manifest = email.message_from_binary_file(f) + except KeyError: + return None + + metadata_file = manifest.get(METADATA_MANIFEST_KEY) + if not metadata_file: + return None + + main_class = manifest.get(MAIN_CLASS_MANIFEST_KEY) + if not main_class: + return None + + try: + with zf.open(metadata_file) as f: + content = f.read().decode() + except KeyError: + return None + except zipfile.BadZipFile: + return None + + dag_ids = _parse_dag_ids_from_metadata(content) + if not dag_ids: + return None + + return main_class, dag_ids + + +def read_dag_code(jar_path: Path) -> str | None: + """ + Read the DAG source code embedded in a JAR bundle. + + Returns the source code string when the JAR carries a valid + ``Airflow-Java-SDK-Dag-Code`` manifest attribute pointing to an + embedded source file. Returns ``None`` otherwise. + """ + try: + with zipfile.ZipFile(jar_path) as zf: + try: + with zf.open(MANIFEST_PATH) as f: + manifest = email.message_from_binary_file(f) + except KeyError: + return None + + dag_code_path = manifest.get(DAG_CODE_MANIFEST_KEY) + if not dag_code_path: + return None + + try: + with zf.open(dag_code_path) as f: + return f.read().decode() + except KeyError: + return None + except zipfile.BadZipFile: + return None + + +def _parse_dag_ids_from_metadata(yaml_content: str) -> set[str]: + """Parse dag IDs from an ``airflow-metadata.yaml`` content string.""" + data = yaml.safe_load(yaml_content) + if not isinstance(data, dict) or "dags" not in data: + return set() + return set(data["dags"].keys()) diff --git a/providers/sdk/java/src/airflow/providers/sdk/java/coordinator.py b/providers/sdk/java/src/airflow/providers/sdk/java/coordinator.py new file mode 100644 index 0000000000000..11833f166ce4b --- /dev/null +++ b/providers/sdk/java/src/airflow/providers/sdk/java/coordinator.py @@ -0,0 +1,131 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Java runtime coordinator that launches a JVM subprocess for Dag file processing and task execution.""" + +from __future__ import annotations + +import contextlib +import os +import zipfile +from pathlib import Path +from typing import TYPE_CHECKING + +from airflow.providers.sdk.java.bundle_scanner import BundleScanner, read_dag_code +from airflow.sdk.execution_time.coordinator import BaseCoordinator + +if TYPE_CHECKING: + from airflow.sdk.api.datamodels._generated import BundleInfo + from airflow.sdk.execution_time.workloads.task import TaskInstanceDTO + + +class JavaCoordinator(BaseCoordinator): + """Coordinator that launches a JVM subprocess for DAG parsing and task execution.""" + + sdk = "java" + file_extension = ".jar" + + @classmethod + def can_handle_dag_file(cls, bundle_name: str, path: str | os.PathLike[str]) -> bool: + """Return ``True`` when *path* is a JAR with valid Airflow Java SDK manifest attributes.""" + if not os.fspath(path).endswith(cls.file_extension): + return False + with contextlib.suppress(FileNotFoundError, NotADirectoryError, zipfile.BadZipFile, KeyError): + return BundleScanner.resolve_jar(Path(path)) is not None + return False + + @classmethod + def get_code_from_file(cls, fileloc: str) -> str: + """Read embedded DAG source code from a JAR bundle.""" + code = read_dag_code(Path(fileloc)) + if code is None: + raise FileNotFoundError(f"No DAG source code found in JAR: {fileloc}") + return code + + @classmethod + def dag_parsing_cmd( + cls, + *, + dag_file_path: str, + bundle_name: str, + bundle_path: str, + comm_addr: str, + logs_addr: str, + ) -> list[str]: + """Build the ``java`` command for parsing a JAR bundle.""" + jar_path = Path(dag_file_path) + # Java bundles are typically thin JARs: the main JAR only contains + # the bundle's own classes while its dependencies (the Airflow Java + # SDK, logging libraries, etc.) are separate JARs that live alongside + # it. Using ``/*`` lets the JVM load every JAR in the directory. + classpath = f"{bundle_path}/*" + return [ + "java", + "-classpath", + classpath, + BundleScanner.resolve_jar(jar_path), + f"--comm={comm_addr}", + f"--logs={logs_addr}", + ] + + @classmethod + def task_execution_cmd( + cls, + *, + what: TaskInstanceDTO, + dag_file_path: str, + bundle_path: str, + bundle_info: BundleInfo, + comm_addr: str, + logs_addr: str, + ) -> list[str]: + """Build the ``java`` command for executing a task in a JAR bundle.""" + if dag_file_path.endswith(".jar"): + # Case 1: Pure Java Dag — the dag_file_path points directly to a + # bundle JAR inside the Airflow Core Dag Bundle. + jar_path = Path(dag_file_path) + classpath = f"{bundle_path}/*" + return [ + "java", + "-classpath", + classpath, + BundleScanner.resolve_jar(jar_path), + f"--comm={comm_addr}", + f"--logs={logs_addr}", + ] + + # Case 2: Python Stub Dag — the dag_file_path is a Python file but + # the task delegates to a Java runtime. The actual JAR bundle lives + # in the provider's configured ``[java] bundles_folder``. + from airflow.providers.common.compat.sdk import conf + + bundles_folder = conf.get("java", "bundles_folder", fallback=None) + if not bundles_folder: + raise ValueError( + "The [java] bundles_folder config must be set for Python stub DAGs " + "that delegate to Java task execution." + ) + + resolved = BundleScanner(Path(bundles_folder)).resolve(dag_id=what.dag_id) + return [ + "java", + "-classpath", + resolved.classpath, + resolved.main_class, + f"--comm={comm_addr}", + f"--logs={logs_addr}", + ] diff --git a/providers/sdk/java/src/airflow/providers/sdk/java/get_provider_info.py b/providers/sdk/java/src/airflow/providers/sdk/java/get_provider_info.py new file mode 100644 index 0000000000000..89df45102b732 --- /dev/null +++ b/providers/sdk/java/src/airflow/providers/sdk/java/get_provider_info.py @@ -0,0 +1,47 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# NOTE! THIS FILE IS AUTOMATICALLY GENERATED AND WILL BE OVERWRITTEN! +# +# IF YOU WANT TO MODIFY THIS FILE, YOU SHOULD MODIFY THE TEMPLATE +# `get_provider_info_TEMPLATE.py.jinja2` IN the `dev/breeze/src/airflow_breeze/templates` DIRECTORY + + +def get_provider_info(): + return { + "package-name": "apache-airflow-providers-sdk-java", + "name": "SDK: Java", + "description": "Java SDK support for Apache Airflow runtime coordinators.\n", + "integrations": [ + {"integration-name": "Java", "external-doc-url": "https://openjdk.org/", "tags": ["software"]} + ], + "config": { + "java": { + "description": "Options for the Java SDK provider.", + "options": { + "bundles_folder": { + "description": "Path to the directory containing Java DAG bundle JARs.\nWhen using Python stub DAGs that delegate task execution to Java,\nthe coordinator scans this directory to find the JAR bundle matching\nthe target dag_id. Each immediate subdirectory is treated as a\nseparate bundle home, and the directory itself is also checked\n(flat layout).\n", + "type": "string", + "version_added": None, + "example": "~/airflow/java-bundles", + "default": "", + } + }, + } + }, + "coordinators": ["airflow.providers.sdk.java.coordinator.JavaCoordinator"], + } diff --git a/providers/sdk/java/tests/conftest.py b/providers/sdk/java/tests/conftest.py new file mode 100644 index 0000000000000..f56ccce0a3f69 --- /dev/null +++ b/providers/sdk/java/tests/conftest.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +pytest_plugins = "tests_common.pytest_plugin" diff --git a/providers/sdk/java/tests/unit/__init__.py b/providers/sdk/java/tests/unit/__init__.py new file mode 100644 index 0000000000000..5966d6b1d5261 --- /dev/null +++ b/providers/sdk/java/tests/unit/__init__.py @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +__path__ = __import__("pkgutil").extend_path(__path__, __name__) diff --git a/providers/sdk/java/tests/unit/sdk/__init__.py b/providers/sdk/java/tests/unit/sdk/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/providers/sdk/java/tests/unit/sdk/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/providers/sdk/java/tests/unit/sdk/java/__init__.py b/providers/sdk/java/tests/unit/sdk/java/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/providers/sdk/java/tests/unit/sdk/java/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/providers/sdk/java/tests/unit/sdk/java/test_bundle_scanner.py b/providers/sdk/java/tests/unit/sdk/java/test_bundle_scanner.py new file mode 100644 index 0000000000000..5c042036143da --- /dev/null +++ b/providers/sdk/java/tests/unit/sdk/java/test_bundle_scanner.py @@ -0,0 +1,337 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import os +import zipfile +from pathlib import Path + +import pytest +import yaml + +from airflow.providers.sdk.java.bundle_scanner import ( + DAG_CODE_MANIFEST_KEY, + MAIN_CLASS_MANIFEST_KEY, + MANIFEST_PATH, + METADATA_MANIFEST_KEY, + SDK_VERSION_MANIFEST_KEY, + BundleScanner, + ResolvedJarBundle, + _jar_files, + _normalize_bundle_home, + _parse_dag_ids_from_metadata, + _read_bundle_jar, + read_dag_code, +) + +METADATA_YAML_PATH = "META-INF/airflow-metadata.yaml" +DAG_CODE_PATH = "JavaExample.java" +TEST_MAIN_CLASS = "com.example.MyDag" +TEST_SDK_VERSION = "1.0.0" + + +def _make_manifest( + *, + main_class: str | None = TEST_MAIN_CLASS, + metadata_path: str | None = METADATA_YAML_PATH, + sdk_version: str | None = TEST_SDK_VERSION, + dag_code_path: str | None = None, +) -> str: + lines = ["Manifest-Version: 1.0"] + if main_class: + lines.append(f"{MAIN_CLASS_MANIFEST_KEY}: {main_class}") + if metadata_path: + lines.append(f"{METADATA_MANIFEST_KEY}: {metadata_path}") + if sdk_version: + lines.append(f"{SDK_VERSION_MANIFEST_KEY}: {sdk_version}") + if dag_code_path: + lines.append(f"{DAG_CODE_MANIFEST_KEY}: {dag_code_path}") + return "\n".join(lines) + "\n" + + +def _make_metadata_yaml(dag_ids: list[str]) -> str: + return yaml.dump({"dags": {dag_id: {} for dag_id in dag_ids}}) + + +def _create_bundle_jar( + jar_path: Path, + *, + dag_ids: list[str] | None = None, + main_class: str | None = TEST_MAIN_CLASS, + include_metadata: bool = True, + include_manifest: bool = True, + dag_code: str | None = None, +) -> Path: + """Create a minimal JAR (zip) file with Airflow Java SDK manifest attributes.""" + with zipfile.ZipFile(jar_path, "w") as zf: + if include_manifest: + dag_code_path = DAG_CODE_PATH if dag_code else None + manifest = _make_manifest( + main_class=main_class, + metadata_path=METADATA_YAML_PATH if include_metadata else None, + dag_code_path=dag_code_path, + ) + zf.writestr(MANIFEST_PATH, manifest) + + if include_metadata and dag_ids is not None: + zf.writestr(METADATA_YAML_PATH, _make_metadata_yaml(dag_ids)) + + if dag_code: + zf.writestr(DAG_CODE_PATH, dag_code) + return jar_path + + +class TestJarFiles: + def test_lists_jar_files_sorted(self, tmp_path: Path): + (tmp_path / "b.jar").touch() + (tmp_path / "a.jar").touch() + (tmp_path / "c.txt").touch() + result = _jar_files(tmp_path) + assert result == [tmp_path / "a.jar", tmp_path / "b.jar"] + + def test_returns_empty_for_nonexistent_directory(self, tmp_path: Path): + assert _jar_files(tmp_path / "nonexistent") == [] + + def test_returns_empty_for_directory_with_no_jars(self, tmp_path: Path): + (tmp_path / "readme.txt").touch() + assert _jar_files(tmp_path) == [] + + def test_ignores_jar_directories(self, tmp_path: Path): + (tmp_path / "fake.jar").mkdir() + assert _jar_files(tmp_path) == [] + + +class TestNormalizeBundleHome: + def test_jar_file_returns_parent(self, tmp_path: Path): + jar = tmp_path / "bundle.jar" + jar.touch() + assert _normalize_bundle_home(jar) == tmp_path.resolve() + + def test_dir_with_lib_containing_jars(self, tmp_path: Path): + lib = tmp_path / "lib" + lib.mkdir() + (lib / "dep.jar").touch() + assert _normalize_bundle_home(tmp_path) == lib.resolve() + + def test_dir_with_empty_lib(self, tmp_path: Path): + lib = tmp_path / "lib" + lib.mkdir() + assert _normalize_bundle_home(tmp_path) == tmp_path.resolve() + + def test_plain_directory(self, tmp_path: Path): + assert _normalize_bundle_home(tmp_path) == tmp_path.resolve() + + +class TestParseDagIdsFromMetadata: + def test_parses_dag_ids(self): + content = yaml.dump({"dags": {"dag_a": {}, "dag_b": {"key": "val"}}}) + assert _parse_dag_ids_from_metadata(content) == {"dag_a", "dag_b"} + + @pytest.mark.parametrize( + "yaml_content", + [ + pytest.param(yaml.dump({"other": 1}), id="missing_dags_key"), + pytest.param("just a string", id="non_dict"), + pytest.param(yaml.dump({"dags": {}}), id="empty_dags"), + ], + ) + def test_returns_empty_set(self, yaml_content): + assert _parse_dag_ids_from_metadata(yaml_content) == set() + + +class TestReadBundleJar: + def test_valid_jar(self, tmp_path: Path): + jar = _create_bundle_jar(tmp_path / "valid.jar", dag_ids=["my_dag"]) + result = _read_bundle_jar(jar) + assert result is not None + main_class, dag_ids = result + assert main_class == TEST_MAIN_CLASS + assert dag_ids == {"my_dag"} + + def test_returns_none_for_missing_manifest(self, tmp_path: Path): + jar = _create_bundle_jar(tmp_path / "no_manifest.jar", include_manifest=False) + assert _read_bundle_jar(jar) is None + + def test_returns_none_for_missing_metadata_key(self, tmp_path: Path): + jar = _create_bundle_jar(tmp_path / "no_meta.jar", include_metadata=False) + assert _read_bundle_jar(jar) is None + + def test_returns_none_for_missing_main_class(self, tmp_path: Path): + jar = _create_bundle_jar(tmp_path / "no_main.jar", dag_ids=["d"], main_class=None) + assert _read_bundle_jar(jar) is None + + def test_returns_none_for_missing_metadata_file(self, tmp_path: Path): + """Manifest references a metadata file that does not exist inside the JAR.""" + jar = tmp_path / "missing_meta_file.jar" + with zipfile.ZipFile(jar, "w") as zf: + manifest = _make_manifest(metadata_path="nonexistent.yaml") + zf.writestr(MANIFEST_PATH, manifest) + assert _read_bundle_jar(jar) is None + + def test_returns_none_for_bad_zip(self, tmp_path: Path): + bad = tmp_path / "bad.jar" + bad.write_text("not a zip file") + assert _read_bundle_jar(bad) is None + + def test_returns_none_for_empty_dag_ids(self, tmp_path: Path): + jar = _create_bundle_jar(tmp_path / "empty_dags.jar", dag_ids=[]) + assert _read_bundle_jar(jar) is None + + def test_multiple_dag_ids(self, tmp_path: Path): + jar = _create_bundle_jar(tmp_path / "multi.jar", dag_ids=["dag_1", "dag_2", "dag_3"]) + result = _read_bundle_jar(jar) + assert result is not None + _, dag_ids = result + assert dag_ids == {"dag_1", "dag_2", "dag_3"} + + +class TestReadDagCode: + def test_reads_embedded_dag_code(self, tmp_path: Path): + code = "public class MyDag {}" + jar = _create_bundle_jar(tmp_path / "with_code.jar", dag_ids=["d"], dag_code=code) + assert read_dag_code(jar) == code + + def test_returns_none_for_missing_dag_code_key(self, tmp_path: Path): + jar = _create_bundle_jar(tmp_path / "no_code.jar", dag_ids=["d"]) + assert read_dag_code(jar) is None + + def test_returns_none_for_missing_manifest(self, tmp_path: Path): + jar = _create_bundle_jar(tmp_path / "no_manifest.jar", include_manifest=False) + assert read_dag_code(jar) is None + + def test_returns_none_for_bad_zip(self, tmp_path: Path): + bad = tmp_path / "bad.jar" + bad.write_text("not a zip") + assert read_dag_code(bad) is None + + def test_returns_none_when_code_file_missing(self, tmp_path: Path): + """Manifest references a dag code file that does not exist inside the JAR.""" + jar = tmp_path / "broken_code.jar" + with zipfile.ZipFile(jar, "w") as zf: + manifest = _make_manifest(dag_code_path="missing_source.py") + zf.writestr(MANIFEST_PATH, manifest) + assert read_dag_code(jar) is None + + +class TestBundleScannerResolveJar: + def test_returns_main_class(self, tmp_path: Path): + jar = _create_bundle_jar(tmp_path / "bundle.jar", dag_ids=["d"]) + assert BundleScanner.resolve_jar(jar) == TEST_MAIN_CLASS + + def test_raises_for_invalid_jar(self, tmp_path: Path): + jar = tmp_path / "not_bundle.jar" + jar.write_text("not a zip") + with pytest.raises(FileNotFoundError, match="Not a valid Airflow Java SDK bundle"): + BundleScanner.resolve_jar(jar) + + +class TestBundleScannerCandidateHomes: + def test_nested_layout(self, tmp_path: Path): + sub_a = tmp_path / "bundle_a" + sub_a.mkdir() + (sub_a / "app.jar").touch() + + sub_b = tmp_path / "bundle_b" + sub_b.mkdir() + (sub_b / "app.jar").touch() + + scanner = BundleScanner(tmp_path) + homes = scanner._candidate_homes() + # Nested subdirs + the bundles_dir itself + assert len(homes) == 3 + assert sub_a.resolve() in homes + assert sub_b.resolve() in homes + assert tmp_path.resolve() in homes + + def test_flat_layout(self, tmp_path: Path): + (tmp_path / "app.jar").touch() + scanner = BundleScanner(tmp_path) + homes = scanner._candidate_homes() + # Only the directory itself (no subdirectories) + assert homes == [tmp_path.resolve()] + + def test_nested_with_lib_subdir(self, tmp_path: Path): + sub = tmp_path / "my_bundle" + sub.mkdir() + lib = sub / "lib" + lib.mkdir() + (lib / "dep.jar").touch() + + scanner = BundleScanner(tmp_path) + homes = scanner._candidate_homes() + # _normalize_bundle_home should redirect to lib/ + assert lib.resolve() in homes + + +class TestBundleScannerResolve: + def test_finds_matching_dag(self, tmp_path: Path): + bundle_dir = tmp_path / "my_bundle" + bundle_dir.mkdir() + _create_bundle_jar(bundle_dir / "app.jar", dag_ids=["target_dag"]) + + scanner = BundleScanner(tmp_path) + result = scanner.resolve("target_dag") + assert isinstance(result, ResolvedJarBundle) + assert result.main_class == TEST_MAIN_CLASS + assert str((bundle_dir / "app.jar").resolve()) in result.classpath + + def test_raises_when_no_match(self, tmp_path: Path): + bundle_dir = tmp_path / "my_bundle" + bundle_dir.mkdir() + _create_bundle_jar(bundle_dir / "app.jar", dag_ids=["other_dag"]) + + scanner = BundleScanner(tmp_path) + with pytest.raises(FileNotFoundError, match="No JAR bundle containing dag_id='missing'"): + scanner.resolve("missing") + + def test_classpath_includes_all_jars(self, tmp_path: Path): + bundle_dir = tmp_path / "my_bundle" + bundle_dir.mkdir() + _create_bundle_jar(bundle_dir / "app.jar", dag_ids=["my_dag"]) + # Create a dependency JAR (no SDK metadata, just a plain JAR) + with zipfile.ZipFile(bundle_dir / "dep.jar", "w") as zf: + zf.writestr("dummy.class", b"") + + scanner = BundleScanner(tmp_path) + result = scanner.resolve("my_dag") + parts = result.classpath.split(os.pathsep) + assert len(parts) == 2 + + def test_flat_layout_resolve(self, tmp_path: Path): + _create_bundle_jar(tmp_path / "app.jar", dag_ids=["flat_dag"]) + + scanner = BundleScanner(tmp_path) + result = scanner.resolve("flat_dag") + assert result.main_class == TEST_MAIN_CLASS + + def test_skips_non_bundle_jars(self, tmp_path: Path): + bundle_dir = tmp_path / "my_bundle" + bundle_dir.mkdir() + # Non-bundle JAR (no manifest) + with zipfile.ZipFile(bundle_dir / "plain.jar", "w") as zf: + zf.writestr("dummy.class", b"") + _create_bundle_jar(bundle_dir / "real.jar", dag_ids=["real_dag"]) + + scanner = BundleScanner(tmp_path) + result = scanner.resolve("real_dag") + assert result.main_class == TEST_MAIN_CLASS + + def test_empty_bundles_dir(self, tmp_path: Path): + scanner = BundleScanner(tmp_path) + with pytest.raises(FileNotFoundError): + scanner.resolve("any_dag") diff --git a/providers/sdk/java/tests/unit/sdk/java/test_coordinator.py b/providers/sdk/java/tests/unit/sdk/java/test_coordinator.py new file mode 100644 index 0000000000000..d5dc053a822bf --- /dev/null +++ b/providers/sdk/java/tests/unit/sdk/java/test_coordinator.py @@ -0,0 +1,242 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import uuid +import zipfile +from pathlib import Path +from unittest.mock import patch + +import pytest +import yaml + +from airflow.providers.sdk.java.bundle_scanner import ( + MAIN_CLASS_MANIFEST_KEY, + MANIFEST_PATH, + METADATA_MANIFEST_KEY, + SDK_VERSION_MANIFEST_KEY, +) +from airflow.providers.sdk.java.coordinator import JavaCoordinator +from airflow.sdk.api.datamodels._generated import BundleInfo +from airflow.sdk.execution_time.workloads.task import TaskInstanceDTO + +from tests_common.test_utils.version_compat import AIRFLOW_V_3_3_PLUS + +if not AIRFLOW_V_3_3_PLUS: + pytest.skip("Coordinator is only compatible with Airflow >= 3.3.0", allow_module_level=True) + +METADATA_YAML_PATH = "META-INF/airflow-metadata.yaml" +DAG_CODE_PATH = "dag_source.py" +TEST_MAIN_CLASS = "com.example.MyBundle" + + +def _make_manifest( + *, + main_class: str | None = TEST_MAIN_CLASS, + metadata_path: str | None = METADATA_YAML_PATH, + dag_code_path: str | None = None, +) -> str: + lines = ["Manifest-Version: 1.0"] + if main_class: + lines.append(f"{MAIN_CLASS_MANIFEST_KEY}: {main_class}") + if metadata_path: + lines.append(f"{METADATA_MANIFEST_KEY}: {metadata_path}") + lines.append(f"{SDK_VERSION_MANIFEST_KEY}: 1.0.0") + if dag_code_path: + lines.append(f"Airflow-Java-SDK-Dag-Code: {dag_code_path}") + return "\n".join(lines) + "\n" + + +def _create_bundle_jar( + jar_path: Path, + *, + dag_ids: list[str] | None = None, + dag_code: str | None = None, +) -> Path: + with zipfile.ZipFile(jar_path, "w") as zf: + dag_code_path = DAG_CODE_PATH if dag_code else None + manifest = _make_manifest(dag_code_path=dag_code_path) + zf.writestr(MANIFEST_PATH, manifest) + if dag_ids is not None: + metadata = yaml.dump({"dags": {d: {} for d in dag_ids}}) + zf.writestr(METADATA_YAML_PATH, metadata) + if dag_code: + zf.writestr(DAG_CODE_PATH, dag_code) + return jar_path + + +def _make_ti(dag_id: str = "test_dag") -> TaskInstanceDTO: + return TaskInstanceDTO( + id=uuid.uuid4(), + dag_version_id=uuid.uuid4(), + task_id="task_1", + dag_id=dag_id, + run_id="run_1", + try_number=1, + map_index=-1, + pool_slots=1, + queue="default", + priority_weight=1, + ) + + +class TestJavaCoordinatorAttributes: + def test_sdk(self): + assert JavaCoordinator.sdk == "java" + + def test_file_extension(self): + assert JavaCoordinator.file_extension == ".jar" + + +class TestCanHandleDagFile: + def test_valid_jar_returns_true(self, tmp_path: Path): + jar = _create_bundle_jar(tmp_path / "valid.jar", dag_ids=["d"]) + assert JavaCoordinator.can_handle_dag_file("bundle", str(jar)) is True + + def test_non_jar_file_returns_false(self, tmp_path: Path): + py_file = tmp_path / "dag.py" + py_file.write_text("from airflow import DAG") + assert JavaCoordinator.can_handle_dag_file("bundle", str(py_file)) is False + + def test_missing_file_returns_false(self, tmp_path: Path): + assert JavaCoordinator.can_handle_dag_file("bundle", str(tmp_path / "missing.jar")) is False + + def test_bad_zip_returns_false(self, tmp_path: Path): + bad = tmp_path / "bad.jar" + bad.write_text("not a zip") + assert JavaCoordinator.can_handle_dag_file("bundle", str(bad)) is False + + def test_jar_without_sdk_manifest_returns_false(self, tmp_path: Path): + jar = tmp_path / "plain.jar" + with zipfile.ZipFile(jar, "w") as zf: + zf.writestr("dummy.class", b"") + assert JavaCoordinator.can_handle_dag_file("bundle", str(jar)) is False + + +class TestGetCodeFromFile: + def test_returns_embedded_code(self, tmp_path: Path): + code = "from airflow import DAG\ndag = DAG('my_dag')" + jar = _create_bundle_jar(tmp_path / "with_code.jar", dag_ids=["d"], dag_code=code) + assert JavaCoordinator.get_code_from_file(str(jar)) == code + + def test_raises_when_no_code(self, tmp_path: Path): + jar = _create_bundle_jar(tmp_path / "no_code.jar", dag_ids=["d"]) + with pytest.raises(FileNotFoundError, match="No DAG source code found in JAR"): + JavaCoordinator.get_code_from_file(str(jar)) + + +class TestDagParsingCmd: + def test_builds_java_command(self, tmp_path: Path): + jar = _create_bundle_jar(tmp_path / "app.jar", dag_ids=["d"]) + bundle_path = str(tmp_path) + cmd = JavaCoordinator.dag_parsing_cmd( + dag_file_path=str(jar), + bundle_name="my_bundle", + bundle_path=bundle_path, + comm_addr="localhost:1234", + logs_addr="localhost:5678", + ) + assert cmd == [ + "java", + "-classpath", + f"{bundle_path}/*", + TEST_MAIN_CLASS, + "--comm=localhost:1234", + "--logs=localhost:5678", + ] + + +class TestTaskExecutionCmd: + def test_pure_java_dag(self, tmp_path: Path): + jar = _create_bundle_jar(tmp_path / "app.jar", dag_ids=["test_dag"]) + bundle_path = str(tmp_path) + ti = _make_ti() + bundle_info = BundleInfo(name="my_bundle") + + cmd = JavaCoordinator.task_execution_cmd( + what=ti, # type: ignore[arg-type] + dag_file_path=str(jar), + bundle_path=bundle_path, + bundle_info=bundle_info, + comm_addr="localhost:1234", + logs_addr="localhost:5678", + ) + assert cmd == [ + "java", + "-classpath", + f"{bundle_path}/*", + TEST_MAIN_CLASS, + "--comm=localhost:1234", + "--logs=localhost:5678", + ] + + def test_python_stub_dag_with_bundles_folder(self, tmp_path: Path): + bundles_folder = tmp_path / "java_bundles" + bundle_sub = bundles_folder / "my_bundle" + bundle_sub.mkdir(parents=True) + _create_bundle_jar(bundle_sub / "app.jar", dag_ids=["stub_dag"]) + + ti = _make_ti(dag_id="stub_dag") + bundle_info = BundleInfo(name="my_bundle") + + with patch( + "airflow.providers.common.compat.sdk.conf.get", + return_value=str(bundles_folder), + ): + cmd = JavaCoordinator.task_execution_cmd( + what=ti, # type: ignore[arg-type] + dag_file_path="/dags/stub_dag.py", + bundle_path="/some/bundle/path", + bundle_info=bundle_info, + comm_addr="localhost:1234", + logs_addr="localhost:5678", + ) + + assert cmd == [ + "java", + "-classpath", + f"{bundles_folder}/my_bundle/app.jar", + TEST_MAIN_CLASS, + "--comm=localhost:1234", + "--logs=localhost:5678", + ] + + @pytest.mark.parametrize( + "config_value", + [ + pytest.param(None, id="none"), + pytest.param("", id="empty_string"), + ], + ) + def test_python_stub_dag_invalid_config_raises(self, config_value): + ti = _make_ti() + bundle_info = BundleInfo(name="my_bundle") + + with patch( + "airflow.providers.common.compat.sdk.conf.get", + return_value=config_value, + ): + with pytest.raises(ValueError, match="bundles_folder config must be set"): + JavaCoordinator.task_execution_cmd( + what=ti, # type: ignore[arg-type] + dag_file_path="/dags/stub_dag.py", + bundle_path="/some/bundle/path", + bundle_info=bundle_info, + comm_addr="localhost:1234", + logs_addr="localhost:5678", + ) diff --git a/providers/sdk/java/tests/unit/sdk/java/test_java_provider.py b/providers/sdk/java/tests/unit/sdk/java/test_java_provider.py new file mode 100644 index 0000000000000..e0489ada7cc17 --- /dev/null +++ b/providers/sdk/java/tests/unit/sdk/java/test_java_provider.py @@ -0,0 +1,58 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pytest + +from airflow.providers.sdk.java.coordinator import JavaCoordinator +from airflow.providers.sdk.java.get_provider_info import get_provider_info + +from tests_common.test_utils.version_compat import AIRFLOW_V_3_3_PLUS + +if not AIRFLOW_V_3_3_PLUS: + pytest.skip("Coordinator is only compatible with Airflow >= 3.3.0", allow_module_level=True) + + +def test_get_provider_info_exposes_java_runtime_components(): + assert get_provider_info() == { + "package-name": "apache-airflow-providers-sdk-java", + "name": "SDK: Java", + "description": "Java SDK support for Apache Airflow runtime coordinators.\n", + "integrations": [ + {"integration-name": "Java", "external-doc-url": "https://openjdk.org/", "tags": ["software"]} + ], + "config": { + "java": { + "description": "Options for the Java SDK provider.", + "options": { + "bundles_folder": { + "description": "Path to the directory containing Java DAG bundle JARs.\nWhen using Python stub DAGs that delegate task execution to Java,\nthe coordinator scans this directory to find the JAR bundle matching\nthe target dag_id. Each immediate subdirectory is treated as a\nseparate bundle home, and the directory itself is also checked\n(flat layout).\n", + "type": "string", + "version_added": None, + "example": "~/airflow/java-bundles", + "default": "", + } + }, + } + }, + "coordinators": ["airflow.providers.sdk.java.coordinator.JavaCoordinator"], + } + + +def test_java_provider_entrypoints_are_importable(): + assert JavaCoordinator.sdk == "java" diff --git a/providers/standard/src/airflow/providers/standard/decorators/stub.py b/providers/standard/src/airflow/providers/standard/decorators/stub.py index f29d123c740c1..a5e63d925f795 100644 --- a/providers/standard/src/airflow/providers/standard/decorators/stub.py +++ b/providers/standard/src/airflow/providers/standard/decorators/stub.py @@ -85,7 +85,6 @@ def stub( Stub tasks exist in the Dag graph only, but the execution must happen in an external environment via the Task Execution Interface. - """ return task_decorator_factory( decorated_operator_class=_StubOperator, diff --git a/pyproject.toml b/pyproject.toml index 4a988db2d3ea7..2c733873cd796 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -338,6 +338,9 @@ apache-airflow = "airflow.__main__:main" "samba" = [ "apache-airflow-providers-samba>=4.9.0" ] +"sdk.java" = [ + "apache-airflow-providers-sdk-java>=0.1.0" +] "segment" = [ "apache-airflow-providers-segment>=3.7.0" ] @@ -481,6 +484,7 @@ apache-airflow = "airflow.__main__:main" "apache-airflow-providers-redis>=4.0.0", "apache-airflow-providers-salesforce>=5.9.0", "apache-airflow-providers-samba>=4.9.0", + "apache-airflow-providers-sdk-java>=0.1.0", "apache-airflow-providers-segment>=3.7.0", "apache-airflow-providers-sendgrid>=4.0.0", "apache-airflow-providers-sftp>=5.0.0", @@ -1218,6 +1222,8 @@ mypy_path = [ "$MYPY_CONFIG_FILE_DIR/providers/salesforce/tests", "$MYPY_CONFIG_FILE_DIR/providers/samba/src", "$MYPY_CONFIG_FILE_DIR/providers/samba/tests", + "$MYPY_CONFIG_FILE_DIR/providers/sdk/java/src", + "$MYPY_CONFIG_FILE_DIR/providers/sdk/java/tests", "$MYPY_CONFIG_FILE_DIR/providers/segment/src", "$MYPY_CONFIG_FILE_DIR/providers/segment/tests", "$MYPY_CONFIG_FILE_DIR/providers/sendgrid/src", @@ -1461,6 +1467,7 @@ apache-airflow-providers-qdrant = false apache-airflow-providers-redis = false apache-airflow-providers-salesforce = false apache-airflow-providers-samba = false +apache-airflow-providers-sdk-java = false apache-airflow-providers-segment = false apache-airflow-providers-sendgrid = false apache-airflow-providers-sftp = false @@ -1612,6 +1619,7 @@ apache-airflow-providers-qdrant = false apache-airflow-providers-redis = false apache-airflow-providers-salesforce = false apache-airflow-providers-samba = false +apache-airflow-providers-sdk-java = false apache-airflow-providers-segment = false apache-airflow-providers-sendgrid = false apache-airflow-providers-sftp = false @@ -1773,6 +1781,7 @@ apache-airflow-providers-qdrant = { workspace = true } apache-airflow-providers-redis = { workspace = true } apache-airflow-providers-salesforce = { workspace = true } apache-airflow-providers-samba = { workspace = true } +apache-airflow-providers-sdk-java = { workspace = true } apache-airflow-providers-segment = { workspace = true } apache-airflow-providers-sendgrid = { workspace = true } apache-airflow-providers-sftp = { workspace = true } @@ -1910,6 +1919,7 @@ members = [ "providers/redis", "providers/salesforce", "providers/samba", + "providers/sdk/java", "providers/segment", "providers/sendgrid", "providers/sftp", diff --git a/scripts/ci/docker-compose/remove-sources.yml b/scripts/ci/docker-compose/remove-sources.yml index a2f7d3a035766..24ca15bbb0c47 100644 --- a/scripts/ci/docker-compose/remove-sources.yml +++ b/scripts/ci/docker-compose/remove-sources.yml @@ -107,6 +107,7 @@ services: - ../../../empty:/opt/airflow/providers/redis/src - ../../../empty:/opt/airflow/providers/salesforce/src - ../../../empty:/opt/airflow/providers/samba/src + - ../../../empty:/opt/airflow/providers/sdk/java/src - ../../../empty:/opt/airflow/providers/segment/src - ../../../empty:/opt/airflow/providers/sendgrid/src - ../../../empty:/opt/airflow/providers/sftp/src diff --git a/scripts/ci/docker-compose/tests-sources.yml b/scripts/ci/docker-compose/tests-sources.yml index 9c02d1c271412..de736d60237ae 100644 --- a/scripts/ci/docker-compose/tests-sources.yml +++ b/scripts/ci/docker-compose/tests-sources.yml @@ -120,6 +120,7 @@ services: - ../../../providers/redis/tests:/opt/airflow/providers/redis/tests - ../../../providers/salesforce/tests:/opt/airflow/providers/salesforce/tests - ../../../providers/samba/tests:/opt/airflow/providers/samba/tests + - ../../../providers/sdk/java/tests:/opt/airflow/providers/sdk/java/tests - ../../../providers/segment/tests:/opt/airflow/providers/segment/tests - ../../../providers/sendgrid/tests:/opt/airflow/providers/sendgrid/tests - ../../../providers/sftp/tests:/opt/airflow/providers/sftp/tests diff --git a/scripts/ci/prek/check_task_instance_dto_sync.py b/scripts/ci/prek/check_task_instance_dto_sync.py new file mode 100755 index 0000000000000..689d35a4d15e3 --- /dev/null +++ b/scripts/ci/prek/check_task_instance_dto_sync.py @@ -0,0 +1,125 @@ +#!/usr/bin/env python +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Verify that the duplicate ``BaseTaskInstanceDTO`` definitions in airflow-core +and task-sdk stay structurally identical. + +``BaseTaskInstanceDTO`` is duplicated (not shared) in: + +- ``airflow-core/src/airflow/executors/workloads/task.py`` +- ``task-sdk/src/airflow/sdk/execution_time/workloads/task.py`` + +This hook compares the *fields* (annotated assignments) and bases of both +``BaseTaskInstanceDTO`` classes. The concrete ``TaskInstanceDTO`` subclasses +in each file are allowed to differ (airflow-core adds an executor-specific +``key`` property that depends on ``airflow.models``, which the Task SDK +does not have access to). +""" + +from __future__ import annotations + +import ast +import sys +from pathlib import Path + +AIRFLOW_ROOT = Path(__file__).parents[3].resolve() +CORE_FILE = AIRFLOW_ROOT / "airflow-core" / "src" / "airflow" / "executors" / "workloads" / "task.py" +SDK_FILE = AIRFLOW_ROOT / "task-sdk" / "src" / "airflow" / "sdk" / "execution_time" / "workloads" / "task.py" +CLASS_NAME = "BaseTaskInstanceDTO" + + +def _find_class(tree: ast.AST, class_name: str) -> ast.ClassDef | None: + for node in ast.walk(tree): + if isinstance(node, ast.ClassDef) and node.name == class_name: + return node + return None + + +def _field_signature(class_node: ast.ClassDef) -> list[tuple[str, str, str | None]]: + """Return a normalized list of ``(name, annotation, default)`` for each field.""" + fields: list[tuple[str, str, str | None]] = [] + for stmt in class_node.body: + if isinstance(stmt, ast.AnnAssign) and isinstance(stmt.target, ast.Name): + name = stmt.target.id + annotation = ast.unparse(stmt.annotation) + default = ast.unparse(stmt.value) if stmt.value is not None else None + fields.append((name, annotation, default)) + return fields + + +def _bases(class_node: ast.ClassDef) -> list[str]: + return [ast.unparse(base) for base in class_node.bases] + + +def _extract(file_path: Path) -> tuple[list[str], list[tuple[str, str, str | None]]]: + source = file_path.read_text() + tree = ast.parse(source, filename=str(file_path)) + class_node = _find_class(tree, CLASS_NAME) + if class_node is None: + print(f"ERROR: Could not find class {CLASS_NAME} in {file_path}", file=sys.stderr) + sys.exit(1) + return _bases(class_node), _field_signature(class_node) + + +def main() -> None: + core_bases, core_fields = _extract(CORE_FILE) + sdk_bases, sdk_fields = _extract(SDK_FILE) + + if core_bases == sdk_bases and core_fields == sdk_fields: + sys.exit(0) + + print( + f"\nERROR: {CLASS_NAME} definitions in airflow-core and task-sdk are out of sync!", + file=sys.stderr, + ) + print(f"\n airflow-core: {CORE_FILE.relative_to(AIRFLOW_ROOT)}", file=sys.stderr) + print(f" task-sdk: {SDK_FILE.relative_to(AIRFLOW_ROOT)}", file=sys.stderr) + + if core_bases != sdk_bases: + print("\nClass bases differ:", file=sys.stderr) + print(f" airflow-core: {core_bases}", file=sys.stderr) + print(f" task-sdk: {sdk_bases}", file=sys.stderr) + + if core_fields != sdk_fields: + core_set = {f[0]: f for f in core_fields} + sdk_set = {f[0]: f for f in sdk_fields} + only_in_core = sorted(set(core_set) - set(sdk_set)) + only_in_sdk = sorted(set(sdk_set) - set(core_set)) + differing = sorted(name for name in set(core_set) & set(sdk_set) if core_set[name] != sdk_set[name]) + if only_in_core: + print(f"\n Fields only in airflow-core: {only_in_core}", file=sys.stderr) + if only_in_sdk: + print(f"\n Fields only in task-sdk: {only_in_sdk}", file=sys.stderr) + for name in differing: + print( + f"\n Field {name!r} differs:" + f"\n airflow-core: {core_set[name]}" + f"\n task-sdk: {sdk_set[name]}", + file=sys.stderr, + ) + + print( + f"\nUpdate both files together so the two {CLASS_NAME} definitions stay in sync.", + file=sys.stderr, + ) + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/scripts/in_container/install_airflow_and_providers.py b/scripts/in_container/install_airflow_and_providers.py index c8223f3eeff10..84847ea3e3041 100755 --- a/scripts/in_container/install_airflow_and_providers.py +++ b/scripts/in_container/install_airflow_and_providers.py @@ -1064,6 +1064,7 @@ def install_airflow_and_providers( "apache-airflow-providers-common-messaging", "apache-airflow-providers-git", "apache-airflow-providers-edge3", + "apache-airflow-providers-sdk-java", ] run_command( ["uv", "pip", "uninstall", *providers_to_uninstall_for_airflow_2], diff --git a/scripts/in_container/java_sdk_setup.sh b/scripts/in_container/java_sdk_setup.sh new file mode 100644 index 0000000000000..b3437b7fc4200 --- /dev/null +++ b/scripts/in_container/java_sdk_setup.sh @@ -0,0 +1,73 @@ +#!/usr/bin/env bash +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + + +# 1. Check Java +check_java() { + local java_bin="/files/openjdk/bin/java" + local version_output + + # First check if the locally installed OpenJDK exists and works. + if [ -x "$java_bin" ] && version_output=$("$java_bin" -version 2>&1); then + echo "Found existing OpenJDK at $java_bin. OK." + return + fi + + # On macOS, /usr/bin/java exists as a shim even without a JDK installed, + # so we must test with `java -version` directly. + if ! version_output=$(java -version 2>&1); then + echo "Java is not installed." + install_java + return + fi + + local java_version + java_version=$(echo "$version_output" | head -n1 | sed -E 's/.*"([0-9]+)(\.[0-9]+)*.*/\1/') + + if ! [[ "$java_version" =~ ^[0-9]+$ ]]; then + echo "Could not determine Java version." + install_java + return + fi + + if [ "$java_version" -ge 11 ]; then + echo "Java $java_version detected. OK." + else + echo "Java version $java_version found, but >= 11 is required." + install_java + fi +} + + +install_java() { + echo "Installing OpenJDK 11 in Breeze..." + + curl -L -o /files/openjdk-11-aarch64.tar.gz \ + https://github.com/adoptium/temurin11-binaries/releases/download/jdk-11.0.30+7/OpenJDK11U-jdk_aarch64_linux_hotspot_11.0.30_7.tar.gz + + rm -rf /files/openjdk && mkdir -p /files/openjdk && \ + tar -xzf /files/openjdk-11-aarch64.tar.gz --strip-components=1 -C /files/openjdk + + /files/openjdk/bin/java -version + echo "" +} + +check_java +# Install Java Provider +pip install -e /opt/airflow/providers/languages/java/ diff --git a/task-sdk/.pre-commit-config.yaml b/task-sdk/.pre-commit-config.yaml index 100a6e6490849..c1d4498a623fc 100644 --- a/task-sdk/.pre-commit-config.yaml +++ b/task-sdk/.pre-commit-config.yaml @@ -43,6 +43,7 @@ repos: ^src/airflow/sdk/definitions/deadline\.py$| ^src/airflow/sdk/definitions/dag\.py$| ^src/airflow/sdk/definitions/_internal/types\.py$| + ^src/airflow/sdk/execution_time/coordinator\.py$| ^src/airflow/sdk/execution_time/execute_workload\.py$| ^src/airflow/sdk/execution_time/secrets_masker\.py$| ^src/airflow/sdk/execution_time/callback_supervisor\.py$| diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py b/task-sdk/src/airflow/sdk/execution_time/comms.py index 87c7881333ad4..2268a0d1ad486 100644 --- a/task-sdk/src/airflow/sdk/execution_time/comms.py +++ b/task-sdk/src/airflow/sdk/execution_time/comms.py @@ -79,7 +79,6 @@ PreviousTIResponse, PrevSuccessfulDagRunResponse, TaskBreadcrumbsResponse, - TaskInstance, TaskInstanceState, TaskStatesResponse, TIDeferredStatePayload, @@ -96,6 +95,9 @@ XComSequenceSliceResponse, ) from airflow.sdk.exceptions import ErrorType +from airflow.sdk.execution_time.workloads.task import ( + TaskInstanceDTO, # noqa: TC001 -- Pydantic needs this at runtime +) try: from socket import recv_fds @@ -316,7 +318,7 @@ def _get_response(self) -> ReceiveMsgType | None: class StartupDetails(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) - ti: TaskInstance + ti: TaskInstanceDTO dag_rel_path: str bundle_info: BundleInfo start_date: datetime diff --git a/task-sdk/src/airflow/sdk/execution_time/coordinator.py b/task-sdk/src/airflow/sdk/execution_time/coordinator.py new file mode 100644 index 0000000000000..5f411e643ed5f --- /dev/null +++ b/task-sdk/src/airflow/sdk/execution_time/coordinator.py @@ -0,0 +1,462 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Runtime coordinator for non-Python DAG file processing and task execution. + +Provides :class:`BaseCoordinator`, the base class for +SDK-specific coordinators that bridge subprocess I/O between the +Airflow supervisor and an external-SDK runtime (Java, Go, Rust, etc.). + +The coordinator's :meth:`~BaseCoordinator.run_dag_parsing` method +handles the full lifecycle: + +1. Creates TCP servers for comm and logs channels. +2. Calls :meth:`~BaseCoordinator.dag_parsing_cmd` (provided + by the subclass) to obtain the subprocess command. +3. Spawns the subprocess and accepts TCP connections from it. +4. Runs a selector-based bridge that transparently forwards bytes + between fd 0 (supervisor) and the subprocess comm socket, and + re-emits the subprocess's log output through structlog. + +I/O multiplexing uses the same selector-based loop as +:class:`~airflow.sdk.execution_time.supervisor.WatchedSubprocess`, +driven by :func:`~airflow.sdk.execution_time.selector_loop.service_selector`. +""" + +from __future__ import annotations + +import contextlib +import os +import selectors +import socket +import subprocess +import time +from typing import TYPE_CHECKING, NamedTuple + +if TYPE_CHECKING: + from structlog.typing import FilteringBoundLogger + from typing_extensions import Self + + from airflow.sdk.api.datamodels._generated import BundleInfo + from airflow.sdk.execution_time.comms import StartupDetails + from airflow.sdk.execution_time.workloads.task import TaskInstanceDTO + + +def _start_server() -> socket.socket: + """Create a TCP server socket bound to a random port on localhost.""" + server = socket.socket() + server.bind(("127.0.0.1", 0)) + server.setblocking(True) + server.listen(1) + return server + + +def _send_startup_details(runtime_comm: socket.socket, startup_details: StartupDetails) -> None: + """ + Re-encode and send the ``StartupDetails`` frame to the runtime subprocess. + + In the task execution flow, ``task_runner.main()`` consumes the + ``StartupDetails`` message from fd 0 (to determine routing) before + delegating to the runtime coordinator. This function re-serializes + the message and writes it to the runtime subprocess's comm socket so + the subprocess receives it as if it came directly from the supervisor. + """ + from airflow.sdk.execution_time.comms import _ResponseFrame + + # Use mode="json" so that datetime, UUID, and other complex Python + # types are serialized as plain strings/numbers in msgpack — avoiding + # msgpack extension types (e.g. Timestamp) that non-Python decoders + # may not support. + frame = _ResponseFrame(id=0, body=startup_details.model_dump(mode="json")) + runtime_comm.sendall(frame.as_bytes()) + + +def _bridge( + supervisor_comm: socket.socket, + runtime_comm: socket.socket, + runtime_logs: socket.socket, + runtime_stderr: socket.socket, + proc: subprocess.Popen, + log: FilteringBoundLogger, +) -> None: + """ + Multiplex I/O between the supervisor and a runtime subprocess. + + Four channels are registered with the selector: + + - ``supervisor_comm`` -> ``runtime_comm`` (raw byte forwarding) + - ``runtime_comm`` -> ``supervisor_comm`` (raw byte forwarding) + - ``runtime_logs`` -> structlog (line-buffered JSON logs) + - ``runtime_stderr`` -> structlog (line-buffered stderr output) + + Uses the same ``(handler, on_close)`` callback contract as + :class:`~airflow.sdk.execution_time.supervisor.WatchedSubprocess`, + driven by :func:`~airflow.sdk.execution_time.selector_loop.service_selector`. + """ + from airflow.sdk.execution_time.selector_loop import ( + make_buffered_socket_reader, + make_raw_forwarder, + service_selector, + ) + from airflow.sdk.execution_time.supervisor import ( + forward_to_log, + process_log_messages_from_subprocess, + ) + + sel = selectors.DefaultSelector() + + def on_close(sock: socket.socket) -> None: + with contextlib.suppress(KeyError): + sel.unregister(sock) + + target_loggers = (log,) + + # Comm: bidirectional raw byte forwarding. + sel.register(supervisor_comm, selectors.EVENT_READ, make_raw_forwarder(runtime_comm, on_close)) + sel.register(runtime_comm, selectors.EVENT_READ, make_raw_forwarder(supervisor_comm, on_close)) + + # TCP logs channel: line-buffered JSON from the runtime SDK's LogSender, + # processed with the same handler as WatchedSubprocess (level mapping, + # timestamp parsing, exception extraction). + sel.register( + runtime_logs, + selectors.EVENT_READ, + make_buffered_socket_reader(process_log_messages_from_subprocess(target_loggers), on_close), + ) + # stderr: plain-text output from the runtime process's logging framework + # (e.g. SLF4J simple logger). Use forward_to_log which handles raw + # text lines, not process_log_messages_from_subprocess which expects JSON. + import logging + + sel.register( + runtime_stderr, + selectors.EVENT_READ, + make_buffered_socket_reader( + forward_to_log(target_loggers, logger="task.stderr", level=logging.ERROR), on_close + ), + ) + + # Event loop -- runs until the subprocess exits and all sockets are drained. + while sel.get_map(): + service_selector(sel, timeout=1.0) + if proc.poll() is not None: + # Subprocess has exited -- drain remaining data with a short deadline. + deadline = time.monotonic() + 5.0 + while sel.get_map() and time.monotonic() < deadline: + service_selector(sel, timeout=0.5) + break + + sel.close() + for sock in (supervisor_comm, runtime_comm, runtime_logs, runtime_stderr): + with contextlib.suppress(OSError): + sock.close() + + +class BaseCoordinator: + """ + Base coordinator for runtime-specific DAG file processing and task execution. + + Providers register subclasses in their ``provider.yaml`` under + ``coordinators``. Both :class:`ProvidersManager` (airflow-core) + and :class:`ProvidersManagerTaskRuntime` (task-sdk) discover registered + coordinators through this single extension point. + + Subclasses represent a specific SDK runtime (Java, Go, etc.) and + only need to implement :meth:`can_handle_dag_file`, + :meth:`dag_parsing_cmd` and :meth:`task_execution_cmd`. + The base class owns the entire bridge lifecycle: TCP servers, + subprocess management, selector-based I/O loop, and cleanup. + """ + + sdk: str + file_extension: str + + class DagParsingInfo(NamedTuple): + """Information needed for runtime Dag parsing.""" + + dag_file_path: str + bundle_name: str + bundle_path: str + mode: str = "dag-parsing" + + class TaskExecutionInfo(NamedTuple): + """Information needed for runtime task execution.""" + + what: TaskInstanceDTO + dag_rel_path: str | os.PathLike[str] + bundle_info: BundleInfo + startup_details: StartupDetails + mode: str = "task-execution" + + @classmethod + def can_handle_dag_file(cls, bundle_name: str, path: str | os.PathLike[str]) -> bool: + """ + Return ``True`` if this coordinator should handle DAG-file parsing for *path*. + + Called by :meth:`DagFileProcessorProcess._resolve_processor_target` to + decide whether to delegate parsing to this coordinator's + :meth:`run_dag_parsing` instead of the default Python entrypoint. + + The default implementation returns ``False``; subclasses must override. + """ + return False + + @classmethod + def get_code_from_file(cls, fileloc: str) -> str: + """ + Return the human-readable source code for a DAG file managed by this coordinator. + + Called by :class:`~airflow.models.dagcode.DagCode` when persisting DAG + source to the metadata database. The default Python path reads ``.py`` + files directly; runtime coordinators must override this to extract source + from their native packaging format (e.g. reading an embedded ``.java`` + file from a JAR bundle). + + :param fileloc: Absolute path to the DAG file (e.g. a ``/path/to/example.jar``). + :return: The source code as a string. + :raises FileNotFoundError: If source code cannot be retrieved from *fileloc*. + """ + raise NotImplementedError + + @classmethod + def dag_parsing_cmd( + cls, + *, + dag_file_path: str, + bundle_name: str, + bundle_path: str, + comm_addr: str, + logs_addr: str, + ) -> list[str]: + """ + Return the subprocess command for DAG file parsing. + + :param dag_file_path: Absolute path to the DAG file to parse. + :param bundle_name: Name of the DAG bundle. + :param bundle_path: Root path of the DAG bundle. + :param comm_addr: ``host:port`` the subprocess must connect to + for the bidirectional msgpack comm channel. + :param logs_addr: ``host:port`` the subprocess must connect to + for the structured JSON log channel. + :returns: Full command list (e.g. ``["java", "-cp", "...", ...]`` based on each runtime). + """ + raise NotImplementedError + + @classmethod + def task_execution_cmd( + cls, + *, + what: TaskInstanceDTO, + dag_file_path: str, + bundle_path: str, + bundle_info: BundleInfo, + comm_addr: str, + logs_addr: str, + ) -> list[str]: + """ + Return the subprocess command for task execution. + + :param what: The task instance to execute. + :param dag_file_path: Absolute path to the DAG file. + :param bundle_path: Root path of the DAG bundle. + :param bundle_info: Bundle metadata. + :param comm_addr: ``host:port`` the subprocess must connect to + for the bidirectional msgpack comm channel. + :param logs_addr: ``host:port`` the subprocess must connect to + for the structured JSON log channel. + :returns: Full command list. + """ + raise NotImplementedError + + @classmethod + def run_dag_parsing(cls, *, path: str, bundle_name: str, bundle_path: str) -> None: + """Entry point for running runtime-specific Dag File Processing.""" + cls._runtime_subprocess_entrypoint( + cls.DagParsingInfo( + dag_file_path=path, + bundle_name=bundle_name, + bundle_path=bundle_path, + ) + ) + + @classmethod + def run_task_execution( + cls, + *, + what: TaskInstanceDTO, + dag_rel_path: str | os.PathLike[str], + bundle_info: BundleInfo, + startup_details: StartupDetails, + ) -> None: + cls._runtime_subprocess_entrypoint( + cls.TaskExecutionInfo( + what=what, + dag_rel_path=dag_rel_path, + bundle_info=bundle_info, + startup_details=startup_details, + ) + ) + + @classmethod + def _runtime_subprocess_entrypoint(cls, entrypoint_info: DagParsingInfo | TaskExecutionInfo) -> None: + """ + Spawn the runtime subprocess and bridge I/O with the supervisor. + + This is called inside the forked child process where fd 0 is the + bidirectional comms socket to the supervisor. The method: + + 1. Creates TCP servers for comm and logs. + 2. Calls :meth:`dag_parsing_cmd` or :meth:`task_execution_cmd` to get the command. + 3. Spawns the subprocess with ``stdin=/dev/null`` and stderr + captured via a socketpair. + 4. Runs the selector-based bridge until the subprocess exits. + + fd layout (set up by ``_reopen_std_io_handles`` before this runs): + + - fd 0 -- bidirectional comms socket to the supervisor + (``DagFileParseRequest`` <-> ``DagFileParsingResult``, + length-prefixed msgpack frames) + - fd 1 -- stdout socket to the supervisor + - fd 2 -- stderr socket to the supervisor + - fd N -- structured JSON log channel (``log_fd``, configured by + ``_configure_logs_over_json_channel`` -> structlog) + """ + os.environ["_AIRFLOW_PROCESS_CONTEXT"] = "client" + + import structlog + + log = structlog.get_logger(logger_name="task") + log.info( + "Starting runtime subprocess", + sdk=cls.sdk, + mode=entrypoint_info.mode, + ) + + # TCP servers for the runtime subprocess to connect to. + comm_server = _start_server() + logs_server = _start_server() + comm_host, comm_port = comm_server.getsockname() + logs_host, logs_port = logs_server.getsockname() + + comm_addr = f"{comm_host}:{comm_port}" + logs_addr = f"{logs_host}:{logs_port}" + + # stderr uses a socketpair (instead of ``subprocess.PIPE``) so it + # is a real socket compatible with ``make_buffered_socket_reader``. + child_stderr, read_stderr = socket.socketpair() + + # For task execution, hold a BundleVersionLock for the entire + # subprocess lifetime to prevent the bundle version from being + # garbage-collected while the runtime process is still running. + bundle_version_lock: contextlib.AbstractContextManager = contextlib.nullcontext() + + if isinstance(entrypoint_info, cls.DagParsingInfo): + cmd = cls.dag_parsing_cmd( + dag_file_path=entrypoint_info.dag_file_path, + bundle_name=entrypoint_info.bundle_name, + bundle_path=entrypoint_info.bundle_path, + comm_addr=comm_addr, + logs_addr=logs_addr, + ) + elif isinstance(entrypoint_info, cls.TaskExecutionInfo): + from airflow.dag_processing.bundles.base import BundleVersionLock + from airflow.sdk.execution_time.task_runner import resolve_bundle + + bundle_instance = resolve_bundle(entrypoint_info.bundle_info, log) + resolved_dag_file_path = bundle_instance.path / entrypoint_info.dag_rel_path + + cmd = cls.task_execution_cmd( + what=entrypoint_info.what, + dag_file_path=os.fspath(resolved_dag_file_path), + bundle_path=os.fspath(bundle_instance.path), + bundle_info=entrypoint_info.bundle_info, + comm_addr=comm_addr, + logs_addr=logs_addr, + ) + bundle_version_lock = BundleVersionLock( + bundle_name=entrypoint_info.bundle_info.name, + bundle_version=entrypoint_info.bundle_info.version, + ) + else: + raise ValueError(f"Unknown entrypoint_info type: {type(entrypoint_info)}") + + with bundle_version_lock: + # stdin redirected to /dev/null so the subprocess does not inherit + # fd 0 (the comms socket). + proc = subprocess.Popen( + cmd, + stdin=subprocess.DEVNULL, + stderr=child_stderr.fileno(), + ) + child_stderr.close() + + # Wait for the subprocess to connect to both servers. + runtime_comm, _ = comm_server.accept() + runtime_logs, _ = logs_server.accept() + comm_server.close() + logs_server.close() + + # For task execution the supervisor already sent ``StartupDetails`` + # on fd 0 and ``task_runner.main()`` consumed it before delegating + # here. Re-encode and forward it to the runtime subprocess so it + # knows which task to execute. + if isinstance(entrypoint_info, cls.TaskExecutionInfo): + _send_startup_details(runtime_comm, entrypoint_info.startup_details) + + # fd 0 is the bidirectional comms socket to the supervisor. + supervisor_comm = socket.socket(fileno=os.dup(0)) + + _bridge(supervisor_comm, runtime_comm, runtime_logs, read_stderr, proc, log) + + +class QueueToCoordinatorMapper: + """ + Map queue names to coordinator names. + + Users often use queues as environment/isolation identifiers (e.g. ``"java-11"``, + ``"java-12"``). This mapper lets them reuse existing queue assignments to route + tasks to the correct coordinator. + + The mapping is read from the ``[sdk] queue_to_sdk`` + configuration option, which is a JSON dict of ``queue -> sdk``. + + Example configuration:: + + [sdk] + queue_to_sdk = {"java-11": "java", "java-12": "java"} + """ + + def __init__(self, mapping: dict[str, str]) -> None: + self._mapping = mapping + + @classmethod + def from_config(cls) -> Self: + """Load the queue-to-runtime mapping from airflow configuration.""" + from airflow.sdk.configuration import conf + + mapping = conf.getjson("sdk", "queue_to_sdk", fallback={}) + if not isinstance(mapping, dict): + return cls({}) + return cls(mapping) + + def resolve(self, queue: str) -> str | None: + """Return the runtime coordinator name for *queue*, or ``None`` if unmapped.""" + return self._mapping.get(queue) + + +__all__ = ["BaseCoordinator", "QueueToCoordinatorMapper"] diff --git a/task-sdk/src/airflow/sdk/execution_time/selector_loop.py b/task-sdk/src/airflow/sdk/execution_time/selector_loop.py new file mode 100644 index 0000000000000..d67014ad1b418 --- /dev/null +++ b/task-sdk/src/airflow/sdk/execution_time/selector_loop.py @@ -0,0 +1,159 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Selector-based I/O loop utilities shared across subprocess monitors. + +Both :class:`~airflow.sdk.execution_time.supervisor.WatchedSubprocess` +(supervisor-side) and provider-registered bridges such as the Locale DagFileProcessor (child-side) use these building blocks to multiplex +socket I/O without threads. + +The common contract for every callback registered with the selector: + +* The selector stores a ``(handler, on_close)`` tuple as ``key.data``. +* ``handler(fileobj) -> bool`` — read available data and return + ``True`` to keep listening, ``False`` on EOF / error. +* ``on_close(fileobj)`` — called when the handler returns ``False``; + must unregister the fileobj from the selector. +* :func:`service_selector` drives one iteration of this protocol. +""" + +from __future__ import annotations + +import selectors +from contextlib import suppress +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from collections.abc import Callable, Generator + from socket import socket + + # (handler, on_close) — stored as ``selector.register(..., data=cb)`` + SelectorCallback = tuple[Callable[[socket], bool], Callable[[socket], None]] + + +# Sockets, even the `.makefile()` function don't correctly do line buffering on reading. If a chunk is read +# and it doesn't contain a new line character, `.readline()` will just return the chunk as is. +# +# This returns a callback suitable for attaching to a `selector` that reads in to a buffer, and yields lines +# to a (sync) generator +def make_buffered_socket_reader( + gen: Generator[None, bytes | bytearray, None], + on_close: Callable[[socket], None], + buffer_size: int = 4096, +) -> SelectorCallback: + """ + Create a selector callback that line-buffers socket data into a generator. + + Bytes are accumulated until a newline is found; each + complete line is sent to *gen* via ``gen.send(line)``. On EOF the + remainder of the buffer (if any) is flushed. + + Returns a ``(handler, on_close)`` tuple suitable for + ``selector.register(..., data=...)``. + """ + buffer = bytearray() # This will hold our accumulated binary data + read_buffer = bytearray(buffer_size) # Temporary buffer for each read + + # We need to start up the generator to get it to the point it's at waiting on the yield + next(gen) + + def cb(sock: socket): + nonlocal buffer, read_buffer + # Read up to `buffer_size` bytes of data from the socket + n_received = sock.recv_into(read_buffer) + + if not n_received: + # If no data is returned, the connection is closed. Return whatever is left in the buffer + if len(buffer): + with suppress(StopIteration): + gen.send(buffer) + return False + + buffer.extend(read_buffer[:n_received]) + + # We could have read multiple lines in one go, yield them all + while (newline_pos := buffer.find(b"\n")) != -1: + line = buffer[: newline_pos + 1] + try: + gen.send(line) + except StopIteration: + return False + buffer = buffer[newline_pos + 1 :] # Update the buffer with remaining data + + return True + + return cb, on_close + + +def make_raw_forwarder( + dest: socket, + on_close: Callable[[socket], None], +) -> SelectorCallback: + """ + Create a selector callback that forwards raw bytes to *dest*. + + Used for transparent protocol bridges where bytes must be shuttled + between two sockets without interpretation (e.g. length-prefixed + msgpack frames between a supervisor and a Java subprocess). + """ + + def cb(sock: socket) -> bool: + data = sock.recv(65536) + if not data: + return False + try: + dest.sendall(data) + except (BrokenPipeError, ConnectionResetError, OSError): + return False + return True + + return cb, on_close + + +def service_selector(selector: selectors.BaseSelector, timeout: float = 1.0) -> None: + """ + Process one round of selector events. + + For each ready socket whose handler returns ``False`` (EOF / error), + the socket's *on_close* callback is invoked and the socket is closed. + """ + # Ensure minimum timeout to prevent CPU spike with tight loop when timeout is 0 or negative + timeout = max(0.01, timeout) + events = selector.select(timeout=timeout) + for key, _ in events: + # Retrieve the handler responsible for processing this file object (e.g., stdout, stderr) + socket_handler, on_close = key.data + + # Example of handler behavior: + # If the subprocess writes "Hello, World!" to stdout: + # - `socket_handler` reads and processes the message. + # - If EOF is reached, the handler returns False to signal no more reads are expected. + # - BrokenPipeError should be caught and treated as if the handler returned false, similar + # to EOF case + try: + need_more = socket_handler(key.fileobj) + except (BrokenPipeError, ConnectionResetError): + need_more = False + + # If the handler signals that the file object is no longer needed (EOF, closed, etc.) + # unregister it from the selector to stop monitoring; `wait()` blocks until all selectors + # are removed. + if not need_more: + sock: socket = key.fileobj # type: ignore[assignment] + on_close(sock) + sock.close() diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index 375c5a9e30b8e..9913a2c19f60b 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -126,6 +126,7 @@ handle_get_variable, handle_mask_secret, ) +from airflow.sdk.execution_time.selector_loop import make_buffered_socket_reader, service_selector try: from socket import send_fds @@ -139,6 +140,8 @@ from airflow.executors.workloads import BundleInfo from airflow.sdk.bases.secrets_backend import BaseSecretsBackend from airflow.sdk.definitions.connection import Connection + from airflow.sdk.execution_time.selector_loop import SelectorCallback + from airflow.sdk.execution_time.workloads.task import TaskInstanceDTO from airflow.sdk.types import RuntimeTaskInstanceProtocol as RuntimeTI __all__ = ["ActivitySubprocess", "WatchedSubprocess", "supervise", "supervise_task"] @@ -690,7 +693,7 @@ def _get_target_loggers(self) -> tuple[FilteringBoundLogger, ...]: target_loggers += (log,) return target_loggers - def _create_log_forwarder(self, loggers, name, log_level=logging.INFO) -> Callable[[socket], bool]: + def _create_log_forwarder(self, loggers, name, log_level=logging.INFO) -> SelectorCallback: """Create a socket handler that forwards logs to a logger.""" loggers = tuple( reconfigure_logger( @@ -874,41 +877,15 @@ def _service_subprocess( """ Service subprocess events by processing socket activity and checking for process exit. - This method: - - Waits for activity on the registered file objects (via `self.selector.select`). - - Processes any events triggered on these file objects. - - Checks if the subprocess has exited during the wait. + Delegates the selector event loop to :func:`service_selector` (shared + with provider-registered bridges), then checks the subprocess status. :param max_wait_time: Maximum time to block while waiting for events, in seconds. :param raise_on_timeout: If True, raise an exception if the subprocess does not exit within the timeout. :param expect_signal: Signal not to log if the task exits with this code. :returns: The process exit code, or None if it's still alive """ - # Ensure minimum timeout to prevent CPU spike with tight loop when timeout is 0 or negative - timeout = max(0.01, max_wait_time) - events = self.selector.select(timeout=timeout) - for key, _ in events: - # Retrieve the handler responsible for processing this file object (e.g., stdout, stderr) - socket_handler, on_close = key.data - - # Example of handler behavior: - # If the subprocess writes "Hello, World!" to stdout: - # - `socket_handler` reads and processes the message. - # - If EOF is reached, the handler returns False to signal no more reads are expected. - # - BrokenPipeError should be caught and treated as if the handler returned false, similar - # to EOF case - try: - need_more = socket_handler(key.fileobj) - except (BrokenPipeError, ConnectionResetError): - need_more = False - - # If the handler signals that the file object is no longer needed (EOF, closed, etc.) - # unregister it from the selector to stop monitoring; `wait()` blocks until all selectors - # are removed. - if not need_more: - sock: socket = key.fileobj # type: ignore[assignment] - on_close(sock) - sock.close() + service_selector(self.selector, timeout=max_wait_time) # Check if the subprocess has exited return self._check_subprocess_exit(raise_on_timeout=raise_on_timeout, expect_signal=expect_signal) @@ -1118,7 +1095,7 @@ class ActivitySubprocess(WatchedSubprocess): def start( # type: ignore[override] cls, *, - what: TaskInstance, + what: TaskInstanceDTO, dag_rel_path: str | os.PathLike[str], bundle_info, client: Client, @@ -1147,7 +1124,7 @@ def start( # type: ignore[override] def _on_child_started( self, *, - ti: TaskInstance, + ti: TaskInstanceDTO, dag_rel_path: str | os.PathLike[str], bundle_info, sentry_integration: str, @@ -1918,50 +1895,6 @@ def run_task_in_process(ti: TaskInstance, task) -> TaskRunResult: return InProcessTestSupervisor.start(what=ti, task=task) -# Sockets, even the `.makefile()` function don't correctly do line buffering on reading. If a chunk is read -# and it doesn't contain a new line character, `.readline()` will just return the chunk as is. -# -# This returns a callback suitable for attaching to a `selector` that reads in to a buffer, and yields lines -# to a (sync) generator -def make_buffered_socket_reader( - gen: Generator[None, bytes | bytearray, None], - on_close: Callable[[socket], None], - buffer_size: int = 4096, -): - buffer = bytearray() # This will hold our accumulated binary data - read_buffer = bytearray(buffer_size) # Temporary buffer for each read - - # We need to start up the generator to get it to the point it's at waiting on the yield - next(gen) - - def cb(sock: socket): - nonlocal buffer, read_buffer - # Read up to `buffer_size` bytes of data from the socket - n_received = sock.recv_into(read_buffer) - - if not n_received: - # If no data is returned, the connection is closed. Return whatever is left in the buffer - if len(buffer): - with suppress(StopIteration): - gen.send(buffer) - return False - - buffer.extend(read_buffer[:n_received]) - - # We could have read multiple lines in one go, yield them all - while (newline_pos := buffer.find(b"\n")) != -1: - line = buffer[: newline_pos + 1] - try: - gen.send(line) - except StopIteration: - return False - buffer = buffer[newline_pos + 1 :] # Update the buffer with remaining data - - return True - - return cb, on_close - - def length_prefixed_frame_reader( gen: Generator[None, _RequestFrame, None], on_close: Callable[[socket], None] ): @@ -2136,7 +2069,7 @@ def _configure_logging(log_path: str, client: Client) -> tuple[FilteringBoundLog def supervise_task( *, - ti: TaskInstance, + ti: TaskInstanceDTO, bundle_info: BundleInfo, dag_rel_path: str | os.PathLike[str], token: str, diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 56ba8343c648b..68a5240386af0 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -47,6 +47,7 @@ from airflow.sdk.api.client import get_hostname, getuser from airflow.sdk.api.datamodels._generated import ( AssetProfile, + BundleInfo, DagRun, PreviousTIResponse, TaskInstance, @@ -778,12 +779,7 @@ def parse(what: StartupDetails, log: Logger) -> RuntimeTaskInstance: bundle_info = what.bundle_info bundle_prepare_start = time.monotonic() - bundle_instance = DagBundlesManager().get_bundle( - name=bundle_info.name, - version=bundle_info.version, - ) - bundle_instance.initialize() - _verify_bundle_access(bundle_instance, log) + bundle_instance = resolve_bundle(bundle_info, log) bundle_prepare_ms = int((time.monotonic() - bundle_prepare_start) * 1000) dag_absolute_path = os.fspath(Path(bundle_instance.path, what.dag_rel_path)) @@ -909,6 +905,22 @@ def _verify_bundle_access(bundle_instance: BaseDagBundle, log: Logger) -> None: ) +def resolve_bundle(bundle_info: BundleInfo, log: Logger) -> BaseDagBundle: + """ + Resolve, initialize, and verify access to a DAG bundle. + + Used by both the standard Python task execution path and locale + coordinators (Java, Go, etc.) to obtain a ready-to-use bundle instance. + """ + bundle_instance = DagBundlesManager().get_bundle( + name=bundle_info.name, + version=bundle_info.version, + ) + bundle_instance.initialize() + _verify_bundle_access(bundle_instance, log) + return bundle_instance + + def get_startup_details() -> StartupDetails: # The parent sends us a StartupDetails message un-prompted. After this, every single message is only sent # in response to us sending a request. @@ -1954,6 +1966,96 @@ def flush_spans(): provider.force_flush(timeout_millis=timeout_millis) +def _resolve_runtime_entrypoint(startup_details: StartupDetails, log: Logger) -> Callable[[], None] | None: + """ + Check provider-registered runtime coordinators for a runtime-specific entrypoint. + + Resolution order: + + 1. **Queue mapping** -- the ``[sdk] queue_to_sdk`` config maps + the task's ``queue`` to a runtime coordinator name (e.g. ``"java-queue" -> "java"``). + Used by the python-stub pattern where users set ``queue="java-queue"`` explicitly. + 2. **DAG file extension** -- if no queue mapping matches, the DAG file's extension + (e.g. ``.jar``) is compared against each coordinator's ``file_extension`` attribute. + Used by the pure-Java (or pure-) pattern where the entire DAG is authored + in a non-Python language. + + Returns a no-arg callable that bridges fd 0 to the runtime subprocess, + or ``None`` to fall through to the standard Python execution path. + """ + import functools + + from airflow.sdk.execution_time.coordinator import QueueToCoordinatorMapper + from airflow.sdk.providers_manager_runtime import ProvidersManagerTaskRuntime + + coordinators = ProvidersManagerTaskRuntime().coordinators + + # Step 1: queue-to-runtime mapping. + queue = startup_details.ti.queue + if (sdk := QueueToCoordinatorMapper.from_config().resolve(queue)) is not None: + for coordinator_cls in coordinators: + if not hasattr(coordinator_cls, "run_task_execution"): + continue + if getattr(coordinator_cls, "sdk", None) != sdk: + continue + + log.debug( + "Resolved sdk-specific entrypoint for task via queue mapping", + coordinator=coordinator_cls, + sdk=sdk, + queue=queue, + task_id=startup_details.ti.task_id, + ) + return functools.partial( + coordinator_cls.run_task_execution, + what=startup_details.ti, + dag_rel_path=startup_details.dag_rel_path, + bundle_info=startup_details.bundle_info, + startup_details=startup_details, + ) + + log.warning( + "No coordinator found for sdk", + sdk=sdk, + queue=queue, + task_id=startup_details.ti.task_id, + ) + return None + + # Step 2: DAG file extension fallback (pure- DAGs). + dag_rel_path = startup_details.dag_rel_path + for coordinator_cls in coordinators: + # TODO: Use `can_handle_dag_file` method instead of file_extension attribute for better maintainability. + ext = getattr(coordinator_cls, "file_extension", None) + if not ext or not dag_rel_path.endswith(ext): + continue + if not hasattr(coordinator_cls, "run_task_execution"): + continue + + log.debug( + "Resolved runtime-specific entrypoint for task via DAG file extension", + coordinator=coordinator_cls, + sdk=getattr(coordinator_cls, "sdk", None), + dag_rel_path=dag_rel_path, + task_id=startup_details.ti.task_id, + ) + return functools.partial( + coordinator_cls.run_task_execution, + what=startup_details.ti, + dag_rel_path=startup_details.dag_rel_path, + bundle_info=startup_details.bundle_info, + startup_details=startup_details, + ) + + log.debug( + "No runtime coordinator matched, using standard Python execution path", + queue=queue, + dag_rel_path=dag_rel_path, + task_id=startup_details.ti.task_id, + ) + return None + + @flush_spans() def main(): log = structlog.get_logger(logger_name="task") @@ -1980,6 +2082,14 @@ def main(): # startup message as a ResendLoggingFD response. if os.environ.pop("_AIRFLOW_FORK_EXEC", None) == "1": reinit_supervisor_comms() + # Check if a provider-registered runtime coordinator should + # handle this task (e.g. Java, Go) instead of the standard + # Python execution path. + log.debug("Checking for runtime-specific entrypoint") + runtime_entrypoint = _resolve_runtime_entrypoint(startup_details, log) + if runtime_entrypoint is not None: + runtime_entrypoint() + return span = _make_task_span(msg=startup_details) stack.enter_context(span) ti, context, log = startup(msg=startup_details) diff --git a/task-sdk/src/airflow/sdk/execution_time/workloads/__init__.py b/task-sdk/src/airflow/sdk/execution_time/workloads/__init__.py new file mode 100644 index 0000000000000..cdf955e742dc0 --- /dev/null +++ b/task-sdk/src/airflow/sdk/execution_time/workloads/__init__.py @@ -0,0 +1,23 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Workload schemas for Task SDK execution-time communication.""" + +from __future__ import annotations + +from airflow.sdk.execution_time.workloads.task import TaskInstanceDTO + +__all__ = ["TaskInstanceDTO"] diff --git a/task-sdk/src/airflow/sdk/execution_time/workloads/task.py b/task-sdk/src/airflow/sdk/execution_time/workloads/task.py new file mode 100644 index 0000000000000..ceff200856f06 --- /dev/null +++ b/task-sdk/src/airflow/sdk/execution_time/workloads/task.py @@ -0,0 +1,53 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Task workload schemas for Task SDK execution-time communication.""" + +from __future__ import annotations + +import uuid + +from pydantic import BaseModel, Field + + +class BaseTaskInstanceDTO(BaseModel): + """ + Base schema for TaskInstance with the minimal fields shared by Executors and the Task SDK. + + This class is duplicated in :mod:`airflow.executors.workloads.task` and the + two definitions are kept in sync by the ``check-task-instance-dto-sync`` + prek hook. Update both files together. + """ + + id: uuid.UUID + dag_version_id: uuid.UUID + task_id: str + dag_id: str + run_id: str + try_number: int + map_index: int = -1 + + pool_slots: int + queue: str + priority_weight: int + executor_config: dict | None = Field(default=None, exclude=True) + + parent_context_carrier: dict | None = None + context_carrier: dict | None = None + + +class TaskInstanceDTO(BaseTaskInstanceDTO): + """Task SDK TaskInstanceDTO.""" diff --git a/task-sdk/src/airflow/sdk/providers_manager_runtime.py b/task-sdk/src/airflow/sdk/providers_manager_runtime.py index e28ed3fe14a83..63c8c97f816ef 100644 --- a/task-sdk/src/airflow/sdk/providers_manager_runtime.py +++ b/task-sdk/src/airflow/sdk/providers_manager_runtime.py @@ -51,6 +51,7 @@ from airflow.sdk import BaseHook from airflow.sdk.bases.decorator import TaskDecorator from airflow.sdk.definitions.asset import Asset + from airflow.sdk.execution_time.coordinator import BaseCoordinator log = structlog.getLogger(__name__) @@ -150,6 +151,7 @@ def __init__(self): # Keeps dict of hooks keyed by connection type. They are lazy evaluated at access time self._hooks_lazy_dict: LazyDictWithCache[str, HookInfo | Callable] = LazyDictWithCache() self._plugins_set: set[PluginInfo] = set() + self._coordinators: list[type[BaseCoordinator]] = [] self._provider_schema_validator = _create_provider_info_schema_validator() self._init_airflow_core_hooks() # Populated by initialize_provider_configs(); holds provider-contributed config sections. @@ -220,6 +222,12 @@ def initialize_providers_taskflow_decorator(self): self.initialize_providers_list() self._discover_taskflow_decorators() + @provider_info_cache("coordinators") + def initialize_providers_coordinators(self): + """Lazy initialization of providers runtime coordinators.""" + self.initialize_providers_list() + self._discover_coordinators() + @provider_info_cache("provider_configs") def initialize_provider_configs(self): """Lazy initialization of provider configuration metadata and merge it into SDK ``conf``.""" @@ -464,6 +472,19 @@ def _import_hook( connection_testable=hasattr(hook_class, "test_connection"), ) + def _discover_coordinators(self) -> None: + """Retrieve and pre-load all coordinators defined in the providers.""" + seen: set[str] = set() + for provider_package, provider in self._provider_dict.items(): + for coordinator_class_path in provider.data.get("coordinators", []): + if coordinator_class_path in seen: + continue + coordinator_cls = _correctness_check(provider_package, coordinator_class_path, provider) + if coordinator_cls: + seen.add(coordinator_class_path) + self._coordinators.append(coordinator_cls) + self._coordinators = sorted(self._coordinators, key=lambda c: c.__qualname__) + def _discover_filesystems(self) -> None: """Retrieve all filesystems defined in the providers.""" for provider_package, provider in self._provider_dict.items(): @@ -611,6 +632,12 @@ def plugins(self) -> list[PluginInfo]: self.initialize_providers_plugins() return sorted(self._plugins_set, key=lambda x: x.plugin_class) + @property + def coordinators(self) -> list[type[BaseCoordinator]]: + """Returns pre-loaded runtime coordinator classes available in providers.""" + self.initialize_providers_coordinators() + return self._coordinators + @property def provider_configs(self) -> list[tuple[str, dict[str, Any]]]: self.initialize_provider_configs() @@ -643,6 +670,7 @@ def _cleanup(self): self._asset_uri_handlers.clear() self._asset_factories.clear() self._asset_to_openlineage_converters.clear() + self._coordinators.clear() self._provider_configs.clear() # Imported lazily to preserve SDK conf lazy initialization and avoid a configuration/runtime cycle. diff --git a/task-sdk/tests/task_sdk/definitions/test_mappedoperator.py b/task-sdk/tests/task_sdk/definitions/test_mappedoperator.py index 2b34fac6ea0f9..93c5cc19aed47 100644 --- a/task-sdk/tests/task_sdk/definitions/test_mappedoperator.py +++ b/task-sdk/tests/task_sdk/definitions/test_mappedoperator.py @@ -680,14 +680,14 @@ def mock_comms_response(msg): ("tg.t2", 0): ["a", "b"], ("tg.t2", 1): [4], ("tg.t2", 2): ["z"], - ("t3", None): [["a", "b"], [4], ["z"]], + ("t3", -1): [["a", "b"], [4], ["z"]], } # We hard-code the number of expansions here as the server is in charge of that. expansion_per_task_id = { "tg.t1": range(3), "tg.t2": range(3), - "t3": [None], + "t3": [-1], } for task in dag.tasks: for map_index in expansion_per_task_id[task.task_id]: diff --git a/task-sdk/tests/task_sdk/definitions/test_xcom_arg.py b/task-sdk/tests/task_sdk/definitions/test_xcom_arg.py index af487851b07cb..6014d26f2208b 100644 --- a/task-sdk/tests/task_sdk/definitions/test_xcom_arg.py +++ b/task-sdk/tests/task_sdk/definitions/test_xcom_arg.py @@ -344,7 +344,7 @@ def xcom_get(msg): mock_supervisor_comms.send.side_effect = xcom_get # Run "pull_one" and "pull_all". - assert run_ti(dag, "pull_all", None) == TaskInstanceState.SUCCESS + assert run_ti(dag, "pull_all", -1) == TaskInstanceState.SUCCESS assert all_results == ["a", "b", "c", 1, 2] states = [run_ti(dag, "pull_one", map_index) for map_index in range(5)] diff --git a/task-sdk/tests/task_sdk/execution_time/test_comms.py b/task-sdk/tests/task_sdk/execution_time/test_comms.py index 5c6d88439250c..37a91dd0ecc28 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_comms.py +++ b/task-sdk/tests/task_sdk/execution_time/test_comms.py @@ -86,6 +86,9 @@ def test_recv_StartupDetails(self): "run_id": "b", "dag_id": "c", "dag_version_id": uuid.UUID("4d828a62-a417-4936-a7a6-2b3fabacecab"), + "pool_slots": 1, + "queue": "default", + "priority_weight": 1, }, "ti_context": { "dag_run": { diff --git a/task-sdk/tests/task_sdk/execution_time/test_coordinator.py b/task-sdk/tests/task_sdk/execution_time/test_coordinator.py new file mode 100644 index 0000000000000..082cfaf6051b9 --- /dev/null +++ b/task-sdk/tests/task_sdk/execution_time/test_coordinator.py @@ -0,0 +1,598 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import contextlib +import os +import socket +import subprocess +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +from airflow.sdk.execution_time.coordinator import ( + BaseCoordinator, + _bridge, + _send_startup_details, + _start_server, +) + + +class TestStartServer: + def test_binds_to_localhost(self): + server = _start_server() + try: + host, port = server.getsockname() + assert host == "127.0.0.1" + assert port > 0 + finally: + server.close() + + def test_assigns_random_port(self): + s1 = _start_server() + s2 = _start_server() + try: + _, port1 = s1.getsockname() + _, port2 = s2.getsockname() + # Two servers should get different ports + assert port1 != port2 + finally: + s1.close() + s2.close() + + def test_accepts_connection(self): + server = _start_server() + try: + addr = server.getsockname() + client = socket.socket() + client.connect(addr) + conn, _ = server.accept() + conn.sendall(b"ping") + assert client.recv(4) == b"ping" + conn.close() + client.close() + finally: + server.close() + + +class TestSendStartupDetails: + def test_sends_frame_bytes_to_socket(self): + """Verify _send_startup_details calls sendall with a length-prefixed msgpack frame.""" + mock_startup = MagicMock() + mock_startup.model_dump.return_value = {"type": "StartupDetails", "ti": {}} + + mock_socket = MagicMock(spec=socket.socket) + + _send_startup_details(mock_socket, mock_startup) + + mock_startup.model_dump.assert_called_once_with(mode="json") + mock_socket.sendall.assert_called_once() + + sent_bytes = mock_socket.sendall.call_args[0][0] + # First 4 bytes are the big-endian length prefix + assert len(sent_bytes) > 4 + length = int.from_bytes(sent_bytes[:4], "big") + assert length == len(sent_bytes) - 4 + + def test_frame_contains_response_id_zero(self): + """The frame should use id=0.""" + import msgpack + + mock_startup = MagicMock() + mock_startup.model_dump.return_value = {"type": "StartupDetails"} + + mock_socket = MagicMock(spec=socket.socket) + + _send_startup_details(mock_socket, mock_startup) + + sent_bytes = mock_socket.sendall.call_args[0][0] + # Frame is encoded as [id, body, error] + frame = msgpack.unpackb(sent_bytes[4:]) + assert frame[0] == 0 + + def test_frame_body_matches_model_dump(self): + """The frame body should be the model_dump(mode='json') output.""" + import msgpack + + body = {"type": "StartupDetails", "ti": {"task_id": "t1"}, "dag_rel_path": "test.jar"} + mock_startup = MagicMock() + mock_startup.model_dump.return_value = body + + mock_socket = MagicMock(spec=socket.socket) + + _send_startup_details(mock_socket, mock_startup) + + sent_bytes = mock_socket.sendall.call_args[0][0] + # Frame is encoded as [id, body, error] + frame = msgpack.unpackb(sent_bytes[4:]) + assert frame[1] == body + + def test_real_socket_roundtrip(self): + """Send through real sockets and verify the frame is receivable.""" + import msgpack + + server = socket.socket() + server.bind(("127.0.0.1", 0)) + server.listen(1) + addr = server.getsockname() + + client = socket.socket() + client.connect(addr) + conn, _ = server.accept() + + try: + body = {"type": "StartupDetails", "value": 42} + mock_startup = MagicMock() + mock_startup.model_dump.return_value = body + + _send_startup_details(conn, mock_startup) + + # Read the length prefix + length_bytes = client.recv(4) + length = int.from_bytes(length_bytes, "big") + + # Read the payload — frame is [id, body, error] + data = client.recv(length) + frame = msgpack.unpackb(data) + assert frame[0] == 0 + assert frame[1] == body + finally: + conn.close() + client.close() + server.close() + + +class TestBaseCoordinatorDefaults: + def test_can_handle_dag_file_returns_false(self): + assert BaseCoordinator.can_handle_dag_file("bundle", "/path/to/dag.py") is False + + def test_get_code_from_file_raises_not_implemented(self): + with pytest.raises(NotImplementedError): + BaseCoordinator.get_code_from_file("/path/to/dag.jar") + + def test_dag_parsing_cmd_raises_not_implemented(self): + with pytest.raises(NotImplementedError): + BaseCoordinator.dag_parsing_cmd( + dag_file_path="/dag.jar", + bundle_name="b", + bundle_path="/path", + comm_addr="127.0.0.1:1234", + logs_addr="127.0.0.1:1235", + ) + + def test_task_execution_cmd_raises_not_implemented(self): + with pytest.raises(NotImplementedError): + BaseCoordinator.task_execution_cmd( + what=MagicMock(), + dag_file_path="/dag.jar", + bundle_path="/path", + bundle_info=MagicMock(), + comm_addr="127.0.0.1:1234", + logs_addr="127.0.0.1:1235", + ) + + +class TestCoordinatorNamedTuples: + def test_dag_parsing_info_defaults(self): + info = BaseCoordinator.DagParsingInfo( + dag_file_path="/dag.jar", + bundle_name="my-bundle", + bundle_path="/bundles/my-bundle", + ) + assert info.mode == "dag-parsing" + assert info.dag_file_path == "/dag.jar" + assert info.bundle_name == "my-bundle" + assert info.bundle_path == "/bundles/my-bundle" + + def test_task_execution_info_defaults(self): + mock_ti = MagicMock() + mock_bundle = MagicMock() + mock_startup = MagicMock() + info = BaseCoordinator.TaskExecutionInfo( + what=mock_ti, + dag_rel_path="dags/example.jar", + bundle_info=mock_bundle, + startup_details=mock_startup, + ) + assert info.mode == "task-execution" + assert info.what is mock_ti + assert info.dag_rel_path == "dags/example.jar" + + +class TestBridge: + def test_bridge_forwards_comm_bidirectionally(self): + """Verify _bridge sets up bidirectional forwarding and processes all channels.""" + # Use real socketpairs for the 4 channels + sup_send, sup_recv = socket.socketpair() + rt_send, rt_recv = socket.socketpair() + log_send, log_recv = socket.socketpair() + stderr_send, stderr_recv = socket.socketpair() + + mock_proc = MagicMock(spec=subprocess.Popen) + # Make the process "exit" immediately so the bridge drains and stops + mock_proc.poll.return_value = 0 + mock_log = MagicMock() + + try: + # Send data before starting the bridge + sup_send.sendall(b"from_supervisor") + rt_send.sendall(b"from_runtime") + log_send.sendall(b'{"event":"hello","level":"info"}\n') + stderr_send.sendall(b"stderr line\n") + + # Close sending sides so the bridge will see EOF + sup_send.close() + rt_send.close() + log_send.close() + stderr_send.close() + + _bridge(sup_recv, rt_recv, log_recv, stderr_recv, mock_proc, mock_log) + + # If we got here without hanging, the bridge correctly processed all channels + finally: + for s in (sup_send, rt_send, log_send, stderr_send, sup_recv, rt_recv, log_recv, stderr_recv): + with contextlib.suppress(OSError): + s.close() + + def test_bridge_drains_after_process_exit(self): + """Verify _bridge drains remaining data after the subprocess exits.""" + sup_local, sup_remote = socket.socketpair() + rt_local, rt_remote = socket.socketpair() + log_local, log_remote = socket.socketpair() + stderr_local, stderr_remote = socket.socketpair() + + mock_proc = MagicMock(spec=subprocess.Popen) + # First poll: still running; subsequent: exited + mock_proc.poll.side_effect = [None, 0, 0, 0, 0, 0, 0, 0, 0, 0] + mock_log = MagicMock() + + try: + # Send data after bridge starts its first iteration + stderr_local.sendall(b"error output\n") + stderr_local.close() + sup_local.close() + rt_local.close() + log_local.close() + + _bridge(sup_remote, rt_remote, log_remote, stderr_remote, mock_proc, mock_log) + finally: + for s in ( + sup_local, + sup_remote, + rt_local, + rt_remote, + log_local, + log_remote, + stderr_local, + stderr_remote, + ): + with contextlib.suppress(OSError): + s.close() + + def test_bridge_closes_all_sockets(self): + """Verify _bridge closes all four sockets when done.""" + sup = MagicMock(spec=socket.socket) + rt = MagicMock(spec=socket.socket) + logs = MagicMock(spec=socket.socket) + stderr = MagicMock(spec=socket.socket) + + mock_proc = MagicMock(spec=subprocess.Popen) + mock_proc.poll.return_value = 0 + mock_log = MagicMock() + + # Patch the selector to avoid real I/O; service_selector is imported inside + # _bridge so we patch it on the selector_loop module + with ( + patch("airflow.sdk.execution_time.coordinator.selectors.DefaultSelector") as mock_sel_cls, + patch("airflow.sdk.execution_time.selector_loop.service_selector"), + ): + mock_sel = MagicMock() + mock_sel_cls.return_value = mock_sel + # Empty selector map so the while loop exits immediately + mock_sel.get_map.return_value = {} + + _bridge(sup, rt, logs, stderr, mock_proc, mock_log) + + sup.close.assert_called() + rt.close.assert_called() + logs.close.assert_called() + stderr.close.assert_called() + mock_sel.close.assert_called_once() + + +class TestRunDagParsing: + @patch.object(BaseCoordinator, "_runtime_subprocess_entrypoint") + def test_run_dag_parsing_creates_dag_parsing_info(self, mock_entrypoint): + BaseCoordinator.run_dag_parsing( + path="/bundles/my-bundle/dags/example.jar", + bundle_name="my-bundle", + bundle_path="/bundles/my-bundle", + ) + + mock_entrypoint.assert_called_once() + info = mock_entrypoint.call_args[0][0] + assert isinstance(info, BaseCoordinator.DagParsingInfo) + assert info.dag_file_path == "/bundles/my-bundle/dags/example.jar" + assert info.bundle_name == "my-bundle" + assert info.bundle_path == "/bundles/my-bundle" + assert info.mode == "dag-parsing" + + +class TestRunTaskExecution: + @patch.object(BaseCoordinator, "_runtime_subprocess_entrypoint") + def test_run_task_execution_creates_task_execution_info(self, mock_entrypoint): + mock_ti = MagicMock() + mock_bundle_info = MagicMock() + mock_startup = MagicMock() + + BaseCoordinator.run_task_execution( + what=mock_ti, + dag_rel_path="dags/example.jar", + bundle_info=mock_bundle_info, + startup_details=mock_startup, + ) + + mock_entrypoint.assert_called_once() + info = mock_entrypoint.call_args[0][0] + assert isinstance(info, BaseCoordinator.TaskExecutionInfo) + assert info.what is mock_ti + assert info.dag_rel_path == "dags/example.jar" + assert info.bundle_info is mock_bundle_info + assert info.startup_details is mock_startup + assert info.mode == "task-execution" + + +class TestRuntimeSubprocessEntrypoint: + @pytest.fixture(autouse=True) + def _restore_process_context_env(self): + """``_runtime_subprocess_entrypoint`` runs inside a forked child in production + and sets ``_AIRFLOW_PROCESS_CONTEXT`` for the runtime subprocess. When tests + invoke it in-process, the env var leaks into other tests — restore it.""" + old = os.environ.get("_AIRFLOW_PROCESS_CONTEXT") + try: + yield + finally: + if old is None: + os.environ.pop("_AIRFLOW_PROCESS_CONTEXT", None) + else: + os.environ["_AIRFLOW_PROCESS_CONTEXT"] = old + + def test_unknown_entrypoint_info_type_raises(self): + class TestCoordinator(BaseCoordinator): + sdk = "test" + file_extension = ".test" + + # Needs a 'mode' attribute (accessed during logging) but must not be + # an instance of DagParsingInfo or TaskExecutionInfo. + fake_info = MagicMock() + fake_info.mode = "unknown" + + with pytest.raises(ValueError, match="Unknown entrypoint_info type"): + TestCoordinator._runtime_subprocess_entrypoint(fake_info) # type: ignore[arg-type] + + @patch("airflow.sdk.execution_time.coordinator._bridge") + @patch("airflow.sdk.execution_time.coordinator._send_startup_details") + @patch("subprocess.Popen", autospec=True) + @patch("airflow.sdk.execution_time.coordinator._start_server") + @patch("os.dup", return_value=99) + def test_dag_parsing_flow(self, mock_dup, mock_start_server, mock_popen, mock_send_startup, mock_bridge): + """Verify the dag-parsing entrypoint wires up servers, spawns subprocess, and bridges.""" + # Set up mock servers + comm_server = MagicMock(spec=socket.socket) + comm_server.getsockname.return_value = ("127.0.0.1", 5000) + logs_server = MagicMock(spec=socket.socket) + logs_server.getsockname.return_value = ("127.0.0.1", 5001) + mock_start_server.side_effect = [comm_server, logs_server] + + # The runtime connects back + runtime_comm = MagicMock(spec=socket.socket) + runtime_logs = MagicMock(spec=socket.socket) + comm_server.accept.return_value = (runtime_comm, ("127.0.0.1", 9000)) + logs_server.accept.return_value = (runtime_logs, ("127.0.0.1", 9001)) + + # Mock socketpair for stderr + child_stderr = MagicMock(spec=socket.socket) + read_stderr = MagicMock(spec=socket.socket) + child_stderr.fileno.return_value = 10 + + # Mock supervisor_comm created from os.dup(0) + supervisor_comm = MagicMock(spec=socket.socket) + + class TestCoordinator(BaseCoordinator): + sdk = "test" + file_extension = ".test" + + @classmethod + def dag_parsing_cmd(cls, **kwargs): + return ["test-runtime", "--parse", kwargs["dag_file_path"]] + + info = BaseCoordinator.DagParsingInfo( + dag_file_path="/dag.test", + bundle_name="test-bundle", + bundle_path="/bundles/test-bundle", + ) + + with ( + patch("socket.socketpair", return_value=(child_stderr, read_stderr)), + patch("airflow.sdk.execution_time.coordinator.socket.socket", return_value=supervisor_comm), + ): + TestCoordinator._runtime_subprocess_entrypoint(info) + + # Subprocess spawned + mock_popen.assert_called_once() + cmd = mock_popen.call_args[0][0] + assert cmd == ["test-runtime", "--parse", "/dag.test"] + + # Servers accepted and closed + comm_server.accept.assert_called_once() + logs_server.accept.assert_called_once() + comm_server.close.assert_called_once() + logs_server.close.assert_called_once() + + # stderr child side closed after Popen + child_stderr.close.assert_called_once() + + # _send_startup_details NOT called for dag parsing + mock_send_startup.assert_not_called() + + # _bridge called with the supervisor_comm socket + mock_bridge.assert_called_once() + assert mock_bridge.call_args[0][0] is supervisor_comm + + @patch("airflow.sdk.execution_time.coordinator._bridge") + @patch("airflow.sdk.execution_time.coordinator._send_startup_details") + @patch("subprocess.Popen", autospec=True) + @patch("airflow.sdk.execution_time.coordinator._start_server") + @patch("os.dup", return_value=99) + @patch("airflow.sdk.execution_time.task_runner.resolve_bundle") + @patch("airflow.dag_processing.bundles.base.BundleVersionLock", autospec=True) + def test_task_execution_flow( + self, + mock_bundle_lock, + mock_resolve_bundle, + mock_dup, + mock_start_server, + mock_popen, + mock_send_startup, + mock_bridge, + ): + """Verify the task-execution entrypoint resolves bundle, sends startup details, and bridges.""" + # Mock servers + comm_server = MagicMock(spec=socket.socket) + comm_server.getsockname.return_value = ("127.0.0.1", 6000) + logs_server = MagicMock(spec=socket.socket) + logs_server.getsockname.return_value = ("127.0.0.1", 6001) + mock_start_server.side_effect = [comm_server, logs_server] + + runtime_comm = MagicMock(spec=socket.socket) + runtime_logs = MagicMock(spec=socket.socket) + comm_server.accept.return_value = (runtime_comm, ("127.0.0.1", 9000)) + logs_server.accept.return_value = (runtime_logs, ("127.0.0.1", 9001)) + + child_stderr = MagicMock(spec=socket.socket) + read_stderr = MagicMock(spec=socket.socket) + child_stderr.fileno.return_value = 10 + + # Mock resolved bundle + mock_bundle_instance = MagicMock() + mock_bundle_instance.path = Path("/resolved/bundles/test-bundle") + mock_resolve_bundle.return_value = mock_bundle_instance + + # BundleVersionLock as context manager + mock_lock_instance = MagicMock() + mock_bundle_lock.return_value = mock_lock_instance + mock_lock_instance.__enter__ = MagicMock(return_value=mock_lock_instance) + mock_lock_instance.__exit__ = MagicMock(return_value=False) + + mock_ti = MagicMock() + mock_bundle_info = MagicMock() + mock_bundle_info.name = "test-bundle" + mock_bundle_info.version = "v1" + mock_startup = MagicMock() + + class TestCoordinator(BaseCoordinator): + sdk = "test" + file_extension = ".test" + + @classmethod + def task_execution_cmd(cls, **kwargs): + return ["test-runtime", "--execute", kwargs["dag_file_path"]] + + info = BaseCoordinator.TaskExecutionInfo( + what=mock_ti, + dag_rel_path="dags/example.test", + bundle_info=mock_bundle_info, + startup_details=mock_startup, + ) + + supervisor_comm = MagicMock(spec=socket.socket) + + with ( + patch("socket.socketpair", return_value=(child_stderr, read_stderr)), + patch("airflow.sdk.execution_time.coordinator.socket.socket", return_value=supervisor_comm), + ): + TestCoordinator._runtime_subprocess_entrypoint(info) + + # Bundle resolved + mock_resolve_bundle.assert_called_once() + + # BundleVersionLock used + mock_bundle_lock.assert_called_once_with(bundle_name="test-bundle", bundle_version="v1") + + # Subprocess spawned with resolved path + mock_popen.assert_called_once() + cmd = mock_popen.call_args[0][0] + assert cmd == ["test-runtime", "--execute", "/resolved/bundles/test-bundle/dags/example.test"] + + # StartupDetails forwarded to the runtime subprocess + mock_send_startup.assert_called_once_with(runtime_comm, mock_startup) + + # _bridge called + mock_bridge.assert_called_once() + + @patch("airflow.sdk.execution_time.coordinator._bridge") + @patch("subprocess.Popen", autospec=True) + @patch("airflow.sdk.execution_time.coordinator._start_server") + @patch("os.dup", return_value=99) + def test_sets_process_context_env_var(self, mock_dup, mock_start_server, mock_popen, mock_bridge): + """Verify _AIRFLOW_PROCESS_CONTEXT is set to 'client'.""" + comm_server = MagicMock(spec=socket.socket) + comm_server.getsockname.return_value = ("127.0.0.1", 7000) + logs_server = MagicMock(spec=socket.socket) + logs_server.getsockname.return_value = ("127.0.0.1", 7001) + mock_start_server.side_effect = [comm_server, logs_server] + + runtime_comm = MagicMock(spec=socket.socket) + runtime_logs = MagicMock(spec=socket.socket) + comm_server.accept.return_value = (runtime_comm, ("127.0.0.1", 9000)) + logs_server.accept.return_value = (runtime_logs, ("127.0.0.1", 9001)) + + child_stderr = MagicMock(spec=socket.socket) + read_stderr = MagicMock(spec=socket.socket) + child_stderr.fileno.return_value = 10 + + class TestCoordinator(BaseCoordinator): + sdk = "test" + file_extension = ".test" + + @classmethod + def dag_parsing_cmd(cls, **kwargs): + return ["echo", "test"] + + info = BaseCoordinator.DagParsingInfo( + dag_file_path="/dag.test", + bundle_name="b", + bundle_path="/path", + ) + + supervisor_comm = MagicMock(spec=socket.socket) + + old_val = os.environ.get("_AIRFLOW_PROCESS_CONTEXT") + try: + with ( + patch("socket.socketpair", return_value=(child_stderr, read_stderr)), + patch("airflow.sdk.execution_time.coordinator.socket.socket", return_value=supervisor_comm), + ): + TestCoordinator._runtime_subprocess_entrypoint(info) + assert os.environ["_AIRFLOW_PROCESS_CONTEXT"] == "client" + finally: + if old_val is None: + os.environ.pop("_AIRFLOW_PROCESS_CONTEXT", None) + else: + os.environ["_AIRFLOW_PROCESS_CONTEXT"] = old_val diff --git a/task-sdk/tests/task_sdk/execution_time/test_selector_loop.py b/task-sdk/tests/task_sdk/execution_time/test_selector_loop.py new file mode 100644 index 0000000000000..efbfa83adecf8 --- /dev/null +++ b/task-sdk/tests/task_sdk/execution_time/test_selector_loop.py @@ -0,0 +1,479 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import selectors +import socket +from unittest.mock import MagicMock + +import pytest + +from airflow.sdk.execution_time.selector_loop import ( + make_buffered_socket_reader, + make_raw_forwarder, + service_selector, +) + + +def _make_generator(): + """Return a generator that collects sent lines into a list.""" + received: list[bytes | bytearray] = [] + + def gen(): + while True: + line = yield + received.append(bytes(line)) + + g = gen() + return g, received + + +def _make_socket_pair(): + """Create a connected TCP socket pair on localhost.""" + server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + server.bind(("127.0.0.1", 0)) + server.listen(1) + addr = server.getsockname() + + client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + client.connect(addr) + conn, _ = server.accept() + server.close() + return client, conn + + +class TestMakeBufferedSocketReader: + def test_single_complete_line(self): + gen, received = _make_generator() + on_close = MagicMock() + handler, returned_on_close = make_buffered_socket_reader(gen, on_close) + + sock = MagicMock(spec=socket.socket) + # recv_into writes data and returns count + data = b"hello world\n" + sock.recv_into.side_effect = lambda buf: _fill_buffer(buf, data) + + result = handler(sock) + + assert result is True + assert received == [b"hello world\n"] + assert returned_on_close is on_close + + def test_multiple_lines_in_single_recv(self): + gen, received = _make_generator() + on_close = MagicMock() + handler, _ = make_buffered_socket_reader(gen, on_close) + + sock = MagicMock(spec=socket.socket) + data = b"line1\nline2\nline3\n" + sock.recv_into.side_effect = lambda buf: _fill_buffer(buf, data) + + result = handler(sock) + + assert result is True + assert received == [b"line1\n", b"line2\n", b"line3\n"] + + def test_partial_line_accumulated_across_calls(self): + gen, received = _make_generator() + on_close = MagicMock() + handler, _ = make_buffered_socket_reader(gen, on_close) + + sock = MagicMock(spec=socket.socket) + + # First call: partial line (no newline) + sock.recv_into.side_effect = lambda buf: _fill_buffer(buf, b"hell") + result = handler(sock) + assert result is True + assert received == [] + + # Second call: rest of the line + sock.recv_into.side_effect = lambda buf: _fill_buffer(buf, b"o\n") + result = handler(sock) + assert result is True + assert received == [b"hello\n"] + + def test_eof_flushes_remaining_buffer(self): + gen, received = _make_generator() + on_close = MagicMock() + handler, _ = make_buffered_socket_reader(gen, on_close) + + sock = MagicMock(spec=socket.socket) + + # Send partial data (no newline) + sock.recv_into.side_effect = lambda buf: _fill_buffer(buf, b"leftover") + handler(sock) + assert received == [] + + # EOF (recv_into returns 0) — clear side_effect so return_value takes effect + sock.recv_into.side_effect = None + sock.recv_into.return_value = 0 + result = handler(sock) + + assert result is False + assert received == [b"leftover"] + + def test_eof_with_empty_buffer(self): + gen, received = _make_generator() + on_close = MagicMock() + handler, _ = make_buffered_socket_reader(gen, on_close) + + sock = MagicMock(spec=socket.socket) + sock.recv_into.return_value = 0 + + result = handler(sock) + + assert result is False + assert received == [] + + def test_generator_stop_iteration_returns_false(self): + """If the generator is exhausted, handler returns False.""" + + def limited_gen(): + yield # startup + yield # receive one line, then stop + + gen = limited_gen() + on_close = MagicMock() + handler, _ = make_buffered_socket_reader(gen, on_close) + + sock = MagicMock(spec=socket.socket) + # First line succeeds + sock.recv_into.side_effect = lambda buf: _fill_buffer(buf, b"line1\n") + result = handler(sock) + assert result is True + + # Second line triggers StopIteration in the generator + sock.recv_into.side_effect = lambda buf: _fill_buffer(buf, b"line2\n") + result = handler(sock) + assert result is False + + def test_mixed_complete_and_partial_lines(self): + gen, received = _make_generator() + on_close = MagicMock() + handler, _ = make_buffered_socket_reader(gen, on_close) + + sock = MagicMock(spec=socket.socket) + # Data contains one complete line and a partial line + sock.recv_into.side_effect = lambda buf: _fill_buffer(buf, b"complete\npart") + handler(sock) + assert received == [b"complete\n"] + + # Finish the partial line + sock.recv_into.side_effect = lambda buf: _fill_buffer(buf, b"ial\n") + handler(sock) + assert received == [b"complete\n", b"partial\n"] + + def test_custom_buffer_size(self): + gen, received = _make_generator() + on_close = MagicMock() + handler, _ = make_buffered_socket_reader(gen, on_close, buffer_size=8) + + sock = MagicMock(spec=socket.socket) + # Data larger than buffer_size — recv_into only reads buffer_size bytes + full_data = b"abcdefghijklmnop\n" + # Simulate chunked reads + sock.recv_into.side_effect = lambda buf: _fill_buffer(buf, full_data[: len(buf)]) + handler(sock) + # Only first 8 bytes read, no newline yet + assert received == [] + + sock.recv_into.side_effect = lambda buf: _fill_buffer(buf, full_data[8:16]) + handler(sock) + assert received == [] + + sock.recv_into.side_effect = lambda buf: _fill_buffer(buf, full_data[16:]) + handler(sock) + assert received == [b"abcdefghijklmnop\n"] + + +def _fill_buffer(buf: bytearray, data: bytes) -> int: + """Helper to simulate socket.recv_into by filling the buffer.""" + n = min(len(data), len(buf)) + buf[:n] = data[:n] + return n + + +class TestMakeRawForwarder: + def test_forwards_data_to_dest(self): + on_close = MagicMock() + dest = MagicMock(spec=socket.socket) + handler, returned_on_close = make_raw_forwarder(dest, on_close) + + src = MagicMock(spec=socket.socket) + src.recv.return_value = b"hello" + + result = handler(src) + + assert result is True + dest.sendall.assert_called_once_with(b"hello") + assert returned_on_close is on_close + + def test_eof_returns_false(self): + on_close = MagicMock() + dest = MagicMock(spec=socket.socket) + handler, _ = make_raw_forwarder(dest, on_close) + + src = MagicMock(spec=socket.socket) + src.recv.return_value = b"" + + result = handler(src) + + assert result is False + dest.sendall.assert_not_called() + + @pytest.mark.parametrize( + "exception", + [BrokenPipeError, ConnectionResetError, OSError], + ids=["broken_pipe", "connection_reset", "os_error"], + ) + def test_sendall_exception_returns_false(self, exception): + on_close = MagicMock() + dest = MagicMock(spec=socket.socket) + dest.sendall.side_effect = exception + handler, _ = make_raw_forwarder(dest, on_close) + + src = MagicMock(spec=socket.socket) + src.recv.return_value = b"data" + + result = handler(src) + + assert result is False + + def test_multiple_forwards(self): + on_close = MagicMock() + dest = MagicMock(spec=socket.socket) + handler, _ = make_raw_forwarder(dest, on_close) + + src = MagicMock(spec=socket.socket) + + for chunk in [b"chunk1", b"chunk2", b"chunk3"]: + src.recv.return_value = chunk + assert handler(src) is True + + assert dest.sendall.call_count == 3 + + +class TestServiceSelector: + def test_calls_handler_for_ready_sockets(self): + sel = MagicMock(spec=selectors.DefaultSelector) + handler = MagicMock(return_value=True) + on_close = MagicMock() + sock = MagicMock(spec=socket.socket) + + key = MagicMock() + key.data = (handler, on_close) + key.fileobj = sock + + sel.select.return_value = [(key, selectors.EVENT_READ)] + + service_selector(sel, timeout=1.0) + + handler.assert_called_once_with(sock) + on_close.assert_not_called() + sock.close.assert_not_called() + + def test_on_close_and_sock_close_when_handler_returns_false(self): + sel = MagicMock(spec=selectors.DefaultSelector) + handler = MagicMock(return_value=False) + on_close = MagicMock() + sock = MagicMock(spec=socket.socket) + + key = MagicMock() + key.data = (handler, on_close) + key.fileobj = sock + + sel.select.return_value = [(key, selectors.EVENT_READ)] + + service_selector(sel, timeout=1.0) + + handler.assert_called_once_with(sock) + on_close.assert_called_once_with(sock) + sock.close.assert_called_once() + + @pytest.mark.parametrize( + "exception", + [BrokenPipeError, ConnectionResetError], + ids=["broken_pipe", "connection_reset"], + ) + def test_pipe_errors_treated_as_eof(self, exception): + sel = MagicMock(spec=selectors.DefaultSelector) + handler = MagicMock(side_effect=exception) + on_close = MagicMock() + sock = MagicMock(spec=socket.socket) + + key = MagicMock() + key.data = (handler, on_close) + key.fileobj = sock + + sel.select.return_value = [(key, selectors.EVENT_READ)] + + service_selector(sel, timeout=1.0) + + on_close.assert_called_once_with(sock) + sock.close.assert_called_once() + + def test_empty_selector_no_events(self): + sel = MagicMock(spec=selectors.DefaultSelector) + sel.select.return_value = [] + + # Should not raise + service_selector(sel, timeout=1.0) + + @pytest.mark.parametrize( + ("input_timeout", "expected_min"), + [ + (0.0, 0.01), + (-1.0, 0.01), + (-100.0, 0.01), + (0.5, 0.5), + (2.0, 2.0), + ], + ids=["zero", "negative", "very_negative", "positive_half", "positive_two"], + ) + def test_timeout_clamped_to_minimum(self, input_timeout, expected_min): + sel = MagicMock(spec=selectors.DefaultSelector) + sel.select.return_value = [] + + service_selector(sel, timeout=input_timeout) + + sel.select.assert_called_once() + actual_timeout = sel.select.call_args[1].get("timeout") or sel.select.call_args[0][0] + assert actual_timeout == pytest.approx(expected_min) + + def test_multiple_ready_sockets(self): + sel = MagicMock(spec=selectors.DefaultSelector) + + handler1 = MagicMock(return_value=True) + on_close1 = MagicMock() + sock1 = MagicMock(spec=socket.socket) + key1 = MagicMock() + key1.data = (handler1, on_close1) + key1.fileobj = sock1 + + handler2 = MagicMock(return_value=False) + on_close2 = MagicMock() + sock2 = MagicMock(spec=socket.socket) + key2 = MagicMock() + key2.data = (handler2, on_close2) + key2.fileobj = sock2 + + sel.select.return_value = [(key1, selectors.EVENT_READ), (key2, selectors.EVENT_READ)] + + service_selector(sel, timeout=1.0) + + # First socket: handler returns True, stays open + handler1.assert_called_once_with(sock1) + on_close1.assert_not_called() + sock1.close.assert_not_called() + + # Second socket: handler returns False, closed + handler2.assert_called_once_with(sock2) + on_close2.assert_called_once_with(sock2) + sock2.close.assert_called_once() + + +class TestSelectorLoopIntegration: + def test_buffered_reader_with_real_sockets(self): + """End-to-end: send lines through real sockets and verify buffered reading.""" + gen, received = _make_generator() + sender, reader = _make_socket_pair() + try: + sel = selectors.DefaultSelector() + + def on_close(sock): + sel.unregister(sock) + + sel.register(reader, selectors.EVENT_READ, make_buffered_socket_reader(gen, on_close)) + + sender.sendall(b"first line\nsecond line\n") + + service_selector(sel, timeout=1.0) + + assert b"first line\n" in received + assert b"second line\n" in received + + # Close sender, then drain + sender.close() + sender = None + + service_selector(sel, timeout=0.5) + + sel.close() + finally: + if sender: + sender.close() + reader.close() + + def test_raw_forwarder_with_real_sockets(self): + """End-to-end: forward raw bytes between real socket pairs.""" + src_send, src_recv = _make_socket_pair() + # Use socketpair for the destination so reads/writes are symmetric + dst_write, dst_read = socket.socketpair() + try: + sel = selectors.DefaultSelector() + + def on_close(sock): + sel.unregister(sock) + + sel.register(src_recv, selectors.EVENT_READ, make_raw_forwarder(dst_write, on_close)) + + src_send.sendall(b"raw data payload") + + service_selector(sel, timeout=1.0) + + dst_read.setblocking(False) + forwarded = dst_read.recv(4096) + + assert forwarded == b"raw data payload" + + sel.close() + finally: + for s in (src_send, src_recv, dst_write, dst_read): + s.close() + + def test_eof_triggers_on_close_with_real_sockets(self): + """When the sender closes, the selector callback chain fires on_close.""" + gen, received = _make_generator() + sender, reader = _make_socket_pair() + closed_sockets: list[socket.socket] = [] + try: + sel = selectors.DefaultSelector() + + def on_close(sock): + sel.unregister(sock) + closed_sockets.append(sock) + + sel.register(reader, selectors.EVENT_READ, make_buffered_socket_reader(gen, on_close)) + + # Send data then close + sender.sendall(b"final\n") + service_selector(sel, timeout=1.0) + assert received == [b"final\n"] + + sender.close() + sender = None + service_selector(sel, timeout=0.5) + + # on_close should have been called, and socket closed by service_selector + assert len(closed_sockets) == 1 + + sel.close() + finally: + if sender: + sender.close() + reader.close() diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index 3695af1fff592..1b66c17ed47c1 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -60,7 +60,6 @@ DagRunState, DagRunType, PreviousTIResponse, - TaskInstance, TaskInstanceState, ) from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType, TaskAlreadyRunningError @@ -147,6 +146,7 @@ supervise_task, ) from airflow.sdk.execution_time.task_runner import run +from airflow.sdk.execution_time.workloads.task import TaskInstanceDTO from tests_common.test_utils.config import conf_vars @@ -207,13 +207,16 @@ def test_supervise( """ Test that the supervisor validates server URL and dry_run parameter combinations correctly. """ - ti = TaskInstance( + ti = TaskInstanceDTO( id=uuid7(), task_id="async", dag_id="super_basic_deferred_run", run_id="d", try_number=1, dag_version_id=uuid7(), + pool_slots=1, + queue="default", + priority_weight=1, ) bundle_info = BundleInfo(name="my-bundle", version=None) @@ -300,13 +303,16 @@ def subprocess_main(): proc = ActivitySubprocess.start( dag_rel_path=os.devnull, bundle_info=FAKE_BUNDLE, - what=TaskInstance( + what=TaskInstanceDTO( id="4d828a62-a417-4936-a7a6-2b3fabacecab", task_id="b", dag_id="c", run_id="d", try_number=1, dag_version_id=uuid7(), + pool_slots=1, + queue="default", + priority_weight=1, ), client=client_with_ti_start, target=subprocess_main, @@ -375,13 +381,16 @@ def subprocess_main(): proc = ActivitySubprocess.start( dag_rel_path=os.devnull, bundle_info=FAKE_BUNDLE, - what=TaskInstance( + what=TaskInstanceDTO( id="4d828a62-a417-4936-a7a6-2b3fabacecab", task_id="b", dag_id="c", run_id="d", try_number=1, dag_version_id=uuid7(), + pool_slots=1, + queue="default", + priority_weight=1, ), client=client_with_ti_start, target=subprocess_main, @@ -472,13 +481,16 @@ def on_kill(self) -> None: proc = ActivitySubprocess.start( dag_rel_path=os.devnull, bundle_info=FAKE_BUNDLE, - what=TaskInstance( + what=TaskInstanceDTO( id=ti_id, task_id="b", dag_id="c", run_id="d", try_number=1, dag_version_id=uuid7(), + pool_slots=1, + queue="default", + priority_weight=1, ), client=make_client(transport=httpx.MockTransport(handle_request)), target=subprocess_main, @@ -501,13 +513,16 @@ def subprocess_main(): proc = ActivitySubprocess.start( dag_rel_path=os.devnull, bundle_info=FAKE_BUNDLE, - what=TaskInstance( + what=TaskInstanceDTO( id="4d828a62-a417-4936-a7a6-2b3fabacecab", task_id="b", dag_id="c", run_id="d", try_number=1, dag_version_id=uuid7(), + pool_slots=1, + queue="default", + priority_weight=1, ), client=client_with_ti_start, target=subprocess_main, @@ -536,8 +551,16 @@ def subprocess_main(): proc = ActivitySubprocess.start( dag_rel_path=os.devnull, bundle_info=FAKE_BUNDLE, - what=TaskInstance( - id=uuid7(), task_id="b", dag_id="c", run_id="d", try_number=1, dag_version_id=uuid7() + what=TaskInstanceDTO( + id=uuid7(), + task_id="b", + dag_id="c", + run_id="d", + try_number=1, + dag_version_id=uuid7(), + pool_slots=1, + queue="default", + priority_weight=1, ), client=mock_client, target=subprocess_main, @@ -575,13 +598,16 @@ def test_resume_start_date_from_context(self, mocker, make_ti_context, start_dat proc = ActivitySubprocess.start( dag_rel_path=os.devnull, bundle_info=FAKE_BUNDLE, - what=TaskInstance( + what=TaskInstanceDTO( id=uuid7(), task_id="b", dag_id="c", run_id="d", try_number=1, dag_version_id=uuid7(), + pool_slots=1, + queue="default", + priority_weight=1, ), client=mock_client, target=lambda: None, @@ -618,8 +644,16 @@ def subprocess_main(): proc = ActivitySubprocess.start( dag_rel_path=os.devnull, bundle_info=FAKE_BUNDLE, - what=TaskInstance( - id=ti_id, task_id="b", dag_id="c", run_id="d", try_number=1, dag_version_id=uuid7() + what=TaskInstanceDTO( + id=ti_id, + task_id="b", + dag_id="c", + run_id="d", + try_number=1, + dag_version_id=uuid7(), + pool_slots=1, + queue="default", + priority_weight=1, ), client=sdk_client.Client(base_url="", dry_run=True, token=""), target=subprocess_main, @@ -655,8 +689,16 @@ def _on_child_started(self, *args, **kwargs): proc = ActivitySubprocess.start( dag_rel_path=os.devnull, bundle_info=FAKE_BUNDLE, - what=TaskInstance( - id=ti_id, task_id="b", dag_id="c", run_id="d", try_number=1, dag_version_id=uuid7() + what=TaskInstanceDTO( + id=ti_id, + task_id="b", + dag_id="c", + run_id="d", + try_number=1, + dag_version_id=uuid7(), + pool_slots=1, + queue="default", + priority_weight=1, ), client=sdk_client.Client(base_url="", dry_run=True, token=""), target=subprocess_main, @@ -671,13 +713,16 @@ def test_run_simple_dag(self, test_dags_dir, captured_logs, time_machine, mocker time_machine.move_to(instant, tick=False) dagfile_path = test_dags_dir - ti = TaskInstance( + ti = TaskInstanceDTO( id=uuid7(), task_id="hello", dag_id="super_basic_run", run_id="c", try_number=1, dag_version_id=uuid7(), + pool_slots=1, + queue="default", + priority_weight=1, ) bundle_info = BundleInfo(name="my-bundle", version=None) @@ -712,13 +757,16 @@ def test_supervise_handles_deferred_task( """ instant = timezone.datetime(2024, 11, 7, 12, 34, 56, 0) - ti = TaskInstance( + ti = TaskInstanceDTO( id=uuid7(), task_id="async", dag_id="super_basic_deferred_run", run_id="d", try_number=1, dag_version_id=uuid7(), + pool_slots=1, + queue="default", + priority_weight=1, ) # Create a mock client to assert calls to the client @@ -839,8 +887,16 @@ def handle_request(request: httpx.Request) -> httpx.Response: proc = ActivitySubprocess.start( dag_rel_path=os.devnull, - what=TaskInstance( - id=ti_id, task_id="b", dag_id="c", run_id="d", try_number=1, dag_version_id=uuid7() + what=TaskInstanceDTO( + id=ti_id, + task_id="b", + dag_id="c", + run_id="d", + try_number=1, + dag_version_id=uuid7(), + pool_slots=1, + queue="default", + priority_weight=1, ), client=make_client(transport=httpx.MockTransport(handle_request)), target=subprocess_main, @@ -917,8 +973,16 @@ def subprocess_main(): ActivitySubprocess.start( dag_rel_path=os.devnull, bundle_info=FAKE_BUNDLE, - what=TaskInstance( - id=ti_id, task_id="b", dag_id="c", run_id="d", try_number=1, dag_version_id=uuid7() + what=TaskInstanceDTO( + id=ti_id, + task_id="b", + dag_id="c", + run_id="d", + try_number=1, + dag_version_id=uuid7(), + pool_slots=1, + queue="default", + priority_weight=1, ), client=make_client(transport=httpx.MockTransport(handle_request)), target=subprocess_main, @@ -1122,13 +1186,16 @@ def subprocess_main(): proc = ActivitySubprocess.start( dag_rel_path=os.devnull, bundle_info=FAKE_BUNDLE, - what=TaskInstance( + what=TaskInstanceDTO( id="4d828a62-a417-4936-a7a6-2b3fabacecab", task_id="b", dag_id="c", run_id="d", try_number=1, dag_version_id=uuid7(), + pool_slots=1, + queue="default", + priority_weight=1, ), client=client_with_ti_start, target=subprocess_main, @@ -1285,8 +1352,16 @@ def _handler(sig, frame): proc = ActivitySubprocess.start( dag_rel_path=os.devnull, bundle_info=FAKE_BUNDLE, - what=TaskInstance( - id=ti_id, task_id="b", dag_id="c", run_id="d", try_number=1, dag_version_id=uuid7() + what=TaskInstanceDTO( + id=ti_id, + task_id="b", + dag_id="c", + run_id="d", + try_number=1, + dag_version_id=uuid7(), + pool_slots=1, + queue="default", + priority_weight=1, ), client=client_with_ti_start, target=subprocess_main, @@ -3300,13 +3375,16 @@ def subprocess_main(): proc = ActivitySubprocess.start( dag_rel_path=os.devnull, bundle_info=FAKE_BUNDLE, - what=TaskInstance( + what=TaskInstanceDTO( id="4d828a62-a417-4936-a7a6-2b3fabacecab", task_id="b", dag_id="c", run_id="d", try_number=1, dag_version_id=uuid7(), + pool_slots=1, + queue="default", + priority_weight=1, ), client=client_with_ti_start, target=subprocess_main, diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index 630aff9094ed1..476db10184713 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -146,6 +146,7 @@ run, startup, ) +from airflow.sdk.execution_time.workloads.task import TaskInstanceDTO from airflow.sdk.execution_time.xcom import XCom from airflow.sdk.serde import deserialize from airflow.triggers.base import BaseEventTrigger, BaseTrigger, TriggerEvent @@ -177,13 +178,16 @@ def execute(self, context): def test_parse(test_dags_dir: Path, make_ti_context): """Test that checks parsing of a basic dag with an un-mocked parse.""" what = StartupDetails( - ti=TaskInstance( + ti=TaskInstanceDTO( id=uuid7(), task_id="a", dag_id="super_basic", run_id="c", try_number=1, dag_version_id=uuid7(), + pool_slots=1, + queue="default", + priority_weight=1, ), dag_rel_path="super_basic.py", bundle_info=BundleInfo(name="my-bundle", version=None), @@ -224,13 +228,16 @@ def test_parse_dag_bag(mock_dagbag, test_dags_dir: Path, make_ti_context): mock_dag.task_dict = {"a": mock_task} what = StartupDetails( - ti=TaskInstance( + ti=TaskInstanceDTO( id=uuid7(), task_id="a", dag_id="super_basic", run_id="c", try_number=1, dag_version_id=uuid7(), + pool_slots=1, + queue="default", + priority_weight=1, ), dag_rel_path="super_basic.py", bundle_info=BundleInfo(name="my-bundle", version=None), @@ -284,13 +291,16 @@ def test_parse_dag_bag(mock_dagbag, test_dags_dir: Path, make_ti_context): def test_parse_not_found(test_dags_dir: Path, make_ti_context, dag_id, task_id, expected_error): """Check for nice error messages on dag not found.""" what = StartupDetails( - ti=TaskInstance( + ti=TaskInstanceDTO( id=uuid7(), task_id=task_id, dag_id=dag_id, run_id="c", try_number=1, dag_version_id=uuid7(), + pool_slots=1, + queue="default", + priority_weight=1, ), dag_rel_path="super_basic.py", bundle_info=BundleInfo(name="my-bundle", version=None), @@ -330,13 +340,16 @@ def test_parse_not_found_does_not_reschedule_when_max_attempts_reached(test_dags and should surface as a hard failure (SystemExit in the task runner process). """ what = StartupDetails( - ti=TaskInstance( + ti=TaskInstanceDTO( id=uuid7(), task_id="a", dag_id="madeup_dag_id", run_id="c", try_number=1, dag_version_id=uuid7(), + pool_slots=1, + queue="default", + priority_weight=1, ), dag_rel_path="super_basic.py", bundle_info=BundleInfo(name="my-bundle", version=None), @@ -391,13 +404,16 @@ def test_main_sends_reschedule_task_when_startup_reschedules( mock_comms_instance.socket = None mock_comms_decoder_cls.__getitem__.return_value.return_value = mock_comms_instance what = StartupDetails( - ti=TaskInstance( + ti=TaskInstanceDTO( id=uuid7(), task_id="my_task", dag_id="test_dag", run_id="test_run", try_number=1, dag_version_id=uuid7(), + pool_slots=1, + queue="default", + priority_weight=1, context_carrier={}, ), dag_rel_path="", @@ -484,13 +500,16 @@ def test_task_span_is_child_of_dag_run_span(make_ti_context): # Step 3: build StartupDetails with ti.context_carrier = ti_carrier. what = StartupDetails( - ti=TaskInstance( + ti=TaskInstanceDTO( id=uuid7(), task_id="my_task", dag_id="test_dag", run_id="test_run", try_number=1, dag_version_id=uuid7(), + pool_slots=1, + queue="default", + priority_weight=1, context_carrier=ti_carrier, ), dag_rel_path="", @@ -552,13 +571,16 @@ def test_task_span_no_parent_when_no_context_carrier(make_ti_context): provider.add_span_processor(SimpleSpanProcessor(in_mem_exporter)) what = StartupDetails( - ti=TaskInstance( + ti=TaskInstanceDTO( id=uuid7(), task_id="standalone_task", dag_id="test_dag", run_id="test_run", try_number=1, dag_version_id=uuid7(), + pool_slots=1, + queue="default", + priority_weight=1, context_carrier=None, ), dag_rel_path="", @@ -593,13 +615,16 @@ def test_parse_module_in_bundle_root(tmp_path: Path, make_ti_context): dag1_path.write_text(textwrap.dedent(dag1_code)) what = StartupDetails( - ti=TaskInstance( + ti=TaskInstanceDTO( id=uuid7(), task_id="a", dag_id="dag_name", run_id="c", try_number=1, dag_version_id=uuid7(), + pool_slots=1, + queue="default", + priority_weight=1, ), dag_rel_path="path_test.py", bundle_info=BundleInfo(name="my-bundle", version=None), @@ -1040,13 +1065,16 @@ def test_basic_templated_dag(mocked_parse, make_ti_context, mock_supervisor_comm ) what = StartupDetails( - ti=TaskInstance( + ti=TaskInstanceDTO( id=uuid7(), task_id="templated_task", dag_id="basic_templated_dag", run_id="c", try_number=1, dag_version_id=uuid7(), + pool_slots=1, + queue="default", + priority_weight=1, ), bundle_info=FAKE_BUNDLE, dag_rel_path="", @@ -1156,13 +1184,16 @@ def execute(self, context): instant = timezone.datetime(2024, 12, 3, 10, 0) what = StartupDetails( - ti=TaskInstance( + ti=TaskInstanceDTO( id=uuid7(), task_id="templated_task", dag_id="basic_dag", run_id="c", try_number=1, dag_version_id=uuid7(), + pool_slots=1, + queue="default", + priority_weight=1, ), dag_rel_path="", bundle_info=FAKE_BUNDLE, @@ -1204,13 +1235,16 @@ def execute(self, context): instant = timezone.datetime(2024, 12, 3, 10, 0) what = StartupDetails( - ti=TaskInstance( + ti=TaskInstanceDTO( id=uuid7(), task_id="impersonation_task", dag_id="basic_dag", run_id="c", try_number=1, dag_version_id=uuid7(), + pool_slots=1, + queue="default", + priority_weight=1, ), dag_rel_path="", bundle_info=FAKE_BUNDLE, @@ -1252,13 +1286,16 @@ def execute(self, context): instant = timezone.datetime(2024, 12, 3, 10, 0) what = StartupDetails( - ti=TaskInstance( + ti=TaskInstanceDTO( id=uuid7(), task_id="impersonation_task", dag_id="basic_dag", run_id="c", try_number=1, dag_version_id=uuid7(), + pool_slots=1, + queue="default", + priority_weight=1, ), dag_rel_path="", bundle_info=FAKE_BUNDLE, @@ -1292,13 +1329,16 @@ def execute(self, context): instant = timezone.datetime(2024, 12, 3, 10, 0) what = StartupDetails( - ti=TaskInstance( + ti=TaskInstanceDTO( id=uuid7(), task_id="impersonation_task", dag_id="basic_dag", run_id="c", try_number=1, dag_version_id=uuid7(), + pool_slots=1, + queue="default", + priority_weight=1, ), dag_rel_path="", bundle_info=FAKE_BUNDLE, @@ -1465,8 +1505,16 @@ def test_dag_parsing_context(make_ti_context, mock_supervisor_comms, monkeypatch task_id = "conditional_task" what = StartupDetails( - ti=TaskInstance( - id=uuid7(), task_id=task_id, dag_id=dag_id, run_id="c", try_number=1, dag_version_id=uuid7() + ti=TaskInstanceDTO( + id=uuid7(), + task_id=task_id, + dag_id=dag_id, + run_id="c", + try_number=1, + dag_version_id=uuid7(), + pool_slots=1, + queue="default", + priority_weight=1, ), dag_rel_path="dag_parsing_context.py", bundle_info=BundleInfo(name="my-bundle", version=None), @@ -2061,8 +2109,10 @@ def execute(self, context): test_task_id = "pull_task" task = CustomOperator(task_id=test_task_id) - # In case of the specific map_index or None we should check it is passed to TI - extra_for_ti = {"map_index": map_indexes} if map_indexes in (1, None) else {} + # In case of the specific map_index we should check it is passed to TI. + # ``None`` is not a valid TaskInstanceDTO.map_index value, but xcom_pull's + # behaviour with ``map_indexes=None`` is independent of the TI's own map_index. + extra_for_ti = {"map_index": map_indexes} if isinstance(map_indexes, int) else {} runtime_ti = create_runtime_ti(task=task, **extra_for_ti) ser_value = BaseXCom.serialize_value(xcom_values) @@ -3883,13 +3933,16 @@ def execute(self, context): task_id="test_task_runner_calls_listeners", do_xcom_push=True, multiple_outputs=True ) what = StartupDetails( - ti=TaskInstance( + ti=TaskInstanceDTO( id=uuid7(), task_id="templated_task", dag_id="basic_dag", run_id="c", try_number=1, dag_version_id=uuid7(), + pool_slots=1, + queue="default", + priority_weight=1, ), dag_rel_path="", bundle_info=FAKE_BUNDLE, @@ -4491,7 +4544,8 @@ class CustomOperator(BaseOperator): class TestTriggerDagRunOperator: """Tests to verify various aspects of TriggerDagRunOperator""" - @time_machine.travel("2025-01-01 00:00:00", tick=False) + # make timetravel timezone-aware + @time_machine.travel(datetime(2025, 1, 1, 0, 0, 0, tzinfo=timezone.utc), tick=False) def test_handle_trigger_dag_run(self, create_runtime_ti, mock_supervisor_comms): """Test that TriggerDagRunOperator (with default args) sends the correct message to the Supervisor""" from airflow.providers.standard.operators.trigger_dagrun import TriggerDagRunOperator @@ -4539,7 +4593,7 @@ def test_handle_trigger_dag_run(self, create_runtime_ti, mock_supervisor_comms): (False, TaskInstanceState.FAILED), ], ) - @time_machine.travel("2025-01-01 00:00:00", tick=False) + @time_machine.travel(datetime(2025, 1, 1, 0, 0, 0, tzinfo=timezone.utc), tick=False) def test_handle_trigger_dag_run_conflict( self, skip_when_already_exists, expected_state, create_runtime_ti, mock_supervisor_comms ): @@ -4583,7 +4637,7 @@ def test_handle_trigger_dag_run_conflict( ([DagRunState.SUCCESS], None, DagRunState.FAILED, DagRunState.FAILED), ], ) - @time_machine.travel("2025-01-01 00:00:00", tick=False) + @time_machine.travel(datetime(2025, 1, 1, 0, 0, 0, tzinfo=timezone.utc), tick=False) def test_handle_trigger_dag_run_wait_for_completion( self, allowed_states, @@ -4704,7 +4758,7 @@ def test_handle_trigger_dag_run_deferred( assert state == intermediate_state - @time_machine.travel("2025-01-01 00:00:00", tick=False) + @time_machine.travel(datetime(2025, 1, 1, 0, 0, 0, tzinfo=timezone.utc), tick=False) def test_handle_trigger_dag_run_deferred_with_reset_uses_run_id_only( self, create_runtime_ti, mock_supervisor_comms ): diff --git a/task-sdk/tests/task_sdk/test_providers_manager_runtime.py b/task-sdk/tests/task_sdk/test_providers_manager_runtime.py index 1cae21d53c764..6e775f790be89 100644 --- a/task-sdk/tests/task_sdk/test_providers_manager_runtime.py +++ b/task-sdk/tests/task_sdk/test_providers_manager_runtime.py @@ -243,6 +243,33 @@ def test_already_initialized_provider_configs_emits_deprecation_warning(self): with pytest.warns(DeprecationWarning, match="already_initialized_provider_configs.*deprecated"): pm.already_initialized_provider_configs + @patch("airflow.sdk.providers_manager_runtime.import_string") + def test_coordinators(self, mock_import_string): + class ACoordinator: + pass + + class ZCoordinator: + pass + + mock_import_string.side_effect = lambda path: { + "airflow.providers.sdk.java.coordinator.ACoordinator": ACoordinator, + "airflow.providers.sdk.java.coordinator.ZCoordinator": ZCoordinator, + }[path] + providers_manager = ProvidersManagerTaskRuntime() + providers_manager._provider_dict["apache-airflow-providers-sdk-java"] = ProviderInfo( + version="0.0.1", + data={ + "coordinators": [ + "airflow.providers.sdk.java.coordinator.ZCoordinator", + "airflow.providers.sdk.java.coordinator.ACoordinator", + "airflow.providers.sdk.java.coordinator.ZCoordinator", + ] + }, + ) + + with patch.object(providers_manager, "initialize_providers_list"): + assert providers_manager.coordinators == [ACoordinator, ZCoordinator] + def test_initialize_provider_configs_can_reload_sdk_conf(self): from airflow.sdk.configuration import conf diff --git a/uv.lock b/uv.lock index b31aab16b3b71..58edbb2f959c0 100644 --- a/uv.lock +++ b/uv.lock @@ -80,8 +80,8 @@ apache-airflow-providers-salesforce = false apache-airflow-providers-ssh = false apache-airflow-providers-papermill = false apache-airflow-providers-google = false -apache-airflow-providers-microsoft-psrp = false apache-airflow-providers-vertica = false +apache-airflow-providers-microsoft-psrp = false apache-airflow-providers-apache-hdfs = false apache-airflow-shared-template-rendering = false apache-airflow-mypy = false @@ -152,6 +152,7 @@ apache-airflow-providers-smtp = false apache-airflow-providers-dingding = false apache-airflow-providers-apache-kylin = false apache-airflow-providers-cloudant = false +apache-airflow-providers-sdk-java = false apache-aurflow-docker-stack = false [manifest] @@ -249,6 +250,7 @@ members = [ "apache-airflow-providers-redis", "apache-airflow-providers-salesforce", "apache-airflow-providers-samba", + "apache-airflow-providers-sdk-java", "apache-airflow-providers-segment", "apache-airflow-providers-sendgrid", "apache-airflow-providers-sftp", @@ -1025,6 +1027,7 @@ all = [ { name = "apache-airflow-providers-redis" }, { name = "apache-airflow-providers-salesforce" }, { name = "apache-airflow-providers-samba" }, + { name = "apache-airflow-providers-sdk-java" }, { name = "apache-airflow-providers-segment" }, { name = "apache-airflow-providers-sendgrid" }, { name = "apache-airflow-providers-sftp" }, @@ -1344,6 +1347,9 @@ salesforce = [ samba = [ { name = "apache-airflow-providers-samba" }, ] +sdk-java = [ + { name = "apache-airflow-providers-sdk-java" }, +] segment = [ { name = "apache-airflow-providers-segment" }, ] @@ -1635,6 +1641,8 @@ requires-dist = [ { name = "apache-airflow-providers-salesforce", marker = "extra == 'salesforce'", editable = "providers/salesforce" }, { name = "apache-airflow-providers-samba", marker = "extra == 'all'", editable = "providers/samba" }, { name = "apache-airflow-providers-samba", marker = "extra == 'samba'", editable = "providers/samba" }, + { name = "apache-airflow-providers-sdk-java", marker = "extra == 'all'", editable = "providers/sdk/java" }, + { name = "apache-airflow-providers-sdk-java", marker = "extra == 'sdk-java'", editable = "providers/sdk/java" }, { name = "apache-airflow-providers-segment", marker = "extra == 'all'", editable = "providers/segment" }, { name = "apache-airflow-providers-segment", marker = "extra == 'segment'", editable = "providers/segment" }, { name = "apache-airflow-providers-sendgrid", marker = "extra == 'all'", editable = "providers/sendgrid" }, @@ -1685,7 +1693,7 @@ requires-dist = [ { name = "sentry-sdk", marker = "extra == 'sentry'", specifier = ">=2.30.0" }, { name = "uv", marker = "extra == 'uv'", specifier = ">=0.11.8" }, ] -provides-extras = ["all-core", "async", "graphviz", "gunicorn", "kerberos", "memray", "otel", "statsd", "all-task-sdk", "airbyte", "akeyless", "alibaba", "amazon", "apache-cassandra", "apache-drill", "apache-druid", "apache-flink", "apache-hdfs", "apache-hive", "apache-iceberg", "apache-impala", "apache-kafka", "apache-kylin", "apache-livy", "apache-pig", "apache-pinot", "apache-spark", "apache-tinkerpop", "apprise", "arangodb", "asana", "atlassian-jira", "celery", "cloudant", "cncf-kubernetes", "cohere", "common-ai", "common-compat", "common-io", "common-messaging", "common-sql", "databricks", "datadog", "dbt-cloud", "dingding", "discord", "docker", "edge3", "elasticsearch", "exasol", "fab", "facebook", "ftp", "git", "github", "google", "grpc", "hashicorp", "http", "imap", "influxdb", "informatica", "jdbc", "jenkins", "keycloak", "microsoft-azure", "microsoft-mssql", "microsoft-psrp", "microsoft-winrm", "mongo", "mysql", "neo4j", "odbc", "openai", "openfaas", "openlineage", "opensearch", "opsgenie", "oracle", "pagerduty", "papermill", "pgvector", "pinecone", "postgres", "presto", "qdrant", "redis", "salesforce", "samba", "segment", "sendgrid", "sftp", "singularity", "slack", "smtp", "snowflake", "sqlite", "ssh", "standard", "tableau", "telegram", "teradata", "trino", "vertica", "vespa", "weaviate", "yandex", "ydb", "zendesk", "all", "aiobotocore", "apache-atlas", "apache-webhdfs", "amazon-aws-auth", "cloudpickle", "github-enterprise", "google-auth", "ldap", "pandas", "polars", "rabbitmq", "sentry", "s3fs", "uv"] +provides-extras = ["all-core", "async", "graphviz", "gunicorn", "kerberos", "memray", "otel", "statsd", "all-task-sdk", "airbyte", "akeyless", "alibaba", "amazon", "apache-cassandra", "apache-drill", "apache-druid", "apache-flink", "apache-hdfs", "apache-hive", "apache-iceberg", "apache-impala", "apache-kafka", "apache-kylin", "apache-livy", "apache-pig", "apache-pinot", "apache-spark", "apache-tinkerpop", "apprise", "arangodb", "asana", "atlassian-jira", "celery", "cloudant", "cncf-kubernetes", "cohere", "common-ai", "common-compat", "common-io", "common-messaging", "common-sql", "databricks", "datadog", "dbt-cloud", "dingding", "discord", "docker", "edge3", "elasticsearch", "exasol", "fab", "facebook", "ftp", "git", "github", "google", "grpc", "hashicorp", "http", "imap", "influxdb", "informatica", "jdbc", "jenkins", "keycloak", "microsoft-azure", "microsoft-mssql", "microsoft-psrp", "microsoft-winrm", "mongo", "mysql", "neo4j", "odbc", "openai", "openfaas", "openlineage", "opensearch", "opsgenie", "oracle", "pagerduty", "papermill", "pgvector", "pinecone", "postgres", "presto", "qdrant", "redis", "salesforce", "samba", "sdk-java", "segment", "sendgrid", "sftp", "singularity", "slack", "smtp", "snowflake", "sqlite", "ssh", "standard", "tableau", "telegram", "teradata", "trino", "vertica", "vespa", "weaviate", "yandex", "ydb", "zendesk", "all", "aiobotocore", "apache-atlas", "apache-webhdfs", "amazon-aws-auth", "cloudpickle", "github-enterprise", "google-auth", "ldap", "pandas", "polars", "rabbitmq", "sentry", "s3fs", "uv"] [package.metadata.requires-dev] dev = [ @@ -7027,6 +7035,46 @@ dev = [ ] docs = [{ name = "apache-airflow-devel-common", extras = ["docs"], editable = "devel-common" }] +[[package]] +name = "apache-airflow-providers-sdk-java" +version = "0.1.0" +source = { editable = "providers/sdk/java" } +dependencies = [ + { name = "apache-airflow" }, +] + +[package.optional-dependencies] +common-compat = [ + { name = "apache-airflow-providers-common-compat" }, +] + +[package.dev-dependencies] +dev = [ + { name = "apache-airflow" }, + { name = "apache-airflow-devel-common" }, + { name = "apache-airflow-providers-common-compat" }, + { name = "apache-airflow-task-sdk" }, +] +docs = [ + { name = "apache-airflow-devel-common", extra = ["docs"] }, +] + +[package.metadata] +requires-dist = [ + { name = "apache-airflow", editable = "." }, + { name = "apache-airflow-providers-common-compat", marker = "extra == 'common-compat'", editable = "providers/common/compat" }, +] +provides-extras = ["common-compat"] + +[package.metadata.requires-dev] +dev = [ + { name = "apache-airflow", editable = "." }, + { name = "apache-airflow-devel-common", editable = "devel-common" }, + { name = "apache-airflow-providers-common-compat", editable = "providers/common/compat" }, + { name = "apache-airflow-task-sdk", editable = "task-sdk" }, +] +docs = [{ name = "apache-airflow-devel-common", extras = ["docs"], editable = "devel-common" }] + [[package]] name = "apache-airflow-providers-segment" version = "3.9.4"