Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 21 additions & 5 deletions airflow-core/src/airflow/api_fastapi/execution_api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
get_sig_validation_args,
get_signing_args,
)
from airflow.process_context import override_process_context

if TYPE_CHECKING:
import httpx
Expand Down Expand Up @@ -236,7 +237,7 @@ def replace_any_of_with_one_of(spec):
return openapi_schema


def create_task_execution_api_app() -> FastAPI:
def create_task_execution_api_app() -> CadwynWithOpenAPICustomization:
"""Create FastAPI app for task execution API."""
from airflow.api_fastapi.execution_api.routes import execution_api_router
from airflow.api_fastapi.execution_api.versions import bundle
Expand Down Expand Up @@ -300,6 +301,17 @@ def get_extra_schemas() -> dict[str, dict]:
}


class _RequestScopedServerContextApp:
"""Wrap an ASGI app so in-process requests behave like server-side API handling."""

def __init__(self, app: FastAPI) -> None:
self.app = app

async def __call__(self, scope: Any, receive: Any, send: Any) -> None:
with override_process_context("server"):
await self.app(scope, receive, send)


@attrs.define()
class InProcessExecutionAPI:
"""
Expand All @@ -309,11 +321,11 @@ class InProcessExecutionAPI:
needed so that we can use the sync httpx client
"""

_app: FastAPI | None = None
_app: CadwynWithOpenAPICustomization | None = None
_cm: AsyncExitStack | None = None

@cached_property
def app(self):
def app(self) -> CadwynWithOpenAPICustomization:
if not self._app:
from airflow.api_fastapi.common.dagbag import create_dag_bag
from airflow.api_fastapi.execution_api.datamodels.token import TIClaims, TIToken
Expand Down Expand Up @@ -343,14 +355,18 @@ async def always_allow(request: Request):

return self._app

@cached_property
def request_scoped_app(self) -> _RequestScopedServerContextApp:
return _RequestScopedServerContextApp(self.app)

@cached_property
def transport(self) -> httpx.WSGITransport:
import asyncio

import httpx
from a2wsgi import ASGIMiddleware

middleware = ASGIMiddleware(self.app)
middleware = ASGIMiddleware(self.request_scoped_app)

# https://github.com/abersheeran/a2wsgi/discussions/64
async def start_lifespan(cm: AsyncExitStack, app: FastAPI):
Expand All @@ -365,4 +381,4 @@ async def start_lifespan(cm: AsyncExitStack, app: FastAPI):
def atransport(self) -> httpx.ASGITransport:
import httpx

return httpx.ASGITransport(app=self.app)
return httpx.ASGITransport(app=self.request_scoped_app)
6 changes: 3 additions & 3 deletions airflow-core/src/airflow/models/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import json
import logging
import re
import sys
import warnings
from contextlib import suppress
from json import JSONDecodeError
Expand All @@ -35,6 +34,7 @@
from airflow.exceptions import AirflowException, AirflowNotFoundException
from airflow.models.base import ID_LEN, Base
from airflow.models.crypto import get_fernet
from airflow.process_context import should_use_task_sdk_api_path
from airflow.utils.helpers import prune_dict
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.session import NEW_SESSION, provide_session
Expand Down Expand Up @@ -507,7 +507,7 @@ def get_connection_from_secrets(cls, conn_id: str, team_name: str | None = None)

# If this is set it means are in some kind of execution context (Task, Dag Parse or Triggerer perhaps)
# and should use the Task SDK API server path
if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"):
if should_use_task_sdk_api_path():
from airflow.sdk import Connection as TaskSDKConnection
from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType

Expand Down Expand Up @@ -593,7 +593,7 @@ def to_dict(self, *, prune_empty: bool = False, validate: bool = True) -> dict[s

@classmethod
def from_json(cls, value, conn_id=None) -> Connection:
if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"):
if should_use_task_sdk_api_path():
from airflow.sdk import Connection as TaskSDKConnection

warnings.warn(
Expand Down
10 changes: 5 additions & 5 deletions airflow-core/src/airflow/models/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import contextlib
import json
import logging
import sys
import warnings
from typing import TYPE_CHECKING, Any

Expand All @@ -32,6 +31,7 @@
from airflow.configuration import conf, ensure_secrets_loaded
from airflow.models.base import ID_LEN, Base
from airflow.models.crypto import get_fernet
from airflow.process_context import should_use_task_sdk_api_path
from airflow.secrets.metastore import MetastoreBackend
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.session import NEW_SESSION, create_session, provide_session
Expand Down Expand Up @@ -181,7 +181,7 @@ def get(

# If this is set it means we are in some kind of execution context (Task, Dag Parse or Triggerer perhaps)
# and should use the Task SDK API server path
if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"):
if should_use_task_sdk_api_path():
warnings.warn(
"Using Variable.get from `airflow.models` is deprecated."
"Please use `get` on Variable from sdk(`airflow.sdk.Variable`) instead",
Expand Down Expand Up @@ -241,7 +241,7 @@ def set(

# If this is set it means we are in some kind of execution context (Task, Dag Parse or Triggerer perhaps)
# and should use the Task SDK API server path
if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"):
if should_use_task_sdk_api_path():
warnings.warn(
"Using Variable.set from `airflow.models` is deprecated."
"Please use `set` on Variable from sdk(`airflow.sdk.Variable`) instead",
Expand Down Expand Up @@ -329,7 +329,7 @@ def update(

# If this is set it means are in some kind of execution context (Task, Dag Parse or Triggerer perhaps)
# and should use the Task SDK API server path
if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"):
if should_use_task_sdk_api_path():
warnings.warn(
"Using Variable.update from `airflow.models` is deprecated."
"Please use `set` on Variable from sdk(`airflow.sdk.Variable`) instead as it is an upsert.",
Expand Down Expand Up @@ -395,7 +395,7 @@ def delete(key: str, team_name: str | None = None, session: Session | None = Non

# If this is set it means are in some kind of execution context (Task, Dag Parse or Triggerer perhaps)
# and should use the Task SDK API server path
if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"):
if should_use_task_sdk_api_path():
warnings.warn(
"Using Variable.delete from `airflow.models` is deprecated."
"Please use `delete` on Variable from sdk(`airflow.sdk.Variable`) instead",
Expand Down
59 changes: 59 additions & 0 deletions airflow-core/src/airflow/process_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import os
import sys
from collections.abc import Generator
from contextlib import contextmanager
from contextvars import ContextVar
from typing import Literal

__all__ = [
"get_process_context",
"override_process_context",
"should_use_task_sdk_api_path",
]

_PROCESS_CONTEXT_OVERRIDE: ContextVar[str | None] = ContextVar(
"_AIRFLOW_PROCESS_CONTEXT_OVERRIDE",
default=None,
)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't seem like a good place for these utilities. Shouldn't models/* be something related to "Data Model" (or a Table in Metadata DB more specific).

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't seem like a good place for these utilities. Shouldn't models/* be something related to "Data Model" (or a Table in Metadata DB more specific).

Good point on the directory structure. I'm also concerned that models/ isn't the right fit since these aren't DB models. Where would you suggest moving them? I previously thought about utils/, but the CI checks didn't allow that, and putting them in _shared might be overkill.


def get_process_context() -> str | None:
"""Return the current process context, preferring request-scoped overrides."""
return _PROCESS_CONTEXT_OVERRIDE.get() or os.environ.get("_AIRFLOW_PROCESS_CONTEXT")


@contextmanager
def override_process_context(context: Literal["server", "client"]) -> Generator[None, None, None]:
"""Temporarily override the current process context for the active execution flow."""
token = _PROCESS_CONTEXT_OVERRIDE.set(context)
try:
yield
finally:
_PROCESS_CONTEXT_OVERRIDE.reset(token)


def should_use_task_sdk_api_path() -> bool:
"""Return True when execution-context helpers should route through Task SDK APIs."""
if get_process_context() == "server":
return False

task_runner_module = sys.modules.get("airflow.sdk.execution_time.task_runner")
return bool(getattr(task_runner_module, "SUPERVISOR_COMMS", None))
35 changes: 35 additions & 0 deletions airflow-core/tests/unit/api_fastapi/execution_api/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,16 @@
# under the License.
from __future__ import annotations

import sys
from unittest import mock

import httpx
import pytest

from airflow.api_fastapi.execution_api.app import InProcessExecutionAPI
from airflow.api_fastapi.execution_api.datamodels.taskinstance import TaskInstance
from airflow.api_fastapi.execution_api.versions import bundle
from airflow.process_context import get_process_context, override_process_context

pytestmark = pytest.mark.db_test

Expand Down Expand Up @@ -111,3 +117,32 @@ def test_multiple_requests_with_different_correlation_ids(self, client):

# Verify they didn't interfere with each other
assert correlation_id_1 != correlation_id_2


def test_in_process_execution_api_uses_request_scoped_server_context(monkeypatch):
api = InProcessExecutionAPI()
fake_task_runner = mock.Mock()
fake_task_runner.SUPERVISOR_COMMS = object()

monkeypatch.setenv("AIRFLOW_VAR_KEY1", "VALUE")

with (
override_process_context("client"),
mock.patch.dict(sys.modules, {"airflow.sdk.execution_time.task_runner": fake_task_runner}),
mock.patch(
"airflow.sdk.Variable.get",
side_effect=AssertionError(
"In-process Execution API requests should not route through Task SDK Variable.get"
),
),
httpx.Client(transport=api.transport, base_url="http://in-process.invalid") as client,
):
assert get_process_context() == "client"

response = client.get("/variables/key1")

assert response.status_code == 200
assert response.json() == {"key": "key1", "value": "VALUE"}
assert get_process_context() == "client"

assert get_process_context() is None
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from __future__ import annotations

import sys
from unittest import mock

import pytest
Expand Down Expand Up @@ -105,6 +106,40 @@ def test_connection_get_from_env_var(self, client, session):
"extra": '{"headers": "header"}',
}

@mock.patch.dict(
"os.environ",
{
"AIRFLOW_CONN_TEST_CONN_SERVER": '{"uri": "http://root:admin@localhost:8080/https?headers=header"}',
"_AIRFLOW_PROCESS_CONTEXT": "server",
},
)
def test_connection_get_uses_server_path_when_supervisor_comms_exists(self, client):
fake_task_runner = mock.Mock()
fake_task_runner.SUPERVISOR_COMMS = object()

with (
mock.patch.dict(sys.modules, {"airflow.sdk.execution_time.task_runner": fake_task_runner}),
mock.patch(
"airflow.sdk.Connection.get",
side_effect=AssertionError(
"Execution API should not route through Task SDK Connection.get in server context"
),
),
):
response = client.get("/execution/connections/test_conn_server")

assert response.status_code == 200
assert response.json() == {
"conn_id": "test_conn_server",
"conn_type": "http",
"host": "localhost",
"login": "root",
"password": "admin",
"schema": "https",
"port": 8080,
"extra": '{"headers": "header"}',
}

def test_connection_get_not_found(self, client):
response = client.get("/execution/connections/non_existent_test_conn")

Expand Down
Loading
Loading