Skip to content

Commit

Permalink
Support Task execution interface (AIP-72) in Airflow 3 in EdgeExecutor (
Browse files Browse the repository at this point in the history
#44982)

* Extend the queue_workload() call with an ORM session

* Support Task execution interface (AIP-72) in Airflow 3

* Fix CI errors, typos and static checks

* Fix pytest backcompat

* Apply Copilot feedback from other PR

* Review Feedback

* Add missing session to CLI call

* Review Feedback
  • Loading branch information
jscheffl authored Jan 7, 2025
1 parent 4c5d85a commit 0399381
Show file tree
Hide file tree
Showing 17 changed files with 263 additions and 48 deletions.
3 changes: 2 additions & 1 deletion airflow/cli/commands/remote_commands/task_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,8 @@ def _run_task_by_executor(args, dag: DAG, ti: TaskInstance) -> None:
from airflow.executors import workloads

workload = workloads.ExecuteTask.make(ti, dag_path=dag.relative_fileloc)
executor.queue_workload(workload)
with create_session() as session:
executor.queue_workload(workload, session)
else:
executor.queue_task_instance(
ti,
Expand Down
4 changes: 3 additions & 1 deletion airflow/executors/base_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@
import argparse
from datetime import datetime

from sqlalchemy.orm import Session

from airflow.callbacks.base_callback_sink import BaseCallbackSink
from airflow.callbacks.callback_requests import CallbackRequest
from airflow.cli.cli_config import GroupCommand
Expand Down Expand Up @@ -171,7 +173,7 @@ def queue_command(
else:
self.log.error("could not queue task %s", task_instance.key)

def queue_workload(self, workload: workloads.All) -> None:
def queue_workload(self, workload: workloads.All, session: Session) -> None:
raise ValueError(f"Un-handled workload kind {type(workload).__name__!r} in {type(self).__name__}")

def queue_task_instance(
Expand Down
6 changes: 5 additions & 1 deletion airflow/executors/local_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,12 @@

from airflow import settings
from airflow.executors.base_executor import PARALLELISM, BaseExecutor
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.state import TaskInstanceState

if TYPE_CHECKING:
from sqlalchemy.orm import Session

from airflow.executors import workloads

TaskInstanceStateType = tuple[workloads.TaskInstance, TaskInstanceState, Optional[Exception]]
Expand Down Expand Up @@ -239,7 +242,8 @@ def end(self) -> None:
def terminate(self):
"""Terminate the executor is not doing anything."""

def queue_workload(self, workload: workloads.All):
@provide_session
def queue_workload(self, workload: workloads.All, session: Session = NEW_SESSION):
self.activity_queue.put(workload)
with self._unread_messages:
self._unread_messages.value += 1
Expand Down
2 changes: 1 addition & 1 deletion airflow/jobs/scheduler_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,7 +650,7 @@ def _enqueue_task_instances_with_queued_state(
# Has a real queue_activity implemented
if executor.queue_workload.__func__ is not BaseExecutor.queue_workload: # type: ignore[attr-defined]
workload = workloads.ExecuteTask.make(ti)
executor.queue_workload(workload)
executor.queue_workload(workload, session=session)
continue

command = ti.command_as_list(
Expand Down
7 changes: 2 additions & 5 deletions docs/apache-airflow-providers-edge/edge_executor.rst
Original file line number Diff line number Diff line change
Expand Up @@ -294,11 +294,8 @@ The following features are known missing and will be implemented in increments:

- Scaling test - Check and define boundaries of workers/jobs
- Load tests - impact of scaled execution and code optimization
- Airflow 3 / AIP-72 Migration

- Thin deployment based on Task SDK
- DAG Code push (no need to GIT Sync)
- Implicit with AIP-72: Move task context generation from Remote to Executor
- Incremental logs during task execution can be served w/o shared log disk
- Host name of worker is applied as job runner host name as well

- Documentation

Expand Down
14 changes: 11 additions & 3 deletions providers/src/airflow/providers/edge/CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,23 @@
Changelog
---------

0.10.0pre0
..........

Feature
~~~~~~~

* ``Support Task execution interface (AIP-72) in Airflow 3. Experimental with ongoing development as AIP-72 is also under development.``

0.9.7pre0
.........

* ``Make API retries configurable via ENV. Connection loss is sustained for 5min by default.``
* ``Align retry handling logic and tooling with Task SDK, via retryhttp.``

Misc
~~~~

* ``Make API retries configurable via ENV. Connection loss is sustained for 5min by default.``
* ``Align retry handling logic and tooling with Task SDK, via retryhttp.``

0.9.6pre0
.........

Expand Down
2 changes: 1 addition & 1 deletion providers/src/airflow/providers/edge/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

__all__ = ["__version__"]

__version__ = "0.9.7pre0"
__version__ = "0.10.0pre0"

if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse(
"2.10.0"
Expand Down
104 changes: 87 additions & 17 deletions providers/src/airflow/providers/edge/cli/edge_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from dataclasses import dataclass
from datetime import datetime
from http import HTTPStatus
from multiprocessing import Process
from pathlib import Path
from subprocess import Popen
from time import sleep
Expand Down Expand Up @@ -82,12 +83,6 @@ def force_use_internal_api_on_edge_worker():
os.environ["_AIRFLOW__SKIP_DATABASE_EXECUTOR_COMPATIBILITY_CHECK"] = "1"
os.environ["AIRFLOW_ENABLE_AIP_44"] = "True"
if "airflow" in sys.argv[0] and sys.argv[1:3] == ["edge", "worker"]:
if AIRFLOW_V_3_0_PLUS:
# Obvious TODO Make EdgeWorker compatible with Airflow 3 (again)
raise SystemExit(
"Error: EdgeWorker is currently broken on Airflow 3/main due to removal of AIP-44, rework for AIP-72."
)

api_url = conf.get("edge", "api_url")
if not api_url:
raise SystemExit("Error: API URL is not configured, please correct configuration.")
Expand Down Expand Up @@ -138,11 +133,26 @@ class _Job:
"""Holds all information for a task/job to be executed as bundle."""

edge_job: EdgeJobFetched
process: Popen
process: Popen | Process
logfile: Path
logsize: int
"""Last size of log file, point of last chunk push."""

@property
def is_running(self) -> bool:
"""Check if the job is still running."""
if isinstance(self.process, Popen):
self.process.poll()
return self.process.returncode is None
return self.process.exitcode is None

@property
def is_success(self) -> bool:
"""Check if the job was successful."""
if isinstance(self.process, Popen):
return self.process.returncode == 0
return self.process.exitcode == 0


class _EdgeWorkerCli:
"""Runner instance which executes the Edge Worker."""
Expand Down Expand Up @@ -191,6 +201,73 @@ def _get_sysinfo(self) -> dict:
"free_concurrency": self.free_concurrency,
}

def _launch_job_af3(self, edge_job: EdgeJobFetched) -> tuple[Process, Path]:
if TYPE_CHECKING:
from airflow.executors.workloads import ExecuteTask

def _run_job_via_supervisor(
workload: ExecuteTask,
) -> int:
from setproctitle import setproctitle

from airflow.sdk.execution_time.supervisor import supervise

# Ignore ctrl-c in this process -- we don't want to kill _this_ one. we let tasks run to completion
signal.signal(signal.SIGINT, signal.SIG_IGN)

logger.info("Worker starting up pid=%d", os.getpid())
setproctitle(f"airflow edge worker: {workload.ti.key}")

try:
supervise(
# This is the "wrong" ti type, but it duck types the same. TODO: Create a protocol for this.
# Same like in airflow/executors/local_executor.py:_execute_work()
ti=workload.ti, # type: ignore[arg-type]
dag_path=workload.dag_path,
token=workload.token,
server=conf.get(
"workers", "execution_api_server_url", fallback="http://localhost:9091/execution/"
),
log_path=workload.log_path,
)
return 0
except Exception as e:
logger.exception("Task execution failed: %s", e)
return 1

workload: ExecuteTask = edge_job.command
process = Process(
target=_run_job_via_supervisor,
kwargs={"workload": workload},
)
process.start()
base_log_folder = conf.get("logging", "base_log_folder", fallback="NOT AVAILABLE")
if TYPE_CHECKING:
assert workload.log_path # We need to assume this is defined in here
logfile = Path(base_log_folder, workload.log_path)
return process, logfile

def _launch_job_af2_10(self, edge_job: EdgeJobFetched) -> tuple[Popen, Path]:
"""Compatibility for Airflow 2.10 Launch."""
env = os.environ.copy()
env["AIRFLOW__CORE__DATABASE_ACCESS_ISOLATION"] = "True"
env["AIRFLOW__CORE__INTERNAL_API_URL"] = conf.get("edge", "api_url")
env["_AIRFLOW__SKIP_DATABASE_EXECUTOR_COMPATIBILITY_CHECK"] = "1"
command: list[str] = edge_job.command # type: ignore[assignment]
process = Popen(command, close_fds=True, env=env, start_new_session=True)
logfile = logs_logfile_path(edge_job.key)
return process, logfile

def _launch_job(self, edge_job: EdgeJobFetched):
"""Get the received job executed."""
process: Popen | Process
if AIRFLOW_V_3_0_PLUS:
process, logfile = self._launch_job_af3(edge_job)
else:
# Airflow 2.10
process, logfile = self._launch_job_af2_10(edge_job)
self.jobs.append(_Job(edge_job, process, logfile, 0))

def start(self):
"""Start the execution in a loop until terminated."""
try:
Expand Down Expand Up @@ -239,13 +316,7 @@ def fetch_job(self) -> bool:
edge_job = jobs_fetch(self.hostname, self.queues, self.free_concurrency)
if edge_job:
logger.info("Received job: %s", edge_job)
env = os.environ.copy()
env["AIRFLOW__CORE__DATABASE_ACCESS_ISOLATION"] = "True"
env["AIRFLOW__CORE__INTERNAL_API_URL"] = conf.get("edge", "api_url")
env["_AIRFLOW__SKIP_DATABASE_EXECUTOR_COMPATIBILITY_CHECK"] = "1"
process = Popen(edge_job.command, close_fds=True, env=env, start_new_session=True)
logfile = logs_logfile_path(edge_job.key)
self.jobs.append(_Job(edge_job, process, logfile, 0))
self._launch_job(edge_job)
jobs_set_state(edge_job.key, TaskInstanceState.RUNNING)
return True

Expand All @@ -257,10 +328,9 @@ def check_running_jobs(self) -> None:
used_concurrency = 0
for i in range(len(self.jobs) - 1, -1, -1):
job = self.jobs[i]
job.process.poll()
if job.process.returncode is not None:
if not job.is_running:
self.jobs.remove(job)
if job.process.returncode == 0:
if job.is_success:
logger.info("Job completed: %s", job.edge_job)
jobs_set_state(job.edge_job.key, TaskInstanceState.SUCCESS)
else:
Expand Down
30 changes: 29 additions & 1 deletion providers/src/airflow/providers/edge/executors/edge_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def execute_async(
executor_config: Any | None = None,
session: Session = NEW_SESSION,
) -> None:
"""Execute asynchronously."""
"""Execute asynchronously. Airflow 2.10 entry point to execute a task."""
# Use of a temponary trick to get task instance, will be changed with Airflow 3.0.0
# code works together with _process_tasks overwrite to get task instance.
task_instance = self.edge_queued_tasks[key][3] # TaskInstance in fourth element
Expand All @@ -129,6 +129,34 @@ def execute_async(
)
)

@provide_session
def queue_workload(
self,
workload: Any, # Note actually "airflow.executors.workloads.All" but not existing in Airflow 2.10
session: Session = NEW_SESSION,
) -> None:
"""Put new workload to queue. Airflow 3 entry point to execute a task."""
from airflow.executors import workloads

if not isinstance(workload, workloads.ExecuteTask):
raise TypeError(f"Don't know how to queue workload of type {type(workload).__name__}")

task_instance = workload.ti
key = task_instance.key
session.add(
EdgeJobModel(
dag_id=key.dag_id,
task_id=key.task_id,
run_id=key.run_id,
map_index=key.map_index,
try_number=key.try_number,
state=TaskInstanceState.QUEUED,
queue=DEFAULT_QUEUE, # TODO Queues to be added once implemented in AIP-72
concurrency_slots=1, # TODO Pool slots to be added once implemented in AIP-72
command=workload.model_dump_json(),
)
)

def _check_worker_liveness(self, session: Session) -> bool:
"""Reset worker state if heartbeat timed out."""
changed = False
Expand Down
2 changes: 1 addition & 1 deletion providers/src/airflow/providers/edge/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ source-date-epoch: 1729683247

# note that those versions are maintained by release manager - do not update them manually
versions:
- 0.9.7pre0
- 0.10.0pre0

dependencies:
- apache-airflow>=2.10.0
Expand Down
5 changes: 4 additions & 1 deletion providers/src/airflow/providers/edge/worker_api/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,4 +112,7 @@ def jwt_token_authorization_rest(
request: Request, authorization: str = Header(description="JWT Authorization Token")
):
"""Check if the JWT token is correct for REST API requests."""
jwt_token_authorization(request.url.path, authorization)
PREFIX = "/edge_worker/v1/"
path = request.url.path
method_path = path[path.find(PREFIX) + len(PREFIX) :] if PREFIX in path else path
jwt_token_authorization(method_path, authorization)
8 changes: 6 additions & 2 deletions providers/src/airflow/providers/edge/worker_api/datamodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.providers.edge.models.edge_worker import EdgeWorkerState # noqa: TCH001
from airflow.providers.edge.worker_api.routes._v2_compat import Path
from airflow.providers.edge.worker_api.routes._v2_compat import ExecuteTask, Path


class WorkerApiDocs:
Expand Down Expand Up @@ -90,7 +90,11 @@ class EdgeJobFetched(EdgeJobBase):
"""Job that is to be executed on the edge worker."""

command: Annotated[
list[str], Field(title="Command", description="Command line to use to execute the job.")
ExecuteTask,
Field(
title="Command",
description="Command line to use to execute the job in Airflow 2. Task definition in Airflow 3",
),
]
concurrency_slots: Annotated[int, Field(description="Number of concurrency slots the job requires.")]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@
from airflow.api_fastapi.common.db.common import SessionDep
from airflow.api_fastapi.common.router import AirflowRouter
from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc

# In Airflow 3 with AIP-72 we get workload addressed by ExecuteTask
from airflow.executors.workloads import ExecuteTask

def parse_command(command: str) -> ExecuteTask:
return ExecuteTask.model_validate_json(command)
else:
# Mock the external dependnecies
from typing import Callable
Expand Down Expand Up @@ -118,3 +124,12 @@ def decorator(func: Callable) -> Callable:
return func

return decorator

# In Airflow 3 with AIP-72 we get workload addressed by ExecuteTask
# But in Airflow 2.10 it is a command line array
ExecuteTask = list[str] # type: ignore[no-redef,assignment,misc]

def parse_command(command: str) -> ExecuteTask:
from ast import literal_eval

return literal_eval(command)
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

from __future__ import annotations

from ast import literal_eval
from typing import Annotated

from sqlalchemy import select, update
Expand All @@ -35,6 +34,7 @@
Depends,
SessionDep,
create_openapi_http_exception_doc,
parse_command,
status,
)
from airflow.utils import timezone
Expand Down Expand Up @@ -91,7 +91,7 @@ def fetch(
run_id=job.run_id,
map_index=job.map_index,
try_number=job.try_number,
command=literal_eval(job.command),
command=parse_command(job.command),
concurrency_slots=job.concurrency_slots,
)

Expand Down
Loading

0 comments on commit 0399381

Please sign in to comment.