diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index a3eee17aa5899..ed1812c53436d 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -2,7 +2,9 @@ from __future__ import annotations +import asyncio import base64 +import inspect import json import logging import os @@ -10,7 +12,14 @@ import ssl import sys import warnings -from collections.abc import AsyncIterator, Callable, Iterator, Mapping, Sequence +from collections.abc import ( + AsyncIterator, + Awaitable, + Callable, + Iterator, + Mapping, + Sequence, +) from functools import partial from io import BytesIO from json import JSONDecodeError @@ -465,7 +474,9 @@ class BaseChatOpenAI(BaseChatModel): """What sampling temperature to use.""" model_kwargs: dict[str, Any] = Field(default_factory=dict) """Holds any model parameters valid for `create` call not explicitly specified.""" - openai_api_key: SecretStr | None = Field( + openai_api_key: ( + SecretStr | None | Callable[[], str] | Callable[[], Awaitable[str]] + ) = Field( alias="api_key", default_factory=secret_from_env("OPENAI_API_KEY", default=None) ) openai_api_base: str | None = Field(default=None, alias="base_url") @@ -776,10 +787,29 @@ def validate_environment(self) -> Self: ): self.stream_usage = True + # Resolve API key from SecretStr or Callable + sync_api_key_value: str | Callable[[], str] | None = None + async_api_key_value: str | Callable[[], Awaitable[str]] | None = None + + if self.openai_api_key is not None: + if isinstance(self.openai_api_key, SecretStr): + sync_api_key_value = self.openai_api_key.get_secret_value() + async_api_key_value = self.openai_api_key.get_secret_value() + elif callable(self.openai_api_key): + if inspect.iscoroutinefunction(self.openai_api_key): + async_api_key_value = self.openai_api_key + sync_api_key_value = None + else: + sync_api_key_value = cast(Callable, self.openai_api_key) + + async def async_api_key_wrapper() -> str: + return await asyncio.get_event_loop().run_in_executor( + None, cast(Callable, self.openai_api_key) + ) + + async_api_key_value = async_api_key_wrapper + client_params: dict = { - "api_key": ( - self.openai_api_key.get_secret_value() if self.openai_api_key else None - ), "organization": self.openai_organization, "base_url": self.openai_api_base, "timeout": self.request_timeout, @@ -800,24 +830,32 @@ def validate_environment(self) -> Self: ) raise ValueError(msg) if not self.client: - if self.openai_proxy and not self.http_client: - try: - import httpx - except ImportError as e: - msg = ( - "Could not import httpx python package. " - "Please install it with `pip install httpx`." + if sync_api_key_value is None: + # No valid sync API key, leave client as None + self.client = None + self.root_client = None + else: + if self.openai_proxy and not self.http_client: + try: + import httpx + except ImportError as e: + msg = ( + "Could not import httpx python package. " + "Please install it with `pip install httpx`." + ) + raise ImportError(msg) from e + self.http_client = httpx.Client( + proxy=self.openai_proxy, verify=global_ssl_context ) - raise ImportError(msg) from e - self.http_client = httpx.Client( - proxy=self.openai_proxy, verify=global_ssl_context - ) - sync_specific = { - "http_client": self.http_client - or _get_default_httpx_client(self.openai_api_base, self.request_timeout) - } - self.root_client = openai.OpenAI(**client_params, **sync_specific) # type: ignore[arg-type] - self.client = self.root_client.chat.completions + sync_specific = { + "http_client": self.http_client + or _get_default_httpx_client( + self.openai_api_base, self.request_timeout + ), + "api_key": sync_api_key_value, + } + self.root_client = openai.OpenAI(**client_params, **sync_specific) # type: ignore[arg-type] + self.client = self.root_client.chat.completions if not self.async_client: if self.openai_proxy and not self.http_async_client: try: @@ -835,7 +873,8 @@ def validate_environment(self) -> Self: "http_client": self.http_async_client or _get_default_async_httpx_client( self.openai_api_base, self.request_timeout - ) + ), + "api_key": async_api_key_value, } self.root_async_client = openai.AsyncOpenAI( **client_params, @@ -965,6 +1004,16 @@ def _convert_chunk_to_generation_chunk( message=message_chunk, generation_info=generation_info or None ) + def _ensure_sync_client_available(self) -> None: + """Check that sync client is available, raise error if not.""" + if self.client is None: + msg = ( + "Sync client is not available. This happens when an async callable " + "was provided for the API key. Use async methods (ainvoke, astream) " + "instead." + ) + raise ValueError(msg) + def _stream_responses( self, messages: list[BaseMessage], @@ -972,6 +1021,7 @@ def _stream_responses( run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: + self._ensure_sync_client_available() kwargs["stream"] = True payload = self._get_request_payload(messages, stop=stop, **kwargs) if self.include_response_headers: @@ -1101,6 +1151,7 @@ def _stream( stream_usage: bool | None = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: + self._ensure_sync_client_available() kwargs["stream"] = True stream_usage = self._should_stream_usage(stream_usage, **kwargs) if stream_usage: @@ -1169,6 +1220,7 @@ def _generate( run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> ChatResult: + self._ensure_sync_client_available() payload = self._get_request_payload(messages, stop=stop, **kwargs) generation_info = None raw_response = None diff --git a/libs/partners/openai/langchain_openai/embeddings/base.py b/libs/partners/openai/langchain_openai/embeddings/base.py index b0b5c9f770d75..e676ceb136b41 100644 --- a/libs/partners/openai/langchain_openai/embeddings/base.py +++ b/libs/partners/openai/langchain_openai/embeddings/base.py @@ -4,7 +4,7 @@ import logging import warnings -from collections.abc import Iterable, Mapping, Sequence +from collections.abc import Callable, Iterable, Mapping, Sequence from typing import Any, Literal, cast import openai @@ -189,7 +189,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings): ) embedding_ctx_length: int = 8191 """The maximum number of tokens to embed at once.""" - openai_api_key: SecretStr | None = Field( + openai_api_key: SecretStr | None | Callable[[], str] = Field( alias="api_key", default_factory=secret_from_env("OPENAI_API_KEY", default=None) ) """Automatically inferred from env var `OPENAI_API_KEY` if not provided.""" @@ -292,10 +292,17 @@ def validate_environment(self) -> Self: "If you are using Azure, please use the `AzureOpenAIEmbeddings` class." ) raise ValueError(msg) + + # Resolve API key from SecretStr or Callable + api_key_value: str | Callable[[], str] | None = None + if self.openai_api_key is not None: + if isinstance(self.openai_api_key, SecretStr): + api_key_value = self.openai_api_key.get_secret_value() + elif callable(self.openai_api_key): + api_key_value = self.openai_api_key + client_params: dict = { - "api_key": ( - self.openai_api_key.get_secret_value() if self.openai_api_key else None - ), + "api_key": api_key_value, "organization": self.openai_organization, "base_url": self.openai_api_base, "timeout": self.request_timeout, diff --git a/libs/partners/openai/langchain_openai/llms/base.py b/libs/partners/openai/langchain_openai/llms/base.py index ed456ec89df39..f983f3982390d 100644 --- a/libs/partners/openai/langchain_openai/llms/base.py +++ b/libs/partners/openai/langchain_openai/llms/base.py @@ -4,7 +4,7 @@ import logging import sys -from collections.abc import AsyncIterator, Collection, Iterator, Mapping +from collections.abc import AsyncIterator, Callable, Collection, Iterator, Mapping from typing import Any, Literal import openai @@ -186,7 +186,7 @@ class BaseOpenAI(BaseLLM): """Generates best_of completions server-side and returns the "best".""" model_kwargs: dict[str, Any] = Field(default_factory=dict) """Holds any model parameters valid for `create` call not explicitly specified.""" - openai_api_key: SecretStr | None = Field( + openai_api_key: SecretStr | None | Callable[[], str] = Field( alias="api_key", default_factory=secret_from_env("OPENAI_API_KEY", default=None) ) """Automatically inferred from env var `OPENAI_API_KEY` if not provided.""" @@ -276,10 +276,16 @@ def validate_environment(self) -> Self: msg = "Cannot stream results when best_of > 1." raise ValueError(msg) + # Resolve API key from SecretStr or Callable + api_key_value: str | Callable[[], str] | None = None + if self.openai_api_key is not None: + if isinstance(self.openai_api_key, SecretStr): + api_key_value = self.openai_api_key.get_secret_value() + elif callable(self.openai_api_key): + api_key_value = self.openai_api_key + client_params: dict = { - "api_key": ( - self.openai_api_key.get_secret_value() if self.openai_api_key else None - ), + "api_key": api_key_value, "organization": self.openai_organization, "base_url": self.openai_api_base, "timeout": self.request_timeout, diff --git a/libs/partners/openai/tests/unit_tests/test_secrets.py b/libs/partners/openai/tests/unit_tests/test_secrets.py index aa1484058e0e2..27d69bed92ce1 100644 --- a/libs/partners/openai/tests/unit_tests/test_secrets.py +++ b/libs/partners/openai/tests/unit_tests/test_secrets.py @@ -187,6 +187,18 @@ def test_openai_uses_actual_secret_value_from_secretstr(model_class: type) -> No assert cast(SecretStr, model.openai_api_key).get_secret_value() == "secret-api-key" +@pytest.mark.parametrize("model_class", [ChatOpenAI, OpenAI, OpenAIEmbeddings]) +def test_openai_api_key_accepts_callable(model_class: type) -> None: + """Test that the API key can be passed as a callable.""" + + def get_api_key() -> str: + return "secret-api-key-from-callable" + + model = model_class(openai_api_key=get_api_key) + assert callable(model.openai_api_key) + assert model.openai_api_key() == "secret-api-key-from-callable" + + @pytest.mark.parametrize("model_class", [AzureChatOpenAI, AzureOpenAI]) def test_azure_serialized_secrets(model_class: type) -> None: """Test that the actual secret value is correctly retrieved."""