Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions providers/openai/docs/operators/openai.rst
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,12 @@ For example, to create a conversation and continue it across responses:
conversation = hook.create_conversation()
hook.create_response(input="Hello", conversation=conversation.id)

.. note::

The Assistants/Threads hook methods (``create_assistant``, ``create_thread``, ``create_run`` and
related) are deprecated, mirroring OpenAI's deprecation of the Assistants API. Migrate to the
Responses and Conversations methods above.

.. _howto/operator:OpenAITriggerBatchOperator:

OpenAITriggerBatchOperator
Expand Down
27 changes: 27 additions & 0 deletions providers/openai/src/airflow/providers/openai/hooks/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from functools import cached_property
from typing import TYPE_CHECKING, Any, BinaryIO, Literal

from deprecated import deprecated
from openai import OpenAI
from openai.auth import (
azure_managed_identity_token_provider,
Expand Down Expand Up @@ -51,10 +52,20 @@
from openai.types.conversations import Conversation, ConversationDeletedResource
from openai.types.responses import Response
from openai.types.vector_stores import VectorStoreFile, VectorStoreFileBatch, VectorStoreFileDeleted
from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers.common.compat.module_loading import import_string
from airflow.providers.common.compat.sdk import BaseHook
from airflow.providers.openai.exceptions import OpenAIBatchJobException, OpenAIBatchTimeout

#: The OpenAI Assistants API (``beta.assistants``/``beta.threads``) is deprecated by OpenAI. The hook
#: methods wrapping it warn and point at the Responses and Conversations APIs (``create_response`` /
#: ``create_conversation``); the removal date is stated once, in the reason string below.
_ASSISTANTS_DEPRECATION_REASON = (
"The OpenAI Assistants API is deprecated and will be removed by OpenAI on 2026-08-26. "
"Use the Responses API (create_response) and Conversations API (create_conversation) instead. "
"See https://platform.openai.com/docs/guides/migrate-to-responses."
)


class BatchStatus(str, Enum):
"""Enum for the status of a batch."""
Expand Down Expand Up @@ -290,6 +301,7 @@ def delete_conversation(self, conversation_id: str) -> ConversationDeletedResour
"""
return self.conn.conversations.delete(conversation_id)

@deprecated(reason=_ASSISTANTS_DEPRECATION_REASON, category=AirflowProviderDeprecationWarning)
def create_assistant(self, model: str = "gpt-4o-mini", **kwargs: Any) -> Assistant:
"""
Create an OpenAI assistant using the given model.
Expand All @@ -299,6 +311,7 @@ def create_assistant(self, model: str = "gpt-4o-mini", **kwargs: Any) -> Assista
assistant = self.conn.beta.assistants.create(model=model, **kwargs)
return assistant

@deprecated(reason=_ASSISTANTS_DEPRECATION_REASON, category=AirflowProviderDeprecationWarning)
def get_assistant(self, assistant_id: str) -> Assistant:
"""
Get an OpenAI assistant.
Expand All @@ -308,11 +321,13 @@ def get_assistant(self, assistant_id: str) -> Assistant:
assistant = self.conn.beta.assistants.retrieve(assistant_id=assistant_id)
return assistant

@deprecated(reason=_ASSISTANTS_DEPRECATION_REASON, category=AirflowProviderDeprecationWarning)
def get_assistants(self, **kwargs: Any) -> list[Assistant]:
"""Get a list of Assistant objects."""
assistants = self.conn.beta.assistants.list(**kwargs)
return assistants.data

@deprecated(reason=_ASSISTANTS_DEPRECATION_REASON, category=AirflowProviderDeprecationWarning)
def modify_assistant(self, assistant_id: str, **kwargs: Any) -> Assistant:
"""
Modify an existing Assistant object.
Expand All @@ -322,6 +337,7 @@ def modify_assistant(self, assistant_id: str, **kwargs: Any) -> Assistant:
assistant = self.conn.beta.assistants.update(assistant_id=assistant_id, **kwargs)
return assistant

@deprecated(reason=_ASSISTANTS_DEPRECATION_REASON, category=AirflowProviderDeprecationWarning)
def delete_assistant(self, assistant_id: str) -> AssistantDeleted:
"""
Delete an OpenAI Assistant for a given ID.
Expand All @@ -331,11 +347,13 @@ def delete_assistant(self, assistant_id: str) -> AssistantDeleted:
response = self.conn.beta.assistants.delete(assistant_id=assistant_id)
return response

@deprecated(reason=_ASSISTANTS_DEPRECATION_REASON, category=AirflowProviderDeprecationWarning)
def create_thread(self, **kwargs: Any) -> Thread:
"""Create an OpenAI thread."""
thread = self.conn.beta.threads.create(**kwargs)
return thread

@deprecated(reason=_ASSISTANTS_DEPRECATION_REASON, category=AirflowProviderDeprecationWarning)
def modify_thread(self, thread_id: str, metadata: dict[str, Any]) -> Thread:
"""
Modify an existing Thread object.
Expand All @@ -346,6 +364,7 @@ def modify_thread(self, thread_id: str, metadata: dict[str, Any]) -> Thread:
thread = self.conn.beta.threads.update(thread_id=thread_id, metadata=metadata)
return thread

@deprecated(reason=_ASSISTANTS_DEPRECATION_REASON, category=AirflowProviderDeprecationWarning)
def delete_thread(self, thread_id: str) -> ThreadDeleted:
"""
Delete an OpenAI thread for a given thread_id.
Expand All @@ -355,6 +374,7 @@ def delete_thread(self, thread_id: str) -> ThreadDeleted:
response = self.conn.beta.threads.delete(thread_id=thread_id)
return response

@deprecated(reason=_ASSISTANTS_DEPRECATION_REASON, category=AirflowProviderDeprecationWarning)
def create_message(
self, thread_id: str, role: Literal["user", "assistant"], content: str, **kwargs: Any
) -> Message:
Expand All @@ -370,6 +390,7 @@ def create_message(
)
return thread_message

@deprecated(reason=_ASSISTANTS_DEPRECATION_REASON, category=AirflowProviderDeprecationWarning)
def get_messages(self, thread_id: str, **kwargs: Any) -> list[Message]:
"""
Return a list of messages for a given Thread.
Expand All @@ -379,6 +400,7 @@ def get_messages(self, thread_id: str, **kwargs: Any) -> list[Message]:
messages = self.conn.beta.threads.messages.list(thread_id=thread_id, **kwargs)
return messages.data

@deprecated(reason=_ASSISTANTS_DEPRECATION_REASON, category=AirflowProviderDeprecationWarning)
def modify_message(self, thread_id: str, message_id, **kwargs: Any) -> Message:
"""
Modify an existing message for a given Thread.
Expand All @@ -391,6 +413,7 @@ def modify_message(self, thread_id: str, message_id, **kwargs: Any) -> Message:
)
return thread_message

@deprecated(reason=_ASSISTANTS_DEPRECATION_REASON, category=AirflowProviderDeprecationWarning)
def create_run(self, thread_id: str, assistant_id: str, **kwargs: Any) -> Run:
"""
Create a run for a given thread and assistant.
Expand All @@ -401,6 +424,7 @@ def create_run(self, thread_id: str, assistant_id: str, **kwargs: Any) -> Run:
run = self.conn.beta.threads.runs.create(thread_id=thread_id, assistant_id=assistant_id, **kwargs)
return run

@deprecated(reason=_ASSISTANTS_DEPRECATION_REASON, category=AirflowProviderDeprecationWarning)
def create_run_and_poll(self, thread_id: str, assistant_id: str, **kwargs: Any) -> Run:
"""
Create a run for a given thread and assistant and then polls until completion.
Expand All @@ -414,6 +438,7 @@ def create_run_and_poll(self, thread_id: str, assistant_id: str, **kwargs: Any)
)
return run

@deprecated(reason=_ASSISTANTS_DEPRECATION_REASON, category=AirflowProviderDeprecationWarning)
def get_run(self, thread_id: str, run_id: str) -> Run:
"""
Retrieve a run for a given thread and run.
Expand All @@ -424,6 +449,7 @@ def get_run(self, thread_id: str, run_id: str) -> Run:
run = self.conn.beta.threads.runs.retrieve(thread_id=thread_id, run_id=run_id)
return run

@deprecated(reason=_ASSISTANTS_DEPRECATION_REASON, category=AirflowProviderDeprecationWarning)
def get_runs(self, thread_id: str, **kwargs: Any) -> list[Run]:
"""
Return a list of runs belonging to a thread.
Expand All @@ -433,6 +459,7 @@ def get_runs(self, thread_id: str, **kwargs: Any) -> list[Run]:
runs = self.conn.beta.threads.runs.list(thread_id=thread_id, **kwargs)
return runs.data

@deprecated(reason=_ASSISTANTS_DEPRECATION_REASON, category=AirflowProviderDeprecationWarning)
def modify_run(self, thread_id: str, run_id: str, **kwargs: Any) -> Run:
"""
Modify a run on a given thread.
Expand Down
55 changes: 37 additions & 18 deletions providers/openai/tests/unit/openai/hooks/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from openai.types.chat import ChatCompletion
from openai.types.vector_stores import VectorStoreFile, VectorStoreFileBatch, VectorStoreFileDeleted

from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.models import Connection
from airflow.providers.openai.exceptions import OpenAIBatchJobException, OpenAIBatchTimeout
from airflow.providers.openai.hooks.openai import OpenAIHook
Expand Down Expand Up @@ -363,82 +364,95 @@ def test_delete_conversation(mock_openai_hook):

def test_create_assistant(mock_openai_hook, mock_assistant):
mock_openai_hook.conn.beta.assistants.create.return_value = mock_assistant
assistant = mock_openai_hook.create_assistant(
name=ASSISTANT_NAME, model=MODEL, instructions=ASSISTANT_INSTRUCTIONS
)
with pytest.warns(AirflowProviderDeprecationWarning):
assistant = mock_openai_hook.create_assistant(
name=ASSISTANT_NAME, model=MODEL, instructions=ASSISTANT_INSTRUCTIONS
)
assert assistant.name == ASSISTANT_NAME
assert assistant.model == MODEL
assert assistant.instructions == ASSISTANT_INSTRUCTIONS


def test_get_assistant(mock_openai_hook, mock_assistant):
mock_openai_hook.conn.beta.assistants.retrieve.return_value = mock_assistant
assistant = mock_openai_hook.get_assistant(assistant_id=ASSISTANT_ID)
with pytest.warns(AirflowProviderDeprecationWarning):
assistant = mock_openai_hook.get_assistant(assistant_id=ASSISTANT_ID)
assert assistant.name == ASSISTANT_NAME
assert assistant.model == MODEL
assert assistant.instructions == ASSISTANT_INSTRUCTIONS


def test_get_assistants(mock_openai_hook, mock_assistant_list):
mock_openai_hook.conn.beta.assistants.list.return_value = mock_assistant_list
assistants = mock_openai_hook.get_assistants()
with pytest.warns(AirflowProviderDeprecationWarning):
assistants = mock_openai_hook.get_assistants()
assert isinstance(assistants, list)


def test_modify_assistant(mock_openai_hook, mock_assistant):
new_assistant_name = "New Test Assistant"
mock_assistant.name = new_assistant_name
mock_openai_hook.conn.beta.assistants.update.return_value = mock_assistant
assistant = mock_openai_hook.modify_assistant(assistant_id=ASSISTANT_ID, name=new_assistant_name)
with pytest.warns(AirflowProviderDeprecationWarning):
assistant = mock_openai_hook.modify_assistant(assistant_id=ASSISTANT_ID, name=new_assistant_name)
assert assistant.name == new_assistant_name


def test_delete_assistant(mock_openai_hook):
delete_response = AssistantDeleted(id=ASSISTANT_ID, object="assistant.deleted", deleted=True)
mock_openai_hook.conn.beta.assistants.delete.return_value = delete_response
assistant_deleted = mock_openai_hook.delete_assistant(assistant_id=ASSISTANT_ID)
with pytest.warns(AirflowProviderDeprecationWarning):
assistant_deleted = mock_openai_hook.delete_assistant(assistant_id=ASSISTANT_ID)
assert assistant_deleted.deleted


def test_create_thread(mock_openai_hook, mock_thread):
mock_openai_hook.conn.beta.threads.create.return_value = mock_thread
thread = mock_openai_hook.create_thread()
with pytest.warns(AirflowProviderDeprecationWarning):
thread = mock_openai_hook.create_thread()
assert thread.id == THREAD_ID


def test_modify_thread(mock_openai_hook, mock_thread):
mock_thread.metadata = METADATA
mock_openai_hook.conn.beta.threads.update.return_value = mock_thread
thread = mock_openai_hook.modify_thread(thread_id=THREAD_ID, metadata=METADATA)
with pytest.warns(AirflowProviderDeprecationWarning):
thread = mock_openai_hook.modify_thread(thread_id=THREAD_ID, metadata=METADATA)
assert thread.metadata.get("modified") == "true"
assert thread.metadata.get("user") == "abc123"


def test_delete_thread(mock_openai_hook):
delete_response = ThreadDeleted(id=THREAD_ID, object="thread.deleted", deleted=True)
mock_openai_hook.conn.beta.threads.delete.return_value = delete_response
thread_deleted = mock_openai_hook.delete_thread(thread_id=THREAD_ID)
with pytest.warns(AirflowProviderDeprecationWarning):
thread_deleted = mock_openai_hook.delete_thread(thread_id=THREAD_ID)
assert thread_deleted.deleted


def test_create_message(mock_openai_hook, mock_message):
role = "user"
content = "Tell me something interesting."
mock_openai_hook.conn.beta.threads.messages.create.return_value = mock_message
message = mock_openai_hook.create_message(thread_id=THREAD_ID, content=content, role=role)
with pytest.warns(AirflowProviderDeprecationWarning):
message = mock_openai_hook.create_message(thread_id=THREAD_ID, content=content, role=role)
assert message.id == MESSAGE_ID


def test_get_messages(mock_openai_hook, mock_message_list):
mock_openai_hook.conn.beta.threads.messages.list.return_value = mock_message_list
messages = mock_openai_hook.get_messages(thread_id=THREAD_ID)
with pytest.warns(AirflowProviderDeprecationWarning):
messages = mock_openai_hook.get_messages(thread_id=THREAD_ID)
assert isinstance(messages, list)


def test_modify_messages(mock_openai_hook, mock_message):
mock_message.metadata = METADATA
mock_openai_hook.conn.beta.threads.messages.update.return_value = mock_message
message = mock_openai_hook.modify_message(thread_id=THREAD_ID, message_id=MESSAGE_ID, metadata=METADATA)
with pytest.warns(AirflowProviderDeprecationWarning):
message = mock_openai_hook.modify_message(
thread_id=THREAD_ID, message_id=MESSAGE_ID, metadata=METADATA
)
assert message.metadata.get("modified") == "true"
assert message.metadata.get("user") == "abc123"

Expand All @@ -447,34 +461,39 @@ def test_create_run(mock_openai_hook, mock_run):
thread_id = THREAD_ID
assistant_id = ASSISTANT_ID
mock_openai_hook.conn.beta.threads.runs.create.return_value = mock_run
run = mock_openai_hook.create_run(thread_id=thread_id, assistant_id=assistant_id)
with pytest.warns(AirflowProviderDeprecationWarning):
run = mock_openai_hook.create_run(thread_id=thread_id, assistant_id=assistant_id)
assert run.id == RUN_ID


def test_create_run_and_poll(mock_openai_hook, mock_run):
thread_id = THREAD_ID
assistant_id = ASSISTANT_ID
mock_openai_hook.conn.beta.threads.runs.create_and_poll.return_value = mock_run
run = mock_openai_hook.create_run_and_poll(thread_id=thread_id, assistant_id=assistant_id)
with pytest.warns(AirflowProviderDeprecationWarning):
run = mock_openai_hook.create_run_and_poll(thread_id=thread_id, assistant_id=assistant_id)
assert run.id == RUN_ID


def test_get_runs(mock_openai_hook, mock_run_list):
mock_openai_hook.conn.beta.threads.runs.list.return_value = mock_run_list
runs = mock_openai_hook.get_runs(thread_id=THREAD_ID)
with pytest.warns(AirflowProviderDeprecationWarning):
runs = mock_openai_hook.get_runs(thread_id=THREAD_ID)
assert isinstance(runs, list)


def test_get_run_with_run_id(mock_openai_hook, mock_run):
mock_openai_hook.conn.beta.threads.runs.retrieve.return_value = mock_run
run = mock_openai_hook.get_run(thread_id=THREAD_ID, run_id=RUN_ID)
with pytest.warns(AirflowProviderDeprecationWarning):
run = mock_openai_hook.get_run(thread_id=THREAD_ID, run_id=RUN_ID)
assert run.id == RUN_ID


def test_modify_run(mock_openai_hook, mock_run):
mock_run.metadata = METADATA
mock_openai_hook.conn.beta.threads.runs.update.return_value = mock_run
message = mock_openai_hook.modify_run(thread_id=THREAD_ID, run_id=RUN_ID, metadata=METADATA)
with pytest.warns(AirflowProviderDeprecationWarning):
message = mock_openai_hook.modify_run(thread_id=THREAD_ID, run_id=RUN_ID, metadata=METADATA)
assert message.metadata.get("modified") == "true"
assert message.metadata.get("user") == "abc123"

Expand Down
Loading