Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 15 additions & 24 deletions libs/partners/openai/langchain_openai/chat_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,7 @@
from json import JSONDecodeError
from math import ceil
from operator import itemgetter
from typing import (
TYPE_CHECKING,
Any,
Literal,
TypeAlias,
TypeVar,
cast,
)
from typing import TYPE_CHECKING, Any, Literal, TypeAlias, TypeVar, cast
from urllib.parse import urlparse

import certifi
Expand All @@ -34,10 +27,7 @@
CallbackManagerForLLMRun,
)
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import (
BaseChatModel,
LangSmithParams,
)
from langchain_core.language_models.chat_models import BaseChatModel, LangSmithParams
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
Expand Down Expand Up @@ -96,13 +86,7 @@
is_basemodel_subclass,
)
from langchain_core.utils.utils import _build_model_kwargs, from_env, secret_from_env
from pydantic import (
BaseModel,
ConfigDict,
Field,
SecretStr,
model_validator,
)
from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
from pydantic.v1 import BaseModel as BaseModelV1
from typing_extensions import Self

Expand Down Expand Up @@ -465,8 +449,9 @@ class BaseChatOpenAI(BaseChatModel):
"""What sampling temperature to use."""
model_kwargs: dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified."""
openai_api_key: SecretStr | None = Field(
alias="api_key", default_factory=secret_from_env("OPENAI_API_KEY", default=None)
openai_api_key: SecretStr | None | Callable[[], str] = Field(
alias="api_key",
default_factory=secret_from_env("OPENAI_API_KEY", default=None),
)
openai_api_base: str | None = Field(default=None, alias="base_url")
"""Base URL path for API requests, leave blank if not using a proxy or service emulator.""" # noqa: E501
Expand Down Expand Up @@ -776,10 +761,16 @@ def validate_environment(self) -> Self:
):
self.stream_usage = True

# Resolve API key from SecretStr or Callable
api_key_value = None
if self.openai_api_key is not None:
if isinstance(self.openai_api_key, SecretStr):
api_key_value = self.openai_api_key.get_secret_value()
elif callable(self.openai_api_key):
api_key_value = self.openai_api_key()

client_params: dict = {
"api_key": (
self.openai_api_key.get_secret_value() if self.openai_api_key else None
),
"api_key": api_key_value,
"organization": self.openai_organization,
"base_url": self.openai_api_base,
"timeout": self.request_timeout,
Expand Down
20 changes: 14 additions & 6 deletions libs/partners/openai/langchain_openai/embeddings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import logging
import warnings
from collections.abc import Iterable, Mapping, Sequence
from collections.abc import Callable, Iterable, Mapping, Sequence
from typing import Any, Literal, cast

import openai
Expand Down Expand Up @@ -189,8 +189,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
)
embedding_ctx_length: int = 8191
"""The maximum number of tokens to embed at once."""
openai_api_key: SecretStr | None = Field(
alias="api_key", default_factory=secret_from_env("OPENAI_API_KEY", default=None)
openai_api_key: SecretStr | None | Callable[[], str] = Field(
alias="api_key",
default_factory=secret_from_env("OPENAI_API_KEY", default=None),
)
"""Automatically inferred from env var `OPENAI_API_KEY` if not provided."""
openai_organization: str | None = Field(
Expand Down Expand Up @@ -292,10 +293,17 @@ def validate_environment(self) -> Self:
"If you are using Azure, please use the `AzureOpenAIEmbeddings` class."
)
raise ValueError(msg)

# Resolve API key from SecretStr or Callable
api_key_value = None
if self.openai_api_key is not None:
if isinstance(self.openai_api_key, SecretStr):
api_key_value = self.openai_api_key.get_secret_value()
elif callable(self.openai_api_key):
api_key_value = self.openai_api_key()

client_params: dict = {
"api_key": (
self.openai_api_key.get_secret_value() if self.openai_api_key else None
),
"api_key": api_key_value,
"organization": self.openai_organization,
"base_url": self.openai_api_base,
"timeout": self.request_timeout,
Expand Down
19 changes: 13 additions & 6 deletions libs/partners/openai/langchain_openai/llms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import logging
import sys
from collections.abc import AsyncIterator, Collection, Iterator, Mapping
from collections.abc import AsyncIterator, Callable, Collection, Iterator, Mapping
from typing import Any, Literal

import openai
Expand Down Expand Up @@ -186,8 +186,9 @@ class BaseOpenAI(BaseLLM):
"""Generates best_of completions server-side and returns the "best"."""
model_kwargs: dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified."""
openai_api_key: SecretStr | None = Field(
alias="api_key", default_factory=secret_from_env("OPENAI_API_KEY", default=None)
openai_api_key: SecretStr | None | Callable[[], str] = Field(
alias="api_key",
default_factory=secret_from_env("OPENAI_API_KEY", default=None),
)
"""Automatically inferred from env var `OPENAI_API_KEY` if not provided."""
openai_api_base: str | None = Field(
Expand Down Expand Up @@ -276,10 +277,16 @@ def validate_environment(self) -> Self:
msg = "Cannot stream results when best_of > 1."
raise ValueError(msg)

# Resolve API key from SecretStr or Callable
api_key_value = None
if self.openai_api_key is not None:
if isinstance(self.openai_api_key, SecretStr):
api_key_value = self.openai_api_key.get_secret_value()
elif callable(self.openai_api_key):
api_key_value = self.openai_api_key()

client_params: dict = {
"api_key": (
self.openai_api_key.get_secret_value() if self.openai_api_key else None
),
"api_key": api_key_value,
"organization": self.openai_organization,
"base_url": self.openai_api_base,
"timeout": self.request_timeout,
Expand Down
12 changes: 12 additions & 0 deletions libs/partners/openai/tests/unit_tests/test_secrets.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,18 @@ def test_openai_uses_actual_secret_value_from_secretstr(model_class: type) -> No
assert cast(SecretStr, model.openai_api_key).get_secret_value() == "secret-api-key"


@pytest.mark.parametrize("model_class", [ChatOpenAI, OpenAI, OpenAIEmbeddings])
def test_openai_api_key_accepts_callable(model_class: type) -> None:
"""Test that the API key can be passed as a callable."""

def get_api_key() -> str:
return "secret-api-key-from-callable"

model = model_class(openai_api_key=get_api_key)
assert callable(model.openai_api_key)
assert model.openai_api_key() == "secret-api-key-from-callable"


@pytest.mark.parametrize("model_class", [AzureChatOpenAI, AzureOpenAI])
def test_azure_serialized_secrets(model_class: type) -> None:
"""Test that the actual secret value is correctly retrieved."""
Expand Down