diff --git a/providers/edge3/src/airflow/providers/edge3/cli/dataclasses.py b/providers/edge3/src/airflow/providers/edge3/cli/dataclasses.py index 63e12f6f81092..a45561c63991f 100644 --- a/providers/edge3/src/airflow/providers/edge3/cli/dataclasses.py +++ b/providers/edge3/src/airflow/providers/edge3/cli/dataclasses.py @@ -17,12 +17,16 @@ from __future__ import annotations import json +import subprocess +import traceback from dataclasses import asdict, dataclass from multiprocessing import Process from pathlib import Path from typing import TYPE_CHECKING if TYPE_CHECKING: + from multiprocessing.queues import Queue + from airflow.providers.edge3.models.edge_worker import EdgeWorkerState from airflow.providers.edge3.worker_api.datamodels import EdgeJobFetched @@ -72,17 +76,64 @@ class Job: """Holds all information for a task/job to be executed as bundle.""" edge_job: EdgeJobFetched - process: Process + # Process can be either a subprocess.Popen (for the spawn path) or a + # multiprocessing.Process (for the fork path) + process: subprocess.Popen | Process logfile: Path logsize: int = 0 """Last size of log file, point of last chunk push.""" + results_queue: Queue | None = None + """Queue for child process to push results to parent, if using fork-based execution model.""" + stderr_file_path: Path | None = None + """Path to file where stderr is being redirected, if using spawn-based execution model.""" @property def is_running(self) -> bool: """Check if the job is still running.""" + if isinstance(self.process, subprocess.Popen): + return self.process.poll() is None return self.process.is_alive() @property def is_success(self) -> bool: """Check if the job was successful.""" + if isinstance(self.process, subprocess.Popen): + return self.process.returncode == 0 return self.process.exitcode == 0 + + @property + def should_poll_logs(self) -> bool: + """Check if logs should be pushed while waiting for job completion.""" + # Fork path: keep pushing logs while the child is running and has not sent a result yet. + # Subprocess path: keep pushing logs while the child is running; status comes from Popen. + if not self.is_running: + return False + return self.results_queue is None or self.results_queue.empty() + + def drain_result(self) -> object | None: + """Read the child result if the execution model provides one.""" + if self.results_queue is None or self.results_queue.empty(): + return None + return self.results_queue.get() + + def failure_details(self, result: object | None) -> str: + """Format execution-model-specific failure details.""" + if isinstance(self.process, subprocess.Popen): + stderr_output = "" + if self.stderr_file_path: + stderr_output = self.stderr_file_path.read_bytes().decode(errors="backslashreplace").strip() + ex_txt = f"Task subprocess exited with code {self.process.returncode}" + if stderr_output: + ex_txt = f"{ex_txt}\n{stderr_output}" + return ex_txt + + return ( + "\n".join(traceback.format_exception(result)) + if isinstance(result, Exception) + else "(Unknown error, no exception details available)" + ) + + def cleanup(self) -> None: + """Remove transient files owned by this job.""" + if self.stderr_file_path: + self.stderr_file_path.unlink(missing_ok=True) diff --git a/providers/edge3/src/airflow/providers/edge3/cli/worker.py b/providers/edge3/src/airflow/providers/edge3/cli/worker.py index 5eff1a9cac850..077f68023294c 100644 --- a/providers/edge3/src/airflow/providers/edge3/cli/worker.py +++ b/providers/edge3/src/airflow/providers/edge3/cli/worker.py @@ -19,9 +19,10 @@ import logging import os import signal +import subprocess import sys +import tempfile import time -import traceback from asyncio import Task, create_task, gather, get_running_loop, sleep from collections.abc import Awaitable, Callable from contextlib import suppress @@ -66,6 +67,7 @@ if TYPE_CHECKING: from airflow.configuration import AirflowConfigParser from airflow.executors.workloads import ExecuteTask + from airflow.providers.edge3.worker_api.datamodels import EdgeJobFetched logger = logging.getLogger(__name__) @@ -414,6 +416,7 @@ def _get_state(self) -> EdgeWorkerState: return EdgeWorkerState.IDLE def _run_job_via_supervisor(self, workload: ExecuteTask, results_queue: Queue) -> int: + """Run a task by calling the supervisor directly (executes inside a forked child process).""" _reset_parent_signal_state() from airflow.sdk.execution_time.supervisor import supervise @@ -447,17 +450,80 @@ def _run_job_via_supervisor(self, workload: ExecuteTask, results_queue: Queue) - results_queue.put(e) return 1 - def _launch_job(self, workload: ExecuteTask) -> tuple[Process, Queue[Exception]]: + def _launch_job_subprocess(self, workload: ExecuteTask) -> tuple[subprocess.Popen, Path]: + """Launch workload via a fresh Python interpreter (subprocess.Popen).""" + env = os.environ.copy() + if self._execution_api_server_url: + env["AIRFLOW__CORE__EXECUTION_API_SERVER_URL"] = self._execution_api_server_url + + # Keep stderr off a PIPE: the worker only inspects stderr after the task finishes, + # so a verbose child could otherwise fill the pipe buffer and block forever. Also keep + # it task-scoped instead of inheriting the worker's stderr/stdout; supervisor startup + # failures should be pushed to the task log, not only the worker/container log. + with tempfile.NamedTemporaryFile( + prefix="airflow-edge-task-stderr-", suffix=".log", delete=False + ) as stderr_file: + stderr_file_path = Path(stderr_file.name) + try: + process = subprocess.Popen( + [ + sys.executable, + "-m", + "airflow.sdk.execution_time.execute_workload", + "--json-string", + workload.model_dump_json(), + ], + env=env, + start_new_session=True, + stderr=stderr_file, + ) + except Exception: + stderr_file_path.unlink(missing_ok=True) + raise + logger.info( + "Launched task subprocess pid=%d for %s", + process.pid, + workload.ti.id, + ) + return process, stderr_file_path + + def _launch_job_fork(self, workload: ExecuteTask) -> tuple[Process, Queue]: + """Launch workload by forking the current process (multiprocessing.Process).""" # Improvement: Use frozen GC to prevent child process from copying unnecessary memory # See _spawn_workers_with_gc_freeze() in airflow-core/src/airflow/executors/local_executor.py - results_queue: Queue[Exception] = Queue() + results_queue: Queue = Queue() process = Process( target=self._run_job_via_supervisor, kwargs={"workload": workload, "results_queue": results_queue}, ) process.start() + logger.info("Launched task fork pid=%d for %s", process.pid, workload.ti.id) return process, results_queue + def _launch_job(self, edge_job: EdgeJobFetched, workload: ExecuteTask, logfile: Path) -> Job: + """ + Launch a task process. + + Uses ``subprocess.Popen`` (fresh interpreter) when + ``core.execute_tasks_new_python_interpreter`` is ``True`` or when + ``os.fork`` is unavailable (e.g. Windows). Falls back to + ``multiprocessing.Process`` (fork) otherwise — preserving the + original behaviour for existing deployments. + """ + use_new_interpreter = not hasattr(os, "fork") or self.conf.getboolean( + "core", + "execute_tasks_new_python_interpreter", + fallback=False, + ) + if use_new_interpreter: + # Fresh subprocess path: spawn a new Python interpreter; no shared memory with parent + # Technically safer and more robust, but with more overhead + subprocess_process, stderr_file_path = self._launch_job_subprocess(workload) + return Job(edge_job, subprocess_process, logfile, stderr_file_path=stderr_file_path) + # Fork path: clone the current process; child inherits parent memory + fork_process, results_queue = self._launch_job_fork(workload) + return Job(edge_job, fork_process, logfile, results_queue=results_queue) + async def _push_logs_in_chunks(self, job: Job): aio_logfile = anyio.Path(job.logfile) if self.push_logs and await aio_logfile.exists() and (await aio_logfile.stat()).st_size > job.logsize: @@ -581,11 +647,10 @@ async def fetch_and_run_job(self) -> None: logger.info("Received job: %s", edge_job.identifier) workload: ExecuteTask = edge_job.command - process, results_queue = self._launch_job(workload) if TYPE_CHECKING: assert workload.log_path # We need to assume this is defined in here logfile = Path(self.base_log_folder, workload.log_path) - job = Job(edge_job, process, logfile) + job = self._launch_job(edge_job, workload, logfile) self.jobs.append(job) await jobs_set_state(edge_job.key, TaskInstanceState.RUNNING) @@ -595,37 +660,37 @@ async def fetch_and_run_job(self) -> None: self.background_tasks.add(task) task.add_done_callback(self.background_tasks.discard) - while job.is_running and results_queue.empty(): + while job.should_poll_logs: await self._push_logs_in_chunks(job) for _ in range(0, self.job_poll_interval * 10): await sleep(0.1) if not job.is_running: break await self._push_logs_in_chunks(job) - supervisor_msg = ( - "(Unknown error, no exception details available)" - if results_queue.empty() - else results_queue.get() - ) - # Ensure that supervisor really ended after we grabbed results from queue - while True: - if not job.is_running: - break + # Fork path: drain the result queue BEFORE waiting for the child to fully exit. + # A large exception travels through multiprocessing's pipe-backed queue; reading it + # here lets the child's feeder thread flush and avoids deadlocking on process exit. + # Fresh-interpreter subprocesses do not share Python exception objects with the parent. + result = job.drain_result() + # Wait for the child process to fully exit (fork path: queue is already drained above). + while job.is_running: # noqa: ASYNC110 await sleep(0.1) self.jobs.remove(job) if job.is_success: logger.info("Job completed: %s", job.edge_job.identifier) await jobs_set_state(job.edge_job.key, TaskInstanceState.SUCCESS) + job.cleanup() else: - if isinstance(supervisor_msg, Exception): - supervisor_msg = "\n".join(traceback.format_exception(supervisor_msg)) - logger.error("Job failed: %s with:\n%s", job.edge_job.identifier, supervisor_msg) + ex_txt = job.failure_details(result) + job.cleanup() + logger.error("Job failed: %s with:\n%s", job.edge_job.identifier, ex_txt) + # Push it upwards to logs for better diagnostic as well await logs_push( task=job.edge_job.key, log_chunk_time=timezone.utcnow(), - log_chunk_data=f"Error executing job:\n{supervisor_msg}", + log_chunk_data=f"Error executing job:\n{ex_txt}", ) await jobs_set_state(job.edge_job.key, TaskInstanceState.FAILED) diff --git a/providers/edge3/tests/unit/edge3/cli/test_worker.py b/providers/edge3/tests/unit/edge3/cli/test_worker.py index 5ac72b84867f1..7babb2a9f06a2 100644 --- a/providers/edge3/tests/unit/edge3/cli/test_worker.py +++ b/providers/edge3/tests/unit/edge3/cli/test_worker.py @@ -22,11 +22,15 @@ import json import logging import multiprocessing +import os import signal +import subprocess +import sys from datetime import datetime from io import StringIO -from multiprocessing import Process, Queue +from multiprocessing import Process from pathlib import Path +from typing import cast from unittest import mock from unittest.mock import call, patch @@ -52,6 +56,7 @@ WorkerRegistrationReturn, WorkerSetStateReturn, ) +from airflow.utils.state import TaskInstanceState from tests_common.test_utils.config import conf_vars from tests_common.test_utils.version_compat import AIRFLOW_V_3_2_PLUS @@ -118,6 +123,15 @@ def join(self, timeout=None): pass +class _MockPopen(subprocess.Popen): + def __init__(self, returncode: int | None = None, pid: int = 1234): + self.returncode = returncode + self.pid = pid + + def poll(self): + return self.returncode + + class TestEdgeWorker: @pytest.fixture(autouse=True) def setup_parser(self): @@ -234,6 +248,100 @@ def test_execution_api_server_url( url = test_worker._execution_api_server_url assert url == expected_url + @pytest.mark.parametrize( + ("has_fork", "use_new_interpreter", "expected_launch_method"), + [ + pytest.param(True, False, "fork", id="fork_available_config_false"), + pytest.param(True, True, "subprocess", id="fork_available_config_true"), + pytest.param(False, False, "subprocess", id="fork_unavailable_config_false"), + pytest.param(False, True, "subprocess", id="fork_unavailable_config_true"), + ], + ) + def test_launch_job_honors_execute_tasks_new_python_interpreter( + self, + has_fork, + use_new_interpreter, + expected_launch_method, + monkeypatch, + tmp_path: Path, + worker_with_job: EdgeWorker, + ): + if not has_fork: + monkeypatch.delattr(os, "fork", raising=False) + worker_with_job.conf = mock.MagicMock() + worker_with_job.conf.getboolean.return_value = use_new_interpreter + edge_job = worker_with_job.jobs[0].edge_job + workload = edge_job.command + logfile = tmp_path / "mock.log" + subprocess_process = _MockPopen(returncode=None) + stderr_file_path = tmp_path / "stderr.log" + fork_process = _MockProcess() + results_queue = mock.MagicMock() + + with ( + patch.object( + worker_with_job, "_launch_job_subprocess", return_value=(subprocess_process, stderr_file_path) + ) as mock_launch_subprocess, + patch.object( + worker_with_job, "_launch_job_fork", return_value=(fork_process, results_queue) + ) as mock_launch_fork, + ): + job = worker_with_job._launch_job(edge_job, workload, logfile) + + if has_fork: + worker_with_job.conf.getboolean.assert_called_once_with( + "core", "execute_tasks_new_python_interpreter", fallback=False + ) + else: + worker_with_job.conf.getboolean.assert_not_called() + if expected_launch_method == "subprocess": + assert job.process is subprocess_process + assert job.results_queue is None + assert job.stderr_file_path == stderr_file_path + mock_launch_subprocess.assert_called_once_with(workload) + mock_launch_fork.assert_not_called() + else: + assert job.process is fork_process + assert job.results_queue is results_queue + assert job.stderr_file_path is None + mock_launch_fork.assert_called_once_with(workload) + mock_launch_subprocess.assert_not_called() + + @patch("airflow.providers.edge3.cli.worker.subprocess.Popen") + def test_launch_job_subprocess_uses_fresh_interpreter_and_spools_stderr( + self, + mock_popen, + worker_with_job: EdgeWorker, + ): + process = _MockPopen(returncode=None, pid=4321) + mock_popen.return_value = process + worker_with_job.__dict__["_execution_api_server_url"] = "https://mock-server/execution" + workload = worker_with_job.jobs[0].edge_job.command + stderr_file_path = None + + try: + returned_process, stderr_file_path = worker_with_job._launch_job_subprocess(workload) + assert returned_process is process + + popen_args, popen_kwargs = mock_popen.call_args + assert popen_args[0] == [ + sys.executable, + "-m", + "airflow.sdk.execution_time.execute_workload", + "--json-string", + workload.model_dump_json(), + ] + assert ( + popen_kwargs["env"]["AIRFLOW__CORE__EXECUTION_API_SERVER_URL"] + == "https://mock-server/execution" + ) + assert popen_kwargs["start_new_session"] is True + assert popen_kwargs["stderr"] is not subprocess.PIPE + assert Path(popen_kwargs["stderr"].name) == stderr_file_path + finally: + if stderr_file_path: + stderr_file_path.unlink(missing_ok=True) + @patch("airflow.sdk.execution_time.supervisor.supervise") @pytest.mark.asyncio async def test_supervise_launch( @@ -265,7 +373,7 @@ async def test_supervise_launch_fail( q.put.assert_called_once() @patch("airflow.providers.edge3.cli.worker.jobs_fetch") - @patch("airflow.providers.edge3.cli.worker.EdgeWorker._launch_job", return_value=(Process(), Queue())) + @patch("airflow.providers.edge3.cli.worker.EdgeWorker._launch_job") @pytest.mark.asyncio async def test_fetch_and_run_job_no_job( self, @@ -282,7 +390,7 @@ async def test_fetch_and_run_job_no_job( mock_launch_job.assert_not_called() @patch("airflow.providers.edge3.cli.worker.jobs_fetch") - @patch("airflow.providers.edge3.cli.worker.EdgeWorker._launch_job", return_value=(Process(), Queue())) + @patch("airflow.providers.edge3.cli.worker.EdgeWorker._launch_job") @patch("airflow.providers.edge3.cli.worker.jobs_set_state") @patch("airflow.providers.edge3.cli.worker.EdgeWorker._push_logs_in_chunks") @patch("airflow.providers.edge3.cli.worker.logs_push") @@ -296,20 +404,20 @@ async def test_fetch_and_run_job_one_job( mock_jobs_set_state, mock_launch_job, mock_jobs_fetch, + tmp_path: Path, worker_with_job: EdgeWorker, ): - mock_jobs_fetch.side_effect = [ - EdgeJobFetched( - dag_id="test", - task_id="test", - run_id="test", - map_index=-1, - try_number=1, - concurrency_slots=1, - command=MOCK_COMMAND, # type: ignore[arg-type] - ), - None, - ] + edge_job = EdgeJobFetched( + dag_id="test", + task_id="test", + run_id="test", + map_index=-1, + try_number=1, + concurrency_slots=1, + command=MOCK_COMMAND, # type: ignore[arg-type] + ) + mock_jobs_fetch.side_effect = [edge_job, None] + mock_launch_job.return_value = Job(edge_job, _MockProcess(), tmp_path / "mock.log") worker_with_job.concurrency = 1 # only one job at a time assert worker_with_job.free_concurrency == 0 @@ -318,7 +426,9 @@ async def test_fetch_and_run_job_one_job( mock_jobs_fetch.assert_called_once() fetch_args = mock_jobs_fetch.call_args assert fetch_args.args[3] is None # team_name should be None - mock_launch_job.assert_called_once() + mock_launch_job.assert_called_once_with( + edge_job, edge_job.command, Path(worker_with_job.base_log_folder, "mock.log") + ) assert mock_jobs_set_state.call_count == 2 mock_push_log_chunks.assert_called_once() assert len(worker_with_job.jobs) == 1 # no new job added (was removed at the end...) @@ -346,26 +456,26 @@ async def test_fetch_and_run_job_possible_deadlock( the deadlock condition. Forking a small top-level target sidesteps that and reproduces the actual queue-feeder/pipe-buffer deadlock the fix targets. """ - mock_jobs_fetch.side_effect = [ - EdgeJobFetched( - dag_id="test", - task_id="test", - run_id="test", - map_index=-1, - try_number=1, - concurrency_slots=1, - command=MOCK_COMMAND, # type: ignore[arg-type] - ), - None, - ] + edge_job = EdgeJobFetched( + dag_id="test", + task_id="test", + run_id="test", + map_index=-1, + try_number=1, + concurrency_slots=1, + command=MOCK_COMMAND, # type: ignore[arg-type] + ) + mock_jobs_fetch.side_effect = [edge_job, None] worker_with_job.concurrency = 1 # only one job at a time assert worker_with_job.free_concurrency == 0 ctx = multiprocessing.get_context("fork") results_queue = ctx.Queue() - process = ctx.Process(target=_emit_large_exception_target, args=(results_queue,)) + process = cast("Process", ctx.Process(target=_emit_large_exception_target, args=(results_queue,))) + + launched_job = Job(edge_job, process, worker_with_job.jobs[0].logfile, results_queue=results_queue) - with patch.object(EdgeWorker, "_launch_job", return_value=(process, results_queue)): + with patch.object(EdgeWorker, "_launch_job", return_value=launched_job): process.start() try: await asyncio.wait_for(worker_with_job.fetch_and_run_job(), timeout=10.0) @@ -391,48 +501,89 @@ async def test_fetch_and_run_job_possible_deadlock( assert len(worker_with_job.jobs) <= 1 # new job removed, original fixture job still there @patch("airflow.providers.edge3.cli.worker.jobs_fetch") - @patch("airflow.providers.edge3.cli.worker.EdgeWorker._launch_job", return_value=(Process(), Queue())) + @patch("airflow.providers.edge3.cli.worker.EdgeWorker._launch_job") @patch("airflow.providers.edge3.cli.worker.jobs_set_state") @patch("airflow.providers.edge3.cli.worker.EdgeWorker._push_logs_in_chunks") @patch("airflow.providers.edge3.cli.worker.logs_push") @patch.object(Job, "is_running", property(lambda _: False)) @patch.object(Job, "is_success", property(lambda _: False)) - @patch("traceback.format_exception", return_value=[]) @pytest.mark.asyncio async def test_fetch_and_run_job_one_job_fail( self, - mock_traceback, mock_logs_push, mock_push_log_chunks, mock_jobs_set_state, mock_launch_job, mock_jobs_fetch, + tmp_path: Path, worker_with_job: EdgeWorker, ): - mock_jobs_fetch.side_effect = [ - EdgeJobFetched( - dag_id="test", - task_id="test", - run_id="test", - map_index=-1, - try_number=1, - concurrency_slots=1, - command=MOCK_COMMAND, # type: ignore[arg-type] - ), - None, - ] + edge_job = EdgeJobFetched( + dag_id="test", + task_id="test", + run_id="test", + map_index=-1, + try_number=1, + concurrency_slots=1, + command=MOCK_COMMAND, # type: ignore[arg-type] + ) + mock_jobs_fetch.side_effect = [edge_job, None] + mock_launch_job.return_value = Job(edge_job, _MockProcess(), tmp_path / "mock.log") worker_with_job.concurrency = 1 # only one job at a time assert worker_with_job.free_concurrency == 0 await worker_with_job.fetch_and_run_job() mock_jobs_fetch.assert_called_once() - mock_launch_job.assert_called_once() + mock_launch_job.assert_called_once_with( + edge_job, edge_job.command, Path(worker_with_job.base_log_folder, "mock.log") + ) assert mock_jobs_set_state.call_count == 2 mock_push_log_chunks.assert_called_once() assert len(worker_with_job.jobs) == 1 # no new job added (was removed at the end...) mock_logs_push.assert_called_once() + @patch("airflow.providers.edge3.cli.worker.jobs_fetch") + @patch("airflow.providers.edge3.cli.worker.jobs_set_state") + @patch("airflow.providers.edge3.cli.worker.EdgeWorker._push_logs_in_chunks") + @patch("airflow.providers.edge3.cli.worker.logs_push") + @pytest.mark.asyncio + async def test_fetch_and_run_job_subprocess_failure_pushes_stderr_to_logs( + self, + mock_logs_push, + mock_push_log_chunks, + mock_jobs_set_state, + mock_jobs_fetch, + tmp_path: Path, + worker_with_job: EdgeWorker, + ): + edge_job = EdgeJobFetched( + dag_id="test", + task_id="test", + run_id="test", + map_index=-1, + try_number=1, + concurrency_slots=1, + command=MOCK_COMMAND, # type: ignore[arg-type] + ) + mock_jobs_fetch.return_value = edge_job + worker_with_job.concurrency = 1 + process = _MockPopen(returncode=1, pid=5678) + stderr_file_path = tmp_path / "subprocess-stderr.log" + stderr_file_path.write_text("ModuleNotFoundError: No module named 'common'\n") + launched_job = Job(edge_job, process, tmp_path / "mock.log", stderr_file_path=stderr_file_path) + + with patch.object(worker_with_job, "_launch_job", return_value=launched_job): + await worker_with_job.fetch_and_run_job() + + mock_jobs_fetch.assert_called_once() + mock_push_log_chunks.assert_called_once() + assert mock_jobs_set_state.call_args_list[-1].args[1] == TaskInstanceState.FAILED + log_chunk_data = mock_logs_push.call_args.kwargs["log_chunk_data"] + assert "Task subprocess exited with code 1" in log_chunk_data + assert "ModuleNotFoundError: No module named 'common'" in log_chunk_data + assert not stderr_file_path.exists() + @time_machine.travel(datetime.now(), tick=False) @patch("airflow.providers.edge3.cli.worker.logs_push") @pytest.mark.asyncio