Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions airflow-core/src/airflow/executors/base_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,15 +295,15 @@ def _get_workloads_to_schedule(self, open_slots: int) -> list[tuple[WorkloadKey,

return workloads_to_schedule

def _process_workloads(self, workloads: Sequence[ExecutorWorkload]) -> None:
def _process_workloads(self, workload_items: Sequence[ExecutorWorkload]) -> None:
"""
Process the given workloads.

This method must be implemented by subclasses to define how they handle
the execution of workloads (e.g., queuing them to workers, submitting to
external systems, etc.).

:param workloads: List of workloads to process
:param workload_items: List of workloads to process
"""
raise NotImplementedError(f"{type(self).__name__} must implement _process_workloads()")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.

"""AWS Batch Executor. Each Airflow task gets delegated out to an AWS Batch Job."""
"""AWS Batch Executor. Each Airflow workload gets delegated out to an AWS Batch Job."""

from __future__ import annotations

Expand All @@ -33,7 +33,7 @@
exponential_backoff_retry,
)
from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook
from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS
from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_3_PLUS
from airflow.providers.common.compat.sdk import AirflowException, Stats, timezone
from airflow.utils.helpers import merge_dicts

Expand All @@ -42,6 +42,8 @@

from airflow.executors import workloads
from airflow.models.taskinstance import TaskInstance, TaskInstanceKey


from airflow.providers.amazon.aws.executors.batch.boto_schema import (
BatchDescribeJobsResponseSchema,
BatchSubmitJobResponseSchema,
Expand Down Expand Up @@ -88,6 +90,8 @@ class AwsBatchExecutor(BaseExecutor):
"""

supports_multi_team: bool = True
if AIRFLOW_V_3_3_PLUS:
supports_callbacks: bool = True

# AWS only allows a maximum number of JOBs in the describe_jobs function
DESCRIBE_JOBS_BATCH_SIZE = 99
Expand Down Expand Up @@ -127,26 +131,44 @@ def __init__(self, *args, **kwargs):
def queue_workload(self, workload: workloads.All, session: Session | None) -> None:
from airflow.executors import workloads

if not isinstance(workload, workloads.ExecuteTask):
raise RuntimeError(f"{type(self)} cannot handle workloads of type {type(workload)}")
ti = workload.ti
self.queued_tasks[ti.key] = workload
if isinstance(workload, workloads.ExecuteTask):
self.queued_tasks[workload.ti.key] = workload
return
if AIRFLOW_V_3_3_PLUS and isinstance(workload, workloads.ExecuteCallback):
self.queued_callbacks[workload.callback.key] = workload
return
raise RuntimeError(f"{type(self)} cannot handle workloads of type {type(workload)}")

def _process_workloads(self, workloads: Sequence[workloads.All]) -> None:
from airflow.executors.workloads import ExecuteTask
def _process_workloads(self, workload_items: Sequence[workloads.All]) -> None:
from airflow.executors import workloads

# Airflow V3 version
for w in workloads:
if not isinstance(w, ExecuteTask):
for w in workload_items:
if isinstance(w, workloads.ExecuteTask):
task_command = [w]
task_key = w.ti.key
queue = w.ti.queue
executor_config = w.ti.executor_config or {}

del self.queued_tasks[task_key]
self.execute_async(
key=task_key,
command=task_command, # type: ignore[arg-type]
queue=queue,
executor_config=executor_config,
)
self.running.add(task_key)
elif AIRFLOW_V_3_3_PLUS and isinstance(w, workloads.ExecuteCallback):
callback_command = [w]
callback_key = w.callback.key
queue = None
if isinstance(w.callback.data, dict) and "queue" in w.callback.data:
queue = w.callback.data["queue"]

del self.queued_callbacks[callback_key]
self.execute_async(key=callback_key, command=callback_command, queue=queue) # type: ignore[arg-type]
self.running.add(callback_key)
else:
raise RuntimeError(f"{type(self)} cannot handle workloads of type {type(w)}")
command = [w]
key = w.ti.key
queue = w.ti.queue
executor_config = w.ti.executor_config or {}

del self.queued_tasks[key]
self.execute_async(key=key, command=command, queue=queue, executor_config=executor_config) # type: ignore[arg-type]
self.running.add(key)

def check_health(self):
"""Make a test API call to check the health of the Batch Executor."""
Expand Down Expand Up @@ -235,7 +257,7 @@ def sync(self):
def sync_running_jobs(self):
all_job_ids = self.active_workers.get_all_jobs()
if not all_job_ids:
self.log.debug("No active Airflow tasks, skipping sync")
self.log.debug("No active Airflow workloads, skipping sync")
return
describe_job_response = self._describe_jobs(all_job_ids)

Expand All @@ -245,8 +267,8 @@ def sync_running_jobs(self):
if job.get_job_state() == State.FAILED:
self._handle_failed_job(job)
elif job.get_job_state() == State.SUCCESS:
task_key = self.active_workers.pop_by_id(job.job_id)
self.success(task_key)
workload_key = self.active_workers.pop_by_id(job.job_id)
self.success(workload_key)

def _handle_failed_job(self, job):
"""
Expand All @@ -263,15 +285,15 @@ def _handle_failed_job(self, job):
# responsibility for ensuring the process started. Failures in the DAG will be caught by
# Airflow, which will be handled separately.
job_info = self.active_workers.id_to_job_info[job.job_id]
task_key = self.active_workers.id_to_key[job.job_id]
task_cmd = job_info.cmd
workload_key = self.active_workers.id_to_key[job.job_id]
workload_cmd = job_info.cmd
queue = job_info.queue
exec_info = job_info.config
failure_count = self.active_workers.failure_count_by_id(job_id=job.job_id)
if int(failure_count) < int(self.max_submit_job_attempts):
self.log.warning(
"Airflow task %s failed due to %s. Failure %s out of %s occurred on %s. Rescheduling.",
task_key,
"Airflow workload %s failed due to %s. Failure %s out of %s occurred on %s. Rescheduling.",
workload_key,
job.status_reason,
failure_count,
self.max_submit_job_attempts,
Expand All @@ -281,8 +303,8 @@ def _handle_failed_job(self, job):
self.active_workers.pop_by_id(job.job_id)
self.pending_jobs.append(
BatchQueuedJob(
task_key,
task_cmd,
workload_key,
workload_cmd,
queue,
exec_info,
failure_count + 1,
Expand All @@ -291,12 +313,12 @@ def _handle_failed_job(self, job):
)
else:
self.log.error(
"Airflow task %s has failed a maximum of %s times. Marking as failed",
task_key,
"Airflow workload %s has failed a maximum of %s times. Marking as failed",
workload_key,
failure_count,
)
self.active_workers.pop_by_id(job.job_id)
self.fail(task_key)
self.fail(workload_key)

def attempt_submit_jobs(self):
"""
Expand All @@ -309,8 +331,8 @@ def attempt_submit_jobs(self):
"""
for _ in range(len(self.pending_jobs)):
batch_job = self.pending_jobs.popleft()
key = batch_job.key
cmd = batch_job.command
workload_key = batch_job.key
workload_cmd = batch_job.command
queue = batch_job.queue
exec_config = batch_job.executor_config
attempt_number = batch_job.attempt_number
Expand All @@ -319,7 +341,7 @@ def attempt_submit_jobs(self):
self.pending_jobs.append(batch_job)
continue
try:
submit_job_response = self._submit_job(key, cmd, queue, exec_config or {})
submit_job_response = self._submit_job(workload_key, workload_cmd, queue, exec_config or {})
except NoCredentialsError:
self.pending_jobs.append(batch_job)
raise
Expand All @@ -337,18 +359,18 @@ def attempt_submit_jobs(self):
self.log.error(
(
"This job has been unsuccessfully attempted too many times (%s). "
"Dropping the task. Reason: %s"
"Dropping the workload. Reason: %s"
),
attempt_number,
failure_reason,
)
self.log_task_event(
event="batch job submit failure",
extra=f"This job has been unsuccessfully attempted too many times ({attempt_number}). "
f"Dropping the task. Reason: {failure_reason}",
ti_key=key,
f"Dropping the workload. Reason: {failure_reason}",
ti_key=workload_key,
)
self.fail(key=key)
self.fail(key=workload_key)
Comment thread
ferruzzi marked this conversation as resolved.
else:
batch_job.next_attempt_time = timezone.utcnow() + calculate_next_attempt_delay(
attempt_number
Expand All @@ -360,35 +382,39 @@ def attempt_submit_jobs(self):
job_id = submit_job_response["job_id"]
self.active_workers.add_job(
job_id=job_id,
airflow_task_key=key,
airflow_cmd=cmd,
airflow_workload_key=workload_key,
airflow_cmd=workload_cmd,
queue=queue,
exec_config=exec_config,
attempt_number=attempt_number,
)
self.running_state(key, job_id)
self.running_state(workload_key, job_id)

def _describe_jobs(self, job_ids) -> list[BatchJob]:
all_jobs = []
for i in range(0, len(job_ids), self.__class__.DESCRIBE_JOBS_BATCH_SIZE):
batched_job_ids = job_ids[i : i + self.__class__.DESCRIBE_JOBS_BATCH_SIZE]
if not batched_job_ids:
continue
boto_describe_tasks = self.batch.describe_jobs(jobs=batched_job_ids)
boto_describe_workloads = self.batch.describe_jobs(jobs=batched_job_ids)

describe_tasks_response = BatchDescribeJobsResponseSchema().load(boto_describe_tasks)
all_jobs.extend(describe_tasks_response["jobs"])
describe_workloads_response = BatchDescribeJobsResponseSchema().load(boto_describe_workloads)
all_jobs.extend(describe_workloads_response["jobs"])
return all_jobs

def execute_async(self, key: TaskInstanceKey, command: CommandType, queue=None, executor_config=None):
"""Save the task to be executed in the next sync using Boto3's RunTask API."""
def execute_async(
self, key: TaskInstanceKey | str, command: CommandType, queue=None, executor_config=None
):
"""Save the workload to be executed in the next sync using Boto3's RunTask API."""
if executor_config and "command" in executor_config:
raise ValueError('Executor Config should never override "command"')

if len(command) == 1:
from airflow.executors.workloads import ExecuteTask
from airflow.executors import workloads

if isinstance(command[0], ExecuteTask):
if isinstance(command[0], workloads.ExecuteTask) or (
AIRFLOW_V_3_3_PLUS and isinstance(command[0], workloads.ExecuteCallback)
):
workload = command[0]
ser_input = workload.model_dump_json()
command = [
Expand Down Expand Up @@ -433,7 +459,7 @@ def _submit_job_kwargs(
self, key: TaskInstanceKey, cmd: CommandType, queue: str, exec_config: ExecutorConfigType
) -> dict:
"""
Override the Airflow command to update the container overrides so kwargs are specific to this task.
Override the Airflow command to update the container overrides so kwargs are specific to this workload.

One last chance to modify Boto3's "submit_job" kwarg params before it gets passed into the Boto3
client. For the latest kwarg parameters:
Expand All @@ -450,7 +476,7 @@ def _submit_job_kwargs(
return submit_job_api

def end(self, heartbeat_interval=10):
"""Wait for all currently running tasks to end and prevent any new jobs from running."""
"""Wait for all currently running workloads to end and prevent any new jobs from running."""
try:
while True:
self.sync()
Expand Down Expand Up @@ -500,7 +526,7 @@ def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[Task
ti = next(ti for ti in tis if ti.external_executor_id == batch_job.job_id)
self.active_workers.add_job(
job_id=batch_job.job_id,
airflow_task_key=ti.key,
airflow_workload_key=ti.key,
airflow_cmd=ti.command_as_list(),
queue=ti.queue,
exec_config=ti.executor_config,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,22 @@
import datetime
from collections import defaultdict
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, TypeAlias

from airflow.providers.amazon.aws.executors.utils.base_config_keys import BaseConfigKeys
from airflow.utils.state import State

if TYPE_CHECKING:
from airflow.models.taskinstance import TaskInstanceKey
from airflow.providers.amazon.version_compat import AIRFLOW_V_3_3_PLUS

if AIRFLOW_V_3_3_PLUS:
from airflow.executors.workloads.types import WorkloadKey

BatchJobWorkloadKey: TypeAlias = WorkloadKey
else:
BatchJobWorkloadKey: TypeAlias = TaskInstanceKey # type: ignore[no-redef, misc]


CommandType = list[str]
ExecutorConfigType = dict[str, Any]
Expand All @@ -43,9 +52,9 @@
class BatchQueuedJob:
"""Represents a Batch job that is queued. The job will be run in the next heartbeat."""

key: TaskInstanceKey
key: BatchJobWorkloadKey
command: CommandType
queue: str
queue: str | None
executor_config: ExecutorConfigType
attempt_number: int
next_attempt_time: datetime.datetime
Expand Down Expand Up @@ -91,33 +100,33 @@ class BatchJobCollection:
"""A collection to manage running Batch Jobs."""

def __init__(self):
self.key_to_id: dict[TaskInstanceKey, str] = {}
self.id_to_key: dict[str, TaskInstanceKey] = {}
self.key_to_id: dict[BatchJobWorkloadKey, str] = {}
self.id_to_key: dict[str, BatchJobWorkloadKey] = {}
self.id_to_failure_counts: dict[str, int] = defaultdict(int)
self.id_to_job_info: dict[str, BatchJobInfo] = {}

def add_job(
self,
job_id: str,
airflow_task_key: TaskInstanceKey,
airflow_workload_key: BatchJobWorkloadKey,
airflow_cmd: CommandType,
queue: str,
exec_config: ExecutorConfigType,
attempt_number: int,
):
"""Add a job to the collection."""
self.key_to_id[airflow_task_key] = job_id
self.id_to_key[job_id] = airflow_task_key
self.key_to_id[airflow_workload_key] = job_id
self.id_to_key[job_id] = airflow_workload_key
self.id_to_failure_counts[job_id] = attempt_number
self.id_to_job_info[job_id] = BatchJobInfo(cmd=airflow_cmd, queue=queue, config=exec_config)

def pop_by_id(self, job_id: str) -> TaskInstanceKey:
def pop_by_id(self, job_id: str) -> BatchJobWorkloadKey:
"""Delete job from collection based off of Batch Job ID."""
task_key = self.id_to_key[job_id]
del self.key_to_id[task_key]
workload_key = self.id_to_key[job_id]
del self.key_to_id[workload_key]
del self.id_to_key[job_id]
del self.id_to_failure_counts[job_id]
return task_key
return workload_key

def failure_count_by_id(self, job_id: str) -> int:
"""Get the number of times a job has failed given a Batch Job Id."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]:
AIRFLOW_V_3_1_PLUS: bool = get_base_airflow_version_tuple() >= (3, 1, 0)
AIRFLOW_V_3_1_1_PLUS: bool = get_base_airflow_version_tuple() >= (3, 1, 1)
AIRFLOW_V_3_1_8_PLUS: bool = get_base_airflow_version_tuple() >= (3, 1, 8)
AIRFLOW_V_3_3_PLUS = get_base_airflow_version_tuple() >= (3, 3, 0)

try:
from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet
Expand All @@ -58,6 +59,7 @@ def is_arg_set(value): # type: ignore[misc,no-redef]
"AIRFLOW_V_3_1_PLUS",
"AIRFLOW_V_3_1_1_PLUS",
"AIRFLOW_V_3_1_8_PLUS",
"AIRFLOW_V_3_3_PLUS",
"NOTSET",
"ArgNotSet",
"is_arg_set",
Expand Down
Loading
Loading