Skip to content

Commit 72e2ff1

Browse files
committed
updating chat, embeddings, llms files and adding test
1 parent 0c8cbfb commit 72e2ff1

File tree

5 files changed

+68
-49
lines changed

5 files changed

+68
-49
lines changed

libs/partners/openai/langchain_openai/chat_models/base.py

Lines changed: 20 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,7 @@
1616
from json import JSONDecodeError
1717
from math import ceil
1818
from operator import itemgetter
19-
from typing import (
20-
TYPE_CHECKING,
21-
Any,
22-
Literal,
23-
TypeAlias,
24-
TypeVar,
25-
cast,
26-
)
19+
from typing import TYPE_CHECKING, Any, Literal, TypeAlias, TypeVar, cast
2720
from urllib.parse import urlparse
2821

2922
import certifi
@@ -34,10 +27,7 @@
3427
CallbackManagerForLLMRun,
3528
)
3629
from langchain_core.language_models import LanguageModelInput
37-
from langchain_core.language_models.chat_models import (
38-
BaseChatModel,
39-
LangSmithParams,
40-
)
30+
from langchain_core.language_models.chat_models import BaseChatModel, LangSmithParams
4131
from langchain_core.messages import (
4232
AIMessage,
4333
AIMessageChunk,
@@ -55,9 +45,9 @@
5545
ToolCall,
5646
ToolMessage,
5747
ToolMessageChunk,
58-
is_data_content_block,
5948
)
6049
from langchain_core.messages import content as types
50+
from langchain_core.messages import is_data_content_block
6151
from langchain_core.messages.ai import (
6252
InputTokenDetails,
6353
OutputTokenDetails,
@@ -96,16 +86,6 @@
9686
is_basemodel_subclass,
9787
)
9888
from langchain_core.utils.utils import _build_model_kwargs, from_env, secret_from_env
99-
from pydantic import (
100-
BaseModel,
101-
ConfigDict,
102-
Field,
103-
SecretStr,
104-
model_validator,
105-
)
106-
from pydantic.v1 import BaseModel as BaseModelV1
107-
from typing_extensions import Self
108-
10989
from langchain_openai.chat_models._client_utils import (
11090
_get_default_async_httpx_client,
11191
_get_default_httpx_client,
@@ -115,6 +95,9 @@
11595
_convert_from_v1_to_responses,
11696
_convert_to_v03_ai_message,
11797
)
98+
from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
99+
from pydantic.v1 import BaseModel as BaseModelV1
100+
from typing_extensions import Self
118101

119102
if TYPE_CHECKING:
120103
from openai.types.responses import Response
@@ -465,8 +448,11 @@ class BaseChatOpenAI(BaseChatModel):
465448
"""What sampling temperature to use."""
466449
model_kwargs: dict[str, Any] = Field(default_factory=dict)
467450
"""Holds any model parameters valid for `create` call not explicitly specified."""
468-
openai_api_key: SecretStr | None = Field(
469-
alias="api_key", default_factory=secret_from_env("OPENAI_API_KEY", default=None)
451+
openai_api_key: SecretStr | None | Callable[[], str] = Field(
452+
alias="api_key",
453+
default_factory=secret_from_env(
454+
["OPENAI_API_KEY", "AZURE_OPENAI_API_KEY"], default=None
455+
),
470456
)
471457
openai_api_base: str | None = Field(default=None, alias="base_url")
472458
"""Base URL path for API requests, leave blank if not using a proxy or service emulator.""" # noqa: E501
@@ -776,10 +762,16 @@ def validate_environment(self) -> Self:
776762
):
777763
self.stream_usage = True
778764

765+
# Resolve API key from SecretStr or Callable
766+
api_key_value = None
767+
if self.openai_api_key is not None:
768+
if isinstance(self.openai_api_key, SecretStr):
769+
api_key_value = self.openai_api_key.get_secret_value()
770+
elif callable(self.openai_api_key):
771+
api_key_value = self.openai_api_key()
772+
779773
client_params: dict = {
780-
"api_key": (
781-
self.openai_api_key.get_secret_value() if self.openai_api_key else None
782-
),
774+
"api_key": api_key_value,
783775
"organization": self.openai_organization,
784776
"base_url": self.openai_api_base,
785777
"timeout": self.request_timeout,

libs/partners/openai/langchain_openai/embeddings/base.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import logging
66
import warnings
7-
from collections.abc import Iterable, Mapping, Sequence
7+
from collections.abc import Callable, Iterable, Mapping, Sequence
88
from typing import Any, Literal, cast
99

1010
import openai
@@ -189,8 +189,11 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
189189
)
190190
embedding_ctx_length: int = 8191
191191
"""The maximum number of tokens to embed at once."""
192-
openai_api_key: SecretStr | None = Field(
193-
alias="api_key", default_factory=secret_from_env("OPENAI_API_KEY", default=None)
192+
openai_api_key: SecretStr | None | Callable[[], str] = Field(
193+
alias="api_key",
194+
default_factory=secret_from_env(
195+
["OPENAI_API_KEY", "AZURE_OPENAI_API_KEY"], default=None
196+
),
194197
)
195198
"""Automatically inferred from env var `OPENAI_API_KEY` if not provided."""
196199
openai_organization: str | None = Field(
@@ -292,10 +295,17 @@ def validate_environment(self) -> Self:
292295
"If you are using Azure, please use the `AzureOpenAIEmbeddings` class."
293296
)
294297
raise ValueError(msg)
298+
299+
# Resolve API key from SecretStr or Callable
300+
api_key_value = None
301+
if self.openai_api_key is not None:
302+
if isinstance(self.openai_api_key, SecretStr):
303+
api_key_value = self.openai_api_key.get_secret_value()
304+
elif callable(self.openai_api_key):
305+
api_key_value = self.openai_api_key()
306+
295307
client_params: dict = {
296-
"api_key": (
297-
self.openai_api_key.get_secret_value() if self.openai_api_key else None
298-
),
308+
"api_key": api_key_value,
299309
"organization": self.openai_organization,
300310
"base_url": self.openai_api_base,
301311
"timeout": self.request_timeout,

libs/partners/openai/langchain_openai/llms/base.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import logging
66
import sys
7-
from collections.abc import AsyncIterator, Collection, Iterator, Mapping
7+
from collections.abc import AsyncIterator, Callable, Collection, Iterator, Mapping
88
from typing import Any, Literal
99

1010
import openai
@@ -186,8 +186,11 @@ class BaseOpenAI(BaseLLM):
186186
"""Generates best_of completions server-side and returns the "best"."""
187187
model_kwargs: dict[str, Any] = Field(default_factory=dict)
188188
"""Holds any model parameters valid for `create` call not explicitly specified."""
189-
openai_api_key: SecretStr | None = Field(
190-
alias="api_key", default_factory=secret_from_env("OPENAI_API_KEY", default=None)
189+
openai_api_key: SecretStr | None | Callable[[], str] = Field(
190+
alias="api_key",
191+
default_factory=secret_from_env(
192+
["OPENAI_API_KEY", "AZURE_OPENAI_API_KEY"], default=None
193+
),
191194
)
192195
"""Automatically inferred from env var `OPENAI_API_KEY` if not provided."""
193196
openai_api_base: str | None = Field(
@@ -276,10 +279,16 @@ def validate_environment(self) -> Self:
276279
msg = "Cannot stream results when best_of > 1."
277280
raise ValueError(msg)
278281

282+
# Resolve API key from SecretStr or Callable
283+
api_key_value = None
284+
if self.openai_api_key is not None:
285+
if isinstance(self.openai_api_key, SecretStr):
286+
api_key_value = self.openai_api_key.get_secret_value()
287+
elif callable(self.openai_api_key):
288+
api_key_value = self.openai_api_key()
289+
279290
client_params: dict = {
280-
"api_key": (
281-
self.openai_api_key.get_secret_value() if self.openai_api_key else None
282-
),
291+
"api_key": api_key_value,
283292
"organization": self.openai_organization,
284293
"base_url": self.openai_api_base,
285294
"timeout": self.request_timeout,

libs/partners/openai/tests/unit_tests/test_secrets.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,6 @@
22

33
import pytest
44
from langchain_core.load import dumpd
5-
from pydantic import SecretStr
6-
from pytest import CaptureFixture, MonkeyPatch
7-
85
from langchain_openai import (
96
AzureChatOpenAI,
107
AzureOpenAI,
@@ -13,6 +10,8 @@
1310
OpenAI,
1411
OpenAIEmbeddings,
1512
)
13+
from pydantic import SecretStr
14+
from pytest import CaptureFixture, MonkeyPatch
1615

1716
AZURE_AD_TOKEN = "secret-api-key" # noqa: S105
1817

@@ -187,6 +186,18 @@ def test_openai_uses_actual_secret_value_from_secretstr(model_class: type) -> No
187186
assert cast(SecretStr, model.openai_api_key).get_secret_value() == "secret-api-key"
188187

189188

189+
@pytest.mark.parametrize("model_class", [ChatOpenAI, OpenAI, OpenAIEmbeddings])
190+
def test_openai_api_key_accepts_callable(model_class: type) -> None:
191+
"""Test that the API key can be passed as a callable."""
192+
193+
def get_api_key() -> str:
194+
return "secret-api-key-from-callable"
195+
196+
model = model_class(openai_api_key=get_api_key)
197+
assert callable(model.openai_api_key)
198+
assert model.openai_api_key() == "secret-api-key-from-callable"
199+
200+
190201
@pytest.mark.parametrize("model_class", [AzureChatOpenAI, AzureOpenAI])
191202
def test_azure_serialized_secrets(model_class: type) -> None:
192203
"""Test that the actual secret value is correctly retrieved."""

libs/partners/openai/uv.lock

Lines changed: 3 additions & 6 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)