Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,15 @@
from __future__ import annotations

import time
import warnings
from collections import defaultdict, deque
from collections.abc import Sequence
from copy import deepcopy
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, TypeAlias

from botocore.exceptions import ClientError, NoCredentialsError

from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.executors.base_executor import BaseExecutor
from airflow.providers.amazon.aws.executors.ecs.boto_schema import BotoDescribeTasksSchema, BotoRunTaskSchema
from airflow.providers.amazon.aws.executors.ecs.utils import (
Expand All @@ -46,7 +48,7 @@
exponential_backoff_retry,
)
from airflow.providers.amazon.aws.hooks.ecs import EcsHook
from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS
from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_3_PLUS
from airflow.providers.common.compat.sdk import AirflowException, Stats, timezone
from airflow.utils.helpers import merge_dicts
from airflow.utils.state import State
Expand All @@ -61,6 +63,13 @@
ExecutorConfigType,
)

if AIRFLOW_V_3_3_PLUS:
from airflow.executors.workloads.types import WorkloadKey as _EcsWorkloadKey

WorkloadKey: TypeAlias = _EcsWorkloadKey
else:
WorkloadKey: TypeAlias = TaskInstanceKey # type: ignore[no-redef, misc]

INVALID_CREDENTIALS_EXCEPTIONS = [
"ExpiredTokenException",
"InvalidClientTokenId",
Expand Down Expand Up @@ -92,6 +101,9 @@ class AwsEcsExecutor(BaseExecutor):

supports_multi_team: bool = True

if AIRFLOW_V_3_3_PLUS:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should not need this? You should be able to declare this variable regardless o the Airflow version?

supports_callbacks: bool = True

# AWS limits the maximum number of ARNs in the describe_tasks function.
DESCRIBE_TASKS_BATCH_SIZE = 99

Expand All @@ -103,7 +115,7 @@ class AwsEcsExecutor(BaseExecutor):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.active_workers: EcsTaskCollection = EcsTaskCollection()
self.pending_tasks: deque = deque()
self.pending_workloads: deque = deque()

# Check if self has the ExecutorConf set on the self.conf attribute, and if not, set it to the global
# configuration object. This allows the changes to be backwards compatible with older versions of
Expand Down Expand Up @@ -131,30 +143,42 @@ def __init__(self, *args, **kwargs):
fallback=CONFIG_DEFAULTS[AllEcsConfigKeys.MAX_RUN_TASK_ATTEMPTS],
)

# TODO: Remove this once the minimum supported version is 3.3+, and defer to BaseExecutor.queue_workload.
def queue_workload(self, workload: workloads.All, session: Session | None) -> None:
Comment thread
shivaam marked this conversation as resolved.
from airflow.executors import workloads

if not isinstance(workload, workloads.ExecuteTask):
raise RuntimeError(f"{type(self)} cannot handle workloads of type {type(workload)}")
ti = workload.ti
self.queued_tasks[ti.key] = workload
if isinstance(workload, workloads.ExecuteTask):
self.queued_tasks[workload.ti.key] = workload
return
if AIRFLOW_V_3_3_PLUS and isinstance(workload, workloads.ExecuteCallback):
self.queued_callbacks[workload.callback.key] = workload
return
raise RuntimeError(f"{type(self)} cannot handle workloads of type {type(workload)}")

def _process_workloads(self, workloads: Sequence[workloads.All]) -> None:
from airflow.executors.workloads import ExecuteTask
def _process_workloads(self, workload_items: Sequence[workloads.All]) -> None:
Comment thread
shivaam marked this conversation as resolved.
from airflow.executors import workloads

# Airflow V3 version
for w in workloads:
if not isinstance(w, ExecuteTask):
raise RuntimeError(f"{type(self)} cannot handle workloads of type {type(w)}")
for workload in workload_items:
if isinstance(workload, workloads.ExecuteTask):
command = [workload]
key = workload.ti.key
queue = workload.ti.queue
executor_config = workload.ti.executor_config or {}

command = [w]
key = w.ti.key
queue = w.ti.queue
executor_config = w.ti.executor_config or {}
del self.queued_tasks[key]
self.execute_async(key=key, command=command, queue=queue, executor_config=executor_config) # type: ignore[arg-type]
self.running.add(key)

del self.queued_tasks[key]
self.execute_async(key=key, command=command, queue=queue, executor_config=executor_config) # type: ignore[arg-type]
self.running.add(key)
elif AIRFLOW_V_3_3_PLUS and isinstance(workload, workloads.ExecuteCallback):
command = [workload] # type: ignore[list-item]
key = workload.callback.key # type: ignore[assignment]

del self.queued_callbacks[key] # type: ignore[arg-type]
self.execute_async(key=key, command=command, queue=None) # type: ignore[arg-type]
self.running.add(key)

else:
raise RuntimeError(f"{type(self)} cannot handle workloads of type {type(workload)}")

def start(self):
"""Call this when the Executor is run for the first time by the scheduler."""
Expand Down Expand Up @@ -246,8 +270,8 @@ def sync(self):
if not self.IS_BOTO_CONNECTION_HEALTHY:
return
try:
self.sync_running_tasks()
self.attempt_task_runs()
self.sync_running_workloads()
self.attempt_workload_runs()
except (ClientError, NoCredentialsError) as error:
error_code = error.response["Error"]["Code"]
if error_code in INVALID_CREDENTIALS_EXCEPTIONS:
Expand All @@ -261,11 +285,11 @@ def sync(self):
# up and kill the scheduler process
self.log.exception("Failed to sync %s", self.__class__.__name__)

def sync_running_tasks(self):
"""Check and update state on all running tasks."""
def sync_running_workloads(self):
Comment thread
shivaam marked this conversation as resolved.
"""Check and update state on all running workloads (tasks and callbacks)."""
all_task_arns = self.active_workers.get_all_arns()
if not all_task_arns:
self.log.debug("No active Airflow tasks, skipping sync.")
self.log.debug("No active Airflow workloads, skipping sync.")
return

describe_tasks_response = self.__describe_tasks(all_task_arns)
Expand All @@ -274,13 +298,13 @@ def sync_running_tasks(self):

if describe_tasks_response["failures"]:
for failure in describe_tasks_response["failures"]:
self.__handle_failed_task(failure["arn"], failure["reason"])
self.__handle_failed_workload(failure["arn"], failure["reason"])

updated_tasks = describe_tasks_response["tasks"]
for task in updated_tasks:
self.__update_running_task(task)
self.__update_running_workload(task)

def __update_running_task(self, task):
def __update_running_workload(self, task):
self.active_workers.update_task(task)
# Get state of current task.
task_state = task.get_task_state()
Expand All @@ -289,10 +313,10 @@ def __update_running_task(self, task):
# Mark finished tasks as either a success/failure.
if task_state == State.FAILED or task_state == State.REMOVED:
self.__log_container_failures(task_arn=task.task_arn)
self.__handle_failed_task(task.task_arn, task.stopped_reason)
self.__handle_failed_workload(task.task_arn, task.stopped_reason)
elif task_state == State.SUCCESS:
self.log.debug(
"Airflow task %s marked as %s after running on ECS Task (arn) %s",
"Airflow workload %s marked as %s after running on ECS Task (arn) %s",
task_key,
task_state,
task.task_arn,
Expand Down Expand Up @@ -329,13 +353,13 @@ def __log_container_failures(self, task_arn: str):
"The ECS task failed due to the following containers failing:\n%s", "\n".join(reasons)
)

def __handle_failed_task(self, task_arn: str, reason: str):
def __handle_failed_workload(self, task_arn: str, reason: str):
"""
If an API failure occurs, the task is rescheduled.

This function will determine whether the task has been attempted the appropriate number
of times, and determine whether the task should be marked failed or not. The task will
be removed active_workers, and marked as FAILED, or set into pending_tasks depending on
be removed active_workers, and marked as FAILED, or set into pending_workloads depending on
how many times it has been retried.
"""
task_key = self.active_workers.arn_to_key[task_arn]
Expand All @@ -346,14 +370,14 @@ def __handle_failed_task(self, task_arn: str, reason: str):
failure_count = self.active_workers.failure_count_by_key(task_key)
if int(failure_count) < int(self.max_run_task_attempts):
self.log.warning(
"Airflow task %s failed due to %s. Failure %s out of %s occurred on %s. Rescheduling.",
"Airflow workload %s failed due to %s. Failure %s out of %s occurred on %s. Rescheduling.",
task_key,
reason,
failure_count,
self.max_run_task_attempts,
task_arn,
)
self.pending_tasks.append(
self.pending_workloads.append(
EcsQueuedTask(
task_key,
task_cmd,
Expand All @@ -365,45 +389,45 @@ def __handle_failed_task(self, task_arn: str, reason: str):
)
else:
self.log.error(
"Airflow task %s has failed a maximum of %s times. Marking as failed",
"Airflow workload %s has failed a maximum of %s times. Marking as failed",
task_key,
failure_count,
)
self.fail(task_key)
self.active_workers.pop_by_key(task_key)

def attempt_task_runs(self):
def attempt_workload_runs(self):
"""
Take tasks from the pending_tasks queue, and attempts to find an instance to run it on.
Take tasks from the pending_workloads queue, and attempts to find an instance to run it on.

If the launch type is EC2, this will attempt to place tasks on empty EC2 instances. If
there are no EC2 instances available, no task is placed and this function will be
called again in the next heart-beat.

If the launch type is FARGATE, this will run the tasks on new AWS Fargate instances.
"""
queue_len = len(self.pending_tasks)
queue_len = len(self.pending_workloads)
failure_reasons = defaultdict(int)
for _ in range(queue_len):
ecs_task = self.pending_tasks.popleft()
ecs_task = self.pending_workloads.popleft()
task_key = ecs_task.key
cmd = ecs_task.command
queue = ecs_task.queue
exec_config = ecs_task.executor_config
attempt_number = ecs_task.attempt_number
failure_reasons = []
if timezone.utcnow() < ecs_task.next_attempt_time:
self.pending_tasks.append(ecs_task)
self.pending_workloads.append(ecs_task)
continue
try:
run_task_response = self._run_task(task_key, cmd, queue, exec_config)
except NoCredentialsError:
self.pending_tasks.append(ecs_task)
self.pending_workloads.append(ecs_task)
raise
except ClientError as e:
error_code = e.response["Error"]["Code"]
if error_code in INVALID_CREDENTIALS_EXCEPTIONS:
self.pending_tasks.append(ecs_task)
self.pending_workloads.append(ecs_task)
raise
failure_reasons.append(str(e))
except Exception as e:
Expand All @@ -426,11 +450,11 @@ def attempt_task_runs(self):
ecs_task.next_attempt_time = timezone.utcnow() + calculate_next_attempt_delay(
attempt_number
)
self.pending_tasks.append(ecs_task)
self.pending_workloads.append(ecs_task)
else:
reasons_str = ", ".join(failure_reasons)
self.log.error(
"ECS task %s has failed a maximum of %s times. Marking as failed. Reasons: %s",
"ECS workload %s has failed a maximum of %s times. Marking as failed. Reasons: %s",
task_key,
attempt_number,
reasons_str,
Expand Down Expand Up @@ -460,7 +484,11 @@ def attempt_task_runs(self):
self.running_state(task_key, task.task_arn)

def _run_task(
self, task_id: TaskInstanceKey, cmd: CommandType, queue: str, exec_config: ExecutorConfigType
self,
task_id: WorkloadKey,
cmd: CommandType,
queue: str | None,
exec_config: ExecutorConfigType,
):
"""
Run a queued-up Airflow task.
Expand All @@ -475,7 +503,11 @@ def _run_task(
return run_task_response

def _run_task_kwargs(
self, task_id: TaskInstanceKey, cmd: CommandType, queue: str, exec_config: ExecutorConfigType
self,
task_id: WorkloadKey,
cmd: CommandType,
queue: str | None,
exec_config: ExecutorConfigType,
) -> dict:
"""
Update the Airflow command by modifying container overrides for task-specific kwargs.
Expand All @@ -494,21 +526,23 @@ def _run_task_kwargs(

return run_task_kwargs

def execute_async(self, key: TaskInstanceKey, command: CommandType, queue=None, executor_config=None):
"""Save the task to be executed in the next sync by inserting the commands into a queue."""
def execute_async(self, key: WorkloadKey, command: CommandType, queue=None, executor_config=None):
"""Save the workload to be executed in the next sync by inserting the commands into a queue."""
if executor_config and ("name" in executor_config or "command" in executor_config):
raise ValueError('Executor Config should never override "name" or "command"')
if len(command) == 1:
from airflow.executors.workloads import ExecuteTask
from airflow.executors import workloads

if isinstance(command[0], ExecuteTask):
if isinstance(command[0], workloads.ExecuteTask) or (
AIRFLOW_V_3_3_PLUS and isinstance(command[0], workloads.ExecuteCallback)
):
command = self._serialize_workload_to_command(command[0])
else:
raise ValueError(
f"EcsExecutor doesn't know how to handle workload of type: {type(command[0])}"
)

self.pending_tasks.append(
self.pending_workloads.append(
EcsQueuedTask(key, command, queue, executor_config or {}, 1, timezone.utcnow())
)

Expand Down Expand Up @@ -567,9 +601,9 @@ def get_container(self, container_list):
@staticmethod
def _serialize_workload_to_command(workload) -> CommandType:
"""
Serialize an ExecuteTask workload into a command for the Task SDK.
Serialize a workload into a command for the Task SDK.

:param workload: ExecuteTask workload to serialize
:param workload: ExecuteTask or ExecuteCallback workload to serialize
:return: Command as list of strings for Task SDK execution
"""
return [
Expand Down Expand Up @@ -634,3 +668,33 @@ def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[Task

not_adopted_tis = [ti for ti in tis if ti not in adopted_tis]
return not_adopted_tis

# ── Back-compat shims for renamed methods/attrs ────────────────────────

@property
def pending_tasks(self) -> deque:
"""Use pending_workloads as pending_tasks is deprecated."""
warnings.warn(
"pending_tasks is deprecated, use pending_workloads instead.",
AirflowProviderDeprecationWarning,
stacklevel=2,
)
return self.pending_workloads

def sync_running_tasks(self):
"""Use sync_running_workloads as sync_running_tasks is deprecated."""
warnings.warn(
"sync_running_tasks is deprecated, use sync_running_workloads instead.",
AirflowProviderDeprecationWarning,
stacklevel=2,
)
return self.sync_running_workloads()

def attempt_task_runs(self):
"""Use attempt_workload_runs as attempt_task_runs is deprecated."""
warnings.warn(
"attempt_task_runs is deprecated, use attempt_workload_runs instead.",
AirflowProviderDeprecationWarning,
stacklevel=2,
)
return self.attempt_workload_runs()
Loading
Loading