From e5ace314f3253a7165c5d910130ff0b8bad41b5b Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Fri, 26 Jun 2026 23:55:41 +0100 Subject: [PATCH 1/2] Update OpenAI provider to OpenAI Python SDK 2.x Raise the openai[datalib] floor to >=2.37.0 so the provider builds against the current 2.x SDK line, and refresh stale default models (gpt-3.5-turbo -> gpt-4o-mini, text-embedding-ada-002 -> text-embedding-3-small). --- providers/openai/README.rst | 2 +- providers/openai/docs/index.rst | 2 +- providers/openai/pyproject.toml | 2 +- .../openai/src/airflow/providers/openai/hooks/openai.py | 6 +++--- .../openai/src/airflow/providers/openai/operators/openai.py | 2 +- providers/openai/tests/system/openai/example_openai.py | 6 +++--- 6 files changed, 10 insertions(+), 10 deletions(-) diff --git a/providers/openai/README.rst b/providers/openai/README.rst index a3c4ba53ac55d..5d78147fb2dce 100644 --- a/providers/openai/README.rst +++ b/providers/openai/README.rst @@ -57,7 +57,7 @@ PIP package Version required ========================================== ================== ``apache-airflow`` ``>=2.11.0`` ``apache-airflow-providers-common-compat`` ``>=1.12.0`` -``openai[datalib]`` ``>=1.66.0`` +``openai[datalib]`` ``>=2.37.0`` ========================================== ================== Cross provider package dependencies diff --git a/providers/openai/docs/index.rst b/providers/openai/docs/index.rst index 70ba43a0a9bf3..7883d87efdd8b 100644 --- a/providers/openai/docs/index.rst +++ b/providers/openai/docs/index.rst @@ -98,7 +98,7 @@ PIP package Version required ========================================== ================== ``apache-airflow`` ``>=2.11.0`` ``apache-airflow-providers-common-compat`` ``>=1.12.0`` -``openai[datalib]`` ``>=1.66.0`` +``openai[datalib]`` ``>=2.37.0`` ========================================== ================== Cross provider package dependencies diff --git a/providers/openai/pyproject.toml b/providers/openai/pyproject.toml index 489e9c69c4973..5e2f97bd29e83 100644 --- a/providers/openai/pyproject.toml +++ b/providers/openai/pyproject.toml @@ -61,7 +61,7 @@ requires-python = ">=3.10" dependencies = [ "apache-airflow>=2.11.0", "apache-airflow-providers-common-compat>=1.12.0", - "openai[datalib]>=1.66.0", + "openai[datalib]>=2.37.0", ] [dependency-groups] diff --git a/providers/openai/src/airflow/providers/openai/hooks/openai.py b/providers/openai/src/airflow/providers/openai/hooks/openai.py index c245bd10ca06a..dd98b23984865 100644 --- a/providers/openai/src/airflow/providers/openai/hooks/openai.py +++ b/providers/openai/src/airflow/providers/openai/hooks/openai.py @@ -129,7 +129,7 @@ def create_chat_completion( | ChatCompletionToolMessageParam | ChatCompletionFunctionMessageParam ], - model: str = "gpt-3.5-turbo", + model: str = "gpt-4o-mini", **kwargs: Any, ) -> list[ChatCompletionMessage]: """ @@ -141,7 +141,7 @@ def create_chat_completion( response = self.conn.chat.completions.create(model=model, messages=messages, **kwargs) return response.choices - def create_assistant(self, model: str = "gpt-3.5-turbo", **kwargs: Any) -> Assistant: + def create_assistant(self, model: str = "gpt-4o-mini", **kwargs: Any) -> Assistant: """ Create an OpenAI assistant using the given model. @@ -297,7 +297,7 @@ def modify_run(self, thread_id: str, run_id: str, **kwargs: Any) -> Run: def create_embeddings( self, text: str | list[str] | list[int] | list[list[int]], - model: str = "text-embedding-ada-002", + model: str = "text-embedding-3-small", **kwargs: Any, ) -> list[float]: """ diff --git a/providers/openai/src/airflow/providers/openai/operators/openai.py b/providers/openai/src/airflow/providers/openai/operators/openai.py index 574d7fb85d857..192986ab456f8 100644 --- a/providers/openai/src/airflow/providers/openai/operators/openai.py +++ b/providers/openai/src/airflow/providers/openai/operators/openai.py @@ -54,7 +54,7 @@ def __init__( self, conn_id: str, input_text: str | list[str] | list[int] | list[list[int]], - model: str = "text-embedding-ada-002", + model: str = "text-embedding-3-small", embedding_kwargs: dict | None = None, **kwargs: Any, ): diff --git a/providers/openai/tests/system/openai/example_openai.py b/providers/openai/tests/system/openai/example_openai.py index f8df2bd60248f..6cbbe2b47f32c 100644 --- a/providers/openai/tests/system/openai/example_openai.py +++ b/providers/openai/tests/system/openai/example_openai.py @@ -78,7 +78,7 @@ def task_to_store_input_text_in_xcom(): task_id="embedding_using_xcom_data", conn_id="openai_default", input_text=task_to_store_input_text_in_xcom(), - model="text-embedding-ada-002", + model="text-embedding-3-small", ) OpenAIEmbeddingOperator( @@ -90,13 +90,13 @@ def task_to_store_input_text_in_xcom(): input_kwarg1="input_kwarg1_value", input_kwarg2="input_kwarg2_value", ), - model="text-embedding-ada-002", + model="text-embedding-3-small", ) OpenAIEmbeddingOperator( task_id="embedding_using_text", conn_id="openai_default", input_text=texts, - model="text-embedding-ada-002", + model="text-embedding-3-small", ) # [END howto_operator_openai_embedding] From 7fba8e0b0536c7ca8a4cca1539d45038664c0ab3 Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Sat, 27 Jun 2026 00:29:48 +0100 Subject: [PATCH 2/2] Add Workload Identity authentication to OpenAI provider Authenticate OpenAI connections with short-lived identity tokens instead of a long-lived API key, selected with a new auth_type key in the connection extra (defaults to api_key, so existing connections are unchanged). Supports Kubernetes, Azure managed identity, GCP, and a custom token provider. --- providers/openai/docs/connections.rst | 36 ++++ .../airflow/providers/openai/hooks/openai.py | 97 ++++++++++- .../tests/unit/openai/hooks/test_openai.py | 154 ++++++++++++++++++ 3 files changed, 281 insertions(+), 6 deletions(-) diff --git a/providers/openai/docs/connections.rst b/providers/openai/docs/connections.rst index 8ef7ee456b1c9..a0a6e1a83b4cf 100644 --- a/providers/openai/docs/connections.rst +++ b/providers/openai/docs/connections.rst @@ -52,3 +52,39 @@ Extra (optional) "api_key": "YOUR_API_KEY" } } + +Authentication +-------------- + +The authentication mechanism is selected with the ``auth_type`` key in the ``extra`` field. It +defaults to ``"api_key"``, which uses the API key from the connection password (backward compatible +with existing connections). + +Set ``auth_type`` to ``"workload_identity"`` to authenticate with short-lived identity tokens +instead of a long-lived API key. This exchanges a token minted by your environment's identity +provider, so no API key is stored in the connection. ``identity_provider_id`` and +``service_account_id`` are always required, and the token source is chosen with +``workload_identity_provider``: + +* ``"kubernetes"`` -- a Kubernetes service account token read from ``token_file_path`` (defaults to + the in-cluster path ``/var/run/secrets/kubernetes.io/serviceaccount/token``). +* ``"azure"`` -- an Azure managed identity. Optional keys: ``resource``, ``client_id``, + ``object_id``, ``msi_res_id``, ``api_version``. +* ``"gcp"`` -- a Google Cloud ID token for the given ``audience``. +* ``"custom"`` -- import ``token_provider`` (a dotted path to a ``Callable[[], str]``) and use it as + the token source. The callable is imported and invoked in the process that runs the hook, so point + it only at trusted code. Optional ``token_type`` (``"jwt"`` or ``"id"``, defaults to ``"jwt"``); + ``token_type`` applies only to the ``custom`` source. + +The optional ``refresh_buffer_seconds`` controls how long before expiry the token is refreshed. + +For example, to authenticate from a Kubernetes pod: + +.. code-block:: json + + { + "auth_type": "workload_identity", + "workload_identity_provider": "kubernetes", + "identity_provider_id": "idp-123", + "service_account_id": "sa-456" + } diff --git a/providers/openai/src/airflow/providers/openai/hooks/openai.py b/providers/openai/src/airflow/providers/openai/hooks/openai.py index dd98b23984865..ca80b4d3761be 100644 --- a/providers/openai/src/airflow/providers/openai/hooks/openai.py +++ b/providers/openai/src/airflow/providers/openai/hooks/openai.py @@ -23,8 +23,14 @@ from typing import TYPE_CHECKING, Any, BinaryIO, Literal from openai import OpenAI +from openai.auth import ( + azure_managed_identity_token_provider, + gcp_id_token_provider, + k8s_service_account_token_provider, +) if TYPE_CHECKING: + from openai.auth import SubjectTokenProvider, WorkloadIdentity from openai.types import ( FileDeleted, FileObject, @@ -43,6 +49,7 @@ ChatCompletionUserMessageParam, ) from openai.types.vector_stores import VectorStoreFile, VectorStoreFileBatch, VectorStoreFileDeleted +from airflow.providers.common.compat.module_loading import import_string from airflow.providers.common.compat.sdk import BaseHook from airflow.providers.openai.exceptions import OpenAIBatchJobException, OpenAIBatchTimeout @@ -108,18 +115,96 @@ def conn(self) -> OpenAI: return self.get_conn() def get_conn(self) -> OpenAI: - """Return an OpenAI connection object.""" + """ + Return an OpenAI connection object. + + The authentication mechanism is selected with the ``auth_type`` key in the connection + ``extra`` (default ``"api_key"``): + + * ``"api_key"`` -- use the API key from the connection password (or ``openai_client_kwargs``). + * ``"workload_identity"`` -- exchange a short-lived identity token for API access using the + OpenAI client's workload identity support. See :meth:`_build_workload_identity`. + """ conn = self.get_connection(self.conn_id) extras = conn.extra_dejson openai_client_kwargs = extras.get("openai_client_kwargs", {}) - api_key = openai_client_kwargs.pop("api_key", None) or conn.password base_url = openai_client_kwargs.pop("base_url", None) or conn.host or None - return OpenAI( - api_key=api_key, - base_url=base_url, - **openai_client_kwargs, + # Pop api_key for every path so it is never forwarded alongside ``workload_identity`` + # (the OpenAI client rejects both being set at once). + api_key = openai_client_kwargs.pop("api_key", None) + auth_type = extras.get("auth_type", "api_key") + + if auth_type == "api_key": + return OpenAI(api_key=api_key or conn.password, base_url=base_url, **openai_client_kwargs) + + if auth_type == "workload_identity": + return OpenAI( + workload_identity=self._build_workload_identity(extras), + base_url=base_url, + **openai_client_kwargs, + ) + + raise ValueError( + f"Unsupported auth_type {auth_type!r} for OpenAI connection {self.conn_id!r}; " + "expected 'api_key' or 'workload_identity'." ) + def _build_workload_identity(self, extras: dict[str, Any]) -> WorkloadIdentity: + """ + Build the OpenAI ``workload_identity`` config from the connection ``extra``. + + Returns the ``workload_identity`` mapping (``identity_provider_id``, ``service_account_id``, + ``provider``, and optional ``refresh_buffer_seconds``) for the token source selected by + ``workload_identity_provider``. Raises ``ValueError`` when a required key is missing or the + source is unknown. See :ref:`howto/connection:openai` for the full key reference. + """ + for key in ("identity_provider_id", "service_account_id"): + if key not in extras: + raise ValueError( + f"Missing required {key!r} for workload_identity auth on OpenAI connection " + f"{self.conn_id!r}." + ) + + provider_name = extras.get("workload_identity_provider") + provider: SubjectTokenProvider + if provider_name == "kubernetes": + kwargs = {key: extras[key] for key in ("token_file_path",) if key in extras} + provider = k8s_service_account_token_provider(**kwargs) + elif provider_name == "azure": + kwargs = { + key: extras[key] + for key in ("resource", "object_id", "client_id", "msi_res_id", "api_version") + if key in extras + } + provider = azure_managed_identity_token_provider(**kwargs) + elif provider_name == "gcp": + kwargs = {key: extras[key] for key in ("audience",) if key in extras} + provider = gcp_id_token_provider(**kwargs) + elif provider_name == "custom": + if "token_provider" not in extras: + raise ValueError( + f"Missing required 'token_provider' for custom workload_identity auth on OpenAI " + f"connection {self.conn_id!r}." + ) + provider = { + "token_type": extras.get("token_type", "jwt"), + "get_token": import_string(extras["token_provider"]), + } + else: + raise ValueError( + f"Unsupported workload_identity_provider {provider_name!r} for OpenAI connection " + f"{self.conn_id!r}; expected one of 'kubernetes', 'azure', 'gcp', 'custom'." + ) + + workload_identity: WorkloadIdentity = { + "identity_provider_id": extras["identity_provider_id"], + "service_account_id": extras["service_account_id"], + "provider": provider, + } + if "refresh_buffer_seconds" in extras: + workload_identity["refresh_buffer_seconds"] = extras["refresh_buffer_seconds"] + return workload_identity + def create_chat_completion( self, messages: list[ diff --git a/providers/openai/tests/unit/openai/hooks/test_openai.py b/providers/openai/tests/unit/openai/hooks/test_openai.py index f2bd2a5d9f4b0..6cee1244e2c44 100644 --- a/providers/openai/tests/unit/openai/hooks/test_openai.py +++ b/providers/openai/tests/unit/openai/hooks/test_openai.py @@ -641,3 +641,157 @@ def test_get_conn_with_openai_client_kwargs(mock_client): base_url=None, organization="organization_in_extra", ) + + +def _make_workload_identity_conn(conn_id, extra): + conn = Connection( + conn_id=conn_id, + conn_type="openai", + extra={ + "auth_type": "workload_identity", + "identity_provider_id": "idp-123", + "service_account_id": "sa-456", + **extra, + }, + ) + os.environ[f"AIRFLOW_CONN_{conn.conn_id.upper()}"] = conn.get_uri() + return conn + + +@pytest.mark.parametrize( + ("conn_id", "extra", "factory_attr", "expected_kwargs"), + [ + ( + "wi_k8s", + {"workload_identity_provider": "kubernetes", "token_file_path": "/var/run/token"}, + "k8s_service_account_token_provider", + {"token_file_path": "/var/run/token"}, + ), + ( + "wi_k8s_default", + {"workload_identity_provider": "kubernetes"}, + "k8s_service_account_token_provider", + {}, + ), + ( + "wi_azure", + { + "workload_identity_provider": "azure", + "resource": "https://cognitiveservices.azure.com/", + "client_id": "client-789", + }, + "azure_managed_identity_token_provider", + {"resource": "https://cognitiveservices.azure.com/", "client_id": "client-789"}, + ), + ( + "wi_gcp", + {"workload_identity_provider": "gcp", "audience": "https://api.openai.com/v1"}, + "gcp_id_token_provider", + {"audience": "https://api.openai.com/v1"}, + ), + ], +) +@patch("airflow.providers.openai.hooks.openai.OpenAI") +def test_get_conn_workload_identity_token_source(mock_client, conn_id, extra, factory_attr, expected_kwargs): + provider = object() + with patch( + f"airflow.providers.openai.hooks.openai.{factory_attr}", return_value=provider + ) as mock_factory: + conn = _make_workload_identity_conn(conn_id, extra) + OpenAIHook(conn_id=conn.conn_id).get_conn() + mock_factory.assert_called_once_with(**expected_kwargs) + mock_client.assert_called_once_with( + workload_identity={ + "identity_provider_id": "idp-123", + "service_account_id": "sa-456", + "provider": provider, + }, + base_url=None, + ) + + +@patch("airflow.providers.openai.hooks.openai.OpenAI") +@patch("airflow.providers.openai.hooks.openai.import_string") +def test_get_conn_workload_identity_custom(mock_import_string, mock_client): + def get_token(): + return "token" + + mock_import_string.return_value = get_token + conn = _make_workload_identity_conn( + "wi_custom", + { + "workload_identity_provider": "custom", + "token_provider": "my.module.get_token", + "token_type": "id", + "refresh_buffer_seconds": 300, + }, + ) + OpenAIHook(conn_id=conn.conn_id).get_conn() + mock_import_string.assert_called_once_with("my.module.get_token") + mock_client.assert_called_once_with( + workload_identity={ + "identity_provider_id": "idp-123", + "service_account_id": "sa-456", + "provider": {"token_type": "id", "get_token": get_token}, + "refresh_buffer_seconds": 300, + }, + base_url=None, + ) + + +@patch("airflow.providers.openai.hooks.openai.k8s_service_account_token_provider") +@patch("airflow.providers.openai.hooks.openai.OpenAI") +def test_get_conn_workload_identity_ignores_stray_api_key(mock_client, mock_provider): + provider = object() + mock_provider.return_value = provider + conn = _make_workload_identity_conn( + "wi_stray_api_key", + { + "workload_identity_provider": "kubernetes", + "openai_client_kwargs": {"api_key": "leftover-key"}, + }, + ) + OpenAIHook(conn_id=conn.conn_id).get_conn() + # ``api_key`` is popped for every path, so it is never forwarded alongside ``workload_identity``. + mock_client.assert_called_once_with( + workload_identity={ + "identity_provider_id": "idp-123", + "service_account_id": "sa-456", + "provider": provider, + }, + base_url=None, + ) + + +def test_get_conn_invalid_auth_type(): + conn = Connection(conn_id="bad_auth", conn_type="openai", extra={"auth_type": "oauth"}) + os.environ[f"AIRFLOW_CONN_{conn.conn_id.upper()}"] = conn.get_uri() + with pytest.raises(ValueError, match="Unsupported auth_type 'oauth'"): + OpenAIHook(conn_id=conn.conn_id).get_conn() + + +def test_get_conn_invalid_workload_identity_provider(): + conn = _make_workload_identity_conn("bad_wi_provider", {"workload_identity_provider": "saml"}) + with pytest.raises(ValueError, match="Unsupported workload_identity_provider 'saml'"): + OpenAIHook(conn_id=conn.conn_id).get_conn() + + +def test_get_conn_workload_identity_missing_required_key(): + conn = Connection( + conn_id="wi_missing_key", + conn_type="openai", + extra={ + "auth_type": "workload_identity", + "workload_identity_provider": "kubernetes", + "service_account_id": "sa-456", + }, + ) + os.environ[f"AIRFLOW_CONN_{conn.conn_id.upper()}"] = conn.get_uri() + with pytest.raises(ValueError, match="Missing required 'identity_provider_id'"): + OpenAIHook(conn_id=conn.conn_id).get_conn() + + +def test_get_conn_workload_identity_custom_missing_token_provider(): + conn = _make_workload_identity_conn("wi_custom_missing", {"workload_identity_provider": "custom"}) + with pytest.raises(ValueError, match="Missing required 'token_provider'"): + OpenAIHook(conn_id=conn.conn_id).get_conn()