Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 18 additions & 17 deletions libs/partners/mistralai/langchain_mistralai/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import logging
import warnings
from collections.abc import Iterable
from collections.abc import Callable, Iterable

import httpx
from httpx import Response
Expand Down Expand Up @@ -57,6 +57,7 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
api_key:
The API key for the MistralAI API. If not provided, it will be read from the
environment variable `MISTRAL_API_KEY`.
max_concurrent_requests: int
max_retries:
The number of times to retry a request if it fails.
timeout:
Expand Down Expand Up @@ -133,9 +134,9 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
default_factory=secret_from_env("MISTRAL_API_KEY", default=""),
)
endpoint: str = "https://api.mistral.ai/v1/"
max_retries: int = 5
max_retries: int | None = 5
timeout: int = 120
wait_time: int = 30
wait_time: int | None = 30
max_concurrent_requests: int = 64
tokenizer: Tokenizer = Field(default=None)

Expand Down Expand Up @@ -212,6 +213,18 @@ def _get_batches(self, texts: list[str]) -> Iterable[list[str]]:
if batch:
yield batch

def _retry(self, func: Callable) -> Callable:
if self.max_retries is None or self.wait_time is None:
return func

return retry(
retry=retry_if_exception_type(
(httpx.TimeoutException, httpx.HTTPStatusError)
),
wait=wait_fixed(self.wait_time),
stop=stop_after_attempt(self.max_retries),
)(func)

def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Embed a list of document texts.

Expand All @@ -225,13 +238,7 @@ def embed_documents(self, texts: list[str]) -> list[list[float]]:
try:
batch_responses = []

@retry(
retry=retry_if_exception_type(
(httpx.TimeoutException, httpx.HTTPStatusError)
),
wait=wait_fixed(self.wait_time),
stop=stop_after_attempt(self.max_retries),
)
@self._retry
def _embed_batch(batch: list[str]) -> Response:
response = self.client.post(
url="/embeddings",
Expand Down Expand Up @@ -266,13 +273,7 @@ async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
"""
try:

@retry(
retry=retry_if_exception_type(
(httpx.TimeoutException, httpx.HTTPStatusError)
),
wait=wait_fixed(self.wait_time),
stop=stop_after_attempt(self.max_retries),
)
@self._retry
async def _aembed_batch(batch: list[str]) -> Response:
response = await self.async_client.post(
url="/embeddings",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ async def test_mistralai_embedding_documents_async() -> None:
assert len(output[0]) == 1024


async def test_mistralai_embedding_documents_http_error_async() -> None:
async def test_mistralai_embedding_documents_tenacity_error_async() -> None:
"""Test MistralAI embeddings for documents."""
documents = ["foo bar", "test document"]
embedding = MistralAIEmbeddings(max_retries=0)
Expand All @@ -50,6 +50,21 @@ async def test_mistralai_embedding_documents_http_error_async() -> None:
await embedding.aembed_documents(documents)


async def test_mistralai_embedding_documents_http_error_async() -> None:
"""Test MistralAI embeddings for documents."""
documents = ["foo bar", "test document"]
embedding = MistralAIEmbeddings(max_retries=None)
mock_response = httpx.Response(
status_code=400,
request=httpx.Request("POST", url=embedding.async_client.base_url),
)
with (
patch.object(embedding.async_client, "post", return_value=mock_response),
pytest.raises(httpx.HTTPStatusError),
):
await embedding.aembed_documents(documents)


async def test_mistralai_embedding_query_async() -> None:
"""Test MistralAI embeddings for query."""
document = "foo bar"
Expand Down