From 1dd0e1491780d618f602f3a8da01e55973b160a1 Mon Sep 17 00:00:00 2001 From: Marlene <57748216+marlenezw@users.noreply.github.com> Date: Thu, 16 Oct 2025 18:48:21 +0000 Subject: [PATCH 01/12] feat(openai): add callable support for openai_api_key parameter --- .../langchain_openai/chat_models/base.py | 41 ++++++++----------- .../langchain_openai/embeddings/base.py | 22 +++++++--- .../openai/langchain_openai/llms/base.py | 21 +++++++--- .../openai/tests/unit_tests/test_secrets.py | 12 ++++++ 4 files changed, 60 insertions(+), 36 deletions(-) diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index 8e8a3569c5b51..9ef7ad2bec502 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -16,14 +16,7 @@ from json import JSONDecodeError from math import ceil from operator import itemgetter -from typing import ( - TYPE_CHECKING, - Any, - Literal, - TypeAlias, - TypeVar, - cast, -) +from typing import TYPE_CHECKING, Any, Literal, TypeAlias, TypeVar, cast from urllib.parse import urlparse import certifi @@ -34,10 +27,7 @@ CallbackManagerForLLMRun, ) from langchain_core.language_models import LanguageModelInput -from langchain_core.language_models.chat_models import ( - BaseChatModel, - LangSmithParams, -) +from langchain_core.language_models.chat_models import BaseChatModel, LangSmithParams from langchain_core.messages import ( AIMessage, AIMessageChunk, @@ -96,13 +86,7 @@ is_basemodel_subclass, ) from langchain_core.utils.utils import _build_model_kwargs, from_env, secret_from_env -from pydantic import ( - BaseModel, - ConfigDict, - Field, - SecretStr, - model_validator, -) +from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator from pydantic.v1 import BaseModel as BaseModelV1 from typing_extensions import Self @@ -465,8 +449,11 @@ class BaseChatOpenAI(BaseChatModel): """What sampling temperature to use.""" model_kwargs: dict[str, Any] = Field(default_factory=dict) """Holds any model parameters valid for `create` call not explicitly specified.""" - openai_api_key: SecretStr | None = Field( - alias="api_key", default_factory=secret_from_env("OPENAI_API_KEY", default=None) + openai_api_key: SecretStr | None | Callable[[], str] = Field( + alias="api_key", + default_factory=secret_from_env( + ["OPENAI_API_KEY", "AZURE_OPENAI_API_KEY"], default=None + ), ) openai_api_base: str | None = Field(default=None, alias="base_url") """Base URL path for API requests, leave blank if not using a proxy or service emulator.""" # noqa: E501 @@ -776,10 +763,16 @@ def validate_environment(self) -> Self: ): self.stream_usage = True + # Resolve API key from SecretStr or Callable + api_key_value = None + if self.openai_api_key is not None: + if isinstance(self.openai_api_key, SecretStr): + api_key_value = self.openai_api_key.get_secret_value() + elif callable(self.openai_api_key): + api_key_value = self.openai_api_key() + client_params: dict = { - "api_key": ( - self.openai_api_key.get_secret_value() if self.openai_api_key else None - ), + "api_key": api_key_value, "organization": self.openai_organization, "base_url": self.openai_api_base, "timeout": self.request_timeout, diff --git a/libs/partners/openai/langchain_openai/embeddings/base.py b/libs/partners/openai/langchain_openai/embeddings/base.py index 5f92002ad72ab..6291717742795 100644 --- a/libs/partners/openai/langchain_openai/embeddings/base.py +++ b/libs/partners/openai/langchain_openai/embeddings/base.py @@ -4,7 +4,7 @@ import logging import warnings -from collections.abc import Iterable, Mapping, Sequence +from collections.abc import Callable, Iterable, Mapping, Sequence from typing import Any, Literal, cast import openai @@ -189,8 +189,11 @@ class OpenAIEmbeddings(BaseModel, Embeddings): ) embedding_ctx_length: int = 8191 """The maximum number of tokens to embed at once.""" - openai_api_key: SecretStr | None = Field( - alias="api_key", default_factory=secret_from_env("OPENAI_API_KEY", default=None) + openai_api_key: SecretStr | None | Callable[[], str] = Field( + alias="api_key", + default_factory=secret_from_env( + ["OPENAI_API_KEY", "AZURE_OPENAI_API_KEY"], default=None + ), ) """Automatically inferred from env var `OPENAI_API_KEY` if not provided.""" openai_organization: str | None = Field( @@ -292,10 +295,17 @@ def validate_environment(self) -> Self: "If you are using Azure, please use the `AzureOpenAIEmbeddings` class." ) raise ValueError(msg) + + # Resolve API key from SecretStr or Callable + api_key_value = None + if self.openai_api_key is not None: + if isinstance(self.openai_api_key, SecretStr): + api_key_value = self.openai_api_key.get_secret_value() + elif callable(self.openai_api_key): + api_key_value = self.openai_api_key() + client_params: dict = { - "api_key": ( - self.openai_api_key.get_secret_value() if self.openai_api_key else None - ), + "api_key": api_key_value, "organization": self.openai_organization, "base_url": self.openai_api_base, "timeout": self.request_timeout, diff --git a/libs/partners/openai/langchain_openai/llms/base.py b/libs/partners/openai/langchain_openai/llms/base.py index bdffd5eff7bef..ff55e9335884a 100644 --- a/libs/partners/openai/langchain_openai/llms/base.py +++ b/libs/partners/openai/langchain_openai/llms/base.py @@ -4,7 +4,7 @@ import logging import sys -from collections.abc import AsyncIterator, Collection, Iterator, Mapping +from collections.abc import AsyncIterator, Callable, Collection, Iterator, Mapping from typing import Any, Literal import openai @@ -186,8 +186,11 @@ class BaseOpenAI(BaseLLM): """Generates best_of completions server-side and returns the "best".""" model_kwargs: dict[str, Any] = Field(default_factory=dict) """Holds any model parameters valid for `create` call not explicitly specified.""" - openai_api_key: SecretStr | None = Field( - alias="api_key", default_factory=secret_from_env("OPENAI_API_KEY", default=None) + openai_api_key: SecretStr | None | Callable[[], str] = Field( + alias="api_key", + default_factory=secret_from_env( + ["OPENAI_API_KEY", "AZURE_OPENAI_API_KEY"], default=None + ), ) """Automatically inferred from env var `OPENAI_API_KEY` if not provided.""" openai_api_base: str | None = Field( @@ -276,10 +279,16 @@ def validate_environment(self) -> Self: msg = "Cannot stream results when best_of > 1." raise ValueError(msg) + # Resolve API key from SecretStr or Callable + api_key_value = None + if self.openai_api_key is not None: + if isinstance(self.openai_api_key, SecretStr): + api_key_value = self.openai_api_key.get_secret_value() + elif callable(self.openai_api_key): + api_key_value = self.openai_api_key() + client_params: dict = { - "api_key": ( - self.openai_api_key.get_secret_value() if self.openai_api_key else None - ), + "api_key": api_key_value, "organization": self.openai_organization, "base_url": self.openai_api_base, "timeout": self.request_timeout, diff --git a/libs/partners/openai/tests/unit_tests/test_secrets.py b/libs/partners/openai/tests/unit_tests/test_secrets.py index aa1484058e0e2..27d69bed92ce1 100644 --- a/libs/partners/openai/tests/unit_tests/test_secrets.py +++ b/libs/partners/openai/tests/unit_tests/test_secrets.py @@ -187,6 +187,18 @@ def test_openai_uses_actual_secret_value_from_secretstr(model_class: type) -> No assert cast(SecretStr, model.openai_api_key).get_secret_value() == "secret-api-key" +@pytest.mark.parametrize("model_class", [ChatOpenAI, OpenAI, OpenAIEmbeddings]) +def test_openai_api_key_accepts_callable(model_class: type) -> None: + """Test that the API key can be passed as a callable.""" + + def get_api_key() -> str: + return "secret-api-key-from-callable" + + model = model_class(openai_api_key=get_api_key) + assert callable(model.openai_api_key) + assert model.openai_api_key() == "secret-api-key-from-callable" + + @pytest.mark.parametrize("model_class", [AzureChatOpenAI, AzureOpenAI]) def test_azure_serialized_secrets(model_class: type) -> None: """Test that the actual secret value is correctly retrieved.""" From 801fbe19c415de1229bb4de1a263805402c26279 Mon Sep 17 00:00:00 2001 From: Marlene <57748216+marlenezw@users.noreply.github.com> Date: Fri, 17 Oct 2025 20:58:47 +0000 Subject: [PATCH 02/12] removing azure from key factory --- libs/partners/openai/langchain_openai/chat_models/base.py | 4 +--- libs/partners/openai/langchain_openai/embeddings/base.py | 4 +--- libs/partners/openai/langchain_openai/llms/base.py | 4 +--- 3 files changed, 3 insertions(+), 9 deletions(-) diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index 29545eac718ad..a10e9ce3191fd 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -451,9 +451,7 @@ class BaseChatOpenAI(BaseChatModel): """Holds any model parameters valid for `create` call not explicitly specified.""" openai_api_key: SecretStr | None | Callable[[], str] = Field( alias="api_key", - default_factory=secret_from_env( - ["OPENAI_API_KEY", "AZURE_OPENAI_API_KEY"], default=None - ), + default_factory=secret_from_env("OPENAI_API_KEY", default=None), ) openai_api_base: str | None = Field(default=None, alias="base_url") """Base URL path for API requests, leave blank if not using a proxy or service emulator.""" # noqa: E501 diff --git a/libs/partners/openai/langchain_openai/embeddings/base.py b/libs/partners/openai/langchain_openai/embeddings/base.py index 8d6c20b6281c9..30d47c6d28b10 100644 --- a/libs/partners/openai/langchain_openai/embeddings/base.py +++ b/libs/partners/openai/langchain_openai/embeddings/base.py @@ -191,9 +191,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings): """The maximum number of tokens to embed at once.""" openai_api_key: SecretStr | None | Callable[[], str] = Field( alias="api_key", - default_factory=secret_from_env( - ["OPENAI_API_KEY", "AZURE_OPENAI_API_KEY"], default=None - ), + default_factory=secret_from_env("OPENAI_API_KEY", default=None), ) """Automatically inferred from env var `OPENAI_API_KEY` if not provided.""" openai_organization: str | None = Field( diff --git a/libs/partners/openai/langchain_openai/llms/base.py b/libs/partners/openai/langchain_openai/llms/base.py index 9de9c2c272932..9de2b5a10b021 100644 --- a/libs/partners/openai/langchain_openai/llms/base.py +++ b/libs/partners/openai/langchain_openai/llms/base.py @@ -188,9 +188,7 @@ class BaseOpenAI(BaseLLM): """Holds any model parameters valid for `create` call not explicitly specified.""" openai_api_key: SecretStr | None | Callable[[], str] = Field( alias="api_key", - default_factory=secret_from_env( - ["OPENAI_API_KEY", "AZURE_OPENAI_API_KEY"], default=None - ), + default_factory=secret_from_env("OPENAI_API_KEY", default=None), ) """Automatically inferred from env var `OPENAI_API_KEY` if not provided.""" openai_api_base: str | None = Field( From 542f769ea29ee79bf65f986eca7c485afb2a2ad2 Mon Sep 17 00:00:00 2001 From: Marlene <57748216+marlenezw@users.noreply.github.com> Date: Mon, 20 Oct 2025 17:30:31 +0000 Subject: [PATCH 03/12] dont call callable until required --- .../langchain_openai/chat_models/base.py | 91 ++++++------------- .../langchain_openai/embeddings/base.py | 2 +- .../openai/langchain_openai/llms/base.py | 8 +- 3 files changed, 32 insertions(+), 69 deletions(-) diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index a10e9ce3191fd..1b0fb93a6b74c 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -22,84 +22,49 @@ import certifi import openai import tiktoken -from langchain_core.callbacks import ( - AsyncCallbackManagerForLLMRun, - CallbackManagerForLLMRun, -) +from langchain_core.callbacks import (AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun) from langchain_core.language_models import LanguageModelInput from langchain_core.language_models.chat_models import BaseChatModel, LangSmithParams -from langchain_core.messages import ( - AIMessage, - AIMessageChunk, - BaseMessage, - BaseMessageChunk, - ChatMessage, - ChatMessageChunk, - FunctionMessage, - FunctionMessageChunk, - HumanMessage, - HumanMessageChunk, - InvalidToolCall, - SystemMessage, - SystemMessageChunk, - ToolCall, - ToolMessage, - ToolMessageChunk, - is_data_content_block, -) +from langchain_core.messages import (AIMessage, AIMessageChunk, BaseMessage, + BaseMessageChunk, ChatMessage, ChatMessageChunk, + FunctionMessage, FunctionMessageChunk, + HumanMessage, HumanMessageChunk, InvalidToolCall, + SystemMessage, SystemMessageChunk, ToolCall, + ToolMessage, ToolMessageChunk) from langchain_core.messages import content as types -from langchain_core.messages.ai import ( - InputTokenDetails, - OutputTokenDetails, - UsageMetadata, -) +from langchain_core.messages import is_data_content_block +from langchain_core.messages.ai import (InputTokenDetails, OutputTokenDetails, + UsageMetadata) from langchain_core.messages.block_translators.openai import ( - _convert_from_v03_ai_message, - convert_to_openai_data_block, -) + _convert_from_v03_ai_message, convert_to_openai_data_block) from langchain_core.messages.tool import tool_call_chunk from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser -from langchain_core.output_parsers.openai_tools import ( - JsonOutputKeyToolsParser, - PydanticToolsParser, - make_invalid_tool_call, - parse_tool_call, -) +from langchain_core.output_parsers.openai_tools import (JsonOutputKeyToolsParser, + PydanticToolsParser, + make_invalid_tool_call, + parse_tool_call) from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult -from langchain_core.runnables import ( - Runnable, - RunnableLambda, - RunnableMap, - RunnablePassthrough, -) +from langchain_core.runnables import (Runnable, RunnableLambda, RunnableMap, + RunnablePassthrough) from langchain_core.runnables.config import run_in_executor from langchain_core.tools import BaseTool from langchain_core.tools.base import _stringify from langchain_core.utils import get_pydantic_field_names -from langchain_core.utils.function_calling import ( - convert_to_openai_function, - convert_to_openai_tool, -) -from langchain_core.utils.pydantic import ( - PydanticBaseModel, - TypeBaseModel, - is_basemodel_subclass, -) +from langchain_core.utils.function_calling import (convert_to_openai_function, + convert_to_openai_tool) +from langchain_core.utils.pydantic import (PydanticBaseModel, TypeBaseModel, + is_basemodel_subclass) from langchain_core.utils.utils import _build_model_kwargs, from_env, secret_from_env +from langchain_openai.chat_models._client_utils import (_get_default_async_httpx_client, + _get_default_httpx_client) +from langchain_openai.chat_models._compat import (_convert_from_v1_to_chat_completions, + _convert_from_v1_to_responses, + _convert_to_v03_ai_message) from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator from pydantic.v1 import BaseModel as BaseModelV1 from typing_extensions import Self -from langchain_openai.chat_models._client_utils import ( - _get_default_async_httpx_client, - _get_default_httpx_client, -) -from langchain_openai.chat_models._compat import ( - _convert_from_v1_to_chat_completions, - _convert_from_v1_to_responses, - _convert_to_v03_ai_message, -) - if TYPE_CHECKING: from openai.types.responses import Response @@ -767,7 +732,7 @@ def validate_environment(self) -> Self: if isinstance(self.openai_api_key, SecretStr): api_key_value = self.openai_api_key.get_secret_value() elif callable(self.openai_api_key): - api_key_value = self.openai_api_key() + api_key_value = self.openai_api_key client_params: dict = { "api_key": api_key_value, diff --git a/libs/partners/openai/langchain_openai/embeddings/base.py b/libs/partners/openai/langchain_openai/embeddings/base.py index 30d47c6d28b10..d346bcbc8c3bc 100644 --- a/libs/partners/openai/langchain_openai/embeddings/base.py +++ b/libs/partners/openai/langchain_openai/embeddings/base.py @@ -300,7 +300,7 @@ def validate_environment(self) -> Self: if isinstance(self.openai_api_key, SecretStr): api_key_value = self.openai_api_key.get_secret_value() elif callable(self.openai_api_key): - api_key_value = self.openai_api_key() + api_key_value = self.openai_api_key client_params: dict = { "api_key": api_key_value, diff --git a/libs/partners/openai/langchain_openai/llms/base.py b/libs/partners/openai/langchain_openai/llms/base.py index 9de2b5a10b021..872f20c455c1e 100644 --- a/libs/partners/openai/langchain_openai/llms/base.py +++ b/libs/partners/openai/langchain_openai/llms/base.py @@ -9,10 +9,8 @@ import openai import tiktoken -from langchain_core.callbacks import ( - AsyncCallbackManagerForLLMRun, - CallbackManagerForLLMRun, -) +from langchain_core.callbacks import (AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun) from langchain_core.language_models.llms import BaseLLM from langchain_core.outputs import Generation, GenerationChunk, LLMResult from langchain_core.utils import get_pydantic_field_names @@ -283,7 +281,7 @@ def validate_environment(self) -> Self: if isinstance(self.openai_api_key, SecretStr): api_key_value = self.openai_api_key.get_secret_value() elif callable(self.openai_api_key): - api_key_value = self.openai_api_key() + api_key_value = self.openai_api_key client_params: dict = { "api_key": api_key_value, From 328c45886e5b0098ba408eb11931ed26cd1907a7 Mon Sep 17 00:00:00 2001 From: Marlene <57748216+marlenezw@users.noreply.github.com> Date: Mon, 20 Oct 2025 18:00:52 +0000 Subject: [PATCH 04/12] adding typing so mypy doesnt complain --- libs/partners/openai/langchain_openai/chat_models/base.py | 4 ++-- libs/partners/openai/langchain_openai/embeddings/base.py | 4 ++-- libs/partners/openai/langchain_openai/llms/base.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index 1b0fb93a6b74c..6b5a092f98559 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -727,12 +727,12 @@ def validate_environment(self) -> Self: self.stream_usage = True # Resolve API key from SecretStr or Callable - api_key_value = None + api_key_value: str | Callable[[], str] | None = None if self.openai_api_key is not None: if isinstance(self.openai_api_key, SecretStr): api_key_value = self.openai_api_key.get_secret_value() elif callable(self.openai_api_key): - api_key_value = self.openai_api_key + api_key_value = self.openai_api_key # type: ignore[assignment] client_params: dict = { "api_key": api_key_value, diff --git a/libs/partners/openai/langchain_openai/embeddings/base.py b/libs/partners/openai/langchain_openai/embeddings/base.py index d346bcbc8c3bc..50e126b40b57f 100644 --- a/libs/partners/openai/langchain_openai/embeddings/base.py +++ b/libs/partners/openai/langchain_openai/embeddings/base.py @@ -295,12 +295,12 @@ def validate_environment(self) -> Self: raise ValueError(msg) # Resolve API key from SecretStr or Callable - api_key_value = None + api_key_value: str | Callable[[], str] | None = None if self.openai_api_key is not None: if isinstance(self.openai_api_key, SecretStr): api_key_value = self.openai_api_key.get_secret_value() elif callable(self.openai_api_key): - api_key_value = self.openai_api_key + api_key_value = self.openai_api_key # type: ignore[assignment] client_params: dict = { "api_key": api_key_value, diff --git a/libs/partners/openai/langchain_openai/llms/base.py b/libs/partners/openai/langchain_openai/llms/base.py index 872f20c455c1e..c7c17cde807f3 100644 --- a/libs/partners/openai/langchain_openai/llms/base.py +++ b/libs/partners/openai/langchain_openai/llms/base.py @@ -276,12 +276,12 @@ def validate_environment(self) -> Self: raise ValueError(msg) # Resolve API key from SecretStr or Callable - api_key_value = None + api_key_value: str | Callable[[], str] | None = None if self.openai_api_key is not None: if isinstance(self.openai_api_key, SecretStr): api_key_value = self.openai_api_key.get_secret_value() elif callable(self.openai_api_key): - api_key_value = self.openai_api_key + api_key_value = self.openai_api_key # type: ignore[assignment] client_params: dict = { "api_key": api_key_value, From 22e384e550e100796b342cd2f67fc8b41cecb256 Mon Sep 17 00:00:00 2001 From: Marlene <57748216+marlenezw@users.noreply.github.com> Date: Mon, 20 Oct 2025 18:06:23 +0000 Subject: [PATCH 05/12] linting --- .../langchain_openai/chat_models/base.py | 89 +++++++++++++------ .../openai/langchain_openai/llms/base.py | 6 +- 2 files changed, 66 insertions(+), 29 deletions(-) diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index 6b5a092f98559..654036c238a90 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -22,49 +22,84 @@ import certifi import openai import tiktoken -from langchain_core.callbacks import (AsyncCallbackManagerForLLMRun, - CallbackManagerForLLMRun) +from langchain_core.callbacks import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) from langchain_core.language_models import LanguageModelInput from langchain_core.language_models.chat_models import BaseChatModel, LangSmithParams -from langchain_core.messages import (AIMessage, AIMessageChunk, BaseMessage, - BaseMessageChunk, ChatMessage, ChatMessageChunk, - FunctionMessage, FunctionMessageChunk, - HumanMessage, HumanMessageChunk, InvalidToolCall, - SystemMessage, SystemMessageChunk, ToolCall, - ToolMessage, ToolMessageChunk) +from langchain_core.messages import ( + AIMessage, + AIMessageChunk, + BaseMessage, + BaseMessageChunk, + ChatMessage, + ChatMessageChunk, + FunctionMessage, + FunctionMessageChunk, + HumanMessage, + HumanMessageChunk, + InvalidToolCall, + SystemMessage, + SystemMessageChunk, + ToolCall, + ToolMessage, + ToolMessageChunk, + is_data_content_block, +) from langchain_core.messages import content as types -from langchain_core.messages import is_data_content_block -from langchain_core.messages.ai import (InputTokenDetails, OutputTokenDetails, - UsageMetadata) +from langchain_core.messages.ai import ( + InputTokenDetails, + OutputTokenDetails, + UsageMetadata, +) from langchain_core.messages.block_translators.openai import ( - _convert_from_v03_ai_message, convert_to_openai_data_block) + _convert_from_v03_ai_message, + convert_to_openai_data_block, +) from langchain_core.messages.tool import tool_call_chunk from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser -from langchain_core.output_parsers.openai_tools import (JsonOutputKeyToolsParser, - PydanticToolsParser, - make_invalid_tool_call, - parse_tool_call) +from langchain_core.output_parsers.openai_tools import ( + JsonOutputKeyToolsParser, + PydanticToolsParser, + make_invalid_tool_call, + parse_tool_call, +) from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult -from langchain_core.runnables import (Runnable, RunnableLambda, RunnableMap, - RunnablePassthrough) +from langchain_core.runnables import ( + Runnable, + RunnableLambda, + RunnableMap, + RunnablePassthrough, +) from langchain_core.runnables.config import run_in_executor from langchain_core.tools import BaseTool from langchain_core.tools.base import _stringify from langchain_core.utils import get_pydantic_field_names -from langchain_core.utils.function_calling import (convert_to_openai_function, - convert_to_openai_tool) -from langchain_core.utils.pydantic import (PydanticBaseModel, TypeBaseModel, - is_basemodel_subclass) +from langchain_core.utils.function_calling import ( + convert_to_openai_function, + convert_to_openai_tool, +) +from langchain_core.utils.pydantic import ( + PydanticBaseModel, + TypeBaseModel, + is_basemodel_subclass, +) from langchain_core.utils.utils import _build_model_kwargs, from_env, secret_from_env -from langchain_openai.chat_models._client_utils import (_get_default_async_httpx_client, - _get_default_httpx_client) -from langchain_openai.chat_models._compat import (_convert_from_v1_to_chat_completions, - _convert_from_v1_to_responses, - _convert_to_v03_ai_message) from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator from pydantic.v1 import BaseModel as BaseModelV1 from typing_extensions import Self +from langchain_openai.chat_models._client_utils import ( + _get_default_async_httpx_client, + _get_default_httpx_client, +) +from langchain_openai.chat_models._compat import ( + _convert_from_v1_to_chat_completions, + _convert_from_v1_to_responses, + _convert_to_v03_ai_message, +) + if TYPE_CHECKING: from openai.types.responses import Response diff --git a/libs/partners/openai/langchain_openai/llms/base.py b/libs/partners/openai/langchain_openai/llms/base.py index c7c17cde807f3..63bb947ef02c7 100644 --- a/libs/partners/openai/langchain_openai/llms/base.py +++ b/libs/partners/openai/langchain_openai/llms/base.py @@ -9,8 +9,10 @@ import openai import tiktoken -from langchain_core.callbacks import (AsyncCallbackManagerForLLMRun, - CallbackManagerForLLMRun) +from langchain_core.callbacks import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) from langchain_core.language_models.llms import BaseLLM from langchain_core.outputs import Generation, GenerationChunk, LLMResult from langchain_core.utils import get_pydantic_field_names From 24d9e1e33ed2817e94d4af0c86c2ff80ff9089d0 Mon Sep 17 00:00:00 2001 From: Marlene <57748216+marlenezw@users.noreply.github.com> Date: Mon, 20 Oct 2025 18:35:54 +0000 Subject: [PATCH 06/12] updating code to try make more performant --- libs/partners/openai/langchain_openai/chat_models/base.py | 3 ++- libs/partners/openai/langchain_openai/embeddings/base.py | 3 ++- libs/partners/openai/langchain_openai/llms/base.py | 3 ++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index 654036c238a90..6dd61f0a7bcc0 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -766,7 +766,8 @@ def validate_environment(self) -> Self: if self.openai_api_key is not None: if isinstance(self.openai_api_key, SecretStr): api_key_value = self.openai_api_key.get_secret_value() - elif callable(self.openai_api_key): + else: + # Must be a Callable if not SecretStr and not None api_key_value = self.openai_api_key # type: ignore[assignment] client_params: dict = { diff --git a/libs/partners/openai/langchain_openai/embeddings/base.py b/libs/partners/openai/langchain_openai/embeddings/base.py index 50e126b40b57f..113ae2512b01e 100644 --- a/libs/partners/openai/langchain_openai/embeddings/base.py +++ b/libs/partners/openai/langchain_openai/embeddings/base.py @@ -299,7 +299,8 @@ def validate_environment(self) -> Self: if self.openai_api_key is not None: if isinstance(self.openai_api_key, SecretStr): api_key_value = self.openai_api_key.get_secret_value() - elif callable(self.openai_api_key): + else: + # Must be a Callable if not SecretStr and not None api_key_value = self.openai_api_key # type: ignore[assignment] client_params: dict = { diff --git a/libs/partners/openai/langchain_openai/llms/base.py b/libs/partners/openai/langchain_openai/llms/base.py index 63bb947ef02c7..ea7c7db1ffb7e 100644 --- a/libs/partners/openai/langchain_openai/llms/base.py +++ b/libs/partners/openai/langchain_openai/llms/base.py @@ -282,7 +282,8 @@ def validate_environment(self) -> Self: if self.openai_api_key is not None: if isinstance(self.openai_api_key, SecretStr): api_key_value = self.openai_api_key.get_secret_value() - elif callable(self.openai_api_key): + else: + # Must be a Callable if not SecretStr and not None api_key_value = self.openai_api_key # type: ignore[assignment] client_params: dict = { From 1c6ee6279145433a302e19dbc69cb3ebf977d284 Mon Sep 17 00:00:00 2001 From: Marlene <57748216+marlenezw@users.noreply.github.com> Date: Mon, 20 Oct 2025 18:56:34 +0000 Subject: [PATCH 07/12] reverting to explicit checking of callable type --- libs/partners/openai/langchain_openai/chat_models/base.py | 3 +-- libs/partners/openai/langchain_openai/embeddings/base.py | 3 +-- libs/partners/openai/langchain_openai/llms/base.py | 3 +-- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index 6dd61f0a7bcc0..654036c238a90 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -766,8 +766,7 @@ def validate_environment(self) -> Self: if self.openai_api_key is not None: if isinstance(self.openai_api_key, SecretStr): api_key_value = self.openai_api_key.get_secret_value() - else: - # Must be a Callable if not SecretStr and not None + elif callable(self.openai_api_key): api_key_value = self.openai_api_key # type: ignore[assignment] client_params: dict = { diff --git a/libs/partners/openai/langchain_openai/embeddings/base.py b/libs/partners/openai/langchain_openai/embeddings/base.py index 113ae2512b01e..50e126b40b57f 100644 --- a/libs/partners/openai/langchain_openai/embeddings/base.py +++ b/libs/partners/openai/langchain_openai/embeddings/base.py @@ -299,8 +299,7 @@ def validate_environment(self) -> Self: if self.openai_api_key is not None: if isinstance(self.openai_api_key, SecretStr): api_key_value = self.openai_api_key.get_secret_value() - else: - # Must be a Callable if not SecretStr and not None + elif callable(self.openai_api_key): api_key_value = self.openai_api_key # type: ignore[assignment] client_params: dict = { diff --git a/libs/partners/openai/langchain_openai/llms/base.py b/libs/partners/openai/langchain_openai/llms/base.py index ea7c7db1ffb7e..63bb947ef02c7 100644 --- a/libs/partners/openai/langchain_openai/llms/base.py +++ b/libs/partners/openai/langchain_openai/llms/base.py @@ -282,8 +282,7 @@ def validate_environment(self) -> Self: if self.openai_api_key is not None: if isinstance(self.openai_api_key, SecretStr): api_key_value = self.openai_api_key.get_secret_value() - else: - # Must be a Callable if not SecretStr and not None + elif callable(self.openai_api_key): api_key_value = self.openai_api_key # type: ignore[assignment] client_params: dict = { From 109a7e714bfd0dec83008b025a244e6811cc8716 Mon Sep 17 00:00:00 2001 From: Chester Curme Date: Mon, 20 Oct 2025 15:10:08 -0400 Subject: [PATCH 08/12] minimize changes --- .../langchain_openai/chat_models/base.py | 27 ++++++++++++++----- .../langchain_openai/embeddings/base.py | 5 ++-- .../openai/langchain_openai/llms/base.py | 5 ++-- 3 files changed, 25 insertions(+), 12 deletions(-) diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index 654036c238a90..97d6749fb668a 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -16,7 +16,14 @@ from json import JSONDecodeError from math import ceil from operator import itemgetter -from typing import TYPE_CHECKING, Any, Literal, TypeAlias, TypeVar, cast +from typing import ( + TYPE_CHECKING, + Any, + Literal, + TypeAlias, + TypeVar, + cast, +) from urllib.parse import urlparse import certifi @@ -27,7 +34,10 @@ CallbackManagerForLLMRun, ) from langchain_core.language_models import LanguageModelInput -from langchain_core.language_models.chat_models import BaseChatModel, LangSmithParams +from langchain_core.language_models.chat_models import ( + BaseChatModel, + LangSmithParams, +) from langchain_core.messages import ( AIMessage, AIMessageChunk, @@ -86,7 +96,13 @@ is_basemodel_subclass, ) from langchain_core.utils.utils import _build_model_kwargs, from_env, secret_from_env -from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator +from pydantic import ( + BaseModel, + ConfigDict, + Field, + SecretStr, + model_validator, +) from pydantic.v1 import BaseModel as BaseModelV1 from typing_extensions import Self @@ -450,8 +466,7 @@ class BaseChatOpenAI(BaseChatModel): model_kwargs: dict[str, Any] = Field(default_factory=dict) """Holds any model parameters valid for `create` call not explicitly specified.""" openai_api_key: SecretStr | None | Callable[[], str] = Field( - alias="api_key", - default_factory=secret_from_env("OPENAI_API_KEY", default=None), + alias="api_key", default_factory=secret_from_env("OPENAI_API_KEY", default=None) ) openai_api_base: str | None = Field(default=None, alias="base_url") """Base URL path for API requests, leave blank if not using a proxy or service emulator.""" # noqa: E501 @@ -767,7 +782,7 @@ def validate_environment(self) -> Self: if isinstance(self.openai_api_key, SecretStr): api_key_value = self.openai_api_key.get_secret_value() elif callable(self.openai_api_key): - api_key_value = self.openai_api_key # type: ignore[assignment] + api_key_value = self.openai_api_key client_params: dict = { "api_key": api_key_value, diff --git a/libs/partners/openai/langchain_openai/embeddings/base.py b/libs/partners/openai/langchain_openai/embeddings/base.py index 50e126b40b57f..e676ceb136b41 100644 --- a/libs/partners/openai/langchain_openai/embeddings/base.py +++ b/libs/partners/openai/langchain_openai/embeddings/base.py @@ -190,8 +190,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings): embedding_ctx_length: int = 8191 """The maximum number of tokens to embed at once.""" openai_api_key: SecretStr | None | Callable[[], str] = Field( - alias="api_key", - default_factory=secret_from_env("OPENAI_API_KEY", default=None), + alias="api_key", default_factory=secret_from_env("OPENAI_API_KEY", default=None) ) """Automatically inferred from env var `OPENAI_API_KEY` if not provided.""" openai_organization: str | None = Field( @@ -300,7 +299,7 @@ def validate_environment(self) -> Self: if isinstance(self.openai_api_key, SecretStr): api_key_value = self.openai_api_key.get_secret_value() elif callable(self.openai_api_key): - api_key_value = self.openai_api_key # type: ignore[assignment] + api_key_value = self.openai_api_key client_params: dict = { "api_key": api_key_value, diff --git a/libs/partners/openai/langchain_openai/llms/base.py b/libs/partners/openai/langchain_openai/llms/base.py index 63bb947ef02c7..f983f3982390d 100644 --- a/libs/partners/openai/langchain_openai/llms/base.py +++ b/libs/partners/openai/langchain_openai/llms/base.py @@ -187,8 +187,7 @@ class BaseOpenAI(BaseLLM): model_kwargs: dict[str, Any] = Field(default_factory=dict) """Holds any model parameters valid for `create` call not explicitly specified.""" openai_api_key: SecretStr | None | Callable[[], str] = Field( - alias="api_key", - default_factory=secret_from_env("OPENAI_API_KEY", default=None), + alias="api_key", default_factory=secret_from_env("OPENAI_API_KEY", default=None) ) """Automatically inferred from env var `OPENAI_API_KEY` if not provided.""" openai_api_base: str | None = Field( @@ -283,7 +282,7 @@ def validate_environment(self) -> Self: if isinstance(self.openai_api_key, SecretStr): api_key_value = self.openai_api_key.get_secret_value() elif callable(self.openai_api_key): - api_key_value = self.openai_api_key # type: ignore[assignment] + api_key_value = self.openai_api_key client_params: dict = { "api_key": api_key_value, From f13e521dec9ea429cfb61759c446203ca81422fb Mon Sep 17 00:00:00 2001 From: Chester Curme Date: Mon, 20 Oct 2025 16:55:16 -0400 Subject: [PATCH 09/12] start on async support --- .../langchain_openai/chat_models/base.py | 94 ++++++++++++++----- 1 file changed, 70 insertions(+), 24 deletions(-) diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index 97d6749fb668a..ed1812c53436d 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -2,7 +2,9 @@ from __future__ import annotations +import asyncio import base64 +import inspect import json import logging import os @@ -10,7 +12,14 @@ import ssl import sys import warnings -from collections.abc import AsyncIterator, Callable, Iterator, Mapping, Sequence +from collections.abc import ( + AsyncIterator, + Awaitable, + Callable, + Iterator, + Mapping, + Sequence, +) from functools import partial from io import BytesIO from json import JSONDecodeError @@ -465,7 +474,9 @@ class BaseChatOpenAI(BaseChatModel): """What sampling temperature to use.""" model_kwargs: dict[str, Any] = Field(default_factory=dict) """Holds any model parameters valid for `create` call not explicitly specified.""" - openai_api_key: SecretStr | None | Callable[[], str] = Field( + openai_api_key: ( + SecretStr | None | Callable[[], str] | Callable[[], Awaitable[str]] + ) = Field( alias="api_key", default_factory=secret_from_env("OPENAI_API_KEY", default=None) ) openai_api_base: str | None = Field(default=None, alias="base_url") @@ -777,15 +788,28 @@ def validate_environment(self) -> Self: self.stream_usage = True # Resolve API key from SecretStr or Callable - api_key_value: str | Callable[[], str] | None = None + sync_api_key_value: str | Callable[[], str] | None = None + async_api_key_value: str | Callable[[], Awaitable[str]] | None = None + if self.openai_api_key is not None: if isinstance(self.openai_api_key, SecretStr): - api_key_value = self.openai_api_key.get_secret_value() + sync_api_key_value = self.openai_api_key.get_secret_value() + async_api_key_value = self.openai_api_key.get_secret_value() elif callable(self.openai_api_key): - api_key_value = self.openai_api_key + if inspect.iscoroutinefunction(self.openai_api_key): + async_api_key_value = self.openai_api_key + sync_api_key_value = None + else: + sync_api_key_value = cast(Callable, self.openai_api_key) + + async def async_api_key_wrapper() -> str: + return await asyncio.get_event_loop().run_in_executor( + None, cast(Callable, self.openai_api_key) + ) + + async_api_key_value = async_api_key_wrapper client_params: dict = { - "api_key": api_key_value, "organization": self.openai_organization, "base_url": self.openai_api_base, "timeout": self.request_timeout, @@ -806,24 +830,32 @@ def validate_environment(self) -> Self: ) raise ValueError(msg) if not self.client: - if self.openai_proxy and not self.http_client: - try: - import httpx - except ImportError as e: - msg = ( - "Could not import httpx python package. " - "Please install it with `pip install httpx`." + if sync_api_key_value is None: + # No valid sync API key, leave client as None + self.client = None + self.root_client = None + else: + if self.openai_proxy and not self.http_client: + try: + import httpx + except ImportError as e: + msg = ( + "Could not import httpx python package. " + "Please install it with `pip install httpx`." + ) + raise ImportError(msg) from e + self.http_client = httpx.Client( + proxy=self.openai_proxy, verify=global_ssl_context ) - raise ImportError(msg) from e - self.http_client = httpx.Client( - proxy=self.openai_proxy, verify=global_ssl_context - ) - sync_specific = { - "http_client": self.http_client - or _get_default_httpx_client(self.openai_api_base, self.request_timeout) - } - self.root_client = openai.OpenAI(**client_params, **sync_specific) # type: ignore[arg-type] - self.client = self.root_client.chat.completions + sync_specific = { + "http_client": self.http_client + or _get_default_httpx_client( + self.openai_api_base, self.request_timeout + ), + "api_key": sync_api_key_value, + } + self.root_client = openai.OpenAI(**client_params, **sync_specific) # type: ignore[arg-type] + self.client = self.root_client.chat.completions if not self.async_client: if self.openai_proxy and not self.http_async_client: try: @@ -841,7 +873,8 @@ def validate_environment(self) -> Self: "http_client": self.http_async_client or _get_default_async_httpx_client( self.openai_api_base, self.request_timeout - ) + ), + "api_key": async_api_key_value, } self.root_async_client = openai.AsyncOpenAI( **client_params, @@ -971,6 +1004,16 @@ def _convert_chunk_to_generation_chunk( message=message_chunk, generation_info=generation_info or None ) + def _ensure_sync_client_available(self) -> None: + """Check that sync client is available, raise error if not.""" + if self.client is None: + msg = ( + "Sync client is not available. This happens when an async callable " + "was provided for the API key. Use async methods (ainvoke, astream) " + "instead." + ) + raise ValueError(msg) + def _stream_responses( self, messages: list[BaseMessage], @@ -978,6 +1021,7 @@ def _stream_responses( run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: + self._ensure_sync_client_available() kwargs["stream"] = True payload = self._get_request_payload(messages, stop=stop, **kwargs) if self.include_response_headers: @@ -1107,6 +1151,7 @@ def _stream( stream_usage: bool | None = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: + self._ensure_sync_client_available() kwargs["stream"] = True stream_usage = self._should_stream_usage(stream_usage, **kwargs) if stream_usage: @@ -1175,6 +1220,7 @@ def _generate( run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> ChatResult: + self._ensure_sync_client_available() payload = self._get_request_payload(messages, stop=stop, **kwargs) generation_info = None raw_response = None From 87aba18383a5ee6a57ab330b35b7152e349a26dc Mon Sep 17 00:00:00 2001 From: Chester Curme Date: Tue, 21 Oct 2025 10:28:59 -0400 Subject: [PATCH 10/12] nits --- .../chat_models/_client_utils.py | 35 ++++++++- .../langchain_openai/chat_models/base.py | 75 ++++++++++++++----- 2 files changed, 89 insertions(+), 21 deletions(-) diff --git a/libs/partners/openai/langchain_openai/chat_models/_client_utils.py b/libs/partners/openai/langchain_openai/chat_models/_client_utils.py index 3eba5c7309be5..4a0efce9e1bfe 100644 --- a/libs/partners/openai/langchain_openai/chat_models/_client_utils.py +++ b/libs/partners/openai/langchain_openai/chat_models/_client_utils.py @@ -9,11 +9,14 @@ from __future__ import annotations import asyncio +import inspect import os +from collections.abc import Awaitable, Callable from functools import lru_cache -from typing import Any +from typing import Any, cast import openai +from pydantic import SecretStr class _SyncHttpxClientWrapper(openai.DefaultHttpxClient): @@ -107,3 +110,33 @@ def _get_default_async_httpx_client( return _build_async_httpx_client(base_url, timeout) else: return _cached_async_httpx_client(base_url, timeout) + + +def _resolve_sync_and_async_api_keys( + api_key: SecretStr | Callable[[], str] | Callable[[], Awaitable[str]], +) -> tuple[str | None | Callable[[], str], str | Callable[[], Awaitable[str]]]: + """Resolve sync and async API key values. + + Because OpenAI and AsyncOpenAI clients support either sync or async callables for + the API key, we need to resolve separate values here. + """ + if isinstance(api_key, SecretStr): + sync_api_key_value: str | None | Callable[[], str] = api_key.get_secret_value() + async_api_key_value: str | Callable[[], Awaitable[str]] = ( + api_key.get_secret_value() + ) + elif callable(api_key): + if inspect.iscoroutinefunction(api_key): + async_api_key_value = api_key + sync_api_key_value = None + else: + sync_api_key_value = cast(Callable, api_key) + + async def async_api_key_wrapper() -> str: + return await asyncio.get_event_loop().run_in_executor( + None, cast(Callable, api_key) + ) + + async_api_key_value = async_api_key_wrapper + + return sync_api_key_value, async_api_key_value diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index ed1812c53436d..8303242ebbc8c 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -2,9 +2,7 @@ from __future__ import annotations -import asyncio import base64 -import inspect import json import logging import os @@ -118,6 +116,7 @@ from langchain_openai.chat_models._client_utils import ( _get_default_async_httpx_client, _get_default_httpx_client, + _resolve_sync_and_async_api_keys, ) from langchain_openai.chat_models._compat import ( _convert_from_v1_to_chat_completions, @@ -479,6 +478,52 @@ class BaseChatOpenAI(BaseChatModel): ) = Field( alias="api_key", default_factory=secret_from_env("OPENAI_API_KEY", default=None) ) + """API key to use. + + Can be inferred from the `OPENAI_API_KEY` environment variable, or specified as a + string, or sync or async callable that returns a string. + + ??? example "Specify with environment variable" + + ```bash + export OPENAI_API_KEY=... + ``` + ```python + from langchain_openai import ChatOpenAI + + model = ChatOpenAI(model="gpt-5-nano") + ``` + + ??? example "Specify with a string" + + ```python + from langchain_openai import ChatOpenAI + + model = ChatOpenAI(model="gpt-5-nano", api_key="...") + ``` + + ??? example "Specify with a sync callable" + ```python + from langchain_openai import ChatOpenAI + + def get_api_key() -> str: + # Custom logic to retrieve API key + return "..." + + model = ChatOpenAI(model="gpt-5-nano", api_key=get_api_key) + ``` + + ??? example "Specify with an async callable" + ```python + from langchain_openai import ChatOpenAI + + async def get_api_key() -> str: + # Custom async logic to retrieve API key + return "..." + + model = ChatOpenAI(model="gpt-5-nano", api_key=get_api_key) + ``` + """ openai_api_base: str | None = Field(default=None, alias="base_url") """Base URL path for API requests, leave blank if not using a proxy or service emulator.""" # noqa: E501 openai_organization: str | None = Field(default=None, alias="organization") @@ -792,22 +837,11 @@ def validate_environment(self) -> Self: async_api_key_value: str | Callable[[], Awaitable[str]] | None = None if self.openai_api_key is not None: - if isinstance(self.openai_api_key, SecretStr): - sync_api_key_value = self.openai_api_key.get_secret_value() - async_api_key_value = self.openai_api_key.get_secret_value() - elif callable(self.openai_api_key): - if inspect.iscoroutinefunction(self.openai_api_key): - async_api_key_value = self.openai_api_key - sync_api_key_value = None - else: - sync_api_key_value = cast(Callable, self.openai_api_key) - - async def async_api_key_wrapper() -> str: - return await asyncio.get_event_loop().run_in_executor( - None, cast(Callable, self.openai_api_key) - ) - - async_api_key_value = async_api_key_wrapper + # Because OpenAI and AsyncOpenAI clients support either sync or async + # callables for the API key, we need to resolve separate values here. + sync_api_key_value, async_api_key_value = _resolve_sync_and_async_api_keys( + self.openai_api_key + ) client_params: dict = { "organization": self.openai_organization, @@ -831,7 +865,8 @@ async def async_api_key_wrapper() -> str: raise ValueError(msg) if not self.client: if sync_api_key_value is None: - # No valid sync API key, leave client as None + # No valid sync API key, leave client as None and raise informative + # error on invocation. self.client = None self.root_client = None else: @@ -1010,7 +1045,7 @@ def _ensure_sync_client_available(self) -> None: msg = ( "Sync client is not available. This happens when an async callable " "was provided for the API key. Use async methods (ainvoke, astream) " - "instead." + "instead, or provide a string or sync callable for the API key." ) raise ValueError(msg) From 4b07dca2461ad6572cb0175053455e0facc2a8f7 Mon Sep 17 00:00:00 2001 From: Chester Curme Date: Tue, 21 Oct 2025 11:03:54 -0400 Subject: [PATCH 11/12] add test --- .../chat_models/test_base.py | 52 +++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/libs/partners/openai/tests/integration_tests/chat_models/test_base.py b/libs/partners/openai/tests/integration_tests/chat_models/test_base.py index 945e83124d324..6c0193de58f94 100644 --- a/libs/partners/openai/tests/integration_tests/chat_models/test_base.py +++ b/libs/partners/openai/tests/integration_tests/chat_models/test_base.py @@ -2,6 +2,7 @@ import base64 import json +import os from collections.abc import AsyncIterator from pathlib import Path from textwrap import dedent @@ -64,6 +65,57 @@ def test_chat_openai_model() -> None: assert chat.model_name == "bar" +def test_callable_api_key(monkeypatch: pytest.MonkeyPatch) -> None: + original_key = os.environ["OPENAI_API_KEY"] + + calls = {"sync": 0} + + def get_openai_api_key() -> str: + calls["sync"] += 1 + return original_key + + monkeypatch.delenv("OPENAI_API_KEY") + + model = ChatOpenAI(model="gpt-4.1-mini", api_key=get_openai_api_key) + response = model.invoke("hello") + assert isinstance(response, AIMessage) + assert calls["sync"] == 1 + + +async def test_callable_api_key_async(monkeypatch: pytest.MonkeyPatch) -> None: + original_key = os.environ["OPENAI_API_KEY"] + + calls = {"sync": 0, "async": 0} + + def get_openai_api_key() -> str: + calls["sync"] += 1 + return original_key + + async def get_openai_api_key_async() -> str: + calls["async"] += 1 + return original_key + + monkeypatch.delenv("OPENAI_API_KEY") + + model = ChatOpenAI(model="gpt-4.1-mini", api_key=get_openai_api_key) + response = model.invoke("hello") + assert isinstance(response, AIMessage) + assert calls["sync"] == 1 + + response = await model.ainvoke("hello") + assert isinstance(response, AIMessage) + assert calls["sync"] == 2 + + model = ChatOpenAI(model="gpt-4.1-mini", api_key=get_openai_api_key_async) + async_response = await model.ainvoke("hello") + assert isinstance(async_response, AIMessage) + assert calls["async"] == 1 + + with pytest.raises(ValueError): + # We do not create a sync callable from an async one + _ = model.invoke("hello") + + @pytest.mark.parametrize("use_responses_api", [False, True]) def test_chat_openai_system_message(use_responses_api: bool) -> None: """Test ChatOpenAI wrapper with system message.""" From 954f29498261d5a4a9af96d83fb487104a15777a Mon Sep 17 00:00:00 2001 From: Chester Curme Date: Tue, 21 Oct 2025 11:04:31 -0400 Subject: [PATCH 12/12] fix async case for embeddings --- .../langchain_openai/embeddings/base.py | 71 +++++++++++++------ .../integration_tests/embeddings/test_base.py | 56 +++++++++++++++ 2 files changed, 106 insertions(+), 21 deletions(-) diff --git a/libs/partners/openai/langchain_openai/embeddings/base.py b/libs/partners/openai/langchain_openai/embeddings/base.py index e676ceb136b41..f53640b02ed40 100644 --- a/libs/partners/openai/langchain_openai/embeddings/base.py +++ b/libs/partners/openai/langchain_openai/embeddings/base.py @@ -4,7 +4,7 @@ import logging import warnings -from collections.abc import Callable, Iterable, Mapping, Sequence +from collections.abc import Awaitable, Callable, Iterable, Mapping, Sequence from typing import Any, Literal, cast import openai @@ -15,6 +15,8 @@ from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator from typing_extensions import Self +from langchain_openai.chat_models._client_utils import _resolve_sync_and_async_api_keys + logger = logging.getLogger(__name__) @@ -189,7 +191,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings): ) embedding_ctx_length: int = 8191 """The maximum number of tokens to embed at once.""" - openai_api_key: SecretStr | None | Callable[[], str] = Field( + openai_api_key: ( + SecretStr | None | Callable[[], str] | Callable[[], Awaitable[str]] + ) = Field( alias="api_key", default_factory=secret_from_env("OPENAI_API_KEY", default=None) ) """Automatically inferred from env var `OPENAI_API_KEY` if not provided.""" @@ -294,15 +298,17 @@ def validate_environment(self) -> Self: raise ValueError(msg) # Resolve API key from SecretStr or Callable - api_key_value: str | Callable[[], str] | None = None + sync_api_key_value: str | Callable[[], str] | None = None + async_api_key_value: str | Callable[[], Awaitable[str]] | None = None + if self.openai_api_key is not None: - if isinstance(self.openai_api_key, SecretStr): - api_key_value = self.openai_api_key.get_secret_value() - elif callable(self.openai_api_key): - api_key_value = self.openai_api_key + # Because OpenAI and AsyncOpenAI clients support either sync or async + # callables for the API key, we need to resolve separate values here. + sync_api_key_value, async_api_key_value = _resolve_sync_and_async_api_keys( + self.openai_api_key + ) client_params: dict = { - "api_key": api_key_value, "organization": self.openai_organization, "base_url": self.openai_api_base, "timeout": self.request_timeout, @@ -322,18 +328,26 @@ def validate_environment(self) -> Self: ) raise ValueError(msg) if not self.client: - if self.openai_proxy and not self.http_client: - try: - import httpx - except ImportError as e: - msg = ( - "Could not import httpx python package. " - "Please install it with `pip install httpx`." - ) - raise ImportError(msg) from e - self.http_client = httpx.Client(proxy=self.openai_proxy) - sync_specific = {"http_client": self.http_client} - self.client = openai.OpenAI(**client_params, **sync_specific).embeddings # type: ignore[arg-type] + if sync_api_key_value is None: + # No valid sync API key, leave client as None and raise informative + # error on invocation. + self.client = None + else: + if self.openai_proxy and not self.http_client: + try: + import httpx + except ImportError as e: + msg = ( + "Could not import httpx python package. " + "Please install it with `pip install httpx`." + ) + raise ImportError(msg) from e + self.http_client = httpx.Client(proxy=self.openai_proxy) + sync_specific = { + "http_client": self.http_client, + "api_key": sync_api_key_value, + } + self.client = openai.OpenAI(**client_params, **sync_specific).embeddings # type: ignore[arg-type] if not self.async_client: if self.openai_proxy and not self.http_async_client: try: @@ -345,7 +359,10 @@ def validate_environment(self) -> Self: ) raise ImportError(msg) from e self.http_async_client = httpx.AsyncClient(proxy=self.openai_proxy) - async_specific = {"http_client": self.http_async_client} + async_specific = { + "http_client": self.http_async_client, + "api_key": async_api_key_value, + } self.async_client = openai.AsyncOpenAI( **client_params, **async_specific, # type: ignore[arg-type] @@ -359,6 +376,16 @@ def _invocation_params(self) -> dict[str, Any]: params["dimensions"] = self.dimensions return params + def _ensure_sync_client_available(self) -> None: + """Check that sync client is available, raise error if not.""" + if self.client is None: + msg = ( + "Sync client is not available. This happens when an async callable " + "was provided for the API key. Use async methods (ainvoke, astream) " + "instead, or provide a string or sync callable for the API key." + ) + raise ValueError(msg) + def _tokenize( self, texts: list[str], chunk_size: int ) -> tuple[Iterable[int], list[list[int] | str], list[int]]: @@ -578,6 +605,7 @@ def embed_documents( Returns: List of embeddings, one for each text. """ + self._ensure_sync_client_available() chunk_size_ = chunk_size or self.chunk_size client_kwargs = {**self._invocation_params, **kwargs} if not self.check_embedding_ctx_length: @@ -642,6 +670,7 @@ def embed_query(self, text: str, **kwargs: Any) -> list[float]: Returns: Embedding for the text. """ + self._ensure_sync_client_available() return self.embed_documents([text], **kwargs)[0] async def aembed_query(self, text: str, **kwargs: Any) -> list[float]: diff --git a/libs/partners/openai/tests/integration_tests/embeddings/test_base.py b/libs/partners/openai/tests/integration_tests/embeddings/test_base.py index 321edcfc0fb82..95a0385945581 100644 --- a/libs/partners/openai/tests/integration_tests/embeddings/test_base.py +++ b/libs/partners/openai/tests/integration_tests/embeddings/test_base.py @@ -1,7 +1,10 @@ """Test OpenAI embeddings.""" +import os + import numpy as np import openai +import pytest from langchain_openai.embeddings.base import OpenAIEmbeddings @@ -67,3 +70,56 @@ def test_langchain_openai_embeddings_dimensions_large_num() -> None: output = embedding.embed_documents(documents) assert len(output) == 2000 assert len(output[0]) == 128 + + +def test_callable_api_key(monkeypatch: pytest.MonkeyPatch) -> None: + original_key = os.environ["OPENAI_API_KEY"] + + calls = {"sync": 0} + + def get_openai_api_key() -> str: + calls["sync"] += 1 + return original_key + + monkeypatch.delenv("OPENAI_API_KEY") + + model = OpenAIEmbeddings( + model="text-embedding-3-small", dimensions=128, api_key=get_openai_api_key + ) + _ = model.embed_query("hello") + assert calls["sync"] == 1 + + +async def test_callable_api_key_async(monkeypatch: pytest.MonkeyPatch) -> None: + original_key = os.environ["OPENAI_API_KEY"] + + calls = {"sync": 0, "async": 0} + + def get_openai_api_key() -> str: + calls["sync"] += 1 + return original_key + + async def get_openai_api_key_async() -> str: + calls["async"] += 1 + return original_key + + monkeypatch.delenv("OPENAI_API_KEY") + + model = OpenAIEmbeddings( + model="text-embedding-3-small", dimensions=128, api_key=get_openai_api_key + ) + _ = model.embed_query("hello") + assert calls["sync"] == 1 + + _ = await model.aembed_query("hello") + assert calls["sync"] == 2 + + model = OpenAIEmbeddings( + model="text-embedding-3-small", dimensions=128, api_key=get_openai_api_key_async + ) + _ = await model.aembed_query("hello") + assert calls["async"] == 1 + + with pytest.raises(ValueError): + # We do not create a sync callable from an async one + _ = model.embed_query("hello")