Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion providers/openai/README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ PIP package Version required
========================================== ==================
``apache-airflow`` ``>=2.11.0``
``apache-airflow-providers-common-compat`` ``>=1.12.0``
``openai[datalib]`` ``>=1.66.0``
``openai[datalib]`` ``>=2.37.0``
========================================== ==================

Cross provider package dependencies
Expand Down
36 changes: 36 additions & 0 deletions providers/openai/docs/connections.rst
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,39 @@ Extra (optional)
"api_key": "YOUR_API_KEY"
}
}

Authentication
--------------

The authentication mechanism is selected with the ``auth_type`` key in the ``extra`` field. It
defaults to ``"api_key"``, which uses the API key from the connection password (backward compatible
with existing connections).

Set ``auth_type`` to ``"workload_identity"`` to authenticate with short-lived identity tokens
instead of a long-lived API key. This exchanges a token minted by your environment's identity
provider, so no API key is stored in the connection. ``identity_provider_id`` and
``service_account_id`` are always required, and the token source is chosen with
``workload_identity_provider``:

* ``"kubernetes"`` -- a Kubernetes service account token read from ``token_file_path`` (defaults to
the in-cluster path ``/var/run/secrets/kubernetes.io/serviceaccount/token``).
* ``"azure"`` -- an Azure managed identity. Optional keys: ``resource``, ``client_id``,
``object_id``, ``msi_res_id``, ``api_version``.
* ``"gcp"`` -- a Google Cloud ID token for the given ``audience``.
* ``"custom"`` -- import ``token_provider`` (a dotted path to a ``Callable[[], str]``) and use it as
the token source. The callable is imported and invoked in the process that runs the hook, so point
it only at trusted code. Optional ``token_type`` (``"jwt"`` or ``"id"``, defaults to ``"jwt"``);
``token_type`` applies only to the ``custom`` source.

The optional ``refresh_buffer_seconds`` controls how long before expiry the token is refreshed.

For example, to authenticate from a Kubernetes pod:

.. code-block:: json

{
"auth_type": "workload_identity",
"workload_identity_provider": "kubernetes",
"identity_provider_id": "idp-123",
"service_account_id": "sa-456"
}
2 changes: 1 addition & 1 deletion providers/openai/docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ PIP package Version required
========================================== ==================
``apache-airflow`` ``>=2.11.0``
``apache-airflow-providers-common-compat`` ``>=1.12.0``
``openai[datalib]`` ``>=1.66.0``
``openai[datalib]`` ``>=2.37.0``
========================================== ==================

Cross provider package dependencies
Expand Down
2 changes: 1 addition & 1 deletion providers/openai/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ requires-python = ">=3.10"
dependencies = [
"apache-airflow>=2.11.0",
"apache-airflow-providers-common-compat>=1.12.0",
"openai[datalib]>=1.66.0",
"openai[datalib]>=2.37.0",
]

[dependency-groups]
Expand Down
103 changes: 94 additions & 9 deletions providers/openai/src/airflow/providers/openai/hooks/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,14 @@
from typing import TYPE_CHECKING, Any, BinaryIO, Literal

from openai import OpenAI
from openai.auth import (
azure_managed_identity_token_provider,
gcp_id_token_provider,
k8s_service_account_token_provider,
)

if TYPE_CHECKING:
from openai.auth import SubjectTokenProvider, WorkloadIdentity
from openai.types import (
FileDeleted,
FileObject,
Expand All @@ -43,6 +49,7 @@
ChatCompletionUserMessageParam,
)
from openai.types.vector_stores import VectorStoreFile, VectorStoreFileBatch, VectorStoreFileDeleted
from airflow.providers.common.compat.module_loading import import_string
from airflow.providers.common.compat.sdk import BaseHook
from airflow.providers.openai.exceptions import OpenAIBatchJobException, OpenAIBatchTimeout

Expand Down Expand Up @@ -108,18 +115,96 @@ def conn(self) -> OpenAI:
return self.get_conn()

def get_conn(self) -> OpenAI:
"""Return an OpenAI connection object."""
"""
Return an OpenAI connection object.

The authentication mechanism is selected with the ``auth_type`` key in the connection
``extra`` (default ``"api_key"``):

* ``"api_key"`` -- use the API key from the connection password (or ``openai_client_kwargs``).
* ``"workload_identity"`` -- exchange a short-lived identity token for API access using the
OpenAI client's workload identity support. See :meth:`_build_workload_identity`.
"""
conn = self.get_connection(self.conn_id)
extras = conn.extra_dejson
openai_client_kwargs = extras.get("openai_client_kwargs", {})
api_key = openai_client_kwargs.pop("api_key", None) or conn.password
base_url = openai_client_kwargs.pop("base_url", None) or conn.host or None
return OpenAI(
api_key=api_key,
base_url=base_url,
**openai_client_kwargs,
# Pop api_key for every path so it is never forwarded alongside ``workload_identity``
# (the OpenAI client rejects both being set at once).
api_key = openai_client_kwargs.pop("api_key", None)
auth_type = extras.get("auth_type", "api_key")

if auth_type == "api_key":
return OpenAI(api_key=api_key or conn.password, base_url=base_url, **openai_client_kwargs)

if auth_type == "workload_identity":
return OpenAI(
workload_identity=self._build_workload_identity(extras),
base_url=base_url,
**openai_client_kwargs,
)

raise ValueError(
f"Unsupported auth_type {auth_type!r} for OpenAI connection {self.conn_id!r}; "
"expected 'api_key' or 'workload_identity'."
)

def _build_workload_identity(self, extras: dict[str, Any]) -> WorkloadIdentity:
"""
Build the OpenAI ``workload_identity`` config from the connection ``extra``.

Returns the ``workload_identity`` mapping (``identity_provider_id``, ``service_account_id``,
``provider``, and optional ``refresh_buffer_seconds``) for the token source selected by
``workload_identity_provider``. Raises ``ValueError`` when a required key is missing or the
source is unknown. See :ref:`howto/connection:openai` for the full key reference.
"""
for key in ("identity_provider_id", "service_account_id"):
if key not in extras:
raise ValueError(
f"Missing required {key!r} for workload_identity auth on OpenAI connection "
f"{self.conn_id!r}."
)

provider_name = extras.get("workload_identity_provider")
provider: SubjectTokenProvider
if provider_name == "kubernetes":
kwargs = {key: extras[key] for key in ("token_file_path",) if key in extras}
provider = k8s_service_account_token_provider(**kwargs)
elif provider_name == "azure":
kwargs = {
key: extras[key]
for key in ("resource", "object_id", "client_id", "msi_res_id", "api_version")
if key in extras
}
provider = azure_managed_identity_token_provider(**kwargs)
elif provider_name == "gcp":
kwargs = {key: extras[key] for key in ("audience",) if key in extras}
provider = gcp_id_token_provider(**kwargs)
elif provider_name == "custom":
if "token_provider" not in extras:
raise ValueError(
f"Missing required 'token_provider' for custom workload_identity auth on OpenAI "
f"connection {self.conn_id!r}."
)
provider = {
"token_type": extras.get("token_type", "jwt"),
"get_token": import_string(extras["token_provider"]),
}
else:
raise ValueError(
f"Unsupported workload_identity_provider {provider_name!r} for OpenAI connection "
f"{self.conn_id!r}; expected one of 'kubernetes', 'azure', 'gcp', 'custom'."
)

workload_identity: WorkloadIdentity = {
"identity_provider_id": extras["identity_provider_id"],
"service_account_id": extras["service_account_id"],
"provider": provider,
}
if "refresh_buffer_seconds" in extras:
workload_identity["refresh_buffer_seconds"] = extras["refresh_buffer_seconds"]
return workload_identity

def create_chat_completion(
self,
messages: list[
Expand All @@ -129,7 +214,7 @@ def create_chat_completion(
| ChatCompletionToolMessageParam
| ChatCompletionFunctionMessageParam
],
model: str = "gpt-3.5-turbo",
model: str = "gpt-4o-mini",
**kwargs: Any,
) -> list[ChatCompletionMessage]:
"""
Expand All @@ -141,7 +226,7 @@ def create_chat_completion(
response = self.conn.chat.completions.create(model=model, messages=messages, **kwargs)
return response.choices

def create_assistant(self, model: str = "gpt-3.5-turbo", **kwargs: Any) -> Assistant:
def create_assistant(self, model: str = "gpt-4o-mini", **kwargs: Any) -> Assistant:
"""
Create an OpenAI assistant using the given model.

Expand Down Expand Up @@ -297,7 +382,7 @@ def modify_run(self, thread_id: str, run_id: str, **kwargs: Any) -> Run:
def create_embeddings(
self,
text: str | list[str] | list[int] | list[list[int]],
model: str = "text-embedding-ada-002",
model: str = "text-embedding-3-small",
**kwargs: Any,
) -> list[float]:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(
self,
conn_id: str,
input_text: str | list[str] | list[int] | list[list[int]],
model: str = "text-embedding-ada-002",
model: str = "text-embedding-3-small",
embedding_kwargs: dict | None = None,
**kwargs: Any,
):
Expand Down
6 changes: 3 additions & 3 deletions providers/openai/tests/system/openai/example_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def task_to_store_input_text_in_xcom():
task_id="embedding_using_xcom_data",
conn_id="openai_default",
input_text=task_to_store_input_text_in_xcom(),
model="text-embedding-ada-002",
model="text-embedding-3-small",
)

OpenAIEmbeddingOperator(
Expand All @@ -90,13 +90,13 @@ def task_to_store_input_text_in_xcom():
input_kwarg1="input_kwarg1_value",
input_kwarg2="input_kwarg2_value",
),
model="text-embedding-ada-002",
model="text-embedding-3-small",
)
OpenAIEmbeddingOperator(
task_id="embedding_using_text",
conn_id="openai_default",
input_text=texts,
model="text-embedding-ada-002",
model="text-embedding-3-small",
)
# [END howto_operator_openai_embedding]

Expand Down
Loading
Loading