Skip to content

Commit c1ebabf

Browse files
jschefflgot686-yandex
authored andcommitted
Support Task execution interface (AIP-72) in Airflow 3 in EdgeExecutor (apache#44982)
* Extend the queue_workload() call with an ORM session * Support Task execution interface (AIP-72) in Airflow 3 * Fix CI errors, typos and static checks * Fix pytest backcompat * Apply Copilot feedback from other PR * Review Feedback * Add missing session to CLI call * Review Feedback
1 parent e912651 commit c1ebabf

File tree

17 files changed

+263
-48
lines changed

17 files changed

+263
-48
lines changed

airflow/cli/commands/remote_commands/task_command.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,8 @@ def _run_task_by_executor(args, dag: DAG, ti: TaskInstance) -> None:
286286
from airflow.executors import workloads
287287

288288
workload = workloads.ExecuteTask.make(ti, dag_path=dag.relative_fileloc)
289-
executor.queue_workload(workload)
289+
with create_session() as session:
290+
executor.queue_workload(workload, session)
290291
else:
291292
executor.queue_task_instance(
292293
ti,

airflow/executors/base_executor.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@
4646
import argparse
4747
from datetime import datetime
4848

49+
from sqlalchemy.orm import Session
50+
4951
from airflow.callbacks.base_callback_sink import BaseCallbackSink
5052
from airflow.callbacks.callback_requests import CallbackRequest
5153
from airflow.cli.cli_config import GroupCommand
@@ -171,7 +173,7 @@ def queue_command(
171173
else:
172174
self.log.error("could not queue task %s", task_instance.key)
173175

174-
def queue_workload(self, workload: workloads.All) -> None:
176+
def queue_workload(self, workload: workloads.All, session: Session) -> None:
175177
raise ValueError(f"Un-handled workload kind {type(workload).__name__!r} in {type(self).__name__}")
176178

177179
def queue_task_instance(

airflow/executors/local_executor.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,12 @@
3737

3838
from airflow import settings
3939
from airflow.executors.base_executor import PARALLELISM, BaseExecutor
40+
from airflow.utils.session import NEW_SESSION, provide_session
4041
from airflow.utils.state import TaskInstanceState
4142

4243
if TYPE_CHECKING:
44+
from sqlalchemy.orm import Session
45+
4346
from airflow.executors import workloads
4447

4548
TaskInstanceStateType = tuple[workloads.TaskInstance, TaskInstanceState, Optional[Exception]]
@@ -239,7 +242,8 @@ def end(self) -> None:
239242
def terminate(self):
240243
"""Terminate the executor is not doing anything."""
241244

242-
def queue_workload(self, workload: workloads.All):
245+
@provide_session
246+
def queue_workload(self, workload: workloads.All, session: Session = NEW_SESSION):
243247
self.activity_queue.put(workload)
244248
with self._unread_messages:
245249
self._unread_messages.value += 1

airflow/jobs/scheduler_job_runner.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -650,7 +650,7 @@ def _enqueue_task_instances_with_queued_state(
650650
# Has a real queue_activity implemented
651651
if executor.queue_workload.__func__ is not BaseExecutor.queue_workload: # type: ignore[attr-defined]
652652
workload = workloads.ExecuteTask.make(ti)
653-
executor.queue_workload(workload)
653+
executor.queue_workload(workload, session=session)
654654
continue
655655

656656
command = ti.command_as_list(

docs/apache-airflow-providers-edge/edge_executor.rst

+2-5
Original file line numberDiff line numberDiff line change
@@ -294,11 +294,8 @@ The following features are known missing and will be implemented in increments:
294294

295295
- Scaling test - Check and define boundaries of workers/jobs
296296
- Load tests - impact of scaled execution and code optimization
297-
- Airflow 3 / AIP-72 Migration
298-
299-
- Thin deployment based on Task SDK
300-
- DAG Code push (no need to GIT Sync)
301-
- Implicit with AIP-72: Move task context generation from Remote to Executor
297+
- Incremental logs during task execution can be served w/o shared log disk
298+
- Host name of worker is applied as job runner host name as well
302299

303300
- Documentation
304301

providers/src/airflow/providers/edge/CHANGELOG.rst

+11-3
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,23 @@
2727
Changelog
2828
---------
2929

30+
0.10.0pre0
31+
..........
32+
33+
Feature
34+
~~~~~~~
35+
36+
* ``Support Task execution interface (AIP-72) in Airflow 3. Experimental with ongoing development as AIP-72 is also under development.``
37+
3038
0.9.7pre0
3139
.........
3240

33-
* ``Make API retries configurable via ENV. Connection loss is sustained for 5min by default.``
34-
* ``Align retry handling logic and tooling with Task SDK, via retryhttp.``
35-
3641
Misc
3742
~~~~
3843

44+
* ``Make API retries configurable via ENV. Connection loss is sustained for 5min by default.``
45+
* ``Align retry handling logic and tooling with Task SDK, via retryhttp.``
46+
3947
0.9.6pre0
4048
.........
4149

providers/src/airflow/providers/edge/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
__all__ = ["__version__"]
3131

32-
__version__ = "0.9.7pre0"
32+
__version__ = "0.10.0pre0"
3333

3434
if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse(
3535
"2.10.0"

providers/src/airflow/providers/edge/cli/edge_command.py

+87-17
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from dataclasses import dataclass
2525
from datetime import datetime
2626
from http import HTTPStatus
27+
from multiprocessing import Process
2728
from pathlib import Path
2829
from subprocess import Popen
2930
from time import sleep
@@ -82,12 +83,6 @@ def force_use_internal_api_on_edge_worker():
8283
os.environ["_AIRFLOW__SKIP_DATABASE_EXECUTOR_COMPATIBILITY_CHECK"] = "1"
8384
os.environ["AIRFLOW_ENABLE_AIP_44"] = "True"
8485
if "airflow" in sys.argv[0] and sys.argv[1:3] == ["edge", "worker"]:
85-
if AIRFLOW_V_3_0_PLUS:
86-
# Obvious TODO Make EdgeWorker compatible with Airflow 3 (again)
87-
raise SystemExit(
88-
"Error: EdgeWorker is currently broken on Airflow 3/main due to removal of AIP-44, rework for AIP-72."
89-
)
90-
9186
api_url = conf.get("edge", "api_url")
9287
if not api_url:
9388
raise SystemExit("Error: API URL is not configured, please correct configuration.")
@@ -138,11 +133,26 @@ class _Job:
138133
"""Holds all information for a task/job to be executed as bundle."""
139134

140135
edge_job: EdgeJobFetched
141-
process: Popen
136+
process: Popen | Process
142137
logfile: Path
143138
logsize: int
144139
"""Last size of log file, point of last chunk push."""
145140

141+
@property
142+
def is_running(self) -> bool:
143+
"""Check if the job is still running."""
144+
if isinstance(self.process, Popen):
145+
self.process.poll()
146+
return self.process.returncode is None
147+
return self.process.exitcode is None
148+
149+
@property
150+
def is_success(self) -> bool:
151+
"""Check if the job was successful."""
152+
if isinstance(self.process, Popen):
153+
return self.process.returncode == 0
154+
return self.process.exitcode == 0
155+
146156

147157
class _EdgeWorkerCli:
148158
"""Runner instance which executes the Edge Worker."""
@@ -191,6 +201,73 @@ def _get_sysinfo(self) -> dict:
191201
"free_concurrency": self.free_concurrency,
192202
}
193203

204+
def _launch_job_af3(self, edge_job: EdgeJobFetched) -> tuple[Process, Path]:
205+
if TYPE_CHECKING:
206+
from airflow.executors.workloads import ExecuteTask
207+
208+
def _run_job_via_supervisor(
209+
workload: ExecuteTask,
210+
) -> int:
211+
from setproctitle import setproctitle
212+
213+
from airflow.sdk.execution_time.supervisor import supervise
214+
215+
# Ignore ctrl-c in this process -- we don't want to kill _this_ one. we let tasks run to completion
216+
signal.signal(signal.SIGINT, signal.SIG_IGN)
217+
218+
logger.info("Worker starting up pid=%d", os.getpid())
219+
setproctitle(f"airflow edge worker: {workload.ti.key}")
220+
221+
try:
222+
supervise(
223+
# This is the "wrong" ti type, but it duck types the same. TODO: Create a protocol for this.
224+
# Same like in airflow/executors/local_executor.py:_execute_work()
225+
ti=workload.ti, # type: ignore[arg-type]
226+
dag_path=workload.dag_path,
227+
token=workload.token,
228+
server=conf.get(
229+
"workers", "execution_api_server_url", fallback="http://localhost:9091/execution/"
230+
),
231+
log_path=workload.log_path,
232+
)
233+
return 0
234+
except Exception as e:
235+
logger.exception("Task execution failed: %s", e)
236+
return 1
237+
238+
workload: ExecuteTask = edge_job.command
239+
process = Process(
240+
target=_run_job_via_supervisor,
241+
kwargs={"workload": workload},
242+
)
243+
process.start()
244+
base_log_folder = conf.get("logging", "base_log_folder", fallback="NOT AVAILABLE")
245+
if TYPE_CHECKING:
246+
assert workload.log_path # We need to assume this is defined in here
247+
logfile = Path(base_log_folder, workload.log_path)
248+
return process, logfile
249+
250+
def _launch_job_af2_10(self, edge_job: EdgeJobFetched) -> tuple[Popen, Path]:
251+
"""Compatibility for Airflow 2.10 Launch."""
252+
env = os.environ.copy()
253+
env["AIRFLOW__CORE__DATABASE_ACCESS_ISOLATION"] = "True"
254+
env["AIRFLOW__CORE__INTERNAL_API_URL"] = conf.get("edge", "api_url")
255+
env["_AIRFLOW__SKIP_DATABASE_EXECUTOR_COMPATIBILITY_CHECK"] = "1"
256+
command: list[str] = edge_job.command # type: ignore[assignment]
257+
process = Popen(command, close_fds=True, env=env, start_new_session=True)
258+
logfile = logs_logfile_path(edge_job.key)
259+
return process, logfile
260+
261+
def _launch_job(self, edge_job: EdgeJobFetched):
262+
"""Get the received job executed."""
263+
process: Popen | Process
264+
if AIRFLOW_V_3_0_PLUS:
265+
process, logfile = self._launch_job_af3(edge_job)
266+
else:
267+
# Airflow 2.10
268+
process, logfile = self._launch_job_af2_10(edge_job)
269+
self.jobs.append(_Job(edge_job, process, logfile, 0))
270+
194271
def start(self):
195272
"""Start the execution in a loop until terminated."""
196273
try:
@@ -239,13 +316,7 @@ def fetch_job(self) -> bool:
239316
edge_job = jobs_fetch(self.hostname, self.queues, self.free_concurrency)
240317
if edge_job:
241318
logger.info("Received job: %s", edge_job)
242-
env = os.environ.copy()
243-
env["AIRFLOW__CORE__DATABASE_ACCESS_ISOLATION"] = "True"
244-
env["AIRFLOW__CORE__INTERNAL_API_URL"] = conf.get("edge", "api_url")
245-
env["_AIRFLOW__SKIP_DATABASE_EXECUTOR_COMPATIBILITY_CHECK"] = "1"
246-
process = Popen(edge_job.command, close_fds=True, env=env, start_new_session=True)
247-
logfile = logs_logfile_path(edge_job.key)
248-
self.jobs.append(_Job(edge_job, process, logfile, 0))
319+
self._launch_job(edge_job)
249320
jobs_set_state(edge_job.key, TaskInstanceState.RUNNING)
250321
return True
251322

@@ -257,10 +328,9 @@ def check_running_jobs(self) -> None:
257328
used_concurrency = 0
258329
for i in range(len(self.jobs) - 1, -1, -1):
259330
job = self.jobs[i]
260-
job.process.poll()
261-
if job.process.returncode is not None:
331+
if not job.is_running:
262332
self.jobs.remove(job)
263-
if job.process.returncode == 0:
333+
if job.is_success:
264334
logger.info("Job completed: %s", job.edge_job)
265335
jobs_set_state(job.edge_job.key, TaskInstanceState.SUCCESS)
266336
else:

providers/src/airflow/providers/edge/executors/edge_executor.py

+29-1
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def execute_async(
108108
executor_config: Any | None = None,
109109
session: Session = NEW_SESSION,
110110
) -> None:
111-
"""Execute asynchronously."""
111+
"""Execute asynchronously. Airflow 2.10 entry point to execute a task."""
112112
# Use of a temponary trick to get task instance, will be changed with Airflow 3.0.0
113113
# code works together with _process_tasks overwrite to get task instance.
114114
task_instance = self.edge_queued_tasks[key][3] # TaskInstance in fourth element
@@ -129,6 +129,34 @@ def execute_async(
129129
)
130130
)
131131

132+
@provide_session
133+
def queue_workload(
134+
self,
135+
workload: Any, # Note actually "airflow.executors.workloads.All" but not existing in Airflow 2.10
136+
session: Session = NEW_SESSION,
137+
) -> None:
138+
"""Put new workload to queue. Airflow 3 entry point to execute a task."""
139+
from airflow.executors import workloads
140+
141+
if not isinstance(workload, workloads.ExecuteTask):
142+
raise TypeError(f"Don't know how to queue workload of type {type(workload).__name__}")
143+
144+
task_instance = workload.ti
145+
key = task_instance.key
146+
session.add(
147+
EdgeJobModel(
148+
dag_id=key.dag_id,
149+
task_id=key.task_id,
150+
run_id=key.run_id,
151+
map_index=key.map_index,
152+
try_number=key.try_number,
153+
state=TaskInstanceState.QUEUED,
154+
queue=DEFAULT_QUEUE, # TODO Queues to be added once implemented in AIP-72
155+
concurrency_slots=1, # TODO Pool slots to be added once implemented in AIP-72
156+
command=workload.model_dump_json(),
157+
)
158+
)
159+
132160
def _check_worker_liveness(self, session: Session) -> bool:
133161
"""Reset worker state if heartbeat timed out."""
134162
changed = False

providers/src/airflow/providers/edge/provider.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ source-date-epoch: 1729683247
2727

2828
# note that those versions are maintained by release manager - do not update them manually
2929
versions:
30-
- 0.9.7pre0
30+
- 0.10.0pre0
3131

3232
dependencies:
3333
- apache-airflow>=2.10.0

providers/src/airflow/providers/edge/worker_api/auth.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -112,4 +112,7 @@ def jwt_token_authorization_rest(
112112
request: Request, authorization: str = Header(description="JWT Authorization Token")
113113
):
114114
"""Check if the JWT token is correct for REST API requests."""
115-
jwt_token_authorization(request.url.path, authorization)
115+
PREFIX = "/edge_worker/v1/"
116+
path = request.url.path
117+
method_path = path[path.find(PREFIX) + len(PREFIX) :] if PREFIX in path else path
118+
jwt_token_authorization(method_path, authorization)

providers/src/airflow/providers/edge/worker_api/datamodels.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
from airflow.models.taskinstancekey import TaskInstanceKey
2828
from airflow.providers.edge.models.edge_worker import EdgeWorkerState # noqa: TCH001
29-
from airflow.providers.edge.worker_api.routes._v2_compat import Path
29+
from airflow.providers.edge.worker_api.routes._v2_compat import ExecuteTask, Path
3030

3131

3232
class WorkerApiDocs:
@@ -90,7 +90,11 @@ class EdgeJobFetched(EdgeJobBase):
9090
"""Job that is to be executed on the edge worker."""
9191

9292
command: Annotated[
93-
list[str], Field(title="Command", description="Command line to use to execute the job.")
93+
ExecuteTask,
94+
Field(
95+
title="Command",
96+
description="Command line to use to execute the job in Airflow 2. Task definition in Airflow 3",
97+
),
9498
]
9599
concurrency_slots: Annotated[int, Field(description="Number of concurrency slots the job requires.")]
96100

providers/src/airflow/providers/edge/worker_api/routes/_v2_compat.py

+15
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,12 @@
2727
from airflow.api_fastapi.common.db.common import SessionDep
2828
from airflow.api_fastapi.common.router import AirflowRouter
2929
from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc
30+
31+
# In Airflow 3 with AIP-72 we get workload addressed by ExecuteTask
32+
from airflow.executors.workloads import ExecuteTask
33+
34+
def parse_command(command: str) -> ExecuteTask:
35+
return ExecuteTask.model_validate_json(command)
3036
else:
3137
# Mock the external dependnecies
3238
from typing import Callable
@@ -118,3 +124,12 @@ def decorator(func: Callable) -> Callable:
118124
return func
119125

120126
return decorator
127+
128+
# In Airflow 3 with AIP-72 we get workload addressed by ExecuteTask
129+
# But in Airflow 2.10 it is a command line array
130+
ExecuteTask = list[str] # type: ignore[no-redef,assignment,misc]
131+
132+
def parse_command(command: str) -> ExecuteTask:
133+
from ast import literal_eval
134+
135+
return literal_eval(command)

providers/src/airflow/providers/edge/worker_api/routes/jobs.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
from __future__ import annotations
1919

20-
from ast import literal_eval
2120
from typing import Annotated
2221

2322
from sqlalchemy import select, update
@@ -35,6 +34,7 @@
3534
Depends,
3635
SessionDep,
3736
create_openapi_http_exception_doc,
37+
parse_command,
3838
status,
3939
)
4040
from airflow.utils import timezone
@@ -91,7 +91,7 @@ def fetch(
9191
run_id=job.run_id,
9292
map_index=job.map_index,
9393
try_number=job.try_number,
94-
command=literal_eval(job.command),
94+
command=parse_command(job.command),
9595
concurrency_slots=job.concurrency_slots,
9696
)
9797

0 commit comments

Comments
 (0)