diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/app.py b/airflow-core/src/airflow/api_fastapi/execution_api/app.py index 6caafa23d2bd5..f96802ae3411c 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/app.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/app.py @@ -38,6 +38,7 @@ get_sig_validation_args, get_signing_args, ) +from airflow.process_context import override_process_context if TYPE_CHECKING: import httpx @@ -236,7 +237,7 @@ def replace_any_of_with_one_of(spec): return openapi_schema -def create_task_execution_api_app() -> FastAPI: +def create_task_execution_api_app() -> CadwynWithOpenAPICustomization: """Create FastAPI app for task execution API.""" from airflow.api_fastapi.execution_api.routes import execution_api_router from airflow.api_fastapi.execution_api.versions import bundle @@ -300,6 +301,17 @@ def get_extra_schemas() -> dict[str, dict]: } +class _RequestScopedServerContextApp: + """Wrap an ASGI app so in-process requests behave like server-side API handling.""" + + def __init__(self, app: FastAPI) -> None: + self.app = app + + async def __call__(self, scope: Any, receive: Any, send: Any) -> None: + with override_process_context("server"): + await self.app(scope, receive, send) + + @attrs.define() class InProcessExecutionAPI: """ @@ -309,11 +321,11 @@ class InProcessExecutionAPI: needed so that we can use the sync httpx client """ - _app: FastAPI | None = None + _app: CadwynWithOpenAPICustomization | None = None _cm: AsyncExitStack | None = None @cached_property - def app(self): + def app(self) -> CadwynWithOpenAPICustomization: if not self._app: from airflow.api_fastapi.common.dagbag import create_dag_bag from airflow.api_fastapi.execution_api.datamodels.token import TIClaims, TIToken @@ -343,6 +355,10 @@ async def always_allow(request: Request): return self._app + @cached_property + def request_scoped_app(self) -> _RequestScopedServerContextApp: + return _RequestScopedServerContextApp(self.app) + @cached_property def transport(self) -> httpx.WSGITransport: import asyncio @@ -350,7 +366,7 @@ def transport(self) -> httpx.WSGITransport: import httpx from a2wsgi import ASGIMiddleware - middleware = ASGIMiddleware(self.app) + middleware = ASGIMiddleware(self.request_scoped_app) # https://github.com/abersheeran/a2wsgi/discussions/64 async def start_lifespan(cm: AsyncExitStack, app: FastAPI): @@ -365,4 +381,4 @@ async def start_lifespan(cm: AsyncExitStack, app: FastAPI): def atransport(self) -> httpx.ASGITransport: import httpx - return httpx.ASGITransport(app=self.app) + return httpx.ASGITransport(app=self.request_scoped_app) diff --git a/airflow-core/src/airflow/models/connection.py b/airflow-core/src/airflow/models/connection.py index 3032636f1b17c..a4eafbc6d2e11 100644 --- a/airflow-core/src/airflow/models/connection.py +++ b/airflow-core/src/airflow/models/connection.py @@ -20,7 +20,6 @@ import json import logging import re -import sys import warnings from contextlib import suppress from json import JSONDecodeError @@ -35,6 +34,7 @@ from airflow.exceptions import AirflowException, AirflowNotFoundException from airflow.models.base import ID_LEN, Base from airflow.models.crypto import get_fernet +from airflow.process_context import should_use_task_sdk_api_path from airflow.utils.helpers import prune_dict from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.session import NEW_SESSION, provide_session @@ -507,7 +507,7 @@ def get_connection_from_secrets(cls, conn_id: str, team_name: str | None = None) # If this is set it means are in some kind of execution context (Task, Dag Parse or Triggerer perhaps) # and should use the Task SDK API server path - if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"): + if should_use_task_sdk_api_path(): from airflow.sdk import Connection as TaskSDKConnection from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType @@ -593,7 +593,7 @@ def to_dict(self, *, prune_empty: bool = False, validate: bool = True) -> dict[s @classmethod def from_json(cls, value, conn_id=None) -> Connection: - if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"): + if should_use_task_sdk_api_path(): from airflow.sdk import Connection as TaskSDKConnection warnings.warn( diff --git a/airflow-core/src/airflow/models/variable.py b/airflow-core/src/airflow/models/variable.py index 9bc2ba9be9649..925a0bb39c1e9 100644 --- a/airflow-core/src/airflow/models/variable.py +++ b/airflow-core/src/airflow/models/variable.py @@ -20,7 +20,6 @@ import contextlib import json import logging -import sys import warnings from typing import TYPE_CHECKING, Any @@ -32,6 +31,7 @@ from airflow.configuration import conf, ensure_secrets_loaded from airflow.models.base import ID_LEN, Base from airflow.models.crypto import get_fernet +from airflow.process_context import should_use_task_sdk_api_path from airflow.secrets.metastore import MetastoreBackend from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.session import NEW_SESSION, create_session, provide_session @@ -181,7 +181,7 @@ def get( # If this is set it means we are in some kind of execution context (Task, Dag Parse or Triggerer perhaps) # and should use the Task SDK API server path - if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"): + if should_use_task_sdk_api_path(): warnings.warn( "Using Variable.get from `airflow.models` is deprecated." "Please use `get` on Variable from sdk(`airflow.sdk.Variable`) instead", @@ -241,7 +241,7 @@ def set( # If this is set it means we are in some kind of execution context (Task, Dag Parse or Triggerer perhaps) # and should use the Task SDK API server path - if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"): + if should_use_task_sdk_api_path(): warnings.warn( "Using Variable.set from `airflow.models` is deprecated." "Please use `set` on Variable from sdk(`airflow.sdk.Variable`) instead", @@ -329,7 +329,7 @@ def update( # If this is set it means are in some kind of execution context (Task, Dag Parse or Triggerer perhaps) # and should use the Task SDK API server path - if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"): + if should_use_task_sdk_api_path(): warnings.warn( "Using Variable.update from `airflow.models` is deprecated." "Please use `set` on Variable from sdk(`airflow.sdk.Variable`) instead as it is an upsert.", @@ -395,7 +395,7 @@ def delete(key: str, team_name: str | None = None, session: Session | None = Non # If this is set it means are in some kind of execution context (Task, Dag Parse or Triggerer perhaps) # and should use the Task SDK API server path - if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"): + if should_use_task_sdk_api_path(): warnings.warn( "Using Variable.delete from `airflow.models` is deprecated." "Please use `delete` on Variable from sdk(`airflow.sdk.Variable`) instead", diff --git a/airflow-core/src/airflow/process_context.py b/airflow-core/src/airflow/process_context.py new file mode 100644 index 0000000000000..1655d6e81fbb8 --- /dev/null +++ b/airflow-core/src/airflow/process_context.py @@ -0,0 +1,59 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import os +import sys +from collections.abc import Generator +from contextlib import contextmanager +from contextvars import ContextVar +from typing import Literal + +__all__ = [ + "get_process_context", + "override_process_context", + "should_use_task_sdk_api_path", +] + +_PROCESS_CONTEXT_OVERRIDE: ContextVar[str | None] = ContextVar( + "_AIRFLOW_PROCESS_CONTEXT_OVERRIDE", + default=None, +) + + +def get_process_context() -> str | None: + """Return the current process context, preferring request-scoped overrides.""" + return _PROCESS_CONTEXT_OVERRIDE.get() or os.environ.get("_AIRFLOW_PROCESS_CONTEXT") + + +@contextmanager +def override_process_context(context: Literal["server", "client"]) -> Generator[None, None, None]: + """Temporarily override the current process context for the active execution flow.""" + token = _PROCESS_CONTEXT_OVERRIDE.set(context) + try: + yield + finally: + _PROCESS_CONTEXT_OVERRIDE.reset(token) + + +def should_use_task_sdk_api_path() -> bool: + """Return True when execution-context helpers should route through Task SDK APIs.""" + if get_process_context() == "server": + return False + + task_runner_module = sys.modules.get("airflow.sdk.execution_time.task_runner") + return bool(getattr(task_runner_module, "SUPERVISOR_COMMS", None)) diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/test_app.py b/airflow-core/tests/unit/api_fastapi/execution_api/test_app.py index b0cb1d85c2e33..858296301566b 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/test_app.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/test_app.py @@ -16,10 +16,16 @@ # under the License. from __future__ import annotations +import sys +from unittest import mock + +import httpx import pytest +from airflow.api_fastapi.execution_api.app import InProcessExecutionAPI from airflow.api_fastapi.execution_api.datamodels.taskinstance import TaskInstance from airflow.api_fastapi.execution_api.versions import bundle +from airflow.process_context import get_process_context, override_process_context pytestmark = pytest.mark.db_test @@ -111,3 +117,32 @@ def test_multiple_requests_with_different_correlation_ids(self, client): # Verify they didn't interfere with each other assert correlation_id_1 != correlation_id_2 + + +def test_in_process_execution_api_uses_request_scoped_server_context(monkeypatch): + api = InProcessExecutionAPI() + fake_task_runner = mock.Mock() + fake_task_runner.SUPERVISOR_COMMS = object() + + monkeypatch.setenv("AIRFLOW_VAR_KEY1", "VALUE") + + with ( + override_process_context("client"), + mock.patch.dict(sys.modules, {"airflow.sdk.execution_time.task_runner": fake_task_runner}), + mock.patch( + "airflow.sdk.Variable.get", + side_effect=AssertionError( + "In-process Execution API requests should not route through Task SDK Variable.get" + ), + ), + httpx.Client(transport=api.transport, base_url="http://in-process.invalid") as client, + ): + assert get_process_context() == "client" + + response = client.get("/variables/key1") + + assert response.status_code == 200 + assert response.json() == {"key": "key1", "value": "VALUE"} + assert get_process_context() == "client" + + assert get_process_context() is None diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_connections.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_connections.py index a2e3cb51fab32..19e87bafd14c1 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_connections.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_connections.py @@ -17,6 +17,7 @@ from __future__ import annotations +import sys from unittest import mock import pytest @@ -105,6 +106,40 @@ def test_connection_get_from_env_var(self, client, session): "extra": '{"headers": "header"}', } + @mock.patch.dict( + "os.environ", + { + "AIRFLOW_CONN_TEST_CONN_SERVER": '{"uri": "http://root:admin@localhost:8080/https?headers=header"}', + "_AIRFLOW_PROCESS_CONTEXT": "server", + }, + ) + def test_connection_get_uses_server_path_when_supervisor_comms_exists(self, client): + fake_task_runner = mock.Mock() + fake_task_runner.SUPERVISOR_COMMS = object() + + with ( + mock.patch.dict(sys.modules, {"airflow.sdk.execution_time.task_runner": fake_task_runner}), + mock.patch( + "airflow.sdk.Connection.get", + side_effect=AssertionError( + "Execution API should not route through Task SDK Connection.get in server context" + ), + ), + ): + response = client.get("/execution/connections/test_conn_server") + + assert response.status_code == 200 + assert response.json() == { + "conn_id": "test_conn_server", + "conn_type": "http", + "host": "localhost", + "login": "root", + "password": "admin", + "schema": "https", + "port": 8080, + "extra": '{"headers": "header"}', + } + def test_connection_get_not_found(self, client): response = client.get("/execution/connections/non_existent_test_conn") diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_variables.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_variables.py index fe7611636358d..67290edde423e 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_variables.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_variables.py @@ -96,6 +96,28 @@ def test_variable_get_from_env_var(self, client, session): assert response.status_code == 200 assert response.json() == {"key": "key1", "value": "VALUE"} + @mock.patch.dict( + "os.environ", + {"AIRFLOW_VAR_KEY1": "VALUE", "_AIRFLOW_PROCESS_CONTEXT": "server"}, + ) + def test_variable_get_uses_server_path_when_supervisor_comms_exists(self, client): + fake_task_runner = mock.Mock() + fake_task_runner.SUPERVISOR_COMMS = object() + + with ( + mock.patch.dict("sys.modules", {"airflow.sdk.execution_time.task_runner": fake_task_runner}), + mock.patch( + "airflow.sdk.Variable.get", + side_effect=AssertionError( + "Execution API should not route through Task SDK Variable.get in server context" + ), + ), + ): + response = client.get("/execution/variables/key1") + + assert response.status_code == 200 + assert response.json() == {"key": "key1", "value": "VALUE"} + @pytest.mark.parametrize( "key", [ @@ -158,6 +180,31 @@ def test_should_create_variable(self, client, key, payload, session): if "description" in payload: assert var_from_db.description == payload["description"] + @mock.patch.dict( + "os.environ", + {"_AIRFLOW_PROCESS_CONTEXT": "server"}, + ) + def test_variable_put_uses_server_path_when_supervisor_comms_exists(self, client, session): + fake_task_runner = mock.Mock() + fake_task_runner.SUPERVISOR_COMMS = object() + + with ( + mock.patch.dict("sys.modules", {"airflow.sdk.execution_time.task_runner": fake_task_runner}), + mock.patch( + "airflow.sdk.Variable.set", + side_effect=AssertionError( + "Execution API should not route through Task SDK Variable.set in server context" + ), + ), + ): + response = client.put("/execution/variables/var_server_only", json={"value": "server_value"}) + + assert response.status_code == 201 + assert response.json()["message"] == "Variable successfully set" + var_from_db = session.scalars(select(Variable).where(Variable.key == "var_server_only")).first() + assert var_from_db is not None + assert var_from_db.val == "server_value" + @pytest.mark.parametrize( ("key", "payload", "error_type"), [ @@ -278,3 +325,29 @@ def test_should_not_delete_variable(self, client, session): vars = session.scalars(select(Variable)).all() assert len(vars) == 1 + + @mock.patch.dict( + "os.environ", + {"_AIRFLOW_PROCESS_CONTEXT": "server"}, + ) + def test_variable_delete_uses_server_path_when_supervisor_comms_exists(self, client, session): + Variable.set(key="var_server_delete", value="to_delete", session=session) + session.commit() + + fake_task_runner = mock.Mock() + fake_task_runner.SUPERVISOR_COMMS = object() + + with ( + mock.patch.dict("sys.modules", {"airflow.sdk.execution_time.task_runner": fake_task_runner}), + mock.patch( + "airflow.sdk.Variable.delete", + side_effect=AssertionError( + "Execution API should not route through Task SDK Variable.delete in server context" + ), + ), + ): + response = client.delete("/execution/variables/var_server_delete") + + assert response.status_code == 204 + session.expire_all() + assert session.scalar(select(Variable).where(Variable.key == "var_server_delete")) is None diff --git a/airflow-core/tests/unit/models/test_connection.py b/airflow-core/tests/unit/models/test_connection.py index 94cabe5e4daf4..1fd0d755cae7d 100644 --- a/airflow-core/tests/unit/models/test_connection.py +++ b/airflow-core/tests/unit/models/test_connection.py @@ -49,6 +49,12 @@ def clear_fernet_cache(self): yield get_fernet.cache_clear() + @pytest.fixture(autouse=True) + def clear_process_context(self, monkeypatch): + """Isolate tests from process-wide execution context left behind by other imports.""" + monkeypatch.delenv("_AIRFLOW_PROCESS_CONTEXT", raising=False) + monkeypatch.delitem(sys.modules, "airflow.sdk.execution_time.task_runner", raising=False) + @pytest.mark.parametrize( ( "uri", @@ -455,6 +461,28 @@ def test_get_connection_from_secrets_task_sdk_not_found(self, mock_task_sdk_conn with pytest.raises(AirflowNotFoundException): Connection.get_connection_from_secrets("test_conn") + @mock.patch.dict("os.environ", {"_AIRFLOW_PROCESS_CONTEXT": "server"}) + def test_connection_from_json_uses_core_path_when_server_context(self): + """Server context should prefer core Connection.from_json even if comms exist.""" + fake_task_runner = mock.MagicMock() + fake_task_runner.SUPERVISOR_COMMS = True + + with ( + mock.patch.dict(sys.modules, {"airflow.sdk.execution_time.task_runner": fake_task_runner}), + mock.patch( + "airflow.sdk.Connection.from_json", + side_effect=AssertionError( + "Connection.from_json should not route through Task SDK in server context" + ), + ), + ): + result = Connection.from_json('{"conn_type": "http", "host": "localhost"}', conn_id="test_conn") + + assert isinstance(result, Connection) + assert result.conn_id == "test_conn" + assert result.conn_type == "http" + assert result.host == "localhost" + @mock.patch.dict(sys.modules, {"airflow.sdk.execution_time.task_runner": None}) @mock.patch("airflow.sdk.Connection") @mock.patch("airflow.secrets.environment_variables.EnvironmentVariablesBackend.get_connection") diff --git a/task-sdk/src/airflow/sdk/log.py b/task-sdk/src/airflow/sdk/log.py index f93abe1a72240..07637edf71c72 100644 --- a/task-sdk/src/airflow/sdk/log.py +++ b/task-sdk/src/airflow/sdk/log.py @@ -250,12 +250,16 @@ def mask_secret(secret: JsonValue, name: str | None = None) -> None: they're masked in both the task subprocess AND supervisor's log output. Works safely in both sync and async contexts. """ + import os from contextlib import suppress from airflow.sdk._shared.secrets_masker import _secrets_masker _secrets_masker().add_mask(secret, name) + if os.environ.get("PYTHON_OPERATORS_VIRTUAL_ENV_MODE"): + return + with suppress(Exception): # Try to tell supervisor (only if in task execution context) from airflow.sdk.execution_time import task_runner diff --git a/task-sdk/tests/task_sdk/definitions/test_variables.py b/task-sdk/tests/task_sdk/definitions/test_variables.py index 3717f834735ff..38f7af350b47e 100644 --- a/task-sdk/tests/task_sdk/definitions/test_variables.py +++ b/task-sdk/tests/task_sdk/definitions/test_variables.py @@ -18,6 +18,8 @@ from __future__ import annotations import json +import queue +import threading from unittest import mock from unittest.mock import patch @@ -160,6 +162,49 @@ def test_get_variable_env_var(self, mock_env_get, mock_supervisor_comms): Variable.get(key="fake_var_key") mock_env_get.assert_called_once_with(key="fake_var_key") + def test_get_variable_env_var_in_virtualenv_does_not_wait_for_supervisor_comms(self, monkeypatch): + """Regression test for Variable.get() hanging in PythonVirtualenvOperator child processes.""" + from airflow.sdk.execution_time import task_runner + from airflow.sdk.execution_time.cache import SecretCache + + events = queue.Queue() + release_supervisor_comms = threading.Event() + + class BlockingSupervisorComms: + def send(self, *args, **kwargs): + events.put("supervisor_comms") + release_supervisor_comms.wait(timeout=5) + + SecretCache.reset() + monkeypatch.setenv("PYTHON_OPERATORS_VIRTUAL_ENV_MODE", "1") + monkeypatch.setenv("AIRFLOW_VAR_DEMO_MESSAGE", "hello from env") + monkeypatch.setattr(task_runner, "SUPERVISOR_COMMS", BlockingSupervisorComms(), raising=False) + + result = {} + error = {} + + def get_variable(): + try: + result["value"] = Variable.get(key="DEMO_MESSAGE") + except BaseException as exc: + error["exception"] = exc + finally: + events.put("done") + + thread = threading.Thread(target=get_variable, daemon=True) + thread.start() + first_event = events.get(timeout=5) + + release_supervisor_comms.set() + thread.join(timeout=5) + + assert first_event == "done", ( + "Variable.get() should not wait for supervisor comms when an env var backend returns the value " + "inside a PythonVirtualenvOperator child process." + ) + assert error == {} + assert result == {"value": "hello from env"} + @conf_vars( { ("workers", "secrets_backend"): "airflow.secrets.local_filesystem.LocalFilesystemBackend",