Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions providers/edge3/docs/edge_executor.rst
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,38 @@ Here is an example setting pool_slots for a task:

task_with_template()

.. _edge_executor:execute_callback:

Support ExecuteCallback in Worker
---------------------------------

In addition to executing tasks, the EdgeExecutor can also dispatch executor-level
callbacks (``ExecuteCallback`` workloads, e.g. deadline callbacks) to edge workers.
When the scheduler hands an ``ExecuteCallback`` to ``EdgeExecutor.queue_workload``,
it is enqueued into the same job queue (``EdgeJobModel``) that is used for task
workloads, so an edge worker picks it up alongside regular tasks without any
additional configuration.

Callback jobs share the ``EdgeJobModel`` table with task jobs. They are
distinguished by reserved values in the identifier columns:

- ``dag_id`` is set to the constant tag ``ExecuteCallback``.
- ``task_id`` is set to the callback key (the callback ID).
- ``run_id`` is set to ``ExecuteCallback-<callback_key>``.
- ``map_index`` is fixed to ``-1`` and ``try_number`` to ``0``.

When the worker fetches such a job through the worker API, the command payload is
deserialized back into an ``ExecuteCallback`` workload (instead of an
``ExecuteTask``) based on these identifiers. The worker then runs the callback
through ``BaseExecutor.run_workload`` rather than the task supervisor flow used for
normal tasks.

.. note::

This feature is only active on Airflow 3.3 or newer. On earlier Airflow versions
the EdgeExecutor only handles ``ExecuteTask`` workloads and any
``ExecuteCallback`` will be rejected with a ``TypeError``.

Current Limitations Edge Executor
---------------------------------

Expand Down
53 changes: 30 additions & 23 deletions providers/edge3/src/airflow/providers/edge3/cli/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,13 @@
EdgeWorkerState,
EdgeWorkerVersionException,
)
from airflow.providers.edge3.version_compat import AIRFLOW_V_3_2_PLUS
from airflow.providers.edge3.version_compat import AIRFLOW_V_3_2_PLUS, AIRFLOW_V_3_3_PLUS
from airflow.utils.net import getfqdn
from airflow.utils.state import TaskInstanceState

if TYPE_CHECKING:
from airflow.configuration import AirflowConfigParser
from airflow.executors.workloads import ExecuteTask
from airflow.providers.edge3.models.types import ExecuteTypeBody

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -260,40 +260,47 @@ def _get_state(self) -> EdgeWorkerState:
return EdgeWorkerState.MAINTENANCE_MODE
return EdgeWorkerState.IDLE

def _run_job_via_supervisor(self, workload: ExecuteTask, results_queue: Queue) -> int:
def _run_job_via_supervisor(self, workload: ExecuteTypeBody, results_queue: Queue) -> int:
_reset_parent_signal_state()

from airflow.sdk.execution_time.supervisor import supervise

# Ignore ctrl-c in this process -- we don't want to kill _this_ one. we let tasks run to completion

os.setpgrp()

logger.info("Worker starting up pid=%d", os.getpid())
ti = workload.ti
setproctitle(
"airflow edge supervisor: "
f"dag_id={ti.dag_id} task_id={ti.task_id} run_id={ti.run_id} map_index={ti.map_index} "
f"try_number={ti.try_number}"
)

try:
supervise(
# This is the "wrong" ti type, but it duck types the same. TODO: Create a protocol for this.
# Same like in airflow/executors/local_executor.py:_execute_workload()
ti=ti, # type: ignore[arg-type]
dag_rel_path=workload.dag_rel_path,
bundle_info=workload.bundle_info,
token=workload.token,
server=self._execution_api_server_url,
log_path=workload.log_path,
)
if AIRFLOW_V_3_3_PLUS:
from airflow.executors.base_executor import BaseExecutor

BaseExecutor.run_workload(workload=workload, server=self._execution_api_server_url)
else:
from airflow.sdk.execution_time.supervisor import supervise

ti = workload.ti
setproctitle(
"airflow edge supervisor executing task: "
f"dag_id={ti.dag_id} task_id={ti.task_id} run_id={ti.run_id} map_index={ti.map_index} "
f"try_number={ti.try_number}"
)

supervise(
# This is the "wrong" ti type, but it duck types the same. TODO: Create a protocol for this.
# Same like in airflow/executors/local_executor.py:_execute_workload()
ti=ti, # type: ignore[arg-type]
dag_rel_path=workload.dag_rel_path,
bundle_info=workload.bundle_info,
token=workload.token,
server=self._execution_api_server_url,
log_path=workload.log_path,
)
return 0
except Exception as e:
logger.exception("Task execution failed")
results_queue.put(e)
return 1

def _launch_job(self, workload: ExecuteTask) -> tuple[Process, Queue[Exception]]:
def _launch_job(self, workload: ExecuteTypeBody) -> tuple[Process, Queue[Exception]]:
# Improvement: Use frozen GC to prevent child process from copying unnecessary memory
# See _spawn_workers_with_gc_freeze() in airflow-core/src/airflow/executors/local_executor.py
results_queue: Queue[Exception] = Queue()
Expand Down Expand Up @@ -421,7 +428,7 @@ async def fetch_and_run_job(self) -> None:

logger.info("Received job: %s", edge_job.identifier)

workload: ExecuteTask = edge_job.command
workload: ExecuteTypeBody = edge_job.command
process, results_queue = self._launch_job(workload)
if TYPE_CHECKING:
assert workload.log_path # We need to assume this is defined in here
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from airflow.providers.edge3.models.edge_job import EdgeJobModel
from airflow.providers.edge3.models.edge_logs import EdgeLogsModel
from airflow.providers.edge3.models.edge_worker import EdgeWorkerModel, EdgeWorkerState, reset_metrics
from airflow.providers.edge3.models.types import is_callback_execute
from airflow.utils.db import DBLocks, create_global_lock
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.state import TaskInstanceState
Expand Down Expand Up @@ -102,44 +103,72 @@ def queue_workload(
session: Session = NEW_SESSION,
) -> None:
"""Put new workload to queue. Airflow 3 entry point to execute a task."""
if not isinstance(workload, workloads.ExecuteTask):
raise TypeError(f"Don't know how to queue workload of type {type(workload).__name__}")
if is_callback_execute(workload):
from airflow.providers.edge3.models.types import EXECUTE_CALLBACK_TAG

existing_job = session.scalars(
select(EdgeJobModel).where(
EdgeJobModel.dag_id == EXECUTE_CALLBACK_TAG,
EdgeJobModel.task_id == workload.callback.key,
EdgeJobModel.run_id == f"{EXECUTE_CALLBACK_TAG}-{workload.callback.key}",
)
).first()

if existing_job:
existing_job.state = TaskInstanceState.QUEUED
existing_job.command = workload.model_dump_json()
else:
session.add(
EdgeJobModel(
dag_id=EXECUTE_CALLBACK_TAG,
task_id=workload.callback.key,
run_id=f"{EXECUTE_CALLBACK_TAG}-{workload.callback.key}",
map_index=-1,
try_number=0,
queue=self.conf.get_mandatory_value("operators", "default_queue"),
concurrency_slots=1,
state=TaskInstanceState.QUEUED,
command=workload.model_dump_json(),
)
)

task_instance = workload.ti
key = task_instance.key

# Check if job already exists with same dag_id, task_id, run_id, map_index, try_number
existing_job = session.scalars(
select(EdgeJobModel).where(
EdgeJobModel.dag_id == key.dag_id,
EdgeJobModel.task_id == key.task_id,
EdgeJobModel.run_id == key.run_id,
EdgeJobModel.map_index == key.map_index,
EdgeJobModel.try_number == key.try_number,
)
).first()

if existing_job:
existing_job.state = TaskInstanceState.QUEUED
existing_job.queue = task_instance.queue
existing_job.concurrency_slots = task_instance.pool_slots
existing_job.command = workload.model_dump_json()
existing_job.team_name = self.team_name
else:
session.add(
EdgeJobModel(
dag_id=key.dag_id,
task_id=key.task_id,
run_id=key.run_id,
map_index=key.map_index,
try_number=key.try_number,
state=TaskInstanceState.QUEUED,
queue=task_instance.queue,
concurrency_slots=task_instance.pool_slots,
command=workload.model_dump_json(),
team_name=self.team_name,
elif isinstance(workload, workloads.ExecuteTask):
task_instance = workload.ti
key = task_instance.key

# Check if job already exists with same dag_id, task_id, run_id, map_index, try_number
existing_job = session.scalars(
select(EdgeJobModel).where(
EdgeJobModel.dag_id == key.dag_id,
EdgeJobModel.task_id == key.task_id,
EdgeJobModel.run_id == key.run_id,
EdgeJobModel.map_index == key.map_index,
EdgeJobModel.try_number == key.try_number,
)
)
).first()

if existing_job:
existing_job.state = TaskInstanceState.QUEUED
existing_job.queue = task_instance.queue
existing_job.concurrency_slots = task_instance.pool_slots
existing_job.command = workload.model_dump_json()
else:
session.add(
EdgeJobModel(
dag_id=key.dag_id,
task_id=key.task_id,
run_id=key.run_id,
map_index=key.map_index,
try_number=key.try_number,
state=TaskInstanceState.QUEUED,
queue=task_instance.queue,
concurrency_slots=task_instance.pool_slots,
command=workload.model_dump_json(),
)
)

else:
raise TypeError(f"Don't know how to queue workload of type {type(workload).__name__}")

def _process_workloads(self, workloads: Sequence[workloads.All]) -> None:
"""
Expand Down
46 changes: 46 additions & 0 deletions providers/edge3/src/airflow/providers/edge3/models/types.py
Comment thread
wjddn279 marked this conversation as resolved.
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from typing import TYPE_CHECKING, TypeAlias, TypeGuard

from airflow.executors.workloads import ExecuteTask
from airflow.providers.edge3.version_compat import AIRFLOW_V_3_3_PLUS

if TYPE_CHECKING:
from airflow.executors import workloads
from airflow.executors.workloads import ExecuteCallback

if not AIRFLOW_V_3_3_PLUS:
ExecuteTypeBody: TypeAlias = ExecuteTask
else:
from airflow.executors.workloads import ExecutorWorkload

ExecuteTypeBody: TypeAlias = ExecutorWorkload # type: ignore[no-redef,misc]


def is_callback_execute(workload: workloads.All) -> TypeGuard[ExecuteCallback]:
if AIRFLOW_V_3_3_PLUS:
from airflow.executors.workloads import ExecuteCallback

return isinstance(workload, ExecuteCallback)
return False


# This is the key used to identify execute_callback jobs.
# Changing this value may break compatibility with existing data in the edge_job table.
EXECUTE_CALLBACK_TAG = "ExecuteCallback"
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]:

AIRFLOW_V_3_1_PLUS = get_base_airflow_version_tuple() >= (3, 1, 0)
AIRFLOW_V_3_2_PLUS = get_base_airflow_version_tuple() >= (3, 2, 0)
AIRFLOW_V_3_3_PLUS = get_base_airflow_version_tuple() >= (3, 3, 0)

__all__ = [
"AIRFLOW_V_3_1_PLUS",
"AIRFLOW_V_3_2_PLUS",
"AIRFLOW_V_3_3_PLUS",
]
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@
from fastapi import Path
from pydantic import BaseModel, Field

from airflow.executors.workloads import ExecuteTask # noqa: TCH001
from airflow.providers.common.compat.sdk import TaskInstanceKey
from airflow.providers.edge3.models.edge_worker import EdgeWorkerState # noqa: TCH001
from airflow.providers.edge3.models.types import ExecuteTypeBody # noqa: TCH001


class WorkerApiDocs:
Expand Down Expand Up @@ -91,7 +91,7 @@ class EdgeJobFetched(EdgeJobBase):
"""Job that is to be executed on the edge worker."""

command: Annotated[
ExecuteTask,
ExecuteTypeBody,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, maybe coming a bit late during second pass: So far the API was always backwards compatible, means if you have a Edge worker with a previous version running it would happily continue to run and softly "drain" if there is a provider or Airflow core version mis-match.

Now here the return type for job fetching changes, means if the worker would attempt to fetch a job with an old version, the job would be assigned but the worker (probably?) fails in de-serializing the content and fails. Probably before draining.

Have you tested this and how does it behave?
(Compared to Task SDK there is currently no Cadwyn layer for versioning, this is in the backlog...)

If this is a problem can we make it somehow compatible or make a parallel endpoint such that it is not failing on old clients until they drain and jobs are not pulled and never executed?

Field(
title="Command",
description="Command line to use to execute the job in Airflow",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from __future__ import annotations

from typing import Annotated
from typing import TYPE_CHECKING, Annotated

from fastapi import Body, Depends, status
from sqlalchemy import select, update
Expand All @@ -32,6 +32,7 @@
from airflow.executors.workloads import ExecuteTask
from airflow.providers.common.compat.sdk import timezone
from airflow.providers.edge3.models.edge_job import EdgeJobModel
from airflow.providers.edge3.version_compat import AIRFLOW_V_3_3_PLUS
from airflow.providers.edge3.worker_api.auth import jwt_token_authorization_rest
from airflow.providers.edge3.worker_api.datamodels import (
EdgeJobFetched,
Expand All @@ -40,10 +41,20 @@
)
from airflow.utils.state import TaskInstanceState

if TYPE_CHECKING:
from airflow.providers.edge3.models.types import ExecuteTypeBody

jobs_router = AirflowRouter(tags=["Jobs"], prefix="/jobs")


def parse_command(command: str) -> ExecuteTask:
def parse_command(command: str, dag_id: str, run_id: str) -> ExecuteTypeBody:
if AIRFLOW_V_3_3_PLUS:
from airflow.executors.workloads import ExecuteCallback
from airflow.providers.edge3.models.types import EXECUTE_CALLBACK_TAG

if dag_id == EXECUTE_CALLBACK_TAG and run_id.startswith(EXECUTE_CALLBACK_TAG):
return ExecuteCallback.model_validate_json(command) # type: ignore[return-value]

return ExecuteTask.model_validate_json(command)


Expand Down Expand Up @@ -104,7 +115,7 @@ def fetch(
run_id=job.run_id,
map_index=job.map_index,
try_number=job.try_number,
command=parse_command(job.command),
command=parse_command(job.command, job.dag_id, job.run_id),
concurrency_slots=job.concurrency_slots,
)

Expand Down
Loading
Loading