From 0f9ac5c9e8517f2841e253488719f18b8aec5235 Mon Sep 17 00:00:00 2001 From: Shivam Rastogi <6463385+shivaam@users.noreply.github.com> Date: Sat, 18 Apr 2026 18:34:55 -0700 Subject: [PATCH 1/2] Add ExecuteCallback support to AWS ECS Executor Enables the ECS executor to dispatch ExecuteCallback workloads (deadline alerts) alongside regular ExecuteTask workloads. Builds on #65392 which widened BaseExecutor signatures to accept WorkloadKey. - supports_callbacks = True (gated on AIRFLOW_V_3_3_PLUS) - Widen key types to WorkloadKey throughout EcsQueuedTask / EcsTaskCollection - Branch _process_workloads on ExecuteTask vs ExecuteCallback - Add AIRFLOW_V_3_3_PLUS to version_compat.py - Unit tests for queueing, processing, serialization, sync, mixed keys --- .../amazon/aws/executors/ecs/ecs_executor.py | 82 +++++++------ .../amazon/aws/executors/ecs/utils.py | 48 ++++---- .../aws/executors/ecs/test_ecs_executor.py | 116 +++++++++++++++++- 3 files changed, 185 insertions(+), 61 deletions(-) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py b/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py index c35ba5c0fa2e0..3a1fc19f21a3c 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py @@ -46,15 +46,14 @@ exponential_backoff_retry, ) from airflow.providers.amazon.aws.hooks.ecs import EcsHook -from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS +from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_3_PLUS from airflow.providers.common.compat.sdk import AirflowException, Stats, timezone from airflow.utils.helpers import merge_dicts from airflow.utils.state import State if TYPE_CHECKING: - from sqlalchemy.orm import Session - from airflow.executors import workloads + from airflow.executors.workloads.types import WorkloadKey from airflow.models.taskinstance import TaskInstance, TaskInstanceKey from airflow.providers.amazon.aws.executors.ecs.utils import ( CommandType, @@ -92,6 +91,9 @@ class AwsEcsExecutor(BaseExecutor): supports_multi_team: bool = True + if AIRFLOW_V_3_3_PLUS: + supports_callbacks: bool = True + # AWS limits the maximum number of ARNs in the describe_tasks function. DESCRIBE_TASKS_BATCH_SIZE = 99 @@ -131,30 +133,30 @@ def __init__(self, *args, **kwargs): fallback=CONFIG_DEFAULTS[AllEcsConfigKeys.MAX_RUN_TASK_ATTEMPTS], ) - def queue_workload(self, workload: workloads.All, session: Session | None) -> None: + def _process_workloads(self, workload_items: Sequence[workloads.All]) -> None: from airflow.executors import workloads - if not isinstance(workload, workloads.ExecuteTask): - raise RuntimeError(f"{type(self)} cannot handle workloads of type {type(workload)}") - ti = workload.ti - self.queued_tasks[ti.key] = workload + for workload in workload_items: + if isinstance(workload, workloads.ExecuteTask): + command = [workload] + key = workload.ti.key + queue = workload.ti.queue + executor_config = workload.ti.executor_config or {} - def _process_workloads(self, workloads: Sequence[workloads.All]) -> None: - from airflow.executors.workloads import ExecuteTask + del self.queued_tasks[key] + self.execute_async(key=key, command=command, queue=queue, executor_config=executor_config) # type: ignore[arg-type] + self.running.add(key) - # Airflow V3 version - for w in workloads: - if not isinstance(w, ExecuteTask): - raise RuntimeError(f"{type(self)} cannot handle workloads of type {type(w)}") + elif AIRFLOW_V_3_3_PLUS and isinstance(workload, workloads.ExecuteCallback): + command = [workload] # type: ignore[list-item] + key = workload.callback.id # type: ignore[assignment] - command = [w] - key = w.ti.key - queue = w.ti.queue - executor_config = w.ti.executor_config or {} + del self.queued_callbacks[key] # type: ignore[arg-type] + self.execute_async(key=key, command=command, queue=None) # type: ignore[arg-type] + self.running.add(key) - del self.queued_tasks[key] - self.execute_async(key=key, command=command, queue=queue, executor_config=executor_config) # type: ignore[arg-type] - self.running.add(key) + else: + raise RuntimeError(f"{type(self)} cannot handle workloads of type {type(workload)}") def start(self): """Call this when the Executor is run for the first time by the scheduler.""" @@ -262,10 +264,10 @@ def sync(self): self.log.exception("Failed to sync %s", self.__class__.__name__) def sync_running_tasks(self): - """Check and update state on all running tasks.""" + """Check and update state on all running workloads (tasks and callbacks).""" all_task_arns = self.active_workers.get_all_arns() if not all_task_arns: - self.log.debug("No active Airflow tasks, skipping sync.") + self.log.debug("No active Airflow workloads, skipping sync.") return describe_tasks_response = self.__describe_tasks(all_task_arns) @@ -292,7 +294,7 @@ def __update_running_task(self, task): self.__handle_failed_task(task.task_arn, task.stopped_reason) elif task_state == State.SUCCESS: self.log.debug( - "Airflow task %s marked as %s after running on ECS Task (arn) %s", + "Airflow workload %s marked as %s after running on ECS Task (arn) %s", task_key, task_state, task.task_arn, @@ -346,7 +348,7 @@ def __handle_failed_task(self, task_arn: str, reason: str): failure_count = self.active_workers.failure_count_by_key(task_key) if int(failure_count) < int(self.max_run_task_attempts): self.log.warning( - "Airflow task %s failed due to %s. Failure %s out of %s occurred on %s. Rescheduling.", + "Airflow workload %s failed due to %s. Failure %s out of %s occurred on %s. Rescheduling.", task_key, reason, failure_count, @@ -365,7 +367,7 @@ def __handle_failed_task(self, task_arn: str, reason: str): ) else: self.log.error( - "Airflow task %s has failed a maximum of %s times. Marking as failed", + "Airflow workload %s has failed a maximum of %s times. Marking as failed", task_key, failure_count, ) @@ -430,7 +432,7 @@ def attempt_task_runs(self): else: reasons_str = ", ".join(failure_reasons) self.log.error( - "ECS task %s has failed a maximum of %s times. Marking as failed. Reasons: %s", + "ECS workload %s has failed a maximum of %s times. Marking as failed. Reasons: %s", task_key, attempt_number, reasons_str, @@ -460,7 +462,11 @@ def attempt_task_runs(self): self.running_state(task_key, task.task_arn) def _run_task( - self, task_id: TaskInstanceKey, cmd: CommandType, queue: str, exec_config: ExecutorConfigType + self, + task_id: WorkloadKey, + cmd: CommandType, + queue: str | None, + exec_config: ExecutorConfigType, ): """ Run a queued-up Airflow task. @@ -475,7 +481,11 @@ def _run_task( return run_task_response def _run_task_kwargs( - self, task_id: TaskInstanceKey, cmd: CommandType, queue: str, exec_config: ExecutorConfigType + self, + task_id: WorkloadKey, + cmd: CommandType, + queue: str | None, + exec_config: ExecutorConfigType, ) -> dict: """ Update the Airflow command by modifying container overrides for task-specific kwargs. @@ -494,14 +504,16 @@ def _run_task_kwargs( return run_task_kwargs - def execute_async(self, key: TaskInstanceKey, command: CommandType, queue=None, executor_config=None): - """Save the task to be executed in the next sync by inserting the commands into a queue.""" + def execute_async(self, key: WorkloadKey, command: CommandType, queue=None, executor_config=None): + """Save the workload to be executed in the next sync by inserting the commands into a queue.""" if executor_config and ("name" in executor_config or "command" in executor_config): raise ValueError('Executor Config should never override "name" or "command"') if len(command) == 1: - from airflow.executors.workloads import ExecuteTask + from airflow.executors import workloads - if isinstance(command[0], ExecuteTask): + if isinstance(command[0], workloads.ExecuteTask) or ( + AIRFLOW_V_3_3_PLUS and isinstance(command[0], workloads.ExecuteCallback) + ): command = self._serialize_workload_to_command(command[0]) else: raise ValueError( @@ -567,9 +579,9 @@ def get_container(self, container_list): @staticmethod def _serialize_workload_to_command(workload) -> CommandType: """ - Serialize an ExecuteTask workload into a command for the Task SDK. + Serialize a workload into a command for the Task SDK. - :param workload: ExecuteTask workload to serialize + :param workload: ExecuteTask or ExecuteCallback workload to serialize :return: Command as list of strings for Task SDK execution """ return [ diff --git a/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/utils.py b/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/utils.py index f8d7f58062189..592a61b3603ce 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/utils.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/utils.py @@ -35,7 +35,7 @@ from airflow.utils.state import State if TYPE_CHECKING: - from airflow.models.taskinstance import TaskInstanceKey + from airflow.executors.workloads.types import WorkloadKey CommandType = Sequence[str] ExecutorConfigFunctionType = Callable[[CommandType], dict] @@ -57,11 +57,11 @@ @dataclass class EcsQueuedTask: - """Represents an ECS task that is queued. The task will be run in the next heartbeat.""" + """Represents a queued ECS workload (task or callback). The workload will be run in the next heartbeat.""" - key: TaskInstanceKey + key: WorkloadKey command: CommandType - queue: str + queue: str | None executor_config: ExecutorConfigType attempt_number: int next_attempt_time: datetime.datetime @@ -72,7 +72,7 @@ class EcsTaskInfo: """Contains information about a currently running ECS task.""" cmd: CommandType - queue: str + queue: str | None config: ExecutorConfigType @@ -156,20 +156,20 @@ def __repr__(self): class EcsTaskCollection: - """A five-way dictionary between Airflow task ids, Airflow cmds, ECS ARNs, and ECS task objects.""" + """A five-way dictionary between Airflow workload keys, commands, ECS ARNs, and ECS task objects.""" def __init__(self): - self.key_to_arn: dict[TaskInstanceKey, str] = {} - self.arn_to_key: dict[str, TaskInstanceKey] = {} + self.key_to_arn: dict[WorkloadKey, str] = {} + self.arn_to_key: dict[str, WorkloadKey] = {} self.tasks: dict[str, EcsExecutorTask] = {} - self.key_to_failure_counts: dict[TaskInstanceKey, int] = defaultdict(int) - self.key_to_task_info: dict[TaskInstanceKey, EcsTaskInfo] = {} + self.key_to_failure_counts: dict[WorkloadKey, int] = defaultdict(int) + self.key_to_task_info: dict[WorkloadKey, EcsTaskInfo] = {} def add_task( self, task: EcsExecutorTask, - airflow_task_key: TaskInstanceKey, - queue: str, + airflow_task_key: WorkloadKey, + queue: str | None, airflow_cmd: CommandType, exec_config: ExecutorConfigType, attempt_number: int, @@ -186,8 +186,8 @@ def update_task(self, task: EcsExecutorTask): """Update the state of the given task based on task ARN.""" self.tasks[task.task_arn] = task - def task_by_key(self, task_key: TaskInstanceKey) -> EcsExecutorTask: - """Get a task by Airflow Instance Key.""" + def task_by_key(self, task_key: WorkloadKey) -> EcsExecutorTask: + """Get a task by Airflow workload key.""" arn = self.key_to_arn[task_key] return self.task_by_arn(arn) @@ -195,8 +195,8 @@ def task_by_arn(self, arn) -> EcsExecutorTask: """Get a task by AWS ARN.""" return self.tasks[arn] - def pop_by_key(self, task_key: TaskInstanceKey) -> EcsExecutorTask: - """Delete task from collection based off of Airflow Task Instance Key.""" + def pop_by_key(self, task_key: WorkloadKey) -> EcsExecutorTask: + """Delete task from collection based off of Airflow workload key.""" arn = self.key_to_arn[task_key] task = self.tasks[arn] del self.key_to_arn[task_key] @@ -211,20 +211,20 @@ def get_all_arns(self) -> list[str]: """Get all AWS ARNs in collection.""" return list(self.key_to_arn.values()) - def get_all_task_keys(self) -> list[TaskInstanceKey]: - """Get all Airflow Task Keys in collection.""" + def get_all_task_keys(self) -> list[WorkloadKey]: + """Get all Airflow workload keys in collection.""" return list(self.key_to_arn.keys()) - def failure_count_by_key(self, task_key: TaskInstanceKey) -> int: - """Get the number of times a task has failed given an Airflow Task Key.""" + def failure_count_by_key(self, task_key: WorkloadKey) -> int: + """Get the number of times a workload has failed given an Airflow workload key.""" return self.key_to_failure_counts[task_key] - def increment_failure_count(self, task_key: TaskInstanceKey): - """Increment the failure counter given an Airflow Task Key.""" + def increment_failure_count(self, task_key: WorkloadKey): + """Increment the failure counter given an Airflow workload key.""" self.key_to_failure_counts[task_key] += 1 - def info_by_key(self, task_key: TaskInstanceKey) -> EcsTaskInfo: - """Get the Airflow Command given an Airflow task key.""" + def info_by_key(self, task_key: WorkloadKey) -> EcsTaskInfo: + """Get the task info given an Airflow workload key.""" return self.key_to_task_info[task_key] def __getitem__(self, value): diff --git a/providers/amazon/tests/unit/amazon/aws/executors/ecs/test_ecs_executor.py b/providers/amazon/tests/unit/amazon/aws/executors/ecs/test_ecs_executor.py index f57cd13d28aab..a9732d3089488 100644 --- a/providers/amazon/tests/unit/amazon/aws/executors/ecs/test_ecs_executor.py +++ b/providers/amazon/tests/unit/amazon/aws/executors/ecs/test_ecs_executor.py @@ -52,6 +52,7 @@ parse_assign_public_ip, ) from airflow.providers.amazon.aws.hooks.ecs import EcsHook +from airflow.providers.amazon.version_compat import AIRFLOW_V_3_3_PLUS from airflow.providers.common.compat.sdk import AirflowException, conf from airflow.utils.helpers import convert_camel_to_snake from airflow.utils.state import State, TaskInstanceState @@ -779,7 +780,7 @@ def test_task_retry_on_api_failure_all_tasks_fail(self, _, mock_executor, caplog mock_executor.sync_running_tasks() for i in range(2): assert ( - f"Airflow task {airflow_keys[i]} failed due to {describe_tasks[i]['stoppedReason']}. Failure 1 out of 2" + f"Airflow workload {airflow_keys[i]} failed due to {describe_tasks[i]['stoppedReason']}. Failure 1 out of 2" in caplog.messages[i] ) @@ -803,7 +804,7 @@ def test_task_retry_on_api_failure_all_tasks_fail(self, _, mock_executor, caplog mock_executor.sync_running_tasks() for i in range(2): assert ( - f"Airflow task {airflow_keys[i]} has failed a maximum of 2 times. Marking as failed" + f"Airflow workload {airflow_keys[i]} has failed a maximum of 2 times. Marking as failed" in caplog.messages[i] ) @@ -1976,3 +1977,114 @@ def test_short_import_path(self): from airflow.providers.amazon.aws.executors.ecs import AwsEcsExecutor as AwsEcsExecutorShortPath assert AwsEcsExecutor is AwsEcsExecutorShortPath + + +class TestEcsExecutorCallbackSupport: + """Tests for ExecuteCallback support in the ECS Executor.""" + + @pytest.fixture + def callback_workload(self): + """Create a mock ExecuteCallback workload for testing.""" + from airflow.executors.workloads import ExecuteCallback + from airflow.executors.workloads.base import BundleInfo + from airflow.executors.workloads.callback import CallbackDTO, CallbackFetchMethod + + callback_data = CallbackDTO( + id="12345678-1234-5678-1234-567812345678", + fetch_method=CallbackFetchMethod.IMPORT_PATH, + data={"path": "test.module.alert_func", "kwargs": {}}, + ) + return ExecuteCallback( + callback=callback_data, + dag_rel_path="test.py", + bundle_info=BundleInfo(name="test_bundle", version="1.0"), + token="test_token", + log_path="test.log", + ) + + @pytest.mark.skipif(not AIRFLOW_V_3_3_PLUS, reason="Test requires Airflow 3.3+") + def test_supports_callbacks_attribute(self, mock_executor): + """Verify that the ECS executor declares callback support.""" + assert mock_executor.supports_callbacks is True + + @pytest.mark.skipif(not AIRFLOW_V_3_3_PLUS, reason="Test requires Airflow 3.3+") + def test_queue_callback_workload(self, mock_executor, callback_workload): + """Test that queue_workload correctly stores ExecuteCallback in queued_callbacks.""" + mock_executor.queue_workload(callback_workload, session=None) + + assert len(mock_executor.queued_callbacks) == 1 + assert callback_workload.callback.id in mock_executor.queued_callbacks + assert mock_executor.queued_callbacks[callback_workload.callback.id] is callback_workload + + @pytest.mark.skipif(not AIRFLOW_V_3_3_PLUS, reason="Test requires Airflow 3.3+") + def test_process_callback_workload(self, mock_executor, callback_workload): + """Test that _process_workloads handles ExecuteCallback correctly.""" + callback_key = callback_workload.callback.id + mock_executor.queued_callbacks[callback_key] = callback_workload + + mock_executor._process_workloads([callback_workload]) + + # Callback should be removed from queued_callbacks + assert callback_key not in mock_executor.queued_callbacks + # Callback should be added to running set + assert callback_key in mock_executor.running + # Callback should be added to pending_tasks for execution + assert len(mock_executor.pending_tasks) == 1 + queued = mock_executor.pending_tasks[0] + assert queued.key == callback_key + assert queued.queue is None + + @pytest.mark.skipif(not AIRFLOW_V_3_3_PLUS, reason="Test requires Airflow 3.3+") + def test_execute_async_callback_workload(self, mock_executor, callback_workload): + """Test that execute_async serializes ExecuteCallback workloads correctly.""" + callback_key = callback_workload.callback.id + mock_executor.execute_async(key=callback_key, command=[callback_workload], queue=None) + + assert len(mock_executor.pending_tasks) == 1 + queued = mock_executor.pending_tasks[0] + assert queued.key == callback_key + # Command should be serialized to the execute_workload entrypoint + assert queued.command[0] == "python" + assert queued.command[2] == "airflow.sdk.execution_time.execute_workload" + assert queued.command[3] == "--json-string" + + @pytest.mark.skipif(not AIRFLOW_V_3_3_PLUS, reason="Test requires Airflow 3.3+") + def test_callback_sync_running_success(self, mock_executor, callback_workload): + """Test that sync_running_tasks correctly handles successful callback ECS tasks.""" + callback_key = callback_workload.callback.id + ecs_task = mock_task(ARN1, State.SUCCESS) + mock_cmd = _generate_mock_cmd() + mock_executor.active_workers.add_task(ecs_task, callback_key, None, mock_cmd, {}, 1) + + mock_executor.ecs.describe_tasks.return_value = { + "tasks": [ + { + "taskArn": ARN1, + "lastStatus": "STOPPED", + "desiredStatus": "STOPPED", + "containers": [{"name": "container-name", "exitCode": 0, "lastStatus": "STOPPED"}], + "startedAt": "2024-01-01T00:00:00Z", + } + ], + "failures": [], + } + + mock_executor.sync_running_tasks() + + # Callback should be removed from active workers after success + assert len(mock_executor.active_workers) == 0 + + @pytest.mark.skipif(not AIRFLOW_V_3_3_PLUS, reason="Test requires Airflow 3.3+") + def test_collection_mixed_key_types(self): + """Test that EcsTaskCollection works with both TaskInstanceKey and callback string keys.""" + collection = EcsTaskCollection() + mock_cmd = _generate_mock_cmd() + task_key = mock.Mock(spec=tuple) + callback_key = "12345678-1234-5678-1234-567812345678" + + collection.add_task(mock_task(ARN1), task_key, "default", mock_cmd, {}, 1) + collection.add_task(mock_task(ARN2), callback_key, None, mock_cmd, {}, 1) + + assert len(collection) == 2 + assert collection.key_to_arn[task_key] == ARN1 + assert collection.key_to_arn[callback_key] == ARN2 From 233ec6bd644eeca209be23fb94095953a6bdabc0 Mon Sep 17 00:00:00 2001 From: Shivam Rastogi <6463385+shivaam@users.noreply.github.com> Date: Sat, 18 Apr 2026 19:43:25 -0700 Subject: [PATCH 2/2] Rename task-named methods/attrs and fix older-Airflow compat MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Renames (mirrors the merged Lambda callback PR — straight rename, no shim, executor-internal surface): sync_running_tasks -> sync_running_workloads attempt_task_runs -> attempt_workload_runs pending_tasks (attr) -> pending_workloads __update_running_task -> __update_running_workload __handle_failed_task -> __handle_failed_workload Fix CI on older Airflow compat tests: - Restore queue_workload() override. Airflow 3.3+ BaseExecutor routes ExecuteCallback natively, but pre-3.3 raises ValueError for anything not ExecuteTask. Override works across versions. - Import AIRFLOW_V_3_3_PLUS from tests_common (main bumped to 3.3). check-airflow-v-imports-in-tests hook disallows provider-internal version_compat imports from test files. --- .../amazon/aws/executors/ecs/ecs_executor.py | 98 +++++++--- .../amazon/aws/executors/ecs/utils.py | 5 +- .../aws/executors/ecs/test_ecs_executor.py | 169 +++++++++--------- 3 files changed, 163 insertions(+), 109 deletions(-) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py b/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py index 3a1fc19f21a3c..764c8ddb21eaa 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py @@ -24,13 +24,15 @@ from __future__ import annotations import time +import warnings from collections import defaultdict, deque from collections.abc import Sequence from copy import deepcopy -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, TypeAlias from botocore.exceptions import ClientError, NoCredentialsError +from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.executors.base_executor import BaseExecutor from airflow.providers.amazon.aws.executors.ecs.boto_schema import BotoDescribeTasksSchema, BotoRunTaskSchema from airflow.providers.amazon.aws.executors.ecs.utils import ( @@ -52,14 +54,22 @@ from airflow.utils.state import State if TYPE_CHECKING: + from sqlalchemy.orm import Session + from airflow.executors import workloads - from airflow.executors.workloads.types import WorkloadKey from airflow.models.taskinstance import TaskInstance, TaskInstanceKey from airflow.providers.amazon.aws.executors.ecs.utils import ( CommandType, ExecutorConfigType, ) + if AIRFLOW_V_3_3_PLUS: + from airflow.executors.workloads.types import WorkloadKey as _EcsWorkloadKey + + WorkloadKey: TypeAlias = _EcsWorkloadKey + else: + WorkloadKey: TypeAlias = TaskInstanceKey # type: ignore[no-redef, misc] + INVALID_CREDENTIALS_EXCEPTIONS = [ "ExpiredTokenException", "InvalidClientTokenId", @@ -105,7 +115,7 @@ class AwsEcsExecutor(BaseExecutor): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.active_workers: EcsTaskCollection = EcsTaskCollection() - self.pending_tasks: deque = deque() + self.pending_workloads: deque = deque() # Check if self has the ExecutorConf set on the self.conf attribute, and if not, set it to the global # configuration object. This allows the changes to be backwards compatible with older versions of @@ -133,6 +143,18 @@ def __init__(self, *args, **kwargs): fallback=CONFIG_DEFAULTS[AllEcsConfigKeys.MAX_RUN_TASK_ATTEMPTS], ) + # TODO: Remove this once the minimum supported version is 3.3+, and defer to BaseExecutor.queue_workload. + def queue_workload(self, workload: workloads.All, session: Session | None) -> None: + from airflow.executors import workloads + + if isinstance(workload, workloads.ExecuteTask): + self.queued_tasks[workload.ti.key] = workload + return + if AIRFLOW_V_3_3_PLUS and isinstance(workload, workloads.ExecuteCallback): + self.queued_callbacks[workload.callback.key] = workload + return + raise RuntimeError(f"{type(self)} cannot handle workloads of type {type(workload)}") + def _process_workloads(self, workload_items: Sequence[workloads.All]) -> None: from airflow.executors import workloads @@ -149,7 +171,7 @@ def _process_workloads(self, workload_items: Sequence[workloads.All]) -> None: elif AIRFLOW_V_3_3_PLUS and isinstance(workload, workloads.ExecuteCallback): command = [workload] # type: ignore[list-item] - key = workload.callback.id # type: ignore[assignment] + key = workload.callback.key # type: ignore[assignment] del self.queued_callbacks[key] # type: ignore[arg-type] self.execute_async(key=key, command=command, queue=None) # type: ignore[arg-type] @@ -248,8 +270,8 @@ def sync(self): if not self.IS_BOTO_CONNECTION_HEALTHY: return try: - self.sync_running_tasks() - self.attempt_task_runs() + self.sync_running_workloads() + self.attempt_workload_runs() except (ClientError, NoCredentialsError) as error: error_code = error.response["Error"]["Code"] if error_code in INVALID_CREDENTIALS_EXCEPTIONS: @@ -263,7 +285,7 @@ def sync(self): # up and kill the scheduler process self.log.exception("Failed to sync %s", self.__class__.__name__) - def sync_running_tasks(self): + def sync_running_workloads(self): """Check and update state on all running workloads (tasks and callbacks).""" all_task_arns = self.active_workers.get_all_arns() if not all_task_arns: @@ -276,13 +298,13 @@ def sync_running_tasks(self): if describe_tasks_response["failures"]: for failure in describe_tasks_response["failures"]: - self.__handle_failed_task(failure["arn"], failure["reason"]) + self.__handle_failed_workload(failure["arn"], failure["reason"]) updated_tasks = describe_tasks_response["tasks"] for task in updated_tasks: - self.__update_running_task(task) + self.__update_running_workload(task) - def __update_running_task(self, task): + def __update_running_workload(self, task): self.active_workers.update_task(task) # Get state of current task. task_state = task.get_task_state() @@ -291,7 +313,7 @@ def __update_running_task(self, task): # Mark finished tasks as either a success/failure. if task_state == State.FAILED or task_state == State.REMOVED: self.__log_container_failures(task_arn=task.task_arn) - self.__handle_failed_task(task.task_arn, task.stopped_reason) + self.__handle_failed_workload(task.task_arn, task.stopped_reason) elif task_state == State.SUCCESS: self.log.debug( "Airflow workload %s marked as %s after running on ECS Task (arn) %s", @@ -331,13 +353,13 @@ def __log_container_failures(self, task_arn: str): "The ECS task failed due to the following containers failing:\n%s", "\n".join(reasons) ) - def __handle_failed_task(self, task_arn: str, reason: str): + def __handle_failed_workload(self, task_arn: str, reason: str): """ If an API failure occurs, the task is rescheduled. This function will determine whether the task has been attempted the appropriate number of times, and determine whether the task should be marked failed or not. The task will - be removed active_workers, and marked as FAILED, or set into pending_tasks depending on + be removed active_workers, and marked as FAILED, or set into pending_workloads depending on how many times it has been retried. """ task_key = self.active_workers.arn_to_key[task_arn] @@ -355,7 +377,7 @@ def __handle_failed_task(self, task_arn: str, reason: str): self.max_run_task_attempts, task_arn, ) - self.pending_tasks.append( + self.pending_workloads.append( EcsQueuedTask( task_key, task_cmd, @@ -374,9 +396,9 @@ def __handle_failed_task(self, task_arn: str, reason: str): self.fail(task_key) self.active_workers.pop_by_key(task_key) - def attempt_task_runs(self): + def attempt_workload_runs(self): """ - Take tasks from the pending_tasks queue, and attempts to find an instance to run it on. + Take tasks from the pending_workloads queue, and attempts to find an instance to run it on. If the launch type is EC2, this will attempt to place tasks on empty EC2 instances. If there are no EC2 instances available, no task is placed and this function will be @@ -384,10 +406,10 @@ def attempt_task_runs(self): If the launch type is FARGATE, this will run the tasks on new AWS Fargate instances. """ - queue_len = len(self.pending_tasks) + queue_len = len(self.pending_workloads) failure_reasons = defaultdict(int) for _ in range(queue_len): - ecs_task = self.pending_tasks.popleft() + ecs_task = self.pending_workloads.popleft() task_key = ecs_task.key cmd = ecs_task.command queue = ecs_task.queue @@ -395,17 +417,17 @@ def attempt_task_runs(self): attempt_number = ecs_task.attempt_number failure_reasons = [] if timezone.utcnow() < ecs_task.next_attempt_time: - self.pending_tasks.append(ecs_task) + self.pending_workloads.append(ecs_task) continue try: run_task_response = self._run_task(task_key, cmd, queue, exec_config) except NoCredentialsError: - self.pending_tasks.append(ecs_task) + self.pending_workloads.append(ecs_task) raise except ClientError as e: error_code = e.response["Error"]["Code"] if error_code in INVALID_CREDENTIALS_EXCEPTIONS: - self.pending_tasks.append(ecs_task) + self.pending_workloads.append(ecs_task) raise failure_reasons.append(str(e)) except Exception as e: @@ -428,7 +450,7 @@ def attempt_task_runs(self): ecs_task.next_attempt_time = timezone.utcnow() + calculate_next_attempt_delay( attempt_number ) - self.pending_tasks.append(ecs_task) + self.pending_workloads.append(ecs_task) else: reasons_str = ", ".join(failure_reasons) self.log.error( @@ -520,7 +542,7 @@ def execute_async(self, key: WorkloadKey, command: CommandType, queue=None, exec f"EcsExecutor doesn't know how to handle workload of type: {type(command[0])}" ) - self.pending_tasks.append( + self.pending_workloads.append( EcsQueuedTask(key, command, queue, executor_config or {}, 1, timezone.utcnow()) ) @@ -646,3 +668,33 @@ def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[Task not_adopted_tis = [ti for ti in tis if ti not in adopted_tis] return not_adopted_tis + + # ── Back-compat shims for renamed methods/attrs ──────────────────────── + + @property + def pending_tasks(self) -> deque: + """Use pending_workloads as pending_tasks is deprecated.""" + warnings.warn( + "pending_tasks is deprecated, use pending_workloads instead.", + AirflowProviderDeprecationWarning, + stacklevel=2, + ) + return self.pending_workloads + + def sync_running_tasks(self): + """Use sync_running_workloads as sync_running_tasks is deprecated.""" + warnings.warn( + "sync_running_tasks is deprecated, use sync_running_workloads instead.", + AirflowProviderDeprecationWarning, + stacklevel=2, + ) + return self.sync_running_workloads() + + def attempt_task_runs(self): + """Use attempt_workload_runs as attempt_task_runs is deprecated.""" + warnings.warn( + "attempt_task_runs is deprecated, use attempt_workload_runs instead.", + AirflowProviderDeprecationWarning, + stacklevel=2, + ) + return self.attempt_workload_runs() diff --git a/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/utils.py b/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/utils.py index 592a61b3603ce..24e05ee2acb49 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/utils.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/utils.py @@ -35,7 +35,10 @@ from airflow.utils.state import State if TYPE_CHECKING: - from airflow.executors.workloads.types import WorkloadKey + from airflow.providers.amazon.version_compat import AIRFLOW_V_3_3_PLUS + + if AIRFLOW_V_3_3_PLUS: + from airflow.executors.workloads.types import WorkloadKey CommandType = Sequence[str] ExecutorConfigFunctionType = Callable[[CommandType], dict] diff --git a/providers/amazon/tests/unit/amazon/aws/executors/ecs/test_ecs_executor.py b/providers/amazon/tests/unit/amazon/aws/executors/ecs/test_ecs_executor.py index a9732d3089488..cc09a3faaa9d2 100644 --- a/providers/amazon/tests/unit/amazon/aws/executors/ecs/test_ecs_executor.py +++ b/providers/amazon/tests/unit/amazon/aws/executors/ecs/test_ecs_executor.py @@ -52,7 +52,6 @@ parse_assign_public_ip, ) from airflow.providers.amazon.aws.hooks.ecs import EcsHook -from airflow.providers.amazon.version_compat import AIRFLOW_V_3_3_PLUS from airflow.providers.common.compat.sdk import AirflowException, conf from airflow.utils.helpers import convert_camel_to_snake from airflow.utils.state import State, TaskInstanceState @@ -61,7 +60,7 @@ from tests_common import RUNNING_TESTS_AGAINST_AIRFLOW_PACKAGES from tests_common.test_utils.config import conf_vars -from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS +from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_3_PLUS airflow_version = VersionInfo(*map(int, airflow_version_str.split(".")[:3])) @@ -400,11 +399,11 @@ def test_execute(self, change_state_mock, mock_airflow_key, mock_executor, mock_ "failures": [], } - assert len(mock_executor.pending_tasks) == 0 + assert len(mock_executor.pending_workloads) == 0 mock_executor.execute_async(airflow_key, mock_cmd) - assert len(mock_executor.pending_tasks) == 1 + assert len(mock_executor.pending_workloads) == 1 - mock_executor.attempt_task_runs() + mock_executor.attempt_workload_runs() mock_executor.ecs.run_task.assert_called_once() # Task is stored in active worker. @@ -443,14 +442,14 @@ def test_task_sdk(self, change_state_mock, mock_airflow_key, mock_executor, mock } assert mock_executor.queued_tasks[workload.ti.key] == workload - assert len(mock_executor.pending_tasks) == 0 + assert len(mock_executor.pending_workloads) == 0 assert len(mock_executor.running) == 0 mock_executor._process_workloads([workload]) assert len(mock_executor.queued_tasks) == 0 assert len(mock_executor.running) == 1 assert workload.ti.key in mock_executor.running - assert len(mock_executor.pending_tasks) == 1 - assert mock_executor.pending_tasks[0].command == [ + assert len(mock_executor.pending_workloads) == 1 + assert mock_executor.pending_workloads[0].command == [ "python", "-m", "airflow.sdk.execution_time.execute_workload", @@ -458,9 +457,9 @@ def test_task_sdk(self, change_state_mock, mock_airflow_key, mock_executor, mock '{"test_key": "test_value"}', ] - mock_executor.attempt_task_runs() + mock_executor.attempt_workload_runs() mock_executor.ecs.run_task.assert_called_once() - assert len(mock_executor.pending_tasks) == 0 + assert len(mock_executor.pending_workloads) == 0 mock_executor.ecs.run_task.assert_called_once_with( cluster="some-cluster", count=1, @@ -525,13 +524,13 @@ def test_success_execute_api_exception(self, mock_backoff, mock_executor, mock_c # Fail 2 times for _ in range(expected_retry_count): - mock_executor.attempt_task_runs() + mock_executor.attempt_workload_runs() # Task is not stored in active workers. assert len(mock_executor.active_workers) == 0 # Pass in last attempt - mock_executor.attempt_task_runs() - assert len(mock_executor.pending_tasks) == 0 + mock_executor.attempt_workload_runs() + assert len(mock_executor.pending_workloads) == 0 assert ARN1 in mock_executor.active_workers.get_all_arns() assert mock_backoff.call_count == expected_retry_count for attempt_number in range(1, expected_retry_count): @@ -544,7 +543,7 @@ def test_failed_execute_api_exception(self, mock_executor, mock_cmd): # No matter what, don't schedule until run_task becomes successful. for _ in range(int(mock_executor.max_run_task_attempts) * 2): - mock_executor.attempt_task_runs() + mock_executor.attempt_workload_runs() # Task is not stored in active workers. assert len(mock_executor.active_workers) == 0 @@ -561,12 +560,12 @@ def test_failed_execute_api(self, mock_executor, mock_cmd): # No matter what, don't schedule until run_task becomes successful. for _ in range(int(mock_executor.max_run_task_attempts) * 2): - mock_executor.attempt_task_runs() + mock_executor.attempt_workload_runs() # Task is not stored in active workers. assert len(mock_executor.active_workers) == 0 @mock.patch.object(ecs_executor, "calculate_next_attempt_delay", return_value=dt.timedelta(seconds=0)) - def test_attempt_task_runs_attempts_when_tasks_fail(self, _, mock_executor): + def test_attempt_workload_runs_attempts_when_tasks_fail(self, _, mock_executor): """ Test case when all tasks fail to run. @@ -587,36 +586,36 @@ def test_attempt_task_runs_attempts_when_tasks_fail(self, _, mock_executor): mock_executor.execute_async(airflow_keys[0], commands[0]) mock_executor.execute_async(airflow_keys[1], commands[1]) - assert len(mock_executor.pending_tasks) == 2 + assert len(mock_executor.pending_workloads) == 2 assert len(mock_executor.active_workers.get_all_arns()) == 0 mock_executor.ecs.run_task.side_effect = failures - mock_executor.attempt_task_runs() + mock_executor.attempt_workload_runs() for i in range(2): RUN_TASK_KWARGS["overrides"]["containerOverrides"][0]["command"] = commands[i] assert mock_executor.ecs.run_task.call_args_list[i].kwargs == RUN_TASK_KWARGS - assert len(mock_executor.pending_tasks) == 2 + assert len(mock_executor.pending_workloads) == 2 assert len(mock_executor.active_workers.get_all_arns()) == 0 mock_executor.ecs.run_task.call_args_list.clear() mock_executor.ecs.run_task.side_effect = failures - mock_executor.attempt_task_runs() + mock_executor.attempt_workload_runs() for i in range(2): RUN_TASK_KWARGS["overrides"]["containerOverrides"][0]["command"] = commands[i] assert mock_executor.ecs.run_task.call_args_list[i].kwargs == RUN_TASK_KWARGS - assert len(mock_executor.pending_tasks) == 2 + assert len(mock_executor.pending_workloads) == 2 assert len(mock_executor.active_workers.get_all_arns()) == 0 mock_executor.ecs.run_task.call_args_list.clear() mock_executor.ecs.run_task.side_effect = failures - mock_executor.attempt_task_runs() + mock_executor.attempt_workload_runs() assert len(mock_executor.active_workers.get_all_arns()) == 0 - assert len(mock_executor.pending_tasks) == 0 + assert len(mock_executor.pending_workloads) == 0 if airflow_version >= (2, 10, 0): events = [(x.event, x.task_id, x.try_number) for x in mock_executor._task_event_logs] @@ -626,7 +625,7 @@ def test_attempt_task_runs_attempts_when_tasks_fail(self, _, mock_executor): ] @mock.patch.object(ecs_executor, "calculate_next_attempt_delay", return_value=dt.timedelta(seconds=0)) - def test_attempt_task_runs_attempts_when_some_tasks_fal(self, _, mock_executor): + def test_attempt_workload_runs_attempts_when_some_tasks_fal(self, _, mock_executor): """ Test case when one task fail to run, and a new task gets queued. @@ -655,16 +654,16 @@ def test_attempt_task_runs_attempts_when_some_tasks_fal(self, _, mock_executor): mock_executor.execute_async(airflow_keys[0], airflow_commands[0]) mock_executor.execute_async(airflow_keys[1], airflow_commands[1]) - assert len(mock_executor.pending_tasks) == 2 + assert len(mock_executor.pending_workloads) == 2 mock_executor.ecs.run_task.side_effect = responses - mock_executor.attempt_task_runs() + mock_executor.attempt_workload_runs() for i in range(2): RUN_TASK_KWARGS["overrides"]["containerOverrides"][0]["command"] = airflow_commands[i] assert mock_executor.ecs.run_task.call_args_list[i].kwargs == RUN_TASK_KWARGS - assert len(mock_executor.pending_tasks) == 1 + assert len(mock_executor.pending_workloads) == 1 assert len(mock_executor.active_workers.get_all_arns()) == 1 mock_executor.ecs.run_task.call_args_list.clear() @@ -674,29 +673,29 @@ def test_attempt_task_runs_attempts_when_some_tasks_fal(self, _, mock_executor): airflow_commands[1] = _generate_mock_cmd() mock_executor.execute_async(airflow_keys[1], airflow_commands[1]) - assert len(mock_executor.pending_tasks) == 2 + assert len(mock_executor.pending_workloads) == 2 # assert that the order of pending tasks is preserved i.e. the first task is 1st etc. - assert mock_executor.pending_tasks[0].key == airflow_keys[0] - assert mock_executor.pending_tasks[0].command == airflow_commands[0] + assert mock_executor.pending_workloads[0].key == airflow_keys[0] + assert mock_executor.pending_workloads[0].command == airflow_commands[0] task["taskArn"] = ARN2 success_response = {"tasks": [task], "failures": []} responses = [Exception("Failure 1"), success_response] mock_executor.ecs.run_task.side_effect = responses - mock_executor.attempt_task_runs() + mock_executor.attempt_workload_runs() for i in range(2): RUN_TASK_KWARGS["overrides"]["containerOverrides"][0]["command"] = airflow_commands[i] assert mock_executor.ecs.run_task.call_args_list[i].kwargs == RUN_TASK_KWARGS - assert len(mock_executor.pending_tasks) == 1 + assert len(mock_executor.pending_workloads) == 1 assert len(mock_executor.active_workers.get_all_arns()) == 2 mock_executor.ecs.run_task.call_args_list.clear() responses = [Exception("Failure 1")] mock_executor.ecs.run_task.side_effect = responses - mock_executor.attempt_task_runs() + mock_executor.attempt_workload_runs() RUN_TASK_KWARGS["overrides"]["containerOverrides"][0]["command"] = airflow_commands[0] assert mock_executor.ecs.run_task.call_args_list[0].kwargs == RUN_TASK_KWARGS @@ -716,7 +715,7 @@ def test_task_retry_on_api_failure_all_tasks_fail(self, _, mock_executor, caplog mock_executor.execute_async(airflow_keys[0], airflow_commands[0]) mock_executor.execute_async(airflow_keys[1], airflow_commands[1]) - assert len(mock_executor.pending_tasks) == 2 + assert len(mock_executor.pending_workloads) == 2 caplog.set_level("WARNING") describe_tasks = [ @@ -768,16 +767,16 @@ def test_task_retry_on_api_failure_all_tasks_fail(self, _, mock_executor, caplog ] mock_executor.ecs.describe_tasks.side_effect = [{"tasks": describe_tasks, "failures": []}] - mock_executor.attempt_task_runs() + mock_executor.attempt_workload_runs() for i in range(2): RUN_TASK_KWARGS["overrides"]["containerOverrides"][0]["command"] = airflow_commands[i] assert mock_executor.ecs.run_task.call_args_list[i].kwargs == RUN_TASK_KWARGS - assert len(mock_executor.pending_tasks) == 0 + assert len(mock_executor.pending_workloads) == 0 assert len(mock_executor.active_workers.get_all_arns()) == 2 - mock_executor.sync_running_tasks() + mock_executor.sync_running_workloads() for i in range(2): assert ( f"Airflow workload {airflow_keys[i]} failed due to {describe_tasks[i]['stoppedReason']}. Failure 1 out of 2" @@ -793,15 +792,15 @@ def test_task_retry_on_api_failure_all_tasks_fail(self, _, mock_executor, caplog ] mock_executor.ecs.describe_tasks.side_effect = [{"tasks": describe_tasks, "failures": []}] - mock_executor.attempt_task_runs() + mock_executor.attempt_workload_runs() - mock_executor.attempt_task_runs() + mock_executor.attempt_workload_runs() for i in range(2): RUN_TASK_KWARGS["overrides"]["containerOverrides"][0]["command"] = airflow_commands[i] assert mock_executor.ecs.run_task.call_args_list[i].kwargs == RUN_TASK_KWARGS - mock_executor.sync_running_tasks() + mock_executor.sync_running_workloads() for i in range(2): assert ( f"Airflow workload {airflow_keys[i]} has failed a maximum of 2 times. Marking as failed" @@ -814,7 +813,7 @@ def test_sync(self, success_mock, fail_mock, mock_executor): """Test sync from end-to-end.""" self._mock_sync(mock_executor) - mock_executor.sync_running_tasks() + mock_executor.sync_running_workloads() mock_executor.ecs.describe_tasks.assert_called_once() # Task is not stored in active workers. @@ -829,7 +828,7 @@ def test_sync(self, success_mock, fail_mock, mock_executor): def test_sync_short_circuits_with_no_arns(self, _, success_mock, fail_mock, mock_executor): self._mock_sync(mock_executor) - mock_executor.sync_running_tasks() + mock_executor.sync_running_workloads() mock_executor.ecs.describe_tasks.assert_not_called() fail_mock.assert_not_called() @@ -858,7 +857,7 @@ def test_removed_sync(self, fail_mock, success_mock, mock_executor): mock_executor.max_run_task_attempts = "1" self._mock_sync(mock_executor, expected_state=State.REMOVED, set_task_state=State.REMOVED) - mock_executor.sync_running_tasks() + mock_executor.sync_running_workloads() # Task is not stored in active workers. assert len(mock_executor.active_workers) == 0 @@ -884,11 +883,11 @@ def test_failed_sync_cumulative_fail( task_key = mock_airflow_key() mock_executor.execute_async(task_key, mock_cmd) for _ in range(2): - assert len(mock_executor.pending_tasks) == 1 - keys = [task.key for task in mock_executor.pending_tasks] + assert len(mock_executor.pending_workloads) == 1 + keys = [task.key for task in mock_executor.pending_workloads] assert task_key in keys - mock_executor.attempt_task_runs() - assert len(mock_executor.pending_tasks) == 1 + mock_executor.attempt_workload_runs() + assert len(mock_executor.pending_workloads) == 1 mock_executor.ecs.run_task.return_value = { "tasks": [ @@ -901,8 +900,8 @@ def test_failed_sync_cumulative_fail( ], "failures": [], } - mock_executor.attempt_task_runs() - assert len(mock_executor.pending_tasks) == 0 + mock_executor.attempt_workload_runs() + assert len(mock_executor.pending_workloads) == 0 assert ARN1 in mock_executor.active_workers.get_all_arns() mock_executor.ecs.describe_tasks.return_value = { @@ -912,19 +911,19 @@ def test_failed_sync_cumulative_fail( ], } - # Call sync_running_tasks and attempt_task_runs 2 times with failures. + # Call sync_running_workloads and attempt_workload_runs 2 times with failures. for _ in range(2): - mock_executor.sync_running_tasks() + mock_executor.sync_running_workloads() # Ensure task gets removed from active_workers. assert ARN1 not in mock_executor.active_workers.get_all_arns() - # Ensure task gets back on the pending_tasks queue - assert len(mock_executor.pending_tasks) == 1 - keys = [task.key for task in mock_executor.pending_tasks] + # Ensure task gets back on the pending_workloads queue + assert len(mock_executor.pending_workloads) == 1 + keys = [task.key for task in mock_executor.pending_workloads] assert task_key in keys - mock_executor.attempt_task_runs() - assert len(mock_executor.pending_tasks) == 0 + mock_executor.attempt_workload_runs() + assert len(mock_executor.pending_workloads) == 0 assert ARN1 in mock_executor.active_workers.get_all_arns() # Task is neither failed nor succeeded. @@ -938,7 +937,7 @@ def test_failed_sync_cumulative_fail( # 2 run_task failures + 2 describe_task failures = 4 failures # Last call should fail the task. - mock_executor.sync_running_tasks() + mock_executor.sync_running_workloads() assert ARN1 not in mock_executor.active_workers.get_all_arns() fail_mock.assert_called() success_mock.assert_not_called() @@ -958,7 +957,7 @@ def test_failed_sync_api(self, _, success_mock, fail_mock, mock_executor, mock_c """Test what happens when ECS sync fails for certain tasks repeatedly.""" airflow_key = "test-key" mock_executor.execute_async(airflow_key, mock_cmd) - assert len(mock_executor.pending_tasks) == 1 + assert len(mock_executor.pending_workloads) == 1 run_task_ret_val = { "taskArn": ARN1, @@ -980,35 +979,35 @@ def test_failed_sync_api(self, _, success_mock, fail_mock, mock_executor, mock_c ], } mock_executor.ecs.describe_tasks.return_value = describe_tasks_ret_value - mock_executor.attempt_task_runs() - assert len(mock_executor.pending_tasks) == 0 + mock_executor.attempt_workload_runs() + assert len(mock_executor.pending_workloads) == 0 assert len(mock_executor.active_workers.get_all_arns()) == 1 task_key = mock_executor.active_workers.arn_to_key[ARN1] # Call Sync 2 times with failures. The task can only fail max_run_task_attempts times. for check_count in range(1, int(mock_executor.max_run_task_attempts)): - mock_executor.sync_running_tasks() + mock_executor.sync_running_workloads() assert mock_executor.ecs.describe_tasks.call_count == check_count # Ensure task gets removed from active_workers. assert ARN1 not in mock_executor.active_workers.get_all_arns() - # Ensure task gets back on the pending_tasks queue - assert len(mock_executor.pending_tasks) == 1 - keys = [task.key for task in mock_executor.pending_tasks] + # Ensure task gets back on the pending_workloads queue + assert len(mock_executor.pending_workloads) == 1 + keys = [task.key for task in mock_executor.pending_workloads] assert task_key in keys # Task is neither failed nor succeeded. fail_mock.assert_not_called() success_mock.assert_not_called() - mock_executor.attempt_task_runs() + mock_executor.attempt_workload_runs() - assert len(mock_executor.pending_tasks) == 0 + assert len(mock_executor.pending_workloads) == 0 assert len(mock_executor.active_workers.get_all_arns()) == 1 assert ARN1 in mock_executor.active_workers.get_all_arns() task_key = mock_executor.active_workers.arn_to_key[ARN1] # Last call should fail the task. - mock_executor.sync_running_tasks() + mock_executor.sync_running_workloads() assert ARN1 not in mock_executor.active_workers.get_all_arns() fail_mock.assert_called() success_mock.assert_not_called() @@ -1102,7 +1101,7 @@ def test_executor_config_exceptions(self, bad_config, mock_executor, mock_cmd): with pytest.raises(ValueError, match='Executor Config should never override "name" or "command"'): mock_executor.execute_async(mock_airflow_key, mock_cmd, executor_config=bad_config) - assert len(mock_executor.pending_tasks) == 0 + assert len(mock_executor.pending_workloads) == 0 @mock.patch.object(ecs_executor_config, "build_task_kwargs") def test_container_not_found(self, mock_build_task_kwargs, mock_executor): @@ -1116,7 +1115,7 @@ def test_container_not_found(self, mock_build_task_kwargs, mock_executor): '"overrides[containerOverrides][containers][x][command]"' ) ) - assert len(mock_executor.pending_tasks) == 0 + assert len(mock_executor.pending_workloads) == 0 def _mock_sync( self, @@ -1125,7 +1124,7 @@ def _mock_sync( set_task_state=TaskInstanceState.RUNNING, ) -> None: """Mock ECS to the expected state.""" - executor.pending_tasks.clear() + executor.pending_workloads.clear() self._add_mock_task(executor, ARN1, set_task_state) response_task_json = { @@ -1185,7 +1184,7 @@ def test_update_running_tasks( ], } mock_executor.ecs.describe_tasks.return_value = {"tasks": [test_response_task_json], "failures": []} - mock_executor.sync_running_tasks() + mock_executor.sync_running_workloads() if expected_status != State.REMOVED: assert mock_executor.active_workers.tasks["arn1"].get_task_state() == expected_status # The task is not removed from active_workers in these states @@ -1214,7 +1213,7 @@ def test_update_running_tasks_success(self, mock_executor): ) mock_success_function = patcher.start() mock_executor.ecs.describe_tasks.return_value = {"tasks": [test_response_task_json], "failures": []} - mock_executor.sync_running_tasks() + mock_executor.sync_running_workloads() assert len(mock_executor.active_workers) == 0 mock_success_function.assert_called_once() @@ -1243,7 +1242,7 @@ def test_update_running_tasks_failed(self, mock_executor, caplog): ) mock_failed_function = patcher.start() mock_executor.ecs.describe_tasks.return_value = {"tasks": [test_response_task_json], "failures": []} - mock_executor.sync_running_tasks() + mock_executor.sync_running_workloads() assert len(mock_executor.active_workers) == 0 mock_failed_function.assert_called_once() assert ( @@ -2013,13 +2012,13 @@ def test_queue_callback_workload(self, mock_executor, callback_workload): mock_executor.queue_workload(callback_workload, session=None) assert len(mock_executor.queued_callbacks) == 1 - assert callback_workload.callback.id in mock_executor.queued_callbacks - assert mock_executor.queued_callbacks[callback_workload.callback.id] is callback_workload + assert callback_workload.callback.key in mock_executor.queued_callbacks + assert mock_executor.queued_callbacks[callback_workload.callback.key] is callback_workload @pytest.mark.skipif(not AIRFLOW_V_3_3_PLUS, reason="Test requires Airflow 3.3+") def test_process_callback_workload(self, mock_executor, callback_workload): """Test that _process_workloads handles ExecuteCallback correctly.""" - callback_key = callback_workload.callback.id + callback_key = callback_workload.callback.key mock_executor.queued_callbacks[callback_key] = callback_workload mock_executor._process_workloads([callback_workload]) @@ -2028,20 +2027,20 @@ def test_process_callback_workload(self, mock_executor, callback_workload): assert callback_key not in mock_executor.queued_callbacks # Callback should be added to running set assert callback_key in mock_executor.running - # Callback should be added to pending_tasks for execution - assert len(mock_executor.pending_tasks) == 1 - queued = mock_executor.pending_tasks[0] + # Callback should be added to pending_workloads for execution + assert len(mock_executor.pending_workloads) == 1 + queued = mock_executor.pending_workloads[0] assert queued.key == callback_key assert queued.queue is None @pytest.mark.skipif(not AIRFLOW_V_3_3_PLUS, reason="Test requires Airflow 3.3+") def test_execute_async_callback_workload(self, mock_executor, callback_workload): """Test that execute_async serializes ExecuteCallback workloads correctly.""" - callback_key = callback_workload.callback.id + callback_key = callback_workload.callback.key mock_executor.execute_async(key=callback_key, command=[callback_workload], queue=None) - assert len(mock_executor.pending_tasks) == 1 - queued = mock_executor.pending_tasks[0] + assert len(mock_executor.pending_workloads) == 1 + queued = mock_executor.pending_workloads[0] assert queued.key == callback_key # Command should be serialized to the execute_workload entrypoint assert queued.command[0] == "python" @@ -2050,8 +2049,8 @@ def test_execute_async_callback_workload(self, mock_executor, callback_workload) @pytest.mark.skipif(not AIRFLOW_V_3_3_PLUS, reason="Test requires Airflow 3.3+") def test_callback_sync_running_success(self, mock_executor, callback_workload): - """Test that sync_running_tasks correctly handles successful callback ECS tasks.""" - callback_key = callback_workload.callback.id + """Test that sync_running_workloads correctly handles successful callback ECS tasks.""" + callback_key = callback_workload.callback.key ecs_task = mock_task(ARN1, State.SUCCESS) mock_cmd = _generate_mock_cmd() mock_executor.active_workers.add_task(ecs_task, callback_key, None, mock_cmd, {}, 1) @@ -2069,7 +2068,7 @@ def test_callback_sync_running_success(self, mock_executor, callback_workload): "failures": [], } - mock_executor.sync_running_tasks() + mock_executor.sync_running_workloads() # Callback should be removed from active workers after success assert len(mock_executor.active_workers) == 0